Exemplo n.º 1
0
    def calculate_loss(self, x, beta=1., average=False):
        # pass through VAE
        x_mean, x_logvar, z1_q, z1_q_mean, z1_q_logvar, z2_q, z2_q_mean, z2_q_logvar, z1_p_mean, z1_p_logvar = self.forward(x)

        # RE
        if self.args.input_type == 'binary':
            RE = log_Bernoulli(x, x_mean, dim=1)
        elif self.args.input_type == 'gray' or self.args.input_type == 'continuous':
            RE = -log_Logistic_256(x, x_mean, x_logvar, dim=1)
        else:
            raise Exception('Wrong input type!')

        # KL
        log_p_z1 = log_Normal_diag(z1_q, z1_p_mean, z1_p_logvar, dim=1)
        log_q_z1 = log_Normal_diag(z1_q, z1_q_mean, z1_q_logvar, dim=1)
        log_p_z2 = self.log_p_z2(z2_q)
        log_q_z2 = log_Normal_diag(z2_q, z2_q_mean, z2_q_logvar, dim=1)
        KL = -(log_p_z1 + log_p_z2 - log_q_z1 - log_q_z2)

        # full loss
        loss = -RE + beta * KL

        if average:
            loss = torch.mean(loss)
            RE = torch.mean(RE)
            KL = torch.mean(KL)

        return loss, RE, KL
Exemplo n.º 2
0
    def calculate_loss(self, x, beta=1., average=False):
        '''
        :param x: input image(s)
        :param beta: a hyperparam for warmup
        :param average: whether to average loss or not
        :return: value of a loss function
        '''
        # pass through VAE
        x_mean, x_logvar, z1_q, z1_q_mean, z1_q_logvar, z2_q, z2_q_mean, z2_q_logvar, z1_p_mean, z1_p_logvar = self.forward(x)

        # RE
        if self.args.input_type == 'binary':
            RE = log_Bernoulli(x, x_mean, dim=1)
        elif self.args.input_type == 'gray' or self.args.input_type == 'continuous':
            RE = -log_Logistic_256(x, x_mean, x_logvar, dim=1)
        else:
            raise Exception('Wrong input type!')

        # KL
        log_p_z1 = log_Normal_diag(z1_q, z1_p_mean, z1_p_logvar, dim=1)
        log_q_z1 = log_Normal_diag(z1_q, z1_q_mean, z1_q_logvar, dim=1)
        log_p_z2 = self.log_p_z2(z2_q)
        log_q_z2 = log_Normal_diag(z2_q, z2_q_mean, z2_q_logvar, dim=1)
        KL = -(log_p_z1 + log_p_z2 - log_q_z1 - log_q_z2)

        loss = -RE + beta * KL

        if average:
            loss = torch.mean(loss)
            RE = torch.mean(RE)
            KL = torch.mean(KL)

        return loss, RE, KL
Exemplo n.º 3
0
    def calculate_loss(self, x, beta=1., average=False):
        # pass through VAE
        x_mean, x_logvar, z1_q, z1_q_mean, z1_q_logvar, z2_q, z2_q_mean, z2_q_logvar, z1_p_mean, z1_p_logvar = self.forward(
            x)

        # RE
        if self.args.input_type == 'binary':
            RE = log_Bernoulli(x, x_mean, dim=1)
        elif self.args.input_type == 'gray' or self.args.input_type == 'continuous':
            print(x.shape, x_mean.shape, x_logvar.shape)
            RE = -log_Logistic_256(x, x_mean, x_logvar, dim=1)
        else:
            raise Exception('Wrong input type!')

        # KL
        log_p_z1 = log_Normal_diag(z1_q, z1_p_mean, z1_p_logvar, dim=1)
        log_q_z1 = log_Normal_diag(z1_q, z1_q_mean, z1_q_logvar, dim=1)
        log_p_z2 = self.log_p_z2(z2_q)
        log_q_z2 = log_Normal_diag(z2_q, z2_q_mean, z2_q_logvar, dim=1)
        KL = -(log_p_z1 + log_p_z2 - log_q_z1 - log_q_z2)

        # full loss
        loss = -RE + beta * KL

        if average:
            loss = torch.mean(loss)
            RE = torch.mean(RE)
            KL = torch.mean(KL)

        return loss, RE, KL
Exemplo n.º 4
0
    def calculate_lower_bound(self, X_full):
        # CALCULATE LOWER BOUND:
        lower_bound = 0.
        RE_all = 0.
        KL_all = 0.

        MB = 100

        for i in range(X_full.size(0) / MB):
            x = X_full[i * MB:(i + 1) * MB].view(-1,
                                                 np.prod(self.args.input_size))

            # pass through VAE
            x_mean, x_logvar, z1_q, z1_q_mean, z1_q_logvar, z2_q, z2_q_mean, z2_q_logvar, z1_p_mean, z1_p_logvar = self.forward(
                x)

            # RE
            RE = log_Bernoulli(x, x_mean)

            # KL
            log_p_z2 = self.log_p_z2(z2_q)
            log_p_z1 = log_Normal_diag(z1_q, z1_p_mean, z1_p_logvar, dim=1)
            log_q_z1 = log_Normal_diag(z1_q, z1_q_mean, z1_q_logvar, dim=1)
            log_q_z2 = log_Normal_diag(z2_q, z2_q_mean, z2_q_logvar, dim=1)
            KL = -torch.sum(log_p_z1 + log_p_z2 - log_q_z1 - log_q_z2)

            RE_all += RE.cpu().data[0]
            KL_all += KL.cpu().data[0]

            # CALCULATE LOWER-BOUND: RE + KL - ln(N)
            lower_bound += (-RE + KL).cpu().data[0]

        lower_bound = lower_bound / X_full.size(0)

        return lower_bound
