1 #include "colls_private.h"
6 int bcast_SMP_binary_segment_byte = 8192;
8 int smpi_coll_tuned_bcast_SMP_binary(void *buf, int count,
9 MPI_Datatype datatype, int root,
12 int tag = COLL_TAG_BCAST;
15 MPI_Request *request_array;
16 MPI_Status *status_array;
20 extent = smpi_datatype_get_extent(datatype);
22 rank = smpi_comm_rank(comm);
23 size = smpi_comm_size(comm);
25 int segment = bcast_SMP_binary_segment_byte / extent;
26 int pipe_length = count / segment;
27 int remainder = count % segment;
29 int to_intra_left = (rank / NUM_CORE) * NUM_CORE + (rank % NUM_CORE) * 2 + 1;
30 int to_intra_right = (rank / NUM_CORE) * NUM_CORE + (rank % NUM_CORE) * 2 + 2;
31 int to_inter_left = ((rank / NUM_CORE) * 2 + 1) * NUM_CORE;
32 int to_inter_right = ((rank / NUM_CORE) * 2 + 2) * NUM_CORE;
33 int from_inter = (((rank / NUM_CORE) - 1) / 2) * NUM_CORE;
34 int from_intra = (rank / NUM_CORE) * NUM_CORE + ((rank % NUM_CORE) - 1) / 2;
35 int increment = segment * extent;
37 int base = (rank / NUM_CORE) * NUM_CORE;
38 int num_core = NUM_CORE;
39 if (((rank / NUM_CORE) * NUM_CORE) == ((size / NUM_CORE) * NUM_CORE))
40 num_core = size - (rank / NUM_CORE) * NUM_CORE;
42 // if root is not zero send to rank zero first
45 smpi_mpi_send(buf, count, datatype, 0, tag, comm);
47 smpi_mpi_recv(buf, count, datatype, root, tag, comm, &status);
49 // when a message is smaller than a block size => no pipeline
50 if (count <= segment) {
51 // case ROOT-of-each-SMP
52 if (rank % NUM_CORE == 0) {
55 //printf("node %d left %d right %d\n",rank,to_inter_left,to_inter_right);
56 if (to_inter_left < size)
57 smpi_mpi_send(buf, count, datatype, to_inter_left, tag, comm);
58 if (to_inter_right < size)
59 smpi_mpi_send(buf, count, datatype, to_inter_right, tag, comm);
60 if ((to_intra_left - base) < num_core)
61 smpi_mpi_send(buf, count, datatype, to_intra_left, tag, comm);
62 if ((to_intra_right - base) < num_core)
63 smpi_mpi_send(buf, count, datatype, to_intra_right, tag, comm);
65 // case LEAVES ROOT-of-eash-SMP
66 else if (to_inter_left >= size) {
67 //printf("node %d from %d\n",rank,from_inter);
68 request = smpi_mpi_irecv(buf, count, datatype, from_inter, tag, comm);
69 smpi_mpi_wait(&request, &status);
70 if ((to_intra_left - base) < num_core)
71 smpi_mpi_send(buf, count, datatype, to_intra_left, tag, comm);
72 if ((to_intra_right - base) < num_core)
73 smpi_mpi_send(buf, count, datatype, to_intra_right, tag, comm);
75 // case INTERMEDIAT ROOT-of-each-SMP
77 //printf("node %d left %d right %d from %d\n",rank,to_inter_left,to_inter_right,from_inter);
78 request = smpi_mpi_irecv(buf, count, datatype, from_inter, tag, comm);
79 smpi_mpi_wait(&request, &status);
80 smpi_mpi_send(buf, count, datatype, to_inter_left, tag, comm);
81 if (to_inter_right < size)
82 smpi_mpi_send(buf, count, datatype, to_inter_right, tag, comm);
83 if ((to_intra_left - base) < num_core)
84 smpi_mpi_send(buf, count, datatype, to_intra_left, tag, comm);
85 if ((to_intra_right - base) < num_core)
86 smpi_mpi_send(buf, count, datatype, to_intra_right, tag, comm);
89 // case non ROOT-of-each-SMP
92 if ((to_intra_left - base) >= num_core) {
93 request = smpi_mpi_irecv(buf, count, datatype, from_intra, tag, comm);
94 smpi_mpi_wait(&request, &status);
98 request = smpi_mpi_irecv(buf, count, datatype, from_intra, tag, comm);
99 smpi_mpi_wait(&request, &status);
100 smpi_mpi_send(buf, count, datatype, to_intra_left, tag, comm);
101 if ((to_intra_right - base) < num_core)
102 smpi_mpi_send(buf, count, datatype, to_intra_right, tag, comm);
112 (MPI_Request *) xbt_malloc((size + pipe_length) * sizeof(MPI_Request));
114 (MPI_Status *) xbt_malloc((size + pipe_length) * sizeof(MPI_Status));
116 // case ROOT-of-each-SMP
117 if (rank % NUM_CORE == 0) {
120 for (i = 0; i < pipe_length; i++) {
121 //printf("node %d left %d right %d\n",rank,to_inter_left,to_inter_right);
122 if (to_inter_left < size)
123 smpi_mpi_send((char *) buf + (i * increment), segment, datatype,
124 to_inter_left, (tag + i), comm);
125 if (to_inter_right < size)
126 smpi_mpi_send((char *) buf + (i * increment), segment, datatype,
127 to_inter_right, (tag + i), comm);
128 if ((to_intra_left - base) < num_core)
129 smpi_mpi_send((char *) buf + (i * increment), segment, datatype,
130 to_intra_left, (tag + i), comm);
131 if ((to_intra_right - base) < num_core)
132 smpi_mpi_send((char *) buf + (i * increment), segment, datatype,
133 to_intra_right, (tag + i), comm);
136 // case LEAVES ROOT-of-eash-SMP
137 else if (to_inter_left >= size) {
138 //printf("node %d from %d\n",rank,from_inter);
139 for (i = 0; i < pipe_length; i++) {
140 request_array[i] = smpi_mpi_irecv((char *) buf + (i * increment), segment, datatype,
141 from_inter, (tag + i), comm);
143 for (i = 0; i < pipe_length; i++) {
144 smpi_mpi_wait(&request_array[i], &status);
145 if ((to_intra_left - base) < num_core)
146 smpi_mpi_send((char *) buf + (i * increment), segment, datatype,
147 to_intra_left, (tag + i), comm);
148 if ((to_intra_right - base) < num_core)
149 smpi_mpi_send((char *) buf + (i * increment), segment, datatype,
150 to_intra_right, (tag + i), comm);
153 // case INTERMEDIAT ROOT-of-each-SMP
155 //printf("node %d left %d right %d from %d\n",rank,to_inter_left,to_inter_right,from_inter);
156 for (i = 0; i < pipe_length; i++) {
157 request_array[i] = smpi_mpi_irecv((char *) buf + (i * increment), segment, datatype,
158 from_inter, (tag + i), comm);
160 for (i = 0; i < pipe_length; i++) {
161 smpi_mpi_wait(&request_array[i], &status);
162 smpi_mpi_send((char *) buf + (i * increment), segment, datatype,
163 to_inter_left, (tag + i), comm);
164 if (to_inter_right < size)
165 smpi_mpi_send((char *) buf + (i * increment), segment, datatype,
166 to_inter_right, (tag + i), comm);
167 if ((to_intra_left - base) < num_core)
168 smpi_mpi_send((char *) buf + (i * increment), segment, datatype,
169 to_intra_left, (tag + i), comm);
170 if ((to_intra_right - base) < num_core)
171 smpi_mpi_send((char *) buf + (i * increment), segment, datatype,
172 to_intra_right, (tag + i), comm);
176 // case non-ROOT-of-each-SMP
179 if ((to_intra_left - base) >= num_core) {
180 for (i = 0; i < pipe_length; i++) {
181 request_array[i] = smpi_mpi_irecv((char *) buf + (i * increment), segment, datatype,
182 from_intra, (tag + i), comm);
184 smpi_mpi_waitall((pipe_length), request_array, status_array);
188 for (i = 0; i < pipe_length; i++) {
189 request_array[i] = smpi_mpi_irecv((char *) buf + (i * increment), segment, datatype,
190 from_intra, (tag + i), comm);
192 for (i = 0; i < pipe_length; i++) {
193 smpi_mpi_wait(&request_array[i], &status);
194 smpi_mpi_send((char *) buf + (i * increment), segment, datatype,
195 to_intra_left, (tag + i), comm);
196 if ((to_intra_right - base) < num_core)
197 smpi_mpi_send((char *) buf + (i * increment), segment, datatype,
198 to_intra_right, (tag + i), comm);
207 // when count is not divisible by block size, use default BCAST for the remainder
208 if ((remainder != 0) && (count > segment)) {
209 XBT_WARN("MPI_bcast_SMP_binary use default MPI_bcast.");
210 smpi_mpi_bcast((char *) buf + (pipe_length * increment), remainder, datatype,