예제 #1
0
    def _shuffle_dataset(dataset, hparams, dataset_files):
        dataset_size = None
        shuffle_buffer_size = hparams["shuffle_buffer_size"]
        if hparams["shard_and_shuffle"]:
            if shuffle_buffer_size is None:
                raise ValueError(
                    "Dataset hyperparameter 'shuffle_buffer_size' "
                    "must not be `None` if 'shard_and_shuffle'=`True`.")
            dataset_size = count_file_lines(dataset_files)
            if shuffle_buffer_size >= dataset_size:
                raise ValueError(
                    "Dataset size (%d) <= shuffle_buffer_size (%d). Set "
                    "shuffle_and_shard to `False`." %
                    (dataset_size, shuffle_buffer_size))
            #TODO(zhiting): Use a different seed?
            dataset = dataset.apply(
                dsutils.random_shard_dataset(dataset_size, shuffle_buffer_size,
                                             hparams["seed"]))
            dataset = dataset.shuffle(
                shuffle_buffer_size + 16,  # add a margin
                seed=hparams["seed"])
        elif hparams["shuffle"]:
            if shuffle_buffer_size is None:
                dataset_size = count_file_lines(dataset_files)
                shuffle_buffer_size = dataset_size
            dataset = dataset.shuffle(shuffle_buffer_size,
                                      seed=hparams["seed"])

        return dataset, dataset_size
    def test_load_glove(self):
        """Tests the load_glove function.
        """
        file_1 = tempfile.NamedTemporaryFile(mode="w+")
        num_lines = data_utils.count_file_lines(file_1.name)
        self.assertEqual(num_lines, 0)

        file_2 = tempfile.NamedTemporaryFile(mode="w+")
        file_2.write('\n'.join(['x'] * 5))
        file_2.flush()
        num_lines = data_utils.count_file_lines(
            [file_1.name, file_2.name, file_2.name])
        self.assertEqual(num_lines, 0 + 5 + 5)
예제 #3
0
    def dataset_size(self):
        """Returns the number of data instances in the data files.

        Note that this is the total data count in the raw files, before any
        filtering and truncation.
        """
        if not self._dataset_size:
            # pylint: disable=attribute-defined-outside-init
            self._dataset_size = count_file_lines(self._hparams.dataset.files)
        return self._dataset_size