Ejemplo n.º 1
0
def visualize_traverse(VAE,
                       data_loader,
                       limit=3,
                       inter=2 / 3,
                       loc=-1,
                       z_dim=64,
                       output_dir='traverse_result'):

    decoder = VAE.decode
    encoder = VAE.encode
    interpolation = torch.arange(-limit, limit + 0.1, inter)
    global_iter = 10

    fixed_idx = 0
    fixed_img = data_loader.dataset.__getitem__(fixed_idx)[1]
    fixed_img = fixed_img.to('cpu').unsqueeze(0)
    fixed_img_z = encoder(fixed_img)[:, :z_dim]

    # random_z = torch.rand(1, z_dim, 1, 1, device='cpu')

    Z = {'fixed_img': fixed_img_z}
    index_feature = [2, 20, 22, 26, 41]
    gifs = []
    for key in Z:
        z_ori = Z[key]
        samples = []
        for row in index_feature:
            if loc != -1 and row != loc:
                continue
            z = z_ori.clone()
            for val in interpolation:
                z[:, row, :, :] = val
                sample = F.sigmoid(decoder(z)).data
                sample = fixed_img + sample
                # sample = merge(sample, fixed_img)
                print("fixed image shape {}".format(fixed_img.shape))
                samples.append(sample)
                gifs.append(sample)

    # samples = torch.cat(samples, dim=0).cpu()
    # title = '{}_latent_traversal(iter:{})'.format(key, 1)

    output_dir = os.path.join(output_dir, str(global_iter))
    mkdirs(output_dir)
    gifs = torch.cat(gifs)
    print("gif size is {}".format(gifs.shape))
    gifs = gifs.view(len(Z), len(index_feature), len(interpolation), 1, 256,
                     256).transpose(1, 2)

    for i, key in enumerate(Z.keys()):
        for j, val in enumerate(interpolation):
            save_image(tensor=gifs[i][j].cpu(),
                       filename=os.path.join(output_dir,
                                             '{}_{}.jpg'.format(key, j)),
                       nrow=z_dim,
                       pad_value=1)

        grid2gif(str(os.path.join(output_dir, key + '*.jpg')),
                 str(os.path.join(output_dir, key + '.gif')),
                 delay=10)
Ejemplo n.º 2
0
    def visualize_traverse(self, limit=3, inter=2 / 3, loc=-1):
        self.net_mode(train=False)

        decoder = self.VAE.decode
        encoder = self.VAE.encode
        interpolation = torch.arange(-limit, limit + 0.1, inter)

        random_img = self.data_loader.dataset.__getitem__(0)[1]
        random_img = random_img.to(self.device).unsqueeze(0)
        random_img_z = encoder(random_img)[:, :self.z_dim]

        if self.dataset.lower() == 'dsprites':
            fixed_idx1 = 87040  # square
            fixed_idx2 = 332800  # ellipse
            fixed_idx3 = 578560  # heart

            fixed_img1 = self.data_loader.dataset.__getitem__(fixed_idx1)[0]
            fixed_img1 = fixed_img1.to(self.device).unsqueeze(0)
            fixed_img_z1 = encoder(fixed_img1)[:, :self.z_dim]

            fixed_img2 = self.data_loader.dataset.__getitem__(fixed_idx2)[0]
            fixed_img2 = fixed_img2.to(self.device).unsqueeze(0)
            fixed_img_z2 = encoder(fixed_img2)[:, :self.z_dim]

            fixed_img3 = self.data_loader.dataset.__getitem__(fixed_idx3)[0]
            fixed_img3 = fixed_img3.to(self.device).unsqueeze(0)
            fixed_img_z3 = encoder(fixed_img3)[:, :self.z_dim]

            Z = {
                'fixed_square': fixed_img_z1,
                'fixed_ellipse': fixed_img_z2,
                'fixed_heart': fixed_img_z3,
                'random_img': random_img_z
            }

        elif self.dataset.lower() == 'celeba':
            fixed_idx1 = 70000  # 'CelebA/img_align_celeba/191282.jpg'
            fixed_idx2 = 143307  # 'CelebA/img_align_celeba/143308.jpg'
            fixed_idx3 = 101535  # 'CelebA/img_align_celeba/101536.jpg'
            fixed_idx4 = 70059  # 'CelebA/img_align_celeba/070060.jpg'

            fixed_img1 = self.data_loader.dataset.__getitem__(fixed_idx1)[0]
            fixed_img1 = fixed_img1.to(self.device).unsqueeze(0)
            fixed_img_z1 = encoder(fixed_img1)[:, :self.z_dim]

            fixed_img2 = self.data_loader.dataset.__getitem__(fixed_idx2)[0]
            fixed_img2 = fixed_img2.to(self.device).unsqueeze(0)
            fixed_img_z2 = encoder(fixed_img2)[:, :self.z_dim]

            fixed_img3 = self.data_loader.dataset.__getitem__(fixed_idx3)[0]
            fixed_img3 = fixed_img3.to(self.device).unsqueeze(0)
            fixed_img_z3 = encoder(fixed_img3)[:, :self.z_dim]

            fixed_img4 = self.data_loader.dataset.__getitem__(fixed_idx4)[0]
            fixed_img4 = fixed_img4.to(self.device).unsqueeze(0)
            fixed_img_z4 = encoder(fixed_img4)[:, :self.z_dim]

            Z = {
                'fixed_1': fixed_img_z1,
                'fixed_2': fixed_img_z2,
                'fixed_3': fixed_img_z3,
                'fixed_4': fixed_img_z4,
                'random': random_img_z
            }

        elif self.dataset.lower() == '3dchairs':
            fixed_idx1 = 40919  # 3DChairs/images/4682_image_052_p030_t232_r096.png
            fixed_idx2 = 5172  # 3DChairs/images/14657_image_020_p020_t232_r096.png
            fixed_idx3 = 22330  # 3DChairs/images/30099_image_052_p030_t232_r096.png

            fixed_img1 = self.data_loader.dataset.__getitem__(fixed_idx1)[0]
            fixed_img1 = fixed_img1.to(self.device).unsqueeze(0)
            fixed_img_z1 = encoder(fixed_img1)[:, :self.z_dim]

            fixed_img2 = self.data_loader.dataset.__getitem__(fixed_idx2)[0]
            fixed_img2 = fixed_img2.to(self.device).unsqueeze(0)
            fixed_img_z2 = encoder(fixed_img2)[:, :self.z_dim]

            fixed_img3 = self.data_loader.dataset.__getitem__(fixed_idx3)[0]
            fixed_img3 = fixed_img3.to(self.device).unsqueeze(0)
            fixed_img_z3 = encoder(fixed_img3)[:, :self.z_dim]

            Z = {
                'fixed_1': fixed_img_z1,
                'fixed_2': fixed_img_z2,
                'fixed_3': fixed_img_z3,
                'random': random_img_z
            }
        else:
            fixed_idx = 0
            fixed_img = self.data_loader.dataset.__getitem__(fixed_idx)[0]
            fixed_img = fixed_img.to(self.device).unsqueeze(0)
            fixed_img_z = encoder(fixed_img)[:, :self.z_dim]

            random_z = torch.rand(1, self.z_dim, 1, 1, device=self.device)

            Z = {
                'fixed_img': fixed_img_z,
                'random_img': random_img_z,
                'random_z': random_z
            }

        gifs = []
        for key in Z:
            z_ori = Z[key]
            samples = []
            for row in range(self.z_dim):
                if loc != -1 and row != loc:
                    continue
                z = z_ori.clone()
                for val in interpolation:
                    z[:, row] = val
                    sample = F.sigmoid(decoder(z)).data
                    samples.append(sample)
                    gifs.append(sample)
            samples = torch.cat(samples, dim=0).cpu()
            title = '{}_latent_traversal(iter:{})'.format(
                key, self.global_iter)
            self.viz.images(samples,
                            env=self.name + '/traverse',
                            opts=dict(title=title),
                            nrow=len(interpolation))

        if self.output_save:
            output_dir = os.path.join(self.output_dir, str(self.global_iter))
            mkdirs(output_dir)
            gifs = torch.cat(gifs)
            gifs = gifs.view(len(Z), self.z_dim, len(interpolation), self.nc,
                             64, 64).transpose(1, 2)
            for i, key in enumerate(Z.keys()):
                for j, val in enumerate(interpolation):
                    save_image(tensor=gifs[i][j].cpu(),
                               filename=os.path.join(
                                   output_dir, '{}_{}.jpg'.format(key, j)),
                               nrow=self.z_dim,
                               pad_value=1)

                grid2gif(str(os.path.join(output_dir, key + '*.jpg')),
                         str(os.path.join(output_dir, key + '.gif')),
                         delay=10)

        self.net_mode(train=True)
