コード例 #1
0
    def _run_and_test(self, hparams):
        # Construct database
        record_data = RecordData(hparams)
        iterator = DataIterator(record_data)

        def _prod(lst):
            res = 1
            for i in lst:
                res *= i
            return res

        for idx, data_batch in enumerate(iterator):
            self.assertEqual(set(data_batch.keys()),
                             set(record_data.list_items()))

            # Check data consistency
            for key in self._unconvert_features:
                value = data_batch[key][0]
                self.assertEqual(value, self._dataset_valid[key][idx])
            self.assertEqual(list(data_batch['shape'][0]),
                             list(self._dataset_valid['shape'][idx]))

            # Check data type conversion
            for key, item in self._feature_convert_types.items():
                dtype = get_numpy_dtype(item)
                value = data_batch[key][0]
                if dtype is np.str_:
                    self.assertIsInstance(value, str)
                elif dtype is np.bytes_:
                    self.assertIsInstance(value, bytes)
                else:
                    if isinstance(value, torch.Tensor):
                        value_dtype = get_numpy_dtype(value.dtype)
                    else:
                        value_dtype = value.dtype
                    dtype_matched = np.issubdtype(value_dtype, dtype)
                    self.assertTrue(dtype_matched)

            # Check image decoding and resize
            if hparams["dataset"].get("image_options"):
                image_options = hparams["dataset"].get("image_options")
                if isinstance(image_options, dict):
                    image_options = [image_options]
                for image_option_feature in image_options:
                    image_key = image_option_feature.get("image_feature_name")
                    if image_key is None:
                        continue
                    image_gen = data_batch[image_key][0]
                    image_valid_shape = self._dataset_valid["shape"][idx]
                    resize_height = image_option_feature.get("resize_height")
                    resize_width = image_option_feature.get("resize_width")
                    if resize_height and resize_width:
                        self.assertEqual(
                            image_gen.shape[0] * image_gen.shape[1],
                            resize_height * resize_width)
                    else:
                        self.assertEqual(_prod(image_gen.shape),
                                         _prod(image_valid_shape))
コード例 #2
0
    def setUp(self):
        # Create test data
        vocab_list = ['This', 'is', 'a', 'word', '词']
        vocab_file = tempfile.NamedTemporaryFile()
        vocab_file.write('\n'.join(vocab_list).encode("utf-8"))
        vocab_file.flush()
        self._vocab_file = vocab_file
        self._vocab_size = len(vocab_list)

        text_0 = ['This is a sentence from source .', '词 词 。 source']
        text_0_file = tempfile.NamedTemporaryFile()
        text_0_file.write('\n'.join(text_0).encode("utf-8"))
        text_0_file.flush()
        self._text_0_file = text_0_file

        text_1 = ['This is a sentence from target .', '词 词 。 target']
        text_1_file = tempfile.NamedTemporaryFile()
        text_1_file.write('\n'.join(text_1).encode("utf-8"))
        text_1_file.flush()
        self._text_1_file = text_1_file

        text_2 = [
            'This is a sentence from dialog . ||| dialog ',
            '词 词 。 ||| 词 dialog'
        ]
        text_2_file = tempfile.NamedTemporaryFile()
        text_2_file.write('\n'.join(text_2).encode("utf-8"))
        text_2_file.flush()
        self._text_2_file = text_2_file

        int_3 = [0, 1]
        int_3_file = tempfile.NamedTemporaryFile()
        int_3_file.write(('\n'.join([str(_) for _ in int_3])).encode("utf-8"))
        int_3_file.flush()
        self._int_3_file = int_3_file

        self._tfrecord_filepath = os.path.join(tempfile.mkdtemp(),
                                               'test.tfrecord')
        self._feature_original_types = {
            'number1': ('tf.int64', 'FixedLenFeature'),
            'number2': ('tf.int64', 'FixedLenFeature'),
            'text': ('tf.string', 'FixedLenFeature')
        }

        features = [{
            "number1": 128,
            "number2": 512,
            "text": "This is a sentence for TFRecord 词 词 。"
        }, {
            "number1": 128,
            "number2": 512,
            "text": "This is a another sentence for TFRecord 词 词 。"
        }]
        # Prepare Validation data
        with RecordData.writer(self._tfrecord_filepath,
                               self._feature_original_types) as writer:
            for feature in features:
                writer.write(feature)

        # Construct database
        self._hparams = {
            "num_epochs":
            1,
            "batch_size":
            1,
            "datasets": [
                {  # dataset 0
                    "files": [self._text_0_file.name],
                    "vocab_file": self._vocab_file.name,
                    "bos_token": "",
                    "data_name": "0"
                },
                {  # dataset 1
                    "files": [self._text_1_file.name],
                    "vocab_share_with": 0,
                    "eos_token": "<TARGET_EOS>",
                    "data_name": "1"
                },
                {  # dataset 2
                    "files": [self._text_2_file.name],
                    "vocab_file": self._vocab_file.name,
                    "processing_share_with": 0,
                    # TODO(avinash) - Add it back once feature is added
                    "variable_utterance": False,
                    "data_name": "2"
                },
                {  # dataset 3
                    "files": self._int_3_file.name,
                    "data_type": "int",
                    "data_name": "label"
                },
                {  # dataset 4
                    "files": self._tfrecord_filepath,
                    "feature_original_types": self._feature_original_types,
                    "feature_convert_types": {
                        'number2': 'tf.float32',
                    },
                    "num_shards": 2,
                    "shard_id": 1,
                    "data_type": "record",
                    "data_name": "4"
                }
            ]
        }
