Пример #1
0
 def setUp(self) -> None:
     self.source = TextLineDataSource(
         '../../Downloads/en-es.bicleaner07.txt.gz',
         compression_type='gzip')
     self.source.__iter__ = wrap_progress(  # type: ignore
         self.source.__iter__)
     self.num_workers = 3
     self.batch_size = 64
    def __init__(self, hparams, device: Optional[torch.device] = None):
        self._hparams = HParams(hparams, self.default_hparams())
        if self._hparams.dataset.variable_utterance:
            raise NotImplementedError

        # Create vocabulary
        self._bos_token = self._hparams.dataset.bos_token
        self._eos_token = self._hparams.dataset.eos_token
        self._other_transforms = self._hparams.dataset.other_transformations
        bos = utils.default_str(self._bos_token, SpecialTokens.BOS)
        eos = utils.default_str(self._eos_token, SpecialTokens.EOS)
        self._vocab = Vocab(self._hparams.dataset.vocab_file,
                            bos_token=bos,
                            eos_token=eos)

        # Create embedding
        self._embedding = self.make_embedding(
            self._hparams.dataset.embedding_init,
            self._vocab.token_to_id_map_py)

        self._delimiter = self._hparams.dataset.delimiter
        self._max_seq_length = self._hparams.dataset.max_seq_length
        self._length_filter_mode = _LengthFilterMode(
            self._hparams.dataset.length_filter_mode)
        self._pad_length = self._max_seq_length
        if self._pad_length is not None:
            self._pad_length += sum(
                int(x != '') for x in [self._bos_token, self._eos_token])

        if (self._length_filter_mode is _LengthFilterMode.DISCARD
                and self._max_seq_length is not None):
            data_source = TextLineDataSource(
                self._hparams.dataset.files,
                compression_type=self._hparams.dataset.compression_type,
                delimiter=self._delimiter,
                max_length=self._max_seq_length)
        else:
            data_source = TextLineDataSource(
                self._hparams.dataset.files,
                compression_type=self._hparams.dataset.compression_type)

        super().__init__(data_source, hparams, device=device)
Пример #3
0
 def __init__(self, hparams, device: Optional[torch.device] = None):
     self._hparams = HParams(hparams, self.default_hparams())
     self._other_transforms = self._hparams.dataset.other_transformations
     data_type = self._hparams.dataset["data_type"]
     self._typecast_func: Union[Type[int], Type[float]]
     if data_type == "int":
         self._typecast_func = int
         self._to_data_type = np.int32
     elif data_type == "float":
         self._typecast_func = float
         self._to_data_type = np.float32
     else:
         raise ValueError(
             "Incorrect 'data_type'. Currently 'int' and "
             "'float' are supported. Received {}".format(data_type))
     data_source = TextLineDataSource(
         self._hparams.dataset.files,
         compression_type=self._hparams.dataset.compression_type)
     super().__init__(data_source, hparams, device=device)
