Пример #1
0
def load_model(config, num_train_steps, label_list):
    device = torch.device("cuda")
    n_gpu = torch.cuda.device_count()
    model = BertQueryNER(config, )
    model.to(device)
    if n_gpu > 1:
        model = torch.nn.DataParallel(model)

    # prepare optimzier
    param_optimizer = list(model.named_parameters())

    no_decay = ["bias", "LayerNorm.bias", "LayerNorm.weight"]
    optimizer_grouped_parameters = [{
        "params":
        [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
        "weight_decay":
        0.01
    }, {
        "params":
        [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
        "weight_decay":
        0.0
    }]

    optimizer = BertAdam(optimizer_grouped_parameters,
                         lr=config.learning_rate,
                         warmup=config.warmup_proportion,
                         t_total=num_train_steps,
                         max_grad_norm=config.clip_grad)

    return model, optimizer, device, n_gpu
    def __init__(self, args: argparse.Namespace):
        """Initialize a model, tokenizer and config."""
        super().__init__()
        if isinstance(args, argparse.Namespace):
            self.save_hyperparameters(args)
            self.args = args
        else:
            # eval mode
            TmpArgs = namedtuple("tmp_args", field_names=list(args.keys()))
            self.args = args = TmpArgs(**args)

        self.bert_dir = args.bert_config_dir
        self.data_dir = self.args.data_dir

        bert_config = get_auto_config(
            bert_config_dir=args.bert_config_dir,
            hidden_dropout_prob=args.bert_dropout,
            attention_probs_dropout_prob=args.bert_dropout,
            mrc_dropout=args.mrc_dropout,
        )

        self.model = BertQueryNER(config=bert_config)
        # logging.info(str(self.model))
        logging.info(
            str(args.__dict__ if isinstance(args, argparse.ArgumentParser
                                            ) else args))
        # self.ce_loss = CrossEntropyLoss(reduction="none")
        self.loss_type = args.loss_type
        # self.loss_type = "bce"
        if self.loss_type == "bce":
            self.bce_loss = BCEWithLogitsLoss(reduction="none")
        else:
            self.dice_loss = DiceLoss(with_logits=True,
                                      smooth=args.dice_smooth)
        # todo(yuxian): 由于match loss是n^2的,应该特殊调整一下loss rate
        weight_sum = args.weight_start + args.weight_end + args.weight_span
        self.weight_start = args.weight_start / weight_sum
        self.weight_end = args.weight_end / weight_sum
        self.weight_span = args.weight_span / weight_sum
        self.flat_ner = args.flat
        self.span_f1 = QuerySpanF1(flat=self.flat_ner)
        self.chinese = args.chinese
        self.optimizer = args.optimizer
        self.span_loss_candidates = args.span_loss_candidates
Пример #3
0
    args.weight_start = args.weight_start / weight_sum
    args.weight_end = args.weight_end / weight_sum
    args.weight_span = args.weight_span / weight_sum

    bert_path = args.bert_config_dir
    json_path = args.data_dir
    is_chinese = True
    vocab_file = os.path.join(bert_path, "vocab.txt")
    tokenizer = BertWordPieceTokenizer(vocab_file=vocab_file)
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    bert_config = BertQueryNerConfig.from_pretrained(
        args.bert_config_dir,
        hidden_dropout_prob=args.bert_dropout,
        attention_probs_dropout_prob=args.bert_dropout,
        mrc_dropout=args.mrc_dropout)
    model = BertQueryNER.from_pretrained(args.bert_config_dir,
                                         config=bert_config).to(device)

    log = Logger(os.path.join(args.output_dir, "all.log"), level='debug')
    log.logger.info('开始训练')

    train_json_path = os.path.join(json_path, 'mrc-ner.train')
    dev_json_path = os.path.join(json_path, 'mrc-ner.dev')

    train_dataset = MRCNERDataset(json_path=train_json_path,
                                  tokenizer=tokenizer,
                                  is_chinese=is_chinese)
    train_dataloader = DataLoader(train_dataset,
                                  batch_size=args.batch_size,
                                  collate_fn=collate_to_max_length,
                                  shuffle=True)
    dev_dataset = MRCNERDataset(json_path=dev_json_path,
Пример #4
0
    args = get_argparse().parse_args()
    bert_path = args.bert_config_dir
    json_path = args.data_dir
    is_chinese = True
    vocab_file = os.path.join(bert_path, "vocab.txt")
    tokenizer = BertWordPieceTokenizer(vocab_file=vocab_file)
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    output_dir = os.path.join(args.output_dir, "best_f1_checkpoint")

    bert_config = BertQueryNerConfig.from_pretrained(
        output_dir,
        hidden_dropout_prob=args.bert_dropout,
        attention_probs_dropout_prob=args.bert_dropout,
        mrc_dropout=args.mrc_dropout)
    model = BertQueryNER.from_pretrained(output_dir,
                                         config=bert_config).to(device)
    model.eval()

    test_json_path = os.path.join(json_path, 'mrc-ner.test')
    test_dataset = MRCNERDataset_test(json_path=test_json_path,
                                      tokenizer=tokenizer,
                                      is_chinese=is_chinese)
    test_dataloader = DataLoader(
        test_dataset,
        batch_size=1,
    )

    all_test_data = json.load(open(test_json_path, encoding="utf-8"))

    print(len(all_test_data))
Пример #5
0
def main(json_path=''):

    parser = HfArgumentParser((CustomizeArguments, TrainingArguments))

    if json_path:
        custom_args, training_args = parser.parse_json_file(json_file=json_path)
    elif len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
        custom_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
    else:
        custom_args, training_args = parser.parse_args_into_dataclasses()

    logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s -   %(message)s",
        datefmt="%m/%d/%Y %H:%M:%S",
        handlers=[logging.StreamHandler(sys.stdout), logging.FileHandler(custom_args.log_file_path)],
    )
    logger.setLevel(logging.INFO if is_main_process(training_args.local_rank) else logging.WARN)

    logger.info(
        f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu} "
        + f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
    )

    logger.info('Description: {}'.format(custom_args.description))
    if json_path:
        logger.info('json file path is : {}'.format(json_path))
        logger.info('json file args are: \n'+open(json_path, 'r').read())

    # last_checkpoint = None
    # if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir:
    #     last_checkpoint = get_last_checkpoint(training_args.output_dir)
    #     if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0:
    #         raise ValueError(
    #             f"Output directory ({training_args.output_dir}) already exists and is not empty. "
    #             "Use --overwrite_output_dir to overcome."
    #         )
    #     elif last_checkpoint is not None:
    #         logger.info(
    #             f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change "
    #             "the `--output_dir` or add `--overwrite_output_dir` to train from scratch."
    #         )

    set_seed(training_args.seed)

    config = BertQueryNerConfig.from_pretrained(
        custom_args.config_name_or_path if custom_args.config_name_or_path else custom_args.model_name_or_path,
        # num_labels=custom_args.num_labels
    )

    model = BertQueryNER.from_pretrained(
        custom_args.model_name_or_path,
        config=config
    )

    tokenizer = BertTokenizer.from_pretrained(
        custom_args.tokenizer_name_or_path if custom_args.tokenizer_name_or_path else custom_args.model_name_or_path,
    )

    # data = pd.read_pickle(custom_args.pickle_data_path)
    # # df_train = pd.read_pickle(custom_args.train_pickle_data_path)
    # # df_eval = pd.read_pickle(custom_args.eval_pickle_data_path)
    # train_dataloader, eval_dataloader = gen_dataloader(
    #     df=data,
    #     # df_train=df_train,
    #     # df_eval=df_eval,
    #     tokenizer=tokenizer,
    #     per_device_train_batch_size=training_args.per_device_train_batch_size,
    #     per_device_eval_batch_size=training_args.per_device_eval_batch_size,
    #     test_size=custom_args.test_size,
    #     max_length=custom_args.max_length,
    # )

    train_dataloader = get_dataloader('train', 64)
    eval_dataloader = get_dataloader('test', 32)
    extra_loss = BCEWithLogitsLoss(reduction="none")
    extra_dice_loss = MRCDiceLoss(with_logits=True)

    # device = training_args.device if torch.cuda.is_available() else 'cpu'

    model = nn.DataParallel(model)
    model = model.cuda()
    total_bt = time.time()

    optimizer = AdamW(model.parameters(),
                  lr = 1e-5,
                  eps = 1e-8
                )

    total_steps = len(train_dataloader) * training_args.num_train_epochs

    scheduler = get_linear_schedule_with_warmup(optimizer, 
                                                num_warmup_steps = 5, 
                                                num_training_steps = total_steps)

    weight_sum = custom_args.weight_start + custom_args.weight_end + custom_args.weight_span
    weight_start = custom_args.weight_start / weight_sum
    weight_end = custom_args.weight_end / weight_sum
    weight_span = custom_args.weight_span / weight_sum
    # fgm = FGM(model)

    for e in range(training_args.num_train_epochs):

        logger.info('============= Epoch {:} / {:} =============='.format(e + 1, training_args.num_train_epochs))
        logger.info('Training...')

        bt = time.time()
        total_train_loss = 0
        model.train()

        for step, batch in enumerate(train_dataloader):
            # break

            if step % 50 == 0 and not step == 0:
                elapsed = format_time(time.time() - bt)
                logger.info('  Batch {:>5,}  of  {:>5,}.   Elapsed: {:}.    loss: {}'.format(step, len(train_dataloader), elapsed, total_train_loss/step))

            input_ids = batch[0].cuda()
            token_type_ids = batch[1].cuda()
            start_labels = batch[2].cuda()
            end_labels = batch[3].cuda()
            start_label_mask = batch[4].cuda()
            end_label_mask = batch[5].cuda()
            match_labels = batch[6].cuda()
            # sample_idx = batch[7].cuda()
            label_idx = batch[7].cuda()

            attention_mask = (input_ids != 0).long()

            model.zero_grad()

            start_logits, end_logits, span_logits = model(
                input_ids=input_ids,
                attention_mask=attention_mask,
                token_type_ids=token_type_ids,
            )

            start_loss, end_loss, match_loss = compute_loss(
                                                            _loss=extra_dice_loss,
                                                            start_logits=start_logits,
                                                            end_logits=end_logits,
                                                            span_logits=span_logits,
                                                            start_labels=start_labels,
                                                            end_labels=end_labels,
                                                            match_labels=match_labels,
                                                            start_label_mask=start_label_mask,
                                                            end_label_mask=end_label_mask
                                                            )

            loss = weight_start * start_loss + weight_end * end_loss + weight_span * match_loss

            # loss = output.loss
            # logits = output.logits
            total_train_loss += loss.item()

            loss.backward()

            # fgm.attack(epsilon=1.2)
            # output_adv = model(
            #     input_ids=input_ids,
            #     attention_mask=attention_mask,
            #     labels=labels
            # )

            # loss_adv = output_adv.loss
            # loss_adv.backward()
            # fgm.restore()

            optimizer.step()
            scheduler.step()
            # if step % 50 == 0 and step != 0:
            #     break

        avg_train_loss = total_train_loss / len(train_dataloader)
        training_time = format_time(time.time() - bt)
        logger.info('Average training loss: {0:.2f}'.format(avg_train_loss))
        logger.info('Training epcoh took: {:}'.format(training_time))

        logger.info('Running Validation...')
        bt = time.time()
        model.eval()

        total_eval_loss = 0
        total_eval_f1 = 0
        total_eval_acc = 0
        total_eval_p = []
        total_eval_l = []

        for batch in eval_dataloader:

            input_ids = batch[0].cuda()
            token_type_ids = batch[1].cuda()
            start_labels = batch[2].cuda()
            end_labels = batch[3].cuda()
            start_label_mask = batch[4].cuda()
            end_label_mask = batch[5].cuda()
            match_labels = batch[6].cuda()
            # sample_idx = batch[7].cuda()
            label_idx = batch[7].cuda()

            attention_mask = (input_ids != 0).long()

            with torch.no_grad():
                start_logits, end_logits, span_logits = model(
                    input_ids=input_ids,
                    attention_mask=attention_mask,
                    token_type_ids=token_type_ids
                )

            start_loss, end_loss, match_loss = compute_loss(
                                                _loss=extra_dice_loss,
                                                start_logits=start_logits,
                                                end_logits=end_logits,
                                                span_logits=span_logits,
                                                start_labels=start_labels,
                                                end_labels=end_labels,
                                                match_labels=match_labels,
                                                start_label_mask=start_label_mask,
                                                end_label_mask=end_label_mask
                                                )

            loss = weight_start * start_loss + weight_end * end_loss + weight_span * match_loss

            total_eval_loss += loss.item()
            start_preds, end_preds = start_logits > 0, end_logits > 0
            eval_f1 = query_span_f1(start_preds, end_preds, span_logits, start_label_mask, end_label_mask, match_labels)
            # logger.info('eval_f1 : {}'.format(eval_f1))
            total_eval_f1 += eval_f1
            # break


        # logger.info(f'\n{classification_report(total_eval_p, total_eval_l, zero_division=1)}')
        avg_val_f1 = total_eval_f1 / len(eval_dataloader)
        # avg_val_acc = total_eval_acc / len(eval_dataloader)
        logger.info('F1: {0:.2f}'.format(avg_val_f1))
        # logger.info('Acc: {0:.2f}'.format(avg_val_acc))

        avg_val_loss = total_eval_loss / len(eval_dataloader)
        validation_time = format_time(time.time() - bt)

        logger.info('Validation Loss: {0:.2f}'.format(avg_val_loss))
        logger.info('Validation took: {:}'.format(validation_time))

        current_ckpt = training_args.output_dir + '/bert-' + datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S") + '-f1_' + str(int(avg_val_f1*100)) + '.pth'
        logger.info('Start to save checkpoint named {}'.format(current_ckpt))
        if custom_args.deploy is True:
            logger.info('>>>>>>>>>>>> saving the model <<<<<<<<<<<<<<')
            torch.save(model.module, current_ckpt)
        else:
            logger.info('>>>>>>>>>>>> saving the state_dict of model <<<<<<<<<<<<<')
            torch.save(model.module.state_dict(), current_ckpt)
class BertLabeling(pl.LightningModule):
    """MLM Trainer"""
    def __init__(self, args: argparse.Namespace):
        """Initialize a model, tokenizer and config."""
        super().__init__()
        if isinstance(args, argparse.Namespace):
            self.save_hyperparameters(args)
            self.args = args
        else:
            # eval mode
            TmpArgs = namedtuple("tmp_args", field_names=list(args.keys()))
            self.args = args = TmpArgs(**args)

        self.bert_dir = args.bert_config_dir
        self.data_dir = self.args.data_dir

        bert_config = get_auto_config(
            bert_config_dir=args.bert_config_dir,
            hidden_dropout_prob=args.bert_dropout,
            attention_probs_dropout_prob=args.bert_dropout,
            mrc_dropout=args.mrc_dropout,
        )

        self.model = BertQueryNER(config=bert_config)
        # logging.info(str(self.model))
        logging.info(
            str(args.__dict__ if isinstance(args, argparse.ArgumentParser
                                            ) else args))
        # self.ce_loss = CrossEntropyLoss(reduction="none")
        self.loss_type = args.loss_type
        # self.loss_type = "bce"
        if self.loss_type == "bce":
            self.bce_loss = BCEWithLogitsLoss(reduction="none")
        else:
            self.dice_loss = DiceLoss(with_logits=True,
                                      smooth=args.dice_smooth)
        # todo(yuxian): 由于match loss是n^2的,应该特殊调整一下loss rate
        weight_sum = args.weight_start + args.weight_end + args.weight_span
        self.weight_start = args.weight_start / weight_sum
        self.weight_end = args.weight_end / weight_sum
        self.weight_span = args.weight_span / weight_sum
        self.flat_ner = args.flat
        self.span_f1 = QuerySpanF1(flat=self.flat_ner)
        self.chinese = args.chinese
        self.optimizer = args.optimizer
        self.span_loss_candidates = args.span_loss_candidates

    @staticmethod
    def add_model_specific_args(parent_parser):
        parser = argparse.ArgumentParser(parents=[parent_parser],
                                         add_help=False)
        parser.add_argument("--mrc_dropout",
                            type=float,
                            default=0.1,
                            help="mrc dropout rate")
        parser.add_argument("--bert_dropout",
                            type=float,
                            default=0.1,
                            help="bert dropout rate")
        parser.add_argument("--weight_start", type=float, default=1.0)
        parser.add_argument("--weight_end", type=float, default=1.0)
        parser.add_argument("--weight_span", type=float, default=1.0)
        parser.add_argument("--flat", action="store_true", help="is flat ner")
        parser.add_argument(
            "--span_loss_candidates",
            choices=["all", "pred_and_gold", "gold"],
            default="all",
            help="Candidates used to compute span loss",
        )
        parser.add_argument("--chinese",
                            action="store_true",
                            help="is chinese dataset")

        parser.add_argument("--loss_type",
                            choices=["bce", "dice"],
                            default="bce",
                            help="loss type")
        parser.add_argument("--optimizer",
                            choices=["adamw", "sgd"],
                            default="adamw",
                            help="loss type")
        parser.add_argument("--dice_smooth",
                            type=float,
                            default=1e-8,
                            help="smooth value of dice loss")
        parser.add_argument(
            "--final_div_factor",
            type=float,
            default=1e4,
            help="final div factor of linear decay scheduler",
        )
        return parser

    def configure_optimizers(self):
        """Prepare optimizer and schedule (linear warmup and decay)"""
        no_decay = ["bias", "LayerNorm.weight"]
        optimizer_grouped_parameters = [
            {
                "params": [
                    p for n, p in self.model.named_parameters()
                    if not any(nd in n for nd in no_decay)
                ],
                "weight_decay":
                self.args.weight_decay,
            },
            {
                "params": [
                    p for n, p in self.model.named_parameters()
                    if any(nd in n for nd in no_decay)
                ],
                "weight_decay":
                0.0,
            },
        ]
        if self.optimizer == "adamw":
            optimizer = AdamW(
                optimizer_grouped_parameters,
                betas=(0.9, 0.98),  # according to RoBERTa paper
                lr=self.args.lr,
                eps=self.args.adam_epsilon,
            )
        else:
            optimizer = SGD(optimizer_grouped_parameters,
                            lr=self.args.lr,
                            momentum=0.9)
        num_gpus = len(
            [x for x in str(self.args.gpus).split(",") if x.strip()])
        t_total = (len(self.train_dataloader()) //
                   (self.args.accumulate_grad_batches * num_gpus) +
                   1) * self.args.max_epochs
        scheduler = torch.optim.lr_scheduler.OneCycleLR(
            optimizer,
            max_lr=self.args.lr,
            pct_start=float(self.args.warmup_steps / t_total),
            final_div_factor=self.args.final_div_factor,
            total_steps=t_total,
            anneal_strategy="linear",
        )
        return [optimizer], [{"scheduler": scheduler, "interval": "step"}]

    def forward(self, input_ids, attention_mask, token_type_ids):
        """"""
        return self.model(input_ids,
                          attention_mask=attention_mask,
                          token_type_ids=token_type_ids)

    def compute_loss(
        self,
        start_logits,
        end_logits,
        span_logits,
        start_labels,
        end_labels,
        match_labels,
        start_label_mask,
        end_label_mask,
    ):
        batch_size, seq_len = start_logits.size()

        start_float_label_mask = start_label_mask.view(-1).float()
        end_float_label_mask = end_label_mask.view(-1).float()
        match_label_row_mask = (start_label_mask.bool().unsqueeze(-1).expand(
            -1, -1, seq_len))
        match_label_col_mask = (end_label_mask.bool().unsqueeze(-2).expand(
            -1, seq_len, -1))
        match_label_mask = match_label_row_mask & match_label_col_mask
        match_label_mask = torch.triu(match_label_mask,
                                      0)  # start should be less equal to end

        if self.span_loss_candidates == "all":
            # naive mask
            float_match_label_mask = match_label_mask.view(batch_size,
                                                           -1).float()
        else:
            # use only pred or golden start/end to compute match loss
            start_preds = start_logits > 0
            end_preds = end_logits > 0
            if self.span_loss_candidates == "gold":
                match_candidates = (start_labels.unsqueeze(-1).expand(
                    -1, -1, seq_len) > 0) & (end_labels.unsqueeze(-2).expand(
                        -1, seq_len, -1) > 0)
            else:
                match_candidates = torch.logical_or(
                    (start_preds.unsqueeze(-1).expand(-1, -1, seq_len)
                     & end_preds.unsqueeze(-2).expand(-1, seq_len, -1)),
                    (start_labels.unsqueeze(-1).expand(-1, -1, seq_len)
                     & end_labels.unsqueeze(-2).expand(-1, seq_len, -1)),
                )
            match_label_mask = match_label_mask & match_candidates
            float_match_label_mask = match_label_mask.view(batch_size,
                                                           -1).float()
        if self.loss_type == "bce":
            start_loss = self.bce_loss(start_logits.view(-1),
                                       start_labels.view(-1).float())
            start_loss = (start_loss * start_float_label_mask
                          ).sum() / start_float_label_mask.sum()
            end_loss = self.bce_loss(end_logits.view(-1),
                                     end_labels.view(-1).float())
            end_loss = (end_loss * end_float_label_mask
                        ).sum() / end_float_label_mask.sum()
            match_loss = self.bce_loss(
                span_logits.view(batch_size, -1),
                match_labels.view(batch_size, -1).float(),
            )
            match_loss = match_loss * float_match_label_mask
            match_loss = match_loss.sum() / (float_match_label_mask.sum() +
                                             1e-10)
        else:
            start_loss = self.dice_loss(start_logits, start_labels.float(),
                                        start_float_label_mask)
            end_loss = self.dice_loss(end_logits, end_labels.float(),
                                      end_float_label_mask)
            match_loss = self.dice_loss(span_logits, match_labels.float(),
                                        float_match_label_mask)

        return start_loss, end_loss, match_loss

    def training_step(self, batch, batch_idx):
        """"""
        tf_board_logs = {
            "lr": self.trainer.optimizers[0].param_groups[0]["lr"]
        }
        (
            tokens,
            token_type_ids,
            start_labels,
            end_labels,
            start_label_mask,
            end_label_mask,
            match_labels,
            sample_idx,
            label_idx,
        ) = batch

        # num_tasks * [bsz, length, num_labels]
        attention_mask = (tokens != 0).long()
        start_logits, end_logits, span_logits = self(tokens, attention_mask,
                                                     token_type_ids)

        start_loss, end_loss, match_loss = self.compute_loss(
            start_logits=start_logits,
            end_logits=end_logits,
            span_logits=span_logits,
            start_labels=start_labels,
            end_labels=end_labels,
            match_labels=match_labels,
            start_label_mask=start_label_mask,
            end_label_mask=end_label_mask,
        )

        total_loss = (self.weight_start * start_loss +
                      self.weight_end * end_loss +
                      self.weight_span * match_loss)

        tf_board_logs[f"train_loss"] = total_loss
        tf_board_logs[f"start_loss"] = start_loss
        tf_board_logs[f"end_loss"] = end_loss
        tf_board_logs[f"match_loss"] = match_loss

        return {"loss": total_loss, "log": tf_board_logs}

    def validation_step(self, batch, batch_idx):
        """"""

        output = {}

        (
            tokens,
            token_type_ids,
            start_labels,
            end_labels,
            start_label_mask,
            end_label_mask,
            match_labels,
            sample_idx,
            label_idx,
        ) = batch

        attention_mask = (tokens != 0).long()

        start_logits, end_logits, span_logits = self(tokens, attention_mask,
                                                     token_type_ids)

        start_loss, end_loss, match_loss = self.compute_loss(
            start_logits=start_logits,
            end_logits=end_logits,
            span_logits=span_logits,
            start_labels=start_labels,
            end_labels=end_labels,
            match_labels=match_labels,
            start_label_mask=start_label_mask,
            end_label_mask=end_label_mask,
        )

        total_loss = (self.weight_start * start_loss +
                      self.weight_end * end_loss +
                      self.weight_span * match_loss)

        output[f"val_loss"] = total_loss
        output[f"start_loss"] = start_loss
        output[f"end_loss"] = end_loss
        output[f"match_loss"] = match_loss

        start_preds, end_preds = start_logits > 0, end_logits > 0
        span_f1_stats = self.span_f1(
            start_preds=start_preds,
            end_preds=end_preds,
            match_logits=span_logits,
            start_label_mask=start_label_mask,
            end_label_mask=end_label_mask,
            match_labels=match_labels,
        )
        output["span_f1_stats"] = span_f1_stats

        return output

    def validation_epoch_end(self, outputs):
        """"""
        avg_loss = torch.stack([x["val_loss"] for x in outputs]).mean()
        tensorboard_logs = {"val_loss": avg_loss}

        all_counts = torch.stack([x[f"span_f1_stats"] for x in outputs]).sum(0)
        span_tp, span_fp, span_fn = all_counts
        span_recall = span_tp / (span_tp + span_fn + 1e-10)
        span_precision = span_tp / (span_tp + span_fp + 1e-10)
        span_f1 = (span_precision * span_recall * 2 /
                   (span_recall + span_precision + 1e-10))
        tensorboard_logs[f"span_precision"] = span_precision
        tensorboard_logs[f"span_recall"] = span_recall
        tensorboard_logs[f"span_f1"] = span_f1

        return {"val_loss": avg_loss, "log": tensorboard_logs}

    def test_step(self, batch, batch_idx):
        """"""
        return self.validation_step(batch, batch_idx)

    def test_epoch_end(self, outputs) -> Dict[str, Dict[str, Tensor]]:
        """"""
        return self.validation_epoch_end(outputs)

    def train_dataloader(self) -> DataLoader:
        return self.get_dataloader("train")
        # return self.get_dataloader("dev", 100)

    def val_dataloader(self):
        return self.get_dataloader("dev")

    def test_dataloader(self):
        return self.get_dataloader("test")
        # return self.get_dataloader("dev")

    def get_dataloader(self, prefix="train", limit: int = None) -> DataLoader:
        """get training dataloader"""
        """
        load_mmap_dataset
        """
        json_path = os.path.join(self.data_dir, f"mrc-ner.{prefix}")
        vocab_path = os.path.join(self.bert_dir, "vocab.txt")
        dataset = MRCNERDataset(
            json_path=json_path,
            tokenizer=AutoTokenizer.from_pretrained(self.args.bert_config_dir),
            max_length=self.args.max_length,
            is_chinese=self.chinese,
            pad_to_maxlen=False,
        )

        if limit is not None:
            dataset = TruncateDataset(dataset, limit)

        dataloader = DataLoader(
            dataset=dataset,
            batch_size=self.args.batch_size,
            num_workers=self.args.workers,
            shuffle=True if prefix == "train" else False,
            collate_fn=collate_to_max_length,
        )

        return dataloader