Example #1
0
def individual_iwaes(qz_xs, px_zs, zss, x):
    lws = []
    for d, _px_zs in enumerate(np.array(px_zs).T):  # rows are decoders now
        lw = [
            px_z.log_prob(x[d]).view(*px_z.batch_shape[:2], -1).sum(-1) +
            model.pz(*model.pz_params).log_prob(zss[e]).sum(-1) - log_mean_exp(
                torch.stack([qz_x.log_prob(zss[e]).sum(-1) for qz_x in qz_xs]))
            for e, px_z in enumerate(_px_zs)
        ]
        lw = torch.cat(lw)
        lws.append(log_mean_exp(lw).sum())
    return lws
Example #2
0
def log_bernoulli_norm_flow_marginal_estimate(recon_x_mu, x, zk, z0, z0_mu,
                                              z0_logvar, log_abs_det_jacobian):
    batch_size, n_samples, z_dim = z0.size()
    input_dim = x.size(1)
    x = x.unsqueeze(1).repeat(1, n_samples, 1)

    z0_2d = z0.view(batch_size * n_samples, z_dim)
    zk_2d = zk.view(batch_size * n_samples, z_dim)
    z0_mu_2d = z0_mu.view(batch_size * n_samples, z_dim)
    z0_logvar_2d = z0_logvar.view(batch_size * n_samples, z_dim)
    log_abs_det_jacobian_2d = \
        log_abs_det_jacobian.view(batch_size * n_samples)
    recon_x_mu_2d = recon_x_mu.view(batch_size * n_samples, input_dim)
    x_2d = x.view(batch_size * n_samples, input_dim)

    log_p_x_given_zk_2d = bernoulli_log_pdf(x_2d, recon_x_mu_2d)
    log_q_z0_given_x_2d = gaussian_log_pdf(z0_2d, z0_mu_2d, z0_logvar_2d)
    log_q_zk_given_x_2d = log_q_z0_given_x_2d - log_abs_det_jacobian_2d
    log_p_zk_2d = unit_gaussian_log_pdf(zk_2d)

    log_weight_2d = log_p_x_given_zk_2d + log_p_zk_2d - log_q_zk_given_x_2d
    log_weight = log_weight_2d.view(batch_size, n_samples)

    log_p_x = log_mean_exp(log_weight, dim=1)
    return -torch.mean(log_p_x)
Example #3
0
def iwae(qz_x, px_z, zs, x):
    """IWAE estimate for log p_\theta(x) -- fully vectorised."""
    lpz = model.pz(*model.pz_params).log_prob(zs).sum(-1)
    lpx_z = px_z.log_prob(x).view(*px_z.batch_shape[:2],
                                  -1) * model.llik_scaling
    lqz_x = qz_x.log_prob(zs).sum(-1)
    return log_mean_exp(lpz + lpx_z.sum(-1) - lqz_x).sum()
Example #4
0
def m_iwae(qz_xs, px_zs, zss, x):
    """IWAE estimate for log p_\theta(x) for multi-modal vae -- fully vectorised"""
    lws = []
    for r, qz_x in enumerate(qz_xs):
        lpz = model.pz(*model.pz_params).log_prob(zss[r]).sum(-1)
        lqz_x = log_mean_exp(
            torch.stack([qz_x.log_prob(zss[r]).sum(-1) for qz_x in qz_xs]))
        lpx_z = [
            px_z.log_prob(x[d]).view(*px_z.batch_shape[:2], -1).mul(
                model.vaes[d].llik_scaling).sum(-1)
            for d, px_z in enumerate(px_zs[r])
        ]
        lpx_z = torch.stack(lpx_z).sum(0)
        lw = lpz + lpx_z - lqz_x
        lws.append(lw)
    return log_mean_exp(torch.cat(lws)).sum()
Example #5
0
def log_logistic_volume_flow_marginal_estimate(recon_x_mu, recon_x_logvar, x,
                                               zk, z0, z0_mu, z0_logvar):
    batch_size, n_samples, z_dim = z0.size()
    input_dim = x.size(1)
    x = x.unsqueeze(1).repeat(1, n_samples, 1)

    z0_2d = z0.view(batch_size * n_samples, z_dim)
    zk_2d = zk.view(batch_size * n_samples, z_dim)
    z0_mu_2d = z0_mu.view(batch_size * n_samples, z_dim)
    z0_logvar_2d = z0_logvar.view(batch_size * n_samples, z_dim)
    recon_x_mu_2d = recon_x_mu.view(batch_size * n_samples, input_dim)
    recon_x_logvar_2d = recon_x_logvar.view(batch_size * n_samples, input_dim)
    x_2d = x.view(batch_size * n_samples, input_dim)

    log_p_x_given_zk_2d = logistic_256_log_pdf(x_2d, recon_x_mu_2d,
                                               recon_x_logvar_2d)
    log_q_z0_given_x_2d = gaussian_log_pdf(z0_2d, z0_mu_2d, z0_logvar_2d)
    log_q_zk_given_x_2d = log_q_z0_given_x_2d  # diff
    log_p_zk_2d = unit_gaussian_log_pdf(zk_2d)

    log_weight_2d = log_p_x_given_zk_2d + log_p_zk_2d - log_q_zk_given_x_2d
    log_weight = log_weight_2d.view(batch_size, n_samples)

    log_p_x = log_mean_exp(log_weight, dim=1)
    return -torch.mean(log_p_x)
Example #6
0
def m_iwae(model, x, K=1):
    """Computes iwae estimate for log p_\theta(x) for multi-modal vae """
    S = compute_microbatch_split(x, K)
    x_split = zip(*[_x.split(S) for _x in x])
    lw = [_m_iwae(model, _x, K) for _x in x_split]
    lw = torch.cat(lw, 1)  # concat on batch
    return log_mean_exp(lw).sum()
Example #7
0
def weighted_bernoulli_elbo_loss(recon_x_mu, x, z, z_mu, z_logvar):
    r"""Importance weighted evidence lower bound.

    @param recon_x_mu: torch.Tensor (batch size x # samples x |input_dim|)
                       reconstructed means on bernoulli
    @param x: torch.Tensor (batch size x |input_dim|)
                 original observed data
    @param z: torch.Tensor (batch_size x # samples x z dim)
              samples drawn from variational distribution
    @param z_mu: torch.Tensor (batch_size x # samples x z dim)
                 means of variational distribution
    @param z_logvar: torch.Tensor (batch_size x # samples x z dim)
                     log-variance of variational distribution
    """
    batch_size = recon_x_mu.size(0)
    n_samples = recon_x_mu.size(1)

    log_ws = []
    for i in xrange(n_samples):
        log_p_x_given_z = bernoulli_log_pdf(x, recon_x_mu[:, i])
        log_q_z_given_x = gaussian_log_pdf(z[:, i], z_mu[:, i], z_logvar[:, i])
        log_p_z = unit_gaussian_log_pdf(z[:, i])

        log_ws_i = log_p_x_given_z + log_p_z - log_q_z_given_x
        log_ws.append(log_ws_i.unsqueeze(1))

    log_ws = torch.cat(log_ws, dim=1)
    log_ws = log_mean_exp(log_ws, dim=1)
    BOUND = -torch.mean(log_ws)

    return BOUND
Example #8
0
def iwae(model, x, K):
    """Computes an importance-weighted ELBO estimate for log p_\theta(x)
    Iterates over the batch as necessary.
    """
    S = compute_microbatch_split(x, K)
    lw = torch.cat([_iwae(model, _x, K) for _x in x.split(S)],
                   1)  # concat on batch
    return log_mean_exp(lw).sum()
Example #9
0
def kld_inc(pz, qz_x):
    B, D = qz_x.loc.shape
    _zs = pz.rsample(torch.Size([B]))
    lpz = pz.log_prob(_zs).sum(-1).squeeze(-1)
    _zs = _zs.expand(B, B, D)
    lqz = log_mean_exp(qz_x.log_prob(_zs).sum(-1), dim=1)
    inc_kld = lpz - lqz
    inc_kld = inc_kld.mean(0, keepdim=True).expand(1, B)
    return inc_kld.mean(0).sum() / B