Exemplo n.º 5
0
def train_vae_vampprior_2level(epoch, args, train_loader, model, optimizer):
    # set loss to 0
    train_loss = 0
    train_re = 0
    train_kl = 0
    # set model in training mode
    model.train()

    # start training
    if args.warmup == 0:
        beta = 1.
    else:
        beta = 1. * (epoch - 1) / args.warmup
        if beta > 1.:
            beta = 1.
    print('beta: {}'.format(beta))

    for batch_idx, (data, target) in enumerate(train_loader):
        if args.cuda:
            data, target = data.cuda(), target.cuda()
        data, target = Variable(data), Variable(target)
        # dynamic binarization
        if args.dynamic_binarization:
            x = torch.bernoulli(data)
        else:
            x = data
        # reset gradients
        optimizer.zero_grad()
        # forward pass
        x_mean, x_logvar, z1_q, z1_q_mean, z1_q_logvar, z2_q, z2_q_mean, z2_q_logvar, z1_p_mean, z1_p_logvar = model.forward(
            x)
        # loss function
        # RE
        RE = log_Bernoulli(x, x_mean)

        # KL
        log_p_z2 = model.log_p_z2(z2_q)
        log_p_z1 = log_Normal_diag(z1_q, z1_p_mean, z1_p_logvar, dim=1)
        log_q_z1 = log_Normal_diag(z1_q, z1_q_mean, z1_q_logvar, dim=1)
        log_q_z2 = log_Normal_diag(z2_q, z2_q_mean, z2_q_logvar, dim=1)
        KL = beta * (-torch.sum(log_p_z1 + log_p_z2 - log_q_z1 - log_q_z2))

        loss = (-RE + KL) / data.size(0)
        # backward pass
        loss.backward()
        # optimization
        optimizer.step()

        train_loss += loss.data[0]
        train_re += (-RE / data.size(0)).data[0]
        train_kl += (KL / data.size(0)).data[0]

    # calculate final loss
    train_loss /= len(
        train_loader)  # loss function already averages over batch size
    train_re /= len(train_loader)  # re already averages over batch size
    train_kl /= len(train_loader)  # kl already averages over batch size

    return model, train_loss, train_re, train_kl
Exemplo n.º 6
0
    def calculate_likelihood(self, X, dir, mode='test', S=5000):
        # set auxiliary variables for number of training and test sets
        N_test = X.size(0)

        # init list
        likelihood_test = []

        MB = 100

        if S <= MB:
            R = 1
        else:
            R = S / MB
            S = MB

        for j in range(N_test):
            if j % 100 == 0:
                print('{:.2f}%'.format(j / (1. * N_test) * 100))
            # Take x*
            x_single = X[j].unsqueeze(0)

            a = []
            for r in range(0, R):
                # Repeat it for all training points
                x = x_single.expand(S, x_single.size(1))

                # pass through VAE
                x_mean, x_logvar, z1_q, z1_q_mean, z1_q_logvar, z2_q, z2_q_mean, z2_q_logvar, z1_p_mean, z1_p_logvar = self.forward(
                    x)

                # RE
                RE = log_Bernoulli(x, x_mean, dim=1)

                # KL
                log_p_z2 = self.log_p_z2(z2_q)
                log_p_z1 = log_Normal_diag(z1_q, z1_p_mean, z1_p_logvar, dim=1)
                log_q_z1 = log_Normal_diag(z1_q, z1_q_mean, z1_q_logvar, dim=1)
                log_q_z2 = log_Normal_diag(z2_q, z2_q_mean, z2_q_logvar, dim=1)
                KL = -(log_p_z1 + log_p_z2 - log_q_z1 - log_q_z2)

                a_tmp = (RE - KL)

                a.append(a_tmp.cpu().data.numpy())

            # calculate max
            a = np.asarray(a)
            a = np.reshape(a, (a.shape[0] * a.shape[1], 1))
            likelihood_x = logsumexp(a)
            likelihood_test.append(likelihood_x - np.log(S))

        likelihood_test = np.array(likelihood_test)

        plot_histogram(-likelihood_test, dir, mode)

        return -np.mean(likelihood_test)
Exemplo n.º 7
0
    def log_p_z2(self, z2):
        # z1 - MB x M
        # X - N x D
        MB = z2.size(0)
        C = self.args.number_components
        M = z2.size(1)

        # calculate params for given data
        X = self.means(self.idle_input)

        # calculate params for given data
        z_p_mean, z_p_logvar = self.q_z2(X)  #C x M

        # expand z
        z_expand = z2.unsqueeze(1).expand(MB, C, M)
        means = z_p_mean.unsqueeze(0).expand(MB, C, M)
        logvars = z_p_logvar.unsqueeze(0).expand(MB, C, M)

        a = log_Normal_diag(z_expand, means, logvars,
                            dim=2).squeeze(2) - math.log(C)  # MB x C
        a_max, _ = torch.max(a, 1)  # MB x 1
        # calculte log-sum-exp
        log_p_z = a_max + torch.log(
            torch.sum(torch.exp(a - a_max.expand(MB, C)), 1))  # MB x 1

        return log_p_z
Exemplo n.º 8
0
    def calculate_loss(self, x, beta=1., average=False):
        '''
        :param x: input image(s)
        :param beta: a hyperparam for warmup
        :param average: whether to average loss or not
        :return: value of a loss function
        '''
        # pass through VAE
        x_mean, x_logvar, z_q, z_q_mean, z_q_logvar = self.forward(x)

        # RE
        if self.args.input_type == 'binary':
            RE = log_Bernoulli(x, x_mean, dim=1)
        elif self.args.input_type == 'gray' or self.args.input_type == 'continuous':
            RE = -log_Logistic_256(x, x_mean, x_logvar, dim=1)
        else:
            raise Exception('Wrong input type!')

        # KL
        log_p_z = self.log_p_z(z_q)
        log_q_z = log_Normal_diag(z_q, z_q_mean, z_q_logvar, dim=1)
        KL = -(log_p_z - log_q_z)

        loss = - RE + beta * KL

        if average:
            loss = torch.mean(loss)
            RE = torch.mean(RE)
            KL = torch.mean(KL)

        return loss, RE, KL
