def _make_processor(dataset_hparams, data_spec, name_prefix): processors = [] for i, hparams_i in enumerate(dataset_hparams): data_spec_i = data_spec.get_ith_data_spec(i) data_type = hparams_i["data_type"] if _is_text_data(data_type): tgt_proc_hparams = hparams_i proc_shr = hparams_i["processing_share_with"] if proc_shr is not None: tgt_proc_hparams = copy.copy(dataset_hparams[proc_shr]) try: tgt_proc_hparams["variable_utterance"] = \ hparams_i["variable_utterance"] except TypeError: tgt_proc_hparams.variable_utterance = \ hparams_i["variable_utterance"] processor, data_spec_i = MonoTextData._make_processor( tgt_proc_hparams, data_spec_i) elif _is_scalar_data(data_type): processor, data_spec_i = ScalarData._make_processor( hparams_i, data_spec_i, name_prefix='') else: raise ValueError("Unsupported data type: %s" % data_type) processors.append(processor) data_spec.set_ith_data_spec(i, data_spec_i, len(dataset_hparams)) tran_fn = dsutils.make_combined_transformation(processors, name_prefix=name_prefix) data_spec.add_spec(name_prefix=name_prefix) return tran_fn, data_spec
def __init__(self, hparams, device: Optional[torch.device] = None): self._hparams = HParams(hparams, self.default_hparams()) # Defaultizes hyperparameters of each dataset datasets_hparams = self._hparams.datasets defaultized_datasets_hparams = [] for hparams_i in datasets_hparams: data_type = hparams_i.get("data_type", None) defaultized_ds_hpms = HParams(hparams_i, _default_dataset_hparams(data_type)) defaultized_datasets_hparams.append(defaultized_ds_hpms) self._hparams.datasets = defaultized_datasets_hparams self._vocab = self.make_vocab(self._hparams.datasets) self._embedding = self.make_embedding(self._hparams.datasets, self._vocab) name_prefix: List[str] = [] self._names: List[Dict[str, Any]] = [] datasources: List[DataSource] = [] filters: List[Optional[Callable[[str], bool]]] = [] self._databases: List[DataBase] = [] for idx, hparams_i in enumerate(self._hparams.datasets): dtype = hparams_i.data_type datasource_i: DataSource if _is_text_data(dtype): datasource_i = TextLineDataSource( hparams_i.files, compression_type=hparams_i.compression_type) datasources.append(datasource_i) if (hparams_i["length_filter_mode"] == _LengthFilterMode.DISCARD.value) and \ (hparams_i["max_seq_length"] is not None): def _get_filter(delimiter, max_seq_length): return lambda x: len(x.split(delimiter)) <= \ max_seq_length filters.append( _get_filter(hparams_i["delimiter"], hparams_i["max_seq_length"])) else: filters.append(None) self._names.append({ "text": connect_name(hparams_i["data_name"], "text"), "text_ids": connect_name(hparams_i["data_name"], "text_ids"), "length": connect_name(hparams_i["data_name"], "length") }) text_hparams = MonoTextData.default_hparams() for key in text_hparams["dataset"].keys(): if key in hparams_i: text_hparams["dataset"][key] = hparams_i[key] # handle prepend logic in MultiAlignedData collate function text_hparams["dataset"]["data_name"] = None self._databases.append( MonoTextData._construct(hparams=text_hparams, device=device, vocab=self._vocab[idx], embedding=self._embedding[idx])) elif _is_scalar_data(dtype): datasource_i = TextLineDataSource( hparams_i.files, compression_type=hparams_i.compression_type) datasources.append(datasource_i) filters.append(None) self._names.append({"label": hparams_i["data_name"]}) scalar_hparams = ScalarData.default_hparams() scalar_hparams.update({"dataset": hparams_i.todict()}) self._databases.append( ScalarData._construct(hparams=scalar_hparams, device=device)) elif _is_record_data(dtype): datasource_i = PickleDataSource(file_paths=hparams_i.files) datasources.append(datasource_i) self._names.append({ feature_type: connect_name(hparams_i["data_name"], feature_type) for feature_type in hparams_i.feature_original_types.keys() }) filters.append(None) record_hparams = RecordData.default_hparams() for key in record_hparams["dataset"].keys(): if key in hparams_i: record_hparams["dataset"][key] = hparams_i[key] self._databases.append( RecordData._construct(hparams=record_hparams)) else: raise ValueError("Unknown data type: %s" % hparams_i.data_type) # check for duplicate names for i in range(1, len(name_prefix)): if name_prefix[i] in name_prefix[:i - 1]: raise ValueError("Data name duplicated: %s" % name_prefix[i]) name_prefix.append(hparams_i["data_name"]) self._name_to_id = {v: k for k, v in enumerate(name_prefix)} datasource: DataSource datasource = ZipDataSource(*datasources) if any(filters): def filter_fn(data): return all(filters[idx](data_i) for idx, data_i in enumerate(data) if filters[idx] is not None) datasource = FilterDataSource(datasource, filter_fn=filter_fn) super().__init__(datasource, self._hparams, device)