Esempio n. 1
0
def train_electra(DS_model,
                  dloader,
                  lr=1e-4,
                  epoch=100,
                  log_interval=20,
                  parallel=parallel):
    DS_model.to(device)
    DS_model.train()
    model_optimizer = optim.Adam(DS_model.parameters(), lr=lr)
    if parallel:
        DS_model = torch.nn.DataParallel(DS_model)

#    criterion = nn.CrossEntropyLoss().to(device)
    iteration = 0
    total_loss = []
    for ep in range(epoch):

        t0 = time.time()
        #        step_loss = 0
        epoch_loss = 0
        epoch_cases = 0
        for batch_idx, sample in enumerate(dloader):
            model_optimizer.zero_grad()
            loss = 0

            src = sample['src_token']
            trg = sample['trg']
            att_mask = sample['mask_padding']
            origin_len = sample['origin_seq_length']

            bs = len(src)

            src = src.long().to(device)
            trg = trg.long().to(device)
            att_mask = att_mask.float().to(device)
            origin_len = origin_len.to(device)

            outputs = DS_model(input_ids=src,
                               attention_mask=att_mask,
                               labels=trg)

            loss, scores = outputs[:2]

            loss.sum().backward()
            model_optimizer.step()

            with torch.no_grad():
                epoch_loss += loss.sum().item() * bs
                epoch_cases += bs

            if iteration % log_interval == 0:
                #                step_loss.backward()
                #                model_optimizer.step()
                #                print('+++ update +++')
                print(
                    'Ep:{} [{} ({:.0f}%)/ ep_time:{:.0f}min] L:{:.4f}'.format(
                        ep, batch_idx * batch_size,
                        100. * batch_idx / len(dloader), (time.time() - t0) *
                        len(dloader) / (60 * (batch_idx + 1)),
                        loss.sum().item()))


#                print(0,st_target)
#                step_loss = 0

            if iteration % 400 == 0:
                checkpoint_file = '../checkpoint_electra'
                save_checkpoint(checkpoint_file, 'electra_task1.pth', DS_model,
                                model_optimizer, parallel)

            iteration += 1
        if ep % 1 == 0:
            checkpoint_file = '../checkpoint_electra'
            save_checkpoint(checkpoint_file, 'electra_task1.pth', DS_model,
                            model_optimizer, parallel)

            print('======= epoch:%i ========' % ep)

        print('++ Ep Time: {:.1f} Secs ++'.format(time.time() - t0))
        total_loss.append(float(epoch_loss / epoch_cases))
        pd_total_loss = pd.DataFrame(total_loss)
        pd_total_loss.to_csv('./total_loss_finetune_electra.csv', sep=',')
    print(total_loss)
Esempio n. 2
0
def finetune_BERT(DS_model,
                  dloader,
                  lr=1e-4,
                  epoch=10,
                  log_interval=20,
                  parallel=parallel):
    DS_model.to(device)
    model_optimizer = optim.Adam(DS_model.parameters(), lr=lr)
    if parallel:
        DS_model = torch.nn.DataParallel(DS_model)
    DS_model.train()

    #    criterion = nn.MSELoss().to(device)
    criterion = nn.CrossEntropyLoss(ignore_index=-1).to(device)
    iteration = 0
    total_loss = []
    for ep in range(epoch):
        DS_model.train()

        t0 = time.time()
        #        step_loss = 0
        epoch_loss = 0
        epoch_cases = 0
        for batch_idx, sample in enumerate(dloader):
            model_optimizer.zero_grad()
            loss = 0

            src = sample['src_token']
            trg = sample['trg']
            att_mask = sample['mask_padding']
            origin_len = sample['origin_seq_length']

            bs = len(src)

            src = src.float().to(device)
            trg = trg.float().to(device)
            att_mask = att_mask.float().to(device)
            origin_len = origin_len.to(device)

            pred_prop = DS_model(input_ids=src.long(), attention_mask=att_mask)
            #            print(pred_prop.shape, trg.shape)
            loss = criterion(pred_prop, trg.view(-1).contiguous().long())

            #            print(pred_prop.shape, trg.view(-1).shape)
            #            return pred_prop, trg

            loss.backward()
            model_optimizer.step()

            with torch.no_grad():
                epoch_loss += loss.item() * bs
                epoch_cases += bs

            if iteration % log_interval == 0:
                #                step_loss.backward()
                #                model_optimizer.step()
                #                print('+++ update +++')
                print(
                    'Ep:{} [{} ({:.0f}%)/ ep_time:{:.0f}min] L:{:.4f}'.format(
                        ep, batch_idx * batch_size,
                        100. * batch_idx / len(dloader), (time.time() - t0) *
                        len(dloader) / (60 * (batch_idx + 1)), loss.item()))
