Пример #1
0
    def __init__(self, idim, odim, args, ignore_id=-1):
        """Construct an E2E object.

        :param int idim: dimension of inputs
        :param int odim: dimension of outputs
        :param Namespace args: argument Namespace containing options
        """
        torch.nn.Module.__init__(self)
        if args.transformer_attn_dropout_rate is None:
            args.transformer_attn_dropout_rate = args.dropout_rate
        self.encoder = Encoder(
            idim=idim,
            attention_dim=args.adim,
            attention_heads=args.aheads,
            linear_units=args.eunits,
            num_blocks=args.elayers,
            input_layer=args.transformer_input_layer,
            dropout_rate=args.dropout_rate,
            positional_dropout_rate=args.dropout_rate,
            attention_dropout_rate=args.transformer_attn_dropout_rate
        )
        self.decoder = Decoder(
            odim=odim,
            attention_dim=args.adim,
            attention_heads=args.aheads,
            linear_units=args.dunits,
            num_blocks=args.dlayers,
            dropout_rate=args.dropout_rate,
            positional_dropout_rate=args.dropout_rate,
            self_attention_dropout_rate=args.transformer_attn_dropout_rate,
            src_attention_dropout_rate=args.transformer_attn_dropout_rate
        )
        self.sos = odim - 1
        self.eos = odim - 1
        self.odim = odim
        self.ignore_id = ignore_id
        self.subsample = get_subsample(args, mode='asr', arch='transformer')
        self.reporter = Reporter()

        # self.lsm_weight = a
        self.criterion = LabelSmoothingLoss(self.odim, self.ignore_id, args.lsm_weight,
                                            args.transformer_length_normalized_loss)
        # self.verbose = args.verbose
        self.reset_parameters(args)
        self.adim = args.adim
        self.mtlalpha = args.mtlalpha
        if args.mtlalpha > 0.0:
            self.ctc = CTC(odim, args.adim, args.dropout_rate, ctc_type=args.ctc_type, reduce=True)
        else:
            self.ctc = None

        if args.report_cer or args.report_wer:
            from espnet.nets.e2e_asr_common import ErrorCalculator
            self.error_calculator = ErrorCalculator(args.char_list,
                                                    args.sym_space, args.sym_blank,
                                                    args.report_cer, args.report_wer)
        else:
            self.error_calculator = None
        self.rnnlm = None
Пример #2
0
 def __init__(self, idim, odim, args, ignore_id=-1, flag_return=True):
     """Initialize the transformer."""
     chainer.Chain.__init__(self)
     self.mtlalpha = args.mtlalpha
     assert 0 <= self.mtlalpha <= 1, "mtlalpha must be [0,1]"
     if args.transformer_attn_dropout_rate is None:
         args.transformer_attn_dropout_rate = args.dropout_rate
     self.use_label_smoothing = False
     self.char_list = args.char_list
     self.space = args.sym_space
     self.blank = args.sym_blank
     self.scale_emb = args.adim ** 0.5
     self.sos = odim - 1
     self.eos = odim - 1
     self.subsample = get_subsample(args, mode='asr', arch='transformer')
     self.ignore_id = ignore_id
     self.reset_parameters(args)
     with self.init_scope():
         self.encoder = Encoder(
             idim=idim,
             attention_dim=args.adim,
             attention_heads=args.aheads,
             linear_units=args.eunits,
             input_layer=args.transformer_input_layer,
             dropout_rate=args.dropout_rate,
             positional_dropout_rate=args.dropout_rate,
             attention_dropout_rate=args.transformer_attn_dropout_rate,
             initialW=self.initialW,
             initial_bias=self.initialB)
         self.decoder = Decoder(odim, args, initialW=self.initialW, initial_bias=self.initialB)
         self.criterion = LabelSmoothingLoss(args.lsm_weight, len(args.char_list),
                                             args.transformer_length_normalized_loss)
         if args.mtlalpha > 0.0:
             if args.ctc_type == 'builtin':
                 logging.info("Using chainer CTC implementation")
                 self.ctc = ctc.CTC(odim, args.adim, args.dropout_rate)
             elif args.ctc_type == 'warpctc':
                 logging.info("Using warpctc CTC implementation")
                 self.ctc = ctc.WarpCTC(odim, args.adim, args.dropout_rate)
             else:
                 raise ValueError('ctc_type must be "builtin" or "warpctc": {}'
                                  .format(args.ctc_type))
         else:
             self.ctc = None
     self.dims = args.adim
     self.odim = odim
     self.flag_return = flag_return
     if args.report_cer or args.report_wer:
         from espnet.nets.e2e_asr_common import ErrorCalculator
         self.error_calculator = ErrorCalculator(args.char_list,
                                                 args.sym_space, args.sym_blank,
                                                 args.report_cer, args.report_wer)
     else:
         self.error_calculator = None
     if 'Namespace' in str(type(args)):
         self.verbose = 0 if 'verbose' not in args else args.verbose
     else:
         self.verbose = 0 if args.verbose is None else args.verbose
Пример #3
0
    def __init__(self, idim, odim, args, flag_return=True):
        """Construct an E2E object.

        :param int idim: dimension of inputs
        :param int odim: dimension of outputs
        :param Namespace args: argument Namespace containing options
        """
        chainer.Chain.__init__(self)
        self.mtlalpha = args.mtlalpha
        assert 0 <= self.mtlalpha <= 1, "mtlalpha must be [0,1]"
        self.etype = args.etype
        self.verbose = args.verbose
        self.char_list = args.char_list
        self.outdir = args.outdir

        # below means the last number becomes eos/sos ID
        # note that sos/eos IDs are identical
        self.sos = odim - 1
        self.eos = odim - 1

        # subsample info
        self.subsample = get_subsample(args, mode="asr", arch="rnn")

        # label smoothing info
        if args.lsm_type:
            logging.info("Use label smoothing with " + args.lsm_type)
            labeldist = label_smoothing_dist(odim,
                                             args.lsm_type,
                                             transcript=args.train_json)
        else:
            labeldist = None

        with self.init_scope():
            # encoder
            self.enc = encoder_for(args, idim, self.subsample)
            # ctc
            self.ctc = ctc_for(args, odim)
            # attention
            self.att = att_for(args)
            # decoder
            self.dec = decoder_for(args, odim, self.sos, self.eos, self.att,
                                   labeldist)

        self.acc = None
        self.loss = None
        self.flag_return = flag_return
Пример #4
0
    def __init__(self, idim, odim, args, ignore_id=-1):
        """Construct an E2E object.

        :param int idim: dimension of inputs
        :param int odim: dimension of outputs
        :param Namespace args: argument Namespace containing options
        """
        torch.nn.Module.__init__(self)

        # fill missing arguments for compatibility
        args = fill_missing_args(args, self.add_arguments)

        if args.transformer_attn_dropout_rate is None:
            args.transformer_attn_dropout_rate = args.dropout_rate
        self.encoder = Encoder(
            idim=idim,
            selfattention_layer_type=args.transformer_encoder_selfattn_layer_type,
            attention_dim=args.adim,
            attention_heads=args.aheads,
            conv_wshare=args.wshare,
            conv_kernel_length=args.ldconv_encoder_kernel_length,
            conv_usebias=args.ldconv_usebias,
            linear_units=args.eunits,
            num_blocks=args.elayers,
            input_layer=args.transformer_input_layer,
            dropout_rate=args.dropout_rate,
            positional_dropout_rate=args.dropout_rate,
            attention_dropout_rate=args.transformer_attn_dropout_rate,
        )
        odim = odim - 1
        self.odim = odim
        self.ignore_id = ignore_id
        self.subsample = get_subsample(args, mode="asr", arch="transformer")
        self.reporter = Reporter()

        self.criterion = LabelSmoothingLoss(self.odim, self.ignore_id, args.lsm_weight,
                                            args.transformer_length_normalized_loss)
        self.output = torch.nn.Linear(256, self.odim) # mean + std pooling
        self.att = Attention(256)

        self.reset_parameters(args)
        self.adim = args.adim  # used for CTC (equal to d_model)
    def __init__(self, idim, odim, args, ignore_id=-1):
        """Construct an E2E object.

        :param int idim: dimension of inputs
        :param int odim: dimension of outputs
        :param Namespace args: argument Namespace containing options
        """
        torch.nn.Module.__init__(self)
        if args.transformer_attn_dropout_rate is None:
            args.transformer_attn_dropout_rate = args.dropout_rate
        self.encoder = Encoder(
            idim=idim,
            attention_dim=args.adim,
            attention_heads=args.aheads,
            linear_units=args.eunits,
            num_blocks=args.elayers,
            input_layer=args.transformer_input_layer,
            dropout_rate=args.dropout_rate,
            positional_dropout_rate=args.dropout_rate,
            attention_dropout_rate=args.transformer_attn_dropout_rate
        )

        # submodule for ASR task
        self.mtlalpha = args.mtlalpha
        self.asr_weight = getattr(args, "asr_weight", 0.0)
        self.do_asr = self.asr_weight > 0 and args.mtlalpha < 1

        # cross-attention parameters
        self.cross_weight = getattr(args, "cross_weight", 0.0)
        self.cross_self = getattr(args, "cross_self", False)
        self.cross_src = getattr(args, "cross_src", False)
        self.cross_operator = getattr(args, "cross_operator", None)
        self.cross_to_asr = getattr(args, "cross_to_asr", False)
        self.cross_to_st = getattr(args, "cross_to_st", False)
        self.num_decoders = getattr(args, "num_decoders", 1)
        self.wait_k_asr = getattr(args, "wait_k_asr", 0)
        self.wait_k_st = getattr(args, "wait_k_st", 0)
        self.cross_src_from = getattr(args, "cross_src_from", "embedding")
        self.cross_self_from = getattr(args, "cross_self_from", "embedding")
        self.cross_weight_learnable = getattr(args, "cross_weight_learnable", False)

        # one-to-many ST experiments
        self.one_to_many = getattr(args, "one_to_many", False)
        self.langs_dict = getattr(args, "langs_dict", None)
        self.lang_tok = getattr(args, "lang_tok", None)

        self.normalize_before = getattr(args, "normalize_before", True)
        logging.info(f'self.normalize_before = {self.normalize_before}')

        # Check parameters
        if self.cross_operator == 'sum' and self.cross_weight <= 0:
            assert (not self.cross_to_asr) and (not self.cross_to_st)
        if self.cross_to_asr or self.cross_to_st:
            assert self.do_asr
            assert self.cross_self or self.cross_src
        assert bool(self.cross_operator) == (self.do_asr and (self.cross_to_asr or self.cross_to_st))
        if self.cross_src_from != "embedding" or self.cross_self_from != "embedding":
            assert self.normalize_before
        if self.wait_k_asr > 0:
            assert self.wait_k_st == 0
        elif self.wait_k_st > 0:
            assert self.wait_k_asr == 0
        else:
            assert self.wait_k_asr == 0
            assert self.wait_k_st == 0

        logging.info("*** Cross attention parameters ***")
        if self.cross_to_asr:
            logging.info("| Cross to ASR")
        if self.cross_to_st:
            logging.info("| Cross to ST")
        if self.cross_self:
            logging.info("| Cross at Self")
        if self.cross_src:
            logging.info("| Cross at Source")
        if self.cross_to_asr or self.cross_to_st:
            logging.info(f'| Cross operator: {self.cross_operator}')
            logging.info(f'| Cross sum weight: {self.cross_weight}')
            if self.cross_src:
                logging.info(f'| Cross source from: {self.cross_src_from}')
            if self.cross_self:
                logging.info(f'| Cross self from: {self.cross_self_from}')
        logging.info(f'| wait_k_asr = {self.wait_k_asr}')
        logging.info(f'| wait_k_st = {self.wait_k_st}')
        
        if (self.cross_src_from != "embedding" and self.cross_src) and (not self.normalize_before):
            logging.warning(f'WARNING: Resort to using self.cross_src_from == embedding for cross at source attention.')
        if (self.cross_self_from != "embedding" and self.cross_self) and (not self.normalize_before):
            logging.warning(f'WARNING: Resort to using self.cross_self_from == embedding for cross at self attention.')

        self.dual_decoder = DualDecoder(
                odim=odim,
                attention_dim=args.adim,
                attention_heads=args.aheads,
                linear_units=args.dunits,
                num_blocks=args.dlayers,
                dropout_rate=args.dropout_rate,
                positional_dropout_rate=args.dropout_rate,
                self_attention_dropout_rate=args.transformer_attn_dropout_rate,
                src_attention_dropout_rate=args.transformer_attn_dropout_rate,
                normalize_before=self.normalize_before,
                cross_operator=self.cross_operator,
                cross_weight_learnable=self.cross_weight_learnable,
                cross_weight=self.cross_weight,
                cross_self=self.cross_self,
                cross_src=self.cross_src,
                cross_to_asr=self.cross_to_asr,
                cross_to_st=self.cross_to_st
        )

        self.pad = 0
        self.sos = odim - 1
        self.eos = odim - 1
        self.odim = odim
        self.idim = idim
        self.ignore_id = ignore_id
        self.subsample = get_subsample(args, mode='st', arch='transformer')
        self.reporter = Reporter()

        # self.lsm_weight = a
        self.criterion = LabelSmoothingLoss(self.odim, self.ignore_id, args.lsm_weight,
                                            args.transformer_length_normalized_loss)
        # self.verbose = args.verbose
        self.adim = args.adim
        
        # submodule for MT task
        self.mt_weight = getattr(args, "mt_weight", 0.0)
        if self.mt_weight > 0:
            self.encoder_mt = Encoder(
                idim=odim,
                attention_dim=args.adim,
                attention_heads=args.aheads,
                linear_units=args.dunits,
                num_blocks=args.dlayers,
                input_layer='embed',
                dropout_rate=args.dropout_rate,
                positional_dropout_rate=args.dropout_rate,
                attention_dropout_rate=args.transformer_attn_dropout_rate,
                padding_idx=0
            )
        self.reset_parameters(args)  # place after the submodule initialization
        if args.mtlalpha > 0.0:
            self.ctc = CTC(odim, args.adim, args.dropout_rate, ctc_type=args.ctc_type, reduce=True)
        else:
            self.ctc = None

        if self.asr_weight > 0 and (args.report_cer or args.report_wer):
            from espnet.nets.e2e_asr_common import ErrorCalculator
            self.error_calculator = ErrorCalculator(args.char_list,
                                                    args.sym_space, args.sym_blank,
                                                    args.report_cer, args.report_wer)
        else:
            self.error_calculator = None
        self.rnnlm = None

        # multilingual E2E-ST related
        self.multilingual = getattr(args, "multilingual", False)
        self.replace_sos = getattr(args, "replace_sos", False)
        if self.multilingual:
            assert self.replace_sos

        if self.lang_tok == "encoder-pre-sum":
            self.language_embeddings = build_embedding(self.langs_dict, self.idim, padding_idx=self.pad)
            print(f'language_embeddings: {self.language_embeddings}')
