示例#1
0
    def train(epoch):
        pace = args.pace
        for i in range(4):
            models[i].train()
            if epoch <= 50 and epoch % 20 == 0:
                for param_group1 in optimizers[i].param_groups:
                    param_group1['lr'] = 0.5 * param_group1['lr']
            elif epoch > 50 and epoch % 20 == 0:
                for param_group1 in optimizers[i].param_groups:
                    param_group1['lr'] = 0.5 * param_group1['lr']
            if epoch <= 50 and epoch % 20 == 0:
                for param_group1 in optimizerGs[i].param_groups:
                    param_group1['lr'] = 0.5 * param_group1['lr']
            elif epoch > 50 and epoch % 20 == 0:
                for param_group1 in optimizerGs[i].param_groups:
                    param_group1['lr'] = 0.5 * param_group1['lr']

            discriminators[i].train()
            if epoch <= 50 and epoch % 20 == 0:
                for param_group1 in optimizerDs[i].param_groups:
                    param_group1['lr'] = 0.5 * param_group1['lr']
            elif epoch > 50 and epoch % 20 == 0:
                for param_group1 in optimizerDs[i].param_groups:
                    param_group1['lr'] = 0.5 * param_group1['lr']

        #define weights
        w = dict()
        denominator = np.sum(np.array(tbs))
        for i in range(4):
            w[i] = 0.25  #tbs[i]/denominator

        loss_all = dict()
        lossD_all = dict()
        lossG_all = dict()
        num_data = dict()
        num_dataG = dict()
        num_dataD = dict()
        for i in range(4):
            loss_all[i] = 0
            num_data[i] = EPS
            num_dataG[i] = EPS
            lossG_all[i] = 0
            lossD_all[i] = 0
            num_dataD[i] = EPS

        count = 0
        for t in range(args.nsteps):
            fs = []

            # optimize classifier

            for i in range(4):
                optimizers[i].zero_grad()
                a, b = next(data_iters[i])
                num_data[i] += b.size(0)
                a = a.to(device)
                b = b.to(device)
                output = models[i](a)
                loss = celoss(output, b)
                loss_all[i] += loss.item() * b.size(0)
                if epoch >= 0:
                    loss.backward(retain_graph=True)
                    optimizers[i].step()

                fs.append(models[i].encoder(a))

            #optimize alignment

            nn = []
            noises = []
            for i in range(4):
                nn = tdist.Normal(torch.tensor([0.0]),
                                  0.001 * torch.std(fs[i].detach().cpu()))
                noises.append(nn.sample(fs[i].size()).squeeze().to(device))

            for i in range(4):
                for j in range(4):
                    if i != j:
                        optimizerDs[i].zero_grad()
                        optimizerGs[i].zero_grad()
                        optimizerGs[j].zero_grad()

                        d1 = discriminators[i](fs[i] + noises[i])
                        d2 = discriminators[i](fs[j] + noises[j])
                        num_dataG[i] += d1.size(0)
                        num_dataD[i] += d1.size(0)
                        lossD = advDloss(d1, d2)
                        lossG = advGloss(d1, d2)
                        lossD_all[i] += lossD.item() * d1.size(0)
                        lossG_all[i] += lossG.item() * d1.size(0)
                        lossG_all[j] += lossG.item() * d2.size(0)
                        lossD = 0.1 * lossD
                        if epoch >= 5:
                            lossD.backward(retain_graph=True)
                            optimizerDs[i].step()
                            lossG.backward(retain_graph=True)
                            optimizerGs[i].step()
                            optimizerGs[j].step()
                        writer.add_histogram(
                            'Hist/hist_' + site[i] + '2' + site[j] + '_source',
                            d1, epoch * args.nsteps + t)
                        writer.add_histogram(
                            'Hist/hist_' + site[i] + '2' + site[j] + '_target',
                            d2, epoch * args.nsteps + t)

            count += 1
            if count % pace == 0 or t == args.nsteps - 1:
                with torch.no_grad():
                    for key in model.state_dict().keys():
                        if models[0].state_dict()[key].dtype == torch.int64:
                            model.state_dict()[key].data.copy_(
                                models[0].state_dict()[key])
                        else:
                            temp = torch.zeros_like(model.state_dict()[key])
                            # add noise
                            for s in range(4):
                                if args.type == 'G':
                                    nn = tdist.Normal(
                                        torch.tensor([0.0]),
                                        args.noise *
                                        torch.std(models[s].state_dict()
                                                  [key].detach().cpu()))
                                else:
                                    nn = tdist.Laplace(
                                        torch.tensor([0.0]),
                                        args.noise *
                                        torch.std(models[s].state_dict()
                                                  [key].detach().cpu()))
                                noise = nn.sample(models[s].state_dict()
                                                  [key].size()).squeeze()
                                noise = noise.to(device)
                                temp += w[s] * (models[s].state_dict()[key] +
                                                noise)
                            # update global model
                            model.state_dict()[key].data.copy_(temp)
                            # updata local model
                            for s in range(4):
                                models[s].state_dict()[key].data.copy_(
                                    model.state_dict()[key])

        return loss_all, lossG_all, lossD_all, num_data, num_dataG, num_dataD