Exemplo n.º 9
0
    def calculate_loss(self, x, beta=1., average=False):
        '''
        :param x: input image(s)
        :param beta: a hyperparam for warmup
        :param average: whether to average loss or not
        :return: value of a loss function
        '''
        # pass through VAE
        x = x.view(-1, np.prod(self.args.input_size))
        x_mean, x_logvar, z_q, z_q_mean, z_q_logvar = self.forward(x)

        # RE
        if self.args.input_type == 'binary':
            RE = log_Bernoulli(x, x_mean, dim=1)
        elif self.args.input_type == 'gray' or self.args.input_type == 'continuous':
            RE = -log_Logistic_256(x, x_mean, x_logvar, dim=1)
        else:
            raise Exception('Wrong input type!')

        # KL
        log_p_z = self.log_p_z(z_q)
        log_q_z = log_Normal_diag(z_q, z_q_mean, z_q_logvar, dim=1)
        KL = -(log_p_z - log_q_z)
        if self.isnan(z_q_mean.data[0][0]):
            print("mean:")
            print(z_q_mean)
        if self.isnan(z_q_logvar.data[0][0]):
            print("var:")
            print(z_q_logvar)

        loss = -RE + beta * KL

        #FI
        if self.args.FI is True:
            FI, gamma = self.FI(x)
            #loss -= torch.mean(FI * gamma, dim = 1)
        else:
            FI, gamma = self.FI(x)
            FI *= 0.
        #FI = (torch.mean((torch.log(2*torch.pow(torch.exp( z_q_logvar ),2) + 1) - 2 * z_q_logvar)) - self.args.M ).abs()
        #FI = (torch.mean((1/torch.exp( z_q_logvar ) + 1/(2*torch.pow( torch.exp( z_q_logvar ), 2 )))) - self.args.M ).abs()

        # MI
        if self.args.MI is True:
            MI = self.MI(x)
            #loss += self.args.ksi * (MI -self.args.M).abs()
        else:
            MI = self.MI(x) * 0.

        if self.args.adv is True:
            loss += self.args.ksi * (torch.exp(MI) - FI).abs()

        if average:
            loss = torch.mean(loss)
            RE = torch.mean(RE)
            KL = torch.mean(KL)
            FI = torch.mean(FI)
            MI = torch.mean(torch.exp(MI))

        return loss, RE, KL, FI, MI
Exemplo n.º 10
0
    def log_p_z2(self, z2):
        # vamp prior
        # z2 - MB x M
        C = self.args.number_components

        # calculate params
        X = self.means(self.idle_input)

        # calculate params for given data
        z2_p_mean, z2_p_logvar = self.q_z2(X)  # C x M

        # expand z
        z_expand = z2.unsqueeze(1)
        means = z2_p_mean.unsqueeze(0)
        logvars = z2_p_logvar.unsqueeze(0)

        a = log_Normal_diag(z_expand, means, logvars, dim=2) - math.log(
            C)  # MB x C
        a_max, _ = torch.max(a, 1)  # MB
        # calculte log-sum-exp
        log_prior = (a_max +
                     torch.log(torch.sum(torch.exp(a - a_max.unsqueeze(1)), 1))
                     )  # MB

        return log_prior
Exemplo n.º 11
0
    def calculate_loss(self, x, beta=1., average=False):
        '''
        :param x: input image(s)
        :param beta: a hyperparam for warmup
        :param average: whether to average loss or not
        :return: value of a loss function
        '''
        # pass through VAE
        x_mean, x_logvar, z_q, z_q_mean, z_q_logvar = self.forward(x)

        # RE
        RE = log_Softmax(
            x, x_mean,
            dim=1)  #! Actually not Reconstruction Error but Log-Likelihood

        # KL
        log_p_z = self.log_p_z(z_q)
        log_q_z = log_Normal_diag(z_q, z_q_mean, z_q_logvar, dim=1)
        KL = -(log_p_z - log_q_z)

        loss = -RE + beta * KL

        if average:
            loss = torch.mean(loss)
            RE = torch.mean(RE)
            KL = torch.mean(KL)

        return loss, RE, KL
Exemplo n.º 12
0
    def calculate_loss(self, x, beta=1., average=False):
        '''
        :param x: input image(s)
        :param beta: a hyperparam for warmup
        :param average: whether to average loss or not
        :return: value of a loss function
        '''
        # pass through VAE
        x_mean, x_logvar, z_q, z_q_mean, z_q_logvar = self.forward(x)

        # RE
        if self.args.input_type == 'binary':
            RE = log_Bernoulli(x, x_mean, dim=1)
        elif self.args.input_type == 'multinomial':
            RE = log_Softmax(
                x, x_mean,
                dim=1)  #! Actually not Reconstruction Error but Log-Likelihood
        elif self.args.input_type == 'gray' or self.args.input_type == 'continuous':
            RE = -log_Logistic_256(x, x_mean, x_logvar, dim=1)
        else:
            raise Exception('Wrong input type!')

        # KL
        log_p_z = self.log_p_z(z_q)
        log_q_z = log_Normal_diag(z_q, z_q_mean, z_q_logvar, dim=1)
        KL = -(log_p_z - log_q_z)

        loss = -RE + beta * KL

        if average:
            loss = torch.mean(loss)
            RE = torch.mean(RE)
            KL = torch.mean(KL)

        return loss, RE, KL
Exemplo n.º 13
0
    def log_p_z2(self, z2):
        if self.args.prior == 'standard':
            log_prior = log_Normal_standard(z2, dim=1)

        elif self.args.prior == 'vampprior':
            # z - MB x M
            C = self.args.number_components

            # calculate params
            X = self.means(self.idle_input).view(-1, self.args.input_size[0], self.args.input_size[1], self.args.input_size[2])

            # calculate params for given data
            z2_p_mean, z2_p_logvar = self.q_z2(X)  # C x M)

            # expand z
            z_expand = z2.unsqueeze(1)
            means = z2_p_mean.unsqueeze(0)
            logvars = z2_p_logvar.unsqueeze(0)

            a = log_Normal_diag(z_expand, means, logvars, dim=2) - math.log(C)  # MB x C
            a_max, _ = torch.max(a, 1)  # MB
            # calculte log-sum-exp
            log_prior = (a_max + torch.log(torch.sum(torch.exp(a - a_max.unsqueeze(1)), 1)))  # MB

        else:
            raise Exception('Wrong name of the prior!')

        return log_prior
Exemplo n.º 14
0
    def log_p_z(self, z):
        if self.args.prior == 'standard':
            log_prior = log_Normal_standard(z, dim=1)

        elif self.args.prior == 'vampprior':
            # z - MB x M
            C = self.args.number_components

            # calculate params
            X = self.means(self.idle_input)

            # calculate params for given data
            X = X.view(X.shape[0],1,self.args.input_size[1],self.args.input_size[2])
            z_p_mean, z_p_logvar = self.q_z(X)  # C x M

            # expand z
            z_expand = z.unsqueeze(1)
            means = z_p_mean.unsqueeze(0)
            logvars = z_p_logvar.unsqueeze(0)

            a = log_Normal_diag(z_expand, means, logvars, dim=2) - math.log(C)  # MB x C
            a_max, _ = torch.max(a, 1)  # MB x 1

            # calculte log-sum-exp
            log_prior = a_max + torch.log(torch.sum(torch.exp(a - a_max.unsqueeze(1)), 1))  # MB x 1

        else:
            raise Exception('Wrong name of the prior!')

        return log_prior
