Ejemplo n.º 1
0
    def __init__(self, temp, latent_num, latent_dim):
        super(Model, self).__init__()
        if type(temp) != torch.Tensor:
            temp = torch.tensor(temp)
        self.__temp = temp
        self.latent_num = latent_num
        self.latent_dim = latent_dim
        self.encoder = Encoder(latent_num=latent_num, latent_dim=latent_dim)
        self.decoder = Decoder(latent_num=latent_num, latent_dim=latent_dim)
        if 'ExpTDModel' in  str(self.__class__):
            self.prior = ExpRelaxedCategorical(temp, probs=torch.ones(latent_dim).cuda())
        else:
            self.prior = dist.RelaxedOneHotCategorical(temp, probs=torch.ones(latent_dim).cuda())
        self.initialize()

        self.softmax = nn.Softmax(dim=-1)
 def __init__(self,
              temperature,
              probs=None,
              logits=None,
              validate_args=None):
     base_dist = ExpRelaxedCategorical(temperature, probs, logits)
     # Do *not* call constructor of RelaxedOneHotCategorical, as it has hard-coded
     # the unstable ExpTransform. Call its parent class directly.
     super(RelaxedOneHotCategorical,
           self).__init__(base_dist,
                          StableExpTransform(),
                          validate_args=validate_args)
Ejemplo n.º 3
0
col = 0
row = 0
ax = plt.subplot2grid((rows, cols), (row, col),
                      frameon=False,
                      colspan=1,
                      rowspan=1)

n_cats = 2

# needsoftmax_mixtureweight = torch.tensor(np.ones(n_cats), requires_grad=True)
# needsoftmax_mixtureweight = torch.tensor([], requires_grad=True)
# weights = torch.softmax(needsoftmax_mixtureweight, dim=0).float()
theta = .99
weights = torch.tensor([1 - theta, theta], requires_grad=True).float()
cat = ExpRelaxedCategorical(probs=weights, temperature=torch.tensor([1.]))

