Logo AND Algorithmique Numérique Distribuée

Public GIT Repository
MPI_Iallreduce
authorAugustin Degomme <adegomme@users.noreply.github.com>
Mon, 1 Apr 2019 21:43:26 +0000 (23:43 +0200)
committerAugustin Degomme <adegomme@users.noreply.github.com>
Mon, 1 Apr 2019 22:46:45 +0000 (00:46 +0200)
src/smpi/bindings/smpi_pmpi_coll.cpp
src/smpi/colls/smpi_nbc_impl.cpp
src/smpi/mpi/smpi_request.cpp
teshsuite/smpi/mpich3-test/coll/nonblocking.c
teshsuite/smpi/mpich3-test/coll/nonblocking2.c

index bd0ac32..cc86aec 100644 (file)
@@ -455,11 +455,7 @@ int PMPI_Iallreduce(void *sendbuf, void *recvbuf, int count, MPI_Datatype dataty
     retval = MPI_ERR_TYPE;
   } else if (op == MPI_OP_NULL) {
     retval = MPI_ERR_OP;
-  } else if (request != MPI_REQUEST_IGNORED) {
-    xbt_die("Iallreduce is not yet implemented. WIP");
-    retval = MPI_ERR_ARG;
   } else {
-
     char* sendtmpbuf = static_cast<char*>(sendbuf);
     if( sendbuf == MPI_IN_PLACE ) {
       sendtmpbuf = static_cast<char*>(xbt_malloc(count*datatype->get_extent()));
@@ -472,10 +468,10 @@ int PMPI_Iallreduce(void *sendbuf, void *recvbuf, int count, MPI_Datatype dataty
                                                       datatype->is_replayable() ? count : count * datatype->size(), -1,
                                                       simgrid::smpi::Datatype::encode(datatype), ""));
 
-//    if(request == MPI_REQUEST_IGNORED)
+    if(request == MPI_REQUEST_IGNORED)
       simgrid::smpi::Colls::allreduce(sendtmpbuf, recvbuf, count, datatype, op, comm);
-//    else
-//      simgrid::smpi::Colls::iallreduce(sendtmpbuf, recvbuf, count, datatype, op, comm, request);
+    else
+      simgrid::smpi::Colls::iallreduce(sendtmpbuf, recvbuf, count, datatype, op, comm, request);
 
     if( sendbuf == MPI_IN_PLACE )
       xbt_free(sendtmpbuf);
index 0cb25f0..8e9b16d 100644 (file)
@@ -501,5 +501,40 @@ int Colls::ireduce(void *sendbuf, void *recvbuf, int count, MPI_Datatype datatyp
   }
   return MPI_SUCCESS;
 }
+
+int Colls::iallreduce(void *sendbuf, void *recvbuf, int count, MPI_Datatype datatype,
+                      MPI_Op op, MPI_Comm comm, MPI_Request* request)
+{
+
+  const int system_tag = COLL_TAG_ALLREDUCE;
+  MPI_Aint lb = 0;
+  MPI_Aint dataext = 0;
+  MPI_Request *requests;
+
+  int rank = comm->rank();
+  int size = comm->size();
+  (*request) = new Request( recvbuf, count, datatype,
+                         rank,rank, COLL_TAG_ALLREDUCE, comm, MPI_REQ_PERSISTENT, op);
+  // FIXME: check for errors
+  datatype->extent(&lb, &dataext);
+  // Local copy from self
+  Datatype::copy(sendbuf, count, datatype, recvbuf, count, datatype);
+  // Send/Recv buffers to/from others;
+  requests = new MPI_Request[2 * (size - 1)];
+  int index = 0;
+  for (int other = 0; other < size; other++) {
+    if(other != rank) {
+      requests[index] = Request::isend_init(sendbuf, count, datatype, other, system_tag,comm);
+      index++;
+      requests[index] = Request::irecv_init(smpi_get_tmp_sendbuffer(count * dataext), count, datatype,
+                                        other, system_tag, comm);
+      index++;
+    }
+  }
+  Request::startall(2 * (size - 1), requests);
+  (*request)->set_nbc_requests(requests, 2 * (size - 1));
+  return MPI_SUCCESS;
+}
+
 }
 }
index 5cf3a49..47b37fe 100644 (file)
@@ -871,11 +871,13 @@ int Request::wait(MPI_Request * request, MPI_Status * status)
         void * buf=(*request)->nbc_requests_[i]->buf_;
         if((*request)->old_type_->flags() & DT_FLAG_DERIVED)
           buf=(*request)->nbc_requests_[i]->old_buf_;
