Esempio n. 1
0
class E2E(ASRInterface, torch.nn.Module):
    """E2E module for transducer models.

    Args:
        idim: Dimension of inputs.
        odim: Dimension of outputs.
        args: Namespace containing model options.
        ignore_id: Padding symbol ID.
        blank_id: Blank symbol ID.

    """
    @staticmethod
    def add_arguments(parser: ArgumentParser) -> ArgumentParser:
        """Add arguments for transducer model."""
        E2E.encoder_add_general_arguments(parser)
        E2E.encoder_add_rnn_arguments(parser)
        E2E.encoder_add_custom_arguments(parser)

        E2E.decoder_add_general_arguments(parser)
        E2E.decoder_add_rnn_arguments(parser)
        E2E.decoder_add_custom_arguments(parser)

        E2E.training_add_custom_arguments(parser)
        E2E.transducer_add_arguments(parser)
        E2E.auxiliary_task_add_arguments(parser)

        return parser

    @staticmethod
    def encoder_add_general_arguments(
            parser: ArgumentParser) -> ArgumentParser:
        """Add general arguments for encoder."""
        group = parser.add_argument_group("Encoder general arguments")
        group = add_encoder_general_arguments(group)

        return parser

    @staticmethod
    def encoder_add_rnn_arguments(parser: ArgumentParser) -> ArgumentParser:
        """Add arguments for RNN encoder."""
        group = parser.add_argument_group("RNN encoder arguments")
        group = add_rnn_encoder_arguments(group)

        return parser

    @staticmethod
    def encoder_add_custom_arguments(parser: ArgumentParser) -> ArgumentParser:
        """Add arguments for Custom encoder."""
        group = parser.add_argument_group("Custom encoder arguments")
        group = add_custom_encoder_arguments(group)

        return parser

    @staticmethod
    def decoder_add_general_arguments(
            parser: ArgumentParser) -> ArgumentParser:
        """Add general arguments for decoder."""
        group = parser.add_argument_group("Decoder general arguments")
        group = add_decoder_general_arguments(group)

        return parser

    @staticmethod
    def decoder_add_rnn_arguments(parser: ArgumentParser) -> ArgumentParser:
        """Add arguments for RNN decoder."""
        group = parser.add_argument_group("RNN decoder arguments")
        group = add_rnn_decoder_arguments(group)

        return parser

    @staticmethod
    def decoder_add_custom_arguments(parser: ArgumentParser) -> ArgumentParser:
        """Add arguments for Custom decoder."""
        group = parser.add_argument_group("Custom decoder arguments")
        group = add_custom_decoder_arguments(group)

        return parser

    @staticmethod
    def training_add_custom_arguments(
            parser: ArgumentParser) -> ArgumentParser:
        """Add arguments for Custom architecture training."""
        group = parser.add_argument_group(
            "Training arguments for custom archictecture")
        group = add_custom_training_arguments(group)

        return parser

    @staticmethod
    def transducer_add_arguments(parser: ArgumentParser) -> ArgumentParser:
        """Add arguments for transducer model."""
        group = parser.add_argument_group("Transducer model arguments")
        group = add_transducer_arguments(group)

        return parser

    @staticmethod
    def auxiliary_task_add_arguments(parser: ArgumentParser) -> ArgumentParser:
        """Add arguments for auxiliary task."""
        group = parser.add_argument_group("Auxiliary task arguments")
        group = add_auxiliary_task_arguments(group)

        return parser

    @property
    def attention_plot_class(self):
        """Get attention plot class."""
        return PlotAttentionReport

    def get_total_subsampling_factor(self) -> float:
        """Get total subsampling factor."""
        if self.etype == "custom":
            return self.encoder.conv_subsampling_factor * int(
                numpy.prod(self.subsample))
        else:
            return self.enc.conv_subsampling_factor * int(
                numpy.prod(self.subsample))

    def __init__(
        self,
        idim: int,
        odim: int,
        args: Namespace,
        ignore_id: int = -1,
        blank_id: int = 0,
        training: bool = True,
    ):
        """Construct an E2E object for transducer model."""
        torch.nn.Module.__init__(self)

        args = fill_missing_args(args, self.add_arguments)

        self.is_transducer = True

        self.use_auxiliary_enc_outputs = (True if (
            training and args.use_aux_transducer_loss) else False)

        self.subsample = get_subsample(
            args,
            mode="asr",
            arch="transformer" if args.etype == "custom" else "rnn-t")

        if self.use_auxiliary_enc_outputs:
            n_layers = (((len(args.enc_block_arch) * args.enc_block_repeat) -
                         1) if args.enc_block_arch is not None else
                        (args.elayers - 1))

            aux_enc_output_layers = valid_aux_encoder_output_layers(
                args.aux_transducer_loss_enc_output_layers,
                n_layers,
                args.use_symm_kl_div_loss,
                self.subsample,
            )
        else:
            aux_enc_output_layers = []

        if args.etype == "custom":
            if args.enc_block_arch is None:
                raise ValueError(
                    "When specifying custom encoder type, --enc-block-arch"
                    "should also be specified in training config. See"
                    "egs/vivos/asr1/conf/transducer/train_*.yaml for more info."
                )

            self.encoder = CustomEncoder(
                idim,
                args.enc_block_arch,
                input_layer=args.custom_enc_input_layer,
                repeat_block=args.enc_block_repeat,
                self_attn_type=args.custom_enc_self_attn_type,
                positional_encoding_type=args.
                custom_enc_positional_encoding_type,
                positionwise_activation_type=args.
                custom_enc_pw_activation_type,
                conv_mod_activation_type=args.
                custom_enc_conv_mod_activation_type,
                aux_enc_output_layers=aux_enc_output_layers,
            )
            encoder_out = self.encoder.enc_out
        else:
            self.enc = encoder_for(
                args,
                idim,
                self.subsample,
                aux_enc_output_layers=aux_enc_output_layers,
            )
            encoder_out = args.eprojs

        if args.dtype == "custom":
            if args.dec_block_arch is None:
                raise ValueError(
                    "When specifying custom decoder type, --dec-block-arch"
                    "should also be specified in training config. See"
                    "egs/vivos/asr1/conf/transducer/train_*.yaml for more info."
                )

            self.decoder = CustomDecoder(
                odim,
                args.dec_block_arch,
                input_layer=args.custom_dec_input_layer,
                repeat_block=args.dec_block_repeat,
                positionwise_activation_type=args.
                custom_dec_pw_activation_type,
                dropout_rate_embed=args.dropout_rate_embed_decoder,
                blank_id=blank_id,
            )
            decoder_out = self.decoder.dunits
        else:
            self.dec = RNNDecoder(
                odim,
                args.dtype,
                args.dlayers,
                args.dunits,
                args.dec_embed_dim,
                dropout_rate=args.dropout_rate_decoder,
                dropout_rate_embed=args.dropout_rate_embed_decoder,
                blank_id=blank_id,
            )
            decoder_out = args.dunits

        self.transducer_tasks = TransducerTasks(
            encoder_out,
            decoder_out,
            args.joint_dim,
            odim,
            joint_activation_type=args.joint_activation_type,
            transducer_loss_weight=args.transducer_weight,
            ctc_loss=args.use_ctc_loss,
            ctc_loss_weight=args.ctc_loss_weight,
            ctc_loss_dropout_rate=args.ctc_loss_dropout_rate,
            lm_loss=args.use_lm_loss,
            lm_loss_weight=args.lm_loss_weight,
            lm_loss_smoothing_rate=args.lm_loss_smoothing_rate,
            aux_transducer_loss=args.use_aux_transducer_loss,
            aux_transducer_loss_weight=args.aux_transducer_loss_weight,
            aux_transducer_loss_mlp_dim=args.aux_transducer_loss_mlp_dim,
            aux_trans_loss_mlp_dropout_rate=args.
            aux_transducer_loss_mlp_dropout_rate,
            symm_kl_div_loss=args.use_symm_kl_div_loss,
            symm_kl_div_loss_weight=args.symm_kl_div_loss_weight,
            fastemit_lambda=args.fastemit_lambda,
            blank_id=blank_id,
            ignore_id=ignore_id,
            training=training,
        )

        if training and (args.report_cer or args.report_wer):
            self.error_calculator = ErrorCalculator(
                self.decoder if args.dtype == "custom" else self.dec,
                self.transducer_tasks.joint_network,
                args.char_list,
                args.sym_space,
                args.sym_blank,
                args.report_cer,
                args.report_wer,
            )
        else:
            self.error_calculator = None

        self.etype = args.etype
        self.dtype = args.dtype

        self.sos = odim - 1
        self.eos = odim - 1
        self.blank_id = blank_id
        self.ignore_id = ignore_id

        self.space = args.sym_space
        self.blank = args.sym_blank

        self.odim = odim

        self.reporter = Reporter()

        self.default_parameters(args)

        self.loss = None
        self.rnnlm = None

    def default_parameters(self, args: Namespace):
        """Initialize/reset parameters for transducer.

        Args:
            args: Namespace containing model options.

        """
        initializer(self, args)

    def forward(self, feats: torch.Tensor, feats_len: torch.Tensor,
                labels: torch.Tensor) -> torch.Tensor:
        """E2E forward.

        Args:
            feats: Feature sequences. (B, F, D_feats)
            feats_len: Feature sequences lengths. (B,)
            labels: Label ID sequences. (B, L)

        Returns:
            loss: Transducer loss value

        """
        # 1. encoder
        feats = feats[:, :max(feats_len)]

        if self.etype == "custom":
            feats_mask = (make_non_pad_mask(feats_len.tolist()).to(
                feats.device).unsqueeze(-2))

            _enc_out, _enc_out_len = self.encoder(feats, feats_mask)
        else:
            _enc_out, _enc_out_len, _ = self.enc(feats, feats_len)

        if self.use_auxiliary_enc_outputs:
            enc_out, aux_enc_out = _enc_out[0], _enc_out[1]
            enc_out_len, aux_enc_out_len = _enc_out_len[0], _enc_out_len[1]
        else:
            enc_out, aux_enc_out = _enc_out, None
            enc_out_len, aux_enc_out_len = _enc_out_len, None

        # 2. decoder
        dec_in = get_decoder_input(labels, self.blank_id, self.ignore_id)

        if self.dtype == "custom":
            self.decoder.set_device(enc_out.device)

            dec_in_mask = target_mask(dec_in, self.blank_id)
            dec_out, _ = self.decoder(dec_in, dec_in_mask)
        else:
            self.dec.set_device(enc_out.device)

            dec_out = self.dec(dec_in)

        # 3. transducer tasks computation
        losses = self.transducer_tasks(
            enc_out,
            aux_enc_out,
            dec_out,
            labels,
            enc_out_len,
            aux_enc_out_len,
        )

        if self.training or self.error_calculator is None:
            cer, wer = None, None
        else:
            cer, wer = self.error_calculator(
                enc_out, self.transducer_tasks.get_target())

        self.loss = sum(losses)
        loss_data = float(self.loss)

        if not math.isnan(loss_data):
            self.reporter.report(
                loss_data,
                *[float(loss) for loss in losses],
                cer,
                wer,
            )
        else:
            logging.warning("loss (=%f) is not correct", loss_data)

        return self.loss

    def encode_custom(self, feats: numpy.ndarray) -> torch.Tensor:
        """Encode acoustic features.

        Args:
            feats: Feature sequence. (F, D_feats)

        Returns:
            enc_out: Encoded feature sequence. (T, D_enc)

        """
        feats = torch.as_tensor(feats).unsqueeze(0)
        enc_out, _ = self.encoder(feats, None)

        return enc_out.squeeze(0)

    def encode_rnn(self, feats: numpy.ndarray) -> torch.Tensor:
        """Encode acoustic features.

        Args:
            feats: Feature sequence. (F, D_feats)

        Returns:
            enc_out: Encoded feature sequence. (T, D_enc)

        """
        p = next(self.parameters())

        feats_len = [feats.shape[0]]

        feats = feats[::self.subsample[0], :]
        feats = torch.as_tensor(feats, device=p.device, dtype=p.dtype)
        feats = feats.contiguous().unsqueeze(0)

        enc_out, _, _ = self.enc(feats, feats_len)

        return enc_out.squeeze(0)

    def recognize(self, feats: numpy.ndarray,
                  beam_search: BeamSearchTransducer) -> List:
        """Recognize input features.

        Args:
            feats: Feature sequence. (F, D_feats)
            beam_search: Beam search class.

        Returns:
            nbest_hyps: N-best decoding results.

        """
        self.eval()

        if self.etype == "custom":
            enc_out = self.encode_custom(feats)
        else:
            enc_out = self.encode_rnn(feats)

        nbest_hyps = beam_search(enc_out)

        return [asdict(n) for n in nbest_hyps]

    def calculate_all_attentions(self, feats: torch.Tensor,
                                 feats_len: torch.Tensor,
                                 labels: torch.Tensor) -> numpy.ndarray:
        """E2E attention calculation.

        Args:
            feats: Feature sequences. (B, F, D_feats)
            feats_len: Feature sequences lengths. (B,)
            labels: Label ID sequences. (B, L)

        Returns:
            ret: Attention weights with the following shape,
                1) multi-head case => attention weights. (B, D_att, U, T),
                2) other case => attention weights. (B, U, T)

        """
        self.eval()

        if self.etype != "custom" and self.dtype != "custom":
            return []
        else:
            with torch.no_grad():
                self.forward(feats, feats_len, labels)

            ret = dict()
            for name, m in self.named_modules():
                if isinstance(m, MultiHeadedAttention) or isinstance(
                        m, RelPositionMultiHeadedAttention):
                    ret[name] = m.attn.cpu().numpy()

        self.train()

        return ret
