Beispiel #1
0
def run():
    utils.set_logger(config.log_path)

    train_dataset = MTDataset(config.train_data_path)
    dev_dataset = MTDataset(config.dev_data_path)
    test_dataset = MTDataset(config.test_data_path)

    logging.info("-------- Dataset Build! --------")
    train_dataloader = DataLoader(train_dataset,
                                  shuffle=True,
                                  batch_size=config.batch_size,
                                  collate_fn=train_dataset.collate_fn)
    dev_dataloader = DataLoader(dev_dataset,
                                shuffle=False,
                                batch_size=config.batch_size,
                                collate_fn=dev_dataset.collate_fn)
    test_dataloader = DataLoader(test_dataset,
                                 shuffle=False,
                                 batch_size=config.batch_size,
                                 collate_fn=test_dataset.collate_fn)

    logging.info("-------- Get Dataloader! --------")
    # 初始化模型
    model = make_model(config.src_vocab_size, config.tgt_vocab_size,
                       config.n_layers, config.d_model, config.d_ff,
                       config.n_heads, config.dropout)
    model_par = torch.nn.DataParallel(model)
    # 训练
    if config.use_smoothing:
        criterion = LabelSmoothing(size=config.tgt_vocab_size,
                                   padding_idx=config.padding_idx,
                                   smoothing=0.1)
        criterion.cuda()
    else:
        criterion = torch.nn.CrossEntropyLoss(ignore_index=0, reduction='sum')
    if config.use_noamopt:
        optimizer = get_std_opt(model)
    else:
        optimizer = torch.optim.AdamW(model.parameters(), lr=config.lr)
    train(train_dataloader, dev_dataloader, model, model_par, criterion,
          optimizer)
    test(test_dataloader, model, criterion)
def train_iters(ae_model, dis_model):
    train_data_loader = non_pair_data_loader(
        batch_size=args.batch_size,
        id_bos=args.id_bos,
        id_eos=args.id_eos,
        id_unk=args.id_unk,
        max_sequence_length=args.max_sequence_length,
        vocab_size=args.vocab_size)
    train_data_loader.create_batches(args.train_file_list,
                                     args.train_label_list,
                                     if_shuffle=True)
    add_log("Start train process.")
    ae_model.train()
    dis_model.train()

    ae_optimizer = NoamOpt(
        ae_model.src_embed[0].d_model, 1, 2000,
        torch.optim.Adam(ae_model.parameters(),
                         lr=0,
                         betas=(0.9, 0.98),
                         eps=1e-9))
    dis_optimizer = torch.optim.Adam(dis_model.parameters(), lr=0.0001)

    ae_criterion = get_cuda(
        LabelSmoothing(size=args.vocab_size,
                       padding_idx=args.id_pad,
                       smoothing=0.1))
    dis_criterion = nn.BCELoss(size_average=True)

    for epoch in range(200):
        print('-' * 94)
        epoch_start_time = time.time()
        for it in range(train_data_loader.num_batch):
            batch_sentences, tensor_labels, \
            tensor_src, tensor_src_mask, tensor_tgt, tensor_tgt_y, \
            tensor_tgt_mask, tensor_ntokens = train_data_loader.next_batch()

            # For debug
            # print(batch_sentences[0])
            # print(tensor_src[0])
            # print(tensor_src_mask[0])
            # print("tensor_src_mask", tensor_src_mask.size())
            # print(tensor_tgt[0])
            # print(tensor_tgt_y[0])
            # print(tensor_tgt_mask[0])
            # print(batch_ntokens)

            # Forward pass
            latent, out = ae_model.forward(tensor_src, tensor_tgt,
                                           tensor_src_mask, tensor_tgt_mask)
            # print(latent.size())  # (batch_size, max_src_seq, d_model)
            # print(out.size())  # (batch_size, max_tgt_seq, vocab_size)

            # Loss calculation
            loss_rec = ae_criterion(
                out.contiguous().view(-1, out.size(-1)),
                tensor_tgt_y.contiguous().view(-1)) / tensor_ntokens.data

            # loss_all = loss_rec + loss_dis

            ae_optimizer.optimizer.zero_grad()
            loss_rec.backward()
            ae_optimizer.step()

            # Classifier
            dis_lop = dis_model.forward(to_var(latent.clone()))

            loss_dis = dis_criterion(dis_lop, tensor_labels)

            dis_optimizer.zero_grad()
            loss_dis.backward()
            dis_optimizer.step()

            if it % 200 == 0:
                add_log(
                    '| epoch {:3d} | {:5d}/{:5d} batches | rec loss {:5.4f} | dis loss {:5.4f} |'
                    .format(epoch, it, train_data_loader.num_batch, loss_rec,
                            loss_dis))

                print(id2text_sentence(tensor_tgt_y[0], args.id_to_word))
                generator_text = ae_model.greedy_decode(
                    latent,
                    max_len=args.max_sequence_length,
                    start_id=args.id_bos)
                print(id2text_sentence(generator_text[0], args.id_to_word))

        add_log('| end of epoch {:3d} | time: {:5.2f}s |'.format(
            epoch, (time.time() - epoch_start_time)))
        # Save model
        torch.save(ae_model.state_dict(),
                   args.current_save_path + 'ae_model_params.pkl')
        torch.save(dis_model.state_dict(),
                   args.current_save_path + 'dis_model_params.pkl')
    return
