def __init__(self, args):
        self.use_cuda = args.cuda and torch.cuda.is_available()
        self.max_epoch = args.max_epoch
        self.global_epoch = 0
        self.global_iter = 0

        self.z_dim = args.z_dim
        self.z_var = args.z_var
        self.z_sigma = math.sqrt(args.z_var)
        self._lambda = args.reg_weight
        self.lr = args.lr
        self.beta1 = args.beta1
        self.beta2 = args.beta2
        self.lr_schedules = {30: 2, 50: 5, 100: 10}

        if args.dataset.lower() == 'celeba':
            self.nc = 3
            self.decoder_dist = 'gaussian'
        else:
            raise NotImplementedError

        net = WAE
        self.net = cuda(net(self.z_dim, self.nc), self.use_cuda)
        self.optim = optim.Adam(self.net.parameters(),
                                lr=self.lr,
                                betas=(self.beta1, self.beta2))

        self.gather = DataGather()
        self.viz_name = args.viz_name
        self.viz_port = args.viz_port
        self.viz_on = args.viz_on
        if self.viz_on:
            self.viz = visdom.Visdom(env=self.viz_name + '_lines',
                                     port=self.viz_port)
            self.win_recon = None
            self.win_mmd = None
            self.win_mu = None
            self.win_var = None

        self.ckpt_dir = Path(args.ckpt_dir).joinpath(args.viz_name)
        if not self.ckpt_dir.exists():
            self.ckpt_dir.mkdir(parents=True, exist_ok=True)
        self.ckpt_name = args.ckpt_name
        if self.ckpt_name is not None:
            self.load_checkpoint(self.ckpt_name)

        self.save_output = args.save_output
        self.output_dir = Path(args.output_dir).joinpath(args.viz_name)
        if not self.output_dir.exists():
            self.output_dir.mkdir(parents=True, exist_ok=True)

        self.dset_dir = args.dset_dir
        self.dataset = args.dataset
        self.batch_size = args.batch_size
        self.data_loader = return_data(args)
Example #2
0
    def __init__(self, args):
        # Misc
        use_cuda = args.cuda and torch.cuda.is_available()
        self.device = 'cuda' if use_cuda else 'cpu'
        self.name = args.name
        self.max_iter = int(args.max_iter)
        self.print_iter = args.print_iter
        self.global_iter = 0
        self.pbar = tqdm(total=self.max_iter)

        # Data
        self.dset_dir = args.dset_dir
        self.dataset = args.dataset
        self.batch_size = args.batch_size
        self.data_loader = return_data(args)

        # Networks & Optimizers
        self.z_dim = args.z_dim
        self.gamma = args.gamma
        self.etaS = args.etaS
        self.etaH = args.etaH
        self.lr_VAE = args.lr_VAE
        self.beta1_VAE = args.beta1_VAE
        self.beta2_VAE = args.beta2_VAE

        self.lr_D = args.lr_D
        self.beta1_D = args.beta1_D
        self.beta2_D = args.beta2_D
        self.lr_r = args.lr_r
        self.beta1_r = args.beta1_r
        self.beta2_r = args.beta2_r
        ones = torch.Tensor(np.ones([self.z_dim])*0.5).to(self.device)  # 先创建一个自定义权值的Tensor,这里为了方便将所有权值设为1
        self.r = torch.nn.Parameter(ones)
        if args.dataset == 'dsprites':
            self.VAE = RF_VAE1(self.z_dim).to(self.device)
            self.nc = 1
        else:
            self.VAE = RF_VAE2(self.z_dim).to(self.device)
            self.nc = 3
        self.optim_VAE = optim.Adam(self.VAE.parameters(), lr=self.lr_VAE,
                                    betas=(self.beta1_VAE, self.beta2_VAE))

        self.D = Discriminator(self.z_dim).to(self.device)
        self.optim_D = optim.Adam(self.D.parameters(), lr=self.lr_D,
                                  betas=(self.beta1_D, self.beta2_D))
        self.optim_r = optim.Adam([self.r],lr=self.lr_r,
                                  betas=(self.beta1_r,self.beta2_r))
        self.nets = [self.VAE, self.D]

        # Visdom
        self.viz_on = args.viz_on
        self.win_id = dict(D_z='win_D_z', recon='win_recon', kld='win_kld', acc='win_acc',r_distribute = 'r_distribute')
        self.line_gather = DataGather('iter', 'soft_D_z', 'soft_D_z_pperm', 'recon', 'kld', 'acc','r_distribute')
        self.image_gather = DataGather('true', 'recon')
        if self.viz_on:
            self.viz_port = args.viz_port
            self.viz = visdom.Visdom(port=self.viz_port)
            self.viz_ll_iter = args.viz_ll_iter
            self.viz_la_iter = args.viz_la_iter
            self.viz_ra_iter = args.viz_ra_iter
            self.viz_ta_iter = args.viz_ta_iter

        # Checkpoint
        self.ckpt_dir = os.path.join(args.ckpt_dir, args.name)
        self.ckpt_save_iter = args.ckpt_save_iter
        mkdirs(self.ckpt_dir)
        if args.ckpt_load:
            self.load_checkpoint(args.ckpt_load)

        # Output(latent traverse GIF)
        self.output_dir = os.path.join(args.output_dir, args.name)
        self.output_save = args.output_save
        mkdirs(self.output_dir)
Example #3
0
class Solver(object):
    def __init__(self, args):
        # Misc
        use_cuda = args.cuda and torch.cuda.is_available()
        self.device = 'cuda' if use_cuda else 'cpu'
        self.name = args.name
        self.max_iter = int(args.max_iter)
        self.print_iter = args.print_iter
        self.global_iter = 0
        self.pbar = tqdm(total=self.max_iter)

        # Data
        self.dset_dir = args.dset_dir
        self.dataset = args.dataset
        self.batch_size = args.batch_size
        self.data_loader = return_data(args)

        # Networks & Optimizers
        self.z_dim = args.z_dim
        self.gamma = args.gamma
        self.etaS = args.etaS
        self.etaH = args.etaH
        self.lr_VAE = args.lr_VAE
        self.beta1_VAE = args.beta1_VAE
        self.beta2_VAE = args.beta2_VAE

        self.lr_D = args.lr_D
        self.beta1_D = args.beta1_D
        self.beta2_D = args.beta2_D
        self.lr_r = args.lr_r
        self.beta1_r = args.beta1_r
        self.beta2_r = args.beta2_r
        ones = torch.Tensor(np.ones([self.z_dim])*0.5).to(self.device)  # 先创建一个自定义权值的Tensor,这里为了方便将所有权值设为1
        self.r = torch.nn.Parameter(ones)
        if args.dataset == 'dsprites':
            self.VAE = RF_VAE1(self.z_dim).to(self.device)
            self.nc = 1
        else:
            self.VAE = RF_VAE2(self.z_dim).to(self.device)
            self.nc = 3
        self.optim_VAE = optim.Adam(self.VAE.parameters(), lr=self.lr_VAE,
                                    betas=(self.beta1_VAE, self.beta2_VAE))

        self.D = Discriminator(self.z_dim).to(self.device)
        self.optim_D = optim.Adam(self.D.parameters(), lr=self.lr_D,
                                  betas=(self.beta1_D, self.beta2_D))
        self.optim_r = optim.Adam([self.r],lr=self.lr_r,
                                  betas=(self.beta1_r,self.beta2_r))
        self.nets = [self.VAE, self.D]

        # Visdom
        self.viz_on = args.viz_on
        self.win_id = dict(D_z='win_D_z', recon='win_recon', kld='win_kld', acc='win_acc',r_distribute = 'r_distribute')
        self.line_gather = DataGather('iter', 'soft_D_z', 'soft_D_z_pperm', 'recon', 'kld', 'acc','r_distribute')
        self.image_gather = DataGather('true', 'recon')
        if self.viz_on:
            self.viz_port = args.viz_port
            self.viz = visdom.Visdom(port=self.viz_port)
            self.viz_ll_iter = args.viz_ll_iter
            self.viz_la_iter = args.viz_la_iter
            self.viz_ra_iter = args.viz_ra_iter
            self.viz_ta_iter = args.viz_ta_iter

        # Checkpoint
        self.ckpt_dir = os.path.join(args.ckpt_dir, args.name)
        self.ckpt_save_iter = args.ckpt_save_iter
        mkdirs(self.ckpt_dir)
        if args.ckpt_load:
            self.load_checkpoint(args.ckpt_load)

        # Output(latent traverse GIF)
        self.output_dir = os.path.join(args.output_dir, args.name)
        self.output_save = args.output_save
        mkdirs(self.output_dir)

    def train(self):
        self.net_mode(train=True)

        ones = torch.ones(self.batch_size, dtype=torch.long, device=self.device)
        zeros = torch.zeros(self.batch_size, dtype=torch.long, device=self.device)

        out = False
        while not out:
            for x_true1, x_true2 in self.data_loader:
                self.global_iter += 1
                self.pbar.update(1)

                x_true1 = x_true1.to(self.device)
                x_recon, mu, logvar, z = self.VAE(x_true1)
                vae_recon_loss = recon_loss(x_true1, x_recon)
                vae_kld = kl_divergence(mu, logvar,self.r)
                H_r = entropy(self.r)

                D_z = self.D(self.r*z)
                vae_tc_loss = (D_z[:, :1] - D_z[:, 1:]).mean()

                vae_loss = vae_recon_loss + vae_kld + self.gamma*vae_tc_loss + self.etaS*self.r.abs().sum() + self.etaH*H_r

                self.optim_VAE.zero_grad()
                vae_loss.backward(retain_graph=True)
                self.optim_VAE.step()

                self.optim_r.zero_grad()
                vae_loss.backward(retain_graph=True)
                self.optim_r.step()

                x_true2 = x_true2.to(self.device)
                z_prime = self.VAE(x_true2, no_dec=True)
                z_pperm = permute_dims(z_prime).detach()
                D_z_pperm = self.D(self.r*z_pperm)
                D_tc_loss = 0.5*(F.cross_entropy(D_z, zeros) + F.cross_entropy(D_z_pperm, ones))

                self.optim_D.zero_grad()
                D_tc_loss.backward()
                self.optim_D.step()


                if self.global_iter%self.print_iter == 0:
                    self.pbar.write('[{}] vae_recon_loss:{:.3f} vae_kld:{:.3f} vae_tc_loss:{:.3f} D_tc_loss:{:.3f}'.format(
                        self.global_iter, vae_recon_loss.item(), vae_kld.item(), vae_tc_loss.item(), D_tc_loss.item()))

                if self.global_iter%self.ckpt_save_iter == 0:
                    self.save_checkpoint(self.global_iter)

                if self.viz_on and (self.global_iter%self.viz_ll_iter == 0):
                    soft_D_z = F.softmax(D_z, 1)[:, :1].detach()
                    soft_D_z_pperm = F.softmax(D_z_pperm, 1)[:, :1].detach()
                    D_acc = ((soft_D_z >= 0.5).sum() + (soft_D_z_pperm < 0.5).sum()).float()
                    D_acc /= 2*self.batch_size
                    self.line_gather.insert(iter=self.global_iter,
                                            soft_D_z=soft_D_z.mean().item(),
                                            soft_D_z_pperm=soft_D_z_pperm.mean().item(),
                                            recon=vae_recon_loss.item(),
                                            kld=vae_kld.item(),
                                            acc=D_acc.item(),
                                            r_distribute=self.r.data.cpu())

                if self.viz_on and (self.global_iter%self.viz_la_iter == 0):
                    self.visualize_line()
                    self.line_gather.flush()

                if self.viz_on and (self.global_iter%self.viz_ra_iter == 0):
                    self.image_gather.insert(true=x_true1.data.cpu(),
                                             recon=F.sigmoid(x_recon).data.cpu())
                    self.visualize_recon()
                    self.image_gather.flush()

                if self.viz_on and (self.global_iter%self.viz_ta_iter == 0):
                    if self.dataset.lower() == '3dchairs':
                        self.visualize_traverse(limit=2, inter=0.5)
                    else:
                        self.visualize_traverse(limit=3, inter=2/3)

                if self.global_iter >= self.max_iter:
                    out = True
                    break

        self.pbar.write("[Training Finished]")
        self.pbar.close()

    def visualize_recon(self):
        data = self.image_gather.data
        true_image = data['true'][0]
        recon_image = data['recon'][0]

        true_image = make_grid(true_image)
        recon_image = make_grid(recon_image)
        sample = torch.stack([true_image, recon_image], dim=0)
        self.viz.images(sample, env=self.name+'/recon_image',
                        opts=dict(title=str(self.global_iter)))

    def visualize_line(self):
        data = self.line_gather.data
        iters = torch.Tensor(data['iter'])
        recon = torch.Tensor(data['recon'])
        kld = torch.Tensor(data['kld'])
        D_acc = torch.Tensor(data['acc'])
        soft_D_z = torch.Tensor(data['soft_D_z'])
        soft_D_z_pperm = torch.Tensor(data['soft_D_z_pperm'])
        r_distribute = data['r_distribute'][-1]
        soft_D_zs = torch.stack([soft_D_z, soft_D_z_pperm], -1)
        if not self.viz.win_exists(env=self.name + '/lines', win=self.win_id['D_z']):
            self.viz.line(X=iters,
                          Y=soft_D_zs,
                          env=self.name + '/lines',
                          win=self.win_id['D_z'],
                          opts=dict(
                              xlabel='iteration',
                              ylabel='D(.)',
                              legend=['D(z)', 'D(z_perm)']))
        else:
            self.viz.line(X=iters,
                          Y=soft_D_zs,
                          env=self.name+'/lines',
                          win=self.win_id['D_z'],
                          update='append',
                          opts=dict(
                            xlabel='iteration',
                            ylabel='D(.)',
                            legend=['D(z)', 'D(z_perm)']))
        if not self.viz.win_exists(env=self.name + '/lines', win=self.win_id['recon']):
            self.viz.line(X=iters,
                          Y=recon,
                          env=self.name + '/lines',
                          win=self.win_id['recon'],
                          opts=dict(
                              xlabel='iteration',
                              ylabel='reconstruction loss', ))
        else:
            self.viz.line(X=iters,
                          Y=recon,
                          env=self.name+'/lines',
                          win=self.win_id['recon'],
                          update='append',
                          opts=dict(
                            xlabel='iteration',
                            ylabel='reconstruction loss',))
        if not self.viz.win_exists(env=self.name + '/lines', win=self.win_id['acc']):
            self.viz.line(X=iters,
                          Y=D_acc,
                          env=self.name + '/lines',
                          win=self.win_id['acc'],
                          opts=dict(
                              xlabel='iteration',
                              ylabel='discriminator accuracy', ))
        else:
            self.viz.line(X=iters,
                          Y=D_acc,
                          env=self.name+'/lines',
                          win=self.win_id['acc'],
                          update='append',
                          opts=dict(
                            xlabel='iteration',
                            ylabel='discriminator accuracy',))
        if not self.viz.win_exists(env=self.name + '/lines', win=self.win_id['kld']):
            self.viz.line(X=iters,
                          Y=kld,
                          env=self.name+'/lines',
                          win=self.win_id['kld'],
                          opts=dict(
                            xlabel='iteration',
                            ylabel='kl divergence',))
        else:
            self.viz.line(X=iters,
                          Y=kld,
                          env=self.name + '/lines',
                          win=self.win_id['kld'],
                          update='append',
                          opts=dict(
                              xlabel='iteration',
                              ylabel='kl divergence', ))
        if self.viz.win_exists(env=self.name + '/lines', win=self.win_id['r_distribute']):
            self.viz.close(win=self.win_id['r_distribute'],env=self.name + '/lines')
        self.viz.bar(X=r_distribute,
                      env=self.name + '/lines',
                      win=self.win_id['r_distribute'],
                      opts=dict(
                          xlabel='dimention',
                          ylabel='relevance score', ))

    def visualize_traverse(self, limit=3, inter=2/3, loc=-1):
        self.net_mode(train=False)

        decoder = self.VAE.decode
        encoder = self.VAE.encode
        interpolation = torch.arange(-limit, limit+0.1, inter)

        random_img = self.data_loader.dataset.__getitem__(0)[1]
        random_img = random_img.to(self.device).unsqueeze(0)
        random_img_z = encoder(random_img)[:, :self.z_dim]

        if self.dataset.lower() == 'dsprites':
            fixed_idx1 = 87040 # square
            fixed_idx2 = 332800 # ellipse
            fixed_idx3 = 578560 # heart

            fixed_img1 = self.data_loader.dataset.__getitem__(fixed_idx1)[0]
            fixed_img1 = fixed_img1.to(self.device).unsqueeze(0)
            fixed_img_z1 = encoder(fixed_img1)[:, :self.z_dim]

            fixed_img2 = self.data_loader.dataset.__getitem__(fixed_idx2)[0]
            fixed_img2 = fixed_img2.to(self.device).unsqueeze(0)
            fixed_img_z2 = encoder(fixed_img2)[:, :self.z_dim]

            fixed_img3 = self.data_loader.dataset.__getitem__(fixed_idx3)[0]
            fixed_img3 = fixed_img3.to(self.device).unsqueeze(0)
            fixed_img_z3 = encoder(fixed_img3)[:, :self.z_dim]

            Z = {'fixed_square':fixed_img_z1, 'fixed_ellipse':fixed_img_z2,
                 'fixed_heart':fixed_img_z3, 'random_img':random_img_z}

        elif self.dataset.lower() == 'celeba':
            fixed_idx1 = 191281 # 'CelebA/img_align_celeba/191282.jpg'
            fixed_idx2 = 143307 # 'CelebA/img_align_celeba/143308.jpg'
            fixed_idx3 = 101535 # 'CelebA/img_align_celeba/101536.jpg'
            fixed_idx4 = 70059  # 'CelebA/img_align_celeba/070060.jpg'

            fixed_img1 = self.data_loader.dataset.__getitem__(fixed_idx1)[0]
            fixed_img1 = fixed_img1.to(self.device).unsqueeze(0)
            fixed_img_z1 = encoder(fixed_img1)[:, :self.z_dim]

            fixed_img2 = self.data_loader.dataset.__getitem__(fixed_idx2)[0]
            fixed_img2 = fixed_img2.to(self.device).unsqueeze(0)
            fixed_img_z2 = encoder(fixed_img2)[:, :self.z_dim]

            fixed_img3 = self.data_loader.dataset.__getitem__(fixed_idx3)[0]
            fixed_img3 = fixed_img3.to(self.device).unsqueeze(0)
            fixed_img_z3 = encoder(fixed_img3)[:, :self.z_dim]

            fixed_img4 = self.data_loader.dataset.__getitem__(fixed_idx4)[0]
            fixed_img4 = fixed_img4.to(self.device).unsqueeze(0)
            fixed_img_z4 = encoder(fixed_img4)[:, :self.z_dim]

            Z = {'fixed_1':fixed_img_z1, 'fixed_2':fixed_img_z2,
                 'fixed_3':fixed_img_z3, 'fixed_4':fixed_img_z4,
                 'random':random_img_z}

        elif self.dataset.lower() == '3dchairs':
            fixed_idx1 = 40919 # 3DChairs/images/4682_image_052_p030_t232_r096.png
            fixed_idx2 = 5172  # 3DChairs/images/14657_image_020_p020_t232_r096.png
            fixed_idx3 = 22330 # 3DChairs/images/30099_image_052_p030_t232_r096.png

            fixed_img1 = self.data_loader.dataset.__getitem__(fixed_idx1)[0]
            fixed_img1 = fixed_img1.to(self.device).unsqueeze(0)
            fixed_img_z1 = encoder(fixed_img1)[:, :self.z_dim]

            fixed_img2 = self.data_loader.dataset.__getitem__(fixed_idx2)[0]
            fixed_img2 = fixed_img2.to(self.device).unsqueeze(0)
            fixed_img_z2 = encoder(fixed_img2)[:, :self.z_dim]

            fixed_img3 = self.data_loader.dataset.__getitem__(fixed_idx3)[0]
            fixed_img3 = fixed_img3.to(self.device).unsqueeze(0)
            fixed_img_z3 = encoder(fixed_img3)[:, :self.z_dim]

            Z = {'fixed_1':fixed_img_z1, 'fixed_2':fixed_img_z2,
                 'fixed_3':fixed_img_z3, 'random':random_img_z}
        else:
            fixed_idx = 0
            fixed_img = self.data_loader.dataset.__getitem__(fixed_idx)[0]
            fixed_img = fixed_img.to(self.device).unsqueeze(0)
            fixed_img_z = encoder(fixed_img)[:, :self.z_dim]

            random_z = torch.rand(1, self.z_dim, 1, 1, device=self.device)

            Z = {'fixed_img':fixed_img_z, 'random_img':random_img_z, 'random_z':random_z}

        gifs = []
        for key in Z:
            z_ori = Z[key]
            samples = []
            for row in range(self.z_dim):
                if loc != -1 and row != loc:
                    continue
                z = z_ori.clone()
                for val in interpolation:
                    z[:, row] = val
                    sample = F.sigmoid(decoder(z)).data
                    samples.append(sample)
                    gifs.append(sample)
            samples = torch.cat(samples, dim=0).cpu()
            title = '{}_latent_traversal(iter:{})'.format(key, self.global_iter)
            self.viz.images(samples, env=self.name+'/traverse',
                            opts=dict(title=title), nrow=len(interpolation))

        if self.output_save:
            output_dir = os.path.join(self.output_dir, str(self.global_iter))
            mkdirs(output_dir)
            gifs = torch.cat(gifs)
            gifs = gifs.view(len(Z), self.z_dim, len(interpolation), self.nc, 64, 64).transpose(1, 2)
            name_str = ''
            for i, key in enumerate(Z.keys()):
                for j, val in enumerate(interpolation):
                    save_image(tensor=gifs[i][j].cpu(),
                               filename=os.path.join(output_dir, '{}_{}.jpg'.format(key, j)),
                               nrow=self.z_dim, pad_value=1)
                    name_str = name_str + '{}_{}.jpg '.format(key, j)

                grid2gif(name_str,
                         str(os.path.join(output_dir, key+'.gif')),output_dir,delay=10)

        self.net_mode(train=True)


    def net_mode(self, train):
        if not isinstance(train, bool):
            raise ValueError('Only bool type is supported. True|False')

        for net in self.nets:
            if train:
                net.train()
            else:
                net.eval()

    def save_checkpoint(self, ckptname='last', verbose=True):
        model_states = {'D':self.D.state_dict(),
                        'VAE':self.VAE.state_dict(),
                        'r':self.r}
        optim_states = {'optim_D':self.optim_D.state_dict(),
                        'optim_VAE':self.optim_VAE.state_dict(),
                        'optim_r':self.optim_r.state_dict()}
        states = {'iter':self.global_iter,
                  'model_states':model_states,
                  'optim_states':optim_states}

        filepath = os.path.join(self.ckpt_dir, str(ckptname))
        with open(filepath, 'wb+') as f:
            torch.save(states, f)
        if verbose:
            self.pbar.write("=> saved checkpoint '{}' (iter {})".format(filepath, self.global_iter))

    def load_checkpoint(self, ckptname='last', verbose=True):
        if ckptname == 'last':
            ckpts = os.listdir(self.ckpt_dir)
            if not ckpts:
                if verbose:
                    self.pbar.write("=> no checkpoint found")
                return

            ckpts = [int(ckpt) for ckpt in ckpts]
            ckpts.sort(reverse=True)
            ckptname = str(ckpts[0])

        filepath = os.path.join(self.ckpt_dir, ckptname)
        if os.path.isfile(filepath):
            with open(filepath, 'rb') as f:
                checkpoint = torch.load(f)

            self.global_iter = checkpoint['iter']
            self.VAE.load_state_dict(checkpoint['model_states']['VAE'])
            self.D.load_state_dict(checkpoint['model_states']['D'])
            self.r = checkpoint['model_states']['r']
            self.optim_VAE.load_state_dict(checkpoint['optim_states']['optim_VAE'])
            self.optim_D.load_state_dict(checkpoint['optim_states']['optim_D'])
            self.optim_r.load_state_dict(checkpoint['optim_states']['optim_r'])
            self.pbar.update(self.global_iter)
            if verbose:
                self.pbar.write("=> loaded checkpoint '{} (iter {})'".format(filepath, self.global_iter))
        else:
            if verbose:
                self.pbar.write("=> no checkpoint found at '{}'".format(filepath))
Example #4
0
    def __init__(self, args):

        self.args = args

        self.name = '%s_lr_%s_a_%s_r_%s_k_%s' % \
                    (args.dataset_name, args.lr_VAE, args.alpha, args.gamma, args.k_fold)

        self.device = args.device
        self.temp=0.66
        self.dt=0.4
        self.eps=1e-9
        self.alpha=args.alpha
        self.gamma=args.gamma

        self.max_iter = int(args.max_iter)

        # do it every specified iters
        self.print_iter = args.print_iter
        self.ckpt_save_iter = args.ckpt_save_iter
        self.output_save_iter = args.output_save_iter

        # data info
        args.dataset_dir = os.path.join(args.dataset_dir, str(args.k_fold))

        self.dataset_dir = args.dataset_dir
        self.dataset_name = args.dataset_name

        # self.N = self.latent_values.shape[0]
        # self.eval_metrics_iter = args.eval_metrics_iter

        # networks and optimizers
        self.batch_size = args.batch_size
        self.lr_VAE = args.lr_VAE
        self.beta1_VAE = args.beta1_VAE
        self.beta2_VAE = args.beta2_VAE
        print(args.desc)


        # set run id
        self.run_id = args.run_id

        # finalize name
        self.name = self.name + '_run_' + str(self.run_id)

        # records (text file to store console outputs)
        self.record_file = 'records/%s.txt' % self.name

        # checkpoints
        self.ckpt_dir = os.path.join("ckpts", self.name)

        # outputs
        self.output_dir_recon = os.path.join("outputs", self.name + '_recon')

        #### create a new model or load a previously saved model

        self.ckpt_load_iter = args.ckpt_load_iter

        self.obs_len = args.obs_len
        self.pred_len = args.pred_len


        # visdom setup
        self.viz_on = args.viz_on
        if self.viz_on:
            self.win_id = dict(
                map_loss='win_map_loss', test_map_loss='win_test_map_loss'
            )
            self.line_gather = DataGather(
                'iter', 'loss', 'test_loss'
            )

            import visdom

            self.viz_port = args.viz_port  # port number, eg, 8097
            self.viz = visdom.Visdom(port=self.viz_port, env=self.name)
            self.viz_ll_iter = args.viz_ll_iter
            self.viz_la_iter = args.viz_la_iter

            self.viz_init()

        # create dirs: "records", "ckpts", "outputs" (if not exist)
        mkdirs("records");
        mkdirs("ckpts");
        mkdirs("outputs")

        if self.ckpt_load_iter == 0 or args.dataset_name =='all':  # create a new model
            # self.encoder = Encoder(
            #     fc_hidden_dim=args.hidden_dim,
            #     output_dim=args.latent_dim,
            #     drop_out=args.dropout_map).to(self.device)
            #
            # self.decoder = Decoder(
            #     fc_hidden_dim=args.hidden_dim,
            #     input_dim=args.latent_dim).to(self.device)

            num_filters = [32, 32, 32, 64, 64, 32, 32]
            # input = env + 8 past + lg / output = env + sg(including lg)
            self.sg_unet = Unet(input_channels=1, num_classes=1, num_filters=num_filters,
                             apply_last_layer=True, padding=True).to(self.device)


        else:  # load a previously saved model
            print('Loading saved models (iter: %d)...' % self.ckpt_load_iter)
            self.load_checkpoint()
            print('...done')


        # get VAE parameters
        # vae_params = \
        #     list(self.encoder.parameters()) + \
        #     list(self.decoder.parameters())
        vae_params = \
            list(self.sg_unet.parameters())

        # create optimizers
        self.optim_vae = optim.Adam(
            vae_params,
            lr=self.lr_VAE,
            betas=[self.beta1_VAE, self.beta2_VAE]
        )

        # prepare dataloader (iterable)
        print('Start loading data...')
        if self.ckpt_load_iter != self.max_iter:
            print("Initializing train dataset")
            _, self.train_loader = data_loader(self.args, args.dataset_dir, 'train', shuffle=True)
            print("Initializing val dataset")
            self.args.batch_size = 1
            _, self.val_loader = data_loader(self.args, args.dataset_dir, 'test', shuffle=False)
            self.args.batch_size = args.batch_size

            print(
                'There are {} iterations per epoch'.format(len(self.train_loader.dataset) / args.batch_size)
            )
        print('...done')
Example #5
0
class Solver(object):

    ####
    def __init__(self, args):

        self.args = args

        self.name = '%s_lr_%s_a_%s_r_%s_k_%s' % \
                    (args.dataset_name, args.lr_VAE, args.alpha, args.gamma, args.k_fold)

        self.device = args.device
        self.temp=0.66
        self.dt=0.4
        self.eps=1e-9
        self.alpha=args.alpha
        self.gamma=args.gamma

        self.max_iter = int(args.max_iter)

        # do it every specified iters
        self.print_iter = args.print_iter
        self.ckpt_save_iter = args.ckpt_save_iter
        self.output_save_iter = args.output_save_iter

        # data info
        args.dataset_dir = os.path.join(args.dataset_dir, str(args.k_fold))

        self.dataset_dir = args.dataset_dir
        self.dataset_name = args.dataset_name

        # self.N = self.latent_values.shape[0]
        # self.eval_metrics_iter = args.eval_metrics_iter

        # networks and optimizers
        self.batch_size = args.batch_size
        self.lr_VAE = args.lr_VAE
        self.beta1_VAE = args.beta1_VAE
        self.beta2_VAE = args.beta2_VAE
        print(args.desc)


        # set run id
        self.run_id = args.run_id

        # finalize name
        self.name = self.name + '_run_' + str(self.run_id)

        # records (text file to store console outputs)
        self.record_file = 'records/%s.txt' % self.name

        # checkpoints
        self.ckpt_dir = os.path.join("ckpts", self.name)

        # outputs
        self.output_dir_recon = os.path.join("outputs", self.name + '_recon')

        #### create a new model or load a previously saved model

        self.ckpt_load_iter = args.ckpt_load_iter

        self.obs_len = args.obs_len
        self.pred_len = args.pred_len


        # visdom setup
        self.viz_on = args.viz_on
        if self.viz_on:
            self.win_id = dict(
                map_loss='win_map_loss', test_map_loss='win_test_map_loss'
            )
            self.line_gather = DataGather(
                'iter', 'loss', 'test_loss'
            )

            import visdom

            self.viz_port = args.viz_port  # port number, eg, 8097
            self.viz = visdom.Visdom(port=self.viz_port, env=self.name)
            self.viz_ll_iter = args.viz_ll_iter
            self.viz_la_iter = args.viz_la_iter

            self.viz_init()

        # create dirs: "records", "ckpts", "outputs" (if not exist)
        mkdirs("records");
        mkdirs("ckpts");
        mkdirs("outputs")

        if self.ckpt_load_iter == 0 or args.dataset_name =='all':  # create a new model
            # self.encoder = Encoder(
            #     fc_hidden_dim=args.hidden_dim,
            #     output_dim=args.latent_dim,
            #     drop_out=args.dropout_map).to(self.device)
            #
            # self.decoder = Decoder(
            #     fc_hidden_dim=args.hidden_dim,
            #     input_dim=args.latent_dim).to(self.device)

            num_filters = [32, 32, 32, 64, 64, 32, 32]
            # input = env + 8 past + lg / output = env + sg(including lg)
            self.sg_unet = Unet(input_channels=1, num_classes=1, num_filters=num_filters,
                             apply_last_layer=True, padding=True).to(self.device)


        else:  # load a previously saved model
            print('Loading saved models (iter: %d)...' % self.ckpt_load_iter)
            self.load_checkpoint()
            print('...done')


        # get VAE parameters
        # vae_params = \
        #     list(self.encoder.parameters()) + \
        #     list(self.decoder.parameters())
        vae_params = \
            list(self.sg_unet.parameters())

        # create optimizers
        self.optim_vae = optim.Adam(
            vae_params,
            lr=self.lr_VAE,
            betas=[self.beta1_VAE, self.beta2_VAE]
        )

        # prepare dataloader (iterable)
        print('Start loading data...')
        if self.ckpt_load_iter != self.max_iter:
            print("Initializing train dataset")
            _, self.train_loader = data_loader(self.args, args.dataset_dir, 'train', shuffle=True)
            print("Initializing val dataset")
            self.args.batch_size = 1
            _, self.val_loader = data_loader(self.args, args.dataset_dir, 'test', shuffle=False)
            self.args.batch_size = args.batch_size

            print(
                'There are {} iterations per epoch'.format(len(self.train_loader.dataset) / args.batch_size)
            )
        print('...done')


    def preprocess_map(self, local_map, aug=False):
        local_map = torch.from_numpy(local_map).float().to(self.device)

        if aug:
            all_heatmaps = []
            for h in local_map:
                h = torch.tensor(h).float().to(self.device)
                degree = np.random.choice([0, 90, 180, -90])
                all_heatmaps.append(
                    transforms.Compose([
                        transforms.RandomRotation(degrees=(degree, degree))
                    ])(h)
                )
            all_heatmaps = torch.stack(all_heatmaps)
        else:
            all_heatmaps = local_map
        return all_heatmaps



    ####
    def train(self):
        self.set_mode(train=True)
        torch.autograd.set_detect_anomaly(True)
        data_loader = self.train_loader
        self.N = len(data_loader.dataset)
        iterator = iter(data_loader)

        iter_per_epoch = len(iterator)
        start_iter = self.ckpt_load_iter + 1
        epoch = int(start_iter / iter_per_epoch)

        for iteration in range(start_iter, self.max_iter + 1):

            # reset data iterators for each epoch
            if iteration % iter_per_epoch == 0:
                print('==== epoch %d done ====' % epoch)
                epoch +=1
                iterator = iter(data_loader)

            # ============================================
            #          TRAIN THE VAE (ENC & DEC)
            # ============================================


            (obs_traj, fut_traj, obs_traj_st, fut_vel_st, seq_start_end,
             obs_frames, pred_frames, map_path, inv_h_t,
             local_map, local_ic, local_homo) = next(iterator)

            sampled_local_map = []
            for s, e in seq_start_end:
                rng = list(range(s,e))
                random.shuffle(rng)
                sampled_local_map.append(local_map[rng[:2]])

            sampled_local_map = np.concatenate(sampled_local_map)

            batch_size = sampled_local_map.shape[0]

            local_map = self.preprocess_map(sampled_local_map, aug=True)

            recon_local_map = self.sg_unet.forward(local_map)
            recon_local_map = F.sigmoid(recon_local_map)


            focal_loss =  F.mse_loss(recon_local_map, local_map).sum().div(batch_size)

            self.optim_vae.zero_grad()
            focal_loss.backward()
            self.optim_vae.step()


            # save model parameters
            if (iteration % (iter_per_epoch*10) == 0):
                self.save_checkpoint(epoch)

            # (visdom) insert current line stats
            if iteration == iter_per_epoch or (self.viz_on and (iteration % (iter_per_epoch * 10) == 0)):
                test_recon_map_loss = self.test()
                self.line_gather.insert(iter=epoch,
                                        loss=focal_loss.item(),
                                        test_loss= test_recon_map_loss.item(),
                                        )
                prn_str = ('[iter_%d (epoch_%d)] loss: %.3f \n'
                          ) % \
                          (iteration, epoch,
                           focal_loss.item())

                print(prn_str)
                self.visualize_line()
                self.line_gather.flush()


    def test(self):
        self.set_mode(train=False)
        loss=0
        b = 0
        with torch.no_grad():
            for abatch in self.val_loader:
                b += 1

                (obs_traj, fut_traj, obs_traj_st, fut_vel_st, seq_start_end,
                 obs_frames, pred_frames, map_path, inv_h_t,
                 local_map, local_ic, local_homo) = abatch
                batch_size = obs_traj.size(1)  # =sum(seq_start_end[:,1] - seq_start_end[:,0])
                local_map = self.preprocess_map(local_map, aug=False)

                recon_local_map = self.sg_unet.forward(local_map)
                recon_local_map = F.sigmoid(recon_local_map)

                focal_loss = F.mse_loss(recon_local_map, local_map).sum().div(batch_size)

                loss += focal_loss
        self.set_mode(train=True)
        return loss.div(b)

    ####


    def make_feat(self, test_loader):
        from sklearn.manifold import TSNE
        # from data.trajectories import seq_collate

        # from data.macro_trajectories import TrajectoryDataset
        # from torch.utils.data import DataLoader

        # test_dset = TrajectoryDataset('../datasets/large_real/Trajectories', data_split='test', device=self.device)
        # test_loader = DataLoader(dataset=test_dset, batch_size=1,
        #                              shuffle=True, num_workers=0)

        self.set_mode(train=False)
        with torch.no_grad():
            test_enc_feat = []
            total_scenario = []
            b = 0
            for batch in test_loader:
                b+=1
                if len(test_enc_feat) > 0 and np.concatenate(test_enc_feat).shape[0] > 1000:
                    break
                (obs_traj, fut_traj, obs_traj_st, fut_vel_st, seq_start_end,
                 obs_frames, fut_frames, map_path, inv_h_t,
                 local_map, local_ic, local_homo) = batch

                rng = list(range(len(local_map)))
                random.shuffle(rng)
                sampling_idx = rng[:32]
                local_map1 = local_map[sampling_idx]
                local_map1 = self.preprocess_map(local_map1, aug=False)

                self.sg_unet.forward(local_map1)
                test_enc_feat.append(self.sg_unet.enc_feat.view(len(local_map1), -1).detach().cpu().numpy())

                for m in map_path[sampling_idx]:
                    total_scenario.append(int(m.split('/')[-1].split('.')[0]))


            import matplotlib.pyplot as plt
            test_enc_feat = np.concatenate(test_enc_feat)
            print(test_enc_feat.shape)

            # tsne = TSNE(n_components=2, random_state=0)
            # tsne_feat = tsne.fit_transform(test_enc_feat)
            all_feat = np.concatenate([test_enc_feat, np.expand_dims(np.array(total_scenario),1)], 1)

            np.save('large_tsne_r10_k0_tr.npy', all_feat)
            print('done')

            '''
            import pandas as  pd
            df = pd.read_csv('C:\dataset\large_real/large_5_bs1.csv')
            data = np.array(df)


            # all_feat = np.load('large_tsne_ae1_tr.npy')
            all_feat_tr = np.load('large_tsne_lg_k0_tr.npy')
            all_feat_te = np.load('large_tsne_lg_k0_te.npy')
            # tsne_faet = np.concatenate([all_feat[:,:2], all_feat_te[:,:2]])
            all_feat = np.concatenate([all_feat_tr[:,:-3], all_feat_te[:,:-3]])
            tsne = TSNE(n_components=2, random_state=0, perplexity=30)
            tsne_feat = tsne.fit_transform(all_feat)


            # tsne_faet = all_feat_tr[:,:-3]
            # obst_ratio = all_feat_tr[:,-3]
            # curv = all_feat_tr[:,-2]
            # scenario = all_feat_tr[:,-1]

            tsne_faet = all_feat_tr[:,:-3]
            obst_ratio = all_feat_tr[:,-3]
            curv = all_feat_tr[:,-2]
            scenario =  np.concatenate([all_feat_tr[:,-1], all_feat_te[:,-1]])
            labels = scenario //10

            labels = obst_ratio*100 //10
            # labels = curv*100 //10

            target_names = ['Training', 'Test']
            colors = np.array(['blue', 'red'])
            labels= np.array(df['0.5']) // 10
            labels= np.array(df['# agent']) //10
            labels= np.array(df['curvature'])*100 //10
            labels= np.array(df['map ratio'])*100 //10


            ## k fold labels
            k=0
            labels = scenario //10
            for i in range(len(labels)):
                if labels[i] in range(k*3,(k+1)*3):
                    labels[i] = 0
                else:
                    labels[i] = 1



            # colors = ['red', 'magenta', 'lightgreen', 'slateblue', 'blue', 'darkgreen', 'darkorange',
            #           'gray', 'purple', 'turquoise', 'midnightblue', 'olive', 'black', 'pink', 'burlywood',
            #           'yellow']

            colors = np.array(['gray','pink', 'orange', 'magenta', 'darkgreen', 'cyan', 'blue', 'red', 'lightgreen', 'olive', 'burlywood', 'purple'])
            target_names = np.unique(labels)

            fig = plt.figure(figsize=(5,4))
            fig.tight_layout()

            # labels = np.concatenate([np.zeros(len(all_feat_tr)), np.ones(len(all_feat_te))])
            target_names = ['Training', 'Test']
            colors = np.array(['blue', 'red'])

            for color, i, target_name in zip(colors, np.unique(labels), target_names):
                plt.scatter(tsne_feat[labels == i, 0], tsne_feat[labels == i, 1], alpha=.5, color=color,
                            label=str(target_name), s=10)
            fig.axes[0]._get_axis_list()[0].set_visible(False)
            fig.axes[0]._get_axis_list()[1].set_visible(False)
            plt.legend(loc=0, shadow=False, scatterpoints=1)
            '''

    ####
    def viz_init(self):
        self.viz.close(env=self.name, win=self.win_id['test_map_loss'])
        self.viz.close(env=self.name, win=self.win_id['map_loss'])

    ####
    def visualize_line(self):

        # prepare data to plot
        data = self.line_gather.data
        iters = torch.Tensor(data['iter'])
        test_map_loss = torch.Tensor(data['test_loss'])
        map_loss = torch.Tensor(data['loss'])

        self.viz.line(
            X=iters, Y=map_loss, env=self.name,
            win=self.win_id['map_loss'], update='append',
            opts=dict(xlabel='iter', ylabel='loss',
                      title='Recon. map loss')
        )


        self.viz.line(
            X=iters, Y=test_map_loss, env=self.name,
            win=self.win_id['test_map_loss'], update='append',
            opts=dict(xlabel='iter', ylabel='test_loss',
                      title='Recon. map loss - Test'),
        )


    #
    #
    # def set_mode(self, train=True):
    #
    #     if train:
    #         self.encoder.train()
    #         self.decoder.train()
    #     else:
    #         self.encoder.eval()
    #         self.decoder.eval()
    #
    # ####
    # def save_checkpoint(self, iteration):
    #
    #     encoder_path = os.path.join(
    #         self.ckpt_dir,
    #         'iter_%s_encoder.pt' % iteration
    #     )
    #     decoder_path = os.path.join(
    #         self.ckpt_dir,
    #         'iter_%s_decoder.pt' % iteration
    #     )
    #
    #
    #     mkdirs(self.ckpt_dir)
    #
    #     torch.save(self.encoder, encoder_path)
    #     torch.save(self.decoder, decoder_path)
    ####


    def set_mode(self, train=True):

        if train:
            self.sg_unet.train()
        else:
            self.sg_unet.eval()

    ####
    def save_checkpoint(self, iteration):

        sg_unet_path = os.path.join(
            self.ckpt_dir,
            'iter_%s_sg_unet.pt' % iteration
        )
        mkdirs(self.ckpt_dir)
        torch.save(self.sg_unet, sg_unet_path)


    ####
    def load_checkpoint(self):
        sg_unet_path = os.path.join(
            self.ckpt_dir,
            'iter_%s_sg_unet.pt' % self.ckpt_load_iter
        )

        if self.device == 'cuda':
            sg_unet_path = 'ckpts/large.map.ae_lr_0.0001_a_0.25_r_2.0_run_8/iter_100_sg_unet.pt'
            sg_unet_path = 'ckpts/large.map.ae_lr_0.0001_a_0.25_r_2.0_k_0_run_9/iter_100_sg_unet.pt'
            print('>>>>>>>>>>> load: ', sg_unet_path)

            self.sg_unet = torch.load(sg_unet_path)
        else:
            sg_unet_path = 'ckpts/large.map.ae_lr_0.0001_a_0.25_r_2.0_run_8/iter_100_sg_unet.pt'
            sg_unet_path = 'ckpts/large.map.ae_lr_0.0001_a_0.25_r_2.0_k_1_run_10/iter_200_sg_unet.pt'
            # sg_unet_path = 'ckpts/large.map.ae_lr_0.0001_a_0.25_r_2.0_k_0_run_10/iter_200_sg_unet.pt'
            sg_unet_path = 'ckpts/large.map.ae_lr_0.0001_a_0.25_r_2.0_k_0_run_9/iter_40_sg_unet.pt'
            # sg_unet_path = 'd:\crowd\mcrowd\ckpts\mapae.path_lr_0.001_a_0.25_r_2.0_run_2/iter_3360_sg_unet.pt'
            self.sg_unet = torch.load(sg_unet_path, map_location='cpu')
         ####

    #
    # def load_checkpoint(self):
    #
    #     encoder_path = os.path.join(
    #         self.ckpt_dir,
    #         'iter_%s_encoder.pt' % self.ckpt_load_iter
    #     )
    #     decoder_path = os.path.join(
    #         self.ckpt_dir,
    #         'iter_%s_decoder.pt' % self.ckpt_load_iter
    #     )
    #
    #     if self.device == 'cuda':
    #         self.encoder = torch.load(encoder_path)
    #         self.decoder = torch.load(decoder_path)
    #     else:
    #         self.encoder = torch.load(encoder_path, map_location='cpu')
    #         self.decoder = torch.load(decoder_path, map_location='cpu')
    #
    # def load_map_weights(self, map_path):
    #     if self.device == 'cuda':
    #         loaded_map_w = torch.load(map_path)
    #     else:
    #         loaded_map_w = torch.load(map_path, map_location='cpu')
    #     self.encoder.conv1.weight = loaded_map_w.map_net.conv1.weight
    #     self.encoder.conv2.weight = loaded_map_w.map_net.conv2.weight
    #     self.encoder.conv3.weight = loaded_map_w.map_net.conv3.weight
Example #6
0
    def __init__(self, args):
        self.args = args

        self.name = '%s_map_pred_len_%s_zS_%s_dr_mlp_%s_dr_rnn_%s_dr_map_%s_enc_h_dim_%s_dec_h_dim_%s_mlp_dim_%s_emb_dim_%s_lr_%s_klw_%s_map_%s' % \
                    (args.dataset_name, args.pred_len, args.zS_dim, args.dropout_mlp, args.dropout_rnn, args.dropout_map, args.encoder_h_dim,
                     args.decoder_h_dim, args.mlp_dim, args.emb_dim, args.lr_VAE, args.kl_weight, args.map_size)

        # to be appended by run_id

        # self.use_cuda = args.cuda and torch.cuda.is_available()
        self.device = args.device
        self.temp = 0.66
        self.eps = 1e-9
        self.kl_weight = args.kl_weight

        self.max_iter = int(args.max_iter)

        # do it every specified iters
        self.print_iter = args.print_iter
        self.ckpt_save_iter = args.ckpt_save_iter
        self.output_save_iter = args.output_save_iter

        # data info
        self.dataset_dir = args.dataset_dir
        self.dataset_name = args.dataset_name

        # self.N = self.latent_values.shape[0]
        # self.eval_metrics_iter = args.eval_metrics_iter

        # networks and optimizers
        self.batch_size = args.batch_size
        self.zS_dim = args.zS_dim
        self.lr_VAE = args.lr_VAE
        self.beta1_VAE = args.beta1_VAE
        self.beta2_VAE = args.beta2_VAE
        print(args.desc)

        # visdom setup
        self.viz_on = args.viz_on
        if self.viz_on:
            self.win_id = dict(recon='win_recon',
                               loss_kl='win_loss_kl',
                               loss_recon='win_loss_recon',
                               total_loss='win_total_loss',
                               ade_min='win_ade_min',
                               fde_min='win_fde_min',
                               ade_avg='win_ade_avg',
                               fde_avg='win_fde_avg',
                               ade_std='win_ade_std',
                               fde_std='win_fde_std',
                               test_loss_recon='win_test_loss_recon',
                               test_loss_kl='win_test_loss_kl',
                               test_total_loss='win_test_total_loss')
            self.line_gather = DataGather('iter', 'loss_recon', 'loss_kl',
                                          'total_loss', 'ade_min', 'fde_min',
                                          'ade_avg', 'fde_avg', 'ade_std',
                                          'fde_std', 'test_loss_recon',
                                          'test_loss_kl', 'test_total_loss')

            import visdom

            self.viz_port = args.viz_port  # port number, eg, 8097
            self.viz = visdom.Visdom(port=self.viz_port)
            self.viz_ll_iter = args.viz_ll_iter
            self.viz_la_iter = args.viz_la_iter

            self.viz_init()

        # create dirs: "records", "ckpts", "outputs" (if not exist)
        mkdirs("records")
        mkdirs("ckpts")
        mkdirs("outputs")

        # set run id
        if args.run_id < 0:  # create a new id
            k = 0
            rfname = os.path.join("records", self.name + '_run_0.txt')
            while os.path.exists(rfname):
                k += 1
                rfname = os.path.join("records", self.name + '_run_%d.txt' % k)
            self.run_id = k
        else:  # user-provided id
            self.run_id = args.run_id

        # finalize name
        self.name = self.name + '_run_' + str(self.run_id)

        # records (text file to store console outputs)
        self.record_file = 'records/%s.txt' % self.name

        # checkpoints
        self.ckpt_dir = os.path.join("ckpts", self.name)

        # outputs
        self.output_dir_recon = os.path.join("outputs", self.name + '_recon')
        # dir for reconstructed images
        self.output_dir_synth = os.path.join("outputs", self.name + '_synth')
        # dir for synthesized images
        self.output_dir_trvsl = os.path.join("outputs", self.name + '_trvsl')

        #### create a new model or load a previously saved model

        self.ckpt_load_iter = args.ckpt_load_iter

        self.obs_len = args.obs_len
        self.pred_len = args.pred_len
        self.num_layers = args.num_layers
        self.decoder_h_dim = args.decoder_h_dim

        if self.ckpt_load_iter == 0 or args.dataset_name == 'all':  # create a new model
            self.encoderMx = Encoder(args.zS_dim,
                                     enc_h_dim=args.encoder_h_dim,
                                     mlp_dim=args.mlp_dim,
                                     emb_dim=args.emb_dim,
                                     map_size=args.map_size,
                                     batch_norm=args.batch_norm,
                                     num_layers=args.num_layers,
                                     dropout_mlp=args.dropout_mlp,
                                     dropout_rnn=args.dropout_rnn,
                                     dropout_map=args.dropout_map).to(
                                         self.device)
            self.encoderMy = EncoderY(args.zS_dim,
                                      enc_h_dim=args.encoder_h_dim,
                                      mlp_dim=args.mlp_dim,
                                      emb_dim=args.emb_dim,
                                      map_size=args.map_size,
                                      num_layers=args.num_layers,
                                      dropout_rnn=args.dropout_rnn,
                                      dropout_map=args.dropout_map,
                                      device=self.device).to(self.device)
            self.decoderMy = Decoder(args.pred_len,
                                     dec_h_dim=self.decoder_h_dim,
                                     enc_h_dim=args.encoder_h_dim,
                                     mlp_dim=args.mlp_dim,
                                     z_dim=args.zS_dim,
                                     num_layers=args.num_layers,
                                     device=args.device,
                                     dropout_rnn=args.dropout_rnn).to(
                                         self.device)

        else:  # load a previously saved model
            print('Loading saved models (iter: %d)...' % self.ckpt_load_iter)
            self.load_checkpoint()
            print('...done')

        # get VAE parameters
        vae_params = \
            list(self.encoderMx.parameters()) + \
            list(self.encoderMy.parameters()) + \
            list(self.decoderMy.parameters())

        # create optimizers
        self.optim_vae = optim.Adam(vae_params,
                                    lr=self.lr_VAE,
                                    betas=[self.beta1_VAE, self.beta2_VAE])

        ######## map
        # self.map = imageio.imread('D:\crowd\ewap_dataset\seq_' + self.dataset_name + '/map.png')
        # h = np.loadtxt('D:\crowd\ewap_dataset\seq_' + self.dataset_name + '\H.txt')
        # self.inv_h_t = np.linalg.pinv(np.transpose(h))
        self.map_size = args.map_size
        ######################################
        # prepare dataloader (iterable)
        print('Start loading data...')
        train_path = os.path.join(self.dataset_dir, self.dataset_name, 'train')
        val_path = os.path.join(self.dataset_dir, self.dataset_name, 'test')

        # long_dtype, float_dtype = get_dtypes(args)

        print("Initializing train dataset")
        if self.dataset_name == 'eth':
            self.args.pixel_distance = 5  # for hotel
        else:
            self.args.pixel_distance = 3  # for eth
        _, self.train_loader = data_loader(self.args, train_path)
        print("Initializing val dataset")
        if self.dataset_name == 'eth':
            self.args.pixel_distance = 3
        else:
            self.args.pixel_distance = 5
        _, self.val_loader = data_loader(self.args, val_path)
        # self.val_loader = self.train_loader

        print('There are {} iterations per epoch'.format(
            len(self.train_loader.dataset) / args.batch_size))
        print('...done')
class Solver(object):
    def __init__(self, args):
        # Misc
        use_cuda = args.cuda and torch.cuda.is_available()
        self.device = 'cuda' if use_cuda else 'cpu'
        self.name = args.name
        self.max_iter = int(args.max_iter)
        self.print_iter = args.print_iter
        self.global_iter = 0
        self.global_iter_cls = 0
        self.pbar = tqdm(total=self.max_iter)
        self.pbar_cls = tqdm(total=self.max_iter)

        # Data
        self.dset_dir = args.dset_dir
        self.dataset = args.dataset
        self.batch_size = args.batch_size
        self.eval_batch_size = args.eval_batch_size
        self.data_loader = return_data(args, 0)
        self.data_loader_eval = return_data(args, 2)

        # Networks & Optimizers
        self.z_dim = args.z_dim
        self.gamma = args.gamma
        self.beta = args.beta

        self.lr_VAE = args.lr_VAE
        self.beta1_VAE = args.beta1_VAE
        self.beta2_VAE = args.beta2_VAE

        self.lr_D = args.lr_D
        self.beta1_D = args.beta1_D
        self.beta2_D = args.beta2_D
        self.alpha = args.alpha
        self.beta = args.beta
        self.grl = args.grl

        self.lr_cls = args.lr_cls
        self.beta1_cls = args.beta1_D
        self.beta2_cls = args.beta2_D

        if args.dataset == 'dsprites':
            self.VAE = FactorVAE1(self.z_dim).to(self.device)
            self.nc = 1
        else:
            self.VAE = FactorVAE2(self.z_dim).to(self.device)
            self.nc = 3
        self.optim_VAE = optim.Adam(self.VAE.parameters(),
                                    lr=self.lr_VAE,
                                    betas=(self.beta1_VAE, self.beta2_VAE))

        self.pacls = classifier(30, 2).cuda()
        self.revcls = classifier(30, 2).cuda()
        self.tcls = classifier(30, 2).cuda()
        self.trevcls = classifier(30, 2).cuda()

        self.targetcls = classifier(59, 2).cuda()
        self.pa_target = classifier(30, 2).cuda()
        self.target_pa = paclassifier(1, 1).cuda()
        self.pa_pa = classifier(30, 2).cuda()

        self.D = Discriminator(self.z_dim).to(self.device)
        self.optim_D = optim.Adam(self.D.parameters(),
                                  lr=self.lr_D,
                                  betas=(self.beta1_D, self.beta2_D))

        self.optim_pacls = optim.Adam(self.pacls.parameters(), lr=self.lr_D)

        self.optim_revcls = optim.Adam(self.revcls.parameters(), lr=self.lr_D)

        self.optim_tcls = optim.Adam(self.tcls.parameters(), lr=self.lr_D)
        self.optim_trevcls = optim.Adam(self.trevcls.parameters(),
                                        lr=self.lr_D)

        self.optim_cls = optim.Adam(self.targetcls.parameters(),
                                    lr=self.lr_cls)
        self.optim_pa_target = optim.Adam(self.pa_target.parameters(),
                                          lr=self.lr_cls)
        self.optim_target_pa = optim.Adam(self.target_pa.parameters(),
                                          lr=self.lr_cls)
        self.optim_pa_pa = optim.Adam(self.pa_pa.parameters(), lr=self.lr_cls)

        self.nets = [
            self.VAE, self.D, self.pacls, self.targetcls, self.revcls,
            self.pa_target, self.tcls, self.trevcls
        ]

        # Visdom
        self.viz_on = args.viz_on
        self.win_id = dict(D_z='win_D_z',
                           recon='win_recon',
                           kld='win_kld',
                           acc='win_acc')
        self.line_gather = DataGather('iter', 'soft_D_z', 'soft_D_z_pperm',
                                      'recon', 'kld', 'acc')
        self.image_gather = DataGather('true', 'recon')
        if self.viz_on:
            self.viz_port = args.viz_port
            self.viz = visdom.Visdom(port=self.viz_port)
            self.viz_ll_iter = args.viz_ll_iter
            self.viz_la_iter = args.viz_la_iter
            self.viz_ra_iter = args.viz_ra_iter
            self.viz_ta_iter = args.viz_ta_iter
            if not self.viz.win_exists(env=self.name + '/lines',
                                       win=self.win_id['D_z']):
                self.viz_init()

        # Checkpoint
        self.ckpt_dir = os.path.join(args.ckpt_dir, args.name)
        self.ckpt_save_iter = args.ckpt_save_iter
        mkdirs(self.ckpt_dir + "/cls")
        mkdirs(self.ckpt_dir + "/vae")

        if args.ckpt_load:

            self.load_checkpoint(args.ckpt_load)

        # Output(latent traverse GIF)
        self.output_dir = os.path.join(args.output_dir, args.name)
        self.output_save = args.output_save
        mkdirs(self.output_dir)

    def train(self):
        self.net_mode(train=True)

        ones = torch.ones(self.batch_size,
                          dtype=torch.long,
                          device=self.device)
        zeros = torch.zeros(self.batch_size,
                            dtype=torch.long,
                            device=self.device)

        out = False

        for i_num in range(self.max_iter - self.global_iter):
            total_pa_num = 0
            total_pa_correct_num = 0
            total_male_num = 0
            total_male_correct = 0
            total_female_num = 0
            total_female_correct = 0

            total_rev_num = 0
            total_rev_correct_num = 0

            total_t_num = 0
            total_t_correct_num = 0
            total_t_rev_num = 0
            total_t_rev_correct_num = 0

            for i, (x_true1, x_true2, heavy_makeup,
                    male) in enumerate(self.data_loader):
                #from PIL import Image
                #from torchvision import transforms
                #import pdb;pdb.set_trace()
                #import pdb;pdb.set_trace()

                heavy_makeup = heavy_makeup.to(self.device)
                male = male.to(self.device)
                x_true1 = x_true1.to(self.device)
                x_recon, mu, logvar, z = self.VAE(x_true1)
                vae_recon_loss = recon_loss(x_true1, x_recon)
                vae_kld = kl_divergence(mu, logvar)

                D_z = self.D(z)
                vae_tc_loss = (D_z[:, :1] - D_z[:, 1:]).mean()

                z_reverse = grad_reverse(z.split(30, 1)[-1])
                #z_reverse=z.split(10,1)[-1]
                reverse_output = self.revcls(z_reverse)
                output = self.pacls(z.split(30, 1)[-1])

                z_t_reverse = grad_reverse(z.split(30, 1)[0])
                #z_reverse=z.split(10,1)[-1]
                t_reverse_output = self.trevcls(z_t_reverse)
                t_output = self.tcls(z.split(30, 1)[0])

                #if i==0:
                #    print(output.argmax(1))
                #    # print(t_reverse_output.argmax(1))

                rev_correct = (
                    reverse_output.argmax(1) == heavy_makeup).sum().float()
                rev_num = heavy_makeup.size(0)

                pa_correct = (output.argmax(1) == male).sum().float()
                pa_num = male.size(0)

                t_correct = (t_output.argmax(1) == heavy_makeup).sum().float()
                t_num = heavy_makeup.size(0)

                t_rev_correct = (
                    t_reverse_output.argmax(1) == male).sum().float()
                t_rev_num = male.size(0)

                total_pa_correct_num += pa_correct
                total_pa_num += pa_num

                total_rev_correct_num += rev_correct
                total_rev_num += rev_num

                total_t_correct_num += t_correct
                total_t_num += t_num

                total_t_rev_correct_num += t_rev_correct
                total_t_rev_num += t_rev_num

                total_male_num += (male == 1).sum()
                total_female_num += (male == 0).sum()
                #import pdb;pdb.set_trace()

                total_male_correct += ((output.argmax(1) == male) *
                                       (male == 1)).sum()

                total_female_correct += ((output.argmax(1) == male) *
                                         (male == 0)).sum()
                '''
                pa_correct=(output.argm ax(1)==male).sum()
                pa_num=male.size(0)
                
                total_pa_correct_num+=pa_correct
                total_pa_num+=pa_num
                
                total_male_num+=(male==1).sum()
                total_male_correct+=((output.argmax(1)==male)*(male==1)).sum()
                total_female_num+=(male==0).sum()
                total_female_correct+=((output.argmax(1)==male)*(male==0)).sum()
                
                '''
                #weight=torch.tensor([3.5,1.0]).cuda()
                pa_cls = F.cross_entropy(output, male)
                #pa_cls=F.cross_entropy(output,male)
                rev_cls = F.cross_entropy(reverse_output, heavy_makeup)

                t_pa_cls = F.cross_entropy(t_output, heavy_makeup)
                t_rev_cls = F.cross_entropy(t_reverse_output, male)

                vae_loss = vae_recon_loss + self.beta * vae_kld
                # + self.grl*rev_cls

                self.optim_VAE.zero_grad()
                self.optim_pacls.zero_grad()
                self.optim_revcls.zero_grad()
                self.optim_tcls.zero_grad()
                self.optim_trevcls.zero_grad()

                vae_loss.backward(retain_graph=True)

                self.optim_VAE.step()
                self.optim_pacls.step()
                self.optim_revcls.step()
                self.optim_tcls.step()
                self.optim_trevcls.step()

                x_true2 = x_true2.to(self.device)
                z_prime = self.VAE(x_true2, no_dec=True)
                z_pperm = permute_dims(z_prime).detach()

                D_z_pperm = self.D(z_pperm)
                #D_tc_loss = 0.5*(F.cross_entropy(D_z, zeros) + F.cross_entropy(D_z_pperm, ones))
                D_tc_loss = (F.cross_entropy(D_z, zeros) +
                             F.cross_entropy(D_z_pperm, ones))

                self.optim_D.zero_grad()
                #D_tc_loss.backward()
                self.optim_D.step()

            self.pbar.update(1)
            self.global_iter += 1
            pa_acc = float(total_pa_correct_num) / float(total_pa_num)
            rev_acc = float(total_rev_correct_num) / float(total_rev_num)

            t_acc = float(total_t_correct_num) / float(total_t_num)
            t_rev_acc = float(total_t_rev_correct_num) / float(total_t_rev_num)

            male_acc = float(total_male_correct) / float(total_male_num)
            female_acc = float(total_female_correct) / float(total_female_num)

            if self.global_iter % self.print_iter == 0:
                self.pbar.write(
                    '[{}] vae_recon_loss:{:.3f} vae_kld:{:.3f} vae_tc_loss:{:.3f} D_tc_loss:{:.3f} pa_cls_loss:{:.3f} pa_acc:{:.3f} m_acc:{:.3f}  f_acc:{:.3f} rev_acc:{:.3f} t_acc:{:.3f} t_rev_acc:{:.3f}'
                    .format(self.global_iter, vae_recon_loss.item(),
                            vae_kld.item(), vae_tc_loss.item(),
                            D_tc_loss.item(), pa_cls.item(), pa_acc, male_acc,
                            female_acc, rev_acc, t_acc, t_rev_acc))

            if self.global_iter % self.ckpt_save_iter == 0:
                self.save_checkpoint(self.global_iter)
                #self.ckpt_save_iter+=1

            if self.viz_on and (self.global_iter % self.viz_ll_iter == 0):
                soft_D_z = F.softmax(D_z, 1)[:, :1].detach()
                soft_D_z_pperm = F.softmax(D_z_pperm, 1)[:, :1].detach()
                D_acc = ((soft_D_z >= 0.5).sum() +
                         (soft_D_z_pperm < 0.5).sum()).float()
                D_acc /= 2 * self.batch_size
                self.line_gather.insert(
                    iter=self.global_iter,
                    soft_D_z=soft_D_z.mean().item(),
                    soft_D_z_pperm=soft_D_z_pperm.mean().item(),
                    recon=vae_recon_loss.item(),
                    kld=vae_kld.item(),
                    acc=D_acc.item())
                #viz_ll_iter+=1
            if self.viz_on and (self.global_iter % self.viz_la_iter == 0):
                self.visualize_line()
                self.line_gather.flush()
                #viz_la_iter+=1

            if self.viz_on and (self.global_iter % self.viz_ra_iter == 0):
                self.image_gather.insert(true=x_true1.data.cpu(),
                                         recon=F.sigmoid(x_recon).data.cpu())
                self.visualize_recon()
                self.image_gather.flush()
                #viz_ra_iter+=1

            if self.viz_on and (self.global_iter % self.viz_ta_iter == 0):
                if self.dataset.lower() == '3dchairs':
                    self.visualize_traverse(limit=2, inter=0.5)
                else:
                    self.visualize_traverse(limit=3, inter=2 / 3)

        self.pbar.write("[Training Finished]")
        self.pbar.close()

        self.train_cls()

    def train_cls(self):
        self.net_mode(train=True)

        ones = torch.ones(self.batch_size,
                          dtype=torch.long,
                          device=self.device)
        zeros = torch.zeros(self.batch_size,
                            dtype=torch.long,
                            device=self.device)

        out = False
        for i_num in range(80):

            for i, (x_true1, x_true2, heavy_makeup,
                    male) in enumerate(self.data_loader):

                for name, param in self.VAE.named_parameters():
                    #if name=='encode.0.weight':
                    #    print(param[0])

                    param.requires_grad = False

                male = male.to(self.device)
                heavy_makeup = heavy_makeup.to(self.device)

                x_true1 = x_true1.to(self.device)
                x_recon, mu, logvar, z = self.VAE(x_true1)
                vae_recon_loss = recon_loss(x_true1, x_recon)
                vae_kld = kl_divergence(mu, logvar)
                D_z = self.D(z)
                vae_tc_loss = (D_z[:, :1] - D_z[:, 1:]).mean()

                #weight=torch.tensor([1.0,3.0]).cuda()

                #target=self.targetcls(z)
                dim_list = [24]
                #dim_list=[18,47]

                target_z = z.split(1, 1)[0]

                for dim in range(1, len(z[0])):
                    if dim not in dim_list:

                        target_z = torch.cat([target_z, z.split(1, 1)[dim]], 1)

                target = self.targetcls(target_z)
                pa_target = self.pa_target(z.split(30, 1)[-1])

                target_pa = self.target_pa(z.split(1, 1)[0])
                pa_pa = self.pa_pa(z.split(30, 1)[-1])

                #weight=torch.tensor([1.0,3.0]).cuda()
                target_cls = F.cross_entropy(target, heavy_makeup)

                #import pdb;pdb.set_trace()

                #pa_target_cls=F.cross_entropy(pa_target,heavy_makeup)

                #target_pa_cls=F.cross_entropy(target_pa,male)

                #pa_pa_cls=F.cross_entropy(pa_pa,male)

                vae_loss = vae_recon_loss + vae_kld + self.gamma * vae_tc_loss

                target_loss = target_cls

                self.optim_cls.zero_grad()
                #self.optim_pa_target.zero_grad()
                #self.optim_target_pa.zero_grad()
                #self.optim_pa_pa.zero_grad()

                target_loss.backward()
                #self.optim_pa_target.step()
                self.optim_cls.step()
                #self.optim_target_pa.step()
                #self.optim_pa_pa.step()

            self.global_iter_cls += 1
            self.pbar_cls.update(1)

            if self.global_iter_cls % self.print_iter == 0:
                acc = ((target.argmax(1) == heavy_makeup).sum().float() /
                       len(x_true1)).item()
                pa_acc = ((pa_target.argmax(1) == heavy_makeup).sum().float() /
                          len(x_true1)).item()
                self.pbar_cls.write(
                    '[{}] vae_recon_loss:{:.3f} vae_kld:{:.3f} vae_tc_loss:{:.3f} target_loss:{:.3f} accuracy:{:.3f}'
                    .format(self.global_iter_cls, vae_recon_loss.item(),
                            vae_kld.item(), vae_tc_loss.item(),
                            target_loss.item(), acc))

            #if self.global_iter_cls%self.ckpt_save_iter == 0:
            #    self.save_checkpoint_cls(self.global_iter_cls)

            self.val()

        self.pbar_cls.write("[Classifier Training Finished]")
        self.pbar_cls.close()

    '''
   
    def train_cls(self):
        self.net_mode(train=True)

        ones = torch.ones(self.batch_size, dtype=torch.long, device=self.device)
        zeros = torch.zeros(self.batch_size, dtype=torch.long, device=self.device)
      


        init_weight=self.target_pa.fc.weight
        init_bias=self.target_pa.fc.bias
        out = False
        total_list=list(range(60))
        total_max_value=[]
        total_min_value=[]
        for  dim in range(60):
            self.target_pa.fc.weight=init_weight
            self.target_pa.fc.bias=init_bias
            
            max_dim=0
            max_value=0
            for i_num in range(1):
               
                total_target_pa_true=0
                total_target_pa_num=0    
                for i, (x_true1,x_true2,heavy_makeup, male) in enumerate(self.data_loader):
                    
                    for name,param in self.VAE.named_parameters():
                        #if name=='encode.0.weight':
                        #    print(param[0])
                            
                        param.requires_grad=False

                    
                    
                    male=male.to(self.device)
                    heavy_makeup=heavy_makeup.to(self.device)
                


                    x_true1 = x_true1.to(self.device)
                    x_recon, mu, logvar, z = self.VAE(x_true1)
                    vae_recon_loss = recon_loss(x_true1, x_recon)
                    vae_kld = kl_divergence(mu, logvar)
                    D_z = self.D(z)
                    vae_tc_loss = (D_z[:, :1] - D_z[:, 1:]).mean()
                

                    target=self.targetcls(z.split(59,1)[0])
                    
                    #target=self.targetcls(z)

                   
                    pa_target=self.pa_target(z.split(30,1)[-1])
                    

                    target_pa=self.target_pa(z.split(1,1)[dim])
     
                    pa_pa=self.pa_pa(z.split(30,1)[-1])
                   
                    #import pdb;pdb.set_trace()
                    target_pa_cls =F.binary_cross_entropy(target_pa.squeeze(1),male.type_as(target_pa))





                    target_cls=F.cross_entropy(target,heavy_makeup)

                    

                    pa_target_cls=F.cross_entropy(pa_target,heavy_makeup)

                

                    pa_pa_cls=F.cross_entropy(pa_pa,male)

                
                
                    vae_loss = vae_recon_loss + vae_kld + self.gamma*vae_tc_loss
                    
                    target_loss=target_cls+pa_target_cls+target_pa_cls+pa_pa_cls
                
                    total_target_pa_true+=((target_pa.squeeze(1)>0.5).type(torch.LongTensor).cuda()==male).sum().float()

                    total_target_pa_num+=len(male)
                    
                    
                    self.optim_cls.zero_grad()
                    self.optim_pa_target.zero_grad()
                    self.optim_target_pa.zero_grad()
                    self.optim_pa_pa.zero_grad()
                    
                    target_loss.backward()
                    self.optim_pa_target.step()
                    self.optim_cls.step()
                    self.optim_target_pa.step()
                    self.optim_pa_pa.step()
                    
                    
                 
                self.global_iter_cls += 1
                self.pbar_cls.update(1)

                if self.global_iter_cls%self.print_iter == 0:
                    acc=((target.argmax(1)==heavy_makeup).sum().float()/len(x_true1)).item()
                    pa_acc=((pa_target.argmax(1)==heavy_makeup).sum().float()/len(x_true1)).item()
                    self.pbar_cls.write('[{}] vae_recon_loss:{:.3f} vae_kld:{:.3f} vae_tc_loss:{:.3f} target_loss:{:.3f} accuracy:{:.3f}'.format(
                        self.global_iter_cls, vae_recon_loss.item(), vae_kld.item(), vae_tc_loss.item(), target_loss.item(), acc ))
                
                total_target_pa_acc=total_target_pa_true/total_target_pa_num
                
                print(total_target_pa_acc)
                if total_target_pa_acc>max_value:
                    max_value=total_target_pa_acc
                    
             
                #if self.global_iter_cls%self.ckpt_save_iter == 0:
                #    self.save_checkpoint_cls(self.global_iter_cls)

                #self.val()
            #import pdb;pdb.set_trace() 
            total_max_value.append(max_value)
            
            self.pbar_cls.write("[Classifier Training Finished]")
            self.pbar_cls.close()
        
        
        total_max_index=torch.from_numpy(np.array(total_max_value)).topk(5)
       
  

        print(total_max_index)
     
    '''

    def val(self):

        ones = torch.ones(self.batch_size,
                          dtype=torch.long,
                          device=self.device)
        zeros = torch.zeros(self.batch_size,
                            dtype=torch.long,
                            device=self.device)
        total_true = 0
        total_num = 0
        total_male_heavy = 0
        total_male_nonheavy = 0
        total_female_heavy = 0
        total_female_nonheavy = 0
        total_male_heavy_num = 0
        total_male_nonheavy_num = 0
        total_female_heavy_num = 0
        total_female_nonheavy_num = 0
        total_pa_num = 0
        total_pa_true = 0

        total_target_pa_num = 0
        total_target_pa_true = 0
        total_pa_pa_num = 0
        total_pa_pa_true = 0

        demo = 0
        total_male = 0
        total_female = 0
        total_male_pred = 0
        total_female_pred = 0

        iter = 0

        recon_tsum = np.zeros((self.batch_size, 64, 64, 3))
        recon_csum = np.zeros((self.batch_size, 64, 64, 3))
        recon_psum = np.zeros((self.batch_size, 64, 64, 3))
        recon_sum = np.zeros((self.batch_size, 64, 64, 3))
        origin_sum = np.zeros((self.batch_size, 64, 64, 3))

        for i, (x_true1, x_true2, heavy_makeup,
                male) in enumerate(self.data_loader_eval):

            #for name,param in self.VAE.named_parameters():
            #    param.requires_grad=False

            #for name,param in self.targetcls.named_parameters():
            #    param.requires_grad=False

            male = male.to(self.device)
            heavy_makeup = heavy_makeup.to(self.device)

            x_true1 = x_true1.to(self.device)
            x_recon, mu, logvar, z = self.VAE(x_true1)

            #dim_list=[18,47]
            dim_list = [24]
            target_z = z.split(1, 1)[0]

            for dim in range(1, len(z[0])):
                if dim not in dim_list:
                    target_z = torch.cat([target_z, z.split(1, 1)[dim]], 1)

            target = self.targetcls(target_z)
            pa_target = self.pa_target(z.split(30, 1)[-1])

            target_pa = self.target_pa(z.split(1, 1)[0])
            pa_pa = self.pa_pa(z.split(30, 1)[-1])

            #self.optim_cls.zero_grad()

            z_t = z.split(30, 1)[0].unsqueeze(2).unsqueeze(2)

            z_p = z.split(30, 1)[-1].unsqueeze(2).unsqueeze(2)

            noise = torch.zeros(z.size(0), 30, 1, 1).cuda()

            z_t = torch.cat([z_t, noise], 1)

            z_p = torch.cat([z_p, noise], 1)

            recon_t = F.sigmoid(self.VAE.decode(z_t))

            recon_p = F.sigmoid(self.VAE.decode(z_p))

            recon_tsum += recon_t.transpose(1, 2).transpose(
                2, 3).cpu().detach().numpy()

            recon_psum += recon_p.transpose(1, 2).transpose(
                2, 3).cpu().detach().numpy()
            origin_sum += x_true1.transpose(1, 2).transpose(
                2, 3).cpu().detach().numpy()
            recon_sum += F.sigmoid(x_recon).transpose(1, 2).transpose(
                2, 3).cpu().detach().numpy()

            iter += 1

            male_heavy = (target.argmax(1) == 1) * (heavy_makeup == 1) * (male
                                                                          == 1)
            male_heavy = male_heavy.sum()
            male_heavy_num = ((heavy_makeup == 1) * (male == 1)).sum()

            male_nonheavy = (target.argmax(1) == 0) * (heavy_makeup
                                                       == 0) * (male == 1)
            male_nonheavy = male_nonheavy.sum()
            male_nonheavy_num = ((heavy_makeup == 0) * (male == 1)).sum()

            female_heavy = (target.argmax(1) == 1) * (heavy_makeup
                                                      == 1) * (male == 0)
            female_heavy = female_heavy.sum()
            female_heavy_num = ((heavy_makeup == 1) * (male == 0)).sum()

            female_nonheavy = (target.argmax(1) == 0) * (heavy_makeup
                                                         == 0) * (male == 0)
            female_nonheavy = female_nonheavy.sum()
            female_nonheavy_num = ((heavy_makeup == 0) * (male == 0)).sum()

            total_male_heavy += male_heavy
            total_male_nonheavy += male_nonheavy
            total_female_heavy += female_heavy
            total_female_nonheavy += female_nonheavy
            total_male_heavy_num += male_heavy_num
            total_male_nonheavy_num += male_nonheavy_num
            total_female_heavy_num += female_heavy_num
            total_female_nonheavy_num += female_nonheavy_num

            total_pa_true += (
                pa_target.argmax(1) == heavy_makeup).sum().float()

            total_pa_num += len(heavy_makeup)

            total_target_pa_true += (target_pa.argmax(1) == male).sum().float()

            total_target_pa_num += len(male)

            total_pa_pa_true += (pa_pa.argmax(1) == male).sum().float()

            total_pa_pa_num += len(male)

            total_true += (target.argmax(1) == heavy_makeup).sum().float()
            total_num += len(x_true1)

            total_male += (male == 1).sum()
            total_female += (male == 0).sum()

            total_male_pred += ((target.argmax(1) == 1) * (male == 1)).sum()
            total_female_pred += ((target.argmax(1) == 1) * (male == 0)).sum()

            #import pdb;pdb.set_trace()
        male_heavy_acc = total_male_heavy.float() / total_male_heavy_num.float(
        )
        male_nonheavy_acc = total_male_nonheavy.float(
        ) / total_male_nonheavy_num.float()
        female_heavy_acc = total_female_heavy.float(
        ) / total_female_heavy_num.float()
        female_nonheavy_acc = total_female_nonheavy.float(
        ) / total_female_nonheavy_num.float()
        '''

        plt.imshow(origin_sum.mean(0)/iter)
        plt.savefig('./figure/origin/origin'+str(i)+'.png')
        
        plt.imshow(recon_sum.mean(0)/iter)
        plt.savefig('./figure/recon/recon'+str(i)+'.png')

        plt.imshow(recon_tsum.mean(0)/iter)
        plt.savefig('./figure/target/target'+str(i)+'.png')
        
        plt.imshow(recon_psum.mean(0)/iter)
        plt.savefig('./figure/protected/protected'+str(i)+'.png')

        '''

        print(total_male_heavy_num.item(), total_male_nonheavy_num.item(),
              total_female_heavy_num.item(), total_female_nonheavy_num.item())

        print("\nmale_heavy: ", male_heavy_acc.item(), "\tfemale_heavy: ",
              female_heavy_acc.item())
        print("male_nonheavy: ", male_nonheavy_acc.item(),
              "\tfemale_nonheavy: ", female_nonheavy_acc.item())

        print("Male_prob:", float(total_male_pred) / float(total_male))
        print("feMale_prob:", float(total_female_pred) / float(total_female))
        #import pdb;pdb.set_trace()
        print("DP:", (float(total_male_pred) / float(total_male) -
                      float(total_female_pred) / float(total_female)))

        print("eoo(1):", male_heavy_acc.item() - female_heavy_acc.item())
        print("eoo(0):", male_nonheavy_acc.item() - female_nonheavy_acc.item())

        #import pdb;pdb.set_trace()
        total_acc = total_true / total_num
        total_pa_acc = total_pa_true / total_pa_num

        total_target_pa_acc = total_target_pa_true / total_target_pa_num
        total_pa_pa_acc = total_pa_pa_true / total_pa_pa_num

        print("target->target Accuracy: ", total_acc.item())
        print("PA->target Accuracy: ", total_pa_acc.item())
        print("target->PA Accuracy: ", total_target_pa_acc.item())
        print("PA->PA Accuracy: ", total_pa_pa_acc.item())

    def visualize_recon(self):
        data = self.image_gather.data
        true_image = data['true'][0]
        recon_image = data['recon'][0]

        true_image = make_grid(true_image)
        recon_image = make_grid(recon_image)
        sample = torch.stack([true_image, recon_image], dim=0)
        self.viz.images(sample,
                        env=self.name + '/recon_image',
                        opts=dict(title=str(self.global_iter)))

    def visualize_line(self):
        data = self.line_gather.data
        iters = torch.Tensor(data['iter'])
        recon = torch.Tensor(data['recon'])
        kld = torch.Tensor(data['kld'])
        D_acc = torch.Tensor(data['acc'])
        soft_D_z = torch.Tensor(data['soft_D_z'])
        soft_D_z_pperm = torch.Tensor(data['soft_D_z_pperm'])
        soft_D_zs = torch.stack([soft_D_z, soft_D_z_pperm], -1)

        self.viz.line(X=iters,
                      Y=soft_D_zs,
                      env=self.name + '/lines',
                      win=self.win_id['D_z'],
                      update='append',
                      opts=dict(xlabel='iteration',
                                ylabel='D(.)',
                                legend=['D(z)', 'D(z_perm)']))
        self.viz.line(X=iters,
                      Y=recon,
                      env=self.name + '/lines',
                      win=self.win_id['recon'],
                      update='append',
                      opts=dict(
                          xlabel='iteration',
                          ylabel='reconstruction loss',
                      ))
        self.viz.line(X=iters,
                      Y=D_acc,
                      env=self.name + '/lines',
                      win=self.win_id['acc'],
                      update='append',
                      opts=dict(
                          xlabel='iteration',
                          ylabel='discriminator accuracy',
                      ))
        self.viz.line(X=iters,
                      Y=kld,
                      env=self.name + '/lines',
                      win=self.win_id['kld'],
                      update='append',
                      opts=dict(
                          xlabel='iteration',
                          ylabel='kl divergence',
                      ))

    def visualize_traverse(self, limit=3, inter=2 / 3, loc=-1):
        self.net_mode(train=False)

        decoder = self.VAE.decode
        encoder = self.VAE.encode
        interpolation = torch.arange(-limit, limit + 0.1, inter)

        random_img = self.data_loader.dataset.__getitem__(0)[1]
        random_img = random_img.to(self.device).unsqueeze(0)
        random_img_z = encoder(random_img)[:, :self.z_dim]

        if self.dataset.lower() == 'dsprites':
            fixed_idx1 = 87040  # square
            fixed_idx2 = 332800  # ellipse
            fixed_idx3 = 578560  # heart

            fixed_img1 = self.data_loader.dataset.__getitem__(fixed_idx1)[0]
            fixed_img1 = fixed_img1.to(self.device).unsqueeze(0)
            fixed_img_z1 = encoder(fixed_img1)[:, :self.z_dim]

            fixed_img2 = self.data_loader.dataset.__getitem__(fixed_idx2)[0]
            fixed_img2 = fixed_img2.to(self.device).unsqueeze(0)
            fixed_img_z2 = encoder(fixed_img2)[:, :self.z_dim]

            fixed_img3 = self.data_loader.dataset.__getitem__(fixed_idx3)[0]
            fixed_img3 = fixed_img3.to(self.device).unsqueeze(0)
            fixed_img_z3 = encoder(fixed_img3)[:, :self.z_dim]

            Z = {
                'fixed_square': fixed_img_z1,
                'fixed_ellipse': fixed_img_z2,
                'fixed_heart': fixed_img_z3,
                'random_img': random_img_z
            }

        elif self.dataset.lower() == 'celeba':
            fixed_idx1 = 70000  # 'CelebA/img_align_celeba/191282.jpg'
            fixed_idx2 = 143307  # 'CelebA/img_align_celeba/143308.jpg'
            fixed_idx3 = 101535  # 'CelebA/img_align_celeba/101536.jpg'
            fixed_idx4 = 70059  # 'CelebA/img_align_celeba/070060.jpg'

            fixed_img1 = self.data_loader.dataset.__getitem__(fixed_idx1)[0]
            fixed_img1 = fixed_img1.to(self.device).unsqueeze(0)
            fixed_img_z1 = encoder(fixed_img1)[:, :self.z_dim]

            fixed_img2 = self.data_loader.dataset.__getitem__(fixed_idx2)[0]
            fixed_img2 = fixed_img2.to(self.device).unsqueeze(0)
            fixed_img_z2 = encoder(fixed_img2)[:, :self.z_dim]

            fixed_img3 = self.data_loader.dataset.__getitem__(fixed_idx3)[0]
            fixed_img3 = fixed_img3.to(self.device).unsqueeze(0)
            fixed_img_z3 = encoder(fixed_img3)[:, :self.z_dim]

            fixed_img4 = self.data_loader.dataset.__getitem__(fixed_idx4)[0]
            fixed_img4 = fixed_img4.to(self.device).unsqueeze(0)
            fixed_img_z4 = encoder(fixed_img4)[:, :self.z_dim]

            Z = {
                'fixed_1': fixed_img_z1,
                'fixed_2': fixed_img_z2,
                'fixed_3': fixed_img_z3,
                'fixed_4': fixed_img_z4,
                'random': random_img_z
            }

        elif self.dataset.lower() == '3dchairs':
            fixed_idx1 = 40919  # 3DChairs/images/4682_image_052_p030_t232_r096.png
            fixed_idx2 = 5172  # 3DChairs/images/14657_image_020_p020_t232_r096.png
            fixed_idx3 = 22330  # 3DChairs/images/30099_image_052_p030_t232_r096.png

            fixed_img1 = self.data_loader.dataset.__getitem__(fixed_idx1)[0]
            fixed_img1 = fixed_img1.to(self.device).unsqueeze(0)
            fixed_img_z1 = encoder(fixed_img1)[:, :self.z_dim]

            fixed_img2 = self.data_loader.dataset.__getitem__(fixed_idx2)[0]
            fixed_img2 = fixed_img2.to(self.device).unsqueeze(0)
            fixed_img_z2 = encoder(fixed_img2)[:, :self.z_dim]

            fixed_img3 = self.data_loader.dataset.__getitem__(fixed_idx3)[0]
            fixed_img3 = fixed_img3.to(self.device).unsqueeze(0)
            fixed_img_z3 = encoder(fixed_img3)[:, :self.z_dim]

            Z = {
                'fixed_1': fixed_img_z1,
                'fixed_2': fixed_img_z2,
                'fixed_3': fixed_img_z3,
                'random': random_img_z
            }
        else:
            fixed_idx = 0
            fixed_img = self.data_loader.dataset.__getitem__(fixed_idx)[0]
            fixed_img = fixed_img.to(self.device).unsqueeze(0)
            fixed_img_z = encoder(fixed_img)[:, :self.z_dim]

            random_z = torch.rand(1, self.z_dim, 1, 1, device=self.device)

            Z = {
                'fixed_img': fixed_img_z,
                'random_img': random_img_z,
                'random_z': random_z
            }

        gifs = []
        for key in Z:
            z_ori = Z[key]
            samples = []
            for row in range(self.z_dim):
                if loc != -1 and row != loc:
                    continue
                z = z_ori.clone()
                for val in interpolation:
                    z[:, row] = val
                    sample = F.sigmoid(decoder(z)).data
                    samples.append(sample)
                    gifs.append(sample)
            samples = torch.cat(samples, dim=0).cpu()
            title = '{}_latent_traversal(iter:{})'.format(
                key, self.global_iter)
            self.viz.images(samples,
                            env=self.name + '/traverse',
                            opts=dict(title=title),
                            nrow=len(interpolation))

        if self.output_save:
            output_dir = os.path.join(self.output_dir, str(self.global_iter))
            mkdirs(output_dir)
            gifs = torch.cat(gifs)
            gifs = gifs.view(len(Z), self.z_dim, len(interpolation), self.nc,
                             64, 64).transpose(1, 2)
            for i, key in enumerate(Z.keys()):
                for j, val in enumerate(interpolation):
                    save_image(tensor=gifs[i][j].cpu(),
                               filename=os.path.join(
                                   output_dir, '{}_{}.jpg'.format(key, j)),
                               nrow=self.z_dim,
                               pad_value=1)

                grid2gif(str(os.path.join(output_dir, key + '*.jpg')),
                         str(os.path.join(output_dir, key + '.gif')),
                         delay=10)

        self.net_mode(train=True)

    def viz_init(self):
        zero_init = torch.zeros([1])
        self.viz.line(X=zero_init,
                      Y=torch.stack([zero_init, zero_init], -1),
                      env=self.name + '/lines',
                      win=self.win_id['D_z'],
                      opts=dict(xlabel='iteration',
                                ylabel='D(.)',
                                legend=['D(z)', 'D(z_perm)']))
        self.viz.line(X=zero_init,
                      Y=zero_init,
                      env=self.name + '/lines',
                      win=self.win_id['recon'],
                      opts=dict(
                          xlabel='iteration',
                          ylabel='reconstruction loss',
                      ))
        self.viz.line(X=zero_init,
                      Y=zero_init,
                      env=self.name + '/lines',
                      win=self.win_id['acc'],
                      opts=dict(
                          xlabel='iteration',
                          ylabel='discriminator accuracy',
                      ))
        self.viz.line(X=zero_init,
                      Y=zero_init,
                      env=self.name + '/lines',
                      win=self.win_id['kld'],
                      opts=dict(
                          xlabel='iteration',
                          ylabel='kl divergence',
                      ))

    def net_mode(self, train):
        if not isinstance(train, bool):
            raise ValueError('Only bool type is supported. True|False')

        for net in self.nets:
            if train:
                net.train()
            else:
                net.eval()

    def save_checkpoint(self, ckptname='last', verbose=True):
        model_states = {
            'D': self.D.state_dict(),
            'VAE': self.VAE.state_dict(),
            'PACLS': self.pacls.state_dict(),
            'REVCLS': self.revcls.state_dict(),
            'T_CLS': self.tcls.state_dict(),
            'T_REVCLS': self.trevcls.state_dict()
        }
        optim_states = {
            'optim_D': self.optim_D.state_dict(),
            'optim_VAE': self.optim_VAE.state_dict(),
            'optim_PACLS': self.optim_pacls.state_dict(),
            'optim_REVCLS': self.optim_revcls.state_dict(),
            'optim_TCLS': self.optim_tcls.state_dict(),
            'optim_TREVCLS': self.optim_trevcls.state_dict()
        }
        states = {
            'iter': self.global_iter,
            'model_states': model_states,
            'optim_states': optim_states
        }
        #import pdb;pdb.set_trace()
        filepath = os.path.join(self.ckpt_dir + "/vae", str(ckptname))
        with open(filepath, 'wb+') as f:
            torch.save(states, f)
        if verbose:
            self.pbar.write("=> saved checkpoint '{}' (iter {})".format(
                filepath, self.global_iter))

    def save_checkpoint_cls(self, ckptname='last', verbose=True):
        model_states = {
            'D': self.D.state_dict(),
            'VAE': self.VAE.state_dict(),
            'PACLS': self.pacls.state_dict(),
            'REVCLS': self.revcls.state_dict(),
            'TCLS': self.targetcls.state_dict(),
            'VALCLS': self.pa_target.state_dict(),
            'T_CLS': self.tcls.state_dict(),
            'T_REVCLS': self.trevcls.state_dict()
        }
        optim_states = {
            'optim_D': self.optim_D.state_dict(),
            'optim_VAE': self.optim_VAE.state_dict(),
            'optim_Tcls': self.optim_cls.state_dict(),
            'optim_PACLS': self.optim_pacls.state_dict(),
            'optim_REVCLS': self.optim_revcls.state_dict(),
            'optim_TCLS': self.optim_tcls.state_dict(),
            'optim_TREVCLS': self.optim_trevcls.state_dict(),
            'optim_VALCLS': self.optim_pa_target.state_dict()
        }
        states = {
            'iter': self.global_iter_cls,
            'model_states': model_states,
            'optim_states': optim_states
        }

        #import pdb;pdb.set_trace()

        filepath = os.path.join(self.ckpt_dir + "/cls", str(ckptname))
        with open(filepath, 'wb+') as f:
            torch.save(states, f)
        if verbose:
            self.pbar.write("=> saved checkpoint '{}' (iter {})".format(
                filepath, self.global_iter_cls))

    def load_checkpoint(self, ckptname='last', verbose=True):

        if ckptname == 'last':

            ckpts = os.listdir(self.ckpt_dir + '/vae')
            if not ckpts:
                if verbose:
                    self.pbar.write("=> no checkpoint found")
                return

            ckpts = [int(ckpt) for ckpt in ckpts]
            ckpts.sort(reverse=True)
            ckptname = str(ckpts[0])
        #import pdb;pdb.set_trace()
        filepath = os.path.join(self.ckpt_dir + '/vae', ckptname)
        if os.path.isfile(filepath):
            with open(filepath, 'rb') as f:
                checkpoint = torch.load(f)

            self.global_iter = checkpoint['iter']
            self.VAE.load_state_dict(checkpoint['model_states']['VAE'])
            self.D.load_state_dict(checkpoint['model_states']['D'])
            self.pacls.load_state_dict(checkpoint['model_states']['PACLS'])
            self.revcls.load_state_dict(checkpoint['model_states']['REVCLS'])
            self.tcls.load_state_dict(checkpoint['model_states']['T_CLS'])
            self.trevcls.load_state_dict(
                checkpoint['model_states']['T_REVCLS'])

            self.optim_VAE.load_state_dict(
                checkpoint['optim_states']['optim_VAE'])
            self.optim_D.load_state_dict(checkpoint['optim_states']['optim_D'])
            self.optim_pacls.load_state_dict(
                checkpoint['optim_states']['optim_PACLS'])
            self.optim_revcls.load_state_dict(
                checkpoint['optim_states']['optim_REVCLS'])
            self.optim_tcls.load_state_dict(
                checkpoint['optim_states']['optim_TCLS'])
            self.optim_trevcls.load_state_dict(
                checkpoint['optim_states']['optim_TREVCLS'])

            self.pbar.update(self.global_iter)
            if verbose:
                self.pbar.write("=> loaded checkpoint '{} (iter {})'".format(
                    filepath, self.global_iter))
        else:
            if verbose:
                self.pbar.write(
                    "=> no checkpoint found at '{}'".format(filepath))

    def load_checkpoint_cls(self, ckptname='last', verbose=True):

        if ckptname == 'last':

            ckpts = os.listdir(self.ckpt_dir + "/cls")
            if not ckpts:
                if verbose:
                    self.pbar.write("=> no checkpoint found")
                return

            ckpts = [int(ckpt) for ckpt in ckpts]
            ckpts.sort(reverse=True)

            ckptname = str(ckpts[0])

        filepath = os.path.join(self.ckpt_dir + '/cls', ckptname)
        if os.path.isfile(filepath):
            with open(filepath, 'rb') as f:
                checkpoint = torch.load(f)

            self.global_iter_cls = checkpoint['iter']
            self.VAE.load_state_dict(checkpoint['model_states']['VAE'])
            self.D.load_state_dict(checkpoint['model_states']['D'])
            self.pacls.load_state_dict(checkpoint['model_states']['PACLS'])
            self.revcls.load_state_dict(checkpoint['model_states']['REVCLS'])
            self.targetcls.load_state_dict(checkpoint['model_states']['TCLS'])
            self.pa_target.load_state_dict(
                checkpoint['model_states']['VALCLS'])
            self.tcls.load_state_dict(checkpoint['model_states']['T_CLS'])
            self.trevcls.load_state_dict(
                checkpoint['model_states']['T_REVCLS'])

            self.optim_VAE.load_state_dict(
                checkpoint['optim_states']['optim_VAE'])
            self.optim_D.load_state_dict(checkpoint['optim_states']['optim_D'])
            self.optim_pacls.load_state_dict(
                checkpoint['optim_states']['optim_PACLS'])
            self.optim_revcls.load_state_dict(
                checkpoint['optim_states']['optim_REVCLS'])
            self.optim_pa_target.load_state_dict(
                checkpoint['optim_states']['optim_VALCLS'])
            self.optim_tcls.load_state_dict(
                checkpoint['optim_states']['optim_TCLS'])
            self.optim_trevcls.load_state_dict(
                checkpoint['optim_states']['optim_TREVCLS'])
            self.pbar.update(self.global_iter_cls)
            if verbose:
                self.pbar.write("=> loaded checkpoint '{} (iter {})'".format(
                    filepath, self.global_iter_cls))
        else:
            if verbose:
                self.pbar.write(
                    "=> no checkpoint found at '{}'".format(filepath))
Example #8
0
class Solver(object):

    ####
    def __init__(self, args):

        self.args = args

        self.name = ( '%s_gamma_%s_zDim_%s' + \
            '_lrVAE_%s_lrD_%s_rseed_%s' ) % \
            ( args.dataset, args.gamma, args.z_dim,
              args.lr_VAE, args.lr_D, args.rseed )
        # to be appended by run_id

        self.use_cuda = args.cuda and torch.cuda.is_available()

        self.max_iter = int(args.max_iter)

        # do it every specified iters
        self.print_iter = args.print_iter
        self.ckpt_save_iter = args.ckpt_save_iter
        self.output_save_iter = args.output_save_iter

        # data info
        self.dset_dir = args.dset_dir
        self.dataset = args.dataset
        if args.dataset.endswith('dsprites'):
            self.nc = 1
        elif args.dataset == '3dfaces':
            self.nc = 1
        else:
            self.nc = 3

        # groundtruth factor labels (only available for "dsprites")
        if self.dataset == 'dsprites':

            # latent factor = (color, shape, scale, orient, pos-x, pos-y)
            #   color = {1} (1)
            #   shape = {1=square, 2=oval, 3=heart} (3)
            #   scale = {0.5, 0.6, ..., 1.0} (6)
            #   orient = {2*pi*(k/39)}_{k=0}^39 (40)
            #   pos-x = {k/31}_{k=0}^31 (32)
            #   pos-y = {k/31}_{k=0}^31 (32)
            # (number of variations = 1*3*6*40*32*32 = 737280)

            latent_values = np.load(os.path.join(self.dset_dir,
                                                 'dsprites-dataset',
                                                 'latents_values.npy'),
                                    encoding='latin1')
            self.latent_values = latent_values[:, [1, 2, 3, 4, 5]]
            # latent values (actual values);(737280 x 5)
            latent_classes = np.load(os.path.join(self.dset_dir,
                                                  'dsprites-dataset',
                                                  'latents_classes.npy'),
                                     encoding='latin1')
            self.latent_classes = latent_classes[:, [1, 2, 3, 4, 5]]
            # classes ({0,1,...,K}-valued); (737280 x 5)
            self.latent_sizes = np.array([3, 6, 40, 32, 32])
            self.N = self.latent_values.shape[0]

            if args.eval_metrics:
                self.eval_metrics = True
                self.eval_metrics_iter = args.eval_metrics_iter

        # groundtruth factor labels
        elif self.dataset == 'oval_dsprites':

            latent_classes = np.load(os.path.join(self.dset_dir,
                                                  'dsprites-dataset',
                                                  'latents_classes.npy'),
                                     encoding='latin1')
            idx = np.where(latent_classes[:, 1] == 1)[0]  # "oval" shape only
            self.latent_classes = latent_classes[idx, :]
            self.latent_classes = self.latent_classes[:, [2, 3, 4, 5]]
            # classes ({0,1,...,K}-valued); (245760 x 4)
            latent_values = np.load(os.path.join(self.dset_dir,
                                                 'dsprites-dataset',
                                                 'latents_values.npy'),
                                    encoding='latin1')
            self.latent_values = latent_values[idx, :]
            self.latent_values = self.latent_values[:, [2, 3, 4, 5]]
            # latent values (actual values);(245760 x 4)

            self.latent_sizes = np.array([6, 40, 32, 32])
            self.N = self.latent_values.shape[0]

            if args.eval_metrics:
                self.eval_metrics = True
                self.eval_metrics_iter = args.eval_metrics_iter

        # groundtruth factor labels
        elif self.dataset == '3dfaces':

            # latent factor = (id, azimuth, elevation, lighting)
            #   id = {0,1,...,49} (50)
            #   azimuth = {-1.0,-0.9,...,0.9,1.0} (21)
            #   elevation = {-1.0,0.8,...,0.8,1.0} (11)
            #   lighting = {-1.0,0.8,...,0.8,1.0} (11)
            # (number of variations = 50*21*11*11 = 127050)

            latent_classes, latent_values = np.load(
                os.path.join(self.dset_dir,
                             '3d_faces/rtqichen/gt_factor_labels.npy'))
            self.latent_values = latent_values
            # latent values (actual values);(127050 x 4)
            self.latent_classes = latent_classes
            # classes ({0,1,...,K}-valued); (127050 x 4)
            self.latent_sizes = np.array([50, 21, 11, 11])
            self.N = self.latent_values.shape[0]

            if args.eval_metrics:
                self.eval_metrics = True
                self.eval_metrics_iter = args.eval_metrics_iter

        elif self.dataset == 'celeba':

            self.N = 202599
            self.eval_metrics = False

        elif self.dataset == 'edinburgh_teapots':

            # latent factor = (azimuth, elevation, R, G, B)
            #   azimuth = [0, 2*pi]
            #   elevation = [0, pi/2]
            #   R, G, B = [0,1]
            #
            #   "latent_values" = original (real) factor values
            #   "latent_classes" = equal binning into K=10 classes
            #
            # (refer to "data/edinburgh_teapots/my_make_split_data.py")

            K = 10
            val_ranges = [2 * np.pi, np.pi / 2, 1, 1, 1]
            bins = []
            for j in range(5):
                bins.append(np.linspace(0, val_ranges[j], K + 1))

            latent_values = np.load(
                os.path.join(self.dset_dir, 'edinburgh_teapots',
                             'gtfs_tr.npz'))['data']
            latent_values = np.concatenate(
                (latent_values,
                 np.load(
                     os.path.join(self.dset_dir, 'edinburgh_teapots',
                                  'gtfs_va.npz'))['data']),
                axis=0)
            latent_values = np.concatenate(
                (latent_values,
                 np.load(
                     os.path.join(self.dset_dir, 'edinburgh_teapots',
                                  'gtfs_te.npz'))['data']),
                axis=0)
            self.latent_values = latent_values

            latent_classes = np.zeros(latent_values.shape)
            for j in range(5):
                latent_classes[:, j] = np.digitize(latent_values[:, j],
                                                   bins[j])
            self.latent_classes = latent_classes - 1  # {0,...,K-1}-valued

            self.latent_sizes = K * np.ones(5, 'int64')
            self.N = self.latent_values.shape[0]

            if args.eval_metrics:
                self.eval_metrics = True
                self.eval_metrics_iter = args.eval_metrics_iter

        # networks and optimizers
        self.batch_size = args.batch_size
        self.z_dim = args.z_dim
        self.gamma = args.gamma
        self.lr_VAE = args.lr_VAE
        self.beta1_VAE = args.beta1_VAE
        self.beta2_VAE = args.beta2_VAE
        self.lr_D = args.lr_D
        self.beta1_D = args.beta1_D
        self.beta2_D = args.beta2_D

        # visdom setup
        self.viz_on = args.viz_on
        if self.viz_on:

            self.win_id = dict(DZ='win_DZ',
                               recon='win_recon',
                               kl='win_kl',
                               kl_alpha='win_kl_alpha')
            self.line_gather = DataGather('iter', 'p_DZ', 'p_DZ_perm', 'recon',
                                          'kl', 'kl_alpha')

            if self.eval_metrics:
                self.win_id['metrics'] = 'win_metrics'

            import visdom

            self.viz_port = args.viz_port  # port number, eg, 8097
            self.viz = visdom.Visdom(port=self.viz_port)
            self.viz_ll_iter = args.viz_ll_iter
            self.viz_la_iter = args.viz_la_iter

            self.viz_init()

        # create dirs: "records", "ckpts", "outputs" (if not exist)
        mkdirs("records")
        mkdirs("ckpts")
        mkdirs("outputs")

        # set run id
        if args.run_id < 0:  # create a new id
            k = 0
            rfname = os.path.join("records", self.name + '_run_0.txt')
            while os.path.exists(rfname):
                k += 1
                rfname = os.path.join("records", self.name + '_run_%d.txt' % k)
            self.run_id = k
        else:  # user-provided id
            self.run_id = args.run_id

        # finalize name
        self.name = self.name + '_run_' + str(self.run_id)

        # records (text file to store console outputs)
        self.record_file = 'records/%s.txt' % self.name

        # checkpoints
        self.ckpt_dir = os.path.join("ckpts", self.name)

        # outputs
        self.output_dir_recon = os.path.join("outputs", self.name + '_recon')
        # dir for reconstructed images
        self.output_dir_synth = os.path.join("outputs", self.name + '_synth')
        # dir for synthesized images
        self.output_dir_trvsl = os.path.join("outputs", self.name + '_trvsl')
        # dir for latent traversed images

        #### create a new model or load a previously saved model

        self.ckpt_load_iter = args.ckpt_load_iter

        if self.ckpt_load_iter == 0:  # create a new model

            # create a vae model
            if args.dataset.endswith('dsprites'):
                self.encoder = Encoder1(self.z_dim)
                self.decoder = Decoder1(self.z_dim)
            elif args.dataset == '3dfaces':
                self.encoder = Encoder3(self.z_dim)
                self.decoder = Decoder3(self.z_dim)
            elif args.dataset == 'celeba':
                self.encoder = Encoder4(self.z_dim)
                self.decoder = Decoder4(self.z_dim)
            elif args.dataset.endswith('teapots'):
                # self.encoder = Encoder4(self.z_dim)
                # self.decoder = Decoder4(self.z_dim)
                self.encoder = Encoder_ResNet(self.z_dim)
                self.decoder = Decoder_ResNet(self.z_dim)
            else:
                pass  #self.VAE = FactorVAE2(self.z_dim)

            # create a prior alpha model
            self.prior_alpha = PriorAlphaParams(self.z_dim)

            # create a posterior alpha model
            self.post_alpha = PostAlphaParams(self.z_dim)

            # create a discriminator model
            self.D = Discriminator(self.z_dim)

        else:  # load a previously saved model

            print('Loading saved models (iter: %d)...' % self.ckpt_load_iter)
            self.load_checkpoint()
            print('...done')

        if self.use_cuda:
            print('Models moved to GPU...')
            self.encoder = self.encoder.cuda()
            self.decoder = self.decoder.cuda()
            self.prior_alpha = self.prior_alpha.cuda()
            self.post_alpha = self.post_alpha.cuda()
            self.D = self.D.cuda()
            print('...done')

        # get VAE parameters
        vae_params = list(self.encoder.parameters()) + \
            list(self.decoder.parameters()) + \
            list(self.prior_alpha.parameters()) + \
            list(self.post_alpha.parameters())

        # get discriminator parameters
        dis_params = list(self.D.parameters())

        # create optimizers
        self.optim_vae = optim.Adam(vae_params,
                                    lr=self.lr_VAE,
                                    betas=[self.beta1_VAE, self.beta2_VAE])
        self.optim_dis = optim.Adam(dis_params,
                                    lr=self.lr_D,
                                    betas=[self.beta1_D, self.beta2_D])

    ####
    def train(self):

        self.set_mode(train=True)

        ones = torch.ones(self.batch_size, dtype=torch.long)
        zeros = torch.zeros(self.batch_size, dtype=torch.long)
        if self.use_cuda:
            ones = ones.cuda()
            zeros = zeros.cuda()

        # prepare dataloader (iterable)
        print('Start loading data...')
        self.data_loader = create_dataloader(self.args)
        print('...done')

        # iterators from dataloader
        iterator1 = iter(self.data_loader)
        iterator2 = iter(self.data_loader)

        iter_per_epoch = min(len(iterator1), len(iterator2))

        start_iter = self.ckpt_load_iter + 1
        epoch = int(start_iter / iter_per_epoch)

        for iteration in range(start_iter, self.max_iter + 1):

            # reset data iterators for each epoch
            if iteration % iter_per_epoch == 0:
                print('==== epoch %d done ====' % epoch)
                epoch += 1
                iterator1 = iter(self.data_loader)
                iterator2 = iter(self.data_loader)

            #============================================
            #          TRAIN THE VAE (ENC & DEC)
            #============================================

            # sample a mini-batch
            X, ids = next(iterator1)  # (n x C x H x W)
            if self.use_cuda:
                X = X.cuda()

            # enc(X)
            mu, std, logvar = self.encoder(X)

            # prior alpha params
            a, b = self.prior_alpha()

            # posterior alpha params
            ah, bh = self.post_alpha()

            # kl loss
            kls = 0.5 * ( \
                  (ah/bh)*(mu**2+std**2) - 1.0 + \
                  bh.log() - ah.digamma() - logvar )  # (n x z_dim)
            loss_kl = kls.sum(1).mean()

            # kl loss on alpha
            kls_alpha = ( \
                (ah-a)*ah.digamma() - ah.lgamma() + a.lgamma() + \
                a*(bh.log()-b.log()) + (ah/bh)*(b-bh) )  # z_dim-dim
            loss_kl_alpha = kls_alpha.sum() / self.N

            # reparam'ed samples
            if self.use_cuda:
                Eps = torch.cuda.FloatTensor(mu.shape).normal_()
            else:
                Eps = torch.randn(mu.shape)
            Z = mu + Eps * std

            # dec(Z)
            X_recon = self.decoder(Z)

            # recon loss
            loss_recon = F.binary_cross_entropy_with_logits(
                X_recon, X, reduction='sum').div(X.size(0))

            # dis(Z)
            DZ = self.D(Z)

            # tc loss
            loss_tc = (DZ[:, 0] - DZ[:, 1]).mean()

            # total loss for vae
            vae_loss = loss_recon + loss_kl + loss_kl_alpha + \
                       self.gamma*loss_tc

            # update vae
            self.optim_vae.zero_grad()
            vae_loss.backward()
            self.optim_vae.step()

            #============================================
            #          TRAIN THE DISCRIMINATOR
            #============================================

            # sample a mini-batch
            X2, ids = next(iterator2)  # (n x C x H x W)
            if self.use_cuda:
                X2 = X2.cuda()

            # enc(X2)
            mu, std, _ = self.encoder(X2)

            # reparam'ed samples
            if self.use_cuda:
                Eps = torch.cuda.FloatTensor(mu.shape).normal_()
            else:
                Eps = torch.randn(mu.shape)
            Z = mu + Eps * std

            # dis(Z)
            DZ = self.D(Z)

            # dim-wise permutated Z over the mini-batch
            perm_Z = []
            for zj in Z.split(1, 1):
                idx = torch.randperm(Z.size(0))
                perm_zj = zj[idx]
                perm_Z.append(perm_zj)
            Z_perm = torch.cat(perm_Z, 1)
            Z_perm = Z_perm.detach()

            # dis(Z_perm)
            DZ_perm = self.D(Z_perm)

            # discriminator loss
            dis_loss = 0.5 * (F.cross_entropy(DZ, zeros) +
                              F.cross_entropy(DZ_perm, ones))

            # update discriminator
            self.optim_dis.zero_grad()
            dis_loss.backward()
            self.optim_dis.step()

            ##########################################

            # print the losses
            if iteration % self.print_iter == 0:
                prn_str = ( '[iter %d (epoch %d)] vae_loss: %.3f | ' + \
                    'dis_loss: %.3f\n    ' + \
                    '(recon: %.3f, kl: %.3f, kl_alpha: %.3f, tc: %.3f)' \
                  ) % \
                  ( iteration, epoch, vae_loss.item(), dis_loss.item(),
                    loss_recon.item(), loss_kl.item(), loss_kl_alpha.item(),
                    loss_tc.item() )
                prn_str += '\n    a = {}'.format(
                    a.detach().cpu().numpy().round(2))
                prn_str += '\n    b = {}'.format(
                    b.detach().cpu().numpy().round(2))
                prn_str += '\n    ah = {}'.format(
                    ah.detach().cpu().numpy().round(2))
                prn_str += '\n    bh = {}'.format(
                    bh.detach().cpu().numpy().round(2))
                print(prn_str)
                if self.record_file:
                    record = open(self.record_file, 'a')
                    record.write('%s\n' % (prn_str, ))
                    record.close()

            # save model parameters
            if iteration % self.ckpt_save_iter == 0:
                self.save_checkpoint(iteration)

            # save output images (recon, synth, etc.)
            if iteration % self.output_save_iter == 0:

                # 1) save the recon images
                self.save_recon(iteration, X, torch.sigmoid(X_recon).data)

                # 2) save the synth images
                self.save_synth(iteration, howmany=100)

                # 3) save the latent traversed images
                if self.dataset.lower() == '3dchairs':
                    self.save_traverse(iteration, limb=-2, limu=2, inter=0.5)
                else:
                    self.save_traverse(iteration, limb=-3, limu=3, inter=0.1)

            # (visdom) insert current line stats
            if self.viz_on and (iteration % self.viz_ll_iter == 0):

                # compute discriminator accuracy
                p_DZ = F.softmax(DZ, 1)[:, 0].detach()
                p_DZ_perm = F.softmax(DZ_perm, 1)[:, 0].detach()

                # insert line stats
                self.line_gather.insert(iter=iteration,
                                        p_DZ=p_DZ.mean().item(),
                                        p_DZ_perm=p_DZ_perm.mean().item(),
                                        recon=loss_recon.item(),
                                        kl=loss_kl.item(),
                                        kl_alpha=loss_kl_alpha.item())

            # (visdom) visualize line stats (then flush out)
            if self.viz_on and (iteration % self.viz_la_iter == 0):
                self.visualize_line()
                self.line_gather.flush()

            # evaluate metrics
            if self.eval_metrics and (iteration % self.eval_metrics_iter == 0):

                metric1, _ = self.eval_disentangle_metric1()
                metric2, _ = self.eval_disentangle_metric2()

                prn_str = ( '********\n[iter %d (epoch %d)] ' + \
                  'metric1 = %.4f, metric2 = %.4f\n********' ) % \
                  (iteration, epoch, metric1, metric2)
                print(prn_str)
                if self.record_file:
                    record = open(self.record_file, 'a')
                    record.write('%s\n' % (prn_str, ))
                    record.close()

                # (visdom) visulaize metrics
                if self.viz_on:
                    self.visualize_line_metrics(iteration, metric1, metric2)

    ####
    def eval_disentangle_metric1(self):

        # some hyperparams
        num_pairs = 800  # # data pairs (d,y) for majority vote classification
        bs = 50  # batch size
        nsamps_per_factor = 100  # samples per factor
        nsamps_agn_factor = 5000  # factor-agnostic samples

        self.set_mode(train=False)

        # 1) estimate variances of latent points factor agnostic

        dl = DataLoader(self.data_loader.dataset,
                        batch_size=bs,
                        shuffle=True,
                        num_workers=self.args.num_workers,
                        pin_memory=True)
        iterator = iter(dl)

        M = []
        for ib in range(int(nsamps_agn_factor / bs)):

            # sample a mini-batch
            Xb, _ = next(iterator)  # (bs x C x H x W)
            if self.use_cuda:
                Xb = Xb.cuda()

            # enc(Xb)
            mub, _, _ = self.encoder(Xb)  # (bs x z_dim)

            M.append(mub.cpu().detach().numpy())

        M = np.concatenate(M, 0)

        # estimate sample vairance and mean of latent points for each dim
        vars_agn_factor = np.var(M, 0)

        # 2) estimatet dim-wise vars of latent points with "one factor fixed"

        factor_ids = range(0, len(self.latent_sizes))  # true factor ids
        vars_per_factor = np.zeros([num_pairs, self.z_dim])
        true_factor_ids = np.zeros(num_pairs, np.int)  # true factor ids

        # prepare data pairs for majority-vote classification
        i = 0
        for j in factor_ids:  # for each factor

            # repeat num_paris/num_factors times
            for r in range(int(num_pairs / len(factor_ids))):

                # a true factor (id and class value) to fix
                fac_id = j
                fac_class = np.random.randint(self.latent_sizes[fac_id])

                # randomly select images (with the fixed factor)
                indices = np.where(self.latent_classes[:,
                                                       fac_id] == fac_class)[0]
                np.random.shuffle(indices)
                idx = indices[:nsamps_per_factor]
                M = []
                for ib in range(int(nsamps_per_factor / bs)):
                    Xb, _ = dl.dataset[idx[(ib * bs):(ib + 1) * bs]]
                    if Xb.shape[0] < 1:  # no more samples
                        continue
                    if self.use_cuda:
                        Xb = Xb.cuda()
                    mub, _, _ = self.encoder(Xb)  # (bs x z_dim)
                    M.append(mub.cpu().detach().numpy())
                M = np.concatenate(M, 0)

                # estimate sample var and mean of latent points for each dim
                if M.shape[0] >= 2:
                    vars_per_factor[i, :] = np.var(M, 0)
                else:  # not enough samples to estimate variance
                    vars_per_factor[i, :] = 0.0

                # true factor id (will become the class label)
                true_factor_ids[i] = fac_id

                i += 1

        # 3) evaluate majority vote classification accuracy

        # inputs in the paired data for classification
        smallest_var_dims = np.argmin(vars_per_factor /
                                      (vars_agn_factor + 1e-20),
                                      axis=1)

        # contingency table
        C = np.zeros([self.z_dim, len(factor_ids)])
        for i in range(num_pairs):
            C[smallest_var_dims[i], true_factor_ids[i]] += 1

        num_errs = 0  # # misclassifying errors of majority vote classifier
        for k in range(self.z_dim):
            num_errs += np.sum(C[k, :]) - np.max(C[k, :])

        metric1 = (num_pairs - num_errs) / num_pairs  # metric = accuracy

        self.set_mode(train=True)

        return metric1, C

    ####
    def eval_disentangle_metric2(self):

        # some hyperparams
        num_pairs = 800  # # data pairs (d,y) for majority vote classification
        bs = 50  # batch size
        nsamps_per_factor = 100  # samples per factor
        nsamps_agn_factor = 5000  # factor-agnostic samples

        self.set_mode(train=False)

        # 1) estimate variances of latent points factor agnostic

        dl = DataLoader(self.data_loader.dataset,
                        batch_size=bs,
                        shuffle=True,
                        num_workers=self.args.num_workers,
                        pin_memory=True)
        iterator = iter(dl)

        M = []
        for ib in range(int(nsamps_agn_factor / bs)):

            # sample a mini-batch
            Xb, _ = next(iterator)  # (bs x C x H x W)
            if self.use_cuda:
                Xb = Xb.cuda()

            # enc(Xb)
            mub, _, _ = self.encoder(Xb)  # (bs x z_dim)

            M.append(mub.cpu().detach().numpy())

        M = np.concatenate(M, 0)

        # estimate sample vairance and mean of latent points for each dim
        vars_agn_factor = np.var(M, 0)

        # 2) estimatet dim-wise vars of latent points with "one factor varied"

        factor_ids = range(0, len(self.latent_sizes))  # true factor ids
        vars_per_factor = np.zeros([num_pairs, self.z_dim])
        true_factor_ids = np.zeros(num_pairs, np.int)  # true factor ids

        # prepare data pairs for majority-vote classification
        i = 0
        for j in factor_ids:  # for each factor

            # repeat num_paris/num_factors times
            for r in range(int(num_pairs / len(factor_ids))):

                # randomly choose true factors (id's and class values) to fix
                fac_ids = list(np.setdiff1d(factor_ids, j))
                fac_classes = \
                  [ np.random.randint(self.latent_sizes[k]) for k in fac_ids ]

                # randomly select images (with the other factors fixed)
                if len(fac_ids) > 1:
                    indices = np.where(
                        np.sum(self.latent_classes[:, fac_ids] == fac_classes,
                               1) == len(fac_ids))[0]
                else:
                    indices = np.where(
                        self.latent_classes[:, fac_ids] == fac_classes)[0]
                np.random.shuffle(indices)
                idx = indices[:nsamps_per_factor]
                M = []
                for ib in range(int(nsamps_per_factor / bs)):
                    Xb, _ = dl.dataset[idx[(ib * bs):(ib + 1) * bs]]
                    if Xb.shape[0] < 1:  # no more samples
                        continue
                    if self.use_cuda:
                        Xb = Xb.cuda()
                    mub, _, _ = self.encoder(Xb)  # (bs x z_dim)
                    M.append(mub.cpu().detach().numpy())
                M = np.concatenate(M, 0)

                # estimate sample var and mean of latent points for each dim
                if M.shape[0] >= 2:
                    vars_per_factor[i, :] = np.var(M, 0)
                else:  # not enough samples to estimate variance
                    vars_per_factor[i, :] = 0.0

                # true factor id (will become the class label)
                true_factor_ids[i] = j

                i += 1

        # 3) evaluate majority vote classification accuracy

        # inputs in the paired data for classification
        largest_var_dims = np.argmax(vars_per_factor /
                                     (vars_agn_factor + 1e-20),
                                     axis=1)

        # contingency table
        C = np.zeros([self.z_dim, len(factor_ids)])
        for i in range(num_pairs):
            C[largest_var_dims[i], true_factor_ids[i]] += 1

        num_errs = 0  # # misclassifying errors of majority vote classifier
        for k in range(self.z_dim):
            num_errs += np.sum(C[k, :]) - np.max(C[k, :])

        metric2 = (num_pairs - num_errs) / num_pairs  # metric = accuracy

        self.set_mode(train=True)

        return metric2, C

    ####
    def save_recon(self, iters, true_images, recon_images):

        # make a merge of true and recon, eg,
        #   merged[0,...] = true[0,...],
        #   merged[1,...] = recon[0,...],
        #   merged[2,...] = true[1,...],
        #   merged[3,...] = recon[1,...], ...

        n = true_images.shape[0]
        perm = torch.arange(0, 2 * n).view(2, n).transpose(1, 0)
        perm = perm.contiguous().view(-1)
        merged = torch.cat([true_images, recon_images], dim=0)
        merged = merged[perm, :].cpu()

        # save the results as image
        fname = os.path.join(self.output_dir_recon, 'recon_%s.jpg' % iters)
        mkdirs(self.output_dir_recon)
        save_image(tensor=merged,
                   filename=fname,
                   nrow=2 * int(np.sqrt(n)),
                   pad_value=1)

    ####
    def save_synth(self, iters, howmany=100):

        self.set_mode(train=False)

        decoder = self.decoder

        Z = torch.randn(howmany, self.z_dim)
        if self.use_cuda:
            Z = Z.cuda()

        # do synthesis
        X = torch.sigmoid(decoder(Z)).data.cpu()

        # save the results as image
        fname = os.path.join(self.output_dir_synth, 'synth_%s.jpg' % iters)
        mkdirs(self.output_dir_synth)
        save_image(tensor=X,
                   filename=fname,
                   nrow=int(np.sqrt(howmany)),
                   pad_value=1)

        self.set_mode(train=True)

    ####
    def save_traverse(self, iters, limb=-3, limu=3, inter=2 / 3, loc=-1):

        self.set_mode(train=False)

        encoder = self.encoder
        decoder = self.decoder
        interpolation = torch.arange(limb, limu + 0.001, inter)

        i = np.random.randint(self.N)
        random_img = self.data_loader.dataset.__getitem__(i)[0]
        if self.use_cuda:
            random_img = random_img.cuda()
        random_img = random_img.unsqueeze(0)
        random_img_zmu, _, _ = encoder(random_img)

        if self.dataset.lower() == 'dsprites':

            fixed_idx1 = 87040  # square
            fixed_idx2 = 332800  # ellipse
            fixed_idx3 = 578560  # heart

            fixed_img1 = self.data_loader.dataset.__getitem__(fixed_idx1)[0]
            if self.use_cuda:
                fixed_img1 = fixed_img1.cuda()
            fixed_img1 = fixed_img1.unsqueeze(0)
            fixed_img_zmu1, _, _ = encoder(fixed_img1)

            fixed_img2 = self.data_loader.dataset.__getitem__(fixed_idx2)[0]
            if self.use_cuda:
                fixed_img2 = fixed_img2.cuda()
            fixed_img2 = fixed_img2.unsqueeze(0)
            fixed_img_zmu2, _, _ = encoder(fixed_img2)

            fixed_img3 = self.data_loader.dataset.__getitem__(fixed_idx3)[0]
            if self.use_cuda:
                fixed_img3 = fixed_img3.cuda()
            fixed_img3 = fixed_img3.unsqueeze(0)
            fixed_img_zmu3, _, _ = encoder(fixed_img3)

            IMG = {
                'fixed_square': fixed_img1,
                'fixed_ellipse': fixed_img2,
                'fixed_heart': fixed_img3,
                'random_img': random_img
            }

            Z = {
                'fixed_square': fixed_img_zmu1,
                'fixed_ellipse': fixed_img_zmu2,
                'fixed_heart': fixed_img_zmu3,
                'random_img': random_img_zmu
            }

        elif self.dataset.lower() == 'oval_dsprites':

            fixed_idx1 = 87040  # oval1
            fixed_idx2 = 220045  # oval2
            fixed_idx3 = 178560  # oval3

            fixed_img1 = self.data_loader.dataset.__getitem__(fixed_idx1)[0]
            if self.use_cuda:
                fixed_img1 = fixed_img1.cuda()
            fixed_img1 = fixed_img1.unsqueeze(0)
            fixed_img_zmu1, _, _ = encoder(fixed_img1)

            fixed_img2 = self.data_loader.dataset.__getitem__(fixed_idx2)[0]
            if self.use_cuda:
                fixed_img2 = fixed_img2.cuda()
            fixed_img2 = fixed_img2.unsqueeze(0)
            fixed_img_zmu2, _, _ = encoder(fixed_img2)

            fixed_img3 = self.data_loader.dataset.__getitem__(fixed_idx3)[0]
            if self.use_cuda:
                fixed_img3 = fixed_img3.cuda()
            fixed_img3 = fixed_img3.unsqueeze(0)
            fixed_img_zmu3, _, _ = encoder(fixed_img3)

            IMG = {
                'fixed1': fixed_img1,
                'fixed2': fixed_img2,
                'fixed3': fixed_img3,
                'random_img': random_img
            }

            Z = {
                'fixed1': fixed_img_zmu1,
                'fixed2': fixed_img_zmu2,
                'fixed3': fixed_img_zmu3,
                'random_img': random_img_zmu
            }

        elif self.dataset.lower() == '3dfaces':

            fixed_idx1 = 6245
            fixed_idx2 = 10205
            fixed_idx3 = 68560

            fixed_img1 = self.data_loader.dataset.__getitem__(fixed_idx1)[0]
            if self.use_cuda:
                fixed_img1 = fixed_img1.cuda()
            fixed_img1 = fixed_img1.unsqueeze(0)
            fixed_img_zmu1, _, _ = encoder(fixed_img1)

            fixed_img2 = self.data_loader.dataset.__getitem__(fixed_idx2)[0]
            if self.use_cuda:
                fixed_img2 = fixed_img2.cuda()
            fixed_img2 = fixed_img2.unsqueeze(0)
            fixed_img_zmu2, _, _ = encoder(fixed_img2)

            fixed_img3 = self.data_loader.dataset.__getitem__(fixed_idx3)[0]
            if self.use_cuda:
                fixed_img3 = fixed_img3.cuda()
            fixed_img3 = fixed_img3.unsqueeze(0)
            fixed_img_zmu3, _, _ = encoder(fixed_img3)

            IMG = {
                'fixed1': fixed_img1,
                'fixed2': fixed_img2,
                'fixed3': fixed_img3,
                'random_img': random_img
            }

            Z = {
                'fixed1': fixed_img_zmu1,
                'fixed2': fixed_img_zmu2,
                'fixed3': fixed_img_zmu3,
                'random_img': random_img_zmu
            }

        elif self.dataset.lower() == 'celeba':

            fixed_idx1 = 191281
            fixed_idx2 = 143307
            fixed_idx3 = 101535

            fixed_img1 = self.data_loader.dataset.__getitem__(fixed_idx1)[0]
            if self.use_cuda:
                fixed_img1 = fixed_img1.cuda()
            fixed_img1 = fixed_img1.unsqueeze(0)
            fixed_img_zmu1, _, _ = encoder(fixed_img1)

            fixed_img2 = self.data_loader.dataset.__getitem__(fixed_idx2)[0]
            if self.use_cuda:
                fixed_img2 = fixed_img2.cuda()
            fixed_img2 = fixed_img2.unsqueeze(0)
            fixed_img_zmu2, _, _ = encoder(fixed_img2)

            fixed_img3 = self.data_loader.dataset.__getitem__(fixed_idx3)[0]
            if self.use_cuda:
                fixed_img3 = fixed_img3.cuda()
            fixed_img3 = fixed_img3.unsqueeze(0)
            fixed_img_zmu3, _, _ = encoder(fixed_img3)

            IMG = {
                'fixed1': fixed_img1,
                'fixed2': fixed_img2,
                'fixed3': fixed_img3,
                'random_img': random_img
            }

            Z = {
                'fixed1': fixed_img_zmu1,
                'fixed2': fixed_img_zmu2,
                'fixed3': fixed_img_zmu3,
                'random_img': random_img_zmu
            }

        elif self.dataset.lower() == 'edinburgh_teapots':

            fixed_idx1 = 7040
            fixed_idx2 = 32800
            fixed_idx3 = 78560

            fixed_img1 = self.data_loader.dataset.__getitem__(fixed_idx1)[0]
            if self.use_cuda:
                fixed_img1 = fixed_img1.cuda()
            fixed_img1 = fixed_img1.unsqueeze(0)
            fixed_img_zmu1, _, _ = encoder(fixed_img1)

            fixed_img2 = self.data_loader.dataset.__getitem__(fixed_idx2)[0]
            if self.use_cuda:
                fixed_img2 = fixed_img2.cuda()
            fixed_img2 = fixed_img2.unsqueeze(0)
            fixed_img_zmu2, _, _ = encoder(fixed_img2)

            fixed_img3 = self.data_loader.dataset.__getitem__(fixed_idx3)[0]
            if self.use_cuda:
                fixed_img3 = fixed_img3.cuda()
            fixed_img3 = fixed_img3.unsqueeze(0)
            fixed_img_zmu3, _, _ = encoder(fixed_img3)

            IMG = {
                'fixed1': fixed_img1,
                'fixed2': fixed_img2,
                'fixed3': fixed_img3,
                'random_img': random_img
            }

            Z = {
                'fixed1': fixed_img_zmu1,
                'fixed2': fixed_img_zmu2,
                'fixed3': fixed_img_zmu3,
                'random_img': random_img_zmu
            }

#        elif self.dataset.lower() == '3dchairs':
#
#            fixed_idx1 = 40919 # 3DChairs/images/4682_image_052_p030_t232_r096.png
#            fixed_idx2 = 5172  # 3DChairs/images/14657_image_020_p020_t232_r096.png
#            fixed_idx3 = 22330 # 3DChairs/images/30099_image_052_p030_t232_r096.png
#
#            fixed_img1 = self.data_loader.dataset.__getitem__(fixed_idx1)[0]
#            fixed_img1 = fixed_img1.to(self.device).unsqueeze(0)
#            fixed_img_z1 = encoder(fixed_img1)[:, :self.z_dim]
#
#            fixed_img2 = self.data_loader.dataset.__getitem__(fixed_idx2)[0]
#            fixed_img2 = fixed_img2.to(self.device).unsqueeze(0)
#            fixed_img_z2 = encoder(fixed_img2)[:, :self.z_dim]
#
#            fixed_img3 = self.data_loader.dataset.__getitem__(fixed_idx3)[0]
#            fixed_img3 = fixed_img3.to(self.device).unsqueeze(0)
#            fixed_img_z3 = encoder(fixed_img3)[:, :self.z_dim]
#
#            Z = {'fixed_1':fixed_img_z1, 'fixed_2':fixed_img_z2,
#                 'fixed_3':fixed_img_z3, 'random':random_img_zmu}
#
        else:

            raise NotImplementedError

        # do traversal and collect generated images
        gifs = []
        for key in Z:
            z_ori = Z[key]
            for row in range(self.z_dim):
                if loc != -1 and row != loc:
                    continue
                z = z_ori.clone()
                for val in interpolation:
                    z[:, row] = val
                    sample = torch.sigmoid(decoder(z)).data
                    gifs.append(sample)

        # save the generated files, also the animated gifs
        out_dir = os.path.join(self.output_dir_trvsl, str(iters))
        mkdirs(self.output_dir_trvsl)
        mkdirs(out_dir)
        gifs = torch.cat(gifs)
        gifs = gifs.view(len(Z), self.z_dim, len(interpolation), self.nc, 64,
                         64).transpose(1, 2)
        for i, key in enumerate(Z.keys()):
            for j, val in enumerate(interpolation):
                I = torch.cat([IMG[key], gifs[i][j]], dim=0)
                save_image(tensor=I.cpu(),
                           filename=os.path.join(out_dir,
                                                 '%s_%03d.jpg' % (key, j)),
                           nrow=1 + self.z_dim,
                           pad_value=1)
            # make animated gif
            grid2gif(out_dir,
                     key,
                     str(os.path.join(out_dir, key + '.gif')),
                     delay=10)

        self.set_mode(train=True)

    ####
    def viz_init(self):

        self.viz.close(env=self.name + '/lines', win=self.win_id['DZ'])
        self.viz.close(env=self.name + '/lines', win=self.win_id['recon'])
        self.viz.close(env=self.name + '/lines', win=self.win_id['kl'])
        self.viz.close(env=self.name + '/lines', win=self.win_id['kl_alpha'])

        if self.eval_metrics:
            self.viz.close(env=self.name + '/lines',
                           win=self.win_id['metrics'])

    ####
    def visualize_line(self):

        # prepare data to plot
        data = self.line_gather.data
        iters = torch.Tensor(data['iter'])
        recon = torch.Tensor(data['recon'])
        kl = torch.Tensor(data['kl'])
        kl_alpha = torch.Tensor(data['kl_alpha'])

        p_DZ = torch.Tensor(data['p_DZ'])
        p_DZ_perm = torch.Tensor(data['p_DZ_perm'])
        p_DZs = torch.stack([p_DZ, p_DZ_perm], -1)  # (#items x 2)

        self.viz.line(X=iters,
                      Y=p_DZs,
                      env=self.name + '/lines',
                      win=self.win_id['DZ'],
                      update='append',
                      opts=dict(xlabel='iter',
                                ylabel='D(z)',
                                title='Discriminator-Z',
                                legend=[
                                    'D(z)',
                                    'D(z_perm)',
                                ]))

        self.viz.line(X=iters,
                      Y=recon,
                      env=self.name + '/lines',
                      win=self.win_id['recon'],
                      update='append',
                      opts=dict(xlabel='iter',
                                ylabel='recon loss',
                                title='Reconstruction'))

        self.viz.line(X=iters,
                      Y=kl,
                      env=self.name + '/lines',
                      win=self.win_id['kl'],
                      update='append',
                      opts=dict(xlabel='iter',
                                ylabel='E_q(alpha)E_x[kl(q(z|x)||p(z|alpha)]',
                                title='KL divergence'))

        self.viz.line(X=iters,
                      Y=kl_alpha,
                      env=self.name + '/lines',
                      win=self.win_id['kl_alpha'],
                      update='append',
                      opts=dict(xlabel='iter',
                                ylabel='KL(q(alpha)||p(alpha)) / N',
                                title='KL divergence on alpha'))

    ####
    def visualize_line_metrics(self, iters, metric1, metric2):

        # prepare data to plot
        iters = torch.tensor([iters], dtype=torch.int64).detach()
        metric1 = torch.tensor([metric1])
        metric2 = torch.tensor([metric2])
        metrics = torch.stack([metric1.detach(), metric2.detach()], -1)

        self.viz.line(X=iters,
                      Y=metrics,
                      env=self.name + '/lines',
                      win=self.win_id['metrics'],
                      update='append',
                      opts=dict(xlabel='iter',
                                ylabel='metrics',
                                title='Disentanglement metrics',
                                legend=['metric1', 'metric2']))

    ####
    def set_mode(self, train=True):

        if train:
            self.encoder.train()
            self.decoder.train()
            self.D.train()
        else:
            self.encoder.eval()
            self.decoder.eval()
            self.D.eval()

    ####
    def save_checkpoint(self, iteration):

        encoder_path = os.path.join(self.ckpt_dir,
                                    'iter_%s_encoder.pt' % iteration)
        decoder_path = os.path.join(self.ckpt_dir,
                                    'iter_%s_decoder.pt' % iteration)
        prior_alpha_path = os.path.join(self.ckpt_dir,
                                        'iter_%s_prior_alpha.pt' % iteration)
        post_alpha_path = os.path.join(self.ckpt_dir,
                                       'iter_%s_post_alpha.pt' % iteration)
        D_path = os.path.join(self.ckpt_dir, 'iter_%s_D.pt' % iteration)

        mkdirs(self.ckpt_dir)

        torch.save(self.encoder, encoder_path)
        torch.save(self.decoder, decoder_path)
        torch.save(self.prior_alpha, prior_alpha_path)
        torch.save(self.post_alpha, post_alpha_path)
        torch.save(self.D, D_path)

    ####
    def load_checkpoint(self):

        encoder_path = os.path.join(self.ckpt_dir,
                                    'iter_%s_encoder.pt' % self.ckpt_load_iter)
        decoder_path = os.path.join(self.ckpt_dir,
                                    'iter_%s_decoder.pt' % self.ckpt_load_iter)
        prior_alpha_path = os.path.join(
            self.ckpt_dir, 'iter_%s_prior_alpha.pt' % self.ckpt_load_iter)
        post_alpha_path = os.path.join(
            self.ckpt_dir, 'iter_%s_post_alpha.pt' % self.ckpt_load_iter)
        D_path = os.path.join(self.ckpt_dir,
                              'iter_%s_D.pt' % self.ckpt_load_iter)

        if self.use_cuda:
            self.encoder = torch.load(encoder_path)
            self.decoder = torch.load(decoder_path)
            self.prior_alpha = torch.load(prior_alpha_path)
            self.post_alpha = torch.load(post_alpha_path)
            self.D = torch.load(D_path)
        else:
            self.encoder = torch.load(encoder_path, map_location='cpu')
            self.decoder = torch.load(decoder_path, map_location='cpu')
            self.prior_alpha = torch.load(prior_alpha_path, map_location='cpu')
            self.post_alpha = torch.load(post_alpha_path, map_location='cpu')
            self.D = torch.load(D_path, map_location='cpu')
Example #9
0
    def __init__(self, args):
        # Misc
        use_cuda = args.cuda and torch.cuda.is_available()
        self.device = 'cuda' if use_cuda else 'cpu'
        self.name = args.name
        self.max_iter = int(args.max_iter)
        self.print_iter = args.print_iter
        self.global_iter = 0
        self.pbar = tqdm(total=self.max_iter)

        # Data
        self.dset_dir = args.dset_dir
        self.dataset = args.dataset
        self.batch_size = args.batch_size
        self.data_loader, self.data = return_data(args)

        # Networks & Optimizers
        self.z_dim = args.z_dim
        self.gamma = args.gamma

        self.lr_VAE = args.lr_VAE
        self.beta1_VAE = args.beta1_VAE
        self.beta2_VAE = args.beta2_VAE

        self.lr_D = args.lr_D
        self.beta1_D = args.beta1_D
        self.beta2_D = args.beta2_D

        if args.dataset == 'dsprites':
            self.VAE = FactorVAE1(self.z_dim).to(self.device)
            self.nc = 1
        else:
            self.VAE = FactorVAE2(self.z_dim).to(self.device)
            self.nc = 3
        self.optim_VAE = optim.Adam(self.VAE.parameters(), lr=self.lr_VAE,
                                    betas=(self.beta1_VAE, self.beta2_VAE))

        self.D = Discriminator(self.z_dim).to(self.device)
        self.optim_D = optim.Adam(self.D.parameters(), lr=self.lr_D,
                                  betas=(self.beta1_D, self.beta2_D))

        self.nets = [self.VAE, self.D]

        # Visdom
        self.viz_on = args.viz_on
        self.win_id = dict(D_z='win_D_z', recon='win_recon', kld='win_kld', acc='win_acc')
        self.line_gather = DataGather('iter', 'soft_D_z', 'soft_D_z_pperm', 'recon', 'kld', 'acc')
        self.image_gather = DataGather('true', 'recon')
        if self.viz_on:
            self.viz_port = args.viz_port
            self.viz = visdom.Visdom(log_to_filename='./logging.log', offline=True)
            self.viz_ll_iter = args.viz_ll_iter
            self.viz_la_iter = args.viz_la_iter
            self.viz_ra_iter = args.viz_ra_iter
            self.viz_ta_iter = args.viz_ta_iter
            if not self.viz.win_exists(env=self.name+'/lines', win=self.win_id['D_z']):
                self.viz_init()

        # Checkpoint
        self.ckpt_dir = os.path.join(args.ckpt_dir, args.name)
        self.ckpt_save_iter = args.ckpt_save_iter
        mkdirs(self.ckpt_dir)
        if args.ckpt_load:
            self.load_checkpoint(args.ckpt_load)

        # Output(latent traverse GIF)
        self.output_dir = os.path.join(args.output_dir, args.name)
        self.output_save = args.output_save
        mkdirs(self.output_dir)
Example #10
0
class Trainer(object):
    def __init__(self, args):
        self.use_cuda = args.cuda and torch.cuda.is_available()
        self.max_epoch = args.max_epoch
        self.global_epoch = 0
        self.global_iter = 0

        self.z_dim = args.z_dim
        self.z_var = args.z_var
        self.z_sigma = math.sqrt(args.z_var)
        self.prior_dist = torch.distributions.Normal(
            torch.zeros(self.z_dim),
            torch.ones(self.z_dim) * self.z_sigma)
        self._lambda = args.reg_weight
        self.lr = args.lr
        self.lr_D = args.lr_D
        self.beta1 = args.beta1
        self.beta2 = args.beta2
        self.lr_schedules = {30: 2, 50: 5, 100: 10}

        if args.dataset.lower() == 'celeba':
            self.nc = 3
            self.decoder_dist = 'gaussian'
        else:
            self.nc = 1
            self.decoder_dist = 'gaussian'
            # raise NotImplementedError

        self.net = cuda(WAE(self.z_dim, self.nc), self.use_cuda)
        self.optim = optim.Adam(self.net.parameters(),
                                lr=self.lr,
                                betas=(self.beta1, self.beta2))

        self.D = cuda(Adversary(self.z_dim), self.use_cuda)
        self.optim_D = optim.Adam(self.D.parameters(),
                                  lr=self.lr_D,
                                  betas=(self.beta1, self.beta2))

        self.gather = DataGather()
        self.viz_name = args.viz_name
        self.viz_port = args.viz_port
        self.viz_on = args.viz_on
        if self.viz_on:
            self.viz = visdom.Visdom(env=self.viz_name + '_lines',
                                     port=self.viz_port)
            self.win_recon = None
            self.win_QD = None
            self.win_D = None
            self.win_mu = None
            self.win_var = None
        else:
            self.viz = None
            self.win_recon = None
            self.win_QD = None
            self.win_D = None
            self.win_mu = None
            self.win_var = None

        self.ckpt_dir = Path(args.ckpt_dir).joinpath(args.viz_name)
        if not self.ckpt_dir.exists():
            self.ckpt_dir.mkdir(parents=True, exist_ok=True)
        self.ckpt_name = args.ckpt_name
        if self.ckpt_name is not None:
            self.load_checkpoint(self.ckpt_name)

        self.save_output = args.save_output
        self.output_dir = Path(args.output_dir).joinpath(args.viz_name)
        if not self.output_dir.exists():
            self.output_dir.mkdir(parents=True, exist_ok=True)

        self.dset_dir = args.dset_dir
        self.dataset = args.dataset
        self.batch_size = args.batch_size
        self.data_loader = return_data(args)

    def train(self):
        self.net.train()

        ones = Variable(cuda(torch.ones(self.batch_size, 1), self.use_cuda))
        zeros = Variable(cuda(torch.zeros(self.batch_size, 1), self.use_cuda))

        iters_per_epoch = len(self.data_loader)
        max_iter = self.max_epoch * iters_per_epoch
        pbar = tqdm(total=max_iter)
        with tqdm(total=max_iter) as pbar:
            pbar.update(self.global_iter)
            out = False
            while not out:
                for x in self.data_loader:
                    #x,label = x
                    pbar.update(1)
                    self.global_iter += 1
                    if self.global_iter % iters_per_epoch == 0:
                        self.global_epoch += 1
                    self.optim = multistep_lr_decay(self.optim,
                                                    self.global_epoch,
                                                    self.lr_schedules)

                    x = Variable(cuda(x, self.use_cuda))
                    x_recon, z_tilde = self.net(x)
                    z = self.sample_z(template=z_tilde, sigma=self.z_sigma)
                    log_p_z = log_density_igaussian(z, self.z_var).view(-1, 1)

                    #D_z = self.D(z) + log_p_z.view(-1, 1)
                    #D_z_tilde = self.D(z_tilde) + log_p_z.view(-1, 1)
                    D_z = self.D(z)
                    D_z_tilde = self.D(z_tilde)
                    D_loss = F.binary_cross_entropy_with_logits(D_z+log_p_z, ones) + \
                             F.binary_cross_entropy_with_logits(D_z_tilde+log_p_z, zeros)
                    total_D_loss = self._lambda * D_loss

                    self.optim_D.zero_grad()
                    total_D_loss.backward(retain_graph=True)
                    self.optim_D.step()

                    recon_loss = F.mse_loss(
                        x_recon, x, size_average=False).div(self.batch_size)
                    Q_loss = F.binary_cross_entropy_with_logits(
                        D_z_tilde + log_p_z, ones)
                    total_AE_loss = recon_loss + self._lambda * Q_loss

                    self.optim.zero_grad()
                    total_AE_loss.backward()
                    self.optim.step()

                    if self.global_iter % 10 == 0:
                        self.gather.insert(
                            iter=self.global_iter,
                            D_z=F.sigmoid(D_z).mean().detach().data,
                            D_z_tilde=F.sigmoid(
                                D_z_tilde).mean().detach().data,
                            mu=z.mean(0).data,
                            var=z.var(0).data,
                            recon_loss=recon_loss.data,
                            Q_loss=Q_loss.data,
                            D_loss=D_loss.data)

                    if self.global_iter % 50 == 0:
                        self.save_reconstruction()
                        if self.viz:
                            self.gather.insert(images=x.data)
                            self.gather.insert(images=x_recon.data)
                            self.viz_reconstruction()
                            self.viz_lines()
                            self.sample_x_from_z(n_sample=100)
                            self.gather.flush()
                            self.save_checkpoint('last')
                            pbar.write(
                                '[{}] recon_loss:{:.3f} Q_loss:{:.3f} D_loss:{:.3f}'
                                .format(self.global_iter, recon_loss.data[0],
                                        Q_loss.data[0], D_loss.data[0]))
                            pbar.write('D_z:{:.3f} D_z_tilde:{:.3f}'.format(
                                F.sigmoid(D_z).mean().detach().data[0],
                                F.sigmoid(D_z_tilde).mean().detach().data[0]))

                    if self.global_iter % 2000 == 0:
                        self.save_checkpoint(str(self.global_iter))

                    if self.global_iter >= max_iter:
                        out = True
                        break

            pbar.write("[Training Finished]")

    def viz_reconstruction(self):
        self.net.eval()
        x = self.gather.data['images'][0][:100]
        x = make_grid(x, normalize=True, nrow=10)
        x_recon = F.sigmoid(self.gather.data['images'][1][:100])
        x_recon = make_grid(x_recon, normalize=True, nrow=10)
        images = torch.stack([x, x_recon], dim=0).cpu()
        if self.viz:
            self.viz.images(images,
                            env=self.viz_name + '_reconstruction',
                            opts=dict(title=str(self.global_iter)),
                            nrow=2)
        self.net.train()

    def save_reconstruction(self):
        self.net.eval()
        import numpy as np
        for item in self.data_loader:
            x = Variable(cuda(item, self.use_cuda))
            x_recon, z_tilde = self.net(x)
            x_recon = x_recon.data[:5]
            x = x.data[:5]
            #x_grid = make_grid(x, normalize=True, nrow=10)
            #x_recon = F.sigmoid(x_recon)
            #x_grid_recon = make_grid(x_recon, normalize=True, nrow=10)
            #images = torch.stack([x_grid, x_grid_recon], dim=0).cpu()
            images = torch.stack([x, x_recon], dim=0).cpu()
            np.save('reconstruction.npy', images.numpy())
            break
        self.net.train()

    def viz_lines(self):
        self.net.eval()
        recon_losses = torch.stack(self.gather.data['recon_loss']).cpu()
        Q_losses = torch.stack(self.gather.data['Q_loss']).cpu()
        D_losses = torch.stack(self.gather.data['D_loss']).cpu()
        QD_losses = torch.cat([Q_losses, D_losses], 1)
        D_zs = torch.stack(self.gather.data['D_z']).cpu()
        D_z_tildes = torch.stack(self.gather.data['D_z_tilde']).cpu()
        Ds = torch.cat([D_zs, D_z_tildes], 1)
        mus = torch.stack(self.gather.data['mu']).cpu()
        vars = torch.stack(self.gather.data['var']).cpu()
        iters = torch.Tensor(self.gather.data['iter'])

        legend_z = []
        for z_j in range(self.z_dim):
            legend_z.append('z_{}'.format(z_j))

        legend_QD = ['Q_loss', 'D_loss']
        legend_D = ['D(z)', 'D(z_tilde)']

        if self.win_recon is None:
            self.win_recon = self.viz.line(X=iters,
                                           Y=recon_losses,
                                           env=self.viz_name + '_lines',
                                           opts=dict(
                                               width=400,
                                               height=400,
                                               xlabel='iteration',
                                               title='reconsturction loss',
                                           ))
        else:
            self.win_recon = self.viz.line(X=iters,
                                           Y=recon_losses,
                                           env=self.viz_name + '_lines',
                                           win=self.win_recon,
                                           update='append',
                                           opts=dict(
                                               width=400,
                                               height=400,
                                               xlabel='iteration',
                                               title='reconsturction loss',
                                           ))

        if self.win_QD is None:
            self.win_QD = self.viz.line(X=iters,
                                        Y=QD_losses,
                                        env=self.viz_name + '_lines',
                                        opts=dict(
                                            width=400,
                                            height=400,
                                            legend=legend_QD,
                                            xlabel='iteration',
                                            title='Q&D Losses',
                                        ))
        else:
            self.win_QD = self.viz.line(X=iters,
                                        Y=QD_losses,
                                        env=self.viz_name + '_lines',
                                        win=self.win_QD,
                                        update='append',
                                        opts=dict(
                                            width=400,
                                            height=400,
                                            legend=legend_QD,
                                            xlabel='iteration',
                                            title='Q&D Losses',
                                        ))

        if self.win_D is None:
            self.win_D = self.viz.line(X=iters,
                                       Y=Ds,
                                       env=self.viz_name + '_lines',
                                       opts=dict(
                                           width=400,
                                           height=400,
                                           legend=legend_D,
                                           xlabel='iteration',
                                           title='D(.)',
                                       ))
        else:
            self.win_D = self.viz.line(X=iters,
                                       Y=Ds,
                                       env=self.viz_name + '_lines',
                                       win=self.win_D,
                                       update='append',
                                       opts=dict(
                                           width=400,
                                           height=400,
                                           legend=legend_D,
                                           xlabel='iteration',
                                           title='D(.)',
                                       ))

        if self.win_mu is None:
            self.win_mu = self.viz.line(X=iters,
                                        Y=mus,
                                        env=self.viz_name + '_lines',
                                        opts=dict(
                                            width=400,
                                            height=400,
                                            legend=legend_z,
                                            xlabel='iteration',
                                            title='posterior mean',
                                        ))
        else:
            self.win_mu = self.viz.line(X=iters,
                                        Y=vars,
                                        env=self.viz_name + '_lines',
                                        win=self.win_mu,
                                        update='append',
                                        opts=dict(
                                            width=400,
                                            height=400,
                                            legend=legend_z,
                                            xlabel='iteration',
                                            title='posterior mean',
                                        ))

        if self.win_var is None:
            self.win_var = self.viz.line(X=iters,
                                         Y=vars,
                                         env=self.viz_name + '_lines',
                                         opts=dict(
                                             width=400,
                                             height=400,
                                             legend=legend_z,
                                             xlabel='iteration',
                                             title='posterior variance',
                                         ))
        else:
            self.win_var = self.viz.line(X=iters,
                                         Y=vars,
                                         env=self.viz_name + '_lines',
                                         win=self.win_var,
                                         update='append',
                                         opts=dict(
                                             width=400,
                                             height=400,
                                             legend=legend_z,
                                             xlabel='iteration',
                                             title='posterior variance',
                                         ))
        self.net.train()

    def sample_z(self, n_sample=None, dim=None, sigma=None, template=None):
        if n_sample is None:
            n_sample = self.batch_size
        if dim is None:
            dim = self.z_dim
        if sigma is None:
            sigma = self.z_sigma

        if template is not None:
            z = sigma * Variable(template.data.new(template.size()).normal_())
        else:
            z = sigma * torch.randn(n_sample, dim)
            z = Variable(cuda(z, self.use_cuda))

        return z

    def sample_x_from_z(self, n_sample):
        self.net.eval()
        z = self.sample_z(n_sample=n_sample, sigma=self.z_sigma)
        x_gen = F.sigmoid(self.net._decode(z)[:100]).data.cpu()
        x_gen = make_grid(x_gen, normalize=True, nrow=10)
        self.viz.images(x_gen,
                        env=self.viz_name + '_sampling_from_random_z',
                        opts=dict(title=str(self.global_iter)))
        self.net.train()

    def save_checkpoint(self, filename, silent=True):
        model_states = {
            'net': self.net.state_dict(),
            'D': self.D.state_dict(),
        }
        optim_states = {
            'optim': self.optim.state_dict(),
            'optim_D': self.optim_D.state_dict()
        }
        win_states = {
            'recon': self.win_recon,
            'QD': self.win_QD,
            'D': self.win_D,
            'mu': self.win_mu,
            'var': self.win_var,
        }
        states = {
            'iter': self.global_iter,
            'epoch': self.global_epoch,
            'win_states': win_states,
            'model_states': model_states,
            'optim_states': optim_states
        }

        file_path = self.ckpt_dir.joinpath(filename)
        torch.save(states, file_path.open('wb+'))
        if not silent:
            print("=> saved checkpoint '{}' (iter {})".format(
                file_path, self.global_iter))

    def load_checkpoint(self, filename, silent=False):
        file_path = self.ckpt_dir.joinpath(filename)
        print(file_path)
        if file_path.is_file():
            checkpoint = torch.load(file_path.open('rb'))
            self.global_iter = checkpoint['iter']
            self.global_epoch = checkpoint['epoch']
            self.win_recon = checkpoint['win_states']['recon']
            self.win_QD = checkpoint['win_states']['QD']
            self.win_D = checkpoint['win_states']['D']
            self.win_var = checkpoint['win_states']['var']
            self.win_mu = checkpoint['win_states']['mu']
            self.net.load_state_dict(checkpoint['model_states']['net'])
            self.optim.load_state_dict(checkpoint['optim_states']['optim'])
            self.D.load_state_dict(checkpoint['model_states']['D'])
            self.optim_D.load_state_dict(checkpoint['optim_states']['optim_D'])
            if not silent:
                print("=> loaded checkpoint '{} (iter {})'".format(
                    file_path, self.global_iter))
        else:
            if not silent:
                print("=> no checkpoint found at '{}'".format(file_path))
Example #11
0
    def __init__(self, args):

        self.args = args
        args.num_sg = args.load_e
        self.name = '%s_bs%s_zD_%s_dr_mlp_%s_dr_rnn_%s_enc_hD_%s_dec_hD_%s_mlpD_%s_lr_%s_klw_%s_ll_prior_w_%s_zfb_%s_scale_%s_num_sg_%s' \
                    'ctxtD_%s_coll_th_%s_w_coll_%s_beta_%s_lr_e_%s_k_%s' % \
                    (args.dataset_name, args.batch_size, args.zS_dim, args.dropout_mlp, args.dropout_rnn, args.encoder_h_dim,
                     args.decoder_h_dim, args.mlp_dim, args.lr_VAE, args.kl_weight,
                     args.ll_prior_w, args.fb, args.scale, args.num_sg, args.context_dim, args.coll_th, args.w_coll, args.beta, args.lr_e, args.k_fold)

        # to be appended by run_id

        # self.use_cuda = args.cuda and torch.cuda.is_available()
        self.device = args.device
        self.temp = 1.99
        self.dt = 0.4
        self.eps = 1e-9
        self.ll_prior_w = args.ll_prior_w
        self.sg_idx = np.array(range(12))
        self.sg_idx = np.flip(11 - self.sg_idx[::(12 // args.num_sg)])

        self.coll_th = args.coll_th
        self.beta = args.beta
        self.context_dim = args.context_dim
        self.w_coll = args.w_coll

        self.z_fb = args.fb
        self.scale = args.scale

        self.kl_weight = args.kl_weight
        self.lg_kl_weight = args.lg_kl_weight

        self.max_iter = int(args.max_iter)

        # do it every specified iters
        self.print_iter = args.print_iter
        self.ckpt_save_iter = args.ckpt_save_iter
        self.output_save_iter = args.output_save_iter

        # data info
        args.dataset_dir = os.path.join(args.dataset_dir, str(args.k_fold))

        self.dataset_dir = args.dataset_dir
        self.dataset_name = args.dataset_name

        # self.N = self.latent_values.shape[0]
        # self.eval_metrics_iter = args.eval_metrics_iter

        # networks and optimizers
        self.batch_size = args.batch_size
        self.zS_dim = args.zS_dim
        self.w_dim = args.w_dim
        self.lr_VAE = args.lr_VAE
        self.beta1_VAE = args.beta1_VAE
        self.beta2_VAE = args.beta2_VAE
        print(args.desc)

        # create dirs: "records", "ckpts", "outputs" (if not exist)
        mkdirs("records")
        mkdirs("ckpts")
        mkdirs("outputs")

        # set run id
        if args.run_id < 0:  # create a new id
            k = 0
            rfname = os.path.join("records", self.name + '_run_0.txt')
            while os.path.exists(rfname):
                k += 1
                rfname = os.path.join("records", self.name + '_run_%d.txt' % k)
            self.run_id = k
        else:  # user-provided id
            self.run_id = args.run_id

        # finalize name
        self.name = self.name + '_run_' + str(self.run_id)

        # checkpoints
        self.ckpt_dir = os.path.join("ckpts", self.name)

        # visdom setup
        self.viz_on = args.viz_on
        if self.viz_on:
            self.win_id = dict(recon='win_recon',
                               loss_kl='win_loss_kl',
                               loss_recon='win_loss_recon',
                               ade_min='win_ade_min',
                               fde_min='win_fde_min',
                               ade_avg='win_ade_avg',
                               fde_avg='win_fde_avg',
                               ade_std='win_ade_std',
                               fde_std='win_fde_std',
                               test_loss_recon='win_test_loss_recon',
                               test_loss_kl='win_test_loss_kl',
                               loss_recon_prior='win_loss_recon_prior',
                               loss_coll='win_loss_coll',
                               test_loss_coll='win_test_loss_coll',
                               test_total_coll='win_test_total_coll',
                               total_coll='win_total_coll')
            self.line_gather = DataGather(
                'iter', 'loss_recon', 'loss_kl', 'loss_recon_prior', 'ade_min',
                'fde_min', 'ade_avg', 'fde_avg', 'ade_std', 'fde_std',
                'test_loss_recon', 'test_loss_kl', 'test_loss_coll',
                'loss_coll', 'test_total_coll', 'total_coll')

            self.viz_port = args.viz_port  # port number, eg, 8097
            self.viz = visdom.Visdom(port=self.viz_port, env=self.name)
            self.viz_ll_iter = args.viz_ll_iter
            self.viz_la_iter = args.viz_la_iter

            self.viz_init()
        #### create a new model or load a previously saved model

        self.ckpt_load_iter = args.ckpt_load_iter

        self.obs_len = 8
        self.pred_len = 12
        self.num_layers = args.num_layers
        self.decoder_h_dim = args.decoder_h_dim

        if self.ckpt_load_iter == 0 or args.dataset_name == 'all':  # create a new model
            lg_cvae_path = 'large.lgcvae_enc_block_1_fcomb_block_2_wD_10_lr_0.0001_lg_klw_1.0_a_0.25_r_2.0_fb_5.0_anneal_e_10_load_e_3_run_4'
            lg_cvae_path = os.path.join('ckpts', lg_cvae_path,
                                        'iter_150_lg_cvae.pt')
            if self.device == 'cuda':
                self.lg_cvae = torch.load(lg_cvae_path)

            self.encoderMx = EncoderX(args.zS_dim,
                                      enc_h_dim=args.encoder_h_dim,
                                      mlp_dim=args.mlp_dim,
                                      map_mlp_dim=args.map_mlp_dim,
                                      map_feat_dim=args.map_feat_dim,
                                      num_layers=args.num_layers,
                                      dropout_mlp=args.dropout_mlp,
                                      dropout_rnn=args.dropout_rnn,
                                      device=self.device).to(self.device)
            self.encoderMy = EncoderY(args.zS_dim,
                                      enc_h_dim=args.encoder_h_dim,
                                      mlp_dim=args.mlp_dim,
                                      num_layers=args.num_layers,
                                      dropout_mlp=args.dropout_mlp,
                                      dropout_rnn=args.dropout_rnn,
                                      device=self.device).to(self.device)
            self.decoderMy = Decoder(args.pred_len,
                                     dec_h_dim=self.decoder_h_dim,
                                     enc_h_dim=args.encoder_h_dim,
                                     mlp_dim=args.mlp_dim,
                                     z_dim=args.zS_dim,
                                     num_layers=args.num_layers,
                                     device=args.device,
                                     dropout_rnn=args.dropout_rnn,
                                     scale=args.scale,
                                     dt=self.dt,
                                     context_dim=args.context_dim).to(
                                         self.device)

        else:  # load a previously saved model
            print('Loading saved models (iter: %d)...' % self.ckpt_load_iter)
            self.load_checkpoint()
            print('...done')

        # get VAE parameters
        vae_params = \
            list(self.encoderMx.parameters()) + \
            list(self.encoderMy.parameters()) + \
            list(self.decoderMy.parameters())
        # create optimizers
        self.optim_vae = optim.Adam(vae_params,
                                    lr=self.lr_VAE,
                                    betas=[self.beta1_VAE, self.beta2_VAE])

        self.scheduler = optim.lr_scheduler.LambdaLR(
            optimizer=self.optim_vae, lr_lambda=lambda epoch: args.lr_e**epoch)

        print('Start loading data...')

        if self.ckpt_load_iter != self.max_iter:
            print("Initializing train dataset")
            _, self.train_loader = data_loader(self.args,
                                               args.dataset_dir,
                                               'train',
                                               shuffle=True)
            print("Initializing val dataset")
            _, self.val_loader = data_loader(self.args,
                                             args.dataset_dir,
                                             'val',
                                             shuffle=True)

            print('There are {} iterations per epoch'.format(
                len(self.train_loader.dataset) / args.batch_size))
        print('...done')
Example #12
0
class Solver(object):

    ####
    def __init__(self, args):

        self.args = args
        args.num_sg = args.load_e
        self.name = '%s_bs%s_zD_%s_dr_mlp_%s_dr_rnn_%s_enc_hD_%s_dec_hD_%s_mlpD_%s_lr_%s_klw_%s_ll_prior_w_%s_zfb_%s_scale_%s_num_sg_%s' \
                    'ctxtD_%s_coll_th_%s_w_coll_%s_beta_%s_lr_e_%s_k_%s' % \
                    (args.dataset_name, args.batch_size, args.zS_dim, args.dropout_mlp, args.dropout_rnn, args.encoder_h_dim,
                     args.decoder_h_dim, args.mlp_dim, args.lr_VAE, args.kl_weight,
                     args.ll_prior_w, args.fb, args.scale, args.num_sg, args.context_dim, args.coll_th, args.w_coll, args.beta, args.lr_e, args.k_fold)

        # to be appended by run_id

        # self.use_cuda = args.cuda and torch.cuda.is_available()
        self.device = args.device
        self.temp = 1.99
        self.dt = 0.4
        self.eps = 1e-9
        self.ll_prior_w = args.ll_prior_w
        self.sg_idx = np.array(range(12))
        self.sg_idx = np.flip(11 - self.sg_idx[::(12 // args.num_sg)])

        self.coll_th = args.coll_th
        self.beta = args.beta
        self.context_dim = args.context_dim
        self.w_coll = args.w_coll

        self.z_fb = args.fb
        self.scale = args.scale

        self.kl_weight = args.kl_weight
        self.lg_kl_weight = args.lg_kl_weight

        self.max_iter = int(args.max_iter)

        # do it every specified iters
        self.print_iter = args.print_iter
        self.ckpt_save_iter = args.ckpt_save_iter
        self.output_save_iter = args.output_save_iter

        # data info
        args.dataset_dir = os.path.join(args.dataset_dir, str(args.k_fold))

        self.dataset_dir = args.dataset_dir
        self.dataset_name = args.dataset_name

        # self.N = self.latent_values.shape[0]
        # self.eval_metrics_iter = args.eval_metrics_iter

        # networks and optimizers
        self.batch_size = args.batch_size
        self.zS_dim = args.zS_dim
        self.w_dim = args.w_dim
        self.lr_VAE = args.lr_VAE
        self.beta1_VAE = args.beta1_VAE
        self.beta2_VAE = args.beta2_VAE
        print(args.desc)

        # create dirs: "records", "ckpts", "outputs" (if not exist)
        mkdirs("records")
        mkdirs("ckpts")
        mkdirs("outputs")

        # set run id
        if args.run_id < 0:  # create a new id
            k = 0
            rfname = os.path.join("records", self.name + '_run_0.txt')
            while os.path.exists(rfname):
                k += 1
                rfname = os.path.join("records", self.name + '_run_%d.txt' % k)
            self.run_id = k
        else:  # user-provided id
            self.run_id = args.run_id

        # finalize name
        self.name = self.name + '_run_' + str(self.run_id)

        # checkpoints
        self.ckpt_dir = os.path.join("ckpts", self.name)

        # visdom setup
        self.viz_on = args.viz_on
        if self.viz_on:
            self.win_id = dict(recon='win_recon',
                               loss_kl='win_loss_kl',
                               loss_recon='win_loss_recon',
                               ade_min='win_ade_min',
                               fde_min='win_fde_min',
                               ade_avg='win_ade_avg',
                               fde_avg='win_fde_avg',
                               ade_std='win_ade_std',
                               fde_std='win_fde_std',
                               test_loss_recon='win_test_loss_recon',
                               test_loss_kl='win_test_loss_kl',
                               loss_recon_prior='win_loss_recon_prior',
                               loss_coll='win_loss_coll',
                               test_loss_coll='win_test_loss_coll',
                               test_total_coll='win_test_total_coll',
                               total_coll='win_total_coll')
            self.line_gather = DataGather(
                'iter', 'loss_recon', 'loss_kl', 'loss_recon_prior', 'ade_min',
                'fde_min', 'ade_avg', 'fde_avg', 'ade_std', 'fde_std',
                'test_loss_recon', 'test_loss_kl', 'test_loss_coll',
                'loss_coll', 'test_total_coll', 'total_coll')

            self.viz_port = args.viz_port  # port number, eg, 8097
            self.viz = visdom.Visdom(port=self.viz_port, env=self.name)
            self.viz_ll_iter = args.viz_ll_iter
            self.viz_la_iter = args.viz_la_iter

            self.viz_init()
        #### create a new model or load a previously saved model

        self.ckpt_load_iter = args.ckpt_load_iter

        self.obs_len = 8
        self.pred_len = 12
        self.num_layers = args.num_layers
        self.decoder_h_dim = args.decoder_h_dim

        if self.ckpt_load_iter == 0 or args.dataset_name == 'all':  # create a new model
            lg_cvae_path = 'large.lgcvae_enc_block_1_fcomb_block_2_wD_10_lr_0.0001_lg_klw_1.0_a_0.25_r_2.0_fb_5.0_anneal_e_10_load_e_3_run_4'
            lg_cvae_path = os.path.join('ckpts', lg_cvae_path,
                                        'iter_150_lg_cvae.pt')
            if self.device == 'cuda':
                self.lg_cvae = torch.load(lg_cvae_path)

            self.encoderMx = EncoderX(args.zS_dim,
                                      enc_h_dim=args.encoder_h_dim,
                                      mlp_dim=args.mlp_dim,
                                      map_mlp_dim=args.map_mlp_dim,
                                      map_feat_dim=args.map_feat_dim,
                                      num_layers=args.num_layers,
                                      dropout_mlp=args.dropout_mlp,
                                      dropout_rnn=args.dropout_rnn,
                                      device=self.device).to(self.device)
            self.encoderMy = EncoderY(args.zS_dim,
                                      enc_h_dim=args.encoder_h_dim,
                                      mlp_dim=args.mlp_dim,
                                      num_layers=args.num_layers,
                                      dropout_mlp=args.dropout_mlp,
                                      dropout_rnn=args.dropout_rnn,
                                      device=self.device).to(self.device)
            self.decoderMy = Decoder(args.pred_len,
                                     dec_h_dim=self.decoder_h_dim,
                                     enc_h_dim=args.encoder_h_dim,
                                     mlp_dim=args.mlp_dim,
                                     z_dim=args.zS_dim,
                                     num_layers=args.num_layers,
                                     device=args.device,
                                     dropout_rnn=args.dropout_rnn,
                                     scale=args.scale,
                                     dt=self.dt,
                                     context_dim=args.context_dim).to(
                                         self.device)

        else:  # load a previously saved model
            print('Loading saved models (iter: %d)...' % self.ckpt_load_iter)
            self.load_checkpoint()
            print('...done')

        # get VAE parameters
        vae_params = \
            list(self.encoderMx.parameters()) + \
            list(self.encoderMy.parameters()) + \
            list(self.decoderMy.parameters())
        # create optimizers
        self.optim_vae = optim.Adam(vae_params,
                                    lr=self.lr_VAE,
                                    betas=[self.beta1_VAE, self.beta2_VAE])

        self.scheduler = optim.lr_scheduler.LambdaLR(
            optimizer=self.optim_vae, lr_lambda=lambda epoch: args.lr_e**epoch)

        print('Start loading data...')

        if self.ckpt_load_iter != self.max_iter:
            print("Initializing train dataset")
            _, self.train_loader = data_loader(self.args,
                                               args.dataset_dir,
                                               'train',
                                               shuffle=True)
            print("Initializing val dataset")
            _, self.val_loader = data_loader(self.args,
                                             args.dataset_dir,
                                             'val',
                                             shuffle=True)

            print('There are {} iterations per epoch'.format(
                len(self.train_loader.dataset) / args.batch_size))
        print('...done')

    def temmp(self):
        aa = torch.zeros((120, 2, 256, 256)).to(self.device)
        self.lg_cvae.unet.down_forward(aa)

    ####
    def train(self):
        self.set_mode(train=True)
        data_loader = self.train_loader

        self.N = len(data_loader.dataset)
        iterator = iter(data_loader)

        iter_per_epoch = len(iterator)
        start_iter = self.ckpt_load_iter + 1
        epoch = int(start_iter / iter_per_epoch) + 1

        e_coll_loss = 0
        e_total_coll = 0

        for iteration in range(start_iter, self.max_iter + 1):

            # reset data iterators for each epoch
            if iteration % iter_per_epoch == 0:
                # print(iteration)
                print('==== epoch %d done ====' % epoch)
                if epoch % 10 == 0:
                    if self.optim_vae.param_groups[0]['lr'] > 5e-4:
                        self.scheduler.step()
                    else:
                        self.optim_vae.param_groups[0]['lr'] = 5e-4
                print("lr: ", self.optim_vae.param_groups[0]['lr'],
                      ' // w_coll: ', self.w_coll)
                print('e_coll_loss: ', e_coll_loss, ' // e_total_coll: ',
                      e_total_coll)

                epoch += 1
                iterator = iter(data_loader)
                prev_e_coll_loss = e_coll_loss
                prev_e_total_coll = e_total_coll
                e_coll_loss = 0
                e_total_coll = 0

            # ============================================
            #          TRAIN THE VAE (ENC & DEC)
            # ============================================

            (obs_traj, fut_traj, obs_traj_st, fut_vel_st, seq_start_end,
             obs_frames, fut_frames, map_path, inv_h_t, local_map, local_ic,
             local_homo) = next(iterator)
            batch_size = obs_traj.size(
                1)  #=sum(seq_start_end[:,1] - seq_start_end[:,0])

            #-------- trajectories --------
            (hx, mux, log_varx) \
                = self.encoderMx(obs_traj_st, seq_start_end, train=True)


            (muy, log_vary) \
                = self.encoderMy(obs_traj_st[-1], fut_vel_st, seq_start_end, hx, train=True)

            p_dist = Normal(
                mux, torch.clamp(torch.sqrt(torch.exp(log_varx)), min=1e-8))
            q_dist = Normal(
                muy, torch.clamp(torch.sqrt(torch.exp(log_vary)), min=1e-8))

            # TF, goals, z~posterior
            fut_rel_pos_dist_tf_post = self.decoderMy(
                seq_start_end,
                obs_traj_st[-1],
                obs_traj[-1, :, :2],
                hx,
                q_dist.rsample(),
                fut_traj[list(self.sg_idx), :, :2].permute(1, 0, 2),  # goal
                self.sg_idx,
                fut_vel_st,  # TF
                train=True)

            # NO TF, predicted goals, z~prior
            fut_rel_pos_dist_prior = self.decoderMy(
                seq_start_end,
                obs_traj_st[-1],
                obs_traj[-1, :, :2],
                hx,
                p_dist.rsample(),
                fut_traj[list(self.sg_idx), :, :2].permute(1, 0, 2),  # goal
                self.sg_idx,
                train=True)

            ll_tf_post = fut_rel_pos_dist_tf_post.log_prob(
                fut_vel_st).sum().div(batch_size)
            ll_prior = fut_rel_pos_dist_prior.log_prob(fut_vel_st).sum().div(
                batch_size)

            loss_kl = kl_divergence(q_dist, p_dist)
            loss_kl = torch.clamp(loss_kl, min=self.z_fb).sum().div(batch_size)
            # print('log_likelihood:', loglikelihood.item(), ' kl:', loss_kl.item())

            loglikelihood = ll_tf_post + self.ll_prior_w * ll_prior
            traj_elbo = loglikelihood - self.kl_weight * loss_kl

            coll_loss = torch.tensor(0.0).to(self.device)
            total_coll = 0
            n_scene = 0

            if self.w_coll > 0:
                pred_fut_traj = integrate_samples(
                    fut_rel_pos_dist_prior.rsample() * self.scale,
                    obs_traj[-1, :, :2],
                    dt=self.dt)

                pred_fut_traj_post = integrate_samples(
                    fut_rel_pos_dist_tf_post.rsample() * self.scale,
                    obs_traj[-1, :, :2],
                    dt=self.dt)
                for s, e in seq_start_end:
                    n_scene += 1
                    num_ped = e - s
                    if num_ped == 1:
                        continue
                    for t in range(self.pred_len):
                        ## prior
                        curr1 = pred_fut_traj[t, s:e].repeat(num_ped, 1)
                        curr2 = self.repeat(pred_fut_traj[t, s:e], num_ped)
                        dist = torch.norm(curr1 - curr2, dim=1)
                        dist = dist.reshape(num_ped, num_ped)
                        diff_agent_dist = dist[torch.where(dist > 0)]
                        coll_loss += (torch.sigmoid(
                            -(diff_agent_dist - self.coll_th) *
                            self.beta)).sum()
                        total_coll += (
                            len(torch.where(diff_agent_dist < 0.5)[0]) / 2)
                        ## posterior
                        curr1_post = pred_fut_traj_post[t, s:e].repeat(
                            num_ped, 1)
                        curr2_post = self.repeat(pred_fut_traj_post[t, s:e],
                                                 num_ped)
                        dist_post = torch.norm(curr1_post - curr2_post, dim=1)
                        dist_post = dist_post.reshape(num_ped, num_ped)
                        diff_agent_dist_post = dist_post[torch.where(
                            dist_post > 0)]
                        coll_loss += (torch.sigmoid(
                            -(diff_agent_dist_post - self.coll_th) *
                            self.beta)).sum()
                        total_coll += (
                            len(torch.where(diff_agent_dist_post < 0.5)[0]) /
                            2)

            coll_loss = coll_loss.div(batch_size)
            total_coll = total_coll / batch_size

            loss = -traj_elbo + self.w_coll * coll_loss
            e_coll_loss += coll_loss.item()
            e_total_coll += total_coll

            self.optim_vae.zero_grad()
            loss.backward()
            self.optim_vae.step()

            # save model parameters
            if epoch > 100 and (iteration % (iter_per_epoch * 20) == 0):
                self.save_checkpoint(epoch)

            # (visdom) insert current line stats
            if epoch > 100:
                if iteration == iter_per_epoch or (
                        self.viz_on and (iteration %
                                         (iter_per_epoch * 20) == 0)):
                    ade_min, fde_min, \
                    ade_avg, fde_avg, \
                    ade_std, fde_std, \
                    test_loss_recon, test_loss_kl, test_loss_coll, test_total_coll = self.evaluate_dist(self.val_loader, loss=True)
                    self.line_gather.insert(
                        iter=epoch,
                        ade_min=ade_min,
                        fde_min=fde_min,
                        ade_avg=ade_avg,
                        fde_avg=fde_avg,
                        ade_std=ade_std,
                        fde_std=fde_std,
                        loss_recon=-ll_tf_post.item(),
                        loss_recon_prior=-ll_prior.item(),
                        loss_kl=loss_kl.item(),
                        loss_coll=prev_e_coll_loss,
                        total_coll=prev_e_total_coll,
                        test_loss_recon=test_loss_recon.item(),
                        test_loss_kl=test_loss_kl.item(),
                        test_loss_coll=test_loss_coll.item(),
                        test_total_coll=test_total_coll)
                    prn_str = ('[iter_%d (epoch_%d)] vae_loss: %.3f ' + \
                                  '(recon: %.3f, kl: %.3f)\n' + \
                                  'ADE min: %.2f, FDE min: %.2f, ADE avg: %.2f, FDE avg: %.2f\n'
                              ) % \
                              (iteration, epoch,
                               loss.item(), -loglikelihood.item(), loss_kl.item(),
                               ade_min, fde_min, ade_avg, fde_avg
                               )

                    print(prn_str)
                    self.visualize_line()
                    self.line_gather.flush()

    def repeat(self, tensor, num_reps):
        """
        Inputs:
        -tensor: 2D tensor of any shape
        -num_reps: Number of times to repeat each row
        Outpus:
        -repeat_tensor: Repeat each row such that: R1, R1, R2, R2
        """
        col_len = tensor.size(1)
        tensor = tensor.unsqueeze(dim=1).repeat(1, num_reps, 1)
        tensor = tensor.view(-1, col_len)
        return tensor

    def evaluate_dist(self, data_loader, loss=False):
        self.set_mode(train=False)
        total_traj = 0

        loss_recon = loss_kl = 0
        coll_loss = 0
        total_coll = 0
        n_scene = 0

        all_ade = []
        all_fde = []

        with torch.no_grad():
            b = 0
            for batch in data_loader:
                b += 1
                (obs_traj, fut_traj, obs_traj_st, fut_vel_st, seq_start_end,
                 obs_frames, fut_frames, map_path, inv_h_t, local_map,
                 local_ic, local_homo) = batch
                batch_size = fut_traj.size(1)
                total_traj += fut_traj.size(1)

                # -------- trajectories --------
                (hx, mux, log_varx) \
                    = self.encoderMx(obs_traj_st, seq_start_end)
                p_dist = Normal(
                    mux, torch.clamp(torch.sqrt(torch.exp(log_varx)),
                                     min=1e-8))

                fut_rel_pos_dist20 = []
                for _ in range(4):
                    # NO TF, pred_goals, z~prior
                    fut_rel_pos_dist_prior = self.decoderMy(
                        seq_start_end,
                        obs_traj_st[-1],
                        obs_traj[-1, :, :2],
                        hx,
                        p_dist.rsample(),
                        fut_traj[list(self.sg_idx), :, :2].permute(1, 0,
                                                                   2),  # goal
                        self.sg_idx,
                    )
                    fut_rel_pos_dist20.append(fut_rel_pos_dist_prior)

                if loss:
                    (muy, log_vary) \
                        = self.encoderMy(obs_traj_st[-1], fut_vel_st, seq_start_end, hx, train=False)
                    q_dist = Normal(muy, torch.sqrt(torch.exp(log_vary)))

                    loss_recon -= fut_rel_pos_dist_prior.log_prob(
                        fut_vel_st).sum().div(batch_size)
                    kld = kl_divergence(q_dist, p_dist).sum().div(batch_size)
                    loss_kl += kld

                    pred_fut_traj = integrate_samples(
                        fut_rel_pos_dist_prior.rsample() * self.scale,
                        obs_traj[-1, :, :2],
                        dt=self.dt)
                    for s, e in seq_start_end:
                        n_scene += 1
                        num_ped = e - s
                        if num_ped == 1:
                            continue
                        seq_traj = pred_fut_traj[:, s:e]
                        for i in range(len(seq_traj)):
                            curr1 = seq_traj[i].repeat(num_ped, 1)
                            curr2 = self.repeat(seq_traj[i], num_ped)
                            dist = torch.norm(curr1 - curr2, dim=1)
                            dist = dist.reshape(num_ped, num_ped)
                            diff_agent_dist = dist[torch.where(dist > 0)]
                            if len(diff_agent_dist) > 0:
                                # diff_agent_dist[torch.where(diff_agent_dist > self.coll_th)] += self.beta
                                coll_loss += (torch.sigmoid(
                                    -(diff_agent_dist - self.coll_th) *
                                    self.beta)).sum().div(batch_size)
                                total_coll += (len(
                                    torch.where(diff_agent_dist < 0.5)[0]) /
                                               2) / batch_size

                ade, fde = [], []
                for dist in fut_rel_pos_dist20:
                    pred_fut_traj = integrate_samples(dist.rsample() *
                                                      self.scale,
                                                      obs_traj[-1, :, :2],
                                                      dt=self.dt)
                    ade.append(
                        displacement_error(pred_fut_traj,
                                           fut_traj[:, :, :2],
                                           mode='raw'))
                    fde.append(
                        final_displacement_error(pred_fut_traj[-1],
                                                 fut_traj[-1, :, :2],
                                                 mode='raw'))
                all_ade.append(torch.stack(ade))
                all_fde.append(torch.stack(fde))

            all_ade = torch.cat(all_ade, dim=1).cpu().numpy()
            all_fde = torch.cat(all_fde, dim=1).cpu().numpy()

            ade_min = np.min(all_ade, axis=0).mean() / self.pred_len
            fde_min = np.min(all_fde, axis=0).mean()
            ade_avg = np.mean(all_ade, axis=0).mean() / self.pred_len
            fde_avg = np.mean(all_fde, axis=0).mean()
            ade_std = np.std(all_ade, axis=0).mean() / self.pred_len
            fde_std = np.std(all_fde, axis=0).mean()

        self.set_mode(train=True)
        if loss:
            return ade_min, fde_min, \
                   ade_avg, fde_avg, \
                   ade_std, fde_std, \
                   loss_recon/b, loss_kl/b, coll_loss/b, total_coll
        else:
            return ade_min, fde_min, \
                   ade_avg, fde_avg, \
                   ade_std, fde_std,

    def collision_stat(self, data_loader):
        self.set_mode(train=False)

        total_coll1 = 0
        total_coll2 = 0
        total_coll3 = 0
        total_coll4 = 0
        total_coll5 = 0
        total_coll6 = 0
        n_scene = 0
        total_ped = []
        e_ped = []
        avg_dist = 0
        min_dist = 10000
        n_agent = 0

        with torch.no_grad():
            b = 0
            for batch in data_loader:
                b += 1
                (obs_traj, fut_traj, obs_traj_st, fut_vel_st, seq_start_end,
                 obs_frames, fut_frames, map_path, inv_h_t, local_map,
                 local_ic, local_homo) = batch
                for s, e in seq_start_end:
                    n_scene += 1
                    num_ped = e - s
                    total_ped.append(num_ped)
                    if num_ped == 1:
                        continue
                    e_ped.append(num_ped)

                    seq_traj = fut_traj[:, s:e, :2]
                    for i in range(len(seq_traj)):
                        curr1 = seq_traj[i].repeat(num_ped, 1)
                        curr2 = self.repeat(seq_traj[i], num_ped)
                        dist = torch.sqrt(torch.pow(curr1 - curr2,
                                                    2).sum(1)).cpu().numpy()
                        dist = dist.reshape(num_ped, num_ped)
                        diff_agent_idx = np.triu_indices(num_ped, k=1)
                        diff_agent_dist = dist[diff_agent_idx]
                        avg_dist += diff_agent_dist.sum()
                        min_dist = min(min_dist, diff_agent_dist.min())
                        n_agent += len(diff_agent_dist)
                        total_coll1 += (diff_agent_dist < 0.05).sum()
                        total_coll2 += (diff_agent_dist < 0.1).sum()
                        total_coll3 += (diff_agent_dist < 0.2).sum()
                        total_coll4 += (diff_agent_dist < 0.3).sum()
                        total_coll5 += (diff_agent_dist < 0.4).sum()
                        total_coll6 += (diff_agent_dist < 0.5).sum()
        print('total_coll1: ', total_coll1)
        print('total_coll2: ', total_coll2)
        print('total_coll3: ', total_coll3)
        print('total_coll4: ', total_coll4)
        print('total_coll5: ', total_coll5)
        print('total_coll6: ', total_coll6)
        print('n_scene: ', n_scene)
        print('e_ped:', len(e_ped))
        print('total_ped:', len(total_ped))
        print('avg_dist:', avg_dist / n_agent)
        print('min_dist:', avg_dist / n_agent)
        print('e_ped:', np.array(e_ped).mean())
        print('total_ped:', np.array(total_ped).mean())

    ####
    def viz_init(self):
        self.viz.close(env=self.name, win=self.win_id['loss_recon'])
        self.viz.close(env=self.name, win=self.win_id['loss_recon_prior'])
        self.viz.close(env=self.name, win=self.win_id['loss_kl'])
        self.viz.close(env=self.name, win=self.win_id['test_loss_recon'])
        self.viz.close(env=self.name, win=self.win_id['test_loss_kl'])

        self.viz.close(env=self.name, win=self.win_id['ade_min'])
        self.viz.close(env=self.name, win=self.win_id['fde_min'])
        self.viz.close(env=self.name, win=self.win_id['ade_avg'])
        self.viz.close(env=self.name, win=self.win_id['fde_avg'])
        self.viz.close(env=self.name, win=self.win_id['ade_std'])
        self.viz.close(env=self.name, win=self.win_id['fde_std'])

    ####
    def visualize_line(self):

        # prepare data to plot
        data = self.line_gather.data
        iters = torch.Tensor(data['iter'])
        loss_recon = torch.Tensor(data['loss_recon'])
        loss_recon_prior = torch.Tensor(data['loss_recon_prior'])
        loss_kl = torch.Tensor(data['loss_kl'])
        ade_min = torch.Tensor(data['ade_min'])
        fde_min = torch.Tensor(data['fde_min'])
        ade_avg = torch.Tensor(data['ade_avg'])
        fde_avg = torch.Tensor(data['fde_avg'])
        ade_std = torch.Tensor(data['ade_std'])
        fde_std = torch.Tensor(data['fde_std'])
        test_loss_recon = torch.Tensor(data['test_loss_recon'])
        test_loss_kl = torch.Tensor(data['test_loss_kl'])
        test_loss_coll = torch.Tensor(data['test_loss_coll'])
        loss_coll = torch.Tensor(data['loss_coll'])
        total_coll = torch.Tensor(data['total_coll'])
        test_total_coll = torch.Tensor(data['test_total_coll'])

        self.viz.line(X=iters,
                      Y=total_coll,
                      env=self.name,
                      win=self.win_id['total_coll'],
                      update='append',
                      opts=dict(xlabel='iter',
                                ylabel='total_coll',
                                title='total_coll'))

        self.viz.line(X=iters,
                      Y=test_total_coll,
                      env=self.name,
                      win=self.win_id['test_total_coll'],
                      update='append',
                      opts=dict(xlabel='iter',
                                ylabel='test_total_coll',
                                title='test_total_coll'))

        self.viz.line(X=iters,
                      Y=test_loss_coll,
                      env=self.name,
                      win=self.win_id['test_loss_coll'],
                      update='append',
                      opts=dict(xlabel='iter',
                                ylabel='test_loss_coll',
                                title='test_loss_coll'))

        self.viz.line(X=iters,
                      Y=loss_coll,
                      env=self.name,
                      win=self.win_id['loss_coll'],
                      update='append',
                      opts=dict(xlabel='iter',
                                ylabel='loss_coll',
                                title='loss_coll'))
        self.viz.line(X=iters,
                      Y=loss_recon,
                      env=self.name,
                      win=self.win_id['loss_recon'],
                      update='append',
                      opts=dict(xlabel='iter',
                                ylabel='-loglikelihood',
                                title='Recon. loss of predicted future traj'))

        self.viz.line(X=iters,
                      Y=loss_recon_prior,
                      env=self.name,
                      win=self.win_id['loss_recon_prior'],
                      update='append',
                      opts=dict(xlabel='iter',
                                ylabel='-loglikelihood',
                                title='Recon. loss - prior'))

        self.viz.line(
            X=iters,
            Y=loss_kl,
            env=self.name,
            win=self.win_id['loss_kl'],
            update='append',
            opts=dict(xlabel='iter',
                      ylabel='kl divergence',
                      title='KL div. btw posterior and c. prior'),
        )

        self.viz.line(X=iters,
                      Y=test_loss_recon,
                      env=self.name,
                      win=self.win_id['test_loss_recon'],
                      update='append',
                      opts=dict(
                          xlabel='iter',
                          ylabel='-loglikelihood',
                          title='Test Recon. loss of predicted future traj'))

        self.viz.line(
            X=iters,
            Y=test_loss_kl,
            env=self.name,
            win=self.win_id['test_loss_kl'],
            update='append',
            opts=dict(xlabel='iter',
                      ylabel='kl divergence',
                      title='Test KL div. btw posterior and c. prior'),
        )

        self.viz.line(
            X=iters,
            Y=ade_min,
            env=self.name,
            win=self.win_id['ade_min'],
            update='append',
            opts=dict(xlabel='iter', ylabel='ade', title='ADE min'),
        )
        self.viz.line(
            X=iters,
            Y=fde_min,
            env=self.name,
            win=self.win_id['fde_min'],
            update='append',
            opts=dict(xlabel='iter', ylabel='fde', title='FDE min'),
        )
        self.viz.line(
            X=iters,
            Y=ade_avg,
            env=self.name,
            win=self.win_id['ade_avg'],
            update='append',
            opts=dict(xlabel='iter', ylabel='ade', title='ADE avg'),
        )

        self.viz.line(
            X=iters,
            Y=fde_avg,
            env=self.name,
            win=self.win_id['fde_avg'],
            update='append',
            opts=dict(xlabel='iter', ylabel='fde', title='FDE avg'),
        )
        self.viz.line(
            X=iters,
            Y=ade_std,
            env=self.name,
            win=self.win_id['ade_std'],
            update='append',
            opts=dict(xlabel='iter', ylabel='ade std', title='ADE std'),
        )

        self.viz.line(
            X=iters,
            Y=fde_std,
            env=self.name,
            win=self.win_id['fde_std'],
            update='append',
            opts=dict(xlabel='iter', ylabel='fde std', title='FDE std'),
        )

    def set_mode(self, train=True):

        if train:
            self.encoderMx.train()
            self.encoderMy.train()
            self.decoderMy.train()
        else:
            self.encoderMx.eval()
            self.encoderMy.eval()
            self.decoderMy.eval()

    ####
    def save_checkpoint(self, iteration):

        encoderMx_path = os.path.join(self.ckpt_dir,
                                      'iter_%s_encoderMx.pt' % iteration)
        encoderMy_path = os.path.join(self.ckpt_dir,
                                      'iter_%s_encoderMy.pt' % iteration)
        decoderMy_path = os.path.join(self.ckpt_dir,
                                      'iter_%s_decoderMy.pt' % iteration)
        lg_cvae_path = os.path.join(self.ckpt_dir,
                                    'iter_%s_lg_cvae.pt' % iteration)
        sg_unet_path = os.path.join(self.ckpt_dir,
                                    'iter_%s_sg_unet.pt' % iteration)
        mkdirs(self.ckpt_dir)

        torch.save(self.encoderMx, encoderMx_path)
        torch.save(self.encoderMy, encoderMy_path)
        torch.save(self.decoderMy, decoderMy_path)

    ####
    def load_checkpoint(self):

        encoderMx_path = os.path.join(
            self.ckpt_dir, 'iter_%s_encoderMx.pt' % self.ckpt_load_iter)
        encoderMy_path = os.path.join(
            self.ckpt_dir, 'iter_%s_encoderMy.pt' % self.ckpt_load_iter)
        decoderMy_path = os.path.join(
            self.ckpt_dir, 'iter_%s_decoderMy.pt' % self.ckpt_load_iter)
        lg_cvae_path = os.path.join(self.ckpt_dir,
                                    'iter_%s_lg_cvae.pt' % self.ckpt_load_iter)
        sg_unet_path = os.path.join(self.ckpt_dir,
                                    'iter_%s_sg_unet.pt' % self.ckpt_load_iter)

        if self.device == 'cuda':
            self.encoderMx = torch.load(encoderMx_path)
            self.encoderMy = torch.load(encoderMy_path)
            self.decoderMy = torch.load(decoderMy_path)
        else:
            self.encoderMx = torch.load(encoderMx_path, map_location='cpu')
            self.encoderMy = torch.load(encoderMy_path, map_location='cpu')
            self.decoderMy = torch.load(decoderMy_path, map_location='cpu')
Example #13
0
def main(args):

    torch.manual_seed(args.seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(args.seed)
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    pbar = tqdm(total=args.epochs)
    image_gather = DataGather('true', 'recon')

    dataset = get_celeba_selected_dataset()
    data_loader = DataLoader(dataset=dataset,
                             batch_size=args.batch_size,
                             shuffle=True)

    lr_vae = args.lr_vae
    lr_D = args.lr_D
    vae = CelebaFactorVAE(args.z_dim, args.num_labels).to(device)
    optim_vae = torch.optim.Adam(vae.parameters(), lr=args.lr_vae)

    D = Discriminator(args.z_dim, args.num_labels).to(device)
    optim_D = torch.optim.Adam(D.parameters(), lr=args.lr_D, betas=(0.5, 0.9))

    # Checkpoint
    ckpt_dir = os.path.join(args.ckpt_dir, args.name)
    mkdirs(ckpt_dir)
    start_epoch = 0
    if args.ckpt_load:
        load_checkpoint(pbar, ckpt_dir, D, vae, optim_D, optim_vae, lr_vae,
                        lr_D)
        #optim_D.param_groups[0]['lr'] = 0.00001#lr_D
        #optim_vae.param_groups[0]['lr'] = 0.00001#lr_vae
        print("confirming lr after loading checkpoint: ",
              optim_vae.param_groups[0]['lr'])

    # Output
    output_dir = os.path.join(args.output_dir, args.name)
    mkdirs(output_dir)

    ones = torch.ones(args.batch_size, dtype=torch.long, device=device)
    zeros = torch.zeros(args.batch_size, dtype=torch.long, device=device)

    for epoch in range(start_epoch, args.epochs):
        pbar.update(1)

        for iteration, (x, y, x2, y2) in enumerate(data_loader):

            x, y, x2, y2 = x.to(device), y.to(device), x2.to(device), y2.to(
                device)

            recon_x, mean, log_var, z = vae(x, y)

            if z.shape[0] != args.batch_size:
                print("passed a batch in epoch {}, iteration {}!".format(
                    epoch, iteration))
                continue

            D_z = D(z)

            vae_recon_loss = recon_loss(x, recon_x) * args.recon_weight
            vae_kld = kl_divergence(mean, log_var)
            vae_tc_loss = (D_z[:, :1] - D_z[:, 1:]).mean() * args.gamma
            vae_loss = vae_recon_loss + vae_tc_loss  #+ vae_kld

            optim_vae.zero_grad()
            vae_loss.backward(retain_graph=True)

            z_prime = vae(x2, y2, no_dec=True)
            z_pperm = permute_dims(z_prime).detach()
            D_z_pperm = D(z_pperm)
            D_tc_loss = 0.5 * (F.cross_entropy(D_z, zeros) +
                               F.cross_entropy(D_z_pperm, ones))

            optim_D.zero_grad()
            D_tc_loss.backward()
            optim_vae.step()
            optim_D.step()

            if iteration % args.print_iter == 0:
                pbar.write(
                    '[epoch {}/{}, iter {}/{}] vae_recon_loss:{:.4f} vae_kld:{:.4f} vae_tc_loss:{:.4f} D_tc_loss:{:.4f}'
                    .format(epoch, args.epochs, iteration,
                            len(data_loader) - 1, vae_recon_loss.item(),
                            vae_kld.item(), vae_tc_loss.item(),
                            D_tc_loss.item()))

            if iteration % args.output_iter == 0 and iteration != 0:
                output_dir = os.path.join(
                    args.output_dir,
                    args.name)  #, "{}.{}".format(epoch, iteration))
                mkdirs(output_dir)

                #reconstruction
                #image_gather.insert(true=x.data.cpu(), recon=torch.sigmoid(recon_x).data.cpu())
                #data = image_gather.data
                #true_image = data['true'][0]
                #recon_image = data['recon'][0]
                #true_image = make_grid(true_image)
                #recon_image = make_grid(recon_image)
                #sample = torch.stack([true_image, recon_image], dim=0)
                #save_image(tensor=sample.cpu(), fp=os.path.join(output_dir, "recon.jpg"))
                #image_gather.flush()

                #inference given num_labels = 10
                c = torch.randint(low=0, high=2,
                                  size=(1, 10))  #populated with 0s and 1s
                for i in range(9):
                    c = torch.cat(
                        (c, torch.randint(low=0, high=2, size=(1, 10))), 0)
                c = c.to(device)
                z_inf = torch.rand([c.size(0), args.z_dim]).to(device)
                #print("shapes: ",z_inf.shape, c.shape)
                #c = c.reshape(-1,args.num_labels,1,1)
                z_inf = torch.cat((z_inf, c), dim=1)
                z_inf = z_inf.reshape(-1, args.num_labels + args.z_dim, 1, 1)
                x = vae.decode(z_inf)

                plt.figure()
                plt.figure(figsize=(10, 20))
                for p in range(args.num_labels):
                    plt.subplot(5, 2, p + 1)  #row, col, index starting from 1
                    plt.text(0,
                             0,
                             "c={}".format(c[p]),
                             color='black',
                             backgroundcolor='white',
                             fontsize=10)

                    p = x[p].view(3, 218, 178)
                    image = torch.transpose(p, 0, 2)
                    image = torch.transpose(image, 0, 1)
                    plt.imshow(
                        (image.cpu().data.numpy() * 255).astype(np.uint8))
                    plt.axis('off')

                plt.savefig(os.path.join(
                    output_dir, "E{:d}||{:d}.png".format(epoch, iteration)),
                            dpi=300)
                plt.clf()
                plt.close('all')

        if epoch % 8 == 0:
            optim_vae.param_groups[0]['lr'] /= 10
            optim_D.param_groups[0]['lr'] /= 10
            print("\nnew learning rate at epoch {} is {}!".format(
                epoch, optim_vae.param_groups[0]['lr']))

        if epoch % args.ckpt_iter_epoch == 0:
            save_checkpoint(pbar, epoch, D, vae, optim_D, optim_vae, ckpt_dir,
                            epoch)

    pbar.write("[Training Finished]")
    pbar.close()
Example #14
0
class Solver(object):
    def __init__(self, args):
        # Misc
        use_cuda = args.cuda and torch.cuda.is_available()
        self.device = 'cuda' if use_cuda else 'cpu'
        self.name = args.name
        self.max_iter = int(args.max_iter)
        self.print_iter = args.print_iter
        self.global_iter = 0
        self.test_count = 0
        self.pbar = tqdm(total=self.max_iter)

        # Data
        self.dset_dir = args.dset_dir
        self.dataset = args.dataset
        self.batch_size = args.batch_size
        self.data_loader = return_data(args)
    
        # Networks & Optimizers
        self.z_dim = args.z_dim
        self.gamma = args.gamma

        self.lr_VAE = args.lr_VAE
        self.beta1_VAE = args.beta1_VAE
        self.beta2_VAE = args.beta2_VAE

        self.lr_D = args.lr_D
        self.beta1_D = args.beta1_D
        self.beta2_D = args.beta2_D

        if args.dataset == 'dsprites':
            self.VAE = FactorVAE1(self.z_dim).to(self.device)
            self.nc = 1
        
        elif args.dataset == 'mnist':
            self.VAE = Custom_FactorVAE2(self.z_dim).to(self.device)
            self.nc = 3

        elif args.dataset == 'load_mnist':
            self.VAE = Custom_FactorVAE2(self.z_dim).to(self.device)
            self.nc = 3
       
        elif args.dataset == 'glove/numpy_vector/300d_wiki.npy': 
            self.VAE = Glove_FactorVAE1(self.z_dim).to(self.device)
            self.nc = 3
        else:
            self.VAE = FactorVAE2(self.z_dim).to(self.device)
            self.nc = 3
        self.optim_VAE = optim.Adam(self.VAE.parameters(), lr=self.lr_VAE,
                                    betas=(self.beta1_VAE, self.beta2_VAE))

        self.D = Discriminator(self.z_dim).to(self.device)
        self.optim_D = optim.Adam(self.D.parameters(), lr=self.lr_D,
                                  betas=(self.beta1_D, self.beta2_D))

        self.nets = [self.VAE, self.D]

        # Visdom
        self.viz_on = args.viz_on
        self.win_id = dict(D_z='win_D_z', recon='win_recon', kld='win_kld', acc='win_acc')
        self.line_gather = DataGather('iter', 'soft_D_z', 'soft_D_z_pperm', 'recon', 'kld', 'acc')
        self.image_gather = DataGather('true', 'recon')
        if self.viz_on:
            self.viz_port = args.viz_port
            self.viz = visdom.Visdom(port=self.viz_port)
            self.viz_ll_iter = args.viz_ll_iter
            self.viz_la_iter = args.viz_la_iter
            self.viz_ra_iter = args.viz_ra_iter
            self.viz_ta_iter = args.viz_ta_iter
            if not self.viz.win_exists(env=self.name+'/lines', win=self.win_id['D_z']):
                self.viz_init()

        # Checkpoint
        self.ckpt_dir = os.path.join(args.ckpt_dir, args.name)
        self.ckpt_save_iter = args.ckpt_save_iter
        mkdirs(self.ckpt_dir)
        if args.ckpt_load:
            self.load_checkpoint(args.ckpt_load)

        # Output(latent traverse GIF)
        self.output_dir = os.path.join(args.output_dir, args.name)
        self.output_save = args.output_save
        mkdirs(self.output_dir)


    def custom_loss(self, x): #lossは交差エントロピーを採用している, MSEの事例もある
        #https://tips-memo.com/vae-pytorch#i-7, http://aidiary.hatenablog.com/entry/20180228/1519828344のlossを参考 
        mean, var = self.VAE._encoder(x)
        #KL = -0.5 * torch.mean(torch.sum(1 + torch.log(var) - mean**2 - var)) #オリジナル, mean意味わからんけど, あんまり値が変わらないか>ら
        #上手くいくんじゃないか
        #KL = 0.5 * torch.sum(torch.exp(var) + mean**2 - 1. - var)
        KL = -0.5 * torch.sum(1 + var - mean.pow(2) - var.exp()) 
        # sumを行っているのは各次元ごとに算出しているため
        #print("KL: " + str(KL))
        z = self.VAE._sample_z(mean, var)
        y = self.VAE._decoder(z)
        #delta = 1e-8
        #reconstruction = torch.mean(torch.sum(x * torch.log(y + delta) + (1 - x) * torch.log(1 - y + delta)))                                    
        #reconstruction = F.binary_cross_entropy(y, x.view(-1, 784), size_average=False)
        reconstruction = F.binary_cross_entropy(y, x, size_average=False)
        #交差エントロピー誤差を利用して, 対数尤度の最大化を行っている, 2つのみ=(1-x), (1-y)で算出可能
        #http://aidiary.hatenablog.com/entry/20180228/1519828344(参考記事)
        #print("reconstruction: " + str(reconstruction))
        #lower_bound = [-KL, reconstruction]
        #両方とも小さくしたい, クロスエントロピーは本来マイナス, KLは小さくしたいからプラスに変換
        #returnで恐らくわかりやすくするために, 目的関数から誤差関数への変換をしている
        #return -sum(lower_bound)
        return KL + reconstruction


    def train(self):
        self.net_mode(train=True)

        ones = torch.ones(self.batch_size, dtype=torch.long, device=self.device)
        zeros = torch.zeros(self.batch_size, dtype=torch.long, device=self.device)

        out = False
        while not out:
            for x_true1, x_true2 in self.data_loader:#ここで読み込んでいる?
                self.global_iter += 1
                self.pbar.update(1)
                if self.dataset == 'mnist':
                     x_true1 =  x_true1.view(x_true1.shape[0], -1)
                x_true1 = x_true1.to(self.device)
                x_recon, mu, logvar, z = self.VAE(x_true1)
                x = x_true1.view(x_true1.shape[0], -1) #custom

                #vae_recon_loss = self.custom_loss(x) / self.batch_size #custom
                vae_recon_loss = recon_loss(x, x_recon) #復元誤差, 交差エントロピー誤差
                vae_kld = kl_divergence(mu, logvar)
                D_z = self.D(z)
                vae_tc_loss = (D_z[:, :1] - D_z[:, 1:]).mean() #恐らく, discriminatorのloss

                vae_loss = vae_recon_loss + vae_kld + self.gamma*vae_tc_loss
                #vae_loss = vae_recon_loss + self.gamma*vae_tc_loss 
                self.optim_VAE.zero_grad()
                vae_loss.backward(retain_graph=True)
                self.optim_VAE.step()
                x_true2 = x_true2.to(self.device)
                #x_true2 = x_true2.view(x_true2.shape[0], -1)
                z_prime = self.VAE(x_true2, no_dec=True) #trueにすることで潜在空間に写像した状態のデータを獲得?
                z_pperm = permute_dims(z_prime).detach()
                D_z_pperm = self.D(z_pperm)
                D_tc_loss = 0.5*(F.cross_entropy(D_z, zeros) + F.cross_entropy(D_z_pperm, ones)) #GANのdiscriminatorっぽい?偽物と本物
                #そのため誤差の部分が0と1になっているはず!zerosとonesの部分

                self.optim_D.zero_grad()
                D_tc_loss.backward()
                self.optim_D.step()

                #if self.global_iter%self.print_iter == 0:
                #    self.pbar.write('[{}] vae_recon_loss:{:.3f} vae_kld:{:.3f} vae_tc_loss:{:.3f} D_tc_loss:{:.3f}'.format(
                #        self.global_iter, vae_recon_loss.item(), vae_kld.item(), vae_tc_loss.item(), D_tc_loss.item()))
                if self.test_count % 547 == 0:
                    #self.pbar.write('[{}] vae_recon_loss:{:.3f} vae_kld:{:.3f} vae_tc_loss:{:.3f} D_tc_loss:{:.3f}'.format(
                        #self.global_iter, vae_recon_loss.item(), vae_kld.item(), vae_tc_loss.item(), D_tc_loss.item()))
                    self.pbar.write('[{}] vae_recon_loss:{:.3f} vae_tc_loss:{:.3f} D_tc_loss:{:.3f}'.format(
                        self.global_iter, vae_recon_loss.item(), vae_tc_loss.item(), D_tc_loss.item()))  
                    self.test_count = 0
                
                if self.global_iter%self.ckpt_save_iter == 0:
                    self.save_checkpoint(self.global_iter)

                if self.viz_on and (self.global_iter%self.viz_ll_iter == 0):
                    soft_D_z = F.softmax(D_z, 1)[:, :1].detach()
                    soft_D_z_pperm = F.softmax(D_z_pperm, 1)[:, :1].detach()
                    D_acc = ((soft_D_z >= 0.5).sum() + (soft_D_z_pperm < 0.5).sum()).float()
                    D_acc /= 2*self.batch_size
                    self.line_gather.insert(iter=self.global_iter,
                                            soft_D_z=soft_D_z.mean().item(),
                                            soft_D_z_pperm=soft_D_z_pperm.mean().item(),
                                            recon=vae_recon_loss.item(),
                                            #kld=vae_kld.item(),
                                            acc=D_acc.item())

                if self.viz_on and (self.global_iter%self.viz_la_iter == 0):
                    self.visualize_line()
                    self.line_gather.flush()

                if self.viz_on and (self.global_iter%self.viz_ra_iter == 0):
                    self.image_gather.insert(true=x_true1.data.cpu(),
                                             recon=F.sigmoid(x_recon).data.cpu())
                    self.visualize_recon()
                    self.image_gather.flush()

                if self.viz_on and (self.global_iter%self.viz_ta_iter == 0):
                    if self.dataset.lower() == '3dchairs':
                        self.visualize_traverse(limit=2, inter=0.5)
                    else:
                        #self.visualize_traverse(limit=3, inter=2/3)
                        print("ignore")

                if self.global_iter >= self.max_iter:
                    out = True
                    break
                self.test_count += 1

        self.pbar.write("[Training Finished]")
        torch.save(self.VAE.state_dict(), "model1/0531_128_2_gamma2.pth")
        self.pbar.close()

    def visualize_recon(self):
        data = self.image_gather.data
        true_image = data['true'][0]
        recon_image = data['recon'][0]

        true_image = make_grid(true_image)
        recon_image = make_grid(recon_image)
        sample = torch.stack([true_image, recon_image], dim=0)
        self.viz.images(sample, env=self.name+'/recon_image',
                        opts=dict(title=str(self.global_iter)))

    def visualize_line(self):
        data = self.line_gather.data
        iters = torch.Tensor(data['iter'])
        recon = torch.Tensor(data['recon'])
        kld = torch.Tensor(data['kld'])
        D_acc = torch.Tensor(data['acc'])
        soft_D_z = torch.Tensor(data['soft_D_z'])
        soft_D_z_pperm = torch.Tensor(data['soft_D_z_pperm'])
        soft_D_zs = torch.stack([soft_D_z, soft_D_z_pperm], -1)

        self.viz.line(X=iters,
                      Y=soft_D_zs,
                      env=self.name+'/lines',
                      win=self.win_id['D_z'],
                      update='append',
                      opts=dict(
                        xlabel='iteration',
                        ylabel='D(.)',
                        legend=['D(z)', 'D(z_perm)']))
        self.viz.line(X=iters,
                      Y=recon,
                      env=self.name+'/lines',
                      win=self.win_id['recon'],
                      update='append',
                      opts=dict(
                        xlabel='iteration',
                        ylabel='reconstruction loss',))
        self.viz.line(X=iters,Y=D_acc,
                      env=self.name+'/lines',
                      win=self.win_id['acc'],
                      update='append',
                      opts=dict(
                        xlabel='iteration',
                        ylabel='discriminator accuracy',))
        '''
        self.viz.line(X=iters,
                      Y=kld,
                      env=self.name+'/lines',
                      win=self.win_id['kld'],
                      update='append',
                      opts=dict(
                        xlabel='iteration',
                        ylabel='kl divergence',))
        '''

    def visualize_traverse(self, limit=3, inter=2/3, loc=-1):
        self.net_mode(train=False)

        decoder = self.VAE.decode
        encoder = self.VAE.encode
        interpolation = torch.arange(-limit, limit+0.1, inter)
        random_img = self.data_loader.dataset.__getitem__(0)[1]
        random_img = random_img.to(self.device).unsqueeze(0)
        random_img_z = encoder(random_img)[:, :self.z_dim]

        if self.dataset.lower() == 'dsprites':
            fixed_idx1 = 87040 # square
            fixed_idx2 = 332800 # ellipse
            fixed_idx3 = 578560 # heart

            fixed_img1 = self.data_loader.dataset.__getitem__(fixed_idx1)[0]
            fixed_img1 = fixed_img1.to(self.device).unsqueeze(0)
            fixed_img_z1 = encoder(fixed_img1)[:, :self.z_dim]

            fixed_img2 = self.data_loader.dataset.__getitem__(fixed_idx2)[0]
            fixed_img2 = fixed_img2.to(self.device).unsqueeze(0)
            fixed_img_z2 = encoder(fixed_img2)[:, :self.z_dim]

            fixed_img3 = self.data_loader.dataset.__getitem__(fixed_idx3)[0]
            fixed_img3 = fixed_img3.to(self.device).unsqueeze(0)
            fixed_img_z3 = encoder(fixed_img3)[:, :self.z_dim]

            Z = {'fixed_square':fixed_img_z1, 'fixed_ellipse':fixed_img_z2,
                 'fixed_heart':fixed_img_z3, 'random_img':random_img_z}

        elif self.dataset.lower() == 'celeba':
            fixed_idx1 = 191281 # 'CelebA/img_align_celeba/191282.jpg'
            fixed_idx2 = 143307 # 'CelebA/img_align_celeba/143308.jpg'
            fixed_idx3 = 101535 # 'CelebA/img_align_celeba/101536.jpg'
            fixed_idx4 = 70059  # 'CelebA/img_align_celeba/070060.jpg'

            fixed_img1 = self.data_loader.dataset.__getitem__(fixed_idx1)[0]
            fixed_img1 = fixed_img1.to(self.device).unsqueeze(0)
            fixed_img_z1 = encoder(fixed_img1)[:, :self.z_dim]

            fixed_img2 = self.data_loader.dataset.__getitem__(fixed_idx2)[0]
            fixed_img2 = fixed_img2.to(self.device).unsqueeze(0)
            fixed_img_z2 = encoder(fixed_img2)[:, :self.z_dim]

            fixed_img3 = self.data_loader.dataset.__getitem__(fixed_idx3)[0]
            fixed_img3 = fixed_img3.to(self.device).unsqueeze(0)
            fixed_img_z3 = encoder(fixed_img3)[:, :self.z_dim]

            fixed_img4 = self.data_loader.dataset.__getitem__(fixed_idx4)[0]
            fixed_img4 = fixed_img4.to(self.device).unsqueeze(0)
            fixed_img_z4 = encoder(fixed_img4)[:, :self.z_dim]

            Z = {'fixed_1':fixed_img_z1, 'fixed_2':fixed_img_z2,
                 'fixed_3':fixed_img_z3, 'fixed_4':fixed_img_z4,
                 'random':random_img_z}

        elif self.dataset.lower() == '3dchairs':
            fixed_idx1 = 40919 # 3DChairs/images/4682_image_052_p030_t232_r096.png
            fixed_idx2 = 5172  # 3DChairs/images/14657_image_020_p020_t232_r096.png
            fixed_idx3 = 22330 # 3DChairs/images/30099_image_052_p030_t232_r096.png

            fixed_img1 = self.data_loader.dataset.__getitem__(fixed_idx1)[0]
            fixed_img1 = fixed_img1.to(self.device).unsqueeze(0)
            fixed_img_z1 = encoder(fixed_img1)[:, :self.z_dim]

            fixed_img2 = self.data_loader.dataset.__getitem__(fixed_idx2)[0]
            fixed_img2 = fixed_img2.to(self.device).unsqueeze(0)
            fixed_img_z2 = encoder(fixed_img2)[:, :self.z_dim]

            fixed_img3 = self.data_loader.dataset.__getitem__(fixed_idx3)[0]
            fixed_img3 = fixed_img3.to(self.device).unsqueeze(0)
            fixed_img_z3 = encoder(fixed_img3)[:, :self.z_dim]

            Z = {'fixed_1':fixed_img_z1, 'fixed_2':fixed_img_z2,
                 'fixed_3':fixed_img_z3, 'random':random_img_z}
        else:
            fixed_idx = 0
            fixed_img = self.data_loader.dataset.__getitem__(fixed_idx)[0]
            fixed_img = fixed_img.to(self.device).unsqueeze(0)
            fixed_img_z = encoder(fixed_img)[:, :self.z_dim]

            random_z = torch.rand(1, self.z_dim, 1, 1, device=self.device)

            Z = {'fixed_img':fixed_img_z, 'random_img':random_img_z, 'random_z':random_z}

        gifs = []
        for key in Z:
            z_ori = Z[key]
            samples = []
            for row in range(self.z_dim):
                if loc != -1 and row != loc:
                    continue
                z = z_ori.clone()
                for val in interpolation:
                    z[:, row] = val
                    sample = F.sigmoid(decoder(z)).data
                    samples.append(sample)
                    gifs.append(sample)
            samples = torch.cat(samples, dim=0).cpu()
            title = '{}_latent_traversal(iter:{})'.format(key, self.global_iter)
            self.viz.images(samples, env=self.name+'/traverse',
                            opts=dict(title=title), nrow=len(interpolation))

        if self.output_save:
            output_dir = os.path.join(self.output_dir, str(self.global_iter))
            mkdirs(output_dir)
            gifs = torch.cat(gifs)
            gifs = gifs.view(len(Z), self.z_dim, len(interpolation), self.nc, 64, 64).transpose(1, 2)
            for i, key in enumerate(Z.keys()):
                for j, val in enumerate(interpolation):
                    save_image(tensor=gifs[i][j].cpu(),
                               filename=os.path.join(output_dir, '{}_{}.jpg'.format(key, j)),
                               nrow=self.z_dim, pad_value=1)

                grid2gif(str(os.path.join(output_dir, key+'*.jpg')),
                         str(os.path.join(output_dir, key+'.gif')), delay=10)

        self.net_mode(train=True)

    def viz_init(self):
        zero_init = torch.zeros([1])
        self.viz.line(X=zero_init,
                      Y=torch.stack([zero_init, zero_init], -1),
                      env=self.name+'/lines',
                      win=self.win_id['D_z'],
                      opts=dict(
                        xlabel='iteration',
                        ylabel='D(.)',
                        legend=['D(z)', 'D(z_perm)']))
        self.viz.line(X=zero_init,
                      Y=zero_init,
                      env=self.name+'/lines',
                      win=self.win_id['recon'],
                      opts=dict(
                        xlabel='iteration',
                        ylabel='reconstruction loss',))
        self.viz.line(X=zero_init,
                      Y=zero_init,
                      env=self.name+'/lines',
                      win=self.win_id['acc'],
                      opts=dict(
                        xlabel='iteration',
                        ylabel='discriminator accuracy',))
        self.viz.line(X=zero_init,
                      Y=zero_init,
                      env=self.name+'/lines',
                      win=self.win_id['kld'],
                      opts=dict(
                        xlabel='iteration',
                        ylabel='kl divergence',))

    def net_mode(self, train):
        if not isinstance(train, bool):
            raise ValueError('Only bool type is supported. True|False')

        for net in self.nets:
            if train:
                net.train()
            else:
                net.eval()

    def save_checkpoint(self, ckptname='last', verbose=True):
        model_states = {'D':self.D.state_dict(),
                        'VAE':self.VAE.state_dict()}
        optim_states = {'optim_D':self.optim_D.state_dict(),
                        'optim_VAE':self.optim_VAE.state_dict()}
        states = {'iter':self.global_iter,
                  'model_states':model_states,
                  'optim_states':optim_states}

        filepath = os.path.join(self.ckpt_dir, str(ckptname))
        with open(filepath, 'wb+') as f:
            torch.save(states, f)
        if verbose:
            self.pbar.write("=> saved checkpoint '{}' (iter {})".format(filepath, self.global_iter))

    def load_checkpoint(self, ckptname='last', verbose=True):
        if ckptname == 'last':
            ckpts = os.listdir(self.ckpt_dir)
            if not ckpts:
                if verbose:
                    self.pbar.write("=> no checkpoint found")
                return

            ckpts = [int(ckpt) for ckpt in ckpts]
            ckpts.sort(reverse=True)
            ckptname = str(ckpts[0])

        filepath = os.path.join(self.ckpt_dir, ckptname)
        if os.path.isfile(filepath):
            with open(filepath, 'rb') as f:
                checkpoint = torch.load(f)

            self.global_iter = checkpoint['iter']
            self.VAE.load_state_dict(checkpoint['model_states']['VAE'])
            self.D.load_state_dict(checkpoint['model_states']['D'])
            self.optim_VAE.load_state_dict(checkpoint['optim_states']['optim_VAE'])
            self.optim_D.load_state_dict(checkpoint['optim_states']['optim_D'])
            self.pbar.update(self.global_iter)
            if verbose:
                self.pbar.write("=> loaded checkpoint '{} (iter {})'".format(filepath, self.global_iter))
        else:
            if verbose:
                self.pbar.write("=> no checkpoint found at '{}'".format(filepath))
    
    def senzai_view(self, z, label):
        plt.figure(figsize=(10, 10))
        plt.scatter(z[:, 0], z[:, 1], marker='.', c=label, cmap=pylab.cm.jet)
        plt.colorbar()
        plt.grid()
        plt.title('oza_FVAE_2dimention')
        plt.savefig('FVAE0531_128_2_gamma2_senzai.png')

    def load_model(self):
        self.VAE.load_state_dict(torch.load("model1/0531_128_2_gamma2.pth", map_location=self.device))
        for data, label in self.data_loader:
            data = data.to(self.device)
            data = data.view(data.shape[0], -1)
            label = label.detach().numpy()
            break
        
        n = 10
        x_recon, mu, logvar, z = self.VAE(data)
        z = Variable(z, volatile=True).cpu().numpy()
        data = Variable(data, volatile=True).cpu().numpy()
        x_recon = Variable(x_recon, volatile=True).cpu().numpy()

        #以下, ラベルごとの分散算出
        '''
        sum = 0
        for i in range(10):
            tmp  = np.where(label == i)
            print(np.var(z[tmp]))
            sum += np.var(z[tmp])
        sum /= 10
        print("Ave var: " + str(sum))
        quit()
        #ここまで, 通常は消すこと
        
        plt.figure(figsize=(10, 10))
        plt.scatter(z[:, 0], z[:, 1], marker='.', c=label, cmap=pylab.cm.jet)
        plt.colorbar()
        plt.grid()
        plt.savefig('FVAE0528_128_2_senzai.png')
        '''
        if self.z_dim == 2:
            self.senzai_view(z, label)
        plt.figure(figsize=(12, 6))
        for i in range(n):
            ax = plt.subplot(3, n, i+1)
            if i == 1:
                plt.title('Original MNIST')
            plt.imshow(data[i].reshape(28, 28))
            plt.gray()
            ax.get_xaxis().set_visible(False)
            ax.get_yaxis().set_visible(False)
            ax = plt.subplot(3, n, i+1+n)
            if i == 1:
                plt.title('FVAE_Reconstruction MNIST(20dim)')
            plt.imshow(x_recon[i].reshape(28, 28))
            plt.gray()
            ax.get_xaxis().set_visible(False)
            ax.get_yaxis().set_visible(False)
        plt.savefig("FVAE0531_128_20_gamma2_recon.png")
        plt.show()
        plt.close()
class Trainer(object):
    def __init__(self, args):
        self.use_cuda = args.cuda and torch.cuda.is_available()
        self.max_epoch = args.max_epoch
        self.global_epoch = 0
        self.global_iter = 0

        self.z_dim = args.z_dim
        self.z_var = args.z_var
        self.z_sigma = math.sqrt(args.z_var)
        self._lambda = args.reg_weight
        self.lr = args.lr
        self.beta1 = args.beta1
        self.beta2 = args.beta2
        self.lr_schedules = {30: 2, 50: 5, 100: 10}

        if args.dataset.lower() == 'celeba':
            self.nc = 3
            self.decoder_dist = 'gaussian'
        else:
            raise NotImplementedError

        net = WAE
        self.net = cuda(net(self.z_dim, self.nc), self.use_cuda)
        self.optim = optim.Adam(self.net.parameters(),
                                lr=self.lr,
                                betas=(self.beta1, self.beta2))

        self.gather = DataGather()
        self.viz_name = args.viz_name
        self.viz_port = args.viz_port
        self.viz_on = args.viz_on
        if self.viz_on:
            self.viz = visdom.Visdom(env=self.viz_name + '_lines',
                                     port=self.viz_port)
            self.win_recon = None
            self.win_mmd = None
            self.win_mu = None
            self.win_var = None

        self.ckpt_dir = Path(args.ckpt_dir).joinpath(args.viz_name)
        if not self.ckpt_dir.exists():
            self.ckpt_dir.mkdir(parents=True, exist_ok=True)
        self.ckpt_name = args.ckpt_name
        if self.ckpt_name is not None:
            self.load_checkpoint(self.ckpt_name)

        self.save_output = args.save_output
        self.output_dir = Path(args.output_dir).joinpath(args.viz_name)
        if not self.output_dir.exists():
            self.output_dir.mkdir(parents=True, exist_ok=True)

        self.dset_dir = args.dset_dir
        self.dataset = args.dataset
        self.batch_size = args.batch_size
        self.data_loader = return_data(args)

    def train(self):
        self.net.train()

        iters_per_epoch = len(self.data_loader)
        max_iter = self.max_epoch * iters_per_epoch
        pbar = tqdm(total=max_iter)
        with tqdm(total=max_iter) as pbar:
            pbar.update(self.global_iter)
            out = False
            while not out:
                for x in self.data_loader:
                    pbar.update(1)
                    self.global_iter += 1
                    if self.global_iter % iters_per_epoch == 0:
                        self.global_epoch += 1
                    self.optim = multistep_lr_decay(self.optim,
                                                    self.global_epoch,
                                                    self.lr_schedules)

                    x = Variable(cuda(x, self.use_cuda))
                    x_recon, z_tilde = self.net(x)
                    z = self.sample_z(template=z_tilde, sigma=self.z_sigma)

                    recon_loss = F.mse_loss(
                        x_recon, x, size_average=False).div(self.batch_size)
                    mmd_loss = mmd(z_tilde, z, z_var=self.z_var)
                    total_loss = recon_loss + self._lambda * mmd_loss

                    self.optim.zero_grad()
                    total_loss.backward()
                    self.optim.step()

                    if self.global_iter % 1000 == 0:
                        self.gather.insert(
                            iter=self.global_iter,
                            mu=z.mean(0).data,
                            var=z.var(0).data,
                            recon_loss=recon_loss.data,
                            mmd_loss=mmd_loss.data,
                        )

                    if self.global_iter % 5000 == 0:
                        self.gather.insert(images=x.data)
                        self.gather.insert(images=x_recon.data)
                        self.viz_reconstruction()
                        self.viz_lines()
                        self.sample_x_from_z(n_sample=100)
                        self.gather.flush()
                        self.save_checkpoint('last')
                        pbar.write(
                            '[{}] total_loss:{:.3f} recon_loss:{:.3f} mmd_loss:{:.3f}'
                            .format(self.global_iter, total_loss.data[0],
                                    recon_loss.data[0], mmd_loss.data[0]))

                    if self.global_iter % 20000 == 0:
                        self.save_checkpoint(str(self.global_iter))

                    if self.global_iter >= max_iter:
                        out = True
                        break

            pbar.write("[Training Finished]")

    def viz_reconstruction(self):
        self.net.eval()
        x = self.gather.data['images'][0][:100]
        x = make_grid(x, normalize=True, nrow=10)
        x_recon = F.sigmoid(self.gather.data['images'][1][:100])
        x_recon = make_grid(x_recon, normalize=True, nrow=10)
        images = torch.stack([x, x_recon], dim=0).cpu()
        self.viz.images(images,
                        env=self.viz_name + '_reconstruction',
                        opts=dict(title=str(self.global_iter)),
                        nrow=2)
        self.net.train()

    def viz_lines(self):
        self.net.eval()
        recon_losses = torch.stack(self.gather.data['recon_loss']).cpu()
        mmd_losses = torch.stack(self.gather.data['mmd_loss']).cpu()
        mus = torch.stack(self.gather.data['mu']).cpu()
        vars = torch.stack(self.gather.data['var']).cpu()
        iters = torch.Tensor(self.gather.data['iter'])

        legend = []
        for z_j in range(self.z_dim):
            legend.append('z_{}'.format(z_j))

        if self.win_recon is None:
            self.win_recon = self.viz.line(X=iters,
                                           Y=recon_losses,
                                           env=self.viz_name + '_lines',
                                           opts=dict(
                                               width=400,
                                               height=400,
                                               xlabel='iteration',
                                               title='reconsturction loss',
                                           ))
        else:
            self.win_recon = self.viz.line(X=iters,
                                           Y=recon_losses,
                                           env=self.viz_name + '_lines',
                                           win=self.win_recon,
                                           update='append',
                                           opts=dict(
                                               width=400,
                                               height=400,
                                               xlabel='iteration',
                                               title='reconsturction loss',
                                           ))

        if self.win_mmd is None:
            self.win_mmd = self.viz.line(X=iters,
                                         Y=mmd_losses,
                                         env=self.viz_name + '_lines',
                                         opts=dict(
                                             width=400,
                                             height=400,
                                             xlabel='iteration',
                                             title='maximum mean discrepancy',
                                         ))
        else:
            self.win_mmd = self.viz.line(X=iters,
                                         Y=mmd_losses,
                                         env=self.viz_name + '_lines',
                                         win=self.win_mmd,
                                         update='append',
                                         opts=dict(
                                             width=400,
                                             height=400,
                                             xlabel='iteration',
                                             title='maximum mean discrepancy',
                                         ))

        if self.win_mu is None:
            self.win_mu = self.viz.line(X=iters,
                                        Y=mus,
                                        env=self.viz_name + '_lines',
                                        opts=dict(
                                            width=400,
                                            height=400,
                                            legend=legend,
                                            xlabel='iteration',
                                            title='posterior mean',
                                        ))
        else:
            self.win_mu = self.viz.line(X=iters,
                                        Y=vars,
                                        env=self.viz_name + '_lines',
                                        win=self.win_mu,
                                        update='append',
                                        opts=dict(
                                            width=400,
                                            height=400,
                                            legend=legend,
                                            xlabel='iteration',
                                            title='posterior mean',
                                        ))

        if self.win_var is None:
            self.win_var = self.viz.line(X=iters,
                                         Y=vars,
                                         env=self.viz_name + '_lines',
                                         opts=dict(
                                             width=400,
                                             height=400,
                                             legend=legend,
                                             xlabel='iteration',
                                             title='posterior variance',
                                         ))
        else:
            self.win_var = self.viz.line(X=iters,
                                         Y=vars,
                                         env=self.viz_name + '_lines',
                                         win=self.win_var,
                                         update='append',
                                         opts=dict(
                                             width=400,
                                             height=400,
                                             legend=legend,
                                             xlabel='iteration',
                                             title='posterior variance',
                                         ))
        self.net.train()

    def sample_z(self, n_sample=None, dim=None, sigma=None, template=None):
        if n_sample is None:
            n_sample = self.batch_size
        if dim is None:
            dim = self.z_dim
        if sigma is None:
            sigma = self.z_sigma

        if template is not None:
            z = sigma * Variable(template.data.new(template.size()).normal_())
        else:
            z = sigma * torch.randn(n_sample, dim)
            z = Variable(cuda(z, self.use_cuda))

        return z

    def sample_x_from_z(self, n_sample):
        self.net.eval()
        z = self.sample_z(n_sample=n_sample, sigma=self.z_sigma)
        x_gen = F.sigmoid(self.net._decode(z)[:100]).data.cpu()
        x_gen = make_grid(x_gen, normalize=True, nrow=10)
        self.viz.images(x_gen,
                        env=self.viz_name + '_sampling_from_random_z',
                        opts=dict(title=str(self.global_iter)))
        self.net.train()

    def save_checkpoint(self, filename, silent=True):
        model_states = {
            'net': self.net.state_dict(),
        }
        optim_states = {
            'optim': self.optim.state_dict(),
        }
        win_states = {
            'recon': self.win_recon,
            'mmd': self.win_mmd,
            'mu': self.win_mu,
            'var': self.win_var,
        }
        states = {
            'iter': self.global_iter,
            'epoch': self.global_epoch,
            'win_states': win_states,
            'model_states': model_states,
            'optim_states': optim_states
        }

        file_path = self.ckpt_dir.joinpath(filename)
        torch.save(states, file_path.open('wb+'))
        if not silent:
            print("=> saved checkpoint '{}' (iter {})".format(
                file_path, self.global_iter))

    def load_checkpoint(self, filename, silent=False):
        file_path = self.ckpt_dir.joinpath(filename)
        if file_path.is_file():
            checkpoint = torch.load(file_path.open('rb'))
            self.global_iter = checkpoint['iter']
            self.global_epoch = checkpoint['epoch']
            self.win_recon = checkpoint['win_states']['recon']
            self.win_mmd = checkpoint['win_states']['mmd']
            self.win_var = checkpoint['win_states']['var']
            self.win_mu = checkpoint['win_states']['mu']
            self.net.load_state_dict(checkpoint['model_states']['net'])
            self.optim.load_state_dict(checkpoint['optim_states']['optim'])
            if not silent:
                print("=> loaded checkpoint '{} (iter {})'".format(
                    file_path, self.global_iter))
        else:
            if not silent:
                print("=> no checkpoint found at '{}'".format(file_path))
Example #16
0
class Solver(object):

    ####
    def __init__(self, args):

        self.args = args

        self.name = '%s_lamkl_%s_zA_%s_zB_%s_zS_%s_HYPER_beta1_%s_beta2_%s_beta3_%s' % \
                    (
                        args.dataset, args.lamkl, args.zA_dim, args.zB_dim, args.zS_dim, args.beta1, args.beta2,
                        args.beta3)
        # to be appended by run_id

        self.use_cuda = args.cuda and torch.cuda.is_available()

        self.max_iter = int(args.max_iter)

        # do it every specified iters
        self.print_iter = args.print_iter
        self.ckpt_save_iter = args.ckpt_save_iter
        self.output_save_iter = args.output_save_iter

        # data info
        self.dset_dir = args.dset_dir
        self.dataset = args.dataset
        self.nc = 3

        # self.N = self.latent_values.shape[0]
        self.eval_metrics_iter = args.eval_metrics_iter

        # networks and optimizers
        self.batch_size = args.batch_size
        self.zA_dim = args.zA_dim
        self.zB_dim = args.zB_dim
        self.zS_dim = args.zS_dim
        self.lamkl = args.lamkl
        self.lr_VAE = args.lr_VAE
        self.beta1_VAE = args.beta1_VAE
        self.beta2_VAE = args.beta2_VAE

        self.lr_D = args.lr_D
        self.beta1_D = args.beta1_D
        self.beta2_D = args.beta2_D

        self.beta1 = args.beta1
        self.beta2 = args.beta2
        self.beta3 = args.beta3
        self.is_mss = args.is_mss

        # visdom setup
        self.viz_on = args.viz_on
        if self.viz_on:
            self.win_id = dict(
                recon='win_recon', kl='win_kl', capa='win_capa'
            )
            self.line_gather = DataGather(
                'iter', 'recon_both', 'recon_A', 'recon_B',
                'kl_A', 'kl_B', 
                'cont_capacity_loss_infA', 'disc_capacity_loss_infA', 'cont_capacity_loss_infB', 'disc_capacity_loss_infB'
            )

            # if self.eval_metrics:
            #     self.win_id['metrics'] = 'win_metrics'

            import visdom

            self.viz_port = args.viz_port  # port number, eg, 8097
            self.viz = visdom.Visdom(port=self.viz_port)
            self.viz_ll_iter = args.viz_ll_iter
            self.viz_la_iter = args.viz_la_iter

            self.viz_init()

        # create dirs: "records", "ckpts", "outputs" (if not exist)
        mkdirs("records");
        mkdirs("ckpts");
        mkdirs("outputs")

        # set run id
        if args.run_id < 0:  # create a new id
            k = 0;
            rfname = os.path.join("records", self.name + '_run_0.txt')
            while os.path.exists(rfname):
                k += 1
                rfname = os.path.join("records", self.name + '_run_%d.txt' % k)
            self.run_id = k
        else:  # user-provided id
            self.run_id = args.run_id

        # finalize name
        self.name = self.name + '_run_' + str(self.run_id)

        # records (text file to store console outputs)
        self.record_file = 'records/%s.txt' % self.name

        # checkpoints
        self.ckpt_dir = os.path.join("ckpts", self.name)

        # outputs
        self.output_dir_recon = os.path.join("outputs", self.name + '_recon')
        # dir for reconstructed images
        self.output_dir_synth = os.path.join("outputs", self.name + '_synth')
        # dir for synthesized images
        self.output_dir_trvsl = os.path.join("outputs", self.name + '_trvsl')

        #### create a new model or load a previously saved model

        self.ckpt_load_iter = args.ckpt_load_iter
        self.n_pts = args.n_pts
        self.n_data = args.n_data

        if self.ckpt_load_iter == 0:  # create a new model
            self.encoderA = EncoderA(self.zA_dim, self.zS_dim)
            self.encoderB = EncoderA(self.zB_dim, self.zS_dim)
            self.decoderA = DecoderA(self.zA_dim, self.zS_dim)
            self.decoderB = DecoderA(self.zB_dim, self.zS_dim)

        else:  # load a previously saved model

            print('Loading saved models (iter: %d)...' % self.ckpt_load_iter)
            self.load_checkpoint()
            print('...done')

        if self.use_cuda:
            print('Models moved to GPU...')
            self.encoderA = self.encoderA.cuda()
            self.encoderB = self.encoderB.cuda()
            self.decoderA = self.decoderA.cuda()
            self.decoderB = self.decoderB.cuda()
            print('...done')

        # get VAE parameters
        vae_params = \
            list(self.encoderA.parameters()) + \
            list(self.encoderB.parameters()) + \
            list(self.decoderA.parameters()) + \
            list(self.decoderB.parameters())


        # create optimizers
        self.optim_vae = optim.Adam(
            vae_params,
            lr=self.lr_VAE,
            betas=[self.beta1_VAE, self.beta2_VAE]
        )

    ####
    def train(self):

        self.set_mode(train=True)

        # prepare dataloader (iterable)
        print('Start loading data...')
        dset = DIGIT('./data', train=True)
        self.data_loader = torch.utils.data.DataLoader(dset, batch_size=self.batch_size, shuffle=True)
        test_dset = DIGIT('./data', train=False)
        self.test_data_loader = torch.utils.data.DataLoader(test_dset, batch_size=self.batch_size, shuffle=True)
        print('test: ', len(test_dset))
        self.N = len(self.data_loader.dataset)
        print('...done')

        # iterators from dataloader
        iterator1 = iter(self.data_loader)
        iterator2 = iter(self.data_loader)

        iter_per_epoch = min(len(iterator1), len(iterator2))

        start_iter = self.ckpt_load_iter + 1
        epoch = int(start_iter / iter_per_epoch)

        for iteration in range(start_iter, self.max_iter + 1):

            # reset data iterators for each epoch
            if iteration % iter_per_epoch == 0:
                print('==== epoch %d done ====' % epoch)
                epoch += 1
                iterator1 = iter(self.data_loader)
                iterator2 = iter(self.data_loader)

            # ============================================
            #          TRAIN THE VAE (ENC & DEC)
            # ============================================

            # sample a mini-batch
            XA, XB, index = next(iterator1)  # (n x C x H x W)

            index = index.cpu().detach().numpy()
            if self.use_cuda:
                XA = XA.cuda()
                XB = XB.cuda()

            # zA, zS = encA(xA)
            muA_infA, stdA_infA, logvarA_infA, cate_prob_infA = self.encoderA(XA)

            # zB, zS = encB(xB)
            muB_infB, stdB_infB, logvarB_infB, cate_prob_infB = self.encoderB(XB)

            # read current values

            # zS = encAB(xA,xB) via POE
            cate_prob_POE = torch.exp(
                torch.log(torch.tensor(1 / 10)) + torch.log(cate_prob_infA) + torch.log(cate_prob_infB))

            # latent_dist = {'cont': (muA_infA, logvarA_infA), 'disc': [cate_prob_infA]}
            # (kl_cont_loss, kl_disc_loss, cont_capacity_loss, disc_capacity_loss) = kl_loss_function(self.use_cuda, iteration, latent_dist)

            # kl losses
            #A
            latent_dist_infA = {'cont': (muA_infA, logvarA_infA), 'disc': [cate_prob_infA]}
            (kl_cont_loss_infA, kl_disc_loss_infA, cont_capacity_loss_infA, disc_capacity_loss_infA) = kl_loss_function(
                self.use_cuda, iteration, latent_dist_infA)

            loss_kl_infA = kl_cont_loss_infA + kl_disc_loss_infA
            capacity_loss_infA = cont_capacity_loss_infA + disc_capacity_loss_infA

            #B
            latent_dist_infB = {'cont': (muB_infB, logvarB_infB), 'disc': [cate_prob_infB]}
            (kl_cont_loss_infB, kl_disc_loss_infB, cont_capacity_loss_infB, disc_capacity_loss_infB) = kl_loss_function(
                self.use_cuda, iteration, latent_dist_infB, cont_capacity=[0.0, 5.0, 50000, 100.0] , disc_capacity=[0.0, 10.0, 50000, 100.0])

            loss_kl_infB = kl_cont_loss_infB + kl_disc_loss_infB
            capacity_loss_infB = cont_capacity_loss_infB + disc_capacity_loss_infB


            loss_capa = capacity_loss_infB

            # encoder samples (for training)
            ZA_infA = sample_gaussian(self.use_cuda, muA_infA, stdA_infA)
            ZB_infB = sample_gaussian(self.use_cuda, muB_infB, stdB_infB)
            ZS_POE = sample_gumbel_softmax(self.use_cuda, cate_prob_POE)

            # encoder samples (for cross-modal prediction)
            ZS_infA = sample_gumbel_softmax(self.use_cuda, cate_prob_infA)
            ZS_infB = sample_gumbel_softmax(self.use_cuda, cate_prob_infB)

            # reconstructed samples (given joint modal observation)
            XA_POE_recon = self.decoderA(ZA_infA, ZS_POE)
            XB_POE_recon = self.decoderB(ZB_infB, ZS_POE)

            # reconstructed samples (given single modal observation)
            XA_infA_recon = self.decoderA(ZA_infA, ZS_infA)
            XB_infB_recon = self.decoderB(ZB_infB, ZS_infB)

            # loss_recon_infA = F.l1_loss(torch.sigmoid(XA_infA_recon), XA, reduction='sum').div(XA.size(0))
            loss_recon_infA = reconstruction_loss(XA, torch.sigmoid(XA_infA_recon), distribution="bernoulli")
            #
            loss_recon_infB = reconstruction_loss(XB, torch.sigmoid(XB_infB_recon), distribution="bernoulli")
            #
            loss_recon_POE = \
                F.l1_loss(torch.sigmoid(XA_POE_recon), XA, reduction='sum').div(XA.size(0)) + \
                F.l1_loss(torch.sigmoid(XB_POE_recon), XB, reduction='sum').div(XB.size(0))
            #

            loss_recon = loss_recon_infB

            # total loss for vae
            vae_loss = loss_recon + loss_capa

            # update vae
            self.optim_vae.zero_grad()
            vae_loss.backward()
            self.optim_vae.step()



            # print the losses
            if iteration % self.print_iter == 0:
                prn_str = ( \
                                      '[iter %d (epoch %d)] vae_loss: %.3f ' + \
                                      '(recon: %.3f, capa: %.3f)\n' + \
                                      '    rec_infA = %.3f, rec_infB = %.3f, rec_POE = %.3f\n' + \
                                      '    kl_infA = %.3f, kl_infB = %.3f' + \
                                      '    cont_capacity_loss_infA = %.3f, disc_capacity_loss_infA = %.3f\n' + \
                                      '    cont_capacity_loss_infB = %.3f, disc_capacity_loss_infB = %.3f\n'
                          ) % \
                          (iteration, epoch,
                           vae_loss.item(), loss_recon.item(), loss_capa.item(),
                           loss_recon_infA.item(), loss_recon_infB.item(), loss_recon.item(),
                           loss_kl_infA.item(), loss_kl_infB.item(),
                           cont_capacity_loss_infA.item(), disc_capacity_loss_infA.item(),
                           cont_capacity_loss_infB.item(), disc_capacity_loss_infB.item(),
                           )
                print(prn_str)
                if self.record_file:
                    record = open(self.record_file, 'a')
                    record.write('%s\n' % (prn_str,))
                    record.close()

            # save model parameters
            if iteration % self.ckpt_save_iter == 0:
                self.save_checkpoint(iteration)

            # save output images (recon, synth, etc.)
            if iteration % self.output_save_iter == 0:
                # self.save_embedding(iteration, index, muA_infA, muB_infB, muS_infA, muS_infB, muS_POE)

                # 1) save the recon images
                self.save_recon(iteration)

                # self.save_recon2(iteration, index, XA, XB,
                #     torch.sigmoid(XA_infA_recon).data,
                #     torch.sigmoid(XB_infB_recon).data,
                #     torch.sigmoid(XA_POE_recon).data,
                #     torch.sigmoid(XB_POE_recon).data,
                #     muA_infA, muB_infB, muS_infA, muS_infB, muS_POE,
                #     logalpha, logalphaA, logalphaB
                # )
                z_A, z_B, z_S = self.get_stat()

                #
                #
                #
                # # 2) save the pure-synthesis images
                # # self.save_synth_pure( iteration, howmany=100 )
                # #
                # # 3) save the cross-modal-synthesis images
                # self.save_synth_cross_modal(iteration, z_A, z_B, howmany=3)
                #
                # # 4) save the latent traversed images
                self.save_traverseB(iteration, z_A, z_B, z_S)

                # self.get_loglike(logalpha, logalphaA, logalphaB)

                # # 3) save the latent traversed images
                # if self.dataset.lower() == '3dchairs':
                #     self.save_traverse(iteration, limb=-2, limu=2, inter=0.5)
                # else:
                #     self.save_traverse(iteration, limb=-3, limu=3, inter=0.1)

            if iteration % self.eval_metrics_iter == 0:
                self.save_synth_cross_modal(iteration, z_A, z_B, train=False, howmany=3)

            # (visdom) insert current line stats
            if self.viz_on and (iteration % self.viz_ll_iter == 0):
                self.line_gather.insert(iter=iteration,
                                        recon_both=loss_recon_POE.item(),
                                        recon_A=loss_recon_infA.item(),
                                        recon_B=loss_recon_infB.item(),
                                        kl_A=loss_kl_infA.item(),
                                        kl_B=loss_kl_infB.item(),
                                        cont_capacity_loss_infA=cont_capacity_loss_infA.item(),
                                        disc_capacity_loss_infA=disc_capacity_loss_infA.item(),
                                        cont_capacity_loss_infB=cont_capacity_loss_infB.item(),
                                        disc_capacity_loss_infB=disc_capacity_loss_infB.item()
                                        )

            # (visdom) visualize line stats (then flush out)
            if self.viz_on and (iteration % self.viz_la_iter == 0):
                self.visualize_line()
                self.line_gather.flush()

            # evaluate metrics
            # if self.eval_metrics and (iteration % self.eval_metrics_iter == 0):
            #
            #     metric1, _ = self.eval_disentangle_metric1()
            #     metric2, _ = self.eval_disentangle_metric2()
            #
            #     prn_str = ( '********\n[iter %d (epoch %d)] ' + \
            #       'metric1 = %.4f, metric2 = %.4f\n********' ) % \
            #       (iteration, epoch, metric1, metric2)
            #     print(prn_str)
            #     if self.record_file:
            #         record = open(self.record_file, 'a')
            #         record.write('%s\n' % (prn_str,))
            #         record.close()
            #
            #     # (visdom) visulaize metrics
            #     if self.viz_on:
            #         self.visualize_line_metrics(iteration, metric1, metric2)
            #


    ####
    def eval_disentangle_metric1(self):

        # some hyperparams
        num_pairs = 800  # # data pairs (d,y) for majority vote classification
        bs = 50  # batch size
        nsamps_per_factor = 100  # samples per factor
        nsamps_agn_factor = 5000  # factor-agnostic samples

        self.set_mode(train=False)

        # 1) estimate variances of latent points factor agnostic

        dl = DataLoader(
            self.data_loader.dataset, batch_size=bs,
            shuffle=True, num_workers=self.args.num_workers, pin_memory=True)
        iterator = iter(dl)

        M = []
        for ib in range(int(nsamps_agn_factor / bs)):

            # sample a mini-batch
            XAb, XBb, _, _, _ = next(iterator)  # (bs x C x H x W)
            if self.use_cuda:
                XAb = XAb.cuda()
                XBb = XBb.cuda()

            # z = encA(xA)
            mu_infA, _, logvar_infA = self.encoderA(XAb)

            # z = encB(xB)
            mu_infB, _, logvar_infB = self.encoderB(XBb)

            # z = encAB(xA,xB) via POE
            mu_POE, _, _ = apply_poe(
                self.use_cuda, mu_infA, logvar_infA, mu_infB, logvar_infB,
            )

            mub = mu_POE

            M.append(mub.cpu().detach().numpy())

        M = np.concatenate(M, 0)

        # estimate sample vairance and mean of latent points for each dim
        vars_agn_factor = np.var(M, 0)

        # 2) estimatet dim-wise vars of latent points with "one factor fixed"

        factor_ids = range(0, len(self.latent_sizes))  # true factor ids
        vars_per_factor = np.zeros([num_pairs, self.z_dim])
        true_factor_ids = np.zeros(num_pairs, np.int)  # true factor ids

        # prepare data pairs for majority-vote classification
        i = 0
        for j in factor_ids:  # for each factor

            # repeat num_paris/num_factors times
            for r in range(int(num_pairs / len(factor_ids))):

                # a true factor (id and class value) to fix
                fac_id = j
                fac_class = np.random.randint(self.latent_sizes[fac_id])

                # randomly select images (with the fixed factor)
                indices = np.where(
                    self.latent_classes[:, fac_id] == fac_class)[0]
                np.random.shuffle(indices)
                idx = indices[:nsamps_per_factor]
                M = []
                for ib in range(int(nsamps_per_factor / bs)):
                    XAb, XBb, _, _, _ = dl.dataset[idx[(ib * bs):(ib + 1) * bs]]
                    if XAb.shape[0] < 1:  # no more samples
                        continue;
                    if self.use_cuda:
                        XAb = XAb.cuda()
                        XBb = XBb.cuda()
                    mu_infA, _, logvar_infA = self.encoderA(XAb)
                    mu_infB, _, logvar_infB = self.encoderB(XBb)
                    mu_POE, _, _ = apply_poe(self.use_cuda,
                                             mu_infA, logvar_infA, mu_infB, logvar_infB,
                                             )
                    mub = mu_POE
                    M.append(mub.cpu().detach().numpy())
                M = np.concatenate(M, 0)

                # estimate sample var and mean of latent points for each dim
                if M.shape[0] >= 2:
                    vars_per_factor[i, :] = np.var(M, 0)
                else:  # not enough samples to estimate variance
                    vars_per_factor[i, :] = 0.0

                    # true factor id (will become the class label)
                true_factor_ids[i] = fac_id

                i += 1

        # 3) evaluate majority vote classification accuracy

        # inputs in the paired data for classification
        smallest_var_dims = np.argmin(
            vars_per_factor / (vars_agn_factor + 1e-20), axis=1)

        # contingency table
        C = np.zeros([self.z_dim, len(factor_ids)])
        for i in range(num_pairs):
            C[smallest_var_dims[i], true_factor_ids[i]] += 1

        num_errs = 0  # # misclassifying errors of majority vote classifier
        for k in range(self.z_dim):
            num_errs += np.sum(C[k, :]) - np.max(C[k, :])

        metric1 = (num_pairs - num_errs) / num_pairs  # metric = accuracy

        self.set_mode(train=True)

        return metric1, C

    ####
    def eval_disentangle_metric2(self):

        # some hyperparams
        num_pairs = 800  # # data pairs (d,y) for majority vote classification
        bs = 50  # batch size
        nsamps_per_factor = 100  # samples per factor
        nsamps_agn_factor = 5000  # factor-agnostic samples

        self.set_mode(train=False)

        # 1) estimate variances of latent points factor agnostic

        dl = DataLoader(
            self.data_loader.dataset, batch_size=bs,
            shuffle=True, num_workers=self.args.num_workers, pin_memory=True)
        iterator = iter(dl)

        M = []
        for ib in range(int(nsamps_agn_factor / bs)):

            # sample a mini-batch
            XAb, XBb, _, _, _ = next(iterator)  # (bs x C x H x W)
            if self.use_cuda:
                XAb = XAb.cuda()
                XBb = XBb.cuda()

            # z = encA(xA)
            mu_infA, _, logvar_infA = self.encoderA(XAb)

            # z = encB(xB)
            mu_infB, _, logvar_infB = self.encoderB(XBb)

            # z = encAB(xA,xB) via POE
            mu_POE, _, _ = apply_poe(
                self.use_cuda, mu_infA, logvar_infA, mu_infB, logvar_infB,
            )

            mub = mu_POE

            M.append(mub.cpu().detach().numpy())

        M = np.concatenate(M, 0)

        # estimate sample vairance and mean of latent points for each dim
        vars_agn_factor = np.var(M, 0)

        # 2) estimatet dim-wise vars of latent points with "one factor varied"

        factor_ids = range(0, len(self.latent_sizes))  # true factor ids
        vars_per_factor = np.zeros([num_pairs, self.z_dim])
        true_factor_ids = np.zeros(num_pairs, np.int)  # true factor ids

        # prepare data pairs for majority-vote classification
        i = 0
        for j in factor_ids:  # for each factor

            # repeat num_paris/num_factors times
            for r in range(int(num_pairs / len(factor_ids))):

                # randomly choose true factors (id's and class values) to fix
                fac_ids = list(np.setdiff1d(factor_ids, j))
                fac_classes = \
                    [np.random.randint(self.latent_sizes[k]) for k in fac_ids]

                # randomly select images (with the other factors fixed)
                if len(fac_ids) > 1:
                    indices = np.where(
                        np.sum(self.latent_classes[:, fac_ids] == fac_classes, 1)
                        == len(fac_ids)
                    )[0]
                else:
                    indices = np.where(
                        self.latent_classes[:, fac_ids] == fac_classes
                    )[0]
                np.random.shuffle(indices)
                idx = indices[:nsamps_per_factor]
                M = []
                for ib in range(int(nsamps_per_factor / bs)):
                    XAb, XBb, _, _, _ = dl.dataset[idx[(ib * bs):(ib + 1) * bs]]
                    if XAb.shape[0] < 1:  # no more samples
                        continue;
                    if self.use_cuda:
                        XAb = XAb.cuda()
                        XBb = XBb.cuda()
                    mu_infA, _, logvar_infA = self.encoderA(XAb)
                    mu_infB, _, logvar_infB = self.encoderB(XBb)
                    mu_POE, _, _ = apply_poe(self.use_cuda,
                                             mu_infA, logvar_infA, mu_infB, logvar_infB,
                                             )
                    mub = mu_POE
                    M.append(mub.cpu().detach().numpy())
                M = np.concatenate(M, 0)

                # estimate sample var and mean of latent points for each dim
                if M.shape[0] >= 2:
                    vars_per_factor[i, :] = np.var(M, 0)
                else:  # not enough samples to estimate variance
                    vars_per_factor[i, :] = 0.0

                # true factor id (will become the class label)
                true_factor_ids[i] = j

                i += 1

        # 3) evaluate majority vote classification accuracy

        # inputs in the paired data for classification
        largest_var_dims = np.argmax(
            vars_per_factor / (vars_agn_factor + 1e-20), axis=1)

        # contingency table
        C = np.zeros([self.z_dim, len(factor_ids)])
        for i in range(num_pairs):
            C[largest_var_dims[i], true_factor_ids[i]] += 1

        num_errs = 0  # # misclassifying errors of majority vote classifier
        for k in range(self.z_dim):
            num_errs += np.sum(C[k, :]) - np.max(C[k, :])

        metric2 = (num_pairs - num_errs) / num_pairs  # metric = accuracy

        self.set_mode(train=True)

        return metric2, C

    def save_recon(self, iters):
        self.set_mode(train=False)

        mkdirs(self.output_dir_recon)

        fixed_idxs = [3246, 7000, 14305, 19000, 27444, 33100, 38000, 45231, 51000, 55121]

        fixed_idxs60 = []
        for idx in fixed_idxs:
            for i in range(6):
                fixed_idxs60.append(idx + i)

        XA = [0] * len(fixed_idxs60)
        XB = [0] * len(fixed_idxs60)

        for i, idx in enumerate(fixed_idxs60):
            XA[i], XB[i] = \
                self.data_loader.dataset.__getitem__(idx)[0:2]

            if self.use_cuda:
                XA[i] = XA[i].cuda()
                XB[i] = XB[i].cuda()

        XA = torch.stack(XA)
        XB = torch.stack(XB)

        muA_infA, stdA_infA, logvarA_infA, cate_prob_infA = self.encoderA(XA)

        # zB, zS = encB(xB)
        muB_infB, stdB_infB, logvarB_infB, cate_prob_infB = self.encoderB(XB)

        # zS = encAB(xA,xB) via POE
        cate_prob_POE = torch.exp(
            torch.log(torch.tensor(1 / 10)) + torch.log(cate_prob_infA) + torch.log(cate_prob_infB))

        # encoder samples (for training)
        ZA_infA = sample_gaussian(self.use_cuda, muA_infA, stdA_infA)
        ZB_infB = sample_gaussian(self.use_cuda, muB_infB, stdB_infB)
        ZS_POE = sample_gumbel_softmax(self.use_cuda, cate_prob_POE, train=False)

        # encoder samples (for cross-modal prediction)
        ZS_infA = sample_gumbel_softmax(self.use_cuda, cate_prob_infA, train=False)
        ZS_infB = sample_gumbel_softmax(self.use_cuda, cate_prob_infB, train=False)

        # reconstructed samples (given joint modal observation)
        XA_POE_recon = torch.sigmoid(self.decoderA(ZA_infA, ZS_POE))
        XB_POE_recon = torch.sigmoid(self.decoderB(ZB_infB, ZS_POE))

        # reconstructed samples (given single modal observation)
        XA_infA_recon = torch.sigmoid(self.decoderA(ZA_infA, ZS_infA))
        XB_infB_recon = torch.sigmoid(self.decoderB(ZB_infB, ZS_infB))

        WS = torch.ones(XA.shape)
        if self.use_cuda:
            WS = WS.cuda()

        n = XA.shape[0]
        perm = torch.arange(0, 4 * n).view(4, n).transpose(1, 0)
        perm = perm.contiguous().view(-1)

        ## img
        # merged = torch.cat(
        #     [ XA, XB, XA_infA_recon, XB_infB_recon,
        #       XA_POE_recon, XB_POE_recon, WS ], dim=0
        # )
        merged = torch.cat(
            [XA, XA_infA_recon, XA_POE_recon, WS], dim=0
        )
        merged = merged[perm, :].cpu()

        # save the results as image
        fname = os.path.join(self.output_dir_recon, 'reconA_%s.jpg' % iters)
        mkdirs(self.output_dir_recon)
        save_image(
            tensor=merged, filename=fname, nrow=4 * int(np.sqrt(n)),
            pad_value=1
        )

        WS = torch.ones(XB.shape)
        if self.use_cuda:
            WS = WS.cuda()

        n = XB.shape[0]
        perm = torch.arange(0, 4 * n).view(4, n).transpose(1, 0)
        perm = perm.contiguous().view(-1)

        ## ingr
        merged = torch.cat(
            [XB, XB_infB_recon, XB_POE_recon, WS], dim=0
        )
        merged = merged[perm, :].cpu()

        # save the results as image
        fname = os.path.join(self.output_dir_recon, 'reconB_%s.jpg' % iters)
        mkdirs(self.output_dir_recon)
        save_image(
            tensor=merged, filename=fname, nrow=4 * int(np.sqrt(n)),
            pad_value=1
        )
        self.set_mode(train=True)

    ####
    def save_synth_pure(self, iters, howmany=100):

        self.set_mode(train=False)

        decoderA = self.decoderA
        decoderB = self.decoderB

        Z = torch.randn(howmany, self.z_dim)
        if self.use_cuda:
            Z = Z.cuda()

        # do synthesis
        XA = torch.sigmoid(decoderA(Z)).data
        XB = torch.sigmoid(decoderB(Z)).data

        WS = torch.ones(XA.shape)
        if self.use_cuda:
            WS = WS.cuda()

        perm = torch.arange(0, 3 * howmany).view(3, howmany).transpose(1, 0)
        perm = perm.contiguous().view(-1)
        merged = torch.cat([XA, XB, WS], dim=0)
        merged = merged[perm, :].cpu()

        # save the results as image
        fname = os.path.join(
            self.output_dir_synth, 'synth_pure_%s.jpg' % iters
        )
        mkdirs(self.output_dir_synth)
        save_image(
            tensor=merged, filename=fname, nrow=3 * int(np.sqrt(howmany)),
            pad_value=1
        )

        self.set_mode(train=True)

    ####
    def save_synth_cross_modal(self, iters, z_A_stat, z_B_stat, train=True, howmany=3):

        self.set_mode(train=False)

        if train:
            data_loader = self.data_loader
            fixed_idxs = [3246, 7001, 14308, 19000, 27447, 33103, 38002, 45232, 51000, 55125]
        else:
            data_loader = self.test_data_loader
            fixed_idxs = [2, 982, 2300, 3400, 4500, 5500, 6500, 7500, 8500, 9500]

        fixed_XA = [0] * len(fixed_idxs)
        fixed_XB = [0] * len(fixed_idxs)

        for i, idx in enumerate(fixed_idxs):

            fixed_XA[i], fixed_XB[i] = \
                data_loader.dataset.__getitem__(idx)[0:2]

            if self.use_cuda:
                fixed_XA[i] = fixed_XA[i].cuda()
                fixed_XB[i] = fixed_XB[i].cuda()

        fixed_XA = torch.stack(fixed_XA)
        fixed_XB = torch.stack(fixed_XB)

        _, _, _, cate_prob_infA = self.encoderA(fixed_XA)

        # zB, zS = encB(xB)
        _, _, _, cate_prob_infB = self.encoderB(fixed_XB)

        ZS_infA = sample_gumbel_softmax(self.use_cuda, cate_prob_infA, train=False)
        ZS_infB = sample_gumbel_softmax(self.use_cuda, cate_prob_infB, train=False)

        if self.use_cuda:
            ZS_infA = ZS_infA.cuda()
            ZS_infB = ZS_infB.cuda()

        decoderA = self.decoderA
        decoderB = self.decoderB

        # mkdirs(os.path.join(self.output_dir_synth, str(iters)))

        fixed_XA_3ch = []
        for i in range(len(fixed_XA)):
            each_XA = fixed_XA[i].clone().squeeze()
            fixed_XA_3ch.append(torch.stack([each_XA, each_XA, each_XA]))

        fixed_XA_3ch = torch.stack(fixed_XA_3ch)

        WS = torch.ones(fixed_XA_3ch.shape)
        if self.use_cuda:
            WS = WS.cuda()

        n = len(fixed_idxs)

        perm = torch.arange(0, (howmany + 2) * n).view(howmany + 2, n).transpose(1, 0)
        perm = perm.contiguous().view(-1)

        ######## 1) generate xB from given xA (A2B) ########

        merged = torch.cat([fixed_XA_3ch], dim=0)
        for k in range(howmany):
            # z_B_stat = np.array(z_B_stat)
            # z_B_stat_mean = np.mean(z_B_stat, 0)
            # ZB = torch.Tensor(z_B_stat_mean)
            # ZB_list = []
            # for _ in range(n):
            #     ZB_list.append(ZB)
            # ZB = torch.stack(ZB_list)

            ZB = torch.randn(n, self.zB_dim)
            z_B_stat = np.array(z_B_stat)
            z_B_stat_mean = np.mean(z_B_stat, 0)
            ZB = ZB + torch.Tensor(z_B_stat_mean)

            if self.use_cuda:
                ZB = ZB.cuda()
            XB_synth = torch.sigmoid(decoderB(ZB, ZS_infA))  # given XA
            # merged = torch.cat([merged, fixed_XA_3ch], dim=0)
            merged = torch.cat([merged, XB_synth], dim=0)
        merged = torch.cat([merged, WS], dim=0)
        merged = merged[perm, :].cpu()

        # save the results as image
        if train:
            fname = os.path.join(
                self.output_dir_synth,
                'synth_cross_modal_A2B_%s.jpg' % iters
            )
        else:
            fname = os.path.join(
                self.output_dir_synth,
                'eval_synth_cross_modal_A2B_%s.jpg' % iters
            )
        mkdirs(self.output_dir_synth)
        save_image(
            tensor=merged, filename=fname, nrow=(howmany + 2) * int(np.sqrt(n)),
            pad_value=1
        )

        ######## 2) generate xA from given xB (B2A) ########
        merged = torch.cat([fixed_XB], dim=0)
        for k in range(howmany):
            # z_A_stat = np.array(z_A_stat)
            # z_A_stat_mean = np.mean(z_A_stat, 0)
            # ZA = torch.Tensor(z_A_stat_mean)
            # ZA_list = []
            # for _ in range(n):
            #     ZA_list.append(ZA)
            # ZA = torch.stack(ZA_list)

            ZA = torch.randn(n, self.zA_dim)
            z_A_stat = np.array(z_A_stat)
            z_A_stat_mean = np.mean(z_A_stat, 0)
            ZA = ZA + torch.Tensor(z_A_stat_mean)

            if self.use_cuda:
                ZA = ZA.cuda()
            XA_synth = torch.sigmoid(decoderA(ZA, ZS_infB))  # given XB

            XA_synth_3ch = []
            for i in range(len(XA_synth)):
                each_XA = XA_synth[i].clone().squeeze()
                XA_synth_3ch.append(torch.stack([each_XA, each_XA, each_XA]))

            # merged = torch.cat([merged, fixed_XB[:,:,2:30, 2:30]], dim=0)
            merged = torch.cat([merged, torch.stack(XA_synth_3ch)], dim=0)
        merged = torch.cat([merged, WS], dim=0)
        merged = merged[perm, :].cpu()

        # save the results as image
        if train:
            fname = os.path.join(
                self.output_dir_synth,
                'synth_cross_modal_B2A_%s.jpg' % iters
            )
        else:
            fname = os.path.join(
                self.output_dir_synth,
                'eval_synth_cross_modal_B2A_%s.jpg' % iters
            )
        mkdirs(self.output_dir_synth)
        save_image(
            tensor=merged, filename=fname, nrow=(howmany + 2) * int(np.sqrt(n)),
            pad_value=1
        )

        self.set_mode(train=True)

    def get_stat(self):
        encoderA = self.encoderA
        encoderB = self.encoderB

        z_A, z_B, z_S = [], [], []
        for _ in range(10000):
            rand_i = np.random.randint(self.N)
            random_XA, random_XB = self.data_loader.dataset.__getitem__(rand_i)[0:2]
            if self.use_cuda:
                random_XA = random_XA.cuda()
                random_XB = random_XB.cuda()
            random_XA = random_XA.unsqueeze(0)
            random_XB = random_XB.unsqueeze(0)

            muA_infA, stdA_infA, logvarA_infA, cate_prob_infA = self.encoderA(random_XA)

            # zB, zS = encB(xB)
            muB_infB, stdB_infB, logvarB_infB, cate_prob_infB = self.encoderB(random_XB)
            cate_prob_POE = torch.exp(
                torch.log(torch.tensor(1 / 10)) + torch.log(cate_prob_infA) + torch.log(cate_prob_infB))

            z_A.append(muA_infA.cpu().detach().numpy()[0])
            z_B.append(muB_infB.cpu().detach().numpy()[0])
            z_S.append(cate_prob_POE.cpu().detach().numpy()[0])
        return z_A, z_B, z_S


    def save_traverseA(self, iters, z_A, z_B, z_S, loc=-1):

        self.set_mode(train=False)

        encoderA = self.encoderA
        encoderB = self.encoderB
        decoderA = self.decoderA
        decoderB = self.decoderB
        interpolationA = torch.tensor(np.linspace(-3, 3, self.zS_dim))
        interpolationB = torch.tensor(np.linspace(-3, 3, self.zS_dim))
        interpolationS = torch.tensor(np.linspace(-3, 3, self.zS_dim))

        print('------------ traverse interpolation ------------')
        print('interpolationA: ', np.min(np.array(z_A)), np.max(np.array(z_A)))
        print('interpolationB: ', np.min(np.array(z_B)), np.max(np.array(z_B)))
        print('interpolationS: ', np.min(np.array(z_S)), np.max(np.array(z_S)))

        if self.record_file:
            ####
            fixed_idxs = [3246, 7000, 14305, 19000, 27444, 33100, 38000, 45231, 51000, 55121]

            fixed_XA = [0] * len(fixed_idxs)
            fixed_XB = [0] * len(fixed_idxs)

            for i, idx in enumerate(fixed_idxs):

                fixed_XA[i], fixed_XB[i] = \
                    self.data_loader.dataset.__getitem__(idx)[0:2]
                if self.use_cuda:
                    fixed_XA[i] = fixed_XA[i].cuda()
                    fixed_XB[i] = fixed_XB[i].cuda()
                fixed_XA[i] = fixed_XA[i].unsqueeze(0)
                fixed_XB[i] = fixed_XB[i].unsqueeze(0)

            fixed_XA = torch.cat(fixed_XA, dim=0)
            fixed_XB = torch.cat(fixed_XB, dim=0)

            fixed_zmuA, _, _, cate_prob_infA = encoderA(fixed_XA)

            # zB, zS = encB(xB)
            fixed_zmuB, _, _, cate_prob_infB = encoderB(fixed_XB)

            # zS = encAB(xA,xB) via POE
            fixed_cate_probS = torch.exp(
                torch.log(torch.tensor(1 / 10)) + torch.log(cate_prob_infA) + torch.log(cate_prob_infB))


            # fixed_zS = sample_gumbel_softmax(self.use_cuda, fixed_cate_probS, train=False)
            fixed_zS = sample_gumbel_softmax(self.use_cuda, cate_prob_infA, train=False)


            saving_shape=torch.cat([fixed_XA[i] for i in range(fixed_XA.shape[0])], dim=1).shape

        ####

        WS = torch.ones(saving_shape)
        if self.use_cuda:
            WS = WS.cuda()

        # do traversal and collect generated images
        gifs = []

        zA_ori, zB_ori, zS_ori = fixed_zmuA, fixed_zmuB, fixed_zS

        tempA = [] # zA_dim + zS_dim , num_trv, 1, 32*num_samples, 32
        for row in range(self.zA_dim):
            if loc != -1 and row != loc:
                continue
            zA = zA_ori.clone()

            temp = []
            for val in interpolationA:
                zA[:, row] = val
                sampleA = torch.sigmoid(decoderA(zA, zS_ori)).data
                temp.append((torch.cat([sampleA[i] for i in range(sampleA.shape[0])], dim=1)).unsqueeze(0))

            tempA.append(torch.cat(temp, dim=0).unsqueeze(0)) # torch.cat(temp, dim=0) = num_trv, 1, 32*num_samples, 32

        temp = []
        for i in range(self.zS_dim):
            zS = np.zeros((1, self.zS_dim))
            zS[0, i % self.zS_dim] = 1.
            zS = torch.Tensor(zS)
            zS = torch.cat([zS] * len(fixed_idxs), dim=0)

            if self.use_cuda:
                zS = zS.cuda()

            sampleA = torch.sigmoid(decoderA(zA_ori, zS)).data
            temp.append((torch.cat([sampleA[i] for i in range(sampleA.shape[0])], dim=1)).unsqueeze(0))
        tempA.append(torch.cat(temp, dim=0).unsqueeze(0))
        gifs = torch.cat(tempA, dim=0) #torch.Size([11, 10, 1, 384, 32])


        # save the generated files, also the animated gifs
        out_dir = os.path.join(self.output_dir_trvsl, str(iters), 'train')
        mkdirs(self.output_dir_trvsl)
        mkdirs(out_dir)

        for j, val in enumerate(interpolationA):
            # I = torch.cat([IMG[key], gifs[:][j]], dim=0)
            I = gifs[:,j]
            save_image(
                tensor=I.cpu(),
                filename=os.path.join(out_dir, '%03d.jpg' % (j)),
                nrow=1 + self.zA_dim + 1 + 1 + 1 + self.zB_dim,
                pad_value=1)
            # make animated gif
        grid2gif2(
            out_dir, str(os.path.join(out_dir, 'mnist_traverse' + '.gif')), delay=10
        )

        self.set_mode(train=True)



    ###
    def save_traverseB(self, iters, z_A, z_B, z_S, loc=-1):

        self.set_mode(train=False)

        encoderA = self.encoderA
        encoderB = self.encoderB
        decoderB = self.decoderB
        interpolationA = torch.tensor(np.linspace(-3, 3, self.zS_dim))

        print('------------ traverse interpolation ------------')
        print('interpolationA: ', np.min(np.array(z_A)), np.max(np.array(z_A)))
        print('interpolationB: ', np.min(np.array(z_B)), np.max(np.array(z_B)))
        print('interpolationS: ', np.min(np.array(z_S)), np.max(np.array(z_S)))

        if self.record_file:
            ####
            fixed_idxs = [3246, 7000, 14305, 19000, 27444, 33100, 38000, 45231, 51000, 55121]

            fixed_XA = [0] * len(fixed_idxs)
            fixed_XB = [0] * len(fixed_idxs)

            for i, idx in enumerate(fixed_idxs):

                fixed_XA[i], fixed_XB[i] = \
                    self.data_loader.dataset.__getitem__(idx)[0:2]
                if self.use_cuda:
                    fixed_XA[i] = fixed_XA[i].cuda()
                    fixed_XB[i] = fixed_XB[i].cuda()
                fixed_XA[i] = fixed_XA[i].unsqueeze(0)
                fixed_XB[i] = fixed_XB[i].unsqueeze(0)

            fixed_XA = torch.cat(fixed_XA, dim=0)
            fixed_XB = torch.cat(fixed_XB, dim=0)

            fixed_zmuA, _, _, cate_prob_infA = encoderA(fixed_XA)

            # zB, zS = encB(xB)
            fixed_zmuB, _, _, cate_prob_infB = encoderB(fixed_XB)


            # fixed_zS = sample_gumbel_softmax(self.use_cuda, fixed_cate_probS, train=False)
            fixed_zS = sample_gumbel_softmax(self.use_cuda, cate_prob_infB, train=False)


            saving_shape=torch.cat([fixed_XA[i] for i in range(fixed_XA.shape[0])], dim=1).shape

        ####

        WS = torch.ones(saving_shape)
        if self.use_cuda:
            WS = WS.cuda()

        # do traversal and collect generated images
        gifs = []

        zA_ori, zB_ori, zS_ori = fixed_zmuA, fixed_zmuB, fixed_zS

        tempB = [] # zA_dim + zS_dim , num_trv, 1, 32*num_samples, 32
        for row in range(self.zB_dim):
            if loc != -1 and row != loc:
                continue
            zB = zB_ori.clone()

            temp = []
            for val in interpolationA:
                zB[:, row] = val
                sampleB = torch.sigmoid(decoderB(zB, zS_ori)).data
                temp.append((torch.cat([sampleB[i] for i in range(sampleB.shape[0])], dim=1)).unsqueeze(0))

            tempB.append(torch.cat(temp, dim=0).unsqueeze(0)) # torch.cat(temp, dim=0) = num_trv, 1, 32*num_samples, 32

        temp = []
        for i in range(self.zS_dim):
            zS = np.zeros((1, self.zS_dim))
            zS[0, i % self.zS_dim] = 1.
            zS = torch.Tensor(zS)
            zS = torch.cat([zS] * len(fixed_idxs), dim=0)

            if self.use_cuda:
                zS = zS.cuda()

            sampleB = torch.sigmoid(decoderB(zB_ori, zS)).data
            temp.append((torch.cat([sampleB[i] for i in range(sampleB.shape[0])], dim=1)).unsqueeze(0))
        tempB.append(torch.cat(temp, dim=0).unsqueeze(0))
        gifs = torch.cat(tempB, dim=0) #torch.Size([11, 10, 1, 384, 32])


        # save the generated files, also the animated gifs
        out_dir = os.path.join(self.output_dir_trvsl, str(iters), 'train')
        mkdirs(self.output_dir_trvsl)
        mkdirs(out_dir)

        for j, val in enumerate(interpolationA):
            # I = torch.cat([IMG[key], gifs[:][j]], dim=0)
            I = gifs[:,j]
            save_image(
                tensor=I.cpu(),
                filename=os.path.join(out_dir, '%03d.jpg' % (j)),
                nrow=1 + self.zA_dim + 1 + 1 + 1 + self.zB_dim,
                pad_value=1)
            # make animated gif
        grid2gif2(
            out_dir, str(os.path.join(out_dir, 'fmnist_traverse' + '.gif')), delay=10
        )

        self.set_mode(train=True)

    ####
    def viz_init(self):

        self.viz.close(env=self.name + '/lines', win=self.win_id['recon'])
        self.viz.close(env=self.name + '/lines', win=self.win_id['kl'])
        self.viz.close(env=self.name + '/lines', win=self.win_id['capa'])

        # if self.eval_metrics:
        #     self.viz.close(env=self.name+'/lines', win=self.win_id['metrics'])

    ####
    def visualize_line(self):

        # prepare data to plot
        data = self.line_gather.data
        iters = torch.Tensor(data['iter'])
        recon_both = torch.Tensor(data['recon_both'])
        recon_A = torch.Tensor(data['recon_A'])
        recon_B = torch.Tensor(data['recon_B'])
        kl_A = torch.Tensor(data['kl_A'])
        kl_B = torch.Tensor(data['kl_B'])

        cont_capacity_loss_infA = torch.Tensor(data['cont_capacity_loss_infA'])
        disc_capacity_loss_infA = torch.Tensor(data['disc_capacity_loss_infA'])
        cont_capacity_loss_infB = torch.Tensor(data['cont_capacity_loss_infB'])
        disc_capacity_loss_infB = torch.Tensor(data['disc_capacity_loss_infB'])


        recons = torch.stack(
            [recon_both.detach(), recon_A.detach(), recon_B.detach()], -1
        )
        kls = torch.stack(
            [kl_A.detach(), kl_B.detach()], -1
        )

        each_capa = torch.stack(
            [cont_capacity_loss_infA.detach(), disc_capacity_loss_infA.detach(), cont_capacity_loss_infB.detach(), disc_capacity_loss_infB.detach()], -1
        )

        self.viz.line(
            X=iters, Y=recons, env=self.name + '/lines',
            win=self.win_id['recon'], update='append',
            opts=dict(xlabel='iter', ylabel='recon losses',
                      title='Recon Losses', legend=['both', 'A', 'B'])
        )

        self.viz.line(
            X=iters, Y=kls, env=self.name + '/lines',
            win=self.win_id['kl'], update='append',
            opts=dict(xlabel='iter', ylabel='kl losses',
                      title='KL Losses', legend=['A', 'B']),
        )

        self.viz.line(
            X=iters, Y=each_capa, env=self.name + '/lines',
            win=self.win_id['capa'], update='append',
            opts=dict(xlabel='iter', ylabel='logalpha',
                      title='Capacity loss', legend=['cont_capaA', 'disc_capaA', 'cont_capaB', 'disc_capaB']),
        )


    ####
    def visualize_line_metrics(self, iters, metric1, metric2):

        # prepare data to plot
        iters = torch.tensor([iters], dtype=torch.int64).detach()
        metric1 = torch.tensor([metric1])
        metric2 = torch.tensor([metric2])
        metrics = torch.stack([metric1.detach(), metric2.detach()], -1)

        self.viz.line(
            X=iters, Y=metrics, env=self.name + '/lines',
            win=self.win_id['metrics'], update='append',
            opts=dict(xlabel='iter', ylabel='metrics',
                      title='Disentanglement metrics',
                      legend=['metric1', 'metric2'])
        )

    def set_mode(self, train=True):

        if train:
            self.encoderA.train()
            self.encoderB.train()
            self.decoderA.train()
            self.decoderB.train()
        else:
            self.encoderA.eval()
            self.encoderB.eval()
            self.decoderA.eval()
            self.decoderB.eval()

    ####
    def save_checkpoint(self, iteration):

        encoderA_path = os.path.join(
            self.ckpt_dir,
            'iter_%s_encoderA.pt' % iteration
        )
        encoderB_path = os.path.join(
            self.ckpt_dir,
            'iter_%s_encoderB.pt' % iteration
        )
        decoderA_path = os.path.join(
            self.ckpt_dir,
            'iter_%s_decoderA.pt' % iteration
        )
        decoderB_path = os.path.join(
            self.ckpt_dir,
            'iter_%s_decoderB.pt' % iteration
        )


        mkdirs(self.ckpt_dir)

        torch.save(self.encoderA, encoderA_path)
        torch.save(self.encoderB, encoderB_path)
        torch.save(self.decoderA, decoderA_path)
        torch.save(self.decoderB, decoderB_path)

    ####
    def load_checkpoint(self):

        encoderA_path = os.path.join(
            self.ckpt_dir,
            'iter_%s_encoderA.pt' % self.ckpt_load_iter
        )
        encoderB_path = os.path.join(
            self.ckpt_dir,
            'iter_%s_encoderB.pt' % self.ckpt_load_iter
        )
        decoderA_path = os.path.join(
            self.ckpt_dir,
            'iter_%s_decoderA.pt' % self.ckpt_load_iter
        )
        decoderB_path = os.path.join(
            self.ckpt_dir,
            'iter_%s_decoderB.pt' % self.ckpt_load_iter
        )

        if self.use_cuda:
            self.encoderA = torch.load(encoderA_path)
            self.encoderB = torch.load(encoderB_path)
            self.decoderA = torch.load(decoderA_path)
            self.decoderB = torch.load(decoderB_path)
        else:
            self.encoderA = torch.load(encoderA_path, map_location='cpu')
            self.encoderB = torch.load(encoderB_path, map_location='cpu')
            self.decoderA = torch.load(decoderA_path, map_location='cpu')
            self.decoderB = torch.load(decoderB_path, map_location='cpu')
Example #17
0
    def __init__(self, args):

        self.args = args

        self.name = '%s_lamkl_%s_zA_%s_zB_%s_zS_%s_HYPER_beta1_%s_beta2_%s_beta3_%s' % \
                    (
                        args.dataset, args.lamkl, args.zA_dim, args.zB_dim, args.zS_dim, args.beta1, args.beta2,
                        args.beta3)
        # to be appended by run_id

        self.use_cuda = args.cuda and torch.cuda.is_available()

        self.max_iter = int(args.max_iter)

        # do it every specified iters
        self.print_iter = args.print_iter
        self.ckpt_save_iter = args.ckpt_save_iter
        self.output_save_iter = args.output_save_iter

        # data info
        self.dset_dir = args.dset_dir
        self.dataset = args.dataset
        self.nc = 3

        # self.N = self.latent_values.shape[0]
        self.eval_metrics_iter = args.eval_metrics_iter

        # networks and optimizers
        self.batch_size = args.batch_size
        self.zA_dim = args.zA_dim
        self.zB_dim = args.zB_dim
        self.zS_dim = args.zS_dim
        self.lamkl = args.lamkl
        self.lr_VAE = args.lr_VAE
        self.beta1_VAE = args.beta1_VAE
        self.beta2_VAE = args.beta2_VAE

        self.lr_D = args.lr_D
        self.beta1_D = args.beta1_D
        self.beta2_D = args.beta2_D

        self.beta1 = args.beta1
        self.beta2 = args.beta2
        self.beta3 = args.beta3
        self.is_mss = args.is_mss

        # visdom setup
        self.viz_on = args.viz_on
        if self.viz_on:
            self.win_id = dict(
                recon='win_recon', kl='win_kl', capa='win_capa'
            )
            self.line_gather = DataGather(
                'iter', 'recon_both', 'recon_A', 'recon_B',
                'kl_A', 'kl_B', 
                'cont_capacity_loss_infA', 'disc_capacity_loss_infA', 'cont_capacity_loss_infB', 'disc_capacity_loss_infB'
            )

            # if self.eval_metrics:
            #     self.win_id['metrics'] = 'win_metrics'

            import visdom

            self.viz_port = args.viz_port  # port number, eg, 8097
            self.viz = visdom.Visdom(port=self.viz_port)
            self.viz_ll_iter = args.viz_ll_iter
            self.viz_la_iter = args.viz_la_iter

            self.viz_init()

        # create dirs: "records", "ckpts", "outputs" (if not exist)
        mkdirs("records");
        mkdirs("ckpts");
        mkdirs("outputs")

        # set run id
        if args.run_id < 0:  # create a new id
            k = 0;
            rfname = os.path.join("records", self.name + '_run_0.txt')
            while os.path.exists(rfname):
                k += 1
                rfname = os.path.join("records", self.name + '_run_%d.txt' % k)
            self.run_id = k
        else:  # user-provided id
            self.run_id = args.run_id

        # finalize name
        self.name = self.name + '_run_' + str(self.run_id)

        # records (text file to store console outputs)
        self.record_file = 'records/%s.txt' % self.name

        # checkpoints
        self.ckpt_dir = os.path.join("ckpts", self.name)

        # outputs
        self.output_dir_recon = os.path.join("outputs", self.name + '_recon')
        # dir for reconstructed images
        self.output_dir_synth = os.path.join("outputs", self.name + '_synth')
        # dir for synthesized images
        self.output_dir_trvsl = os.path.join("outputs", self.name + '_trvsl')

        #### create a new model or load a previously saved model

        self.ckpt_load_iter = args.ckpt_load_iter
        self.n_pts = args.n_pts
        self.n_data = args.n_data

        if self.ckpt_load_iter == 0:  # create a new model
            self.encoderA = EncoderA(self.zA_dim, self.zS_dim)
            self.encoderB = EncoderA(self.zB_dim, self.zS_dim)
            self.decoderA = DecoderA(self.zA_dim, self.zS_dim)
            self.decoderB = DecoderA(self.zB_dim, self.zS_dim)

        else:  # load a previously saved model

            print('Loading saved models (iter: %d)...' % self.ckpt_load_iter)
            self.load_checkpoint()
            print('...done')

        if self.use_cuda:
            print('Models moved to GPU...')
            self.encoderA = self.encoderA.cuda()
            self.encoderB = self.encoderB.cuda()
            self.decoderA = self.decoderA.cuda()
            self.decoderB = self.decoderB.cuda()
            print('...done')

        # get VAE parameters
        vae_params = \
            list(self.encoderA.parameters()) + \
            list(self.encoderB.parameters()) + \
            list(self.decoderA.parameters()) + \
            list(self.decoderB.parameters())


        # create optimizers
        self.optim_vae = optim.Adam(
            vae_params,
            lr=self.lr_VAE,
            betas=[self.beta1_VAE, self.beta2_VAE]
        )
Example #18
0
    def __init__(self, args):

        self.args = args

        self.name = ( '%s_gamma_%s_zDim_%s' + \
            '_lrVAE_%s_lrD_%s_rseed_%s' ) % \
            ( args.dataset, args.gamma, args.z_dim,
              args.lr_VAE, args.lr_D, args.rseed )
        # to be appended by run_id

        self.use_cuda = args.cuda and torch.cuda.is_available()

        self.max_iter = int(args.max_iter)

        # do it every specified iters
        self.print_iter = args.print_iter
        self.ckpt_save_iter = args.ckpt_save_iter
        self.output_save_iter = args.output_save_iter

        # data info
        self.dset_dir = args.dset_dir
        self.dataset = args.dataset
        if args.dataset.endswith('dsprites'):
            self.nc = 1
        elif args.dataset == '3dfaces':
            self.nc = 1
        else:
            self.nc = 3

        # groundtruth factor labels (only available for "dsprites")
        if self.dataset == 'dsprites':

            # latent factor = (color, shape, scale, orient, pos-x, pos-y)
            #   color = {1} (1)
            #   shape = {1=square, 2=oval, 3=heart} (3)
            #   scale = {0.5, 0.6, ..., 1.0} (6)
            #   orient = {2*pi*(k/39)}_{k=0}^39 (40)
            #   pos-x = {k/31}_{k=0}^31 (32)
            #   pos-y = {k/31}_{k=0}^31 (32)
            # (number of variations = 1*3*6*40*32*32 = 737280)

            latent_values = np.load(os.path.join(self.dset_dir,
                                                 'dsprites-dataset',
                                                 'latents_values.npy'),
                                    encoding='latin1')
            self.latent_values = latent_values[:, [1, 2, 3, 4, 5]]
            # latent values (actual values);(737280 x 5)
            latent_classes = np.load(os.path.join(self.dset_dir,
                                                  'dsprites-dataset',
                                                  'latents_classes.npy'),
                                     encoding='latin1')
            self.latent_classes = latent_classes[:, [1, 2, 3, 4, 5]]
            # classes ({0,1,...,K}-valued); (737280 x 5)
            self.latent_sizes = np.array([3, 6, 40, 32, 32])
            self.N = self.latent_values.shape[0]

            if args.eval_metrics:
                self.eval_metrics = True
                self.eval_metrics_iter = args.eval_metrics_iter

        # groundtruth factor labels
        elif self.dataset == 'oval_dsprites':

            latent_classes = np.load(os.path.join(self.dset_dir,
                                                  'dsprites-dataset',
                                                  'latents_classes.npy'),
                                     encoding='latin1')
            idx = np.where(latent_classes[:, 1] == 1)[0]  # "oval" shape only
            self.latent_classes = latent_classes[idx, :]
            self.latent_classes = self.latent_classes[:, [2, 3, 4, 5]]
            # classes ({0,1,...,K}-valued); (245760 x 4)
            latent_values = np.load(os.path.join(self.dset_dir,
                                                 'dsprites-dataset',
                                                 'latents_values.npy'),
                                    encoding='latin1')
            self.latent_values = latent_values[idx, :]
            self.latent_values = self.latent_values[:, [2, 3, 4, 5]]
            # latent values (actual values);(245760 x 4)

            self.latent_sizes = np.array([6, 40, 32, 32])
            self.N = self.latent_values.shape[0]

            if args.eval_metrics:
                self.eval_metrics = True
                self.eval_metrics_iter = args.eval_metrics_iter

        # groundtruth factor labels
        elif self.dataset == '3dfaces':

            # latent factor = (id, azimuth, elevation, lighting)
            #   id = {0,1,...,49} (50)
            #   azimuth = {-1.0,-0.9,...,0.9,1.0} (21)
            #   elevation = {-1.0,0.8,...,0.8,1.0} (11)
            #   lighting = {-1.0,0.8,...,0.8,1.0} (11)
            # (number of variations = 50*21*11*11 = 127050)

            latent_classes, latent_values = np.load(
                os.path.join(self.dset_dir,
                             '3d_faces/rtqichen/gt_factor_labels.npy'))
            self.latent_values = latent_values
            # latent values (actual values);(127050 x 4)
            self.latent_classes = latent_classes
            # classes ({0,1,...,K}-valued); (127050 x 4)
            self.latent_sizes = np.array([50, 21, 11, 11])
            self.N = self.latent_values.shape[0]

            if args.eval_metrics:
                self.eval_metrics = True
                self.eval_metrics_iter = args.eval_metrics_iter

        elif self.dataset == 'celeba':

            self.N = 202599
            self.eval_metrics = False

        elif self.dataset == 'edinburgh_teapots':

            # latent factor = (azimuth, elevation, R, G, B)
            #   azimuth = [0, 2*pi]
            #   elevation = [0, pi/2]
            #   R, G, B = [0,1]
            #
            #   "latent_values" = original (real) factor values
            #   "latent_classes" = equal binning into K=10 classes
            #
            # (refer to "data/edinburgh_teapots/my_make_split_data.py")

            K = 10
            val_ranges = [2 * np.pi, np.pi / 2, 1, 1, 1]
            bins = []
            for j in range(5):
                bins.append(np.linspace(0, val_ranges[j], K + 1))

            latent_values = np.load(
                os.path.join(self.dset_dir, 'edinburgh_teapots',
                             'gtfs_tr.npz'))['data']
            latent_values = np.concatenate(
                (latent_values,
                 np.load(
                     os.path.join(self.dset_dir, 'edinburgh_teapots',
                                  'gtfs_va.npz'))['data']),
                axis=0)
            latent_values = np.concatenate(
                (latent_values,
                 np.load(
                     os.path.join(self.dset_dir, 'edinburgh_teapots',
                                  'gtfs_te.npz'))['data']),
                axis=0)
            self.latent_values = latent_values

            latent_classes = np.zeros(latent_values.shape)
            for j in range(5):
                latent_classes[:, j] = np.digitize(latent_values[:, j],
                                                   bins[j])
            self.latent_classes = latent_classes - 1  # {0,...,K-1}-valued

            self.latent_sizes = K * np.ones(5, 'int64')
            self.N = self.latent_values.shape[0]

            if args.eval_metrics:
                self.eval_metrics = True
                self.eval_metrics_iter = args.eval_metrics_iter

        # networks and optimizers
        self.batch_size = args.batch_size
        self.z_dim = args.z_dim
        self.gamma = args.gamma
        self.lr_VAE = args.lr_VAE
        self.beta1_VAE = args.beta1_VAE
        self.beta2_VAE = args.beta2_VAE
        self.lr_D = args.lr_D
        self.beta1_D = args.beta1_D
        self.beta2_D = args.beta2_D

        # visdom setup
        self.viz_on = args.viz_on
        if self.viz_on:

            self.win_id = dict(DZ='win_DZ',
                               recon='win_recon',
                               kl='win_kl',
                               kl_alpha='win_kl_alpha')
            self.line_gather = DataGather('iter', 'p_DZ', 'p_DZ_perm', 'recon',
                                          'kl', 'kl_alpha')

            if self.eval_metrics:
                self.win_id['metrics'] = 'win_metrics'

            import visdom

            self.viz_port = args.viz_port  # port number, eg, 8097
            self.viz = visdom.Visdom(port=self.viz_port)
            self.viz_ll_iter = args.viz_ll_iter
            self.viz_la_iter = args.viz_la_iter

            self.viz_init()

        # create dirs: "records", "ckpts", "outputs" (if not exist)
        mkdirs("records")
        mkdirs("ckpts")
        mkdirs("outputs")

        # set run id
        if args.run_id < 0:  # create a new id
            k = 0
            rfname = os.path.join("records", self.name + '_run_0.txt')
            while os.path.exists(rfname):
                k += 1
                rfname = os.path.join("records", self.name + '_run_%d.txt' % k)
            self.run_id = k
        else:  # user-provided id
            self.run_id = args.run_id

        # finalize name
        self.name = self.name + '_run_' + str(self.run_id)

        # records (text file to store console outputs)
        self.record_file = 'records/%s.txt' % self.name

        # checkpoints
        self.ckpt_dir = os.path.join("ckpts", self.name)

        # outputs
        self.output_dir_recon = os.path.join("outputs", self.name + '_recon')
        # dir for reconstructed images
        self.output_dir_synth = os.path.join("outputs", self.name + '_synth')
        # dir for synthesized images
        self.output_dir_trvsl = os.path.join("outputs", self.name + '_trvsl')
        # dir for latent traversed images

        #### create a new model or load a previously saved model

        self.ckpt_load_iter = args.ckpt_load_iter

        if self.ckpt_load_iter == 0:  # create a new model

            # create a vae model
            if args.dataset.endswith('dsprites'):
                self.encoder = Encoder1(self.z_dim)
                self.decoder = Decoder1(self.z_dim)
            elif args.dataset == '3dfaces':
                self.encoder = Encoder3(self.z_dim)
                self.decoder = Decoder3(self.z_dim)
            elif args.dataset == 'celeba':
                self.encoder = Encoder4(self.z_dim)
                self.decoder = Decoder4(self.z_dim)
            elif args.dataset.endswith('teapots'):
                # self.encoder = Encoder4(self.z_dim)
                # self.decoder = Decoder4(self.z_dim)
                self.encoder = Encoder_ResNet(self.z_dim)
                self.decoder = Decoder_ResNet(self.z_dim)
            else:
                pass  #self.VAE = FactorVAE2(self.z_dim)

            # create a prior alpha model
            self.prior_alpha = PriorAlphaParams(self.z_dim)

            # create a posterior alpha model
            self.post_alpha = PostAlphaParams(self.z_dim)

            # create a discriminator model
            self.D = Discriminator(self.z_dim)

        else:  # load a previously saved model

            print('Loading saved models (iter: %d)...' % self.ckpt_load_iter)
            self.load_checkpoint()
            print('...done')

        if self.use_cuda:
            print('Models moved to GPU...')
            self.encoder = self.encoder.cuda()
            self.decoder = self.decoder.cuda()
            self.prior_alpha = self.prior_alpha.cuda()
            self.post_alpha = self.post_alpha.cuda()
            self.D = self.D.cuda()
            print('...done')

        # get VAE parameters
        vae_params = list(self.encoder.parameters()) + \
            list(self.decoder.parameters()) + \
            list(self.prior_alpha.parameters()) + \
            list(self.post_alpha.parameters())

        # get discriminator parameters
        dis_params = list(self.D.parameters())

        # create optimizers
        self.optim_vae = optim.Adam(vae_params,
                                    lr=self.lr_VAE,
                                    betas=[self.beta1_VAE, self.beta2_VAE])
        self.optim_dis = optim.Adam(dis_params,
                                    lr=self.lr_D,
                                    betas=[self.beta1_D, self.beta2_D])
Example #19
0
    def __init__(self, args):

        self.args = args

        self.name = ( '%s_etaS_%s_etaH_%s_lamklMin_%s_lamklMax_%s' + \
                      '_gamma_%s_zDim_%s' ) % \
            ( args.dataset, args.etaS, args.etaH, \
              args.lamklMin, args.lamklMax, args.gamma, args.z_dim )
        # to be appended by run_id

        self.use_cuda = args.cuda and torch.cuda.is_available()

        self.max_iter = int(args.max_iter)

        # do it every specified iters
        self.print_iter = args.print_iter
        self.ckpt_save_iter = args.ckpt_save_iter
        self.output_save_iter = args.output_save_iter

        # data info
        self.dset_dir = args.dset_dir
        self.dataset = args.dataset
        if args.dataset.endswith('dsprites'):
            self.nc = 1
        else:
            self.nc = 3

        # groundtruth factor labels (only available for "dsprites")
        if self.dataset == 'dsprites':

            # latent factor = (color, shape, scale, orient, pos-x, pos-y)
            #   color = {1} (1)
            #   shape = {1=square, 2=oval, 3=heart} (3)
            #   scale = {0.5, 0.6, ..., 1.0} (6)
            #   orient = {2*pi*(k/39)}_{k=0}^39 (40)
            #   pos-x = {k/31}_{k=0}^31 (32)
            #   pos-y = {k/31}_{k=0}^31 (32)
            # (number of variations = 1*3*6*40*32*32 = 737280)

            latent_values = np.load(os.path.join(self.dset_dir,
                                                 'dsprites-dataset',
                                                 'latents_values.npy'),
                                    encoding='latin1')
            self.latent_values = latent_values[:, [1, 2, 3, 4, 5]]
            # latent values (actual values);(737280 x 5)
            latent_classes = np.load(os.path.join(self.dset_dir,
                                                  'dsprites-dataset',
                                                  'latents_classes.npy'),
                                     encoding='latin1')
            self.latent_classes = latent_classes[:, [1, 2, 3, 4, 5]]
            # classes ({0,1,...,K}-valued); (737280 x 5)
            self.latent_sizes = np.array([3, 6, 40, 32, 32])
            self.N = self.latent_values.shape[0]

            if args.eval_metrics:
                self.eval_metrics = True
                self.eval_metrics_iter = args.eval_metrics_iter

        # groundtruth factor labels
        elif self.dataset == 'oval_dsprites':

            latent_classes = np.load(os.path.join(self.dset_dir,
                                                  'dsprites-dataset',
                                                  'latents_classes.npy'),
                                     encoding='latin1')
            idx = np.where(latent_classes[:, 1] == 1)[0]  # "oval" shape only
            self.latent_classes = latent_classes[idx, :]
            self.latent_classes = self.latent_classes[:, [2, 3, 4, 5]]
            # classes ({0,1,...,K}-valued); (245760 x 4)
            latent_values = np.load(os.path.join(self.dset_dir,
                                                 'dsprites-dataset',
                                                 'latents_values.npy'),
                                    encoding='latin1')
            self.latent_values = latent_values[idx, :]
            self.latent_values = self.latent_values[:, [2, 3, 4, 5]]
            # latent values (actual values);(245760 x 4)

            self.latent_sizes = np.array([6, 40, 32, 32])
            self.N = self.latent_values.shape[0]

            if args.eval_metrics:
                self.eval_metrics = True
                self.eval_metrics_iter = args.eval_metrics_iter

        # networks and optimizers
        self.batch_size = args.batch_size
        self.z_dim = args.z_dim
        self.etaS = args.etaS
        self.etaH = args.etaH
        self.lamklMin = args.lamklMin
        self.lamklMax = args.lamklMax
        self.gamma = args.gamma
        self.lr_VAE = args.lr_VAE
        self.beta1_VAE = args.beta1_VAE
        self.beta2_VAE = args.beta2_VAE
        #        self.lr_rvec = args.lr_rvec
        #        self.beta1_rvec = args.beta1_rvec
        #        self.beta2_rvec = args.beta2_rvec
        self.lr_D = args.lr_D
        self.beta1_D = args.beta1_D
        self.beta2_D = args.beta2_D

        # visdom setup
        self.viz_on = args.viz_on
        if self.viz_on:

            self.win_id = dict(DZ='win_DZ',
                               recon='win_recon',
                               kl='win_kl',
                               rvS='win_rvS',
                               rvH='win_rvH')
            self.line_gather = DataGather('iter', 'p_DZ', 'p_DZ_perm', 'recon',
                                          'kl', 'rvS', 'rvH')

            if self.eval_metrics:
                self.win_id['metrics'] = 'win_metrics'

            import visdom

            self.viz_port = args.viz_port  # port number, eg, 8097
            self.viz = visdom.Visdom(port=self.viz_port)
            self.viz_ll_iter = args.viz_ll_iter
            self.viz_la_iter = args.viz_la_iter

            self.viz_init()

        # create dirs: "records", "ckpts", "outputs" (if not exist)
        mkdirs("records")
        mkdirs("ckpts")
        mkdirs("outputs")

        # set run id
        if args.run_id < 0:  # create a new id
            k = 0
            rfname = os.path.join("records", self.name + '_run_0.txt')
            while os.path.exists(rfname):
                k += 1
                rfname = os.path.join("records", self.name + '_run_%d.txt' % k)
            self.run_id = k
        else:  # user-provided id
            self.run_id = args.run_id

        # finalize name
        self.name = self.name + '_run_' + str(self.run_id)

        # records (text file to store console outputs)
        self.record_file = 'records/%s.txt' % self.name

        # checkpoints
        self.ckpt_dir = os.path.join("ckpts", self.name)

        # outputs
        self.output_dir_recon = os.path.join("outputs", self.name + '_recon')
        # dir for reconstructed images
        self.output_dir_synth = os.path.join("outputs", self.name + '_synth')
        # dir for synthesized images
        self.output_dir_trvsl = os.path.join("outputs", self.name + '_trvsl')
        # dir for latent traversed images

        #### create a new model or load a previously saved model

        self.ckpt_load_iter = args.ckpt_load_iter

        if self.ckpt_load_iter == 0:  # create a new model

            # create a vae model
            if args.dataset.endswith('dsprites'):
                self.encoder = Encoder1(self.z_dim)
                self.decoder = Decoder1(self.z_dim)
            else:
                pass  #self.VAE = FactorVAE2(self.z_dim)

            # create a relevance vector
            self.rvec = RelevanceVector(self.z_dim)

            # create a discriminator model
            self.D = Discriminator(self.z_dim)

        else:  # load a previously saved model

            print('Loading saved models (iter: %d)...' % self.ckpt_load_iter)
            self.load_checkpoint()
            print('...done')

        if self.use_cuda:
            print('Models moved to GPU...')
            self.encoder = self.encoder.cuda()
            self.decoder = self.decoder.cuda()
            self.rvec = self.rvec.cuda()
            self.D = self.D.cuda()
            print('...done')

        # get VAE parameters (and rv parameters)
        vae_params = list(self.encoder.parameters()) + \
          list(self.decoder.parameters()) + list(self.rvec.parameters())

        # get discriminator parameters
        dis_params = list(self.D.parameters())

        # create optimizers
        self.optim_vae = optim.Adam(vae_params,
                                    lr=self.lr_VAE,
                                    betas=[self.beta1_VAE, self.beta2_VAE])
        self.optim_dis = optim.Adam(dis_params,
                                    lr=self.lr_D,
                                    betas=[self.beta1_D, self.beta2_D])
    def __init__(self, args):
        # Misc
        use_cuda = args.cuda and torch.cuda.is_available()
        self.device = 'cuda' if use_cuda else 'cpu'
        self.name = args.name
        self.max_iter = int(args.max_iter)
        self.print_iter = args.print_iter
        self.global_iter = 0
        self.global_iter_cls = 0
        self.pbar = tqdm(total=self.max_iter)
        self.pbar_cls = tqdm(total=self.max_iter)

        # Data
        self.dset_dir = args.dset_dir
        self.dataset = args.dataset
        self.batch_size = args.batch_size
        self.eval_batch_size = args.eval_batch_size
        self.data_loader = return_data(args, 0)
        self.data_loader_eval = return_data(args, 2)

        # Networks & Optimizers
        self.z_dim = args.z_dim
        self.gamma = args.gamma
        self.beta = args.beta

        self.lr_VAE = args.lr_VAE
        self.beta1_VAE = args.beta1_VAE
        self.beta2_VAE = args.beta2_VAE

        self.lr_D = args.lr_D
        self.beta1_D = args.beta1_D
        self.beta2_D = args.beta2_D
        self.alpha = args.alpha
        self.beta = args.beta
        self.grl = args.grl

        self.lr_cls = args.lr_cls
        self.beta1_cls = args.beta1_D
        self.beta2_cls = args.beta2_D

        if args.dataset == 'dsprites':
            self.VAE = FactorVAE1(self.z_dim).to(self.device)
            self.nc = 1
        else:
            self.VAE = FactorVAE2(self.z_dim).to(self.device)
            self.nc = 3
        self.optim_VAE = optim.Adam(self.VAE.parameters(),
                                    lr=self.lr_VAE,
                                    betas=(self.beta1_VAE, self.beta2_VAE))

        self.pacls = classifier(30, 2).cuda()
        self.revcls = classifier(30, 2).cuda()
        self.tcls = classifier(30, 2).cuda()
        self.trevcls = classifier(30, 2).cuda()

        self.targetcls = classifier(59, 2).cuda()
        self.pa_target = classifier(30, 2).cuda()
        self.target_pa = paclassifier(1, 1).cuda()
        self.pa_pa = classifier(30, 2).cuda()

        self.D = Discriminator(self.z_dim).to(self.device)
        self.optim_D = optim.Adam(self.D.parameters(),
                                  lr=self.lr_D,
                                  betas=(self.beta1_D, self.beta2_D))

        self.optim_pacls = optim.Adam(self.pacls.parameters(), lr=self.lr_D)

        self.optim_revcls = optim.Adam(self.revcls.parameters(), lr=self.lr_D)

        self.optim_tcls = optim.Adam(self.tcls.parameters(), lr=self.lr_D)
        self.optim_trevcls = optim.Adam(self.trevcls.parameters(),
                                        lr=self.lr_D)

        self.optim_cls = optim.Adam(self.targetcls.parameters(),
                                    lr=self.lr_cls)
        self.optim_pa_target = optim.Adam(self.pa_target.parameters(),
                                          lr=self.lr_cls)
        self.optim_target_pa = optim.Adam(self.target_pa.parameters(),
                                          lr=self.lr_cls)
        self.optim_pa_pa = optim.Adam(self.pa_pa.parameters(), lr=self.lr_cls)

        self.nets = [
            self.VAE, self.D, self.pacls, self.targetcls, self.revcls,
            self.pa_target, self.tcls, self.trevcls
        ]

        # Visdom
        self.viz_on = args.viz_on
        self.win_id = dict(D_z='win_D_z',
                           recon='win_recon',
                           kld='win_kld',
                           acc='win_acc')
        self.line_gather = DataGather('iter', 'soft_D_z', 'soft_D_z_pperm',
                                      'recon', 'kld', 'acc')
        self.image_gather = DataGather('true', 'recon')
        if self.viz_on:
            self.viz_port = args.viz_port
            self.viz = visdom.Visdom(port=self.viz_port)
            self.viz_ll_iter = args.viz_ll_iter
            self.viz_la_iter = args.viz_la_iter
            self.viz_ra_iter = args.viz_ra_iter
            self.viz_ta_iter = args.viz_ta_iter
            if not self.viz.win_exists(env=self.name + '/lines',
                                       win=self.win_id['D_z']):
                self.viz_init()

        # Checkpoint
        self.ckpt_dir = os.path.join(args.ckpt_dir, args.name)
        self.ckpt_save_iter = args.ckpt_save_iter
        mkdirs(self.ckpt_dir + "/cls")
        mkdirs(self.ckpt_dir + "/vae")

        if args.ckpt_load:

            self.load_checkpoint(args.ckpt_load)

        # Output(latent traverse GIF)
        self.output_dir = os.path.join(args.output_dir, args.name)
        self.output_save = args.output_save
        mkdirs(self.output_dir)
Example #21
0
class Solver(object):

    ####
    def __init__(self, args):
        self.args = args

        self.name = '%s_map_pred_len_%s_zS_%s_dr_mlp_%s_dr_rnn_%s_dr_map_%s_enc_h_dim_%s_dec_h_dim_%s_mlp_dim_%s_emb_dim_%s_lr_%s_klw_%s_map_%s' % \
                    (args.dataset_name, args.pred_len, args.zS_dim, args.dropout_mlp, args.dropout_rnn, args.dropout_map, args.encoder_h_dim,
                     args.decoder_h_dim, args.mlp_dim, args.emb_dim, args.lr_VAE, args.kl_weight, args.map_size)

        # to be appended by run_id

        # self.use_cuda = args.cuda and torch.cuda.is_available()
        self.device = args.device
        self.temp = 0.66
        self.eps = 1e-9
        self.kl_weight = args.kl_weight

        self.max_iter = int(args.max_iter)

        # do it every specified iters
        self.print_iter = args.print_iter
        self.ckpt_save_iter = args.ckpt_save_iter
        self.output_save_iter = args.output_save_iter

        # data info
        self.dataset_dir = args.dataset_dir
        self.dataset_name = args.dataset_name

        # self.N = self.latent_values.shape[0]
        # self.eval_metrics_iter = args.eval_metrics_iter

        # networks and optimizers
        self.batch_size = args.batch_size
        self.zS_dim = args.zS_dim
        self.lr_VAE = args.lr_VAE
        self.beta1_VAE = args.beta1_VAE
        self.beta2_VAE = args.beta2_VAE
        print(args.desc)

        # visdom setup
        self.viz_on = args.viz_on
        if self.viz_on:
            self.win_id = dict(recon='win_recon',
                               loss_kl='win_loss_kl',
                               loss_recon='win_loss_recon',
                               total_loss='win_total_loss',
                               ade_min='win_ade_min',
                               fde_min='win_fde_min',
                               ade_avg='win_ade_avg',
                               fde_avg='win_fde_avg',
                               ade_std='win_ade_std',
                               fde_std='win_fde_std',
                               test_loss_recon='win_test_loss_recon',
                               test_loss_kl='win_test_loss_kl',
                               test_total_loss='win_test_total_loss')
            self.line_gather = DataGather('iter', 'loss_recon', 'loss_kl',
                                          'total_loss', 'ade_min', 'fde_min',
                                          'ade_avg', 'fde_avg', 'ade_std',
                                          'fde_std', 'test_loss_recon',
                                          'test_loss_kl', 'test_total_loss')

            import visdom

            self.viz_port = args.viz_port  # port number, eg, 8097
            self.viz = visdom.Visdom(port=self.viz_port)
            self.viz_ll_iter = args.viz_ll_iter
            self.viz_la_iter = args.viz_la_iter

            self.viz_init()

        # create dirs: "records", "ckpts", "outputs" (if not exist)
        mkdirs("records")
        mkdirs("ckpts")
        mkdirs("outputs")

        # set run id
        if args.run_id < 0:  # create a new id
            k = 0
            rfname = os.path.join("records", self.name + '_run_0.txt')
            while os.path.exists(rfname):
                k += 1
                rfname = os.path.join("records", self.name + '_run_%d.txt' % k)
            self.run_id = k
        else:  # user-provided id
            self.run_id = args.run_id

        # finalize name
        self.name = self.name + '_run_' + str(self.run_id)

        # records (text file to store console outputs)
        self.record_file = 'records/%s.txt' % self.name

        # checkpoints
        self.ckpt_dir = os.path.join("ckpts", self.name)

        # outputs
        self.output_dir_recon = os.path.join("outputs", self.name + '_recon')
        # dir for reconstructed images
        self.output_dir_synth = os.path.join("outputs", self.name + '_synth')
        # dir for synthesized images
        self.output_dir_trvsl = os.path.join("outputs", self.name + '_trvsl')

        #### create a new model or load a previously saved model

        self.ckpt_load_iter = args.ckpt_load_iter

        self.obs_len = args.obs_len
        self.pred_len = args.pred_len
        self.num_layers = args.num_layers
        self.decoder_h_dim = args.decoder_h_dim

        if self.ckpt_load_iter == 0 or args.dataset_name == 'all':  # create a new model
            self.encoderMx = Encoder(args.zS_dim,
                                     enc_h_dim=args.encoder_h_dim,
                                     mlp_dim=args.mlp_dim,
                                     emb_dim=args.emb_dim,
                                     map_size=args.map_size,
                                     batch_norm=args.batch_norm,
                                     num_layers=args.num_layers,
                                     dropout_mlp=args.dropout_mlp,
                                     dropout_rnn=args.dropout_rnn,
                                     dropout_map=args.dropout_map).to(
                                         self.device)
            self.encoderMy = EncoderY(args.zS_dim,
                                      enc_h_dim=args.encoder_h_dim,
                                      mlp_dim=args.mlp_dim,
                                      emb_dim=args.emb_dim,
                                      map_size=args.map_size,
                                      num_layers=args.num_layers,
                                      dropout_rnn=args.dropout_rnn,
                                      dropout_map=args.dropout_map,
                                      device=self.device).to(self.device)
            self.decoderMy = Decoder(args.pred_len,
                                     dec_h_dim=self.decoder_h_dim,
                                     enc_h_dim=args.encoder_h_dim,
                                     mlp_dim=args.mlp_dim,
                                     z_dim=args.zS_dim,
                                     num_layers=args.num_layers,
                                     device=args.device,
                                     dropout_rnn=args.dropout_rnn).to(
                                         self.device)

        else:  # load a previously saved model
            print('Loading saved models (iter: %d)...' % self.ckpt_load_iter)
            self.load_checkpoint()
            print('...done')

        # get VAE parameters
        vae_params = \
            list(self.encoderMx.parameters()) + \
            list(self.encoderMy.parameters()) + \
            list(self.decoderMy.parameters())

        # create optimizers
        self.optim_vae = optim.Adam(vae_params,
                                    lr=self.lr_VAE,
                                    betas=[self.beta1_VAE, self.beta2_VAE])

        ######## map
        # self.map = imageio.imread('D:\crowd\ewap_dataset\seq_' + self.dataset_name + '/map.png')
        # h = np.loadtxt('D:\crowd\ewap_dataset\seq_' + self.dataset_name + '\H.txt')
        # self.inv_h_t = np.linalg.pinv(np.transpose(h))
        self.map_size = args.map_size
        ######################################
        # prepare dataloader (iterable)
        print('Start loading data...')
        train_path = os.path.join(self.dataset_dir, self.dataset_name, 'train')
        val_path = os.path.join(self.dataset_dir, self.dataset_name, 'test')

        # long_dtype, float_dtype = get_dtypes(args)

        print("Initializing train dataset")
        if self.dataset_name == 'eth':
            self.args.pixel_distance = 5  # for hotel
        else:
            self.args.pixel_distance = 3  # for eth
        _, self.train_loader = data_loader(self.args, train_path)
        print("Initializing val dataset")
        if self.dataset_name == 'eth':
            self.args.pixel_distance = 3
        else:
            self.args.pixel_distance = 5
        _, self.val_loader = data_loader(self.args, val_path)
        # self.val_loader = self.train_loader

        print('There are {} iterations per epoch'.format(
            len(self.train_loader.dataset) / args.batch_size))
        print('...done')

    ####
    def train(self):
        self.set_mode(train=True)

        data_loader = self.train_loader
        self.N = len(data_loader.dataset)

        # iterators from dataloader
        iterator = iter(data_loader)

        iter_per_epoch = len(iterator)

        start_iter = self.ckpt_load_iter + 1
        epoch = int(start_iter / iter_per_epoch)

        for iteration in range(start_iter, self.max_iter + 1):

            # reset data iterators for each epoch
            if iteration % iter_per_epoch == 0:
                print('==== epoch %d done ====' % epoch)
                epoch += 1
                iterator = iter(data_loader)

            # ============================================
            #          TRAIN THE VAE (ENC & DEC)
            # ============================================

            # sample a mini-batch
            (obs_traj, fut_traj, seq_start_end, obs_frames, fut_frames,
             past_obst, fut_obst) = next(iterator)
            batch = fut_traj.size(1)

            (last_past_map_feat, encX_h_feat,
             logitX) = self.encoderMx(past_obst, seq_start_end, train=True)

            (fut_map_emb, encY_h_feat, logitY) \
                = self.encoderMy(past_obst[-1], fut_obst, seq_start_end, encX_h_feat, train=True)

            p_dist = discrete(logits=logitX)
            q_dist = discrete(logits=logitY)
            relaxed_q_dist = concrete(logits=logitY, temperature=self.temp)

            fut_map_mean = self.decoderMy(last_past_map_feat, encX_h_feat,
                                          relaxed_q_dist.rsample(),
                                          fut_map_emb)
            fut_map_mean = fut_map_mean.view(fut_obst.shape[0],
                                             fut_obst.shape[1], -1,
                                             fut_map_mean.shape[2],
                                             fut_map_mean.shape[3])
            loglikelihood = (torch.log(fut_map_mean + self.eps) * fut_obst +
                             torch.log(1 - fut_map_mean + self.eps) *
                             (1 - fut_obst)).sum().div(batch)

            loss_kl = kl_divergence(q_dist, p_dist).sum().div(batch)
            loss_kl = torch.clamp(loss_kl, min=0.07)
            # print('log_likelihood:', loglikelihood.item(), ' kl:', loss_kl.item())

            elbo = loglikelihood - self.kl_weight * loss_kl
            vae_loss = -elbo

            self.optim_vae.zero_grad()
            vae_loss.backward()
            self.optim_vae.step()

            # save model parameters
            if iteration % self.ckpt_save_iter == 0:
                self.save_checkpoint(iteration)

            # (visdom) insert current line stats
            if self.viz_on and (iteration % self.viz_ll_iter == 0):
                test_loss_recon, test_loss_kl, test_vae_loss = self.test()
                self.line_gather.insert(
                    iter=iteration,
                    loss_recon=-loglikelihood.item(),
                    loss_kl=loss_kl.item(),
                    total_loss=vae_loss.item(),
                    test_loss_recon=-test_loss_recon.item(),
                    test_loss_kl=test_loss_kl.item(),
                    test_total_loss=test_vae_loss.item(),
                )
                prn_str = ('[iter_%d (epoch_%d)] vae_loss: %.3f ' + \
                           '(recon: %.3f, kl: %.3f)\n'
                           ) % \
                          (iteration, epoch,
                           vae_loss.item(), -loglikelihood.item(), loss_kl.item()
                           )

                print(prn_str)

            # (visdom) visualize line stats (then flush out)
            if self.viz_on and (iteration % self.viz_la_iter == 0):
                self.visualize_line()
                self.line_gather.flush()
            if (iteration % self.output_save_iter == 0):
                self.recon(self.val_loader)

    def test(self):
        self.set_mode(train=False)
        all_loglikelihood = 0
        all_loss_kl = 0
        all_vae_loss = 0
        b = 0
        with torch.no_grad():
            for abatch in self.val_loader:
                b += 1

                # sample a mini-batch
                (obs_traj, fut_traj, seq_start_end, obs_frames, fut_frames,
                 past_obst, fut_obst) = abatch
                batch = fut_traj.size(1)

                (last_past_map_feat, encX_h_feat,
                 logitX) = self.encoderMx(past_obst, seq_start_end)

                (_, _, logitY) \
                    = self.encoderMy(past_obst[-1], fut_obst, seq_start_end, encX_h_feat)

                p_dist = discrete(logits=logitX)
                q_dist = discrete(logits=logitY)
                relaxed_p_dist = concrete(logits=logitX, temperature=self.temp)

                fut_map_mean = self.decoderMy(last_past_map_feat, encX_h_feat,
                                              relaxed_p_dist.rsample())
                fut_map_mean = fut_map_mean.view(fut_obst.shape[0],
                                                 fut_obst.shape[1], -1,
                                                 fut_map_mean.shape[2],
                                                 fut_map_mean.shape[3])
                loglikelihood = (
                    torch.log(fut_map_mean + self.eps) * fut_obst +
                    torch.log(1 - fut_map_mean + self.eps) *
                    (1 - fut_obst)).sum().div(batch)

                loss_kl = kl_divergence(q_dist, p_dist).sum().div(batch)
                loss_kl = torch.clamp(loss_kl, min=0.07)
                elbo = loglikelihood - self.kl_weight * loss_kl
                vae_loss = -elbo
                all_loglikelihood += loglikelihood
                all_loss_kl += loss_kl
                all_vae_loss += vae_loss
        self.set_mode(train=True)
        return all_loglikelihood.div(b), all_loss_kl.div(b), all_vae_loss.div(
            b)

    def recon(self, data_loader):
        self.set_mode(train=False)
        with torch.no_grad():
            fixed_idxs = range(5)

            from data.obstacles import seq_collate
            data = []
            for i, idx in enumerate(fixed_idxs):
                data.append(data_loader.dataset.__getitem__(idx))

            (obs_traj, fut_traj, seq_start_end, obs_frames, fut_frames,
             past_obst, fut_obst) = seq_collate(data)

            (last_past_map_feat, encX_h_feat,
             logitX) = self.encoderMx(past_obst, seq_start_end)
            (fut_map_emb, _, logitY) = self.encoderMy(past_obst[-1], fut_obst,
                                                      seq_start_end,
                                                      encX_h_feat)

            relaxed_p_dist = concrete(logits=logitX, temperature=self.temp)
            relaxed_q_dist = concrete(logits=logitY, temperature=self.temp)

            prior_fut_map_mean = self.decoderMy(last_past_map_feat,
                                                encX_h_feat,
                                                relaxed_p_dist.rsample())

            posterior_fut_map_mean = self.decoderMy(
                last_past_map_feat,
                encX_h_feat,
                relaxed_q_dist.rsample(),
                fut_map_emb,
            )

            prior_fut_map_mean = prior_fut_map_mean.view(
                fut_obst.shape[0], fut_obst.shape[1], -1,
                prior_fut_map_mean.shape[2], prior_fut_map_mean.shape[3])
            posterior_fut_map_mean = posterior_fut_map_mean.view(
                fut_obst.shape[0], fut_obst.shape[1], -1,
                posterior_fut_map_mean.shape[2],
                posterior_fut_map_mean.shape[3])

            out_dir = os.path.join('./output', self.name,
                                   str(self.ckpt_load_iter))
            mkdirs(out_dir)
            for i in range(fut_obst.shape[1]):
                save_image(prior_fut_map_mean[:, i],
                           str(
                               os.path.join(
                                   out_dir,
                                   'prior_recon_img' + str(i) + '.png')),
                           nrow=self.pred_len,
                           pad_value=1)
                save_image(posterior_fut_map_mean[:, i],
                           str(
                               os.path.join(
                                   out_dir,
                                   'posterior_recon_img' + str(i) + '.png')),
                           nrow=self.pred_len,
                           pad_value=1)
                save_image(fut_obst[:, i],
                           str(
                               os.path.join(out_dir,
                                            'gt_img' + str(i) + '.png')),
                           nrow=self.pred_len,
                           pad_value=1)

        self.set_mode(train=True)

    ####
    def viz_init(self):
        self.viz.close(env=self.name + '/lines', win=self.win_id['loss_recon'])
        self.viz.close(env=self.name + '/lines', win=self.win_id['loss_kl'])
        self.viz.close(env=self.name + '/lines', win=self.win_id['total_loss'])
        self.viz.close(env=self.name + '/lines',
                       win=self.win_id['test_loss_recon'])
        self.viz.close(env=self.name + '/lines',
                       win=self.win_id['test_loss_kl'])
        self.viz.close(env=self.name + '/lines',
                       win=self.win_id['test_total_loss'])

    ####
    def visualize_line(self):

        # prepare data to plot
        data = self.line_gather.data
        iters = torch.Tensor(data['iter'])
        loss_recon = torch.Tensor(data['loss_recon'])
        loss_kl = torch.Tensor(data['loss_kl'])
        total_loss = torch.Tensor(data['total_loss'])
        test_loss_recon = torch.Tensor(data['test_loss_recon'])
        test_loss_kl = torch.Tensor(data['test_loss_kl'])
        test_total_loss = torch.Tensor(data['test_total_loss'])

        self.viz.line(X=iters,
                      Y=loss_recon,
                      env=self.name + '/lines',
                      win=self.win_id['loss_recon'],
                      update='append',
                      opts=dict(xlabel='iter',
                                ylabel='-loglikelihood',
                                title='Recon. loss of predicted future traj'))

        self.viz.line(
            X=iters,
            Y=loss_kl,
            env=self.name + '/lines',
            win=self.win_id['loss_kl'],
            update='append',
            opts=dict(xlabel='iter',
                      ylabel='kl divergence',
                      title='KL div. btw posterior and c. prior'),
        )

        self.viz.line(
            X=iters,
            Y=total_loss,
            env=self.name + '/lines',
            win=self.win_id['total_loss'],
            update='append',
            opts=dict(xlabel='iter', ylabel='vae loss', title='VAE loss'),
        )

        self.viz.line(X=iters,
                      Y=test_loss_recon,
                      env=self.name + '/lines',
                      win=self.win_id['test_loss_recon'],
                      update='append',
                      opts=dict(
                          xlabel='iter',
                          ylabel='-loglikelihood',
                          title='Test Recon. loss of predicted future traj'))

        self.viz.line(
            X=iters,
            Y=test_loss_kl,
            env=self.name + '/lines',
            win=self.win_id['test_loss_kl'],
            update='append',
            opts=dict(xlabel='iter',
                      ylabel='kl divergence',
                      title='Test KL div. btw posterior and c. prior'),
        )

        self.viz.line(
            X=iters,
            Y=test_total_loss,
            env=self.name + '/lines',
            win=self.win_id['test_total_loss'],
            update='append',
            opts=dict(xlabel='iter', ylabel='vae loss', title='Test VAE loss'),
        )

    def set_mode(self, train=True):

        if train:
            self.encoderMx.train()
            self.encoderMy.train()
            self.decoderMy.train()
        else:
            self.encoderMx.eval()
            self.encoderMy.eval()
            self.decoderMy.eval()

    ####
    def save_checkpoint(self, iteration):

        encoderMx_path = os.path.join(self.ckpt_dir,
                                      'iter_%s_encoderMx.pt' % iteration)
        encoderMy_path = os.path.join(self.ckpt_dir,
                                      'iter_%s_encoderMy.pt' % iteration)
        decoderMy_path = os.path.join(self.ckpt_dir,
                                      'iter_%s_decoderMy.pt' % iteration)

        mkdirs(self.ckpt_dir)

        torch.save(self.encoderMx, encoderMx_path)
        torch.save(self.encoderMy, encoderMy_path)
        torch.save(self.decoderMy, decoderMy_path)

    ####
    def load_checkpoint(self):

        encoderMx_path = os.path.join(
            self.ckpt_dir, 'iter_%s_encoderMx.pt' % self.ckpt_load_iter)
        encoderMy_path = os.path.join(
            self.ckpt_dir, 'iter_%s_encoderMy.pt' % self.ckpt_load_iter)
        decoderMy_path = os.path.join(
            self.ckpt_dir, 'iter_%s_decoderMy.pt' % self.ckpt_load_iter)

        if self.device == 'cuda':
            self.encoderMx = torch.load(encoderMx_path)
            self.encoderMy = torch.load(encoderMy_path)
            self.decoderMy = torch.load(decoderMy_path)
        else:
            self.encoderMx = torch.load(encoderMx_path, map_location='cpu')
            self.encoderMy = torch.load(encoderMy_path, map_location='cpu')
            self.decoderMy = torch.load(decoderMy_path, map_location='cpu')