Example #10
0
 def get_grad(x, multiply=1):
     n = x.size(0)
     x = x.repeat([multiply, 1])
     elbo, q = ELBO(x)
     reinforce(elbo, q, idb=None, iib=iib)
     iwlb = utils.log_mean_exp(elbo.view(multiply, n).permute(1, 0), 1)
     loss = (-iwlb).mean()
     loss.backward()
     return loss.data.cpu().numpy()
Example #11
0
def m_iwae_looser(model, x, K=1):
    """Computes iwae estimate for log p_\theta(x) for multi-modal vae
    This version is the looser bound---with the average over modalities outside the log
    """
    S = compute_microbatch_split(x, K)
    x_split = zip(*[_x.split(S) for _x in x])
    lw = [_m_iwae_looser(model, _x, K) for _x in x_split]
    lw = torch.cat(lw, 2)  # concat on batch
    return log_mean_exp(lw, dim=1).mean(0).sum()
Example #12
0
    def log_marginal_likelihood_estimate(self, x, num_samples, srng):
        num_xs = x.shape[0]
        # rep_x = T.extra_ops.repeat(x, num_samples, axis=0)
        rep_x = t_repeat(x, num_samples, axis=0)
        samples = self.q_samplesIx_srng(rep_x, srng)

        log_ws = self.log_weightsIq_samples(samples)
        log_ws_matrix = T.reshape(log_ws, (num_xs, num_samples))
        log_marginal_estimate = log_mean_exp(log_ws_matrix, axis=1)

        return log_marginal_estimate
Example #13
0
    def log_marginal_likelihood_estimate(self, x, num_samples, srng):
        num_xs = x.shape[0]
        # rep_x = T.extra_ops.repeat(x, num_samples, axis=0)
        rep_x = t_repeat(x, num_samples, axis=0)
        samples = self.q_samplesIx_srng(rep_x, srng)

        log_ws = self.log_weightsIq_samples(samples)
        log_ws_matrix = T.reshape(log_ws, (num_xs, num_samples))
        log_marginal_estimate = log_mean_exp(log_ws_matrix, axis=1)

        return log_marginal_estimate
Example #14
0
    def log_likelihood_estimate(self, recon_x, x_tile, Z, mu, logsig):

        bce = x_tile * torch.log(recon_x) + (1. - x_tile) * torch.log(1 -
                                                                      recon_x)
        log_p_x_z = torch.sum(torch.sum(torch.sum(bce, dim=4), dim=3), dim=2)

        log_q_z_x = log_likelihood_samples_mean_sigma(Z, mu, logsig, dim=2)
        log_p_z = prior_z(Z, dim=2)

        log_ws = log_p_x_z - log_q_z_x + log_p_z
        log_ws_minus_max = log_ws - torch.max(log_ws, dim=1, keepdim=True)[0]

        ws = torch.exp(log_ws_minus_max)
        normalized_ws = ws / torch.sum(ws, dim=1, keepdim=True)
        loss = torch.sum(
            torch.matmul(normalized_ws.transpose(1, 0),
                         log_mean_exp(log_ws, dim=1)))
        lle = torch.mean(torch.squeeze(log_mean_exp(log_ws, dim=1)), dim=0)

        return -lle, -loss
Example #15
0
def cross_iwaes(qz_xs, px_zs, zss, x):
    lws = []
    for e, _px_zs in enumerate(px_zs):  # rows are encoders
        lpz = model.pz(*model.pz_params).log_prob(zss[e]).sum(-1)
        lqz_x = qz_xs[e].log_prob(zss[e]).sum(-1)
        _lpx_zs = [
            px_z.log_prob(x[d]).view(*px_z.batch_shape[:2], -1).sum(-1)
            for d, px_z in enumerate(_px_zs)
        ]
        lws.append(
            [log_mean_exp(_lpx_z + lpz - lqz_x).sum() for _lpx_z in _lpx_zs])
    return lws
Example #16
0
    def log_likelihood_estimate(self, recon_x, x, Z, mu, logsig):

        N, C, iw, ih = x.shape
        x_tile = x.repeat(self.num_sam, 1, 1, 1, 1).permute(1, 0, 2, 3, 4)

        bce = x_tile * torch.log(recon_x) + (1. - x_tile) * torch.log(1 -
                                                                      recon_x)
        log_p_x_z = torch.sum(torch.sum(torch.sum(bce, dim=4), dim=3), dim=2)

        log_q_z_x = log_likelihood_samples_mean_sigma(Z, mu, logsig, dim=2)
        log_p_z = prior_z(Z, dim=2)
        log_ws = log_p_x_z - log_q_z_x + log_p_z
        #log_ws_minus_max    = log_ws - torch.max(log_ws, dim=1, keepdim=True)[0]
        #ws                  = torch.exp(log_ws_minus_max)
        #normalized_ws       = ws / torch.sum(ws, dim=1, keepdim=True)
        return -torch.mean(torch.squeeze(log_mean_exp(log_ws, dim=1)), dim=0)
Example #17
0
File: vae.py Project: lim0606/BDMC
  def forward(self, x, k=1, warmup_const=1.):

    x = x.repeat(k, 1)
    mu, logvar = self.encode(x)
    z, logpz, logqz = self.sample(mu, logvar)
    _, x_logits = self.decode(z)

    logpx = utils.log_bernoulli(x_logits, x)
    elbo = logpx + logpz - warmup_const * logqz

    # need correction for Tensor.repeat
    elbo = utils.log_mean_exp(elbo.view(k, -1).transpose(0, 1))
    elbo = torch.mean(elbo)

    logpx = torch.mean(logpx)
    logpz = torch.mean(logpz)
    logqz = torch.mean(logqz)

    return elbo, logpx, logpz, logqz
Example #18
0
def weighted_gaussian_elbo_loss(recon_x_mu, recon_x_logvar, x, z, z_mu,
                                z_logvar):
    n_samples = recon_x_mu.size(1)

    log_ws = []
    for i in xrange(n_samples):
        log_p_x_given_z = logistic_256_log_pdf(x, recon_x_mu[:, i],
                                               recon_x_logvar[:, i])
        log_q_z_given_x = gaussian_log_pdf(z[:, i], z_mu[:, i], z_logvar[:, i])
        log_p_z = unit_gaussian_log_pdf(z[:, i])

        log_ws_i = log_p_x_given_z + log_p_z - log_q_z_given_x
        log_ws.append(log_ws_i.unsqueeze(1))

    log_ws = torch.cat(log_ws, dim=1)
    log_ws = log_mean_exp(log_ws, dim=1)
    BOUND = -torch.mean(log_ws)

    return BOUND
Example #19
0
def _m_dreg(model, x, K=1):
    """DERG estimate for log p_\theta(x) for multi-modal vae -- fully vectorised"""
    qz_xs, px_zs, zss = model(x, K)
    qz_xs_ = [
        vae.qz_x(*[p.detach() for p in vae.qz_x_params]) for vae in model.vaes
    ]
    lws = []
    for r, vae in enumerate(model.vaes):
        lpz = model.pz(*model.pz_params).log_prob(zss[r]).sum(-1)
        lqz_x = log_mean_exp(
            torch.stack([qz_x_.log_prob(zss[r]).sum(-1) for qz_x_ in qz_xs_]))
        lpx_z = [
            px_z.log_prob(x[d]).view(*px_z.batch_shape[:2], -1).mul(
                model.vaes[d].llik_scaling).sum(-1)
            for d, px_z in enumerate(px_zs[r])
        ]
        lpx_z = torch.stack(lpx_z).sum(0)
        lw = lpz + lpx_z - lqz_x
        lws.append(lw)
    return torch.cat(lws), torch.cat(zss)
    def get_label_marginal(self, y, n_samples=100):
        z_mu, z_logvar = self.inference(None, None, y)

        log_w = []
        for i in xrange(n_samples):
            z_i = self.reparameterize(z_mu, z_logvar)
            y_out_i = self.label_decoder(z_i)

            log_p_y_given_z_i = bernoulli_log_pdf(y, y_out_i)
            log_q_z_given_y_i = gaussian_log_pdf(z_i, z_mu, z_logvar)
            log_p_z_i = unit_gaussian_log_pdf(z_i)

            log_w_i = log_p_y_given_z_i + log_p_z_i - log_q_z_given_y_i
            log_w.append(log_w_i.unsqueeze(1))

        log_w = torch.cat(log_w, dim=1)
        log_p_y = log_mean_exp(log_w, dim=1)
        log_p_y = -torch.mean(log_p_y)

        return log_p_y
