def test_corpus(self, return_value=False):
     with torch.no_grad():
         self.eval()
         info = {"acc": 0}
         total = 0
         ct = Clock(self.utils.test_step_num)
         for step, (batch_x,
                    batch_y) in enumerate(self.utils.test_data_generator()):
             elmo_x, x_lens = self.utils.elmo(batch_x)
             (elmo_x, x_lens,
              batch_y), _ind = sort_by([elmo_x, x_lens, batch_y], piv=1)
             label = self.utils.sents2idx(batch_y, len_fixed=False)
             target = torch.from_numpy(label).to(
                 self.device) if self.device != "cpu" else torch.from_numpy(
                     label)
             pred = self.forward(elmo_x, x_lens)[:, 1:-1]
             ans = torch.argmax(pred, dim=-1).cpu().numpy()
             idx = (target > 0).nonzero()
             acc = accuracy_score(label[idx[:, 0], idx[:, 1]],
                                  ans[idx[:, 0], idx[:, 1]])
             info["acc"] += acc
             total += 1
             target.cpu()
             ct.flush()
         self.train()
         print("test_acc:", info["acc"] / total)
         if return_value:
             return info["acc"] / total
         else:
             return
    def pretrain_disc(self, num_epochs=100):
        # self.D.to(self.D.device)
        # self.G.to(self.G.device)
        # self.R.to(self.R.device)
        X_datagen = self.utils.data_generator("X")
        Y_datagen = self.utils.data_generator("Y")
        for epoch in range(num_epochs):
            d_steps = self.utils.train_step_num
            d_ct = Clock(d_steps, title="Train Discriminator(%d/%d)"%(epoch, num_epochs))
            for step, X_data, Y_data in zip(range(d_steps), data_gen(X_datagen, self.utils.sents2idx), data_gen(Y_datagen, self.utils.sents2idx)):
                # 1. Train D on real+fake
                # if epoch == 0:
                #     break
                self.D.zero_grad()
        
                #  1A: Train D on real
                d_real_pred = self.D(self.embed(Y_data.src.to(device)))
                d_real_error = self.criterion(d_real_pred, torch.ones((d_real_pred.shape[0],), dtype=torch.int64).to(self.D.device))  # ones = true

                #  1B: Train D on fake
                d_fake_pred = self.D(self.embed(X_data.src.to(device)))
                d_fake_error = self.criterion(d_fake_pred, torch.zeros((d_fake_pred.shape[0],), dtype=torch.int64).to(self.D.device))  # zeros = fake
                (d_fake_error + d_real_error).backward()
                self.D_opt.step()     # Only optimizes D's parameters; changes based on stored gradients from backward()
                d_ct.flush(info={"D_loss": d_fake_error.item()})
        torch.save(self.D.state_dict(), "model_disc_pretrain.ckpt")
 def run_DA(self, big=0, lim=50):
     li = load_cna()[big:lim]
     ct = Clock(len(li), 5, 'DA')
     log = []
     for i in li:
         temp_dict1 = {'sentence': ''.join(i)}
         self.sg = 0
         cm = self.sent_sim(i)
         temp_dict1['graph'] = cm.tolist()
         temp_dict1['sg'] = self.sg
         tag = self.tag_mat(cm)
         log.append(temp_dict1)
         self.texfile.write('\\textbf{CBOW} \\\\ \n')
         self.write_tex(i, tag, cm)
         # plot_confusion_matrix(cm, i, tag, sv=True, sg=self.sg)
         temp_dict2 = {'sentence': ''.join(i)}
         self.sg = 1
         cm = self.sent_sim(i)
         temp_dict2['graph'] = cm.tolist()
         temp_dict2['sg'] = self.sg
         tag = self.tag_mat(cm)
         log.append(temp_dict2)
         self.texfile.write('\\textbf{Skip-gram} \\\\ \n')
         self.write_tex(i, tag, cm)
         # plot_confusion_matrix(cm, i, tag, sv=True, sg=self.sg)
         ct.flush()
     with open('log.json', 'w') as fp:
         json.dump({'data': log}, fp)
    def pretrain(self, num_epochs=1):
        self.G.to(self.G.device)
        self.encoder.to(self.encoder.device)
        real_datagen = self.utils.data_generator("train")
        test_datagen = self.utils.data_generator("test")
        for epoch in range(num_epochs):
            ct = Clock(self.utils.train_step_num,
                       title="Pretrain(%d/%d)" % (epoch, num_epochs))
            for real_data in real_datagen:
                # 2. Train G on D's response (but DO NOT train D on these labels)
                self.G.zero_grad()

                g_org_data, g_data_seqlen = self.utils.raw2elmo(real_data)

                gen_input = self.encoder(g_org_data, g_data_seqlen)
                g_fake_data = self.G(g_org_data,
                                     g_data_seqlen,
                                     hidden=gen_input)
                loss = self.mse(g_fake_data,
                                torch.from_numpy(g_org_data).to(self.G.device))

                loss.backward()
                self.G.optimizer.step()  # Only optimizes G's parameters
                self.encoder.optimizer.step()
                ct.flush(info={"G_loss": loss.item()})

            with torch.no_grad():
                for _, real_data in zip(range(2), test_datagen):
                    g_org_data, g_data_seqlen = self.utils.raw2elmo(real_data)
                    [g_org_data, g_data_seqlen
                     ], _ind = sort_by([g_org_data, g_data_seqlen], piv=1)
                    g_mask_data, g_data_seqlen, g_mask_label = \
                    self.utils.elmo2mask(g_org_data, g_data_seqlen, mask_rate=epoch/num_epochs)
                    gen_input = self.encoder(g_org_data,
                                             g_data_seqlen,
                                             sort=False)
                    g_fake_data = self.G(g_mask_data,
                                         g_data_seqlen,
                                         hidden=gen_input,
                                         sort=False)

                    gen_sents = self.invelmo.test(g_fake_data.cpu().numpy(),
                                                  g_data_seqlen)
                    for i, j in zip(real_data, gen_sents):
                        print("=" * 50)
                        print(' '.join(i))
                        print("---")
                        print(' '.join(j))
                        print("=" * 50)
            torch.save(self.G.state_dict(), "pretrain_model.ckpt")
