def __init__(self, args): super().__init__() self.args = args self.graph_vocab = GraphVocab(args.graph_vocab_file) if args.encoder_type == 'bertology': # args.encoder_type 控制用什么类型的encoder(EBRTology/Transformer等等) # args.bertology_type 控制具体是什么类型的BERT(bert/xlnert/roberta等等) self.encoder = BERTologyEncoder( no_cuda=not args.cuda, bertology=args.bertology_type, bertology_path=args.saved_model_path, bertology_word_select_mode=args.bertology_word_select, bertology_output_mode=args.bertology_output_mode, max_seq_len=args.max_seq_len, bertology_after=args.bertology_after, after_layers=args.after_layers, after_dropout=args.after_dropout) elif args.encoder_type in ['lstm', 'gru']: self.encoder = None # Do NOT support now #todo elif args.encoder_type == 'transformer': self.encoder = None # Do NOT support now #todo if args.direct_biaffine: self.unlabeled_biaffine = DirectBiaffineScorer( args.encoder_output_dim, args.encoder_output_dim, 1, pairwise=True) self.labeled_biaffine = DirectBiaffineScorer( args.encoder_output_dim, args.encoder_output_dim, len(self.graph_vocab.get_labels()), pairwise=True) else: self.unlabeled_biaffine = DeepBiaffineScorer( args.encoder_output_dim, args.encoder_output_dim, args.biaffine_hidden_dim, 1, pairwise=True, dropout=args.biaffine_dropout) self.labeled_biaffine = DeepBiaffineScorer( args.encoder_output_dim, args.encoder_output_dim, args.biaffine_hidden_dim, len(self.graph_vocab.get_labels()), pairwise=True, dropout=args.biaffine_dropout) # self.dropout = nn.Dropout(args.dropout) if args.learned_loss_ratio: self.label_loss_ratio = nn.Parameter(torch.Tensor([0.5])) else: self.label_loss_ratio = args.label_loss_ratio
def load_bertology_input(args): assert (pathlib.Path(args.saved_model_path) / 'vocab.txt').exists() tokenizer = load_bert_tokenizer(args.saved_model_path, args.bertology_type) vocab = GraphVocab(args.graph_vocab_file) if args.run_mode == 'train': tokenizer.save_pretrained(args.output_model_dir) if args.run_mode in ['dev', 'inference']: dataset, conllu_file = load_and_cache_examples(args, args.input_conllu_path, vocab, tokenizer, training=False) data_loader = get_data_loader(dataset, batch_size=args.eval_batch_size, evaluation=True) return data_loader, conllu_file elif args.run_mode == 'train': train_dataset, train_conllu_file = load_and_cache_examples( args, os.path.join(args.data_dir, args.train_file), vocab, tokenizer, training=True) dev_dataset, dev_conllu_file = load_and_cache_examples( args, os.path.join(args.data_dir, args.dev_file), vocab, tokenizer, training=False) train_data_loader = get_data_loader(train_dataset, batch_size=args.train_batch_size, evaluation=False) dev_data_loader = get_data_loader(dev_dataset, batch_size=args.eval_batch_size, evaluation=True) return train_data_loader, train_conllu_file, dev_data_loader, dev_conllu_file
def __init__(self, args, model): self.model = model self.optimizer = self.optim_scheduler = None self.graph_vocab = GraphVocab(args.graph_vocab_file) self.args = args self.logger = get_logger(args.log_name)
class BiaffineDependencyTrainer(metaclass=ABCMeta): def __init__(self, args, model): self.model = model self.optimizer = self.optim_scheduler = None self.graph_vocab = GraphVocab(args.graph_vocab_file) self.args = args self.logger = get_logger(args.log_name) @abstractmethod def _unpack_batch(self, args, batch): """ 拆分batch,得到encoder的输入和word mask,sentence length,以及dep ids :param args: 配置参数 :param batch: 输入的单个batch,类型为TensorDataset(或者torchtext.dataset),可用索引分别取值 :return:返回一个元祖,[1]是inputs,类型为字典;[2]是word mask;[3]是sentence length,python 列表;[4]是dep ids """ raise NotImplementedError('must implement in sub class') def _custom_train_operations(self, epoch): """ 某些模型在训练时可能需要一些定制化的操作, 比如BERT类型的模型可能会在Training的时候动态freeze某些层 为了支持这些操作同时不破坏BiaffineDependencyTrainer的普适性,我们加入这个方法 BiaffineDependencyTrainer的子类可以选择重写该方法以支持定制化操作 注意这个方法会在训练的每个epoch的开始调用一次 本方法默认不会做任何事情 :return: """ pass def _update_and_predict(self, unlabeled_scores, labeled_scores, unlabeled_target, labeled_target, word_pad_mask, label_loss_ratio=None, sentence_lengths=None, calc_loss=True, update=True, calc_prediction=False): """ 针对一个batch输入:计算loss,反向传播,计算预测结果 :param word_pad_mask: 以word为单位,1为PAD,0为真实输入 :return: """ weights = torch.ones(word_pad_mask.size(0), self.args.max_seq_len, self.args.max_seq_len, dtype=unlabeled_scores.dtype, device=unlabeled_scores.device) # 将PAD的位置权重设为0,其余位置为1 weights = weights.masked_fill(word_pad_mask.unsqueeze(1), 0) weights = weights.masked_fill(word_pad_mask.unsqueeze(2), 0) # words_num 记录batch中的单词数量 # torch.eq(word_pad_mask, False) 得到word_mask words_num = torch.sum(torch.eq(word_pad_mask, False)).item() if calc_loss: assert label_loss_ratio assert unlabeled_target is not None and labeled_target is not None dep_arc_loss_func = nn.BCEWithLogitsLoss(weight=weights, reduction='sum') dep_arc_loss = dep_arc_loss_func(unlabeled_scores, unlabeled_target) dep_label_loss_func = nn.CrossEntropyLoss(ignore_index=-1, reduction='sum') dependency_mask = labeled_target.eq(0) labeled_target = labeled_target.masked_fill(dependency_mask, -1) labeled_scores = labeled_scores.contiguous().view( -1, len(self.graph_vocab.get_labels())) dep_label_loss = dep_label_loss_func(labeled_scores, labeled_target.view(-1)) loss = 2 * ((1 - label_loss_ratio) * dep_arc_loss + label_loss_ratio * dep_label_loss) if self.args.average_loss_by_words_num: loss = loss / words_num if self.args.scale_loss: loss = loss * self.args.loss_scaling_ratio if self.args.n_gpu > 1: loss = loss.mean( ) # mean() to average on multi-gpu parallel training if update: loss.backward() if self.args.max_grad_norm > 0: torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.args.max_grad_norm) self.optimizer.step() if self.optim_scheduler: self.optim_scheduler.step( ) # Update learning rate schedule self.model.zero_grad() loss = loss.detach().cpu().item() else: loss = None if calc_prediction: assert sentence_lengths weights = weights.unsqueeze(3) head_probs = torch.sigmoid(unlabeled_scores).unsqueeze(3) label_probs = torch.softmax(labeled_scores, dim=3) batch_probs = head_probs * label_probs * weights batch_probs = batch_probs.detach().cpu().numpy() # debug_print(batch_probs) sem_graph = sdp_decoder(batch_probs, sentence_lengths) sem_sents = parse_semgraph(sem_graph, sentence_lengths) batch_prediction = self.graph_vocab.parse_to_sent_batch(sem_sents) else: batch_prediction = None return loss, batch_prediction def train(self, train_data_loader, dev_data_loader=None, dev_CoNLLU_file=None): self.optimizer, self.optim_scheduler = get_optimizer( self.args, self.model) global_step = 0 best_result = BestResult() self.model.zero_grad() set_seed( self.args ) # Added here for reproductibility (even between python 2 and 3) train_stop = False summary_writer = SummaryWriter(log_dir=self.args.summary_dir) for epoch in range(1, self.args.max_train_epochs + 1): epoch_ave_loss = 0 train_data_loader = tqdm(train_data_loader, desc=f'Training epoch {epoch}') # 某些模型在训练时可能需要一些定制化的操作,默认什么都不做 # 具体参考子类中_custom_train_operations的实现 self._custom_train_operations(epoch) for step, batch in enumerate(train_data_loader): batch = tuple(t.to(self.args.device) for t in batch) self.model.train() # debug_print(batch) # word_mask:以word为单位,1为真实输入,0为PAD inputs, word_mask, _, dep_ids = self._unpack_batch( self.args, batch) # word_pad_mask:以word为单位,1为PAD,0为真实输入 word_pad_mask = torch.eq(word_mask, 0) unlabeled_scores, labeled_scores = self.model(inputs) labeled_target = dep_ids unlabeled_target = labeled_target.ge(1).to( unlabeled_scores.dtype) # Calc loss and update: loss, _ = self._update_and_predict( unlabeled_scores, labeled_scores, unlabeled_target, labeled_target, word_pad_mask, label_loss_ratio=self.model.label_loss_ratio if not self.args.parallel_train else self.model.module.label_loss_ratio, calc_loss=True, update=True, calc_prediction=False) global_step += 1 if loss is not None: epoch_ave_loss += loss if global_step % self.args.eval_interval == 0: summary_writer.add_scalar('loss/train', loss, global_step) # 记录学习率 for i, param_group in enumerate( self.optimizer.param_groups): summary_writer.add_scalar(f'lr/group_{i}', param_group['lr'], global_step) if dev_data_loader: UAS, LAS = self.dev(dev_data_loader, dev_CoNLLU_file) summary_writer.add_scalar('metrics/uas', UAS, global_step) summary_writer.add_scalar('metrics/las', LAS, global_step) if best_result.is_new_record(LAS=LAS, UAS=UAS, global_step=global_step): self.logger.info( f"\n## NEW BEST RESULT in epoch {epoch} ##") self.logger.info('\n' + str(best_result)) # 保存最优模型: if hasattr(self.model, 'module'): # 多卡,torch.nn.DataParallel封装model self.model.module.save_pretrained( self.args.output_model_dir) else: self.model.save_pretrained( self.args.output_model_dir) if self.args.early_stop and global_step - best_result.best_LAS_step > self.args.early_stop_steps: self.logger.info( f'\n## Early stop in step:{global_step} ##') train_stop = True break if train_stop: break # print(f'\n- Epoch {epoch} average loss : {epoch_ave_loss / len(train_data_loader)}') summary_writer.add_scalar('epoch_loss', epoch_ave_loss / len(train_data_loader), epoch) with open(self.args.dev_result_path, 'w', encoding='utf-8') as f: f.write(str(best_result) + '\n') self.logger.info("\n## BEST RESULT in Training ##") self.logger.info('\n' + str(best_result)) summary_writer.close() def dev(self, dev_data_loader, dev_CoNLLU_file, input_conllu_path=None, output_conllu_path=None): assert isinstance(dev_CoNLLU_file, CoNLLFile) if input_conllu_path is None: input_conllu_path = os.path.join(self.args.data_dir, self.args.dev_file) if output_conllu_path is None: output_conllu_path = self.args.dev_output_path dev_data_loader = tqdm(dev_data_loader, desc='Evaluation') predictions = [] for step, batch in enumerate(dev_data_loader): self.model.eval() batch = tuple(t.to(self.args.device) for t in batch) inputs, word_mask, sent_lens, dep_ids = self._unpack_batch( self.args, batch) word_mask = torch.eq(word_mask, 0) unlabeled_scores, labeled_scores = self.model(inputs) try: with torch.no_grad(): _, batch_prediction = self._update_and_predict( unlabeled_scores, labeled_scores, None, None, word_mask, label_loss_ratio=self.model.label_loss_ratio if not self.args.parallel_train else self.model.module.label_loss_ratio, sentence_lengths=sent_lens, calc_loss=False, update=False, calc_prediction=True) except Exception as e: for b in batch: print(b.shape) raise e predictions += batch_prediction # batch_sent_lens += sent_lens dev_CoNLLU_file.set(['deps'], [dep for sent in predictions for dep in sent]) dev_CoNLLU_file.write_conll(output_conllu_path) UAS, LAS = sdp_scorer.score(output_conllu_path, input_conllu_path) return UAS, LAS def inference(self, inference_data_loader, inference_CoNLLU_file, output_conllu_path): inference_data_loader = tqdm(inference_data_loader, desc='Inference') predictions = [] for step, batch in enumerate(inference_data_loader): self.model.eval() inputs, word_mask, sent_lens, _ = self._unpack_batch( self.args, batch) word_mask = torch.eq(word_mask, 0) unlabeled_scores, labeled_scores = self.model(inputs) with torch.no_grad(): _, batch_prediction = self._update_and_predict( unlabeled_scores, labeled_scores, None, None, word_mask, label_loss_ratio=self.model.label_loss_ratio if not self.args.parallel_train else self.model.module.label_loss_ratio, sentence_lengths=sent_lens, calc_loss=False, update=False, calc_prediction=True) predictions += batch_prediction inference_CoNLLU_file.set( ['deps'], [dep for sent in predictions for dep in sent]) inference_CoNLLU_file.write_conll(output_conllu_path) return predictions
# for Biaffine self.biaffine_hidden_dim = 300 self.biaffine_dropout = 0.1 # for loss: self.learned_loss_ratio = True, self.label_loss_ratio = 0.5 args = Args() if args.encoder_type == 'bertology': args.encoder_output_dim = 768 tokenizer = load_bert_tokenizer( '/home/liangs/disk/data/bertology-base-chinese', 'bertology') vocab = GraphVocab('../dataset/graph_vocab.txt') dataset, CoNLLU_file = load_and_cache_examples(args, vocab, tokenizer) data_loader = get_data_loader(dataset, batch_size=2, evaluation=True) # bertology = BERTTypeEncoder(no_cuda=True, bert_path=args.bert_path) model = BiaffineDependencyModel(args) print(model) for batch in data_loader: inputs = { 'input_ids': batch[0], 'attention_mask': batch[1], 'token_type_ids': batch[2] if args.encoder_type in ['bertology', 'xlnet'] else None, 'start_pos': batch[3],
self.data_dir = '/home/liangs/codes/doing_codes/CSDP_Biaffine_Parser_lhy/CSDP_Biaffine_Parser_lhy/dataset' self.train_file = 'test.conllu' self.max_seq_len = 10 self.encoder_type = 'bertology' self.root_representation = 'unused' self.device = 'cpu' self.skip_too_long_input = False args = Args() # print(f'{}') tokenizer = load_bert_tokenizer( '/home/liangs/disk/data/bertology-base-chinese', 'bertology') # print(tokenizer) # print(tokenizer.vocab['[unused1]']) vocab = GraphVocab( '/home/liangs/codes/doing_codes/CSDP_Biaffine_Parser/dataset/graph_vocab.txt' ) dataset, CoNLLU_file, _, _, _, _ = load_and_cache_examples(args, vocab, tokenizer, train=True, dev=False, test=False) data_loader = get_data_loader(dataset, 1, evaluation=True) # 原始输入句子 print(CoNLLU_file.get(['word'], as_sentences=True)) for batch in data_loader: # input ids: # print(batch[0]) # input mask: # print(batch[1])