Exemple #1
0
    def __init__(self, idim, odim, args):
        super(E2E, self).__init__()
        torch.nn.Module.__init__(self)
        self.etype = args.etype
        self.verbose = args.verbose
        self.char_list = args.char_list
        self.outdir = args.outdir
        self.reporter = Reporter()

        # below means the last number becomes eos/sos ID
        # note that sos/eos IDs are identical
        self.sos = odim - 1
        self.eos = odim - 1
        self.pad = odim

        # subsample info
        # +1 means input (+1) and layers outputs (args.elayer)
        subsample = np.ones(args.elayers + 1, dtype=np.int)
        logging.warning(
            'Subsampling is not performed for machine translation.')
        logging.info('subsample: ' + ' '.join([str(x) for x in subsample]))
        self.subsample = subsample

        # label smoothing info
        if args.lsm_type and os.path.isfile(args.train_json):
            logging.info("Use label smoothing with " + args.lsm_type)
            labeldist = label_smoothing_dist(odim,
                                             args.lsm_type,
                                             transcript=args.train_json)
        else:
            labeldist = None

        # multilingual related
        self.replace_sos = args.replace_sos

        # encoder
        self.embed_src = torch.nn.Embedding(idim + 1,
                                            args.eunits,
                                            padding_idx=idim)
        # NOTE: +1 means the padding index
        self.dropout_emb_src = torch.nn.Dropout(p=args.dropout_rate)
        self.enc = encoder_for(args, args.eunits, self.subsample)
        # attention
        self.att = att_for(args)
        # decoder
        self.dec = decoder_for(args, odim, self.sos, self.eos, self.att,
                               labeldist)

        # weight initialization
        self.init_like_chainer()

        self.rnnlm = None

        self.logzero = -10000000000.0
        self.loss = None
        self.acc = None
Exemple #2
0
    def __init__(self, idim, odim, args):
        """Construct an E2E object.

        :param int idim: dimension of inputs
        :param int odim: dimension of outputs
        :param Namespace args: argument Namespace containing options
        """
        super(E2E, self).__init__()
        torch.nn.Module.__init__(self)

        # fill missing arguments for compatibility
        args = fill_missing_args(args, self.add_arguments)

        self.asr_weight = args.asr_weight
        self.mt_weight = args.mt_weight
        self.mtlalpha = args.mtlalpha
        assert 0.0 <= self.asr_weight < 1.0, "asr_weight should be [0.0, 1.0)"
        assert 0.0 <= self.mt_weight < 1.0, "mt_weight should be [0.0, 1.0)"
        assert 0.0 <= self.mtlalpha <= 1.0, "mtlalpha should be [0.0, 1.0]"
        self.etype = args.etype
        self.verbose = args.verbose
        # NOTE: for self.build method
        args.char_list = getattr(args, "char_list", None)
        self.char_list = args.char_list
        self.outdir = args.outdir
        self.space = args.sym_space
        self.blank = args.sym_blank
        self.reporter = Reporter()

        # below means the last number becomes eos/sos ID
        # note that sos/eos IDs are identical
        self.sos = odim - 1
        self.eos = odim - 1
        self.pad = 0
        # NOTE: we reserve index:0 for <pad> although this is reserved for a blank class
        # in ASR. However, blank labels are not used in MT.
        # To keep the vocabulary size,
        # we use index:0 for padding instead of adding one more class.

        # subsample info
        self.subsample = get_subsample(args, mode="st", arch="rnn")

        # label smoothing info
        if args.lsm_type and os.path.isfile(args.train_json):
            logging.info("Use label smoothing with " + args.lsm_type)
            labeldist = label_smoothing_dist(odim,
                                             args.lsm_type,
                                             transcript=args.train_json)
        else:
            labeldist = None

        # multilingual related
        self.multilingual = getattr(args, "multilingual", False)
        self.replace_sos = getattr(args, "replace_sos", False)

        # encoder
        self.enc = encoder_for(args, idim, self.subsample)
        # attention (ST)
        self.att = att_for(args)
        # decoder (ST)
        self.dec = decoder_for(args, odim, self.sos, self.eos, self.att,
                               labeldist)

        # submodule for ASR task
        self.ctc = None
        self.att_asr = None
        self.dec_asr = None
        if self.asr_weight > 0:
            if self.mtlalpha > 0.0:
                self.ctc = CTC(
                    odim,
                    args.eprojs,
                    args.dropout_rate,
                    ctc_type=args.ctc_type,
                    reduce=True,
                )
            if self.mtlalpha < 1.0:
                # attention (asr)
                self.att_asr = att_for(args)
                # decoder (asr)
                args_asr = copy.deepcopy(args)
                args_asr.atype = "location"  # TODO(hirofumi0810): make this option
                self.dec_asr = decoder_for(args_asr, odim, self.sos, self.eos,
                                           self.att_asr, labeldist)

        # submodule for MT task
        if self.mt_weight > 0:
            self.embed_mt = torch.nn.Embedding(odim,
                                               args.eunits,
                                               padding_idx=self.pad)
            self.dropout_mt = torch.nn.Dropout(p=args.dropout_rate)
            self.enc_mt = encoder_for(args,
                                      args.eunits,
                                      subsample=np.ones(args.elayers + 1,
                                                        dtype=np.int))

        # weight initialization
        self.init_like_chainer()

        # options for beam search
        if self.asr_weight > 0 and args.report_cer or args.report_wer:
            recog_args = {
                "beam_size": args.beam_size,
                "penalty": args.penalty,
                "ctc_weight": args.ctc_weight,
                "maxlenratio": args.maxlenratio,
                "minlenratio": args.minlenratio,
                "lm_weight": args.lm_weight,
                "rnnlm": args.rnnlm,
                "nbest": args.nbest,
                "space": args.sym_space,
                "blank": args.sym_blank,
                "tgt_lang": False,
            }

            self.recog_args = argparse.Namespace(**recog_args)
            self.report_cer = args.report_cer
            self.report_wer = args.report_wer
        else:
            self.report_cer = False
            self.report_wer = False
        if args.report_bleu:
            trans_args = {
                "beam_size": args.beam_size,
                "penalty": args.penalty,
                "ctc_weight": 0,
                "maxlenratio": args.maxlenratio,
                "minlenratio": args.minlenratio,
                "lm_weight": args.lm_weight,
                "rnnlm": args.rnnlm,
                "nbest": args.nbest,
                "space": args.sym_space,
                "blank": args.sym_blank,
                "tgt_lang": False,
            }

            self.trans_args = argparse.Namespace(**trans_args)
            self.report_bleu = args.report_bleu
        else:
            self.report_bleu = False
        self.rnnlm = None

        self.logzero = -10000000000.0
        self.loss = None
        self.acc = None
