Exemple #1
0
    def _remove_edges(self, edges: Set[network_components.Edge],
                      node1: network_components.BaseNode,
                      node2: network_components.BaseNode,
                      new_node: network_components.BaseNode) -> None:
        """Collapse a list of edges shared by two nodes in the network.

    Collapses the edges and updates the rest of the network.
    The nodes that currently share the edges in `edges` must be supplied as
    `node1` and `node2`. The ordering of `node1` and `node2` must match the
    axis ordering of `new_node` (as determined by the contraction procedure).

    Args:
      edges: The edges to contract.
      node1: The old node that supplies the first edges of `new_node`.
      node2: The old node that supplies the last edges of `new_node`.
      new_node: The new node that represents the contraction of the two old
        nodes.

    Raises:
      Value Error: If edge isn't in the network.
    """
        network_components._remove_edges(edges, node1, node2, new_node)

        if node1 in self.nodes_set:
            self.nodes_set.remove(node1)
        if node2 in self.nodes_set:
            self.nodes_set.remove(node2)
        if not node1.is_disabled:
            node1.disable()
        if not node1.is_disabled:
            node2.disable()
Exemple #2
0
    def remove_node(
        self, node: network_components.BaseNode
    ) -> Tuple[Dict[Text, network_components.Edge], Dict[
            int, network_components.Edge]]:
        """Remove a node from the network.

    Args:
      node: The node to be removed.

    Returns:
      broken_edges_by_name: A Dictionary mapping `node`'s axis names to
        the newly broken edges.
      broken_edges_by_axis: A Dictionary mapping `node`'s axis numbers
        to the newly broken edges.

    Raises:
      ValueError: If the node isn't in the network.
    """
        if node not in self:
            raise ValueError("Node '{}' is not in the network.".format(node))
        broken_edges_by_name = {}
        broken_edges_by_axis = {}
        print(len(node.axis_names))
        for i, name in enumerate(node.axis_names):
            print(i, name)
            if not node[i].is_dangling() and not node[i].is_trace():
                edge1, edge2 = self.disconnect(node[i])
                new_broken_edge = edge1 if edge1.node1 is not node else edge2
                broken_edges_by_axis[i] = new_broken_edge
                broken_edges_by_name[name] = new_broken_edge

        self.nodes_set.remove(node)
        node.disable()
        return broken_edges_by_name, broken_edges_by_axis
Exemple #3
0
  def outer_product(self,
                    node1: network_components.BaseNode,
                    node2: network_components.BaseNode,
                    name: Optional[Text] = None) -> network_components.BaseNode:
    """Calculates an outer product of the two nodes.

    This causes the nodes to combine their edges and axes, so the shapes are
    combined. For example, if `a` had a shape (2, 3) and `b` had a shape
    (4, 5, 6), then the node `net.outer_product(a, b) will have shape
    (2, 3, 4, 5, 6).

    Args:
      node1: The first node. The axes on this node will be on the left side of
        the new node.
      node2: The second node. The axes on this node will be on the right side of
        the new node.
      name: Optional name to give the new node created.

    Returns:
      A new node. Its shape will be node1.shape + node2.shape
    """
    new_node = self.add_node(
        network_components.outer_product(node1, node2, name, axis_names=None))
    # Remove the nodes from the set.
    if node1 in self.nodes_set:
      self.nodes_set.remove(node1)
    if node2 in self.nodes_set:
      self.nodes_set.remove(node2)
    if not node1.is_disabled:
      node1.disable()
    if not node2.is_disabled:
      node2.disable()

    return new_node
Exemple #4
0
  def contract_between(
      self,
      node1: network_components.BaseNode,
      node2: network_components.BaseNode,
      name: Optional[Text] = None,
      allow_outer_product: bool = False,
      output_edge_order: Optional[Sequence[network_components.Edge]] = None,
  ) -> network_components.BaseNode:
    """Contract all of the edges between the two given nodes.

    Args:
      node1: The first node.
      node2: The second node.
      name: Name to give to the new node created.
      allow_outer_product: Optional boolean. If two nodes do not share any edges
        and `allow_outer_product` is set to `True`, then we return the outer
        product of the two nodes. Else, we raise a `ValueError`.
      output_edge_order: Optional sequence of Edges. When not `None`, must 
        contain all edges belonging to, but not shared by `node1` and `node2`.
        The axes of the new node will be permuted (if necessary) to match this
        ordering of Edges.

    Returns:
      The new node created.

    Raises:
      ValueError: If no edges are found between node1 and node2 and
        `allow_outer_product` is set to `False`.
    """
    new_node = self.add_node(
        network_components.contract_between(
            node1,
            node2,
            name,
            allow_outer_product,
            output_edge_order,
            axis_names=None))
    if node1 in self.nodes_set:
      self.nodes_set.remove(node1)
    if node2 in self.nodes_set:
      self.nodes_set.remove(node2)
    if not node1.is_disabled:
      node1.disable()
    if not node2.is_disabled:
      node2.disable()

    return new_node