Exemplo n.º 15
0
    def log_p_z2(self, z2):
        if self.args.prior == 'standard':
            log_prior = log_Normal_standard(z2, dim=1)

        elif self.args.prior == 'vampprior':
            # z2 - MB x M
            C = self.args.number_components

            # calculate params
            X = self.means(self.idle_input)

            # calculate params for given data
            z2_p_mean, z2_p_logvar = self.q_z2(X)  # C x M

            # expand z
            z_expand = z2.unsqueeze(1)
            means = z2_p_mean.unsqueeze(0)
            logvars = z2_p_logvar.unsqueeze(0)

            a = log_Normal_diag(z_expand, means, logvars, dim=2) - math.log(C)  # MB x C
            a_max, _ = torch.max(a, 1)  # MB
            # calculte log-sum-exp
            log_prior = (a_max + torch.log(torch.sum(torch.exp(a - a_max.unsqueeze(1)), 1)))  # MB

        else:
            raise Exception('Wrong name of the prior!')

        return log_prior
Exemplo n.º 16
0
    def calculate_lower_bound(self, X_full):
        # CALCULATE LOWER BOUND:
        lower_bound = 0.
        RE_all = 0.
        KL_all = 0.

        MB = 500

        for i in range(int(X_full.size(0) / MB)):
            x = X_full[i * MB:(i + 1) * MB].view(-1, np.prod(self.input_size))

            # pass through VAE
            x_mean, x_logvar, z_q, z_q_mean, z_q_logvar = self.forward(x)
            x_mean = x_mean.double()
            # RE
            RE = log_Bernoulli(x, x_mean)

            # KL
            log_p_z = log_Normal_standard(z_q, dim=1)
            log_q_z = log_Normal_diag(z_q, z_q_mean, z_q_logvar, dim=1)
            KL = -torch.sum(log_p_z - log_q_z)

            RE_all += RE.cpu().item()
            KL_all += KL.cpu().item()

            # CALCULATE LOWER-BOUND: RE + KL - ln(N)
            lower_bound += (-RE + KL).cpu().item()

        lower_bound = lower_bound / X_full.size(0)

        return lower_bound
Exemplo n.º 17
0
 def vae_re(x, x_decoded_mean):
     # RE part
     if config.data_type == 'binary':
         re_loss = log_Bernoulli(x, x_decoded_mean)
         RE = K.mean(K.sum(re_loss, axis=1), axis=0)
     elif config.data_type == 'gray':
         re_loss = log_Normal_diag(x, x_decoded_mean, x_decoded_log_var)
         RE = K.mean(K.sum(re_loss, axis=1), axis=0) + 0.5*config.original_dim*np.log(2*np.pi)
     else:
         raise ValueError
     return -RE
Exemplo n.º 18
0
    def calculate_loss(self, x, y=None, beta=1., average=False, head=None):
        '''
        :param x: input image(s)
        :param beta: a hyperparam for warmup
        :param average: whether to average loss or not
        :return: value of a loss function
        '''

        x_mean, x_logvar, z_q, z_q_mean, z_q_logvar, y_hat = self.forward(x)

        # RE
        if self.args.input_type == 'binary':
            RE = log_Bernoulli(x, x_mean, dim=1)
        elif self.args.input_type in ['gray', 'color']:
            #RE = -log_Logistic_256(x, x_mean, x_logvar, dim=1)
            RE = -(x - x_mean).abs().sum(dim=1)

        #elif self.args.input_type == 'color':
        #    RE = -log_Normal_diag(x, x_mean, x_logvar, dim=1)
        else:
            raise Exception('Wrong input type!')

        # KL
        log_p_z = self.log_p_z(z_q)
        log_q_z = log_Normal_diag(z_q, z_q_mean, z_q_logvar, dim=1)
        KL = -(log_p_z - log_q_z)

        # loss
        loss = -RE + beta * KL  #+ self.args.Lambda * CE

        if average:
            loss = torch.mean(loss)
            RE = torch.mean(RE)
            KL = torch.mean(KL)

        if y is None:
            return loss, RE, KL, x_mean

        # CE
        if len(y.shape) == 1:
            CE = F.nll_loss(torch.log(y_hat), y)
        else:
            CE = -(y * torch.log(y_hat)).mean()

        # loss
        loss += self.args.Lambda * CE

        if average:
            CE = torch.mean(CE)

        return loss, RE, KL, CE, x_mean
Exemplo n.º 19
0
    def vae_loss(x, x_decoded_mean):
        # RE part
        if config.data_type == 'binary':
            re_loss = log_Bernoulli(x, x_decoded_mean)
            RE = K.mean(K.sum(re_loss, axis=1), axis=0)
        elif config.data_type == 'gray':
            re_loss = log_Normal_diag(x, x_decoded_mean, x_decoded_log_var)
            RE = K.mean(K.sum(re_loss, axis=1), axis=0) + 0.5*config.original_dim*np.log(2*np.pi)
        else:
            raise ValueError

        #KL part
        log_q = log_Normal_diag(z_0, z_mean, z_log_var)
        log_p = log_Normal_standard( z[str(config.number_of_flows)] )
        kl =  log_q - log_p

        if config.regularization == 'none':
            KL = K.mean( K.sum( kl, axis=1 ), axis=0 )
        elif config.regularization == 'warmup':
            KL = K.mean( K.sum( kl, axis=1 ), axis=0 ) * vae.beta
        else:
            raise Exception('wrong regularization name')

        return -RE + KL
Exemplo n.º 20
0
    def forward(self, HR_feat, LR, down_ref):
        z_q_mu, z_q_logvar = self.encode(HR_feat)

        # reparameterize
        z_q = self.reparameterize(z_q_mu, z_q_logvar, flag=0)
        # prior
        log_p_z = self.log_p_z(z_q, self.idle_input, self.number_component,
                               self.prior)
        # KL
        log_q_z = log_Normal_diag(z_q, z_q_mu, z_q_logvar, dim=1)
        KL = -(log_p_z - log_q_z)
        KL = torch.sum(KL)

        Denoise_LR, SR = self.decode(LR, down_ref, z_q)

        return Denoise_LR, SR, KL
