Exemplo n.º 1
0
    def __call__(
        self, speech: Union[torch.Tensor, np.ndarray]
    ) -> List[Tuple[Optional[str], List[str], List[int], Union[
            Hypothesis, ExtTransHypothesis, TransHypothesis], ]]:
        """Inference

        Args:
            data: Input speech data
        Returns:
            text, token, token_int, hyp

        """
        assert check_argument_types()

        # Input as audio signal
        if isinstance(speech, np.ndarray):
            speech = torch.tensor(speech)

        # data: (Nsamples,) -> (1, Nsamples)
        speech = speech.unsqueeze(0).to(getattr(torch, self.dtype))
        # lengths: (1,)
        lengths = speech.new_full([1],
                                  dtype=torch.long,
                                  fill_value=speech.size(1))
        batch = {"speech": speech, "speech_lengths": lengths}

        # a. To device
        batch = to_device(batch, device=self.device)

        # b. Forward Encoder
        enc, _ = self.asr_model.encode(**batch)
        if isinstance(enc, tuple):
            enc = enc[0]
        assert len(enc) == 1, len(enc)

        # c. Passed the encoder result and the beam search
        if self.beam_search_transducer:
            nbest_hyps = self.beam_search_transducer(enc[0])
        else:
            nbest_hyps = self.beam_search(x=enc[0],
                                          maxlenratio=self.maxlenratio,
                                          minlenratio=self.minlenratio)

        nbest_hyps = nbest_hyps[:self.nbest]

        results = []
        for hyp in nbest_hyps:
            assert isinstance(hyp, (Hypothesis, TransHypothesis)), type(hyp)

            # remove sos/eos and get results
            last_pos = None if self.asr_model.use_transducer_decoder else -1
            if isinstance(hyp.yseq, list):
                token_int = hyp.yseq[1:last_pos]
            else:
                token_int = hyp.yseq[1:last_pos].tolist()

            # remove blank symbol id, which is assumed to be 0
            token_int = list(filter(lambda x: x != 0, token_int))

            # Change integer-ids to tokens
            token = self.converter.ids2tokens(token_int)

            if self.tokenizer is not None:
                text = self.tokenizer.tokens2text(token)
            else:
                text = None
            results.append((text, token, token_int, hyp))

        assert check_return_type(results)
        return results