Exemple #5
0
    def split_node_rq(
        self,
        node: network_components.BaseNode,
        left_edges: List[network_components.Edge],
        right_edges: List[network_components.Edge],
        left_name: Optional[Text] = None,
        right_name: Optional[Text] = None,
        edge_name: Optional[Text] = None,
    ) -> Tuple[network_components.BaseNode, network_components.BaseNode]:
        """Split a `Node` using RQ (reversed QR) decomposition

    Let M be the matrix created by flattening left_edges and right_edges into
    2 axes. Let :math:`QR = M^*` be the QR Decomposition of 
    :math:`M^*`. This will split the network into 2 nodes. The left node's 
    tensor will be :math:`R^*` (a lower triangular matrix) and the right node's tensor will be 
    :math:`Q^*` (an orthonormal matrix)

    Args:
      node: The node you want to split.
      left_edges: The edges you want connected to the new left node.
      right_edges: The edges you want connected to the new right node.
      left_name: The name of the new left node. If `None`, a name will be generated
        automatically.
      right_name: The name of the new right node. If `None`, a name will be generated
        automatically.
      edge_name: The name of the new `Edge` connecting the new left and right node. 
        If `None`, a name will be generated automatically.

    Returns:
      A tuple containing:
        left_node: 
          A new node created that connects to all of the `left_edges`.
          Its underlying tensor is :math:`Q`
        right_node: 
          A new node created that connects to all of the `right_edges`.
          Its underlying tensor is :math:`R`
    """
        r, q = network_operations.split_node_rq(node, left_edges, right_edges,
                                                left_name, right_name,
                                                edge_name)
        left_node = self.add_node(r)
        right_node = self.add_node(q)

        self.nodes_set.remove(node)
        node.disable()
        return left_node, right_node
Exemple #6
0
    def split_node_full_svd(
        self,
        node: network_components.BaseNode,
        left_edges: List[network_components.Edge],
        right_edges: List[network_components.Edge],
        max_singular_values: Optional[int] = None,
        max_truncation_err: Optional[float] = None,
        left_name: Optional[Text] = None,
        middle_name: Optional[Text] = None,
        right_name: Optional[Text] = None,
        left_edge_name: Optional[Text] = None,
        right_edge_name: Optional[Text] = None,
    ) -> Tuple[network_components.BaseNode, network_components.BaseNode,
               network_components.BaseNode, Tensor]:
        """Split a node by doing a full singular value decomposition.

    Let M be the matrix created by flattening left_edges and right_edges into
    2 axes. Let :math:`U S V^* = M` be the Singular Value Decomposition of 
    :math:`M`.

    The left most node will be :math:`U` tensor of the SVD, the middle node is
    the diagonal matrix of the singular values, ordered largest to smallest,
    and the right most node will be the :math:`V*` tensor of the SVD.

    The singular value decomposition is truncated if `max_singular_values` or
    `max_truncation_err` is not `None`.

    The truncation error is the 2-norm of the vector of truncated singular
    values. If only `max_truncation_err` is set, as many singular values will
    be truncated as possible while maintaining:
    `norm(truncated_singular_values) <= max_truncation_err`.

    If only `max_singular_values` is set, the number of singular values kept
    will be `min(max_singular_values, number_of_singular_values)`, so that
    `max(0, number_of_singular_values - max_singular_values)` are truncated.

    If both `max_truncation_err` and `max_singular_values` are set,
    `max_singular_values` takes priority: The truncation error may be larger
    than `max_truncation_err` if required to satisfy `max_singular_values`.

    Args:
      node: The node you want to split.
      left_edges: The edges you want connected to the new left node.
      right_edges: The edges you want connected to the new right node.
      max_singular_values: The maximum number of singular values to keep.
      max_truncation_err: The maximum allowed truncation error.
      left_name: The name of the new left node. If None, a name will be generated
        automatically.
      middle_name: The name of the new center node. If None, a name will be generated
        automatically.
      right_name: The name of the new right node. If None, a name will be generated
        automatically.
      left_edge_name: The name of the new left `Edge` connecting 
        the new left node (`U`) and the new central node (`S`). 
        If `None`, a name will be generated automatically.
      right_edge_name: The name of the new right `Edge` connecting 
        the new central node (`S`) and the new right node (`V*`). 
        If `None`, a name will be generated automatically.

    Returns:
      A tuple containing:
        left_node: 
          A new node created that connects to all of the `left_edges`.
          Its underlying tensor is :math:`U`
        singular_values_node: 
          A new node that has 2 edges connecting `left_node` and `right_node`.
          Its underlying tensor is :math:`S`
        right_node: 
          A new node created that connects to all of the `right_edges`.
          Its underlying tensor is :math:`V^*`
        truncated_singular_values: 
          The vector of truncated singular values.
    """
        U, S, V, trun_vals = network_operations.split_node_full_svd(
            node, left_edges, right_edges, max_singular_values,
            max_truncation_err, left_name, middle_name, right_name,
            left_edge_name, right_edge_name)
        left_node = self.add_node(U)
        singular_values_node = self.add_node(S)
        right_node = self.add_node(V)

        self.nodes_set.remove(node)
        node.disable()
        return left_node, singular_values_node, right_node, trun_vals
