def VRL_train(five_tuple,
              k,
              data,
              H,
              H_core,
              omega,
              lr=0.00001,
              epochs=int(1e4),
              lam=1):
    s, a, P_mat, R_vec, gamma = five_tuple
    energy_history = []
    Pbar = pkbar.Pbar(name='progress', target=epochs)
    op = optim.SGD([data], lr=lr, momentum=0.9, weight_decay=5e-4)

    for e in range(epochs):
        spins = []
        core = []
        edge = []
        energy = 0
        regularity = 0

        op.zero_grad()

        for i in range(k):
            H_core[i] = tensor_completion(H_core[i], H[i], omega[i])
            core.append(tn.replicate_nodes(H_core[i]))
            edge.append([])
            for c in core[i]:
                edge[i] += c.get_all_dangling()

        for i in range(k):
            spins.append(K_Spin(s, a, i + 1, data=data, softmax=True))
            for j in range(i + 1):
                edge[i][j] ^ spins[i].qubits[j][0]

        for i in range(k):
            energy -= contractors.branch(tn.reachable(core[i]),
                                         nbranch=1).get_tensor()
        energy_history.append(energy)

        for j in range(s):
            regularity += (1 - torch.sum(data[j * a:(j + 1) * a], 0))**2
        target = energy + lam * regularity

        target.backward()
        op.step()
        Pbar.update(e)

    return spins, energy_history
Ejemplo n.º 2
0
def binary_mera_energy(hamiltonian, state, isometry, disentangler):
  """Computes the energy using a layer of uniform binary MERA.

  Args:
    hamiltonian: The hamiltonian (rank-6 tensor) defined at the bottom of the
      MERA layer.
    state: The 3-site reduced state (rank-6 tensor) defined at the top of the
      MERA layer.
    isometry: The isometry tensor (rank 3) of the binary MERA.
    disentangler: The disentangler tensor (rank 4) of the binary MERA.

  Returns:
    The energy.
  """
  backend = "jax"

  out = []
  for dirn in ('left', 'right'):
    iso_l = tensornetwork.Node(isometry, backend=backend)
    iso_c = tensornetwork.Node(isometry, backend=backend)
    iso_r = tensornetwork.Node(isometry, backend=backend)

    iso_l_con = tensornetwork.conj(iso_l)
    iso_c_con = tensornetwork.conj(iso_c)
    iso_r_con = tensornetwork.conj(iso_r)

    op = tensornetwork.Node(hamiltonian, backend=backend)
    rho = tensornetwork.Node(state, backend=backend)

    un_l = tensornetwork.Node(disentangler, backend=backend)
    un_l_con = tensornetwork.conj(un_l)

    un_r = tensornetwork.Node(disentangler, backend=backend)
    un_r_con = tensornetwork.conj(un_r)

    tensornetwork.connect(iso_l[2], rho[0])
    tensornetwork.connect(iso_c[2], rho[1])
    tensornetwork.connect(iso_r[2], rho[2])

    tensornetwork.connect(iso_l[0], iso_l_con[0])
    tensornetwork.connect(iso_l[1], un_l[2])
    tensornetwork.connect(iso_c[0], un_l[3])
    tensornetwork.connect(iso_c[1], un_r[2])
    tensornetwork.connect(iso_r[0], un_r[3])
    tensornetwork.connect(iso_r[1], iso_r_con[1])

    if dirn == 'right':
      tensornetwork.connect(un_l[0], un_l_con[0])
      tensornetwork.connect(un_l[1], op[3])
      tensornetwork.connect(un_r[0], op[4])
      tensornetwork.connect(un_r[1], op[5])
      tensornetwork.connect(op[0], un_l_con[1])
      tensornetwork.connect(op[1], un_r_con[0])
      tensornetwork.connect(op[2], un_r_con[1])
    elif dirn == 'left':
      tensornetwork.connect(un_l[0], op[3])
      tensornetwork.connect(un_l[1], op[4])
      tensornetwork.connect(un_r[0], op[5])
      tensornetwork.connect(un_r[1], un_r_con[1])
      tensornetwork.connect(op[0], un_l_con[0])
      tensornetwork.connect(op[1], un_l_con[1])
      tensornetwork.connect(op[2], un_r_con[0])

    tensornetwork.connect(un_l_con[2], iso_l_con[1])
    tensornetwork.connect(un_l_con[3], iso_c_con[0])
    tensornetwork.connect(un_r_con[2], iso_c_con[1])
    tensornetwork.connect(un_r_con[3], iso_r_con[0])

    tensornetwork.connect(iso_l_con[2], rho[3])
    tensornetwork.connect(iso_c_con[2], rho[4])
    tensornetwork.connect(iso_r_con[2], rho[5])

    # FIXME: Check that this is giving us a good path!
    out.append(
        contractors.branch(tensornetwork.reachable(rho),
                           nbranch=2).get_tensor())

  return 0.5 * sum(out)
Ejemplo n.º 3
0
# define network edges
tn.connect(iso_l[0], iso_l_con[0])
tn.connect(iso_l[1], un_l[2])
tn.connect(iso_c[0], un_l[3])
tn.connect(iso_c[1], un_r[2])
tn.connect(iso_r[0], un_r[3])
tn.connect(iso_r[1], iso_r_con[1])
tn.connect(un_l[0], un_l_con[0])
tn.connect(un_l[1], op[3])
tn.connect(un_r[0], op[4])
tn.connect(un_r[1], op[5])
tn.connect(op[0], un_l_con[1])
tn.connect(op[1], un_r_con[0])
tn.connect(op[2], un_r_con[1])
tn.connect(un_l_con[2], iso_l_con[1])
tn.connect(un_l_con[3], iso_c_con[0])
tn.connect(un_r_con[2], iso_c_con[1])
tn.connect(un_r_con[3], iso_r_con[0])

# define output edges
output_edge_order = [
    iso_l_con[2], iso_c_con[2], iso_r_con[2], iso_l[2], iso_c[2], iso_r[2]
]

# solve for optimal order and contract the network
t0 = time.time()
T2 = contractors.branch(tn.reachable(op),
                        output_edge_order=output_edge_order).get_tensor()
print("tn.contractors: time to contract = ", time.time() - t0)