Ejemplo n.º 3
0
    def save_traverse(self, iters, limb=-3, limu=3, inter=2 / 3, loc=-1):

        self.set_mode(train=False)

        encoder = self.encoder
        decoder = self.decoder
        interpolation = torch.arange(limb, limu + 0.001, inter)

        i = np.random.randint(self.N)
        random_img = self.data_loader.dataset.__getitem__(i)[0]
        if self.use_cuda:
            random_img = random_img.cuda()
        random_img = random_img.unsqueeze(0)
        random_img_zmu, _, _ = encoder(random_img)

        if self.dataset.lower() == 'dsprites':

            fixed_idx1 = 87040  # square
            fixed_idx2 = 332800  # ellipse
            fixed_idx3 = 578560  # heart

            fixed_img1 = self.data_loader.dataset.__getitem__(fixed_idx1)[0]
            if self.use_cuda:
                fixed_img1 = fixed_img1.cuda()
            fixed_img1 = fixed_img1.unsqueeze(0)
            fixed_img_zmu1, _, _ = encoder(fixed_img1)

            fixed_img2 = self.data_loader.dataset.__getitem__(fixed_idx2)[0]
            if self.use_cuda:
                fixed_img2 = fixed_img2.cuda()
            fixed_img2 = fixed_img2.unsqueeze(0)
            fixed_img_zmu2, _, _ = encoder(fixed_img2)

            fixed_img3 = self.data_loader.dataset.__getitem__(fixed_idx3)[0]
            if self.use_cuda:
                fixed_img3 = fixed_img3.cuda()
            fixed_img3 = fixed_img3.unsqueeze(0)
            fixed_img_zmu3, _, _ = encoder(fixed_img3)

            IMG = {
                'fixed_square': fixed_img1,
                'fixed_ellipse': fixed_img2,
                'fixed_heart': fixed_img3,
                'random_img': random_img
            }

            Z = {
                'fixed_square': fixed_img_zmu1,
                'fixed_ellipse': fixed_img_zmu2,
                'fixed_heart': fixed_img_zmu3,
                'random_img': random_img_zmu
            }

        elif self.dataset.lower() == 'oval_dsprites':

            fixed_idx1 = 87040  # oval1
            fixed_idx2 = 220045  # oval2
            fixed_idx3 = 178560  # oval3

            fixed_img1 = self.data_loader.dataset.__getitem__(fixed_idx1)[0]
            if self.use_cuda:
                fixed_img1 = fixed_img1.cuda()
            fixed_img1 = fixed_img1.unsqueeze(0)
            fixed_img_zmu1, _, _ = encoder(fixed_img1)

            fixed_img2 = self.data_loader.dataset.__getitem__(fixed_idx2)[0]
            if self.use_cuda:
                fixed_img2 = fixed_img2.cuda()
            fixed_img2 = fixed_img2.unsqueeze(0)
            fixed_img_zmu2, _, _ = encoder(fixed_img2)

            fixed_img3 = self.data_loader.dataset.__getitem__(fixed_idx3)[0]
            if self.use_cuda:
                fixed_img3 = fixed_img3.cuda()
            fixed_img3 = fixed_img3.unsqueeze(0)
            fixed_img_zmu3, _, _ = encoder(fixed_img3)

            IMG = {
                'fixed1': fixed_img1,
                'fixed2': fixed_img2,
                'fixed3': fixed_img3,
                'random_img': random_img
            }

            Z = {
                'fixed1': fixed_img_zmu1,
                'fixed2': fixed_img_zmu2,
                'fixed3': fixed_img_zmu3,
                'random_img': random_img_zmu
            }

        elif self.dataset.lower() == '3dfaces':

            fixed_idx1 = 6245
            fixed_idx2 = 10205
            fixed_idx3 = 68560

            fixed_img1 = self.data_loader.dataset.__getitem__(fixed_idx1)[0]
            if self.use_cuda:
                fixed_img1 = fixed_img1.cuda()
            fixed_img1 = fixed_img1.unsqueeze(0)
            fixed_img_zmu1, _, _ = encoder(fixed_img1)

            fixed_img2 = self.data_loader.dataset.__getitem__(fixed_idx2)[0]
            if self.use_cuda:
                fixed_img2 = fixed_img2.cuda()
            fixed_img2 = fixed_img2.unsqueeze(0)
            fixed_img_zmu2, _, _ = encoder(fixed_img2)

            fixed_img3 = self.data_loader.dataset.__getitem__(fixed_idx3)[0]
            if self.use_cuda:
                fixed_img3 = fixed_img3.cuda()
            fixed_img3 = fixed_img3.unsqueeze(0)
            fixed_img_zmu3, _, _ = encoder(fixed_img3)

            IMG = {
                'fixed1': fixed_img1,
                'fixed2': fixed_img2,
                'fixed3': fixed_img3,
                'random_img': random_img
            }

            Z = {
                'fixed1': fixed_img_zmu1,
                'fixed2': fixed_img_zmu2,
                'fixed3': fixed_img_zmu3,
                'random_img': random_img_zmu
            }

        elif self.dataset.lower() == 'celeba':

            fixed_idx1 = 191281
            fixed_idx2 = 143307
            fixed_idx3 = 101535

            fixed_img1 = self.data_loader.dataset.__getitem__(fixed_idx1)[0]
            if self.use_cuda:
                fixed_img1 = fixed_img1.cuda()
            fixed_img1 = fixed_img1.unsqueeze(0)
            fixed_img_zmu1, _, _ = encoder(fixed_img1)

            fixed_img2 = self.data_loader.dataset.__getitem__(fixed_idx2)[0]
            if self.use_cuda:
                fixed_img2 = fixed_img2.cuda()
            fixed_img2 = fixed_img2.unsqueeze(0)
            fixed_img_zmu2, _, _ = encoder(fixed_img2)

            fixed_img3 = self.data_loader.dataset.__getitem__(fixed_idx3)[0]
            if self.use_cuda:
                fixed_img3 = fixed_img3.cuda()
            fixed_img3 = fixed_img3.unsqueeze(0)
            fixed_img_zmu3, _, _ = encoder(fixed_img3)

            IMG = {
                'fixed1': fixed_img1,
                'fixed2': fixed_img2,
                'fixed3': fixed_img3,
                'random_img': random_img
            }

            Z = {
                'fixed1': fixed_img_zmu1,
                'fixed2': fixed_img_zmu2,
                'fixed3': fixed_img_zmu3,
                'random_img': random_img_zmu
            }

        elif self.dataset.lower() == 'edinburgh_teapots':

            fixed_idx1 = 7040
            fixed_idx2 = 32800
            fixed_idx3 = 78560

            fixed_img1 = self.data_loader.dataset.__getitem__(fixed_idx1)[0]
            if self.use_cuda:
                fixed_img1 = fixed_img1.cuda()
            fixed_img1 = fixed_img1.unsqueeze(0)
            fixed_img_zmu1, _, _ = encoder(fixed_img1)

            fixed_img2 = self.data_loader.dataset.__getitem__(fixed_idx2)[0]
            if self.use_cuda:
                fixed_img2 = fixed_img2.cuda()
            fixed_img2 = fixed_img2.unsqueeze(0)
            fixed_img_zmu2, _, _ = encoder(fixed_img2)

            fixed_img3 = self.data_loader.dataset.__getitem__(fixed_idx3)[0]
            if self.use_cuda:
                fixed_img3 = fixed_img3.cuda()
            fixed_img3 = fixed_img3.unsqueeze(0)
            fixed_img_zmu3, _, _ = encoder(fixed_img3)

            IMG = {
                'fixed1': fixed_img1,
                'fixed2': fixed_img2,
                'fixed3': fixed_img3,
                'random_img': random_img
            }

            Z = {
                'fixed1': fixed_img_zmu1,
                'fixed2': fixed_img_zmu2,
                'fixed3': fixed_img_zmu3,
                'random_img': random_img_zmu
            }

