def test_copy_node_get_partners_with_trace(backend):
  node1 = CopyNode(4, 2, backend=backend)
  node2 = Node(np.random.rand(2, 2), backend=backend, name="node2")
  tn.connect(node1[0], node1[1])
  tn.connect(node1[2], node2[0])
  tn.connect(node1[3], node2[1])
  assert node1.get_partners() == {node2: {0, 1}}
示例#2
0
    def contract_copy_node(
            self,
            copy_node: network_components.CopyNode,
            name: Optional[Text] = None) -> network_components.BaseNode:
        """Contract all edges incident on given copy node.

    Args:
      copy_node: Copy tensor node to be contracted.
      name: Name of the new node created.

    Returns:
      New node representing contracted tensor.

    Raises:
      ValueError: If copy_node has dangling edge(s).
    """
        partners = copy_node.get_partners()
        new_node = self.add_node(
            network_components.contract_copy_node(copy_node, name))
        # Remove nodes
        for partner in partners:
            if partner in self.nodes_set:
                self.nodes_set.remove(partner)
        for partner in partners:
            if not partner.is_disabled:
                partner.disable()

        self.nodes_set.remove(copy_node)
        copy_node.disable()
        return new_node
示例#3
0
    def contract_copy_node(
            self,
            copy_node: network_components.CopyNode,
            name: Optional[Text] = None) -> network_components.Node:
        """Contract all edges incident on given copy node.

    Args:
      copy_node: Copy tensor node to be contracted.
      name: Name of the new node created.

    Returns:
      New node representing contracted tensor.

    Raises:
      ValueError: If copy_node has dangling edge(s).
    """
        new_tensor = copy_node.compute_contracted_tensor()
        new_node = self.add_node(new_tensor, name)

        partners = copy_node.get_partners()
        new_axis = 0
        for partner in partners:
            for edge in partner.edges:
                if edge.node1 is copy_node or edge.node2 is copy_node:
                    self.edge_order.remove(edge)
                    continue
                old_axis = edge.axis1 if edge.node1 is partner else edge.axis2
                edge.update_axis(old_node=partner,
                                 old_axis=old_axis,
                                 new_node=new_node,
                                 new_axis=new_axis)
                new_node.add_edge(edge, new_axis)
                new_axis += 1
            self.nodes_set.remove(partner)
        assert len(new_tensor.shape) == new_axis

        self.nodes_set.remove(copy_node)
        self.nodes_set.add(new_node)
        return new_node