Example #1
0
def compute_loss(start_logits, end_logits, span_logits, start_labels,
                 end_labels, match_labels, start_label_mask, end_label_mask):
    bce_loss = BCEWithLogitsLoss(reduction="none")

    batch_size, seq_len = start_logits.size()

    start_float_label_mask = start_label_mask.view(-1).float()
    end_float_label_mask = end_label_mask.view(-1).float()
    match_label_row_mask = start_label_mask.bool().unsqueeze(-1).expand(
        -1, -1, seq_len)
    match_label_col_mask = end_label_mask.bool().unsqueeze(-2).expand(
        -1, seq_len, -1)
    match_label_mask = match_label_row_mask & match_label_col_mask
    match_label_mask = torch.triu(match_label_mask,
                                  0)  # start should be less equal to end

    float_match_label_mask = match_label_mask.view(batch_size, -1).float()

    start_loss = bce_loss(start_logits.view(-1), start_labels.view(-1).float())
    start_loss = (start_loss *
                  start_float_label_mask).sum() / start_float_label_mask.sum()
    end_loss = bce_loss(end_logits.view(-1), end_labels.view(-1).float())
    end_loss = (end_loss *
                end_float_label_mask).sum() / end_float_label_mask.sum()
    match_loss = bce_loss(span_logits.view(batch_size, -1),
                          match_labels.view(batch_size, -1).float())
    match_loss = match_loss * float_match_label_mask
    match_loss = match_loss.sum() / (float_match_label_mask.sum() + 1e-10)

    return start_loss, end_loss, match_loss
Example #2
0
    def __init__(self, args: argparse.Namespace):
        """Initialize a model, tokenizer and config."""
        super().__init__()
        if isinstance(args, argparse.Namespace):
            self.save_hyperparameters(args)
            self.args = args
        else:
            # eval mode
            TmpArgs = namedtuple("tmp_args", field_names=list(args.keys()))
            self.args = args = TmpArgs(**args)

        self.bert_pretrained_model = args.bert_model
        self.data_dir = self.args.data_dir

        # bert_config = BertQueryNerConfig.from_pretrained(args.bert_config_dir,
        #                                                  hidden_dropout_prob=args.bert_dropout,
        #                                                  attention_probs_dropout_prob=args.bert_dropout,
        #                                                  mrc_dropout=args.mrc_dropout)

        phobert_config = PhobertQueryNerConfig.from_pretrained(
            args.bert_model,
            hidden_dropout_prob=args.bert_dropout,
            attention_probs_dropout_prob=args.bert_dropout,
            type_vocab_size=1,
            mrc_dropout=args.mrc_dropout)

        self.model = PhoBertQueryNER.from_pretrained(args.bert_model,
                                                     config=phobert_config)
        if args.freeze_bert:
            self.model.roberta.requires_grad_(False)
        self.tokenizer = AutoTokenizer.from_pretrained(args.bert_model)
        logging.info(str(self.model))
        logging.info(
            str(args.__dict__ if isinstance(args, argparse.ArgumentParser
                                            ) else args))
        # self.ce_loss = CrossEntropyLoss(reduction="none")
        self.loss_type = args.loss_type
        # self.loss_type = "bce"
        if self.loss_type == "bce":
            self.bce_loss = BCEWithLogitsLoss(reduction="none")
        else:
            self.dice_loss = DiceLoss(with_logits=True,
                                      smooth=args.dice_smooth)
        # todo(yuxian): 由于match loss是n^2的,应该特殊调整一下loss rate
        weight_sum = args.weight_start + args.weight_end + args.weight_span
        self.weight_start = args.weight_start / weight_sum
        self.weight_end = args.weight_end / weight_sum
        self.weight_span = args.weight_span / weight_sum
        self.flat_ner = args.flat
        self.span_f1 = QuerySpanF1(flat=self.flat_ner)
        self.chinese = args.chinese
        self.optimizer = args.optimizer
        self.span_loss_candidates = args.span_loss_candidates
