Exemple #1
0
def calc_kl_divergence_lb_gauss_mixture(flags,
                                        index,
                                        mu1,
                                        logvar1,
                                        mus,
                                        logvars,
                                        norm_value=None):
    PI = torch.Tensor([math.pi])
    w_modalities = torch.Tensor(flags.alpha_modalities)
    if flags.cuda:
        PI = PI.cuda()
        w_modalities = w_modalities.cuda()
    w_modalities = reweight_weights(w_modalities)

    denom = w_modalities[0] * calc_gaussian_scaling_factor(
        PI, mu1, logvar1, norm_value=norm_value)
    for k in range(0, len(mus)):
        if index == k:
            denom += w_modalities[k + 1] * calc_gaussian_scaling_factor_self(
                PI, logvar1, norm_value=norm_value)
        else:
            denom += w_modalities[k + 1] * calc_gaussian_scaling_factor(
                PI, mu1, logvar1, mus[k], logvars[k], norm_value=norm_value)
    lb = -torch.log(denom)
    return lb
    def cond_generation_2a(self, latent_distribution_pairs, num_samples=None):
        if num_samples is None:
            num_samples = self.flags.batch_size

        mu0 = torch.zeros(1, num_samples, self.flags.class_dim)
        logvar0 = torch.zeros(1, num_samples, self.flags.class_dim)
        mu0 = mu0.to(self.flags.device)
        logvar0 = logvar0.to(self.flags.device)
        style_latents = self.get_random_styles(num_samples)
        cond_gen_2a = dict()
        for p, pair in enumerate(latent_distribution_pairs.keys()):
            ld_pair = latent_distribution_pairs[pair]
            mu_list = [mu0]
            logvar_list = [logvar0]
            for k, key in enumerate(ld_pair['latents'].keys()):
                mu_list.append(ld_pair['latents'][key][0].unsqueeze(0))
                logvar_list.append(ld_pair['latents'][key][1].unsqueeze(0))
            mus = torch.cat(mu_list, dim=0)
            logvars = torch.cat(logvar_list, dim=0)
            weights_pair = ld_pair['weights']
            weights_pair.insert(0, self.weights[0])
            weights_pair = utils.reweight_weights(torch.Tensor(weights_pair))
            mu_joint, logvar_joint = self.modality_fusion(
                mus, logvars, weights_pair)
            #mu_joint, logvar_joint = poe(mus, logvars);
            c_emb = utils.reparameterize(mu_joint, logvar_joint)
            l_2a = {
                'content': c_emb,
                'style': style_latents
            }
            cond_gen_2a[pair] = self.generate_from_latents(l_2a)
        return cond_gen_2a
Exemple #3
0
def calc_kl_divergence_ub_gauss_mixture(flags,
                                        index,
                                        mu1,
                                        logvar1,
                                        mus,
                                        logvars,
                                        entropy,
                                        norm_value=None):
    PI = torch.Tensor([math.pi])
    w_modalities = torch.Tensor(flags.alpha_modalities)
    if flags.cuda:
        PI = PI.cuda()
        w_modalities = w_modalities.cuda()
    w_modalities = reweight_weights(w_modalities)

    nom = calc_gaussian_scaling_factor_self(PI, logvar1, norm_value=norm_value)
    kl_div = calc_kl_divergence(mu1, logvar1, norm_value=norm_value)
    print('kl div uniform: ' + str(kl_div))
    denom = w_modalities[0] * torch.min(torch.Tensor([kl_div.exp(), 100000]))
    for k in range(0, len(mus)):
        if index == k:
            denom += w_modalities[k + 1]
        else:
            kl_div = calc_kl_divergence(mu1,
                                        logvar1,
                                        mus[k],
                                        logvars[k],
                                        norm_value=norm_value)
            print('kl div ' + str(k) + ': ' + str(kl_div))
            denom += w_modalities[k + 1] * torch.min(
                torch.Tensor([kl_div.exp(), 100000]))
    ub = torch.log(nom) - torch.log(denom) + entropy
    return ub
