예제 #1
0
파일: main.py 프로젝트: augustdemi/jcvae
def get_paired_data(paired_cnt, seed):
    data = torch.utils.data.DataLoader(DIGIT('./data', train=True),
                                       batch_size=args.batch_size,
                                       shuffle=False)
    tr_labels = data.dataset.label

    cnt = int(paired_cnt / 10)
    assert cnt == paired_cnt / 10

    label_idx = {}
    for i in range(10):
        label_idx.update({i: []})
    for idx in range(len(tr_labels)):
        label = int(tr_labels[idx])
        label_idx[label].append(idx)

    total_random_idx = []
    for i in range(10):
        random.seed(seed)
        per_label_random_idx = random.sample(label_idx[i], cnt)
        total_random_idx.extend(per_label_random_idx)
    random.seed(seed)
    random.shuffle(total_random_idx)

    imgs = []
    labels = []
    for idx in total_random_idx:
        img, label = data.dataset.__getitem__(idx)
        imgs.append(img)
        labels.append(torch.tensor(label))
    imgs = torch.stack(imgs, dim=0)
    labels = torch.stack(labels, dim=0)
    return imgs, labels
예제 #2
0
    def train(self):

        self.set_mode(train=True)

        # prepare dataloader (iterable)
        print('Start loading data...')
        dset = DIGIT('./data', train=True)
        self.data_loader = torch.utils.data.DataLoader(dset, batch_size=self.batch_size, shuffle=True)
        test_dset = DIGIT('./data', train=False)
        self.test_data_loader = torch.utils.data.DataLoader(test_dset, batch_size=self.batch_size, shuffle=True)
        print('test: ', len(test_dset))
        self.N = len(self.data_loader.dataset)
        print('...done')

        # iterators from dataloader
        iterator1 = iter(self.data_loader)
        iterator2 = iter(self.data_loader)

        iter_per_epoch = min(len(iterator1), len(iterator2))

        start_iter = self.ckpt_load_iter + 1
        epoch = int(start_iter / iter_per_epoch)

        for iteration in range(start_iter, self.max_iter + 1):

            # reset data iterators for each epoch
            if iteration % iter_per_epoch == 0:
                print('==== epoch %d done ====' % epoch)
                epoch += 1
                iterator1 = iter(self.data_loader)
                iterator2 = iter(self.data_loader)

            # ============================================
            #          TRAIN THE VAE (ENC & DEC)
            # ============================================

            # sample a mini-batch
            XA, XB, index = next(iterator1)  # (n x C x H x W)

            index = index.cpu().detach().numpy()
            if self.use_cuda:
                XA = XA.cuda()
                XB = XB.cuda()

            # zA, zS = encA(xA)
            muA_infA, stdA_infA, logvarA_infA, cate_prob_infA = self.encoderA(XA)

            # zB, zS = encB(xB)
            muB_infB, stdB_infB, logvarB_infB, cate_prob_infB = self.encoderB(XB)

            # read current values

            # zS = encAB(xA,xB) via POE
            cate_prob_POE = torch.exp(
                torch.log(torch.tensor(1 / 10)) + torch.log(cate_prob_infA) + torch.log(cate_prob_infB))

            # latent_dist = {'cont': (muA_infA, logvarA_infA), 'disc': [cate_prob_infA]}
            # (kl_cont_loss, kl_disc_loss, cont_capacity_loss, disc_capacity_loss) = kl_loss_function(self.use_cuda, iteration, latent_dist)

            # kl losses
            #A
            latent_dist_infA = {'cont': (muA_infA, logvarA_infA), 'disc': [cate_prob_infA]}
            (kl_cont_loss_infA, kl_disc_loss_infA, cont_capacity_loss_infA, disc_capacity_loss_infA) = kl_loss_function(
                self.use_cuda, iteration, latent_dist_infA)

            loss_kl_infA = kl_cont_loss_infA + kl_disc_loss_infA
            capacity_loss_infA = cont_capacity_loss_infA + disc_capacity_loss_infA

            #B
            latent_dist_infB = {'cont': (muB_infB, logvarB_infB), 'disc': [cate_prob_infB]}
            (kl_cont_loss_infB, kl_disc_loss_infB, cont_capacity_loss_infB, disc_capacity_loss_infB) = kl_loss_function(
                self.use_cuda, iteration, latent_dist_infB, cont_capacity=[0.0, 5.0, 50000, 100.0] , disc_capacity=[0.0, 10.0, 50000, 100.0])

            loss_kl_infB = kl_cont_loss_infB + kl_disc_loss_infB
            capacity_loss_infB = cont_capacity_loss_infB + disc_capacity_loss_infB


            loss_capa = capacity_loss_infB

            # encoder samples (for training)
            ZA_infA = sample_gaussian(self.use_cuda, muA_infA, stdA_infA)
            ZB_infB = sample_gaussian(self.use_cuda, muB_infB, stdB_infB)
            ZS_POE = sample_gumbel_softmax(self.use_cuda, cate_prob_POE)

            # encoder samples (for cross-modal prediction)
            ZS_infA = sample_gumbel_softmax(self.use_cuda, cate_prob_infA)
            ZS_infB = sample_gumbel_softmax(self.use_cuda, cate_prob_infB)

            # reconstructed samples (given joint modal observation)
            XA_POE_recon = self.decoderA(ZA_infA, ZS_POE)
            XB_POE_recon = self.decoderB(ZB_infB, ZS_POE)

            # reconstructed samples (given single modal observation)
            XA_infA_recon = self.decoderA(ZA_infA, ZS_infA)
            XB_infB_recon = self.decoderB(ZB_infB, ZS_infB)

            # loss_recon_infA = F.l1_loss(torch.sigmoid(XA_infA_recon), XA, reduction='sum').div(XA.size(0))
            loss_recon_infA = reconstruction_loss(XA, torch.sigmoid(XA_infA_recon), distribution="bernoulli")
            #
            loss_recon_infB = reconstruction_loss(XB, torch.sigmoid(XB_infB_recon), distribution="bernoulli")
            #
            loss_recon_POE = \
                F.l1_loss(torch.sigmoid(XA_POE_recon), XA, reduction='sum').div(XA.size(0)) + \
                F.l1_loss(torch.sigmoid(XB_POE_recon), XB, reduction='sum').div(XB.size(0))
            #

            loss_recon = loss_recon_infB

            # total loss for vae
            vae_loss = loss_recon + loss_capa

            # update vae
            self.optim_vae.zero_grad()
            vae_loss.backward()
            self.optim_vae.step()



            # print the losses
            if iteration % self.print_iter == 0:
                prn_str = ( \
                                      '[iter %d (epoch %d)] vae_loss: %.3f ' + \
                                      '(recon: %.3f, capa: %.3f)\n' + \
                                      '    rec_infA = %.3f, rec_infB = %.3f, rec_POE = %.3f\n' + \
                                      '    kl_infA = %.3f, kl_infB = %.3f' + \
                                      '    cont_capacity_loss_infA = %.3f, disc_capacity_loss_infA = %.3f\n' + \
                                      '    cont_capacity_loss_infB = %.3f, disc_capacity_loss_infB = %.3f\n'
                          ) % \
                          (iteration, epoch,
                           vae_loss.item(), loss_recon.item(), loss_capa.item(),
                           loss_recon_infA.item(), loss_recon_infB.item(), loss_recon.item(),
                           loss_kl_infA.item(), loss_kl_infB.item(),
                           cont_capacity_loss_infA.item(), disc_capacity_loss_infA.item(),
                           cont_capacity_loss_infB.item(), disc_capacity_loss_infB.item(),
                           )
                print(prn_str)
                if self.record_file:
                    record = open(self.record_file, 'a')
                    record.write('%s\n' % (prn_str,))
                    record.close()

            # save model parameters
            if iteration % self.ckpt_save_iter == 0:
                self.save_checkpoint(iteration)

            # save output images (recon, synth, etc.)
            if iteration % self.output_save_iter == 0:
                # self.save_embedding(iteration, index, muA_infA, muB_infB, muS_infA, muS_infB, muS_POE)

                # 1) save the recon images
                self.save_recon(iteration)

                # self.save_recon2(iteration, index, XA, XB,
                #     torch.sigmoid(XA_infA_recon).data,
                #     torch.sigmoid(XB_infB_recon).data,
                #     torch.sigmoid(XA_POE_recon).data,
                #     torch.sigmoid(XB_POE_recon).data,
                #     muA_infA, muB_infB, muS_infA, muS_infB, muS_POE,
                #     logalpha, logalphaA, logalphaB
                # )
                z_A, z_B, z_S = self.get_stat()

                #
                #
                #
                # # 2) save the pure-synthesis images
                # # self.save_synth_pure( iteration, howmany=100 )
                # #
                # # 3) save the cross-modal-synthesis images
                # self.save_synth_cross_modal(iteration, z_A, z_B, howmany=3)
                #
                # # 4) save the latent traversed images
                self.save_traverseB(iteration, z_A, z_B, z_S)

                # self.get_loglike(logalpha, logalphaA, logalphaB)

                # # 3) save the latent traversed images
                # if self.dataset.lower() == '3dchairs':
                #     self.save_traverse(iteration, limb=-2, limu=2, inter=0.5)
                # else:
                #     self.save_traverse(iteration, limb=-3, limu=3, inter=0.1)

            if iteration % self.eval_metrics_iter == 0:
                self.save_synth_cross_modal(iteration, z_A, z_B, train=False, howmany=3)

            # (visdom) insert current line stats
            if self.viz_on and (iteration % self.viz_ll_iter == 0):
                self.line_gather.insert(iter=iteration,
                                        recon_both=loss_recon_POE.item(),
                                        recon_A=loss_recon_infA.item(),
                                        recon_B=loss_recon_infB.item(),
                                        kl_A=loss_kl_infA.item(),
                                        kl_B=loss_kl_infB.item(),
                                        cont_capacity_loss_infA=cont_capacity_loss_infA.item(),
                                        disc_capacity_loss_infA=disc_capacity_loss_infA.item(),
                                        cont_capacity_loss_infB=cont_capacity_loss_infB.item(),
                                        disc_capacity_loss_infB=disc_capacity_loss_infB.item()
                                        )

            # (visdom) visualize line stats (then flush out)
            if self.viz_on and (iteration % self.viz_la_iter == 0):
                self.visualize_line()
                self.line_gather.flush()
