예제 #1
0
    def __init__(self, args):
        # Misc
        use_cuda = args.cuda and torch.cuda.is_available()
        self.device = 'cuda' if use_cuda else 'cpu'
        self.name = args.name
        self.max_iter = int(args.max_iter)
        self.print_iter = args.print_iter
        self.global_iter = 0
        self.pbar = tqdm(total=self.max_iter)

        # Data
        self.dset_dir = args.dset_dir
        self.dataset = args.dataset
        self.batch_size = args.batch_size
        self.data_loader = return_data(args)

        # Networks & Optimizers
        self.z_dim = args.z_dim
        self.gamma = args.gamma

        self.lr_VAE = args.lr_VAE
        self.beta1_VAE = args.beta1_VAE
        self.beta2_VAE = args.beta2_VAE

        self.lr_D = args.lr_D
        self.beta1_D = args.beta1_D
        self.beta2_D = args.beta2_D

        if args.dataset == 'dsprites':
            self.VAE = FactorVAE1(self.z_dim).to(self.device)
            self.nc = 1
        else:
            self.VAE = FactorVAE2(self.z_dim).to(self.device)
            self.nc = 3
        self.optim_VAE = optim.Adam(self.VAE.parameters(),
                                    lr=self.lr_VAE,
                                    betas=(self.beta1_VAE, self.beta2_VAE))

        self.D = Discriminator(self.z_dim).to(self.device)
        self.optim_D = optim.Adam(self.D.parameters(),
                                  lr=self.lr_D,
                                  betas=(self.beta1_D, self.beta2_D))

        self.nets = [self.VAE, self.D]

        # Visdom
        self.viz_on = args.viz_on
        self.win_id = dict(D_z='win_D_z',
                           recon='win_recon',
                           kld='win_kld',
                           acc='win_acc')
        self.line_gather = DataGather('iter', 'soft_D_z', 'soft_D_z_pperm',
                                      'recon', 'kld', 'acc')
        self.image_gather = DataGather('true', 'recon')
        if self.viz_on:
            self.viz_port = args.viz_port
            self.viz = visdom.Visdom(port=self.viz_port)
            self.viz_ll_iter = args.viz_ll_iter
            self.viz_la_iter = args.viz_la_iter
            self.viz_ra_iter = args.viz_ra_iter
            self.viz_ta_iter = args.viz_ta_iter
            if not self.viz.win_exists(env=self.name + '/lines',
                                       win=self.win_id['D_z']):
                self.viz_init()

        # Checkpoint
        self.ckpt_dir = os.path.join(args.ckpt_dir, args.name)
        self.ckpt_save_iter = args.ckpt_save_iter
        mkdirs(self.ckpt_dir)
        if args.ckpt_load:
            self.load_checkpoint(args.ckpt_load)

        # Output(latent traverse GIF)
        self.output_dir = os.path.join(args.output_dir, args.name)
        self.output_save = args.output_save
        mkdirs(self.output_dir)
    def __init__(self, args):
        self.data_loader = return_data(args)
        
        self.model_name = args.model
        self.z_dim = args.z_dim
        self.nc = args.number_channel
        # if args.dataset == 'mnist':
        #     assert os.path.exists('./{}_model_mnist.pk'.format(self.model_name)), 'no model found!'
        #     from models.beta_vae_archs import Beta_VAE_mnist
        #     self.model = nn.DataParallel(Beta_VAE_mnist(z_dim=args.z_dim,n_cls=10).cuda(), device_ids=range(1))
        #     self.model.load_state_dict(torch.load('./{}_model_mnist.pk'.format(self.model_name)))
        #     dim = 28
        # elif args.dataset == 'dsprites':
        #     assert os.path.exists('./{}_model_dsprites.pk'.format(self.model_name)), 'no model found!'
        #     if args.loss_fun != 'TC_montecarlo':
        #         from models.beta_vae_archs import Beta_VAE_Burgess
        #         self.model = nn.DataParallel(Beta_VAE_Burgess(z_dim=args.z_dim, n_cls=3,
        #                                 computes_std=args.computes_std).cuda(), device_ids=range(1))
        #     else:
        #         from models.beta_vae_archs import Beta_VAE_Burgess
        #         self.model = nn.DataParallel(Beta_VAE_Burgess(z_dim=args.z_dim,n_cls=3).cuda(), device_ids=range(1))

        #         dim = 64
        #     self.model.load_state_dict(torch.load('./{}_model_{}.pk'.format(args.model,args.dataset)))
        # else:
        #     assert os.path.exists('./betavae_model_3Dshapes.pk'), 'no model found!'
        #     from models.beta_vae_archs import Beta_VAE_BN
        #     self.model = nn.DataParallel(Beta_VAE_BN(z_dim=args.z_dim,nc=3, n_cls=4,
        #                             output_type='continues').cuda(), device_ids=range(1))
        #     self.model.load_state_dict(torch.load('./betavae_model_3Dshapes.pk'))
        #     dim = 64
        #     self.nc =3
        # assert os.path.exists('./{}_model_{}.pk'.format(self.model_name,self.dataset)), 'no model found!'
        
        # from models.beta_vae_archs import Beta_VAE_Burgess
        assert os.path.exists('./{}_model_{}.pk'.format(self.model_name,args.dataset)), 'no model found!'
        mod = __import__('models.beta_vae_archs', fromlist=['Beta_VAE_{}'.format(args.arch)])
        self.model = nn.DataParallel(getattr(mod, 'Beta_VAE_{}'.format(args.arch))(z_dim=args.z_dim, n_cls=args.n_cls,
                                computes_std=args.computes_std, nc=args.number_channel, 
                                output_type= args.output_type).cuda(), device_ids=range(1))


        dim = args.image_size
        self.model.load_state_dict(torch.load('./{}_model_{}.pk'.format(args.model,args.dataset)))

        self.output_dir = os.path.join(args.output_dir) 
        if args.name != 'main':
            self.todays_date =  args.name
        else:
            self.todays_date = datetime.datetime.today()
        self.output_dir = os.path.join(self.output_dir, str(self.todays_date))
        if not os.path.exists(self.output_dir):
            mkdirs(self.output_dir)

        self.device = 'cuda'
        self.dataset = args.dataset
        self.output_save = args.output_save
        if args.dataset == "mnist":
            self.width = 28
            self.height = 28
        
        else:
            self.width = 64
            self.height = 64
        if  os.path.exists('./AVG_KLDs_VAR_means.pth'):
            AVG_KLDs_var_means = torch.load('./AVG_KLDs_VAR_means.pth')
            self.AVG_KLDs = AVG_KLDs_var_means['AVG_KLDs']
            self.VAR_means = AVG_KLDs_var_means['VAR_means']

        else:
            self.model.module.eval() 
         
            print('computing average of KLDs and variance of means of latent units.')
            N = len(self.data_loader)
            dataset_size = N * args.batch_size
            AVG_KLD = torch.zeros(self.model.module.z_dim).cpu()
            #self.model.eval()
            means = []
            for i ,(x, y) in tqdm(enumerate(self.data_loader, 0),total=len(self.data_loader),smoothing=0.9):

                x = x.cuda(async=True)

                params = self.model.module.encoder(x).view(x.size(0),self.z_dim,2)
                mu, logstd_var = params.select(-1,0), params.select(-1,1)
                KLDs = self.model.module.kld_unit_guassians_per_sample(mu, logstd_var)
                AVG_KLD += KLDs.sum(0).cpu().detach()
                means.append(mu.cpu().detach())
            
            self.AVG_KLDs = AVG_KLD/ dataset_size
            self.VAR_means = torch.std(torch.cat(means),dim=0).pow(2)
            self.sorted_idx_KLDs = torch.argsort(self.AVG_KLDs,descending=True)
            self.sorted_idx_VAR_means = torch.argsort(self.AVG_KLDs,descending=True)