#        elif self.dataset.lower() == '3dchairs':
#
#            fixed_idx1 = 40919 # 3DChairs/images/4682_image_052_p030_t232_r096.png
#            fixed_idx2 = 5172  # 3DChairs/images/14657_image_020_p020_t232_r096.png
#            fixed_idx3 = 22330 # 3DChairs/images/30099_image_052_p030_t232_r096.png
#
#            fixed_img1 = self.data_loader.dataset.__getitem__(fixed_idx1)[0]
#            fixed_img1 = fixed_img1.to(self.device).unsqueeze(0)
#            fixed_img_z1 = encoder(fixed_img1)[:, :self.z_dim]
#
#            fixed_img2 = self.data_loader.dataset.__getitem__(fixed_idx2)[0]
#            fixed_img2 = fixed_img2.to(self.device).unsqueeze(0)
#            fixed_img_z2 = encoder(fixed_img2)[:, :self.z_dim]
#
#            fixed_img3 = self.data_loader.dataset.__getitem__(fixed_idx3)[0]
#            fixed_img3 = fixed_img3.to(self.device).unsqueeze(0)
#            fixed_img_z3 = encoder(fixed_img3)[:, :self.z_dim]
#
#            Z = {'fixed_1':fixed_img_z1, 'fixed_2':fixed_img_z2,
#                 'fixed_3':fixed_img_z3, 'random':random_img_zmu}
#
        else:

            raise NotImplementedError

        # do traversal and collect generated images
        gifs = []
        for key in Z:
            z_ori = Z[key]
            for row in range(self.z_dim):
                if loc != -1 and row != loc:
                    continue
                z = z_ori.clone()
                for val in interpolation:
                    z[:, row] = val
                    sample = torch.sigmoid(decoder(z)).data
                    gifs.append(sample)

        # save the generated files, also the animated gifs
        out_dir = os.path.join(self.output_dir_trvsl, str(iters))
        mkdirs(self.output_dir_trvsl)
        mkdirs(out_dir)
        gifs = torch.cat(gifs)
        gifs = gifs.view(len(Z), self.z_dim, len(interpolation), self.nc, 64,
                         64).transpose(1, 2)
        for i, key in enumerate(Z.keys()):
            for j, val in enumerate(interpolation):
                I = torch.cat([IMG[key], gifs[i][j]], dim=0)
                save_image(tensor=I.cpu(),
                           filename=os.path.join(out_dir,
                                                 '%s_%03d.jpg' % (key, j)),
                           nrow=1 + self.z_dim,
                           pad_value=1)
            # make animated gif
            grid2gif(out_dir,
                     key,
                     str(os.path.join(out_dir, key + '.gif')),
                     delay=10)

        self.set_mode(train=True)
Ejemplo n.º 4
0
    def save_refined_traverse(self, iters):

        encoder = self.encoder
        decoder = self.decoder

        num_vars = 30  # number of variations in each dim
        inter = torch.arange(-16.0, 16.001, 0.2)  # used in "solver_test.py"
        inter = inter.detach().numpy()

        interpolation = [0] * self.z_dim
        for row in range(self.z_dim):
            interpolation[row] = torch.tensor(np.linspace(-3.0, 3.0, num_vars),
                                              dtype=torch.float32)

        ####

        if self.dataset.lower() == 'oval_dsprites':

            idx = 87040  # image ID
            factor_ranges = {}
            factor_ranges[0] = inter[[58, 98]]
            factor_ranges[1] = inter[[72, 98]]
            factor_ranges[4] = inter[[58, 97]]
            factor_ranges[5] = inter[[62, 95]]
            factor_ranges[8] = inter[[68, 93]]

        elif self.dataset.lower() == '3dfaces':

            idx = 6245  # image ID
            factor_ranges = {}
            factor_ranges[0] = inter[[
                76, 87
            ]]  # vary z_0 from inter[76] to inter[87]
            factor_ranges[3] = inter[[68, 90]]
            factor_ranges[5] = inter[[79, 82]]
            factor_ranges[7] = inter[[58, 80]]

        elif self.dataset.lower() == 'edinburgh_teapots':

            idx = 46203  # image ID
            factor_ranges = {}
            factor_ranges[2] = inter[[72, 82]]
            factor_ranges[3] = inter[[60, 100]]
            factor_ranges[4] = inter[[68, 96]]
            factor_ranges[6] = inter[[74, 98]]
            factor_ranges[7] = inter[[73, 91]]
            factor_ranges[8] = inter[[75, 94]]
            factor_ranges[9] = inter[[64, 81]]

        elif self.dataset.lower() == 'celeba':

            all_idx = [4195, 95070]  # image IDs

            all_factor_ranges = [0] * len(all_idx)

            cnt = -1

            # 4195
            cnt += 1
            all_factor_ranges[cnt] = {}
            all_factor_ranges[cnt][2 - 1] = inter[[98, 120]]
            all_factor_ranges[cnt][7 - 1] = inter[[56, 89]]
            all_factor_ranges[cnt][12 - 1] = inter[[72, 112]]
            all_factor_ranges[cnt][20 - 1] = inter[[40, 84]]

            # 95070
            cnt += 1
            all_factor_ranges[cnt] = {}
            all_factor_ranges[cnt][2 - 1] = inter[[72, 119]]
            all_factor_ranges[cnt][7 - 1] = inter[[46, 75]]
            all_factor_ranges[cnt][12 - 1] = inter[[73, 115]]
            all_factor_ranges[cnt][20 - 1] = inter[[17, 56]]

#            all_idx = [4195, 95070]  # image IDs
#
#            all_factor_ranges = [0]*len(all_idx)
#
#            cnt = -1
#
#            # 4195
#            cnt += 1
#            all_factor_ranges[cnt] = {}
#            all_factor_ranges[cnt][4-1] = inter[[57, 80]]
#            all_factor_ranges[cnt][9-1] = inter[[60, 80]]
#            all_factor_ranges[cnt][12-1] = inter[[49, 87]]
#            all_factor_ranges[cnt][15-1] = inter[[48, 70]]
#
#            # 95070
#            cnt += 1
#            all_factor_ranges[cnt] = {}
#            all_factor_ranges[cnt][4-1] = inter[[67, 86]]
#            all_factor_ranges[cnt][9-1] = inter[[83, 106]]
#            all_factor_ranges[cnt][12-1] = inter[[54, 95]]
#            all_factor_ranges[cnt][15-1] = inter[[59, 81]]

