def main(config):
    """Main method for predicting BERT-based NER model on CoNLL-formatted test data."""
    train_config, tag_vocab = load_metadata(config.load_checkpoint_prefix)

    ctx = get_context(config.gpu)
    bert_model, text_vocab = get_bert_model(train_config.bert_model, train_config.cased, ctx,
                                            train_config.dropout_prob)

    dataset = BERTTaggingDataset(text_vocab, None, None, config.test_path,
                                 config.seq_len, train_config.cased, tag_vocab=tag_vocab)

    test_data_loader = dataset.get_test_data_loader(config.batch_size)

    net = BERTTagger(bert_model, dataset.num_tag_types, train_config.dropout_prob)
    model_filename = _find_model_file_from_checkpoint(config.load_checkpoint_prefix)
    net.load_parameters(model_filename, ctx=ctx)

    net.hybridize(static_alloc=True)

    loss_function = mx.gluon.loss.SoftmaxCrossEntropyLoss()
    loss_function.hybridize(static_alloc=True)

    # TODO(bikestra): make it not redundant between train and predict
    def evaluate(data_loader):
        """Eval function"""
        predictions = []

        for batch_id, data in enumerate(data_loader):
            logging.info('evaluating on batch index: %d/%d', batch_id, len(data_loader))
            text_ids, token_types, valid_length, tag_ids, _ = \
                [x.astype('float32').as_in_context(ctx) for x in data]
            out = net(text_ids, token_types, valid_length)

            # convert results to numpy arrays for easier access
            np_text_ids = text_ids.astype('int32').asnumpy()
            np_pred_tags = out.argmax(axis=-1).asnumpy()
            np_valid_length = valid_length.astype('int32').asnumpy()
            np_true_tags = tag_ids.asnumpy()

            predictions += convert_arrays_to_text(text_vocab, dataset.tag_vocab, np_text_ids,
                                                  np_true_tags, np_pred_tags, np_valid_length)

        all_true_tags = [[entry.true_tag for entry in entries] for entries in predictions]
        all_pred_tags = [[entry.pred_tag for entry in entries] for entries in predictions]
        seqeval_f1 = seqeval.metrics.f1_score(all_true_tags, all_pred_tags)
        return seqeval_f1

    test_f1 = evaluate(test_data_loader)
    logging.info('test f1: {:.3f}'.format(test_f1))
