Esempio n. 1
0
    def _construct(cls, hparams):
        record_data = cls.__new__(cls)
        record_data._hparams = HParams(hparams, record_data.default_hparams())

        feature_types = record_data._hparams.dataset.feature_original_types
        record_data._features = _convert_feature_hparams(feature_types)

        convert_types = record_data._hparams.dataset.feature_convert_types
        record_data._convert_types = {
            key: get_numpy_dtype(value)
            for key, value in convert_types.items()
        }
        for key, dtype in record_data._convert_types.items():
            record_data._features[key] = record_data._features[key].\
                _replace(dtype=dtype)

        image_options = record_data._hparams.dataset.image_options
        if isinstance(image_options, HParams):
            image_options = [image_options]
        record_data._image_transforms = {}
        for options in image_options:
            key = options.get('image_feature_name')
            if key is None or key not in record_data._features:
                continue
            record_data._image_transforms[key] = _create_image_transform(
                options.get('resize_height'), options.get('resize_width'),
                options.get('resize_method') or 'bilinear')

        record_data._other_transforms = \
            record_data._hparams.dataset.other_transformations

        data_name = record_data._hparams.dataset.data_name
        record_data._items = {
            key: connect_name(data_name, key)
            for key, _ in record_data._features.items()
        }

        data_source = SequenceDataSource([])

        super(RecordData, record_data).__init__(data_source, hparams)
        return record_data
    def __init__(self, hparams, device: Optional[torch.device] = None):
        self._hparams = HParams(hparams, self.default_hparams())
        # Defaultizes hyperparameters of each dataset
        datasets_hparams = self._hparams.datasets
        defaultized_datasets_hparams = []
        for hparams_i in datasets_hparams:
            data_type = hparams_i.get("data_type", None)
            defaultized_ds_hpms = HParams(hparams_i,
                                          _default_dataset_hparams(data_type))
            defaultized_datasets_hparams.append(defaultized_ds_hpms)
        self._hparams.datasets = defaultized_datasets_hparams

        self._vocab = self.make_vocab(self._hparams.datasets)
        self._embedding = self.make_embedding(self._hparams.datasets,
                                              self._vocab)

        name_prefix: List[str] = []
        self._names: List[Dict[str, Any]] = []
        datasources: List[DataSource] = []
        filters: List[Optional[Callable[[str], bool]]] = []
        self._databases: List[DataBase] = []
        for idx, hparams_i in enumerate(self._hparams.datasets):
            dtype = hparams_i.data_type
            datasource_i: DataSource

            if _is_text_data(dtype):
                datasource_i = TextLineDataSource(
                    hparams_i.files,
                    compression_type=hparams_i.compression_type)
                datasources.append(datasource_i)
                if (hparams_i["length_filter_mode"] ==
                    _LengthFilterMode.DISCARD.value) and \
                        (hparams_i["max_seq_length"] is not None):

                    def _get_filter(delimiter, max_seq_length):
                        return lambda x: len(x.split(delimiter)) <= \
                               max_seq_length

                    filters.append(
                        _get_filter(hparams_i["delimiter"],
                                    hparams_i["max_seq_length"]))
                else:
                    filters.append(None)

                self._names.append({
                    "text":
                    connect_name(hparams_i["data_name"], "text"),
                    "text_ids":
                    connect_name(hparams_i["data_name"], "text_ids"),
                    "length":
                    connect_name(hparams_i["data_name"], "length")
                })

                text_hparams = MonoTextData.default_hparams()
                for key in text_hparams["dataset"].keys():
                    if key in hparams_i:
                        text_hparams["dataset"][key] = hparams_i[key]
                # handle prepend logic in MultiAlignedData collate function
                text_hparams["dataset"]["data_name"] = None

                self._databases.append(
                    MonoTextData._construct(hparams=text_hparams,
                                            device=device,
                                            vocab=self._vocab[idx],
                                            embedding=self._embedding[idx]))
            elif _is_scalar_data(dtype):
                datasource_i = TextLineDataSource(
                    hparams_i.files,
                    compression_type=hparams_i.compression_type)
                datasources.append(datasource_i)
                filters.append(None)
                self._names.append({"label": hparams_i["data_name"]})
                scalar_hparams = ScalarData.default_hparams()
                scalar_hparams.update({"dataset": hparams_i.todict()})
                self._databases.append(
                    ScalarData._construct(hparams=scalar_hparams,
                                          device=device))
            elif _is_record_data(dtype):
                datasource_i = PickleDataSource(file_paths=hparams_i.files)
                datasources.append(datasource_i)
                self._names.append({
                    feature_type: connect_name(hparams_i["data_name"],
                                               feature_type)
                    for feature_type in
                    hparams_i.feature_original_types.keys()
                })
                filters.append(None)
                record_hparams = RecordData.default_hparams()
                for key in record_hparams["dataset"].keys():
                    if key in hparams_i:
                        record_hparams["dataset"][key] = hparams_i[key]
                self._databases.append(
                    RecordData._construct(hparams=record_hparams))
            else:
                raise ValueError("Unknown data type: %s" % hparams_i.data_type)

            # check for duplicate names
            for i in range(1, len(name_prefix)):
                if name_prefix[i] in name_prefix[:i - 1]:
                    raise ValueError("Data name duplicated: %s" %
                                     name_prefix[i])

            name_prefix.append(hparams_i["data_name"])

        self._name_to_id = {v: k for k, v in enumerate(name_prefix)}

        datasource: DataSource
        datasource = ZipDataSource(*datasources)

        if any(filters):

            def filter_fn(data):
                return all(filters[idx](data_i)
                           for idx, data_i in enumerate(data)
                           if filters[idx] is not None)

            datasource = FilterDataSource(datasource, filter_fn=filter_fn)
        super().__init__(datasource, self._hparams, device)