Exemplo n.º 1
0
def train():
    start_epoch = 0
    checkpoint_path = 'BEST_checkpoint.tar'
    best_loss = float('inf')
    epochs_since_improvement = 0

    if os.path.exists(checkpoint_path):
        print('load checkpoint...')
        checkpoint = torch.load(checkpoint_path)
        start_epoch = checkpoint['epoch'] + 1
        model = checkpoint['model']
        epochs_since_improvement = checkpoint['epochs_since_improvement']
        optimizer = checkpoint['optimizer']
    else:
        print('train from begining...')
        encoder = Encoder(vocab_size, hid_dim, n_layers, n_heads, pf_dim, EncoderLayer, SelfAttention, PositionwiseFeedforward, drop_out, device)
        decoder = Decoder(vocab_size, hid_dim, n_layers, n_heads, pf_dim, DecoderLayer, SelfAttention, PositionwiseFeedforward, drop_out, device)

        model = Seq2Seq(encoder, decoder, pad_idx, device).to(device)
        optimizer = NoamOpt(hid_dim, 1, 2000, torch.optim.Adam(model.parameters()))

    train_dataset = AiChallenger2017Dataset('train')
    train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, collate_fn=pad_collate,
                                               shuffle=True, num_workers=num_workers)
    valid_dataset = AiChallenger2017Dataset('valid')
    valid_loader = torch.utils.data.DataLoader(valid_dataset, batch_size=batch_size, collate_fn=pad_collate,
                                               shuffle=True, num_workers=num_workers)
    print('train size', len(train_dataset), 'valid size', len(valid_dataset))

    for p in model.parameters():
        if p.dim() > 1:
            nn.init.xavier_uniform_(p)

    criterion = nn.CrossEntropyLoss(ignore_index=pad_idx)

    for i in range(start_epoch, epoch):
        train_loss = train_epoch(model, train_loader, optimizer, criterion)
        valid_loss = value_epoch(model, valid_loader, criterion)
        print('epoch', i, 'avg train loss', train_loss, 'avg valid loss', valid_loss)
        if valid_loss < best_loss:
            best_loss = valid_loss
            epochs_since_improvement = 0
            save_checkpoint(i, epochs_since_improvement, model, optimizer, best_loss)
        else:
            epochs_since_improvement += 1
def train_iters(ae_model, dis_model):
    train_data_loader = non_pair_data_loader(
        batch_size=args.batch_size,
        id_bos=args.id_bos,
        id_eos=args.id_eos,
        id_unk=args.id_unk,
        max_sequence_length=args.max_sequence_length,
        vocab_size=args.vocab_size)
    train_data_loader.create_batches(args.train_file_list,
                                     args.train_label_list,
                                     if_shuffle=True)
    add_log("Start train process.")
    ae_model.train()
    dis_model.train()

    ae_optimizer = NoamOpt(
        ae_model.src_embed[0].d_model, 1, 2000,
        torch.optim.Adam(ae_model.parameters(),
                         lr=0,
                         betas=(0.9, 0.98),
                         eps=1e-9))
    dis_optimizer = torch.optim.Adam(dis_model.parameters(), lr=0.0001)

    ae_criterion = get_cuda(
        LabelSmoothing(size=args.vocab_size,
                       padding_idx=args.id_pad,
                       smoothing=0.1))
    dis_criterion = nn.BCELoss(size_average=True)

    for epoch in range(200):
        print('-' * 94)
        epoch_start_time = time.time()
        for it in range(train_data_loader.num_batch):
            batch_sentences, tensor_labels, \
            tensor_src, tensor_src_mask, tensor_tgt, tensor_tgt_y, \
            tensor_tgt_mask, tensor_ntokens = train_data_loader.next_batch()

            # For debug
            # print(batch_sentences[0])
            # print(tensor_src[0])
            # print(tensor_src_mask[0])
            # print("tensor_src_mask", tensor_src_mask.size())
            # print(tensor_tgt[0])
            # print(tensor_tgt_y[0])
            # print(tensor_tgt_mask[0])
            # print(batch_ntokens)

            # Forward pass
            latent, out = ae_model.forward(tensor_src, tensor_tgt,
                                           tensor_src_mask, tensor_tgt_mask)
            # print(latent.size())  # (batch_size, max_src_seq, d_model)
            # print(out.size())  # (batch_size, max_tgt_seq, vocab_size)

            # Loss calculation
            loss_rec = ae_criterion(
                out.contiguous().view(-1, out.size(-1)),
                tensor_tgt_y.contiguous().view(-1)) / tensor_ntokens.data

            # loss_all = loss_rec + loss_dis

            ae_optimizer.optimizer.zero_grad()
            loss_rec.backward()
            ae_optimizer.step()

            # Classifier
            dis_lop = dis_model.forward(to_var(latent.clone()))

            loss_dis = dis_criterion(dis_lop, tensor_labels)

            dis_optimizer.zero_grad()
            loss_dis.backward()
            dis_optimizer.step()

            if it % 200 == 0:
                add_log(
                    '| epoch {:3d} | {:5d}/{:5d} batches | rec loss {:5.4f} | dis loss {:5.4f} |'
                    .format(epoch, it, train_data_loader.num_batch, loss_rec,
                            loss_dis))

                print(id2text_sentence(tensor_tgt_y[0], args.id_to_word))
                generator_text = ae_model.greedy_decode(
                    latent,
                    max_len=args.max_sequence_length,
                    start_id=args.id_bos)
                print(id2text_sentence(generator_text[0], args.id_to_word))

        add_log('| end of epoch {:3d} | time: {:5.2f}s |'.format(
            epoch, (time.time() - epoch_start_time)))
        # Save model
        torch.save(ae_model.state_dict(),
                   args.current_save_path + 'ae_model_params.pkl')
        torch.save(dis_model.state_dict(),
                   args.current_save_path + 'dis_model_params.pkl')
    return
