Logo AND Algorithmique Numérique Distribuée

Public GIT Repository
Add factorisation for strategy and Use dynamic over static cast
[simgrid.git] / src / mc / api / strategy / MaxMatchComm.hpp
index 4cb5891..a2f1d60 100644 (file)
@@ -26,7 +26,7 @@ class MaxMatchComm : public Strategy {
 public:
   void copy_from(const Strategy* strategy) override
   {
-    const MaxMatchComm* cast_strategy = static_cast<MaxMatchComm const*>(strategy);
+    const MaxMatchComm* cast_strategy = dynamic_cast<MaxMatchComm const*>(strategy);
     xbt_assert(cast_strategy != nullptr);
     for (auto& [id, val] : cast_strategy->mailbox_)
       mailbox_[id] = val;
@@ -41,78 +41,58 @@ public:
   MaxMatchComm()                     = default;
   ~MaxMatchComm() override           = default;
 
-  std::pair<aid_t, int> next_transition() const override
+  std::pair<aid_t, int> best_transition(bool must_be_todo) const override
   {
-    std::pair<aid_t, int> if_no_match = std::make_pair(-1, 0);
+          std::pair<aid_t, int> min_found = std::make_pair(-1, value_of_state_+2);
     for (auto const& [aid, actor] : actors_to_run_) {
-      if (not actor.is_todo() || not actor.is_enabled() || actor.is_done())
-        continue;
-
-      const Transition* transition = actor.get_transition(actor.get_times_considered());
-
-      const CommRecvTransition* cast_recv = static_cast<CommRecvTransition const*>(transition);
-      if (cast_recv != nullptr and mailbox_.count(cast_recv->get_mailbox()) > 0 and
-          mailbox_.at(cast_recv->get_mailbox()) > 0)
-        return std::make_pair(aid, value_of_state_ - 1); // This means we have waiting send corresponding to this recv
-
-      const CommSendTransition* cast_send = static_cast<CommSendTransition const*>(transition);
-      if (cast_send != nullptr and mailbox_.count(cast_send->get_mailbox()) > 0 and
-          mailbox_.at(cast_send->get_mailbox()) < 0)
-        return std::make_pair(aid, value_of_state_ - 1); // This means we have waiting recv corresponding to this send
-
-      if (if_no_match.first == -1)
-        if_no_match = std::make_pair(aid, value_of_state_);
+       if ((not actor.is_todo() && must_be_todo) || not actor.is_enabled() || actor.is_done())
+           continue;
+
+      int aid_value = value_of_state_;
+      const Transition* transition = actor.get_transition(actor.get_times_considered()).get();
+     
+      const CommRecvTransition* cast_recv = dynamic_cast<CommRecvTransition const*>(transition);
+      if (cast_recv != nullptr) {
+         if (mailbox_.count(cast_recv->get_mailbox()) > 0 and
+             mailbox_.at(cast_recv->get_mailbox()) > 0) { 
+             aid_value--; // This means we have waiting recv corresponding to this recv
+         } else { 
+             aid_value++; 
+
+         }
+      }
+   
+      const CommSendTransition* cast_send = dynamic_cast<CommSendTransition const*>(transition);
+      if (cast_send != nullptr) {
+         if (mailbox_.count(cast_send->get_mailbox()) > 0 and
+             mailbox_.at(cast_send->get_mailbox()) < 0) {
+             aid_value--; // This means we have waiting recv corresponding to this send
+         }else {
+             aid_value++;
+         }
+      }
+   
+      if (aid_value < min_found.second)
+         min_found = std::make_pair(aid, aid_value);
     }
-    return if_no_match;
+    return min_found;
   }
 
+
   void execute_next(aid_t aid, RemoteApp& app) override
   {
-    const Transition* transition = actors_to_run_.at(aid).get_transition(actors_to_run_.at(aid).get_times_considered());
+    const Transition* transition = actors_to_run_.at(aid).get_transition(actors_to_run_.at(aid).get_times_considered()).get();
     last_transition_             = transition->type_;
 
-    const CommRecvTransition* cast_recv = static_cast<CommRecvTransition const*>(transition);
+    const CommRecvTransition* cast_recv = dynamic_cast<CommRecvTransition const*>(transition);
     if (cast_recv != nullptr)
       last_mailbox_ = cast_recv->get_mailbox();
 
-    const CommSendTransition* cast_send = static_cast<CommSendTransition const*>(transition);
+    const CommSendTransition* cast_send = dynamic_cast<CommSendTransition const*>(transition);
     if (cast_send != nullptr)
       last_mailbox_ = cast_send->get_mailbox();
   }
 
-  void consider_best() override
-  {
-    for (auto& [aid, actor] : actors_to_run_)
-      if (actor.is_todo())
-        return;
-
-    for (auto& [aid, actor] : actors_to_run_) {
-      if (not actor.is_enabled() || actor.is_done())
-        continue;
-
-      const Transition* transition = actor.get_transition(actor.get_times_considered());
-
-      const CommRecvTransition* cast_recv = static_cast<CommRecvTransition const*>(transition);
-      if (cast_recv != nullptr and mailbox_.count(cast_recv->get_mailbox()) > 0 and
-          mailbox_.at(cast_recv->get_mailbox()) > 0) {
-        actor.mark_todo();
-        return;
-      }
-
-      const CommSendTransition* cast_send = static_cast<CommSendTransition const*>(transition);
-      if (cast_send != nullptr and mailbox_.count(cast_send->get_mailbox()) > 0 and
-          mailbox_.at(cast_send->get_mailbox()) < 0) {
-        actor.mark_todo();
-        return;
-      }
-    }
-    for (auto& [_, actor] : actors_to_run_) {
-      if (actor.is_enabled() and not actor.is_done()) {
-        actor.mark_todo();
-        return;
-      }
-    }
-  }
 };
 
 } // namespace simgrid::mc