1 /* Copyright (c) 2013-2019. The SimGrid Team.
2 * All rights reserved. */
4 /* This program is free software; you can redistribute it and/or modify it
5 * under the terms of the license (GNU LGPL) which comes with this package. */
8 * (C) 2001 by Argonne National Laboratory.
9 * See COPYRIGHT in top-level directory.
12 /* Copyright (c) 2001-2014, The Ohio State University. All rights
15 * This file is part of the MVAPICH2 software package developed by the
16 * team members of The Ohio State University's Network-Based Computing
17 * Laboratory (NBCL), headed by Professor Dhabaleswar K. (DK) Panda.
19 * For detailed copyright and licensing information, please refer to the
20 * copyright file COPYRIGHT in the top level MVAPICH2 directory.
24 #include "../colls_private.hpp"
29 int Coll_allreduce_mvapich2_rs::allreduce(const void *sendbuf,
32 MPI_Datatype datatype,
33 MPI_Op op, MPI_Comm comm)
35 int mpi_errno = MPI_SUCCESS;
37 int mask, pof2, i, send_idx, recv_idx, last_idx, send_cnt;
38 int dst, is_commutative, rem, newdst, recv_cnt;
39 MPI_Aint true_lb, true_extent, extent;
47 int comm_size = comm->size();
48 int rank = comm->rank();
50 is_commutative = (op==MPI_OP_NULL || op->is_commutative());
52 /* need to allocate temporary buffer to store incoming data */
53 datatype->extent(&true_lb, &true_extent);
54 extent = datatype->get_extent();
56 unsigned char* tmp_buf_free = smpi_get_tmp_recvbuffer(count * std::max(extent, true_extent));
58 /* adjust for potential negative lower bound in datatype */
59 unsigned char* tmp_buf = tmp_buf_free - true_lb;
61 /* copy local data into recvbuf */
62 if (sendbuf != MPI_IN_PLACE) {
64 Datatype::copy(sendbuf, count, datatype, recvbuf, count,
68 /* find nearest power-of-two less than or equal to comm_size */
69 for( pof2 = 1; pof2 <= comm_size; pof2 <<= 1 );
72 rem = comm_size - pof2;
74 /* In the non-power-of-two case, all even-numbered
75 processes of rank < 2*rem send their data to
76 (rank+1). These even-numbered processes no longer
77 participate in the algorithm until the very end. The
78 remaining processes form a nice power-of-two. */
83 Request::send(recvbuf, count, datatype, rank + 1,
84 COLL_TAG_ALLREDUCE, comm);
86 /* temporarily set the rank to -1 so that this
87 process does not pariticipate in recursive
92 Request::recv(tmp_buf, count, datatype, rank - 1,
93 COLL_TAG_ALLREDUCE, comm,
95 /* do the reduction on received data. since the
96 ordering is right, it doesn't matter whether
97 the operation is commutative or not. */
98 if(op!=MPI_OP_NULL) op->apply( tmp_buf, recvbuf, &count, datatype);
102 } else { /* rank >= 2*rem */
103 newrank = rank - rem;
106 /* If op is user-defined or count is less than pof2, use
107 recursive doubling algorithm. Otherwise do a reduce-scatter
108 followed by allgather. (If op is user-defined,
109 derived datatypes are allowed and the user could pass basic
110 datatypes on one process and derived on another as long as
111 the type maps are the same. Breaking up derived
112 datatypes to do the reduce-scatter is tricky, therefore
113 using recursive doubling in that case.) */
116 if (/*(HANDLE_GET_KIND(op) != HANDLE_KIND_BUILTIN) ||*/ (count < pof2)) { /* use recursive doubling */
118 while (mask < pof2) {
119 newdst = newrank ^ mask;
120 /* find real rank of dest */
121 dst = (newdst < rem) ? newdst * 2 + 1 : newdst + rem;
123 /* Send the most current data, which is in recvbuf. Recv
125 Request::sendrecv(recvbuf, count, datatype,
126 dst, COLL_TAG_ALLREDUCE,
127 tmp_buf, count, datatype, dst,
128 COLL_TAG_ALLREDUCE, comm,
131 /* tmp_buf contains data received in this step.
132 recvbuf contains data accumulated so far */
134 if (is_commutative || (dst < rank)) {
135 /* op is commutative OR the order is already right */
136 if(op!=MPI_OP_NULL) op->apply( tmp_buf, recvbuf, &count, datatype);
138 /* op is noncommutative and the order is not right */
139 if(op!=MPI_OP_NULL) op->apply( recvbuf, tmp_buf, &count, datatype);
140 /* copy result back into recvbuf */
141 mpi_errno = Datatype::copy(tmp_buf, count, datatype,
142 recvbuf, count, datatype);
148 /* do a reduce-scatter followed by allgather */
150 /* for the reduce-scatter, calculate the count that
151 each process receives and the displacement within
153 int* cnts = new int[pof2];
154 int* disps = new int[pof2];
156 for (i = 0; i < (pof2 - 1); i++) {
157 cnts[i] = count / pof2;
159 cnts[pof2 - 1] = count - (count / pof2) * (pof2 - 1);
162 for (i = 1; i < pof2; i++) {
163 disps[i] = disps[i - 1] + cnts[i - 1];
167 send_idx = recv_idx = 0;
169 while (mask < pof2) {
170 newdst = newrank ^ mask;
171 /* find real rank of dest */
172 dst = (newdst < rem) ? newdst * 2 + 1 : newdst + rem;
174 send_cnt = recv_cnt = 0;
175 if (newrank < newdst) {
176 send_idx = recv_idx + pof2 / (mask * 2);
177 for (i = send_idx; i < last_idx; i++)
179 for (i = recv_idx; i < send_idx; i++)
182 recv_idx = send_idx + pof2 / (mask * 2);
183 for (i = send_idx; i < recv_idx; i++)
185 for (i = recv_idx; i < last_idx; i++)
189 /* Send data from recvbuf. Recv into tmp_buf */
190 Request::sendrecv(static_cast<char*>(recvbuf) + disps[send_idx] * extent, send_cnt, datatype, dst,
191 COLL_TAG_ALLREDUCE, tmp_buf + disps[recv_idx] * extent, recv_cnt, datatype, dst,
192 COLL_TAG_ALLREDUCE, comm, MPI_STATUS_IGNORE);
194 /* tmp_buf contains data received in this step.
195 recvbuf contains data accumulated so far */
197 /* This algorithm is used only for predefined ops
198 and predefined ops are always commutative. */
200 if (op != MPI_OP_NULL)
201 op->apply(tmp_buf + disps[recv_idx] * extent, static_cast<char*>(recvbuf) + disps[recv_idx] * extent,
202 &recv_cnt, datatype);
204 /* update send_idx for next iteration */
208 /* update last_idx, but not in last iteration
209 because the value is needed in the allgather
212 last_idx = recv_idx + pof2 / mask;
215 /* now do the allgather */
219 newdst = newrank ^ mask;
220 /* find real rank of dest */
221 dst = (newdst < rem) ? newdst * 2 + 1 : newdst + rem;
223 send_cnt = recv_cnt = 0;
224 if (newrank < newdst) {
225 /* update last_idx except on first iteration */
226 if (mask != pof2 / 2) {
227 last_idx = last_idx + pof2 / (mask * 2);
230 recv_idx = send_idx + pof2 / (mask * 2);
231 for (i = send_idx; i < recv_idx; i++) {
234 for (i = recv_idx; i < last_idx; i++) {
238 recv_idx = send_idx - pof2 / (mask * 2);
239 for (i = send_idx; i < last_idx; i++) {
242 for (i = recv_idx; i < send_idx; i++) {
247 Request::sendrecv((char *) recvbuf +
248 disps[send_idx] * extent,
250 dst, COLL_TAG_ALLREDUCE,
252 disps[recv_idx] * extent,
253 recv_cnt, datatype, dst,
254 COLL_TAG_ALLREDUCE, comm,
256 if (newrank > newdst) {
267 /* In the non-power-of-two case, all odd-numbered
268 processes of rank < 2*rem send the result to
269 (rank-1), the ranks who didn't participate above. */
270 if (rank < 2 * rem) {
271 if (rank % 2) { /* odd */
272 Request::send(recvbuf, count,
274 COLL_TAG_ALLREDUCE, comm);
276 Request::recv(recvbuf, count,
278 COLL_TAG_ALLREDUCE, comm,
282 smpi_free_tmp_buffer(tmp_buf_free);