Ejemplo n.º 1
0
 def __init__(self, cfg, act):
     super(MONetCompEncoder, self).__init__()
     nin = cfg.input_channels if hasattr(cfg, 'input_channels') else 3
     c = cfg.comp_enc_channels
     self.ldim = cfg.comp_ldim
     nin_mlp = 2 * c * (cfg.img_size // 16)**2
     nhid_mlp = max(256, 2 * self.ldim)
     self.module = Seq(nn.Conv2d(nin + 1, c, 3, 2, 1), act,
                       nn.Conv2d(c, c, 3, 2, 1), act,
                       nn.Conv2d(c, 2 * c, 3, 2, 1), act,
                       nn.Conv2d(2 * c, 2 * c, 3, 2, 1), act, B.Flatten(),
                       nn.Linear(nin_mlp, nhid_mlp), act,
                       nn.Linear(nhid_mlp, 2 * self.ldim))
Ejemplo n.º 2
0
    def __init__(self, cfg):
        super(Genesis, self).__init__()
        # --- Configuration ---
        # Data dependent config
        self.K_steps = cfg.K_steps
        self.img_size = cfg.img_size
        # Model config
        self.two_stage = cfg.two_stage
        self.autoreg_prior = cfg.autoreg_prior
        self.comp_prior = False
        if self.two_stage and self.K_steps > 1:
            self.comp_prior = cfg.comp_prior
        self.ldim = cfg.attention_latents
        self.pixel_bound = cfg.pixel_bound
        # Default config for backwards compatibility
        if not hasattr(cfg, 'comp_symmetric'):
            cfg.comp_symmetric = False
        # Sanity checks
        self.debug = cfg.debug
        assert cfg.montecarlo_kl == True  # ALWAYS use MC for estimating KL

        # --- Modules ---
        if hasattr(cfg, 'input_channels'):
            input_channels = cfg.input_channels
        else:
            input_channels = 3
        # - Attention core
        att_nin = input_channels
        att_nout = 1
        att_core = sylvester.VAE(self.ldim,
                                 [att_nin, cfg.img_size, cfg.img_size],
                                 att_nout, cfg.enc_norm, cfg.dec_norm)
        # - Attention process
        if self.K_steps > 1:
            self.att_steps = self.K_steps
            self.att_process = seq_att.LatentSBP(att_core)
        # - Component VAE
        if self.two_stage:
            self.comp_vae = ComponentVAE(nout=input_channels,
                                         cfg=cfg,
                                         act=nn.ELU())
            if cfg.comp_symmetric:
                self.comp_vae.encoder_module = nn.Sequential(
                    sylvester.build_gc_encoder(
                        [input_channels + 1, 32, 32, 64, 64],
                        [32, 32, 64, 64, 64], [1, 2, 1, 2, 1],
                        2 * cfg.comp_ldim,
                        att_core.last_kernel_size,
                        hn=cfg.enc_norm,
                        gn=cfg.enc_norm), B.Flatten())
                self.comp_vae.decoder_module = nn.Sequential(
                    B.UnFlatten(),
                    sylvester.build_gc_decoder([64, 64, 32, 32, 32],
                                               [64, 32, 32, 32, 32],
                                               [1, 2, 1, 2, 1],
                                               cfg.comp_ldim,
                                               att_core.last_kernel_size,
                                               hn=cfg.dec_norm,
                                               gn=cfg.dec_norm),
                    nn.Conv2d(32, input_channels, 1))
        else:
            assert self.K_steps > 1
            self.decoder = decoders.BroadcastDecoder(
                in_chnls=self.ldim,
                out_chnls=input_channels,
                h_chnls=cfg.comp_dec_channels,
                num_layers=cfg.comp_dec_layers,
                img_dim=self.img_size,
                act=nn.ELU())

        # --- Priors ---
        # Optional: Autoregressive prior
        if self.autoreg_prior and self.K_steps > 1:
            self.prior_lstm = nn.LSTM(self.ldim, 256)
            self.prior_linear = nn.Linear(256, 2 * self.ldim)
        # Optional: Component prior - only relevant for two stage model
        if self.comp_prior and self.two_stage and self.K_steps > 1:
            self.prior_mlp = nn.Sequential(nn.Linear(self.ldim, 256), nn.ELU(),
                                           nn.Linear(256, 256), nn.ELU(),
                                           nn.Linear(256, 2 * cfg.comp_ldim))

        # --- Output pixel distribution ---
        std = cfg.pixel_std2 * torch.ones(1, 1, 1, 1, self.K_steps)
        std[0, 0, 0, 0, 0] = cfg.pixel_std1  # first step
        self.register_buffer('std', std)