예제 #1
0
    def _make_data(self):
        dataset_hparams = self._hparams.dataset

        # Create vocab and embedding
        self._vocab = self.make_vocab(dataset_hparams)
        self._embedding = self.make_embedding(
            dataset_hparams["embedding_init"], self._vocab.token_to_id_map_py)

        # Create and shuffle dataset
        dataset = self._make_mono_text_dataset(dataset_hparams)
        dataset, dataset_size = self._shuffle_dataset(
            dataset, self._hparams, self._hparams.dataset.files)
        self._dataset_size = dataset_size

        # Processing
        data_spec = dsutils._DataSpec(dataset=dataset,
                                      dataset_size=self._dataset_size,
                                      vocab=self._vocab,
                                      embedding=self._embedding)
        dataset, data_spec = self._process_dataset(dataset, self._hparams,
                                                   data_spec)
        self._data_spec = data_spec
        self._decoder = data_spec.decoder

        # Batching
        length_fn = self._make_bucket_length_fn()
        padded_shapes = self._make_padded_shapes(dataset, self._decoder)
        dataset = self._make_batch(dataset, self._hparams, length_fn,
                                   padded_shapes)

        # Prefetching
        if self._hparams.prefetch_buffer_size > 0:
            dataset = dataset.prefetch(self._hparams.prefetch_buffer_size)

        self._dataset = dataset
예제 #2
0
    def _make_data(self):
        dataset_hparams = self._hparams.dataset

        # Create and shuffle dataset
        dataset = MonoTextData._make_mono_text_dataset(dataset_hparams)
        dataset, dataset_size = self._shuffle_dataset(
            dataset, self._hparams, self._hparams.dataset.files)
        self._dataset_size = dataset_size

        # Processing
        # pylint: disable=protected-access
        data_spec = dsutils._DataSpec(dataset=dataset,
                                      dataset_size=self._dataset_size)
        dataset, data_spec = self._process_dataset(dataset, self._hparams,
                                                   data_spec)
        self._data_spec = data_spec
        self._decoder = data_spec.decoder # pylint: disable=no-member

        # Batching
        dataset = self._make_batch(dataset, self._hparams)

        # Prefetching
        if self._hparams.prefetch_buffer_size > 0:
            dataset = dataset.prefetch(self._hparams.prefetch_buffer_size)

        self._dataset = dataset
예제 #3
0
    def _make_data(self):
        dataset = self._read_TFRecord_data()
        # Create and shuffle dataset
        num_shards = self._hparams.dataset.num_shards
        shard_id = self._hparams.dataset.shard_id
        if num_shards is not None and shard_id is not None:
            dataset = dataset.shard(num_shards, shard_id)
        dataset, dataset_size = self._shuffle_dataset(
            dataset, self._hparams, self._hparams.dataset.files)
        self._dataset_size = dataset_size

        # Processing
        # pylint: disable=protected-access
        data_spec = dsutils._DataSpec(dataset=dataset,
                                      dataset_size=self._dataset_size)
        dataset, data_spec = self._process_dataset(dataset, self._hparams,
                                                   data_spec)
        self._data_spec = data_spec
        self._decoder = data_spec.decoder # pylint: disable=no-member
        # Batching
        dataset = self._make_batch(dataset, self._hparams)
        # Prefetching
        if self._hparams.prefetch_buffer_size > 0:
            dataset = dataset.prefetch(self._hparams.prefetch_buffer_size)

        self._dataset = dataset
        self.dataset = dataset
예제 #4
0
    def _make_data(self):
        self._vocab = self.make_vocab(self._hparams.datasets)
        self._embedding = self.make_embedding(self._hparams.datasets,
                                              self._vocab)

        # Create dataset
        dataset = self._make_dataset()
        dataset, dataset_size = self._shuffle_dataset(
            dataset, self._hparams, self._hparams.datasets[0].files)
        self._dataset_size = dataset_size

        # Processing
        data_spec = dsutils._DataSpec(dataset=dataset,
                                      dataset_size=self._dataset_size,
                                      vocab=self._vocab,
                                      embedding=self._embedding)
        dataset, data_spec = self._process_dataset(dataset, self._hparams,
                                                   data_spec)
        self._data_spec = data_spec
        self._decoder = data_spec.decoder

        # Batching
        length_fn = self._make_bucket_length_fn()
        padded_shapes = self._make_padded_shapes(dataset, self._decoder)
        dataset = self._make_batch(dataset, self._hparams, length_fn,
                                   padded_shapes)

        # Prefetching
        if self._hparams.prefetch_buffer_size > 0:
            dataset = dataset.prefetch(self._hparams.prefetch_buffer_size)

        self._dataset = dataset
예제 #5
0
    def _make_data(self):
        self._src_vocab, self._tgt_vocab = self.make_vocab(
            self._hparams.source_dataset, self._hparams.target_dataset)

        tgt_hparams = self._hparams.target_dataset
        if not tgt_hparams.vocab_share and tgt_hparams.embedding_init_share:
            raise ValueError("embedding_init can be shared only when vocab "
                             "is shared. Got `vocab_share=False, "
                             "emb_init_share=True`.")
        self._src_embedding, self._tgt_embedding = self.make_embedding(
            self._hparams.source_dataset.embedding_init,
            self._src_vocab.token_to_id_map_py,
            self._hparams.target_dataset.embedding_init,
            self._tgt_vocab.token_to_id_map_py,
            self._hparams.target_dataset.embedding_init_share)

        # Create dataset
        dataset = self._make_dataset()
        dataset, dataset_size = self._shuffle_dataset(
            dataset, self._hparams, self._hparams.source_dataset.files)
        self._dataset_size = dataset_size

        # Processing.
        data_spec = dsutils._DataSpec(
            dataset=dataset,
            dataset_size=self._dataset_size,
            vocab=[self._src_vocab, self._tgt_vocab],
            embedding=[self._src_embedding, self._tgt_embedding])
        dataset, data_spec = self._process_dataset(dataset, self._hparams,
                                                   data_spec)
        self._data_spec = data_spec
        self._decoder = data_spec.decoder
        self._src_decoder = data_spec.decoder[0]
        self._tgt_decoder = data_spec.decoder[1]

        # Batching
        length_fn = self._make_bucket_length_fn()
        padded_shapes = self._make_padded_shapes(dataset, self._src_decoder,
                                                 self._tgt_decoder)
        dataset = self._make_batch(dataset, self._hparams, length_fn,
                                   padded_shapes)

        # Prefetching
        if self._hparams.prefetch_buffer_size > 0:
            dataset = dataset.prefetch(self._hparams.prefetch_buffer_size)

        self._dataset = dataset