#                print(0,st_target)
#                step_loss = 0

            if iteration % 400 == 0:
                #                save_checkpoint('bert_pretrain.pth',DS_model,model_optimizer)
                save_checkpoint(checkpoint_DIR, checkpoint_path, DS_model,
                                model_optimizer, parallel)

            iteration += 1
        if ep % 1 == 0:
            save_checkpoint(checkpoint_DIR, checkpoint_path, DS_model,
                            model_optimizer, parallel)
            #            test_alphaBert(DS_model,D2S_valloader,
            #                           is_clean_up=True, ep=ep,train=True)

            print('======= epoch:%i ========' % ep)

#        print('total loss: {:.4f}'.format(total_loss/len(dloader)))
        print('++ Ep Time: {:.1f} Secs ++'.format(time.time() - t0))
        #        total_loss.append(epoch_loss)
        total_loss.append(float(epoch_loss / epoch_cases))
        pd_total_loss = pd.DataFrame(total_loss)
        pd_total_loss.to_csv('./iou_pic/bert/total_loss_finetune.csv', sep=',')
    print(total_loss)


#finetune_BERT(bert_model,bert_dataloader,lr=1e-5,epoch=200,log_interval=3,parallel=parallel)

# test_BERT(bert_model,bert_dataset_valloader, is_clean_up=True, ep='f1_bert_val',mean_max='max',rouge=True)
# test_BERT(bert_model,bert_dataset_testloader,threshold=0.45, is_clean_up=True, ep='f1_bert_test',mean_max='max',rouge=True)
# test_BERT(bert_model,D2S_cyy_testloader,threshold=0.45, is_clean_up=True, ep='f1_bert_cyy',mean_max='max',rouge=True)
# test_BERT(bert_model,D2S_lin_testloader,threshold=0.45, is_clean_up=True, ep='f1_bert_cyy',mean_max='max',rouge=True)
# test_BERT(bert_model,D2S_all_testloader,threshold=0.45, is_clean_up=True, ep='f1_bert_cyy',mean_max='max',rouge=True)
Esempio n. 3
0
def train_alphaBert_stage1(TS_model,
                           dloader,
                           lr=1e-4,
                           epoch=10,
                           log_interval=20,
                           cloze_fix=True,
                           use_amp=False,
                           parallel=True):
    TS_model.to(device)
    model_optimizer = optim.Adam(TS_model.parameters(), lr=lr)

    if parallel:
        TS_model = torch.nn.DataParallel(TS_model)

    TS_model.train()

    #    criterion = alphabert_loss.Alphabert_satge1_loss(device=device)
    criterion = nn.CrossEntropyLoss(ignore_index=-1).to(device)
    iteration = 0
    total_loss = []
    out_pred_res = []
    out_pred_test = []
    for ep in range(epoch):
        t0 = time.time()
        #        step_loss = 0
        epoch_loss = 0
        epoch_cases = 0
        for batch_idx, sample in enumerate(dloader):
            #            TS_model.train()
            model_optimizer.zero_grad()
            loss = 0

            src = sample['src_token']
            trg = sample['trg']
            att_mask = sample['mask_padding']
            origin_len = sample['origin_seq_length']

            bs, max_len = src.shape

            #            src, err_cloze = make_cloze(src,
            #                                        max_len,
            #                                        device=device,
            #                                        percent=0.15,
            #                                        fix=cloze_fix)

            src = src.float().to(device)
            trg = trg.long().to(device)
            att_mask = att_mask.float().to(device)
            origin_len = origin_len.to(device)

            prediction_scores, pred_prop = TS_model(input_ids=src,
                                                    attention_mask=att_mask)

            #            print(1111,prediction_scores.view(-1,84).shape)
            #            print(1111,trg.view(-1).shape)

            loss = criterion(
                prediction_scores.view(-1, 84).contiguous(),
                trg.view(-1).contiguous())

            if use_amp:
                with amp.scale_loss(loss, model_optimizer) as scaled_loss:
                    scaled_loss.backward()
            else:
                loss.backward()

            model_optimizer.step()

            with torch.no_grad():
                epoch_loss += loss.item() * bs
                epoch_cases += bs

                if iteration % log_interval == 0:
                    print('Ep:{} [{} ({:.0f}%)/ ep_time:{:.0f}min] L:{:.4f}'.
                          format(ep, batch_idx * batch_size,
                                 100. * batch_idx / len(dloader),
                                 (time.time() - t0) * len(dloader) /
                                 (60 * (batch_idx + 1)), loss.item()))

                if iteration % 400 == 0:
                    save_checkpoint('DS_pretrain.pth', TS_model,
                                    model_optimizer)
                    a_ = tokenize_alphabets.convert_idx2str(
                        src[0][:origin_len[0]])
                    print(a_)
                    print(' ******** ******** ******** ')
                    _, show_pred = torch.max(prediction_scores[0], dim=1)
                    err_cloze_ = trg[0] > -1
                    src[0][err_cloze_] = show_pred[err_cloze_].float()
                    b_ = tokenize_alphabets.convert_idx2str(
                        src[0][:origin_len[0]])
                    print(b_)
                    print(' ******** ******** ******** ')
                    src[0][err_cloze_] = trg[0][err_cloze_].float()
                    c_ = tokenize_alphabets.convert_idx2str(
                        src[0][:origin_len[0]])
                    print(c_)

                    out_pred_res.append((ep, a_, b_, c_, err_cloze_))
                    out_pd_res = pd.DataFrame(out_pred_res)
                    out_pd_res.to_csv('./result/out_pred_train.csv', sep=',')

                if iteration % 999 == 0:
                    print(' ===== Show the Test of Pretrain ===== ')
                    test_res = test_alphaBert_stage1(TS_model,
                                                     stage1_dataloader_test)
                    print(' ===== Show the Test of Pretrain ===== ')

                    out_pred_test.append((ep, *test_res))
                    out_pd_test = pd.DataFrame(out_pred_test)
                    out_pd_test.to_csv('./result/out_pred_test.csv', sep=',')

            iteration += 1
        if ep % 1 == 0:
            save_checkpoint('DS_pretrain.pth', TS_model, model_optimizer)

            print('======= epoch:%i ========' % ep)

        print('++ Ep Time: {:.1f} Secs ++'.format(time.time() - t0))
        total_loss.append(float(epoch_loss / epoch_cases))
        pd_total_loss = pd.DataFrame(total_loss)
        pd_total_loss.to_csv('./result/total_loss_pretrain.csv', sep=',')
    print(total_loss)