Exemplo n.º 21
0
    def calculate_loss(self,
                       x,
                       beta=1.,
                       average=False,
                       head=0,
                       use_mixw_cor=False):
        '''
        :param x: input image(s)
        :param beta: a hyperparam for warmup
        :param average: whether to average loss or not
        :return: value of a loss function
        '''

        # pass through VAE
        fw_head = min(head, len(self.q_z_layers) - 1)

        x_mean, x_logvar, z_q, z_q_mean, z_q_logvar = self.forward(
            x, head=fw_head)

        # RE
        if self.args.input_type == 'binary':
            RE = log_Bernoulli(x, x_mean, dim=1)
        elif self.args.input_type in ['gray', 'color']:
            #RE = -log_Logistic_256(x, x_mean, x_logvar, dim=1)
            RE = -(x - x_mean).abs().sum(dim=1)

        #elif self.args.input_type == 'color':
        #    RE = -log_Normal_diag(x, x_mean, x_logvar, dim=1)
        else:
            raise Exception('Wrong input type!')

        # KL
        if self.args.prior == 'GMM':
            log_p_z = self.log_p_z(z_q, mu=z_q_mean, logvar=z_q_logvar)
        else:
            log_p_z = self.log_p_z(z_q, head=head, use_mixw_cor=use_mixw_cor)
        log_q_z = log_Normal_diag(z_q, z_q_mean, z_q_logvar, dim=1)
        KL = -(log_p_z - log_q_z)

        loss = -RE + beta * KL

        if average:
            loss = torch.mean(loss)
            RE = torch.mean(RE)
            KL = torch.mean(KL)

        return loss, RE, KL, x_mean
Exemplo n.º 22
0
    def MI(self, x):
        x_mean, x_logvar, z_q, z_q_mean, z_q_logvar = self.forward(x)
        z_q_mean, z_q_logvar, _ = self.q_z(x_mean)
        z_q = self.reparameterize(z_q_mean, z_q_logvar)
        log_q_z = log_Normal_diag(z_q,
                                  z_q_mean,
                                  z_q_logvar,
                                  average=True,
                                  dim=1)
        epsilon = 1e-10

        mi_loss = (
            torch.mean(torch.log(torch.add(torch.exp(log_q_z), epsilon))) -
            self.args.M).abs()

        if self.isnan(mi_loss.data[0]):
            print(log_q_z)

        return mi_loss
Exemplo n.º 23
0
    def log_p_z(self, z, mu=None, logvar=None, head=0, use_mixw_cor=False):
        if self.args.prior == 'standard':
            log_prior = log_Normal_standard(z, dim=1)

        elif self.args.prior == 'vampprior':
            # z - MB x M
            C = self.args.number_components

            # calculate params
            X = self.means[head](self.idle_input)
            if self.args.dataset_name == 'celeba':
                X = X.reshape(X.size(0), 3, 64, 64)

            # calculate params for given data
            z_p_mean, z_p_logvar = self.q_z(X, head=head)  # C x M

            # expand z
            z_expand = z.unsqueeze(1)
            means = z_p_mean.unsqueeze(0)
            logvars = z_p_logvar.unsqueeze(0)

            a = log_Normal_diag(z_expand, means, logvars, dim=2) - math.log(
                C)  # MB x C
            a_max, _ = torch.max(a, 1)  # MB x 1

            # calculte log-sum-exp
            log_prior = a_max + torch.log(
                torch.sum(torch.exp(a - a_max.unsqueeze(1)), 1))  # MB x 1
        elif self.args.prior == 'vampprior_short':
            C = self.args.number_components

            # calculate params
            if self.args.separate_means:
                K = self.means[head].linear.weight.size(1)
                eye = torch.eye(K, K)
                if self.args.cuda:
                    eye = eye.cuda()

                X = self.means[head](eye)
            else:
                K = self.means.linear.weight.size(1)
                eye = torch.eye(K, K)
                if self.args.cuda:
                    eye = eye.cuda()

                X = self.means(eye)

            z_p_mean = self.q_z_mean(X)
            z_p_logvar = self.q_z_logvar(X)

            # expand z
            z_expand = z.unsqueeze(1)
            means = z_p_mean.unsqueeze(0)
            logvars = z_p_logvar.unsqueeze(0)

            # havent yet implemented dealing with mixing weights in the separated means case
            if self.args.use_vampmixingw:
                pis = self.mixingw(eye).t()
                if use_mixw_cor:
                    if self.args.cuda:
                        pis = pis * torch.from_numpy(
                            self.mixingw_c).cuda().unsqueeze(0)
                    else:
                        pis = pis * torch.from_numpy(
                            self.mixingw_c).unsqueeze(0)
                eps = 1e-30
                a = log_Normal_diag(z_expand, means, logvars,
                                    dim=2) + torch.log(pis)  # MB x C
            else:
                a = log_Normal_diag(z_expand, means, logvars,
                                    dim=2) - math.log(C)  # MB x C

            a_max, _ = torch.max(a, 1)  # MB x 1
            # calculte log-sum-exp
            log_prior = a_max + torch.log(
                torch.sum(torch.exp(a - a_max.unsqueeze(1)), 1))  # MB x 1

        elif self.args.prior == 'GMM':
            if self.GMM.covariance_type == 'full':
                #z_np = z.data.cpu().numpy()
                #lls = self.GMM.score_samples(z_np)
                #log_prior = torch.from_numpy(lls).float().cuda() + math.log(2*math.pi)*(0.5*z.size(1))
                log_prior = log_mog_full(z, self.mus, self.sigs, self.pis,
                                         self.icovs_ten, self.det_terms)

            else:
                log_prior = log_mog_diag(z, self.mus, self.sigs, self.pis)
        else:
            raise Exception('invalid prior!')

        return log_prior