Exemplo n.º 2
0
    def train_one_epoch_curriculum(
        cls,
        model: torch.nn.Module,
        iterator: CurriculumIterFactory,
        tasks: List,
        optimizers: Sequence[torch.optim.Optimizer],
        schedulers: Sequence[Optional[AbsScheduler]],
        scaler: Optional[GradScaler],
        reporter: SubReporter,
        curriculum_generator: AbsCurriculumGenerator,
        summary_writer: Optional[SummaryWriter],
        options: TrainerOptions,
        distributed_option: DistributedOption,
        iepoch: int
    ) -> bool:
        assert check_argument_types()

        grad_noise = options.grad_noise
        accum_grad = options.accum_grad
        grad_clip = options.grad_clip
        grad_clip_type = options.grad_clip_type
        log_interval = options.log_interval
        no_forward_run = options.no_forward_run
        ngpu = options.ngpu
        use_wandb = options.use_wandb
        distributed = distributed_option.distributed

        if log_interval is None:
            try:
                log_interval = max(len(iterator) // 20, 10)
            except TypeError:
                log_interval = 100

        all_steps_are_invalid = True
        # [For distributed] Because iteration counts are not always equals between
        # processes, send stop-flag to the other processes if iterator is finished
        iterator_stop = torch.tensor(0).to("cuda" if ngpu > 0 else "cpu")

        start_time = time.perf_counter()
        tasks = [iter(it) for it in tasks]

        iiter = 0
        #Reset the exausted tasks list
        curriculum_generator.reset_exhausted()

        while iiter < iterator.num_iters_per_epoch:
            iiter+=1

            k = curriculum_generator.get_next_task_ind(iiter=iiter, iepoch=iepoch)

            try:
                _, batch = tasks[k].next()
            except StopIteration as e:
                if options.refill_task==True:
                    logging.info(f"Refilled task {k}.")
                    tasks.pop(k)
                    tasks.insert(k, iter(iterator.refill_task(k)))
                    _, batch = tasks[k].next()
                else:   
                    curriculum_generator.report_exhausted_task(k)
                    logging.info(f"Task {k} is exhausted.")
                    if curriculum_generator.all_exhausted():
                        break
                    iiter -= 1
                    continue
            
            assert isinstance(batch, dict), type(batch)
            if distributed:
                torch.distributed.all_reduce(iterator_stop, ReduceOp.SUM)
                if iterator_stop > 0:
                    break

            if no_forward_run:
                all_steps_are_invalid = False
                continue

            if options.gain_type=='PG':
                batch = to_device(batch, "cuda" if ngpu > 0 else "cpu")
                #Calculate loss before training on the batch
                loss1 = cls.get_loss_eval_mode(
                        batch,
                        model,
                        scaler,
                        ngpu,
                        distributed,
                        reporter,
                        iiter,
                        accum_grad 
                )


                all_steps_are_invalid = cls.train_one_batch(
                                            batch,
                                            model,
                                            scaler,
                                            ngpu,
                                            distributed,
                                            reporter,
                                            iiter,
                                            accum_grad,
                                            grad_noise,
                                            grad_clip,
                                            grad_clip_type,
                                            optimizers,
                                            schedulers,
                                            start_time
                                            )
                loss2 = cls.get_loss_eval_mode(
                        batch,
                        model,
                        scaler,
                        ngpu,
                        distributed,
                        reporter,
                        iiter,
                        accum_grad 
                        )
            elif options.gain_type=='SPG':
                #Sample second batch for evaluation
                try:
                    _, batch_eval = tasks[k].next()
                except StopIteration as e:
                    if options.refill_task==True:
                        logging.info(f"Refilled task {k}.")
                        tasks.pop(k)
                        tasks.insert(k, iter(iterator.refill_task(k)))
                        _, batch_eval = tasks[k].next()
                    #Add else condition for exhaust task option

                batch_eval_gpu = to_device(batch_eval, "cuda" if ngpu > 0 else "cpu")
                loss1 = cls.get_loss_eval_mode(
                            batch_eval_gpu,
                            model,
                            scaler,
                            ngpu,
                            distributed,
                            reporter,
                            iiter,
                            accum_grad 
                            )
                del batch_eval_gpu

                batch = to_device(batch, "cuda" if ngpu > 0 else "cpu")
                all_steps_are_invalid = cls.train_one_batch(
                                            batch,
                                            model,
                                            scaler,
                                            ngpu,
                                            distributed,
                                            reporter,
                                            iiter,
                                            accum_grad,
                                            grad_noise,
                                            grad_clip,
                                            grad_clip_type,
                                            optimizers,
                                            schedulers,
                                            start_time
                                            )

                batch_eval = to_device(batch_eval, "cuda" if ngpu > 0 else "cpu")
                loss2 = cls.get_loss_eval_mode(
                            batch_eval,
                            model,
                            scaler,
                            ngpu,
                            distributed,
                            reporter,
                            iiter,
                            accum_grad 
                            )

            if not (np.isinf(loss1.item()) or np.isinf(loss2.item())):
                    curriculum_generator.update_policy(
                        iepoch=iepoch,
                        iiter=iiter,
                        num_iters=iterator.num_iters_per_epoch, 
                        k=k, 
                        losses=(loss1.item(), loss2.item()), 
                        batch_lens=batch['speech_lengths'].detach().cpu().numpy(),
                        algo=options.curriculum_algo
                    )

            start_time = time.perf_counter()

            # NOTE(kamo): Call log_message() after next()
            reporter.next()
            if iiter % log_interval == 0:
                logging.info(reporter.log_message(-log_interval))
                if summary_writer is not None:
                    reporter.tensorboard_add_scalar(summary_writer, -log_interval)
                if use_wandb:
                    reporter.wandb_log()

            torch.cuda.empty_cache()            

        else:
            if distributed:
                iterator_stop.fill_(1)
                torch.distributed.all_reduce(iterator_stop, ReduceOp.SUM)
        logging.info(f"Finished epoch {iepoch}")
        return all_steps_are_invalid, iterator, tasks
Exemplo n.º 3
0
def collect_stats(
    model: AbsESPnetModel,
    train_iter: DataLoader
    and Iterable[Tuple[List[str], Dict[str, torch.Tensor]]],
    valid_iter: DataLoader
    and Iterable[Tuple[List[str], Dict[str, torch.Tensor]]],
    output_dir: Path,
    ngpu: Optional[int],
    log_interval: Optional[int],
    write_collected_feats: bool,
) -> None:
    """Perform on collect_stats mode.

    Running for deriving the shape information from data
    and gathering statistics.
    This method is used before executing train().

    """
    assert check_argument_types()

    npy_scp_writers = {}
    for itr, mode in zip([train_iter, valid_iter], ["train", "valid"]):
        if log_interval is None:
            try:
                log_interval = max(len(itr) // 20, 10)
            except TypeError:
                log_interval = 100

        sum_dict = defaultdict(lambda: 0)
        sq_dict = defaultdict(lambda: 0)
        count_dict = defaultdict(lambda: 0)

        with DatadirWriter(output_dir / mode) as datadir_writer:
            for iiter, (keys, batch) in enumerate(itr, 1):
                batch = to_device(batch, "cuda" if ngpu > 0 else "cpu")

                # 1. Write shape file
                for name in batch:
                    if name.endswith("_lengths"):
                        continue
                    for i, (key, data) in enumerate(zip(keys, batch[name])):
                        if f"{name}_lengths" in batch:
                            lg = int(batch[f"{name}_lengths"][i])
                            data = data[:lg]
                        datadir_writer[f"{name}_shape"][key] = ",".join(
                            map(str, data.shape))

                # 2. Extract feats
                if ngpu <= 1:
                    data = model.collect_feats(**batch)
                else:
                    # Note that data_parallel can parallelize only "forward()"
                    data = data_parallel(
                        ForwardAdaptor(model, "collect_feats"),
                        (),
                        range(ngpu),
                        module_kwargs=batch,
                    )

                # 3. Calculate sum and square sum
                for key, v in data.items():
                    for i, (uttid,
                            seq) in enumerate(zip(keys,
                                                  v.cpu().numpy())):
                        # Truncate zero-padding region
                        if f"{key}_lengths" in data:
                            length = data[f"{key}_lengths"][i]
                            # seq: (Length, Dim, ...)
                            seq = seq[:length]
                        else:
                            # seq: (Dim, ...) -> (1, Dim, ...)
                            seq = seq[None]
                        # Accumulate value, its square, and count
                        sum_dict[key] += seq.sum(0)
                        sq_dict[key] += (seq**2).sum(0)
                        count_dict[key] += len(seq)

                        # 4. [Option] Write derived features as npy format file.
                        if write_collected_feats:
                            # Instantiate NpyScpWriter for the first iteration
                            if (key, mode) not in npy_scp_writers:
                                p = output_dir / mode / "collect_feats"
                                npy_scp_writers[(key, mode)] = NpyScpWriter(
                                    p / f"data_{key}", p / f"{key}.scp")
                            # Save array as npy file
                            npy_scp_writers[(key, mode)][uttid] = seq

                if iiter % log_interval == 0:
                    logging.info(f"Niter: {iiter}")

        for key in sum_dict:
            np.savez(
                output_dir / mode / f"{key}_stats.npz",
                count=count_dict[key],
                sum=sum_dict[key],
                sum_square=sq_dict[key],
            )

        # batch_keys and stats_keys are used by aggregate_stats_dirs.py
        with (output_dir / mode / "batch_keys").open("w",
                                                     encoding="utf-8") as f:
            f.write("\n".join(
                filter(lambda x: not x.endswith("_lengths"), batch)) + "\n")
        with (output_dir / mode / "stats_keys").open("w",
                                                     encoding="utf-8") as f:
            f.write("\n".join(sum_dict) + "\n")
Exemplo n.º 4
0
    def apply_frontend(self,
                       speech: torch.Tensor,
                       prev_states=None,
                       is_final: bool = False):
        if prev_states is not None:
            buf = prev_states["waveform_buffer"]
            speech = torch.cat([buf, speech], dim=0)

        has_enough_samples = False if speech.size(
            0) <= self.win_length else True
        if not has_enough_samples:
            if is_final:
                pad = torch.zeros(self.win_length - speech.size(0),
                                  dtype=speech.dtype)
                speech = torch.cat([speech, pad], dim=0)
            else:
                feats = None
                feats_lengths = None
                next_states = {"waveform_buffer": speech.clone()}
                return feats, feats_lengths, next_states

        if is_final:
            speech_to_process = speech
            waveform_buffer = None
        else:
            n_frames = (speech.size(0) -
                        (self.win_length - self.hop_length)) // self.hop_length
            n_residual = (
                speech.size(0) -
                (self.win_length - self.hop_length)) % self.hop_length
            speech_to_process = speech.narrow(
                0, 0, (self.win_length - self.hop_length) +
                n_frames * self.hop_length)
            waveform_buffer = speech.narrow(
                0,
                speech.size(0) - (self.win_length - self.hop_length) -
                n_residual,
                (self.win_length - self.hop_length) + n_residual,
            ).clone()

        # data: (Nsamples,) -> (1, Nsamples)
        speech_to_process = speech_to_process.unsqueeze(0).to(
            getattr(torch, self.dtype))
        lengths = speech_to_process.new_full(
            [1], dtype=torch.long, fill_value=speech_to_process.size(1))
        batch = {"speech": speech_to_process, "speech_lengths": lengths}

        # lenghts: (1,)
        # a. To device
        batch = to_device(batch, device=self.device)

        feats, feats_lengths = self.asr_model._extract_feats(**batch)
        if self.asr_model.normalize is not None:
            feats, feats_lengths = self.asr_model.normalize(
                feats, feats_lengths)

        # Trimming
        if is_final:
            if prev_states is None:
                pass
            else:
                feats = feats.narrow(
                    1,
                    math.ceil(
                        math.ceil(self.win_length / self.hop_length) / 2),
                    feats.size(1) - math.ceil(
                        math.ceil(self.win_length / self.hop_length) / 2),
                )
        else:
            if prev_states is None:
                feats = feats.narrow(
                    1,
                    0,
                    feats.size(1) - math.ceil(
                        math.ceil(self.win_length / self.hop_length) / 2),
                )
            else:
                feats = feats.narrow(
                    1,
                    math.ceil(
                        math.ceil(self.win_length / self.hop_length) / 2),
                    feats.size(1) - 2 * math.ceil(
                        math.ceil(self.win_length / self.hop_length) / 2),
                )

        feats_lengths = feats.new_full([1],
                                       dtype=torch.long,
                                       fill_value=feats.size(1))

        if is_final:
            next_states = None
        else:
            next_states = {"waveform_buffer": waveform_buffer}
        return feats, feats_lengths, next_states
Exemplo n.º 5
0
    def plot_attention(
        cls,
        model: torch.nn.Module,
        output_dir: Optional[Path],
        summary_writer: Optional[SummaryWriter],
        iterator: Iterable[Tuple[List[str], Dict[str, torch.Tensor]]],
        reporter: SubReporter,
        options: TrainerOptions,
    ) -> None:
        assert check_argument_types()
        import matplotlib

        ngpu = options.ngpu
        no_forward_run = options.no_forward_run

        matplotlib.use("Agg")
        import matplotlib.pyplot as plt
        from matplotlib.ticker import MaxNLocator

        model.eval()
        for ids, batch in iterator:
            assert isinstance(batch, dict), type(batch)
            assert len(next(iter(batch.values()))) == len(ids), (
                len(next(iter(batch.values()))),
                len(ids),
            )
            batch = to_device(batch, "cuda" if ngpu > 0 else "cpu")
            if no_forward_run:
                continue

            # 1. Forwarding model and gathering all attentions
            #    calculate_all_attentions() uses single gpu only.
            att_dict = calculate_all_attentions(model, batch)

            # 2. Plot attentions: This part is slow due to matplotlib
            for k, att_list in att_dict.items():
                assert len(att_list) == len(ids), (len(att_list), len(ids))
                for id_, att_w in zip(ids, att_list):

                    if isinstance(att_w, torch.Tensor):
                        att_w = att_w.detach().cpu().numpy()

                    if att_w.ndim == 2:
                        att_w = att_w[None]
                    elif att_w.ndim > 3 or att_w.ndim == 1:
                        raise RuntimeError(
                            f"Must be 2 or 3 dimension: {att_w.ndim}")

                    w, h = plt.figaspect(1.0 / len(att_w))
                    fig = plt.Figure(figsize=(w * 1.3, h * 1.3))
                    axes = fig.subplots(1, len(att_w))
                    if len(att_w) == 1:
                        axes = [axes]

                    for ax, aw in zip(axes, att_w):
                        ax.imshow(aw.astype(np.float32), aspect="auto")
                        ax.set_title(f"{k}_{id_}")
                        ax.set_xlabel("Input")
                        ax.set_ylabel("Output")
                        ax.xaxis.set_major_locator(MaxNLocator(integer=True))
                        ax.yaxis.set_major_locator(MaxNLocator(integer=True))

                    if output_dir is not None:
                        p = output_dir / id_ / f"{k}.{reporter.get_epoch()}ep.png"
                        p.parent.mkdir(parents=True, exist_ok=True)
                        fig.savefig(p)

                    if summary_writer is not None:
                        summary_writer.add_figure(f"{k}_{id_}", fig,
                                                  reporter.get_epoch())
            reporter.next()
Exemplo n.º 6
0
    def train_one_epoch(
        cls,
        model: torch.nn.Module,
        iterator: Iterable[Tuple[List[str], Dict[str, torch.Tensor]]],
        optimizers: Sequence[torch.optim.Optimizer],
        schedulers: Sequence[Optional[AbsScheduler]],
        reporter: SubReporter,
        options: TrainerOptions,
    ) -> bool:
        assert check_argument_types()

        # Note(kamo): assumes one optimizer
        assert cls.num_optimizers == 1, cls.num_optimizers
        assert len(optimizers) == 1, len(optimizers)
        optimizer = optimizers[0]
        scheduler = schedulers[0]

        grad_noise = options.grad_noise
        accum_grad = options.accum_grad
        grad_clip = options.grad_clip
        log_interval = options.log_interval
        no_forward_run = options.no_forward_run
        ngpu = options.ngpu
        distributed = isinstance(model,
                                 torch.nn.parallel.DistributedDataParallel)
        use_apex = options.train_dtype in ("O0", "O1", "O2", "O3")

        if log_interval is None:
            try:
                log_interval = max(len(iterator) // 20, 10)
            except TypeError:
                log_interval = 100

        model.train()
        all_steps_are_invalid = True
        # [For distributed] Because iteration counts are not always equals between
        # processes, send stop-flag to the other processes if iterator is finished
        iterator_stop = torch.tensor(0).to("cuda" if ngpu > 0 else "cpu")

        start_time = time.perf_counter()
        for iiter, (_, batch) in enumerate(
                reporter.measure_iter_time(iterator, "iter_time"), 1):
            assert isinstance(batch, dict), type(batch)

            if distributed:
                torch.distributed.all_reduce(iterator_stop, ReduceOp.SUM)
                if iterator_stop > 0:
                    break

            batch = to_device(batch, "cuda" if ngpu > 0 else "cpu")
            if no_forward_run:
                all_steps_are_invalid = False
                reporter.register({})
                continue

            with reporter.measure_time("forward_time"):
                loss, stats, weight = model(**batch)
            if ngpu > 1 or distributed:
                # Apply weighted averaging for loss and stats
                loss = (loss * weight.type(loss.dtype)).sum()

                # if distributed, this method can also apply all_reduce()
                stats, weight = recursive_average(stats, weight, distributed)

                # Now weight is summation over all workers
                loss /= weight
            if distributed:
                # NOTE(kamo): Multiply world_size because DistributedDataParallel
                # automatically normalizes the gradient by world_size.
                loss *= torch.distributed.get_world_size()

            reporter.register(stats, weight)

            loss /= accum_grad
            with reporter.measure_time("backward_time"):
                if use_apex:
                    try:
                        from apex import amp
                    except ImportError:
                        logging.error(
                            "You need to install apex. "
                            "See https://github.com/NVIDIA/apex#linux")

                    with amp.scale_loss(loss, optimizers) as scaled_loss:
                        scaled_loss.backward()
                else:
                    loss.backward()

            if iiter % accum_grad == 0:
                # gradient noise injection
                if grad_noise:
                    add_gradient_noise(
                        model,
                        reporter.get_total_count(),
                        duration=100,
                        eta=1.0,
                        scale_factor=0.55,
                    )

                # compute the gradient norm to check if it is normal or not
                grad_norm = torch.nn.utils.clip_grad_norm_(
                    model.parameters(), grad_clip)

                if not np.isfinite(grad_norm):
                    logging.warning(
                        f"The grad norm is {grad_norm}. Skipping updating the model."
                    )
                else:
                    all_steps_are_invalid = False
                    with reporter.measure_time("optim_step_time"):
                        optimizer.step()
                    if isinstance(scheduler, AbsBatchStepScheduler):
                        scheduler.step()
                optimizer.zero_grad()

                # Register lr and train/load time[sec/step],
                # where step refers to accum_grad * mini-batch
                reporter.register(
                    dict(
                        {
                            f"lr_{i}": pg["lr"]
                            for i, pg in enumerate(optimizer.param_groups)
                            if "lr" in pg
                        },
                        train_time=time.perf_counter() - start_time,
                    ),
                    # Suppress to increment the internal counter.
                    not_increment_count=True,
                )
                start_time = time.perf_counter()

            if iiter % log_interval == 0:
                logging.info(reporter.log_message())

        else:
            if distributed:
                iterator_stop.fill_(1)
                torch.distributed.all_reduce(iterator_stop, ReduceOp.SUM)

        return all_steps_are_invalid
Exemplo n.º 7
0
def calc_perplexity(
    output_dir: str,
    batch_size: int,
    dtype: str,
    ngpu: int,
    seed: int,
    num_workers: int,
    log_level: Union[int, str],
    data_path_and_name_and_type: Sequence[Tuple[str, str, str]],
    key_file: Optional[str],
    train_config: Optional[str],
    model_file: Optional[str],
    log_base: Optional[float],
    allow_variable_data_keys: bool,
):
    assert check_argument_types()
    logging.basicConfig(
        level=log_level,
        format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
    )

    if ngpu >= 1:
        device = "cuda"
    else:
        device = "cpu"

    # 1. Set random-seed
    set_all_random_seed(seed)

    # 2. Build LM
    model, train_args = LMTask.build_model_from_file(train_config, model_file, device)
    # Wrape model to make model.nll() data-parallel
    wrapped_model = ForwardAdaptor(model, "nll")
    wrapped_model.to(dtype=getattr(torch, dtype)).eval()
    logging.info(f"Model:\n{model}")

    # 3. Build data-iterator
    loader = LMTask.build_streaming_iterator(
        data_path_and_name_and_type,
        dtype=dtype,
        batch_size=batch_size,
        key_file=key_file,
        num_workers=num_workers,
        preprocess_fn=LMTask.build_preprocess_fn(train_args, False),
        collate_fn=LMTask.build_collate_fn(train_args, False),
        allow_variable_data_keys=allow_variable_data_keys,
        inference=True,
    )

    # 4. Start for-loop
    with DatadirWriter(output_dir) as writer:
        total_nll = 0.0
        total_ntokens = 0
        for keys, batch in loader:
            assert isinstance(batch, dict), type(batch)
            assert all(isinstance(s, str) for s in keys), keys
            _bs = len(next(iter(batch.values())))
            assert len(keys) == _bs, f"{len(keys)} != {_bs}"

            with torch.no_grad():
                batch = to_device(batch, device)
                if ngpu <= 1:
                    # NOTE(kamo): data_parallel also should work with ngpu=1,
                    # but for debuggability it's better to keep this block.
                    nll, lengths = wrapped_model(**batch)
                else:
                    nll, lengths = data_parallel(
                        wrapped_model, (), range(ngpu), module_kwargs=batch
                    )

            assert _bs == len(nll) == len(lengths), (_bs, len(nll), len(lengths))
            # nll: (B, L) -> (B,)
            nll = nll.detach().cpu().numpy().sum(1)
            # lengths: (B,)
            lengths = lengths.detach().cpu().numpy()
            total_nll += nll.sum()
            total_ntokens += lengths.sum()

            for key, _nll, ntoken in zip(keys, nll, lengths):
                if log_base is None:
                    utt_ppl = np.exp(_nll / ntoken)
                else:
                    utt_ppl = log_base ** (_nll / ntoken / np.log(log_base))

                # Write PPL of each utts for debugging or analysis
                writer["utt2ppl"][key] = str(utt_ppl)
                writer["utt2ntokens"][key] = str(ntoken)

        if log_base is None:
            ppl = np.exp(total_nll / total_ntokens)
        else:
            ppl = log_base ** (total_nll / total_ntokens / np.log(log_base))

        with (Path(output_dir) / "ppl").open("w", encoding="utf-8") as f:
            f.write(f"{ppl}\n")
        with (Path(output_dir) / "base").open("w", encoding="utf-8") as f:
            if log_base is None:
                _log_base = np.e
            else:
                _log_base = log_base
            f.write(f"{_log_base}\n")
        logging.info(f"PPL={ppl}")
Exemplo n.º 8
0
    def __call__(
        self, speech: Union[torch.Tensor, np.ndarray]
    ) -> List[Tuple[Optional[str], List[str], List[int], float]]:
        """Inference

        Args:
            data: Input speech data
        Returns:
            text, token, token_int, hyp

        """
        assert check_argument_types()

        # Input as audio signal
        if isinstance(speech, np.ndarray):
            speech = torch.tensor(speech)

        # data: (Nsamples,) -> (1, Nsamples)
        speech = speech.unsqueeze(0).to(getattr(torch, self.dtype))
        # lenghts: (1,)
        lengths = speech.new_full([1],
                                  dtype=torch.long,
                                  fill_value=speech.size(1))
        batch = {"speech": speech, "speech_lengths": lengths}

        # a. To device
        batch = to_device(batch, device=self.device)

        # b. Forward Encoder
        # enc: [N, T, C]
        enc, _ = self.asr_model.encode(**batch)
        assert len(enc) == 1, len(enc)

        # logp_encoder_output: [N, T, C]
        logp_encoder_output = torch.nn.functional.log_softmax(
            self.asr_model.ctc.ctc_lo(enc), dim=2)

        # TODO(Liyong Guo): Support batch decoding.
        # Following statement only support batch_size == 1
        supervision_segments = torch.tensor([[0, 0, enc.shape[1]]],
                                            dtype=torch.int32)
        indices = torch.tensor([0])

        dense_fsa_vec = k2.DenseFsaVec(logp_encoder_output,
                                       supervision_segments)

        lattices = k2.intersect_dense_pruned(self.decode_graph, dense_fsa_vec,
                                             20.0, self.output_beam_size, 30,
                                             10000)

        best_paths = k2.shortest_path(lattices, use_double_scores=True)
        scores = best_paths.get_tot_scores(use_double_scores=True,
                                           log_semiring=False).tolist()

        hyps = get_texts(best_paths, indices)
        # TODO(Liyong Guo): Support batch decoding. now batch_size == 1.
        assert len(scores) == 1
        assert len(scores) == len(hyps)

        results = []

        for token_int, score in zip(hyps, scores):
            # Change integer-ids to tokens
            token = self.converter.ids2tokens(token_int)

            if self.tokenizer is not None:
                text = self.tokenizer.tokens2text(token)
            else:
                text = None
            results.append((text, token, token_int, score))

        assert check_return_type(results)
        return results
Exemplo n.º 9
0
    def __call__(self,
                 speech: Union[torch.Tensor, np.ndarray],
                 fs: int = 8000) -> List[torch.Tensor]:
        """Inference

        Args:
            speech: Input speech data (Batch, Nsamples [, Channels])
            fs: sample rate
        Returns:
            [speaker_info1, speaker_info2, ...]

        """
        assert check_argument_types()

        # Input as audio signal
        if isinstance(speech, np.ndarray):
            speech = torch.as_tensor(speech)

        assert speech.dim() > 1, speech.size()
        batch_size = speech.size(0)
        speech = speech.to(getattr(torch, self.dtype))
        # lengths: (B,)
        lengths = speech.new_full([batch_size],
                                  dtype=torch.long,
                                  fill_value=speech.size(1))

        # a. To device
        speech = to_device(speech, device=self.device)
        lengths = to_device(lengths, device=self.device)

        if self.segmenting and lengths[0] > self.segment_size * fs:
            # Segment-wise speaker diarization
            num_segments = int(
                np.ceil(speech.size(1) / (self.segment_size * fs)))
            t = T = int(self.segment_size * fs)
            pad_shape = speech[:, :T].shape
            diarized_wavs = []
            range_ = trange if self.show_progressbar else range
            for i in range_(num_segments):
                st = int(i * self.segment_size * fs)
                en = st + T
                if en >= lengths[0]:
                    # en - st < T (last segment)
                    en = lengths[0]
                    speech_seg = speech.new_zeros(pad_shape)
                    t = en - st
                    speech_seg[:, :t] = speech[:, st:en]
                else:
                    t = T
                    speech_seg = speech[:, st:en]  # B x T [x C]

                lengths_seg = speech.new_full([batch_size],
                                              dtype=torch.long,
                                              fill_value=T)
                # b. Diarization Forward
                encoder_out, encoder_out_lens = self.diar_model.encode(
                    speech_seg, lengths_seg)
                # SA-EEND
                if self.diar_model.attractor is None:
                    assert (
                        self.num_spk
                        is not None), 'Argument "num_spk" must be specified'
                    spk_prediction = self.diar_model.decoder(
                        encoder_out, encoder_out_lens)
                # EEND-EDA
                else:
                    # if num_spk is specified, use that number
                    if self.num_spk is not None:
                        attractor, att_prob = self.diar_model.attractor(
                            encoder_out,
                            encoder_out_lens,
                            torch.zeros(
                                encoder_out.size(0),
                                self.num_spk + 1,
                                encoder_out.size(2),
                            ),
                        )
                        spk_prediction = torch.bmm(
                            encoder_out,
                            attractor[:, :self.num_spk, :].permute(0, 2, 1),
                        )
                    # else find the first att_prob[i] < 0
                    else:
                        max_num_spk = 15  # upper bound number for estimation
                        attractor, att_prob = self.diar_model.attractor(
                            encoder_out,
                            encoder_out_lens,
                            torch.zeros(
                                encoder_out.size(0),
                                max_num_spk + 1,
                                encoder_out.size(2),
                            ),
                        )
                        att_prob = torch.squeeze(att_prob)
                        for pred_num_spk in range(len(att_prob)):
                            if att_prob[pred_num_spk].item() < 0:
                                break
                        spk_prediction = torch.bmm(
                            encoder_out,
                            attractor[:, :pred_num_spk, :].permute(0, 2, 1))
                # List[torch.Tensor(B, T, num_spks)]
                diarized_wavs.append(spk_prediction)
            # Determine maximum estimated number of speakers among the segments
            max_len = max([x.size(2) for x in diarized_wavs])
            # pad tensors in diarized_wavs with "float('-inf')" to have same size
            diarized_wavs = [
                torch.nn.functional.pad(x,
                                        (0, max_len - x.size(2)), "constant",
                                        float("-inf")) for x in diarized_wavs
            ]
            spk_prediction = torch.cat(diarized_wavs, dim=1)
        else:
            # b. Diarization Forward
            encoder_out, encoder_out_lens = self.diar_model.encode(
                speech, lengths)
            # SA-EEND
            if self.diar_model.attractor is None:
                assert self.num_spk is not None, 'Argument "num_spk" must be specified'
                spk_prediction = self.diar_model.decoder(
                    encoder_out, encoder_out_lens)
            # EEND-EDA
            else:
                # if num_spk is specified, use that number
                if self.num_spk is not None:
                    attractor, att_prob = self.diar_model.attractor(
                        encoder_out,
                        encoder_out_lens,
                        torch.zeros(encoder_out.size(0), self.num_spk + 1,
                                    encoder_out.size(2)),
                    )
                    spk_prediction = torch.bmm(
                        encoder_out,
                        attractor[:, :self.num_spk, :].permute(0, 2, 1))
                # else find the first att_prob[i] < 0
                else:
                    max_num_spk = 15  # upper bound number for estimation
                    attractor, att_prob = self.diar_model.attractor(
                        encoder_out,
                        encoder_out_lens,
                        torch.zeros(encoder_out.size(0), max_num_spk + 1,
                                    encoder_out.size(2)),
                    )
                    att_prob = torch.squeeze(att_prob)
                    for pred_num_spk in range(len(att_prob)):
                        if att_prob[pred_num_spk].item() < 0:
                            break
                    spk_prediction = torch.bmm(
                        encoder_out,
                        attractor[:, :pred_num_spk, :].permute(0, 2, 1))
        if self.num_spk is not None:
            assert spk_prediction.size(2) == self.num_spk, (
                spk_prediction.size(2),
                self.num_spk,
            )
        assert spk_prediction.size(0) == batch_size, (
            spk_prediction.size(0),
            batch_size,
        )
        spk_prediction = spk_prediction.cpu().numpy()
        spk_prediction = 1 / (1 + np.exp(-spk_prediction))

        return spk_prediction
Exemplo n.º 10
0
    def __call__(
        self,
        text: Union[str, torch.Tensor, np.ndarray],
        speech: Union[torch.Tensor, np.ndarray] = None,
        durations: Union[torch.Tensor, np.ndarray] = None,
        spembs: Union[torch.Tensor, np.ndarray] = None,
        sids: Union[torch.Tensor, np.ndarray] = None,
        lids: Union[torch.Tensor, np.ndarray] = None,
        decode_conf: Optional[Dict[str, Any]] = None,
    ) -> Dict[str, torch.Tensor]:
        """Run text-to-speech."""
        assert check_argument_types()

        # check inputs
        if self.use_speech and speech is None:
            raise RuntimeError("Missing required argument: 'speech'")
        if self.use_sids and sids is None:
            raise RuntimeError("Missing required argument: 'sids'")
        if self.use_lids and lids is None:
            raise RuntimeError("Missing required argument: 'lids'")
        if self.use_spembs and spembs is None:
            raise RuntimeError("Missing required argument: 'spembs'")

        # prepare batch
        if isinstance(text, str):
            text = self.preprocess_fn("<dummy>", dict(text=text))["text"]
        batch = dict(text=text)
        if speech is not None:
            batch.update(speech=speech)
        if durations is not None:
            batch.update(durations=durations)
        if spembs is not None:
            batch.update(spembs=spembs)
        if sids is not None:
            batch.update(sids=sids)
        if lids is not None:
            batch.update(lids=lids)
        batch = to_device(batch, self.device)

        # overwrite the decode configs if provided
        cfg = self.decode_conf
        if decode_conf is not None:
            cfg = self.decode_conf.copy()
            cfg.update(decode_conf)

        # inference
        if self.always_fix_seed:
            set_all_random_seed(self.seed)
        output_dict = self.model.inference(**batch, **cfg)

        # calculate additional metrics
        if output_dict.get("att_w") is not None:
            duration, focus_rate = self.duration_calculator(
                output_dict["att_w"])
            output_dict.update(duration=duration, focus_rate=focus_rate)

        # apply vocoder (mel-to-wav)
        if self.vocoder is not None:
            if output_dict.get("feat_gen_denorm") is not None:
                input_feat = output_dict["feat_gen_denorm"]
            else:
                input_feat = output_dict["feat_gen"]
            wav = self.vocoder(input_feat)
            output_dict.update(wav=wav)

        return output_dict
Exemplo n.º 11
0
    def __call__(
        self, src_text: Union[torch.Tensor, np.ndarray]
    ) -> List[Tuple[Optional[str], List[str], List[int], Hypothesis]]:
        """Inference

        Args:
            data: Input text data
        Returns:
            text, token, token_int, hyp

        """
        assert check_argument_types()

        # Input as audio signal
        if isinstance(src_text, np.ndarray):
            src_text = torch.tensor(src_text)

        # data: (Nsamples,) -> (1, Nsamples)
        src_text = src_text.unsqueeze(0).to(torch.long)
        # lengths: (1,)
        lengths = src_text.new_full([1],
                                    dtype=torch.long,
                                    fill_value=src_text.size(1))
        batch = {"src_text": src_text, "src_text_lengths": lengths}

        # a. To device
        batch = to_device(batch, device=self.device)

        # b. Forward Encoder
        enc, _ = self.mt_model.encode(**batch)
        assert len(enc) == 1, len(enc)

        # c. Passed the encoder result and the beam search
        nbest_hyps = self.beam_search(x=enc[0],
                                      maxlenratio=self.maxlenratio,
                                      minlenratio=self.minlenratio)
        nbest_hyps = nbest_hyps[:self.nbest]

        results = []
        for hyp in nbest_hyps:
            assert isinstance(hyp, Hypothesis), type(hyp)

            # remove sos/eos and get results
            # token_int = hyp.yseq[1:-1].tolist()
            # TODO(sdalmia): check why the above line doesn't work
            token_int = hyp.yseq.tolist()
            token_int = list(
                filter(lambda x: x != self.mt_model.sos, token_int))
            token_int = list(
                filter(lambda x: x != self.mt_model.eos, token_int))

            # remove blank symbol id, which is assumed to be 0
            token_int = list(filter(lambda x: x != 0, token_int))

            # Change integer-ids to tokens
            token = self.converter.ids2tokens(token_int)

            if self.tokenizer is not None:
                text = self.tokenizer.tokens2text(token)
            else:
                text = None
            results.append((text, token, token_int, hyp))

        assert check_return_type(results)
        return results
Exemplo n.º 12
0
    def __call__(
        self, batch: Dict[str, Union[torch.Tensor, np.ndarray]]
    ) -> List[Tuple[Optional[str], List[str], List[int], float]]:
        """Inference

        Args:
            batch: Input speech data and corresponding lengths
        Returns:
            text, token, token_int, hyp

        """
        assert check_argument_types()

        if isinstance(batch["speech"], np.ndarray):
            batch["speech"] = torch.tensor(batch["speech"])
        if isinstance(batch["speech_lengths"], np.ndarray):
            batch["speech_lengths"] = torch.tensor(batch["speech_lengths"])

        # a. To device
        batch = to_device(batch, device=self.device)

        # b. Forward Encoder
        # enc: [N, T, C]
        enc, encoder_out_lens = self.asr_model.encode(**batch)

        # logp_encoder_output: [N, T, C]
        logp_encoder_output = torch.nn.functional.log_softmax(
            self.asr_model.ctc.ctc_lo(enc), dim=2
        )

        # It maybe useful to tune blank_bias.
        # The valid range of blank_bias is [-inf, 0]
        logp_encoder_output[:, :, 0] += self.blank_bias

        batch_size = encoder_out_lens.size(0)
        sequence_idx = torch.arange(0, batch_size).unsqueeze(0).t().to(torch.int32)
        start_frame = torch.zeros([batch_size], dtype=torch.int32).unsqueeze(0).t()
        num_frames = encoder_out_lens.cpu().unsqueeze(0).t().to(torch.int32)
        supervision_segments = torch.cat([sequence_idx, start_frame, num_frames], dim=1)

        supervision_segments = supervision_segments.to(torch.int32)

        # An introduction to DenseFsaVec:
        # https://k2-fsa.github.io/k2/core_concepts/index.html#dense-fsa-vector
        # It could be viewed as a fsa-type lopg_encoder_output,
        # whose weight on the arcs are initialized with logp_encoder_output.
        # The goal of converting tensor-type to fsa-type is using
        # fsa related functions in k2. e.g. k2.intersect_dense_pruned below
        dense_fsa_vec = k2.DenseFsaVec(logp_encoder_output, supervision_segments)

        # The term "intersect" is similar to "compose" in k2.
        # The differences is are:
        # for "compose" functions, the composition involves
        # mathcing output label of a.fsa and input label of b.fsa
        # while for "intersect" functions, the composition involves
        # matching input label of a.fsa and input label of b.fsa
        # Actually, in compose functions, b.fsa is inverted and then
        # a.fsa and inv_b.fsa are intersected together.
        # For difference between compose and interset:
        # https://github.com/k2-fsa/k2/blob/master/k2/python/k2/fsa_algo.py#L308
        # For definition of k2.intersect_dense_pruned:
        # https://github.com/k2-fsa/k2/blob/master/k2/python/k2/autograd.py#L648
        lattices = k2.intersect_dense_pruned(
            self.decode_graph,
            dense_fsa_vec,
            self.search_beam_size,
            self.output_beam_size,
            self.min_active_states,
            self.max_active_states,
        )

        # lattices.scores is the sum of decode_graph.scores(a.k.a. lm weight) and
        # dense_fsa_vec.scores(a.k.a. am weight) on related arcs.
        # For ctc decoding graph, lattices.scores only store am weight
        # since the decoder_graph only define the ctc topology and
        # has no lm weight on its arcs.
        # While for 3-gram decoding, whose graph is converted from language models,
        # lattice.scores contains both am weights and lm weights
        #
        # It maybe useful to tune lattice.scores
        # The valid range of lattice_weight is [0, inf)
        # The lattice_weight will affect the search of k2.random_paths
        lattices.scores *= self.lattice_weight

        results = []
        if self.use_nbest_rescoring:
            (
                am_scores,
                lm_scores,
                token_ids,
                new2old,
                path_to_seq_map,
                seq_to_path_splits,
            ) = nbest_am_lm_scores(
                lattices, self.num_paths, self.device, self.nbest_batch_size
            )

            ys_pad_lens = torch.tensor([len(hyp) for hyp in token_ids]).to(self.device)
            max_token_length = max(ys_pad_lens)
            ys_pad_list = []
            for hyp in token_ids:
                ys_pad_list.append(
                    torch.cat(
                        [
                            torch.tensor(hyp, dtype=torch.long),
                            torch.tensor(
                                [self.asr_model.ignore_id]
                                * (max_token_length.item() - len(hyp)),
                                dtype=torch.long,
                            ),
                        ]
                    )
                )

            ys_pad = (
                torch.stack(ys_pad_list).to(torch.long).to(self.device)
            )  # [batch, max_token_length]

            encoder_out = enc.index_select(0, path_to_seq_map.to(torch.long)).to(
                self.device
            )  # [batch, T, dim]
            encoder_out_lens = encoder_out_lens.index_select(
                0, path_to_seq_map.to(torch.long)
            ).to(
                self.device
            )  # [batch]

            decoder_scores = -self.asr_model.batchify_nll(
                encoder_out, encoder_out_lens, ys_pad, ys_pad_lens, self.nll_batch_size
            )

            # padded_value for nnlm is 0
            ys_pad[ys_pad == self.asr_model.ignore_id] = 0
            nnlm_nll, x_lengths = self.lm.batchify_nll(
                ys_pad, ys_pad_lens, self.nll_batch_size
            )
            nnlm_scores = -nnlm_nll.sum(dim=1)

            batch_tot_scores = (
                self.am_weight * am_scores
                + self.decoder_weight * decoder_scores
                + self.nnlm_weight * nnlm_scores
            )
            split_size = indices_to_split_size(
                seq_to_path_splits.tolist(), total_elements=batch_tot_scores.size(0)
            )
            batch_tot_scores = torch.split(
                batch_tot_scores,
                split_size,
            )

            hyps = []
            scores = []
            processed_seqs = 0
            for tot_scores in batch_tot_scores:
                if tot_scores.nelement() == 0:
                    # the last element by torch.tensor_split may be empty
                    # e.g.
                    # torch.tensor_split(torch.tensor([1,2,3,4]), torch.tensor([2,4]))
                    # (tensor([1, 2]), tensor([3, 4]), tensor([], dtype=torch.int64))
                    break
                best_seq_idx = processed_seqs + torch.argmax(tot_scores)

                assert best_seq_idx < len(token_ids)
                best_token_seqs = token_ids[best_seq_idx]
                processed_seqs += tot_scores.nelement()
                hyps.append(best_token_seqs)
                scores.append(tot_scores.max().item())

            assert len(hyps) == len(split_size)
        else:
            best_paths = k2.shortest_path(lattices, use_double_scores=True)
            scores = best_paths.get_tot_scores(
                use_double_scores=True, log_semiring=False
            ).tolist()
            hyps = get_texts(best_paths)

        assert len(scores) == len(hyps)

        for token_int, score in zip(hyps, scores):
            # For decoding methods nbest_rescoring and ctc_decoding
            # hyps stores token_index, which is lattice.labels.

            # convert token_id to text with self.tokenizer
            token = self.converter.ids2tokens(token_int)
            assert self.tokenizer is not None
            text = self.tokenizer.tokens2text(token)
            results.append((text, token, token_int, score))

        assert check_return_type(results)
        return results
Exemplo n.º 13
0
def inference(
    output_dir: str,
    batch_size: int,
    dtype: str,
    ngpu: int,
    seed: int,
    num_workers: int,
    log_level: Union[int, str],
    data_path_and_name_and_type: Sequence[Tuple[str, str, str]],
    key_file: Optional[str],
    train_config: Optional[str],
    model_file: Optional[str],
    threshold: float,
    minlenratio: float,
    maxlenratio: float,
    use_att_constraint: bool,
    backward_window: int,
    forward_window: int,
    allow_variable_data_keys: bool,
    vocoder_conf: dict,
):
    """Perform TTS model decoding."""
    assert check_argument_types()
    if batch_size > 1:
        raise NotImplementedError("batch decoding is not implemented")
    if ngpu > 1:
        raise NotImplementedError("only single GPU decoding is supported")
    logging.basicConfig(
        level=log_level,
        format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
    )

    if ngpu >= 1:
        device = "cuda"
    else:
        device = "cpu"

    # 1. Set random-seed
    set_all_random_seed(seed)

    # 2. Build model
    model, train_args = TTSTask.build_model_from_file(train_config, model_file,
                                                      device)
    model.to(dtype=getattr(torch, dtype)).eval()
    tts = model.tts
    normalize = model.normalize
    logging.info(f"Normalization:\n{normalize}")
    logging.info(f"TTS:\n{tts}")

    # 3. Build data-iterator
    loader = TTSTask.build_streaming_iterator(
        data_path_and_name_and_type,
        dtype=dtype,
        batch_size=batch_size,
        key_file=key_file,
        num_workers=num_workers,
        preprocess_fn=TTSTask.build_preprocess_fn(train_args, False),
        collate_fn=TTSTask.build_collate_fn(train_args),
        allow_variable_data_keys=allow_variable_data_keys,
        inference=True,
    )

    # 4. Build converter from spectrogram to waveform
    if model.feats_extract is not None:
        vocoder_conf.update(model.feats_extract.get_parameters())
    if "n_fft" in vocoder_conf and "n_shift" in vocoder_conf and "fs" in vocoder_conf:
        spc2wav = Spectrogram2Waveform(**vocoder_conf)
        logging.info(f"Vocoder: {spc2wav}")
    else:
        spc2wav = None
        logging.info(
            "Vocoder is not used because vocoder_conf is not sufficient")

    # 5. Start for-loop
    output_dir = Path(output_dir)
    (output_dir / "norm").mkdir(parents=True, exist_ok=True)
    (output_dir / "denorm").mkdir(parents=True, exist_ok=True)
    (output_dir / "wav").mkdir(parents=True, exist_ok=True)

    # FIXME(kamo): I think we shouldn't depend on kaldi-format any more.
    #  How about numpy or HDF5?
    #  >>> with NpyScpWriter() as f:
    with kaldiio.WriteHelper("ark,scp:{o}.ark,{o}.scp".format(
            o=output_dir / "norm/feats")) as f, kaldiio.WriteHelper(
                "ark,scp:{o}.ark,{o}.scp".format(o=output_dir /
                                                 "denorm/feats")) as g:
        for idx, (keys, batch) in enumerate(loader, 1):
            assert isinstance(batch, dict), type(batch)
            assert all(isinstance(s, str) for s in keys), keys
            _bs = len(next(iter(batch.values())))
            assert len(keys) == _bs, f"{len(keys)} != {_bs}"
            batch = to_device(batch, device)

            key = keys[0]
            # Change to single sequence and remove *_length
            # because inference() requires 1-seq, not mini-batch.
            _data = {
                k: v[0]
                for k, v in batch.items() if not k.endswith("_lengths")
            }
            start_time = time.perf_counter()

            # TODO(kamo): Now att_ws is not used.
            outs, probs, att_ws = tts.inference(
                **_data,
                threshold=threshold,
                maxlenratio=maxlenratio,
                minlenratio=minlenratio,
            )
            outs_denorm = normalize.inverse(outs[None])[0][0]
            insize = next(iter(_data.values())).size(0)
            logging.info("inference speed = {} msec / frame.".format(
                (time.perf_counter() - start_time) /
                (int(outs.size(0)) * 1000)))
            logging.info(f"{key} (size:{insize}->{outs.size(0)})")
            if outs.size(0) == insize * maxlenratio:
                logging.warning(
                    f"output length reaches maximum length ({key}).")
            f[key] = outs.cpu().numpy()
            g[key] = outs_denorm.cpu().numpy()

            # TODO(kamo): Write scp
            if spc2wav is not None:
                wav = spc2wav(outs_denorm.cpu().numpy())
                sf.write(f"{output_dir}/wav/{key}.wav", wav, spc2wav.fs,
                         "PCM_16")
Exemplo n.º 14
0
    def recognize(self, audiofile: Union[Path, str, bytes]) -> Result:

        result = Result()

        if isinstance(audiofile, str):
            audio_samples, rate = librosa.load(audiofile, sr=16000)
        elif isinstance(audiofile, bytes):
            audio_samples, rate = librosa.core.load(io.BytesIO(audiofile),
                                                    sr=16000)
        else:
            raise ValueError("Failed to load audio file")

        result.audio_samples = copy.deepcopy(audio_samples)

        #a entrada do modelo é torch.tensor
        if isinstance(audio_samples, np.ndarray):
            audio_samples = torch.tensor(audio_samples)
        audio_samples = audio_samples.unsqueeze(0).to(getattr(
            torch, 'float32'))

        lengths = audio_samples.new_full([1],
                                         dtype=torch.long,
                                         fill_value=audio_samples.size(1))
        batch = {"speech": audio_samples, "speech_lengths": lengths}
        batch = to_device(batch, device=self.device)

        #model encoder
        enc, _ = self.model.encode(**batch)

        #model decoder
        nbest_hyps = self.beam_search(x=enc[0])

        #Apenas a melhor hipótese
        best_hyps = nbest_hyps[0]

        #Conversão de tokenids do treinamento para texto
        token_int = best_hyps.yseq[1:-1].tolist()
        token_int = list(filter(lambda x: x != 0, token_int))
        token = self.converter.ids2tokens(token_int)
        text = self.tokenizer.tokens2text(token)

        #Preenche o objeto result
        result.text = text
        result.encoded_vector = enc[0]  #[0] remove dimensão de batch

        #calcula todas as matrizes de atenção
        #
        text_tensor = torch.Tensor(token_int).unsqueeze(0).to(
            getattr(torch, 'long'))
        batch["text"] = text_tensor
        batch["text_lengths"] = text_tensor.new_full(
            [1], dtype=torch.long, fill_value=text_tensor.size(1))

        result.attention_weights = calculate_all_attentions(self.model, batch)
        result.tokens_txt = token

        #CTC posteriors
        logp = self.model.ctc.log_softmax(enc.unsqueeze(0))[0]
        result.ctc_posteriors = logp.exp_().numpy()
        result.tokens_int = best_hyps.yseq
        result.mel_features, _ = self.frontend(audiofile, normalize=False)
        return result
Exemplo n.º 15
0
    def __call__(
        self, batch: Dict[str, Union[torch.Tensor, np.ndarray]]
    ) -> List[Tuple[Optional[str], List[str], List[int], float]]:
        """Inference

        Args:
            batch: Input speech data and corresponding lengths
        Returns:
            text, token, token_int, hyp

        """
        assert check_argument_types()

        if isinstance(batch["speech"], np.ndarray):
            batch["speech"] = torch.tensor(batch["speech"])
        if isinstance(batch["speech_lengths"], np.ndarray):
            batch["speech_lengths"] = torch.tensor(batch["speech_lengths"])

        # a. To device
        batch = to_device(batch, device=self.device)

        # b. Forward Encoder
        # enc: [N, T, C]
        enc, encoder_out_lens = self.asr_model.encode(**batch)

        # logp_encoder_output: [N, T, C]
        logp_encoder_output = torch.nn.functional.log_softmax(
            self.asr_model.ctc.ctc_lo(enc), dim=2)

        batch_size = encoder_out_lens.size(0)
        sequence_idx = torch.arange(0, batch_size).unsqueeze(0).t().to(
            torch.int32)
        start_frame = torch.zeros([batch_size],
                                  dtype=torch.int32).unsqueeze(0).t()
        num_frames = encoder_out_lens.cpu().unsqueeze(0).t().to(torch.int32)
        supervision_segments = torch.cat(
            [sequence_idx, start_frame, num_frames], dim=1)

        supervision_segments = supervision_segments.to(torch.int32)

        dense_fsa_vec = k2.DenseFsaVec(logp_encoder_output,
                                       supervision_segments)

        lattices = k2.intersect_dense_pruned(self.decode_graph, dense_fsa_vec,
                                             20.0, self.output_beam_size, 30,
                                             10000)

        best_paths = k2.shortest_path(lattices, use_double_scores=True)
        scores = best_paths.get_tot_scores(use_double_scores=True,
                                           log_semiring=False).tolist()

        hyps = get_texts(best_paths)
        assert len(scores) == len(hyps)

        results = []

        for token_int, score in zip(hyps, scores):
            # Change integer-ids to tokens
            token = self.converter.ids2tokens(token_int)

            if self.tokenizer is not None:
                text = self.tokenizer.tokens2text(token)
            else:
                text = None
            results.append((text, token, token_int, score))

        assert check_return_type(results)
        return results
Exemplo n.º 16
0
def inference(
    output_dir: str,
    batch_size: int,
    dtype: str,
    fs: int,
    ngpu: int,
    seed: int,
    num_workers: int,
    log_level: Union[int, str],
    data_path_and_name_and_type: Sequence[Tuple[str, str, str]],
    key_file: Optional[str],
    enh_train_config: str,
    enh_model_file: str,
    allow_variable_data_keys: bool,
    normalize_output_wav: bool,
):
    assert check_argument_types()
    if batch_size > 1:
        raise NotImplementedError("batch decoding is not implemented")
    if ngpu > 1:
        raise NotImplementedError("only single GPU decoding is supported")

    logging.basicConfig(
        level=log_level,
        format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
    )

    if ngpu >= 1:
        device = "cuda"
    else:
        device = "cpu"

    # 1. Set random-seed
    set_all_random_seed(seed)

    # 2. Build Enh model
    enh_model, enh_train_args = EnhancementTask.build_model_from_file(
        enh_train_config, enh_model_file, device)
    enh_model.eval()

    num_spk = enh_model.num_spk

    # 3. Build data-iterator
    loader = EnhancementTask.build_streaming_iterator(
        data_path_and_name_and_type,
        dtype=dtype,
        batch_size=batch_size,
        key_file=key_file,
        num_workers=num_workers,
        preprocess_fn=EnhancementTask.build_preprocess_fn(
            enh_train_args, False),
        collate_fn=EnhancementTask.build_collate_fn(enh_train_args),
        allow_variable_data_keys=allow_variable_data_keys,
        inference=True,
    )

    writers = []
    for i in range(num_spk):
        writers.append(
            SoundScpWriter(f"{output_dir}/wavs/{i + 1}",
                           f"{output_dir}/spk{i + 1}.scp"))

    for keys, batch in loader:
        assert isinstance(batch, dict), type(batch)
        assert all(isinstance(s, str) for s in keys), keys
        _bs = len(next(iter(batch.values())))
        assert len(keys) == _bs, f"{len(keys)} != {_bs}"

        with torch.no_grad():
            # a. To device
            batch = to_device(batch, device)
            # b. Forward Enhancement Frontend
            waves, _, _ = enh_model.enh_model.forward_rawwav(
                batch["speech_mix"], batch["speech_mix_lengths"])
            assert len(waves[0]) == batch_size, len(waves[0])

        # FIXME(Chenda): will be incorrect when
        #  batch size is not 1 or multi-channel case
        if normalize_output_wav:
            waves = [
                (w / abs(w).max(dim=1, keepdim=True)[0] * 0.9).T.cpu().numpy()
                for w in waves
            ]  # list[(sample,batch)]
        else:
            waves = [w.T.cpu().numpy() for w in waves]
        for (i, w) in enumerate(waves):
            writers[i][keys[0]] = fs, w

    for writer in writers:
        writer.close()
Exemplo n.º 17
0
def inference(
    output_dir: str,
    maxlenratio: float,
    minlenratio: float,
    batch_size: int,
    dtype: str,
    beam_size: int,
    ngpu: int,
    seed: int,
    ctc_weight: float,
    lm_weight: float,
    penalty: float,
    nbest: int,
    num_workers: int,
    log_level: Union[int, str],
    data_path_and_name_and_type: Sequence[Tuple[str, str, str]],
    key_file: Optional[str],
    asr_train_config: str,
    asr_model_file: str,
    lm_train_config: Optional[str],
    lm_file: Optional[str],
    word_lm_train_config: Optional[str],
    word_lm_file: Optional[str],
    blank_symbol: str,
    token_type: Optional[str],
    bpemodel: Optional[str],
    allow_variable_data_keys: bool,
):
    assert check_argument_types()
    if batch_size > 1:
        raise NotImplementedError("batch decoding is not implemented")
    if word_lm_train_config is not None:
        raise NotImplementedError("Word LM is not implemented")
    if ngpu > 1:
        raise NotImplementedError("only single GPU decoding is supported")

    logging.basicConfig(
        level=log_level,
        format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
    )

    if ngpu >= 1:
        device = "cuda"
    else:
        device = "cpu"

    # 1. Set random-seed
    set_all_random_seed(seed)

    # 2. Build ASR model
    scorers = {}
    asr_model, asr_train_args = ASRTask.build_model_from_file(
        asr_train_config, asr_model_file, device
    )
    asr_model.eval()

    decoder = asr_model.decoder
    ctc = CTCPrefixScorer(ctc=asr_model.ctc, eos=asr_model.eos)
    token_list = asr_model.token_list
    scorers.update(
        decoder=decoder, ctc=ctc, length_bonus=LengthBonus(len(token_list)),
    )

    # 3. Build Language model
    if lm_train_config is not None:
        lm, lm_train_args = LMTask.build_model_from_file(
            lm_train_config, lm_file, device
        )
        scorers["lm"] = lm.lm

    # 4. Build BeamSearch object
    weights = dict(
        decoder=1.0 - ctc_weight, ctc=ctc_weight, lm=lm_weight, length_bonus=penalty,
    )
    beam_search = BeamSearch(
        beam_size=beam_size,
        weights=weights,
        scorers=scorers,
        sos=asr_model.sos,
        eos=asr_model.eos,
        vocab_size=len(token_list),
        token_list=token_list,
    )
    beam_search.to(device=device, dtype=getattr(torch, dtype)).eval()
    for scorer in scorers.values():
        if isinstance(scorer, torch.nn.Module):
            scorer.to(device=device, dtype=getattr(torch, dtype)).eval()
    logging.info(f"Beam_search: {beam_search}")
    logging.info(f"Decoding device={device}, dtype={dtype}")

    # 5. Build data-iterator
    loader, _, _ = ASRTask.build_non_sorted_iterator(
        data_path_and_name_and_type,
        dtype=dtype,
        batch_size=batch_size,
        key_file=key_file,
        num_workers=num_workers,
        preprocess_fn=ASRTask.build_preprocess_fn(asr_train_args, False),
        collate_fn=ASRTask.build_collate_fn(asr_train_args),
        allow_variable_data_keys=allow_variable_data_keys,
    )

    # 6. [Optional] Build Text converter: e.g. bpe-sym -> Text
    if token_type is None:
        token_type = asr_train_args.token_type
    if bpemodel is None:
        bpemodel = asr_train_args.bpemodel

    if token_type is None:
        tokenizer = None
    elif token_type == "bpe":
        if bpemodel is not None:
            tokenizer = build_tokenizer(token_type=token_type, bpemodel=bpemodel)
        else:
            tokenizer = None
    else:
        tokenizer = build_tokenizer(token_type=token_type)
    converter = TokenIDConverter(token_list=token_list)
    logging.info(f"Text tokenizer: {tokenizer}")

    # 7 .Start for-loop
    # FIXME(kamo): The output format should be discussed about
    with DatadirWriter(output_dir) as writer:
        for keys, batch in loader:
            assert isinstance(batch, dict), type(batch)
            assert all(isinstance(s, str) for s in keys), keys
            _bs = len(next(iter(batch.values())))
            assert len(keys) == _bs, f"{len(keys)} != {_bs}"

            with torch.no_grad():
                # a. To device
                batch = to_device(batch, device)

                # b. Forward Encoder
                enc, _ = asr_model.encode(**batch)
                assert len(enc) == batch_size, len(enc)

                # c. Passed the encoder result and the beam search
                nbest_hyps = beam_search(
                    x=enc[0], maxlenratio=maxlenratio, minlenratio=minlenratio
                )
                nbest_hyps = nbest_hyps[:nbest]

            # Only supporting batch_size==1
            key = keys[0]
            for n in range(1, nbest + 1):
                hyp = nbest_hyps[n - 1]
                assert isinstance(hyp, Hypothesis), type(hyp)

                # remove sos/eos and get results
                token_int = hyp.yseq[1:-1].tolist()

                # remove blank symbol id, which is assumed to be 0
                token_int = list(filter(lambda x: x != 0, token_int))

                # Change integer-ids to tokens
                token = converter.ids2tokens(token_int)

                # Create a directory: outdir/{n}best_recog
                ibest_writer = writer[f"{n}best_recog"]

                # Write the result to each files
                ibest_writer["token"][key] = " ".join(token)
                ibest_writer["token_int"][key] = " ".join(map(str, token_int))
                ibest_writer["score"][key] = str(hyp.score)

                if tokenizer is not None:
                    text = tokenizer.tokens2text(token)
                    ibest_writer["text"][key] = text
Exemplo n.º 18
0
    def __call__(
        self, speech_mix: Union[torch.Tensor, np.ndarray], fs: int = 8000
    ) -> List[torch.Tensor]:
        """Inference

        Args:
            speech_mix: Input speech data (Batch, Nsamples [, Channels])
            fs: sample rate
        Returns:
            [separated_audio1, separated_audio2, ...]

        """
        assert check_argument_types()

        # Input as audio signal
        if isinstance(speech_mix, np.ndarray):
            speech_mix = torch.as_tensor(speech_mix)

        assert speech_mix.dim() > 1, speech_mix.size()
        batch_size = speech_mix.size(0)
        speech_mix = speech_mix.to(getattr(torch, self.dtype))
        # lenghts: (B,)
        lengths = speech_mix.new_full(
            [batch_size], dtype=torch.long, fill_value=speech_mix.size(1)
        )

        # a. To device
        speech_mix = to_device(speech_mix, device=self.device)
        lengths = to_device(lengths, device=self.device)

        if self.segmenting and lengths[0] > self.segment_size * fs:
            # Segment-wise speech enhancement/separation
            overlap_length = int(np.round(fs * (self.segment_size - self.hop_size)))
            num_segments = int(
                np.ceil((speech_mix.size(1) - overlap_length) / (self.hop_size * fs))
            )
            t = T = int(self.segment_size * fs)
            pad_shape = speech_mix[:, :T].shape
            enh_waves = []
            range_ = trange if self.show_progressbar else range
            for i in range_(num_segments):
                st = int(i * self.hop_size * fs)
                en = st + T
                if en >= lengths[0]:
                    # en - st < T (last segment)
                    en = lengths[0]
                    speech_seg = speech_mix.new_zeros(pad_shape)
                    t = en - st
                    speech_seg[:, :t] = speech_mix[:, st:en]
                else:
                    t = T
                    speech_seg = speech_mix[:, st:en]  # B x T [x C]

                lengths_seg = speech_mix.new_full(
                    [batch_size], dtype=torch.long, fill_value=T
                )
                # b. Enhancement/Separation Forward
                feats, f_lens = self.enh_model.encoder(speech_seg, lengths_seg)
                feats, _, _ = self.enh_model.separator(feats, f_lens)
                processed_wav = [
                    self.enh_model.decoder(f, lengths_seg)[0] for f in feats
                ]
                if speech_seg.dim() > 2:
                    # multi-channel speech
                    speech_seg_ = speech_seg[:, self.ref_channel]
                else:
                    speech_seg_ = speech_seg

                if self.normalize_segment_scale:
                    # normalize the energy of each separated stream
                    # to match the input energy
                    processed_wav = [
                        self.normalize_scale(w, speech_seg_) for w in processed_wav
                    ]
                # List[torch.Tensor(num_spk, B, T)]
                enh_waves.append(torch.stack(processed_wav, dim=0))

            # c. Stitch the enhanced segments together
            waves = enh_waves[0]
            for i in range(1, num_segments):
                # permutation between separated streams in last and current segments
                perm = self.cal_permumation(
                    waves[:, :, -overlap_length:],
                    enh_waves[i][:, :, :overlap_length],
                    criterion="si_snr",
                )
                # repermute separated streams in current segment
                for batch in range(batch_size):
                    enh_waves[i][:, batch] = enh_waves[i][perm[batch], batch]

                if i == num_segments - 1:
                    enh_waves[i][:, :, t:] = 0
                    enh_waves_res_i = enh_waves[i][:, :, overlap_length:t]
                else:
                    enh_waves_res_i = enh_waves[i][:, :, overlap_length:]

                # overlap-and-add (average over the overlapped part)
                waves[:, :, -overlap_length:] = (
                    waves[:, :, -overlap_length:] + enh_waves[i][:, :, :overlap_length]
                ) / 2
                # concatenate the residual parts of the later segment
                waves = torch.cat([waves, enh_waves_res_i], dim=2)
            # ensure the stitched length is same as input
            assert waves.size(2) == speech_mix.size(1), (waves.shape, speech_mix.shape)
            waves = torch.unbind(waves, dim=0)
        else:
            # b. Enhancement/Separation Forward
            feats, f_lens = self.enh_model.encoder(speech_mix, lengths)
            feats, _, _ = self.enh_model.separator(feats, f_lens)
            waves = [self.enh_model.decoder(f, lengths)[0] for f in feats]

        assert len(waves) == self.num_spk, len(waves) == self.num_spk
        assert len(waves[0]) == batch_size, (len(waves[0]), batch_size)
        if self.normalize_output_wav:
            waves = [
                (w / abs(w).max(dim=1, keepdim=True)[0] * 0.9).cpu().numpy()
                for w in waves
            ]  # list[(batch, sample)]
        else:
            waves = [w.cpu().numpy() for w in waves]

        return waves
Exemplo n.º 19
0
def inference(
    output_dir: str,
    batch_size: int,
    dtype: str,
    ngpu: int,
    seed: int,
    num_workers: int,
    log_level: Union[int, str],
    data_path_and_name_and_type: Sequence[Tuple[str, str, str]],
    key_file: Optional[str],
    train_config: Optional[str],
    model_file: Optional[str],
    threshold: float,
    minlenratio: float,
    maxlenratio: float,
    use_att_constraint: bool,
    backward_window: int,
    forward_window: int,
    allow_variable_data_keys: bool,
    vocoder_conf: dict,
):
    """Perform TTS model decoding."""
    assert check_argument_types()
    if batch_size > 1:
        raise NotImplementedError("batch decoding is not implemented")
    if ngpu > 1:
        raise NotImplementedError("only single GPU decoding is supported")
    logging.basicConfig(
        level=log_level,
        format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
    )

    if ngpu >= 1:
        device = "cuda"
    else:
        device = "cpu"

    # 1. Set random-seed
    set_all_random_seed(seed)

    # 2. Build model
    model, train_args = TTSTask.build_model_from_file(train_config, model_file,
                                                      device)
    model.to(dtype=getattr(torch, dtype)).eval()
    tts = model.tts
    normalize = model.normalize
    logging.info(f"Normalization:\n{normalize}")
    logging.info(f"TTS:\n{tts}")

    # 3. Build data-iterator
    loader = TTSTask.build_streaming_iterator(
        data_path_and_name_and_type,
        dtype=dtype,
        batch_size=batch_size,
        key_file=key_file,
        num_workers=num_workers,
        preprocess_fn=TTSTask.build_preprocess_fn(train_args, False),
        collate_fn=TTSTask.build_collate_fn(train_args),
        allow_variable_data_keys=allow_variable_data_keys,
        inference=True,
    )

    # 4. Build converter from spectrogram to waveform
    if model.feats_extract is not None:
        vocoder_conf.update(model.feats_extract.get_parameters())
    if "n_fft" in vocoder_conf and "n_shift" in vocoder_conf and "fs" in vocoder_conf:
        spc2wav = Spectrogram2Waveform(**vocoder_conf)
        logging.info(f"Vocoder: {spc2wav}")
    else:
        spc2wav = None
        logging.info(
            "Vocoder is not used because vocoder_conf is not sufficient")

    # 5. Start for-loop
    output_dir = Path(output_dir)
    (output_dir / "norm").mkdir(parents=True, exist_ok=True)
    (output_dir / "denorm").mkdir(parents=True, exist_ok=True)
    (output_dir / "wav").mkdir(parents=True, exist_ok=True)
    (output_dir / "att_ws").mkdir(parents=True, exist_ok=True)
    (output_dir / "probs").mkdir(parents=True, exist_ok=True)

    with NpyScpWriter(
            output_dir / "norm",
            output_dir / "norm/feats.scp",
    ) as f, NpyScpWriter(output_dir / "denorm",
                         output_dir / "denorm/feats.scp") as g:
        for idx, (keys, batch) in enumerate(loader, 1):
            assert isinstance(batch, dict), type(batch)
            assert all(isinstance(s, str) for s in keys), keys
            _bs = len(next(iter(batch.values())))
            assert len(keys) == _bs, f"{len(keys)} != {_bs}"
            batch = to_device(batch, device)

            key = keys[0]
            # Change to single sequence and remove *_length
            # because inference() requires 1-seq, not mini-batch.
            _data = {
                k: v[0]
                for k, v in batch.items() if not k.endswith("_lengths")
            }
            start_time = time.perf_counter()

            _decode_conf = {
                "threshold": threshold,
                "maxlenratio": maxlenratio,
                "minlenratio": minlenratio,
            }
            if isinstance(tts, Tacotron2):
                _decode_conf.update({
                    "use_att_constraint": use_att_constraint,
                    "forward_window": forward_window,
                    "backward_window": backward_window,
                })
            outs, probs, att_ws = tts.inference(**_data, **_decode_conf)
            insize = next(iter(_data.values())).size(0) + 1
            logging.info("inference speed = {:.1f} frames / sec.".format(
                int(outs.size(0)) / (time.perf_counter() - start_time)))
            logging.info(f"{key} (size:{insize}->{outs.size(0)})")
            if outs.size(0) == insize * maxlenratio:
                logging.warning(
                    f"output length reaches maximum length ({key}).")
            f[key] = outs.cpu().numpy()

            # NOTE: normalize.inverse is in-place operation
            outs_denorm = normalize.inverse(outs[None])[0][0]
            g[key] = outs_denorm.cpu().numpy()

            # Lazy load to avoid the backend error
            matplotlib.use("Agg")
            import matplotlib.pyplot as plt
            from matplotlib.ticker import MaxNLocator

            # Plot attention weight
            att_ws = att_ws.cpu().numpy()

            if att_ws.ndim == 2:
                att_ws = att_ws[None][None]
            elif att_ws.ndim != 4:
                raise RuntimeError(f"Must be 2 or 4 dimension: {att_ws.ndim}")

            w, h = plt.figaspect(att_ws.shape[0] / att_ws.shape[1])
            fig = plt.Figure(figsize=(
                w * 1.3 * min(att_ws.shape[0], 2.5),
                h * 1.3 * min(att_ws.shape[1], 2.5),
            ))
            fig.suptitle(f"{key}")
            axes = fig.subplots(att_ws.shape[0], att_ws.shape[1])
            if len(att_ws) == 1:
                axes = [[axes]]
            for ax, att_w in zip(axes, att_ws):
                for ax_, att_w_ in zip(ax, att_w):
                    ax_.imshow(att_w_.astype(np.float32), aspect="auto")
                    ax_.set_xlabel("Input")
                    ax_.set_ylabel("Output")
                    ax_.xaxis.set_major_locator(MaxNLocator(integer=True))
                    ax_.yaxis.set_major_locator(MaxNLocator(integer=True))

            fig.tight_layout(rect=[0, 0.03, 1, 0.95])
            fig.savefig(output_dir / f"att_ws/{key}.png")
            fig.clf()

            # Plot stop token prediction
            probs = probs.cpu().numpy()

            fig = plt.Figure()
            ax = fig.add_subplot(1, 1, 1)
            ax.plot(probs)
            ax.set_title(f"{key}")
            ax.set_xlabel("Output")
            ax.set_ylabel("Stop probability")
            ax.set_ylim(0, 1)
            ax.grid(which="both")

            fig.tight_layout()
            fig.savefig(output_dir / f"probs/{key}.png")
            fig.clf()

            # TODO(kamo): Write scp
            if spc2wav is not None:
                wav = spc2wav(outs_denorm.cpu().numpy())
                sf.write(f"{output_dir}/wav/{key}.wav", wav, spc2wav.fs,
                         "PCM_16")
Exemplo n.º 20
0
def test_to_device(obj):
    to_device(obj, "cpu")
Exemplo n.º 21
0
    def train_one_epoch(
        cls,
        model: torch.nn.Module,
        iterator: Iterable[Tuple[List[str], Dict[str, torch.Tensor]]],
        optimizers: Sequence[torch.optim.Optimizer],
        schedulers: Sequence[Optional[AbsScheduler]],
        scaler: Optional[GradScaler],
        reporter: SubReporter,
        summary_writer: Optional[SummaryWriter],
        options: GANTrainerOptions,
        distributed_option: DistributedOption,
    ) -> bool:
        """Train one epoch."""
        assert check_argument_types()

        grad_noise = options.grad_noise
        accum_grad = options.accum_grad
        grad_clip = options.grad_clip
        grad_clip_type = options.grad_clip_type
        log_interval = options.log_interval
        no_forward_run = options.no_forward_run
        ngpu = options.ngpu
        use_wandb = options.use_wandb
        generator_first = options.generator_first
        distributed = distributed_option.distributed

        # Check unavailable options
        # TODO(kan-bayashi): Support the use of these options
        if accum_grad > 1:
            raise NotImplementedError(
                "accum_grad > 1 is not supported in GAN-based training."
            )
        if grad_noise:
            raise NotImplementedError(
                "grad_noise is not supported in GAN-based training."
            )

        if log_interval is None:
            try:
                log_interval = max(len(iterator) // 20, 10)
            except TypeError:
                log_interval = 100

        model.train()
        all_steps_are_invalid = True
        # [For distributed] Because iteration counts are not always equals between
        # processes, send stop-flag to the other processes if iterator is finished
        iterator_stop = torch.tensor(0).to("cuda" if ngpu > 0 else "cpu")

        start_time = time.perf_counter()
        for iiter, (_, batch) in enumerate(
            reporter.measure_iter_time(iterator, "iter_time"), 1
        ):
            assert isinstance(batch, dict), type(batch)

            if distributed:
                torch.distributed.all_reduce(iterator_stop, ReduceOp.SUM)
                if iterator_stop > 0:
                    break

            batch = to_device(batch, "cuda" if ngpu > 0 else "cpu")
            if no_forward_run:
                all_steps_are_invalid = False
                continue

            turn_start_time = time.perf_counter()
            if generator_first:
                turns = ["generator", "discriminator"]
            else:
                turns = ["discriminator", "generator"]
            for turn in turns:
                with autocast(scaler is not None):
                    with reporter.measure_time(f"{turn}_forward_time"):
                        retval = model(forward_generator=turn == "generator", **batch)

                        # Note(kamo):
                        # Supporting two patterns for the returned value from the model
                        #   a. dict type
                        if isinstance(retval, dict):
                            loss = retval["loss"]
                            stats = retval["stats"]
                            weight = retval["weight"]
                            optim_idx = retval.get("optim_idx")
                            if optim_idx is not None and not isinstance(optim_idx, int):
                                if not isinstance(optim_idx, torch.Tensor):
                                    raise RuntimeError(
                                        "optim_idx must be int or 1dim torch.Tensor, "
                                        f"but got {type(optim_idx)}"
                                    )
                                if optim_idx.dim() >= 2:
                                    raise RuntimeError(
                                        "optim_idx must be int or 1dim torch.Tensor, "
                                        f"but got {optim_idx.dim()}dim tensor"
                                    )
                                if optim_idx.dim() == 1:
                                    for v in optim_idx:
                                        if v != optim_idx[0]:
                                            raise RuntimeError(
                                                "optim_idx must be 1dim tensor "
                                                "having same values for all entries"
                                            )
                                    optim_idx = optim_idx[0].item()
                                else:
                                    optim_idx = optim_idx.item()

                        # b. tuple or list type
                        else:
                            raise RuntimeError("model output must be dict.")

                    stats = {k: v for k, v in stats.items() if v is not None}
                    if ngpu > 1 or distributed:
                        # Apply weighted averaging for loss and stats
                        loss = (loss * weight.type(loss.dtype)).sum()

                        # if distributed, this method can also apply all_reduce()
                        stats, weight = recursive_average(stats, weight, distributed)

                        # Now weight is summation over all workers
                        loss /= weight

                    if distributed:
                        # NOTE(kamo): Multiply world_size since DistributedDataParallel
                        # automatically normalizes the gradient by world_size.
                        loss *= torch.distributed.get_world_size()

                reporter.register(stats, weight)

                with reporter.measure_time(f"{turn}_backward_time"):
                    if scaler is not None:
                        # Scales loss.  Calls backward() on scaled loss
                        # to create scaled gradients.
                        # Backward passes under autocast are not recommended.
                        # Backward ops run in the same dtype autocast chose
                        # for corresponding forward ops.
                        scaler.scale(loss).backward()
                    else:
                        loss.backward()

                if scaler is not None:
                    # Unscales the gradients of optimizer's assigned params in-place
                    for iopt, optimizer in enumerate(optimizers):
                        if optim_idx is not None and iopt != optim_idx:
                            continue
                        scaler.unscale_(optimizer)

                # TODO(kan-bayashi): Compute grad norm without clipping
                grad_norm = None
                if grad_clip > 0.0:
                    # compute the gradient norm to check if it is normal or not
                    grad_norm = torch.nn.utils.clip_grad_norm_(
                        model.parameters(),
                        max_norm=grad_clip,
                        norm_type=grad_clip_type,
                    )
                    # PyTorch<=1.4, clip_grad_norm_ returns float value
                    if not isinstance(grad_norm, torch.Tensor):
                        grad_norm = torch.tensor(grad_norm)

                if grad_norm is None or torch.isfinite(grad_norm):
                    all_steps_are_invalid = False
                    with reporter.measure_time(f"{turn}_optim_step_time"):
                        for iopt, (optimizer, scheduler) in enumerate(
                            zip(optimizers, schedulers)
                        ):
                            if optim_idx is not None and iopt != optim_idx:
                                continue
                            if scaler is not None:
                                # scaler.step() first unscales the gradients of
                                # the optimizer's assigned params.
                                scaler.step(optimizer)
                                # Updates the scale for next iteration.
                                scaler.update()
                            else:
                                optimizer.step()
                            if isinstance(scheduler, AbsBatchStepScheduler):
                                scheduler.step()
                else:
                    logging.warning(
                        f"The grad norm is {grad_norm}. " "Skipping updating the model."
                    )
                    # Must invoke scaler.update() if unscale_() is used in the
                    # iteration to avoid the following error:
                    #   RuntimeError: unscale_() has already been called
                    #   on this optimizer since the last update().
                    # Note that if the gradient has inf/nan values,
                    # scaler.step skips optimizer.step().
                    if scaler is not None:
                        for iopt, optimizer in enumerate(optimizers):
                            if optim_idx is not None and iopt != optim_idx:
                                continue
                            scaler.step(optimizer)
                            scaler.update()

                for iopt, optimizer in enumerate(optimizers):
                    # NOTE(kan-bayashi): In the case of GAN, we need to clear
                    #   the gradient of both optimizers after every update.
                    optimizer.zero_grad()

                # Register lr and train/load time[sec/step],
                # where step refers to accum_grad * mini-batch
                reporter.register(
                    {
                        f"optim{optim_idx}_lr{i}": pg["lr"]
                        for i, pg in enumerate(optimizers[optim_idx].param_groups)
                        if "lr" in pg
                    },
                )
                reporter.register(
                    {f"{turn}_train_time": time.perf_counter() - turn_start_time}
                )
                turn_start_time = time.perf_counter()

            reporter.register({"train_time": time.perf_counter() - start_time})
            start_time = time.perf_counter()

            # NOTE(kamo): Call log_message() after next()
            reporter.next()
            if iiter % log_interval == 0:
                logging.info(reporter.log_message(-log_interval))
                if summary_writer is not None:
                    reporter.tensorboard_add_scalar(summary_writer, -log_interval)
                if use_wandb:
                    reporter.wandb_log()

        else:
            if distributed:
                iterator_stop.fill_(1)
                torch.distributed.all_reduce(iterator_stop, ReduceOp.SUM)

        return all_steps_are_invalid
Exemplo n.º 22
0
def test_to_device_cuda():
    obj = {"a": [torch.tensor([0, 1])]}
    obj2 = to_device(obj, "cuda")
    assert obj2["a"][0].device == torch.device("cuda:0")
Exemplo n.º 23
0
    def train_one_epoch(
        cls,
        model: torch.nn.Module,
        iterator: Iterable[Tuple[List[str], Dict[str, torch.Tensor]]],
        optimizers: Sequence[torch.optim.Optimizer],
        schedulers: Sequence[Optional[AbsScheduler]],
        scaler: Optional[GradScaler],
        reporter: SubReporter,
        summary_writer: Optional[SummaryWriter],
        options: TrainerOptions,
    ) -> bool:
        assert check_argument_types()

        # Note(kamo): assumes one optimizer
        assert cls.num_optimizers == 1, cls.num_optimizers
        assert len(optimizers) == 1, len(optimizers)
        optimizer = optimizers[0]
        scheduler = schedulers[0]

        grad_noise = options.grad_noise
        accum_grad = options.accum_grad
        grad_clip = options.grad_clip
        grad_clip_type = options.grad_clip_type
        log_interval = options.log_interval
        no_forward_run = options.no_forward_run
        ngpu = options.ngpu
        distributed = isinstance(model,
                                 torch.nn.parallel.DistributedDataParallel)

        if log_interval is None:
            try:
                log_interval = max(len(iterator) // 20, 10)
            except TypeError:
                log_interval = 100

        model.train()
        all_steps_are_invalid = True
        # [For distributed] Because iteration counts are not always equals between
        # processes, send stop-flag to the other processes if iterator is finished
        iterator_stop = torch.tensor(0).to("cuda" if ngpu > 0 else "cpu")

        start_time = time.perf_counter()
        for iiter, (_, batch) in enumerate(
                reporter.measure_iter_time(iterator, "iter_time"), 1):
            assert isinstance(batch, dict), type(batch)

            if distributed:
                torch.distributed.all_reduce(iterator_stop, ReduceOp.SUM)
                if iterator_stop > 0:
                    break

            batch = to_device(batch, "cuda" if ngpu > 0 else "cpu")
            if no_forward_run:
                all_steps_are_invalid = False
                continue

            with autocast(scaler is not None):
                with reporter.measure_time("forward_time"):
                    loss, stats, weight = model(**batch)
                stats = {k: v for k, v in stats.items() if v is not None}
                if ngpu > 1 or distributed:
                    # Apply weighted averaging for loss and stats
                    loss = (loss * weight.type(loss.dtype)).sum()

                    # if distributed, this method can also apply all_reduce()
                    stats, weight = recursive_average(stats, weight,
                                                      distributed)

                    # Now weight is summation over all workers
                    loss /= weight
                if distributed:
                    # NOTE(kamo): Multiply world_size because DistributedDataParallel
                    # automatically normalizes the gradient by world_size.
                    loss *= torch.distributed.get_world_size()

                loss /= accum_grad

            reporter.register(stats, weight)

            with reporter.measure_time("backward_time"):
                if scaler is not None:
                    # Scales loss.  Calls backward() on scaled loss
                    # to create scaled gradients.
                    # Backward passes under autocast are not recommended.
                    # Backward ops run in the same dtype autocast chose
                    # for corresponding forward ops.
                    scaler.scale(loss).backward()
                else:
                    loss.backward()

            if iiter % accum_grad == 0:
                if scaler is not None:
                    # Unscales the gradients of optimizer's assigned params in-place
                    scaler.unscale_(optimizer)

                # gradient noise injection
                if grad_noise:
                    add_gradient_noise(
                        model,
                        reporter.get_total_count(),
                        duration=100,
                        eta=1.0,
                        scale_factor=0.55,
                    )

                # compute the gradient norm to check if it is normal or not
                grad_norm = torch.nn.utils.clip_grad_norm_(
                    model.parameters(),
                    max_norm=grad_clip,
                    norm_type=grad_clip_type,
                )
                # PyTorch<=1.4, clip_grad_norm_ returns float value
                if not isinstance(grad_norm, torch.Tensor):
                    grad_norm = torch.tensor(grad_norm)

                if not torch.isfinite(grad_norm):
                    logging.warning(
                        f"The grad norm is {grad_norm}. Skipping updating the model."
                    )

                    # Must invoke scaler.update() if unscale_() is used in the iteration
                    # to avoid the following error:
                    #   RuntimeError: unscale_() has already been called
                    #   on this optimizer since the last update().
                    # Note that if the gradient has inf/nan values,
                    # scaler.step skips optimizer.step().
                    if scaler is not None:
                        scaler.step(optimizer)
                        scaler.update()

                else:
                    all_steps_are_invalid = False
                    with reporter.measure_time("optim_step_time"):
                        if scaler is not None:
                            # scaler.step() first unscales the gradients of
                            # the optimizer's assigned params.
                            scaler.step(optimizer)
                            # Updates the scale for next iteration.
                            scaler.update()
                        else:
                            optimizer.step()
                    if isinstance(scheduler, AbsBatchStepScheduler):
                        scheduler.step()
                optimizer.zero_grad()

                # Register lr and train/load time[sec/step],
                # where step refers to accum_grad * mini-batch
                reporter.register(
                    dict(
                        {
                            f"lr_{i}": pg["lr"]
                            for i, pg in enumerate(optimizer.param_groups)
                            if "lr" in pg
                        },
                        train_time=time.perf_counter() - start_time,
                    ), )
                start_time = time.perf_counter()

            # NOTE(kamo): Call log_message() after next()
            reporter.next()
            if iiter % log_interval == 0:
                logging.info(reporter.log_message(-log_interval))
                if summary_writer is not None:
                    reporter.tensorboard_add_scalar(summary_writer,
                                                    -log_interval)

        else:
            if distributed:
                iterator_stop.fill_(1)
                torch.distributed.all_reduce(iterator_stop, ReduceOp.SUM)

        return all_steps_are_invalid
Exemplo n.º 24
0
    def train_one_epoch(
        cls,
        model: torch.nn.Module,
        iterator: Iterable[Tuple[List[str], Dict[str, torch.Tensor]]],
        optimizers: Sequence[torch.optim.Optimizer],
        schedulers: Sequence[Optional[AbsScheduler]],
        scaler: Optional[GradScaler],
        reporter: SubReporter,
        summary_writer,
        options: TrainerOptions,
        distributed_option: DistributedOption,
    ) -> bool:
        assert check_argument_types()

        grad_noise = options.grad_noise
        accum_grad = options.accum_grad
        grad_clip = options.grad_clip
        grad_clip_type = options.grad_clip_type
        log_interval = options.log_interval
        no_forward_run = options.no_forward_run
        ngpu = options.ngpu
        use_wandb = options.use_wandb
        create_graph_in_tensorboard = options.create_graph_in_tensorboard
        distributed = distributed_option.distributed

        if log_interval is None:
            try:
                log_interval = max(len(iterator) // 20, 10)
            except TypeError:
                log_interval = 100

        model.train()
        all_steps_are_invalid = True
        # [For distributed] Because iteration counts are not always equals between
        # processes, send stop-flag to the other processes if iterator is finished
        iterator_stop = torch.tensor(0).to("cuda" if ngpu > 0 else "cpu")

        start_time = time.perf_counter()
        for iiter, (utt_id, batch) in enumerate(
                reporter.measure_iter_time(iterator, "iter_time"), 1):
            assert isinstance(batch, dict), type(batch)

            if distributed:
                torch.distributed.all_reduce(iterator_stop, ReduceOp.SUM)
                if iterator_stop > 0:
                    break

            batch["utt_id"] = utt_id

            batch = to_device(batch, "cuda" if ngpu > 0 else "cpu")
            if no_forward_run:
                all_steps_are_invalid = False
                continue

            if (create_graph_in_tensorboard and iiter == 1
                    and summary_writer is not None):
                if distributed:
                    _model = getattr(model, "module")
                else:
                    _model = model
                    if _model is not None:
                        try:
                            _args = kwargs2args(_model.forward, batch)
                        except (ValueError, TypeError):
                            logging.warning(
                                "inpect.signature() is failed for the model. "
                                "The graph can't be added for tensorboard.")
                        else:
                            try:
                                summary_writer.add_graph(
                                    _model, _args, use_strict_trace=False)
                            except Exception:
                                logging.warning(
                                    "summary_writer.add_graph() "
                                    "is failed for the model. "
                                    "The graph can't be added for tensorboard."
                                )
                            del _args
                    else:
                        logging.warning(
                            "model.module is not found (This should be a bug.)"
                        )
                del _model

            with autocast(scaler is not None):
                with reporter.measure_time("forward_time"):
                    retval = model(**batch)

                    # Note(kamo):
                    # Supporting two patterns for the returned value from the model
                    #   a. dict type
                    if isinstance(retval, dict):
                        loss = retval["loss"]
                        stats = retval["stats"]
                        weight = retval["weight"]
                        optim_idx = retval.get("optim_idx")
                        if optim_idx is not None and not isinstance(
                                optim_idx, int):
                            if not isinstance(optim_idx, torch.Tensor):
                                raise RuntimeError(
                                    "optim_idx must be int or 1dim torch.Tensor, "
                                    f"but got {type(optim_idx)}")
                            if optim_idx.dim() >= 2:
                                raise RuntimeError(
                                    "optim_idx must be int or 1dim torch.Tensor, "
                                    f"but got {optim_idx.dim()}dim tensor")
                            if optim_idx.dim() == 1:
                                for v in optim_idx:
                                    if v != optim_idx[0]:
                                        raise RuntimeError(
                                            "optim_idx must be 1dim tensor "
                                            "having same values for all entries"
                                        )
                                optim_idx = optim_idx[0].item()
                            else:
                                optim_idx = optim_idx.item()

                    #   b. tuple or list type
                    else:
                        loss, stats, weight = retval
                        optim_idx = None

                stats = {k: v for k, v in stats.items() if v is not None}
                if ngpu > 1 or distributed:
                    # Apply weighted averaging for loss and stats
                    loss = (loss * weight.type(loss.dtype)).sum()

                    # if distributed, this method can also apply all_reduce()
                    stats, weight = recursive_average(stats, weight,
                                                      distributed)

                    # Now weight is summation over all workers
                    loss /= weight
                if distributed:
                    # NOTE(kamo): Multiply world_size because DistributedDataParallel
                    # automatically normalizes the gradient by world_size.
                    loss *= torch.distributed.get_world_size()

                loss /= accum_grad

            reporter.register(stats, weight)

            with reporter.measure_time("backward_time"):
                if scaler is not None:
                    # Scales loss.  Calls backward() on scaled loss
                    # to create scaled gradients.
                    # Backward passes under autocast are not recommended.
                    # Backward ops run in the same dtype autocast chose
                    # for corresponding forward ops.
                    scaler.scale(loss).backward()
                else:
                    loss.backward()

            if iiter % accum_grad == 0:
                if scaler is not None:
                    # Unscales the gradients of optimizer's assigned params in-place
                    for iopt, optimizer in enumerate(optimizers):
                        if optim_idx is not None and iopt != optim_idx:
                            continue
                        scaler.unscale_(optimizer)

                # gradient noise injection
                if grad_noise:
                    add_gradient_noise(
                        model,
                        reporter.get_total_count(),
                        duration=100,
                        eta=1.0,
                        scale_factor=0.55,
                    )

                # compute the gradient norm to check if it is normal or not
                grad_norm = torch.nn.utils.clip_grad_norm_(
                    model.parameters(),
                    max_norm=grad_clip,
                    norm_type=grad_clip_type,
                )
                # PyTorch<=1.4, clip_grad_norm_ returns float value
                if not isinstance(grad_norm, torch.Tensor):
                    grad_norm = torch.tensor(grad_norm)

                if not torch.isfinite(grad_norm):
                    logging.warning(
                        f"The grad norm is {grad_norm}. Skipping updating the model."
                    )

                    # Must invoke scaler.update() if unscale_() is used in the iteration
                    # to avoid the following error:
                    #   RuntimeError: unscale_() has already been called
                    #   on this optimizer since the last update().
                    # Note that if the gradient has inf/nan values,
                    # scaler.step skips optimizer.step().
                    if scaler is not None:
                        for iopt, optimizer in enumerate(optimizers):
                            if optim_idx is not None and iopt != optim_idx:
                                continue
                            scaler.step(optimizer)
                            scaler.update()

                else:
                    all_steps_are_invalid = False
                    with reporter.measure_time("optim_step_time"):
                        for iopt, (optimizer, scheduler) in enumerate(
                                zip(optimizers, schedulers)):
                            if optim_idx is not None and iopt != optim_idx:
                                continue
                            if scaler is not None:
                                # scaler.step() first unscales the gradients of
                                # the optimizer's assigned params.
                                scaler.step(optimizer)
                                # Updates the scale for next iteration.
                                scaler.update()
                            else:
                                optimizer.step()
                            if isinstance(scheduler, AbsBatchStepScheduler):
                                scheduler.step()
                for iopt, optimizer in enumerate(optimizers):
                    if optim_idx is not None and iopt != optim_idx:
                        continue
                    optimizer.zero_grad()

                # Register lr and train/load time[sec/step],
                # where step refers to accum_grad * mini-batch
                reporter.register(
                    dict(
                        {
                            f"optim{i}_lr{j}": pg["lr"]
                            for i, optimizer in enumerate(optimizers)
                            for j, pg in enumerate(optimizer.param_groups)
                            if "lr" in pg
                        },
                        train_time=time.perf_counter() - start_time,
                    ), )
                start_time = time.perf_counter()

            # NOTE(kamo): Call log_message() after next()
            reporter.next()
            if iiter % log_interval == 0:
                logging.info(reporter.log_message(-log_interval))
                if summary_writer is not None:
                    reporter.tensorboard_add_scalar(summary_writer,
                                                    -log_interval)
                if use_wandb:
                    reporter.wandb_log()

        else:
            if distributed:
                iterator_stop.fill_(1)
                torch.distributed.all_reduce(iterator_stop, ReduceOp.SUM)
        return all_steps_are_invalid
Exemplo n.º 25
0
    def __call__(self,
                 speech: Union[torch.Tensor, np.ndarray],
                 fs: int = 8000) -> List[torch.Tensor]:
        """Inference

        Args:
            speech: Input speech data (Batch, Nsamples [, Channels])
            fs: sample rate
        Returns:
            [speaker_info1, speaker_info2, ...]

        """
        assert check_argument_types()

        # Input as audio signal
        if isinstance(speech, np.ndarray):
            speech = torch.as_tensor(speech)

        assert speech.dim() > 1, speech.size()
        batch_size = speech.size(0)
        speech = speech.to(getattr(torch, self.dtype))
        # lenghts: (B,)
        lengths = speech.new_full([batch_size],
                                  dtype=torch.long,
                                  fill_value=speech.size(1))

        # a. To device
        speech = to_device(speech, device=self.device)
        lengths = to_device(lengths, device=self.device)

        if self.segmenting and lengths[0] > self.segment_size * fs:
            # Segment-wise speaker diarization
            num_segments = int(
                np.ceil(speech.size(1) / (self.segment_size * fs)))
            t = T = int(self.segment_size * fs)
            pad_shape = speech[:, :T].shape
            diarized_wavs = []
            range_ = trange if self.show_progressbar else range
            for i in range_(num_segments):
                st = int(i * self.segment_size * fs)
                en = st + T
                if en >= lengths[0]:
                    # en - st < T (last segment)
                    en = lengths[0]
                    speech_seg = speech.new_zeros(pad_shape)
                    t = en - st
                    speech_seg[:, :t] = speech[:, st:en]
                else:
                    t = T
                    speech_seg = speech[:, st:en]  # B x T [x C]

                lengths_seg = speech.new_full([batch_size],
                                              dtype=torch.long,
                                              fill_value=T)
                # b. Diarization Forward
                encoder_out, encoder_out_lens = self.diar_model.encode(
                    speech_seg, lengths_seg)
                spk_prediction = self.diar_model.decoder(
                    encoder_out, encoder_out_lens)

                # List[torch.Tensor(B, T, num_spks)]
                diarized_wavs.append(spk_prediction)

            spk_prediction = torch.cat(diarized_wavs, dim=1)
        else:
            # b. Diarization Forward
            encoder_out, encoder_out_lens = self.diar_model.encode(
                speech, lengths)
            spk_prediction = self.diar_model.decoder(encoder_out,
                                                     encoder_out_lens)

        assert spk_prediction.size(2) == self.num_spk, (
            spk_prediction.size(2),
            self.num_spk,
        )
        assert spk_prediction.size(0) == batch_size, (
            spk_prediction.size(0),
            batch_size,
        )
        spk_prediction = spk_prediction.cpu().numpy()
        spk_prediction = 1 / (1 + np.exp(-spk_prediction))

        return spk_prediction
Exemplo n.º 26
0
    def apply_frontend(
            self,
            speech: torch.Tensor,
            is_final: bool = False) -> Tuple[torch.Tensor, torch.Tensor]:
        """Forward frontend.

        Args:
            speech: Speech data. (S)
            is_final: Whether speech corresponds to the final (or only) chunk of data.

        Returns:
            feats: Features sequence. (1, T_in, F)
            feats_lengths: Features sequence length. (1, T_in, F)

        """
        if self.frontend_cache is not None:
            speech = torch.cat(
                [self.frontend_cache["waveform_buffer"], speech], dim=0)

        if is_final:
            if self.streaming and speech.size(0) < self.last_chunk_length:
                pad = torch.zeros(self.last_chunk_length - speech.size(0),
                                  dtype=speech.dtype)
                speech = torch.cat([speech, pad], dim=0)

            speech_to_process = speech
            waveform_buffer = None
        else:
            n_frames = (speech.size(0) - (self.frontend_window_size -
                                          self.hop_length)) // self.hop_length

            n_residual = (speech.size(0) - (self.frontend_window_size -
                                            self.hop_length)) % self.hop_length

            speech_to_process = speech.narrow(
                0,
                0,
                (self.frontend_window_size - self.hop_length) +
                n_frames * self.hop_length,
            )

            waveform_buffer = speech.narrow(
                0,
                speech.size(0) -
                (self.frontend_window_size - self.hop_length) - n_residual,
                (self.frontend_window_size - self.hop_length) + n_residual,
            ).clone()

        speech_to_process = speech_to_process.unsqueeze(0).to(
            getattr(torch, self.dtype))
        lengths = speech_to_process.new_full(
            [1], dtype=torch.long, fill_value=speech_to_process.size(1))
        batch = {"speech": speech_to_process, "speech_lengths": lengths}
        batch = to_device(batch, device=self.device)

        feats, feats_lengths = self.asr_model._extract_feats(**batch)
        if self.asr_model.normalize is not None:
            feats, feats_lengths = self.asr_model.normalize(
                feats, feats_lengths)

        if is_final:
            if self.frontend_cache is None:
                pass
            else:
                feats = feats.narrow(
                    1,
                    math.ceil(
                        math.ceil(self.frontend_window_size / self.hop_length)
                        / 2),
                    feats.size(1) - math.ceil(
                        math.ceil(self.frontend_window_size / self.hop_length)
                        / 2),
                )
        else:
            if self.frontend_cache is None:
                feats = feats.narrow(
                    1,
                    0,
                    feats.size(1) - math.ceil(
                        math.ceil(self.frontend_window_size / self.hop_length)
                        / 2),
                )
            else:
                feats = feats.narrow(
                    1,
                    math.ceil(
                        math.ceil(self.frontend_window_size / self.hop_length)
                        / 2),
                    feats.size(1) - 2 * math.ceil(
                        math.ceil(self.frontend_window_size / self.hop_length)
                        / 2),
                )

        feats_lengths = feats.new_full([1],
                                       dtype=torch.long,
                                       fill_value=feats.size(1))

        if is_final:
            self.frontend_cache = None
        else:
            self.frontend_cache = {"waveform_buffer": waveform_buffer}

        return feats, feats_lengths