def mask_latent_loss(q_zm_0_k, zm_0_k, zm_k_k, ldj_k, prior_lstm=None, prior_linear=None, debug=False): num_steps = len(zm_k_k) batch_size = zm_k_k[0].size(0) latent_dim = zm_k_k[0].size(1) # -- Determine prior -- if prior_lstm is not None and prior_linear is not None: # zm_seq shape: (att_steps-2, batch_size, ldim) # Do not need the last element in z_k zm_seq = torch.cat( [zm.view(1, batch_size, -1) for zm in zm_k_k[:-1]], dim=0) # lstm_out shape: (att_steps-2, batch_size, state_size) # Note: recurrent state is handled internally by LSTM lstm_out, _ = prior_lstm(zm_seq) # linear_out shape: (att_steps-2, batch_size, 2*ldim) linear_out = prior_linear(lstm_out) linear_out = torch.chunk(linear_out, 2, dim=2) mu_raw = torch.tanh(linear_out[0]) sigma_raw = B.to_prior_sigma(linear_out[1]) # Split into K steps, shape: (att_steps-2)*[1, batch_size, ldim] mu_k = torch.split(mu_raw, 1, dim=0) sigma_k = torch.split(sigma_raw, 1, dim=0) # Use standard Normal as prior for first step p_zm_k = [Normal(0, 1)] # Autoregressive prior for later steps for mean, std in zip(mu_k, sigma_k): # Remember to remove unit dimension at dim=0 p_zm_k += [ Normal(mean.view(batch_size, latent_dim), std.view(batch_size, latent_dim)) ] # Sanity checks if debug: assert zm_seq.size(0) == num_steps - 1 else: p_zm_k = num_steps * [Normal(0, 1)] # -- Compute KL using Monte Carlo samples for every step k -- kl_m_k = [] for step, p_zm in enumerate(p_zm_k): log_q = q_zm_0_k[step].log_prob(zm_0_k[step]).sum(dim=1) log_p = p_zm.log_prob(zm_k_k[step]).sum(dim=1) kld = log_q - log_p if ldj_k is not None: ldj = ldj_k[step].sum(dim=1) kld = kld - ldj kl_m_k.append(kld) # -- Sanity check -- if debug: assert len(p_zm_k) == num_steps assert len(kl_m_k) == num_steps return kl_m_k, p_zm_k
def sample(self, batch_size, K_steps=None): K_steps = self.K_steps if K_steps is None else K_steps # --- Mask --- # Sample latents if self.autoreg_prior: zm_k = [Normal(0, 1).sample([batch_size, self.ldim])] state = None for k in range(1, self.att_steps): # TODO(martin) reuse code from forward method? lstm_out, state = self.prior_lstm( zm_k[-1].view(1, batch_size, -1), state) linear_out = self.prior_linear(lstm_out) mu = linear_out[0, :, :self.ldim] sigma = B.to_prior_sigma(linear_out[0, :, self.ldim:]) p_zm = Normal(mu.view([batch_size, self.ldim]), sigma.view([batch_size, self.ldim])) zm_k.append(p_zm.sample()) else: p_zm = Normal(0, 1) zm_k = [p_zm.sample([batch_size, self.ldim]) for _ in range(self.att_steps)] # Decode latents into masks log_m_k, log_s_k, out_k = self.att_process.masks_from_zm_k( zm_k, self.img_size) # UGLY: Need to correct the last mask for one stage model wo/ softmax # OR when running the two stage model for an additional step if len(log_m_k) == self.K_steps+1: del log_m_k[-1] log_m_k[self.K_steps-1] = log_s_k[self.K_steps-1] # Sanity checks. This function is not called at every training step so # assert statement do not cause a big slow down in total training time assert len(zm_k) == self.K_steps assert len(log_m_k) == self.K_steps if self.two_stage: assert out_k[0].size(1) == 0 else: # assert out_k[0].size(1) == 3 assert out_k[0].size(1) == 0 misc.check_log_masks(log_m_k) # --- Component appearance --- if self.two_stage: # Sample latents if self.comp_prior: zc_k = [] for zm in zm_k: mlp_out = torch.chunk(self.prior_mlp(zm), 2, dim=1) mu = torch.tanh(mlp_out[0]) sigma = B.to_prior_sigma(mlp_out[1]) zc_k.append(Normal(mu, sigma).sample()) # if not self.softmax_attention: # zc_k.append(Normal(0, 1).sample( # [batch_size, self.comp_vae.ldim])) else: zc_k = [Normal(0, 1).sample([batch_size, self.comp_vae.ldim]) for _ in range(K_steps)] # Decode latents into components zc_batch = torch.cat(zc_k, dim=0) x_batch = self.comp_vae.decode(zc_batch) x_k = list(torch.chunk(x_batch, self.K_steps, dim=0)) else: # x_k = out_k zm_batched = torch.cat(zm_k, dim=0) x_batched = self.decoder(zm_batched) x_k = torch.chunk(x_batched, self.K_steps, dim=0) if self.pixel_bound: x_k = [torch.sigmoid(x) for x in x_k] # Sanity check assert len(x_k) == self.K_steps assert len(log_m_k) == self.K_steps if self.two_stage: assert len(zc_k) == self.K_steps # --- Reconstruct image --- x_stack = torch.stack(x_k, dim=4) m_stack = torch.stack(log_m_k, dim=4).exp() generated_image = (m_stack * x_stack).sum(dim=4) # Stats stats = AttrDict(x_k=x_k, log_m_k=log_m_k, log_s_k=log_s_k, mx_k=[x*m.exp() for x, m in zip(x_k, log_m_k)]) return generated_image, stats
def forward(self, x): """ Performs a forward pass in the model. Args: x (torch.Tensor): input images [batch size, 3, dim, dim] Returns: recon: reconstructed images [N, 3, H, W] losses: stats: att_stats: comp_stats: """ # --- Predict segmentation masks --- log_m_k, log_s_k, att_stats = self.att_process(x, self.att_steps) # UGLY: Need to correct the last mask for one stage model wo/ softmax # OR when running the two stage model for an additional step if len(log_m_k) == self.K_steps+1: del log_m_k[-1] log_m_k[self.K_steps-1] = log_s_k[self.K_steps-1] if self.debug or not self.training: assert len(log_m_k) == self.K_steps # --- Reconstruct components --- if self.two_stage: x_r_k, comp_stats = self.comp_vae(x, log_m_k) else: # x_r_k = [x[:, 1:, :, :] for x in att_stats.x_k] z_batched = torch.cat(att_stats.z_k, dim=0) x_r_batched = self.decoder(z_batched) x_r_k = torch.chunk(x_r_batched, self.K_steps, dim=0) if self.pixel_bound: x_r_k = [torch.sigmoid(x) for x in x_r_k] comp_stats = None # --- Reconstruct input image by marginalising (aka summing) --- x_r_stack = torch.stack(x_r_k, dim=4) m_stack = torch.stack(log_m_k, dim=4).exp() recon = (m_stack * x_r_stack).sum(dim=4) # --- Loss terms --- losses = AttrDict() # -- Reconstruction loss losses['err'] = self.x_loss(x, log_m_k, x_r_k, self.std) # -- Attention mask KL # Using normalising flow, arbitrary posterior if 'zm_0_k' in att_stats and 'zm_k_k' in att_stats: q_zm_0_k = [Normal(m, s) for m, s in zip(att_stats.mu_k, att_stats.sigma_k)] zm_0_k = att_stats.z_0_k zm_k_k = att_stats.z_k_k #TODO(martin) variable name not ideal ldj_k = att_stats.ldj_k # No flow, Gaussian posterior else: q_zm_0_k = [Normal(m, s) for m, s in zip(att_stats.mu_k, att_stats.sigma_k)] zm_0_k = att_stats.z_k zm_k_k = att_stats.z_k ldj_k = None # Compute loss losses['kl_m_k'], p_zm_k = self.mask_latent_loss( q_zm_0_k, zm_0_k, zm_k_k, ldj_k, self.prior_lstm, self.prior_linear, debug=self.debug or not self.training) att_stats['pmu_k'] = [p_zm.mean for p_zm in p_zm_k] att_stats['psigma_k'] = [p_zm.scale for p_zm in p_zm_k] # Sanity checks if self.debug or not self.training: assert len(zm_k_k) == self.K_steps assert len(zm_0_k) == self.K_steps if ldj_k is not None: assert len(ldj_k) == self.K_steps # -- Component KL if self.two_stage: if self.comp_prior: losses['kl_l_k'] = [] comp_stats['pmu_k'], comp_stats['psigma_k'] = [], [] for step, zl in enumerate(comp_stats.z_k): mlp_out = self.prior_mlp(zm_k_k[step]) mlp_out = torch.chunk(mlp_out, 2, dim=1) mu = torch.tanh(mlp_out[0]) sigma = B.to_prior_sigma(mlp_out[1]) p_zl = Normal(mu, sigma) comp_stats['pmu_k'].append(mu) comp_stats['psigma_k'].append(sigma) q_zl = Normal(comp_stats.mu_k[step], comp_stats.sigma_k[step]) kld = (q_zl.log_prob(zl) - p_zl.log_prob(zl)).sum(dim=1) losses['kl_l_k'].append(kld) # Sanity checks if self.debug or not self.training: assert len(comp_stats.z_k) == self.K_steps assert len(comp_stats['pmu_k']) == self.K_steps assert len(comp_stats['psigma_k']) == self.K_steps assert len(losses['kl_l_k']) == self.K_steps else: raise NotImplementedError # Tracking stats = AttrDict( recon=recon, log_m_k=log_m_k, log_s_k=log_s_k, x_r_k=x_r_k, mx_r_k=[x*logm.exp() for x, logm in zip(x_r_k, log_m_k)]) # Sanity check that masks sum to one if in debug mode if self.debug or not self.training: assert len(log_m_k) == self.K_steps misc.check_log_masks(log_m_k) return recon, losses, stats, att_stats, comp_stats