def train_iters(ae_model, dis_model):
    if args.use_albert:
        tokenizer = BertTokenizer.from_pretrained("clue/albert_chinese_tiny",
                                                  do_lower_case=True)
    elif args.use_tiny_bert:
        tokenizer = AutoTokenizer.from_pretrained(
            "google/bert_uncased_L-2_H-256_A-4", do_lower_case=True)
    elif args.use_distil_bert:
        tokenizer = DistilBertTokenizer.from_pretrained(
            'distilbert-base-uncased', do_lower_case=True)
    # tokenizer = BertTokenizer.from_pretrained(args.PRETRAINED_MODEL_NAME, do_lower_case=True)
    tokenizer.add_tokens('[EOS]')
    bos_id = tokenizer.convert_tokens_to_ids(['[CLS]'])[0]

    ae_model.bert_encoder.resize_token_embeddings(len(tokenizer))
    #print("[CLS] ID: ", bos_id)

    print("Load trainData...")
    if args.load_trainData and os.path.exists('./{}_trainData.pkl'.format(
            args.task)):
        with open('./{}_trainData.pkl'.format(args.task), 'rb') as f:
            trainData = pickle.load(f)
    else:
        trainData = TextDataset(batch_size=args.batch_size,
                                id_bos='[CLS]',
                                id_eos='[EOS]',
                                id_unk='[UNK]',
                                max_sequence_length=args.max_sequence_length,
                                vocab_size=0,
                                file_list=args.train_file_list,
                                label_list=args.train_label_list,
                                tokenizer=tokenizer)
        with open('./{}_trainData.pkl'.format(args.task), 'wb') as f:
            pickle.dump(trainData, f)

    add_log("Start train process.")

    ae_model.train()
    dis_model.train()
    ae_model.to(device)
    dis_model.to(device)
    '''
    Fixing or distilling BERT encoder
    '''
    if args.fix_first_6:
        print("Try fixing first 6 bertlayers")
        for layer in range(6):
            for param in ae_model.bert_encoder.encoder.layer[layer].parameters(
            ):
                param.requires_grad = False
    elif args.fix_last_6:
        print("Try fixing last 6 bertlayers")
        for layer in range(6, 12):
            for param in ae_model.bert_encoder.encoder.layer[layer].parameters(
            ):
                param.requires_grad = False

    if args.distill_2:
        print("Get result from layer 2")
        for layer in range(2, 12):
            for param in ae_model.bert_encoder.encoder.layer[layer].parameters(
            ):
                param.requires_grad = False

    ae_optimizer = NoamOpt(
        ae_model.d_model, 1, 2000,
        torch.optim.Adam(ae_model.parameters(),
                         lr=0,
                         betas=(0.9, 0.98),
                         eps=1e-9))
    dis_optimizer = torch.optim.Adam(dis_model.parameters(), lr=0.0001)

    #ae_criterion = get_cuda(LabelSmoothing(size=args.vocab_size, padding_idx=args.id_pad, smoothing=0.1))
    ae_criterion = LabelSmoothing(size=ae_model.bert_encoder.config.vocab_size,
                                  padding_idx=0,
                                  smoothing=0.1).to(device)
    dis_criterion = nn.BCELoss(reduction='mean')

    history = {'train': []}

    for epoch in range(args.epochs):
        print('-' * 94)
        epoch_start_time = time.time()
        total_rec_loss = 0
        total_dis_loss = 0

        train_data_loader = DataLoader(trainData,
                                       batch_size=args.batch_size,
                                       shuffle=True,
                                       collate_fn=trainData.collate_fn,
                                       num_workers=4)
        num_batch = len(train_data_loader)
        trange = tqdm(enumerate(train_data_loader),
                      total=num_batch,
                      desc='Training',
                      file=sys.stdout,
                      position=0,
                      leave=True)

        for it, data in trange:
            batch_sentences, tensor_labels, tensor_src, tensor_src_mask, tensor_tgt, tensor_tgt_y, tensor_tgt_mask, tensor_ntokens = data

            tensor_labels = tensor_labels.to(device)
            tensor_src = tensor_src.to(device)
            tensor_tgt = tensor_tgt.to(device)
            tensor_tgt_y = tensor_tgt_y.to(device)
            tensor_src_mask = tensor_src_mask.to(device)
            tensor_tgt_mask = tensor_tgt_mask.to(device)

            # Forward pass
            latent, out = ae_model.forward(tensor_src, tensor_tgt,
                                           tensor_src_mask, tensor_tgt_mask)

            # Loss calculation
            loss_rec = ae_criterion(
                out.contiguous().view(-1, out.size(-1)),
                tensor_tgt_y.contiguous().view(-1)) / tensor_ntokens.data

            ae_optimizer.optimizer.zero_grad()
            loss_rec.backward()
            ae_optimizer.step()

            latent = latent.detach()
            next_latent = latent.to(device)

            # Classifier
            dis_lop = dis_model.forward(next_latent)
            loss_dis = dis_criterion(dis_lop, tensor_labels)

            dis_optimizer.zero_grad()
            loss_dis.backward()
            dis_optimizer.step()

            total_rec_loss += loss_rec.item()
            total_dis_loss += loss_dis.item()

            trange.set_postfix(total_rec_loss=total_rec_loss / (it + 1),
                               total_dis_loss=total_dis_loss / (it + 1))

            if it % 100 == 0:
                add_log(
                    '| epoch {:3d} | {:5d}/{:5d} batches | rec loss {:5.4f} | dis loss {:5.4f} |'
                    .format(epoch, it, num_batch, loss_rec, loss_dis))

                print(id2text_sentence(tensor_tgt_y[0], tokenizer, args.task))
                generator_text = ae_model.greedy_decode(
                    latent, max_len=args.max_sequence_length, start_id=bos_id)
                print(id2text_sentence(generator_text[0], tokenizer,
                                       args.task))

                # Save model
                #torch.save(ae_model.state_dict(), args.current_save_path / 'ae_model_params.pkl')
                #torch.save(dis_model.state_dict(), args.current_save_path / 'dis_model_params.pkl')

        history['train'].append({
            'epoch': epoch,
            'total_rec_loss': total_rec_loss / len(trange),
            'total_dis_loss': total_dis_loss / len(trange)
        })

        add_log('| end of epoch {:3d} | time: {:5.2f}s |'.format(
            epoch, (time.time() - epoch_start_time)))
        # Save model
        torch.save(ae_model.state_dict(),
                   args.current_save_path / 'ae_model_params.pkl')
        torch.save(dis_model.state_dict(),
                   args.current_save_path / 'dis_model_params.pkl')

    print("Save in ", args.current_save_path)
    return
