Exemplo n.º 1
0
def transpose(node: AbstractNode,
              permutation: Sequence[Union[Text, int]],
              name: Optional[Text] = None,
              axis_names: Optional[List[Text]] = None) -> AbstractNode:
    """Transpose `node`

  Args:
    node: A `AbstractNode`.
    permutation: A list of int or str. The permutation of the axis.
    name: Optional name to give the new node.
    axis_names: Optional list of names for the axis.

  Returns:
    A new node. The transpose of `node`.

  Raises:
    AttributeError: If `node` has no `backend` attribute, or if
      `node` has no tensor.
    ValueError: If either `permutation` is not the same as expected or
      if you try to permute with a trace edge.
  """

    if not hasattr(node, 'backend'):
        raise AttributeError('Node {} of type {} has no `backend`'.format(
            node, type(node)))

    perm = [node.get_axis_number(p) for p in permutation]
    if not axis_names:
        axis_names = node.axis_names

    new_node = Node(node.tensor,
                    name=name,
                    axis_names=node.axis_names,
                    backend=node.backend)
    return new_node.reorder_axes(perm)
Exemplo n.º 2
0
def redirect_edge(edge: Edge, new_node: AbstractNode,
                  old_node: AbstractNode) -> None:
    """
  Redirect `edge` from `old_node` to `new_node`.
  Routine updates `new_node` and `old_node`.
  `edge` is added to `new_node`, `old_node` gets a
  new Edge instead of `edge`.

  Args:
    edge: An Edge.
    new_node: The new `Node` object.
    old_node: The old `Node` object.

  Returns:
    None

  Raises:
    ValueError: if `edge` does not point to `old_node`.
  """
    if not edge.is_trace():
        if edge.is_dangling():
            if edge.node1 is not old_node:
                raise ValueError(f"edge {edge} is not pointing "
                                 f"to old_node {old_node}")
            edge.node1 = new_node
            axis = edge.axis1
        else:
            if edge.node1 is old_node:
                edge.node1 = new_node
                axis = edge.axis1
            elif edge.node2 is old_node:
                edge.node2 = new_node
                axis = edge.axis2
            else:
                raise ValueError(f"edge {edge} is not pointing "
                                 f"to old_node {old_node}")
        new_node.add_edge(edge, axis, True)
        new_edge = Edge(old_node, axis)
        old_node.add_edge(new_edge, axis, True)
    else:
        if edge.node1 is not old_node:
            raise ValueError(f"edge {edge} is not pointing "
                             f"to old_node {old_node}")
        edge.node1 = new_node
        edge.node2 = new_node
        axis1 = edge.axis1
        axis2 = edge.axis2
        new_node.add_edge(edge, axis1, True)
        new_node.add_edge(edge, axis2, True)
        new_edge = Edge(old_node, axis1, None, old_node, axis2)
        old_node.add_edge(new_edge, axis1, True)
        old_node.add_edge(new_edge, axis2, True)