예제 #3
0
    def __init__(self, args):
        self.use_cuda = args.cuda and torch.cuda.is_available()
        self.max_epoch = args.max_epoch
        self.global_epoch = 0
        self.global_iter = 0

        self.z_dim = args.z_dim
        self.z_var = args.z_var
        self.z_sigma = math.sqrt(args.z_var)
        self.prior_dist = torch.distributions.Normal(
            torch.zeros(self.z_dim),
            torch.ones(self.z_dim) * self.z_sigma)
        self._lambda = args.reg_weight
        self.lr = args.lr
        self.lr_D = args.lr_D
        self.beta1 = args.beta1
        self.beta2 = args.beta2
        self.lr_schedules = {30: 2, 50: 5, 100: 10}

        if args.dataset.lower() == 'celeba':
            self.nc = 3
            self.decoder_dist = 'gaussian'
        else:
            self.nc = 1
            self.decoder_dist = 'gaussian'
            # raise NotImplementedError

        self.net = cuda(WAE(self.z_dim, self.nc), self.use_cuda)
        self.optim = optim.Adam(self.net.parameters(),
                                lr=self.lr,
                                betas=(self.beta1, self.beta2))

        self.D = cuda(Adversary(self.z_dim), self.use_cuda)
        self.optim_D = optim.Adam(self.D.parameters(),
                                  lr=self.lr_D,
                                  betas=(self.beta1, self.beta2))

        self.gather = DataGather()
        self.viz_name = args.viz_name
        self.viz_port = args.viz_port
        self.viz_on = args.viz_on
        if self.viz_on:
            self.viz = visdom.Visdom(env=self.viz_name + '_lines',
                                     port=self.viz_port)
            self.win_recon = None
            self.win_QD = None
            self.win_D = None
            self.win_mu = None
            self.win_var = None
        else:
            self.viz = None
            self.win_recon = None
            self.win_QD = None
            self.win_D = None
            self.win_mu = None
            self.win_var = None

        self.ckpt_dir = Path(args.ckpt_dir).joinpath(args.viz_name)
        if not self.ckpt_dir.exists():
            self.ckpt_dir.mkdir(parents=True, exist_ok=True)
        self.ckpt_name = args.ckpt_name
        if self.ckpt_name is not None:
            self.load_checkpoint(self.ckpt_name)

        self.save_output = args.save_output
        self.output_dir = Path(args.output_dir).joinpath(args.viz_name)
        if not self.output_dir.exists():
            self.output_dir.mkdir(parents=True, exist_ok=True)

        self.dset_dir = args.dset_dir
        self.dataset = args.dataset
        self.batch_size = args.batch_size
        self.data_loader = return_data(args)
dim = args.image_size
dim = args.image_size

mod = __import__('models.beta_vae_archs',
                 fromlist=['Beta_VAE_{}'.format(args.arch)])
model = nn.DataParallel(getattr(mod, 'Beta_VAE_{}'.format(args.arch))(
    z_dim=args.z_dim,
    n_cls=args.n_cls,
    output_size=(dim, dim),
    hyper_params=hyp_params,
    computes_std=args.computes_std,
    nc=args.number_channel,
    output_type=args.output_type).cuda(),
                        device_ids=range(1))

data_loader = return_data(args)
pre_train_optimizer = Adam(itertools.chain(model.module.encoder.parameters(),
                                           model.module.decoder.parameters()),
                           lr=1e-3)
train_optimizer = Adam(model.module.parameters(), lr=2e-3)
lr_scheduler = StepLR(train_optimizer, step_size=10, gamma=0.95)
if os.path.exists('./model_{}.pk'.format(args.dataset)):
    model.load_state_dict(torch.load('./model_{}.pk'.format(args.dataset)))


def pre_train(data_loader):
    loss = []
    for x, y in data_loader:
        pre_train_optimizer.zero_grad()
        x = x.view((-1, 1, dim, dim)).cuda()
        mu, _ = model.module.encoder(x)
예제 #5
0
    def __init__(self, args):
        self.use_cuda = args.cuda and torch.cuda.is_available()
        self.max_iter = args.max_iter
        self.global_iter = 0

        self.z_dim = args.z_dim
        self.beta = args.beta
        self.gamma = args.gamma
        self.C_max = args.C_max
        self.C_stop_iter = args.C_stop_iter
        self.objective = args.objective
        self.model = args.model
        self.lr = args.lr
        self.beta1 = args.beta1
        self.beta2 = args.beta2

        if args.dataset.lower() == 'dsprites':
            self.nc = 1
            self.decoder_dist = 'bernoulli'
        elif args.dataset.lower() == '3dchairs':
            self.nc = 3
            self.decoder_dist = 'gaussian'
        elif args.dataset.lower() == 'celeba':
            self.nc = 3
            self.decoder_dist = 'gaussian'
        else:
            raise NotImplementedError

        if args.model == 'H':
            net = BetaVAE_H
        elif args.model == 'B':
            net = BetaVAE_B
        else:
            raise NotImplementedError('only support model H or B')

        self.net = cuda(net(self.z_dim, self.nc), self.use_cuda)
        self.optim = optim.Adam(self.net.parameters(),
                                lr=self.lr,
                                betas=(self.beta1, self.beta2))

        self.viz_name = args.viz_name
        self.viz_port = args.viz_port
        self.viz_on = args.viz_on
        self.win_recon = None
        self.win_kld = None
        self.win_mu = None
        self.win_var = None
        if self.viz_on:
            self.viz = visdom.Visdom(port=self.viz_port)

        self.ckpt_dir = os.path.join(args.ckpt_dir, args.viz_name)
        if not os.path.exists(self.ckpt_dir):
            os.makedirs(self.ckpt_dir, exist_ok=True)
        self.ckpt_name = args.ckpt_name
        if self.ckpt_name is not None:
            self.load_checkpoint(self.ckpt_name)

        self.save_output = args.save_output
        self.output_dir = os.path.join(args.output_dir, args.viz_name)
        if not os.path.exists(self.output_dir):
            os.makedirs(self.output_dir, exist_ok=True)

        self.gather_step = args.gather_step
        self.display_step = args.display_step
        self.save_step = args.save_step

        self.dset_dir = args.dset_dir
        self.dataset = args.dataset
        self.batch_size = args.batch_size
        self.data_loader = return_data(args)

        self.gather = DataGather()
