コード例 #1
0
    def __init__(self, idim, odim, args, ignore_id=-1):
        torch.nn.Module.__init__(self)
        if args.transformer_attn_dropout_rate is None:
            args.transformer_attn_dropout_rate = args.dropout_rate
        self.encoder = Encoder(
            idim=idim,
            center_len=args.transformer_encoder_center_chunk_len,
            left_len=args.transformer_encoder_left_chunk_len,
            hop_len=args.transformer_encoder_hop_len,
            right_len=args.transformer_encoder_right_chunk_len,
            abs_pos=args.transformer_encoder_abs_embed,
            rel_pos=args.transformer_encoder_rel_embed,
            use_mem=args.transformer_encoder_use_memory,
            attention_dim=args.adim,
            attention_heads=args.aheads,
            linear_units=args.eunits,
            num_blocks=args.elayers,
            input_layer=args.transformer_input_layer,
            dropout_rate=args.dropout_rate,
            positional_dropout_rate=args.dropout_rate,
            attention_dropout_rate=args.transformer_attn_dropout_rate
        )
        self.decoder = Decoder(
            odim=odim,
            attention_dim=args.adim,
            attention_heads=args.aheads,
            linear_units=args.dunits,
            num_blocks=args.dlayers,
            dropout_rate=args.dropout_rate,
            positional_dropout_rate=args.dropout_rate,
            self_attention_dropout_rate=args.transformer_attn_dropout_rate,
            src_attention_dropout_rate=args.transformer_attn_dropout_rate
        )
        self.sos = odim - 1
        self.eos = odim - 1
        self.odim = odim
        self.ignore_id = ignore_id
        self.subsample = [1]
        self.reporter = Reporter()

        # self.lsm_weight = a
        self.criterion = LabelSmoothingLoss(self.odim, self.ignore_id, args.lsm_weight,
                                            args.transformer_length_normalized_loss)
        # self.verbose = args.verbose
        self.reset_parameters(args)
        self.adim = args.adim
        self.mtlalpha = args.mtlalpha
        if args.mtlalpha > 0.0:
            self.ctc = CTC(odim, args.adim, args.dropout_rate, ctc_type=args.ctc_type, reduce=True)
        else:
            self.ctc = None

        if args.report_cer or args.report_wer or args.mtlalpha > 0.0:
            from espnet.nets.e2e_asr_common import ErrorCalculator
            self.error_calculator = ErrorCalculator(args.char_list,
                                                    args.sym_space, args.sym_blank,
                                                    args.report_cer, args.report_wer)
        else:
            self.error_calculator = None
        self.rnnlm = None
コード例 #2
0
    def __init__(self, idim, odim, args, ignore_id=-1):
        torch.nn.Module.__init__(self)
        if args.transformer_attn_dropout_rate is None:
            args.transformer_attn_dropout_rate = args.dropout_rate
        self.encoder = Encoder(idim, args)
        self.decoder = Decoder(odim, args)
        self.sos = odim - 1
        self.eos = odim - 1
        self.odim = odim
        self.ignore_id = ignore_id
        self.subsample = [1]
        self.reporter = Reporter()

        # self.lsm_weight = a
        self.criterion = LabelSmoothingLoss(
            self.odim, self.ignore_id, args.lsm_weight,
            args.transformer_length_normalized_loss)
        # self.char_list = args.char_list
        # self.verbose = args.verbose
        self.reset_parameters(args)
        self.recog_args = None  # unused
        self.adim = args.adim
        self.mtlalpha = args.mtlalpha
        if args.mtlalpha > 0.0:
            self.ctc = CTC(odim,
                           args.adim,
                           args.dropout_rate,
                           ctc_type=args.ctc_type,
                           reduce=True)
        else:
            self.ctc = None
