Example #1
0
class VAE(Model):
    def __init__(self, args):
        super(VAE, self).__init__(args)

        Khid = 100
        arc = 'conv'
        self.Khid = Khid
        self.arc = arc

        if arc == 'ff':
            # encoder: q(z | x)
            self.q_z_layers = nn.ModuleList([
                nn.Sequential(GatedDense(np.prod(self.args.input_size), Khid),
                              GatedDense(Khid, Khid))
            ])

            self.q_z_mean = Linear(Khid, self.args.z1_size)
            self.q_z_logvar = NonLinear(Khid,
                                        self.args.z1_size,
                                        activation=nn.Hardtanh(min_val=-6.,
                                                               max_val=2.))

            # decoder: p(x | z)
            self.p_x_layers = nn.Sequential(
                GatedDense(self.args.z1_size, Khid), GatedDense(Khid, Khid))

            #if self.args.input_type == 'binary':
            self.p_x_mean = nn.ModuleList([
                GatedDense(Khid,
                           np.prod(self.args.input_size),
                           activation=nn.Sigmoid())
            ])
        elif arc == 'conv':
            act = None
            self.q_z_layers = nn.ModuleList([
                nn.Sequential(
                    GatedConv2d(self.args.input_size[0],
                                32,
                                3,
                                2,
                                1,
                                activation=act),
                    GatedConv2d(32, 32, 3, 2, 1, activation=act),
                    GatedConv2d(32, Khid, 7, 1, 0, activation=act),
                )
            ])

            self.q_z_mean = GatedDense(Khid, self.args.z1_size)
            self.q_z_logvar = GatedDense(Khid,
                                         self.args.z1_size,
                                         activation=nn.Hardtanh(min_val=-6.,
                                                                max_val=2.))

            self.p_x_layers = nn.Sequential(
                GatedConvTranspose2d(self.args.z1_size,
                                     32,
                                     7,
                                     1,
                                     0,
                                     activation=act),
                GatedConvTranspose2d(32, 32, 3, 2, 1, activation=act),
                GatedConvTranspose2d(32,
                                     self.args.input_size[0],
                                     4,
                                     2,
                                     0,
                                     activation=nn.Sigmoid()))

            #if self.args.input_type == 'binary':
            #self.p_x_mean = nn.ModuleList([GatedDense(Khid, np.prod(self.args.input_size),
            #                                         activation=nn.Sigmoid())])

        #elif self.args.input_type in ['gray', 'continuous', 'color']:
        #    self.p_x_mean = NonLinear(300, np.prod(self.args.input_size), activation=nn.Sigmoid())
        #    self.p_x_logvar = NonLinear(300, np.prod(self.args.input_size), activation=nn.Hardtanh(min_val=-4.5,max_val=0))
        self.mixingw_c = np.ones(self.args.number_components)

        self.apply(net_init)
        # weights initialization
        #if isinstance(m, nn.Linear): # or isinstance(m, nn.ConvTranspose2d):
        #    he_init(m)

        # add pseudo-inputs if VampPrior
        if self.args.prior in ['vampprior', 'vampprior_short']:
            self.add_pseudoinputs()
        elif self.args.prior == 'GMM':
            self.initialize_GMMparams(Kmog=10, mode='random')

    def add_head(self, input_size):
        q_z_layers_new = nn.Sequential(GatedDense(np.prod(input_size), 300),
                                       GatedDense(300, 300))
        self.q_z_layers.append(q_z_layers_new)

        #elif self.args.input_type in ['gray', 'continuous', 'color']:
        #p_x_mean_new = NonLinear(300, np.prod(input_size), activation=nn.Sigmoid())

        #logvar_layers = [self.p_x_logvar]
        #self.p_x_logvar_new = NonLinear(300, np.prod(input_size), activation=nn.Hardtanh(min_val=-4.5,max_val=0))
        #logvar_layers.append(self.p_x_logvar_new)
        #self.p_x_logvar = nn.ModuleList(logvar_layers)

        p_x_mean_new = NonLinear(300,
                                 np.prod(input_size),
                                 activation=nn.Sigmoid())
        self.p_x_mean.append(p_x_mean_new)

        if self.args.prior == 'vampprior':
            nonlinearity = nn.Hardtanh(min_val=0.0, max_val=1.0)
            us_new = NonLinear(self.args.number_components,
                               np.prod(self.args.input_size),
                               bias=False,
                               activation=nonlinearity)
            self.means.append(us_new)
        elif self.args.prior == 'vampprior_joint':
            nonlinearity = None  #nn.Hardtanh(min_val=0.0, max_val=1.0)

            us_new = NonLinear(2 * self.args.number_components,
                               300,
                               bias=False,
                               activation=nonlinearity)
            oldweights = self.means.linear.weight.data
            us_new.linear.weight.data[:, :self.args.
                                      number_components] = oldweights
            self.means = us_new

            # fix the idle input size also
            self.idle_input = torch.eye(2 * self.args.number_components, 2 *
                                        self.args.number_components).cuda()

        self.extra_head = True

    def restart_latent_space(self):
        nonlinearity = nn.Hardtanh(min_val=0.0, max_val=1.0)
        self.means = NonLinear(self.args.number_components_init,
                               300,
                               bias=False,
                               activation=nonlinearity)

        if self.args.use_vampmixingw:
            self.mixingw = NonLinear(self.args.number_components_init,
                                     1,
                                     bias=False,
                                     activation=nn.Softmax(dim=0))

    def merge_latent(self):
        # always to be called after separate_latent() or add_latent_cap()
        nonlinearity = None  #nn.Hardtanh(min_val=0.0, max_val=1.0)

        prev_weights = self.means[0].linear.weight.data
        last_prev_weights = self.means[1].linear.weight.data
        all_old = torch.cat([prev_weights, last_prev_weights], dim=1)

        self.means[0] = NonLinear(all_old.size(1),
                                  300,
                                  bias=False,
                                  activation=nonlinearity).cuda()
        self.means[0].linear.weight.data = all_old

    def separate_latent(self):
        # always to be called after merge_latent()
        nonlinearity = None  #nn.Hardtanh(min_val=0.0, max_val=1.0)

        number_components_init = self.args.number_components_init
        number_components_prev = (
            self.args.number_components) - number_components_init

        prev_components = copy.deepcopy(
            self.means[0].linear.weight.data[:, :number_components_prev:])

        last_prev_components = copy.deepcopy(
            self.means[0].linear.weight.data[:, number_components_prev:])

        self.means[0] = NonLinear(number_components_prev,
                                  300,
                                  bias=False,
                                  activation=nonlinearity).cuda()
        self.means[1] = NonLinear(number_components_init,
                                  300,
                                  bias=False,
                                  activation=nonlinearity).cuda()
        self.means[0].linear.weight.data = prev_components
        self.means[1].linear.weight.data = last_prev_components

    def add_latent_cap(self, dg):
        if self.args.prior == 'vampprior_short':
            nonlinearity = None  #nn.Hardtanh(min_val=0.0, max_val=1.0)

            number_components_prev = copy.deepcopy(self.args.number_components)
            self.args.number_components = (dg + 2) * copy.deepcopy(
                (self.args.number_components_init))

            if self.args.separate_means:
                add_number_components = self.args.number_components - number_components_prev
                # set the idle inputs
                #self.idle_input = torch.eye(self.args.number_components,
                #                            self.args.number_components).cuda()
                #self.idle_input1 = torch.eye(add_number_components,
                #                             add_number_components).cuda()
                #self.idle_input2 = torch.eye(number_components_prev,
                #                             number_components_prev).cuda()

                us_new = NonLinear(add_number_components,
                                   300,
                                   bias=False,
                                   activation=nonlinearity).cuda()
                #us_new.linear.weight.data = 0*torch.randn(300, add_number_components).cuda()
                if dg == 0:
                    self.means.append(us_new)
                else:
                    # self.merge_latent() - nope, because we do this because validation evaluation (in main_mnist.py)
                    self.means[1] = us_new

            else:
                self.idle_input = torch.eye(self.args.number_components,
                                            self.args.number_components)
                us_new = NonLinear(self.args.number_components,
                                   self.Khid,
                                   bias=False,
                                   activation=nonlinearity)
                if self.args.cuda:
                    self.idle_input = self.idle_input.cuda()
                    us_new = us_new.cuda()

                if not self.args.restart_means:
                    oldweights = self.means.linear.weight.data
                    us_new.linear.weight.data[:, :
                                              number_components_prev] = oldweights
                self.means = us_new

                if self.args.use_vampmixingw:
                    self.mixingw = NonLinear(self.args.number_components,
                                             1,
                                             bias=False,
                                             activation=nn.Softmax(dim=0))

    def reconstruct_means(self, head=0):
        if self.args.separate_means:
            K = self.means[head].linear.weight.size(1)
            eye = torch.eye(K, K)
            if self.args.cuda:
                eye = eye.cuda()
            X = self.means[head](eye)
        else:
            K = self.means.linear.weight.size(1)
            eye = torch.eye(K, K)
            if self.args.cuda:
                eye = eye.cuda()

            X = self.means(eye)
        z_mean = self.q_z_mean(X)

        recons, _ = self.p_x(z_mean, head=0)

        return recons

    def balance_mixingw(self,
                        classifier,
                        dg,
                        perm=torch.arange(10),
                        dont_balance=False,
                        vis=None):
        # functions that are related:
        # training.train_classifier
        # evaluate.

        means = self.reconstruct_means()
        yhat_means_soft = F.softmax(classifier.forward(means))
        pis = self.mixingw(self.idle_input).squeeze()

        if self.args.number_components == 1:
            curr_per_class_weight = yhat_means_soft
        else:
            curr_per_class_weight = torch.matmul(pis, yhat_means_soft)

        yhat_means = torch.argmax(classifier.forward(means), dim=1)

        # to numpy:
        if self.args.cuda:
            curr_per_class_weight = curr_per_class_weight.detach().cpu().numpy(
            )
        else:
            curr_per_class_weight = curr_per_class_weight.detach().numpy()

        print('\ncurrent per class cluster assignment:')
        print(np.round(curr_per_class_weight, 2))
        print('\n')

        if dont_balance:
            return yhat_means, curr_per_class_weight

        mixingw_c = torch.zeros(self.args.number_components, 1).squeeze()
        ones = torch.ones(self.args.number_components).squeeze()

        if self.args.cuda:
            mixingw_c = mixingw_c.cuda()
            ones = ones.cuda()

        for d in perm[:(dg + 1)]:
            mask = (yhat_means == int(d.item()))
            pis = self.mixingw(self.idle_input).squeeze()
            pis_select = torch.masked_select(pis, mask)
            sm = pis_select.sum()

            # correct the mixing weights
            mixingw_c = mixingw_c + ones * (mask.float()) / (sm * (dg + 1))

        post_per_class_weight = torch.matmul(mixingw_c * pis, yhat_means_soft)

        if self.args.cuda:
            post_per_class_weight = post_per_class_weight.detach().cpu().numpy(
            )
        else:
            post_per_class_weight = post_per_class_weight.detach().numpy()

        print('\npost per class cluster assignment:')
        print(np.round(post_per_class_weight, 2))
        print('\n')

        self.mixingw_c = mixingw_c.data.cpu().numpy()
        return yhat_means, curr_per_class_weight, post_per_class_weight

    def compute_class_entropy(self, classifier, dg, perm):
        means = self.reconstruct_means()
        yhat_means = F.softmax(classifier.forward(means), dim=1)

        pis = self.mixingw(self.idle_input)
        ws = torch.matmul(yhat_means.t(), pis).squeeze()
        dist = ws[perm[:(dg + 1)].long()]

        eps = 1e-30
        nent = (dist * torch.log(dist + eps)).sum()

        return nent, ws

    def calculate_loss(self,
                       x,
                       beta=1.,
                       average=False,
                       head=0,
                       use_mixw_cor=False):
        '''
        :param x: input image(s)
        :param beta: a hyperparam for warmup
        :param average: whether to average loss or not
        :return: value of a loss function
        '''

        # pass through VAE
        fw_head = min(head, len(self.q_z_layers) - 1)

        x_mean, x_logvar, z_q, z_q_mean, z_q_logvar = self.forward(
            x, head=fw_head)

        # RE
        if self.args.input_type == 'binary':
            RE = log_Bernoulli(x, x_mean, dim=1)
        elif self.args.input_type in ['gray', 'color']:
            #RE = -log_Logistic_256(x, x_mean, x_logvar, dim=1)
            RE = -(x - x_mean).abs().sum(dim=1)

        #elif self.args.input_type == 'color':
        #    RE = -log_Normal_diag(x, x_mean, x_logvar, dim=1)
        else:
            raise Exception('Wrong input type!')

        # KL
        if self.args.prior == 'GMM':
            log_p_z = self.log_p_z(z_q, mu=z_q_mean, logvar=z_q_logvar)
        else:
            log_p_z = self.log_p_z(z_q, head=head, use_mixw_cor=use_mixw_cor)
        log_q_z = log_Normal_diag(z_q, z_q_mean, z_q_logvar, dim=1)
        KL = -(log_p_z - log_q_z)

        loss = -RE + beta * KL

        if average:
            loss = torch.mean(loss)
            RE = torch.mean(RE)
            KL = torch.mean(KL)

        return loss, RE, KL, x_mean

    def calculate_likelihood(self,
                             X,
                             dir,
                             mode='test',
                             S=5000,
                             MB=100,
                             use_mixw_cor=False):
        # set auxiliary variables for number of training and test sets
        N_test = X.size(0)

        # init list
        likelihood_test = []

        if S <= MB:
            R = 1
        else:
            R = S / MB
            S = MB

        for j in range(N_test):
            if j % 100 == 0:
                print('{:.2f}%'.format(j / (1. * N_test) * 100))
            # Take x*
            x_single = X[j].unsqueeze(0)

            a = []
            for r in range(0, int(R)):
                # Repeat it for all training points
                if self.args.dataset_name == 'celeba':
                    x = x_single.expand(S, x_single.size(1), x_single.size(2),
                                        x_single.size(3))
                else:
                    x = x_single.expand(S, x_single.size(1))

                a_tmp, _, _, _ = self.calculate_loss(x,
                                                     use_mixw_cor=use_mixw_cor)

                a.append(-a_tmp.cpu().data.numpy())

            # calculate max
            a = np.asarray(a)
            a = np.reshape(a, (a.shape[0] * a.shape[1], 1))
            likelihood_x = logsumexp(a)
            likelihood_test.append(likelihood_x - np.log(len(a)))

        likelihood_test = np.array(likelihood_test)

        #plot_histogram(-likelihood_test, dir, mode)

        return -np.mean(likelihood_test)

    def calculate_lower_bound(self, X_full, MB=100, use_mixw_cor=False):
        # CALCULATE LOWER BOUND:
        lower_bound = 0.
        RE_all = 0.
        KL_all = 0.

        # dealing the case where the last batch is of size 1
        remainder = X_full.size(0) % MB
        if remainder == 1:
            X_full = X_full[:(X_full.size(0) - remainder)]

        I = int(math.ceil(X_full.size(0) / MB))

        for i in range(I):
            if not self.args.dataset_name == 'celeba':
                x = X_full[i * MB:(i + 1) * MB].view(
                    -1, np.prod(self.args.input_size))
            else:
                x = X_full[i * MB:(i + 1) * MB]

            loss, RE, KL, _ = self.calculate_loss(x,
                                                  average=True,
                                                  use_mixw_cor=use_mixw_cor)

            RE_all += RE.cpu().item()
            KL_all += KL.cpu().item()
            lower_bound += loss.cpu().item()

        lower_bound /= I

        return lower_bound

    # ADDITIONAL METHODS
    def generate_x(self, N=25, head=0, replay=False):
        if self.args.prior == 'standard':
            z_sample_rand = Variable(
                torch.FloatTensor(N, self.args.z1_size).normal_())
            if self.args.cuda:
                z_sample_rand = z_sample_rand.cuda()

            samples_rand, _ = self.p_x(z_sample_rand, head=head)
        elif self.args.prior == 'vampprior':
            clsts = np.random.choice(range(self.Kmog),
                                     N,
                                     p=self.pis.data.cpu().numpy())

            means = self.means[head](self.idle_input)
            if self.args.dataset_name == 'celeba':
                means = means.reshape(means.size(0), 3, 64, 64)
            z_sample_gen_mean, z_sample_gen_logvar = self.q_z(means, head=head)
            z_sample_rand = self.reparameterize(z_sample_gen_mean,
                                                z_sample_gen_logvar)

            samples_rand, _ = self.p_x(z_sample_rand, head=head)

            # do a random permutation to see a more representative sampleset

            randperm = torch.randperm(samples_rand.size(0))
            samples_rand = samples_rand[randperm][:N]

        elif self.args.prior == 'vampprior_short':
            if self.args.use_vampmixingw:
                pis = self.mixingw(self.idle_input).squeeze()
            else:
                pis = torch.ones(
                    self.args.number_components) / self.args.number_components

            if self.args.use_mixingw_correction and replay:
                if self.args.cuda:
                    pis = torch.from_numpy(self.mixingw_c).cuda() * pis
                else:
                    pis = torch.from_numpy(self.mixingw_c) * pis

            if self.args.number_components == 1:
                clsts = np.random.choice(range(self.args.number_components),
                                         N,
                                         p=[1])
            else:
                clsts = np.random.choice(range(self.args.number_components),
                                         N,
                                         p=pis.data.cpu().numpy())

            if self.args.separate_means:
                K = self.means[head].linear.weight.size(1)
                eye = torch.eye(K, K).cuda()

                means = self.means[0](eye)[clsts, :]
            else:
                K = self.means.linear.weight.size(1)
                eye = torch.eye(K, K)
                if self.args.cuda:
                    eye = eye.cuda()

                means = self.means(eye)[clsts, :]

            # if used in the separated means case, always use the first head. Therefore you need to merge the means before generation
            z_sample_gen_mean, z_sample_gen_logvar = self.q_z_mean(
                means), self.q_z_logvar(means)
            z_sample_rand = self.reparameterize(z_sample_gen_mean,
                                                z_sample_gen_logvar)

            samples_rand, _ = self.p_x(z_sample_rand, head=head)

            # do a random permutation to see a more representative sampleset
            #randperm = torch.randperm(samples_rand.size(0))
            #samples_rand = samples_rand[randperm][:N]
        elif self.args.prior == 'GMM':
            if self.GMM.covariance_type == 'diag':
                clsts = np.random.choice(range(self.Kmog),
                                         N,
                                         p=self.pis.data.cpu().numpy())
                mus = self.mus[clsts, :]
                randn = torch.randn(mus.size())
                if next(self.parameters()).is_cuda:
                    randn = randn.cuda()

                zs = mus + (self.sigs[clsts, :].sqrt()) * randn
            elif self.GMM.covariance_type == 'full':
                Us = [
                    torch.svd(cov)[0].mm(torch.sqrt(
                        torch.svd(cov)[1]).diag()).unsqueeze(0)
                    for cov in self.sigs
                ]
                Us = torch.cat(Us, dim=0)

                noise = torch.randn(N, self.mus.size(1), 1).cuda()

                clsts = (torch.from_numpy(
                    np.random.choice(range(self.mus.size(0)),
                                     size=N,
                                     p=self.pis.detach().cpu().numpy())).type(
                                         torch.LongTensor)).cuda()
                Us_zs = torch.index_select(Us, dim=0, index=clsts)
                means_zs = torch.index_select(self.mus, dim=0, index=clsts)

                zs = torch.matmul(Us_zs, noise).squeeze() + means_zs
            samples_rand, _ = self.p_x(zs)

        return samples_rand

    def reconstruct_x(self, x):
        x_mean, _, _, _, _ = self.forward(x)
        return x_mean

    # THE MODEL: VARIATIONAL POSTERIOR
    def q_z(self, x, head=0):
        if self.arc == 'conv':
            if self.args.dataset_name in [
                    'omniglot_char', 'dynamic_mnist', 'fashion_mnist',
                    'mnist_plus_fmnist'
            ]:
                x = x.reshape(-1, 1, 28, 28)
            else:
                raise NameError('I dont know what dataset is that')

        x = self.q_z_layers[head](x)

        x = x.squeeze()

        z_q_mean = self.q_z_mean(x)
        z_q_logvar = self.q_z_logvar(x)
        return z_q_mean, z_q_logvar

    # THE MODEL: GENERATIVE DISTRIBUTION
    def p_x(self, z, head=0):

        if self.arc == 'conv':
            z = z.unsqueeze(-1).unsqueeze(-1)
            x_mean = self.p_x_layers(z)
        else:
            z = self.p_x_layers(z)
            x_mean = self.p_x_mean[head](z)

        x_logvar = 0.

        x_mean = x_mean.reshape(x_mean.size(0), -1)
        if self.args.dataset_name == 'celeba':
            x_logvar = x_logvar.reshape(x_logvar.size(0), -1)
        return x_mean, x_logvar

    # the prior
    def log_p_z(self, z, mu=None, logvar=None, head=0, use_mixw_cor=False):
        if self.args.prior == 'standard':
            log_prior = log_Normal_standard(z, dim=1)

        elif self.args.prior == 'vampprior':
            # z - MB x M
            C = self.args.number_components

            # calculate params
            X = self.means[head](self.idle_input)
            if self.args.dataset_name == 'celeba':
                X = X.reshape(X.size(0), 3, 64, 64)

            # calculate params for given data
            z_p_mean, z_p_logvar = self.q_z(X, head=head)  # C x M

            # expand z
            z_expand = z.unsqueeze(1)
            means = z_p_mean.unsqueeze(0)
            logvars = z_p_logvar.unsqueeze(0)

            a = log_Normal_diag(z_expand, means, logvars, dim=2) - math.log(
                C)  # MB x C
            a_max, _ = torch.max(a, 1)  # MB x 1

            # calculte log-sum-exp
            log_prior = a_max + torch.log(
                torch.sum(torch.exp(a - a_max.unsqueeze(1)), 1))  # MB x 1
        elif self.args.prior == 'vampprior_short':
            C = self.args.number_components

            # calculate params
            if self.args.separate_means:
                K = self.means[head].linear.weight.size(1)
                eye = torch.eye(K, K)
                if self.args.cuda:
                    eye = eye.cuda()

                X = self.means[head](eye)
            else:
                K = self.means.linear.weight.size(1)
                eye = torch.eye(K, K)
                if self.args.cuda:
                    eye = eye.cuda()

                X = self.means(eye)

            z_p_mean = self.q_z_mean(X)
            z_p_logvar = self.q_z_logvar(X)

            # expand z
            z_expand = z.unsqueeze(1)
            means = z_p_mean.unsqueeze(0)
            logvars = z_p_logvar.unsqueeze(0)

            # havent yet implemented dealing with mixing weights in the separated means case
            if self.args.use_vampmixingw:
                pis = self.mixingw(eye).t()
                if use_mixw_cor:
                    if self.args.cuda:
                        pis = pis * torch.from_numpy(
                            self.mixingw_c).cuda().unsqueeze(0)
                    else:
                        pis = pis * torch.from_numpy(
                            self.mixingw_c).unsqueeze(0)
                eps = 1e-30
                a = log_Normal_diag(z_expand, means, logvars,
                                    dim=2) + torch.log(pis)  # MB x C
            else:
                a = log_Normal_diag(z_expand, means, logvars,
                                    dim=2) - math.log(C)  # MB x C

            a_max, _ = torch.max(a, 1)  # MB x 1
            # calculte log-sum-exp
            log_prior = a_max + torch.log(
                torch.sum(torch.exp(a - a_max.unsqueeze(1)), 1))  # MB x 1

        elif self.args.prior == 'GMM':
            if self.GMM.covariance_type == 'full':
                #z_np = z.data.cpu().numpy()
                #lls = self.GMM.score_samples(z_np)
                #log_prior = torch.from_numpy(lls).float().cuda() + math.log(2*math.pi)*(0.5*z.size(1))
                log_prior = log_mog_full(z, self.mus, self.sigs, self.pis,
                                         self.icovs_ten, self.det_terms)

            else:
                log_prior = log_mog_diag(z, self.mus, self.sigs, self.pis)
        else:
            raise Exception('invalid prior!')

        return log_prior

    # THE MODEL: FORWARD PASS
    def forward(self, x, head=0):
        # z ~ q(z | x)
        z_q_mean, z_q_logvar = self.q_z(x, head=head)
        z_q = self.reparameterize(z_q_mean, z_q_logvar)

        # x_mean = p(x|z)
        x_mean, x_logvar = self.p_x(z_q, head=head)

        return x_mean, x_logvar, z_q, z_q_mean, z_q_logvar

    def get_embeddings(self, train_loader, cuda=True, flatten=True):
        # get hhats for all batches
        if self.args.dataset_name == 'celeba':
            nbatches = 300
        else:
            nbatches = 1000

        all_hhats = []
        for i, (data, _) in enumerate(it.islice(train_loader, 0, nbatches, 1)):
            if cuda:
                data = data.cuda()

            if self.args.dataset_name != 'celeba':
                data = data.view(-1, np.prod(self.args.input_size))

            mu, logvar = self.q_z(data)
            hhat = torch.randn(mu.size()).cuda() * (0.5 * logvar).exp() + mu

            all_hhats.append(hhat.data.squeeze())
            print('processing batch {}'.format(i))

        return torch.cat(all_hhats, dim=0)

    def fit_GMM(self, train_loader, Kmog, cov_type='diag', model_name=None):
        self.args.prior = 'GMM'
        self.Kmog = Kmog
        hhat = self.get_embeddings(train_loader)

        path = 'gmm_params/'
        if not os.path.exists(path + self.args.dataset_name):
            if not os.path.exists(path):
                os.mkdir(path)
            os.mkdir(path + self.args.dataset_name)

        gmm_path = path + model_name + 'gmm_' + cov_type + '.pk'
        if 1 & os.path.exists(gmm_path):
            self.GMM = pickle.load(open(gmm_path, 'rb'))
        else:
            self.GMM = mix.GaussianMixture(n_components=Kmog,
                                           verbose=1,
                                           n_init=10,
                                           max_iter=200,
                                           covariance_type=cov_type,
                                           warm_start=True)
            self.GMM.fit(hhat.cpu().numpy())
            pickle.dump(self.GMM, open(gmm_path, 'wb'))

        # then initialize the GMM, and change self.args.prior
        self.mus = nn.Parameter(torch.from_numpy(self.GMM.means_).float())
        self.pis = nn.Parameter(torch.from_numpy(self.GMM.weights_).float())
        if cov_type == 'diag':
            self.sigs = nn.Parameter(
                torch.from_numpy(self.GMM.covariances_).float())
        else:
            self.sigs = torch.from_numpy(self.GMM.covariances_).float().cuda()
            self.icovs_ten = (torch.zeros(self.pis.size(0), self.mus.size(1),
                                          self.mus.size(1))).cuda()
            cov_eps = 1e-10
            for k in range(self.pis.size(0)):
                self.icovs_ten[k, :, :] = torch.inverse(
                    self.sigs[k, :, :] +
                    torch.eye(self.mus.size(1)).cuda() * cov_eps)

            self.det_terms = torch.Tensor([
                0.5 * (cov_eps + torch.svd(cov.squeeze())[1]).log().sum()
                for cov in self.sigs
            ]).cuda()

        self.Kmog = self.pis.size(0)
        self.args.prior = 'GMM'
