def test_corpus(self, return_value=False): with torch.no_grad(): self.eval() info = {"acc": 0} total = 0 ct = Clock(self.utils.test_step_num) for step, (batch_x, batch_y) in enumerate(self.utils.test_data_generator()): elmo_x, x_lens = self.utils.elmo(batch_x) (elmo_x, x_lens, batch_y), _ind = sort_by([elmo_x, x_lens, batch_y], piv=1) label = self.utils.sents2idx(batch_y, len_fixed=False) target = torch.from_numpy(label).to( self.device) if self.device != "cpu" else torch.from_numpy( label) pred = self.forward(elmo_x, x_lens)[:, 1:-1] ans = torch.argmax(pred, dim=-1).cpu().numpy() idx = (target > 0).nonzero() acc = accuracy_score(label[idx[:, 0], idx[:, 1]], ans[idx[:, 0], idx[:, 1]]) info["acc"] += acc total += 1 target.cpu() ct.flush() self.train() print("test_acc:", info["acc"] / total) if return_value: return info["acc"] / total else: return
def pretrain_disc(self, num_epochs=100): # self.D.to(self.D.device) # self.G.to(self.G.device) # self.R.to(self.R.device) X_datagen = self.utils.data_generator("X") Y_datagen = self.utils.data_generator("Y") for epoch in range(num_epochs): d_steps = self.utils.train_step_num d_ct = Clock(d_steps, title="Train Discriminator(%d/%d)"%(epoch, num_epochs)) for step, X_data, Y_data in zip(range(d_steps), data_gen(X_datagen, self.utils.sents2idx), data_gen(Y_datagen, self.utils.sents2idx)): # 1. Train D on real+fake # if epoch == 0: # break self.D.zero_grad() # 1A: Train D on real d_real_pred = self.D(self.embed(Y_data.src.to(device))) d_real_error = self.criterion(d_real_pred, torch.ones((d_real_pred.shape[0],), dtype=torch.int64).to(self.D.device)) # ones = true # 1B: Train D on fake d_fake_pred = self.D(self.embed(X_data.src.to(device))) d_fake_error = self.criterion(d_fake_pred, torch.zeros((d_fake_pred.shape[0],), dtype=torch.int64).to(self.D.device)) # zeros = fake (d_fake_error + d_real_error).backward() self.D_opt.step() # Only optimizes D's parameters; changes based on stored gradients from backward() d_ct.flush(info={"D_loss": d_fake_error.item()}) torch.save(self.D.state_dict(), "model_disc_pretrain.ckpt")
def run_DA(self, big=0, lim=50): li = load_cna()[big:lim] ct = Clock(len(li), 5, 'DA') log = [] for i in li: temp_dict1 = {'sentence': ''.join(i)} self.sg = 0 cm = self.sent_sim(i) temp_dict1['graph'] = cm.tolist() temp_dict1['sg'] = self.sg tag = self.tag_mat(cm) log.append(temp_dict1) self.texfile.write('\\textbf{CBOW} \\\\ \n') self.write_tex(i, tag, cm) # plot_confusion_matrix(cm, i, tag, sv=True, sg=self.sg) temp_dict2 = {'sentence': ''.join(i)} self.sg = 1 cm = self.sent_sim(i) temp_dict2['graph'] = cm.tolist() temp_dict2['sg'] = self.sg tag = self.tag_mat(cm) log.append(temp_dict2) self.texfile.write('\\textbf{Skip-gram} \\\\ \n') self.write_tex(i, tag, cm) # plot_confusion_matrix(cm, i, tag, sv=True, sg=self.sg) ct.flush() with open('log.json', 'w') as fp: json.dump({'data': log}, fp)
def pretrain(self, num_epochs=1): self.G.to(self.G.device) self.encoder.to(self.encoder.device) real_datagen = self.utils.data_generator("train") test_datagen = self.utils.data_generator("test") for epoch in range(num_epochs): ct = Clock(self.utils.train_step_num, title="Pretrain(%d/%d)" % (epoch, num_epochs)) for real_data in real_datagen: # 2. Train G on D's response (but DO NOT train D on these labels) self.G.zero_grad() g_org_data, g_data_seqlen = self.utils.raw2elmo(real_data) gen_input = self.encoder(g_org_data, g_data_seqlen) g_fake_data = self.G(g_org_data, g_data_seqlen, hidden=gen_input) loss = self.mse(g_fake_data, torch.from_numpy(g_org_data).to(self.G.device)) loss.backward() self.G.optimizer.step() # Only optimizes G's parameters self.encoder.optimizer.step() ct.flush(info={"G_loss": loss.item()}) with torch.no_grad(): for _, real_data in zip(range(2), test_datagen): g_org_data, g_data_seqlen = self.utils.raw2elmo(real_data) [g_org_data, g_data_seqlen ], _ind = sort_by([g_org_data, g_data_seqlen], piv=1) g_mask_data, g_data_seqlen, g_mask_label = \ self.utils.elmo2mask(g_org_data, g_data_seqlen, mask_rate=epoch/num_epochs) gen_input = self.encoder(g_org_data, g_data_seqlen, sort=False) g_fake_data = self.G(g_mask_data, g_data_seqlen, hidden=gen_input, sort=False) gen_sents = self.invelmo.test(g_fake_data.cpu().numpy(), g_data_seqlen) for i, j in zip(real_data, gen_sents): print("=" * 50) print(' '.join(i)) print("---") print(' '.join(j)) print("=" * 50) torch.save(self.G.state_dict(), "pretrain_model.ckpt")
def pretrain_run_epoch(data_iter, model, loss_compute, train_step_num): "Standard Training and Logging Function" start = time.time() total_tokens = 0 total_loss = 0 tokens = 0 ct = Clock(train_step_num) for i, batch in enumerate(data_iter): out = model.forward(batch.src, batch.trg, batch.src_mask, batch.trg_mask) loss = loss_compute(out, batch.trg_y, batch.ntokens) total_loss += loss total_tokens += batch.ntokens tokens += batch.ntokens ct.flush(info={"loss":loss / batch.ntokens.float().to(device)}) return total_loss / total_tokens.float().to(device)
def pretrain_without_tf(data_iter, model, loss_compute, train_step_num, semi=False): "Real Training without Teacher Forcing" total_tokens = 0 total_loss = 0 tokens = 0 ct = Clock(train_step_num) for i, batch in enumerate(data_iter): if not semi: out = model.self_forward(batch.src, batch.src_mask, max_len=15 + 2) else: out = model.semi_forward(batch.src, batch.src_mask, max_len=15 + 2) loss = loss_compute(out, model.tgt_embed[0](batch.trg), batch.ntokens) total_loss += loss total_tokens += batch.ntokens tokens += batch.ntokens ct.flush(info={"loss":loss / batch.ntokens.float().to(device)}) return total_loss / total_tokens.float().to(device)
def pretrain_run_epoch(data_iter, model, loss_compute, train_step_num, tf_rate=0.7): "Standard Training and Logging Function" start = time.time() total_tokens = 0 total_loss = 0 tokens = 0 ct = Clock(train_step_num) for i, batch in enumerate(data_iter): if random.random() <= tf_rate: out = model.forward(batch.src, batch.trg, batch.src_mask, batch.trg_mask) else: out = model.semi_forward(batch.src, batch.src_mask, max_len=15+2) loss = loss_compute(out, batch.trg_y, batch.ntokens) total_loss += loss total_tokens += batch.ntokens tokens += batch.ntokens if i % 50 == 1: elapsed = time.time() - start ct.flush(info={"loss":loss / batch.ntokens.float().to(device), "tok/sec":tokens.float().to(device) / elapsed}) # print("Epoch Step: %d Loss: %f Tokens per Sec: %f" % # (i, loss / batch.ntokens.float().to(device), tokens.float().to(device) / elapsed)) start = time.time() tokens = 0 else: ct.flush(info={"loss":loss / batch.ntokens.float().to(device)}) return total_loss / total_tokens.float().to(device)
def pretrain_run_epoch(data_iter, model, loss_compute, train_step_num, embedding_layer): "Standard Training and Logging Function" start = time.time() total_tokens = 0 total_loss = 0 tokens = 0 ct = Clock(train_step_num) embedding_layer.to(device) model.to(device) for i, batch in enumerate(data_iter): batch.to(device) out = model.forward(embedding_layer(batch.src.to(device)), embedding_layer(batch.trg.to(device)), batch.src_mask, batch.trg_mask) loss = loss_compute(out, batch.trg_y, batch.ntokens) total_loss += loss total_tokens += batch.ntokens tokens += batch.ntokens batch.to("cpu") if i % 50 == 1: elapsed = time.time() - start ct.flush(info={"loss":loss / batch.ntokens.float().to(device), "tok/sec":tokens.float().to(device) / elapsed}) # print("Epoch Step: %d Loss: %f Tokens per Sec: %f" % # (i, loss / batch.ntokens.float().to(device), tokens.float().to(device) / elapsed)) start = time.time() tokens = 0 else: ct.flush(info={"loss":loss / batch.ntokens.float().to(device)}) return total_loss / total_tokens.float().to(device)
def train_model(self, num_epochs=100, d_steps=20, g_steps=80, g_scale=1.0, r_scale=1.0): # self.D.to(self.D.device) # self.G.to(self.G.device) # self.R.to(self.R.device) for i, batch in enumerate(data_gen(utils.data_generator("X"), utils.sents2idx)): X_test_batch = batch break for i, batch in enumerate(data_gen(utils.data_generator("Y"), utils.sents2idx)): Y_test_batch = batch break X_datagen = self.utils.data_generator("X") Y_datagen = self.utils.data_generator("Y") for epoch in range(num_epochs): d_ct = Clock(d_steps, title="Train Discriminator(%d/%d)" % (epoch, num_epochs)) if epoch>0: for i, X_data, Y_data in zip(range(d_steps), data_gen(X_datagen, self.utils.sents2idx), data_gen(Y_datagen, self.utils.sents2idx)): # 1. Train D on real+fake # if epoch == 0: # break self.D.zero_grad() # 1A: Train D on real d_real_pred = self.D(self.embed(Y_data.src.to(device))) d_real_error = self.criterion(d_real_pred, torch.ones((d_real_pred.shape[0],), dtype=torch.int64).to(self.D.device)) # ones = true # 1B: Train D on fake self.G.to(device) d_fake_data = backward_decode(self.G, self.embed, X_data.src, X_data.src_mask, max_len=self.utils.max_len, return_term=0).detach() # detach to avoid training G on these labels d_fake_pred = self.D(d_fake_data) d_fake_error = self.criterion(d_fake_pred, torch.zeros((d_fake_pred.shape[0],), dtype=torch.int64).to(self.D.device)) # zeros = fake (d_fake_error + d_real_error).backward() self.D_opt.step() # Only optimizes D's parameters; changes based on stored gradients from backward() d_ct.flush(info={"D_loss":d_fake_error.item()}) g_ct = Clock(g_steps, title="Train Generator(%d/%d)"%(epoch, num_epochs)) r_ct = Clock(g_steps, title="Train Reconstructor(%d/%d)" % (epoch, num_epochs)) if epoch>0: for i, X_data in zip(range(g_steps), data_gen(X_datagen, self.utils.sents2idx)): # 2. Train G on D's response (but DO NOT train D on these labels) self.G.zero_grad() g_fake_data = backward_decode(self.G, self.embed, X_data.src, X_data.src_mask, max_len=self.utils.max_len, return_term=0) dg_fake_pred = self.D(g_fake_data) g_error = self.criterion(dg_fake_pred, torch.ones((dg_fake_pred.shape[0],), dtype=torch.int64).to(self.D.device)) # we want to fool, so pretend it's all genuine g_error.backward(retain_graph=True) self.G_opt.step() # Only optimizes G's parameters self.G.zero_grad() g_ct.flush(info={"G_loss": g_error.item()}) # 3. reconstructor 643988636173-69t5i8ehelccbq85o3esu11jgh61j8u5.apps.googleusercontent.com # way_3 out = self.R.forward(g_fake_data, embedding_layer(X_data.trg.to(device)), None, X_data.trg_mask) r_loss = self.r_loss_compute(out, X_data.trg_y, X_data.ntokens) # way_2 # r_reco_data = prob_backward(self.R, self.embed, g_fake_data, None, max_len=self.utils.max_len, raw=True) # x_orgi_data = X_data.src[:, 1:] # r_loss = SimpleLossCompute(None, criterion, self.R_opt)(r_reco_data, x_orgi_data, X_data.ntokens) # way_1 # viewed_num = r_reco_data.shape[0]*r_reco_data.shape[1] # r_error = r_scale*self.cosloss(r_reco_data.float().view(-1, self.embed.weight.shape[1]), x_orgi_data.float().view(-1, self.embed.weight.shape[1]), torch.ones(viewed_num, dtype=torch.float32).to(device)) self.G_opt.step() self.G_opt.optimizer.zero_grad() r_ct.flush(info={"G_loss": g_error.item(), "R_loss": r_loss / X_data.ntokens.float().to(device)}) with torch.no_grad(): x_cont, x_ys = backward_decode(model, self.embed, X_test_batch.src, X_test_batch.src_mask, max_len=25, start_symbol=2) x = utils.idx2sent(x_ys) y_cont, y_ys = backward_decode(model, self.embed, Y_test_batch.src, Y_test_batch.src_mask, max_len=25, start_symbol=2) y = utils.idx2sent(y_ys) r_x = utils.idx2sent(backward_decode(self.R, self.embed, x_cont, None, max_len=self.utils.max_len, raw=True, return_term=1)) r_y = utils.idx2sent(backward_decode(self.R, self.embed, y_cont, None, max_len=self.utils.max_len, raw=True, return_term=1)) for i,j,l in zip(X_test_batch.src, x, r_x): print("===") k = utils.idx2sent([i])[0] print("ORG:", " ".join(k[:k.index('<eos>')+1])) print("--") print("GEN:", " ".join(j[:j.index('<eos>')+1] if '<eos>' in j else j)) print("--") print("REC:", " ".join(l[:l.index('<eos>')+1] if '<eos>' in l else l)) print("===") print("=====") for i, j, l in zip(Y_test_batch.src, y, r_y): print("===") k = utils.idx2sent([i])[0] print("ORG:", " ".join(k[:k.index('<eos>')+1])) print("--") print("GEN:", " ".join(j[:j.index('<eos>')+1] if '<eos>' in j else j)) print("--") print("REC:", " ".join(l[:l.index('<eos>')+1] if '<eos>' in l else l)) print("===")
def train_model(self, num_epochs=1, step_to_save_model=1000, check_point=False): self.to(self.device) self.train() max_acc = 0.0 for epoch in range(num_epochs): ct = Clock(self.utils.train_step_num, title="Epoch(%d/%d)" % (epoch + 1, num_epochs)) His_loss = History(title="Loss", xlabel="step", ylabel="loss", item_name=["train_loss"]) His_acc = History(title="Acc", xlabel="step", ylabel="accuracy", item_name=["train_acc"]) for step, (batch_x, batch_y) in enumerate(self.utils.data_generator()): elmo_x, x_lens = self.utils.elmo(batch_x, max_len=self.utils.max_len) (elmo_x, x_lens, batch_y), _ind = sort_by([elmo_x, x_lens, batch_y], piv=1) label = self.utils.sents2idx(batch_y) target = torch.from_numpy(label).to( self.device) if self.device != "cpu" else torch.from_numpy( label) pred = self.forward(elmo_x, x_lens)[:, 1:-1] if pred.shape[1] < target.shape[1]: p1d = (pred.shape[0], target.shape[1] - pred.shape[1], pred.shape[2]) to_pad = torch.zeros(*p1d).to(self.device) pred = torch.cat([pred, to_pad], dim=1) try: loss = self.criterion(pred.transpose(1, 2), target) except: print(batch_x) print(batch_y) sys.exit() self.optimizer.zero_grad() loss.backward() self.optimizer.step() ans = torch.argmax(pred, dim=-1).cpu().numpy() idx = (target > 0).nonzero() acc = accuracy_score(label[idx[:, 0], idx[:, 1]], ans[idx[:, 0], idx[:, 1]]) info_dict = {"loss": loss, "ppl": math.exp(loss), "acc": acc} ct.flush(info=info_dict) His_loss.append_history(0, (step, loss)) His_acc.append_history(0, (step, acc)) target.cpu() if (step + 1) % step_to_save_model == 0: torch.save(self.state_dict(), os.path.join(self.save_path, 'model.ckpt')) His_loss.plot(os.path.join(self.save_path, "loss_plot")) His_acc.plot(os.path.join(self.save_path, "acc_plot")) # acc, f1 = self.test_corpus(return_value=True) if check_point: if acc > max_acc: path = os.path.join(self.save_path, 'model.ckpt') print( "Checkpoint: acc grow from %4f to %4f, save model to %s" % (max_acc, acc, path)) torch.save(self.state_dict(), path) max_acc = acc else: print( "Checkpoint: acc not grow from %4f to %4f, model not save." % (max_acc, acc)) else: path = os.path.join(self.save_path, 'model.ckpt') torch.save(self.state_dict(), path) His_loss.plot( os.path.join( self.save_path, "Loss_" + self.utils.zhuyin_data_path.split('/')[-1] + "_%d" % num_epochs)) His_acc.plot( os.path.join( self.save_path, "Acc_" + self.utils.zhuyin_data_path.split('/')[-1] + "_%d" % num_epochs))
def train_model(self, train_file, save_model_name, num_epochs=10, prefix=['a', 'b']): self.to(self.device) train_loader, test_loader = self.get_dataloader(train_file, shuffle=True, prefix=prefix) # Train the model total_step = len(train_loader) His = History(title="TrainingCurve", xlabel="step", ylabel="f1-score", item_name=["train_Nf", "train_Na", "test_Nf", "test_Na"]) step_idx = 0 for epoch in range(num_epochs): self.train() ct = Clock(len(train_loader),title="Epoch %d/%d"%(epoch+1, num_epochs)) ac_loss = 0 num = 0 f1ab = 0 f1cd = 0 for i, li in enumerate(train_loader): li = self.expand_data(li) [arcs, seqs, seq_len, label], ind = sort_by(li, piv=2) [arcs, seqs, seq_len, label] = [arcs.to(self.device), seqs.to(self.device), seq_len.to(self.device), label.to(self.device)] # this_bs = arcs.shape[0] # Forward pass root, graph = self.forward(seqs, seq_len) ans = self.arcs_expand(arcs, graph.shape[-1], label) root_ans = (ans.sum(-1) > 0).float() root = root.squeeze() # ans = torch.zeros_like(graph) # for i in range(len(arcs)): # ans[i,arcs[i][0],arcs[i][1]] = int(label[i]==1) # loss_1 = self.criterion(root[(label==1).nonzero().squeeze(1)], arcs[:, 0][(label==1).nonzero().squeeze(1)].unsqueeze(1)) if not ((root_ans >= 0).all() & (root_ans <= 1).all()).item(): print(root_ans) if not ((root >= 0).all() & (root <= 1).all()).item(): print(root) loss_1 = self.bce(root, root_ans) loss_2 = 0.3 * self.multi_target_loss(graph, ans) # return loss_1, loss_2 # print(lossnum) loss = loss_1 + loss_2 ac_loss += loss.item() num += 1 # if dev: # return root, graph, [arcs, seqs, seq_len, label], ans # total, nf_corr, na_corr = self.acc(arcs, root, graph, label) info_dict = self.multi_acc(arcs, root, graph, label) f1ab += f1_score(info_dict['a'], info_dict['b']) f1cd += f1_score(info_dict['c'], info_dict['d']) # info_dict = {'loss':ac_loss/num, 'accNf':train_nf_corr/train_total, 'accNa':train_na_corr/train_total} ct.flush(info=info_dict) step_idx += 1 self.optimizer.zero_grad() loss.backward() self.optimizer.step() His.append_history(0, (step_idx, f1ab/num)) His.append_history(1, (step_idx, f1cd/num)) with torch.no_grad(): self.eval() # nf_correct = 0 # na_correct = 0 # test_total = 0 for i, li in enumerate(test_loader): li = self.expand_data(li) [arcs, seqs, seq_len, label], ind = sort_by(li, piv=2) [arcs, seqs, seq_len, label] = [arcs.to(self.device), seqs.to(self.device), seq_len.to(self.device), label.to(self.device)] root, graph = self.forward(seqs, seq_len) root = root.squeeze() info_dict = self.multi_acc(arcs, root, graph, label) ct.flush(info=info_dict) # t, f, a = self.test(arcs, seqs, seq_len, label) # test_total += t # nf_correct += f # na_correct += a # info_dict = {'val_accNf':nf_correct/test_total, 'val_accNa':na_correct/test_total} # ct.flush(info={'loss':ac_loss/num, 'val_accNf':nf_correct/test_total, 'val_accNa':na_correct/test_total}) His.append_history(2, (step_idx, f1_score(info_dict['a'], info_dict['b']))) His.append_history(3, (step_idx, f1_score(info_dict['c'], info_dict['d']))) # Save the model checkpoint torch.save(self.state_dict(), save_model_name) His.plot()
p = Parser(batch_size=args.batch_size) thres = [args.threshold, args.threshold_2] if args.threshold_2 is not None else args.threshold if args.mode == 'print' or args.out_file is None: out_file = sys.stdout else: out_file = open(args.out_file, 'w') if args.mode in ['print', 'write']: prefix = args.prefix.split("_") assert len(prefix) == 2 p.load_model(args.model_name) sents = [] poss = [] f = open(args.in_file) total_num = int(os.popen("wc -l %s" % args.in_file).read().split(' ')[0]) ct = Clock(total_num, title="===> Parsing File %s with %d lines"%(args.in_file, total_num)) for i in f: if args.timing == "True" and args.mode != 'print': ct.flush() sent_list = i.strip().split(' ') if len(sent_list) < 2: print("TOO SHORT ==>", ' '.join(sent_list), file=out_file) continue elif len(sent_list)>args.max_len: print("TOO LONG ==>", ' '.join(sent_list), file=out_file) continue sents.append(sent_list) if len(sents) >= p.batch_size: p.evaluate(sents, prefix=prefix, threshold=thres, file=out_file, print_out=(out_file is not None), show_score=(args.show_score=="True")) sents = [] if len(sents)>0:
def train_model(self, num_epochs=100, d_steps=10, g_steps=10): self.D.to(self.D.device) self.G.to(self.G.device) self.encoder.to(self.encoder.device) real_datagen = self.utils.data_generator("train") test_datagen = self.utils.data_generator("test") for epoch in range(num_epochs): d_ct = Clock(d_steps, title="Train Discriminator(%d/%d)" % (epoch, num_epochs)) for d_step, real_data in zip(range(d_steps), real_datagen): # 1. Train D on real+fake self.D.zero_grad() # 1A: Train D on real d_org_data, d_data_seqlen = self.utils.raw2elmo(real_data) d_mask_data, d_data_seqlen, d_mask_label = \ self.utils.elmo2mask(d_org_data, d_data_seqlen, mask_rate=epoch/num_epochs) d_real_pred = self.D(d_org_data, d_data_seqlen) d_real_error = self.criterion( d_real_pred.transpose(1, 2), torch.ones(d_mask_label.shape, dtype=torch.int64).to( self.D.device)) # ones = true d_real_error.backward( ) # compute/store gradients, but don't change params self.D.optimizer.step() # 1B: Train D on fake d_gen_input = self.encoder(d_org_data, d_data_seqlen) d_fake_data = self.G( d_mask_data, d_data_seqlen, hidden=d_gen_input).detach( ) # detach to avoid training G on these labels d_fake_pred = self.D(d_fake_data, d_data_seqlen, numpy=False) d_fake_error = self.criterion( d_fake_pred.transpose(1, 2), torch.from_numpy(d_mask_label).to( self.D.device)) # zeros = fake d_fake_error.backward() self.D.optimizer.step( ) # Only optimizes D's parameters; changes based on stored gradients from backward() d_ct.flush(info={"D_loss": d_fake_error.item()}) g_ct = Clock(g_steps, title="Train Generator(%d/%d)" % (epoch, num_epochs)) for g_step, real_data in zip(range(g_steps), real_datagen): # 2. Train G on D's response (but DO NOT train D on these labels) self.G.zero_grad() g_org_data, g_data_seqlen = self.utils.raw2elmo(real_data) g_mask_data, g_data_seqlen, g_mask_label = \ self.utils.elmo2mask(g_org_data, g_data_seqlen, mask_rate=epoch/num_epochs) gen_input = self.encoder(g_org_data, g_data_seqlen) g_fake_data = self.G(g_mask_data, g_data_seqlen, hidden=gen_input) dg_fake_pred = self.D(g_fake_data, g_data_seqlen, numpy=False) g_error = self.criterion( dg_fake_pred.transpose(1, 2), torch.ones(g_mask_label.shape, dtype=torch.int64).to(self.D.device) ) # we want to fool, so pretend it's all genuine g_error.backward() self.G.optimizer.step() # Only optimizes G's parameters self.encoder.optimizer.step() g_ct.flush(info={"G_loss": g_error.item()}) with torch.no_grad(): for _, real_data in zip(range(2), test_datagen): g_org_data, g_data_seqlen = self.utils.raw2elmo(real_data) [g_org_data, g_data_seqlen ], _ind = sort_by([g_org_data, g_data_seqlen], piv=1) g_mask_data, g_data_seqlen, g_mask_label = \ self.utils.elmo2mask(g_org_data, g_data_seqlen, mask_rate=epoch/num_epochs) gen_input = self.encoder(g_org_data, g_data_seqlen, sort=False) g_fake_data = self.G(g_mask_data, g_data_seqlen, hidden=gen_input, sort=False) gen_sents = self.invelmo.test(g_fake_data.cpu().numpy(), g_data_seqlen) for i, j in zip(real_data, gen_sents): print("=" * 50) print(' '.join(i)) print("---") print(' '.join(j)) print("=" * 50) self.save_model()