def find_parallel(
        edge: network_components.Edge
) -> Tuple[Set[network_components.Edge], int]:
    """Finds all edges shared between the nodes connected with the given edge.

  Args:
    edge: A non-dangling edge between two different nodes.

  Returns:
    parallel_edges: Edges that are parallel to the given edge.
    parallel_dim: Product of sizes of all parallel edges.
  """
    if edge.is_dangling():
        raise ValueError(
            "Cannot find parallel edges for dangling edge {}".format(edge))
    nodes = {edge.node1, edge.node2}
    parallel_dim = 1
    parallel_edges = set()
    for e in edge.node1.edges:
        if set(e.get_nodes()) == nodes:
            parallel_edges.add(e)
            edge_size = list(e.node1.get_tensor().shape)[e.axis1]
            if edge_size is not None:
                parallel_dim *= edge_size
    return parallel_edges, parallel_dim
示例#2
0
    def _contract_trace(
            self,
            edge: network_components.Edge,
            name: Optional[Text] = None) -> network_components.Node:
        """Contract a trace edge connecting in the TensorNetwork.

    Args:
      edge: The edge name or object to contract next.
      name: Name to give to the new node. If None, a name will automatically be
        generated.

    Returns:
      new_node: The new node created after the contraction.
    Raise:
      ValueError: When edge is a dangling edge.
    """
        if edge.is_dangling():
            raise ValueError(
                "Attempted to contract dangling edge '{}'".format(edge))
        if edge.node1 is not edge.node2:
            raise ValueError(
                "Can not take trace of edge '{}'. This edge connects to "
                "two different nodes: '{}' and '{}".format(
                    edge, edge.node1, edge.node2))
        axes = sorted([edge.axis1, edge.axis2])
        dims = len(edge.node1.tensor.shape)
        permutation = sorted(set(range(dims)) - set(axes)) + axes
        new_tensor = self.backend.trace(
            self.backend.transpose(edge.node1.tensor, perm=permutation))
        new_node = self.add_node(new_tensor, name)
        self._remove_trace_edge(edge, new_node)
        return new_node
示例#3
0
    def contract(self,
                 edge: network_components.Edge,
                 name: Optional[Text] = None) -> network_components.Node:
        """Contract an edge connecting two nodes in the TensorNetwork.

    Args:
      edge: The edge contract next.
      name: Name of the new node created.

    Returns:
      new_node: The new node created after the contraction.

    Raises:
      ValueError: When edge is a dangling edge or if it already has been
        contracted.
    """
        if not edge.is_being_used() or edge.node1 not in self.nodes_set:
            raise ValueError(
                "Attempting to contract edge '{}' that is not part of "
                "the network.".format(edge))
        if edge.is_dangling():
            raise ValueError("Attempting to contract dangling edge")
        if edge.node1 is edge.node2:
            return self._contract_trace(edge, name)
        new_tensor = self.backend.tensordot(edge.node1.tensor,
                                            edge.node2.tensor,
                                            [[edge.axis1], [edge.axis2]])
        new_node = self.add_node(new_tensor, name)
        self._remove_edges(set([edge]), edge.node1, edge.node2, new_node)
        return new_node
示例#4
0
    def disconnect(
        self,
        edge: network_components.Edge,
        dangling_edge_name_1: Optional[Text] = None,
        dangling_edge_name_2: Optional[Text] = None
    ) -> List[network_components.Edge]:
        """Break a edge into two dangling edges.

    Args:
      edge: An edge to break.
      dangling_edge_name_1: Optional name to give the new dangling edge 1.
      dangling_edge_name_2: Optional name to give the new dangling edge 2.

    Returns:
      dangling_edge_1: A new dangling edge.
      dangling_edge_2: A new dangling edge.

    Raises:
      ValueError: If input edge is a dangling one.
    """
        if edge.is_dangling():
            raise ValueError(
                "Attempted to break a dangling edge '{}'.".format(edge))
        node1 = edge.node1
        node2 = edge.node2
        dangling_edge_name_1 = self._new_edge_name(dangling_edge_name_1)
        dangling_edge_name_2 = self._new_edge_name(dangling_edge_name_2)
        dangling_edge_1 = network_components.Edge(dangling_edge_name_1, node1,
                                                  edge.axis1)
        dangling_edge_2 = network_components.Edge(dangling_edge_name_2, node2,
                                                  edge.axis2)
        node1.add_edge(dangling_edge_1, edge.axis1, True)
        node2.add_edge(dangling_edge_2, edge.axis2, True)
        self.edge_order.remove(edge)
        return [dangling_edge_1, dangling_edge_2]
