Exemple #1
0
    def sample(self, mu, logvar, grad_fn=lambda x: 1, x_info=None):
        eps = Variable(torch.FloatTensor(mu.size()).normal_().type(self.dtype))
        z = eps.mul(logvar.mul(0.5).exp_()).add_(mu)
        logqz = log_normal(z, mu, logvar)

        if self.has_flow:
            z, logprob = self.q_dist.forward(z, grad_fn, x_info)
            logqz += logprob

        zeros = Variable(torch.zeros(z.size()).type(self.dtype))
        logpz = log_normal(z, zeros, zeros)

        return z, logpz, logqz
Exemple #2
0
    def sample(self, mu, logvar, grad_fn=lambda x: 1, x_info=None):
        # grad_fn default is identity, i.e. don't use grad info
        eps = Variable(torch.randn(mu.size()).type(self.dtype))
        z = eps.mul(logvar.mul(0.5).exp()).add(mu)
        logqz = log_normal(z, mu, logvar)

        if self.has_flow:
            z, logprob = self.q_dist.forward(z, grad_fn, x_info)
            logqz += logprob

        zeros = Variable(torch.zeros(z.size()).type(self.dtype))
        logpz = log_normal(z, zeros, zeros)

        return z, logpz, logqz
Exemple #3
0
    def log_joint(x_logits, x, z):
        """log p(x,z)"""
        zeros = Variable(torch.zeros(z.size()).type(model.dtype))
        logpz = log_normal(z, zeros, zeros)
        logpx = log_likelihood(x_logits, x)

        return logpx + logpz
Exemple #4
0
    def log_f_i(z, data, t, log_likelihood_fn=log_bernoulli):
        """Unnormalized density for intermediate distribution `f_i`:
            f_i = p(z)^(1-t) p(x,z)^(t) = p(z) p(x|z)^t
        =>  log f_i = log p(z) + t * log p(x|z)
        """
        zeros = Variable(torch.zeros(B, z_size).type(mdtype))
        log_prior = log_normal(z, zeros, zeros)
        log_likelihood = log_likelihood_fn(model.decode(z), data)

        return log_prior + log_likelihood.mul_(t)
Exemple #5
0
    def _sample(self, z0, grad_fn, x_info):
        B = z0.size(0)
        z_size = self.z_size
        act_func = F.elu
        qv_weights, rv_weights, params = self.qv_weights, self.rv_weights, self.params

        out = torch.cat((z0, x_info), dim=1)
        for i in range(len(qv_weights) - 1):
            out = act_func(qv_weights[i](out))
        out = qv_weights[-1](out)
        mean_v0, logvar_v0 = out[:, :z_size], out[:, z_size:]

        eps = Variable(torch.randn(B, z_size).type(type(out.data)))
        v0 = eps.mul(logvar_v0.mul(0.5).exp_()) + mean_v0
        logqv0 = log_normal(v0, mean_v0, logvar_v0)

        zT, vT = z0, v0
        logdetsum = 0.
        for i in range(self.n_flows):
            zT, vT, logdet = self._norm_flow(params[i], zT, vT, grad_fn,
                                             x_info)
            logdetsum += logdet

        # reverse model, r(vT|x,zT)
        out = torch.cat((zT, x_info), dim=1)
        for i in range(len(rv_weights) - 1):
            out = act_func(rv_weights[i](out))
        out = rv_weights[-1](out)
        mean_vT, logvar_vT = out[:, :z_size], out[:, z_size:]
        logrvT = log_normal(vT, mean_vT, logvar_vT)

        assert logqv0.size() == (B, )
        assert logdetsum.size() == (B, )
        assert logrvT.size() == (B, )

        logprob = logqv0 - logdetsum - logrvT

        return zT, logprob