Esempio n. 4
0
def train_alphaBert(DS_model, dloader, lr=1e-4, epoch=10, log_interval=20):
    DS_model.to(device)
    model_optimizer = optim.Adam(DS_model.parameters(), lr=lr)
    DS_model = torch.nn.DataParallel(DS_model)
    DS_model.train()

    #    criterion = nn.MSELoss().to(device)
    criterion = alphabert_loss_v02.Alphabert_loss(device=device)
    iteration = 0
    total_loss = []
    for ep in range(epoch):
        DS_model.train()

        t0 = time.time()
        #        step_loss = 0
        epoch_loss = 0
        epoch_cases = 0
        for batch_idx, sample in enumerate(dloader):
            model_optimizer.zero_grad()
            loss = 0

            src = sample['src_token']
            trg = sample['trg']
            att_mask = sample['mask_padding']
            origin_len = sample['origin_seq_length']

            bs = len(src)

            src = src.float().to(device)
            trg = trg.float().to(device)
            att_mask = att_mask.float().to(device)
            origin_len = origin_len.to(device)

            pred_prop = DS_model(input_ids=src, attention_mask=att_mask)
            loss = criterion(pred_prop[0], trg, origin_len)

            loss.backward()
            model_optimizer.step()

            with torch.no_grad():
                epoch_loss += loss.item() * bs
                epoch_cases += bs

            if iteration % log_interval == 0:
                #                step_loss.backward()
                #                model_optimizer.step()
                #                print('+++ update +++')
                print(
                    'Ep:{} [{} ({:.0f}%)/ ep_time:{:.0f}min] L:{:.4f}'.format(
                        ep, batch_idx * batch_size,
                        100. * batch_idx / len(dloader), (time.time() - t0) *
                        len(dloader) / (60 * (batch_idx + 1)), loss.item()))
#                print(0,st_target)
#                step_loss = 0

            if iteration % 400 == 0:
                save_checkpoint('d2s.pth', DS_model, model_optimizer)

            iteration += 1
        if ep % 1 == 0:
            save_checkpoint('d2s.pth', DS_model, model_optimizer)
            #            test_alphaBert(DS_model,D2S_valloader,
            #                           is_clean_up=True, ep=ep,train=True)

            print('======= epoch:%i ========' % ep)


