예제 #1
0
    def forward_transducer(self, eouts, elens, ys):
        """Compute Transducer loss.

        Args:
            eouts (FloatTensor): `[B, T, enc_n_units]`
            elens (IntTensor): `[B]`
            ys (list): length `B`, each of which contains a list of size `[L]`
        Returns:
            loss (FloatTensor): `[1]`

        """
        # Append <sos> and <eos>
        _ys = [
            np2tensor(np.fromiter(y, dtype=np.int64), eouts.device) for y in ys
        ]
        ylens = np2tensor(np.fromiter([y.size(0) for y in _ys],
                                      dtype=np.int32))
        eos = eouts.new_zeros((1, ), dtype=torch.int64).fill_(self.eos)
        ys_in = pad_list([torch.cat([eos, y], dim=0) for y in _ys],
                         self.pad)  # `[B, L+1]`
        ys_out = pad_list(_ys, self.blank)  # `[B, L]`

        # Update prediction network
        ys_emb = self.dropout_emb(self.embed(ys_in))
        dout, _ = self.recurrency(ys_emb, None)

        # Compute output distribution
        logits = self.joint(eouts, dout)  # `[B, T, L+1, vocab]`

        # Compute Transducer loss
        log_probs = torch.log_softmax(logits, dim=-1)
        assert log_probs.size(2) == ys_out.size(1) + 1
        if self.device_id >= 0:
            ys_out = ys_out.to(eouts.device)
            elens = elens.to(eouts.device)
            ylens = ylens.to(eouts.device)
            import warp_rnnt
            loss = warp_rnnt.rnnt_loss(log_probs,
                                       ys_out.int(),
                                       elens,
                                       ylens,
                                       average_frames=False,
                                       reduction='mean',
                                       gather=False)
        else:
            import warprnnt_pytorch
            self.warprnnt_loss = warprnnt_pytorch.RNNTLoss()
            loss = self.warprnnt_loss(log_probs, ys_out.int(), elens, ylens)
            # NOTE: Transducer loss has already been normalized by bs
            # NOTE: index 0 is reserved for blank in warprnnt_pytorch

        return loss
예제 #2
0
    def __init__(self,
                 vocab: Vocabulary,
                 input_size: int,
                 hidden_size: int,
                 loss_ratio: float = 1.0,
                 recurrency: nn.LSTM = None,
                 num_layers: int = None,
                 remove_sos: bool = True,
                 remove_eos: bool = False,
                 target_embedder: Embedding = None,
                 target_embedding_dim: int = None,
                 target_namespace: str = "tokens",
                 slow_decode: bool = False,
                 initializer: InitializerApplicator = InitializerApplicator(),
                 regularizer: Optional[RegularizerApplicator] = None) -> None:
        super(RNNTLayer, self).__init__(vocab, regularizer)
        import warprnnt_pytorch
        self.loss_ratio = loss_ratio
        self._remove_sos = remove_sos
        self._remove_eos = remove_eos
        self._slow_decode = slow_decode
        self._target_namespace = target_namespace
        self._num_classes = self.vocab.get_vocab_size(target_namespace)
        self._pad_index = self.vocab.get_token_index(DEFAULT_PADDING_TOKEN,
                                                     self._target_namespace)
        self._start_index = self.vocab.get_token_index(START_SYMBOL,
                                                       self._target_namespace)
        self._end_index = self.vocab.get_token_index(END_SYMBOL,
                                                     self._target_namespace)

        self._loss = warprnnt_pytorch.RNNTLoss(blank=self._pad_index,
                                               reduction='mean')
        self._recurrency = recurrency or \
            nn.LSTM(input_size=target_embedding_dim,
                    hidden_size=hidden_size,
                    num_layers=num_layers,
                    batch_first=True)

        self._target_embedder = target_embedder or Embedding(
            self._num_classes, target_embedding_dim)
        self.w_enc = nn.Linear(input_size, hidden_size, bias=True)
        self.w_dec = nn.Linear(input_size, hidden_size, bias=False)
        self._proj = nn.Linear(hidden_size, self._num_classes)

        exclude_indices = {self._pad_index, self._end_index, self._start_index}
        self._wer: Metric = WER(exclude_indices=exclude_indices)
        self._bleu: Metric = BLEU(exclude_indices=exclude_indices)
        self._dal = Average()

        initializer(self)
