def main(config): """Main method for predicting BERT-based NER model on CoNLL-formatted test data.""" train_config, tag_vocab = load_metadata(config.load_checkpoint_prefix) ctx = get_context(config.gpu) bert_model, text_vocab = get_bert_model(train_config.bert_model, train_config.cased, ctx, train_config.dropout_prob) dataset = BERTTaggingDataset(text_vocab, None, None, config.test_path, config.seq_len, train_config.cased, tag_vocab=tag_vocab) test_data_loader = dataset.get_test_data_loader(config.batch_size) net = BERTTagger(bert_model, dataset.num_tag_types, train_config.dropout_prob) model_filename = _find_model_file_from_checkpoint(config.load_checkpoint_prefix) net.load_parameters(model_filename, ctx=ctx) net.hybridize(static_alloc=True) loss_function = mx.gluon.loss.SoftmaxCrossEntropyLoss() loss_function.hybridize(static_alloc=True) # TODO(bikestra): make it not redundant between train and predict def evaluate(data_loader): """Eval function""" predictions = [] for batch_id, data in enumerate(data_loader): logging.info('evaluating on batch index: %d/%d', batch_id, len(data_loader)) text_ids, token_types, valid_length, tag_ids, _ = \ [x.astype('float32').as_in_context(ctx) for x in data] out = net(text_ids, token_types, valid_length) # convert results to numpy arrays for easier access np_text_ids = text_ids.astype('int32').asnumpy() np_pred_tags = out.argmax(axis=-1).asnumpy() np_valid_length = valid_length.astype('int32').asnumpy() np_true_tags = tag_ids.asnumpy() predictions += convert_arrays_to_text(text_vocab, dataset.tag_vocab, np_text_ids, np_true_tags, np_pred_tags, np_valid_length) all_true_tags = [[entry.true_tag for entry in entries] for entries in predictions] all_pred_tags = [[entry.pred_tag for entry in entries] for entries in predictions] seqeval_f1 = seqeval.metrics.f1_score(all_true_tags, all_pred_tags) return seqeval_f1 test_f1 = evaluate(test_data_loader) logging.info('test f1: {:.3f}'.format(test_f1))
def main(config): """Main method for training BERT-based NER model.""" # provide random seed for every RNGs we use np.random.seed(config.seed) random.seed(config.seed) mx.random.seed(config.seed) ctx = get_context(config.gpu) logging.info('Loading BERT model...') bert_model, text_vocab = get_bert_model(config.bert_model, config.cased, ctx, config.dropout_prob) dataset = BERTTaggingDataset(text_vocab, config.train_path, config.dev_path, config.test_path, config.seq_len, config.cased) train_data_loader = dataset.get_train_data_loader(config.batch_size) dev_data_loader = dataset.get_dev_data_loader(config.batch_size) test_data_loader = dataset.get_test_data_loader(config.batch_size) net = BERTTagger(bert_model, dataset.num_tag_types, config.dropout_prob) net.tag_classifier.initialize(init=mx.init.Normal(0.02), ctx=ctx) net.hybridize(static_alloc=True) loss_function = mx.gluon.loss.SoftmaxCrossEntropyLoss() loss_function.hybridize(static_alloc=True) # step size adaptation, adopted from: https://github.com/dmlc/gluon-nlp/blob/ # 87d36e3cc7c615f93732d01048cf7ce3b3b09eb7/scripts/bert/finetune_classifier.py#L348-L351 step_size = config.batch_size num_train_steps = int( len(dataset.train_inputs) / step_size * config.num_epochs) num_warmup_steps = int(num_train_steps * config.warmup_ratio) optimizer_params = {'learning_rate': config.learning_rate} trainer = mx.gluon.Trainer(net.collect_params(), config.optimizer, optimizer_params) # collect differentiable parameters logging.info('Collect params...') # do not apply weight decay on LayerNorm and bias terms for _, v in net.collect_params('.*beta|.*gamma|.*bias').items(): v.wd_mult = 0.0 params = [p for p in net.collect_params().values() if p.grad_req != 'null'] if config.save_checkpoint_prefix is not None: logging.info('dumping metadata...') dump_metadata(config, tag_vocab=dataset.tag_vocab) def train(data_loader, start_step_num): """Training loop.""" step_num = start_step_num logging.info('current starting step num: %d', step_num) for batch_id, (_, _, _, tag_ids, flag_nonnull_tag, out) in \ enumerate(attach_prediction(data_loader, net, ctx, is_train=True)): logging.info('training on batch index: %d/%d', batch_id, len(data_loader)) # step size adjustments step_num += 1 if step_num < num_warmup_steps: new_lr = config.learning_rate * step_num / num_warmup_steps else: offset = ((step_num - num_warmup_steps) * config.learning_rate / (num_train_steps - num_warmup_steps)) new_lr = config.learning_rate - offset trainer.set_learning_rate(new_lr) with mx.autograd.record(): loss_value = loss_function( out, tag_ids, flag_nonnull_tag.expand_dims(axis=2)).mean() loss_value.backward() nlp.utils.clip_grad_global_norm(params, 1) trainer.step(1) pred_tags = out.argmax(axis=-1) logging.info('loss_value: %6f', loss_value.asscalar()) num_tag_preds = flag_nonnull_tag.sum().asscalar() logging.info( 'accuracy: %6f', (((pred_tags == tag_ids) * flag_nonnull_tag).sum().asscalar() / num_tag_preds)) return step_num def evaluate(data_loader): """Eval loop.""" predictions = [] for batch_id, (text_ids, _, valid_length, tag_ids, _, out) in \ enumerate(attach_prediction(data_loader, net, ctx, is_train=False)): logging.info('evaluating on batch index: %d/%d', batch_id, len(data_loader)) # convert results to numpy arrays for easier access np_text_ids = text_ids.astype('int32').asnumpy() np_pred_tags = out.argmax(axis=-1).asnumpy() np_valid_length = valid_length.astype('int32').asnumpy() np_true_tags = tag_ids.asnumpy() predictions += convert_arrays_to_text(text_vocab, dataset.tag_vocab, np_text_ids, np_true_tags, np_pred_tags, np_valid_length) all_true_tags = [[entry.true_tag for entry in entries] for entries in predictions] all_pred_tags = [[entry.pred_tag for entry in entries] for entries in predictions] seqeval_f1 = seqeval.metrics.f1_score(all_true_tags, all_pred_tags) return seqeval_f1 best_dev_f1 = 0.0 last_test_f1 = 0.0 best_epoch = -1 last_epoch_step_num = 0 for epoch_index in range(config.num_epochs): last_epoch_step_num = train(train_data_loader, last_epoch_step_num) train_f1 = evaluate(train_data_loader) logging.info('train f1: %3f', train_f1) dev_f1 = evaluate(dev_data_loader) logging.info('dev f1: %3f, previous best dev f1: %3f', dev_f1, best_dev_f1) if dev_f1 > best_dev_f1: best_dev_f1 = dev_f1 best_epoch = epoch_index logging.info('update the best dev f1 to be: %3f', best_dev_f1) test_f1 = evaluate(test_data_loader) logging.info('test f1: %3f', test_f1) last_test_f1 = test_f1 # save params params_file = config.save_checkpoint_prefix + '_{:03d}.params'.format( epoch_index) logging.info('saving current checkpoint to: %s', params_file) net.save_parameters(params_file) logging.info('current best epoch: %d', best_epoch) logging.info('best epoch: %d, best dev f1: %3f, test f1 at tha epoch: %3f', best_epoch, best_dev_f1, last_test_f1)