def main(unused_argv):
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True

    hyper_params = create_hparams()

    model_fn = model.create_model_fn(hyper_params,
                                     model_impl=dual_encoder_model)

    estimator = tf.contrib.learn.Estimator(
        model_fn=model_fn,
        model_dir=MODEL_DIR,
        config=tf.contrib.learn.RunConfig(session_config=config))

    input_fn_train = inputs.create_input_fn(
        mode=tf.contrib.learn.ModeKeys.TRAIN,
        input_files=[TRAIN_FILE],
        batch_size=hyper_params.batch_size,
        num_epochs=FLAGS.num_epochs)

    input_fn_eval = inputs.create_input_fn(
        mode=tf.contrib.learn.ModeKeys.EVAL,
        input_files=[VALIDATION_FILE],
        batch_size=hyper_params.eval_batch_size,
        num_epochs=1)

    eval_metrics = metrics.create_evaluation_metrics()

    eval_monitor = tf.contrib.learn.monitors.ValidationMonitor(
        input_fn=input_fn_eval,
        every_n_steps=FLAGS.eval_every,
        metrics=eval_metrics)

    estimator.fit(input_fn=input_fn_train, steps=None, monitors=[eval_monitor])
예제 #2
0
def main(unused_argv):
    flags.mark_flag_as_required('model_dir')
    flags.mark_flag_as_required('label_map_path')
    flags.mark_flag_as_required('train_file_pattern')
    flags.mark_flag_as_required('eval_file_pattern')

    config = tf.estimator.RunConfig(model_dir=FLAGS.model_dir)
    run_config = {"label_map_path": FLAGS.label_map_path,
                  "num_classes": FLAGS.num_classes}
    if FLAGS.finetune_ckpt:
        run_config["finetune_ckpt"] = FLAGS.finetune_ckpt
    model_fn = create_model_fn(run_config)
    estimator = tf.estimator.Estimator(model_fn=model_fn, config=config)
    train_input_fn = create_input_fn(FLAGS.train_file_pattern, True, FLAGS.image_size, FLAGS.batch_size)
    eval_input_fn = create_input_fn(FLAGS.eval_file_pattern, False, FLAGS.image_size)
    prediction_fn = create_prediction_input_fn()
    train_spec, eval_spec = create_train_and_eval_specs(train_input_fn, eval_input_fn, prediction_fn, FLAGS.num_train_steps)
    tf.estimator.train_and_evaluate(estimator, train_spec, eval_spec)
예제 #3
0
def main(_):
    """Launch training"""

    # Load hyperparameters
    params = config.create_config()

    # Prepare function that will be used for loading context/utterance
    model_fn = model.create_model_fn(params, model_impl=selected_model)

    # Prepare estimator
    estimator = tf.contrib.learn.Estimator(model_fn=model_fn,
                                           model_dir=MODEL_DIR,
                                           config=tf.contrib.learn.RunConfig(
                                               gpu_memory_fraction=0.25,
                                               save_checkpoints_secs=60 * 2,
                                               keep_checkpoint_max=1,
                                               log_device_placement=False))

    # Prepare input training examples
    input_fn_train = inputs.create_input_fn(
        mode=tf.contrib.learn.ModeKeys.TRAIN,
        input_files=[TRAIN_FILE],
        batch_size=params.batch_size,
        num_epochs=FLAGS.num_epochs,
        params=params)

    # Prepare input validation examples
    input_fn_eval = inputs.create_input_fn(mode=tf.contrib.learn.ModeKeys.EVAL,
                                           input_files=[VALIDATION_FILE],
                                           batch_size=params.eval_batch_size,
                                           num_epochs=1,
                                           params=params)

    # Load recall metrics for validation
    eval_metrics = metrics.create_evaluation_metrics()

    # Prepare monitor for validation
    eval_monitor = tf.contrib.learn.monitors.ValidationMonitor(
        input_fn=input_fn_eval,
        every_n_steps=FLAGS.eval_every,
        metrics=eval_metrics)

    # Lauch training
    estimator.fit(input_fn=input_fn_train, steps=None, monitors=[eval_monitor])
