示例#1
0
    def test_estimator(self, hparams_string, batch_size=1):
        """Structural test to make sure Estimator Builder works."""

        seed = 42

        # Set up and clean test directory.
        model_dir = os.path.join(flags.FLAGS.test_tmpdir,
                                 "AdanetImproveNasBuilderTest")
        if tf.gfile.Exists(model_dir):
            tf.gfile.DeleteRecursively(model_dir)
        tf.gfile.MkDir(model_dir)

        data_provider = fake_data.FakeImageProvider(seed=seed)
        estimator_builder = adanet_improve_nas.Builder()
        hparams = estimator_builder.hparams(default_batch_size=3,
                                            hparams_string=hparams_string)
        run_config = tf.estimator.RunConfig(tf_random_seed=seed,
                                            model_dir=model_dir)
        _ = data_provider.get_input_fn("train",
                                       tf.estimator.ModeKeys.TRAIN,
                                       batch_size=batch_size)
        test_input_fn = data_provider.get_input_fn("test",
                                                   tf.estimator.ModeKeys.EVAL,
                                                   batch_size=batch_size)

        estimator = estimator_builder.estimator(data_provider=data_provider,
                                                run_config=run_config,
                                                hparams=hparams,
                                                train_steps=10,
                                                seed=seed)
        eval_metrics = estimator.evaluate(input_fn=test_input_fn, steps=1)

        self.assertGreater(eval_metrics["loss"], 0.0)
示例#2
0
文件: trainer.py 项目: zyqqing/adanet
def main(argv):
  del argv

  run_config = make_run_config()
  estimator_builder = adanet_improve_nas.Builder()
  hparams = estimator_builder.hparams(FLAGS.batch_size, FLAGS.hparams)

  tf.logging.info("Running Experiment with HParams: %s", hparams)
  if FLAGS.dataset == "cifar10":
    data_provider = cifar10.Provider()
  elif FLAGS.dataset == "cifar100":
    data_provider = cifar100.Provider()
  elif FLAGS.dataset == "fake":
    data_provider = fake_data.FakeImageProvider(
        num_examples=10,
        num_classes=10,
        image_dim=32,
        channels=3,
        seed=42)
  else:
    raise ValueError("Invalid dataset")

  estimator = estimator_builder.estimator(
      data_provider=data_provider,
      run_config=run_config,
      hparams=hparams,
      train_steps=FLAGS.train_steps)

  train_spec = tf.estimator.TrainSpec(
      input_fn=data_provider.get_input_fn(
          partition="train",
          mode=tf.estimator.ModeKeys.TRAIN,
          batch_size=FLAGS.batch_size),
      max_steps=FLAGS.train_steps
  )

  eval_spec = tf.estimator.EvalSpec(
      input_fn=data_provider.get_input_fn(
          partition="test",
          mode=tf.estimator.ModeKeys.EVAL,
          batch_size=FLAGS.batch_size),
      steps=FLAGS.eval_steps,
      start_delay_secs=10,
      throttle_secs=1800
  )

  tf.logging.info("Training!")
  tf.estimator.train_and_evaluate(estimator, train_spec, eval_spec)
  tf.logging.info("Done training!")