#        print('total loss: {:.4f}'.format(total_loss/len(dloader)))
        print('++ Ep Time: {:.1f} Secs ++'.format(time.time() - t0))
        #        total_loss.append(epoch_loss)
        total_loss.append(float(epoch_loss / epoch_cases))
        pd_total_loss = pd.DataFrame(total_loss)
        pd_total_loss.to_csv('./iou_pic/total_loss_finetune.csv', sep=',')
    print(total_loss)
Esempio n. 5
0
def train_alphaBert_head(TS_model,
                         dloader,
                         lr=1e-4,
                         epoch=10,
                         log_interval=20,
                         cloze_fix=True,
                         use_amp=False,
                         parallel=True):
    global checkpoint_file
    TS_model.to(device)
    model_optimizer = optim.Adam(TS_model.parameters(), lr=lr)
    if use_amp:
        TS_model, model_optimizer = amp.initialize(TS_model,
                                                   model_optimizer,
                                                   opt_level="O1")
    if parallel:
        TS_model = torch.nn.DataParallel(TS_model)

    TS_model.train()

    criterion = nn.CrossEntropyLoss(ignore_index=-1).to(device)
    iteration = 0
    total_loss = []
    out_pred_res = []
    out_pred_test = []
    for ep in range(epoch):
        t0 = time.time()
        #        step_loss = 0
        epoch_loss = 0
        epoch_cases = 0
        for batch_idx, sample in enumerate(dloader):
            #            TS_model.train()
            model_optimizer.zero_grad()
            loss = 0

            src = sample['src_token']
            trg = sample['trg']
            att_mask = sample['mask_padding']
            origin_len = sample['origin_seq_length']

            bs, max_len = src.shape

            src = src.float().to(device)
            trg = trg.long().to(device)
            att_mask = att_mask.float().to(device)
            origin_len = origin_len.to(device)

            head_outputs, = TS_model(input_ids=src,
                                     attention_mask=att_mask,
                                     out='finehead')
            #            head_outputs = outputs[0]
            trg_view = trg.view(-1).contiguous()
            trg_mask0 = trg_view == 0
            trg_mask1 = trg_view == 1

            loss0 = criterion(head_outputs[trg_mask0], trg_view[trg_mask0])
            loss1 = criterion(head_outputs[trg_mask1], trg_view[trg_mask1])

            loss = loss0 + loss1

            if use_amp:
                with amp.scale_loss(loss, model_optimizer) as scaled_loss:
                    scaled_loss.backward()
            else:
                loss.backward()

            model_optimizer.step()

            with torch.no_grad():
                epoch_loss += loss.item() * bs
                epoch_cases += bs

                if iteration % log_interval == 0:
                    print(
                        'Ep:{} [{} ({:.0f}%)/ ep_time:{:.0f}min] L:{:.4f} L0:{:.4f} L1:{:.4f}'
                        .format(ep, batch_idx * batch_size,
                                100. * batch_idx / len(dloader),
                                (time.time() - t0) * len(dloader) /
                                (60 * (batch_idx + 1)), loss.item(),
                                loss0.item(), loss1.item()))

                if iteration % 400 == 0:
                    save_checkpoint(checkpoint_file,
                                    'd2s_head.pth',
                                    TS_model,
                                    model_optimizer,
                                    parallel=parallel)

                    #                    head_outputs = head_outputs.view(bs,max_len,-1)
                    show_pred = nn.Softmax(dim=-1)(head_outputs.view(
                        bs, max_len, -1)[0])
                    head_mask = src[0] == tokenize_alphabets.alphabet2idx['|']

                    a_ = tokenize_alphabets.convert_idx2str_head(
                        src[0][:origin_len[0]], head_mask, show_pred[:, 1],
                        trg[0])
                    print(' ******** ******** ******** ')
                    print(a_)
                    print(' ******** ******** ******** ')
                    #                    b_ = tokenize_alphabets.convert_idx2str(src_ori[0][:origin_len[0]],head=True)
                    #                    print(b_)

                    out_pred_res.append((ep, a_))
                    out_pd_res = pd.DataFrame(out_pred_res)
                    out_pd_res.to_csv('./result/out_fine_head_train.csv',
                                      sep=',')