Пример #4
0
    def __init__(self, hparams, device: Optional[torch.device] = None):
        self._hparams = HParams(hparams, self.default_hparams())
        # Defaultizes hyperparameters of each dataset
        datasets_hparams = self._hparams.datasets
        defaultized_datasets_hparams = []
        for hparams_i in datasets_hparams:
            data_type = hparams_i.get("data_type", None)
            defaultized_ds_hpms = HParams(hparams_i,
                                          _default_dataset_hparams(data_type))
            defaultized_datasets_hparams.append(defaultized_ds_hpms)
        self._hparams.datasets = defaultized_datasets_hparams

        self._vocab = self.make_vocab(self._hparams.datasets)
        self._embedding = self.make_embedding(self._hparams.datasets,
                                              self._vocab)

        name_prefix: List[str] = []
        self._names: List[Dict[str, Any]] = []
        datasources: List[DataSource] = []
        filters: List[Optional[Callable[[str], bool]]] = []
        self._databases: List[DataBase] = []
        for idx, hparams_i in enumerate(self._hparams.datasets):
            dtype = hparams_i.data_type
            datasource_i: DataSource

            if _is_text_data(dtype):
                datasource_i = TextLineDataSource(
                    hparams_i.files,
                    compression_type=hparams_i.compression_type)
                datasources.append(datasource_i)
                if (hparams_i["length_filter_mode"] ==
                    _LengthFilterMode.DISCARD.value) and \
                        (hparams_i["max_seq_length"] is not None):

                    def _get_filter(delimiter, max_seq_length):
                        return lambda x: len(x.split(delimiter)) <= \
                               max_seq_length

                    filters.append(
                        _get_filter(hparams_i["delimiter"],
                                    hparams_i["max_seq_length"]))
                else:
                    filters.append(None)

                self._names.append({
                    "text":
                    connect_name(hparams_i["data_name"], "text"),
                    "text_ids":
                    connect_name(hparams_i["data_name"], "text_ids"),
                    "length":
                    connect_name(hparams_i["data_name"], "length")
                })

                text_hparams = MonoTextData.default_hparams()
                for key in text_hparams["dataset"].keys():
                    if key in hparams_i:
                        text_hparams["dataset"][key] = hparams_i[key]
                # handle prepend logic in MultiAlignedData collate function
                text_hparams["dataset"]["data_name"] = None

                self._databases.append(
                    MonoTextData._construct(hparams=text_hparams,
                                            device=device,
                                            vocab=self._vocab[idx],
                                            embedding=self._embedding[idx]))
            elif _is_scalar_data(dtype):
                datasource_i = TextLineDataSource(
                    hparams_i.files,
                    compression_type=hparams_i.compression_type)
                datasources.append(datasource_i)
                filters.append(None)
                self._names.append({"label": hparams_i["data_name"]})
                scalar_hparams = ScalarData.default_hparams()
                scalar_hparams.update({"dataset": hparams_i.todict()})
                self._databases.append(
                    ScalarData._construct(hparams=scalar_hparams,
                                          device=device))
            elif _is_record_data(dtype):
                datasource_i = PickleDataSource(file_paths=hparams_i.files)
                datasources.append(datasource_i)
                self._names.append({
                    feature_type: connect_name(hparams_i["data_name"],
                                               feature_type)
                    for feature_type in
                    hparams_i.feature_original_types.keys()
                })
                filters.append(None)
                record_hparams = RecordData.default_hparams()
                for key in record_hparams["dataset"].keys():
                    if key in hparams_i:
                        record_hparams["dataset"][key] = hparams_i[key]
                self._databases.append(
                    RecordData._construct(hparams=record_hparams))
            else:
                raise ValueError("Unknown data type: %s" % hparams_i.data_type)

            # check for duplicate names
            for i in range(1, len(name_prefix)):
                if name_prefix[i] in name_prefix[:i - 1]:
                    raise ValueError("Data name duplicated: %s" %
                                     name_prefix[i])

            name_prefix.append(hparams_i["data_name"])

        self._name_to_id = {v: k for k, v in enumerate(name_prefix)}

        datasource: DataSource
        datasource = ZipDataSource(*datasources)

        if any(filters):

            def filter_fn(data):
                return all(filters[idx](data_i)
                           for idx, data_i in enumerate(data)
                           if filters[idx] is not None)

            datasource = FilterDataSource(datasource, filter_fn=filter_fn)
        super().__init__(datasource, self._hparams, device)
