Ejemplo n.º 1
0
def test_remove_after_flatten(backend):
    a = tn.Node(np.ones((2, 2)), backend=backend)
    b = tn.Node(np.ones((2, 2)), backend=backend)
    tn.connect(a[0], b[0])
    tn.connect(a[1], b[1])
    tn.flatten_all_edges({a, b})
    tn.remove_node(a)
Ejemplo n.º 2
0
def test_flatten_all_edges(backend):
    a = tn.Node(np.ones((3, 3, 5, 6, 2, 2)), backend=backend)
    b = tn.Node(np.ones((5, 6, 7)), backend=backend)
    c = tn.Node(np.ones((7, )), backend=backend)
    trace_edge1 = tn.connect(a[0], a[1])
    trace_edge2 = tn.connect(a[4], a[5])
    split_edge1 = tn.connect(a[2], b[0])
    split_edge2 = tn.connect(a[3], b[1])
    ok_edge = tn.connect(b[2], c[0])
    flat_edges = tn.flatten_all_edges({a, b, c})
    tn.check_correct({a, b, c})
    assert len(flat_edges) == 3
    assert trace_edge1 not in flat_edges
    assert trace_edge2 not in flat_edges
    assert split_edge1 not in flat_edges
    assert split_edge2 not in flat_edges
    assert ok_edge in flat_edges