#                if iteration % 999 == 0:
#                    print(' ===== Show the Test of Pretrain ===== ')
#                    test_res = test_alphaBert_stage1_head(TS_model,stage1_head_dataloader_test)
#                    print(' ===== Show the Test of Pretrain ===== ')
#
#                    out_pred_test.append((ep,*test_res))
#                    out_pd_test = pd.DataFrame(out_pred_test)
#                    out_pd_test.to_csv('./result/out_fine_head_test.csv', sep=',')

            iteration += 1
        if ep % 1 == 0:
            save_checkpoint(checkpoint_file,
                            'd2s_head.pth',
                            TS_model,
                            model_optimizer,
                            parallel=parallel)

            print('======= epoch:%i ========' % ep)

        print('++ Ep Time: {:.1f} Secs ++'.format(time.time() - t0))
        total_loss.append(float(epoch_loss / epoch_cases))
        pd_total_loss = pd.DataFrame(total_loss)
        pd_total_loss.to_csv('./result/total_loss_finetune_head.csv', sep=',')
    print(total_loss)
Esempio n. 6
0
def train_lstm_stage1(TS_model,
                      dloader,
                      lr=1e-4,
                      epoch=10,
                      log_interval=20,
                      cloze_fix=True,
                      parallel=True):
    global checkpoint_file
    TS_model.to(device)
    model_optimizer = optim.Adam(TS_model.parameters(), lr=lr)
    if parallel:
        TS_model = torch.nn.DataParallel(TS_model)
    TS_model.train()
    #    criterion = alphabert_loss.Alphabert_satge1_loss(device=device)
    criterion = nn.CrossEntropyLoss(ignore_index=-1).to(device)
    iteration = 0
    total_loss = []
    out_pred_res = []
    out_pred_test = []
    for ep in range(epoch):
        t0 = time.time()
        #        step_loss = 0
        epoch_loss = 0
        epoch_cases = 0
        for batch_idx, sample in enumerate(dloader):
            #            TS_model.train()
            model_optimizer.zero_grad()
            loss = 0

            src = sample['src_token']
            trg = sample['trg']
            att_mask = sample['mask_padding']
            origin_len = sample['origin_seq_length']
            bs, max_len = src.shape

            src = src.float().to(device)
            trg = trg.long().to(device)
            att_mask = att_mask.float().to(device)
            origin_len = origin_len.to(device)

            #            return src,trg,att_mask,origin_len
            _, prediction_scores = TS_model(x=src, x_lengths=origin_len)
            loss = criterion(
                prediction_scores.view(-1, 100).contiguous(),
                trg.view(-1).contiguous())
            loss.backward()
            model_optimizer.step()

            with torch.no_grad():
                epoch_loss += loss.item() * bs
                epoch_cases += bs

                if iteration % log_interval == 0:
                    print('Ep:{} [{} ({:.0f}%)/ ep_time:{:.0f}min] L:{:.4f}'.
                          format(ep, batch_idx * batch_size,
                                 100. * batch_idx / len(dloader),
                                 (time.time() - t0) * len(dloader) /
                                 (60 * (batch_idx + 1)), loss.item()))

                if iteration % 400 == 0:
                    checkpoint_file = './checkpoint_lstm'
                    save_checkpoint(checkpoint_file, 'lstm_pretrain.pth',
                                    TS_model, model_optimizer, parallel)
                    a_ = tokenize_alphabets.convert_idx2str(
                        src[0][:origin_len[0]])
                    print(a_)
                    print(' ******** ******** ******** ')
                    _, show_pred = torch.max(prediction_scores[0], dim=1)
                    err_cloze_ = trg[0] > -1
                    src[0][err_cloze_] = show_pred[err_cloze_].float()
                    b_ = tokenize_alphabets.convert_idx2str(
                        src[0][:origin_len[0]])
                    print(b_)
                    print(' ******** ******** ******** ')
                    src[0][err_cloze_] = trg[0][err_cloze_].float()
                    c_ = tokenize_alphabets.convert_idx2str(
                        src[0][:origin_len[0]])
                    print(c_)

                    out_pred_res.append((ep, a_, b_, c_, err_cloze_))
                    out_pd_res = pd.DataFrame(out_pred_res)
                    out_pd_res.to_csv('./out_pred_train.csv', sep=',')

            iteration += 1
        if ep % 1 == 0:
            checkpoint_file = './checkpoint_lstm'
            save_checkpoint(checkpoint_file, 'lstm_pretrain.pth', TS_model,
                            model_optimizer, parallel)

            print('======= epoch:%i ========' % ep)

        print('++ Ep Time: {:.1f} Secs ++'.format(time.time() - t0))
        total_loss.append(float(epoch_loss / epoch_cases))
        pd_total_loss = pd.DataFrame(total_loss)
        pd_total_loss.to_csv('./result/total_loss_pretrain.csv', sep=',')
    print(total_loss)