Example #3
0
    def __init__(
        self,
        span_loss_candidates='all',
        loss_type='bce',
        dice_smooth=1e-8,
    ):
        super(CustomAdaptiveLoss, self).__init__()
        self.loss_type = loss_type
        self.span_loss_candidates = span_loss_candidates

        if self.loss_type == "bce":
            self.bce_loss = BCEWithLogitsLoss(reduction="none")
        else:
            self.dice_loss = DiceLoss(with_logits=True, smooth=dice_smooth)

        self.log_vars = nn.Parameter(torch.zeros(2))
    def __init__(self, args: argparse.Namespace):
        """Initialize a model, tokenizer and config."""
        super().__init__()
        if isinstance(args, argparse.Namespace):
            self.save_hyperparameters(args)
            self.args = args
        else:
            # eval mode
            TmpArgs = namedtuple("tmp_args", field_names=list(args.keys()))
            self.args = args = TmpArgs(**args)

        self.bert_dir = args.bert_config_dir
        self.data_dir = self.args.data_dir

        bert_config = get_auto_config(
            bert_config_dir=args.bert_config_dir,
            hidden_dropout_prob=args.bert_dropout,
            attention_probs_dropout_prob=args.bert_dropout,
            mrc_dropout=args.mrc_dropout,
        )

        self.model = BertQueryNER(config=bert_config)
        # logging.info(str(self.model))
        logging.info(
            str(args.__dict__ if isinstance(args, argparse.ArgumentParser
                                            ) else args))
        # self.ce_loss = CrossEntropyLoss(reduction="none")
        self.loss_type = args.loss_type
        # self.loss_type = "bce"
        if self.loss_type == "bce":
            self.bce_loss = BCEWithLogitsLoss(reduction="none")
        else:
            self.dice_loss = DiceLoss(with_logits=True,
                                      smooth=args.dice_smooth)
        # todo(yuxian): 由于match loss是n^2的,应该特殊调整一下loss rate
        weight_sum = args.weight_start + args.weight_end + args.weight_span
        self.weight_start = args.weight_start / weight_sum
        self.weight_end = args.weight_end / weight_sum
        self.weight_span = args.weight_span / weight_sum
        self.flat_ner = args.flat
        self.span_f1 = QuerySpanF1(flat=self.flat_ner)
        self.chinese = args.chinese
        self.optimizer = args.optimizer
        self.span_loss_candidates = args.span_loss_candidates
Example #5
0
    def __init__(
        self,
        weight_start=1.,
        weight_end=1.,
        weight_span=1.,
        span_loss_candidates='all',
        loss_type='bce',
        dice_smooth=1e-8,
    ):
        super(CustomLoss, self).__init__()
        weight_sum = weight_start + weight_end + weight_span
        self.weight_start = weight_start / weight_sum
        self.weight_end = weight_end / weight_sum
        self.weight_span = weight_span / weight_sum
        self.span_loss_candidates = span_loss_candidates
        self.loss_type = loss_type

        if self.loss_type == "bce":
            self.bce_loss = BCEWithLogitsLoss(reduction="none")
        else:
            self.dice_loss = DiceLoss(with_logits=True, smooth=dice_smooth)
Example #6
0
        # [batch, seq_len, seq_len, hidden]
        start_extend = sequence_heatmap.unsqueeze(2).expand(
            -1, -1, seq_len, -1)
        # [batch, seq_len, seq_len, hidden]
        end_extend = sequence_heatmap.unsqueeze(1).expand(-1, seq_len, -1, -1)
        # [batch, seq_len, seq_len, hidden*2]
        span_matrix = torch.cat([start_extend, end_extend], 3)
        # [batch, seq_len, seq_len]
        span_logits = self.span_embedding(span_matrix).squeeze(-1)

        return start_logits, end_logits, span_logits


span_loss_candidates = "all"
loss_type = "bce"
bce_loss = BCEWithLogitsLoss(reduction="none")
dice_smooth = 1e-8
dice_loss = ""

#dice_loss = DiceLoss(with_logits=True, smooth=dice_smooth)


