def test_copy_node_load(tmp_path, backend): node = tn.CopyNode( rank=4, dimension=3, name='copier', axis_names=[str(n) for n in range(4)], backend=backend) with h5py.File(tmp_path / 'node', 'w') as node_file: node_group = node_file.create_group('node_data') node_group.create_dataset('signature', data=node.signature) node_group.create_dataset('backend', data=node.backend.name) node_group.create_dataset( 'copy_node_dtype', data=np.dtype(node.copy_node_dtype).name) node_group.create_dataset('name', data=node.name) node_group.create_dataset('shape', data=node.shape) node_group.create_dataset( 'axis_names', data=np.array(node.axis_names, dtype=object), dtype=string_type) node_group.create_dataset( 'edges', data=np.array([edge.name for edge in node.edges], dtype=object), dtype=string_type) loaded_node = CopyNode._load_node(node_data=node_file["node_data/"]) assert loaded_node.name == node.name assert loaded_node.signature == node.signature assert set(loaded_node.axis_names) == set(node.axis_names) assert (set(edge.name for edge in loaded_node.edges) == set( edge.name for edge in node.edges)) assert loaded_node.get_dimension(axis=1) == node.get_dimension(axis=1) assert loaded_node.get_rank() == node.get_rank() assert loaded_node.shape == node.shape assert loaded_node.copy_node_dtype == node.copy_node_dtype
def test_copy_node_load(tmp_path, backend): net = tensornetwork.TensorNetwork(backend) node = net.add_copy_node(rank=4, dimension=3, name='copier') with h5py.File(tmp_path / 'node', 'w') as node_file: node_group = node_file.create_group('node_data') node_group.create_dataset('signature', data=node.signature) node_group.create_dataset('name', data=node.name) node_group.create_dataset('shape', data=node.shape) node_group.create_dataset('axis_names', data=np.array(node.axis_names, dtype=object), dtype=string_type) node_group.create_dataset('edges', data=np.array( [edge.name for edge in node.edges], dtype=object), dtype=string_type) net = tensornetwork.TensorNetwork(backend=node.network.backend.name) loaded_node = CopyNode._load_node(net, node_file["node_data/"]) assert loaded_node.name == node.name assert loaded_node.signature == node.signature assert set(loaded_node.axis_names) == set(node.axis_names) assert (set(edge.name for edge in loaded_node.edges) == set( edge.name for edge in node.edges)) assert loaded_node.get_dimension(axis=1) == node.get_dimension(axis=1) assert loaded_node.get_rank() == node.get_rank() assert loaded_node.shape == node.shape