def test_edge_load(backend, tmp_path, double_node_edge): edge = double_node_edge.edge12 with h5py.File(tmp_path / 'edge', 'w') as edge_file: edge_group = edge_file.create_group('edge_data') edge_group.create_dataset('signature', data=edge.signature) edge_group.create_dataset('name', data=edge.name) edge_group.create_dataset('node1', data=edge.node1.name) edge_group.create_dataset('node2', data=edge.node2.name) edge_group.create_dataset('axis1', data=edge.axis1) edge_group.create_dataset('axis2', data=edge.axis2) ten = np.ones((1, 2, 2)) node1 = Node( tensor=2 * ten, name="test_node1", axis_names=["a", "b", "c"], backend=backend) node2 = Node( tensor=ten, name="test_node2", axis_names=["a", "b", "c"], backend=backend) loaded_edge = Edge._load_edge(edge_group, { node1.name: node1, node2.name: node2 }) assert loaded_edge.name == edge.name assert loaded_edge.signature == edge.signature assert loaded_edge.node1.name == edge.node1.name assert loaded_edge.node2.name == edge.node2.name assert loaded_edge.axis1 == edge.axis1 assert loaded_edge.axis2 == edge.axis2 np.testing.assert_allclose(loaded_edge.node1.tensor, node1.tensor) np.testing.assert_allclose(loaded_edge.node2.tensor, node2.tensor)
def load(path: str): """Load a tensor network from disk. Args: path: path to file where network is saved. """ with h5py.File(path, 'r') as net_file: net = TensorNetwork(backend=net_file["backend"][()]) node_names = list(net_file["nodes"].keys()) edge_names = list(net_file["edges"].keys()) for node_name in node_names: node_data = net_file["nodes/" + node_name] node_type = get_component(node_data['type'][()]) node_type._load_node(net, node_data) nodes_dict = {node.name: node for node in net.nodes_set} for edge in edge_names: edge_data = net_file["edges/" + edge] Edge._load_edge(edge_data, nodes_dict) return net
def load_nodes(path: str) -> List[BaseNode]: """ Load a set of nodes from disk. Args: path: path to file where network is saved. Returns: An iterable of `Node` objects """ nodes_list = [] edges_list = [] with h5py.File(path, 'r') as net_file: nodes = list(net_file["nodes"].keys()) node_names = { 'node{}'.format(n): v for n, v in enumerate(net_file["node_names"]['names'][()]) } edge_names = { 'edge{}'.format(n): v for n, v in enumerate(net_file["edge_names"]['names'][()]) } edges = list(net_file["edges"].keys()) for node_name in nodes: node_data = net_file["nodes/" + node_name] node_type = get_component(node_data['type'][()]) nodes_list.append( node_type._load_node(net=None, node_data=node_data)) nodes_dict = {node.name: node for node in nodes_list} for edge in edges: edge_data = net_file["edges/" + edge] edges_list.append(Edge._load_edge(edge_data, nodes_dict)) for edge in edges_list: edge.set_name(edge_names[edge.name]) for node in nodes_list: node.set_name(node_names[node.name]) return nodes_list