def test_copy_node_get_partners_with_trace(backend): node1 = CopyNode(4, 2, backend=backend) node2 = Node(np.random.rand(2, 2), backend=backend, name="node2") tn.connect(node1[0], node1[1]) tn.connect(node1[2], node2[0]) tn.connect(node1[3], node2[1]) assert node1.get_partners() == {node2: {0, 1}}
def contract_copy_node( self, copy_node: network_components.CopyNode, name: Optional[Text] = None) -> network_components.BaseNode: """Contract all edges incident on given copy node. Args: copy_node: Copy tensor node to be contracted. name: Name of the new node created. Returns: New node representing contracted tensor. Raises: ValueError: If copy_node has dangling edge(s). """ partners = copy_node.get_partners() new_node = self.add_node( network_components.contract_copy_node(copy_node, name)) # Remove nodes for partner in partners: if partner in self.nodes_set: self.nodes_set.remove(partner) for partner in partners: if not partner.is_disabled: partner.disable() self.nodes_set.remove(copy_node) copy_node.disable() return new_node
def contract_copy_node( self, copy_node: network_components.CopyNode, name: Optional[Text] = None) -> network_components.Node: """Contract all edges incident on given copy node. Args: copy_node: Copy tensor node to be contracted. name: Name of the new node created. Returns: New node representing contracted tensor. Raises: ValueError: If copy_node has dangling edge(s). """ new_tensor = copy_node.compute_contracted_tensor() new_node = self.add_node(new_tensor, name) partners = copy_node.get_partners() new_axis = 0 for partner in partners: for edge in partner.edges: if edge.node1 is copy_node or edge.node2 is copy_node: self.edge_order.remove(edge) continue old_axis = edge.axis1 if edge.node1 is partner else edge.axis2 edge.update_axis(old_node=partner, old_axis=old_axis, new_node=new_node, new_axis=new_axis) new_node.add_edge(edge, new_axis) new_axis += 1 self.nodes_set.remove(partner) assert len(new_tensor.shape) == new_axis self.nodes_set.remove(copy_node) self.nodes_set.add(new_node) return new_node