def testTFDatasetFnInput_ShuffleBufferMustBeSet(self): def CreateDatasource(**kwargs): ds_params = datasource.TFDatasetFnInput.Params().Set( load_fn='LoadDataset', kwargs=dict(file_pattern=os.path.join(self.tmpdir, '*file_*')), **kwargs) ds = ds_params.Instantiate() ds.SetInputGenerator(TestInputGenerator.Params().Instantiate()) ds.GetNext() with cluster_factory.SetRequireSequentialInputOrder(False): with self.assertRaisesRegex(ValueError, 'shuffle_buffer_size must be set.'): CreateDatasource() # Setting shuffle_buffer_size works. with cluster_factory.SetRequireSequentialInputOrder(False): CreateDatasource(shuffle_buffer_size=1) # Setting require_sequential_input_order works. with cluster_factory.SetRequireSequentialInputOrder(True): CreateDatasource() # Sanity check that params are not persisting between calls. with cluster_factory.SetRequireSequentialInputOrder(False): with self.assertRaisesRegex(ValueError, 'shuffle_buffer_size must be set.'): CreateDatasource()
def setUp(self): super().setUp() # Ensure the global_step variable is created in the default graph. py_utils.GetOrCreateGlobalStepVar() cluster = cluster_factory.SetRequireSequentialInputOrder(True) cluster.params.in_unit_test = True cluster.__enter__()
def testTFDatasetBatchBySequenceLength(self): ds_params = datasource.TFDatasetFnInput.Params().Set( load_fn='LoadDataset', kwargs=dict(file_pattern=os.path.join(self.tmpdir, '*file_*')), shuffle_buffer_size=100) ds_params = datasource.TFDatasetBatchBySequenceLength.Params().Set( sub=ds_params, seqlen_fn='GetSequenceLength', input_shape_fn='_InputShape', input_padding_fn='_InputPaddingValue', bucket_upper_bound=[ len(os.path.join(self.tmpdir, 'file_1')), len(os.path.join(self.tmpdir, 'longfile_1')) ], bucket_batch_limit=[8, 8]) ds = ds_params.Instantiate() ds.SetInputGenerator(TestInputGenerator.Params().Instantiate()) with self.session(), cluster_factory.SetRequireSequentialInputOrder( False): batch = ds.GetNext() seen = set() for _ in range(20): files = self.evaluate(batch.data) self.assertEqual(len(files), 8) seen.update(files) basenames = [os.path.basename(file) for file in files] # Batch contains different files of the same length. self.assertGreater(len(set(basenames)), 1) # But everything in the batch is the same length. self.assertLen(set([len(basename) for basename in basenames]), 1) # Longer than bucket_upper_bound[-1] is filtered out. longerfile = os.path.join(self.tmpdir, 'longerfile_1').encode() self.assertEqual(set(seen), set(self.files) - set([longerfile]))
def setUp(self): super().setUp() with contextlib.ExitStack() as stack: stack.enter_context(py_utils.VariableStore()) self.addCleanup(stack.pop_all().close) # Ensure the global_step variable is created in the default graph. py_utils.GetOrCreateGlobalStepVar() cluster = cluster_factory.SetRequireSequentialInputOrder(True) cluster.params.in_unit_test = True cluster.__enter__()
def testTFDatasetFnInput(self): ds_params = datasource.TFDatasetFnInput.Params().Set( load_fn='LoadDataset', kwargs=dict(file_pattern=os.path.join(self.tmpdir, '*file_*')), shuffle_buffer_size=100) ds = ds_params.Instantiate() ds.SetInputGenerator(TestInputGenerator.Params().Instantiate()) files = [] with self.session(), cluster_factory.SetRequireSequentialInputOrder(False): batch = ds.GetNext() for _ in range(len(self.files) * 5): file, source_id = self.evaluate([batch.data, batch.source_id]) self.assertEqual(0, source_id) self.assertIn(file, self.files) files.append(file) self.assertEqual(set(files), set(self.files)) # Should not be produced in deterministic order. self.assertNotAllEqual(self.files * 5, files)
def __init__(self): # TODO(llion): Find a more sensible fix. cluster_factory.SetRequireSequentialInputOrder(True).__enter__() checkpoint_glob = sorted( glob.glob(os.path.join(FLAGS.ckpt, "ckpt-*.index"))) if FLAGS.ckpt_limit == -1: self._ckpt = checkpoint_glob[-1].replace(".index", "") else: last_ckpt = None for idx in checkpoint_glob: ckpt_base = idx.replace(".index", "") value = int(ckpt_base.split("/")[-1].replace("ckpt-", "")) if value > FLAGS.ckpt_limit: break last_ckpt = ckpt_base assert last_ckpt is not None self._ckpt = last_ckpt self._decode_path = FLAGS.decode_dir or os.path.dirname(FLAGS.ckpt) sys.stderr.write("Using checkpoint: {}\n".format(self._ckpt))
def setUp(self): super().setUp() cluster_factory.SetRequireSequentialInputOrder(False).__enter__() self._variable_cache = {} _StubOutCreateVariable(self._variable_cache)