Exemple #4
0
    def __init__(self, flags):
        super(VAEbimodalCelebA, self).__init__();
        self.flags = flags;
        self.encoder_img = EncoderImg(flags)
        self.encoder_text = EncoderText(flags)
        self.decoder_img = DecoderImg(flags);
        self.decoder_text = DecoderText(flags);
        self.lhood_celeba = utils.get_likelihood(flags.likelihood_m1);
        self.lhood_text = utils.get_likelihood(flags.likelihood_m2);
        self.encoder_img = self.encoder_img.to(flags.device);
        self.decoder_img = self.decoder_img.to(flags.device);
        self.encoder_text = self.encoder_text.to(flags.device);
        self.decoder_text = self.decoder_text.to(flags.device);

        d_size_m1 = flags.img_size*flags.img_size;
        d_size_m2 = flags.len_sequence;
        w1 = 1.0;
        w2 = d_size_m1/d_size_m2;
        w_total = w1+w2;
        #w1 = w1/w_total;
        #w2 = w2/w_total;
        self.rec_w1 = w1;
        self.rec_w2 = w2;

        weights = utils.reweight_weights(torch.Tensor(flags.alpha_modalities));
        self.weights = weights.to(flags.device);
        if flags.modality_moe or flags.modality_jsd:
            self.modality_fusion = self.moe_fusion;
            if flags.modality_moe:
                self.calc_joint_divergence = self.divergence_moe;
            if flags.modality_jsd:
                self.calc_joint_divergence = self.divergence_jsd;
        elif flags.modality_poe:
            self.modality_fusion = self.poe_fusion;
            self.calc_joint_divergence = self.divergence_poe;
    def moe_fusion(self, mus, logvars, weights=None):
        if weights is None:
            weights = self.weights

        weights[0] = 0.0
        weights = utils.reweight_weights(weights)
        num_samples = mus[0].shape[0]
        mu_moe, logvar_moe = utils.mixture_component_selection(
            self.flags, mus, logvars, weights, num_samples)
        return [mu_moe, logvar_moe]
Exemple #6
0
 def moe_fusion(self, mus, logvars, weights=None):
     if weights is None:
         weights = self.weights;
     weights = weights.clone();
     weights[0] = 0.0;
     weights = utils.reweight_weights(weights);
     mu_moe, logvar_moe = utils.mixture_component_selection(self.flags,
                                                            mus,
                                                            logvars,
                                                            weights);
     return [mu_moe, logvar_moe];
Exemple #7
0
 def divergence_jsd(self, mus, logvars, weights=None):
     if weights is None:
         weights = self.weights;
     weights = weights.clone();
     weights = utils.reweight_weights(weights);
     div_measures = calc_alphaJSD_modalities(self.flags,
                                             mus,
                                             logvars,
                                             weights,
                                             normalization=self.flags.batch_size);
     divs = dict();
     divs['joint_divergence'] = div_measures[0];
     divs['individual_divs'] = div_measures[1];
     divs['dyn_prior'] = div_measures[2];
     return divs;
Exemple #8
0
 def divergence_moe(self, mus, logvars, weights=None):
     if weights is None:
         weights = self.weights;
     weights = weights.clone()
     weights[0] = 0.0;
     weights = utils.reweight_weights(weights);
     div_measures = calc_group_divergence_moe(self.flags,
                                              mus,
                                              logvars,
                                              weights,
                                              normalization=self.flags.batch_size);
     divs = dict();
     divs['joint_divergence'] = div_measures[0];
     divs['individual_divs'] = div_measures[1];
     divs['dyn_prior'] = None;
     return divs;