Esempio n. 2
0
    def __init__(self,
                 idim,
                 odim,
                 args,
                 ignore_id=-1,
                 blank_id=0,
                 training=True):
        """Construct an E2E object for transducer model."""
        torch.nn.Module.__init__(self)

        args = fill_missing_args(args, self.add_arguments)

        self.is_rnnt = True
        self.transducer_weight = args.transducer_weight

        self.use_aux_task = (True if (args.aux_task_type is not None
                                      and training) else False)

        self.use_aux_ctc = args.aux_ctc and training
        self.aux_ctc_weight = args.aux_ctc_weight

        self.use_aux_cross_entropy = args.aux_cross_entropy and training
        self.aux_cross_entropy_weight = args.aux_cross_entropy_weight

        if self.use_aux_task:
            n_layers = ((len(args.enc_block_arch) * args.enc_block_repeat -
                         1) if args.enc_block_arch is not None else
                        (args.elayers - 1))

            aux_task_layer_list = valid_aux_task_layer_list(
                args.aux_task_layer_list,
                n_layers,
            )
        else:
            aux_task_layer_list = []

        if "custom" in args.etype:
            if args.enc_block_arch is None:
                raise ValueError(
                    "When specifying custom encoder type, --enc-block-arch"
                    "should also be specified in training config. See"
                    "egs/vivos/asr1/conf/transducer/train_*.yaml for more info."
                )

            self.subsample = get_subsample(args,
                                           mode="asr",
                                           arch="transformer")

            self.encoder = CustomEncoder(
                idim,
                args.enc_block_arch,
                input_layer=args.custom_enc_input_layer,
                repeat_block=args.enc_block_repeat,
                self_attn_type=args.custom_enc_self_attn_type,
                positional_encoding_type=args.
                custom_enc_positional_encoding_type,
                positionwise_activation_type=args.
                custom_enc_pw_activation_type,
                conv_mod_activation_type=args.
                custom_enc_conv_mod_activation_type,
                aux_task_layer_list=aux_task_layer_list,
            )
            encoder_out = self.encoder.enc_out

            self.most_dom_list = args.enc_block_arch[:]
        else:
            self.subsample = get_subsample(args, mode="asr", arch="rnn-t")

            self.enc = encoder_for(
                args,
                idim,
                self.subsample,
                aux_task_layer_list=aux_task_layer_list,
            )
            encoder_out = args.eprojs

        if "custom" in args.dtype:
            if args.dec_block_arch is None:
                raise ValueError(
                    "When specifying custom decoder type, --dec-block-arch"
                    "should also be specified in training config. See"
                    "egs/vivos/asr1/conf/transducer/train_*.yaml for more info."
                )

            self.decoder = CustomDecoder(
                odim,
                args.dec_block_arch,
                input_layer=args.custom_dec_input_layer,
                repeat_block=args.dec_block_repeat,
                positionwise_activation_type=args.
                custom_dec_pw_activation_type,
                dropout_rate_embed=args.dropout_rate_embed_decoder,
            )
            decoder_out = self.decoder.dunits

            if "custom" in args.etype:
                self.most_dom_list += args.dec_block_arch[:]
            else:
                self.most_dom_list = args.dec_block_arch[:]
        else:
            self.dec = DecoderRNNT(
                odim,
                args.dtype,
                args.dlayers,
                args.dunits,
                blank_id,
                args.dec_embed_dim,
                args.dropout_rate_decoder,
                args.dropout_rate_embed_decoder,
            )
            decoder_out = args.dunits

        self.joint_network = JointNetwork(odim, encoder_out, decoder_out,
                                          args.joint_dim,
                                          args.joint_activation_type)

        if hasattr(self, "most_dom_list"):
            self.most_dom_dim = sorted(
                Counter(d["d_hidden"] for d in self.most_dom_list
                        if "d_hidden" in d).most_common(),
                key=lambda x: x[0],
                reverse=True,
            )[0][0]

        self.etype = args.etype
        self.dtype = args.dtype

        self.sos = odim - 1
        self.eos = odim - 1
        self.blank_id = blank_id
        self.ignore_id = ignore_id

        self.space = args.sym_space
        self.blank = args.sym_blank

        self.odim = odim

        self.reporter = Reporter()
        self.error_calculator = None

        self.default_parameters(args)

        if training:
            self.criterion = TransLoss(args.trans_type, self.blank_id)

            decoder = self.decoder if self.dtype == "custom" else self.dec

            if args.report_cer or args.report_wer:
                self.error_calculator = ErrorCalculator(
                    decoder,
                    self.joint_network,
                    args.char_list,
                    args.sym_space,
                    args.sym_blank,
                    args.report_cer,
                    args.report_wer,
                )

            if self.use_aux_task:
                self.auxiliary_task = AuxiliaryTask(
                    decoder,
                    self.joint_network,
                    self.criterion,
                    args.aux_task_type,
                    args.aux_task_weight,
                    encoder_out,
                    args.joint_dim,
                )

            if self.use_aux_ctc:
                self.aux_ctc = ctc_for(
                    Namespace(
                        num_encs=1,
                        eprojs=encoder_out,
                        dropout_rate=args.aux_ctc_dropout_rate,
                        ctc_type="warpctc",
                    ),
                    odim,
                )

            if self.use_aux_cross_entropy:
                self.aux_decoder_output = torch.nn.Linear(decoder_out, odim)

                self.aux_cross_entropy = LabelSmoothingLoss(
                    odim, ignore_id, args.aux_cross_entropy_smoothing)

        self.loss = None
        self.rnnlm = None
