Exemple #1
0
    def contract(self,
                 edge: network_components.Edge,
                 name: Optional[Text] = None) -> network_components.Node:
        """Contract an edge connecting two nodes in the TensorNetwork.

    Args:
      edge: The edge contract next.
      name: Name of the new node created.

    Returns:
      new_node: The new node created after the contraction.

    Raises:
      ValueError: When edge is a dangling edge or if it already has been
        contracted.
    """
        if not edge.is_being_used() or edge.node1 not in self.nodes_set:
            raise ValueError(
                "Attempting to contract edge '{}' that is not part of "
                "the network.".format(edge))
        if edge.is_dangling():
            raise ValueError("Attempting to contract dangling edge")
        if edge.node1 is edge.node2:
            return self._contract_trace(edge, name)
        new_tensor = self.backend.tensordot(edge.node1.tensor,
                                            edge.node2.tensor,
                                            [[edge.axis1], [edge.axis2]])
        new_node = self.add_node(new_tensor, name)
        self._remove_edges(set([edge]), edge.node1, edge.node2, new_node)
        return new_node
def test_edge_initialize_raises_error_faulty_arguments(double_node_edge):
  node1 = double_node_edge.node1
  node2 = double_node_edge.node2
  with pytest.raises(ValueError):
    Edge(name="edge", node1=node1, node2=node2, axis1=0)
  with pytest.raises(ValueError):
    Edge(name="edge", node1=node1, axis1=0, axis2=0)
def test_edge_magic_xor(double_node_edge):
  node1 = double_node_edge.node1
  node2 = double_node_edge.node2
  edge1 = Edge(name="edge1", node1=node1, axis1=2)
  edge2 = Edge(name="edge2", node1=node2, axis1=2)
  edge = edge1 ^ edge2
  assert edge.node1 == node1
  assert edge.node2 == node2
def fixture_double_node_edge(backend):
  net = tensornetwork.TensorNetwork(backend=backend)
  tensor = net.backend.convert_to_tensor(np.ones((1, 2, 2)))
  node1 = Node(
      tensor=tensor, name="test_node1", axis_names=["a", "b", "c"], network=net)
  node2 = Node(
      tensor=tensor, name="test_node2", axis_names=["a", "b", "c"], network=net)
  net.connect(node1["b"], node2["b"])
  edge1 = Edge(name="edge", node1=node1, axis1=0)
  edge12 = Edge(name="edge", node1=node1, axis1=1, node2=node2, axis2=1)
  return DoubleNodeEdgeTensor(node1, node2, edge1, edge12, tensor)
Exemple #5
0
def copy(nodes: Iterable[AbstractNode],
         conjugate: bool = False) -> Tuple[dict, dict]:
    """Copy the given nodes and their edges.

  This will return a tuple linking original nodes/edges to their copies.
  If nodes A and B are connected but only A is passed in to be
  copied, the edge between them will become a dangling edge.

  Args:
    nodes: An Iterable (Usually a `list` or `set`) of `nodes`.
    conjugate: Boolean. Whether to conjugate all of the nodes
      (useful for calculating norms and reduced density matrices).

  Returns:
    A tuple containing:
      node_dict:
        A dictionary mapping the nodes to their copies.
      edge_dict:
        A dictionary mapping the edges to their copies.
  """
    node_dict = {}
    for node in nodes:
        node_dict[node] = node.copy(conjugate)
    edge_dict = {}
    for edge in get_all_edges(nodes):
        node1 = edge.node1
        axis1 = edge.node1.get_axis_number(edge.axis1)
        # edge dangling or node2 does not need to be copied
        if edge.is_dangling() or edge.node2 not in node_dict:
            new_edge = Edge(node_dict[node1], axis1, edge.name)
            node_dict[node1].add_edge(new_edge, axis1)
            edge_dict[edge] = new_edge
            continue

        node2 = edge.node2
        axis2 = edge.node2.get_axis_number(edge.axis2)
        # copy node2 but not node1
        if node1 not in node_dict:
            new_edge = Edge(node_dict[node2], axis2, edge.name)
            node_dict[node2].add_edge(new_edge, axis2)
            edge_dict[edge] = new_edge
            continue

        # both nodes should be copied
        new_edge = Edge(node_dict[node1], axis1, edge.name, node_dict[node2],
                        axis2)
        if not edge.is_trace():
            node_dict[node2].add_edge(new_edge, axis2)
            node_dict[node1].add_edge(new_edge, axis1)

        edge_dict[edge] = new_edge

    return node_dict, edge_dict
