Logo AND Algorithmique Numérique Distribuée

Public GIT Repository
Add Gather SMP collective from MVAPICH2
[simgrid.git] / src / smpi / colls / smpi_mvapich2_selector.c
1 /* selector for collective algorithms based on mvapich decision logic */
2
3 /* Copyright (c) 2009-2010, 2013-2014. The SimGrid Team.
4  * All rights reserved.                                                     */
5
6 /* This program is xbt_free software; you can redistribute it and/or modify it
7  * under the terms of the license (GNU LGPL) which comes with this package. */
8
9 #include "colls_private.h"
10
11 #include "smpi_mvapich2_selector_stampede.h"
12
13
14
15 int smpi_coll_tuned_alltoall_mvapich2( void *sendbuf, int sendcount, 
16     MPI_Datatype sendtype,
17     void* recvbuf, int recvcount,
18     MPI_Datatype recvtype,
19     MPI_Comm comm)
20 {
21
22   if(mv2_alltoall_table_ppn_conf==NULL)
23     init_mv2_alltoall_tables_stampede();
24
25   int sendtype_size, recvtype_size, nbytes, comm_size;
26   char * tmp_buf = NULL;
27   int mpi_errno=MPI_SUCCESS;
28   int range = 0;
29   int range_threshold = 0;
30   int conf_index = 0;
31   comm_size =  smpi_comm_size(comm);
32
33   sendtype_size=smpi_datatype_size(sendtype);
34   recvtype_size=smpi_datatype_size(recvtype);
35   nbytes = sendtype_size * sendcount;
36
37   /* check if safe to use partial subscription mode */
38
39   /* Search for the corresponding system size inside the tuning table */
40   while ((range < (mv2_size_alltoall_tuning_table[conf_index] - 1)) &&
41       (comm_size > mv2_alltoall_thresholds_table[conf_index][range].numproc)) {
42       range++;
43   }
44   /* Search for corresponding inter-leader function */
45   while ((range_threshold < (mv2_alltoall_thresholds_table[conf_index][range].size_table - 1))
46       && (nbytes >
47   mv2_alltoall_thresholds_table[conf_index][range].algo_table[range_threshold].max)
48   && (mv2_alltoall_thresholds_table[conf_index][range].algo_table[range_threshold].max != -1)) {
49       range_threshold++;
50   }
51   MV2_Alltoall_function = mv2_alltoall_thresholds_table[conf_index][range].algo_table[range_threshold]
52                                                                                       .MV2_pt_Alltoall_function;
53
54   if(sendbuf != MPI_IN_PLACE) {
55       mpi_errno = MV2_Alltoall_function(sendbuf, sendcount, sendtype,
56           recvbuf, recvcount, recvtype,
57           comm);
58   } else {
59       range_threshold = 0;
60       if(nbytes <
61           mv2_alltoall_thresholds_table[conf_index][range].in_place_algo_table[range_threshold].min
62           ||nbytes > mv2_alltoall_thresholds_table[conf_index][range].in_place_algo_table[range_threshold].max
63       ) {
64           tmp_buf = (char *)xbt_malloc( comm_size * recvcount * recvtype_size );
65           mpi_errno = smpi_datatype_copy((char *)recvbuf,
66               comm_size*recvcount, recvtype,
67               (char *)tmp_buf,
68               comm_size*recvcount, recvtype);
69
70           mpi_errno = MV2_Alltoall_function(tmp_buf, recvcount, recvtype,
71               recvbuf, recvcount, recvtype,
72               comm );
73           xbt_free(tmp_buf);
74       } else {
75           mpi_errno = MPIR_Alltoall_inplace_MV2(sendbuf, sendcount, sendtype,
76               recvbuf, recvcount, recvtype,
77               comm );
78       }
79   }
80
81
82   return (mpi_errno);
83 }
84
85
86
87 int smpi_coll_tuned_allgather_mvapich2(void *sendbuf, int sendcount, MPI_Datatype sendtype,
88     void *recvbuf, int recvcount, MPI_Datatype recvtype,
89     MPI_Comm comm)
90 {
91
92   int mpi_errno = MPI_SUCCESS;
93   int nbytes = 0, comm_size, recvtype_size;
94   int range = 0;
95   //int partial_sub_ok = 0;
96   int conf_index = 0;
97   int range_threshold = 0;
98   int is_two_level = 0;
99   //int local_size = -1;
100   //MPI_Comm shmem_comm;
101   //MPI_Comm *shmem_commptr=NULL;
102   /* Get the size of the communicator */
103   comm_size = smpi_comm_size(comm);
104   recvtype_size=smpi_datatype_size(recvtype);
105   nbytes = recvtype_size * recvcount;
106
107   if(mv2_allgather_table_ppn_conf==NULL)
108     init_mv2_allgather_tables_stampede();
109
110   //int i;
111   /* check if safe to use partial subscription mode */
112   /*  if (comm->ch.shmem_coll_ok == 1 && comm->ch.is_uniform) {
113
114         shmem_comm = comm->ch.shmem_comm;
115         MPID_Comm_get_ptr(shmem_comm, shmem_commptr);
116         local_size = shmem_commptr->local_size;
117         i = 0;
118         if (mv2_allgather_table_ppn_conf[0] == -1) {
119             // Indicating user defined tuning
120             conf_index = 0;
121             goto conf_check_end;
122         }
123         do {
124             if (local_size == mv2_allgather_table_ppn_conf[i]) {
125                 conf_index = i;
126                 partial_sub_ok = 1;
127                 break;
128             }
129             i++;
130         } while(i < mv2_allgather_num_ppn_conf);
131     }
132
133   conf_check_end:
134     if (partial_sub_ok != 1) {
135         conf_index = 0;
136     }*/
137   /* Search for the corresponding system size inside the tuning table */
138   while ((range < (mv2_size_allgather_tuning_table[conf_index] - 1)) &&
139       (comm_size >
140   mv2_allgather_thresholds_table[conf_index][range].numproc)) {
141       range++;
142   }
143   /* Search for corresponding inter-leader function */
144   while ((range_threshold <
145       (mv2_allgather_thresholds_table[conf_index][range].size_inter_table - 1))
146       && (nbytes > mv2_allgather_thresholds_table[conf_index][range].inter_leader[range_threshold].max)
147       && (mv2_allgather_thresholds_table[conf_index][range].inter_leader[range_threshold].max !=
148           -1)) {
149       range_threshold++;
150   }
151
152   /* Set inter-leader pt */
153   MV2_Allgather_function =
154       mv2_allgather_thresholds_table[conf_index][range].inter_leader[range_threshold].
155       MV2_pt_Allgather_function;
156
157   is_two_level =  mv2_allgather_thresholds_table[conf_index][range].two_level[range_threshold];
158
159   /* intracommunicator */
160   if(is_two_level ==1){
161
162       /*       if(comm->ch.shmem_coll_ok == 1){
163             MPIR_T_PVAR_COUNTER_INC(MV2, mv2_num_shmem_coll_calls, 1);
164            if (1 == comm->ch.is_blocked) {
165                 mpi_errno = MPIR_2lvl_Allgather_MV2(sendbuf, sendcount, sendtype,
166                                                     recvbuf, recvcount, recvtype,
167                                                     comm, errflag);
168            }
169            else {
170                mpi_errno = MPIR_Allgather_intra(sendbuf, sendcount, sendtype,
171                                                 recvbuf, recvcount, recvtype,
172                                                 comm, errflag);
173            }
174         } else {*/
175       mpi_errno = MPIR_Allgather_RD_MV2(sendbuf, sendcount, sendtype,
176           recvbuf, recvcount, recvtype,
177           comm);
178       //     }
179   } else if(MV2_Allgather_function == &MPIR_Allgather_Bruck_MV2
180       || MV2_Allgather_function == &MPIR_Allgather_RD_MV2
181       || MV2_Allgather_function == &MPIR_Allgather_Ring_MV2) {
182       mpi_errno = MV2_Allgather_function(sendbuf, sendcount, sendtype,
183           recvbuf, recvcount, recvtype,
184           comm);
185   }else{
186       return MPI_ERR_OTHER;
187   }
188
189   return mpi_errno;
190 }
191
192
193 int smpi_coll_tuned_gather_mvapich2(void *sendbuf,
194     int sendcnt,
195     MPI_Datatype sendtype,
196     void *recvbuf,
197     int recvcnt,
198     MPI_Datatype recvtype,
199     int root, MPI_Comm  comm)
200 {
201   if(mv2_gather_thresholds_table==NULL)
202     init_mv2_gather_tables_stampede();
203
204   int mpi_errno = MPI_SUCCESS;
205   int range = 0;
206   int range_threshold = 0;
207   int range_intra_threshold = 0;
208   int nbytes = 0;
209   int comm_size = 0;
210   int recvtype_size, sendtype_size;
211   int rank = -1;
212   comm_size = smpi_comm_size(comm);
213   rank = smpi_comm_rank(comm);
214
215   if (rank == root) {
216       recvtype_size=smpi_datatype_size(recvtype);
217       nbytes = recvcnt * recvtype_size;
218   } else {
219       sendtype_size=smpi_datatype_size(sendtype);
220       nbytes = sendcnt * sendtype_size;
221   }
222
223   /* Search for the corresponding system size inside the tuning table */
224   while ((range < (mv2_size_gather_tuning_table - 1)) &&
225       (comm_size > mv2_gather_thresholds_table[range].numproc)) {
226       range++;
227   }
228   /* Search for corresponding inter-leader function */
229   while ((range_threshold < (mv2_gather_thresholds_table[range].size_inter_table - 1))
230       && (nbytes >
231   mv2_gather_thresholds_table[range].inter_leader[range_threshold].max)
232   && (mv2_gather_thresholds_table[range].inter_leader[range_threshold].max !=
233       -1)) {
234       range_threshold++;
235   }
236
237   /* Search for corresponding intra node function */
238   while ((range_intra_threshold < (mv2_gather_thresholds_table[range].size_intra_table - 1))
239       && (nbytes >
240   mv2_gather_thresholds_table[range].intra_node[range_intra_threshold].max)
241   && (mv2_gather_thresholds_table[range].intra_node[range_intra_threshold].max !=
242       -1)) {
243       range_intra_threshold++;
244   }
245   
246     if (smpi_comm_is_blocked(comm) ) {
247         // Set intra-node function pt for gather_two_level 
248         MV2_Gather_intra_node_function = 
249                               mv2_gather_thresholds_table[range].intra_node[range_intra_threshold].
250                               MV2_pt_Gather_function;
251         //Set inter-leader pt 
252         MV2_Gather_inter_leader_function =
253                               mv2_gather_thresholds_table[range].inter_leader[range_threshold].
254                               MV2_pt_Gather_function;
255         // We call Gather function 
256         mpi_errno =
257             MV2_Gather_inter_leader_function(sendbuf, sendcnt, sendtype, recvbuf, recvcnt,
258                                              recvtype, root, comm);
259
260     } else {
261   // Indeed, direct (non SMP-aware)gather is MPICH one
262   mpi_errno = smpi_coll_tuned_gather_mpich(sendbuf, sendcnt, sendtype,
263       recvbuf, recvcnt, recvtype,
264       root, comm);
265   }
266
267   return mpi_errno;
268 }
269
270
271 int smpi_coll_tuned_allgatherv_mvapich2(void *sendbuf, int sendcount, MPI_Datatype sendtype,
272     void *recvbuf, int *recvcounts, int *displs,
273     MPI_Datatype recvtype, MPI_Comm  comm )
274 {
275   int mpi_errno = MPI_SUCCESS;
276   int range = 0, comm_size, total_count, recvtype_size, i;
277   int range_threshold = 0;
278   int nbytes = 0;
279
280   if(mv2_allgatherv_thresholds_table==NULL)
281     init_mv2_allgatherv_tables_stampede();
282
283   comm_size = smpi_comm_size(comm);
284   total_count = 0;
285   for (i = 0; i < comm_size; i++)
286     total_count += recvcounts[i];
287
288   recvtype_size=smpi_datatype_size(recvtype);
289   nbytes = total_count * recvtype_size;
290
291   /* Search for the corresponding system size inside the tuning table */
292   while ((range < (mv2_size_allgatherv_tuning_table - 1)) &&
293       (comm_size > mv2_allgatherv_thresholds_table[range].numproc)) {
294       range++;
295   }
296   /* Search for corresponding inter-leader function */
297   while ((range_threshold < (mv2_allgatherv_thresholds_table[range].size_inter_table - 1))
298       && (nbytes >
299   comm_size * mv2_allgatherv_thresholds_table[range].inter_leader[range_threshold].max)
300   && (mv2_allgatherv_thresholds_table[range].inter_leader[range_threshold].max !=
301       -1)) {
302       range_threshold++;
303   }
304   /* Set inter-leader pt */
305   MV2_Allgatherv_function =
306       mv2_allgatherv_thresholds_table[range].inter_leader[range_threshold].
307       MV2_pt_Allgatherv_function;
308
309   if (MV2_Allgatherv_function == &MPIR_Allgatherv_Rec_Doubling_MV2)
310     {
311       if(!(comm_size & (comm_size - 1)))
312         {
313           mpi_errno =
314               MPIR_Allgatherv_Rec_Doubling_MV2(sendbuf, sendcount,
315                   sendtype, recvbuf,
316                   recvcounts, displs,
317                   recvtype, comm);
318         } else {
319             mpi_errno =
320                 MPIR_Allgatherv_Bruck_MV2(sendbuf, sendcount,
321                     sendtype, recvbuf,
322                     recvcounts, displs,
323                     recvtype, comm);
324         }
325     } else {
326         mpi_errno =
327             MV2_Allgatherv_function(sendbuf, sendcount, sendtype,
328                 recvbuf, recvcounts, displs,
329                 recvtype, comm);
330     }
331
332   return mpi_errno;
333 }
334
335
336
337 int smpi_coll_tuned_allreduce_mvapich2(void *sendbuf,
338     void *recvbuf,
339     int count,
340     MPI_Datatype datatype,
341     MPI_Op op, MPI_Comm comm)
342 {
343
344   int mpi_errno = MPI_SUCCESS;
345   //int rank = 0,
346   int comm_size = 0;
347
348   comm_size = smpi_comm_size(comm);
349   //rank = smpi_comm_rank(comm);
350
351   if (count == 0) {
352       return MPI_SUCCESS;
353   }
354
355   if (mv2_allreduce_thresholds_table == NULL)
356     init_mv2_allreduce_tables_stampede();
357
358   /* check if multiple threads are calling this collective function */
359
360   MPI_Aint sendtype_size = 0;
361   int nbytes = 0;
362   int range = 0, range_threshold = 0, range_threshold_intra = 0;
363   int is_two_level = 0;
364   //int is_commutative = 0;
365   MPI_Aint true_lb, true_extent;
366
367   sendtype_size=smpi_datatype_size(datatype);
368   nbytes = count * sendtype_size;
369
370   smpi_datatype_extent(datatype, &true_lb, &true_extent);
371   //MPI_Op *op_ptr;
372   //is_commutative = smpi_op_is_commute(op);
373
374   {
375     /* Search for the corresponding system size inside the tuning table */
376     while ((range < (mv2_size_allreduce_tuning_table - 1)) &&
377         (comm_size > mv2_allreduce_thresholds_table[range].numproc)) {
378         range++;
379     }
380     /* Search for corresponding inter-leader function */
381     /* skip mcast poiters if mcast is not available */
382     if(mv2_allreduce_thresholds_table[range].mcast_enabled != 1){
383         while ((range_threshold < (mv2_allreduce_thresholds_table[range].size_inter_table - 1))
384             && ((mv2_allreduce_thresholds_table[range].
385                 inter_leader[range_threshold].MV2_pt_Allreduce_function
386                 == &MPIR_Allreduce_mcst_reduce_redscat_gather_MV2) ||
387                 (mv2_allreduce_thresholds_table[range].
388                     inter_leader[range_threshold].MV2_pt_Allreduce_function
389                     == &MPIR_Allreduce_mcst_reduce_two_level_helper_MV2)
390             )) {
391             range_threshold++;
392         }
393     }
394     while ((range_threshold < (mv2_allreduce_thresholds_table[range].size_inter_table - 1))
395         && (nbytes >
396     mv2_allreduce_thresholds_table[range].inter_leader[range_threshold].max)
397     && (mv2_allreduce_thresholds_table[range].inter_leader[range_threshold].max != -1)) {
398         range_threshold++;
399     }
400     if(mv2_allreduce_thresholds_table[range].is_two_level_allreduce[range_threshold] == 1){
401         is_two_level = 1;
402     }
403     /* Search for corresponding intra-node function */
404     while ((range_threshold_intra <
405         (mv2_allreduce_thresholds_table[range].size_intra_table - 1))
406         && (nbytes >
407     mv2_allreduce_thresholds_table[range].intra_node[range_threshold_intra].max)
408     && (mv2_allreduce_thresholds_table[range].intra_node[range_threshold_intra].max !=
409         -1)) {
410         range_threshold_intra++;
411     }
412
413     MV2_Allreduce_function = mv2_allreduce_thresholds_table[range].inter_leader[range_threshold]
414                                                                                 .MV2_pt_Allreduce_function;
415
416     MV2_Allreduce_intra_function = mv2_allreduce_thresholds_table[range].intra_node[range_threshold_intra]
417                                                                                     .MV2_pt_Allreduce_function;
418
419     /* check if mcast is ready, otherwise replace mcast with other algorithm */
420     if((MV2_Allreduce_function == &MPIR_Allreduce_mcst_reduce_redscat_gather_MV2)||
421         (MV2_Allreduce_function == &MPIR_Allreduce_mcst_reduce_two_level_helper_MV2)){
422         {
423           MV2_Allreduce_function = &MPIR_Allreduce_pt2pt_rd_MV2;
424         }
425         if(is_two_level != 1) {
426             MV2_Allreduce_function = &MPIR_Allreduce_pt2pt_rd_MV2;
427         }
428     }
429
430     if(is_two_level == 1){
431         // check if shm is ready, if not use other algorithm first
432         /*if ((comm->ch.shmem_coll_ok == 1)
433                     && (mv2_enable_shmem_allreduce)
434                     && (is_commutative)
435                     && (mv2_enable_shmem_collectives)) {
436                     mpi_errno = MPIR_Allreduce_two_level_MV2(sendbuf, recvbuf, count,
437                                                      datatype, op, comm);
438                 } else {*/
439         mpi_errno = MPIR_Allreduce_pt2pt_rd_MV2(sendbuf, recvbuf, count,
440             datatype, op, comm);
441         // }
442     } else {
443         mpi_errno = MV2_Allreduce_function(sendbuf, recvbuf, count,
444             datatype, op, comm);
445     }
446   }
447
448   //comm->ch.intra_node_done=0;
449
450   return (mpi_errno);
451
452
453 }
454
455
456 int smpi_coll_tuned_alltoallv_mvapich2(void *sbuf, int *scounts, int *sdisps,
457     MPI_Datatype sdtype,
458     void *rbuf, int *rcounts, int *rdisps,
459     MPI_Datatype rdtype,
460     MPI_Comm  comm
461 )
462 {
463
464   if (sbuf == MPI_IN_PLACE) {
465       return smpi_coll_tuned_alltoallv_ompi_basic_linear(sbuf, scounts, sdisps, sdtype,
466           rbuf, rcounts, rdisps,rdtype,
467           comm);
468   } else     /* For starters, just keep the original algorithm. */
469   return smpi_coll_tuned_alltoallv_ring(sbuf, scounts, sdisps, sdtype,
470       rbuf, rcounts, rdisps,rdtype,
471       comm);
472 }
473
474
475 int smpi_coll_tuned_barrier_mvapich2(MPI_Comm  comm)
476 {   
477   return smpi_coll_tuned_barrier_mvapich2_pair(comm);
478 }
479
480
481
482
483 int smpi_coll_tuned_bcast_mvapich2(void *buffer,
484     int count,
485     MPI_Datatype datatype,
486     int root, MPI_Comm comm)
487 {
488
489   //TODO : Bcast really needs intra/inter phases in mvapich. Default to mpich if not available
490   return smpi_coll_tuned_bcast_mpich(buffer, count, datatype, root, comm);
491
492 }
493
494
495
496 int smpi_coll_tuned_reduce_mvapich2( void *sendbuf,
497     void *recvbuf,
498     int count,
499     MPI_Datatype datatype,
500     MPI_Op op, int root, MPI_Comm comm)
501 {
502   if(mv2_reduce_thresholds_table == NULL)
503     init_mv2_reduce_tables_stampede();
504
505   int mpi_errno = MPI_SUCCESS;
506   int range = 0;
507   int range_threshold = 0;
508   int range_intra_threshold = 0;
509   int is_commutative, pof2;
510   int comm_size = 0;
511   int nbytes = 0;
512   int sendtype_size;
513   int is_two_level = 0;
514
515   comm_size = smpi_comm_size(comm);
516   sendtype_size=smpi_datatype_size(datatype);
517   nbytes = count * sendtype_size;
518
519   if (count == 0)
520     return MPI_SUCCESS;
521
522   is_commutative = smpi_op_is_commute(op);
523
524   /* find nearest power-of-two less than or equal to comm_size */
525   for( pof2 = 1; pof2 <= comm_size; pof2 <<= 1 );
526   pof2 >>=1;
527
528
529   /* Search for the corresponding system size inside the tuning table */
530   while ((range < (mv2_size_reduce_tuning_table - 1)) &&
531       (comm_size > mv2_reduce_thresholds_table[range].numproc)) {
532       range++;
533   }
534   /* Search for corresponding inter-leader function */
535   while ((range_threshold < (mv2_reduce_thresholds_table[range].size_inter_table - 1))
536       && (nbytes >
537   mv2_reduce_thresholds_table[range].inter_leader[range_threshold].max)
538   && (mv2_reduce_thresholds_table[range].inter_leader[range_threshold].max !=
539       -1)) {
540       range_threshold++;
541   }
542
543   /* Search for corresponding intra node function */
544   while ((range_intra_threshold < (mv2_reduce_thresholds_table[range].size_intra_table - 1))
545       && (nbytes >
546   mv2_reduce_thresholds_table[range].intra_node[range_intra_threshold].max)
547   && (mv2_reduce_thresholds_table[range].intra_node[range_intra_threshold].max !=
548       -1)) {
549       range_intra_threshold++;
550   }
551
552   /* Set intra-node function pt for reduce_two_level */
553   MV2_Reduce_intra_function =
554       mv2_reduce_thresholds_table[range].intra_node[range_intra_threshold].
555       MV2_pt_Reduce_function;
556   /* Set inter-leader pt */
557   MV2_Reduce_function =
558       mv2_reduce_thresholds_table[range].inter_leader[range_threshold].
559       MV2_pt_Reduce_function;
560
561   if(mv2_reduce_intra_knomial_factor<0)
562     {
563       mv2_reduce_intra_knomial_factor = mv2_reduce_thresholds_table[range].intra_k_degree;
564     }
565   if(mv2_reduce_inter_knomial_factor<0)
566     {
567       mv2_reduce_inter_knomial_factor = mv2_reduce_thresholds_table[range].inter_k_degree;
568     }
569   if(mv2_reduce_thresholds_table[range].is_two_level_reduce[range_threshold] == 1){
570       is_two_level = 1;
571   }
572   /* We call Reduce function */
573   if(is_two_level == 1)
574     {
575       /* if (comm->ch.shmem_coll_ok == 1
576             && is_commutative == 1) {
577             mpi_errno = MPIR_Reduce_two_level_helper_MV2(sendbuf, recvbuf, count, 
578                                            datatype, op, root, comm, errflag);
579         } else {*/
580       mpi_errno = MPIR_Reduce_binomial_MV2(sendbuf, recvbuf, count,
581           datatype, op, root, comm);
582       //}
583     } else if(MV2_Reduce_function == &MPIR_Reduce_inter_knomial_wrapper_MV2 ){
584         if(is_commutative ==1)
585           {
586             mpi_errno = MV2_Reduce_function(sendbuf, recvbuf, count, 
587                 datatype, op, root, comm);
588           } else {
589               mpi_errno = MPIR_Reduce_binomial_MV2(sendbuf, recvbuf, count,
590                   datatype, op, root, comm);
591           }
592     } else if(MV2_Reduce_function == &MPIR_Reduce_redscat_gather_MV2){
593         if (/*(HANDLE_GET_KIND(op) == HANDLE_KIND_BUILTIN) &&*/ (count >= pof2))
594           {
595             mpi_errno = MV2_Reduce_function(sendbuf, recvbuf, count, 
596                 datatype, op, root, comm);
597           } else {
598               mpi_errno = MPIR_Reduce_binomial_MV2(sendbuf, recvbuf, count,
599                   datatype, op, root, comm);
600           }
601     } else {
602         mpi_errno = MV2_Reduce_function(sendbuf, recvbuf, count, 
603             datatype, op, root, comm);
604     }
605
606
607   return mpi_errno;
608
609 }
610
611
612 int smpi_coll_tuned_reduce_scatter_mvapich2(void *sendbuf, void *recvbuf, int *recvcnts,
613     MPI_Datatype datatype, MPI_Op op,
614     MPI_Comm comm)
615 {
616   int mpi_errno = MPI_SUCCESS;
617   int i = 0, comm_size = smpi_comm_size(comm), total_count = 0, type_size =
618       0, nbytes = 0;
619   int range = 0;
620   int range_threshold = 0;
621   int is_commutative = 0;
622   int *disps = xbt_malloc(comm_size * sizeof (int));
623
624   if(mv2_red_scat_thresholds_table==NULL)
625     init_mv2_reduce_scatter_tables_stampede();
626
627   is_commutative=smpi_op_is_commute(op);
628   for (i = 0; i < comm_size; i++) {
629       disps[i] = total_count;
630       total_count += recvcnts[i];
631   }
632
633   type_size=smpi_datatype_size(datatype);
634   nbytes = total_count * type_size;
635
636   if (is_commutative) {
637
638       /* Search for the corresponding system size inside the tuning table */
639       while ((range < (mv2_size_red_scat_tuning_table - 1)) &&
640           (comm_size > mv2_red_scat_thresholds_table[range].numproc)) {
641           range++;
642       }
643       /* Search for corresponding inter-leader function */
644       while ((range_threshold < (mv2_red_scat_thresholds_table[range].size_inter_table - 1))
645           && (nbytes >
646       mv2_red_scat_thresholds_table[range].inter_leader[range_threshold].max)
647       && (mv2_red_scat_thresholds_table[range].inter_leader[range_threshold].max !=
648           -1)) {
649           range_threshold++;
650       }
651
652       /* Set inter-leader pt */
653       MV2_Red_scat_function =
654           mv2_red_scat_thresholds_table[range].inter_leader[range_threshold].
655           MV2_pt_Red_scat_function;
656
657       mpi_errno = MV2_Red_scat_function(sendbuf, recvbuf,
658           recvcnts, datatype,
659           op, comm);
660   } else {
661       int is_block_regular = 1;
662       for (i = 0; i < (comm_size - 1); ++i) {
663           if (recvcnts[i] != recvcnts[i+1]) {
664               is_block_regular = 0;
665               break;
666           }
667       }
668       int pof2 = 1;
669       while (pof2 < comm_size) pof2 <<= 1;
670       if (pof2 == comm_size && is_block_regular) {
671           /* noncommutative, pof2 size, and block regular */
672           mpi_errno = MPIR_Reduce_scatter_non_comm_MV2(sendbuf, recvbuf,
673               recvcnts, datatype,
674               op, comm);
675       }
676       mpi_errno =  smpi_coll_tuned_reduce_scatter_mpich_rdb(sendbuf, recvbuf,
677           recvcnts, datatype,
678           op, comm);
679   }
680
681   return mpi_errno;
682
683 }
684
685
686
687 int smpi_coll_tuned_scatter_mvapich2(void *sendbuf,
688     int sendcnt,
689     MPI_Datatype sendtype,
690     void *recvbuf,
691     int recvcnt,
692     MPI_Datatype recvtype,
693     int root, MPI_Comm comm_ptr)
694 {
695   int range = 0, range_threshold = 0, range_threshold_intra = 0;
696   int mpi_errno = MPI_SUCCESS;
697   //   int mpi_errno_ret = MPI_SUCCESS;
698   int rank, nbytes, comm_size;
699   int recvtype_size, sendtype_size;
700   int partial_sub_ok = 0;
701   int conf_index = 0;
702   //  int local_size = -1;
703   //  int i;
704   //   MPI_Comm shmem_comm;
705   //    MPID_Comm *shmem_commptr=NULL;
706   if(mv2_scatter_thresholds_table==NULL)
707     init_mv2_scatter_tables_stampede();
708
709   comm_size = smpi_comm_size(comm_ptr);
710
711   rank = smpi_comm_rank(comm_ptr);
712
713   if (rank == root) {
714       sendtype_size=smpi_datatype_size(sendtype);
715       nbytes = sendcnt * sendtype_size;
716   } else {
717       recvtype_size=smpi_datatype_size(recvtype);
718       nbytes = recvcnt * recvtype_size;
719   }
720   /*
721     // check if safe to use partial subscription mode 
722     if (comm_ptr->ch.shmem_coll_ok == 1 && comm_ptr->ch.is_uniform) {
723
724         shmem_comm = comm_ptr->ch.shmem_comm;
725         MPID_Comm_get_ptr(shmem_comm, shmem_commptr);
726         local_size = shmem_commptr->local_size;
727         i = 0;
728         if (mv2_scatter_table_ppn_conf[0] == -1) {
729             // Indicating user defined tuning 
730             conf_index = 0;
731             goto conf_check_end;
732         }
733         do {
734             if (local_size == mv2_scatter_table_ppn_conf[i]) {
735                 conf_index = i;
736                 partial_sub_ok = 1;
737                 break;
738             }
739             i++;
740         } while(i < mv2_scatter_num_ppn_conf);
741     }
742    */
743   if (partial_sub_ok != 1) {
744       conf_index = 0;
745   }
746
747   /* Search for the corresponding system size inside the tuning table */
748   while ((range < (mv2_size_scatter_tuning_table[conf_index] - 1)) &&
749       (comm_size > mv2_scatter_thresholds_table[conf_index][range].numproc)) {
750       range++;
751   }
752   /* Search for corresponding inter-leader function */
753   while ((range_threshold < (mv2_scatter_thresholds_table[conf_index][range].size_inter_table - 1))
754       && (nbytes >
755   mv2_scatter_thresholds_table[conf_index][range].inter_leader[range_threshold].max)
756   && (mv2_scatter_thresholds_table[conf_index][range].inter_leader[range_threshold].max != -1)) {
757       range_threshold++;
758   }
759
760   /* Search for corresponding intra-node function */
761   while ((range_threshold_intra <
762       (mv2_scatter_thresholds_table[conf_index][range].size_intra_table - 1))
763       && (nbytes >
764   mv2_scatter_thresholds_table[conf_index][range].intra_node[range_threshold_intra].max)
765   && (mv2_scatter_thresholds_table[conf_index][range].intra_node[range_threshold_intra].max !=
766       -1)) {
767       range_threshold_intra++;
768   }
769
770   MV2_Scatter_function = mv2_scatter_thresholds_table[conf_index][range].inter_leader[range_threshold]
771                                                                                       .MV2_pt_Scatter_function;
772
773   if(MV2_Scatter_function == &MPIR_Scatter_mcst_wrap_MV2) {
774 #if defined(_MCST_SUPPORT_)
775       if(comm_ptr->ch.is_mcast_ok == 1
776           && mv2_use_mcast_scatter == 1
777           && comm_ptr->ch.shmem_coll_ok == 1) {
778           MV2_Scatter_function = &MPIR_Scatter_mcst_MV2;
779       } else
780 #endif /*#if defined(_MCST_SUPPORT_) */
781         {
782           if(mv2_scatter_thresholds_table[conf_index][range].inter_leader[range_threshold + 1].
783               MV2_pt_Scatter_function != NULL) {
784               MV2_Scatter_function = mv2_scatter_thresholds_table[conf_index][range].inter_leader[range_threshold + 1]
785                                                                                                   .MV2_pt_Scatter_function;
786           } else {
787               /* Fallback! */
788               MV2_Scatter_function = &MPIR_Scatter_MV2_Binomial;
789           }
790         } 
791   }
792
793   if( (MV2_Scatter_function == &MPIR_Scatter_MV2_two_level_Direct) ||
794       (MV2_Scatter_function == &MPIR_Scatter_MV2_two_level_Binomial)) {
795       /* if( comm_ptr->ch.shmem_coll_ok == 1 &&
796              comm_ptr->ch.is_global_block == 1 ) {
797              MV2_Scatter_intra_function = mv2_scatter_thresholds_table[conf_index][range].intra_node[range_threshold_intra]
798                                 .MV2_pt_Scatter_function;
799
800              mpi_errno =
801                    MV2_Scatter_function(sendbuf, sendcnt, sendtype,
802                                         recvbuf, recvcnt, recvtype, root,
803                                         comm_ptr);
804          } else {*/
805       mpi_errno = MPIR_Scatter_MV2_Binomial(sendbuf, sendcnt, sendtype,
806           recvbuf, recvcnt, recvtype, root,
807           comm_ptr);
808
809       //}
810   } else {
811       mpi_errno = MV2_Scatter_function(sendbuf, sendcnt, sendtype,
812           recvbuf, recvcnt, recvtype, root,
813           comm_ptr);
814   }
815   return (mpi_errno);
816 }
817