-        if((*request)->op_!=MPI_OP_NULL){
-          int count=(*request)->size_/ (*request)->old_type_->size();
-          (*request)->op_->apply(buf, (*request)->buf_, &count, (*request)->old_type_);
+        if((*request)->nbc_requests_[i]->flags_ & MPI_REQ_RECV ){
+          if((*request)->op_!=MPI_OP_NULL){
+            int count=(*request)->size_/ (*request)->old_type_->size();
+            (*request)->op_->apply(buf, (*request)->buf_, &count, (*request)->old_type_);
+          }
+          smpi_free_tmp_buffer(buf);
         }
-        smpi_free_tmp_buffer(buf);
       }
       if((*request)->nbc_requests_[i]!=MPI_REQUEST_NULL)
         Request::unref(&((*request)->nbc_requests_[i]));
index 0038da1..94938e5 100644 (file)
@@ -162,11 +162,11 @@ int main(int argc, char **argv)
         MPI_Ireduce(sbuf, rbuf, NUM_INTS, MPI_INT, MPI_SUM, 0, comm, &req);
     MPI_Wait(&req, MPI_STATUS_IGNORE);
 
-/*    MPI_Iallreduce(sbuf, rbuf, NUM_INTS, MPI_INT, MPI_SUM, comm, &req);*/
-/*    MPI_Wait(&req, MPI_STATUS_IGNORE);*/
+    MPI_Iallreduce(sbuf, rbuf, NUM_INTS, MPI_INT, MPI_SUM, comm, &req);
+    MPI_Wait(&req, MPI_STATUS_IGNORE);
 
-/*    MPI_Iallreduce(MPI_IN_PLACE, rbuf, NUM_INTS, MPI_INT, MPI_SUM, comm, &req);*/
-/*    MPI_Wait(&req, MPI_STATUS_IGNORE);*/
+    MPI_Iallreduce(MPI_IN_PLACE, rbuf, NUM_INTS, MPI_INT, MPI_SUM, comm, &req);
+    MPI_Wait(&req, MPI_STATUS_IGNORE);
 
 /*    MPI_Ireduce_scatter(sbuf, rbuf, rcounts, MPI_INT, MPI_SUM, comm, &req);*/
 /*    MPI_Wait(&req, MPI_STATUS_IGNORE);*/
index 72d774f..f266f54 100644 (file)
@@ -117,7 +117,7 @@ int main(int argc, char **argv)
     if (rank == 0) {
         for (i = 0; i < COUNT; ++i) {
             if (recvbuf[i] != ((size * (size - 1) / 2) + (i * size)))
-                printf("aa got recvbuf[%d]=%d, expected %d\n", i, recvbuf[i],
+                printf("got recvbuf[%d]=%d, expected %d\n", i, recvbuf[i],
                        ((size * (size - 1) / 2) + (i * size)));
             my_assert(recvbuf[i] == ((size * (size - 1) / 2) + (i * size)));
         }
@@ -145,18 +145,18 @@ int main(int argc, char **argv)
     }
 
     /* MPI_Iallreduce */
-/*    for (i = 0; i < COUNT; ++i) {*/
-/*        buf[i] = rank + i;*/
-/*        recvbuf[i] = 0xdeadbeef;*/
-/*    }*/
-/*    MPI_Iallreduce(buf, recvbuf, COUNT, MPI_INT, MPI_SUM, MPI_COMM_WORLD, &req);*/
-/*    MPI_Wait(&req, MPI_STATUS_IGNORE);*/
-/*    for (i = 0; i < COUNT; ++i) {*/
-/*        if (recvbuf[i] != ((size * (size - 1) / 2) + (i * size)))*/
-/*            printf("got recvbuf[%d]=%d, expected %d\n", i, recvbuf[i],*/
-/*                   ((size * (size - 1) / 2) + (i * size)));*/
-/*        my_assert(recvbuf[i] == ((size * (size - 1) / 2) + (i * size)));*/
-/*    }*/
+    for (i = 0; i < COUNT; ++i) {
+        buf[i] = rank + i;
+        recvbuf[i] = 0xdeadbeef;
+    }
+    MPI_Iallreduce(buf, recvbuf, COUNT, MPI_INT, MPI_SUM, MPI_COMM_WORLD, &req);
+    MPI_Wait(&req, MPI_STATUS_IGNORE);
+    for (i = 0; i < COUNT; ++i) {
+        if (recvbuf[i] != ((size * (size - 1) / 2) + (i * size)))
+            printf("got recvbuf[%d]=%d, expected %d\n", i, recvbuf[i],
+                   ((size * (size - 1) / 2) + (i * size)));
+        my_assert(recvbuf[i] == ((size * (size - 1) / 2) + (i * size)));
+    }
 
     /* MPI_Ialltoallv (a weak test, neither irregular nor sparse) */
     for (i = 0; i < size; ++i) {