def fixture_double_node_edge(backend):
    tensor = np.ones((1, 2, 2))
    node1 = Node(tensor=tensor,
                 name="test_node1",
                 axis_names=["a", "b", "c"],
                 backend=backend)
    node2 = Node(tensor=tensor,
                 name="test_node2",
                 axis_names=["a", "b", "c"],
                 backend=backend)
    tn.connect(node1["b"], node2["b"])
    edge1 = Edge(name="edge", node1=node1, axis1=0)
    edge12 = Edge(name="edge", node1=node1, axis1=1, node2=node2, axis2=1)
    return DoubleNodeEdgeTensor(node1, node2, edge1, edge12, tensor)
Exemple #7
0
    def disconnect(
        self,
        edge: network_components.Edge,
        dangling_edge_name_1: Optional[Text] = None,
        dangling_edge_name_2: Optional[Text] = None
    ) -> List[network_components.Edge]:
        """Break a edge into two dangling edges.

    Args:
      edge: An edge to break.
      dangling_edge_name_1: Optional name to give the new dangling edge 1.
      dangling_edge_name_2: Optional name to give the new dangling edge 2.

    Returns:
      dangling_edge_1: A new dangling edge.
      dangling_edge_2: A new dangling edge.

    Raises:
      ValueError: If input edge is a dangling one.
    """
        if edge.is_dangling():
            raise ValueError(
                "Attempted to break a dangling edge '{}'.".format(edge))
        node1 = edge.node1
        node2 = edge.node2
        dangling_edge_name_1 = self._new_edge_name(dangling_edge_name_1)
        dangling_edge_name_2 = self._new_edge_name(dangling_edge_name_2)
        dangling_edge_1 = network_components.Edge(dangling_edge_name_1, node1,
                                                  edge.axis1)
        dangling_edge_2 = network_components.Edge(dangling_edge_name_2, node2,
                                                  edge.axis2)
        node1.add_edge(dangling_edge_1, edge.axis1, True)
        node2.add_edge(dangling_edge_2, edge.axis2, True)
        self.edge_order.remove(edge)
        return [dangling_edge_1, dangling_edge_2]
Exemple #8
0
    def _contract_trace(
            self,
            edge: network_components.Edge,
            name: Optional[Text] = None) -> network_components.Node:
        """Contract a trace edge connecting in the TensorNetwork.

    Args:
      edge: The edge name or object to contract next.
      name: Name to give to the new node. If None, a name will automatically be
        generated.

    Returns:
      new_node: The new node created after the contraction.
    Raise:
      ValueError: When edge is a dangling edge.
    """
        if edge.is_dangling():
            raise ValueError(
                "Attempted to contract dangling edge '{}'".format(edge))
        if edge.node1 is not edge.node2:
            raise ValueError(
                "Can not take trace of edge '{}'. This edge connects to "
                "two different nodes: '{}' and '{}".format(
                    edge, edge.node1, edge.node2))
        axes = sorted([edge.axis1, edge.axis2])
        dims = len(edge.node1.tensor.shape)
        permutation = sorted(set(range(dims)) - set(axes)) + axes
        new_tensor = self.backend.trace(
            self.backend.transpose(edge.node1.tensor, perm=permutation))
        new_node = self.add_node(new_tensor, name)
        self._remove_trace_edge(edge, new_node)
        return new_node