def pretrain_run_epoch(data_iter, model, loss_compute, train_step_num):
    "Standard Training and Logging Function"
    start = time.time()
    total_tokens = 0
    total_loss = 0
    tokens = 0
    ct = Clock(train_step_num)
    for i, batch in enumerate(data_iter):
        out = model.forward(batch.src, batch.trg, 
                            batch.src_mask, batch.trg_mask)
        loss = loss_compute(out, batch.trg_y, batch.ntokens)
        total_loss += loss
        total_tokens += batch.ntokens
        tokens += batch.ntokens
        ct.flush(info={"loss":loss / batch.ntokens.float().to(device)})
    return total_loss / total_tokens.float().to(device)
def pretrain_without_tf(data_iter, model, loss_compute, train_step_num, semi=False):
    "Real Training without Teacher Forcing"
    total_tokens = 0
    total_loss = 0
    tokens = 0
    ct = Clock(train_step_num)
    for i, batch in enumerate(data_iter):
        if not semi:
            out = model.self_forward(batch.src, 
                                batch.src_mask, max_len=15 + 2)
        else:
            out = model.semi_forward(batch.src, 
                                batch.src_mask, max_len=15 + 2)
        loss = loss_compute(out, model.tgt_embed[0](batch.trg), batch.ntokens)
        total_loss += loss
        total_tokens += batch.ntokens
        tokens += batch.ntokens
        ct.flush(info={"loss":loss / batch.ntokens.float().to(device)})
    return total_loss / total_tokens.float().to(device)
def pretrain_run_epoch(data_iter, model, loss_compute, train_step_num, tf_rate=0.7):
    "Standard Training and Logging Function"
    start = time.time()
    total_tokens = 0
    total_loss = 0
    tokens = 0
    ct = Clock(train_step_num)
    for i, batch in enumerate(data_iter):
        if random.random() <= tf_rate:
            out = model.forward(batch.src, batch.trg, 
                                batch.src_mask, batch.trg_mask)
        else:
            out = model.semi_forward(batch.src,
                                batch.src_mask, max_len=15+2)
        loss = loss_compute(out, batch.trg_y, batch.ntokens)
        total_loss += loss
        total_tokens += batch.ntokens
        tokens += batch.ntokens
        if i % 50 == 1:
            elapsed = time.time() - start
            ct.flush(info={"loss":loss / batch.ntokens.float().to(device), "tok/sec":tokens.float().to(device) / elapsed})
            # print("Epoch Step: %d Loss: %f Tokens per Sec: %f" %
            #         (i, loss / batch.ntokens.float().to(device), tokens.float().to(device) / elapsed))
            start = time.time()
            tokens = 0
        else:
            ct.flush(info={"loss":loss / batch.ntokens.float().to(device)})
    return total_loss / total_tokens.float().to(device)