예제 #3
0
    def __init__(self, num_classes, reduction='mean_batch'):
        """
        RNN-T Loss function based on https://github.com/HawkAaron/warp-transducer.

        Note:
            Requires the pytorch bindings to be installed prior to calling this class.

        Warning:
            In the case that GPU memory is exhausted in order to compute RNNTLoss, it might cause
            a core dump at the cuda level with the following error message.

            ```
                ...
                costs = costs.to(acts.device)
            RuntimeError: CUDA error: an illegal memory access was encountered
            terminate called after throwing an instance of 'c10::Error'
            ```

            Please kill all remaining python processes after this point, and use a smaller batch size
            for train, validation and test sets so that CUDA memory is not exhausted.

        Args:
            num_classes: Number of target classes for the joint network to predict.
                (Excluding the RNN-T blank token).

            reduction: Type of reduction to perform on loss. Possibly values are `mean`, `sum` or None.
                None will return a torch vector comprising the individual loss values of the batch.
        """
        super(RNNTLoss, self).__init__()

        if not WARP_RNNT_AVAILABLE:
            raise ImportError(
                "Could not import `warprnnt_pytorch`.\n"
                "Please visit https://github.com/HawkAaron/warp-transducer "
                "and follow the steps in the readme to build and install the "
                "pytorch bindings for RNNT Loss, or use the provided docker "
                "container that supports RNN-T loss.")

        if reduction not in [None, 'mean', 'sum', 'mean_batch']:
            raise ValueError(
                '`reduction` must be one of [mean, sum, mean_batch]')

        self._blank = num_classes
        self.reduction = reduction
        self._loss = warprnnt.RNNTLoss(blank=self._blank, reduction='none')
예제 #4
0
def resolve_rnnt_loss(loss_name: str,
                      blank_idx: int,
                      loss_kwargs: dict = None) -> torch.nn.Module:
    loss_function_names = list(RNNT_LOSS_RESOLVER.keys())

    if loss_name not in loss_function_names:
        raise ValueError(
            f"Provided `loss_name` {loss_name} not in list of available RNNT losses \n"
            f"{loss_function_names}")

    all_available_losses = {
        name: config
        for name, config in RNNT_LOSS_RESOLVER.items() if config.is_available
    }

    loss_config = RNNT_LOSS_RESOLVER[loss_name]  # type: RNNTLossConfig

    # Re-raise import error with installation message
    if not loss_config.is_available:
        msg = (
            f"Installed RNNT losses are : {list(all_available_losses.keys())}.\n"
            f"****************************************************************\n"
            f"To install the selected loss function, please follow the steps below:\n"
            f"{loss_config.installation_msg}")
        raise ImportError(msg)

    # Resolve loss functions sequentially
    loss_kwargs = {} if loss_kwargs is None else loss_kwargs

    if loss_name == 'warprnnt':
        loss_func = warprnnt.RNNTLoss(blank=blank_idx, reduction='none')
        _warn_unused_additional_kwargs(loss_name, loss_kwargs)

    elif loss_name == 'warprnnt_numba':
        loss_func = RNNTLossNumba(blank=blank_idx, reduction='none')
        _warn_unused_additional_kwargs(loss_name, loss_kwargs)

    else:
        raise ValueError(
            f"Invalid value of `loss_name`: {loss_name}. Allowed loss names are :"
            f"{loss_function_names}")

    return loss_func
예제 #5
0
def resolve_rnnt_loss(loss_name: str,
                      blank_idx: int,
                      loss_kwargs: dict = None) -> torch.nn.Module:
    loss_function_names = list(RNNT_LOSS_RESOLVER.keys())

    if loss_name not in loss_function_names:
        raise ValueError(
            f"Provided `loss_name` {loss_name} not in list of available RNNT losses \n"
            f"{loss_function_names}")

    all_available_losses = {
        name: config
        for name, config in RNNT_LOSS_RESOLVER.items() if config.is_available
    }

    loss_config = RNNT_LOSS_RESOLVER[loss_name]  # type: RNNTLossConfig

    # Re-raise import error with installation message
    if not loss_config.is_available:
        msg = (
            f"Installed RNNT losses are : {list(all_available_losses.keys())}.\n"
            f"****************************************************************\n"
            f"To install the selected loss function, please follow the steps below:\n"
            f"{loss_config.installation_msg}")
        raise ImportError(msg)

    # Library version check
    if loss_config.min_version is not None:
        ver_matched, msg = model_utils.check_lib_version(
            loss_config.lib_name,
            checked_version=loss_config.min_version,
            operator=operator.ge)

        if ver_matched is False:
            msg = (
                f"{msg}\n"
                f"****************************************************************\n"
                f"To update the selected loss function, please follow the steps below:\n"
                f"{loss_config.installation_msg}")
            raise RuntimeError(msg)

    # Resolve loss functions sequentially
    loss_kwargs = {} if loss_kwargs is None else loss_kwargs

    if isinstance(loss_kwargs, DictConfig):
        loss_kwargs = OmegaConf.to_container(loss_kwargs, resolve=True)

    # Get actual loss name for `default`
    if loss_name == 'default':
        loss_name = loss_config.loss_name
    """
    Resolve RNNT loss functions
    """
    if loss_name == 'warprnnt':
        loss_func = warprnnt.RNNTLoss(blank=blank_idx, reduction='none')
        _warn_unused_additional_kwargs(loss_name, loss_kwargs)

    elif loss_name == 'warprnnt_numba':
        fastemit_lambda = loss_kwargs.pop('fastemit_lambda', 0.0)
        clamp = loss_kwargs.pop('clamp', -1.0)
        loss_func = RNNTLossNumba(blank=blank_idx,
                                  reduction='none',
                                  fastemit_lambda=fastemit_lambda,
                                  clamp=clamp)
        _warn_unused_additional_kwargs(loss_name, loss_kwargs)

    else:
        raise ValueError(
            f"Invalid value of `loss_name`: {loss_name}. Allowed loss names are :"
            f"{loss_function_names}")

    return loss_func