def find_parallel(
        edge: network_components.Edge
) -> Tuple[Set[network_components.Edge], int]:
    """Finds all edges shared between the nodes connected with the given edge.

  Args:
    edge: A non-dangling edge between two different nodes.

  Returns:
    parallel_edges: Edges that are parallel to the given edge.
    parallel_dim: Product of sizes of all parallel edges.
  """
    if edge.is_dangling():
        raise ValueError(
            "Cannot find parallel edges for dangling edge {}".format(edge))
    nodes = {edge.node1, edge.node2}
    parallel_dim = 1
    parallel_edges = set()
    for e in edge.node1.edges:
        if set(e.get_nodes()) == nodes:
            parallel_edges.add(e)
            edge_size = list(e.node1.get_tensor().shape)[e.axis1]
            if edge_size is not None:
                parallel_dim *= edge_size
    return parallel_edges, parallel_dim
Exemple #10
0
    def _remove_edge(self, edge: network_components.Edge,
                     new_node: network_components.Node) -> None:
        """Collapse an edge in the network.

    Collapses an edge and updates the rest of the network.

    Args:
      edge: The edge to contract.
      new_node: The new node that represents the contraction of the two old
        nodes.

    Raises:
      Value Error: If edge isn't in the network.
    """
        # Assert that the edge isn't a dangling edge.
        if edge.is_dangling():
            raise ValueError(
                "Attempted to remove dangling edge '{}'.".format(edge))
        if edge.node1 is edge.node2:
            self._remove_trace_edge(edge, new_node)
        # Collapse the nodes into a new node and remove the edge.
        node1 = edge.node1
        node2 = edge.node2
        node1_edges = edge.node1.edges[:]
        node2_edges = edge.node2.edges[:]
        node1_axis = edge.axis1
        node2_axis = edge.axis2
        # Redefine all other edges.
        num_added_front_edges = len(node1_edges) - 1
        for i, tmp_edge in enumerate(node1_edges[:node1_axis]):
            tmp_edge.update_axis(old_axis=i,
                                 old_node=node1,
                                 new_axis=i,
                                 new_node=new_node)
        for i, tmp_edge in enumerate(node1_edges[node1_axis + 1:]):
            tmp_edge.update_axis(old_axis=i + node1_axis + 1,
                                 old_node=node1,
                                 new_axis=i + node1_axis,
                                 new_node=new_node)
        for i, tmp_edge in enumerate(node2_edges[:node2_axis]):
            tmp_edge.update_axis(old_axis=i,
                                 old_node=node2,
                                 new_axis=i + num_added_front_edges,
                                 new_node=new_node)
        for i, tmp_edge in enumerate(node2_edges[node2_axis + 1:]):
            tmp_edge.update_axis(old_axis=i + node2_axis + 1,
                                 old_node=node2,
                                 new_axis=i + node2_axis +
                                 num_added_front_edges,
                                 new_node=new_node)

        node1_edges.pop(node1_axis)
        node2_edges.pop(node2_axis)
        new_edges = node1_edges + node2_edges
        for i, e in enumerate(new_edges):
            new_node.add_edge(e, i)

        # Remove nodes
        self.nodes_set.remove(node1)
        self.nodes_set.remove(node2)
