def _check_dimensions(self, partition): provider = cifar100.Provider(seed=4) input_fn = provider.get_input_fn(partition, tf.contrib.learn.ModeKeys.TRAIN, batch_size=3) data, labels = input_fn() self.assertIn(cifar100.FEATURES, data) features = data[cifar100.FEATURES] init = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer()) with self.test_session() as sess: sess.run(init) self.assertEqual((3, 32, 32, 3), sess.run(features).shape) self.assertEqual((3, 1), sess.run(labels).shape)
def test_no_preprocess(self): provider = cifar100.Provider(seed=4) input_fn = provider.get_input_fn("train", tf.contrib.learn.ModeKeys.TRAIN, batch_size=3, preprocess=False) data, label = input_fn() init = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer()) with self.test_session() as sess: sess.run(init) self.assertAllEqual([220, 25, 47], sess.run(data["x"])[0][0][0]) self.assertAllEqual([[47], [5], [52]], sess.run(label))
def main(argv): del argv run_config = make_run_config() estimator_builder = adanet_improve_nas.Builder() hparams = estimator_builder.hparams(FLAGS.batch_size, FLAGS.hparams) tf.logging.info("Running Experiment with HParams: %s", hparams) if FLAGS.dataset == "cifar10": data_provider = cifar10.Provider() elif FLAGS.dataset == "cifar100": data_provider = cifar100.Provider() elif FLAGS.dataset == "fake": data_provider = fake_data.FakeImageProvider( num_examples=10, num_classes=10, image_dim=32, channels=3, seed=42) else: raise ValueError("Invalid dataset") estimator = estimator_builder.estimator( data_provider=data_provider, run_config=run_config, hparams=hparams, train_steps=FLAGS.train_steps) train_spec = tf.estimator.TrainSpec( input_fn=data_provider.get_input_fn( partition="train", mode=tf.estimator.ModeKeys.TRAIN, batch_size=FLAGS.batch_size), max_steps=FLAGS.train_steps ) eval_spec = tf.estimator.EvalSpec( input_fn=data_provider.get_input_fn( partition="test", mode=tf.estimator.ModeKeys.EVAL, batch_size=FLAGS.batch_size), steps=FLAGS.eval_steps, start_delay_secs=10, throttle_secs=1800 ) tf.logging.info("Training!") tf.estimator.train_and_evaluate(estimator, train_spec, eval_spec) tf.logging.info("Done training!")
def test_basic_preprocess(self): provider = cifar100.Provider(params_string="augmentation=basic", seed=4) input_fn = provider.get_input_fn("train", tf.contrib.learn.ModeKeys.TRAIN, batch_size=3, preprocess=True) data, label = input_fn() init = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer()) with self.test_session() as sess: sess.run(init) data_result = sess.run(data["x"]) self.assertEqual((3, 32, 32, 3), data_result.shape) self.assertAllEqual([0, 0, 0], data_result[0, 0, 0]) self.assertAlmostEqual(0.0, data_result[0, -1, 0, 0], places=3) self.assertAllEqual([[47], [5], [52]], sess.run(label))