Esempio n. 1
0
    def change_vocabulary(self,
                          new_tokenizer_dir: str,
                          new_tokenizer_type: str,
                          decoding_cfg: Optional[DictConfig] = None):
        """
        Changes vocabulary used during RNNT decoding process. Use this method when fine-tuning on from pre-trained model.
        This method changes only decoder and leaves encoder and pre-processing modules unchanged. For example, you would
        use it if you want to use pretrained encoder when fine-tuning on data in another language, or when you'd need
        model to learn capitalization, punctuation and/or special characters.

        Args:
            new_tokenizer_dir: Directory path to tokenizer.
            new_tokenizer_type: Type of tokenizer. Can be either `bpe` or `wpe`.
            decoding_cfg: A config for the decoder, which is optional. If the decoding type
                needs to be changed (from say Greedy to Beam decoding etc), the config can be passed here.

        Returns: None

        """
        if not os.path.isdir(new_tokenizer_dir):
            raise NotADirectoryError(
                f'New tokenizer dir must be non-empty path to a directory. But I got: {new_tokenizer_dir}'
            )

        if new_tokenizer_type.lower() not in ('bpe', 'wpe'):
            raise ValueError(
                f'New tokenizer type must be either `bpe` or `wpe`')

        tokenizer_cfg = OmegaConf.create({
            'dir': new_tokenizer_dir,
            'type': new_tokenizer_type
        })

        # Setup the tokenizer
        self._setup_tokenizer(tokenizer_cfg)

        # Initialize a dummy vocabulary
        vocabulary = self.tokenizer.tokenizer.get_vocab()

        joint_config = self.joint.to_config_dict()
        new_joint_config = copy.deepcopy(joint_config)
        new_joint_config['vocabulary'] = ListConfig(list(vocabulary.values()))
        new_joint_config['num_classes'] = len(vocabulary)
        del self.joint
        self.joint = EncDecRNNTBPEModel.from_config_dict(new_joint_config)

        decoder_config = self.decoder.to_config_dict()
        new_decoder_config = copy.deepcopy(decoder_config)
        new_decoder_config.vocab_size = len(vocabulary)
        del self.decoder
        self.decoder = EncDecRNNTBPEModel.from_config_dict(new_decoder_config)

        del self.loss
        self.loss = RNNTLoss(num_classes=self.joint.num_classes_with_blank - 1)

        if decoding_cfg is None:
            # Assume same decoding config as before
            decoding_cfg = self.cfg.decoding

        self.decoding = RNNTBPEDecoding(
            decoding_cfg=decoding_cfg,
            decoder=self.decoder,
            joint=self.joint,
            tokenizer=self.tokenizer,
        )

        self.wer = RNNTBPEWER(
            decoding=self.decoding,
            batch_dim_index=self.wer.batch_dim_index,
            use_cer=self.wer.use_cer,
            log_prediction=self.wer.log_prediction,
            dist_sync_on_step=True,
        )

        # Setup fused Joint step
        if self.joint.fuse_loss_wer:
            self.joint.set_loss(self.loss)
            self.joint.set_wer(self.wer)

        # Update config
        with open_dict(self.cfg.joint):
            self.cfg.joint = new_joint_config

        with open_dict(self.cfg.decoder):
            self.cfg.decoder = new_decoder_config

        with open_dict(self.cfg.decoding):
            self.cfg.decoding = decoding_cfg

        logging.info(
            f"Changed decoder to output to {self.joint.vocabulary} vocabulary."
        )
