Logo AND Algorithmique Numérique Distribuée

Public GIT Repository
Fix bug with immediate conflict detection
[simgrid.git] / src / mc / explo / udpor / ExtensionSetCalculator.cpp
index eaacbcfd144793660aaf32b530d5f7b4aa3b43fc..f1e76f4d95461b162f503af5f328b0744b9ed360 100644 (file)
@@ -26,7 +26,8 @@ EventSet ExtensionSetCalculator::partially_extend(const Configuration& C, Unfold
 
   const static HandlerMap handlers =
       HandlerMap{{Action::COMM_ASYNC_RECV, &ExtensionSetCalculator::partially_extend_CommRecv},
-                 {Action::COMM_ASYNC_SEND, &ExtensionSetCalculator::partially_extend_CommSend}};
+                 {Action::COMM_ASYNC_SEND, &ExtensionSetCalculator::partially_extend_CommSend},
+                 {Action::COMM_WAIT, &ExtensionSetCalculator::partially_extend_CommWait}};
 
   if (const auto handler = handlers.find(action->type_); handler != handlers.end()) {
     return handler->second(C, U, std::move(action));
@@ -114,7 +115,7 @@ EventSet ExtensionSetCalculator::partially_extend_CommRecv(const Configuration&
   // Com contains a matching c' = AsyncReceive(m, _) with a
   for (const auto e : C) {
     const bool transition_type_check = [&]() {
-      if (const auto* async_recv = dynamic_cast<const CommSendTransition*>(e->get_transition());
+      if (const auto* async_recv = dynamic_cast<const CommRecvTransition*>(e->get_transition());
           async_recv != nullptr && async_recv->get_mailbox() == recv_mailbox) {
         return true;
       }
@@ -143,13 +144,26 @@ EventSet ExtensionSetCalculator::partially_extend_CommWait(const Configuration&
   EventSet exC;
 
   const auto wait_action   = std::static_pointer_cast<CommWaitTransition>(std::move(action));
+  const auto wait_comm     = wait_action->get_comm();
   const auto pre_event_a_C = C.pre_event(wait_action->aid_);
 
   // Determine the _issuer_ of the communication of the `CommWait` event
   // in `C`. The issuer of the `CommWait` in `C` is the event in `C`
   // whose transition is the `CommRecv` or `CommSend` whose resulting
   // communication this `CommWait` waits on
-  const auto issuer = std::find_if(C.begin(), C.end(), [=](const UnfoldingEvent* e) { return false; });
+  const auto issuer = std::find_if(C.begin(), C.end(), [&](const UnfoldingEvent* e) {
+    if (const CommRecvTransition* e_issuer_receive = dynamic_cast<const CommRecvTransition*>(e->get_transition());
+        e_issuer_receive != nullptr) {
+      return e_issuer_receive->aid_ == wait_action->aid_ && wait_comm == e_issuer_receive->get_comm();
+    }
+
+    if (const CommSendTransition* e_issuer_send = dynamic_cast<const CommSendTransition*>(e->get_transition());
+        e_issuer_send != nullptr) {
+      return e_issuer_send->aid_ == wait_action->aid_ && wait_comm == e_issuer_send->get_comm();
+    }
+
+    return false;
+  });
   xbt_assert(issuer != C.end(),
              "Invariant violation! A (supposedly) enabled `CommWait` transition "
              "waiting on commiunication %lu should not be enabled: the receive/send "
@@ -159,6 +173,7 @@ EventSet ExtensionSetCalculator::partially_extend_CommWait(const Configuration&
              "a bug in SimGrid's UDPOR implementation",
              wait_action->get_comm(), wait_action->to_string(false).c_str());
   const UnfoldingEvent* e_issuer = *issuer;
+  const History e_issuer_history(e_issuer);
 
   // 1. if `a` is enabled at state(config({preEvt(a,C)})), then
   // create `e' := <a, config({preEvt(a,C)})>` and add `e'` to `ex(C)`
@@ -178,19 +193,63 @@ EventSet ExtensionSetCalculator::partially_extend_CommWait(const Configuration&
       // as needed to reach the receive/send number that is `issuer`.
       // ...
       // ...
-      if (e_issuer->get_transition()->type_ == Transition::Type::COMM_ASYNC_RECV) {
-
-        const unsigned send_position    = 0;
-        const unsigned receive_position = 0;
-        if (send_position == receive_position) {
+      if (const CommRecvTransition* e_issuer_receive =
+              dynamic_cast<const CommRecvTransition*>(e_issuer->get_transition());
+          e_issuer_receive != nullptr) {
+
+        const unsigned issuer_mailbox = e_issuer_receive->get_mailbox();
+
+        // Check from the config -> how many sends have there been
+        const unsigned send_position =
+            std::count_if(config_pre_event.begin(), config_pre_event.end(), [=](const auto e) {
+              const CommSendTransition* e_send = dynamic_cast<const CommSendTransition*>(e->get_transition());
+              if (e_send != nullptr) {
+                return e_send->get_mailbox() == issuer_mailbox;
+              }
+              return false;
+            });
+
+        // Check from e_issuer -> what place is the issuer in?
+        const unsigned receive_position =
+            std::count_if(e_issuer_history.begin(), e_issuer_history.end(), [=](const auto e) {
+              const CommRecvTransition* e_receive = dynamic_cast<const CommRecvTransition*>(e->get_transition());
+              if (e_receive != nullptr) {
+                return e_receive->get_mailbox() == issuer_mailbox;
+              }
+              return false;
+            });
+
+        if (send_position >= receive_position) {
           exC.insert(U->discover_event(EventSet({unwrapped_pre_event}), wait_action));
         }
 
-      } else if (e_issuer->get_transition()->type_ == Transition::Type::COMM_ASYNC_SEND) {
-
-        const unsigned send_position    = 0;
-        const unsigned receive_position = 0;
-        if (send_position == receive_position) {
+      } else if (const CommSendTransition* e_issuer_send =
+                     dynamic_cast<const CommSendTransition*>(e_issuer->get_transition());
+                 e_issuer_send != nullptr) {
+
+        const unsigned issuer_mailbox = e_issuer_send->get_mailbox();
+
+        // Check from e_issuer -> what place is the issuer in?
+        const unsigned send_position =
+            std::count_if(e_issuer_history.begin(), e_issuer_history.end(), [=](const auto e) {
+              const CommSendTransition* e_send = dynamic_cast<const CommSendTransition*>(e->get_transition());
+              if (e_send != nullptr) {
+                return e_send->get_mailbox() == issuer_mailbox;
+              }
+              return false;
+            });
+
+        // Check from the config -> how many sends have there been
+        const unsigned receive_position =
+            std::count_if(config_pre_event.begin(), config_pre_event.end(), [=](const auto e) {
+              const CommRecvTransition* e_receive = dynamic_cast<const CommRecvTransition*>(e->get_transition());
+              if (e_receive != nullptr) {
+                return e_receive->get_mailbox() == issuer_mailbox;
+              }
+              return false;
+            });
+
+        if (send_position <= receive_position) {
           exC.insert(U->discover_event(EventSet({unwrapped_pre_event}), wait_action));
         }
 
@@ -219,6 +278,12 @@ EventSet ExtensionSetCalculator::partially_extend_CommWait(const Configuration&
         continue;
       }
 
+      const auto issuer_mailbox        = e_issuer_send->get_mailbox();
+      const CommRecvTransition* e_recv = dynamic_cast<const CommRecvTransition*>(e->get_transition());
+      if (e_recv->get_mailbox() != issuer_mailbox) {
+        continue;
+      }
+
       // If the `issuer` is not in `config(K)`, this implies that
       // `WaitAny()` is always disabled in `config(K)`; hence, it
       // is independent of any transition in `config(K)` (according
@@ -229,17 +294,33 @@ EventSet ExtensionSetCalculator::partially_extend_CommWait(const Configuration&
         continue;
       }
 
-      // std::count_if(config_K.begin(), config_K.end(), [](const auto e) { return false; });
-
       // TODO: Compute the send and receive positions
-      const unsigned send_position = 0;
 
-      const unsigned receive_position = 0;
+      // What send # is the issuer
+      const unsigned send_position = std::count_if(e_issuer_history.begin(), e_issuer_history.end(), [=](const auto e) {
+        const CommSendTransition* e_send = dynamic_cast<const CommSendTransition*>(e->get_transition());
+        if (e_send != nullptr) {
+          return e_send->get_mailbox() == issuer_mailbox;
+        }
+        return false;
+      });
+
+      // What receive # is the event `e`?
+      const unsigned receive_position = std::count_if(config_K.begin(), config_K.end(), [=](const auto e) {
+        const CommRecvTransition* e_receive = dynamic_cast<const CommRecvTransition*>(e->get_transition());
+        if (e_receive != nullptr) {
+          return e_receive->get_mailbox() == issuer_mailbox;
+        }
+        return false;
+      });
+
       if (send_position == receive_position) {
         exC.insert(U->discover_event(std::move(K), wait_action));
       }
 
-    } else if (e_issuer->get_transition()->type_ == Transition::Type::COMM_ASYNC_RECV) {
+    } else if (const CommRecvTransition* e_issuer_recv =
+                   dynamic_cast<const CommRecvTransition*>(e_issuer->get_transition());
+               e_issuer_recv != nullptr) {
 
       // If the provider of the communication for `CommWait` is a
       // `CommRecv(m)`, then we only care about `e` if `λ(e) == `CommSend(m)`.
@@ -250,6 +331,12 @@ EventSet ExtensionSetCalculator::partially_extend_CommWait(const Configuration&
         continue;
       }
 
+      const auto issuer_mailbox        = e_issuer_recv->get_mailbox();
+      const CommSendTransition* e_send = dynamic_cast<const CommSendTransition*>(e->get_transition());
+      if (e_send->get_mailbox() != issuer_mailbox) {
+        continue;
+      }
+
       // If the `issuer` is not in `config(K)`, this implies that
       // `WaitAny()` is always disabled in `config(K)`; hence, it
       // is independent of any transition in `config(K)` (according
@@ -260,9 +347,25 @@ EventSet ExtensionSetCalculator::partially_extend_CommWait(const Configuration&
         continue;
       }
 
-      // TODO: Compute the send and receive positions
-      const unsigned send_position    = 0;
-      const unsigned receive_position = 0;
+      // What receive # is the event `e`?
+      const unsigned send_position = std::count_if(config_K.begin(), config_K.end(), [=](const auto e) {
+        const CommSendTransition* e_send = dynamic_cast<const CommSendTransition*>(e->get_transition());
+        if (e_send != nullptr) {
+          return e_send->get_mailbox() == issuer_mailbox;
+        }
+        return false;
+      });
+
+      // What send # is the issuer
+      const unsigned receive_position =
+          std::count_if(e_issuer_history.begin(), e_issuer_history.end(), [=](const auto e) {
+            const CommRecvTransition* e_receive = dynamic_cast<const CommRecvTransition*>(e->get_transition());
+            if (e_receive != nullptr) {
+              return e_receive->get_mailbox() == issuer_mailbox;
+            }
+            return false;
+          });
+
       if (send_position == receive_position) {
         exC.insert(U->discover_event(std::move(K), wait_action));
       }