def find_parallel( edge: network_components.Edge ) -> Tuple[Set[network_components.Edge], int]: """Finds all edges shared between the nodes connected with the given edge. Args: edge: A non-dangling edge between two different nodes. Returns: parallel_edges: Edges that are parallel to the given edge. parallel_dim: Product of sizes of all parallel edges. """ if edge.is_dangling(): raise ValueError( "Cannot find parallel edges for dangling edge {}".format(edge)) nodes = {edge.node1, edge.node2} parallel_dim = 1 parallel_edges = set() for e in edge.node1.edges: if set(e.get_nodes()) == nodes: parallel_edges.add(e) edge_size = list(e.node1.get_tensor().shape)[e.axis1] if edge_size is not None: parallel_dim *= edge_size return parallel_edges, parallel_dim
def _contract_trace( self, edge: network_components.Edge, name: Optional[Text] = None) -> network_components.Node: """Contract a trace edge connecting in the TensorNetwork. Args: edge: The edge name or object to contract next. name: Name to give to the new node. If None, a name will automatically be generated. Returns: new_node: The new node created after the contraction. Raise: ValueError: When edge is a dangling edge. """ if edge.is_dangling(): raise ValueError( "Attempted to contract dangling edge '{}'".format(edge)) if edge.node1 is not edge.node2: raise ValueError( "Can not take trace of edge '{}'. This edge connects to " "two different nodes: '{}' and '{}".format( edge, edge.node1, edge.node2)) axes = sorted([edge.axis1, edge.axis2]) dims = len(edge.node1.tensor.shape) permutation = sorted(set(range(dims)) - set(axes)) + axes new_tensor = self.backend.trace( self.backend.transpose(edge.node1.tensor, perm=permutation)) new_node = self.add_node(new_tensor, name) self._remove_trace_edge(edge, new_node) return new_node
def contract(self, edge: network_components.Edge, name: Optional[Text] = None) -> network_components.Node: """Contract an edge connecting two nodes in the TensorNetwork. Args: edge: The edge contract next. name: Name of the new node created. Returns: new_node: The new node created after the contraction. Raises: ValueError: When edge is a dangling edge or if it already has been contracted. """ if not edge.is_being_used() or edge.node1 not in self.nodes_set: raise ValueError( "Attempting to contract edge '{}' that is not part of " "the network.".format(edge)) if edge.is_dangling(): raise ValueError("Attempting to contract dangling edge") if edge.node1 is edge.node2: return self._contract_trace(edge, name) new_tensor = self.backend.tensordot(edge.node1.tensor, edge.node2.tensor, [[edge.axis1], [edge.axis2]]) new_node = self.add_node(new_tensor, name) self._remove_edges(set([edge]), edge.node1, edge.node2, new_node) return new_node
def disconnect( self, edge: network_components.Edge, dangling_edge_name_1: Optional[Text] = None, dangling_edge_name_2: Optional[Text] = None ) -> List[network_components.Edge]: """Break a edge into two dangling edges. Args: edge: An edge to break. dangling_edge_name_1: Optional name to give the new dangling edge 1. dangling_edge_name_2: Optional name to give the new dangling edge 2. Returns: dangling_edge_1: A new dangling edge. dangling_edge_2: A new dangling edge. Raises: ValueError: If input edge is a dangling one. """ if edge.is_dangling(): raise ValueError( "Attempted to break a dangling edge '{}'.".format(edge)) node1 = edge.node1 node2 = edge.node2 dangling_edge_name_1 = self._new_edge_name(dangling_edge_name_1) dangling_edge_name_2 = self._new_edge_name(dangling_edge_name_2) dangling_edge_1 = network_components.Edge(dangling_edge_name_1, node1, edge.axis1) dangling_edge_2 = network_components.Edge(dangling_edge_name_2, node2, edge.axis2) node1.add_edge(dangling_edge_1, edge.axis1, True) node2.add_edge(dangling_edge_2, edge.axis2, True) self.edge_order.remove(edge) return [dangling_edge_1, dangling_edge_2]
def _remove_edge(self, edge: network_components.Edge, new_node: network_components.Node) -> None: """Collapse an edge in the network. Collapses an edge and updates the rest of the network. Args: edge: The edge to contract. new_node: The new node that represents the contraction of the two old nodes. Raises: Value Error: If edge isn't in the network. """ # Assert that the edge isn't a dangling edge. if edge.is_dangling(): raise ValueError( "Attempted to remove dangling edge '{}'.".format(edge)) if edge.node1 is edge.node2: self._remove_trace_edge(edge, new_node) # Collapse the nodes into a new node and remove the edge. node1 = edge.node1 node2 = edge.node2 node1_edges = edge.node1.edges[:] node2_edges = edge.node2.edges[:] node1_axis = edge.axis1 node2_axis = edge.axis2 # Redefine all other edges. num_added_front_edges = len(node1_edges) - 1 for i, tmp_edge in enumerate(node1_edges[:node1_axis]): tmp_edge.update_axis(old_axis=i, old_node=node1, new_axis=i, new_node=new_node) for i, tmp_edge in enumerate(node1_edges[node1_axis + 1:]): tmp_edge.update_axis(old_axis=i + node1_axis + 1, old_node=node1, new_axis=i + node1_axis, new_node=new_node) for i, tmp_edge in enumerate(node2_edges[:node2_axis]): tmp_edge.update_axis(old_axis=i, old_node=node2, new_axis=i + num_added_front_edges, new_node=new_node) for i, tmp_edge in enumerate(node2_edges[node2_axis + 1:]): tmp_edge.update_axis(old_axis=i + node2_axis + 1, old_node=node2, new_axis=i + node2_axis + num_added_front_edges, new_node=new_node) node1_edges.pop(node1_axis) node2_edges.pop(node2_axis) new_edges = node1_edges + node2_edges for i, e in enumerate(new_edges): new_node.add_edge(e, i) # Remove nodes self.nodes_set.remove(node1) self.nodes_set.remove(node2)
def redirect_edge(edge: Edge, new_node: AbstractNode, old_node: AbstractNode) -> None: """ Redirect `edge` from `old_node` to `new_node`. Routine updates `new_node` and `old_node`. `edge` is added to `new_node`, `old_node` gets a new Edge instead of `edge`. Args: edge: An Edge. new_node: The new `Node` object. old_node: The old `Node` object. Returns: None Raises: ValueError: if `edge` does not point to `old_node`. """ if not edge.is_trace(): if edge.is_dangling(): if edge.node1 is not old_node: raise ValueError(f"edge {edge} is not pointing " f"to old_node {old_node}") edge.node1 = new_node axis = edge.axis1 else: if edge.node1 is old_node: edge.node1 = new_node axis = edge.axis1 elif edge.node2 is old_node: edge.node2 = new_node axis = edge.axis2 else: raise ValueError(f"edge {edge} is not pointing " f"to old_node {old_node}") new_node.add_edge(edge, axis, True) new_edge = Edge(old_node, axis) old_node.add_edge(new_edge, axis, True) else: if edge.node1 is not old_node: raise ValueError(f"edge {edge} is not pointing " f"to old_node {old_node}") edge.node1 = new_node edge.node2 = new_node axis1 = edge.axis1 axis2 = edge.axis2 new_node.add_edge(edge, axis1, True) new_node.add_edge(edge, axis2, True) new_edge = Edge(old_node, axis1, None, old_node, axis2) old_node.add_edge(new_edge, axis1, True) old_node.add_edge(new_edge, axis2, True)
def contract_parallel( self, edge: network_components.Edge) -> network_components.Node: """Contract all edges parallel to this edge. This method calls `contract_between` with the nodes connected by the edge. Args: edge: The edge to contract. Returns: The new node created after contraction. """ if edge.is_dangling(): raise ValueError("Attempted to contract dangling edge: '{}'".format(edge)) return self.contract_between(edge.node1, edge.node2)
def _remove_trace_edge(self, edge: network_components.Edge, new_node: network_components.Node) -> None: """Collapse a trace edge. Collapses a trace edge and updates the network. Args: edge: The edge to contract. new_node: The new node created after contraction. Returns: The node that had the contracted edge. Raises: ValueError: If edge is not a trace edge. """ if edge.is_dangling(): raise ValueError( "Attempted to remove dangling edge '{}'.".format(edge)) if edge.node1 is not edge.node2: raise ValueError("Edge '{}' is not a trace edge.".format(edge)) axes = sorted([edge.axis1, edge.axis2]) node_edges = edge.node1.edges[:] node_edges.pop(axes[0]) node_edges.pop(axes[1] - 1) seen_edges = set() for tmp_edge in node_edges: if tmp_edge in seen_edges: continue else: seen_edges.add(tmp_edge) if tmp_edge.node1 is edge.node1: to_reduce = 0 to_reduce += 1 if tmp_edge.axis1 > axes[0] else 0 to_reduce += 1 if tmp_edge.axis1 > axes[1] else 0 tmp_edge.axis1 -= to_reduce tmp_edge.node1 = new_node if tmp_edge.node2 is edge.node1: to_reduce = 0 to_reduce += 1 if tmp_edge.axis2 > axes[0] else 0 to_reduce += 1 if tmp_edge.axis2 > axes[1] else 0 tmp_edge.axis2 -= to_reduce tmp_edge.node2 = new_node # Update edges for the new node. for i, e in enumerate(node_edges): new_node.add_edge(e, i) self.nodes_set.remove(edge.node1)