Example #1
0
    def __init__(self, config):
        super(TransformerTransducer, self).__init__()
        self.vocab_size = config.joint.vocab_size
        self.sos = self.vocab_size - 1
        self.eos = self.vocab_size - 1
        self.ignore_id = -1
        self.encoder_left_mask = config.mask.encoder_left_mask
        self.encoder_right_mask = config.mask.encoder_right_mask
        self.decoder_left_mask = config.mask.decoder_left_mask

        self.encoder = TransformerEncoder(**config.enc)
        self.decoder = TransformerEncoder(**config.dec)
        self.joint = JointNetwork(**config.joint)
        self.loss = TransLoss(trans_type="warp-transducer",
                              blank_id=0)  # todo: check blank id
Example #2
0
    def __init__(self, idim, odim, args, ignore_id=-1, blank_id=0):
        """Construct an E2E object for transducer model.

        Args:
            idim (int): dimension of inputs
            odim (int): dimension of outputs
            args (Namespace): argument Namespace containing options

        """
        torch.nn.Module.__init__(self)

        if args.etype == 'transformer':
            self.encoder = Encoder(idim=idim,
                                   attention_dim=args.adim,
                                   attention_heads=args.aheads,
                                   linear_units=args.eunits,
                                   num_blocks=args.elayers,
                                   input_layer=args.transformer_input_layer,
                                   dropout_rate=args.dropout_rate,
                                   positional_dropout_rate=args.dropout_rate,
                                   attention_dropout_rate=args.
                                   transformer_attn_dropout_rate_encoder)

            self.subsample = [1]
        else:
            self.subsample = get_subsample(args, mode='asr', arch='rnn-t')

            self.encoder = encoder_for(args, idim, self.subsample)

        if args.dtype == 'transformer':
            self.decoder = Decoder(
                odim=odim,
                jdim=args.joint_dim,
                attention_dim=args.adim,
                attention_heads=args.aheads,
                linear_units=args.dunits,
                num_blocks=args.dlayers,
                input_layer=args.transformer_dec_input_layer,
                dropout_rate=args.dropout_rate_decoder,
                positional_dropout_rate=args.dropout_rate_decoder,
                attention_dropout_rate=args.
                transformer_attn_dropout_rate_decoder)
        else:
            if args.etype == 'transformer':
                args.eprojs = args.adim

            if args.rnnt_mode == 'rnnt-att':
                self.att = att_for(args)
                self.decoder = decoder_for(args, odim, self.att)
            else:
                self.decoder = decoder_for(args, odim)

        self.etype = args.etype
        self.dtype = args.dtype
        self.rnnt_mode = args.rnnt_mode

        self.sos = odim - 1
        self.eos = odim - 1
        self.blank_id = blank_id
        self.ignore_id = ignore_id

        self.space = args.sym_space
        self.blank = args.sym_blank

        self.odim = odim
        self.adim = args.adim

        self.reporter = Reporter()

        self.criterion = TransLoss(args.trans_type, self.blank_id)

        self.default_parameters(args)

        if args.report_cer or args.report_wer:
            from espnet.nets.e2e_asr_common import ErrorCalculatorTrans

            self.error_calculator = ErrorCalculatorTrans(self.decoder, args)
        else:
            self.error_calculator = None

        self.logzero = -10000000000.0
        self.loss = None
        self.rnnlm = None
    def __init__(self,
                 idim,
                 odim,
                 args,
                 ignore_id=-1,
                 blank_id=0,
                 training=True):
        """Construct an E2E object for transducer model."""
        torch.nn.Module.__init__(self)

        if "transformer" in args.etype:
            if args.enc_block_arch is None:
                raise ValueError(
                    "Transformer-based blocks in transducer mode should be"
                    "defined individually in the YAML file."
                    "See egs/vivos/asr1/conf/transducer/* for more info.")

            self.subsample = get_subsample(args,
                                           mode="asr",
                                           arch="transformer")

            self.encoder = Encoder(
                idim,
                args.enc_block_arch,
                input_layer=args.transformer_enc_input_layer,
                repeat_block=args.enc_block_repeat,
                self_attn_type=args.transformer_enc_self_attn_type,
                positional_encoding_type=args.
                transformer_enc_positional_encoding_type,
                positionwise_activation_type=args.
                transformer_enc_pw_activation_type,
                conv_mod_activation_type=args.
                transformer_enc_conv_mod_activation_type,
            )

            encoder_out = self.encoder.enc_out
            args.eprojs = self.encoder.enc_out

            self.most_dom_list = args.enc_block_arch[:]
        else:
            self.subsample = get_subsample(args, mode="asr", arch="rnn-t")

            self.enc = encoder_for(args, idim, self.subsample)

            encoder_out = args.eprojs

        if "transformer" in args.dtype:
            if args.dec_block_arch is None:
                raise ValueError(
                    "Transformer-based blocks in transducer mode should be"
                    "defined individually in the YAML file."
                    "See egs/vivos/asr1/conf/transducer/* for more info.")

            self.decoder = DecoderTT(
                odim,
                encoder_out,
                args.joint_dim,
                args.dec_block_arch,
                input_layer=args.transformer_dec_input_layer,
                repeat_block=args.dec_block_repeat,
                joint_activation_type=args.joint_activation_type,
                positionwise_activation_type=args.
                transformer_dec_pw_activation_type,
                dropout_rate_embed=args.dropout_rate_embed_decoder,
            )

            if "transformer" in args.etype:
                self.most_dom_list += args.dec_block_arch[:]
            else:
                self.most_dom_list = args.dec_block_arch[:]
        else:
            if args.rnnt_mode == "rnnt-att":
                self.att = att_for(args)

                self.dec = DecoderRNNTAtt(
                    args.eprojs,
                    odim,
                    args.dtype,
                    args.dlayers,
                    args.dunits,
                    blank_id,
                    self.att,
                    args.dec_embed_dim,
                    args.joint_dim,
                    args.joint_activation_type,
                    args.dropout_rate_decoder,
                    args.dropout_rate_embed_decoder,
                )
            else:
                self.dec = DecoderRNNT(
                    args.eprojs,
                    odim,
                    args.dtype,
                    args.dlayers,
                    args.dunits,
                    blank_id,
                    args.dec_embed_dim,
                    args.joint_dim,
                    args.joint_activation_type,
                    args.dropout_rate_decoder,
                    args.dropout_rate_embed_decoder,
                )

        if hasattr(self, "most_dom_list"):
            self.most_dom_dim = sorted(
                Counter(d["d_hidden"] for d in self.most_dom_list
                        if "d_hidden" in d).most_common(),
                key=lambda x: x[0],
                reverse=True,
            )[0][0]

        self.etype = args.etype
        self.dtype = args.dtype
        self.rnnt_mode = args.rnnt_mode

        self.sos = odim - 1
        self.eos = odim - 1
        self.blank_id = blank_id
        self.ignore_id = ignore_id

        self.space = args.sym_space
        self.blank = args.sym_blank

        self.odim = odim

        self.reporter = Reporter()

        if training:
            self.criterion = TransLoss(args.trans_type, self.blank_id)

        self.default_parameters(args)

        if args.report_cer or args.report_wer:
            from espnet.nets.e2e_asr_common import ErrorCalculatorTransducer

            if self.dtype == "transformer":
                decoder = self.decoder
            else:
                decoder = self.dec

            self.error_calculator = ErrorCalculatorTransducer(
                decoder,
                args.char_list,
                args.sym_space,
                args.sym_blank,
                args.report_cer,
                args.report_wer,
            )
        else:
            self.error_calculator = None

        self.loss = None
        self.rnnlm = None