コード例 #3
0
    def __init__(self, idim, odim, args, ignore_id=-1):
        """Construct an E2E object.

        :param int idim: dimension of inputs
        :param int odim: dimension of outputs
        :param Namespace args: argument Namespace containing options
        """
        torch.nn.Module.__init__(self)
        if args.transformer_attn_dropout_rate is None:
            args.transformer_attn_dropout_rate = args.dropout_rate
        self.encoder = Encoder(
            idim=idim,
            attention_dim=args.adim,
            attention_heads=args.aheads,
            linear_units=args.eunits,
            num_blocks=args.elayers,
            input_layer=args.transformer_input_layer,
            dropout_rate=args.dropout_rate,
            positional_dropout_rate=args.dropout_rate,
            attention_dropout_rate=args.transformer_attn_dropout_rate
        )
        self.decoder = Decoder(
            odim=odim,
            attention_dim=args.adim,
            attention_heads=args.aheads,
            linear_units=args.dunits,
            num_blocks=args.dlayers,
            dropout_rate=args.dropout_rate,
            positional_dropout_rate=args.dropout_rate,
            self_attention_dropout_rate=args.transformer_attn_dropout_rate,
            src_attention_dropout_rate=args.transformer_attn_dropout_rate
        )
        self.sos = odim - 1
        self.eos = odim - 1
        self.odim = odim
        self.ignore_id = ignore_id
        self.subsample = get_subsample(args, mode='asr', arch='transformer')
        self.reporter = Reporter()

        # self.lsm_weight = a
        self.criterion = LabelSmoothingLoss(self.odim, self.ignore_id, args.lsm_weight,
                                            args.transformer_length_normalized_loss)
        # self.verbose = args.verbose
        self.reset_parameters(args)
        self.adim = args.adim
        self.mtlalpha = args.mtlalpha
        if args.mtlalpha > 0.0:
            self.ctc = CTC(odim, args.adim, args.dropout_rate, ctc_type=args.ctc_type, reduce=True)
        else:
            self.ctc = None

        if args.report_cer or args.report_wer:
            from espnet.nets.e2e_asr_common import ErrorCalculator
            self.error_calculator = ErrorCalculator(args.char_list,
                                                    args.sym_space, args.sym_blank,
                                                    args.report_cer, args.report_wer)
        else:
            self.error_calculator = None
        self.rnnlm = None
コード例 #4
0
    def __init__(self, idim, odim, args, ignore_id=-1):
        """Construct an E2E object.

        :param int idim: dimension of inputs
        :param int odim: dimension of outputs
        :param Namespace args: argument Namespace containing options
        """
        torch.nn.Module.__init__(self)

        # fill missing arguments for compatibility
        args = fill_missing_args(args, self.add_arguments)

        if args.transformer_attn_dropout_rate is None:
            args.transformer_attn_dropout_rate = args.dropout_rate
        self.encoder = Encoder(
            idim=idim,
            selfattention_layer_type=args.transformer_encoder_selfattn_layer_type,
            attention_dim=args.adim,
            attention_heads=args.aheads,
            conv_wshare=args.wshare,
            conv_kernel_length=args.ldconv_encoder_kernel_length,
            conv_usebias=args.ldconv_usebias,
            linear_units=args.eunits,
            num_blocks=args.elayers,
            input_layer=args.transformer_input_layer,
            dropout_rate=args.dropout_rate,
            positional_dropout_rate=args.dropout_rate,
            attention_dropout_rate=args.transformer_attn_dropout_rate,
        )
        if args.mtlalpha < 1:
            self.decoder = Decoder(
                odim=odim,
                selfattention_layer_type=args.transformer_decoder_selfattn_layer_type,
                attention_dim=args.adim,
                attention_heads=args.aheads,
                conv_wshare=args.wshare,
                conv_kernel_length=args.ldconv_decoder_kernel_length,
                conv_usebias=args.ldconv_usebias,
                linear_units=args.dunits,
                num_blocks=args.dlayers,
                dropout_rate=args.dropout_rate,
                positional_dropout_rate=args.dropout_rate,
                self_attention_dropout_rate=args.transformer_attn_dropout_rate,
                src_attention_dropout_rate=args.transformer_attn_dropout_rate,
            )
            self.criterion = LabelSmoothingLoss(
                odim,
                ignore_id,
                args.lsm_weight,
                args.transformer_length_normalized_loss,
            )
        else:
            self.decoder = None
            self.criterion = None
        self.blank = 0
        self.decoder_mode = args.decoder_mode
        if self.decoder_mode == "maskctc":
            self.mask_token = odim - 1
            self.sos = odim - 2
            self.eos = odim - 2
        else:
            self.sos = odim - 1
            self.eos = odim - 1
        self.odim = odim
        self.ignore_id = ignore_id
        self.subsample = get_subsample(args, mode="asr", arch="transformer")
        self.reporter = Reporter()

        self.reset_parameters(args)
        self.adim = args.adim  # used for CTC (equal to d_model)
        self.mtlalpha = args.mtlalpha
        if args.mtlalpha > 0.0:
            self.ctc = CTC(
                odim, args.adim, args.dropout_rate, ctc_type=args.ctc_type, reduce=True
            )
        else:
            self.ctc = None

        if args.report_cer or args.report_wer:
            self.error_calculator = ErrorCalculator(
                args.char_list,
                args.sym_space,
                args.sym_blank,
                args.report_cer,
                args.report_wer,
            )
        else:
            self.error_calculator = None
        self.rnnlm = None
