Ejemplo n.º 1
0
    def __init__(self, channel_num: int, z_dim: int, beta: float, c: float,
                 lmd_od: float, lmd_d: float, dip_type: str, **kwargs):
        super().__init__()

        # Parameters
        self.channel_num = channel_num
        self.z_dim = z_dim
        self._beta_value = beta
        self._c_value = c
        self.lmd_od = lmd_od
        self.lmd_d = lmd_d

        # Distributions
        self.prior = pxd.Normal(loc=torch.zeros(z_dim),
                                scale=torch.ones(z_dim),
                                var=["z"])
        self.decoder = Decoder(channel_num, z_dim)
        self.encoder = Encoder(channel_num, z_dim)
        self.distributions = [self.prior, self.decoder, self.encoder]

        # Loss class
        self.ce = pxl.CrossEntropy(self.encoder, self.decoder)
        _kl = pxl.KullbackLeibler(self.encoder, self.prior)
        _beta = pxl.Parameter("beta")
        _c = pxl.Parameter("c")
        self.kl = _beta * (_kl - _c).abs()
        self.dip = DipLoss(self.encoder, lmd_od, lmd_d, dip_type)
Ejemplo n.º 2
0
    def __init__(self, channel_num: int, z_dim: int, alpha: float, beta: float,
                 gamma: float, **kwargs):
        super().__init__()

        # Parameters
        self.channel_num = channel_num
        self.z_dim = z_dim

        self._alpha_value = alpha
        self._beta_value = beta
        self._gamma_value = gamma

        # Distributions
        self.prior = pxd.Normal(
            loc=torch.zeros(z_dim), scale=torch.ones(z_dim), var=["z"])
        self.decoder = Decoder(channel_num, z_dim)
        self.encoder = Encoder(channel_num, z_dim)
        self.distributions = [self.prior, self.decoder, self.encoder]

        # Loss class
        self.ce = pxl.CrossEntropy(self.encoder, self.decoder)
        self.kl = pxl.KullbackLeibler(self.encoder, self.prior)
        self.alpha = pxl.Parameter("alpha")
        self.beta = pxl.Parameter("beta")
        self.gamma = pxl.Parameter("gamma")
Ejemplo n.º 3
0
    def __init__(self, channel_num: int, z_dim: int, c_dim: int,
                 temperature: float, gamma_z: float, gamma_c: float,
                 cap_z: float, cap_c: float, **kwargs):
        super().__init__()

        self.channel_num = channel_num
        self.z_dim = z_dim
        self.c_dim = c_dim
        self._gamma_z_value = gamma_z
        self._gamma_c_value = gamma_c
        self._cap_z_value = cap_z
        self._cap_c_value = cap_c

        # Distributions
        self.prior_z = pxd.Normal(
            loc=torch.zeros(z_dim), scale=torch.ones(z_dim), var=["z"])
        self.prior_c = pxd.Categorical(
            probs=torch.ones(c_dim, dtype=torch.float32) / c_dim, var=["c"])

        self.encoder_func = EncoderFunction(channel_num)
        self.encoder_z = ContinuousEncoder(z_dim)
        self.encoder_c = DiscreteEncoder(c_dim, temperature)
        self.decoder = JointDecoder(channel_num, z_dim, c_dim)

        self.distributions = [self.prior_z, self.prior_c, self.encoder_func,
                              self.encoder_z, self.encoder_c, self.decoder]

        # Loss
        self.ce = pxl.CrossEntropy(self.encoder_z * self.encoder_c,
                                   self.decoder)
        self.kl_z = pxl.KullbackLeibler(self.encoder_z, self.prior_z)
        self.kl_c = CategoricalKullbackLeibler(
            self.encoder_c, self.prior_c)

        # Coefficient for kl
        self.gamma_z = pxl.Parameter("gamma_z")
        self.gamma_c = pxl.Parameter("gamma_c")

        # Capacity
        self.cap_z = pxl.Parameter("cap_z")
        self.cap_c = pxl.Parameter("cap_c")
Ejemplo n.º 4
0
    def __init__(self, channel_num: int, z_dim: int, c_dim: int, beta: float,
                 **kwargs):
        super().__init__()

        # Parameters
        self.channel_num = channel_num
        self.z_dim = z_dim
        self.c_dim = c_dim
        self._beta_value = beta

        # Prior
        self.prior_z = pxd.Normal(loc=torch.zeros(z_dim),
                                  scale=torch.ones(z_dim),
                                  var=["z"])
        self.prior_c = pxd.Categorical(
            probs=torch.ones(c_dim, dtype=torch.float32) / c_dim, var=["c"])

        # Encoder
        self.encoder_func = EncoderFunction(channel_num)
        self.encoder_z = ContinuousEncoder(z_dim)
        self.encoder_c = DiscreteEncoder(c_dim)

        # Decoder
        self.decoder = JointDecoder(channel_num, z_dim, c_dim)

        self.distributions = [
            self.prior_z, self.prior_c, self.encoder_func, self.encoder_z,
            self.encoder_c, self.decoder
        ]

        # Loss
        self.ce = pxl.CrossEntropy(self.encoder_z, self.decoder)
        self.beta = pxl.Parameter("beta")

        # Adversarial loss
        self.disc = Discriminator(z_dim)
        self.adv_js = pxl.AdversarialJensenShannon(self.encoder_z,
                                                   self.prior_z, self.disc)