def test_edge_load(backend, tmp_path, double_node_edge):
  edge = double_node_edge.edge12

  with h5py.File(tmp_path / 'edge', 'w') as edge_file:
    edge_group = edge_file.create_group('edge_data')
    edge_group.create_dataset('signature', data=edge.signature)
    edge_group.create_dataset('name', data=edge.name)
    edge_group.create_dataset('node1', data=edge.node1.name)
    edge_group.create_dataset('node2', data=edge.node2.name)
    edge_group.create_dataset('axis1', data=edge.axis1)
    edge_group.create_dataset('axis2', data=edge.axis2)

    ten = np.ones((1, 2, 2))
    node1 = Node(
        tensor=2 * ten,
        name="test_node1",
        axis_names=["a", "b", "c"],
        backend=backend)
    node2 = Node(
        tensor=ten,
        name="test_node2",
        axis_names=["a", "b", "c"],
        backend=backend)
    loaded_edge = Edge._load_edge(edge_group, {
        node1.name: node1,
        node2.name: node2
    })
    assert loaded_edge.name == edge.name
    assert loaded_edge.signature == edge.signature
    assert loaded_edge.node1.name == edge.node1.name
    assert loaded_edge.node2.name == edge.node2.name
    assert loaded_edge.axis1 == edge.axis1
    assert loaded_edge.axis2 == edge.axis2
    np.testing.assert_allclose(loaded_edge.node1.tensor, node1.tensor)
    np.testing.assert_allclose(loaded_edge.node2.tensor, node2.tensor)
Exemple #12
0
def redirect_edge(edge: Edge, new_node: AbstractNode,
                  old_node: AbstractNode) -> None:
    """
  Redirect `edge` from `old_node` to `new_node`.
  Routine updates `new_node` and `old_node`.
  `edge` is added to `new_node`, `old_node` gets a
  new Edge instead of `edge`.

  Args:
    edge: An Edge.
    new_node: The new `Node` object.
    old_node: The old `Node` object.

  Returns:
    None

  Raises:
    ValueError: if `edge` does not point to `old_node`.
  """
    if not edge.is_trace():
        if edge.is_dangling():
            if edge.node1 is not old_node:
                raise ValueError(f"edge {edge} is not pointing "
                                 f"to old_node {old_node}")
            edge.node1 = new_node
            axis = edge.axis1
        else:
            if edge.node1 is old_node:
                edge.node1 = new_node
                axis = edge.axis1
            elif edge.node2 is old_node:
                edge.node2 = new_node
                axis = edge.axis2
            else:
                raise ValueError(f"edge {edge} is not pointing "
                                 f"to old_node {old_node}")
        new_node.add_edge(edge, axis, True)
        new_edge = Edge(old_node, axis)
        old_node.add_edge(new_edge, axis, True)
    else:
        if edge.node1 is not old_node:
            raise ValueError(f"edge {edge} is not pointing "
                             f"to old_node {old_node}")
        edge.node1 = new_node
        edge.node2 = new_node
        axis1 = edge.axis1
        axis2 = edge.axis2
        new_node.add_edge(edge, axis1, True)
        new_node.add_edge(edge, axis2, True)
        new_edge = Edge(old_node, axis1, None, old_node, axis2)
        old_node.add_edge(new_edge, axis1, True)
        old_node.add_edge(new_edge, axis2, True)
def test_node_add_edge_raises_error_mismatch_rank(single_node_edge):
  node = single_node_edge.node
  edge = single_node_edge.edge
  with pytest.raises(ValueError):
    node.add_edge(edge, axis=-1)
  edge = Edge(name="edge", node1=node, axis1=0)
  with pytest.raises(ValueError):
    node.add_edge(edge, axis=3)
def fixture_single_node_edge(backend):
    tensor = np.ones((1, 2, 2))
    node = Node(tensor=tensor,
                name="test_node",
                axis_names=["a", "b", "c"],
                backend=backend)
    edge = Edge(name="edge", node1=node, axis1=0)
    return SingleNodeEdgeTensor(node, edge, tensor)
def fixture_single_node_edge(backend):
  net = tensornetwork.TensorNetwork(backend=backend)
  tensor = np.ones((1, 2, 2))
  tensor = net.backend.convert_to_tensor(tensor)
  node = Node(
      tensor=tensor, name="test_node", axis_names=["a", "b", "c"], network=net)
  edge = Edge(name="edge", node1=node, axis1=0)
  return SingleNodeEdgeTensor(node, edge, tensor)
