Exemple #1
0
    def fit(self, training_loader, validation_loader, epochs, device):
        total_step = 0
        for _epoc in range(epochs):

            with tqdm(total=len(training_loader), desc=f"{_epoc:2.0f}", ncols=100) as pbar:
                for i_bt, batch_data in enumerate(training_loader):
                    batch_data = wrap_data(batch_data, self.data_keys, device=device)

                    loss, loss_nll, loss_kl, weight = self.train_step(batch_data)

                    pbar.set_postfix({"loss": f"{loss:.3f}", "nll": f"{loss_nll:.2f}",
                                      "kl": f"{loss_kl:.2f}", "weight": f"{weight:.3f}"})
                    pbar.update()

                    total_step += 1
                    self.writer.add_scalar("training_loss", loss, total_step)
                    self.writer.add_scalar("training_loss_nll", loss_nll, total_step)
                    self.writer.add_scalar("training_loss_kl", loss_kl, total_step)
                    self.writer.add_scalar("annealing_weight", weight, total_step)

                pbar.close()
                training_elbo, training_ppl = self.eval_metric(training_loader, device)
                validation_elbo, validation_ppl = self.eval_metric(validation_loader, device)

                print(f"evaluation: ppl: {training_ppl:.1f}, val_ppl: {validation_ppl:.1f}, elbo: {training_elbo:.1f} val elbo: {validation_elbo:.1f}")
                # pbar.set_postfix({"ppl": f"{training_ppl:.1f}", "val_ppl": f"{validation_ppl:.1f}",
                #                   "elbo": f"{training_elbo:.1f}", "val_elbo": f"{validation_elbo:.1f}"})
                self.writer.add_scalar("ppl/train", training_ppl, _epoc)
                self.writer.add_scalar("ppl/val", validation_ppl, _epoc)
                self.writer.add_scalar("elbo/train", training_elbo, _epoc)
                self.writer.add_scalar("elbo/val", validation_elbo, _epoc)
        return
Exemple #2
0
    def eval_metric(self, data_loader, device):
        self.model.eval()

        total_ppl = 0
        total_elbo = 0
        total_samples = 0
        weight = self.annealer.weight
        for i_bt, batch_data in enumerate(data_loader):
            batch_data = wrap_data(batch_data, self.data_keys, device=device)
            label = batch_data["label"]
            batch_size = label.shape[1]

            logits, mean, logvar = self.model.batch_forward(batch_data)
            loss_nll, loss_kl = self.loss(logits, label, mean, logvar)

            loss = loss_nll + loss_kl * weight

            ppl = torch.nn.functional.nll_loss(logits.permute(0, 2, 1),
                                               label,
                                               reduction="mean")
            # NLL_loss is setting as reduction = "none"

            total_elbo += loss.item() * batch_size
            total_ppl += ppl.item() * batch_size
            total_samples += batch_size

        return total_elbo / total_samples, np.exp(total_ppl / total_samples)