Exemplo n.º 24
0
    def log_p_z(self, z, mu=None, logvar=None):
        if self.args.prior == 'standard':
            log_prior = log_Normal_standard(z, dim=1)

        elif self.args.prior == 'vampprior':
            # z - MB x M
            C = self.args.number_components

            # calculate params
            X = self.means(self.idle_input)
            if self.args.dataset_name == 'celeba':
                X = X.reshape(X.size(0), 3, 64, 64)

            # calculate params for given data
            z_p_mean, z_p_logvar = self.q_z(X)  # C x M

            # expand z
            z_expand = z.unsqueeze(1)
            means = z_p_mean.unsqueeze(0)
            logvars = z_p_logvar.unsqueeze(0)

            a = log_Normal_diag(z_expand, means, logvars, dim=2) - math.log(
                C)  # MB x C
            a_max, _ = torch.max(a, 1)  # MB x 1

            # calculte log-sum-exp
            log_prior = a_max + torch.log(
                torch.sum(torch.exp(a - a_max.unsqueeze(1)), 1))  # MB x 1
        elif self.args.prior == 'vampprior_short':
            C = self.args.number_components

            K = self.means.linear.weight.size(1)
            eye = torch.eye(K, K)
            if self.args.cuda:
                eye = eye.cuda()

            X = self.means(eye)

            z_p_mean = self.q_z_mean(X)
            z_p_logvar = self.q_z_logvar(X)

            # expand z
            z_expand = z.unsqueeze(1)
            means = z_p_mean.unsqueeze(0)
            logvars = z_p_logvar.unsqueeze(0)

            # havent yet implemented dealing with mixing weights in the separated means case
            if self.args.use_vampmixingw:
                pis = self.mixingw(eye).t()
                eps = 1e-30
                a = log_Normal_diag(z_expand, means, logvars,
                                    dim=2) + torch.log(pis)  # MB x C
            else:
                a = log_Normal_diag(z_expand, means, logvars,
                                    dim=2) - math.log(C)  # MB x C

            a_max, _ = torch.max(a, 1)  # MB x 1
            # calculte log-sum-exp
            log_prior = a_max + torch.log(
                torch.sum(torch.exp(a - a_max.unsqueeze(1)), 1))  # MB x 1

        else:
            raise Exception('invalid prior!')

        return log_prior
Exemplo n.º 25
0
def evaluate_vae(args, model, train_loader, data_loader, epoch, dir, mode):
    # set loss to 0
    evaluate_loss = 0
    evaluate_re = 0
    evaluate_kl = 0
    # set model to evaluation mode
    model.eval()

    # evaluate
    for batch_idx, (data, target) in enumerate(data_loader):
        if args.cuda:
            data, target = data.cuda(), target.cuda()
        data, target = Variable(data, volatile=True), Variable(target)

        x = data
        # forward pass
        x_mean, x_logvar, z_q, z_q_mean, z_q_logvar = model.forward(x)
        # loss function
        # RE
        RE = log_Bernoulli(x, x_mean)

        # KL
        log_p_z = log_Normal_standard(z_q, dim=1)
        log_q_z = log_Normal_diag(z_q, z_q_mean, z_q_logvar, dim=1)
        KL = - torch.sum(log_p_z - log_q_z)

        evaluate_loss += ((-RE + KL) / data.size(0)).data[0]
        evaluate_re += (-RE / data.size(0)).data[0]
        evaluate_kl += (KL / data.size(0)).data[0]

        # print N digits
        if batch_idx == 1 and mode == 'validation':
            plot_reconstruction(args, x_mean, epoch, dir, size_x=3, size_y=3)
            if epoch == 1:
                # VISUALIZATION: plot real images
                plot_real(args, data[:26], dir + 'reconstruction/', size_x=3, size_y=3)

    if mode == 'test':
        # load all data
        test_data = Variable(data_loader.dataset.data_tensor)
        test_target = Variable(data_loader.dataset.target_tensor)
        full_data = Variable(train_loader.dataset.data_tensor)

        if args.cuda:
            test_data, test_target, full_data = test_data.cuda(), test_target.cuda(), full_data.cuda()

        full_data = torch.bernoulli(full_data)

        # VISUALIZATION: plot real images
        plot_real(args, test_data, dir, size_x=5, size_y=5)

        # VISUALIZATION: plot reconstructions
        z_mean_recon, z_logvar_recon = model.q_z(test_data)
        z_recon = model.reparameterize(z_mean_recon, z_logvar_recon)
        samples, _ = model.p_x(z_recon)

        plot_reconstruction(args, samples, epoch, dir, size_x=5, size_y=5)

        # VISUALIZATION: plot generations
        z_sample_rand = Variable(torch.normal( torch.from_numpy( np.zeros((25,args.z1_size)) ).float(), 1. ) )
        if args.cuda:
            z_sample_rand = z_sample_rand.cuda()

        samples_rand, _ = model.p_x(z_sample_rand)

        plot_generation(args, samples_rand, dir, size_x=5, size_y=5)

        if args.z1_size == 2:
            # VISUALIZATION: plot low-dimensional manifold
            plot_manifold(model, args, dir)

            # VISUALIZATION: plot scatter-plot
            plot_scatter(model, test_data, test_target, dir)

        # CALCULATE lower-bound
        t_ll_s = time.time()
        elbo_test = model.calculate_lower_bound(test_data)
        t_ll_e = time.time()
        print('Lower-bound time: {:.2f}'.format(t_ll_e - t_ll_s))

        # CALCULATE log-likelihood
        t_ll_s = time.time()
        elbo_train = model.calculate_lower_bound(full_data)
        t_ll_e = time.time()
        print('Lower-bound time: {:.2f}'.format(t_ll_e - t_ll_s))

        # CALCULATE log-likelihood
        t_ll_s = time.time()
        log_likelihood_test = model.calculate_likelihood(test_data, dir, mode='test')
        t_ll_e = time.time()
        print('Log_likelihood time: {:.2f}'.format(t_ll_e - t_ll_s))

        # CALCULATE log-likelihood
        t_ll_s = time.time()
        log_likelihood_train = 0. #model.calculate_likelihood(full_data, dir, mode='train')
        t_ll_e = time.time()
        print('Log_likelihood time: {:.2f}'.format(t_ll_e - t_ll_s))

    # calculate final loss
    evaluate_loss /= len(data_loader)  # loss function already averages over batch size
    evaluate_re /= len(data_loader)  # re already averages over batch size
    evaluate_kl /= len(data_loader)  # kl already averages over batch size
    if mode == 'test':
        return evaluate_loss, evaluate_re, evaluate_kl, log_likelihood_test, log_likelihood_train, elbo_test, elbo_train
    else:
        return evaluate_loss, evaluate_re, evaluate_kl