Esempio n. 2
0
    def __init__(self, cfg: DictConfig, trainer: Trainer = None):
        # Get global rank and total number of GPU workers for IterableDataset partitioning, if applicable
        # Global_rank and local_rank is set by LightningModule in Lightning 1.2.0
        self.world_size = 1
        if trainer is not None:
            self.world_size = trainer.num_nodes * trainer.num_gpus

        super().__init__(cfg=cfg, trainer=trainer)

        # Initialize components
        self.preprocessor = EncDecRNNTModel.from_config_dict(
            self.cfg.preprocessor)
        self.encoder = EncDecRNNTModel.from_config_dict(self.cfg.encoder)

        # Update config values required by components dynamically
        with open_dict(self.cfg.decoder):
            self.cfg.decoder.vocab_size = len(self.cfg.labels)

        with open_dict(self.cfg.joint):
            self.cfg.joint.num_classes = len(self.cfg.labels)
            self.cfg.joint.vocabulary = self.cfg.labels
            self.cfg.joint.jointnet.encoder_hidden = self.cfg.model_defaults.enc_hidden
            self.cfg.joint.jointnet.pred_hidden = self.cfg.model_defaults.pred_hidden

        self.decoder = EncDecRNNTModel.from_config_dict(self.cfg.decoder)
        self.joint = EncDecRNNTModel.from_config_dict(self.cfg.joint)

        # Setup RNNT Loss
        loss_name, loss_kwargs = self.extract_rnnt_loss_cfg(
            self.cfg.get("loss", None))

        self.loss = RNNTLoss(num_classes=self.joint.num_classes_with_blank - 1,
                             loss_name=loss_name,
                             loss_kwargs=loss_kwargs)

        if hasattr(self.cfg,
                   'spec_augment') and self._cfg.spec_augment is not None:
            self.spec_augmentation = EncDecRNNTModel.from_config_dict(
                self.cfg.spec_augment)
        else:
            self.spec_augmentation = None

        # Setup decoding objects
        self.decoding = RNNTDecoding(
            decoding_cfg=self.cfg.decoding,
            decoder=self.decoder,
            joint=self.joint,
            vocabulary=self.joint.vocabulary,
        )
        # Setup WER calculation
        self.wer = RNNTWER(
            decoding=self.decoding,
            batch_dim_index=0,
            use_cer=self._cfg.get('use_cer', False),
            log_prediction=self._cfg.get('log_prediction', True),
            dist_sync_on_step=True,
        )

        # Whether to compute loss during evaluation
        if 'compute_eval_loss' in self.cfg:
            self.compute_eval_loss = self.cfg.compute_eval_loss
        else:
            self.compute_eval_loss = True

        # Setup fused Joint step if flag is set
        if self.joint.fuse_loss_wer:
            self.joint.set_loss(self.loss)
            self.joint.set_wer(self.wer)

        self.setup_optim_normalization()
Esempio n. 3
0
    def __init__(self, cfg: DictConfig, trainer: Trainer = None):
        # Required loss function
        if not WARP_RNNT_AVAILABLE:
            raise ImportError(
                "Could not import `warprnnt_pytorch`.\n"
                "Please visit https://github.com/HawkAaron/warp-transducer "
                "and follow the steps in the readme to build and install the "
                "pytorch bindings for RNNT Loss, or use the provided docker "
                "container that supports RNN-T loss.")

        # Get global rank and total number of GPU workers for IterableDataset partitioning, if applicable
        # Global_rank and local_rank is set by LightningModule in Lightning 1.2.0
        self.world_size = 1
        if trainer is not None:
            self.world_size = trainer.num_nodes * trainer.num_gpus

        super().__init__(cfg=cfg, trainer=trainer)

        # Initialize components
        self.preprocessor = EncDecRNNTModel.from_config_dict(
            self.cfg.preprocessor)
        self.encoder = EncDecRNNTModel.from_config_dict(self.cfg.encoder)

        # Update config values required by components dynamically
        with open_dict(self.cfg.decoder):
            self.cfg.decoder.vocab_size = len(self.cfg.labels)

        with open_dict(self.cfg.joint):
            self.cfg.joint.num_classes = len(self.cfg.labels)
            self.cfg.joint.vocabulary = self.cfg.labels
            self.cfg.joint.jointnet.encoder_hidden = self.cfg.model_defaults.enc_hidden
            self.cfg.joint.jointnet.pred_hidden = self.cfg.model_defaults.pred_hidden

        self.decoder = EncDecRNNTModel.from_config_dict(self.cfg.decoder)
        self.joint = EncDecRNNTModel.from_config_dict(self.cfg.joint)
        self.loss = RNNTLoss(num_classes=self.joint.num_classes_with_blank - 1)

        if hasattr(self.cfg,
                   'spec_augment') and self._cfg.spec_augment is not None:
            self.spec_augmentation = EncDecRNNTModel.from_config_dict(
                self.cfg.spec_augment)
        else:
            self.spec_augmentation = None

        # Setup decoding objects
        self.decoding = RNNTDecoding(
            decoding_cfg=self.cfg.decoding,
            decoder=self.decoder,
            joint=self.joint,
            vocabulary=self.joint.vocabulary,
        )
        # Setup WER calculation
        self.wer = RNNTWER(
            decoding=self.decoding,
            batch_dim_index=0,
            use_cer=self._cfg.get('use_cer', False),
            log_prediction=self._cfg.get('log_prediction', True),
            dist_sync_on_step=True,
        )

        # Whether to compute loss during evaluation
        if 'compute_eval_loss' in self.cfg:
            self.compute_eval_loss = self.cfg.compute_eval_loss
        else:
            self.compute_eval_loss = True

        # Setup fused Joint step if flag is set
        if self.joint.fuse_loss_wer:
            self.joint.set_loss(self.loss)
            self.joint.set_wer(self.wer)

        # setting up the variational noise for the decoder
        if hasattr(self.cfg, 'variational_noise'):
            self._optim_variational_noise_std = self.cfg[
                'variational_noise'].get('std', 0)
            self._optim_variational_noise_start = self.cfg[
                'variational_noise'].get('start_step', 0)
        else:
            self._optim_variational_noise_std = 0
            self._optim_variational_noise_start = 0
