Esempio n. 1
0
def main(_):
    tf.logging.set_verbosity(tf.logging.INFO)
    trainer_lib.set_random_seed(FLAGS.random_seed)
    usr_dir.import_usr_dir(FLAGS.t2t_usr_dir)

    hparams = trainer_lib.create_hparams(FLAGS.hparams_set,
                                         FLAGS.hparams,
                                         data_dir=FLAGS.data_dir,
                                         problem_name=FLAGS.problem)

    # set appropriate dataset-split, if flags.eval_use_test_set.
    dataset_split = "test" if FLAGS.eval_use_test_set else None
    dataset_kwargs = {"dataset_split": dataset_split}
    eval_input_fn = hparams.problem.make_estimator_input_fn(
        tf.estimator.ModeKeys.EVAL, hparams, dataset_kwargs=dataset_kwargs)
    config = t2t_trainer.create_run_config(hparams)

    # summary-hook in tf.estimator.EstimatorSpec requires
    # hparams.model_dir to be set.
    hparams.add_hparam("model_dir", config.model_dir)

    estimator = trainer_lib.create_estimator(FLAGS.model,
                                             hparams,
                                             config,
                                             use_tpu=FLAGS.use_tpu)
    ckpt_iter = trainer_lib.next_checkpoint(hparams.model_dir,
                                            FLAGS.eval_timeout_mins)
    for ckpt_path in ckpt_iter:
        predictions = estimator.evaluate(eval_input_fn,
                                         steps=FLAGS.eval_steps,
                                         checkpoint_path=ckpt_path)
        tf.logging.info(predictions)
def continuous_eval(experiment_dir):
    """Evaluate until checkpoints stop being produced."""
    for ckpt_path in trainer_lib.next_checkpoint(experiment_dir,
                                                 timeout_mins=-1):
        hparams = seq2act_estimator.load_hparams(experiment_dir)
        hparams.set_hparam("batch_size", FLAGS.eval_batch_size)
        eval_input_fn = seq2act_estimator.create_input_fn(
            FLAGS.eval_files,
            hparams.batch_size,
            -1,
            2,
            input_utils.DataSource.from_str(FLAGS.eval_data_source),
            max_range=hparams.max_span,
            max_dom_pos=hparams.max_dom_pos,
            max_pixel_pos=hparams.max_pixel_pos,
            mean_synthetic_length=hparams.mean_synthetic_length,
            stddev_synthetic_length=hparams.stddev_synthetic_length,
            load_extra=True,
            load_screen=hparams.load_screen,
            load_dom_dist=(hparams.screen_encoder == "gcn"))
        estimator = create_estimator(experiment_dir,
                                     hparams,
                                     decode_length=FLAGS.decode_length)
        estimator.evaluate(input_fn=eval_input_fn,
                           steps=FLAGS.eval_steps,
                           checkpoint_path=ckpt_path,
                           name=FLAGS.eval_name)
Esempio n. 3
0
def prediction(experiment_dir):
    """Evaluate until checkpoints stop being produced."""
    for ckpt_path in trainer_lib.next_checkpoint(experiment_dir,
                                                 timeout_mins=-1):
        hparams = seq2act_estimator.load_hparams(experiment_dir)
        hparams.set_hparam("batch_size", FLAGS.eval_batch_size)
        eval_input_fn = seq2act_estimator.create_input_fn(
            FLAGS.eval_files,
            hparams.batch_size,
            -1,
            2,
            input_utils.DataSource.from_str(FLAGS.eval_data_source),
            max_range=hparams.max_span,
            max_dom_pos=hparams.max_dom_pos,
            max_pixel_pos=hparams.max_pixel_pos,
            mean_synthetic_length=hparams.mean_synthetic_length,
            stddev_synthetic_length=hparams.stddev_synthetic_length,
            load_extra=True,
            load_screen=hparams.load_screen,
            load_dom_dist=(hparams.screen_encoder == "gcn"))

        raw_texts = [
            "Click button OK", "Navigate to settings", "Open app drawer"
        ]
        predict1111111_input_fn = lambda: common_input_fn(features,
                                                          training=False)

        my_input_fn = tf.estimator.inputs.numpy_input_fn(
            x={"x": "Click button OK"}, shuffle=False, batch_size=1)

        predict_input_fn = tf.estimator.inputs.numpy_input_fn(
            {"input_refs": np.array(raw_texts).astype(np.str)},
            shuffle=False,
            batch_size=1)

        estimator = create_estimator(experiment_dir,
                                     hparams,
                                     decode_length=FLAGS.decode_length)

        print("\nSTART PREDICTION\n")

        results = estimator.predict(
            input_fn=predict_input_fn,
            #input_fn=predict_input_fn,
            #steps=FLAGS.eval_steps,
            #checkpoint_path=ckpt_path,
            #name=FLAGS.eval_name
        )

        print(results)
        for result in results:
            print(result)
        print("\nEND PREDICTION\n")
