def __init__(self, cfg, act): super(MONetCompEncoder, self).__init__() nin = cfg.input_channels if hasattr(cfg, 'input_channels') else 3 c = cfg.comp_enc_channels self.ldim = cfg.comp_ldim nin_mlp = 2 * c * (cfg.img_size // 16)**2 nhid_mlp = max(256, 2 * self.ldim) self.module = Seq(nn.Conv2d(nin + 1, c, 3, 2, 1), act, nn.Conv2d(c, c, 3, 2, 1), act, nn.Conv2d(c, 2 * c, 3, 2, 1), act, nn.Conv2d(2 * c, 2 * c, 3, 2, 1), act, B.Flatten(), nn.Linear(nin_mlp, nhid_mlp), act, nn.Linear(nhid_mlp, 2 * self.ldim))
def __init__(self, cfg): super(Genesis, self).__init__() # --- Configuration --- # Data dependent config self.K_steps = cfg.K_steps self.img_size = cfg.img_size # Model config self.two_stage = cfg.two_stage self.autoreg_prior = cfg.autoreg_prior self.comp_prior = False if self.two_stage and self.K_steps > 1: self.comp_prior = cfg.comp_prior self.ldim = cfg.attention_latents self.pixel_bound = cfg.pixel_bound # Default config for backwards compatibility if not hasattr(cfg, 'comp_symmetric'): cfg.comp_symmetric = False # Sanity checks self.debug = cfg.debug assert cfg.montecarlo_kl == True # ALWAYS use MC for estimating KL # --- Modules --- if hasattr(cfg, 'input_channels'): input_channels = cfg.input_channels else: input_channels = 3 # - Attention core att_nin = input_channels att_nout = 1 att_core = sylvester.VAE(self.ldim, [att_nin, cfg.img_size, cfg.img_size], att_nout, cfg.enc_norm, cfg.dec_norm) # - Attention process if self.K_steps > 1: self.att_steps = self.K_steps self.att_process = seq_att.LatentSBP(att_core) # - Component VAE if self.two_stage: self.comp_vae = ComponentVAE(nout=input_channels, cfg=cfg, act=nn.ELU()) if cfg.comp_symmetric: self.comp_vae.encoder_module = nn.Sequential( sylvester.build_gc_encoder( [input_channels + 1, 32, 32, 64, 64], [32, 32, 64, 64, 64], [1, 2, 1, 2, 1], 2 * cfg.comp_ldim, att_core.last_kernel_size, hn=cfg.enc_norm, gn=cfg.enc_norm), B.Flatten()) self.comp_vae.decoder_module = nn.Sequential( B.UnFlatten(), sylvester.build_gc_decoder([64, 64, 32, 32, 32], [64, 32, 32, 32, 32], [1, 2, 1, 2, 1], cfg.comp_ldim, att_core.last_kernel_size, hn=cfg.dec_norm, gn=cfg.dec_norm), nn.Conv2d(32, input_channels, 1)) else: assert self.K_steps > 1 self.decoder = decoders.BroadcastDecoder( in_chnls=self.ldim, out_chnls=input_channels, h_chnls=cfg.comp_dec_channels, num_layers=cfg.comp_dec_layers, img_dim=self.img_size, act=nn.ELU()) # --- Priors --- # Optional: Autoregressive prior if self.autoreg_prior and self.K_steps > 1: self.prior_lstm = nn.LSTM(self.ldim, 256) self.prior_linear = nn.Linear(256, 2 * self.ldim) # Optional: Component prior - only relevant for two stage model if self.comp_prior and self.two_stage and self.K_steps > 1: self.prior_mlp = nn.Sequential(nn.Linear(self.ldim, 256), nn.ELU(), nn.Linear(256, 256), nn.ELU(), nn.Linear(256, 2 * cfg.comp_ldim)) # --- Output pixel distribution --- std = cfg.pixel_std2 * torch.ones(1, 1, 1, 1, self.K_steps) std[0, 0, 0, 0, 0] = cfg.pixel_std1 # first step self.register_buffer('std', std)