Logo AND Algorithmique Numérique Distribuée

Public GIT Repository
Update copyright lines for 2023.
[simgrid.git] / src / smpi / colls / smpi_mvapich2_selector.cpp
1 /* selector for collective algorithms based on mvapich decision logic */
2
3 /* Copyright (c) 2009-2023. The SimGrid Team.
4  * All rights reserved.                                                     */
5
6 /* This program is free software; you can redistribute it and/or modify it
7  * under the terms of the license (GNU LGPL) which comes with this package. */
8
9 #include "colls_private.hpp"
10
11 #include "smpi_mvapich2_selector_stampede.hpp"
12
13 namespace simgrid::smpi {
14
15 int alltoall__mvapich2( const void *sendbuf, int sendcount,
16                         MPI_Datatype sendtype,
17                         void* recvbuf, int recvcount,
18                         MPI_Datatype recvtype,
19                         MPI_Comm comm)
20 {
21
22   if (mv2_alltoall_table_ppn_conf == nullptr)
23     init_mv2_alltoall_tables_stampede();
24
25   int sendtype_size, recvtype_size, comm_size;
26   int mpi_errno=MPI_SUCCESS;
27   int range = 0;
28   int range_threshold = 0;
29   int conf_index = 0;
30   comm_size =  comm->size();
31
32   sendtype_size=sendtype->size();
33   recvtype_size=recvtype->size();
34   long nbytes = sendtype_size * sendcount;
35
36   /* check if safe to use partial subscription mode */
37
38   /* Search for the corresponding system size inside the tuning table */
39   while ((range < (mv2_size_alltoall_tuning_table[conf_index] - 1)) &&
40       (comm_size > mv2_alltoall_thresholds_table[conf_index][range].numproc)) {
41       range++;
42   }
43   /* Search for corresponding inter-leader function */
44   while ((range_threshold < (mv2_alltoall_thresholds_table[conf_index][range].size_table - 1))
45       && (nbytes >
46   mv2_alltoall_thresholds_table[conf_index][range].algo_table[range_threshold].max)
47   && (mv2_alltoall_thresholds_table[conf_index][range].algo_table[range_threshold].max != -1)) {
48       range_threshold++;
49   }
50   MV2_Alltoall_function = mv2_alltoall_thresholds_table[conf_index][range].algo_table[range_threshold]
51                                                                                       .MV2_pt_Alltoall_function;
52
53   if(sendbuf != MPI_IN_PLACE) {
54       mpi_errno = MV2_Alltoall_function(sendbuf, sendcount, sendtype,
55           recvbuf, recvcount, recvtype,
56           comm);
57   } else {
58       range_threshold = 0;
59       if(nbytes <
60           mv2_alltoall_thresholds_table[conf_index][range].in_place_algo_table[range_threshold].min
61           ||nbytes > mv2_alltoall_thresholds_table[conf_index][range].in_place_algo_table[range_threshold].max
62       ) {
63         unsigned char* tmp_buf = smpi_get_tmp_sendbuffer(comm_size * recvcount * recvtype_size);
64         Datatype::copy(recvbuf, comm_size * recvcount, recvtype, tmp_buf, comm_size * recvcount, recvtype);
65
66         mpi_errno = MV2_Alltoall_function(tmp_buf, recvcount, recvtype, recvbuf, recvcount, recvtype, comm);
67         smpi_free_tmp_buffer(tmp_buf);
68       } else {
69           mpi_errno = MPIR_Alltoall_inplace_MV2(sendbuf, sendcount, sendtype,
70               recvbuf, recvcount, recvtype,
71               comm );
72       }
73   }
74
75
76   return (mpi_errno);
77 }
78
79 int allgather__mvapich2(const void *sendbuf, int sendcount, MPI_Datatype sendtype,
80     void *recvbuf, int recvcount, MPI_Datatype recvtype,
81     MPI_Comm comm)
82 {
83
84   int mpi_errno = MPI_SUCCESS;
85   long nbytes = 0, comm_size, recvtype_size;
86   int range = 0;
87   bool partial_sub_ok = false;
88   int conf_index = 0;
89   int range_threshold = 0;
90   MPI_Comm shmem_comm;
91   //MPI_Comm *shmem_commptr=NULL;
92   /* Get the size of the communicator */
93   comm_size = comm->size();
94   recvtype_size=recvtype->size();
95   nbytes = recvtype_size * recvcount;
96
97   if (mv2_allgather_table_ppn_conf == nullptr)
98     init_mv2_allgather_tables_stampede();
99
100   if(comm->get_leaders_comm()==MPI_COMM_NULL){
101     comm->init_smp();
102   }
103
104   if (comm->is_uniform()){
105     shmem_comm = comm->get_intra_comm();
106     int local_size = shmem_comm->size();
107     int i          = 0;
108     if (mv2_allgather_table_ppn_conf[0] == -1) {
109       // Indicating user defined tuning
110       conf_index = 0;
111       goto conf_check_end;
112     }
113     do {
114       if (local_size == mv2_allgather_table_ppn_conf[i]) {
115         conf_index = i;
116         partial_sub_ok = true;
117         break;
118       }
119       i++;
120     } while(i < mv2_allgather_num_ppn_conf);
121   }
122   conf_check_end:
123   if (not partial_sub_ok) {
124     conf_index = 0;
125   }
126
127   /* Search for the corresponding system size inside the tuning table */
128   while ((range < (mv2_size_allgather_tuning_table[conf_index] - 1)) &&
129       (comm_size >
130   mv2_allgather_thresholds_table[conf_index][range].numproc)) {
131       range++;
132   }
133   /* Search for corresponding inter-leader function */
134   while ((range_threshold <
135       (mv2_allgather_thresholds_table[conf_index][range].size_inter_table - 1))
136       && (nbytes > mv2_allgather_thresholds_table[conf_index][range].inter_leader[range_threshold].max)
137       && (mv2_allgather_thresholds_table[conf_index][range].inter_leader[range_threshold].max !=
138           -1)) {
139       range_threshold++;
140   }
141
142   /* Set inter-leader pt */
143   MV2_Allgatherction =
144       mv2_allgather_thresholds_table[conf_index][range].inter_leader[range_threshold].
145       MV2_pt_Allgatherction;
146
147   bool is_two_level = mv2_allgather_thresholds_table[conf_index][range].two_level[range_threshold];
148
149   /* intracommunicator */
150   if (is_two_level) {
151     if (partial_sub_ok) {
152       if (comm->is_blocked()){
153       mpi_errno = MPIR_2lvl_Allgather_MV2(sendbuf, sendcount, sendtype,
154                             recvbuf, recvcount, recvtype,
155                             comm);
156       }else{
157       mpi_errno = allgather__mpich(sendbuf, sendcount, sendtype,
158                             recvbuf, recvcount, recvtype,
159                             comm);
160       }
161     } else {
162       mpi_errno = MPIR_Allgather_RD_MV2(sendbuf, sendcount, sendtype,
163           recvbuf, recvcount, recvtype,
164           comm);
165     }
166   } else if(MV2_Allgatherction == &MPIR_Allgather_Bruck_MV2
167       || MV2_Allgatherction == &MPIR_Allgather_RD_MV2
168       || MV2_Allgatherction == &MPIR_Allgather_Ring_MV2) {
169       mpi_errno = MV2_Allgatherction(sendbuf, sendcount, sendtype,
170           recvbuf, recvcount, recvtype,
171           comm);
172   }else{
173       return MPI_ERR_OTHER;
174   }
175
176   return mpi_errno;
177 }
178
179 int gather__mvapich2(const void *sendbuf,
180     int sendcnt,
181     MPI_Datatype sendtype,
182     void *recvbuf,
183     int recvcnt,
184     MPI_Datatype recvtype,
185     int root, MPI_Comm  comm)
186 {
187   if (mv2_gather_thresholds_table == nullptr)
188     init_mv2_gather_tables_stampede();
189
190   int mpi_errno = MPI_SUCCESS;
191   int range = 0;
192   int range_threshold = 0;
193   int range_intra_threshold = 0;
194   long nbytes = 0;
195   int comm_size = comm->size();
196   int rank      = comm->rank();
197
198   if (rank == root) {
199     int recvtype_size = recvtype->size();
200     nbytes            = recvcnt * recvtype_size;
201   } else {
202     int sendtype_size = sendtype->size();
203     nbytes            = sendcnt * sendtype_size;
204   }
205
206   /* Search for the corresponding system size inside the tuning table */
207   while ((range < (mv2_size_gather_tuning_table - 1)) &&
208       (comm_size > mv2_gather_thresholds_table[range].numproc)) {
209       range++;
210   }
211   /* Search for corresponding inter-leader function */
212   while ((range_threshold < (mv2_gather_thresholds_table[range].size_inter_table - 1))
213       && (nbytes >
214   mv2_gather_thresholds_table[range].inter_leader[range_threshold].max)
215   && (mv2_gather_thresholds_table[range].inter_leader[range_threshold].max !=
216       -1)) {
217       range_threshold++;
218   }
219
220   /* Search for corresponding intra node function */
221   while ((range_intra_threshold < (mv2_gather_thresholds_table[range].size_intra_table - 1))
222       && (nbytes >
223   mv2_gather_thresholds_table[range].intra_node[range_intra_threshold].max)
224   && (mv2_gather_thresholds_table[range].intra_node[range_intra_threshold].max !=
225       -1)) {
226       range_intra_threshold++;
227   }
228
229     if (comm->is_blocked() ) {
230         // Set intra-node function pt for gather_two_level
231         MV2_Gather_intra_node_function =
232                               mv2_gather_thresholds_table[range].intra_node[range_intra_threshold].
233                               MV2_pt_Gather_function;
234         //Set inter-leader pt
235         MV2_Gather_inter_leader_function =
236                               mv2_gather_thresholds_table[range].inter_leader[range_threshold].
237                               MV2_pt_Gather_function;
238         // We call Gather function
239         mpi_errno =
240             MV2_Gather_inter_leader_function(sendbuf, sendcnt, sendtype, recvbuf, recvcnt,
241                                              recvtype, root, comm);
242
243     } else {
244   // Indeed, direct (non SMP-aware)gather is MPICH one
245   mpi_errno = gather__mpich(sendbuf, sendcnt, sendtype,
246       recvbuf, recvcnt, recvtype,
247       root, comm);
248   }
249
250   return mpi_errno;
251 }
252
253 int allgatherv__mvapich2(const void *sendbuf, int sendcount, MPI_Datatype sendtype,
254     void *recvbuf, const int *recvcounts, const int *displs,
255     MPI_Datatype recvtype, MPI_Comm  comm )
256 {
257   int mpi_errno = MPI_SUCCESS;
258   int range = 0, comm_size, total_count, recvtype_size, i;
259   int range_threshold = 0;
260   long nbytes = 0;
261
262   if (mv2_allgatherv_thresholds_table == nullptr)
263     init_mv2_allgatherv_tables_stampede();
264
265   comm_size = comm->size();
266   total_count = 0;
267   for (i = 0; i < comm_size; i++)
268     total_count += recvcounts[i];
269
270   recvtype_size=recvtype->size();
271   nbytes = total_count * recvtype_size;
272
273   /* Search for the corresponding system size inside the tuning table */
274   while ((range < (mv2_size_allgatherv_tuning_table - 1)) &&
275       (comm_size > mv2_allgatherv_thresholds_table[range].numproc)) {
276       range++;
277   }
278   /* Search for corresponding inter-leader function */
279   while ((range_threshold < (mv2_allgatherv_thresholds_table[range].size_inter_table - 1))
280       && (nbytes >
281   comm_size * mv2_allgatherv_thresholds_table[range].inter_leader[range_threshold].max)
282   && (mv2_allgatherv_thresholds_table[range].inter_leader[range_threshold].max !=
283       -1)) {
284       range_threshold++;
285   }
286   /* Set inter-leader pt */
287   MV2_Allgatherv_function =
288       mv2_allgatherv_thresholds_table[range].inter_leader[range_threshold].
289       MV2_pt_Allgatherv_function;
290
291   if (MV2_Allgatherv_function == &MPIR_Allgatherv_Rec_Doubling_MV2)
292     {
293     if (not(comm_size & (comm_size - 1))) {
294       mpi_errno =
295           MPIR_Allgatherv_Rec_Doubling_MV2(sendbuf, sendcount, sendtype, recvbuf, recvcounts, displs, recvtype, comm);
296         } else {
297             mpi_errno =
298                 MPIR_Allgatherv_Bruck_MV2(sendbuf, sendcount,
299                     sendtype, recvbuf,
300                     recvcounts, displs,
301                     recvtype, comm);
302         }
303     } else {
304         mpi_errno =
305             MV2_Allgatherv_function(sendbuf, sendcount, sendtype,
306                 recvbuf, recvcounts, displs,
307                 recvtype, comm);
308     }
309
310   return mpi_errno;
311 }
312
313
314
315 int allreduce__mvapich2(const void *sendbuf,
316     void *recvbuf,
317     int count,
318     MPI_Datatype datatype,
319     MPI_Op op, MPI_Comm comm)
320 {
321
322   int mpi_errno = MPI_SUCCESS;
323   //int rank = 0,
324   int comm_size = 0;
325
326   comm_size = comm->size();
327   //rank = comm->rank();
328
329   if (count == 0) {
330       return MPI_SUCCESS;
331   }
332
333   if (mv2_allreduce_thresholds_table == nullptr)
334     init_mv2_allreduce_tables_stampede();
335
336   /* check if multiple threads are calling this collective function */
337
338   MPI_Aint sendtype_size = 0;
339   long nbytes = 0;
340   MPI_Aint true_lb, true_extent;
341
342   sendtype_size=datatype->size();
343   nbytes = count * sendtype_size;
344
345   datatype->extent(&true_lb, &true_extent);
346   bool is_commutative = op->is_commutative();
347
348   {
349     int range = 0, range_threshold = 0, range_threshold_intra = 0;
350     bool is_two_level = false;
351
352     /* Search for the corresponding system size inside the tuning table */
353     while ((range < (mv2_size_allreduce_tuning_table - 1)) &&
354         (comm_size > mv2_allreduce_thresholds_table[range].numproc)) {
355         range++;
356     }
357     /* Search for corresponding inter-leader function */
358     /* skip mcast pointers if mcast is not available */
359     if (not mv2_allreduce_thresholds_table[range].mcast_enabled) {
360         while ((range_threshold < (mv2_allreduce_thresholds_table[range].size_inter_table - 1))
361             && ((mv2_allreduce_thresholds_table[range].
362                 inter_leader[range_threshold].MV2_pt_Allreducection
363                 == &MPIR_Allreduce_mcst_reduce_redscat_gather_MV2) ||
364                 (mv2_allreduce_thresholds_table[range].
365                     inter_leader[range_threshold].MV2_pt_Allreducection
366                     == &MPIR_Allreduce_mcst_reduce_two_level_helper_MV2)
367             )) {
368             range_threshold++;
369         }
370     }
371     while ((range_threshold < (mv2_allreduce_thresholds_table[range].size_inter_table - 1))
372         && (nbytes >
373     mv2_allreduce_thresholds_table[range].inter_leader[range_threshold].max)
374     && (mv2_allreduce_thresholds_table[range].inter_leader[range_threshold].max != -1)) {
375         range_threshold++;
376     }
377     if (mv2_allreduce_thresholds_table[range].is_two_level_allreduce[range_threshold]) {
378       is_two_level = true;
379     }
380     /* Search for corresponding intra-node function */
381     while ((range_threshold_intra <
382         (mv2_allreduce_thresholds_table[range].size_intra_table - 1))
383         && (nbytes >
384     mv2_allreduce_thresholds_table[range].intra_node[range_threshold_intra].max)
385     && (mv2_allreduce_thresholds_table[range].intra_node[range_threshold_intra].max !=
386         -1)) {
387         range_threshold_intra++;
388     }
389
390     MV2_Allreducection = mv2_allreduce_thresholds_table[range].inter_leader[range_threshold]
391                                                                                 .MV2_pt_Allreducection;
392
393     MV2_Allreduce_intra_function = mv2_allreduce_thresholds_table[range].intra_node[range_threshold_intra]
394                                                                                     .MV2_pt_Allreducection;
395
396     /* check if mcast is ready, otherwise replace mcast with other algorithm */
397     if((MV2_Allreducection == &MPIR_Allreduce_mcst_reduce_redscat_gather_MV2)||
398         (MV2_Allreducection == &MPIR_Allreduce_mcst_reduce_two_level_helper_MV2)){
399         {
400           MV2_Allreducection = &MPIR_Allreduce_pt2pt_rd_MV2;
401         }
402         if (not is_two_level) {
403             MV2_Allreducection = &MPIR_Allreduce_pt2pt_rd_MV2;
404         }
405     }
406
407     if (is_two_level) {
408       // check if shm is ready, if not use other algorithm first
409       if (is_commutative) {
410           if(comm->get_leaders_comm()==MPI_COMM_NULL){
411             comm->init_smp();
412           }
413           mpi_errno = MPIR_Allreduce_two_level_MV2(sendbuf, recvbuf, count,
414                                                      datatype, op, comm);
415       } else {
416         mpi_errno = MPIR_Allreduce_pt2pt_rd_MV2(sendbuf, recvbuf, count,
417             datatype, op, comm);
418       }
419     } else {
420         mpi_errno = MV2_Allreducection(sendbuf, recvbuf, count,
421             datatype, op, comm);
422     }
423   }
424
425   //comm->ch.intra_node_done=0;
426
427   return (mpi_errno);
428
429
430 }
431
432
433 int alltoallv__mvapich2(const void *sbuf, const int *scounts, const int *sdisps,
434     MPI_Datatype sdtype,
435     void *rbuf, const int *rcounts, const int *rdisps,
436     MPI_Datatype rdtype,
437     MPI_Comm  comm
438 )
439 {
440
441   if (sbuf == MPI_IN_PLACE) {
442       return alltoallv__ompi_basic_linear(sbuf, scounts, sdisps, sdtype,
443                                           rbuf, rcounts, rdisps, rdtype,
444                                           comm);
445   } else     /* For starters, just keep the original algorithm. */
446   return alltoallv__ring(sbuf, scounts, sdisps, sdtype,
447                          rbuf, rcounts, rdisps, rdtype,
448                          comm);
449 }
450
451
452 int barrier__mvapich2(MPI_Comm  comm)
453 {
454   return barrier__mvapich2_pair(comm);
455 }
456
457
458
459
460 int bcast__mvapich2(void *buffer,
461                     int count,
462                     MPI_Datatype datatype,
463                     int root, MPI_Comm comm)
464 {
465     int mpi_errno = MPI_SUCCESS;
466     int comm_size/*, rank*/;
467     bool two_level_bcast      = true;
468     long nbytes = 0;
469     int range = 0;
470     int range_threshold = 0;
471     int range_threshold_intra = 0;
472     MPI_Aint type_size;
473     //, position;
474     // unsigned char *tmp_buf = NULL;
475     MPI_Comm shmem_comm;
476     //MPID_Datatype *dtp;
477
478     if (count == 0)
479         return MPI_SUCCESS;
480     if(comm->get_leaders_comm()==MPI_COMM_NULL){
481       comm->init_smp();
482     }
483     if (not mv2_bcast_thresholds_table)
484       init_mv2_bcast_tables_stampede();
485     comm_size = comm->size();
486     //rank = comm->rank();
487
488     // bool is_contig = true;
489 /*    if (HANDLE_GET_KIND(datatype) == HANDLE_KIND_BUILTIN)*/
490 /*        is_contig = true;*/
491 /*    else {*/
492 /*        MPID_Datatype_get_ptr(datatype, dtp);*/
493 /*        is_contig = dtp->is_contig;*/
494 /*    }*/
495
496     // bool is_homogeneous = true;
497
498     /* MPI_Type_size() might not give the accurate size of the packed
499      * datatype for heterogeneous systems (because of padding, encoding,
500      * etc). On the other hand, MPI_Pack_size() can become very
501      * expensive, depending on the implementation, especially for
502      * heterogeneous systems. We want to use MPI_Type_size() wherever
503      * possible, and MPI_Pack_size() in other places.
504      */
505     //if (is_homogeneous) {
506         type_size=datatype->size();
507
508    /* } else {
509         MPIR_Pack_size_impl(1, datatype, &type_size);
510     }*/
511     nbytes =  (count) * (type_size);
512
513     /* Search for the corresponding system size inside the tuning table */
514     while ((range < (mv2_size_bcast_tuning_table - 1)) &&
515            (comm_size > mv2_bcast_thresholds_table[range].numproc)) {
516         range++;
517     }
518     /* Search for corresponding inter-leader function */
519     while ((range_threshold < (mv2_bcast_thresholds_table[range].size_inter_table - 1))
520            && (nbytes >
521                mv2_bcast_thresholds_table[range].inter_leader[range_threshold].max)
522            && (mv2_bcast_thresholds_table[range].inter_leader[range_threshold].max != -1)) {
523         range_threshold++;
524     }
525
526     /* Search for corresponding intra-node function */
527     while ((range_threshold_intra <
528             (mv2_bcast_thresholds_table[range].size_intra_table - 1))
529            && (nbytes >
530                mv2_bcast_thresholds_table[range].intra_node[range_threshold_intra].max)
531            && (mv2_bcast_thresholds_table[range].intra_node[range_threshold_intra].max !=
532                -1)) {
533         range_threshold_intra++;
534     }
535
536     MV2_Bcast_function =
537         mv2_bcast_thresholds_table[range].inter_leader[range_threshold].
538         MV2_pt_Bcast_function;
539
540     MV2_Bcast_intra_node_function =
541         mv2_bcast_thresholds_table[range].
542         intra_node[range_threshold_intra].MV2_pt_Bcast_function;
543
544 /*    if (mv2_user_bcast_intra == NULL && */
545 /*            MV2_Bcast_intra_node_function == &MPIR_Knomial_Bcast_intra_node_MV2) {*/
546 /*            MV2_Bcast_intra_node_function = &MPIR_Shmem_Bcast_MV2;*/
547 /*    }*/
548
549     if (mv2_bcast_thresholds_table[range].inter_leader[range_threshold].
550         zcpy_pipelined_knomial_factor != -1) {
551         zcpy_knomial_factor =
552             mv2_bcast_thresholds_table[range].inter_leader[range_threshold].
553             zcpy_pipelined_knomial_factor;
554     }
555
556     if (mv2_pipelined_zcpy_knomial_factor != -1) {
557         zcpy_knomial_factor = mv2_pipelined_zcpy_knomial_factor;
558     }
559
560     if (MV2_Bcast_intra_node_function == nullptr) {
561       /* if tuning table do not have any intra selection, set func pointer to
562       ** default one for mcast intra node */
563       MV2_Bcast_intra_node_function = &MPIR_Shmem_Bcast_MV2;
564     }
565
566     /* Set value of pipeline segment size */
567     bcast_segment_size = mv2_bcast_thresholds_table[range].bcast_segment_size;
568
569     /* Set value of inter node knomial factor */
570     mv2_inter_node_knomial_factor = mv2_bcast_thresholds_table[range].inter_node_knomial_factor;
571
572     /* Set value of intra node knomial factor */
573     mv2_intra_node_knomial_factor = mv2_bcast_thresholds_table[range].intra_node_knomial_factor;
574
575     /* Check if we will use a two level algorithm or not */
576     two_level_bcast =
577 #if defined(_MCST_SUPPORT_)
578         mv2_bcast_thresholds_table[range].is_two_level_bcast[range_threshold]
579         || comm->ch.is_mcast_ok;
580 #else
581         mv2_bcast_thresholds_table[range].is_two_level_bcast[range_threshold];
582 #endif
583     if (two_level_bcast) {
584        // if (not is_contig || not is_homogeneous) {
585 //   tmp_buf = smpi_get_tmp_sendbuffer(nbytes);
586
587 /*            position = 0;*/
588 /*            if (rank == root) {*/
589 /*                mpi_errno =*/
590 /*                    MPIR_Pack_impl(buffer, count, datatype, tmp_buf, nbytes, &position);*/
591 /*                if (mpi_errno)*/
592 /*                    MPIU_ERR_POP(mpi_errno);*/
593 /*            }*/
594 // }
595 #ifdef CHANNEL_MRAIL_GEN2
596         if ((mv2_enable_zcpy_bcast == 1) &&
597               (&MPIR_Pipelined_Bcast_Zcpy_MV2 == MV2_Bcast_function)) {
598           // if (not is_contig || not is_homogeneous) {
599           //   mpi_errno = MPIR_Pipelined_Bcast_Zcpy_MV2(tmp_buf, nbytes, MPI_BYTE, root, comm);
600           // } else {
601                 mpi_errno = MPIR_Pipelined_Bcast_Zcpy_MV2(buffer, count, datatype,
602                                                  root, comm);
603           // }
604         } else
605 #endif /* defined(CHANNEL_MRAIL_GEN2) */
606         {
607             shmem_comm = comm->get_intra_comm();
608             // if (not is_contig || not is_homogeneous) {
609             //   MPIR_Bcast_tune_inter_node_helper_MV2(tmp_buf, nbytes, MPI_BYTE, root, comm);
610             // } else {
611               MPIR_Bcast_tune_inter_node_helper_MV2(buffer, count, datatype, root, comm);
612             // }
613
614             /* We are now done with the inter-node phase */
615
616
617                     root = INTRA_NODE_ROOT;
618
619                     // if (not is_contig || not is_homogeneous) {
620                     //       mpi_errno = MV2_Bcast_intra_node_function(tmp_buf, nbytes, MPI_BYTE, root, shmem_comm);
621                     // } else {
622                     mpi_errno = MV2_Bcast_intra_node_function(buffer, count,
623                                                               datatype, root, shmem_comm);
624
625                     // }
626         }
627         /*        if (not is_contig || not is_homogeneous) {*/
628         /*            if (rank != root) {*/
629         /*                position = 0;*/
630         /*                mpi_errno = MPIR_Unpack_impl(tmp_buf, nbytes, &position, buffer,*/
631         /*                                             count, datatype);*/
632         /*            }*/
633         /*        }*/
634     } else {
635         /* We use Knomial for intra node */
636         MV2_Bcast_intra_node_function = &MPIR_Knomial_Bcast_intra_node_MV2;
637 /*        if (mv2_enable_shmem_bcast == 0) {*/
638             /* Fall back to non-tuned version */
639 /*            MPIR_Bcast_intra_MV2(buffer, count, datatype, root, comm);*/
640 /*        } else {*/
641             mpi_errno = MV2_Bcast_function(buffer, count, datatype, root,
642                                            comm);
643
644 /*        }*/
645     }
646
647
648     return mpi_errno;
649
650 }
651
652
653
654 int reduce__mvapich2(const void *sendbuf,
655     void *recvbuf,
656     int count,
657     MPI_Datatype datatype,
658     MPI_Op op, int root, MPI_Comm comm)
659 {
660   if (mv2_reduce_thresholds_table == nullptr)
661     init_mv2_reduce_tables_stampede();
662
663   int mpi_errno = MPI_SUCCESS;
664   int range = 0;
665   int range_threshold = 0;
666   int range_intra_threshold = 0;
667   int pof2;
668   int comm_size = 0;
669   long nbytes = 0;
670   int sendtype_size;
671   bool is_two_level = false;
672
673   comm_size = comm->size();
674   sendtype_size=datatype->size();
675   nbytes = count * sendtype_size;
676
677   if (count == 0)
678     return MPI_SUCCESS;
679
680   bool is_commutative = (op == MPI_OP_NULL || op->is_commutative());
681
682   /* find nearest power-of-two less than or equal to comm_size */
683   for( pof2 = 1; pof2 <= comm_size; pof2 <<= 1 );
684   pof2 >>=1;
685
686
687   /* Search for the corresponding system size inside the tuning table */
688   while ((range < (mv2_size_reduce_tuning_table - 1)) &&
689       (comm_size > mv2_reduce_thresholds_table[range].numproc)) {
690       range++;
691   }
692   /* Search for corresponding inter-leader function */
693   while ((range_threshold < (mv2_reduce_thresholds_table[range].size_inter_table - 1))
694       && (nbytes >
695   mv2_reduce_thresholds_table[range].inter_leader[range_threshold].max)
696   && (mv2_reduce_thresholds_table[range].inter_leader[range_threshold].max !=
697       -1)) {
698       range_threshold++;
699   }
700
701   /* Search for corresponding intra node function */
702   while ((range_intra_threshold < (mv2_reduce_thresholds_table[range].size_intra_table - 1))
703       && (nbytes >
704   mv2_reduce_thresholds_table[range].intra_node[range_intra_threshold].max)
705   && (mv2_reduce_thresholds_table[range].intra_node[range_intra_threshold].max !=
706       -1)) {
707       range_intra_threshold++;
708   }
709
710   /* Set intra-node function pt for reduce_two_level */
711   MV2_Reduce_intra_function =
712       mv2_reduce_thresholds_table[range].intra_node[range_intra_threshold].
713       MV2_pt_Reduce_function;
714   /* Set inter-leader pt */
715   MV2_Reduce_function =
716       mv2_reduce_thresholds_table[range].inter_leader[range_threshold].
717       MV2_pt_Reduce_function;
718
719   if(mv2_reduce_intra_knomial_factor<0)
720     {
721       mv2_reduce_intra_knomial_factor = mv2_reduce_thresholds_table[range].intra_k_degree;
722     }
723   if(mv2_reduce_inter_knomial_factor<0)
724     {
725       mv2_reduce_inter_knomial_factor = mv2_reduce_thresholds_table[range].inter_k_degree;
726     }
727   if (mv2_reduce_thresholds_table[range].is_two_level_reduce[range_threshold]) {
728     is_two_level = true;
729   }
730   /* We call Reduce function */
731   if (is_two_level) {
732     if (is_commutative) {
733          if(comm->get_leaders_comm()==MPI_COMM_NULL){
734            comm->init_smp();
735          }
736          mpi_errno = MPIR_Reduce_two_level_helper_MV2(sendbuf, recvbuf, count,
737                                            datatype, op, root, comm);
738     } else {
739       mpi_errno = MPIR_Reduce_binomial_MV2(sendbuf, recvbuf, count,
740           datatype, op, root, comm);
741     }
742     } else if(MV2_Reduce_function == &MPIR_Reduce_inter_knomial_wrapper_MV2 ){
743         if (is_commutative)
744           {
745             mpi_errno = MV2_Reduce_function(sendbuf, recvbuf, count,
746                 datatype, op, root, comm);
747           } else {
748               mpi_errno = MPIR_Reduce_binomial_MV2(sendbuf, recvbuf, count,
749                   datatype, op, root, comm);
750           }
751     } else if(MV2_Reduce_function == &MPIR_Reduce_redscat_gather_MV2){
752         if (/*(HANDLE_GET_KIND(op) == HANDLE_KIND_BUILTIN) &&*/ (count >= pof2))
753           {
754             mpi_errno = MV2_Reduce_function(sendbuf, recvbuf, count,
755                 datatype, op, root, comm);
756           } else {
757               mpi_errno = MPIR_Reduce_binomial_MV2(sendbuf, recvbuf, count,
758                   datatype, op, root, comm);
759           }
760     } else {
761         mpi_errno = MV2_Reduce_function(sendbuf, recvbuf, count,
762             datatype, op, root, comm);
763     }
764
765
766   return mpi_errno;
767
768 }
769
770
771 int reduce_scatter__mvapich2(const void *sendbuf, void *recvbuf, const int *recvcnts,
772     MPI_Datatype datatype, MPI_Op op,
773     MPI_Comm comm)
774 {
775   int mpi_errno = MPI_SUCCESS;
776   int i = 0, comm_size = comm->size(), total_count = 0, type_size =
777       0, nbytes = 0;
778   int* disps          = new int[comm_size];
779
780   if (mv2_red_scat_thresholds_table == nullptr)
781     init_mv2_reduce_scatter_tables_stampede();
782
783   bool is_commutative = (op == MPI_OP_NULL || op->is_commutative());
784   for (i = 0; i < comm_size; i++) {
785       disps[i] = total_count;
786       total_count += recvcnts[i];
787   }
788
789   type_size=datatype->size();
790   nbytes = total_count * type_size;
791
792   if (is_commutative) {
793     int range           = 0;
794     int range_threshold = 0;
795
796       /* Search for the corresponding system size inside the tuning table */
797       while ((range < (mv2_size_red_scat_tuning_table - 1)) &&
798           (comm_size > mv2_red_scat_thresholds_table[range].numproc)) {
799           range++;
800       }
801       /* Search for corresponding inter-leader function */
802       while ((range_threshold < (mv2_red_scat_thresholds_table[range].size_inter_table - 1))
803           && (nbytes >
804       mv2_red_scat_thresholds_table[range].inter_leader[range_threshold].max)
805       && (mv2_red_scat_thresholds_table[range].inter_leader[range_threshold].max !=
806           -1)) {
807           range_threshold++;
808       }
809
810       /* Set inter-leader pt */
811       MV2_Red_scat_function =
812           mv2_red_scat_thresholds_table[range].inter_leader[range_threshold].
813           MV2_pt_Red_scat_function;
814
815       mpi_errno = MV2_Red_scat_function(sendbuf, recvbuf,
816           recvcnts, datatype,
817           op, comm);
818   } else {
819       bool is_block_regular = true;
820       for (i = 0; i < (comm_size - 1); ++i) {
821           if (recvcnts[i] != recvcnts[i+1]) {
822               is_block_regular = false;
823               break;
824           }
825       }
826       int pof2 = 1;
827       while (pof2 < comm_size) pof2 <<= 1;
828       if (pof2 == comm_size && is_block_regular) {
829           /* noncommutative, pof2 size, and block regular */
830           MPIR_Reduce_scatter_non_comm_MV2(sendbuf, recvbuf,
831               recvcnts, datatype,
832               op, comm);
833       }
834       mpi_errno =  reduce_scatter__mpich_rdb(sendbuf, recvbuf,
835                                              recvcnts, datatype,
836                                              op, comm);
837   }
838   delete[] disps;
839   return mpi_errno;
840
841 }
842
843
844
845 int scatter__mvapich2(const void *sendbuf,
846     int sendcnt,
847     MPI_Datatype sendtype,
848     void *recvbuf,
849     int recvcnt,
850     MPI_Datatype recvtype,
851     int root, MPI_Comm comm)
852 {
853   int range = 0, range_threshold = 0, range_threshold_intra = 0;
854   int mpi_errno = MPI_SUCCESS;
855   //   int mpi_errno_ret = MPI_SUCCESS;
856   int rank, nbytes, comm_size;
857   bool partial_sub_ok = false;
858   int conf_index = 0;
859      MPI_Comm shmem_comm;
860   //    MPID_Comm *shmem_commptr=NULL;
861      if (mv2_scatter_thresholds_table == nullptr)
862        init_mv2_scatter_tables_stampede();
863
864      if (comm->get_leaders_comm() == MPI_COMM_NULL) {
865        comm->init_smp();
866      }
867
868   comm_size = comm->size();
869
870   rank = comm->rank();
871
872   if (rank == root) {
873     int sendtype_size = sendtype->size();
874     nbytes            = sendcnt * sendtype_size;
875   } else {
876     int recvtype_size = recvtype->size();
877     nbytes            = recvcnt * recvtype_size;
878   }
879
880     // check if safe to use partial subscription mode
881     if (comm->is_uniform()) {
882
883         shmem_comm = comm->get_intra_comm();
884         if (mv2_scatter_table_ppn_conf[0] == -1) {
885             // Indicating user defined tuning
886             conf_index = 0;
887         }else{
888           int local_size = shmem_comm->size();
889           int i          = 0;
890             do {
891                 if (local_size == mv2_scatter_table_ppn_conf[i]) {
892                     conf_index = i;
893                     partial_sub_ok = true;
894                     break;
895                 }
896                 i++;
897             } while(i < mv2_scatter_num_ppn_conf);
898         }
899     }
900
901   if (not partial_sub_ok) {
902       conf_index = 0;
903   }
904
905   /* Search for the corresponding system size inside the tuning table */
906   while ((range < (mv2_size_scatter_tuning_table[conf_index] - 1)) &&
907       (comm_size > mv2_scatter_thresholds_table[conf_index][range].numproc)) {
908       range++;
909   }
910   /* Search for corresponding inter-leader function */
911   while ((range_threshold < (mv2_scatter_thresholds_table[conf_index][range].size_inter_table - 1))
912       && (nbytes >
913   mv2_scatter_thresholds_table[conf_index][range].inter_leader[range_threshold].max)
914   && (mv2_scatter_thresholds_table[conf_index][range].inter_leader[range_threshold].max != -1)) {
915       range_threshold++;
916   }
917
918   /* Search for corresponding intra-node function */
919   while ((range_threshold_intra <
920       (mv2_scatter_thresholds_table[conf_index][range].size_intra_table - 1))
921       && (nbytes >
922   mv2_scatter_thresholds_table[conf_index][range].intra_node[range_threshold_intra].max)
923   && (mv2_scatter_thresholds_table[conf_index][range].intra_node[range_threshold_intra].max !=
924       -1)) {
925       range_threshold_intra++;
926   }
927
928   MV2_Scatter_function = mv2_scatter_thresholds_table[conf_index][range].inter_leader[range_threshold]
929                                                                                       .MV2_pt_Scatter_function;
930
931   if(MV2_Scatter_function == &MPIR_Scatter_mcst_wrap_MV2) {
932 #if defined(_MCST_SUPPORT_)
933       if(comm->ch.is_mcast_ok == 1
934           && mv2_use_mcast_scatter == 1
935           && comm->ch.shmem_coll_ok == 1) {
936           MV2_Scatter_function = &MPIR_Scatter_mcst_MV2;
937       } else
938 #endif /*#if defined(_MCST_SUPPORT_) */
939         {
940         if (mv2_scatter_thresholds_table[conf_index][range].inter_leader[range_threshold + 1].MV2_pt_Scatter_function !=
941             nullptr) {
942           MV2_Scatter_function =
943               mv2_scatter_thresholds_table[conf_index][range].inter_leader[range_threshold + 1].MV2_pt_Scatter_function;
944         } else {
945           /* Fallback! */
946           MV2_Scatter_function = &MPIR_Scatter_MV2_Binomial;
947         }
948         }
949   }
950
951   if( (MV2_Scatter_function == &MPIR_Scatter_MV2_two_level_Direct) ||
952       (MV2_Scatter_function == &MPIR_Scatter_MV2_two_level_Binomial)) {
953        if( comm->is_blocked()) {
954              MV2_Scatter_intra_function = mv2_scatter_thresholds_table[conf_index][range].intra_node[range_threshold_intra]
955                                 .MV2_pt_Scatter_function;
956
957              mpi_errno =
958                    MV2_Scatter_function(sendbuf, sendcnt, sendtype,
959                                         recvbuf, recvcnt, recvtype, root,
960                                         comm);
961          } else {
962       mpi_errno = MPIR_Scatter_MV2_Binomial(sendbuf, sendcnt, sendtype,
963           recvbuf, recvcnt, recvtype, root,
964           comm);
965
966       }
967   } else {
968       mpi_errno = MV2_Scatter_function(sendbuf, sendcnt, sendtype,
969           recvbuf, recvcnt, recvtype, root,
970           comm);
971   }
972   return (mpi_errno);
973 }
974
975 } // namespace simgrid::smpi
976
977 void smpi_coll_cleanup_mvapich2()
978 {
979   if (mv2_alltoall_thresholds_table)
980     delete[] mv2_alltoall_thresholds_table[0];
981   delete[] mv2_alltoall_thresholds_table;
982   delete[] mv2_size_alltoall_tuning_table;
983   delete[] mv2_alltoall_table_ppn_conf;
984
985   delete[] mv2_gather_thresholds_table;
986   if (mv2_allgather_thresholds_table)
987     delete[] mv2_allgather_thresholds_table[0];
988   delete[] mv2_size_allgather_tuning_table;
989   delete[] mv2_allgather_table_ppn_conf;
990   delete[] mv2_allgather_thresholds_table;
991
992   delete[] mv2_allgatherv_thresholds_table;
993   delete[] mv2_reduce_thresholds_table;
994   delete[] mv2_red_scat_thresholds_table;
995   delete[] mv2_allreduce_thresholds_table;
996   delete[] mv2_bcast_thresholds_table;
997   if (mv2_scatter_thresholds_table)
998     delete[] mv2_scatter_thresholds_table[0];
999   delete[] mv2_scatter_thresholds_table;
1000   delete[] mv2_size_scatter_tuning_table;
1001   delete[] mv2_scatter_table_ppn_conf;
1002 }