def test_node_reorder_edges_raise_error_wrong_edges(single_node_edge):
  node = single_node_edge.node
  e0 = node[0]
  e1 = node[1]
  e2 = node[2]
  edge = Edge(name="edge", node1=node, axis1=0)
  with pytest.raises(ValueError) as e:
    node.reorder_edges([e0])
  assert "Missing edges that belong to node found:" in str(e.value)
  with pytest.raises(ValueError) as e:
    node.reorder_edges([e0, e1, e2, edge])
  assert "Additional edges that do not belong to node found:" in str(e.value)
Exemple #17
0
def load(path: str):
    """Load a tensor network from disk.

  Args:
    path: path to file where network is saved.
  """
    with h5py.File(path, 'r') as net_file:
        net = TensorNetwork(backend=net_file["backend"][()])
        node_names = list(net_file["nodes"].keys())
        edge_names = list(net_file["edges"].keys())

        for node_name in node_names:
            node_data = net_file["nodes/" + node_name]
            node_type = get_component(node_data['type'][()])
            node_type._load_node(net, node_data)

        nodes_dict = {node.name: node for node in net.nodes_set}

        for edge in edge_names:
            edge_data = net_file["edges/" + edge]
            Edge._load_edge(edge_data, nodes_dict)
    return net
Exemple #18
0
  def contract_parallel(
      self, edge: network_components.Edge) -> network_components.Node:
    """Contract all edges parallel to this edge.

    This method calls `contract_between` with the nodes connected by the edge.

    Args:
      edge: The edge to contract.
    Returns:
      The new node created after contraction.
    """
    if edge.is_dangling():
      raise ValueError("Attempted to contract dangling edge: '{}'".format(edge))
    return self.contract_between(edge.node1, edge.node2)
Exemple #19
0
    def _remove_trace_edge(self, edge: network_components.Edge,
                           new_node: network_components.Node) -> None:
        """Collapse a trace edge.

    Collapses a trace edge and updates the network.

    Args:
      edge: The edge to contract.
      new_node: The new node created after contraction.

    Returns:
      The node that had the contracted edge.

    Raises:
      ValueError: If edge is not a trace edge.
    """
        if edge.is_dangling():
            raise ValueError(
                "Attempted to remove dangling edge '{}'.".format(edge))
        if edge.node1 is not edge.node2:
            raise ValueError("Edge '{}' is not a trace edge.".format(edge))
        axes = sorted([edge.axis1, edge.axis2])
        node_edges = edge.node1.edges[:]
        node_edges.pop(axes[0])
        node_edges.pop(axes[1] - 1)
        seen_edges = set()
        for tmp_edge in node_edges:
            if tmp_edge in seen_edges:
                continue
            else:
                seen_edges.add(tmp_edge)
            if tmp_edge.node1 is edge.node1:
                to_reduce = 0
                to_reduce += 1 if tmp_edge.axis1 > axes[0] else 0
                to_reduce += 1 if tmp_edge.axis1 > axes[1] else 0
                tmp_edge.axis1 -= to_reduce
                tmp_edge.node1 = new_node
            if tmp_edge.node2 is edge.node1:
                to_reduce = 0
                to_reduce += 1 if tmp_edge.axis2 > axes[0] else 0
                to_reduce += 1 if tmp_edge.axis2 > axes[1] else 0
                tmp_edge.axis2 -= to_reduce
                tmp_edge.node2 = new_node
        # Update edges for the new node.
        for i, e in enumerate(node_edges):
            new_node.add_edge(e, i)
        self.nodes_set.remove(edge.node1)
