class VAE(nn.Module): def __init__(self,img_params,model_params,latent_params): super(VAE, self).__init__() image_dim = img_params['image_dim'] image_size = img_params['image_size'] n_downsample = model_params['n_downsample'] dim = model_params['dim'] n_res = model_params['n_res'] norm = model_params['norm'] activ = model_params['activ'] pad_type = model_params['pad_type'] n_mlp = model_params['n_mlp'] mlp_dim = model_params['mlp_dim'] self.latent_dim = latent_params['latent_dim'] self.prior = Gaussian(self.latent_dim) self.encoder = Encoder(n_downsample,n_res,n_mlp,image_size,image_dim,dim,mlp_dim, self.latent_dim,norm,activ,pad_type) conv_inp_size = image_size // (2**n_downsample) self.decoder = Decoder(n_downsample,n_res,n_mlp,self.latent_dim,mlp_dim,conv_inp_size, dim,image_dim,norm,activ,pad_type) def forward(self,x): latent_distr = self.encoder(x) latent_distr = self.prior.activate(latent_distr) samples = self.prior.sample(latent_distr) return self.decoder(samples),latent_distr,samples
class CatVAE(nn.Module): # Auto-encoder architecture def __init__(self,img_params,model_params,latent_params): super(CatVAE, self).__init__() image_dim = img_params['image_dim'] image_size = img_params['image_size'] n_downsample = model_params['n_downsample'] dim = model_params['dim'] n_res = model_params['n_res'] norm = model_params['norm'] activ = model_params['activ'] pad_type = model_params['pad_type'] n_mlp = model_params['n_mlp'] mlp_dim = model_params['mlp_dim'] self.continious_dim = latent_params['continious'] self.prior_cont = Gaussian(self.continious_dim) self.categorical_dim = latent_params['categorical'] self.prior_catg = Categorical(self.categorical_dim) self.gumbel = Gumbel(self.categorical_dim) self.encoder = CatEncoder(n_downsample,n_res,n_mlp,image_size,image_dim,dim,mlp_dim, latent_params,norm,activ,pad_type) conv_inp_size = image_size // (2**n_downsample) decoder_inp_dim = self.continious_dim + self.categorical_dim self.decoder = Decoder(n_downsample,n_res,n_mlp,decoder_inp_dim,mlp_dim,conv_inp_size, dim,image_dim,norm,activ,pad_type) def forward(self, x, tempr): latent_distr = self.encoder(x) #categorical distr categorical_distr = latent_distr[:,-self.categorical_dim:] categorical_distr_act = self.prior_catg.activate(categorical_distr)# need for KL catg_samples = self.gumbel.gumbel_softmax_sample(categorical_distr,tempr) # categotical sampling, reconstruction #continious distr continious_distr = latent_distr[:,:-self.categorical_dim] continious_distr_act = self.prior_cont.activate(continious_distr) cont_samples = self.prior_cont.sample(continious_distr_act) #create full latent code full_samples = torch.cat([cont_samples,catg_samples],1) recons = self.decoder(full_samples) return recons, full_samples, categorical_distr_act, continious_distr_act def encode_decode(self, x, tempr=0.4, hard_catg=True): latent_distr = self.encoder(x) #categorical distr stuff categorical_distr = latent_distr[:,-self.categorical_dim:] if hard_catg: #just make one hot vector catg_samples = self.prior_catg.logits_to_onehot(categorical_distr) else: #make smoothed one hot by softmax catg_samples = self.prior_catg.activate(categorical_distr)['prob'] #continious distr stuff continious_distr = latent_distr[:,:-self.categorical_dim] continious_distr_act = self.prior_cont.activate(continious_distr) cont_samples = continious_distr_act['mean'] #create full latent code full_samples = torch.cat([cont_samples,catg_samples],1) recons = self.decoder(full_samples) return recons, full_samples#, categorical_distr_act, continious_distr_act def sample_full_prior(self, batch_size, device='cuda:0'): cont_samples = self.prior_cont.sample_prior(batch_size, device=device) catg_samples = self.prior_catg.sample_prior(batch_size, device=device) full_samples = torch.cat([cont_samples,catg_samples],1) return full_samples