class E2E(STInterface, torch.nn.Module): """E2E module. :param int idim: dimension of inputs :param int odim: dimension of outputs :param Namespace args: argument Namespace containing options """ @staticmethod def add_arguments(parser): """Add arguments.""" group = parser.add_argument_group("transformer model setting") group.add_argument("--transformer-init", type=str, default="pytorch", choices=["pytorch", "xavier_uniform", "xavier_normal", "kaiming_uniform", "kaiming_normal"], help='how to initialize transformer parameters') group.add_argument("--transformer-input-layer", type=str, default="conv2d", choices=["conv2d", "linear", "embed"], help='transformer input layer type') group.add_argument('--transformer-attn-dropout-rate', default=None, type=float, help='dropout in transformer attention. use --dropout-rate if None is set') group.add_argument('--transformer-lr', default=10.0, type=float, help='Initial value of learning rate') group.add_argument('--transformer-warmup-steps', default=25000, type=int, help='optimizer warmup steps') group.add_argument('--transformer-length-normalized-loss', default=True, type=strtobool, help='normalize loss by length') group.add_argument('--dropout-rate', default=0.0, type=float, help='Dropout rate for the encoder') # Encoder group.add_argument('--elayers', default=4, type=int, help='Number of encoder layers (for shared recognition part in multi-speaker asr mode)') group.add_argument('--eunits', '-u', default=300, type=int, help='Number of encoder hidden units') # Attention group.add_argument('--adim', default=320, type=int, help='Number of attention transformation dimensions') group.add_argument('--aheads', default=4, type=int, help='Number of heads for multi head attention') # Decoder group.add_argument('--dlayers', default=1, type=int, help='Number of decoder layers') group.add_argument('--dunits', default=320, type=int, help='Number of decoder hidden units') return parser @property def attention_plot_class(self): """Return PlotAttentionReport.""" return PlotAttentionReport def __init__(self, idim, odim, args, ignore_id=-1): """Construct an E2E object. :param int idim: dimension of inputs :param int odim: dimension of outputs :param Namespace args: argument Namespace containing options """ torch.nn.Module.__init__(self) if args.transformer_attn_dropout_rate is None: args.transformer_attn_dropout_rate = args.dropout_rate self.encoder = Encoder( idim=idim, attention_dim=args.adim, attention_heads=args.aheads, linear_units=args.eunits, num_blocks=args.elayers, input_layer=args.transformer_input_layer, dropout_rate=args.dropout_rate, positional_dropout_rate=args.dropout_rate, attention_dropout_rate=args.transformer_attn_dropout_rate ) self.decoder = Decoder( odim=odim, attention_dim=args.adim, attention_heads=args.aheads, linear_units=args.dunits, num_blocks=args.dlayers, dropout_rate=args.dropout_rate, positional_dropout_rate=args.dropout_rate, self_attention_dropout_rate=args.transformer_attn_dropout_rate, src_attention_dropout_rate=args.transformer_attn_dropout_rate ) self.pad = 0 self.sos = odim - 1 self.eos = odim - 1 self.odim = odim self.ignore_id = ignore_id self.subsample = get_subsample(args, mode='st', arch='transformer') self.reporter = Reporter() # self.lsm_weight = a self.criterion = LabelSmoothingLoss(self.odim, self.ignore_id, args.lsm_weight, args.transformer_length_normalized_loss) # self.verbose = args.verbose self.adim = args.adim # submodule for ASR task self.mtlalpha = args.mtlalpha self.asr_weight = getattr(args, "asr_weight", 0.0) if self.asr_weight > 0 and args.mtlalpha < 1: self.decoder_asr = Decoder( odim=odim, attention_dim=args.adim, attention_heads=args.aheads, linear_units=args.dunits, num_blocks=args.dlayers, dropout_rate=args.dropout_rate, positional_dropout_rate=args.dropout_rate, self_attention_dropout_rate=args.transformer_attn_dropout_rate, src_attention_dropout_rate=args.transformer_attn_dropout_rate, ) # submodule for MT task self.mt_weight = getattr(args, "mt_weight", 0.0) if self.mt_weight > 0: self.encoder_mt = Encoder( idim=odim, attention_dim=args.adim, attention_heads=args.aheads, linear_units=args.dunits, num_blocks=args.dlayers, input_layer='embed', dropout_rate=args.dropout_rate, positional_dropout_rate=args.dropout_rate, attention_dropout_rate=args.transformer_attn_dropout_rate, padding_idx=0 ) self.reset_parameters(args) # place after the submodule initialization if args.mtlalpha > 0.0: self.ctc = CTC(odim, args.adim, args.dropout_rate, ctc_type=args.ctc_type, reduce=True) else: self.ctc = None if self.asr_weight > 0 and (args.report_cer or args.report_wer): from espnet.nets.e2e_asr_common import ErrorCalculator self.error_calculator = ErrorCalculator(args.char_list, args.sym_space, args.sym_blank, args.report_cer, args.report_wer) else: self.error_calculator = None self.rnnlm = None # multilingual E2E-ST related self.multilingual = getattr(args, "multilingual", False) self.replace_sos = getattr(args, "replace_sos", False) if self.multilingual: assert self.replace_sos def reset_parameters(self, args): """Initialize parameters.""" # initialize parameters initialize(self, args.transformer_init) if self.mt_weight > 0: torch.nn.init.normal_(self.encoder_mt.embed[0].weight, mean=0, std=args.adim ** -0.5) torch.nn.init.constant_(self.encoder_mt.embed[0].weight[self.pad], 0) def forward(self, xs_pad, ilens, ys_pad, ys_pad_src): """E2E forward. :param torch.Tensor xs_pad: batch of padded source sequences (B, Tmax, idim) :param torch.Tensor ilens: batch of lengths of source sequences (B) :param torch.Tensor ys_pad: batch of padded target sequences (B, Lmax) :param torch.Tensor ys_pad_src: batch of padded target sequences (B, Lmax) :return: ctc loass value :rtype: torch.Tensor :return: attention loss value :rtype: torch.Tensor :return: accuracy in attention decoder :rtype: float """ # 0. Extract target language ID # src_lang_ids = None tgt_lang_ids = None if self.multilingual: tgt_lang_ids = ys_pad[:, 0:1] ys_pad = ys_pad[:, 1:] # remove target language ID in the beggining # 1. forward encoder xs_pad = xs_pad[:, :max(ilens)] # for data parallel src_mask = (~make_pad_mask(ilens.tolist())).to(xs_pad.device).unsqueeze(-2) hs_pad, hs_mask = self.encoder(xs_pad, src_mask) self.hs_pad = hs_pad # 2. forward decoder ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id) # replace <sos> with target language ID if self.replace_sos: ys_in_pad = torch.cat([tgt_lang_ids, ys_in_pad[:, 1:]], dim=1) ys_mask = target_mask(ys_in_pad, self.ignore_id) pred_pad, pred_mask = self.decoder(ys_in_pad, ys_mask, hs_pad, hs_mask) self.pred_pad = pred_pad pred_pad_asr, pred_pad_mt = None, None # 3. compute attention loss loss_asr, loss_mt = 0.0, 0.0 loss_att = self.criterion(pred_pad, ys_out_pad) # Multi-task w/ ASR if self.asr_weight > 0 and self.mtlalpha < 1.0: # forward ASR decoder ys_in_pad_asr, ys_out_pad_asr = add_sos_eos(ys_pad_src, self.sos, self.eos, self.ignore_id) ys_mask_asr = target_mask(ys_in_pad_asr, self.ignore_id) pred_pad_asr, _ = self.decoder_asr(ys_in_pad_asr, ys_mask_asr, hs_pad, hs_mask) # compute loss loss_asr = self.criterion(pred_pad_asr, ys_out_pad_asr) # Multi-task w/ MT if self.mt_weight > 0: # forward MT encoder ilens_mt = torch.sum(ys_pad_src != self.ignore_id, dim=1).cpu().numpy() # NOTE: ys_pad_src is padded with -1 ys_src = [y[y != self.ignore_id] for y in ys_pad_src] # parse padded ys_src ys_zero_pad_src = pad_list(ys_src, self.pad) # re-pad with zero ys_zero_pad_src = ys_zero_pad_src[:, :max(ilens_mt)] # for data parallel src_mask_mt = (~make_pad_mask(ilens_mt.tolist())).to(ys_zero_pad_src.device).unsqueeze(-2) # ys_zero_pad_src, ys_pad = self.target_forcing(ys_zero_pad_src, ys_pad) hs_pad_mt, hs_mask_mt = self.encoder_mt(ys_zero_pad_src, src_mask_mt) # forward MT decoder pred_pad_mt, _ = self.decoder(ys_in_pad, ys_mask, hs_pad_mt, hs_mask_mt) # compute loss loss_mt = self.criterion(pred_pad_mt, ys_out_pad) self.acc = th_accuracy(pred_pad.view(-1, self.odim), ys_out_pad, ignore_label=self.ignore_id) if pred_pad_asr is not None: self.acc_asr = th_accuracy(pred_pad_asr.view(-1, self.odim), ys_out_pad_asr, ignore_label=self.ignore_id) else: self.acc_asr = 0.0 if pred_pad_mt is not None: self.acc_mt = th_accuracy(pred_pad_mt.view(-1, self.odim), ys_out_pad, ignore_label=self.ignore_id) else: self.acc_mt = 0.0 # TODO(karita) show predicted text # TODO(karita) calculate these stats cer_ctc = None if self.mtlalpha == 0.0 or self.asr_weight == 0: loss_ctc = 0.0 else: batch_size = xs_pad.size(0) hs_len = hs_mask.view(batch_size, -1).sum(1) loss_ctc = self.ctc(hs_pad.view(batch_size, -1, self.adim), hs_len, ys_pad_src) if self.error_calculator is not None: ys_hat = self.ctc.argmax(hs_pad.view(batch_size, -1, self.adim)).data cer_ctc = self.error_calculator(ys_hat.cpu(), ys_pad_src.cpu(), is_ctc=True) # 5. compute cer/wer cer, wer = None, None # TODO(hirofumi0810): fix later # if self.training or (self.asr_weight == 0 or self.mtlalpha == 1 or not (self.report_cer or self.report_wer)): # cer, wer = None, None # else: # ys_hat = pred_pad.argmax(dim=-1) # cer, wer = self.error_calculator(ys_hat.cpu(), ys_pad.cpu()) # copyied from e2e_asr alpha = self.mtlalpha self.loss = (1 - self.asr_weight - self.mt_weight) * loss_att + self.asr_weight * \ (alpha * loss_ctc + (1 - alpha) * loss_asr) + self.mt_weight * loss_mt loss_asr_data = float(alpha * loss_ctc + (1 - alpha) * loss_asr) loss_mt_data = None if self.mt_weight == 0 else float(loss_mt) loss_st_data = float(loss_att) loss_data = float(self.loss) if loss_data < CTC_LOSS_THRESHOLD and not math.isnan(loss_data): self.reporter.report(loss_asr_data, loss_mt_data, loss_st_data, self.acc_asr, self.acc_mt, self.acc, cer_ctc, cer, wer, 0.0, # TODO(hirofumi0810): bleu loss_data) else: logging.warning('loss (=%f) is not correct', loss_data) return self.loss def scorers(self): """Scorers.""" return dict(decoder=self.decoder) def encode(self, x): """Encode source acoustic features. :param ndarray x: source acoustic feature (T, D) :return: encoder outputs :rtype: torch.Tensor """ self.eval() x = torch.as_tensor(x).unsqueeze(0) enc_output, _ = self.encoder(x, None) return enc_output.squeeze(0) def translate(self, x, trans_args, char_list=None, rnnlm=None, use_jit=False): """Translate input speech. :param ndnarray x: input acoustic feature (B, T, D) or (T, D) :param Namespace trans_args: argment Namespace contraining options :param list char_list: list of characters :param torch.nn.Module rnnlm: language model module :return: N-best decoding results :rtype: list """ # preprate sos if getattr(trans_args, "tgt_lang", False): if self.replace_sos: y = char_list.index(trans_args.tgt_lang) else: y = self.sos logging.info('<sos> index: ' + str(y)) logging.info('<sos> mark: ' + char_list[y]) enc_output = self.encode(x).unsqueeze(0) h = enc_output.squeeze(0) logging.info('input lengths: ' + str(h.size(0))) # search parms beam = trans_args.beam_size penalty = trans_args.penalty vy = h.new_zeros(1).long() if trans_args.maxlenratio == 0: maxlen = h.shape[0] else: # maxlen >= 1 maxlen = max(1, int(trans_args.maxlenratio * h.size(0))) minlen = int(trans_args.minlenratio * h.size(0)) logging.info('max output length: ' + str(maxlen)) logging.info('min output length: ' + str(minlen)) # initialize hypothesis if rnnlm: hyp = {'score': 0.0, 'yseq': [y], 'rnnlm_prev': None} else: hyp = {'score': 0.0, 'yseq': [y]} hyps = [hyp] ended_hyps = [] import six traced_decoder = None for i in six.moves.range(maxlen): logging.debug('position ' + str(i)) hyps_best_kept = [] for hyp in hyps: vy.unsqueeze(1) vy[0] = hyp['yseq'][i] # get nbest local scores and their ids ys_mask = subsequent_mask(i + 1).unsqueeze(0) ys = torch.tensor(hyp['yseq']).unsqueeze(0) # FIXME: jit does not match non-jit result if use_jit: if traced_decoder is None: traced_decoder = torch.jit.trace(self.decoder.forward_one_step, (ys, ys_mask, enc_output)) local_att_scores = traced_decoder(ys, ys_mask, enc_output)[0] else: local_att_scores = self.decoder.forward_one_step(ys, ys_mask, enc_output)[0] if rnnlm: rnnlm_state, local_lm_scores = rnnlm.predict(hyp['rnnlm_prev'], vy) local_scores = local_att_scores + trans_args.lm_weight * local_lm_scores else: local_scores = local_att_scores local_best_scores, local_best_ids = torch.topk(local_scores, beam, dim=1) for j in six.moves.range(beam): new_hyp = {} new_hyp['score'] = hyp['score'] + float(local_best_scores[0, j]) new_hyp['yseq'] = [0] * (1 + len(hyp['yseq'])) new_hyp['yseq'][:len(hyp['yseq'])] = hyp['yseq'] new_hyp['yseq'][len(hyp['yseq'])] = int(local_best_ids[0, j]) if rnnlm: new_hyp['rnnlm_prev'] = rnnlm_state # will be (2 x beam) hyps at most hyps_best_kept.append(new_hyp) hyps_best_kept = sorted( hyps_best_kept, key=lambda x: x['score'], reverse=True)[:beam] # sort and get nbest hyps = hyps_best_kept logging.debug('number of pruned hypothes: ' + str(len(hyps))) if char_list is not None: logging.debug( 'best hypo: ' + ''.join([char_list[int(x)] for x in hyps[0]['yseq'][1:]])) # add eos in the final loop to avoid that there are no ended hyps if i == maxlen - 1: logging.info('adding <eos> in the last postion in the loop') for hyp in hyps: hyp['yseq'].append(self.eos) # add ended hypothes to a final list, and removed them from current hypothes # (this will be a probmlem, number of hyps < beam) remained_hyps = [] for hyp in hyps: if hyp['yseq'][-1] == self.eos: # only store the sequence that has more than minlen outputs # also add penalty if len(hyp['yseq']) > minlen: hyp['score'] += (i + 1) * penalty if rnnlm: # Word LM needs to add final <eos> score hyp['score'] += trans_args.lm_weight * rnnlm.final( hyp['rnnlm_prev']) ended_hyps.append(hyp) else: remained_hyps.append(hyp) # end detection from espnet.nets.e2e_asr_common import end_detect if end_detect(ended_hyps, i) and trans_args.maxlenratio == 0.0: logging.info('end detected at %d', i) break hyps = remained_hyps if len(hyps) > 0: logging.debug('remeined hypothes: ' + str(len(hyps))) else: logging.info('no hypothesis. Finish decoding.') break if char_list is not None: for hyp in hyps: logging.debug( 'hypo: ' + ''.join([char_list[int(x)] for x in hyp['yseq'][1:]])) logging.debug('number of ended hypothes: ' + str(len(ended_hyps))) nbest_hyps = sorted( ended_hyps, key=lambda x: x['score'], reverse=True)[:min(len(ended_hyps), trans_args.nbest)] # check number of hypotheis if len(nbest_hyps) == 0: logging.warning('there is no N-best results, perform recognition again with smaller minlenratio.') # should copy becasuse Namespace will be overwritten globally trans_args = Namespace(**vars(trans_args)) trans_args.minlenratio = max(0.0, trans_args.minlenratio - 0.1) return self.translate(x, trans_args, char_list, rnnlm) logging.info('total log probability: ' + str(nbest_hyps[0]['score'])) logging.info('normalized log probability: ' + str(nbest_hyps[0]['score'] / len(nbest_hyps[0]['yseq']))) return nbest_hyps def calculate_all_attentions(self, xs_pad, ilens, ys_pad, ys_pad_src): """E2E attention calculation. :param torch.Tensor xs_pad: batch of padded input sequences (B, Tmax, idim) :param torch.Tensor ilens: batch of lengths of input sequences (B) :param torch.Tensor ys_pad: batch of padded token id sequence tensor (B, Lmax) :param torch.Tensor ys_pad_src: batch of padded token id sequence tensor (B, Lmax) :return: attention weights with the following shape, 1) multi-head case => attention weights (B, H, Lmax, Tmax), 2) other case => attention weights (B, Lmax, Tmax). :rtype: float ndarray """ with torch.no_grad(): self.forward(xs_pad, ilens, ys_pad, ys_pad_src) ret = dict() for name, m in self.named_modules(): if isinstance(m, MultiHeadedAttention) and m.attn is not None: # skip MHA for submodules ret[name] = m.attn.cpu().numpy() return ret
def __init__(self, idim, odim, args=None): """Initialize Transformer-VC module. Args: idim (int): Dimension of the inputs. odim (int): Dimension of the outputs. args (Namespace, optional): - eprenet_conv_layers (int): Number of encoder prenet convolution layers. - eprenet_conv_chans (int): Number of encoder prenet convolution channels. - eprenet_conv_filts (int): Filter size of encoder prenet convolution. - transformer_input_layer (str): Input layer before the encoder. - dprenet_layers (int): Number of decoder prenet layers. - dprenet_units (int): Number of decoder prenet hidden units. - elayers (int): Number of encoder layers. - eunits (int): Number of encoder hidden units. - adim (int): Number of attention transformation dimensions. - aheads (int): Number of heads for multi head attention. - dlayers (int): Number of decoder layers. - dunits (int): Number of decoder hidden units. - postnet_layers (int): Number of postnet layers. - postnet_chans (int): Number of postnet channels. - postnet_filts (int): Filter size of postnet. - use_scaled_pos_enc (bool): Whether to use trainable scaled positional encoding. - use_batch_norm (bool): Whether to use batch normalization in encoder prenet. - encoder_normalize_before (bool): Whether to perform layer normalization before encoder block. - decoder_normalize_before (bool): Whether to perform layer normalization before decoder block. - encoder_concat_after (bool): Whether to concatenate attention layer's input and output in encoder. - decoder_concat_after (bool): Whether to concatenate attention layer's input and output in decoder. - reduction_factor (int): Reduction factor (for decoder). - encoder_reduction_factor (int): Reduction factor (for encoder). - spk_embed_dim (int): Number of speaker embedding dimenstions. - spk_embed_integration_type: How to integrate speaker embedding. - transformer_init (float): How to initialize transformer parameters. - transformer_lr (float): Initial value of learning rate. - transformer_warmup_steps (int): Optimizer warmup steps. - transformer_enc_dropout_rate (float): Dropout rate in encoder except attention & positional encoding. - transformer_enc_positional_dropout_rate (float): Dropout rate after encoder positional encoding. - transformer_enc_attn_dropout_rate (float): Dropout rate in encoder self-attention module. - transformer_dec_dropout_rate (float): Dropout rate in decoder except attention & positional encoding. - transformer_dec_positional_dropout_rate (float): Dropout rate after decoder positional encoding. - transformer_dec_attn_dropout_rate (float): Dropout rate in deocoder self-attention module. - transformer_enc_dec_attn_dropout_rate (float): Dropout rate in encoder-deocoder attention module. - eprenet_dropout_rate (float): Dropout rate in encoder prenet. - dprenet_dropout_rate (float): Dropout rate in decoder prenet. - postnet_dropout_rate (float): Dropout rate in postnet. - use_masking (bool): Whether to apply masking for padded part in loss calculation. - use_weighted_masking (bool): Whether to apply weighted masking in loss calculation. - bce_pos_weight (float): Positive sample weight in bce calculation (only for use_masking=true). - loss_type (str): How to calculate loss. - use_guided_attn_loss (bool): Whether to use guided attention loss. - num_heads_applied_guided_attn (int): Number of heads in each layer to apply guided attention loss. - num_layers_applied_guided_attn (int): Number of layers to apply guided attention loss. - modules_applied_guided_attn (list): List of module names to apply guided attention loss. - guided-attn-loss-sigma (float) Sigma in guided attention loss. - guided-attn-loss-lambda (float): Lambda in guided attention loss. """ # initialize base classes TTSInterface.__init__(self) torch.nn.Module.__init__(self) # fill missing arguments args = fill_missing_args(args, self.add_arguments) # store hyperparameters self.idim = idim self.odim = odim self.spk_embed_dim = args.spk_embed_dim if self.spk_embed_dim is not None: self.spk_embed_integration_type = args.spk_embed_integration_type self.use_scaled_pos_enc = args.use_scaled_pos_enc self.reduction_factor = args.reduction_factor self.encoder_reduction_factor = args.encoder_reduction_factor self.transformer_input_layer = args.transformer_input_layer self.loss_type = args.loss_type self.use_guided_attn_loss = args.use_guided_attn_loss if self.use_guided_attn_loss: if args.num_layers_applied_guided_attn == -1: self.num_layers_applied_guided_attn = args.elayers else: self.num_layers_applied_guided_attn = ( args.num_layers_applied_guided_attn) if args.num_heads_applied_guided_attn == -1: self.num_heads_applied_guided_attn = args.aheads else: self.num_heads_applied_guided_attn = args.num_heads_applied_guided_attn self.modules_applied_guided_attn = args.modules_applied_guided_attn # use idx 0 as padding idx padding_idx = 0 # get positional encoding class pos_enc_class = (ScaledPositionalEncoding if self.use_scaled_pos_enc else PositionalEncoding) # define transformer encoder if args.eprenet_conv_layers != 0: # encoder prenet encoder_input_layer = torch.nn.Sequential( EncoderPrenet( idim=idim, elayers=0, econv_layers=args.eprenet_conv_layers, econv_chans=args.eprenet_conv_chans, econv_filts=args.eprenet_conv_filts, use_batch_norm=args.use_batch_norm, dropout_rate=args.eprenet_dropout_rate, padding_idx=padding_idx, input_layer=torch.nn.Linear( idim * args.encoder_reduction_factor, idim), ), torch.nn.Linear(args.eprenet_conv_chans, args.adim), ) elif args.transformer_input_layer == "linear": encoder_input_layer = torch.nn.Linear( idim * args.encoder_reduction_factor, args.adim) else: encoder_input_layer = args.transformer_input_layer self.encoder = Encoder( idim=idim, attention_dim=args.adim, attention_heads=args.aheads, linear_units=args.eunits, num_blocks=args.elayers, input_layer=encoder_input_layer, dropout_rate=args.transformer_enc_dropout_rate, positional_dropout_rate=args. transformer_enc_positional_dropout_rate, attention_dropout_rate=args.transformer_enc_attn_dropout_rate, pos_enc_class=pos_enc_class, normalize_before=args.encoder_normalize_before, concat_after=args.encoder_concat_after, positionwise_layer_type=args.positionwise_layer_type, positionwise_conv_kernel_size=args.positionwise_conv_kernel_size, ) # define projection layer if self.spk_embed_dim is not None: if self.spk_embed_integration_type == "add": self.projection = torch.nn.Linear(self.spk_embed_dim, args.adim) else: self.projection = torch.nn.Linear( args.adim + self.spk_embed_dim, args.adim) # define transformer decoder if args.dprenet_layers != 0: # decoder prenet decoder_input_layer = torch.nn.Sequential( DecoderPrenet( idim=odim, n_layers=args.dprenet_layers, n_units=args.dprenet_units, dropout_rate=args.dprenet_dropout_rate, ), torch.nn.Linear(args.dprenet_units, args.adim), ) else: decoder_input_layer = "linear" self.decoder = Decoder( odim=-1, attention_dim=args.adim, attention_heads=args.aheads, linear_units=args.dunits, num_blocks=args.dlayers, dropout_rate=args.transformer_dec_dropout_rate, positional_dropout_rate=args. transformer_dec_positional_dropout_rate, self_attention_dropout_rate=args.transformer_dec_attn_dropout_rate, src_attention_dropout_rate=args. transformer_enc_dec_attn_dropout_rate, input_layer=decoder_input_layer, use_output_layer=False, pos_enc_class=pos_enc_class, normalize_before=args.decoder_normalize_before, concat_after=args.decoder_concat_after, ) # define final projection self.feat_out = torch.nn.Linear(args.adim, odim * args.reduction_factor) self.prob_out = torch.nn.Linear(args.adim, args.reduction_factor) # define postnet self.postnet = (None if args.postnet_layers == 0 else Postnet( idim=idim, odim=odim, n_layers=args.postnet_layers, n_chans=args.postnet_chans, n_filts=args.postnet_filts, use_batch_norm=args.use_batch_norm, dropout_rate=args.postnet_dropout_rate, )) # define loss function self.criterion = TransformerLoss( use_masking=args.use_masking, use_weighted_masking=args.use_weighted_masking, bce_pos_weight=args.bce_pos_weight, ) if self.use_guided_attn_loss: self.attn_criterion = GuidedMultiHeadAttentionLoss( sigma=args.guided_attn_loss_sigma, alpha=args.guided_attn_loss_lambda, ) # initialize parameters self._reset_parameters( init_type=args.transformer_init, init_enc_alpha=args.initial_encoder_alpha, init_dec_alpha=args.initial_decoder_alpha, ) # load pretrained model if args.pretrained_model is not None: self.load_pretrained_model(args.pretrained_model)
class E2E(ASRInterface, torch.nn.Module): """E2E module. :param int idim: dimension of inputs :param int odim: dimension of outputs :param Namespace args: argument Namespace containing options """ @staticmethod def add_arguments(parser): """Add arguments.""" group = parser.add_argument_group("transformer model setting") group.add_argument("--transformer-init", type=str, default="pytorch", choices=[ "pytorch", "xavier_uniform", "xavier_normal", "kaiming_uniform", "kaiming_normal" ], help='how to initialize transformer parameters') group.add_argument("--transformer-input-layer", type=str, default="conv2d", choices=["conv2d", "linear", "embed"], help='transformer input layer type') group.add_argument( '--transformer-attn-dropout-rate', default=None, type=float, help= 'dropout in transformer attention. use --dropout-rate if None is set' ) group.add_argument('--transformer-lr', default=10.0, type=float, help='Initial value of learning rate') group.add_argument('--transformer-warmup-steps', default=25000, type=int, help='optimizer warmup steps') group.add_argument('--transformer-length-normalized-loss', default=True, type=strtobool, help='normalize loss by length') group.add_argument('--dropout-rate', default=0.0, type=float, help='Dropout rate for the encoder') # Encoder group.add_argument( '--elayers', default=4, type=int, help= 'Number of encoder layers (for shared recognition part in multi-speaker asr mode)' ) group.add_argument('--eunits', '-u', default=300, type=int, help='Number of encoder hidden units') # Attention group.add_argument( '--adim', default=320, type=int, help='Number of attention transformation dimensions') group.add_argument('--aheads', default=4, type=int, help='Number of heads for multi head attention') # Decoder group.add_argument('--dlayers', default=1, type=int, help='Number of decoder layers') group.add_argument('--dunits', default=320, type=int, help='Number of decoder hidden units') # ctc init path group.add_argument('--pretrained-cn-ctc-model', default='', type=str, help='pretrained cn ctc model') group.add_argument('--pretrained-en-ctc-model', default='', type=str, help='pretrained en ctc model') group.add_argument('--pretrained-cn-jca-model', default='', type=str, help='pretrained cn jca model') return parser @property def attention_plot_class(self): """Return PlotAttentionReport.""" return PlotAttentionReport def __init__(self, idim, odim, args, ignore_id=-1): """Construct an E2E object. :param int idim: dimension of inputs :param int odim: dimension of outputs :param Namespace args: argument Namespace containing options """ torch.nn.Module.__init__(self) if args.transformer_attn_dropout_rate is None: args.transformer_attn_dropout_rate = args.dropout_rate self.cn_encoder = Encoder( idim=idim, attention_dim=args.adim, attention_heads=args.aheads, linear_units=args.eunits, num_blocks=args.elayers, input_layer=args.transformer_input_layer, dropout_rate=args.dropout_rate, positional_dropout_rate=args.dropout_rate, attention_dropout_rate=args.transformer_attn_dropout_rate) self.en_encoder = Encoder( idim=idim, attention_dim=args.adim, attention_heads=args.aheads, linear_units=args.eunits, num_blocks=args.elayers, input_layer=args.transformer_input_layer, dropout_rate=args.dropout_rate, positional_dropout_rate=args.dropout_rate, attention_dropout_rate=args.transformer_attn_dropout_rate) self.decoder = Decoder( odim=odim, attention_dim=args.adim, attention_heads=args.aheads, linear_units=args.dunits, num_blocks=args.dlayers, dropout_rate=args.dropout_rate, positional_dropout_rate=args.dropout_rate, self_attention_dropout_rate=args.transformer_attn_dropout_rate, src_attention_dropout_rate=args.transformer_attn_dropout_rate) self.sos = odim - 1 self.eos = odim - 1 self.odim = odim self.ignore_id = ignore_id self.subsample = [1] self.reporter = Reporter() # self.lsm_weight = a self.criterion = LabelSmoothingLoss( self.odim, self.ignore_id, args.lsm_weight, args.transformer_length_normalized_loss) # self.verbose = args.verbose self.adim = args.adim self.mtlalpha = args.mtlalpha if args.mtlalpha > 0.0: # Note: here CTC also need to have seperate ctc_lo layer self.ctc = CTC(odim, args.adim, args.dropout_rate, ctc_type=args.ctc_type, reduce=True) else: self.ctc = None if args.report_cer or args.report_wer: from espnet.nets.e2e_asr_common import ErrorCalculator self.error_calculator = ErrorCalculator(args.char_list, args.sym_space, args.sym_blank, args.report_cer, args.report_wer) else: self.error_calculator = None self.rnnlm = None # yzl23 config self.remove_blank_in_ctc_mode = True self.reset_parameters(args) # reset params at the last def reset_parameters(self, args): """Initialize parameters.""" # load state_dict, and keeps only encoder part # note that self.ctc.ctc_lo is also removed # prefix is added to meet the needs of moe structure def load_state_dict_encoder(path, prefix=''): if 'snapshot' in path: model_state_dict = torch.load( path, map_location=lambda storage, loc: storage)['model'] else: model_state_dict = torch.load( path, map_location=lambda storage, loc: storage) for k in list(model_state_dict.keys()): if 'encoder' in k: new_k = k.replace('encoder.', prefix + 'encoder.') model_state_dict[new_k] = model_state_dict.pop(k) elif 'ctc_lo' in k: new_k = k.replace('ctc_lo', prefix + 'ctc_lo') model_state_dict[new_k] = model_state_dict.pop(k) else: # remove this key del model_state_dict[k] return model_state_dict def load_state_dict_all(path, prefix=''): if 'snapshot' in path: model_state_dict = torch.load( path, map_location=lambda storage, loc: storage)['model'] else: model_state_dict = torch.load( path, map_location=lambda storage, loc: storage) for k in list(model_state_dict.keys()): if 'encoder' in k: new_k = k.replace('encoder.', prefix + 'encoder.') model_state_dict[new_k] = model_state_dict.pop(k) elif 'ctc_lo' in k: new_k = k.replace('ctc_lo', prefix + 'ctc_lo') model_state_dict[new_k] = model_state_dict.pop(k) else: continue return model_state_dict # initialize parameters if args.pretrained_cn_ctc_model and args.pretrained_en_ctc_model: logging.warning( "loading pretrained ctc model for parallel encoder") # still need to initialize the 'other' params initialize(self, args.transformer_init) cn_state_dict = load_state_dict_encoder( args.pretrained_cn_ctc_model, prefix='cn_') self.load_state_dict(cn_state_dict, strict=False) del cn_state_dict en_state_dict = load_state_dict_encoder( args.pretrained_en_ctc_model, prefix='en_') self.load_state_dict(en_state_dict, strict=False) del en_state_dict elif args.pretrained_cn_jca_model and args.pretrained_en_ctc_model: logging.warning( "loading pretrained cn-jca & en-ctc model for parallel encoder" ) initialize(self, args.transformer_init) en_state_dict = load_state_dict_encoder( args.pretrained_en_ctc_model, prefix='en_') self.load_state_dict(en_state_dict, strict=False) del en_state_dict cn_state_dict = load_state_dict_all(args.pretrained_cn_jca_model, prefix='cn_') self.load_state_dict(cn_state_dict, strict=False) del cn_state_dict else: initialize(self, args.transformer_init) def forward(self, xs_pad, ilens, ys_pad, moe_coes, moe_coe_lens): """E2E forward. :param torch.Tensor xs_pad: batch of padded source sequences (B, Tmax, idim) :param torch.Tensor ilens: batch of lengths of source sequences (B) :param torch.Tensor ys_pad: batch of padded target sequences (B, Lmax) :return: ctc loass value :rtype: torch.Tensor :return: attention loss value :rtype: torch.Tensor :return: accuracy in attention decoder :rtype: float """ # 1. forward encoder moe_coes = moe_coes[:, :max(moe_coe_lens)] # for data parallel xs_pad = xs_pad[:, :max(ilens)] # for data parallel src_mask = (~make_pad_mask(ilens.tolist())).to( xs_pad.device).unsqueeze(-2) # multi-encoder forward cn_hs_pad, hs_mask = self.cn_encoder(xs_pad, src_mask) en_hs_pad, hs_mask = self.en_encoder(xs_pad, src_mask) moe_coes = moe_coes.unsqueeze(-1) hs_pad = cn_hs_pad * moe_coes[:, :, 1] + en_hs_pad * moe_coes[:, :, 0] self.hs_pad = hs_pad # TODO(karita) show predicted text # TODO(karita) calculate these stats cer_ctc = None if self.mtlalpha == 0.0: loss_ctc = None else: batch_size = xs_pad.size(0) hs_len = hs_mask.view(batch_size, -1).sum(1) loss_ctc = self.ctc(hs_pad.view(batch_size, -1, self.adim), hs_len, ys_pad, moe_coes) if self.error_calculator is not None: ys_hat = self.ctc.argmax( hs_pad.view(batch_size, -1, self.adim), moe_coes).data cer_ctc = self.error_calculator(ys_hat.cpu(), ys_pad.cpu(), is_ctc=True) if self.mtlalpha == 1: self.loss_att, acc = None, None else: # 2. forward decoder ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id) ys_mask = target_mask(ys_in_pad, self.ignore_id) pred_pad, pred_mask = self.decoder(ys_in_pad, ys_mask, hs_pad, hs_mask) self.pred_pad = pred_pad # 3. compute attention loss loss_att = self.criterion(pred_pad, ys_out_pad) acc = th_accuracy(pred_pad.view(-1, self.odim), ys_out_pad, ignore_label=self.ignore_id) self.acc = acc # 5. compute cer/wer if self.training or self.error_calculator is None: cer, wer = None, None else: ys_hat = pred_pad.argmax(dim=-1) cer, wer = self.error_calculator(ys_hat.cpu(), ys_pad.cpu()) # copyied from e2e_asr alpha = self.mtlalpha if alpha == 0: self.loss = loss_att loss_att_data = float(loss_att) loss_ctc_data = None elif alpha == 1: self.loss = loss_ctc loss_att_data = None loss_ctc_data = float(loss_ctc) else: self.loss = alpha * loss_ctc + (1 - alpha) * loss_att loss_att_data = float(loss_att) loss_ctc_data = float(loss_ctc) loss_data = float(self.loss) if loss_data < CTC_LOSS_THRESHOLD and not math.isnan(loss_data): self.reporter.report(loss_ctc_data, loss_att_data, self.acc, cer_ctc, cer, wer, loss_data) else: logging.warning('loss (=%f) is not correct', loss_data) return self.loss def encode(self, x): """Encode acoustic features. :param ndarray x: source acoustic feature (T, D) :return: encoder outputs :rtype: torch.Tensor """ self.eval() fbank_feats, moe_coe = x x = torch.as_tensor(fbank_feats).unsqueeze(0) # (B, T, D) with #B=1 moe_coe = torch.as_tensor(moe_coe).unsqueeze(0) cn_enc_output, _ = self.cn_encoder(x, None) en_enc_output, _ = self.en_encoder(x, None) moe_coe = moe_coe.unsqueeze(-1) # (B, T, 2, 1) enc_output = cn_enc_output * moe_coe[:, :, 1] + en_enc_output * moe_coe[:, :, 0] return enc_output.squeeze(0) # returns tensor(T, D) def recognize(self, x, recog_args, char_list=None, rnnlm=None, use_jit=False): if recog_args.ctc_greedy_decoding: return self.recognize_ctc_greedy(x, recog_args) else: return self.recognize_jca(x, recog_args, char_list, rnnlm, use_jit) def store_penultimate_state(self, xs_pad, ilens, ys_pad, moe_coes, moe_coe_lens): moe_coes = moe_coes[:, :max(moe_coe_lens)] # for data parallel xs_pad = xs_pad[:, :max(ilens)] # for data parallel src_mask = (~make_pad_mask(ilens.tolist())).to( xs_pad.device).unsqueeze(-2) # multi-encoder forward cn_hs_pad, hs_mask = self.cn_encoder(xs_pad, src_mask) en_hs_pad, hs_mask = self.en_encoder(xs_pad, src_mask) moe_coes = moe_coes.unsqueeze(-1) hs_pad = cn_hs_pad * moe_coes[:, :, 1] + en_hs_pad * moe_coes[:, :, 0] self.hs_pad = hs_pad # forward decoder ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id) ys_mask = target_mask(ys_in_pad, self.ignore_id) pred_pad, pred_mask, penultimate_state = self.decoder( ys_in_pad, ys_mask, hs_pad, hs_mask, return_penultimate_state=True) # plot penultimate_state, (B,T,att_dim) return penultimate_state.squeeze(0).detach().cpu().numpy() def recognize_ctc_greedy(self, x, recog_args): """Recognize input speech with ctc greedy decoding. :param ndnarray x: input acoustic feature (B, T, D) or (T, D) :param Namespace recog_args: argment Namespace contraining options :return: N-best decoding results (fake results for compatibility) :rtype: list """ enc_output = self.encode(x).unsqueeze(0) # (1, T, D) lpz = self.ctc.log_softmax(enc_output, torch.as_tensor(x[1]).unsqueeze(0)) lpz = lpz.squeeze(0) # shape of (T, D) idx = lpz.argmax(-1).cpu().numpy().tolist() hyp = {} if recog_args.ctc_raw_results: hyp['yseq'] = [ self.sos ] + idx # not apply ctc mapping, to get ctc alignment else: # <sos> is added here to be compatible with S2S decoding, # file: espnet/asr/asr_utils/parse_hypothesis hyp['yseq'] = [self.sos] + self.ctc_mapping(idx) logging.info(hyp['yseq']) hyp['score'] = -1 return [hyp] def ctc_mapping(self, x, blank=0): prev = blank y = [] for i in x: if i != blank and i != prev: y.append(i) prev = i return y def recognize_jca(self, x, recog_args, char_list=None, rnnlm=None, use_jit=False): """Recognize input speech. :param ndnarray x: input acoustic feature (B, T, D) or (T, D) :param Namespace recog_args: argment Namespace contraining options :param list char_list: list of characters :param torch.nn.Module rnnlm: language model module :return: N-best decoding results :rtype: list """ enc_output = self.encode(x).unsqueeze(0) # (1, T, D) if recog_args.ctc_weight > 0.0: lpz = self.ctc.log_softmax(enc_output, torch.as_tensor(x[1]).unsqueeze(0)) lpz = lpz.squeeze(0) # shape of (T, D) else: lpz = None h = enc_output.squeeze(0) # (B, T, D), #B=1 logging.info('input lengths: ' + str(h.size(0))) # search parms beam = recog_args.beam_size penalty = recog_args.penalty ctc_weight = recog_args.ctc_weight # preprare sos y = self.sos vy = h.new_zeros(1).long() if recog_args.maxlenratio == 0: maxlen = h.shape[0] else: # maxlen >= 1 maxlen = max(1, int(recog_args.maxlenratio * h.size(0))) minlen = int(recog_args.minlenratio * h.size(0)) logging.info('max output length: ' + str(maxlen)) logging.info('min output length: ' + str(minlen)) # initialize hypothesis if rnnlm: hyp = {'score': 0.0, 'yseq': [y], 'rnnlm_prev': None} else: hyp = {'score': 0.0, 'yseq': [y]} if lpz is not None: import numpy from espnet.nets.ctc_prefix_score import CTCPrefixScore ctc_prefix_score = CTCPrefixScore(lpz.detach().numpy(), 0, self.eos, numpy) hyp['ctc_state_prev'] = ctc_prefix_score.initial_state() hyp['ctc_score_prev'] = 0.0 if ctc_weight != 1.0: # pre-pruning based on attention scores from espnet.nets.pytorch_backend.rnn.decoders import CTC_SCORING_RATIO ctc_beam = min(lpz.shape[-1], int(beam * CTC_SCORING_RATIO)) else: if self.remove_blank_in_ctc_mode: ctc_beam = lpz.shape[-1] - 1 # except blank else: ctc_beam = lpz.shape[-1] hyps = [hyp] ended_hyps = [] import six traced_decoder = None for i in six.moves.range(maxlen): logging.debug('position ' + str(i)) hyps_best_kept = [] for hyp in hyps: vy.unsqueeze(1) vy[0] = hyp['yseq'][i] # get nbest local scores and their ids ys_mask = subsequent_mask(i + 1).unsqueeze( 0) # mask scores of future state ys = torch.tensor(hyp['yseq']).unsqueeze(0) # FIXME: jit does not match non-jit result if use_jit: if traced_decoder is None: traced_decoder = torch.jit.trace( self.decoder.forward_one_step, (ys, ys_mask, enc_output)) local_att_scores = traced_decoder(ys, ys_mask, enc_output)[0] else: local_att_scores = self.decoder.forward_one_step( ys, ys_mask, enc_output)[0] if rnnlm: rnnlm_state, local_lm_scores = rnnlm.predict( hyp['rnnlm_prev'], vy) local_scores = local_att_scores + recog_args.lm_weight * local_lm_scores else: local_scores = local_att_scores if lpz is not None: if self.remove_blank_in_ctc_mode: # here we need to filter out <blank> in local_best_ids # it happens in pure ctc-mode, when ctc_beam equals to #vocab local_best_scores, local_best_ids = torch.topk( local_att_scores[:, 1:], ctc_beam, dim=1) local_best_ids += 1 # hack else: local_best_scores, local_best_ids = torch.topk( local_att_scores, ctc_beam, dim=1) ctc_scores, ctc_states = ctc_prefix_score( hyp['yseq'], local_best_ids[0], hyp['ctc_state_prev']) local_scores = \ (1.0 - ctc_weight) * local_att_scores[:, local_best_ids[0]] \ + ctc_weight * torch.from_numpy(ctc_scores - hyp['ctc_score_prev']) if rnnlm: local_scores += recog_args.lm_weight * local_lm_scores[:, local_best_ids[ 0]] local_best_scores, joint_best_ids = torch.topk( local_scores, beam, dim=1) local_best_ids = local_best_ids[:, joint_best_ids[0]] else: local_best_scores, local_best_ids = torch.topk( local_scores, beam, dim=1) for j in six.moves.range(beam): new_hyp = {} new_hyp['score'] = hyp['score'] + float( local_best_scores[0, j]) new_hyp['yseq'] = [0] * (1 + len(hyp['yseq'])) new_hyp['yseq'][:len(hyp['yseq'])] = hyp['yseq'] new_hyp['yseq'][len(hyp['yseq'])] = int(local_best_ids[0, j]) if rnnlm: new_hyp['rnnlm_prev'] = rnnlm_state if lpz is not None: new_hyp['ctc_state_prev'] = ctc_states[joint_best_ids[ 0, j]] new_hyp['ctc_score_prev'] = ctc_scores[joint_best_ids[ 0, j]] # will be (2 x beam) hyps at most hyps_best_kept.append(new_hyp) hyps_best_kept = sorted(hyps_best_kept, key=lambda x: x['score'], reverse=True)[:beam] # sort and get nbest hyps = hyps_best_kept logging.debug('number of pruned hypothes: ' + str(len(hyps))) if char_list is not None: logging.debug( 'best hypo: ' + ''.join([char_list[int(x)] for x in hyps[0]['yseq'][1:]])) # add eos in the final loop to avoid that there are no ended hyps if i == maxlen - 1: logging.info('adding <eos> in the last postion in the loop') for hyp in hyps: hyp['yseq'].append(self.eos) # add ended hypothes to a final list, and removed them from current hypothes # (this will be a probmlem, number of hyps < beam) remained_hyps = [] for hyp in hyps: if hyp['yseq'][-1] == self.eos: # only store the sequence that has more than minlen outputs # also add penalty if len(hyp['yseq']) > minlen: hyp['score'] += (i + 1) * penalty if rnnlm: # Word LM needs to add final <eos> score hyp['score'] += recog_args.lm_weight * rnnlm.final( hyp['rnnlm_prev']) ended_hyps.append(hyp) else: remained_hyps.append(hyp) # end detection # from espnet.nets.e2e_asr_common import end_detect # if end_detect(ended_hyps, i) and recog_args.maxlenratio == 0.0: from espnet.nets.e2e_asr_common import end_detect_yzl23 if end_detect_yzl23(ended_hyps, remained_hyps, penalty) and recog_args.maxlenratio == 0.0: logging.info('end detected at %d', i) break hyps = remained_hyps if len(hyps) > 0: logging.debug('remeined hypothes: ' + str(len(hyps))) else: logging.info('no hypothesis. Finish decoding.') break if char_list is not None: for hyp in hyps: logging.debug( 'hypo: ' + ''.join([char_list[int(x)] for x in hyp['yseq'][1:]])) logging.debug('number of ended hypothes: ' + str(len(ended_hyps))) nbest_hyps = sorted( ended_hyps, key=lambda x: x['score'], reverse=True)[:min(len(ended_hyps), recog_args.nbest)] # check number of hypotheis if len(nbest_hyps) == 0: logging.warning( 'there is no N-best results, perform recognition again with smaller minlenratio.' ) # should copy becasuse Namespace will be overwritten globally recog_args = Namespace(**vars(recog_args)) recog_args.minlenratio = max(0.0, recog_args.minlenratio - 0.1) return self.recognize(x, recog_args, char_list, rnnlm) logging.info('total log probability: ' + str(nbest_hyps[0]['score'])) logging.info('normalized log probability: ' + str(nbest_hyps[0]['score'] / len(nbest_hyps[0]['yseq']))) return nbest_hyps def calculate_all_attentions(self, xs_pad, ilens, ys_pad, moe_coes, moe_coe_lens): """E2E attention calculation. :param torch.Tensor xs_pad: batch of padded input sequences (B, Tmax, idim) :param torch.Tensor ilens: batch of lengths of input sequences (B) :param torch.Tensor ys_pad: batch of padded token id sequence tensor (B, Lmax) :return: attention weights with the following shape, 1) multi-head case => attention weights (B, H, Lmax, Tmax), 2) other case => attention weights (B, Lmax, Tmax). :rtype: float ndarray """ with torch.no_grad(): self.forward(xs_pad, ilens, ys_pad, moe_coes, moe_coe_lens) ret = dict() for name, m in self.named_modules(): if isinstance(m, MultiHeadedAttention): ret[name] = m.attn.cpu().numpy() return ret
class E2E(torch.nn.Module): """E2E module. :param int idim: dimension of inputs :param int odim: dimension of outputs :param Namespace args: argument Namespace containing options """ @staticmethod def add_arguments(parser): """Add arguments.""" group = parser.add_argument_group("transformer model setting") group.add_argument("--transformer-init", type=str, default="pytorch", choices=[ "pytorch", "xavier_uniform", "xavier_normal", "kaiming_uniform", "kaiming_normal" ], help='how to initialize transformer parameters') group.add_argument("--transformer-input-layer", type=str, default="conv2d", choices=["conv2d", "linear", "embed", "custom"], help='transformer input layer type') group.add_argument("--transformer-output-layer", type=str, default='embed', choices=['conv', 'embed', 'linear']) group.add_argument( '--transformer-attn-dropout-rate', default=None, type=float, help= 'dropout in transformer attention. use --dropout-rate if None is set' ) group.add_argument('--transformer-lr', default=10.0, type=float, help='Initial value of learning rate') group.add_argument('--transformer-warmup-steps', default=25000, type=int, help='optimizer warmup steps') group.add_argument('--transformer-length-normalized-loss', default=True, type=strtobool, help='normalize loss by length') group.add_argument('--dropout-rate', default=0.0, type=float, help='Dropout rate for the encoder') # Encoder group.add_argument( '--elayers', default=4, type=int, help= 'Number of encoder layers (for shared recognition part in multi-speaker asr mode)' ) group.add_argument('--eunits', '-u', default=300, type=int, help='Number of encoder hidden units') # Attention group.add_argument( '--adim', default=320, type=int, help='Number of attention transformation dimensions') group.add_argument('--aheads', default=4, type=int, help='Number of heads for multi head attention') # Decoder group.add_argument('--dlayers', default=1, type=int, help='Number of decoder layers') group.add_argument('--dunits', default=320, type=int, help='Number of decoder hidden units') # Streaming params group.add_argument( '--chunk', default=True, type=strtobool, help= 'streaming mode, set True for chunk-encoder, False for look-ahead encoder' ) group.add_argument('--chunk-size', default=16, type=int, help='chunk size for chunk-based encoder') group.add_argument( '--left-window', default=1000, type=int, help='left window size for look-ahead based encoder') group.add_argument( '--right-window', default=1000, type=int, help='right window size for look-ahead based encoder') group.add_argument( '--dec-left-window', default=0, type=int, help='left window size for decoder (look-ahead based method)') group.add_argument( '--dec-right-window', default=6, type=int, help='right window size for decoder (look-ahead based method)') return parser def __init__(self, idim, odim, args, ignore_id=-1): """Construct an E2E object. :param int idim: dimension of inputs :param int odim: dimension of outputs :param Namespace args: argument Namespace containing options """ torch.nn.Module.__init__(self) if args.transformer_attn_dropout_rate is None: args.transformer_attn_dropout_rate = args.dropout_rate self.encoder = Encoder( idim=idim, attention_dim=args.adim, attention_heads=args.aheads, linear_units=args.eunits, num_blocks=args.elayers, input_layer=args.transformer_input_layer, dropout_rate=args.dropout_rate, positional_dropout_rate=args.dropout_rate, attention_dropout_rate=args.transformer_attn_dropout_rate, ) self.decoder = Decoder( odim=odim, attention_dim=args.adim, attention_heads=args.aheads, linear_units=args.dunits, num_blocks=args.dlayers, input_layer=args.transformer_output_layer, dropout_rate=args.dropout_rate, positional_dropout_rate=args.dropout_rate, self_attention_dropout_rate=args.transformer_attn_dropout_rate, src_attention_dropout_rate=args.transformer_attn_dropout_rate) self.sos = odim - 1 self.eos = odim - 1 self.odim = odim self.ignore_id = ignore_id self.subsample = [1] # self.lsm_weight = a self.criterion = LabelSmoothingLoss( self.odim, self.ignore_id, args.lsm_weight, args.transformer_length_normalized_loss) # self.verbose = args.verbose self.reset_parameters(args) self.adim = args.adim self.mtlalpha = args.mtlalpha if args.mtlalpha > 0.0: self.ctc = CTC(odim, args.adim, args.dropout_rate, ctc_type=args.ctc_type, reduce=True) else: self.ctc = None self.rnnlm = None self.left_window = args.dec_left_window self.right_window = args.dec_right_window def reset_parameters(self, args): """Initialize parameters.""" # initialize parameters initialize(self, args.transformer_init) def forward(self, xs_pad, ilens, ys_pad, enc_mask=None, dec_mask=None): """E2E forward. :param torch.Tensor xs_pad: batch of padded source sequences (B, Tmax, idim) :param torch.Tensor ilens: batch of lengths of source sequences (B) :param torch.Tensor ys_pad: batch of padded target sequences (B, Lmax) :return: ctc loass value :rtype: torch.Tensor :return: attention loss value :rtype: torch.Tensor :return: accuracy in attention decoder :rtype: float """ # 1. forward encoder xs_pad = xs_pad[:, :max(ilens)] # for data parallel batch_size = xs_pad.shape[0] src_mask = make_non_pad_mask(ilens.tolist()).to( xs_pad.device).unsqueeze(-2) if isinstance(self.encoder.embed, EncoderConv2d): xs, hs_mask = self.encoder.embed(xs_pad, torch.sum(src_mask, 2).squeeze()) hs_mask = hs_mask.unsqueeze(1) else: xs, hs_mask = self.encoder.embed(xs_pad, src_mask) if enc_mask is not None: enc_mask = enc_mask[:, :hs_mask.shape[2], :hs_mask.shape[2]] enc_mask = enc_mask & hs_mask if enc_mask is not None else hs_mask hs_pad, _ = self.encoder.encoders(xs, enc_mask) if self.encoder.normalize_before: hs_pad = self.encoder.after_norm(hs_pad) # CTC forward ys = [y[y != self.ignore_id] for y in ys_pad] y_len = max([len(y) for y in ys]) ys_pad = ys_pad[:, :y_len] if dec_mask is not None: dec_mask = dec_mask[:, :y_len + 1, :hs_pad.shape[1]] self.hs_pad = hs_pad batch_size = xs_pad.size(0) if self.mtlalpha == 0.0: loss_ctc = None else: batch_size = xs_pad.size(0) hs_len = hs_mask.view(batch_size, -1).sum(1) loss_ctc = self.ctc(hs_pad.view(batch_size, -1, self.adim), hs_len, ys_pad) # trigger mask hs_mask = hs_mask & dec_mask if dec_mask is not None else hs_mask # 2. forward decoder ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id) ys_mask = target_mask(ys_in_pad, self.ignore_id) pred_pad, pred_mask = self.decoder(ys_in_pad, ys_mask, hs_pad, hs_mask) self.pred_pad = pred_pad # 3. compute attention loss loss_att = self.criterion(pred_pad, ys_out_pad) self.acc = th_accuracy(pred_pad.view(-1, self.odim), ys_out_pad, ignore_label=self.ignore_id) # copyied from e2e_asr alpha = self.mtlalpha if alpha == 0: self.loss = loss_att loss_att_data = float(loss_att) loss_ctc_data = None elif alpha == 1: self.loss = loss_ctc loss_att_data = None loss_ctc_data = float(loss_ctc) else: self.loss = alpha * loss_ctc + (1 - alpha) * loss_att loss_att_data = float(loss_att) loss_ctc_data = float(loss_ctc) return self.loss, loss_ctc_data, loss_att_data, self.acc def scorers(self): """Scorers.""" return dict(decoder=self.decoder, ctc=CTCPrefixScorer(self.ctc, self.eos)) def encode(self, x, mask=None): """Encode acoustic features. :param ndarray x: source acoustic feature (T, D) :return: encoder outputs :rtype: torch.Tensor """ self.eval() x = torch.as_tensor(x).unsqueeze(0).cuda() if mask is not None: mask = mask.cuda() if isinstance(self.encoder.embed, EncoderConv2d): hs, _ = self.encoder.embed( x, torch.Tensor([float(x.shape[1])]).cuda()) else: hs, _ = self.encoder.embed(x, None) hs, _ = self.encoder.encoders(hs, mask) if self.encoder.normalize_before: hs = self.encoder.after_norm(hs) return hs.squeeze(0) def viterbi_decode(self, x, y, mask=None): enc_output = self.encode(x, mask) logits = self.ctc.ctc_lo(enc_output).detach().data logit = np.array(logits.cpu().data).T align = viterbi_align(logit, y)[0] return align def ctc_decode(self, x, mask=None): enc_output = self.encode(x, mask) logits = self.ctc.argmax(enc_output.view(1, -1, 512)).detach().data path = np.array(logits.cpu()[0]) return path def recognize(self, x, recog_args, char_list=None, rnnlm=None, use_jit=False): """Recognize input speech. :param ndnarray x: input acoustic feature (B, T, D) or (T, D) :param Namespace recog_args: argment Namespace contraining options :param list char_list: list of characters :param torch.nn.Module rnnlm: language model module :return: N-best decoding results :rtype: list """ enc_output = self.encode(x).unsqueeze(0) if recog_args.ctc_weight > 0.0: lpz = self.ctc.log_softmax(enc_output) lpz = lpz.squeeze(0) else: lpz = None h = enc_output.squeeze(0) logging.info('input lengths: ' + str(h.size(0))) # search parms beam = recog_args.beam_size penalty = recog_args.penalty ctc_weight = recog_args.ctc_weight # preprare sos y = self.sos vy = h.new_zeros(1).long() if recog_args.maxlenratio == 0: maxlen = h.shape[0] else: # maxlen >= 1 maxlen = max(1, int(recog_args.maxlenratio * h.size(0))) minlen = int(recog_args.minlenratio * h.size(0)) logging.info('max output length: ' + str(maxlen)) logging.info('min output length: ' + str(minlen)) # initialize hypothesis if rnnlm: hyp = {'score': 0.0, 'yseq': [y], 'rnnlm_prev': None} else: hyp = {'score': 0.0, 'yseq': [y]} if lpz is not None: import numpy from espnet.nets.ctc_prefix_score import CTCPrefixScore ctc_prefix_score = CTCPrefixScore(lpz.cpu().detach().numpy(), 0, self.eos, numpy) hyp['ctc_state_prev'] = ctc_prefix_score.initial_state() hyp['ctc_score_prev'] = 0.0 if ctc_weight != 1.0: # pre-pruning based on attention scores ctc_beam = min(lpz.shape[-1], int(beam * CTC_SCORING_RATIO)) else: ctc_beam = lpz.shape[-1] hyps = [hyp] ended_hyps = [] import six traced_decoder = None for i in six.moves.range(maxlen): logging.debug('position ' + str(i)) hyps_best_kept = [] for hyp in hyps: vy.unsqueeze(1) vy[0] = hyp['yseq'][i] # get nbest local scores and their ids ys_mask = subsequent_mask(i + 1).unsqueeze(0).cuda() ys = torch.tensor(hyp['yseq']).unsqueeze(0).cuda() # FIXME: jit does not match non-jit result if use_jit: if traced_decoder is None: traced_decoder = torch.jit.trace( self.decoder.forward_one_step, (ys, ys_mask, enc_output)) local_att_scores = traced_decoder(ys, ys_mask, enc_output)[0] else: local_att_scores = self.decoder.forward_one_step( ys, ys_mask, enc_output)[0] if rnnlm: rnnlm_state, local_lm_scores = rnnlm.predict( hyp['rnnlm_prev'], vy) local_scores = local_att_scores + recog_args.lm_weight * local_lm_scores else: local_scores = local_att_scores if lpz is not None: local_best_scores, local_best_ids = torch.topk( local_att_scores, ctc_beam, dim=1) ctc_scores, ctc_states = ctc_prefix_score( hyp['yseq'], local_best_ids[0].cpu(), hyp['ctc_state_prev']) local_scores = \ (1.0 - ctc_weight) * local_att_scores[:, local_best_ids[0]].cpu() \ + ctc_weight * torch.from_numpy(ctc_scores - hyp['ctc_score_prev']) if rnnlm: local_scores += recog_args.lm_weight * local_lm_scores[:, local_best_ids[ 0]].cpu() local_best_scores, joint_best_ids = torch.topk( local_scores, beam, dim=1) local_best_ids = local_best_ids[:, joint_best_ids[0]] else: local_best_scores, local_best_ids = torch.topk( local_scores, beam, dim=1) for j in six.moves.range(beam): new_hyp = {} new_hyp['score'] = hyp['score'] + float( local_best_scores[0, j]) new_hyp['yseq'] = [0] * (1 + len(hyp['yseq'])) new_hyp['yseq'][:len(hyp['yseq'])] = hyp['yseq'] new_hyp['yseq'][len(hyp['yseq'])] = int(local_best_ids[0, j]) if rnnlm: new_hyp['rnnlm_prev'] = rnnlm_state if lpz is not None: new_hyp['ctc_state_prev'] = ctc_states[joint_best_ids[ 0, j]] new_hyp['ctc_score_prev'] = ctc_scores[joint_best_ids[ 0, j]] # will be (2 x beam) hyps at most hyps_best_kept.append(new_hyp) hyps_best_kept = sorted(hyps_best_kept, key=lambda x: x['score'], reverse=True)[:beam] # sort and get nbest hyps = hyps_best_kept logging.debug('number of pruned hypothes: ' + str(len(hyps))) if char_list is not None: logging.debug( 'best hypo: ' + ''.join([char_list[int(x)] for x in hyps[0]['yseq'][1:]])) # add eos in the final loop to avoid that there are no ended hyps if i == maxlen - 1: logging.info('adding <eos> in the last postion in the loop') for hyp in hyps: hyp['yseq'].append(self.eos) # add ended hypothes to a final list, and removed them from current hypothes # (this will be a probmlem, number of hyps < beam) remained_hyps = [] for hyp in hyps: if hyp['yseq'][-1] == self.eos: # only store the sequence that has more than minlen outputs # also add penalty if len(hyp['yseq']) > minlen: hyp['score'] += (i + 1) * penalty if rnnlm: # Word LM needs to add final <eos> score hyp['score'] += recog_args.lm_weight * rnnlm.final( hyp['rnnlm_prev']) ended_hyps.append(hyp) else: remained_hyps.append(hyp) # end detection from espnet.nets.e2e_asr_common import end_detect if end_detect(ended_hyps, i) and recog_args.maxlenratio == 0.0: logging.info('end detected at %d', i) break hyps = remained_hyps if len(hyps) > 0: logging.debug('remeined hypothes: ' + str(len(hyps))) else: logging.info('no hypothesis. Finish decoding.') break if char_list is not None: for hyp in hyps: logging.debug( 'hypo: ' + ''.join([char_list[int(x)] for x in hyp['yseq'][1:]])) logging.debug('number of ended hypothes: ' + str(len(ended_hyps))) nbest_hyps = sorted( ended_hyps, key=lambda x: x['score'], reverse=True)[:min(len(ended_hyps), recog_args.nbest)] # check number of hypotheis if len(nbest_hyps) == 0: logging.warning( 'there is no N-best results, perform recognition again with smaller minlenratio.' ) # should copy becasuse Namespace will be overwritten globally recog_args = Namespace(**vars(recog_args)) recog_args.minlenratio = max(0.0, recog_args.minlenratio - 0.1) return self.recognize(x, recog_args, char_list, rnnlm) logging.info('total log probability: ' + str(nbest_hyps[0]['score'])) logging.info('normalized log probability: ' + str(nbest_hyps[0]['score'] / len(nbest_hyps[0]['yseq']))) return nbest_hyps def prefix_recognize(self, x, recog_args, train_args, char_list=None, rnnlm=None): '''recognize feat :param ndnarray x: input acouctic feature (B, T, D) or (T, D) :param namespace recog_args: argment namespace contraining options :param list char_list: list of characters :param torch.nn.Module rnnlm: language model module :return: N-best decoding results :rtype: list TODO(karita): do not recompute previous attention for faster decoding ''' pad_len = self.eos - len(char_list) + 1 for i in range(pad_len): char_list.append('<eos>') if isinstance(self.encoder.embed, EncoderConv2d): seq_len = ((x.shape[0] + 1) // 2 + 1) // 2 else: seq_len = ((x.shape[0] - 1) // 2 - 1) // 2 if train_args.chunk: s = np.arange(0, seq_len, train_args.chunk_size) mask = adaptive_enc_mask(seq_len, s).unsqueeze(0) else: mask = turncated_mask(1, seq_len, train_args.left_window, train_args.right_window) enc_output = self.encode(x, mask).unsqueeze(0) lpz = torch.nn.functional.softmax(self.ctc.ctc_lo(enc_output), dim=-1) lpz = lpz.squeeze(0) h = enc_output.squeeze(0) logging.info('input lengths: ' + str(h.size(0))) h_len = h.size(0) # search parms beam = recog_args.beam_size penalty = recog_args.penalty ctc_weight = recog_args.ctc_weight # preprare sos y = self.sos vy = h.new_zeros(1).long() if recog_args.maxlenratio == 0: maxlen = h.shape[0] else: # maxlen >= 1 maxlen = max(1, int(recog_args.maxlenratio * h.size(0))) minlen = int(recog_args.minlenratio * h.size(0)) hyp = { 'score': 0.0, 'yseq': [y], 'rnnlm_prev': None, 'seq': char_list[y], 'last_time': [], "ctc_score": 0.0, "rnnlm_score": 0.0, "att_score": 0.0, "cache": None, "precache": None, "preatt_score": 0.0, "prev_score": 0.0 } hyps = {char_list[y]: hyp} hyps_att = {char_list[y]: hyp} Pb_prev, Pnb_prev = Counter(), Counter() Pb, Pnb = Counter(), Counter() Pjoint = Counter() lpz = lpz.cpu().detach().numpy() vocab_size = lpz.shape[1] r = np.ndarray((vocab_size), dtype=np.float32) l = char_list[y] Pb_prev[l] = 1 Pnb_prev[l] = 0 A_prev = [l] A_prev_id = [[y]] vy.unsqueeze(1) total_copy = time.time() - time.time() samelen = 0 hat_att = {} if mask is not None: chunk_pos = set(np.array(mask.sum(dim=-1))[0]) for i in chunk_pos: hat_att[i] = {} else: hat_att[enc_output.shape[1]] = {} for i in range(h_len): hyps_ctc = {} threshold = recog_args.threshold #self.threshold #np.percentile(r, 98) pos_ctc = np.where(lpz[i] > threshold)[0] #self.removeIlegal(hyps) if mask is not None: chunk_index = mask[0][i].sum().item() else: chunk_index = h_len hyps_res = {} for l, hyp in hyps.items(): if l in hat_att[chunk_index]: hyp['tmp_cache'] = hat_att[chunk_index][l]['cache'] hyp['tmp_att'] = hat_att[chunk_index][l]['att_scores'] else: hyps_res[l] = hyp tmp = self.clusterbyLength( hyps_res ) # This step clusters hyps according to length dict:{length,hyps} start = time.time() # pre-compute beam self.compute_hyps(tmp, i, h_len, enc_output, hat_att[chunk_index], mask) total_copy += time.time() - start # Assign score and tokens to hyps #print(hyps.keys()) for l, hyp in hyps.items(): if 'tmp_att' not in hyp: continue #Todo check why local_att_scores = hyp['tmp_att'] local_best_scores, local_best_ids = torch.topk( local_att_scores, 5, dim=1) pos_att = np.array(local_best_ids[0].cpu()) pos = np.union1d(pos_ctc, pos_att) hyp['pos'] = pos # pre-compute ctc beam hyps_ctc_compute = self.get_ctchyps2compute(hyps, hyps_ctc, i) hyps_res2 = {} for l, hyp in hyps_ctc_compute.items(): l_minus = ' '.join(l.split()[:-1]) if l_minus in hat_att[chunk_index]: hyp['tmp_cur_new_cache'] = hat_att[chunk_index][l_minus][ 'cache'] hyp['tmp_cur_att_scores'] = hat_att[chunk_index][l_minus][ 'att_scores'] else: hyps_res2[l] = hyp tmp2_cluster = self.clusterbyLength(hyps_res2) self.compute_hyps_ctc(tmp2_cluster, h_len, enc_output, hat_att[chunk_index], mask) for l, hyp in hyps.items(): start = time.time() l_id = hyp['yseq'] l_end = l_id[-1] vy[0] = l_end prefix_len = len(l_id) if rnnlm: rnnlm_state, local_lm_scores = rnnlm.predict( hyp['rnnlm_prev'], vy) else: rnnlm_state = None local_lm_scores = torch.zeros([1, len(char_list)]) r = lpz[i] * (Pb_prev[l] + Pnb_prev[l]) start = time.time() if 'tmp_att' not in hyp: continue #Todo check why local_att_scores = hyp['tmp_att'] new_cache = hyp['tmp_cache'] align = [0] * prefix_len align[:prefix_len - 1] = hyp['last_time'][:] align[-1] = i pos = hyp['pos'] if 0 in pos or l_end in pos: if l not in hyps_ctc: hyps_ctc[l] = {'yseq': l_id} hyps_ctc[l]['rnnlm_prev'] = hyp['rnnlm_prev'] hyps_ctc[l]['rnnlm_score'] = hyp['rnnlm_score'] if l_end != self.eos: hyps_ctc[l]['last_time'] = [0] * prefix_len hyps_ctc[l]['last_time'][:] = hyp['last_time'][:] hyps_ctc[l]['last_time'][-1] = i try: cur_att_scores = hyps_ctc_compute[l][ "tmp_cur_att_scores"] cur_new_cache = hyps_ctc_compute[l][ "tmp_cur_new_cache"] except: pdb.set_trace() hyps_ctc[l]['att_score'] = hyp['preatt_score'] + \ float(cur_att_scores[0, l_end].data) hyps_ctc[l]['cur_att'] = float( cur_att_scores[0, l_end].data) hyps_ctc[l]['cache'] = cur_new_cache else: if len(hyps_ctc[l]["yseq"]) > 1: hyps_ctc[l]["end"] = True hyps_ctc[l]['last_time'] = [] hyps_ctc[l]['att_score'] = hyp['att_score'] hyps_ctc[l]['cur_att'] = 0 hyps_ctc[l]['cache'] = hyp['cache'] hyps_ctc[l]['prev_score'] = hyp['prev_score'] hyps_ctc[l]['preatt_score'] = hyp['preatt_score'] hyps_ctc[l]['precache'] = hyp['precache'] hyps_ctc[l]['seq'] = hyp['seq'] for c in list(pos): if c == 0: Pb[l] += lpz[i][0] * (Pb_prev[l] + Pnb_prev[l]) else: l_plus = l + " " + char_list[c] if l_plus not in hyps_ctc: hyps_ctc[l_plus] = {} if "end" in hyp: hyps_ctc[l_plus]['yseq'] = True hyps_ctc[l_plus]['yseq'] = [0] * (prefix_len + 1) hyps_ctc[l_plus]['yseq'][:len(hyp['yseq'])] = l_id hyps_ctc[l_plus]['yseq'][-1] = int(c) hyps_ctc[l_plus]['rnnlm_prev'] = rnnlm_state hyps_ctc[l_plus][ 'rnnlm_score'] = hyp['rnnlm_score'] + float( local_lm_scores[0, c].data) hyps_ctc[l_plus]['att_score'] = hyp['att_score'] \ + float(local_att_scores[0, c].data) hyps_ctc[l_plus]['cur_att'] = float( local_att_scores[0, c].data) hyps_ctc[l_plus]['cache'] = new_cache hyps_ctc[l_plus]['precache'] = hyp['cache'] hyps_ctc[l_plus]['preatt_score'] = hyp['att_score'] hyps_ctc[l_plus]['prev_score'] = hyp['score'] hyps_ctc[l_plus]['last_time'] = align hyps_ctc[l_plus]['rule_penalty'] = 0 hyps_ctc[l_plus]['seq'] = l_plus if l_end != self.eos and c == l_end: Pnb[l_plus] += lpz[i][l_end] * Pb_prev[l] Pnb[l] += lpz[i][l_end] * Pnb_prev[l] else: Pnb[l_plus] += r[c] if l_plus not in hyps: Pb[l_plus] += lpz[i][0] * (Pb_prev[l_plus] + Pnb_prev[l_plus]) Pb[l_plus] += lpz[i][c] * Pnb_prev[l_plus] #total_copy += time.time() - start for l in hyps_ctc.keys(): if Pb[l] != 0 or Pnb[l] != 0: hyps_ctc[l]['ctc_score'] = np.log(Pb[l] + Pnb[l]) else: hyps_ctc[l]['ctc_score'] = float('-inf') local_score = hyps_ctc[l]['ctc_score'] + recog_args.ctc_lm_weight * hyps_ctc[l]['rnnlm_score'] + \ recog_args.penalty * (len(hyps_ctc[l]['yseq'])) hyps_ctc[l]['local_score'] = local_score hyps_ctc[l]['score'] = (1-recog_args.ctc_weight) * hyps_ctc[l]['att_score'] \ + recog_args.ctc_weight * hyps_ctc[l]['ctc_score'] + \ recog_args.penalty * (len(hyps_ctc[l]['yseq'])) + \ recog_args.lm_weight * hyps_ctc[l]['rnnlm_score'] Pb_prev = Pb Pnb_prev = Pnb Pb = Counter() Pnb = Counter() hyps1 = sorted(hyps_ctc.items(), key=lambda x: x[1]['local_score'], reverse=True)[:beam] hyps1 = dict(hyps1) hyps2 = sorted(hyps_ctc.items(), key=lambda x: x[1]['att_score'], reverse=True)[:beam] hyps2 = dict(hyps2) hyps = sorted(hyps_ctc.items(), key=lambda x: x[1]['score'], reverse=True)[:beam] hyps = dict(hyps) for key in hyps1.keys(): if key not in hyps: hyps[key] = hyps1[key] for key in hyps2.keys(): if key not in hyps: hyps[key] = hyps2[key] hyps = sorted(hyps.items(), key=lambda x: x[1]['score'], reverse=True)[:beam] hyps = dict(hyps) logging.info('input lengths: ' + str(h.size(0))) logging.info('max output length: ' + str(maxlen)) logging.info('min output length: ' + str(minlen)) if "<eos>" in hyps.keys(): del hyps["<eos>"] #for key in hyps.keys(): # logging.info("{0}\tctc:{1}\tatt:{2}\trnnlm:{3}\tscore:{4}".format(key,hyps[key]["ctc_score"],hyps[key]['att_score'], # hyps[key]['rnnlm_score'], hyps[key]['score'])) # print("!!!","Decoding None") best = list(hyps.keys())[0] ids = hyps[best]['yseq'] score = hyps[best]['score'] logging.info('score: ' + str(score)) #if l in hyps.keys(): # logging.info(l) #print(samelen,h_len) return best, ids, score def removeIlegal(self, hyps): max_y = max([len(hyp['yseq']) for l, hyp in hyps.items()]) for_remove = [] for l, hyp in hyps.items(): if max_y - len(hyp['yseq']) > 4: for_remove.append(l) for cur_str in for_remove: del hyps[cur_str] def clusterbyLength(self, hyps): tmp = {} for l, hyp in hyps.items(): prefix_len = len(hyp['yseq']) if prefix_len > 1 and hyp['yseq'][-1] == self.eos: continue else: if prefix_len not in tmp: tmp[prefix_len] = [] tmp[prefix_len].append(hyp) return tmp def compute_hyps(self, current_hyps, curren_frame, total_frame, enc_output, hat_att, enc_mask=None): for length, hyps_t in current_hyps.items(): ys_mask = subsequent_mask(length).unsqueeze(0).cuda() ys_mask4use = ys_mask.repeat(len(hyps_t), 1, 1) # print(ys_mask4use.shape) l_id = [hyp_t['yseq'] for hyp_t in hyps_t] ys4use = torch.tensor(l_id).cuda() enc_output4use = enc_output.repeat(len(hyps_t), 1, 1) if hyps_t[0]["cache"] is None: cache4use = None else: cache4use = [] for decode_num in range(len(hyps_t[0]["cache"])): current_cache = [] for hyp_t in hyps_t: current_cache.append( hyp_t["cache"][decode_num].squeeze(0)) # print( torch.stack(current_cache).shape) current_cache = torch.stack(current_cache) cache4use.append(current_cache) partial_mask4use = [] for hyp_t in hyps_t: #partial_mask4use.append(torch.ones([1, len(hyp_t['last_time'])+1, enc_mask.shape[1]]).byte()) align = [0] * length align[:length - 1] = hyp_t['last_time'][:] align[-1] = curren_frame align_tensor = torch.tensor(align).unsqueeze(0) if enc_mask is not None: partial_mask = enc_mask[0][align_tensor] else: right_window = self.right_window partial_mask = trigger_mask(1, total_frame, align_tensor, self.left_window, right_window) partial_mask4use.append(partial_mask) partial_mask4use = torch.stack(partial_mask4use).cuda().squeeze(1) local_att_scores_b, new_cache_b = self.decoder.forward_one_step( ys4use, ys_mask4use, enc_output4use, partial_mask4use, cache4use) for idx, hyp_t in enumerate(hyps_t): hyp_t['tmp_cache'] = [ new_cache_b[decode_num][idx].unsqueeze(0) for decode_num in range(len(new_cache_b)) ] hyp_t['tmp_att'] = local_att_scores_b[idx].unsqueeze(0) hat_att[hyp_t['seq']] = {} hat_att[hyp_t['seq']]['cache'] = hyp_t['tmp_cache'] hat_att[hyp_t['seq']]['att_scores'] = hyp_t['tmp_att'] def get_ctchyps2compute(self, hyps, hyps_ctc, current_frame): tmp2 = {} for l, hyp in hyps.items(): l_id = hyp['yseq'] l_end = l_id[-1] if "pos" not in hyp: continue if 0 in hyp['pos'] or l_end in hyp['pos']: #l_minus = ' '.join(l.split()[:-1]) #if l_minus in hat_att: # hyps[l]['tmp_cur_new_cache'] = hat_att[l_minus]['cache'] # hyps[l]['tmp_cur_att_scores'] = hat_att[l_minus]['att_scores'] # continue if l not in hyps_ctc and l_end != self.eos: tmp2[l] = {'yseq': l_id} tmp2[l]['seq'] = l tmp2[l]['rnnlm_prev'] = hyp['rnnlm_prev'] tmp2[l]['rnnlm_score'] = hyp['rnnlm_score'] if l_end != self.eos: tmp2[l]['last_time'] = [0] * len(l_id) tmp2[l]['last_time'][:] = hyp['last_time'][:] tmp2[l]['last_time'][-1] = current_frame return tmp2 def compute_hyps_ctc(self, hyps_ctc_cluster, total_frame, enc_output, hat_att, enc_mask=None): for length, hyps_t in hyps_ctc_cluster.items(): ys_mask = subsequent_mask(length - 1).unsqueeze(0).cuda() ys_mask4use = ys_mask.repeat(len(hyps_t), 1, 1) l_id = [hyp_t['yseq'][:-1] for hyp_t in hyps_t] ys4use = torch.tensor(l_id).cuda() enc_output4use = enc_output.repeat(len(hyps_t), 1, 1) if "precache" not in hyps_t[0] or hyps_t[0]["precache"] is None: cache4use = None else: cache4use = [] for decode_num in range(len(hyps_t[0]["precache"])): current_cache = [] for hyp_t in hyps_t: # print(length, hyp_t["yseq"], hyp_t["cache"][0].shape, # hyp_t["cache"][2].shape, hyp_t["cache"][4].shape) current_cache.append( hyp_t["precache"][decode_num].squeeze(0)) current_cache = torch.stack(current_cache) cache4use.append(current_cache) partial_mask4use = [] for hyp_t in hyps_t: #partial_mask4use.append(torch.ones([1, len(hyp_t['last_time']), enc_mask.shape[1]]).byte()) align = hyp_t['last_time'] align_tensor = torch.tensor(align).unsqueeze(0) if enc_mask is not None: partial_mask = enc_mask[0][align_tensor] else: right_window = self.right_window partial_mask = trigger_mask(1, total_frame, align_tensor, self.left_window, right_window) partial_mask4use.append(partial_mask) partial_mask4use = torch.stack(partial_mask4use).cuda().squeeze(1) local_att_scores_b, new_cache_b = \ self.decoder.forward_one_step(ys4use, ys_mask4use, enc_output4use, partial_mask4use, cache4use) for idx, hyp_t in enumerate(hyps_t): hyp_t['tmp_cur_new_cache'] = [ new_cache_b[decode_num][idx].unsqueeze(0) for decode_num in range(len(new_cache_b)) ] hyp_t['tmp_cur_att_scores'] = local_att_scores_b[ idx].unsqueeze(0) l_minus = ' '.join(hyp_t['seq'].split()[:-1]) hat_att[l_minus] = {} hat_att[l_minus]['att_scores'] = hyp_t['tmp_cur_att_scores'] hat_att[l_minus]['cache'] = hyp_t['tmp_cur_new_cache']
class E2E(ASRInterface, torch.nn.Module): """E2E module. :param int idim: dimension of inputs :param int odim: dimension of outputs :param Namespace args: argument Namespace containing options """ @staticmethod def add_arguments(parser): """Add arguments.""" group = parser.add_argument_group("transformer model setting") group.add_argument( "--transformer-init", type=str, default="pytorch", choices=[ "pytorch", "xavier_uniform", "xavier_normal", "kaiming_uniform", "kaiming_normal", ], help="how to initialize transformer parameters", ) group.add_argument( "--transformer-input-layer", type=str, default="conv2d", choices=["conv2d", "linear", "embed"], help="transformer input layer type", ) group.add_argument( "--transformer-attn-dropout-rate", default=None, type=float, help= "dropout in transformer attention. use --dropout-rate if None is set", ) group.add_argument( "--transformer-lr", default=10.0, type=float, help="Initial value of learning rate", ) group.add_argument( "--transformer-warmup-steps", default=25000, type=int, help="optimizer warmup steps", ) group.add_argument( "--transformer-length-normalized-loss", default=True, type=strtobool, help="normalize loss by length", ) group.add_argument( "--transformer-encoder-selfattn-layer-type", type=str, default="selfattn", choices=[ "selfattn", "rel_selfattn", "lightconv", "lightconv2d", "dynamicconv", "dynamicconv2d", "light-dynamicconv2d", ], help="transformer encoder self-attention layer type", ) group.add_argument( "--transformer-decoder-selfattn-layer-type", type=str, default="selfattn", choices=[ "selfattn", "lightconv", "lightconv2d", "dynamicconv", "dynamicconv2d", "light-dynamicconv2d", ], help="transformer decoder self-attention layer type", ) # Lightweight/Dynamic convolution related parameters. # See https://arxiv.org/abs/1912.11793v2 # and https://arxiv.org/abs/1901.10430 for detail of the method. # Configurations used in the first paper are in # egs/{csj, librispeech}/asr1/conf/tuning/ld_conv/ group.add_argument( "--wshare", default=4, type=int, help="Number of parameter shargin for lightweight convolution", ) group.add_argument( "--ldconv-encoder-kernel-length", default="21_23_25_27_29_31_33_35_37_39_41_43", type=str, help="kernel size for lightweight/dynamic convolution: " 'Encoder side. For example, "21_23_25" means kernel length 21 for ' "First layer, 23 for Second layer and so on.", ) group.add_argument( "--ldconv-decoder-kernel-length", default="11_13_15_17_19_21", type=str, help="kernel size for lightweight/dynamic convolution: " 'Decoder side. For example, "21_23_25" means kernel length 21 for ' "First layer, 23 for Second layer and so on.", ) group.add_argument( "--ldconv-usebias", type=strtobool, default=False, help="use bias term in lightweight/dynamic convolution", ) group.add_argument( "--dropout-rate", default=0.0, type=float, help="Dropout rate for the encoder", ) # Encoder group.add_argument( "--elayers", default=4, type=int, help="Number of encoder layers (for shared recognition part " "in multi-speaker asr mode)", ) group.add_argument( "--eunits", "-u", default=300, type=int, help="Number of encoder hidden units", ) # Attention group.add_argument( "--adim", default=320, type=int, help="Number of attention transformation dimensions", ) group.add_argument( "--aheads", default=4, type=int, help="Number of heads for multi head attention", ) # Decoder group.add_argument("--dlayers", default=1, type=int, help="Number of decoder layers") group.add_argument("--dunits", default=320, type=int, help="Number of decoder hidden units") return parser @property def attention_plot_class(self): """Return PlotAttentionReport.""" return PlotAttentionReport def __init__(self, idim, odim, args, ignore_id=-1): """Construct an E2E object. :param int idim: dimension of inputs :param int odim: dimension of outputs :param Namespace args: argument Namespace containing options """ torch.nn.Module.__init__(self) # fill missing arguments for compatibility args = fill_missing_args(args, self.add_arguments) if args.transformer_attn_dropout_rate is None: args.transformer_attn_dropout_rate = args.dropout_rate self.encoder = Encoder( idim=idim, selfattention_layer_type=args. transformer_encoder_selfattn_layer_type, attention_dim=args.adim, attention_heads=args.aheads, conv_wshare=args.wshare, conv_kernel_length=args.ldconv_encoder_kernel_length, conv_usebias=args.ldconv_usebias, linear_units=args.eunits, num_blocks=args.elayers, input_layer=args.transformer_input_layer, dropout_rate=args.dropout_rate, positional_dropout_rate=args.dropout_rate, attention_dropout_rate=args.transformer_attn_dropout_rate, ) if args.mtlalpha < 1: self.decoder = Decoder( odim=odim, selfattention_layer_type=args. transformer_decoder_selfattn_layer_type, attention_dim=args.adim, attention_heads=args.aheads, conv_wshare=args.wshare, conv_kernel_length=args.ldconv_decoder_kernel_length, conv_usebias=args.ldconv_usebias, linear_units=args.dunits, num_blocks=args.dlayers, dropout_rate=args.dropout_rate, positional_dropout_rate=args.dropout_rate, self_attention_dropout_rate=args.transformer_attn_dropout_rate, src_attention_dropout_rate=args.transformer_attn_dropout_rate, ) else: self.decoder = None self.blank = 0 self.sos = odim - 1 self.eos = odim - 1 self.odim = odim self.ignore_id = ignore_id self.subsample = get_subsample(args, mode="asr", arch="transformer") self.reporter = Reporter() # self.lsm_weight = a self.criterion = LabelSmoothingLoss( self.odim, self.ignore_id, args.lsm_weight, args.transformer_length_normalized_loss, ) # self.verbose = args.verbose self.reset_parameters(args) self.adim = args.adim self.mtlalpha = args.mtlalpha if args.mtlalpha > 0.0: self.ctc = CTC(odim, args.adim, args.dropout_rate, ctc_type=args.ctc_type, reduce=True) else: self.ctc = None if args.report_cer or args.report_wer: self.error_calculator = ErrorCalculator( args.char_list, args.sym_space, args.sym_blank, args.report_cer, args.report_wer, ) else: self.error_calculator = None self.rnnlm = None def reset_parameters(self, args): """Initialize parameters.""" # initialize parameters initialize(self, args.transformer_init) def forward(self, xs_pad, ilens, ys_pad): """E2E forward. :param torch.Tensor xs_pad: batch of padded source sequences (B, Tmax, idim) :param torch.Tensor ilens: batch of lengths of source sequences (B) :param torch.Tensor ys_pad: batch of padded target sequences (B, Lmax) :return: ctc loass value :rtype: torch.Tensor :return: attention loss value :rtype: torch.Tensor :return: accuracy in attention decoder :rtype: float """ # 1. forward encoder xs_pad = xs_pad[:, :max(ilens)] # for data parallel src_mask = make_non_pad_mask(ilens.tolist()).to( xs_pad.device).unsqueeze(-2) hs_pad, hs_mask = self.encoder(xs_pad, src_mask) self.hs_pad = hs_pad # 2. forward decoder if self.decoder is not None: ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id) ys_mask = target_mask(ys_in_pad, self.ignore_id) pred_pad, pred_mask = self.decoder(ys_in_pad, ys_mask, hs_pad, hs_mask) self.pred_pad = pred_pad # 3. compute attention loss loss_att = self.criterion(pred_pad, ys_out_pad) self.acc = th_accuracy(pred_pad.view(-1, self.odim), ys_out_pad, ignore_label=self.ignore_id) else: loss_att = None self.acc = None # TODO(karita) show predicted text # TODO(karita) calculate these stats cer_ctc = None if self.mtlalpha == 0.0: loss_ctc = None else: batch_size = xs_pad.size(0) hs_len = hs_mask.view(batch_size, -1).sum(1) loss_ctc = self.ctc(hs_pad.view(batch_size, -1, self.adim), hs_len, ys_pad) if self.error_calculator is not None: ys_hat = self.ctc.argmax(hs_pad.view(batch_size, -1, self.adim)).data cer_ctc = self.error_calculator(ys_hat.cpu(), ys_pad.cpu(), is_ctc=True) # 5. compute cer/wer if self.training or self.error_calculator is None or self.mtlalpha == 1.0: cer, wer = None, None else: ys_hat = pred_pad.argmax(dim=-1) cer, wer = self.error_calculator(ys_hat.cpu(), ys_pad.cpu()) # copyied from e2e_asr alpha = self.mtlalpha if alpha == 0: self.loss = loss_att loss_att_data = float(loss_att) loss_ctc_data = None elif alpha == 1: self.loss = loss_ctc loss_att_data = None loss_ctc_data = float(loss_ctc) else: self.loss = alpha * loss_ctc + (1 - alpha) * loss_att loss_att_data = float(loss_att) loss_ctc_data = float(loss_ctc) loss_data = float(self.loss) if loss_data < CTC_LOSS_THRESHOLD and not math.isnan(loss_data): self.reporter.report(loss_ctc_data, loss_att_data, self.acc, cer_ctc, cer, wer, loss_data) else: logging.warning("loss (=%f) is not correct", loss_data) return self.loss def scorers(self): """Scorers.""" return dict(decoder=self.decoder, ctc=CTCPrefixScorer(self.ctc, self.eos)) def encode(self, x): """Encode acoustic features. :param ndarray x: source acoustic feature (T, D) :return: encoder outputs :rtype: torch.Tensor """ self.eval() x = torch.as_tensor(x).unsqueeze(0) enc_output, _ = self.encoder(x, None) return enc_output.squeeze(0) def recognize(self, x, recog_args, char_list=None, rnnlm=None, use_jit=False): """Recognize input speech. :param ndnarray x: input acoustic feature (B, T, D) or (T, D) :param Namespace recog_args: argment Namespace contraining options :param list char_list: list of characters :param torch.nn.Module rnnlm: language model module :return: N-best decoding results :rtype: list """ enc_output = self.encode(x).unsqueeze(0) if self.mtlalpha == 1.0: recog_args.ctc_weight = 1.0 logging.info("Set to pure CTC decoding mode.") if self.mtlalpha > 0 and recog_args.ctc_weight == 1.0: from itertools import groupby lpz = self.ctc.argmax(enc_output) collapsed_indices = [x[0] for x in groupby(lpz[0])] hyp = [ x for x in filter(lambda x: x != self.blank, collapsed_indices) ] nbest_hyps = [{"score": 0.0, "yseq": hyp}] if recog_args.beam_size > 1: raise NotImplementedError( "Pure CTC beam search is not implemented.") # TODO(hirofumi0810): Implement beam search return nbest_hyps elif self.mtlalpha > 0 and recog_args.ctc_weight > 0.0: lpz = self.ctc.log_softmax(enc_output) lpz = lpz.squeeze(0) else: lpz = None h = enc_output.squeeze(0) logging.info("input lengths: " + str(h.size(0))) # search parms beam = recog_args.beam_size penalty = recog_args.penalty ctc_weight = recog_args.ctc_weight # preprare sos y = self.sos vy = h.new_zeros(1).long() if recog_args.maxlenratio == 0: maxlen = h.shape[0] else: # maxlen >= 1 maxlen = max(1, int(recog_args.maxlenratio * h.size(0))) minlen = int(recog_args.minlenratio * h.size(0)) logging.info("max output length: " + str(maxlen)) logging.info("min output length: " + str(minlen)) # initialize hypothesis if rnnlm: hyp = {"score": 0.0, "yseq": [y], "rnnlm_prev": None} else: hyp = {"score": 0.0, "yseq": [y]} if lpz is not None: ctc_prefix_score = CTCPrefixScore(lpz.detach().numpy(), 0, self.eos, numpy) hyp["ctc_state_prev"] = ctc_prefix_score.initial_state() hyp["ctc_score_prev"] = 0.0 if ctc_weight != 1.0: # pre-pruning based on attention scores ctc_beam = min(lpz.shape[-1], int(beam * CTC_SCORING_RATIO)) else: ctc_beam = lpz.shape[-1] hyps = [hyp] ended_hyps = [] import six traced_decoder = None for i in six.moves.range(maxlen): logging.debug("position " + str(i)) hyps_best_kept = [] for hyp in hyps: vy[0] = hyp["yseq"][i] # get nbest local scores and their ids ys_mask = subsequent_mask(i + 1).unsqueeze(0) ys = torch.tensor(hyp["yseq"]).unsqueeze(0) # FIXME: jit does not match non-jit result if use_jit: if traced_decoder is None: traced_decoder = torch.jit.trace( self.decoder.forward_one_step, (ys, ys_mask, enc_output)) local_att_scores = traced_decoder(ys, ys_mask, enc_output)[0] else: local_att_scores = self.decoder.forward_one_step( ys, ys_mask, enc_output)[0] if rnnlm: rnnlm_state, local_lm_scores = rnnlm.predict( hyp["rnnlm_prev"], vy) local_scores = (local_att_scores + recog_args.lm_weight * local_lm_scores) else: local_scores = local_att_scores if lpz is not None: local_best_scores, local_best_ids = torch.topk( local_att_scores, ctc_beam, dim=1) ctc_scores, ctc_states = ctc_prefix_score( hyp["yseq"], local_best_ids[0], hyp["ctc_state_prev"]) local_scores = ( 1.0 - ctc_weight) * local_att_scores[:, local_best_ids[ 0]] + ctc_weight * torch.from_numpy( ctc_scores - hyp["ctc_score_prev"]) if rnnlm: local_scores += (recog_args.lm_weight * local_lm_scores[:, local_best_ids[0]]) local_best_scores, joint_best_ids = torch.topk( local_scores, beam, dim=1) local_best_ids = local_best_ids[:, joint_best_ids[0]] else: local_best_scores, local_best_ids = torch.topk( local_scores, beam, dim=1) for j in six.moves.range(beam): new_hyp = {} new_hyp["score"] = hyp["score"] + float( local_best_scores[0, j]) new_hyp["yseq"] = [0] * (1 + len(hyp["yseq"])) new_hyp["yseq"][:len(hyp["yseq"])] = hyp["yseq"] new_hyp["yseq"][len(hyp["yseq"])] = int(local_best_ids[0, j]) if rnnlm: new_hyp["rnnlm_prev"] = rnnlm_state if lpz is not None: new_hyp["ctc_state_prev"] = ctc_states[joint_best_ids[ 0, j]] new_hyp["ctc_score_prev"] = ctc_scores[joint_best_ids[ 0, j]] # will be (2 x beam) hyps at most hyps_best_kept.append(new_hyp) hyps_best_kept = sorted(hyps_best_kept, key=lambda x: x["score"], reverse=True)[:beam] # sort and get nbest hyps = hyps_best_kept logging.debug("number of pruned hypothes: " + str(len(hyps))) if char_list is not None: logging.debug( "best hypo: " + "".join([char_list[int(x)] for x in hyps[0]["yseq"][1:]])) # add eos in the final loop to avoid that there are no ended hyps if i == maxlen - 1: logging.info("adding <eos> in the last postion in the loop") for hyp in hyps: hyp["yseq"].append(self.eos) # add ended hypothes to a final list, and removed them from current hypothes # (this will be a probmlem, number of hyps < beam) remained_hyps = [] for hyp in hyps: if hyp["yseq"][-1] == self.eos: # only store the sequence that has more than minlen outputs # also add penalty if len(hyp["yseq"]) > minlen: hyp["score"] += (i + 1) * penalty if rnnlm: # Word LM needs to add final <eos> score hyp["score"] += recog_args.lm_weight * rnnlm.final( hyp["rnnlm_prev"]) ended_hyps.append(hyp) else: remained_hyps.append(hyp) # end detection if end_detect(ended_hyps, i) and recog_args.maxlenratio == 0.0: logging.info("end detected at %d", i) break hyps = remained_hyps if len(hyps) > 0: logging.debug("remeined hypothes: " + str(len(hyps))) else: logging.info("no hypothesis. Finish decoding.") break if char_list is not None: for hyp in hyps: logging.debug( "hypo: " + "".join([char_list[int(x)] for x in hyp["yseq"][1:]])) logging.debug("number of ended hypothes: " + str(len(ended_hyps))) nbest_hyps = sorted( ended_hyps, key=lambda x: x["score"], reverse=True)[:min(len(ended_hyps), recog_args.nbest)] # check number of hypotheis if len(nbest_hyps) == 0: logging.warning("there is no N-best results, perform recognition " "again with smaller minlenratio.") # should copy becasuse Namespace will be overwritten globally recog_args = Namespace(**vars(recog_args)) recog_args.minlenratio = max(0.0, recog_args.minlenratio - 0.1) return self.recognize(x, recog_args, char_list, rnnlm) logging.info("total log probability: " + str(nbest_hyps[0]["score"])) logging.info("normalized log probability: " + str(nbest_hyps[0]["score"] / len(nbest_hyps[0]["yseq"]))) return nbest_hyps def calculate_all_attentions(self, xs_pad, ilens, ys_pad): """E2E attention calculation. :param torch.Tensor xs_pad: batch of padded input sequences (B, Tmax, idim) :param torch.Tensor ilens: batch of lengths of input sequences (B) :param torch.Tensor ys_pad: batch of padded token id sequence tensor (B, Lmax) :return: attention weights with the following shape, 1) multi-head case => attention weights (B, H, Lmax, Tmax), 2) other case => attention weights (B, Lmax, Tmax). :rtype: float ndarray """ with torch.no_grad(): self.forward(xs_pad, ilens, ys_pad) ret = dict() for name, m in self.named_modules(): if (isinstance(m, MultiHeadedAttention) or isinstance(m, DynamicConvolution) or isinstance(m, RelPositionMultiHeadedAttention)): ret[name] = m.attn.cpu().numpy() if isinstance(m, DynamicConvolution2D): ret[name + "_time"] = m.attn_t.cpu().numpy() ret[name + "_freq"] = m.attn_f.cpu().numpy() return ret
class Transformer(TTSInterface, torch.nn.Module): """Transformer for TTS - Reference: Neural Speech Synthesis with Transformer Network (https://arxiv.org/pdf/1809.08895.pdf) :param int idim: dimension of the inputs :param int odim: dimension of the outputs :param Namespace args: argments containing following attributes (int) embed_dim: dimension of character embedding (int) eprenet_conv_layers: number of encoder prenet convolution layers (int) eprenet_conv_chans: number of encoder prenet convolution channels (int) eprenet_conv_filts: filter size of encoder prenet convolution (int) dprenet_layers: number of decoder prenet layers (int) dprenet_units: number of decoder prenet hidden units (int) elayers: number of encoder layers (int) eunits: number of encoder hidden units (int) adim: number of attention transformation dimensions (int) aheads: number of heads for multi head attention (int) dlayers: number of decoder layers (int) dunits: number of decoder hidden units (int) postnet_layers: number of postnet layers (int) postnet_chans: number of postnet channels (int) postnet_filts: filter size of postnet (bool) use_scaled_pos_enc: whether to use trainable scaled positional encoding instead of the fixed scale one (bool) use_batch_norm: whether to use batch normalization (float) transformer_init: how to initialize transformer parameters (float) transformer_lr: initial value of learning rate (int) transformer_warmup_steps: optimizer warmup steps (float) transformer_attn_dropout_rate: dropout in transformer attention. use dropout if none is set (float) eprenet_dropout_rate: dropout rate in encoder prenet. use dropout if none is set (float) dprenet_dropout_rate: dropout rate in decoder prenet. use dropout if none is set (float) postnet_dropout_rate: dropout rate in postnet. use dropout_rate if none is set (float) dropout_rate: dropout rate in the other module (bool) use_masking: whether to use masking in calculation of loss (float) bce_pos_weight: positive sample weight in bce calculation (only for use_masking=true) (str) loss_type: how to calculate loss """ @staticmethod def add_arguments(parser): group = parser.add_argument_group("transformer model setting") # network structure related group.add_argument("--embed-dim", default=512, type=int, help="Dimension of character embedding") group.add_argument("--eprenet-conv-layers", default=3, type=int, help="Number of encoder prenet convolution layers") group.add_argument( "--eprenet-conv-chans", default=512, type=int, help="Number of encoder prenet convolution channels") group.add_argument("--eprenet-conv-filts", default=5, type=int, help="Filter size of encoder prenet convolution") group.add_argument("--dprenet-layers", default=2, type=int, help="Number of decoder prenet layers") group.add_argument("--dprenet-units", default=256, type=int, help="Number of decoder prenet hidden units") group.add_argument("--elayers", default=3, type=int, help="Number of encoder layers") group.add_argument("--eunits", default=2048, type=int, help="Number of encoder hidden units") group.add_argument( "--adim", default=512, type=int, help="Number of attention transformation dimensions") group.add_argument("--aheads", default=4, type=int, help="Number of heads for multi head attention") group.add_argument("--dlayers", default=3, type=int, help="Number of decoder layers") group.add_argument("--dunits", default=2048, type=int, help="Number of decoder hidden units") group.add_argument("--postnet-layers", default=5, type=int, help="Number of postnet layers") group.add_argument("--postnet-chans", default=512, type=int, help="Number of postnet channels") group.add_argument("--postnet-filts", default=5, type=int, help="Filter size of postnet") group.add_argument( "--use-scaled-pos-enc", default=True, type=strtobool, help= "use trainable scaled positional encoding instead of the fixed scale one." ) group.add_argument("--use-batch-norm", default=True, type=strtobool, help="Whether to use batch normalization") group.add_argument( "--encoder-normalize-before", default=True, type=strtobool, help="Whether to apply layer norm before encoder block") group.add_argument( "--decoder-normalize-before", default=True, type=strtobool, help="Whether to apply layer norm before decoder block") group.add_argument( "--encoder-concat-after", default=False, type=strtobool, help= "Whether to concatenate attention layer\"s input and output in encoder" ) group.add_argument( "--decoder-concat-after", default=False, type=strtobool, help= "Whether to concatenate attention layer\"s input and output in decoder" ) parser.add_argument("--reduction-factor", default=1, type=int, help="Reduction factor") # training related group.add_argument("--transformer-init", type=str, default="pytorch", choices=[ "pytorch", "xavier_uniform", "xavier_normal", "kaiming_uniform", "kaiming_normal" ], help="how to initialize transformer parameters") group.add_argument( "--initial-encoder-alpha", type=float, default=1.0, help="initial alpha value in encoder\"s ScaledPositionalEncoding") group.add_argument( "--initial-decoder-alpha", type=float, default=1.0, help="initial alpha value in decoder\"s ScaledPositionalEncoding") group.add_argument("--transformer-lr", default=1.0, type=float, help="Initial value of learning rate") group.add_argument("--transformer-warmup-steps", default=4000, type=int, help="optimizer warmup steps") group.add_argument( "--transformer-enc-dropout-rate", default=0.1, type=float, help="dropout rate for transformer encoder except for attention") group.add_argument( "--transformer-enc-positional-dropout-rate", default=0.1, type=float, help="dropout rate for transformer encoder positional encoding") group.add_argument( "--transformer-enc-attn-dropout-rate", default=0.0, type=float, help="dropout rate for transformer encoder self-attention") group.add_argument( "--transformer-dec-dropout-rate", default=0.1, type=float, help= "dropout rate for transformer decoder except for attention and pos encoding" ) group.add_argument( "--transformer-dec-positional-dropout-rate", default=0.1, type=float, help="dropout rate for transformer decoder positional encoding") group.add_argument( "--transformer-dec-attn-dropout-rate", default=0.3, type=float, help="dropout rate for transformer decoder self-attention") group.add_argument( "--transformer-enc-dec-attn-dropout-rate", default=0.0, type=float, help="dropout rate for transformer encoder-decoder attention") group.add_argument("--eprenet-dropout-rate", default=0.1, type=float, help="dropout rate in encoder prenet") group.add_argument("--dprenet-dropout-rate", default=0.5, type=float, help="dropout rate in decoder prenet") group.add_argument("--postnet-dropout-rate", default=0.1, type=float, help="dropout rate in postnet") # loss related group.add_argument( "--use-masking", default=True, type=strtobool, help="Whether to use masking in calculation of loss") group.add_argument("--loss-type", default="L1", choices=["L1", "L2", "L1+L2"], help="How to calc loss") group.add_argument( "--bce-pos-weight", default=5.0, type=float, help= "Positive sample weight in BCE calculation (only for use-masking=True)" ) group.add_argument("--use-guided-attn-loss", default=False, type=strtobool, help="Whether to use guided attention loss") group.add_argument("--guided-attn-loss-sigma", default=0.4, type=float, help="Sigma in guided attention loss") group.add_argument( "--num-heads-applied-guided-attn", default=2, type=int, help= "Number of heads in each layer to be applied guided attention loss" "if set -1, all of the heads will be applied.") group.add_argument( "--num-layers-applied-guided-attn", default=2, type=int, help="Number of layers to be applied guided attention loss" "if set -1, all of the layers will be applied.") group.add_argument( "--modules-applied-guided-attn", type=str, nargs="+", default=["encoder", "decoder", "encoder-decoder"], help="Module name list to be applied guided attention loss") return parser @property def attention_plot_class(self): return TTSPlot def __init__(self, idim, odim, args): # initialize base classes TTSInterface.__init__(self) torch.nn.Module.__init__(self) # store hyperparameters self.idim = idim self.odim = odim self.use_scaled_pos_enc = args.use_scaled_pos_enc self.reduction_factor = args.reduction_factor self.loss_type = args.loss_type self.use_guided_attn_loss = args.use_guided_attn_loss if self.use_guided_attn_loss: if args.num_layers_applied_guided_attn == -1: self.num_layers_applied_guided_attn = args.elayers else: self.num_layers_applied_guided_attn = args.num_layers_applied_guided_attn if args.num_heads_applied_guided_attn == -1: self.num_heads_applied_guided_attn = args.aheads else: self.num_heads_applied_guided_attn = args.num_heads_applied_guided_attn self.modules_applied_guided_attn = args.modules_applied_guided_attn # use idx 0 as padding idx padding_idx = 0 # get positional encoding class pos_enc_class = ScaledPositionalEncoding if self.use_scaled_pos_enc else PositionalEncoding # define transformer encoder if args.eprenet_conv_layers != 0: # encoder prenet encoder_input_layer = torch.nn.Sequential( EncoderPrenet(idim=idim, embed_dim=args.embed_dim, elayers=0, econv_layers=args.eprenet_conv_layers, econv_chans=args.eprenet_conv_chans, econv_filts=args.eprenet_conv_filts, use_batch_norm=args.use_batch_norm, dropout_rate=args.eprenet_dropout_rate, padding_idx=padding_idx), torch.nn.Linear(args.eprenet_conv_chans, args.adim)) else: encoder_input_layer = torch.nn.Embedding(num_embeddings=idim, embedding_dim=args.adim, padding_idx=padding_idx) self.encoder = Encoder( idim=idim, attention_dim=args.adim, attention_heads=args.aheads, linear_units=args.eunits, num_blocks=args.elayers, input_layer=encoder_input_layer, dropout_rate=args.transformer_enc_dropout_rate, positional_dropout_rate=args. transformer_enc_positional_dropout_rate, attention_dropout_rate=args.transformer_enc_attn_dropout_rate, pos_enc_class=pos_enc_class, normalize_before=args.encoder_normalize_before, concat_after=args.encoder_concat_after) # define transformer decoder if args.dprenet_layers != 0: # decoder prenet decoder_input_layer = torch.nn.Sequential( DecoderPrenet(idim=odim, n_layers=args.dprenet_layers, n_units=args.dprenet_units, dropout_rate=args.dprenet_dropout_rate), torch.nn.Linear(args.dprenet_units, args.adim)) else: decoder_input_layer = "linear" self.decoder = Decoder( odim=-1, attention_dim=args.adim, attention_heads=args.aheads, linear_units=args.dunits, num_blocks=args.dlayers, dropout_rate=args.transformer_dec_dropout_rate, positional_dropout_rate=args. transformer_dec_positional_dropout_rate, self_attention_dropout_rate=args.transformer_dec_attn_dropout_rate, src_attention_dropout_rate=args. transformer_enc_dec_attn_dropout_rate, input_layer=decoder_input_layer, use_output_layer=False, pos_enc_class=pos_enc_class, normalize_before=args.decoder_normalize_before, concat_after=args.decoder_concat_after) # define final projection self.feat_out = torch.nn.Linear(args.adim, odim * args.reduction_factor) self.prob_out = torch.nn.Linear(args.adim, args.reduction_factor) # define postnet self.postnet = None if args.postnet_layers == 0 else Postnet( idim=idim, odim=odim, n_layers=args.postnet_layers, n_chans=args.postnet_chans, n_filts=args.postnet_filts, use_batch_norm=args.use_batch_norm, dropout_rate=args.postnet_dropout_rate) # define loss function self.criterion = TransformerLoss(args) if self.use_guided_attn_loss: self.attn_criterion = GuidedMultiHeadAttentionLoss( args.guided_attn_loss_sigma) # initialize parameters self._reset_parameters(args) def _reset_parameters(self, args): if self.use_scaled_pos_enc: # alpha in scaled positional encoding init self.encoder.embed[-1].alpha.data = torch.tensor( args.initial_encoder_alpha) self.decoder.embed[-1].alpha.data = torch.tensor( args.initial_decoder_alpha) if args.transformer_init == "pytorch": return # weight init for p in self.parameters(): if p.dim() > 1: if args.transformer_init == "xavier_uniform": torch.nn.init.xavier_uniform_(p.data) elif args.transformer_init == "xavier_normal": torch.nn.init.xavier_normal_(p.data) elif args.transformer_init == "kaiming_uniform": torch.nn.init.kaiming_uniform_(p.data, nonlinearity="relu") elif args.transformer_init == "kaiming_normal": torch.nn.init.kaiming_normal_(p.data, nonlinearity="relu") else: raise ValueError("Unknown initialization: " + args.transformer_init) # bias init for p in self.parameters(): if p.dim() == 1: p.data.zero_() # reset some modules with default init for m in self.modules(): if isinstance(m, (torch.nn.Embedding, LayerNorm)): m.reset_parameters() def _add_first_frame_and_remove_last_frame(self, ys): ys_in = torch.cat( [ys.new_zeros((ys.shape[0], 1, ys.shape[2])), ys[:, :-1]], dim=1) return ys_in def forward(self, xs, ilens, ys, labels, olens, *args, **kwargs): """Transformer forward computation :param torch.Tensor xs: batch of padded character ids (B, Tmax) :param torch.Tensor ilens: list of lengths of each input batch (B) :param torch.Tensor ys: batch of padded target features (B, Lmax, odim) :param torch.Tensor olens: batch of the lengths of each target (B) :return: loss value :rtype: torch.Tensor """ # remove unnecessary padded part (for multi-gpus) max_ilen = max(ilens) max_olen = max(olens) if max_ilen != xs.shape[1]: xs = xs[:, :max_ilen] if max_olen != ys.shape[1]: ys = ys[:, :max_olen] labels = labels[:, :max_olen] # forward encoder x_masks = self._source_mask(ilens) hs, _ = self.encoder(xs, x_masks) # thin out frames for reduction factor (B, Lmax, odim) -> (B, Lmax//r, odim) if self.reduction_factor > 1: ys_in = ys[:, self.reduction_factor - 1::self.reduction_factor] olens_in = olens.new( [olen // self.reduction_factor for olen in olens]) else: ys_in, olens_in = ys, olens # add first zero frame and remove last frame for auto-regressive ys_in = self._add_first_frame_and_remove_last_frame(ys_in) # forward decoder y_masks = self._target_mask(olens_in) xy_masks = self._source_to_target_mask(ilens, olens_in) zs, _ = self.decoder(ys_in, y_masks, hs, xy_masks) # (B, Lmax//r, odim * r) -> (B, Lmax//r * r, odim) before_outs = self.feat_out(zs).view(zs.size(0), -1, self.odim) # (B, Lmax//r, r) -> (B, Lmax//r * r) logits = self.prob_out(zs).view(zs.size(0), -1) # postnet -> (B, Lmax//r * r, odim) if self.postnet is None: after_outs = before_outs else: after_outs = before_outs + self.postnet(before_outs.transpose( 1, 2)).transpose(1, 2) # modifiy mod part of groundtruth if self.reduction_factor > 1: olens = olens.new( [olen - olen % self.reduction_factor for olen in olens]) max_olen = max(olens) ys = ys[:, :max_olen] labels = labels[:, :max_olen] labels[:, -1] = 1.0 # make sure at least one frame has 1 # caluculate loss values l1_loss, l2_loss, bce_loss = self.criterion(after_outs, before_outs, logits, ys, labels, olens) if self.loss_type == "L1": loss = l1_loss + bce_loss elif self.loss_type == "L2": loss = l2_loss + bce_loss elif self.loss_type == "L1+L2": loss = l1_loss + l2_loss + bce_loss else: raise ValueError("unknown --loss-type " + self.loss_type) report_keys = [ { "l1_loss": l1_loss.item() }, { "l2_loss": l2_loss.item() }, { "bce_loss": bce_loss.item() }, { "loss": loss.item() }, ] # calculate guided attention loss if self.use_guided_attn_loss: # calculate for encoder if "encoder" in self.modules_applied_guided_attn: att_ws = [] for idx, layer_idx in enumerate( reversed(range(len(self.decoder.decoders)))): att_ws += [ self.encoder.encoders[layer_idx].self_attn. attn[:, :self.num_heads_applied_guided_attn] ] if idx + 1 == self.num_layers_applied_guided_attn: break att_ws = torch.cat(att_ws, dim=1) # (B, H*L, T_in, T_in) enc_attn_loss = self.attn_criterion(att_ws, ilens, ilens) loss = loss + enc_attn_loss report_keys += [{"enc_attn_loss": enc_attn_loss.item()}] # calculate for decoder if "decoder" in self.modules_applied_guided_attn: att_ws = [] for idx, layer_idx in enumerate( reversed(range(len(self.decoder.decoders)))): att_ws += [ self.decoder.decoders[layer_idx].self_attn. attn[:, :self.num_heads_applied_guided_attn] ] if idx + 1 == self.num_layers_applied_guided_attn: break att_ws = torch.cat(att_ws, dim=1) # (B, H*L, T_out, T_out) dec_attn_loss = self.attn_criterion(att_ws, olens_in, olens_in) loss = loss + dec_attn_loss report_keys += [{"dec_attn_loss": dec_attn_loss.item()}] # calculate for encoder-decoder if "encoder-decoder" in self.modules_applied_guided_attn: att_ws = [] for idx, layer_idx in enumerate( reversed(range(len(self.decoder.decoders)))): att_ws += [ self.decoder.decoders[layer_idx].src_attn. attn[:, :self.num_heads_applied_guided_attn] ] if idx + 1 == self.num_layers_applied_guided_attn: break att_ws = torch.cat(att_ws, dim=1) # (B, H*L, T_out, T_in) enc_dec_attn_loss = self.attn_criterion( att_ws, ilens, olens_in) loss = loss + enc_dec_attn_loss report_keys += [{ "enc_dec_attn_loss": enc_dec_attn_loss.item() }] # report extra information if self.use_scaled_pos_enc: report_keys += [ { "encoder_alpha": self.encoder.embed[-1].alpha.data.item() }, { "decoder_alpha": self.decoder.embed[-1].alpha.data.item() }, ] self.reporter.report(report_keys) return loss def inference(self, x, inference_args, *args, **kwargs): """Generates the sequence of features given the sequences of characters :param torch.Tensor x: the sequence of characters (T) :param Namespace inference_args: argments containing following attributes (float) threshold: threshold in inference (float) minlenratio: minimum length ratio in inference (float) maxlenratio: maximum length ratio in inference :rtype: torch.Tensor :return: the sequence of stop probabilities (L) :rtype: torch.Tensor """ # get options threshold = inference_args.threshold minlenratio = inference_args.minlenratio maxlenratio = inference_args.maxlenratio # forward encoder xs = x.unsqueeze(0) hs, _ = self.encoder(xs, None) # set limits of length maxlen = int(hs.size(1) * maxlenratio / self.reduction_factor) minlen = int(hs.size(1) * minlenratio / self.reduction_factor) # initialize idx = 0 ys = hs.new_zeros(1, 1, self.odim) outs, probs = [], [] # forward decoder step-by-step while True: # update index idx += 1 # calculate output and stop prob at idx-th step y_masks = subsequent_mask(idx).unsqueeze(0) z = self.decoder.recognize(ys, y_masks, hs) # (B, adim) outs += [self.feat_out(z).view(self.reduction_factor, self.odim)] # [(r, odim), ...] probs += [torch.sigmoid(self.prob_out(z))[0]] # [(r), ...] # update next inputs ys = torch.cat((ys, outs[-1][-1].view(1, 1, self.odim)), dim=1) # (1, idx + 1, odim) # check whether to finish generation if int(sum(probs[-1] >= threshold)) > 0 or idx >= maxlen: # check mininum length if idx < minlen: continue outs = torch.cat(outs, dim=0).unsqueeze(0).transpose( 1, 2) # (L, odim) -> (1, L, odim) -> (1, odim, L) if self.postnet is not None: outs = outs + self.postnet(outs) # (1, odim, L) outs = outs.transpose(2, 1).squeeze(0) # (L, odim) probs = torch.cat(probs, dim=0) break return outs, probs def calculate_all_attentions(self, xs, ilens, ys, olens, *args, **kwargs): """Calculate attention weights :param torch.Tensor xs: batch of padded character ids (B, Tmax) :param torch.Tensor ilens: list of lengths of each input batch (B) :param torch.Tensor ys: batch of padded target features (B, Lmax, odim) :param torch.Tensor ilens: list of lengths of each output batch (B) :return: attention weights dict :rtype: dict """ with torch.no_grad(): # forward encoder x_masks = self._source_mask(ilens) hs, _ = self.encoder(xs, x_masks) # add first zero frame and remove last frame for auto-regressive ys_in = self._add_first_frame_and_remove_last_frame(ys) # thin out frames for reduction factor (B, Lmax, odim) -> (B, Lmax//r, odim) if self.reduction_factor > 1: ys_in = ys_in[:, self.reduction_factor - 1::self.reduction_factor] olens_in = olens.new( [olen // self.reduction_factor for olen in olens]) else: olens_in = olens # forward decoder y_masks = self._target_mask(olens_in) xy_masks = self._source_to_target_mask(ilens, olens_in) zs, _ = self.decoder(ys_in, y_masks, hs, xy_masks) # (B, Lmax//r, odim * r) -> (B, Lmax//r * r, odim) before_outs = self.feat_out(zs).view(zs.size(0), -1, self.odim) # postnet -> (B, Lmax//r * r, odim) if self.postnet is None: after_outs = before_outs else: after_outs = before_outs + self.postnet( before_outs.transpose(1, 2)).transpose(1, 2) # modifiy mod part of groundtruth if self.reduction_factor > 1: olens = olens.new( [olen - olen % self.reduction_factor for olen in olens]) att_ws_dict = dict() for name, m in self.named_modules(): if isinstance(m, MultiHeadedAttention): attn = m.attn.cpu().numpy() if "encoder" in name: attn = [a[:, :l, :l] for a, l in zip(attn, ilens.tolist())] elif "decoder" in name: if "src" in name: attn = [ a[:, :ol, :il] for a, il, ol in zip( attn, ilens.tolist(), olens_in.tolist()) ] elif "self" in name: attn = [ a[:, :l, :l] for a, l in zip(attn, olens_in.tolist()) ] else: logging.warning("unknown attention module: " + name) else: logging.warning("unknown attention module: " + name) att_ws_dict[name] = attn att_ws_dict["before_postnet_fbank"] = [ m[:l].T for m, l in zip(before_outs.cpu().numpy(), olens.tolist()) ] att_ws_dict["after_postnet_fbank"] = [ m[:l].T for m, l in zip(after_outs.cpu().numpy(), olens.tolist()) ] return att_ws_dict def _source_mask(self, ilens): """Make mask for MultiHeadedAttention using padded sequences >>> ilens = [5, 3] >>> self._source_mask(ilens) tensor([[[1, 1, 1, 1, 1], [1, 1, 1, 1, 1], [1, 1, 1, 1, 1], [1, 1, 1, 1, 1], [1, 1, 1, 1, 1]], [[1, 1, 1, 0, 0], [1, 1, 1, 0, 0], [1, 1, 1, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]]], dtype=torch.uint8) """ x_masks = make_non_pad_mask(ilens).to(next(self.parameters()).device) return x_masks.unsqueeze(-2) & x_masks.unsqueeze(-1) def _target_mask(self, olens): """Make mask for MaskedMultiHeadedAttention using padded sequences >>> olens = [5, 3] >>> self._target_mask(olens) tensor([[[1, 0, 0, 0, 0], [1, 1, 0, 0, 0], [1, 1, 1, 0, 0], [1, 1, 1, 1, 0], [1, 1, 1, 1, 1]], [[1, 0, 0, 0, 0], [1, 1, 0, 0, 0], [1, 1, 1, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]]], dtype=torch.uint8) """ y_masks = make_non_pad_mask(olens).to(next(self.parameters()).device) s_masks = subsequent_mask(y_masks.size(-1), device=y_masks.device).unsqueeze(0) return y_masks.unsqueeze(-2) & s_masks & y_masks.unsqueeze(-1) def _source_to_target_mask(self, ilens, olens): """Make source to target mask for MultiHeadedAttention using padded sequences >>> ilens = [4, 2] >>> olens = [5, 3] >>> self._source_to_target_mask(ilens) tensor([[[1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1]], [[1, 1, 0, 0], [1, 1, 0, 0], [1, 1, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]]], dtype=torch.uint8) """ x_masks = make_non_pad_mask(ilens).to(next(self.parameters()).device) y_masks = make_non_pad_mask(olens).to(next(self.parameters()).device) return x_masks.unsqueeze(-2) & y_masks.unsqueeze(-1) @property def base_plot_keys(self): """base key names to plot during training. keys should match what `chainer.reporter` reports if you add the key `loss`, the reporter will report `main/loss` and `validation/main/loss` values. also `loss.png` will be created as a figure visulizing `main/loss` and `validation/main/loss` values. :rtype list[str] plot_keys: base keys to plot during training """ plot_keys = ["loss", "l1_loss", "l2_loss", "bce_loss"] if self.use_scaled_pos_enc: plot_keys += ["encoder_alpha", "decoder_alpha"] if self.use_guided_attn_loss: if "encoder" in self.modules_applied_guided_attn: plot_keys += ["enc_attn_loss"] if "decoder" in self.modules_applied_guided_attn: plot_keys += ["dec_attn_loss"] if "encoder-decoder" in self.modules_applied_guided_attn: plot_keys += ["enc_dec_attn_loss"] return plot_keys
class E2E(ASRInterface, torch.nn.Module): @staticmethod def add_arguments(parser): group = parser.add_argument_group("transformer model setting") group.add_argument("--transformer-init", type=str, default="pytorch", choices=["pytorch", "xavier_uniform", "xavier_normal", "kaiming_uniform", "kaiming_normal"], help='how to initialize transformer parameters') group.add_argument("--transformer-input-layer", type=str, default="conv2d", choices=["conv2d", "linear", "embed"], help='transformer input layer type') group.add_argument('--transformer-attn-dropout-rate', default=None, type=float, help='dropout in transformer attention. use --dropout-rate if None is set') group.add_argument('--transformer-lr', default=10.0, type=float, help='Initial value of learning rate') group.add_argument('--transformer-warmup-steps', default=25000, type=int, help='optimizer warmup steps') group.add_argument('--transformer-length-normalized-loss', default=True, type=strtobool, help='normalize loss by length') return parser @property def attention_plot_class(self): return PlotAttentionReport def __init__(self, idim, odim, args, ignore_id=-1): torch.nn.Module.__init__(self) if args.transformer_attn_dropout_rate is None: args.transformer_attn_dropout_rate = args.dropout_rate self.encoder = Encoder( idim=idim, attention_dim=args.adim, attention_heads=args.aheads, linear_units=args.eunits, num_blocks=args.elayers, input_layer=args.transformer_input_layer, dropout_rate=args.dropout_rate, positional_dropout_rate=args.dropout_rate, attention_dropout_rate=args.transformer_attn_dropout_rate ) self.decoder = Decoder( odim=odim, attention_dim=args.adim, attention_heads=args.aheads, linear_units=args.dunits, num_blocks=args.dlayers, dropout_rate=args.dropout_rate, positional_dropout_rate=args.dropout_rate, self_attention_dropout_rate=args.transformer_attn_dropout_rate, src_attention_dropout_rate=args.transformer_attn_dropout_rate ) self.sos = odim - 1 self.eos = odim - 1 self.odim = odim self.ignore_id = ignore_id self.subsample = [1] self.reporter = Reporter() # self.lsm_weight = a self.criterion = LabelSmoothingLoss(self.odim, self.ignore_id, args.lsm_weight, args.transformer_length_normalized_loss) # self.char_list = args.char_list # self.verbose = args.verbose self.reset_parameters(args) self.recog_args = None # unused self.adim = args.adim self.mtlalpha = args.mtlalpha if args.mtlalpha > 0.0: self.ctc = CTC(odim, args.adim, args.dropout_rate, ctc_type=args.ctc_type, reduce=True) else: self.ctc = None def reset_parameters(self, args): if args.transformer_init == "pytorch": return # weight init for p in self.parameters(): if p.dim() > 1: if args.transformer_init == "xavier_uniform": torch.nn.init.xavier_uniform_(p.data) elif args.transformer_init == "xavier_normal": torch.nn.init.xavier_normal_(p.data) elif args.transformer_init == "kaiming_uniform": torch.nn.init.kaiming_uniform_(p.data, nonlinearity="relu") elif args.transformer_init == "kaiming_normal": torch.nn.init.kaiming_normal_(p.data, nonlinearity="relu") else: raise ValueError("Unknown initialization: " + args.transformer_init) # bias init for p in self.parameters(): if p.dim() == 1: p.data.zero_() # reset some modules with default init for m in self.modules(): if isinstance(m, (torch.nn.Embedding, LayerNorm)): m.reset_parameters() def add_sos_eos(self, ys_pad): from espnet.nets.pytorch_backend.nets_utils import pad_list eos = ys_pad.new([self.eos]) sos = ys_pad.new([self.sos]) ys = [y[y != self.ignore_id] for y in ys_pad] # parse padded ys ys_in = [torch.cat([sos, y], dim=0) for y in ys] ys_out = [torch.cat([y, eos], dim=0) for y in ys] return pad_list(ys_in, self.eos), pad_list(ys_out, self.ignore_id) def target_mask(self, ys_in_pad): ys_mask = ys_in_pad != self.ignore_id m = subsequent_mask(ys_mask.size(-1), device=ys_mask.device).unsqueeze(0) return ys_mask.unsqueeze(-2) & m def forward(self, xs_pad, ilens, ys_pad): '''E2E forward :param torch.Tensor xs_pad: batch of padded source sequences (B, Tmax, idim) :param torch.Tensor ilens: batch of lengths of source sequences (B) :param torch.Tensor ys_pad: batch of padded target sequences (B, Lmax) :return: ctc loass value :rtype: torch.Tensor :return: attention loss value :rtype: torch.Tensor :return: accuracy in attention decoder :rtype: float ''' # forward encoder xs_pad = xs_pad[:, :max(ilens)] # for data parallel src_mask = (~make_pad_mask(ilens.tolist())).to(xs_pad.device).unsqueeze(-2) hs_pad, hs_mask = self.encoder(xs_pad, src_mask) self.hs_pad = hs_pad # forward decoder ys_in_pad, ys_out_pad = self.add_sos_eos(ys_pad) ys_mask = self.target_mask(ys_in_pad) pred_pad, pred_mask = self.decoder(ys_in_pad, ys_mask, hs_pad, hs_mask) self.pred_pad = pred_pad # compute loss loss_att = self.criterion(pred_pad, ys_out_pad) self.acc = th_accuracy(pred_pad.view(-1, self.odim), ys_out_pad, ignore_label=self.ignore_id) # TODO(karita) show predected text # TODO(karita) calculate these stats cer, cer_ctc, wer = 0.0, 0.0, 0.0 if self.ctc is None: loss_ctc = None else: batch_size = xs_pad.size(0) hs_len = hs_mask.view(batch_size, -1).sum(1) loss_ctc = self.ctc(hs_pad.view(batch_size, -1, self.adim), hs_len, ys_pad) # copyied from e2e_asr alpha = self.mtlalpha if alpha == 0: self.loss = loss_att loss_att_data = float(loss_att) loss_ctc_data = None elif alpha == 1: self.loss = loss_ctc loss_att_data = None loss_ctc_data = float(loss_ctc) else: self.loss = alpha * loss_ctc + (1 - alpha) * loss_att loss_att_data = float(loss_att) loss_ctc_data = float(loss_ctc) loss_data = float(self.loss) if loss_data < CTC_LOSS_THRESHOLD and not math.isnan(loss_data): self.reporter.report(loss_ctc_data, loss_att_data, self.acc, cer_ctc, cer, wer, loss_data) else: logging.warning('loss (=%f) is not correct', loss_data) return self.loss def recognize(self, feat, recog_args, char_list=None, rnnlm=None, use_jit=False): '''recognize feat :param ndnarray x: input acouctic feature (B, T, D) or (T, D) :param namespace recog_args: argment namespace contraining options :param list char_list: list of characters :param torch.nn.Module rnnlm: language model module :return: N-best decoding results :rtype: list TODO(karita): do not recompute previous attention for faster decoding ''' self.eval() feat = torch.as_tensor(feat).unsqueeze(0) enc_output, _ = self.encoder(feat, None) if recog_args.ctc_weight > 0.0: lpz = self.ctc.log_softmax(enc_output) lpz = lpz.squeeze(0) else: lpz = None h = enc_output.squeeze(0) logging.info('input lengths: ' + str(h.size(0))) # search parms beam = recog_args.beam_size penalty = recog_args.penalty ctc_weight = recog_args.ctc_weight # preprare sos y = self.sos vy = h.new_zeros(1).long() if recog_args.maxlenratio == 0: maxlen = h.shape[0] else: # maxlen >= 1 maxlen = max(1, int(recog_args.maxlenratio * h.size(0))) minlen = int(recog_args.minlenratio * h.size(0)) logging.info('max output length: ' + str(maxlen)) logging.info('min output length: ' + str(minlen)) # initialize hypothesis if rnnlm: hyp = {'score': 0.0, 'yseq': [y], 'rnnlm_prev': None} else: hyp = {'score': 0.0, 'yseq': [y]} if lpz is not None: import numpy from espnet.nets.ctc_prefix_score import CTCPrefixScore ctc_prefix_score = CTCPrefixScore(lpz.detach().numpy(), 0, self.eos, numpy) hyp['ctc_state_prev'] = ctc_prefix_score.initial_state() hyp['ctc_score_prev'] = 0.0 if ctc_weight != 1.0: # pre-pruning based on attention scores from espnet.nets.pytorch_backend.rnn.decoders import CTC_SCORING_RATIO ctc_beam = min(lpz.shape[-1], int(beam * CTC_SCORING_RATIO)) else: ctc_beam = lpz.shape[-1] hyps = [hyp] ended_hyps = [] import six traced_decoder = None for i in six.moves.range(maxlen): logging.debug('position ' + str(i)) hyps_best_kept = [] for hyp in hyps: vy.unsqueeze(1) vy[0] = hyp['yseq'][i] # get nbest local scores and their ids ys_mask = subsequent_mask(i + 1).unsqueeze(0) ys = torch.tensor(hyp['yseq']).unsqueeze(0) # FIXME: jit does not match non-jit result if use_jit: if traced_decoder is None: traced_decoder = torch.jit.trace(self.decoder.recognize, (ys, ys_mask, enc_output)) local_att_scores = traced_decoder(ys, ys_mask, enc_output) else: local_att_scores = self.decoder.recognize(ys, ys_mask, enc_output) if rnnlm: rnnlm_state, local_lm_scores = rnnlm.predict(hyp['rnnlm_prev'], vy) local_scores = local_att_scores + recog_args.lm_weight * local_lm_scores else: local_scores = local_att_scores if lpz is not None: local_best_scores, local_best_ids = torch.topk( local_att_scores, ctc_beam, dim=1) ctc_scores, ctc_states = ctc_prefix_score( hyp['yseq'], local_best_ids[0], hyp['ctc_state_prev']) local_scores = \ (1.0 - ctc_weight) * local_att_scores[:, local_best_ids[0]] \ + ctc_weight * torch.from_numpy(ctc_scores - hyp['ctc_score_prev']) if rnnlm: local_scores += recog_args.lm_weight * local_lm_scores[:, local_best_ids[0]] local_best_scores, joint_best_ids = torch.topk(local_scores, beam, dim=1) local_best_ids = local_best_ids[:, joint_best_ids[0]] else: local_best_scores, local_best_ids = torch.topk(local_scores, beam, dim=1) for j in six.moves.range(beam): new_hyp = {} new_hyp['score'] = hyp['score'] + float(local_best_scores[0, j]) new_hyp['yseq'] = [0] * (1 + len(hyp['yseq'])) new_hyp['yseq'][:len(hyp['yseq'])] = hyp['yseq'] new_hyp['yseq'][len(hyp['yseq'])] = int(local_best_ids[0, j]) if rnnlm: new_hyp['rnnlm_prev'] = rnnlm_state if lpz is not None: new_hyp['ctc_state_prev'] = ctc_states[joint_best_ids[0, j]] new_hyp['ctc_score_prev'] = ctc_scores[joint_best_ids[0, j]] # will be (2 x beam) hyps at most hyps_best_kept.append(new_hyp) hyps_best_kept = sorted( hyps_best_kept, key=lambda x: x['score'], reverse=True)[:beam] # sort and get nbest hyps = hyps_best_kept logging.debug('number of pruned hypothes: ' + str(len(hyps))) if char_list is not None: logging.debug( 'best hypo: ' + ''.join([char_list[int(x)] for x in hyps[0]['yseq'][1:]])) # add eos in the final loop to avoid that there are no ended hyps if i == maxlen - 1: logging.info('adding <eos> in the last postion in the loop') for hyp in hyps: hyp['yseq'].append(self.eos) # add ended hypothes to a final list, and removed them from current hypothes # (this will be a probmlem, number of hyps < beam) remained_hyps = [] for hyp in hyps: if hyp['yseq'][-1] == self.eos: # only store the sequence that has more than minlen outputs # also add penalty if len(hyp['yseq']) > minlen: hyp['score'] += (i + 1) * penalty if rnnlm: # Word LM needs to add final <eos> score hyp['score'] += recog_args.lm_weight * rnnlm.final( hyp['rnnlm_prev']) ended_hyps.append(hyp) else: remained_hyps.append(hyp) # end detection from espnet.nets.e2e_asr_common import end_detect if end_detect(ended_hyps, i) and recog_args.maxlenratio == 0.0: logging.info('end detected at %d', i) break hyps = remained_hyps if len(hyps) > 0: logging.debug('remeined hypothes: ' + str(len(hyps))) else: logging.info('no hypothesis. Finish decoding.') break if char_list is not None: for hyp in hyps: logging.debug( 'hypo: ' + ''.join([char_list[int(x)] for x in hyp['yseq'][1:]])) logging.debug('number of ended hypothes: ' + str(len(ended_hyps))) nbest_hyps = sorted( ended_hyps, key=lambda x: x['score'], reverse=True)[:min(len(ended_hyps), recog_args.nbest)] # check number of hypotheis if len(nbest_hyps) == 0: logging.warning('there is no N-best results, perform recognition again with smaller minlenratio.') # should copy becasuse Namespace will be overwritten globally recog_args = Namespace(**vars(recog_args)) recog_args.minlenratio = max(0.0, recog_args.minlenratio - 0.1) return self.recognize(feat, recog_args, char_list, rnnlm) logging.info('total log probability: ' + str(nbest_hyps[0]['score'])) logging.info('normalized log probability: ' + str(nbest_hyps[0]['score'] / len(nbest_hyps[0]['yseq']))) return nbest_hyps def calculate_all_attentions(self, xs_pad, ilens, ys_pad): '''E2E attention calculation :param torch.Tensor xs_pad: batch of padded input sequences (B, Tmax, idim) :param torch.Tensor ilens: batch of lengths of input sequences (B) :param torch.Tensor ys_pad: batch of padded character id sequence tensor (B, Lmax) :return: attention weights with the following shape, 1) multi-head case => attention weights (B, H, Lmax, Tmax), 2) other case => attention weights (B, Lmax, Tmax). :rtype: float ndarray ''' with torch.no_grad(): self.forward(xs_pad, ilens, ys_pad) ret = dict() for name, m in self.named_modules(): if isinstance(m, MultiHeadedAttention): ret[name] = m.attn.cpu().numpy() return ret
def __init__(self, idim, odim, args, ignore_id=-1): torch.nn.Module.__init__(self) if args.transformer_attn_dropout_rate is None: args.transformer_attn_dropout_rate = args.dropout_rate self.encoder = Encoder( idim=idim, center_len=args.transformer_encoder_center_chunk_len, left_len=args.transformer_encoder_left_chunk_len, hop_len=args.transformer_encoder_hop_len, right_len=args.transformer_encoder_right_chunk_len, abs_pos=args.transformer_encoder_abs_embed, rel_pos=args.transformer_encoder_rel_embed, use_mem=args.transformer_encoder_use_memory, att_type=args.transformer_encoder_att_type, attention_dim=args.adim, attention_heads=args.aheads, linear_units=args.eunits, num_blocks=args.elayers, input_layer=args.transformer_input_layer, dropout_rate=args.dropout_rate, positional_dropout_rate=args.dropout_rate, attention_dropout_rate=args.transformer_attn_dropout_rate) self.decoder = Decoder( odim=odim, attention_dim=args.adim, attention_heads=args.aheads, linear_units=args.dunits, num_blocks=args.dlayers, dropout_rate=args.dropout_rate, positional_dropout_rate=args.dropout_rate, self_attention_dropout_rate=args.transformer_attn_dropout_rate, src_attention_dropout_rate=args.transformer_attn_dropout_rate) self.sos = odim - 1 self.eos = odim - 1 self.odim = odim self.ignore_id = ignore_id self.subsample = [1] self.reporter = Reporter() # self.lsm_weight = a self.criterion = LabelSmoothingLoss( self.odim, self.ignore_id, args.lsm_weight, args.transformer_length_normalized_loss) # self.verbose = args.verbose self.reset_parameters(args) self.adim = args.adim self.mtlalpha = args.mtlalpha if args.mtlalpha > 0.0: self.ctc = CTC(odim, args.adim, args.dropout_rate, ctc_type=args.ctc_type, reduce=True) else: self.ctc = None if args.report_cer or args.report_wer or args.mtlalpha > 0.0: from espnet.nets.e2e_asr_common import ErrorCalculator self.error_calculator = ErrorCalculator(args.char_list, args.sym_space, args.sym_blank, args.report_cer, args.report_wer) else: self.error_calculator = None self.rnnlm = None
def __init__(self, idim, odim, args, ignore_id=-1): """Construct an E2E object. :param int idim: dimension of inputs :param int odim: dimension of outputs :param Namespace args: argument Namespace containing options """ torch.nn.Module.__init__(self) if args.transformer_attn_dropout_rate is None: args.transformer_attn_dropout_rate = args.dropout_rate self.encoder = Encoder( idim=idim, attention_dim=args.adim, attention_heads=args.aheads, linear_units=args.eunits, num_blocks=args.elayers, input_layer='embed', dropout_rate=args.dropout_rate, positional_dropout_rate=args.dropout_rate, attention_dropout_rate=args.transformer_attn_dropout_rate, ) self.decoder = Decoder( odim=odim, attention_dim=args.adim, attention_heads=args.aheads, linear_units=args.dunits, num_blocks=args.dlayers, dropout_rate=args.dropout_rate, positional_dropout_rate=args.dropout_rate, self_attention_dropout_rate=args.transformer_attn_dropout_rate, src_attention_dropout_rate=args.transformer_attn_dropout_rate, ) self.pad = 0 self.sos = odim - 1 self.eos = odim - 1 self.odim = odim self.ignore_id = ignore_id self.subsample = get_subsample(args, mode='mt', arch='transformer') self.reporter = Reporter() # tie source and target emeddings if args.tie_src_tgt_embedding: if idim != odim: raise ValueError( 'When using tie_src_tgt_embedding, idim and odim must be equal.' ) self.encoder.embed[0].weight = self.decoder.embed[0].weight # tie emeddings and the classfier if args.tie_classifier: self.decoder.output_layer.weight = self.decoder.embed[0].weight # self.lsm_weight = a self.criterion = LabelSmoothingLoss( self.odim, self.ignore_id, args.lsm_weight, args.transformer_length_normalized_loss) self.normalize_length = args.transformer_length_normalized_loss # for PPL # self.verbose = args.verbose self.reset_parameters(args) self.adim = args.adim if args.report_bleu: from espnet.nets.e2e_mt_common import ErrorCalculator self.error_calculator = ErrorCalculator(args.char_list, args.sym_space, args.report_bleu) else: self.error_calculator = None self.rnnlm = None # multilingual NMT related self.multilingual = args.multilingual
class E2E(torch.nn.Module): """E2E module. :param int idim: dimension of inputs :param int odim: dimension of outputs :param Namespace args: argument Namespace containing options """ @staticmethod def add_arguments(parser): """Add arguments.""" group = parser.add_argument_group("transformer model setting") group.add_argument("--transformer-init", type=str, default="pytorch", choices=[ "pytorch", "xavier_uniform", "xavier_normal", "kaiming_uniform", "kaiming_normal" ], help='how to initialize transformer parameters') group.add_argument("--transformer-input-layer", type=str, default="conv2d", choices=["conv2d", "linear", "embed", "custom"], help='transformer input layer type') group.add_argument("--transformer-output-layer", type=str, default="embed", choices=["embed", "linear", "conv"]) group.add_argument( '--transformer-attn-dropout-rate', default=None, type=float, help= 'dropout in transformer attention. use --dropout-rate if None is set' ) group.add_argument('--transformer-lr', default=10.0, type=float, help='Initial value of learning rate') group.add_argument('--transformer-warmup-steps', default=25000, type=int, help='optimizer warmup steps') group.add_argument('--transformer-length-normalized-loss', default=True, type=strtobool, help='normalize loss by length') group.add_argument('--dropout-rate', default=0.0, type=float, help='Dropout rate for the encoder') # Encoder group.add_argument( '--elayers', default=4, type=int, help= 'Number of encoder layers (for shared recognition part in multi-speaker asr mode)' ) group.add_argument('--eunits', '-u', default=300, type=int, help='Number of encoder hidden units') # Attention group.add_argument( '--adim', default=320, type=int, help='Number of attention transformation dimensions') group.add_argument('--aheads', default=4, type=int, help='Number of heads for multi head attention') # Decoder group.add_argument('--dlayers', default=1, type=int, help='Number of decoder layers') group.add_argument('--dunits', default=320, type=int, help='Number of decoder hidden units') return parser def __init__(self, idim, odim, args, ignore_id=-1): """Construct an E2E object. :param int idim: dimension of inputs :param int odim: dimension of outputs :param Namespace args: argument Namespace containing options """ torch.nn.Module.__init__(self) if args.transformer_attn_dropout_rate is None: args.transformer_attn_dropout_rate = args.dropout_rate self.encoder = Encoder( idim=idim, attention_dim=args.adim, attention_heads=args.aheads, linear_units=args.eunits, num_blocks=args.elayers, input_layer=args.transformer_input_layer, dropout_rate=args.dropout_rate, positional_dropout_rate=args.dropout_rate, attention_dropout_rate=args.transformer_attn_dropout_rate) self.decoder = Decoder( odim=odim, attention_dim=args.adim, attention_heads=args.aheads, linear_units=args.dunits, num_blocks=args.dlayers, input_layer=args.transformer_output_layer, dropout_rate=args.dropout_rate, positional_dropout_rate=args.dropout_rate, self_attention_dropout_rate=args.transformer_attn_dropout_rate, src_attention_dropout_rate=args.transformer_attn_dropout_rate) self.sos = odim - 1 self.eos = odim - 1 self.odim = odim self.ignore_id = ignore_id self.subsample = [1] # self.lsm_weight = a self.criterion = LabelSmoothingLoss( self.odim, self.ignore_id, args.lsm_weight, args.transformer_length_normalized_loss) # self.verbose = args.verbose self.reset_parameters(args) self.adim = args.adim self.mtlalpha = args.mtlalpha if args.mtlalpha > 0.0: self.ctc = CTC(odim, args.adim, args.dropout_rate, ctc_type=args.ctc_type, reduce=True) else: self.ctc = None self.rnnlm = None def reset_parameters(self, args): """Initialize parameters.""" # initialize parameters initialize(self, args.transformer_init) def forward(self, xs_pad, ilens, ys_pad): """E2E forward. :param torch.Tensor xs_pad: batch of padded source sequences (B, Tmax, idim) :param torch.Tensor ilens: batch of lengths of source sequences (B) :param torch.Tensor ys_pad: batch of padded target sequences (B, Lmax) :return: ctc loass value :rtype: torch.Tensor :return: attention loss value :rtype: torch.Tensor :return: accuracy in attention decoder :rtype: float """ # 1. forward encoder xs_pad = xs_pad[:, :max(ilens)] # for data parallel src_mask = make_non_pad_mask(ilens.tolist()).to( xs_pad.device).unsqueeze(-2) hs_pad, hs_mask = self.encoder(xs_pad, src_mask) # CTC forward ys = [y[y != self.ignore_id] for y in ys_pad] y_len = max([len(y) for y in ys]) ys_pad = ys_pad[:, :y_len] self.hs_pad = hs_pad cer_ctc = None batch_size = xs_pad.size(0) if self.mtlalpha == 0.0: loss_ctc = None else: batch_size = xs_pad.size(0) hs_len = hs_mask.view(batch_size, -1).sum(1) loss_ctc = self.ctc(hs_pad.view(batch_size, -1, self.adim), hs_len, ys_pad) # trigger mask start_time = time.time() # 2. forward decoder ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id) ys_mask = target_mask(ys_in_pad, self.ignore_id) pred_pad, pred_mask = self.decoder(ys_in_pad, ys_mask, hs_pad, hs_mask) self.pred_pad = pred_pad # 3. compute attention loss loss_att = self.criterion(pred_pad, ys_out_pad) self.acc = th_accuracy(pred_pad.view(-1, self.odim), ys_out_pad, ignore_label=self.ignore_id) # copyied from e2e_asr alpha = self.mtlalpha if alpha == 0: self.loss = loss_att loss_att_data = float(loss_att) loss_ctc_data = None elif alpha == 1: self.loss = loss_ctc loss_att_data = None loss_ctc_data = float(loss_ctc) else: self.loss = alpha * loss_ctc + (1 - alpha) * loss_att loss_att_data = float(loss_att) return self.loss, loss_ctc_data, loss_att_data, self.acc def encode(self, x, mask=None): """Encode acoustic features. :param ndarray x: source acoustic feature (T, D) :return: encoder outputs :rtype: torch.Tensor """ self.eval() x = torch.as_tensor(x).unsqueeze(0).cuda() if mask is not None: mask = mask.cuda() if isinstance(self.encoder.embed, EncoderConv2d): hs, _ = self.encoder.embed( x, torch.Tensor([float(x.shape[1])]).cuda()) else: hs, _ = self.encoder.embed(x, None) enc_output, _ = self.encoder.encoders(hs, mask) if self.encoder.normalize_before: enc_output = self.encoder.after_norm(enc_output) return enc_output.squeeze(0) def viterbi_decode(self, x, y, mask=None): enc_output = self.encode(x) logits = self.ctc.ctc_lo(enc_output).detach().data logit = np.array(logits.cpu().data).T align = viterbi_align(logit, y)[0] return align def ctc_decode(self, x, mask=None): enc_output = self.encode(x) logits = self.ctc.argmax(enc_output.view(1, -1, 512)).detach().data path = np.array(logits.cpu()[0]) return path def recognize(self, x, recog_args, char_list=None, rnnlm=None, use_jit=False): """Recognize input speech. :param ndnarray x: input acoustic feature (B, T, D) or (T, D) :param Namespace recog_args: argment Namespace contraining options :param list char_list: list of characters :param torch.nn.Module rnnlm: language model module :return: N-best decoding results :rtype: list """ enc_output = self.encode(x).unsqueeze(0) if recog_args.ctc_weight > 0.0: lpz = self.ctc.log_softmax(enc_output) lpz = lpz.squeeze(0) else: lpz = None h = enc_output.squeeze(0) logging.info('input lengths: ' + str(h.size(0))) # search parms beam = recog_args.beam_size penalty = recog_args.penalty ctc_weight = recog_args.ctc_weight # preprare sos y = self.sos vy = h.new_zeros(1).long() if recog_args.maxlenratio == 0: maxlen = h.shape[0] else: # maxlen >= 1 maxlen = max(1, int(recog_args.maxlenratio * h.size(0))) minlen = int(recog_args.minlenratio * h.size(0)) logging.info('max output length: ' + str(maxlen)) logging.info('min output length: ' + str(minlen)) # initialize hypothesis if rnnlm: hyp = {'score': 0.0, 'yseq': [y], 'rnnlm_prev': None} else: hyp = {'score': 0.0, 'yseq': [y]} if lpz is not None: import numpy from espnet.nets.ctc_prefix_score import CTCPrefixScore ctc_prefix_score = CTCPrefixScore(lpz.cpu().detach().numpy(), 0, self.eos, numpy) hyp['ctc_state_prev'] = ctc_prefix_score.initial_state() hyp['ctc_score_prev'] = 0.0 if ctc_weight != 1.0: # pre-pruning based on attention scores from espnet.nets.pytorch_backend.rnn.decoders import CTC_SCORING_RATIO ctc_beam = min(lpz.shape[-1], int(beam * CTC_SCORING_RATIO)) else: ctc_beam = lpz.shape[-1] hyps = [hyp] ended_hyps = [] import six traced_decoder = None for i in six.moves.range(maxlen): logging.debug('position ' + str(i)) hyps_best_kept = [] for hyp in hyps: vy.unsqueeze(1) vy[0] = hyp['yseq'][i] # get nbest local scores and their ids ys_mask = subsequent_mask(i + 1).unsqueeze(0).cuda() ys = torch.tensor(hyp['yseq']).unsqueeze(0).cuda() # FIXME: jit does not match non-jit result if use_jit: if traced_decoder is None: traced_decoder = torch.jit.trace( self.decoder.forward_one_step, (ys, ys_mask, enc_output)) local_att_scores = traced_decoder(ys, ys_mask, enc_output)[0] else: local_att_scores = self.decoder.forward_one_step( ys, ys_mask, enc_output)[0] if rnnlm: rnnlm_state, local_lm_scores = rnnlm.predict( hyp['rnnlm_prev'], vy) local_scores = local_att_scores + recog_args.lm_weight * local_lm_scores else: local_scores = local_att_scores if lpz is not None: local_best_scores, local_best_ids = torch.topk( local_att_scores, ctc_beam, dim=1) ctc_scores, ctc_states = ctc_prefix_score( hyp['yseq'], local_best_ids[0].cpu(), hyp['ctc_state_prev']) local_scores = \ (1.0 - ctc_weight) * local_att_scores[:, local_best_ids[0]].cpu() \ + ctc_weight * torch.from_numpy(ctc_scores - hyp['ctc_score_prev']) if rnnlm: local_scores += recog_args.lm_weight * local_lm_scores[:, local_best_ids[ 0]].cpu() local_best_scores, joint_best_ids = torch.topk( local_scores, beam, dim=1) local_best_ids = local_best_ids[:, joint_best_ids[0]] else: local_best_scores, local_best_ids = torch.topk( local_scores, beam, dim=1) for j in six.moves.range(beam): new_hyp = {} new_hyp['score'] = hyp['score'] + float( local_best_scores[0, j]) new_hyp['yseq'] = [0] * (1 + len(hyp['yseq'])) new_hyp['yseq'][:len(hyp['yseq'])] = hyp['yseq'] new_hyp['yseq'][len(hyp['yseq'])] = int(local_best_ids[0, j]) if rnnlm: new_hyp['rnnlm_prev'] = rnnlm_state if lpz is not None: new_hyp['ctc_state_prev'] = ctc_states[joint_best_ids[ 0, j]] new_hyp['ctc_score_prev'] = ctc_scores[joint_best_ids[ 0, j]] # will be (2 x beam) hyps at most hyps_best_kept.append(new_hyp) hyps_best_kept = sorted(hyps_best_kept, key=lambda x: x['score'], reverse=True)[:beam] # sort and get nbest hyps = hyps_best_kept logging.debug('number of pruned hypothes: ' + str(len(hyps))) if char_list is not None: logging.debug( 'best hypo: ' + ''.join([char_list[int(x)] for x in hyps[0]['yseq'][1:]])) # add eos in the final loop to avoid that there are no ended hyps if i == maxlen - 1: logging.info('adding <eos> in the last postion in the loop') for hyp in hyps: hyp['yseq'].append(self.eos) # add ended hypothes to a final list, and removed them from current hypothes # (this will be a probmlem, number of hyps < beam) remained_hyps = [] for hyp in hyps: if hyp['yseq'][-1] == self.eos: # only store the sequence that has more than minlen outputs # also add penalty if len(hyp['yseq']) > minlen: hyp['score'] += (i + 1) * penalty if rnnlm: # Word LM needs to add final <eos> score hyp['score'] += recog_args.lm_weight * rnnlm.final( hyp['rnnlm_prev']) ended_hyps.append(hyp) else: remained_hyps.append(hyp) # end detection from espnet.nets.e2e_asr_common import end_detect if end_detect(ended_hyps, i) and recog_args.maxlenratio == 0.0: logging.info('end detected at %d', i) break hyps = remained_hyps if len(hyps) > 0: logging.debug('remeined hypothes: ' + str(len(hyps))) else: logging.info('no hypothesis. Finish decoding.') break if char_list is not None: for hyp in hyps: logging.debug( 'hypo: ' + ''.join([char_list[int(x)] for x in hyp['yseq'][1:]])) logging.debug('number of ended hypothes: ' + str(len(ended_hyps))) nbest_hyps = sorted( ended_hyps, key=lambda x: x['score'], reverse=True)[:min(len(ended_hyps), recog_args.nbest)] # check number of hypotheis if len(nbest_hyps) == 0: logging.warning( 'there is no N-best results, perform recognition again with smaller minlenratio.' ) # should copy becasuse Namespace will be overwritten globally recog_args = Namespace(**vars(recog_args)) recog_args.minlenratio = max(0.0, recog_args.minlenratio - 0.1) return self.recognize(x, recog_args, char_list, rnnlm) logging.info('total log probability: ' + str(nbest_hyps[0]['score'])) logging.info('normalized log probability: ' + str(nbest_hyps[0]['score'] / len(nbest_hyps[0]['yseq']))) return nbest_hyps
def __init__(self, idim, odim, args, ignore_id=-1): """Construct an E2E object. :param int idim: dimension of inputs :param int odim: dimension of outputs :param Namespace args: argument Namespace containing options """ torch.nn.Module.__init__(self) if args.transformer_attn_dropout_rate is None: args.transformer_attn_dropout_rate = args.dropout_rate self.encoder = Encoder( idim=idim, attention_dim=args.adim, attention_heads=args.aheads, linear_units=args.eunits, num_blocks=args.elayers, input_layer=args.transformer_input_layer, dropout_rate=args.dropout_rate, positional_dropout_rate=args.dropout_rate, attention_dropout_rate=args.transformer_attn_dropout_rate, ) self.decoder = Decoder( odim=odim, attention_dim=args.adim, attention_heads=args.aheads, linear_units=args.dunits, num_blocks=args.dlayers, dropout_rate=args.dropout_rate, positional_dropout_rate=args.dropout_rate, self_attention_dropout_rate=args.transformer_attn_dropout_rate, src_attention_dropout_rate=args.transformer_attn_dropout_rate, ) self.sos = odim - 1 self.eos = odim - 1 self.odim = odim self.ignore_id = ignore_id self.subsample = get_subsample(args, mode="asr", arch="transformer") self.reporter = Reporter() # self.lsm_weight = a self.criterion = LabelSmoothingLoss( self.odim, self.ignore_id, args.lsm_weight, args.transformer_length_normalized_loss, ) # self.verbose = args.verbose self.reset_parameters(args) self.adim = args.adim self.mtlalpha = args.mtlalpha if args.mtlalpha > 0.0: self.ctc = CTC(odim, args.adim, args.dropout_rate, ctc_type=args.ctc_type, reduce=True) else: self.ctc = None if args.report_cer or args.report_wer: self.error_calculator = ErrorCalculator( args.char_list, args.sym_space, args.sym_blank, args.report_cer, args.report_wer, ) else: self.error_calculator = None self.rnnlm = None
def __init__(self, idim, odim, args, ignore_id=-1): """Construct an E2E object. :param int idim: dimension of inputs :param int odim: dimension of outputs :param Namespace args: argument Namespace containing options """ torch.nn.Module.__init__(self) if args.transformer_attn_dropout_rate is None: args.transformer_attn_dropout_rate = args.dropout_rate self.encoder_type = getattr(args, 'encoder_type', 'all_add') self.vbs = getattr(args, 'vbs', False) self.noise = getattr(args, 'noise_type', 'none') if self.encoder_type == 'all_add': from espnet.nets.pytorch_backend.transformer.multimodal_encoder_all_add import MultimodalEncoder elif self.encoder_type == 'proportion_add': from espnet.nets.pytorch_backend.transformer.multimodal_encoder_proportion_add import MultimodalEncoder elif self.encoder_type == 'vat': from espnet.nets.pytorch_backend.transformer.multimodal_encoder_vat import MultimodalEncoder self.encoder = MultimodalEncoder( idim=idim, attention_dim=args.adim, attention_heads=args.aheads, linear_units=args.eunits, num_blocks=args.elayers, visual_dim=args.visual_dim, input_layer=args.transformer_input_layer, dropout_rate=args.dropout_rate, positional_dropout_rate=args.dropout_rate, attention_dropout_rate=args.transformer_attn_dropout_rate, vbs=self.vbs) self.decoder = MultimodalDecoder( odim=odim, attention_dim=args.adim, attention_heads=args.aheads, linear_units=args.dunits, num_blocks=args.dlayers, dropout_rate=args.dropout_rate, positional_dropout_rate=args.dropout_rate, self_attention_dropout_rate=args.transformer_attn_dropout_rate, src_attention_dropout_rate=args.transformer_attn_dropout_rate) self.pad = 0 self.sos = odim - 1 self.eos = odim - 1 self.odim = odim self.ignore_id = ignore_id self.subsample = get_subsample(args, mode='st', arch='transformer') self.reporter = Reporter() # self.lsm_weight = a self.criterion = LabelSmoothingLoss( self.odim, self.ignore_id, args.lsm_weight, args.transformer_length_normalized_loss) # self.verbose = args.verbose self.adim = args.adim # submodule for ASR task self.mtlalpha = args.mtlalpha self.asr_weight = getattr(args, "asr_weight", 0.0) if self.asr_weight > 0 and args.mtlalpha < 1: self.decoder_asr = Decoder( odim=odim, attention_dim=args.adim, attention_heads=args.aheads, linear_units=args.dunits, num_blocks=args.dlayers, dropout_rate=args.dropout_rate, positional_dropout_rate=args.dropout_rate, self_attention_dropout_rate=args.transformer_attn_dropout_rate, src_attention_dropout_rate=args.transformer_attn_dropout_rate, ) # submodule for MT task self.mt_weight = getattr(args, "mt_weight", 0.0) if self.mt_weight > 0: self.encoder_mt = Encoder( idim=odim, attention_dim=args.adim, attention_heads=args.aheads, linear_units=args.dunits, num_blocks=args.dlayers, input_layer='embed', dropout_rate=args.dropout_rate, positional_dropout_rate=args.dropout_rate, attention_dropout_rate=args.transformer_attn_dropout_rate, padding_idx=0) self.reset_parameters(args) # place after the submodule initialization if args.mtlalpha > 0.0: self.ctc = CTC(odim, args.adim, args.dropout_rate, ctc_type=args.ctc_type, reduce=True) else: self.ctc = None if self.asr_weight > 0 and (args.report_cer or args.report_wer): from espnet.nets.e2e_asr_common import ErrorCalculator self.error_calculator = ErrorCalculator(args.char_list, args.sym_space, args.sym_blank, args.report_cer, args.report_wer) else: self.error_calculator = None self.rnnlm = None # multilingual E2E-ST related self.multilingual = getattr(args, "multilingual", False) self.replace_sos = getattr(args, "replace_sos", False) if self.multilingual: assert self.replace_sos self.device = "cuda:0" if torch.cuda.is_available() else "cpu"
def __init__(self, idim, odim, args, ignore_id=-1): """Construct an E2E object. :param int idim: dimension of inputs :param int odim: dimension of outputs :param Namespace args: argument Namespace containing options """ torch.nn.Module.__init__(self) if args.transformer_attn_dropout_rate is None: args.transformer_attn_dropout_rate = args.dropout_rate self.cn_encoder = Encoder( idim=idim, attention_dim=args.adim, attention_heads=args.aheads, linear_units=args.eunits, num_blocks=args.elayers, input_layer=args.transformer_input_layer, dropout_rate=args.dropout_rate, positional_dropout_rate=args.dropout_rate, attention_dropout_rate=args.transformer_attn_dropout_rate) self.en_encoder = Encoder( idim=idim, attention_dim=args.adim, attention_heads=args.aheads, linear_units=args.eunits, num_blocks=args.elayers, input_layer=args.transformer_input_layer, dropout_rate=args.dropout_rate, positional_dropout_rate=args.dropout_rate, attention_dropout_rate=args.transformer_attn_dropout_rate) # gated add module self.vectorize_lambda = args.vectorize_lambda lambda_dim = args.adim if self.vectorize_lambda else 1 # note: dropout is activated in the linear proj layer, not 1-layer lstm self.aggregation_module = EncoderAggregrationLSTM(idim=2 * args.adim, odim=lambda_dim, num_layers=1, hidden_dim=args.adim, bi=True, drop=0.1) self.decoder = Decoder( odim=odim, attention_dim=args.adim, attention_heads=args.aheads, linear_units=args.dunits, num_blocks=args.dlayers, dropout_rate=args.dropout_rate, positional_dropout_rate=args.dropout_rate, self_attention_dropout_rate=args.transformer_attn_dropout_rate, src_attention_dropout_rate=args.transformer_attn_dropout_rate) self.sos = odim - 1 self.eos = odim - 1 self.odim = odim self.ignore_id = ignore_id self.subsample = [1] self.reporter = Reporter() # self.lsm_weight = a self.criterion = LabelSmoothingLoss( self.odim, self.ignore_id, args.lsm_weight, args.transformer_length_normalized_loss) # self.verbose = args.verbose self.adim = args.adim self.mtlalpha = args.mtlalpha if args.mtlalpha > 0.0: self.ctc = CTC(odim, args.adim, args.dropout_rate, ctc_type=args.ctc_type, reduce=True) else: self.ctc = None if args.report_cer or args.report_wer: from espnet.nets.e2e_asr_common import ErrorCalculator self.error_calculator = ErrorCalculator(args.char_list, args.sym_space, args.sym_blank, args.report_cer, args.report_wer) else: self.error_calculator = None self.rnnlm = None # yzl23 config self.remove_blank_in_ctc_mode = True self.reset_parameters(args) # reset params at the last logging.warning( "Model total size: {}M, requires_grad size: {}M".format( self.count_parameters(), self.count_parameters(requires_grad=True)))
class E2E(ASRInterface, torch.nn.Module): """E2E module. :param int idim: dimension of inputs :param int odim: dimension of outputs :param Namespace args: argument Namespace containing options """ @staticmethod def add_arguments(parser): """Add arguments.""" group = parser.add_argument_group("transformer model setting") group = add_arguments_transformer_common(group) return parser @property def attention_plot_class(self): """Return PlotAttentionReport.""" return PlotAttentionReport def get_total_subsampling_factor(self): """Get total subsampling factor.""" return self.encoder.conv_subsampling_factor * int(numpy.prod(self.subsample)) def __init__(self, idim, odim, args, ignore_id=-1): """Construct an E2E object. :param int idim: dimension of inputs :param int odim: dimension of outputs :param Namespace args: argument Namespace containing options """ torch.nn.Module.__init__(self) # fill missing arguments for compatibility args = fill_missing_args(args, self.add_arguments) if args.transformer_attn_dropout_rate is None: args.transformer_attn_dropout_rate = args.dropout_rate self.adim = args.adim # used for CTC (equal to d_model) self.mtlalpha = args.mtlalpha if args.mtlalpha > 0.0: self.ctc = CTC( odim, args.adim, args.dropout_rate, ctc_type=args.ctc_type, reduce=True ) else: self.ctc = None self.intermediate_ctc_weight = args.intermediate_ctc_weight self.intermediate_ctc_layers = None if args.intermediate_ctc_layer != "": self.intermediate_ctc_layers = [ int(i) for i in args.intermediate_ctc_layer.split(",") ] self.encoder = Encoder( idim=idim, selfattention_layer_type=args.transformer_encoder_selfattn_layer_type, attention_dim=args.adim, attention_heads=args.aheads, conv_wshare=args.wshare, conv_kernel_length=args.ldconv_encoder_kernel_length, conv_usebias=args.ldconv_usebias, linear_units=args.eunits, num_blocks=args.elayers, input_layer=args.transformer_input_layer, dropout_rate=args.dropout_rate, positional_dropout_rate=args.dropout_rate, attention_dropout_rate=args.transformer_attn_dropout_rate, stochastic_depth_rate=args.stochastic_depth_rate, intermediate_layers=self.intermediate_ctc_layers, ctc_softmax=self.ctc.softmax if args.self_conditioning else None, conditioning_layer_dim=odim, ) if args.mtlalpha < 1: self.decoder = Decoder( odim=odim, selfattention_layer_type=args.transformer_decoder_selfattn_layer_type, attention_dim=args.adim, attention_heads=args.aheads, conv_wshare=args.wshare, conv_kernel_length=args.ldconv_decoder_kernel_length, conv_usebias=args.ldconv_usebias, linear_units=args.dunits, num_blocks=args.dlayers, dropout_rate=args.dropout_rate, positional_dropout_rate=args.dropout_rate, self_attention_dropout_rate=args.transformer_attn_dropout_rate, src_attention_dropout_rate=args.transformer_attn_dropout_rate, ) self.criterion = LabelSmoothingLoss( odim, ignore_id, args.lsm_weight, args.transformer_length_normalized_loss, ) else: self.decoder = None self.criterion = None self.blank = 0 self.sos = odim - 1 self.eos = odim - 1 self.odim = odim self.ignore_id = ignore_id self.subsample = get_subsample(args, mode="asr", arch="transformer") self.reporter = Reporter() self.reset_parameters(args) if args.report_cer or args.report_wer: self.error_calculator = ErrorCalculator( args.char_list, args.sym_space, args.sym_blank, args.report_cer, args.report_wer, ) else: self.error_calculator = None self.rnnlm = None def reset_parameters(self, args): """Initialize parameters.""" # initialize parameters initialize(self, args.transformer_init) def forward(self, xs_pad, ilens, ys_pad): """E2E forward. :param torch.Tensor xs_pad: batch of padded source sequences (B, Tmax, idim) :param torch.Tensor ilens: batch of lengths of source sequences (B) :param torch.Tensor ys_pad: batch of padded target sequences (B, Lmax) :return: ctc loss value :rtype: torch.Tensor :return: attention loss value :rtype: torch.Tensor :return: accuracy in attention decoder :rtype: float """ # 1. forward encoder xs_pad = xs_pad[:, : max(ilens)] # for data parallel src_mask = make_non_pad_mask(ilens.tolist()).to(xs_pad.device).unsqueeze(-2) if self.intermediate_ctc_layers: hs_pad, hs_mask, hs_intermediates = self.encoder(xs_pad, src_mask) else: hs_pad, hs_mask = self.encoder(xs_pad, src_mask) self.hs_pad = hs_pad # 2. forward decoder if self.decoder is not None: ys_in_pad, ys_out_pad = add_sos_eos( ys_pad, self.sos, self.eos, self.ignore_id ) ys_mask = target_mask(ys_in_pad, self.ignore_id) pred_pad, pred_mask = self.decoder(ys_in_pad, ys_mask, hs_pad, hs_mask) self.pred_pad = pred_pad # 3. compute attention loss loss_att = self.criterion(pred_pad, ys_out_pad) self.acc = th_accuracy( pred_pad.view(-1, self.odim), ys_out_pad, ignore_label=self.ignore_id ) else: loss_att = None self.acc = None # TODO(karita) show predicted text # TODO(karita) calculate these stats cer_ctc = None loss_intermediate_ctc = 0.0 if self.mtlalpha == 0.0: loss_ctc = None else: batch_size = xs_pad.size(0) hs_len = hs_mask.view(batch_size, -1).sum(1) loss_ctc = self.ctc(hs_pad.view(batch_size, -1, self.adim), hs_len, ys_pad) if not self.training and self.error_calculator is not None: ys_hat = self.ctc.argmax(hs_pad.view(batch_size, -1, self.adim)).data cer_ctc = self.error_calculator(ys_hat.cpu(), ys_pad.cpu(), is_ctc=True) # for visualization if not self.training: self.ctc.softmax(hs_pad) if self.intermediate_ctc_weight > 0 and self.intermediate_ctc_layers: for hs_intermediate in hs_intermediates: # assuming hs_intermediates and hs_pad has same length / padding loss_inter = self.ctc( hs_intermediate.view(batch_size, -1, self.adim), hs_len, ys_pad ) loss_intermediate_ctc += loss_inter loss_intermediate_ctc /= len(self.intermediate_ctc_layers) # 5. compute cer/wer if self.training or self.error_calculator is None or self.decoder is None: cer, wer = None, None else: ys_hat = pred_pad.argmax(dim=-1) cer, wer = self.error_calculator(ys_hat.cpu(), ys_pad.cpu()) # copied from e2e_asr alpha = self.mtlalpha if alpha == 0: self.loss = loss_att loss_att_data = float(loss_att) loss_ctc_data = None elif alpha == 1: self.loss = loss_ctc if self.intermediate_ctc_weight > 0: self.loss = ( 1 - self.intermediate_ctc_weight ) * loss_ctc + self.intermediate_ctc_weight * loss_intermediate_ctc loss_att_data = None loss_ctc_data = float(loss_ctc) else: self.loss = alpha * loss_ctc + (1 - alpha) * loss_att if self.intermediate_ctc_weight > 0: self.loss = ( (1 - alpha - self.intermediate_ctc_weight) * loss_att + alpha * loss_ctc + self.intermediate_ctc_weight * loss_intermediate_ctc ) loss_att_data = float(loss_att) loss_ctc_data = float(loss_ctc) loss_data = float(self.loss) if loss_data < CTC_LOSS_THRESHOLD and not math.isnan(loss_data): self.reporter.report( loss_ctc_data, loss_att_data, self.acc, cer_ctc, cer, wer, loss_data ) else: logging.warning("loss (=%f) is not correct", loss_data) return self.loss def scorers(self): """Scorers.""" return dict(decoder=self.decoder, ctc=CTCPrefixScorer(self.ctc, self.eos)) def encode(self, x): """Encode acoustic features. :param ndarray x: source acoustic feature (T, D) :return: encoder outputs :rtype: torch.Tensor """ self.eval() x = torch.as_tensor(x).unsqueeze(0) enc_output, *_ = self.encoder(x, None) return enc_output.squeeze(0) def recognize(self, x, recog_args, char_list=None, rnnlm=None, use_jit=False): """Recognize input speech. :param ndnarray x: input acoustic feature (B, T, D) or (T, D) :param Namespace recog_args: argment Namespace contraining options :param list char_list: list of characters :param torch.nn.Module rnnlm: language model module :return: N-best decoding results :rtype: list """ enc_output = self.encode(x).unsqueeze(0) if self.mtlalpha == 1.0: recog_args.ctc_weight = 1.0 logging.info("Set to pure CTC decoding mode.") if self.mtlalpha > 0 and recog_args.ctc_weight == 1.0: from itertools import groupby lpz = self.ctc.argmax(enc_output) collapsed_indices = [x[0] for x in groupby(lpz[0])] hyp = [x for x in filter(lambda x: x != self.blank, collapsed_indices)] nbest_hyps = [{"score": 0.0, "yseq": [self.sos] + hyp}] if recog_args.beam_size > 1: raise NotImplementedError("Pure CTC beam search is not implemented.") # TODO(hirofumi0810): Implement beam search return nbest_hyps elif self.mtlalpha > 0 and recog_args.ctc_weight > 0.0: lpz = self.ctc.log_softmax(enc_output) lpz = lpz.squeeze(0) else: lpz = None h = enc_output.squeeze(0) logging.info("input lengths: " + str(h.size(0))) # search parms beam = recog_args.beam_size penalty = recog_args.penalty ctc_weight = recog_args.ctc_weight # preprare sos y = self.sos vy = h.new_zeros(1).long() if recog_args.maxlenratio == 0: maxlen = h.shape[0] else: # maxlen >= 1 maxlen = max(1, int(recog_args.maxlenratio * h.size(0))) minlen = int(recog_args.minlenratio * h.size(0)) logging.info("max output length: " + str(maxlen)) logging.info("min output length: " + str(minlen)) # initialize hypothesis if rnnlm: hyp = {"score": 0.0, "yseq": [y], "rnnlm_prev": None} else: hyp = {"score": 0.0, "yseq": [y]} if lpz is not None: ctc_prefix_score = CTCPrefixScore(lpz.detach().numpy(), 0, self.eos, numpy) hyp["ctc_state_prev"] = ctc_prefix_score.initial_state() hyp["ctc_score_prev"] = 0.0 if ctc_weight != 1.0: # pre-pruning based on attention scores ctc_beam = min(lpz.shape[-1], int(beam * CTC_SCORING_RATIO)) else: ctc_beam = lpz.shape[-1] hyps = [hyp] ended_hyps = [] import six traced_decoder = None for i in six.moves.range(maxlen): logging.debug("position " + str(i)) hyps_best_kept = [] for hyp in hyps: vy[0] = hyp["yseq"][i] # get nbest local scores and their ids ys_mask = subsequent_mask(i + 1).unsqueeze(0) ys = torch.tensor(hyp["yseq"]).unsqueeze(0) # FIXME: jit does not match non-jit result if use_jit: if traced_decoder is None: traced_decoder = torch.jit.trace( self.decoder.forward_one_step, (ys, ys_mask, enc_output) ) local_att_scores = traced_decoder(ys, ys_mask, enc_output)[0] else: local_att_scores = self.decoder.forward_one_step( ys, ys_mask, enc_output )[0] if rnnlm: rnnlm_state, local_lm_scores = rnnlm.predict(hyp["rnnlm_prev"], vy) local_scores = ( local_att_scores + recog_args.lm_weight * local_lm_scores ) else: local_scores = local_att_scores if lpz is not None: local_best_scores, local_best_ids = torch.topk( local_att_scores, ctc_beam, dim=1 ) ctc_scores, ctc_states = ctc_prefix_score( hyp["yseq"], local_best_ids[0], hyp["ctc_state_prev"] ) local_scores = (1.0 - ctc_weight) * local_att_scores[ :, local_best_ids[0] ] + ctc_weight * torch.from_numpy( ctc_scores - hyp["ctc_score_prev"] ) if rnnlm: local_scores += ( recog_args.lm_weight * local_lm_scores[:, local_best_ids[0]] ) local_best_scores, joint_best_ids = torch.topk( local_scores, beam, dim=1 ) local_best_ids = local_best_ids[:, joint_best_ids[0]] else: local_best_scores, local_best_ids = torch.topk( local_scores, beam, dim=1 ) for j in six.moves.range(beam): new_hyp = {} new_hyp["score"] = hyp["score"] + float(local_best_scores[0, j]) new_hyp["yseq"] = [0] * (1 + len(hyp["yseq"])) new_hyp["yseq"][: len(hyp["yseq"])] = hyp["yseq"] new_hyp["yseq"][len(hyp["yseq"])] = int(local_best_ids[0, j]) if rnnlm: new_hyp["rnnlm_prev"] = rnnlm_state if lpz is not None: new_hyp["ctc_state_prev"] = ctc_states[joint_best_ids[0, j]] new_hyp["ctc_score_prev"] = ctc_scores[joint_best_ids[0, j]] # will be (2 x beam) hyps at most hyps_best_kept.append(new_hyp) hyps_best_kept = sorted( hyps_best_kept, key=lambda x: x["score"], reverse=True )[:beam] # sort and get nbest hyps = hyps_best_kept logging.debug("number of pruned hypothes: " + str(len(hyps))) if char_list is not None: logging.debug( "best hypo: " + "".join([char_list[int(x)] for x in hyps[0]["yseq"][1:]]) ) # add eos in the final loop to avoid that there are no ended hyps if i == maxlen - 1: logging.info("adding <eos> in the last position in the loop") for hyp in hyps: hyp["yseq"].append(self.eos) # add ended hypothes to a final list, and removed them from current hypothes # (this will be a probmlem, number of hyps < beam) remained_hyps = [] for hyp in hyps: if hyp["yseq"][-1] == self.eos: # only store the sequence that has more than minlen outputs # also add penalty if len(hyp["yseq"]) > minlen: hyp["score"] += (i + 1) * penalty if rnnlm: # Word LM needs to add final <eos> score hyp["score"] += recog_args.lm_weight * rnnlm.final( hyp["rnnlm_prev"] ) ended_hyps.append(hyp) else: remained_hyps.append(hyp) # end detection if end_detect(ended_hyps, i) and recog_args.maxlenratio == 0.0: logging.info("end detected at %d", i) break hyps = remained_hyps if len(hyps) > 0: logging.debug("remeined hypothes: " + str(len(hyps))) else: logging.info("no hypothesis. Finish decoding.") break if char_list is not None: for hyp in hyps: logging.debug( "hypo: " + "".join([char_list[int(x)] for x in hyp["yseq"][1:]]) ) logging.debug("number of ended hypothes: " + str(len(ended_hyps))) nbest_hyps = sorted(ended_hyps, key=lambda x: x["score"], reverse=True)[ : min(len(ended_hyps), recog_args.nbest) ] # check number of hypotheis if len(nbest_hyps) == 0: logging.warning( "there is no N-best results, perform recognition " "again with smaller minlenratio." ) # should copy becasuse Namespace will be overwritten globally recog_args = Namespace(**vars(recog_args)) recog_args.minlenratio = max(0.0, recog_args.minlenratio - 0.1) return self.recognize(x, recog_args, char_list, rnnlm) logging.info("total log probability: " + str(nbest_hyps[0]["score"])) logging.info( "normalized log probability: " + str(nbest_hyps[0]["score"] / len(nbest_hyps[0]["yseq"])) ) return nbest_hyps def calculate_all_attentions(self, xs_pad, ilens, ys_pad): """E2E attention calculation. :param torch.Tensor xs_pad: batch of padded input sequences (B, Tmax, idim) :param torch.Tensor ilens: batch of lengths of input sequences (B) :param torch.Tensor ys_pad: batch of padded token id sequence tensor (B, Lmax) :return: attention weights (B, H, Lmax, Tmax) :rtype: float ndarray """ self.eval() with torch.no_grad(): self.forward(xs_pad, ilens, ys_pad) ret = dict() for name, m in self.named_modules(): if ( isinstance(m, MultiHeadedAttention) or isinstance(m, DynamicConvolution) or isinstance(m, RelPositionMultiHeadedAttention) ): ret[name] = m.attn.cpu().numpy() if isinstance(m, DynamicConvolution2D): ret[name + "_time"] = m.attn_t.cpu().numpy() ret[name + "_freq"] = m.attn_f.cpu().numpy() self.train() return ret def calculate_all_ctc_probs(self, xs_pad, ilens, ys_pad): """E2E CTC probability calculation. :param torch.Tensor xs_pad: batch of padded input sequences (B, Tmax) :param torch.Tensor ilens: batch of lengths of input sequences (B) :param torch.Tensor ys_pad: batch of padded token id sequence tensor (B, Lmax) :return: CTC probability (B, Tmax, vocab) :rtype: float ndarray """ ret = None if self.mtlalpha == 0: return ret self.eval() with torch.no_grad(): self.forward(xs_pad, ilens, ys_pad) for name, m in self.named_modules(): if isinstance(m, CTC) and m.probs is not None: ret = m.probs.cpu().numpy() self.train() return ret
def __init__( self, # network structure related idim: int, odim: int, embed_dim: int = 512, eprenet_conv_layers: int = 3, eprenet_conv_chans: int = 256, eprenet_conv_filts: int = 5, dprenet_layers: int = 2, dprenet_units: int = 256, elayers: int = 6, eunits: int = 1024, adim: int = 512, aheads: int = 4, dlayers: int = 6, dunits: int = 1024, postnet_layers: int = 5, postnet_chans: int = 256, postnet_filts: int = 5, positionwise_layer_type: str = "conv1d", positionwise_conv_kernel_size: int = 1, use_scaled_pos_enc: bool = True, use_batch_norm: bool = True, encoder_normalize_before: bool = True, decoder_normalize_before: bool = True, encoder_concat_after: bool = False, decoder_concat_after: bool = False, reduction_factor: int = 1, spk_embed_dim: int = None, spk_embed_integration_type: str = "add", use_gst: bool = False, gst_tokens: int = 10, gst_heads: int = 4, gst_conv_layers: int = 6, gst_conv_chans_list: Sequence[int] = (32, 32, 64, 64, 128, 128), gst_conv_kernel_size: int = 3, gst_conv_stride: int = 2, gst_gru_layers: int = 1, gst_gru_units: int = 128, # training related transformer_enc_dropout_rate: float = 0.1, transformer_enc_positional_dropout_rate: float = 0.1, transformer_enc_attn_dropout_rate: float = 0.1, transformer_dec_dropout_rate: float = 0.1, transformer_dec_positional_dropout_rate: float = 0.1, transformer_dec_attn_dropout_rate: float = 0.1, transformer_enc_dec_attn_dropout_rate: float = 0.1, eprenet_dropout_rate: float = 0.5, dprenet_dropout_rate: float = 0.5, postnet_dropout_rate: float = 0.5, init_type: str = "xavier_uniform", init_enc_alpha: float = 1.0, init_dec_alpha: float = 1.0, use_masking: bool = False, use_weighted_masking: bool = False, bce_pos_weight: float = 5.0, loss_type: str = "L1", use_guided_attn_loss: bool = True, num_heads_applied_guided_attn: int = 2, num_layers_applied_guided_attn: int = 2, modules_applied_guided_attn: Sequence[str] = ("encoder-decoder"), guided_attn_loss_sigma: float = 0.4, guided_attn_loss_lambda: float = 1.0, ): """Initialize Transformer module.""" assert check_argument_types() super().__init__() # store hyperparameters self.idim = idim self.odim = odim self.eos = idim - 1 self.spk_embed_dim = spk_embed_dim self.reduction_factor = reduction_factor self.use_gst = use_gst self.use_guided_attn_loss = use_guided_attn_loss self.use_scaled_pos_enc = use_scaled_pos_enc self.loss_type = loss_type self.use_guided_attn_loss = use_guided_attn_loss if self.use_guided_attn_loss: if num_layers_applied_guided_attn == -1: self.num_layers_applied_guided_attn = elayers else: self.num_layers_applied_guided_attn = num_layers_applied_guided_attn if num_heads_applied_guided_attn == -1: self.num_heads_applied_guided_attn = aheads else: self.num_heads_applied_guided_attn = num_heads_applied_guided_attn self.modules_applied_guided_attn = modules_applied_guided_attn if self.spk_embed_dim is not None: self.spk_embed_integration_type = spk_embed_integration_type # use idx 0 as padding idx self.padding_idx = 0 # get positional encoding class pos_enc_class = ( ScaledPositionalEncoding if self.use_scaled_pos_enc else PositionalEncoding ) # define transformer encoder if eprenet_conv_layers != 0: # encoder prenet encoder_input_layer = torch.nn.Sequential( EncoderPrenet( idim=idim, embed_dim=embed_dim, elayers=0, econv_layers=eprenet_conv_layers, econv_chans=eprenet_conv_chans, econv_filts=eprenet_conv_filts, use_batch_norm=use_batch_norm, dropout_rate=eprenet_dropout_rate, padding_idx=self.padding_idx, ), torch.nn.Linear(eprenet_conv_chans, adim), ) else: encoder_input_layer = torch.nn.Embedding( num_embeddings=idim, embedding_dim=adim, padding_idx=self.padding_idx ) self.encoder = Encoder( idim=idim, attention_dim=adim, attention_heads=aheads, linear_units=eunits, num_blocks=elayers, input_layer=encoder_input_layer, dropout_rate=transformer_enc_dropout_rate, positional_dropout_rate=transformer_enc_positional_dropout_rate, attention_dropout_rate=transformer_enc_attn_dropout_rate, pos_enc_class=pos_enc_class, normalize_before=encoder_normalize_before, concat_after=encoder_concat_after, positionwise_layer_type=positionwise_layer_type, positionwise_conv_kernel_size=positionwise_conv_kernel_size, ) # define GST if self.use_gst: self.gst = StyleEncoder( idim=odim, # the input is mel-spectrogram gst_tokens=gst_tokens, gst_token_dim=adim, gst_heads=gst_heads, conv_layers=gst_conv_layers, conv_chans_list=gst_conv_chans_list, conv_kernel_size=gst_conv_kernel_size, conv_stride=gst_conv_stride, gru_layers=gst_gru_layers, gru_units=gst_gru_units, ) # define projection layer if self.spk_embed_dim is not None: if self.spk_embed_integration_type == "add": self.projection = torch.nn.Linear(self.spk_embed_dim, adim) else: self.projection = torch.nn.Linear(adim + self.spk_embed_dim, adim) # define transformer decoder if dprenet_layers != 0: # decoder prenet decoder_input_layer = torch.nn.Sequential( DecoderPrenet( idim=odim, n_layers=dprenet_layers, n_units=dprenet_units, dropout_rate=dprenet_dropout_rate, ), torch.nn.Linear(dprenet_units, adim), ) else: decoder_input_layer = "linear" self.decoder = Decoder( odim=odim, # odim is needed when no prenet is used attention_dim=adim, attention_heads=aheads, linear_units=dunits, num_blocks=dlayers, dropout_rate=transformer_dec_dropout_rate, positional_dropout_rate=transformer_dec_positional_dropout_rate, self_attention_dropout_rate=transformer_dec_attn_dropout_rate, src_attention_dropout_rate=transformer_enc_dec_attn_dropout_rate, input_layer=decoder_input_layer, use_output_layer=False, pos_enc_class=pos_enc_class, normalize_before=decoder_normalize_before, concat_after=decoder_concat_after, ) # define final projection self.feat_out = torch.nn.Linear(adim, odim * reduction_factor) self.prob_out = torch.nn.Linear(adim, reduction_factor) # define postnet self.postnet = ( None if postnet_layers == 0 else Postnet( idim=idim, odim=odim, n_layers=postnet_layers, n_chans=postnet_chans, n_filts=postnet_filts, use_batch_norm=use_batch_norm, dropout_rate=postnet_dropout_rate, ) ) # define loss function self.criterion = TransformerLoss( use_masking=use_masking, use_weighted_masking=use_weighted_masking, bce_pos_weight=bce_pos_weight, ) if self.use_guided_attn_loss: self.attn_criterion = GuidedMultiHeadAttentionLoss( sigma=guided_attn_loss_sigma, alpha=guided_attn_loss_lambda, ) # initialize parameters self._reset_parameters( init_type=init_type, init_enc_alpha=init_enc_alpha, init_dec_alpha=init_dec_alpha, )
class E2E(MTInterface, torch.nn.Module): """E2E module. :param int idim: dimension of inputs :param int odim: dimension of outputs :param Namespace args: argument Namespace containing options """ @staticmethod def add_arguments(parser): """Add arguments.""" group = parser.add_argument_group("transformer model setting") group.add_argument("--transformer-init", type=str, default="xavier_uniform", choices=[ "pytorch", "xavier_uniform", "xavier_normal", "kaiming_uniform", "kaiming_normal" ], help='how to initialize transformer parameters') group.add_argument( '--transformer-attn-dropout-rate', default=None, type=float, help= 'dropout in transformer attention. use --dropout-rate if None is set' ) group.add_argument('--transformer-lr', default=1.0, type=float, help='Initial value of learning rate') group.add_argument('--transformer-warmup-steps', default=4000, type=int, help='optimizer warmup steps') group.add_argument('--transformer-length-normalized-loss', default=False, type=strtobool, help='normalize loss by length') group.add_argument('--dropout-rate', default=0.1, type=float, help='Dropout rate for the encoder') # Encoder group.add_argument( '--elayers', default=6, type=int, help= 'Number of encoder layers (for shared recognition part in multi-speaker asr mode)' ) group.add_argument('--eunits', '-u', default=2048, type=int, help='Number of encoder hidden units') # Attention group.add_argument( '--adim', default=256, type=int, help='Number of attention transformation dimensions') group.add_argument('--aheads', default=4, type=int, help='Number of heads for multi head attention') # Decoder group.add_argument('--dlayers', default=6, type=int, help='Number of decoder layers') group.add_argument('--dunits', default=2048, type=int, help='Number of decoder hidden units') return parser @property def attention_plot_class(self): """Return PlotAttentionReport.""" return PlotAttentionReport def __init__(self, idim, odim, args, ignore_id=-1): """Construct an E2E object. :param int idim: dimension of inputs :param int odim: dimension of outputs :param Namespace args: argument Namespace containing options """ torch.nn.Module.__init__(self) if args.transformer_attn_dropout_rate is None: args.transformer_attn_dropout_rate = args.dropout_rate self.encoder = Encoder( idim=idim, attention_dim=args.adim, attention_heads=args.aheads, linear_units=args.eunits, num_blocks=args.elayers, input_layer='embed', dropout_rate=args.dropout_rate, positional_dropout_rate=args.dropout_rate, attention_dropout_rate=args.transformer_attn_dropout_rate, ) self.decoder = Decoder( odim=odim, attention_dim=args.adim, attention_heads=args.aheads, linear_units=args.dunits, num_blocks=args.dlayers, dropout_rate=args.dropout_rate, positional_dropout_rate=args.dropout_rate, self_attention_dropout_rate=args.transformer_attn_dropout_rate, src_attention_dropout_rate=args.transformer_attn_dropout_rate, ) self.pad = 0 self.sos = odim - 1 self.eos = odim - 1 self.odim = odim self.ignore_id = ignore_id self.subsample = get_subsample(args, mode='mt', arch='transformer') self.reporter = Reporter() # tie source and target emeddings if args.tie_src_tgt_embedding: if idim != odim: raise ValueError( 'When using tie_src_tgt_embedding, idim and odim must be equal.' ) self.encoder.embed[0].weight = self.decoder.embed[0].weight # tie emeddings and the classfier if args.tie_classifier: self.decoder.output_layer.weight = self.decoder.embed[0].weight # self.lsm_weight = a self.criterion = LabelSmoothingLoss( self.odim, self.ignore_id, args.lsm_weight, args.transformer_length_normalized_loss) self.normalize_length = args.transformer_length_normalized_loss # for PPL # self.verbose = args.verbose self.reset_parameters(args) self.adim = args.adim if args.report_bleu: from espnet.nets.e2e_mt_common import ErrorCalculator self.error_calculator = ErrorCalculator(args.char_list, args.sym_space, args.report_bleu) else: self.error_calculator = None self.rnnlm = None # multilingual NMT related self.multilingual = args.multilingual def reset_parameters(self, args): """Initialize parameters.""" # initialize parameters initialize(self, args.transformer_init) torch.nn.init.normal_(self.encoder.embed[0].weight, mean=0, std=args.adim**-0.5) torch.nn.init.constant_(self.encoder.embed[0].weight[self.pad], 0) torch.nn.init.normal_(self.decoder.embed[0].weight, mean=0, std=args.adim**-0.5) torch.nn.init.constant_(self.decoder.embed[0].weight[self.pad], 0) def forward(self, xs_pad, ilens, ys_pad): """E2E forward. :param torch.Tensor xs_pad: batch of padded source sequences (B, Tmax) :param torch.Tensor ilens: batch of lengths of source sequences (B) :param torch.Tensor ys_pad: batch of padded target sequences (B, Lmax) :rtype: torch.Tensor :return: attention loss value :rtype: torch.Tensor :return: accuracy in attention decoder :rtype: float """ # 1. forward encoder xs_pad = xs_pad[:, :max(ilens)] # for data parallel src_mask = (~make_pad_mask(ilens.tolist())).to( xs_pad.device).unsqueeze(-2) xs_pad, ys_pad = self.target_forcing(xs_pad, ys_pad) hs_pad, hs_mask = self.encoder(xs_pad, src_mask) self.hs_pad = hs_pad # 2. forward decoder ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id) ys_mask = target_mask(ys_in_pad, self.ignore_id) pred_pad, pred_mask = self.decoder(ys_in_pad, ys_mask, hs_pad, hs_mask) self.pred_pad = pred_pad # 3. compute attention loss loss = self.criterion(pred_pad, ys_out_pad) self.acc = th_accuracy(pred_pad.view(-1, self.odim), ys_out_pad, ignore_label=self.ignore_id) # TODO(karita) show predicted text # TODO(karita) calculate these stats # 5. compute bleu if self.training or self.error_calculator is None: bleu = 0.0 else: ys_hat = pred_pad.argmax(dim=-1) bleu = self.error_calculator(ys_hat.cpu(), ys_pad.cpu()) # copyied from e2e_mt self.loss = loss loss_data = float(self.loss) if self.normalize_length: self.ppl = np.exp(loss_data) else: ys_out_pad = ys_out_pad.view(-1) ignore = ys_out_pad == self.ignore_id # (B,) total = len(ys_out_pad) - ignore.sum().item() self.ppl = np.exp(loss_data * ys_out_pad.size(0) / total) if not math.isnan(loss_data): self.reporter.report(loss_data, self.acc, self.ppl, bleu) else: logging.warning('loss (=%f) is not correct', loss_data) return self.loss def scorers(self): """Scorers.""" return dict(decoder=self.decoder) def encode(self, xs): """Encode source sentences.""" self.eval() xs = torch.as_tensor(xs).unsqueeze(0) enc_output, _ = self.encoder(xs, None) return enc_output.squeeze(0) def target_forcing(self, xs_pad, ys_pad=None, tgt_lang=None): """Prepend target language IDs to source sentences for multilingual NMT. These tags are prepended in source/target sentences as pre-processing. :param torch.Tensor xs_pad: batch of padded input sequences (B, Tmax) :return: source text without language IDs :rtype: torch.Tensor :return: target text without language IDs :rtype: torch.Tensor :return: target language IDs :rtype: torch.Tensor (B, 1) """ if self.multilingual: xs_pad = xs_pad[:, 1:] # remove source language IDs here if ys_pad is not None: # remove language ID in the beginning lang_ids = ys_pad[:, 0].unsqueeze(1) ys_pad = ys_pad[:, 1:] elif tgt_lang is not None: lang_ids = xs_pad.new_zeros(xs_pad.size(0), 1).fill_(tgt_lang) else: raise ValueError("Set ys_pad or tgt_lang.") # prepend target language ID to source sentences xs_pad = torch.cat([lang_ids, xs_pad], dim=1) return xs_pad, ys_pad def translate(self, x, trans_args, char_list=None, rnnlm=None, use_jit=False): """Translate source text. :param list x: input source text feature (T,) :param Namespace trans_args: argment Namespace contraining options :param list char_list: list of characters :param torch.nn.Module rnnlm: language model module :return: N-best decoding results :rtype: list """ self.eval( ) # NOTE: this is important because self.encode() is not used assert isinstance(x, list) # make a utt list (1) to use the same interface for encoder if self.multilingual: x = to_device( self, torch.from_numpy( np.fromiter(map(int, x[0][1:]), dtype=np.int64))) else: x = to_device( self, torch.from_numpy(np.fromiter(map(int, x[0]), dtype=np.int64))) xs_pad = x.unsqueeze(0) tgt_lang = None if trans_args.tgt_lang: tgt_lang = char_list.index(trans_args.tgt_lang) xs_pad, _ = self.target_forcing(xs_pad, tgt_lang=tgt_lang) enc_output, _ = self.encoder(xs_pad, None) h = enc_output.squeeze(0) logging.info('input lengths: ' + str(h.size(0))) # search parms beam = trans_args.beam_size penalty = trans_args.penalty # preprare sos y = self.sos vy = h.new_zeros(1).long() if trans_args.maxlenratio == 0: maxlen = h.shape[0] else: # maxlen >= 1 maxlen = max(1, int(trans_args.maxlenratio * h.size(0))) minlen = int(trans_args.minlenratio * h.size(0)) logging.info('max output length: ' + str(maxlen)) logging.info('min output length: ' + str(minlen)) # initialize hypothesis if rnnlm: hyp = {'score': 0.0, 'yseq': [y], 'rnnlm_prev': None} else: hyp = {'score': 0.0, 'yseq': [y]} hyps = [hyp] ended_hyps = [] import six traced_decoder = None for i in six.moves.range(maxlen): logging.debug('position ' + str(i)) hyps_best_kept = [] for hyp in hyps: vy[0] = hyp['yseq'][i] # get nbest local scores and their ids ys_mask = subsequent_mask(i + 1).unsqueeze(0) ys = torch.tensor(hyp['yseq']).unsqueeze(0) # FIXME: jit does not match non-jit result if use_jit: if traced_decoder is None: traced_decoder = torch.jit.trace( self.decoder.forward_one_step, (ys, ys_mask, enc_output)) local_att_scores = traced_decoder(ys, ys_mask, enc_output)[0] else: local_att_scores = self.decoder.forward_one_step( ys, ys_mask, enc_output)[0] if rnnlm: rnnlm_state, local_lm_scores = rnnlm.predict( hyp['rnnlm_prev'], vy) local_scores = local_att_scores + trans_args.lm_weight * local_lm_scores else: local_scores = local_att_scores local_best_scores, local_best_ids = torch.topk(local_scores, beam, dim=1) for j in six.moves.range(beam): new_hyp = {} new_hyp['score'] = hyp['score'] + float( local_best_scores[0, j]) new_hyp['yseq'] = [0] * (1 + len(hyp['yseq'])) new_hyp['yseq'][:len(hyp['yseq'])] = hyp['yseq'] new_hyp['yseq'][len(hyp['yseq'])] = int(local_best_ids[0, j]) if rnnlm: new_hyp['rnnlm_prev'] = rnnlm_state # will be (2 x beam) hyps at most hyps_best_kept.append(new_hyp) hyps_best_kept = sorted(hyps_best_kept, key=lambda x: x['score'], reverse=True)[:beam] # sort and get nbest hyps = hyps_best_kept logging.debug('number of pruned hypothes: ' + str(len(hyps))) if char_list is not None: logging.debug( 'best hypo: ' + ''.join([char_list[int(x)] for x in hyps[0]['yseq'][1:]])) # add eos in the final loop to avoid that there are no ended hyps if i == maxlen - 1: logging.info('adding <eos> in the last postion in the loop') for hyp in hyps: hyp['yseq'].append(self.eos) # add ended hypothes to a final list, and removed them from current hypothes # (this will be a probmlem, number of hyps < beam) remained_hyps = [] for hyp in hyps: if hyp['yseq'][-1] == self.eos: # only store the sequence that has more than minlen outputs # also add penalty if len(hyp['yseq']) > minlen: hyp['score'] += (i + 1) * penalty if rnnlm: # Word LM needs to add final <eos> score hyp['score'] += trans_args.lm_weight * rnnlm.final( hyp['rnnlm_prev']) ended_hyps.append(hyp) else: remained_hyps.append(hyp) # end detection from espnet.nets.e2e_asr_common import end_detect if end_detect(ended_hyps, i) and trans_args.maxlenratio == 0.0: logging.info('end detected at %d', i) break hyps = remained_hyps if len(hyps) > 0: logging.debug('remeined hypothes: ' + str(len(hyps))) else: logging.info('no hypothesis. Finish decoding.') break if char_list is not None: for hyp in hyps: logging.debug( 'hypo: ' + ''.join([char_list[int(x)] for x in hyp['yseq'][1:]])) logging.debug('number of ended hypothes: ' + str(len(ended_hyps))) nbest_hyps = sorted( ended_hyps, key=lambda x: x['score'], reverse=True)[:min(len(ended_hyps), trans_args.nbest)] # check number of hypotheis if len(nbest_hyps) == 0: logging.warning( 'there is no N-best results, perform recognition again with smaller minlenratio.' ) # should copy becasuse Namespace will be overwritten globally trans_args = Namespace(**vars(trans_args)) trans_args.minlenratio = max(0.0, trans_args.minlenratio - 0.1) return self.translate(x, trans_args, char_list, rnnlm) logging.info('total log probability: ' + str(nbest_hyps[0]['score'])) logging.info('normalized log probability: ' + str(nbest_hyps[0]['score'] / len(nbest_hyps[0]['yseq']))) return nbest_hyps def calculate_all_attentions(self, xs_pad, ilens, ys_pad): """E2E attention calculation. :param torch.Tensor xs_pad: batch of padded input sequences (B, Tmax) :param torch.Tensor ilens: batch of lengths of input sequences (B) :param torch.Tensor ys_pad: batch of padded character id sequence tensor (B, Lmax) :return: attention weights with the following shape, 1) multi-head case => attention weights (B, H, Lmax, Tmax), 2) other case => attention weights (B, Lmax, Tmax). :rtype: float ndarray """ with torch.no_grad(): self.forward(xs_pad, ilens, ys_pad) ret = dict() for name, m in self.named_modules(): if isinstance(m, MultiHeadedAttention): ret[name] = m.attn.cpu().numpy() return ret
class Transformer(AbsTTS): """TTS-Transformer module. This is a module of text-to-speech Transformer described in `Neural Speech Synthesis with Transformer Network`_, which convert the sequence of tokens into the sequence of Mel-filterbanks. .. _`Neural Speech Synthesis with Transformer Network`: https://arxiv.org/pdf/1809.08895.pdf Args: idim (int): Dimension of the inputs. odim (int): Dimension of the outputs. embed_dim (int, optional): Dimension of character embedding. eprenet_conv_layers (int, optional): Number of encoder prenet convolution layers. eprenet_conv_chans (int, optional): Number of encoder prenet convolution channels. eprenet_conv_filts (int, optional): Filter size of encoder prenet convolution. dprenet_layers (int, optional): Number of decoder prenet layers. dprenet_units (int, optional): Number of decoder prenet hidden units. elayers (int, optional): Number of encoder layers. eunits (int, optional): Number of encoder hidden units. adim (int, optional): Number of attention transformation dimensions. aheads (int, optional): Number of heads for multi head attention. dlayers (int, optional): Number of decoder layers. dunits (int, optional): Number of decoder hidden units. postnet_layers (int, optional): Number of postnet layers. postnet_chans (int, optional): Number of postnet channels. postnet_filts (int, optional): Filter size of postnet. use_scaled_pos_enc (bool, optional): Whether to use trainable scaled positional encoding. use_batch_norm (bool, optional): Whether to use batch normalization in encoder prenet. encoder_normalize_before (bool, optional): Whether to perform layer normalization before encoder block. decoder_normalize_before (bool, optional): Whether to perform layer normalization before decoder block. encoder_concat_after (bool, optional): Whether to concatenate attention layer's input and output in encoder. decoder_concat_after (bool, optional): Whether to concatenate attention layer's input and output in decoder. positionwise_layer_type (str, optional): Position-wise operation type. positionwise_conv_kernel_size (int, optional): Kernel size in position wise conv 1d. reduction_factor (int, optional): Reduction factor. spk_embed_dim (int, optional): Number of speaker embedding dimenstions. spk_embed_integration_type (str, optional): How to integrate speaker embedding. use_gst (str, optional): Whether to use global style token. gst_tokens (int, optional): The number of GST embeddings. gst_heads (int, optional): The number of heads in GST multihead attention. gst_conv_layers (int, optional): The number of conv layers in GST. gst_conv_chans_list: (Sequence[int], optional): List of the number of channels of conv layers in GST. gst_conv_kernel_size (int, optional): Kernal size of conv layers in GST. gst_conv_stride (int, optional): Stride size of conv layers in GST. gst_gru_layers (int, optional): The number of GRU layers in GST. gst_gru_units (int, optional): The number of GRU units in GST. transformer_lr (float, optional): Initial value of learning rate. transformer_warmup_steps (int, optional): Optimizer warmup steps. transformer_enc_dropout_rate (float, optional): Dropout rate in encoder except attention and positional encoding. transformer_enc_positional_dropout_rate (float, optional): Dropout rate after encoder positional encoding. transformer_enc_attn_dropout_rate (float, optional): Dropout rate in encoder self-attention module. transformer_dec_dropout_rate (float, optional): Dropout rate in decoder except attention & positional encoding. transformer_dec_positional_dropout_rate (float, optional): Dropout rate after decoder positional encoding. transformer_dec_attn_dropout_rate (float, optional): Dropout rate in deocoder self-attention module. transformer_enc_dec_attn_dropout_rate (float, optional): Dropout rate in encoder-deocoder attention module. init_type (str, optional): How to initialize transformer parameters. init_enc_alpha (float, optional): Initial value of alpha in scaled pos encoding of the encoder. init_dec_alpha (float, optional): Initial value of alpha in scaled pos encoding of the decoder. eprenet_dropout_rate (float, optional): Dropout rate in encoder prenet. dprenet_dropout_rate (float, optional): Dropout rate in decoder prenet. postnet_dropout_rate (float, optional): Dropout rate in postnet. use_masking (bool, optional): Whether to apply masking for padded part in loss calculation. use_weighted_masking (bool, optional): Whether to apply weighted masking in loss calculation. bce_pos_weight (float, optional): Positive sample weight in bce calculation (only for use_masking=true). loss_type (str, optional): How to calculate loss. use_guided_attn_loss (bool, optional): Whether to use guided attention loss. num_heads_applied_guided_attn (int, optional): Number of heads in each layer to apply guided attention loss. num_layers_applied_guided_attn (int, optional): Number of layers to apply guided attention loss. modules_applied_guided_attn (Sequence[str], optional): List of module names to apply guided attention loss. guided_attn_loss_sigma (float, optional) Sigma in guided attention loss. guided_attn_loss_lambda (float, optional): Lambda in guided attention loss. """ def __init__( self, # network structure related idim: int, odim: int, embed_dim: int = 512, eprenet_conv_layers: int = 3, eprenet_conv_chans: int = 256, eprenet_conv_filts: int = 5, dprenet_layers: int = 2, dprenet_units: int = 256, elayers: int = 6, eunits: int = 1024, adim: int = 512, aheads: int = 4, dlayers: int = 6, dunits: int = 1024, postnet_layers: int = 5, postnet_chans: int = 256, postnet_filts: int = 5, positionwise_layer_type: str = "conv1d", positionwise_conv_kernel_size: int = 1, use_scaled_pos_enc: bool = True, use_batch_norm: bool = True, encoder_normalize_before: bool = True, decoder_normalize_before: bool = True, encoder_concat_after: bool = False, decoder_concat_after: bool = False, reduction_factor: int = 1, spk_embed_dim: int = None, spk_embed_integration_type: str = "add", use_gst: bool = False, gst_tokens: int = 10, gst_heads: int = 4, gst_conv_layers: int = 6, gst_conv_chans_list: Sequence[int] = (32, 32, 64, 64, 128, 128), gst_conv_kernel_size: int = 3, gst_conv_stride: int = 2, gst_gru_layers: int = 1, gst_gru_units: int = 128, # training related transformer_enc_dropout_rate: float = 0.1, transformer_enc_positional_dropout_rate: float = 0.1, transformer_enc_attn_dropout_rate: float = 0.1, transformer_dec_dropout_rate: float = 0.1, transformer_dec_positional_dropout_rate: float = 0.1, transformer_dec_attn_dropout_rate: float = 0.1, transformer_enc_dec_attn_dropout_rate: float = 0.1, eprenet_dropout_rate: float = 0.5, dprenet_dropout_rate: float = 0.5, postnet_dropout_rate: float = 0.5, init_type: str = "xavier_uniform", init_enc_alpha: float = 1.0, init_dec_alpha: float = 1.0, use_masking: bool = False, use_weighted_masking: bool = False, bce_pos_weight: float = 5.0, loss_type: str = "L1", use_guided_attn_loss: bool = True, num_heads_applied_guided_attn: int = 2, num_layers_applied_guided_attn: int = 2, modules_applied_guided_attn: Sequence[str] = ("encoder-decoder"), guided_attn_loss_sigma: float = 0.4, guided_attn_loss_lambda: float = 1.0, ): """Initialize Transformer module.""" assert check_argument_types() super().__init__() # store hyperparameters self.idim = idim self.odim = odim self.eos = idim - 1 self.spk_embed_dim = spk_embed_dim self.reduction_factor = reduction_factor self.use_gst = use_gst self.use_guided_attn_loss = use_guided_attn_loss self.use_scaled_pos_enc = use_scaled_pos_enc self.loss_type = loss_type self.use_guided_attn_loss = use_guided_attn_loss if self.use_guided_attn_loss: if num_layers_applied_guided_attn == -1: self.num_layers_applied_guided_attn = elayers else: self.num_layers_applied_guided_attn = num_layers_applied_guided_attn if num_heads_applied_guided_attn == -1: self.num_heads_applied_guided_attn = aheads else: self.num_heads_applied_guided_attn = num_heads_applied_guided_attn self.modules_applied_guided_attn = modules_applied_guided_attn if self.spk_embed_dim is not None: self.spk_embed_integration_type = spk_embed_integration_type # use idx 0 as padding idx self.padding_idx = 0 # get positional encoding class pos_enc_class = ( ScaledPositionalEncoding if self.use_scaled_pos_enc else PositionalEncoding ) # define transformer encoder if eprenet_conv_layers != 0: # encoder prenet encoder_input_layer = torch.nn.Sequential( EncoderPrenet( idim=idim, embed_dim=embed_dim, elayers=0, econv_layers=eprenet_conv_layers, econv_chans=eprenet_conv_chans, econv_filts=eprenet_conv_filts, use_batch_norm=use_batch_norm, dropout_rate=eprenet_dropout_rate, padding_idx=self.padding_idx, ), torch.nn.Linear(eprenet_conv_chans, adim), ) else: encoder_input_layer = torch.nn.Embedding( num_embeddings=idim, embedding_dim=adim, padding_idx=self.padding_idx ) self.encoder = Encoder( idim=idim, attention_dim=adim, attention_heads=aheads, linear_units=eunits, num_blocks=elayers, input_layer=encoder_input_layer, dropout_rate=transformer_enc_dropout_rate, positional_dropout_rate=transformer_enc_positional_dropout_rate, attention_dropout_rate=transformer_enc_attn_dropout_rate, pos_enc_class=pos_enc_class, normalize_before=encoder_normalize_before, concat_after=encoder_concat_after, positionwise_layer_type=positionwise_layer_type, positionwise_conv_kernel_size=positionwise_conv_kernel_size, ) # define GST if self.use_gst: self.gst = StyleEncoder( idim=odim, # the input is mel-spectrogram gst_tokens=gst_tokens, gst_token_dim=adim, gst_heads=gst_heads, conv_layers=gst_conv_layers, conv_chans_list=gst_conv_chans_list, conv_kernel_size=gst_conv_kernel_size, conv_stride=gst_conv_stride, gru_layers=gst_gru_layers, gru_units=gst_gru_units, ) # define projection layer if self.spk_embed_dim is not None: if self.spk_embed_integration_type == "add": self.projection = torch.nn.Linear(self.spk_embed_dim, adim) else: self.projection = torch.nn.Linear(adim + self.spk_embed_dim, adim) # define transformer decoder if dprenet_layers != 0: # decoder prenet decoder_input_layer = torch.nn.Sequential( DecoderPrenet( idim=odim, n_layers=dprenet_layers, n_units=dprenet_units, dropout_rate=dprenet_dropout_rate, ), torch.nn.Linear(dprenet_units, adim), ) else: decoder_input_layer = "linear" self.decoder = Decoder( odim=odim, # odim is needed when no prenet is used attention_dim=adim, attention_heads=aheads, linear_units=dunits, num_blocks=dlayers, dropout_rate=transformer_dec_dropout_rate, positional_dropout_rate=transformer_dec_positional_dropout_rate, self_attention_dropout_rate=transformer_dec_attn_dropout_rate, src_attention_dropout_rate=transformer_enc_dec_attn_dropout_rate, input_layer=decoder_input_layer, use_output_layer=False, pos_enc_class=pos_enc_class, normalize_before=decoder_normalize_before, concat_after=decoder_concat_after, ) # define final projection self.feat_out = torch.nn.Linear(adim, odim * reduction_factor) self.prob_out = torch.nn.Linear(adim, reduction_factor) # define postnet self.postnet = ( None if postnet_layers == 0 else Postnet( idim=idim, odim=odim, n_layers=postnet_layers, n_chans=postnet_chans, n_filts=postnet_filts, use_batch_norm=use_batch_norm, dropout_rate=postnet_dropout_rate, ) ) # define loss function self.criterion = TransformerLoss( use_masking=use_masking, use_weighted_masking=use_weighted_masking, bce_pos_weight=bce_pos_weight, ) if self.use_guided_attn_loss: self.attn_criterion = GuidedMultiHeadAttentionLoss( sigma=guided_attn_loss_sigma, alpha=guided_attn_loss_lambda, ) # initialize parameters self._reset_parameters( init_type=init_type, init_enc_alpha=init_enc_alpha, init_dec_alpha=init_dec_alpha, ) def _reset_parameters(self, init_type, init_enc_alpha=1.0, init_dec_alpha=1.0): # initialize parameters if init_type != "pytorch": initialize(self, init_type) # initialize alpha in scaled positional encoding if self.use_scaled_pos_enc: self.encoder.embed[-1].alpha.data = torch.tensor(init_enc_alpha) self.decoder.embed[-1].alpha.data = torch.tensor(init_dec_alpha) def forward( self, text: torch.Tensor, text_lengths: torch.Tensor, speech: torch.Tensor, speech_lengths: torch.Tensor, spembs: torch.Tensor = None, ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]: """Calculate forward propagation. Args: text (LongTensor): Batch of padded character ids (B, Tmax). text_lengths (LongTensor): Batch of lengths of each input batch (B,). speech (Tensor): Batch of padded target features (B, Lmax, odim). speech_lengths (LongTensor): Batch of the lengths of each target (B,). spembs (Tensor, optional): Batch of speaker embeddings (B, spk_embed_dim). Returns: Tensor: Loss scalar value. Dict: Statistics to be monitored. Tensor: Weight value. """ text = text[:, : text_lengths.max()] # for data-parallel speech = speech[:, : speech_lengths.max()] # for data-parallel batch_size = text.size(0) # Add eos at the last of sequence xs = F.pad(text, [0, 1], "constant", self.padding_idx) for i, l in enumerate(text_lengths): xs[i, l] = self.eos ilens = text_lengths + 1 ys = speech olens = speech_lengths # make labels for stop prediction labels = make_pad_mask(olens - 1).to(ys.device, ys.dtype) labels = F.pad(labels, [0, 1], "constant", 1.0) # calculate transformer outputs after_outs, before_outs, logits = self._forward(xs, ilens, ys, olens, spembs) # modifiy mod part of groundtruth olens_in = olens if self.reduction_factor > 1: olens_in = olens.new([olen // self.reduction_factor for olen in olens]) olens = olens.new([olen - olen % self.reduction_factor for olen in olens]) max_olen = max(olens) ys = ys[:, :max_olen] labels = labels[:, :max_olen] labels[:, -1] = 1.0 # make sure at least one frame has 1 # caluculate loss values l1_loss, l2_loss, bce_loss = self.criterion( after_outs, before_outs, logits, ys, labels, olens ) if self.loss_type == "L1": loss = l1_loss + bce_loss elif self.loss_type == "L2": loss = l2_loss + bce_loss elif self.loss_type == "L1+L2": loss = l1_loss + l2_loss + bce_loss else: raise ValueError("unknown --loss-type " + self.loss_type) stats = dict( l1_loss=l1_loss.item(), l2_loss=l2_loss.item(), bce_loss=bce_loss.item(), ) # calculate guided attention loss if self.use_guided_attn_loss: # calculate for encoder if "encoder" in self.modules_applied_guided_attn: att_ws = [] for idx, layer_idx in enumerate( reversed(range(len(self.encoder.encoders))) ): att_ws += [ self.encoder.encoders[layer_idx].self_attn.attn[ :, : self.num_heads_applied_guided_attn ] ] if idx + 1 == self.num_layers_applied_guided_attn: break att_ws = torch.cat(att_ws, dim=1) # (B, H*L, T_in, T_in) enc_attn_loss = self.attn_criterion(att_ws, ilens, ilens) loss = loss + enc_attn_loss stats.update(enc_attn_loss=enc_attn_loss.item()) # calculate for decoder if "decoder" in self.modules_applied_guided_attn: att_ws = [] for idx, layer_idx in enumerate( reversed(range(len(self.decoder.decoders))) ): att_ws += [ self.decoder.decoders[layer_idx].self_attn.attn[ :, : self.num_heads_applied_guided_attn ] ] if idx + 1 == self.num_layers_applied_guided_attn: break att_ws = torch.cat(att_ws, dim=1) # (B, H*L, T_out, T_out) dec_attn_loss = self.attn_criterion(att_ws, olens_in, olens_in) loss = loss + dec_attn_loss stats.update(dec_attn_loss=dec_attn_loss.item()) # calculate for encoder-decoder if "encoder-decoder" in self.modules_applied_guided_attn: att_ws = [] for idx, layer_idx in enumerate( reversed(range(len(self.decoder.decoders))) ): att_ws += [ self.decoder.decoders[layer_idx].src_attn.attn[ :, : self.num_heads_applied_guided_attn ] ] if idx + 1 == self.num_layers_applied_guided_attn: break att_ws = torch.cat(att_ws, dim=1) # (B, H*L, T_out, T_in) enc_dec_attn_loss = self.attn_criterion(att_ws, ilens, olens_in) loss = loss + enc_dec_attn_loss stats.update(enc_dec_attn_loss=enc_dec_attn_loss.item()) stats.update(loss=loss.item()) # report extra information if self.use_scaled_pos_enc: stats.update( encoder_alpha=self.encoder.embed[-1].alpha.data.item(), decoder_alpha=self.decoder.embed[-1].alpha.data.item(), ) loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device) return loss, stats, weight def _forward( self, xs: torch.Tensor, ilens: torch.Tensor, ys: torch.Tensor, olens: torch.Tensor, spembs: torch.Tensor, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: # forward encoder x_masks = self._source_mask(ilens) hs, h_masks = self.encoder(xs, x_masks) # integrate with GST if self.use_gst: style_embs = self.gst(ys) hs = hs + style_embs.unsqueeze(1) # integrate speaker embedding if self.spk_embed_dim is not None: hs = self._integrate_with_spk_embed(hs, spembs) # thin out frames for reduction factor (B, Lmax, odim) -> (B, Lmax//r, odim) if self.reduction_factor > 1: ys_in = ys[:, self.reduction_factor - 1 :: self.reduction_factor] olens_in = olens.new([olen // self.reduction_factor for olen in olens]) else: ys_in, olens_in = ys, olens # add first zero frame and remove last frame for auto-regressive ys_in = self._add_first_frame_and_remove_last_frame(ys_in) # forward decoder y_masks = self._target_mask(olens_in) zs, _ = self.decoder(ys_in, y_masks, hs, h_masks) # (B, Lmax//r, odim * r) -> (B, Lmax//r * r, odim) before_outs = self.feat_out(zs).view(zs.size(0), -1, self.odim) # (B, Lmax//r, r) -> (B, Lmax//r * r) logits = self.prob_out(zs).view(zs.size(0), -1) # postnet -> (B, Lmax//r * r, odim) if self.postnet is None: after_outs = before_outs else: after_outs = before_outs + self.postnet( before_outs.transpose(1, 2) ).transpose(1, 2) return after_outs, before_outs, logits def inference( self, text: torch.Tensor, speech: torch.Tensor = None, spembs: torch.Tensor = None, threshold: float = 0.5, minlenratio: float = 0.0, maxlenratio: float = 10.0, use_teacher_forcing: bool = False, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """Generate the sequence of features given the sequences of characters. Args: text (LongTensor): Input sequence of characters (T,). speech (Tensor, optional): Feature sequence to extract style (N, idim). spembs (Tensor, optional): Speaker embedding vector (spk_embed_dim,). threshold (float, optional): Threshold in inference. minlenratio (float, optional): Minimum length ratio in inference. maxlenratio (float, optional): Maximum length ratio in inference. use_teacher_forcing (bool, optional): Whether to use teacher forcing. Returns: Tensor: Output sequence of features (L, odim). Tensor: Output sequence of stop probabilities (L,). Tensor: Encoder-decoder (source) attention weights (#layers, #heads, L, T). """ x = text y = speech spemb = spembs # add eos at the last of sequence x = F.pad(x, [0, 1], "constant", self.eos) # inference with teacher forcing if use_teacher_forcing: assert speech is not None, "speech must be provided with teacher forcing." # get teacher forcing outputs xs, ys = x.unsqueeze(0), y.unsqueeze(0) spembs = None if spemb is None else spemb.unsqueeze(0) ilens = x.new_tensor([xs.size(1)]).long() olens = y.new_tensor([ys.size(1)]).long() outs, *_ = self._forward(xs, ilens, ys, olens, spembs) # get attention weights att_ws = [] for i in range(len(self.decoder.decoders)): att_ws += [self.decoder.decoders[i].src_attn.attn] att_ws = torch.stack(att_ws, dim=1) # (B, L, H, T_out, T_in) return outs[0], None, att_ws[0] # forward encoder xs = x.unsqueeze(0) hs, _ = self.encoder(xs, None) # integrate GST if self.use_gst: style_embs = self.gst(y.unsqueeze(0)) hs = hs + style_embs.unsqueeze(1) # integrate speaker embedding if self.spk_embed_dim is not None: spembs = spemb.unsqueeze(0) hs = self._integrate_with_spk_embed(hs, spembs) # set limits of length maxlen = int(hs.size(1) * maxlenratio / self.reduction_factor) minlen = int(hs.size(1) * minlenratio / self.reduction_factor) # initialize idx = 0 ys = hs.new_zeros(1, 1, self.odim) outs, probs = [], [] # forward decoder step-by-step z_cache = self.decoder.init_state(x) while True: # update index idx += 1 # calculate output and stop prob at idx-th step y_masks = subsequent_mask(idx).unsqueeze(0).to(x.device) z, z_cache = self.decoder.forward_one_step( ys, y_masks, hs, cache=z_cache ) # (B, adim) outs += [ self.feat_out(z).view(self.reduction_factor, self.odim) ] # [(r, odim), ...] probs += [torch.sigmoid(self.prob_out(z))[0]] # [(r), ...] # update next inputs ys = torch.cat( (ys, outs[-1][-1].view(1, 1, self.odim)), dim=1 ) # (1, idx + 1, odim) # get attention weights att_ws_ = [] for name, m in self.named_modules(): if isinstance(m, MultiHeadedAttention) and "src" in name: att_ws_ += [m.attn[0, :, -1].unsqueeze(1)] # [(#heads, 1, T),...] if idx == 1: att_ws = att_ws_ else: # [(#heads, l, T), ...] att_ws = [ torch.cat([att_w, att_w_], dim=1) for att_w, att_w_ in zip(att_ws, att_ws_) ] # check whether to finish generation if int(sum(probs[-1] >= threshold)) > 0 or idx >= maxlen: # check mininum length if idx < minlen: continue outs = ( torch.cat(outs, dim=0).unsqueeze(0).transpose(1, 2) ) # (L, odim) -> (1, L, odim) -> (1, odim, L) if self.postnet is not None: outs = outs + self.postnet(outs) # (1, odim, L) outs = outs.transpose(2, 1).squeeze(0) # (L, odim) probs = torch.cat(probs, dim=0) break # concatenate attention weights -> (#layers, #heads, L, T) att_ws = torch.stack(att_ws, dim=0) return outs, probs, att_ws def _add_first_frame_and_remove_last_frame(self, ys: torch.Tensor) -> torch.Tensor: ys_in = torch.cat( [ys.new_zeros((ys.shape[0], 1, ys.shape[2])), ys[:, :-1]], dim=1 ) return ys_in def _source_mask(self, ilens): """Make masks for self-attention. Args: ilens (LongTensor): Batch of lengths (B,). Returns: Tensor: Mask tensor for self-attention. dtype=torch.uint8 in PyTorch 1.2- dtype=torch.bool in PyTorch 1.2+ (including 1.2) Examples: >>> ilens = [5, 3] >>> self._source_mask(ilens) tensor([[[1, 1, 1, 1, 1], [[1, 1, 1, 0, 0]]], dtype=torch.uint8) """ x_masks = make_non_pad_mask(ilens).to(next(self.parameters()).device) return x_masks.unsqueeze(-2) def _target_mask(self, olens: torch.Tensor) -> torch.Tensor: """Make masks for masked self-attention. Args: olens (LongTensor): Batch of lengths (B,). Returns: Tensor: Mask tensor for masked self-attention. dtype=torch.uint8 in PyTorch 1.2- dtype=torch.bool in PyTorch 1.2+ (including 1.2) Examples: >>> olens = [5, 3] >>> self._target_mask(olens) tensor([[[1, 0, 0, 0, 0], [1, 1, 0, 0, 0], [1, 1, 1, 0, 0], [1, 1, 1, 1, 0], [1, 1, 1, 1, 1]], [[1, 0, 0, 0, 0], [1, 1, 0, 0, 0], [1, 1, 1, 0, 0], [1, 1, 1, 0, 0], [1, 1, 1, 0, 0]]], dtype=torch.uint8) """ y_masks = make_non_pad_mask(olens).to(next(self.parameters()).device) s_masks = subsequent_mask(y_masks.size(-1), device=y_masks.device).unsqueeze(0) return y_masks.unsqueeze(-2) & s_masks def _integrate_with_spk_embed( self, hs: torch.Tensor, spembs: torch.Tensor ) -> torch.Tensor: """Integrate speaker embedding with hidden states. Args: hs (Tensor): Batch of hidden state sequences (B, Tmax, adim). spembs (Tensor): Batch of speaker embeddings (B, spk_embed_dim). Returns: Tensor: Batch of integrated hidden state sequences (B, Tmax, adim). """ if self.spk_embed_integration_type == "add": # apply projection and then add to hidden states spembs = self.projection(F.normalize(spembs)) hs = hs + spembs.unsqueeze(1) elif self.spk_embed_integration_type == "concat": # concat hidden states with spk embeds and then apply projection spembs = F.normalize(spembs).unsqueeze(1).expand(-1, hs.size(1), -1) hs = self.projection(torch.cat([hs, spembs], dim=-1)) else: raise NotImplementedError("support only add or concat.") return hs
class E2E(MTInterface, torch.nn.Module): """E2E module. :param int idim: dimension of inputs :param int odim: dimension of outputs :param Namespace args: argument Namespace containing options """ @staticmethod def add_arguments(parser): """Add arguments.""" group = parser.add_argument_group("transformer model setting") group = add_arguments_transformer_common(group) return parser @property def attention_plot_class(self): """Return PlotAttentionReport.""" return PlotAttentionReport def __init__(self, idim, odim, args, ignore_id=-1): """Construct an E2E object. :param int idim: dimension of inputs :param int odim: dimension of outputs :param Namespace args: argument Namespace containing options """ torch.nn.Module.__init__(self) # fill missing arguments for compatibility args = fill_missing_args(args, self.add_arguments) if args.transformer_attn_dropout_rate is None: args.transformer_attn_dropout_rate = args.dropout_rate self.encoder = Encoder( idim=idim, selfattention_layer_type=args. transformer_encoder_selfattn_layer_type, attention_dim=args.adim, attention_heads=args.aheads, conv_wshare=args.wshare, conv_kernel_length=args.ldconv_encoder_kernel_length, conv_usebias=args.ldconv_usebias, linear_units=args.eunits, num_blocks=args.elayers, input_layer="embed", dropout_rate=args.dropout_rate, positional_dropout_rate=args.dropout_rate, attention_dropout_rate=args.transformer_attn_dropout_rate, ) self.decoder = Decoder( odim=odim, selfattention_layer_type=args. transformer_decoder_selfattn_layer_type, attention_dim=args.adim, attention_heads=args.aheads, conv_wshare=args.wshare, conv_kernel_length=args.ldconv_decoder_kernel_length, conv_usebias=args.ldconv_usebias, linear_units=args.dunits, num_blocks=args.dlayers, dropout_rate=args.dropout_rate, positional_dropout_rate=args.dropout_rate, self_attention_dropout_rate=args.transformer_attn_dropout_rate, src_attention_dropout_rate=args.transformer_attn_dropout_rate, ) self.pad = 0 # use <blank> for padding self.sos = odim - 1 self.eos = odim - 1 self.odim = odim self.ignore_id = ignore_id self.subsample = get_subsample(args, mode="mt", arch="transformer") self.reporter = Reporter() # tie source and target emeddings if args.tie_src_tgt_embedding: if idim != odim: raise ValueError( "When using tie_src_tgt_embedding, idim and odim must be equal." ) self.encoder.embed[0].weight = self.decoder.embed[0].weight # tie emeddings and the classfier if args.tie_classifier: self.decoder.output_layer.weight = self.decoder.embed[0].weight self.criterion = LabelSmoothingLoss( self.odim, self.ignore_id, args.lsm_weight, args.transformer_length_normalized_loss, ) self.normalize_length = args.transformer_length_normalized_loss # for PPL self.reset_parameters(args) self.adim = args.adim self.error_calculator = ErrorCalculator(args.char_list, args.sym_space, args.sym_blank, args.report_bleu) self.rnnlm = None # multilingual MT related self.multilingual = args.multilingual def reset_parameters(self, args): """Initialize parameters.""" initialize(self, args.transformer_init) torch.nn.init.normal_(self.encoder.embed[0].weight, mean=0, std=args.adim**-0.5) torch.nn.init.constant_(self.encoder.embed[0].weight[self.pad], 0) torch.nn.init.normal_(self.decoder.embed[0].weight, mean=0, std=args.adim**-0.5) torch.nn.init.constant_(self.decoder.embed[0].weight[self.pad], 0) def forward(self, xs_pad, ilens, ys_pad): """E2E forward. :param torch.Tensor xs_pad: batch of padded source sequences (B, Tmax) :param torch.Tensor ilens: batch of lengths of source sequences (B) :param torch.Tensor ys_pad: batch of padded target sequences (B, Lmax) :rtype: torch.Tensor :return: attention loss value :rtype: torch.Tensor :return: accuracy in attention decoder :rtype: float """ # 1. forward encoder xs_pad = xs_pad[:, :max(ilens)] # for data parallel src_mask = (~make_pad_mask(ilens.tolist())).to( xs_pad.device).unsqueeze(-2) xs_pad, ys_pad = self.target_forcing(xs_pad, ys_pad) hs_pad, hs_mask = self.encoder(xs_pad, src_mask) # 2. forward decoder ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id) ys_mask = target_mask(ys_in_pad, self.ignore_id) pred_pad, pred_mask = self.decoder(ys_in_pad, ys_mask, hs_pad, hs_mask) # 3. compute attention loss self.loss = self.criterion(pred_pad, ys_out_pad) self.acc = th_accuracy(pred_pad.view(-1, self.odim), ys_out_pad, ignore_label=self.ignore_id) # 4. compute corpus-level bleu in a mini-batch if self.training: self.bleu = None else: ys_hat = pred_pad.argmax(dim=-1) self.bleu = self.error_calculator(ys_hat.cpu(), ys_pad.cpu()) loss_data = float(self.loss) if self.normalize_length: self.ppl = np.exp(loss_data) else: batch_size = ys_out_pad.size(0) ys_out_pad = ys_out_pad.view(-1) ignore = ys_out_pad == self.ignore_id # (B*T,) total_n_tokens = len(ys_out_pad) - ignore.sum().item() self.ppl = np.exp(loss_data * batch_size / total_n_tokens) if not math.isnan(loss_data): self.reporter.report(loss_data, self.acc, self.ppl, self.bleu) else: logging.warning("loss (=%f) is not correct", loss_data) return self.loss def scorers(self): """Scorers.""" return dict(decoder=self.decoder) def encode(self, xs): """Encode source sentences.""" self.eval() xs = torch.as_tensor(xs).unsqueeze(0) enc_output, _ = self.encoder(xs, None) return enc_output.squeeze(0) def target_forcing(self, xs_pad, ys_pad=None, tgt_lang=None): """Prepend target language IDs to source sentences for multilingual MT. These tags are prepended in source/target sentences as pre-processing. :param torch.Tensor xs_pad: batch of padded input sequences (B, Tmax) :return: source text without language IDs :rtype: torch.Tensor :return: target text without language IDs :rtype: torch.Tensor :return: target language IDs :rtype: torch.Tensor (B, 1) """ if self.multilingual: xs_pad = xs_pad[:, 1:] # remove source language IDs here if ys_pad is not None: # remove language ID in the beginning lang_ids = ys_pad[:, 0].unsqueeze(1) ys_pad = ys_pad[:, 1:] elif tgt_lang is not None: lang_ids = xs_pad.new_zeros(xs_pad.size(0), 1).fill_(tgt_lang) else: raise ValueError("Set ys_pad or tgt_lang.") # prepend target language ID to source sentences xs_pad = torch.cat([lang_ids, xs_pad], dim=1) return xs_pad, ys_pad def translate(self, x, trans_args, char_list=None): """Translate source text. :param list x: input source text feature (T,) :param Namespace trans_args: argment Namespace contraining options :param list char_list: list of characters :return: N-best decoding results :rtype: list """ self.eval( ) # NOTE: this is important because self.encode() is not used assert isinstance(x, list) # make a utt list (1) to use the same interface for encoder if self.multilingual: x = to_device( self, torch.from_numpy( np.fromiter(map(int, x[0][1:]), dtype=np.int64))) else: x = to_device( self, torch.from_numpy(np.fromiter(map(int, x[0]), dtype=np.int64))) logging.info("input lengths: " + str(x.size(0))) xs_pad = x.unsqueeze(0) tgt_lang = None if trans_args.tgt_lang: tgt_lang = char_list.index(trans_args.tgt_lang) xs_pad, _ = self.target_forcing(xs_pad, tgt_lang=tgt_lang) h, _ = self.encoder(xs_pad, None) logging.info("encoder output lengths: " + str(h.size(1))) # search parms beam = trans_args.beam_size penalty = trans_args.penalty if trans_args.maxlenratio == 0: maxlen = h.size(1) else: # maxlen >= 1 maxlen = max(1, int(trans_args.maxlenratio * h.size(1))) minlen = int(trans_args.minlenratio * h.size(1)) logging.info("max output length: " + str(maxlen)) logging.info("min output length: " + str(minlen)) # initialize hypothesis hyp = {"score": 0.0, "yseq": [self.sos]} hyps = [hyp] ended_hyps = [] for i in range(maxlen): logging.debug("position " + str(i)) # batchfy ys = h.new_zeros((len(hyps), i + 1), dtype=torch.int64) for j, hyp in enumerate(hyps): ys[j, :] = torch.tensor(hyp["yseq"]) ys_mask = subsequent_mask(i + 1).unsqueeze(0).to(h.device) local_scores = self.decoder.forward_one_step( ys, ys_mask, h.repeat([len(hyps), 1, 1]))[0] hyps_best_kept = [] for j, hyp in enumerate(hyps): local_best_scores, local_best_ids = torch.topk( local_scores[j:j + 1], beam, dim=1) for j in range(beam): new_hyp = {} new_hyp["score"] = hyp["score"] + float( local_best_scores[0, j]) new_hyp["yseq"] = [0] * (1 + len(hyp["yseq"])) new_hyp["yseq"][:len(hyp["yseq"])] = hyp["yseq"] new_hyp["yseq"][len(hyp["yseq"])] = int(local_best_ids[0, j]) # will be (2 x beam) hyps at most hyps_best_kept.append(new_hyp) hyps_best_kept = sorted(hyps_best_kept, key=lambda x: x["score"], reverse=True)[:beam] # sort and get nbest hyps = hyps_best_kept logging.debug("number of pruned hypothes: " + str(len(hyps))) if char_list is not None: logging.debug( "best hypo: " + "".join([char_list[int(x)] for x in hyps[0]["yseq"][1:]])) # add eos in the final loop to avoid that there are no ended hyps if i == maxlen - 1: logging.info("adding <eos> in the last position in the loop") for hyp in hyps: hyp["yseq"].append(self.eos) # add ended hypothes to a final list, and removed them from current hypothes # (this will be a probmlem, number of hyps < beam) remained_hyps = [] for hyp in hyps: if hyp["yseq"][-1] == self.eos: # only store the sequence that has more than minlen outputs # also add penalty if len(hyp["yseq"]) > minlen: hyp["score"] += (i + 1) * penalty ended_hyps.append(hyp) else: remained_hyps.append(hyp) # end detection if end_detect(ended_hyps, i) and trans_args.maxlenratio == 0.0: logging.info("end detected at %d", i) break hyps = remained_hyps if len(hyps) > 0: logging.debug("remeined hypothes: " + str(len(hyps))) else: logging.info("no hypothesis. Finish decoding.") break if char_list is not None: for hyp in hyps: logging.debug( "hypo: " + "".join([char_list[int(x)] for x in hyp["yseq"][1:]])) logging.debug("number of ended hypothes: " + str(len(ended_hyps))) nbest_hyps = sorted( ended_hyps, key=lambda x: x["score"], reverse=True)[:min(len(ended_hyps), trans_args.nbest)] # check number of hypotheis if len(nbest_hyps) == 0: logging.warning("there is no N-best results, perform translation " "again with smaller minlenratio.") # should copy becasuse Namespace will be overwritten globally trans_args = Namespace(**vars(trans_args)) trans_args.minlenratio = max(0.0, trans_args.minlenratio - 0.1) return self.translate(x, trans_args, char_list) logging.info("total log probability: " + str(nbest_hyps[0]["score"])) logging.info("normalized log probability: " + str(nbest_hyps[0]["score"] / len(nbest_hyps[0]["yseq"]))) return nbest_hyps def calculate_all_attentions(self, xs_pad, ilens, ys_pad): """E2E attention calculation. :param torch.Tensor xs_pad: batch of padded input sequences (B, Tmax) :param torch.Tensor ilens: batch of lengths of input sequences (B) :param torch.Tensor ys_pad: batch of padded token id sequence tensor (B, Lmax) :return: attention weights (B, H, Lmax, Tmax) :rtype: float ndarray """ self.eval() with torch.no_grad(): self.forward(xs_pad, ilens, ys_pad) ret = dict() for name, m in self.named_modules(): if isinstance(m, MultiHeadedAttention) and m.attn is not None: ret[name] = m.attn.cpu().numpy() self.train() return ret
def __init__(self, idim, odim, args): # initialize base classes TTSInterface.__init__(self) torch.nn.Module.__init__(self) # store hyperparameters self.idim = idim self.odim = odim self.use_scaled_pos_enc = args.use_scaled_pos_enc self.reduction_factor = args.reduction_factor self.loss_type = args.loss_type self.use_guided_attn_loss = args.use_guided_attn_loss if self.use_guided_attn_loss: if args.num_layers_applied_guided_attn == -1: self.num_layers_applied_guided_attn = args.elayers else: self.num_layers_applied_guided_attn = args.num_layers_applied_guided_attn if args.num_heads_applied_guided_attn == -1: self.num_heads_applied_guided_attn = args.aheads else: self.num_heads_applied_guided_attn = args.num_heads_applied_guided_attn self.modules_applied_guided_attn = args.modules_applied_guided_attn # use idx 0 as padding idx padding_idx = 0 # get positional encoding class pos_enc_class = ScaledPositionalEncoding if self.use_scaled_pos_enc else PositionalEncoding # define transformer encoder if args.eprenet_conv_layers != 0: # encoder prenet encoder_input_layer = torch.nn.Sequential( EncoderPrenet(idim=idim, embed_dim=args.embed_dim, elayers=0, econv_layers=args.eprenet_conv_layers, econv_chans=args.eprenet_conv_chans, econv_filts=args.eprenet_conv_filts, use_batch_norm=args.use_batch_norm, dropout_rate=args.eprenet_dropout_rate, padding_idx=padding_idx), torch.nn.Linear(args.eprenet_conv_chans, args.adim)) else: encoder_input_layer = torch.nn.Embedding(num_embeddings=idim, embedding_dim=args.adim, padding_idx=padding_idx) self.encoder = Encoder( idim=idim, attention_dim=args.adim, attention_heads=args.aheads, linear_units=args.eunits, num_blocks=args.elayers, input_layer=encoder_input_layer, dropout_rate=args.transformer_enc_dropout_rate, positional_dropout_rate=args. transformer_enc_positional_dropout_rate, attention_dropout_rate=args.transformer_enc_attn_dropout_rate, pos_enc_class=pos_enc_class, normalize_before=args.encoder_normalize_before, concat_after=args.encoder_concat_after) # define transformer decoder if args.dprenet_layers != 0: # decoder prenet decoder_input_layer = torch.nn.Sequential( DecoderPrenet(idim=odim, n_layers=args.dprenet_layers, n_units=args.dprenet_units, dropout_rate=args.dprenet_dropout_rate), torch.nn.Linear(args.dprenet_units, args.adim)) else: decoder_input_layer = "linear" self.decoder = Decoder( odim=-1, attention_dim=args.adim, attention_heads=args.aheads, linear_units=args.dunits, num_blocks=args.dlayers, dropout_rate=args.transformer_dec_dropout_rate, positional_dropout_rate=args. transformer_dec_positional_dropout_rate, self_attention_dropout_rate=args.transformer_dec_attn_dropout_rate, src_attention_dropout_rate=args. transformer_enc_dec_attn_dropout_rate, input_layer=decoder_input_layer, use_output_layer=False, pos_enc_class=pos_enc_class, normalize_before=args.decoder_normalize_before, concat_after=args.decoder_concat_after) # define final projection self.feat_out = torch.nn.Linear(args.adim, odim * args.reduction_factor) self.prob_out = torch.nn.Linear(args.adim, args.reduction_factor) # define postnet self.postnet = None if args.postnet_layers == 0 else Postnet( idim=idim, odim=odim, n_layers=args.postnet_layers, n_chans=args.postnet_chans, n_filts=args.postnet_filts, use_batch_norm=args.use_batch_norm, dropout_rate=args.postnet_dropout_rate) # define loss function self.criterion = TransformerLoss(args) if self.use_guided_attn_loss: self.attn_criterion = GuidedMultiHeadAttentionLoss( args.guided_attn_loss_sigma) # initialize parameters self._reset_parameters(args)
def __init__(self, idim, odim, args, ignore_id=-1): """Construct an E2E object. :param int idim: dimension of inputs :param int odim: dimension of outputs :param Namespace args: argument Namespace containing options """ torch.nn.Module.__init__(self) # fill missing arguments for compatibility args = fill_missing_args(args, self.add_arguments) if args.transformer_attn_dropout_rate is None: args.transformer_attn_dropout_rate = args.dropout_rate self.encoder = Encoder( idim=idim, selfattention_layer_type=args. transformer_encoder_selfattn_layer_type, attention_dim=args.adim, attention_heads=args.aheads, conv_wshare=args.wshare, conv_kernel_length=args.ldconv_encoder_kernel_length, conv_usebias=args.ldconv_usebias, linear_units=args.eunits, num_blocks=args.elayers, input_layer="embed", dropout_rate=args.dropout_rate, positional_dropout_rate=args.dropout_rate, attention_dropout_rate=args.transformer_attn_dropout_rate, ) self.decoder = Decoder( odim=odim, selfattention_layer_type=args. transformer_decoder_selfattn_layer_type, attention_dim=args.adim, attention_heads=args.aheads, conv_wshare=args.wshare, conv_kernel_length=args.ldconv_decoder_kernel_length, conv_usebias=args.ldconv_usebias, linear_units=args.dunits, num_blocks=args.dlayers, dropout_rate=args.dropout_rate, positional_dropout_rate=args.dropout_rate, self_attention_dropout_rate=args.transformer_attn_dropout_rate, src_attention_dropout_rate=args.transformer_attn_dropout_rate, ) self.pad = 0 # use <blank> for padding self.sos = odim - 1 self.eos = odim - 1 self.odim = odim self.ignore_id = ignore_id self.subsample = get_subsample(args, mode="mt", arch="transformer") self.reporter = Reporter() # tie source and target emeddings if args.tie_src_tgt_embedding: if idim != odim: raise ValueError( "When using tie_src_tgt_embedding, idim and odim must be equal." ) self.encoder.embed[0].weight = self.decoder.embed[0].weight # tie emeddings and the classfier if args.tie_classifier: self.decoder.output_layer.weight = self.decoder.embed[0].weight self.criterion = LabelSmoothingLoss( self.odim, self.ignore_id, args.lsm_weight, args.transformer_length_normalized_loss, ) self.normalize_length = args.transformer_length_normalized_loss # for PPL self.reset_parameters(args) self.adim = args.adim self.error_calculator = ErrorCalculator(args.char_list, args.sym_space, args.sym_blank, args.report_bleu) self.rnnlm = None # multilingual MT related self.multilingual = args.multilingual
def __init__(self, idim, odim, args, ignore_id=-1): """Construct an E2E object. :param int idim: dimension of inputs :param int odim: dimension of outputs :param Namespace args: argument Namespace containing options """ torch.nn.Module.__init__(self) if args.transformer_attn_dropout_rate is None: args.transformer_attn_dropout_rate = args.dropout_rate self.encoder = Encoder( idim=idim, attention_dim=args.adim, attention_heads=args.aheads, linear_units=args.eunits, num_blocks=args.elayers, input_layer=args.transformer_input_layer, dropout_rate=args.dropout_rate, positional_dropout_rate=args.dropout_rate, attention_dropout_rate=args.transformer_attn_dropout_rate) self.decoder = Decoder( odim=odim, attention_dim=args.adim, attention_heads=args.aheads, linear_units=args.dunits, num_blocks=args.dlayers, dropout_rate=args.dropout_rate, positional_dropout_rate=args.dropout_rate, self_attention_dropout_rate=args.transformer_attn_dropout_rate, src_attention_dropout_rate=args.transformer_attn_dropout_rate) self.sos = odim - 1 self.eos = odim - 1 self.odim = odim self.ignore_id = ignore_id self.subsample = [1] self.reporter = Reporter() # self.lsm_weight = a self.criterion = LabelSmoothingLoss( self.odim, self.ignore_id, args.lsm_weight, args.transformer_length_normalized_loss) # self.verbose = args.verbose self.adim = args.adim self.mtlalpha = args.mtlalpha if args.mtlalpha > 0.0: self.ctc = CTC(odim, args.adim, args.dropout_rate, ctc_type=args.ctc_type, reduce=True) else: self.ctc = None if args.report_cer or args.report_wer: from espnet.nets.e2e_asr_common import ErrorCalculator self.error_calculator = ErrorCalculator(args.char_list, args.sym_space, args.sym_blank, args.report_cer, args.report_wer) else: self.error_calculator = None self.rnnlm = None # yzl23 config self.remove_blank_in_ctc_mode = True # lid multitask related adim = args.adim self.lid_odim = 2 # cn and en # src attention self.lid_src_att = MultiHeadedAttention( args.aheads, args.adim, args.transformer_attn_dropout_rate) # self.lid_output_layer = torch.nn.Sequential(torch.nn.Linear(adim, adim), # torch.nn.Tanh(), # torch.nn.Linear(adim, self.lid_odim)) self.lid_output_layer = torch.nn.Linear(adim, self.lid_odim) # here we hack to use lsm loss, but with lsm_weight ZERO self.lid_criterion = LanguageIDMultitakLoss(self.ignore_id, \ normalize_length=args.transformer_length_normalized_loss) self.lid_mtl_alpha = args.lid_mtl_alpha logging.warning("language id multitask training alpha %f" % (self.lid_mtl_alpha)) self.log_lid_mtl_acc = args.log_lid_mtl_acc # reset parameters self.reset_parameters(args)
class E2E(ASRInterface, torch.nn.Module): """E2E module. Args: idim (int): dimension of inputs odim (int): dimension of outputs args (Namespace): argument Namespace containing options """ @staticmethod def add_arguments(parser): """Extend arguments for transducer models. Both Transformer and RNN modules are supported. General options encapsulate both modules options. """ group = parser.add_argument_group("transformer model setting") # Encoder - general group.add_argument( '--elayers', default=4, type=int, help= 'Number of encoder layers (for shared recognition part in multi-speaker asr mode)' ) group.add_argument('--eunits', '-u', default=300, type=int, help='Number of encoder hidden units') group.add_argument('--kernel-size', default=32, type=int) group.add_argument('--dropout-rate', default=0.0, type=float, help='Dropout rate for the encoder') # Encoder - RNN group.add_argument('--eprojs', default=320, type=int, help='Number of encoder projection units') group.add_argument( '--subsample', default="1", type=str, help='Subsample input frames x_y_z means subsample every x frame \ at 1st layer, every y frame at 2nd layer etc.') # Attention - general group.add_argument( '--transformer-attn-dropout-rate', default=None, type=float, help= 'dropout in transformer attention. use --dropout-rate if None is set' ) group.add_argument( '--adim', default=320, type=int, help='Number of attention transformation dimensions') group.add_argument('--aheads', default=4, type=int, help='Number of heads for multi head attention') # Decoder - general group.add_argument('--dtype', default='lstm', type=str, choices=['lstm', 'gru', 'transformer'], help='Type of decoder to use.') group.add_argument('--dlayers', default=1, type=int, help='Number of decoder layers') group.add_argument('--dunits', default=320, type=int, help='Number of decoder hidden units') group.add_argument('--dropout-rate-decoder', default=0.0, type=float, help='Dropout rate for the decoder') group.add_argument('--relative-v', default=False, type=strtobool) # Decoder - RNN group.add_argument('--dec-embed-dim', default=320, type=int, help='Number of decoder embeddings dimensions') group.add_argument('--dropout-rate-embed-decoder', default=0.0, type=float, help='Dropout rate for the decoder embeddings') # Transformer group.add_argument( "--input-layer", type=str, default="conv2d", choices=["conv2d", "vgg2l", "linear", "embed", "custom"], help='transformer encoder input layer type') group.add_argument('--transformer-lr', default=10.0, type=float, help='Initial value of learning rate') group.add_argument('--transformer-warmup-steps', default=25000, type=int, help='optimizer warmup steps') group.add_argument('--transformer-length-normalized-loss', default=False, type=strtobool, help='normalize loss by length') group.add_argument("--transformer-init", type=str, default="pytorch", choices=[ "pytorch", "xavier_uniform", "xavier_normal", "kaiming_uniform", "kaiming_normal" ], help='how to initialize transformer parameters') group.add_argument('--causal', type=strtobool, default=False) return parser def __init__(self, idim, odim, args, ignore_id=-1, blank_id=0): """Construct an E2E object for transducer model. Args: idim (int): dimension of inputs odim (int): dimension of outputs args (Namespace): argument Namespace containing options """ torch.nn.Module.__init__(self) if args.transformer_attn_dropout_rate is None: args.transformer_attn_dropout_rate = args.dropout_rate self.encoder = Encoder(idim=idim, d_model=args.adim, n_heads=args.aheads, d_ffn=args.eunits, layers=args.elayers, kernel_size=args.kernel_size, input_layer=args.input_layer, dropout_rate=args.dropout_rate, causal=args.causal) args.eprojs = args.adim if args.mtlalpha < 1.0: self.decoder = Decoder( odim=odim, attention_dim=args.adim, attention_heads=args.aheads, linear_units=args.dunits, num_blocks=args.dlayers, dropout_rate=args.dropout_rate, positional_dropout_rate=args.dropout_rate, self_attention_dropout_rate=args.transformer_attn_dropout_rate, src_attention_dropout_rate=args.transformer_attn_dropout_rate) self.sos = odim - 1 self.eos = odim - 1 self.ignore_id = ignore_id self.odim = odim self.adim = args.adim self.criterion = LabelSmoothingLoss( self.odim, self.ignore_id, args.lsm_weight, args.transformer_length_normalized_loss) # self.verbose = args.verbose self.reset_parameters(args) self.adim = args.adim self.mtlalpha = args.mtlalpha if args.mtlalpha > 0.0: self.ctc = CTC(odim, args.adim, args.dropout_rate, ctc_type=args.ctc_type, reduce=True) else: self.ctc = None def reset_parameters(self, args): """Initialize parameters.""" # initialize parameters initialize(self, args.transformer_init) def forward(self, xs_pad, ilens, ys_pad, enc_mask=None): """E2E forward. Args: xs_pad (torch.Tensor): batch of padded source sequences (B, Tmax, idim) ilens (torch.Tensor): batch of lengths of input sequences (B) ys_pad (torch.Tensor): batch of padded target sequences (B, Lmax) Returns: loss (torch.Tensor): transducer loss value """ # 1. encoder xs_pad = xs_pad[:, :max(ilens)] masks = (~make_pad_mask(ilens.tolist())).to( xs_pad.device).unsqueeze(-2) hs_pad, hs_mask = self.encoder(xs_pad, masks) self.hs_pad = hs_pad # CTC forward ys = [y[y != self.ignore_id] for y in ys_pad] y_len = max([len(y) for y in ys]) ys_pad = ys_pad[:, :y_len] self.hs_pad = hs_pad cer_ctc = None batch_size = xs_pad.size(0) if self.mtlalpha == 0.0: loss_ctc = None else: batch_size = xs_pad.size(0) hs_len = hs_mask.view(batch_size, -1).sum(1) loss_ctc = self.ctc(hs_pad.view(batch_size, -1, self.adim), hs_len, ys_pad) # trigger mask start_time = time.time() # 2. forward decoder ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id) ys_mask = target_mask(ys_in_pad, self.ignore_id) pred_pad, pred_mask = self.decoder(ys_in_pad, ys_mask, hs_pad, hs_mask) self.pred_pad = pred_pad # 3. compute attention loss loss_att = self.criterion(pred_pad, ys_out_pad) self.acc = th_accuracy(pred_pad.view(-1, self.odim), ys_out_pad, ignore_label=self.ignore_id) # copyied from e2e_asr alpha = self.mtlalpha if alpha == 0: self.loss = loss_att loss_att_data = float(loss_att) loss_ctc_data = None elif alpha == 1: self.loss = loss_ctc loss_att_data = None loss_ctc_data = float(loss_ctc) else: self.loss = alpha * loss_ctc + (1 - alpha) * loss_att loss_att_data = float(loss_att) loss_ctc_data = float(loss_ctc) return self.loss, loss_att_data, loss_ctc_data, self.acc def encode(self, x, streaming_mask=None): """Encode acoustic features. Args: x (ndarray): input acoustic feature (T, D) Returns: x (torch.Tensor): encoded features (T, attention_dim) """ self.eval() x = torch.as_tensor(x).unsqueeze(0).cuda() enc_output, _ = self.encoder(x, None, streaming_mask) return enc_output.squeeze(0) def greedy_recognize(self, x, recog_args, streaming_mask=None): h = self.encode(x, streaming_mask).unsqueeze(0) lpz = self.ctc.log_softmax(h) ys_hat = lpz.argmax(dim=-1)[0].cpu() ids = [] for i in range(len(ys_hat)): if ys_hat[i].item() != 0 and ys_hat[i].item() != ys_hat[i - 1].item(): ids.append(ys_hat[i].item()) rets = [{'score': 0.0, 'yseq': ids}] return rets def recognize(self, x, recog_args, char_list=None, rnnlm=None, streaming_mask=None): """ Recognize input features. Args: x (ndarray): input acoustic feature (T, D) recog_args (namespace): argument Namespace containing options char_list (list): list of characters rnnlm (torch.nn.Module): language model module Returns: y (list): n-best decoding results """ h = self.encode(x, streaming_mask) params = [h, recog_args] if recog_args.beam_size == 1: nbest_hyps = self.decoder.recognize(*params) else: params.append(rnnlm) start_time = time.time() nbest_hyps = self.decoder.recognize_beam(*params) #nbest_hyps, decoder_time = self.decoder.beam_search(h, recog_args, prefix=False) end_time = time.time() decode_time = end_time - start_time return nbest_hyps def calculate_all_attentions(self, xs_pad, ilens, ys_pad): """E2E attention calculation. Args: xs_pad (torch.Tensor): batch of padded input sequences (B, Tmax, idim) ilens (torch.Tensor): batch of lengths of input sequences (B) ys_pad (torch.Tensor): batch of padded character id sequence tensor (B, Lmax) Returns: ret (ndarray): attention weights with the following shape, 1) multi-head case => attention weights (B, H, Lmax, Tmax), 2) other case => attention weights (B, Lmax, Tmax). """ if self.etype == 'transformer' and self.dtype != 'transformer' and \ self.rnnt_mode == 'rnnt-att': raise NotImplementedError( "Transformer encoder with rnn attention decoder" "is not supported yet.") elif self.etype != 'transformer' and self.dtype != 'transformer': if self.rnnt_mode == 'rnnt': return [] else: with torch.no_grad(): hs_pad, hlens = xs_pad, ilens hpad, hlens, _ = self.encoder(hs_pad, hlens) ret = self.decoder.calculate_all_attentions( hpad, hlens, ys_pad) else: with torch.no_grad(): self.forward(xs_pad, ilens, ys_pad) ret = dict() for name, m in self.named_modules(): if isinstance(m, MultiHeadedAttention): ret[name] = m.attn.cpu().numpy() return ret
def __init__(self, idim, odim, args, ignore_id=-1): """Construct an E2E object. :param int idim: dimension of inputs :param int odim: dimension of outputs :param Namespace args: argument Namespace containing options """ torch.nn.Module.__init__(self) if args.transformer_attn_dropout_rate is None: args.transformer_attn_dropout_rate = args.dropout_rate self.encoder = Encoder( idim=idim, attention_dim=2 * args.adim, attention_heads=args.aheads, linear_units=args.eunits, num_blocks=args.elayers, input_layer=args.transformer_input_layer, dropout_rate=args.dropout_rate, positional_dropout_rate=args.dropout_rate, attention_dropout_rate=args.transformer_attn_dropout_rate) self.decoder = Decoder( odim=odim, attention_dim=2 * args.adim, attention_heads=args.aheads, linear_units=args.dunits, num_blocks=args.dlayers, dropout_rate=args.dropout_rate, positional_dropout_rate=args.dropout_rate, self_attention_dropout_rate=args.transformer_attn_dropout_rate, src_attention_dropout_rate=args.transformer_attn_dropout_rate) self.sos = odim - 1 self.eos = odim - 1 self.odim = odim self.ignore_id = ignore_id self.subsample = [1] self.reporter = Reporter() # self.lsm_weight = a self.criterion = LabelSmoothingLoss( self.odim, self.ignore_id, args.lsm_weight, args.transformer_length_normalized_loss) # self.verbose = args.verbose self.adim = args.adim self.mtlalpha = args.mtlalpha if args.mtlalpha > 0.0: self.ctc = CTC(odim, 2 * args.adim, args.dropout_rate, ctc_type=args.ctc_type, reduce=True) else: self.ctc = None if args.report_cer or args.report_wer: from espnet.nets.e2e_asr_common import ErrorCalculator self.error_calculator = ErrorCalculator(args.char_list, args.sym_space, args.sym_blank, args.report_cer, args.report_wer) else: self.error_calculator = None self.rnnlm = None # yzl23 config # the encoder is enlarged to 2*adim, # thus a layer-norm affine transformation is needed # self.enc_proj = torch.nn.Linear(2*args.adim, 2*args.adim, bias=True) # self.enc_proj_ln = LayerNorm(2*args.adim) # compatible with previous self.enc_proj = torch.nn.Linear(2 * args.adim, 2 * args.adim, bias=False) # espnet CTC decoding-bug, remove blank in prefix-decoding self.remove_blank_in_ctc_mode = True self.reset_parameters(args) # reset params at the last logging.warning( "Model total size: {}M, requires_grad size: {}M".format( self.count_parameters(), self.count_parameters(requires_grad=True)))
class E2E(STInterface, torch.nn.Module): """E2E module. :param int idim: dimension of inputs :param int odim: dimension of outputs :param Namespace args: argument Namespace containing options """ @staticmethod def add_arguments(parser): """Add arguments.""" group = parser.add_argument_group("transformer model setting") group = add_arguments_transformer_common(group) return parser @property def attention_plot_class(self): """Return PlotAttentionReport.""" return PlotAttentionReport def __init__(self, idim, odim, args, ignore_id=-1): """Construct an E2E object. :param int idim: dimension of inputs :param int odim: dimension of outputs :param Namespace args: argument Namespace containing options """ torch.nn.Module.__init__(self) # fill missing arguments for compatibility args = fill_missing_args(args, self.add_arguments) if args.transformer_attn_dropout_rate is None: args.transformer_attn_dropout_rate = args.dropout_rate self.encoder = Encoder( idim=idim, selfattention_layer_type=args. transformer_encoder_selfattn_layer_type, attention_dim=args.adim, attention_heads=args.aheads, conv_wshare=args.wshare, conv_kernel_length=args.ldconv_encoder_kernel_length, conv_usebias=args.ldconv_usebias, linear_units=args.eunits, num_blocks=args.elayers, input_layer=args.transformer_input_layer, dropout_rate=args.dropout_rate, positional_dropout_rate=args.dropout_rate, attention_dropout_rate=args.transformer_attn_dropout_rate, ) self.decoder = Decoder( odim=odim, selfattention_layer_type=args. transformer_decoder_selfattn_layer_type, attention_dim=args.adim, attention_heads=args.aheads, conv_wshare=args.wshare, conv_kernel_length=args.ldconv_decoder_kernel_length, conv_usebias=args.ldconv_usebias, linear_units=args.dunits, num_blocks=args.dlayers, dropout_rate=args.dropout_rate, positional_dropout_rate=args.dropout_rate, self_attention_dropout_rate=args.transformer_attn_dropout_rate, src_attention_dropout_rate=args.transformer_attn_dropout_rate, ) self.pad = 0 # use <blank> for padding self.sos = odim - 1 self.eos = odim - 1 self.odim = odim self.ignore_id = ignore_id self.subsample = get_subsample(args, mode="st", arch="transformer") self.reporter = Reporter() self.criterion = LabelSmoothingLoss( self.odim, self.ignore_id, args.lsm_weight, args.transformer_length_normalized_loss, ) # submodule for ASR task self.mtlalpha = args.mtlalpha self.asr_weight = args.asr_weight if self.asr_weight > 0 and args.mtlalpha < 1: self.decoder_asr = Decoder( odim=odim, attention_dim=args.adim, attention_heads=args.aheads, linear_units=args.dunits, num_blocks=args.dlayers, dropout_rate=args.dropout_rate, positional_dropout_rate=args.dropout_rate, self_attention_dropout_rate=args.transformer_attn_dropout_rate, src_attention_dropout_rate=args.transformer_attn_dropout_rate, ) # submodule for MT task self.mt_weight = args.mt_weight if self.mt_weight > 0: self.encoder_mt = Encoder( idim=odim, attention_dim=args.adim, attention_heads=args.aheads, linear_units=args.dunits, num_blocks=args.dlayers, input_layer="embed", dropout_rate=args.dropout_rate, positional_dropout_rate=args.dropout_rate, attention_dropout_rate=args.transformer_attn_dropout_rate, padding_idx=0, ) self.reset_parameters( args) # NOTE: place after the submodule initialization self.adim = args.adim # used for CTC (equal to d_model) if self.asr_weight > 0 and args.mtlalpha > 0.0: self.ctc = CTC(odim, args.adim, args.dropout_rate, ctc_type=args.ctc_type, reduce=True) else: self.ctc = None # translation error calculator self.error_calculator = MTErrorCalculator(args.char_list, args.sym_space, args.sym_blank, args.report_bleu) # recognition error calculator self.error_calculator_asr = ASRErrorCalculator( args.char_list, args.sym_space, args.sym_blank, args.report_cer, args.report_wer, ) self.rnnlm = None # multilingual E2E-ST related self.multilingual = getattr(args, "multilingual", False) self.replace_sos = getattr(args, "replace_sos", False) def reset_parameters(self, args): """Initialize parameters.""" initialize(self, args.transformer_init) if self.mt_weight > 0: torch.nn.init.normal_(self.encoder_mt.embed[0].weight, mean=0, std=args.adim**-0.5) torch.nn.init.constant_(self.encoder_mt.embed[0].weight[self.pad], 0) def forward(self, xs_pad, ilens, ys_pad, ys_pad_src): """E2E forward. :param torch.Tensor xs_pad: batch of padded source sequences (B, Tmax, idim) :param torch.Tensor ilens: batch of lengths of source sequences (B) :param torch.Tensor ys_pad: batch of padded target sequences (B, Lmax) :param torch.Tensor ys_pad_src: batch of padded target sequences (B, Lmax) :return: ctc loss value :rtype: torch.Tensor :return: attention loss value :rtype: torch.Tensor :return: accuracy in attention decoder :rtype: float """ # 0. Extract target language ID tgt_lang_ids = None if self.multilingual: tgt_lang_ids = ys_pad[:, 0:1] ys_pad = ys_pad[:, 1:] # remove target language ID in the beggining # 1. forward encoder xs_pad = xs_pad[:, :max(ilens)] # for data parallel src_mask = make_non_pad_mask(ilens.tolist()).to( xs_pad.device).unsqueeze(-2) hs_pad, hs_mask = self.encoder(xs_pad, src_mask) # 2. forward decoder ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id) # replace <sos> with target language ID if self.replace_sos: ys_in_pad = torch.cat([tgt_lang_ids, ys_in_pad[:, 1:]], dim=1) ys_mask = target_mask(ys_in_pad, self.ignore_id) pred_pad, pred_mask = self.decoder(ys_in_pad, ys_mask, hs_pad, hs_mask) # 3. compute ST loss loss_att = self.criterion(pred_pad, ys_out_pad) self.acc = th_accuracy(pred_pad.view(-1, self.odim), ys_out_pad, ignore_label=self.ignore_id) # 4. compute corpus-level bleu in a mini-batch if self.training: self.bleu = None else: ys_hat = pred_pad.argmax(dim=-1) self.bleu = self.error_calculator(ys_hat.cpu(), ys_pad.cpu()) # 5. compute auxiliary ASR loss loss_asr_att, acc_asr, loss_asr_ctc, cer_ctc, cer, wer = self.forward_asr( hs_pad, hs_mask, ys_pad_src) # 6. compute auxiliary MT loss loss_mt, acc_mt = 0.0, None if self.mt_weight > 0: loss_mt, acc_mt = self.forward_mt(ys_pad_src, ys_in_pad, ys_out_pad, ys_mask) asr_ctc_weight = self.mtlalpha self.loss = ((1 - self.asr_weight - self.mt_weight) * loss_att + self.asr_weight * (asr_ctc_weight * loss_asr_ctc + (1 - asr_ctc_weight) * loss_asr_att) + self.mt_weight * loss_mt) loss_asr_data = float(asr_ctc_weight * loss_asr_ctc + (1 - asr_ctc_weight) * loss_asr_att) loss_mt_data = None if self.mt_weight == 0 else float(loss_mt) loss_st_data = float(loss_att) loss_data = float(self.loss) if loss_data < CTC_LOSS_THRESHOLD and not math.isnan(loss_data): self.reporter.report( loss_asr_data, loss_mt_data, loss_st_data, acc_asr, acc_mt, self.acc, cer_ctc, cer, wer, self.bleu, loss_data, ) else: logging.warning("loss (=%f) is not correct", loss_data) return self.loss def forward_asr(self, hs_pad, hs_mask, ys_pad): """Forward pass in the auxiliary ASR task. :param torch.Tensor hs_pad: batch of padded source sequences (B, Tmax, idim) :param torch.Tensor hs_mask: batch of input token mask (B, Lmax) :param torch.Tensor ys_pad: batch of padded target sequences (B, Lmax) :return: ASR attention loss value :rtype: torch.Tensor :return: accuracy in ASR attention decoder :rtype: float :return: ASR CTC loss value :rtype: torch.Tensor :return: character error rate from CTC prediction :rtype: float :return: character error rate from attetion decoder prediction :rtype: float :return: word error rate from attetion decoder prediction :rtype: float """ loss_att, loss_ctc = 0.0, 0.0 acc = None cer, wer = None, None cer_ctc = None if self.asr_weight == 0: return loss_att, acc, loss_ctc, cer_ctc, cer, wer # attention if self.mtlalpha < 1: ys_in_pad_asr, ys_out_pad_asr = add_sos_eos( ys_pad, self.sos, self.eos, self.ignore_id) ys_mask_asr = target_mask(ys_in_pad_asr, self.ignore_id) pred_pad, _ = self.decoder_asr(ys_in_pad_asr, ys_mask_asr, hs_pad, hs_mask) loss_att = self.criterion(pred_pad, ys_out_pad_asr) acc = th_accuracy( pred_pad.view(-1, self.odim), ys_out_pad_asr, ignore_label=self.ignore_id, ) if not self.training: ys_hat_asr = pred_pad.argmax(dim=-1) cer, wer = self.error_calculator_asr(ys_hat_asr.cpu(), ys_pad.cpu()) # CTC if self.mtlalpha > 0: batch_size = hs_pad.size(0) hs_len = hs_mask.view(batch_size, -1).sum(1) loss_ctc = self.ctc(hs_pad.view(batch_size, -1, self.adim), hs_len, ys_pad) if not self.training: ys_hat_ctc = self.ctc.argmax( hs_pad.view(batch_size, -1, self.adim)).data cer_ctc = self.error_calculator_asr(ys_hat_ctc.cpu(), ys_pad.cpu(), is_ctc=True) # for visualization self.ctc.softmax(hs_pad) return loss_att, acc, loss_ctc, cer_ctc, cer, wer def forward_mt(self, xs_pad, ys_in_pad, ys_out_pad, ys_mask): """Forward pass in the auxiliary MT task. :param torch.Tensor xs_pad: batch of padded source sequences (B, Tmax, idim) :param torch.Tensor ys_in_pad: batch of padded target sequences (B, Lmax) :param torch.Tensor ys_out_pad: batch of padded target sequences (B, Lmax) :param torch.Tensor ys_mask: batch of input token mask (B, Lmax) :return: MT loss value :rtype: torch.Tensor :return: accuracy in MT decoder :rtype: float """ loss, acc = 0.0, None if self.mt_weight == 0: return loss, acc ilens = torch.sum(xs_pad != self.ignore_id, dim=1).cpu().numpy() # NOTE: xs_pad is padded with -1 xs = [x[x != self.ignore_id] for x in xs_pad] # parse padded xs xs_zero_pad = pad_list(xs, self.pad) # re-pad with zero xs_zero_pad = xs_zero_pad[:, :max(ilens)] # for data parallel src_mask = (make_non_pad_mask(ilens.tolist()).to( xs_zero_pad.device).unsqueeze(-2)) hs_pad, hs_mask = self.encoder_mt(xs_zero_pad, src_mask) pred_pad, _ = self.decoder(ys_in_pad, ys_mask, hs_pad, hs_mask) loss = self.criterion(pred_pad, ys_out_pad) acc = th_accuracy(pred_pad.view(-1, self.odim), ys_out_pad, ignore_label=self.ignore_id) return loss, acc def scorers(self): """Scorers.""" return dict(decoder=self.decoder) def encode(self, x): """Encode source acoustic features. :param ndarray x: source acoustic feature (T, D) :return: encoder outputs :rtype: torch.Tensor """ self.eval() x = torch.as_tensor(x).unsqueeze(0) enc_output, _ = self.encoder(x, None) return enc_output.squeeze(0) def translate( self, x, trans_args, char_list=None, ): """Translate input speech. :param ndnarray x: input acoustic feature (B, T, D) or (T, D) :param Namespace trans_args: argment Namespace contraining options :param list char_list: list of characters :return: N-best decoding results :rtype: list """ # preprate sos if getattr(trans_args, "tgt_lang", False): if self.replace_sos: y = char_list.index(trans_args.tgt_lang) else: y = self.sos logging.info("<sos> index: " + str(y)) logging.info("<sos> mark: " + char_list[y]) logging.info("input lengths: " + str(x.shape[0])) enc_output = self.encode(x).unsqueeze(0) h = enc_output logging.info("encoder output lengths: " + str(h.size(1))) # search parms beam = trans_args.beam_size penalty = trans_args.penalty if trans_args.maxlenratio == 0: maxlen = h.size(1) else: # maxlen >= 1 maxlen = max(1, int(trans_args.maxlenratio * h.size(1))) minlen = int(trans_args.minlenratio * h.size(1)) logging.info("max output length: " + str(maxlen)) logging.info("min output length: " + str(minlen)) # initialize hypothesis hyp = {"score": 0.0, "yseq": [y]} hyps = [hyp] ended_hyps = [] for i in range(maxlen): logging.debug("position " + str(i)) # batchfy ys = h.new_zeros((len(hyps), i + 1), dtype=torch.int64) for j, hyp in enumerate(hyps): ys[j, :] = torch.tensor(hyp["yseq"]) ys_mask = subsequent_mask(i + 1).unsqueeze(0).to(h.device) local_scores = self.decoder.forward_one_step( ys, ys_mask, h.repeat([len(hyps), 1, 1]))[0] hyps_best_kept = [] for j, hyp in enumerate(hyps): local_best_scores, local_best_ids = torch.topk( local_scores[j:j + 1], beam, dim=1) for j in range(beam): new_hyp = {} new_hyp["score"] = hyp["score"] + float( local_best_scores[0, j]) new_hyp["yseq"] = [0] * (1 + len(hyp["yseq"])) new_hyp["yseq"][:len(hyp["yseq"])] = hyp["yseq"] new_hyp["yseq"][len(hyp["yseq"])] = int(local_best_ids[0, j]) # will be (2 x beam) hyps at most hyps_best_kept.append(new_hyp) hyps_best_kept = sorted(hyps_best_kept, key=lambda x: x["score"], reverse=True)[:beam] # sort and get nbest hyps = hyps_best_kept logging.debug("number of pruned hypothes: " + str(len(hyps))) if char_list is not None: logging.debug( "best hypo: " + "".join([char_list[int(x)] for x in hyps[0]["yseq"][1:]])) # add eos in the final loop to avoid that there are no ended hyps if i == maxlen - 1: logging.info("adding <eos> in the last postion in the loop") for hyp in hyps: hyp["yseq"].append(self.eos) # add ended hypothes to a final list, and removed them from current hypothes # (this will be a probmlem, number of hyps < beam) remained_hyps = [] for hyp in hyps: if hyp["yseq"][-1] == self.eos: # only store the sequence that has more than minlen outputs # also add penalty if len(hyp["yseq"]) > minlen: hyp["score"] += (i + 1) * penalty ended_hyps.append(hyp) else: remained_hyps.append(hyp) # end detection if end_detect(ended_hyps, i) and trans_args.maxlenratio == 0.0: logging.info("end detected at %d", i) break hyps = remained_hyps if len(hyps) > 0: logging.debug("remeined hypothes: " + str(len(hyps))) else: logging.info("no hypothesis. Finish decoding.") break if char_list is not None: for hyp in hyps: logging.debug( "hypo: " + "".join([char_list[int(x)] for x in hyp["yseq"][1:]])) logging.debug("number of ended hypothes: " + str(len(ended_hyps))) nbest_hyps = sorted( ended_hyps, key=lambda x: x["score"], reverse=True)[:min(len(ended_hyps), trans_args.nbest)] # check number of hypotheis if len(nbest_hyps) == 0: logging.warning("there is no N-best results, perform translation " "again with smaller minlenratio.") # should copy becasuse Namespace will be overwritten globally trans_args = Namespace(**vars(trans_args)) trans_args.minlenratio = max(0.0, trans_args.minlenratio - 0.1) return self.translate(x, trans_args, char_list) logging.info("total log probability: " + str(nbest_hyps[0]["score"])) logging.info("normalized log probability: " + str(nbest_hyps[0]["score"] / len(nbest_hyps[0]["yseq"]))) return nbest_hyps def calculate_all_attentions(self, xs_pad, ilens, ys_pad, ys_pad_src): """E2E attention calculation. :param torch.Tensor xs_pad: batch of padded input sequences (B, Tmax, idim) :param torch.Tensor ilens: batch of lengths of input sequences (B) :param torch.Tensor ys_pad: batch of padded token id sequence tensor (B, Lmax) :param torch.Tensor ys_pad_src: batch of padded token id sequence tensor (B, Lmax) :return: attention weights (B, H, Lmax, Tmax) :rtype: float ndarray """ self.eval() with torch.no_grad(): self.forward(xs_pad, ilens, ys_pad, ys_pad_src) ret = dict() for name, m in self.named_modules(): if (isinstance(m, MultiHeadedAttention) and m.attn is not None): # skip MHA for submodules ret[name] = m.attn.cpu().numpy() self.train() return ret def calculate_all_ctc_probs(self, xs_pad, ilens, ys_pad, ys_pad_src): """E2E CTC probability calculation. :param torch.Tensor xs_pad: batch of padded input sequences (B, Tmax) :param torch.Tensor ilens: batch of lengths of input sequences (B) :param torch.Tensor ys_pad: batch of padded token id sequence tensor (B, Lmax) :param torch.Tensor ys_pad_src: batch of padded token id sequence tensor (B, Lmax) :return: CTC probability (B, Tmax, vocab) :rtype: float ndarray """ ret = None if self.asr_weight == 0 or self.mtlalpha == 0: return ret self.eval() with torch.no_grad(): self.forward(xs_pad, ilens, ys_pad, ys_pad_src) ret = None for name, m in self.named_modules(): if isinstance(m, CTC) and m.probs is not None: ret = m.probs.cpu().numpy() self.train() return ret
class Transformer(TTSInterface, torch.nn.Module): """VC Transformer module. This is a module of the Voice Transformer Network (a.k.a. VTN or Transformer-VC) described in `Voice Transformer Network: Sequence-to-Sequence Voice Conversion Using Transformer with Text-to-Speech Pretraining`_, which convert the sequence of acoustic features into the sequence of acoustic features. .. _`Voice Transformer Network: Sequence-to-Sequence Voice Conversion Using Transformer with Text-to-Speech Pretraining`: https://arxiv.org/pdf/1912.06813.pdf """ @staticmethod def add_arguments(parser): """Add model-specific arguments to the parser.""" group = parser.add_argument_group("transformer model setting") # network structure related group.add_argument( "--eprenet-conv-layers", default=0, type=int, help="Number of encoder prenet convolution layers", ) group.add_argument( "--eprenet-conv-chans", default=0, type=int, help="Number of encoder prenet convolution channels", ) group.add_argument( "--eprenet-conv-filts", default=0, type=int, help="Filter size of encoder prenet convolution", ) group.add_argument( "--transformer-input-layer", default="linear", type=str, help="Type of input layer (linear or conv2d)", ) group.add_argument( "--dprenet-layers", default=2, type=int, help="Number of decoder prenet layers", ) group.add_argument( "--dprenet-units", default=256, type=int, help="Number of decoder prenet hidden units", ) group.add_argument("--elayers", default=3, type=int, help="Number of encoder layers") group.add_argument("--eunits", default=1536, type=int, help="Number of encoder hidden units") group.add_argument( "--adim", default=384, type=int, help="Number of attention transformation dimensions", ) group.add_argument( "--aheads", default=4, type=int, help="Number of heads for multi head attention", ) group.add_argument("--dlayers", default=3, type=int, help="Number of decoder layers") group.add_argument("--dunits", default=1536, type=int, help="Number of decoder hidden units") group.add_argument( "--positionwise-layer-type", default="linear", type=str, choices=["linear", "conv1d", "conv1d-linear"], help="Positionwise layer type.", ) group.add_argument( "--positionwise-conv-kernel-size", default=1, type=int, help="Kernel size of positionwise conv1d layer", ) group.add_argument("--postnet-layers", default=5, type=int, help="Number of postnet layers") group.add_argument("--postnet-chans", default=256, type=int, help="Number of postnet channels") group.add_argument("--postnet-filts", default=5, type=int, help="Filter size of postnet") group.add_argument( "--use-scaled-pos-enc", default=True, type=strtobool, help="Use trainable scaled positional encoding" "instead of the fixed scale one.", ) group.add_argument( "--use-batch-norm", default=True, type=strtobool, help="Whether to use batch normalization", ) group.add_argument( "--encoder-normalize-before", default=False, type=strtobool, help="Whether to apply layer norm before encoder block", ) group.add_argument( "--decoder-normalize-before", default=False, type=strtobool, help="Whether to apply layer norm before decoder block", ) group.add_argument( "--encoder-concat-after", default=False, type=strtobool, help= "Whether to concatenate attention layer's input and output in encoder", ) group.add_argument( "--decoder-concat-after", default=False, type=strtobool, help= "Whether to concatenate attention layer's input and output in decoder", ) group.add_argument( "--reduction-factor", default=1, type=int, help="Reduction factor (for decoder)", ) group.add_argument( "--encoder-reduction-factor", default=1, type=int, help="Reduction factor (for encoder)", ) group.add_argument( "--spk-embed-dim", default=None, type=int, help="Number of speaker embedding dimensions", ) group.add_argument( "--spk-embed-integration-type", type=str, default="add", choices=["add", "concat"], help="How to integrate speaker embedding", ) # training related group.add_argument( "--transformer-init", type=str, default="pytorch", choices=[ "pytorch", "xavier_uniform", "xavier_normal", "kaiming_uniform", "kaiming_normal", ], help="How to initialize transformer parameters", ) group.add_argument( "--initial-encoder-alpha", type=float, default=1.0, help="Initial alpha value in encoder's ScaledPositionalEncoding", ) group.add_argument( "--initial-decoder-alpha", type=float, default=1.0, help="Initial alpha value in decoder's ScaledPositionalEncoding", ) group.add_argument( "--transformer-lr", default=1.0, type=float, help="Initial value of learning rate", ) group.add_argument( "--transformer-warmup-steps", default=4000, type=int, help="Optimizer warmup steps", ) group.add_argument( "--transformer-enc-dropout-rate", default=0.1, type=float, help="Dropout rate for transformer encoder except for attention", ) group.add_argument( "--transformer-enc-positional-dropout-rate", default=0.1, type=float, help="Dropout rate for transformer encoder positional encoding", ) group.add_argument( "--transformer-enc-attn-dropout-rate", default=0.1, type=float, help="Dropout rate for transformer encoder self-attention", ) group.add_argument( "--transformer-dec-dropout-rate", default=0.1, type=float, help="Dropout rate for transformer decoder " "except for attention and pos encoding", ) group.add_argument( "--transformer-dec-positional-dropout-rate", default=0.1, type=float, help="Dropout rate for transformer decoder positional encoding", ) group.add_argument( "--transformer-dec-attn-dropout-rate", default=0.1, type=float, help="Dropout rate for transformer decoder self-attention", ) group.add_argument( "--transformer-enc-dec-attn-dropout-rate", default=0.1, type=float, help="Dropout rate for transformer encoder-decoder attention", ) group.add_argument( "--eprenet-dropout-rate", default=0.5, type=float, help="Dropout rate in encoder prenet", ) group.add_argument( "--dprenet-dropout-rate", default=0.5, type=float, help="Dropout rate in decoder prenet", ) group.add_argument( "--postnet-dropout-rate", default=0.5, type=float, help="Dropout rate in postnet", ) group.add_argument("--pretrained-model", default=None, type=str, help="Pretrained model path") # loss related group.add_argument( "--use-masking", default=True, type=strtobool, help="Whether to use masking in calculation of loss", ) group.add_argument( "--use-weighted-masking", default=False, type=strtobool, help="Whether to use weighted masking in calculation of loss", ) group.add_argument( "--loss-type", default="L1", choices=["L1", "L2", "L1+L2"], help="How to calc loss", ) group.add_argument( "--bce-pos-weight", default=5.0, type=float, help="Positive sample weight in BCE calculation " "(only for use-masking=True)", ) group.add_argument( "--use-guided-attn-loss", default=False, type=strtobool, help="Whether to use guided attention loss", ) group.add_argument( "--guided-attn-loss-sigma", default=0.4, type=float, help="Sigma in guided attention loss", ) group.add_argument( "--guided-attn-loss-lambda", default=1.0, type=float, help="Lambda in guided attention loss", ) group.add_argument( "--num-heads-applied-guided-attn", default=2, type=int, help= "Number of heads in each layer to be applied guided attention loss" "if set -1, all of the heads will be applied.", ) group.add_argument( "--num-layers-applied-guided-attn", default=2, type=int, help="Number of layers to be applied guided attention loss" "if set -1, all of the layers will be applied.", ) group.add_argument( "--modules-applied-guided-attn", type=str, nargs="+", default=["encoder-decoder"], help="Module name list to be applied guided attention loss", ) return parser @property def attention_plot_class(self): """Return plot class for attention weight plot.""" return TTSPlot def __init__(self, idim, odim, args=None): """Initialize Transformer-VC module. Args: idim (int): Dimension of the inputs. odim (int): Dimension of the outputs. args (Namespace, optional): - eprenet_conv_layers (int): Number of encoder prenet convolution layers. - eprenet_conv_chans (int): Number of encoder prenet convolution channels. - eprenet_conv_filts (int): Filter size of encoder prenet convolution. - transformer_input_layer (str): Input layer before the encoder. - dprenet_layers (int): Number of decoder prenet layers. - dprenet_units (int): Number of decoder prenet hidden units. - elayers (int): Number of encoder layers. - eunits (int): Number of encoder hidden units. - adim (int): Number of attention transformation dimensions. - aheads (int): Number of heads for multi head attention. - dlayers (int): Number of decoder layers. - dunits (int): Number of decoder hidden units. - postnet_layers (int): Number of postnet layers. - postnet_chans (int): Number of postnet channels. - postnet_filts (int): Filter size of postnet. - use_scaled_pos_enc (bool): Whether to use trainable scaled positional encoding. - use_batch_norm (bool): Whether to use batch normalization in encoder prenet. - encoder_normalize_before (bool): Whether to perform layer normalization before encoder block. - decoder_normalize_before (bool): Whether to perform layer normalization before decoder block. - encoder_concat_after (bool): Whether to concatenate attention layer's input and output in encoder. - decoder_concat_after (bool): Whether to concatenate attention layer's input and output in decoder. - reduction_factor (int): Reduction factor (for decoder). - encoder_reduction_factor (int): Reduction factor (for encoder). - spk_embed_dim (int): Number of speaker embedding dimenstions. - spk_embed_integration_type: How to integrate speaker embedding. - transformer_init (float): How to initialize transformer parameters. - transformer_lr (float): Initial value of learning rate. - transformer_warmup_steps (int): Optimizer warmup steps. - transformer_enc_dropout_rate (float): Dropout rate in encoder except attention & positional encoding. - transformer_enc_positional_dropout_rate (float): Dropout rate after encoder positional encoding. - transformer_enc_attn_dropout_rate (float): Dropout rate in encoder self-attention module. - transformer_dec_dropout_rate (float): Dropout rate in decoder except attention & positional encoding. - transformer_dec_positional_dropout_rate (float): Dropout rate after decoder positional encoding. - transformer_dec_attn_dropout_rate (float): Dropout rate in deocoder self-attention module. - transformer_enc_dec_attn_dropout_rate (float): Dropout rate in encoder-deocoder attention module. - eprenet_dropout_rate (float): Dropout rate in encoder prenet. - dprenet_dropout_rate (float): Dropout rate in decoder prenet. - postnet_dropout_rate (float): Dropout rate in postnet. - use_masking (bool): Whether to apply masking for padded part in loss calculation. - use_weighted_masking (bool): Whether to apply weighted masking in loss calculation. - bce_pos_weight (float): Positive sample weight in bce calculation (only for use_masking=true). - loss_type (str): How to calculate loss. - use_guided_attn_loss (bool): Whether to use guided attention loss. - num_heads_applied_guided_attn (int): Number of heads in each layer to apply guided attention loss. - num_layers_applied_guided_attn (int): Number of layers to apply guided attention loss. - modules_applied_guided_attn (list): List of module names to apply guided attention loss. - guided-attn-loss-sigma (float) Sigma in guided attention loss. - guided-attn-loss-lambda (float): Lambda in guided attention loss. """ # initialize base classes TTSInterface.__init__(self) torch.nn.Module.__init__(self) # fill missing arguments args = fill_missing_args(args, self.add_arguments) # store hyperparameters self.idim = idim self.odim = odim self.spk_embed_dim = args.spk_embed_dim if self.spk_embed_dim is not None: self.spk_embed_integration_type = args.spk_embed_integration_type self.use_scaled_pos_enc = args.use_scaled_pos_enc self.reduction_factor = args.reduction_factor self.encoder_reduction_factor = args.encoder_reduction_factor self.transformer_input_layer = args.transformer_input_layer self.loss_type = args.loss_type self.use_guided_attn_loss = args.use_guided_attn_loss if self.use_guided_attn_loss: if args.num_layers_applied_guided_attn == -1: self.num_layers_applied_guided_attn = args.elayers else: self.num_layers_applied_guided_attn = ( args.num_layers_applied_guided_attn) if args.num_heads_applied_guided_attn == -1: self.num_heads_applied_guided_attn = args.aheads else: self.num_heads_applied_guided_attn = args.num_heads_applied_guided_attn self.modules_applied_guided_attn = args.modules_applied_guided_attn # use idx 0 as padding idx padding_idx = 0 # get positional encoding class pos_enc_class = (ScaledPositionalEncoding if self.use_scaled_pos_enc else PositionalEncoding) # define transformer encoder if args.eprenet_conv_layers != 0: # encoder prenet encoder_input_layer = torch.nn.Sequential( EncoderPrenet( idim=idim, elayers=0, econv_layers=args.eprenet_conv_layers, econv_chans=args.eprenet_conv_chans, econv_filts=args.eprenet_conv_filts, use_batch_norm=args.use_batch_norm, dropout_rate=args.eprenet_dropout_rate, padding_idx=padding_idx, input_layer=torch.nn.Linear( idim * args.encoder_reduction_factor, idim), ), torch.nn.Linear(args.eprenet_conv_chans, args.adim), ) elif args.transformer_input_layer == "linear": encoder_input_layer = torch.nn.Linear( idim * args.encoder_reduction_factor, args.adim) else: encoder_input_layer = args.transformer_input_layer self.encoder = Encoder( idim=idim, attention_dim=args.adim, attention_heads=args.aheads, linear_units=args.eunits, num_blocks=args.elayers, input_layer=encoder_input_layer, dropout_rate=args.transformer_enc_dropout_rate, positional_dropout_rate=args. transformer_enc_positional_dropout_rate, attention_dropout_rate=args.transformer_enc_attn_dropout_rate, pos_enc_class=pos_enc_class, normalize_before=args.encoder_normalize_before, concat_after=args.encoder_concat_after, positionwise_layer_type=args.positionwise_layer_type, positionwise_conv_kernel_size=args.positionwise_conv_kernel_size, ) # define projection layer if self.spk_embed_dim is not None: if self.spk_embed_integration_type == "add": self.projection = torch.nn.Linear(self.spk_embed_dim, args.adim) else: self.projection = torch.nn.Linear( args.adim + self.spk_embed_dim, args.adim) # define transformer decoder if args.dprenet_layers != 0: # decoder prenet decoder_input_layer = torch.nn.Sequential( DecoderPrenet( idim=odim, n_layers=args.dprenet_layers, n_units=args.dprenet_units, dropout_rate=args.dprenet_dropout_rate, ), torch.nn.Linear(args.dprenet_units, args.adim), ) else: decoder_input_layer = "linear" self.decoder = Decoder( odim=-1, attention_dim=args.adim, attention_heads=args.aheads, linear_units=args.dunits, num_blocks=args.dlayers, dropout_rate=args.transformer_dec_dropout_rate, positional_dropout_rate=args. transformer_dec_positional_dropout_rate, self_attention_dropout_rate=args.transformer_dec_attn_dropout_rate, src_attention_dropout_rate=args. transformer_enc_dec_attn_dropout_rate, input_layer=decoder_input_layer, use_output_layer=False, pos_enc_class=pos_enc_class, normalize_before=args.decoder_normalize_before, concat_after=args.decoder_concat_after, ) # define final projection self.feat_out = torch.nn.Linear(args.adim, odim * args.reduction_factor) self.prob_out = torch.nn.Linear(args.adim, args.reduction_factor) # define postnet self.postnet = (None if args.postnet_layers == 0 else Postnet( idim=idim, odim=odim, n_layers=args.postnet_layers, n_chans=args.postnet_chans, n_filts=args.postnet_filts, use_batch_norm=args.use_batch_norm, dropout_rate=args.postnet_dropout_rate, )) # define loss function self.criterion = TransformerLoss( use_masking=args.use_masking, use_weighted_masking=args.use_weighted_masking, bce_pos_weight=args.bce_pos_weight, ) if self.use_guided_attn_loss: self.attn_criterion = GuidedMultiHeadAttentionLoss( sigma=args.guided_attn_loss_sigma, alpha=args.guided_attn_loss_lambda, ) # initialize parameters self._reset_parameters( init_type=args.transformer_init, init_enc_alpha=args.initial_encoder_alpha, init_dec_alpha=args.initial_decoder_alpha, ) # load pretrained model if args.pretrained_model is not None: self.load_pretrained_model(args.pretrained_model) def _reset_parameters(self, init_type, init_enc_alpha=1.0, init_dec_alpha=1.0): # initialize parameters initialize(self, init_type) # initialize alpha in scaled positional encoding if self.use_scaled_pos_enc: self.encoder.embed[-1].alpha.data = torch.tensor(init_enc_alpha) self.decoder.embed[-1].alpha.data = torch.tensor(init_dec_alpha) def _add_first_frame_and_remove_last_frame(self, ys): ys_in = torch.cat( [ys.new_zeros((ys.shape[0], 1, ys.shape[2])), ys[:, :-1]], dim=1) return ys_in def forward(self, xs, ilens, ys, labels, olens, spembs=None, *args, **kwargs): """Calculate forward propagation. Args: xs (Tensor): Batch of padded acoustic features (B, Tmax, idim). ilens (LongTensor): Batch of lengths of each input batch (B,). ys (Tensor): Batch of padded target features (B, Lmax, odim). olens (LongTensor): Batch of the lengths of each target (B,). spembs (Tensor, optional): Batch of speaker embedding vectors (B, spk_embed_dim). Returns: Tensor: Loss value. """ # remove unnecessary padded part (for multi-gpus) max_ilen = max(ilens) max_olen = max(olens) if max_ilen != xs.shape[1]: xs = xs[:, :max_ilen] if max_olen != ys.shape[1]: ys = ys[:, :max_olen] labels = labels[:, :max_olen] # thin out input frames for reduction factor # (B, Lmax, idim) -> (B, Lmax // r, idim * r) if self.encoder_reduction_factor > 1: B, Lmax, idim = xs.shape if Lmax % self.encoder_reduction_factor != 0: xs = xs[:, :-(Lmax % self.encoder_reduction_factor), :] xs_ds = xs.contiguous().view( B, int(Lmax / self.encoder_reduction_factor), idim * self.encoder_reduction_factor, ) ilens_ds = ilens.new( [ilen // self.encoder_reduction_factor for ilen in ilens]) else: xs_ds, ilens_ds = xs, ilens # forward encoder x_masks = self._source_mask(ilens_ds) hs, hs_masks = self.encoder(xs_ds, x_masks) # integrate speaker embedding if self.spk_embed_dim is not None: hs_int = self._integrate_with_spk_embed(hs, spembs) else: hs_int = hs # thin out frames for reduction factor (B, Lmax, odim) -> (B, Lmax//r, odim) if self.reduction_factor > 1: ys_in = ys[:, self.reduction_factor - 1::self.reduction_factor] olens_in = olens.new( [olen // self.reduction_factor for olen in olens]) else: ys_in, olens_in = ys, olens # add first zero frame and remove last frame for auto-regressive ys_in = self._add_first_frame_and_remove_last_frame(ys_in) # if conv2d, modify mask. Use ceiling division here if "conv2d" in self.transformer_input_layer: ilens_ds_st = ilens_ds.new([((ilen - 2 + 1) // 2 - 2 + 1) // 2 for ilen in ilens_ds]) else: ilens_ds_st = ilens_ds # forward decoder y_masks = self._target_mask(olens_in) zs, _ = self.decoder(ys_in, y_masks, hs_int, hs_masks) # (B, Lmax//r, odim * r) -> (B, Lmax//r * r, odim) before_outs = self.feat_out(zs).view(zs.size(0), -1, self.odim) # (B, Lmax//r, r) -> (B, Lmax//r * r) logits = self.prob_out(zs).view(zs.size(0), -1) # postnet -> (B, Lmax//r * r, odim) if self.postnet is None: after_outs = before_outs else: after_outs = before_outs + self.postnet(before_outs.transpose( 1, 2)).transpose(1, 2) # modifiy mod part of groundtruth if self.reduction_factor > 1: assert olens.ge(self.reduction_factor).all( ), "Output length must be greater than or equal to reduction factor." olens = olens.new( [olen - olen % self.reduction_factor for olen in olens]) max_olen = max(olens) ys = ys[:, :max_olen] labels = labels[:, :max_olen] labels = torch.scatter(labels, 1, (olens - 1).unsqueeze(1), 1.0) # see #3388 # calculate loss values l1_loss, l2_loss, bce_loss = self.criterion(after_outs, before_outs, logits, ys, labels, olens) if self.loss_type == "L1": loss = l1_loss + bce_loss elif self.loss_type == "L2": loss = l2_loss + bce_loss elif self.loss_type == "L1+L2": loss = l1_loss + l2_loss + bce_loss else: raise ValueError("unknown --loss-type " + self.loss_type) report_keys = [ { "l1_loss": l1_loss.item() }, { "l2_loss": l2_loss.item() }, { "bce_loss": bce_loss.item() }, { "loss": loss.item() }, ] # calculate guided attention loss if self.use_guided_attn_loss: # calculate for encoder if "encoder" in self.modules_applied_guided_attn: att_ws = [] for idx, layer_idx in enumerate( reversed(range(len(self.encoder.encoders)))): att_ws += [ self.encoder.encoders[layer_idx].self_attn. attn[:, :self.num_heads_applied_guided_attn] ] if idx + 1 == self.num_layers_applied_guided_attn: break att_ws = torch.cat(att_ws, dim=1) # (B, H*L, T_in, T_in) enc_attn_loss = self.attn_criterion( att_ws, ilens_ds_st, ilens_ds_st ) # TODO(unilight): is changing to ilens_ds_st right? loss = loss + enc_attn_loss report_keys += [{"enc_attn_loss": enc_attn_loss.item()}] # calculate for decoder if "decoder" in self.modules_applied_guided_attn: att_ws = [] for idx, layer_idx in enumerate( reversed(range(len(self.decoder.decoders)))): att_ws += [ self.decoder.decoders[layer_idx].self_attn. attn[:, :self.num_heads_applied_guided_attn] ] if idx + 1 == self.num_layers_applied_guided_attn: break att_ws = torch.cat(att_ws, dim=1) # (B, H*L, T_out, T_out) dec_attn_loss = self.attn_criterion(att_ws, olens_in, olens_in) loss = loss + dec_attn_loss report_keys += [{"dec_attn_loss": dec_attn_loss.item()}] # calculate for encoder-decoder if "encoder-decoder" in self.modules_applied_guided_attn: att_ws = [] for idx, layer_idx in enumerate( reversed(range(len(self.decoder.decoders)))): att_ws += [ self.decoder.decoders[layer_idx].src_attn. attn[:, :self.num_heads_applied_guided_attn] ] if idx + 1 == self.num_layers_applied_guided_attn: break att_ws = torch.cat(att_ws, dim=1) # (B, H*L, T_out, T_in) enc_dec_attn_loss = self.attn_criterion( att_ws, ilens_ds_st, olens_in ) # TODO(unilight): is changing to ilens_ds_st right? loss = loss + enc_dec_attn_loss report_keys += [{ "enc_dec_attn_loss": enc_dec_attn_loss.item() }] # report extra information if self.use_scaled_pos_enc: report_keys += [ { "encoder_alpha": self.encoder.embed[-1].alpha.data.item() }, { "decoder_alpha": self.decoder.embed[-1].alpha.data.item() }, ] self.reporter.report(report_keys) return loss def inference(self, x, inference_args, spemb=None, *args, **kwargs): """Generate the sequence of features given the sequences of acoustic features. Args: x (Tensor): Input sequence of acoustic features (T, idim). inference_args (Namespace): - threshold (float): Threshold in inference. - minlenratio (float): Minimum length ratio in inference. - maxlenratio (float): Maximum length ratio in inference. spemb (Tensor, optional): Speaker embedding vector (spk_embed_dim). Returns: Tensor: Output sequence of features (L, odim). Tensor: Output sequence of stop probabilities (L,). Tensor: Encoder-decoder (source) attention weights (#layers, #heads, L, T). """ # get options threshold = inference_args.threshold minlenratio = inference_args.minlenratio maxlenratio = inference_args.maxlenratio use_att_constraint = getattr(inference_args, "use_att_constraint", False) # keep compatibility if use_att_constraint: logging.warning( "Attention constraint is not yet supported in Transformer. Not enabled." ) # thin out input frames for reduction factor # (B, Lmax, idim) -> (B, Lmax // r, idim * r) if self.encoder_reduction_factor > 1: Lmax, idim = x.shape if Lmax % self.encoder_reduction_factor != 0: x = x[:-(Lmax % self.encoder_reduction_factor), :] x_ds = x.contiguous().view( int(Lmax / self.encoder_reduction_factor), idim * self.encoder_reduction_factor, ) else: x_ds = x # forward encoder x_ds = x_ds.unsqueeze(0) hs, _ = self.encoder(x_ds, None) # integrate speaker embedding if self.spk_embed_dim is not None: spembs = spemb.unsqueeze(0) hs = self._integrate_with_spk_embed(hs, spembs) # set limits of length maxlen = int(hs.size(1) * maxlenratio / self.reduction_factor) minlen = int(hs.size(1) * minlenratio / self.reduction_factor) # initialize idx = 0 ys = hs.new_zeros(1, 1, self.odim) outs, probs = [], [] # forward decoder step-by-step z_cache = self.decoder.init_state(x) while True: # update index idx += 1 # calculate output and stop prob at idx-th step y_masks = subsequent_mask(idx).unsqueeze(0).to(x.device) z, z_cache = self.decoder.forward_one_step( ys, y_masks, hs, cache=z_cache) # (B, adim) outs += [self.feat_out(z).view(self.reduction_factor, self.odim)] # [(r, odim), ...] probs += [torch.sigmoid(self.prob_out(z))[0]] # [(r), ...] # update next inputs ys = torch.cat((ys, outs[-1][-1].view(1, 1, self.odim)), dim=1) # (1, idx + 1, odim) # get attention weights att_ws_ = [] for name, m in self.named_modules(): if isinstance(m, MultiHeadedAttention) and "src" in name: att_ws_ += [m.attn[0, :, -1].unsqueeze(1) ] # [(#heads, 1, T),...] if idx == 1: att_ws = att_ws_ else: # [(#heads, l, T), ...] att_ws = [ torch.cat([att_w, att_w_], dim=1) for att_w, att_w_ in zip(att_ws, att_ws_) ] # check whether to finish generation if int(sum(probs[-1] >= threshold)) > 0 or idx >= maxlen: # check mininum length if idx < minlen: continue outs = (torch.cat(outs, dim=0).unsqueeze(0).transpose(1, 2) ) # (L, odim) -> (1, L, odim) -> (1, odim, L) if self.postnet is not None: outs = outs + self.postnet(outs) # (1, odim, L) outs = outs.transpose(2, 1).squeeze(0) # (L, odim) probs = torch.cat(probs, dim=0) break # concatenate attention weights -> (#layers, #heads, L, T) att_ws = torch.stack(att_ws, dim=0) return outs, probs, att_ws def calculate_all_attentions(self, xs, ilens, ys, olens, spembs=None, skip_output=False, keep_tensor=False, *args, **kwargs): """Calculate all of the attention weights. Args: xs (Tensor): Batch of padded acoustic features (B, Tmax, idim). ilens (LongTensor): Batch of lengths of each input batch (B,). ys (Tensor): Batch of padded target features (B, Lmax, odim). olens (LongTensor): Batch of the lengths of each target (B,). spembs (Tensor, optional): Batch of speaker embedding vectors (B, spk_embed_dim). skip_output (bool, optional): Whether to skip calculate the final output. keep_tensor (bool, optional): Whether to keep original tensor. Returns: dict: Dict of attention weights and outputs. """ with torch.no_grad(): # thin out input frames for reduction factor # (B, Lmax, idim) -> (B, Lmax // r, idim * r) if self.encoder_reduction_factor > 1: B, Lmax, idim = xs.shape if Lmax % self.encoder_reduction_factor != 0: xs = xs[:, :-(Lmax % self.encoder_reduction_factor), :] xs_ds = xs.contiguous().view( B, int(Lmax / self.encoder_reduction_factor), idim * self.encoder_reduction_factor, ) ilens_ds = ilens.new( [ilen // self.encoder_reduction_factor for ilen in ilens]) else: xs_ds, ilens_ds = xs, ilens # forward encoder x_masks = self._source_mask(ilens_ds) hs, hs_masks = self.encoder(xs_ds, x_masks) # integrate speaker embedding if self.spk_embed_dim is not None: hs = self._integrate_with_spk_embed(hs, spembs) # thin out frames for reduction factor # (B, Lmax, odim) -> (B, Lmax//r, odim) if self.reduction_factor > 1: ys_in = ys[:, self.reduction_factor - 1::self.reduction_factor] olens_in = olens.new( [olen // self.reduction_factor for olen in olens]) else: ys_in, olens_in = ys, olens # add first zero frame and remove last frame for auto-regressive ys_in = self._add_first_frame_and_remove_last_frame(ys_in) # forward decoder y_masks = self._target_mask(olens_in) zs, _ = self.decoder(ys_in, y_masks, hs, hs_masks) # calculate final outputs if not skip_output: before_outs = self.feat_out(zs).view(zs.size(0), -1, self.odim) if self.postnet is None: after_outs = before_outs else: after_outs = before_outs + self.postnet( before_outs.transpose(1, 2)).transpose(1, 2) # modifiy mod part of output lengths due to reduction factor > 1 if self.reduction_factor > 1: olens = olens.new( [olen - olen % self.reduction_factor for olen in olens]) # store into dict att_ws_dict = dict() if keep_tensor: for name, m in self.named_modules(): if isinstance(m, MultiHeadedAttention): att_ws_dict[name] = m.attn if not skip_output: att_ws_dict["before_postnet_fbank"] = before_outs att_ws_dict["after_postnet_fbank"] = after_outs else: for name, m in self.named_modules(): if isinstance(m, MultiHeadedAttention): attn = m.attn.cpu().numpy() if "encoder" in name: attn = [ a[:, :l, :l] for a, l in zip(attn, ilens.tolist()) ] elif "decoder" in name: if "src" in name: attn = [ a[:, :ol, :il] for a, il, ol in zip( attn, ilens.tolist(), olens_in.tolist()) ] elif "self" in name: attn = [ a[:, :l, :l] for a, l in zip(attn, olens_in.tolist()) ] else: logging.warning("unknown attention module: " + name) else: logging.warning("unknown attention module: " + name) att_ws_dict[name] = attn if not skip_output: before_outs = before_outs.cpu().numpy() after_outs = after_outs.cpu().numpy() att_ws_dict["before_postnet_fbank"] = [ m[:l].T for m, l in zip(before_outs, olens.tolist()) ] att_ws_dict["after_postnet_fbank"] = [ m[:l].T for m, l in zip(after_outs, olens.tolist()) ] return att_ws_dict def _integrate_with_spk_embed(self, hs, spembs): """Integrate speaker embedding with hidden states. Args: hs (Tensor): Batch of hidden state sequences (B, Tmax, adim). spembs (Tensor): Batch of speaker embeddings (B, spk_embed_dim). Returns: Tensor: Batch of integrated hidden state sequences (B, Tmax, adim) """ if self.spk_embed_integration_type == "add": # apply projection and then add to hidden states spembs = self.projection(F.normalize(spembs)) hs = hs + spembs.unsqueeze(1) elif self.spk_embed_integration_type == "concat": # concat hidden states with spk embeds and then apply projection spembs = F.normalize(spembs).unsqueeze(1).expand( -1, hs.size(1), -1) hs = self.projection(torch.cat([hs, spembs], dim=-1)) else: raise NotImplementedError("support only add or concat.") return hs def _source_mask(self, ilens): """Make masks for self-attention. Args: ilens (LongTensor or List): Batch of lengths (B,). Returns: Tensor: Mask tensor for self-attention. dtype=torch.uint8 in PyTorch 1.2- dtype=torch.bool in PyTorch 1.2+ (including 1.2) Examples: >>> ilens = [5, 3] >>> self._source_mask(ilens) tensor([[[1, 1, 1, 1, 1], [[1, 1, 1, 0, 0]]], dtype=torch.uint8) """ x_masks = make_non_pad_mask(ilens).to(next(self.parameters()).device) return x_masks.unsqueeze(-2) def _target_mask(self, olens): """Make masks for masked self-attention. Args: olens (LongTensor or List): Batch of lengths (B,). Returns: Tensor: Mask tensor for masked self-attention. dtype=torch.uint8 in PyTorch 1.2- dtype=torch.bool in PyTorch 1.2+ (including 1.2) Examples: >>> olens = [5, 3] >>> self._target_mask(olens) tensor([[[1, 0, 0, 0, 0], [1, 1, 0, 0, 0], [1, 1, 1, 0, 0], [1, 1, 1, 1, 0], [1, 1, 1, 1, 1]], [[1, 0, 0, 0, 0], [1, 1, 0, 0, 0], [1, 1, 1, 0, 0], [1, 1, 1, 0, 0], [1, 1, 1, 0, 0]]], dtype=torch.uint8) """ y_masks = make_non_pad_mask(olens).to(next(self.parameters()).device) s_masks = subsequent_mask(y_masks.size(-1), device=y_masks.device).unsqueeze(0) return y_masks.unsqueeze(-2) & s_masks @property def base_plot_keys(self): """Return base key names to plot during training. keys should match what `chainer.reporter` reports. If you add the key `loss`, the reporter will report `main/loss` and `validation/main/loss` values. also `loss.png` will be created as a figure visulizing `main/loss` and `validation/main/loss` values. Returns: list: List of strings which are base keys to plot during training. """ plot_keys = ["loss", "l1_loss", "l2_loss", "bce_loss"] if self.use_scaled_pos_enc: plot_keys += ["encoder_alpha", "decoder_alpha"] if self.use_guided_attn_loss: if "encoder" in self.modules_applied_guided_attn: plot_keys += ["enc_attn_loss"] if "decoder" in self.modules_applied_guided_attn: plot_keys += ["dec_attn_loss"] if "encoder-decoder" in self.modules_applied_guided_attn: plot_keys += ["enc_dec_attn_loss"] return plot_keys
def __init__(self, idim, odim, args, ignore_id=-1): """Construct an E2E object. :param int idim: dimension of inputs :param int odim: dimension of outputs :param Namespace args: argument Namespace containing options """ torch.nn.Module.__init__(self) # fill missing arguments for compatibility args = fill_missing_args(args, self.add_arguments) if args.transformer_attn_dropout_rate is None: args.transformer_attn_dropout_rate = args.dropout_rate self.encoder = Encoder( idim=idim, selfattention_layer_type=args. transformer_encoder_selfattn_layer_type, attention_dim=args.adim, attention_heads=args.aheads, conv_wshare=args.wshare, conv_kernel_length=args.ldconv_encoder_kernel_length, conv_usebias=args.ldconv_usebias, linear_units=args.eunits, num_blocks=args.elayers, input_layer=args.transformer_input_layer, dropout_rate=args.dropout_rate, positional_dropout_rate=args.dropout_rate, attention_dropout_rate=args.transformer_attn_dropout_rate, ) self.decoder = Decoder( odim=odim, selfattention_layer_type=args. transformer_decoder_selfattn_layer_type, attention_dim=args.adim, attention_heads=args.aheads, conv_wshare=args.wshare, conv_kernel_length=args.ldconv_decoder_kernel_length, conv_usebias=args.ldconv_usebias, linear_units=args.dunits, num_blocks=args.dlayers, dropout_rate=args.dropout_rate, positional_dropout_rate=args.dropout_rate, self_attention_dropout_rate=args.transformer_attn_dropout_rate, src_attention_dropout_rate=args.transformer_attn_dropout_rate, ) self.pad = 0 # use <blank> for padding self.sos = odim - 1 self.eos = odim - 1 self.odim = odim self.ignore_id = ignore_id self.subsample = get_subsample(args, mode="st", arch="transformer") self.reporter = Reporter() self.criterion = LabelSmoothingLoss( self.odim, self.ignore_id, args.lsm_weight, args.transformer_length_normalized_loss, ) # submodule for ASR task self.mtlalpha = args.mtlalpha self.asr_weight = args.asr_weight if self.asr_weight > 0 and args.mtlalpha < 1: self.decoder_asr = Decoder( odim=odim, attention_dim=args.adim, attention_heads=args.aheads, linear_units=args.dunits, num_blocks=args.dlayers, dropout_rate=args.dropout_rate, positional_dropout_rate=args.dropout_rate, self_attention_dropout_rate=args.transformer_attn_dropout_rate, src_attention_dropout_rate=args.transformer_attn_dropout_rate, ) # submodule for MT task self.mt_weight = args.mt_weight if self.mt_weight > 0: self.encoder_mt = Encoder( idim=odim, attention_dim=args.adim, attention_heads=args.aheads, linear_units=args.dunits, num_blocks=args.dlayers, input_layer="embed", dropout_rate=args.dropout_rate, positional_dropout_rate=args.dropout_rate, attention_dropout_rate=args.transformer_attn_dropout_rate, padding_idx=0, ) self.reset_parameters( args) # NOTE: place after the submodule initialization self.adim = args.adim # used for CTC (equal to d_model) if self.asr_weight > 0 and args.mtlalpha > 0.0: self.ctc = CTC(odim, args.adim, args.dropout_rate, ctc_type=args.ctc_type, reduce=True) else: self.ctc = None # translation error calculator self.error_calculator = MTErrorCalculator(args.char_list, args.sym_space, args.sym_blank, args.report_bleu) # recognition error calculator self.error_calculator_asr = ASRErrorCalculator( args.char_list, args.sym_space, args.sym_blank, args.report_cer, args.report_wer, ) self.rnnlm = None # multilingual E2E-ST related self.multilingual = getattr(args, "multilingual", False) self.replace_sos = getattr(args, "replace_sos", False)
def __init__(self, idim, odim, args, ignore_id=-1): """Construct an E2E object. :param int idim: dimension of inputs :param int odim: dimension of outputs :param Namespace args: argument Namespace containing options """ torch.nn.Module.__init__(self) if args.transformer_attn_dropout_rate is None: args.transformer_attn_dropout_rate = args.dropout_rate self.cn_encoder = Encoder( idim=idim, attention_dim=args.adim, attention_heads=args.aheads, linear_units=args.eunits, num_blocks=args.elayers, input_layer=args.transformer_input_layer, dropout_rate=args.dropout_rate, positional_dropout_rate=args.dropout_rate, attention_dropout_rate=args.transformer_attn_dropout_rate) self.en_encoder = Encoder( idim=idim, attention_dim=args.adim, attention_heads=args.aheads, linear_units=args.eunits, num_blocks=args.elayers, input_layer=args.transformer_input_layer, dropout_rate=args.dropout_rate, positional_dropout_rate=args.dropout_rate, attention_dropout_rate=args.transformer_attn_dropout_rate) self.decoder = Decoder( odim=odim, attention_dim=args.adim, attention_heads=args.aheads, linear_units=args.dunits, num_blocks=args.dlayers, dropout_rate=args.dropout_rate, positional_dropout_rate=args.dropout_rate, self_attention_dropout_rate=args.transformer_attn_dropout_rate, src_attention_dropout_rate=args.transformer_attn_dropout_rate) self.sos = odim - 1 self.eos = odim - 1 self.odim = odim self.ignore_id = ignore_id self.subsample = [1] self.reporter = Reporter() # self.lsm_weight = a self.criterion = LabelSmoothingLoss( self.odim, self.ignore_id, args.lsm_weight, args.transformer_length_normalized_loss) # self.verbose = args.verbose self.adim = args.adim self.mtlalpha = args.mtlalpha if args.mtlalpha > 0.0: # Note: here CTC also need to have seperate ctc_lo layer self.ctc = CTC(odim, args.adim, args.dropout_rate, ctc_type=args.ctc_type, reduce=True) else: self.ctc = None if args.report_cer or args.report_wer: from espnet.nets.e2e_asr_common import ErrorCalculator self.error_calculator = ErrorCalculator(args.char_list, args.sym_space, args.sym_blank, args.report_cer, args.report_wer) else: self.error_calculator = None self.rnnlm = None # yzl23 config self.remove_blank_in_ctc_mode = True self.reset_parameters(args) # reset params at the last
numpy.testing.assert_allclose(y.numpy(), y_fast.numpy(), rtol=RTOL) if __name__ == "__main__": # benchmark with synth dataset from time import time import matplotlib.pyplot as plt adim = 4 odim = 5 model = "decoder" if model == "decoder": decoder = Decoder(odim=odim, attention_dim=adim, linear_units=3, num_blocks=2, dropout_rate=0.0) decoder.eval() else: encoder = Encoder(idim=odim, attention_dim=adim, linear_units=3, num_blocks=2, dropout_rate=0.0, input_layer="embed") encoder.eval() xlen = 100 xs = torch.randint(0, odim, (1, xlen)) memory = torch.randn(2, 500, adim)
def __init__(self, idim, odim, args, ignore_id=-1): """Construct an E2E object. :param int idim: dimension of inputs :param int odim: dimension of outputs :param Namespace args: argument Namespace containing options """ torch.nn.Module.__init__(self) # fill missing arguments for compatibility args = fill_missing_args(args, self.add_arguments) if args.transformer_attn_dropout_rate is None: args.transformer_attn_dropout_rate = args.dropout_rate self.encoder = Encoder( idim=idim, selfattention_layer_type=args. transformer_encoder_selfattn_layer_type, attention_dim=args.adim, attention_heads=args.aheads, conv_wshare=args.wshare, conv_kernel_length=args.ldconv_encoder_kernel_length, conv_usebias=args.ldconv_usebias, linear_units=args.eunits, num_blocks=args.elayers, input_layer=args.transformer_input_layer, dropout_rate=args.dropout_rate, positional_dropout_rate=args.dropout_rate, attention_dropout_rate=args.transformer_attn_dropout_rate, ) if args.mtlalpha < 1: self.decoder = Decoder( odim=odim, selfattention_layer_type=args. transformer_decoder_selfattn_layer_type, attention_dim=args.adim, attention_heads=args.aheads, conv_wshare=args.wshare, conv_kernel_length=args.ldconv_decoder_kernel_length, conv_usebias=args.ldconv_usebias, linear_units=args.dunits, num_blocks=args.dlayers, dropout_rate=args.dropout_rate, positional_dropout_rate=args.dropout_rate, self_attention_dropout_rate=args.transformer_attn_dropout_rate, src_attention_dropout_rate=args.transformer_attn_dropout_rate, ) self.criterion = LabelSmoothingLoss( odim, ignore_id, args.lsm_weight, args.transformer_length_normalized_loss, ) else: self.decoder = None self.criterion = None self.blank = 0 self.decoder_mode = args.decoder_mode if self.decoder_mode == "maskctc": self.mask_token = odim - 1 self.sos = odim - 2 self.eos = odim - 2 else: self.sos = odim - 1 self.eos = odim - 1 self.odim = odim self.ignore_id = ignore_id self.subsample = get_subsample(args, mode="asr", arch="transformer") self.reporter = Reporter() self.reset_parameters(args) self.adim = args.adim # used for CTC (equal to d_model) self.mtlalpha = args.mtlalpha if args.mtlalpha > 0.0: self.ctc = CTC(odim, args.adim, args.dropout_rate, ctc_type=args.ctc_type, reduce=True) else: self.ctc = None if args.report_cer or args.report_wer: self.error_calculator = ErrorCalculator( args.char_list, args.sym_space, args.sym_blank, args.report_cer, args.report_wer, ) else: self.error_calculator = None self.rnnlm = None
def __init__(self, num_time_mask=2, num_freq_mask=2, freq_mask_length=15, time_mask_length=15, feature_dim=320, model_size=512, feed_forward_size=1024, hidden_size=64, dropout=0.1, num_head=8, num_encoder_layer=6, num_decoder_layer=6, vocab_path='testing_vocab.model', max_feature_length=1024, max_token_length=50, enable_spec_augment=True, share_weight=True, smoothing=0.1, restrict_left_length=20, restrict_right_length=20, mtlalpha=0.2, report_wer=True): super(Transformer, self).__init__() self.enable_spec_augment = enable_spec_augment self.max_token_length = max_token_length self.restrict_left_length = restrict_left_length self.restrict_right_length = restrict_right_length self.vocab = Vocab(vocab_path) self.sos = self.vocab.bos_id self.eos = self.vocab.eos_id self.adim = model_size self.odim = self.vocab.vocab_size self.ignore_id = self.vocab.pad_id if enable_spec_augment: self.spec_augment = SpecAugment( num_time_mask=num_time_mask, num_freq_mask=num_freq_mask, freq_mask_length=freq_mask_length, time_mask_length=time_mask_length, max_sequence_length=max_feature_length) self.encoder = Encoder(idim=feature_dim, attention_dim=model_size, attention_heads=num_head, linear_units=feed_forward_size, num_blocks=num_encoder_layer, dropout_rate=dropout, positional_dropout_rate=dropout, attention_dropout_rate=dropout, input_layer='linear', padding_idx=self.vocab.pad_id) self.decoder = Decoder(odim=self.vocab.vocab_size, attention_dim=model_size, attention_heads=num_head, linear_units=feed_forward_size, num_blocks=num_decoder_layer, dropout_rate=dropout, positional_dropout_rate=dropout, self_attention_dropout_rate=dropout, src_attention_dropout_rate=0, input_layer='embed', use_output_layer=False) self.decoder_linear = t.nn.Linear(model_size, self.vocab.vocab_size, bias=True) self.decoder_switch_linear = t.nn.Linear(model_size, 4, bias=True) self.criterion = LabelSmoothingLoss(size=self.odim, smoothing=smoothing, padding_idx=self.vocab.pad_id, normalize_length=True) self.switch_criterion = LabelSmoothingLoss( size=4, smoothing=0, padding_idx=self.vocab.pad_id, normalize_length=True) self.mtlalpha = mtlalpha if mtlalpha > 0.0: self.ctc = CTC(self.odim, eprojs=self.adim, dropout_rate=dropout, ctc_type='builtin', reduce=False) else: self.ctc = None if report_wer: from espnet.nets.e2e_asr_common import ErrorCalculator def load_token_list(path=vocab_path.replace('.model', '.vocab')): with open(path) as reader: data = reader.readlines() data = [i.split('\t')[0] for i in data] return data self.char_list = load_token_list() self.error_calculator = ErrorCalculator( char_list=self.char_list, sym_space=' ', sym_blank=self.vocab.blank_token, report_wer=True) else: self.error_calculator = None self.rnnlm = None self.reporter = Reporter() self.switch_loss = LabelSmoothingLoss(size=4, smoothing=0, padding_idx=0) print('initing') initialize(self, init_type='xavier_normal') print('inited')