Exemple #3
0
    def __init__(self, idim, odim, args):
        """Construct an E2E object.

        :param int idim: dimension of inputs
        :param int odim: dimension of outputs
        :param Namespace args: argument Namespace containing options
        """
        super(E2E, self).__init__()
        torch.nn.Module.__init__(self)

        # fill missing arguments for compatibility
        args = fill_missing_args(args, self.add_arguments)

        self.etype = args.etype
        self.verbose = args.verbose
        # NOTE: for self.build method
        args.char_list = getattr(args, "char_list", None)
        self.char_list = args.char_list
        self.outdir = args.outdir
        self.space = args.sym_space
        self.blank = args.sym_blank
        self.reporter = Reporter()

        # below means the last number becomes eos/sos ID
        # note that sos/eos IDs are identical
        self.sos = odim - 1
        self.eos = odim - 1
        self.pad = 0
        # NOTE: we reserve index:0 for <pad> although this is reserved for a blank class
        # in ASR. However, blank labels are not used in MT.
        # To keep the vocabulary size,
        # we use index:0 for padding instead of adding one more class.

        # subsample info
        self.subsample = get_subsample(args, mode="mt", arch="rnn")

        # label smoothing info
        if args.lsm_type and os.path.isfile(args.train_json):
            logging.info("Use label smoothing with " + args.lsm_type)
            labeldist = label_smoothing_dist(odim,
                                             args.lsm_type,
                                             transcript=args.train_json)
        else:
            labeldist = None

        # multilingual related
        self.multilingual = getattr(args, "multilingual", False)
        self.replace_sos = getattr(args, "replace_sos", False)

        # encoder
        self.embed = torch.nn.Embedding(idim,
                                        args.eunits,
                                        padding_idx=self.pad)
        self.dropout = torch.nn.Dropout(p=args.dropout_rate)
        self.enc = encoder_for(args, args.eunits, self.subsample)
        # attention
        self.att = att_for(args)
        # decoder
        self.dec = decoder_for(args, odim, self.sos, self.eos, self.att,
                               labeldist)

        # tie source and target emeddings
        if args.tie_src_tgt_embedding:
            if idim != odim:
                raise ValueError(
                    "When using tie_src_tgt_embedding, idim and odim must be equal."
                )
            if args.eunits != args.dunits:
                raise ValueError(
                    "When using tie_src_tgt_embedding, eunits and dunits must be equal."
                )
            self.embed.weight = self.dec.embed.weight

        # tie emeddings and the classfier
        if args.tie_classifier:
            if args.context_residual:
                raise ValueError(
                    "When using tie_classifier, context_residual must be turned off."
                )
            self.dec.output.weight = self.dec.embed.weight

        # weight initialization
        self.init_like_fairseq()

        # options for beam search
        if args.report_bleu:
            trans_args = {
                "beam_size": args.beam_size,
                "penalty": args.penalty,
                "ctc_weight": 0,
                "maxlenratio": args.maxlenratio,
                "minlenratio": args.minlenratio,
                "lm_weight": args.lm_weight,
                "rnnlm": args.rnnlm,
                "nbest": args.nbest,
                "space": args.sym_space,
                "blank": args.sym_blank,
                "tgt_lang": False,
            }

            self.trans_args = argparse.Namespace(**trans_args)
            self.report_bleu = args.report_bleu
        else:
            self.report_bleu = False
        self.rnnlm = None

        self.logzero = -10000000000.0
        self.loss = None
        self.acc = None
Exemple #4
0
    def __init__(self, idims, odim, args):
        """Initialize this class with python-level args.

        Args:
            idims (list): list of the number of an input feature dim.
            odim (int): The number of output vocab.
            args (Namespace): arguments

        """
        super(E2E, self).__init__()
        torch.nn.Module.__init__(self)
        self.mtlalpha = args.mtlalpha
        assert 0.0 <= self.mtlalpha <= 1.0, "mtlalpha should be [0.0, 1.0]"
        self.verbose = args.verbose
        # NOTE: for self.build method
        args.char_list = getattr(args, "char_list", None)
        self.char_list = args.char_list
        self.outdir = args.outdir
        self.space = args.sym_space
        self.blank = args.sym_blank
        self.reporter = Reporter()
        self.num_encs = args.num_encs
        self.share_ctc = args.share_ctc

        # below means the last number becomes eos/sos ID
        # note that sos/eos IDs are identical
        self.sos = odim - 1
        self.eos = odim - 1

        # subsample info
        self.subsample_list = get_subsample(args,
                                            mode="asr",
                                            arch="rnn_mulenc")

        # label smoothing info
        if args.lsm_type and os.path.isfile(args.train_json):
            logging.info("Use label smoothing with " + args.lsm_type)
            labeldist = label_smoothing_dist(odim,
                                             args.lsm_type,
                                             transcript=args.train_json)
        else:
            labeldist = None

        # speech translation related
        self.replace_sos = getattr(args, "replace_sos",
                                   False)  # use getattr to keep compatibility

        self.frontend = None

        # encoder
        self.enc = encoder_for(args, idims, self.subsample_list)
        # ctc
        self.ctc = ctc_for(args, odim)
        # attention
        self.att = att_for(args)
        # hierarchical attention network
        han = att_for(args, han_mode=True)
        self.att.append(han)
        # decoder
        self.dec = decoder_for(args, odim, self.sos, self.eos, self.att,
                               labeldist)

        if args.mtlalpha > 0 and self.num_encs > 1:
            # weights-ctc,
            # e.g. ctc_loss = w_1*ctc_1_loss + w_2 * ctc_2_loss + w_N * ctc_N_loss
            self.weights_ctc_train = args.weights_ctc_train / np.sum(
                args.weights_ctc_train)  # normalize
            self.weights_ctc_dec = args.weights_ctc_dec / np.sum(
                args.weights_ctc_dec)  # normalize
            logging.info("ctc weights (training during training): " +
                         " ".join([str(x) for x in self.weights_ctc_train]))
            logging.info("ctc weights (decoding during training): " +
                         " ".join([str(x) for x in self.weights_ctc_dec]))
        else:
            self.weights_ctc_dec = [1.0]
            self.weights_ctc_train = [1.0]

        # weight initialization
        self.init_like_chainer()

        # options for beam search
        if args.report_cer or args.report_wer:
            recog_args = {
                "beam_size": args.beam_size,
                "penalty": args.penalty,
                "ctc_weight": args.ctc_weight,
                "maxlenratio": args.maxlenratio,
                "minlenratio": args.minlenratio,
                "lm_weight": args.lm_weight,
                "rnnlm": args.rnnlm,
                "nbest": args.nbest,
                "space": args.sym_space,
                "blank": args.sym_blank,
                "tgt_lang": False,
                "ctc_weights_dec": self.weights_ctc_dec,
            }

            self.recog_args = argparse.Namespace(**recog_args)
            self.report_cer = args.report_cer
            self.report_wer = args.report_wer
        else:
            self.report_cer = False
            self.report_wer = False
        self.rnnlm = None

        self.logzero = -10000000000.0
        self.loss = None
        self.acc = None
