Ejemplo n.º 1
0
    def calculate_lower_bound(self, X_full):
        # CALCULATE LOWER BOUND:
        lower_bound = 0.
        RE_all = 0.
        KL_all = 0.

        MB = 500

        for i in range(X_full.size(0) / MB):
            x = X_full[i * MB:(i + 1) * MB].view(-1,
                                                 np.prod(self.args.input_size))

            # pass through VAE
            x_mean, x_logvar, z_0, z_T, z_q_mean, z_q_logvar = self.forward(x)

            # RE
            RE = log_Bernoulli(x, x_mean)

            # KL
            log_p_z = log_Normal_standard(z_T, dim=1)
            log_q_z = log_Normal_diag(z_0, z_q_mean, z_q_logvar, dim=1)
            KL = -torch.sum(log_p_z - log_q_z)

            RE_all += RE.cpu().data[0]
            KL_all += KL.cpu().data[0]

            # CALCULATE LOWER-BOUND: RE + KL - ln(N)
            lower_bound += (-RE + KL).cpu().data[0]

        lower_bound = lower_bound / X_full.size(0)

        return lower_bound
Ejemplo n.º 2
0
    def log_p_z2(self, z2):
        if self.args.prior == 'standard':
            log_prior = log_Normal_standard(z2, dim=1)

        elif self.args.prior == 'vampprior':
            # z2 - MB x M
            C = self.args.number_components

            # calculate params
            X = self.means(self.idle_input)

            # calculate params for given data
            z2_p_mean, z2_p_logvar = self.q_z2(X)  # C x M

            # expand z
            z_expand = z2.unsqueeze(1)
            means = z2_p_mean.unsqueeze(0)
            logvars = z2_p_logvar.unsqueeze(0)

            a = log_Normal_diag(z_expand, means, logvars, dim=2) - math.log(C)  # MB x C
            a_max, _ = torch.max(a, 1)  # MB
            # calculte log-sum-exp
            log_prior = (a_max + torch.log(torch.sum(torch.exp(a - a_max.unsqueeze(1)), 1)))  # MB

        else:
            raise Exception('Wrong name of the prior!')

        return log_prior
Ejemplo n.º 3
0
    def log_p_z(self, z):
        if self.args.prior == 'standard':
            log_prior = log_Normal_standard(z, dim=1)

        elif self.args.prior == 'vampprior':
            # z - MB x M
            C = self.args.number_components

            # calculate params
            X = self.means(self.idle_input)

            # calculate params for given data
            X = X.view(X.shape[0],1,self.args.input_size[1],self.args.input_size[2])
            z_p_mean, z_p_logvar = self.q_z(X)  # C x M

            # expand z
            z_expand = z.unsqueeze(1)
            means = z_p_mean.unsqueeze(0)
            logvars = z_p_logvar.unsqueeze(0)

            a = log_Normal_diag(z_expand, means, logvars, dim=2) - math.log(C)  # MB x C
            a_max, _ = torch.max(a, 1)  # MB x 1

            # calculte log-sum-exp
            log_prior = a_max + torch.log(torch.sum(torch.exp(a - a_max.unsqueeze(1)), 1))  # MB x 1

        else:
            raise Exception('Wrong name of the prior!')

        return log_prior
Ejemplo n.º 4
0
    def log_p_z2(self, z2):
        if self.args.prior == 'standard':
            log_prior = log_Normal_standard(z2, dim=1)

        elif self.args.prior == 'vampprior':
            # z - MB x M
            C = self.args.number_components

            # calculate params
            X = self.means(self.idle_input).view(-1, self.args.input_size[0], self.args.input_size[1], self.args.input_size[2])

            # calculate params for given data
            z2_p_mean, z2_p_logvar = self.q_z2(X)  # C x M)

            # expand z
            z_expand = z2.unsqueeze(1)
            means = z2_p_mean.unsqueeze(0)
            logvars = z2_p_logvar.unsqueeze(0)

            a = log_Normal_diag(z_expand, means, logvars, dim=2) - math.log(C)  # MB x C
            a_max, _ = torch.max(a, 1)  # MB
            # calculte log-sum-exp
            log_prior = (a_max + torch.log(torch.sum(torch.exp(a - a_max.unsqueeze(1)), 1)))  # MB

        else:
            raise Exception('Wrong name of the prior!')

        return log_prior