def train_iters(ae_model, dis_model):
    if args.use_albert:
        tokenizer = BertTokenizer.from_pretrained("clue/albert_chinese_tiny",
                                                  do_lower_case=True)
    elif args.use_tiny_bert:
        tokenizer = AutoTokenizer.from_pretrained(
            "google/bert_uncased_L-2_H-256_A-4", do_lower_case=True)
    elif args.use_distil_bert:
        tokenizer = DistilBertTokenizer.from_pretrained(
            'distilbert-base-uncased', do_lower_case=True)
    # tokenizer = BertTokenizer.from_pretrained(args.PRETRAINED_MODEL_NAME, do_lower_case=True)
    tokenizer.add_tokens('[EOS]')
    bos_id = tokenizer.convert_tokens_to_ids(['[CLS]'])[0]

    ae_model.bert_encoder.resize_token_embeddings(len(tokenizer))
    #print("[CLS] ID: ", bos_id)

    print("Load trainData...")
    if args.load_trainData and os.path.exists('./{}_trainData.pkl'.format(
            args.task)):
        with open('./{}_trainData.pkl'.format(args.task), 'rb') as f:
            trainData = pickle.load(f)
    else:
        trainData = TextDataset(batch_size=args.batch_size,
                                id_bos='[CLS]',
                                id_eos='[EOS]',
                                id_unk='[UNK]',
                                max_sequence_length=args.max_sequence_length,
                                vocab_size=0,
                                file_list=args.train_file_list,
                                label_list=args.train_label_list,
                                tokenizer=tokenizer)
        with open('./{}_trainData.pkl'.format(args.task), 'wb') as f:
            pickle.dump(trainData, f)

    add_log("Start train process.")

    ae_model.train()
    dis_model.train()
    ae_model.to(device)
    dis_model.to(device)
    '''
    Fixing or distilling BERT encoder
    '''
    if args.fix_first_6:
        print("Try fixing first 6 bertlayers")
        for layer in range(6):
            for param in ae_model.bert_encoder.encoder.layer[layer].parameters(
            ):
                param.requires_grad = False
    elif args.fix_last_6:
        print("Try fixing last 6 bertlayers")
        for layer in range(6, 12):
            for param in ae_model.bert_encoder.encoder.layer[layer].parameters(
            ):
                param.requires_grad = False

    if args.distill_2:
        print("Get result from layer 2")
        for layer in range(2, 12):
            for param in ae_model.bert_encoder.encoder.layer[layer].parameters(
            ):
                param.requires_grad = False

    ae_optimizer = NoamOpt(
        ae_model.d_model, 1, 2000,
        torch.optim.Adam(ae_model.parameters(),
                         lr=0,
                         betas=(0.9, 0.98),
                         eps=1e-9))
    dis_optimizer = torch.optim.Adam(dis_model.parameters(), lr=0.0001)

    #ae_criterion = get_cuda(LabelSmoothing(size=args.vocab_size, padding_idx=args.id_pad, smoothing=0.1))
    ae_criterion = LabelSmoothing(size=ae_model.bert_encoder.config.vocab_size,
                                  padding_idx=0,
                                  smoothing=0.1).to(device)
    dis_criterion = nn.BCELoss(reduction='mean')

    history = {'train': []}

    for epoch in range(args.epochs):
        print('-' * 94)
        epoch_start_time = time.time()
        total_rec_loss = 0
        total_dis_loss = 0

        train_data_loader = DataLoader(trainData,
                                       batch_size=args.batch_size,
                                       shuffle=True,
                                       collate_fn=trainData.collate_fn,
                                       num_workers=4)
        num_batch = len(train_data_loader)
        trange = tqdm(enumerate(train_data_loader),
                      total=num_batch,
                      desc='Training',
                      file=sys.stdout,
                      position=0,
                      leave=True)

        for it, data in trange:
            batch_sentences, tensor_labels, tensor_src, tensor_src_mask, tensor_tgt, tensor_tgt_y, tensor_tgt_mask, tensor_ntokens = data

            tensor_labels = tensor_labels.to(device)
            tensor_src = tensor_src.to(device)
            tensor_tgt = tensor_tgt.to(device)
            tensor_tgt_y = tensor_tgt_y.to(device)
            tensor_src_mask = tensor_src_mask.to(device)
            tensor_tgt_mask = tensor_tgt_mask.to(device)

            # Forward pass
            latent, out = ae_model.forward(tensor_src, tensor_tgt,
                                           tensor_src_mask, tensor_tgt_mask)

            # Loss calculation
            loss_rec = ae_criterion(
                out.contiguous().view(-1, out.size(-1)),
                tensor_tgt_y.contiguous().view(-1)) / tensor_ntokens.data

            ae_optimizer.optimizer.zero_grad()
            loss_rec.backward()
            ae_optimizer.step()

            latent = latent.detach()
            next_latent = latent.to(device)

            # Classifier
            dis_lop = dis_model.forward(next_latent)
            loss_dis = dis_criterion(dis_lop, tensor_labels)

            dis_optimizer.zero_grad()
            loss_dis.backward()
            dis_optimizer.step()

            total_rec_loss += loss_rec.item()
            total_dis_loss += loss_dis.item()

            trange.set_postfix(total_rec_loss=total_rec_loss / (it + 1),
                               total_dis_loss=total_dis_loss / (it + 1))

            if it % 100 == 0:
                add_log(
                    '| epoch {:3d} | {:5d}/{:5d} batches | rec loss {:5.4f} | dis loss {:5.4f} |'
                    .format(epoch, it, num_batch, loss_rec, loss_dis))

                print(id2text_sentence(tensor_tgt_y[0], tokenizer, args.task))
                generator_text = ae_model.greedy_decode(
                    latent, max_len=args.max_sequence_length, start_id=bos_id)
                print(id2text_sentence(generator_text[0], tokenizer,
                                       args.task))

                # Save model
                #torch.save(ae_model.state_dict(), args.current_save_path / 'ae_model_params.pkl')
                #torch.save(dis_model.state_dict(), args.current_save_path / 'dis_model_params.pkl')

        history['train'].append({
            'epoch': epoch,
            'total_rec_loss': total_rec_loss / len(trange),
            'total_dis_loss': total_dis_loss / len(trange)
        })

        add_log('| end of epoch {:3d} | time: {:5.2f}s |'.format(
            epoch, (time.time() - epoch_start_time)))
        # Save model
        torch.save(ae_model.state_dict(),
                   args.current_save_path / 'ae_model_params.pkl')
        torch.save(dis_model.state_dict(),
                   args.current_save_path / 'dis_model_params.pkl')

    print("Save in ", args.current_save_path)
    return
