예제 #1
0
    def test_make_chained_transformation(self):
        """Tests :func:`texar.tf.data.make_chained_transformation`
        """
        original_data = np.arange(0, 10)
        dataset = tf.data.Dataset.from_tensor_slices(original_data)

        def _tran_a(data):
            return data + 100

        def _tran_b(data):
            return data + 1000

        def _tran_c(data):
            return data + 10000

        chained_tran = dsutils.make_chained_transformation(
            [_tran_a, _tran_b, _tran_c])
        dataset = dataset.map(chained_tran)

        iterator = dataset.make_one_shot_iterator()
        elem = iterator.get_next()
        with self.test_session() as sess:
            data_ = []
            while True:
                try:
                    data_.append(sess.run(elem))
                except tf.errors.OutOfRangeError:
                    break
            self.assertEqual(len(data_), len(original_data))
            data_ = [elem_ - 11100 for elem_ in data_]
            self.assertEqual(data_, original_data.tolist())
예제 #2
0
    def _make_processor(dataset_hparams,
                        data_spec,
                        chained=True,
                        name_prefix=None):
        # Create data decoder
        max_seq_length = None
        if dataset_hparams["length_filter_mode"] == "truncate":
            max_seq_length = dataset_hparams["max_seq_length"]

        if not dataset_hparams["variable_utterance"]:
            decoder = TextDataDecoder(
                delimiter=dataset_hparams["delimiter"],
                bos_token=dataset_hparams["bos_token"],
                eos_token=dataset_hparams["eos_token"],
                max_seq_length=max_seq_length,
                token_to_id_map=data_spec.vocab.token_to_id_map)
        else:
            decoder = VarUttTextDataDecoder(
                sentence_delimiter=dataset_hparams["utterance_delimiter"],
                delimiter=dataset_hparams["delimiter"],
                bos_token=dataset_hparams["bos_token"],
                eos_token=dataset_hparams["eos_token"],
                max_seq_length=max_seq_length,
                max_utterance_cnt=dataset_hparams["max_utterance_cnt"],
                token_to_id_map=data_spec.vocab.token_to_id_map)

        # Create other transformations
        data_spec.add_spec(decoder=decoder)
        other_trans = MonoTextData._make_other_transformations(
            dataset_hparams["other_transformations"], data_spec)
        if name_prefix:
            other_trans.append(dsutils.name_prefix_fn(name_prefix))

        data_spec.add_spec(name_prefix=name_prefix)

        if chained:
            chained_tran = dsutils.make_chained_transformation([decoder] +
                                                               other_trans)
            return chained_tran, data_spec
        else:
            return decoder, other_trans, data_spec
예제 #3
0
    def _make_processor(dataset_hparams, data_spec, chained=True,
                        name_prefix=None):
        # Create data decoder
        decoder = TFRecordDataNumpyDecoder(
            feature_original_types=dataset_hparams.feature_original_types,
            feature_convert_types=dataset_hparams.feature_convert_types,
            image_options=dataset_hparams.image_options,
            numpy_options=dataset_hparams.numpy_options)
        # Create other transformations
        data_spec.add_spec(decoder=decoder)
        # pylint: disable=protected-access
        other_trans = MonoTextData._make_other_transformations(
            dataset_hparams["other_transformations"], data_spec)

        data_spec.add_spec(name_prefix=name_prefix)

        if chained:
            chained_tran = dsutils.make_chained_transformation(
                [decoder] + other_trans)
            return chained_tran, data_spec
        else:
            return decoder, other_trans, data_spec
예제 #4
0
    def _make_processor(dataset_hparams,
                        data_spec,
                        chained=True,
                        name_prefix=None):
        # Create data decoder
        decoder = ScalarDataDecoder(ScalarData._get_dtype(
            dataset_hparams["data_type"]),
                                    data_name=name_prefix)
        # Create other transformations
        data_spec.add_spec(decoder=decoder)
        # pylint: disable=protected-access
        other_trans = MonoTextData._make_other_transformations(
            dataset_hparams["other_transformations"], data_spec)

        data_spec.add_spec(name_prefix=name_prefix)

        if chained:
            chained_tran = dsutils.make_chained_transformation([decoder] +
                                                               other_trans)
            return chained_tran, data_spec
        else:
            return decoder, other_trans, data_spec