예제 #6
0
    def __init__(self, args):
        self.use_cuda = args.cuda and torch.cuda.is_available()
        self.max_iter = args.max_iter
        self.global_iter = 0

        self.z_dim = args.z_dim
        self.beta = args.beta
        self.gamma = args.gamma
        self.C_max = args.C_max
        self.C_stop_iter = args.C_stop_iter
        self.objective = args.objective
        self.model = args.model
        self.lr = args.lr
        self.beta1 = args.beta1
        self.beta2 = args.beta2
        self.l1 = 10
        self.l2 = 1

        if args.dataset.lower() == 'mnist':
            self.nc = 1
            self.decoder_dist = 'bernoulli'

        net = BetaVAE_mnist_mod

        self.net = cuda(net(self.z_dim, self.nc), self.use_cuda)
        self.optim = optim.Adam(self.net.parameters(),
                                lr=self.lr,
                                betas=(self.beta1, self.beta2))

        self.viz_name = args.viz_name
        self.viz_port = args.viz_port
        self.viz_on = args.viz_on
        self.win_recon = None
        self.win_kld = None
        self.win_mu = None
        self.win_var = None
        if self.viz_on:
            self.viz = visdom.Visdom(port=self.viz_port)

        self.ckpt_dir = os.path.join(args.ckpt_dir, args.viz_name)
        if not os.path.exists(self.ckpt_dir):
            os.makedirs(self.ckpt_dir)
        self.ckpt_name = args.ckpt_name
        if self.ckpt_name is not None:
            self.load_checkpoint(self.ckpt_name)
#self.intervene()
        self.save_output = args.save_output
        self.output_dir = os.path.join(args.output_dir, args.viz_name)
        if not os.path.exists(self.output_dir):
            os.makedirs(self.output_dir)

        self.gather_step = args.gather_step
        self.display_step = args.display_step
        self.save_step = args.save_step

        self.dset_dir = args.dset_dir
        self.dataset = args.dataset
        self.batch_size = args.batch_size
        self.data_loader = return_data(args)

        self.gather = DataGather()
    def __init__(self, args):
        self.data_loader = return_data(args)
        
        self.model_name = args.model
        self.z_dim = [args.con_latent_dims, args.disc_latent_dims]
        self.nc = 1
        dim = args.image_size
        mod = __import__('models.joint_b_vae_archs', fromlist=['Joint_VAE_{}'.format(args.arch)])
        self.model = nn.DataParallel(getattr(mod, 'Joint_VAE_{}'.format(args.arch))(z_dim=args.z_dim, n_cls=args.n_cls,
                                computes_std=args.computes_std, nc=args.number_channel, output_size=(dim,dim),
                                output_type= args.output_type).cuda(), device_ids=range(1))
        self.nc =3
        self.output_dir = os.path.join(args.output_dir) 
        self.todays_date = datetime.datetime.today()
        self.output_dir = os.path.join(self.output_dir, str(self.todays_date))
        if not os.path.exists(self.output_dir):
            mkdirs(self.output_dir)

        self.device = 'cuda'
        self.dataset = args.dataset
        self.output_save = args.output_save
        if args.dataset == "mnist":
            self.width = 32
            self.height = 32
        
        else:
            self.width = 64
            self.height = 64
        if  os.path.exists('./AVG_KLDs_VAR_means.pth'):
            AVG_KLDs_var_means = torch.load('./AVG_KLDs_VAR_means.pth')
            self.AVG_KLDs = AVG_KLDs_var_means['AVG_KLDs']
            self.VAR_means = AVG_KLDs_var_means['VAR_means']

        else:
            with torch.no_grad(): 
                self.model.module.eval()
         
                print('computing average of KLDs and variance of means of latent units.')
                N = len(self.data_loader)
                dataset_size = N * args.batch_size
                AVG_KLD = torch.zeros(self.model.module.z_dim[0]).cpu()
                #self.model.eval()
                means = []


                
                print('And computing average of KLDs and variance of means of discrete latent units.')
 
                means = []
                AVG_KLDs_disc = [torch.Tensor([0]) for i in range(len(self.z_dim[1]))]
                for i ,(x, y) in tqdm(enumerate(self.data_loader, 0),total=len(self.data_loader),smoothing=0.9):

                    x = x.cuda(async=True)

                    x_hat, mu, logstd_var, z, alphas, rep_as = self.model(x)

                    loss, recon, KLD_total_mean, KLDs, KLD_catigorical_total_mean, KLDs_catigorical = \
                                        self.model.module.losses(x, x_hat, mu, logstd_var, z, alphas, rep_as, 'H')
                                        
                    for ii in range(len(KLDs_catigorical)):
                        AVG_KLDs_disc[ii] += KLDs_catigorical[ii]

                    KLDs = self.model.module.kld_unit_guassians_per_sample(mu, logstd_var)
                    AVG_KLD += KLDs.sum(0).cpu().detach()
                    means.append(mu.cpu().detach())
                for ii in range(len(AVG_KLDs_disc)):
                    AVG_KLDs_disc[ii] = args.batch_size *AVG_KLDs_disc[ii]/ dataset_size
                self.AVG_KLDs_disc = torch.cat(AVG_KLDs_disc,dim=0)
                self.sorted_idx_KLDs_disc = self.AVG_KLDs_disc.argsort(0)

                self.AVG_KLDs = AVG_KLD / dataset_size
                self.VAR_means = torch.std(torch.cat(means),dim=0).pow(2)
                self.sorted_idx_KLDs = torch.argsort(self.AVG_KLDs,descending=True)
                self.sorted_idx_VAR_means = torch.argsort(self.AVG_KLDs,descending=True)
        mkdirs(os.path.join(self.output_dir,"images"))