def pretrain_run_epoch(data_iter, model, loss_compute, train_step_num, embedding_layer):
    "Standard Training and Logging Function"
    start = time.time()
    total_tokens = 0
    total_loss = 0
    tokens = 0
    ct = Clock(train_step_num)
    embedding_layer.to(device)
    model.to(device)
    for i, batch in enumerate(data_iter):
        batch.to(device)
        out = model.forward(embedding_layer(batch.src.to(device)), embedding_layer(batch.trg.to(device)), 
                batch.src_mask, batch.trg_mask)
        loss = loss_compute(out, batch.trg_y, batch.ntokens)
        total_loss += loss
        total_tokens += batch.ntokens
        tokens += batch.ntokens
        batch.to("cpu")
        if i % 50 == 1:
            elapsed = time.time() - start
            ct.flush(info={"loss":loss / batch.ntokens.float().to(device), "tok/sec":tokens.float().to(device) / elapsed})
            # print("Epoch Step: %d Loss: %f Tokens per Sec: %f" %
            #         (i, loss / batch.ntokens.float().to(device), tokens.float().to(device) / elapsed))
            start = time.time()
            tokens = 0
        else:
            ct.flush(info={"loss":loss / batch.ntokens.float().to(device)})
    return total_loss / total_tokens.float().to(device)
    def train_model(self, num_epochs=100, d_steps=20, g_steps=80, g_scale=1.0, r_scale=1.0):
        # self.D.to(self.D.device)
        # self.G.to(self.G.device)
        # self.R.to(self.R.device)
        for i, batch in enumerate(data_gen(utils.data_generator("X"), utils.sents2idx)):
            X_test_batch = batch
            break

        for i, batch in enumerate(data_gen(utils.data_generator("Y"), utils.sents2idx)):
            Y_test_batch = batch
            break
        X_datagen = self.utils.data_generator("X")
        Y_datagen = self.utils.data_generator("Y")
        for epoch in range(num_epochs):
            d_ct = Clock(d_steps, title="Train Discriminator(%d/%d)" % (epoch, num_epochs))
            if epoch>0:
                for i, X_data, Y_data in zip(range(d_steps), data_gen(X_datagen, self.utils.sents2idx), data_gen(Y_datagen, self.utils.sents2idx)):
                    # 1. Train D on real+fake
                    # if epoch == 0:
                    #     break
                    self.D.zero_grad()
            
                    #  1A: Train D on real
                    d_real_pred = self.D(self.embed(Y_data.src.to(device)))
                    d_real_error = self.criterion(d_real_pred, torch.ones((d_real_pred.shape[0],), dtype=torch.int64).to(self.D.device))  # ones = true

                    #  1B: Train D on fake
                    self.G.to(device)
                    d_fake_data = backward_decode(self.G, self.embed, X_data.src, X_data.src_mask, max_len=self.utils.max_len, return_term=0).detach()  # detach to avoid training G on these labels
                    d_fake_pred = self.D(d_fake_data)
                    d_fake_error = self.criterion(d_fake_pred, torch.zeros((d_fake_pred.shape[0],), dtype=torch.int64).to(self.D.device))  # zeros = fake
                    (d_fake_error + d_real_error).backward()
                    self.D_opt.step()     # Only optimizes D's parameters; changes based on stored gradients from backward()
                    d_ct.flush(info={"D_loss":d_fake_error.item()})

            g_ct = Clock(g_steps, title="Train Generator(%d/%d)"%(epoch, num_epochs))
            r_ct = Clock(g_steps, title="Train Reconstructor(%d/%d)" % (epoch, num_epochs))
            if epoch>0:
                for i, X_data in zip(range(g_steps), data_gen(X_datagen, self.utils.sents2idx)):
                    # 2. Train G on D's response (but DO NOT train D on these labels)
                    self.G.zero_grad()
                    g_fake_data = backward_decode(self.G, self.embed, X_data.src, X_data.src_mask, max_len=self.utils.max_len, return_term=0)
                    dg_fake_pred = self.D(g_fake_data)
                    g_error = self.criterion(dg_fake_pred, torch.ones((dg_fake_pred.shape[0],), dtype=torch.int64).to(self.D.device))  # we want to fool, so pretend it's all genuine
            
                    g_error.backward(retain_graph=True)
                    self.G_opt.step()  # Only optimizes G's parameters
                    self.G.zero_grad()
                    g_ct.flush(info={"G_loss": g_error.item()})

                    # 3. reconstructor  643988636173-69t5i8ehelccbq85o3esu11jgh61j8u5.apps.googleusercontent.com
                    # way_3
                    out = self.R.forward(g_fake_data, embedding_layer(X_data.trg.to(device)), 
                        None, X_data.trg_mask)
                    r_loss = self.r_loss_compute(out, X_data.trg_y, X_data.ntokens)
                    # way_2
                    # r_reco_data = prob_backward(self.R, self.embed, g_fake_data, None, max_len=self.utils.max_len, raw=True)
                    # x_orgi_data = X_data.src[:, 1:]
                    # r_loss = SimpleLossCompute(None, criterion, self.R_opt)(r_reco_data, x_orgi_data, X_data.ntokens)
                    # way_1
                    # viewed_num = r_reco_data.shape[0]*r_reco_data.shape[1]
                    # r_error = r_scale*self.cosloss(r_reco_data.float().view(-1, self.embed.weight.shape[1]), x_orgi_data.float().view(-1, self.embed.weight.shape[1]), torch.ones(viewed_num, dtype=torch.float32).to(device))
                    self.G_opt.step()
                    self.G_opt.optimizer.zero_grad()
                    r_ct.flush(info={"G_loss": g_error.item(),
                    "R_loss": r_loss / X_data.ntokens.float().to(device)})
                
            with torch.no_grad():
                x_cont, x_ys = backward_decode(model, self.embed, X_test_batch.src, X_test_batch.src_mask, max_len=25, start_symbol=2)
                x = utils.idx2sent(x_ys)
                y_cont, y_ys = backward_decode(model, self.embed, Y_test_batch.src, Y_test_batch.src_mask, max_len=25, start_symbol=2)
                y = utils.idx2sent(y_ys)
                r_x = utils.idx2sent(backward_decode(self.R, self.embed, x_cont, None, max_len=self.utils.max_len, raw=True, return_term=1))
                r_y = utils.idx2sent(backward_decode(self.R, self.embed, y_cont, None, max_len=self.utils.max_len, raw=True, return_term=1))

                for i,j,l in zip(X_test_batch.src, x, r_x):
                    print("===")
                    k = utils.idx2sent([i])[0]
                    print("ORG:", " ".join(k[:k.index('<eos>')+1]))
                    print("--")
                    print("GEN:", " ".join(j[:j.index('<eos>')+1] if '<eos>' in j else j))
                    print("--")
                    print("REC:", " ".join(l[:l.index('<eos>')+1] if '<eos>' in l else l))
                    print("===")
                print("=====")
                for i, j, l in zip(Y_test_batch.src, y, r_y):
                    print("===")
                    k = utils.idx2sent([i])[0]
                    print("ORG:", " ".join(k[:k.index('<eos>')+1]))
                    print("--")
                    print("GEN:", " ".join(j[:j.index('<eos>')+1] if '<eos>' in j else j))
                    print("--")
                    print("REC:", " ".join(l[:l.index('<eos>')+1] if '<eos>' in l else l))
                    print("===")