Exemple #5
0
    def __init__(self, idim, odim, args):
        """Construct an E2E object.

        :param int idim: dimension of inputs
        :param int odim: dimension of outputs
        :param Namespace args: argument Namespace containing options
        """
        super(E2E, self).__init__()
        torch.nn.Module.__init__(self)
        self.asr_weight = getattr(args, "asr_weight", 0)
        self.mt_weight = getattr(args, "mt_weight", 0)
        self.mtlalpha = args.mtlalpha
        assert 0.0 <= self.asr_weight < 1.0, "asr_weight should be [0.0, 1.0)"
        assert 0.0 <= self.mt_weight < 1.0, "mt_weight should be [0.0, 1.0)"
        assert 0.0 <= self.mtlalpha <= 1.0, "mtlalpha should be [0.0, 1.0]"
        self.etype = args.etype
        self.verbose = args.verbose
        # NOTE: for self.build method
        args.char_list = getattr(args, "char_list", None)
        self.char_list = args.char_list
        self.outdir = args.outdir
        self.space = args.sym_space
        self.blank = args.sym_blank
        self.reporter = Reporter()

        # below means the last number becomes eos/sos ID
        # note that sos/eos IDs are identical
        self.sos = odim - 1
        self.eos = odim - 1
        self.pad = 0
        # NOTE: we reserve index:0 for <pad> although this is reserved for a blank class
        # in ASR. However, blank labels are not used in NMT. To keep the vocabulary size,
        # we use index:0 for padding instead of adding one more class.

        # subsample info
        # +1 means input (+1) and layers outputs (args.elayer)
        subsample = np.ones(args.elayers + 1, dtype=np.int)
        if args.etype.endswith("p") and not args.etype.startswith("vgg"):
            ss = args.subsample.split("_")
            for j in range(min(args.elayers + 1, len(ss))):
                subsample[j] = int(ss[j])
        else:
            logging.warning(
                'Subsampling is not performed for vgg*. It is performed in max pooling layers at CNN.'
            )
        logging.info('subsample: ' + ' '.join([str(x) for x in subsample]))
        self.subsample = subsample

        # label smoothing info
        if args.lsm_type and os.path.isfile(args.train_json):
            logging.info("Use label smoothing with " + args.lsm_type)
            labeldist = label_smoothing_dist(odim,
                                             args.lsm_type,
                                             transcript=args.train_json)
        else:
            labeldist = None

        # multilingual E2E-ST related
        self.multilingual = getattr(args, "multilingual", False)
        self.joint_asr = getattr(args, "joint_asr", False)
        self.replace_sos = getattr(args, "replace_sos", False)

        # encoder
        self.enc = encoder_for(args, idim, self.subsample)
        # attention (ST)
        self.att = att_for(args)
        # decoder (ST)
        self.dec = decoder_for(args, odim, self.sos, self.eos, self.att,
                               labeldist)

        # submodule for ASR task
        self.ctc = None
        self.att_asr = None
        self.dec_asr = None
        if self.asr_weight > 0:
            if self.mtlalpha > 0.0:
                self.ctc = CTC(odim,
                               args.eprojs,
                               args.dropout_rate,
                               ctc_type=args.ctc_type,
                               reduce=True)
            if self.mtlalpha < 1.0:
                # attention (asr)
                self.att_asr = att_for(args)
                # decoder (asr)
                args_asr = copy.deepcopy(args)
                args_asr.atype = 'location'  # TODO(hirofumi0810): make this option
                self.dec_asr = decoder_for(args_asr, odim, self.sos, self.eos,
                                           self.att_asr, labeldist)

        # submodule for MT task
        if self.mt_weight > 0:
            self.embed_mt = torch.nn.Embedding(odim,
                                               args.eunits,
                                               padding_idx=self.pad)
            self.dropout_mt = torch.nn.Dropout(p=args.dropout_rate)
            self.enc_mt = encoder_for(args,
                                      args.eunits,
                                      subsample=np.ones(args.elayers + 1,
                                                        dtype=np.int))

        # weight initialization
        self.init_like_chainer()

        # options for beam search
        if self.asr_weight > 0 and args.report_cer or args.report_wer:
            recog_args = {
                'beam_size': args.beam_size,
                'penalty': args.penalty,
                'ctc_weight': args.ctc_weight,
                'maxlenratio': args.maxlenratio,
                'minlenratio': args.minlenratio,
                'lm_weight': args.lm_weight,
                'rnnlm': args.rnnlm,
                'nbest': args.nbest,
                'space': args.sym_space,
                'blank': args.sym_blank,
                'tgt_lang': False
            }

            self.recog_args = argparse.Namespace(**recog_args)
            self.report_cer = args.report_cer
            self.report_wer = args.report_wer
        else:
            self.report_cer = False
            self.report_wer = False
        if args.report_bleu:
            trans_args = {
                'beam_size': args.beam_size,
                'penalty': args.penalty,
                'ctc_weight': 0,
                'maxlenratio': args.maxlenratio,
                'minlenratio': args.minlenratio,
                'lm_weight': args.lm_weight,
                'rnnlm': args.rnnlm,
                'nbest': args.nbest,
                'space': args.sym_space,
                'blank': args.sym_blank,
                'tgt_lang': False
            }

            self.trans_args = argparse.Namespace(**trans_args)
            self.report_bleu = args.report_bleu
        else:
            self.report_bleu = False
        self.rnnlm = None

        self.logzero = -10000000000.0
        self.loss = None
        self.acc = None
Exemple #6
0
    def __init__(self, idim, odim, args):
        super(E2E, self).__init__()
        torch.nn.Module.__init__(self)
        self.mtlalpha = args.mtlalpha
        assert 0.0 <= self.mtlalpha <= 1.0, "mtlalpha should be [0.0, 1.0]"
        self.etype = args.etype
        self.verbose = args.verbose
        # NOTE: for self.build method
        args.char_list = getattr(args, "char_list", None)
        self.char_list = args.char_list
        self.outdir = args.outdir
        self.space = args.sym_space
        self.blank = args.sym_blank
        # self.oracle_length = args.oracle_length
        self.reporter = Reporter()

        # below means the last number becomes eos/sos ID
        # note that sos/eos IDs are identical
        self.sos = odim - 1
        self.eos = odim - 1

        # subsample info
        # +1 means input (+1) and layers outputs (args.elayer)
        subsample = np.ones(args.elayers + 1, dtype=np.int)
        if args.etype.endswith("p") and not args.etype.startswith("vgg"):
            ss = args.subsample.split("_")
            for j in range(min(args.elayers + 1, len(ss))):
                subsample[j] = int(ss[j])
        else:
            logging.warning(
                'Subsampling is not performed for vgg*. It is performed in max pooling layers at CNN.'
            )
        logging.info('subsample: ' + ' '.join([str(x) for x in subsample]))
        self.subsample = subsample

        # label smoothing info
        if args.lsm_type and os.path.isfile(args.train_json):
            logging.info("Use label smoothing with " + args.lsm_type)
            labeldist = label_smoothing_dist(odim,
                                             args.lsm_type,
                                             transcript=args.train_json)
        else:
            labeldist = None

        # speech translation related
        self.replace_sos = getattr(args, "replace_sos",
                                   False)  # use getattr to keep compatibility

        if getattr(args, "use_frontend",
                   False):  # use getattr to keep compatibility
            # Relative importing because of using python3 syntax
            from espnet.nets.pytorch_backend.frontends.feature_transform \
                import feature_transform_for
            from espnet.nets.pytorch_backend.frontends.frontend \
                import frontend_for

            self.frontend = frontend_for(args, idim)
            self.feature_transform = feature_transform_for(
                args, (idim - 1) * 2)
            idim = args.n_mels
        else:
            self.frontend = None

        # encoder
        self.enc = encoder_for(args, idim, self.subsample)
        # ctc
        self.ctc = ctc_for(args, odim)
        # attention
        self.att = att_for(args)
        # decoder
        self.dec = decoder_for(args, odim, self.sos, self.eos, self.att,
                               labeldist)

        # weight initialization
        self.init_like_chainer()

        # options for beam search
        if args.report_cer or args.report_wer:
            recog_args = {
                'beam_size': args.beam_size,
                'penalty': args.penalty,
                'ctc_weight': args.ctc_weight,
                'maxlenratio': args.maxlenratio,
                'minlenratio': args.minlenratio,
                'lm_weight': args.lm_weight,
                'rnnlm': args.rnnlm,
                'nbest': args.nbest,
                'space': args.sym_space,
                'blank': args.sym_blank,
                'tgt_lang': False,
                'sampling': args.sampling
            }

            self.recog_args = argparse.Namespace(**recog_args)
            self.report_cer = args.report_cer
            self.report_wer = args.report_wer
        else:
            self.report_cer = False
            self.report_wer = False
        self.rnnlm = None

        self.logzero = -10000000000.0
        self.loss = None
        self.acc = None
        self.loss_nll = torch.nn.NLLLoss()
