Example #1
0
  def testReturnDatasetFromInputFn(self):

    def _input_fn():
      return dataset_ops.Dataset.range(10)

    est = estimator.Estimator(model_fn=self._model_fn)

    est.train(_input_fn, steps=2, hooks=[self._build_iterator_saver_hook(est)])
    self.assertSequenceEqual(self._read_vars(est.model_dir), (2, 1))
    est.train(_input_fn, steps=2, hooks=[self._build_iterator_saver_hook(est)])
    self.assertSequenceEqual(self._read_vars(est.model_dir), (4, 3))
Example #2
0
  def testBuildIteratorInInputFn(self):

    def _input_fn():
      ds = dataset_ops.Dataset.range(10)
      iterator = ds.make_one_shot_iterator()
      return iterator.get_next()

    est = estimator.Estimator(model_fn=self._model_fn)

    est.train(_input_fn, steps=2, hooks=[self._build_iterator_saver_hook(est)])
    self.assertSequenceEqual(self._read_vars(est.model_dir), (2, 1))
    est.train(_input_fn, steps=2, hooks=[self._build_iterator_saver_hook(est)])
    self.assertSequenceEqual(self._read_vars(est.model_dir), (4, 3))