Ejemplo n.º 5
0
    def log_p_z(self, z, idle, c, prior):
        if prior == 'standard':
            log_prior = log_Normal_standard(z, dim=1)

        else:
            raise Exception('Wrong name of the prior!')

        return log_prior
Ejemplo n.º 6
0
def train_vae(epoch, args, train_loader, model, optimizer):
    # set loss to 0
    train_loss = 0
    train_re = 0
    train_kl = 0
    # set model in training mode
    model.train()

    # start training
    if args.warmup == 0:
        beta = 1.
    else:
        beta = 1. * (epoch - 1) / args.warmup
        if beta > 1.:
            beta = 1.
    print('beta: {}'.format(beta))

    for batch_idx, (data, target) in enumerate(train_loader):
        if args.cuda:
            data, target = data.cuda(), target.cuda()
        data, target = Variable(data), Variable(target)
        # dynamic binarization
        if args.dynamic_binarization:
            x = torch.bernoulli(data)
        else:
            x = data
        # reset gradients
        optimizer.zero_grad()
        # forward pass
        x_mean, x_logvar, z_q, z_q_mean, z_q_logvar = model.forward(x)
        # loss function
        RE = log_Bernoulli(data, x_mean, average=False)
        # KL
        log_p_z = log_Normal_standard(z_q, dim=1)
        log_q_z = log_Normal_diag(z_q, z_q_mean, z_q_logvar, dim=1)
        KL = beta * (-torch.sum(log_p_z - log_q_z))

        loss = (-RE + KL) / data.size(0)
        # backward pass
        loss.backward()
        # optimization
        optimizer.step()

        train_loss += loss.data[0]
        train_re += (-RE / data.size(0)).data[0]
        train_kl += (KL / data.size(0)).data[0]

    # calculate final loss
    train_loss /= len(
        train_loader)  # loss function already averages over batch size
    train_re /= len(train_loader)  # re already averages over batch size
    train_kl /= len(train_loader)  # kl already averages over batch size

    return model, train_loss, train_re, train_kl
Ejemplo n.º 7
0
    def calculate_likelihood(self, X, dir, mode='test', S=5000):
        # set auxiliary variables for number of training and test sets
        N_test = X.size(0)

        # init list
        likelihood_test = []

        MB = 500

        if S <= MB:
            R = 1
        else:
            R = S / MB
            S = MB

        for j in range(N_test):
            if j % 100 == 0:
                print('{:.2f}%'.format(j / (1. * N_test) * 100))
            # Take x*
            x_single = X[j].unsqueeze(0)

            a = []
            for r in range(0, R):
                # Repeat it for all training points
                x = x_single.expand(S, x_single.size(1))

                # pass through VAE
                x_mean, x_logvar, z_0, z_T, z_q_mean, z_q_logvar = self.forward(
                    x)

                # RE
                RE = log_Bernoulli(x, x_mean, dim=1)

                # KL
                log_p_z = log_Normal_standard(z_T, dim=1)
                log_q_z = log_Normal_diag(z_0, z_q_mean, z_q_logvar, dim=1)
                KL = -(log_p_z - log_q_z)

                a_tmp = (RE - KL)

                a.append(a_tmp.cpu().data.numpy())

            # calculate max
            a = np.asarray(a)
            a = np.reshape(a, (a.shape[0] * a.shape[1], 1))
            likelihood_x = logsumexp(a)
            likelihood_test.append(likelihood_x - np.log(len(a)))

        likelihood_test = np.array(likelihood_test)

        plot_histogram(-likelihood_test, dir, mode)

        return -np.mean(likelihood_test)
