Program Listing for File MeshTopology.hpp

Return to documentation for file (include/netuit/mesh/MeshTopology.hpp)

#pragma once
#ifndef NETUIT_MESH_MESHTOPOLOGY_HPP_INCLUDE
#define NETUIT_MESH_MESHTOPOLOGY_HPP_INCLUDE

#include <set>
#include <stddef.h>
#include <unordered_map>

#include <mpi.h>

#include "../../uitsl/mpi/mpi_utils.hpp"
#include "../../uitsl/utility/assign_utils.hpp"

#include "../../uit/ducts/Duct.hpp"
#include "../../uit/fixtures/Conduit.hpp"

#include "../topology/Topology.hpp"

#include "MeshNode.hpp"

namespace netuit {
namespace internal {

template<typename ImplSpec>
class MeshTopology {

  using node_id_t = size_t;
  using edge_id_t = size_t;
  using node_t = MeshNode<ImplSpec>;
  using node_lookup_t = std::unordered_map<node_id_t, node_t>;

  // node_id -> node
  node_lookup_t nodes;

  // ordered by edge_id
  std::set<edge_id_t> edge_registry;
  // edge_id -> node_id
  std::unordered_map<edge_id_t, node_id_t> input_registry;
  std::unordered_map<edge_id_t, node_id_t> output_registry;

  void InitializeRegistries(const netuit::Topology& topology) {
    for (node_id_t node_id = 0; node_id < topology.GetSize(); ++node_id) {
      const netuit::TopoNode& topo_node = topology[node_id];
      RegisterNodeInputs(node_id, topo_node);
      RegisterNodeOutputs(node_id, topo_node);
    }
  }

  void RegisterNodeInputs(
    const node_id_t node_id, const netuit::TopoNode& topo_node
  ) {
    for (const netuit::TopoNodeInput & input : topo_node.GetInputs()) {
      emp_assert(input_registry.count(input.GetEdgeID()) == 0);
      edge_registry.insert(input.GetEdgeID());
      input_registry[input.GetEdgeID()] = node_id;
    }
  }

  void RegisterNodeOutputs(
    const node_id_t node_id, const netuit::TopoNode& topo_node
  ) {
    for (const netuit::TopoNodeOutput& output : topo_node.GetOutputs()) {
      emp_assert(output_registry.count(output.GetEdgeID()) == 0);
      edge_registry.insert(output.GetEdgeID());
      output_registry[output.GetEdgeID()] = node_id;
    }
  }

  void InitializeNodes(
    const netuit::Topology& topology,
    const std::function<uitsl::proc_id_t(node_id_t)> proc_assignment,
    const MPI_Comm& comm
  ) {

    // ensures that we include relevant nodes that don't have any edges
    for (node_id_t node_id = 0; node_id < topology.GetSize(); ++node_id) {
      if (proc_assignment(node_id) == uitsl::get_proc_id(comm)) {
        InitializeNode(node_id);
      }
    }

  }

  void InitializeNode(const node_id_t node_id){
    if (nodes.count(node_id) == 0) nodes.emplace(
      std::piecewise_construct,
      std::forward_as_tuple(node_id),
      std::forward_as_tuple(node_id)
    );
  }

  void InitializeEdges(
    const netuit::Topology& topology,
    const std::function<uitsl::proc_id_t(node_id_t)> proc_assignment,
    const MPI_Comm& comm
  ) {
    for (edge_id_t edge : edge_registry) {
      const node_id_t input_id = input_registry.at(edge);
      const node_id_t output_id = output_registry.at(edge);
      // only construct infrastructure relevant to this proc
      // (but do need nodes that are connected to nodes on this proc)
      if (
        proc_assignment(input_id) == uitsl::get_proc_id(comm)
        || proc_assignment(output_id) == uitsl::get_proc_id(comm)
      ) {
        uit::Conduit<ImplSpec> conduit;

        InitializeNode(input_id);
        nodes.at(input_id).AddInput(
          MeshNodeInput<ImplSpec>{conduit.GetOutlet(), edge}
        );

        InitializeNode(output_id);
        nodes.at(output_id).AddOutput(
          MeshNodeOutput<ImplSpec>{conduit.GetInlet(), edge}
        );
      }
    }
  }

public:

  using value_type = typename node_lookup_t::value_type;

  MeshTopology(
    const netuit::Topology & topology,
    const std::function<uitsl::proc_id_t(node_id_t)> proc_assignment
      =uitsl::AssignIntegrated<uitsl::proc_id_t>{},
    const MPI_Comm comm=MPI_COMM_WORLD
  ) {
    InitializeRegistries(topology);
    InitializeNodes(topology, proc_assignment, comm);
    InitializeEdges(topology, proc_assignment, comm);

    // ensure that input, output registries have same keys as edge registry
    emp_assert(
      edge_registry == [this](){
        std::set<edge_id_t> res;
        std::transform(
          std::begin(input_registry),
          std::end(input_registry),
          std::inserter(res, std::begin(res)),
          [](const auto & kv){ return kv.first; }
        );
        return res;
      }()
    );
    emp_assert(
      edge_registry == [this](){
        std::set<edge_id_t> res;
        std::transform(
          std::begin(output_registry),
          std::end(output_registry),
          std::inserter(res, std::begin(res)),
          [](const auto & kv){ return kv.first; }
        );
        return res;
      }()
    );

  }

  size_t GetNodeCount() const { return nodes.size(); }

  size_t GetEdgeCount() const { return edge_registry.size(); }

  typename node_lookup_t::iterator begin() { return std::begin(nodes); }

  typename node_lookup_t::iterator end() { return std::end(nodes); }

  typename node_lookup_t::const_iterator begin() const {
    return std::begin(nodes);
  }

  typename node_lookup_t::const_iterator end() const {
    return std::end(nodes);
  }

  const std::set<edge_id_t>& GetEdgeRegistry() const { return edge_registry; }

  const std::unordered_map<edge_id_t, node_id_t>& GetInputRegistry() const {
    return input_registry;
  }

  const std::unordered_map<edge_id_t, node_id_t>& GetOutputRegistry() const {
    return output_registry;
  }


  std::string ToString() const {
    std::stringstream ss;
    ss << "TODO" << std::endl;
    return ss.str();
  }

};

} // namespace internal
} // namespace netuit

#endif // #ifndef NETUIT_MESH_MESHTOPOLOGY_HPP_INCLUDE