From 1409071f3f26af4879d31ed7e5fdb3df445638fc Mon Sep 17 00:00:00 2001 From: Augustin Degomme Date: Sun, 30 Jan 2022 21:55:21 +0100 Subject: [PATCH] SMPI: enforce MPI message ordering. Fix #100 (hopefully) This adds an id for each message, representing the amount of messages sent from one process to another with the same tag. On receiver side, the number of received messages from each source/tag must be kept the same way and is compared to the message id. If the count is not identical, the message is early and overtook another message, so don't match with it yet. Hopefully this does not cause too much memory overhead, but it can be quite a lot in some cases. todo: cleanup the way probes are handled --- src/smpi/include/smpi_comm.hpp | 7 ++++++ src/smpi/include/smpi_request.hpp | 1 + src/smpi/mpi/smpi_comm.cpp | 26 +++++++++++++++++++++- src/smpi/mpi/smpi_request.cpp | 37 +++++++++++++++++++++++++++---- 4 files changed, 66 insertions(+), 5 deletions(-) diff --git a/src/smpi/include/smpi_comm.hpp b/src/smpi/include/smpi_comm.hpp index 7ef18c9912..d8dd8e85d9 100644 --- a/src/smpi/include/smpi_comm.hpp +++ b/src/smpi/include/smpi_comm.hpp @@ -40,6 +40,8 @@ class Comm : public F2C, public Keyval{ MPI_Errhandler errhandler_ = _smpi_cfg_default_errhandler_is_error ? MPI_ERRORS_ARE_FATAL : MPI_ERRORS_RETURN;; MPI_Errhandler* errhandlers_ = nullptr; //for MPI_COMM_WORLD only + std::unordered_map sent_messages_; + std::unordered_map recv_messages_; public: static std::unordered_map keyvals_; static int keyval_id_; @@ -90,6 +92,11 @@ public: void remove_rma_win(MPI_Win win); void finish_rma_calls() const; MPI_Comm split_type(int type, int key, const Info* info); + unsigned int get_sent_messages_count(int src, int dst, int tag); + void increment_sent_messages_count(int src, int dst, int tag); + unsigned int get_received_messages_count(int src, int dst, int tag); + void increment_received_messages_count(int src, int dst, int tag); + }; } // namespace smpi diff --git a/src/smpi/include/smpi_request.hpp b/src/smpi/include/smpi_request.hpp index e3a30898a3..a91b6ab948 100644 --- a/src/smpi/include/smpi_request.hpp +++ b/src/smpi/include/smpi_request.hpp @@ -48,6 +48,7 @@ class Request : public F2C { bool detached_; MPI_Request detached_sender_; int refcount_; + unsigned int message_id_; MPI_Op op_; std::unique_ptr generalized_funcs; std::vector nbc_requests_; diff --git a/src/smpi/mpi/smpi_comm.cpp b/src/smpi/mpi/smpi_comm.cpp index 193d88b8ef..683a64a350 100644 --- a/src/smpi/mpi/smpi_comm.cpp +++ b/src/smpi/mpi/smpi_comm.cpp @@ -360,7 +360,7 @@ void Comm::unref(Comm* comm){ simgrid::smpi::Info::unref(comm->info_); if(comm->errhandlers_!=nullptr){ for (int i=0; isize(); i++) - if (comm->errhandlers_[i]!=MPI_ERRHANDLER_NULL) + if (comm->errhandlers_[i]!=MPI_ERRHANDLER_NULL) simgrid::smpi::Errhandler::unref(comm->errhandlers_[i]); delete[] comm->errhandlers_; } else if (comm->errhandler_ != MPI_ERRHANDLER_NULL) @@ -630,5 +630,29 @@ MPI_Comm Comm::split_type(int type, int /*key*/, const Info*) } } +static inline std::string hash_message(int src, int dst, int tag){ + return std::to_string(tag) + '_' + std::to_string(src) + '_' + std::to_string(dst); +} + +unsigned int Comm::get_sent_messages_count(int src, int dst, int tag) +{ + return sent_messages_[hash_message(src, dst, tag)]; +} + +void Comm::increment_sent_messages_count(int src, int dst, int tag) +{ + sent_messages_[hash_message(src, dst, tag)]++; +} + +unsigned int Comm::get_received_messages_count(int src, int dst, int tag) +{ + return recv_messages_[hash_message(src, dst, tag)]; +} + +void Comm::increment_received_messages_count(int src, int dst, int tag) +{ + recv_messages_[hash_message(src, dst, tag)]++; +} + } // namespace smpi } // namespace simgrid diff --git a/src/smpi/mpi/smpi_request.cpp b/src/smpi/mpi/smpi_request.cpp index 2b54eacf68..07b98cb826 100644 --- a/src/smpi/mpi/smpi_request.cpp +++ b/src/smpi/mpi/smpi_request.cpp @@ -69,6 +69,7 @@ Request::Request(const void* buf, int count, MPI_Datatype datatype, aid_t src, a refcount_ = 1; else refcount_ = 0; + message_id_ = 0; init_buffer(count); this->add_f(); } @@ -146,8 +147,6 @@ bool Request::match_common(MPI_Request req, MPI_Request sender, MPI_Request rece if (receiver->real_size_ < sender->real_size_){ XBT_DEBUG("Truncating message - should not happen: receiver size : %zu < sender size : %zu", receiver->real_size_, sender->real_size_); receiver->truncated_ = true; - } else if (receiver->real_size_ > sender->real_size_){ - receiver->real_size_=sender->real_size_; } } //0-sized datatypes/counts should not interfere and match @@ -186,7 +185,24 @@ bool Request::match_recv(void* a, void* b, simgrid::kernel::activity::CommImpl*) { auto ref = static_cast(a); auto req = static_cast(b); - return match_common(req, req, ref); + bool match = match_common(req, req, ref); + if (match && (ref->comm_ != MPI_COMM_UNINITIALIZED) && !ref->comm_->is_smp_comm()){ + if (ref->comm_->get_received_messages_count(ref->comm_->group()->rank(req->src_), ref->comm_->group()->rank(req->dst_), req->tag_) == req->message_id_ ){ + if (((ref->flags_ & MPI_REQ_PROBE) == 0 ) && ((req->flags_ & MPI_REQ_PROBE) == 0)){ + XBT_DEBUG("increasing count in comm %p, which was %u from pid %ld, to pid %ld with tag %d", ref->comm_, ref->comm_->get_received_messages_count(ref->comm_->group()->rank(req->src_), ref->comm_->group()->rank(req->dst_), req->tag_), req->src_, req->dst_, req->tag_); + ref->comm_->increment_received_messages_count(ref->comm_->group()->rank(req->src_), ref->comm_->group()->rank(req->dst_), req->tag_); + if (ref->real_size_ > req->real_size_){ + ref->real_size_=req->real_size_; + } + } + } else { + match = false; + req->flags_ &= ~MPI_REQ_MATCHED; + ref->detached_sender_=nullptr; + XBT_DEBUG("Refusing to match message, as its ID is not the one I expect. in comm %p, %u != %u, from pid %ld to pid %ld, with tag %d",ref->comm_, ref->comm_->get_received_messages_count(ref->comm_->group()->rank(req->src_), ref->comm_->group()->rank(req->dst_), req->tag_), req->message_id_ , req->src_, req->dst_, req->tag_); + } + } + return match; } bool Request::match_send(void* a, void* b, simgrid::kernel::activity::CommImpl*) @@ -444,6 +460,9 @@ void Request::start() if (smpi_cfg_async_small_thresh() != 0 || (flags_ & MPI_REQ_RMA) != 0) mut->lock(); + bool is_probe = ((flags_ & MPI_REQ_PROBE) != 0); + flags_ |= MPI_REQ_PROBE; + if (smpi_cfg_async_small_thresh() == 0 && (flags_ & MPI_REQ_RMA) == 0) { mailbox = process->mailbox(); } else if (((flags_ & MPI_REQ_RMA) != 0) || static_cast(size_) < smpi_cfg_async_small_thresh()) { @@ -463,7 +482,7 @@ void Request::start() mailbox = process->mailbox_small(); } } else { - XBT_DEBUG("yes there was something for us in the large mailbox"); + XBT_DEBUG("yes there was something for us in the small mailbox"); } } else { mailbox = process->mailbox_small(); @@ -477,6 +496,8 @@ void Request::start() XBT_DEBUG("yes there was something for us in the small mailbox"); } } + if(!is_probe) + flags_ &= ~MPI_REQ_PROBE; action_ = simcall_comm_irecv( process->get_actor()->get_impl(), mailbox->get_impl(), buf_, &real_size_, &match_recv, @@ -492,6 +513,9 @@ void Request::start() TRACE_smpi_send(src_, src_, dst_, tag_, size_); this->print_request("New send"); + message_id_=comm_->get_sent_messages_count(comm_->group()->rank(src_), comm_->group()->rank(dst_), tag_); + comm_->increment_sent_messages_count(comm_->group()->rank(src_), comm_->group()->rank(dst_), tag_); + void* buf = buf_; if ((flags_ & MPI_REQ_SSEND) == 0 && ((flags_ & MPI_REQ_RMA) != 0 || (flags_ & MPI_REQ_BSEND) != 0 || @@ -541,6 +565,9 @@ void Request::start() if (not(smpi_cfg_async_small_thresh() != 0 || (flags_ & MPI_REQ_RMA) != 0)) { mailbox = process->mailbox(); } else if (((flags_ & MPI_REQ_RMA) != 0) || static_cast(size_) < smpi_cfg_async_small_thresh()) { // eager mode + bool is_probe = ((flags_ & MPI_REQ_PROBE) != 0); + flags_ |= MPI_REQ_PROBE; + mailbox = process->mailbox(); XBT_DEBUG("Is there a corresponding recv already posted in the large mailbox %s?", mailbox->get_cname()); simgrid::kernel::activity::ActivityImplPtr action = mailbox->iprobe(1, &match_send, static_cast(this)); @@ -562,6 +589,8 @@ void Request::start() } else { XBT_DEBUG("Yes there was something for us in the large mailbox"); } + if(!is_probe) + flags_ &= ~MPI_REQ_PROBE; } else { mailbox = process->mailbox(); XBT_DEBUG("Send request %p is in the large mailbox %s (buf: %p)", this, mailbox->get_cname(), buf_); -- 2.20.1