Exemple #6
0
    def sample(mean_v0, logvar_v0):

        B = mean_v0.size()[0]
        eps = Variable(
            torch.FloatTensor(B, z_size).normal_().type(model.dtype))
        v0 = eps.mul(logvar_v0.mul(0.5).exp_()) + mean_v0
        logqv0 = log_normal(v0, mean_v0, logvar_v0)

        out = v0
        for i in range(len(qz_weights) - 1):
            out = act_func(qz_weights[i](out))
        out = qz_weights[-1](out)
        mean_z0, logvar_z0 = out[:, :z_size], out[:, z_size:]

        eps = Variable(
            torch.FloatTensor(B, z_size).normal_().type(model.dtype))
        z0 = eps.mul(logvar_z0.mul(0.5).exp_()) + mean_z0
        logqz0 = log_normal(z0, mean_z0, logvar_z0)

        zT, vT = z0, v0
        logdetsum = 0.
        for i in range(n_flows):
            zT, vT, logdet = norm_flow(params[i], zT, vT)
            logdetsum += logdet

        # reverse model, r(vT|x,zT)
        out = zT
        for i in range(len(rv_weights) - 1):
            out = act_func(rv_weights[i](out))
        out = rv_weights[-1](out)
        mean_vT, logvar_vT = out[:, :z_size], out[:, z_size:]
        logrvT = log_normal(vT, mean_vT, logvar_vT)

        logq = logqz0 + logqv0 - logdetsum - logrvT

        return zT, logq
Exemple #7
0
 def U(z):
     logpx = log_bernoulli(self.decode(z), x)
     logpz = log_normal(z)
     return -logpx - logpz  # energy as -log p(x, z)
def optimize_local_gaussian(log_likelihood,
                            model,
                            data_var,
                            k=100,
                            check_every=100,
                            sentinel_thres=10,
                            debug=False):
    """data_var should be (cuda) variable."""

    B = data_var.size()[0]
    z_size = model.z_size

    data_var = safe_repeat(data_var, k)
    zeros = Variable(torch.zeros(B * k, z_size).type(model.dtype))
    mean = Variable(torch.zeros(B * k, z_size).type(model.dtype),
                    requires_grad=True)
    logvar = Variable(torch.zeros(B * k, z_size).type(model.dtype),
                      requires_grad=True)

    optimizer = optim.Adam([mean, logvar], lr=1e-3)
    best_avg, sentinel, prev_seq = 999999, 0, []

    # perform local opt
    time_ = time.time()
    for epoch in range(1, 999999):

        eps = Variable(
            torch.FloatTensor(mean.size()).normal_().type(model.dtype))
        z = eps.mul(logvar.mul(0.5).exp_()).add_(mean)
        x_logits = model.decode(z)

        logpz = log_normal(z, zeros, zeros)
        logqz = log_normal(z, mean, logvar)
        logpx = log_likelihood(x_logits, data_var)

        optimizer.zero_grad()
        loss = -torch.mean(logpx + logpz - logqz)
        loss_np = loss.data.cpu().numpy()
        loss.backward()
        optimizer.step()

        prev_seq.append(loss_np)
        if epoch % check_every == 0:
            last_avg = np.mean(prev_seq)
            if debug:  # debugging helper
                sys.stderr.write(
                    'Epoch %d, time elapse %.4f, last avg %.4f, prev best %.4f\n' % \
                    (epoch, time.time()-time_, -last_avg, -best_avg)
                )
            if last_avg < best_avg:
                sentinel, best_avg = 0, last_avg
            else:
                sentinel += 1
            if sentinel > sentinel_thres:
                break
            prev_seq = []
            time_ = time.time()

    # evaluation
    eps = Variable(
        torch.FloatTensor(B * k, z_size).normal_().type(model.dtype))
    z = eps.mul(logvar.mul(0.5).exp_()).add_(mean)

    logpz = log_normal(z, zeros, zeros)
    logqz = log_normal(z, mean, logvar)
    logpx = log_likelihood(model.decode(z), data_var)
    elbo = logpx + logpz - logqz

    vae_elbo = torch.mean(elbo)
    iwae_elbo = torch.mean(log_mean_exp(elbo.view(k, -1).transpose(0, 1)))

    return vae_elbo.data[0], iwae_elbo.data[0]
Exemple #9
0
 def normalized_kinetic(v):
     zeros = Variable(torch.zeros(B, z_size).type(mdtype))
     # this is superior to the unnormalized version
     return -log_normal(v, zeros, zeros)