def test_edge_load(backend, tmp_path, double_node_edge):
  edge = double_node_edge.edge12

  with h5py.File(tmp_path / 'edge', 'w') as edge_file:
    edge_group = edge_file.create_group('edge_data')
    edge_group.create_dataset('signature', data=edge.signature)
    edge_group.create_dataset('name', data=edge.name)
    edge_group.create_dataset('node1', data=edge.node1.name)
    edge_group.create_dataset('node2', data=edge.node2.name)
    edge_group.create_dataset('axis1', data=edge.axis1)
    edge_group.create_dataset('axis2', data=edge.axis2)

    ten = np.ones((1, 2, 2))
    node1 = Node(
        tensor=2 * ten,
        name="test_node1",
        axis_names=["a", "b", "c"],
        backend=backend)
    node2 = Node(
        tensor=ten,
        name="test_node2",
        axis_names=["a", "b", "c"],
        backend=backend)
    loaded_edge = Edge._load_edge(edge_group, {
        node1.name: node1,
        node2.name: node2
    })
    assert loaded_edge.name == edge.name
    assert loaded_edge.signature == edge.signature
    assert loaded_edge.node1.name == edge.node1.name
    assert loaded_edge.node2.name == edge.node2.name
    assert loaded_edge.axis1 == edge.axis1
    assert loaded_edge.axis2 == edge.axis2
    np.testing.assert_allclose(loaded_edge.node1.tensor, node1.tensor)
    np.testing.assert_allclose(loaded_edge.node2.tensor, node2.tensor)
示例#2
0
def load(path: str):
    """Load a tensor network from disk.

  Args:
    path: path to file where network is saved.
  """
    with h5py.File(path, 'r') as net_file:
        net = TensorNetwork(backend=net_file["backend"][()])
        node_names = list(net_file["nodes"].keys())
        edge_names = list(net_file["edges"].keys())

        for node_name in node_names:
            node_data = net_file["nodes/" + node_name]
            node_type = get_component(node_data['type'][()])
            node_type._load_node(net, node_data)

        nodes_dict = {node.name: node for node in net.nodes_set}

        for edge in edge_names:
            edge_data = net_file["edges/" + edge]
            Edge._load_edge(edge_data, nodes_dict)
    return net
示例#3
0
def load_nodes(path: str) -> List[BaseNode]:
    """
  Load a set of nodes from disk.

  Args:
    path: path to file where network is saved.
  Returns:
    An iterable of `Node` objects
  """
    nodes_list = []
    edges_list = []
    with h5py.File(path, 'r') as net_file:
        nodes = list(net_file["nodes"].keys())
        node_names = {
            'node{}'.format(n): v
            for n, v in enumerate(net_file["node_names"]['names'][()])
        }

        edge_names = {
            'edge{}'.format(n): v
            for n, v in enumerate(net_file["edge_names"]['names'][()])
        }
        edges = list(net_file["edges"].keys())
        for node_name in nodes:
            node_data = net_file["nodes/" + node_name]
            node_type = get_component(node_data['type'][()])
            nodes_list.append(
                node_type._load_node(net=None, node_data=node_data))
        nodes_dict = {node.name: node for node in nodes_list}
        for edge in edges:
            edge_data = net_file["edges/" + edge]
            edges_list.append(Edge._load_edge(edge_data, nodes_dict))

    for edge in edges_list:
        edge.set_name(edge_names[edge.name])
    for node in nodes_list:
        node.set_name(node_names[node.name])

    return nodes_list