Esempio n. 3
0
    def __init__(
        self,
        idim: int,
        odim: int,
        args: Namespace,
        ignore_id: int = -1,
        blank_id: int = 0,
        training: bool = True,
    ):
        """Construct an E2E object for transducer model."""
        torch.nn.Module.__init__(self)

        args = fill_missing_args(args, self.add_arguments)

        self.is_transducer = True

        self.use_auxiliary_enc_outputs = (True if (
            training and args.use_aux_transducer_loss) else False)

        self.subsample = get_subsample(
            args,
            mode="asr",
            arch="transformer" if args.etype == "custom" else "rnn-t")

        if self.use_auxiliary_enc_outputs:
            n_layers = (((len(args.enc_block_arch) * args.enc_block_repeat) -
                         1) if args.enc_block_arch is not None else
                        (args.elayers - 1))

            aux_enc_output_layers = valid_aux_encoder_output_layers(
                args.aux_transducer_loss_enc_output_layers,
                n_layers,
                args.use_symm_kl_div_loss,
                self.subsample,
            )
        else:
            aux_enc_output_layers = []

        if args.etype == "custom":
            if args.enc_block_arch is None:
                raise ValueError(
                    "When specifying custom encoder type, --enc-block-arch"
                    "should also be specified in training config. See"
                    "egs/vivos/asr1/conf/transducer/train_*.yaml for more info."
                )

            self.encoder = CustomEncoder(
                idim,
                args.enc_block_arch,
                input_layer=args.custom_enc_input_layer,
                repeat_block=args.enc_block_repeat,
                self_attn_type=args.custom_enc_self_attn_type,
                positional_encoding_type=args.
                custom_enc_positional_encoding_type,
                positionwise_activation_type=args.
                custom_enc_pw_activation_type,
                conv_mod_activation_type=args.
                custom_enc_conv_mod_activation_type,
                aux_enc_output_layers=aux_enc_output_layers,
            )
            encoder_out = self.encoder.enc_out
        else:
            self.enc = encoder_for(
                args,
                idim,
                self.subsample,
                aux_enc_output_layers=aux_enc_output_layers,
            )
            encoder_out = args.eprojs

        if args.dtype == "custom":
            if args.dec_block_arch is None:
                raise ValueError(
                    "When specifying custom decoder type, --dec-block-arch"
                    "should also be specified in training config. See"
                    "egs/vivos/asr1/conf/transducer/train_*.yaml for more info."
                )

            self.decoder = CustomDecoder(
                odim,
                args.dec_block_arch,
                input_layer=args.custom_dec_input_layer,
                repeat_block=args.dec_block_repeat,
                positionwise_activation_type=args.
                custom_dec_pw_activation_type,
                dropout_rate_embed=args.dropout_rate_embed_decoder,
                blank_id=blank_id,
            )
            decoder_out = self.decoder.dunits
        else:
            self.dec = RNNDecoder(
                odim,
                args.dtype,
                args.dlayers,
                args.dunits,
                args.dec_embed_dim,
                dropout_rate=args.dropout_rate_decoder,
                dropout_rate_embed=args.dropout_rate_embed_decoder,
                blank_id=blank_id,
            )
            decoder_out = args.dunits

        self.transducer_tasks = TransducerTasks(
            encoder_out,
            decoder_out,
            args.joint_dim,
            odim,
            joint_activation_type=args.joint_activation_type,
            transducer_loss_weight=args.transducer_weight,
            ctc_loss=args.use_ctc_loss,
            ctc_loss_weight=args.ctc_loss_weight,
            ctc_loss_dropout_rate=args.ctc_loss_dropout_rate,
            lm_loss=args.use_lm_loss,
            lm_loss_weight=args.lm_loss_weight,
            lm_loss_smoothing_rate=args.lm_loss_smoothing_rate,
            aux_transducer_loss=args.use_aux_transducer_loss,
            aux_transducer_loss_weight=args.aux_transducer_loss_weight,
            aux_transducer_loss_mlp_dim=args.aux_transducer_loss_mlp_dim,
            aux_trans_loss_mlp_dropout_rate=args.
            aux_transducer_loss_mlp_dropout_rate,
            symm_kl_div_loss=args.use_symm_kl_div_loss,
            symm_kl_div_loss_weight=args.symm_kl_div_loss_weight,
            fastemit_lambda=args.fastemit_lambda,
            blank_id=blank_id,
            ignore_id=ignore_id,
            training=training,
        )

        if training and (args.report_cer or args.report_wer):
            self.error_calculator = ErrorCalculator(
                self.decoder if args.dtype == "custom" else self.dec,
                self.transducer_tasks.joint_network,
                args.char_list,
                args.sym_space,
                args.sym_blank,
                args.report_cer,
                args.report_wer,
            )
        else:
            self.error_calculator = None

        self.etype = args.etype
        self.dtype = args.dtype

        self.sos = odim - 1
        self.eos = odim - 1
        self.blank_id = blank_id
        self.ignore_id = ignore_id

        self.space = args.sym_space
        self.blank = args.sym_blank

        self.odim = odim

        self.reporter = Reporter()

        self.default_parameters(args)

        self.loss = None
        self.rnnlm = None
