def testReturnDatasetFromInputFn(self): def _input_fn(): return dataset_ops.Dataset.range(10) est = estimator.Estimator(model_fn=self._model_fn) est.train(_input_fn, steps=2, hooks=[self._build_iterator_saver_hook(est)]) self.assertSequenceEqual(self._read_vars(est.model_dir), (2, 1)) est.train(_input_fn, steps=2, hooks=[self._build_iterator_saver_hook(est)]) self.assertSequenceEqual(self._read_vars(est.model_dir), (4, 3))
def testBuildIteratorInInputFn(self): def _input_fn(): ds = dataset_ops.Dataset.range(10) iterator = ds.make_one_shot_iterator() return iterator.get_next() est = estimator.Estimator(model_fn=self._model_fn) est.train(_input_fn, steps=2, hooks=[self._build_iterator_saver_hook(est)]) self.assertSequenceEqual(self._read_vars(est.model_dir), (2, 1)) est.train(_input_fn, steps=2, hooks=[self._build_iterator_saver_hook(est)]) self.assertSequenceEqual(self._read_vars(est.model_dir), (4, 3))