Пример #6
0
    def __init__(self, idim, odim, args, ignore_id=-1, blank_id=0):
        """Construct an E2E object for transducer model."""
        torch.nn.Module.__init__(self)

        if "transformer" in args.etype:
            if args.enc_block_arch is None:
                raise ValueError(
                    "Transformer-based blocks in transducer mode should be"
                    "defined individually in the YAML file."
                    "See egs/vivos/asr1/conf/transducer/* for more info.")

            self.subsample = get_subsample(args,
                                           mode="asr",
                                           arch="transformer")
            # 2. use transformer to joint feature maps
            # transformer without positional encoding

            self.clayers = repeat(
                2,
                lambda lnum: EncoderLayer(
                    16,
                    MultiHeadedAttention(4, 16, 0.1),
                    PositionwiseFeedForward(16, 2048, 0.1),
                    dropout_rate=0.1,
                    normalize_before=True,
                    concat_after=False,
                ),
            )

            self.conv = torch.nn.Sequential(
                torch.nn.Conv2d(1, 32, kernel_size=(3, 5), stride=(1, 2)),
                torch.nn.ReLU(),
                torch.nn.Conv2d(32, 32, kernel_size=(3, 7), stride=(2, 2)),
                torch.nn.ReLU())

            self.encoder = Encoder(
                idim,
                args.enc_block_arch,
                input_layer=args.transformer_enc_input_layer,
                repeat_block=args.enc_block_repeat,
                self_attn_type=args.transformer_enc_self_attn_type,
                positional_encoding_type=args.
                transformer_enc_positional_encoding_type,
                positionwise_activation_type=args.
                transformer_enc_pw_activation_type,
                conv_mod_activation_type=args.
                transformer_enc_conv_mod_activation_type,
            )
            encoder_out = self.encoder.enc_out
            args.eprojs = self.encoder.enc_out

            self.most_dom_list = args.enc_block_arch[:]
        else:
            self.subsample = get_subsample(args, mode="asr", arch="rnn-t")

            self.enc = encoder_for(args, idim, self.subsample)

            encoder_out = args.eprojs

        if "transformer" in args.dtype:
            if args.dec_block_arch is None:
                raise ValueError(
                    "Transformer-based blocks in transducer mode should be"
                    "defined individually in the YAML file."
                    "See egs/vivos/asr1/conf/transducer/* for more info.")

            self.decoder = DecoderTT(
                odim,
                encoder_out,
                args.joint_dim,
                args.dec_block_arch,
                input_layer=args.transformer_dec_input_layer,
                repeat_block=args.dec_block_repeat,
                joint_activation_type=args.joint_activation_type,
                positionwise_activation_type=args.
                transformer_dec_pw_activation_type,
                dropout_rate_embed=args.dropout_rate_embed_decoder,
            )

            if "transformer" in args.etype:
                self.most_dom_list += args.dec_block_arch[:]
            else:
                self.most_dom_list = args.dec_block_arch[:]
        else:
            if args.rnnt_mode == "rnnt-att":
                self.att = att_for(args)

                self.dec = DecoderRNNTAtt(
                    args.eprojs,
                    odim,
                    args.dtype,
                    args.dlayers,
                    args.dunits,
                    blank_id,
                    self.att,
                    args.dec_embed_dim,
                    args.joint_dim,
                    args.joint_activation_type,
                    args.dropout_rate_decoder,
                    args.dropout_rate_embed_decoder,
                )
            else:
                self.dec = DecoderRNNT(
                    args.eprojs,
                    odim,
                    args.dtype,
                    args.dlayers,
                    args.dunits,
                    blank_id,
                    args.dec_embed_dim,
                    args.joint_dim,
                    args.joint_activation_type,
                    args.dropout_rate_decoder,
                    args.dropout_rate_embed_decoder,
                )

        if hasattr(self, "most_dom_list"):
            self.most_dom_dim = sorted(
                Counter(d["d_hidden"] for d in self.most_dom_list
                        if "d_hidden" in d).most_common(),
                key=lambda x: x[0],
                reverse=True,
            )[0][0]

        self.etype = args.etype
        self.dtype = args.dtype
        self.rnnt_mode = args.rnnt_mode

        self.sos = odim - 1
        self.eos = odim - 1
        self.blank_id = blank_id
        self.ignore_id = ignore_id

        self.space = args.sym_space
        self.blank = args.sym_blank

        self.odim = odim

        self.reporter = Reporter()

        self.criterion = TransLoss(args.trans_type, self.blank_id)

        self.default_parameters(args)

        if args.report_cer or args.report_wer:
            from espnet.nets.e2e_asr_common import ErrorCalculatorTransducer

            if self.dtype == "transformer":
                decoder = self.decoder
            else:
                decoder = self.dec

            self.error_calculator = ErrorCalculatorTransducer(
                decoder,
                args.char_list,
                args.sym_space,
                args.sym_blank,
                args.report_cer,
                args.report_wer,
            )
        else:
            self.error_calculator = None

        self.loss = None
        self.rnnlm = None
Пример #7
0
    def __init__(self, idim, odim, args, ignore_id=-1):
        """Construct an E2E object.

        :param int idim: dimension of inputs
        :param int odim: dimension of outputs
        :param Namespace args: argument Namespace containing options
        """
        torch.nn.Module.__init__(self)
        if args.transformer_attn_dropout_rate is None:
            args.transformer_attn_dropout_rate = args.dropout_rate
        self.encoder = Encoder(
            idim=idim,
            attention_dim=args.adim,
            attention_heads=args.aheads,
            linear_units=args.eunits,
            num_blocks=args.elayers,
            input_layer=args.transformer_input_layer,
            dropout_rate=args.dropout_rate,
            positional_dropout_rate=args.dropout_rate,
            attention_dropout_rate=args.transformer_attn_dropout_rate,
        )
        # target matching system organization
        self.oversampling = args.oversampling
        self.residual = args.residual
        self.outer = args.outer
        self.poster = torch.nn.Linear(args.adim, odim * self.oversampling)
        if self.outer:
            if self.residual:
                self.matcher_res = torch.nn.Linear(idim, odim)
                self.matcher = torch.nn.Linear(odim, odim)
            else:
                self.matcher = torch.nn.Linear(odim + idim, odim)

        self.sos = odim - 1
        self.eos = odim - 1
        self.odim = odim
        self.ignore_id = ignore_id
        self.subsample = get_subsample(args, mode="asr", arch="transformer")
        self.reporter = Reporter()

        # self.lsm_weight = a
        self.criterion = LabelSmoothingLoss(
            self.odim,
            self.ignore_id,
            args.lsm_weight,
            args.transformer_length_normalized_loss,
        )
        # self.verbose = args.verbose
        self.reset_parameters(args)
        self.adim = args.adim
        self.mtlalpha = args.mtlalpha
        if args.mtlalpha > 0.0:
            self.ctc = CTC(
                odim, args.adim, args.dropout_rate, ctc_type=args.ctc_type, reduce=True
            )
        else:
            self.ctc = None

        if args.report_cer or args.report_wer:
            self.error_calculator = ErrorCalculator(
                args.char_list,
                args.sym_space,
                args.sym_blank,
                args.report_cer,
                args.report_wer,
            )
        else:
            self.error_calculator = None
        self.rnnlm = None
