Logo AND Algorithmique Numérique Distribuée

Public GIT Repository
5a762008ab10afe02f2c9aaa3fa81776071dda62
[simgrid.git] / src / smpi / colls / smpi_openmpi_selector.cpp
1 /* selector for collective algorithms based on openmpi's default coll_tuned_decision_fixed selector
2  * Updated 02/2022                                                          */
3
4 /* Copyright (c) 2009-2022. The SimGrid Team.
5  * All rights reserved.                                                     */
6
7 /* This program is free software; you can redistribute it and/or modify it
8  * under the terms of the license (GNU LGPL) which comes with this package. */
9
10 #include "colls_private.hpp"
11
12 #include <memory>
13
14 /* FIXME
15 add algos:
16 allreduce nonoverlapping, basic linear
17 alltoall linear_sync
18 bcast chain
19 scatter linear_nb
20 */
21
22 namespace simgrid {
23 namespace smpi {
24
25 int allreduce__ompi(const void *sbuf, void *rbuf, int count,
26                     MPI_Datatype dtype, MPI_Op op, MPI_Comm comm)
27 {
28     size_t total_dsize = dtype->size() * (ptrdiff_t)count;
29     int communicator_size = comm->size();
30     int alg = 1;
31     int(*funcs[]) (const void*, void*, int, MPI_Datatype, MPI_Op, MPI_Comm)={
32         &allreduce__redbcast,
33         &allreduce__redbcast,
34         &allreduce__rdb,
35         &allreduce__lr,
36         &allreduce__ompi_ring_segmented,
37         &allreduce__rab_rdb
38     };
39
40     /** Algorithms:
41      *  {1, "basic_linear"},
42      *  {2, "nonoverlapping"},
43      *  {3, "recursive_doubling"},
44      *  {4, "ring"},
45      *  {5, "segmented_ring"},
46      *  {6, "rabenseifner"
47      *
48      * Currently, ring, segmented ring, and rabenseifner do not support
49      * non-commutative operations.
50      */
51     if ((op != MPI_OP_NULL) && not op->is_commutative()) {
52         if (communicator_size < 4) {
53             if (total_dsize < 131072) {
54                 alg = 3;
55             } else {
56                 alg = 1;
57             }
58         } else if (communicator_size < 8) {
59             alg = 3;
60         } else if (communicator_size < 16) {
61             if (total_dsize < 1048576) {
62                 alg = 3;
63             } else {
64                 alg = 2;
65             }
66         } else if (communicator_size < 128) {
67             alg = 3;
68         } else if (communicator_size < 256) {
69             if (total_dsize < 131072) {
70                 alg = 2;
71             } else if (total_dsize < 524288) {
72                 alg = 3;
73             } else {
74                 alg = 2;
75             }
76         } else if (communicator_size < 512) {
77             if (total_dsize < 4096) {
78                 alg = 2;
79             } else if (total_dsize < 524288) {
80                 alg = 3;
81             } else {
82                 alg = 2;
83             }
84         } else {
85             if (total_dsize < 2048) {
86                 alg = 2;
87             } else {
88                 alg = 3;
89             }
90         }
91     } else {
92         if (communicator_size < 4) {
93             if (total_dsize < 8) {
94                 alg = 4;
95             } else if (total_dsize < 4096) {
96                 alg = 3;
97             } else if (total_dsize < 8192) {
98                 alg = 4;
99             } else if (total_dsize < 16384) {
100                 alg = 3;
101             } else if (total_dsize < 65536) {
102                 alg = 4;
103             } else if (total_dsize < 262144) {
104                 alg = 5;
105             } else {
106                 alg = 6;
107             }
108         } else if (communicator_size < 8) {
109             if (total_dsize < 16) {
110                 alg = 4;
111             } else if (total_dsize < 8192) {
112                 alg = 3;
113             } else {
114                 alg = 6;
115             }
116         } else if (communicator_size < 16) {
117             if (total_dsize < 8192) {
118                 alg = 3;
119             } else {
120                 alg = 6;
121             }
122         } else if (communicator_size < 32) {
123             if (total_dsize < 64) {
124                 alg = 5;
125             } else if (total_dsize < 4096) {
126                 alg = 3;
127             } else {
128                 alg = 6;
129             }
130         } else if (communicator_size < 64) {
131             if (total_dsize < 128) {
132                 alg = 5;
133             } else {
134                 alg = 6;
135             }
136         } else if (communicator_size < 128) {
137             if (total_dsize < 262144) {
138                 alg = 3;
139             } else {
140                 alg = 6;
141             }
142         } else if (communicator_size < 256) {
143             if (total_dsize < 131072) {
144                 alg = 2;
145             } else if (total_dsize < 262144) {
146                 alg = 3;
147             } else {
148                 alg = 6;
149             }
150         } else if (communicator_size < 512) {
151             if (total_dsize < 4096) {
152                 alg = 2;
153             } else {
154                 alg = 6;
155             }
156         } else if (communicator_size < 2048) {
157             if (total_dsize < 2048) {
158                 alg = 2;
159             } else if (total_dsize < 16384) {
160                 alg = 3;
161             } else {
162                 alg = 6;
163             }
164         } else if (communicator_size < 4096) {
165             if (total_dsize < 2048) {
166                 alg = 2;
167             } else if (total_dsize < 4096) {
168                 alg = 5;
169             } else if (total_dsize < 16384) {
170                 alg = 3;
171             } else {
172                 alg = 6;
173             }
174         } else {
175             if (total_dsize < 2048) {
176                 alg = 2;
177             } else if (total_dsize < 16384) {
178                 alg = 5;
179             } else if (total_dsize < 32768) {
180                 alg = 3;
181             } else {
182                 alg = 6;
183             }
184         }
185     }
186     return funcs[alg-1](sbuf, rbuf, count, dtype, op, comm);
187 }
188
189
190
191 int alltoall__ompi(const void *sbuf, int scount,
192                    MPI_Datatype sdtype,
193                    void* rbuf, int rcount,
194                    MPI_Datatype rdtype,
195                    MPI_Comm comm)
196 {
197     int alg = 1;
198     size_t dsize, total_dsize;
199     int communicator_size = comm->size();
200
201     if (MPI_IN_PLACE != sbuf) {
202         dsize = sdtype->size();
203     } else {
204         dsize = rdtype->size();
205     }
206     total_dsize = dsize * (ptrdiff_t)scount;
207     int (*funcs[])(const void *, int, MPI_Datatype, void*, int, MPI_Datatype, MPI_Comm) = {
208         &alltoall__basic_linear,
209         &alltoall__pair,
210         &alltoall__bruck,
211         &alltoall__basic_linear,
212         &alltoall__basic_linear
213     };
214     /** Algorithms:
215      *  {1, "linear"},
216      *  {2, "pairwise"},
217      *  {3, "modified_bruck"},
218      *  {4, "linear_sync"},
219      *  {5, "two_proc"},
220      */
221     if (communicator_size == 2) {
222         if (total_dsize < 2) {
223             alg = 2;
224         } else if (total_dsize < 4) {
225             alg = 5;
226         } else if (total_dsize < 16) {
227             alg = 2;
228         } else if (total_dsize < 64) {
229             alg = 5;
230         } else if (total_dsize < 256) {
231             alg = 2;
232         } else if (total_dsize < 4096) {
233             alg = 5;
234         } else if (total_dsize < 32768) {
235             alg = 2;
236         } else if (total_dsize < 262144) {
237             alg = 4;
238         } else if (total_dsize < 1048576) {
239             alg = 5;
240         } else {
241             alg = 2;
242         }
243     } else if (communicator_size < 8) {
244         if (total_dsize < 8192) {
245             alg = 4;
246         } else if (total_dsize < 16384) {
247             alg = 1;
248         } else if (total_dsize < 65536) {
249             alg = 4;
250         } else if (total_dsize < 524288) {
251             alg = 1;
252         } else if (total_dsize < 1048576) {
253             alg = 2;
254         } else {
255             alg = 1;
256         }
257     } else if (communicator_size < 16) {
258         if (total_dsize < 262144) {
259             alg = 4;
260         } else {
261             alg = 1;
262         }
263     } else if (communicator_size < 32) {
264         if (total_dsize < 4) {
265             alg = 4;
266         } else if (total_dsize < 512) {
267             alg = 3;
268         } else if (total_dsize < 8192) {
269             alg = 4;
270         } else if (total_dsize < 32768) {
271             alg = 1;
272         } else if (total_dsize < 262144) {
273             alg = 4;
274         } else if (total_dsize < 524288) {
275             alg = 1;
276         } else {
277             alg = 4;
278         }
279     } else if (communicator_size < 64) {
280         if (total_dsize < 512) {
281             alg = 3;
282         } else if (total_dsize < 524288) {
283             alg = 1;
284         } else {
285             alg = 4;
286         }
287     } else if (communicator_size < 128) {
288         if (total_dsize < 1024) {
289             alg = 3;
290         } else if (total_dsize < 2048) {
291             alg = 1;
292         } else if (total_dsize < 4096) {
293             alg = 4;
294         } else if (total_dsize < 262144) {
295             alg = 1;
296         } else {
297             alg = 2;
298         }
299     } else if (communicator_size < 256) {
300         if (total_dsize < 1024) {
301             alg = 3;
302         } else if (total_dsize < 2048) {
303             alg = 4;
304         } else if (total_dsize < 262144) {
305             alg = 1;
306         } else {
307             alg = 2;
308         }
309     } else if (communicator_size < 512) {
310         if (total_dsize < 1024) {
311             alg = 3;
312         } else if (total_dsize < 8192) {
313             alg = 4;
314         } else if (total_dsize < 32768) {
315             alg = 1;
316         } else {
317             alg = 2;
318         }
319     } else if (communicator_size < 1024) {
320         if (total_dsize < 512) {
321             alg = 3;
322         } else if (total_dsize < 8192) {
323             alg = 4;
324         } else if (total_dsize < 16384) {
325             alg = 1;
326         } else if (total_dsize < 131072) {
327             alg = 4;
328         } else if (total_dsize < 262144) {
329             alg = 1;
330         } else {
331             alg = 2;
332         }
333     } else if (communicator_size < 2048) {
334         if (total_dsize < 512) {
335             alg = 3;
336         } else if (total_dsize < 1024) {
337             alg = 4;
338         } else if (total_dsize < 2048) {
339             alg = 1;
340         } else if (total_dsize < 16384) {
341             alg = 4;
342         } else if (total_dsize < 262144) {
343             alg = 1;
344         } else {
345             alg = 4;
346         }
347     } else if (communicator_size < 4096) {
348         if (total_dsize < 1024) {
349             alg = 3;
350         } else if (total_dsize < 4096) {
351             alg = 4;
352         } else if (total_dsize < 8192) {
353             alg = 1;
354         } else if (total_dsize < 131072) {
355             alg = 4;
356         } else {
357             alg = 1;
358         }
359     } else {
360         if (total_dsize < 2048) {
361             alg = 3;
362         } else if (total_dsize < 8192) {
363             alg = 4;
364         } else if (total_dsize < 16384) {
365             alg = 1;
366         } else if (total_dsize < 32768) {
367             alg = 4;
368         } else if (total_dsize < 65536) {
369             alg = 1;
370         } else {
371             alg = 4;
372         }
373     }
374
375     return funcs[alg-1](sbuf, scount, sdtype,
376                           rbuf, rcount, rdtype, comm);
377 }
378
379 int alltoallv__ompi(const void *sbuf, const int *scounts, const int *sdisps,
380                     MPI_Datatype sdtype,
381                     void *rbuf, const int *rcounts, const int *rdisps,
382                     MPI_Datatype rdtype,
383                     MPI_Comm  comm
384                     )
385 {
386     int communicator_size = comm->size();
387     int alg = 1;
388     int (*funcs[])(const void *, const int*, const int*, MPI_Datatype, void*, const int*, const int*, MPI_Datatype, MPI_Comm) = {
389         &alltoallv__ompi_basic_linear,
390         &alltoallv__pair
391     };
392    /** Algorithms:
393      *  {1, "basic_linear"},
394      *  {2, "pairwise"},
395      *
396      * We can only optimize based on com size
397      */
398     if (communicator_size < 4) {
399         alg = 2;
400     } else if (communicator_size < 64) {
401         alg = 1;
402     } else if (communicator_size < 128) {
403         alg = 2;
404     } else if (communicator_size < 256) {
405         alg = 1;
406     } else if (communicator_size < 1024) {
407         alg = 2;
408     } else {
409         alg = 1;
410     }
411     return funcs[alg-1](sbuf, scounts, sdisps, sdtype,
412                            rbuf, rcounts, rdisps,rdtype,
413                            comm);
414 }
415
416 int barrier__ompi(MPI_Comm  comm)
417 {
418     int communicator_size = comm->size();
419     int alg = 1;
420     int (*funcs[])(MPI_Comm) = {
421         &barrier__ompi_basic_linear,
422         &barrier__ompi_basic_linear,
423         &barrier__ompi_recursivedoubling,
424         &barrier__ompi_bruck,
425         &barrier__ompi_two_procs,
426         &barrier__ompi_tree
427     };
428     /** Algorithms:
429      *  {1, "linear"},
430      *  {2, "double_ring"},
431      *  {3, "recursive_doubling"},
432      *  {4, "bruck"},
433      *  {5, "two_proc"},
434      *  {6, "tree"},
435      *
436      * We can only optimize based on com size
437      */
438     if (communicator_size < 4) {
439         alg = 3;
440     } else if (communicator_size < 8) {
441         alg = 1;
442     } else if (communicator_size < 64) {
443         alg = 3;
444     } else if (communicator_size < 256) {
445         alg = 4;
446     } else if (communicator_size < 512) {
447         alg = 6;
448     } else if (communicator_size < 1024) {
449         alg = 4;
450     } else if (communicator_size < 4096) {
451         alg = 6;
452     } else {
453         alg = 4;
454     }
455
456     return funcs[alg-1](comm);
457 }
458
459 int bcast__ompi(void *buff, int count, MPI_Datatype datatype, int root, MPI_Comm  comm)
460 {
461     int alg = 1;
462     size_t total_dsize, dsize;
463
464     int communicator_size = comm->size();
465
466     dsize = datatype->size();
467     total_dsize = dsize * (unsigned long)count;
468     int (*funcs[])(void*, int, MPI_Datatype, int, MPI_Comm) = {
469         &bcast__NTSL,
470         &bcast__ompi_pipeline,
471         &bcast__ompi_pipeline,
472         &bcast__ompi_split_bintree,
473         &bcast__NTSB,
474         &bcast__binomial_tree,
475         &bcast__mvapich2_knomial_intra_node,
476         &bcast__scatter_rdb_allgather,
477         &bcast__scatter_LR_allgather,
478     };
479     /** Algorithms:
480      *  {1, "basic_linear"},
481      *  {2, "chain"},
482      *  {3, "pipeline"},
483      *  {4, "split_binary_tree"},
484      *  {5, "binary_tree"},
485      *  {6, "binomial"},
486      *  {7, "knomial"},
487      *  {8, "scatter_allgather"},
488      *  {9, "scatter_allgather_ring"},
489      */
490     if (communicator_size < 4) {
491         if (total_dsize < 32) {
492             alg = 3;
493         } else if (total_dsize < 256) {
494             alg = 5;
495         } else if (total_dsize < 512) {
496             alg = 3;
497         } else if (total_dsize < 1024) {
498             alg = 7;
499         } else if (total_dsize < 32768) {
500             alg = 1;
501         } else if (total_dsize < 131072) {
502             alg = 5;
503         } else if (total_dsize < 262144) {
504             alg = 2;
505         } else if (total_dsize < 524288) {
506             alg = 1;
507         } else if (total_dsize < 1048576) {
508             alg = 6;
509         } else {
510             alg = 5;
511         }
512     } else if (communicator_size < 8) {
513         if (total_dsize < 64) {
514             alg = 5;
515         } else if (total_dsize < 128) {
516             alg = 6;
517         } else if (total_dsize < 2048) {
518             alg = 5;
519         } else if (total_dsize < 8192) {
520             alg = 6;
521         } else if (total_dsize < 1048576) {
522             alg = 1;
523         } else {
524             alg = 2;
525         }
526     } else if (communicator_size < 16) {
527         if (total_dsize < 8) {
528             alg = 7;
529         } else if (total_dsize < 64) {
530             alg = 5;
531         } else if (total_dsize < 4096) {
532             alg = 7;
533         } else if (total_dsize < 16384) {
534             alg = 5;
535         } else if (total_dsize < 32768) {
536             alg = 6;
537         } else {
538             alg = 1;
539         }
540     } else if (communicator_size < 32) {
541         if (total_dsize < 4096) {
542             alg = 7;
543         } else if (total_dsize < 1048576) {
544             alg = 6;
545         } else {
546             alg = 8;
547         }
548     } else if (communicator_size < 64) {
549         if (total_dsize < 2048) {
550             alg = 6;
551         } else {
552             alg = 7;
553         }
554     } else if (communicator_size < 128) {
555         alg = 7;
556     } else if (communicator_size < 256) {
557         if (total_dsize < 2) {
558             alg = 6;
559         } else if (total_dsize < 16384) {
560             alg = 5;
561         } else if (total_dsize < 32768) {
562             alg = 1;
563         } else if (total_dsize < 65536) {
564             alg = 5;
565         } else {
566             alg = 7;
567         }
568     } else if (communicator_size < 1024) {
569         if (total_dsize < 16384) {
570             alg = 7;
571         } else if (total_dsize < 32768) {
572             alg = 4;
573         } else {
574             alg = 7;
575         }
576     } else if (communicator_size < 2048) {
577         if (total_dsize < 524288) {
578             alg = 7;
579         } else {
580             alg = 8;
581         }
582     } else if (communicator_size < 4096) {
583         if (total_dsize < 262144) {
584             alg = 7;
585         } else {
586             alg = 8;
587         }
588     } else {
589         if (total_dsize < 8192) {
590             alg = 7;
591         } else if (total_dsize < 16384) {
592             alg = 5;
593         } else if (total_dsize < 262144) {
594             alg = 7;
595         } else {
596             alg = 8;
597         }
598     }
599     return funcs[alg-1](buff, count, datatype, root, comm);
600 }
601
602 int reduce__ompi(const void *sendbuf, void *recvbuf,
603                  int count, MPI_Datatype  datatype,
604                  MPI_Op   op, int root,
605                  MPI_Comm   comm)
606 {
607     size_t total_dsize, dsize;
608     int alg = 1;
609     int communicator_size = comm->size();
610
611     dsize=datatype->size();
612     total_dsize = dsize * count;
613     int (*funcs[])(const void*, void*, int, MPI_Datatype, MPI_Op, int, MPI_Comm) = {
614         &reduce__ompi_basic_linear,
615         &reduce__ompi_chain,
616         &reduce__ompi_pipeline,
617         &reduce__ompi_binary,
618         &reduce__ompi_binomial,
619         &reduce__ompi_in_order_binary,
620         //&reduce__rab our rab can't be used with all datatypes
621         &reduce__ompi_basic_linear
622     };
623     /** Algorithms:
624      *  {1, "linear"},
625      *  {2, "chain"},
626      *  {3, "pipeline"},
627      *  {4, "binary"},
628      *  {5, "binomial"},
629      *  {6, "in-order_binary"},
630      *  {7, "rabenseifner"},
631      *
632      * Currently, only linear and in-order binary tree algorithms are
633      * capable of non commutative ops.
634      */
635      if ((op != MPI_OP_NULL) && not op->is_commutative()) {
636         if (communicator_size < 4) {
637             if (total_dsize < 8) {
638                 alg = 6;
639             } else {
640                 alg = 1;
641             }
642         } else if (communicator_size < 8) {
643             alg = 1;
644         } else if (communicator_size < 16) {
645             if (total_dsize < 1024) {
646                 alg = 6;
647             } else if (total_dsize < 8192) {
648                 alg = 1;
649             } else if (total_dsize < 16384) {
650                 alg = 6;
651             } else if (total_dsize < 262144) {
652                 alg = 1;
653             } else {
654                 alg = 6;
655             }
656         } else if (communicator_size < 128) {
657             alg = 6;
658         } else if (communicator_size < 256) {
659             if (total_dsize < 512) {
660                 alg = 6;
661             } else if (total_dsize < 1024) {
662                 alg = 1;
663             } else {
664                 alg = 6;
665             }
666         } else {
667             alg = 6;
668         }
669     } else {
670         if (communicator_size < 4) {
671             if (total_dsize < 8) {
672                 alg = 7;
673             } else if (total_dsize < 16) {
674                 alg = 4;
675             } else if (total_dsize < 32) {
676                 alg = 3;
677             } else if (total_dsize < 262144) {
678                 alg = 1;
679             } else if (total_dsize < 524288) {
680                 alg = 3;
681             } else if (total_dsize < 1048576) {
682                 alg = 2;
683             } else {
684                 alg = 3;
685             }
686         } else if (communicator_size < 8) {
687             if (total_dsize < 4096) {
688                 alg = 4;
689             } else if (total_dsize < 65536) {
690                 alg = 2;
691             } else if (total_dsize < 262144) {
692                 alg = 5;
693             } else if (total_dsize < 524288) {
694                 alg = 1;
695             } else if (total_dsize < 1048576) {
696                 alg = 5;
697             } else {
698                 alg = 1;
699             }
700         } else if (communicator_size < 16) {
701             if (total_dsize < 8192) {
702                 alg = 4;
703             } else {
704                 alg = 5;
705             }
706         } else if (communicator_size < 32) {
707             if (total_dsize < 4096) {
708                 alg = 4;
709             } else {
710                 alg = 5;
711             }
712         } else if (communicator_size < 256) {
713             alg = 5;
714         } else if (communicator_size < 512) {
715             if (total_dsize < 8192) {
716                 alg = 5;
717             } else if (total_dsize < 16384) {
718                 alg = 6;
719             } else {
720                 alg = 5;
721             }
722         } else if (communicator_size < 2048) {
723             alg = 5;
724         } else if (communicator_size < 4096) {
725             if (total_dsize < 512) {
726                 alg = 5;
727             } else if (total_dsize < 1024) {
728                 alg = 6;
729             } else if (total_dsize < 8192) {
730                 alg = 5;
731             } else if (total_dsize < 16384) {
732                 alg = 6;
733             } else {
734                 alg = 5;
735             }
736         } else {
737             if (total_dsize < 16) {
738                 alg = 5;
739             } else if (total_dsize < 32) {
740                 alg = 6;
741             } else if (total_dsize < 1024) {
742                 alg = 5;
743             } else if (total_dsize < 2048) {
744                 alg = 6;
745             } else if (total_dsize < 8192) {
746                 alg = 5;
747             } else if (total_dsize < 16384) {
748                 alg = 6;
749             } else {
750                 alg = 5;
751             }
752         }
753     }
754
755     return funcs[alg-1] (sendbuf, recvbuf, count, datatype, op, root, comm);
756 }
757
758 int reduce_scatter__ompi(const void *sbuf, void *rbuf,
759                          const int *rcounts,
760                          MPI_Datatype dtype,
761                          MPI_Op  op,
762                          MPI_Comm  comm
763                          )
764 {
765     size_t total_dsize, dsize;
766     int communicator_size = comm->size();
767     int alg = 1;
768     int zerocounts = 0;
769     dsize=dtype->size();
770     total_dsize = 0;
771     for (int i = 0; i < communicator_size; i++) {
772         total_dsize += rcounts[i];
773        // if (0 == rcounts[i]) {
774         //    zerocounts = 1;
775         //}
776     }
777     total_dsize *= dsize;
778     int (*funcs[])(const void*, void*, const int*, MPI_Datatype, MPI_Op, MPI_Comm) = {
779         &reduce_scatter__default,
780         &reduce_scatter__ompi_basic_recursivehalving,
781         &reduce_scatter__ompi_ring,
782         &reduce_scatter__ompi_butterfly,
783     };
784     /** Algorithms:
785      *  {1, "non-overlapping"},
786      *  {2, "recursive_halving"},
787      *  {3, "ring"},
788      *  {4, "butterfly"},
789      *
790      * Non commutative algorithm capability needs re-investigation.
791      * Defaulting to non overlapping for non commutative ops.
792      */
793     if (((op != MPI_OP_NULL) && not op->is_commutative()) || (zerocounts)) {
794         alg = 1;
795     } else {
796         if (communicator_size < 4) {
797             if (total_dsize < 65536) {
798                 alg = 3;
799             } else if (total_dsize < 131072) {
800                 alg = 4;
801             } else {
802                 alg = 3;
803             }
804         } else if (communicator_size < 8) {
805             if (total_dsize < 8) {
806                 alg = 1;
807             } else if (total_dsize < 262144) {
808                 alg = 2;
809             } else {
810                 alg = 3;
811             }
812         } else if (communicator_size < 32) {
813             if (total_dsize < 262144) {
814                 alg = 2;
815             } else {
816                 alg = 3;
817             }
818         } else if (communicator_size < 64) {
819             if (total_dsize < 64) {
820                 alg = 1;
821             } else if (total_dsize < 2048) {
822                 alg = 2;
823             } else if (total_dsize < 524288) {
824                 alg = 4;
825             } else {
826                 alg = 3;
827             }
828         } else if (communicator_size < 128) {
829             if (total_dsize < 256) {
830                 alg = 1;
831             } else if (total_dsize < 512) {
832                 alg = 2;
833             } else if (total_dsize < 2048) {
834                 alg = 4;
835             } else if (total_dsize < 4096) {
836                 alg = 2;
837             } else {
838                 alg = 4;
839             }
840         } else if (communicator_size < 256) {
841             if (total_dsize < 256) {
842                 alg = 1;
843             } else if (total_dsize < 512) {
844                 alg = 2;
845             } else {
846                 alg = 4;
847             }
848         } else if (communicator_size < 512) {
849             if (total_dsize < 256) {
850                 alg = 1;
851             } else if (total_dsize < 1024) {
852                 alg = 2;
853             } else {
854                 alg = 4;
855             }
856         } else if (communicator_size < 1024) {
857             if (total_dsize < 512) {
858                 alg = 1;
859             } else if (total_dsize < 2048) {
860                 alg = 2;
861             } else if (total_dsize < 8192) {
862                 alg = 4;
863             } else if (total_dsize < 16384) {
864                 alg = 2;
865             } else {
866                 alg = 4;
867             }
868         } else if (communicator_size < 2048) {
869             if (total_dsize < 512) {
870                 alg = 1;
871             } else if (total_dsize < 4096) {
872                 alg = 2;
873             } else if (total_dsize < 16384) {
874                 alg = 4;
875             } else if (total_dsize < 32768) {
876                 alg = 2;
877             } else {
878                 alg = 4;
879             }
880         } else if (communicator_size < 4096) {
881             if (total_dsize < 512) {
882                 alg = 1;
883             } else if (total_dsize < 4096) {
884                 alg = 2;
885             } else {
886                 alg = 4;
887             }
888         } else {
889             if (total_dsize < 1024) {
890                 alg = 1;
891             } else if (total_dsize < 8192) {
892                 alg = 2;
893             } else {
894                 alg = 4;
895             }
896         }
897     }
898
899     return funcs[alg-1] (sbuf, rbuf, rcounts, dtype, op, comm);
900 }
901
902 int allgather__ompi(const void *sbuf, int scount,
903                     MPI_Datatype sdtype,
904                     void* rbuf, int rcount,
905                     MPI_Datatype rdtype,
906                     MPI_Comm  comm
907                     )
908 {
909     int communicator_size;
910     size_t dsize, total_dsize;
911     int alg = 1;
912     communicator_size = comm->size();
913     if (MPI_IN_PLACE != sbuf) {
914         dsize = sdtype->size();
915     } else {
916         dsize = rdtype->size();
917     }
918     total_dsize = dsize * (ptrdiff_t)scount;
919     int (*funcs[])(const void*, int, MPI_Datatype, void*, int, MPI_Datatype, MPI_Comm) = {
920         &allgather__NTSLR_NB,
921         &allgather__bruck,
922         &allgather__rdb,
923         &allgather__ring,
924         &allgather__ompi_neighborexchange,
925         &allgather__pair
926     };
927     /** Algorithms:
928      *  {1, "linear"},
929      *  {2, "bruck"},
930      *  {3, "recursive_doubling"},
931      *  {4, "ring"},
932      *  {5, "neighbor"},
933      *  {6, "two_proc"}
934      */
935     if (communicator_size == 2) {
936         alg = 6;
937     } else if (communicator_size < 32) {
938         alg = 3;
939     } else if (communicator_size < 64) {
940         if (total_dsize < 1024) {
941             alg = 3;
942         } else if (total_dsize < 65536) {
943             alg = 5;
944         } else {
945             alg = 4;
946         }
947     } else if (communicator_size < 128) {
948         if (total_dsize < 512) {
949             alg = 3;
950         } else if (total_dsize < 65536) {
951             alg = 5;
952         } else {
953             alg = 4;
954         }
955     } else if (communicator_size < 256) {
956         if (total_dsize < 512) {
957             alg = 3;
958         } else if (total_dsize < 131072) {
959             alg = 5;
960         } else if (total_dsize < 524288) {
961             alg = 4;
962         } else if (total_dsize < 1048576) {
963             alg = 5;
964         } else {
965             alg = 4;
966         }
967     } else if (communicator_size < 512) {
968         if (total_dsize < 32) {
969             alg = 3;
970         } else if (total_dsize < 128) {
971             alg = 2;
972         } else if (total_dsize < 1024) {
973             alg = 3;
974         } else if (total_dsize < 131072) {
975             alg = 5;
976         } else if (total_dsize < 524288) {
977             alg = 4;
978         } else if (total_dsize < 1048576) {
979             alg = 5;
980         } else {
981             alg = 4;
982         }
983     } else if (communicator_size < 1024) {
984         if (total_dsize < 64) {
985             alg = 3;
986         } else if (total_dsize < 256) {
987             alg = 2;
988         } else if (total_dsize < 2048) {
989             alg = 3;
990         } else {
991             alg = 5;
992         }
993     } else if (communicator_size < 2048) {
994         if (total_dsize < 4) {
995             alg = 3;
996         } else if (total_dsize < 8) {
997             alg = 2;
998         } else if (total_dsize < 16) {
999             alg = 3;
1000         } else if (total_dsize < 32) {
1001             alg = 2;
1002         } else if (total_dsize < 256) {
1003             alg = 3;
1004         } else if (total_dsize < 512) {
1005             alg = 2;
1006         } else if (total_dsize < 4096) {
1007             alg = 3;
1008         } else {
1009             alg = 5;
1010         }
1011     } else if (communicator_size < 4096) {
1012         if (total_dsize < 32) {
1013             alg = 2;
1014         } else if (total_dsize < 128) {
1015             alg = 3;
1016         } else if (total_dsize < 512) {
1017             alg = 2;
1018         } else if (total_dsize < 4096) {
1019             alg = 3;
1020         } else {
1021             alg = 5;
1022         }
1023     } else {
1024         if (total_dsize < 2) {
1025             alg = 3;
1026         } else if (total_dsize < 8) {
1027             alg = 2;
1028         } else if (total_dsize < 16) {
1029             alg = 3;
1030         } else if (total_dsize < 512) {
1031             alg = 2;
1032         } else if (total_dsize < 4096) {
1033             alg = 3;
1034         } else {
1035             alg = 5;
1036         }
1037     }
1038
1039     return funcs[alg-1](sbuf, scount, sdtype, rbuf, rcount, rdtype, comm);
1040
1041 }
1042
1043 int allgatherv__ompi(const void *sbuf, int scount,
1044                      MPI_Datatype sdtype,
1045                      void* rbuf, const int *rcounts,
1046                      const int *rdispls,
1047                      MPI_Datatype rdtype,
1048                      MPI_Comm  comm
1049                      )
1050 {
1051     int i;
1052     int communicator_size;
1053     size_t dsize, total_dsize;
1054     int alg = 1;
1055     communicator_size = comm->size();
1056     if (MPI_IN_PLACE != sbuf) {
1057         dsize = sdtype->size();
1058     } else {
1059         dsize = rdtype->size();
1060     }
1061
1062     total_dsize = 0;
1063     for (i = 0; i < communicator_size; i++) {
1064         total_dsize += dsize * rcounts[i];
1065     }
1066
1067     /* use the per-rank data size as basis, similar to allgather */
1068     size_t per_rank_dsize = total_dsize / communicator_size;
1069
1070     int (*funcs[])(const void*, int, MPI_Datatype, void*, const int*, const int*, MPI_Datatype, MPI_Comm) = {
1071         &allgatherv__GB,
1072         &allgatherv__ompi_bruck,
1073         &allgatherv__mpich_ring,
1074         &allgatherv__ompi_neighborexchange,
1075         &allgatherv__pair
1076     };
1077     /** Algorithms:
1078      *  {1, "default"},
1079      *  {2, "bruck"},
1080      *  {3, "ring"},
1081      *  {4, "neighbor"},
1082      *  {5, "two_proc"},
1083      */
1084     if (communicator_size == 2) {
1085         if (per_rank_dsize < 2048) {
1086             alg = 3;
1087         } else if (per_rank_dsize < 4096) {
1088             alg = 5;
1089         } else if (per_rank_dsize < 8192) {
1090             alg = 3;
1091         } else {
1092             alg = 5;
1093         }
1094     } else if (communicator_size < 8) {
1095         if (per_rank_dsize < 256) {
1096             alg = 1;
1097         } else if (per_rank_dsize < 4096) {
1098             alg = 4;
1099         } else if (per_rank_dsize < 8192) {
1100             alg = 3;
1101         } else if (per_rank_dsize < 16384) {
1102             alg = 4;
1103         } else if (per_rank_dsize < 262144) {
1104             alg = 2;
1105         } else {
1106             alg = 4;
1107         }
1108     } else if (communicator_size < 16) {
1109         if (per_rank_dsize < 1024) {
1110             alg = 1;
1111         } else {
1112             alg = 2;
1113         }
1114     } else if (communicator_size < 32) {
1115         if (per_rank_dsize < 128) {
1116             alg = 1;
1117         } else if (per_rank_dsize < 262144) {
1118             alg = 2;
1119         } else {
1120             alg = 3;
1121         }
1122     } else if (communicator_size < 64) {
1123         if (per_rank_dsize < 256) {
1124             alg = 1;
1125         } else if (per_rank_dsize < 8192) {
1126             alg = 2;
1127         } else {
1128             alg = 3;
1129         }
1130     } else if (communicator_size < 128) {
1131         if (per_rank_dsize < 256) {
1132             alg = 1;
1133         } else if (per_rank_dsize < 4096) {
1134             alg = 2;
1135         } else {
1136             alg = 3;
1137         }
1138     } else if (communicator_size < 256) {
1139         if (per_rank_dsize < 1024) {
1140             alg = 2;
1141         } else if (per_rank_dsize < 65536) {
1142             alg = 4;
1143         } else {
1144             alg = 3;
1145         }
1146     } else if (communicator_size < 512) {
1147         if (per_rank_dsize < 1024) {
1148             alg = 2;
1149         } else {
1150             alg = 3;
1151         }
1152     } else if (communicator_size < 1024) {
1153         if (per_rank_dsize < 512) {
1154             alg = 2;
1155         } else if (per_rank_dsize < 1024) {
1156             alg = 1;
1157         } else if (per_rank_dsize < 4096) {
1158             alg = 2;
1159         } else if (per_rank_dsize < 1048576) {
1160             alg = 4;
1161         } else {
1162             alg = 3;
1163         }
1164     } else {
1165         if (per_rank_dsize < 4096) {
1166             alg = 2;
1167         } else {
1168             alg = 4;
1169         }
1170     }
1171
1172     return funcs[alg-1](sbuf, scount, sdtype, rbuf, rcounts, rdispls, rdtype, comm);
1173 }
1174
1175 int gather__ompi(const void *sbuf, int scount,
1176                  MPI_Datatype sdtype,
1177                  void* rbuf, int rcount,
1178                  MPI_Datatype rdtype,
1179                  int root,
1180                  MPI_Comm  comm
1181                  )
1182 {
1183     int communicator_size, rank;
1184     size_t dsize, total_dsize;
1185     int alg = 1;
1186     communicator_size = comm->size();
1187     rank = comm->rank();
1188
1189     if (rank == root) {
1190         dsize = rdtype->size();
1191         total_dsize = dsize * rcount;
1192     } else {
1193         dsize = sdtype->size();
1194         total_dsize = dsize * scount;
1195     }
1196     int (*funcs[])(const void*, int, MPI_Datatype, void*, int, MPI_Datatype, int, MPI_Comm) = {
1197         &gather__ompi_basic_linear,
1198         &gather__ompi_binomial,
1199         &gather__ompi_linear_sync
1200     };
1201     /** Algorithms:
1202      *  {1, "basic_linear"},
1203      *  {2, "binomial"},
1204      *  {3, "linear_sync"},
1205      *
1206      * We do not make any rank specific checks since the params
1207      * should be uniform across ranks.
1208      */
1209     if (communicator_size < 4) {
1210         if (total_dsize < 2) {
1211             alg = 3;
1212         } else if (total_dsize < 4) {
1213             alg = 1;
1214         } else if (total_dsize < 32768) {
1215             alg = 2;
1216         } else if (total_dsize < 65536) {
1217             alg = 1;
1218         } else if (total_dsize < 131072) {
1219             alg = 2;
1220         } else {
1221             alg = 3;
1222         }
1223     } else if (communicator_size < 8) {
1224         if (total_dsize < 1024) {
1225             alg = 2;
1226         } else if (total_dsize < 8192) {
1227             alg = 1;
1228         } else if (total_dsize < 32768) {
1229             alg = 2;
1230         } else if (total_dsize < 262144) {
1231             alg = 1;
1232         } else {
1233             alg = 3;
1234         }
1235     } else if (communicator_size < 256) {
1236         alg = 2;
1237     } else if (communicator_size < 512) {
1238         if (total_dsize < 2048) {
1239             alg = 2;
1240         } else if (total_dsize < 8192) {
1241             alg = 1;
1242         } else {
1243             alg = 2;
1244         }
1245     } else {
1246         alg = 2;
1247     }
1248
1249     return funcs[alg-1](sbuf, scount, sdtype, rbuf, rcount, rdtype, root, comm);
1250 }
1251
1252
1253 int scatter__ompi(const void *sbuf, int scount,
1254                   MPI_Datatype sdtype,
1255                   void* rbuf, int rcount,
1256                   MPI_Datatype rdtype,
1257                   int root, MPI_Comm  comm
1258                   )
1259 {
1260     int communicator_size, rank;
1261     size_t dsize, total_dsize;
1262     int alg = 1;
1263
1264     communicator_size = comm->size();
1265     rank = comm->rank();
1266     if (root == rank) {
1267         dsize=sdtype->size();
1268         total_dsize = dsize * scount;
1269     } else {
1270         dsize=rdtype->size();
1271         total_dsize = dsize * rcount;
1272     }
1273     int (*funcs[])(const void*, int, MPI_Datatype, void*, int, MPI_Datatype, int, MPI_Comm) = {
1274         &scatter__ompi_basic_linear,
1275         &scatter__ompi_binomial,
1276         &scatter__ompi_basic_linear
1277     };
1278     /** Algorithms:
1279      *  {1, "basic_linear"},
1280      *  {2, "binomial"},
1281      *  {3, "linear_nb"},
1282      *
1283      * We do not make any rank specific checks since the params
1284      * should be uniform across ranks.
1285      */
1286     if (communicator_size < 4) {
1287         if (total_dsize < 2) {
1288             alg = 3;
1289         } else if (total_dsize < 131072) {
1290             alg = 1;
1291         } else if (total_dsize < 262144) {
1292             alg = 3;
1293         } else {
1294             alg = 1;
1295         }
1296     } else if (communicator_size < 8) {
1297         if (total_dsize < 2048) {
1298             alg = 2;
1299         } else if (total_dsize < 4096) {
1300             alg = 1;
1301         } else if (total_dsize < 8192) {
1302             alg = 2;
1303         } else if (total_dsize < 32768) {
1304             alg = 1;
1305         } else if (total_dsize < 1048576) {
1306             alg = 3;
1307         } else {
1308             alg = 1;
1309         }
1310     } else if (communicator_size < 16) {
1311         if (total_dsize < 16384) {
1312             alg = 2;
1313         } else if (total_dsize < 1048576) {
1314             alg = 3;
1315         } else {
1316             alg = 1;
1317         }
1318     } else if (communicator_size < 32) {
1319         if (total_dsize < 16384) {
1320             alg = 2;
1321         } else if (total_dsize < 32768) {
1322             alg = 1;
1323         } else {
1324             alg = 3;
1325         }
1326     } else if (communicator_size < 64) {
1327         if (total_dsize < 512) {
1328             alg = 2;
1329         } else if (total_dsize < 8192) {
1330             alg = 3;
1331         } else if (total_dsize < 16384) {
1332             alg = 2;
1333         } else {
1334             alg = 3;
1335         }
1336     } else {
1337         if (total_dsize < 512) {
1338             alg = 2;
1339         } else {
1340             alg = 3;
1341         }
1342     }
1343
1344     return funcs[alg-1](sbuf, scount, sdtype, rbuf, rcount, rdtype, root, comm);
1345 }
1346
1347 }
1348 }