def fit(self, training_loader, validation_loader, epochs, device): total_step = 0 for _epoc in range(epochs): with tqdm(total=len(training_loader), desc=f"{_epoc:2.0f}", ncols=100) as pbar: for i_bt, batch_data in enumerate(training_loader): batch_data = wrap_data(batch_data, self.data_keys, device=device) loss, loss_nll, loss_kl, weight = self.train_step(batch_data) pbar.set_postfix({"loss": f"{loss:.3f}", "nll": f"{loss_nll:.2f}", "kl": f"{loss_kl:.2f}", "weight": f"{weight:.3f}"}) pbar.update() total_step += 1 self.writer.add_scalar("training_loss", loss, total_step) self.writer.add_scalar("training_loss_nll", loss_nll, total_step) self.writer.add_scalar("training_loss_kl", loss_kl, total_step) self.writer.add_scalar("annealing_weight", weight, total_step) pbar.close() training_elbo, training_ppl = self.eval_metric(training_loader, device) validation_elbo, validation_ppl = self.eval_metric(validation_loader, device) print(f"evaluation: ppl: {training_ppl:.1f}, val_ppl: {validation_ppl:.1f}, elbo: {training_elbo:.1f} val elbo: {validation_elbo:.1f}") # pbar.set_postfix({"ppl": f"{training_ppl:.1f}", "val_ppl": f"{validation_ppl:.1f}", # "elbo": f"{training_elbo:.1f}", "val_elbo": f"{validation_elbo:.1f}"}) self.writer.add_scalar("ppl/train", training_ppl, _epoc) self.writer.add_scalar("ppl/val", validation_ppl, _epoc) self.writer.add_scalar("elbo/train", training_elbo, _epoc) self.writer.add_scalar("elbo/val", validation_elbo, _epoc) return
def eval_metric(self, data_loader, device): self.model.eval() total_ppl = 0 total_elbo = 0 total_samples = 0 weight = self.annealer.weight for i_bt, batch_data in enumerate(data_loader): batch_data = wrap_data(batch_data, self.data_keys, device=device) label = batch_data["label"] batch_size = label.shape[1] logits, mean, logvar = self.model.batch_forward(batch_data) loss_nll, loss_kl = self.loss(logits, label, mean, logvar) loss = loss_nll + loss_kl * weight ppl = torch.nn.functional.nll_loss(logits.permute(0, 2, 1), label, reduction="mean") # NLL_loss is setting as reduction = "none" total_elbo += loss.item() * batch_size total_ppl += ppl.item() * batch_size total_samples += batch_size return total_elbo / total_samples, np.exp(total_ppl / total_samples)
def main(args): traj, n_traj, translator = traj_prepare( TRAJ_PATH, CODE_PATH, NUM_SAMPLES, use_cols=args.begin_index + np.arange(args.num_per_day) * args.time_interval, add_idx=ADD_IDX) NUM_CLASS = n_traj num_training_sample = int(NUM_SAMPLES * 0.8) ix_sos = translator.trans_code2ix(["0/SOS"]) traj = np.concatenate([np.ones((traj.shape[0], 1), dtype=int), traj], axis=1) training_data, validation_data = traj[:num_training_sample, :-1], traj[ num_training_sample:, :-1] training_label, validation_label = traj[:num_training_sample, 1:], traj[num_training_sample:, 1:] training_loader = torch.utils.data.DataLoader(TrajDatasetGeneral( [training_data, training_label]), batch_size=args.batch_size, shuffle=True, num_workers=4, drop_last=False) validation_loader = torch.utils.data.DataLoader(TrajDatasetGeneral( [validation_data, validation_label]), batch_size=2000, shuffle=True, num_workers=4, drop_last=False) seq_vae = SeqVAE(num_class=NUM_CLASS, embedding_size=args.embed_size, embedding_drop_out=args.emb_dropout, hid_size=args.hid_size, hid_layers=args.hid_layer, latent_z_size=args.latent_z_size, word_dropout_rate=args.word_dropout, gru_drop_out=args.gru_dropout, anneal_k=args.anneal_k, anneal_x0=args.anneal_x0, anneal_function=args.anneal_func, data_keys=DATA_KEYS, learning_rate=args.lr, writer_comment="SentenceVAE", sos_idx=ix_sos, unk_idx=None) test_batch = wrap_data(next(iter(training_loader)), DATA_KEYS) model_summary = torchsummaryX.summary(seq_vae.model, test_batch['data']) seq_vae.model.to(DEVICE) seq_vae.fit(training_loader=training_loader, validation_loader=validation_loader, epochs=args.epochs, device=DEVICE) return
def eval_metric(self, data_loader, device): self.model.eval() total_ppl = 0 total_elbo = 0 total_samples = 0 weight = 1. for i_bt, batch_data in enumerate(data_loader): batch_data = wrap_data(batch_data, self.data_keys, device=device) data, cond, label = batch_data["data"], batch_data[ "cond"], batch_data["label"] batch_size = label.shape[1] logits, mean, logvar, classify_logits = self.model.forward( input_data=data, cond=cond.T, z=None, batch_first_cond=True) loss = self.disentangled_loss(logits, label, mean, logvar, classify_logits, cond.T, batchfirst_cond=True) # loss = loss_nll + loss_kl * weight ppl = torch.nn.functional.nll_loss(logits.permute(0, 2, 1), label, reduction="mean") # NLL_loss is setting as reduction = "none" total_elbo += loss.item() * batch_size total_ppl += ppl.item() * batch_size total_samples += batch_size return total_elbo / total_samples, np.exp(total_ppl / total_samples)
def fit(self, training_loader, validation_loader, epochs, device, optimizer, criterion, teach_forcing=0.): # TODO not proper transfer of optimizer and criterion # TODO teach forcing ratio total_step = 0 for _epoc in range(epochs): total_loss = 0. num_sample = 0. with tqdm(total=len(training_loader), desc=f"{_epoc:2.0f}") as pbar: for i_bt, batch_data in enumerate(training_loader): batch_data = wrap_data(batch_data, self.data_keys, device=device) batch_size = batch_data["data"].shape[1] loss = self.train_step(batch_data, optimizer=optimizer, criterion=criterion) pbar.set_postfix({"loss": f"{loss:.3f}"}) pbar.update() total_loss += loss num_sample += batch_size total_step += 1 self.writer.add_scalar("training_loss", loss, total_step) training_ppl = self.eval_ppl(training_loader, criterion=criterion, device=device) validation_ppl = self.eval_ppl(validation_loader, criterion=criterion, device=device) pbar.set_postfix({ "ppl": f"{training_ppl:.1f}", "val_ppl": f"{validation_ppl:.1f}" }) self.writer.add_scalar("ppl/train", training_ppl, _epoc) self.writer.add_scalar("ppl/val", validation_ppl, _epoc) return
def train(args): traj, n_traj, translator = traj_prepare( TRAJ_PATH, CODE_PATH, NUM_SAMPLES, use_cols=args.begin_index + np.arange(args.num_per_day) * args.time_interval, add_idx=ADD_IDX) cond, n_categ_cond = cond_prepare(COND_PATH, NUM_SAMPLES) cond = cond[:, 1] # only use job here n_categ_cond = n_categ_cond[1] cond = trans_one_hot(cond.flatten(), n_categ_cond).astype(np.float32) NUM_CLASS = n_traj # add SOS num_training_sample = int(NUM_SAMPLES * 0.8) ix_sos = translator.trans_code2ix(["0/SOS"])[0] traj = np.concatenate([np.ones((traj.shape[0], 1), dtype=int), traj], axis=1) training_data, validation_data = traj[:num_training_sample, :-1], traj[num_training_sample:, :-1] training_label, validation_label = traj[:num_training_sample, 1:], traj[num_training_sample:, 1:] training_cond, validation_cond = cond[:num_training_sample], cond[num_training_sample:] training_loader = torch.utils.data.DataLoader( TrajDatasetGeneral([training_data, training_label, training_cond]), batch_size=args.batch_size, shuffle=True, num_workers=4, drop_last=False) validation_loader = torch.utils.data.DataLoader( TrajDatasetGeneral([validation_data, validation_label, validation_cond]), batch_size=2000, shuffle=False, num_workers=4, drop_last=False) seq_vae = DisentangledVAE(num_class=NUM_CLASS, embedding_size=args.embed_size, embedding_drop_out=args.emb_dropout, hid_size=args.hid_size, hid_layers=args.hid_layer, latent_z_size=args.latent_z_size, cond_size=n_categ_cond, word_dropout_rate=args.word_dropout, gru_drop_out=args.gru_dropout, anneal_k=args.anneal_k, anneal_x0=args.anneal_x0, anneal_function=args.anneal_func, data_keys=DATA_KEYS, learning_rate=args.lr, sos_idx=ix_sos, unk_idx=None) test_batch = wrap_data(next(iter(training_loader)), DATA_KEYS) _ = torchsummaryX.summary(seq_vae.model, test_batch['data'], test_batch['cond'], None, False) seq_vae.model.to(DEVICE) seq_vae.fit(training_loader=training_loader, validation_loader=validation_loader, epochs=args.epochs, device=DEVICE) return seq_vae, translator
def eval_ppl(self, data_loader, criterion, device=None): total_loss = 0. num_sample = 0. for i_bt, batch_data in enumerate(data_loader): batch_data = wrap_data(batch_data, self.data_keys, device=device) batch_size = batch_data["data"].shape[1] output, hn = self.forward_packed_data(batch_data, h0=None, teaching_ratio=0.) label = batch_data["label"] output = output.permute(0, 2, 1) loss = criterion(output, label) total_loss += loss.item() * batch_size num_sample += batch_size return np.exp(total_loss / num_sample)
embedding_size=args.embed_size, learning_rate=args.lr, data_keys=DATA_KEYS, writer_comment="Single_RNN") test_dt = training_loader.dataset[0:10][0].T test_dt = torch.tensor(test_dt) model_summary = torchsummaryX.summary(sing_gru, test_dt) #%% sing_gru.to(DEVICE) sing_gru.fit(training_loader=training_loader, validation_loader=validation_loader, epochs=args.epochs, device=DEVICE, optimizer=sing_gru.optimizer, criterion=sing_gru.criterion) log_dir = sing_gru.writer.log_dir sing_gru.writer.close() torch.save(sing_gru, f"{log_dir}/model.pt") #%% data = training_loader.dataset.get_sample(15) begin_time = int(args.num_per_day * 0.35) data[0] = data[0][:, :begin_time] data = wrap_data(data, keys=DATA_KEYS, device=DEVICE) sampled_traj = sing_gru.sample_traj(data, predict_length=args.num_per_day - args.begin_index, sample_method="multinomial")