Exemple #7
0
    def __init__(self, idims, odim, args):
        """Initialize this class with python-level args.

        Args:
            idims (list): list of the number of an input feature dim.
            odim (int): The number of output vocab.
            args (Namespace): arguments

        """
        super(E2E, self).__init__()
        torch.nn.Module.__init__(self)
        self.mtlalpha = args.mtlalpha
        assert 0.0 <= self.mtlalpha <= 1.0, "mtlalpha should be [0.0, 1.0]"
        self.verbose = args.verbose
        # NOTE: for self.build method
        args.char_list = getattr(args, "char_list", None)
        self.char_list = args.char_list
        self.outdir = args.outdir
        self.space = args.sym_space
        self.blank = args.sym_blank
        self.reporter = Reporter()
        self.num_encs = args.num_encs
        self.share_ctc = args.share_ctc

        # below means the last number becomes eos/sos ID
        # note that sos/eos IDs are identical
        self.sos = odim - 1
        self.eos = odim - 1

        # subsample info
        self.subsample_list = []
        for idx in range(self.num_encs):
            # +1 means input (+1) and layers outputs (args.elayer)
            subsample = np.ones(args.elayers[idx] + 1, dtype=np.int)
            if args.etype[idx].endswith(
                    "p") and not args.etype[idx].startswith("vgg"):
                ss = args.subsample[idx].split("_")
                for j in range(min(args.elayers[idx] + 1, len(ss))):
                    subsample[j] = int(ss[j])
            else:
                logging.warning(
                    'Encoder {}: Subsampling is not performed for vgg*. '
                    'It is performed in max pooling layers at CNN.'.format(
                        idx + 1))
            logging.info('subsample: ' + ' '.join([str(x) for x in subsample]))
            self.subsample_list.append(subsample)

        # label smoothing info
        if args.lsm_type and os.path.isfile(args.train_json):
            logging.info("Use label smoothing with " + args.lsm_type)
            labeldist = label_smoothing_dist(odim,
                                             args.lsm_type,
                                             transcript=args.train_json)
        else:
            labeldist = None

        # speech translation related
        self.replace_sos = getattr(args, "replace_sos",
                                   False)  # use getattr to keep compatibility

        self.frontend = None

        # encoder
        self.enc = encoder_for(args, idims, self.subsample_list)
        # ctc
        self.ctc = ctc_for(args, odim)
        # attention
        self.att = att_for(args)
        # hierarchical attention network
        han = att_for(args, han_mode=True)
        self.att.append(han)
        # decoder
        self.dec = decoder_for(args, odim, self.sos, self.eos, self.att,
                               labeldist)

        if args.mtlalpha > 0 and self.num_encs > 1:
            # weights-ctc, e.g. ctc_loss = w_1*ctc_1_loss + w_2 * ctc_2_loss + w_N * ctc_N_loss
            self.weights_ctc_train = args.weights_ctc_train / np.sum(
                args.weights_ctc_train)  # normalize
            self.weights_ctc_dec = args.weights_ctc_dec / np.sum(
                args.weights_ctc_dec)  # normalize
            logging.info('ctc weights (training during training): ' +
                         ' '.join([str(x) for x in self.weights_ctc_train]))
            logging.info('ctc weights (decoding during training): ' +
                         ' '.join([str(x) for x in self.weights_ctc_dec]))
        else:
            self.weights_ctc_dec = [1.0]
            self.weights_ctc_train = [1.0]

        # weight initialization
        self.init_like_chainer()

        # options for beam search
        if args.report_cer or args.report_wer:
            recog_args = {
                'beam_size': args.beam_size,
                'penalty': args.penalty,
                'ctc_weight': args.ctc_weight,
                'maxlenratio': args.maxlenratio,
                'minlenratio': args.minlenratio,
                'lm_weight': args.lm_weight,
                'rnnlm': args.rnnlm,
                'nbest': args.nbest,
                'space': args.sym_space,
                'blank': args.sym_blank,
                'tgt_lang': False,
                'ctc_weights_dec': self.weights_ctc_dec
            }

            self.recog_args = argparse.Namespace(**recog_args)
            self.report_cer = args.report_cer
            self.report_wer = args.report_wer
        else:
            self.report_cer = False
            self.report_wer = False
        self.rnnlm = None

        self.logzero = -10000000000.0
        self.loss = None
        self.acc = None
