Пример #1
0
    def save_traverseA(self, iters, z_A, z_B, z_S, loc=-1):

        self.set_mode(train=False)

        encoderA = self.encoderA
        encoderB = self.encoderB
        decoderA = self.decoderA
        decoderB = self.decoderB
        interpolationA = torch.tensor(np.linspace(-3, 3, self.zS_dim))
        interpolationB = torch.tensor(np.linspace(-3, 3, self.zS_dim))
        interpolationS = torch.tensor(np.linspace(-3, 3, self.zS_dim))

        print('------------ traverse interpolation ------------')
        print('interpolationA: ', np.min(np.array(z_A)), np.max(np.array(z_A)))
        print('interpolationB: ', np.min(np.array(z_B)), np.max(np.array(z_B)))
        print('interpolationS: ', np.min(np.array(z_S)), np.max(np.array(z_S)))

        if self.record_file:
            ####
            fixed_idxs = [3246, 7000, 14305, 19000, 27444, 33100, 38000, 45231, 51000, 55121]

            fixed_XA = [0] * len(fixed_idxs)
            fixed_XB = [0] * len(fixed_idxs)

            for i, idx in enumerate(fixed_idxs):

                fixed_XA[i], fixed_XB[i] = \
                    self.data_loader.dataset.__getitem__(idx)[0:2]
                if self.use_cuda:
                    fixed_XA[i] = fixed_XA[i].cuda()
                    fixed_XB[i] = fixed_XB[i].cuda()
                fixed_XA[i] = fixed_XA[i].unsqueeze(0)
                fixed_XB[i] = fixed_XB[i].unsqueeze(0)

            fixed_XA = torch.cat(fixed_XA, dim=0)
            fixed_XB = torch.cat(fixed_XB, dim=0)

            fixed_zmuA, _, _, cate_prob_infA = encoderA(fixed_XA)

            # zB, zS = encB(xB)
            fixed_zmuB, _, _, cate_prob_infB = encoderB(fixed_XB)

            # zS = encAB(xA,xB) via POE
            fixed_cate_probS = torch.exp(
                torch.log(torch.tensor(1 / 10)) + torch.log(cate_prob_infA) + torch.log(cate_prob_infB))


            # fixed_zS = sample_gumbel_softmax(self.use_cuda, fixed_cate_probS, train=False)
            fixed_zS = sample_gumbel_softmax(self.use_cuda, cate_prob_infA, train=False)


            saving_shape=torch.cat([fixed_XA[i] for i in range(fixed_XA.shape[0])], dim=1).shape

        ####

        WS = torch.ones(saving_shape)
        if self.use_cuda:
            WS = WS.cuda()

        # do traversal and collect generated images
        gifs = []

        zA_ori, zB_ori, zS_ori = fixed_zmuA, fixed_zmuB, fixed_zS

        tempA = [] # zA_dim + zS_dim , num_trv, 1, 32*num_samples, 32
        for row in range(self.zA_dim):
            if loc != -1 and row != loc:
                continue
            zA = zA_ori.clone()

            temp = []
            for val in interpolationA:
                zA[:, row] = val
                sampleA = torch.sigmoid(decoderA(zA, zS_ori)).data
                temp.append((torch.cat([sampleA[i] for i in range(sampleA.shape[0])], dim=1)).unsqueeze(0))

            tempA.append(torch.cat(temp, dim=0).unsqueeze(0)) # torch.cat(temp, dim=0) = num_trv, 1, 32*num_samples, 32

        temp = []
        for i in range(self.zS_dim):
            zS = np.zeros((1, self.zS_dim))
            zS[0, i % self.zS_dim] = 1.
            zS = torch.Tensor(zS)
            zS = torch.cat([zS] * len(fixed_idxs), dim=0)

            if self.use_cuda:
                zS = zS.cuda()

            sampleA = torch.sigmoid(decoderA(zA_ori, zS)).data
            temp.append((torch.cat([sampleA[i] for i in range(sampleA.shape[0])], dim=1)).unsqueeze(0))
        tempA.append(torch.cat(temp, dim=0).unsqueeze(0))
        gifs = torch.cat(tempA, dim=0) #torch.Size([11, 10, 1, 384, 32])


        # save the generated files, also the animated gifs
        out_dir = os.path.join(self.output_dir_trvsl, str(iters), 'train')
        mkdirs(self.output_dir_trvsl)
        mkdirs(out_dir)

        for j, val in enumerate(interpolationA):
            # I = torch.cat([IMG[key], gifs[:][j]], dim=0)
            I = gifs[:,j]
            save_image(
                tensor=I.cpu(),
                filename=os.path.join(out_dir, '%03d.jpg' % (j)),
                nrow=1 + self.zA_dim + 1 + 1 + 1 + self.zB_dim,
                pad_value=1)
            # make animated gif
        grid2gif2(
            out_dir, str(os.path.join(out_dir, 'mnist_traverse' + '.gif')), delay=10
        )

        self.set_mode(train=True)
