Example #1
0
    def __init__(self, channel_num: int, z_dim: int, beta: float, c: float,
                 lmd_od: float, lmd_d: float, dip_type: str, **kwargs):
        super().__init__()

        # Parameters
        self.channel_num = channel_num
        self.z_dim = z_dim
        self._beta_value = beta
        self._c_value = c
        self.lmd_od = lmd_od
        self.lmd_d = lmd_d

        # Distributions
        self.prior = pxd.Normal(loc=torch.zeros(z_dim),
                                scale=torch.ones(z_dim),
                                var=["z"])
        self.decoder = Decoder(channel_num, z_dim)
        self.encoder = Encoder(channel_num, z_dim)
        self.distributions = [self.prior, self.decoder, self.encoder]

        # Loss class
        self.ce = pxl.CrossEntropy(self.encoder, self.decoder)
        _kl = pxl.KullbackLeibler(self.encoder, self.prior)
        _beta = pxl.Parameter("beta")
        _c = pxl.Parameter("c")
        self.kl = _beta * (_kl - _c).abs()
        self.dip = DipLoss(self.encoder, lmd_od, lmd_d, dip_type)
Example #2
0
    def __init__(self, channel_num: int, z_dim: int, e_dim: int, beta: float,
                 **kwargs):
        super().__init__()

        self.channel_num = channel_num
        self.z_dim = z_dim
        self.e_dim = e_dim
        self._beta_val = beta

        # Distributions
        self.normal = pxd.Normal(loc=torch.zeros(z_dim),
                                 scale=torch.ones(z_dim),
                                 var=["e"])
        self.prior = pxd.Normal(loc=torch.zeros(z_dim),
                                scale=torch.ones(z_dim),
                                var=["z"])
        self.decoder = Decoder(channel_num, z_dim)
        self.encoder = AVBEncoder(channel_num, z_dim, e_dim)
        self.distributions = [
            self.normal, self.prior, self.decoder, self.encoder
        ]

        # Loss
        self.ce = pxl.CrossEntropy(self.encoder, self.decoder)

        # Adversarial loss
        self.disc = AVBDiscriminator(channel_num, z_dim)
        self.adv_js = pxl.AdversarialJensenShannon(self.encoder, self.prior,
                                                   self.disc)
Example #3
0
    def __init__(self, channel_num: int, z_dim: int, alpha: float, beta: float,
                 gamma: float, **kwargs):
        super().__init__()

        # Parameters
        self.channel_num = channel_num
        self.z_dim = z_dim

        self._alpha_value = alpha
        self._beta_value = beta
        self._gamma_value = gamma

        # Distributions
        self.prior = pxd.Normal(
            loc=torch.zeros(z_dim), scale=torch.ones(z_dim), var=["z"])
        self.decoder = Decoder(channel_num, z_dim)
        self.encoder = Encoder(channel_num, z_dim)
        self.distributions = [self.prior, self.decoder, self.encoder]

        # Loss class
        self.ce = pxl.CrossEntropy(self.encoder, self.decoder)
        self.kl = pxl.KullbackLeibler(self.encoder, self.prior)
        self.alpha = pxl.Parameter("alpha")
        self.beta = pxl.Parameter("beta")
        self.gamma = pxl.Parameter("gamma")
Example #4
0
def load_dmm_model(x_dim, t_max, device, args):

    # Latent dimensions
    h_dim = args.h_dim
    hidden_dim = args.hidden_dim
    z_dim = args.z_dim

    # Distributions
    prior = Prior(z_dim, hidden_dim).to(device)
    decoder = Generator(z_dim, hidden_dim, x_dim).to(device)
    encoder = Inference(z_dim, h_dim).to(device)
    rnn = RNN(x_dim, h_dim).to(device)

    # Sampler
    generate_from_prior = prior * decoder

    # Loss
    ce = pxl.CrossEntropy(encoder, decoder)
    kl = pxl.KullbackLeibler(encoder, prior)
    step_loss = ce + kl
    _loss = pxl.IterativeLoss(step_loss, max_iter=t_max, series_var=["x", "h"],
                              update_value={"z": "z_prev"})
    loss = _loss.expectation(rnn).mean()

    # Model
    dmm = pxm.Model(loss, distributions=[rnn, encoder, decoder, prior],
                    optimizer=optim.Adam,
                    optimizer_params={"lr": args.learning_rate,
                                      "betas": (args.beta1, args.beta2),
                                      "weight_decay": args.weight_decay},
                    clip_grad_norm=args.clip_grad_norm)

    return dmm, generate_from_prior, decoder
Example #5
0
    def __init__(self, x_dim, z_dim, h_dim):

        # Generative model
        self.prior = pxd.Normal(loc=torch.tensor(0.),
                                scale=torch.tensor(1.),
                                var=["z"],
                                features_shape=[z_dim])
        self.decoder = Generator(z_dim, h_dim, x_dim)

        # Variational model
        self.encoder = Inference(x_dim, h_dim, z_dim)

        # Loss
        ce = pxl.CrossEntropy(self.encoder, self.decoder)
        kl = pxl.KullbackLeibler(self.encoder, self.prior)
        loss = (ce + kl).mean()

        # Init
        super().__init__(loss, distributions=[self.encoder, self.decoder])
