Пример #1
0
    def generate(self, n_samples):
        """ AIR paper figure 3 left:

        The generative model draws n ∼ Geom(ρ) digits {y_i_att} of size 28 × 28 (two shown), scales andshifts them
        according to z_i_where ∼ N (0, Σ) using spatial transformers, and sums the results {y_i} to form a 50 × 50 image.
        Each digit is obtained by first sampling a latent code z_i_what from the prior z_i_what ∼ N (0, 1) and 
        propagating it through the decoder network of a variational autoencoder.
        The learnable parameters θ of the generative model are the parameters of this decoder network.
        """
        # sample z_pres ~ Geom(rho) -- this is the number of digits present in an image
        z_pres = D.Geometric(1 - self.z_pres_prior).sample(
            (n_samples, )).clamp_(0, self.max_steps)

        # compute a mask on z_pres as e.g.:
        #   z_pres = [1,4,2,0]
        #   mask = [[1,0,0,0,0],
        #           [1,1,1,1,0],
        #           [1,1,0,0,0],
        #           [0,0,0,0,0]]
        #   thus network outputs more objects (sample z_what, z_where and decode) where z_pres is 1
        #   and outputs nothing when z_pres is 0
        z_pres_mask = torch.arange(self.max_steps).float().to(
            z_pres.device).expand(n_samples, self.max_steps) < z_pres.view(
                -1, 1)
        z_pres_mask = z_pres_mask.float().to(z_pres.device)

        # initialize image canvas
        x = torch.zeros(n_samples, self.C, self.H, self.W).to(z_pres.device)

        # generate digits
        for i in range(int(z_pres.max().item())
                       ):  # up until the number of objects sampled via z_pres
            # sample priors
            z_what = self.p_z_what.sample((n_samples, ))
            z_where = self.p_z_where.sample((n_samples, ))

            # propagate through the decoder, scale and shift y_att according to z_where using spatial transformers
            y_att = torch.sigmoid(
                self.decoder(z_what).view(n_samples, self.C, self.A, self.A) +
                self.decoder_bias)
            y = stn(y_att,
                    z_where, (n_samples, self.C, self.H, self.W),
                    inverse=True,
                    box_attn_window_color=i)

            # apply mask and sum results towards final image
            x = x + y * z_pres_mask[:, i].view(-1, 1, 1, 1)
        return x
Пример #2
0
 def p_z_pres(self):
     return D.Geometric(probs=1-self.z_pres_prob)