Пример #8
0
    def __init__(self,
                 idim,
                 odim,
                 args,
                 ignore_id=-1,
                 blank_id=0,
                 training=True):
        """Construct an E2E object for transducer model."""
        torch.nn.Module.__init__(self)

        if "transformer" in args.etype:
            if args.enc_block_arch is None:
                raise ValueError(
                    "Transformer-based blocks in transducer mode should be"
                    "defined individually in the YAML file."
                    "See egs/vivos/asr1/conf/transducer/* for more info.")

            self.subsample = get_subsample(args,
                                           mode="asr",
                                           arch="transformer")

            self.encoder = Encoder(
                idim,
                args.enc_block_arch,
                input_layer=args.transformer_enc_input_layer,
                repeat_block=args.enc_block_repeat,
                self_attn_type=args.transformer_enc_self_attn_type,
                positional_encoding_type=args.
                transformer_enc_positional_encoding_type,
                positionwise_activation_type=args.
                transformer_enc_pw_activation_type,
                conv_mod_activation_type=args.
                transformer_enc_conv_mod_activation_type,
            )

            encoder_out = self.encoder.enc_out
            args.eprojs = self.encoder.enc_out

            self.most_dom_list = args.enc_block_arch[:]
        else:
            self.subsample = get_subsample(args, mode="asr", arch="rnn-t")

            self.enc = encoder_for(args, idim, self.subsample)

            encoder_out = args.eprojs

        if "transformer" in args.dtype:
            if args.dec_block_arch is None:
                raise ValueError(
                    "Transformer-based blocks in transducer mode should be"
                    "defined individually in the YAML file."
                    "See egs/vivos/asr1/conf/transducer/* for more info.")

            self.decoder = DecoderTT(
                odim,
                encoder_out,
                args.joint_dim,
                args.dec_block_arch,
                input_layer=args.transformer_dec_input_layer,
                repeat_block=args.dec_block_repeat,
                joint_activation_type=args.joint_activation_type,
                positionwise_activation_type=args.
                transformer_dec_pw_activation_type,
                dropout_rate_embed=args.dropout_rate_embed_decoder,
            )

            if "transformer" in args.etype:
                self.most_dom_list += args.dec_block_arch[:]
            else:
                self.most_dom_list = args.dec_block_arch[:]
        else:
            if args.rnnt_mode == "rnnt-att":
                self.att = att_for(args)

                self.dec = DecoderRNNTAtt(
                    args.eprojs,
                    odim,
                    args.dtype,
                    args.dlayers,
                    args.dunits,
                    blank_id,
                    self.att,
                    args.dec_embed_dim,
                    args.joint_dim,
                    args.joint_activation_type,
                    args.dropout_rate_decoder,
                    args.dropout_rate_embed_decoder,
                )
            else:
                self.dec = DecoderRNNT(
                    args.eprojs,
                    odim,
                    args.dtype,
                    args.dlayers,
                    args.dunits,
                    blank_id,
                    args.dec_embed_dim,
                    args.joint_dim,
                    args.joint_activation_type,
                    args.dropout_rate_decoder,
                    args.dropout_rate_embed_decoder,
                )

        if hasattr(self, "most_dom_list"):
            self.most_dom_dim = sorted(
                Counter(d["d_hidden"] for d in self.most_dom_list
                        if "d_hidden" in d).most_common(),
                key=lambda x: x[0],
                reverse=True,
            )[0][0]

        self.etype = args.etype
        self.dtype = args.dtype
        self.rnnt_mode = args.rnnt_mode

        self.sos = odim - 1
        self.eos = odim - 1
        self.blank_id = blank_id
        self.ignore_id = ignore_id

        self.space = args.sym_space
        self.blank = args.sym_blank

        self.odim = odim

        self.reporter = Reporter()

        if training:
            self.criterion = TransLoss(args.trans_type, self.blank_id)

        self.default_parameters(args)

        if args.report_cer or args.report_wer:
            from espnet.nets.e2e_asr_common import ErrorCalculatorTransducer

            if self.dtype == "transformer":
                decoder = self.decoder
            else:
                decoder = self.dec

            self.error_calculator = ErrorCalculatorTransducer(
                decoder,
                args.char_list,
                args.sym_space,
                args.sym_blank,
                args.report_cer,
                args.report_wer,
            )
        else:
            self.error_calculator = None

        self.loss = None
        self.rnnlm = None
Пример #9
0
    def __init__(self,
                 idim,
                 odim,
                 args,
                 ignore_id=-1,
                 blank_id=0,
                 training=True):
        """Construct an E2E object for transducer model."""
        torch.nn.Module.__init__(self)

        args = fill_missing_args(args, self.add_arguments)

        self.is_rnnt = True
        self.transducer_weight = args.transducer_weight

        self.use_aux_task = (True if (args.aux_task_type is not None
                                      and training) else False)

        self.use_aux_ctc = args.aux_ctc and training
        self.aux_ctc_weight = args.aux_ctc_weight

        self.use_aux_cross_entropy = args.aux_cross_entropy and training
        self.aux_cross_entropy_weight = args.aux_cross_entropy_weight

        if self.use_aux_task:
            n_layers = ((len(args.enc_block_arch) * args.enc_block_repeat -
                         1) if args.enc_block_arch is not None else
                        (args.elayers - 1))

            aux_task_layer_list = valid_aux_task_layer_list(
                args.aux_task_layer_list,
                n_layers,
            )
        else:
            aux_task_layer_list = []

        if "custom" in args.etype:
            if args.enc_block_arch is None:
                raise ValueError(
                    "When specifying custom encoder type, --enc-block-arch"
                    "should also be specified in training config. See"
                    "egs/vivos/asr1/conf/transducer/train_*.yaml for more info."
                )

            self.subsample = get_subsample(args,
                                           mode="asr",
                                           arch="transformer")

            self.encoder = CustomEncoder(
                idim,
                args.enc_block_arch,
                input_layer=args.custom_enc_input_layer,
                repeat_block=args.enc_block_repeat,
                self_attn_type=args.custom_enc_self_attn_type,
                positional_encoding_type=args.
                custom_enc_positional_encoding_type,
                positionwise_activation_type=args.
                custom_enc_pw_activation_type,
                conv_mod_activation_type=args.
                custom_enc_conv_mod_activation_type,
                aux_task_layer_list=aux_task_layer_list,
            )
            encoder_out = self.encoder.enc_out

            self.most_dom_list = args.enc_block_arch[:]
        else:
            self.subsample = get_subsample(args, mode="asr", arch="rnn-t")

            self.enc = encoder_for(
                args,
                idim,
                self.subsample,
                aux_task_layer_list=aux_task_layer_list,
            )
            encoder_out = args.eprojs

        if "custom" in args.dtype:
            if args.dec_block_arch is None:
                raise ValueError(
                    "When specifying custom decoder type, --dec-block-arch"
                    "should also be specified in training config. See"
                    "egs/vivos/asr1/conf/transducer/train_*.yaml for more info."
                )

            self.decoder = CustomDecoder(
                odim,
                args.dec_block_arch,
                input_layer=args.custom_dec_input_layer,
                repeat_block=args.dec_block_repeat,
                positionwise_activation_type=args.
                custom_dec_pw_activation_type,
                dropout_rate_embed=args.dropout_rate_embed_decoder,
            )
            decoder_out = self.decoder.dunits

            if "custom" in args.etype:
                self.most_dom_list += args.dec_block_arch[:]
            else:
                self.most_dom_list = args.dec_block_arch[:]
        else:
            self.dec = DecoderRNNT(
                odim,
                args.dtype,
                args.dlayers,
                args.dunits,
                blank_id,
                args.dec_embed_dim,
                args.dropout_rate_decoder,
                args.dropout_rate_embed_decoder,
            )
            decoder_out = args.dunits

        self.joint_network = JointNetwork(odim, encoder_out, decoder_out,
                                          args.joint_dim,
                                          args.joint_activation_type)

        if hasattr(self, "most_dom_list"):
            self.most_dom_dim = sorted(
                Counter(d["d_hidden"] for d in self.most_dom_list
                        if "d_hidden" in d).most_common(),
                key=lambda x: x[0],
                reverse=True,
            )[0][0]

        self.etype = args.etype
        self.dtype = args.dtype

        self.sos = odim - 1
        self.eos = odim - 1
        self.blank_id = blank_id
        self.ignore_id = ignore_id

        self.space = args.sym_space
        self.blank = args.sym_blank

        self.odim = odim

        self.reporter = Reporter()
        self.error_calculator = None

        self.default_parameters(args)

        if training:
            self.criterion = TransLoss(args.trans_type, self.blank_id)

            decoder = self.decoder if self.dtype == "custom" else self.dec

            if args.report_cer or args.report_wer:
                self.error_calculator = ErrorCalculator(
                    decoder,
                    self.joint_network,
                    args.char_list,
                    args.sym_space,
                    args.sym_blank,
                    args.report_cer,
                    args.report_wer,
                )

            if self.use_aux_task:
                self.auxiliary_task = AuxiliaryTask(
                    decoder,
                    self.joint_network,
                    self.criterion,
                    args.aux_task_type,
                    args.aux_task_weight,
                    encoder_out,
                    args.joint_dim,
                )

            if self.use_aux_ctc:
                self.aux_ctc = ctc_for(
                    Namespace(
                        num_encs=1,
                        eprojs=encoder_out,
                        dropout_rate=args.aux_ctc_dropout_rate,
                        ctc_type="warpctc",
                    ),
                    odim,
                )

            if self.use_aux_cross_entropy:
                self.aux_decoder_output = torch.nn.Linear(decoder_out, odim)

                self.aux_cross_entropy = LabelSmoothingLoss(
                    odim, ignore_id, args.aux_cross_entropy_smoothing)

        self.loss = None
        self.rnnlm = None
Пример #10
0
    def __init__(self, idim, odim, args):
        """Initialize multi-speaker E2E module."""
        torch.nn.Module.__init__(self)
        self.mtlalpha = args.mtlalpha
        assert 0.0 <= self.mtlalpha <= 1.0, "mtlalpha should be [0.0, 1.0]"
        self.etype = args.etype
        self.verbose = args.verbose
        self.char_list = args.char_list
        self.outdir = args.outdir
        self.reporter = Reporter()
        self.num_spkrs = args.num_spkrs
        self.spa = args.spa
        self.pit = PIT(self.num_spkrs)

        # below means the last number becomes eos/sos ID
        # note that sos/eos IDs are identical
        self.sos = odim - 1
        self.eos = odim - 1

        # subsample info
        self.subsample = get_subsample(args, mode='asr', arch='rnn_mix')

        # label smoothing info
        if args.lsm_type and os.path.isfile(args.train_json):
            logging.info("Use label smoothing with " + args.lsm_type)
            labeldist = label_smoothing_dist(odim,
                                             args.lsm_type,
                                             transcript=args.train_json)
        else:
            labeldist = None

        if getattr(args, "use_frontend",
                   False):  # use getattr to keep compatibility
            # Relative importing because of using python3 syntax
            from espnet.nets.pytorch_backend.frontends.feature_transform \
                import feature_transform_for
            from espnet.nets.pytorch_backend.frontends.frontend \
                import frontend_for

            self.frontend = frontend_for(args, idim)
            self.feature_transform = feature_transform_for(
                args, (idim - 1) * 2)
            idim = args.n_mels
        else:
            self.frontend = None

        # encoder
        self.enc = encoder_for(args, idim, self.subsample)
        # ctc
        self.ctc = ctc_for(args, odim, reduce=False)
        # attention
        num_att = self.num_spkrs if args.spa else 1
        self.att = att_for(args, num_att)
        # decoder
        self.dec = decoder_for(args, odim, self.sos, self.eos, self.att,
                               labeldist)

        # weight initialization
        self.init_like_chainer()

        # options for beam search
        if 'report_cer' in vars(args) and (args.report_cer or args.report_wer):
            recog_args = {
                'beam_size': args.beam_size,
                'penalty': args.penalty,
                'ctc_weight': args.ctc_weight,
                'maxlenratio': args.maxlenratio,
                'minlenratio': args.minlenratio,
                'lm_weight': args.lm_weight,
                'rnnlm': args.rnnlm,
                'nbest': args.nbest,
                'space': args.sym_space,
                'blank': args.sym_blank
            }

            self.recog_args = argparse.Namespace(**recog_args)
            self.report_cer = args.report_cer
            self.report_wer = args.report_wer
        else:
            self.report_cer = False
            self.report_wer = False
        self.rnnlm = None

        self.logzero = -10000000000.0
        self.loss = None
        self.acc = None
