def continuous_eval(experiment_dir):
    """Evaluate until checkpoints stop being produced."""
    for ckpt_path in trainer_lib.next_checkpoint(experiment_dir,
                                                 timeout_mins=-1):
        hparams = seq2act_estimator.load_hparams(experiment_dir)
        hparams.set_hparam("batch_size", FLAGS.eval_batch_size)
        eval_input_fn = seq2act_estimator.create_input_fn(
            FLAGS.eval_files,
            hparams.batch_size,
            -1,
            2,
            input_utils.DataSource.from_str(FLAGS.eval_data_source),
            max_range=hparams.max_span,
            max_dom_pos=hparams.max_dom_pos,
            max_pixel_pos=hparams.max_pixel_pos,
            mean_synthetic_length=hparams.mean_synthetic_length,
            stddev_synthetic_length=hparams.stddev_synthetic_length,
            load_extra=True,
            load_screen=hparams.load_screen,
            load_dom_dist=(hparams.screen_encoder == "gcn"))
        estimator = create_estimator(experiment_dir,
                                     hparams,
                                     decode_length=FLAGS.decode_length)
        estimator.evaluate(input_fn=eval_input_fn,
                           steps=FLAGS.eval_steps,
                           checkpoint_path=ckpt_path,
                           name=FLAGS.eval_name)
예제 #2
0
def prediction(experiment_dir):
    """Evaluate until checkpoints stop being produced."""
    for ckpt_path in trainer_lib.next_checkpoint(experiment_dir,
                                                 timeout_mins=-1):
        hparams = seq2act_estimator.load_hparams(experiment_dir)
        hparams.set_hparam("batch_size", FLAGS.eval_batch_size)
        eval_input_fn = seq2act_estimator.create_input_fn(
            FLAGS.eval_files,
            hparams.batch_size,
            -1,
            2,
            input_utils.DataSource.from_str(FLAGS.eval_data_source),
            max_range=hparams.max_span,
            max_dom_pos=hparams.max_dom_pos,
            max_pixel_pos=hparams.max_pixel_pos,
            mean_synthetic_length=hparams.mean_synthetic_length,
            stddev_synthetic_length=hparams.stddev_synthetic_length,
            load_extra=True,
            load_screen=hparams.load_screen,
            load_dom_dist=(hparams.screen_encoder == "gcn"))

        raw_texts = [
            "Click button OK", "Navigate to settings", "Open app drawer"
        ]
        predict1111111_input_fn = lambda: common_input_fn(features,
                                                          training=False)

        my_input_fn = tf.estimator.inputs.numpy_input_fn(
            x={"x": "Click button OK"}, shuffle=False, batch_size=1)

        predict_input_fn = tf.estimator.inputs.numpy_input_fn(
            {"input_refs": np.array(raw_texts).astype(np.str)},
            shuffle=False,
            batch_size=1)

        estimator = create_estimator(experiment_dir,
                                     hparams,
                                     decode_length=FLAGS.decode_length)

        print("\nSTART PREDICTION\n")

        results = estimator.predict(
            input_fn=predict_input_fn,
            #input_fn=predict_input_fn,
            #steps=FLAGS.eval_steps,
            #checkpoint_path=ckpt_path,
            #name=FLAGS.eval_name
        )

        print(results)
        for result in results:
            print(result)
        print("\nEND PREDICTION\n")
def train(experiment_dir):
    """Trains the model."""
    if FLAGS.hparam_file:
        hparams = seq2act_estimator.load_hparams(FLAGS.hparam_file)
    else:
        hparams = seq2act_estimator.create_hparams()

    estimator = create_estimator(experiment_dir, hparams)
    seq2act_estimator.save_hyperparams(hparams, experiment_dir)
    train_file_list = FLAGS.train_file_list.split(",")
    train_source_list = FLAGS.train_source_list.split(",")
    train_batch_sizes = FLAGS.train_batch_sizes.split(",")
    print("* xm_train", train_file_list, train_source_list, train_batch_sizes)
    if len(train_file_list) > 1:
        train_input_fn = seq2act_estimator.create_hybrid_input_fn(
            train_file_list,
            [input_utils.DataSource.from_str(s) for s in train_source_list],
            map(int, train_batch_sizes),
            max_range=hparams.max_span,
            max_dom_pos=hparams.max_dom_pos,
            max_pixel_pos=hparams.max_pixel_pos,
            mean_synthetic_length=hparams.mean_synthetic_length,
            stddev_synthetic_length=hparams.stddev_synthetic_length,
            batch_size=hparams.batch_size,
            boost_input=FLAGS.boost_input,
            load_screen=hparams.load_screen,
            buffer_size=FLAGS.shuffle_size,
            shuffle_size=FLAGS.shuffle_size,
            load_dom_dist=(hparams.screen_encoder == "gcn"))
    else:
        train_input_fn = seq2act_estimator.create_input_fn(
            train_file_list[0],
            hparams.batch_size,
            -1,
            -1,
            input_utils.DataSource.from_str(train_source_list[0]),
            max_range=hparams.max_span,
            max_dom_pos=hparams.max_dom_pos,
            max_pixel_pos=hparams.max_pixel_pos,
            mean_synthetic_length=hparams.mean_synthetic_length,
            stddev_synthetic_length=hparams.stddev_synthetic_length,
            load_extra=False,
            load_screen=hparams.load_screen,
            buffer_size=FLAGS.shuffle_size,
            shuffle_size=FLAGS.shuffle_size,
            load_dom_dist=(hparams.screen_encoder == "gcn"))
    estimator.train(input_fn=train_input_fn, steps=FLAGS.train_steps)