예제 #6
0
    def __init__(self,
                 eos,
                 unk,
                 pad,
                 blank,
                 enc_n_units,
                 rnn_type,
                 n_units,
                 n_projs,
                 n_layers,
                 residual,
                 bottleneck_dim,
                 emb_dim,
                 vocab,
                 tie_embedding=False,
                 dropout=0.0,
                 dropout_emb=0.0,
                 lsm_prob=0.0,
                 ctc_weight=0.0,
                 ctc_lsm_prob=0.0,
                 ctc_fc_list=[],
                 lm_init=None,
                 lmobj_weight=0.0,
                 share_lm_softmax=False,
                 global_weight=1.0,
                 mtl_per_batch=False,
                 param_init=0.1,
                 start_pointing=False,
                 end_pointing=True):

        super(RNNTransducer, self).__init__()
        logger = logging.getLogger('training')

        self.eos = eos
        self.unk = unk
        self.pad = pad
        self.blank = blank
        self.vocab = vocab
        self.rnn_type = rnn_type
        assert rnn_type in ['lstm_transducer', 'gru_transducer']
        self.enc_n_units = enc_n_units
        self.dec_n_units = n_units
        self.n_projs = n_projs
        self.n_layers = n_layers
        self.residual = residual
        self.lsm_prob = lsm_prob
        self.ctc_weight = ctc_weight
        self.lmobj_weight = lmobj_weight
        self.share_lm_softmax = share_lm_softmax
        self.global_weight = global_weight
        self.mtl_per_batch = mtl_per_batch

        # VAD
        self.start_pointing = start_pointing
        self.end_pointing = end_pointing

        # for cache
        self.prev_spk = ''
        self.lmstate_final = None
        self.state_cache = OrderedDict()

        if ctc_weight > 0:
            self.ctc = CTC(eos=eos,
                           blank=blank,
                           enc_n_units=enc_n_units,
                           vocab=vocab,
                           dropout=dropout,
                           lsm_prob=ctc_lsm_prob,
                           fc_list=ctc_fc_list,
                           param_init=param_init)

        if ctc_weight < global_weight:
            import warprnnt_pytorch
            self.warprnnt_loss = warprnnt_pytorch.RNNTLoss()

            # for MTL with LM objective
            if lmobj_weight > 0:
                if share_lm_softmax:
                    self.output_lmobj = self.output  # share paramters
                else:
                    self.output_lmobj = Linear(n_units, vocab)

            # Prediction network
            self.fast_impl = False
            rnn = nn.LSTM if rnn_type == 'lstm_transducer' else nn.GRU
            if n_projs == 0 and not residual:
                self.fast_impl = True
                self.rnn = rnn(emb_dim, n_units, n_layers,
                               bias=True,
                               batch_first=True,
                               dropout=dropout,
                               bidirectional=False)
                # NOTE: pytorch introduces a dropout layer on the outputs of each layer EXCEPT the last layer
                dec_idim = n_units
                self.dropout_top = nn.Dropout(p=dropout)
            else:
                self.rnn = nn.ModuleList()
                self.dropout = nn.ModuleList([nn.Dropout(p=dropout) for _ in range(n_layers)])
                if n_projs > 0:
                    self.proj = nn.ModuleList([Linear(dec_idim, n_projs) for _ in range(n_layers)])
                dec_idim = emb_dim
                for l in range(n_layers):
                    self.rnn += [rnn(dec_idim, n_units, 1,
                                     bias=True,
                                     batch_first=True,
                                     dropout=0,
                                     bidirectional=False)]
                    dec_idim = n_projs if n_projs > 0 else n_units

            self.embed = Embedding(vocab, emb_dim,
                                   dropout=dropout_emb,
                                   ignore_index=pad)

            self.w_enc = Linear(enc_n_units, bottleneck_dim, bias=True)
            self.w_dec = Linear(dec_idim, bottleneck_dim, bias=False)
            self.output = Linear(bottleneck_dim, vocab)

        # Initialize parameters
        self.reset_parameters(param_init)

        # prediction network initialization with pre-trained LM
        if lm_init is not None:
            assert lm_init.vocab == vocab
            assert lm_init.n_units == n_units
            assert lm_init.n_projs == n_projs
            assert lm_init.n_layers == n_layers
            assert lm_init.residual == residual

            param_dict = dict(lm_init.named_parameters())
            for n, p in self.named_parameters():
                if n in param_dict.keys() and p.size() == param_dict[n].size():
                    if 'output' in n:
                        continue
                    p.data = param_dict[n].data
                    logger.info('Overwrite %s' % n)