def __init__(self, channel_num: int, z_dim: int, beta: float, c: float, lmd_od: float, lmd_d: float, dip_type: str, **kwargs): super().__init__() # Parameters self.channel_num = channel_num self.z_dim = z_dim self._beta_value = beta self._c_value = c self.lmd_od = lmd_od self.lmd_d = lmd_d # Distributions self.prior = pxd.Normal(loc=torch.zeros(z_dim), scale=torch.ones(z_dim), var=["z"]) self.decoder = Decoder(channel_num, z_dim) self.encoder = Encoder(channel_num, z_dim) self.distributions = [self.prior, self.decoder, self.encoder] # Loss class self.ce = pxl.CrossEntropy(self.encoder, self.decoder) _kl = pxl.KullbackLeibler(self.encoder, self.prior) _beta = pxl.Parameter("beta") _c = pxl.Parameter("c") self.kl = _beta * (_kl - _c).abs() self.dip = DipLoss(self.encoder, lmd_od, lmd_d, dip_type)
def __init__(self, channel_num: int, z_dim: int, e_dim: int, beta: float, **kwargs): super().__init__() self.channel_num = channel_num self.z_dim = z_dim self.e_dim = e_dim self._beta_val = beta # Distributions self.normal = pxd.Normal(loc=torch.zeros(z_dim), scale=torch.ones(z_dim), var=["e"]) self.prior = pxd.Normal(loc=torch.zeros(z_dim), scale=torch.ones(z_dim), var=["z"]) self.decoder = Decoder(channel_num, z_dim) self.encoder = AVBEncoder(channel_num, z_dim, e_dim) self.distributions = [ self.normal, self.prior, self.decoder, self.encoder ] # Loss self.ce = pxl.CrossEntropy(self.encoder, self.decoder) # Adversarial loss self.disc = AVBDiscriminator(channel_num, z_dim) self.adv_js = pxl.AdversarialJensenShannon(self.encoder, self.prior, self.disc)
def __init__(self, channel_num: int, z_dim: int, alpha: float, beta: float, gamma: float, **kwargs): super().__init__() # Parameters self.channel_num = channel_num self.z_dim = z_dim self._alpha_value = alpha self._beta_value = beta self._gamma_value = gamma # Distributions self.prior = pxd.Normal( loc=torch.zeros(z_dim), scale=torch.ones(z_dim), var=["z"]) self.decoder = Decoder(channel_num, z_dim) self.encoder = Encoder(channel_num, z_dim) self.distributions = [self.prior, self.decoder, self.encoder] # Loss class self.ce = pxl.CrossEntropy(self.encoder, self.decoder) self.kl = pxl.KullbackLeibler(self.encoder, self.prior) self.alpha = pxl.Parameter("alpha") self.beta = pxl.Parameter("beta") self.gamma = pxl.Parameter("gamma")
def load_dmm_model(x_dim, t_max, device, args): # Latent dimensions h_dim = args.h_dim hidden_dim = args.hidden_dim z_dim = args.z_dim # Distributions prior = Prior(z_dim, hidden_dim).to(device) decoder = Generator(z_dim, hidden_dim, x_dim).to(device) encoder = Inference(z_dim, h_dim).to(device) rnn = RNN(x_dim, h_dim).to(device) # Sampler generate_from_prior = prior * decoder # Loss ce = pxl.CrossEntropy(encoder, decoder) kl = pxl.KullbackLeibler(encoder, prior) step_loss = ce + kl _loss = pxl.IterativeLoss(step_loss, max_iter=t_max, series_var=["x", "h"], update_value={"z": "z_prev"}) loss = _loss.expectation(rnn).mean() # Model dmm = pxm.Model(loss, distributions=[rnn, encoder, decoder, prior], optimizer=optim.Adam, optimizer_params={"lr": args.learning_rate, "betas": (args.beta1, args.beta2), "weight_decay": args.weight_decay}, clip_grad_norm=args.clip_grad_norm) return dmm, generate_from_prior, decoder
def __init__(self, x_dim, z_dim, h_dim): # Generative model self.prior = pxd.Normal(loc=torch.tensor(0.), scale=torch.tensor(1.), var=["z"], features_shape=[z_dim]) self.decoder = Generator(z_dim, h_dim, x_dim) # Variational model self.encoder = Inference(x_dim, h_dim, z_dim) # Loss ce = pxl.CrossEntropy(self.encoder, self.decoder) kl = pxl.KullbackLeibler(self.encoder, self.prior) loss = (ce + kl).mean() # Init super().__init__(loss, distributions=[self.encoder, self.decoder])
def __init__(self, channel_num: int, z_dim: int, c_dim: int, temperature: float, gamma_z: float, gamma_c: float, cap_z: float, cap_c: float, **kwargs): super().__init__() self.channel_num = channel_num self.z_dim = z_dim self.c_dim = c_dim self._gamma_z_value = gamma_z self._gamma_c_value = gamma_c self._cap_z_value = cap_z self._cap_c_value = cap_c # Distributions self.prior_z = pxd.Normal( loc=torch.zeros(z_dim), scale=torch.ones(z_dim), var=["z"]) self.prior_c = pxd.Categorical( probs=torch.ones(c_dim, dtype=torch.float32) / c_dim, var=["c"]) self.encoder_func = EncoderFunction(channel_num) self.encoder_z = ContinuousEncoder(z_dim) self.encoder_c = DiscreteEncoder(c_dim, temperature) self.decoder = JointDecoder(channel_num, z_dim, c_dim) self.distributions = [self.prior_z, self.prior_c, self.encoder_func, self.encoder_z, self.encoder_c, self.decoder] # Loss self.ce = pxl.CrossEntropy(self.encoder_z * self.encoder_c, self.decoder) self.kl_z = pxl.KullbackLeibler(self.encoder_z, self.prior_z) self.kl_c = CategoricalKullbackLeibler( self.encoder_c, self.prior_c) # Coefficient for kl self.gamma_z = pxl.Parameter("gamma_z") self.gamma_c = pxl.Parameter("gamma_c") # Capacity self.cap_z = pxl.Parameter("cap_z") self.cap_c = pxl.Parameter("cap_c")
def __init__(self, channel_num: int, z_dim: int, c_dim: int, beta: float, **kwargs): super().__init__() # Parameters self.channel_num = channel_num self.z_dim = z_dim self.c_dim = c_dim self._beta_value = beta # Prior self.prior_z = pxd.Normal(loc=torch.zeros(z_dim), scale=torch.ones(z_dim), var=["z"]) self.prior_c = pxd.Categorical( probs=torch.ones(c_dim, dtype=torch.float32) / c_dim, var=["c"]) # Encoder self.encoder_func = EncoderFunction(channel_num) self.encoder_z = ContinuousEncoder(z_dim) self.encoder_c = DiscreteEncoder(c_dim) # Decoder self.decoder = JointDecoder(channel_num, z_dim, c_dim) self.distributions = [ self.prior_z, self.prior_c, self.encoder_func, self.encoder_z, self.encoder_c, self.decoder ] # Loss self.ce = pxl.CrossEntropy(self.encoder_z, self.decoder) self.beta = pxl.Parameter("beta") # Adversarial loss self.disc = Discriminator(z_dim) self.adv_js = pxl.AdversarialJensenShannon(self.encoder_z, self.prior_z, self.disc)
def main(): # ------------------------------------------------------------------------- # 1. Settings # ------------------------------------------------------------------------- # Args args = init_args() # Settings use_cuda = args.cuda and torch.cuda.is_available() device = torch.device("cuda" if use_cuda else "cpu") torch.manual_seed(args.seed) # Tensorboard writer writer = tensorboard.SummaryWriter(args.logdir) # ------------------------------------------------------------------------- # 2. Data # ------------------------------------------------------------------------- # Loader batch_size = args.batch_size train_loader, test_loader = init_dataloader(root=args.data_root, cuda=use_cuda, batch_size=batch_size) # Data dimension x_dim = train_loader.dataset.data.shape[1] t_max = train_loader.dataset.data.shape[2] # ------------------------------------------------------------------------- # 3. Model # ------------------------------------------------------------------------- # Latent dimension h_dim = args.h_dim hidden_dim = args.hidden_dim z_dim = args.z_dim # Distributions prior = Prior(z_dim, hidden_dim).to(device) decoder = Generator(z_dim, hidden_dim, x_dim).to(device) encoder = Inference(z_dim, h_dim).to(device) rnn = RNN(x_dim, h_dim).to(device) # Sampler generate_from_prior = prior * decoder # Loss ce = pxl.CrossEntropy(encoder, decoder) kl = pxl.KullbackLeibler(encoder, prior) _loss = pxl.IterativeLoss(ce + kl, max_iter=t_max, series_var=["x", "h"], update_value={"z": "z_prev"}) loss = _loss.expectation(rnn).mean() # Model model = pxm.Model(loss, distributions=[rnn, encoder, decoder, prior], optimizer=optim.Adam, optimizer_params={"lr": 1e-3}, clip_grad_value=10) # ------------------------------------------------------------------------- # 4. Training # ------------------------------------------------------------------------- for epoch in range(1, args.epochs + 1): # Training train_loss = data_loop(train_loader, model, z_dim, device, train_mode=True) test_loss = data_loop(test_loader, model, z_dim, device, train_mode=False) # Sample data sample = plot_image_from_latent(generate_from_prior, decoder, batch_size, z_dim, t_max, device) # Log writer.add_scalar("train_loss", train_loss.item(), epoch) writer.add_scalar("test_loss", test_loss.item(), epoch) writer.add_images("image_from_latent", sample, epoch) writer.close()