Esempio n. 4
0
    def __init__(self,
                 idim,
                 odim,
                 args,
                 ignore_id=-1,
                 blank_id=0,
                 training=True):
        """Construct an E2E object for transducer model."""
        torch.nn.Module.__init__(self)

        self.is_rnnt = True

        if "custom" in args.etype:
            if args.enc_block_arch is None:
                raise ValueError(
                    "When specifying custom encoder type, --enc-block-arch"
                    "should also be specified in training config. See"
                    "egs/vivos/asr1/conf/transducer/train_*.yaml for more info."
                )

            self.subsample = get_subsample(args,
                                           mode="asr",
                                           arch="transformer")

            self.encoder = CustomEncoder(
                idim,
                args.enc_block_arch,
                input_layer=args.custom_enc_input_layer,
                repeat_block=args.enc_block_repeat,
                self_attn_type=args.custom_enc_self_attn_type,
                positional_encoding_type=args.
                custom_enc_positional_encoding_type,
                positionwise_activation_type=args.
                custom_enc_pw_activation_type,
                conv_mod_activation_type=args.
                custom_enc_conv_mod_activation_type,
            )

            encoder_out = self.encoder.enc_out
            args.eprojs = self.encoder.enc_out

            self.most_dom_list = args.enc_block_arch[:]
        else:
            self.subsample = get_subsample(args, mode="asr", arch="rnn-t")

            self.enc = encoder_for(args, idim, self.subsample)

            encoder_out = args.eprojs

        if "custom" in args.dtype:
            if args.dec_block_arch is None:
                raise ValueError(
                    "When specifying custom decoder type, --dec-block-arch"
                    "should also be specified in training config. See"
                    "egs/vivos/asr1/conf/transducer/train_*.yaml for more info."
                )

            self.decoder = CustomDecoder(
                odim,
                encoder_out,
                args.joint_dim,
                args.dec_block_arch,
                input_layer=args.custom_dec_input_layer,
                repeat_block=args.dec_block_repeat,
                joint_activation_type=args.joint_activation_type,
                positionwise_activation_type=args.
                custom_dec_pw_activation_type,
                dropout_rate_embed=args.dropout_rate_embed_decoder,
            )

            if "custom" in args.etype:
                self.most_dom_list += args.dec_block_arch[:]
            else:
                self.most_dom_list = args.dec_block_arch[:]
        else:
            self.dec = DecoderRNNT(
                args.eprojs,
                odim,
                args.dtype,
                args.dlayers,
                args.dunits,
                blank_id,
                args.dec_embed_dim,
                args.joint_dim,
                args.joint_activation_type,
                args.dropout_rate_decoder,
                args.dropout_rate_embed_decoder,
            )

        if hasattr(self, "most_dom_list"):
            self.most_dom_dim = sorted(
                Counter(d["d_hidden"] for d in self.most_dom_list
                        if "d_hidden" in d).most_common(),
                key=lambda x: x[0],
                reverse=True,
            )[0][0]

        self.etype = args.etype
        self.dtype = args.dtype

        self.sos = odim - 1
        self.eos = odim - 1
        self.blank_id = blank_id
        self.ignore_id = ignore_id

        self.space = args.sym_space
        self.blank = args.sym_blank

        self.odim = odim

        self.reporter = Reporter()

        if training:
            self.criterion = TransLoss(args.trans_type, self.blank_id)

        self.default_parameters(args)

        if args.report_cer or args.report_wer:
            from espnet.nets.e2e_asr_common import ErrorCalculatorTransducer

            if self.dtype == "custom":
                decoder = self.decoder
            else:
                decoder = self.dec

            self.error_calculator = ErrorCalculatorTransducer(
                decoder,
                args.char_list,
                args.sym_space,
                args.sym_blank,
                args.report_cer,
                args.report_wer,
            )
        else:
            self.error_calculator = None

        self.loss = None
        self.rnnlm = None