def test_contraction_sanity(self):
     net = network.TensorNetwork(backend="tensorflow")
     a = net.add_node(np.ones([4, 5, 2]))
     b = net.add_node(np.ones([3, 2, 3]))
     net.connect(a[2], b[1])
     net.connect(b[0], b[2])
     net = stochastic_contractor.stochastic(net, 2)
     net.check_correct()
     res = net.get_final_node()
     self.assertAllClose(res.get_tensor(), 6 * np.ones([4, 5]))
 def test_contraction_parallel_edges(self):
     net = network.TensorNetwork()
     a = net.add_node(np.ones([4, 5, 2]))
     b = net.add_node(np.ones([3, 2, 3, 5]))
     c = net.add_node(np.ones([
         4,
     ]))
     net.connect(a[2], b[1])
     net.connect(b[0], b[2])
     net.connect(a[1], b[3])
     net.connect(a[0], c[0])
     net = stochastic_contractor.stochastic(net, 2)
     net.check_correct()
     res = net.get_final_node()
     self.assertAllClose(res.get_tensor(), 120)
 def test_contraction_disconnected(self):
     net = network.TensorNetwork(backend="tensorflow")
     a = net.add_node(np.ones([4, 5, 2]))
     b = net.add_node(np.ones([3, 2, 3]))
     edge1 = a[0]
     net.connect(a[2], b[1])
     net.connect(b[0], b[2])
     c = net.add_node(np.ones([3, 4]))
     d = net.add_node(np.ones([4, 3]))
     edge2 = c[0]
     net.connect(c[1], d[0])
     net = stochastic_contractor.stochastic(net, 2)
     net.check_correct(check_connected=False)
     node1, node2 = edge1.node1, edge2.node1
     self.assertAllClose(node1.get_tensor(), 6 * np.ones([4, 5]))
     self.assertAllClose(node2.get_tensor(), 4 * np.ones([3, 3]))