Ejemplo n.º 1
0
def main():
    args = parse_args()

    logfmt = "%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s"
    if args.verbose > 0:
        logging.basicConfig(level=logging.INFO, format=logfmt)
    else:
        logging.basicConfig(level=logging.WARN, format=logfmt)
    logger.info(get_commandline_args())

    utt_text_speaker = consolidate_utt_info(scp=None,
                                            text=args.text_file,
                                            utt2spk=args.utt2spk_file)

    with kaldiio.ReadHelper(
            args.rspecifier,
            segments=args.segments) as reader, file_writer_helper(
                args.wspecifier,
                filetype=args.archive_format,
                compress=args.compress,
                compression_method=args.compression_method,
                sample_frequency=args.sample_frequency,
                transform=Transformation(args.feature_config)) as writer:
        for utt_id, (rate, wave) in tqdm.tqdm(reader,
                                              miniters=100,
                                              maxinterval=30):
            utt_dict = {"x": wave, "rate": rate}
            utt_dict.update(utt_text_speaker.get(utt_id, {}))
            try:
                writer[utt_id] = utt_dict
            except Exception as e:
                logger.warning(
                    f"Failed to process utterance {utt_id} with exception:\n{str(e)}"
                )
                continue
Ejemplo n.º 2
0
def main():
    args = parse_args()

    # logging info
    logfmt = "%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s"
    if args.verbose > 0:
        logging.basicConfig(level=logging.INFO, format=logfmt)
    else:
        logging.basicConfig(level=logging.WARN, format=logfmt)
    logger.info(get_commandline_args())

    cmvn = Transformation([{"type": "cmvn",
                            "stats": args.stats_file,
                            "cmvn_type": args.cmvn_type,
                            "norm_means": args.norm_means,
                            "norm_vars": args.norm_vars,
                            "utt2spk": args.utt2spk,
                            "reverse": args.reverse}])

    with file_writer_helper(
        args.wspecifier,
        filetype=args.out_filetype,
        compress=args.compress,
        compression_method=args.compression_method,
    ) as writer:
        for utt, data in file_reader_helper(args.rspecifier, args.in_filetype,
                                            transform=cmvn, return_dict=True):
            writer[utt] = data
Ejemplo n.º 3
0
    def __init__(self, wspecifier: str, sample_frequency: Union[float, None],
                 transform: Union[Transformation,
                                  None], max_hr_per_file: Union[float, None]):
        self.spec_dict = parse_wspecifier(wspecifier)
        self.ark = self.spec_dict["ark"]
        self.input_hz = sample_frequency
        self.transform = Transformation() if transform is None else transform
        self.output_hz = self.input_hz
        for fn in self.transform.functions:
            if hasattr(fn, "sample_frequency"):
                self.output_hz = fn.sample_frequency
            if hasattr(fn, "frame_shift_ms"):
                self.output_hz = 1000 / fn.frame_shift_ms

        if "scp" in self.spec_dict:
            self.writer_scp = open(self.spec_dict["scp"],
                                   "w",
                                   encoding="utf-8")
        else:
            self.writer_scp = None

        self.writer = None
        self.file_idx = 1
        self.cur_file_sec = 0
        self.max_file_sec = max_hr_per_file * 3600 if max_hr_per_file else None
Ejemplo n.º 4
0
def main():
    parser = get_parser()
    args = parser.parse_args()

    # logging info
    logfmt = "%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s"
    if args.verbose > 0:
        logging.basicConfig(level=logging.INFO, format=logfmt)
    else:
        logging.basicConfig(level=logging.WARN, format=logfmt)
    logger.info(get_commandline_args())

    if args.preprocess_conf is not None:
        preprocessing = Transformation(args.preprocess_conf)
        logger.info("Apply preprocessing: {}".format(preprocessing))
    else:
        preprocessing = None

    for utt, shape in file_reader_helper(args.rspecifier,
                                         args.filetype,
                                         return_shape=True,
                                         transform=preprocessing):
        shape_str = ",".join(map(str, shape))  # shape is a tuple of ints
        args.out.write("{} {}\n".format(utt, shape_str))