Пример #2
0
    def save_recon(self, iters):
        self.set_mode(train=False)

        mkdirs(self.output_dir_recon)

        fixed_idxs = [3246, 7000, 14305, 19000, 27444, 33100, 38000, 45231, 51000, 55121]

        fixed_idxs60 = []
        for idx in fixed_idxs:
            for i in range(6):
                fixed_idxs60.append(idx + i)

        XA = [0] * len(fixed_idxs60)
        XB = [0] * len(fixed_idxs60)

        for i, idx in enumerate(fixed_idxs60):
            XA[i], XB[i] = \
                self.data_loader.dataset.__getitem__(idx)[0:2]

            if self.use_cuda:
                XA[i] = XA[i].cuda()
                XB[i] = XB[i].cuda()

        XA = torch.stack(XA)
        XB = torch.stack(XB)

        muA_infA, stdA_infA, logvarA_infA, cate_prob_infA = self.encoderA(XA)

        # zB, zS = encB(xB)
        muB_infB, stdB_infB, logvarB_infB, cate_prob_infB = self.encoderB(XB)

        # zS = encAB(xA,xB) via POE
        cate_prob_POE = torch.exp(
            torch.log(torch.tensor(1 / 10)) + torch.log(cate_prob_infA) + torch.log(cate_prob_infB))

        # encoder samples (for training)
        ZA_infA = sample_gaussian(self.use_cuda, muA_infA, stdA_infA)
        ZB_infB = sample_gaussian(self.use_cuda, muB_infB, stdB_infB)
        ZS_POE = sample_gumbel_softmax(self.use_cuda, cate_prob_POE, train=False)

        # encoder samples (for cross-modal prediction)
        ZS_infA = sample_gumbel_softmax(self.use_cuda, cate_prob_infA, train=False)
        ZS_infB = sample_gumbel_softmax(self.use_cuda, cate_prob_infB, train=False)

        # reconstructed samples (given joint modal observation)
        XA_POE_recon = torch.sigmoid(self.decoderA(ZA_infA, ZS_POE))
        XB_POE_recon = torch.sigmoid(self.decoderB(ZB_infB, ZS_POE))

        # reconstructed samples (given single modal observation)
        XA_infA_recon = torch.sigmoid(self.decoderA(ZA_infA, ZS_infA))
        XB_infB_recon = torch.sigmoid(self.decoderB(ZB_infB, ZS_infB))

        WS = torch.ones(XA.shape)
        if self.use_cuda:
            WS = WS.cuda()

        n = XA.shape[0]
        perm = torch.arange(0, 4 * n).view(4, n).transpose(1, 0)
        perm = perm.contiguous().view(-1)

        ## img
        # merged = torch.cat(
        #     [ XA, XB, XA_infA_recon, XB_infB_recon,
        #       XA_POE_recon, XB_POE_recon, WS ], dim=0
        # )
        merged = torch.cat(
            [XA, XA_infA_recon, XA_POE_recon, WS], dim=0
        )
        merged = merged[perm, :].cpu()

        # save the results as image
        fname = os.path.join(self.output_dir_recon, 'reconA_%s.jpg' % iters)
        mkdirs(self.output_dir_recon)
        save_image(
            tensor=merged, filename=fname, nrow=4 * int(np.sqrt(n)),
            pad_value=1
        )

        WS = torch.ones(XB.shape)
        if self.use_cuda:
            WS = WS.cuda()

        n = XB.shape[0]
        perm = torch.arange(0, 4 * n).view(4, n).transpose(1, 0)
        perm = perm.contiguous().view(-1)

        ## ingr
        merged = torch.cat(
            [XB, XB_infB_recon, XB_POE_recon, WS], dim=0
        )
        merged = merged[perm, :].cpu()

        # save the results as image
        fname = os.path.join(self.output_dir_recon, 'reconB_%s.jpg' % iters)
        mkdirs(self.output_dir_recon)
        save_image(
            tensor=merged, filename=fname, nrow=4 * int(np.sqrt(n)),
            pad_value=1
        )
        self.set_mode(train=True)