#            all_idx = [4195, 2428, 148838, 95070, 118857]  # image IDs
#
#            all_factor_ranges = [0]*len(all_idx)
#
#            cnt = -1
#
#            # 4195
#            cnt += 1
#            all_factor_ranges[cnt] = {}
#            all_factor_ranges[cnt][9-1] = inter[[55, 80]]
#            all_factor_ranges[cnt][12-1] = inter[[49, 87]]
#            all_factor_ranges[cnt][15-1] = inter[[48, 70]]
#            all_factor_ranges[cnt][3-1] = inter[[44, 129]]
#            all_factor_ranges[cnt][11-1] = inter[[76, 107]]
#            all_factor_ranges[cnt][13-1] = inter[[77, 102]]
#
#            # 2428
#            cnt += 1
#            all_factor_ranges[cnt] = {}
#            all_factor_ranges[cnt][9-1] = inter[[60, 87]]
#            all_factor_ranges[cnt][12-1] = inter[[46, 117]]
#            all_factor_ranges[cnt][15-1] = inter[[43, 88]]
#            all_factor_ranges[cnt][3-1] = inter[[50, 129]]
#            all_factor_ranges[cnt][11-1] = inter[[80, 110]]
#            all_factor_ranges[cnt][13-1] = inter[[45, 101]]
#
#            # 148838
#            cnt += 1
#            all_factor_ranges[cnt] = {}
#            all_factor_ranges[cnt][9-1] = inter[[56, 105]]
#            all_factor_ranges[cnt][12-1] = inter[[46, 104]]
#            all_factor_ranges[cnt][15-1] = inter[[40, 80]]
#            all_factor_ranges[cnt][3-1] = inter[[43, 129]]
#            all_factor_ranges[cnt][11-1] = inter[[82, 111]]
#            all_factor_ranges[cnt][13-1] = inter[[61, 106]]
#
#            # 95070
#            cnt += 1
#            all_factor_ranges[cnt] = {}
#            all_factor_ranges[cnt][9-1] = inter[[72, 116]]
#            all_factor_ranges[cnt][12-1] = inter[[52, 104]]
#            all_factor_ranges[cnt][15-1] = inter[[50, 91]]
#            all_factor_ranges[cnt][3-1] = inter[[23, 112]]
#            all_factor_ranges[cnt][11-1] = inter[[83, 111]]
#            all_factor_ranges[cnt][13-1] = inter[[79, 108]]
#
#            # 118857
#            cnt += 1
#            all_factor_ranges[cnt] = {}
#            all_factor_ranges[cnt][9-1] = inter[[85, 109]]
#            all_factor_ranges[cnt][12-1] = inter[[52, 105]]
#            all_factor_ranges[cnt][15-1] = inter[[60, 92]]
#            all_factor_ranges[cnt][3-1] = inter[[45, 110]]
#            all_factor_ranges[cnt][11-1] = inter[[77, 106]]
#            all_factor_ranges[cnt][13-1] = inter[[60, 110]]

        else:

            raise NotImplementedError

        ####

        if self.dataset.lower() == 'celeba':

            num_vars = 11

            for i, idx in enumerate(all_idx):

                interpolation = {}
                for key in all_factor_ranges[i]:
                    interpolation[key] = torch.tensor(np.linspace(
                        all_factor_ranges[i][key][0],
                        all_factor_ranges[i][key][1], num_vars),
                                                      dtype=torch.float32)

                img = self.data_loader.dataset.__getitem__(idx)[0]
                if self.use_cuda:
                    img = img.cuda()
                img = img.unsqueeze(0)
                z_ori, _, _ = encoder(img)

                # do for each dim
                for key in all_factor_ranges[i]:

                    # do traversal and collect generated images
                    gifs = []
                    z = z_ori.clone()
                    for val in interpolation[key]:
                        z[:, key] = val
                        sample = torch.sigmoid(decoder(z)).data
                        gifs.append(sample)

                    # save the generated files, also the animated gifs
                    out_dir = os.path.join(self.output_dir_trvsl, str(iters))
                    mkdirs(self.output_dir_trvsl)
                    mkdirs(out_dir)
                    gifs = torch.cat(gifs)
                    gifs = gifs.view(1, 1, num_vars, self.nc, 64,
                                     64).transpose(1, 2)
                    gifs = gifs.squeeze(2)

                    save_image(tensor=gifs[0].cpu(),
                               filename=os.path.join(
                                   out_dir,
                                   'all_%d_z%02d.jpg' % (idx, key + 1)),
                               nrow=num_vars,
                               pad_value=1)

#                    for j in range(num_vars):
#                        I = torch.cat([img, gifs[0][j]], dim=0)
#                            # input image leftmost
#                        I2 = gifs[0][j]  # no leftmost input image
#                        save_image( tensor=I.cpu(),
#                            filename=os.path.join(out_dir,
#                                '%d_z%02d_%03d.jpg' % (idx,key+1,j)),
#                            nrow=1+1, pad_value=1 )
#                        save_image( tensor=I2.cpu(),
#                            filename=os.path.join(out_dir,
#                                'nox_%d_z%02d_%03d.jpg' % (idx,key+1,j)),
#                            nrow=1, pad_value=1 )

            return

        ####

        for key in factor_ranges:
            interpolation[key] = torch.tensor(np.linspace(
                factor_ranges[key][0], factor_ranges[key][1], num_vars),
                                              dtype=torch.float32)

        img = self.data_loader.dataset.__getitem__(idx)[0]
        if self.use_cuda:
            img = img.cuda()
        img = img.unsqueeze(0)
        img_zmu, _, _ = encoder(img)

        # do traversal and collect generated images
        gifs = []
        z_ori = img_zmu
        for row in range(self.z_dim):
            z = z_ori.clone()
            for val in interpolation[row]:
                z[:, row] = val
                sample = torch.sigmoid(decoder(z)).data
                gifs.append(sample)

        # save the generated files, also the animated gifs
        out_dir = os.path.join(self.output_dir_trvsl, str(iters))
        mkdirs(self.output_dir_trvsl)
        mkdirs(out_dir)
        gifs = torch.cat(gifs)
        gifs = gifs.view(1, self.z_dim, num_vars, self.nc, 64,
                         64).transpose(1, 2)
        for j in range(num_vars):
            I = torch.cat([img, gifs[0][j]], dim=0)  # input image leftmost
            I2 = gifs[0][j]  # no leftmost input image
            save_image(tensor=I.cpu(),
                       filename=os.path.join(out_dir,
                                             '%d_%03d.jpg' % (idx, j)),
                       nrow=1 + self.z_dim,
                       pad_value=1)
            save_image(tensor=I2.cpu(),
                       filename=os.path.join(out_dir,
                                             'nox_%d_%03d.jpg' % (idx, j)),
                       nrow=self.z_dim,
                       pad_value=1)

        # make animated gif
        grid2gif(out_dir,
                 str(idx),
                 str(os.path.join(out_dir,
                                  str(idx) + '.gif')),
                 delay=10)
        grid2gif(out_dir,
                 'nox_' + str(idx),
                 str(os.path.join(out_dir, 'nox_' + str(idx) + '.gif')),
                 delay=10)
