def generate(self, n_samples): """ AIR paper figure 3 left: The generative model draws n ∼ Geom(ρ) digits {y_i_att} of size 28 × 28 (two shown), scales andshifts them according to z_i_where ∼ N (0, Σ) using spatial transformers, and sums the results {y_i} to form a 50 × 50 image. Each digit is obtained by first sampling a latent code z_i_what from the prior z_i_what ∼ N (0, 1) and propagating it through the decoder network of a variational autoencoder. The learnable parameters θ of the generative model are the parameters of this decoder network. """ # sample z_pres ~ Geom(rho) -- this is the number of digits present in an image z_pres = D.Geometric(1 - self.z_pres_prior).sample( (n_samples, )).clamp_(0, self.max_steps) # compute a mask on z_pres as e.g.: # z_pres = [1,4,2,0] # mask = [[1,0,0,0,0], # [1,1,1,1,0], # [1,1,0,0,0], # [0,0,0,0,0]] # thus network outputs more objects (sample z_what, z_where and decode) where z_pres is 1 # and outputs nothing when z_pres is 0 z_pres_mask = torch.arange(self.max_steps).float().to( z_pres.device).expand(n_samples, self.max_steps) < z_pres.view( -1, 1) z_pres_mask = z_pres_mask.float().to(z_pres.device) # initialize image canvas x = torch.zeros(n_samples, self.C, self.H, self.W).to(z_pres.device) # generate digits for i in range(int(z_pres.max().item()) ): # up until the number of objects sampled via z_pres # sample priors z_what = self.p_z_what.sample((n_samples, )) z_where = self.p_z_where.sample((n_samples, )) # propagate through the decoder, scale and shift y_att according to z_where using spatial transformers y_att = torch.sigmoid( self.decoder(z_what).view(n_samples, self.C, self.A, self.A) + self.decoder_bias) y = stn(y_att, z_where, (n_samples, self.C, self.H, self.W), inverse=True, box_attn_window_color=i) # apply mask and sum results towards final image x = x + y * z_pres_mask[:, i].view(-1, 1, 1, 1) return x
def p_z_pres(self): return D.Geometric(probs=1-self.z_pres_prob)