Esempio n. 4
0
    def change_vocabulary(self,
                          new_vocabulary: List[str],
                          decoding_cfg: Optional[DictConfig] = None):
        """
        Changes vocabulary used during RNNT decoding process. Use this method when fine-tuning a pre-trained model.
        This method changes only decoder and leaves encoder and pre-processing modules unchanged. For example, you would
        use it if you want to use pretrained encoder when fine-tuning on data in another language, or when you'd need
        model to learn capitalization, punctuation and/or special characters.

        Args:
            new_vocabulary: list with new vocabulary. Must contain at least 2 elements. Typically, \
                this is target alphabet.
            decoding_cfg: A config for the decoder, which is optional. If the decoding type
                needs to be changed (from say Greedy to Beam decoding etc), the config can be passed here.

        Returns: None

        """
        if self.joint.vocabulary == new_vocabulary:
            logging.warning(
                f"Old {self.joint.vocabulary} and new {new_vocabulary} match. Not changing anything."
            )
        else:
            if new_vocabulary is None or len(new_vocabulary) == 0:
                raise ValueError(
                    f'New vocabulary must be non-empty list of chars. But I got: {new_vocabulary}'
                )

            joint_config = self.joint.to_config_dict()
            new_joint_config = copy.deepcopy(joint_config)
            new_joint_config['vocabulary'] = new_vocabulary
            new_joint_config['num_classes'] = len(new_vocabulary)
            del self.joint
            self.joint = EncDecRNNTModel.from_config_dict(new_joint_config)

            decoder_config = self.decoder.to_config_dict()
            new_decoder_config = copy.deepcopy(decoder_config)
            new_decoder_config.vocab_size = len(new_vocabulary)
            del self.decoder
            self.decoder = EncDecRNNTModel.from_config_dict(new_decoder_config)

            del self.loss
            loss_name, loss_kwargs = self.extract_rnnt_loss_cfg(
                self.cfg.get('loss', None))
            self.loss = RNNTLoss(
                num_classes=self.joint.num_classes_with_blank - 1,
                loss_name=loss_name,
                loss_kwargs=loss_kwargs)

            if decoding_cfg is None:
                # Assume same decoding config as before
                decoding_cfg = self.cfg.decoding

            self.decoding = RNNTDecoding(
                decoding_cfg=decoding_cfg,
                decoder=self.decoder,
                joint=self.joint,
                vocabulary=self.joint.vocabulary,
            )

            self.wer = RNNTWER(
                decoding=self.decoding,
                batch_dim_index=self.wer.batch_dim_index,
                use_cer=self.wer.use_cer,
                log_prediction=self.wer.log_prediction,
                dist_sync_on_step=True,
            )

            # Setup fused Joint step
            if self.joint.fuse_loss_wer:
                self.joint.set_loss(self.loss)
                self.joint.set_wer(self.wer)

            # Update config
            with open_dict(self.cfg.joint):
                self.cfg.joint = new_joint_config

            with open_dict(self.cfg.decoder):
                self.cfg.decoder = new_decoder_config

            with open_dict(self.cfg.decoding):
                self.cfg.decoding = decoding_cfg

            ds_keys = ['train_ds', 'validation_ds', 'test_ds']
            for key in ds_keys:
                if key in self.cfg:
                    with open_dict(self.cfg[key]):
                        self.cfg[key]['labels'] = OmegaConf.create(
                            new_vocabulary)

            logging.info(
                f"Changed decoder to output to {self.joint.vocabulary} vocabulary."
            )
Esempio n. 5
0
 def __init__(self, num_classes):
     super().__init__()
     self.loss = RNNTLoss(num_classes=num_classes)