From 49ddf0d3ee554c79c94509e22d37400c1b2b5f33 Mon Sep 17 00:00:00 2001 From: Augustin Degomme Date: Sun, 4 Apr 2021 17:10:53 +0200 Subject: [PATCH] increase rank checking for RMA comms --- src/smpi/bindings/smpi_pmpi_win.cpp | 39 +++++++++++++---------------- src/smpi/include/private.hpp | 10 ++++---- src/smpi/include/smpi_win.hpp | 1 + src/smpi/mpi/smpi_win.cpp | 5 ++++ 4 files changed, 29 insertions(+), 26 deletions(-) diff --git a/src/smpi/bindings/smpi_pmpi_win.cpp b/src/smpi/bindings/smpi_pmpi_win.cpp index e8861546c0..95874227ce 100644 --- a/src/smpi/bindings/smpi_pmpi_win.cpp +++ b/src/smpi/bindings/smpi_pmpi_win.cpp @@ -17,8 +17,7 @@ XBT_LOG_EXTERNAL_DEFAULT_CATEGORY(smpi_pmpi); CHECK_BUFFER(1, origin_addr, origin_count)\ CHECK_COUNT(2, origin_count)\ CHECK_TYPE(3, origin_datatype)\ - CHECK_PROC(4, target_rank)\ - CHECK_NEGATIVE(4, MPI_ERR_RANK, target_rank)\ + CHECK_PROC_RMA(4, target_rank, win)\ CHECK_COUNT(6, target_count)\ CHECK_TYPE(7, target_datatype) @@ -163,8 +162,8 @@ int PMPI_Win_fence( int assert, MPI_Win win){ int PMPI_Get( void *origin_addr, int origin_count, MPI_Datatype origin_datatype, int target_rank, MPI_Aint target_disp, int target_count, MPI_Datatype target_datatype, MPI_Win win){ - CHECK_RMA CHECK_WIN(8, win) + CHECK_RMA CHECK_TARGET_DISP(5) int retval = 0; @@ -190,8 +189,8 @@ int PMPI_Rget( void *origin_addr, int origin_count, MPI_Datatype origin_datatype MPI_Aint target_disp, int target_count, MPI_Datatype target_datatype, MPI_Win win, MPI_Request* request){ if(target_rank==MPI_PROC_NULL) *request = MPI_REQUEST_NULL; - CHECK_RMA CHECK_WIN(8, win) + CHECK_RMA CHECK_TARGET_DISP(5) CHECK_NULL(9, MPI_ERR_ARG, request) @@ -218,8 +217,8 @@ int PMPI_Rget( void *origin_addr, int origin_count, MPI_Datatype origin_datatype int PMPI_Put(const void *origin_addr, int origin_count, MPI_Datatype origin_datatype, int target_rank, MPI_Aint target_disp, int target_count, MPI_Datatype target_datatype, MPI_Win win){ - CHECK_RMA CHECK_WIN(8, win) + CHECK_RMA CHECK_TARGET_DISP(5) int retval = 0; @@ -249,8 +248,8 @@ int PMPI_Rput(const void *origin_addr, int origin_count, MPI_Datatype origin_dat MPI_Aint target_disp, int target_count, MPI_Datatype target_datatype, MPI_Win win, MPI_Request* request){ if(target_rank==MPI_PROC_NULL) *request = MPI_REQUEST_NULL; - CHECK_RMA CHECK_WIN(8, win) + CHECK_RMA CHECK_TARGET_DISP(5) CHECK_NULL(9, MPI_ERR_ARG, request) int retval = 0; @@ -278,9 +277,9 @@ int PMPI_Rput(const void *origin_addr, int origin_count, MPI_Datatype origin_dat int PMPI_Accumulate(const void *origin_addr, int origin_count, MPI_Datatype origin_datatype, int target_rank, MPI_Aint target_disp, int target_count, MPI_Datatype target_datatype, MPI_Op op, MPI_Win win){ + CHECK_WIN(9, win) CHECK_RMA CHECK_MPI_NULL(8, MPI_OP_NULL, MPI_ERR_OP, op) - CHECK_WIN(9, win) CHECK_TARGET_DISP(5) int retval = 0; @@ -307,9 +306,9 @@ int PMPI_Raccumulate(const void *origin_addr, int origin_count, MPI_Datatype ori MPI_Aint target_disp, int target_count, MPI_Datatype target_datatype, MPI_Op op, MPI_Win win, MPI_Request* request){ if(target_rank==MPI_PROC_NULL) *request = MPI_REQUEST_NULL; + CHECK_WIN(9, win) CHECK_RMA CHECK_MPI_NULL(8, MPI_OP_NULL, MPI_ERR_OP, op) - CHECK_WIN(9, win) CHECK_TARGET_DISP(5) CHECK_NULL(10, MPI_ERR_ARG, request) @@ -346,12 +345,11 @@ MPI_Datatype target_datatype, MPI_Op op, MPI_Win win){ CHECK_BUFFER(4, result_addr, result_count) CHECK_COUNT(5, result_count) CHECK_TYPE(6, result_datatype) - CHECK_PROC(7, target_rank) - CHECK_NEGATIVE(7, MPI_ERR_RANK, target_rank) + CHECK_WIN(12, win) + CHECK_PROC_RMA(7, target_rank, win) CHECK_COUNT(9, target_count) CHECK_TYPE(10, target_datatype) CHECK_MPI_NULL(11, MPI_OP_NULL, MPI_ERR_OP, op) - CHECK_WIN(12, win) CHECK_TARGET_DISP(8) int retval = 0; @@ -388,12 +386,11 @@ MPI_Datatype target_datatype, MPI_Op op, MPI_Win win, MPI_Request* request){ CHECK_BUFFER(4, result_addr, result_count) CHECK_COUNT(5, result_count) CHECK_TYPE(6, result_datatype) - CHECK_PROC(7, target_rank) - CHECK_NEGATIVE(7, MPI_ERR_RANK, target_rank) + CHECK_WIN(12, win) + CHECK_PROC_RMA(7, target_rank, win) CHECK_COUNT(9, target_count) CHECK_TYPE(10, target_datatype) CHECK_MPI_NULL(11, MPI_OP_NULL, MPI_ERR_OP, op) - CHECK_WIN(12, win) CHECK_TARGET_DISP(8) CHECK_NULL(10, MPI_ERR_ARG, request) int retval = 0; @@ -429,9 +426,8 @@ int PMPI_Compare_and_swap(const void* origin_addr, void* compare_addr, void* res CHECK_NULL(2, MPI_ERR_BUFFER, compare_addr) CHECK_NULL(3, MPI_ERR_BUFFER, result_addr) CHECK_TYPE(4, datatype) - CHECK_PROC(5, target_rank) - CHECK_NEGATIVE(5, MPI_ERR_RANK, target_rank) CHECK_WIN(6, win) + CHECK_PROC_RMA(5, target_rank, win) CHECK_TARGET_DISP(6) int retval = 0; @@ -501,8 +497,8 @@ int PMPI_Win_wait(MPI_Win win){ } int PMPI_Win_lock(int lock_type, int rank, int assert, MPI_Win win){ - CHECK_PROC(2, rank) CHECK_WIN(4, win) + CHECK_PROC_RMA(2, rank, win) int retval = 0; smpi_bench_end(); if (lock_type != MPI_LOCK_EXCLUSIVE && @@ -519,8 +515,8 @@ int PMPI_Win_lock(int lock_type, int rank, int assert, MPI_Win win){ } int PMPI_Win_unlock(int rank, MPI_Win win){ - CHECK_PROC(1, rank) CHECK_WIN(2, win) + CHECK_PROC_RMA(1, rank, win) smpi_bench_end(); int my_proc_id = simgrid::s4u::this_actor::get_pid(); TRACE_smpi_comm_in(my_proc_id, __func__, new simgrid::instr::NoOpTIData("Win_unlock")); @@ -553,8 +549,8 @@ int PMPI_Win_unlock_all(MPI_Win win){ } int PMPI_Win_flush(int rank, MPI_Win win){ - CHECK_PROC(1, rank) CHECK_WIN(2, win) + CHECK_PROC_RMA(1, rank, win) smpi_bench_end(); int my_proc_id = simgrid::s4u::this_actor::get_pid(); TRACE_smpi_comm_in(my_proc_id, __func__, new simgrid::instr::NoOpTIData("Win_flush")); @@ -565,8 +561,9 @@ int PMPI_Win_flush(int rank, MPI_Win win){ } int PMPI_Win_flush_local(int rank, MPI_Win win){ - CHECK_PROC(1, rank) - CHECK_WIN(2, win) smpi_bench_end(); + CHECK_WIN(2, win) + CHECK_PROC_RMA(1, rank, win) + smpi_bench_end(); int my_proc_id = simgrid::s4u::this_actor::get_pid(); TRACE_smpi_comm_in(my_proc_id, __func__, new simgrid::instr::NoOpTIData("Win_flush_local")); int retval = win->flush_local(rank); diff --git a/src/smpi/include/private.hpp b/src/smpi/include/private.hpp index afd523a439..322d4f7328 100644 --- a/src/smpi/include/private.hpp +++ b/src/smpi/include/private.hpp @@ -556,13 +556,10 @@ XBT_PRIVATE void private_execute_flops(double flops); CHECK_MPI_NULL((num), MPI_OP_NULL, MPI_ERR_OP, (op)) \ CHECK_ARGS(((op)->allowed_types() && (((op)->allowed_types() & (type)->flags()) == 0)), MPI_ERR_OP, \ "%s: param %d op %s can't be applied to type %s", __func__, (num), _XBT_STRINGIFY(op), type->name()); - #define CHECK_ROOT(num)\ CHECK_ARGS((root < 0 || root >= comm->size()), MPI_ERR_ROOT, \ "%s: param %d root (=%d) cannot be negative or larger than communicator size (=%d)", __func__, (num), \ root, comm->size()); -#define CHECK_PROC(num,proc) \ - CHECK_MPI_NULL((num), MPI_PROC_NULL, MPI_SUCCESS, (proc)) #define CHECK_INFO(num,info) \ CHECK_MPI_NULL((num), MPI_INFO_NULL, MPI_ERR_INFO, (info)) #define CHECK_TAG(num,tag) \ @@ -577,7 +574,10 @@ XBT_PRIVATE void private_execute_flops(double flops); #define CHECK_WIN(num, win) \ CHECK_MPI_NULL((num), MPI_WIN_NULL, MPI_ERR_WIN, (win)) #define CHECK_RANK(num, rank, comm) \ - CHECK_ARGS(((rank) >= (comm)->group()->size() || (rank) <0), MPI_ERR_RANK, \ + CHECK_ARGS(((rank) >= (comm)->size() || (rank) <0), MPI_ERR_RANK, \ "%s: param %d %s (=%d) cannot be < 0 or > %d", __func__, (num), _XBT_STRINGIFY(rank), \ - (rank), (comm)->group()->size() ); + (rank), (comm)->size() ); +#define CHECK_PROC_RMA(num,proc,win) \ + CHECK_MPI_NULL((num), MPI_PROC_NULL, MPI_SUCCESS, (proc)) \ + CHECK_RANK(num, proc, win->comm()) #endif diff --git a/src/smpi/include/smpi_win.hpp b/src/smpi/include/smpi_win.hpp index 69b05a32b7..5022f6eeb4 100644 --- a/src/smpi/include/smpi_win.hpp +++ b/src/smpi/include/smpi_win.hpp @@ -57,6 +57,7 @@ public: void get_group( MPI_Group* group); void set_name(const char* name); int rank() const; + MPI_Comm comm() const; int dynamic() const; int start(MPI_Group group, int assert); int post(MPI_Group group, int assert); diff --git a/src/smpi/mpi/smpi_win.cpp b/src/smpi/mpi/smpi_win.cpp index b04f1b7c3f..9850c2a2f2 100644 --- a/src/smpi/mpi/smpi_win.cpp +++ b/src/smpi/mpi/smpi_win.cpp @@ -128,6 +128,11 @@ int Win::rank() const return rank_; } +MPI_Comm Win::comm() const +{ + return comm_; +} + MPI_Aint Win::size() const { return size_; -- 2.20.1