Beispiel #4
0
        print('<<<<< Evaluate loss: %f' % dev_loss)
        # 如果当前epoch的模型在dev集上的loss优于之前记录的最优loss则保存当前模型,并更新最优loss值
        if dev_loss < best_dev_loss:
            torch.save(model.state_dict(), SAVE_FILE)
            best_dev_loss = dev_loss
            print('****** Save model done... ******')
        print()


if __name__ == '__main__':
    print('处理数据')
    data = PrepareData(TRAIN_FILE, DEV_FILE)

    print('>>>开始训练')
    train_start = time.time()
    # 损失函数
    criterion = LabelSmoothing(TGT_VOCAB, padding_idx=0, smoothing=0.0)
    # 优化器
    optimizer = NoamOpt(
        D_MODEL, 1, 2000,
        torch.optim.Adam(model.parameters(), lr=0, betas=(0.9, 0.98),
                         eps=1e-9))

    train(data, model, criterion, optimizer)
    print(f'<<<训练结束, 花费时间 {time.time() - train_start:.4f}秒')

    # 对测试数据集进行测试
    # print('开始测试')
    # from test import evaluate_test
    # evaluate_test(data, model)
def val(ae_model, dis_model, eval_data_loader, epoch, args):

    print("Transformer Validation process....")
    ae_model.eval()

    print('-' * 94)
    epoch_start = time.time()

    loss_ae = list()
    loss_dis = list()

    acc = list()

    ae_criterion = get_cuda(
        LabelSmoothing(size=args.vocab_size,
                       padding_idx=args.id_pad,
                       smoothing=0.1), args.gpu)
    dis_criterion = torch.nn.BCELoss(size_average=True)
    for it in range(eval_data_loader.num_batch):
        ####################
        #####load data######
        ####################
        batch_sentences, tensor_labels, \
            tensor_src, tensor_src_mask, tensor_tgt, tensor_tgt_y, \
            tensor_tgt_mask, tensor_ntokens = eval_data_loader.next_batch()

        latent = ae_model.getLatent(tensor_src, tensor_src_mask)  #(128, 256)
        style, similarity = ae_model.getSim(
            latent)  #style (128, 2, 256), sim(128, 2)
        dis_out = dis_model.forward(similarity)
        one = get_cuda(torch.tensor(1), args.gpu)
        zero = get_cuda(torch.tensor(0), args.gpu)
        style_pred = torch.where(dis_out > 0.5, one, zero)
        style_pred = style_pred.reshape(style_pred.size(0))
        style_emb = get_cuda(style.clone()[torch.arange(style.size(0)),
                                           tensor_labels.squeeze().long()],
                             args.gpu)  #(128, 256)

        add_latent = latent + style_emb  #batch, dim
        out = ae_model.getOutput(add_latent, tensor_tgt, tensor_tgt_mask)
        loss_rec = ae_criterion(
            out.contiguous().view(-1, out.size(-1)),
            tensor_tgt_y.contiguous().view(-1)) / tensor_ntokens.data

        loss_style = dis_criterion(dis_out, tensor_labels)

        pred = style_pred.to('cpu').detach().tolist()
        true = tensor_labels.squeeze().to('cpu').tolist()

        dis_acc = accuracy_score(pred, true)
        acc.append(dis_acc)

        loss_ae.append(loss_rec.item())
        loss_dis.append(loss_style.item())

        if it % 200 == 0:
            print(
                '| epoch {:3d} | {:5d}/{:5d} batches |\n| rec loss {:5.4f} | dis loss {:5.4f} |\n'
                .format(epoch, it, eval_data_loader.num_batch, loss_rec.item(),
                        loss_style.item()))

    print('| end of epoch {:3d} | time: {:5.2f}s |'.format(
        epoch, (time.time() - epoch_start)))

    return np.mean(loss_ae), np.mean(loss_dis), np.mean(acc)
