def change_vocabulary(self, new_tokenizer_dir: str, new_tokenizer_type: str, decoding_cfg: Optional[DictConfig] = None): """ Changes vocabulary used during RNNT decoding process. Use this method when fine-tuning on from pre-trained model. This method changes only decoder and leaves encoder and pre-processing modules unchanged. For example, you would use it if you want to use pretrained encoder when fine-tuning on data in another language, or when you'd need model to learn capitalization, punctuation and/or special characters. Args: new_tokenizer_dir: Directory path to tokenizer. new_tokenizer_type: Type of tokenizer. Can be either `bpe` or `wpe`. decoding_cfg: A config for the decoder, which is optional. If the decoding type needs to be changed (from say Greedy to Beam decoding etc), the config can be passed here. Returns: None """ if not os.path.isdir(new_tokenizer_dir): raise NotADirectoryError( f'New tokenizer dir must be non-empty path to a directory. But I got: {new_tokenizer_dir}' ) if new_tokenizer_type.lower() not in ('bpe', 'wpe'): raise ValueError( f'New tokenizer type must be either `bpe` or `wpe`') tokenizer_cfg = OmegaConf.create({ 'dir': new_tokenizer_dir, 'type': new_tokenizer_type }) # Setup the tokenizer self._setup_tokenizer(tokenizer_cfg) # Initialize a dummy vocabulary vocabulary = self.tokenizer.tokenizer.get_vocab() joint_config = self.joint.to_config_dict() new_joint_config = copy.deepcopy(joint_config) new_joint_config['vocabulary'] = ListConfig(list(vocabulary.values())) new_joint_config['num_classes'] = len(vocabulary) del self.joint self.joint = EncDecRNNTBPEModel.from_config_dict(new_joint_config) decoder_config = self.decoder.to_config_dict() new_decoder_config = copy.deepcopy(decoder_config) new_decoder_config.vocab_size = len(vocabulary) del self.decoder self.decoder = EncDecRNNTBPEModel.from_config_dict(new_decoder_config) del self.loss self.loss = RNNTLoss(num_classes=self.joint.num_classes_with_blank - 1) if decoding_cfg is None: # Assume same decoding config as before decoding_cfg = self.cfg.decoding self.decoding = RNNTBPEDecoding( decoding_cfg=decoding_cfg, decoder=self.decoder, joint=self.joint, tokenizer=self.tokenizer, ) self.wer = RNNTBPEWER( decoding=self.decoding, batch_dim_index=self.wer.batch_dim_index, use_cer=self.wer.use_cer, log_prediction=self.wer.log_prediction, dist_sync_on_step=True, ) # Setup fused Joint step if self.joint.fuse_loss_wer: self.joint.set_loss(self.loss) self.joint.set_wer(self.wer) # Update config with open_dict(self.cfg.joint): self.cfg.joint = new_joint_config with open_dict(self.cfg.decoder): self.cfg.decoder = new_decoder_config with open_dict(self.cfg.decoding): self.cfg.decoding = decoding_cfg logging.info( f"Changed decoder to output to {self.joint.vocabulary} vocabulary." )
def __init__(self, cfg: DictConfig, trainer: Trainer = None): # Get global rank and total number of GPU workers for IterableDataset partitioning, if applicable # Global_rank and local_rank is set by LightningModule in Lightning 1.2.0 self.world_size = 1 if trainer is not None: self.world_size = trainer.num_nodes * trainer.num_gpus super().__init__(cfg=cfg, trainer=trainer) # Initialize components self.preprocessor = EncDecRNNTModel.from_config_dict( self.cfg.preprocessor) self.encoder = EncDecRNNTModel.from_config_dict(self.cfg.encoder) # Update config values required by components dynamically with open_dict(self.cfg.decoder): self.cfg.decoder.vocab_size = len(self.cfg.labels) with open_dict(self.cfg.joint): self.cfg.joint.num_classes = len(self.cfg.labels) self.cfg.joint.vocabulary = self.cfg.labels self.cfg.joint.jointnet.encoder_hidden = self.cfg.model_defaults.enc_hidden self.cfg.joint.jointnet.pred_hidden = self.cfg.model_defaults.pred_hidden self.decoder = EncDecRNNTModel.from_config_dict(self.cfg.decoder) self.joint = EncDecRNNTModel.from_config_dict(self.cfg.joint) # Setup RNNT Loss loss_name, loss_kwargs = self.extract_rnnt_loss_cfg( self.cfg.get("loss", None)) self.loss = RNNTLoss(num_classes=self.joint.num_classes_with_blank - 1, loss_name=loss_name, loss_kwargs=loss_kwargs) if hasattr(self.cfg, 'spec_augment') and self._cfg.spec_augment is not None: self.spec_augmentation = EncDecRNNTModel.from_config_dict( self.cfg.spec_augment) else: self.spec_augmentation = None # Setup decoding objects self.decoding = RNNTDecoding( decoding_cfg=self.cfg.decoding, decoder=self.decoder, joint=self.joint, vocabulary=self.joint.vocabulary, ) # Setup WER calculation self.wer = RNNTWER( decoding=self.decoding, batch_dim_index=0, use_cer=self._cfg.get('use_cer', False), log_prediction=self._cfg.get('log_prediction', True), dist_sync_on_step=True, ) # Whether to compute loss during evaluation if 'compute_eval_loss' in self.cfg: self.compute_eval_loss = self.cfg.compute_eval_loss else: self.compute_eval_loss = True # Setup fused Joint step if flag is set if self.joint.fuse_loss_wer: self.joint.set_loss(self.loss) self.joint.set_wer(self.wer) self.setup_optim_normalization()
def __init__(self, cfg: DictConfig, trainer: Trainer = None): # Required loss function if not WARP_RNNT_AVAILABLE: raise ImportError( "Could not import `warprnnt_pytorch`.\n" "Please visit https://github.com/HawkAaron/warp-transducer " "and follow the steps in the readme to build and install the " "pytorch bindings for RNNT Loss, or use the provided docker " "container that supports RNN-T loss.") # Get global rank and total number of GPU workers for IterableDataset partitioning, if applicable # Global_rank and local_rank is set by LightningModule in Lightning 1.2.0 self.world_size = 1 if trainer is not None: self.world_size = trainer.num_nodes * trainer.num_gpus super().__init__(cfg=cfg, trainer=trainer) # Initialize components self.preprocessor = EncDecRNNTModel.from_config_dict( self.cfg.preprocessor) self.encoder = EncDecRNNTModel.from_config_dict(self.cfg.encoder) # Update config values required by components dynamically with open_dict(self.cfg.decoder): self.cfg.decoder.vocab_size = len(self.cfg.labels) with open_dict(self.cfg.joint): self.cfg.joint.num_classes = len(self.cfg.labels) self.cfg.joint.vocabulary = self.cfg.labels self.cfg.joint.jointnet.encoder_hidden = self.cfg.model_defaults.enc_hidden self.cfg.joint.jointnet.pred_hidden = self.cfg.model_defaults.pred_hidden self.decoder = EncDecRNNTModel.from_config_dict(self.cfg.decoder) self.joint = EncDecRNNTModel.from_config_dict(self.cfg.joint) self.loss = RNNTLoss(num_classes=self.joint.num_classes_with_blank - 1) if hasattr(self.cfg, 'spec_augment') and self._cfg.spec_augment is not None: self.spec_augmentation = EncDecRNNTModel.from_config_dict( self.cfg.spec_augment) else: self.spec_augmentation = None # Setup decoding objects self.decoding = RNNTDecoding( decoding_cfg=self.cfg.decoding, decoder=self.decoder, joint=self.joint, vocabulary=self.joint.vocabulary, ) # Setup WER calculation self.wer = RNNTWER( decoding=self.decoding, batch_dim_index=0, use_cer=self._cfg.get('use_cer', False), log_prediction=self._cfg.get('log_prediction', True), dist_sync_on_step=True, ) # Whether to compute loss during evaluation if 'compute_eval_loss' in self.cfg: self.compute_eval_loss = self.cfg.compute_eval_loss else: self.compute_eval_loss = True # Setup fused Joint step if flag is set if self.joint.fuse_loss_wer: self.joint.set_loss(self.loss) self.joint.set_wer(self.wer) # setting up the variational noise for the decoder if hasattr(self.cfg, 'variational_noise'): self._optim_variational_noise_std = self.cfg[ 'variational_noise'].get('std', 0) self._optim_variational_noise_start = self.cfg[ 'variational_noise'].get('start_step', 0) else: self._optim_variational_noise_std = 0 self._optim_variational_noise_start = 0
def change_vocabulary(self, new_vocabulary: List[str], decoding_cfg: Optional[DictConfig] = None): """ Changes vocabulary used during RNNT decoding process. Use this method when fine-tuning a pre-trained model. This method changes only decoder and leaves encoder and pre-processing modules unchanged. For example, you would use it if you want to use pretrained encoder when fine-tuning on data in another language, or when you'd need model to learn capitalization, punctuation and/or special characters. Args: new_vocabulary: list with new vocabulary. Must contain at least 2 elements. Typically, \ this is target alphabet. decoding_cfg: A config for the decoder, which is optional. If the decoding type needs to be changed (from say Greedy to Beam decoding etc), the config can be passed here. Returns: None """ if self.joint.vocabulary == new_vocabulary: logging.warning( f"Old {self.joint.vocabulary} and new {new_vocabulary} match. Not changing anything." ) else: if new_vocabulary is None or len(new_vocabulary) == 0: raise ValueError( f'New vocabulary must be non-empty list of chars. But I got: {new_vocabulary}' ) joint_config = self.joint.to_config_dict() new_joint_config = copy.deepcopy(joint_config) new_joint_config['vocabulary'] = new_vocabulary new_joint_config['num_classes'] = len(new_vocabulary) del self.joint self.joint = EncDecRNNTModel.from_config_dict(new_joint_config) decoder_config = self.decoder.to_config_dict() new_decoder_config = copy.deepcopy(decoder_config) new_decoder_config.vocab_size = len(new_vocabulary) del self.decoder self.decoder = EncDecRNNTModel.from_config_dict(new_decoder_config) del self.loss loss_name, loss_kwargs = self.extract_rnnt_loss_cfg( self.cfg.get('loss', None)) self.loss = RNNTLoss( num_classes=self.joint.num_classes_with_blank - 1, loss_name=loss_name, loss_kwargs=loss_kwargs) if decoding_cfg is None: # Assume same decoding config as before decoding_cfg = self.cfg.decoding self.decoding = RNNTDecoding( decoding_cfg=decoding_cfg, decoder=self.decoder, joint=self.joint, vocabulary=self.joint.vocabulary, ) self.wer = RNNTWER( decoding=self.decoding, batch_dim_index=self.wer.batch_dim_index, use_cer=self.wer.use_cer, log_prediction=self.wer.log_prediction, dist_sync_on_step=True, ) # Setup fused Joint step if self.joint.fuse_loss_wer: self.joint.set_loss(self.loss) self.joint.set_wer(self.wer) # Update config with open_dict(self.cfg.joint): self.cfg.joint = new_joint_config with open_dict(self.cfg.decoder): self.cfg.decoder = new_decoder_config with open_dict(self.cfg.decoding): self.cfg.decoding = decoding_cfg ds_keys = ['train_ds', 'validation_ds', 'test_ds'] for key in ds_keys: if key in self.cfg: with open_dict(self.cfg[key]): self.cfg[key]['labels'] = OmegaConf.create( new_vocabulary) logging.info( f"Changed decoder to output to {self.joint.vocabulary} vocabulary." )
def __init__(self, num_classes): super().__init__() self.loss = RNNTLoss(num_classes=num_classes)