Esempio n. 10
0
 def train_model(self,
                 num_epochs=1,
                 step_to_save_model=1000,
                 check_point=False):
     self.to(self.device)
     self.train()
     max_acc = 0.0
     for epoch in range(num_epochs):
         ct = Clock(self.utils.train_step_num,
                    title="Epoch(%d/%d)" % (epoch + 1, num_epochs))
         His_loss = History(title="Loss",
                            xlabel="step",
                            ylabel="loss",
                            item_name=["train_loss"])
         His_acc = History(title="Acc",
                           xlabel="step",
                           ylabel="accuracy",
                           item_name=["train_acc"])
         for step, (batch_x,
                    batch_y) in enumerate(self.utils.data_generator()):
             elmo_x, x_lens = self.utils.elmo(batch_x,
                                              max_len=self.utils.max_len)
             (elmo_x, x_lens,
              batch_y), _ind = sort_by([elmo_x, x_lens, batch_y], piv=1)
             label = self.utils.sents2idx(batch_y)
             target = torch.from_numpy(label).to(
                 self.device) if self.device != "cpu" else torch.from_numpy(
                     label)
             pred = self.forward(elmo_x, x_lens)[:, 1:-1]
             if pred.shape[1] < target.shape[1]:
                 p1d = (pred.shape[0], target.shape[1] - pred.shape[1],
                        pred.shape[2])
                 to_pad = torch.zeros(*p1d).to(self.device)
                 pred = torch.cat([pred, to_pad], dim=1)
             try:
                 loss = self.criterion(pred.transpose(1, 2), target)
             except:
                 print(batch_x)
                 print(batch_y)
                 sys.exit()
             self.optimizer.zero_grad()
             loss.backward()
             self.optimizer.step()
             ans = torch.argmax(pred, dim=-1).cpu().numpy()
             idx = (target > 0).nonzero()
             acc = accuracy_score(label[idx[:, 0], idx[:, 1]],
                                  ans[idx[:, 0], idx[:, 1]])
             info_dict = {"loss": loss, "ppl": math.exp(loss), "acc": acc}
             ct.flush(info=info_dict)
             His_loss.append_history(0, (step, loss))
             His_acc.append_history(0, (step, acc))
             target.cpu()
             if (step + 1) % step_to_save_model == 0:
                 torch.save(self.state_dict(),
                            os.path.join(self.save_path, 'model.ckpt'))
                 His_loss.plot(os.path.join(self.save_path, "loss_plot"))
                 His_acc.plot(os.path.join(self.save_path, "acc_plot"))
         # acc, f1 = self.test_corpus(return_value=True)
         if check_point:
             if acc > max_acc:
                 path = os.path.join(self.save_path, 'model.ckpt')
                 print(
                     "Checkpoint: acc grow from %4f to %4f, save model to %s"
                     % (max_acc, acc, path))
                 torch.save(self.state_dict(), path)
                 max_acc = acc
             else:
                 print(
                     "Checkpoint: acc not grow from %4f to %4f, model not save."
                     % (max_acc, acc))
         else:
             path = os.path.join(self.save_path, 'model.ckpt')
             torch.save(self.state_dict(), path)
     His_loss.plot(
         os.path.join(
             self.save_path,
             "Loss_" + self.utils.zhuyin_data_path.split('/')[-1] +
             "_%d" % num_epochs))
     His_acc.plot(
         os.path.join(
             self.save_path,
             "Acc_" + self.utils.zhuyin_data_path.split('/')[-1] +
             "_%d" % num_epochs))
    def train_model(self, train_file, save_model_name, num_epochs=10, prefix=['a', 'b']):
        self.to(self.device)
        train_loader, test_loader = self.get_dataloader(train_file, shuffle=True, prefix=prefix)
        # Train the model
        total_step = len(train_loader)
        His = History(title="TrainingCurve", xlabel="step", ylabel="f1-score", item_name=["train_Nf", "train_Na", "test_Nf", "test_Na"])
        step_idx = 0
        for epoch in range(num_epochs):
            self.train()
            ct = Clock(len(train_loader),title="Epoch %d/%d"%(epoch+1, num_epochs))
            ac_loss = 0
            num = 0
            f1ab = 0
            f1cd = 0

            for i, li in enumerate(train_loader):
                li = self.expand_data(li)
                [arcs, seqs, seq_len, label], ind = sort_by(li, piv=2)
                [arcs, seqs, seq_len, label] = [arcs.to(self.device), seqs.to(self.device), seq_len.to(self.device), label.to(self.device)]
                # this_bs = arcs.shape[0]
                # Forward pass
                root, graph = self.forward(seqs, seq_len)
                ans = self.arcs_expand(arcs, graph.shape[-1], label)
                root_ans = (ans.sum(-1) > 0).float()
                root = root.squeeze()
                # ans = torch.zeros_like(graph)
                # for i in range(len(arcs)):
                #     ans[i,arcs[i][0],arcs[i][1]] = int(label[i]==1)
                # loss_1 = self.criterion(root[(label==1).nonzero().squeeze(1)], arcs[:, 0][(label==1).nonzero().squeeze(1)].unsqueeze(1))
                if not ((root_ans >= 0).all() & (root_ans <= 1).all()).item():
                    print(root_ans)
                if not ((root >= 0).all() & (root <= 1).all()).item():
                    print(root)
                loss_1 = self.bce(root, root_ans)
                loss_2 = 0.3 * self.multi_target_loss(graph, ans)
                # return loss_1, loss_2
                # print(lossnum)
                loss = loss_1 + loss_2
                ac_loss += loss.item()
                num += 1
                # if dev:
                #     return root, graph, [arcs, seqs, seq_len, label], ans
                # total, nf_corr, na_corr = self.acc(arcs, root, graph, label)
                info_dict = self.multi_acc(arcs, root, graph, label)
                f1ab += f1_score(info_dict['a'], info_dict['b'])
                f1cd += f1_score(info_dict['c'], info_dict['d'])
                # info_dict = {'loss':ac_loss/num, 'accNf':train_nf_corr/train_total, 'accNa':train_na_corr/train_total}
                ct.flush(info=info_dict)
                step_idx += 1
                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()
            His.append_history(0, (step_idx, f1ab/num))
            His.append_history(1, (step_idx, f1cd/num))
                
            
            with torch.no_grad():
                self.eval()
                # nf_correct = 0
                # na_correct = 0
                # test_total = 0
                for i, li in enumerate(test_loader):
                    li = self.expand_data(li)
                    [arcs, seqs, seq_len, label], ind = sort_by(li, piv=2)
                    [arcs, seqs, seq_len, label] = [arcs.to(self.device), seqs.to(self.device), seq_len.to(self.device), label.to(self.device)]
                    root, graph = self.forward(seqs, seq_len)
                    root = root.squeeze()
                    info_dict = self.multi_acc(arcs, root, graph, label)
                    ct.flush(info=info_dict)
                    # t, f, a = self.test(arcs, seqs, seq_len, label)
                    # test_total += t
                    # nf_correct += f
                    # na_correct += a
                # info_dict = {'val_accNf':nf_correct/test_total, 'val_accNa':na_correct/test_total}
                # ct.flush(info={'loss':ac_loss/num, 'val_accNf':nf_correct/test_total, 'val_accNa':na_correct/test_total})
                
                His.append_history(2, (step_idx, f1_score(info_dict['a'], info_dict['b'])))
                His.append_history(3, (step_idx, f1_score(info_dict['c'], info_dict['d'])))

        # Save the model checkpoint
        torch.save(self.state_dict(), save_model_name)
        His.plot()
 p = Parser(batch_size=args.batch_size)
 thres = [args.threshold, args.threshold_2] if args.threshold_2 is not None else args.threshold
 if args.mode == 'print' or args.out_file is None:
     out_file = sys.stdout
 else:
     out_file = open(args.out_file, 'w')
 
 if args.mode in ['print', 'write']:
     prefix = args.prefix.split("_")
     assert len(prefix) == 2
     p.load_model(args.model_name)
     sents = []
     poss = []
     f = open(args.in_file)
     total_num = int(os.popen("wc -l %s" % args.in_file).read().split(' ')[0])
     ct = Clock(total_num, title="===> Parsing File %s with %d lines"%(args.in_file, total_num))
     for i in f:
         if args.timing == "True" and args.mode != 'print':
             ct.flush()
         sent_list = i.strip().split(' ')
         if len(sent_list) < 2:
             print("TOO SHORT ==>", ' '.join(sent_list), file=out_file)
             continue
         elif len(sent_list)>args.max_len:
             print("TOO LONG ==>", ' '.join(sent_list), file=out_file)
             continue
         sents.append(sent_list)
         if len(sents) >= p.batch_size:
             p.evaluate(sents, prefix=prefix, threshold=thres, file=out_file, print_out=(out_file is not None), show_score=(args.show_score=="True"))
             sents = []
     if len(sents)>0:
    def train_model(self, num_epochs=100, d_steps=10, g_steps=10):
        self.D.to(self.D.device)
        self.G.to(self.G.device)
        self.encoder.to(self.encoder.device)
        real_datagen = self.utils.data_generator("train")
        test_datagen = self.utils.data_generator("test")
        for epoch in range(num_epochs):
            d_ct = Clock(d_steps,
                         title="Train Discriminator(%d/%d)" %
                         (epoch, num_epochs))
            for d_step, real_data in zip(range(d_steps), real_datagen):
                # 1. Train D on real+fake
                self.D.zero_grad()

                #  1A: Train D on real
                d_org_data, d_data_seqlen = self.utils.raw2elmo(real_data)
                d_mask_data, d_data_seqlen, d_mask_label = \
                self.utils.elmo2mask(d_org_data, d_data_seqlen, mask_rate=epoch/num_epochs)
                d_real_pred = self.D(d_org_data, d_data_seqlen)
                d_real_error = self.criterion(
                    d_real_pred.transpose(1, 2),
                    torch.ones(d_mask_label.shape, dtype=torch.int64).to(
                        self.D.device))  # ones = true
                d_real_error.backward(
                )  # compute/store gradients, but don't change params
                self.D.optimizer.step()

                #  1B: Train D on fake
                d_gen_input = self.encoder(d_org_data, d_data_seqlen)
                d_fake_data = self.G(
                    d_mask_data, d_data_seqlen, hidden=d_gen_input).detach(
                    )  # detach to avoid training G on these labels
                d_fake_pred = self.D(d_fake_data, d_data_seqlen, numpy=False)
                d_fake_error = self.criterion(
                    d_fake_pred.transpose(1, 2),
                    torch.from_numpy(d_mask_label).to(
                        self.D.device))  # zeros = fake
                d_fake_error.backward()
                self.D.optimizer.step(
                )  # Only optimizes D's parameters; changes based on stored gradients from backward()
                d_ct.flush(info={"D_loss": d_fake_error.item()})

            g_ct = Clock(g_steps,
                         title="Train Generator(%d/%d)" % (epoch, num_epochs))
            for g_step, real_data in zip(range(g_steps), real_datagen):
                # 2. Train G on D's response (but DO NOT train D on these labels)
                self.G.zero_grad()

                g_org_data, g_data_seqlen = self.utils.raw2elmo(real_data)
                g_mask_data, g_data_seqlen, g_mask_label = \
                self.utils.elmo2mask(g_org_data, g_data_seqlen, mask_rate=epoch/num_epochs)
                gen_input = self.encoder(g_org_data, g_data_seqlen)
                g_fake_data = self.G(g_mask_data,
                                     g_data_seqlen,
                                     hidden=gen_input)
                dg_fake_pred = self.D(g_fake_data, g_data_seqlen, numpy=False)
                g_error = self.criterion(
                    dg_fake_pred.transpose(1, 2),
                    torch.ones(g_mask_label.shape,
                               dtype=torch.int64).to(self.D.device)
                )  # we want to fool, so pretend it's all genuine

                g_error.backward()
                self.G.optimizer.step()  # Only optimizes G's parameters
                self.encoder.optimizer.step()
                g_ct.flush(info={"G_loss": g_error.item()})

            with torch.no_grad():
                for _, real_data in zip(range(2), test_datagen):
                    g_org_data, g_data_seqlen = self.utils.raw2elmo(real_data)
                    [g_org_data, g_data_seqlen
                     ], _ind = sort_by([g_org_data, g_data_seqlen], piv=1)
                    g_mask_data, g_data_seqlen, g_mask_label = \
                    self.utils.elmo2mask(g_org_data, g_data_seqlen, mask_rate=epoch/num_epochs)
                    gen_input = self.encoder(g_org_data,
                                             g_data_seqlen,
                                             sort=False)
                    g_fake_data = self.G(g_mask_data,
                                         g_data_seqlen,
                                         hidden=gen_input,
                                         sort=False)

                    gen_sents = self.invelmo.test(g_fake_data.cpu().numpy(),
                                                  g_data_seqlen)
                    for i, j in zip(real_data, gen_sents):
                        print("=" * 50)
                        print(' '.join(i))
                        print("---")
                        print(' '.join(j))
                        print("=" * 50)
            self.save_model()