Example #21
0
def _m_iwae_looser(model, x, K=1):
    """IWAE estimate for log p_\theta(x) for multi-modal vae -- fully vectorised
    This version is the looser bound---with the average over modalities outside the log
    """
    qz_xs, px_zs, zss = model(x, K)
    lws = []
    for r, qz_x in enumerate(qz_xs):
        lpz = model.pz(*model.pz_params).log_prob(zss[r]).sum(-1)
        lqz_x = log_mean_exp(
            torch.stack([qz_x.log_prob(zss[r]).sum(-1) for qz_x in qz_xs]))
        lpx_z = [
            px_z.log_prob(x[d]).view(*px_z.batch_shape[:2], -1).mul(
                model.vaes[d].llik_scaling).sum(-1)
            for d, px_z in enumerate(px_zs[r])
        ]
        lpx_z = torch.stack(lpx_z).sum(0)
        lw = lpz + lpx_z - lqz_x
        lws.append(lw)
    return torch.stack(
        lws)  # (n_modality * n_samples) x batch_size, batch_size
Example #22
0
    def log_Z(self, n_betas=100, n_runs=100, n_gibbs_steps=5):
        """Estimate log partition function using Annealed Importance Sampling.
        Currently implemented only for 2-layer binary BM.
        AIS is run on a state space x = {h_1} with v and h_2
        analytically summed out, as in [1] and using formulae from [4].
        To obtain reasonable estimate, parameter `n_betas` should be at least 10000 or more.

        Parameters
        ----------
        n_betas : >1 int
            Number of intermediate distributions.
        n_runs : positive int
            Number of AIS runs.
        n_gibbs_steps : positive int
            Number of Gibbs steps per transition.

        Returns
        -------
        log_mean, (log_low, log_high) : float
            `log_mean` = log(Z_mean)
            `log_low`  = log(Z_mean - std(Z))
            `log_high` = log(Z_mean + std(Z))
        values : (`n_runs`,) np.ndarray
            All estimates.
        """
        assert self.n_layers_ == 2
        for L in [self._v_layer] + self._h_layers:
            assert isinstance(L, BernoulliLayer)

        self._log_Z = tf.get_collection('log_Z')[0]
        values = self._tf_session.run(self._log_Z,
                                      feed_dict=self._make_tf_feed_dict(
                                          delta_beta=1. / n_betas,
                                          n_ais_runs=n_runs,
                                          n_gibbs_steps=n_gibbs_steps))

        log_mean = log_mean_exp(values)
        log_std = log_std_exp(values, log_mean_exp_x=log_mean)
        log_high = log_sum_exp([log_std, log_mean])
        log_low = log_diff_exp([log_std, log_mean])[0]
        return log_mean, (log_low, log_high), values
    def get_program_marginal(self, seq, length, n_samples=100):
        z_mu, z_logvar = self.inference(seq, length, None)

        log_w = []
        for i in xrange(n_samples):
            z_i = self.reparameterize(z_mu, z_logvar)
            seq_logits_i = self.program_decoder(z_i, seq, length)

            # probability of text is product of probabilities of each word
            log_p_x_given_z_i = categorical_program_log_pdf(seq[:, 1:], seq_logits_i[:, :-1])
            log_q_z_given_x_i = gaussian_log_pdf(z_i, z_mu, z_logvar)
            log_p_z_i = unit_gaussian_log_pdf(z_i)

            log_w_i = log_p_x_given_z_i + log_p_z_i - log_q_z_given_x_i
            log_w.append(log_w_i.unsqueeze(1))

        log_w = torch.cat(log_w, dim=1)
        log_p_x = log_mean_exp(log_w, dim=1)
        log_p_x = -torch.mean(log_p_x)

        return log_p_x
    def get_joint_marginal(self, seq, length, label, n_samples=100):
        z_mu, z_logvar = self.inference(seq, length, label)

        log_w = []
        for i in xrange(n_samples):
            z_i = self.reparameterize(z_mu, z_logvar)
            x_logits_i = self.program_decoder(z_i, seq, length)
            y_out_i = self.label_decoder(z_i)

            log_p_x_given_z_i = categorical_program_log_pdf(seq[:, 1:], x_logits_i[:, :-1])
            log_p_y_given_z_i = bernoulli_log_pdf(label, y_out_i)
            log_q_z_given_x_y_i = gaussian_log_pdf(z_i, z_mu, z_logvar)
            log_p_z_i = unit_gaussian_log_pdf(z_i)

            log_w_i = log_p_x_given_z_i + log_p_y_given_z_i + log_p_z_i - log_q_z_given_x_y_i
            log_w.append(log_w_i.unsqueeze(1))

        log_w = torch.cat(log_w, dim=1)
        log_p_x_y = log_mean_exp(log_w, dim=1)
        log_p_x_y = -torch.mean(log_p_x_y)

        return log_p_x_y
Example #25
0
def log_bernoulli_marginal_estimate(recon_x_mu, x, z, z_mu, z_logvar):
    r"""Estimate log p(x). NOTE: this is not the objective that
    should be directly optimized.

    @param recon_x_mu: torch.Tensor (batch size x # samples x input_dim)
                       reconstructed means on bernoulli
    @param x: torch.Tensor (batch size x input_dim)
              original observed data
    @param z: torch.Tensor (batch_size x # samples x z dim)
              samples drawn from variational distribution
    @param z_mu: torch.Tensor (batch_size x # samples x z dim)
                 means of variational distribution
    @param z_logvar: torch.Tensor (batch_size x # samples x z dim)
                     log-variance of variational distribution
    """
    batch_size, n_samples, z_dim = z.size()
    input_dim = x.size(1)
    x = x.unsqueeze(1).repeat(1, n_samples, 1)

    z_2d = z.view(batch_size * n_samples, z_dim)
    z_mu_2d = z_mu.view(batch_size * n_samples, z_dim)
    z_logvar_2d = z_logvar.view(batch_size * n_samples, z_dim)
    recon_x_mu_2d = recon_x_mu.view(batch_size * n_samples, input_dim)
    x_2d = x.view(batch_size * n_samples, input_dim)

    log_p_x_given_z_2d = bernoulli_log_pdf(x_2d, recon_x_mu_2d)
    log_q_z_given_x_2d = gaussian_log_pdf(z_2d, z_mu_2d, z_logvar_2d)
    log_p_z_2d = unit_gaussian_log_pdf(z_2d)

    log_weight_2d = log_p_x_given_z_2d + log_p_z_2d - log_q_z_given_x_2d
    log_weight = log_weight_2d.view(batch_size, n_samples)

    # need to compute normalization constant for weights
    # i.e. log ( mean ( exp ( log_weights ) ) )
    log_p_x = log_mean_exp(log_weight, dim=1)
    return -torch.mean(log_p_x)
