1 #include "colls_private.h"
8 int smpi_coll_tuned_reduce_scatter_gather(void *sendbuf, void *recvbuf,
9 int count, MPI_Datatype datatype,
10 MPI_Op op, int root, MPI_Comm comm)
13 int comm_size, rank, type_size, pof2, rem, newrank;
14 int mask, *cnts, *disps, i, j, send_idx = 0;
15 int recv_idx, last_idx = 0, newdst;
16 int dst, send_cnt, recv_cnt, newroot, newdst_tree_root;
17 int newroot_tree_root, new_count;
19 void *send_ptr, *recv_ptr, *tmp_buf;
28 rank = smpi_comm_rank(comm);
29 comm_size = smpi_comm_size(comm);
31 extent = smpi_datatype_get_extent(datatype);
32 type_size = smpi_datatype_size(datatype);
34 /* find nearest power-of-two less than or equal to comm_size */
36 while (pof2 <= comm_size)
40 if (count < comm_size) {
41 new_count = comm_size;
42 send_ptr = (void *) xbt_malloc(new_count * extent);
43 recv_ptr = (void *) xbt_malloc(new_count * extent);
44 tmp_buf = (void *) xbt_malloc(new_count * extent);
45 memcpy(send_ptr, sendbuf, extent * new_count);
48 smpi_mpi_sendrecv(send_ptr, new_count, datatype, rank, tag,
49 recv_ptr, new_count, datatype, rank, tag, comm, &status);
51 rem = comm_size - pof2;
55 smpi_mpi_send(recv_ptr, new_count, datatype, rank - 1, tag, comm);
58 smpi_mpi_recv(tmp_buf, count, datatype, rank + 1, tag, comm, &status);
59 smpi_op_apply(op, tmp_buf, recv_ptr, &new_count, &datatype);
62 } else /* rank >= 2*rem */
65 cnts = (int *) xbt_malloc(pof2 * sizeof(int));
66 disps = (int *) xbt_malloc(pof2 * sizeof(int));
69 for (i = 0; i < (pof2 - 1); i++)
70 cnts[i] = new_count / pof2;
71 cnts[pof2 - 1] = new_count - (new_count / pof2) * (pof2 - 1);
74 for (i = 1; i < pof2; i++)
75 disps[i] = disps[i - 1] + cnts[i - 1];
78 send_idx = recv_idx = 0;
81 newdst = newrank ^ mask;
82 /* find real rank of dest */
83 dst = (newdst < rem) ? newdst * 2 : newdst + rem;
85 send_cnt = recv_cnt = 0;
86 if (newrank < newdst) {
87 send_idx = recv_idx + pof2 / (mask * 2);
88 for (i = send_idx; i < last_idx; i++)
90 for (i = recv_idx; i < send_idx; i++)
93 recv_idx = send_idx + pof2 / (mask * 2);
94 for (i = send_idx; i < recv_idx; i++)
96 for (i = recv_idx; i < last_idx; i++)
100 /* Send data from recvbuf. Recv into tmp_buf */
101 smpi_mpi_sendrecv((char *) recv_ptr +
102 disps[send_idx] * extent,
106 disps[recv_idx] * extent,
107 recv_cnt, datatype, dst, tag, comm, &status);
109 /* tmp_buf contains data received in this step.
110 recvbuf contains data accumulated so far */
112 smpi_op_apply(op, (char *) tmp_buf + disps[recv_idx] * extent,
113 (char *) recv_ptr + disps[recv_idx] * extent,
114 &recv_cnt, &datatype);
116 /* update send_idx for next iteration */
121 last_idx = recv_idx + pof2 / mask;
125 /* now do the gather to root */
127 if (root < 2 * rem) {
131 for (i = 0; i < (pof2 - 1); i++)
132 cnts[i] = new_count / pof2;
133 cnts[pof2 - 1] = new_count - (new_count / pof2) * (pof2 - 1);
136 for (i = 1; i < pof2; i++)
137 disps[i] = disps[i - 1] + cnts[i - 1];
139 smpi_mpi_recv(recv_ptr, cnts[0], datatype, 0, tag, comm, &status);
144 } else if (newrank == 0) {
145 smpi_mpi_send(recv_ptr, cnts[0], datatype, root, tag, comm);
152 newroot = root - rem;
157 while (mask < pof2) {
164 newdst = newrank ^ mask;
166 /* find real rank of dest */
167 dst = (newdst < rem) ? newdst * 2 : newdst + rem;
169 if ((newdst == 0) && (root < 2 * rem) && (root % 2 != 0))
171 newdst_tree_root = newdst >> j;
172 newdst_tree_root <<= j;
174 newroot_tree_root = newroot >> j;
175 newroot_tree_root <<= j;
177 send_cnt = recv_cnt = 0;
178 if (newrank < newdst) {
179 /* update last_idx except on first iteration */
180 if (mask != pof2 / 2)
181 last_idx = last_idx + pof2 / (mask * 2);
183 recv_idx = send_idx + pof2 / (mask * 2);
184 for (i = send_idx; i < recv_idx; i++)
186 for (i = recv_idx; i < last_idx; i++)
189 recv_idx = send_idx - pof2 / (mask * 2);
190 for (i = send_idx; i < last_idx; i++)
192 for (i = recv_idx; i < send_idx; i++)
196 if (newdst_tree_root == newroot_tree_root) {
197 smpi_mpi_send((char *) recv_ptr +
198 disps[send_idx] * extent,
199 send_cnt, datatype, dst, tag, comm);
202 smpi_mpi_recv((char *) recv_ptr +
203 disps[recv_idx] * extent,
204 recv_cnt, datatype, dst, tag, comm, &status);
207 if (newrank > newdst)
214 memcpy(recvbuf, recv_ptr, extent * count);
220 else if (count >= comm_size) {
221 tmp_buf = (void *) xbt_malloc(count * extent);
223 //if ((rank != root))
224 smpi_mpi_sendrecv(sendbuf, count, datatype, rank, tag,
225 recvbuf, count, datatype, rank, tag, comm, &status);
227 rem = comm_size - pof2;
228 if (rank < 2 * rem) {
229 if (rank % 2 != 0) { /* odd */
230 smpi_mpi_send(recvbuf, count, datatype, rank - 1, tag, comm);
235 smpi_mpi_recv(tmp_buf, count, datatype, rank + 1, tag, comm, &status);
236 smpi_op_apply(op, tmp_buf, recvbuf, &count, &datatype);
239 } else /* rank >= 2*rem */
240 newrank = rank - rem;
242 cnts = (int *) xbt_malloc(pof2 * sizeof(int));
243 disps = (int *) xbt_malloc(pof2 * sizeof(int));
246 for (i = 0; i < (pof2 - 1); i++)
247 cnts[i] = count / pof2;
248 cnts[pof2 - 1] = count - (count / pof2) * (pof2 - 1);
251 for (i = 1; i < pof2; i++)
252 disps[i] = disps[i - 1] + cnts[i - 1];
255 send_idx = recv_idx = 0;
257 while (mask < pof2) {
258 newdst = newrank ^ mask;
259 /* find real rank of dest */
260 dst = (newdst < rem) ? newdst * 2 : newdst + rem;
262 send_cnt = recv_cnt = 0;
263 if (newrank < newdst) {
264 send_idx = recv_idx + pof2 / (mask * 2);
265 for (i = send_idx; i < last_idx; i++)
267 for (i = recv_idx; i < send_idx; i++)
270 recv_idx = send_idx + pof2 / (mask * 2);
271 for (i = send_idx; i < recv_idx; i++)
273 for (i = recv_idx; i < last_idx; i++)
277 /* Send data from recvbuf. Recv into tmp_buf */
278 smpi_mpi_sendrecv((char *) recvbuf +
279 disps[send_idx] * extent,
283 disps[recv_idx] * extent,
284 recv_cnt, datatype, dst, tag, comm, &status);
286 /* tmp_buf contains data received in this step.
287 recvbuf contains data accumulated so far */
289 smpi_op_apply(op, (char *) tmp_buf + disps[recv_idx] * extent,
290 (char *) recvbuf + disps[recv_idx] * extent,
291 &recv_cnt, &datatype);
293 /* update send_idx for next iteration */
298 last_idx = recv_idx + pof2 / mask;
302 /* now do the gather to root */
304 if (root < 2 * rem) {
306 if (rank == root) { /* recv */
307 for (i = 0; i < (pof2 - 1); i++)
308 cnts[i] = count / pof2;
309 cnts[pof2 - 1] = count - (count / pof2) * (pof2 - 1);
312 for (i = 1; i < pof2; i++)
313 disps[i] = disps[i - 1] + cnts[i - 1];
315 smpi_mpi_recv(recvbuf, cnts[0], datatype, 0, tag, comm, &status);
320 } else if (newrank == 0) {
321 smpi_mpi_send(recvbuf, cnts[0], datatype, root, tag, comm);
328 newroot = root - rem;
333 while (mask < pof2) {
340 newdst = newrank ^ mask;
342 /* find real rank of dest */
343 dst = (newdst < rem) ? newdst * 2 : newdst + rem;
345 if ((newdst == 0) && (root < 2 * rem) && (root % 2 != 0))
347 newdst_tree_root = newdst >> j;
348 newdst_tree_root <<= j;
350 newroot_tree_root = newroot >> j;
351 newroot_tree_root <<= j;
353 send_cnt = recv_cnt = 0;
354 if (newrank < newdst) {
355 /* update last_idx except on first iteration */
356 if (mask != pof2 / 2)
357 last_idx = last_idx + pof2 / (mask * 2);
359 recv_idx = send_idx + pof2 / (mask * 2);
360 for (i = send_idx; i < recv_idx; i++)
362 for (i = recv_idx; i < last_idx; i++)
365 recv_idx = send_idx - pof2 / (mask * 2);
366 for (i = send_idx; i < last_idx; i++)
368 for (i = recv_idx; i < send_idx; i++)
372 if (newdst_tree_root == newroot_tree_root) {
373 smpi_mpi_send((char *) recvbuf +
374 disps[send_idx] * extent,
375 send_cnt, datatype, dst, tag, comm);
378 smpi_mpi_recv((char *) recvbuf +
379 disps[recv_idx] * extent,
380 recv_cnt, datatype, dst, tag, comm, &status);
383 if (newrank > newdst)