Example #4
0
    def __init__(self,
                 idim,
                 odim,
                 args,
                 ignore_id=-1,
                 blank_id=0,
                 training=True):
        """Construct an E2E object for transducer model."""
        torch.nn.Module.__init__(self)

        args = fill_missing_args(args, self.add_arguments)

        self.is_rnnt = True
        self.transducer_weight = args.transducer_weight

        self.use_aux_task = (True if (args.aux_task_type is not None
                                      and training) else False)

        self.use_aux_ctc = args.aux_ctc and training
        self.aux_ctc_weight = args.aux_ctc_weight

        self.use_aux_cross_entropy = args.aux_cross_entropy and training
        self.aux_cross_entropy_weight = args.aux_cross_entropy_weight

        if self.use_aux_task:
            n_layers = ((len(args.enc_block_arch) * args.enc_block_repeat -
                         1) if args.enc_block_arch is not None else
                        (args.elayers - 1))

            aux_task_layer_list = valid_aux_task_layer_list(
                args.aux_task_layer_list,
                n_layers,
            )
        else:
            aux_task_layer_list = []

        if "custom" in args.etype:
            if args.enc_block_arch is None:
                raise ValueError(
                    "When specifying custom encoder type, --enc-block-arch"
                    "should also be specified in training config. See"
                    "egs/vivos/asr1/conf/transducer/train_*.yaml for more info."
                )

            self.subsample = get_subsample(args,
                                           mode="asr",
                                           arch="transformer")

            self.encoder = CustomEncoder(
                idim,
                args.enc_block_arch,
                input_layer=args.custom_enc_input_layer,
                repeat_block=args.enc_block_repeat,
                self_attn_type=args.custom_enc_self_attn_type,
                positional_encoding_type=args.
                custom_enc_positional_encoding_type,
                positionwise_activation_type=args.
                custom_enc_pw_activation_type,
                conv_mod_activation_type=args.
                custom_enc_conv_mod_activation_type,
                aux_task_layer_list=aux_task_layer_list,
            )
            encoder_out = self.encoder.enc_out

            self.most_dom_list = args.enc_block_arch[:]
        else:
            self.subsample = get_subsample(args, mode="asr", arch="rnn-t")

            self.enc = encoder_for(
                args,
                idim,
                self.subsample,
                aux_task_layer_list=aux_task_layer_list,
            )
            encoder_out = args.eprojs

        if "custom" in args.dtype:
            if args.dec_block_arch is None:
                raise ValueError(
                    "When specifying custom decoder type, --dec-block-arch"
                    "should also be specified in training config. See"
                    "egs/vivos/asr1/conf/transducer/train_*.yaml for more info."
                )

            self.decoder = CustomDecoder(
                odim,
                args.dec_block_arch,
                input_layer=args.custom_dec_input_layer,
                repeat_block=args.dec_block_repeat,
                positionwise_activation_type=args.
                custom_dec_pw_activation_type,
                dropout_rate_embed=args.dropout_rate_embed_decoder,
            )
            decoder_out = self.decoder.dunits

            if "custom" in args.etype:
                self.most_dom_list += args.dec_block_arch[:]
            else:
                self.most_dom_list = args.dec_block_arch[:]
        else:
            self.dec = DecoderRNNT(
                odim,
                args.dtype,
                args.dlayers,
                args.dunits,
                blank_id,
                args.dec_embed_dim,
                args.dropout_rate_decoder,
                args.dropout_rate_embed_decoder,
            )
            decoder_out = args.dunits

        self.joint_network = JointNetwork(odim, encoder_out, decoder_out,
                                          args.joint_dim,
                                          args.joint_activation_type)

        if hasattr(self, "most_dom_list"):
            self.most_dom_dim = sorted(
                Counter(d["d_hidden"] for d in self.most_dom_list
                        if "d_hidden" in d).most_common(),
                key=lambda x: x[0],
                reverse=True,
            )[0][0]

        self.etype = args.etype
        self.dtype = args.dtype

        self.sos = odim - 1
        self.eos = odim - 1
        self.blank_id = blank_id
        self.ignore_id = ignore_id

        self.space = args.sym_space
        self.blank = args.sym_blank

        self.odim = odim

        self.reporter = Reporter()
        self.error_calculator = None

        self.default_parameters(args)

        if training:
            self.criterion = TransLoss(args.trans_type, self.blank_id)

            decoder = self.decoder if self.dtype == "custom" else self.dec

            if args.report_cer or args.report_wer:
                self.error_calculator = ErrorCalculator(
                    decoder,
                    self.joint_network,
                    args.char_list,
                    args.sym_space,
                    args.sym_blank,
                    args.report_cer,
                    args.report_wer,
                )

            if self.use_aux_task:
                self.auxiliary_task = AuxiliaryTask(
                    decoder,
                    self.joint_network,
                    self.criterion,
                    args.aux_task_type,
                    args.aux_task_weight,
                    encoder_out,
                    args.joint_dim,
                )

            if self.use_aux_ctc:
                self.aux_ctc = ctc_for(
                    Namespace(
                        num_encs=1,
                        eprojs=encoder_out,
                        dropout_rate=args.aux_ctc_dropout_rate,
                        ctc_type="warpctc",
                    ),
                    odim,
                )

            if self.use_aux_cross_entropy:
                self.aux_decoder_output = torch.nn.Linear(decoder_out, odim)

                self.aux_cross_entropy = LabelSmoothingLoss(
                    odim, ignore_id, args.aux_cross_entropy_smoothing)

        self.loss = None
        self.rnnlm = None