Exemplo n.º 3
0
def split_node_full_svd(
    node: AbstractNode,
    left_edges: List[Edge],
    right_edges: List[Edge],
    max_singular_values: Optional[int] = None,
    max_truncation_err: Optional[float] = None,
    relative: Optional[bool] = False,
    left_name: Optional[Text] = None,
    middle_name: Optional[Text] = None,
    right_name: Optional[Text] = None,
    left_edge_name: Optional[Text] = None,
    right_edge_name: Optional[Text] = None,
) -> Tuple[AbstractNode, AbstractNode, AbstractNode, Tensor]:
    """Split a node by doing a full singular value decomposition.

  Let :math:`M` be the matrix created by 
  flattening `left_edges` and `right_edges` into 2 axes. 
  Let :math:`U S V^* = M` be the Singular Value Decomposition of :math:`M`.

  The left most node will be :math:`U` tensor of the SVD, the middle node is
  the diagonal matrix of the singular values, ordered largest to smallest,
  and the right most node will be the :math:`V*` tensor of the SVD.

  The singular value decomposition is truncated if `max_singular_values` or
  `max_truncation_err` is not `None`.

  The truncation error is the 2-norm of the vector of truncated singular
  values. If only `max_truncation_err` is set, as many singular values will
  be truncated as possible while maintaining:
  `norm(truncated_singular_values) <= max_truncation_err`.
  If `relative` is set `True` then `max_truncation_err` is understood
  relative to the largest singular value.

  If only `max_singular_values` is set, the number of singular values kept
  will be `min(max_singular_values, number_of_singular_values)`, so that
  `max(0, number_of_singular_values - max_singular_values)` are truncated.

  If both `max_truncation_err` and `max_singular_values` are set,
  `max_singular_values` takes priority: The truncation error may be larger
  than `max_truncation_err` if required to satisfy `max_singular_values`.

  Args:
    node: The node you want to split.
    left_edges: The edges you want connected to the new left node.
    right_edges: The edges you want connected to the new right node.
    max_singular_values: The maximum number of singular values to keep.
    max_truncation_err: The maximum allowed truncation error.
    relative: Multiply `max_truncation_err` with the largest singular value.
    left_name: The name of the new left node. If None, a name will be 
      generated automatically.
    middle_name: The name of the new center node. If `None`, a name will be
      generated automatically.
    right_name: The name of the new right node. If `None`, a name will be
      generated automatically.
    left_edge_name: The name of the new left `Edge` connecting
      the new left node (:math:`U`) and the new central node (:math:`S`).
      If `None`, a name will be generated automatically.
    right_edge_name: The name of the new right `Edge` connecting
      the new central node (:math:`S`) and the new right node (:math:`V*`).
      If `None`, a name will be generated automatically.

  Returns:
    A tuple containing:
      left_node:
        A new node created that connects to all of the `left_edges`.
        Its underlying tensor is :math:`U`
      singular_values_node:
        A new node that has 2 edges connecting `left_node` and `right_node`.
        Its underlying tensor is :math:`S`
      right_node:
        A new node created that connects to all of the `right_edges`.
        Its underlying tensor is :math:`V^*`
      truncated_singular_values:
        The vector of truncated singular values.

  Raises:
    AttributeError: If `node` has no backend attribute
  """

    if not hasattr(node, 'backend'):
        raise AttributeError('Node {} of type {} has no `backend`'.format(
            node, type(node)))

    if node.axis_names and left_edge_name and right_edge_name:
        left_axis_names = []
        right_axis_names = [right_edge_name]
        for edge in left_edges:
            left_axis_names.append(node.axis_names[edge.axis1] if edge.node1 is
                                   node else node.axis_names[edge.axis2])
        for edge in right_edges:
            right_axis_names.append(node.axis_names[edge.axis1] if edge.node1
                                    is node else node.axis_names[edge.axis2])
        left_axis_names.append(left_edge_name)
        center_axis_names = [left_edge_name, right_edge_name]
    else:
        left_axis_names = None
        center_axis_names = None
        right_axis_names = None

    backend = node.backend
    transp_tensor = node.tensor_from_edge_order(left_edges + right_edges)

    u, s, vh, trun_vals = backend.svd(transp_tensor,
                                      len(left_edges),
                                      max_singular_values,
                                      max_truncation_err,
                                      relative=relative)
    left_node = Node(u,
                     name=left_name,
                     axis_names=left_axis_names,
                     backend=backend)
    singular_values_node = Node(backend.diagflat(s),
                                name=middle_name,
                                axis_names=center_axis_names,
                                backend=backend)

    right_node = Node(vh,
                      name=right_name,
                      axis_names=right_axis_names,
                      backend=backend)

    left_axes_order = [
        edge.axis1 if edge.node1 is node else edge.axis2 for edge in left_edges
    ]
    for i, edge in enumerate(left_edges):
        left_node.add_edge(edge, i)
        edge.update_axis(left_axes_order[i], node, i, left_node)

    right_axes_order = [
        edge.axis1 if edge.node1 is node else edge.axis2
        for edge in right_edges
    ]
    for i, edge in enumerate(right_edges):
        # i + 1 to account for the new edge.
        right_node.add_edge(edge, i + 1)
        edge.update_axis(right_axes_order[i], node, i + 1, right_node)
    connect(left_node.edges[-1],
            singular_values_node.edges[0],
            name=left_edge_name)
    connect(singular_values_node.edges[1],
            right_node.edges[0],
            name=right_edge_name)
    node.fresh_edges(node.axis_names)
    return left_node, singular_values_node, right_node, trun_vals