Beispiel #6
0
def sedat_train(args, ae_model, f, deb):
    """
    Input: 
        Original latent representation z : (n_batch, batch_size, seq_length, latent_size)
    Output: 
        An optimal modified latent representation z'
    """
    # TODO : fin a metric to control the evelotuion of training, mainly for deb model
    lambda_ = args.sedat_threshold
    alpha, beta = [float(coef) for coef in args.sedat_alpha_beta.split(",")]
    # only on negative example
    only_on_negative_example = args.sedat_only_on_negative_example
    penalty = args.penalty
    type_penalty = args.type_penalty

    assert penalty in ["lasso", "ridge"]
    assert type_penalty in ["last", "group"]

    train_data_loader = non_pair_data_loader(
        batch_size=args.batch_size,
        id_bos=args.id_bos,
        id_eos=args.id_eos,
        id_unk=args.id_unk,
        max_sequence_length=args.max_sequence_length,
        vocab_size=args.vocab_size)
    file_list = [args.train_data_file]
    if os.path.exists(args.val_data_file):
        file_list.append(args.val_data_file)
    train_data_loader.create_batches(args,
                                     file_list,
                                     if_shuffle=True,
                                     n_samples=args.train_n_samples)

    add_log(args, "Start train process.")

    #add_log("Start train process.")
    ae_model.train()
    f.train()
    deb.train()

    ae_optimizer = get_optimizer(parameters=ae_model.parameters(),
                                 s=args.ae_optimizer,
                                 noamopt=args.ae_noamopt)
    dis_optimizer = get_optimizer(parameters=f.parameters(),
                                  s=args.dis_optimizer)
    deb_optimizer = get_optimizer(parameters=deb.parameters(),
                                  s=args.dis_optimizer)

    ae_criterion = get_cuda(
        LabelSmoothing(size=args.vocab_size,
                       padding_idx=args.id_pad,
                       smoothing=0.1), args)
    dis_criterion = nn.BCELoss(size_average=True)
    deb_criterion = LossSedat(penalty=penalty)

    stats = []
    for epoch in range(args.max_epochs):
        print('-' * 94)
        epoch_start_time = time.time()

        loss_ae, n_words_ae, xe_loss_ae, n_valid_ae = 0, 0, 0, 0
        loss_clf, total_clf, n_valid_clf = 0, 0, 0
        for it in range(train_data_loader.num_batch):
            _, tensor_labels, \
            tensor_src, tensor_src_mask, tensor_src_attn_mask, tensor_tgt, tensor_tgt_y, \
            tensor_tgt_mask, _ = train_data_loader.next_batch()
            flag = True
            # only on negative example
            if only_on_negative_example:
                negative_examples = ~(tensor_labels.squeeze()
                                      == args.positive_label)
                tensor_labels = tensor_labels[negative_examples].squeeze(
                    0)  # .view(1, -1)
                tensor_src = tensor_src[negative_examples].squeeze(0)
                tensor_src_mask = tensor_src_mask[negative_examples].squeeze(0)
                tensor_src_attn_mask = tensor_src_attn_mask[
                    negative_examples].squeeze(0)
                tensor_tgt_y = tensor_tgt_y[negative_examples].squeeze(0)
                tensor_tgt = tensor_tgt[negative_examples].squeeze(0)
                tensor_tgt_mask = tensor_tgt_mask[negative_examples].squeeze(0)
                flag = negative_examples.any()
            if flag:
                # forward
                z, out, z_list = ae_model.forward(tensor_src,
                                                  tensor_tgt,
                                                  tensor_src_mask,
                                                  tensor_src_attn_mask,
                                                  tensor_tgt_mask,
                                                  return_intermediate=True)
                #y_hat = f.forward(to_var(z.clone()))
                y_hat = f.forward(z)

                loss_dis = dis_criterion(y_hat, tensor_labels)
                dis_optimizer.zero_grad()
                loss_dis.backward(retain_graph=True)
                dis_optimizer.step()

                dis_lop = f.forward(z)
                t_c = tensor_labels.view(-1).size(0)
                n_v = (dis_lop.round().int() == tensor_labels).sum().item()
                loss_clf += loss_dis.item()
                total_clf += t_c
                n_valid_clf += n_v
                clf_acc = 100. * n_v / (t_c + eps)
                avg_clf_acc = 100. * n_valid_clf / (total_clf + eps)
                avg_clf_loss = loss_clf / (it + 1)

                mask_deb = y_hat.squeeze(
                ) >= lambda_ if args.positive_label == 0 else y_hat.squeeze(
                ) < lambda_
                # if f(z) > lambda :
                if mask_deb.any():
                    y_hat_deb = y_hat[mask_deb]
                    if type_penalty == "last":
                        z_deb = z[mask_deb].squeeze(
                            0) if args.batch_size == 1 else z[mask_deb]
                    elif type_penalty == "group":
                        # TODO : unit test for bach_size = 1
                        z_deb = z_list[-1][mask_deb]
                    z_prime, z_prime_list = deb(z_deb,
                                                mask=None,
                                                return_intermediate=True)
                    if type_penalty == "last":
                        z_prime = torch.sum(ae_model.sigmoid(z_prime), dim=1)
                        loss_deb = alpha * deb_criterion(
                            z_deb, z_prime,
                            is_list=False) + beta * y_hat_deb.sum()
                    elif type_penalty == "group":
                        z_deb_list = [z_[mask_deb] for z_ in z_list]
                        #assert len(z_deb_list) == len(z_prime_list)
                        loss_deb = alpha * deb_criterion(
                            z_deb_list, z_prime_list,
                            is_list=True) + beta * y_hat_deb.sum()

                    deb_optimizer.zero_grad()
                    loss_deb.backward(retain_graph=True)
                    deb_optimizer.step()
                else:
                    loss_deb = torch.tensor(float("nan"))

                # else :
                if (~mask_deb).any():
                    out_ = out[~mask_deb]
                    tensor_tgt_y_ = tensor_tgt_y[~mask_deb]
                    tensor_ntokens = (tensor_tgt_y_ != 0).data.sum().float()
                    loss_rec = ae_criterion(
                        out_.contiguous().view(-1, out_.size(-1)),
                        tensor_tgt_y_.contiguous().view(-1)) / (
                            tensor_ntokens.data + eps)
                else:
                    loss_rec = torch.tensor(float("nan"))

                ae_optimizer.zero_grad()
                (loss_dis + loss_deb + loss_rec).backward()
                ae_optimizer.step()

                if True:
                    n_v, n_w = get_n_v_w(tensor_tgt_y, out)
                else:
                    n_w = float("nan")
                    n_v = float("nan")

                x_e = loss_rec.item() * n_w
                loss_ae += loss_rec.item()
                n_words_ae += n_w
                xe_loss_ae += x_e
                n_valid_ae += n_v
                ae_acc = 100. * n_v / (n_w + eps)
                avg_ae_acc = 100. * n_valid_ae / (n_words_ae + eps)
                avg_ae_loss = loss_ae / (it + 1)
                ae_ppl = np.exp(x_e / (n_w + eps))
                avg_ae_ppl = np.exp(xe_loss_ae / (n_words_ae + eps))

                x_e = loss_rec.item() * n_w
                loss_ae += loss_rec.item()
                n_words_ae += n_w
                xe_loss_ae += x_e
                n_valid_ae += n_v

                if it % args.log_interval == 0:
                    add_log(args, "")
                    add_log(
                        args, 'epoch {:3d} | {:5d}/{:5d} batches |'.format(
                            epoch, it, train_data_loader.num_batch))
                    add_log(
                        args,
                        'Train : rec acc {:5.4f} | rec loss {:5.4f} | ppl {:5.4f} | dis acc {:5.4f} | dis loss {:5.4f} |'
                        .format(ae_acc, loss_rec.item(), ae_ppl, clf_acc,
                                loss_dis.item()))
                    add_log(
                        args,
                        'Train : avg : rec acc {:5.4f} | rec loss {:5.4f} | ppl {:5.4f} |  dis acc {:5.4f} | diss loss {:5.4f} |'
                        .format(avg_ae_acc, avg_ae_loss, avg_ae_ppl,
                                avg_clf_acc, avg_clf_loss))

                    add_log(
                        args, "input : %s" %
                        id2text_sentence(tensor_tgt_y[0], args.id_to_word))
                    generator_text = ae_model.greedy_decode(
                        z,
                        max_len=args.max_sequence_length,
                        start_id=args.id_bos)
                    # batch_sentences
                    add_log(
                        args, "gen : %s" %
                        id2text_sentence(generator_text[0], args.id_to_word))
                    if mask_deb.any():
                        generator_text_prime = ae_model.greedy_decode(
                            z_prime,
                            max_len=args.max_sequence_length,
                            start_id=args.id_bos)

                        add_log(
                            args, "deb : %s" % id2text_sentence(
                                generator_text_prime[0], args.id_to_word))

        s = {}
        L = train_data_loader.num_batch + eps
        s["train_ae_loss"] = loss_ae / L
        s["train_ae_acc"] = 100. * n_valid_ae / (n_words_ae + eps)
        s["train_ae_ppl"] = np.exp(xe_loss_ae / (n_words_ae + eps))
        s["train_clf_loss"] = loss_clf / L
        s["train_clf_acc"] = 100. * n_valid_clf / (total_clf + eps)
        stats.append(s)

        add_log(args, "")
        add_log(
            args, '| end of epoch {:3d} | time: {:5.2f}s |'.format(
                epoch, (time.time() - epoch_start_time)))

        add_log(
            args,
            '| rec acc {:5.4f} | rec loss {:5.4f} | rec ppl {:5.4f} | dis acc {:5.4f} | dis loss {:5.4f} |'
            .format(s["train_ae_acc"], s["train_ae_loss"], s["train_ae_ppl"],
                    s["train_clf_acc"], s["train_clf_loss"]))

        # Save model
        torch.save(
            ae_model.state_dict(),
            os.path.join(args.current_save_path, 'ae_model_params_deb.pkl'))
        torch.save(
            f.state_dict(),
            os.path.join(args.current_save_path, 'dis_model_params_deb.pkl'))
        torch.save(
            deb.state_dict(),
            os.path.join(args.current_save_path, 'deb_model_params_deb.pkl'))

    add_log(args, "Saving training statistics %s ..." % args.current_save_path)
    torch.save(stats,
               os.path.join(args.current_save_path, 'stats_train_deb.pkl'))