Exemple #9
0
def calc_alphaJSD_modalities_mixture(m1_mu, m1_logvar, m2_mu, m2_logvar,
                                     flags):
    klds = torch.zeros(2)
    entropies_mixture = torch.zeros(2)
    w_modalities = torch.Tensor(flags.alpha_modalities[1:])
    if flags.cuda:
        w_modalities = w_modalities.cuda()
        klds = klds.cuda()
        entropies_mixture = entropies_mixture.cuda()
    w_modalities = reweight_weights(w_modalities)

    mus = [m1_mu, m2_mu]
    logvars = [m1_logvar, m2_logvar]
    for k in range(0, len(mus)):
        ent = calc_entropy_gauss(flags,
                                 logvars[k],
                                 norm_value=flags.batch_size)
        # print('entropy: ' + str(ent))
        # print('lb: ' )
        kld_lb = calc_kl_divergence_lb_gauss_mixture(
            flags,
            k,
            mus[k],
            logvars[k],
            mus,
            logvars,
            norm_value=flags.batch_size)
        print('kld_lb: ' + str(kld_lb))
        # print('ub: ')
        kld_ub = calc_kl_divergence_ub_gauss_mixture(
            flags,
            k,
            mus[k],
            logvars[k],
            mus,
            logvars,
            ent,
            norm_value=flags.batch_size)
        print('kld_ub: ' + str(kld_ub))
        # kld_mean = (kld_lb+kld_ub)/2;
        entropies_mixture[k] = ent.clone()
        klds[k] = 0.5 * (kld_lb + kld_ub)
        # klds[k] = kld_ub;
    summed_klds = (w_modalities * klds).sum()
    # print('summed klds: ' + str(summed_klds));
    return summed_klds, klds, entropies_mixture
Exemple #10
0
def get_synergy_dist(flags, flows, num_imp_samples):
    flow_reps = torch.zeros(1, flags.batch_size * num_imp_samples,
                            flags.class_dim)
    flow_reps = flow_reps.to(flags.device)
    for k, key in enumerate(flows.keys()):
        flow_reps = torch.cat(
            [flow_reps, flows[key]['content'][2].unsqueeze(0)])
    # only works if modalities are equally weighted
    weights_mixture_selection = utils.reweight_weights(
        torch.Tensor(
            [0.0, flags.alpha_modalities[1], flags.alpha_modalities[2]]))
    weights_mixture_selection = weights_mixture_selection.to(flags.device)
    flow_emb_moe = utils.flow_mixture_component_selection(
        flags,
        flow_reps,
        weights_mixture_selection,
        num_samples=flags.batch_size * num_imp_samples)
    return flow_emb_moe
    def __init__(self, flags):
        super(VAEtrimodalSVHNMNIST, self).__init__()
        self.num_modalities = 3
        self.flags = flags
        self.encoder_svhn = EncoderSVHN(flags)
        self.encoder_mnist = EncoderImg(flags)
        self.encoder_text = EncoderText(flags)
        self.decoder_mnist = DecoderImg(flags)
        self.decoder_svhn = DecoderSVHN(flags)
        self.decoder_text = DecoderText(flags)
        self.encoder_mnist = self.encoder_mnist.to(flags.device)
        self.encoder_svhn = self.encoder_svhn.to(flags.device)
        self.encoder_text = self.encoder_text.to(flags.device)
        self.decoder_mnist = self.decoder_mnist.to(flags.device)
        self.decoder_svhn = self.decoder_svhn.to(flags.device)
        self.decoder_text = self.decoder_text.to(flags.device)
        self.lhood_mnist = utils.get_likelihood(flags.likelihood_m1)
        self.lhood_svhn = utils.get_likelihood(flags.likelihood_m2)
        self.lhood_text = utils.get_likelihood(flags.likelihood_m3)

        d_size_m1 = flags.img_size_mnist * flags.img_size_mnist
        d_size_m2 = 3 * flags.img_size_svhn * flags.img_size_svhn
        d_size_m3 = flags.len_sequence
        total_d_size = d_size_m1 + d_size_m2 + d_size_m3
        w1 = d_size_m2 / d_size_m1
        w2 = 1.0
        w3 = d_size_m2 / d_size_m3
        w_total = w1 + w2 + w3
        self.rec_w1 = w1
        self.rec_w2 = w2
        self.rec_w3 = w3

        weights = utils.reweight_weights(torch.Tensor(flags.alpha_modalities))
        self.weights = weights.to(flags.device)
        if flags.modality_moe or flags.modality_jsd:
            self.modality_fusion = self.moe_fusion
            self.calc_joint_divergence = self.divergence_moe
            if flags.modality_jsd:
                self.calc_joint_divergence = self.divergence_jsd
        elif flags.modality_poe:
            self.modality_fusion = self.poe_fusion
            self.calc_joint_divergence = self.divergence_poe