def total_energy(self, q_list, l_list): nsample = q_list.shape[0] nparticle = q_list.shape[1] dim = q_list.shape[2] dq = delta_pbc(q_list, l_list) idx = get_paired_distance_indices.get_indices(dq.shape) dr = get_paired_distance_indices.reduce(dq, idx) dr = dr.view([nsample, nparticle, nparticle - 1, dim]) r = torch.sqrt(torch.sum(dr * dr, dim=-1)) e_list = self.paired_energy(r) e_total = torch.sum(e_list, dim=(1, 2)) * 0.5 return e_total
def paired_distance_reduced(q, npar): l_list = torch.zeros(q.shape) l_list.fill_(1) l_list = mydevice.load(l_list) dq = delta_pbc(q,l_list) # shape is [nsamples, nparticle, nparticle, DIM] dq_reduced_index = get_paired_distance_indices.get_indices(dq.shape) dq_flatten = get_paired_distance_indices.reduce(dq, dq_reduced_index) # dq_flatten.shape is [nsamples x nparticle x (nparticle - 1) x DIM] dq_reshape = dq_flatten.view(q.shape[0], npar, npar - 1, q.shape[2]) # dq_reshape.shape is [nsamples, nparticle, (nparticle - 1), DIM] dd = torch.sqrt(torch.sum(dq_reshape * dq_reshape, dim=-1)) # dd.shape is [nsamples, nparticle, (nparticle - 1 )] return dq_reshape, dd
def prepare_input(self,q_list,p_list,l_list,tau): nsamples, nparticle, DIM = q_list.shape dq = delta_pbc(q_list, l_list) # shape is [nsamples, nparticle, nparticle, DIM] dq = torch.reshape(dq, (nsamples * nparticle * nparticle, DIM)) # shape is [nsamples* nparticle* nparticle, DIM] dp = delta_state(p_list) # dq.shape = dp.shape = [nsamples, nparticle, nparticle, 2] dp = torch.reshape(dp, (nsamples * nparticle * nparticle, DIM)) # shape is [nsamples* nparticle* nparticle, DIM] tau_tensor = torch.zeros([nsamples*nparticle*nparticle, 1],requires_grad=False) + 0.5*tau tau_tensor = mydevice.load(tau_tensor) #tau_tensor.fill_(tau * 0.5) # tau_tensor take them tau/2 x = torch.cat((dq, dp, tau_tensor), dim=-1) # dqdp.shape is [ nsamples*nparticle*nparticle, 5] return x
torch.set_default_dtype(torch.float64) nsample = 20 nparticle = 2 dim = 2 q_list = torch.rand([nsample, nparticle, dim], requires_grad=True) l_list = torch.rand([nsample, dim]) + nparticle * nparticle l_list = torch.unsqueeze(l_list, dim=1) l_list = torch.repeat_interleave(l_list, nparticle, dim=1) #print('q_list ',q_list) #print('l_list ',l_list) dq = delta_pbc(q_list, l_list) # shape [nsample,nparticle,nparticle,dim] dr = torch.sqrt(torch.sum(dq * dq, dim=-1)) # shape [nsample,nparticle,nparticle] #print('dr slow ',dr) e_list = [] for s in range(nsample): e = 0.0 for p1 in range(nparticle): for p2 in range(nparticle): if p1 != p2: r = dr[s][p1][p2] e6 = 1 / (r**6 + 1e-10) e12 = 1 / (r**12 + 1e-10) #print('r ',r,' add to ',4*(e12-e6))
def prepare_q_input(self, pwnet_id, q_list, p_list, l_list): # p_list not used here dq0 = delta_pbc(q_list, l_list) dq1 = self.make_correct_shape(dq0) return dq1