def forward(self, x, y, opt, step, writer=None):
        opt.zero_grad()
        batch_size = x.size(0)
        z = self.flownet(x)

        # Loss
        y, delta_log_py = self.point_cnf(
            y, z,
            torch.zeros(batch_size, y.size(1), 1).to(y))
        if self.logprob_type == "Laplace":
            log_py = standard_laplace_logprob(y).view(batch_size,
                                                      -1).sum(1, keepdim=True)
        if self.logprob_type == "Normal":
            log_py = standard_normal_logprob(y).view(batch_size,
                                                     -1).sum(1, keepdim=True)
        delta_log_py = delta_log_py.view(batch_size, y.size(1), 1).sum(1)
        log_px = log_py - delta_log_py

        loss = -log_px.mean()

        loss.backward()
        opt.step()
        recon = -log_px.sum()
        recon_nats = recon / float(y.size(0))
        return recon_nats
    def forward(self, hist, nbrs, masks, op_mask, y, opt):
        opt.zero_grad()
        batch_size = y.size(0)
        target_networks_weights = self.hyper(hist, nbrs, masks)

        # Loss
        y, delta_log_py = self.point_cnf(
            y, target_networks_weights,
            torch.zeros(batch_size, y.size(1), 1).to(y))
        if self.logprob_type == "Laplace":
            log_py = standard_laplace_logprob(y).view(batch_size,
                                                      -1).sum(1, keepdim=True)
        if self.logprob_type == "Normal":
            log_py = standard_normal_logprob(y).view(batch_size,
                                                     -1).sum(1, keepdim=True)
        delta_log_py = delta_log_py.view(batch_size, y.size(1), 1).sum(1)
        log_px = log_py - delta_log_py
        log_px = log_px * op_mask
        loss = -log_px.mean()

        loss.backward()
        opt.step()
        recon = -log_px.sum()
        recon_nats = recon / float(y.size(0))
        return recon_nats
    def get_logprob(self, hist, nbrs, masks, y_in):
        batch_size = y_in.size(0)
        target_networks_weights = self.hyper(hist, nbrs, masks)

        # Loss
        y, delta_log_py = self.point_cnf(
            y_in, target_networks_weights,
            torch.zeros(batch_size, y_in.size(1), 1).to(y_in))
        if self.logprob_type == "Laplace":
            log_py = standard_laplace_logprob(y).view(batch_size,
                                                      -1).sum(1, keepdim=True)
        if self.logprob_type == "Normal":
            log_py = standard_normal_logprob(y).view(batch_size,
                                                     -1).sum(1, keepdim=True)

        delta_log_py = delta_log_py.view(batch_size, y.size(1), 1).sum(1)
        log_px = log_py - delta_log_py

        return log_py, log_px
    def get_logprob(self, x, y_in):
        batch_size = x.size(0)
        w = self.flownet(x)

        # Loss
        y, delta_log_py = self.point_cnf(
            y_in, w,
            torch.zeros(batch_size, y_in.size(1), 1).to(y_in))
        if self.logprob_type == "Laplace":
            log_py = standard_laplace_logprob(y)
        if self.logprob_type == "Normal":
            log_py = standard_normal_logprob(y)

        batch_log_py = log_py.sum(dim=2)
        batch_log_px = batch_log_py - delta_log_py.sum(dim=2)
        log_py = log_py.view(batch_size, -1).sum(1, keepdim=True)
        delta_log_py = delta_log_py.view(batch_size, y.size(1), 1).sum(1)
        log_px = log_py - delta_log_py

        return log_py, log_px, (batch_log_py, batch_log_px)
Пример #5
0
    def forward(self, x, y, opt, step, writer=None):
        opt.zero_grad()
        batch_size = x.size(0)
        target_networks_weights = self.hyper(x)

        # Loss
        y, delta_log_py = self.point_cnf(
            y, target_networks_weights,
            torch.zeros(batch_size, y.size(1), 1).to(y))
        log_py = standard_normal_logprob(y).view(batch_size,
                                                 -1).sum(1, keepdim=True)
        delta_log_py = delta_log_py.view(batch_size, y.size(1), 1).sum(1)
        log_px = log_py - delta_log_py

        loss = -log_px.mean()

        loss.backward()
        opt.step()
        recon = -log_px.sum()
        recon_nats = recon / float(y.size(0))
        return recon_nats