Ejemplo n.º 5
0
    def viz_traverse(self, limit=3, inter=2 / 3, loc=-1):
        self.net_mode(train=False)
        import random

        decoder = self.net.decoder
        encoder = self.net.encoder
        interpolation = torch.arange(-limit, limit + 0.1, inter)

        n_dsets = len(self.data_loader.dataset)
        rand_idx = random.randint(1, n_dsets - 1)

        random_img = self.data_loader.dataset.__getitem__(rand_idx)
        random_img = Variable(cuda(random_img, self.use_cuda),
                              volatile=True).unsqueeze(0)
        random_img_z = encoder(random_img)[:, :self.z_dim]

        random_z = Variable(cuda(torch.rand(1, self.z_dim), self.use_cuda),
                            volatile=True)

        if self.dataset == 'dsprites':
            fixed_idx1 = 87040  # square
            fixed_idx2 = 332800  # ellipse
            fixed_idx3 = 578560  # heart

            fixed_img1 = self.data_loader.dataset.__getitem__(fixed_idx1)
            fixed_img1 = Variable(cuda(fixed_img1, self.use_cuda),
                                  volatile=True).unsqueeze(0)
            fixed_img_z1 = encoder(fixed_img1)[:, :self.z_dim]

            fixed_img2 = self.data_loader.dataset.__getitem__(fixed_idx2)
            fixed_img2 = Variable(cuda(fixed_img2, self.use_cuda),
                                  volatile=True).unsqueeze(0)
            fixed_img_z2 = encoder(fixed_img2)[:, :self.z_dim]

            fixed_img3 = self.data_loader.dataset.__getitem__(fixed_idx3)
            fixed_img3 = Variable(cuda(fixed_img3, self.use_cuda),
                                  volatile=True).unsqueeze(0)
            fixed_img_z3 = encoder(fixed_img3)[:, :self.z_dim]

            Z = {
                'fixed_square': fixed_img_z1,
                'fixed_ellipse': fixed_img_z2,
                'fixed_heart': fixed_img_z3,
                'random_img': random_img_z
            }
        else:
            fixed_idx = 0
            fixed_img = self.data_loader.dataset.__getitem__(fixed_idx)
            fixed_img = Variable(cuda(fixed_img, self.use_cuda),
                                 volatile=True).unsqueeze(0)
            fixed_img_z = encoder(fixed_img)[:, :self.z_dim]

            Z = {
                'fixed_img': fixed_img_z,
                'random_img': random_img_z,
                'random_z': random_z
            }

        gifs = []
        for key in Z.keys():
            z_ori = Z[key]
            samples = []
            for row in range(self.z_dim):
                if loc != -1 and row != loc:
                    continue
                z = z_ori.clone()
                for val in interpolation:
                    z[:, row] = val
                    sample = F.sigmoid(decoder(z)).data
                    samples.append(sample)
                    gifs.append(sample)
            samples = torch.cat(samples, dim=0).cpu()
            title = '{}_latent_traversal(iter:{})'.format(
                key, self.global_iter)

            if self.viz_on:
                self.viz.images(samples,
                                env=self.viz_name + '_traverse',
                                opts=dict(title=title),
                                nrow=len(interpolation))

        if self.save_output:
            output_dir = os.path.join(self.output_dir, str(self.global_iter))
            os.makedirs(output_dir, exist_ok=True)
            gifs = torch.cat(gifs)
            gifs = gifs.view(len(Z), self.z_dim, len(interpolation), self.nc,
                             64, 64).transpose(1, 2)
            for i, key in enumerate(Z.keys()):
                for j, val in enumerate(interpolation):
                    save_image(tensor=gifs[i][j].cpu(),
                               filename=os.path.join(
                                   output_dir, '{}_{}.jpg'.format(key, j)),
                               nrow=self.z_dim,
                               pad_value=1)

                grid2gif(os.path.join(output_dir, key + '*.jpg'),
                         os.path.join(output_dir, key + '.gif'),
                         delay=10)

        self.net_mode(train=True)
Ejemplo n.º 6
0
    def save_traverse_new(self,
                          iters,
                          num_reps,
                          limb=-3,
                          limu=3,
                          inter=2 / 3,
                          loc=-1):

        encoderA = self.encoderA
        encoderB = self.encoderB
        decoderA = self.decoderA
        decoderB = self.decoderB
        interpolation = torch.arange(limb, limu + 0.001, inter)

        np.random.seed(123)
        rii = np.random.randint(self.N, size=num_reps)
        #--#
        prn_str = '(TRAVERSAL) random image IDs = {}'.format(rii)
        print(prn_str)
        self.dump_to_record(prn_str)
        #--#
        random_XA = [0] * num_reps
        random_XB = [0] * num_reps
        random_zmu = [0] * num_reps
        for i, i2 in enumerate(rii):
            random_XA[i], random_XB[i] = \
                self.data_loader.dataset.__getitem__(i2)[0:2]
            if self.use_cuda:
                random_XA[i] = random_XA[i].cuda()
                random_XB[i] = random_XB[i].cuda()
            random_XA[i] = random_XA[i].unsqueeze(0)
            random_XB[i] = random_XB[i].unsqueeze(0)
            #
            mu_infA, std_infA, logvar_infA = encoderA(random_XA[i])
            mu_infB, std_infB, logvar_infB = encoderB(random_XB[i])
            random_zmu[i], _, _ = apply_poe(self.use_cuda, mu_infA,
                                            logvar_infA, mu_infB, logvar_infB)

        if self.dataset.lower() == 'idaz_elli_3df':

            fixed_idxs = [10306, 7246, 21440]

            fixed_XA = [0] * len(fixed_idxs)
            fixed_XB = [0] * len(fixed_idxs)
            fixed_zmu = [0] * len(fixed_idxs)

            for i, idx in enumerate(fixed_idxs):

                fixed_XA[i], fixed_XB[i] = \
                    self.data_loader.dataset.__getitem__(idx)[0:2]
                if self.use_cuda:
                    fixed_XA[i] = fixed_XA[i].cuda()
                    fixed_XB[i] = fixed_XB[i].cuda()
                fixed_XA[i] = fixed_XA[i].unsqueeze(0)
                fixed_XB[i] = fixed_XB[i].unsqueeze(0)

                mu_infA, std_infA, logvar_infA = encoderA(fixed_XA[i])
                mu_infB, std_infB, logvar_infB = encoderB(fixed_XB[i])
                fixed_zmu[i], _, _ = apply_poe(self.use_cuda, mu_infA,
                                               logvar_infA, mu_infB,
                                               logvar_infB)

            IMG = {}
            for i, idx in enumerate(fixed_idxs):
                IMG['fixed'+str(i)] = \
                    torch.cat([fixed_XA[i], fixed_XB[i]], dim=2)
            for i in range(num_reps):
                IMG['random'+str(i)] = \
                    torch.cat([random_XA[i], random_XB[i]], dim=2)

            Z = {}
            for i, idx in enumerate(fixed_idxs):
                Z['fixed' + str(i)] = fixed_zmu[i]
            for i in range(num_reps):
                Z['random' + str(i)] = random_zmu[i]

        else:

            raise NotImplementedError

        WS = torch.ones(IMG['fixed1'].shape)
        if self.use_cuda:
            WS = WS.cuda()

        # do traversal and collect generated images
        gifs = []
        for key in Z:

            z_ori = Z[key]

            # traversal over z
            for val in interpolation:
                gifs.append(WS)
            for row in range(self.z_dim):
                if loc != -1 and row != loc:
                    continue
                z = z_ori.clone()
                for val in interpolation:
                    z[:, row] = val
                    sampleA = torch.sigmoid(decoderA(z)).data
                    sampleB = torch.sigmoid(decoderB(z)).data
                    sample = torch.cat([sampleA, sampleB], dim=2)
                    gifs.append(sample)

        ####

        # save the generated files, also the animated gifs
        out_dir = os.path.join(self.output_dir_trvsl, str(iters))
        mkdirs(self.output_dir_trvsl)
        mkdirs(out_dir)
        gifs = torch.cat(gifs)
        gifs = gifs.view(len(Z), 1 + self.z_dim, len(interpolation), self.nc,
                         2 * 64, 64).transpose(1, 2)
        for i, key in enumerate(Z.keys()):
            for j, val in enumerate(interpolation):
                I = torch.cat([IMG[key], gifs[i][j]], dim=0)
                save_image(tensor=I.cpu(),
                           filename=os.path.join(out_dir,
                                                 '%s_%03d.jpg' % (key, j)),
                           nrow=1 + 1 + self.z_dim,
                           pad_value=1)
            # make animated gif
            grid2gif(out_dir,
                     key,
                     str(os.path.join(out_dir, key + '.gif')),
                     delay=10)
