Пример #1
0
    def forward(self, x):
        """
        Args:
            x (torch.Tensor): Input images [batch size, 3, dim, dim]
        """
        # --- Predict segmentation masks ---
        log_m_k, log_s_k, att_stats = self.att_process(x, self.K_steps - 1)

        # --- Reconstruct components ---
        x_m_r_k, comp_stats = self.comp_vae(x, log_m_k)
        # Split into appearances and mask prior
        x_r_k = [item[:, :3, :, :] for item in x_m_r_k]
        m_r_logits_k = [item[:, 3:, :, :] for item in x_m_r_k]
        # Apply pixelbound
        if self.pixel_bound:
            x_r_k = [torch.sigmoid(item) for item in x_r_k]

        # --- Reconstruct input image by marginalising (aka summing) ---
        x_r_stack = torch.stack(x_r_k, dim=4)
        m_stack = torch.stack(log_m_k, dim=4).exp()
        recon = (m_stack * x_r_stack).sum(dim=4)

        # --- Reconstruct masks ---
        log_m_r_stack = self.get_mask_recon_stack(m_r_logits_k,
                                                  self.prior_mode,
                                                  log=True)
        log_m_r_k = torch.split(log_m_r_stack, 1, dim=4)
        log_m_r_k = [m[:, :, :, :, 0] for m in log_m_r_k]

        # --- Loss terms ---
        losses = AttrDict()
        # -- Reconstruction loss
        losses['err'] = Genesis.x_loss(x, log_m_k, x_r_k, self.std)
        # -- Attention mask KL
        losses['kl_m'] = self.kl_m_loss(log_m_k=log_m_k, log_m_r_k=log_m_r_k)
        # -- Component KL
        q_z_k = [
            Normal(m, s) for m, s in zip(comp_stats.mu_k, comp_stats.sigma_k)
        ]
        kl_l_k = misc.get_kl(comp_stats.z_k, q_z_k,
                             len(q_z_k) * [Normal(0, 1)], self.mckl)
        losses['kl_l_k'] = [kld.sum(1) for kld in kl_l_k]

        # Track quantities of interest
        stats = AttrDict(
            recon=recon,
            log_m_k=log_m_k,
            log_s_k=log_s_k,
            x_r_k=x_r_k,
            log_m_r_k=log_m_r_k,
            mx_r_k=[x * logm.exp() for x, logm in zip(x_r_k, log_m_k)])

        # Sanity check that masks sum to one if in debug mode
        if self.debug:
            assert len(log_m_k) == self.K_steps
            assert len(log_m_r_k) == self.K_steps
            misc.check_log_masks(log_m_k)
            misc.check_log_masks(log_m_r_k)

        return recon, losses, stats, att_stats, comp_stats
Пример #2
0
    def sample(self, batch_size, K_steps=None):
        K_steps = self.K_steps if K_steps is None else K_steps
        # --- Mask ---
        # Sample latents
        if self.autoreg_prior:
            zm_k = [Normal(0, 1).sample([batch_size, self.ldim])]
            state = None
            for k in range(1, self.att_steps):
                # TODO(martin) reuse code from forward method?
                lstm_out, state = self.prior_lstm(
                    zm_k[-1].view(1, batch_size, -1), state)
                linear_out = self.prior_linear(lstm_out)
                mu = linear_out[0, :, :self.ldim]
                sigma = B.to_prior_sigma(linear_out[0, :, self.ldim:])
                p_zm = Normal(mu.view([batch_size, self.ldim]),
                              sigma.view([batch_size, self.ldim]))
                zm_k.append(p_zm.sample())
        else:
            p_zm = Normal(0, 1)
            zm_k = [p_zm.sample([batch_size, self.ldim])
                    for _ in range(self.att_steps)]
        # Decode latents into masks
        log_m_k, log_s_k, out_k = self.att_process.masks_from_zm_k(
            zm_k, self.img_size)
        # UGLY: Need to correct the last mask for one stage model wo/ softmax
        # OR when running the two stage model for an additional step
        if len(log_m_k) == self.K_steps+1:
            del log_m_k[-1]
            log_m_k[self.K_steps-1] = log_s_k[self.K_steps-1]
        # Sanity checks. This function is not called at every training step so
        # assert statement do not cause a big slow down in total training time
        assert len(zm_k) == self.K_steps
        assert len(log_m_k) == self.K_steps
        if self.two_stage:
            assert out_k[0].size(1) == 0
        else:
            # assert out_k[0].size(1) == 3
            assert out_k[0].size(1) == 0
        misc.check_log_masks(log_m_k)

        # --- Component appearance ---
        if self.two_stage:
            # Sample latents
            if self.comp_prior:
                zc_k = []
                for zm in zm_k:
                    mlp_out = torch.chunk(self.prior_mlp(zm), 2, dim=1)
                    mu = torch.tanh(mlp_out[0])
                    sigma = B.to_prior_sigma(mlp_out[1])
                    zc_k.append(Normal(mu, sigma).sample())
                # if not self.softmax_attention:
                #     zc_k.append(Normal(0, 1).sample(
                #         [batch_size, self.comp_vae.ldim]))
            else:
                zc_k = [Normal(0, 1).sample([batch_size, self.comp_vae.ldim])
                        for _ in range(K_steps)]
            #  Decode latents into components
            zc_batch = torch.cat(zc_k, dim=0)
            x_batch = self.comp_vae.decode(zc_batch)
            x_k = list(torch.chunk(x_batch, self.K_steps, dim=0))
        else:
            # x_k = out_k
            zm_batched = torch.cat(zm_k, dim=0)
            x_batched = self.decoder(zm_batched)
            x_k = torch.chunk(x_batched, self.K_steps, dim=0)
            if self.pixel_bound:
                x_k = [torch.sigmoid(x) for x in x_k]
        # Sanity check
        assert len(x_k) == self.K_steps
        assert len(log_m_k) == self.K_steps
        if self.two_stage:
            assert len(zc_k) == self.K_steps

        # --- Reconstruct image ---
        x_stack = torch.stack(x_k, dim=4)
        m_stack = torch.stack(log_m_k, dim=4).exp()
        generated_image = (m_stack * x_stack).sum(dim=4)

        # Stats
        stats = AttrDict(x_k=x_k, log_m_k=log_m_k, log_s_k=log_s_k,
                         mx_k=[x*m.exp() for x, m in zip(x_k, log_m_k)])
        return generated_image, stats