Пример #11
0
    def __init__(self, idim, odim, mono_odim, args, ignore_id=-1):
        """Construct an E2E object.

        :param int idim: dimension of inputs
        :param int odim: dimension of outputs
        :param Namespace args: argument Namespace containing options
        """
        super(E2E, self).__init__()
        torch.nn.Module.__init__(self)
        self.mtlalpha = args.mtlalpha
        assert 0.0 <= self.mtlalpha <= 1.0, "mtlalpha should be [0.0, 1.0]"
        self.etype = args.etype
        self.verbose = args.verbose
        # NOTE: for self.build method
        self.outdir = args.outdir

        # target matching system organization
        self.oversampling = args.oversampling
        self.residual = args.residual
        self.outer = args.outer
        self.poster = torch.nn.Linear(args.eprojs, odim * self.oversampling)
        self.poster_mono = torch.nn.Linear(args.eprojs,
                                           mono_odim * self.oversampling)

        # below means the last number becomes eos/sos ID
        # note that sos/eos IDs are identical
        self.sos = odim - 1
        self.eos = odim - 1
        self.sos_mono = mono_odim - 1
        self.eos_mono = mono_odim - 1
        self.odim = odim
        self.mono_odim = mono_odim
        self.ignore_id = ignore_id
        self.subsample = get_subsample(args, mode="asr", arch="rnn")
        self.reporter = Reporter()

        # label smoothing info
        if args.lsm_type and os.path.isfile(args.train_json):
            logging.info("Use label smoothing with " + args.lsm_type)
            labeldist = label_smoothing_dist(odim,
                                             args.lsm_type,
                                             transcript=args.train_json)
        else:
            labeldist = None

        if getattr(args, "use_frontend",
                   False):  # use getattr to keep compatibility
            self.frontend = frontend_for(args, idim)
            self.feature_transform = feature_transform_for(
                args, (idim - 1) * 2)
            idim = args.n_mels
        else:
            self.frontend = None

        # encoder
        self.enc = encoder_for(args, idim, self.subsample)
        # ctc
        self.ctc = ctc_for(args, odim)

        # weight initialization
        if args.initializer == "lecun":
            self.init_like_chainer()
        elif args.initializer == "orthogonal":
            self.init_orthogonal()
        elif args.initializer == "xavier":
            self.init_xavier()
        else:
            raise NotImplementedError("unknown initializer: " +
                                      args.initializer)

        if args.report_cer or args.report_wer:
            self.error_calculator = ErrorCalculator(
                args.char_list,
                args.sym_space,
                args.sym_blank,
                args.report_cer,
                args.report_wer,
            )
        else:
            self.error_calculator = None
        self.rnnlm = None

        self.logzero = -10000000000.0
        self.loss = None
        self.acc = None
Пример #12
0
        arch_subsample = "transformer"
    else:
        raise ValueError("Unsupported model module: %s" % model_module)

    model_class = dynamic_import(model_module)
    model_class.add_arguments(parser)
    args = parser.parse_args(cmd_args)

    # subsampling info
    if hasattr(args, "etype") and args.etype.startswith("vgg"):
        # Subsampling is not performed for vgg*.
        # It is performed in max pooling layers at CNN.
        min_io_ratio = 4
    else:
        subsample = get_subsample(args,
                                  mode=args.mode_subsample,
                                  arch=arch_subsample)
        # the minimum input-output length ratio for all samples
        min_io_ratio = reduce(mul, subsample)

    # load dictionary
    with open(args.data_json, "rb") as f:
        j = json.load(f)["utts"]

    # remove samples with IO ratio smaller than `min_io_ratio`
    for key in list(j.keys()):
        ilen = j[key]["input"][0]["shape"][0]
        olen = min(x["shape"][0] for x in j[key]["output"])
        if float(ilen) - float(olen) * min_io_ratio < args.min_io_delta:
            j.pop(key)
            print("'{}' removed".format(key))
    def __init__(self, idim, odim, args, ignore_id=-1):
        """Construct an E2E object.

        :param int idim: dimension of inputs
        :param int odim: dimension of outputs
        :param Namespace args: argument Namespace containing options
        """
        torch.nn.Module.__init__(self)
        if args.transformer_attn_dropout_rate is None:
            args.transformer_attn_dropout_rate = args.dropout_rate
        self.encoder = Encoder(
            idim=idim,
            attention_dim=args.adim,
            attention_heads=args.aheads,
            linear_units=args.eunits,
            num_blocks=args.elayers,
            input_layer='embed',
            dropout_rate=args.dropout_rate,
            positional_dropout_rate=args.dropout_rate,
            attention_dropout_rate=args.transformer_attn_dropout_rate,
        )
        self.decoder = Decoder(
            odim=odim,
            attention_dim=args.adim,
            attention_heads=args.aheads,
            linear_units=args.dunits,
            num_blocks=args.dlayers,
            dropout_rate=args.dropout_rate,
            positional_dropout_rate=args.dropout_rate,
            self_attention_dropout_rate=args.transformer_attn_dropout_rate,
            src_attention_dropout_rate=args.transformer_attn_dropout_rate,
        )
        self.pad = 0
        self.sos = odim - 1
        self.eos = odim - 1
        self.odim = odim
        self.ignore_id = ignore_id
        self.subsample = get_subsample(args, mode='mt', arch='transformer')
        self.reporter = Reporter()

        # tie source and target emeddings
        if args.tie_src_tgt_embedding:
            if idim != odim:
                raise ValueError(
                    'When using tie_src_tgt_embedding, idim and odim must be equal.'
                )
            self.encoder.embed[0].weight = self.decoder.embed[0].weight

        # tie emeddings and the classfier
        if args.tie_classifier:
            self.decoder.output_layer.weight = self.decoder.embed[0].weight

        # self.lsm_weight = a
        self.criterion = LabelSmoothingLoss(
            self.odim, self.ignore_id, args.lsm_weight,
            args.transformer_length_normalized_loss)
        self.normalize_length = args.transformer_length_normalized_loss  # for PPL
        # self.verbose = args.verbose
        self.reset_parameters(args)
        self.adim = args.adim
        if args.report_bleu:
            from espnet.nets.e2e_mt_common import ErrorCalculator
            self.error_calculator = ErrorCalculator(args.char_list,
                                                    args.sym_space,
                                                    args.report_bleu)
        else:
            self.error_calculator = None
        self.rnnlm = None

        # multilingual NMT related
        self.multilingual = args.multilingual
Пример #14
0
    def __init__(self, idim, odim, args, ignore_id=-1):
        """Construct an E2E object.

        :param int idim: dimension of inputs
        :param int odim: dimension of outputs
        :param Namespace args: argument Namespace containing options
        """
        torch.nn.Module.__init__(self)
        if args.transformer_attn_dropout_rate is None:
            args.transformer_attn_dropout_rate = args.dropout_rate

        self.encoder_type = getattr(args, 'encoder_type', 'all_add')
        self.vbs = getattr(args, 'vbs', False)
        self.noise = getattr(args, 'noise_type', 'none')

        if self.encoder_type == 'all_add':
            from espnet.nets.pytorch_backend.transformer.multimodal_encoder_all_add import MultimodalEncoder
        elif self.encoder_type == 'proportion_add':
            from espnet.nets.pytorch_backend.transformer.multimodal_encoder_proportion_add import MultimodalEncoder
        elif self.encoder_type == 'vat':
            from espnet.nets.pytorch_backend.transformer.multimodal_encoder_vat import MultimodalEncoder

        self.encoder = MultimodalEncoder(
            idim=idim,
            attention_dim=args.adim,
            attention_heads=args.aheads,
            linear_units=args.eunits,
            num_blocks=args.elayers,
            visual_dim=args.visual_dim,
            input_layer=args.transformer_input_layer,
            dropout_rate=args.dropout_rate,
            positional_dropout_rate=args.dropout_rate,
            attention_dropout_rate=args.transformer_attn_dropout_rate,
            vbs=self.vbs)
        self.decoder = MultimodalDecoder(
            odim=odim,
            attention_dim=args.adim,
            attention_heads=args.aheads,
            linear_units=args.dunits,
            num_blocks=args.dlayers,
            dropout_rate=args.dropout_rate,
            positional_dropout_rate=args.dropout_rate,
            self_attention_dropout_rate=args.transformer_attn_dropout_rate,
            src_attention_dropout_rate=args.transformer_attn_dropout_rate)
        self.pad = 0
        self.sos = odim - 1
        self.eos = odim - 1
        self.odim = odim
        self.ignore_id = ignore_id
        self.subsample = get_subsample(args, mode='st', arch='transformer')
        self.reporter = Reporter()

        # self.lsm_weight = a
        self.criterion = LabelSmoothingLoss(
            self.odim, self.ignore_id, args.lsm_weight,
            args.transformer_length_normalized_loss)
        # self.verbose = args.verbose
        self.adim = args.adim
        # submodule for ASR task
        self.mtlalpha = args.mtlalpha
        self.asr_weight = getattr(args, "asr_weight", 0.0)
        if self.asr_weight > 0 and args.mtlalpha < 1:
            self.decoder_asr = Decoder(
                odim=odim,
                attention_dim=args.adim,
                attention_heads=args.aheads,
                linear_units=args.dunits,
                num_blocks=args.dlayers,
                dropout_rate=args.dropout_rate,
                positional_dropout_rate=args.dropout_rate,
                self_attention_dropout_rate=args.transformer_attn_dropout_rate,
                src_attention_dropout_rate=args.transformer_attn_dropout_rate,
            )
        # submodule for MT task
        self.mt_weight = getattr(args, "mt_weight", 0.0)
        if self.mt_weight > 0:
            self.encoder_mt = Encoder(
                idim=odim,
                attention_dim=args.adim,
                attention_heads=args.aheads,
                linear_units=args.dunits,
                num_blocks=args.dlayers,
                input_layer='embed',
                dropout_rate=args.dropout_rate,
                positional_dropout_rate=args.dropout_rate,
                attention_dropout_rate=args.transformer_attn_dropout_rate,
                padding_idx=0)
        self.reset_parameters(args)  # place after the submodule initialization
        if args.mtlalpha > 0.0:
            self.ctc = CTC(odim,
                           args.adim,
                           args.dropout_rate,
                           ctc_type=args.ctc_type,
                           reduce=True)
        else:
            self.ctc = None

        if self.asr_weight > 0 and (args.report_cer or args.report_wer):
            from espnet.nets.e2e_asr_common import ErrorCalculator
            self.error_calculator = ErrorCalculator(args.char_list,
                                                    args.sym_space,
                                                    args.sym_blank,
                                                    args.report_cer,
                                                    args.report_wer)
        else:
            self.error_calculator = None
        self.rnnlm = None

        # multilingual E2E-ST related
        self.multilingual = getattr(args, "multilingual", False)
        self.replace_sos = getattr(args, "replace_sos", False)
        if self.multilingual:
            assert self.replace_sos
        self.device = "cuda:0" if torch.cuda.is_available() else "cpu"
Пример #15
0
    def __init__(self, idim, odim, args):
        """Initialize transducer modules.

        Args:
            idim (int): dimension of inputs
            odim (int): dimension of outputs
            args (Namespace): argument Namespace containing options

        """
        super(E2E, self).__init__()
        torch.nn.Module.__init__(self)
        self.rnnt_mode = args.rnnt_mode
        self.etype = args.etype
        self.verbose = args.verbose
        self.char_list = args.char_list
        self.outdir = args.outdir
        self.space = args.sym_space
        self.blank = args.sym_blank
        self.reporter = Reporter()
        self.beam_size = args.beam_size

        # note that eos is the same as sos (equivalent ID)
        self.sos = odim - 1
        self.eos = odim - 1

        # subsample info
        self.subsample = get_subsample(args, mode='asr', arch='rnn-t')

        if args.use_frontend:
            # Relative importing because of using python3 syntax
            from espnet.nets.pytorch_backend.frontends.feature_transform \
                import feature_transform_for
            from espnet.nets.pytorch_backend.frontends.frontend \
                import frontend_for

            self.frontend = frontend_for(args, idim)
            self.feature_transform = feature_transform_for(
                args, (idim - 1) * 2)
            idim = args.n_mels
        else:
            self.frontend = None

        # encoder
        self.enc = encoder_for(args, idim, self.subsample)

        if args.rnnt_mode == 'rnnt-att':
            # attention
            self.att = att_for(args)
            # decoder
            self.dec = decoder_for(args, odim, self.att)
        else:
            # prediction
            self.dec = decoder_for(args, odim)
        # weight initialization
        self.init_like_chainer()

        # options for beam search
        if 'report_cer' in vars(args) and (args.report_cer or args.report_wer):
            recog_args = {
                'beam_size': args.beam_size,
                'nbest': args.nbest,
                'space': args.sym_space,
                'score_norm_transducer': args.score_norm_transducer
            }

            self.recog_args = argparse.Namespace(**recog_args)
            self.report_cer = args.report_cer
            self.report_wer = args.report_wer
        else:
            self.report_cer = False
            self.report_wer = False

        self.logzero = -10000000000.0
        self.rnnlm = None
        self.loss = None
