def train_iters(ae_model, dis_model):
    train_data_loader = non_pair_data_loader(
        batch_size=args.batch_size,
        id_bos=args.id_bos,
        id_eos=args.id_eos,
        id_unk=args.id_unk,
        max_sequence_length=args.max_sequence_length,
        vocab_size=args.vocab_size)
    train_data_loader.create_batches(args.train_file_list,
                                     args.train_label_list,
                                     if_shuffle=True)
    add_log("Start train process.")
    ae_model.train()
    dis_model.train()

    ae_optimizer = NoamOpt(
        ae_model.src_embed[0].d_model, 1, 2000,
        torch.optim.Adam(ae_model.parameters(),
                         lr=0,
                         betas=(0.9, 0.98),
                         eps=1e-9))
    dis_optimizer = torch.optim.Adam(dis_model.parameters(), lr=0.0001)

    ae_criterion = get_cuda(
        LabelSmoothing(size=args.vocab_size,
                       padding_idx=args.id_pad,
                       smoothing=0.1))
    dis_criterion = nn.BCELoss(size_average=True)

    for epoch in range(200):
        print('-' * 94)
        epoch_start_time = time.time()
        for it in range(train_data_loader.num_batch):
            batch_sentences, tensor_labels, \
            tensor_src, tensor_src_mask, tensor_tgt, tensor_tgt_y, \
            tensor_tgt_mask, tensor_ntokens = train_data_loader.next_batch()

            # For debug
            # print(batch_sentences[0])
            # print(tensor_src[0])
            # print(tensor_src_mask[0])
            # print("tensor_src_mask", tensor_src_mask.size())
            # print(tensor_tgt[0])
            # print(tensor_tgt_y[0])
            # print(tensor_tgt_mask[0])
            # print(batch_ntokens)

            # Forward pass
            latent, out = ae_model.forward(tensor_src, tensor_tgt,
                                           tensor_src_mask, tensor_tgt_mask)
            # print(latent.size())  # (batch_size, max_src_seq, d_model)
            # print(out.size())  # (batch_size, max_tgt_seq, vocab_size)

            # Loss calculation
            loss_rec = ae_criterion(
                out.contiguous().view(-1, out.size(-1)),
                tensor_tgt_y.contiguous().view(-1)) / tensor_ntokens.data

            # loss_all = loss_rec + loss_dis

            ae_optimizer.optimizer.zero_grad()
            loss_rec.backward()
            ae_optimizer.step()

            # Classifier
            dis_lop = dis_model.forward(to_var(latent.clone()))

            loss_dis = dis_criterion(dis_lop, tensor_labels)

            dis_optimizer.zero_grad()
            loss_dis.backward()
            dis_optimizer.step()

            if it % 200 == 0:
                add_log(
                    '| epoch {:3d} | {:5d}/{:5d} batches | rec loss {:5.4f} | dis loss {:5.4f} |'
                    .format(epoch, it, train_data_loader.num_batch, loss_rec,
                            loss_dis))

                print(id2text_sentence(tensor_tgt_y[0], args.id_to_word))
                generator_text = ae_model.greedy_decode(
                    latent,
                    max_len=args.max_sequence_length,
                    start_id=args.id_bos)
                print(id2text_sentence(generator_text[0], args.id_to_word))

        add_log('| end of epoch {:3d} | time: {:5.2f}s |'.format(
            epoch, (time.time() - epoch_start_time)))
        # Save model
        torch.save(ae_model.state_dict(),
                   args.current_save_path + 'ae_model_params.pkl')
        torch.save(dis_model.state_dict(),
                   args.current_save_path + 'dis_model_params.pkl')
    return