コード例 #3
0
    def setUp(self):
        # Create test data
        self._test_dir = tempfile.mkdtemp()

        cat_in_snow = maybe_download(
            'https://storage.googleapis.com/download.tensorflow.org/'
            'example_images/320px-Felis_catus-cat_on_snow.jpg', self._test_dir,
            'cat_0.jpg')
        williamsburg_bridge = maybe_download(
            'https://storage.googleapis.com/download.tensorflow.org/'
            'example_images/194px-New_East_River_Bridge_from_Brooklyn_'
            'det.4a09796u.jpg', self._test_dir, 'bridge_0.jpg')

        _feature_types = {
            'height': ('tf.int64', 'FixedLenFeature', 1),
            'width': ('tf.int64', 'FixedLenFeature', 1),
            'label': ('tf.int64', 'stacked_tensor', 1),
            'shape': (np.int64, 'VarLenFeature'),
            'image_raw': (bytes, 'stacked_tensor'),
            'variable1': (np.str, 'FixedLenFeature'),
            'variable2': ('tf.int64', 'FixedLenFeature'),
        }
        self._feature_convert_types = {
            'variable1': 'tf.float32',
            'variable2': 'tf.string',
        }
        _image_options = {}
        self._unconvert_features = ['height', 'width', 'label']

        self._dataset_valid = {
            'height': [],
            'width': [],
            'shape': [],
            'label': [],
            'image_raw': [],
            'variable1': [],
            'variable2': [],
        }
        _toy_image_labels_valid = {
            cat_in_snow: 0,
            williamsburg_bridge: 1,
        }
        _toy_image_shapes = {
            cat_in_snow: (213, 320, 3),
            williamsburg_bridge: (239, 194),
        }
        _record_filepath = os.path.join(self._test_dir, 'test.pkl')

        # Prepare Validation data
        with RecordData.writer(_record_filepath, _feature_types) as writer:
            for image_path, label in _toy_image_labels_valid.items():
                with open(image_path, 'rb') as fid:
                    image_data = fid.read()
                image_shape = _toy_image_shapes[image_path]

                # _construct_dataset_valid("", shape, label)
                single_data = {
                    'height': image_shape[0],
                    'width': image_shape[1],
                    'shape': image_shape,
                    'label': label,
                    'image_raw': image_data,
                    'variable1': "1234567890",
                    'variable2': int(9876543210),
                }
                for key, value in single_data.items():
                    self._dataset_valid[key].append(value)
                writer.write(single_data)

        self._hparams = {
            "num_epochs": 1,
            "batch_size": 1,
            "shuffle": False,
            "dataset": {
                "files": _record_filepath,
                "feature_original_types": _feature_types,
                "feature_convert_types": self._feature_convert_types,
                "image_options": [_image_options],
            }
        }
