From 1329c0171c05c43b652a4c542851f844d7d10eed Mon Sep 17 00:00:00 2001 From: Arnaud Giersch Date: Mon, 31 May 2021 15:14:10 +0200 Subject: [PATCH] Handle duplicated datatypes within predefined MPI_Op. --- src/smpi/mpi/smpi_op.cpp | 21 +++++++++++++++++++-- 1 file changed, 19 insertions(+), 2 deletions(-) diff --git a/src/smpi/mpi/smpi_op.cpp b/src/smpi/mpi/smpi_op.cpp index 5ec47aa8eb..afc3264188 100644 --- a/src/smpi/mpi/smpi_op.cpp +++ b/src/smpi/mpi/smpi_op.cpp @@ -45,8 +45,13 @@ XBT_LOG_NEW_DEFAULT_SUBCATEGORY(smpi_op, smpi, "Logging specific to SMPI (op)"); } \ } +#define APPLY_BEGIN_OP_LOOP() \ + MPI_Datatype datatype_base = *datatype; \ + while (datatype_base->duplicated_datatype() != MPI_DATATYPE_NULL) \ + datatype_base = datatype_base->duplicated_datatype(); + #define APPLY_OP_LOOP(dtype, type, op) \ - if (*datatype == (dtype)) { \ + if (datatype_base == (dtype)) { \ APPLY_FUNC(a, b, length, type, op) \ } else @@ -121,6 +126,7 @@ APPLY_OP_LOOP(MPI_COMPLEX32, double_double,op) static void max_func(void *a, void *b, int *length, MPI_Datatype * datatype) { + APPLY_BEGIN_OP_LOOP() APPLY_BASIC_OP_LOOP(MAX_OP) APPLY_FLOAT_OP_LOOP(MAX_OP) APPLY_END_OP_LOOP(MAX_OP) @@ -128,6 +134,7 @@ static void max_func(void *a, void *b, int *length, MPI_Datatype * datatype) static void min_func(void *a, void *b, int *length, MPI_Datatype * datatype) { + APPLY_BEGIN_OP_LOOP() APPLY_BASIC_OP_LOOP(MIN_OP) APPLY_FLOAT_OP_LOOP(MIN_OP) APPLY_END_OP_LOOP(MIN_OP) @@ -135,6 +142,7 @@ static void min_func(void *a, void *b, int *length, MPI_Datatype * datatype) static void sum_func(void *a, void *b, int *length, MPI_Datatype * datatype) { + APPLY_BEGIN_OP_LOOP() APPLY_BASIC_OP_LOOP(SUM_OP) APPLY_FLOAT_OP_LOOP(SUM_OP) APPLY_COMPLEX_OP_LOOP(SUM_OP) @@ -144,6 +152,7 @@ static void sum_func(void *a, void *b, int *length, MPI_Datatype * datatype) static void prod_func(void *a, void *b, int *length, MPI_Datatype * datatype) { + APPLY_BEGIN_OP_LOOP() APPLY_BASIC_OP_LOOP(PROD_OP) APPLY_FLOAT_OP_LOOP(PROD_OP) APPLY_COMPLEX_OP_LOOP(PROD_OP) @@ -153,6 +162,7 @@ static void prod_func(void *a, void *b, int *length, MPI_Datatype * datatype) static void land_func(void *a, void *b, int *length, MPI_Datatype * datatype) { + APPLY_BEGIN_OP_LOOP() APPLY_BASIC_OP_LOOP(LAND_OP) APPLY_FLOAT_OP_LOOP(LAND_OP) APPLY_BOOL_OP_LOOP(LAND_OP) @@ -161,6 +171,7 @@ static void land_func(void *a, void *b, int *length, MPI_Datatype * datatype) static void lor_func(void *a, void *b, int *length, MPI_Datatype * datatype) { + APPLY_BEGIN_OP_LOOP() APPLY_BASIC_OP_LOOP(LOR_OP) APPLY_FLOAT_OP_LOOP(LOR_OP) APPLY_BOOL_OP_LOOP(LOR_OP) @@ -169,6 +180,7 @@ static void lor_func(void *a, void *b, int *length, MPI_Datatype * datatype) static void lxor_func(void *a, void *b, int *length, MPI_Datatype * datatype) { + APPLY_BEGIN_OP_LOOP() APPLY_BASIC_OP_LOOP(LXOR_OP) APPLY_FLOAT_OP_LOOP(LXOR_OP) APPLY_BOOL_OP_LOOP(LXOR_OP) @@ -177,6 +189,7 @@ static void lxor_func(void *a, void *b, int *length, MPI_Datatype * datatype) static void band_func(void *a, void *b, int *length, MPI_Datatype * datatype) { + APPLY_BEGIN_OP_LOOP() APPLY_BASIC_OP_LOOP(BAND_OP) APPLY_BOOL_OP_LOOP(BAND_OP) APPLY_BYTE_OP_LOOP(BAND_OP) @@ -185,6 +198,7 @@ static void band_func(void *a, void *b, int *length, MPI_Datatype * datatype) static void bor_func(void *a, void *b, int *length, MPI_Datatype * datatype) { + APPLY_BEGIN_OP_LOOP() APPLY_BASIC_OP_LOOP(BOR_OP) APPLY_BOOL_OP_LOOP(BOR_OP) APPLY_BYTE_OP_LOOP(BOR_OP) @@ -193,6 +207,7 @@ static void bor_func(void *a, void *b, int *length, MPI_Datatype * datatype) static void bxor_func(void *a, void *b, int *length, MPI_Datatype * datatype) { + APPLY_BEGIN_OP_LOOP() APPLY_BASIC_OP_LOOP(BXOR_OP) APPLY_BOOL_OP_LOOP(BXOR_OP) APPLY_BYTE_OP_LOOP(BXOR_OP) @@ -201,12 +216,14 @@ static void bxor_func(void *a, void *b, int *length, MPI_Datatype * datatype) static void minloc_func(void *a, void *b, int *length, MPI_Datatype * datatype) { + APPLY_BEGIN_OP_LOOP() APPLY_PAIR_OP_LOOP(MINLOC_OP) APPLY_END_OP_LOOP(MINLOC_OP) } static void maxloc_func(void *a, void *b, int *length, MPI_Datatype * datatype) { + APPLY_BEGIN_OP_LOOP() APPLY_PAIR_OP_LOOP(MAXLOC_OP) APPLY_END_OP_LOOP(MAXLOC_OP) } @@ -256,10 +273,10 @@ void Op::apply(const void* invec, void* inoutvec, const int* len, MPI_Datatype d } if (not smpi_process()->replaying() && *len > 0) { + XBT_DEBUG("Applying operation of length %d from %p and from/to %p", *len, invec, inoutvec); if (not is_fortran_op_) this->func_(const_cast(invec), inoutvec, const_cast(len), &datatype); else{ - XBT_DEBUG("Applying operation of length %d from %p and from/to %p", *len, invec, inoutvec); int tmp = datatype->c2f(); /* Unfortunately, the C and Fortran version of the MPI standard do not agree on the type here, thus the reinterpret_cast. */ -- 2.20.1