def test_save_load_nodes(backend, tmp_path):
  nodes = [
      Node(
          np.random.rand(2, 2, 2, 2),
          backend=backend,
          name='Node{}'.format(n),
          axis_names=[
              'node{}_1'.format(n), 'node{}_2'.format(n), 'node{}_3'.format(n),
              'node{}_4'.format(n)
          ]) for n in range(4)
  ]

  nodes[0][0] ^ nodes[1][1]
  nodes[2][1] ^ nodes[2][2]

  tn.save_nodes(nodes, tmp_path / 'test_file_save_nodes')

  loaded_nodes = tn.load_nodes(tmp_path / 'test_file_save_nodes')
  for n, node in enumerate(nodes):
    assert node.name == loaded_nodes[n].name
    assert node.axis_names == loaded_nodes[n].axis_names
    assert node.backend.name == loaded_nodes[n].backend.name
    np.testing.assert_allclose(node.tensor, loaded_nodes[n].tensor)

  res = nodes[0] @ nodes[1]
  loaded_res = loaded_nodes[0] @ loaded_nodes[1]
  np.testing.assert_allclose(res.tensor, loaded_res.tensor)

  trace = tn.contract_trace_edges(nodes[2])
  loaded_trace = tn.contract_trace_edges(loaded_nodes[2])
  np.testing.assert_allclose(trace.tensor, loaded_trace.tensor)
Beispiel #2
0
def delete_traces_no_complaint(N):
    """
    Contracts all the edges that begin and end on the same
    node. (Let's call them as "internal" edges.)

    Args:
        N(tensornetwork.Node): Tensornetwork Node

    Returns:
        N(tensornetwork.Node): The same node, but after
                               contracting all the internal
                               edges.
    """
    has_trace_edges = False
    for edge in N.edges:
        if edge.is_trace():
            has_trace_edges = True
            break

    if has_trace_edges:
        N = tn.contract_trace_edges(N)
    return N
Beispiel #3
0
def test_contract_trace_edges(backend):
    a = tn.Node(np.random.rand(3, 3, 3), backend=backend)
    with pytest.raises(ValueError):
        tn.contract_trace_edges(a)
Beispiel #4
0
def test_contract_trace_edges(dtype, num_charges):
    np.random.seed(10)
    a = tn.Node(get_random((3, 3, 3), num_charges=num_charges, dtype=dtype),
                backend='symmetric')
    with pytest.raises(ValueError):
        tn.contract_trace_edges(a)