def rotate_psi(nn_state, basis, space, unitaries, psi=None): r"""A function that rotates the reconstructed wavefunction to a different basis. :param nn_state: The neural network state (i.e. complex wavefunction or positive wavefunction). :type nn_state: WaveFunction :param basis: The basis to rotate the wavefunction to. :type basis: str :param space: The hilbert space of the system. :type space: torch.Tensor :param unitaries: A dictionary of (2x2) unitary operators. :type unitaries: dict :param psi: A wavefunction that the user can input to override the neural network state's wavefunction. :type psi: torch.Tensor :returns: A wavefunction in a new basis. :rtype: torch.Tensor """ N = nn_state.num_visible v = torch.zeros(N, dtype=torch.double, device=nn_state.device) psi_r = torch.zeros(2, 1 << N, dtype=torch.double, device=nn_state.device) for x in range(1 << N): Upsi = torch.zeros(2, dtype=torch.double, device=nn_state.device) num_nontrivial_U = 0 nontrivial_sites = [] for jj in range(N): if basis[jj] != "Z": num_nontrivial_U += 1 nontrivial_sites.append(jj) sub_state = nn_state.generate_hilbert_space(num_nontrivial_U) for xp in range(1 << num_nontrivial_U): cnt = 0 for j in range(N): if basis[j] != "Z": v[j] = sub_state[xp][cnt] cnt += 1 else: v[j] = space[x, j] U = torch.tensor([1.0, 0.0], dtype=torch.double, device=nn_state.device) for ii in range(num_nontrivial_U): tmp = unitaries[basis[nontrivial_sites[ii]]] tmp = tmp[:, int(space[x][nontrivial_sites[ii]]), int(v[nontrivial_sites[ii]])].to(nn_state.device) U = cplx.scalar_mult(U, tmp) if psi is None: Upsi += cplx.scalar_mult(U, nn_state.psi(v).squeeze()) else: index = 0 for k in range(len(v)): index = (index << 1) | int(v[k].item()) Upsi += cplx.scalar_mult(U, psi[:, index]) psi_r[:, x] = Upsi return psi_r
def test_scalar_mult_overwrite_fail(self): scalar = torch.tensor([2, 3], dtype=torch.double) vector = torch.tensor([[1, 2], [3, 4]], dtype=torch.double) with self.assertRaises(RuntimeError): cplx.scalar_mult(scalar, vector, out=vector) with self.assertRaises(RuntimeError): cplx.scalar_mult(scalar, vector, out=scalar)
def test_scalar_mult_overwrite(self): scalar = torch.tensor([2, 3], dtype=torch.double) vector = torch.tensor([[1, 2], [3, 4]], dtype=torch.double) out = torch.zeros_like(vector) expect = torch.tensor([[-7, -8], [9, 14]], dtype=torch.double) cplx.scalar_mult(scalar, vector, out=out) self.assertTensorsEqual( out, expect, msg="Scalar * Vector multiplication with 'out' parameter failed!", )
def rotated_gradient(self, basis, sites, sample): Upsi, vp, Us, rotated_grad = self.init_gradient(basis, sites) int_sample = sample[sites].round().int().cpu().numpy() vp = sample.round().clone() grad_size = (self.num_visible * self.num_hidden + self.num_hidden + self.num_visible) Upsi_v = torch.zeros_like(Upsi, device=self.device) Z = torch.zeros(grad_size, dtype=torch.double, device=self.device) Z2 = torch.zeros((2, grad_size), dtype=torch.double, device=self.device) U = torch.tensor([1.0, 1.0], dtype=torch.double, device=self.device) Ut = np.zeros_like(Us[:, 0], dtype=complex) ints_size = np.arange(sites.size) for x in range(2**sites.size): # overwrite rotated elements vp = sample.round().clone() vp[sites] = self.subspace_vector(x, size=sites.size) int_vp = vp[sites].int().cpu().numpy() all_Us = Us[ints_size, :, int_sample, int_vp] # Gradient from the rotation Ut = np.prod(all_Us[:, 0] + (1j * all_Us[:, 1])) U[0] = Ut.real U[1] = Ut.imag cplx.scalar_mult(U, self.psi(vp), out=Upsi_v) Upsi += Upsi_v # Gradient on the current configuration grad_vp0 = self.rbm_am.effective_energy_gradient(vp) grad_vp1 = self.rbm_ph.effective_energy_gradient(vp) rotated_grad[0] += cplx.scalar_mult(Upsi_v, cplx.make_complex(grad_vp0, Z), out=Z2) rotated_grad[1] += cplx.scalar_mult(Upsi_v, cplx.make_complex(grad_vp1, Z), out=Z2) grad = [ cplx.scalar_divide(rotated_grad[0], Upsi)[0, :], # Real -cplx.scalar_divide(rotated_grad[1], Upsi)[1, :], # Imaginary ] return grad
def rotate_psi(self, basis, unitary_dict, vis): N = self.nn_state.num_visible v = torch.zeros(N, dtype=torch.double, device=self.nn_state.device) psi_r = torch.zeros(2, 1 << N, dtype=torch.double, device=self.nn_state.device) for x in range(1 << N): Upsi = torch.zeros(2, dtype=torch.double, device=self.nn_state.device) num_nontrivial_U = 0 nontrivial_sites = [] for j in range(N): if basis[j] is not "Z": num_nontrivial_U += 1 nontrivial_sites.append(j) sub_state = self.nn_state.generate_hilbert_space(num_nontrivial_U) for xp in range(1 << num_nontrivial_U): cnt = 0 for j in range(N): if basis[j] is not "Z": v[j] = sub_state[xp][cnt] cnt += 1 else: v[j] = vis[x, j] U = torch.tensor([1.0, 0.0], dtype=torch.double, device=self.nn_state.device) for ii in range(num_nontrivial_U): tmp = unitary_dict[basis[nontrivial_sites[ii]]] tmp = tmp[:, int(vis[x][nontrivial_sites[ii]]), int(v[nontrivial_sites[ii]]), ] U = cplx.scalar_mult(U, tmp) Upsi += cplx.scalar_mult(U, self.nn_state.psi(v)) psi_r[:, x] = Upsi return psi_r
def test_scalar_matrix_mult(self): scalar = torch.tensor([2, 3]) matrix = torch.tensor([[[1, 2], [3, 4]], [[5, 6], [7, 8]]]) expect = torch.tensor([[[-13, -14], [-15, -16]], [[13, 18], [23, 28]]]) self.assertTensorsEqual( cplx.scalar_mult(scalar, matrix), expect, msg="Scalar * Matrix multiplication failed!", )
def test_scalar_vector_mult(self): scalar = torch.tensor([2, 3], dtype=torch.double) vector = torch.tensor([[1, 2], [3, 4]], dtype=torch.double) expect = torch.tensor([[-7, -8], [9, 14]], dtype=torch.double) self.assertTensorsEqual( cplx.scalar_mult(scalar, vector), expect, msg="Scalar * Vector multiplication failed!", )
def ph_grads(self, v): r"""Computes the gradients of the phase RBM for given input states :param v: The first input state, :math:`\sigma` :type v: torch.Tensor :returns: The gradients of all phase RBM parameters :rtype: torch.Tensor """ return cplx.scalar_mult( # need to multiply Gamma- by i self.rbm_ph.gamma_grad(v, v, eta=-1, expand=True), cplx.I) + self.pi_grad(v, v, phase=True, expand=True)
def rotate_psi(nn_state, basis, unitaries, psi=None): N = nn_state.num_visible v = torch.zeros(N, dtype=torch.double, device=nn_state.device) psi_r = torch.zeros(2, 1 << N, dtype=torch.double, device=nn_state.device) for x in range(1 << N): Upsi = torch.zeros(2, dtype=torch.double, device=nn_state.device) num_nontrivial_U = 0 nontrivial_sites = [] for j in range(N): if (basis[j] != 'Z'): num_nontrivial_U += 1 nontrivial_sites.append(j) sub_state = nn_state.generate_Hilbert_space(num_nontrivial_U) for xp in range(1 << num_nontrivial_U): cnt = 0 for j in range(N): if (basis[j] != 'Z'): v[j] = sub_state[xp][cnt] cnt += 1 else: v[j] = nn_state.space[x, j] U = torch.tensor([1., 0.], dtype=torch.double, device=nn_state.device) for ii in range(num_nontrivial_U): tmp = unitaries[basis[nontrivial_sites[ ii]]][:, int(nn_state.space[x][nontrivial_sites[ii]]), int(v[nontrivial_sites[ii]])].to(nn_state.device) U = cplx.scalar_mult(U, tmp) if psi is None: Upsi += cplx.scalar_mult(U, nn_state.psi(v)) else: index = 0 for k in range(len(v)): index = (index << 1) | int(v[k].item()) Upsi += cplx.scalar_mult(U, psi[:, index]) psi_r[:, x] = Upsi return psi_r
def ph_grads(self, v): r"""Computes the gradients of the phase RBM for given input states :param v: The input state, :math:`\sigma` :type v: torch.Tensor :returns: The gradients of all phase RBM parameters :rtype: torch.Tensor """ return cplx.scalar_mult( cplx.make_complex(self.rbm_ph.effective_energy_gradient(v, reduce=False)), cplx.I, # need to multiply phase gradient by i )
def ph_grads(self, v, vp): r"""Computes the gradients of the phase RBM for given input states :param v: The first input state, :math:`\sigma` :type v: torch.Tensor :param vp: The second input state, :math:`\sigma'` :type vp: torch.Tensor :returns: The gradients of all phase RBM parameters :rtype: torch.Tensor """ return cplx.scalar_mult( # need to multiply Gamma- by i self.rbm_ph.GammaM_grad(v, vp), torch.Tensor( [0, 1])) + self.Pi_grad_ph(v, vp)
def rotated_gradient(self, basis, sites, sample): Upsi, vp, Us, rotated_grad = self.init_gradient(basis, sites) int_sample = sample[sites].round().int().cpu().numpy() Upsi_v = torch.zeros_like(Upsi, device=self.device) ints_size = np.arange(sites.size) # if the number of rotated sites is too large, fallback to loop # since memory may be unable to store the entire expanded set of # visible states if sites.size > self.max_size or ( hasattr(self, "debug_gradient_rotation") and self.debug_gradient_rotation ): grad_size = ( self.num_visible * self.num_hidden + self.num_hidden + self.num_visible ) vp = sample.round().clone() Z = torch.zeros(grad_size, dtype=torch.double, device=self.device) Z2 = torch.zeros((2, grad_size), dtype=torch.double, device=self.device) U = torch.tensor([1.0, 1.0], dtype=torch.double, device=self.device) Ut = np.zeros_like(Us[:, 0], dtype=complex) for x in range(2 ** sites.size): # overwrite rotated elements vp = sample.round().clone() vp[sites] = self.subspace_vector(x, size=sites.size) int_vp = vp[sites].int().cpu().numpy() all_Us = Us[ints_size, :, int_sample, int_vp] # Gradient from the rotation Ut = np.prod(all_Us[:, 0] + (1j * all_Us[:, 1])) U[0] = Ut.real U[1] = Ut.imag cplx.scalar_mult(U, self.psi(vp), out=Upsi_v) Upsi += Upsi_v # Gradient on the current configuration grad_vp0 = self.rbm_am.effective_energy_gradient(vp) grad_vp1 = self.rbm_ph.effective_energy_gradient(vp) rotated_grad[0] += cplx.scalar_mult( Upsi_v, cplx.make_complex(grad_vp0, Z), out=Z2 ) rotated_grad[1] += cplx.scalar_mult( Upsi_v, cplx.make_complex(grad_vp1, Z), out=Z2 ) else: vp = sample.round().clone().unsqueeze(0).repeat(2 ** sites.size, 1) vp[:, sites] = self.generate_hilbert_space(size=sites.size) vp = vp.contiguous() # overwrite rotated elements int_vp = vp[:, sites].long().cpu().numpy() all_Us = Us[ints_size, :, int_sample, int_vp] Ut = np.prod(all_Us[..., 0] + (1j * all_Us[..., 1]), axis=1) U = ( cplx.make_complex(torch.tensor(Ut.real), torch.tensor(Ut.imag)) .to(vp) .contiguous() ) Upsi_v = cplx.scalar_mult(U, self.psi(vp).detach()) Upsi = torch.sum(Upsi_v, dim=1) grad_vp0 = self.rbm_am.effective_energy_gradient(vp, reduce=False) grad_vp1 = self.rbm_ph.effective_energy_gradient(vp, reduce=False) # since grad_vp0/1 are real, can just treat the scalar multiplication # and addition as a matrix multiplication torch.matmul(Upsi_v, grad_vp0, out=rotated_grad[0]) torch.matmul(Upsi_v, grad_vp1, out=rotated_grad[1]) grad = [ cplx.scalar_divide(rotated_grad[0], Upsi)[0, :], # Real -cplx.scalar_divide(rotated_grad[1], Upsi)[1, :], # Imaginary ] return grad
def pi_grad(self, v, vp, phase=False, expand=False): r"""Calculates the gradient of the :math:`\Pi` matrix with respect to the amplitude RBM parameters for two input states :param v: One of the visible states, :math:`\sigma` :type v: torch.Tensor :param vp: The other visible state, :math`\sigma'` :type vp: torch.Tensor :param phase: Whether to compute the gradients for the phase RBM (`True`) or the amplitude RBM (`False`) :type phase: bool :returns: The matrix element of the gradient given by :math:`\langle\sigma|\nabla_\lambda\Pi|\sigma'\rangle` :rtype: torch.Tensor """ unsqueezed = v.dim() < 2 or vp.dim() < 2 v = (v.unsqueeze(0) if v.dim() < 2 else v).to(self.rbm_am.weights_W) vp = (vp.unsqueeze(0) if vp.dim() < 2 else vp).to( self.rbm_am.weights_W) if expand: arg_real = 0.5 * (F.linear(v, self.rbm_am.weights_U, self.rbm_am.aux_bias).unsqueeze_(1) + F.linear(vp, self.rbm_am.weights_U, self.rbm_am.aux_bias).unsqueeze_(0)) arg_imag = 0.5 * ( F.linear(v, self.rbm_ph.weights_U).unsqueeze_(1) - F.linear(vp, self.rbm_ph.weights_U).unsqueeze_(0)) else: arg_real = self.rbm_am.mixing_term(v + vp) arg_imag = self.rbm_ph.mixing_term(v - vp) sig = cplx.sigmoid(arg_real, arg_imag) batch_sizes = ((v.shape[0], vp.shape[0], *v.shape[1:-1]) if expand else (*v.shape[:-1], )) W_grad = torch.zeros_like(self.rbm_am.weights_W).expand( *batch_sizes, -1, -1) vb_grad = torch.zeros_like(self.rbm_am.visible_bias).expand( *batch_sizes, -1) hb_grad = torch.zeros_like(self.rbm_am.hidden_bias).expand( *batch_sizes, -1) if phase: temp = (v.unsqueeze(1) - vp.unsqueeze(0)) if expand else (v - vp) sig = cplx.scalar_mult(sig, cplx.I) ab_grad_real = torch.zeros_like(self.rbm_ph.aux_bias).expand( *batch_sizes, -1) ab_grad_imag = ab_grad_real.clone() else: temp = (v.unsqueeze(1) + vp.unsqueeze(0)) if expand else (v + vp) ab_grad_real = cplx.real(sig) ab_grad_imag = cplx.imag(sig) U_grad = 0.5 * torch.einsum("c...j,...k->c...jk", sig, temp) U_grad_real = cplx.real(U_grad) U_grad_imag = cplx.imag(U_grad) vec_real = [ W_grad.view(*batch_sizes, -1), U_grad_real.view(*batch_sizes, -1), vb_grad, hb_grad, ab_grad_real, ] vec_imag = [ W_grad.view(*batch_sizes, -1).clone(), U_grad_imag.view(*batch_sizes, -1), vb_grad.clone(), hb_grad.clone(), ab_grad_imag, ] if unsqueezed and not expand: vec_real = [grad.squeeze_(0) for grad in vec_real] vec_imag = [grad.squeeze_(0) for grad in vec_imag] return cplx.make_complex(torch.cat(vec_real, dim=-1), torch.cat(vec_imag, dim=-1))
def rotated_gradient(self, basis, sites, sample): r"""Computes the gradients rotated into the measurement basis :param basis: The bases in which the measurement is made :type basis: numpy.ndarray :param sites: The sites where the measurements are not made in the computational basis :type sites: numpy.ndarray :param sample: The measurement (either 0 or 1) :type sample: torch.Tensor :returns: A list of two tensors, representing the rotated gradients of the amplitude and phase RBMS :rtype: list[torch.Tensor, torch.Tensor] """ UrhoU, v, vp, Us, Us_dag, rotated_grad = self.init_gradient( basis, sites) int_sample = sample[sites].round().int().cpu().numpy() ints_size = np.arange(sites.size) U_ = torch.tensor([1.0, 1.0], dtype=torch.double, device=self.device) UrhoU = torch.zeros(2, dtype=torch.double, device=self.device) UrhoU_ = torch.zeros_like(UrhoU) grad_size = (self.num_visible * self.num_hidden + self.num_visible * self.num_aux + self.num_visible + self.num_hidden + self.num_aux) Z2 = torch.zeros((2, grad_size), dtype=torch.double, device=self.device) v = sample.round().clone() vp = sample.round().clone() for x in range(2**sites.size): v = sample.round().clone() v[sites] = self.subspace_vector(x, sites.size) int_v = v[sites].int().cpu().numpy() all_Us = Us[ints_size, :, int_sample, int_v] for y in range(2**sites.size): vp = sample.round().clone() vp[sites] = self.subspace_vector(y, sites.size) int_vp = vp[sites].int().cpu().numpy() all_Us_dag = Us[ints_size, :, int_sample, int_vp] Ut = np.prod(all_Us[:, 0] + (1j * all_Us[:, 1])) Ut *= np.prod( np.conj(all_Us_dag[:, 0] + (1j * all_Us_dag[:, 1]))) U_[0] = Ut.real U_[1] = Ut.imag cplx.scalar_mult(U_, self.rhoRBM_tilde(v, vp), out=UrhoU_) UrhoU += UrhoU_ grad0 = self.am_grads(v, vp) grad1 = self.ph_grads(v, vp) rotated_grad[0] += cplx.scalar_mult(UrhoU_, grad0, out=Z2) rotated_grad[1] += cplx.scalar_mult(UrhoU_, grad1, out=Z2) grad = [ cplx.scalar_divide(rotated_grad[0], UrhoU), cplx.scalar_divide(rotated_grad[1], UrhoU), ] return grad
def gradient(self, basis, sample): r"""Compute the gradient of a set (v_state) of samples, measured in different bases :param basis: A set of basis, (i.e.vector of strings) :type basis: np.array """ num_U = 0 # Number of 1-local unitary rotations rotated_sites = [] # List of site where the rotations are applied grad = [] # Gradient # Read where the unitary rotations are applied for j in range(self.num_visible): if (basis[j] != 'Z'): num_U += 1 rotated_sites.append(j) # If the basis is the reference one ('ZZZ..Z') if (num_U == 0): grad.append(self.rbm_am.effective_energy_gradient(sample)) # Real grad.append(0.0) # Imaginary else: # Initialize vp = torch.zeros(self.num_visible, dtype=torch.double, device=self.device) rotated_grad = [ torch.zeros(2, self.rbm_am.num_pars, dtype=torch.double, device=self.device), torch.zeros(2, self.rbm_ph.num_pars, dtype=torch.double, device=self.device) ] Upsi = torch.zeros(2, dtype=torch.double, device=self.device) # Sum over the full subspace where the rotation are applied #sub_state = self.generate_visible_space(num_U) sub_space = self.generate_Hilbert_space(num_U) for x in range(1 << num_U): # Create the correct state for the full system (given the data) cnt = 0 for j in range(self.num_visible): if (basis[j] != 'Z'): #vp[j]=sub_state[x][cnt] # This site sums (it is rotated) vp[j] = sub_space[x][cnt] cnt += 1 else: vp[j] = sample[j] # This site is left unchanged U = torch.tensor( [1., 0.], dtype=torch.double, device=self.device ) #Product of the matrix elements of the unitaries for ii in range(num_U): tmp = self.unitary_dict[basis[ rotated_sites[ii]]][:, int(sample[rotated_sites[ii]]), int(vp[rotated_sites[ii]])] U = cplx.scalar_mult(U, tmp.to(self.device)) # Gradient on the current configuration grad_vp = [ self.rbm_am.effective_energy_gradient(vp), self.rbm_ph.effective_energy_gradient(vp) ] # NN state rotated in this bases Upsi_v = cplx.scalar_mult(U, self.psi(vp)) Upsi += Upsi_v rotated_grad[0] += cplx.scalar_mult( Upsi_v, cplx.make_complex(grad_vp[0], torch.zeros_like(grad_vp[0]))) rotated_grad[1] += cplx.scalar_mult( Upsi_v, cplx.make_complex(grad_vp[1], torch.zeros_like(grad_vp[1]))) grad.append(cplx.scalar_divide(rotated_grad[0], Upsi)[0, :]) grad.append(-cplx.scalar_divide(rotated_grad[1], Upsi)[1, :]) return grad