Program Listing for File Mesh.hpp

Return to documentation for file (include/netuit/mesh/Mesh.hpp)

#pragma once
#ifndef NETUIT_MESH_MESH_HPP_INCLUDE
#define NETUIT_MESH_MESH_HPP_INCLUDE

#include <algorithm>
#include <cassert>
#include <ratio>
#include <stddef.h>
#include <unordered_map>
#include <vector>

#include <mpi.h>

#include "../../uitsl/debug/safe_cast.hpp"
#include "../../uitsl/math/math_utils.hpp"
#include "../../uitsl/mpi/mpi_init_utils.hpp"
#include "../../uitsl/parallel/thread_utils.hpp"
#include "../../uitsl/utility/assign_utils.hpp"

#include "../../uit/ducts/Duct.hpp"
#include "../../uit/setup/InterProcAddress.hpp"
#include "../../uit/spouts/wrappers/impl/RoundTripCounterAddr.hpp"
#include "../../uit/spouts/wrappers/impl/round_trip_touch_counter.hpp"

#include "../assign/AssignIntegrated.hpp"
#include "../topology/Topology.hpp"

#include "MeshNode.hpp"
#include "MeshTopology.hpp"

namespace netuit {

namespace internal {

class MeshIDCounter {

  static inline size_t counter{};

public:

  static size_t Generate() { return counter++; }
  static void Reset() { counter = 0; }
  static size_t Get() { return counter; }


};

} // namespace internal

template<typename ImplSpec>
class Mesh {

  using node_id_t = size_t;
  using edge_id_t = size_t;
  using node_t = MeshNode<ImplSpec>;

  size_t mesh_id;
  MPI_Comm comm;

  // node_id -> node
  internal::MeshTopology<ImplSpec> nodes;

  std::function<uitsl::thread_id_t(node_id_t)> thread_assignment;
  std::function<uitsl::proc_id_t(node_id_t)> proc_assignment;

  using back_end_t = typename ImplSpec::ProcBackEnd;
  std::shared_ptr<back_end_t> back_end;

  void InitializeInterThreadDucts() {
    for (auto& [node_id, node] : nodes) {
      InitializeInterThreadDucts(node_id, node);
    }
  }

  void InitializeInterThreadDucts(const node_id_t node_id, node_t & node) {
    // only need to iterate through inputs because this fixes outputs' ducts too
    for (auto& input : node.GetInputs()) InitializeInterThreadDuct(input);
  }

  void InitializeInterThreadDuct(netuit::MeshNodeInput<ImplSpec> & input) {

    const node_id_t inlet_node_id = nodes.GetOutputRegistry().at(
      input.GetEdgeID()
    );
    const uitsl::thread_id_t inlet_thread = thread_assignment(inlet_node_id);

    const node_id_t outlet_node_id = nodes.GetInputRegistry().at(
      input.GetEdgeID()
    );
    const uitsl::thread_id_t outlet_thread = thread_assignment(outlet_node_id);

    if (inlet_thread != outlet_thread) input.template EmplaceDuct<
      typename ImplSpec::ThreadDuct
    >();

  }

  void InitializeInterProcDucts() {
    for (auto& [node_id, node] : nodes) {
      InitializeInterProcDucts(node_id, node);
    }
  }

  void InitializeInterProcDucts(const node_id_t node_id, node_t& node) {

    for (auto & input : node.GetInputs()) InitializeInterProcDuct(input);

    for (auto & output : node.GetOutputs()) InitializeInterProcDuct(output);

  }