val = .5
val2 = -2.5
val3 = -11
cmap = 'Blues'
alpha = 1.
xlimits = [val3, val]
ylimits = [val2, val]
numticks = 51
x = np.linspace(*xlimits, num=numticks)
y = np.linspace(*ylimits, num=numticks)
X, Y = np.meshgrid(x, y)
aaa = np.concatenate([np.atleast_2d(X.ravel()), np.atleast_2d(Y.ravel())]).T
aaa = torch.tensor(aaa).float()
logprob = cat.log_prob(aaa)
Ejemplo n.º 4
0
def get_log_pz_qz_prodzi_qzCx(latent_sample,
                              latent_dist,
                              n_data,
                              is_mss=True,
                              mi=False):
    """
    Calculates log densities

    Parameters
    ----------
    latent_sample: torch.Tensor or np.ndarray or float
        Value at which to compute the density. (batch size, latent dim)

    latent_dist: torch.Tensor or np.ndarray or float
        statisitc for dist. Each of statistics has size of (batch size, latent dim).
        For guassian, latent_dist = (Mean, logVar)
        For gumbel_softmax, latent_dist = alpha(prob. of categorical variable)
    """
    batch_size, hidden_dim = latent_sample['cont'].shape

    # calculate log q(z|x)
    log_q_ziCx_cont = log_density_gaussian(latent_sample['cont'],
                                           *(latent_dist['cont']))  #64,10
    if 'disc' in latent_sample.keys():
        log_q_ziCx_disc = log_density_categorical(latent_sample['disc'],
                                                  latent_dist['disc'])  #64
        log_q_ziCx = torch.cat(
            (log_q_ziCx_cont, log_q_ziCx_disc.unsqueeze(-1)), dim=1)  # 64,11
    else:
        log_q_ziCx = log_q_ziCx_cont
    log_q_zCx = log_q_ziCx.sum(
        1)  #64   sum across logP(z_i). i.e, \prod P(z_i | x_i)

    # calculate log p(z)
    zeros = torch.zeros_like(latent_sample['cont'])  # mean and log var is 0
    log_pzi_cont = log_density_gaussian(
        latent_sample['cont'], zeros,
        zeros)  # sum across logP(z_i). i.e, \prod P(z_i)
    if 'disc' in latent_sample.keys():
        unif_logits = torch.log(
            torch.ones_like(latent_sample['disc']) * 1 /
            latent_sample['disc'].shape[1])
        relaxedCate = ExpRelaxedCategorical(torch.tensor(.67),
                                            logits=unif_logits)
        log_pzi_disc = log_density_categorical(
            latent_sample['disc'],
            relaxedCate)  # sum across logP(z_i). i.e, \prod P(z_i)
        log_pzi = torch.cat((log_pzi_cont, log_pzi_disc.unsqueeze(-1)), dim=1)
    else:
        log_pzi = log_pzi_cont
    log_pz = log_pzi.sum(1)

    # compute log q(z) ~= log 1/(NM) sum_m=1^M q(z|x_m) = - log(MN) + logsumexp_m(q(z|x_m))
    mat_log_qzi_cont = matrix_log_density(
        latent_sample['cont'], *(latent_dist['cont'])
    )  #(256,256,10): only (n,n,10) is the result of correct pair of (latent sample, m, s).
    # (n,n,10) --> first n = num of given samples(batch). second n = for Monte Carolo. 10 = latent dim.

    if 'disc' in latent_sample.keys():
        batch_dim = 1
        latent_sample_disc = latent_sample['disc'].unsqueeze(0).unsqueeze(
            batch_dim + 1).transpose(batch_dim, 0)
        mat_log_qzi_disc = log_density_categorical(
            latent_sample_disc,
            latent_dist['disc']).transpose(1, batch_dim + 1)  #(64,64,1)
        mat_log_qzi = torch.cat((mat_log_qzi_cont, mat_log_qzi_disc), dim=2)
    else:
        mat_log_qzi = mat_log_qzi_cont

    if is_mss:
        # use stratification
        log_iw_mat = log_importance_weight_matrix(batch_size, n_data).to(
            latent_sample.device)
        mat_log_qzi = mat_log_qzi + log_iw_mat.view(batch_size, batch_size, 1)

    log_qz = torch.logsumexp(mat_log_qzi.sum(2), dim=1,
                             keepdim=False) - math.log(batch_size * n_data)
    # mat_log_qz.sum(2): sum across logP(z_i). i.e, \prod P(z_i|x) ==> (256,256) : joint dist of zi|x = z|x
    # logsumexp = sum across all possible pair of (m, s) for each of latent sample : from z|x -> z
    log_qzi = torch.logsumexp(mat_log_qzi, dim=1, keepdim=False) - math.log(
        batch_size * n_data)
    log_prod_qzi = log_qzi.sum(1)
    mi_zi_x = (log_q_ziCx - log_qzi).sum(dim=0) / batch_size

    # logsumexp = sum across all possible pair of (m, s) for each of latent sample => (256,10): zi|x -> zi
    # and then logsum across z_i => 256: \prod zi
    #############
    #q(z=l|x,y=k)
    mi_zi_y = torch.tensor([.0] * 11)
    if mi:

        # first = mat_qzi.sum(1) * np.log(1000)
        # second = (mat_qzi * mat_log_qzi).sum(1)
        # third = 10 * np.log(100)
        # fourth = 10 * log_qzi
        # mi_zi_y = ((first + second - third - fourth) / 1000).sum(0) / batch_size
        for k in range(10):
            # mat_log_qzi_cont = matrix_log_density(latent_sample['cont'][k*100:(k+1)*100], *(latent_dist['cont']))
            # mat_log_qzi_disc = log_density_categorical(latent_sample_disc[k*100:(k+1)*100], latent_dist['disc']).transpose(1, batch_dim + 1)
            mat_log_qzi_k = torch.cat(
                (mat_log_qzi_cont[k * 10:(k + 1) * 10, k * 10:(k + 1) * 10, :],
                 mat_log_qzi_disc[k * 10:(k + 1) * 10,
                                  k * 10:(k + 1) * 10, :]),
                dim=2)
            # mat_log_qzi_k = torch.cat((mat_log_qzi_cont[k*100:(k+1)*100, k*100:(k+1)*100, :], mat_log_qzi_disc[k*100:(k+1)*100, k*100:(k+1)*100, :]), dim=2)
            mat_qzi = torch.exp(mat_log_qzi_k)  #100,100,11
            star = mat_qzi.sum(1).mean(0)  #  100, 11(x, marginalize out) ->
            mi_zi_y += star * (np.log(10) + torch.log(star) - log_qzi.mean(0))
        mi_zi_y = mi_zi_y / batch_size
    #############

    return log_pz, log_qz, log_prod_qzi, log_q_zCx, mi_zi_y
