コード例 #1
0
    def testTFDatasetFnInput_ShuffleBufferMustBeSet(self):
        def CreateDatasource(**kwargs):
            ds_params = datasource.TFDatasetFnInput.Params().Set(
                load_fn='LoadDataset',
                kwargs=dict(file_pattern=os.path.join(self.tmpdir, '*file_*')),
                **kwargs)
            ds = ds_params.Instantiate()
            ds.SetInputGenerator(TestInputGenerator.Params().Instantiate())
            ds.GetNext()

        with cluster_factory.SetRequireSequentialInputOrder(False):
            with self.assertRaisesRegex(ValueError,
                                        'shuffle_buffer_size must be set.'):
                CreateDatasource()

        # Setting shuffle_buffer_size works.
        with cluster_factory.SetRequireSequentialInputOrder(False):
            CreateDatasource(shuffle_buffer_size=1)

        # Setting require_sequential_input_order works.
        with cluster_factory.SetRequireSequentialInputOrder(True):
            CreateDatasource()

        # Sanity check that params are not persisting between calls.
        with cluster_factory.SetRequireSequentialInputOrder(False):
            with self.assertRaisesRegex(ValueError,
                                        'shuffle_buffer_size must be set.'):
                CreateDatasource()
コード例 #2
0
 def setUp(self):
     super().setUp()
     # Ensure the global_step variable is created in the default graph.
     py_utils.GetOrCreateGlobalStepVar()
     cluster = cluster_factory.SetRequireSequentialInputOrder(True)
     cluster.params.in_unit_test = True
     cluster.__enter__()
コード例 #3
0
 def testTFDatasetBatchBySequenceLength(self):
     ds_params = datasource.TFDatasetFnInput.Params().Set(
         load_fn='LoadDataset',
         kwargs=dict(file_pattern=os.path.join(self.tmpdir, '*file_*')),
         shuffle_buffer_size=100)
     ds_params = datasource.TFDatasetBatchBySequenceLength.Params().Set(
         sub=ds_params,
         seqlen_fn='GetSequenceLength',
         input_shape_fn='_InputShape',
         input_padding_fn='_InputPaddingValue',
         bucket_upper_bound=[
             len(os.path.join(self.tmpdir, 'file_1')),
             len(os.path.join(self.tmpdir, 'longfile_1'))
         ],
         bucket_batch_limit=[8, 8])
     ds = ds_params.Instantiate()
     ds.SetInputGenerator(TestInputGenerator.Params().Instantiate())
     with self.session(), cluster_factory.SetRequireSequentialInputOrder(
             False):
         batch = ds.GetNext()
         seen = set()
         for _ in range(20):
             files = self.evaluate(batch.data)
             self.assertEqual(len(files), 8)
             seen.update(files)
             basenames = [os.path.basename(file) for file in files]
             # Batch contains different files of the same length.
             self.assertGreater(len(set(basenames)), 1)
             # But everything in the batch is the same length.
             self.assertLen(set([len(basename) for basename in basenames]),
                            1)
         # Longer than bucket_upper_bound[-1] is filtered out.
         longerfile = os.path.join(self.tmpdir, 'longerfile_1').encode()
         self.assertEqual(set(seen), set(self.files) - set([longerfile]))
コード例 #4
0
ファイル: test_utils.py プロジェクト: tensorflow/lingvo
 def setUp(self):
     super().setUp()
     with contextlib.ExitStack() as stack:
         stack.enter_context(py_utils.VariableStore())
         self.addCleanup(stack.pop_all().close)
     # Ensure the global_step variable is created in the default graph.
     py_utils.GetOrCreateGlobalStepVar()
     cluster = cluster_factory.SetRequireSequentialInputOrder(True)
     cluster.params.in_unit_test = True
     cluster.__enter__()
コード例 #5
0
 def testTFDatasetFnInput(self):
   ds_params = datasource.TFDatasetFnInput.Params().Set(
       load_fn='LoadDataset',
       kwargs=dict(file_pattern=os.path.join(self.tmpdir, '*file_*')),
       shuffle_buffer_size=100)
   ds = ds_params.Instantiate()
   ds.SetInputGenerator(TestInputGenerator.Params().Instantiate())
   files = []
   with self.session(), cluster_factory.SetRequireSequentialInputOrder(False):
     batch = ds.GetNext()
     for _ in range(len(self.files) * 5):
       file, source_id = self.evaluate([batch.data, batch.source_id])
       self.assertEqual(0, source_id)
       self.assertIn(file, self.files)
       files.append(file)
   self.assertEqual(set(files), set(self.files))
   # Should not be produced in deterministic order.
   self.assertNotAllEqual(self.files * 5, files)
コード例 #6
0
 def __init__(self):
     # TODO(llion): Find a more sensible fix.
     cluster_factory.SetRequireSequentialInputOrder(True).__enter__()
     checkpoint_glob = sorted(
         glob.glob(os.path.join(FLAGS.ckpt, "ckpt-*.index")))
     if FLAGS.ckpt_limit == -1:
         self._ckpt = checkpoint_glob[-1].replace(".index", "")
     else:
         last_ckpt = None
         for idx in checkpoint_glob:
             ckpt_base = idx.replace(".index", "")
             value = int(ckpt_base.split("/")[-1].replace("ckpt-", ""))
             if value > FLAGS.ckpt_limit:
                 break
             last_ckpt = ckpt_base
         assert last_ckpt is not None
         self._ckpt = last_ckpt
     self._decode_path = FLAGS.decode_dir or os.path.dirname(FLAGS.ckpt)
     sys.stderr.write("Using checkpoint: {}\n".format(self._ckpt))
コード例 #7
0
 def setUp(self):
     super().setUp()
     cluster_factory.SetRequireSequentialInputOrder(False).__enter__()
     self._variable_cache = {}
     _StubOutCreateVariable(self._variable_cache)