def _construct(cls, hparams): record_data = cls.__new__(cls) record_data._hparams = HParams(hparams, record_data.default_hparams()) feature_types = record_data._hparams.dataset.feature_original_types record_data._features = _convert_feature_hparams(feature_types) convert_types = record_data._hparams.dataset.feature_convert_types record_data._convert_types = { key: get_numpy_dtype(value) for key, value in convert_types.items() } for key, dtype in record_data._convert_types.items(): record_data._features[key] = record_data._features[key].\ _replace(dtype=dtype) image_options = record_data._hparams.dataset.image_options if isinstance(image_options, HParams): image_options = [image_options] record_data._image_transforms = {} for options in image_options: key = options.get('image_feature_name') if key is None or key not in record_data._features: continue record_data._image_transforms[key] = _create_image_transform( options.get('resize_height'), options.get('resize_width'), options.get('resize_method') or 'bilinear') record_data._other_transforms = \ record_data._hparams.dataset.other_transformations data_name = record_data._hparams.dataset.data_name record_data._items = { key: connect_name(data_name, key) for key, _ in record_data._features.items() } data_source = SequenceDataSource([]) super(RecordData, record_data).__init__(data_source, hparams) return record_data
def __init__(self, hparams, device: Optional[torch.device] = None): self._hparams = HParams(hparams, self.default_hparams()) # Defaultizes hyperparameters of each dataset datasets_hparams = self._hparams.datasets defaultized_datasets_hparams = [] for hparams_i in datasets_hparams: data_type = hparams_i.get("data_type", None) defaultized_ds_hpms = HParams(hparams_i, _default_dataset_hparams(data_type)) defaultized_datasets_hparams.append(defaultized_ds_hpms) self._hparams.datasets = defaultized_datasets_hparams self._vocab = self.make_vocab(self._hparams.datasets) self._embedding = self.make_embedding(self._hparams.datasets, self._vocab) name_prefix: List[str] = [] self._names: List[Dict[str, Any]] = [] datasources: List[DataSource] = [] filters: List[Optional[Callable[[str], bool]]] = [] self._databases: List[DataBase] = [] for idx, hparams_i in enumerate(self._hparams.datasets): dtype = hparams_i.data_type datasource_i: DataSource if _is_text_data(dtype): datasource_i = TextLineDataSource( hparams_i.files, compression_type=hparams_i.compression_type) datasources.append(datasource_i) if (hparams_i["length_filter_mode"] == _LengthFilterMode.DISCARD.value) and \ (hparams_i["max_seq_length"] is not None): def _get_filter(delimiter, max_seq_length): return lambda x: len(x.split(delimiter)) <= \ max_seq_length filters.append( _get_filter(hparams_i["delimiter"], hparams_i["max_seq_length"])) else: filters.append(None) self._names.append({ "text": connect_name(hparams_i["data_name"], "text"), "text_ids": connect_name(hparams_i["data_name"], "text_ids"), "length": connect_name(hparams_i["data_name"], "length") }) text_hparams = MonoTextData.default_hparams() for key in text_hparams["dataset"].keys(): if key in hparams_i: text_hparams["dataset"][key] = hparams_i[key] # handle prepend logic in MultiAlignedData collate function text_hparams["dataset"]["data_name"] = None self._databases.append( MonoTextData._construct(hparams=text_hparams, device=device, vocab=self._vocab[idx], embedding=self._embedding[idx])) elif _is_scalar_data(dtype): datasource_i = TextLineDataSource( hparams_i.files, compression_type=hparams_i.compression_type) datasources.append(datasource_i) filters.append(None) self._names.append({"label": hparams_i["data_name"]}) scalar_hparams = ScalarData.default_hparams() scalar_hparams.update({"dataset": hparams_i.todict()}) self._databases.append( ScalarData._construct(hparams=scalar_hparams, device=device)) elif _is_record_data(dtype): datasource_i = PickleDataSource(file_paths=hparams_i.files) datasources.append(datasource_i) self._names.append({ feature_type: connect_name(hparams_i["data_name"], feature_type) for feature_type in hparams_i.feature_original_types.keys() }) filters.append(None) record_hparams = RecordData.default_hparams() for key in record_hparams["dataset"].keys(): if key in hparams_i: record_hparams["dataset"][key] = hparams_i[key] self._databases.append( RecordData._construct(hparams=record_hparams)) else: raise ValueError("Unknown data type: %s" % hparams_i.data_type) # check for duplicate names for i in range(1, len(name_prefix)): if name_prefix[i] in name_prefix[:i - 1]: raise ValueError("Data name duplicated: %s" % name_prefix[i]) name_prefix.append(hparams_i["data_name"]) self._name_to_id = {v: k for k, v in enumerate(name_prefix)} datasource: DataSource datasource = ZipDataSource(*datasources) if any(filters): def filter_fn(data): return all(filters[idx](data_i) for idx, data_i in enumerate(data) if filters[idx] is not None) datasource = FilterDataSource(datasource, filter_fn=filter_fn) super().__init__(datasource, self._hparams, device)