Ejemplo n.º 7
0
    def visualize_traverse(self, limit=3, inter=2 / 3, loc=-1):
        self.model.eval()

        decoder = self.model.decode
        encoder = self.model.encode
        reparametrize = self.model.reparameterize
        interpolation = torch.arange(-limit, limit + 0.1, inter)

        random_img = self.dataloader.dataset.__getitem__(0)[1]
        random_img = random_img.to(self.device).unsqueeze(0)
        mu, logvar = encoder(random_img)
        random_img_z = reparametrize(mu, logvar)

        random_img_1 = self.dataloader.dataset.__getitem__(100)[1]
        random_img_1 = random_img_1.to(self.device).unsqueeze(0)
        mu, logvar = encoder(random_img_1)
        random_img_z_1 = reparametrize(mu, logvar)

        random_img_2 = self.dataloader.dataset.__getitem__(33)[1]
        random_img_2 = random_img_2.to(self.device).unsqueeze(0)
        mu, logvar = encoder(random_img_2)
        random_img_z_2 = reparametrize(mu, logvar)

        random_img_3 = self.dataloader.dataset.__getitem__(78)[1]
        random_img_3 = random_img_3.to(self.device).unsqueeze(0)
        mu, logvar = encoder(random_img_3)
        random_img_z_3 = reparametrize(mu, logvar)

        Z = {
            'random_img_1': random_img_z_2,
            'random_img_2': random_img_z_1,
            'random_img_3': random_img_z_3,
            'random_img': random_img_z
        }

        gifs = []
        for key in Z:
            z_ori = Z[key]
            samples = []
            for row in range(self.latent_dim):
                if loc != -1 and row != loc:
                    continue
                z = z_ori.clone()
                for val in interpolation:
                    z[:, row] = val
                    sample = F.sigmoid(decoder(z)).data
                    samples.append(sample)
                    gifs.append(sample)
            samples = torch.cat(samples, dim=0).cpu()

        if self.output_save:
            output_dir = os.path.join(self.model_save_dir,
                                      str(self.num_epochs))
            mkdirs(output_dir)
            gifs = torch.cat(gifs)
            gifs = gifs.view(len(Z), self.latent_dim, len(interpolation), 3, 8,
                             8).transpose(1, 2)
            for i, key in enumerate(Z.keys()):
                for j, val in enumerate(interpolation):
                    save_image(tensor=gifs[i][j].cpu(),
                               filename=os.path.join(
                                   output_dir, '{}_{}.jpg'.format(key, j)),
                               nrow=self.latent_dim,
                               pad_value=1)

                grid2gif(str(os.path.join(output_dir, key + '*.jpg')),
                         str(os.path.join(output_dir, key + '.gif')),
                         delay=10)

        self.model(train=True)
Ejemplo n.º 8
0
    def viz_traverse(self,
                     limit=3,
                     inter=2.0 / 3,
                     loc=-1):  ###Don't look beyond this function
        self.net_mode(train=False)
        import random

        decoder = self.net.decoder
        encoder = self.net.encoder

        interpolation = torch.arange(-limit, limit + 0.1, inter)

        n_dsets = len(self.data_loader.dataset)
        rand_idx = random.randint(1, n_dsets - 1)

        random_label = [0, 0, 0, 0, 1, 0, 0, 0, 0, 0]
        # random_label = [0, 1]
        random_label = torch.FloatTensor([random_label])
        random_label = Variable(cuda(random_label, self.use_cuda))

        random_z_feat = Variable(cuda(torch.rand(1, 10), self.use_cuda),
                                 volatile=True)
        random_z = torch.cat([random_z_feat, random_label], 1)
        # Z = {'fixed_img':fixed_img_z, 'random_img':random_img_z, 'random_z':random_z}
        Z = {'random_z': random_z}

        gifs = []
        for key in Z.keys():
            z_ori = Z[key]
            samples = []
            for row in range(self.z_dim):
                if loc != -1 and row != loc:
                    continue
                z = z_ori.clone()
                for val in interpolation:
                    z[:, row] = val
                    sample = F.sigmoid(decoder(z)).data
                    samples.append(sample)
                    gifs.append(sample)
            samples = torch.cat(samples, dim=0).cpu()
            title = '{}_latent_traversal(iter:{})'.format(
                key, self.global_iter)

            if self.viz_on:
                self.viz.images(samples,
                                env=self.viz_name + '_traverse',
                                opts=dict(title=title),
                                nrow=len(interpolation))

        if self.save_output:
            output_dir = os.path.join(self.output_dir, str(self.global_iter))
            os.makedirs(output_dir)
            gifs = torch.cat(gifs)
            gifs = gifs.view(len(Z), self.z_dim, len(interpolation), self.nc,
                             28, 28).transpose(1, 2)
            for i, key in enumerate(Z.keys()):
                for j, val in enumerate(interpolation):
                    save_image(tensor=gifs[i][j].cpu(),
                               filename=os.path.join(
                                   output_dir, '{}_{}.jpg'.format(key, j)),
                               nrow=self.z_dim,
                               pad_value=1)

                grid2gif(os.path.join(output_dir, key + '*.jpg'),
                         os.path.join(output_dir, key + '.gif'),
                         delay=10)

        self.net_mode(train=True)