Пример #16
0
    def __init__(self, idim, odim, args):
        """Construct an E2E object.

        :param int idim: dimension of inputs
        :param int odim: dimension of outputs
        :param Namespace args: argument Namespace containing options
        """
        super(E2E, self).__init__()
        torch.nn.Module.__init__(self)

        # fill missing arguments for compatibility
        args = fill_missing_args(args, self.add_arguments)

        self.mtlalpha = args.mtlalpha
        assert 0.0 <= self.mtlalpha <= 1.0, "mtlalpha should be [0.0, 1.0]"
        self.etype = args.etype
        self.verbose = args.verbose
        # NOTE: for self.build method
        args.char_list = getattr(args, "char_list", None)
        self.char_list = args.char_list
        self.outdir = args.outdir
        self.space = args.sym_space
        self.blank = args.sym_blank
        self.reporter = Reporter()

        # below means the last number becomes eos/sos ID
        # note that sos/eos IDs are identical
        self.sos = odim - 1
        self.eos = odim - 1

        # gs534 - word vocab
        bpe = len(self.char_list) > 100 # hack here for bpe flag
        self.vocabulary = Vocabulary(args.dictfile, bpe) if args.dictfile != '' else None

        # gs534 - create lexicon tree
        lextree = None
        self.meeting_KB = None
        self.n_KBs = getattr(args, 'dynamicKBs', 0)
        pretrain_emb = []
        if args.meetingKB and args.meetingpath != '':
            if self.n_KBs == 0 or not os.path.isdir(os.path.join(args.meetingpath, 'split_0')):
                self.meeting_KB = KBmeeting(self.vocabulary, args.meetingpath, args.char_list, bpe)
            else:
                # arrange multiple KBs
                self.meeting_KB = []
                for i in range(self.n_KBs):
                    self.meeting_KB.append(KBmeeting(self.vocabulary,
                        os.path.join(args.meetingpath, 'split_{}'.format(i)), args.char_list, bpe))

        # subsample info
        self.subsample = get_subsample(args, mode="asr", arch="rnn")

        # label smoothing info
        if args.lsm_type and os.path.isfile(args.train_json):
            logging.info("Use label smoothing with " + args.lsm_type)
            labeldist = label_smoothing_dist(
                odim, args.lsm_type, transcript=args.train_json
            )
        else:
            labeldist = None

        if getattr(args, "use_frontend", False):  # use getattr to keep compatibility
            self.frontend = frontend_for(args, idim)
            self.feature_transform = feature_transform_for(args, (idim - 1) * 2)
            idim = args.n_mels
        else:
            self.frontend = None

        # encoder
        self.enc = encoder_for(args, idim, self.subsample)
        # ctc
        self.ctc = ctc_for(args, odim)
        # attention
        self.att = att_for(args)
        # decoder
        self.dec = decoder_for(args, odim, self.sos, self.eos, self.att, labeldist,
            meetingKB=self.meeting_KB[0] if isinstance(self.meeting_KB, list) else self.meeting_KB)

        # weight initialization
        self.init_from = getattr(args, 'init_full_model', None)
        self.init_like_chainer()

        # options for beam search
        if args.report_cer or args.report_wer:
            recog_args = {
                "beam_size": args.beam_size,
                "penalty": args.penalty,
                "ctc_weight": args.ctc_weight,
                "maxlenratio": args.maxlenratio,
                "minlenratio": args.minlenratio,
                "lm_weight": args.lm_weight,
                "rnnlm": args.rnnlm,
                "nbest": args.nbest,
                "space": args.sym_space,
                "blank": args.sym_blank,
            }

            self.recog_args = argparse.Namespace(**recog_args)
            self.report_cer = args.report_cer
            self.report_wer = args.report_wer
        else:
            self.report_cer = False
            self.report_wer = False
        self.rnnlm = None

        self.logzero = -10000000000.0
        self.loss = None
        self.acc = None
Пример #17
0
    def __init__(self, idim, odim, args, ignore_id=-1):
        """Construct an E2E object.

        :param int idim: dimension of inputs
        :param int odim: dimension of outputs
        :param Namespace args: argument Namespace containing options
        """
        torch.nn.Module.__init__(self)

        # fill missing arguments for compatibility
        args = fill_missing_args(args, self.add_arguments)

        if args.transformer_attn_dropout_rate is None:
            args.transformer_attn_dropout_rate = args.dropout_rate
        self.encoder = Encoder(
            idim=idim,
            selfattention_layer_type=args.
            transformer_encoder_selfattn_layer_type,
            attention_dim=args.adim,
            attention_heads=args.aheads,
            conv_wshare=args.wshare,
            conv_kernel_length=args.ldconv_encoder_kernel_length,
            conv_usebias=args.ldconv_usebias,
            linear_units=args.eunits,
            num_blocks=args.elayers,
            input_layer="embed",
            dropout_rate=args.dropout_rate,
            positional_dropout_rate=args.dropout_rate,
            attention_dropout_rate=args.transformer_attn_dropout_rate,
        )
        self.decoder = Decoder(
            odim=odim,
            selfattention_layer_type=args.
            transformer_decoder_selfattn_layer_type,
            attention_dim=args.adim,
            attention_heads=args.aheads,
            conv_wshare=args.wshare,
            conv_kernel_length=args.ldconv_decoder_kernel_length,
            conv_usebias=args.ldconv_usebias,
            linear_units=args.dunits,
            num_blocks=args.dlayers,
            dropout_rate=args.dropout_rate,
            positional_dropout_rate=args.dropout_rate,
            self_attention_dropout_rate=args.transformer_attn_dropout_rate,
            src_attention_dropout_rate=args.transformer_attn_dropout_rate,
        )
        self.pad = 0  # use <blank> for padding
        self.sos = odim - 1
        self.eos = odim - 1
        self.odim = odim
        self.ignore_id = ignore_id
        self.subsample = get_subsample(args, mode="mt", arch="transformer")
        self.reporter = Reporter()

        # tie source and target emeddings
        if args.tie_src_tgt_embedding:
            if idim != odim:
                raise ValueError(
                    "When using tie_src_tgt_embedding, idim and odim must be equal."
                )
            self.encoder.embed[0].weight = self.decoder.embed[0].weight

        # tie emeddings and the classfier
        if args.tie_classifier:
            self.decoder.output_layer.weight = self.decoder.embed[0].weight

        self.criterion = LabelSmoothingLoss(
            self.odim,
            self.ignore_id,
            args.lsm_weight,
            args.transformer_length_normalized_loss,
        )
        self.normalize_length = args.transformer_length_normalized_loss  # for PPL
        self.reset_parameters(args)
        self.adim = args.adim
        self.error_calculator = ErrorCalculator(args.char_list, args.sym_space,
                                                args.sym_blank,
                                                args.report_bleu)
        self.rnnlm = None

        # multilingual MT related
        self.multilingual = args.multilingual
Пример #18
0
    def __init__(self, idim, odim, args, ignore_id=-1):
        """Construct an E2E object.

        :param int idim: dimension of inputs
        :param int odim: dimension of outputs
        :param Namespace args: argument Namespace containing options
        """
        super(E2E, self).__init__()
        torch.nn.Module.__init__(self)
        self.etype = args.etype
        self.verbose = args.verbose
        # NOTE: for self.build method
        self.outdir = args.outdir

        # below means the last number becomes eos/sos ID
        # note that sos/eos IDs are identical
        self.sos = odim - 1
        self.eos = odim - 1
        self.odim = odim
        self.ignore_id = ignore_id
        self.subsample = get_subsample(args, mode="asr", arch="rnn")
        self.reporter = Reporter()

        # ICT related
        self.scheme = args.mixup_scheme
        self.consistency_weight = args.consistency_weight
        self.consistency_rampup_starts = args.consistency_rampup_starts
        self.consistency_rampup_ends = args.consistency_rampup_ends
        self.mixup_alpha = args.mixup_alpha

        # if True, print out student model accuracy
        self.show_student_model_acc = args.show_student_model_acc

        # label smoothing info
        if args.lsm_type and os.path.isfile(args.train_json):
            logging.info("Use label smoothing with " + args.lsm_type)
            labeldist = label_smoothing_dist(
                odim, args.lsm_type, transcript=args.train_json
            )
        else:
            labeldist = None

        if getattr(args, "use_frontend", False):  # use getattr to keep compatibility
            self.frontend = frontend_for(args, idim)
            self.feature_transform = feature_transform_for(args, (idim - 1) * 2)
            idim = args.n_mels
        else:
            self.frontend = None

        # encoder
        self.enc = encoder_for(args, idim, odim, self.subsample)
        self.ema_enc = encoder_for(args, idim, odim, self.subsample)
        for param in self.ema_enc.parameters():
            param.detach_()
        # leave ctc for future works
        # self.ctc = ctc_for(args, odim)

        # weight initialization
        if args.initializer == "lecun":
            self.init_like_chainer()
        elif args.initializer == "orthogonal":
            self.init_orthogonal()
        else:
            raise NotImplementedError(
                "unknown initializer: " + args.initializer
            )

        if args.report_cer or args.report_wer:
            self.error_calculator = ErrorCalculator(
                args.char_list,
                args.sym_space,
                args.sym_blank,
                args.report_cer,
                args.report_wer,
            )
        else:
            self.error_calculator = None
        self.rnnlm = None

        self.logzero = -10000000000.0
        self.loss = None
        self.acc = None