def train_iters(ae_model, dis_model):
    if args.use_albert:
        tokenizer = BertTokenizer.from_pretrained("clue/albert_chinese_tiny",
                                                  do_lower_case=True)
    elif args.use_tiny_bert:
        tokenizer = AutoTokenizer.from_pretrained(
            "google/bert_uncased_L-2_H-256_A-4", do_lower_case=True)
    elif args.use_distil_bert:
        tokenizer = DistilBertTokenizer.from_pretrained(
            'distilbert-base-uncased', do_lower_case=True)
    # tokenizer = BertTokenizer.from_pretrained(args.PRETRAINED_MODEL_NAME, do_lower_case=True)
    tokenizer.add_tokens('[EOS]')
    bos_id = tokenizer.convert_tokens_to_ids(['[CLS]'])[0]

    ae_model.bert_encoder.resize_token_embeddings(len(tokenizer))
    #print("[CLS] ID: ", bos_id)

    print("Load trainData...")
    if args.load_trainData and os.path.exists('./{}_trainData.pkl'.format(
            args.task)):
        with open('./{}_trainData.pkl'.format(args.task), 'rb') as f:
            trainData = pickle.load(f)
    else:
        trainData = TextDataset(batch_size=args.batch_size,
                                id_bos='[CLS]',
                                id_eos='[EOS]',
                                id_unk='[UNK]',
                                max_sequence_length=args.max_sequence_length,
                                vocab_size=0,
                                file_list=args.train_file_list,
                                label_list=args.train_label_list,
                                tokenizer=tokenizer)
        with open('./{}_trainData.pkl'.format(args.task), 'wb') as f:
            pickle.dump(trainData, f)

    add_log("Start train process.")

    ae_model.train()
    dis_model.train()
    ae_model.to(device)
    dis_model.to(device)
    '''
    Fixing or distilling BERT encoder
    '''
    if args.fix_first_6:
        print("Try fixing first 6 bertlayers")
        for layer in range(6):
            for param in ae_model.bert_encoder.encoder.layer[layer].parameters(
            ):
                param.requires_grad = False
    elif args.fix_last_6:
        print("Try fixing last 6 bertlayers")
        for layer in range(6, 12):
            for param in ae_model.bert_encoder.encoder.layer[layer].parameters(
            ):
                param.requires_grad = False

    if args.distill_2:
        print("Get result from layer 2")
        for layer in range(2, 12):
            for param in ae_model.bert_encoder.encoder.layer[layer].parameters(
            ):
                param.requires_grad = False

    ae_optimizer = NoamOpt(
        ae_model.d_model, 1, 2000,
        torch.optim.Adam(ae_model.parameters(),
                         lr=0,
                         betas=(0.9, 0.98),
                         eps=1e-9))
    dis_optimizer = torch.optim.Adam(dis_model.parameters(), lr=0.0001)

    #ae_criterion = get_cuda(LabelSmoothing(size=args.vocab_size, padding_idx=args.id_pad, smoothing=0.1))
    ae_criterion = LabelSmoothing(size=ae_model.bert_encoder.config.vocab_size,
                                  padding_idx=0,
                                  smoothing=0.1).to(device)
    dis_criterion = nn.BCELoss(reduction='mean')

    history = {'train': []}

    for epoch in range(args.epochs):
        print('-' * 94)
        epoch_start_time = time.time()
        total_rec_loss = 0
        total_dis_loss = 0

        train_data_loader = DataLoader(trainData,
                                       batch_size=args.batch_size,
                                       shuffle=True,
                                       collate_fn=trainData.collate_fn,
                                       num_workers=4)
        num_batch = len(train_data_loader)
        trange = tqdm(enumerate(train_data_loader),
                      total=num_batch,
                      desc='Training',
                      file=sys.stdout,
                      position=0,
                      leave=True)

        for it, data in trange:
            batch_sentences, tensor_labels, tensor_src, tensor_src_mask, tensor_tgt, tensor_tgt_y, tensor_tgt_mask, tensor_ntokens = data

            tensor_labels = tensor_labels.to(device)
            tensor_src = tensor_src.to(device)
            tensor_tgt = tensor_tgt.to(device)
            tensor_tgt_y = tensor_tgt_y.to(device)
            tensor_src_mask = tensor_src_mask.to(device)
            tensor_tgt_mask = tensor_tgt_mask.to(device)

            # Forward pass
            latent, out = ae_model.forward(tensor_src, tensor_tgt,
                                           tensor_src_mask, tensor_tgt_mask)

            # Loss calculation
            loss_rec = ae_criterion(
                out.contiguous().view(-1, out.size(-1)),
                tensor_tgt_y.contiguous().view(-1)) / tensor_ntokens.data

            ae_optimizer.optimizer.zero_grad()
            loss_rec.backward()
            ae_optimizer.step()

            latent = latent.detach()
            next_latent = latent.to(device)

            # Classifier
            dis_lop = dis_model.forward(next_latent)
            loss_dis = dis_criterion(dis_lop, tensor_labels)

            dis_optimizer.zero_grad()
            loss_dis.backward()
            dis_optimizer.step()

            total_rec_loss += loss_rec.item()
            total_dis_loss += loss_dis.item()

            trange.set_postfix(total_rec_loss=total_rec_loss / (it + 1),
                               total_dis_loss=total_dis_loss / (it + 1))

            if it % 100 == 0:
                add_log(
                    '| epoch {:3d} | {:5d}/{:5d} batches | rec loss {:5.4f} | dis loss {:5.4f} |'
                    .format(epoch, it, num_batch, loss_rec, loss_dis))

                print(id2text_sentence(tensor_tgt_y[0], tokenizer, args.task))
                generator_text = ae_model.greedy_decode(
                    latent, max_len=args.max_sequence_length, start_id=bos_id)
                print(id2text_sentence(generator_text[0], tokenizer,
                                       args.task))

                # Save model
                #torch.save(ae_model.state_dict(), args.current_save_path / 'ae_model_params.pkl')
                #torch.save(dis_model.state_dict(), args.current_save_path / 'dis_model_params.pkl')

        history['train'].append({
            'epoch': epoch,
            'total_rec_loss': total_rec_loss / len(trange),
            'total_dis_loss': total_dis_loss / len(trange)
        })

        add_log('| end of epoch {:3d} | time: {:5.2f}s |'.format(
            epoch, (time.time() - epoch_start_time)))
        # Save model
        torch.save(ae_model.state_dict(),
                   args.current_save_path / 'ae_model_params.pkl')
        torch.save(dis_model.state_dict(),
                   args.current_save_path / 'dis_model_params.pkl')

    print("Save in ", args.current_save_path)
    return