Exemplo n.º 26
0
def evaluate_vae_vampprior_2level(args, model, train_loader, data_loader, epoch, dir, mode):
    # set loss to 0
    evaluate_loss = 0
    evaluate_re = 0
    evaluate_kl = 0
    # set model to evaluation mode
    model.eval()

    # evaluate
    for batch_idx, (data, target) in enumerate(data_loader):
        if args.cuda:
            data, target = data.cuda(), target.cuda()
        data, target = Variable(data, volatile=True), Variable(target)

        x = data
        # forward pass
        x_mean, x_logvar, z1_q, z1_q_mean, z1_q_logvar, z2_q, z2_q_mean, z2_q_logvar, z1_p_mean, z1_p_logvar = model.forward(
            x)
        # loss function
        # RE
        RE = log_Bernoulli(x, x_mean)

        # KL
        log_p_z2 = model.log_p_z2(z2_q)
        log_p_z1 = log_Normal_diag(z1_q, z1_p_mean, z1_p_logvar, dim=1)
        log_q_z1 = log_Normal_diag(z1_q, z1_q_mean, z1_q_logvar, dim=1)
        log_q_z2 = log_Normal_diag(z2_q, z2_q_mean, z2_q_logvar, dim=1)
        KL = - torch.sum(log_p_z1 + log_p_z2 - log_q_z1 - log_q_z2)

        evaluate_loss += ((-RE + KL) / data.size(0)).data[0]
        evaluate_re += (-RE / data.size(0)).data[0]
        evaluate_kl += (KL / data.size(0)).data[0]

        # print N digits
        if batch_idx == 1 and mode == 'validation':
            plot_reconstruction(args, x_mean, epoch, dir, size_x=3, size_y=3)
            if epoch == 1:
                # VISUALIZATION: plot real images
                plot_real(args, data[:26], dir + 'reconstruction/', size_x=3, size_y=3)

    if mode == 'test':
        # load all data
        test_data = Variable(data_loader.dataset.data_tensor)
        test_target = Variable(data_loader.dataset.target_tensor)
        full_data = Variable(train_loader.dataset.data_tensor)

        if args.cuda:
            test_data, test_target, full_data = test_data.cuda(), test_target.cuda(), full_data.cuda()

        full_data = torch.bernoulli(full_data)

        # VISUALIZATION: plot real images
        plot_real(args, test_data, dir, size_x=5, size_y=5)

        # VISUALIZATION: plot reconstructions
        z2_mean_recon, z2_logvar_recon = model.q_z2(test_data)
        z2_recon = model.reparameterize(z2_mean_recon, z2_logvar_recon)
        z1_mean_recon, z1_logvar_recon = model.q_z1(test_data, z2_recon)
        z1_recon = model.reparameterize(z1_mean_recon, z1_logvar_recon)
        samples, _ = model.p_x(z1_recon, z2_recon)

        plot_reconstruction(args, samples, epoch, dir, size_x=5, size_y=5)


        # VISUALIZATION: plot generations
        means = model.means(model.idle_input)[0:25]

        z2_sample_gen_mean, z2_sample_gen_logvar = model.q_z2(means)
        z2_sample_rand = model.reparameterize(z2_sample_gen_mean, z2_sample_gen_logvar)

        z1_mean_rand, z1_logvar_rand = model.p_z1(z2_sample_rand)
        z1_sample_rand = model.reparameterize(z1_mean_rand, z1_logvar_rand)
        samples_rand, _ = model.p_x(z1_sample_rand, z2_sample_rand)

        plot_generation(args, samples_rand, dir, size_x=5, size_y=5)

        if args.z1_size == 2 and args.z2_size == 2:
            # VISUALIZATION: plot low-dimensional manifold
            from utils.visual_evaluation import plot_manifold2
            plot_manifold2(model, args, dir)

            # VISUALIZATION: plot scatter-plot
            from utils.visual_evaluation import plot_scatter2
            z2_mean_recon, z2_logvar_recon = model.q_z2(test_data)
            z2_recon = model.reparameterize(z2_mean_recon, z2_logvar_recon)

            plot_scatter2(model, z2_recon, test_target, dir, name='scatter2D_z2.png')

            z1_mean_recon, z1_logvar_recon = model.q_z1(test_data, z2_recon)
            z1_recon = model.reparameterize(z1_mean_recon, z1_logvar_recon)
            plot_scatter2(model, z1_recon, test_target, dir, name='scatter2D_z1_1.png')

            z1_mean_recon, z1_logvar_recon = model.q_z1(test_data, z2_recon)
            z1_recon = model.reparameterize(z1_mean_recon, z1_logvar_recon)
            plot_scatter2(model, z1_recon, test_target, dir, name='scatter2D_z1_2.png')

        # CALCULATE lower-bound
        t_ll_s = time.time()
        elbo_test = model.calculate_lower_bound(test_data)
        t_ll_e = time.time()
        print('Lower-bound time: {:.2f}'.format(t_ll_e - t_ll_s))

        # CALCULATE log-likelihood
        t_ll_s = time.time()
        elbo_train = model.calculate_lower_bound(full_data)
        t_ll_e = time.time()
        print('Lower-bound time: {:.2f}'.format(t_ll_e - t_ll_s))

        # CALCULATE log-likelihood
        t_ll_s = time.time()
        log_likelihood_test = model.calculate_likelihood(test_data, dir, mode='test')
        t_ll_e = time.time()
        print('Log_likelihood time: {:.2f}'.format(t_ll_e - t_ll_s))

        # CALCULATE log-likelihood
        t_ll_s = time.time()
        log_likelihood_train = 0.  # model.calculate_likelihood(full_data, dir, mode='train')
        t_ll_e = time.time()
        print('Log_likelihood time: {:.2f}'.format(t_ll_e - t_ll_s))

    # calculate final loss
    evaluate_loss /= len(data_loader)  # loss function already averages over batch size
    evaluate_re /= len(data_loader)  # re already averages over batch size
    evaluate_kl /= len(data_loader)  # kl already averages over batch size
    
    if mode == 'test':
        return evaluate_loss, evaluate_re, evaluate_kl, log_likelihood_test, log_likelihood_train, elbo_test, elbo_train
    else:
        return evaluate_loss, evaluate_re, evaluate_kl