Example #26
0
def objective(vae_model,
              c_model,
              device,
              x,
              mask,
              types,
              label,
              beta,
              nu,
              K=10,
              kappa=1.0,
              components=False):
    """Computes E_{p(x)}[ELBO_{\alpha,\beta}] """
    types = vae_model.disease_types
    qz_x, px_z, zs = vae_model(x, K)

    # compute supervised loss
    pred = c_model(zs.squeeze(0), device)
    bce = torch.nn.BCELoss(reduction='none')
    onehot_label = torch.zeros((x.size(0), types), dtype=torch.float32)
    for i in range(len(x)):
        for j in range(len(mask[i])):
            onehot_label[i][mask[i][j]] = 1
    supervised_loss = bce(pred, onehot_label)

    # compute vae loss
    lpx_z = px_z.log_prob(x).view(*px_z.batch_shape[:2], -1).sum(-1)
    pz = vae_model.pz(*vae_model.pz_params)
    kld = kl_divergence(qz_x, pz, samples=zs).sum(-1)

    # compute kl(p(z), q(z))
    B, D = qz_x.loc.shape
    _zs = pz.rsample(torch.Size([B]))
    lpz = pz.log_prob(_zs).sum(-1).squeeze(-1)

    _zs_expand = _zs.expand(B, B, D)
    lqz = qz_x.log_prob(_zs_expand).sum(-1)  #B*B

    #    qz = []
    #    _max = torch.max(lqz)
    #    for i in range(types):
    #        tmp_mask = torch.sum(mask == i,dim=-1).bool()
    #        ds = lqz[:, tmp_mask] - _max
    #        qz_j = torch.exp(ds)  #B*k
    #        qz.append((qz_j * beta[i]))
    #    qz = torch.cat(qz, dim=1).to(device)
    #    lqz = torch.sum(qz, dim=1)
    #    lqz = _max + torch.log(lqz) - math.log(qz.size(1))
    lqz = log_mean_exp(lqz, dim=1)

    inc_kld = lpz - lqz
    inc_kld = inc_kld.mean(0, keepdim=True).expand(1, B)
    inc_kld = inc_kld.mean(0).sum() / B

    # compute kl(q(z), N(z))
    _zs = qz_x.rsample(torch.Size([B]))  #B*B*D
    lqz = qz_x.log_prob(_zs)  #B*B*D
    kld2 = []
    for i in range(types):
        tmp_mask = torch.sum(mask == i, dim=-1).bool()
        if torch.sum(tmp_mask) == 0:
            continue
        tmp_zs = _zs[:, tmp_mask, :]  #B*k*D
        lnj = dist.Normal(vae_model.pz_params[0][i],
                          vae_model.pz_params[1][i]).log_prob(tmp_zs)
        _kl = (tmp_zs - lnj).mean(0)  #k*D
        kld2.append((_kl * nu[i]).mean(0))
    kld2 = torch.stack(kld2).to(device)
    kld2 = torch.sum(kld2)

    obj = supervised_loss.sum(
        dim=-1) - lpx_z + kld.mean() + kappa * inc_kld + kld2
    return obj.mean() if not components else (
        obj.mean(), supervised_loss.mean(dim=0), lpx_z.mean(), kld.mean(),
        inc_kld.mean(), kld2.mean(), pred.cpu().detach().numpy(),
        onehot_label.cpu().detach().numpy())
Example #27
0
File: gqn.py Project: soudia/snp
    def evaluate(self, C, XY, A=None, params={}):

        # parse params
        beta = params["beta"] if "beta" in params else 1.0
        std = params["std"] if "std" in params else 1.0
        K = params["K"] if "K" in params else 50
        eval_nll = params["eval_nll"] if "eval_nll" in params else True
        eval_mse = params["eval_mse"] if "eval_mse" in params else True

        # init prelims
        n_episodes = len(C[0])
        n_timesteps = len(C)

        # loss adders
        info_kl_adder = NormalizedAdder(next(self.parameters()).new_zeros(1))
        info_kl_t_adder = NormalizedAdderList(
            next(self.parameters()).new_zeros(1), n_timesteps)

        info_recon_nll_adder = NormalizedAdder(
            next(self.parameters()).new_zeros(1))
        info_recon_nll_t_adder = NormalizedAdderList(
            next(self.parameters()).new_zeros(1), n_timesteps)

        info_gen_nll_adder = NormalizedAdder(
            next(self.parameters()).new_zeros(1))
        info_gen_nll_t_adder = NormalizedAdderList(
            next(self.parameters()).new_zeros(1), n_timesteps)

        iwae_info = {
            "log_p_y_giv_xz__per_t":
            next(self.parameters()).new_zeros((n_timesteps, n_episodes, K)),
            "log_pz/qz__per_t":
            next(self.parameters()).new_zeros((n_timesteps, n_episodes, K)),
        }

        # init reps
        AA = []

        # expand actions
        for t in range(n_timesteps):
            if self.n_actions > 0:
                a_t = torch.cat([a_t_b.unsqueeze(0) for a_t_b in A[t]], dim=0)
            else:
                a_t = next(self.parameters()).new_ones((n_episodes, 1)) * (
                    t / 50.
                )  # hardcode 50. to have same scale for different data sets
            AA += [a_t]
        if self.n_actions > 0:
            AA = self.action_encoder(AA)

        # append action to C
        CA = recursive_clone_structure(C)
        for t in range(len(CA)):
            for b in range(len(CA[t])):
                if CA[t][b][QUERIES] is not None:
                    n_X_t_b = len(CA[t][b][QUERIES])
                    CA[t][b][QUERIES] = torch.cat([
                        CA[t][b][QUERIES], AA[t][b].unsqueeze(0).repeat(
                            n_X_t_b, 1)
                    ],
                                                  dim=1)

        # append action to X
        XYA = recursive_clone_structure(XY)
        for t in range(len(XYA)):
            for b in range(len(XYA[t])):
                n_X_t_b = len(XYA[t][b][QUERIES])
                XYA[t][b][QUERIES] = torch.cat([
                    XYA[t][b][QUERIES], AA[t][b].unsqueeze(0).repeat(
                        n_X_t_b, 1)
                ],
                                               dim=1)

        # compute representations
        R_CA = []
        R_CXYA = 0
        for t in range(n_timesteps):
            # compute context representation and record cumulative context
            R_CA_t = None
            if CA[t][b][IMAGES] is not None and CA[t][b][QUERIES] is not None:
                R_CA_t = self.repnet(CA[t])

            if R_CA_t is not None:
                R_CA = R_CA + [R_CA[t - 1] + R_CA_t] if len(R_CA) > 0 else [
                    R_CA_t
                ]
            else:
                R_CA = R_CA + [R_CA[t - 1]]

            # compute target representation
            R_XYA_t = None
            if XYA[t][b][IMAGES] is not None and XYA[t][b][QUERIES] is not None:
                R_XYA_t = self.repnet(XYA[t])

            # add context and target for inference
            if R_CA_t is not None:
                R_CXYA += R_CA_t
            if R_XYA_t is not None:
                R_CXYA += R_XYA_t

        if eval_nll:
            for k in range(K):
                SS = {
                    "gqn": {},
                }
                for t in range(n_timesteps):
                    response = self.convdraw(R_CA[t], R_CXYA)

                    # record kl
                    info_kl_adder.append(response["kl"].detach())
                    info_kl_t_adder[t].append(response["kl"].detach())

                    # record log p/q
                    iwae_info["log_pz/qz__per_t"][
                        t, :, k] = response["log_pz/qz_batchwise"].detach()

                    # record states
                    SS["gqn"][t] = {'z_t': response["z_t"]}

                # emission
                emission_gqn = self.emission(SS["gqn"], XYA, std=std)

                # record recon nll
                info_recon_nll_adder.append(emission_gqn["nll"].detach())
                info_recon_nll_t_adder.append_list([
                    emission_gqn["nll_per_t"][t].detach()
                    for t in emission_gqn["nll_per_t"]
                ])

                # record log p_y_giv_xz
                for t in emission_gqn["nll_per_t_per_b"]:
                    for b in emission_gqn["nll_per_t_per_b"][t]:
                        iwae_info["log_p_y_giv_xz__per_t"][
                            t, b,
                            k] = -emission_gqn["nll_per_t_per_b"][t][b].detach(
                            )

        # basic metrics
        ret = {
            "info_scalar": {
                "kl":
                info_kl_adder.mean().detach().item(),
                "recon_nll":
                info_recon_nll_adder.mean().detach().item(),
                "elbo":
                info_kl_adder.mean().detach().item() +
                info_recon_nll_adder.mean().detach().item(),
            },
            "info_temporal": {
                "kl_t":
                [item.detach().item() for item in info_kl_t_adder.mean_list()],
                "recon_nll_t": [
                    item.detach().item()
                    for item in info_recon_nll_t_adder.mean_list()
                ],
            },
        }

        # elbo per timestep
        ret["info_temporal"]["elbo_t"] = [
            (ret["info_temporal"]["kl_t"][t] +
             ret["info_temporal"]["recon_nll_t"][t])
            for t in range(n_timesteps)
        ]

        # iwae nll
        ret["info_scalar"]["iwae_nll"] = -torch.mean(
            log_mean_exp(
                torch.sum(iwae_info["log_p_y_giv_xz__per_t"], dim=0) + torch.
                sum(iwae_info["log_pz/qz__per_t"], dim=0))).detach().item()
        ret["info_temporal"]["iwae_nll_t"] = [
            -torch.mean(
                log_mean_exp(
                    iwae_info["log_p_y_giv_xz__per_t"][t] +
                    iwae_info["log_pz/qz__per_t"][t])).detach().item()
            for t in range(n_timesteps)
        ]

        # gen nll
        if eval_mse:
            for k in range(K):
                SS = {
                    "gqn": {},
                }

                for t in range(n_timesteps):
                    response = self.convdraw.generate(R_CA[t])
                    SS["gqn"][t] = {'z_t': response["z_t"]}

                # emission
                emission_gqn = self.emission(SS["gqn"], XYA, std=std)

                # record nll
                info_gen_nll_adder.append(emission_gqn["nll"].detach())
                info_gen_nll_t_adder.append_list([
                    emission_gqn["nll_per_t"][t].detach()
                    for t in emission_gqn["nll_per_t"]
                ])

        ret["info_scalar"]["gen_nll"] = info_gen_nll_adder.mean().detach(
        ).item()
        ret["info_temporal"]["gen_nll_t"] = [
            item.detach().item() for item in info_gen_nll_t_adder.mean_list()
        ]

        return ret
