Logo AND Algorithmique Numérique Distribuée

Public GIT Repository
6f448ae7d4411a20746daac0c9399134fb37b628
[simgrid.git] / src / smpi / colls / smpi_openmpi_selector.cpp
1 /* selector for collective algorithms based on openmpi's default coll_tuned_decision_fixed selector
2  * Updated 02/2022                                                          */
3
4 /* Copyright (c) 2009-2022. The SimGrid Team.
5  * All rights reserved.                                                     */
6
7 /* This program is free software; you can redistribute it and/or modify it
8  * under the terms of the license (GNU LGPL) which comes with this package. */
9
10 #include "colls_private.hpp"
11
12 #include <memory>
13
14 /* FIXME
15 add algos:
16 allreduce nonoverlapping, basic linear
17 alltoall linear_sync
18 bcast chain
19 */
20
21 namespace simgrid {
22 namespace smpi {
23
24 int allreduce__ompi(const void *sbuf, void *rbuf, int count,
25                     MPI_Datatype dtype, MPI_Op op, MPI_Comm comm)
26 {
27     size_t total_dsize = dtype->size() * (ptrdiff_t)count;
28     int communicator_size = comm->size();
29     int alg = 1;
30     int(*funcs[]) (const void*, void*, int, MPI_Datatype, MPI_Op, MPI_Comm)={
31         &allreduce__redbcast,
32         &allreduce__redbcast,
33         &allreduce__rdb,
34         &allreduce__lr,
35         &allreduce__ompi_ring_segmented,
36         &allreduce__rab_rdb
37     };
38
39     /** Algorithms:
40      *  {1, "basic_linear"},
41      *  {2, "nonoverlapping"},
42      *  {3, "recursive_doubling"},
43      *  {4, "ring"},
44      *  {5, "segmented_ring"},
45      *  {6, "rabenseifner"
46      *
47      * Currently, ring, segmented ring, and rabenseifner do not support
48      * non-commutative operations.
49      */
50     if ((op != MPI_OP_NULL) && not op->is_commutative()) {
51         if (communicator_size < 4) {
52             if (total_dsize < 131072) {
53                 alg = 3;
54             } else {
55                 alg = 1;
56             }
57         } else if (communicator_size < 8) {
58             alg = 3;
59         } else if (communicator_size < 16) {
60             if (total_dsize < 1048576) {
61                 alg = 3;
62             } else {
63                 alg = 2;
64             }
65         } else if (communicator_size < 128) {
66             alg = 3;
67         } else if (communicator_size < 256) {
68             if (total_dsize < 131072) {
69                 alg = 2;
70             } else if (total_dsize < 524288) {
71                 alg = 3;
72             } else {
73                 alg = 2;
74             }
75         } else if (communicator_size < 512) {
76             if (total_dsize < 4096) {
77                 alg = 2;
78             } else if (total_dsize < 524288) {
79                 alg = 3;
80             } else {
81                 alg = 2;
82             }
83         } else {
84             if (total_dsize < 2048) {
85                 alg = 2;
86             } else {
87                 alg = 3;
88             }
89         }
90     } else {
91         if (communicator_size < 4) {
92             if (total_dsize < 8) {
93                 alg = 4;
94             } else if (total_dsize < 4096) {
95                 alg = 3;
96             } else if (total_dsize < 8192) {
97                 alg = 4;
98             } else if (total_dsize < 16384) {
99                 alg = 3;
100             } else if (total_dsize < 65536) {
101                 alg = 4;
102             } else if (total_dsize < 262144) {
103                 alg = 5;
104             } else {
105                 alg = 6;
106             }
107         } else if (communicator_size < 8) {
108             if (total_dsize < 16) {
109                 alg = 4;
110             } else if (total_dsize < 8192) {
111                 alg = 3;
112             } else {
113                 alg = 6;
114             }
115         } else if (communicator_size < 16) {
116             if (total_dsize < 8192) {
117                 alg = 3;
118             } else {
119                 alg = 6;
120             }
121         } else if (communicator_size < 32) {
122             if (total_dsize < 64) {
123                 alg = 5;
124             } else if (total_dsize < 4096) {
125                 alg = 3;
126             } else {
127                 alg = 6;
128             }
129         } else if (communicator_size < 64) {
130             if (total_dsize < 128) {
131                 alg = 5;
132             } else {
133                 alg = 6;
134             }
135         } else if (communicator_size < 128) {
136             if (total_dsize < 262144) {
137                 alg = 3;
138             } else {
139                 alg = 6;
140             }
141         } else if (communicator_size < 256) {
142             if (total_dsize < 131072) {
143                 alg = 2;
144             } else if (total_dsize < 262144) {
145                 alg = 3;
146             } else {
147                 alg = 6;
148             }
149         } else if (communicator_size < 512) {
150             if (total_dsize < 4096) {
151                 alg = 2;
152             } else {
153                 alg = 6;
154             }
155         } else if (communicator_size < 2048) {
156             if (total_dsize < 2048) {
157                 alg = 2;
158             } else if (total_dsize < 16384) {
159                 alg = 3;
160             } else {
161                 alg = 6;
162             }
163         } else if (communicator_size < 4096) {
164             if (total_dsize < 2048) {
165                 alg = 2;
166             } else if (total_dsize < 4096) {
167                 alg = 5;
168             } else if (total_dsize < 16384) {
169                 alg = 3;
170             } else {
171                 alg = 6;
172             }
173         } else {
174             if (total_dsize < 2048) {
175                 alg = 2;
176             } else if (total_dsize < 16384) {
177                 alg = 5;
178             } else if (total_dsize < 32768) {
179                 alg = 3;
180             } else {
181                 alg = 6;
182             }
183         }
184     }
185     return funcs[alg-1](sbuf, rbuf, count, dtype, op, comm);
186 }
187
188
189
190 int alltoall__ompi(const void *sbuf, int scount,
191                    MPI_Datatype sdtype,
192                    void* rbuf, int rcount,
193                    MPI_Datatype rdtype,
194                    MPI_Comm comm)
195 {
196     int alg = 1;
197     size_t dsize, total_dsize;
198     int communicator_size = comm->size();
199
200     if (MPI_IN_PLACE != sbuf) {
201         dsize = sdtype->size();
202     } else {
203         dsize = rdtype->size();
204     }
205     total_dsize = dsize * (ptrdiff_t)scount;
206     int (*funcs[])(const void *, int, MPI_Datatype, void*, int, MPI_Datatype, MPI_Comm) = {
207         &alltoall__basic_linear,
208         &alltoall__pair,
209         &alltoall__bruck,
210         &alltoall__basic_linear,
211         &alltoall__basic_linear
212     };
213     /** Algorithms:
214      *  {1, "linear"},
215      *  {2, "pairwise"},
216      *  {3, "modified_bruck"},
217      *  {4, "linear_sync"},
218      *  {5, "two_proc"},
219      */
220     if (communicator_size == 2) {
221         if (total_dsize < 2) {
222             alg = 2;
223         } else if (total_dsize < 4) {
224             alg = 5;
225         } else if (total_dsize < 16) {
226             alg = 2;
227         } else if (total_dsize < 64) {
228             alg = 5;
229         } else if (total_dsize < 256) {
230             alg = 2;
231         } else if (total_dsize < 4096) {
232             alg = 5;
233         } else if (total_dsize < 32768) {
234             alg = 2;
235         } else if (total_dsize < 262144) {
236             alg = 4;
237         } else if (total_dsize < 1048576) {
238             alg = 5;
239         } else {
240             alg = 2;
241         }
242     } else if (communicator_size < 8) {
243         if (total_dsize < 8192) {
244             alg = 4;
245         } else if (total_dsize < 16384) {
246             alg = 1;
247         } else if (total_dsize < 65536) {
248             alg = 4;
249         } else if (total_dsize < 524288) {
250             alg = 1;
251         } else if (total_dsize < 1048576) {
252             alg = 2;
253         } else {
254             alg = 1;
255         }
256     } else if (communicator_size < 16) {
257         if (total_dsize < 262144) {
258             alg = 4;
259         } else {
260             alg = 1;
261         }
262     } else if (communicator_size < 32) {
263         if (total_dsize < 4) {
264             alg = 4;
265         } else if (total_dsize < 512) {
266             alg = 3;
267         } else if (total_dsize < 8192) {
268             alg = 4;
269         } else if (total_dsize < 32768) {
270             alg = 1;
271         } else if (total_dsize < 262144) {
272             alg = 4;
273         } else if (total_dsize < 524288) {
274             alg = 1;
275         } else {
276             alg = 4;
277         }
278     } else if (communicator_size < 64) {
279         if (total_dsize < 512) {
280             alg = 3;
281         } else if (total_dsize < 524288) {
282             alg = 1;
283         } else {
284             alg = 4;
285         }
286     } else if (communicator_size < 128) {
287         if (total_dsize < 1024) {
288             alg = 3;
289         } else if (total_dsize < 2048) {
290             alg = 1;
291         } else if (total_dsize < 4096) {
292             alg = 4;
293         } else if (total_dsize < 262144) {
294             alg = 1;
295         } else {
296             alg = 2;
297         }
298     } else if (communicator_size < 256) {
299         if (total_dsize < 1024) {
300             alg = 3;
301         } else if (total_dsize < 2048) {
302             alg = 4;
303         } else if (total_dsize < 262144) {
304             alg = 1;
305         } else {
306             alg = 2;
307         }
308     } else if (communicator_size < 512) {
309         if (total_dsize < 1024) {
310             alg = 3;
311         } else if (total_dsize < 8192) {
312             alg = 4;
313         } else if (total_dsize < 32768) {
314             alg = 1;
315         } else {
316             alg = 2;
317         }
318     } else if (communicator_size < 1024) {
319         if (total_dsize < 512) {
320             alg = 3;
321         } else if (total_dsize < 8192) {
322             alg = 4;
323         } else if (total_dsize < 16384) {
324             alg = 1;
325         } else if (total_dsize < 131072) {
326             alg = 4;
327         } else if (total_dsize < 262144) {
328             alg = 1;
329         } else {
330             alg = 2;
331         }
332     } else if (communicator_size < 2048) {
333         if (total_dsize < 512) {
334             alg = 3;
335         } else if (total_dsize < 1024) {
336             alg = 4;
337         } else if (total_dsize < 2048) {
338             alg = 1;
339         } else if (total_dsize < 16384) {
340             alg = 4;
341         } else if (total_dsize < 262144) {
342             alg = 1;
343         } else {
344             alg = 4;
345         }
346     } else if (communicator_size < 4096) {
347         if (total_dsize < 1024) {
348             alg = 3;
349         } else if (total_dsize < 4096) {
350             alg = 4;
351         } else if (total_dsize < 8192) {
352             alg = 1;
353         } else if (total_dsize < 131072) {
354             alg = 4;
355         } else {
356             alg = 1;
357         }
358     } else {
359         if (total_dsize < 2048) {
360             alg = 3;
361         } else if (total_dsize < 8192) {
362             alg = 4;
363         } else if (total_dsize < 16384) {
364             alg = 1;
365         } else if (total_dsize < 32768) {
366             alg = 4;
367         } else if (total_dsize < 65536) {
368             alg = 1;
369         } else {
370             alg = 4;
371         }
372     }
373
374     return funcs[alg-1](sbuf, scount, sdtype,
375                           rbuf, rcount, rdtype, comm);
376 }
377
378 int alltoallv__ompi(const void *sbuf, const int *scounts, const int *sdisps,
379                     MPI_Datatype sdtype,
380                     void *rbuf, const int *rcounts, const int *rdisps,
381                     MPI_Datatype rdtype,
382                     MPI_Comm  comm
383                     )
384 {
385     int communicator_size = comm->size();
386     int alg = 1;
387     int (*funcs[])(const void *, const int*, const int*, MPI_Datatype, void*, const int*, const int*, MPI_Datatype, MPI_Comm) = {
388         &alltoallv__ompi_basic_linear,
389         &alltoallv__pair
390     };
391    /** Algorithms:
392      *  {1, "basic_linear"},
393      *  {2, "pairwise"},
394      *
395      * We can only optimize based on com size
396      */
397     if (communicator_size < 4) {
398         alg = 2;
399     } else if (communicator_size < 64) {
400         alg = 1;
401     } else if (communicator_size < 128) {
402         alg = 2;
403     } else if (communicator_size < 256) {
404         alg = 1;
405     } else if (communicator_size < 1024) {
406         alg = 2;
407     } else {
408         alg = 1;
409     }
410     return funcs[alg-1](sbuf, scounts, sdisps, sdtype,
411                            rbuf, rcounts, rdisps,rdtype,
412                            comm);
413 }
414
415 int barrier__ompi(MPI_Comm  comm)
416 {
417     int communicator_size = comm->size();
418     int alg = 1;
419     int (*funcs[])(MPI_Comm) = {
420         &barrier__ompi_basic_linear,
421         &barrier__ompi_basic_linear,
422         &barrier__ompi_recursivedoubling,
423         &barrier__ompi_bruck,
424         &barrier__ompi_two_procs,
425         &barrier__ompi_tree
426     };
427     /** Algorithms:
428      *  {1, "linear"},
429      *  {2, "double_ring"},
430      *  {3, "recursive_doubling"},
431      *  {4, "bruck"},
432      *  {5, "two_proc"},
433      *  {6, "tree"},
434      *
435      * We can only optimize based on com size
436      */
437     if (communicator_size < 4) {
438         alg = 3;
439     } else if (communicator_size < 8) {
440         alg = 1;
441     } else if (communicator_size < 64) {
442         alg = 3;
443     } else if (communicator_size < 256) {
444         alg = 4;
445     } else if (communicator_size < 512) {
446         alg = 6;
447     } else if (communicator_size < 1024) {
448         alg = 4;
449     } else if (communicator_size < 4096) {
450         alg = 6;
451     } else {
452         alg = 4;
453     }
454
455     return funcs[alg-1](comm);
456 }
457
458 int bcast__ompi(void *buff, int count, MPI_Datatype datatype, int root, MPI_Comm  comm)
459 {
460     int alg = 1;
461     size_t total_dsize, dsize;
462
463     int communicator_size = comm->size();
464
465     dsize = datatype->size();
466     total_dsize = dsize * (unsigned long)count;
467     int (*funcs[])(void*, int, MPI_Datatype, int, MPI_Comm) = {
468         &bcast__NTSL,
469         &bcast__ompi_pipeline,
470         &bcast__ompi_pipeline,
471         &bcast__ompi_split_bintree,
472         &bcast__NTSB,
473         &bcast__binomial_tree,
474         &bcast__mvapich2_knomial_intra_node,
475         &bcast__scatter_rdb_allgather,
476         &bcast__scatter_LR_allgather,
477     };
478     /** Algorithms:
479      *  {1, "basic_linear"},
480      *  {2, "chain"},
481      *  {3, "pipeline"},
482      *  {4, "split_binary_tree"},
483      *  {5, "binary_tree"},
484      *  {6, "binomial"},
485      *  {7, "knomial"},
486      *  {8, "scatter_allgather"},
487      *  {9, "scatter_allgather_ring"},
488      */
489     if (communicator_size < 4) {
490         if (total_dsize < 32) {
491             alg = 3;
492         } else if (total_dsize < 256) {
493             alg = 5;
494         } else if (total_dsize < 512) {
495             alg = 3;
496         } else if (total_dsize < 1024) {
497             alg = 7;
498         } else if (total_dsize < 32768) {
499             alg = 1;
500         } else if (total_dsize < 131072) {
501             alg = 5;
502         } else if (total_dsize < 262144) {
503             alg = 2;
504         } else if (total_dsize < 524288) {
505             alg = 1;
506         } else if (total_dsize < 1048576) {
507             alg = 6;
508         } else {
509             alg = 5;
510         }
511     } else if (communicator_size < 8) {
512         if (total_dsize < 64) {
513             alg = 5;
514         } else if (total_dsize < 128) {
515             alg = 6;
516         } else if (total_dsize < 2048) {
517             alg = 5;
518         } else if (total_dsize < 8192) {
519             alg = 6;
520         } else if (total_dsize < 1048576) {
521             alg = 1;
522         } else {
523             alg = 2;
524         }
525     } else if (communicator_size < 16) {
526         if (total_dsize < 8) {
527             alg = 7;
528         } else if (total_dsize < 64) {
529             alg = 5;
530         } else if (total_dsize < 4096) {
531             alg = 7;
532         } else if (total_dsize < 16384) {
533             alg = 5;
534         } else if (total_dsize < 32768) {
535             alg = 6;
536         } else {
537             alg = 1;
538         }
539     } else if (communicator_size < 32) {
540         if (total_dsize < 4096) {
541             alg = 7;
542         } else if (total_dsize < 1048576) {
543             alg = 6;
544         } else {
545             alg = 8;
546         }
547     } else if (communicator_size < 64) {
548         if (total_dsize < 2048) {
549             alg = 6;
550         } else {
551             alg = 7;
552         }
553     } else if (communicator_size < 128) {
554         alg = 7;
555     } else if (communicator_size < 256) {
556         if (total_dsize < 2) {
557             alg = 6;
558         } else if (total_dsize < 16384) {
559             alg = 5;
560         } else if (total_dsize < 32768) {
561             alg = 1;
562         } else if (total_dsize < 65536) {
563             alg = 5;
564         } else {
565             alg = 7;
566         }
567     } else if (communicator_size < 1024) {
568         if (total_dsize < 16384) {
569             alg = 7;
570         } else if (total_dsize < 32768) {
571             alg = 4;
572         } else {
573             alg = 7;
574         }
575     } else if (communicator_size < 2048) {
576         if (total_dsize < 524288) {
577             alg = 7;
578         } else {
579             alg = 8;
580         }
581     } else if (communicator_size < 4096) {
582         if (total_dsize < 262144) {
583             alg = 7;
584         } else {
585             alg = 8;
586         }
587     } else {
588         if (total_dsize < 8192) {
589             alg = 7;
590         } else if (total_dsize < 16384) {
591             alg = 5;
592         } else if (total_dsize < 262144) {
593             alg = 7;
594         } else {
595             alg = 8;
596         }
597     }
598     return funcs[alg-1](buff, count, datatype, root, comm);
599 }
600
601 int reduce__ompi(const void *sendbuf, void *recvbuf,
602                  int count, MPI_Datatype  datatype,
603                  MPI_Op   op, int root,
604                  MPI_Comm   comm)
605 {
606     size_t total_dsize, dsize;
607     int alg = 1;
608     int communicator_size = comm->size();
609
610     dsize=datatype->size();
611     total_dsize = dsize * count;
612     int (*funcs[])(const void*, void*, int, MPI_Datatype, MPI_Op, int, MPI_Comm) = {
613         &reduce__ompi_basic_linear,
614         &reduce__ompi_chain,
615         &reduce__ompi_pipeline,
616         &reduce__ompi_binary,
617         &reduce__ompi_binomial,
618         &reduce__ompi_in_order_binary,
619         //&reduce__rab our rab can't be used with all datatypes
620         &reduce__ompi_basic_linear
621     };
622     /** Algorithms:
623      *  {1, "linear"},
624      *  {2, "chain"},
625      *  {3, "pipeline"},
626      *  {4, "binary"},
627      *  {5, "binomial"},
628      *  {6, "in-order_binary"},
629      *  {7, "rabenseifner"},
630      *
631      * Currently, only linear and in-order binary tree algorithms are
632      * capable of non commutative ops.
633      */
634      if ((op != MPI_OP_NULL) && not op->is_commutative()) {
635         if (communicator_size < 4) {
636             if (total_dsize < 8) {
637                 alg = 6;
638             } else {
639                 alg = 1;
640             }
641         } else if (communicator_size < 8) {
642             alg = 1;
643         } else if (communicator_size < 16) {
644             if (total_dsize < 1024) {
645                 alg = 6;
646             } else if (total_dsize < 8192) {
647                 alg = 1;
648             } else if (total_dsize < 16384) {
649                 alg = 6;
650             } else if (total_dsize < 262144) {
651                 alg = 1;
652             } else {
653                 alg = 6;
654             }
655         } else if (communicator_size < 128) {
656             alg = 6;
657         } else if (communicator_size < 256) {
658             if (total_dsize < 512) {
659                 alg = 6;
660             } else if (total_dsize < 1024) {
661                 alg = 1;
662             } else {
663                 alg = 6;
664             }
665         } else {
666             alg = 6;
667         }
668     } else {
669         if (communicator_size < 4) {
670             if (total_dsize < 8) {
671                 alg = 7;
672             } else if (total_dsize < 16) {
673                 alg = 4;
674             } else if (total_dsize < 32) {
675                 alg = 3;
676             } else if (total_dsize < 262144) {
677                 alg = 1;
678             } else if (total_dsize < 524288) {
679                 alg = 3;
680             } else if (total_dsize < 1048576) {
681                 alg = 2;
682             } else {
683                 alg = 3;
684             }
685         } else if (communicator_size < 8) {
686             if (total_dsize < 4096) {
687                 alg = 4;
688             } else if (total_dsize < 65536) {
689                 alg = 2;
690             } else if (total_dsize < 262144) {
691                 alg = 5;
692             } else if (total_dsize < 524288) {
693                 alg = 1;
694             } else if (total_dsize < 1048576) {
695                 alg = 5;
696             } else {
697                 alg = 1;
698             }
699         } else if (communicator_size < 16) {
700             if (total_dsize < 8192) {
701                 alg = 4;
702             } else {
703                 alg = 5;
704             }
705         } else if (communicator_size < 32) {
706             if (total_dsize < 4096) {
707                 alg = 4;
708             } else {
709                 alg = 5;
710             }
711         } else if (communicator_size < 256) {
712             alg = 5;
713         } else if (communicator_size < 512) {
714             if (total_dsize < 8192) {
715                 alg = 5;
716             } else if (total_dsize < 16384) {
717                 alg = 6;
718             } else {
719                 alg = 5;
720             }
721         } else if (communicator_size < 2048) {
722             alg = 5;
723         } else if (communicator_size < 4096) {
724             if (total_dsize < 512) {
725                 alg = 5;
726             } else if (total_dsize < 1024) {
727                 alg = 6;
728             } else if (total_dsize < 8192) {
729                 alg = 5;
730             } else if (total_dsize < 16384) {
731                 alg = 6;
732             } else {
733                 alg = 5;
734             }
735         } else {
736             if (total_dsize < 16) {
737                 alg = 5;
738             } else if (total_dsize < 32) {
739                 alg = 6;
740             } else if (total_dsize < 1024) {
741                 alg = 5;
742             } else if (total_dsize < 2048) {
743                 alg = 6;
744             } else if (total_dsize < 8192) {
745                 alg = 5;
746             } else if (total_dsize < 16384) {
747                 alg = 6;
748             } else {
749                 alg = 5;
750             }
751         }
752     }
753
754     return funcs[alg-1] (sendbuf, recvbuf, count, datatype, op, root, comm);
755 }
756
757 int reduce_scatter__ompi(const void *sbuf, void *rbuf,
758                          const int *rcounts,
759                          MPI_Datatype dtype,
760                          MPI_Op  op,
761                          MPI_Comm  comm
762                          )
763 {
764     size_t total_dsize, dsize;
765     int communicator_size = comm->size();
766     int alg = 1;
767     int zerocounts = 0;
768     dsize=dtype->size();
769     total_dsize = 0;
770     for (int i = 0; i < communicator_size; i++) {
771         total_dsize += rcounts[i];
772        // if (0 == rcounts[i]) {
773         //    zerocounts = 1;
774         //}
775     }
776     total_dsize *= dsize;
777     int (*funcs[])(const void*, void*, const int*, MPI_Datatype, MPI_Op, MPI_Comm) = {
778         &reduce_scatter__default,
779         &reduce_scatter__ompi_basic_recursivehalving,
780         &reduce_scatter__ompi_ring,
781         &reduce_scatter__ompi_butterfly,
782     };
783     /** Algorithms:
784      *  {1, "non-overlapping"},
785      *  {2, "recursive_halving"},
786      *  {3, "ring"},
787      *  {4, "butterfly"},
788      *
789      * Non commutative algorithm capability needs re-investigation.
790      * Defaulting to non overlapping for non commutative ops.
791      */
792     if (((op != MPI_OP_NULL) && not op->is_commutative()) || (zerocounts)) {
793         alg = 1;
794     } else {
795         if (communicator_size < 4) {
796             if (total_dsize < 65536) {
797                 alg = 3;
798             } else if (total_dsize < 131072) {
799                 alg = 4;
800             } else {
801                 alg = 3;
802             }
803         } else if (communicator_size < 8) {
804             if (total_dsize < 8) {
805                 alg = 1;
806             } else if (total_dsize < 262144) {
807                 alg = 2;
808             } else {
809                 alg = 3;
810             }
811         } else if (communicator_size < 32) {
812             if (total_dsize < 262144) {
813                 alg = 2;
814             } else {
815                 alg = 3;
816             }
817         } else if (communicator_size < 64) {
818             if (total_dsize < 64) {
819                 alg = 1;
820             } else if (total_dsize < 2048) {
821                 alg = 2;
822             } else if (total_dsize < 524288) {
823                 alg = 4;
824             } else {
825                 alg = 3;
826             }
827         } else if (communicator_size < 128) {
828             if (total_dsize < 256) {
829                 alg = 1;
830             } else if (total_dsize < 512) {
831                 alg = 2;
832             } else if (total_dsize < 2048) {
833                 alg = 4;
834             } else if (total_dsize < 4096) {
835                 alg = 2;
836             } else {
837                 alg = 4;
838             }
839         } else if (communicator_size < 256) {
840             if (total_dsize < 256) {
841                 alg = 1;
842             } else if (total_dsize < 512) {
843                 alg = 2;
844             } else {
845                 alg = 4;
846             }
847         } else if (communicator_size < 512) {
848             if (total_dsize < 256) {
849                 alg = 1;
850             } else if (total_dsize < 1024) {
851                 alg = 2;
852             } else {
853                 alg = 4;
854             }
855         } else if (communicator_size < 1024) {
856             if (total_dsize < 512) {
857                 alg = 1;
858             } else if (total_dsize < 2048) {
859                 alg = 2;
860             } else if (total_dsize < 8192) {
861                 alg = 4;
862             } else if (total_dsize < 16384) {
863                 alg = 2;
864             } else {
865                 alg = 4;
866             }
867         } else if (communicator_size < 2048) {
868             if (total_dsize < 512) {
869                 alg = 1;
870             } else if (total_dsize < 4096) {
871                 alg = 2;
872             } else if (total_dsize < 16384) {
873                 alg = 4;
874             } else if (total_dsize < 32768) {
875                 alg = 2;
876             } else {
877                 alg = 4;
878             }
879         } else if (communicator_size < 4096) {
880             if (total_dsize < 512) {
881                 alg = 1;
882             } else if (total_dsize < 4096) {
883                 alg = 2;
884             } else {
885                 alg = 4;
886             }
887         } else {
888             if (total_dsize < 1024) {
889                 alg = 1;
890             } else if (total_dsize < 8192) {
891                 alg = 2;
892             } else {
893                 alg = 4;
894             }
895         }
896     }
897
898     return funcs[alg-1] (sbuf, rbuf, rcounts, dtype, op, comm);
899 }
900
901 int allgather__ompi(const void *sbuf, int scount,
902                     MPI_Datatype sdtype,
903                     void* rbuf, int rcount,
904                     MPI_Datatype rdtype,
905                     MPI_Comm  comm
906                     )
907 {
908     int communicator_size;
909     size_t dsize, total_dsize;
910     int alg = 1;
911     communicator_size = comm->size();
912     if (MPI_IN_PLACE != sbuf) {
913         dsize = sdtype->size();
914     } else {
915         dsize = rdtype->size();
916     }
917     total_dsize = dsize * (ptrdiff_t)scount;
918     int (*funcs[])(const void*, int, MPI_Datatype, void*, int, MPI_Datatype, MPI_Comm) = {
919         &allgather__NTSLR_NB,
920         &allgather__bruck,
921         &allgather__rdb,
922         &allgather__ring,
923         &allgather__ompi_neighborexchange,
924         &allgather__pair
925     };
926     /** Algorithms:
927      *  {1, "linear"},
928      *  {2, "bruck"},
929      *  {3, "recursive_doubling"},
930      *  {4, "ring"},
931      *  {5, "neighbor"},
932      *  {6, "two_proc"}
933      */
934     if (communicator_size == 2) {
935         alg = 6;
936     } else if (communicator_size < 32) {
937         alg = 3;
938     } else if (communicator_size < 64) {
939         if (total_dsize < 1024) {
940             alg = 3;
941         } else if (total_dsize < 65536) {
942             alg = 5;
943         } else {
944             alg = 4;
945         }
946     } else if (communicator_size < 128) {
947         if (total_dsize < 512) {
948             alg = 3;
949         } else if (total_dsize < 65536) {
950             alg = 5;
951         } else {
952             alg = 4;
953         }
954     } else if (communicator_size < 256) {
955         if (total_dsize < 512) {
956             alg = 3;
957         } else if (total_dsize < 131072) {
958             alg = 5;
959         } else if (total_dsize < 524288) {
960             alg = 4;
961         } else if (total_dsize < 1048576) {
962             alg = 5;
963         } else {
964             alg = 4;
965         }
966     } else if (communicator_size < 512) {
967         if (total_dsize < 32) {
968             alg = 3;
969         } else if (total_dsize < 128) {
970             alg = 2;
971         } else if (total_dsize < 1024) {
972             alg = 3;
973         } else if (total_dsize < 131072) {
974             alg = 5;
975         } else if (total_dsize < 524288) {
976             alg = 4;
977         } else if (total_dsize < 1048576) {
978             alg = 5;
979         } else {
980             alg = 4;
981         }
982     } else if (communicator_size < 1024) {
983         if (total_dsize < 64) {
984             alg = 3;
985         } else if (total_dsize < 256) {
986             alg = 2;
987         } else if (total_dsize < 2048) {
988             alg = 3;
989         } else {
990             alg = 5;
991         }
992     } else if (communicator_size < 2048) {
993         if (total_dsize < 4) {
994             alg = 3;
995         } else if (total_dsize < 8) {
996             alg = 2;
997         } else if (total_dsize < 16) {
998             alg = 3;
999         } else if (total_dsize < 32) {
1000             alg = 2;
1001         } else if (total_dsize < 256) {
1002             alg = 3;
1003         } else if (total_dsize < 512) {
1004             alg = 2;
1005         } else if (total_dsize < 4096) {
1006             alg = 3;
1007         } else {
1008             alg = 5;
1009         }
1010     } else if (communicator_size < 4096) {
1011         if (total_dsize < 32) {
1012             alg = 2;
1013         } else if (total_dsize < 128) {
1014             alg = 3;
1015         } else if (total_dsize < 512) {
1016             alg = 2;
1017         } else if (total_dsize < 4096) {
1018             alg = 3;
1019         } else {
1020             alg = 5;
1021         }
1022     } else {
1023         if (total_dsize < 2) {
1024             alg = 3;
1025         } else if (total_dsize < 8) {
1026             alg = 2;
1027         } else if (total_dsize < 16) {
1028             alg = 3;
1029         } else if (total_dsize < 512) {
1030             alg = 2;
1031         } else if (total_dsize < 4096) {
1032             alg = 3;
1033         } else {
1034             alg = 5;
1035         }
1036     }
1037
1038     return funcs[alg-1](sbuf, scount, sdtype, rbuf, rcount, rdtype, comm);
1039
1040 }
1041
1042 int allgatherv__ompi(const void *sbuf, int scount,
1043                      MPI_Datatype sdtype,
1044                      void* rbuf, const int *rcounts,
1045                      const int *rdispls,
1046                      MPI_Datatype rdtype,
1047                      MPI_Comm  comm
1048                      )
1049 {
1050     int i;
1051     int communicator_size;
1052     size_t dsize, total_dsize;
1053     int alg = 1;
1054     communicator_size = comm->size();
1055     if (MPI_IN_PLACE != sbuf) {
1056         dsize = sdtype->size();
1057     } else {
1058         dsize = rdtype->size();
1059     }
1060
1061     total_dsize = 0;
1062     for (i = 0; i < communicator_size; i++) {
1063         total_dsize += dsize * rcounts[i];
1064     }
1065
1066     /* use the per-rank data size as basis, similar to allgather */
1067     size_t per_rank_dsize = total_dsize / communicator_size;
1068
1069     int (*funcs[])(const void*, int, MPI_Datatype, void*, const int*, const int*, MPI_Datatype, MPI_Comm) = {
1070         &allgatherv__GB,
1071         &allgatherv__ompi_bruck,
1072         &allgatherv__mpich_ring,
1073         &allgatherv__ompi_neighborexchange,
1074         &allgatherv__pair
1075     };
1076     /** Algorithms:
1077      *  {1, "default"},
1078      *  {2, "bruck"},
1079      *  {3, "ring"},
1080      *  {4, "neighbor"},
1081      *  {5, "two_proc"},
1082      */
1083     if (communicator_size == 2) {
1084         if (per_rank_dsize < 2048) {
1085             alg = 3;
1086         } else if (per_rank_dsize < 4096) {
1087             alg = 5;
1088         } else if (per_rank_dsize < 8192) {
1089             alg = 3;
1090         } else {
1091             alg = 5;
1092         }
1093     } else if (communicator_size < 8) {
1094         if (per_rank_dsize < 256) {
1095             alg = 1;
1096         } else if (per_rank_dsize < 4096) {
1097             alg = 4;
1098         } else if (per_rank_dsize < 8192) {
1099             alg = 3;
1100         } else if (per_rank_dsize < 16384) {
1101             alg = 4;
1102         } else if (per_rank_dsize < 262144) {
1103             alg = 2;
1104         } else {
1105             alg = 4;
1106         }
1107     } else if (communicator_size < 16) {
1108         if (per_rank_dsize < 1024) {
1109             alg = 1;
1110         } else {
1111             alg = 2;
1112         }
1113     } else if (communicator_size < 32) {
1114         if (per_rank_dsize < 128) {
1115             alg = 1;
1116         } else if (per_rank_dsize < 262144) {
1117             alg = 2;
1118         } else {
1119             alg = 3;
1120         }
1121     } else if (communicator_size < 64) {
1122         if (per_rank_dsize < 256) {
1123             alg = 1;
1124         } else if (per_rank_dsize < 8192) {
1125             alg = 2;
1126         } else {
1127             alg = 3;
1128         }
1129     } else if (communicator_size < 128) {
1130         if (per_rank_dsize < 256) {
1131             alg = 1;
1132         } else if (per_rank_dsize < 4096) {
1133             alg = 2;
1134         } else {
1135             alg = 3;
1136         }
1137     } else if (communicator_size < 256) {
1138         if (per_rank_dsize < 1024) {
1139             alg = 2;
1140         } else if (per_rank_dsize < 65536) {
1141             alg = 4;
1142         } else {
1143             alg = 3;
1144         }
1145     } else if (communicator_size < 512) {
1146         if (per_rank_dsize < 1024) {
1147             alg = 2;
1148         } else {
1149             alg = 3;
1150         }
1151     } else if (communicator_size < 1024) {
1152         if (per_rank_dsize < 512) {
1153             alg = 2;
1154         } else if (per_rank_dsize < 1024) {
1155             alg = 1;
1156         } else if (per_rank_dsize < 4096) {
1157             alg = 2;
1158         } else if (per_rank_dsize < 1048576) {
1159             alg = 4;
1160         } else {
1161             alg = 3;
1162         }
1163     } else {
1164         if (per_rank_dsize < 4096) {
1165             alg = 2;
1166         } else {
1167             alg = 4;
1168         }
1169     }
1170
1171     return funcs[alg-1](sbuf, scount, sdtype, rbuf, rcounts, rdispls, rdtype, comm);
1172 }
1173
1174 int gather__ompi(const void *sbuf, int scount,
1175                  MPI_Datatype sdtype,
1176                  void* rbuf, int rcount,
1177                  MPI_Datatype rdtype,
1178                  int root,
1179                  MPI_Comm  comm
1180                  )
1181 {
1182     int communicator_size, rank;
1183     size_t dsize, total_dsize;
1184     int alg = 1;
1185     communicator_size = comm->size();
1186     rank = comm->rank();
1187
1188     if (rank == root) {
1189         dsize = rdtype->size();
1190         total_dsize = dsize * rcount;
1191     } else {
1192         dsize = sdtype->size();
1193         total_dsize = dsize * scount;
1194     }
1195     int (*funcs[])(const void*, int, MPI_Datatype, void*, int, MPI_Datatype, int, MPI_Comm) = {
1196         &gather__ompi_basic_linear,
1197         &gather__ompi_binomial,
1198         &gather__ompi_linear_sync
1199     };
1200     /** Algorithms:
1201      *  {1, "basic_linear"},
1202      *  {2, "binomial"},
1203      *  {3, "linear_sync"},
1204      *
1205      * We do not make any rank specific checks since the params
1206      * should be uniform across ranks.
1207      */
1208     if (communicator_size < 4) {
1209         if (total_dsize < 2) {
1210             alg = 3;
1211         } else if (total_dsize < 4) {
1212             alg = 1;
1213         } else if (total_dsize < 32768) {
1214             alg = 2;
1215         } else if (total_dsize < 65536) {
1216             alg = 1;
1217         } else if (total_dsize < 131072) {
1218             alg = 2;
1219         } else {
1220             alg = 3;
1221         }
1222     } else if (communicator_size < 8) {
1223         if (total_dsize < 1024) {
1224             alg = 2;
1225         } else if (total_dsize < 8192) {
1226             alg = 1;
1227         } else if (total_dsize < 32768) {
1228             alg = 2;
1229         } else if (total_dsize < 262144) {
1230             alg = 1;
1231         } else {
1232             alg = 3;
1233         }
1234     } else if (communicator_size < 256) {
1235         alg = 2;
1236     } else if (communicator_size < 512) {
1237         if (total_dsize < 2048) {
1238             alg = 2;
1239         } else if (total_dsize < 8192) {
1240             alg = 1;
1241         } else {
1242             alg = 2;
1243         }
1244     } else {
1245         alg = 2;
1246     }
1247
1248     return funcs[alg-1](sbuf, scount, sdtype, rbuf, rcount, rdtype, root, comm);
1249 }
1250
1251
1252 int scatter__ompi(const void *sbuf, int scount,
1253                   MPI_Datatype sdtype,
1254                   void* rbuf, int rcount,
1255                   MPI_Datatype rdtype,
1256                   int root, MPI_Comm  comm
1257                   )
1258 {
1259     int communicator_size, rank;
1260     size_t dsize, total_dsize;
1261     int alg = 1;
1262
1263     communicator_size = comm->size();
1264     rank = comm->rank();
1265     if (root == rank) {
1266         dsize=sdtype->size();
1267         total_dsize = dsize * scount;
1268     } else {
1269         dsize=rdtype->size();
1270         total_dsize = dsize * rcount;
1271     }
1272     int (*funcs[])(const void*, int, MPI_Datatype, void*, int, MPI_Datatype, int, MPI_Comm) = {
1273         &scatter__ompi_basic_linear,
1274         &scatter__ompi_binomial,
1275         &scatter__ompi_linear_nb
1276     };
1277     /** Algorithms:
1278      *  {1, "basic_linear"},
1279      *  {2, "binomial"},
1280      *  {3, "linear_nb"},
1281      *
1282      * We do not make any rank specific checks since the params
1283      * should be uniform across ranks.
1284      */
1285     if (communicator_size < 4) {
1286         if (total_dsize < 2) {
1287             alg = 3;
1288         } else if (total_dsize < 131072) {
1289             alg = 1;
1290         } else if (total_dsize < 262144) {
1291             alg = 3;
1292         } else {
1293             alg = 1;
1294         }
1295     } else if (communicator_size < 8) {
1296         if (total_dsize < 2048) {
1297             alg = 2;
1298         } else if (total_dsize < 4096) {
1299             alg = 1;
1300         } else if (total_dsize < 8192) {
1301             alg = 2;
1302         } else if (total_dsize < 32768) {
1303             alg = 1;
1304         } else if (total_dsize < 1048576) {
1305             alg = 3;
1306         } else {
1307             alg = 1;
1308         }
1309     } else if (communicator_size < 16) {
1310         if (total_dsize < 16384) {
1311             alg = 2;
1312         } else if (total_dsize < 1048576) {
1313             alg = 3;
1314         } else {
1315             alg = 1;
1316         }
1317     } else if (communicator_size < 32) {
1318         if (total_dsize < 16384) {
1319             alg = 2;
1320         } else if (total_dsize < 32768) {
1321             alg = 1;
1322         } else {
1323             alg = 3;
1324         }
1325     } else if (communicator_size < 64) {
1326         if (total_dsize < 512) {
1327             alg = 2;
1328         } else if (total_dsize < 8192) {
1329             alg = 3;
1330         } else if (total_dsize < 16384) {
1331             alg = 2;
1332         } else {
1333             alg = 3;
1334         }
1335     } else {
1336         if (total_dsize < 512) {
1337             alg = 2;
1338         } else {
1339             alg = 3;
1340         }
1341     }
1342
1343     return funcs[alg-1](sbuf, scount, sdtype, rbuf, rcount, rdtype, root, comm);
1344 }
1345
1346 }
1347 }