  void InitializeInterProcDuct(netuit::MeshNodeInput<ImplSpec>& input) {
    const node_id_t inlet_node_id = nodes.GetOutputRegistry().at(
      input.GetEdgeID()
    );
    const uitsl::proc_id_t inlet_proc_id = proc_assignment(inlet_node_id);

    const node_id_t outlet_node_id = nodes.GetInputRegistry().at(
      input.GetEdgeID()
    );
    const uitsl::proc_id_t outlet_proc_id = proc_assignment(outlet_node_id);

    static std::unordered_set<int> tag_checker;
    const int tag = uitsl::safe_cast<int>(
      uitsl::sidebyside_hash<std::ratio<3, 4>>(mesh_id, input.GetEdgeID())
    );

    const uit::InterProcAddress addr{
      outlet_proc_id,
      inlet_proc_id,
      thread_assignment(outlet_node_id),
      thread_assignment(inlet_node_id),
      tag,
      comm
    };

    if (inlet_proc_id != outlet_proc_id) {
      input.template SplitDuct<
        typename ImplSpec::ProcOutletDuct
      >(addr, back_end);
      // assert that generated tags are unique
      assert( tag_checker.insert(tag).second );
    }

  }

  void InitializeInterProcDuct(netuit::MeshNodeOutput<ImplSpec>& output) {
    const node_id_t inlet_node_id = nodes.GetOutputRegistry().at(
      output.GetEdgeID()
    );
    const uitsl::proc_id_t inlet_proc_id = proc_assignment(inlet_node_id);

    const node_id_t outlet_node_id = nodes.GetInputRegistry().at(
      output.GetEdgeID()
    );
    const uitsl::proc_id_t outlet_proc_id = proc_assignment(outlet_node_id);

    const uit::InterProcAddress addr{
      outlet_proc_id,
      inlet_proc_id,
      thread_assignment(outlet_node_id),
      thread_assignment(inlet_node_id),
      uitsl::safe_cast<int>(
        uitsl::sidebyside_hash<std::ratio<3, 4>>(mesh_id, output.GetEdgeID())
      ),
      comm
    };

    if (inlet_proc_id != outlet_proc_id) output.template SplitDuct<
      typename ImplSpec::ProcInletDuct
    >(addr, back_end);

  }

  // solely for instrumentation purposes
  void RegisterDuctTargets() {
    for (auto& [node_id, node] : nodes) RegisterDuctTargets(node_id, node);
  }

  // solely for instrumentation purposes
  void RegisterDuctTargets(const node_id_t node_id, node_t& node) {
    for (const auto & input : node.GetInputs()) RegisterDuctTarget(input);
    for (const auto & output : node.GetOutputs()) RegisterDuctTarget(output);
  }

  // solely for instrumentation purposes
  void RegisterDuctTarget(const netuit::MeshNodeOutput<ImplSpec>& output) {
    output.RegisterEdgeID( output.GetEdgeID() );
    output.RegisterMeshID( mesh_id );
    {
      const node_id_t inlet_node_id = nodes.GetOutputRegistry().at(
        output.GetEdgeID()
      );
      const uitsl::proc_id_t inlet_proc_id = proc_assignment(inlet_node_id);
      const uitsl::thread_id_t inlet_thread_id = thread_assignment(
        inlet_node_id
      );

      output.RegisterInletProc( inlet_proc_id );
      output.RegisterInletThread( inlet_thread_id );
      output.RegisterInletNodeID( inlet_node_id );
    }

    {
      const node_id_t outlet_node_id = nodes.GetInputRegistry().at(
        output.GetEdgeID()
      );
      const uitsl::proc_id_t outlet_proc_id = proc_assignment(outlet_node_id);
      const uitsl::thread_id_t outlet_thread_id = thread_assignment(
        outlet_node_id
      );

      output.RegisterOutletProc( outlet_proc_id );
      output.RegisterOutletThread( outlet_thread_id );
      output.RegisterOutletNodeID( outlet_node_id );
    }

    {
      const auto addr = uit::impl::RoundTripCounterAddr{
        mesh_id,
        nodes.GetOutputRegistry().at( output.GetEdgeID() ),
        nodes.GetInputRegistry().at( output.GetEdgeID() )
      };
      uit::impl::round_trip_touch_counter[ addr ];
      assert( uit::impl::round_trip_touch_counter.count(addr) == 1 );
    }
  }