Пример #3
0
    def save_synth_cross_modal(self, iters, z_A_stat, z_B_stat, train=True, howmany=3):

        self.set_mode(train=False)

        if train:
            data_loader = self.data_loader
            fixed_idxs = [3246, 7001, 14308, 19000, 27447, 33103, 38002, 45232, 51000, 55125]
        else:
            data_loader = self.test_data_loader
            fixed_idxs = [2, 982, 2300, 3400, 4500, 5500, 6500, 7500, 8500, 9500]

        fixed_XA = [0] * len(fixed_idxs)
        fixed_XB = [0] * len(fixed_idxs)

        for i, idx in enumerate(fixed_idxs):

            fixed_XA[i], fixed_XB[i] = \
                data_loader.dataset.__getitem__(idx)[0:2]

            if self.use_cuda:
                fixed_XA[i] = fixed_XA[i].cuda()
                fixed_XB[i] = fixed_XB[i].cuda()

        fixed_XA = torch.stack(fixed_XA)
        fixed_XB = torch.stack(fixed_XB)

        _, _, _, cate_prob_infA = self.encoderA(fixed_XA)

        # zB, zS = encB(xB)
        _, _, _, cate_prob_infB = self.encoderB(fixed_XB)

        ZS_infA = sample_gumbel_softmax(self.use_cuda, cate_prob_infA, train=False)
        ZS_infB = sample_gumbel_softmax(self.use_cuda, cate_prob_infB, train=False)

        if self.use_cuda:
            ZS_infA = ZS_infA.cuda()
            ZS_infB = ZS_infB.cuda()

        decoderA = self.decoderA
        decoderB = self.decoderB

        # mkdirs(os.path.join(self.output_dir_synth, str(iters)))

        fixed_XA_3ch = []
        for i in range(len(fixed_XA)):
            each_XA = fixed_XA[i].clone().squeeze()
            fixed_XA_3ch.append(torch.stack([each_XA, each_XA, each_XA]))

        fixed_XA_3ch = torch.stack(fixed_XA_3ch)

        WS = torch.ones(fixed_XA_3ch.shape)
        if self.use_cuda:
            WS = WS.cuda()

        n = len(fixed_idxs)

        perm = torch.arange(0, (howmany + 2) * n).view(howmany + 2, n).transpose(1, 0)
        perm = perm.contiguous().view(-1)

        ######## 1) generate xB from given xA (A2B) ########

        merged = torch.cat([fixed_XA_3ch], dim=0)
        for k in range(howmany):
            # z_B_stat = np.array(z_B_stat)
            # z_B_stat_mean = np.mean(z_B_stat, 0)
            # ZB = torch.Tensor(z_B_stat_mean)
            # ZB_list = []
            # for _ in range(n):
            #     ZB_list.append(ZB)
            # ZB = torch.stack(ZB_list)

            ZB = torch.randn(n, self.zB_dim)
            z_B_stat = np.array(z_B_stat)
            z_B_stat_mean = np.mean(z_B_stat, 0)
            ZB = ZB + torch.Tensor(z_B_stat_mean)

            if self.use_cuda:
                ZB = ZB.cuda()
            XB_synth = torch.sigmoid(decoderB(ZB, ZS_infA))  # given XA
            # merged = torch.cat([merged, fixed_XA_3ch], dim=0)
            merged = torch.cat([merged, XB_synth], dim=0)
        merged = torch.cat([merged, WS], dim=0)
        merged = merged[perm, :].cpu()

        # save the results as image
        if train:
            fname = os.path.join(
                self.output_dir_synth,
                'synth_cross_modal_A2B_%s.jpg' % iters
            )
        else:
            fname = os.path.join(
                self.output_dir_synth,
                'eval_synth_cross_modal_A2B_%s.jpg' % iters
            )
        mkdirs(self.output_dir_synth)
        save_image(
            tensor=merged, filename=fname, nrow=(howmany + 2) * int(np.sqrt(n)),
            pad_value=1
        )

        ######## 2) generate xA from given xB (B2A) ########
        merged = torch.cat([fixed_XB], dim=0)
        for k in range(howmany):
            # z_A_stat = np.array(z_A_stat)
            # z_A_stat_mean = np.mean(z_A_stat, 0)
            # ZA = torch.Tensor(z_A_stat_mean)
            # ZA_list = []
            # for _ in range(n):
            #     ZA_list.append(ZA)
            # ZA = torch.stack(ZA_list)

            ZA = torch.randn(n, self.zA_dim)
            z_A_stat = np.array(z_A_stat)
            z_A_stat_mean = np.mean(z_A_stat, 0)
            ZA = ZA + torch.Tensor(z_A_stat_mean)

            if self.use_cuda:
                ZA = ZA.cuda()
            XA_synth = torch.sigmoid(decoderA(ZA, ZS_infB))  # given XB

            XA_synth_3ch = []
            for i in range(len(XA_synth)):
                each_XA = XA_synth[i].clone().squeeze()
                XA_synth_3ch.append(torch.stack([each_XA, each_XA, each_XA]))

            # merged = torch.cat([merged, fixed_XB[:,:,2:30, 2:30]], dim=0)
            merged = torch.cat([merged, torch.stack(XA_synth_3ch)], dim=0)
        merged = torch.cat([merged, WS], dim=0)
        merged = merged[perm, :].cpu()

        # save the results as image
        if train:
            fname = os.path.join(
                self.output_dir_synth,
                'synth_cross_modal_B2A_%s.jpg' % iters
            )
        else:
            fname = os.path.join(
                self.output_dir_synth,
                'eval_synth_cross_modal_B2A_%s.jpg' % iters
            )
        mkdirs(self.output_dir_synth)
        save_image(
            tensor=merged, filename=fname, nrow=(howmany + 2) * int(np.sqrt(n)),
            pad_value=1
        )

        self.set_mode(train=True)