Example #5
0
    def __init__(self, idim, odim, args, ignore_id=-1, blank_id=0):
        """Construct an E2E object for transducer model."""
        torch.nn.Module.__init__(self)

        if "transformer" in args.etype:
            if args.enc_block_arch is None:
                raise ValueError(
                    "Transformer-based blocks in transducer mode should be"
                    "defined individually in the YAML file."
                    "See egs/vivos/asr1/conf/transducer/* for more info.")

            self.subsample = get_subsample(args,
                                           mode="asr",
                                           arch="transformer")
            # 2. use transformer to joint feature maps
            # transformer without positional encoding

            self.clayers = repeat(
                2,
                lambda lnum: EncoderLayer(
                    16,
                    MultiHeadedAttention(4, 16, 0.1),
                    PositionwiseFeedForward(16, 2048, 0.1),
                    dropout_rate=0.1,
                    normalize_before=True,
                    concat_after=False,
                ),
            )

            self.conv = torch.nn.Sequential(
                torch.nn.Conv2d(1, 32, kernel_size=(3, 5), stride=(1, 2)),
                torch.nn.ReLU(),
                torch.nn.Conv2d(32, 32, kernel_size=(3, 7), stride=(2, 2)),
                torch.nn.ReLU())

            self.encoder = Encoder(
                idim,
                args.enc_block_arch,
                input_layer=args.transformer_enc_input_layer,
                repeat_block=args.enc_block_repeat,
                self_attn_type=args.transformer_enc_self_attn_type,
                positional_encoding_type=args.
                transformer_enc_positional_encoding_type,
                positionwise_activation_type=args.
                transformer_enc_pw_activation_type,
                conv_mod_activation_type=args.
                transformer_enc_conv_mod_activation_type,
            )
            encoder_out = self.encoder.enc_out
            args.eprojs = self.encoder.enc_out

            self.most_dom_list = args.enc_block_arch[:]
        else:
            self.subsample = get_subsample(args, mode="asr", arch="rnn-t")

            self.enc = encoder_for(args, idim, self.subsample)

            encoder_out = args.eprojs

        if "transformer" in args.dtype:
            if args.dec_block_arch is None:
                raise ValueError(
                    "Transformer-based blocks in transducer mode should be"
                    "defined individually in the YAML file."
                    "See egs/vivos/asr1/conf/transducer/* for more info.")

            self.decoder = DecoderTT(
                odim,
                encoder_out,
                args.joint_dim,
                args.dec_block_arch,
                input_layer=args.transformer_dec_input_layer,
                repeat_block=args.dec_block_repeat,
                joint_activation_type=args.joint_activation_type,
                positionwise_activation_type=args.
                transformer_dec_pw_activation_type,
                dropout_rate_embed=args.dropout_rate_embed_decoder,
            )

            if "transformer" in args.etype:
                self.most_dom_list += args.dec_block_arch[:]
            else:
                self.most_dom_list = args.dec_block_arch[:]
        else:
            if args.rnnt_mode == "rnnt-att":
                self.att = att_for(args)

                self.dec = DecoderRNNTAtt(
                    args.eprojs,
                    odim,
                    args.dtype,
                    args.dlayers,
                    args.dunits,
                    blank_id,
                    self.att,
                    args.dec_embed_dim,
                    args.joint_dim,
                    args.joint_activation_type,
                    args.dropout_rate_decoder,
                    args.dropout_rate_embed_decoder,
                )
            else:
                self.dec = DecoderRNNT(
                    args.eprojs,
                    odim,
                    args.dtype,
                    args.dlayers,
                    args.dunits,
                    blank_id,
                    args.dec_embed_dim,
                    args.joint_dim,
                    args.joint_activation_type,
                    args.dropout_rate_decoder,
                    args.dropout_rate_embed_decoder,
                )

        if hasattr(self, "most_dom_list"):
            self.most_dom_dim = sorted(
                Counter(d["d_hidden"] for d in self.most_dom_list
                        if "d_hidden" in d).most_common(),
                key=lambda x: x[0],
                reverse=True,
            )[0][0]

        self.etype = args.etype
        self.dtype = args.dtype
        self.rnnt_mode = args.rnnt_mode

        self.sos = odim - 1
        self.eos = odim - 1
        self.blank_id = blank_id
        self.ignore_id = ignore_id

        self.space = args.sym_space
        self.blank = args.sym_blank

        self.odim = odim

        self.reporter = Reporter()

        self.criterion = TransLoss(args.trans_type, self.blank_id)

        self.default_parameters(args)

        if args.report_cer or args.report_wer:
            from espnet.nets.e2e_asr_common import ErrorCalculatorTransducer

            if self.dtype == "transformer":
                decoder = self.decoder
            else:
                decoder = self.dec

            self.error_calculator = ErrorCalculatorTransducer(
                decoder,
                args.char_list,
                args.sym_space,
                args.sym_blank,
                args.report_cer,
                args.report_wer,
            )
        else:
            self.error_calculator = None

        self.loss = None
        self.rnnlm = None