コード例 #5
0
    def __init__(self, idim, odim, args, ignore_id=-1):
        """Construct an E2E object.

        :param int idim: dimension of inputs
        :param int odim: dimension of outputs
        :param Namespace args: argument Namespace containing options
        """
        torch.nn.Module.__init__(self)
        if args.transformer_attn_dropout_rate is None:
            args.transformer_attn_dropout_rate = args.dropout_rate
        self.cn_encoder = Encoder(
            idim=idim,
            attention_dim=args.adim,
            attention_heads=args.aheads,
            linear_units=args.eunits,
            num_blocks=args.elayers,
            input_layer=args.transformer_input_layer,
            dropout_rate=args.dropout_rate,
            positional_dropout_rate=args.dropout_rate,
            attention_dropout_rate=args.transformer_attn_dropout_rate)
        self.en_encoder = Encoder(
            idim=idim,
            attention_dim=args.adim,
            attention_heads=args.aheads,
            linear_units=args.eunits,
            num_blocks=args.elayers,
            input_layer=args.transformer_input_layer,
            dropout_rate=args.dropout_rate,
            positional_dropout_rate=args.dropout_rate,
            attention_dropout_rate=args.transformer_attn_dropout_rate)
        self.decoder = Decoder(
            odim=odim,
            attention_dim=args.adim,
            attention_heads=args.aheads,
            linear_units=args.dunits,
            num_blocks=args.dlayers,
            dropout_rate=args.dropout_rate,
            positional_dropout_rate=args.dropout_rate,
            self_attention_dropout_rate=args.transformer_attn_dropout_rate,
            src_attention_dropout_rate=args.transformer_attn_dropout_rate)
        self.sos = odim - 1
        self.eos = odim - 1
        self.odim = odim
        self.ignore_id = ignore_id
        self.subsample = [1]
        self.reporter = Reporter()

        # self.lsm_weight = a
        self.criterion = LabelSmoothingLoss(
            self.odim, self.ignore_id, args.lsm_weight,
            args.transformer_length_normalized_loss)
        # self.verbose = args.verbose
        self.adim = args.adim
        self.mtlalpha = args.mtlalpha
        if args.mtlalpha > 0.0:
            self.ctc = CTC(odim,
                           args.adim,
                           args.dropout_rate,
                           ctc_type=args.ctc_type,
                           reduce=True)
        else:
            self.ctc = None

        if args.report_cer or args.report_wer:
            from espnet.nets.e2e_asr_common import ErrorCalculator
            self.error_calculator = ErrorCalculator(args.char_list,
                                                    args.sym_space,
                                                    args.sym_blank,
                                                    args.report_cer,
                                                    args.report_wer)
        else:
            self.error_calculator = None
        self.rnnlm = None

        # yzl23 config
        self.remove_blank_in_ctc_mode = True
        self.reset_parameters(args)  # reset params at the last
        self.enc_lambda = args.enc_lambda
        logging.warning("Using fixed encoder lambda: {}".format(
            self.enc_lambda))
        logging.warning(
            "Model total size: {}M, requires_grad size: {}M".format(
                self.count_parameters(),
                self.count_parameters(requires_grad=True)))
