示例#1
0
    def _make_processor(dataset_hparams, data_spec, name_prefix):
        processors = []
        for i, hparams_i in enumerate(dataset_hparams):
            data_spec_i = data_spec.get_ith_data_spec(i)

            data_type = hparams_i["data_type"]
            if _is_text_data(data_type):
                tgt_proc_hparams = hparams_i
                proc_shr = hparams_i["processing_share_with"]
                if proc_shr is not None:
                    tgt_proc_hparams = copy.copy(dataset_hparams[proc_shr])
                    try:
                        tgt_proc_hparams["variable_utterance"] = \
                                hparams_i["variable_utterance"]
                    except TypeError:
                        tgt_proc_hparams.variable_utterance = \
                                hparams_i["variable_utterance"]

                processor, data_spec_i = MonoTextData._make_processor(
                    tgt_proc_hparams, data_spec_i)
            elif _is_scalar_data(data_type):
                processor, data_spec_i = ScalarData._make_processor(
                    hparams_i, data_spec_i, name_prefix='')
            else:
                raise ValueError("Unsupported data type: %s" % data_type)

            processors.append(processor)
            data_spec.set_ith_data_spec(i, data_spec_i, len(dataset_hparams))

        tran_fn = dsutils.make_combined_transformation(processors,
                                                       name_prefix=name_prefix)

        data_spec.add_spec(name_prefix=name_prefix)

        return tran_fn, data_spec
    def __init__(self, hparams, device: Optional[torch.device] = None):
        self._hparams = HParams(hparams, self.default_hparams())
        # Defaultizes hyperparameters of each dataset
        datasets_hparams = self._hparams.datasets
        defaultized_datasets_hparams = []
        for hparams_i in datasets_hparams:
            data_type = hparams_i.get("data_type", None)
            defaultized_ds_hpms = HParams(hparams_i,
                                          _default_dataset_hparams(data_type))
            defaultized_datasets_hparams.append(defaultized_ds_hpms)
        self._hparams.datasets = defaultized_datasets_hparams

        self._vocab = self.make_vocab(self._hparams.datasets)
        self._embedding = self.make_embedding(self._hparams.datasets,
                                              self._vocab)

        name_prefix: List[str] = []
        self._names: List[Dict[str, Any]] = []
        datasources: List[DataSource] = []
        filters: List[Optional[Callable[[str], bool]]] = []
        self._databases: List[DataBase] = []
        for idx, hparams_i in enumerate(self._hparams.datasets):
            dtype = hparams_i.data_type
            datasource_i: DataSource

            if _is_text_data(dtype):
                datasource_i = TextLineDataSource(
                    hparams_i.files,
                    compression_type=hparams_i.compression_type)
                datasources.append(datasource_i)
                if (hparams_i["length_filter_mode"] ==
                    _LengthFilterMode.DISCARD.value) and \
                        (hparams_i["max_seq_length"] is not None):

                    def _get_filter(delimiter, max_seq_length):
                        return lambda x: len(x.split(delimiter)) <= \
                               max_seq_length

                    filters.append(
                        _get_filter(hparams_i["delimiter"],
                                    hparams_i["max_seq_length"]))
                else:
                    filters.append(None)

                self._names.append({
                    "text":
                    connect_name(hparams_i["data_name"], "text"),
                    "text_ids":
                    connect_name(hparams_i["data_name"], "text_ids"),
                    "length":
                    connect_name(hparams_i["data_name"], "length")
                })

                text_hparams = MonoTextData.default_hparams()
                for key in text_hparams["dataset"].keys():
                    if key in hparams_i:
                        text_hparams["dataset"][key] = hparams_i[key]
                # handle prepend logic in MultiAlignedData collate function
                text_hparams["dataset"]["data_name"] = None

                self._databases.append(
                    MonoTextData._construct(hparams=text_hparams,
                                            device=device,
                                            vocab=self._vocab[idx],
                                            embedding=self._embedding[idx]))
            elif _is_scalar_data(dtype):
                datasource_i = TextLineDataSource(
                    hparams_i.files,
                    compression_type=hparams_i.compression_type)
                datasources.append(datasource_i)
                filters.append(None)
                self._names.append({"label": hparams_i["data_name"]})
                scalar_hparams = ScalarData.default_hparams()
                scalar_hparams.update({"dataset": hparams_i.todict()})
                self._databases.append(
                    ScalarData._construct(hparams=scalar_hparams,
                                          device=device))
            elif _is_record_data(dtype):
                datasource_i = PickleDataSource(file_paths=hparams_i.files)
                datasources.append(datasource_i)
                self._names.append({
                    feature_type: connect_name(hparams_i["data_name"],
                                               feature_type)
                    for feature_type in
                    hparams_i.feature_original_types.keys()
                })
                filters.append(None)
                record_hparams = RecordData.default_hparams()
                for key in record_hparams["dataset"].keys():
                    if key in hparams_i:
                        record_hparams["dataset"][key] = hparams_i[key]
                self._databases.append(
                    RecordData._construct(hparams=record_hparams))
            else:
                raise ValueError("Unknown data type: %s" % hparams_i.data_type)

            # check for duplicate names
            for i in range(1, len(name_prefix)):
                if name_prefix[i] in name_prefix[:i - 1]:
                    raise ValueError("Data name duplicated: %s" %
                                     name_prefix[i])

            name_prefix.append(hparams_i["data_name"])

        self._name_to_id = {v: k for k, v in enumerate(name_prefix)}

        datasource: DataSource
        datasource = ZipDataSource(*datasources)

        if any(filters):

            def filter_fn(data):
                return all(filters[idx](data_i)
                           for idx, data_i in enumerate(data)
                           if filters[idx] is not None)

            datasource = FilterDataSource(datasource, filter_fn=filter_fn)
        super().__init__(datasource, self._hparams, device)