Exemple #3
0
def main(args):
    traj, n_traj, translator = traj_prepare(
        TRAJ_PATH,
        CODE_PATH,
        NUM_SAMPLES,
        use_cols=args.begin_index +
        np.arange(args.num_per_day) * args.time_interval,
        add_idx=ADD_IDX)
    NUM_CLASS = n_traj

    num_training_sample = int(NUM_SAMPLES * 0.8)
    ix_sos = translator.trans_code2ix(["0/SOS"])
    traj = np.concatenate([np.ones((traj.shape[0], 1), dtype=int), traj],
                          axis=1)
    training_data, validation_data = traj[:num_training_sample, :-1], traj[
        num_training_sample:, :-1]
    training_label, validation_label = traj[:num_training_sample,
                                            1:], traj[num_training_sample:, 1:]

    training_loader = torch.utils.data.DataLoader(TrajDatasetGeneral(
        [training_data, training_label]),
                                                  batch_size=args.batch_size,
                                                  shuffle=True,
                                                  num_workers=4,
                                                  drop_last=False)
    validation_loader = torch.utils.data.DataLoader(TrajDatasetGeneral(
        [validation_data, validation_label]),
                                                    batch_size=2000,
                                                    shuffle=True,
                                                    num_workers=4,
                                                    drop_last=False)

    seq_vae = SeqVAE(num_class=NUM_CLASS,
                     embedding_size=args.embed_size,
                     embedding_drop_out=args.emb_dropout,
                     hid_size=args.hid_size,
                     hid_layers=args.hid_layer,
                     latent_z_size=args.latent_z_size,
                     word_dropout_rate=args.word_dropout,
                     gru_drop_out=args.gru_dropout,
                     anneal_k=args.anneal_k,
                     anneal_x0=args.anneal_x0,
                     anneal_function=args.anneal_func,
                     data_keys=DATA_KEYS,
                     learning_rate=args.lr,
                     writer_comment="SentenceVAE",
                     sos_idx=ix_sos,
                     unk_idx=None)

    test_batch = wrap_data(next(iter(training_loader)), DATA_KEYS)
    model_summary = torchsummaryX.summary(seq_vae.model, test_batch['data'])

    seq_vae.model.to(DEVICE)

    seq_vae.fit(training_loader=training_loader,
                validation_loader=validation_loader,
                epochs=args.epochs,
                device=DEVICE)

    return
    def eval_metric(self, data_loader, device):
        self.model.eval()

        total_ppl = 0
        total_elbo = 0
        total_samples = 0
        weight = 1.
        for i_bt, batch_data in enumerate(data_loader):
            batch_data = wrap_data(batch_data, self.data_keys, device=device)
            data, cond, label = batch_data["data"], batch_data[
                "cond"], batch_data["label"]
            batch_size = label.shape[1]

            logits, mean, logvar, classify_logits = self.model.forward(
                input_data=data, cond=cond.T, z=None, batch_first_cond=True)
            loss = self.disentangled_loss(logits,
                                          label,
                                          mean,
                                          logvar,
                                          classify_logits,
                                          cond.T,
                                          batchfirst_cond=True)

            # loss = loss_nll + loss_kl * weight

            ppl = torch.nn.functional.nll_loss(logits.permute(0, 2, 1),
                                               label,
                                               reduction="mean")
            # NLL_loss is setting as reduction = "none"

            total_elbo += loss.item() * batch_size
            total_ppl += ppl.item() * batch_size
            total_samples += batch_size

        return total_elbo / total_samples, np.exp(total_ppl / total_samples)
Exemple #5
0
    def fit(self,
            training_loader,
            validation_loader,
            epochs,
            device,
            optimizer,
            criterion,
            teach_forcing=0.):
        # TODO not proper transfer of optimizer and criterion
        # TODO teach forcing ratio
        total_step = 0
        for _epoc in range(epochs):
            total_loss = 0.
            num_sample = 0.
            with tqdm(total=len(training_loader),
                      desc=f"{_epoc:2.0f}") as pbar:
                for i_bt, batch_data in enumerate(training_loader):
                    batch_data = wrap_data(batch_data,
                                           self.data_keys,
                                           device=device)
                    batch_size = batch_data["data"].shape[1]

                    loss = self.train_step(batch_data,
                                           optimizer=optimizer,
                                           criterion=criterion)

                    pbar.set_postfix({"loss": f"{loss:.3f}"})
                    pbar.update()

                    total_loss += loss
                    num_sample += batch_size

                    total_step += 1
                    self.writer.add_scalar("training_loss", loss, total_step)

                training_ppl = self.eval_ppl(training_loader,
                                             criterion=criterion,
                                             device=device)
                validation_ppl = self.eval_ppl(validation_loader,
                                               criterion=criterion,
                                               device=device)
                pbar.set_postfix({
                    "ppl": f"{training_ppl:.1f}",
                    "val_ppl": f"{validation_ppl:.1f}"
                })
                self.writer.add_scalar("ppl/train", training_ppl, _epoc)
                self.writer.add_scalar("ppl/val", validation_ppl, _epoc)
        return
