def test_contract_between_output_order(backend): a_val = np.ones((2, 3, 4, 5)) b_val = np.ones((3, 5, 4, 2)) c_val = np.ones((2, 2)) a = tn.Node(a_val, backend=backend) b = tn.Node(b_val, backend=backend) c = tn.Node(c_val, backend=backend) tn.connect(a[0], b[3]) tn.connect(b[1], a[3]) tn.connect(a[1], b[0]) with pytest.raises(ValueError): d = tn.contract_between(a, b, name="New Node", output_edge_order=[a[2], b[2], a[0]]) tn.check_correct({a, b, c}, check_connections=False) with pytest.raises(ValueError): d = tn.contract_between(a, b, name="New Node", output_edge_order=[a[2], b[2], c[0]]) tn.check_correct({a, b, c}, check_connections=False) d = tn.contract_between(a, b, name="New Node", output_edge_order=[b[2], a[2]]) tn.check_correct({c, d}, check_connections=False) a_flat = np.reshape(np.transpose(a_val, (2, 1, 0, 3)), (4, 30)) b_flat = np.reshape(np.transpose(b_val, (2, 0, 3, 1)), (4, 30)) final_val = np.matmul(b_flat, a_flat.T) np.testing.assert_allclose(d.tensor, final_val) assert d.name == "New Node"
def getOverlap(psi1Orig: List[tn.Node], psi2Orig: List[tn.Node]): psi1 = copyState(psi1Orig) psi2 = copyState(psi2Orig, conj=True) psi1[0][0] ^ psi2[0][0] psi1[0][1] ^ psi2[0][1] contracted = tn.contract_between(psi1[0], psi2[0], name='contracted') for i in range(1, len(psi1) - 1): psi1[i][1] ^ psi2[i][1] contracted[0] ^ psi1[i][0] contracted[1] ^ psi2[i][0] contracted = tn.contract_between( tn.contract_between(contracted, psi1[i]), psi2[i]) psi1[len(psi1) - 1][1] ^ psi2[len(psi1) - 1][1] psi1[len(psi1) - 1][2] ^ psi2[len(psi1) - 1][2] contracted[0] ^ psi1[len(psi1) - 1][0] contracted[1] ^ psi2[len(psi1) - 1][0] contracted = tn.contract_between( tn.contract_between(contracted, psi1[len(psi1) - 1]), psi2[len(psi1) - 1]) result = contracted.tensor tn.remove_node(contracted) removeState(psi1) removeState(psi2) return result
def stateEnergy(psi: List[tn.Node], H: HOp): E = 0 for i in range(len(psi)): psiCopy = bops.copyState(psi) single_i = bops.copyState([H.singles[i]])[0] psiCopy[i] = bops.permute(tn.contract(psiCopy[i][1] ^ single_i[0], name=('site' + str(i))), [0, 2, 1]) E += bops.getOverlap(psiCopy, psi) bops.removeState(psiCopy) tn.remove_node(single_i) for i in range(len(psi) - 1): psiCopy = bops.copyState(psi) r2l = bops.copyState([H.r2l[i+1]])[0] l2r = bops.copyState([H.l2r[i]])[0] psiCopy[i][2] ^ psiCopy[i+1][0] psiCopy[i][1] ^ l2r[0] r2l[0] ^ psiCopy[i+1][1] l2r[2] ^ r2l[2] M = tn.contract_between(psiCopy[i], \ tn.contract_between(l2r, tn.contract_between(r2l, psiCopy[i+1]))) if bops.multiContraction(M, M, '0123', '0123*').tensor != 0: [psiCopy, te] = bops.assignNewSiteTensors(psiCopy, i, M, '>>') E += bops.getOverlap(psiCopy, psi) bops.removeState(psiCopy) tn.remove_node(r2l) tn.remove_node(l2r) return E
def test_contract_between_no_outer_product_value_error(backend): a_val = np.ones((2, 3, 4)) b_val = np.ones((5, 6, 7)) a = tn.Node(a_val, backend=backend) b = tn.Node(b_val, backend=backend) with pytest.raises(ValueError): tn.contract_between(a, b)
def gf_cost_func_v1(B, AL, AR): Bn = tn.Node(B, axis_names=["p", "L", "R"]) ALn = tn.Node(AL, axis_names=["p", "L", "R"]) ARn = tn.Node(AR, axis_names=["p", "L", "R"]) ALn_c = tn.conj(ALn) ALn["L"] ^ ALn_c["L"] ALn["R"] ^ ALn_c["R"] PLn = tn.contract_between(ALn, ALn_c, output_edge_order=[ALn_c["p"], ALn["p"]], axis_names=["p_in", "p_out"]) ARn_c = tn.conj(ARn) ARn["L"] ^ ARn_c["L"] ARn["R"] ^ ARn_c["R"] PRn = tn.contract_between(ARn, ARn_c, output_edge_order=[ARn_c["p"], ARn["p"]], axis_names=["p_in", "p_out"]) PLRn = tn.Node(PLn.get_tensor() + PRn.get_tensor(), axis_names=["p_in", "p_out"]) Bn_c = tn.conj(Bn) Bn["L"] ^ Bn_c["L"] Bn["R"] ^ Bn_c["R"] Bn["p"] ^ PLRn["p_in"] Bn_c["p"] ^ PLRn["p_out"] res = Bn @ PLRn @ Bn_c return res.get_tensor()
def _ev_mps(self, obs_nodes, wires): r"""Expectation value of observables on specified wires using a MPS representation. Args: obs_nodes (Sequence[tn.Node]): the observables as TensorNetwork Nodes wires (Sequence[Sequence[int]]): measured subsystems for each observable Returns: complex: expectation value :math:`\expect{A} = \bra{\psi}A\ket{\psi}` """ if any(len(wires_seq) > 2 for wires_seq in wires): raise NotImplementedError( "Multi-wire measurement only supported for nearest-neighbour wire pairs." ) if len(obs_nodes) == 1 and len(wires[0]) == 1: # TODO: can measure multiple local expectation values at once, # but this would require change of `expval` behaviour and # refactor of `execute` logic from parent class expval = self.mps.measure_local_operator(obs_nodes, wires[0])[0] else: conj_nodes = [tn.conj(node) for node in self.mps.nodes] meas_wires = [] # connect measured bra and ket nodes with observables for obs_node, wire_seq in zip(obs_nodes, wires): if len(wire_seq) == 2 and abs(wire_seq[0] - wire_seq[1]) > 1: raise NotImplementedError( "Multi-wire measurement only supported for nearest-neighbour wire pairs." ) offset = len(wire_seq) for idx, wire in enumerate(wire_seq): tn.connect(conj_nodes[wire][1], obs_node[idx]) tn.connect(obs_node[offset + idx], self.mps.nodes[wire][1]) meas_wires.extend(wire_seq) for wire in range(self.num_wires): # connect unmeasured ket nodes with bra nodes if wire not in meas_wires: tn.connect(conj_nodes[wire][1], self.mps.nodes[wire][1]) # connect local nodes of MPS (not connected by default in tn) if wire != self.num_wires - 1: tn.connect(self.mps.nodes[wire][2], self.mps.nodes[wire + 1][0]) tn.connect(conj_nodes[wire][2], conj_nodes[wire + 1][0]) # contract MPS bonds first bra_node = conj_nodes[0] ket_node = self.mps.nodes[0] for wire in range(self.num_wires - 1): bra_node = tn.contract_between(bra_node, conj_nodes[wire + 1]) ket_node = tn.contract_between(ket_node, self.mps.nodes[wire + 1]) # contract observables into ket for obs_node in obs_nodes: ket_node = tn.contract_between(obs_node, ket_node) # contract bra into observables/ket expval_node = tn.contract_between(bra_node, ket_node) # remove dangling singleton edges expval = self._squeeze(expval_node.tensor) return expval
def test_split_node_rq_unitarity_float(backend): a = tn.Node(np.random.rand(3, 3), backend=backend) _, q = tn.split_node_rq(a, [a[0]], [a[1]]) n1 = tn.Node(q.tensor, backend=backend) n2 = tn.linalg.node_linalg.conj(q) n1[1] ^ n2[1] u1 = tn.contract_between(n1, n2) n1 = tn.Node(q.tensor, backend=backend) n2 = tn.Node(q.tensor, backend=backend) n2[0] ^ n1[0] u2 = tn.contract_between(n1, n2) np.testing.assert_almost_equal(u1.tensor, np.eye(3)) np.testing.assert_almost_equal(u2.tensor, np.eye(3))
def test_svd_consistency_symmetric_real_matrix(backend): original_tensor = np.array([[1.0, 2.0, 3.0, 4.0], [5.0, 6.0, 3.0, 2.0]], dtype=np.float64) node = tn.Node(original_tensor, backend=backend) u, vh, _ = tn.split_node(node, [node[0]], [node[1]]) final_node = tn.contract_between(u, vh) np.testing.assert_allclose(final_node.tensor, original_tensor, rtol=1e-6)
def apply_op(psi, op, n1, pbc=False): """Apply a local operator to a wavefunction. The number of dimensions of the tensor representing the wavefunction `psi` is taken to be the number of lattice sites `N`. The operator acts nontrivially on sites `n1` to `n1 + k - 1` of psi, where `0 <= n1 < N`, and is expected to have `2*k` dimensions. The first `k` dimensions represent the output and the last `k` dimensions represent the input, to be contracted with `psi`. Args: psi: An `N`-dimensional tensor representing the wavefunction. op: Tensor with `2 * k` dimensions. The operator to apply. n1: The number of the leftmost site at which to apply the operator. pbc: If `True`, use periodic boundary conditions, so that site `N` is identified with site `0`. Otherwise, site `N-1` has no neighbors to the right. Returns: psi_final: The result of applying `op` to `psi`. """ n_psi = tensornetwork.Node(psi, backend="tensorflow") site_edges = n_psi.get_all_edges() site_edges, n_op = _apply_op_network(site_edges, op, n1, pbc) n_res = tensornetwork.contract_between(n_op, n_psi, output_edge_order=site_edges) return n_res.tensor
def f(x: tf.Tensor, nodes: List[Node], output_dim: int, exp_base: int, num_nodes: int, use_bias: bool, bias_var: tf.Tensor) -> tf.Tensor: input_reshaped = tf.reshape(x, (exp_base, ) * num_nodes + (output_dim, )) state_node = tn.Node(input_reshaped, name='xnode', backend="tensorflow") # The TN will be connected like this: # xxxxxxxxx # | | | | # | | 11111 # | | | # | 22222 # | | # 33333 # | # | for i in range(num_nodes): op = tn.Node(nodes[i], name=f'node_{i}', backend="tensorflow") tn.connect(state_node.edges[-1], op[0]) tn.connect(state_node.edges[-2], op[1]) state_node = tn.contract_between(state_node, op) result = tf.reshape(state_node.tensor, (-1, )) if use_bias: result += bias_var return result
def multiContraction(node1: tn.Node, node2: tn.Node, edges1, edges2, nodeName=None, cleanOr1=False, cleanOr2=False, isDiag1=False, isDiag2=False) -> tn.Node: if node1 is None or node2 is None: return None if edges1[len(edges1) - 1] == '*': copy1 = copyState([node1], conj=True)[0] edges1 = edges1[0:len(edges1) - 1] else: copy1 = copyState([node1])[0] if edges2[len(edges2) - 1] == '*': copy2 = copyState([node2], conj=True)[0] edges2 = edges2[0:len(edges2) - 1] else: copy2 = copyState([node2])[0] if cleanOr1: tn.remove_node(node1) if cleanOr2: tn.remove_node(node2) if isDiag1 and isDiag2: return tn.Node(copy1.tensor * copy2.tensor) elif isDiag1 and not isDiag2: return contractDiag(copy2, copy1.tensor, int(edges2[0])) elif isDiag2 and not isDiag1: return contractDiag(copy1, copy2.tensor, int(edges1[0])) for i in range(len(edges1)): copy1[int(edges1[i])] ^ copy2[int(edges2[i])] return tn.contract_between(copy1, copy2, name=nodeName)
def f(x: tf.Tensor, nodes: List[Node], num_nodes: int, num_legs: int, leg_dim: int, use_bias: bool, bias_var: tf.Tensor) -> tf.Tensor: l = [leg_dim] * num_legs input_reshaped = tf.reshape(x, tuple(l)) x_node = tn.Node(input_reshaped, name='xnode', backend="tensorflow") edges = x_node.edges[:] # force a copy for i in range(num_nodes): node = tn.Node(nodes[i], name=f'node_{i}', backend="tensorflow") tn.connect(edges[i % num_legs], node[0]) tn.connect(edges[(i + 1) % num_legs], node[1]) edges[i % num_legs] = node[2] edges[(i + 1) % num_legs] = node[3] x_node = tn.contract_between(x_node, node) # The TN will be connected in a "staircase" pattern, like this: # | | | | # | | 3333 # | | | | # | 2222 | # | | | | # 1111 | | # | | | | # xxxxxxxxxx result = tf.reshape(x_node.tensor, (self.output_dim, )) if use_bias: result += bias_var return result
def update_dis(hamiltonian, state, isometry, disentangler): """Updates the disentangler with the aim of reducing the energy. Args: hamiltonian: The hamiltonian (rank-6 tensor) defined at the bottom of the MERA layer. state: The 3-site reduced state (rank-6 tensor) defined at the top of the MERA layer. isometry: The isometry tensor (rank 3) of the binary MERA. disentangler: The disentangler tensor (rank 4) of the binary MERA. Returns: The updated disentangler. """ env = env_dis(hamiltonian, state, isometry, disentangler) nenv = tensornetwork.Node( env, axis_names=["bl", "br", "tl", "tr"], backend="jax") output_edges = [nenv["bl"], nenv["br"], nenv["tl"], nenv["tr"]] nu, _, nv, _ = tensornetwork.split_node_full_svd( nenv, [nenv["bl"], nenv["br"]], [nenv["tl"], nenv["tr"]], left_edge_name="s1", right_edge_name="s2") nu["s1"].disconnect() nv["s2"].disconnect() tensornetwork.connect(nu["s1"], nv["s2"]) nres = tensornetwork.contract_between(nu, nv, output_edge_order=output_edges) return np.conj(nres.get_tensor())
def f(x: tf.Tensor, nodes: List[Node], num_nodes: int, use_bias: bool, bias_var: tf.Tensor) -> tf.Tensor: state_node = tn.Node(x, name='xnode', backend="tensorflow") operating_edge = state_node[0] # The TN will be connected like this: # | | | | # | | 33333 # | | | # | 22222 # | | # 11111 # | # xxxxxxx for i in range(num_nodes): op = tn.Node(nodes[i], name=f'node_{i}', backend="tensorflow") tn.connect(operating_edge, op[0]) operating_edge = op[2] state_node = tn.contract_between(state_node, op) result = tf.reshape(state_node.tensor, (-1,)) if use_bias: result += bias_var return result
def test_apply_two_site_gate(backend_dtype_values): backend = backend_dtype_values[0] dtype = backend_dtype_values[1] D, d, N = 10, 2, 10 tensors = [get_random_np((1, d, D), dtype) ] + [get_random_np((D, d, D), dtype) for _ in range(N - 2)] + [get_random_np((D, d, 1), dtype)] mps = BaseMPS(tensors, center_position=0, backend=backend) gate = get_random_np((2, 2, 2, 2), dtype) tensor1 = mps.tensors[5] tensor2 = mps.tensors[6] mps.apply_two_site_gate(gate, 5, 6) tmp = np.tensordot(tensor1, tensor2, ([2], [0])) actual = np.transpose(np.tensordot(tmp, gate, ([1, 2], [2, 3])), (0, 2, 3, 1)) node1 = tn.Node(mps.tensors[5], backend=backend) node2 = tn.Node(mps.tensors[6], backend=backend) node1[2] ^ node2[0] order = [node1[0], node1[1], node2[1], node2[2]] res = tn.contract_between(node1, node2) res.reorder_edges(order) np.testing.assert_allclose(res.tensor, actual)
def test_contract_between_trace(backend): a_val = np.ones((2, 3, 2, 4)) a = tn.Node(a_val, backend=backend) tn.connect(a[0], a[2]) c = tn.contract_between(a, a, axis_names=["1", "3"]) assert c.shape == (3, 4) assert c.axis_names == ["1", "3"]
def test_contract_between_outer_product_no_value_error(backend): a_val = np.ones((2, 3, 4)) b_val = np.ones((5, 6, 7)) a = tn.Node(a_val, backend=backend) b = tn.Node(b_val, backend=backend) c = tn.contract_between(a, b, allow_outer_product=True) assert c.shape == (2, 3, 4, 5, 6, 7)
def _ev_exact(self, obs_nodes, obs_wires): r"""Expectation value of observables on specified wires using an exact representation. Args: obs_nodes (Sequence[tn.Node]): the observables as TensorNetwork Nodes obs_wires (Sequence[Wires]): measured wires for each observable Returns: complex: expectation value :math:`\expect{A} = \bra{\psi}A\ket{\psi}` """ self._contract_premeasurement_network() ket = self._contracted_state_node bra = tn.conj(ket, name="Bra") all_device_wires = Wires(range(self.num_wires)) meas_device_wires = [] # For wires which are measured, add edges between # the ket node, the observable nodes, and the bra node for obs_node, wires in zip(obs_nodes, obs_wires): # translate to consecutive wire labels used by device device_wires = self.map_wires(wires) meas_device_wires.append(device_wires) for idx, l in enumerate(device_wires.labels): # Use convention that the indices of a tensor are ordered like # [output_idx1, output_idx2, ..., input_idx1, input_idx2, ...] output_idx = idx input_idx = len(device_wires) + idx tn.connect(obs_node[input_idx], ket[l]) # A|psi> tn.connect(bra[l], obs_node[output_idx]) # <psi|A meas_device_wires = Wires(meas_device_wires) # unmeasured wires are contracted directly between bra and ket unmeasured_device_wires = Wires.unique_wires( [all_device_wires, meas_device_wires]) for w in unmeasured_device_wires.labels: tn.connect(bra[w], ket[w]) # At this stage, all nodes are connected, and the contraction yields a # scalar value. ket_and_observable_node = ket for obs_node in obs_nodes: ket_and_observable_node = tn.contract_between( obs_node, ket_and_observable_node) return tn.contract_between(bra, ket_and_observable_node).tensor
def f(x, n): x_slice = x[..., :n] n1 = Node(x_slice, backend="pytorch") n2 = Node(x_slice, backend="pytorch") connect(n1[0], n2[0]) connect(n1[1], n2[1]) connect(n1[2], n2[2]) return contract_between(n1, n2).get_tensor()
def test_contract_between_trace_edges(backend): a_val = np.ones((3, 3)) final_val = np.trace(a_val) a = tn.Node(a_val, backend=backend) tn.connect(a[0], a[1]) b = tn.contract_between(a, a) tn.check_correct({b}) np.testing.assert_allclose(b.tensor, final_val)
def test_contract_between_trace_output_edge_order(backend): a_val = np.ones((2, 3, 2, 4)) a = tn.Node(a_val, backend=backend) tn.connect(a[0], a[2]) c = tn.contract_between( a, a, output_edge_order=[a[3], a[1]], axis_names=["3", "1"]) assert c.shape == (4, 3) assert c.axis_names == ["3", "1"]
def ev(self, obs_nodes, wires): r"""Expectation value of observables on specified wires. Args: obs_nodes (Sequence[tn.Node]): the observables as tensornetwork Nodes wires (Sequence[Sequence[[int]]): measured subsystems for each observable Returns: float: expectation value :math:`\expect{A} = \bra{\psi}A\ket{\psi}` """ all_wires = tuple(w for w in range(self.num_wires)) ket = self._add_node(self._state, wires=all_wires, name="Ket") bra = self._add_node(tn.conj(ket), wires=all_wires, name="Bra") meas_wires = [] # We need to build up <psi|A|psi> step-by-step. # For wires which are measured, we need to connect edges between # bra, obs_node, and ket. # For wires which are not measured, we need to connect edges between # bra and ket. # We use the convention that the indices of a tensor are ordered like # [output_idx1, output_idx2, ..., input_idx1, input_idx2, ...] for obs_node, obs_wires in zip(obs_nodes, wires): meas_wires.extend(obs_wires) for idx, w in enumerate(obs_wires): output_idx = idx input_idx = len(obs_wires) + idx self._add_edge(obs_node, input_idx, ket, w) # A|psi> self._add_edge(bra, w, obs_node, output_idx) # <psi|A for w in set(all_wires) - set(meas_wires): self._add_edge(bra, w, ket, w) # |psi[w]|**2 # At this stage, all nodes are connected, and the contraction yields a # scalar value. contracted_ket = ket for obs_node in obs_nodes: contracted_ket = tn.contract_between(obs_node, contracted_ket) expval = tn.contract_between(bra, contracted_ket).tensor if np.abs(expval.imag) > tolerance: warnings.warn( "Nonvanishing imaginary part {} in expectation value.".format( expval.imag), RuntimeWarning, ) return expval.real
def test_contract_between_trace_edges(dtype, num_charges): a_val = get_random_symmetric((50, 50), [False, True], num_charges, dtype=dtype) final_val = np.trace(a_val.todense()) a = tn.Node(a_val, backend='symmetric') tn.connect(a[0], a[1]) b = tn.contract_between(a, a) tn.check_correct({b}) np.testing.assert_allclose(b.tensor.todense(), final_val)
def test_contract_between_outer_product_no_value_error(backend): a_val = np.ones((2, 3, 4)) b_val = np.ones((5, 6, 7)) a = tn.Node(a_val, backend=backend) b = tn.Node(b_val, backend=backend) output_axis_names = ["a0", "a1", "a2", "b0", "b1", "b2"] c = tn.contract_between( a, b, allow_outer_product=True, axis_names=output_axis_names) assert c.shape == (2, 3, 4, 5, 6, 7) assert c.axis_names == output_axis_names
def test_split_node_rq_unitarity_complex(backend): if backend == "pytorch": pytest.skip("Complex numbers currently not supported in PyTorch") if backend == "jax": pytest.skip("Complex QR crashes jax") a = tn.Node(np.random.rand(3, 3) + 1j * np.random.rand(3, 3), backend=backend) _, q = tn.split_node_rq(a, [a[0]], [a[1]]) n1 = tn.Node(q.tensor, backend=backend) n2 = tn.linalg.node_linalg.conj(q) n1[1] ^ n2[1] u1 = tn.contract_between(n1, n2) n1 = tn.Node(q.tensor, backend=backend) n2 = tn.linalg.node_linalg.conj(q) n2[0] ^ n1[0] u2 = tn.contract_between(n1, n2) np.testing.assert_almost_equal(u1.tensor, np.eye(3)) np.testing.assert_almost_equal(u2.tensor, np.eye(3))
def test_svd_consistency(backend): if backend == "pytorch": pytest.skip("Complex numbers currently not supported in PyTorch") original_tensor = np.array( [[1.0, 2.0j, 3.0, 4.0], [5.0, 6.0 + 1.0j, 3.0j, 2.0 + 1.0j]], dtype=np.complex64) node = tn.Node(original_tensor, backend=backend) u, vh, _ = tn.split_node(node, [node[0]], [node[1]]) final_node = tn.contract_between(u, vh) np.testing.assert_allclose(final_node.tensor, original_tensor, rtol=1e-6)
def apply(self, operation, wires, par): A = self._get_operator_matrix(operation, par) num_mult_idxs = len(wires) A = np.reshape(A, [2] * num_mult_idxs * 2) op_node = self._add_node(A, wires=wires, name=operation) for idx, w in enumerate(wires): self._add_edge(op_node, num_mult_idxs + idx, self._state, w) self._free_edges[w] = op_node[idx] # TODO: can be smarter here about collecting contractions? self._state = tn.contract_between(op_node, self._state, output_edge_order=self._free_edges)
def test_svd_consistency(dtype, num_charges): np.random.seed(111) original_tensor = get_random((20, 20), num_charges, dtype) node = tn.Node(original_tensor, backend='symmetric') u, vh, _ = tn.split_node(node, [node[0]], [node[1]]) final_node = tn.contract_between(u, vh) np.testing.assert_allclose( final_node.tensor.data, original_tensor.data, rtol=1e-6) assert np.all([ charge_equal(final_node.tensor._charges[n], original_tensor._charges[n]) for n in range(len(original_tensor._charges)) ])
def test_split_edges_standard_contract_between(backend): a = tn.Node(np.random.randn(6, 3, 5), name="A", backend=backend) b = tn.Node(np.random.randn(2, 4, 6, 3), name="B", backend=backend) e1 = tn.connect(a[0], b[2], "Edge_1_1") # to be split tn.connect(a[1], b[3], "Edge_1_2") # background standard edge node_dict, _ = tn.copy({a, b}) c_prior = node_dict[a] @ node_dict[b] shape = (2, 1, 3) tn.split_edge(e1, shape) tn.check_correct({a, b}) c_post = tn.contract_between(a, b) np.testing.assert_allclose(c_prior.tensor, c_post.tensor)
def svdTruncation(node: tn.Node, leftEdges: List[tn.Edge], rightEdges: List[tn.Edge], \ dir: str, maxBondDim=1024, leftName='U', rightName='V', edgeName=None): maxBondDim = getAppropriateMaxBondDim(maxBondDim, leftEdges, rightEdges) if dir == '>>': leftEdgeName = edgeName rightEdgeName = None else: leftEdgeName = None rightEdgeName = edgeName [U, S, V, truncErr] = tn.split_node_full_svd(node, leftEdges, rightEdges, max_singular_values=maxBondDim, \ left_name=leftName, right_name=rightName, \ left_edge_name=leftEdgeName, right_edge_name=rightEdgeName) if dir == '>>': l = copyState([U])[0] r = copyState([tn.contract_between(S, V, name=V.name)])[0] else: l = copyState([tn.contract_between(U, S, name=U.name)])[0] r = copyState([V])[0] tn.remove_node(U) tn.remove_node(S) tn.remove_node(V) return [l, r, truncErr]