]> AND Public Git Repository - simgrid.git/blobdiff - src/smpi/bindings/smpi_pmpi_request.cpp
Logo AND Algorithmique Numérique Distribuée

Public GIT Repository
use previous buffer check feature in MPI checks, to crash when a buffer overflow...
[simgrid.git] / src / smpi / bindings / smpi_pmpi_request.cpp
index ddc0a152eb3a0877a93b2962ebf3d9657bffb54f..e75d904147de3acbc892e94119a919faa4249bb2 100644 (file)
@@ -1,4 +1,4 @@
-/* Copyright (c) 2007-2020. The SimGrid Team. All rights reserved.          */
+/* Copyright (c) 2007-2021. The SimGrid Team. All rights reserved.          */
 
 /* This program is free software; you can redistribute it and/or modify it
  * under the terms of the license (GNU LGPL) which comes with this package. */
@@ -19,13 +19,14 @@ static int getPid(MPI_Comm comm, int id)
 }
 
 #define CHECK_SEND_INPUTS\
-  CHECK_BUFFER(1, buf, count)\
+  SET_BUF1(buf)\
   CHECK_COUNT(2, count)\
   CHECK_TYPE(3, datatype)\
-  CHECK_PROC(4, dst)\
-  CHECK_RANK(4, dst, comm)\
-  CHECK_TAG(5, tag)\
+  CHECK_BUFFER(1, buf, count, datatype)\
   CHECK_COMM(6)\
+  if(dst!= MPI_PROC_NULL)\
+    CHECK_RANK(4, dst, comm)\
+  CHECK_TAG(5, tag)
 
 #define CHECK_ISEND_INPUTS\
   CHECK_REQUEST(7)\
@@ -33,16 +34,16 @@ static int getPid(MPI_Comm comm, int id)
   CHECK_SEND_INPUTS
   
 #define CHECK_IRECV_INPUTS\
+  SET_BUF1(buf)\
   CHECK_REQUEST(7)\
   *request = MPI_REQUEST_NULL;\
-  CHECK_BUFFER(1, buf, count)\
   CHECK_COUNT(2, count)\
   CHECK_TYPE(3, datatype)\
-  CHECK_PROC(4, src)\
-  if(src!=MPI_ANY_SOURCE)\
+  CHECK_BUFFER(1, buf, count, datatype)\
+  CHECK_COMM(6)\
+  if(src!=MPI_ANY_SOURCE && src!=MPI_PROC_NULL)\
     CHECK_RANK(4, src, comm)\
-  CHECK_TAG(5, tag)\
-  CHECK_COMM(6)
+  CHECK_TAG(5, tag)
 /* PMPI User level calls */
 
 int PMPI_Send_init(const void *buf, int count, MPI_Datatype datatype, int dst, int tag, MPI_Comm comm, MPI_Request * request)
@@ -95,7 +96,7 @@ int PMPI_Start(MPI_Request * request)
   int retval = 0;
 
   smpi_bench_end();