コード例 #6
0
    def __init__(self, idim, odim, args, ignore_id=-1):
        """Construct an E2E object.

        :param int idim: dimension of inputs
        :param int odim: dimension of outputs
        :param Namespace args: argument Namespace containing options
        """
        torch.nn.Module.__init__(self)
        if args.transformer_attn_dropout_rate is None:
            args.transformer_attn_dropout_rate = args.dropout_rate
        self.cn_encoder = Encoder(
            idim=idim,
            attention_dim=args.adim,
            attention_heads=args.aheads,
            linear_units=args.eunits,
            num_blocks=args.elayers,
            input_layer=args.transformer_input_layer,
            dropout_rate=args.dropout_rate,
            positional_dropout_rate=args.dropout_rate,
            attention_dropout_rate=args.transformer_attn_dropout_rate)
        self.en_encoder = Encoder(
            idim=idim,
            attention_dim=args.adim,
            attention_heads=args.aheads,
            linear_units=args.eunits,
            num_blocks=args.elayers,
            input_layer=args.transformer_input_layer,
            dropout_rate=args.dropout_rate,
            positional_dropout_rate=args.dropout_rate,
            attention_dropout_rate=args.transformer_attn_dropout_rate)
        # gated add module
        self.vectorize_lambda = args.vectorize_lambda
        lambda_dim = args.adim if self.vectorize_lambda else 1
        self.aggregation_module = torch.nn.Sequential(
            torch.nn.Linear(2 * args.adim, lambda_dim), torch.nn.Sigmoid())
        self.language_divider = 1000

        self.decoder = Decoder(
            odim=odim,
            attention_dim=args.adim,
            attention_heads=args.aheads,
            linear_units=args.dunits,
            num_blocks=args.dlayers,
            dropout_rate=args.dropout_rate,
            positional_dropout_rate=args.dropout_rate,
            self_attention_dropout_rate=args.transformer_attn_dropout_rate,
            src_attention_dropout_rate=args.transformer_attn_dropout_rate)
        self.sos = odim - 1
        self.eos = odim - 1
        self.odim = odim
        self.ignore_id = ignore_id
        self.subsample = [1]
        self.reporter = Reporter()

        # self.lsm_weight = a
        self.criterion = LabelSmoothingLoss(
            self.odim, self.ignore_id, args.lsm_weight,
            args.transformer_length_normalized_loss)
        # self.verbose = args.verbose
        self.adim = args.adim
        self.mtlalpha = args.mtlalpha
        if args.mtlalpha > 0.0:
            self.cn_ctc = CTC(odim,
                              args.adim,
                              args.dropout_rate,
                              ctc_type=args.ctc_type,
                              reduce=True)
            self.en_ctc = CTC(odim,
                              args.adim,
                              args.dropout_rate,
                              ctc_type=args.ctc_type,
                              reduce=True)
        else:
            self.cn_ctc = None
            self.en_ctc = None

        self.rnnlm = None

        # yzl23 config
        self.remove_blank_in_ctc_mode = True
        self.reset_parameters(args)  # reset params at the last

        logging.warning(
            "Model total size: {}M, requires_grad size: {}M".format(
                self.count_parameters(),
                self.count_parameters(requires_grad=True)))