예제 #8
0
    def __init__(self, args):
        self.args = args

        # Misc
        use_cuda = args.cuda and torch.cuda.is_available()
        self.device = 'cuda' if use_cuda else 'cpu'
        self.name = args.name
        self.max_iter = int(args.max_iter)
        self.print_iter = args.print_iter
        self.global_iter = 0
        self.pbar = tqdm(total=self.max_iter)

        # Data
        assert args.dataset == 'dsprites', 'Only dSprites is implemented'
        self.dset_dir = args.dset_dir
        self.dataset = args.dataset
        self.batch_size = args.batch_size
        self.data_loader, self.dataset = return_data(args)

        # Networks & Optimizers
        self.z_dim = args.z_dim
        self.gamma = args.gamma

        self.lr_VAE = args.lr_VAE
        self.beta1_VAE = args.beta1_VAE
        self.beta2_VAE = args.beta2_VAE

        self.lr_D = args.lr_D
        self.beta1_D = args.beta1_D
        self.beta2_D = args.beta2_D

        # Disentanglement score
        self.L = args.L
        self.vote_count = args.vote_count
        self.dis_score = args.dis_score
        self.dis_batch_size = args.dis_batch_size

        # Models and optimizers
        self.VAE = FactorVAE(self.z_dim).to(self.device)
        self.nc = 1

        self.optim_VAE = optim.Adam(self.VAE.parameters(),
                                    lr=self.lr_VAE,
                                    betas=(self.beta1_VAE, self.beta2_VAE))

        self.D = Discriminator(self.z_dim).to(self.device)
        self.optim_D = optim.Adam(self.D.parameters(),
                                  lr=self.lr_D,
                                  betas=(self.beta1_D, self.beta2_D))

        self.nets = [self.VAE, self.D]

        # Attention Disentanglement loss
        self.ad_loss = args.ad_loss
        self.lamb = args.lamb
        if self.ad_loss:
            self.gcam = GradCAM(self.VAE.encode, args.target_layer,
                                self.device, args.image_size)
            self.pick2 = True

        # Checkpoint
        self.ckpt_dir = os.path.join(args.ckpt_dir,
                                     args.name + '_' + str(args.seed))
        self.ckpt_save_iter = args.ckpt_save_iter
        if self.max_iter >= args.ckpt_save_iter:
            mkdirs(self.ckpt_dir)
        if args.ckpt_load:
            self.load_checkpoint(args.ckpt_load)

        # Results
        self.results_dir = os.path.join(args.results_dir,
                                        args.name + '_' + str(args.seed))
        self.results_save = args.results_save

        self.outputs = {
            'vae_recon_loss': [],
            'vae_kld': [],
            'vae_tc_loss': [],
            'D_tc_loss': [],
            'ad_loss': [],
            'dis_score': [],
            'iteration': []
        }
예제 #9
0
    def __init__(self, parameters):
        self.use_cuda = parameters['use_cuda'] and torch.cuda.is_available()
        self.max_iter = parameters['max_iter']
        self.global_iter = 0
        if parameters['model'] == 'G':
            self.split_z_dim = parameters['split_z_dim']
            self.z_dim = sum(self.split_z_dim)
        else:
            self.z_dim = parameters['z_dim']
        self.beta = parameters['beta']
        self.objective = parameters['objective']
        self.model = parameters['model']
        self.lr = parameters['lr']
        self.beta1 = parameters['beta1']
        self.beta2 = parameters['beta2']
        if parameters['dataset'].lower() == 'dsprites':
            self.nc = 1
            self.decoder_dist = 'bernoulli'
        elif parameters['dataset'].lower() == '3dchairs':
            self.nc = 3
            self.decoder_dist = 'gaussian'
        elif parameters['dataset'].lower() == 'celeba':
            self.nc = 3
            self.decoder_dist = 'gaussian'
        else:
            raise NotImplementedError

        if parameters['model'] == 'H':
            net = BetaVAE_H
        elif parameters['model'] == 'B':
            net = BetaVAE_B
        elif parameters['model'] == 'G':
            net = Split_BetaVAE_G
        else:
            raise NotImplementedError('only support model H or B')

        if parameters['model'] == 'G':
            self.net = cuda(net(self.split_z_dim, self.nc), self.use_cuda)
        else:
            self.net = cuda(net(self.z_dim, self.nc), self.use_cuda)
        self.optim = optim.Adam(self.net.parameters(),
                                lr=self.lr,
                                betas=(self.beta1, self.beta2))
        self.viz_name = parameters['project_name']
        self.viz_on = parameters['viz_on']
        self.tensorboard_dir = os.path.join(parameters['tensorboard_dir'],
                                            self.viz_name)
        self.ckpt_dir = os.path.join(parameters['ckpt_dir'],
                                     parameters['project_name'])
        #print(self.ckpt_dir)
        if not os.path.exists(self.ckpt_dir):
            os.makedirs(self.ckpt_dir, exist_ok=True)
        self.ckpt_name = parameters['ckpt_name']
        if self.ckpt_name is not None:
            self.load_checkpoint(self.ckpt_name)
        if self.viz_on:
            self.viz = SummaryWriter(
                os.path.join(parameters['tensorboard_dir'], self.viz_name))
        self.save_output = parameters['save_output']
        self.output_dir = os.path.join(parameters['output_dir'],
                                       parameters['project_name'])
        if not os.path.exists(self.output_dir):
            os.makedirs(self.output_dir, exist_ok=True)

        self.gather_step = parameters['gather_step']
        self.display_step = parameters['display_step']
        self.save_step = parameters['save_step']

        self.dset_dir = parameters['dset_dir']
        self.dataset = parameters['dataset']
        self.batch_size = parameters['batch_size']
        self.data_loader = return_data(parameters)

        self.gather = DataGather()
