def test_remove_after_flatten(backend): a = tn.Node(np.ones((2, 2)), backend=backend) b = tn.Node(np.ones((2, 2)), backend=backend) tn.connect(a[0], b[0]) tn.connect(a[1], b[1]) tn.flatten_all_edges({a, b}) tn.remove_node(a)
def test_flatten_all_edges(backend): a = tn.Node(np.ones((3, 3, 5, 6, 2, 2)), backend=backend) b = tn.Node(np.ones((5, 6, 7)), backend=backend) c = tn.Node(np.ones((7, )), backend=backend) trace_edge1 = tn.connect(a[0], a[1]) trace_edge2 = tn.connect(a[4], a[5]) split_edge1 = tn.connect(a[2], b[0]) split_edge2 = tn.connect(a[3], b[1]) ok_edge = tn.connect(b[2], c[0]) flat_edges = tn.flatten_all_edges({a, b, c}) tn.check_correct({a, b, c}) assert len(flat_edges) == 3 assert trace_edge1 not in flat_edges assert trace_edge2 not in flat_edges assert split_edge1 not in flat_edges assert split_edge2 not in flat_edges assert ok_edge in flat_edges