def copy(nodes: Iterable[BaseNode],
         conjugate: bool = False) -> Tuple[dict, dict]:
    """Copy the given nodes and their edges.

  This will return a dictionary linking original nodes/edges 
  to their copies.

  Args:
    nodes: An `Iterable` (Usually a `List` or `Set`) of `Nodes`.
    conjugate: Boolean. Whether to conjugate all of the nodes
        (useful for calculating norms and reduced density
        matrices).

  Returns:
    A tuple containing:
      node_dict: A dictionary mapping the nodes to their copies.
      edge_dict: A dictionary mapping the edges to their copies.
  """
    #TODO: add support for copying CopyTensor
    if conjugate:
        node_dict = {
            node: Node(node.backend.conj(node.tensor),
                       name=node.name,
                       axis_names=node.axis_names,
                       backend=node.backend.name)
            for node in nodes
        }
    else:
        node_dict = {
            node: Node(node.tensor,
                       name=node.name,
                       axis_names=node.axis_names,
                       backend=node.backend.name)
            for node in nodes
        }
    edge_dict = {}
    for edge in get_all_edges(nodes):
        node1 = edge.node1
        axis1 = edge.node1.get_axis_number(edge.axis1)

        if not edge.is_dangling():
            node2 = edge.node2
            axis2 = edge.node2.get_axis_number(edge.axis2)
            new_edge = Edge(node_dict[node1], axis1, edge.name,
                            node_dict[node2], axis2)
            new_edge.set_signature(edge.signature)
        else:
            new_edge = Edge(node_dict[node1], axis1, edge.name)

        node_dict[node1].add_edge(new_edge, axis1)
        if not edge.is_dangling():
            node_dict[node2].add_edge(new_edge, axis2)
        edge_dict[edge] = new_edge
    return node_dict, edge_dict
Exemple #21
0
def nodes_from_json(
        json_str: str) -> Tuple[List[AbstractNode], Dict[str, Tuple[Edge]]]:
    """
  Create a tensor network from a JSON string representation of a tensor network.
  
  Args:
    json_str: A string representing a JSON serialized tensor network.
    
  Returns:
    A list of nodes making up the tensor network.
    A dictionary of {str -> (edge,)} bindings. All dictionary values are tuples
      of Edges.
    
  """
    network_dict = json.loads(json_str)
    nodes = []
    node_ids = {}
    edge_lookup = {}
    edge_binding = {}
    for n in network_dict['nodes']:
        node = Node.from_serial_dict(n['attributes'])
        nodes.append(node)
        node_ids[n['id']] = node
    for e in network_dict['edges']:
        e_nodes = [node_ids.get(n_id) for n_id in e['node_ids']]
        axes = e['attributes']['axes']
        edge = Edge(node1=e_nodes[0],
                    axis1=axes[0],
                    node2=e_nodes[1],
                    axis2=axes[1],
                    name=e['attributes']['name'])
        edge_lookup[e['id']] = edge
        for node, axis in zip(e_nodes, axes):
            if node is not None:
                node.add_edge(edge, axis, override=True)
    for k, v in network_dict.get('edge_binding', {}).items():
        for e_id in v:
            edge_binding[k] = edge_binding.get(k, ()) + (edge_lookup[e_id], )

    return nodes, edge_binding
Exemple #22
0
def load_nodes(path: str) -> List[BaseNode]:
    """
  Load a set of nodes from disk.

  Args:
    path: path to file where network is saved.
  Returns:
    An iterable of `Node` objects
  """
    nodes_list = []
    edges_list = []
    with h5py.File(path, 'r') as net_file:
        nodes = list(net_file["nodes"].keys())
        node_names = {
            'node{}'.format(n): v
            for n, v in enumerate(net_file["node_names"]['names'][()])
        }

        edge_names = {
            'edge{}'.format(n): v
            for n, v in enumerate(net_file["edge_names"]['names'][()])
        }
        edges = list(net_file["edges"].keys())
        for node_name in nodes:
            node_data = net_file["nodes/" + node_name]
            node_type = get_component(node_data['type'][()])
            nodes_list.append(
                node_type._load_node(net=None, node_data=node_data))
        nodes_dict = {node.name: node for node in nodes_list}
        for edge in edges:
            edge_data = net_file["edges/" + edge]
            edges_list.append(Edge._load_edge(edge_data, nodes_dict))

    for edge in edges_list:
        edge.set_name(edge_names[edge.name])
    for node in nodes_list:
        node.set_name(node_names[node.name])

    return nodes_list