Ejemplo n.º 5
0
def get_log_pz_qz_prodzi_qzCx(latent_sample,
                              latent_dist,
                              n_data,
                              use_cude,
                              is_mss=True):
    """
    Calculates log densities

    Parameters
    ----------
    latent_sample: torch.Tensor or np.ndarray or float
        Value at which to compute the density. (batch size, latent dim)

    latent_dist: torch.Tensor or np.ndarray or float
        statisitc for dist. Each of statistics has size of (batch size, latent dim).
        For guassian, latent_dist = (Mean, logVar)
        For gumbel_softmax, latent_dist = alpha(prob. of categorical variable)
    """
    batch_size, hidden_dim = latent_sample['cont'].shape

    # calculate log q(z|x)
    log_q_ziCx_cont = log_density_gaussian(latent_sample['cont'],
                                           *(latent_dist['cont']))  #64,10
    log_q_ziCx_disc = log_density_categorical(latent_sample['disc'],
                                              latent_dist['disc'])  #64

    log_q_ziCx = torch.cat((log_q_ziCx_cont, log_q_ziCx_disc.unsqueeze(-1)),
                           dim=1)  # 64,11
    log_q_zCx = log_q_ziCx.sum(
        1)  #64   sum across logP(z_i). i.e, \prod P(z_i | x_i)

    # calculate log p(z)
    zeros = torch.zeros_like(latent_sample['cont'])  # mean and log var is 0
    log_pzi_cont = log_density_gaussian(
        latent_sample['cont'], zeros,
        zeros)  # sum across logP(z_i). i.e, \prod P(z_i)

    unif_logits = torch.log(
        torch.ones_like(latent_sample['disc']) * 1 /
        latent_sample['disc'].shape[1])
    relaxedCate = ExpRelaxedCategorical(torch.tensor(.67), logits=unif_logits)

    log_pzi_dist = log_density_categorical(
        latent_sample['disc'],
        relaxedCate)  # sum across logP(z_i). i.e, \prod P(z_i)

    log_pzi = torch.cat((log_pzi_cont, log_pzi_dist.unsqueeze(-1)), dim=1)
    log_pz = log_pzi.sum(1)

    # compute log q(z) ~= log 1/(NM) sum_m=1^M q(z|x_m) = - log(MN) + logsumexp_m(q(z|x_m))
    mat_log_qzi_cont = matrix_log_density(
        latent_sample['cont'], *(latent_dist['cont'])
    )  #(256,256,10): only (n,n,10) is the result of correct pair of (latent sample, m, s).
    # (n,n,10) --> first n = num of given samples(batch). second n = for Monte Carolo. 10 = latent dim.
    batch_dim = 1
    latent_sample_disc = latent_sample['disc'].unsqueeze(0).unsqueeze(
        batch_dim + 1).transpose(batch_dim, 0)
    mat_log_qzi_disc = log_density_categorical(latent_sample_disc,
                                               latent_dist['disc']).transpose(
                                                   1,
                                                   batch_dim + 1)  #(64,64,1)

    mat_log_qzi = torch.cat((mat_log_qzi_cont, mat_log_qzi_disc), dim=2)
    if is_mss:
        # use stratification
        log_iw_mat = log_importance_weight_matrix(batch_size, n_data).to(
            latent_sample.device)
        mat_log_qzi = mat_log_qzi + log_iw_mat.view(batch_size, batch_size, 1)

    log_qz = torch.logsumexp(mat_log_qzi.sum(2), dim=1,
                             keepdim=False) - math.log(batch_size * n_data)
    # mat_log_qz.sum(2): sum across logP(z_i). i.e, \prod P(z_i|x) ==> (256,256) : joint dist of zi|x = z|x
    # logsumexp = sum across all possible pair of (m, s) for each of latent sample : from z|x -> z
    log_prod_qzi = (torch.logsumexp(mat_log_qzi, dim=1, keepdim=False) -
                    math.log(batch_size * n_data)).sum(1)
    # logsumexp = sum across all possible pair of (m, s) for each of latent sample => (256,10): zi|x -> zi
    # and then logsum across z_i => 256: \prod zi

    return log_pz, log_qz, log_prod_qzi, log_q_zCx
