def _shuffle_dataset(dataset, hparams, dataset_files): dataset_size = None shuffle_buffer_size = hparams["shuffle_buffer_size"] if hparams["shard_and_shuffle"]: if shuffle_buffer_size is None: raise ValueError( "Dataset hyperparameter 'shuffle_buffer_size' " "must not be `None` if 'shard_and_shuffle'=`True`.") dataset_size = count_file_lines(dataset_files) if shuffle_buffer_size >= dataset_size: raise ValueError( "Dataset size (%d) <= shuffle_buffer_size (%d). Set " "shuffle_and_shard to `False`." % (dataset_size, shuffle_buffer_size)) # TODO(zhiting): Use a different seed? dataset = dataset.apply( dsutils.random_shard_dataset(dataset_size, shuffle_buffer_size, hparams["seed"])) dataset = dataset.shuffle( shuffle_buffer_size + 16, # add a margin seed=hparams["seed"]) elif hparams["shuffle"]: if shuffle_buffer_size is None: dataset_size = count_file_lines(dataset_files) shuffle_buffer_size = dataset_size dataset = dataset.shuffle(shuffle_buffer_size, seed=hparams["seed"]) return dataset, dataset_size
def test_load_glove(self): """Tests the load_glove function. """ file_1 = tempfile.NamedTemporaryFile(mode="w+") num_lines = data_utils.count_file_lines(file_1.name) self.assertEqual(num_lines, 0) file_2 = tempfile.NamedTemporaryFile(mode="w+") file_2.write('\n'.join(['x'] * 5)) file_2.flush() num_lines = data_utils.count_file_lines( [file_1.name, file_2.name, file_2.name]) self.assertEqual(num_lines, 0 + 5 + 5)
def dataset_size(self): """Returns the number of data instances in the data files. Note that this is the total data count in the raw files, before any filtering and truncation. """ if not self._dataset_size: # pylint: disable=attribute-defined-outside-init self._dataset_size = count_file_lines(self._hparams.dataset.files) return self._dataset_size