def continuous_eval(experiment_dir): """Evaluate until checkpoints stop being produced.""" for ckpt_path in trainer_lib.next_checkpoint(experiment_dir, timeout_mins=-1): hparams = seq2act_estimator.load_hparams(experiment_dir) hparams.set_hparam("batch_size", FLAGS.eval_batch_size) eval_input_fn = seq2act_estimator.create_input_fn( FLAGS.eval_files, hparams.batch_size, -1, 2, input_utils.DataSource.from_str(FLAGS.eval_data_source), max_range=hparams.max_span, max_dom_pos=hparams.max_dom_pos, max_pixel_pos=hparams.max_pixel_pos, mean_synthetic_length=hparams.mean_synthetic_length, stddev_synthetic_length=hparams.stddev_synthetic_length, load_extra=True, load_screen=hparams.load_screen, load_dom_dist=(hparams.screen_encoder == "gcn")) estimator = create_estimator(experiment_dir, hparams, decode_length=FLAGS.decode_length) estimator.evaluate(input_fn=eval_input_fn, steps=FLAGS.eval_steps, checkpoint_path=ckpt_path, name=FLAGS.eval_name)
def prediction(experiment_dir): """Evaluate until checkpoints stop being produced.""" for ckpt_path in trainer_lib.next_checkpoint(experiment_dir, timeout_mins=-1): hparams = seq2act_estimator.load_hparams(experiment_dir) hparams.set_hparam("batch_size", FLAGS.eval_batch_size) eval_input_fn = seq2act_estimator.create_input_fn( FLAGS.eval_files, hparams.batch_size, -1, 2, input_utils.DataSource.from_str(FLAGS.eval_data_source), max_range=hparams.max_span, max_dom_pos=hparams.max_dom_pos, max_pixel_pos=hparams.max_pixel_pos, mean_synthetic_length=hparams.mean_synthetic_length, stddev_synthetic_length=hparams.stddev_synthetic_length, load_extra=True, load_screen=hparams.load_screen, load_dom_dist=(hparams.screen_encoder == "gcn")) raw_texts = [ "Click button OK", "Navigate to settings", "Open app drawer" ] predict1111111_input_fn = lambda: common_input_fn(features, training=False) my_input_fn = tf.estimator.inputs.numpy_input_fn( x={"x": "Click button OK"}, shuffle=False, batch_size=1) predict_input_fn = tf.estimator.inputs.numpy_input_fn( {"input_refs": np.array(raw_texts).astype(np.str)}, shuffle=False, batch_size=1) estimator = create_estimator(experiment_dir, hparams, decode_length=FLAGS.decode_length) print("\nSTART PREDICTION\n") results = estimator.predict( input_fn=predict_input_fn, #input_fn=predict_input_fn, #steps=FLAGS.eval_steps, #checkpoint_path=ckpt_path, #name=FLAGS.eval_name ) print(results) for result in results: print(result) print("\nEND PREDICTION\n")
def train(experiment_dir): """Trains the model.""" if FLAGS.hparam_file: hparams = seq2act_estimator.load_hparams(FLAGS.hparam_file) else: hparams = seq2act_estimator.create_hparams() estimator = create_estimator(experiment_dir, hparams) seq2act_estimator.save_hyperparams(hparams, experiment_dir) train_file_list = FLAGS.train_file_list.split(",") train_source_list = FLAGS.train_source_list.split(",") train_batch_sizes = FLAGS.train_batch_sizes.split(",") print("* xm_train", train_file_list, train_source_list, train_batch_sizes) if len(train_file_list) > 1: train_input_fn = seq2act_estimator.create_hybrid_input_fn( train_file_list, [input_utils.DataSource.from_str(s) for s in train_source_list], map(int, train_batch_sizes), max_range=hparams.max_span, max_dom_pos=hparams.max_dom_pos, max_pixel_pos=hparams.max_pixel_pos, mean_synthetic_length=hparams.mean_synthetic_length, stddev_synthetic_length=hparams.stddev_synthetic_length, batch_size=hparams.batch_size, boost_input=FLAGS.boost_input, load_screen=hparams.load_screen, buffer_size=FLAGS.shuffle_size, shuffle_size=FLAGS.shuffle_size, load_dom_dist=(hparams.screen_encoder == "gcn")) else: train_input_fn = seq2act_estimator.create_input_fn( train_file_list[0], hparams.batch_size, -1, -1, input_utils.DataSource.from_str(train_source_list[0]), max_range=hparams.max_span, max_dom_pos=hparams.max_dom_pos, max_pixel_pos=hparams.max_pixel_pos, mean_synthetic_length=hparams.mean_synthetic_length, stddev_synthetic_length=hparams.stddev_synthetic_length, load_extra=False, load_screen=hparams.load_screen, buffer_size=FLAGS.shuffle_size, shuffle_size=FLAGS.shuffle_size, load_dom_dist=(hparams.screen_encoder == "gcn")) estimator.train(input_fn=train_input_fn, steps=FLAGS.train_steps)