Exemple #7
0
    def split_node(
        self,
        node: network_components.BaseNode,
        left_edges: List[network_components.Edge],
        right_edges: List[network_components.Edge],
        max_singular_values: Optional[int] = None,
        max_truncation_err: Optional[float] = None,
        left_name: Optional[Text] = None,
        right_name: Optional[Text] = None,
        edge_name: Optional[Text] = None,
    ) -> Tuple[network_components.BaseNode, network_components.BaseNode,
               Tensor]:
        """Split a `Node` using Singular Value Decomposition.

    Let M be the matrix created by flattening left_edges and right_edges into
    2 axes. Let :math:`U S V^* = M` be the Singular Value Decomposition of 
    :math:`M`. This will split the network into 2 nodes. The left node's 
    tensor will be :math:`U \\sqrt{S}` and the right node's tensor will be 
    :math:`\\sqrt{S} V^*` where :math:`V^*` is
    the adjoint of :math:`V`.

    The singular value decomposition is truncated if `max_singular_values` or
    `max_truncation_err` is not `None`.

    The truncation error is the 2-norm of the vector of truncated singular
    values. If only `max_truncation_err` is set, as many singular values will
    be truncated as possible while maintaining:
    `norm(truncated_singular_values) <= max_truncation_err`.

    If only `max_singular_values` is set, the number of singular values kept
    will be `min(max_singular_values, number_of_singular_values)`, so that
    `max(0, number_of_singular_values - max_singular_values)` are truncated.

    If both `max_truncation_err` and `max_singular_values` are set,
    `max_singular_values` takes priority: The truncation error may be larger
    than `max_truncation_err` if required to satisfy `max_singular_values`.

    Args:
      node: The node you want to split.
      left_edges: The edges you want connected to the new left node.
      right_edges: The edges you want connected to the new right node.
      max_singular_values: The maximum number of singular values to keep.
      max_truncation_err: The maximum allowed truncation error.
      left_name: The name of the new left node. If `None`, a name will be generated
        automatically.
      right_name: The name of the new right node. If `None`, a name will be generated
        automatically.
      edge_name: The name of the new `Edge` connecting the new left and right node. 
        If `None`, a name will be generated automatically.

    Returns:
      A tuple containing:
        left_node: 
          A new node created that connects to all of the `left_edges`.
          Its underlying tensor is :math:`U \\sqrt{S}`
        right_node: 
          A new node created that connects to all of the `right_edges`.
          Its underlying tensor is :math:`\\sqrt{S} V^*`
        truncated_singular_values: 
          The vector of truncated singular values.
    """
        left, right, trun_vals = network_operations.split_node(
            node, left_edges, right_edges, max_singular_values,
            max_truncation_err, left_name, right_name, edge_name)
        left_node = self.add_node(left)
        right_node = self.add_node(right)

        self.nodes_set.remove(node)
        node.disable()
        return left_node, right_node, trun_vals