Ejemplo n.º 8
0
    def vae_loss(x, x_decoded_mean):
        # RE part
        if config.data_type == 'binary':
            re_loss = log_Bernoulli(x, x_decoded_mean)
            RE = K.mean(K.sum(re_loss, axis=1), axis=0)
        elif config.data_type == 'gray':
            re_loss = log_Normal_diag(x, x_decoded_mean, x_decoded_log_var)
            RE = K.mean(K.sum(re_loss, axis=1), axis=0) + 0.5*config.original_dim*np.log(2*np.pi)
        else:
            raise ValueError

        #KL part
        log_q = log_Normal_diag(z_0, z_mean, z_log_var)
        log_p = log_Normal_standard( z[str(config.number_of_flows)] )
        kl =  log_q - log_p

        if config.regularization == 'none':
            KL = K.mean( K.sum( kl, axis=1 ), axis=0 )
        elif config.regularization == 'warmup':
            KL = K.mean( K.sum( kl, axis=1 ), axis=0 ) * vae.beta
        else:
            raise Exception('wrong regularization name')

        return -RE + KL
Ejemplo n.º 9
0
    def log_p_z(self, z):
        # standard normal prior
        log_prior = log_Normal_standard(z, dim=1)

        return log_prior
Ejemplo n.º 10
0
 def vae_kl(x, x_decoded_mean):
     #KL part
     log_q = log_Normal_diag(z_0, z_mean, z_log_var)
     log_p = log_Normal_standard( z[str(config.number_of_flows)] )
     kl_loss =  log_q - log_p
     return K.mean( K.sum( kl_loss, axis=1 ), axis=0 )
Ejemplo n.º 11
0
    def log_p_z(self, z, mu=None, logvar=None, head=0, use_mixw_cor=False):
        if self.args.prior == 'standard':
            log_prior = log_Normal_standard(z, dim=1)

        elif self.args.prior == 'vampprior':
            # z - MB x M
            C = self.args.number_components

            # calculate params
            X = self.means[head](self.idle_input)
            if self.args.dataset_name == 'celeba':
                X = X.reshape(X.size(0), 3, 64, 64)

            # calculate params for given data
            z_p_mean, z_p_logvar = self.q_z(X, head=head)  # C x M

            # expand z
            z_expand = z.unsqueeze(1)
            means = z_p_mean.unsqueeze(0)
            logvars = z_p_logvar.unsqueeze(0)

            a = log_Normal_diag(z_expand, means, logvars, dim=2) - math.log(
                C)  # MB x C
            a_max, _ = torch.max(a, 1)  # MB x 1

            # calculte log-sum-exp
            log_prior = a_max + torch.log(
                torch.sum(torch.exp(a - a_max.unsqueeze(1)), 1))  # MB x 1
        elif self.args.prior == 'vampprior_short':
            C = self.args.number_components

            # calculate params
            if self.args.separate_means:
                K = self.means[head].linear.weight.size(1)
                eye = torch.eye(K, K)
                if self.args.cuda:
                    eye = eye.cuda()

                X = self.means[head](eye)
            else:
                K = self.means.linear.weight.size(1)
                eye = torch.eye(K, K)
                if self.args.cuda:
                    eye = eye.cuda()

                X = self.means(eye)

            z_p_mean = self.q_z_mean(X)
            z_p_logvar = self.q_z_logvar(X)

            # expand z
            z_expand = z.unsqueeze(1)
            means = z_p_mean.unsqueeze(0)
            logvars = z_p_logvar.unsqueeze(0)

            # havent yet implemented dealing with mixing weights in the separated means case
            if self.args.use_vampmixingw:
                pis = self.mixingw(eye).t()
                if use_mixw_cor:
                    if self.args.cuda:
                        pis = pis * torch.from_numpy(
                            self.mixingw_c).cuda().unsqueeze(0)
                    else:
                        pis = pis * torch.from_numpy(
                            self.mixingw_c).unsqueeze(0)
                eps = 1e-30
                a = log_Normal_diag(z_expand, means, logvars,
                                    dim=2) + torch.log(pis)  # MB x C
            else:
                a = log_Normal_diag(z_expand, means, logvars,
                                    dim=2) - math.log(C)  # MB x C

            a_max, _ = torch.max(a, 1)  # MB x 1
            # calculte log-sum-exp
            log_prior = a_max + torch.log(
                torch.sum(torch.exp(a - a_max.unsqueeze(1)), 1))  # MB x 1

        elif self.args.prior == 'GMM':
            if self.GMM.covariance_type == 'full':
                #z_np = z.data.cpu().numpy()
                #lls = self.GMM.score_samples(z_np)
                #log_prior = torch.from_numpy(lls).float().cuda() + math.log(2*math.pi)*(0.5*z.size(1))
                log_prior = log_mog_full(z, self.mus, self.sigs, self.pis,
                                         self.icovs_ten, self.det_terms)

            else:
                log_prior = log_mog_diag(z, self.mus, self.sigs, self.pis)
        else:
            raise Exception('invalid prior!')

        return log_prior
