From a9c055488291a864d81058c3619607b652495871 Mon Sep 17 00:00:00 2001 From: Augustin Degomme Date: Mon, 9 Dec 2019 18:43:42 +0100 Subject: [PATCH] Unify input checking using shared macros. Avoid repeating code. --- src/smpi/bindings/smpi_pmpi_coll.cpp | 32 --- src/smpi/bindings/smpi_pmpi_file.cpp | 130 +++++---- src/smpi/bindings/smpi_pmpi_request.cpp | 335 ++++++++---------------- src/smpi/include/private.hpp | 41 +++ 4 files changed, 213 insertions(+), 325 deletions(-) diff --git a/src/smpi/bindings/smpi_pmpi_coll.cpp b/src/smpi/bindings/smpi_pmpi_coll.cpp index 446501d769..0a97fa032b 100644 --- a/src/smpi/bindings/smpi_pmpi_coll.cpp +++ b/src/smpi/bindings/smpi_pmpi_coll.cpp @@ -13,38 +13,6 @@ XBT_LOG_EXTERNAL_DEFAULT_CATEGORY(smpi_pmpi); -#define CHECK_ARGS(test, errcode, ...) \ - if (test) { \ - XBT_WARN(__VA_ARGS__); \ - return (errcode); \ - } - -#define CHECK_COMM(num)\ - CHECK_ARGS(comm == MPI_COMM_NULL, MPI_ERR_COMM,\ - "%s: param %d communicator cannot be MPI_COMM_NULL", __func__, num); -#define CHECK_REQUEST(num)\ - CHECK_ARGS(request == nullptr, MPI_ERR_ARG,\ - "%s: param %d request cannot be NULL",__func__, num); -#define CHECK_BUFFER(num,buf,count)\ - CHECK_ARGS(buf == nullptr && count > 0, MPI_ERR_BUFFER,\ - "%s: param %d %s cannot be NULL if %s > 0",__func__, num, #buf, #count); -#define CHECK_COUNT(num,count)\ - CHECK_ARGS(count < 0, MPI_ERR_COUNT,\ - "%s: param %d %s cannot be negative", __func__, num, #count); -#define CHECK_TYPE(num, datatype)\ - CHECK_ARGS((datatype == MPI_DATATYPE_NULL|| not datatype->is_valid()), MPI_ERR_TYPE,\ - "%s: param %d %s cannot be MPI_DATATYPE_NULL or invalid", __func__, num, #datatype); -#define CHECK_OP(num)\ - CHECK_ARGS(op == MPI_OP_NULL, MPI_ERR_OP,\ - "%s: param %d op cannot be MPI_OP_NULL or invalid", __func__, num); -#define CHECK_ROOT(num)\ - CHECK_ARGS((root < 0 || root >= comm->size()), MPI_ERR_ROOT,\ - "%s: param %d root (=%d) cannot be negative or larger than communicator size (=%d)", __func__, num, root,\ - comm->size()); -#define CHECK_NULL(num,err,buf)\ - CHECK_ARGS(buf == nullptr, err,\ - "%s: param %d %s cannot be NULL", __func__, num, #buf); - static const void* smpi_get_in_place_buf(const void* inplacebuf, const void* otherbuf,std::unique_ptr& tmp_sendbuf, int count, MPI_Datatype datatype){ if (inplacebuf == MPI_IN_PLACE) { tmp_sendbuf.reset(new unsigned char[count * datatype->get_extent()]); diff --git a/src/smpi/bindings/smpi_pmpi_file.cpp b/src/smpi/bindings/smpi_pmpi_file.cpp index d8a1ceb03c..32e175092c 100644 --- a/src/smpi/bindings/smpi_pmpi_file.cpp +++ b/src/smpi/bindings/smpi_pmpi_file.cpp @@ -8,6 +8,8 @@ #include "smpi_file.hpp" #include "smpi_datatype.hpp" +XBT_LOG_EXTERNAL_DEFAULT_CATEGORY(smpi_pmpi); + extern MPI_Errhandler SMPI_default_File_Errhandler; int PMPI_File_open(MPI_Comm comm, const char *filename, int amode, MPI_Info info, MPI_File *fh){ @@ -38,21 +40,11 @@ int PMPI_File_close(MPI_File *fh){ smpi_bench_begin(); return ret; } -#define CHECK_FILE(fh) \ - if ((fh) == MPI_FILE_NULL) \ - return MPI_ERR_FILE; -#define CHECK_BUFFER(buf, count) \ - if ((buf) == nullptr && (count) > 0) \ - return MPI_ERR_BUFFER; -#define CHECK_COUNT(count) \ - if ((count) < 0) \ - return MPI_ERR_COUNT; + + #define CHECK_OFFSET(offset) \ if ((offset) < 0) \ return MPI_ERR_DISP; -#define CHECK_DATATYPE(datatype, count) \ - if ((datatype) == MPI_DATATYPE_NULL && (count) > 0) \ - return MPI_ERR_TYPE; #define CHECK_STATUS(status) \ if ((status) == nullptr) \ return MPI_ERR_ARG; @@ -70,7 +62,7 @@ int PMPI_File_close(MPI_File *fh){ } int PMPI_File_seek(MPI_File fh, MPI_Offset offset, int whence){ - CHECK_FILE(fh) + CHECK_FILE(1, fh) smpi_bench_end(); int ret = fh->seek(offset,whence); smpi_bench_begin(); @@ -78,7 +70,7 @@ int PMPI_File_seek(MPI_File fh, MPI_Offset offset, int whence){ } int PMPI_File_seek_shared(MPI_File fh, MPI_Offset offset, int whence){ - CHECK_FILE(fh) + CHECK_FILE(1, fh) smpi_bench_end(); int ret = fh->seek_shared(offset,whence); smpi_bench_begin(); @@ -95,7 +87,7 @@ int PMPI_File_get_position(MPI_File fh, MPI_Offset* offset){ } int PMPI_File_get_position_shared(MPI_File fh, MPI_Offset* offset){ - CHECK_FILE(fh) + CHECK_FILE(1, fh) if (offset==nullptr) return MPI_ERR_DISP; smpi_bench_end(); @@ -105,10 +97,10 @@ int PMPI_File_get_position_shared(MPI_File fh, MPI_Offset* offset){ } int PMPI_File_read(MPI_File fh, void *buf, int count,MPI_Datatype datatype, MPI_Status *status){ - CHECK_FILE(fh) - CHECK_BUFFER(buf, count) - CHECK_COUNT(count) - CHECK_DATATYPE(datatype, count) + CHECK_FILE(1, fh) + CHECK_BUFFER(2, buf, count) + CHECK_COUNT(3, count) + CHECK_TYPE(4, datatype) CHECK_STATUS(status) CHECK_FLAGS(fh) PASS_ZEROCOUNT(count) @@ -122,10 +114,10 @@ int PMPI_File_read(MPI_File fh, void *buf, int count,MPI_Datatype datatype, MPI_ } int PMPI_File_read_shared(MPI_File fh, void *buf, int count,MPI_Datatype datatype, MPI_Status *status){ - CHECK_FILE(fh) - CHECK_BUFFER(buf, count) - CHECK_COUNT(count) - CHECK_DATATYPE(datatype, count) + CHECK_FILE(1, fh) + CHECK_BUFFER(2, buf, count) + CHECK_COUNT(3, count) + CHECK_TYPE(4, datatype) CHECK_STATUS(status) CHECK_FLAGS(fh) PASS_ZEROCOUNT(count) @@ -140,10 +132,10 @@ int PMPI_File_read_shared(MPI_File fh, void *buf, int count,MPI_Datatype datatyp } int PMPI_File_write(MPI_File fh, const void *buf, int count,MPI_Datatype datatype, MPI_Status *status){ - CHECK_FILE(fh) - CHECK_BUFFER(buf, count) - CHECK_COUNT(count) - CHECK_DATATYPE(datatype, count) + CHECK_FILE(1, fh) + CHECK_BUFFER(2, buf, count) + CHECK_COUNT(3, count) + CHECK_TYPE(4, datatype) CHECK_STATUS(status) CHECK_FLAGS(fh) CHECK_RDONLY(fh) @@ -158,10 +150,10 @@ int PMPI_File_write(MPI_File fh, const void *buf, int count,MPI_Datatype datatyp } int PMPI_File_write_shared(MPI_File fh, const void *buf, int count,MPI_Datatype datatype, MPI_Status *status){ - CHECK_FILE(fh) - CHECK_BUFFER(buf, count) - CHECK_COUNT(count) - CHECK_DATATYPE(datatype, count) + CHECK_FILE(1, fh) + CHECK_BUFFER(2, buf, count) + CHECK_COUNT(3, count) + CHECK_TYPE(4, datatype) CHECK_STATUS(status) CHECK_FLAGS(fh) CHECK_RDONLY(fh) @@ -177,10 +169,10 @@ int PMPI_File_write_shared(MPI_File fh, const void *buf, int count,MPI_Datatype } int PMPI_File_read_all(MPI_File fh, void *buf, int count,MPI_Datatype datatype, MPI_Status *status){ - CHECK_FILE(fh) - CHECK_BUFFER(buf, count) - CHECK_COUNT(count) - CHECK_DATATYPE(datatype, count) + CHECK_FILE(1, fh) + CHECK_BUFFER(2, buf, count) + CHECK_COUNT(3, count) + CHECK_TYPE(4, datatype) CHECK_STATUS(status) CHECK_FLAGS(fh) smpi_bench_end(); @@ -193,10 +185,10 @@ int PMPI_File_read_all(MPI_File fh, void *buf, int count,MPI_Datatype datatype, } int PMPI_File_read_ordered(MPI_File fh, void *buf, int count,MPI_Datatype datatype, MPI_Status *status){ - CHECK_FILE(fh) - CHECK_BUFFER(buf, count) - CHECK_COUNT(count) - CHECK_DATATYPE(datatype, count) + CHECK_FILE(1, fh) + CHECK_BUFFER(2, buf, count) + CHECK_COUNT(3, count) + CHECK_TYPE(4, datatype) CHECK_STATUS(status) CHECK_FLAGS(fh) smpi_bench_end(); @@ -210,10 +202,10 @@ int PMPI_File_read_ordered(MPI_File fh, void *buf, int count,MPI_Datatype dataty } int PMPI_File_write_all(MPI_File fh, const void *buf, int count,MPI_Datatype datatype, MPI_Status *status){ - CHECK_FILE(fh) - CHECK_BUFFER(buf, count) - CHECK_COUNT(count) - CHECK_DATATYPE(datatype, count) + CHECK_FILE(1, fh) + CHECK_BUFFER(2, buf, count) + CHECK_COUNT(3, count) + CHECK_TYPE(4, datatype) CHECK_STATUS(status) CHECK_FLAGS(fh) CHECK_RDONLY(fh) @@ -227,10 +219,10 @@ int PMPI_File_write_all(MPI_File fh, const void *buf, int count,MPI_Datatype dat } int PMPI_File_write_ordered(MPI_File fh, const void *buf, int count,MPI_Datatype datatype, MPI_Status *status){ - CHECK_FILE(fh) - CHECK_BUFFER(buf, count) - CHECK_COUNT(count) - CHECK_DATATYPE(datatype, count) + CHECK_FILE(1, fh) + CHECK_BUFFER(2, buf, count) + CHECK_COUNT(3, count) + CHECK_TYPE(4, datatype) CHECK_STATUS(status) CHECK_FLAGS(fh) CHECK_RDONLY(fh) @@ -245,11 +237,11 @@ int PMPI_File_write_ordered(MPI_File fh, const void *buf, int count,MPI_Datatype } int PMPI_File_read_at(MPI_File fh, MPI_Offset offset, void *buf, int count,MPI_Datatype datatype, MPI_Status *status){ - CHECK_FILE(fh) - CHECK_BUFFER(buf, count) + CHECK_FILE(1, fh) + CHECK_BUFFER(2, buf, count) CHECK_OFFSET(offset) - CHECK_COUNT(count) - CHECK_DATATYPE(datatype, count) + CHECK_COUNT(3, count) + CHECK_TYPE(4, datatype) CHECK_STATUS(status) CHECK_FLAGS(fh) PASS_ZEROCOUNT(count); @@ -266,11 +258,11 @@ int PMPI_File_read_at(MPI_File fh, MPI_Offset offset, void *buf, int count,MPI_D } int PMPI_File_read_at_all(MPI_File fh, MPI_Offset offset, void *buf, int count,MPI_Datatype datatype, MPI_Status *status){ - CHECK_FILE(fh) - CHECK_BUFFER(buf, count) + CHECK_FILE(1, fh) + CHECK_BUFFER(2, buf, count) CHECK_OFFSET(offset) - CHECK_COUNT(count) - CHECK_DATATYPE(datatype, count) + CHECK_COUNT(3, count) + CHECK_TYPE(4, datatype) CHECK_STATUS(status) CHECK_FLAGS(fh) smpi_bench_end(); @@ -287,11 +279,11 @@ int PMPI_File_read_at_all(MPI_File fh, MPI_Offset offset, void *buf, int count,M } int PMPI_File_write_at(MPI_File fh, MPI_Offset offset, const void *buf, int count,MPI_Datatype datatype, MPI_Status *status){ - CHECK_FILE(fh) - CHECK_BUFFER(buf, count) + CHECK_FILE(1, fh) + CHECK_BUFFER(2, buf, count) CHECK_OFFSET(offset) - CHECK_COUNT(count) - CHECK_DATATYPE(datatype, count) + CHECK_COUNT(4, count) + CHECK_TYPE(5, datatype) CHECK_STATUS(status) CHECK_FLAGS(fh) CHECK_RDONLY(fh) @@ -309,11 +301,11 @@ int PMPI_File_write_at(MPI_File fh, MPI_Offset offset, const void *buf, int coun } int PMPI_File_write_at_all(MPI_File fh, MPI_Offset offset, const void *buf, int count,MPI_Datatype datatype, MPI_Status *status){ - CHECK_FILE(fh) - CHECK_BUFFER(buf, count) + CHECK_FILE(1, fh) + CHECK_BUFFER(2, buf, count) CHECK_OFFSET(offset) - CHECK_COUNT(count) - CHECK_DATATYPE(datatype, count) + CHECK_COUNT(4, count) + CHECK_TYPE(5, datatype) CHECK_STATUS(status) CHECK_FLAGS(fh) CHECK_RDONLY(fh) @@ -341,42 +333,42 @@ int PMPI_File_delete(const char *filename, MPI_Info info){ int PMPI_File_get_info(MPI_File fh, MPI_Info* info) { - CHECK_FILE(fh) + CHECK_FILE(1, fh) *info = fh->info(); return MPI_SUCCESS; } int PMPI_File_set_info(MPI_File fh, MPI_Info info) { - CHECK_FILE(fh) + CHECK_FILE(1, fh) fh->set_info(info); return MPI_SUCCESS; } int PMPI_File_get_size(MPI_File fh, MPI_Offset* size) { - CHECK_FILE(fh) + CHECK_FILE(1, fh) *size = fh->size(); return MPI_SUCCESS; } int PMPI_File_get_amode(MPI_File fh, int* amode) { - CHECK_FILE(fh) + CHECK_FILE(1, fh) *amode = fh->flags(); return MPI_SUCCESS; } int PMPI_File_get_group(MPI_File fh, MPI_Group* group) { - CHECK_FILE(fh) + CHECK_FILE(1, fh) *group = fh->comm()->group(); return MPI_SUCCESS; } int PMPI_File_sync(MPI_File fh) { - CHECK_FILE(fh) + CHECK_FILE(1, fh) fh->sync(); return MPI_SUCCESS; } diff --git a/src/smpi/bindings/smpi_pmpi_request.cpp b/src/smpi/bindings/smpi_pmpi_request.cpp index 62e4822d57..3caec9503d 100644 --- a/src/smpi/bindings/smpi_pmpi_request.cpp +++ b/src/smpi/bindings/smpi_pmpi_request.cpp @@ -18,52 +18,49 @@ static int getPid(MPI_Comm comm, int id) return (actor == nullptr) ? MPI_UNDEFINED : actor->get_pid(); } +#define CHECK_SEND_INPUTS\ + CHECK_BUFFER(1, buf, count)\ + CHECK_COUNT(2, count)\ + CHECK_TYPE(3, datatype)\ + CHECK_PROC(4, dst)\ + CHECK_TAG(5, tag)\ + CHECK_COMM(6)\ + +#define CHECK_ISEND_INPUTS\ + CHECK_REQUEST(7)\ + *request = MPI_REQUEST_NULL;\ + CHECK_SEND_INPUTS + +#define CHECK_IRECV_INPUTS\ + CHECK_REQUEST(7)\ + *request = MPI_REQUEST_NULL;\ + CHECK_BUFFER(1, buf, count)\ + CHECK_COUNT(2, count)\ + CHECK_TYPE(3, datatype)\ + CHECK_PROC(4, src)\ + CHECK_TAG(5, tag)\ + CHECK_COMM(6) /* PMPI User level calls */ int PMPI_Send_init(const void *buf, int count, MPI_Datatype datatype, int dst, int tag, MPI_Comm comm, MPI_Request * request) { - int retval = 0; + CHECK_ISEND_INPUTS smpi_bench_end(); - if (request == nullptr) { - retval = MPI_ERR_ARG; - } else if (comm == MPI_COMM_NULL) { - retval = MPI_ERR_COMM; - } else if (datatype==MPI_DATATYPE_NULL || not datatype->is_valid()) { - retval = MPI_ERR_TYPE; - } else if (dst == MPI_PROC_NULL) { - retval = MPI_SUCCESS; - } else { - *request = simgrid::smpi::Request::send_init(buf, count, datatype, dst, tag, comm); - retval = MPI_SUCCESS; - } + *request = simgrid::smpi::Request::send_init(buf, count, datatype, dst, tag, comm); smpi_bench_begin(); - if (retval != MPI_SUCCESS && request != nullptr) - *request = MPI_REQUEST_NULL; - return retval; + + return MPI_SUCCESS; } int PMPI_Recv_init(void *buf, int count, MPI_Datatype datatype, int src, int tag, MPI_Comm comm, MPI_Request * request) { - int retval = 0; + CHECK_IRECV_INPUTS smpi_bench_end(); - if (request == nullptr) { - retval = MPI_ERR_ARG; - } else if (comm == MPI_COMM_NULL) { - retval = MPI_ERR_COMM; - } else if (datatype==MPI_DATATYPE_NULL || not datatype->is_valid()) { - retval = MPI_ERR_TYPE; - } else if (src == MPI_PROC_NULL) { - retval = MPI_SUCCESS; - } else { - *request = simgrid::smpi::Request::recv_init(buf, count, datatype, src, tag, comm); - retval = MPI_SUCCESS; - } + *request = simgrid::smpi::Request::recv_init(buf, count, datatype, src, tag, comm); smpi_bench_begin(); - if (retval != MPI_SUCCESS && request != nullptr) - *request = MPI_REQUEST_NULL; - return retval; + return MPI_SUCCESS; } int PMPI_Rsend_init(const void* buf, int count, MPI_Datatype datatype, int dst, int tag, MPI_Comm comm, @@ -74,24 +71,14 @@ int PMPI_Rsend_init(const void* buf, int count, MPI_Datatype datatype, int dst, int PMPI_Ssend_init(const void* buf, int count, MPI_Datatype datatype, int dst, int tag, MPI_Comm comm, MPI_Request* request) { - int retval = 0; + CHECK_ISEND_INPUTS + int retval = 0; smpi_bench_end(); - if (request == nullptr) { - retval = MPI_ERR_ARG; - } else if (comm == MPI_COMM_NULL) { - retval = MPI_ERR_COMM; - } else if (datatype==MPI_DATATYPE_NULL || not datatype->is_valid()) { - retval = MPI_ERR_TYPE; - } else if (dst == MPI_PROC_NULL) { - retval = MPI_SUCCESS; - } else { - *request = simgrid::smpi::Request::ssend_init(buf, count, datatype, dst, tag, comm); - retval = MPI_SUCCESS; - } + *request = simgrid::smpi::Request::ssend_init(buf, count, datatype, dst, tag, comm); + retval = MPI_SUCCESS; + smpi_bench_begin(); - if (retval != MPI_SUCCESS && request != nullptr) - *request = MPI_REQUEST_NULL; return retval; } @@ -105,7 +92,8 @@ int PMPI_Start(MPI_Request * request) int retval = 0; smpi_bench_end(); - if (request == nullptr || *request == MPI_REQUEST_NULL) { + CHECK_REQUEST(1) + if ( *request == MPI_REQUEST_NULL) { retval = MPI_ERR_REQUEST; } else { MPI_Request req = *request; @@ -183,25 +171,12 @@ int PMPI_Request_free(MPI_Request * request) int PMPI_Irecv(void *buf, int count, MPI_Datatype datatype, int src, int tag, MPI_Comm comm, MPI_Request * request) { - int retval = 0; + CHECK_IRECV_INPUTS smpi_bench_end(); - - if (request == nullptr) { - retval = MPI_ERR_ARG; - } else if (comm == MPI_COMM_NULL) { - retval = MPI_ERR_COMM; - } else if (src == MPI_PROC_NULL) { - *request = MPI_REQUEST_NULL; - retval = MPI_SUCCESS; - } else if (src!=MPI_ANY_SOURCE && (src >= comm->group()->size() || src <0)){ + int retval = 0; + if (src!=MPI_ANY_SOURCE && (src >= comm->group()->size() || src <0)){ retval = MPI_ERR_RANK; - } else if ((count < 0) || (buf==nullptr && count > 0)) { - retval = MPI_ERR_COUNT; - } else if (datatype==MPI_DATATYPE_NULL || not datatype->is_valid()) { - retval = MPI_ERR_TYPE; - } else if(tag<0 && tag != MPI_ANY_TAG){ - retval = MPI_ERR_TAG; } else { int my_proc_id = simgrid::s4u::this_actor::get_pid(); @@ -217,32 +192,18 @@ int PMPI_Irecv(void *buf, int count, MPI_Datatype datatype, int src, int tag, MP } smpi_bench_begin(); - if (retval != MPI_SUCCESS && request != nullptr) - *request = MPI_REQUEST_NULL; return retval; } int PMPI_Isend(const void *buf, int count, MPI_Datatype datatype, int dst, int tag, MPI_Comm comm, MPI_Request * request) { - int retval = 0; + CHECK_ISEND_INPUTS smpi_bench_end(); - if (request == nullptr) { - retval = MPI_ERR_ARG; - } else if (comm == MPI_COMM_NULL) { - retval = MPI_ERR_COMM; - } else if (dst == MPI_PROC_NULL) { - *request = MPI_REQUEST_NULL; - retval = MPI_SUCCESS; - } else if (dst >= comm->group()->size() || dst <0){ + int retval = 0; + if (dst >= comm->group()->size() || dst <0){ retval = MPI_ERR_RANK; - } else if ((count < 0) || (buf==nullptr && count > 0)) { - retval = MPI_ERR_COUNT; - } else if (datatype==MPI_DATATYPE_NULL || not datatype->is_valid()) { - retval = MPI_ERR_TYPE; - } else if(tag<0 && tag != MPI_ANY_TAG){ - retval = MPI_ERR_TAG; } else { int my_proc_id = simgrid::s4u::this_actor::get_pid(); int trace_dst = getPid(comm, dst); @@ -260,8 +221,7 @@ int PMPI_Isend(const void *buf, int count, MPI_Datatype datatype, int dst, int t } smpi_bench_begin(); - if (retval != MPI_SUCCESS && request!=nullptr) - *request = MPI_REQUEST_NULL; + return retval; } @@ -273,24 +233,12 @@ int PMPI_Irsend(const void* buf, int count, MPI_Datatype datatype, int dst, int int PMPI_Issend(const void* buf, int count, MPI_Datatype datatype, int dst, int tag, MPI_Comm comm, MPI_Request* request) { - int retval = 0; + CHECK_ISEND_INPUTS smpi_bench_end(); - if (request == nullptr) { - retval = MPI_ERR_ARG; - } else if (comm == MPI_COMM_NULL) { - retval = MPI_ERR_COMM; - } else if (dst == MPI_PROC_NULL) { - *request = MPI_REQUEST_NULL; - retval = MPI_SUCCESS; - } else if (dst >= comm->group()->size() || dst <0){ + int retval = 0; + if (dst >= comm->group()->size() || dst <0){ retval = MPI_ERR_RANK; - } else if ((count < 0)|| (buf==nullptr && count > 0)) { - retval = MPI_ERR_COUNT; - } else if (datatype==MPI_DATATYPE_NULL || not datatype->is_valid()) { - retval = MPI_ERR_TYPE; - } else if(tag<0 && tag != MPI_ANY_TAG){ - retval = MPI_ERR_TAG; } else { int my_proc_id = simgrid::s4u::this_actor::get_pid(); int trace_dst = getPid(comm, dst); @@ -307,8 +255,6 @@ int PMPI_Issend(const void* buf, int count, MPI_Datatype datatype, int dst, int } smpi_bench_begin(); - if (retval != MPI_SUCCESS && request!=nullptr) - *request = MPI_REQUEST_NULL; return retval; } @@ -316,10 +262,14 @@ int PMPI_Recv(void *buf, int count, MPI_Datatype datatype, int src, int tag, MPI { int retval = 0; + CHECK_BUFFER(1, buf, count) + CHECK_COUNT(2, count) + CHECK_TYPE(3, datatype) + CHECK_TAG(5, tag) + CHECK_COMM(6) + smpi_bench_end(); - if (comm == MPI_COMM_NULL) { - retval = MPI_ERR_COMM; - } else if (src == MPI_PROC_NULL) { + if (src == MPI_PROC_NULL) { if(status != MPI_STATUS_IGNORE){ simgrid::smpi::Status::empty(status); status->MPI_SOURCE = MPI_PROC_NULL; @@ -327,12 +277,6 @@ int PMPI_Recv(void *buf, int count, MPI_Datatype datatype, int src, int tag, MPI retval = MPI_SUCCESS; } else if (src!=MPI_ANY_SOURCE && (src >= comm->group()->size() || src <0)){ retval = MPI_ERR_RANK; - } else if ((count < 0) || (buf==nullptr && count > 0)) { - retval = MPI_ERR_COUNT; - } else if (datatype==MPI_DATATYPE_NULL || not datatype->is_valid()) { - retval = MPI_ERR_TYPE; - } else if(tag<0 && tag != MPI_ANY_TAG){ - retval = MPI_ERR_TAG; } else { int my_proc_id = simgrid::s4u::this_actor::get_pid(); TRACE_smpi_comm_in(my_proc_id, __func__, @@ -362,22 +306,12 @@ int PMPI_Recv(void *buf, int count, MPI_Datatype datatype, int src, int tag, MPI int PMPI_Send(const void *buf, int count, MPI_Datatype datatype, int dst, int tag, MPI_Comm comm) { - int retval = 0; + CHECK_SEND_INPUTS smpi_bench_end(); - - if (comm == MPI_COMM_NULL) { - retval = MPI_ERR_COMM; - } else if (dst == MPI_PROC_NULL) { - retval = MPI_SUCCESS; - } else if (dst >= comm->group()->size() || dst <0){ + int retval = 0; + if (dst >= comm->group()->size() || dst <0){ retval = MPI_ERR_RANK; - } else if ((count < 0) || (buf == nullptr && count > 0)) { - retval = MPI_ERR_COUNT; - } else if (datatype==MPI_DATATYPE_NULL || not datatype->is_valid()) { - retval = MPI_ERR_TYPE; - } else if(tag < 0 && tag != MPI_ANY_TAG){ - retval = MPI_ERR_TAG; } else { int my_proc_id = simgrid::s4u::this_actor::get_pid(); int dst_traced = getPid(comm, dst); @@ -406,22 +340,12 @@ int PMPI_Rsend(const void* buf, int count, MPI_Datatype datatype, int dst, int t int PMPI_Bsend(const void* buf, int count, MPI_Datatype datatype, int dst, int tag, MPI_Comm comm) { - int retval = 0; + CHECK_SEND_INPUTS smpi_bench_end(); - - if (comm == MPI_COMM_NULL) { - retval = MPI_ERR_COMM; - } else if (dst == MPI_PROC_NULL) { - retval = MPI_SUCCESS; - } else if (dst >= comm->group()->size() || dst <0){ + int retval = 0; + if (dst >= comm->group()->size() || dst <0){ retval = MPI_ERR_RANK; - } else if ((count < 0) || (buf == nullptr && count > 0)) { - retval = MPI_ERR_COUNT; - } else if (datatype==MPI_DATATYPE_NULL || not datatype->is_valid()) { - retval = MPI_ERR_TYPE; - } else if(tag < 0 && tag != MPI_ANY_TAG){ - retval = MPI_ERR_TAG; } else { int my_proc_id = simgrid::s4u::this_actor::get_pid(); int dst_traced = getPid(comm, dst); @@ -451,24 +375,12 @@ int PMPI_Bsend(const void* buf, int count, MPI_Datatype datatype, int dst, int t int PMPI_Ibsend(const void* buf, int count, MPI_Datatype datatype, int dst, int tag, MPI_Comm comm, MPI_Request* request) { - int retval = 0; + CHECK_ISEND_INPUTS smpi_bench_end(); - if (request == nullptr) { - retval = MPI_ERR_ARG; - } else if (comm == MPI_COMM_NULL) { - retval = MPI_ERR_COMM; - } else if (dst == MPI_PROC_NULL) { - *request = MPI_REQUEST_NULL; - retval = MPI_SUCCESS; - } else if (dst >= comm->group()->size() || dst <0){ + int retval = 0; + if (dst >= comm->group()->size() || dst <0){ retval = MPI_ERR_RANK; - } else if ((count < 0) || (buf==nullptr && count > 0)) { - retval = MPI_ERR_COUNT; - } else if (datatype==MPI_DATATYPE_NULL || not datatype->is_valid()) { - retval = MPI_ERR_TYPE; - } else if(tag<0 && tag != MPI_ANY_TAG){ - retval = MPI_ERR_TAG; } else { int my_proc_id = simgrid::s4u::this_actor::get_pid(); int trace_dst = getPid(comm, dst); @@ -499,51 +411,31 @@ int PMPI_Ibsend(const void* buf, int count, MPI_Datatype datatype, int dst, int int PMPI_Bsend_init(const void* buf, int count, MPI_Datatype datatype, int dst, int tag, MPI_Comm comm, MPI_Request* request) { - int retval = 0; + CHECK_ISEND_INPUTS smpi_bench_end(); - if (request == nullptr) { - retval = MPI_ERR_ARG; - } else if (comm == MPI_COMM_NULL) { - retval = MPI_ERR_COMM; - } else if (datatype==MPI_DATATYPE_NULL || not datatype->is_valid()) { - retval = MPI_ERR_TYPE; - } else if (dst == MPI_PROC_NULL) { - retval = MPI_SUCCESS; + int retval = 0; + int bsend_buf_size = 0; + void* bsend_buf = nullptr; + smpi_process()->bsend_buffer(&bsend_buf, &bsend_buf_size); + if( bsend_buf==nullptr || bsend_buf_size < datatype->get_extent() * count + MPI_BSEND_OVERHEAD ) { + retval = MPI_ERR_BUFFER; } else { - int bsend_buf_size = 0; - void* bsend_buf = nullptr; - smpi_process()->bsend_buffer(&bsend_buf, &bsend_buf_size); - if( bsend_buf==nullptr || bsend_buf_size < datatype->get_extent() * count + MPI_BSEND_OVERHEAD ) { - retval = MPI_ERR_BUFFER; - } else { - *request = simgrid::smpi::Request::bsend_init(buf, count, datatype, dst, tag, comm); - retval = MPI_SUCCESS; - } + *request = simgrid::smpi::Request::bsend_init(buf, count, datatype, dst, tag, comm); + retval = MPI_SUCCESS; } smpi_bench_begin(); - if (retval != MPI_SUCCESS && request != nullptr) - *request = MPI_REQUEST_NULL; return retval; } -int PMPI_Ssend(const void* buf, int count, MPI_Datatype datatype, int dst, int tag, MPI_Comm comm) { - int retval = 0; +int PMPI_Ssend(const void* buf, int count, MPI_Datatype datatype, int dst, int tag, MPI_Comm comm) +{ + CHECK_SEND_INPUTS smpi_bench_end(); - - if (comm == MPI_COMM_NULL) { - retval = MPI_ERR_COMM; - } else if (dst == MPI_PROC_NULL) { - retval = MPI_SUCCESS; - } else if (dst >= comm->group()->size() || dst <0){ + int retval = 0; + if (dst >= comm->group()->size() || dst <0){ retval = MPI_ERR_RANK; - } else if ((count < 0) || (buf==nullptr && count > 0)) { - retval = MPI_ERR_COUNT; - } else if (datatype==MPI_DATATYPE_NULL || not datatype->is_valid()) { - retval = MPI_ERR_TYPE; - } else if(tag<0 && tag != MPI_ANY_TAG){ - retval = MPI_ERR_TAG; } else { int my_proc_id = simgrid::s4u::this_actor::get_pid(); int dst_traced = getPid(comm, dst); @@ -569,12 +461,16 @@ int PMPI_Sendrecv(const void* sendbuf, int sendcount, MPI_Datatype sendtype, int int retval = 0; smpi_bench_end(); - - if (comm == MPI_COMM_NULL) { - retval = MPI_ERR_COMM; - } else if (not sendtype->is_valid() || not recvtype->is_valid()) { - retval = MPI_ERR_TYPE; - } else if (src == MPI_PROC_NULL) { + CHECK_BUFFER(1, sendbuf, sendcount) + CHECK_COUNT(2, sendcount) + CHECK_TYPE(3, sendtype) + CHECK_TAG(5, sendtag) + CHECK_BUFFER(6, recvbuf, recvcount) + CHECK_COUNT(7, recvcount) + CHECK_TYPE(8, recvtype) + CHECK_TAG(10, recvtag) + CHECK_COMM(11) + if (src == MPI_PROC_NULL) { if(status!=MPI_STATUS_IGNORE){ simgrid::smpi::Status::empty(status); status->MPI_SOURCE = MPI_PROC_NULL; @@ -582,17 +478,12 @@ int PMPI_Sendrecv(const void* sendbuf, int sendcount, MPI_Datatype sendtype, int if(dst != MPI_PROC_NULL) simgrid::smpi::Request::send(sendbuf, sendcount, sendtype, dst, sendtag, comm); retval = MPI_SUCCESS; - }else if (dst == MPI_PROC_NULL){ + } else if (dst == MPI_PROC_NULL){ simgrid::smpi::Request::recv(recvbuf, recvcount, recvtype, src, recvtag, comm, status); retval = MPI_SUCCESS; - }else if (dst >= comm->group()->size() || dst <0 || + } else if (dst >= comm->group()->size() || dst <0 || (src!=MPI_ANY_SOURCE && (src >= comm->group()->size() || src <0))){ retval = MPI_ERR_RANK; - } else if ((sendcount < 0 || recvcount<0) || - (sendbuf==nullptr && sendcount > 0) || (recvbuf==nullptr && recvcount>0)) { - retval = MPI_ERR_COUNT; - } else if((sendtag<0 && sendtag != MPI_ANY_TAG)||(recvtag<0 && recvtag != MPI_ANY_TAG)){ - retval = MPI_ERR_TAG; } else { int my_proc_id = simgrid::s4u::this_actor::get_pid(); int dst_traced = getPid(comm, dst); @@ -627,19 +518,17 @@ int PMPI_Sendrecv_replace(void* buf, int count, MPI_Datatype datatype, int dst, MPI_Comm comm, MPI_Status* status) { int retval = 0; - if (datatype==MPI_DATATYPE_NULL || not datatype->is_valid()) { - return MPI_ERR_TYPE; - } else if (count < 0) { - return MPI_ERR_COUNT; - } else { - int size = datatype->get_extent() * count; - void* recvbuf = xbt_new0(char, size); - retval = MPI_Sendrecv(buf, count, datatype, dst, sendtag, recvbuf, count, datatype, src, recvtag, comm, status); - if(retval==MPI_SUCCESS){ - simgrid::smpi::Datatype::copy(recvbuf, count, datatype, buf, count, datatype); - } - xbt_free(recvbuf); + CHECK_BUFFER(1, buf, count) + CHECK_COUNT(2, count) + CHECK_TYPE(3, datatype) + + int size = datatype->get_extent() * count; + void* recvbuf = xbt_new0(char, size); + retval = MPI_Sendrecv(buf, count, datatype, dst, sendtag, recvbuf, count, datatype, src, recvtag, comm, status); + if(retval==MPI_SUCCESS){ + simgrid::smpi::Datatype::copy(recvbuf, count, datatype, buf, count, datatype); } + xbt_free(recvbuf); return retval; } @@ -659,7 +548,6 @@ int PMPI_Test(MPI_Request * request, int *flag, MPI_Status * status) int my_proc_id = ((*request)->comm() != MPI_COMM_NULL) ? simgrid::s4u::this_actor::get_pid() : -1; TRACE_smpi_comm_in(my_proc_id, __func__, new simgrid::instr::NoOpTIData("test")); - retval = simgrid::smpi::Request::test(request,status, flag); TRACE_smpi_comm_out(my_proc_id); @@ -671,7 +559,7 @@ int PMPI_Test(MPI_Request * request, int *flag, MPI_Status * status) int PMPI_Testany(int count, MPI_Request requests[], int *index, int *flag, MPI_Status * status) { int retval = 0; - + CHECK_COUNT(1, count) smpi_bench_end(); if (index == nullptr || flag == nullptr) { retval = MPI_ERR_ARG; @@ -688,7 +576,7 @@ int PMPI_Testany(int count, MPI_Request requests[], int *index, int *flag, MPI_S int PMPI_Testall(int count, MPI_Request* requests, int* flag, MPI_Status* statuses) { int retval = 0; - + CHECK_COUNT(1, count) smpi_bench_end(); if (flag == nullptr) { retval = MPI_ERR_ARG; @@ -705,7 +593,7 @@ int PMPI_Testall(int count, MPI_Request* requests, int* flag, MPI_Status* status int PMPI_Testsome(int incount, MPI_Request requests[], int* outcount, int* indices, MPI_Status status[]) { int retval = 0; - + CHECK_COUNT(1, incount) smpi_bench_end(); if (outcount == nullptr) { retval = MPI_ERR_ARG; @@ -723,9 +611,9 @@ int PMPI_Probe(int source, int tag, MPI_Comm comm, MPI_Status* status) { int retval = 0; smpi_bench_end(); - if (comm == MPI_COMM_NULL) { - retval = MPI_ERR_COMM; - } else if (source == MPI_PROC_NULL) { + CHECK_COMM(6) + CHECK_TAG(2, tag) + if (source == MPI_PROC_NULL) { if (status != MPI_STATUS_IGNORE){ simgrid::smpi::Status::empty(status); status->MPI_SOURCE = MPI_PROC_NULL; @@ -742,11 +630,10 @@ int PMPI_Probe(int source, int tag, MPI_Comm comm, MPI_Status* status) { int PMPI_Iprobe(int source, int tag, MPI_Comm comm, int* flag, MPI_Status* status) { int retval = 0; smpi_bench_end(); - + CHECK_COMM(6) + CHECK_TAG(2, tag) if (flag == nullptr) { retval = MPI_ERR_ARG; - } else if (comm == MPI_COMM_NULL) { - retval = MPI_ERR_COMM; } else if (source == MPI_PROC_NULL) { *flag=true; if (status != MPI_STATUS_IGNORE){ @@ -787,9 +674,8 @@ int PMPI_Wait(MPI_Request * request, MPI_Status * status) simgrid::smpi::Status::empty(status); - if (request == nullptr) { - retval = MPI_ERR_ARG; - } else if (*request == MPI_REQUEST_NULL) { + CHECK_REQUEST(1) + if (*request == MPI_REQUEST_NULL) { retval = MPI_SUCCESS; } else { // for tracing, save the handle which might get overridden before we can use the helper on it @@ -858,7 +744,7 @@ int PMPI_Waitany(int count, MPI_Request requests[], int *index, MPI_Status * sta int PMPI_Waitall(int count, MPI_Request requests[], MPI_Status status[]) { smpi_bench_end(); - + CHECK_COUNT(1, count) // for tracing, save the handles which might get overridden before we can use the helper on it std::vector savedreqs(requests, requests + count); for (MPI_Request& req : savedreqs) { @@ -889,7 +775,7 @@ int PMPI_Waitall(int count, MPI_Request requests[], MPI_Status status[]) int PMPI_Waitsome(int incount, MPI_Request requests[], int *outcount, int *indices, MPI_Status status[]) { int retval = 0; - + CHECK_COUNT(1, incount) smpi_bench_end(); if (outcount == nullptr) { retval = MPI_ERR_ARG; @@ -906,6 +792,7 @@ int PMPI_Cancel(MPI_Request* request) int retval = 0; smpi_bench_end(); + CHECK_REQUEST(1) if (*request == MPI_REQUEST_NULL) { retval = MPI_ERR_REQUEST; } else { diff --git a/src/smpi/include/private.hpp b/src/smpi/include/private.hpp index 5238bb508f..f8cb15f449 100644 --- a/src/smpi/include/private.hpp +++ b/src/smpi/include/private.hpp @@ -498,4 +498,45 @@ XBT_PUBLIC smpi_trace_call_location_t* smpi_trace_get_call_location(); XBT_PRIVATE void private_execute_flops(double flops); + +#define CHECK_ARGS(test, errcode, ...) \ + if (test) { \ + XBT_WARN(__VA_ARGS__); \ + return (errcode); \ + } + +#define CHECK_COMM(num) \ + CHECK_ARGS(comm == MPI_COMM_NULL, MPI_ERR_COMM, \ + "%s: param %d communicator cannot be MPI_COMM_NULL", __func__, num); +#define CHECK_REQUEST(num) \ + CHECK_ARGS(request == nullptr, MPI_ERR_REQUEST, \ + "%s: param %d request cannot be NULL",__func__, num); +#define CHECK_BUFFER(num,buf,count) \ + CHECK_ARGS(buf == nullptr && count > 0, MPI_ERR_BUFFER, \ + "%s: param %d %s cannot be NULL if %s > 0",__func__, num, #buf, #count); +#define CHECK_COUNT(num,count) \ + CHECK_ARGS(count < 0, MPI_ERR_COUNT, \ + "%s: param %d %s cannot be negative", __func__, num, #count); +#define CHECK_TYPE(num, datatype) \ + CHECK_ARGS((datatype == MPI_DATATYPE_NULL|| not datatype->is_valid()), MPI_ERR_TYPE, \ + "%s: param %d %s cannot be MPI_DATATYPE_NULL or invalid", __func__, num, #datatype); +#define CHECK_OP(num) \ + CHECK_ARGS(op == MPI_OP_NULL, MPI_ERR_OP, \ + "%s: param %d op cannot be MPI_OP_NULL or invalid", __func__, num); +#define CHECK_ROOT(num)\ + CHECK_ARGS((root < 0 || root >= comm->size()), MPI_ERR_ROOT, \ + "%s: param %d root (=%d) cannot be negative or larger than communicator size (=%d)", __func__, num, root, \ + comm->size()); +#define CHECK_NULL(num,err,buf) \ + CHECK_ARGS(buf == nullptr, err, \ + "%s: param %d %s cannot be NULL", __func__, num, #buf); +#define CHECK_PROC(num,proc) \ + CHECK_ARGS(proc == MPI_PROC_NULL, MPI_SUCCESS, \ + "%s: param %d %s cannot be MPI_PROC_NULL", __func__, num, #proc); +#define CHECK_TAG(num,tag) \ + CHECK_ARGS((tag<0 && tag != MPI_ANY_TAG), MPI_ERR_TAG, \ + "%s: param %d %s cannot be negative", __func__, num, #tag); +#define CHECK_FILE(num, fh) \ + CHECK_ARGS(fh == MPI_FILE_NULL, MPI_ERR_FILE, \ + "%s: param %d %s cannot be MPI_PROC_NULL", __func__, num, #fh); #endif -- 2.20.1