Example #28
0
def main():
    # torch.autograd.set_detect_anomaly(True)

    torch.backends.cudnn.benchmark = True

    # os.environ['CUDA_LAUNCH_BLOCKING'] = '1'

    args = get_config()[0]

    torch.manual_seed(args.train.seed)
    torch.cuda.manual_seed(args.train.seed)
    torch.cuda.manual_seed_all(args.train.seed)
    np.random.seed(args.train.seed)

    model_dir = os.path.join(args.model_dir, args.exp_name)
    summary_dir = os.path.join(args.summary_dir, args.exp_name)

    if not os.path.isdir(model_dir):
        os.makedirs(model_dir)
    if not os.path.isdir(summary_dir):
        os.makedirs(summary_dir)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    # torch.manual_seed(args.seed)
    args.train.num_gpu = torch.cuda.device_count()
    with open(os.path.join(summary_dir, 'config.yaml'), 'w') as f:
        yaml.dump(args, f)
    if args.data.dataset == 'mnist':
        train_data = MultiMNIST(args, mode='train')
        test_data = MultiMNIST(args, mode='test')
        val_data = MultiMNIST(args, mode='val')
    elif args.data.dataset == 'blender':
        train_data = Blender(args, mode='train')
        test_data = Blender(args, mode='test')
        val_data = Blender(args, mode='val')
    else:
        raise NotImplemented

    train_loader = DataLoader(train_data,
                              batch_size=args.train.batch_size,
                              shuffle=True,
                              drop_last=True,
                              num_workers=6)
    num_train = len(train_data)

    test_loader = DataLoader(test_data,
                             batch_size=args.train.batch_size * 4,
                             shuffle=False,
                             drop_last=True,
                             num_workers=6)
    num_test = len(test_data)

    val_loader = DataLoader(val_data,
                            batch_size=args.train.batch_size * 4,
                            shuffle=False,
                            drop_last=True,
                            num_workers=6)
    num_val = len(val_data)

    model = GNM(args)
    model.to(device)
    num_gpu = 1
    if device.type == 'cuda' and torch.cuda.device_count() > 1:
        print("Let's use", torch.cuda.device_count(), "GPUs!")
        num_gpu = torch.cuda.device_count()
        model = nn.DataParallel(model)
    model.train()

    optimizer = torch.optim.RMSprop(model.parameters(), lr=args.train.lr)

    global_step = 0
    if args.last_ckpt:
        global_step, args.train.start_epoch = \
            load_ckpt(model, optimizer, args.last_ckpt, device)

    args.train.global_step = global_step
    args.log.phase_log = False

    writer = SummaryWriter(summary_dir)

    end_time = time.time()

    for epoch in range(int(args.train.start_epoch), args.train.epoch):

        local_count = 0
        last_count = 0
        for batch_idx, sample in enumerate(train_loader):

            imgs = sample.to(device)

            hyperparam_anneal(args, global_step)

            global_step += 1

            phase_log = global_step % args.log.print_step_freq == 0 or global_step == 1
            args.train.global_step = global_step
            args.log.phase_log = phase_log

            pa_recon, log_like, kl, _, _, _, log = \
                model(imgs)

            aux_kl_pres, aux_kl_where, aux_kl_depth, aux_kl_what, aux_kl_bg, kl_pres, \
            kl_where, kl_depth, kl_what, kl_global_all, kl_bg = kl

            aux_kl_pres_raw = aux_kl_pres.mean(dim=0)
            aux_kl_where_raw = aux_kl_where.mean(dim=0)
            aux_kl_depth_raw = aux_kl_depth.mean(dim=0)
            aux_kl_what_raw = aux_kl_what.mean(dim=0)
            aux_kl_bg_raw = aux_kl_bg.mean(dim=0)
            kl_pres_raw = kl_pres.mean(dim=0)
            kl_where_raw = kl_where.mean(dim=0)
            kl_depth_raw = kl_depth.mean(dim=0)
            kl_what_raw = kl_what.mean(dim=0)
            kl_bg_raw = kl_bg.mean(dim=0)

            log_like = log_like.mean(dim=0)

            aux_kl_pres = aux_kl_pres_raw * args.train.beta_aux_pres
            aux_kl_where = aux_kl_where_raw * args.train.beta_aux_where
            aux_kl_depth = aux_kl_depth_raw * args.train.beta_aux_depth
            aux_kl_what = aux_kl_what_raw * args.train.beta_aux_what
            aux_kl_bg = aux_kl_bg_raw * args.train.beta_aux_bg
            kl_pres = kl_pres_raw * args.train.beta_pres
            kl_where = kl_where_raw * args.train.beta_where
            kl_depth = kl_depth_raw * args.train.beta_depth
            kl_what = kl_what_raw * args.train.beta_what
            kl_bg = kl_bg_raw * args.train.beta_bg

            kl_global_raw = kl_global_all.sum(dim=-1).mean(dim=0)
            kl_global = kl_global_raw * args.train.beta_global

            total_loss = -(log_like - kl_pres - kl_where - kl_depth - kl_what -
                           kl_bg - kl_global - aux_kl_pres - aux_kl_where -
                           aux_kl_depth - aux_kl_what - aux_kl_bg)

            optimizer.zero_grad()
            total_loss.backward()

            clip_grad_norm_(model.parameters(), args.train.cp)
            optimizer.step()

            local_count += imgs.data.shape[0]
            if phase_log:

                bs = imgs.size(0)

                time_inter = time.time() - end_time
                count_inter = local_count - last_count
                print_schedule(global_step, epoch, local_count, count_inter,
                               num_train, total_loss, time_inter)
                end_time = time.time()

                for name, param in model.named_parameters():
                    writer.add_histogram('param/' + name,
                                         param.cpu().detach().numpy(),
                                         global_step)
                    if param.grad is not None:
                        writer.add_histogram('grad/' + name,
                                             param.grad.cpu().detach(),
                                             global_step)
                        if len(param.size()) != 1:
                            writer.add_scalar(
                                'grad_std/' + name + '.grad',
                                param.grad.cpu().detach().std().item(),
                                global_step)
                        writer.add_scalar(
                            'grad_mean/' + name + '.grad',
                            param.grad.cpu().detach().mean().item(),
                            global_step)

                for key, value in log.items():
                    if value is None:
                        continue

                    if key == 'importance_map_full_res_norm':
                        writer.add_histogram(
                            'inside_value/' + key,
                            value[value > 0].cpu().detach().numpy(),
                            global_step)
                    else:
                        writer.add_histogram('inside_value/' + key,
                                             value.cpu().detach().numpy(),
                                             global_step)

                grid_image = make_grid(
                    imgs.cpu().detach()[:args.log.num_summary_img].view(
                        -1, args.data.inp_channel, args.data.img_h,
                        args.data.img_w),
                    args.log.num_img_per_row,
                    normalize=False,
                    pad_value=1)
                writer.add_image('train/1-image', grid_image, global_step)

                grid_image = make_grid(pa_recon[0].cpu().detach()
                                       [:args.log.num_summary_img].clamp(
                                           0,
                                           1).view(-1, args.data.inp_channel,
                                                   args.data.img_h,
                                                   args.data.img_w),
                                       args.log.num_img_per_row,
                                       normalize=False,
                                       pad_value=1)
                writer.add_image('train/2-reconstruction_overall', grid_image,
                                 global_step)

                if args.arch.phase_background:
                    grid_image = make_grid(pa_recon[1].cpu().detach()
                                           [:args.log.num_summary_img].clamp(
                                               0,
                                               1).view(-1,
                                                       args.data.inp_channel,
                                                       args.data.img_h,
                                                       args.data.img_w),
                                           args.log.num_img_per_row,
                                           normalize=False,
                                           pad_value=1)
                    writer.add_image('train/3-reconstruction-fg', grid_image,
                                     global_step)

                    grid_image = make_grid(pa_recon[2].cpu().detach()
                                           [:args.log.num_summary_img].clamp(
                                               0,
                                               1).view(-1, 1, args.data.img_h,
                                                       args.data.img_w),
                                           args.log.num_img_per_row,
                                           normalize=False,
                                           pad_value=1)
                    writer.add_image('train/4-reconstruction-alpha',
                                     grid_image, global_step)

                    grid_image = make_grid(pa_recon[3].cpu().detach()
                                           [:args.log.num_summary_img].clamp(
                                               0,
                                               1).view(-1,
                                                       args.data.inp_channel,
                                                       args.data.img_h,
                                                       args.data.img_w),
                                           args.log.num_img_per_row,
                                           normalize=False,
                                           pad_value=1)
                    writer.add_image('train/5-reconstruction-bg', grid_image,
                                     global_step)

                bbox = visualize(
                    imgs[:args.log.num_summary_img].cpu(),
                    log['z_pres'].view(
                        bs, args.arch.num_cell**2,
                        -1)[:args.log.num_summary_img].cpu().detach(),
                    log['z_where_scale'].view(
                        bs, args.arch.num_cell**2,
                        -1)[:args.log.num_summary_img].cpu().detach(),
                    log['z_where_shift'].view(
                        bs, args.arch.num_cell**2,
                        -1)[:args.log.num_summary_img].cpu().detach(),
                    only_bbox=True,
                    phase_only_display_pres=False)

                bbox = bbox.view(args.log.num_summary_img, -1, 3,
                                 args.data.img_h,
                                 args.data.img_w).sum(1).clamp(0.0, 1.0)
                bbox_img = imgs[:args.log.num_summary_img].cpu().expand(
                    -1, 3, -1, -1).contiguous()
                bbox_img[bbox.sum(dim=1, keepdim=True).expand(-1, 3, -1, -1) > 0.5] = \
                    bbox[bbox.sum(dim=1, keepdim=True).expand(-1, 3, -1, -1) > 0.5]
                grid_image = make_grid(bbox_img,
                                       args.log.num_img_per_row,
                                       normalize=False,
                                       pad_value=1)

                writer.add_image('train/6-bbox', grid_image, global_step)

                bbox_white = visualize(
                    imgs[:args.log.num_summary_img].cpu(),
                    log['z_pres'].view(
                        bs, args.arch.num_cell**2,
                        -1)[:args.log.num_summary_img].cpu().detach(),
                    log['z_where_scale'].view(
                        bs, args.arch.num_cell**2,
                        -1)[:args.log.num_summary_img].cpu().detach(),
                    log['z_where_shift'].view(
                        bs, args.arch.num_cell**2,
                        -1)[:args.log.num_summary_img].cpu().detach(),
                    only_bbox=True,
                    phase_only_display_pres=True)

                bbox_white = bbox_white.view(args.log.num_summary_img, -1, 3,
                                             args.data.img_h,
                                             args.data.img_w).sum(1).clamp(
                                                 0.0, 1.0)
                bbox_white_img = imgs[:args.log.num_summary_img].cpu().expand(
                    -1, 3, -1, -1).contiguous()
                bbox_white_img[bbox_white.sum(dim=1, keepdim=True).expand(-1, 3, -1, -1) > 0.5] = \
                    bbox_white[bbox_white.sum(dim=1, keepdim=True).expand(-1, 3, -1, -1) > 0.5]
                grid_image = make_grid(bbox_white_img,
                                       args.log.num_img_per_row,
                                       normalize=False,
                                       pad_value=1)

                writer.add_image('train/6a-bbox-white', grid_image,
                                 global_step)

                grid_image = make_grid(log['recon_from_q_g'].cpu().detach()
                                       [:args.log.num_summary_img].clamp(
                                           0,
                                           1).view(-1, args.data.inp_channel,
                                                   args.data.img_h,
                                                   args.data.img_w),
                                       args.log.num_img_per_row,
                                       normalize=False,
                                       pad_value=1)
                writer.add_image('train/7-reconstruction_from_q_g', grid_image,
                                 global_step)

                if args.arch.phase_background:
                    grid_image = make_grid(
                        log['recon_from_q_g_fg'].cpu().detach()
                        [:args.log.num_summary_img].clamp(0, 1).view(
                            -1, args.data.inp_channel, args.data.img_h,
                            args.data.img_w),
                        args.log.num_img_per_row,
                        normalize=False,
                        pad_value=1)
                    writer.add_image('train/8-recon_from_q_g-fg', grid_image,
                                     global_step)

                    grid_image = make_grid(
                        log['recon_from_q_g_alpha'].cpu().detach()
                        [:args.log.num_summary_img].clamp(0, 1).view(
                            -1, 1, args.data.img_h, args.data.img_w),
                        args.log.num_img_per_row,
                        normalize=False,
                        pad_value=1)
                    writer.add_image('train/9-recon_from_q_g-alpha',
                                     grid_image, global_step)

                    grid_image = make_grid(
                        log['recon_from_q_g_bg'].cpu().detach()
                        [:args.log.num_summary_img].clamp(0, 1).view(
                            -1, args.data.inp_channel, args.data.img_h,
                            args.data.img_w),
                        args.log.num_img_per_row,
                        normalize=False,
                        pad_value=1)
                    writer.add_image('train/a-background_from_q_g', grid_image,
                                     global_step)

                writer.add_scalar('train/total_loss',
                                  total_loss.item(),
                                  global_step=global_step)
                writer.add_scalar('train/log_like',
                                  log_like.item(),
                                  global_step=global_step)
                writer.add_scalar('train/What_KL',
                                  kl_what.item(),
                                  global_step=global_step)
                writer.add_scalar('train/bg_KL',
                                  kl_bg.item(),
                                  global_step=global_step)
                writer.add_scalar('train/Where_KL',
                                  kl_where.item(),
                                  global_step=global_step)
                writer.add_scalar('train/Pres_KL',
                                  kl_pres.item(),
                                  global_step=global_step)
                writer.add_scalar('train/Depth_KL',
                                  kl_depth.item(),
                                  global_step=global_step)
                writer.add_scalar('train/kl_global',
                                  kl_global.item(),
                                  global_step=global_step)
                writer.add_scalar('train/What_KL_raw',
                                  kl_what_raw.item(),
                                  global_step=global_step)
                writer.add_scalar('train/bg_KL_raw',
                                  kl_bg_raw.item(),
                                  global_step=global_step)
                writer.add_scalar('train/Where_KL_raw',
                                  kl_where_raw.item(),
                                  global_step=global_step)
                writer.add_scalar('train/Pres_KL_raw',
                                  kl_pres_raw.item(),
                                  global_step=global_step)
                writer.add_scalar('train/Depth_KL_raw',
                                  kl_depth_raw.item(),
                                  global_step=global_step)

                writer.add_scalar('train/aux_What_KL',
                                  aux_kl_what.item(),
                                  global_step=global_step)
                writer.add_scalar('train/aux_bg_KL',
                                  aux_kl_bg.item(),
                                  global_step=global_step)
                writer.add_scalar('train/aux_Where_KL',
                                  aux_kl_where.item(),
                                  global_step=global_step)
                writer.add_scalar('train/aux_Pres_KL',
                                  aux_kl_pres.item(),
                                  global_step=global_step)
                writer.add_scalar('train/aux_Depth_KL',
                                  aux_kl_depth.item(),
                                  global_step=global_step)
                writer.add_scalar('train/aux_What_KL_raw',
                                  aux_kl_what_raw.item(),
                                  global_step=global_step)
                writer.add_scalar('train/aux_bg_KL_raw',
                                  aux_kl_bg_raw.item(),
                                  global_step=global_step)
                writer.add_scalar('train/aux_Where_KL_raw',
                                  aux_kl_where_raw.item(),
                                  global_step=global_step)
                writer.add_scalar('train/aux_Pres_KL_raw',
                                  aux_kl_pres_raw.item(),
                                  global_step=global_step)
                writer.add_scalar('train/aux_Depth_KL_raw',
                                  aux_kl_depth_raw.item(),
                                  global_step=global_step)

                writer.add_scalar('train/kl_global_raw',
                                  kl_global_raw.item(),
                                  global_step=global_step)
                writer.add_scalar('train/tau_pres',
                                  args.train.tau_pres,
                                  global_step=global_step)
                for i in range(args.arch.draw_step):
                    writer.add_scalar(f'train/kl_global_raw_step_{i}',
                                      kl_global_all[:, i].mean().item(),
                                      global_step=global_step)

                writer.add_scalar('train/log_prob_x_given_g',
                                  log['log_prob_x_given_g'].mean(0).item(),
                                  global_step=global_step)

                elbo = (log_like.item() - kl_pres_raw.item() -
                        kl_where_raw.item() - kl_depth_raw.item() -
                        kl_what_raw.item() - kl_bg_raw.item() -
                        kl_global_raw.item())

                writer.add_scalar('train/elbo', elbo, global_step=global_step)

                ######################################## generation ########################################

                with torch.no_grad():
                    model.eval()
                    if num_gpu > 1:
                        sample = model.module.sample()[0]
                    else:
                        sample = model.sample()[0]
                    model.train()

                grid_image = make_grid(sample[0].cpu().detach().clamp(0, 1),
                                       args.log.num_img_per_row,
                                       normalize=False,
                                       pad_value=1)
                writer.add_image('generation/1-image', grid_image, global_step)

                if args.arch.phase_background:
                    grid_image = make_grid(sample[1].cpu().detach().clamp(
                        0, 1),
                                           args.log.num_img_per_row,
                                           normalize=False,
                                           pad_value=1)
                    writer.add_image('generation/2-fg', grid_image,
                                     global_step)

                    grid_image = make_grid(sample[2].cpu().detach().clamp(
                        0, 1),
                                           args.log.num_img_per_row,
                                           normalize=False,
                                           pad_value=1)
                    writer.add_image('generation/3-alpha', grid_image,
                                     global_step)

                    grid_image = make_grid(sample[3].cpu().detach().clamp(
                        0, 1),
                                           args.log.num_img_per_row,
                                           normalize=False,
                                           pad_value=1)
                    writer.add_image('generation/4-bg', grid_image,
                                     global_step)

                ###################################### generation end ######################################

                last_count = local_count

        ###################################### ll computing ######################################
        # only for logging, final ll should be computed using 100 particles

        if epoch % args.log.compute_nll_freq == 0:

            print(f'val nll at the end of epoch {epoch}')

            model.eval()
            args.log.phase_nll = True

            elbo_list = []
            kl_list = []
            ll_list = []
            with torch.no_grad():
                args.log.phase_log = False
                for batch_idx, sample in enumerate(val_loader):
                    imgs = sample.to(device)

                    ll_sample_list = []
                    for i in range(args.log.nll_num_sample):
                        _, log_like, kl, log_imp, _, _, _ = \
                            model(imgs)
                        aux_kl_pres, aux_kl_where, aux_kl_depth, aux_kl_what, \
                        aux_kl_bg, kl_pres, kl_where, kl_depth, kl_what, \
                        kl_global_all, kl_bg = kl

                        log_imp_pres, log_imp_depth, log_imp_what, log_imp_where, log_imp_bg, log_imp_g = log_imp

                        ll_sample_list.append(
                            (log_like + log_imp_pres + log_imp_depth +
                             log_imp_what + log_imp_where + log_imp_bg +
                             log_imp_g).cpu())
                        # Only use one sample for elbo
                        if i == 0:
                            elbo_list.append((log_like - kl_pres - kl_where -
                                              kl_depth - kl_what - kl_bg -
                                              kl_global_all.sum(dim=1)).cpu())
                            kl_list.append(
                                (kl_pres + kl_where + kl_depth + kl_what +
                                 kl_bg + kl_global_all.sum(dim=1)).cpu())
                    ll_sample = log_mean_exp(torch.stack(ll_sample_list,
                                                         dim=1),
                                             dim=1)
                    ll_list.append(ll_sample)

                ll_all = torch.cat(ll_list, dim=0)
                elbo_all = torch.cat(elbo_list, dim=0)
                kl_all = torch.cat(kl_list, dim=0)

            writer.add_scalar('val/ll',
                              ll_all.mean(0).item(),
                              global_step=epoch)
            writer.add_scalar('val/elbo',
                              elbo_all.mean(0).item(),
                              global_step=epoch)
            writer.add_scalar('val/kl',
                              kl_all.mean(0).item(),
                              global_step=epoch)

            args.log.phase_nll = False
            model.train()

        if epoch % (args.log.compute_nll_freq * 10) == 0:

            print(f'test nll at the end of epoch {epoch}')

            model.eval()
            args.log.phase_nll = True

            elbo_list = []
            kl_list = []
            ll_list = []
            with torch.no_grad():
                args.log.phase_log = False
                for batch_idx, sample in enumerate(test_loader):
                    imgs = sample.to(device)

                    ll_sample_list = []
                    for i in range(args.log.nll_num_sample):
                        _, log_like, kl, log_imp, _, _, _ = \
                            model(imgs)
                        aux_kl_pres, aux_kl_where, aux_kl_depth, aux_kl_what, \
                        aux_kl_bg, kl_pres, kl_where, kl_depth, kl_what, \
                        kl_global_all, kl_bg = kl

                        log_imp_pres, log_imp_depth, log_imp_what, log_imp_where, log_imp_bg, log_imp_g = log_imp

                        ll_sample_list.append(
                            (log_like + log_imp_pres + log_imp_depth +
                             log_imp_what + log_imp_where + log_imp_bg +
                             log_imp_g).cpu())
                        # Only use one sample for elbo
                        if i == 0:
                            elbo_list.append((log_like - kl_pres - kl_where -
                                              kl_depth - kl_what - kl_bg -
                                              kl_global_all.sum(dim=1)).cpu())
                            kl_list.append(
                                (kl_pres + kl_where + kl_depth + kl_what +
                                 kl_bg + kl_global_all.sum(dim=1)).cpu())
                    ll_sample = log_mean_exp(torch.stack(ll_sample_list,
                                                         dim=1),
                                             dim=1)
                    ll_list.append(ll_sample)

                ll_all = torch.cat(ll_list, dim=0)
                elbo_all = torch.cat(elbo_list, dim=0)
                kl_all = torch.cat(kl_list, dim=0)

            writer.add_scalar('test/ll',
                              ll_all.mean(0).item(),
                              global_step=epoch)
            writer.add_scalar('test/elbo',
                              elbo_all.mean(0).item(),
                              global_step=epoch)
            writer.add_scalar('test/kl',
                              kl_all.mean(0).item(),
                              global_step=epoch)

            args.log.phase_nll = False
            model.train()

        if epoch % args.log.save_epoch_freq == 0 and epoch != 0:
            save_ckpt(model_dir, model, optimizer, global_step, epoch,
                      local_count, args.train.batch_size, num_train)

    save_ckpt(model_dir, model, optimizer, global_step, epoch, local_count,
              args.train.batch_size, num_train)