コード例 #7
0
    def __init__(self, idim, odim, args, ignore_id=-1):
        """Construct an E2E object.

        :param int idim: dimension of inputs
        :param int odim: dimension of outputs
        :param Namespace args: argument Namespace containing options
        """
        torch.nn.Module.__init__(self)
        if args.transformer_attn_dropout_rate is None:
            args.transformer_attn_dropout_rate = args.dropout_rate
        self.cn_encoder = Encoder(
            idim=idim,
            attention_dim=args.adim,
            attention_heads=args.aheads,
            linear_units=args.eunits,
            num_blocks=args.elayers,
            input_layer=args.transformer_input_layer,
            dropout_rate=args.dropout_rate,
            positional_dropout_rate=args.dropout_rate,
            attention_dropout_rate=args.transformer_attn_dropout_rate)
        self.en_encoder = Encoder(
            idim=idim,
            attention_dim=args.adim,
            attention_heads=args.aheads,
            linear_units=args.eunits,
            num_blocks=args.elayers,
            input_layer=args.transformer_input_layer,
            dropout_rate=args.dropout_rate,
            positional_dropout_rate=args.dropout_rate,
            attention_dropout_rate=args.transformer_attn_dropout_rate)
        # gated add module
        self.vectorize_lambda = args.vectorize_lambda
        lambda_dim = args.adim if self.vectorize_lambda else 1
        self.aggregation_module = torch.nn.Sequential(
            torch.nn.Linear(2 * args.adim, lambda_dim), torch.nn.Sigmoid())

        self.decoder = Decoder(
            odim=odim,
            attention_dim=args.adim,
            attention_heads=args.aheads,
            linear_units=args.dunits,
            num_blocks=args.dlayers,
            dropout_rate=args.dropout_rate,
            positional_dropout_rate=args.dropout_rate,
            self_attention_dropout_rate=args.transformer_attn_dropout_rate,
            src_attention_dropout_rate=args.transformer_attn_dropout_rate)
        self.sos = odim - 1
        self.eos = odim - 1
        self.odim = odim
        self.ignore_id = ignore_id
        self.subsample = [1]
        self.reporter = Reporter()

        # self.lsm_weight = a
        self.criterion = LabelSmoothingLoss(
            self.odim, self.ignore_id, args.lsm_weight,
            args.transformer_length_normalized_loss)
        # self.verbose = args.verbose
        self.adim = args.adim
        self.mtlalpha = args.mtlalpha
        if args.mtlalpha > 0.0:
            self.ctc = CTC(odim,
                           args.adim,
                           args.dropout_rate,
                           ctc_type=args.ctc_type,
                           reduce=True)
        else:
            self.ctc = None

        if args.report_cer or args.report_wer:
            from espnet.nets.e2e_asr_common import ErrorCalculator
            self.error_calculator = ErrorCalculator(args.char_list,
                                                    args.sym_space,
                                                    args.sym_blank,
                                                    args.report_cer,
                                                    args.report_wer)
        else:
            self.error_calculator = None
        self.rnnlm = None

        # yzl23 config
        self.remove_blank_in_ctc_mode = True
        self.reset_parameters(args)  # reset params at the last

        # we frozen params here
        if args.activated_keys:
            activated_keys = args.activated_keys.split(',')
            for name, params in self.named_parameters():
                requires_grad = False  # by default, we'd like to frozen all params
                for key in activated_keys:
                    if key in name:
                        requires_grad = True  # hit the key, activate this param
                params.requires_grad = requires_grad
        else:
            logging.warning("Not frozen anything.")

        logging.warning(
            "Model total size: {}M, requires_grad size: {}M".format(
                self.count_parameters(),
                self.count_parameters(requires_grad=True)))
