def __init__(self, channel_num: int, z_dim: int, e_dim: int, beta: float, **kwargs): super().__init__() self.channel_num = channel_num self.z_dim = z_dim self.e_dim = e_dim self._beta_val = beta # Distributions self.normal = pxd.Normal(loc=torch.zeros(z_dim), scale=torch.ones(z_dim), var=["e"]) self.prior = pxd.Normal(loc=torch.zeros(z_dim), scale=torch.ones(z_dim), var=["z"]) self.decoder = Decoder(channel_num, z_dim) self.encoder = AVBEncoder(channel_num, z_dim, e_dim) self.distributions = [ self.normal, self.prior, self.decoder, self.encoder ] # Loss self.ce = pxl.CrossEntropy(self.encoder, self.decoder) # Adversarial loss self.disc = AVBDiscriminator(channel_num, z_dim) self.adv_js = pxl.AdversarialJensenShannon(self.encoder, self.prior, self.disc)
def __init__(self, channel_num: int, z_dim: int, beta: float, c: float, lmd_od: float, lmd_d: float, dip_type: str, **kwargs): super().__init__() # Parameters self.channel_num = channel_num self.z_dim = z_dim self._beta_value = beta self._c_value = c self.lmd_od = lmd_od self.lmd_d = lmd_d # Distributions self.prior = pxd.Normal(loc=torch.zeros(z_dim), scale=torch.ones(z_dim), var=["z"]) self.decoder = Decoder(channel_num, z_dim) self.encoder = Encoder(channel_num, z_dim) self.distributions = [self.prior, self.decoder, self.encoder] # Loss class self.ce = pxl.CrossEntropy(self.encoder, self.decoder) _kl = pxl.KullbackLeibler(self.encoder, self.prior) _beta = pxl.Parameter("beta") _c = pxl.Parameter("c") self.kl = _beta * (_kl - _c).abs() self.dip = DipLoss(self.encoder, lmd_od, lmd_d, dip_type)
def __init__(self, channel_num: int, z_dim: int, alpha: float, beta: float, gamma: float, **kwargs): super().__init__() # Parameters self.channel_num = channel_num self.z_dim = z_dim self._alpha_value = alpha self._beta_value = beta self._gamma_value = gamma # Distributions self.prior = pxd.Normal( loc=torch.zeros(z_dim), scale=torch.ones(z_dim), var=["z"]) self.decoder = Decoder(channel_num, z_dim) self.encoder = Encoder(channel_num, z_dim) self.distributions = [self.prior, self.decoder, self.encoder] # Loss class self.ce = pxl.CrossEntropy(self.encoder, self.decoder) self.kl = pxl.KullbackLeibler(self.encoder, self.prior) self.alpha = pxl.Parameter("alpha") self.beta = pxl.Parameter("beta") self.gamma = pxl.Parameter("gamma")
def test_diploss_i(self): p = pxd.Normal(loc=torch.tensor(0.), scale=torch.tensor(1.), features_shape=[2]) loss_cls = dvl.DipLoss(p, 10, 10, dip_type="i") # Check symbol print(loss_cls) # Evaluate self.assertGreaterEqual(loss_cls.eval(), 0)
def __init__(self, x_dim, z_dim, n_dim): self.x_dim = x_dim self.z_dim = z_dim self.n_dim = n_dim # Generative model self.prior = pxd.Normal(var=["z"], features_shape=[z_dim], loc=torch.tensor(0.), scale=torch.tensor(1.)) self.generator = Generator(z_dim, x_dim) self.posterior = Posterior(x_dim, self.generator) # Variational model self.vposterior = VarationalPosterior(z_dim, n_dim)
def __init__(self, x_dim, z_dim, h_dim): # Generative model self.prior = pxd.Normal(loc=torch.tensor(0.), scale=torch.tensor(1.), var=["z"], features_shape=[z_dim]) self.decoder = Generator(z_dim, h_dim, x_dim) # Variational model self.encoder = Inference(x_dim, h_dim, z_dim) # Loss ce = pxl.CrossEntropy(self.encoder, self.decoder) kl = pxl.KullbackLeibler(self.encoder, self.prior) loss = (ce + kl).mean() # Init super().__init__(loss, distributions=[self.encoder, self.decoder])
def __init__(self, channel_num: int, z_dim: int, c_dim: int, temperature: float, gamma_z: float, gamma_c: float, cap_z: float, cap_c: float, **kwargs): super().__init__() self.channel_num = channel_num self.z_dim = z_dim self.c_dim = c_dim self._gamma_z_value = gamma_z self._gamma_c_value = gamma_c self._cap_z_value = cap_z self._cap_c_value = cap_c # Distributions self.prior_z = pxd.Normal( loc=torch.zeros(z_dim), scale=torch.ones(z_dim), var=["z"]) self.prior_c = pxd.Categorical( probs=torch.ones(c_dim, dtype=torch.float32) / c_dim, var=["c"]) self.encoder_func = EncoderFunction(channel_num) self.encoder_z = ContinuousEncoder(z_dim) self.encoder_c = DiscreteEncoder(c_dim, temperature) self.decoder = JointDecoder(channel_num, z_dim, c_dim) self.distributions = [self.prior_z, self.prior_c, self.encoder_func, self.encoder_z, self.encoder_c, self.decoder] # Loss self.ce = pxl.CrossEntropy(self.encoder_z * self.encoder_c, self.decoder) self.kl_z = pxl.KullbackLeibler(self.encoder_z, self.prior_z) self.kl_c = CategoricalKullbackLeibler( self.encoder_c, self.prior_c) # Coefficient for kl self.gamma_z = pxl.Parameter("gamma_z") self.gamma_c = pxl.Parameter("gamma_c") # Capacity self.cap_z = pxl.Parameter("cap_z") self.cap_c = pxl.Parameter("cap_c")
def __init__(self, channel_num: int, z_dim: int, c_dim: int, beta: float, **kwargs): super().__init__() # Parameters self.channel_num = channel_num self.z_dim = z_dim self.c_dim = c_dim self._beta_value = beta # Prior self.prior_z = pxd.Normal(loc=torch.zeros(z_dim), scale=torch.ones(z_dim), var=["z"]) self.prior_c = pxd.Categorical( probs=torch.ones(c_dim, dtype=torch.float32) / c_dim, var=["c"]) # Encoder self.encoder_func = EncoderFunction(channel_num) self.encoder_z = ContinuousEncoder(z_dim) self.encoder_c = DiscreteEncoder(c_dim) # Decoder self.decoder = JointDecoder(channel_num, z_dim, c_dim) self.distributions = [ self.prior_z, self.prior_c, self.encoder_func, self.encoder_z, self.encoder_c, self.decoder ] # Loss self.ce = pxl.CrossEntropy(self.encoder_z, self.decoder) self.beta = pxl.Parameter("beta") # Adversarial loss self.disc = Discriminator(z_dim) self.adv_js = pxl.AdversarialJensenShannon(self.encoder_z, self.prior_z, self.disc)
def main(): # ------------------------------------------------------------------------- # 1. Settings # ------------------------------------------------------------------------- # Args args = init_args() # Settings use_cuda = args.cuda and torch.cuda.is_available() device = torch.device("cuda" if use_cuda else "cpu") torch.manual_seed(args.seed) # Tensorboard writer writer = tensorboard.SummaryWriter(args.logdir) # ------------------------------------------------------------------------- # 2. Data # ------------------------------------------------------------------------- # Loader batch_size = args.batch_size train_loader, test_loader = init_dataloader( root=args.data_root, cuda=use_cuda, batch_size=batch_size) # Sample data _x, _ = iter(test_loader).next() _x = _x.to(device) # Data dimension x_dim = _x.size(1) image_dim = int(x_dim ** 0.5) # Latent dimension z_dim = args.z_dim # Dummy latent variable z_sample = 0.5 * torch.randn(args.plot_num, z_dim).to(device) # ------------------------------------------------------------------------- # 3. Pixyz classses # ------------------------------------------------------------------------- # Distributions p = Generator(z_dim, x_dim).to(device) q = Inference(x_dim, z_dim).to(device) d = Discriminator(z_dim).to(device) q_shuffle = InferenceShuffleDim(q).to(device) prior = pxd.Normal( loc=torch.tensor(0.), scale=torch.tensor(1.), var=["z"], features_shape=[z_dim], name="p_prior").to(device) # Loss reconst = -q.log_prob().expectation(q) kl = pxl.KullbackLeibler(q, prior) tc = pxl.AdversarialKullbackLeibler(q, q_shuffle, d, optimizer=optim.Adam, optimizer_params={"lr": 1e-3}) loss_cls = reconst.mean() + kl.mean() + 10 * tc # Model model = pxm.Model(loss_cls, distributions=[p, q], optimizer=optim.Adam, optimizer_params={"lr": 1e-3}) # ------------------------------------------------------------------------- # 4. Training # ------------------------------------------------------------------------- for epoch in range(1, args.epochs + 1): # Training train_loss, train_d_loss = data_loop(train_loader, model, tc, device, train_mode=True) test_loss, test_d_loss = data_loop(test_loader, model, tc, device, train_mode=False) # Sample data recon = plot_reconstruction( _x[:args.plot_recon], q, p, image_dim, image_dim) sample = plot_image_from_latent(z_sample, p, image_dim, image_dim) # Log writer.add_scalar("train_loss", train_loss.item(), epoch) writer.add_scalar("test_loss", test_loss.item(), epoch) writer.add_scalar("train_d_loss", train_d_loss.item(), epoch) writer.add_scalar("test_d_loss", test_d_loss.item(), epoch) writer.add_images("image_reconstruction", recon, epoch) writer.add_images("image_from_latent", sample, epoch) writer.close()
def main(): # ------------------------------------------------------------------------- # 1. Settings # ------------------------------------------------------------------------- # Args args = init_args() # Settings use_cuda = args.cuda and torch.cuda.is_available() device = torch.device("cuda" if use_cuda else "cpu") torch.manual_seed(args.seed) # Tensorboard writer writer = tensorboard.SummaryWriter(args.logdir) # ------------------------------------------------------------------------- # 2. Data # ------------------------------------------------------------------------- # Loader train_loader, test_loader = init_dataloader(root=args.data_root, cuda=use_cuda, batch_size=args.batch_size) # Sample data for comparison _x, _ = iter(test_loader).next() _x = _x.to(device) # Data dimension x_dim = _x.shape[1] image_dim = int(x_dim**0.5) # Latent dimension z_dim = args.z_dim # Latent data for visualization z_sample = 0.5 * torch.randn(args.plot_dim, z_dim).to(device) # ------------------------------------------------------------------------- # 3. Model # ------------------------------------------------------------------------- # Distributions p = Generator(z_dim, x_dim).to(device) q = Inference(x_dim, z_dim).to(device) prior = pxd.Normal(loc=torch.tensor(0.), scale=torch.tensor(1.), var=["z"], name="p_prior").to(device) # Model model = pxm.VAE(q, p, regularizer=pxl.KullbackLeibler(q, prior), optimizer=optim.Adam, optimizer_params={"lr": 1e-3}) # ------------------------------------------------------------------------- # 4. Training # ------------------------------------------------------------------------- for epoch in range(1, args.epochs + 1): train_loss = data_loop(train_loader, model, device, train_mode=True) test_loss = data_loop(test_loader, model, device, train_mode=False) recon = plot_reconstruction(_x[:args.plot_recon], q, p, image_dim, image_dim) sample = plot_image_from_latent(z_sample, p, image_dim, image_dim) writer.add_scalar("train_loss", train_loss.item(), epoch) writer.add_scalar("test_loss", test_loss.item(), epoch) writer.add_images("image_reconstruction", recon, epoch) writer.add_images("image_from_latent", sample, epoch) writer.close()