def train(): start_epoch = 0 checkpoint_path = 'BEST_checkpoint.tar' best_loss = float('inf') epochs_since_improvement = 0 if os.path.exists(checkpoint_path): print('load checkpoint...') checkpoint = torch.load(checkpoint_path) start_epoch = checkpoint['epoch'] + 1 model = checkpoint['model'] epochs_since_improvement = checkpoint['epochs_since_improvement'] optimizer = checkpoint['optimizer'] else: print('train from begining...') encoder = Encoder(vocab_size, hid_dim, n_layers, n_heads, pf_dim, EncoderLayer, SelfAttention, PositionwiseFeedforward, drop_out, device) decoder = Decoder(vocab_size, hid_dim, n_layers, n_heads, pf_dim, DecoderLayer, SelfAttention, PositionwiseFeedforward, drop_out, device) model = Seq2Seq(encoder, decoder, pad_idx, device).to(device) optimizer = NoamOpt(hid_dim, 1, 2000, torch.optim.Adam(model.parameters())) train_dataset = AiChallenger2017Dataset('train') train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, collate_fn=pad_collate, shuffle=True, num_workers=num_workers) valid_dataset = AiChallenger2017Dataset('valid') valid_loader = torch.utils.data.DataLoader(valid_dataset, batch_size=batch_size, collate_fn=pad_collate, shuffle=True, num_workers=num_workers) print('train size', len(train_dataset), 'valid size', len(valid_dataset)) for p in model.parameters(): if p.dim() > 1: nn.init.xavier_uniform_(p) criterion = nn.CrossEntropyLoss(ignore_index=pad_idx) for i in range(start_epoch, epoch): train_loss = train_epoch(model, train_loader, optimizer, criterion) valid_loss = value_epoch(model, valid_loader, criterion) print('epoch', i, 'avg train loss', train_loss, 'avg valid loss', valid_loss) if valid_loss < best_loss: best_loss = valid_loss epochs_since_improvement = 0 save_checkpoint(i, epochs_since_improvement, model, optimizer, best_loss) else: epochs_since_improvement += 1
def train_iters(ae_model, dis_model): train_data_loader = non_pair_data_loader( batch_size=args.batch_size, id_bos=args.id_bos, id_eos=args.id_eos, id_unk=args.id_unk, max_sequence_length=args.max_sequence_length, vocab_size=args.vocab_size) train_data_loader.create_batches(args.train_file_list, args.train_label_list, if_shuffle=True) add_log("Start train process.") ae_model.train() dis_model.train() ae_optimizer = NoamOpt( ae_model.src_embed[0].d_model, 1, 2000, torch.optim.Adam(ae_model.parameters(), lr=0, betas=(0.9, 0.98), eps=1e-9)) dis_optimizer = torch.optim.Adam(dis_model.parameters(), lr=0.0001) ae_criterion = get_cuda( LabelSmoothing(size=args.vocab_size, padding_idx=args.id_pad, smoothing=0.1)) dis_criterion = nn.BCELoss(size_average=True) for epoch in range(200): print('-' * 94) epoch_start_time = time.time() for it in range(train_data_loader.num_batch): batch_sentences, tensor_labels, \ tensor_src, tensor_src_mask, tensor_tgt, tensor_tgt_y, \ tensor_tgt_mask, tensor_ntokens = train_data_loader.next_batch() # For debug # print(batch_sentences[0]) # print(tensor_src[0]) # print(tensor_src_mask[0]) # print("tensor_src_mask", tensor_src_mask.size()) # print(tensor_tgt[0]) # print(tensor_tgt_y[0]) # print(tensor_tgt_mask[0]) # print(batch_ntokens) # Forward pass latent, out = ae_model.forward(tensor_src, tensor_tgt, tensor_src_mask, tensor_tgt_mask) # print(latent.size()) # (batch_size, max_src_seq, d_model) # print(out.size()) # (batch_size, max_tgt_seq, vocab_size) # Loss calculation loss_rec = ae_criterion( out.contiguous().view(-1, out.size(-1)), tensor_tgt_y.contiguous().view(-1)) / tensor_ntokens.data # loss_all = loss_rec + loss_dis ae_optimizer.optimizer.zero_grad() loss_rec.backward() ae_optimizer.step() # Classifier dis_lop = dis_model.forward(to_var(latent.clone())) loss_dis = dis_criterion(dis_lop, tensor_labels) dis_optimizer.zero_grad() loss_dis.backward() dis_optimizer.step() if it % 200 == 0: add_log( '| epoch {:3d} | {:5d}/{:5d} batches | rec loss {:5.4f} | dis loss {:5.4f} |' .format(epoch, it, train_data_loader.num_batch, loss_rec, loss_dis)) print(id2text_sentence(tensor_tgt_y[0], args.id_to_word)) generator_text = ae_model.greedy_decode( latent, max_len=args.max_sequence_length, start_id=args.id_bos) print(id2text_sentence(generator_text[0], args.id_to_word)) add_log('| end of epoch {:3d} | time: {:5.2f}s |'.format( epoch, (time.time() - epoch_start_time))) # Save model torch.save(ae_model.state_dict(), args.current_save_path + 'ae_model_params.pkl') torch.save(dis_model.state_dict(), args.current_save_path + 'dis_model_params.pkl') return
def train_iters(ae_model, dis_model): if args.use_albert: tokenizer = BertTokenizer.from_pretrained("clue/albert_chinese_tiny", do_lower_case=True) elif args.use_tiny_bert: tokenizer = AutoTokenizer.from_pretrained( "google/bert_uncased_L-2_H-256_A-4", do_lower_case=True) elif args.use_distil_bert: tokenizer = DistilBertTokenizer.from_pretrained( 'distilbert-base-uncased', do_lower_case=True) # tokenizer = BertTokenizer.from_pretrained(args.PRETRAINED_MODEL_NAME, do_lower_case=True) tokenizer.add_tokens('[EOS]') bos_id = tokenizer.convert_tokens_to_ids(['[CLS]'])[0] ae_model.bert_encoder.resize_token_embeddings(len(tokenizer)) #print("[CLS] ID: ", bos_id) print("Load trainData...") if args.load_trainData and os.path.exists('./{}_trainData.pkl'.format( args.task)): with open('./{}_trainData.pkl'.format(args.task), 'rb') as f: trainData = pickle.load(f) else: trainData = TextDataset(batch_size=args.batch_size, id_bos='[CLS]', id_eos='[EOS]', id_unk='[UNK]', max_sequence_length=args.max_sequence_length, vocab_size=0, file_list=args.train_file_list, label_list=args.train_label_list, tokenizer=tokenizer) with open('./{}_trainData.pkl'.format(args.task), 'wb') as f: pickle.dump(trainData, f) add_log("Start train process.") ae_model.train() dis_model.train() ae_model.to(device) dis_model.to(device) ''' Fixing or distilling BERT encoder ''' if args.fix_first_6: print("Try fixing first 6 bertlayers") for layer in range(6): for param in ae_model.bert_encoder.encoder.layer[layer].parameters( ): param.requires_grad = False elif args.fix_last_6: print("Try fixing last 6 bertlayers") for layer in range(6, 12): for param in ae_model.bert_encoder.encoder.layer[layer].parameters( ): param.requires_grad = False if args.distill_2: print("Get result from layer 2") for layer in range(2, 12): for param in ae_model.bert_encoder.encoder.layer[layer].parameters( ): param.requires_grad = False ae_optimizer = NoamOpt( ae_model.d_model, 1, 2000, torch.optim.Adam(ae_model.parameters(), lr=0, betas=(0.9, 0.98), eps=1e-9)) dis_optimizer = torch.optim.Adam(dis_model.parameters(), lr=0.0001) #ae_criterion = get_cuda(LabelSmoothing(size=args.vocab_size, padding_idx=args.id_pad, smoothing=0.1)) ae_criterion = LabelSmoothing(size=ae_model.bert_encoder.config.vocab_size, padding_idx=0, smoothing=0.1).to(device) dis_criterion = nn.BCELoss(reduction='mean') history = {'train': []} for epoch in range(args.epochs): print('-' * 94) epoch_start_time = time.time() total_rec_loss = 0 total_dis_loss = 0 train_data_loader = DataLoader(trainData, batch_size=args.batch_size, shuffle=True, collate_fn=trainData.collate_fn, num_workers=4) num_batch = len(train_data_loader) trange = tqdm(enumerate(train_data_loader), total=num_batch, desc='Training', file=sys.stdout, position=0, leave=True) for it, data in trange: batch_sentences, tensor_labels, tensor_src, tensor_src_mask, tensor_tgt, tensor_tgt_y, tensor_tgt_mask, tensor_ntokens = data tensor_labels = tensor_labels.to(device) tensor_src = tensor_src.to(device) tensor_tgt = tensor_tgt.to(device) tensor_tgt_y = tensor_tgt_y.to(device) tensor_src_mask = tensor_src_mask.to(device) tensor_tgt_mask = tensor_tgt_mask.to(device) # Forward pass latent, out = ae_model.forward(tensor_src, tensor_tgt, tensor_src_mask, tensor_tgt_mask) # Loss calculation loss_rec = ae_criterion( out.contiguous().view(-1, out.size(-1)), tensor_tgt_y.contiguous().view(-1)) / tensor_ntokens.data ae_optimizer.optimizer.zero_grad() loss_rec.backward() ae_optimizer.step() latent = latent.detach() next_latent = latent.to(device) # Classifier dis_lop = dis_model.forward(next_latent) loss_dis = dis_criterion(dis_lop, tensor_labels) dis_optimizer.zero_grad() loss_dis.backward() dis_optimizer.step() total_rec_loss += loss_rec.item() total_dis_loss += loss_dis.item() trange.set_postfix(total_rec_loss=total_rec_loss / (it + 1), total_dis_loss=total_dis_loss / (it + 1)) if it % 100 == 0: add_log( '| epoch {:3d} | {:5d}/{:5d} batches | rec loss {:5.4f} | dis loss {:5.4f} |' .format(epoch, it, num_batch, loss_rec, loss_dis)) print(id2text_sentence(tensor_tgt_y[0], tokenizer, args.task)) generator_text = ae_model.greedy_decode( latent, max_len=args.max_sequence_length, start_id=bos_id) print(id2text_sentence(generator_text[0], tokenizer, args.task)) # Save model #torch.save(ae_model.state_dict(), args.current_save_path / 'ae_model_params.pkl') #torch.save(dis_model.state_dict(), args.current_save_path / 'dis_model_params.pkl') history['train'].append({ 'epoch': epoch, 'total_rec_loss': total_rec_loss / len(trange), 'total_dis_loss': total_dis_loss / len(trange) }) add_log('| end of epoch {:3d} | time: {:5.2f}s |'.format( epoch, (time.time() - epoch_start_time))) # Save model torch.save(ae_model.state_dict(), args.current_save_path / 'ae_model_params.pkl') torch.save(dis_model.state_dict(), args.current_save_path / 'dis_model_params.pkl') print("Save in ", args.current_save_path) return
print('<<<<< Evaluate loss: %f' % dev_loss) # 如果当前epoch的模型在dev集上的loss优于之前记录的最优loss则保存当前模型,并更新最优loss值 if dev_loss < best_dev_loss: torch.save(model.state_dict(), SAVE_FILE) best_dev_loss = dev_loss print('****** Save model done... ******') print() if __name__ == '__main__': print('处理数据') data = PrepareData(TRAIN_FILE, DEV_FILE) print('>>>开始训练') train_start = time.time() # 损失函数 criterion = LabelSmoothing(TGT_VOCAB, padding_idx=0, smoothing=0.0) # 优化器 optimizer = NoamOpt( D_MODEL, 1, 2000, torch.optim.Adam(model.parameters(), lr=0, betas=(0.9, 0.98), eps=1e-9)) train(data, model, criterion, optimizer) print(f'<<<训练结束, 花费时间 {time.time() - train_start:.4f}秒') # 对测试数据集进行测试 # print('开始测试') # from test import evaluate_test # evaluate_test(data, model)
###################################################################################### # End of hyper parameters ###################################################################################### if __name__ == '__main__': preparation(args) ae_model = get_cuda(make_model(d_vocab=args.vocab_size, N=args.num_layers_AE, d_model=args.transformer_model_size, latent_size=args.latent_size, gpu=args.gpu, d_ff=args.transformer_ff_size), args.gpu) dis_model=get_cuda(Classifier(1, args),args.gpu) ae_optimizer = NoamOpt(ae_model.src_embed[0].d_model, 1, 2000, torch.optim.Adam(ae_model.parameters(), lr=0, betas=(0.9, 0.98), eps=1e-9)) dis_optimizer=torch.optim.Adam(dis_model.parameters(), lr=0.0001) if args.load_model: # Load models' params from checkpoint ae_model.load_state_dict(torch.load(args.current_save_path + '/{}_ae_model_params.pkl'.format(args.load_iter), map_location=device)) dis_model.load_state_dict(torch.load(args.current_save_path + '/{}_dis_model_params.pkl'.format(args.load_iter), map_location=device)) start=args.load_iter+1 else: start=0 train_data_loader=non_pair_data_loader( batch_size=args.batch_size, id_bos=args.id_bos, id_eos=args.id_eos, id_unk=args.id_unk, max_sequence_length=args.max_sequence_length, vocab_size=args.vocab_size, gpu=args.gpu