Logo AND Algorithmique Numérique Distribuée

Public GIT Repository
Check types matching in MPI communications.
authorAugustin Degomme <adegomme@gmail.com>
Wed, 26 May 2021 14:49:56 +0000 (16:49 +0200)
committerAugustin Degomme <adegomme@gmail.com>
Wed, 26 May 2021 14:50:12 +0000 (16:50 +0200)
should work with basic ones and duplicated ones.
todo: derived types matching, which is more tricky.

src/smpi/include/smpi_request.hpp
src/smpi/mpi/smpi_request.cpp

index b36aa2e..1449c54 100644 (file)
@@ -40,6 +40,7 @@ class Request : public F2C {
   aid_t real_src_;
   int real_tag_;
   bool truncated_;
+  bool unmatched_types_;
   size_t real_size_;
   MPI_Comm comm_;
   simgrid::kernel::activity::ActivityImplPtr action_;
@@ -51,6 +52,7 @@ class Request : public F2C {
   std::unique_ptr<smpi_mpi_generalized_request_funcs_t> generalized_funcs;
   std::vector<MPI_Request> nbc_requests_;
   static bool match_common(MPI_Request req, MPI_Request sender, MPI_Request receiver);
+  static bool match_types(MPI_Datatype stype, MPI_Datatype rtype);
 
 public:
   Request() = default;
index 5591e51..5ba3993 100644 (file)
@@ -57,6 +57,7 @@ Request::Request(const void* buf, int count, MPI_Datatype datatype, aid_t src, a
   detached_sender_ = nullptr;
   real_src_        = 0;
   truncated_       = false;
+  unmatched_types_ = false;
   real_size_       = 0;
   real_tag_        = 0;
   if (flags & MPI_REQ_PERSISTENT)
@@ -99,6 +100,23 @@ void Request::unref(MPI_Request* request)
   }
 }
 
+bool Request::match_types(MPI_Datatype stype, MPI_Datatype rtype){
+  bool match = false;
+  if ((stype == rtype) ||
+     //byte and packed always match with anything
+     (stype == MPI_PACKED || rtype == MPI_PACKED || stype == MPI_BYTE || rtype == MPI_BYTE) ||
+     //complex datatypes - we don't properly match these yet, as it would mean checking each subtype recursively.
+     (stype->flags() & DT_FLAG_DERIVED || rtype->flags() & DT_FLAG_DERIVED) ||
+     //duplicated datatypes, check if underlying is ok
+     (stype->duplicated_datatype()!=MPI_DATATYPE_NULL && match_types(stype->duplicated_datatype(), rtype)) ||
+     (rtype->duplicated_datatype()!=MPI_DATATYPE_NULL && match_types(stype, rtype->duplicated_datatype())))
+    match = true;
+  if (!match)
+    XBT_WARN("Mismatched datatypes : sending %s and receiving %s", stype->name().c_str(), rtype->name().c_str());
+  return match;
+}
+
+
 bool Request::match_common(MPI_Request req, MPI_Request sender, MPI_Request receiver)
 {
   xbt_assert(sender, "Cannot match against null sender");
@@ -125,6 +143,10 @@ bool Request::match_common(MPI_Request req, MPI_Request sender, MPI_Request rece
         receiver->real_size_=sender->real_size_;
       }
     }
+    //0-sized datatypes/counts should not interfere and match
+    if ( sender->real_size_ != 0 && receiver->real_size_ != 0 &&
+         !match_types(sender->type_, receiver->type_))
+      receiver->unmatched_types_ = true;
     if (sender->detached_)
       receiver->detached_sender_ = sender; // tie the sender to the receiver, as it is detached and has to be freed in
                                            // the receiver
@@ -961,17 +983,22 @@ void Request::finish_wait(MPI_Request* request, MPI_Status * status)
     req->action_ = nullptr;
   req->flags_ |= MPI_REQ_FINISHED;
 
-  if (req->truncated_) {
+  if (req->truncated_ || req->unmatched_types_) {
     char error_string[MPI_MAX_ERROR_STRING];
     int error_size;
-    PMPI_Error_string(MPI_ERR_TRUNCATE, error_string, &error_size);
+    int errkind;
+    if(req->truncated_ )
+      errkind = MPI_ERR_TRUNCATE;
+    else
+      errkind = MPI_ERR_TYPE;
+    PMPI_Error_string(errkind, error_string, &error_size);
     MPI_Errhandler err = (req->comm_) ? (req->comm_)->errhandler() : MPI_ERRHANDLER_NULL;
     if (err == MPI_ERRHANDLER_NULL || err == MPI_ERRORS_RETURN)
       XBT_WARN("recv - returned %.*s instead of MPI_SUCCESS", error_size, error_string);
     else if (err == MPI_ERRORS_ARE_FATAL)
       xbt_die("recv - returned %.*s instead of MPI_SUCCESS", error_size, error_string);
     else
-      err->call((req->comm_), MPI_ERR_TRUNCATE);
+      err->call((req->comm_), errkind);
     if (err != MPI_ERRHANDLER_NULL)
       simgrid::smpi::Errhandler::unref(err);
     MC_assert(not MC_is_active()); /* Only fail in MC mode */