  // solely for instrumentation purposes
  void RegisterDuctTarget(const netuit::MeshNodeInput<ImplSpec>& input) {
    input.RegisterEdgeID( input.GetEdgeID() );
    input.RegisterMeshID( mesh_id );
    {
      const node_id_t inlet_node_id = nodes.GetOutputRegistry().at(
        input.GetEdgeID()
      );
      const uitsl::proc_id_t inlet_proc_id = proc_assignment(inlet_node_id);
      const uitsl::thread_id_t inlet_thread_id = thread_assignment(
          inlet_node_id
      );

      input.RegisterInletProc( inlet_proc_id );
      input.RegisterInletThread( inlet_thread_id );
      input.RegisterInletNodeID( inlet_node_id );
    }

    {
      const node_id_t outlet_node_id = nodes.GetInputRegistry().at(
        input.GetEdgeID()
      );
      const uitsl::proc_id_t outlet_proc_id = proc_assignment(outlet_node_id);
      const uitsl::thread_id_t outlet_thread_id = thread_assignment(
          outlet_node_id
      );

      input.RegisterOutletProc( outlet_proc_id );
      input.RegisterOutletThread( outlet_thread_id );
      input.RegisterOutletNodeID( outlet_node_id );
    }

    {
      const auto addr = uit::impl::RoundTripCounterAddr{
        mesh_id,
        nodes.GetInputRegistry().at( input.GetEdgeID() ),
        nodes.GetOutputRegistry().at( input.GetEdgeID() )
      };
      uit::impl::round_trip_touch_counter[ addr ];
      assert( uit::impl::round_trip_touch_counter.count(addr) == 1 );
    }
  }


public:

  Mesh(
    const Topology & topology,
    const std::function<uitsl::thread_id_t(node_id_t)> thread_assignment_
      =uitsl::AssignIntegrated<uitsl::thread_id_t>{},
    const std::function<uitsl::proc_id_t(node_id_t)> proc_assignment_
      =uitsl::AssignIntegrated<uitsl::proc_id_t>{},
    std::shared_ptr<back_end_t> back_end_=std::make_shared<back_end_t>(),
    const MPI_Comm comm_=MPI_COMM_WORLD,
    const size_t mesh_id_=internal::MeshIDCounter::Generate()
  )
  : mesh_id(mesh_id_)
  , comm(comm_)
  , nodes(topology, proc_assignment_, comm)
  , thread_assignment(thread_assignment_)
  , proc_assignment(proc_assignment_)
  , back_end(back_end_) {
    InitializeInterThreadDucts();
    InitializeInterProcDucts();
    RegisterDuctTargets();
    back_end->Initialize();
  }

  // TODO rename GetNumNodes
  size_t GetNodeCount() const { return nodes.GetNodeCount(); }

  // TODO rename GetNumEdges
  size_t GetEdgeCount() const { return nodes.GetEdgeCount(); }

  using submesh_t = std::vector<node_t>;

  submesh_t GetSubmesh(const uitsl::thread_id_t tid=0) const {
    return GetSubmesh(tid, uitsl::get_proc_id(comm));
  }

  submesh_t GetSubmesh(
    const uitsl::thread_id_t tid,
    const uitsl::proc_id_t pid
  ) const {
    submesh_t res;
    for (const auto& [node_id, node] : nodes) {
      if (
        thread_assignment(node_id) == tid
        && proc_assignment(node_id) == pid
      ) res.push_back(node);
    }
    return res;
  }

  std::string ToString() const {
    std::stringstream ss;
    for (const auto& [node_id, node] : nodes) {
      ss << uitsl::format_member(
        "node id", node_id
      );
      ss << uitsl::format_member(
        "proc assignment", proc_assignment(node_id)
      );
      ss << uitsl::format_member(
        "thread assignment", thread_assignment(node_id)
      );
      ss << uitsl::format_member(
        "node", node.ToString()
      );
      ss << '\n';
    }
    return ss.str();
  }

};

} // namespace netuit

#endif // #ifndef NETUIT_MESH_MESH_HPP_INCLUDE