Пример #6
0
    def forward(self, x, step, writer=None):
        # x is (n, l, c)
        batch_size = x.size(0)
        num_points = x.size(1)
        z_mu, z_sigma = self.encoder(x)  # assume z_sigma is ln(sigma)
        if self.use_deterministic_encoder:
            z = z_mu + 0 * z_sigma
        else:
            z = self.reparameterize_gaussian(z_mu, z_sigma)

        # Compute H[Q(z|X)]
        if self.use_deterministic_encoder:
            entropy = torch.zeros(batch_size).to(z)
        else:
            entropy = self.gaussian_entropy(z_sigma)

        # Compute the prior probability P(z)
        if self.use_latent_flow:
            w, delta_log_pw = self.latent_rsf(z,
                                              torch.zeros(batch_size, 1).to(z))
            log_pw = standard_normal_logprob(w).view(batch_size,
                                                     -1).sum(1, keepdim=True)
            delta_log_pw = delta_log_pw.view(batch_size, 1)
            log_pz = log_pw - delta_log_pw
        else:
            log_pz = torch.zeros(batch_size, 1).to(z)

        # Compute the reconstruction likelihood P(X|z)
        # z_new = z.view(*z.size())
        # z_new = z_new + (log_pz * 0.).mean()
        y, delta_log_py = self.point_rsf(
            x,
            torch.zeros(batch_size, num_points, 1).to(x))
        log_py = standard_normal_logprob(y).view(batch_size,
                                                 -1).sum(1, keepdim=True)
        delta_log_py = delta_log_py.view(batch_size, num_points, 1).sum(1)
        log_px = log_py - delta_log_py

        # Loss
        entropy_loss = -entropy.mean() * self.entropy_weight
        recon_loss = -log_px.mean() * self.recon_weight
        prior_loss = -log_pz.mean() * self.prior_weight
        loss = entropy_loss + prior_loss + recon_loss

        # LOGGING (after the training)
        if self.distributed:
            entropy_log = reduce_tensor(entropy.mean())
            recon = reduce_tensor(-log_px.mean())
            prior = reduce_tensor(-log_pz.mean())
        else:
            entropy_log = entropy.mean()
            recon = -log_px.mean()
            prior = -log_pz.mean()

        recon_nats = recon / float(x.size(1) * x.size(2))
        prior_nats = prior / float(self.zdim)

        if writer is not None:
            writer.add_scalar('train/entropy', entropy_log, step)
            writer.add_scalar('train/prior', prior, step)
            writer.add_scalar('train/prior(nats)', prior_nats, step)
            writer.add_scalar('train/recon', recon, step)
            writer.add_scalar('train/recon(nats)', recon_nats, step)

        return {
            'entropy':
            entropy_log.cpu().detach().item()
            if not isinstance(entropy_log, float) else entropy_log,
            'prior_nats':
            prior_nats,
            'recon_nats':
            recon_nats,
        }, loss
Пример #7
0
    def forward(self,
                x,
                x_noisy,
                std_in,
                opt,
                step=None,
                writer=None,
                init=False,
                valid=False):
        opt.zero_grad()
        batch_size = x.size(0)
        num_points = x.size(1)
        z_mu, z_sigma = self.encoder(x)
        if self.use_deterministic_encoder:
            z = z_mu + 0 * z_sigma
        else:
            z = self.reparameterize_gaussian(z_mu, z_sigma)

        # Compute H[Q(z|X)]
        if self.use_deterministic_encoder:
            entropy = torch.zeros(batch_size).to(z)
        else:
            entropy = self.gaussian_entropy(z_sigma)

        # Compute the prior probability P(z)
        w, delta_log_pw = self.latent_glow(z)
        log_pw = standard_normal_logprob(w).view(batch_size,
                                                 -1).sum(1, keepdim=True)
        delta_log_pw = delta_log_pw.view(batch_size, 1)
        log_pz = log_pw - delta_log_pw

        # Compute the reconstruction likelihood P(X|z)
        z_new = z.view(*z.size())
        z_new = z_new + (log_pz * 0.).mean()
        y, delta_log_py = self.point_AF(x_noisy, std_in, z_new)
        log_py = standard_normal_logprob(y).view(batch_size,
                                                 -1).sum(1, keepdim=True)
        delta_log_py = delta_log_py.view(batch_size, num_points, 1).sum(1)
        log_px = log_py - delta_log_py

        # Loss
        entropy_loss = -entropy.mean()
        recon_loss = -log_px.mean()
        prior_loss = -log_pz.mean()
        loss = entropy_loss + prior_loss + recon_loss
        if not init and not valid:
            loss.backward()
            opt.step()

        # LOGGING (after the training)
        if self.distributed:
            loss = reduce_tensor(loss.mean())
            entropy_log = reduce_tensor(entropy.mean())
            recon = reduce_tensor(-log_px.mean())
            prior = reduce_tensor(-log_pz.mean())
        else:
            loss = loss.mean()
            entropy_log = entropy.mean()
            recon = -log_px.mean()
            prior = -log_pz.mean()

        recon_nats = recon / float(x.size(1) * x.size(2))
        prior_nats = prior / float(self.zdim)

        if writer is not None and not valid:
            writer.add_scalar('train/entropy', entropy_log, step)
            writer.add_scalar('train/prior', prior, step)
            writer.add_scalar('train/prior(nats)', prior_nats, step)
            writer.add_scalar('train/recon', recon, step)
            writer.add_scalar('train/recon(nats)', recon_nats, step)
            writer.add_scalar('train/loss', loss.item(), step)

        return {
            'entropy':
            entropy_log.cpu().detach().item()
            if not isinstance(entropy_log, float) else entropy_log,
            'prior_nats':
            prior_nats,
            'recon_nats':
            recon_nats,
            'prior':
            prior,
            'recon':
            recon,
            'loss':
            loss.item()
        }