def compute_loss(start_logits, end_logits, span_logits, start_labels,
                 end_labels, match_labels, start_label_mask, end_label_mask):
    batch_size, seq_len = start_logits.size()

    start_float_label_mask = start_label_mask.view(-1).float()
    end_float_label_mask = end_label_mask.view(-1).float()
    match_label_row_mask = start_label_mask.bool().unsqueeze(-1).expand(
        -1, -1, seq_len)
    match_label_col_mask = end_label_mask.bool().unsqueeze(-2).expand(
Example #7
0
def main(json_path=''):

    parser = HfArgumentParser((CustomizeArguments, TrainingArguments))

    if json_path:
        custom_args, training_args = parser.parse_json_file(json_file=json_path)
    elif len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
        custom_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
    else:
        custom_args, training_args = parser.parse_args_into_dataclasses()

    logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s -   %(message)s",
        datefmt="%m/%d/%Y %H:%M:%S",
        handlers=[logging.StreamHandler(sys.stdout), logging.FileHandler(custom_args.log_file_path)],
    )
    logger.setLevel(logging.INFO if is_main_process(training_args.local_rank) else logging.WARN)

    logger.info(
        f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu} "
        + f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
    )

    logger.info('Description: {}'.format(custom_args.description))
    if json_path:
        logger.info('json file path is : {}'.format(json_path))
        logger.info('json file args are: \n'+open(json_path, 'r').read())

    # last_checkpoint = None
    # if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir:
    #     last_checkpoint = get_last_checkpoint(training_args.output_dir)
    #     if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0:
    #         raise ValueError(
    #             f"Output directory ({training_args.output_dir}) already exists and is not empty. "
    #             "Use --overwrite_output_dir to overcome."
    #         )
    #     elif last_checkpoint is not None:
    #         logger.info(
    #             f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change "
    #             "the `--output_dir` or add `--overwrite_output_dir` to train from scratch."
    #         )

    set_seed(training_args.seed)

    config = BertQueryNerConfig.from_pretrained(
        custom_args.config_name_or_path if custom_args.config_name_or_path else custom_args.model_name_or_path,
        # num_labels=custom_args.num_labels
    )

    model = BertQueryNER.from_pretrained(
        custom_args.model_name_or_path,
        config=config
    )

    tokenizer = BertTokenizer.from_pretrained(
        custom_args.tokenizer_name_or_path if custom_args.tokenizer_name_or_path else custom_args.model_name_or_path,
    )

    # data = pd.read_pickle(custom_args.pickle_data_path)
    # # df_train = pd.read_pickle(custom_args.train_pickle_data_path)
    # # df_eval = pd.read_pickle(custom_args.eval_pickle_data_path)
    # train_dataloader, eval_dataloader = gen_dataloader(
    #     df=data,
    #     # df_train=df_train,
    #     # df_eval=df_eval,
    #     tokenizer=tokenizer,
    #     per_device_train_batch_size=training_args.per_device_train_batch_size,
    #     per_device_eval_batch_size=training_args.per_device_eval_batch_size,
    #     test_size=custom_args.test_size,
    #     max_length=custom_args.max_length,
    # )

    train_dataloader = get_dataloader('train', 64)
    eval_dataloader = get_dataloader('test', 32)
    extra_loss = BCEWithLogitsLoss(reduction="none")
    extra_dice_loss = MRCDiceLoss(with_logits=True)

    # device = training_args.device if torch.cuda.is_available() else 'cpu'

    model = nn.DataParallel(model)
    model = model.cuda()
    total_bt = time.time()

    optimizer = AdamW(model.parameters(),
                  lr = 1e-5,
                  eps = 1e-8
                )

    total_steps = len(train_dataloader) * training_args.num_train_epochs

    scheduler = get_linear_schedule_with_warmup(optimizer, 
                                                num_warmup_steps = 5, 
                                                num_training_steps = total_steps)

    weight_sum = custom_args.weight_start + custom_args.weight_end + custom_args.weight_span
    weight_start = custom_args.weight_start / weight_sum
    weight_end = custom_args.weight_end / weight_sum
    weight_span = custom_args.weight_span / weight_sum
    # fgm = FGM(model)

    for e in range(training_args.num_train_epochs):

        logger.info('============= Epoch {:} / {:} =============='.format(e + 1, training_args.num_train_epochs))
        logger.info('Training...')

        bt = time.time()
        total_train_loss = 0
        model.train()

        for step, batch in enumerate(train_dataloader):
            # break

            if step % 50 == 0 and not step == 0:
                elapsed = format_time(time.time() - bt)
                logger.info('  Batch {:>5,}  of  {:>5,}.   Elapsed: {:}.    loss: {}'.format(step, len(train_dataloader), elapsed, total_train_loss/step))

            input_ids = batch[0].cuda()
            token_type_ids = batch[1].cuda()
            start_labels = batch[2].cuda()
            end_labels = batch[3].cuda()
            start_label_mask = batch[4].cuda()
            end_label_mask = batch[5].cuda()
            match_labels = batch[6].cuda()
            # sample_idx = batch[7].cuda()
            label_idx = batch[7].cuda()

            attention_mask = (input_ids != 0).long()

            model.zero_grad()

            start_logits, end_logits, span_logits = model(
                input_ids=input_ids,
                attention_mask=attention_mask,
                token_type_ids=token_type_ids,
            )

            start_loss, end_loss, match_loss = compute_loss(
                                                            _loss=extra_dice_loss,
                                                            start_logits=start_logits,
                                                            end_logits=end_logits,
                                                            span_logits=span_logits,
                                                            start_labels=start_labels,
                                                            end_labels=end_labels,
                                                            match_labels=match_labels,
                                                            start_label_mask=start_label_mask,
                                                            end_label_mask=end_label_mask
                                                            )

            loss = weight_start * start_loss + weight_end * end_loss + weight_span * match_loss

            # loss = output.loss
            # logits = output.logits
            total_train_loss += loss.item()

            loss.backward()

            # fgm.attack(epsilon=1.2)
            # output_adv = model(
            #     input_ids=input_ids,
            #     attention_mask=attention_mask,
            #     labels=labels
            # )

            # loss_adv = output_adv.loss
            # loss_adv.backward()
            # fgm.restore()

            optimizer.step()
            scheduler.step()
            # if step % 50 == 0 and step != 0:
            #     break

        avg_train_loss = total_train_loss / len(train_dataloader)
        training_time = format_time(time.time() - bt)
        logger.info('Average training loss: {0:.2f}'.format(avg_train_loss))
        logger.info('Training epcoh took: {:}'.format(training_time))

        logger.info('Running Validation...')
        bt = time.time()
        model.eval()

        total_eval_loss = 0
        total_eval_f1 = 0
        total_eval_acc = 0
        total_eval_p = []
        total_eval_l = []

        for batch in eval_dataloader:

            input_ids = batch[0].cuda()
            token_type_ids = batch[1].cuda()
            start_labels = batch[2].cuda()
            end_labels = batch[3].cuda()
            start_label_mask = batch[4].cuda()
            end_label_mask = batch[5].cuda()
            match_labels = batch[6].cuda()
            # sample_idx = batch[7].cuda()
            label_idx = batch[7].cuda()

            attention_mask = (input_ids != 0).long()

            with torch.no_grad():
                start_logits, end_logits, span_logits = model(
                    input_ids=input_ids,
                    attention_mask=attention_mask,
                    token_type_ids=token_type_ids
                )

            start_loss, end_loss, match_loss = compute_loss(
                                                _loss=extra_dice_loss,
                                                start_logits=start_logits,
                                                end_logits=end_logits,
                                                span_logits=span_logits,
                                                start_labels=start_labels,
                                                end_labels=end_labels,
                                                match_labels=match_labels,
                                                start_label_mask=start_label_mask,
                                                end_label_mask=end_label_mask
                                                )

            loss = weight_start * start_loss + weight_end * end_loss + weight_span * match_loss

            total_eval_loss += loss.item()
            start_preds, end_preds = start_logits > 0, end_logits > 0
            eval_f1 = query_span_f1(start_preds, end_preds, span_logits, start_label_mask, end_label_mask, match_labels)
            # logger.info('eval_f1 : {}'.format(eval_f1))
            total_eval_f1 += eval_f1
            # break


        # logger.info(f'\n{classification_report(total_eval_p, total_eval_l, zero_division=1)}')
        avg_val_f1 = total_eval_f1 / len(eval_dataloader)
        # avg_val_acc = total_eval_acc / len(eval_dataloader)
        logger.info('F1: {0:.2f}'.format(avg_val_f1))
        # logger.info('Acc: {0:.2f}'.format(avg_val_acc))

        avg_val_loss = total_eval_loss / len(eval_dataloader)
        validation_time = format_time(time.time() - bt)

        logger.info('Validation Loss: {0:.2f}'.format(avg_val_loss))
        logger.info('Validation took: {:}'.format(validation_time))

        current_ckpt = training_args.output_dir + '/bert-' + datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S") + '-f1_' + str(int(avg_val_f1*100)) + '.pth'
        logger.info('Start to save checkpoint named {}'.format(current_ckpt))
        if custom_args.deploy is True:
            logger.info('>>>>>>>>>>>> saving the model <<<<<<<<<<<<<<')
            torch.save(model.module, current_ckpt)
        else:
            logger.info('>>>>>>>>>>>> saving the state_dict of model <<<<<<<<<<<<<')
            torch.save(model.module.state_dict(), current_ckpt)