Exemplo n.º 4
0
def split_node_rq(
    node: AbstractNode,
    left_edges: List[Edge],
    right_edges: List[Edge],
    left_name: Optional[Text] = None,
    right_name: Optional[Text] = None,
    edge_name: Optional[Text] = None,
) -> Tuple[AbstractNode, AbstractNode]:
    """Split a `node` using RQ (reversed QR) decomposition.

  Let :math:`M` be the matrix created by 
  flattening `left_edges` and `right_edges` into 2 axes. 

  Let :math:`QR = M^*` be the QR Decomposition of :math:`M^*`. 
  This will split the network into 2 nodes. 

  The left node's tensor will be :math:`R^*` (a lower triangular matrix) 
  and the right node's tensor will be :math:`Q^*` (an orthonormal matrix)

  Args:
    node: The node you want to split.
    left_edges: The edges you want connected to the new left node.
    right_edges: The edges you want connected to the new right node.
    left_name: The name of the new left node. If `None`, a name will be
      generated automatically.
    right_name: The name of the new right node. If `None`, a name will be
      generated automatically.
    edge_name: The name of the new `Edge` connecting the new left and
      right node. If `None`, a name will be generated automatically.

  Returns:
    A tuple containing:
      left_node:
        A new node that connects to all of the `left_edges`.
        Its underlying tensor is :math:`R^*`
      right_node:
        A new node that connects to all of the `right_edges`.
        Its underlying tensor is :math:`Q^*`

  Raises:
    AttributeError: If `node` has no backend attribute
  """

    if not hasattr(node, 'backend'):
        raise AttributeError('Node {} of type {} has no `backend`'.format(
            node, type(node)))

    if node.axis_names and edge_name:
        left_axis_names = []
        right_axis_names = [edge_name]
        for edge in left_edges:
            left_axis_names.append(node.axis_names[edge.axis1] if edge.node1 is
                                   node else node.axis_names[edge.axis2])
        for edge in right_edges:
            right_axis_names.append(node.axis_names[edge.axis1] if edge.node1
                                    is node else node.axis_names[edge.axis2])
        left_axis_names.append(edge_name)
    else:
        left_axis_names = None
        right_axis_names = None
    backend = node.backend
    transp_tensor = node.tensor_from_edge_order(left_edges + right_edges)

    r, q = backend.rq(transp_tensor, len(left_edges))
    left_node = Node(r,
                     name=left_name,
                     axis_names=left_axis_names,
                     backend=backend)

    left_axes_order = [
        edge.axis1 if edge.node1 is node else edge.axis2 for edge in left_edges
    ]
    for i, edge in enumerate(left_edges):
        left_node.add_edge(edge, i)
        edge.update_axis(left_axes_order[i], node, i, left_node)

    right_node = Node(q,
                      name=right_name,
                      axis_names=right_axis_names,
                      backend=backend)

    right_axes_order = [
        edge.axis1 if edge.node1 is node else edge.axis2
        for edge in right_edges
    ]

    for i, edge in enumerate(right_edges):
        # i + 1 to account for the new edge.
        right_node.add_edge(edge, i + 1)
        edge.update_axis(right_axes_order[i], node, i + 1, right_node)

    connect(left_node.edges[-1], right_node.edges[0], name=edge_name)
    node.fresh_edges(node.axis_names)
    return left_node, right_node
