Program Listing for File Mesh.hpp

Return to documentation for file (include/netuit/mesh/Mesh.hpp)

#pragma once
#ifndef NETUIT_MESH_MESH_HPP_INCLUDE
#define NETUIT_MESH_MESH_HPP_INCLUDE

#include <stddef.h>
#include <unordered_map>

#include <mpi.h>

#include "../../uitsl/debug/safe_cast.hpp"
#include "../../uitsl/math/math_utils.hpp"
#include "../../uitsl/mpi/mpi_utils.hpp"
#include "../../uitsl/utility/assign_utils.hpp"

#include "../../uit/ducts/Duct.hpp"
#include "../../uit/setup/InterProcAddress.hpp"

#include "../assign/AssignIntegrated.hpp"
#include "../topology/Topology.hpp"

#include "MeshNode.hpp"
#include "MeshTopology.hpp"

namespace netuit {

namespace internal {

class MeshIDCounter {

  static inline size_t counter{};

public:

  static size_t Generate() { return counter++; }

};

} // namespace internal

template<typename ImplSpec>
class Mesh {

  using node_id_t = size_t;
  using edge_id_t = size_t;
  using node_t = MeshNode<ImplSpec>;

  size_t mesh_id;
  MPI_Comm comm;

  // node_id -> node
  internal::MeshTopology<ImplSpec> nodes;

  std::function<uitsl::thread_id_t(node_id_t)> thread_assignment;
  std::function<uitsl::proc_id_t(node_id_t)> proc_assignment;

  using back_end_t = typename ImplSpec::ProcBackEnd;
  std::shared_ptr<back_end_t> back_end;

  void InitializeInterThreadDucts() {
    for (auto& [node_id, node] : nodes) {
      InitializeInterThreadDucts(node_id, node);
    }
  }

  void InitializeInterThreadDucts(const node_id_t node_id, node_t & node) {
    // only need to iterate through inputs because this fixes outputs' ducts too
    for (auto& input : node.GetInputs()) InitializeInterThreadDuct(input);
  }

  void InitializeInterThreadDuct(netuit::MeshNodeInput<ImplSpec> & input) {

    const node_id_t inlet_node_id = nodes.GetOutputRegistry().at(
      input.GetEdgeID()
    );
    const uitsl::thread_id_t inlet_thread = thread_assignment(inlet_node_id);

    const node_id_t outlet_node_id = nodes.GetInputRegistry().at(
      input.GetEdgeID()
    );
    const uitsl::thread_id_t outlet_thread = thread_assignment(outlet_node_id);

    if (inlet_thread != outlet_thread) input.template EmplaceDuct<
      typename ImplSpec::ThreadDuct
    >();

  }

  void InitializeInterProcDucts() {
    for (auto& [node_id, node] : nodes) {
      InitializeInterProcDucts(node_id, node);
    }
  }

  void InitializeInterProcDucts(const node_id_t node_id, node_t& node) {

    for (auto & input : node.GetInputs()) InitializeInterProcDuct(input);

    for (auto & output : node.GetOutputs()) InitializeInterProcDuct(output);

  }

  void InitializeInterProcDuct(netuit::MeshNodeInput<ImplSpec>& input) {
    const node_id_t inlet_node_id = nodes.GetOutputRegistry().at(
      input.GetEdgeID()
    );
    const uitsl::proc_id_t inlet_proc_id = proc_assignment(inlet_node_id);

    const node_id_t outlet_node_id = nodes.GetInputRegistry().at(
      input.GetEdgeID()
    );
    const uitsl::proc_id_t outlet_proc_id = proc_assignment(outlet_node_id);

    static std::unordered_set<int> tag_checker;
    const int tag = uitsl::safe_cast<int>(
      uitsl::sidebyside_hash(mesh_id, input.GetEdgeID())
    );

    // assert that generated tags are unique
    emp_assert( tag_checker.insert(tag).second );

    const uit::InterProcAddress addr{
      outlet_proc_id,
      inlet_proc_id,
      thread_assignment(outlet_node_id),
      thread_assignment(inlet_node_id),
      tag,
      comm
    };

    if (inlet_proc_id != outlet_proc_id) input.template SplitDuct<
      typename ImplSpec::ProcOutletDuct
    >(addr, back_end);

  }

  void InitializeInterProcDuct(netuit::MeshNodeOutput<ImplSpec>& output) {
    const node_id_t inlet_node_id = nodes.GetOutputRegistry().at(
      output.GetEdgeID()
    );
    const uitsl::proc_id_t inlet_proc_id = proc_assignment(inlet_node_id);

    const node_id_t outlet_node_id = nodes.GetInputRegistry().at(
      output.GetEdgeID()
    );
    const uitsl::proc_id_t outlet_proc_id = proc_assignment(outlet_node_id);

    const uit::InterProcAddress addr{
      outlet_proc_id,
      inlet_proc_id,
      thread_assignment(outlet_node_id),
      thread_assignment(inlet_node_id),
      uitsl::safe_cast<int>(
        uitsl::sidebyside_hash(mesh_id, output.GetEdgeID())
      ),
      comm
    };

    if (inlet_proc_id != outlet_proc_id) output.template SplitDuct<
      typename ImplSpec::ProcInletDuct
    >(addr, back_end);


  }


public:

  Mesh(
    const Topology & topology,
    const std::function<uitsl::thread_id_t(node_id_t)> thread_assignment_
      =uitsl::AssignIntegrated<uitsl::thread_id_t>{},
    const std::function<uitsl::proc_id_t(node_id_t)> proc_assignment_
      =uitsl::AssignIntegrated<uitsl::proc_id_t>{},
    std::shared_ptr<back_end_t> back_end_=std::make_shared<back_end_t>(),
    const MPI_Comm comm_=MPI_COMM_WORLD,
    const size_t mesh_id_=internal::MeshIDCounter::Generate()
  )
  : mesh_id(mesh_id_)
  , comm(comm_)
  , nodes(topology, proc_assignment_, comm)
  , thread_assignment(thread_assignment_)
  , proc_assignment(proc_assignment_)
  , back_end(back_end_) {
    InitializeInterThreadDucts();
    InitializeInterProcDucts();
    back_end->Initialize();
  }

  // TODO rename GetNumNodes
  size_t GetNodeCount() const { return nodes.GetNodeCount(); }

  // TODO rename GetNumEdges
  size_t GetEdgeCount() const { return nodes.GetEdgeCount(); }

  using submesh_t = emp::vector<node_t>;

  submesh_t GetSubmesh(const uitsl::thread_id_t tid=0) const {
    return GetSubmesh(tid, uitsl::get_proc_id(comm));
  }

  submesh_t GetSubmesh(
    const uitsl::thread_id_t tid,
    const uitsl::proc_id_t pid
  ) const {
    submesh_t res;
    for (const auto& [node_id, node] : nodes) {
      if (
        thread_assignment(node_id) == tid
        && proc_assignment(node_id) == pid
      ) res.push_back(node);
    }
    return res;
  }

  std::string ToString() const {
    std::stringstream ss;
    ss << "TODO" << std::endl;
    return ss.str();
  }

};

} // namespace netuit

#endif // #ifndef NETUIT_MESH_MESH_HPP_INCLUDE