示例#1
0
def contract_trace_edges(
    net: network.TensorNetwork,
    none_value: int = 1
) -> Tuple[network.TensorNetwork, Dict[network_components.BaseNode, int], Dict[
        network_components.BaseNode, int]]:
    """Contracts trace edges and calculate tensor sizes for every node.

  Tensor size is defined as the product of sizes of each of edges (axes).

  Args:
    net: TensorNetwork to contract all the trace edges of.
    none_value: The value that None dimensions contribute to the tensor size.
      Unit (default) means that None dimensions are neglected.

  Returns:
    A tuple containing:
      net: 
        Given TensorNetwork with all its trace edges contracted.
      node_sizes: 
        Map from nodes in the network to their total size.
      node_sizes_none: 
        Map from nodes that have at least one None dimension to
        their size.
  """
    # Keep node sizes in memory for cost calculation
    node_sizes, node_sizes_none = dict(), dict()
    initial_node_set = set(net.nodes_set)
    for node in initial_node_set:
        trace_edges, flag_none, total_dim = set(), False, 1
        new_node = node
        # makes sure node_edges points to the original edges
        # even after contracting the trace
        # pylint: disable=unnecessary-comprehension
        node_edges = [e for e in node.edges]
        node_dims = list(node.get_tensor().shape)
        for edge, dim in zip(node_edges, node_dims):
            if (not edge.is_disabled) and (edge.node1 is edge.node2):
                if edge not in trace_edges:
                    # Contract trace edge
                    new_node = net.contract(edge, name=node.name)
                    trace_edges.add(edge)
            elif edge.is_disabled:  #edge has been contracted; skip it
                continue
            else:
                if dim is None:
                    total_dim *= none_value
                    flag_none = True
                else:
                    total_dim *= dim
            if flag_none:
                node_sizes_none[new_node] = total_dim
            else:
                node_sizes[new_node] = total_dim
    return net, node_sizes, node_sizes_none
def contract_trace_edges(
    net: network.TensorNetwork,
    none_value: int = 1
) -> Tuple[network.TensorNetwork, Dict[network_components.Node, int], Dict[
        network_components.Node, int]]:
    """Contracts trace edges and calculate tensor sizes for every node.

  Tensor size is defined as the product of sizes of each of edges (axes).

  Args:
    net: TensorNetwork to contract all the trace edges of.
    none_value: The value that None dimensions contribute to the tensor size.
      Unit (default) means that None dimensions are neglected.

  Returns:
    A tuple containing:
      net: 
        Given TensorNetwork with all its trace edges contracted.
      node_sizes: 
        Map from nodes in the network to their total size.
      node_sizes_none: 
        Map from nodes that have at least one None dimension to
        their size.
  """
    # Keep node sizes in memory for cost calculation
    node_sizes, node_sizes_none = dict(), dict()
    initial_node_set = set(net.nodes_set)
    for node in initial_node_set:
        trace_edges, flag_none, total_dim = set(), False, 1
        new_node = node
        for edge, dim in zip(node.edges, list(node.get_tensor().shape)):
            if edge.node1 is edge.node2:
                if edge not in trace_edges:
                    # Contract trace edge
                    new_node = net.contract(edge)
                    trace_edges.add(edge)
            else:
                if dim is None:
                    total_dim *= none_value
                    flag_none = True
                else:
                    total_dim *= dim
            if flag_none:
                node_sizes_none[new_node] = total_dim
            else:
                node_sizes[new_node] = total_dim
    return net, node_sizes, node_sizes_none