Exemplo n.º 27
0
    def calculate_loss(self, x, beta=1., average=False):
        # pass through VAE
        x_mean, x_logvar, z1_q, z1_q_mean, z1_q_logvar, z2_q, z2_q_mean, z2_q_logvar, z1_p_mean, z1_p_logvar = self.forward(
            x)

        # p(x|z)p(z)
        if self.args.input_type == 'binary':
            log_p_x_given_z = log_Bernoulli(x, x_mean, dim=1)
        elif self.args.input_type == 'gray' or self.args.input_type == 'continuous':
            log_p_x_given_z = -log_Logistic_256(x, x_mean, x_logvar, dim=1)
        else:
            raise Exception('Wrong input type!')

        log_p_z1 = log_Normal_diag(z1_q, z1_p_mean, z1_p_logvar, dim=1)
        log_p_z2 = self.log_p_z2(z2_q)
        log_p_z = log_p_z1 + beta * log_p_z2

        # q(z|x)q(x)
        log_q_z1 = log_Normal_diag(z1_q, z1_q_mean, z1_q_logvar, dim=1)
        log_q_z2 = log_Normal_diag(z2_q, z2_q_mean, z2_q_logvar, dim=1)
        log_q_z_given_x = log_q_z1 + log_q_z2

        if self.args.q_x_prior == "marginal":
            # q(x) is marginal of p(x, z)
            log_q_x = log_p_x_given_z + log_p_z - log_q_z_given_x
        elif self.args.q_x_prior == "vampprior":
            # q(x) is vamprior of p(x|u)
            log_q_x = self.log_q_x_vampprior(x)

        RE = log_p_x_given_z
        KL = -(log_p_z1 + log_p_z2 - log_q_z1 - log_q_z2)

        # MIM loss
        loss = -0.5 * (log_p_x_given_z + log_p_z + beta *
                       (log_q_z_given_x + log_q_x))

        if average:
            loss = torch.mean(loss)
            RE = torch.mean(RE)
            KL = torch.mean(KL)

        # symmetric sampling
        if self.p_samp and (beta >= 1.0):

            z2_q = None
            z1_q = None

            # p(x|z) Sampling from PixelCNN
            z1_q, z2_q, x = self.generate_x(
                N=x.shape[0],
                return_z=True,
                z1=z1_q,
                z2=z2_q,
            )

            # discrete samples should have no gradients
            if self.args.input_type == 'binary':
                x = x.detach()

            # discrete samples should have no gradients
            x_shape = (-1, ) + tuple(self.args.input_size)
            x = x.view(x_shape)

            # z2 ~ q(z2 | x)
            z2_q_mean, z2_q_logvar = self.q_z2(x)
            # z1 ~ q(z1 | x, z2)
            z1_q_mean, z1_q_logvar = self.q_z1(x, z2_q)
            # p(z1 | z2)
            z1_p_mean, z1_p_logvar = self.p_z1(z2_q)
            # x_mean = p(x|z1,z2)
            x_mean, x_logvar = self.p_x(z1_q, z2_q)

            x = x.view((x.shape[0], -1))

            # p(x|z)p(z)
            if self.args.input_type == 'binary':
                log_p_x_given_z = log_Bernoulli(x, x_mean, dim=1)
            elif self.args.input_type == 'gray' or self.args.input_type == 'continuous':
                log_p_x_given_z = -log_Logistic_256(x, x_mean, x_logvar, dim=1)
            else:
                raise Exception('Wrong input type!')

            log_p_z1 = log_Normal_diag(z1_q, z1_p_mean, z1_p_logvar, dim=1)
            log_p_z2 = self.log_p_z2(z2_q)
            log_p_z = log_p_z1 + beta * log_p_z2

            # q(z|x)q(x)
            log_q_z1 = log_Normal_diag(z1_q, z1_q_mean, z1_q_logvar, dim=1)
            log_q_z2 = log_Normal_diag(z2_q, z2_q_mean, z2_q_logvar, dim=1)
            log_q_z_given_x = log_q_z1 + log_q_z2

            if self.args.q_x_prior == "marginal":
                # q(x) is marginal of p(x, z)
                log_q_x = log_p_x_given_z
            elif self.args.q_x_prior == "vampprior":
                # q(x) is vamprior of p(x|u)
                log_q_x = self.log_q_x_vampprior(x)

            loss_p = -0.5 * (log_p_x_given_z + log_p_z + beta *
                             (log_q_z_given_x + log_q_x))

            # REINFORCE
            if self.args.input_type == 'binary':
                loss_p = loss_p + loss_p.detach() * log_p_x_given_z - (
                    loss_p * log_p_x_given_z).detach()

            # MIM loss
            loss += beta * loss_p.mean()

        return loss, RE, KL
Exemplo n.º 28
0
    def calculate_loss(self, x, beta=1., average=False):
        '''
        :param x: input image(s)
        :param beta: a hyperparam for warmup
        :param average: whether to average loss or not
        :return: value of a loss function
        '''
        # pass through VAE
        x_mean, x_logvar, z_q, z_q_mean, z_q_logvar = self.forward(x)

        # RE
        if self.args.input_type == 'binary':
            RE = log_Bernoulli(x, x_mean, dim=1)
        elif self.args.input_type == 'gray' or self.args.input_type == 'continuous':
            RE = -log_Logistic_256(x, x_mean, x_logvar, dim=1)
        else:
            raise Exception('Wrong input type!')

        # KL
        log_p_z = self.log_p_z(z_q)
        log_q_z = log_Normal_diag(z_q, z_q_mean, z_q_logvar, dim=1)
        KL = -(log_p_z - log_q_z)
        if self.isnan(z_q_mean.data[0][0]):
            print("mean:")
            print(z_q_mean)
        if self.isnan(z_q_logvar.data[0][0]):
            print("var:")
            print(z_q_logvar)

        #print(z_q_logvar)

        #FI
        if self.args.FI is True:
            FI, gamma = self.FI(x)
        else:
            FI = Variable(torch.zeros(1), requires_grad=False)
            if self.args.cuda:
                FI = FI.cuda()
        #FI = (torch.mean((torch.log(2*torch.pow(torch.exp( z_q_logvar ),2) + 1) - 2 * z_q_logvar)) - self.args.M ).abs()
        #FI = (torch.mean((1/torch.exp( z_q_logvar ) + 1/(2*torch.pow( torch.exp( z_q_logvar ), 2 )))) - self.args.M ).abs()

        # MI
        if self.args.MI is True:
            MI = self.MI(x)
        else:
            MI = Variable(torch.zeros(1), requires_grad=False)
            if self.args.cuda:
                MI = MI.cuda()

        loss = -RE + beta * KL  #+  self.args.gamma * FI + self.args.ksi * MI #- self.args.gamma * torch.log(FI)

        if self.args.FI is True:
            loss -= torch.mean(FI * gamma, dim=1)

        #print(FI)

        if average:
            loss = torch.mean(loss)
            RE = torch.mean(RE)
            KL = torch.mean(KL)
            FI = torch.mean(torch.exp(torch.mean(FI)))
            MI = torch.mean(MI)

        return loss, RE, KL, FI, MI
Exemplo n.º 29
0
 def vae_kl(x, x_decoded_mean):
     #KL part
     log_q = log_Normal_diag(z_0, z_mean, z_log_var)
     log_p = log_Normal_standard( z[str(config.number_of_flows)] )
     kl_loss =  log_q - log_p
     return K.mean( K.sum( kl_loss, axis=1 ), axis=0 )