def forward_transducer(self, eouts, elens, ys): """Compute Transducer loss. Args: eouts (FloatTensor): `[B, T, enc_n_units]` elens (IntTensor): `[B]` ys (list): length `B`, each of which contains a list of size `[L]` Returns: loss (FloatTensor): `[1]` """ # Append <sos> and <eos> _ys = [ np2tensor(np.fromiter(y, dtype=np.int64), eouts.device) for y in ys ] ylens = np2tensor(np.fromiter([y.size(0) for y in _ys], dtype=np.int32)) eos = eouts.new_zeros((1, ), dtype=torch.int64).fill_(self.eos) ys_in = pad_list([torch.cat([eos, y], dim=0) for y in _ys], self.pad) # `[B, L+1]` ys_out = pad_list(_ys, self.blank) # `[B, L]` # Update prediction network ys_emb = self.dropout_emb(self.embed(ys_in)) dout, _ = self.recurrency(ys_emb, None) # Compute output distribution logits = self.joint(eouts, dout) # `[B, T, L+1, vocab]` # Compute Transducer loss log_probs = torch.log_softmax(logits, dim=-1) assert log_probs.size(2) == ys_out.size(1) + 1 if self.device_id >= 0: ys_out = ys_out.to(eouts.device) elens = elens.to(eouts.device) ylens = ylens.to(eouts.device) import warp_rnnt loss = warp_rnnt.rnnt_loss(log_probs, ys_out.int(), elens, ylens, average_frames=False, reduction='mean', gather=False) else: import warprnnt_pytorch self.warprnnt_loss = warprnnt_pytorch.RNNTLoss() loss = self.warprnnt_loss(log_probs, ys_out.int(), elens, ylens) # NOTE: Transducer loss has already been normalized by bs # NOTE: index 0 is reserved for blank in warprnnt_pytorch return loss
def __init__(self, vocab: Vocabulary, input_size: int, hidden_size: int, loss_ratio: float = 1.0, recurrency: nn.LSTM = None, num_layers: int = None, remove_sos: bool = True, remove_eos: bool = False, target_embedder: Embedding = None, target_embedding_dim: int = None, target_namespace: str = "tokens", slow_decode: bool = False, initializer: InitializerApplicator = InitializerApplicator(), regularizer: Optional[RegularizerApplicator] = None) -> None: super(RNNTLayer, self).__init__(vocab, regularizer) import warprnnt_pytorch self.loss_ratio = loss_ratio self._remove_sos = remove_sos self._remove_eos = remove_eos self._slow_decode = slow_decode self._target_namespace = target_namespace self._num_classes = self.vocab.get_vocab_size(target_namespace) self._pad_index = self.vocab.get_token_index(DEFAULT_PADDING_TOKEN, self._target_namespace) self._start_index = self.vocab.get_token_index(START_SYMBOL, self._target_namespace) self._end_index = self.vocab.get_token_index(END_SYMBOL, self._target_namespace) self._loss = warprnnt_pytorch.RNNTLoss(blank=self._pad_index, reduction='mean') self._recurrency = recurrency or \ nn.LSTM(input_size=target_embedding_dim, hidden_size=hidden_size, num_layers=num_layers, batch_first=True) self._target_embedder = target_embedder or Embedding( self._num_classes, target_embedding_dim) self.w_enc = nn.Linear(input_size, hidden_size, bias=True) self.w_dec = nn.Linear(input_size, hidden_size, bias=False) self._proj = nn.Linear(hidden_size, self._num_classes) exclude_indices = {self._pad_index, self._end_index, self._start_index} self._wer: Metric = WER(exclude_indices=exclude_indices) self._bleu: Metric = BLEU(exclude_indices=exclude_indices) self._dal = Average() initializer(self)
def __init__(self, num_classes, reduction='mean_batch'): """ RNN-T Loss function based on https://github.com/HawkAaron/warp-transducer. Note: Requires the pytorch bindings to be installed prior to calling this class. Warning: In the case that GPU memory is exhausted in order to compute RNNTLoss, it might cause a core dump at the cuda level with the following error message. ``` ... costs = costs.to(acts.device) RuntimeError: CUDA error: an illegal memory access was encountered terminate called after throwing an instance of 'c10::Error' ``` Please kill all remaining python processes after this point, and use a smaller batch size for train, validation and test sets so that CUDA memory is not exhausted. Args: num_classes: Number of target classes for the joint network to predict. (Excluding the RNN-T blank token). reduction: Type of reduction to perform on loss. Possibly values are `mean`, `sum` or None. None will return a torch vector comprising the individual loss values of the batch. """ super(RNNTLoss, self).__init__() if not WARP_RNNT_AVAILABLE: raise ImportError( "Could not import `warprnnt_pytorch`.\n" "Please visit https://github.com/HawkAaron/warp-transducer " "and follow the steps in the readme to build and install the " "pytorch bindings for RNNT Loss, or use the provided docker " "container that supports RNN-T loss.") if reduction not in [None, 'mean', 'sum', 'mean_batch']: raise ValueError( '`reduction` must be one of [mean, sum, mean_batch]') self._blank = num_classes self.reduction = reduction self._loss = warprnnt.RNNTLoss(blank=self._blank, reduction='none')
def resolve_rnnt_loss(loss_name: str, blank_idx: int, loss_kwargs: dict = None) -> torch.nn.Module: loss_function_names = list(RNNT_LOSS_RESOLVER.keys()) if loss_name not in loss_function_names: raise ValueError( f"Provided `loss_name` {loss_name} not in list of available RNNT losses \n" f"{loss_function_names}") all_available_losses = { name: config for name, config in RNNT_LOSS_RESOLVER.items() if config.is_available } loss_config = RNNT_LOSS_RESOLVER[loss_name] # type: RNNTLossConfig # Re-raise import error with installation message if not loss_config.is_available: msg = ( f"Installed RNNT losses are : {list(all_available_losses.keys())}.\n" f"****************************************************************\n" f"To install the selected loss function, please follow the steps below:\n" f"{loss_config.installation_msg}") raise ImportError(msg) # Resolve loss functions sequentially loss_kwargs = {} if loss_kwargs is None else loss_kwargs if loss_name == 'warprnnt': loss_func = warprnnt.RNNTLoss(blank=blank_idx, reduction='none') _warn_unused_additional_kwargs(loss_name, loss_kwargs) elif loss_name == 'warprnnt_numba': loss_func = RNNTLossNumba(blank=blank_idx, reduction='none') _warn_unused_additional_kwargs(loss_name, loss_kwargs) else: raise ValueError( f"Invalid value of `loss_name`: {loss_name}. Allowed loss names are :" f"{loss_function_names}") return loss_func
def resolve_rnnt_loss(loss_name: str, blank_idx: int, loss_kwargs: dict = None) -> torch.nn.Module: loss_function_names = list(RNNT_LOSS_RESOLVER.keys()) if loss_name not in loss_function_names: raise ValueError( f"Provided `loss_name` {loss_name} not in list of available RNNT losses \n" f"{loss_function_names}") all_available_losses = { name: config for name, config in RNNT_LOSS_RESOLVER.items() if config.is_available } loss_config = RNNT_LOSS_RESOLVER[loss_name] # type: RNNTLossConfig # Re-raise import error with installation message if not loss_config.is_available: msg = ( f"Installed RNNT losses are : {list(all_available_losses.keys())}.\n" f"****************************************************************\n" f"To install the selected loss function, please follow the steps below:\n" f"{loss_config.installation_msg}") raise ImportError(msg) # Library version check if loss_config.min_version is not None: ver_matched, msg = model_utils.check_lib_version( loss_config.lib_name, checked_version=loss_config.min_version, operator=operator.ge) if ver_matched is False: msg = ( f"{msg}\n" f"****************************************************************\n" f"To update the selected loss function, please follow the steps below:\n" f"{loss_config.installation_msg}") raise RuntimeError(msg) # Resolve loss functions sequentially loss_kwargs = {} if loss_kwargs is None else loss_kwargs if isinstance(loss_kwargs, DictConfig): loss_kwargs = OmegaConf.to_container(loss_kwargs, resolve=True) # Get actual loss name for `default` if loss_name == 'default': loss_name = loss_config.loss_name """ Resolve RNNT loss functions """ if loss_name == 'warprnnt': loss_func = warprnnt.RNNTLoss(blank=blank_idx, reduction='none') _warn_unused_additional_kwargs(loss_name, loss_kwargs) elif loss_name == 'warprnnt_numba': fastemit_lambda = loss_kwargs.pop('fastemit_lambda', 0.0) clamp = loss_kwargs.pop('clamp', -1.0) loss_func = RNNTLossNumba(blank=blank_idx, reduction='none', fastemit_lambda=fastemit_lambda, clamp=clamp) _warn_unused_additional_kwargs(loss_name, loss_kwargs) else: raise ValueError( f"Invalid value of `loss_name`: {loss_name}. Allowed loss names are :" f"{loss_function_names}") return loss_func
def __init__(self, eos, unk, pad, blank, enc_n_units, rnn_type, n_units, n_projs, n_layers, residual, bottleneck_dim, emb_dim, vocab, tie_embedding=False, dropout=0.0, dropout_emb=0.0, lsm_prob=0.0, ctc_weight=0.0, ctc_lsm_prob=0.0, ctc_fc_list=[], lm_init=None, lmobj_weight=0.0, share_lm_softmax=False, global_weight=1.0, mtl_per_batch=False, param_init=0.1, start_pointing=False, end_pointing=True): super(RNNTransducer, self).__init__() logger = logging.getLogger('training') self.eos = eos self.unk = unk self.pad = pad self.blank = blank self.vocab = vocab self.rnn_type = rnn_type assert rnn_type in ['lstm_transducer', 'gru_transducer'] self.enc_n_units = enc_n_units self.dec_n_units = n_units self.n_projs = n_projs self.n_layers = n_layers self.residual = residual self.lsm_prob = lsm_prob self.ctc_weight = ctc_weight self.lmobj_weight = lmobj_weight self.share_lm_softmax = share_lm_softmax self.global_weight = global_weight self.mtl_per_batch = mtl_per_batch # VAD self.start_pointing = start_pointing self.end_pointing = end_pointing # for cache self.prev_spk = '' self.lmstate_final = None self.state_cache = OrderedDict() if ctc_weight > 0: self.ctc = CTC(eos=eos, blank=blank, enc_n_units=enc_n_units, vocab=vocab, dropout=dropout, lsm_prob=ctc_lsm_prob, fc_list=ctc_fc_list, param_init=param_init) if ctc_weight < global_weight: import warprnnt_pytorch self.warprnnt_loss = warprnnt_pytorch.RNNTLoss() # for MTL with LM objective if lmobj_weight > 0: if share_lm_softmax: self.output_lmobj = self.output # share paramters else: self.output_lmobj = Linear(n_units, vocab) # Prediction network self.fast_impl = False rnn = nn.LSTM if rnn_type == 'lstm_transducer' else nn.GRU if n_projs == 0 and not residual: self.fast_impl = True self.rnn = rnn(emb_dim, n_units, n_layers, bias=True, batch_first=True, dropout=dropout, bidirectional=False) # NOTE: pytorch introduces a dropout layer on the outputs of each layer EXCEPT the last layer dec_idim = n_units self.dropout_top = nn.Dropout(p=dropout) else: self.rnn = nn.ModuleList() self.dropout = nn.ModuleList([nn.Dropout(p=dropout) for _ in range(n_layers)]) if n_projs > 0: self.proj = nn.ModuleList([Linear(dec_idim, n_projs) for _ in range(n_layers)]) dec_idim = emb_dim for l in range(n_layers): self.rnn += [rnn(dec_idim, n_units, 1, bias=True, batch_first=True, dropout=0, bidirectional=False)] dec_idim = n_projs if n_projs > 0 else n_units self.embed = Embedding(vocab, emb_dim, dropout=dropout_emb, ignore_index=pad) self.w_enc = Linear(enc_n_units, bottleneck_dim, bias=True) self.w_dec = Linear(dec_idim, bottleneck_dim, bias=False) self.output = Linear(bottleneck_dim, vocab) # Initialize parameters self.reset_parameters(param_init) # prediction network initialization with pre-trained LM if lm_init is not None: assert lm_init.vocab == vocab assert lm_init.n_units == n_units assert lm_init.n_projs == n_projs assert lm_init.n_layers == n_layers assert lm_init.residual == residual param_dict = dict(lm_init.named_parameters()) for n, p in self.named_parameters(): if n in param_dict.keys() and p.size() == param_dict[n].size(): if 'output' in n: continue p.data = param_dict[n].data logger.info('Overwrite %s' % n)