def train(epoch): pace = args.pace for i in range(4): models[i].train() if epoch <= 50 and epoch % 20 == 0: for param_group1 in optimizers[i].param_groups: param_group1['lr'] = 0.5 * param_group1['lr'] elif epoch > 50 and epoch % 20 == 0: for param_group1 in optimizers[i].param_groups: param_group1['lr'] = 0.5 * param_group1['lr'] if epoch <= 50 and epoch % 20 == 0: for param_group1 in optimizerGs[i].param_groups: param_group1['lr'] = 0.5 * param_group1['lr'] elif epoch > 50 and epoch % 20 == 0: for param_group1 in optimizerGs[i].param_groups: param_group1['lr'] = 0.5 * param_group1['lr'] discriminators[i].train() if epoch <= 50 and epoch % 20 == 0: for param_group1 in optimizerDs[i].param_groups: param_group1['lr'] = 0.5 * param_group1['lr'] elif epoch > 50 and epoch % 20 == 0: for param_group1 in optimizerDs[i].param_groups: param_group1['lr'] = 0.5 * param_group1['lr'] #define weights w = dict() denominator = np.sum(np.array(tbs)) for i in range(4): w[i] = 0.25 #tbs[i]/denominator loss_all = dict() lossD_all = dict() lossG_all = dict() num_data = dict() num_dataG = dict() num_dataD = dict() for i in range(4): loss_all[i] = 0 num_data[i] = EPS num_dataG[i] = EPS lossG_all[i] = 0 lossD_all[i] = 0 num_dataD[i] = EPS count = 0 for t in range(args.nsteps): fs = [] # optimize classifier for i in range(4): optimizers[i].zero_grad() a, b = next(data_iters[i]) num_data[i] += b.size(0) a = a.to(device) b = b.to(device) output = models[i](a) loss = celoss(output, b) loss_all[i] += loss.item() * b.size(0) if epoch >= 0: loss.backward(retain_graph=True) optimizers[i].step() fs.append(models[i].encoder(a)) #optimize alignment nn = [] noises = [] for i in range(4): nn = tdist.Normal(torch.tensor([0.0]), 0.001 * torch.std(fs[i].detach().cpu())) noises.append(nn.sample(fs[i].size()).squeeze().to(device)) for i in range(4): for j in range(4): if i != j: optimizerDs[i].zero_grad() optimizerGs[i].zero_grad() optimizerGs[j].zero_grad() d1 = discriminators[i](fs[i] + noises[i]) d2 = discriminators[i](fs[j] + noises[j]) num_dataG[i] += d1.size(0) num_dataD[i] += d1.size(0) lossD = advDloss(d1, d2) lossG = advGloss(d1, d2) lossD_all[i] += lossD.item() * d1.size(0) lossG_all[i] += lossG.item() * d1.size(0) lossG_all[j] += lossG.item() * d2.size(0) lossD = 0.1 * lossD if epoch >= 5: lossD.backward(retain_graph=True) optimizerDs[i].step() lossG.backward(retain_graph=True) optimizerGs[i].step() optimizerGs[j].step() writer.add_histogram( 'Hist/hist_' + site[i] + '2' + site[j] + '_source', d1, epoch * args.nsteps + t) writer.add_histogram( 'Hist/hist_' + site[i] + '2' + site[j] + '_target', d2, epoch * args.nsteps + t) count += 1 if count % pace == 0 or t == args.nsteps - 1: with torch.no_grad(): for key in model.state_dict().keys(): if models[0].state_dict()[key].dtype == torch.int64: model.state_dict()[key].data.copy_( models[0].state_dict()[key]) else: temp = torch.zeros_like(model.state_dict()[key]) # add noise for s in range(4): if args.type == 'G': nn = tdist.Normal( torch.tensor([0.0]), args.noise * torch.std(models[s].state_dict() [key].detach().cpu())) else: nn = tdist.Laplace( torch.tensor([0.0]), args.noise * torch.std(models[s].state_dict() [key].detach().cpu())) noise = nn.sample(models[s].state_dict() [key].size()).squeeze() noise = noise.to(device) temp += w[s] * (models[s].state_dict()[key] + noise) # update global model model.state_dict()[key].data.copy_(temp) # updata local model for s in range(4): models[s].state_dict()[key].data.copy_( model.state_dict()[key]) return loss_all, lossG_all, lossD_all, num_data, num_dataG, num_dataD
def train(epoch): pace = args.pace for i in range(4): models[i].train() models_local[i].train() if epoch <= 50 and epoch % 20 == 0: for param_group1 in optimizers[i].param_groups: param_group1['lr'] = 0.5 * param_group1['lr'] elif epoch > 50 and epoch % 20 == 0: for param_group1 in optimizers[i].param_groups: param_group1['lr'] = 0.5 * param_group1['lr'] if epoch <= 50 and epoch % 20 == 0: for param_group1 in optimizers_local[i].param_groups: param_group1['lr'] = 0.5 * param_group1['lr'] elif epoch > 50 and epoch % 20 == 0: for param_group1 in optimizers_local[i].param_groups: param_group1['lr'] = 0.5 * param_group1['lr'] #define weights w = dict() denominator = np.sum(np.array(tbs)) for i in range(4): w[i] = 0.25 #tbs[i]/denominator loss_all = dict() loss_lc = dict() num_data = dict() for i in range(4): loss_all[i] = 0 loss_lc[i] = 0 num_data[i] = 0 count = 0 for t in range(args.nsteps): for i in range(4): optimizers[i].zero_grad() a, b = next(iter(train_loaders[i])) num_data[i] += b.size(0) a = a.to(device) b = b.to(device) outlocal = models_local[i](a) loss_local = nnloss(outlocal, b) loss_local.backward(retain_graph=True) loss_lc[i] += loss_local.item() * b.size(0) optimizers_local[i].step() output, _ = models[i](a, outlocal) loss = nnloss(output, b) loss.backward() loss_all[i] += loss.item() * b.size(0) optimizers[i].step() count += 1 if count % pace == 0 or t == args.nsteps - 1: with torch.no_grad(): for key in model.classifier.state_dict().keys(): if models[0].classifier.state_dict( )[key].dtype == torch.int64: model.classifier.state_dict()[key].data.copy_( models[0].classifier.state_dict()[key]) else: temp = torch.zeros_like( model.classifier.state_dict()[key]) # add noise for s in range(4): nn = tdist.Normal( torch.tensor([0.0]), args.noise * torch.std(models[s].classifier.state_dict( )[key].detach().cpu())) noise = nn.sample( models[i].classifier.state_dict() [key].size()).squeeze() noise = noise.to(device) temp += w[s] * ( models[s].classifier.state_dict()[key] + noise) #updata global model model.classifier.state_dict()[key].data.copy_(temp) # only classifier get updated for s in range(4): models[s].classifier.state_dict( )[key].data.copy_( model.classifier.state_dict()[key]) return loss_all[0] / num_data[0], loss_all[1] / num_data[1],loss_all[2] / num_data[2],loss_all[3] / num_data[3], \ loss_lc[0] / num_data[0],loss_lc[1] / num_data[1], loss_lc[2] / num_data[2], loss_lc[3] / num_data[3]
# def __init__(self,model,nelec,ndim): # NeuralWF.__init__(self, model, nelec, ndim) # def nuclear_potential(self,pos): # return torch.sum(0.5*pos**2,1) # def electronic_potential(self,pos): # return 0 # wf = HarmOsc3D(model=WaveNet,nelec=1, ndim=3) wf = NEURAL_PYSCF_WF(atom='O 0 0 0; H 0 1 0; H 0 0 1', basis='dzp', active_space=(2, 2)) sampler = Metropolis(nwalkers=64, nstep=10, step_size=3, nelec=wf.nelec, ndim=3, domain={ 'min': -5, 'max': 5 }) nn = NN4PYSCF(wf=wf, sampler=sampler) pos = nn.sample() dataset = QMC_DataSet(pos) dataloader = DataLoader(dataset, batch_size=nn.batchsize) qmc_loss = QMCLoss(nn.wf, method='variance')