Exemple #8
0
    def __init__(self, idim, odim, args):
        torch.nn.Module.__init__(self)
        self.mtlalpha = args.mtlalpha
        assert 0.0 <= self.mtlalpha <= 1.0, "mtlalpha should be [0.0, 1.0]"
        self.etype = args.etype
        self.verbose = args.verbose
        self.char_list = args.char_list
        self.outdir = args.outdir
        self.reporter = Reporter()
        self.num_spkrs = args.num_spkrs
        self.spa = args.spa
        self.pit = PIT(self.num_spkrs)

        # below means the last number becomes eos/sos ID
        # note that sos/eos IDs are identical
        self.sos = odim - 1
        self.eos = odim - 1

        # subsample info
        # +1 means input (+1) and layers outputs (args.elayer_sd + args.elayers)
        subsample = np.ones(args.elayers_sd + args.elayers + 1, dtype=np.int)
        if args.etype.endswith("p") and not args.etype.startswith("vgg"):
            ss = args.subsample.split("_")
            for j in range(min(args.elayers_sd + args.elayers + 1, len(ss))):
                subsample[j] = int(ss[j])
        else:
            logging.warning(
                'Subsampling is not performed for vgg*. It is performed in max pooling layers at CNN.'
            )
        logging.info('subsample: ' + ' '.join([str(x) for x in subsample]))
        self.subsample = subsample

        # label smoothing info
        if args.lsm_type:
            logging.info("Use label smoothing with " + args.lsm_type)
            labeldist = label_smoothing_dist(odim,
                                             args.lsm_type,
                                             transcript=args.train_json)
        else:
            labeldist = None

        # encoder
        self.enc = encoder_for(args, idim, self.subsample)
        # ctc
        self.ctc = ctc_for(args, odim, reduce=False)
        # attention
        num_att = self.num_spkrs if args.spa else 1
        self.att = att_for(args, num_att)
        # decoder
        self.dec = decoder_for(args, odim, self.sos, self.eos, self.att,
                               labeldist)

        # weight initialization
        self.init_like_chainer()

        # options for beam search
        if 'report_cer' in vars(args) and (args.report_cer or args.report_wer):
            recog_args = {
                'beam_size': args.beam_size,
                'penalty': args.penalty,
                'ctc_weight': args.ctc_weight,
                'maxlenratio': args.maxlenratio,
                'minlenratio': args.minlenratio,
                'lm_weight': args.lm_weight,
                'rnnlm': args.rnnlm,
                'nbest': args.nbest,
                'space': args.sym_space,
                'blank': args.sym_blank
            }

            self.recog_args = argparse.Namespace(**recog_args)
            self.report_cer = args.report_cer
            self.report_wer = args.report_wer
        else:
            self.report_cer = False
            self.report_wer = False
        self.rnnlm = None

        self.logzero = -10000000000.0
        self.loss = None
        self.acc = None
    def __init__(self, idim, odim, args):
        """Initialize multi-speaker E2E module."""
        torch.nn.Module.__init__(self)
        self.mtlalpha = args.mtlalpha
        assert 0.0 <= self.mtlalpha <= 1.0, "mtlalpha should be [0.0, 1.0]"
        self.etype = args.etype
        self.verbose = args.verbose
        self.char_list = args.char_list
        self.outdir = args.outdir
        self.reporter = Reporter()
        self.num_spkrs = args.num_spkrs
        self.spa = args.spa
        self.pit = PIT(self.num_spkrs)

        # below means the last number becomes eos/sos ID
        # note that sos/eos IDs are identical
        self.sos = odim - 1
        self.eos = odim - 1

        # subsample info
        # +1 means input (+1) and layers outputs (args.elayer_sd + args.elayers)
        subsample = np.ones(args.elayers_sd + args.elayers + 1, dtype=np.int)
        if args.etype.endswith("p") and not args.etype.startswith(
                "vgg") and not args.etype.startswith("sinc"):
            ss = args.subsample.split("_")
            for j in range(min(args.elayers_sd + args.elayers + 1, len(ss))):
                subsample[j] = int(ss[j])
        else:
            logging.warning(
                'Subsampling is not performed for vgg* and sinc*. It is performed in max pooling layers at CNN (Not performed at all for SincNet). '
            )
        logging.info('subsample: ' + ' '.join([str(x) for x in subsample]))
        self.subsample = subsample

        # label smoothing info
        if args.lsm_type and os.path.isfile(args.train_json):
            logging.info("Use label smoothing with " + args.lsm_type)
            labeldist = label_smoothing_dist(odim,
                                             args.lsm_type,
                                             transcript=args.train_json)
        else:
            labeldist = None

        if getattr(args, "use_frontend",
                   False):  # use getattr to keep compatibility
            # Relative importing because of using python3 syntax
            from espnet.nets.pytorch_backend.frontends.feature_transform \
                import feature_transform_for
            from espnet.nets.pytorch_backend.frontends.frontend \
                import frontend_for

            self.frontend = frontend_for(args, idim)
            self.feature_transform = feature_transform_for(
                args, (idim - 1) * 2)
            idim = args.n_mels
        else:
            self.frontend = None

        # encoder
        self.enc = encoder_for(args, idim, self.subsample)
        # ctc
        self.ctc = ctc_for(args, odim, reduce=False)
        # attention
        num_att = self.num_spkrs if args.spa else 1
        self.att = att_for(args, num_att)
        # decoder
        self.dec = decoder_for(args, odim, self.sos, self.eos, self.att,
                               labeldist)

        # weight initialization
        self.init_like_chainer()

        # options for beam search
        if 'report_cer' in vars(args) and (args.report_cer or args.report_wer):
            recog_args = {
                'beam_size': args.beam_size,
                'penalty': args.penalty,
                'ctc_weight': args.ctc_weight,
                'maxlenratio': args.maxlenratio,
                'minlenratio': args.minlenratio,
                'lm_weight': args.lm_weight,
                'rnnlm': args.rnnlm,
                'nbest': args.nbest,
                'space': args.sym_space,
                'blank': args.sym_blank
            }

            self.recog_args = argparse.Namespace(**recog_args)
            self.report_cer = args.report_cer
            self.report_wer = args.report_wer
        else:
            self.report_cer = False
            self.report_wer = False
        self.rnnlm = None

        self.logzero = -10000000000.0
        self.loss = None
        self.acc = None
Exemple #10
0
    def __init__(self, idim, odim, args):
        """Initialize multi-speaker E2E module."""
        torch.nn.Module.__init__(self)
        self.mtlalpha = args.mtlalpha
        assert 0.0 <= self.mtlalpha <= 1.0, "mtlalpha should be [0.0, 1.0]"
        self.etype = args.etype
        self.verbose = args.verbose
        self.char_list = args.char_list
        self.outdir = args.outdir
        self.reporter = Reporter()
        self.num_spkrs = args.num_spkrs
        self.spa = args.spa
        self.pit = PIT(self.num_spkrs)

        # below means the last number becomes eos/sos ID
        # note that sos/eos IDs are identical
        self.sos = odim - 1
        self.eos = odim - 1

        # subsample info
        self.subsample = get_subsample(args, mode='asr', arch='rnn_mix')

        # label smoothing info
        if args.lsm_type and os.path.isfile(args.train_json):
            logging.info("Use label smoothing with " + args.lsm_type)
            labeldist = label_smoothing_dist(odim,
                                             args.lsm_type,
                                             transcript=args.train_json)
        else:
            labeldist = None

        if getattr(args, "use_frontend",
                   False):  # use getattr to keep compatibility
            # Relative importing because of using python3 syntax
            from espnet.nets.pytorch_backend.frontends.feature_transform \
                import feature_transform_for
            from espnet.nets.pytorch_backend.frontends.frontend \
                import frontend_for

            self.frontend = frontend_for(args, idim)
            self.feature_transform = feature_transform_for(
                args, (idim - 1) * 2)
            idim = args.n_mels
        else:
            self.frontend = None

        # encoder
        self.enc = encoder_for(args, idim, self.subsample)
        # ctc
        self.ctc = ctc_for(args, odim, reduce=False)
        # attention
        num_att = self.num_spkrs if args.spa else 1
        self.att = att_for(args, num_att)
        # decoder
        self.dec = decoder_for(args, odim, self.sos, self.eos, self.att,
                               labeldist)

        # weight initialization
        self.init_like_chainer()

        # options for beam search
        if 'report_cer' in vars(args) and (args.report_cer or args.report_wer):
            recog_args = {
                'beam_size': args.beam_size,
                'penalty': args.penalty,
                'ctc_weight': args.ctc_weight,
                'maxlenratio': args.maxlenratio,
                'minlenratio': args.minlenratio,
                'lm_weight': args.lm_weight,
                'rnnlm': args.rnnlm,
                'nbest': args.nbest,
                'space': args.sym_space,
                'blank': args.sym_blank
            }

            self.recog_args = argparse.Namespace(**recog_args)
            self.report_cer = args.report_cer
            self.report_wer = args.report_wer
        else:
            self.report_cer = False
            self.report_wer = False
        self.rnnlm = None

        self.logzero = -10000000000.0
        self.loss = None
        self.acc = None