Пример #3
0
    def forward(self, x):
        """
        Performs a forward pass in the model.

        Args:
          x (torch.Tensor): input images [batch size, 3, dim, dim]

        Returns:
          recon: reconstructed images [N, 3, H, W]
          losses: 
          stats: 
          att_stats: 
          comp_stats: 
        """

        # --- Predict segmentation masks ---
        log_m_k, log_s_k, att_stats = self.att_process(x, self.att_steps)
        # UGLY: Need to correct the last mask for one stage model wo/ softmax
        # OR when running the two stage model for an additional step
        if len(log_m_k) == self.K_steps+1:
            del log_m_k[-1]
            log_m_k[self.K_steps-1] = log_s_k[self.K_steps-1]
        if self.debug or not self.training:
            assert len(log_m_k) == self.K_steps

        # --- Reconstruct components ---
        if self.two_stage:
            x_r_k, comp_stats = self.comp_vae(x, log_m_k)
        else:
            # x_r_k = [x[:, 1:, :, :] for x in att_stats.x_k]
            z_batched = torch.cat(att_stats.z_k, dim=0)
            x_r_batched = self.decoder(z_batched)
            x_r_k = torch.chunk(x_r_batched, self.K_steps, dim=0)
            if self.pixel_bound:
                x_r_k = [torch.sigmoid(x) for x in x_r_k]
            comp_stats = None

        # --- Reconstruct input image by marginalising (aka summing) ---
        x_r_stack = torch.stack(x_r_k, dim=4)
        m_stack = torch.stack(log_m_k, dim=4).exp()
        recon = (m_stack * x_r_stack).sum(dim=4)

        # --- Loss terms ---
        losses = AttrDict()
        # -- Reconstruction loss
        losses['err'] = self.x_loss(x, log_m_k, x_r_k, self.std)
        # -- Attention mask KL
        # Using normalising flow, arbitrary posterior
        if 'zm_0_k' in att_stats and 'zm_k_k' in att_stats:
            q_zm_0_k = [Normal(m, s) for m, s in
                        zip(att_stats.mu_k, att_stats.sigma_k)]
            zm_0_k = att_stats.z_0_k
            zm_k_k = att_stats.z_k_k  #TODO(martin) variable name not ideal
            ldj_k = att_stats.ldj_k
        # No flow, Gaussian posterior
        else:
            q_zm_0_k = [Normal(m, s) for m, s in
                        zip(att_stats.mu_k, att_stats.sigma_k)]
            zm_0_k = att_stats.z_k
            zm_k_k = att_stats.z_k
            ldj_k = None
        # Compute loss
        losses['kl_m_k'], p_zm_k = self.mask_latent_loss(
            q_zm_0_k, zm_0_k, zm_k_k, ldj_k, self.prior_lstm, self.prior_linear,
            debug=self.debug or not self.training)
        att_stats['pmu_k'] = [p_zm.mean for p_zm in p_zm_k]
        att_stats['psigma_k'] = [p_zm.scale for p_zm in p_zm_k]
        # Sanity checks
        if self.debug or not self.training:
            assert len(zm_k_k) == self.K_steps
            assert len(zm_0_k) == self.K_steps
            if ldj_k is not None:
                assert len(ldj_k) == self.K_steps
        # -- Component KL
        if self.two_stage:
            if self.comp_prior:
                losses['kl_l_k'] = []
                comp_stats['pmu_k'], comp_stats['psigma_k'] = [], []
                for step, zl in enumerate(comp_stats.z_k):
                    mlp_out = self.prior_mlp(zm_k_k[step])
                    mlp_out = torch.chunk(mlp_out, 2, dim=1)
                    mu = torch.tanh(mlp_out[0])
                    sigma = B.to_prior_sigma(mlp_out[1])
                    p_zl = Normal(mu, sigma)
                    comp_stats['pmu_k'].append(mu)
                    comp_stats['psigma_k'].append(sigma)
                    q_zl = Normal(comp_stats.mu_k[step], comp_stats.sigma_k[step])
                    kld = (q_zl.log_prob(zl) - p_zl.log_prob(zl)).sum(dim=1)
                    losses['kl_l_k'].append(kld)
                # Sanity checks
                if self.debug or not self.training:
                    assert len(comp_stats.z_k) == self.K_steps
                    assert len(comp_stats['pmu_k']) == self.K_steps
                    assert len(comp_stats['psigma_k']) == self.K_steps
                    assert len(losses['kl_l_k']) == self.K_steps
            else:
                raise NotImplementedError

        # Tracking
        stats = AttrDict(
            recon=recon, log_m_k=log_m_k, log_s_k=log_s_k, x_r_k=x_r_k,
            mx_r_k=[x*logm.exp() for x, logm in zip(x_r_k, log_m_k)])

        # Sanity check that masks sum to one if in debug mode
        if self.debug or not self.training:
            assert len(log_m_k) == self.K_steps
            misc.check_log_masks(log_m_k)

        return recon, losses, stats, att_stats, comp_stats