Program Listing for File Gatherer.hpp

Return to documentation for file (include/uitsl/concurrent/Gatherer.hpp)

#pragma once
#ifndef UITSL_CONCURRENT_GATHERER_HPP_INCLUDE
#define UITSL_CONCURRENT_GATHERER_HPP_INCLUDE

#include <algorithm>
#include <numeric>
#include <thread>

#include <mpi.h>

#include "../../../third-party/Empirical/source/base/optional.h"
#include "../../../third-party/Empirical/source/base/vector.h"

#include "../containers/safe/deque.hpp"
#include "../mpi/audited_routines.hpp"
#include "../mpi/mpi_utils.hpp"

namespace uitsl {

template<typename T>
class Gatherer {

  uitsl::safe::deque<T> items;

  // TODO use template metaprogramming to automatically deduce this
  MPI_Datatype mpi_type;
  MPI_Comm comm;

  emp::vector<int> GatherCounts(const int root) {

    const int count = items.size();
    emp::vector<int> res(get_nprocs());

    UITSL_Gather(
      &count, // const void *sendbuf,
      1, // int sendcount,
      MPI_INT, // MPI_Datatype sendtype,
      res.data(), // void *recvbuf,
      1, // int recvcount,
      MPI_INT, // MPI_Datatype recvtype,
      root, // int root,
      comm // MPI_Comm comm
    );

    return res;

  }

public:

  Gatherer(
    MPI_Datatype mpi_type_,
    MPI_Comm comm_=MPI_COMM_WORLD
  )
  : mpi_type(mpi_type_)
  , comm(comm_)
  { ; }

  void Put(const T& item) { items.push_back(item); }

  emp::optional<emp::vector<T>> Gather(const int root=0) {

    const emp::vector<int> counts = GatherCounts(root);

    // calculate where each processes' contribution should be placed
    emp::vector<int> displacements{0};
    std::partial_sum(
      std::begin(counts),
      std::end(counts),
      std::back_inserter(displacements)
    );

    // initialize buffer to hold contributed items from all processes
    const size_t num_items = std::accumulate(
      std::begin(counts),
      std::end(counts),
      0
    );
    emp::vector<T> res(num_items);

    // initialize buffer to contribute items from
    emp::vector<T> send_buffer( std::begin(items), std::end(items) );

    // do gather, contributed items are only delivered to root process
    UITSL_Gatherv(
      send_buffer.data(), // const void *sendbuf
      send_buffer.size(), // int sendcount
      mpi_type, // MPI_Datatype sendtype
      res.data(), // void *recvbuf
      counts.data(), // const int *recvcounts
      displacements.data(), // const int *displs
      mpi_type, // MPI_Datatype recvtype
      root, // int root
      comm // MPI_Comm comm
    );

    // if executing process is root, return gathered items
    return root == get_rank(comm)
      ? emp::optional<emp::vector<T>>{ res }
      : std::nullopt;

  }

};

} // namespace uitsl

#endif // #ifndef UITSL_CONCURRENT_GATHERER_HPP_INCLUDE