예제 #1
0
def save_nodes(nodes: List[AbstractNode], path: Union[str, BinaryIO]) -> None:
  """Save an iterable of nodes into hdf5 format.

  Args:
    nodes: An iterable of connected nodes. All nodes have to connect within
      `nodes`.
    path: path to file where network is saved.
  """
  if reachable(nodes) > set(nodes):
    raise ValueError(
        "Some nodes in `nodes` are connected to nodes not contained in `nodes`."
        " Saving not possible.")
  if len(set(nodes)) < len(list(nodes)):
    raise ValueError(
        'Some nodes in `nodes` appear more than once. This is not supported')
  #we need to iterate twice and order matters
  edges = list(get_all_edges(nodes))
  nodes = list(nodes)

  old_edge_names = {n: edge.name for n, edge in enumerate(edges)}
  old_node_names = {n: node.name for n, node in enumerate(nodes)}

  #generate unique names for nodes and edges
  #for saving them
  for n, node in enumerate(nodes):
    node.set_name('node{}'.format(n))

  for e, edge in enumerate(edges):
    edge.set_name('edge{}'.format(e))

  with h5py.File(path, 'w') as net_file:
    nodes_group = net_file.create_group('nodes')
    node_names_group = net_file.create_group('node_names')
    node_names_group.create_dataset(
        'names',
        dtype=string_type,
        data=np.array(list(old_node_names.values()), dtype=object))

    edges_group = net_file.create_group('edges')
    edge_names_group = net_file.create_group('edge_names')
    edge_names_group.create_dataset(
        'names',
        dtype=string_type,
        data=np.array(list(old_edge_names.values()), dtype=object))

    for n, node in enumerate(nodes):
      node_group = nodes_group.create_group(node.name)
      node._save_node(node_group)
      for edge in node.edges:
        if edge.node1 == node and edge in edges:
          edge_group = edges_group.create_group(edge.name)
          edge._save_edge(edge_group)
          edges.remove(edge)

  #name edges and nodes back  to their original names
  for n, node in enumerate(nodes):
    nodes[n].set_name(old_node_names[n])

  for n, edge in enumerate(edges):
    edges[n].set_name(old_edge_names[n])
예제 #2
0
 def nodes(self) -> Set[BaseNode]:
     """All tensor-network nodes involved in the operator."""
     return reachable(
         get_all_nodes(self.out_edges + self.in_edges) | self.ref_nodes)