Ejemplo n.º 5
0
    def __init__(self, datasets: Union[str, List[str]],
                 task="asr", precomputed_feats_type="raw",
                 text_filename="text",
                 transform_conf: Union[str, List[Dict[str, Any]]] = None,
                 batch_size=1, max_len=None, train=False, shuffle=False,
                 num_replicas=None, rank=None, ensure_equal_parts=True,
                 num_workers=0, data_cache_mb=2048,
                 spmodel=None, token_list=None):
        """
        :param datasets: a list of strings specifying which datasets to load.
            Each dataset should be formatted as `<dataset_name>/<split_name>`,
            e.g. `wsj/train_si284`, `swbd/rt03`, `librispeech/dev-other`, etc.
        :param task: either "asr" or "tts"
        :param precomputed_feats_type: "raw", "fbank", or "fbank_pitch". Tells
            the data loader where to look for archive files, and what sort of
            pre-computed features have been dumped to those archives.
        :param text_filename: name of the file associating each utterance to its
            transcription. Can be useful to override if you want to transcribe
            (for example) phones instead of characters.
        :param transform_conf: either the filename of a `yaml` file specifying a
            transformation, or a `List[Dict[str, Any]]` specifying that
            transformation. `None` means no data transformation.
        :param batch_size: batch size
        :param train: whether the dataset's transform should be applied in
            training mode
        :param shuffle: whether to shuffle the dataset
        :param max_len: the maximum average utterance length of a batch (after
            any data transformation is applied). The sum of utterance lengths
            in the batch is restricted to be less than `batch_size * max_len`.
        :param num_replicas: the number of distributed copies of the dataset
            being used, e.g. for training a `DistributedDataParallel` model.
            If you are running multiple jobs in parallel, and want each job to
            work on a different subset of the dataset, set this parameter to the
            total number of jobs. You probably don't need to specify this
            manually if you are using torch.distributed.
        :param rank: the index (amongst all distributed workers) of the current
            worker, e.g. for training a `DistributedDataParallel` model. If you
            are running multiple jobs in parallel, and want each job to work on
            a different subset of the dataset, set this parameter to the index
            of the current job. You probably don't need to specify this manually
            if you are using torch.distributed.
        :param ensure_equal_parts: Whether to ensure that all parallel processes
            receive the same number of batches to process. This should always be
            enabled if you are training with `DistributedDataParallel`, but you
            may wish to disable it if you are evaluating your model and wish to
            have each utterance processed exactly once.
        :param num_workers: the number of parallel processes to use to
            apply the data transformation (specified by `transform_conf`) to
            the data being loaded. If `None`, the value used is
            `math.ceil(os.n_cpu() / num_replicas) - 1`. Note that there are
            some known issues with this feature! If you are running into
            hanging/deadlock with `num_workers > 0`, change `num_workers`
            to `0` before trying anything else.
        :param data_cache_mb: the number of megabytes the cache (for
            pre-fetching archive files into memory) can contain.
        :param spmodel: the path to a `sentencepiece` model to use to tokenize
            the text. Can be trained with BPE, word, or char.
        :param token_list: the path to a list of `sentencepiece` tokens to use
            use to tokenize the text. The indices in `token_list` override those
            in `spmodel`. By default, we will search the directory containing
            `spmodel` for a `tokens.txt` and use it if found.
        """
        # For using the next() syntax
        self.iter = None
        self.current_position = 0

        # Initialize parameters for distributed data loading
        is_dist = dist.is_available() and dist.is_initialized()
        if num_replicas is None:
            num_replicas = dist.get_world_size() if is_dist else 1
        self.num_replicas = num_replicas
        if rank is None:
            rank = dist.get_rank() if is_dist else 0
        self.rank = rank
        if rank >= num_replicas:
            raise ValueError("You must specify rank < n_replicas. "
                             f"(Got rank={rank}, n_replicas={num_replicas}).")

        # Build a sentencepiece tokenizer if desired
        self.tokenizer = None
        if spmodel is not None:
            if token_list is None:
                token_list = os.path.join(os.path.dirname(spmodel), "tokens.txt")
            elif isinstance(token_list, list):
                if not all(isinstance(item, str) for item in token_list):
                    token_list = None
            if (isinstance(token_list, str) or isinstance(token_list, Path)) and not os.path.isfile(token_list):
                token_list = None
            self.tokenizer = SentencepieceTokenizer(spmodel, token_list)

        # Get all the datasets & their sub-datasets, and validate them
        datasets = [datasets] if isinstance(datasets, str) else datasets
        assert len(datasets) > 0, "Cannot load 0 datasets"
        scp_files, reader_class = get_dataset_scps(datasets, task, precomputed_feats_type)

        try:
            self.aux_utt_info = {}
            for scp in scp_files:
                scp_dir = os.path.dirname(scp)
                self.aux_utt_info.update(consolidate_utt_info(
                    text=os.path.join(scp_dir, text_filename),
                    utt2spk=os.path.join(scp_dir, "utt2spk"),
                    utt2num_frames=os.path.join(scp_dir, "utt2num_frames")))
        except FileNotFoundError:
            self.aux_utt_info = {}
            scps, _ = get_dataset_scps(datasets, task, precomputed_feats_type)
            for scp in scps:
                scp_dir = os.path.dirname(scp)
                self.aux_utt_info.update(consolidate_utt_info(
                    text=os.path.join(scp_dir, text_filename),
                    utt2spk=os.path.join(scp_dir, "utt2spk"),
                    utt2num_frames=os.path.join(scp_dir, "utt2num_frames")))

        # Initialize the transform & determine how it re-samples inputs
        transform = Transformation(transform_conf, precomputed_feats_type)
        input_hz, output_hz = None, None
        for fn in transform.functions:
            if hasattr(fn, "sample_frequency"):
                if input_hz is None:
                    input_hz = fn.sample_frequency
                output_hz = fn.sample_frequency
            if hasattr(fn, "frame_shift_ms"):
                if input_hz is None:
                    input_hz = 1000 / fn.frame_shift_ms
                output_hz = 1000 / fn.frame_shift_ms

        # Get the approximate length of each utterance, post-transformation
        ratio = output_hz / input_hz if input_hz and output_hz else 1.0
        utt2len = {k: v["length"] * ratio for k, v in self.aux_utt_info.items()}

        # Combine all the relevant SCP files into a single temporary file, and
        # use this temporary file to initialize the actual dataset
        with NamedTemporaryFile(delete=True) as tmpfile:
            with open(tmpfile.name, "ab") as dst:
                for scp in scp_files:
                    with open(scp, "rb") as src:
                        shutil.copyfileobj(src, dst)

            dataset = reader_class(
                f"scp:{tmpfile.name}", return_dict=True, train=train,
                shuffle=shuffle, n_parts=num_replicas, i_part=rank,
                transform=transform, pre_fetch_next_epoch=True,
                num_workers=num_workers, data_cache_mb=data_cache_mb,
                batch_size=batch_size, max_len=max_len,
                utt2len=utt2len, ensure_equal_parts=ensure_equal_parts)

            super().__init__(dataset, batch_size=1, shuffle=False,
                             collate_fn=self.collate_fn)
