def __init__(self, channel_num: int, z_dim: int, beta: float, c: float, lmd_od: float, lmd_d: float, dip_type: str, **kwargs): super().__init__() # Parameters self.channel_num = channel_num self.z_dim = z_dim self._beta_value = beta self._c_value = c self.lmd_od = lmd_od self.lmd_d = lmd_d # Distributions self.prior = pxd.Normal(loc=torch.zeros(z_dim), scale=torch.ones(z_dim), var=["z"]) self.decoder = Decoder(channel_num, z_dim) self.encoder = Encoder(channel_num, z_dim) self.distributions = [self.prior, self.decoder, self.encoder] # Loss class self.ce = pxl.CrossEntropy(self.encoder, self.decoder) _kl = pxl.KullbackLeibler(self.encoder, self.prior) _beta = pxl.Parameter("beta") _c = pxl.Parameter("c") self.kl = _beta * (_kl - _c).abs() self.dip = DipLoss(self.encoder, lmd_od, lmd_d, dip_type)
def __init__(self, channel_num: int, z_dim: int, alpha: float, beta: float, gamma: float, **kwargs): super().__init__() # Parameters self.channel_num = channel_num self.z_dim = z_dim self._alpha_value = alpha self._beta_value = beta self._gamma_value = gamma # Distributions self.prior = pxd.Normal( loc=torch.zeros(z_dim), scale=torch.ones(z_dim), var=["z"]) self.decoder = Decoder(channel_num, z_dim) self.encoder = Encoder(channel_num, z_dim) self.distributions = [self.prior, self.decoder, self.encoder] # Loss class self.ce = pxl.CrossEntropy(self.encoder, self.decoder) self.kl = pxl.KullbackLeibler(self.encoder, self.prior) self.alpha = pxl.Parameter("alpha") self.beta = pxl.Parameter("beta") self.gamma = pxl.Parameter("gamma")
def __init__(self, channel_num: int, z_dim: int, c_dim: int, temperature: float, gamma_z: float, gamma_c: float, cap_z: float, cap_c: float, **kwargs): super().__init__() self.channel_num = channel_num self.z_dim = z_dim self.c_dim = c_dim self._gamma_z_value = gamma_z self._gamma_c_value = gamma_c self._cap_z_value = cap_z self._cap_c_value = cap_c # Distributions self.prior_z = pxd.Normal( loc=torch.zeros(z_dim), scale=torch.ones(z_dim), var=["z"]) self.prior_c = pxd.Categorical( probs=torch.ones(c_dim, dtype=torch.float32) / c_dim, var=["c"]) self.encoder_func = EncoderFunction(channel_num) self.encoder_z = ContinuousEncoder(z_dim) self.encoder_c = DiscreteEncoder(c_dim, temperature) self.decoder = JointDecoder(channel_num, z_dim, c_dim) self.distributions = [self.prior_z, self.prior_c, self.encoder_func, self.encoder_z, self.encoder_c, self.decoder] # Loss self.ce = pxl.CrossEntropy(self.encoder_z * self.encoder_c, self.decoder) self.kl_z = pxl.KullbackLeibler(self.encoder_z, self.prior_z) self.kl_c = CategoricalKullbackLeibler( self.encoder_c, self.prior_c) # Coefficient for kl self.gamma_z = pxl.Parameter("gamma_z") self.gamma_c = pxl.Parameter("gamma_c") # Capacity self.cap_z = pxl.Parameter("cap_z") self.cap_c = pxl.Parameter("cap_c")
def __init__(self, channel_num: int, z_dim: int, c_dim: int, beta: float, **kwargs): super().__init__() # Parameters self.channel_num = channel_num self.z_dim = z_dim self.c_dim = c_dim self._beta_value = beta # Prior self.prior_z = pxd.Normal(loc=torch.zeros(z_dim), scale=torch.ones(z_dim), var=["z"]) self.prior_c = pxd.Categorical( probs=torch.ones(c_dim, dtype=torch.float32) / c_dim, var=["c"]) # Encoder self.encoder_func = EncoderFunction(channel_num) self.encoder_z = ContinuousEncoder(z_dim) self.encoder_c = DiscreteEncoder(c_dim) # Decoder self.decoder = JointDecoder(channel_num, z_dim, c_dim) self.distributions = [ self.prior_z, self.prior_c, self.encoder_func, self.encoder_z, self.encoder_c, self.decoder ] # Loss self.ce = pxl.CrossEntropy(self.encoder_z, self.decoder) self.beta = pxl.Parameter("beta") # Adversarial loss self.disc = Discriminator(z_dim) self.adv_js = pxl.AdversarialJensenShannon(self.encoder_z, self.prior_z, self.disc)