コード例 #4
0
    def __init__(self, hparams, device: Optional[torch.device] = None):
        self._hparams = HParams(hparams, self.default_hparams())
        # Defaultizes hyperparameters of each dataset
        datasets_hparams = self._hparams.datasets
        defaultized_datasets_hparams = []
        for hparams_i in datasets_hparams:
            data_type = hparams_i.get("data_type", None)
            defaultized_ds_hpms = HParams(hparams_i,
                                          _default_dataset_hparams(data_type))
            defaultized_datasets_hparams.append(defaultized_ds_hpms)
        self._hparams.datasets = defaultized_datasets_hparams

        self._vocab = self.make_vocab(self._hparams.datasets)
        self._embedding = self.make_embedding(self._hparams.datasets,
                                              self._vocab)

        dummy_source = SequenceDataSource[Any]([])
        name_prefix: List[str] = []
        self._names: List[Dict[str, Any]] = []
        sources: List[DataSource] = []
        filters: List[Optional[Callable[[str], bool]]] = []
        self._databases: List[DataBase] = []
        for idx, hparams_i in enumerate(self._hparams.datasets):
            data_type = _DataType(hparams_i.data_type)
            source_i: DataSource

            if _is_text_data(data_type):
                source_i = TextLineDataSource(
                    hparams_i.files,
                    compression_type=hparams_i.compression_type,
                    delimiter=hparams_i.delimiter)
                sources.append(source_i)
                if ((hparams_i.length_filter_mode
                     == _LengthFilterMode.DISCARD.value)
                        and hparams_i.max_seq_length is not None):

                    def _get_filter(max_seq_length):
                        return lambda x: len(x) <= max_seq_length

                    filters.append(_get_filter(hparams_i.max_seq_length))
                else:
                    filters.append(None)

                self._names.append({
                    field: connect_name(hparams_i.data_name, field)
                    for field in ["text", "text_ids", "length"]
                })

                dataset_hparams = dict_fetch(
                    hparams_i,
                    MonoTextData.default_hparams()["dataset"])
                dataset_hparams["data_name"] = None
                self._databases.append(
                    MonoTextData(hparams={"dataset": dataset_hparams},
                                 device=device,
                                 vocab=self._vocab[idx],
                                 embedding=self._embedding[idx],
                                 data_source=dummy_source))
            elif _is_scalar_data(data_type):
                source_i = TextLineDataSource(
                    hparams_i.files,
                    compression_type=hparams_i.compression_type)
                sources.append(source_i)
                filters.append(None)
                self._names.append({"data": hparams_i.data_name})

                dataset_hparams = dict_fetch(
                    hparams_i,
                    ScalarData.default_hparams()["dataset"])
                dataset_hparams["data_name"] = "data"
                self._databases.append(
                    ScalarData(hparams={"dataset": dataset_hparams},
                               device=device,
                               data_source=dummy_source))
            elif _is_record_data(data_type):
                source_i = PickleDataSource(file_paths=hparams_i.files)
                sources.append(source_i)
                self._names.append({
                    name: connect_name(hparams_i.data_name, name)
                    for name in hparams_i.feature_original_types.keys()
                })
                filters.append(None)

                dataset_hparams = dict_fetch(
                    hparams_i,
                    RecordData.default_hparams()["dataset"])
                self._databases.append(
                    RecordData(hparams={"dataset": dataset_hparams},
                               device=device,
                               data_source=dummy_source))
            else:
                raise ValueError(f"Unknown data type: {hparams_i.data_type}")

            # check for duplicate names
            for i in range(1, len(name_prefix)):
                if name_prefix[i] in name_prefix[:i - 1]:
                    raise ValueError(f"Duplicate data name: {name_prefix[i]}")

            name_prefix.append(hparams_i["data_name"])

        self._name_to_id = {v: k for k, v in enumerate(name_prefix)}

        data_source: DataSource = ZipDataSource(*sources)

        if any(filters):

            def filter_fn(data):
                return all(
                    fn(data) for fn, data in zip(filters, data)
                    if fn is not None)

            data_source = FilterDataSource(data_source, filter_fn=filter_fn)
        super().__init__(data_source, self._hparams, device)
    def __init__(self, hparams, device: Optional[torch.device] = None):
        print("Using local texar")
        self._hparams = HParams(hparams, self.default_hparams())
        # Defaultizes hyperparameters of each dataset
        datasets_hparams = self._hparams.datasets
        defaultized_datasets_hparams = []
        for hparams_i in datasets_hparams:
            data_type = hparams_i.get("data_type", None)
            #print("data_type:", data_type)
            defaultized_ds_hpms = HParams(hparams_i,
                                          _default_dataset_hparams(data_type))
            defaultized_datasets_hparams.append(defaultized_ds_hpms)
        self._hparams.datasets = defaultized_datasets_hparams

        #print("will make_vocab")
        self._vocab = self.make_vocab(self._hparams.datasets)
        #print("will make_embedding")
        self._embedding = self.make_embedding(self._hparams.datasets,
                                              self._vocab)

        dummy_source = SequenceDataSource[Any]([])
        name_prefix: List[str] = []
        self._names: List[Dict[str, Any]] = []
        sources: List[DataSource] = []
        filters: List[Optional[Callable[[str], bool]]] = []
        self._databases: List[DatasetBase] = []
        for idx, hparams_i in enumerate(self._hparams.datasets):
            data_type = hparams_i.data_type
            source_i: DataSource

            if _is_text_data(data_type):
                #print("will TextLineDataSource")
                source_i = TextLineDataSource(
                    hparams_i.files,
                    compression_type=hparams_i.compression_type,
                    delimiter=hparams_i.delimiter)
                sources.append(source_i)
                if ((hparams_i.length_filter_mode
                     == _LengthFilterMode.DISCARD.value)
                        and hparams_i.max_seq_length is not None):

                    def _get_filter(max_seq_length):
                        return lambda x: len(x) <= max_seq_length

                    filters.append(_get_filter(hparams_i.max_seq_length))
                else:
                    filters.append(None)

                self._names.append({
                    field: connect_name(hparams_i.data_name, field)
                    for field in ["text", "text_ids", "length"]
                })

                dataset_hparams = dict_fetch(
                    hparams_i,
                    MonoTextData.default_hparams()["dataset"])
                dataset_hparams["data_name"] = None
                self._databases.append(
                    MonoTextData(hparams={"dataset": dataset_hparams},
                                 device=device,
                                 vocab=self._vocab[idx],
                                 embedding=self._embedding[idx],
                                 data_source=dummy_source))
            elif _is_scalar_data(data_type):
                source_i = TextLineDataSource(
                    hparams_i.files,
                    compression_type=hparams_i.compression_type)
                sources.append(source_i)
                filters.append(None)
                self._names.append({"data": hparams_i.data_name})

                dataset_hparams = dict_fetch(
                    hparams_i,
                    ScalarData.default_hparams()["dataset"])
                dataset_hparams["data_name"] = "data"
                self._databases.append(
                    ScalarData(hparams={"dataset": dataset_hparams},
                               device=device,
                               data_source=dummy_source))
            elif _is_record_data(data_type):
                source_i = PickleDataSource(file_paths=hparams_i.files)
                sources.append(source_i)
                # TODO: Only check `feature_types` when we finally remove
                #   `feature_original_types`.
                feature_types = (hparams_i.feature_types
                                 or hparams_i.feature_original_types)
                self._names.append({
                    name: connect_name(hparams_i.data_name, name)
                    for name in feature_types.keys()
                })
                filters.append(None)

                dataset_hparams = dict_fetch(
                    hparams_i,
                    RecordData.default_hparams()["dataset"])
                self._databases.append(
                    RecordData(hparams={"dataset": dataset_hparams},
                               device=device,
                               data_source=dummy_source))
            else:
                raise ValueError(f"Unknown data type: {hparams_i.data_type}")

            # check for duplicate names
            for i in range(1, len(name_prefix)):
                if name_prefix[i] in name_prefix[:i - 1]:
                    raise ValueError(f"Duplicate data name: {name_prefix[i]}")

            name_prefix.append(hparams_i["data_name"])

        self._name_to_id = {v: k for k, v in enumerate(name_prefix)}
        self._processed_cache = []
        self._datafile_id = 0  # for training from multiple files
        self._index_at_beginning_of_this_dataset = 0
        self._datafile_prefix = hparams_i.files
        #self._datafile_num = 33 # hparams_i.datafile_num
        #self._datafile_num = 64 # hparams_i.datafile_num
        #self._datafile_num = 3 # hparams_i.datafile_num
        #self._datafile_num = 16 # hparams_i.datafile_num
        #self._datafile_num = 26 # hparams_i.datafile_num
        self._datafile_num = 1  # hparams_i.datafile_num
        #self._datafile_num = 3 # hparams_i.datafile_num

        data_source: DataSource = ZipDataSource(*sources)

        if any(filters):

            def filter_fn(data):
                return all(
                    fn(data) for fn, data in zip(filters, data)
                    if fn is not None)

            data_source = FilterDataSource(data_source, filter_fn=filter_fn)
        #print("data init derive done")
        super(MultiAlignedData, self).__init__(data_source, self._hparams,
                                               device)
        #self._dataset_size = 3000000
        #self._dataset_size = 6400000
        #self._dataset_size = 16000000
        #self._dataset_size = 3802215
        #self._dataset_size = 1250000
        #self._dataset_size = 3000
        self._dataset_size = 834229