Example #6
0
    def __init__(self, channel_num: int, z_dim: int, c_dim: int,
                 temperature: float, gamma_z: float, gamma_c: float,
                 cap_z: float, cap_c: float, **kwargs):
        super().__init__()

        self.channel_num = channel_num
        self.z_dim = z_dim
        self.c_dim = c_dim
        self._gamma_z_value = gamma_z
        self._gamma_c_value = gamma_c
        self._cap_z_value = cap_z
        self._cap_c_value = cap_c

        # Distributions
        self.prior_z = pxd.Normal(
            loc=torch.zeros(z_dim), scale=torch.ones(z_dim), var=["z"])
        self.prior_c = pxd.Categorical(
            probs=torch.ones(c_dim, dtype=torch.float32) / c_dim, var=["c"])

        self.encoder_func = EncoderFunction(channel_num)
        self.encoder_z = ContinuousEncoder(z_dim)
        self.encoder_c = DiscreteEncoder(c_dim, temperature)
        self.decoder = JointDecoder(channel_num, z_dim, c_dim)

        self.distributions = [self.prior_z, self.prior_c, self.encoder_func,
                              self.encoder_z, self.encoder_c, self.decoder]

        # Loss
        self.ce = pxl.CrossEntropy(self.encoder_z * self.encoder_c,
                                   self.decoder)
        self.kl_z = pxl.KullbackLeibler(self.encoder_z, self.prior_z)
        self.kl_c = CategoricalKullbackLeibler(
            self.encoder_c, self.prior_c)

        # Coefficient for kl
        self.gamma_z = pxl.Parameter("gamma_z")
        self.gamma_c = pxl.Parameter("gamma_c")

        # Capacity
        self.cap_z = pxl.Parameter("cap_z")
        self.cap_c = pxl.Parameter("cap_c")
Example #7
0
    def __init__(self, channel_num: int, z_dim: int, c_dim: int, beta: float,
                 **kwargs):
        super().__init__()

        # Parameters
        self.channel_num = channel_num
        self.z_dim = z_dim
        self.c_dim = c_dim
        self._beta_value = beta

        # Prior
        self.prior_z = pxd.Normal(loc=torch.zeros(z_dim),
                                  scale=torch.ones(z_dim),
                                  var=["z"])
        self.prior_c = pxd.Categorical(
            probs=torch.ones(c_dim, dtype=torch.float32) / c_dim, var=["c"])

        # Encoder
        self.encoder_func = EncoderFunction(channel_num)
        self.encoder_z = ContinuousEncoder(z_dim)
        self.encoder_c = DiscreteEncoder(c_dim)

        # Decoder
        self.decoder = JointDecoder(channel_num, z_dim, c_dim)

        self.distributions = [
            self.prior_z, self.prior_c, self.encoder_func, self.encoder_z,
            self.encoder_c, self.decoder
        ]

        # Loss
        self.ce = pxl.CrossEntropy(self.encoder_z, self.decoder)
        self.beta = pxl.Parameter("beta")

        # Adversarial loss
        self.disc = Discriminator(z_dim)
        self.adv_js = pxl.AdversarialJensenShannon(self.encoder_z,
                                                   self.prior_z, self.disc)
Example #8
0
def main():
    # -------------------------------------------------------------------------
    # 1. Settings
    # -------------------------------------------------------------------------

    # Args
    args = init_args()

    # Settings
    use_cuda = args.cuda and torch.cuda.is_available()
    device = torch.device("cuda" if use_cuda else "cpu")
    torch.manual_seed(args.seed)

    # Tensorboard writer
    writer = tensorboard.SummaryWriter(args.logdir)

    # -------------------------------------------------------------------------
    # 2. Data
    # -------------------------------------------------------------------------

    # Loader
    batch_size = args.batch_size
    train_loader, test_loader = init_dataloader(root=args.data_root,
                                                cuda=use_cuda,
                                                batch_size=batch_size)

    # Data dimension
    x_dim = train_loader.dataset.data.shape[1]
    t_max = train_loader.dataset.data.shape[2]

    # -------------------------------------------------------------------------
    # 3. Model
    # -------------------------------------------------------------------------

    # Latent dimension
    h_dim = args.h_dim
    hidden_dim = args.hidden_dim
    z_dim = args.z_dim

    # Distributions
    prior = Prior(z_dim, hidden_dim).to(device)
    decoder = Generator(z_dim, hidden_dim, x_dim).to(device)
    encoder = Inference(z_dim, h_dim).to(device)
    rnn = RNN(x_dim, h_dim).to(device)

    # Sampler
    generate_from_prior = prior * decoder

    # Loss
    ce = pxl.CrossEntropy(encoder, decoder)
    kl = pxl.KullbackLeibler(encoder, prior)
    _loss = pxl.IterativeLoss(ce + kl,
                              max_iter=t_max,
                              series_var=["x", "h"],
                              update_value={"z": "z_prev"})
    loss = _loss.expectation(rnn).mean()

    # Model
    model = pxm.Model(loss,
                      distributions=[rnn, encoder, decoder, prior],
                      optimizer=optim.Adam,
                      optimizer_params={"lr": 1e-3},
                      clip_grad_value=10)

    # -------------------------------------------------------------------------
    # 4. Training
    # -------------------------------------------------------------------------

    for epoch in range(1, args.epochs + 1):
        # Training
        train_loss = data_loop(train_loader,
                               model,
                               z_dim,
                               device,
                               train_mode=True)
        test_loss = data_loop(test_loader,
                              model,
                              z_dim,
                              device,
                              train_mode=False)

        # Sample data
        sample = plot_image_from_latent(generate_from_prior, decoder,
                                        batch_size, z_dim, t_max, device)

        # Log
        writer.add_scalar("train_loss", train_loss.item(), epoch)
        writer.add_scalar("test_loss", test_loss.item(), epoch)
        writer.add_images("image_from_latent", sample, epoch)

    writer.close()