Пример #5
0
    def __init__(self, hparams, device: Optional[torch.device] = None):
        self._hparams = HParams(hparams, self.default_hparams())

        src_hparams = self.hparams.source_dataset
        tgt_hparams = self.hparams.target_dataset

        # create vocabulary
        self._src_bos_token = src_hparams["bos_token"]
        self._src_eos_token = src_hparams["eos_token"]
        self._src_transforms = src_hparams["other_transformations"]
        self._src_vocab = Vocab(src_hparams.vocab_file,
                                bos_token=src_hparams.bos_token,
                                eos_token=src_hparams.eos_token)

        if tgt_hparams["processing_share"]:
            self._tgt_bos_token = src_hparams["bos_token"]
            self._tgt_eos_token = src_hparams["eos_token"]
        else:
            self._tgt_bos_token = tgt_hparams["bos_token"]
            self._tgt_eos_token = tgt_hparams["eos_token"]
        tgt_bos_token = utils.default_str(self._tgt_bos_token,
                                          SpecialTokens.BOS)
        tgt_eos_token = utils.default_str(self._tgt_eos_token,
                                          SpecialTokens.EOS)
        if tgt_hparams["vocab_share"]:
            if tgt_bos_token == self._src_vocab.bos_token and \
                    tgt_eos_token == self._src_vocab.eos_token:
                self._tgt_vocab = self._src_vocab
            else:
                self._tgt_vocab = Vocab(src_hparams["vocab_file"],
                                        bos_token=tgt_bos_token,
                                        eos_token=tgt_eos_token)
        else:
            self._tgt_vocab = Vocab(tgt_hparams["vocab_file"],
                                    bos_token=tgt_bos_token,
                                    eos_token=tgt_eos_token)

        # create embeddings
        self._src_embedding = MonoTextData.make_embedding(
            src_hparams.embedding_init, self._src_vocab.token_to_id_map_py)

        if self._hparams.target_dataset.embedding_init_share:
            self._tgt_embedding = self._src_embedding
        else:
            tgt_emb_file = tgt_hparams.embedding_init["file"]
            self._tgt_embedding = None
            if tgt_emb_file is not None and tgt_emb_file != "":
                self._tgt_embedding = MonoTextData.make_embedding(
                    self._tgt_vocab.token_to_id_map_py,
                    tgt_hparams.embedding_init)

        # create data source
        self._src_delimiter = src_hparams.delimiter
        self._src_max_seq_length = src_hparams.max_seq_length
        self._src_length_filter_mode = _LengthFilterMode(
            src_hparams.length_filter_mode)
        self._src_pad_length = self._src_max_seq_length
        if self._src_pad_length is not None:
            self._src_pad_length += sum(int(x is not None and x != '')
                                        for x in [src_hparams.bos_token,
                                                  src_hparams.eos_token])

        src_data_source = TextLineDataSource(src_hparams.files,
                                             compression_type=
                                             src_hparams.compression_type)

        self._tgt_transforms = tgt_hparams["other_transformations"]
        self._tgt_delimiter = tgt_hparams.delimiter
        self._tgt_max_seq_length = tgt_hparams.max_seq_length
        self._tgt_length_filter_mode = _LengthFilterMode(
            tgt_hparams.length_filter_mode)
        self._tgt_pad_length = self._tgt_max_seq_length
        if self._tgt_pad_length is not None:
            self._tgt_pad_length += sum(int(x is not None and x != '')
                                        for x in [tgt_hparams.bos_token,
                                                  tgt_hparams.eos_token])

        tgt_data_source = TextLineDataSource(tgt_hparams.files,
                                             compression_type=
                                             tgt_hparams.compression_type)

        data_source: DataSource[Tuple[str, str]]
        data_source = ZipDataSource(  # type: ignore
            src_data_source, tgt_data_source)
        if (self._src_length_filter_mode is _LengthFilterMode.DISCARD and
            self._src_max_seq_length is not None) or \
                (self._tgt_length_filter_mode is _LengthFilterMode.DISCARD and
                 self._tgt_length_filter_mode is not None):
            max_source_length = self._src_max_seq_length if \
                self._src_max_seq_length is not None else np.inf
            max_tgt_length = self._tgt_max_seq_length if \
                self._tgt_max_seq_length is not None else np.inf

            def filter_fn(raw_example):
                return len(raw_example[0].split(self._src_delimiter)) \
                       <= max_source_length and \
                       len(raw_example[1].split(self._tgt_delimiter)) \
                       <= max_tgt_length

            data_source = FilterDataSource(data_source, filter_fn)

        super().__init__(data_source, hparams, device=device)