def test_edge_set_name_throws_type_error(single_node_edge, name):
  edge = Edge(node1=single_node_edge.node, axis1=0)
  with pytest.raises(TypeError):
    edge.set_name(name)
Exemple #24
0
def copy(nodes: Iterable[BaseNode],
         conjugate: bool = False) -> Tuple[dict, dict]:
  """Copy the given nodes and their edges.

  This will return a dictionary linking original nodes/edges 
  to their copies. If nodes A and B are connected but only A is passed in to be
  copied, the edge between them will become a dangling edge.

  Args:
    nodes: An `Iterable` (Usually a `List` or `Set`) of `Nodes`.
    conjugate: Boolean. Whether to conjugate all of the nodes
        (useful for calculating norms and reduced density
        matrices).

  Returns:
    A tuple containing:
      node_dict: A dictionary mapping the nodes to their copies.
      edge_dict: A dictionary mapping the edges to their copies.
  """
  #TODO: add support for copying CopyTensor
  if conjugate:
    node_dict = {
        node: Node(
            node.backend.conj(node.tensor),
            name=node.name,
            axis_names=node.axis_names,
            backend=node.backend) for node in nodes
    }
  else:
    node_dict = {
        node: Node(
            node.tensor,
            name=node.name,
            axis_names=node.axis_names,
            backend=node.backend) for node in nodes
    }
  edge_dict = {}
  for edge in get_all_edges(nodes):
    node1 = edge.node1
    axis1 = edge.node1.get_axis_number(edge.axis1)
    # edge dangling or node2 does not need to be copied
    if edge.is_dangling() or edge.node2 not in node_dict:
      new_edge = Edge(node_dict[node1], axis1, edge.name)
      node_dict[node1].add_edge(new_edge, axis1)
      edge_dict[edge] = new_edge
      continue

    node2 = edge.node2
    axis2 = edge.node2.get_axis_number(edge.axis2)
    # copy node2 but not node1
    if node1 not in node_dict:
      new_edge = Edge(node_dict[node2], axis2, edge.name)
      node_dict[node2].add_edge(new_edge, axis2)
      edge_dict[edge] = new_edge
      continue

    # both nodes should be copied
    new_edge = Edge(node_dict[node1], axis1, edge.name, node_dict[node2], axis2)
    new_edge.set_signature(edge.signature)
    node_dict[node2].add_edge(new_edge, axis2)
    node_dict[node1].add_edge(new_edge, axis1)
    edge_dict[edge] = new_edge

  return node_dict, edge_dict
def test_edge_signature_setter_disabled_throws_error(single_node_edge):
  edge = Edge(node1=single_node_edge.node, axis1=0)
  edge.is_disabled = True
  with pytest.raises(ValueError):
    edge.signature = "signature"
def test_edge_is_being_used_false(single_node_edge):
  node = single_node_edge.node
  edge2 = Edge(name="edge", node1=node, axis1=0)
  assert not edge2.is_being_used()
def test_edge_is_trace_true(single_node_edge):
  node = single_node_edge.node
  edge = Edge(name="edge", node1=node, axis1=1, node2=node, axis2=2)
  assert edge.is_trace()
def test_edge_name_throws_type_error(single_node_edge, name):
  with pytest.raises(TypeError):
    Edge(node1=single_node_edge.node, axis1=0, name=name)
def test_edge_node1_throws_value_error(single_node_edge):
  edge = Edge(node1=single_node_edge.node, axis1=0, name="edge")
  edge._node1 = None
  err_msg = "node1 for edge 'edge' no longer exists."
  with pytest.raises(ValueError, match=err_msg):
    edge.node1