1 /* selector for collective algorithms based on mvapich decision logic */
3 /* Copyright (c) 2009-2022. The SimGrid Team.
4 * All rights reserved. */
6 /* This program is free software; you can redistribute it and/or modify it
7 * under the terms of the license (GNU LGPL) which comes with this package. */
9 #include "colls_private.hpp"
11 #include "smpi_mvapich2_selector_stampede.hpp"
17 int alltoall__mvapich2( const void *sendbuf, int sendcount,
18 MPI_Datatype sendtype,
19 void* recvbuf, int recvcount,
20 MPI_Datatype recvtype,
24 if (mv2_alltoall_table_ppn_conf == nullptr)
25 init_mv2_alltoall_tables_stampede();
27 int sendtype_size, recvtype_size, comm_size;
28 int mpi_errno=MPI_SUCCESS;
30 int range_threshold = 0;
32 comm_size = comm->size();
34 sendtype_size=sendtype->size();
35 recvtype_size=recvtype->size();
36 long nbytes = sendtype_size * sendcount;
38 /* check if safe to use partial subscription mode */
40 /* Search for the corresponding system size inside the tuning table */
41 while ((range < (mv2_size_alltoall_tuning_table[conf_index] - 1)) &&
42 (comm_size > mv2_alltoall_thresholds_table[conf_index][range].numproc)) {
45 /* Search for corresponding inter-leader function */
46 while ((range_threshold < (mv2_alltoall_thresholds_table[conf_index][range].size_table - 1))
48 mv2_alltoall_thresholds_table[conf_index][range].algo_table[range_threshold].max)
49 && (mv2_alltoall_thresholds_table[conf_index][range].algo_table[range_threshold].max != -1)) {
52 MV2_Alltoall_function = mv2_alltoall_thresholds_table[conf_index][range].algo_table[range_threshold]
53 .MV2_pt_Alltoall_function;
55 if(sendbuf != MPI_IN_PLACE) {
56 mpi_errno = MV2_Alltoall_function(sendbuf, sendcount, sendtype,
57 recvbuf, recvcount, recvtype,
62 mv2_alltoall_thresholds_table[conf_index][range].in_place_algo_table[range_threshold].min
63 ||nbytes > mv2_alltoall_thresholds_table[conf_index][range].in_place_algo_table[range_threshold].max
65 unsigned char* tmp_buf = smpi_get_tmp_sendbuffer(comm_size * recvcount * recvtype_size);
66 Datatype::copy(recvbuf, comm_size * recvcount, recvtype, tmp_buf, comm_size * recvcount, recvtype);
68 mpi_errno = MV2_Alltoall_function(tmp_buf, recvcount, recvtype, recvbuf, recvcount, recvtype, comm);
69 smpi_free_tmp_buffer(tmp_buf);
71 mpi_errno = MPIR_Alltoall_inplace_MV2(sendbuf, sendcount, sendtype,
72 recvbuf, recvcount, recvtype,
81 int allgather__mvapich2(const void *sendbuf, int sendcount, MPI_Datatype sendtype,
82 void *recvbuf, int recvcount, MPI_Datatype recvtype,
86 int mpi_errno = MPI_SUCCESS;
87 long nbytes = 0, comm_size, recvtype_size;
89 bool partial_sub_ok = false;
91 int range_threshold = 0;
93 //MPI_Comm *shmem_commptr=NULL;
94 /* Get the size of the communicator */
95 comm_size = comm->size();
96 recvtype_size=recvtype->size();
97 nbytes = recvtype_size * recvcount;
99 if (mv2_allgather_table_ppn_conf == nullptr)
100 init_mv2_allgather_tables_stampede();
102 if(comm->get_leaders_comm()==MPI_COMM_NULL){
106 if (comm->is_uniform()){
107 shmem_comm = comm->get_intra_comm();
108 int local_size = shmem_comm->size();
110 if (mv2_allgather_table_ppn_conf[0] == -1) {
111 // Indicating user defined tuning
116 if (local_size == mv2_allgather_table_ppn_conf[i]) {
118 partial_sub_ok = true;
122 } while(i < mv2_allgather_num_ppn_conf);
125 if (not partial_sub_ok) {
129 /* Search for the corresponding system size inside the tuning table */
130 while ((range < (mv2_size_allgather_tuning_table[conf_index] - 1)) &&
132 mv2_allgather_thresholds_table[conf_index][range].numproc)) {
135 /* Search for corresponding inter-leader function */
136 while ((range_threshold <
137 (mv2_allgather_thresholds_table[conf_index][range].size_inter_table - 1))
138 && (nbytes > mv2_allgather_thresholds_table[conf_index][range].inter_leader[range_threshold].max)
139 && (mv2_allgather_thresholds_table[conf_index][range].inter_leader[range_threshold].max !=
144 /* Set inter-leader pt */
146 mv2_allgather_thresholds_table[conf_index][range].inter_leader[range_threshold].
147 MV2_pt_Allgatherction;
149 bool is_two_level = mv2_allgather_thresholds_table[conf_index][range].two_level[range_threshold];
151 /* intracommunicator */
153 if (partial_sub_ok) {
154 if (comm->is_blocked()){
155 mpi_errno = MPIR_2lvl_Allgather_MV2(sendbuf, sendcount, sendtype,
156 recvbuf, recvcount, recvtype,
159 mpi_errno = allgather__mpich(sendbuf, sendcount, sendtype,
160 recvbuf, recvcount, recvtype,
164 mpi_errno = MPIR_Allgather_RD_MV2(sendbuf, sendcount, sendtype,
165 recvbuf, recvcount, recvtype,
168 } else if(MV2_Allgatherction == &MPIR_Allgather_Bruck_MV2
169 || MV2_Allgatherction == &MPIR_Allgather_RD_MV2
170 || MV2_Allgatherction == &MPIR_Allgather_Ring_MV2) {
171 mpi_errno = MV2_Allgatherction(sendbuf, sendcount, sendtype,
172 recvbuf, recvcount, recvtype,
175 return MPI_ERR_OTHER;
181 int gather__mvapich2(const void *sendbuf,
183 MPI_Datatype sendtype,
186 MPI_Datatype recvtype,
187 int root, MPI_Comm comm)
189 if (mv2_gather_thresholds_table == nullptr)
190 init_mv2_gather_tables_stampede();
192 int mpi_errno = MPI_SUCCESS;
194 int range_threshold = 0;
195 int range_intra_threshold = 0;
197 int comm_size = comm->size();
198 int rank = comm->rank();
201 int recvtype_size = recvtype->size();
202 nbytes = recvcnt * recvtype_size;
204 int sendtype_size = sendtype->size();
205 nbytes = sendcnt * sendtype_size;
208 /* Search for the corresponding system size inside the tuning table */
209 while ((range < (mv2_size_gather_tuning_table - 1)) &&
210 (comm_size > mv2_gather_thresholds_table[range].numproc)) {
213 /* Search for corresponding inter-leader function */
214 while ((range_threshold < (mv2_gather_thresholds_table[range].size_inter_table - 1))
216 mv2_gather_thresholds_table[range].inter_leader[range_threshold].max)
217 && (mv2_gather_thresholds_table[range].inter_leader[range_threshold].max !=
222 /* Search for corresponding intra node function */
223 while ((range_intra_threshold < (mv2_gather_thresholds_table[range].size_intra_table - 1))
225 mv2_gather_thresholds_table[range].intra_node[range_intra_threshold].max)
226 && (mv2_gather_thresholds_table[range].intra_node[range_intra_threshold].max !=
228 range_intra_threshold++;
231 if (comm->is_blocked() ) {
232 // Set intra-node function pt for gather_two_level
233 MV2_Gather_intra_node_function =
234 mv2_gather_thresholds_table[range].intra_node[range_intra_threshold].
235 MV2_pt_Gather_function;
236 //Set inter-leader pt
237 MV2_Gather_inter_leader_function =
238 mv2_gather_thresholds_table[range].inter_leader[range_threshold].
239 MV2_pt_Gather_function;
240 // We call Gather function
242 MV2_Gather_inter_leader_function(sendbuf, sendcnt, sendtype, recvbuf, recvcnt,
243 recvtype, root, comm);
246 // Indeed, direct (non SMP-aware)gather is MPICH one
247 mpi_errno = gather__mpich(sendbuf, sendcnt, sendtype,
248 recvbuf, recvcnt, recvtype,
255 int allgatherv__mvapich2(const void *sendbuf, int sendcount, MPI_Datatype sendtype,
256 void *recvbuf, const int *recvcounts, const int *displs,
257 MPI_Datatype recvtype, MPI_Comm comm )
259 int mpi_errno = MPI_SUCCESS;
260 int range = 0, comm_size, total_count, recvtype_size, i;
261 int range_threshold = 0;
264 if (mv2_allgatherv_thresholds_table == nullptr)
265 init_mv2_allgatherv_tables_stampede();
267 comm_size = comm->size();
269 for (i = 0; i < comm_size; i++)
270 total_count += recvcounts[i];
272 recvtype_size=recvtype->size();
273 nbytes = total_count * recvtype_size;
275 /* Search for the corresponding system size inside the tuning table */
276 while ((range < (mv2_size_allgatherv_tuning_table - 1)) &&
277 (comm_size > mv2_allgatherv_thresholds_table[range].numproc)) {
280 /* Search for corresponding inter-leader function */
281 while ((range_threshold < (mv2_allgatherv_thresholds_table[range].size_inter_table - 1))
283 comm_size * mv2_allgatherv_thresholds_table[range].inter_leader[range_threshold].max)
284 && (mv2_allgatherv_thresholds_table[range].inter_leader[range_threshold].max !=
288 /* Set inter-leader pt */
289 MV2_Allgatherv_function =
290 mv2_allgatherv_thresholds_table[range].inter_leader[range_threshold].
291 MV2_pt_Allgatherv_function;
293 if (MV2_Allgatherv_function == &MPIR_Allgatherv_Rec_Doubling_MV2)
295 if (not(comm_size & (comm_size - 1))) {
297 MPIR_Allgatherv_Rec_Doubling_MV2(sendbuf, sendcount, sendtype, recvbuf, recvcounts, displs, recvtype, comm);
300 MPIR_Allgatherv_Bruck_MV2(sendbuf, sendcount,
307 MV2_Allgatherv_function(sendbuf, sendcount, sendtype,
308 recvbuf, recvcounts, displs,
317 int allreduce__mvapich2(const void *sendbuf,
320 MPI_Datatype datatype,
321 MPI_Op op, MPI_Comm comm)
324 int mpi_errno = MPI_SUCCESS;
328 comm_size = comm->size();
329 //rank = comm->rank();
335 if (mv2_allreduce_thresholds_table == nullptr)
336 init_mv2_allreduce_tables_stampede();
338 /* check if multiple threads are calling this collective function */
340 MPI_Aint sendtype_size = 0;
342 MPI_Aint true_lb, true_extent;
344 sendtype_size=datatype->size();
345 nbytes = count * sendtype_size;
347 datatype->extent(&true_lb, &true_extent);
348 bool is_commutative = op->is_commutative();
351 int range = 0, range_threshold = 0, range_threshold_intra = 0;
352 bool is_two_level = false;
354 /* Search for the corresponding system size inside the tuning table */
355 while ((range < (mv2_size_allreduce_tuning_table - 1)) &&
356 (comm_size > mv2_allreduce_thresholds_table[range].numproc)) {
359 /* Search for corresponding inter-leader function */
360 /* skip mcast pointers if mcast is not available */
361 if (not mv2_allreduce_thresholds_table[range].mcast_enabled) {
362 while ((range_threshold < (mv2_allreduce_thresholds_table[range].size_inter_table - 1))
363 && ((mv2_allreduce_thresholds_table[range].
364 inter_leader[range_threshold].MV2_pt_Allreducection
365 == &MPIR_Allreduce_mcst_reduce_redscat_gather_MV2) ||
366 (mv2_allreduce_thresholds_table[range].
367 inter_leader[range_threshold].MV2_pt_Allreducection
368 == &MPIR_Allreduce_mcst_reduce_two_level_helper_MV2)
373 while ((range_threshold < (mv2_allreduce_thresholds_table[range].size_inter_table - 1))
375 mv2_allreduce_thresholds_table[range].inter_leader[range_threshold].max)
376 && (mv2_allreduce_thresholds_table[range].inter_leader[range_threshold].max != -1)) {
379 if (mv2_allreduce_thresholds_table[range].is_two_level_allreduce[range_threshold]) {
382 /* Search for corresponding intra-node function */
383 while ((range_threshold_intra <
384 (mv2_allreduce_thresholds_table[range].size_intra_table - 1))
386 mv2_allreduce_thresholds_table[range].intra_node[range_threshold_intra].max)
387 && (mv2_allreduce_thresholds_table[range].intra_node[range_threshold_intra].max !=
389 range_threshold_intra++;
392 MV2_Allreducection = mv2_allreduce_thresholds_table[range].inter_leader[range_threshold]
393 .MV2_pt_Allreducection;
395 MV2_Allreduce_intra_function = mv2_allreduce_thresholds_table[range].intra_node[range_threshold_intra]
396 .MV2_pt_Allreducection;
398 /* check if mcast is ready, otherwise replace mcast with other algorithm */
399 if((MV2_Allreducection == &MPIR_Allreduce_mcst_reduce_redscat_gather_MV2)||
400 (MV2_Allreducection == &MPIR_Allreduce_mcst_reduce_two_level_helper_MV2)){
402 MV2_Allreducection = &MPIR_Allreduce_pt2pt_rd_MV2;
404 if (not is_two_level) {
405 MV2_Allreducection = &MPIR_Allreduce_pt2pt_rd_MV2;
410 // check if shm is ready, if not use other algorithm first
411 if (is_commutative) {
412 if(comm->get_leaders_comm()==MPI_COMM_NULL){
415 mpi_errno = MPIR_Allreduce_two_level_MV2(sendbuf, recvbuf, count,
418 mpi_errno = MPIR_Allreduce_pt2pt_rd_MV2(sendbuf, recvbuf, count,
422 mpi_errno = MV2_Allreducection(sendbuf, recvbuf, count,
427 //comm->ch.intra_node_done=0;
435 int alltoallv__mvapich2(const void *sbuf, const int *scounts, const int *sdisps,
437 void *rbuf, const int *rcounts, const int *rdisps,
443 if (sbuf == MPI_IN_PLACE) {
444 return alltoallv__ompi_basic_linear(sbuf, scounts, sdisps, sdtype,
445 rbuf, rcounts, rdisps, rdtype,
447 } else /* For starters, just keep the original algorithm. */
448 return alltoallv__ring(sbuf, scounts, sdisps, sdtype,
449 rbuf, rcounts, rdisps, rdtype,
454 int barrier__mvapich2(MPI_Comm comm)
456 return barrier__mvapich2_pair(comm);
462 int bcast__mvapich2(void *buffer,
464 MPI_Datatype datatype,
465 int root, MPI_Comm comm)
467 int mpi_errno = MPI_SUCCESS;
468 int comm_size/*, rank*/;
469 bool two_level_bcast = true;
472 int range_threshold = 0;
473 int range_threshold_intra = 0;
476 // unsigned char *tmp_buf = NULL;
478 //MPID_Datatype *dtp;
482 if(comm->get_leaders_comm()==MPI_COMM_NULL){
485 if (not mv2_bcast_thresholds_table)
486 init_mv2_bcast_tables_stampede();
487 comm_size = comm->size();
488 //rank = comm->rank();
490 // bool is_contig = true;
491 /* if (HANDLE_GET_KIND(datatype) == HANDLE_KIND_BUILTIN)*/
492 /* is_contig = true;*/
494 /* MPID_Datatype_get_ptr(datatype, dtp);*/
495 /* is_contig = dtp->is_contig;*/
498 // bool is_homogeneous = true;
500 /* MPI_Type_size() might not give the accurate size of the packed
501 * datatype for heterogeneous systems (because of padding, encoding,
502 * etc). On the other hand, MPI_Pack_size() can become very
503 * expensive, depending on the implementation, especially for
504 * heterogeneous systems. We want to use MPI_Type_size() wherever
505 * possible, and MPI_Pack_size() in other places.
507 //if (is_homogeneous) {
508 type_size=datatype->size();
511 MPIR_Pack_size_impl(1, datatype, &type_size);
513 nbytes = (count) * (type_size);
515 /* Search for the corresponding system size inside the tuning table */
516 while ((range < (mv2_size_bcast_tuning_table - 1)) &&
517 (comm_size > mv2_bcast_thresholds_table[range].numproc)) {
520 /* Search for corresponding inter-leader function */
521 while ((range_threshold < (mv2_bcast_thresholds_table[range].size_inter_table - 1))
523 mv2_bcast_thresholds_table[range].inter_leader[range_threshold].max)
524 && (mv2_bcast_thresholds_table[range].inter_leader[range_threshold].max != -1)) {
528 /* Search for corresponding intra-node function */
529 while ((range_threshold_intra <
530 (mv2_bcast_thresholds_table[range].size_intra_table - 1))
532 mv2_bcast_thresholds_table[range].intra_node[range_threshold_intra].max)
533 && (mv2_bcast_thresholds_table[range].intra_node[range_threshold_intra].max !=
535 range_threshold_intra++;
539 mv2_bcast_thresholds_table[range].inter_leader[range_threshold].
540 MV2_pt_Bcast_function;
542 MV2_Bcast_intra_node_function =
543 mv2_bcast_thresholds_table[range].
544 intra_node[range_threshold_intra].MV2_pt_Bcast_function;
546 /* if (mv2_user_bcast_intra == NULL && */
547 /* MV2_Bcast_intra_node_function == &MPIR_Knomial_Bcast_intra_node_MV2) {*/
548 /* MV2_Bcast_intra_node_function = &MPIR_Shmem_Bcast_MV2;*/
551 if (mv2_bcast_thresholds_table[range].inter_leader[range_threshold].
552 zcpy_pipelined_knomial_factor != -1) {
553 zcpy_knomial_factor =
554 mv2_bcast_thresholds_table[range].inter_leader[range_threshold].
555 zcpy_pipelined_knomial_factor;
558 if (mv2_pipelined_zcpy_knomial_factor != -1) {
559 zcpy_knomial_factor = mv2_pipelined_zcpy_knomial_factor;
562 if (MV2_Bcast_intra_node_function == nullptr) {
563 /* if tuning table do not have any intra selection, set func pointer to
564 ** default one for mcast intra node */
565 MV2_Bcast_intra_node_function = &MPIR_Shmem_Bcast_MV2;
568 /* Set value of pipeline segment size */
569 bcast_segment_size = mv2_bcast_thresholds_table[range].bcast_segment_size;
571 /* Set value of inter node knomial factor */
572 mv2_inter_node_knomial_factor = mv2_bcast_thresholds_table[range].inter_node_knomial_factor;
574 /* Set value of intra node knomial factor */
575 mv2_intra_node_knomial_factor = mv2_bcast_thresholds_table[range].intra_node_knomial_factor;
577 /* Check if we will use a two level algorithm or not */
579 #if defined(_MCST_SUPPORT_)
580 mv2_bcast_thresholds_table[range].is_two_level_bcast[range_threshold]
581 || comm->ch.is_mcast_ok;
583 mv2_bcast_thresholds_table[range].is_two_level_bcast[range_threshold];
585 if (two_level_bcast) {
586 // if (not is_contig || not is_homogeneous) {
587 // tmp_buf = smpi_get_tmp_sendbuffer(nbytes);
590 /* if (rank == root) {*/
592 /* MPIR_Pack_impl(buffer, count, datatype, tmp_buf, nbytes, &position);*/
594 /* MPIU_ERR_POP(mpi_errno);*/
597 #ifdef CHANNEL_MRAIL_GEN2
598 if ((mv2_enable_zcpy_bcast == 1) &&
599 (&MPIR_Pipelined_Bcast_Zcpy_MV2 == MV2_Bcast_function)) {
600 // if (not is_contig || not is_homogeneous) {
601 // mpi_errno = MPIR_Pipelined_Bcast_Zcpy_MV2(tmp_buf, nbytes, MPI_BYTE, root, comm);
603 mpi_errno = MPIR_Pipelined_Bcast_Zcpy_MV2(buffer, count, datatype,
607 #endif /* defined(CHANNEL_MRAIL_GEN2) */
609 shmem_comm = comm->get_intra_comm();
610 // if (not is_contig || not is_homogeneous) {
611 // MPIR_Bcast_tune_inter_node_helper_MV2(tmp_buf, nbytes, MPI_BYTE, root, comm);
613 MPIR_Bcast_tune_inter_node_helper_MV2(buffer, count, datatype, root, comm);
616 /* We are now done with the inter-node phase */
619 root = INTRA_NODE_ROOT;
621 // if (not is_contig || not is_homogeneous) {
622 // mpi_errno = MV2_Bcast_intra_node_function(tmp_buf, nbytes, MPI_BYTE, root, shmem_comm);
624 mpi_errno = MV2_Bcast_intra_node_function(buffer, count,
625 datatype, root, shmem_comm);
629 /* if (not is_contig || not is_homogeneous) {*/
630 /* if (rank != root) {*/
632 /* mpi_errno = MPIR_Unpack_impl(tmp_buf, nbytes, &position, buffer,*/
633 /* count, datatype);*/
637 /* We use Knomial for intra node */
638 MV2_Bcast_intra_node_function = &MPIR_Knomial_Bcast_intra_node_MV2;
639 /* if (mv2_enable_shmem_bcast == 0) {*/
640 /* Fall back to non-tuned version */
641 /* MPIR_Bcast_intra_MV2(buffer, count, datatype, root, comm);*/
643 mpi_errno = MV2_Bcast_function(buffer, count, datatype, root,
656 int reduce__mvapich2(const void *sendbuf,
659 MPI_Datatype datatype,
660 MPI_Op op, int root, MPI_Comm comm)
662 if (mv2_reduce_thresholds_table == nullptr)
663 init_mv2_reduce_tables_stampede();
665 int mpi_errno = MPI_SUCCESS;
667 int range_threshold = 0;
668 int range_intra_threshold = 0;
673 bool is_two_level = false;
675 comm_size = comm->size();
676 sendtype_size=datatype->size();
677 nbytes = count * sendtype_size;
682 bool is_commutative = (op == MPI_OP_NULL || op->is_commutative());
684 /* find nearest power-of-two less than or equal to comm_size */
685 for( pof2 = 1; pof2 <= comm_size; pof2 <<= 1 );
689 /* Search for the corresponding system size inside the tuning table */
690 while ((range < (mv2_size_reduce_tuning_table - 1)) &&
691 (comm_size > mv2_reduce_thresholds_table[range].numproc)) {
694 /* Search for corresponding inter-leader function */
695 while ((range_threshold < (mv2_reduce_thresholds_table[range].size_inter_table - 1))
697 mv2_reduce_thresholds_table[range].inter_leader[range_threshold].max)
698 && (mv2_reduce_thresholds_table[range].inter_leader[range_threshold].max !=
703 /* Search for corresponding intra node function */
704 while ((range_intra_threshold < (mv2_reduce_thresholds_table[range].size_intra_table - 1))
706 mv2_reduce_thresholds_table[range].intra_node[range_intra_threshold].max)
707 && (mv2_reduce_thresholds_table[range].intra_node[range_intra_threshold].max !=
709 range_intra_threshold++;
712 /* Set intra-node function pt for reduce_two_level */
713 MV2_Reduce_intra_function =
714 mv2_reduce_thresholds_table[range].intra_node[range_intra_threshold].
715 MV2_pt_Reduce_function;
716 /* Set inter-leader pt */
717 MV2_Reduce_function =
718 mv2_reduce_thresholds_table[range].inter_leader[range_threshold].
719 MV2_pt_Reduce_function;
721 if(mv2_reduce_intra_knomial_factor<0)
723 mv2_reduce_intra_knomial_factor = mv2_reduce_thresholds_table[range].intra_k_degree;
725 if(mv2_reduce_inter_knomial_factor<0)
727 mv2_reduce_inter_knomial_factor = mv2_reduce_thresholds_table[range].inter_k_degree;
729 if (mv2_reduce_thresholds_table[range].is_two_level_reduce[range_threshold]) {
732 /* We call Reduce function */
734 if (is_commutative) {
735 if(comm->get_leaders_comm()==MPI_COMM_NULL){
738 mpi_errno = MPIR_Reduce_two_level_helper_MV2(sendbuf, recvbuf, count,
739 datatype, op, root, comm);
741 mpi_errno = MPIR_Reduce_binomial_MV2(sendbuf, recvbuf, count,
742 datatype, op, root, comm);
744 } else if(MV2_Reduce_function == &MPIR_Reduce_inter_knomial_wrapper_MV2 ){
747 mpi_errno = MV2_Reduce_function(sendbuf, recvbuf, count,
748 datatype, op, root, comm);
750 mpi_errno = MPIR_Reduce_binomial_MV2(sendbuf, recvbuf, count,
751 datatype, op, root, comm);
753 } else if(MV2_Reduce_function == &MPIR_Reduce_redscat_gather_MV2){
754 if (/*(HANDLE_GET_KIND(op) == HANDLE_KIND_BUILTIN) &&*/ (count >= pof2))
756 mpi_errno = MV2_Reduce_function(sendbuf, recvbuf, count,
757 datatype, op, root, comm);
759 mpi_errno = MPIR_Reduce_binomial_MV2(sendbuf, recvbuf, count,
760 datatype, op, root, comm);
763 mpi_errno = MV2_Reduce_function(sendbuf, recvbuf, count,
764 datatype, op, root, comm);
773 int reduce_scatter__mvapich2(const void *sendbuf, void *recvbuf, const int *recvcnts,
774 MPI_Datatype datatype, MPI_Op op,
777 int mpi_errno = MPI_SUCCESS;
778 int i = 0, comm_size = comm->size(), total_count = 0, type_size =
780 int* disps = new int[comm_size];
782 if (mv2_red_scat_thresholds_table == nullptr)
783 init_mv2_reduce_scatter_tables_stampede();
785 bool is_commutative = (op == MPI_OP_NULL || op->is_commutative());
786 for (i = 0; i < comm_size; i++) {
787 disps[i] = total_count;
788 total_count += recvcnts[i];
791 type_size=datatype->size();
792 nbytes = total_count * type_size;
794 if (is_commutative) {
796 int range_threshold = 0;
798 /* Search for the corresponding system size inside the tuning table */
799 while ((range < (mv2_size_red_scat_tuning_table - 1)) &&
800 (comm_size > mv2_red_scat_thresholds_table[range].numproc)) {
803 /* Search for corresponding inter-leader function */
804 while ((range_threshold < (mv2_red_scat_thresholds_table[range].size_inter_table - 1))
806 mv2_red_scat_thresholds_table[range].inter_leader[range_threshold].max)
807 && (mv2_red_scat_thresholds_table[range].inter_leader[range_threshold].max !=
812 /* Set inter-leader pt */
813 MV2_Red_scat_function =
814 mv2_red_scat_thresholds_table[range].inter_leader[range_threshold].
815 MV2_pt_Red_scat_function;
817 mpi_errno = MV2_Red_scat_function(sendbuf, recvbuf,
821 bool is_block_regular = true;
822 for (i = 0; i < (comm_size - 1); ++i) {
823 if (recvcnts[i] != recvcnts[i+1]) {
824 is_block_regular = false;
829 while (pof2 < comm_size) pof2 <<= 1;
830 if (pof2 == comm_size && is_block_regular) {
831 /* noncommutative, pof2 size, and block regular */
832 MPIR_Reduce_scatter_non_comm_MV2(sendbuf, recvbuf,
836 mpi_errno = reduce_scatter__mpich_rdb(sendbuf, recvbuf,
847 int scatter__mvapich2(const void *sendbuf,
849 MPI_Datatype sendtype,
852 MPI_Datatype recvtype,
853 int root, MPI_Comm comm)
855 int range = 0, range_threshold = 0, range_threshold_intra = 0;
856 int mpi_errno = MPI_SUCCESS;
857 // int mpi_errno_ret = MPI_SUCCESS;
858 int rank, nbytes, comm_size;
859 bool partial_sub_ok = false;
862 // MPID_Comm *shmem_commptr=NULL;
863 if (mv2_scatter_thresholds_table == nullptr)
864 init_mv2_scatter_tables_stampede();
866 if (comm->get_leaders_comm() == MPI_COMM_NULL) {
870 comm_size = comm->size();
875 int sendtype_size = sendtype->size();
876 nbytes = sendcnt * sendtype_size;
878 int recvtype_size = recvtype->size();
879 nbytes = recvcnt * recvtype_size;
882 // check if safe to use partial subscription mode
883 if (comm->is_uniform()) {
885 shmem_comm = comm->get_intra_comm();
886 if (mv2_scatter_table_ppn_conf[0] == -1) {
887 // Indicating user defined tuning
890 int local_size = shmem_comm->size();
893 if (local_size == mv2_scatter_table_ppn_conf[i]) {
895 partial_sub_ok = true;
899 } while(i < mv2_scatter_num_ppn_conf);
903 if (not partial_sub_ok) {
907 /* Search for the corresponding system size inside the tuning table */
908 while ((range < (mv2_size_scatter_tuning_table[conf_index] - 1)) &&
909 (comm_size > mv2_scatter_thresholds_table[conf_index][range].numproc)) {
912 /* Search for corresponding inter-leader function */
913 while ((range_threshold < (mv2_scatter_thresholds_table[conf_index][range].size_inter_table - 1))
915 mv2_scatter_thresholds_table[conf_index][range].inter_leader[range_threshold].max)
916 && (mv2_scatter_thresholds_table[conf_index][range].inter_leader[range_threshold].max != -1)) {
920 /* Search for corresponding intra-node function */
921 while ((range_threshold_intra <
922 (mv2_scatter_thresholds_table[conf_index][range].size_intra_table - 1))
924 mv2_scatter_thresholds_table[conf_index][range].intra_node[range_threshold_intra].max)
925 && (mv2_scatter_thresholds_table[conf_index][range].intra_node[range_threshold_intra].max !=
927 range_threshold_intra++;
930 MV2_Scatter_function = mv2_scatter_thresholds_table[conf_index][range].inter_leader[range_threshold]
931 .MV2_pt_Scatter_function;
933 if(MV2_Scatter_function == &MPIR_Scatter_mcst_wrap_MV2) {
934 #if defined(_MCST_SUPPORT_)
935 if(comm->ch.is_mcast_ok == 1
936 && mv2_use_mcast_scatter == 1
937 && comm->ch.shmem_coll_ok == 1) {
938 MV2_Scatter_function = &MPIR_Scatter_mcst_MV2;
940 #endif /*#if defined(_MCST_SUPPORT_) */
942 if (mv2_scatter_thresholds_table[conf_index][range].inter_leader[range_threshold + 1].MV2_pt_Scatter_function !=
944 MV2_Scatter_function =
945 mv2_scatter_thresholds_table[conf_index][range].inter_leader[range_threshold + 1].MV2_pt_Scatter_function;
948 MV2_Scatter_function = &MPIR_Scatter_MV2_Binomial;
953 if( (MV2_Scatter_function == &MPIR_Scatter_MV2_two_level_Direct) ||
954 (MV2_Scatter_function == &MPIR_Scatter_MV2_two_level_Binomial)) {
955 if( comm->is_blocked()) {
956 MV2_Scatter_intra_function = mv2_scatter_thresholds_table[conf_index][range].intra_node[range_threshold_intra]
957 .MV2_pt_Scatter_function;
960 MV2_Scatter_function(sendbuf, sendcnt, sendtype,
961 recvbuf, recvcnt, recvtype, root,
964 mpi_errno = MPIR_Scatter_MV2_Binomial(sendbuf, sendcnt, sendtype,
965 recvbuf, recvcnt, recvtype, root,
970 mpi_errno = MV2_Scatter_function(sendbuf, sendcnt, sendtype,
971 recvbuf, recvcnt, recvtype, root,
980 void smpi_coll_cleanup_mvapich2()
982 if (mv2_alltoall_thresholds_table)
983 delete[] mv2_alltoall_thresholds_table[0];
984 delete[] mv2_alltoall_thresholds_table;
985 delete[] mv2_size_alltoall_tuning_table;
986 delete[] mv2_alltoall_table_ppn_conf;
988 delete[] mv2_gather_thresholds_table;
989 if (mv2_allgather_thresholds_table)
990 delete[] mv2_allgather_thresholds_table[0];
991 delete[] mv2_size_allgather_tuning_table;
992 delete[] mv2_allgather_table_ppn_conf;
993 delete[] mv2_allgather_thresholds_table;
995 delete[] mv2_allgatherv_thresholds_table;
996 delete[] mv2_reduce_thresholds_table;
997 delete[] mv2_red_scat_thresholds_table;
998 delete[] mv2_allreduce_thresholds_table;
999 delete[] mv2_bcast_thresholds_table;
1000 if (mv2_scatter_thresholds_table)
1001 delete[] mv2_scatter_thresholds_table[0];
1002 delete[] mv2_scatter_thresholds_table;
1003 delete[] mv2_size_scatter_tuning_table;
1004 delete[] mv2_scatter_table_ppn_conf;