Ejemplo n.º 6
0
    def __init__(self,
                 rspecifier,
                 return_shape=False,
                 return_dict=False,
                 transform: Transformation = None,
                 train=False,
                 batch_size: int = None,
                 max_len: int = None,
                 utt2len=None,
                 ensure_equal_parts=True,
                 pre_fetch_next_epoch=False,
                 data_cache_mb=2048,
                 num_workers: int = 1,
                 n_parts=1,
                 i_part=0,
                 shuffle=False,
                 seed=0):
        if ":" not in rspecifier:
            raise ValueError(
                f'Give "rspecifier" such as "scp:some.scp: {rspecifier}"')
        self.ark_or_scp, self.filepath = rspecifier.split(":", maxsplit=1)
        if self.ark_or_scp not in ["scp", "ark"]:
            raise ValueError(
                f"Filetype must be scp or ark, but got {self.ark_or_scp}")
        elif self.ark_or_scp == "scp":
            self._full_scp_dict = parse_scp_file(self.filepath)
        else:
            self._full_scp_dict = None
            if pre_fetch_next_epoch:
                raise ValueError(
                    f"Cannot pre-fetch next epoch if reading file directly. "
                    f"rspecifier={rspecifier}, pre_fetch_next_epoch=True. "
                    f"Change rspecifier to scp, or pre_fetch_next_epoch=False."
                )
            if shuffle:
                raise ValueError(
                    f"Cannot shuffle data if reading file directly. "
                    f"rspecifier={rspecifier}, shuffle=True. "
                    f"Change rspecifier to scp, or shuffle=False.")
            if max_len is not None:
                raise ValueError(
                    f"Cannot enforce dynamic batching if rreading file directly. "
                    f"rspecifier={rspecifier}, max_len={max_len}. "
                    f"Change rspecifier to scp, or max_len=None.")

        # For loading data from SCP
        self._n_parts = n_parts
        self._i_part = i_part
        self._shuffle = shuffle
        self._batch_size = batch_size
        self._max_len = max_len
        self._utt2len = utt2len
        self._ensure_equal_parts = ensure_equal_parts
        self._bszs = None

        # For determining the data format to return
        self.return_dict = return_dict
        self.return_shape = return_shape and not return_dict
        self.train = train
        self.transform = Transformation() if transform is None else transform

        # Set up an actor pool to apply the transform if needed
        if self.transform.is_null() or (num_workers is not None
                                        and num_workers < 1):
            self.num_workers, self.actor_pool = 0, None
        else:
            ncpu = os.cpu_count() or 1
            if num_workers is None:
                num_workers = max(1, math.ceil(ncpu / self._n_parts) - 1)
            self.num_workers = num_workers

            global _ray_refs
            _ray_refs += 1
            if not ray.is_initialized():
                ray.init(num_gpus=0,
                         include_dashboard=False,
                         ignore_reinit_error=True)
            actors = [
                TransformActor.remote(transform, train)
                for _ in range(num_workers)
            ]
            self.actor_pool = ray.util.ActorPool(actors)

        # For pre-fetching and caching hdf5/ark files in memory.
        self.pre_fetch_next_epoch = pre_fetch_next_epoch
        self.queue = FileQueue(max_size=data_cache_mb * (2**20),
                               read_file=self.get_file_dict,
                               get_file_size=self.file_dict_size)
        self.files_loaded = None
        self.thread_pool = ThreadPoolExecutor(max_workers=1)
        self.seed = seed

        # Make sure that the data loader is shut down in case of premature exits
        atexit.register(_weakref_close, weakref.ref(self))
        atexit.register(ray.shutdown)
