1 /* Copyright (c) 2013-2014. The SimGrid Team.
2 * All rights reserved. */
4 /* This program is free software; you can redistribute it and/or modify it
5 * under the terms of the license (GNU LGPL) which comes with this package. */
7 #include "colls_private.h"
14 int smpi_coll_tuned_reduce_scatter_gather(void *sendbuf, void *recvbuf,
15 int count, MPI_Datatype datatype,
16 MPI_Op op, int root, MPI_Comm comm)
19 int comm_size, rank, pof2, rem, newrank;
20 int mask, *cnts, *disps, i, j, send_idx = 0;
21 int recv_idx, last_idx = 0, newdst;
22 int dst, send_cnt, recv_cnt, newroot, newdst_tree_root;
23 int newroot_tree_root, new_count;
24 int tag = COLL_TAG_REDUCE;
25 void *send_ptr, *recv_ptr, *tmp_buf;
34 rank = smpi_comm_rank(comm);
35 comm_size = smpi_comm_size(comm);
37 extent = smpi_datatype_get_extent(datatype);
39 /* find nearest power-of-two less than or equal to comm_size */
41 while (pof2 <= comm_size)
45 if (count < comm_size) {
46 new_count = comm_size;
47 send_ptr = (void *) xbt_malloc(new_count * extent);
48 recv_ptr = (void *) xbt_malloc(new_count * extent);
49 tmp_buf = (void *) xbt_malloc(new_count * extent);
50 memcpy(send_ptr, sendbuf, extent * count);
53 smpi_mpi_sendrecv(send_ptr, new_count, datatype, rank, tag,
54 recv_ptr, new_count, datatype, rank, tag, comm, &status);
56 rem = comm_size - pof2;
60 smpi_mpi_send(recv_ptr, new_count, datatype, rank - 1, tag, comm);
63 smpi_mpi_recv(tmp_buf, count, datatype, rank + 1, tag, comm, &status);
64 smpi_op_apply(op, tmp_buf, recv_ptr, &new_count, &datatype);
67 } else /* rank >= 2*rem */
70 cnts = (int *) xbt_malloc(pof2 * sizeof(int));
71 disps = (int *) xbt_malloc(pof2 * sizeof(int));
74 for (i = 0; i < (pof2 - 1); i++)
75 cnts[i] = new_count / pof2;
76 cnts[pof2 - 1] = new_count - (new_count / pof2) * (pof2 - 1);
79 for (i = 1; i < pof2; i++)
80 disps[i] = disps[i - 1] + cnts[i - 1];
83 send_idx = recv_idx = 0;
86 newdst = newrank ^ mask;
87 /* find real rank of dest */
88 dst = (newdst < rem) ? newdst * 2 : newdst + rem;
90 send_cnt = recv_cnt = 0;
91 if (newrank < newdst) {
92 send_idx = recv_idx + pof2 / (mask * 2);
93 for (i = send_idx; i < last_idx; i++)
95 for (i = recv_idx; i < send_idx; i++)
98 recv_idx = send_idx + pof2 / (mask * 2);
99 for (i = send_idx; i < recv_idx; i++)
101 for (i = recv_idx; i < last_idx; i++)
105 /* Send data from recvbuf. Recv into tmp_buf */
106 smpi_mpi_sendrecv((char *) recv_ptr +
107 disps[send_idx] * extent,
111 disps[recv_idx] * extent,
112 recv_cnt, datatype, dst, tag, comm, &status);
114 /* tmp_buf contains data received in this step.
115 recvbuf contains data accumulated so far */
117 smpi_op_apply(op, (char *) tmp_buf + disps[recv_idx] * extent,
118 (char *) recv_ptr + disps[recv_idx] * extent,
119 &recv_cnt, &datatype);
121 /* update send_idx for next iteration */
126 last_idx = recv_idx + pof2 / mask;
130 /* now do the gather to root */
132 if (root < 2 * rem) {
136 for (i = 0; i < (pof2 - 1); i++)
137 cnts[i] = new_count / pof2;
138 cnts[pof2 - 1] = new_count - (new_count / pof2) * (pof2 - 1);
141 for (i = 1; i < pof2; i++)
142 disps[i] = disps[i - 1] + cnts[i - 1];
144 smpi_mpi_recv(recv_ptr, cnts[0], datatype, 0, tag, comm, &status);
149 } else if (newrank == 0) {
150 smpi_mpi_send(recv_ptr, cnts[0], datatype, root, tag, comm);
157 newroot = root - rem;
162 while (mask < pof2) {
169 newdst = newrank ^ mask;
171 /* find real rank of dest */
172 dst = (newdst < rem) ? newdst * 2 : newdst + rem;
174 if ((newdst == 0) && (root < 2 * rem) && (root % 2 != 0))
176 newdst_tree_root = newdst >> j;
177 newdst_tree_root <<= j;
179 newroot_tree_root = newroot >> j;
180 newroot_tree_root <<= j;
182 send_cnt = recv_cnt = 0;
183 if (newrank < newdst) {
184 /* update last_idx except on first iteration */
185 if (mask != pof2 / 2)
186 last_idx = last_idx + pof2 / (mask * 2);
188 recv_idx = send_idx + pof2 / (mask * 2);
189 for (i = send_idx; i < recv_idx; i++)
191 for (i = recv_idx; i < last_idx; i++)
194 recv_idx = send_idx - pof2 / (mask * 2);
195 for (i = send_idx; i < last_idx; i++)
197 for (i = recv_idx; i < send_idx; i++)
201 if (newdst_tree_root == newroot_tree_root) {
202 smpi_mpi_send((char *) recv_ptr +
203 disps[send_idx] * extent,
204 send_cnt, datatype, dst, tag, comm);
207 smpi_mpi_recv((char *) recv_ptr +
208 disps[recv_idx] * extent,
209 recv_cnt, datatype, dst, tag, comm, &status);
212 if (newrank > newdst)
219 memcpy(recvbuf, recv_ptr, extent * count);
225 else /* (count >= comm_size) */ {
226 tmp_buf = (void *) xbt_malloc(count * extent);
228 //if ((rank != root))
229 smpi_mpi_sendrecv(sendbuf, count, datatype, rank, tag,
230 recvbuf, count, datatype, rank, tag, comm, &status);
232 rem = comm_size - pof2;
233 if (rank < 2 * rem) {
234 if (rank % 2 != 0) { /* odd */
235 smpi_mpi_send(recvbuf, count, datatype, rank - 1, tag, comm);
240 smpi_mpi_recv(tmp_buf, count, datatype, rank + 1, tag, comm, &status);
241 smpi_op_apply(op, tmp_buf, recvbuf, &count, &datatype);
244 } else /* rank >= 2*rem */
245 newrank = rank - rem;
247 cnts = (int *) xbt_malloc(pof2 * sizeof(int));
248 disps = (int *) xbt_malloc(pof2 * sizeof(int));
251 for (i = 0; i < (pof2 - 1); i++)
252 cnts[i] = count / pof2;
253 cnts[pof2 - 1] = count - (count / pof2) * (pof2 - 1);
256 for (i = 1; i < pof2; i++)
257 disps[i] = disps[i - 1] + cnts[i - 1];
260 send_idx = recv_idx = 0;
262 while (mask < pof2) {
263 newdst = newrank ^ mask;
264 /* find real rank of dest */
265 dst = (newdst < rem) ? newdst * 2 : newdst + rem;
267 send_cnt = recv_cnt = 0;
268 if (newrank < newdst) {
269 send_idx = recv_idx + pof2 / (mask * 2);
270 for (i = send_idx; i < last_idx; i++)
272 for (i = recv_idx; i < send_idx; i++)
275 recv_idx = send_idx + pof2 / (mask * 2);
276 for (i = send_idx; i < recv_idx; i++)
278 for (i = recv_idx; i < last_idx; i++)
282 /* Send data from recvbuf. Recv into tmp_buf */
283 smpi_mpi_sendrecv((char *) recvbuf +
284 disps[send_idx] * extent,
288 disps[recv_idx] * extent,
289 recv_cnt, datatype, dst, tag, comm, &status);
291 /* tmp_buf contains data received in this step.
292 recvbuf contains data accumulated so far */
294 smpi_op_apply(op, (char *) tmp_buf + disps[recv_idx] * extent,
295 (char *) recvbuf + disps[recv_idx] * extent,
296 &recv_cnt, &datatype);
298 /* update send_idx for next iteration */
303 last_idx = recv_idx + pof2 / mask;
307 /* now do the gather to root */
309 if (root < 2 * rem) {
311 if (rank == root) { /* recv */
312 for (i = 0; i < (pof2 - 1); i++)
313 cnts[i] = count / pof2;
314 cnts[pof2 - 1] = count - (count / pof2) * (pof2 - 1);
317 for (i = 1; i < pof2; i++)
318 disps[i] = disps[i - 1] + cnts[i - 1];
320 smpi_mpi_recv(recvbuf, cnts[0], datatype, 0, tag, comm, &status);
325 } else if (newrank == 0) {
326 smpi_mpi_send(recvbuf, cnts[0], datatype, root, tag, comm);
333 newroot = root - rem;
338 while (mask < pof2) {
345 newdst = newrank ^ mask;
347 /* find real rank of dest */
348 dst = (newdst < rem) ? newdst * 2 : newdst + rem;
350 if ((newdst == 0) && (root < 2 * rem) && (root % 2 != 0))
352 newdst_tree_root = newdst >> j;
353 newdst_tree_root <<= j;
355 newroot_tree_root = newroot >> j;
356 newroot_tree_root <<= j;
358 send_cnt = recv_cnt = 0;
359 if (newrank < newdst) {
360 /* update last_idx except on first iteration */
361 if (mask != pof2 / 2)
362 last_idx = last_idx + pof2 / (mask * 2);
364 recv_idx = send_idx + pof2 / (mask * 2);
365 for (i = send_idx; i < recv_idx; i++)
367 for (i = recv_idx; i < last_idx; i++)
370 recv_idx = send_idx - pof2 / (mask * 2);
371 for (i = send_idx; i < last_idx; i++)
373 for (i = recv_idx; i < send_idx; i++)
377 if (newdst_tree_root == newroot_tree_root) {
378 smpi_mpi_send((char *) recvbuf +
379 disps[send_idx] * extent,
380 send_cnt, datatype, dst, tag, comm);
383 smpi_mpi_recv((char *) recvbuf +
384 disps[recv_idx] * extent,
385 recv_cnt, datatype, dst, tag, comm, &status);
388 if (newrank > newdst)