def test_contraction_sanity(self): net = network.TensorNetwork(backend="tensorflow") a = net.add_node(np.ones([4, 5, 2])) b = net.add_node(np.ones([3, 2, 3])) net.connect(a[2], b[1]) net.connect(b[0], b[2]) net = stochastic_contractor.stochastic(net, 2) net.check_correct() res = net.get_final_node() self.assertAllClose(res.get_tensor(), 6 * np.ones([4, 5]))
def test_contraction_parallel_edges(self): net = network.TensorNetwork() a = net.add_node(np.ones([4, 5, 2])) b = net.add_node(np.ones([3, 2, 3, 5])) c = net.add_node(np.ones([ 4, ])) net.connect(a[2], b[1]) net.connect(b[0], b[2]) net.connect(a[1], b[3]) net.connect(a[0], c[0]) net = stochastic_contractor.stochastic(net, 2) net.check_correct() res = net.get_final_node() self.assertAllClose(res.get_tensor(), 120)
def test_contraction_disconnected(self): net = network.TensorNetwork(backend="tensorflow") a = net.add_node(np.ones([4, 5, 2])) b = net.add_node(np.ones([3, 2, 3])) edge1 = a[0] net.connect(a[2], b[1]) net.connect(b[0], b[2]) c = net.add_node(np.ones([3, 4])) d = net.add_node(np.ones([4, 3])) edge2 = c[0] net.connect(c[1], d[0]) net = stochastic_contractor.stochastic(net, 2) net.check_correct(check_connected=False) node1, node2 = edge1.node1, edge2.node1 self.assertAllClose(node1.get_tensor(), 6 * np.ones([4, 5])) self.assertAllClose(node2.get_tensor(), 4 * np.ones([3, 3]))