Пример #19
0
    def __init__(
        self,
        idim: int,
        odim: int,
        args: Namespace,
        ignore_id: int = -1,
        blank_id: int = 0,
        training: bool = True,
    ):
        """Construct an E2E object for transducer model."""
        torch.nn.Module.__init__(self)

        args = fill_missing_args(args, self.add_arguments)

        self.is_transducer = True

        self.use_auxiliary_enc_outputs = (True if (
            training and args.use_aux_transducer_loss) else False)

        self.subsample = get_subsample(
            args,
            mode="asr",
            arch="transformer" if args.etype == "custom" else "rnn-t")

        if self.use_auxiliary_enc_outputs:
            n_layers = (((len(args.enc_block_arch) * args.enc_block_repeat) -
                         1) if args.enc_block_arch is not None else
                        (args.elayers - 1))

            aux_enc_output_layers = valid_aux_encoder_output_layers(
                args.aux_transducer_loss_enc_output_layers,
                n_layers,
                args.use_symm_kl_div_loss,
                self.subsample,
            )
        else:
            aux_enc_output_layers = []

        if args.etype == "custom":
            if args.enc_block_arch is None:
                raise ValueError(
                    "When specifying custom encoder type, --enc-block-arch"
                    "should also be specified in training config. See"
                    "egs/vivos/asr1/conf/transducer/train_*.yaml for more info."
                )

            self.encoder = CustomEncoder(
                idim,
                args.enc_block_arch,
                input_layer=args.custom_enc_input_layer,
                repeat_block=args.enc_block_repeat,
                self_attn_type=args.custom_enc_self_attn_type,
                positional_encoding_type=args.
                custom_enc_positional_encoding_type,
                positionwise_activation_type=args.
                custom_enc_pw_activation_type,
                conv_mod_activation_type=args.
                custom_enc_conv_mod_activation_type,
                aux_enc_output_layers=aux_enc_output_layers,
            )
            encoder_out = self.encoder.enc_out
        else:
            self.enc = encoder_for(
                args,
                idim,
                self.subsample,
                aux_enc_output_layers=aux_enc_output_layers,
            )
            encoder_out = args.eprojs

        if args.dtype == "custom":
            if args.dec_block_arch is None:
                raise ValueError(
                    "When specifying custom decoder type, --dec-block-arch"
                    "should also be specified in training config. See"
                    "egs/vivos/asr1/conf/transducer/train_*.yaml for more info."
                )

            self.decoder = CustomDecoder(
                odim,
                args.dec_block_arch,
                input_layer=args.custom_dec_input_layer,
                repeat_block=args.dec_block_repeat,
                positionwise_activation_type=args.
                custom_dec_pw_activation_type,
                dropout_rate_embed=args.dropout_rate_embed_decoder,
                blank_id=blank_id,
            )
            decoder_out = self.decoder.dunits
        else:
            self.dec = RNNDecoder(
                odim,
                args.dtype,
                args.dlayers,
                args.dunits,
                args.dec_embed_dim,
                dropout_rate=args.dropout_rate_decoder,
                dropout_rate_embed=args.dropout_rate_embed_decoder,
                blank_id=blank_id,
            )
            decoder_out = args.dunits

        self.transducer_tasks = TransducerTasks(
            encoder_out,
            decoder_out,
            args.joint_dim,
            odim,
            joint_activation_type=args.joint_activation_type,
            transducer_loss_weight=args.transducer_weight,
            ctc_loss=args.use_ctc_loss,
            ctc_loss_weight=args.ctc_loss_weight,
            ctc_loss_dropout_rate=args.ctc_loss_dropout_rate,
            lm_loss=args.use_lm_loss,
            lm_loss_weight=args.lm_loss_weight,
            lm_loss_smoothing_rate=args.lm_loss_smoothing_rate,
            aux_transducer_loss=args.use_aux_transducer_loss,
            aux_transducer_loss_weight=args.aux_transducer_loss_weight,
            aux_transducer_loss_mlp_dim=args.aux_transducer_loss_mlp_dim,
            aux_trans_loss_mlp_dropout_rate=args.
            aux_transducer_loss_mlp_dropout_rate,
            symm_kl_div_loss=args.use_symm_kl_div_loss,
            symm_kl_div_loss_weight=args.symm_kl_div_loss_weight,
            fastemit_lambda=args.fastemit_lambda,
            blank_id=blank_id,
            ignore_id=ignore_id,
            training=training,
        )

        if training and (args.report_cer or args.report_wer):
            self.error_calculator = ErrorCalculator(
                self.decoder if args.dtype == "custom" else self.dec,
                self.transducer_tasks.joint_network,
                args.char_list,
                args.sym_space,
                args.sym_blank,
                args.report_cer,
                args.report_wer,
            )
        else:
            self.error_calculator = None

        self.etype = args.etype
        self.dtype = args.dtype

        self.sos = odim - 1
        self.eos = odim - 1
        self.blank_id = blank_id
        self.ignore_id = ignore_id

        self.space = args.sym_space
        self.blank = args.sym_blank

        self.odim = odim

        self.reporter = Reporter()

        self.default_parameters(args)

        self.loss = None
        self.rnnlm = None
Пример #20
0
    def __init__(self, idim, odim, args):
        """Construct an E2E object.

        :param int idim: dimension of inputs
        :param int odim: dimension of outputs
        :param Namespace args: argument Namespace containing options
        """
        super(E2E, self).__init__()
        torch.nn.Module.__init__(self)

        # fill missing arguments for compatibility
        args = fill_missing_args(args, self.add_arguments)

        self.etype = args.etype
        self.verbose = args.verbose
        # NOTE: for self.build method
        args.char_list = getattr(args, "char_list", None)
        self.char_list = args.char_list
        self.outdir = args.outdir
        self.space = args.sym_space
        self.blank = args.sym_blank
        self.reporter = Reporter()

        # below means the last number becomes eos/sos ID
        # note that sos/eos IDs are identical
        self.sos = odim - 1
        self.eos = odim - 1
        self.pad = 0
        # NOTE: we reserve index:0 for <pad> although this is reserved for a blank class
        # in ASR. However, blank labels are not used in MT.
        # To keep the vocabulary size,
        # we use index:0 for padding instead of adding one more class.

        # subsample info
        self.subsample = get_subsample(args, mode="mt", arch="rnn")

        # label smoothing info
        if args.lsm_type and os.path.isfile(args.train_json):
            logging.info("Use label smoothing with " + args.lsm_type)
            labeldist = label_smoothing_dist(odim,
                                             args.lsm_type,
                                             transcript=args.train_json)
        else:
            labeldist = None

        # multilingual related
        self.multilingual = getattr(args, "multilingual", False)
        self.replace_sos = getattr(args, "replace_sos", False)

        # encoder
        self.embed = torch.nn.Embedding(idim,
                                        args.eunits,
                                        padding_idx=self.pad)
        self.dropout = torch.nn.Dropout(p=args.dropout_rate)
        self.enc = encoder_for(args, args.eunits, self.subsample)
        # attention
        self.att = att_for(args)
        # decoder
        self.dec = decoder_for(args, odim, self.sos, self.eos, self.att,
                               labeldist)

        # tie source and target emeddings
        if args.tie_src_tgt_embedding:
            if idim != odim:
                raise ValueError(
                    "When using tie_src_tgt_embedding, idim and odim must be equal."
                )
            if args.eunits != args.dunits:
                raise ValueError(
                    "When using tie_src_tgt_embedding, eunits and dunits must be equal."
                )
            self.embed.weight = self.dec.embed.weight

        # tie emeddings and the classfier
        if args.tie_classifier:
            if args.context_residual:
                raise ValueError(
                    "When using tie_classifier, context_residual must be turned off."
                )
            self.dec.output.weight = self.dec.embed.weight

        # weight initialization
        self.init_like_fairseq()

        # options for beam search
        if args.report_bleu:
            trans_args = {
                "beam_size": args.beam_size,
                "penalty": args.penalty,
                "ctc_weight": 0,
                "maxlenratio": args.maxlenratio,
                "minlenratio": args.minlenratio,
                "lm_weight": args.lm_weight,
                "rnnlm": args.rnnlm,
                "nbest": args.nbest,
                "space": args.sym_space,
                "blank": args.sym_blank,
                "tgt_lang": False,
            }

            self.trans_args = argparse.Namespace(**trans_args)
            self.report_bleu = args.report_bleu
        else:
            self.report_bleu = False
        self.rnnlm = None

        self.logzero = -10000000000.0
        self.loss = None
        self.acc = None
Пример #21
0
    def __init__(self, idims, odim, args):
        """Initialize this class with python-level args.

        Args:
            idims (list): list of the number of an input feature dim.
            odim (int): The number of output vocab.
            args (Namespace): arguments

        """
        super(E2E, self).__init__()
        torch.nn.Module.__init__(self)
        self.mtlalpha = args.mtlalpha
        assert 0.0 <= self.mtlalpha <= 1.0, "mtlalpha should be [0.0, 1.0]"
        self.verbose = args.verbose
        # NOTE: for self.build method
        args.char_list = getattr(args, "char_list", None)
        self.char_list = args.char_list
        self.outdir = args.outdir
        self.space = args.sym_space
        self.blank = args.sym_blank
        self.reporter = Reporter()
        self.num_encs = args.num_encs
        self.share_ctc = args.share_ctc

        # below means the last number becomes eos/sos ID
        # note that sos/eos IDs are identical
        self.sos = odim - 1
        self.eos = odim - 1

        # subsample info
        self.subsample_list = get_subsample(args,
                                            mode="asr",
                                            arch="rnn_mulenc")

        # label smoothing info
        if args.lsm_type and os.path.isfile(args.train_json):
            logging.info("Use label smoothing with " + args.lsm_type)
            labeldist = label_smoothing_dist(odim,
                                             args.lsm_type,
                                             transcript=args.train_json)
        else:
            labeldist = None

        # speech translation related
        self.replace_sos = getattr(args, "replace_sos",
                                   False)  # use getattr to keep compatibility

        self.frontend = None

        # encoder
        self.enc = encoder_for(args, idims, self.subsample_list)
        # ctc
        self.ctc = ctc_for(args, odim)
        # attention
        self.att = att_for(args)
        # hierarchical attention network
        han = att_for(args, han_mode=True)
        self.att.append(han)
        # decoder
        self.dec = decoder_for(args, odim, self.sos, self.eos, self.att,
                               labeldist)

        if args.mtlalpha > 0 and self.num_encs > 1:
            # weights-ctc,
            # e.g. ctc_loss = w_1*ctc_1_loss + w_2 * ctc_2_loss + w_N * ctc_N_loss
            self.weights_ctc_train = args.weights_ctc_train / np.sum(
                args.weights_ctc_train)  # normalize
            self.weights_ctc_dec = args.weights_ctc_dec / np.sum(
                args.weights_ctc_dec)  # normalize
            logging.info("ctc weights (training during training): " +
                         " ".join([str(x) for x in self.weights_ctc_train]))
            logging.info("ctc weights (decoding during training): " +
                         " ".join([str(x) for x in self.weights_ctc_dec]))
        else:
            self.weights_ctc_dec = [1.0]
            self.weights_ctc_train = [1.0]

        # weight initialization
        self.init_like_chainer()

        # options for beam search
        if args.report_cer or args.report_wer:
            recog_args = {
                "beam_size": args.beam_size,
                "penalty": args.penalty,
                "ctc_weight": args.ctc_weight,
                "maxlenratio": args.maxlenratio,
                "minlenratio": args.minlenratio,
                "lm_weight": args.lm_weight,
                "rnnlm": args.rnnlm,
                "nbest": args.nbest,
                "space": args.sym_space,
                "blank": args.sym_blank,
                "tgt_lang": False,
                "ctc_weights_dec": self.weights_ctc_dec,
            }

            self.recog_args = argparse.Namespace(**recog_args)
            self.report_cer = args.report_cer
            self.report_wer = args.report_wer
        else:
            self.report_cer = False
            self.report_wer = False
        self.rnnlm = None

        self.logzero = -10000000000.0
        self.loss = None
        self.acc = None