Example #2
0
def main(config):
    """Main method for training BERT-based NER model."""
    # provide random seed for every RNGs we use
    np.random.seed(config.seed)
    random.seed(config.seed)
    mx.random.seed(config.seed)

    ctx = get_context(config.gpu)

    logging.info('Loading BERT model...')
    bert_model, text_vocab = get_bert_model(config.bert_model, config.cased,
                                            ctx, config.dropout_prob)

    dataset = BERTTaggingDataset(text_vocab, config.train_path,
                                 config.dev_path, config.test_path,
                                 config.seq_len, config.cased)

    train_data_loader = dataset.get_train_data_loader(config.batch_size)
    dev_data_loader = dataset.get_dev_data_loader(config.batch_size)
    test_data_loader = dataset.get_test_data_loader(config.batch_size)

    net = BERTTagger(bert_model, dataset.num_tag_types, config.dropout_prob)
    net.tag_classifier.initialize(init=mx.init.Normal(0.02), ctx=ctx)
    net.hybridize(static_alloc=True)

    loss_function = mx.gluon.loss.SoftmaxCrossEntropyLoss()
    loss_function.hybridize(static_alloc=True)

    # step size adaptation, adopted from: https://github.com/dmlc/gluon-nlp/blob/
    # 87d36e3cc7c615f93732d01048cf7ce3b3b09eb7/scripts/bert/finetune_classifier.py#L348-L351
    step_size = config.batch_size
    num_train_steps = int(
        len(dataset.train_inputs) / step_size * config.num_epochs)
    num_warmup_steps = int(num_train_steps * config.warmup_ratio)

    optimizer_params = {'learning_rate': config.learning_rate}
    trainer = mx.gluon.Trainer(net.collect_params(), config.optimizer,
                               optimizer_params)

    # collect differentiable parameters
    logging.info('Collect params...')
    # do not apply weight decay on LayerNorm and bias terms
    for _, v in net.collect_params('.*beta|.*gamma|.*bias').items():
        v.wd_mult = 0.0
    params = [p for p in net.collect_params().values() if p.grad_req != 'null']

    if config.save_checkpoint_prefix is not None:
        logging.info('dumping metadata...')
        dump_metadata(config, tag_vocab=dataset.tag_vocab)

    def train(data_loader, start_step_num):
        """Training loop."""
        step_num = start_step_num
        logging.info('current starting step num: %d', step_num)
        for batch_id, (_, _, _, tag_ids, flag_nonnull_tag, out) in \
                enumerate(attach_prediction(data_loader, net, ctx, is_train=True)):
            logging.info('training on batch index: %d/%d', batch_id,
                         len(data_loader))

            # step size adjustments
            step_num += 1
            if step_num < num_warmup_steps:
                new_lr = config.learning_rate * step_num / num_warmup_steps
            else:
                offset = ((step_num - num_warmup_steps) *
                          config.learning_rate /
                          (num_train_steps - num_warmup_steps))
                new_lr = config.learning_rate - offset
            trainer.set_learning_rate(new_lr)

            with mx.autograd.record():
                loss_value = loss_function(
                    out, tag_ids, flag_nonnull_tag.expand_dims(axis=2)).mean()

            loss_value.backward()
            nlp.utils.clip_grad_global_norm(params, 1)
            trainer.step(1)

            pred_tags = out.argmax(axis=-1)
            logging.info('loss_value: %6f', loss_value.asscalar())

            num_tag_preds = flag_nonnull_tag.sum().asscalar()
            logging.info(
                'accuracy: %6f',
                (((pred_tags == tag_ids) * flag_nonnull_tag).sum().asscalar() /
                 num_tag_preds))
        return step_num

    def evaluate(data_loader):
        """Eval loop."""
        predictions = []

        for batch_id, (text_ids, _, valid_length, tag_ids, _, out) in \
                enumerate(attach_prediction(data_loader, net, ctx, is_train=False)):
            logging.info('evaluating on batch index: %d/%d', batch_id,
                         len(data_loader))

            # convert results to numpy arrays for easier access
            np_text_ids = text_ids.astype('int32').asnumpy()
            np_pred_tags = out.argmax(axis=-1).asnumpy()
            np_valid_length = valid_length.astype('int32').asnumpy()
            np_true_tags = tag_ids.asnumpy()

            predictions += convert_arrays_to_text(text_vocab,
                                                  dataset.tag_vocab,
                                                  np_text_ids, np_true_tags,
                                                  np_pred_tags,
                                                  np_valid_length)

        all_true_tags = [[entry.true_tag for entry in entries]
                         for entries in predictions]
        all_pred_tags = [[entry.pred_tag for entry in entries]
                         for entries in predictions]
        seqeval_f1 = seqeval.metrics.f1_score(all_true_tags, all_pred_tags)
        return seqeval_f1

    best_dev_f1 = 0.0
    last_test_f1 = 0.0
    best_epoch = -1

    last_epoch_step_num = 0
    for epoch_index in range(config.num_epochs):
        last_epoch_step_num = train(train_data_loader, last_epoch_step_num)
        train_f1 = evaluate(train_data_loader)
        logging.info('train f1: %3f', train_f1)
        dev_f1 = evaluate(dev_data_loader)
        logging.info('dev f1: %3f, previous best dev f1: %3f', dev_f1,
                     best_dev_f1)
        if dev_f1 > best_dev_f1:
            best_dev_f1 = dev_f1
            best_epoch = epoch_index
            logging.info('update the best dev f1 to be: %3f', best_dev_f1)
            test_f1 = evaluate(test_data_loader)
            logging.info('test f1: %3f', test_f1)
            last_test_f1 = test_f1

            # save params
            params_file = config.save_checkpoint_prefix + '_{:03d}.params'.format(
                epoch_index)
            logging.info('saving current checkpoint to: %s', params_file)
            net.save_parameters(params_file)

        logging.info('current best epoch: %d', best_epoch)

    logging.info('best epoch: %d, best dev f1: %3f, test f1 at tha epoch: %3f',
                 best_epoch, best_dev_f1, last_test_f1)