Program Listing for File RdmaWindowManager.hpp¶
↰ Return to documentation for file (include/uitsl/distributed/RdmaWindowManager.hpp)
#pragma once
#ifndef UITSL_DISTRIBUTED_RDMAWINDOWMANAGER_HPP_INCLUDE
#define UITSL_DISTRIBUTED_RDMAWINDOWMANAGER_HPP_INCLUDE
#include <cstddef>
#include <set>
#include <stddef.h>
#include <thread>
#include <unordered_map>
#include <unordered_map>
#include <mpi.h>
#include "../mpi/group_utils.hpp"
#include "../mpi/mpi_utils.hpp"
#include "RdmaWindow.hpp"
namespace uitsl {
// TODO is it possible to have a seperate window/communicator
// between each pair of procs?
class RdmaWindowManager {
std::unordered_map<proc_id_t, RdmaWindow> windows{};
mutable std::mutex mutex;
std::set<proc_id_t> GetSortedRanks() {
std::set<proc_id_t> res;
std::transform(
std::begin(windows),
std::end(windows),
std::inserter(res, std::begin(res)),
[](const auto & kv_pair){ return kv_pair.first; }
);
return res;
}
bool IsInitialized() {
return std::any_of(
std::begin(windows),
std::end(windows),
[](const auto & key_value){
const auto & [rank, window] = key_value;
return window.IsInitialized();
}
);
}
public:
~RdmaWindowManager() {
// sort ranks to prevent deadlock
for (proc_id_t rank : GetSortedRanks()) {
windows.erase(rank);
}
}
// TODO cache line alignment?
size_t Acquire(
const proc_id_t rank,
const emp::vector<std::byte>& initial_bytes
) {
// make this call thread safe
const std::lock_guard guard{mutex};
emp_assert( !IsInitialized() );
return windows[rank].Acquire(initial_bytes);
}
std::byte *GetBytes(const proc_id_t rank, const size_t byte_offset) {
emp_assert( IsInitialized() );
emp_assert( windows.count(rank) );
return windows.at(rank).GetBytes(byte_offset);
}
const MPI_Win& GetWindow(const proc_id_t rank) {
emp_assert( IsInitialized() );
emp_assert( windows.count(rank) );
return windows.at(rank).GetWindow();
}
void LockExclusive(const proc_id_t rank) {
emp_assert( IsInitialized() );
emp_assert( windows.count(rank) );
return windows.at(rank).LockExclusive();
}
void LockShared(const proc_id_t rank) {
emp_assert( IsInitialized() );
emp_assert( windows.count(rank) );
return windows.at(rank).LockShared();
}
void Unlock(const proc_id_t rank) {
emp_assert( IsInitialized() );
emp_assert( windows.count(rank) );
return windows.at(rank).Unlock();
}
void Put(
const proc_id_t rank,
const std::byte *origin_addr,
const size_t num_bytes,
const MPI_Aint target_disp
) {
emp_assert( IsInitialized() );
emp_assert( windows.count(rank) );
windows.at(rank).Put(origin_addr, num_bytes, target_disp);
}
void Rput(
const proc_id_t rank,
const std::byte *origin_addr,
const size_t num_bytes,
const MPI_Aint target_disp,
MPI_Request *request
) {
emp_assert( IsInitialized() );
emp_assert( windows.count(rank) );
emp_assert( uitsl::test_null(*request) );
windows.at(rank).Rput(origin_addr, num_bytes, target_disp, request);
}
template<typename T>
void Accumulate(
const proc_id_t rank,
const std::byte *origin_addr,
const size_t num_bytes,
const MPI_Aint target_disp
) {
emp_assert( IsInitialized() );
emp_assert( windows.count(rank) );
windows.at(rank).Accumulate<T>(origin_addr, num_bytes, target_disp);
}
template<typename T>
void Raccumulate(
const proc_id_t rank,
const std::byte *origin_addr,
const size_t num_bytes,
const MPI_Aint target_disp,
MPI_Request *request
) {
emp_assert( IsInitialized() );
emp_assert( windows.count(rank) );
emp_assert( uitsl::test_null(*request) );
windows.at(
rank
).Raccumulate<T>(origin_addr, num_bytes, target_disp, request);
}
void Initialize(MPI_Comm comm=MPI_COMM_WORLD) {
emp_assert(!IsInitialized());
// sort ranks to prevent deadlock
for (proc_id_t rank : GetSortedRanks()) {
MPI_Comm dyad{
uitsl::group_to_comm(
uitsl::make_group(
{rank, uitsl::get_rank(comm)},
uitsl::comm_to_group(comm)
),
comm
)
};
windows.at(rank).Initialize(
uitsl::translate_comm_rank(rank, comm, dyad),
dyad
);
}
// ensure that RputDucts have received target offsets
UITSL_Barrier(comm);
emp_assert(windows.empty() || IsInitialized());
}
std::string ToString() {
std::stringstream ss;
ss << uitsl::format_member("windows.size()", windows.size()) << std::endl;
for (proc_id_t rank : GetSortedRanks()) {
ss << uitsl::format_member("rank", rank) << std::endl;
ss << uitsl::format_member("window", windows.at(rank).ToString()) << std::endl;
}
return ss.str();
}
};
} // namespace uitsl
#endif // #ifndef UITSL_DISTRIBUTED_RDMAWINDOWMANAGER_HPP_INCLUDE