Exemple #6
0
def train(args):
    traj, n_traj, translator = traj_prepare(
                                TRAJ_PATH, CODE_PATH, NUM_SAMPLES,
                                use_cols=args.begin_index + np.arange(args.num_per_day) * args.time_interval,
                                add_idx=ADD_IDX)
    cond, n_categ_cond = cond_prepare(COND_PATH, NUM_SAMPLES)
    cond = cond[:, 1] # only use job here
    n_categ_cond = n_categ_cond[1]
    cond = trans_one_hot(cond.flatten(), n_categ_cond).astype(np.float32)

    NUM_CLASS = n_traj

    # add SOS
    num_training_sample = int(NUM_SAMPLES * 0.8)
    ix_sos = translator.trans_code2ix(["0/SOS"])[0]
    traj = np.concatenate([np.ones((traj.shape[0], 1), dtype=int),
                           traj], axis=1)

    training_data, validation_data = traj[:num_training_sample, :-1], traj[num_training_sample:, :-1]
    training_label, validation_label = traj[:num_training_sample, 1:], traj[num_training_sample:, 1:]
    training_cond, validation_cond = cond[:num_training_sample], cond[num_training_sample:]

    training_loader = torch.utils.data.DataLoader(
        TrajDatasetGeneral([training_data, training_label, training_cond]),
        batch_size=args.batch_size, shuffle=True, num_workers=4, drop_last=False)
    validation_loader = torch.utils.data.DataLoader(
        TrajDatasetGeneral([validation_data, validation_label, validation_cond]),
        batch_size=2000, shuffle=False, num_workers=4, drop_last=False)

    seq_vae = DisentangledVAE(num_class=NUM_CLASS, embedding_size=args.embed_size, embedding_drop_out=args.emb_dropout,
                              hid_size=args.hid_size, hid_layers=args.hid_layer,
                              latent_z_size=args.latent_z_size, cond_size=n_categ_cond,
                              word_dropout_rate=args.word_dropout, gru_drop_out=args.gru_dropout,
                              anneal_k=args.anneal_k, anneal_x0=args.anneal_x0, anneal_function=args.anneal_func,
                              data_keys=DATA_KEYS, learning_rate=args.lr,
                              sos_idx=ix_sos, unk_idx=None)

    test_batch = wrap_data(next(iter(training_loader)), DATA_KEYS)
    _ = torchsummaryX.summary(seq_vae.model, test_batch['data'], test_batch['cond'], None, False)

    seq_vae.model.to(DEVICE)

    seq_vae.fit(training_loader=training_loader, validation_loader=validation_loader, epochs=args.epochs, device=DEVICE)
    return seq_vae, translator
Exemple #7
0
    def eval_ppl(self, data_loader, criterion, device=None):
        total_loss = 0.
        num_sample = 0.

        for i_bt, batch_data in enumerate(data_loader):
            batch_data = wrap_data(batch_data, self.data_keys, device=device)
            batch_size = batch_data["data"].shape[1]

            output, hn = self.forward_packed_data(batch_data,
                                                  h0=None,
                                                  teaching_ratio=0.)
            label = batch_data["label"]
            output = output.permute(0, 2, 1)
            loss = criterion(output, label)

            total_loss += loss.item() * batch_size
            num_sample += batch_size

        return np.exp(total_loss / num_sample)
Exemple #8
0
                     embedding_size=args.embed_size,
                     learning_rate=args.lr,
                     data_keys=DATA_KEYS,
                     writer_comment="Single_RNN")

test_dt = training_loader.dataset[0:10][0].T
test_dt = torch.tensor(test_dt)
model_summary = torchsummaryX.summary(sing_gru, test_dt)
#%%
sing_gru.to(DEVICE)

sing_gru.fit(training_loader=training_loader,
             validation_loader=validation_loader,
             epochs=args.epochs,
             device=DEVICE,
             optimizer=sing_gru.optimizer,
             criterion=sing_gru.criterion)
log_dir = sing_gru.writer.log_dir
sing_gru.writer.close()
torch.save(sing_gru, f"{log_dir}/model.pt")
#%%

data = training_loader.dataset.get_sample(15)
begin_time = int(args.num_per_day * 0.35)
data[0] = data[0][:, :begin_time]
data = wrap_data(data, keys=DATA_KEYS, device=DEVICE)
sampled_traj = sing_gru.sample_traj(data,
                                    predict_length=args.num_per_day -
                                    args.begin_index,
                                    sample_method="multinomial")