Example #6
0
    def __init__(self,
                 idim,
                 odim,
                 args,
                 ignore_id=-1,
                 blank_id=0,
                 training=True):
        """Construct an E2E object for transducer model."""
        torch.nn.Module.__init__(self)

        self.is_rnnt = True

        if "custom" in args.etype:
            if args.enc_block_arch is None:
                raise ValueError(
                    "When specifying custom encoder type, --enc-block-arch"
                    "should also be specified in training config. See"
                    "egs/vivos/asr1/conf/transducer/train_*.yaml for more info."
                )

            self.subsample = get_subsample(args,
                                           mode="asr",
                                           arch="transformer")

            self.encoder = CustomEncoder(
                idim,
                args.enc_block_arch,
                input_layer=args.custom_enc_input_layer,
                repeat_block=args.enc_block_repeat,
                self_attn_type=args.custom_enc_self_attn_type,
                positional_encoding_type=args.
                custom_enc_positional_encoding_type,
                positionwise_activation_type=args.
                custom_enc_pw_activation_type,
                conv_mod_activation_type=args.
                custom_enc_conv_mod_activation_type,
            )
            encoder_out = self.encoder.enc_out

            self.most_dom_list = args.enc_block_arch[:]
        else:
            self.subsample = get_subsample(args, mode="asr", arch="rnn-t")

            self.enc = encoder_for(args, idim, self.subsample)
            encoder_out = args.eprojs

        if "custom" in args.dtype:
            if args.dec_block_arch is None:
                raise ValueError(
                    "When specifying custom decoder type, --dec-block-arch"
                    "should also be specified in training config. See"
                    "egs/vivos/asr1/conf/transducer/train_*.yaml for more info."
                )

            self.decoder = CustomDecoder(
                odim,
                args.dec_block_arch,
                input_layer=args.custom_dec_input_layer,
                repeat_block=args.dec_block_repeat,
                positionwise_activation_type=args.
                custom_dec_pw_activation_type,
                dropout_rate_embed=args.dropout_rate_embed_decoder,
            )
            decoder_out = self.decoder.dunits

            if "custom" in args.etype:
                self.most_dom_list += args.dec_block_arch[:]
            else:
                self.most_dom_list = args.dec_block_arch[:]
        else:
            self.dec = DecoderRNNT(
                odim,
                args.dtype,
                args.dlayers,
                args.dunits,
                blank_id,
                args.dec_embed_dim,
                args.dropout_rate_decoder,
                args.dropout_rate_embed_decoder,
            )
            decoder_out = args.dunits

        self.joint_network = JointNetwork(odim, encoder_out, decoder_out,
                                          args.joint_dim,
                                          args.joint_activation_type)

        if hasattr(self, "most_dom_list"):
            self.most_dom_dim = sorted(
                Counter(d["d_hidden"] for d in self.most_dom_list
                        if "d_hidden" in d).most_common(),
                key=lambda x: x[0],
                reverse=True,
            )[0][0]

        self.etype = args.etype
        self.dtype = args.dtype

        self.sos = odim - 1
        self.eos = odim - 1
        self.blank_id = blank_id
        self.ignore_id = ignore_id

        self.space = args.sym_space
        self.blank = args.sym_blank

        self.odim = odim

        self.reporter = Reporter()

        if training:
            self.criterion = TransLoss(args.trans_type, self.blank_id)

        self.default_parameters(args)

        if training and (args.report_cer or args.report_wer):
            self.error_calculator = ErrorCalculator(
                self.decoder if self.dtype == "custom" else self.dec,
                self.joint_network,
                args.char_list,
                args.sym_space,
                args.sym_blank,
                args.report_cer,
                args.report_wer,
            )
        else:
            self.error_calculator = None

        self.loss = None
        self.rnnlm = None