Ejemplo n.º 9
0
    def save_traverse_new(self,
                          iters,
                          num_reps,
                          limb=-3,
                          limu=3,
                          inter=2 / 3,
                          loc=-1):

        encoder = self.encoder
        decoder = self.decoder
        interpolation = torch.arange(limb, limu + 0.001, inter)

        rii = np.random.randint(self.N, size=num_reps)
        random_imgs = [0] * num_reps
        random_imgs_zmu = [0] * num_reps
        for i, i2 in enumerate(rii):
            random_imgs[i] = self.data_loader.dataset.__getitem__(i2)[0]
            if self.use_cuda:
                random_imgs[i] = random_imgs[i].cuda()
            random_imgs[i] = random_imgs[i].unsqueeze(0)
            random_imgs_zmu[i], _, _ = encoder(random_imgs[i])

        if self.dataset.lower() == 'dsprites':

            fixed_idx1 = 87040  # square
            fixed_idx2 = 332800  # ellipse
            fixed_idx3 = 578560  # heart

            fixed_img1 = self.data_loader.dataset.__getitem__(fixed_idx1)[0]
            if self.use_cuda:
                fixed_img1 = fixed_img1.cuda()
            fixed_img1 = fixed_img1.unsqueeze(0)
            fixed_img_zmu1, _, _ = encoder(fixed_img1)

            fixed_img2 = self.data_loader.dataset.__getitem__(fixed_idx2)[0]
            if self.use_cuda:
                fixed_img2 = fixed_img2.cuda()
            fixed_img2 = fixed_img2.unsqueeze(0)
            fixed_img_zmu2, _, _ = encoder(fixed_img2)

            fixed_img3 = self.data_loader.dataset.__getitem__(fixed_idx3)[0]
            if self.use_cuda:
                fixed_img3 = fixed_img3.cuda()
            fixed_img3 = fixed_img3.unsqueeze(0)
            fixed_img_zmu3, _, _ = encoder(fixed_img3)

            IMG = {
                'fixed_square': fixed_img1,
                'fixed_ellipse': fixed_img2,
                'fixed_heart': fixed_img3
            }
            for i in range(num_reps):
                IMG['random_img' + str(i)] = random_imgs[i]

            Z = {
                'fixed_square': fixed_img_zmu1,
                'fixed_ellipse': fixed_img_zmu2,
                'fixed_heart': fixed_img_zmu3
            }
            for i in range(num_reps):
                Z['random_img' + str(i)] = random_imgs_zmu[i]

        elif self.dataset.lower() == 'oval_dsprites':

            fixed_idx1 = 87040  # oval1
            fixed_idx2 = 220045  # oval2
            fixed_idx3 = 178560  # oval3

            fixed_img1 = self.data_loader.dataset.__getitem__(fixed_idx1)[0]
            if self.use_cuda:
                fixed_img1 = fixed_img1.cuda()
            fixed_img1 = fixed_img1.unsqueeze(0)
            fixed_img_zmu1, _, _ = encoder(fixed_img1)

            fixed_img2 = self.data_loader.dataset.__getitem__(fixed_idx2)[0]
            if self.use_cuda:
                fixed_img2 = fixed_img2.cuda()
            fixed_img2 = fixed_img2.unsqueeze(0)
            fixed_img_zmu2, _, _ = encoder(fixed_img2)

            fixed_img3 = self.data_loader.dataset.__getitem__(fixed_idx3)[0]
            if self.use_cuda:
                fixed_img3 = fixed_img3.cuda()
            fixed_img3 = fixed_img3.unsqueeze(0)
            fixed_img_zmu3, _, _ = encoder(fixed_img3)

            IMG = {
                'fixed1': fixed_img1,
                'fixed2': fixed_img2,
                'fixed3': fixed_img3
            }
            for i in range(num_reps):
                IMG['random_img' + str(i)] = random_imgs[i]

            Z = {
                'fixed1': fixed_img_zmu1,
                'fixed2': fixed_img_zmu2,
                'fixed3': fixed_img_zmu3
            }
            for i in range(num_reps):
                Z['random_img' + str(i)] = random_imgs_zmu[i]

#        elif self.dataset.lower() == 'celeba':
#
#            fixed_idx1 = 191281 # 'CelebA/img_align_celeba/191282.jpg'
#            fixed_idx2 = 143307 # 'CelebA/img_align_celeba/143308.jpg'
#            fixed_idx3 = 101535 # 'CelebA/img_align_celeba/101536.jpg'
#            fixed_idx4 = 70059  # 'CelebA/img_align_celeba/070060.jpg'
#
#            fixed_img1 = self.data_loader.dataset.__getitem__(fixed_idx1)[0]
#            fixed_img1 = fixed_img1.to(self.device).unsqueeze(0)
#            fixed_img_z1 = encoder(fixed_img1)[:, :self.z_dim]
#
#            fixed_img2 = self.data_loader.dataset.__getitem__(fixed_idx2)[0]
#            fixed_img2 = fixed_img2.to(self.device).unsqueeze(0)
#            fixed_img_z2 = encoder(fixed_img2)[:, :self.z_dim]
#
#            fixed_img3 = self.data_loader.dataset.__getitem__(fixed_idx3)[0]
#            fixed_img3 = fixed_img3.to(self.device).unsqueeze(0)
#            fixed_img_z3 = encoder(fixed_img3)[:, :self.z_dim]
#
#            fixed_img4 = self.data_loader.dataset.__getitem__(fixed_idx4)[0]
#            fixed_img4 = fixed_img4.to(self.device).unsqueeze(0)
#            fixed_img_z4 = encoder(fixed_img4)[:, :self.z_dim]
#
#            Z = {'fixed_1':fixed_img_z1, 'fixed_2':fixed_img_z2,
#                 'fixed_3':fixed_img_z3, 'fixed_4':fixed_img_z4,
#                 'random':random_img_zmu}
#
#        elif self.dataset.lower() == '3dchairs':
#
#            fixed_idx1 = 40919 # 3DChairs/images/4682_image_052_p030_t232_r096.png
#            fixed_idx2 = 5172  # 3DChairs/images/14657_image_020_p020_t232_r096.png
#            fixed_idx3 = 22330 # 3DChairs/images/30099_image_052_p030_t232_r096.png
#
#            fixed_img1 = self.data_loader.dataset.__getitem__(fixed_idx1)[0]
#            fixed_img1 = fixed_img1.to(self.device).unsqueeze(0)
#            fixed_img_z1 = encoder(fixed_img1)[:, :self.z_dim]
#
#            fixed_img2 = self.data_loader.dataset.__getitem__(fixed_idx2)[0]
#            fixed_img2 = fixed_img2.to(self.device).unsqueeze(0)
#            fixed_img_z2 = encoder(fixed_img2)[:, :self.z_dim]
#
#            fixed_img3 = self.data_loader.dataset.__getitem__(fixed_idx3)[0]
#            fixed_img3 = fixed_img3.to(self.device).unsqueeze(0)
#            fixed_img_z3 = encoder(fixed_img3)[:, :self.z_dim]
#
#            Z = {'fixed_1':fixed_img_z1, 'fixed_2':fixed_img_z2,
#                 'fixed_3':fixed_img_z3, 'random':random_img_zmu}
#
        else:

            raise NotImplementedError

        # do traversal and collect generated images
        gifs = []
        for key in Z:
            z_ori = Z[key]
            for row in range(self.z_dim):
                if loc != -1 and row != loc:
                    continue
                z = z_ori.clone()
                #### ONLY FOR RFVAE-LEARN-AGAIN
                if row == 0:
                    for val in (4.0 * interpolation + 2.0):
                        z[:, row] = val
                        sample = torch.sigmoid(decoder(z)).data
                        gifs.append(sample)
                elif row == 7 or row == 8:
                    for val in (2.0 * interpolation):
                        z[:, row] = val
                        sample = torch.sigmoid(decoder(z)).data
                        gifs.append(sample)
                else:
                    for val in interpolation:
                        z[:, row] = val
                        sample = torch.sigmoid(decoder(z)).data
                        gifs.append(sample)
                #### ONLY FOR RFVAE-LEARN-AGAIN

        # save the generated files, also the animated gifs
        out_dir = os.path.join(self.output_dir_trvsl, str(iters))
        mkdirs(self.output_dir_trvsl)
        mkdirs(out_dir)
        gifs = torch.cat(gifs)
        gifs = gifs.view(len(Z), self.z_dim, len(interpolation), self.nc, 64,
                         64).transpose(1, 2)
        for i, key in enumerate(Z.keys()):
            for j, val in enumerate(interpolation):
                I = torch.cat([IMG[key], gifs[i][j]], dim=0)
                save_image(tensor=I.cpu(),
                           filename=os.path.join(out_dir,
                                                 '%s_%03d.jpg' % (key, j)),
                           nrow=1 + self.z_dim,
                           pad_value=1)
            # make animated gif
            grid2gif(
                out_dir,
                key,
                str(os.path.join(out_dir, key + '.gif')),
                duration=0.01  #### ONLY FOR RFVAE-LEARN-AGAIN
            )