Exemplo n.º 5
0
def split_node(
    node: AbstractNode,
    left_edges: List[Edge],
    right_edges: List[Edge],
    max_singular_values: Optional[int] = None,
    max_truncation_err: Optional[float] = None,
    relative: Optional[bool] = False,
    left_name: Optional[Text] = None,
    right_name: Optional[Text] = None,
    edge_name: Optional[Text] = None,
) -> Tuple[AbstractNode, AbstractNode, Tensor]:
    """Split a `node` using Singular Value Decomposition.

  Let :math:`M` be the matrix created by flattening `left_edges` and 
  `right_edges` into 2 axes. 
  Let :math:`U S V^* = M` be the SVD of :math:`M`. 
  This will split the network into 2 nodes. 
  The left node's tensor will be :math:`U \\sqrt{S}` 
  and the right node's tensor will be
  :math:`\\sqrt{S} V^*` where :math:`V^*` is the adjoint of :math:`V`.

  The singular value decomposition is truncated if `max_singular_values` or
  `max_truncation_err` is not `None`.

  The truncation error is the 2-norm of the vector of truncated singular
  values. If only `max_truncation_err` is set, as many singular values will
  be truncated as possible while maintaining:
  `norm(truncated_singular_values) <= max_truncation_err`.
  If `relative` is set `True` then `max_truncation_err` is understood
  relative to the largest singular value.

  If only `max_singular_values` is set, the number of singular values kept
  will be `min(max_singular_values, number_of_singular_values)`, so that
  `max(0, number_of_singular_values - max_singular_values)` are truncated.

  If both `max_truncation_err` and `max_singular_values` are set,
  `max_singular_values` takes priority: The truncation error may be larger
  than `max_truncation_err` if required to satisfy `max_singular_values`.

  Args:
    node: The node you want to split.
    left_edges: The edges you want connected to the new left node.
    right_edges: The edges you want connected to the new right node.
    max_singular_values: The maximum number of singular values to keep.
    max_truncation_err: The maximum allowed truncation error.
    relative: Multiply `max_truncation_err` with the largest singular value.
    left_name: The name of the new left node. If `None`, a name will be 
      generated automatically.
    right_name: The name of the new right node. If `None`, a name will be
      generated automatically.
    edge_name: The name of the new `Edge` connecting the new left and
      right node. If `None`, a name will be generated automatically.
      The new axis will get the same name as the edge.

  Returns:
    A tuple containing:
      left_node:
        A new node created that connects to all of the `left_edges`.
        Its underlying tensor is :math:`U \\sqrt{S}`
      right_node:
        A new node created that connects to all of the `right_edges`.
        Its underlying tensor is :math:`\\sqrt{S} V^*`
      truncated_singular_values:
        The vector of truncated singular values.
  Raises:
    AttributeError: If `node` has no backend attribute
  """

    if not hasattr(node, 'backend'):
        raise AttributeError('Node {} of type {} has no `backend`'.format(
            node, type(node)))

    if node.axis_names and edge_name:
        left_axis_names = []
        right_axis_names = [edge_name]
        for edge in left_edges:
            left_axis_names.append(node.axis_names[edge.axis1] if edge.node1 is
                                   node else node.axis_names[edge.axis2])
        for edge in right_edges:
            right_axis_names.append(node.axis_names[edge.axis1] if edge.node1
                                    is node else node.axis_names[edge.axis2])
        left_axis_names.append(edge_name)
    else:
        left_axis_names = None
        right_axis_names = None

    backend = node.backend
    transp_tensor = node.tensor_from_edge_order(left_edges + right_edges)

    u, s, vh, trun_vals = backend.svd(transp_tensor,
                                      len(left_edges),
                                      max_singular_values,
                                      max_truncation_err,
                                      relative=relative)
    sqrt_s = backend.sqrt(s)
    u_s = backend.broadcast_right_multiplication(u, sqrt_s)
    vh_s = backend.broadcast_left_multiplication(sqrt_s, vh)

    left_node = Node(u_s,
                     name=left_name,
                     axis_names=left_axis_names,
                     backend=backend)

    left_axes_order = [
        edge.axis1 if edge.node1 is node else edge.axis2 for edge in left_edges
    ]
    for i, edge in enumerate(left_edges):
        left_node.add_edge(edge, i)
        edge.update_axis(left_axes_order[i], node, i, left_node)

    right_node = Node(vh_s,
                      name=right_name,
                      axis_names=right_axis_names,
                      backend=backend)

    right_axes_order = [
        edge.axis1 if edge.node1 is node else edge.axis2
        for edge in right_edges
    ]
    for i, edge in enumerate(right_edges):
        # i + 1 to account for the new edge.
        right_node.add_edge(edge, i + 1)
        edge.update_axis(right_axes_order[i], node, i + 1, right_node)

    connect(left_node.edges[-1], right_node.edges[0], name=edge_name)
    node.fresh_edges(node.axis_names)
    return left_node, right_node, trun_vals