Example #29
0
def ais_trajectory(model, loader, mode='forward', schedule=np.linspace(0., 1., 500), n_sample=100):
    """Compute annealed importance sampling trajectories for a batch of data. 
    Could be used for *both* forward and reverse chain in bidirectional Monte Carlo
    (default: forward chain with linear schedule).

    Args:
        model (vae.VAE): VAE model
        loader (iterator): iterator that returns pairs, with first component being `x`,
            second would be `z` or label (will not be used)
        mode (string): indicate forward/backward chain; must be either `forward` or 'backward'
        schedule (list or 1D np.ndarray): temperature schedule, i.e. `p(z)p(x|z)^t`;
            foward chain has increasing values, whereas backward has decreasing values
        n_sample (int): number of importance samples (i.e. number of parallel chains 
            for each datapoint)

    Returns:
        A list where each element is a torch.autograd.Variable that contains the 
        log importance weights for a single batch of data
    """

    assert mode == 'forward' or mode == 'backward', 'Should have either forward/backward mode'

    def log_f_i(z, data, t, log_likelihood_fn=log_bernoulli):
        """Unnormalized density for intermediate distribution `f_i`:
            f_i = p(z)^(1-t) p(x,z)^(t) = p(z) p(x|z)^t
        =>  log f_i = log p(z) + t * log p(x|z)
        """
        zeros = torch.zeros(B, z_size, dtype = z.dtype, device = z.device)
        log_prior = log_normal(z, zeros, zeros)
        log_likelihood = log_likelihood_fn(model.decoder(z), data)

        return log_prior + log_likelihood.mul_(t)

    # shorter aliases
    try:
        z_size = model.hps.z_size
    except:
        try:
            z_size = model.z_dim
        except:
            z_size = model.zdim
    #mdtype = model.dtype

    _time = time.time()
    logws = []  # for output

    print ('In %s mode' % mode)

    for i, (batch, post_z) in enumerate(loader):

        B = batch.size(0) * n_sample
        batch = safe_repeat(batch, n_sample)
        # batch of step sizes, one for each chain
        epsilon = torch.ones(B, dtype = batch.dtype, device = batch.device).mul_(0.01)
        # accept/reject history for tuning step size
        accept_hist = torch.zeros(B, dtype = batch.dtype, device = batch.device)

        # record log importance weight; volatile=True reduces memory greatly
        logw = torch.zeros(B, dtype = batch.dtype, device = batch.device)
        logw.requires_grad = False

        # initial sample of z
        if mode == 'forward':
            current_z = torch.randn(B, z_size, dtype=batch.dtype, device = batch.device)
            #current_z = model.encode(batch)
            #current_z.detach_()
            current_z.requires_grad = True
        else:
            current_z = safe_repeat(post_z, n_sample).type(batch.dtype).device(batch.device)
            current_z.requires_grad = True

        for j, (t0, t1) in tqdm(enumerate(zip(schedule[:-1], schedule[1:]), 1)):
            # update log importance weight
            log_int_1 = log_f_i(current_z, batch, t0).data
            log_int_2 = log_f_i(current_z, batch, t1).data
            logw.data.add_(log_int_2 - log_int_1)

            del log_int_1, log_int_2

            # resample speed
            current_v = torch.randn(current_z.size(), dtype=batch.dtype, device=batch.device, requires_grad = False)

            def U(z):
                return -log_f_i(z, batch, t1)

            def grad_U(z):
                # grad w.r.t. outputs; mandatory in this case
                grad_outputs = torch.ones(B, dtype = batch.dtype, device = batch.device)
                grad = torchgrad(U(z), z, grad_outputs=grad_outputs)[0]
                # clip by norm
                grad = torch.clamp(grad, -B*z_size*100, B*z_size*100)
                grad.requires_grad = True
                return grad

            def normalized_kinetic(v):
                zeros = torch.zeros(B, z_size, dtype = batch.dtype, device = batch.device)
                # this is superior to the unnormalized version
                return -log_normal(v, zeros, zeros)

            
            z, v = hmc_trajectory(current_z, current_v, U, grad_U, epsilon)

            # accept-reject step
            current_z, epsilon, accept_hist = accept_reject(current_z, current_v,
                                                            z, v,
                                                            epsilon,
                                                            accept_hist, j,
                                                            U, K=normalized_kinetic)

        # IWAE lower bound
        logw = log_mean_exp(logw.view(n_sample, -1).transpose(0, 1))
        print(logw.size())
        if mode == 'backward':
            logw = -logw

        logws.append(logw.data)

        print ('Time elapse %.4f, last batch stats %.4f' % (time.time()-_time, logw.mean().cpu().data.numpy()))
        _time = time.time()

    return logws
