def change_decoding_strategy(self, decoding_cfg: DictConfig): """ Changes decoding strategy used during RNNT decoding process. Args: decoding_cfg: A config for the decoder, which is optional. If the decoding type needs to be changed (from say Greedy to Beam decoding etc), the config can be passed here. """ if decoding_cfg is None: # Assume same decoding config as before logging.info("No `decoding_cfg` passed when changing decoding strategy, using internal config") decoding_cfg = self.cfg.decoding self.decoding = RNNTDecoding( decoding_cfg=decoding_cfg, decoder=self.decoder, joint=self.joint, vocabulary=self.joint.vocabulary, ) self.wer = RNNTWER( decoding=self.decoding, batch_dim_index=self.wer.batch_dim_index, use_cer=self.wer.use_cer, log_prediction=self.wer.log_prediction, dist_sync_on_step=True, ) # Setup fused Joint step if self.joint.fuse_loss_wer: self.joint.set_loss(self.loss) self.joint.set_wer(self.wer) # Update config with open_dict(self.cfg.decoding): self.cfg.decoding = decoding_cfg logging.info(f"Changed decoding strategy to \n{OmegaConf.to_yaml(self.cfg.decoding)}")
def get_wer_rnnt(self, prediction: str, reference: str, batch_dim_index: int, test_wer_bpe: bool): rnnt_decoder_predictions_tensor_mock = Mock(return_value=([prediction], None)) if test_wer_bpe: decoding = Mock( blank_id=self.char_tokenizer.tokenizer.vocab_size, tokenizer=deepcopy(self.char_tokenizer), rnnt_decoder_predictions_tensor= rnnt_decoder_predictions_tensor_mock, decode_tokens_to_str=self.char_tokenizer.ids_to_text, ) wer = RNNTBPEWER(decoding, batch_dim_index=batch_dim_index, use_cer=False) else: decoding = Mock( blank_id=len(self.vocabulary), labels_map=self.vocabulary.copy(), rnnt_decoder_predictions_tensor= rnnt_decoder_predictions_tensor_mock, decode_tokens_to_str=self. decode_token_to_str_with_vocabulary_mock, ) wer = RNNTWER(decoding, batch_dim_index=batch_dim_index, use_cer=False) targets_tensor = self.__reference_string_to_tensor( reference, test_wer_bpe) if wer.batch_dim_index > 0: targets_tensor.transpose_(0, 1) wer( encoder_output=None, encoded_lengths=None, targets=targets_tensor, target_lengths=torch.tensor([len(reference)]), ) res, _, _ = wer.compute() res = res.detach().cpu() # return res[0] / res[1] return res.item()
def __init__(self, cfg: DictConfig, trainer: Trainer = None): # Get global rank and total number of GPU workers for IterableDataset partitioning, if applicable # Global_rank and local_rank is set by LightningModule in Lightning 1.2.0 self.world_size = 1 if trainer is not None: self.world_size = trainer.num_nodes * trainer.num_gpus super().__init__(cfg=cfg, trainer=trainer) # Initialize components self.preprocessor = EncDecRNNTModel.from_config_dict( self.cfg.preprocessor) self.encoder = EncDecRNNTModel.from_config_dict(self.cfg.encoder) # Update config values required by components dynamically with open_dict(self.cfg.decoder): self.cfg.decoder.vocab_size = len(self.cfg.labels) with open_dict(self.cfg.joint): self.cfg.joint.num_classes = len(self.cfg.labels) self.cfg.joint.vocabulary = self.cfg.labels self.cfg.joint.jointnet.encoder_hidden = self.cfg.model_defaults.enc_hidden self.cfg.joint.jointnet.pred_hidden = self.cfg.model_defaults.pred_hidden self.decoder = EncDecRNNTModel.from_config_dict(self.cfg.decoder) self.joint = EncDecRNNTModel.from_config_dict(self.cfg.joint) # Setup RNNT Loss loss_name, loss_kwargs = self.extract_rnnt_loss_cfg( self.cfg.get("loss", None)) self.loss = RNNTLoss(num_classes=self.joint.num_classes_with_blank - 1, loss_name=loss_name, loss_kwargs=loss_kwargs) if hasattr(self.cfg, 'spec_augment') and self._cfg.spec_augment is not None: self.spec_augmentation = EncDecRNNTModel.from_config_dict( self.cfg.spec_augment) else: self.spec_augmentation = None # Setup decoding objects self.decoding = RNNTDecoding( decoding_cfg=self.cfg.decoding, decoder=self.decoder, joint=self.joint, vocabulary=self.joint.vocabulary, ) # Setup WER calculation self.wer = RNNTWER( decoding=self.decoding, batch_dim_index=0, use_cer=self._cfg.get('use_cer', False), log_prediction=self._cfg.get('log_prediction', True), dist_sync_on_step=True, ) # Whether to compute loss during evaluation if 'compute_eval_loss' in self.cfg: self.compute_eval_loss = self.cfg.compute_eval_loss else: self.compute_eval_loss = True # Setup fused Joint step if flag is set if self.joint.fuse_loss_wer: self.joint.set_loss(self.loss) self.joint.set_wer(self.wer) self.setup_optim_normalization()
def change_vocabulary(self, new_vocabulary: List[str], decoding_cfg: Optional[DictConfig] = None): """ Changes vocabulary used during RNNT decoding process. Use this method when fine-tuning a pre-trained model. This method changes only decoder and leaves encoder and pre-processing modules unchanged. For example, you would use it if you want to use pretrained encoder when fine-tuning on data in another language, or when you'd need model to learn capitalization, punctuation and/or special characters. Args: new_vocabulary: list with new vocabulary. Must contain at least 2 elements. Typically, \ this is target alphabet. decoding_cfg: A config for the decoder, which is optional. If the decoding type needs to be changed (from say Greedy to Beam decoding etc), the config can be passed here. Returns: None """ if self.joint.vocabulary == new_vocabulary: logging.warning( f"Old {self.joint.vocabulary} and new {new_vocabulary} match. Not changing anything." ) else: if new_vocabulary is None or len(new_vocabulary) == 0: raise ValueError( f'New vocabulary must be non-empty list of chars. But I got: {new_vocabulary}' ) joint_config = self.joint.to_config_dict() new_joint_config = copy.deepcopy(joint_config) new_joint_config['vocabulary'] = new_vocabulary new_joint_config['num_classes'] = len(new_vocabulary) del self.joint self.joint = EncDecRNNTModel.from_config_dict(new_joint_config) decoder_config = self.decoder.to_config_dict() new_decoder_config = copy.deepcopy(decoder_config) new_decoder_config.vocab_size = len(new_vocabulary) del self.decoder self.decoder = EncDecRNNTModel.from_config_dict(new_decoder_config) del self.loss loss_name, loss_kwargs = self.extract_rnnt_loss_cfg( self.cfg.get('loss', None)) self.loss = RNNTLoss( num_classes=self.joint.num_classes_with_blank - 1, loss_name=loss_name, loss_kwargs=loss_kwargs) if decoding_cfg is None: # Assume same decoding config as before decoding_cfg = self.cfg.decoding self.decoding = RNNTDecoding( decoding_cfg=decoding_cfg, decoder=self.decoder, joint=self.joint, vocabulary=self.joint.vocabulary, ) self.wer = RNNTWER( decoding=self.decoding, batch_dim_index=self.wer.batch_dim_index, use_cer=self.wer.use_cer, log_prediction=self.wer.log_prediction, dist_sync_on_step=True, ) # Setup fused Joint step if self.joint.fuse_loss_wer: self.joint.set_loss(self.loss) self.joint.set_wer(self.wer) # Update config with open_dict(self.cfg.joint): self.cfg.joint = new_joint_config with open_dict(self.cfg.decoder): self.cfg.decoder = new_decoder_config with open_dict(self.cfg.decoding): self.cfg.decoding = decoding_cfg ds_keys = ['train_ds', 'validation_ds', 'test_ds'] for key in ds_keys: if key in self.cfg: with open_dict(self.cfg[key]): self.cfg[key]['labels'] = OmegaConf.create( new_vocabulary) logging.info( f"Changed decoder to output to {self.joint.vocabulary} vocabulary." )
def __init__(self, cfg: DictConfig, trainer: Trainer = None): # Required loss function if not WARP_RNNT_AVAILABLE: raise ImportError( "Could not import `warprnnt_pytorch`.\n" "Please visit https://github.com/HawkAaron/warp-transducer " "and follow the steps in the readme to build and install the " "pytorch bindings for RNNT Loss, or use the provided docker " "container that supports RNN-T loss.") # Get global rank and total number of GPU workers for IterableDataset partitioning, if applicable # Global_rank and local_rank is set by LightningModule in Lightning 1.2.0 self.world_size = 1 if trainer is not None: self.world_size = trainer.num_nodes * trainer.num_gpus super().__init__(cfg=cfg, trainer=trainer) # Initialize components self.preprocessor = EncDecRNNTModel.from_config_dict( self.cfg.preprocessor) self.encoder = EncDecRNNTModel.from_config_dict(self.cfg.encoder) # Update config values required by components dynamically with open_dict(self.cfg.decoder): self.cfg.decoder.vocab_size = len(self.cfg.labels) with open_dict(self.cfg.joint): self.cfg.joint.num_classes = len(self.cfg.labels) self.cfg.joint.vocabulary = self.cfg.labels self.cfg.joint.jointnet.encoder_hidden = self.cfg.model_defaults.enc_hidden self.cfg.joint.jointnet.pred_hidden = self.cfg.model_defaults.pred_hidden self.decoder = EncDecRNNTModel.from_config_dict(self.cfg.decoder) self.joint = EncDecRNNTModel.from_config_dict(self.cfg.joint) self.loss = RNNTLoss(num_classes=self.joint.num_classes_with_blank - 1) if hasattr(self.cfg, 'spec_augment') and self._cfg.spec_augment is not None: self.spec_augmentation = EncDecRNNTModel.from_config_dict( self.cfg.spec_augment) else: self.spec_augmentation = None # Setup decoding objects self.decoding = RNNTDecoding( decoding_cfg=self.cfg.decoding, decoder=self.decoder, joint=self.joint, vocabulary=self.joint.vocabulary, ) # Setup WER calculation self.wer = RNNTWER( decoding=self.decoding, batch_dim_index=0, use_cer=self._cfg.get('use_cer', False), log_prediction=self._cfg.get('log_prediction', True), dist_sync_on_step=True, ) # Whether to compute loss during evaluation if 'compute_eval_loss' in self.cfg: self.compute_eval_loss = self.cfg.compute_eval_loss else: self.compute_eval_loss = True # Setup fused Joint step if flag is set if self.joint.fuse_loss_wer: self.joint.set_loss(self.loss) self.joint.set_wer(self.wer) # setting up the variational noise for the decoder if hasattr(self.cfg, 'variational_noise'): self._optim_variational_noise_std = self.cfg[ 'variational_noise'].get('std', 0) self._optim_variational_noise_start = self.cfg[ 'variational_noise'].get('start_step', 0) else: self._optim_variational_noise_std = 0 self._optim_variational_noise_start = 0