def dip_conv_vae(args): from models.beta import beta_celeb_encoder, beta_celeb_decoder encoder, decoder = beta_celeb_encoder(args), beta_celeb_decoder(args) dip_type = 'dip_vae_i' if args.model == 'dip_conv_vae_i' else 'dip_vae_ii' return DipVAE(encoder, decoder, args.beta, args.lambda_d, args.lambda_od, dip_type, args.capacity, args.capacity_leadin, args.xav_init)
def __init__(self, args): super().__init__(beta_celeb_encoder(args), beta_decoder(args), args.beta, args.capacity, args.capacity_leadin) if args.xav_init: for p in self.encoder.modules(): if isinstance(p, nn.Conv2d) or isinstance(p, nn.Linear) or \ isinstance(p, nn.ConvTranspose2d): torch.nn.init.xavier_uniform_(p.weight) for p in self.decoder.modules(): if isinstance(p, nn.Conv2d) or isinstance(p, nn.Linear) or \ isinstance(p, nn.ConvTranspose2d): torch.nn.init.xavier_uniform_(p.weight) self.uneven_reg_maxval = args.uneven_reg_maxval self.uneven_reg_lambda = args.uneven_reg_lambda self.uneven_reg_encoder_lambda = args.uneven_reg_encoder_lambda self.reg_type = args.reg_type self.orth_lambda = args.orth_lambda if self.reg_type == 'cumax_ada' or self.reg_type == 'monoconst_ada': self.ada_logits = nn.Parameter(torch.ones(args.latents), requires_grad=True) self.lpips_lambda = args.lpips_lambda if self.lpips_lambda > 0: self.lpips_fn = lpips.LPIPS(net='alex').cuda()
def factor_conv_vae(args): from models.beta import beta_celeb_encoder, beta_celeb_decoder encoder, decoder = beta_celeb_encoder(args), beta_celeb_decoder(args) return FactorVAE(encoder, decoder, args.beta, args.latents, args.capacity, args.capacity_leadin, args.factor_vae_gamma, args.xav_init)
def dimvar_vae_64(args): from models.beta import beta_celeb_encoder, beta_celeb_decoder encoder, decoder = beta_celeb_encoder(args), beta_celeb_decoder(args) return DimVarVAE(encoder, decoder, args)