def test_flatten_edges_between(backend): a = tn.Node(np.ones((3, 4, 5)), backend=backend) b = tn.Node(np.ones((5, 4, 3)), backend=backend) tn.connect(a[0], b[2]) tn.connect(a[1], b[1]) tn.connect(a[2], b[0]) tn.flatten_edges_between(a, b) tn.check_correct({a, b}) np.testing.assert_allclose(a.tensor, np.ones((60, ))) np.testing.assert_allclose(b.tensor, np.ones((60, )))
def f(x, n): x_slice = x[..., :n] n1 = Node(x_slice, backend="pytorch") n2 = Node(x_slice, backend="pytorch") connect(n1[0], n2[0]) connect(n1[1], n2[1]) connect(n1[2], n2[2]) return contract(flatten_edges_between(n1, n2)).get_tensor()
def test_flatten_edges_between_no_edges(backend): a = tn.Node(np.ones((3)), backend=backend) b = tn.Node(np.ones((3)), backend=backend) assert tn.flatten_edges_between(a, b) is None
def f(x, n): x_slice = x[..., :n] n1 = Node(x_slice, backend="tensorflow") connect(n1[0], n1[2]) connect(n1[1], n1[3]) return contract(flatten_edges_between(n1, n1)).get_tensor()