Exemplo n.º 1
0
 def __init__(self, idim, odim, args, ignore_id=-1, flag_return=True):
     chainer.Chain.__init__(self)
     self.mtlalpha = args.mtlalpha
     assert 0 <= self.mtlalpha <= 1, "mtlalpha must be [0,1]"
     if args.transformer_attn_dropout_rate is None:
         self.dropout = args.dropout_rate
     else:
         self.dropout = args.transformer_attn_dropout_rate
     self.use_label_smoothing = False
     self.char_list = args.char_list
     self.space = args.sym_space
     self.blank = args.sym_blank
     self.scale_emb = args.adim**0.5
     self.sos = odim - 1
     self.eos = odim - 1
     self.subsample = [0]
     self.ignore_id = ignore_id
     self.reset_parameters(args)
     with self.init_scope():
         self.encoder = Encoder(idim,
                                args,
                                initialW=self.initialW,
                                initial_bias=self.initialB)
         self.decoder = Decoder(odim,
                                args,
                                initialW=self.initialW,
                                initial_bias=self.initialB)
         self.criterion = LabelSmoothingLoss(
             args.lsm_weight, len(args.char_list),
             args.transformer_length_normalized_loss)
         if args.mtlalpha > 0.0:
             if args.ctc_type == 'builtin':
                 logging.info("Using chainer CTC implementation")
                 self.ctc = ctc.CTC(odim, args.adim, args.dropout_rate)
             elif args.ctc_type == 'warpctc':
                 logging.info("Using warpctc CTC implementation")
                 self.ctc = ctc.WarpCTC(odim, args.adim, args.dropout_rate)
             else:
                 raise ValueError(
                     'ctc_type must be "builtin" or "warpctc": {}'.format(
                         args.ctc_type))
         else:
             self.ctc = None
     self.dims = args.adim
     self.odim = odim
     self.flag_return = flag_return
     if args.report_cer or args.report_wer:
         from espnet.nets.e2e_asr_common import ErrorCalculator
         self.error_calculator = ErrorCalculator(args.char_list,
                                                 args.sym_space,
                                                 args.sym_blank,
                                                 args.report_cer,
                                                 args.report_wer)
     else:
         self.error_calculator = None
     if 'Namespace' in str(type(args)):
         self.verbose = 0 if 'verbose' not in args else args.verbose
     else:
         self.verbose = 0 if args.verbose is None else args.verbose
 def __init__(self, idim, odim, args, ignore_id=-1, flag_return=True):
     """Initialize the transformer."""
     chainer.Chain.__init__(self)
     self.mtlalpha = args.mtlalpha
     assert 0 <= self.mtlalpha <= 1, "mtlalpha must be [0,1]"
     if args.transformer_attn_dropout_rate is None:
         args.transformer_attn_dropout_rate = args.dropout_rate
     self.use_label_smoothing = False
     self.char_list = args.char_list
     self.space = args.sym_space
     self.blank = args.sym_blank
     self.scale_emb = args.adim**0.5
     self.sos = odim - 1
     self.eos = odim - 1
     self.subsample = get_subsample(args, mode="asr", arch="transformer")
     self.ignore_id = ignore_id
     self.reset_parameters(args)
     with self.init_scope():
         self.encoder = Encoder(
             idim=idim,
             attention_dim=args.adim,
             attention_heads=args.aheads,
             linear_units=args.eunits,
             input_layer=args.transformer_input_layer,
             dropout_rate=args.dropout_rate,
             positional_dropout_rate=args.dropout_rate,
             attention_dropout_rate=args.transformer_attn_dropout_rate,
             initialW=self.initialW,
             initial_bias=self.initialB,
         )
         self.decoder = Decoder(odim,
                                args,
                                initialW=self.initialW,
                                initial_bias=self.initialB)
         self.criterion = LabelSmoothingLoss(
             args.lsm_weight,
             len(args.char_list),
             args.transformer_length_normalized_loss,
         )
         if args.mtlalpha > 0.0:
             if args.ctc_type == "builtin":
                 logging.info("Using chainer CTC implementation")
                 self.ctc = ctc.CTC(odim, args.adim, args.dropout_rate)
             elif args.ctc_type == "warpctc":
                 logging.info("Using warpctc CTC implementation")
                 self.ctc = ctc.WarpCTC(odim, args.adim, args.dropout_rate)
             else:
                 raise ValueError(
                     'ctc_type must be "builtin" or "warpctc": {}'.format(
                         args.ctc_type))
         else:
             self.ctc = None
     self.dims = args.adim
     self.odim = odim
     self.flag_return = flag_return
     if args.report_cer or args.report_wer:
         self.error_calculator = ErrorCalculator(
             args.char_list,
             args.sym_space,
             args.sym_blank,
             args.report_cer,
             args.report_wer,
         )
     else:
         self.error_calculator = None
     if "Namespace" in str(type(args)):
         self.verbose = 0 if "verbose" not in args else args.verbose
     else:
         self.verbose = 0 if args.verbose is None else args.verbose