Exemplo n.º 1
0
def experiment_fn(run_config, hparams):
    estimator = tf.estimator.Estimator(model_fn=optimizer.make_model_fn(
        MODELS[FLAGS.model].model, FLAGS.num_gpus),
                                       config=run_config,
                                       params=hparams)
    train_hooks = [
        hooks.ExamplesPerSecondHook(batch_size=hparams.batch_size,
                                    every_n_iter=FLAGS.save_summary_steps),
        hooks.LoggingTensorHook(collection="batch_logging",
                                every_n_iter=FLAGS.save_summary_steps,
                                batch=True),
        hooks.LoggingTensorHook(collection="logging",
                                every_n_iter=FLAGS.save_summary_steps,
                                batch=False)
    ]
    eval_hooks = [
        hooks.SummarySaverHook(every_n_iter=FLAGS.save_summary_steps,
                               output_dir=os.path.join(run_config.model_dir,
                                                       "eval"))
    ]
    experiment = tf.contrib.learn.Experiment(
        estimator=estimator,
        train_input_fn=common_io.make_input_fn(
            DATASETS[FLAGS.dataset],
            tf.estimator.ModeKeys.TRAIN,
            hparams,
            num_epochs=FLAGS.num_epochs,
            shuffle_batches=FLAGS.shuffle_batches,
            num_threads=FLAGS.num_reader_threads),
        eval_input_fn=common_io.make_input_fn(
            DATASETS[FLAGS.dataset],
            tf.estimator.ModeKeys.EVAL,
            hparams,
            num_epochs=FLAGS.num_epochs,
            shuffle_batches=FLAGS.shuffle_batches,
            num_threads=FLAGS.num_reader_threads),
        eval_steps=None,
        min_eval_frequency=FLAGS.eval_frequency,
        eval_hooks=eval_hooks)
    experiment.extend_train_hooks(train_hooks)
    return experiment
Exemplo n.º 2
0
def main(unused_argv):
    if FLAGS.output_dir:
        model_dir = FLAGS.output_dir
    else:
        raise NotImplementedError

    DATASETS[FLAGS.dataset].prepare()

    session_config = tf.ConfigProto()
    session_config.allow_soft_placement = True
    session_config.gpu_options.allow_growth = True
    run_config = tf.contrib.learn.RunConfig(
        model_dir=model_dir,
        save_summary_steps=FLAGS.save_summary_steps,
        save_checkpoints_steps=FLAGS.save_checkpoints_steps,
        save_checkpoints_secs=None,
        session_config=session_config)

    estimator = tf.estimator.Estimator(
        model_fn=optimizer.make_model_fn(MODELS[FLAGS.model].model,
                                         FLAGS.num_gpus),
        config=run_config,
        params=hparams.get_params(MODELS[FLAGS.model], DATASETS[FLAGS.dataset],
                                  FLAGS.hparams))

    y = estimator.predict(input_fn=common_io.make_input_fn(
        DATASETS[FLAGS.dataset],
        tf.estimator.ModeKeys.PREDICT,
        hparams.get_params(MODELS[FLAGS.model], DATASETS[FLAGS.dataset],
                           FLAGS.hparams),
        num_epochs=1,
        shuffle_batches=False,
        num_threads=FLAGS.num_reader_threads), )

    print("fname,label")
    words = DATASETS[FLAGS.dataset].WORDS
    for file, p in zip(DATASETS[FLAGS.dataset].TEST_LIST, y):
        print(file, words[p["predictions"]], sep=',')
