def __init__( self, vocab_size: int, token_list: Union[Tuple[str, ...], List[str]], frontend: Optional[AbsFrontend], specaug: Optional[AbsSpecAug], normalize: Optional[AbsNormalize], preencoder: Optional[AbsPreEncoder], encoder: AbsEncoder, ignore_id: int = -1, lsm_weight: float = 0.0, length_normalized_loss: bool = False, report_cer: bool = False, report_wer: bool = False, sym_space: str = "<space>", sym_blank: str = "<blank>", pred_masked_weight: float = 1.0, pred_nomask_weight: float = 0.0, loss_weights: float = 0.0, ): assert check_argument_types() super().__init__() # note that eos is the same as sos (equivalent ID) self.sos = vocab_size - 1 self.eos = vocab_size - 1 self.vocab_size = vocab_size self.ignore_id = ignore_id self.token_list = token_list.copy() self.frontend = frontend self.specaug = specaug self.normalize = normalize self.preencoder = preencoder self.encoder = encoder self.criterion_att = HubertPretrainLoss( pred_masked_weight, pred_nomask_weight, loss_weights, ) self.pred_masked_weight = pred_masked_weight self.pred_nomask_weight = pred_nomask_weight self.loss_weights = loss_weights if report_cer or report_wer: self.error_calculator = ErrorCalculator(token_list, sym_space, sym_blank, report_cer, report_wer) else: self.error_calculator = None
def test_hubert_loss_forward_backward(hubert_args): hloss = HubertPretrainLoss() hloss(*hubert_args)