コード例 #8
0
    def __init__(self, idim, odim, args, ignore_id=-1):
        """Construct an E2E object.

        :param int idim: dimension of inputs
        :param int odim: dimension of outputs
        :param Namespace args: argument Namespace containing options
        """
        torch.nn.Module.__init__(self)
        if args.transformer_attn_dropout_rate is None:
            args.transformer_attn_dropout_rate = args.dropout_rate
        self.encoder = Encoder(
            idim=idim,
            attention_dim=args.adim,
            attention_heads=args.aheads,
            linear_units=args.eunits,
            num_blocks=args.elayers,
            input_layer=args.transformer_input_layer,
            dropout_rate=args.dropout_rate,
            positional_dropout_rate=args.dropout_rate,
            attention_dropout_rate=args.transformer_attn_dropout_rate,
        )
        # target matching system organization
        self.oversampling = args.oversampling
        self.residual = args.residual
        self.outer = args.outer
        self.poster = torch.nn.Linear(args.adim, odim * self.oversampling)
        if self.outer:
            if self.residual:
                self.matcher_res = torch.nn.Linear(idim, odim)
                self.matcher = torch.nn.Linear(odim, odim)
            else:
                self.matcher = torch.nn.Linear(odim + idim, odim)

        self.sos = odim - 1
        self.eos = odim - 1
        self.odim = odim
        self.ignore_id = ignore_id
        self.subsample = get_subsample(args, mode="asr", arch="transformer")
        self.reporter = Reporter()

        # self.lsm_weight = a
        self.criterion = LabelSmoothingLoss(
            self.odim,
            self.ignore_id,
            args.lsm_weight,
            args.transformer_length_normalized_loss,
        )
        # self.verbose = args.verbose
        self.reset_parameters(args)
        self.adim = args.adim
        self.mtlalpha = args.mtlalpha
        if args.mtlalpha > 0.0:
            self.ctc = CTC(
                odim, args.adim, args.dropout_rate, ctc_type=args.ctc_type, reduce=True
            )
        else:
            self.ctc = None

        if args.report_cer or args.report_wer:
            self.error_calculator = ErrorCalculator(
                args.char_list,
                args.sym_space,
                args.sym_blank,
                args.report_cer,
                args.report_wer,
            )
        else:
            self.error_calculator = None
        self.rnnlm = None
コード例 #9
0
ファイル: e2e_asr_mix.py プロジェクト: diamondasr/espnet-1
    def __init__(self, idim, odim, args):
        """Initialize multi-speaker E2E module."""
        super(E2E, self).__init__()
        torch.nn.Module.__init__(self)
        self.mtlalpha = args.mtlalpha
        assert 0.0 <= self.mtlalpha <= 1.0, "mtlalpha should be [0.0, 1.0]"
        self.etype = args.etype
        self.verbose = args.verbose
        # NOTE: for self.build method
        args.char_list = getattr(args, "char_list", None)
        self.char_list = args.char_list
        self.outdir = args.outdir
        self.space = args.sym_space
        self.blank = args.sym_blank
        self.reporter = Reporter()
        self.num_spkrs = args.num_spkrs
        self.spa = args.spa
        self.pit = PIT(self.num_spkrs)

        # below means the last number becomes eos/sos ID
        # note that sos/eos IDs are identical
        self.sos = odim - 1
        self.eos = odim - 1

        # subsample info
        self.subsample = get_subsample(args, mode="asr", arch="rnn_mix")

        # label smoothing info
        if args.lsm_type and os.path.isfile(args.train_json):
            logging.info("Use label smoothing with " + args.lsm_type)
            labeldist = label_smoothing_dist(
                odim, args.lsm_type, transcript=args.train_json
            )
        else:
            labeldist = None

        if getattr(args, "use_frontend", False):  # use getattr to keep compatibility
            self.frontend = frontend_for(args, idim)
            self.feature_transform = feature_transform_for(args, (idim - 1) * 2)
            idim = args.n_mels
        else:
            self.frontend = None

        # encoder
        self.enc = encoder_for(args, idim, self.subsample)
        # ctc
        self.ctc = ctc_for(args, odim, reduce=False)
        # attention
        num_att = self.num_spkrs if args.spa else 1
        self.att = att_for(args, num_att)
        # decoder
        self.dec = decoder_for(args, odim, self.sos, self.eos, self.att, labeldist)

        # weight initialization
        self.init_like_chainer()

        # options for beam search
        if "report_cer" in vars(args) and (args.report_cer or args.report_wer):
            recog_args = {
                "beam_size": args.beam_size,
                "penalty": args.penalty,
                "ctc_weight": args.ctc_weight,
                "maxlenratio": args.maxlenratio,
                "minlenratio": args.minlenratio,
                "lm_weight": args.lm_weight,
                "rnnlm": args.rnnlm,
                "nbest": args.nbest,
                "space": args.sym_space,
                "blank": args.sym_blank,
            }

            self.recog_args = argparse.Namespace(**recog_args)
            self.report_cer = args.report_cer
            self.report_wer = args.report_wer
        else:
            self.report_cer = False
            self.report_wer = False
        self.rnnlm = None

        self.logzero = -10000000000.0
        self.loss = None
        self.acc = None
