예제 #1
0
def main(args):

    if args.generate_submission:
        generate_submission(args)
    else:
        trainer = AppTrainer(args, ner_labels)

        if args.do_eda:
            show_ner_datainfo(ner_labels, train_data_generator,
                              args.train_file, test_data_generator,
                              args.test_file)

        elif args.do_train:
            train_examples, val_examples = load_train_val_examples(args)
            trainer.train(args, train_examples, val_examples)

        elif args.do_eval:
            _, eval_examples = load_train_val_examples(args)
            model = load_model(args)
            trainer.evaluate(args, model, eval_examples)

        elif args.do_predict:
            test_examples = load_test_examples(args)
            model = load_model(args)
            trainer.predict(args, model, test_examples)
            save_ner_preds(args, trainer.pred_results, test_examples)
예제 #2
0
 def do_predict(args):
     args.model_path = args.best_model_path
     test_examples = load_test_examples(args)
     model = load_model(args)
     trainer.predict(args, model, test_examples)
     reviews_file, category_mentions_file = save_ner_preds(
         args, trainer.pred_results, test_examples)
     return reviews_file, category_mentions_file
예제 #3
0
 def do_predict(args):
     args.model_path = args.best_model_path
     test_examples = load_test_examples(args)
     model = load_model(args)
     trainer.predict(args, model, test_examples)
예제 #4
0
 def do_eval(args):
     args.model_path = args.best_model_path
     eval_examples = load_eval_examples(eval_text_file, eval_bio_file)
     model = load_model(args)
     trainer.evaluate(args, model, eval_examples)
예제 #5
0
 def do_eval(args):
     args.model_path = args.best_model_path
     _, eval_examples = load_train_val_examples(args)
     model = load_model(args)
     trainer.evaluate(args, model, eval_examples)
예제 #6
0
def main(args):
    init_theta(args)

    if args.do_eda:
        eda(args)
        return

    if args.do_merge:
        merge_all_tops()
        return

    if args.merge_multi:
        merge_multi(args)
        return

    if args.generate_submission:
        generate_submission(args)
        return

    if args.fix_results:
        fix_results_file(args)
        return

    train_examples, val_examples = load_train_val_examples(
        args, seg_len=seg_len, seg_backoff=seg_backoff)
    init_labels(args, ner_labels)

    trainer = Trainer(args)
    tokenizer = trainer.tokenizer

    #  tokenizer = load_pretrained_tokenizer(args)

    #  def examples_to_dataset(examples, label2id, tokenizer, max_seq_length):
    #  from functools import partial
    #  do_examples_to_dataset = partial(examples_to_dataset,
    #                                   label2id=args.label2id,
    #                                   tokenizer=tokenizer)

    # --------------- train phase ---------------
    if args.do_train:
        #  train_examples = load_examples_from_bios_file(args.train_file)
        #  eval_examples = load_examples_from_bios_file(args.eval_file)
        #  train_examples, val_examples = load_train_val_examples(train_base_file)

        #  train_dataset = do_examples_to_dataset(
        #      examples=train_examples, max_seq_length=args.train_max_seq_length)
        #  eval_dataset = do_examples_to_dataset(
        #      examples=eval_examples, max_seq_length=args.eval_max_seq_length)
        # for ner_labels

        trainer.train(args, train_examples, val_examples)

    # --------------- predict phase ---------------
    if args.do_eval:
        #  eval_examples = load_examples_from_bios_file(args.eval_file)
        #  train_examples, eval_examples = load_train_eval_examples(
        #      train_base_file)
        #  eval_dataset = do_examples_to_dataset(
        #      examples=eval_examples, max_seq_length=args.eval_max_seq_length)

        #  train_examples, val_examples = load_train_val_examples(train_base_file)
        model = load_model(args)
        trainer.evaluate(args, model, val_examples)

    # --------------- predict phase ---------------
    if args.do_predict:
        #  test_examples = load_examples_from_bios_file(args.test_file)
        # for ner_labels
        test_examples = load_test_examples(args,
                                           test_base_file,
                                           seg_len=seg_len,
                                           seg_backoff=seg_backoff)
        #  test_dataset = do_examples_to_dataset(
        #      examples=test_examples, max_seq_length=args.eval_max_seq_length)

        model = load_model(args)

        #  s0 = 0
        #  while s0 < len(text):
        #      seg_text = text[s0:s0 + seg_len]
        #      out = find_events_in_text(seg_text)
        #      #  out['doc_id'] = l['doc_id']
        #
        #      if out:
        #          event_outputs += out['events']
        #      s0 += seg_len
        #
        #      if seg_len == 0:
        #          break

        trainer.predict(args, model, test_examples)

        save_predict_results(args, trainer.pred_results, get_result_file(args),
                             test_examples)