Пример #22
0
    def __init__(self, idim, odim, args):
        """Construct an E2E object.

        :param int idim: dimension of inputs
        :param int odim: dimension of outputs
        :param Namespace args: argument Namespace containing options
        """
        super(E2E, self).__init__()
        torch.nn.Module.__init__(self)

        # fill missing arguments for compatibility
        args = fill_missing_args(args, self.add_arguments)

        self.asr_weight = args.asr_weight
        self.mt_weight = args.mt_weight
        self.mtlalpha = args.mtlalpha
        assert 0.0 <= self.asr_weight < 1.0, "asr_weight should be [0.0, 1.0)"
        assert 0.0 <= self.mt_weight < 1.0, "mt_weight should be [0.0, 1.0)"
        assert 0.0 <= self.mtlalpha <= 1.0, "mtlalpha should be [0.0, 1.0]"
        self.etype = args.etype
        self.verbose = args.verbose
        # NOTE: for self.build method
        args.char_list = getattr(args, "char_list", None)
        self.char_list = args.char_list
        self.outdir = args.outdir
        self.space = args.sym_space
        self.blank = args.sym_blank
        self.reporter = Reporter()

        # below means the last number becomes eos/sos ID
        # note that sos/eos IDs are identical
        self.sos = odim - 1
        self.eos = odim - 1
        self.pad = 0
        # NOTE: we reserve index:0 for <pad> although this is reserved for a blank class
        # in ASR. However, blank labels are not used in MT.
        # To keep the vocabulary size,
        # we use index:0 for padding instead of adding one more class.

        # subsample info
        self.subsample = get_subsample(args, mode="st", arch="rnn")

        # label smoothing info
        if args.lsm_type and os.path.isfile(args.train_json):
            logging.info("Use label smoothing with " + args.lsm_type)
            labeldist = label_smoothing_dist(odim,
                                             args.lsm_type,
                                             transcript=args.train_json)
        else:
            labeldist = None

        # multilingual related
        self.multilingual = getattr(args, "multilingual", False)
        self.replace_sos = getattr(args, "replace_sos", False)

        # encoder
        self.enc = encoder_for(args, idim, self.subsample)
        # attention (ST)
        self.att = att_for(args)
        # decoder (ST)
        self.dec = decoder_for(args, odim, self.sos, self.eos, self.att,
                               labeldist)

        # submodule for ASR task
        self.ctc = None
        self.att_asr = None
        self.dec_asr = None
        if self.asr_weight > 0:
            if self.mtlalpha > 0.0:
                self.ctc = CTC(
                    odim,
                    args.eprojs,
                    args.dropout_rate,
                    ctc_type=args.ctc_type,
                    reduce=True,
                )
            if self.mtlalpha < 1.0:
                # attention (asr)
                self.att_asr = att_for(args)
                # decoder (asr)
                args_asr = copy.deepcopy(args)
                args_asr.atype = "location"  # TODO(hirofumi0810): make this option
                self.dec_asr = decoder_for(args_asr, odim, self.sos, self.eos,
                                           self.att_asr, labeldist)

        # submodule for MT task
        if self.mt_weight > 0:
            self.embed_mt = torch.nn.Embedding(odim,
                                               args.eunits,
                                               padding_idx=self.pad)
            self.dropout_mt = torch.nn.Dropout(p=args.dropout_rate)
            self.enc_mt = encoder_for(args,
                                      args.eunits,
                                      subsample=np.ones(args.elayers + 1,
                                                        dtype=np.int))

        # weight initialization
        self.init_like_chainer()

        # options for beam search
        if self.asr_weight > 0 and args.report_cer or args.report_wer:
            recog_args = {
                "beam_size": args.beam_size,
                "penalty": args.penalty,
                "ctc_weight": args.ctc_weight,
                "maxlenratio": args.maxlenratio,
                "minlenratio": args.minlenratio,
                "lm_weight": args.lm_weight,
                "rnnlm": args.rnnlm,
                "nbest": args.nbest,
                "space": args.sym_space,
                "blank": args.sym_blank,
                "tgt_lang": False,
            }

            self.recog_args = argparse.Namespace(**recog_args)
            self.report_cer = args.report_cer
            self.report_wer = args.report_wer
        else:
            self.report_cer = False
            self.report_wer = False
        if args.report_bleu:
            trans_args = {
                "beam_size": args.beam_size,
                "penalty": args.penalty,
                "ctc_weight": 0,
                "maxlenratio": args.maxlenratio,
                "minlenratio": args.minlenratio,
                "lm_weight": args.lm_weight,
                "rnnlm": args.rnnlm,
                "nbest": args.nbest,
                "space": args.sym_space,
                "blank": args.sym_blank,
                "tgt_lang": False,
            }

            self.trans_args = argparse.Namespace(**trans_args)
            self.report_bleu = args.report_bleu
        else:
            self.report_bleu = False
        self.rnnlm = None

        self.logzero = -10000000000.0
        self.loss = None
        self.acc = None
Пример #23
0
    def __init__(self, idim, odim, args, ignore_id=-1, blank_id=0):
        """Construct an E2E object for transducer model.

        Args:
            idim (int): dimension of inputs
            odim (int): dimension of outputs
            args (Namespace): argument Namespace containing options

        """
        torch.nn.Module.__init__(self)

        if args.etype == 'transformer':
            self.encoder = Encoder(idim=idim,
                                   attention_dim=args.adim,
                                   attention_heads=args.aheads,
                                   linear_units=args.eunits,
                                   num_blocks=args.elayers,
                                   input_layer=args.transformer_input_layer,
                                   dropout_rate=args.dropout_rate,
                                   positional_dropout_rate=args.dropout_rate,
                                   attention_dropout_rate=args.
                                   transformer_attn_dropout_rate_encoder)

            self.subsample = [1]
        else:
            self.subsample = get_subsample(args, mode='asr', arch='rnn-t')

            self.encoder = encoder_for(args, idim, self.subsample)

        if args.dtype == 'transformer':
            self.decoder = Decoder(
                odim=odim,
                jdim=args.joint_dim,
                attention_dim=args.adim,
                attention_heads=args.aheads,
                linear_units=args.dunits,
                num_blocks=args.dlayers,
                input_layer=args.transformer_dec_input_layer,
                dropout_rate=args.dropout_rate_decoder,
                positional_dropout_rate=args.dropout_rate_decoder,
                attention_dropout_rate=args.
                transformer_attn_dropout_rate_decoder)
        else:
            if args.etype == 'transformer':
                args.eprojs = args.adim

            if args.rnnt_mode == 'rnnt-att':
                self.att = att_for(args)
                self.decoder = decoder_for(args, odim, self.att)
            else:
                self.decoder = decoder_for(args, odim)

        self.etype = args.etype
        self.dtype = args.dtype
        self.rnnt_mode = args.rnnt_mode

        self.sos = odim - 1
        self.eos = odim - 1
        self.blank_id = blank_id
        self.ignore_id = ignore_id

        self.space = args.sym_space
        self.blank = args.sym_blank

        self.odim = odim
        self.adim = args.adim

        self.reporter = Reporter()

        self.criterion = TransLoss(args.trans_type, self.blank_id)

        self.default_parameters(args)

        if args.report_cer or args.report_wer:
            from espnet.nets.e2e_asr_common import ErrorCalculatorTrans

            self.error_calculator = ErrorCalculatorTrans(self.decoder, args)
        else:
            self.error_calculator = None

        self.logzero = -10000000000.0
        self.loss = None
        self.rnnlm = None
Пример #24
0
    def __init__(self, idim, odim, args, ignore_id=-1):
        """Construct an E2E object.

        :param int idim: dimension of inputs
        :param int odim: dimension of outputs
        :param Namespace args: argument Namespace containing options
        """
        torch.nn.Module.__init__(self)
        if args.transformer_attn_dropout_rate is None:
            args.transformer_attn_dropout_rate = args.dropout_rate
        self.encoder = Encoder(
            idim=idim,
            selfattention_layer_type=args.
            transformer_encoder_selfattn_layer_type,
            attention_dim=args.adim,
            attention_heads=args.aheads,
            conv_wshare=args.wshare,
            conv_kernel_length=args.ldconv_encoder_kernel_length,
            conv_usebias=args.ldconv_usebias,
            linear_units=args.eunits,
            num_blocks=args.elayers,
            input_layer=args.transformer_input_layer,
            dropout_rate=args.dropout_rate,
            positional_dropout_rate=args.dropout_rate,
            attention_dropout_rate=args.transformer_attn_dropout_rate,
            attention_type=getattr(args, 'transformer_enc_attn_type',
                                   'self_attn'),
            max_attn_span=getattr(args, 'enc_max_attn_span', [None]),
            span_init=getattr(args, 'span_init', None),
            span_ratio=getattr(args, 'span_ratio', None),
            ratio_adaptive=getattr(args, 'ratio_adaptive', None))
        self.decoder = Decoder(
            odim=odim,
            selfattention_layer_type=args.
            transformer_decoder_selfattn_layer_type,
            attention_dim=args.adim,
            attention_heads=args.aheads,
            conv_wshare=args.wshare,
            conv_kernel_length=args.ldconv_decoder_kernel_length,
            conv_usebias=args.ldconv_usebias,
            linear_units=args.dunits,
            num_blocks=args.dlayers,
            dropout_rate=args.dropout_rate,
            positional_dropout_rate=args.dropout_rate,
            self_attention_dropout_rate=args.transformer_attn_dropout_rate,
            src_attention_dropout_rate=args.transformer_attn_dropout_rate,
            attention_type=getattr(args, 'transformer_dec_attn_type',
                                   'self_attn'),
            max_attn_span=getattr(args, 'dec_max_attn_span', [None]),
            span_init=getattr(args, 'span_init', None),
            span_ratio=getattr(args, 'span_ratio', None),
            ratio_adaptive=getattr(args, 'ratio_adaptive', None))
        self.sos = odim - 1
        self.eos = odim - 1
        self.odim = odim
        self.ignore_id = ignore_id
        self.subsample = get_subsample(args, mode="asr", arch="transformer")
        self.reporter = Reporter()

        # self.lsm_weight = a
        self.criterion = LabelSmoothingLoss(
            self.odim,
            self.ignore_id,
            args.lsm_weight,
            args.transformer_length_normalized_loss,
        )
        # self.verbose = args.verbose
        self.reset_parameters(args)
        self.adim = args.adim
        self.mtlalpha = args.mtlalpha
        if args.mtlalpha > 0.0:
            self.ctc = CTC(odim,
                           args.adim,
                           args.dropout_rate,
                           ctc_type=args.ctc_type,
                           reduce=True)
        else:
            self.ctc = None

        if args.report_cer or args.report_wer:
            self.error_calculator = ErrorCalculator(
                args.char_list,
                args.sym_space,
                args.sym_blank,
                args.report_cer,
                args.report_wer,
            )
        else:
            self.error_calculator = None
        self.rnnlm = None
        self.attention_enc_type = getattr(args, 'transformer_enc_attn_type',
                                          'self_attn')
        self.attention_dec_type = getattr(args, 'transformer_dec_attn_type',
                                          'self_attn')
        self.span_loss_coef = getattr(args, 'span_loss_coef', None)
        self.ratio_adaptive = getattr(args, 'ratio_adaptive', None)
        self.sym_blank = args.sym_blank