Exemple #11
0
    def __init__(self, idim, odim, args):
        """Construct an E2E object.

        :param int idim: dimension of inputs
        :param int odim: dimension of outputs
        :param Namespace args: argument Namespace containing options
        """
        super(E2E, self).__init__()
        torch.nn.Module.__init__(self)

        # fill missing arguments for compatibility
        args = fill_missing_args(args, self.add_arguments)

        self.mtlalpha = args.mtlalpha
        assert 0.0 <= self.mtlalpha <= 1.0, "mtlalpha should be [0.0, 1.0]"
        self.etype = args.etype
        self.verbose = args.verbose
        # NOTE: for self.build method
        args.char_list = getattr(args, "char_list", None)
        self.char_list = args.char_list
        self.outdir = args.outdir
        self.space = args.sym_space
        self.blank = args.sym_blank
        self.reporter = Reporter()

        # below means the last number becomes eos/sos ID
        # note that sos/eos IDs are identical
        self.sos = odim - 1
        self.eos = odim - 1

        # subsample info
        self.subsample = get_subsample(args, mode="asr", arch="rnn")

        # label smoothing info
        if args.lsm_type and os.path.isfile(args.train_json):
            logging.info("Use label smoothing with " + args.lsm_type)
            labeldist = label_smoothing_dist(odim,
                                             args.lsm_type,
                                             transcript=args.train_json)
        else:
            labeldist = None

        if getattr(args, "use_frontend",
                   False):  # use getattr to keep compatibility
            self.frontend = frontend_for(args, idim)
            self.feature_transform = feature_transform_for(
                args, (idim - 1) * 2)
            idim = args.n_mels
        else:
            self.frontend = None

        # encoder
        self.enc = encoder_for(args, idim, self.subsample)
        # ctc
        self.ctc = ctc_for(args, odim)
        # attention
        self.att = att_for(args)
        # decoder
        self.dec = decoder_for(args, odim, self.sos, self.eos, self.att,
                               labeldist)

        # weight initialization
        self.init_like_chainer()

        # options for beam search
        if args.report_cer or args.report_wer:
            recog_args = {
                "beam_size": args.beam_size,
                "penalty": args.penalty,
                "ctc_weight": args.ctc_weight,
                "maxlenratio": args.maxlenratio,
                "minlenratio": args.minlenratio,
                "lm_weight": args.lm_weight,
                "rnnlm": args.rnnlm,
                "nbest": args.nbest,
                "space": args.sym_space,
                "blank": args.sym_blank,
            }

            self.recog_args = argparse.Namespace(**recog_args)
            self.report_cer = args.report_cer
            self.report_wer = args.report_wer
        else:
            self.report_cer = False
            self.report_wer = False
        self.rnnlm = None

        self.logzero = -10000000000.0
        self.loss = None
        self.acc = None
Exemple #12
0
    def __init__(self, idim, odim, args):
        """Construct an E2E object.

        :param int idim: dimension of inputs
        :param int odim: dimension of outputs
        :param Namespace args: argument Namespace containing options

        """
        super(E2E, self).__init__()
        torch.nn.Module.__init__(self)
        self.etype = args.etype
        self.verbose = args.verbose
        # NOTE: for self.build method
        args.char_list = getattr(args, "char_list", None)
        self.char_list = args.char_list
        self.outdir = args.outdir
        self.space = args.sym_space
        self.blank = args.sym_blank
        self.reporter = Reporter()

        # below means the last number becomes eos/sos ID
        # note that sos/eos IDs are identical
        self.sos = odim - 1
        self.eos = odim - 1
        self.pad = 0
        # NOTE: we reserve index:0 for <pad> although this is reserved for a blank class
        # in ASR. However, blank labels are not used in NMT. To keep the vocabulary size,
        # we use index:0 for padding instead of adding one more class.

        # subsample info
        # +1 means input (+1) and layers outputs (args.elayer)
        subsample = np.ones(args.elayers + 1, dtype=np.int)
        logging.warning(
            'Subsampling is not performed for machine translation.')
        logging.info('subsample: ' + ' '.join([str(x) for x in subsample]))
        self.subsample = subsample

        # label smoothing info
        if args.lsm_type and os.path.isfile(args.train_json):
            logging.info("Use label smoothing with " + args.lsm_type)
            labeldist = label_smoothing_dist(odim,
                                             args.lsm_type,
                                             transcript=args.train_json)
        else:
            labeldist = None

        # multilingual related
        self.multilingual = getattr(args, "multilingual", False)
        self.replace_sos = getattr(args, "replace_sos", False)

        # encoder
        self.embed = torch.nn.Embedding(idim,
                                        args.eunits,
                                        padding_idx=self.pad)
        self.dropout = torch.nn.Dropout(p=args.dropout_rate)
        self.enc = encoder_for(args, args.eunits, self.subsample)
        # attention
        self.att = att_for(args)
        # decoder
        self.dec = decoder_for(args, odim, self.sos, self.eos, self.att,
                               labeldist)

        # tie source and target emeddings
        if args.tie_src_tgt_embedding:
            if idim != odim:
                raise ValueError(
                    'When using tie_src_tgt_embedding, idim and odim must be equal.'
                )
            if args.eunits != args.dunits:
                raise ValueError(
                    'When using tie_src_tgt_embedding, eunits and dunits must be equal.'
                )
            self.embed.weight = self.dec.embed.weight

        # tie emeddings and the classfier
        if args.tie_classifier:
            if args.context_residual:
                raise ValueError(
                    'When using tie_classifier, context_residual must be turned off.'
                )
            self.dec.output.weight = self.dec.embed.weight

        # weight initialization
        self.init_like_fairseq()

        # options for beam search
        if args.report_bleu:
            trans_args = {
                'beam_size': args.beam_size,
                'penalty': args.penalty,
                'ctc_weight': 0,
                'maxlenratio': args.maxlenratio,
                'minlenratio': args.minlenratio,
                'lm_weight': args.lm_weight,
                'rnnlm': args.rnnlm,
                'nbest': args.nbest,
                'space': args.sym_space,
                'blank': args.sym_blank,
                'tgt_lang': False
            }

            self.trans_args = argparse.Namespace(**trans_args)
            self.report_bleu = args.report_bleu
        else:
            self.report_bleu = False
        self.rnnlm = None

        self.logzero = -10000000000.0
        self.loss = None
        self.acc = None