Ejemplo n.º 12
0
def evaluate_vae(args, model, train_loader, data_loader, epoch, dir, mode):
    # set loss to 0
    evaluate_loss = 0
    evaluate_re = 0
    evaluate_kl = 0
    # set model to evaluation mode
    model.eval()

    # evaluate
    for batch_idx, (data, target) in enumerate(data_loader):
        if args.cuda:
            data, target = data.cuda(), target.cuda()
        data, target = Variable(data, volatile=True), Variable(target)

        x = data
        # forward pass
        x_mean, x_logvar, z_q, z_q_mean, z_q_logvar = model.forward(x)
        # loss function
        # RE
        RE = log_Bernoulli(x, x_mean)

        # KL
        log_p_z = log_Normal_standard(z_q, dim=1)
        log_q_z = log_Normal_diag(z_q, z_q_mean, z_q_logvar, dim=1)
        KL = - torch.sum(log_p_z - log_q_z)

        evaluate_loss += ((-RE + KL) / data.size(0)).data[0]
        evaluate_re += (-RE / data.size(0)).data[0]
        evaluate_kl += (KL / data.size(0)).data[0]

        # print N digits
        if batch_idx == 1 and mode == 'validation':
            plot_reconstruction(args, x_mean, epoch, dir, size_x=3, size_y=3)
            if epoch == 1:
                # VISUALIZATION: plot real images
                plot_real(args, data[:26], dir + 'reconstruction/', size_x=3, size_y=3)

    if mode == 'test':
        # load all data
        test_data = Variable(data_loader.dataset.data_tensor)
        test_target = Variable(data_loader.dataset.target_tensor)
        full_data = Variable(train_loader.dataset.data_tensor)

        if args.cuda:
            test_data, test_target, full_data = test_data.cuda(), test_target.cuda(), full_data.cuda()

        full_data = torch.bernoulli(full_data)

        # VISUALIZATION: plot real images
        plot_real(args, test_data, dir, size_x=5, size_y=5)

        # VISUALIZATION: plot reconstructions
        z_mean_recon, z_logvar_recon = model.q_z(test_data)
        z_recon = model.reparameterize(z_mean_recon, z_logvar_recon)
        samples, _ = model.p_x(z_recon)

        plot_reconstruction(args, samples, epoch, dir, size_x=5, size_y=5)

        # VISUALIZATION: plot generations
        z_sample_rand = Variable(torch.normal( torch.from_numpy( np.zeros((25,args.z1_size)) ).float(), 1. ) )
        if args.cuda:
            z_sample_rand = z_sample_rand.cuda()

        samples_rand, _ = model.p_x(z_sample_rand)

        plot_generation(args, samples_rand, dir, size_x=5, size_y=5)

        if args.z1_size == 2:
            # VISUALIZATION: plot low-dimensional manifold
            plot_manifold(model, args, dir)

            # VISUALIZATION: plot scatter-plot
            plot_scatter(model, test_data, test_target, dir)

        # CALCULATE lower-bound
        t_ll_s = time.time()
        elbo_test = model.calculate_lower_bound(test_data)
        t_ll_e = time.time()
        print('Lower-bound time: {:.2f}'.format(t_ll_e - t_ll_s))

        # CALCULATE log-likelihood
        t_ll_s = time.time()
        elbo_train = model.calculate_lower_bound(full_data)
        t_ll_e = time.time()
        print('Lower-bound time: {:.2f}'.format(t_ll_e - t_ll_s))

        # CALCULATE log-likelihood
        t_ll_s = time.time()
        log_likelihood_test = model.calculate_likelihood(test_data, dir, mode='test')
        t_ll_e = time.time()
        print('Log_likelihood time: {:.2f}'.format(t_ll_e - t_ll_s))

        # CALCULATE log-likelihood
        t_ll_s = time.time()
        log_likelihood_train = 0. #model.calculate_likelihood(full_data, dir, mode='train')
        t_ll_e = time.time()
        print('Log_likelihood time: {:.2f}'.format(t_ll_e - t_ll_s))

    # calculate final loss
    evaluate_loss /= len(data_loader)  # loss function already averages over batch size
    evaluate_re /= len(data_loader)  # re already averages over batch size
    evaluate_kl /= len(data_loader)  # kl already averages over batch size
    if mode == 'test':
        return evaluate_loss, evaluate_re, evaluate_kl, log_likelihood_test, log_likelihood_train, elbo_test, elbo_train
    else:
        return evaluate_loss, evaluate_re, evaluate_kl