Exemplo n.º 3
0
def main(unused_argv):
  cfg = config.get_config(FLAGS.config, FLAGS.override)
  print("Configuration loaded: ")
  print(cfg)

  if not cfg.experiment:
    if FLAGS.config:
      cfg.experiment = os.path.splitext(os.path.basename(FLAGS.config))[0]
    else:
      cfg.experiment = "default"

  model_dir = config.get_model_dir(cfg)
  if not os.path.exists(model_dir):
    os.makedirs(model_dir)
  
  tf.logging.set_verbosity(tf.logging.INFO)

  session_config = tf.ConfigProto()
  session_config.allow_soft_placement = True
  session_config.gpu_options.allow_growth = True  # pylint: ignore

  run_config = tf.estimator.RunConfig(
    model_dir=model_dir,
    save_summary_steps=cfg.save_summary_steps,
    save_checkpoints_steps=cfg.save_checkpoints_steps,
    save_checkpoints_secs=None,
    session_config=session_config)
  
  m = model.ModelFactory.create(cfg.model)

  d = dataset.DatasetFactory.create(cfg.dataset)

  hp = config.get_params(m, d, cfg.hparams)

  d.prepare(hp)

  estimator = tf.estimator.Estimator(
    model_fn=optimizer.make_model_fn(
      m.model, d.process, cfg.num_gpus, cfg.gpu_id, 
      hp.weight_averaging_decay),
    config=run_config, 
    params=hp)

  def eval_input_fn(eval_set):  
    return io_utils.make_input_fn(
      d, eval_set, tf.estimator.ModeKeys.EVAL, hp,
      num_epochs=1,
      num_threads=cfg.num_reader_threads,
      prefetch_buffer_size=cfg.prefetch_buffer_size)

  def _predict():
    results = {}
    for eval_set in cfg.eval_sets:
      result_iterator = estimator.predict(
        input_fn=eval_input_fn(eval_set))
      result = {}
      for item in result_iterator:
        for k, v in item.items():
          result.setdefault(k, []).append(np.array(v).tolist())
      results[eval_set] = result
    return results

  def _eval():
    results = {}
    for eval_set in cfg.eval_sets:
      metrics = estimator.evaluate(
        input_fn=eval_input_fn(eval_set),
        hooks=[
          hooks.SummarySaverHook(
            every_n_iter=cfg.save_summary_steps,
            output_dir=os.path.join(run_config.model_dir, "eval_" + eval_set))],
        name=eval_set)
      results[eval_set] = metrics
      print(metrics)
    return results

  def _train():
    params_path = os.path.join(model_dir, "params.json")
    if os.path.exists(params_path) and not FLAGS.overwrite_params:
      with open(params_path, "r") as fp:
        if not fp.read() == str(hp):
          raise RuntimeError("Mismatching parameters found.")
    else:
      with open(params_path, "w") as fp:
        fp.write(str(hp))

    train_sets = (
      cfg.train_sets.to_dict() 
      if isinstance(cfg.train_sets, misc_utils.Tuple) 
      else cfg.train_sets)      
    estimator.train(
      input_fn=io_utils.make_input_fn(
        d, train_sets, 
        tf.estimator.ModeKeys.TRAIN, hp,
        num_epochs=cfg.num_epochs,
        shuffle_batches=cfg.shuffle_batches,
        num_threads=cfg.num_reader_threads,
        prefetch_buffer_size=cfg.prefetch_buffer_size),
      hooks=[
        hooks.ExamplesPerSecondHook(
          batch_size=hp.batch_size,
          every_n_iter=cfg.save_summary_steps),
        hooks.LoggingTensorHook(
          collection="batch_logging",
          every_n_iter=cfg.save_summary_steps,
          batch=True),
        hooks.LoggingTensorHook(
          collection="logging",
          every_n_iter=cfg.save_summary_steps,
          batch=False),
        tf.train.CheckpointSaverHook(
          model_dir,
          save_steps=cfg.save_checkpoints_steps,
          listeners=[
            hooks.BestCheckpointKeeper(
              model_dir,
              eval_fn=_eval,
              eval_set=cfg.checkpoint_selector.eval_set, 
              eval_metric=cfg.checkpoint_selector.eval_metric, 
              compare_fn=cfg.checkpoint_selector.compare_fn)])])

  def _export():
    serving_input_fn = tf.estimator.export.build_raw_serving_input_receiver_fn(
        m.get_features(hp))

    estimator.export_saved_model(
      os.path.join(model_dir, "export"), 
      serving_input_fn)
    
    tf.reset_default_graph()
    with tf.Session() as sess:
      features = serving_input_fn().features
      predictions = m.model(features, None, tf.estimator.ModeKeys.PREDICT, hp)[0]
      print("Features", {k: v.name for k, v in features.items()})
      print("Predictions", {k: v.name for k, v in predictions.items()})
      tf.train.write_graph(sess.graph_def, model_dir, 'graph_eval.pbtxt')
 
  if cfg.mode == "train":
    _train()
    if FLAGS.results:
      results = _eval()
      with open(FLAGS.results, "w") as f:
        f.write(misc_utils.serialize_json(results))
  elif cfg.mode == "eval":
    results = _eval()
    if FLAGS.results:
      with open(FLAGS.results, "w") as f:
        f.write(misc_utils.serialize_json(results))
  elif cfg.mode == "predict":
    results = _predict()
    if FLAGS.results:
      with open(FLAGS.results, "w") as f:
        f.write(misc_utils.serialize_json(results))
  elif cfg.mode == "export":
    _export()
  else:
    print("Unrecognized mode", cfg.mode)
  print("Done.")