示例#2
0
    def train(epoch):
        pace = args.pace
        for i in range(4):
            models[i].train()
            models_local[i].train()
            if epoch <= 50 and epoch % 20 == 0:
                for param_group1 in optimizers[i].param_groups:
                    param_group1['lr'] = 0.5 * param_group1['lr']
            elif epoch > 50 and epoch % 20 == 0:
                for param_group1 in optimizers[i].param_groups:
                    param_group1['lr'] = 0.5 * param_group1['lr']
            if epoch <= 50 and epoch % 20 == 0:
                for param_group1 in optimizers_local[i].param_groups:
                    param_group1['lr'] = 0.5 * param_group1['lr']
            elif epoch > 50 and epoch % 20 == 0:
                for param_group1 in optimizers_local[i].param_groups:
                    param_group1['lr'] = 0.5 * param_group1['lr']

        #define weights
        w = dict()
        denominator = np.sum(np.array(tbs))
        for i in range(4):
            w[i] = 0.25  #tbs[i]/denominator
        loss_all = dict()
        loss_lc = dict()
        num_data = dict()
        for i in range(4):
            loss_all[i] = 0
            loss_lc[i] = 0
            num_data[i] = 0
        count = 0
        for t in range(args.nsteps):
            for i in range(4):
                optimizers[i].zero_grad()
                a, b = next(iter(train_loaders[i]))
                num_data[i] += b.size(0)
                a = a.to(device)
                b = b.to(device)
                outlocal = models_local[i](a)
                loss_local = nnloss(outlocal, b)
                loss_local.backward(retain_graph=True)
                loss_lc[i] += loss_local.item() * b.size(0)
                optimizers_local[i].step()

                output, _ = models[i](a, outlocal)
                loss = nnloss(output, b)
                loss.backward()
                loss_all[i] += loss.item() * b.size(0)
                optimizers[i].step()
            count += 1

            if count % pace == 0 or t == args.nsteps - 1:
                with torch.no_grad():
                    for key in model.classifier.state_dict().keys():
                        if models[0].classifier.state_dict(
                        )[key].dtype == torch.int64:
                            model.classifier.state_dict()[key].data.copy_(
                                models[0].classifier.state_dict()[key])
                        else:
                            temp = torch.zeros_like(
                                model.classifier.state_dict()[key])
                            # add noise
                            for s in range(4):
                                nn = tdist.Normal(
                                    torch.tensor([0.0]),
                                    args.noise *
                                    torch.std(models[s].classifier.state_dict(
                                    )[key].detach().cpu()))
                                noise = nn.sample(
                                    models[i].classifier.state_dict()
                                    [key].size()).squeeze()
                                noise = noise.to(device)
                                temp += w[s] * (
                                    models[s].classifier.state_dict()[key] +
                                    noise)
                            #updata global model
                            model.classifier.state_dict()[key].data.copy_(temp)
                            # only classifier get updated
                            for s in range(4):
                                models[s].classifier.state_dict(
                                )[key].data.copy_(
                                    model.classifier.state_dict()[key])

        return loss_all[0] / num_data[0], loss_all[1] / num_data[1],loss_all[2] / num_data[2],loss_all[3] / num_data[3], \
               loss_lc[0] / num_data[0],loss_lc[1] / num_data[1], loss_lc[2] / num_data[2], loss_lc[3] / num_data[3]
示例#3
0
    #     def __init__(self,model,nelec,ndim):
    #         NeuralWF.__init__(self, model, nelec, ndim)

    #     def nuclear_potential(self,pos):
    #         return torch.sum(0.5*pos**2,1)

    #     def electronic_potential(self,pos):
    #         return 0
    # wf = HarmOsc3D(model=WaveNet,nelec=1, ndim=3)

    wf = NEURAL_PYSCF_WF(atom='O 0 0 0; H 0 1 0; H 0 0 1',
                         basis='dzp',
                         active_space=(2, 2))

    sampler = Metropolis(nwalkers=64,
                         nstep=10,
                         step_size=3,
                         nelec=wf.nelec,
                         ndim=3,
                         domain={
                             'min': -5,
                             'max': 5
                         })

    nn = NN4PYSCF(wf=wf, sampler=sampler)

    pos = nn.sample()
    dataset = QMC_DataSet(pos)
    dataloader = DataLoader(dataset, batch_size=nn.batchsize)
    qmc_loss = QMCLoss(nn.wf, method='variance')