def test_evaluate_result(): """Evaluate prediction result.""" h_params = get_debug_hparams()._replace(do_train=False, do_eval=False, do_predict=True) result = run(h_params) evaluate_ner_pred_result(h_params, result)
def test_train_evaluate(): reduce_output() validate_debug_params() hparams = get_debug_hparams() hparams = hparams._replace(output_dir=str(get_project_root() / 'tmp' / 'test_my_classifier_train_eval')) _, estimator = get_model_fn_and_estimator(hparams) train_evaluate(hparams, estimator)
def test_train(): reduce_output() validate_debug_params() hparams = get_debug_hparams() hparams = hparams._replace(output_dir=str(get_project_root() / 'tmp' / 'test_my_classifier_train')) _, estimator = get_model_fn_and_estimator(hparams) train(hparams, estimator, max_steps=2) clean_folder(hparams.output_dir)
def test_convert_examples_to_features(): hparams = get_debug_hparams() data_filename = get_data_filename(hparams) examples = get_examples(data_filename, SetType.train) features = convert_examples_to_features( examples=examples, label_list=get_unique_intents(data_filename), max_seq_length=hparams.max_seq_length, tokenizer=get_tokenizer(hparams)) print( examples[2].__dict__ ) # {'guid': 'train-2', 'text_a': 'Is it worth upgrading from 12.04 LTS to 13.04', 'text_b': None, 'label': 'Make Update'} print( features[2].__dict__ ) # {'input_ids': [101, 2003, 2009, 4276, 25925, 2013, 2260, 1012, 5840, 8318, 2015, 2000, 2410, 1012, 5840, 102, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 'input_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 'segment_ids': [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 'label_id': 1}
def main(): hparams = get_debug_hparams() tf.gfile.MakeDirs(str(hparams.output_dir)) model_fn, estimator = get_model_fn_and_estimator(hparams) if hparams.do_train: train(hparams, estimator) if hparams.do_eval: evaluate(hparams, estimator) # note that predictions are non-deterministic if hparams.do_predict: predict(hparams)
def test_eval(): """Smoke test""" h_params = get_debug_hparams()._replace(do_train=False, do_eval=True, do_predict=False) run(h_params)
def validate_debug_params(): hparams = get_debug_hparams() assert 2 == hparams.num_train_steps assert 1 == hparams.save_checkpoints_steps assert 1 == hparams.train_batch_size assert 1 == hparams.save_summary_steps
predict_file) tf.logging.info("***** Running prediction*****") tf.logging.info(" Num examples = %d", len(predict_examples)) tf.logging.info(" Batch size = %d", params.predict_batch_size) predict_drop_remainder = params.use_tpu predict_input_fn = file_based_input_fn_builder( input_file=predict_file, seq_length=params.max_seq_length, is_training=False, drop_remainder=predict_drop_remainder) result: Iterable[np.ndarray] = estimator.predict(input_fn=predict_input_fn) label_list = get_intents( data_filename) # used for label_list[max_class] this might be wrong y_pred = convert_result_pred(result, label_list) print('f1 score: {}'.format( get_rounded_f1(params.data_dir.parent / 'askubuntu.tsv', y_pred, average='micro'))) return y_pred if __name__ == '__main__': from improv.config import get_debug_hparams # train_and_evaluate(get_debug_hparams()) predict(get_debug_hparams())