Exemplo n.º 1
0
    def mask_latent_loss(q_zm_0_k,
                         zm_0_k,
                         zm_k_k,
                         ldj_k,
                         prior_lstm=None,
                         prior_linear=None,
                         debug=False):
        num_steps = len(zm_k_k)
        batch_size = zm_k_k[0].size(0)
        latent_dim = zm_k_k[0].size(1)

        # -- Determine prior --
        if prior_lstm is not None and prior_linear is not None:
            # zm_seq shape: (att_steps-2, batch_size, ldim)
            # Do not need the last element in z_k
            zm_seq = torch.cat(
                [zm.view(1, batch_size, -1) for zm in zm_k_k[:-1]], dim=0)
            # lstm_out shape: (att_steps-2, batch_size, state_size)
            # Note: recurrent state is handled internally by LSTM
            lstm_out, _ = prior_lstm(zm_seq)
            # linear_out shape: (att_steps-2, batch_size, 2*ldim)
            linear_out = prior_linear(lstm_out)
            linear_out = torch.chunk(linear_out, 2, dim=2)
            mu_raw = torch.tanh(linear_out[0])
            sigma_raw = B.to_prior_sigma(linear_out[1])
            # Split into K steps, shape: (att_steps-2)*[1, batch_size, ldim]
            mu_k = torch.split(mu_raw, 1, dim=0)
            sigma_k = torch.split(sigma_raw, 1, dim=0)
            # Use standard Normal as prior for first step
            p_zm_k = [Normal(0, 1)]
            # Autoregressive prior for later steps
            for mean, std in zip(mu_k, sigma_k):
                # Remember to remove unit dimension at dim=0
                p_zm_k += [
                    Normal(mean.view(batch_size, latent_dim),
                           std.view(batch_size, latent_dim))
                ]
            # Sanity checks
            if debug:
                assert zm_seq.size(0) == num_steps - 1
        else:
            p_zm_k = num_steps * [Normal(0, 1)]

        # -- Compute KL using Monte Carlo samples for every step k --
        kl_m_k = []
        for step, p_zm in enumerate(p_zm_k):
            log_q = q_zm_0_k[step].log_prob(zm_0_k[step]).sum(dim=1)
            log_p = p_zm.log_prob(zm_k_k[step]).sum(dim=1)
            kld = log_q - log_p
            if ldj_k is not None:
                ldj = ldj_k[step].sum(dim=1)
                kld = kld - ldj
            kl_m_k.append(kld)

        # -- Sanity check --
        if debug:
            assert len(p_zm_k) == num_steps
            assert len(kl_m_k) == num_steps

        return kl_m_k, p_zm_k
Exemplo n.º 2
0
    def sample(self, batch_size, K_steps=None):
        K_steps = self.K_steps if K_steps is None else K_steps
        # --- Mask ---
        # Sample latents
        if self.autoreg_prior:
            zm_k = [Normal(0, 1).sample([batch_size, self.ldim])]
            state = None
            for k in range(1, self.att_steps):
                # TODO(martin) reuse code from forward method?
                lstm_out, state = self.prior_lstm(
                    zm_k[-1].view(1, batch_size, -1), state)
                linear_out = self.prior_linear(lstm_out)
                mu = linear_out[0, :, :self.ldim]
                sigma = B.to_prior_sigma(linear_out[0, :, self.ldim:])
                p_zm = Normal(mu.view([batch_size, self.ldim]),
                              sigma.view([batch_size, self.ldim]))
                zm_k.append(p_zm.sample())
        else:
            p_zm = Normal(0, 1)
            zm_k = [p_zm.sample([batch_size, self.ldim])
                    for _ in range(self.att_steps)]
        # Decode latents into masks
        log_m_k, log_s_k, out_k = self.att_process.masks_from_zm_k(
            zm_k, self.img_size)
        # UGLY: Need to correct the last mask for one stage model wo/ softmax
        # OR when running the two stage model for an additional step
        if len(log_m_k) == self.K_steps+1:
            del log_m_k[-1]
            log_m_k[self.K_steps-1] = log_s_k[self.K_steps-1]
        # Sanity checks. This function is not called at every training step so
        # assert statement do not cause a big slow down in total training time
        assert len(zm_k) == self.K_steps
        assert len(log_m_k) == self.K_steps
        if self.two_stage:
            assert out_k[0].size(1) == 0
        else:
            # assert out_k[0].size(1) == 3
            assert out_k[0].size(1) == 0
        misc.check_log_masks(log_m_k)

        # --- Component appearance ---
        if self.two_stage:
            # Sample latents
            if self.comp_prior:
                zc_k = []
                for zm in zm_k:
                    mlp_out = torch.chunk(self.prior_mlp(zm), 2, dim=1)
                    mu = torch.tanh(mlp_out[0])
                    sigma = B.to_prior_sigma(mlp_out[1])
                    zc_k.append(Normal(mu, sigma).sample())
                # if not self.softmax_attention:
                #     zc_k.append(Normal(0, 1).sample(
                #         [batch_size, self.comp_vae.ldim]))
            else:
                zc_k = [Normal(0, 1).sample([batch_size, self.comp_vae.ldim])
                        for _ in range(K_steps)]
            #  Decode latents into components
            zc_batch = torch.cat(zc_k, dim=0)
            x_batch = self.comp_vae.decode(zc_batch)
            x_k = list(torch.chunk(x_batch, self.K_steps, dim=0))
        else:
            # x_k = out_k
            zm_batched = torch.cat(zm_k, dim=0)
            x_batched = self.decoder(zm_batched)
            x_k = torch.chunk(x_batched, self.K_steps, dim=0)
            if self.pixel_bound:
                x_k = [torch.sigmoid(x) for x in x_k]
        # Sanity check
        assert len(x_k) == self.K_steps
        assert len(log_m_k) == self.K_steps
        if self.two_stage:
            assert len(zc_k) == self.K_steps

        # --- Reconstruct image ---
        x_stack = torch.stack(x_k, dim=4)
        m_stack = torch.stack(log_m_k, dim=4).exp()
        generated_image = (m_stack * x_stack).sum(dim=4)

        # Stats
        stats = AttrDict(x_k=x_k, log_m_k=log_m_k, log_s_k=log_s_k,
                         mx_k=[x*m.exp() for x, m in zip(x_k, log_m_k)])
        return generated_image, stats