Example #8
0
    def __init__(self, args: argparse.Namespace):
        """Initialize a model, tokenizer and config."""
        super().__init__()
        if isinstance(args, argparse.Namespace):
            self.save_hyperparameters(args)
            self.args = args
        else:
            # eval mode
            TmpArgs = namedtuple("tmp_args", field_names=list(args.keys()))
            self.args = args = TmpArgs(**args)

        # self.bert_dir = args.bert_config_dir
        self.data_dir = self.args.data_dir
        self.bert_config_dir = BERT_DIR[self.args.model]
        self.tokenizer = BertWordPieceTokenizer(vocab=self.bert_config_dir +
                                                '/vocab.txt')

        if self.args.model == 'BERTMRC':
            self.tokenizer = BertWordPieceTokenizer(
                vocab=self.bert_config_dir + '/vocab.txt')
            bert_config = BertQueryNerConfig.from_pretrained(
                self.bert_config_dir,
                hidden_dropout_prob=args.bert_dropout,
                attention_probs_dropout_prob=args.bert_dropout,
                mrc_dropout=args.mrc_dropout)

            self.model = BERTModel[self.args.model].from_pretrained(
                self.bert_config_dir, config=bert_config)
        else:
            # self.tokenizer = AutoTokenizer.from_pretrained(self.bert_config_dir, do_lower_case=True)
            # self.model = BertForQuestionAnswering.from_pretrained(self.bert_config_dir)
            self.model = BERTModel[self.args.model](self.bert_config_dir,
                                                    self.args)
        logging.info(str(self.model))
        logging.info(
            str(args.__dict__ if isinstance(args, argparse.ArgumentParser
                                            ) else args))
        # self.ce_loss = CrossEntropyLoss(reduction="none")
        self.loss_type = args.loss_type
        # self.loss_type = "bce"
        if self.loss_type == "bce":
            self.bce_loss = BCEWithLogitsLoss(reduction="none")
        else:
            self.dice_loss = DiceLoss(with_logits=True,
                                      smooth=args.dice_smooth)
        # todo(yuxian): 由于match loss是n^2的,应该特殊调整一下loss rate
        weight_sum = args.weight_start + args.weight_end + args.weight_span
        self.weight_start = args.weight_start / weight_sum
        self.weight_end = args.weight_end / weight_sum
        self.weight_span = args.weight_span / weight_sum
        self.flat_ner = args.flat
        self.span_f1 = QuerySpanF1(flat=self.flat_ner)
        self.chinese = args.chinese
        self.optimizer = args.optimizer
        self.span_loss_candidates = args.span_loss_candidates
        self.dataset_train, self.dataset_valid, self.dataset_test = get_dataloader(
            args.tgt_domain,
            args.n_samples,
            args.batch_size,
            self.tokenizer,
            query_type=self.args.query_type)