Exemplo n.º 4
0
        print('<<<<< Evaluate loss: %f' % dev_loss)
        # 如果当前epoch的模型在dev集上的loss优于之前记录的最优loss则保存当前模型,并更新最优loss值
        if dev_loss < best_dev_loss:
            torch.save(model.state_dict(), SAVE_FILE)
            best_dev_loss = dev_loss
            print('****** Save model done... ******')
        print()


if __name__ == '__main__':
    print('处理数据')
    data = PrepareData(TRAIN_FILE, DEV_FILE)

    print('>>>开始训练')
    train_start = time.time()
    # 损失函数
    criterion = LabelSmoothing(TGT_VOCAB, padding_idx=0, smoothing=0.0)
    # 优化器
    optimizer = NoamOpt(
        D_MODEL, 1, 2000,
        torch.optim.Adam(model.parameters(), lr=0, betas=(0.9, 0.98),
                         eps=1e-9))

    train(data, model, criterion, optimizer)
    print(f'<<<训练结束, 花费时间 {time.time() - train_start:.4f}秒')

    # 对测试数据集进行测试
    # print('开始测试')
    # from test import evaluate_test
    # evaluate_test(data, model)
Exemplo n.º 5
0
######################################################################################
#  End of hyper parameters
######################################################################################

if __name__ == '__main__':
    preparation(args)

    ae_model = get_cuda(make_model(d_vocab=args.vocab_size,
                                   N=args.num_layers_AE,
                                   d_model=args.transformer_model_size,
                                   latent_size=args.latent_size,
                                   gpu=args.gpu, 
                                   d_ff=args.transformer_ff_size), args.gpu)

    dis_model=get_cuda(Classifier(1, args),args.gpu)
    ae_optimizer = NoamOpt(ae_model.src_embed[0].d_model, 1, 2000,
                           torch.optim.Adam(ae_model.parameters(), lr=0, betas=(0.9, 0.98), eps=1e-9))
    dis_optimizer=torch.optim.Adam(dis_model.parameters(), lr=0.0001)
    if args.load_model:
        # Load models' params from checkpoint
        ae_model.load_state_dict(torch.load(args.current_save_path + '/{}_ae_model_params.pkl'.format(args.load_iter), map_location=device))
        dis_model.load_state_dict(torch.load(args.current_save_path + '/{}_dis_model_params.pkl'.format(args.load_iter), map_location=device))
       
        start=args.load_iter+1 
    else:
        start=0        
    
    train_data_loader=non_pair_data_loader(
        batch_size=args.batch_size, id_bos=args.id_bos,
        id_eos=args.id_eos, id_unk=args.id_unk,
        max_sequence_length=args.max_sequence_length, vocab_size=args.vocab_size,
        gpu=args.gpu