Esempio n. 7
0
def train_alphaBert(DS_model,
                    dloader,
                    lr=1e-4,
                    epoch=10,
                    log_interval=20,
                    lkahead=False):
    global checkpoint_file
    DS_model.to(device)
    #    model_optimizer = optim.Adam(DS_model.parameters(), lr=lr)
    model_optimizer = Ranger(DS_model.parameters(), lr=lr)
    DS_model = torch.nn.DataParallel(DS_model)
    DS_model.train()
    #    if lkahead:
    #        print('using Lookahead')
    #        model_optimizer = lookahead_pytorch.Lookahead(model_optimizer, la_steps=5, la_alpha=0.5)
    #    model_optimizer = Ranger(DS_model.parameters(), lr=4e-3, alpha=0.5, k=5)
    #    criterion = nn.MSELoss().to(device)
    #    criterion = alphabert_loss_v02.Alphabert_loss(device=device)
    criterion = nn.CrossEntropyLoss(ignore_index=-1).to(device)
    iteration = 0
    total_loss = []
    for ep in range(epoch):
        DS_model.train()

        t0 = time.time()
        #        step_loss = 0
        epoch_loss = 0
        epoch_cases = 0
        for batch_idx, sample in enumerate(dloader):
            model_optimizer.zero_grad()
            loss = 0

            src = sample['src_token']
            trg = sample['trg']
            att_mask = sample['mask_padding']
            origin_len = sample['origin_seq_length']

            bs = len(src)

            src = src.float().to(device)
            trg = trg.long().to(device)
            att_mask = att_mask.float().to(device)
            origin_len = origin_len.to(device)

            pred_prop, = DS_model(input_ids=src,
                                  attention_mask=att_mask,
                                  out='finehead')

            trg_view = trg.view(-1).contiguous()
            trg_mask0 = trg_view == 0
            trg_mask1 = trg_view == 1

            loss = criterion(pred_prop, trg_view)
            #            try:
            #                loss0 = criterion(pred_prop[trg_mask0],trg_view[trg_mask0])
            #                loss1 = criterion(pred_prop[trg_mask1],trg_view[trg_mask1])
            #
            #                loss += 0.2*loss0+0.8*loss1
            #            except:
            #                loss = criterion(pred_prop,trg.view(-1).contiguous())

            loss.backward()
            model_optimizer.step()

            with torch.no_grad():
                epoch_loss += loss.item() * bs
                epoch_cases += bs

            if iteration % log_interval == 0:
                #                step_loss.backward()
                #                model_optimizer.step()
                #                print('+++ update +++')
                print(
                    'Ep:{} [{} ({:.0f}%)/ ep_time:{:.0f}min] L:{:.4f}'.format(
                        ep, batch_idx * batch_size,
                        100. * batch_idx / len(dloader), (time.time() - t0) *
                        len(dloader) / (60 * (batch_idx + 1)), loss.item()))
#                print(0,st_target)
#                step_loss = 0

            if iteration % 400 == 0:
                save_checkpoint(checkpoint_file,
                                'd2s_total.pth',
                                DS_model,
                                model_optimizer,
                                parallel=parallel)
                print(
                    tokenize_alphabets.convert_idx2str(src[0][:origin_len[0]]))
            iteration += 1
        if ep % 1 == 0:
            save_checkpoint(checkpoint_file,
                            'd2s_total.pth',
                            DS_model,
                            model_optimizer,
                            parallel=parallel)
            #            test_alphaBert(DS_model,D2S_valloader,
            #                           is_clean_up=True, ep=ep,train=True)

            print('======= epoch:%i ========' % ep)


#        print('total loss: {:.4f}'.format(total_loss/len(dloader)))
        print('++ Ep Time: {:.1f} Secs ++'.format(time.time() - t0))
        #        total_loss.append(epoch_loss)
        total_loss.append(float(epoch_loss / epoch_cases))
        pd_total_loss = pd.DataFrame(total_loss)
        pd_total_loss.to_csv('./iou_pic/total_loss_finetune.csv', sep=',')
    print(total_loss)