예제 #3
0
파일: main.py 프로젝트: augustdemi/jcvae

if args.viz_on:
    WIN_ID = dict(llA='win_llA',
                  llB='win_llB',
                  test_acc='win_test_acc',
                  total_losses='win_total_losses')
    LINE_GATHER = probtorch.util.DataGather('epoch', 'recon_A', 'recon_B',
                                            'recon_poeA', 'recon_poeB',
                                            'recon_crA', 'recon_crB',
                                            'total_loss', 'test_total_loss',
                                            'test_acc')
    VIZ = visdom.Visdom(port=args.viz_port)
    viz_init()

train_data = torch.utils.data.DataLoader(DIGIT('./data', train=True),
                                         batch_size=args.batch_size,
                                         shuffle=False)
test_data = torch.utils.data.DataLoader(DIGIT('./data', train=False),
                                        batch_size=args.batch_size,
                                        shuffle=False)

train_data_size = len(train_data)

BIAS_TRAIN = (train_data_size - 1) / (args.batch_size - 1)
BIAS_TEST = (test_data.dataset.__len__() - 1) / (args.batch_size - 1)


def cuda_tensors(obj):
    for attr in dir(obj):
        value = getattr(obj, attr)
예제 #4
0
파일: main.py 프로젝트: augustdemi/jcvae
                  title='Total Loss', legend=['train_loss', 'test_loss'])
    )


if args.viz_on:
    WIN_ID = dict(
        llA='win_llA', llB='win_llB', test_acc='win_test_acc', total_losses='win_total_losses'
    )
    LINE_GATHER = probtorch.util.DataGather(
        'epoch', 'recon_A', 'recon_B', 'recon_poeA', 'recon_poeB', 'recon_crA', 'recon_crB',
        'total_loss', 'test_total_loss', 'test_acc'
    )
    VIZ = visdom.Visdom(port=args.viz_port)
    viz_init()

train_data = torch.utils.data.DataLoader(DIGIT('./data', train=True), batch_size=args.batch_size, shuffle=False)
test_data = torch.utils.data.DataLoader(DIGIT('./data', train=False), batch_size=args.batch_size, shuffle=False)

train_data_size = len(train_data)

BIAS_TRAIN = (test_data.dataset.__len__() - 1) / (args.batch_size - 1)
BIAS_TEST = (test_data.dataset.__len__() - 1) / (args.batch_size - 1)


def cuda_tensors(obj):
    for attr in dir(obj):
        value = getattr(obj, attr)
        if isinstance(value, torch.Tensor):
            setattr(obj, attr, value.cuda())