-  CHECK_REQUEST(1)
+  CHECK_REQUEST_VALID(1)
   if ( *request == MPI_REQUEST_NULL) {
     retval = MPI_ERR_REQUEST;
   } else {
@@ -164,6 +165,7 @@ int PMPI_Request_free(MPI_Request * request)
 
   smpi_bench_end();
   if (*request != MPI_REQUEST_NULL) {
+    (*request)->mark_as_deleted();
     simgrid::smpi::Request::unref(request);
     *request = MPI_REQUEST_NULL;
     retval = MPI_SUCCESS;
@@ -237,10 +239,10 @@ int PMPI_Issend(const void* buf, int count, MPI_Datatype datatype, int dst, int
 int PMPI_Recv(void *buf, int count, MPI_Datatype datatype, int src, int tag, MPI_Comm comm, MPI_Status * status)
 {
   int retval = 0;
-
-  CHECK_BUFFER(1, buf, count)
+  SET_BUF1(buf)
   CHECK_COUNT(2, count)
   CHECK_TYPE(3, datatype)
+  CHECK_BUFFER(1, buf, count, datatype)
   CHECK_TAG(5, tag)
   CHECK_COMM(6)
 
@@ -260,8 +262,7 @@ int PMPI_Recv(void *buf, int count, MPI_Datatype datatype, int src, int tag, MPI
                                                        datatype->is_replayable() ? count : count * datatype->size(),
                                                        tag, simgrid::smpi::Datatype::encode(datatype)));
 
-    simgrid::smpi::Request::recv(buf, count, datatype, src, tag, comm, status);
-    retval = MPI_SUCCESS;
+    retval = simgrid::smpi::Request::recv(buf, count, datatype, src, tag, comm, status);
 
     // the src may not have been known at the beginning of the recv (MPI_ANY_SOURCE)
     int src_traced=0;
@@ -396,13 +397,15 @@ int PMPI_Sendrecv(const void* sendbuf, int sendcount, MPI_Datatype sendtype, int
                   int recvcount, MPI_Datatype recvtype, int src, int recvtag, MPI_Comm comm, MPI_Status* status)
 {
   int retval = 0;
-  CHECK_BUFFER(1, sendbuf, sendcount)
+  SET_BUF1(sendbuf)
+  SET_BUF2(recvbuf)
   CHECK_COUNT(2, sendcount)
   CHECK_TYPE(3, sendtype)
   CHECK_TAG(5, sendtag)
-  CHECK_BUFFER(6, recvbuf, recvcount)
   CHECK_COUNT(7, recvcount)
   CHECK_TYPE(8, recvtype)
+  CHECK_BUFFER(1, sendbuf, sendcount, sendtype)
+  CHECK_BUFFER(6, recvbuf, recvcount, recvtype)
   CHECK_TAG(10, recvtag)
   CHECK_COMM(11)
   smpi_bench_end();
@@ -416,8 +419,7 @@ int PMPI_Sendrecv(const void* sendbuf, int sendcount, MPI_Datatype sendtype, int
       simgrid::smpi::Request::send(sendbuf, sendcount, sendtype, dst, sendtag, comm);
     retval = MPI_SUCCESS;
   } else if (dst == MPI_PROC_NULL){
-    simgrid::smpi::Request::recv(recvbuf, recvcount, recvtype, src, recvtag, comm, status);
-    retval = MPI_SUCCESS;
+    retval = simgrid::smpi::Request::recv(recvbuf, recvcount, recvtype, src, recvtag, comm, status);
   } else if (dst >= comm->group()->size() || dst <0 ||
       (src!=MPI_ANY_SOURCE && (src >= comm->group()->size() || src <0))){
     retval = MPI_ERR_RANK;
@@ -427,8 +429,8 @@ int PMPI_Sendrecv(const void* sendbuf, int sendcount, MPI_Datatype sendtype, int
     int src_traced         = getPid(comm, src);
 
     // FIXME: Hack the way to trace this one
-    std::vector<int>* dst_hack = new std::vector<int>();
-    std::vector<int>* src_hack = new std::vector<int>();
+    auto dst_hack = std::make_shared<std::vector<int>>();
+    auto src_hack = std::make_shared<std::vector<int>>();
     dst_hack->push_back(dst_traced);
     src_hack->push_back(src_traced);
     TRACE_smpi_comm_in(my_proc_id, __func__,
@@ -455,18 +457,19 @@ int PMPI_Sendrecv_replace(void* buf, int count, MPI_Datatype datatype, int dst,
                           MPI_Comm comm, MPI_Status* status)
 {
   int retval = 0;
-  CHECK_BUFFER(1, buf, count)
+  SET_BUF1(buf)
   CHECK_COUNT(2, count)
   CHECK_TYPE(3, datatype)
+  CHECK_BUFFER(1, buf, count, datatype)
 
   int size = datatype->get_extent() * count;
   xbt_assert(size > 0);
-  void* recvbuf = xbt_new0(char, size);
-  retval = MPI_Sendrecv(buf, count, datatype, dst, sendtag, recvbuf, count, datatype, src, recvtag, comm, status);
+  std::vector<char> recvbuf(size);
+  retval =
+      MPI_Sendrecv(buf, count, datatype, dst, sendtag, recvbuf.data(), count, datatype, src, recvtag, comm, status);
   if(retval==MPI_SUCCESS){
-    simgrid::smpi::Datatype::copy(recvbuf, count, datatype, buf, count, datatype);
+    simgrid::smpi::Datatype::copy(recvbuf.data(), count, datatype, buf, count, datatype);
   }
-  xbt_free(recvbuf);
   return retval;
 }
 
@@ -550,6 +553,8 @@ int PMPI_Probe(int source, int tag, MPI_Comm comm, MPI_Status* status) {
   smpi_bench_end();
 
   CHECK_COMM(6)
+  if(source!=MPI_ANY_SOURCE && source!=MPI_PROC_NULL)\
+    CHECK_RANK(1, source, comm)
   CHECK_TAG(2, tag)
   if (source == MPI_PROC_NULL) {
     if (status != MPI_STATUS_IGNORE){
@@ -569,6 +574,8 @@ int PMPI_Iprobe(int source, int tag, MPI_Comm comm, int* flag, MPI_Status* statu
   int retval = 0;
   smpi_bench_end();
   CHECK_COMM(6)
+  if(source!=MPI_ANY_SOURCE && source!=MPI_PROC_NULL)\
+    CHECK_RANK(1, source, comm)
   CHECK_TAG(2, tag)
   if (flag == nullptr) {
     retval = MPI_ERR_ARG;
@@ -729,7 +736,7 @@ int PMPI_Cancel(MPI_Request* request)
   int retval = 0;
 
   smpi_bench_end();
-  CHECK_REQUEST(1)
+  CHECK_REQUEST_VALID(1)
   if (*request == MPI_REQUEST_NULL) {
     retval = MPI_ERR_REQUEST;
   } else {
@@ -762,7 +769,15 @@ int PMPI_Status_set_elements(MPI_Status* status, MPI_Datatype datatype, int coun
     return MPI_ERR_ARG;
   }
   simgrid::smpi::Status::set_elements(status,datatype, count);
-  return MPI_SUCCESS;  
+  return MPI_SUCCESS;
+}
+
+int PMPI_Status_set_elements_x(MPI_Status* status, MPI_Datatype datatype, MPI_Count count){
+  if(status==MPI_STATUS_IGNORE){
+    return MPI_ERR_ARG;
+  }
+  simgrid::smpi::Status::set_elements(status,datatype, static_cast<int>(count));
+  return MPI_SUCCESS;
 }
 
 int PMPI_Grequest_start( MPI_Grequest_query_function *query_fn, MPI_Grequest_free_function *free_fn, MPI_Grequest_cancel_function *cancel_fn, void *extra_state, MPI_Request *request){
@@ -778,7 +793,7 @@ int PMPI_Request_get_status( MPI_Request request, int *flag, MPI_Status *status)
     *flag=1;
     simgrid::smpi::Status::empty(status);
     return MPI_SUCCESS;
-  } else if (flag==NULL || status ==NULL){
+  } else if (flag == nullptr) {
     return MPI_ERR_ARG;
   }
   return simgrid::smpi::Request::get_status(request,flag,status);