Пример #25
0
    def __init__(self, idim, odim, args, ignore_id=-1):
        """Construct an E2E object.

        :param int idim: dimension of inputs
        :param int odim: dimension of outputs
        :param Namespace args: argument Namespace containing options
        """
        torch.nn.Module.__init__(self)

        # fill missing arguments for compatibility
        args = fill_missing_args(args, self.add_arguments)

        if args.transformer_attn_dropout_rate is None:
            args.transformer_attn_dropout_rate = args.dropout_rate
        self.encoder = Encoder(
            idim=idim,
            selfattention_layer_type=args.
            transformer_encoder_selfattn_layer_type,
            attention_dim=args.adim,
            attention_heads=args.aheads,
            conv_wshare=args.wshare,
            conv_kernel_length=args.ldconv_encoder_kernel_length,
            conv_usebias=args.ldconv_usebias,
            linear_units=args.eunits,
            num_blocks=args.elayers,
            input_layer=args.transformer_input_layer,
            dropout_rate=args.dropout_rate,
            positional_dropout_rate=args.dropout_rate,
            attention_dropout_rate=args.transformer_attn_dropout_rate,
        )
        self.decoder = Decoder(
            odim=odim,
            selfattention_layer_type=args.
            transformer_decoder_selfattn_layer_type,
            attention_dim=args.adim,
            attention_heads=args.aheads,
            conv_wshare=args.wshare,
            conv_kernel_length=args.ldconv_decoder_kernel_length,
            conv_usebias=args.ldconv_usebias,
            linear_units=args.dunits,
            num_blocks=args.dlayers,
            dropout_rate=args.dropout_rate,
            positional_dropout_rate=args.dropout_rate,
            self_attention_dropout_rate=args.transformer_attn_dropout_rate,
            src_attention_dropout_rate=args.transformer_attn_dropout_rate,
        )
        self.pad = 0  # use <blank> for padding
        self.sos = odim - 1
        self.eos = odim - 1
        self.odim = odim
        self.ignore_id = ignore_id
        self.subsample = get_subsample(args, mode="st", arch="transformer")
        self.reporter = Reporter()

        self.criterion = LabelSmoothingLoss(
            self.odim,
            self.ignore_id,
            args.lsm_weight,
            args.transformer_length_normalized_loss,
        )
        # submodule for ASR task
        self.mtlalpha = args.mtlalpha
        self.asr_weight = getattr(args, "asr_weight", 0.0)
        if self.asr_weight > 0 and args.mtlalpha < 1:
            self.decoder_asr = Decoder(
                odim=odim,
                attention_dim=args.adim,
                attention_heads=args.aheads,
                linear_units=args.dunits,
                num_blocks=args.dlayers,
                dropout_rate=args.dropout_rate,
                positional_dropout_rate=args.dropout_rate,
                self_attention_dropout_rate=args.transformer_attn_dropout_rate,
                src_attention_dropout_rate=args.transformer_attn_dropout_rate,
            )

        # submodule for MT task
        self.mt_weight = getattr(args, "mt_weight", 0.0)
        if self.mt_weight > 0:
            self.encoder_mt = Encoder(
                idim=odim,
                attention_dim=args.adim,
                attention_heads=args.aheads,
                linear_units=args.dunits,
                num_blocks=args.dlayers,
                input_layer="embed",
                dropout_rate=args.dropout_rate,
                positional_dropout_rate=args.dropout_rate,
                attention_dropout_rate=args.transformer_attn_dropout_rate,
                padding_idx=0,
            )
        self.reset_parameters(
            args)  # NOTE: place after the submodule initialization
        self.adim = args.adim  # used for CTC (equal to d_model)
        if self.asr_weight > 0 and args.mtlalpha > 0.0:
            self.ctc = CTC(odim,
                           args.adim,
                           args.dropout_rate,
                           ctc_type=args.ctc_type,
                           reduce=True)
        else:
            self.ctc = None

        # translation error calculator
        self.error_calculator = MTErrorCalculator(args.char_list,
                                                  args.sym_space,
                                                  args.sym_blank,
                                                  args.report_bleu)

        # recognition error calculator
        self.error_calculator_asr = ASRErrorCalculator(
            args.char_list,
            args.sym_space,
            args.sym_blank,
            args.report_cer,
            args.report_wer,
        )
        self.rnnlm = None

        # multilingual E2E-ST related
        self.multilingual = getattr(args, "multilingual", False)
        self.replace_sos = getattr(args, "replace_sos", False)
Пример #26
0
    def __init__(self, idim, odim, args):
        """Construct an E2E object.

        :param int idim: dimension of inputs
        :param int odim: dimension of outputs
        :param Namespace args: argument Namespace containing options
        """
        super(E2E, self).__init__()
        torch.nn.Module.__init__(self)

        # fill missing arguments for compatibility
        args = fill_missing_args(args, self.add_arguments)

        self.mtlalpha = args.mtlalpha
        assert 0.0 <= self.mtlalpha <= 1.0, "mtlalpha should be [0.0, 1.0]"
        self.etype = args.etype
        self.verbose = args.verbose
        # NOTE: for self.build method
        args.char_list = getattr(args, "char_list", None)
        self.char_list = args.char_list
        self.outdir = args.outdir
        self.space = args.sym_space
        self.blank = args.sym_blank
        self.reporter = Reporter()

        # below means the last number becomes eos/sos ID
        # note that sos/eos IDs are identical
        self.sos = odim - 1
        self.eos = odim - 1

        # subsample info
        self.subsample = get_subsample(args, mode="asr", arch="rnn")

        # label smoothing info
        if args.lsm_type and os.path.isfile(args.train_json):
            logging.info("Use label smoothing with " + args.lsm_type)
            labeldist = label_smoothing_dist(odim,
                                             args.lsm_type,
                                             transcript=args.train_json)
        else:
            labeldist = None

        if getattr(args, "use_frontend",
                   False):  # use getattr to keep compatibility
            self.frontend = frontend_for(args, idim)
            self.feature_transform = feature_transform_for(
                args, (idim - 1) * 2)
            idim = args.n_mels
        else:
            self.frontend = None

        # encoder
        self.enc = encoder_for(args, idim, self.subsample)
        # ctc
        self.ctc = ctc_for(args, odim)
        # attention
        self.att = att_for(args)
        # decoder
        self.dec = decoder_for(args, odim, self.sos, self.eos, self.att,
                               labeldist)

        # weight initialization
        self.init_like_chainer()

        # options for beam search
        if args.report_cer or args.report_wer:
            recog_args = {
                "beam_size": args.beam_size,
                "penalty": args.penalty,
                "ctc_weight": args.ctc_weight,
                "maxlenratio": args.maxlenratio,
                "minlenratio": args.minlenratio,
                "lm_weight": args.lm_weight,
                "rnnlm": args.rnnlm,
                "nbest": args.nbest,
                "space": args.sym_space,
                "blank": args.sym_blank,
            }

            self.recog_args = argparse.Namespace(**recog_args)
            self.report_cer = args.report_cer
            self.report_wer = args.report_wer
        else:
            self.report_cer = False
            self.report_wer = False
        self.rnnlm = None

        self.logzero = -10000000000.0
        self.loss = None
        self.acc = None
    def __init__(self, idim, odim, args, ignore_id=-1):
        """Construct an E2E object.

        :param int idim: dimension of inputs
        :param int odim: dimension of outputs
        :param Namespace args: argument Namespace containing options
        """
        torch.nn.Module.__init__(self)

        # fill missing arguments for compatibility
        args = fill_missing_args(args, self.add_arguments)

        if args.transformer_attn_dropout_rate is None:
            args.transformer_attn_dropout_rate = args.dropout_rate
        self.encoder = Encoder(
            idim=idim,
            selfattention_layer_type=args.transformer_encoder_selfattn_layer_type,
            attention_dim=args.adim,
            attention_heads=args.aheads,
            conv_wshare=args.wshare,
            conv_kernel_length=args.ldconv_encoder_kernel_length,
            conv_usebias=args.ldconv_usebias,
            linear_units=args.eunits,
            num_blocks=args.elayers,
            input_layer=args.transformer_input_layer,
            dropout_rate=args.dropout_rate,
            positional_dropout_rate=args.dropout_rate,
            attention_dropout_rate=args.transformer_attn_dropout_rate,
        )
        if args.mtlalpha < 1:
            self.decoder = Decoder(
                odim=odim,
                selfattention_layer_type=args.transformer_decoder_selfattn_layer_type,
                attention_dim=args.adim,
                attention_heads=args.aheads,
                conv_wshare=args.wshare,
                conv_kernel_length=args.ldconv_decoder_kernel_length,
                conv_usebias=args.ldconv_usebias,
                linear_units=args.dunits,
                num_blocks=args.dlayers,
                dropout_rate=args.dropout_rate,
                positional_dropout_rate=args.dropout_rate,
                self_attention_dropout_rate=args.transformer_attn_dropout_rate,
                src_attention_dropout_rate=args.transformer_attn_dropout_rate,
            )
            self.criterion = LabelSmoothingLoss(
                odim,
                ignore_id,
                args.lsm_weight,
                args.transformer_length_normalized_loss,
            )
        else:
            self.decoder = None
            self.criterion = None
        self.blank = 0
        self.decoder_mode = args.decoder_mode
        if self.decoder_mode == "maskctc":
            self.mask_token = odim - 1
            self.sos = odim - 2
            self.eos = odim - 2
        else:
            self.sos = odim - 1
            self.eos = odim - 1
        self.odim = odim
        self.ignore_id = ignore_id
        self.subsample = get_subsample(args, mode="asr", arch="transformer")
        self.reporter = Reporter()

        self.reset_parameters(args)
        self.adim = args.adim  # used for CTC (equal to d_model)
        self.mtlalpha = args.mtlalpha
        if args.mtlalpha > 0.0:
            self.ctc = CTC(
                odim, args.adim, args.dropout_rate, ctc_type=args.ctc_type, reduce=True
            )
        else:
            self.ctc = None

        if args.report_cer or args.report_wer:
            self.error_calculator = ErrorCalculator(
                args.char_list,
                args.sym_space,
                args.sym_blank,
                args.report_cer,
                args.report_wer,
            )
        else:
            self.error_calculator = None
        self.rnnlm = None
Пример #28
0
    def __init__(self,
                 idim,
                 odim,
                 args,
                 ignore_id=-1,
                 blank_id=0,
                 training=True):
        """Construct an E2E object for transducer model."""
        torch.nn.Module.__init__(self)

        self.is_rnnt = True

        if "custom" in args.etype:
            if args.enc_block_arch is None:
                raise ValueError(
                    "When specifying custom encoder type, --enc-block-arch"
                    "should also be specified in training config. See"
                    "egs/vivos/asr1/conf/transducer/train_*.yaml for more info."
                )

            self.subsample = get_subsample(args,
                                           mode="asr",
                                           arch="transformer")

            self.encoder = CustomEncoder(
                idim,
                args.enc_block_arch,
                input_layer=args.custom_enc_input_layer,
                repeat_block=args.enc_block_repeat,
                self_attn_type=args.custom_enc_self_attn_type,
                positional_encoding_type=args.
                custom_enc_positional_encoding_type,
                positionwise_activation_type=args.
                custom_enc_pw_activation_type,
                conv_mod_activation_type=args.
                custom_enc_conv_mod_activation_type,
            )
            encoder_out = self.encoder.enc_out

            self.most_dom_list = args.enc_block_arch[:]
        else:
            self.subsample = get_subsample(args, mode="asr", arch="rnn-t")

            self.enc = encoder_for(args, idim, self.subsample)
            encoder_out = args.eprojs

        if "custom" in args.dtype:
            if args.dec_block_arch is None:
                raise ValueError(
                    "When specifying custom decoder type, --dec-block-arch"
                    "should also be specified in training config. See"
                    "egs/vivos/asr1/conf/transducer/train_*.yaml for more info."
                )

            self.decoder = CustomDecoder(
                odim,
                args.dec_block_arch,
                input_layer=args.custom_dec_input_layer,
                repeat_block=args.dec_block_repeat,
                positionwise_activation_type=args.
                custom_dec_pw_activation_type,
                dropout_rate_embed=args.dropout_rate_embed_decoder,
            )
            decoder_out = self.decoder.dunits

            if "custom" in args.etype:
                self.most_dom_list += args.dec_block_arch[:]
            else:
                self.most_dom_list = args.dec_block_arch[:]
        else:
            self.dec = DecoderRNNT(
                odim,
                args.dtype,
                args.dlayers,
                args.dunits,
                blank_id,
                args.dec_embed_dim,
                args.dropout_rate_decoder,
                args.dropout_rate_embed_decoder,
            )
            decoder_out = args.dunits

        self.joint_network = JointNetwork(odim, encoder_out, decoder_out,
                                          args.joint_dim,
                                          args.joint_activation_type)

        if hasattr(self, "most_dom_list"):
            self.most_dom_dim = sorted(
                Counter(d["d_hidden"] for d in self.most_dom_list
                        if "d_hidden" in d).most_common(),
                key=lambda x: x[0],
                reverse=True,
            )[0][0]

        self.etype = args.etype
        self.dtype = args.dtype

        self.sos = odim - 1
        self.eos = odim - 1
        self.blank_id = blank_id
        self.ignore_id = ignore_id

        self.space = args.sym_space
        self.blank = args.sym_blank

        self.odim = odim

        self.reporter = Reporter()

        if training:
            self.criterion = TransLoss(args.trans_type, self.blank_id)

        self.default_parameters(args)

        if training and (args.report_cer or args.report_wer):
            self.error_calculator = ErrorCalculator(
                self.decoder if self.dtype == "custom" else self.dec,
                self.joint_network,
                args.char_list,
                args.sym_space,
                args.sym_blank,
                args.report_cer,
                args.report_wer,
            )
        else:
            self.error_calculator = None

        self.loss = None
        self.rnnlm = None