Example #30
0
def ais_trajectory(model,
                   loader,
                   forward=True,
                   schedule=np.linspace(0., 1., 500),
                   n_sample=100):
  """Compute annealed importance sampling trajectories for a batch of data. 
  Could be used for *both* forward and reverse chain in BDMC.

  Args:
    model (vae.VAE): VAE model
    loader (iterator): iterator that returns pairs, with first component
      being `x`, second would be `z` or label (will not be used)
    forward (boolean): indicate forward/backward chain
    schedule (list or 1D np.ndarray): temperature schedule, i.e. `p(z)p(x|z)^t`
    n_sample (int): number of importance samples

  Returns:
      A list where each element is a torch.autograd.Variable that contains the 
      log importance weights for a single batch of data
  """

  def log_f_i(z, data, t, log_likelihood_fn=utils.log_bernoulli):
    """Unnormalized density for intermediate distribution `f_i`:
        f_i = p(z)^(1-t) p(x,z)^(t) = p(z) p(x|z)^t
    =>  log f_i = log p(z) + t * log p(x|z)
    """
    zeros = torch.zeros(B, model.latent_dim).cuda()
    log_prior = utils.log_normal(z, zeros, zeros)
    log_likelihood = log_likelihood_fn(model.decode(z), data)

    return log_prior + log_likelihood.mul_(t)

  logws = []
  for i, (batch, post_z) in enumerate(loader):
    B = batch.size(0) * n_sample
    batch = batch.cuda()
    batch = utils.safe_repeat(batch, n_sample)

    with torch.no_grad():
      epsilon = torch.ones(B).cuda().mul_(0.01)
      accept_hist = torch.zeros(B).cuda()
      logw = torch.zeros(B).cuda()

    # initial sample of z
    if forward:
      current_z = torch.randn(B, model.latent_dim).cuda()
    else:
      current_z = utils.safe_repeat(post_z, n_sample).cuda()
    current_z = current_z.requires_grad_()

    for j, (t0, t1) in tqdm(enumerate(zip(schedule[:-1], schedule[1:]), 1)):
      # update log importance weight
      log_int_1 = log_f_i(current_z, batch, t0)
      log_int_2 = log_f_i(current_z, batch, t1)
      logw += log_int_2.detach() - log_int_1.detach()

      # resample velocity
      current_v = torch.randn(current_z.size()).cuda()

      def U(z):
        return -log_f_i(z, batch, t1)

      def grad_U(z):
        # grad w.r.t. outputs; mandatory in this case
        grad_outputs = torch.ones(B).cuda()
        # torch.autograd.grad default returns volatile
        grad = torchgrad(U(z), z, grad_outputs=grad_outputs)[0]
        # clip by norm
        max_ = B * model.latent_dim * 100.
        grad = torch.clamp(grad, -max_, max_)
        grad.requires_grad_()
        return grad

      def normalized_kinetic(v):
        zeros = torch.zeros(B, model.latent_dim).cuda()
        return -utils.log_normal(v, zeros, zeros)

      z, v = hmc.hmc_trajectory(current_z, current_v, U, grad_U, epsilon)
      current_z, epsilon, accept_hist = hmc.accept_reject(
          current_z, current_v,
          z, v,
          epsilon,
          accept_hist, j,
          U, K=normalized_kinetic)

    logw = utils.log_mean_exp(logw.view(n_sample, -1).transpose(0, 1))
    if not forward:
      logw = -logw
    logws.append(logw.data)
    print('Last batch stats %.4f' % (logw.mean().cpu().data.numpy()))

  return logws