Exemplo n.º 1
0
def main():
    parser = get_parser()
    args = parser.parse_args()
    if args.verbose:
        # logger.setLevel(logging.INFO)
        logging.basicConfig(level=logging.INFO)
    # To do: argparse for config file
    if args.datapath is None:
        config = util.initialize_from_env(model_name='final')
    else:
        config = {'datapath': args.datapath}
    download_data(config)
Exemplo n.º 2
0
    def __init__(self, model_name='final', config=None, verbose=False):
        if verbose:
            logger.setLevel(logging.INFO)

        if config:
            self.config = config
        else:
            # if no configuration is provided, try to get a default config.
            self.config = util.initialize_from_env(model_name=model_name)

        # Clear tensorflow context:
        tf.reset_default_graph()
        self.session = tf.compat.v1.Session()

        try:
            self.model = cm.CorefModel(self.config)
            self.model.restore(self.session)
        except ValueError:
            raise Exception("Trying to reload the model while the previous " +
                            "session hasn't been ended. Close the existing " +
                            "session with predictor.end_session()")
Exemplo n.º 3
0
    def __init__(self, config, pipeline, use_gpu):
        # Make e2edutch follow Stanza's GPU settings:
        # set the environment value for GPU, so that initialize_from_env picks it up.
        # if use_gpu:
        #    os.environ['GPU'] = ' '.join(tf.config.experimental.list_physical_devices('GPU'))
        # else:
        #    if 'GPU' in os.environ['GPU'] :
        #        os.environ.pop('GPU')

        self.e2econfig = util.initialize_from_env(model_name='final')

        # Override datapath and log_root:
        # store e2edata with the Stanza resources, ie. a 'stanza_resources/nl/coref' directory
        self.e2econfig['datapath'] = Path(config['model_path']).parent
        self.e2econfig['log_root'] = Path(config['model_path']).parent

        # Download data files if not present
        download_data(self.e2econfig)

        # Start and stop a session to cache all models
        predictor = Predictor(config=self.e2econfig)
        predictor.end_session()
Exemplo n.º 4
0
import tensorflow.compat.v1 as tf
tf.disable_v2_behavior()


def get_parser():
    parser = argparse.ArgumentParser()
    parser.add_argument('config')
    parser.add_argument('--cfg_file',
                        type=str,
                        default=None,
                        help="config file")
    parser.add_argument('--model_cfg_file',
                        type=str,
                        default=None,
                        help="model config file")
    parser.add_argument('-v', '--verbose', action='store_true')
    return parser


if __name__ == "__main__":
    args = get_parser().parse_args()
    if args.verbose:
        logging.basicConfig(level=logging.DEBUG)
    config = util.initialize_from_env(
        args.config, args.cfg_file, args.model_cfg_file)
    model = cm.CorefModel(config)
    with tf.Session() as session:
        model.restore(session)
        model.evaluate(session, official_stdout=True)
Exemplo n.º 5
0
def main(args=None):
    args = get_parser().parse_args()
    if args.verbose:
        logging.basicConfig(level=logging.DEBUG)
    config = util.initialize_from_env(args.config,
                                      args.cfg_file,
                                      args.model_cfg_file)
    # Overwrite train and eval file if specified
    if args.train is not None:
        config['train_path'] = args.train
    if args.eval is not None:
        config['eval_path'] = args.eval
    if args.eval_conll is not None:
        config['conll_eval_path'] = args.eval_conll

    report_frequency = config["report_frequency"]
    eval_frequency = config["eval_frequency"]

    model = cm.CorefModel(config)
    saver = tf.train.Saver()

    log_dir = os.path.join(config['log_root'], config['log_dir'])
    writer = tf.summary.FileWriter(log_dir, flush_secs=20)

    max_f1 = 0

    with tf.Session() as session:
        session.run(tf.global_variables_initializer())
        model.start_enqueue_thread(session)
        accumulated_loss = 0.0

        ckpt = tf.train.get_checkpoint_state(log_dir)
        if ckpt and ckpt.model_checkpoint_path:
            print("Restoring from: {}".format(ckpt.model_checkpoint_path))
            saver.restore(session, ckpt.model_checkpoint_path)

        initial_time = time.time()
        while True:
            tf_loss, tf_global_step, _ = session.run(
                [model.loss, model.global_step, model.train_op])
            accumulated_loss += tf_loss

            if tf_global_step % report_frequency == 0:
                total_time = time.time() - initial_time
                steps_per_second = tf_global_step / total_time

                average_loss = accumulated_loss / report_frequency
                print("[{}] loss={:.2f}, steps/s={:.2f}"
                      .format(tf_global_step,
                              average_loss,
                              steps_per_second))
                writer.add_summary(util.make_summary(
                    {"loss": average_loss}), tf_global_step)
                accumulated_loss = 0.0

            if tf_global_step % eval_frequency == 0:
                saver.save(session, os.path.join(log_dir, "model"),
                           global_step=tf_global_step)
                eval_summary, eval_f1 = model.evaluate(session)

                if eval_f1 > max_f1:
                    max_f1 = eval_f1
                    util.copy_checkpoint(os.path.join(
                        log_dir, "model-{}".format(tf_global_step)),
                                         os.path.join(log_dir, "model.max.ckpt"))

                writer.add_summary(eval_summary, tf_global_step)
                writer.add_summary(util.make_summary(
                    {"max_eval_f1": max_f1}), tf_global_step)

                print("[{}] evaL_f1={:.2f}, max_f1={:.2f}".format(
                    tf_global_step, eval_f1, max_f1))