Exemple #13
0
    def __init__(self, idim, odim, args, asr_model=None, mt_model=None):
        super(E2E, self).__init__()
        torch.nn.Module.__init__(self)
        self.mtlalpha = args.mtlalpha
        assert 0.0 <= self.mtlalpha <= 1.0, "mtlalpha should be [0.0, 1.0]"
        self.etype = args.etype
        self.verbose = args.verbose
        self.char_list = args.char_list
        self.outdir = args.outdir
        self.space = args.sym_space
        self.blank = args.sym_blank
        self.reporter = Reporter()

        # below means the last number becomes eos/sos ID
        # note that sos/eos IDs are identical
        self.sos = odim - 1
        self.eos = odim - 1

        # subsample info
        # +1 means input (+1) and layers outputs (args.elayer)
        subsample = np.ones(args.elayers + 1, dtype=np.int)
        if args.etype.endswith("p") and not args.etype.startswith("vgg"):
            ss = args.subsample.split("_")
            for j in range(min(args.elayers + 1, len(ss))):
                subsample[j] = int(ss[j])
        else:
            logging.warning(
                'Subsampling is not performed for vgg*. It is performed in max pooling layers at CNN.')
        logging.info('subsample: ' + ' '.join([str(x) for x in subsample]))
        self.subsample = subsample

        # label smoothing info
        if args.lsm_type and os.path.isfile(args.train_json):
            logging.info("Use label smoothing with " + args.lsm_type)
            labeldist = label_smoothing_dist(odim, args.lsm_type, transcript=args.train_json)
        else:
            labeldist = None

        # speech translation related
        self.replace_sos = args.replace_sos

        if args.use_frontend:
            # Relative importing because of using python3 syntax
            from espnet.nets.pytorch_backend.frontends.feature_transform \
                import feature_transform_for
            from espnet.nets.pytorch_backend.frontends.frontend \
                import frontend_for

            self.frontend = frontend_for(args, idim)
            self.feature_transform = feature_transform_for(args, (idim - 1) * 2)
            idim = args.n_mels
        else:
            self.frontend = None

        # encoder
        self.enc = encoder_for(args, idim, self.subsample)
        # ctc
        self.ctc = ctc_for(args, odim)
        # attention
        self.att = att_for(args)
        # decoder
        self.dec = decoder_for(args, odim, self.sos, self.eos, self.att, labeldist)

        # weight initialization
        self.init_like_chainer()

        # pre-training w/ ASR encoder and NMT decoder
        if asr_model is not None:
            param_dict = dict(asr_model.named_parameters())
            for n, p in self.named_parameters():
                # overwrite the encoder
                if n in param_dict.keys() and p.size() == param_dict[n].size():
                    if 'enc.enc' in n:
                        p.data = param_dict[n].data
                        logging.warning('Overwrite %s' % n)
        if mt_model is not None:
            param_dict = dict(mt_model.named_parameters())
            for n, p in self.named_parameters():
                # overwrite the decoder
                if n in param_dict.keys() and p.size() == param_dict[n].size():
                    if 'dec.' in n or 'att' in n:
                        p.data = param_dict[n].data
                        logging.warning('Overwrite %s' % n)

        # options for beam search
        if args.report_cer or args.report_wer:
            recog_args = {'beam_size': args.beam_size, 'penalty': args.penalty,
                          'ctc_weight': args.ctc_weight, 'maxlenratio': args.maxlenratio,
                          'minlenratio': args.minlenratio, 'lm_weight': args.lm_weight,
                          'rnnlm': args.rnnlm, 'nbest': args.nbest,
                          'space': args.sym_space, 'blank': args.sym_blank,
                          'tgt_lang': False}

            self.recog_args = argparse.Namespace(**recog_args)
            self.report_cer = args.report_cer
            self.report_wer = args.report_wer
        else:
            self.report_cer = False
            self.report_wer = False
        self.rnnlm = None

        self.logzero = -10000000000.0
        self.loss = None
        self.acc = None
Exemple #14
0
    def __init__(self, idim, odim, args):
        """Construct an E2E object.

        :param int idim: dimension of inputs
        :param int odim: dimension of outputs
        :param Namespace args: argument Namespace containing options
        """
        super(E2E, self).__init__()
        torch.nn.Module.__init__(self)

        # fill missing arguments for compatibility
        args = fill_missing_args(args, self.add_arguments)

        self.mtlalpha = args.mtlalpha
        assert 0.0 <= self.mtlalpha <= 1.0, "mtlalpha should be [0.0, 1.0]"
        self.etype = args.etype
        self.verbose = args.verbose
        # NOTE: for self.build method
        args.char_list = getattr(args, "char_list", None)
        self.char_list = args.char_list
        self.outdir = args.outdir
        self.space = args.sym_space
        self.blank = args.sym_blank
        self.reporter = Reporter()

        # below means the last number becomes eos/sos ID
        # note that sos/eos IDs are identical
        self.sos = odim - 1
        self.eos = odim - 1

        # gs534 - word vocab
        bpe = len(self.char_list) > 100 # hack here for bpe flag
        self.vocabulary = Vocabulary(args.dictfile, bpe) if args.dictfile != '' else None

        # gs534 - create lexicon tree
        lextree = None
        self.meeting_KB = None
        self.n_KBs = getattr(args, 'dynamicKBs', 0)
        pretrain_emb = []
        if args.meetingKB and args.meetingpath != '':
            if self.n_KBs == 0 or not os.path.isdir(os.path.join(args.meetingpath, 'split_0')):
                self.meeting_KB = KBmeeting(self.vocabulary, args.meetingpath, args.char_list, bpe)
            else:
                # arrange multiple KBs
                self.meeting_KB = []
                for i in range(self.n_KBs):
                    self.meeting_KB.append(KBmeeting(self.vocabulary,
                        os.path.join(args.meetingpath, 'split_{}'.format(i)), args.char_list, bpe))

        # subsample info
        self.subsample = get_subsample(args, mode="asr", arch="rnn")

        # label smoothing info
        if args.lsm_type and os.path.isfile(args.train_json):
            logging.info("Use label smoothing with " + args.lsm_type)
            labeldist = label_smoothing_dist(
                odim, args.lsm_type, transcript=args.train_json
            )
        else:
            labeldist = None

        if getattr(args, "use_frontend", False):  # use getattr to keep compatibility
            self.frontend = frontend_for(args, idim)
            self.feature_transform = feature_transform_for(args, (idim - 1) * 2)
            idim = args.n_mels
        else:
            self.frontend = None

        # encoder
        self.enc = encoder_for(args, idim, self.subsample)
        # ctc
        self.ctc = ctc_for(args, odim)
        # attention
        self.att = att_for(args)
        # decoder
        self.dec = decoder_for(args, odim, self.sos, self.eos, self.att, labeldist,
            meetingKB=self.meeting_KB[0] if isinstance(self.meeting_KB, list) else self.meeting_KB)

        # weight initialization
        self.init_from = getattr(args, 'init_full_model', None)
        self.init_like_chainer()

        # options for beam search
        if args.report_cer or args.report_wer:
            recog_args = {
                "beam_size": args.beam_size,
                "penalty": args.penalty,
                "ctc_weight": args.ctc_weight,
                "maxlenratio": args.maxlenratio,
                "minlenratio": args.minlenratio,
                "lm_weight": args.lm_weight,
                "rnnlm": args.rnnlm,
                "nbest": args.nbest,
                "space": args.sym_space,
                "blank": args.sym_blank,
            }

            self.recog_args = argparse.Namespace(**recog_args)
            self.report_cer = args.report_cer
            self.report_wer = args.report_wer
        else:
            self.report_cer = False
            self.report_wer = False
        self.rnnlm = None

        self.logzero = -10000000000.0
        self.loss = None
        self.acc = None