コード例 #10
0
    def __init__(self,
                 num_time_mask=2,
                 num_freq_mask=2,
                 freq_mask_length=15,
                 time_mask_length=15,
                 feature_dim=320,
                 model_size=512,
                 feed_forward_size=1024,
                 hidden_size=64,
                 dropout=0.1,
                 num_head=8,
                 num_encoder_layer=6,
                 num_decoder_layer=6,
                 vocab_path='testing_vocab.model',
                 max_feature_length=1024,
                 max_token_length=50,
                 enable_spec_augment=True,
                 share_weight=True,
                 smoothing=0.1,
                 restrict_left_length=20,
                 restrict_right_length=20,
                 mtlalpha=0.2,
                 report_wer=True):
        super(Transformer, self).__init__()

        self.enable_spec_augment = enable_spec_augment
        self.max_token_length = max_token_length
        self.restrict_left_length = restrict_left_length
        self.restrict_right_length = restrict_right_length
        self.vocab = Vocab(vocab_path)
        self.sos = self.vocab.bos_id
        self.eos = self.vocab.eos_id
        self.adim = model_size
        self.odim = self.vocab.vocab_size
        self.ignore_id = self.vocab.pad_id

        if enable_spec_augment:
            self.spec_augment = SpecAugment(
                num_time_mask=num_time_mask,
                num_freq_mask=num_freq_mask,
                freq_mask_length=freq_mask_length,
                time_mask_length=time_mask_length,
                max_sequence_length=max_feature_length)

        self.encoder = Encoder(idim=feature_dim,
                               attention_dim=model_size,
                               attention_heads=num_head,
                               linear_units=feed_forward_size,
                               num_blocks=num_encoder_layer,
                               dropout_rate=dropout,
                               positional_dropout_rate=dropout,
                               attention_dropout_rate=dropout,
                               input_layer='linear',
                               padding_idx=self.vocab.pad_id)

        self.decoder = Decoder(odim=self.vocab.vocab_size,
                               attention_dim=model_size,
                               attention_heads=num_head,
                               linear_units=feed_forward_size,
                               num_blocks=num_decoder_layer,
                               dropout_rate=dropout,
                               positional_dropout_rate=dropout,
                               self_attention_dropout_rate=dropout,
                               src_attention_dropout_rate=0,
                               input_layer='embed',
                               use_output_layer=False)
        self.decoder_linear = t.nn.Linear(model_size,
                                          self.vocab.vocab_size,
                                          bias=True)
        self.decoder_switch_linear = t.nn.Linear(model_size, 4, bias=True)

        self.criterion = LabelSmoothingLoss(size=self.odim,
                                            smoothing=smoothing,
                                            padding_idx=self.vocab.pad_id,
                                            normalize_length=True)
        self.switch_criterion = LabelSmoothingLoss(
            size=4,
            smoothing=0,
            padding_idx=self.vocab.pad_id,
            normalize_length=True)
        self.mtlalpha = mtlalpha
        if mtlalpha > 0.0:
            self.ctc = CTC(self.odim,
                           eprojs=self.adim,
                           dropout_rate=dropout,
                           ctc_type='builtin',
                           reduce=False)
        else:
            self.ctc = None

        if report_wer:
            from espnet.nets.e2e_asr_common import ErrorCalculator

            def load_token_list(path=vocab_path.replace('.model', '.vocab')):
                with open(path) as reader:
                    data = reader.readlines()
                    data = [i.split('\t')[0] for i in data]
                return data

            self.char_list = load_token_list()
            self.error_calculator = ErrorCalculator(
                char_list=self.char_list,
                sym_space=' ',
                sym_blank=self.vocab.blank_token,
                report_wer=True)
        else:
            self.error_calculator = None
        self.rnnlm = None
        self.reporter = Reporter()

        self.switch_loss = LabelSmoothingLoss(size=4,
                                              smoothing=0,
                                              padding_idx=0)
        print('initing')
        initialize(self, init_type='xavier_normal')
        print('inited')