예제 #4
0
def main(unused_argv):

  # Replace MODEL_DIR with the folder current run to resume training from a set of hyperparameters
  # MODEL_DIR = '/Users/eduardolitonjua/Desktop/Retrieval-System/runs/1472130056' 
  hparams = hyperparameters.create_hparams()

  model_fn = model.create_model_fn(
    hparams,
    model_impl=dual_encoder_model)

  estimator = tf.contrib.learn.Estimator(
    model_fn=model_fn,
    model_dir=MODEL_DIR,
    config=tf.contrib.learn.RunConfig())

  input_fn_train = inputs.create_input_fn(
    mode=tf.contrib.learn.ModeKeys.TRAIN,
    input_files=[TRAIN_FILE],
    batch_size=hparams.batch_size,
    num_epochs=FLAGS.num_epochs)

  input_fn_eval = inputs.create_input_fn(
    mode=tf.contrib.learn.ModeKeys.EVAL,
    input_files=[VALIDATION_FILE],
    batch_size=hparams.eval_batch_size,
    num_epochs=1)

  eval_metrics = metrics.create_evaluation_metrics()

  class EvaluationMonitor(tf.contrib.learn.monitors.EveryN):
    def every_n_step_end(self, step, outputs):
      self._estimator.evaluate(
        input_fn=input_fn_eval,
        metrics=eval_metrics,
        steps=None)

  eval_monitor = EvaluationMonitor(every_n_steps=FLAGS.eval_every, first_n_steps=-1)
  estimator.fit(input_fn=input_fn_train, steps=None, monitors=[eval_monitor])
예제 #5
0
tf.flags.DEFINE_string("test_file", "./data/test.tfrecords",
                       "Path of test data in TFRecords format")
tf.flags.DEFINE_string("model_dir", None,
                       "Directory to load model checkpoints from")
tf.flags.DEFINE_integer("loglevel", 20, "Tensorflow log level")
tf.flags.DEFINE_integer("test_batch_size", 16, "Batch size for testing")
FLAGS = tf.flags.FLAGS

if not FLAGS.model_dir:
    print("You must specify a model directory")
    sys.exit(1)

tf.logging.set_verbosity(FLAGS.loglevel)

if __name__ == "__main__":
    hparams = hparams.create_hparams()
    model_fn = model.create_model_fn(hparams, model_impl=dual_encoder_model)
    estimator = tf.contrib.learn.Estimator(model_fn=model_fn,
                                           model_dir=FLAGS.model_dir,
                                           config=tf.contrib.learn.RunConfig())

    input_fn_test = inputs.create_input_fn(mode=tf.contrib.learn.ModeKeys.EVAL,
                                           input_files=[FLAGS.test_file],
                                           batch_size=FLAGS.test_batch_size,
                                           num_epochs=1)

    eval_metrics = metrics.create_evaluation_metrics()
    estimator.evaluate(input_fn=input_fn_test,
                       steps=None,
                       metrics=eval_metrics)
예제 #6
0
파일: train.py 프로젝트: iamgroot42/dstc7
def main(unused_argv):
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True

    hyper_params = create_hparams()
    print("\n\nModel hyperparameters", hyper_params)

    model_fn = model.create_model_fn(hyper_params,
                                     model_impl=dual_encoder_model)

    estimator = tf.contrib.learn.Estimator(
        model_fn=model_fn,
        model_dir=MODEL_DIR,
        config=tf.contrib.learn.RunConfig(session_config=config))

    # Training mode
    if not FLAGS.infer_mode:
        input_fn_train = inputs.create_input_fn(
            mode=tf.contrib.learn.ModeKeys.TRAIN,
            input_files=[TRAIN_FILE],
            batch_size=hyper_params.batch_size,
            num_epochs=FLAGS.num_epochs,
            has_dssm=hyper_params.dssm,
            has_lcs=hyper_params.lcs,
        )

        input_fn_eval = inputs.create_input_fn(
            mode=tf.contrib.learn.ModeKeys.EVAL,
            input_files=[VALIDATION_FILE],
            batch_size=hyper_params.eval_batch_size,
            num_epochs=1,
            has_dssm=hyper_params.dssm,
            has_lcs=hyper_params.lcs,
        )

        eval_metrics = metrics.create_evaluation_metrics()

        eval_monitor = tf.contrib.learn.monitors.ValidationMonitor(
            input_fn=input_fn_eval,
            every_n_steps=FLAGS.eval_every,
            metrics=eval_metrics,
        )

        estimator.fit(input_fn=input_fn_train,
                      steps=None,
                      monitors=[eval_monitor])
    # Testing mode
    else:
        input_fn_infer = inputs.create_input_fn(
            mode=tf.contrib.learn.ModeKeys.INFER,
            input_files=[TEST_FILE],
            batch_size=hyper_params.eval_batch_size,
            num_epochs=1,
            has_dssm=hyper_params.dssm,
            has_lcs=hyper_params.lcs,
            randomize=False)

        preds = estimator.predict(input_fn=input_fn_infer)
        i = 0
        with open(FLAGS.test_out, 'w') as f:
            for pred in preds:
                i += 1
                output_string = ",".join([("%.15f" % indi) for indi in pred])
                f.write(output_string + "\n")
                print(i)