def test_batchfy_hyp(): vocab_size = 5 eos = -1 # simplest beam search beam = BatchBeamSearch( beam_size=3, vocab_size=vocab_size, weights={ "a": 0.5, "b": 0.5 }, scorers={ "a": LengthBonus(vocab_size), "b": LengthBonus(vocab_size) }, pre_beam_score_key="a", sos=eos, eos=eos, ) hs = [ Hypothesis( yseq=torch.tensor([0, 1, 2]), score=torch.tensor(0.15), scores={ "a": torch.tensor(0.1), "b": torch.tensor(0.2) }, states={ "a": 1, "b": 2 }, ), Hypothesis( yseq=torch.tensor([0, 1]), score=torch.tensor(0.1), scores={ "a": torch.tensor(0.0), "b": torch.tensor(0.2) }, states={ "a": 3, "b": 4 }, ), ] bs = beam.batchfy(hs) assert torch.all(bs.yseq == torch.tensor([[0, 1, 2], [0, 1, eos]])) assert torch.all(bs.score == torch.tensor([0.15, 0.1])) assert torch.all(bs.scores["a"] == torch.tensor([0.1, 0.0])) assert torch.all(bs.scores["b"] == torch.tensor([0.2, 0.2])) assert bs.states["a"] == [1, 3] assert bs.states["b"] == [2, 4] us = beam.unbatchfy(bs) for i in range(len(hs)): assert us[i].yseq.tolist() == hs[i].yseq.tolist() assert us[i].score == hs[i].score assert us[i].scores == hs[i].scores assert us[i].states == hs[i].states
def init_hyp(self, x: torch.Tensor) -> BatchHypothesis: """Get an initial hypothesis data. Args: x (torch.Tensor): The encoder output feature Returns: Hypothesis: The initial hypothesis. """ init_states = dict() init_scores = dict() for k, d in self.scorers.items(): init_states[k] = d.batch_init_state(x) init_scores[k] = 0.0 return self.batchfy( [ Hypothesis( score=0.0, scores=init_scores, states=init_states, yseq=torch.tensor([self.sos], device=x.device), ) ] )
def forward(self, enc_out: torch.Tensor) -> List[Hypothesis]: """Perform Mask-CTC inference""" # greedy ctc outputs enc_out = enc_out.unsqueeze(0) ctc_probs, ctc_ids = torch.exp(self.ctc.log_softmax(enc_out)).max(dim=-1) y_hat = torch.stack([x[0] for x in groupby(ctc_ids[0])]) y_idx = torch.nonzero(y_hat != 0).squeeze(-1) logging.info("ctc:{}".format(self.ids2text(y_hat[y_idx].tolist()))) # calculate token-level ctc probabilities by taking # the maximum probability of consecutive frames with # the same ctc symbols probs_hat = [] cnt = 0 for i, y in enumerate(y_hat.tolist()): probs_hat.append(-1) while cnt < ctc_ids.shape[1] and y == ctc_ids[0][cnt]: if probs_hat[i] < ctc_probs[0][cnt]: probs_hat[i] = ctc_probs[0][cnt].item() cnt += 1 probs_hat = torch.from_numpy(numpy.array(probs_hat)) # mask ctc outputs based on ctc probabilities p_thres = self.threshold_probability mask_idx = torch.nonzero(probs_hat[y_idx] < p_thres).squeeze(-1) confident_idx = torch.nonzero(probs_hat[y_idx] >= p_thres).squeeze(-1) mask_num = len(mask_idx) y_in = torch.zeros(1, len(y_idx), dtype=torch.long) + self.mask_token y_in[0][confident_idx] = y_hat[y_idx][confident_idx] logging.info("msk:{}".format(self.ids2text(y_in[0].tolist()))) # iterative decoding if not mask_num == 0: K = self.n_iterations num_iter = K if mask_num >= K and K > 0 else mask_num for t in range(num_iter - 1): pred, _ = self.mlm(enc_out, [enc_out.size(1)], y_in, [y_in.size(1)]) pred_score, pred_id = pred[0][mask_idx].max(dim=-1) cand = torch.topk(pred_score, mask_num // num_iter, -1)[1] y_in[0][mask_idx[cand]] = pred_id[cand] mask_idx = torch.nonzero(y_in[0] == self.mask_token).squeeze(-1) logging.info("msk:{}".format(self.ids2text(y_in[0].tolist()))) # predict leftover masks (|masks| < mask_num // num_iter) pred, _ = self.mlm(enc_out, [enc_out.size(1)], y_in, [y_in.size(1)]) y_in[0][mask_idx] = pred[0][mask_idx].argmax(dim=-1) logging.info("msk:{}".format(self.ids2text(y_in[0].tolist()))) # pad with mask tokens to ensure compatibility with sos/eos tokens yseq = torch.tensor( [self.mask_token] + y_in.tolist()[0] + [self.mask_token], device=y_in.device ) return Hypothesis(yseq=yseq)
def _select(self, hyps: BatchHypothesis, i: int) -> Hypothesis: return Hypothesis( yseq=hyps.yseq[i, : hyps.length[i]], score=hyps.score[i], scores={k: v[i] for k, v in hyps.scores.items()}, states={ k: self.scorers[k].select_state(v, i) for k, v in hyps.states.items() }, )
def unbatchfy(self, batch_hyps: BatchHypothesis) -> List[Hypothesis]: """Revert batch to list.""" return [ Hypothesis( yseq=batch_hyps.yseq[i][:batch_hyps.length[i]], score=batch_hyps.score[i], scores={k: batch_hyps.scores[k][i] for k in self.scorers}, states={ k: v.select_state(batch_hyps.states[k], i) for k, v in self.scorers.items() }) for i in range(len(batch_hyps.length)) ]
def extend(self, x: torch.Tensor, hyps: Hypothesis) -> List[Hypothesis]: """Extend probabilities and states with more encoded chunks. Args: x (torch.Tensor): The extended encoder output feature hyps (Hypothesis): Current list of hypothesis Returns: Hypothesis: The extended hypothesis """ for k, d in self.scorers.items(): if hasattr(d, "extend_prob"): d.extend_prob(x) if hasattr(d, "extend_state"): hyps.states[k] = d.extend_state(hyps.states[k])
def inference( output_dir: str, batch_size: int, dtype: str, ngpu: int, seed: int, num_workers: int, log_level: Union[int, str], data_path_and_name_and_type: Sequence[Tuple[str, str, str]], key_file: Optional[str], asr_train_config: str, asr_model_file: str, model_tag: Optional[str], token_type: Optional[str], bpemodel: Optional[str], allow_variable_data_keys: bool, maskctc_n_iterations: int, maskctc_threshold_probability: float, ): assert check_argument_types() if batch_size > 1: raise NotImplementedError("batch decoding is not implemented") if ngpu > 1: raise NotImplementedError("only single GPU decoding is supported") logging.basicConfig( level=log_level, format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s", ) if ngpu >= 1: device = "cuda" else: device = "cpu" # 1. Set random-seed set_all_random_seed(seed) # 2. Build speech2text speech2text_kwargs = dict( asr_train_config=asr_train_config, asr_model_file=asr_model_file, token_type=token_type, bpemodel=bpemodel, device=device, batch_size=batch_size, dtype=dtype, maskctc_n_iterations=maskctc_n_iterations, maskctc_threshold_probability=maskctc_threshold_probability, ) speech2text = Speech2Text.from_pretrained( model_tag=model_tag, **speech2text_kwargs, ) # 3. Build data-iterator loader = ASRTask.build_streaming_iterator( data_path_and_name_and_type, dtype=dtype, batch_size=batch_size, key_file=key_file, num_workers=num_workers, preprocess_fn=ASRTask.build_preprocess_fn(speech2text.asr_train_args, False), collate_fn=ASRTask.build_collate_fn(speech2text.asr_train_args, False), allow_variable_data_keys=allow_variable_data_keys, inference=True, ) # 7 .Start for-loop with DatadirWriter(output_dir) as writer: for keys, batch in loader: assert isinstance(batch, dict), type(batch) assert all(isinstance(s, str) for s in keys), keys _bs = len(next(iter(batch.values()))) assert len(keys) == _bs, f"{len(keys)} != {_bs}" batch = { k: v[0] for k, v in batch.items() if not k.endswith("_lengths") } try: results = speech2text(**batch) except TooShortUttError as e: logging.warning(f"Utterance {keys} {e}") hyp = Hypothesis(score=0.0, scores={}, states={}, yseq=[]) results = [[" ", ["<space>"], [2], hyp]] # Only supporting batch_size==1 key = keys[0] (text, token, token_int, hyp) = results[0] # Create a directory: outdir/{n}best_recog ibest_writer = writer["1best_recog"] # Write the result to each file ibest_writer["token"][key] = " ".join(token) ibest_writer["token_int"][key] = " ".join(map(str, token_int)) ibest_writer["score"][key] = str(hyp.score) if text is not None: ibest_writer["text"][key] = text
def inference( output_dir: str, maxlenratio: float, minlenratio: float, batch_size: int, dtype: str, beam_size: int, ngpu: int, seed: int, ctc_weight: float, lm_weight: float, penalty: float, nbest: int, num_workers: int, log_level: Union[int, str], data_path_and_name_and_type: Sequence[Tuple[str, str, str]], key_file: Optional[str], asr_train_config: str, asr_model_file: str, lm_train_config: Optional[str], lm_file: Optional[str], word_lm_train_config: Optional[str], word_lm_file: Optional[str], token_type: Optional[str], bpemodel: Optional[str], allow_variable_data_keys: bool, sim_chunk_length: int, disable_repetition_detection: bool, encoded_feat_length_limit: int, decoder_text_length_limit: int, ): assert check_argument_types() if batch_size > 1: raise NotImplementedError("batch decoding is not implemented") if word_lm_train_config is not None: raise NotImplementedError("Word LM is not implemented") if ngpu > 1: raise NotImplementedError("only single GPU decoding is supported") logging.basicConfig( level=log_level, format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s", ) if ngpu >= 1: device = "cuda" else: device = "cpu" # 1. Set random-seed set_all_random_seed(seed) # 2. Build speech2text speech2text = Speech2TextStreaming( asr_train_config=asr_train_config, asr_model_file=asr_model_file, lm_train_config=lm_train_config, lm_file=lm_file, token_type=token_type, bpemodel=bpemodel, device=device, maxlenratio=maxlenratio, minlenratio=minlenratio, dtype=dtype, beam_size=beam_size, ctc_weight=ctc_weight, lm_weight=lm_weight, penalty=penalty, nbest=nbest, disable_repetition_detection=disable_repetition_detection, decoder_text_length_limit=decoder_text_length_limit, encoded_feat_length_limit=encoded_feat_length_limit, ) # 3. Build data-iterator loader = ASRTask.build_streaming_iterator( data_path_and_name_and_type, dtype=dtype, batch_size=batch_size, key_file=key_file, num_workers=num_workers, preprocess_fn=ASRTask.build_preprocess_fn(speech2text.asr_train_args, False), collate_fn=ASRTask.build_collate_fn(speech2text.asr_train_args, False), allow_variable_data_keys=allow_variable_data_keys, inference=True, ) # 7 .Start for-loop # FIXME(kamo): The output format should be discussed about with DatadirWriter(output_dir) as writer: for keys, batch in loader: assert isinstance(batch, dict), type(batch) assert all(isinstance(s, str) for s in keys), keys _bs = len(next(iter(batch.values()))) assert len(keys) == _bs, f"{len(keys)} != {_bs}" batch = { k: v[0] for k, v in batch.items() if not k.endswith("_lengths") } assert len(batch.keys()) == 1 try: if sim_chunk_length == 0: # N-best list of (text, token, token_int, hyp_object) results = speech2text(**batch) else: speech = batch["speech"] for i in range(len(speech) // sim_chunk_length): speech2text( speech=speech[i * sim_chunk_length:(i + 1) * sim_chunk_length], is_final=False, ) results = speech2text(speech[(i + 1) * sim_chunk_length:len(speech)], is_final=True) except TooShortUttError as e: logging.warning(f"Utterance {keys} {e}") hyp = Hypothesis(score=0.0, scores={}, states={}, yseq=[]) results = [[" ", ["<space>"], [2], hyp]] * nbest # Only supporting batch_size==1 key = keys[0] for n, (text, token, token_int, hyp) in zip(range(1, nbest + 1), results): # Create a directory: outdir/{n}best_recog ibest_writer = writer[f"{n}best_recog"] # Write the result to each file ibest_writer["token"][key] = " ".join(token) ibest_writer["token_int"][key] = " ".join(map(str, token_int)) ibest_writer["score"][key] = str(hyp.score) if text is not None: ibest_writer["text"][key] = text
def search(self, running_hyps: BatchHypothesis, x: torch.Tensor) -> BatchHypothesis: """Search new tokens for running hypotheses and encoded speech x. Args: running_hyps (BatchHypothesis): Running hypotheses on beam x (torch.Tensor): Encoded speech feature (T, D) Returns: BatchHypothesis: Best sorted hypotheses """ n_batch = len(running_hyps) part_ids = None # no pre-beam # batch scoring weighted_scores = torch.zeros( n_batch, self.n_vocab, dtype=x.dtype, device=x.device ) scores, states = self.score_full(running_hyps, x.expand(n_batch, *x.shape)) for k in self.full_scorers: weighted_scores += self.weights[k] * scores[k] # partial scoring if self.do_pre_beam: pre_beam_scores = ( weighted_scores if self.pre_beam_score_key == "full" else scores[self.pre_beam_score_key] ) part_ids = torch.topk(pre_beam_scores, self.pre_beam_size, dim=-1)[1] # NOTE(takaaki-hori): Unlike BeamSearch, we assume that score_partial returns # full-size score matrices, which has non-zero scores for part_ids and zeros # for others. part_scores, part_states = self.score_partial(running_hyps, part_ids, x) for k in self.part_scorers: weighted_scores += self.weights[k] * part_scores[k] # add previous hyp scores weighted_scores += running_hyps.score.to( dtype=x.dtype, device=x.device ).unsqueeze(1) # TODO(karita): do not use list. use batch instead # see also https://github.com/espnet/espnet/pull/1402#discussion_r354561029 # update hyps best_hyps = [] prev_hyps = self.unbatchfy(running_hyps) for ( full_prev_hyp_id, full_new_token_id, part_prev_hyp_id, part_new_token_id, ) in zip(*self.batch_beam(weighted_scores, part_ids)): prev_hyp = prev_hyps[full_prev_hyp_id] best_hyps.append( Hypothesis( score=weighted_scores[full_prev_hyp_id, full_new_token_id], yseq=self.append_token(prev_hyp.yseq, full_new_token_id), scores=self.merge_scores( prev_hyp.scores, {k: v[full_prev_hyp_id] for k, v in scores.items()}, full_new_token_id, {k: v[part_prev_hyp_id] for k, v in part_scores.items()}, part_new_token_id, ), states=self.merge_states( { k: self.full_scorers[k].select_state(v, full_prev_hyp_id) for k, v in states.items() }, { k: self.part_scorers[k].select_state( v, part_prev_hyp_id, part_new_token_id ) for k, v in part_states.items() }, part_new_token_id, ), ) ) return self.batchfy(best_hyps)
def search(self, running_hyps: BatchHypothesis, x: torch.Tensor) -> BatchHypothesis: """Search new tokens for running hypotheses and encoded speech x. Args: running_hyps (BatchHypothesis): Running hypotheses on beam x (torch.Tensor): Encoded speech feature (T, D) Returns: BatchHypothesis: Best sorted hypotheses """ n_batch = len(running_hyps) # batch scoring scores, states = self.score_full(running_hyps, x.expand(n_batch, *x.shape)) if self.do_pre_beam: part_ids = torch.topk( scores[self.pre_beam_score_key], self.pre_beam_size, dim=-1 )[1] else: part_ids = torch.arange(self.n_vocab, device=x.device).expand( n_batch, self.n_vocab ) part_scores, part_states = self.score_partial(running_hyps, part_ids, x) # weighted sum scores weighted_scores = torch.zeros( n_batch, self.n_vocab, dtype=x.dtype, device=x.device ) for k in self.full_scorers: weighted_scores += self.weights[k] * scores[k] for k in self.part_scorers: weighted_scores[part_ids] += self.weights[k] * part_scores[k] weighted_scores += running_hyps.score.to( dtype=x.dtype, device=x.device ).unsqueeze(1) # TODO(karita): do not use list. use batch instead # see also https://github.com/espnet/espnet/pull/1402#discussion_r354561029 # update hyps best_hyps = [] prev_hyps = self.unbatchfy(running_hyps) for ( full_prev_hyp_id, full_new_token_id, part_prev_hyp_id, part_new_token_id, ) in zip(*self.batch_beam(weighted_scores, part_ids)): prev_hyp = prev_hyps[full_prev_hyp_id] best_hyps.append( Hypothesis( score=weighted_scores[full_prev_hyp_id, full_new_token_id], yseq=self.append_token(prev_hyp.yseq, full_new_token_id), scores=self.merge_scores( prev_hyp.scores, {k: v[full_prev_hyp_id] for k, v in scores.items()}, full_new_token_id, {k: v[part_prev_hyp_id] for k, v in part_scores.items()}, part_new_token_id, ), states=self.merge_states( { k: self.full_scorers[k].select_state(v, full_prev_hyp_id) for k, v in states.items() }, { k: self.part_scorers[k].select_state(v, part_prev_hyp_id) for k, v in part_states.items() }, part_new_token_id, ), ) ) return self.batchfy(best_hyps)