def forward(self, x): """ Args: x (torch.Tensor): Input images [batch size, 3, dim, dim] """ # --- Predict segmentation masks --- log_m_k, log_s_k, att_stats = self.att_process(x, self.K_steps - 1) # --- Reconstruct components --- x_m_r_k, comp_stats = self.comp_vae(x, log_m_k) # Split into appearances and mask prior x_r_k = [item[:, :3, :, :] for item in x_m_r_k] m_r_logits_k = [item[:, 3:, :, :] for item in x_m_r_k] # Apply pixelbound if self.pixel_bound: x_r_k = [torch.sigmoid(item) for item in x_r_k] # --- Reconstruct input image by marginalising (aka summing) --- x_r_stack = torch.stack(x_r_k, dim=4) m_stack = torch.stack(log_m_k, dim=4).exp() recon = (m_stack * x_r_stack).sum(dim=4) # --- Reconstruct masks --- log_m_r_stack = self.get_mask_recon_stack(m_r_logits_k, self.prior_mode, log=True) log_m_r_k = torch.split(log_m_r_stack, 1, dim=4) log_m_r_k = [m[:, :, :, :, 0] for m in log_m_r_k] # --- Loss terms --- losses = AttrDict() # -- Reconstruction loss losses['err'] = Genesis.x_loss(x, log_m_k, x_r_k, self.std) # -- Attention mask KL losses['kl_m'] = self.kl_m_loss(log_m_k=log_m_k, log_m_r_k=log_m_r_k) # -- Component KL q_z_k = [ Normal(m, s) for m, s in zip(comp_stats.mu_k, comp_stats.sigma_k) ] kl_l_k = misc.get_kl(comp_stats.z_k, q_z_k, len(q_z_k) * [Normal(0, 1)], self.mckl) losses['kl_l_k'] = [kld.sum(1) for kld in kl_l_k] # Track quantities of interest stats = AttrDict( recon=recon, log_m_k=log_m_k, log_s_k=log_s_k, x_r_k=x_r_k, log_m_r_k=log_m_r_k, mx_r_k=[x * logm.exp() for x, logm in zip(x_r_k, log_m_k)]) # Sanity check that masks sum to one if in debug mode if self.debug: assert len(log_m_k) == self.K_steps assert len(log_m_r_k) == self.K_steps misc.check_log_masks(log_m_k) misc.check_log_masks(log_m_r_k) return recon, losses, stats, att_stats, comp_stats
def sample(self, batch_size, K_steps=None): K_steps = self.K_steps if K_steps is None else K_steps # --- Mask --- # Sample latents if self.autoreg_prior: zm_k = [Normal(0, 1).sample([batch_size, self.ldim])] state = None for k in range(1, self.att_steps): # TODO(martin) reuse code from forward method? lstm_out, state = self.prior_lstm( zm_k[-1].view(1, batch_size, -1), state) linear_out = self.prior_linear(lstm_out) mu = linear_out[0, :, :self.ldim] sigma = B.to_prior_sigma(linear_out[0, :, self.ldim:]) p_zm = Normal(mu.view([batch_size, self.ldim]), sigma.view([batch_size, self.ldim])) zm_k.append(p_zm.sample()) else: p_zm = Normal(0, 1) zm_k = [p_zm.sample([batch_size, self.ldim]) for _ in range(self.att_steps)] # Decode latents into masks log_m_k, log_s_k, out_k = self.att_process.masks_from_zm_k( zm_k, self.img_size) # UGLY: Need to correct the last mask for one stage model wo/ softmax # OR when running the two stage model for an additional step if len(log_m_k) == self.K_steps+1: del log_m_k[-1] log_m_k[self.K_steps-1] = log_s_k[self.K_steps-1] # Sanity checks. This function is not called at every training step so # assert statement do not cause a big slow down in total training time assert len(zm_k) == self.K_steps assert len(log_m_k) == self.K_steps if self.two_stage: assert out_k[0].size(1) == 0 else: # assert out_k[0].size(1) == 3 assert out_k[0].size(1) == 0 misc.check_log_masks(log_m_k) # --- Component appearance --- if self.two_stage: # Sample latents if self.comp_prior: zc_k = [] for zm in zm_k: mlp_out = torch.chunk(self.prior_mlp(zm), 2, dim=1) mu = torch.tanh(mlp_out[0]) sigma = B.to_prior_sigma(mlp_out[1]) zc_k.append(Normal(mu, sigma).sample()) # if not self.softmax_attention: # zc_k.append(Normal(0, 1).sample( # [batch_size, self.comp_vae.ldim])) else: zc_k = [Normal(0, 1).sample([batch_size, self.comp_vae.ldim]) for _ in range(K_steps)] # Decode latents into components zc_batch = torch.cat(zc_k, dim=0) x_batch = self.comp_vae.decode(zc_batch) x_k = list(torch.chunk(x_batch, self.K_steps, dim=0)) else: # x_k = out_k zm_batched = torch.cat(zm_k, dim=0) x_batched = self.decoder(zm_batched) x_k = torch.chunk(x_batched, self.K_steps, dim=0) if self.pixel_bound: x_k = [torch.sigmoid(x) for x in x_k] # Sanity check assert len(x_k) == self.K_steps assert len(log_m_k) == self.K_steps if self.two_stage: assert len(zc_k) == self.K_steps # --- Reconstruct image --- x_stack = torch.stack(x_k, dim=4) m_stack = torch.stack(log_m_k, dim=4).exp() generated_image = (m_stack * x_stack).sum(dim=4) # Stats stats = AttrDict(x_k=x_k, log_m_k=log_m_k, log_s_k=log_s_k, mx_k=[x*m.exp() for x, m in zip(x_k, log_m_k)]) return generated_image, stats
def forward(self, x): """ Performs a forward pass in the model. Args: x (torch.Tensor): input images [batch size, 3, dim, dim] Returns: recon: reconstructed images [N, 3, H, W] losses: stats: att_stats: comp_stats: """ # --- Predict segmentation masks --- log_m_k, log_s_k, att_stats = self.att_process(x, self.att_steps) # UGLY: Need to correct the last mask for one stage model wo/ softmax # OR when running the two stage model for an additional step if len(log_m_k) == self.K_steps+1: del log_m_k[-1] log_m_k[self.K_steps-1] = log_s_k[self.K_steps-1] if self.debug or not self.training: assert len(log_m_k) == self.K_steps # --- Reconstruct components --- if self.two_stage: x_r_k, comp_stats = self.comp_vae(x, log_m_k) else: # x_r_k = [x[:, 1:, :, :] for x in att_stats.x_k] z_batched = torch.cat(att_stats.z_k, dim=0) x_r_batched = self.decoder(z_batched) x_r_k = torch.chunk(x_r_batched, self.K_steps, dim=0) if self.pixel_bound: x_r_k = [torch.sigmoid(x) for x in x_r_k] comp_stats = None # --- Reconstruct input image by marginalising (aka summing) --- x_r_stack = torch.stack(x_r_k, dim=4) m_stack = torch.stack(log_m_k, dim=4).exp() recon = (m_stack * x_r_stack).sum(dim=4) # --- Loss terms --- losses = AttrDict() # -- Reconstruction loss losses['err'] = self.x_loss(x, log_m_k, x_r_k, self.std) # -- Attention mask KL # Using normalising flow, arbitrary posterior if 'zm_0_k' in att_stats and 'zm_k_k' in att_stats: q_zm_0_k = [Normal(m, s) for m, s in zip(att_stats.mu_k, att_stats.sigma_k)] zm_0_k = att_stats.z_0_k zm_k_k = att_stats.z_k_k #TODO(martin) variable name not ideal ldj_k = att_stats.ldj_k # No flow, Gaussian posterior else: q_zm_0_k = [Normal(m, s) for m, s in zip(att_stats.mu_k, att_stats.sigma_k)] zm_0_k = att_stats.z_k zm_k_k = att_stats.z_k ldj_k = None # Compute loss losses['kl_m_k'], p_zm_k = self.mask_latent_loss( q_zm_0_k, zm_0_k, zm_k_k, ldj_k, self.prior_lstm, self.prior_linear, debug=self.debug or not self.training) att_stats['pmu_k'] = [p_zm.mean for p_zm in p_zm_k] att_stats['psigma_k'] = [p_zm.scale for p_zm in p_zm_k] # Sanity checks if self.debug or not self.training: assert len(zm_k_k) == self.K_steps assert len(zm_0_k) == self.K_steps if ldj_k is not None: assert len(ldj_k) == self.K_steps # -- Component KL if self.two_stage: if self.comp_prior: losses['kl_l_k'] = [] comp_stats['pmu_k'], comp_stats['psigma_k'] = [], [] for step, zl in enumerate(comp_stats.z_k): mlp_out = self.prior_mlp(zm_k_k[step]) mlp_out = torch.chunk(mlp_out, 2, dim=1) mu = torch.tanh(mlp_out[0]) sigma = B.to_prior_sigma(mlp_out[1]) p_zl = Normal(mu, sigma) comp_stats['pmu_k'].append(mu) comp_stats['psigma_k'].append(sigma) q_zl = Normal(comp_stats.mu_k[step], comp_stats.sigma_k[step]) kld = (q_zl.log_prob(zl) - p_zl.log_prob(zl)).sum(dim=1) losses['kl_l_k'].append(kld) # Sanity checks if self.debug or not self.training: assert len(comp_stats.z_k) == self.K_steps assert len(comp_stats['pmu_k']) == self.K_steps assert len(comp_stats['psigma_k']) == self.K_steps assert len(losses['kl_l_k']) == self.K_steps else: raise NotImplementedError # Tracking stats = AttrDict( recon=recon, log_m_k=log_m_k, log_s_k=log_s_k, x_r_k=x_r_k, mx_r_k=[x*logm.exp() for x, logm in zip(x_r_k, log_m_k)]) # Sanity check that masks sum to one if in debug mode if self.debug or not self.training: assert len(log_m_k) == self.K_steps misc.check_log_masks(log_m_k) return recon, losses, stats, att_stats, comp_stats