Logo AND Algorithmique Numérique Distribuée

Public GIT Repository
Fix previous fix.
authorAugustin Degomme <adegomme@users.noreply.github.com>
Fri, 7 May 2021 13:45:16 +0000 (15:45 +0200)
committerAugustin Degomme <adegomme@users.noreply.github.com>
Fri, 7 May 2021 13:45:32 +0000 (15:45 +0200)
It was turning test into wait for non blocking collectives, which is not what we wanted.

src/smpi/include/smpi_request.hpp
src/smpi/mpi/smpi_request.cpp

index 8027533..19bc383 100644 (file)
@@ -73,7 +73,7 @@ public:
   void init_buffer(int count);
   void ref();
   void set_nbc_requests(MPI_Request* reqs, int size);
-  static int finish_nbc_requests(MPI_Request* req);
+  static int finish_nbc_requests(MPI_Request* req, int test);
   int get_nbc_requests_size() const;
   MPI_Request* get_nbc_requests() const;
   static void finish_wait(MPI_Request* request, MPI_Status* status);
index b782759..a5f4ea8 100644 (file)
@@ -604,6 +604,11 @@ int Request::test(MPI_Request * request, MPI_Status * status, int* flag) {
 
   Status::empty(status);
   *flag = 1;
+
+  if ((*request)->flags_ & MPI_REQ_NBC){
+    *flag = finish_nbc_requests(request, 1);
+  }
+
   if (((*request)->flags_ & (MPI_REQ_PREPARED | MPI_REQ_FINISHED)) == 0) {
     if ((*request)->action_ != nullptr && ((*request)->flags_ & MPI_REQ_CANCELLED) == 0){
       try{
@@ -720,7 +725,11 @@ int Request::testany(int count, MPI_Request requests[], int *index, int* flag, M
         ret=(requests[*index]->generalized_funcs)->query_fn((requests[*index]->generalized_funcs)->extra_state, mystatus);
       }
 
-        if (requests[*index] != MPI_REQUEST_NULL && (requests[*index]->flags_ & MPI_REQ_NON_PERSISTENT)) 
+      if (requests[*index] != MPI_REQUEST_NULL && requests[*index]->flags_ & MPI_REQ_NBC){
+        *flag = finish_nbc_requests(&requests[*index] , 1);
+      }
+
+      if (requests[*index] != MPI_REQUEST_NULL && (requests[*index]->flags_ & MPI_REQ_NON_PERSISTENT))
           requests[*index] = MPI_REQUEST_NULL;
         XBT_DEBUG("Testany - returning with index %d", *index);
         *flag=1;
@@ -752,7 +761,6 @@ int Request::testall(int count, MPI_Request requests[], int* outflag, MPI_Status
       int ret = test(&requests[i], pstat, &flag);
       if (flag){
         flag=0;
-        requests[i]=MPI_REQUEST_NULL;
       }else{
         *outflag=0;
       }
@@ -844,28 +852,38 @@ void Request::iprobe(int source, int tag, MPI_Comm comm, int* flag, MPI_Status*
   xbt_assert(request == MPI_REQUEST_NULL);
 }
 
-int Request::finish_nbc_requests(MPI_Request* request){
-  int ret = waitall((*request)->nbc_requests_size_, (*request)->nbc_requests_, MPI_STATUSES_IGNORE);
-  XBT_DEBUG("finish non blocking collective request with %d sub-requests", (*request)->nbc_requests_size_);
-  for (int i = 0; i < (*request)->nbc_requests_size_; i++) {
-    if((*request)->buf_!=nullptr && (*request)->nbc_requests_[i]!=MPI_REQUEST_NULL){//reduce case
-      void * buf=(*request)->nbc_requests_[i]->buf_;
-      if((*request)->old_type_->flags() & DT_FLAG_DERIVED)
-        buf=(*request)->nbc_requests_[i]->old_buf_;
-      if((*request)->nbc_requests_[i]->flags_ & MPI_REQ_RECV ){
-        if((*request)->op_!=MPI_OP_NULL){
-          int count=(*request)->size_/ (*request)->old_type_->size();
-          (*request)->op_->apply(buf, (*request)->buf_, &count, (*request)->old_type_);
+int Request::finish_nbc_requests(MPI_Request* request, int test){
+  int flag = 1;
+  int ret = 0;
+  if(test == 0)
+    ret = waitall((*request)->nbc_requests_size_, (*request)->nbc_requests_, MPI_STATUSES_IGNORE);
+  else{
+    ret = testall((*request)->nbc_requests_size_, (*request)->nbc_requests_, &flag, MPI_STATUSES_IGNORE);
+  }
+  if(ret!=MPI_SUCCESS)
+    xbt_die("Failure when waiting on non blocking collective sub-requests");
+  if(flag == 1){
+    XBT_DEBUG("Finishing non blocking collective request with %d sub-requests", (*request)->nbc_requests_size_);
+    for (int i = 0; i < (*request)->nbc_requests_size_; i++) {
+      if((*request)->buf_!=nullptr && (*request)->nbc_requests_[i]!=MPI_REQUEST_NULL){//reduce case
+        void * buf=(*request)->nbc_requests_[i]->buf_;
+        if((*request)->old_type_->flags() & DT_FLAG_DERIVED)
+          buf=(*request)->nbc_requests_[i]->old_buf_;
+        if((*request)->nbc_requests_[i]->flags_ & MPI_REQ_RECV ){
+          if((*request)->op_!=MPI_OP_NULL){
+            int count=(*request)->size_/ (*request)->old_type_->size();
+            (*request)->op_->apply(buf, (*request)->buf_, &count, (*request)->old_type_);
+          }
+          smpi_free_tmp_buffer(static_cast<unsigned char*>(buf));
         }
-        smpi_free_tmp_buffer(static_cast<unsigned char*>(buf));
       }
+      if((*request)->nbc_requests_[i]!=MPI_REQUEST_NULL)
+        Request::unref(&((*request)->nbc_requests_[i]));
     }
-    if((*request)->nbc_requests_[i]!=MPI_REQUEST_NULL)
-      Request::unref(&((*request)->nbc_requests_[i]));
+    delete[] (*request)->nbc_requests_;
+    (*request)->nbc_requests_size_=0;
   }
-  delete[] (*request)->nbc_requests_;
-  (*request)->nbc_requests_size_=0;
-  return ret;
+  return flag;
 }
 
 void Request::finish_wait(MPI_Request* request, MPI_Status * status)
@@ -881,12 +899,6 @@ void Request::finish_wait(MPI_Request* request, MPI_Status * status)
     return;
   }
 
-  if ((*request)->flags() & MPI_REQ_NBC){
-    int ret = finish_nbc_requests(request);
-    if (ret != MPI_SUCCESS)
-      xbt_die("error when finishing non blocking collective requests");
-  }
-
   if ((req->flags_ & (MPI_REQ_PREPARED | MPI_REQ_GENERALIZED | MPI_REQ_FINISHED)) == 0) {
     if (status != MPI_STATUS_IGNORE) {
       if (req->src_== MPI_PROC_NULL || req->dst_== MPI_PROC_NULL){
@@ -1023,6 +1035,9 @@ int Request::wait(MPI_Request * request, MPI_Status * status)
   if ((*request)->truncated_)
     ret = MPI_ERR_TRUNCATE;
 
+  if ((*request)->flags_ & MPI_REQ_NBC)
+    finish_nbc_requests(request, 0);
+
   finish_wait(request, status); // may invalidate *request
   if (*request != MPI_REQUEST_NULL && (((*request)->flags_ & MPI_REQ_NON_PERSISTENT) != 0))
     *request = MPI_REQUEST_NULL;
@@ -1049,6 +1064,8 @@ int Request::waitany(int count, MPI_Request requests[], MPI_Status * status)
           // This is a finished detached request, let's return this one
           comms.clear(); // don't do the waitany call afterwards
           index = i;
+          if (requests[index] != MPI_REQUEST_NULL && (requests[index])->flags_ & MPI_REQ_NBC)
+            finish_nbc_requests(&requests[index], 0);
           finish_wait(&requests[i], status); // cleanup if refcount = 0
           if (requests[i] != MPI_REQUEST_NULL && (requests[i]->flags_ & MPI_REQ_NON_PERSISTENT))
             requests[i] = MPI_REQUEST_NULL; // set to null
@@ -1080,6 +1097,7 @@ int Request::waitany(int count, MPI_Request requests[], MPI_Status * status)
     }
   }
 
+
   if (index==MPI_UNDEFINED)
     Status::empty(status);