Ejemplo n.º 7
0
class BaseReader(IterableDataset):
    """Uses a .scp file to read data from .h5 files. Pre-fetches .h5 files into
    dictionaries stored in memory for efficient access."""
    def __init__(self,
                 rspecifier,
                 return_shape=False,
                 return_dict=False,
                 transform: Transformation = None,
                 train=False,
                 batch_size: int = None,
                 max_len: int = None,
                 utt2len=None,
                 ensure_equal_parts=True,
                 pre_fetch_next_epoch=False,
                 data_cache_mb=2048,
                 num_workers: int = 1,
                 n_parts=1,
                 i_part=0,
                 shuffle=False,
                 seed=0):
        if ":" not in rspecifier:
            raise ValueError(
                f'Give "rspecifier" such as "scp:some.scp: {rspecifier}"')
        self.ark_or_scp, self.filepath = rspecifier.split(":", maxsplit=1)
        if self.ark_or_scp not in ["scp", "ark"]:
            raise ValueError(
                f"Filetype must be scp or ark, but got {self.ark_or_scp}")
        elif self.ark_or_scp == "scp":
            self._full_scp_dict = parse_scp_file(self.filepath)
        else:
            self._full_scp_dict = None
            if pre_fetch_next_epoch:
                raise ValueError(
                    f"Cannot pre-fetch next epoch if reading file directly. "
                    f"rspecifier={rspecifier}, pre_fetch_next_epoch=True. "
                    f"Change rspecifier to scp, or pre_fetch_next_epoch=False."
                )
            if shuffle:
                raise ValueError(
                    f"Cannot shuffle data if reading file directly. "
                    f"rspecifier={rspecifier}, shuffle=True. "
                    f"Change rspecifier to scp, or shuffle=False.")
            if max_len is not None:
                raise ValueError(
                    f"Cannot enforce dynamic batching if rreading file directly. "
                    f"rspecifier={rspecifier}, max_len={max_len}. "
                    f"Change rspecifier to scp, or max_len=None.")

        # For loading data from SCP
        self._n_parts = n_parts
        self._i_part = i_part
        self._shuffle = shuffle
        self._batch_size = batch_size
        self._max_len = max_len
        self._utt2len = utt2len
        self._ensure_equal_parts = ensure_equal_parts
        self._bszs = None

        # For determining the data format to return
        self.return_dict = return_dict
        self.return_shape = return_shape and not return_dict
        self.train = train
        self.transform = Transformation() if transform is None else transform

        # Set up an actor pool to apply the transform if needed
        if self.transform.is_null() or (num_workers is not None
                                        and num_workers < 1):
            self.num_workers, self.actor_pool = 0, None
        else:
            ncpu = os.cpu_count() or 1
            if num_workers is None:
                num_workers = max(1, math.ceil(ncpu / self._n_parts) - 1)
            self.num_workers = num_workers

            global _ray_refs
            _ray_refs += 1
            if not ray.is_initialized():
                ray.init(num_gpus=0,
                         include_dashboard=False,
                         ignore_reinit_error=True)
            actors = [
                TransformActor.remote(transform, train)
                for _ in range(num_workers)
            ]
            self.actor_pool = ray.util.ActorPool(actors)

        # For pre-fetching and caching hdf5/ark files in memory.
        self.pre_fetch_next_epoch = pre_fetch_next_epoch
        self.queue = FileQueue(max_size=data_cache_mb * (2**20),
                               read_file=self.get_file_dict,
                               get_file_size=self.file_dict_size)
        self.files_loaded = None
        self.thread_pool = ThreadPoolExecutor(max_workers=1)
        self.seed = seed

        # Make sure that the data loader is shut down in case of premature exits
        atexit.register(_weakref_close, weakref.ref(self))
        atexit.register(ray.shutdown)

    @abstractmethod
    def get_file_dict(self, path: str, uttid_locs: List[Tuple[str, str]] = None) \
            -> Dict[str, Dict[str, Any]]:
        """Gets a dict of items associated with the file at the given path.
        file_dict[uttid] contains all the data associated with uttid."""
        raise NotImplementedError

    def __del__(self):
        """Shut down the thread pool."""
        if self.actor_pool is not None:
            global _ray_refs
            _ray_refs -= 1
            if _ray_refs == 0:
                ray.shutdown()
        self.thread_pool.shutdown(wait=False)

    def __len__(self):
        if self.ark_or_scp == "ark":
            raise RuntimeError(
                "Cannot get length of Reader reading directly from ark/h5")
        return len(self._bszs)

    @property
    def num_utts(self):
        if self.ark_or_scp == "ark":
            raise RuntimeError(
                "Cannot get number of utterances in Reader reading directly from ark/h5"
            )
        return sum(self._bszs)

    def close(self):
        """Empties the queue and suspends loading any files not yet loaded."""
        self.queue.clear()  # this will stop any pending queue.put()
        files_future = self.files_loaded
        if files_future is not None and not files_future.cancelled():
            self.files_loaded.cancel()
            while not (files_future.done() or files_future.cancelled()):
                pass
        assert self.queue.empty()
        self.queue.open()
        self._fetch_seed = self._seed

    @property
    def seed(self):
        return self._seed

    @seed.setter
    def seed(self, seed):
        """Sets the random seed & cancels the current file-loading job if we are
        currently (or have already) pre-fetched data for the wrong seed."""
        self._seed = seed
        if not hasattr(self, "_fetch_seed") or (seed != self._fetch_seed
                                                and self.shuffle):
            self.close()
            self._fetch_seed = seed
            scp_dict, bszs = self.get_scp_dict_and_bszs()
            self._bszs = bszs
            if self.pre_fetch_next_epoch:
                self.load_files(scp_dict)
        else:
            self._fetch_seed = seed

    @property
    def shuffle(self):
        return self._shuffle

    def get_scp_dict_and_bszs(self):
        """Shuffles the SCP dict and gets the relevant split to load from."""
        batch_size = 1 if self._batch_size is None else self._batch_size
        scp_dict, bszs = split_scp_dict(
            self._full_scp_dict,
            n_parts=self._n_parts,
            randomize=self.shuffle,
            seed=self._fetch_seed,
            batch_size=batch_size,
            max_len=self._max_len,
            utt2len=self._utt2len,
            ensure_equal_parts=self._ensure_equal_parts)[self._i_part]
        if self._bszs is None:
            self._bszs = bszs
        return scp_dict, bszs

    def load_files(self, scp_dict=None):
        """Schedules a background thread to read the relevant file dicts from
        either the SCP dict given, or the one from self.get_scp_dict_and_bszs().
        The thread finishes as soon as it fails to load a file (which happens
        if queue.clear() is called) or once it's done loading all files."""
        def wrapper(d):
            for path, uttid_locs in d.items():
                if not self.queue.put(path, uttid_locs):
                    break

        if scp_dict is None:
            scp_dict, _ = self.get_scp_dict_and_bszs()
        self.files_loaded = self.thread_pool.submit(wrapper, scp_dict)

    @staticmethod
    def file_dict_size(file_dict):
        """file_dict[utt_id] is a dict containing a numpy array (under key
        "x") and other metadata. We primarily care about the size of the
        numpy arrays in the file_dict, so this is the size we return."""
        return sum(map(lambda utt: utt["x"].nbytes, file_dict.values()))

    def datadict_to_output(self, datadict):
        """Converts the dict retrieved from the archive into the desired form."""
        if self.return_dict:
            return datadict
        if self.return_shape:
            return datadict["x"].shape
        return datadict["x"]

    def output_iterator(self, file_dict):
        """Output iterator which applies self.transform to all the signals in
        file_dict, and returns them (in desired format) along w/ their keys."""
        if self.actor_pool is not None:
            it = self.actor_pool.map(lambda a, v: a.apply.remote(v),
                                     file_dict.items())
        else:
            it = map(
                lambda kd: apply_transform(self.transform, kd, self.train),
                file_dict.items())
        return zip(file_dict.keys(), map(self.datadict_to_output, it))

    def __iter__(self):
        """Loads the data one .h5 archive at a time. Uses a queue to pre-fetch
        the contents of each file into memory (while not exceeding our maximum
        allowed cache size). Yields data points one at a time."""
        if self.ark_or_scp == "scp":
            # Fetch data for this epoch in the background (if not pre-fetched)
            scp_dict, bszs = self.get_scp_dict_and_bszs()
            self._bszs = bszs
            if (self.files_loaded is None or self.files_loaded.done()
                    or self.files_loaded.cancelled()) and self.queue.empty():
                self.load_files(scp_dict)

            i_batch, batch = 0, []
            for j_file, expected_path in enumerate(scp_dict):
                n_from_file = 0
                path, file_dict = self.queue.get(expected_path)

                # Pre-fetch the next epoch as soon as we're done with this one
                if self.pre_fetch_next_epoch and self._fetch_seed == self._seed \
                        and self.files_loaded.done():
                    self._fetch_seed += 1
                    self.load_files()

                output_iterator = self.output_iterator(file_dict)
                if self._batch_size is None:
                    yield from output_iterator
                else:
                    try:
                        for i_batch in range(i_batch, len(bszs)):
                            while len(batch) < bszs[i_batch]:
                                batch.append(next(output_iterator))
                                n_from_file = n_from_file + 1
                            if i_batch == len(bszs) - 1 and j_file == len(
                                    scp_dict) - 1:
                                raise StopIteration
                            else:
                                assert len(batch) == bszs[i_batch], \
                                    f"Expected batch {i_batch} to have size {bszs[i_batch]}, " \
                                    f"but got a batch of size {len(batch)} instead"
                                yield batch
                                batch = []

                    except StopIteration:
                        assert n_from_file == len(file_dict), \
                            f"Expected to get {len(file_dict)} utts from " \
                            f"{os.path.basename(path)}, but got {n_from_file}."
                        logger.debug(
                            f"FINISHED FILE {j_file+1}/{len(scp_dict)}: "
                            f"{os.path.basename(path)}")
                        if i_batch == len(bszs) - 1 and j_file == len(
                                scp_dict) - 1:
                            assert len(batch) == bszs[i_batch], \
                                f"Expected batch {i_batch} to have size {bszs[i_batch]}, " \
                                f"but got a batch of size {len(batch)} instead"
                            yield batch
                        continue

        else:  # self.ark_or_scp == "ark"
            if self.filepath == "-":  # Required h5py>=2.9 for hdf5
                filepath = io.BytesIO(sys.stdin.buffer.read())
            else:
                filepath = self.filepath
            file_dict = self.get_file_dict(filepath)
            output_iterator = self.output_iterator(file_dict)

            if self._batch_size is None:
                yield from output_iterator
            else:
                done = False
                while not done:
                    batch = []
                    try:
                        while len(batch) < self._batch_size:
                            batch.append(next(output_iterator))
                        yield batch
                    except StopIteration:
                        done = True
                        yield batch