Ejemplo n.º 13
0
    def log_p_z(self, z, mu=None, logvar=None):
        if self.args.prior == 'standard':
            log_prior = log_Normal_standard(z, dim=1)

        elif self.args.prior == 'vampprior':
            # z - MB x M
            C = self.args.number_components

            # calculate params
            X = self.means(self.idle_input)
            if self.args.dataset_name == 'celeba':
                X = X.reshape(X.size(0), 3, 64, 64)

            # calculate params for given data
            z_p_mean, z_p_logvar = self.q_z(X)  # C x M

            # expand z
            z_expand = z.unsqueeze(1)
            means = z_p_mean.unsqueeze(0)
            logvars = z_p_logvar.unsqueeze(0)

            a = log_Normal_diag(z_expand, means, logvars, dim=2) - math.log(
                C)  # MB x C
            a_max, _ = torch.max(a, 1)  # MB x 1

            # calculte log-sum-exp
            log_prior = a_max + torch.log(
                torch.sum(torch.exp(a - a_max.unsqueeze(1)), 1))  # MB x 1
        elif self.args.prior == 'vampprior_short':
            C = self.args.number_components

            K = self.means.linear.weight.size(1)
            eye = torch.eye(K, K)
            if self.args.cuda:
                eye = eye.cuda()

            X = self.means(eye)

            z_p_mean = self.q_z_mean(X)
            z_p_logvar = self.q_z_logvar(X)

            # expand z
            z_expand = z.unsqueeze(1)
            means = z_p_mean.unsqueeze(0)
            logvars = z_p_logvar.unsqueeze(0)

            # havent yet implemented dealing with mixing weights in the separated means case
            if self.args.use_vampmixingw:
                pis = self.mixingw(eye).t()
                eps = 1e-30
                a = log_Normal_diag(z_expand, means, logvars,
                                    dim=2) + torch.log(pis)  # MB x C
            else:
                a = log_Normal_diag(z_expand, means, logvars,
                                    dim=2) - math.log(C)  # MB x C

            a_max, _ = torch.max(a, 1)  # MB x 1
            # calculte log-sum-exp
            log_prior = a_max + torch.log(
                torch.sum(torch.exp(a - a_max.unsqueeze(1)), 1))  # MB x 1

        else:
            raise Exception('invalid prior!')

        return log_prior