def contract_trace_edges( net: network.TensorNetwork, none_value: int = 1 ) -> Tuple[network.TensorNetwork, Dict[network_components.BaseNode, int], Dict[ network_components.BaseNode, int]]: """Contracts trace edges and calculate tensor sizes for every node. Tensor size is defined as the product of sizes of each of edges (axes). Args: net: TensorNetwork to contract all the trace edges of. none_value: The value that None dimensions contribute to the tensor size. Unit (default) means that None dimensions are neglected. Returns: A tuple containing: net: Given TensorNetwork with all its trace edges contracted. node_sizes: Map from nodes in the network to their total size. node_sizes_none: Map from nodes that have at least one None dimension to their size. """ # Keep node sizes in memory for cost calculation node_sizes, node_sizes_none = dict(), dict() initial_node_set = set(net.nodes_set) for node in initial_node_set: trace_edges, flag_none, total_dim = set(), False, 1 new_node = node # makes sure node_edges points to the original edges # even after contracting the trace # pylint: disable=unnecessary-comprehension node_edges = [e for e in node.edges] node_dims = list(node.get_tensor().shape) for edge, dim in zip(node_edges, node_dims): if (not edge.is_disabled) and (edge.node1 is edge.node2): if edge not in trace_edges: # Contract trace edge new_node = net.contract(edge, name=node.name) trace_edges.add(edge) elif edge.is_disabled: #edge has been contracted; skip it continue else: if dim is None: total_dim *= none_value flag_none = True else: total_dim *= dim if flag_none: node_sizes_none[new_node] = total_dim else: node_sizes[new_node] = total_dim return net, node_sizes, node_sizes_none
def contract_trace_edges( net: network.TensorNetwork, none_value: int = 1 ) -> Tuple[network.TensorNetwork, Dict[network_components.Node, int], Dict[ network_components.Node, int]]: """Contracts trace edges and calculate tensor sizes for every node. Tensor size is defined as the product of sizes of each of edges (axes). Args: net: TensorNetwork to contract all the trace edges of. none_value: The value that None dimensions contribute to the tensor size. Unit (default) means that None dimensions are neglected. Returns: A tuple containing: net: Given TensorNetwork with all its trace edges contracted. node_sizes: Map from nodes in the network to their total size. node_sizes_none: Map from nodes that have at least one None dimension to their size. """ # Keep node sizes in memory for cost calculation node_sizes, node_sizes_none = dict(), dict() initial_node_set = set(net.nodes_set) for node in initial_node_set: trace_edges, flag_none, total_dim = set(), False, 1 new_node = node for edge, dim in zip(node.edges, list(node.get_tensor().shape)): if edge.node1 is edge.node2: if edge not in trace_edges: # Contract trace edge new_node = net.contract(edge) trace_edges.add(edge) else: if dim is None: total_dim *= none_value flag_none = True else: total_dim *= dim if flag_none: node_sizes_none[new_node] = total_dim else: node_sizes[new_node] = total_dim return net, node_sizes, node_sizes_none