Exemple #15
0
    def __init__(self, idim, odim, args, asr_model=None, mt_model=None):
        """Construct an E2E object."""
        super(E2E, self).__init__()
        torch.nn.Module.__init__(self)
        self.asr_weight = getattr(args, "asr_weight", 0)
        assert 0.0 <= self.asr_weight < 1.0, "asr_weight should be [0.0, 1.0)"
        self.etype = args.etype
        self.verbose = args.verbose
        # NOTE: for self.build method
        args.char_list = getattr(args, "char_list", None)
        self.char_list = args.char_list
        self.outdir = args.outdir
        self.space = args.sym_space
        self.blank = args.sym_blank
        self.reporter = Reporter()

        # below means the last number becomes eos/sos ID
        # note that sos/eos IDs are identical
        self.sos = odim - 1
        self.eos = odim - 1

        # subsample info
        # +1 means input (+1) and layers outputs (args.elayer)
        subsample = np.ones(args.elayers + 1, dtype=np.int)
        if args.etype.endswith("p") and not args.etype.startswith("vgg"):
            ss = args.subsample.split("_")
            for j in range(min(args.elayers + 1, len(ss))):
                subsample[j] = int(ss[j])
        else:
            logging.warning(
                'Subsampling is not performed for vgg*. It is performed in max pooling layers at CNN.'
            )
        logging.info('subsample: ' + ' '.join([str(x) for x in subsample]))
        self.subsample = subsample

        # label smoothing info
        if args.lsm_type and os.path.isfile(args.train_json):
            logging.info("Use label smoothing with " + args.lsm_type)
            labeldist = label_smoothing_dist(odim,
                                             args.lsm_type,
                                             transcript=args.train_json)
        else:
            labeldist = None

        # multilingual related
        self.multilingual = getattr(args, "multilingual", False)
        self.replace_sos = args.replace_sos

        # encoder
        self.enc = encoder_for(args, idim, self.subsample)
        if self.asr_weight > 0:
            # attention (asr)
            self.att_asr = att_for(args)
            # decoder (asr)
            args_asr = copy.deepcopy(args)
            args_asr.atype = 'location'  # TODO(hirofumi0810): make this option
            self.dec_asr = decoder_for(args, odim, self.sos, self.eos,
                                       self.att_asr, labeldist)
        # attention (st)
        self.att = att_for(args)
        # decoder (st)
        self.dec = decoder_for(args, odim, self.sos, self.eos, self.att,
                               labeldist)

        # weight initialization
        self.init_like_chainer()

        # options for beam search
        if args.report_cer or args.report_wer:
            recog_args = {
                'beam_size': args.beam_size,
                'penalty': args.penalty,
                'ctc_weight': 0,
                'maxlenratio': args.maxlenratio,
                'minlenratio': args.minlenratio,
                'lm_weight': args.lm_weight,
                'rnnlm': args.rnnlm,
                'nbest': args.nbest,
                'space': args.sym_space,
                'blank': args.sym_blank,
                'tgt_lang': False
            }

            self.recog_args = argparse.Namespace(**recog_args)
            self.report_cer = args.report_cer
            self.report_wer = args.report_wer
        else:
            self.report_cer = False
            self.report_wer = False
        if args.report_bleu:
            trans_args = {
                'beam_size': args.beam_size,
                'penalty': args.penalty,
                'ctc_weight': 0,
                'maxlenratio': args.maxlenratio,
                'minlenratio': args.minlenratio,
                'lm_weight': args.lm_weight,
                'rnnlm': args.rnnlm,
                'nbest': args.nbest,
                'space': args.sym_space,
                'blank': args.sym_blank,
                'tgt_lang': False
            }

            self.trans_args = argparse.Namespace(**trans_args)
            self.report_bleu = args.report_bleu
        else:
            self.report_bleu = False
        self.rnnlm = None

        self.logzero = -10000000000.0
        self.loss = None
        self.acc = None
    def __init__(self, idim, odim, args):
        torch.nn.Module.__init__(self)
        self.mtlalpha = args.mtlalpha
        assert 0.0 <= self.mtlalpha <= 1.0, "mtlalpha should be [0.0, 1.0]"
        self.etype = args.etype
        self.verbose = args.verbose
        self.char_list = args.char_list
        self.outdir = args.outdir
        self.space = args.sym_space
        # self.space = -1
        self.blank = args.sym_blank
        self.reporter = Reporter()

        # below means the last number becomes eos/sos ID
        # note that sos/eos IDs are identical
        self.sos = odim - 1
        self.eos = odim - 1

        # subsample info
        # +1 means input (+1) and layers outputs (args.elayer)
        subsample = np.ones(args.elayers + 1, dtype=np.int)
        if args.etype.endswith("p") and not args.etype.startswith("vgg"):
            ss = args.subsample.split("_")
            for j in range(min(args.elayers + 1, len(ss))):
                subsample[j] = int(ss[j])
        else:
            logging.warning(
                'Subsampling is not performed for vgg*. It is performed in max pooling layers at CNN.'
            )
        logging.info('subsample: ' + ' '.join([str(x) for x in subsample]))
        self.subsample = subsample

        # label smoothing info
        if args.lsm_type:
            logging.info("Use label smoothing with " + args.lsm_type)
            labeldist = label_smoothing_dist(odim,
                                             args.lsm_type,
                                             transcript=args.train_json)
        else:
            labeldist = None

        if args.use_frontend:
            # Relative importing because of using python3 syntax
            from espnet.nets.pytorch_backend.frontends.feature_transform \
                import feature_transform_for
            from espnet.nets.pytorch_backend.frontends.frontend \
                import frontend_for

            self.frontend = frontend_for(args, idim)
            self.feature_transform = feature_transform_for(
                args, (idim - 1) * 2)
            idim = args.n_mels
        else:
            self.frontend = None

        # encoder
        # self.enc = encoder_for(args, idim, self.subsample)

        self.encoder = Encoder(
            idim=idim,
            center_len=args.transformer_encoder_center_chunk_len,
            left_len=args.transformer_encoder_left_chunk_len,
            hop_len=args.transformer_encoder_hop_len,
            right_len=args.transformer_encoder_right_chunk_len,
            abs_pos=args.transformer_encoder_abs_embed,
            rel_pos=args.transformer_encoder_rel_embed,
            use_mem=args.transformer_encoder_use_memory,
            attention_dim=args.adim,
            attention_heads=args.aheads,
            linear_units=args.eunits,
            num_blocks=args.elayers,
            input_layer=args.transformer_input_layer,
            dropout_rate=args.dropout_rate,
            positional_dropout_rate=args.dropout_rate,
            attention_dropout_rate=args.transformer_attn_dropout_rate)

        # ctc
        self.ctc = ctc_for(args, odim)
        # attention
        self.att = att_for(args)
        # decoder
        self.dec = decoder_for(args, odim, self.sos, self.eos, self.att,
                               labeldist)

        # weight initialization
        self.init_like_chainer()

        # options for beam search
        if args.report_cer or args.report_wer:
            recog_args = {
                'beam_size': args.beam_size,
                'penalty': args.penalty,
                'ctc_weight': args.ctc_weight,
                'maxlenratio': args.maxlenratio,
                'minlenratio': args.minlenratio,
                'lm_weight': args.lm_weight,
                'rnnlm': args.rnnlm,
                'nbest': args.nbest,
                'space': args.sym_space,
                'blank': args.sym_blank
            }

            self.recog_args = argparse.Namespace(**recog_args)
            self.report_cer = args.report_cer
            self.report_wer = args.report_wer
        else:
            self.report_cer = False
            self.report_wer = False
        self.rnnlm = None

        self.logzero = -10000000000.0
        self.loss = None
        self.acc = None