def mutual_info_metric_shapes(vae, args):
    dataset_loader = return_data(args)

    N = len(dataset_loader.dataset)  # number of data samples
    K = vae.module.z_dim  # number of latent variables
    nparams = 2

    print('Computing q(z|x) distributions.')
    qz_params = torch.Tensor(N, K, nparams)

    n = 0
    for xs, _ in dataset_loader:
        batch_size = xs.size(0)
        xs = xs.view(batch_size, 1, 64, 64).cuda()
        qz_params[n:n + batch_size] = vae.module.encoder(xs).view(
            batch_size, vae.module.z_dim, nparams).data
        n += batch_size

    if not vae.module.computes_std:
        qz_params[:, :, 1] = qz_params[:, :, 1] * 0.5
    qz_params = qz_params.view(3, 6, 40, 32, 32, K, nparams).cuda()
    mu, logstd_var = qz_params.select(-1, 0), qz_params.select(-1, 1)

    qz_samples = vae.module.reparam(mu, logstd_var)

    print('Estimating marginal entropies.')
    # marginal entropies
    marginal_entropies = estimate_entropies(
        qz_samples.view(N, K).transpose(0, 1), qz_params.view(N, K, nparams))

    marginal_entropies = marginal_entropies.cpu()
    cond_entropies = torch.zeros(4, K)

    print('Estimating conditional entropies for scale.')
    for i in range(6):
        qz_samples_scale = qz_samples[:, i, :, :, :, :].contiguous()
        qz_params_scale = qz_params[:, i, :, :, :, :].contiguous()

        cond_entropies_i = estimate_entropies(
            qz_samples_scale.view(N // 6, K).transpose(0, 1),
            qz_params_scale.view(N // 6, K, nparams))

        cond_entropies[0] += cond_entropies_i.cpu() / 6

    print('Estimating conditional entropies for orientation.')
    for i in range(40):
        qz_samples_scale = qz_samples[:, :, i, :, :, :].contiguous()
        qz_params_scale = qz_params[:, :, i, :, :, :].contiguous()

        cond_entropies_i = estimate_entropies(
            qz_samples_scale.view(N // 40, K).transpose(0, 1),
            qz_params_scale.view(N // 40, K, nparams))

        cond_entropies[1] += cond_entropies_i.cpu() / 40

    print('Estimating conditional entropies for pos x.')
    for i in range(32):
        qz_samples_scale = qz_samples[:, :, :, i, :, :].contiguous()
        qz_params_scale = qz_params[:, :, :, i, :, :].contiguous()

        cond_entropies_i = estimate_entropies(
            qz_samples_scale.view(N // 32, K).transpose(0, 1),
            qz_params_scale.view(N // 32, K, nparams))

        cond_entropies[2] += cond_entropies_i.cpu() / 32

    print('Estimating conditional entropies for pox y.')
    for i in range(32):
        qz_samples_scale = qz_samples[:, :, :, :, i, :].contiguous()
        qz_params_scale = qz_params[:, :, :, :, i, :].contiguous()

        cond_entropies_i = estimate_entropies(
            qz_samples_scale.view(N // 32, K).transpose(0, 1),
            qz_params_scale.view(N // 32, K, nparams))

        cond_entropies[3] += cond_entropies_i.cpu() / 32

    metric = compute_metric_shapes(marginal_entropies, cond_entropies)
    return metric, marginal_entropies, cond_entropies
예제 #11
0
    def __init__(self, args):
        self.use_cuda = args.cuda and torch.cuda.is_available()
        self.max_iter = args.max_iter
        self.global_iter = args.global_iter

        self.z_dim = args.z_dim
        self.y_dim = args.y_dim

        self.gamma = args.gamma  # MLP
        self.alpha = args.alpha  # DIS
        self.beta = args.beta    # KL
        self.rec = args.rec      # REC

        self.image_size = args.image_size
        self.lr = args.lr

        if args.dataset.lower() == 'dsprites' or args.dataset.lower() == 'faces':
            self.nc = 1
            self.decoder_dist = 'bernoulli'

            self.netE = cuda(ICP_Encoder(z_dim = self.z_dim, y_dim = self.y_dim, nc = self.nc), self.use_cuda)
            self.netG_Less = cuda(ICP_Decoder(dim = self.z_dim, nc = self.nc), self.use_cuda)
            self.netG_More = cuda(ICP_Decoder(dim = self.y_dim, nc = self.nc), self.use_cuda)
            self.netG_Total = cuda(ICP_Decoder(dim = self.z_dim + self.y_dim, nc = self.nc), self.use_cuda)
            self.netD = cuda(Discriminator(y_dim = self.y_dim), self.use_cuda)
            self.netZ2Y = cuda(MLP(s_dim = self.z_dim, t_dim = self.y_dim), self.use_cuda)
            self.netY2Z = cuda(MLP(s_dim = self.y_dim, t_dim = self.z_dim), self.use_cuda)

        elif args.dataset.lower() == 'celeba':
            self.nc = 3
            self.decoder_dist = 'gaussian'

            self.netE = cuda(ICP_Encoder_Big(z_dim = self.z_dim, y_dim = self.y_dim, nc = self.nc), self.use_cuda)
            self.netG_Less = cuda(ICP_Decoder_Big(dim = self.z_dim, nc = self.nc), self.use_cuda)
            self.netG_More = cuda(ICP_Decoder_Big(dim = self.y_dim, nc = self.nc), self.use_cuda)
            self.netG_Total = cuda(ICP_Decoder_Big(dim = self.z_dim + self.y_dim, nc = self.nc), self.use_cuda)
            self.netD = cuda(Discriminator(y_dim = self.y_dim), self.use_cuda)
            self.netZ2Y = cuda(MLP(s_dim = self.z_dim, t_dim = self.y_dim), self.use_cuda)
            self.netY2Z = cuda(MLP(s_dim = self.y_dim, t_dim = self.z_dim), self.use_cuda)
        else:
            raise NotImplementedError

        self.optimizerG = optim.Adam([{'params' : self.netE.parameters()},
                         {'params' : self.netG_Less.parameters()},
                         {'params' : self.netG_More.parameters()},
                         {'params' : self.netG_Total.parameters()}], lr = self.lr, betas = (0.9, 0.999))
        self.optimizerD = optim.Adam(self.netD.parameters(), lr = self.lr, betas = (0.9, 0.999))
        self.optimizerMLP = optim.Adam([{'params' : self.netZ2Y.parameters()},
                           {'params' : self.netY2Z.parameters()}], lr = self.lr, betas = (0.9, 0.999))


        self.save_name = args.save_name
        self.ckpt_dir = os.path.join('./checkpoints/', args.save_name)
        if not os.path.exists(self.ckpt_dir):
            os.makedirs(self.ckpt_dir, exist_ok=True)
        self.ckpt_name = args.ckpt_name
        if args.train and self.ckpt_name is not None:
            self.load_checkpoint(self.ckpt_name)
        self.save_output = args.save_output
        self.output_dir = os.path.join('./trainvis/', args.save_name)
        if not os.path.exists(self.output_dir):
            os.makedirs(self.output_dir, exist_ok=True)

        self.display_step = args.display_step
        self.save_step = args.save_step

        self.dataset = args.dataset
        self.batch_size = args.batch_size
        self.data_loader = return_data(args)
def get_samples_F_VAE_metric_disc(model, L, num_votes, num_test_votes, args):
    dataset_loader = return_data(args)

    N = len(dataset_loader.dataset)  # number of data samples
    K = 2*model.module.num_latent_dims +  \
        model.module.num_latent_dims_disc          # number of latent variables

    nparams = 2
    qz_params = torch.Tensor(N, K).cuda()

    n = 0
    with torch.no_grad():

        #self.model.eval()
        means = []
        for xs, _ in dataset_loader:
            batch_size = xs.shape[0]
            qz_params[n:n + batch_size] = model.module.encoder(xs.cuda())
            n += batch_size

    mu, logstd_var = qz_params[:, 0:model.module.
                               z_dim[0]], qz_params[:,
                                                    model.module.z_dim[0]:2 *
                                                    model.module.z_dim[0]]

    d = 2 * model.module.z_dim[0]  #starting index of discrete vars
    alphas = []
    rep_as = 0
    if len(model.module.z_dim[1]) > 0:

        for dims in model.module.z_dim[1]:
            alphas.append(
                torch.softmax(qz_params[:, d:d + dims], dim=-1).cuda())
            d += dims
        rep_as = []
        for alpha in alphas:
            rep_as.append(dist.Gumbel_Softmax.sample(alpha, 0.67, False))
        rep_as = model.module.reparam_disc(alphas)

        accepted_disc = []
        accepted_disc_dims = []
        for index_disc, alpha in enumerate(alphas):
            unifom_params = torch.ones_like(alpha) / alpha.shape[1]
            kld = kl_divergence(Categorical(alpha), Categorical(unifom_params))
            if kld.mean() > 1e-2:
                # normalize
                alpha = alpha / torch.sqrt(
                    compute_disc_embds_var(alpha, [alpha.shape[1]])[0].float())

                accepted_disc.append(alpha)
                accepted_disc_dims.append(alpha.shape[-1])

    KLDs = model.module.kld_unit_guassians_per_sample(mu, logstd_var)
    KLDs = KLDs.sum(0) / N
    # discarding latent dimensions with small KLD
    idx = torch.where(KLDs > 5e-2)[0]
    mu = mu[:, idx]
    num_con_dims = mu.shape[1]
    list_num_con_disc = [alpha.shape[-1] for alpha in accepted_disc]
    list_samples = []
    list_test_samples = []
    #append the accepted disc latent dims
    if len(accepted_disc) > 1:
        accepted_disc = torch.cat(accepted_disc, dim=-1)
    elif len(accepted_disc) == 1:
        accepted_disc = accepted_disc[0]

    mu = mu / torch.std(mu, axis=0)
    #discrete embds have been already normilized
    if len(accepted_disc) != 0:
        mu = torch.cat([mu.cuda(), accepted_disc.cuda()], dim=-1)
    K = mu.shape[1]

    if args.dataset == 'dsprites':
        # 5 is the number of generative factors
        num_votes_per_factor = num_votes / 5
        num_samples_per_factor = int(num_votes_per_factor * L)

        num_test_votes_per_factor = num_test_votes / 5
        num_test_samples_per_factor = int(num_test_votes_per_factor * L)
        mu = mu.view(3, 6, 40, 32, 32, K)

        #####################
        # SHAPE FIXED
        #####################
        unused = []  # the unused indices
        for _ in range(3):
            unused.append(ops.choice(0, 6 * 40 * 32 * 32))

        shape_fixed, unused = ops.get_fixed_factor_samples(
            mu, 0, num_samples_per_factor, K, L, 3, unused)
        list_samples.append(shape_fixed)
        del shape_fixed
        shape_fixed = torch.zeros((num_test_samples_per_factor, K))
        shape_fixed, _ = ops.get_fixed_factor_samples(
            mu, 0, num_test_samples_per_factor, K, L, 3, unused)
        list_test_samples.append(shape_fixed)
        del shape_fixed
        ###############################

        ################################
        # SCALE FIXED
        ################################
        unused = []  # the unused indices
        for _ in range(6):
            unused.append(ops.choice(0, 3 * 40 * 32 * 32))

        scale_fixed, unused = ops.get_fixed_factor_samples(
            mu, 1, num_samples_per_factor, K, L, 6, unused)
        list_samples.append(scale_fixed)
        del scale_fixed
        scale_fixed, _ = ops.get_fixed_factor_samples(
            mu, 1, num_test_samples_per_factor, K, L, 6, unused)
        list_test_samples.append(scale_fixed)
        del scale_fixed
        ################################

        ################################
        # ORIENTATION FIXED
        ################################
        unused = []  # the unused indices
        for _ in range(40):
            unused.append(ops.choice(0, 3 * 6 * 32 * 32))
        orientation_fixed, unused = ops.get_fixed_factor_samples(
            mu, 2, num_samples_per_factor, K, L, 40, unused)
        list_samples.append(orientation_fixed)
        del orientation_fixed
        orientation_fixed, _ = ops.get_fixed_factor_samples(
            mu, 2, num_test_samples_per_factor, K, L, 40, unused)
        list_test_samples.append(orientation_fixed)
        del orientation_fixed

        #################################

        #################################
        # COORDINATE ON X-AXIS FIXED
        #################################
        unused = []  # the unused indices
        for _ in range(32):
            unused.append(ops.choice(0, 3 * 6 * 40 * 32))

        posx_fixed, unused = ops.get_fixed_factor_samples(
            mu, 3, num_samples_per_factor, K, L, 32, unused)
        list_samples.append(posx_fixed)
        del posx_fixed
        posx_fixed, _ = ops.get_fixed_factor_samples(
            mu, 3, num_test_samples_per_factor, K, L, 32, unused)
        list_test_samples.append(posx_fixed)
        del posx_fixed
        ###############################

        #################################
        # COORDINATE ON Y-AXIS FIXED
        #################################
        unused = []  # the unused indices
        for _ in range(32):
            unused.append(ops.choice(0, 3 * 6 * 40 * 32))

        posy_fixed, unused = ops.get_fixed_factor_samples(
            mu, 4, num_samples_per_factor, K, L, 32, unused)
        list_samples.append(posy_fixed)
        del posy_fixed
        posy_fixed, _ = ops.get_fixed_factor_samples(
            mu, 4, num_test_samples_per_factor, K, L, 32, unused)
        list_test_samples.append(posy_fixed)
        del posy_fixed
        ###############################
    else:
        num_votes_per_factor = num_votes / 6
        num_samples_per_factor = int(num_votes_per_factor * L)

        num_test_votes_per_factor = num_test_votes / 6
        num_test_samples_per_factor = int(num_test_votes_per_factor * L)
        mu = mu.view(10, 10, 10, 8, 4, 15, K)
        factors_variations = [10, 10, 10, 8, 4, 15]
        total_multi = 10 * 10 * 10 * 8 * 4 * 15
        num_rest = [total_multi // vari for vari in factors_variations]
        for fac_id, vari in enumerate(factors_variations):
            unused = []  # the unused indices
            for _ in range(vari):
                unused.append(ops.choice(0, num_rest[fac_id]))
            genfac_fixed, unused = ops.get_fixed_factor_samples(
                mu, fac_id, num_samples_per_factor, K, L, vari, unused)
            list_samples.append(genfac_fixed)
            del genfac_fixed
            genfac_fixed, _ = ops.get_fixed_factor_samples(
                mu, fac_id, num_test_samples_per_factor, K, L, vari, unused)
            list_test_samples.append(genfac_fixed)
            del genfac_fixed

    return list_samples, list_test_samples, num_con_dims, accepted_disc_dims
def get_samples_F_VAE_metric_v3(model, L, num_votes, num_test_votes, args):
    """This is the one that is used in the paper
    """
    dataset_loader = return_data(args)

    N = len(dataset_loader.dataset)  # number of data samples
    K = args.z_dim  # number of latent variables

    nparams = 2
    qz_params = torch.Tensor(N, K, nparams)

    n = 0
    with torch.no_grad():
        for xs, _ in dataset_loader:
            batch_size = xs.shape[0]
            qz_params[n:n + batch_size] = model.module.encoder(xs.cuda()).view(
                batch_size, model.module.z_dim, nparams).data
            n += batch_size

    mu, logstd_var = qz_params.select(-1, 0), qz_params.select(-1, 1)
    z = model.module.reparam(mu, logstd_var)
    KLDs = model.module.kld_unit_guassians_per_sample(mu, logstd_var).mean(0)

    # discarding latent dimensions with small KLD

    idx = torch.where(KLDs > 1)[0]
    print('Mean KL-diveregence of units')
    # print(KLDs)

    std = torch.std(mu, axis=0)
    mu = mu / std
    # global_var = torch.var(mu,axis = 0)
    # print("variances: ")
    # print(global_var)
    # idx = torch.where(global_var>0.005)[0]
    mu = mu[:, idx]
    K = mu.shape[1]
    print('There are :{} active unit'.format(K))
    list_samples = []
    list_test_samples = []
    # global_var = global_var[idx]

    if args.dataset == 'dsprites':

        # 5 is the number of generative factors
        num_votes_per_factor = num_votes / 5
        num_samples_per_factor = int(num_votes_per_factor * L)

        num_test_votes_per_factor = num_test_votes / 5
        num_test_samples_per_factor = int(num_test_votes_per_factor * L)

        mu = mu.view(3, 6, 40, 32, 32, K)

        #####################
        # SHAPE FIXED
        #####################
        unused = []  # the unused indices
        for _ in range(3):
            unused.append(ops.choice(0, 6 * 40 * 32 * 32))

        shape_fixed, unused = ops.get_fixed_factor_samples(
            mu, 0, num_samples_per_factor, K, L, 3, unused)
        list_samples.append(shape_fixed)
        del shape_fixed
        shape_fixed = torch.zeros((num_test_samples_per_factor, K))
        shape_fixed, _ = ops.get_fixed_factor_samples(
            mu, 0, num_test_samples_per_factor, K, L, 3, unused)
        list_test_samples.append(shape_fixed)
        del shape_fixed
        ###############################

        ################################
        # SCALE FIXED
        ################################
        unused = []  # the unused indices
        for _ in range(6):
            unused.append(ops.choice(0, 3 * 40 * 32 * 32))

        scale_fixed, unused = ops.get_fixed_factor_samples(
            mu, 1, num_samples_per_factor, K, L, 6, unused)
        list_samples.append(scale_fixed)
        del scale_fixed
        scale_fixed, _ = ops.get_fixed_factor_samples(
            mu, 1, num_test_samples_per_factor, K, L, 6, unused)
        list_test_samples.append(scale_fixed)
        del scale_fixed
        ################################

        ################################
        # ORIENTATION FIXED
        ################################
        unused = []  # the unused indices
        for _ in range(40):
            unused.append(ops.choice(0, 3 * 6 * 32 * 32))
        orientation_fixed, unused = ops.get_fixed_factor_samples(
            mu, 2, num_samples_per_factor, K, L, 40, unused)
        list_samples.append(orientation_fixed)
        del orientation_fixed
        orientation_fixed, _ = ops.get_fixed_factor_samples(
            mu, 2, num_test_samples_per_factor, K, L, 40, unused)
        list_test_samples.append(orientation_fixed)
        del orientation_fixed

        #################################

        #################################
        # COORDINATE ON X-AXIS FIXED
        #################################
        unused = []  # the unused indices
        for _ in range(32):
            unused.append(ops.choice(0, 3 * 6 * 40 * 32))

        posx_fixed, unused = ops.get_fixed_factor_samples(
            mu, 3, num_samples_per_factor, K, L, 32, unused)
        list_samples.append(posx_fixed)
        del posx_fixed
        posx_fixed, _ = ops.get_fixed_factor_samples(
            mu, 3, num_test_samples_per_factor, K, L, 32, unused)
        list_test_samples.append(posx_fixed)
        del posx_fixed
        ###############################

        #################################
        # COORDINATE ON Y-AXIS FIXED
        #################################
        unused = []  # the unused indices
        for _ in range(32):
            unused.append(ops.choice(0, 3 * 6 * 40 * 32))

        posy_fixed, unused = ops.get_fixed_factor_samples(
            mu, 4, num_samples_per_factor, K, L, 32, unused)
        list_samples.append(posy_fixed)
        del posy_fixed
        posy_fixed, _ = ops.get_fixed_factor_samples(
            mu, 4, num_test_samples_per_factor, K, L, 32, unused)
        list_test_samples.append(posy_fixed)
        del posy_fixed
        ###############################

    else:
        num_votes_per_factor = num_votes / 6
        num_samples_per_factor = int(num_votes_per_factor * L)

        num_test_votes_per_factor = num_test_votes / 6
        num_test_samples_per_factor = int(num_test_votes_per_factor * L)
        mu = mu.view(10, 10, 10, 8, 4, 15, K)
        factors_variations = [10, 10, 10, 8, 4, 15]
        total_multi = 10 * 10 * 10 * 8 * 4 * 15
        num_rest = [total_multi // vari for vari in factors_variations]
        for fac_id, vari in enumerate(factors_variations):
            unused = []  # the unused indices
            for _ in range(vari):
                unused.append(ops.choice(0, num_rest[fac_id]))
            genfac_fixed, unused = ops.get_fixed_factor_samples(
                mu, fac_id, num_samples_per_factor, K, L, vari, unused)
            list_samples.append(genfac_fixed)
            del genfac_fixed
            genfac_fixed, _ = ops.get_fixed_factor_samples(
                mu, fac_id, num_test_samples_per_factor, K, L, vari, unused)
            list_test_samples.append(genfac_fixed)
            del genfac_fixed

    return list_samples, list_test_samples
def get_samples_F_VAE_metric_v2(model, L, num_votes, args, used_smaples=None):
    """This is the one that is used in the paper
    """
    dataset_loader = return_data(args)

    N = len(dataset_loader.dataset)  # number of data samples
    K = args.z_dim  # number of latent variables

    nparams = 2
    qz_params = torch.Tensor(N, K, nparams)

    n = 0
    with torch.no_grad():
        for xs, _ in dataset_loader:
            batch_size = xs.shape[0]
            qz_params[n:n + batch_size] = model.module.encoder(xs.cuda()).view(
                batch_size, model.module.z_dim, nparams).data
            n += batch_size

    mu, logstd_var = qz_params.select(-1, 0), qz_params.select(-1, 1)
    z = model.module.reparam(mu, logstd_var)
    KLDs = model.module.kld_unit_guassians_per_sample(mu, logstd_var).mean(0)

    # discarding latent dimensions with small KLD

    # idx = torch.where(KLDs>1e-2)[0]
    global_var = torch.var(mu, axis=0)
    idx = torch.where(global_var > 5e-2)[0]
    mu = mu[:, idx]
    K = mu.shape[1]
    list_samples = []
    global_var = global_var[idx]
    if args.dataset == 'dsprites':
        if used_smaples == None:
            used_smaples = []
            factors = [3, 6, 40, 32, 32]
            for f in factors:
                used_smaples.append([0 for _ in range(f)])

        # 5 is the number of generative factors
        num_votes_per_factor = num_votes / 5
        num_samples_per_factor = int(num_votes_per_factor * L)
        mu = mu.view(3, 6, 40, 32, 32, K)
        # for factor in range(3):
        #     if factor == 0:
        #         shape_fixed = mu[0,:,:,:,:].view(1,6*40*32*32,K)
        #     else:
        #         shape_fixed = torch.cat([shape_fixed, mu[factor,:,:,:,:].view(1,6*40*32*32,K)],dim=0)
        shape_fixed = torch.zeros((num_samples_per_factor, K))
        for idx in range(0, num_samples_per_factor, L):
            fixed = torch.randint(0, 3, (1, ))

            shape_fixed[idx:idx + L] = mu[fixed, :, :, :, :].view(
                6 * 40 * 32 * 32, K)[torch.randint(0, 6 * 40 * 32 * 32,
                                                   (L, )), :]

        list_samples.append(shape_fixed)
        del shape_fixed

        scale_fixed = torch.zeros((num_samples_per_factor, K))
        for idx in range(0, num_samples_per_factor, L):
            fixed = torch.randint(0, 6, (1, ))

            scale_fixed[idx:idx + L] = mu[:, fixed, :, :, :].view(
                3 * 40 * 32 * 32, K)[torch.randint(0, 3 * 40 * 32 * 32,
                                                   (L, )), :]
        list_samples.append(scale_fixed)
        del scale_fixed

        orientation_fixed = torch.zeros((num_samples_per_factor, K))
        for idx in range(0, num_samples_per_factor, L):
            fixed = torch.randint(0, 40, (1, ))

            orientation_fixed[idx:idx + L] = mu[:, :, fixed, :, :].view(
                3 * 6 * 32 * 32, K)[torch.randint(0, 3 * 6 * 32 * 32,
                                                  (L, )), :]

        list_samples.append(orientation_fixed)
        del orientation_fixed

        posx_fixed = torch.zeros((num_samples_per_factor, K))
        for idx in range(0, num_samples_per_factor, L):
            fixed = torch.randint(0, 32, (1, ))

            posx_fixed[idx:idx + L] = mu[:, :, :, fixed, :].view(
                3 * 6 * 40 * 32, K)[torch.randint(0, 3 * 6 * 40 * 32,
                                                  (L, )), :]

        list_samples.append(posx_fixed)
        del posx_fixed

        posy_fixed = torch.zeros((num_samples_per_factor, K))
        for idx in range(0, num_samples_per_factor, L):

            idx = used_smaples[4][fixed]
            posy_fixed[idx:idx + L] = mu[:, :, :, :, fixed].view(
                3 * 6 * 40 * 32, K)[torch.randint(0, 3 * 6 * 40 * 32,
                                                  (L, )), :]

        list_samples.append(posy_fixed)
        del posy_fixed
    else:
        pass

    return list_samples, global_var, used_smaples
def get_samples_F_VAE_metric(model, args):
    dataset_loader = return_data(args)

    N = len(dataset_loader.dataset)  # number of data samples
    K = args.z_dim  # number of latent variables

    nparams = 2
    qz_params = torch.Tensor(N, K, nparams)

    n = 0
    with torch.no_grad():
        for xs, _ in dataset_loader:
            batch_size = xs.shape[0]
            qz_params[n:n + batch_size] = model.module.encoder(xs.cuda()).view(
                batch_size, model.module.z_dim, nparams).data
            n += batch_size

    mu, logstd_var = qz_params.select(-1, 0), qz_params.select(-1, 1)
    z = model.module.reparam(mu, logstd_var)
    KLDs = model.module.kld_unit_guassians_per_sample(mu, logstd_var).mean(0)

    # discarding latent dimensions with small KLD
    idx = torch.where(KLDs > 1e-2)[0]
    mu = mu[:, idx]
    K = mu.shape[1]
    list_samples = []
    if args.dataset == 'dsprites':
        mu = mu.view(3, 6, 40, 32, 32, K)
        # for factor in range(3):
        #     if factor == 0:
        #         shape_fixed = mu[0,:,:,:,:].view(1,6*40*32*32,K)
        #     else:
        #         shape_fixed = torch.cat([shape_fixed, mu[factor,:,:,:,:].view(1,6*40*32*32,K)],dim=0)
        fixed = torch.randint(0, 3, (1, ))
        shape_fixed = mu[fixed, :, :, :, :].view(6 * 40 * 32 * 32, K)
        list_samples.append(shape_fixed)
        del shape_fixed

        fixed = torch.randint(0, 6, (1, ))
        scale_fixed = mu[:, fixed, :, :, :].view(3 * 40 * 32 * 32, K)
        list_samples.append(scale_fixed)
        del scale_fixed

        fixed = torch.randint(0, 40, (1, ))
        orientation_fixed = mu[:, :, fixed, :, :].view(3 * 6 * 32 * 32, K)
        list_samples.append(orientation_fixed)
        del orientation_fixed

        fixed = torch.randint(0, 32, (1, ))
        posx_fixed = mu[:, :, :, fixed, :].view(3 * 6 * 40 * 32, K)
        list_samples.append(posx_fixed)
        del posx_fixed

        fixed = torch.randint(0, 32, (1, ))
        posy_fixed = mu[:, :, :, :, fixed].view(3 * 6 * 40 * 32, K)
        list_samples.append(posy_fixed)
        del posy_fixed
    else:
        pass

    return list_samples
def mutual_info_metric_faces(vae, args):
    dataset_loader = return_data(args)

    N = len(dataset_loader.dataset)  # number of data samples
    K = vae.module.z_dim  # number of latent variables
    nparams = 2

    print('Computing q(z|x) distributions.')
    qz_params = torch.Tensor(N, K, nparams)

    n = 0
    for xs, _ in dataset_loader:
        batch_size = xs.size(0).cuda()
        xs = xs.view(batch_size, 1, 64, 64).cuda()
        qz_params[n:n + batch_size] = vae.module.encoder.forward(xs).view(
            batch_size, vae.module.z_dim, nparams).data
        n += batch_size

    if not vae.module.computes_std:
        qz_params[:, :, 1] = qz_params[:, :, 1] * 0.5
    qz_params = qz_params.view(50, 21, 11, 11, K, nparams).cuda()
    mu, logstd_var = qz_params.select(-1, 0), params.select(-1, 1)

    qz_samples = vae.module.reparam(mu, logstd_var)

    print('Estimating marginal entropies.')
    # marginal entropies
    marginal_entropies = estimate_entropies(
        qz_samples.view(N, K).transpose(0, 1), qz_params.view(N, K, nparams))

    marginal_entropies = marginal_entropies.cpu()
    cond_entropies = torch.zeros(3, K)

    print('Estimating conditional entropies for azimuth.')
    for i in range(21):
        qz_samples_pose_az = qz_samples[:, i, :, :, :].contiguous()
        qz_params_pose_az = qz_params[:, i, :, :, :].contiguous()

        cond_entropies_i = estimate_entropies(
            qz_samples_pose_az.view(N // 21, K).transpose(0, 1),
            qz_params_pose_az.view(N // 21, K, nparams))

        cond_entropies[0] += cond_entropies_i.cpu() / 21

    print('Estimating conditional entropies for elevation.')
    for i in range(11):
        qz_samples_pose_el = qz_samples[:, :, i, :, :].contiguous()
        qz_params_pose_el = qz_params[:, :, i, :, :].contiguous()

        cond_entropies_i = estimate_entropies(
            qz_samples_pose_el.view(N // 11, K).transpose(0, 1),
            qz_params_pose_el.view(N // 11, K, nparams))

        cond_entropies[1] += cond_entropies_i.cpu() / 11

    print('Estimating conditional entropies for lighting.')
    for i in range(11):
        qz_samples_lighting = qz_samples[:, :, :, i, :].contiguous()
        qz_params_lighting = qz_params[:, :, :, i, :].contiguous()

        cond_entropies_i = estimate_entropies(
            qz_samples_lighting.view(N // 11, K).transpose(0, 1),
            qz_params_lighting.view(N // 11, K, nparams))

        cond_entropies[2] += cond_entropies_i.cpu() / 11

    metric = compute_metric_faces(marginal_entropies, cond_entropies)
    return metric, marginal_entropies, cond_entropies