Ejemplo n.º 6
0
class Model(nn.Module):
    def __init__(self, temp, latent_num, latent_dim):
        super(Model, self).__init__()
        if type(temp) != torch.Tensor:
            temp = torch.tensor(temp)
        self.__temp = temp
        self.latent_num = latent_num
        self.latent_dim = latent_dim
        self.encoder = Encoder(latent_num=latent_num, latent_dim=latent_dim)
        self.decoder = Decoder(latent_num=latent_num, latent_dim=latent_dim)
        if 'ExpTDModel' in  str(self.__class__):
            self.prior = ExpRelaxedCategorical(temp, probs=torch.ones(latent_dim).cuda())
        else:
            self.prior = dist.RelaxedOneHotCategorical(temp, probs=torch.ones(latent_dim).cuda())
        self.initialize()

        self.softmax = nn.Softmax(dim=-1)

    @property
    def temp(self):
        return self.__temp

    @temp.setter
    def temp(self, value):
        self.__temp = value
        if 'ExpTDModel' in  str(self.__class__):
            self.prior = ExpRelaxedCategorical(value, probs=torch.ones(self.latent_dim).cuda())
        else:
            self.prior = dist.RelaxedOneHotCategorical(value, probs=torch.ones(self.latent_dim).cuda())

    def initialize(self):
        for param in self.parameters():
            if len(param.shape) > 2:
                nn.init.xavier_uniform_(param)

    def encode(self, x):
        return self.encoder(x)
    
    def decode(self, z):
        return torch.sigmoid(self.decoder(z))

    def forward(self, x):
        log_alpha = self.encode(x)
        z, v_dist = self.sample(log_alpha, self.temp)
        if 'ExpTDModel' in  str(self.__class__):
            x_recon = self.decode(z.exp().view(-1, self.latent_num*self.latent_dim))
        else:
            x_recon = self.decode(z.view(-1, self.latent_num*self.latent_dim))
        return z, x_recon, v_dist

    def approximate_loss(self, x, x_recon, v_dist, eps=1e-3):
        """ KL-divergence follows Eric Jang's trick
        """
        log_alpha = v_dist.logits
        bce = F.binary_cross_entropy(x_recon, x.view(-1, 784), reduction='sum')
        num_class = torch.tensor(self.latent_dim).float()
        probs = torch.softmax(log_alpha, dim=-1) # alpha_i / alpha_sum
        kl = torch.sum(probs * (num_class * (probs + eps)).log(), dim=-1).sum()
        return bce, kl

    def loss(self, x, x_recon, z, v_dist):
        """ Monte-Carlo estimate KL-divergence
        """
        bce = F.binary_cross_entropy(x_recon, x.view(-1, 784), reduction='sum')
        n_batch = x.shape[0]
        prior = self.prior.expand(torch.Size([n_batch, self.latent_num]))
        kl = (v_dist.log_prob(z) - prior.log_prob(z)).sum()
        return bce, kl

    def sample(self, log_alpha, temp):
        raise ValueError("Not Implemented")
Ejemplo n.º 7
0
 def sample(self, log_alpha, temp):
     v_dist = ExpRelaxedCategorical(temp, logits=log_alpha)
     log_concrete = v_dist.rsample()
     return log_concrete, v_dist
Ejemplo n.º 8
0
 def temp(self, value):
     self.__temp = value
     if 'ExpTDModel' in  str(self.__class__):
         self.prior = ExpRelaxedCategorical(value, probs=torch.ones(self.latent_dim).cuda())
     else:
         self.prior = dist.RelaxedOneHotCategorical(value, probs=torch.ones(self.latent_dim).cuda())