Пример #4
0
    def train(self):

        self.set_mode(train=True)

        # prepare dataloader (iterable)
        print('Start loading data...')
        dset = DIGIT('./data', train=True)
        self.data_loader = torch.utils.data.DataLoader(dset, batch_size=self.batch_size, shuffle=True)
        test_dset = DIGIT('./data', train=False)
        self.test_data_loader = torch.utils.data.DataLoader(test_dset, batch_size=self.batch_size, shuffle=True)
        print('test: ', len(test_dset))
        self.N = len(self.data_loader.dataset)
        print('...done')

        # iterators from dataloader
        iterator1 = iter(self.data_loader)
        iterator2 = iter(self.data_loader)

        iter_per_epoch = min(len(iterator1), len(iterator2))

        start_iter = self.ckpt_load_iter + 1
        epoch = int(start_iter / iter_per_epoch)

        for iteration in range(start_iter, self.max_iter + 1):

            # reset data iterators for each epoch
            if iteration % iter_per_epoch == 0:
                print('==== epoch %d done ====' % epoch)
                epoch += 1
                iterator1 = iter(self.data_loader)
                iterator2 = iter(self.data_loader)

            # ============================================
            #          TRAIN THE VAE (ENC & DEC)
            # ============================================

            # sample a mini-batch
            XA, XB, index = next(iterator1)  # (n x C x H x W)

            index = index.cpu().detach().numpy()
            if self.use_cuda:
                XA = XA.cuda()
                XB = XB.cuda()

            # zA, zS = encA(xA)
            muA_infA, stdA_infA, logvarA_infA, cate_prob_infA = self.encoderA(XA)

            # zB, zS = encB(xB)
            muB_infB, stdB_infB, logvarB_infB, cate_prob_infB = self.encoderB(XB)

            # read current values

            # zS = encAB(xA,xB) via POE
            cate_prob_POE = torch.exp(
                torch.log(torch.tensor(1 / 10)) + torch.log(cate_prob_infA) + torch.log(cate_prob_infB))

            # latent_dist = {'cont': (muA_infA, logvarA_infA), 'disc': [cate_prob_infA]}
            # (kl_cont_loss, kl_disc_loss, cont_capacity_loss, disc_capacity_loss) = kl_loss_function(self.use_cuda, iteration, latent_dist)

            # kl losses
            #A
            latent_dist_infA = {'cont': (muA_infA, logvarA_infA), 'disc': [cate_prob_infA]}
            (kl_cont_loss_infA, kl_disc_loss_infA, cont_capacity_loss_infA, disc_capacity_loss_infA) = kl_loss_function(
                self.use_cuda, iteration, latent_dist_infA)

            loss_kl_infA = kl_cont_loss_infA + kl_disc_loss_infA
            capacity_loss_infA = cont_capacity_loss_infA + disc_capacity_loss_infA

            #B
            latent_dist_infB = {'cont': (muB_infB, logvarB_infB), 'disc': [cate_prob_infB]}
            (kl_cont_loss_infB, kl_disc_loss_infB, cont_capacity_loss_infB, disc_capacity_loss_infB) = kl_loss_function(
                self.use_cuda, iteration, latent_dist_infB, cont_capacity=[0.0, 5.0, 50000, 100.0] , disc_capacity=[0.0, 10.0, 50000, 100.0])

            loss_kl_infB = kl_cont_loss_infB + kl_disc_loss_infB
            capacity_loss_infB = cont_capacity_loss_infB + disc_capacity_loss_infB


            loss_capa = capacity_loss_infB

            # encoder samples (for training)
            ZA_infA = sample_gaussian(self.use_cuda, muA_infA, stdA_infA)
            ZB_infB = sample_gaussian(self.use_cuda, muB_infB, stdB_infB)
            ZS_POE = sample_gumbel_softmax(self.use_cuda, cate_prob_POE)

            # encoder samples (for cross-modal prediction)
            ZS_infA = sample_gumbel_softmax(self.use_cuda, cate_prob_infA)
            ZS_infB = sample_gumbel_softmax(self.use_cuda, cate_prob_infB)

            # reconstructed samples (given joint modal observation)
            XA_POE_recon = self.decoderA(ZA_infA, ZS_POE)
            XB_POE_recon = self.decoderB(ZB_infB, ZS_POE)

            # reconstructed samples (given single modal observation)
            XA_infA_recon = self.decoderA(ZA_infA, ZS_infA)
            XB_infB_recon = self.decoderB(ZB_infB, ZS_infB)

            # loss_recon_infA = F.l1_loss(torch.sigmoid(XA_infA_recon), XA, reduction='sum').div(XA.size(0))
            loss_recon_infA = reconstruction_loss(XA, torch.sigmoid(XA_infA_recon), distribution="bernoulli")
            #
            loss_recon_infB = reconstruction_loss(XB, torch.sigmoid(XB_infB_recon), distribution="bernoulli")
            #
            loss_recon_POE = \
                F.l1_loss(torch.sigmoid(XA_POE_recon), XA, reduction='sum').div(XA.size(0)) + \
                F.l1_loss(torch.sigmoid(XB_POE_recon), XB, reduction='sum').div(XB.size(0))
            #

            loss_recon = loss_recon_infB

            # total loss for vae
            vae_loss = loss_recon + loss_capa

            # update vae
            self.optim_vae.zero_grad()
            vae_loss.backward()
            self.optim_vae.step()



            # print the losses
            if iteration % self.print_iter == 0:
                prn_str = ( \
                                      '[iter %d (epoch %d)] vae_loss: %.3f ' + \
                                      '(recon: %.3f, capa: %.3f)\n' + \
                                      '    rec_infA = %.3f, rec_infB = %.3f, rec_POE = %.3f\n' + \
                                      '    kl_infA = %.3f, kl_infB = %.3f' + \
                                      '    cont_capacity_loss_infA = %.3f, disc_capacity_loss_infA = %.3f\n' + \
                                      '    cont_capacity_loss_infB = %.3f, disc_capacity_loss_infB = %.3f\n'
                          ) % \
                          (iteration, epoch,
                           vae_loss.item(), loss_recon.item(), loss_capa.item(),
                           loss_recon_infA.item(), loss_recon_infB.item(), loss_recon.item(),
                           loss_kl_infA.item(), loss_kl_infB.item(),
                           cont_capacity_loss_infA.item(), disc_capacity_loss_infA.item(),
                           cont_capacity_loss_infB.item(), disc_capacity_loss_infB.item(),
                           )
                print(prn_str)
                if self.record_file:
                    record = open(self.record_file, 'a')
                    record.write('%s\n' % (prn_str,))
                    record.close()

            # save model parameters
            if iteration % self.ckpt_save_iter == 0:
                self.save_checkpoint(iteration)

            # save output images (recon, synth, etc.)
            if iteration % self.output_save_iter == 0:
                # self.save_embedding(iteration, index, muA_infA, muB_infB, muS_infA, muS_infB, muS_POE)

                # 1) save the recon images
                self.save_recon(iteration)

                # self.save_recon2(iteration, index, XA, XB,
                #     torch.sigmoid(XA_infA_recon).data,
                #     torch.sigmoid(XB_infB_recon).data,
                #     torch.sigmoid(XA_POE_recon).data,
                #     torch.sigmoid(XB_POE_recon).data,
                #     muA_infA, muB_infB, muS_infA, muS_infB, muS_POE,
                #     logalpha, logalphaA, logalphaB
                # )
                z_A, z_B, z_S = self.get_stat()

                #
                #
                #
                # # 2) save the pure-synthesis images
                # # self.save_synth_pure( iteration, howmany=100 )
                # #
                # # 3) save the cross-modal-synthesis images
                # self.save_synth_cross_modal(iteration, z_A, z_B, howmany=3)
                #
                # # 4) save the latent traversed images
                self.save_traverseB(iteration, z_A, z_B, z_S)

                # self.get_loglike(logalpha, logalphaA, logalphaB)

                # # 3) save the latent traversed images
                # if self.dataset.lower() == '3dchairs':
                #     self.save_traverse(iteration, limb=-2, limu=2, inter=0.5)
                # else:
                #     self.save_traverse(iteration, limb=-3, limu=3, inter=0.1)

            if iteration % self.eval_metrics_iter == 0:
                self.save_synth_cross_modal(iteration, z_A, z_B, train=False, howmany=3)

            # (visdom) insert current line stats
            if self.viz_on and (iteration % self.viz_ll_iter == 0):
                self.line_gather.insert(iter=iteration,
                                        recon_both=loss_recon_POE.item(),
                                        recon_A=loss_recon_infA.item(),
                                        recon_B=loss_recon_infB.item(),
                                        kl_A=loss_kl_infA.item(),
                                        kl_B=loss_kl_infB.item(),
                                        cont_capacity_loss_infA=cont_capacity_loss_infA.item(),
                                        disc_capacity_loss_infA=disc_capacity_loss_infA.item(),
                                        cont_capacity_loss_infB=cont_capacity_loss_infB.item(),
                                        disc_capacity_loss_infB=disc_capacity_loss_infB.item()
                                        )

            # (visdom) visualize line stats (then flush out)
            if self.viz_on and (iteration % self.viz_la_iter == 0):
                self.visualize_line()
                self.line_gather.flush()