Exemplo n.º 6
0
def main(args=None):
    parser = get_parser()
    args = parser.parse_args()
    if args.verbose:
        logger.setLevel(logging.INFO)

    # Input file in .jsonlines format or .conll.
    input_filename = args.input_filename

    ext_input = os.path.splitext(input_filename)[-1]
    if ext_input not in ['.conll', '.jsonlines', '.txt', '.naf']:
        raise Exception(
            'Input file should be .naf, .conll, .txt or .jsonlines, but is {}.'
            .format(ext_input))

    if ext_input == '.conll':
        labels = collections.defaultdict(set)
        stats = collections.defaultdict(int)
        docs = minimize.minimize_partition(input_filename, labels, stats,
                                           args.word_col)
    elif ext_input == '.jsonlines':
        docs = read_jsonlines(input_filename)
    elif ext_input == '.naf':
        naf_obj = naf.get_naf(input_filename)
        jsonlines_obj, term_ids, tok_ids = naf.get_jsonlines(naf_obj)
        docs = [jsonlines_obj]
    else:
        text = open(input_filename).read()
        docs = [util.create_example(text)]

    output_file = args.output_file

    config = util.initialize_from_env(model_name=args.model,
                                      cfg_file=args.cfg_file,
                                      model_cfg_file=args.model_cfg_file)
    predictor = Predictor(config=config)

    sentences = {}
    predictions = {}
    for example_num, example in enumerate(docs):
        example["predicted_clusters"] = predictor.predict(example)
        if args.format_out == 'jsonlines':
            output_file.write(json.dumps(example))
            output_file.write("\n")
        else:
            predictions[example['doc_key']] = example["predicted_clusters"]
            sentences[example['doc_key']] = example["sentences"]
        if example_num % 100 == 0:
            logger.info("Decoded {} examples.".format(example_num + 1))
    if args.format_out == 'conll':
        conll.output_conll(output_file, sentences, predictions)
    elif args.format_out == 'naf':
        # Check number of docs - what to do if multiple?
        # Create naf obj if input format was not naf
        if ext_input != '.naf':
            # To do: add linguistic processing layers for terms and tokens
            logger.warn('Outputting NAF when input was not naf,' +
                        'no dependency information available')
            for doc_key in sentences:
                naf_obj, term_ids = naf.get_naf_from_sentences(
                    sentences[doc_key])
                naf_obj = naf.create_coref_layer(naf_obj, predictions[doc_key],
                                                 term_ids)
                naf_obj = naf.add_linguistic_processors(naf_obj)
                buffer = io.BytesIO()
                naf_obj.dump(buffer)
                output_file.write(buffer.getvalue().decode('utf-8'))
                # To do, make sepearate outputs?
                # TO do, use dependency information from conll?
        else:
            # We only have one input doc
            naf_obj = naf.create_coref_layer(naf_obj,
                                             example["predicted_clusters"],
                                             term_ids)
            naf_obj = naf.add_linguistic_processors(naf_obj)
            buffer = io.BytesIO()
            naf_obj.dump(buffer)
            output_file.write(buffer.getvalue().decode('utf-8'))