Esempio n. 4
0
def main(_):
    if FLAGS.results_dir:
        print("\n\n\n\n\nresults_dir = {}\n\n\n\n\n".format(FLAGS.results_dir))
        print(FLAGS.results_dir and FLAGS.results_dir[:5] == "gs://")
    tf.logging.set_verbosity(tf.logging.INFO)
    trainer_lib.set_random_seed(FLAGS.random_seed)
    usr_dir.import_usr_dir(FLAGS.t2t_usr_dir)

    hparams = trainer_lib.create_hparams(FLAGS.hparams_set,
                                         FLAGS.hparams,
                                         data_dir=FLAGS.data_dir,
                                         problem_name=FLAGS.problem)

    if FLAGS.task_direction == problem.TaskDirections.NORMAL:
        dataset_split = "test" if FLAGS.eval_use_test_set else None
    elif FLAGS.task_direction == problem.TaskDirections.Q12:
        dataset_split = FLAGS.dataset_split
    elif FLAGS.task_direction == problem.TaskDirections.Q8:
        dataset_split = FLAGS.dataset_split
    elif FLAGS.task_direction == problem.TaskDirections.INTERPOLATE:
        dataset_split = FLAGS.dataset_split
    elif FLAGS.task_direction == problem.TaskDirections.EXTRAPOLATE:
        dataset_split = FLAGS.dataset_split
    else:
        raise ValueError("Found unknown task_direction which is ",
                         FLAGS.task_direction)
    # pdb.set_trace()
    dataset_kwargs = {"dataset_split": dataset_split}
    eval_input_fn = hparams.problem.make_estimator_input_fn(
        tf.estimator.ModeKeys.EVAL, hparams, dataset_kwargs=dataset_kwargs)
    config = t2t_trainer.create_run_config(hparams)

    # summary-hook in tf.estimator.EstimatorSpec requires
    # hparams.model_dir to be set.
    hparams.add_hparam("model_dir", config.model_dir)

    estimator = trainer_lib.create_estimator(FLAGS.model,
                                             hparams,
                                             config,
                                             use_tpu=FLAGS.use_tpu)

    if FLAGS.task_direction == problem.TaskDirections.NORMAL:
        ckpt_iter = trainer_lib.next_checkpoint(hparams.model_dir,
                                                FLAGS.eval_timeout_mins)
    elif FLAGS.task_direction == problem.TaskDirections.Q12:
        ckpt_iter = my_chkpt_iter(hparams.model_dir)
    elif FLAGS.task_direction == problem.TaskDirections.Q8:
        ckpt_iter = my_chkpt_iter(hparams.model_dir)
    elif FLAGS.task_direction == problem.TaskDirections.INTERPOLATE:
        ckpt_iter = my_chkpt_iter(hparams.model_dir)
    elif FLAGS.task_direction == problem.TaskDirections.EXTRAPOLATE:
        ckpt_iter = my_chkpt_iter(hparams.model_dir)
    else:
        raise ValueError("Found unknown task_direction which is ",
                         FLAGS.task_direction)

    # Chose a specific set of checkpoints if dataset_split provided
    if FLAGS.results_dir:
        results_dir = FLAGS.results_dir
    else:
        raise ValueError("results_dir not defined")
    results_all_ckpts = []
    for ckpt_path in ckpt_iter:
        results = estimator.evaluate(eval_input_fn,
                                     steps=FLAGS.eval_steps,
                                     checkpoint_path=ckpt_path)
        results_all_ckpts.append(results)
        tf.logging.info(results)

    # forms a line of text from each category of data
    def build_line(items, labels=False):
        items = map(str, items)
        if labels:
            return "\t".join([i.split("/")[-1] for i in items]) + "\n"
        else:
            return "\t".join(items) + "\n"

    # pdb.set_trace()

    # get the category_names
    category_names = results_all_ckpts[0].keys()
    # Write to bucket
    with file_io.FileIO(
            results_dir + "/eval_" + FLAGS.dataset_split + "_results.txt",
            "w") as results_file:
        results_file.write(build_line(category_names, labels=True))
        for r in results_all_ckpts:
            results_file.write(build_line([r[k] for k in category_names]))
    with file_io.FileIO(results_dir + "/checklist", "w") as checklist_file:
        checklist_file.write(FLAGS.dataset_split + "\n")