Exemplo n.º 3
0
    def forward(self, x):
        """
        Performs a forward pass in the model.

        Args:
          x (torch.Tensor): input images [batch size, 3, dim, dim]

        Returns:
          recon: reconstructed images [N, 3, H, W]
          losses: 
          stats: 
          att_stats: 
          comp_stats: 
        """

        # --- Predict segmentation masks ---
        log_m_k, log_s_k, att_stats = self.att_process(x, self.att_steps)
        # UGLY: Need to correct the last mask for one stage model wo/ softmax
        # OR when running the two stage model for an additional step
        if len(log_m_k) == self.K_steps+1:
            del log_m_k[-1]
            log_m_k[self.K_steps-1] = log_s_k[self.K_steps-1]
        if self.debug or not self.training:
            assert len(log_m_k) == self.K_steps

        # --- Reconstruct components ---
        if self.two_stage:
            x_r_k, comp_stats = self.comp_vae(x, log_m_k)
        else:
            # x_r_k = [x[:, 1:, :, :] for x in att_stats.x_k]
            z_batched = torch.cat(att_stats.z_k, dim=0)
            x_r_batched = self.decoder(z_batched)
            x_r_k = torch.chunk(x_r_batched, self.K_steps, dim=0)
            if self.pixel_bound:
                x_r_k = [torch.sigmoid(x) for x in x_r_k]
            comp_stats = None

        # --- Reconstruct input image by marginalising (aka summing) ---
        x_r_stack = torch.stack(x_r_k, dim=4)
        m_stack = torch.stack(log_m_k, dim=4).exp()
        recon = (m_stack * x_r_stack).sum(dim=4)

        # --- Loss terms ---
        losses = AttrDict()
        # -- Reconstruction loss
        losses['err'] = self.x_loss(x, log_m_k, x_r_k, self.std)
        # -- Attention mask KL
        # Using normalising flow, arbitrary posterior
        if 'zm_0_k' in att_stats and 'zm_k_k' in att_stats:
            q_zm_0_k = [Normal(m, s) for m, s in
                        zip(att_stats.mu_k, att_stats.sigma_k)]
            zm_0_k = att_stats.z_0_k
            zm_k_k = att_stats.z_k_k  #TODO(martin) variable name not ideal
            ldj_k = att_stats.ldj_k
        # No flow, Gaussian posterior
        else:
            q_zm_0_k = [Normal(m, s) for m, s in
                        zip(att_stats.mu_k, att_stats.sigma_k)]
            zm_0_k = att_stats.z_k
            zm_k_k = att_stats.z_k
            ldj_k = None
        # Compute loss
        losses['kl_m_k'], p_zm_k = self.mask_latent_loss(
            q_zm_0_k, zm_0_k, zm_k_k, ldj_k, self.prior_lstm, self.prior_linear,
            debug=self.debug or not self.training)
        att_stats['pmu_k'] = [p_zm.mean for p_zm in p_zm_k]
        att_stats['psigma_k'] = [p_zm.scale for p_zm in p_zm_k]
        # Sanity checks
        if self.debug or not self.training:
            assert len(zm_k_k) == self.K_steps
            assert len(zm_0_k) == self.K_steps
            if ldj_k is not None:
                assert len(ldj_k) == self.K_steps
        # -- Component KL
        if self.two_stage:
            if self.comp_prior:
                losses['kl_l_k'] = []
                comp_stats['pmu_k'], comp_stats['psigma_k'] = [], []
                for step, zl in enumerate(comp_stats.z_k):
                    mlp_out = self.prior_mlp(zm_k_k[step])
                    mlp_out = torch.chunk(mlp_out, 2, dim=1)
                    mu = torch.tanh(mlp_out[0])
                    sigma = B.to_prior_sigma(mlp_out[1])
                    p_zl = Normal(mu, sigma)
                    comp_stats['pmu_k'].append(mu)
                    comp_stats['psigma_k'].append(sigma)
                    q_zl = Normal(comp_stats.mu_k[step], comp_stats.sigma_k[step])
                    kld = (q_zl.log_prob(zl) - p_zl.log_prob(zl)).sum(dim=1)
                    losses['kl_l_k'].append(kld)
                # Sanity checks
                if self.debug or not self.training:
                    assert len(comp_stats.z_k) == self.K_steps
                    assert len(comp_stats['pmu_k']) == self.K_steps
                    assert len(comp_stats['psigma_k']) == self.K_steps
                    assert len(losses['kl_l_k']) == self.K_steps
            else:
                raise NotImplementedError

        # Tracking
        stats = AttrDict(
            recon=recon, log_m_k=log_m_k, log_s_k=log_s_k, x_r_k=x_r_k,
            mx_r_k=[x*logm.exp() for x, logm in zip(x_r_k, log_m_k)])

        # Sanity check that masks sum to one if in debug mode
        if self.debug or not self.training:
            assert len(log_m_k) == self.K_steps
            misc.check_log_masks(log_m_k)

        return recon, losses, stats, att_stats, comp_stats