def load_model(config, num_train_steps, label_list): device = torch.device("cuda") n_gpu = torch.cuda.device_count() model = BertQueryNER(config, ) model.to(device) if n_gpu > 1: model = torch.nn.DataParallel(model) # prepare optimzier param_optimizer = list(model.named_parameters()) no_decay = ["bias", "LayerNorm.bias", "LayerNorm.weight"] optimizer_grouped_parameters = [{ "params": [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], "weight_decay": 0.01 }, { "params": [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], "weight_decay": 0.0 }] optimizer = BertAdam(optimizer_grouped_parameters, lr=config.learning_rate, warmup=config.warmup_proportion, t_total=num_train_steps, max_grad_norm=config.clip_grad) return model, optimizer, device, n_gpu
def __init__(self, args: argparse.Namespace): """Initialize a model, tokenizer and config.""" super().__init__() if isinstance(args, argparse.Namespace): self.save_hyperparameters(args) self.args = args else: # eval mode TmpArgs = namedtuple("tmp_args", field_names=list(args.keys())) self.args = args = TmpArgs(**args) self.bert_dir = args.bert_config_dir self.data_dir = self.args.data_dir bert_config = get_auto_config( bert_config_dir=args.bert_config_dir, hidden_dropout_prob=args.bert_dropout, attention_probs_dropout_prob=args.bert_dropout, mrc_dropout=args.mrc_dropout, ) self.model = BertQueryNER(config=bert_config) # logging.info(str(self.model)) logging.info( str(args.__dict__ if isinstance(args, argparse.ArgumentParser ) else args)) # self.ce_loss = CrossEntropyLoss(reduction="none") self.loss_type = args.loss_type # self.loss_type = "bce" if self.loss_type == "bce": self.bce_loss = BCEWithLogitsLoss(reduction="none") else: self.dice_loss = DiceLoss(with_logits=True, smooth=args.dice_smooth) # todo(yuxian): 由于match loss是n^2的,应该特殊调整一下loss rate weight_sum = args.weight_start + args.weight_end + args.weight_span self.weight_start = args.weight_start / weight_sum self.weight_end = args.weight_end / weight_sum self.weight_span = args.weight_span / weight_sum self.flat_ner = args.flat self.span_f1 = QuerySpanF1(flat=self.flat_ner) self.chinese = args.chinese self.optimizer = args.optimizer self.span_loss_candidates = args.span_loss_candidates
args.weight_start = args.weight_start / weight_sum args.weight_end = args.weight_end / weight_sum args.weight_span = args.weight_span / weight_sum bert_path = args.bert_config_dir json_path = args.data_dir is_chinese = True vocab_file = os.path.join(bert_path, "vocab.txt") tokenizer = BertWordPieceTokenizer(vocab_file=vocab_file) device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") bert_config = BertQueryNerConfig.from_pretrained( args.bert_config_dir, hidden_dropout_prob=args.bert_dropout, attention_probs_dropout_prob=args.bert_dropout, mrc_dropout=args.mrc_dropout) model = BertQueryNER.from_pretrained(args.bert_config_dir, config=bert_config).to(device) log = Logger(os.path.join(args.output_dir, "all.log"), level='debug') log.logger.info('开始训练') train_json_path = os.path.join(json_path, 'mrc-ner.train') dev_json_path = os.path.join(json_path, 'mrc-ner.dev') train_dataset = MRCNERDataset(json_path=train_json_path, tokenizer=tokenizer, is_chinese=is_chinese) train_dataloader = DataLoader(train_dataset, batch_size=args.batch_size, collate_fn=collate_to_max_length, shuffle=True) dev_dataset = MRCNERDataset(json_path=dev_json_path,
args = get_argparse().parse_args() bert_path = args.bert_config_dir json_path = args.data_dir is_chinese = True vocab_file = os.path.join(bert_path, "vocab.txt") tokenizer = BertWordPieceTokenizer(vocab_file=vocab_file) device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") output_dir = os.path.join(args.output_dir, "best_f1_checkpoint") bert_config = BertQueryNerConfig.from_pretrained( output_dir, hidden_dropout_prob=args.bert_dropout, attention_probs_dropout_prob=args.bert_dropout, mrc_dropout=args.mrc_dropout) model = BertQueryNER.from_pretrained(output_dir, config=bert_config).to(device) model.eval() test_json_path = os.path.join(json_path, 'mrc-ner.test') test_dataset = MRCNERDataset_test(json_path=test_json_path, tokenizer=tokenizer, is_chinese=is_chinese) test_dataloader = DataLoader( test_dataset, batch_size=1, ) all_test_data = json.load(open(test_json_path, encoding="utf-8")) print(len(all_test_data))
def main(json_path=''): parser = HfArgumentParser((CustomizeArguments, TrainingArguments)) if json_path: custom_args, training_args = parser.parse_json_file(json_file=json_path) elif len(sys.argv) == 2 and sys.argv[1].endswith(".json"): custom_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1])) else: custom_args, training_args = parser.parse_args_into_dataclasses() logging.basicConfig( format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt="%m/%d/%Y %H:%M:%S", handlers=[logging.StreamHandler(sys.stdout), logging.FileHandler(custom_args.log_file_path)], ) logger.setLevel(logging.INFO if is_main_process(training_args.local_rank) else logging.WARN) logger.info( f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu} " + f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}" ) logger.info('Description: {}'.format(custom_args.description)) if json_path: logger.info('json file path is : {}'.format(json_path)) logger.info('json file args are: \n'+open(json_path, 'r').read()) # last_checkpoint = None # if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir: # last_checkpoint = get_last_checkpoint(training_args.output_dir) # if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0: # raise ValueError( # f"Output directory ({training_args.output_dir}) already exists and is not empty. " # "Use --overwrite_output_dir to overcome." # ) # elif last_checkpoint is not None: # logger.info( # f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change " # "the `--output_dir` or add `--overwrite_output_dir` to train from scratch." # ) set_seed(training_args.seed) config = BertQueryNerConfig.from_pretrained( custom_args.config_name_or_path if custom_args.config_name_or_path else custom_args.model_name_or_path, # num_labels=custom_args.num_labels ) model = BertQueryNER.from_pretrained( custom_args.model_name_or_path, config=config ) tokenizer = BertTokenizer.from_pretrained( custom_args.tokenizer_name_or_path if custom_args.tokenizer_name_or_path else custom_args.model_name_or_path, ) # data = pd.read_pickle(custom_args.pickle_data_path) # # df_train = pd.read_pickle(custom_args.train_pickle_data_path) # # df_eval = pd.read_pickle(custom_args.eval_pickle_data_path) # train_dataloader, eval_dataloader = gen_dataloader( # df=data, # # df_train=df_train, # # df_eval=df_eval, # tokenizer=tokenizer, # per_device_train_batch_size=training_args.per_device_train_batch_size, # per_device_eval_batch_size=training_args.per_device_eval_batch_size, # test_size=custom_args.test_size, # max_length=custom_args.max_length, # ) train_dataloader = get_dataloader('train', 64) eval_dataloader = get_dataloader('test', 32) extra_loss = BCEWithLogitsLoss(reduction="none") extra_dice_loss = MRCDiceLoss(with_logits=True) # device = training_args.device if torch.cuda.is_available() else 'cpu' model = nn.DataParallel(model) model = model.cuda() total_bt = time.time() optimizer = AdamW(model.parameters(), lr = 1e-5, eps = 1e-8 ) total_steps = len(train_dataloader) * training_args.num_train_epochs scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps = 5, num_training_steps = total_steps) weight_sum = custom_args.weight_start + custom_args.weight_end + custom_args.weight_span weight_start = custom_args.weight_start / weight_sum weight_end = custom_args.weight_end / weight_sum weight_span = custom_args.weight_span / weight_sum # fgm = FGM(model) for e in range(training_args.num_train_epochs): logger.info('============= Epoch {:} / {:} =============='.format(e + 1, training_args.num_train_epochs)) logger.info('Training...') bt = time.time() total_train_loss = 0 model.train() for step, batch in enumerate(train_dataloader): # break if step % 50 == 0 and not step == 0: elapsed = format_time(time.time() - bt) logger.info(' Batch {:>5,} of {:>5,}. Elapsed: {:}. loss: {}'.format(step, len(train_dataloader), elapsed, total_train_loss/step)) input_ids = batch[0].cuda() token_type_ids = batch[1].cuda() start_labels = batch[2].cuda() end_labels = batch[3].cuda() start_label_mask = batch[4].cuda() end_label_mask = batch[5].cuda() match_labels = batch[6].cuda() # sample_idx = batch[7].cuda() label_idx = batch[7].cuda() attention_mask = (input_ids != 0).long() model.zero_grad() start_logits, end_logits, span_logits = model( input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, ) start_loss, end_loss, match_loss = compute_loss( _loss=extra_dice_loss, start_logits=start_logits, end_logits=end_logits, span_logits=span_logits, start_labels=start_labels, end_labels=end_labels, match_labels=match_labels, start_label_mask=start_label_mask, end_label_mask=end_label_mask ) loss = weight_start * start_loss + weight_end * end_loss + weight_span * match_loss # loss = output.loss # logits = output.logits total_train_loss += loss.item() loss.backward() # fgm.attack(epsilon=1.2) # output_adv = model( # input_ids=input_ids, # attention_mask=attention_mask, # labels=labels # ) # loss_adv = output_adv.loss # loss_adv.backward() # fgm.restore() optimizer.step() scheduler.step() # if step % 50 == 0 and step != 0: # break avg_train_loss = total_train_loss / len(train_dataloader) training_time = format_time(time.time() - bt) logger.info('Average training loss: {0:.2f}'.format(avg_train_loss)) logger.info('Training epcoh took: {:}'.format(training_time)) logger.info('Running Validation...') bt = time.time() model.eval() total_eval_loss = 0 total_eval_f1 = 0 total_eval_acc = 0 total_eval_p = [] total_eval_l = [] for batch in eval_dataloader: input_ids = batch[0].cuda() token_type_ids = batch[1].cuda() start_labels = batch[2].cuda() end_labels = batch[3].cuda() start_label_mask = batch[4].cuda() end_label_mask = batch[5].cuda() match_labels = batch[6].cuda() # sample_idx = batch[7].cuda() label_idx = batch[7].cuda() attention_mask = (input_ids != 0).long() with torch.no_grad(): start_logits, end_logits, span_logits = model( input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids ) start_loss, end_loss, match_loss = compute_loss( _loss=extra_dice_loss, start_logits=start_logits, end_logits=end_logits, span_logits=span_logits, start_labels=start_labels, end_labels=end_labels, match_labels=match_labels, start_label_mask=start_label_mask, end_label_mask=end_label_mask ) loss = weight_start * start_loss + weight_end * end_loss + weight_span * match_loss total_eval_loss += loss.item() start_preds, end_preds = start_logits > 0, end_logits > 0 eval_f1 = query_span_f1(start_preds, end_preds, span_logits, start_label_mask, end_label_mask, match_labels) # logger.info('eval_f1 : {}'.format(eval_f1)) total_eval_f1 += eval_f1 # break # logger.info(f'\n{classification_report(total_eval_p, total_eval_l, zero_division=1)}') avg_val_f1 = total_eval_f1 / len(eval_dataloader) # avg_val_acc = total_eval_acc / len(eval_dataloader) logger.info('F1: {0:.2f}'.format(avg_val_f1)) # logger.info('Acc: {0:.2f}'.format(avg_val_acc)) avg_val_loss = total_eval_loss / len(eval_dataloader) validation_time = format_time(time.time() - bt) logger.info('Validation Loss: {0:.2f}'.format(avg_val_loss)) logger.info('Validation took: {:}'.format(validation_time)) current_ckpt = training_args.output_dir + '/bert-' + datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S") + '-f1_' + str(int(avg_val_f1*100)) + '.pth' logger.info('Start to save checkpoint named {}'.format(current_ckpt)) if custom_args.deploy is True: logger.info('>>>>>>>>>>>> saving the model <<<<<<<<<<<<<<') torch.save(model.module, current_ckpt) else: logger.info('>>>>>>>>>>>> saving the state_dict of model <<<<<<<<<<<<<') torch.save(model.module.state_dict(), current_ckpt)
class BertLabeling(pl.LightningModule): """MLM Trainer""" def __init__(self, args: argparse.Namespace): """Initialize a model, tokenizer and config.""" super().__init__() if isinstance(args, argparse.Namespace): self.save_hyperparameters(args) self.args = args else: # eval mode TmpArgs = namedtuple("tmp_args", field_names=list(args.keys())) self.args = args = TmpArgs(**args) self.bert_dir = args.bert_config_dir self.data_dir = self.args.data_dir bert_config = get_auto_config( bert_config_dir=args.bert_config_dir, hidden_dropout_prob=args.bert_dropout, attention_probs_dropout_prob=args.bert_dropout, mrc_dropout=args.mrc_dropout, ) self.model = BertQueryNER(config=bert_config) # logging.info(str(self.model)) logging.info( str(args.__dict__ if isinstance(args, argparse.ArgumentParser ) else args)) # self.ce_loss = CrossEntropyLoss(reduction="none") self.loss_type = args.loss_type # self.loss_type = "bce" if self.loss_type == "bce": self.bce_loss = BCEWithLogitsLoss(reduction="none") else: self.dice_loss = DiceLoss(with_logits=True, smooth=args.dice_smooth) # todo(yuxian): 由于match loss是n^2的,应该特殊调整一下loss rate weight_sum = args.weight_start + args.weight_end + args.weight_span self.weight_start = args.weight_start / weight_sum self.weight_end = args.weight_end / weight_sum self.weight_span = args.weight_span / weight_sum self.flat_ner = args.flat self.span_f1 = QuerySpanF1(flat=self.flat_ner) self.chinese = args.chinese self.optimizer = args.optimizer self.span_loss_candidates = args.span_loss_candidates @staticmethod def add_model_specific_args(parent_parser): parser = argparse.ArgumentParser(parents=[parent_parser], add_help=False) parser.add_argument("--mrc_dropout", type=float, default=0.1, help="mrc dropout rate") parser.add_argument("--bert_dropout", type=float, default=0.1, help="bert dropout rate") parser.add_argument("--weight_start", type=float, default=1.0) parser.add_argument("--weight_end", type=float, default=1.0) parser.add_argument("--weight_span", type=float, default=1.0) parser.add_argument("--flat", action="store_true", help="is flat ner") parser.add_argument( "--span_loss_candidates", choices=["all", "pred_and_gold", "gold"], default="all", help="Candidates used to compute span loss", ) parser.add_argument("--chinese", action="store_true", help="is chinese dataset") parser.add_argument("--loss_type", choices=["bce", "dice"], default="bce", help="loss type") parser.add_argument("--optimizer", choices=["adamw", "sgd"], default="adamw", help="loss type") parser.add_argument("--dice_smooth", type=float, default=1e-8, help="smooth value of dice loss") parser.add_argument( "--final_div_factor", type=float, default=1e4, help="final div factor of linear decay scheduler", ) return parser def configure_optimizers(self): """Prepare optimizer and schedule (linear warmup and decay)""" no_decay = ["bias", "LayerNorm.weight"] optimizer_grouped_parameters = [ { "params": [ p for n, p in self.model.named_parameters() if not any(nd in n for nd in no_decay) ], "weight_decay": self.args.weight_decay, }, { "params": [ p for n, p in self.model.named_parameters() if any(nd in n for nd in no_decay) ], "weight_decay": 0.0, }, ] if self.optimizer == "adamw": optimizer = AdamW( optimizer_grouped_parameters, betas=(0.9, 0.98), # according to RoBERTa paper lr=self.args.lr, eps=self.args.adam_epsilon, ) else: optimizer = SGD(optimizer_grouped_parameters, lr=self.args.lr, momentum=0.9) num_gpus = len( [x for x in str(self.args.gpus).split(",") if x.strip()]) t_total = (len(self.train_dataloader()) // (self.args.accumulate_grad_batches * num_gpus) + 1) * self.args.max_epochs scheduler = torch.optim.lr_scheduler.OneCycleLR( optimizer, max_lr=self.args.lr, pct_start=float(self.args.warmup_steps / t_total), final_div_factor=self.args.final_div_factor, total_steps=t_total, anneal_strategy="linear", ) return [optimizer], [{"scheduler": scheduler, "interval": "step"}] def forward(self, input_ids, attention_mask, token_type_ids): """""" return self.model(input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids) def compute_loss( self, start_logits, end_logits, span_logits, start_labels, end_labels, match_labels, start_label_mask, end_label_mask, ): batch_size, seq_len = start_logits.size() start_float_label_mask = start_label_mask.view(-1).float() end_float_label_mask = end_label_mask.view(-1).float() match_label_row_mask = (start_label_mask.bool().unsqueeze(-1).expand( -1, -1, seq_len)) match_label_col_mask = (end_label_mask.bool().unsqueeze(-2).expand( -1, seq_len, -1)) match_label_mask = match_label_row_mask & match_label_col_mask match_label_mask = torch.triu(match_label_mask, 0) # start should be less equal to end if self.span_loss_candidates == "all": # naive mask float_match_label_mask = match_label_mask.view(batch_size, -1).float() else: # use only pred or golden start/end to compute match loss start_preds = start_logits > 0 end_preds = end_logits > 0 if self.span_loss_candidates == "gold": match_candidates = (start_labels.unsqueeze(-1).expand( -1, -1, seq_len) > 0) & (end_labels.unsqueeze(-2).expand( -1, seq_len, -1) > 0) else: match_candidates = torch.logical_or( (start_preds.unsqueeze(-1).expand(-1, -1, seq_len) & end_preds.unsqueeze(-2).expand(-1, seq_len, -1)), (start_labels.unsqueeze(-1).expand(-1, -1, seq_len) & end_labels.unsqueeze(-2).expand(-1, seq_len, -1)), ) match_label_mask = match_label_mask & match_candidates float_match_label_mask = match_label_mask.view(batch_size, -1).float() if self.loss_type == "bce": start_loss = self.bce_loss(start_logits.view(-1), start_labels.view(-1).float()) start_loss = (start_loss * start_float_label_mask ).sum() / start_float_label_mask.sum() end_loss = self.bce_loss(end_logits.view(-1), end_labels.view(-1).float()) end_loss = (end_loss * end_float_label_mask ).sum() / end_float_label_mask.sum() match_loss = self.bce_loss( span_logits.view(batch_size, -1), match_labels.view(batch_size, -1).float(), ) match_loss = match_loss * float_match_label_mask match_loss = match_loss.sum() / (float_match_label_mask.sum() + 1e-10) else: start_loss = self.dice_loss(start_logits, start_labels.float(), start_float_label_mask) end_loss = self.dice_loss(end_logits, end_labels.float(), end_float_label_mask) match_loss = self.dice_loss(span_logits, match_labels.float(), float_match_label_mask) return start_loss, end_loss, match_loss def training_step(self, batch, batch_idx): """""" tf_board_logs = { "lr": self.trainer.optimizers[0].param_groups[0]["lr"] } ( tokens, token_type_ids, start_labels, end_labels, start_label_mask, end_label_mask, match_labels, sample_idx, label_idx, ) = batch # num_tasks * [bsz, length, num_labels] attention_mask = (tokens != 0).long() start_logits, end_logits, span_logits = self(tokens, attention_mask, token_type_ids) start_loss, end_loss, match_loss = self.compute_loss( start_logits=start_logits, end_logits=end_logits, span_logits=span_logits, start_labels=start_labels, end_labels=end_labels, match_labels=match_labels, start_label_mask=start_label_mask, end_label_mask=end_label_mask, ) total_loss = (self.weight_start * start_loss + self.weight_end * end_loss + self.weight_span * match_loss) tf_board_logs[f"train_loss"] = total_loss tf_board_logs[f"start_loss"] = start_loss tf_board_logs[f"end_loss"] = end_loss tf_board_logs[f"match_loss"] = match_loss return {"loss": total_loss, "log": tf_board_logs} def validation_step(self, batch, batch_idx): """""" output = {} ( tokens, token_type_ids, start_labels, end_labels, start_label_mask, end_label_mask, match_labels, sample_idx, label_idx, ) = batch attention_mask = (tokens != 0).long() start_logits, end_logits, span_logits = self(tokens, attention_mask, token_type_ids) start_loss, end_loss, match_loss = self.compute_loss( start_logits=start_logits, end_logits=end_logits, span_logits=span_logits, start_labels=start_labels, end_labels=end_labels, match_labels=match_labels, start_label_mask=start_label_mask, end_label_mask=end_label_mask, ) total_loss = (self.weight_start * start_loss + self.weight_end * end_loss + self.weight_span * match_loss) output[f"val_loss"] = total_loss output[f"start_loss"] = start_loss output[f"end_loss"] = end_loss output[f"match_loss"] = match_loss start_preds, end_preds = start_logits > 0, end_logits > 0 span_f1_stats = self.span_f1( start_preds=start_preds, end_preds=end_preds, match_logits=span_logits, start_label_mask=start_label_mask, end_label_mask=end_label_mask, match_labels=match_labels, ) output["span_f1_stats"] = span_f1_stats return output def validation_epoch_end(self, outputs): """""" avg_loss = torch.stack([x["val_loss"] for x in outputs]).mean() tensorboard_logs = {"val_loss": avg_loss} all_counts = torch.stack([x[f"span_f1_stats"] for x in outputs]).sum(0) span_tp, span_fp, span_fn = all_counts span_recall = span_tp / (span_tp + span_fn + 1e-10) span_precision = span_tp / (span_tp + span_fp + 1e-10) span_f1 = (span_precision * span_recall * 2 / (span_recall + span_precision + 1e-10)) tensorboard_logs[f"span_precision"] = span_precision tensorboard_logs[f"span_recall"] = span_recall tensorboard_logs[f"span_f1"] = span_f1 return {"val_loss": avg_loss, "log": tensorboard_logs} def test_step(self, batch, batch_idx): """""" return self.validation_step(batch, batch_idx) def test_epoch_end(self, outputs) -> Dict[str, Dict[str, Tensor]]: """""" return self.validation_epoch_end(outputs) def train_dataloader(self) -> DataLoader: return self.get_dataloader("train") # return self.get_dataloader("dev", 100) def val_dataloader(self): return self.get_dataloader("dev") def test_dataloader(self): return self.get_dataloader("test") # return self.get_dataloader("dev") def get_dataloader(self, prefix="train", limit: int = None) -> DataLoader: """get training dataloader""" """ load_mmap_dataset """ json_path = os.path.join(self.data_dir, f"mrc-ner.{prefix}") vocab_path = os.path.join(self.bert_dir, "vocab.txt") dataset = MRCNERDataset( json_path=json_path, tokenizer=AutoTokenizer.from_pretrained(self.args.bert_config_dir), max_length=self.args.max_length, is_chinese=self.chinese, pad_to_maxlen=False, ) if limit is not None: dataset = TruncateDataset(dataset, limit) dataloader = DataLoader( dataset=dataset, batch_size=self.args.batch_size, num_workers=self.args.workers, shuffle=True if prefix == "train" else False, collate_fn=collate_to_max_length, ) return dataloader