1 /* selector for collective algorithms based on mvapich decision logic */
3 /* Copyright (c) 2009-2023. The SimGrid Team.
4 * All rights reserved. */
6 /* This program is free software; you can redistribute it and/or modify it
7 * under the terms of the license (GNU LGPL) which comes with this package. */
9 #include "colls_private.hpp"
11 #include "smpi_mvapich2_selector_stampede.hpp"
13 namespace simgrid::smpi {
15 int alltoall__mvapich2( const void *sendbuf, int sendcount,
16 MPI_Datatype sendtype,
17 void* recvbuf, int recvcount,
18 MPI_Datatype recvtype,
22 if (mv2_alltoall_table_ppn_conf == nullptr)
23 init_mv2_alltoall_tables_stampede();
25 int sendtype_size, recvtype_size, comm_size;
26 int mpi_errno=MPI_SUCCESS;
28 int range_threshold = 0;
30 comm_size = comm->size();
32 sendtype_size=sendtype->size();
33 recvtype_size=recvtype->size();
34 long nbytes = sendtype_size * sendcount;
36 /* check if safe to use partial subscription mode */
38 /* Search for the corresponding system size inside the tuning table */
39 while ((range < (mv2_size_alltoall_tuning_table[conf_index] - 1)) &&
40 (comm_size > mv2_alltoall_thresholds_table[conf_index][range].numproc)) {
43 /* Search for corresponding inter-leader function */
44 while ((range_threshold < (mv2_alltoall_thresholds_table[conf_index][range].size_table - 1))
46 mv2_alltoall_thresholds_table[conf_index][range].algo_table[range_threshold].max)
47 && (mv2_alltoall_thresholds_table[conf_index][range].algo_table[range_threshold].max != -1)) {
50 MV2_Alltoall_function = mv2_alltoall_thresholds_table[conf_index][range].algo_table[range_threshold]
51 .MV2_pt_Alltoall_function;
53 if(sendbuf != MPI_IN_PLACE) {
54 mpi_errno = MV2_Alltoall_function(sendbuf, sendcount, sendtype,
55 recvbuf, recvcount, recvtype,
60 mv2_alltoall_thresholds_table[conf_index][range].in_place_algo_table[range_threshold].min
61 ||nbytes > mv2_alltoall_thresholds_table[conf_index][range].in_place_algo_table[range_threshold].max
63 unsigned char* tmp_buf = smpi_get_tmp_sendbuffer(comm_size * recvcount * recvtype_size);
64 Datatype::copy(recvbuf, comm_size * recvcount, recvtype, tmp_buf, comm_size * recvcount, recvtype);
66 mpi_errno = MV2_Alltoall_function(tmp_buf, recvcount, recvtype, recvbuf, recvcount, recvtype, comm);
67 smpi_free_tmp_buffer(tmp_buf);
69 mpi_errno = MPIR_Alltoall_inplace_MV2(sendbuf, sendcount, sendtype,
70 recvbuf, recvcount, recvtype,
79 int allgather__mvapich2(const void *sendbuf, int sendcount, MPI_Datatype sendtype,
80 void *recvbuf, int recvcount, MPI_Datatype recvtype,
84 int mpi_errno = MPI_SUCCESS;
85 long nbytes = 0, comm_size, recvtype_size;
87 bool partial_sub_ok = false;
89 int range_threshold = 0;
91 //MPI_Comm *shmem_commptr=NULL;
92 /* Get the size of the communicator */
93 comm_size = comm->size();
94 recvtype_size=recvtype->size();
95 nbytes = recvtype_size * recvcount;
97 if (mv2_allgather_table_ppn_conf == nullptr)
98 init_mv2_allgather_tables_stampede();
100 if(comm->get_leaders_comm()==MPI_COMM_NULL){
104 if (comm->is_uniform()){
105 shmem_comm = comm->get_intra_comm();
106 int local_size = shmem_comm->size();
108 if (mv2_allgather_table_ppn_conf[0] == -1) {
109 // Indicating user defined tuning
114 if (local_size == mv2_allgather_table_ppn_conf[i]) {
116 partial_sub_ok = true;
120 } while(i < mv2_allgather_num_ppn_conf);
123 if (not partial_sub_ok) {
127 /* Search for the corresponding system size inside the tuning table */
128 while ((range < (mv2_size_allgather_tuning_table[conf_index] - 1)) &&
130 mv2_allgather_thresholds_table[conf_index][range].numproc)) {
133 /* Search for corresponding inter-leader function */
134 while ((range_threshold <
135 (mv2_allgather_thresholds_table[conf_index][range].size_inter_table - 1))
136 && (nbytes > mv2_allgather_thresholds_table[conf_index][range].inter_leader[range_threshold].max)
137 && (mv2_allgather_thresholds_table[conf_index][range].inter_leader[range_threshold].max !=
142 /* Set inter-leader pt */
144 mv2_allgather_thresholds_table[conf_index][range].inter_leader[range_threshold].
145 MV2_pt_Allgatherction;
147 bool is_two_level = mv2_allgather_thresholds_table[conf_index][range].two_level[range_threshold];
149 /* intracommunicator */
151 if (partial_sub_ok) {
152 if (comm->is_blocked()){
153 mpi_errno = MPIR_2lvl_Allgather_MV2(sendbuf, sendcount, sendtype,
154 recvbuf, recvcount, recvtype,
157 mpi_errno = allgather__mpich(sendbuf, sendcount, sendtype,
158 recvbuf, recvcount, recvtype,
162 mpi_errno = MPIR_Allgather_RD_MV2(sendbuf, sendcount, sendtype,
163 recvbuf, recvcount, recvtype,
166 } else if(MV2_Allgatherction == &MPIR_Allgather_Bruck_MV2
167 || MV2_Allgatherction == &MPIR_Allgather_RD_MV2
168 || MV2_Allgatherction == &MPIR_Allgather_Ring_MV2) {
169 mpi_errno = MV2_Allgatherction(sendbuf, sendcount, sendtype,
170 recvbuf, recvcount, recvtype,
173 return MPI_ERR_OTHER;
179 int gather__mvapich2(const void *sendbuf,
181 MPI_Datatype sendtype,
184 MPI_Datatype recvtype,
185 int root, MPI_Comm comm)
187 if (mv2_gather_thresholds_table == nullptr)
188 init_mv2_gather_tables_stampede();
190 int mpi_errno = MPI_SUCCESS;
192 int range_threshold = 0;
193 int range_intra_threshold = 0;
195 int comm_size = comm->size();
196 int rank = comm->rank();
199 int recvtype_size = recvtype->size();
200 nbytes = recvcnt * recvtype_size;
202 int sendtype_size = sendtype->size();
203 nbytes = sendcnt * sendtype_size;
206 /* Search for the corresponding system size inside the tuning table */
207 while ((range < (mv2_size_gather_tuning_table - 1)) &&
208 (comm_size > mv2_gather_thresholds_table[range].numproc)) {
211 /* Search for corresponding inter-leader function */
212 while ((range_threshold < (mv2_gather_thresholds_table[range].size_inter_table - 1))
214 mv2_gather_thresholds_table[range].inter_leader[range_threshold].max)
215 && (mv2_gather_thresholds_table[range].inter_leader[range_threshold].max !=
220 /* Search for corresponding intra node function */
221 while ((range_intra_threshold < (mv2_gather_thresholds_table[range].size_intra_table - 1))
223 mv2_gather_thresholds_table[range].intra_node[range_intra_threshold].max)
224 && (mv2_gather_thresholds_table[range].intra_node[range_intra_threshold].max !=
226 range_intra_threshold++;
229 if (comm->is_blocked() ) {
230 // Set intra-node function pt for gather_two_level
231 MV2_Gather_intra_node_function =
232 mv2_gather_thresholds_table[range].intra_node[range_intra_threshold].
233 MV2_pt_Gather_function;
234 //Set inter-leader pt
235 MV2_Gather_inter_leader_function =
236 mv2_gather_thresholds_table[range].inter_leader[range_threshold].
237 MV2_pt_Gather_function;
238 // We call Gather function
240 MV2_Gather_inter_leader_function(sendbuf, sendcnt, sendtype, recvbuf, recvcnt,
241 recvtype, root, comm);
244 // Indeed, direct (non SMP-aware)gather is MPICH one
245 mpi_errno = gather__mpich(sendbuf, sendcnt, sendtype,
246 recvbuf, recvcnt, recvtype,
253 int allgatherv__mvapich2(const void *sendbuf, int sendcount, MPI_Datatype sendtype,
254 void *recvbuf, const int *recvcounts, const int *displs,
255 MPI_Datatype recvtype, MPI_Comm comm )
257 int mpi_errno = MPI_SUCCESS;
258 int range = 0, comm_size, total_count, recvtype_size, i;
259 int range_threshold = 0;
262 if (mv2_allgatherv_thresholds_table == nullptr)
263 init_mv2_allgatherv_tables_stampede();
265 comm_size = comm->size();
267 for (i = 0; i < comm_size; i++)
268 total_count += recvcounts[i];
270 recvtype_size=recvtype->size();
271 nbytes = total_count * recvtype_size;
273 /* Search for the corresponding system size inside the tuning table */
274 while ((range < (mv2_size_allgatherv_tuning_table - 1)) &&
275 (comm_size > mv2_allgatherv_thresholds_table[range].numproc)) {
278 /* Search for corresponding inter-leader function */
279 while ((range_threshold < (mv2_allgatherv_thresholds_table[range].size_inter_table - 1))
281 comm_size * mv2_allgatherv_thresholds_table[range].inter_leader[range_threshold].max)
282 && (mv2_allgatherv_thresholds_table[range].inter_leader[range_threshold].max !=
286 /* Set inter-leader pt */
287 MV2_Allgatherv_function =
288 mv2_allgatherv_thresholds_table[range].inter_leader[range_threshold].
289 MV2_pt_Allgatherv_function;
291 if (MV2_Allgatherv_function == &MPIR_Allgatherv_Rec_Doubling_MV2)
293 if (not(comm_size & (comm_size - 1))) {
295 MPIR_Allgatherv_Rec_Doubling_MV2(sendbuf, sendcount, sendtype, recvbuf, recvcounts, displs, recvtype, comm);
298 MPIR_Allgatherv_Bruck_MV2(sendbuf, sendcount,
305 MV2_Allgatherv_function(sendbuf, sendcount, sendtype,
306 recvbuf, recvcounts, displs,
315 int allreduce__mvapich2(const void *sendbuf,
318 MPI_Datatype datatype,
319 MPI_Op op, MPI_Comm comm)
322 int mpi_errno = MPI_SUCCESS;
326 comm_size = comm->size();
327 //rank = comm->rank();
333 if (mv2_allreduce_thresholds_table == nullptr)
334 init_mv2_allreduce_tables_stampede();
336 /* check if multiple threads are calling this collective function */
338 MPI_Aint sendtype_size = 0;
340 MPI_Aint true_lb, true_extent;
342 sendtype_size=datatype->size();
343 nbytes = count * sendtype_size;
345 datatype->extent(&true_lb, &true_extent);
346 bool is_commutative = op->is_commutative();
349 int range = 0, range_threshold = 0, range_threshold_intra = 0;
350 bool is_two_level = false;
352 /* Search for the corresponding system size inside the tuning table */
353 while ((range < (mv2_size_allreduce_tuning_table - 1)) &&
354 (comm_size > mv2_allreduce_thresholds_table[range].numproc)) {
357 /* Search for corresponding inter-leader function */
358 /* skip mcast pointers if mcast is not available */
359 if (not mv2_allreduce_thresholds_table[range].mcast_enabled) {
360 while ((range_threshold < (mv2_allreduce_thresholds_table[range].size_inter_table - 1))
361 && ((mv2_allreduce_thresholds_table[range].
362 inter_leader[range_threshold].MV2_pt_Allreducection
363 == &MPIR_Allreduce_mcst_reduce_redscat_gather_MV2) ||
364 (mv2_allreduce_thresholds_table[range].
365 inter_leader[range_threshold].MV2_pt_Allreducection
366 == &MPIR_Allreduce_mcst_reduce_two_level_helper_MV2)
371 while ((range_threshold < (mv2_allreduce_thresholds_table[range].size_inter_table - 1))
373 mv2_allreduce_thresholds_table[range].inter_leader[range_threshold].max)
374 && (mv2_allreduce_thresholds_table[range].inter_leader[range_threshold].max != -1)) {
377 if (mv2_allreduce_thresholds_table[range].is_two_level_allreduce[range_threshold]) {
380 /* Search for corresponding intra-node function */
381 while ((range_threshold_intra <
382 (mv2_allreduce_thresholds_table[range].size_intra_table - 1))
384 mv2_allreduce_thresholds_table[range].intra_node[range_threshold_intra].max)
385 && (mv2_allreduce_thresholds_table[range].intra_node[range_threshold_intra].max !=
387 range_threshold_intra++;
390 MV2_Allreducection = mv2_allreduce_thresholds_table[range].inter_leader[range_threshold]
391 .MV2_pt_Allreducection;
393 MV2_Allreduce_intra_function = mv2_allreduce_thresholds_table[range].intra_node[range_threshold_intra]
394 .MV2_pt_Allreducection;
396 /* check if mcast is ready, otherwise replace mcast with other algorithm */
397 if((MV2_Allreducection == &MPIR_Allreduce_mcst_reduce_redscat_gather_MV2)||
398 (MV2_Allreducection == &MPIR_Allreduce_mcst_reduce_two_level_helper_MV2)){
400 MV2_Allreducection = &MPIR_Allreduce_pt2pt_rd_MV2;
402 if (not is_two_level) {
403 MV2_Allreducection = &MPIR_Allreduce_pt2pt_rd_MV2;
408 // check if shm is ready, if not use other algorithm first
409 if (is_commutative) {
410 if(comm->get_leaders_comm()==MPI_COMM_NULL){
413 mpi_errno = MPIR_Allreduce_two_level_MV2(sendbuf, recvbuf, count,
416 mpi_errno = MPIR_Allreduce_pt2pt_rd_MV2(sendbuf, recvbuf, count,
420 mpi_errno = MV2_Allreducection(sendbuf, recvbuf, count,
425 //comm->ch.intra_node_done=0;
433 int alltoallv__mvapich2(const void *sbuf, const int *scounts, const int *sdisps,
435 void *rbuf, const int *rcounts, const int *rdisps,
441 if (sbuf == MPI_IN_PLACE) {
442 return alltoallv__ompi_basic_linear(sbuf, scounts, sdisps, sdtype,
443 rbuf, rcounts, rdisps, rdtype,
445 } else /* For starters, just keep the original algorithm. */
446 return alltoallv__ring(sbuf, scounts, sdisps, sdtype,
447 rbuf, rcounts, rdisps, rdtype,
452 int barrier__mvapich2(MPI_Comm comm)
454 return barrier__mvapich2_pair(comm);
460 int bcast__mvapich2(void *buffer,
462 MPI_Datatype datatype,
463 int root, MPI_Comm comm)
465 int mpi_errno = MPI_SUCCESS;
466 int comm_size/*, rank*/;
467 bool two_level_bcast = true;
470 int range_threshold = 0;
471 int range_threshold_intra = 0;
474 // unsigned char *tmp_buf = NULL;
476 //MPID_Datatype *dtp;
480 if(comm->get_leaders_comm()==MPI_COMM_NULL){
483 if (not mv2_bcast_thresholds_table)
484 init_mv2_bcast_tables_stampede();
485 comm_size = comm->size();
486 //rank = comm->rank();
488 // bool is_contig = true;
489 /* if (HANDLE_GET_KIND(datatype) == HANDLE_KIND_BUILTIN)*/
490 /* is_contig = true;*/
492 /* MPID_Datatype_get_ptr(datatype, dtp);*/
493 /* is_contig = dtp->is_contig;*/
496 // bool is_homogeneous = true;
498 /* MPI_Type_size() might not give the accurate size of the packed
499 * datatype for heterogeneous systems (because of padding, encoding,
500 * etc). On the other hand, MPI_Pack_size() can become very
501 * expensive, depending on the implementation, especially for
502 * heterogeneous systems. We want to use MPI_Type_size() wherever
503 * possible, and MPI_Pack_size() in other places.
505 //if (is_homogeneous) {
506 type_size=datatype->size();
509 MPIR_Pack_size_impl(1, datatype, &type_size);
511 nbytes = (count) * (type_size);
513 /* Search for the corresponding system size inside the tuning table */
514 while ((range < (mv2_size_bcast_tuning_table - 1)) &&
515 (comm_size > mv2_bcast_thresholds_table[range].numproc)) {
518 /* Search for corresponding inter-leader function */
519 while ((range_threshold < (mv2_bcast_thresholds_table[range].size_inter_table - 1))
521 mv2_bcast_thresholds_table[range].inter_leader[range_threshold].max)
522 && (mv2_bcast_thresholds_table[range].inter_leader[range_threshold].max != -1)) {
526 /* Search for corresponding intra-node function */
527 while ((range_threshold_intra <
528 (mv2_bcast_thresholds_table[range].size_intra_table - 1))
530 mv2_bcast_thresholds_table[range].intra_node[range_threshold_intra].max)
531 && (mv2_bcast_thresholds_table[range].intra_node[range_threshold_intra].max !=
533 range_threshold_intra++;
537 mv2_bcast_thresholds_table[range].inter_leader[range_threshold].
538 MV2_pt_Bcast_function;
540 MV2_Bcast_intra_node_function =
541 mv2_bcast_thresholds_table[range].
542 intra_node[range_threshold_intra].MV2_pt_Bcast_function;
544 /* if (mv2_user_bcast_intra == NULL && */
545 /* MV2_Bcast_intra_node_function == &MPIR_Knomial_Bcast_intra_node_MV2) {*/
546 /* MV2_Bcast_intra_node_function = &MPIR_Shmem_Bcast_MV2;*/
549 if (mv2_bcast_thresholds_table[range].inter_leader[range_threshold].
550 zcpy_pipelined_knomial_factor != -1) {
551 zcpy_knomial_factor =
552 mv2_bcast_thresholds_table[range].inter_leader[range_threshold].
553 zcpy_pipelined_knomial_factor;
556 if (mv2_pipelined_zcpy_knomial_factor != -1) {
557 zcpy_knomial_factor = mv2_pipelined_zcpy_knomial_factor;
560 if (MV2_Bcast_intra_node_function == nullptr) {
561 /* if tuning table do not have any intra selection, set func pointer to
562 ** default one for mcast intra node */
563 MV2_Bcast_intra_node_function = &MPIR_Shmem_Bcast_MV2;
566 /* Set value of pipeline segment size */
567 bcast_segment_size = mv2_bcast_thresholds_table[range].bcast_segment_size;
569 /* Set value of inter node knomial factor */
570 mv2_inter_node_knomial_factor = mv2_bcast_thresholds_table[range].inter_node_knomial_factor;
572 /* Set value of intra node knomial factor */
573 mv2_intra_node_knomial_factor = mv2_bcast_thresholds_table[range].intra_node_knomial_factor;
575 /* Check if we will use a two level algorithm or not */
577 #if defined(_MCST_SUPPORT_)
578 mv2_bcast_thresholds_table[range].is_two_level_bcast[range_threshold]
579 || comm->ch.is_mcast_ok;
581 mv2_bcast_thresholds_table[range].is_two_level_bcast[range_threshold];
583 if (two_level_bcast) {
584 // if (not is_contig || not is_homogeneous) {
585 // tmp_buf = smpi_get_tmp_sendbuffer(nbytes);
588 /* if (rank == root) {*/
590 /* MPIR_Pack_impl(buffer, count, datatype, tmp_buf, nbytes, &position);*/
592 /* MPIU_ERR_POP(mpi_errno);*/
595 #ifdef CHANNEL_MRAIL_GEN2
596 if ((mv2_enable_zcpy_bcast == 1) &&
597 (&MPIR_Pipelined_Bcast_Zcpy_MV2 == MV2_Bcast_function)) {
598 // if (not is_contig || not is_homogeneous) {
599 // mpi_errno = MPIR_Pipelined_Bcast_Zcpy_MV2(tmp_buf, nbytes, MPI_BYTE, root, comm);
601 mpi_errno = MPIR_Pipelined_Bcast_Zcpy_MV2(buffer, count, datatype,
605 #endif /* defined(CHANNEL_MRAIL_GEN2) */
607 shmem_comm = comm->get_intra_comm();
608 // if (not is_contig || not is_homogeneous) {
609 // MPIR_Bcast_tune_inter_node_helper_MV2(tmp_buf, nbytes, MPI_BYTE, root, comm);
611 MPIR_Bcast_tune_inter_node_helper_MV2(buffer, count, datatype, root, comm);
614 /* We are now done with the inter-node phase */
617 root = INTRA_NODE_ROOT;
619 // if (not is_contig || not is_homogeneous) {
620 // mpi_errno = MV2_Bcast_intra_node_function(tmp_buf, nbytes, MPI_BYTE, root, shmem_comm);
622 mpi_errno = MV2_Bcast_intra_node_function(buffer, count,
623 datatype, root, shmem_comm);
627 /* if (not is_contig || not is_homogeneous) {*/
628 /* if (rank != root) {*/
630 /* mpi_errno = MPIR_Unpack_impl(tmp_buf, nbytes, &position, buffer,*/
631 /* count, datatype);*/
635 /* We use Knomial for intra node */
636 MV2_Bcast_intra_node_function = &MPIR_Knomial_Bcast_intra_node_MV2;
637 /* if (mv2_enable_shmem_bcast == 0) {*/
638 /* Fall back to non-tuned version */
639 /* MPIR_Bcast_intra_MV2(buffer, count, datatype, root, comm);*/
641 mpi_errno = MV2_Bcast_function(buffer, count, datatype, root,
654 int reduce__mvapich2(const void *sendbuf,
657 MPI_Datatype datatype,
658 MPI_Op op, int root, MPI_Comm comm)
660 if (mv2_reduce_thresholds_table == nullptr)
661 init_mv2_reduce_tables_stampede();
663 int mpi_errno = MPI_SUCCESS;
665 int range_threshold = 0;
666 int range_intra_threshold = 0;
671 bool is_two_level = false;
673 comm_size = comm->size();
674 sendtype_size=datatype->size();
675 nbytes = count * sendtype_size;
680 bool is_commutative = (op == MPI_OP_NULL || op->is_commutative());
682 /* find nearest power-of-two less than or equal to comm_size */
683 for( pof2 = 1; pof2 <= comm_size; pof2 <<= 1 );
687 /* Search for the corresponding system size inside the tuning table */
688 while ((range < (mv2_size_reduce_tuning_table - 1)) &&
689 (comm_size > mv2_reduce_thresholds_table[range].numproc)) {
692 /* Search for corresponding inter-leader function */
693 while ((range_threshold < (mv2_reduce_thresholds_table[range].size_inter_table - 1))
695 mv2_reduce_thresholds_table[range].inter_leader[range_threshold].max)
696 && (mv2_reduce_thresholds_table[range].inter_leader[range_threshold].max !=
701 /* Search for corresponding intra node function */
702 while ((range_intra_threshold < (mv2_reduce_thresholds_table[range].size_intra_table - 1))
704 mv2_reduce_thresholds_table[range].intra_node[range_intra_threshold].max)
705 && (mv2_reduce_thresholds_table[range].intra_node[range_intra_threshold].max !=
707 range_intra_threshold++;
710 /* Set intra-node function pt for reduce_two_level */
711 MV2_Reduce_intra_function =
712 mv2_reduce_thresholds_table[range].intra_node[range_intra_threshold].
713 MV2_pt_Reduce_function;
714 /* Set inter-leader pt */
715 MV2_Reduce_function =
716 mv2_reduce_thresholds_table[range].inter_leader[range_threshold].
717 MV2_pt_Reduce_function;
719 if(mv2_reduce_intra_knomial_factor<0)
721 mv2_reduce_intra_knomial_factor = mv2_reduce_thresholds_table[range].intra_k_degree;
723 if(mv2_reduce_inter_knomial_factor<0)
725 mv2_reduce_inter_knomial_factor = mv2_reduce_thresholds_table[range].inter_k_degree;
727 if (mv2_reduce_thresholds_table[range].is_two_level_reduce[range_threshold]) {
730 /* We call Reduce function */
732 if (is_commutative) {
733 if(comm->get_leaders_comm()==MPI_COMM_NULL){
736 mpi_errno = MPIR_Reduce_two_level_helper_MV2(sendbuf, recvbuf, count,
737 datatype, op, root, comm);
739 mpi_errno = MPIR_Reduce_binomial_MV2(sendbuf, recvbuf, count,
740 datatype, op, root, comm);
742 } else if(MV2_Reduce_function == &MPIR_Reduce_inter_knomial_wrapper_MV2 ){
745 mpi_errno = MV2_Reduce_function(sendbuf, recvbuf, count,
746 datatype, op, root, comm);
748 mpi_errno = MPIR_Reduce_binomial_MV2(sendbuf, recvbuf, count,
749 datatype, op, root, comm);
751 } else if(MV2_Reduce_function == &MPIR_Reduce_redscat_gather_MV2){
752 if (/*(HANDLE_GET_KIND(op) == HANDLE_KIND_BUILTIN) &&*/ (count >= pof2))
754 mpi_errno = MV2_Reduce_function(sendbuf, recvbuf, count,
755 datatype, op, root, comm);
757 mpi_errno = MPIR_Reduce_binomial_MV2(sendbuf, recvbuf, count,
758 datatype, op, root, comm);
761 mpi_errno = MV2_Reduce_function(sendbuf, recvbuf, count,
762 datatype, op, root, comm);
771 int reduce_scatter__mvapich2(const void *sendbuf, void *recvbuf, const int *recvcnts,
772 MPI_Datatype datatype, MPI_Op op,
775 int mpi_errno = MPI_SUCCESS;
776 int i = 0, comm_size = comm->size(), total_count = 0, type_size =
778 int* disps = new int[comm_size];
780 if (mv2_red_scat_thresholds_table == nullptr)
781 init_mv2_reduce_scatter_tables_stampede();
783 bool is_commutative = (op == MPI_OP_NULL || op->is_commutative());
784 for (i = 0; i < comm_size; i++) {
785 disps[i] = total_count;
786 total_count += recvcnts[i];
789 type_size=datatype->size();
790 nbytes = total_count * type_size;
792 if (is_commutative) {
794 int range_threshold = 0;
796 /* Search for the corresponding system size inside the tuning table */
797 while ((range < (mv2_size_red_scat_tuning_table - 1)) &&
798 (comm_size > mv2_red_scat_thresholds_table[range].numproc)) {
801 /* Search for corresponding inter-leader function */
802 while ((range_threshold < (mv2_red_scat_thresholds_table[range].size_inter_table - 1))
804 mv2_red_scat_thresholds_table[range].inter_leader[range_threshold].max)
805 && (mv2_red_scat_thresholds_table[range].inter_leader[range_threshold].max !=
810 /* Set inter-leader pt */
811 MV2_Red_scat_function =
812 mv2_red_scat_thresholds_table[range].inter_leader[range_threshold].
813 MV2_pt_Red_scat_function;
815 mpi_errno = MV2_Red_scat_function(sendbuf, recvbuf,
819 bool is_block_regular = true;
820 for (i = 0; i < (comm_size - 1); ++i) {
821 if (recvcnts[i] != recvcnts[i+1]) {
822 is_block_regular = false;
827 while (pof2 < comm_size) pof2 <<= 1;
828 if (pof2 == comm_size && is_block_regular) {
829 /* noncommutative, pof2 size, and block regular */
830 MPIR_Reduce_scatter_non_comm_MV2(sendbuf, recvbuf,
834 mpi_errno = reduce_scatter__mpich_rdb(sendbuf, recvbuf,
845 int scatter__mvapich2(const void *sendbuf,
847 MPI_Datatype sendtype,
850 MPI_Datatype recvtype,
851 int root, MPI_Comm comm)
853 int range = 0, range_threshold = 0, range_threshold_intra = 0;
854 int mpi_errno = MPI_SUCCESS;
855 // int mpi_errno_ret = MPI_SUCCESS;
856 int rank, nbytes, comm_size;
857 bool partial_sub_ok = false;
860 // MPID_Comm *shmem_commptr=NULL;
861 if (mv2_scatter_thresholds_table == nullptr)
862 init_mv2_scatter_tables_stampede();
864 if (comm->get_leaders_comm() == MPI_COMM_NULL) {
868 comm_size = comm->size();
873 int sendtype_size = sendtype->size();
874 nbytes = sendcnt * sendtype_size;
876 int recvtype_size = recvtype->size();
877 nbytes = recvcnt * recvtype_size;
880 // check if safe to use partial subscription mode
881 if (comm->is_uniform()) {
883 shmem_comm = comm->get_intra_comm();
884 if (mv2_scatter_table_ppn_conf[0] == -1) {
885 // Indicating user defined tuning
888 int local_size = shmem_comm->size();
891 if (local_size == mv2_scatter_table_ppn_conf[i]) {
893 partial_sub_ok = true;
897 } while(i < mv2_scatter_num_ppn_conf);
901 if (not partial_sub_ok) {
905 /* Search for the corresponding system size inside the tuning table */
906 while ((range < (mv2_size_scatter_tuning_table[conf_index] - 1)) &&
907 (comm_size > mv2_scatter_thresholds_table[conf_index][range].numproc)) {
910 /* Search for corresponding inter-leader function */
911 while ((range_threshold < (mv2_scatter_thresholds_table[conf_index][range].size_inter_table - 1))
913 mv2_scatter_thresholds_table[conf_index][range].inter_leader[range_threshold].max)
914 && (mv2_scatter_thresholds_table[conf_index][range].inter_leader[range_threshold].max != -1)) {
918 /* Search for corresponding intra-node function */
919 while ((range_threshold_intra <
920 (mv2_scatter_thresholds_table[conf_index][range].size_intra_table - 1))
922 mv2_scatter_thresholds_table[conf_index][range].intra_node[range_threshold_intra].max)
923 && (mv2_scatter_thresholds_table[conf_index][range].intra_node[range_threshold_intra].max !=
925 range_threshold_intra++;
928 MV2_Scatter_function = mv2_scatter_thresholds_table[conf_index][range].inter_leader[range_threshold]
929 .MV2_pt_Scatter_function;
931 if(MV2_Scatter_function == &MPIR_Scatter_mcst_wrap_MV2) {
932 #if defined(_MCST_SUPPORT_)
933 if(comm->ch.is_mcast_ok == 1
934 && mv2_use_mcast_scatter == 1
935 && comm->ch.shmem_coll_ok == 1) {
936 MV2_Scatter_function = &MPIR_Scatter_mcst_MV2;
938 #endif /*#if defined(_MCST_SUPPORT_) */
940 if (mv2_scatter_thresholds_table[conf_index][range].inter_leader[range_threshold + 1].MV2_pt_Scatter_function !=
942 MV2_Scatter_function =
943 mv2_scatter_thresholds_table[conf_index][range].inter_leader[range_threshold + 1].MV2_pt_Scatter_function;
946 MV2_Scatter_function = &MPIR_Scatter_MV2_Binomial;
951 if( (MV2_Scatter_function == &MPIR_Scatter_MV2_two_level_Direct) ||
952 (MV2_Scatter_function == &MPIR_Scatter_MV2_two_level_Binomial)) {
953 if( comm->is_blocked()) {
954 MV2_Scatter_intra_function = mv2_scatter_thresholds_table[conf_index][range].intra_node[range_threshold_intra]
955 .MV2_pt_Scatter_function;
958 MV2_Scatter_function(sendbuf, sendcnt, sendtype,
959 recvbuf, recvcnt, recvtype, root,
962 mpi_errno = MPIR_Scatter_MV2_Binomial(sendbuf, sendcnt, sendtype,
963 recvbuf, recvcnt, recvtype, root,
968 mpi_errno = MV2_Scatter_function(sendbuf, sendcnt, sendtype,
969 recvbuf, recvcnt, recvtype, root,
975 } // namespace simgrid::smpi
977 void smpi_coll_cleanup_mvapich2()
979 if (mv2_alltoall_thresholds_table)
980 delete[] mv2_alltoall_thresholds_table[0];
981 delete[] mv2_alltoall_thresholds_table;
982 delete[] mv2_size_alltoall_tuning_table;
983 delete[] mv2_alltoall_table_ppn_conf;
985 delete[] mv2_gather_thresholds_table;
986 if (mv2_allgather_thresholds_table)
987 delete[] mv2_allgather_thresholds_table[0];
988 delete[] mv2_size_allgather_tuning_table;
989 delete[] mv2_allgather_table_ppn_conf;
990 delete[] mv2_allgather_thresholds_table;
992 delete[] mv2_allgatherv_thresholds_table;
993 delete[] mv2_reduce_thresholds_table;
994 delete[] mv2_red_scat_thresholds_table;
995 delete[] mv2_allreduce_thresholds_table;
996 delete[] mv2_bcast_thresholds_table;
997 if (mv2_scatter_thresholds_table)
998 delete[] mv2_scatter_thresholds_table[0];
999 delete[] mv2_scatter_thresholds_table;
1000 delete[] mv2_size_scatter_tuning_table;
1001 delete[] mv2_scatter_table_ppn_conf;