コード例 #11
0
ファイル: e2e_asr_transformer.py プロジェクト: sas91/espnet
    def __init__(self, idim, odim, args, ignore_id=-1):
        """Construct an E2E object.

        :param int idim: dimension of inputs
        :param int odim: dimension of outputs
        :param Namespace args: argument Namespace containing options
        """
        torch.nn.Module.__init__(self)
        if args.transformer_attn_dropout_rate is None:
            args.transformer_attn_dropout_rate = args.dropout_rate
        self.encoder = Encoder(
            idim=idim,
            selfattention_layer_type=args.
            transformer_encoder_selfattn_layer_type,
            attention_dim=args.adim,
            attention_heads=args.aheads,
            conv_wshare=args.wshare,
            conv_kernel_length=args.ldconv_encoder_kernel_length,
            conv_usebias=args.ldconv_usebias,
            linear_units=args.eunits,
            num_blocks=args.elayers,
            input_layer=args.transformer_input_layer,
            dropout_rate=args.dropout_rate,
            positional_dropout_rate=args.dropout_rate,
            attention_dropout_rate=args.transformer_attn_dropout_rate,
            attention_type=getattr(args, 'transformer_enc_attn_type',
                                   'self_attn'),
            max_attn_span=getattr(args, 'enc_max_attn_span', [None]),
            span_init=getattr(args, 'span_init', None),
            span_ratio=getattr(args, 'span_ratio', None),
            ratio_adaptive=getattr(args, 'ratio_adaptive', None))
        self.decoder = Decoder(
            odim=odim,
            selfattention_layer_type=args.
            transformer_decoder_selfattn_layer_type,
            attention_dim=args.adim,
            attention_heads=args.aheads,
            conv_wshare=args.wshare,
            conv_kernel_length=args.ldconv_decoder_kernel_length,
            conv_usebias=args.ldconv_usebias,
            linear_units=args.dunits,
            num_blocks=args.dlayers,
            dropout_rate=args.dropout_rate,
            positional_dropout_rate=args.dropout_rate,
            self_attention_dropout_rate=args.transformer_attn_dropout_rate,
            src_attention_dropout_rate=args.transformer_attn_dropout_rate,
            attention_type=getattr(args, 'transformer_dec_attn_type',
                                   'self_attn'),
            max_attn_span=getattr(args, 'dec_max_attn_span', [None]),
            span_init=getattr(args, 'span_init', None),
            span_ratio=getattr(args, 'span_ratio', None),
            ratio_adaptive=getattr(args, 'ratio_adaptive', None))
        self.sos = odim - 1
        self.eos = odim - 1
        self.odim = odim
        self.ignore_id = ignore_id
        self.subsample = get_subsample(args, mode="asr", arch="transformer")
        self.reporter = Reporter()

        # self.lsm_weight = a
        self.criterion = LabelSmoothingLoss(
            self.odim,
            self.ignore_id,
            args.lsm_weight,
            args.transformer_length_normalized_loss,
        )
        # self.verbose = args.verbose
        self.reset_parameters(args)
        self.adim = args.adim
        self.mtlalpha = args.mtlalpha
        if args.mtlalpha > 0.0:
            self.ctc = CTC(odim,
                           args.adim,
                           args.dropout_rate,
                           ctc_type=args.ctc_type,
                           reduce=True)
        else:
            self.ctc = None

        if args.report_cer or args.report_wer:
            self.error_calculator = ErrorCalculator(
                args.char_list,
                args.sym_space,
                args.sym_blank,
                args.report_cer,
                args.report_wer,
            )
        else:
            self.error_calculator = None
        self.rnnlm = None
        self.attention_enc_type = getattr(args, 'transformer_enc_attn_type',
                                          'self_attn')
        self.attention_dec_type = getattr(args, 'transformer_dec_attn_type',
                                          'self_attn')
        self.span_loss_coef = getattr(args, 'span_loss_coef', None)
        self.ratio_adaptive = getattr(args, 'ratio_adaptive', None)
        self.sym_blank = args.sym_blank