Ejemplo n.º 10
0
    def viz_traverse(self, current_iter, limit=3, inter=2 / 3, loc=-1):
        self.net_mode(train=False)
        import random

        decoder = self.net.decoder
        encoder = self.net.encoder
        interpolation = torch.arange(-limit, limit + 0.1, inter)

        n_dsets = len(self.data_loader.dataset)
        rand_idx = random.randint(1, n_dsets - 1)

        random_img, _ = self.data_loader.dataset.__getitem__(rand_idx)
        random_img = Variable(random_img,
                              volatile=True).unsqueeze(0).to(device)
        random_img_z = encoder(random_img)[:, :self.params.z_dim]

        random_z = Variable(torch.rand(1, self.params.z_dim),
                            volatile=True).to(device)

        if self.dataset == 'dsprites':
            fixed_idx1 = 87040  # square
            fixed_idx2 = 332800  # ellipse
            fixed_idx3 = 578560  # heart

            fixed_img1 = self.data_loader.dataset.__getitem__(fixed_idx1)
            fixed_img1 = Variable(fixed_img1,
                                  volatile=True).unsqueeze(0).to(device)
            fixed_img_z1 = encoder(fixed_img1)[:, :self.params.z_dim]

            fixed_img2 = self.data_loader.dataset.__getitem__(fixed_idx2)
            fixed_img2 = Variable(fixed_img2,
                                  volatile=True).unsqueeze(0).to(device)
            fixed_img_z2 = encoder(fixed_img2)[:, :self.params.z_dim]

            fixed_img3 = self.data_loader.dataset.__getitem__(fixed_idx3)
            fixed_img3 = Variable(fixed_img3,
                                  volatile=True).unsqueeze(0).to(device)
            fixed_img_z3 = encoder(fixed_img3)[:, :self.params.z_dim]

            Z = {
                'fixed_square': fixed_img_z1,
                'fixed_ellipse': fixed_img_z2,
                'fixed_heart': fixed_img_z3,
                'random_img': random_img_z
            }
        else:
            fixed_idx = 0
            fixed_img, _ = self.data_loader.dataset.__getitem__(fixed_idx)
            fixed_img = Variable(fixed_img,
                                 volatile=True).unsqueeze(0).to(device)
            fixed_img_z = encoder(fixed_img)[:, :self.params.z_dim]

            Z = {
                'fixed_img': fixed_img_z,
                'random_img': random_img_z,
                'random_z': random_z
            }

        gifs = []

        for key in list(Z.keys()):
            z_ori = Z[key]
            samples = []
            for row in range(self.params.z_dim):
                if loc != -1 and row != loc:
                    continue
                z = z_ori.clone()
                for val in interpolation:
                    z[:, row] = val
                    sample = F.sigmoid(decoder(z)).data
                    samples.append(sample)
                    gifs.append(sample)
            samples = torch.cat(samples[:100], dim=0).cpu()

            if self.viz_on:
                images = make_grid(samples, nrow=10, padding=2, normalize=True)
                self.writer.add_image('Traverse/%s' % key, images,
                                      current_iter)

        if self.save_output:
            output_dir = self.params['info']
            gifs = torch.cat(gifs)
            gifs = gifs.view(len(Z), self.params.z_dim, len(interpolation),
                             self.params.nb_channels, self.params.image_size,
                             self.params.image_size).transpose(1, 2)
            for i, key in enumerate(Z.keys()):
                for j, val in enumerate(interpolation):
                    save_image(tensor=gifs[i][j].cpu(),
                               fp=os.path.join(output_dir,
                                               '{}_{}.jpg'.format(key, j)),
                               nrow=self.params.z_dim,
                               pad_value=1)

                    # index = slice(len(interpolation) * i + j, len(interpolation) * i + j + self.params.z_dim)  # img = torch.cat(gifs[index][:100]).cpu()  # save_image(tensor=img,  #            fp=os.path.join(output_dir, '{}_{}.jpg'.format(key, j)),  #            nrow=self.params.z_dim, pad_value=1)

                grid2gif(os.path.join(output_dir, key + '*.jpg'),
                         os.path.join(output_dir, key + '.gif'),
                         delay=10)

        self.net_mode(train=True)
Ejemplo n.º 11
0
    def viz_traverse(self, limit = 3, inter = 2/3, loc = -1):
        self.net_mode(train = False)
        
        decoder = self.netG_Total
        encoder = self.netE
        interpolation = torch.arange(-limit, limit+0.1, inter)

        n_dsets = len(self.data_loader.dataset)
        rand_idx = random.randint(1, n_dsets-1)

        random_img = self.data_loader.dataset.__getitem__(rand_idx)
        random_img = cuda(random_img, self.use_cuda).unsqueeze(0)
        
        _, random_img_z, _, random_img_y = encoder(random_img)

        random_z = cuda(torch.rand(1, self.z_dim), self.use_cuda)

        if self.dataset == 'dsprites':
            fixed_idx1 = 87040 # square
            fixed_idx2 = 332800 # ellipse
            fixed_idx3 = 578560 # heart

            fixed_img1 = self.data_loader.dataset.__getitem__(fixed_idx1)
            fixed_img1 = cuda(fixed_img1, self.use_cuda).unsqueeze(0)
            _, fixed_img_z1, _, fixed_img_y1 = encoder(fixed_img1)

            fixed_img2 = self.data_loader.dataset.__getitem__(fixed_idx2)
            fixed_img2 = cuda(fixed_img2, self.use_cuda).unsqueeze(0)
            _, fixed_img_z2, _, fixed_img_y2 = encoder(fixed_img2)

            fixed_img3 = self.data_loader.dataset.__getitem__(fixed_idx3)
            fixed_img3 = cuda(fixed_img3, self.use_cuda).unsqueeze(0)
            _, fixed_img_z3, _, fixed_img_y3 = encoder(fixed_img3)

            Z = {'fixed_square':fixed_img_z1, 'fixed_ellipse':fixed_img_z2,
                 'fixed_heart':fixed_img_z3, 'random_img':random_img_z}

            Y = {'fixed_square':fixed_img_y1, 'fixed_ellipse':fixed_img_y2,
                 'fixed_heart':fixed_img_y3, 'random_img':random_img_y}
        else:
            fixed_idx = 0
            fixed_img = self.data_loader.dataset.__getitem__(fixed_idx)
            fixed_img = cuda(fixed_img, self.use_cuda).unsqueeze(0)
            _, fixed_img_z, _, fixed_img_y = encoder(fixed_img)

            Z = {'fixed_img':fixed_img_z, 'random_img':random_img_z}
            Y = {'fixed_img':fixed_img_y, 'random_img':random_img_y}

        gifs = []
        for key in Z.keys():
            z_ori = Z[key]
            y = Y[key]
            samples = []
            for row in range(self.z_dim):
                if loc != -1 and row != loc:
                    continue
                z = z_ori.clone()
                for val in interpolation:
                    z[:, row] = val
                    sample = decoder(torch.cat([z, y], dim=1)).date
                    if self.dataset == 'faces':
                        sample = torch.sigmoid(sample)
                    samples.append(sample)
                    gifs.append(sample)
            samples = torch.cat(samples, dim=0).cpu()
            title = '{}_latent_traversal(iter:{})'.format(key, self.global_iter)

        if self.save_output:
            output_dir = os.path.join(self.output_dir, str(self.global_iter))
            os.makedirs(output_dir, exist_ok=True)
            gifs = torch.cat(gifs)
            gifs = gifs.view(len(Z), self.z_dim, len(interpolation), self.nc, self.image_size, self.image_size).transpose(1, 2)
            for i, key in enumerate(Z.keys()):
                for j, val in enumerate(interpolation):
                    save_image(tensor=gifs[i][j].cpu(),
                               filename=os.path.join(output_dir, '{}_{}.jpg'.format(key, j)),
                               nrow=self.z_dim, pad_value=1)

                grid2gif(os.path.join(output_dir, key+'*.jpg'),
                         os.path.join(output_dir, key+'.gif'), delay=10)

        self.net_mode(train = True)