def test_estimator(self, hparams_string, batch_size=1): """Structural test to make sure Estimator Builder works.""" seed = 42 # Set up and clean test directory. model_dir = os.path.join(flags.FLAGS.test_tmpdir, "AdanetImproveNasBuilderTest") if tf.gfile.Exists(model_dir): tf.gfile.DeleteRecursively(model_dir) tf.gfile.MkDir(model_dir) data_provider = fake_data.FakeImageProvider(seed=seed) estimator_builder = adanet_improve_nas.Builder() hparams = estimator_builder.hparams(default_batch_size=3, hparams_string=hparams_string) run_config = tf.estimator.RunConfig(tf_random_seed=seed, model_dir=model_dir) _ = data_provider.get_input_fn("train", tf.estimator.ModeKeys.TRAIN, batch_size=batch_size) test_input_fn = data_provider.get_input_fn("test", tf.estimator.ModeKeys.EVAL, batch_size=batch_size) estimator = estimator_builder.estimator(data_provider=data_provider, run_config=run_config, hparams=hparams, train_steps=10, seed=seed) eval_metrics = estimator.evaluate(input_fn=test_input_fn, steps=1) self.assertGreater(eval_metrics["loss"], 0.0)
def main(argv): del argv run_config = make_run_config() estimator_builder = adanet_improve_nas.Builder() hparams = estimator_builder.hparams(FLAGS.batch_size, FLAGS.hparams) tf.logging.info("Running Experiment with HParams: %s", hparams) if FLAGS.dataset == "cifar10": data_provider = cifar10.Provider() elif FLAGS.dataset == "cifar100": data_provider = cifar100.Provider() elif FLAGS.dataset == "fake": data_provider = fake_data.FakeImageProvider( num_examples=10, num_classes=10, image_dim=32, channels=3, seed=42) else: raise ValueError("Invalid dataset") estimator = estimator_builder.estimator( data_provider=data_provider, run_config=run_config, hparams=hparams, train_steps=FLAGS.train_steps) train_spec = tf.estimator.TrainSpec( input_fn=data_provider.get_input_fn( partition="train", mode=tf.estimator.ModeKeys.TRAIN, batch_size=FLAGS.batch_size), max_steps=FLAGS.train_steps ) eval_spec = tf.estimator.EvalSpec( input_fn=data_provider.get_input_fn( partition="test", mode=tf.estimator.ModeKeys.EVAL, batch_size=FLAGS.batch_size), steps=FLAGS.eval_steps, start_delay_secs=10, throttle_secs=1800 ) tf.logging.info("Training!") tf.estimator.train_and_evaluate(estimator, train_spec, eval_spec) tf.logging.info("Done training!")