Logo AND Algorithmique Numérique Distribuée

Public GIT Repository
SMPI: add utility to keep the order of collective calls performed by each process...
authorAugustin Degomme <adegomme@users.noreply.github.com>
Sun, 20 Mar 2022 20:40:34 +0000 (21:40 +0100)
committerAugustin Degomme <adegomme@users.noreply.github.com>
Sun, 20 Mar 2022 20:40:34 +0000 (21:40 +0100)
It's not activated by default, and needs --cfg=smpi/pedantic:true option, as it may store too much data in memory for now
For each comm we maintain a vector of encountered collective calls
Each process stores the amount of calls performed in each communicator
At each new one we compare the amount of calls to the size of the corresponding vector
if we are the first process, add the call to the list
if not, compare its name to the one at the corresponding place in the vector, and cry for help if it's not the same.
Kudos mquinson and MBI for the idea.

src/smpi/include/smpi_comm.hpp
src/smpi/include/smpi_utils.hpp
src/smpi/internals/smpi_utils.cpp
src/smpi/mpi/smpi_comm.cpp

index 87ff9b2..82bf722 100644 (file)
@@ -42,6 +42,9 @@ class Comm : public F2C, public Keyval{
 
   std::unordered_map<std::string, unsigned int> sent_messages_;
   std::unordered_map<std::string, unsigned int> recv_messages_;
+  unsigned int collectives_count_ = 0;
+  unsigned int* collectives_counts_ = nullptr; //for MPI_COMM_WORLD only
+
 
 public:
   static std::unordered_map<int, smpi_key_elem> keyvals_;
@@ -97,6 +100,9 @@ public:
   void increment_sent_messages_count(int src, int dst, int tag);
   unsigned int get_received_messages_count(int src, int dst, int tag);
   void increment_received_messages_count(int src, int dst, int tag);
+  unsigned int get_collectives_count();
+  void increment_collectives_count();
+
 };
 
 } // namespace smpi
index a531494..1125ad1 100644 (file)
@@ -8,6 +8,7 @@
 #include <xbt/base.h>
 
 #include "smpi_f2c.hpp"
+#include "smpi_comm.hpp"
 
 #include <cstddef>
 #include <string>
@@ -35,6 +36,7 @@ XBT_PUBLIC void set_current_handle(F2C* handle);
 XBT_PUBLIC void set_current_buffer(int i, const char* name, const void* handle);
 XBT_PUBLIC size_t get_buffer_size(const void* ptr);
 XBT_PUBLIC void account_free(const void* ptr);
+XBT_PUBLIC int check_collectives_ordering(MPI_Comm comm, std::string call);
 
 } // namespace utils
 } // namespace smpi
index 837cedf..686bc14 100644 (file)
@@ -11,6 +11,7 @@
 #include "src/surf/xml/platf.hpp"
 #include "xbt/file.hpp"
 #include "xbt/log.h"
+#include "xbt/ex.h"
 #include "xbt/parse_units.hpp"
 #include "xbt/sysdep.h"
 #include <algorithm>
@@ -48,6 +49,8 @@ current_buffer_metadata_t current_buffer2;
 
 std::unordered_map<const void*, alloc_metadata_t> allocs;
 
+std::unordered_map<int, std::vector<std::string>> collective_calls;
+
 std::vector<s_smpi_factor_t> parse_factor(const std::string& smpi_coef_string)
 {
   std::vector<s_smpi_factor_t> smpi_factor;
@@ -345,6 +348,30 @@ void account_free(const void* ptr){
   }
 }
 
+int check_collectives_ordering(MPI_Comm comm, std::string call){
+  if(_smpi_cfg_pedantic){
+    unsigned int count = comm->get_collectives_count();
+    comm->increment_collectives_count();
+    auto vec = collective_calls.find(comm->id());
+    if (vec == collective_calls.end()) {
+      collective_calls.emplace(comm->id(), std::vector<std::string>{call});
+    }else{
+      //are we the first ? add the call
+      if (vec->second.size() == (count)){
+        vec->second.push_back(call);
+      } else if (vec->second.size() > count){
+        if (vec->second[count] != call){
+          XBT_WARN("Collective communication mismatch. For process %ld, expected %s, got %s", simgrid::s4u::this_actor::get_pid(), vec->second[count].c_str(), call.c_str());
+          return MPI_ERR_OTHER;
+        }
+      } else {
+        THROW_IMPOSSIBLE;
+      }
+    }
+  }
+  return MPI_SUCCESS;
+}
+
 }
 }
 } // namespace simgrid
index b482e21..84f604d 100644 (file)
@@ -365,6 +365,8 @@ void Comm::unref(Comm* comm){
       delete[] comm->errhandlers_;
     } else if (comm->errhandler_ != MPI_ERRHANDLER_NULL)
       simgrid::smpi::Errhandler::unref(comm->errhandler_);
+    if(comm->collectives_counts_!=nullptr)
+      delete[] comm->collectives_counts_;
   }
   Group::unref(comm->group_);
   if(comm->refcount_==0)
@@ -650,5 +652,31 @@ void Comm::increment_received_messages_count(int src, int dst, int tag)
   recv_messages_[hash_message(src, dst, tag)]++;
 }
 
+unsigned int Comm::get_collectives_count()
+{
+  if (this==MPI_COMM_UNINITIALIZED){
+    return smpi_process()->comm_world()->get_collectives_count();
+  }else if(this == MPI_COMM_WORLD || this == smpi_process()->comm_world()){
+    if(collectives_counts_==nullptr)
+      collectives_counts_=new unsigned int[this->size()]{0};
+    return collectives_counts_[this->rank()];
+  }else{
+    return collectives_count_;
+  }
+}
+
+void Comm::increment_collectives_count()
+{
+   if (this==MPI_COMM_UNINITIALIZED){
+    smpi_process()->comm_world()->increment_collectives_count();
+  }else if (this == MPI_COMM_WORLD || this == smpi_process()->comm_world()){
+    if(collectives_counts_==nullptr)
+      collectives_counts_=new unsigned int[this->size()]{0};
+    collectives_counts_[this->rank()]++;
+  }else{
+    collectives_count_++;
+  }
+}
+
 } // namespace smpi
 } // namespace simgrid