Example #2
0
class SSVAE(Model):
    def __init__(self, args):
        super(SSVAE, self).__init__(args)

        assert self.args.prior != 'GMM'
        assert self.args.prior != 'vampprior'
        assert not self.args.separate_means

        # encoder: q(z | x)
        self.q_z_layers = nn.Sequential(
            GatedDense(np.prod(self.args.input_size), 300),
            GatedDense(300, 300))

        self.q_z_mean = Linear(300, self.args.z1_size)
        self.q_z_logvar = NonLinear(300,
                                    self.args.z1_size,
                                    activation=nn.Hardtanh(min_val=-6.,
                                                           max_val=2.))

        # decoder: p(x | z)
        self.p_x_layers = nn.Sequential(GatedDense(self.args.z1_size, 300),
                                        GatedDense(300, 300))

        #if self.args.input_type == 'binary':
        self.p_x_mean = NonLinear(300,
                                  np.prod(self.args.input_size),
                                  activation=nn.Sigmoid())
        #elif self.args.input_type in ['gray', 'continuous', 'color']:
        #    self.p_x_mean = NonLinear(300, np.prod(self.args.input_size), activation=nn.Sigmoid())
        #    self.p_x_logvar = NonLinear(300, np.prod(self.args.input_size), activation=nn.Hardtanh(min_val=-4.5,max_val=0))
        self.mixingw_c = np.ones(self.args.number_components)

        #self.semi_supervisor = nn.Sequential(Linear(args.z1_size, args.num_classes),
        #                                         nn.Softmax())
        self.semi_supervisor = nn.Sequential(
            GatedDense(args.z1_size, args.z1_size), nn.Dropout(0.5),
            GatedDense(args.z1_size, args.num_classes), nn.Softmax())

        # weights initialization
        for m in self.modules():
            if isinstance(m, nn.Linear):
                he_init(m)

        # add pseudo-inputs if VampPrior
        if self.args.prior in ['vampprior', 'vampprior_short']:
            self.add_pseudoinputs()

    def restart_latent_space(self):
        nonlinearity = nn.Hardtanh(min_val=0.0, max_val=1.0)
        self.means = NonLinear(self.args.number_components_init,
                               300,
                               bias=False,
                               activation=nonlinearity)

        if self.args.use_vampmixingw:
            self.mixingw = NonLinear(self.args.number_components_init,
                                     1,
                                     bias=False,
                                     activation=nn.Softmax(dim=0))

    def merge_latent(self):
        # always to be called after separate_latent() or add_latent_cap()
        nonlinearity = None  #nn.Hardtanh(min_val=0.0, max_val=1.0)

        prev_weights = self.means[0].linear.weight.data
        last_prev_weights = self.means[1].linear.weight.data
        all_old = torch.cat([prev_weights, last_prev_weights], dim=1)

        self.means[0] = NonLinear(all_old.size(1),
                                  300,
                                  bias=False,
                                  activation=nonlinearity).cuda()
        self.means[0].linear.weight.data = all_old

    def separate_latent(self):
        # always to be called after merge_latent()
        nonlinearity = None  #nn.Hardtanh(min_val=0.0, max_val=1.0)

        number_components_init = self.args.number_components_init
        number_components_prev = (
            self.args.number_components) - number_components_init

        prev_components = copy.deepcopy(
            self.means[0].linear.weight.data[:, :number_components_prev:])

        last_prev_components = copy.deepcopy(
            self.means[0].linear.weight.data[:, number_components_prev:])

        self.means[0] = NonLinear(number_components_prev,
                                  300,
                                  bias=False,
                                  activation=nonlinearity).cuda()
        self.means[1] = NonLinear(number_components_init,
                                  300,
                                  bias=False,
                                  activation=nonlinearity).cuda()
        self.means[0].linear.weight.data = prev_components
        self.means[1].linear.weight.data = last_prev_components

    def add_latent_cap(self, dg):
        if self.args.prior == 'vampprior_short':
            nonlinearity = None  #nn.Hardtanh(min_val=0.0, max_val=1.0)

            number_components_prev = copy.deepcopy(self.args.number_components)
            self.args.number_components = (dg + 2) * copy.deepcopy(
                (self.args.number_components_init))

            if self.args.separate_means:
                add_number_components = self.args.number_components - number_components_prev
                # set the idle inputs
                #self.idle_input = torch.eye(self.args.number_components,
                #                            self.args.number_components).cuda()
                #self.idle_input1 = torch.eye(add_number_components,
                #                             add_number_components).cuda()
                #self.idle_input2 = torch.eye(number_components_prev,
                #                             number_components_prev).cuda()

                us_new = NonLinear(add_number_components,
                                   300,
                                   bias=False,
                                   activation=nonlinearity).cuda()
                #us_new.linear.weight.data = 0*torch.randn(300, add_number_components).cuda()
                if dg == 0:
                    self.means.append(us_new)
                else:
                    # self.merge_latent() - nope, because we do this because validation evaluation (in main_mnist.py)
                    self.means[1] = us_new

            else:
                self.idle_input = torch.eye(
                    self.args.number_components,
                    self.args.number_components).cuda()

                us_new = NonLinear(self.args.number_components,
                                   300,
                                   bias=False,
                                   activation=nonlinearity).cuda()
                if not self.args.restart_means:
                    oldweights = self.means.linear.weight.data
                    us_new.linear.weight.data[:, :
                                              number_components_prev] = oldweights
                self.means = us_new

                if self.args.use_vampmixingw:
                    self.mixingw = NonLinear(self.args.number_components,
                                             1,
                                             bias=False,
                                             activation=nn.Softmax(dim=0))

    def reconstruct_means(self, head=None):
        K = self.means.linear.weight.size(1)
        eye = torch.eye(K, K)
        if self.args.cuda: eye = eye.cuda()

        X = self.means(eye)
        z_mean = self.q_z_mean(X)

        recons, _ = self.p_x(z_mean)

        return recons

    def balance_mixingw(self, dg, perm, dont_balance=False, vis=None):

        # get means
        K = self.means.linear.weight.size(1)
        eye = torch.eye(K, K)
        if self.args.cuda:
            eye = eye.cuda()
        X = self.means(eye)
        z_mean = self.q_z_mean(X)
        yhat_means = self.semi_supervisor(z_mean)

        pis = self.mixingw(self.idle_input).squeeze()

        # to numpy:
        if self.args.cuda:
            y_hat_means = yhat_means.detach().cpu().numpy()
            pis = pis.detach().cpu().numpy()
        else:
            y_hat_means = yhat_means.detach().numpy()
            pis = pis.detach().numpy()
        perm = perm.numpy()
        curr_per_class_weight = np.matmul(pis, y_hat_means)

        print('\ncurrent per class cluster assignment:')
        print(np.round(curr_per_class_weight, 2))
        print('\n')

        if dont_balance:
            return yhat_means, curr_per_class_weight

        mixingw_c = np.zeros(self.args.number_components)
        per_class_scaling = np.zeros(self.args.num_classes)
        for d in range(dg + 1):
            idx = int(perm[d])
            per_class_scaling[idx] = 1 / curr_per_class_weight[idx] / (dg + 1)

        self.mixingw_c = np.matmul(y_hat_means, per_class_scaling)

        print('\nrebalanced per class cluster assignment:')
        post_per_class_weight = np.matmul(self.mixingw_c * pis, y_hat_means)
        print(np.round(post_per_class_weight, 2))
        print('\n')

        return yhat_means, curr_per_class_weight, post_per_class_weight

    def calculate_loss(self, x, y=None, beta=1., average=False, head=None):
        '''
        :param x: input image(s)
        :param beta: a hyperparam for warmup
        :param average: whether to average loss or not
        :return: value of a loss function
        '''

        x_mean, x_logvar, z_q, z_q_mean, z_q_logvar, y_hat = self.forward(x)

        # RE
        if self.args.input_type == 'binary':
            RE = log_Bernoulli(x, x_mean, dim=1)
        elif self.args.input_type in ['gray', 'color']:
            #RE = -log_Logistic_256(x, x_mean, x_logvar, dim=1)
            RE = -(x - x_mean).abs().sum(dim=1)

        #elif self.args.input_type == 'color':
        #    RE = -log_Normal_diag(x, x_mean, x_logvar, dim=1)
        else:
            raise Exception('Wrong input type!')

        # KL
        log_p_z = self.log_p_z(z_q)
        log_q_z = log_Normal_diag(z_q, z_q_mean, z_q_logvar, dim=1)
        KL = -(log_p_z - log_q_z)

        # loss
        loss = -RE + beta * KL  #+ self.args.Lambda * CE

        if average:
            loss = torch.mean(loss)
            RE = torch.mean(RE)
            KL = torch.mean(KL)

        if y is None:
            return loss, RE, KL, x_mean

        # CE
        if len(y.shape) == 1:
            CE = F.nll_loss(torch.log(y_hat), y)
        else:
            CE = -(y * torch.log(y_hat)).mean()

        # loss
        loss += self.args.Lambda * CE

        if average:
            CE = torch.mean(CE)

        return loss, RE, KL, CE, x_mean

    def calculate_lower_bound(self, X_full, MB=100):
        # CALCULATE LOWER BOUND:
        lower_bound = 0.
        RE_all = 0.
        KL_all = 0.

        # dealing the case where the last batch is of size 1
        remainder = X_full.size(0) % MB
        if remainder == 1:
            X_full = X_full[:(X_full.size(0) - remainder)]

        I = int(math.ceil(X_full.size(0) / MB))

        for i in range(I):
            if not self.args.dataset_name == 'celeba':
                x = X_full[i * MB:(i + 1) * MB].view(
                    -1, np.prod(self.args.input_size))
            else:
                x = X_full[i * MB:(i + 1) * MB]

            loss, RE, KL, _ = self.calculate_loss(x, average=True)

            #RE_all += RE.item()
            #KL_all += KL.item()
            lower_bound += loss

        lower_bound /= I

        return lower_bound

    def calculate_likelihood(self, X, dir, mode='test', S=5000, MB=100):

        # set auxiliary variables for number of training and test sets
        N_test = X.size(0)

        # init list
        likelihood_test = []

        if S <= MB:
            R = 1
        else:
            R = S / MB
            S = MB

        for j in range(N_test):
            if j % 100 == 0:
                print('{:.2f}%'.format(j / (1. * N_test) * 100))
            # Take x*
            x_single = X[j].unsqueeze(0)

            a = []
            for r in range(0, int(R)):
                # Repeat it for all training points
                if self.args.dataset_name == 'celeba':
                    x = x_single.expand(S, x_single.size(1), x_single.size(2),
                                        x_single.size(3))
                else:
                    x = x_single.expand(S, x_single.size(1))

                a_tmp, _, _, _ = self.calculate_loss(x)

                a.append(-a_tmp.cpu().data.numpy())

            # calculate max
            a = np.asarray(a)
            a = np.reshape(a, (a.shape[0] * a.shape[1], 1))
            likelihood_x = logsumexp(a)
            likelihood_test.append(likelihood_x - np.log(len(a)))

        likelihood_test = np.array(likelihood_test)

        #plot_histogram(-likelihood_test, dir, mode)

        return -np.mean(likelihood_test)

    def calculate_accuracy(self, X_full, y_full, MB=100):
        # CALCULATE ACCURACY:
        acc = 0.

        # dealing the case where the last batch is of size 1
        remainder = X_full.size(0) % MB
        if remainder == 1:
            X_full = X_full[:(X_full.size(0) - remainder)]
            y_full = y_full[:(y_full.size(0) - remainder)]

        I = int(math.ceil(X_full.size(0) / MB))

        for i in range(I):
            if not self.args.dataset_name == 'celeba':
                x = X_full[i * MB:(i + 1) * MB].view(
                    -1, np.prod(self.args.input_size))
                y = y_full[i * MB:(i + 1) * MB]
            else:
                x = X_full[i * MB:(i + 1) * MB]
                y = y_full[i * MB:(i + 1) * MB]

            _, _, _, _, _, y_hat = self.forward(x)
            _, predicted = torch.max(y_hat.data, 1)

            correct = (predicted == y).sum().item()

            acc += correct

        acc /= I

        return acc

    # ADDITIONAL METHODS
    def generate_x(self, N=25, replay=False):

        if self.args.prior == 'standard':
            z_sample_rand = Variable(
                torch.FloatTensor(N, self.args.z1_size).normal_())
            if self.args.cuda:
                z_sample_rand = z_sample_rand.cuda()

            samples_rand, _ = self.p_x(z_sample_rand)
        elif self.args.prior == 'vampprior':
            clsts = np.random.choice(range(self.Kmog),
                                     N,
                                     p=self.pis.data.cpu().numpy())

            means = self.means(self.idle_input)
            if self.args.dataset_name == 'celeba':
                means = means.reshape(means.size(0), 3, 64, 64)
            z_sample_gen_mean, z_sample_gen_logvar = self.q_z(means)
            z_sample_rand = self.reparameterize(z_sample_gen_mean,
                                                z_sample_gen_logvar)

            samples_rand, _ = self.p_x(z_sample_rand)

            # do a random permutation to see a more representative sampleset

            randperm = torch.randperm(samples_rand.size(0))
            samples_rand = samples_rand[randperm][:N]

        elif self.args.prior == 'vampprior_short':
            if self.args.use_vampmixingw:
                pis = self.mixingw(self.idle_input).squeeze()
            else:
                pis = torch.ones(
                    self.args.number_components) / self.args.number_components

            if self.args.use_mixingw_correction and replay:
                mixingw_c = torch.from_numpy(self.mixingw_c)
                if self.args.cuda:
                    mixingw_c = mixingw_c.type(torch.cuda.FloatTensor)
                else:
                    mixingw_c = mixingw_c.type(torch.FloatTensor)
                pis = mixingw_c * pis

            clsts = np.random.choice(range(self.args.number_components),
                                     N,
                                     p=pis.data.cpu().numpy())

            K = self.means.linear.weight.size(1)
            eye = torch.eye(K, K)
            if self.args.cuda: eye = eye.cuda()

            means = self.means(eye)[clsts, :]

            # if used in the separated means case, always use the first head. Therefore you need to merge the means before generation
            z_sample_gen_mean, z_sample_gen_logvar = self.q_z_mean(
                means), self.q_z_logvar(means)
            z_sample_rand = self.reparameterize(z_sample_gen_mean,
                                                z_sample_gen_logvar)

            samples_rand, _ = self.p_x(z_sample_rand)

            # do a random permutation to see a more representative sampleset
            #randperm = torch.randperm(samples_rand.size(0))
            #samples_rand = samples_rand[randperm][:N]

        # generate soft labels:
        y_rand = self.semi_supervisor(z_sample_rand)

        return samples_rand, y_rand

    def reconstruct_x(self, x):
        x_mean, _, _, _, _, _ = self.forward(x)
        return x_mean

    # THE MODEL: VARIATIONAL POSTERIOR
    def q_z(self, x):
        x = self.q_z_layers(x)

        x = x.squeeze()

        z_q_mean = self.q_z_mean(x)
        z_q_logvar = self.q_z_logvar(x)
        return z_q_mean, z_q_logvar

    # THE MODEL: GENERATIVE DISTRIBUTION
    def p_x(self, z):

        if self.args.dataset_name == 'celeba':
            z = z.unsqueeze(-1).unsqueeze(-1)
        z = self.p_x_layers(z)

        x_mean = self.p_x_mean(z)

        #if self.args.input_type == 'binary':
        x_logvar = 0.
        #else:
        #x_mean = torch.clamp(x_mean, min=0.+1./512., max=1.-1./512.)

        #x_logvar = self.p_x_logvar[head](z)

        x_mean = x_mean.reshape(x_mean.size(0), -1)
        if self.args.dataset_name == 'celeba':
            x_logvar = x_logvar.reshape(x_logvar.size(0), -1)
        return x_mean, x_logvar

    # the prior
    def log_p_z(self, z, mu=None, logvar=None):
        if self.args.prior == 'standard':
            log_prior = log_Normal_standard(z, dim=1)

        elif self.args.prior == 'vampprior':
            # z - MB x M
            C = self.args.number_components

            # calculate params
            X = self.means(self.idle_input)
            if self.args.dataset_name == 'celeba':
                X = X.reshape(X.size(0), 3, 64, 64)

            # calculate params for given data
            z_p_mean, z_p_logvar = self.q_z(X)  # C x M

            # expand z
            z_expand = z.unsqueeze(1)
            means = z_p_mean.unsqueeze(0)
            logvars = z_p_logvar.unsqueeze(0)

            a = log_Normal_diag(z_expand, means, logvars, dim=2) - math.log(
                C)  # MB x C
            a_max, _ = torch.max(a, 1)  # MB x 1

            # calculte log-sum-exp
            log_prior = a_max + torch.log(
                torch.sum(torch.exp(a - a_max.unsqueeze(1)), 1))  # MB x 1
        elif self.args.prior == 'vampprior_short':
            C = self.args.number_components

            K = self.means.linear.weight.size(1)
            eye = torch.eye(K, K)
            if self.args.cuda:
                eye = eye.cuda()

            X = self.means(eye)

            z_p_mean = self.q_z_mean(X)
            z_p_logvar = self.q_z_logvar(X)

            # expand z
            z_expand = z.unsqueeze(1)
            means = z_p_mean.unsqueeze(0)
            logvars = z_p_logvar.unsqueeze(0)

            # havent yet implemented dealing with mixing weights in the separated means case
            if self.args.use_vampmixingw:
                pis = self.mixingw(eye).t()
                eps = 1e-30
                a = log_Normal_diag(z_expand, means, logvars,
                                    dim=2) + torch.log(pis)  # MB x C
            else:
                a = log_Normal_diag(z_expand, means, logvars,
                                    dim=2) - math.log(C)  # MB x C

            a_max, _ = torch.max(a, 1)  # MB x 1
            # calculte log-sum-exp
            log_prior = a_max + torch.log(
                torch.sum(torch.exp(a - a_max.unsqueeze(1)), 1))  # MB x 1

        else:
            raise Exception('invalid prior!')

        return log_prior

    # THE MODEL: FORWARD PASS
    def forward(self, x):
        # z ~ q(z | x)
        z_q_mean, z_q_logvar = self.q_z(x)
        z_q = self.reparameterize(z_q_mean, z_q_logvar)

        # x_mean = p(x|z)
        x_mean, x_logvar = self.p_x(z_q)

        y_hat = self.semi_supervisor(z_q_mean)
        return x_mean, x_logvar, z_q, z_q_mean, z_q_logvar, y_hat

    def get_embeddings(self, train_loader, cuda=True, flatten=True):
        # get hhats for all batches
        if self.args.dataset_name == 'celeba':
            nbatches = 300
        else:
            nbatches = 1000

        all_hhats = []
        for i, (data, _) in enumerate(it.islice(train_loader, 0, nbatches, 1)):
            if cuda:
                data = data.cuda()

            if self.args.dataset_name != 'celeba':
                data = data.view(-1, np.prod(self.args.input_size))

            mu, logvar = self.q_z(data)
            hhat = torch.randn(mu.size()).cuda() * (0.5 * logvar).exp() + mu

            all_hhats.append(hhat.data.squeeze())
            print('processing batch {}'.format(i))

        return torch.cat(all_hhats, dim=0)