Logo AND Algorithmique Numérique Distribuée

Public GIT Repository
increase rank checking for RMA comms
authorAugustin Degomme <adegomme@users.noreply.github.com>
Sun, 4 Apr 2021 15:10:53 +0000 (17:10 +0200)
committerAugustin Degomme <adegomme@users.noreply.github.com>
Sun, 4 Apr 2021 15:10:53 +0000 (17:10 +0200)
src/smpi/bindings/smpi_pmpi_win.cpp
src/smpi/include/private.hpp
src/smpi/include/smpi_win.hpp
src/smpi/mpi/smpi_win.cpp

index e886154..9587422 100644 (file)
@@ -17,8 +17,7 @@ XBT_LOG_EXTERNAL_DEFAULT_CATEGORY(smpi_pmpi);
   CHECK_BUFFER(1, origin_addr, origin_count)\
   CHECK_COUNT(2, origin_count)\
   CHECK_TYPE(3, origin_datatype)\
-  CHECK_PROC(4, target_rank)\
-  CHECK_NEGATIVE(4, MPI_ERR_RANK, target_rank)\
+  CHECK_PROC_RMA(4, target_rank, win)\
   CHECK_COUNT(6, target_count)\
   CHECK_TYPE(7, target_datatype)
 
@@ -163,8 +162,8 @@ int PMPI_Win_fence( int assert,  MPI_Win win){
 
 int PMPI_Get( void *origin_addr, int origin_count, MPI_Datatype origin_datatype, int target_rank,
               MPI_Aint target_disp, int target_count, MPI_Datatype target_datatype, MPI_Win win){
-  CHECK_RMA
   CHECK_WIN(8, win)
+  CHECK_RMA
   CHECK_TARGET_DISP(5)
 
   int retval = 0;
@@ -190,8 +189,8 @@ int PMPI_Rget( void *origin_addr, int origin_count, MPI_Datatype origin_datatype
               MPI_Aint target_disp, int target_count, MPI_Datatype target_datatype, MPI_Win win, MPI_Request* request){
   if(target_rank==MPI_PROC_NULL)
     *request = MPI_REQUEST_NULL;
-  CHECK_RMA
   CHECK_WIN(8, win)
+  CHECK_RMA
   CHECK_TARGET_DISP(5)
   CHECK_NULL(9, MPI_ERR_ARG, request)
 
@@ -218,8 +217,8 @@ int PMPI_Rget( void *origin_addr, int origin_count, MPI_Datatype origin_datatype
 
 int PMPI_Put(const void *origin_addr, int origin_count, MPI_Datatype origin_datatype, int target_rank,
               MPI_Aint target_disp, int target_count, MPI_Datatype target_datatype, MPI_Win win){
-  CHECK_RMA
   CHECK_WIN(8, win)
+  CHECK_RMA
   CHECK_TARGET_DISP(5)
 
   int retval = 0;
@@ -249,8 +248,8 @@ int PMPI_Rput(const void *origin_addr, int origin_count, MPI_Datatype origin_dat
               MPI_Aint target_disp, int target_count, MPI_Datatype target_datatype, MPI_Win win, MPI_Request* request){
   if(target_rank==MPI_PROC_NULL)
     *request = MPI_REQUEST_NULL;
-  CHECK_RMA
   CHECK_WIN(8, win)
+  CHECK_RMA
   CHECK_TARGET_DISP(5)
   CHECK_NULL(9, MPI_ERR_ARG, request)
   int retval = 0;
@@ -278,9 +277,9 @@ int PMPI_Rput(const void *origin_addr, int origin_count, MPI_Datatype origin_dat
 
 int PMPI_Accumulate(const void *origin_addr, int origin_count, MPI_Datatype origin_datatype, int target_rank,
               MPI_Aint target_disp, int target_count, MPI_Datatype target_datatype, MPI_Op op, MPI_Win win){
+  CHECK_WIN(9, win)
   CHECK_RMA
   CHECK_MPI_NULL(8, MPI_OP_NULL, MPI_ERR_OP, op)
-  CHECK_WIN(9, win)
   CHECK_TARGET_DISP(5)
 
   int retval = 0;
@@ -307,9 +306,9 @@ int PMPI_Raccumulate(const void *origin_addr, int origin_count, MPI_Datatype ori
               MPI_Aint target_disp, int target_count, MPI_Datatype target_datatype, MPI_Op op, MPI_Win win, MPI_Request* request){
   if(target_rank==MPI_PROC_NULL)
     *request = MPI_REQUEST_NULL;
+  CHECK_WIN(9, win)
   CHECK_RMA
   CHECK_MPI_NULL(8, MPI_OP_NULL, MPI_ERR_OP, op)
-  CHECK_WIN(9, win)
   CHECK_TARGET_DISP(5)
   CHECK_NULL(10, MPI_ERR_ARG, request)
 
@@ -346,12 +345,11 @@ MPI_Datatype target_datatype, MPI_Op op, MPI_Win win){
   CHECK_BUFFER(4, result_addr, result_count)
   CHECK_COUNT(5, result_count)
   CHECK_TYPE(6, result_datatype)
-  CHECK_PROC(7, target_rank)
-  CHECK_NEGATIVE(7, MPI_ERR_RANK, target_rank)
+  CHECK_WIN(12, win)
+  CHECK_PROC_RMA(7, target_rank, win)
   CHECK_COUNT(9, target_count)
   CHECK_TYPE(10, target_datatype)
   CHECK_MPI_NULL(11, MPI_OP_NULL, MPI_ERR_OP, op)
-  CHECK_WIN(12, win)
   CHECK_TARGET_DISP(8)
 
   int retval = 0;
@@ -388,12 +386,11 @@ MPI_Datatype target_datatype, MPI_Op op, MPI_Win win, MPI_Request* request){
   CHECK_BUFFER(4, result_addr, result_count)
   CHECK_COUNT(5, result_count)
   CHECK_TYPE(6, result_datatype)
-  CHECK_PROC(7, target_rank)
-  CHECK_NEGATIVE(7, MPI_ERR_RANK, target_rank)
+  CHECK_WIN(12, win)
+  CHECK_PROC_RMA(7, target_rank, win)
   CHECK_COUNT(9, target_count)
   CHECK_TYPE(10, target_datatype)
   CHECK_MPI_NULL(11, MPI_OP_NULL, MPI_ERR_OP, op)
-  CHECK_WIN(12, win)
   CHECK_TARGET_DISP(8)
   CHECK_NULL(10, MPI_ERR_ARG, request)
   int retval = 0;
@@ -429,9 +426,8 @@ int PMPI_Compare_and_swap(const void* origin_addr, void* compare_addr, void* res
   CHECK_NULL(2, MPI_ERR_BUFFER, compare_addr)
   CHECK_NULL(3, MPI_ERR_BUFFER, result_addr)
   CHECK_TYPE(4, datatype)
-  CHECK_PROC(5, target_rank)
-  CHECK_NEGATIVE(5, MPI_ERR_RANK, target_rank)
   CHECK_WIN(6, win)
+  CHECK_PROC_RMA(5, target_rank, win)
   CHECK_TARGET_DISP(6)
 
   int retval = 0;
@@ -501,8 +497,8 @@ int PMPI_Win_wait(MPI_Win win){
 }
 
 int PMPI_Win_lock(int lock_type, int rank, int assert, MPI_Win win){
-  CHECK_PROC(2, rank)
   CHECK_WIN(4, win)
+  CHECK_PROC_RMA(2, rank, win)
   int retval = 0;
   smpi_bench_end();
   if (lock_type != MPI_LOCK_EXCLUSIVE &&
@@ -519,8 +515,8 @@ int PMPI_Win_lock(int lock_type, int rank, int assert, MPI_Win win){
 }
 
 int PMPI_Win_unlock(int rank, MPI_Win win){
-  CHECK_PROC(1, rank)
   CHECK_WIN(2, win)
+  CHECK_PROC_RMA(1, rank, win)
   smpi_bench_end();
   int my_proc_id = simgrid::s4u::this_actor::get_pid();
   TRACE_smpi_comm_in(my_proc_id, __func__, new simgrid::instr::NoOpTIData("Win_unlock"));
@@ -553,8 +549,8 @@ int PMPI_Win_unlock_all(MPI_Win win){
 }
 
 int PMPI_Win_flush(int rank, MPI_Win win){
-  CHECK_PROC(1, rank)
   CHECK_WIN(2, win)
+  CHECK_PROC_RMA(1, rank, win)
   smpi_bench_end();
   int my_proc_id = simgrid::s4u::this_actor::get_pid();
   TRACE_smpi_comm_in(my_proc_id, __func__, new simgrid::instr::NoOpTIData("Win_flush"));
@@ -565,8 +561,9 @@ int PMPI_Win_flush(int rank, MPI_Win win){
 }
 
 int PMPI_Win_flush_local(int rank, MPI_Win win){
-  CHECK_PROC(1, rank)
-  CHECK_WIN(2, win)  smpi_bench_end();
+  CHECK_WIN(2, win)
+  CHECK_PROC_RMA(1, rank, win)
+  smpi_bench_end();
   int my_proc_id = simgrid::s4u::this_actor::get_pid();
   TRACE_smpi_comm_in(my_proc_id, __func__, new simgrid::instr::NoOpTIData("Win_flush_local"));
   int retval = win->flush_local(rank);
index afd523a..322d4f7 100644 (file)
@@ -556,13 +556,10 @@ XBT_PRIVATE void private_execute_flops(double flops);
   CHECK_MPI_NULL((num), MPI_OP_NULL, MPI_ERR_OP, (op))                                                                 \
   CHECK_ARGS(((op)->allowed_types() && (((op)->allowed_types() & (type)->flags()) == 0)), MPI_ERR_OP,                \
              "%s: param %d op %s can't be applied to type %s", __func__, (num), _XBT_STRINGIFY(op), type->name());
-
 #define CHECK_ROOT(num)\
   CHECK_ARGS((root < 0 || root >= comm->size()), MPI_ERR_ROOT,                                                         \
              "%s: param %d root (=%d) cannot be negative or larger than communicator size (=%d)", __func__, (num),     \
              root, comm->size());
-#define CHECK_PROC(num,proc)                                                                                           \
-  CHECK_MPI_NULL((num), MPI_PROC_NULL, MPI_SUCCESS, (proc))
 #define CHECK_INFO(num,info)                                                                                           \
   CHECK_MPI_NULL((num), MPI_INFO_NULL, MPI_ERR_INFO, (info))
 #define CHECK_TAG(num,tag)                                                                                             \
@@ -577,7 +574,10 @@ XBT_PRIVATE void private_execute_flops(double flops);
 #define CHECK_WIN(num, win)                                                                                            \
   CHECK_MPI_NULL((num), MPI_WIN_NULL, MPI_ERR_WIN, (win))
 #define CHECK_RANK(num, rank, comm)                                                                                    \
-  CHECK_ARGS(((rank) >= (comm)->group()->size() || (rank) <0), MPI_ERR_RANK,                                           \
+  CHECK_ARGS(((rank) >= (comm)->size() || (rank) <0), MPI_ERR_RANK,                                                    \
              "%s: param %d %s (=%d) cannot be < 0 or > %d", __func__, (num), _XBT_STRINGIFY(rank),                     \
-             (rank), (comm)->group()->size() );
+             (rank), (comm)->size() );
+#define CHECK_PROC_RMA(num,proc,win)                                                                                   \
+  CHECK_MPI_NULL((num), MPI_PROC_NULL, MPI_SUCCESS, (proc))                                                            \
+  CHECK_RANK(num, proc, win->comm())
 #endif
index 69b05a3..5022f6e 100644 (file)
@@ -57,6 +57,7 @@ public:
   void get_group( MPI_Group* group);
   void set_name(const char* name);
   int rank() const;
+  MPI_Comm comm() const;
   int dynamic() const;
   int start(MPI_Group group, int assert);
   int post(MPI_Group group, int assert);
index b04f1b7..9850c2a 100644 (file)
@@ -128,6 +128,11 @@ int Win::rank() const
   return rank_;
 }
 
+MPI_Comm Win::comm() const
+{
+  return comm_;
+}
+
 MPI_Aint Win::size() const
 {
   return size_;