Beispiel #7
0
def pretrain(args, ae_model, dis_model):
    train_data_loader = non_pair_data_loader(
        batch_size=args.batch_size,
        id_bos=args.id_bos,
        id_eos=args.id_eos,
        id_unk=args.id_unk,
        max_sequence_length=args.max_sequence_length,
        vocab_size=args.vocab_size)
    train_data_loader.create_batches(args, [args.train_data_file],
                                     if_shuffle=True,
                                     n_samples=args.train_n_samples)

    val_data_loader = non_pair_data_loader(
        batch_size=args.batch_size,
        id_bos=args.id_bos,
        id_eos=args.id_eos,
        id_unk=args.id_unk,
        max_sequence_length=args.max_sequence_length,
        vocab_size=args.vocab_size)
    val_data_loader.create_batches(args, [args.val_data_file],
                                   if_shuffle=True,
                                   n_samples=args.valid_n_samples)

    ae_model.train()
    dis_model.train()

    ae_optimizer = get_optimizer(parameters=ae_model.parameters(),
                                 s=args.ae_optimizer,
                                 noamopt=args.ae_noamopt)
    dis_optimizer = get_optimizer(parameters=dis_model.parameters(),
                                  s=args.dis_optimizer)

    ae_criterion = get_cuda(
        LabelSmoothing(size=args.vocab_size,
                       padding_idx=args.id_pad,
                       smoothing=0.1), args)
    dis_criterion = nn.BCELoss(size_average=True)

    possib = [
        "%s_%s" % (i, j) for i, j in itertools.product(
            ["train", "eval"],
            ["ae_loss", "ae_acc", "ae_ppl", "clf_loss", "clf_acc"])
    ]
    stopping_criterion, best_criterion, decrease_counts, decrease_counts_max = settings(
        args, possib)
    metric, biggest = stopping_criterion
    factor = 1 if biggest else -1

    stats = []

    add_log(args, "Start train process.")
    for epoch in range(args.max_epochs):
        print('-' * 94)
        add_log(args, "")
        s_train = train_step(args, train_data_loader, ae_model, dis_model,
                             ae_optimizer, dis_optimizer, ae_criterion,
                             dis_criterion, epoch)
        add_log(args, "")
        s_eval = eval_step(args, val_data_loader, ae_model, dis_model,
                           ae_criterion, dis_criterion)
        scores = {**s_train, **s_eval}
        stats.append(scores)
        add_log(args, "")
        if factor * scores[metric] > factor * best_criterion:
            best_criterion = scores[metric]
            add_log(args, "New best validation score: %f" % best_criterion)
            decrease_counts = 0
            # Save model
            add_log(args, "Saving model to %s ..." % args.current_save_path)
            torch.save(
                ae_model.state_dict(),
                os.path.join(args.current_save_path, 'ae_model_params.pkl'))
            torch.save(
                dis_model.state_dict(),
                os.path.join(args.current_save_path, 'dis_model_params.pkl'))
        else:
            add_log(
                args, "Not a better validation score (%i / %i)." %
                (decrease_counts, decrease_counts_max))
            decrease_counts += 1
        if decrease_counts > decrease_counts_max:
            add_log(
                args,
                "Stopping criterion has been below its best value for more "
                "than %i epochs. Ending the experiment..." %
                decrease_counts_max)
            #exit()
            break

    s_test = None
    if os.path.exists(args.test_data_file):
        add_log(args, "")
        test_data_loader = non_pair_data_loader(
            batch_size=args.batch_size,
            id_bos=args.id_bos,
            id_eos=args.id_eos,
            id_unk=args.id_unk,
            max_sequence_length=args.max_sequence_length,
            vocab_size=args.vocab_size)
        test_data_loader.create_batches(args, [args.test_data_file],
                                        if_shuffle=True,
                                        n_samples=args.test_n_samples)
        s = eval_step(args, test_data_loader, ae_model, dis_model,
                      ae_criterion, dis_criterion)
        add_log(
            args,
            'Test | rec acc {:5.4f} | rec loss {:5.4f} | rec ppl {:5.4f} | dis acc {:5.4f} | dis loss {:5.4f} |'
            .format(s["eval_ae_acc"], s["eval_ae_loss"], s["eval_ae_ppl"],
                    s["eval_clf_acc"], s["eval_clf_loss"]))
        s_test = s
    add_log(args, "")
    add_log(args, "Saving training statistics %s ..." % args.current_save_path)
    torch.save(stats,
               os.path.join(args.current_save_path, 'stats_train_eval.pkl'))
    if s_test is not None:
        torch.save(s_test, os.path.join(args.current_save_path,
                                        'stat_test.pkl'))
    return stats, s_test