Logo AND Algorithmique Numérique Distribuée

Public GIT Repository
Update copyright lines for 2023.
[simgrid.git] / src / smpi / colls / reduce_scatter / reduce_scatter-ompi.cpp
1 /* Copyright (c) 2013-2023. The SimGrid Team.
2  * All rights reserved.                                                     */
3
4 /* This program is free software; you can redistribute it and/or modify it
5  * under the terms of the license (GNU LGPL) which comes with this package. */
6
7 /*
8  * Copyright (c) 2004-2005 The Trustees of Indiana University and Indiana
9  *                         University Research and Technology
10  *                         Corporation.  All rights reserved.
11  * Copyright (c) 2004-2012 The University of Tennessee and The University
12  *                         of Tennessee Research Foundation.  All rights
13  *                         reserved.
14  * Copyright (c) 2004-2005 High Performance Computing Center Stuttgart,
15  *                         University of Stuttgart.  All rights reserved.
16  * Copyright (c) 2004-2005 The Regents of the University of California.
17  *                         All rights reserved.
18  * Copyright (c) 2008      Sun Microsystems, Inc.  All rights reserved.
19  * Copyright (c) 2009      University of Houston. All rights reserved.
20  *
21  * Additional copyrights may follow
22  */
23
24 #include "../coll_tuned_topo.hpp"
25 #include "../colls_private.hpp"
26
27 /*
28  * Recursive-halving function is (*mostly*) copied from the BASIC coll module.
29  * I have removed the part which handles "large" message sizes
30  * (non-overlapping version of reduce_Scatter).
31  */
32
33 /* copied function (with appropriate renaming) starts here */
34
35 /*
36  *  reduce_scatter_ompi_basic_recursivehalving
37  *
38  *  Function:   - reduce scatter implementation using recursive-halving
39  *                algorithm
40  *  Accepts:    - same as MPI_Reduce_scatter()
41  *  Returns:    - MPI_SUCCESS or error code
42  *  Limitation: - Works only for commutative operations.
43  */
44 namespace simgrid::smpi {
45 int reduce_scatter__ompi_basic_recursivehalving(const void *sbuf,
46                                                 void *rbuf,
47                                                 const int *rcounts,
48                                                 MPI_Datatype dtype,
49                                                 MPI_Op op,
50                                                 MPI_Comm comm
51                                                 )
52 {
53     int i, rank, size, count, err = MPI_SUCCESS;
54     int tmp_size = 1, remain = 0, tmp_rank;
55     ptrdiff_t true_lb, true_extent, lb, extent, buf_size;
56     unsigned char *result_buf = nullptr, *result_buf_free = nullptr;
57
58     /* Initialize */
59     rank = comm->rank();
60     size = comm->size();
61
62     XBT_DEBUG("coll:tuned:reduce_scatter_ompi_basic_recursivehalving, rank %d", rank);
63     if ((op != MPI_OP_NULL && not op->is_commutative()))
64       throw std::invalid_argument(
65           "reduce_scatter ompi_basic_recursivehalving can only be used for commutative operations!");
66
67     /* Find displacements and the like */
68     int* disps = new int[size];
69
70     disps[0] = 0;
71     for (i = 0; i < (size - 1); ++i) {
72         disps[i + 1] = disps[i] + rcounts[i];
73     }
74     count = disps[size - 1] + rcounts[size - 1];
75
76     /* short cut the trivial case */
77     if (0 == count) {
78       delete[] disps;
79       return MPI_SUCCESS;
80     }
81
82     /* get datatype information */
83     dtype->extent(&lb, &extent);
84     dtype->extent(&true_lb, &true_extent);
85     buf_size = true_extent + (ptrdiff_t)(count - 1) * extent;
86
87     /* Handle MPI_IN_PLACE */
88     if (MPI_IN_PLACE == sbuf) {
89         sbuf = rbuf;
90     }
91
92     /* Allocate temporary receive buffer. */
93     unsigned char* recv_buf_free = smpi_get_tmp_recvbuffer(buf_size);
94     unsigned char* recv_buf      = recv_buf_free - lb;
95     if (nullptr == recv_buf_free) {
96       err = MPI_ERR_OTHER;
97       goto cleanup;
98     }
99
100     /* allocate temporary buffer for results */
101     result_buf_free = smpi_get_tmp_sendbuffer(buf_size);
102     result_buf = result_buf_free - lb;
103
104     /* copy local buffer into the temporary results */
105     err =Datatype::copy(sbuf, count, dtype, result_buf, count, dtype);
106     if (MPI_SUCCESS != err) goto cleanup;
107
108     /* figure out power of two mapping: grow until larger than
109        comm size, then go back one, to get the largest power of
110        two less than comm size */
111     while (tmp_size <= size) tmp_size <<= 1;
112     tmp_size >>= 1;
113     remain = size - tmp_size;
114
115     /* If comm size is not a power of two, have the first "remain"
116        procs with an even rank send to rank + 1, leaving a power of
117        two procs to do the rest of the algorithm */
118     if (rank < 2 * remain) {
119         if ((rank & 1) == 0) {
120             Request::send(result_buf, count, dtype, rank + 1,
121                                     COLL_TAG_REDUCE_SCATTER,
122                                     comm);
123             /* we don't participate from here on out */
124             tmp_rank = -1;
125         } else {
126             Request::recv(recv_buf, count, dtype, rank - 1,
127                                     COLL_TAG_REDUCE_SCATTER,
128                                     comm, MPI_STATUS_IGNORE);
129
130             /* integrate their results into our temp results */
131             if(op!=MPI_OP_NULL) op->apply( recv_buf, result_buf, &count, dtype);
132
133             /* adjust rank to be the bottom "remain" ranks */
134             tmp_rank = rank / 2;
135         }
136     } else {
137         /* just need to adjust rank to show that the bottom "even
138            remain" ranks dropped out */
139         tmp_rank = rank - remain;
140     }
141
142     /* For ranks not kicked out by the above code, perform the
143        recursive halving */
144     if (tmp_rank >= 0) {
145         int mask, send_index, recv_index, last_index;
146
147         /* recalculate disps and rcounts to account for the
148            special "remainder" processes that are no longer doing
149            anything */
150         int* tmp_rcounts = new int[tmp_size];
151         int* tmp_disps   = new int[tmp_size];
152
153         for (i = 0 ; i < tmp_size ; ++i) {
154             if (i < remain) {
155                 /* need to include old neighbor as well */
156                 tmp_rcounts[i] = rcounts[i * 2 + 1] + rcounts[i * 2];
157             } else {
158                 tmp_rcounts[i] = rcounts[i + remain];
159             }
160         }
161
162         tmp_disps[0] = 0;
163         for (i = 0; i < tmp_size - 1; ++i) {
164             tmp_disps[i + 1] = tmp_disps[i] + tmp_rcounts[i];
165         }
166
167         /* do the recursive halving communication.  Don't use the
168            dimension information on the communicator because I
169            think the information is invalidated by our "shrinking"
170            of the communicator */
171         mask = tmp_size >> 1;
172         send_index = recv_index = 0;
173         last_index = tmp_size;
174         while (mask > 0) {
175             int tmp_peer, peer, send_count, recv_count;
176             MPI_Request request;
177
178             tmp_peer = tmp_rank ^ mask;
179             peer = (tmp_peer < remain) ? tmp_peer * 2 + 1 : tmp_peer + remain;
180
181             /* figure out if we're sending, receiving, or both */
182             send_count = recv_count = 0;
183             if (tmp_rank < tmp_peer) {
184                 send_index = recv_index + mask;
185                 for (i = send_index ; i < last_index ; ++i) {
186                     send_count += tmp_rcounts[i];
187                 }
188                 for (i = recv_index ; i < send_index ; ++i) {
189                     recv_count += tmp_rcounts[i];
190                 }
191             } else {
192                 recv_index = send_index + mask;
193                 for (i = send_index ; i < recv_index ; ++i) {
194                     send_count += tmp_rcounts[i];
195                 }
196                 for (i = recv_index ; i < last_index ; ++i) {
197                     recv_count += tmp_rcounts[i];
198                 }
199             }
200
201             /* actual data transfer.  Send from result_buf,
202                receive into recv_buf */
203             if (send_count > 0 && recv_count != 0) {
204                 request=Request::irecv(recv_buf + (ptrdiff_t)tmp_disps[recv_index] * extent,
205                                          recv_count, dtype, peer,
206                                          COLL_TAG_REDUCE_SCATTER,
207                                          comm);
208                 if (MPI_SUCCESS != err) {
209                   delete[] tmp_rcounts;
210                   delete[] tmp_disps;
211                   goto cleanup;
212                 }
213             }
214             if (recv_count > 0 && send_count != 0) {
215                 Request::send(result_buf + (ptrdiff_t)tmp_disps[send_index] * extent,
216                                         send_count, dtype, peer,
217                                         COLL_TAG_REDUCE_SCATTER,
218                                         comm);
219                 if (MPI_SUCCESS != err) {
220                   delete[] tmp_rcounts;
221                   delete[] tmp_disps;
222                   goto cleanup;
223                 }
224             }
225             if (send_count > 0 && recv_count != 0) {
226                 Request::wait(&request, MPI_STATUS_IGNORE);
227             }
228
229             /* if we received something on this step, push it into
230                the results buffer */
231             if (recv_count > 0) {
232                 if(op!=MPI_OP_NULL) op->apply(
233                                recv_buf + (ptrdiff_t)tmp_disps[recv_index] * extent,
234                                result_buf + (ptrdiff_t)tmp_disps[recv_index] * extent,
235                                &recv_count, dtype);
236             }
237
238             /* update for next iteration */
239             send_index = recv_index;
240             last_index = recv_index + mask;
241             mask >>= 1;
242         }
243
244         /* copy local results from results buffer into real receive buffer */
245         if (0 != rcounts[rank]) {
246             err = Datatype::copy(result_buf + disps[rank] * extent,
247                                        rcounts[rank], dtype,
248                                        rbuf, rcounts[rank], dtype);
249             if (MPI_SUCCESS != err) {
250               delete[] tmp_rcounts;
251               delete[] tmp_disps;
252               goto cleanup;
253             }
254         }
255
256         delete[] tmp_rcounts;
257         delete[] tmp_disps;
258     }
259
260     /* Now fix up the non-power of two case, by having the odd
261        procs send the even procs the proper results */
262     if (rank < (2 * remain)) {
263         if ((rank & 1) == 0) {
264             if (rcounts[rank]) {
265                 Request::recv(rbuf, rcounts[rank], dtype, rank + 1,
266                                         COLL_TAG_REDUCE_SCATTER,
267                                         comm, MPI_STATUS_IGNORE);
268             }
269         } else {
270             if (rcounts[rank - 1]) {
271                 Request::send(result_buf + disps[rank - 1] * extent,
272                                         rcounts[rank - 1], dtype, rank - 1,
273                                         COLL_TAG_REDUCE_SCATTER,
274                                         comm);
275             }
276         }
277     }
278
279  cleanup:
280     delete[] disps;
281     if (nullptr != recv_buf_free)
282       smpi_free_tmp_buffer(recv_buf_free);
283     if (nullptr != result_buf_free)
284       smpi_free_tmp_buffer(result_buf_free);
285
286     return err;
287 }
288
289 /* copied function (with appropriate renaming) ends here */
290
291
292 /*
293  *   Coll_reduce_scatter_ompi_ring::reduce_scatter
294  *
295  *   Function:       Ring algorithm for reduce_scatter operation
296  *   Accepts:        Same as MPI_Reduce_scatter()
297  *   Returns:        MPI_SUCCESS or error code
298  *
299  *   Description:    Implements ring algorithm for reduce_scatter:
300  *                   the block sizes defined in rcounts are exchanged and
301  8                    updated until they reach proper destination.
302  *                   Algorithm requires 2 * max(rcounts) extra buffering
303  *
304  *   Limitations:    The algorithm DOES NOT preserve order of operations so it
305  *                   can be used only for commutative operations.
306  *         Example on 5 nodes:
307  *         Initial state
308  *   #      0              1             2              3             4
309  *        [00]           [10]   ->     [20]           [30]           [40]
310  *        [01]           [11]          [21]  ->       [31]           [41]
311  *        [02]           [12]          [22]           [32]  ->       [42]
312  *    ->  [03]           [13]          [23]           [33]           [43] --> ..
313  *        [04]  ->       [14]          [24]           [34]           [44]
314  *
315  *        COMPUTATION PHASE
316  *         Step 0: rank r sends block (r-1) to rank (r+1) and
317  *                 receives block (r+1) from rank (r-1) [with wraparound].
318  *   #      0              1             2              3             4
319  *        [00]           [10]        [10+20]   ->     [30]           [40]
320  *        [01]           [11]          [21]          [21+31]  ->     [41]
321  *    ->  [02]           [12]          [22]           [32]         [32+42] -->..
322  *      [43+03] ->       [13]          [23]           [33]           [43]
323  *        [04]         [04+14]  ->     [24]           [34]           [44]
324  *
325  *         Step 1:
326  *   #      0              1             2              3             4
327  *        [00]           [10]        [10+20]       [10+20+30] ->     [40]
328  *    ->  [01]           [11]          [21]          [21+31]      [21+31+41] ->
329  *     [32+42+02] ->     [12]          [22]           [32]         [32+42]
330  *        [03]        [43+03+13] ->    [23]           [33]           [43]
331  *        [04]         [04+14]      [04+14+24]  ->    [34]           [44]
332  *
333  *         Step 2:
334  *   #      0              1             2              3             4
335  *     -> [00]           [10]        [10+20]       [10+20+30]   [10+20+30+40] ->
336  *   [21+31+41+01]->     [11]          [21]          [21+31]      [21+31+41]
337  *     [32+42+02]   [32+42+02+12]->    [22]           [32]         [32+42]
338  *        [03]        [43+03+13]   [43+03+13+23]->    [33]           [43]
339  *        [04]         [04+14]      [04+14+24]    [04+14+24+34] ->   [44]
340  *
341  *         Step 3:
342  *   #      0             1              2              3             4
343  * [10+20+30+40+00]     [10]         [10+20]       [10+20+30]   [10+20+30+40]
344  *  [21+31+41+01] [21+31+41+01+11]     [21]          [21+31]      [21+31+41]
345  *    [32+42+02]   [32+42+02+12] [32+42+02+12+22]     [32]         [32+42]
346  *       [03]        [43+03+13]    [43+03+13+23] [43+03+13+23+33]    [43]
347  *       [04]         [04+14]       [04+14+24]    [04+14+24+34] [04+14+24+34+44]
348  *    DONE :)
349  *
350  */
351 int reduce_scatter__ompi_ring(const void *sbuf, void *rbuf, const int *rcounts,
352                               MPI_Datatype dtype,
353                               MPI_Op op,
354                               MPI_Comm comm
355                               )
356 {
357     int ret, line, rank, size, i, k, recv_from, send_to, total_count, max_block_count;
358     int inbi;
359     unsigned char *tmpsend = nullptr, *tmprecv = nullptr, *accumbuf = nullptr, *accumbuf_free = nullptr;
360     unsigned char *inbuf_free[2] = {nullptr, nullptr}, *inbuf[2] = {nullptr, nullptr};
361     ptrdiff_t true_lb, true_extent, lb, extent, max_real_segsize;
362     MPI_Request reqs[2] = {nullptr, nullptr};
363
364     size = comm->size();
365     rank = comm->rank();
366
367     XBT_DEBUG(  "coll:tuned:reduce_scatter_ompi_ring rank %d, size %d",
368                  rank, size);
369
370     /* Determine the maximum number of elements per node,
371        corresponding block size, and displacements array.
372     */
373     int* displs = new int[size];
374
375     displs[0] = 0;
376     total_count = rcounts[0];
377     max_block_count = rcounts[0];
378     for (i = 1; i < size; i++) {
379         displs[i] = total_count;
380         total_count += rcounts[i];
381         if (max_block_count < rcounts[i]) max_block_count = rcounts[i];
382     }
383
384     /* Special case for size == 1 */
385     if (1 == size) {
386         if (MPI_IN_PLACE != sbuf) {
387             ret = Datatype::copy((char*)sbuf, total_count, dtype, (char*)rbuf, total_count, dtype);
388             if (ret < 0) { line = __LINE__; goto error_hndl; }
389         }
390         delete[] displs;
391         return MPI_SUCCESS;
392     }
393
394     /* Allocate and initialize temporary buffers, we need:
395        - a temporary buffer to perform reduction (size total_count) since
396        rbuf can be of rcounts[rank] size.
397        - up to two temporary buffers used for communication/computation overlap.
398     */
399     dtype->extent(&lb, &extent);
400     dtype->extent(&true_lb, &true_extent);
401
402     max_real_segsize = true_extent + (ptrdiff_t)(max_block_count - 1) * extent;
403
404     accumbuf_free = smpi_get_tmp_recvbuffer(true_extent + (ptrdiff_t)(total_count - 1) * extent);
405     if (nullptr == accumbuf_free) {
406       ret  = -1;
407       line = __LINE__;
408       goto error_hndl;
409     }
410     accumbuf = accumbuf_free - lb;
411
412     inbuf_free[0] = smpi_get_tmp_sendbuffer(max_real_segsize);
413     if (nullptr == inbuf_free[0]) {
414       ret  = -1;
415       line = __LINE__;
416       goto error_hndl;
417     }
418     inbuf[0] = inbuf_free[0] - lb;
419     if (size > 2) {
420       inbuf_free[1] = smpi_get_tmp_sendbuffer(max_real_segsize);
421       if (nullptr == inbuf_free[1]) {
422         ret  = -1;
423         line = __LINE__;
424         goto error_hndl;
425       }
426       inbuf[1] = inbuf_free[1] - lb;
427     }
428
429     /* Handle MPI_IN_PLACE for size > 1 */
430     if (MPI_IN_PLACE == sbuf) {
431         sbuf = rbuf;
432     }
433
434     ret = Datatype::copy((char*)sbuf, total_count, dtype, accumbuf, total_count, dtype);
435     if (ret < 0) { line = __LINE__; goto error_hndl; }
436
437     /* Computation loop */
438
439     /*
440        For each of the remote nodes:
441        - post irecv for block (r-2) from (r-1) with wrap around
442        - send block (r-1) to (r+1)
443        - in loop for every step k = 2 .. n
444        - post irecv for block (r - 1 + n - k) % n
445        - wait on block (r + n - k) % n to arrive
446        - compute on block (r + n - k ) % n
447        - send block (r + n - k) % n
448        - wait on block (r)
449        - compute on block (r)
450        - copy block (r) to rbuf
451        Note that we must be careful when computing the beginning of buffers and
452        for send operations and computation we must compute the exact block size.
453     */
454     send_to = (rank + 1) % size;
455     recv_from = (rank + size - 1) % size;
456
457     inbi = 0;
458     /* Initialize first receive from the neighbor on the left */
459     reqs[inbi]=Request::irecv(inbuf[inbi], max_block_count, dtype, recv_from,
460                              COLL_TAG_REDUCE_SCATTER, comm
461                              );
462     tmpsend = accumbuf + (ptrdiff_t)displs[recv_from] * extent;
463     Request::send(tmpsend, rcounts[recv_from], dtype, send_to,
464                             COLL_TAG_REDUCE_SCATTER,
465                              comm);
466
467     for (k = 2; k < size; k++) {
468         const int prevblock = (rank + size - k) % size;
469
470         inbi = inbi ^ 0x1;
471
472         /* Post irecv for the current block */
473         reqs[inbi]=Request::irecv(inbuf[inbi], max_block_count, dtype, recv_from,
474                                  COLL_TAG_REDUCE_SCATTER, comm
475                                  );
476
477         /* Wait on previous block to arrive */
478         Request::wait(&reqs[inbi ^ 0x1], MPI_STATUS_IGNORE);
479
480         /* Apply operation on previous block: result goes to rbuf
481            rbuf[prevblock] = inbuf[inbi ^ 0x1] (op) rbuf[prevblock]
482         */
483         tmprecv = accumbuf + (ptrdiff_t)displs[prevblock] * extent;
484         if (op != MPI_OP_NULL)
485           op->apply(inbuf[inbi ^ 0x1], tmprecv, &rcounts[prevblock], dtype);
486
487         /* send previous block to send_to */
488         Request::send(tmprecv, rcounts[prevblock], dtype, send_to,
489                                 COLL_TAG_REDUCE_SCATTER,
490                                  comm);
491     }
492
493     /* Wait on the last block to arrive */
494     Request::wait(&reqs[inbi], MPI_STATUS_IGNORE);
495
496     /* Apply operation on the last block (my block)
497        rbuf[rank] = inbuf[inbi] (op) rbuf[rank] */
498     tmprecv = accumbuf + (ptrdiff_t)displs[rank] * extent;
499     if (op != MPI_OP_NULL)
500       op->apply(inbuf[inbi], tmprecv, &rcounts[rank], dtype);
501
502     /* Copy result from tmprecv to rbuf */
503     ret = Datatype::copy(tmprecv, rcounts[rank], dtype, (char*)rbuf, rcounts[rank], dtype);
504     if (ret < 0) { line = __LINE__; goto error_hndl; }
505
506     delete[] displs;
507     if (nullptr != accumbuf_free)
508       smpi_free_tmp_buffer(accumbuf_free);
509     if (nullptr != inbuf_free[0])
510       smpi_free_tmp_buffer(inbuf_free[0]);
511     if (nullptr != inbuf_free[1])
512       smpi_free_tmp_buffer(inbuf_free[1]);
513
514     return MPI_SUCCESS;
515
516  error_hndl:
517     XBT_DEBUG( "%s:%4d\tRank %d Error occurred %d\n",
518                  __FILE__, line, rank, ret);
519     delete[] displs;
520     if (nullptr != accumbuf_free)
521       smpi_free_tmp_buffer(accumbuf_free);
522     if (nullptr != inbuf_free[0])
523       smpi_free_tmp_buffer(inbuf_free[0]);
524     if (nullptr != inbuf_free[1])
525       smpi_free_tmp_buffer(inbuf_free[1]);
526     return ret;
527 }
528
529 static int ompi_sum_counts(const int *counts, int *displs, int nprocs_rem, int lo, int hi)
530 {
531     /* Adjust lo and hi for taking into account blocks of excluded processes */
532     lo = (lo < nprocs_rem) ? lo * 2 : lo + nprocs_rem;
533     hi = (hi < nprocs_rem) ? hi * 2 + 1 : hi + nprocs_rem;
534     return displs[hi] + counts[hi] - displs[lo];
535 }
536
537 /*
538  * ompi_mirror_perm: Returns mirror permutation of nbits low-order bits
539  *                   of x [*].
540  * [*] Warren Jr., Henry S. Hacker's Delight (2ed). 2013.
541  *     Chapter 7. Rearranging Bits and Bytes.
542  */
543 static unsigned int ompi_mirror_perm(unsigned int x, int nbits)
544 {
545     x = (((x & 0xaaaaaaaa) >> 1) | ((x & 0x55555555) << 1));
546     x = (((x & 0xcccccccc) >> 2) | ((x & 0x33333333) << 2));
547     x = (((x & 0xf0f0f0f0) >> 4) | ((x & 0x0f0f0f0f) << 4));
548     x = (((x & 0xff00ff00) >> 8) | ((x & 0x00ff00ff) << 8));
549     x = ((x >> 16) | (x << 16));
550     return x >> (sizeof(x) * 8 - nbits);
551 }
552 /*
553  * ompi_coll_base_reduce_scatter_intra_butterfly
554  *
555  * Function:  Butterfly algorithm for reduce_scatter
556  * Accepts:   Same as MPI_Reduce_scatter
557  * Returns:   MPI_SUCCESS or error code
558  *
559  * Description:  Implements butterfly algorithm for MPI_Reduce_scatter [*].
560  *               The algorithm can be used both by commutative and non-commutative
561  *               operations, for power-of-two and non-power-of-two number of processes.
562  *
563  * [*] J.L. Traff. An improved Algorithm for (non-commutative) Reduce-scatter
564  *     with an Application // Proc. of EuroPVM/MPI, 2005. -- pp. 129-137.
565  *
566  * Time complexity: O(m\lambda + log(p)\alpha + m\beta + m\gamma),
567  *   where m = sum of rcounts[], p = comm_size
568  * Memory requirements (per process): 2 * m * typesize + comm_size
569  *
570  * Example: comm_size=6, nprocs_pof2=4, nprocs_rem=2, rcounts[]=1, sbuf=[0,1,...,5]
571  * Step 1. Reduce the number of processes to 4
572  * rank 0: [0|1|2|3|4|5]: send to 1: vrank -1
573  * rank 1: [0|1|2|3|4|5]: recv from 0, op: vrank 0: [0|2|4|6|8|10]
574  * rank 2: [0|1|2|3|4|5]: send to 3: vrank -1
575  * rank 3: [0|1|2|3|4|5]: recv from 2, op: vrank 1: [0|2|4|6|8|10]
576  * rank 4: [0|1|2|3|4|5]: vrank 2: [0|1|2|3|4|5]
577  * rank 5: [0|1|2|3|4|5]: vrank 3: [0|1|2|3|4|5]
578  *
579  * Step 2. Butterfly. Buffer of 6 elements is divided into 4 blocks.
580  * Round 1 (mask=1, nblocks=2)
581  * 0: vrank -1
582  * 1: vrank  0 [0 2|4 6|8|10]: exch with 1: send [2,3], recv [0,1]: [0 4|8 12|*|*]
583  * 2: vrank -1
584  * 3: vrank  1 [0 2|4 6|8|10]: exch with 0: send [0,1], recv [2,3]: [**|**|16|20]
585  * 4: vrank  2 [0 1|2 3|4|5] : exch with 3: send [2,3], recv [0,1]: [0 2|4 6|*|*]
586  * 5: vrank  3 [0 1|2 3|4|5] : exch with 2: send [0,1], recv [2,3]: [**|**|8|10]
587  *
588  * Round 2 (mask=2, nblocks=1)
589  * 0: vrank -1
590  * 1: vrank  0 [0 4|8 12|*|*]: exch with 2: send [1], recv [0]: [0 6|**|*|*]
591  * 2: vrank -1
592  * 3: vrank  1 [**|**|16|20] : exch with 3: send [3], recv [2]: [**|**|24|*]
593  * 4: vrank  2 [0 2|4 6|*|*] : exch with 0: send [0], recv [1]: [**|12 18|*|*]
594  * 5: vrank  3 [**|**|8|10]  : exch with 1: send [2], recv [3]: [**|**|*|30]
595  *
596  * Step 3. Exchange with remote process according to a mirror permutation:
597  *         mperm(0)=0, mperm(1)=2, mperm(2)=1, mperm(3)=3
598  * 0: vrank -1: recv "0" from process 0
599  * 1: vrank  0 [0 6|**|*|*]: send "0" to 0, copy "6" to rbuf (mperm(0)=0)
600  * 2: vrank -1: recv result "12" from process 4
601  * 3: vrank  1 [**|**|24|*]
602  * 4: vrank  2 [**|12 18|*|*]: send "12" to 2, send "18" to 3, recv "24" from 3
603  * 5: vrank  3 [**|**|*|30]: copy "30" to rbuf (mperm(3)=3)
604  */
605 int reduce_scatter__ompi_butterfly(
606     const void *sbuf, void *rbuf, const int *rcounts, MPI_Datatype dtype,
607     MPI_Op op, MPI_Comm comm)
608 {
609     char *tmpbuf[2] = {NULL, NULL}, *psend, *precv;
610     int *displs = NULL, index;
611     ptrdiff_t span, gap, totalcount, extent;
612     int err = MPI_SUCCESS;
613     int comm_size = comm->size();
614     int rank = comm->rank();
615     int vrank = -1;
616     int nprocs_rem = 0;
617
618     XBT_DEBUG("coll:base:reduce_scatter_intra_butterfly: rank %d/%d",
619                  rank, comm_size);
620     if (comm_size < 2)
621         return MPI_SUCCESS;
622
623     displs = (int*)malloc(sizeof(*displs) * comm_size);
624     if (NULL == displs) {
625         err = MPI_ERR_OTHER;
626         goto cleanup_and_return;
627     }
628     displs[0] = 0;
629     for (int i = 1; i < comm_size; i++) {
630         displs[i] = displs[i - 1] + rcounts[i - 1];
631     }
632     totalcount = displs[comm_size - 1] + rcounts[comm_size - 1];
633     dtype->extent(&gap, &extent);
634     span = extent * totalcount;
635     tmpbuf[0] = (char*)malloc(span);
636     tmpbuf[1] = (char*)malloc(span);
637     if (NULL == tmpbuf[0] || NULL == tmpbuf[1]) {
638         err = MPI_ERR_OTHER;
639         goto cleanup_and_return;
640     }
641     psend = tmpbuf[0] - gap;
642     precv = tmpbuf[1] - gap;
643
644     if (sbuf != MPI_IN_PLACE) {
645         err = Datatype::copy(sbuf, totalcount, dtype, psend, totalcount, dtype);
646         if (MPI_SUCCESS != err) { goto cleanup_and_return; }
647     } else {
648         err = Datatype::copy(rbuf, totalcount, dtype, psend, totalcount, dtype);
649         if (MPI_SUCCESS != err) { goto cleanup_and_return; }
650     }
651
652     /*
653      * Step 1. Reduce the number of processes to the nearest lower power of two
654      * p' = 2^{\floor{\log_2 p}} by removing r = p - p' processes.
655      * In the first 2r processes (ranks 0 to 2r - 1), all the even ranks send
656      * the input vector to their neighbor (rank + 1) and all the odd ranks recv
657      * the input vector and perform local reduction.
658      * The odd ranks (0 to 2r - 1) contain the reduction with the input
659      * vector on their neighbors (the even ranks). The first r odd
660      * processes and the p - 2r last processes are renumbered from
661      * 0 to 2^{\floor{\log_2 p}} - 1. Even ranks do not participate in the
662      * rest of the algorithm.
663      */
664
665     /* Find nearest power-of-two less than or equal to comm_size */
666     int nprocs_pof2, size;
667     for( nprocs_pof2 = 1, size = comm_size; size > 0; size >>= 1, nprocs_pof2 <<= 1 );
668     nprocs_pof2 >>= 1;
669
670     nprocs_rem = comm_size - nprocs_pof2;
671     int log2_size;
672     for (log2_size = 0, size = 1; size < nprocs_pof2; ++log2_size, size <<= 1);
673
674     if (rank < 2 * nprocs_rem) {
675         if ((rank % 2) == 0) {
676             /* Even process */
677             Request::send(psend, totalcount, dtype, rank + 1,
678                                     COLL_TAG_REDUCE_SCATTER, comm);
679             /* This process does not participate in the rest of the algorithm */
680             vrank = -1;
681         } else {
682             /* Odd process */
683             Request::recv(precv, totalcount, dtype, rank - 1,
684                                     COLL_TAG_REDUCE_SCATTER, comm, MPI_STATUS_IGNORE);
685             op->apply(precv, psend, (int*)&totalcount, dtype);
686             /* Adjust rank to be the bottom "remain" ranks */
687             vrank = rank / 2;
688         }
689     } else {
690         /* Adjust rank to show that the bottom "even remain" ranks dropped out */
691         vrank = rank - nprocs_rem;
692     }
693
694     if (vrank != -1) {
695         /*
696          * Now, psend vector of size totalcount is divided into nprocs_pof2 blocks:
697          * block 0:   rcounts[0] and rcounts[1] -- for process 0 and 1
698          * block 1:   rcounts[2] and rcounts[3] -- for process 2 and 3
699          * ...
700          * block r-1: rcounts[2*(r-1)] and rcounts[2*(r-1)+1]
701          * block r:   rcounts[r+r]
702          * block r+1: rcounts[r+r+1]
703          * ...
704          * block nprocs_pof2 - 1: rcounts[r+nprocs_pof2-1]
705          */
706         int nblocks = nprocs_pof2, send_index = 0, recv_index = 0;
707         for (int mask = 1; mask < nprocs_pof2; mask <<= 1) {
708             int vpeer = vrank ^ mask;
709             int peer = (vpeer < nprocs_rem) ? vpeer * 2 + 1 : vpeer + nprocs_rem;
710
711             nblocks /= 2;
712             if ((vrank & mask) == 0) {
713                 /* Send the upper half of reduction buffer, recv the lower half */
714                 send_index += nblocks;
715             } else {
716                 /* Send the upper half of reduction buffer, recv the lower half */
717                 recv_index += nblocks;
718             }
719
720             /* Send blocks: [send_index, send_index + nblocks - 1] */
721             int send_count = ompi_sum_counts(rcounts, displs, nprocs_rem,
722                                              send_index, send_index + nblocks - 1);
723             index = (send_index < nprocs_rem) ? 2 * send_index : nprocs_rem + send_index;
724             ptrdiff_t sdispl = displs[index];
725
726             /* Recv blocks: [recv_index, recv_index + nblocks - 1] */
727             int recv_count = ompi_sum_counts(rcounts, displs, nprocs_rem,
728                                              recv_index, recv_index + nblocks - 1);
729             index = (recv_index < nprocs_rem) ? 2 * recv_index : nprocs_rem + recv_index;
730             ptrdiff_t rdispl = displs[index];
731
732             Request::sendrecv(psend + (ptrdiff_t)sdispl * extent, send_count,
733                                           dtype, peer, COLL_TAG_REDUCE_SCATTER,
734                                           precv + (ptrdiff_t)rdispl * extent, recv_count,
735                                           dtype, peer, COLL_TAG_REDUCE_SCATTER,
736                                           comm, MPI_STATUS_IGNORE);
737
738             if (vrank < vpeer) {
739                 /* precv = psend <op> precv */
740                 op->apply(psend + (ptrdiff_t)rdispl * extent,
741                                precv + (ptrdiff_t)rdispl * extent, &recv_count, dtype);
742                 char *p = psend;
743                 psend = precv;
744                 precv = p;
745             } else {
746                 /* psend = precv <op> psend */
747                 op->apply(precv + (ptrdiff_t)rdispl * extent,
748                                psend + (ptrdiff_t)rdispl * extent, &recv_count, dtype);
749             }
750             send_index = recv_index;
751         }
752         /*
753          * psend points to the result block [send_index]
754          * Exchange results with remote process according to a mirror permutation.
755          */
756         int vpeer = ompi_mirror_perm(vrank, log2_size);
757         int peer = (vpeer < nprocs_rem) ? vpeer * 2 + 1 : vpeer + nprocs_rem;
758         index = (send_index < nprocs_rem) ? 2 * send_index : nprocs_rem + send_index;
759
760         if (vpeer < nprocs_rem) {
761             /*
762              * Process has two blocks: for excluded process and own.
763              * Send the first block to excluded process.
764              */
765             Request::send(psend + (ptrdiff_t)displs[index] * extent,
766                                     rcounts[index], dtype, peer - 1,
767                                     COLL_TAG_REDUCE_SCATTER,
768                                     comm);
769         }
770
771         /* If process has two blocks, then send the second block (own block) */
772         if (vpeer < nprocs_rem)
773             index++;
774         if (vpeer != vrank) {
775             Request::sendrecv(psend + (ptrdiff_t)displs[index] * extent,
776                                           rcounts[index], dtype, peer,
777                                           COLL_TAG_REDUCE_SCATTER,
778                                           rbuf, rcounts[rank], dtype, peer,
779                                           COLL_TAG_REDUCE_SCATTER,
780                                           comm, MPI_STATUS_IGNORE);
781         } else {
782             err = Datatype::copy(psend + (ptrdiff_t)displs[rank] * extent, rcounts[rank], dtype,
783                                  rbuf, rcounts[rank], dtype);
784             if (MPI_SUCCESS != err) { goto cleanup_and_return; }
785         }
786
787     } else {
788         /* Excluded process: receive result */
789         int vpeer = ompi_mirror_perm((rank + 1) / 2, log2_size);
790         int peer = (vpeer < nprocs_rem) ? vpeer * 2 + 1 : vpeer + nprocs_rem;
791         Request::recv(rbuf, rcounts[rank], dtype, peer,
792                                 COLL_TAG_REDUCE_SCATTER, comm,
793                                 MPI_STATUS_IGNORE);
794     }
795
796 cleanup_and_return:
797     if (displs)
798         free(displs);
799     if (tmpbuf[0])
800         free(tmpbuf[0]);
801     if (tmpbuf[1])
802         free(tmpbuf[1]);
803     return err;
804 }
805 } // namespace simgrid::smpi