Ejemplo n.º 8
0
def main():
    args = parse_args()

    logfmt = "%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s"
    logging.basicConfig(level=logging.INFO if args.verbose else logging.WARN,
                        format=logfmt)
    logger.info(get_commandline_args())

    if args.cmvn_type == "speaker":
        logger.info("Performing as speaker CMVN mode")
        utt2spk_dict = {}
        with open(args.spk2utt) as f:
            for line in f:
                spk, utts = line.rstrip().split(None, maxsplit=1)
                for utt in utts.split():
                    utt2spk_dict[utt] = spk

        def utt2spk(x):
            return utt2spk_dict[x]

    else:
        logger.info(f"Performing as {args.cmvn_type} CMVN mode")
        if args.spk2utt is not None:
            logger.warning(
                f"spk2utt is not used for {args.cmvn_type} CMVN mode")

        if args.cmvn_type == "utterance":

            def utt2spk(x):
                return x

        else:  # args.cmvn_type == "global"

            def utt2spk(x):
                return None

    if args.preprocess_conf is not None:
        preprocessing = Transformation(args.preprocess_conf)
        logger.info("Apply preprocessing: {}".format(preprocessing))
    else:
        preprocessing = None

    # Calculate stats for each "speaker"
    cmvn_stats, n = {}, 0
    for utt, matrix in tqdm.tqdm(
            file_reader_helper(args.rspecifier,
                               args.filetype,
                               transform=preprocessing)):
        # Init at the first seen of the spk
        spk = utt2spk(utt)
        spk_stats = CMVNStats(count=matrix.shape[0],
                              sum=matrix.sum(axis=0),
                              sum_squares=(matrix**2).sum(axis=0))
        if spk not in cmvn_stats:
            cmvn_stats[spk] = spk_stats
        else:
            cmvn_stats[spk] += spk_stats
        n += 1
    logger.info(f"Processed {n} utterances")
    assert n > 0, n

    write_cmvn_stats(args.wfilename, args.cmvn_type, cmvn_stats)