示例#5
0
    def _remove_edge(self, edge: network_components.Edge,
                     new_node: network_components.Node) -> None:
        """Collapse an edge in the network.

    Collapses an edge and updates the rest of the network.

    Args:
      edge: The edge to contract.
      new_node: The new node that represents the contraction of the two old
        nodes.

    Raises:
      Value Error: If edge isn't in the network.
    """
        # Assert that the edge isn't a dangling edge.
        if edge.is_dangling():
            raise ValueError(
                "Attempted to remove dangling edge '{}'.".format(edge))
        if edge.node1 is edge.node2:
            self._remove_trace_edge(edge, new_node)
        # Collapse the nodes into a new node and remove the edge.
        node1 = edge.node1
        node2 = edge.node2
        node1_edges = edge.node1.edges[:]
        node2_edges = edge.node2.edges[:]
        node1_axis = edge.axis1
        node2_axis = edge.axis2
        # Redefine all other edges.
        num_added_front_edges = len(node1_edges) - 1
        for i, tmp_edge in enumerate(node1_edges[:node1_axis]):
            tmp_edge.update_axis(old_axis=i,
                                 old_node=node1,
                                 new_axis=i,
                                 new_node=new_node)
        for i, tmp_edge in enumerate(node1_edges[node1_axis + 1:]):
            tmp_edge.update_axis(old_axis=i + node1_axis + 1,
                                 old_node=node1,
                                 new_axis=i + node1_axis,
                                 new_node=new_node)
        for i, tmp_edge in enumerate(node2_edges[:node2_axis]):
            tmp_edge.update_axis(old_axis=i,
                                 old_node=node2,
                                 new_axis=i + num_added_front_edges,
                                 new_node=new_node)
        for i, tmp_edge in enumerate(node2_edges[node2_axis + 1:]):
            tmp_edge.update_axis(old_axis=i + node2_axis + 1,
                                 old_node=node2,
                                 new_axis=i + node2_axis +
                                 num_added_front_edges,
                                 new_node=new_node)

        node1_edges.pop(node1_axis)
        node2_edges.pop(node2_axis)
        new_edges = node1_edges + node2_edges
        for i, e in enumerate(new_edges):
            new_node.add_edge(e, i)

        # Remove nodes
        self.nodes_set.remove(node1)
        self.nodes_set.remove(node2)
示例#6
0
def redirect_edge(edge: Edge, new_node: AbstractNode,
                  old_node: AbstractNode) -> None:
    """
  Redirect `edge` from `old_node` to `new_node`.
  Routine updates `new_node` and `old_node`.
  `edge` is added to `new_node`, `old_node` gets a
  new Edge instead of `edge`.

  Args:
    edge: An Edge.
    new_node: The new `Node` object.
    old_node: The old `Node` object.

  Returns:
    None

  Raises:
    ValueError: if `edge` does not point to `old_node`.
  """
    if not edge.is_trace():
        if edge.is_dangling():
            if edge.node1 is not old_node:
                raise ValueError(f"edge {edge} is not pointing "
                                 f"to old_node {old_node}")
            edge.node1 = new_node
            axis = edge.axis1
        else:
            if edge.node1 is old_node:
                edge.node1 = new_node
                axis = edge.axis1
            elif edge.node2 is old_node:
                edge.node2 = new_node
                axis = edge.axis2
            else:
                raise ValueError(f"edge {edge} is not pointing "
                                 f"to old_node {old_node}")
        new_node.add_edge(edge, axis, True)
        new_edge = Edge(old_node, axis)
        old_node.add_edge(new_edge, axis, True)
    else:
        if edge.node1 is not old_node:
            raise ValueError(f"edge {edge} is not pointing "
                             f"to old_node {old_node}")
        edge.node1 = new_node
        edge.node2 = new_node
        axis1 = edge.axis1
        axis2 = edge.axis2
        new_node.add_edge(edge, axis1, True)
        new_node.add_edge(edge, axis2, True)
        new_edge = Edge(old_node, axis1, None, old_node, axis2)
        old_node.add_edge(new_edge, axis1, True)
        old_node.add_edge(new_edge, axis2, True)
示例#7
0
  def contract_parallel(
      self, edge: network_components.Edge) -> network_components.Node:
    """Contract all edges parallel to this edge.

    This method calls `contract_between` with the nodes connected by the edge.

    Args:
      edge: The edge to contract.
    Returns:
      The new node created after contraction.
    """
    if edge.is_dangling():
      raise ValueError("Attempted to contract dangling edge: '{}'".format(edge))
    return self.contract_between(edge.node1, edge.node2)
示例#8
0
    def _remove_trace_edge(self, edge: network_components.Edge,
                           new_node: network_components.Node) -> None:
        """Collapse a trace edge.

    Collapses a trace edge and updates the network.

    Args:
      edge: The edge to contract.
      new_node: The new node created after contraction.

    Returns:
      The node that had the contracted edge.

    Raises:
      ValueError: If edge is not a trace edge.
    """
        if edge.is_dangling():
            raise ValueError(
                "Attempted to remove dangling edge '{}'.".format(edge))
        if edge.node1 is not edge.node2:
            raise ValueError("Edge '{}' is not a trace edge.".format(edge))
        axes = sorted([edge.axis1, edge.axis2])
        node_edges = edge.node1.edges[:]
        node_edges.pop(axes[0])
        node_edges.pop(axes[1] - 1)
        seen_edges = set()
        for tmp_edge in node_edges:
            if tmp_edge in seen_edges:
                continue
            else:
                seen_edges.add(tmp_edge)
            if tmp_edge.node1 is edge.node1:
                to_reduce = 0
                to_reduce += 1 if tmp_edge.axis1 > axes[0] else 0
                to_reduce += 1 if tmp_edge.axis1 > axes[1] else 0
                tmp_edge.axis1 -= to_reduce
                tmp_edge.node1 = new_node
            if tmp_edge.node2 is edge.node1:
                to_reduce = 0
                to_reduce += 1 if tmp_edge.axis2 > axes[0] else 0
                to_reduce += 1 if tmp_edge.axis2 > axes[1] else 0
                tmp_edge.axis2 -= to_reduce
                tmp_edge.node2 = new_node
        # Update edges for the new node.
        for i, e in enumerate(node_edges):
            new_node.add_edge(e, i)
        self.nodes_set.remove(edge.node1)