コード例 #1
0
 def _check_dimensions(self, partition):
     provider = cifar100.Provider(seed=4)
     input_fn = provider.get_input_fn(partition,
                                      tf.contrib.learn.ModeKeys.TRAIN,
                                      batch_size=3)
     data, labels = input_fn()
     self.assertIn(cifar100.FEATURES, data)
     features = data[cifar100.FEATURES]
     init = tf.group(tf.global_variables_initializer(),
                     tf.local_variables_initializer())
     with self.test_session() as sess:
         sess.run(init)
         self.assertEqual((3, 32, 32, 3), sess.run(features).shape)
         self.assertEqual((3, 1), sess.run(labels).shape)
コード例 #2
0
    def test_no_preprocess(self):
        provider = cifar100.Provider(seed=4)
        input_fn = provider.get_input_fn("train",
                                         tf.contrib.learn.ModeKeys.TRAIN,
                                         batch_size=3,
                                         preprocess=False)
        data, label = input_fn()

        init = tf.group(tf.global_variables_initializer(),
                        tf.local_variables_initializer())
        with self.test_session() as sess:
            sess.run(init)
            self.assertAllEqual([220, 25, 47], sess.run(data["x"])[0][0][0])
            self.assertAllEqual([[47], [5], [52]], sess.run(label))
コード例 #3
0
ファイル: trainer.py プロジェクト: zyqqing/adanet
def main(argv):
  del argv

  run_config = make_run_config()
  estimator_builder = adanet_improve_nas.Builder()
  hparams = estimator_builder.hparams(FLAGS.batch_size, FLAGS.hparams)

  tf.logging.info("Running Experiment with HParams: %s", hparams)
  if FLAGS.dataset == "cifar10":
    data_provider = cifar10.Provider()
  elif FLAGS.dataset == "cifar100":
    data_provider = cifar100.Provider()
  elif FLAGS.dataset == "fake":
    data_provider = fake_data.FakeImageProvider(
        num_examples=10,
        num_classes=10,
        image_dim=32,
        channels=3,
        seed=42)
  else:
    raise ValueError("Invalid dataset")

  estimator = estimator_builder.estimator(
      data_provider=data_provider,
      run_config=run_config,
      hparams=hparams,
      train_steps=FLAGS.train_steps)

  train_spec = tf.estimator.TrainSpec(
      input_fn=data_provider.get_input_fn(
          partition="train",
          mode=tf.estimator.ModeKeys.TRAIN,
          batch_size=FLAGS.batch_size),
      max_steps=FLAGS.train_steps
  )

  eval_spec = tf.estimator.EvalSpec(
      input_fn=data_provider.get_input_fn(
          partition="test",
          mode=tf.estimator.ModeKeys.EVAL,
          batch_size=FLAGS.batch_size),
      steps=FLAGS.eval_steps,
      start_delay_secs=10,
      throttle_secs=1800
  )

  tf.logging.info("Training!")
  tf.estimator.train_and_evaluate(estimator, train_spec, eval_spec)
  tf.logging.info("Done training!")
コード例 #4
0
    def test_basic_preprocess(self):
        provider = cifar100.Provider(params_string="augmentation=basic",
                                     seed=4)
        input_fn = provider.get_input_fn("train",
                                         tf.contrib.learn.ModeKeys.TRAIN,
                                         batch_size=3,
                                         preprocess=True)
        data, label = input_fn()

        init = tf.group(tf.global_variables_initializer(),
                        tf.local_variables_initializer())
        with self.test_session() as sess:
            sess.run(init)
            data_result = sess.run(data["x"])
            self.assertEqual((3, 32, 32, 3), data_result.shape)
            self.assertAllEqual([0, 0, 0], data_result[0, 0, 0])
            self.assertAlmostEqual(0.0, data_result[0, -1, 0, 0], places=3)
            self.assertAllEqual([[47], [5], [52]], sess.run(label))