def main(_): tf.logging.set_verbosity(tf.logging.INFO) trainer_lib.set_random_seed(FLAGS.random_seed) usr_dir.import_usr_dir(FLAGS.t2t_usr_dir) hparams = trainer_lib.create_hparams(FLAGS.hparams_set, FLAGS.hparams, data_dir=FLAGS.data_dir, problem_name=FLAGS.problem) # set appropriate dataset-split, if flags.eval_use_test_set. dataset_split = "test" if FLAGS.eval_use_test_set else None dataset_kwargs = {"dataset_split": dataset_split} eval_input_fn = hparams.problem.make_estimator_input_fn( tf.estimator.ModeKeys.EVAL, hparams, dataset_kwargs=dataset_kwargs) config = t2t_trainer.create_run_config(hparams) # summary-hook in tf.estimator.EstimatorSpec requires # hparams.model_dir to be set. hparams.add_hparam("model_dir", config.model_dir) estimator = trainer_lib.create_estimator(FLAGS.model, hparams, config, use_tpu=FLAGS.use_tpu) ckpt_iter = trainer_lib.next_checkpoint(hparams.model_dir, FLAGS.eval_timeout_mins) for ckpt_path in ckpt_iter: predictions = estimator.evaluate(eval_input_fn, steps=FLAGS.eval_steps, checkpoint_path=ckpt_path) tf.logging.info(predictions)
def continuous_eval(experiment_dir): """Evaluate until checkpoints stop being produced.""" for ckpt_path in trainer_lib.next_checkpoint(experiment_dir, timeout_mins=-1): hparams = seq2act_estimator.load_hparams(experiment_dir) hparams.set_hparam("batch_size", FLAGS.eval_batch_size) eval_input_fn = seq2act_estimator.create_input_fn( FLAGS.eval_files, hparams.batch_size, -1, 2, input_utils.DataSource.from_str(FLAGS.eval_data_source), max_range=hparams.max_span, max_dom_pos=hparams.max_dom_pos, max_pixel_pos=hparams.max_pixel_pos, mean_synthetic_length=hparams.mean_synthetic_length, stddev_synthetic_length=hparams.stddev_synthetic_length, load_extra=True, load_screen=hparams.load_screen, load_dom_dist=(hparams.screen_encoder == "gcn")) estimator = create_estimator(experiment_dir, hparams, decode_length=FLAGS.decode_length) estimator.evaluate(input_fn=eval_input_fn, steps=FLAGS.eval_steps, checkpoint_path=ckpt_path, name=FLAGS.eval_name)
def prediction(experiment_dir): """Evaluate until checkpoints stop being produced.""" for ckpt_path in trainer_lib.next_checkpoint(experiment_dir, timeout_mins=-1): hparams = seq2act_estimator.load_hparams(experiment_dir) hparams.set_hparam("batch_size", FLAGS.eval_batch_size) eval_input_fn = seq2act_estimator.create_input_fn( FLAGS.eval_files, hparams.batch_size, -1, 2, input_utils.DataSource.from_str(FLAGS.eval_data_source), max_range=hparams.max_span, max_dom_pos=hparams.max_dom_pos, max_pixel_pos=hparams.max_pixel_pos, mean_synthetic_length=hparams.mean_synthetic_length, stddev_synthetic_length=hparams.stddev_synthetic_length, load_extra=True, load_screen=hparams.load_screen, load_dom_dist=(hparams.screen_encoder == "gcn")) raw_texts = [ "Click button OK", "Navigate to settings", "Open app drawer" ] predict1111111_input_fn = lambda: common_input_fn(features, training=False) my_input_fn = tf.estimator.inputs.numpy_input_fn( x={"x": "Click button OK"}, shuffle=False, batch_size=1) predict_input_fn = tf.estimator.inputs.numpy_input_fn( {"input_refs": np.array(raw_texts).astype(np.str)}, shuffle=False, batch_size=1) estimator = create_estimator(experiment_dir, hparams, decode_length=FLAGS.decode_length) print("\nSTART PREDICTION\n") results = estimator.predict( input_fn=predict_input_fn, #input_fn=predict_input_fn, #steps=FLAGS.eval_steps, #checkpoint_path=ckpt_path, #name=FLAGS.eval_name ) print(results) for result in results: print(result) print("\nEND PREDICTION\n")
def main(_): if FLAGS.results_dir: print("\n\n\n\n\nresults_dir = {}\n\n\n\n\n".format(FLAGS.results_dir)) print(FLAGS.results_dir and FLAGS.results_dir[:5] == "gs://") tf.logging.set_verbosity(tf.logging.INFO) trainer_lib.set_random_seed(FLAGS.random_seed) usr_dir.import_usr_dir(FLAGS.t2t_usr_dir) hparams = trainer_lib.create_hparams(FLAGS.hparams_set, FLAGS.hparams, data_dir=FLAGS.data_dir, problem_name=FLAGS.problem) if FLAGS.task_direction == problem.TaskDirections.NORMAL: dataset_split = "test" if FLAGS.eval_use_test_set else None elif FLAGS.task_direction == problem.TaskDirections.Q12: dataset_split = FLAGS.dataset_split elif FLAGS.task_direction == problem.TaskDirections.Q8: dataset_split = FLAGS.dataset_split elif FLAGS.task_direction == problem.TaskDirections.INTERPOLATE: dataset_split = FLAGS.dataset_split elif FLAGS.task_direction == problem.TaskDirections.EXTRAPOLATE: dataset_split = FLAGS.dataset_split else: raise ValueError("Found unknown task_direction which is ", FLAGS.task_direction) # pdb.set_trace() dataset_kwargs = {"dataset_split": dataset_split} eval_input_fn = hparams.problem.make_estimator_input_fn( tf.estimator.ModeKeys.EVAL, hparams, dataset_kwargs=dataset_kwargs) config = t2t_trainer.create_run_config(hparams) # summary-hook in tf.estimator.EstimatorSpec requires # hparams.model_dir to be set. hparams.add_hparam("model_dir", config.model_dir) estimator = trainer_lib.create_estimator(FLAGS.model, hparams, config, use_tpu=FLAGS.use_tpu) if FLAGS.task_direction == problem.TaskDirections.NORMAL: ckpt_iter = trainer_lib.next_checkpoint(hparams.model_dir, FLAGS.eval_timeout_mins) elif FLAGS.task_direction == problem.TaskDirections.Q12: ckpt_iter = my_chkpt_iter(hparams.model_dir) elif FLAGS.task_direction == problem.TaskDirections.Q8: ckpt_iter = my_chkpt_iter(hparams.model_dir) elif FLAGS.task_direction == problem.TaskDirections.INTERPOLATE: ckpt_iter = my_chkpt_iter(hparams.model_dir) elif FLAGS.task_direction == problem.TaskDirections.EXTRAPOLATE: ckpt_iter = my_chkpt_iter(hparams.model_dir) else: raise ValueError("Found unknown task_direction which is ", FLAGS.task_direction) # Chose a specific set of checkpoints if dataset_split provided if FLAGS.results_dir: results_dir = FLAGS.results_dir else: raise ValueError("results_dir not defined") results_all_ckpts = [] for ckpt_path in ckpt_iter: results = estimator.evaluate(eval_input_fn, steps=FLAGS.eval_steps, checkpoint_path=ckpt_path) results_all_ckpts.append(results) tf.logging.info(results) # forms a line of text from each category of data def build_line(items, labels=False): items = map(str, items) if labels: return "\t".join([i.split("/")[-1] for i in items]) + "\n" else: return "\t".join(items) + "\n" # pdb.set_trace() # get the category_names category_names = results_all_ckpts[0].keys() # Write to bucket with file_io.FileIO( results_dir + "/eval_" + FLAGS.dataset_split + "_results.txt", "w") as results_file: results_file.write(build_line(category_names, labels=True)) for r in results_all_ckpts: results_file.write(build_line([r[k] for k in category_names])) with file_io.FileIO(results_dir + "/checklist", "w") as checklist_file: checklist_file.write(FLAGS.dataset_split + "\n")