Beispiel #1
0
    def __init__(self, args):
        super(ParallelTraining, self).__init__()
        if args.tokenizer == 'char':
            _tokenizer = CharTokenizer()
        else:
            print('use BPE 1000')
            _tokenizer = HuggingFaceTokenizer()  # use BPE-1000
        audio_feature = args.audio_feat
        if args.concat:
            audio_feature *= 3

        self.tokenizer = _tokenizer
        self.loss_fn = RNNTLoss(blank=0)
        self.model = Transducer(
            audio_feature,
            _tokenizer.vocab_size,
            args.vocab_dim,  # vocab embedding dim
            args.h_dim,  # hidden dim
            args.layers,
            pred_num_layers=args.pred_layers,
            dropout=args.dropout)
        self.latest_alignment = None
        self.steps = 0
        self.epoch = 0
        self.args = args
        self.best_wer = 1000
Beispiel #2
0
    def __init__(self, input_size, vocab_size, hidden_size, decoder_num_layers, encoder_num_layers, dropout=0.5, blank=0, bidirectional=False, LM_model_path=False):
        super(Transducer, self).__init__()
        self.blank = blank
        self.vocab_size = vocab_size
        self.hidden_size = hidden_size
        self.decoder_num_layers = decoder_num_layers
        self.encoder_num_layers = encoder_num_layers

        self.loss = RNNTLoss()

        self.decoder = DecoderModel(embed_size=vocab_size,
                                    vocab_size=vocab_size,
                                    num_layers=decoder_num_layers,
                                    hidden_size=hidden_size,
                                    dropout=dropout)
        if LM_model_path:
            self.decoder.load_state_dict(torch.load(LM_model_path), strict=False)

        self.encoder = EncoderModel(input_size=input_size,
                                    vocab_size=hidden_size,
                                    hidden_size=hidden_size,
                                    num_layers=encoder_num_layers,
                                    dropout=dropout,
                                    bidirectional=bidirectional)

        self.fc1 = nn.Linear(2 * hidden_size, hidden_size)
        self.fc2 = nn.Linear(hidden_size, vocab_size)
        self.dropout = nn.Dropout(0.2)
 def __init__(self,
              n_cnn_layers,
              n_rnn_layers,
              rnn_dim,
              n_class,
              n_feats,
              hidden_size,
              vocab_size,
              output_size,
              input_size,
              inner_dim,
              n_layers,
              stride=2,
              dropout=0.2,
              share_weight=False):
     super(Transducer, self).__init__()
     # define encoder
     self.encoder = SpeechRecognitionModel(n_cnn_layers,
                                           n_rnn_layers,
                                           rnn_dim,
                                           n_class,
                                           n_feats,
                                           stride=2,
                                           dropout=0.1)
     # define decoder
     self.decoder = BaseDecoder(hidden_size,
                                vocab_size,
                                output_size,
                                n_layers,
                                dropout=0.2,
                                share_weight=False)
     # define JointNet
     self.joint = JointNet(input_size, inner_dim, vocab_size)
     # Use RNN-T Loss function to evaluate the goodness of the model, most important part in this project
     self.criterion = RNNTLoss()
Beispiel #4
0
 def __init__(self,
              input_size,
              vocab_size,
              enc_hidden_size,
              enc_num_layers,
              dec_units,
              dec_layers,
              dropout=.5,
              blank=0,
              bidirectional=False):
     super(Transducer, self).__init__()
     self.blank = blank
     self.vocab_size = vocab_size
     self.enc_hidden_size = enc_hidden_size
     self.enc_num_layers = enc_num_layers
     self.loss = RNNTLoss()
     # NOTE encoder & decoder only use lstm
     self.encoder = RNNModel(input_size,
                             enc_hidden_size,
                             enc_hidden_size,
                             enc_num_layers,
                             dropout,
                             bidirectional=bidirectional)
     self.embed = nn.Embedding(vocab_size, 512, padding_idx=blank)
     self.decoder = nn.LSTM(512,
                            dec_units,
                            dec_layers,
                            batch_first=True,
                            dropout=dropout)
     self.fc1 = nn.Linear(dec_units + enc_hidden_size, enc_hidden_size)
     self.fc2 = nn.Linear(enc_hidden_size, vocab_size)
     self._nparams_dict = {}
     self._nparams = 0
 def __init__(self,
              vocab_embed_size,
              vocab_size,
              input_size,
              enc_hidden_size,
              enc_layers,
              enc_dropout,
              enc_proj_size,
              dec_hidden_size,
              dec_layers,
              dec_dropout,
              dec_proj_size,
              joint_size,
              blank=NUL):
     super().__init__()
     self.blank = blank
     # Encoder
     self.encoder = Encoder(input_size=input_size,
                            hidden_size=enc_hidden_size,
                            num_layers=enc_layers,
                            dropout=enc_dropout,
                            proj_size=enc_proj_size)
     # Decoder
     self.decoder = Decoder(vocab_embed_size=vocab_embed_size,
                            vocab_size=vocab_size,
                            hidden_size=dec_hidden_size,
                            num_layers=dec_layers,
                            dropout=dec_dropout,
                            proj_size=dec_proj_size)
     # Joint
     self.joint = Joint(input_size=enc_proj_size + dec_proj_size,
                        hidden_size=joint_size,
                        vocab_size=vocab_size)
     self.loss_fn = RNNTLoss(blank=blank)
Beispiel #6
0
    def __init__(self, config):
        super(Transducer, self).__init__()
        # define encoder
        self.config = config
        self.encoder = build_encoder(config)
        # define decoder
        self.decoder = build_decoder(config)

        if config.lm_pre_train:
            lm_path = os.path.join(home_dir, config.lm_model_path)
            if os.path.exists(lm_path):
                print('load language model')
                self.decoder.load_state_dict(torch.load(lm_path), strict=False)

        if config.ctc_pre_train:
            ctc_path = os.path.join(home_dir, config.ctc_model_path)
            if os.path.exists(ctc_path):
                print('load ctc pretrain model')
                self.encoder.load_state_dict(torch.load(ctc_path),
                                             strict=False)

        # define JointNet
        self.joint = JointNet(input_size=config.joint.input_size,
                              inner_dim=config.joint.inner_size,
                              vocab_size=config.vocab_size)

        if config.share_embedding:
            assert self.decoder.embedding.weight.size(
            ) == self.joint.project_layer.weight.size(), '%d != %d' % (
                self.decoder.embedding.weight.size(1),
                self.joint.project_layer.weight.size(1))
            self.joint.project_layer.weight = self.decoder.embedding.weight

        self.crit = RNNTLoss()
Beispiel #7
0
 def __init__(self,
              input_size,
              vocab_size,
              hidden_size,
              num_layers,
              dropout=.5,
              blank=0,
              bidirectional=False):
     super(Transducer, self).__init__()
     self.blank = blank
     self.vocab_size = vocab_size
     self.hidden_size = hidden_size
     self.num_layers = num_layers
     self.loss = RNNTLoss(size_average=True)
     # NOTE encoder & decoder only use lstm
     self.encoder = RNNModel(input_size,
                             hidden_size,
                             hidden_size,
                             num_layers,
                             dropout,
                             bidirectional=bidirectional)
     self.embed = nn.Embedding(vocab_size,
                               vocab_size - 1,
                               padding_idx=blank)
     self.embed.weight.data[1:] = torch.eye(vocab_size - 1)
     self.embed.weight.requires_grad = False
     # self.decoder = RNNModel(vocab_size-1, vocab_size, hidden_size, 1, dropout)
     self.decoder = nn.LSTM(vocab_size - 1,
                            hidden_size,
                            1,
                            batch_first=True,
                            dropout=dropout)
     self.fc1 = nn.Linear(2 * hidden_size, hidden_size)
     self.fc2 = nn.Linear(hidden_size, vocab_size)
Beispiel #8
0
    def __init__(self, trans_type, blank_id):
        """Construct an TransLoss object."""
        super(TransLoss, self).__init__()

        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

        if trans_type == "warp-transducer":
            from warprnnt_pytorch import RNNTLoss

            self.trans_loss = RNNTLoss(blank=blank_id)
        elif trans_type == "warp-rnnt":
            if device.type == "cuda":
                try:
                    from warp_rnnt import rnnt_loss

                    self.trans_loss = rnnt_loss
                except ImportError:
                    raise ImportError(
                        "warp-rnnt is not installed. Please re-setup"
                        " espnet or use 'warp-transducer'")
            else:
                raise ValueError("warp-rnnt is not supported in CPU mode")
        else:
            raise NotImplementedError

        self.trans_type = trans_type
        self.blank_id = blank_id
Beispiel #9
0
    def __init__(self, config):
        super(Transducer, self).__init__()
        self.config = config
        # define encoder
        self.encoder = build_encoder(config)
        # define decoder
        self.decoder = build_decoder(config)
        # define JointNet
        self.joint = JointNet(
            input_size=config.joint.input_size,
            inner_dim=config.joint.inner_size,
            vocab_size=config.vocab_size,
            joint=config.joint.type if config.joint.type else "concat")

        if config.share_embedding:
            assert self.decoder.embedding.weight.size(
            ) == self.joint.project_layer.weight.size(), '%d != %d' % (
                self.decoder.embedding.weight.size(1),
                self.joint.project_layer.weight.size(1))
            self.joint.project_layer.weight = self.decoder.embedding.weight

        self.transducer_loss = RNNTLoss()

        # multask learning (loss_decoder and loss_encoder)
        if config.enc.ctc_weight and config.enc.ctc_weight > 0.0:
            self.ctc_loss = nn.CTCLoss()
            self.encoder_project_layer = nn.Sequential(
                nn.Tanh(),
                nn.Linear(self.config.enc.output_size, self.config.vocab_size))
        if config.dec.ce_weight and config.dec.ce_weight > 0.0:
            self.nll_loss = nll_loss
            self.decoder_project_layer = nn.Sequential(
                nn.Tanh(),
                nn.Linear(self.config.dec.output_size, self.config.vocab_size))
    def set_model(self):
        ''' Setup ASR model and optimizer '''
        # Model
        init_adadelta = self.config['hparas']['optimizer'] == 'Adadelta'
        #self.model = ASR(self.feat_dim, self.vocab_size, init_adadelta, **self.config['model']).to(self.device)
        self.model = Transducer(self.feat_dim, self.vocab_size, init_adadelta,
                                **self.config['model']).to(self.device)

        # load pre-trained model
        print('loading pre-trained predictor')
        self.model.predictor.load_state_dict(
            torch.load(
                '/home/eden.chien/DLHLP/DEV/ckpt/lm_dlhlp_sd0/best_ppx.pth',
                map_location=self.device
                if self.mode == 'train' else 'cpu')['model'])
        print('loading pre-trained encoder')
        self.model.encoder.load_state_dict(
            torch.load(
                '/home/eden.chien/DLHLP/RNNT/ckpt/asr_dlhlp_ctc05_layer5_sd0/best_ctc.pth',
                map_location=self.device
                if self.mode == 'train' else 'cpu')['encoder'])
        #self.verbose(self.model.create_msg())
        #model_paras = [{'params': self.model.parameters()}]

        # Losses
        # self.seq_loss = torch.nn.CrossEntropyLoss(ignore_index=0)
        # Note: zero_infinity=False is unstable?
        # self.ctc_loss = torch.nn.CTCLoss(blank=0, zero_infinity=False)

        from warprnnt_pytorch import RNNTLoss
        self.rnntloss = RNNTLoss()
        # Plug-ins
        '''
        self.emb_fuse = False
        self.emb_reg = ('emb' in self.config) and (
            self.config['emb']['enable'])
        if self.emb_reg:
            from src.plugin import EmbeddingRegularizer
            self.emb_decoder = EmbeddingRegularizer(
                self.tokenizer, self.model.dec_dim, **self.config['emb']).to(self.device)
            model_paras.append({'params': self.emb_decoder.parameters()})
            self.emb_fuse = self.emb_decoder.apply_fuse
            if self.emb_fuse:
                self.seq_loss = torch.nn.NLLLoss(ignore_index=0)
            self.verbose(self.emb_decoder.create_msg())
        '''
        # Optimizer
        # self.optimizer = Optimizer(model_paras, **self.config['hparas'])
        # self.verbose(self.optimizer.create_msg())

        self.optimizer = torch.optim.Adam(self.model.parameters(), lr=0.0001)
        self.lr_scheduler = ReduceLROnPlateau(self.optimizer,
                                              'min',
                                              patience=5)
        # Enable AMP if needed
        self.enable_apex()
    def __init__(self,
                 eprojs,
                 odim,
                 dtype,
                 dlayers,
                 dunits,
                 blank,
                 att,
                 embed_dim,
                 joint_dim,
                 dropout=0.0,
                 dropout_embed=0.0,
                 rnnt_type='warp-transducer'):
        """Transducer with attention initializer."""
        super(DecoderRNNTAtt, self).__init__()

        self.embed = torch.nn.Embedding(odim, embed_dim, padding_idx=blank)
        self.dropout_emb = torch.nn.Dropout(p=dropout_embed)

        if dtype == "lstm":
            dec_net = torch.nn.LSTMCell
        else:
            dec_net = torch.nn.GRUCell

        self.decoder = torch.nn.ModuleList(
            [dec_net((embed_dim + eprojs), dunits)])
        self.dropout_dec = torch.nn.ModuleList([torch.nn.Dropout(p=dropout)])

        for _ in six.moves.range(1, dlayers):
            self.decoder += [dec_net(dunits, dunits)]
            self.dropout_dec += [torch.nn.Dropout(p=dropout)]

        if rnnt_type == 'warp-transducer':
            from warprnnt_pytorch import RNNTLoss

            self.rnnt_loss = RNNTLoss(blank=blank)
        else:
            raise NotImplementedError

        self.lin_enc = torch.nn.Linear(eprojs, joint_dim)
        self.lin_dec = torch.nn.Linear(dunits, joint_dim, bias=False)
        self.lin_out = torch.nn.Linear(joint_dim, odim)

        self.att = att

        self.dtype = dtype
        self.dlayers = dlayers
        self.dunits = dunits
        self.embed_dim = embed_dim
        self.joint_dim = joint_dim
        self.odim = odim

        self.rnnt_type = rnnt_type

        self.ignore_id = -1
        self.blank = blank
Beispiel #12
0
    def __init__(self, trans_type, blank_id):
        """Construct an TransLoss object."""
        super(TransLoss, self).__init__()

        if trans_type == "warp-transducer":
            self.trans_loss = RNNTLoss(blank=blank_id)
        else:
            raise NotImplementedError

        self.blank_id = blank_id
Beispiel #13
0
    def __init__(self, hidden_size, vocab_size):
        super(JointNetwork, self).__init__()
        self.vocab_size = vocab_size

        self.linear_enc = nn.Linear(in_features=hidden_size,
                                    out_features=hidden_size)
        self.linear_pred = nn.Linear(in_features=hidden_size,
                                     out_features=hidden_size)
        self.linear_feed_forward = nn.Linear(in_features=hidden_size * 2,
                                             out_features=vocab_size)
        self.tanH = nn.Tanh()
        self.loss = RNNTLoss()
Beispiel #14
0
 def __init__(self,
              vocab_embed_size,
              vocab_size,
              input_size,
              enc_hidden_size,
              enc_layers,
              enc_dropout,
              enc_proj_size,
              dec_hidden_size,
              dec_layers,
              dec_dropout,
              dec_proj_size,
              joint_size,
              enc_time_reductions=[1],
              blank=NUL,
              module_type='LSTM',
              output_loss=True):
     super().__init__()
     self.blank = blank
     # Encoder
     if module_type not in ['GRU', 'LSTM']:
         raise ValueError('Unsupported module type')
     if module_type == 'GRU':
         module = ResLayerNormGRU
     else:
         module = ResLayerNormLSTM
     self.encoder = Encoder(input_size=input_size,
                            hidden_size=enc_hidden_size,
                            num_layers=enc_layers,
                            dropout=enc_dropout,
                            proj_size=enc_proj_size,
                            time_reductions=enc_time_reductions,
                            module=module)
     # Decoder
     self.decoder = Decoder(vocab_embed_size=vocab_embed_size,
                            vocab_size=vocab_size,
                            hidden_size=dec_hidden_size,
                            num_layers=dec_layers,
                            dropout=dec_dropout,
                            proj_size=dec_proj_size)
     # Joint
     self.joint = Joint(input_size=enc_proj_size + dec_proj_size,
                        hidden_size=joint_size,
                        vocab_size=vocab_size)
     self.output_loss = output_loss
     if output_loss:
         self.loss_fn = RNNTLoss(blank=blank)
Beispiel #15
0
    def __init__(self):
        super(ParallelTraining, self).__init__()
        _, _, input_size = build_transform(feature_type=FLAGS.feature,
                                           feature_size=FLAGS.feature_size,
                                           n_fft=FLAGS.n_fft,
                                           win_length=FLAGS.win_length,
                                           hop_length=FLAGS.hop_length,
                                           delta=FLAGS.delta,
                                           cmvn=FLAGS.cmvn,
                                           downsample=FLAGS.downsample,
                                           T_mask=FLAGS.T_mask,
                                           T_num_mask=FLAGS.T_num_mask,
                                           F_mask=FLAGS.F_mask,
                                           F_num_mask=FLAGS.F_num_mask)
        self.log_path = None
        self.loss_fn = RNNTLoss(blank=NUL)

        if FLAGS.tokenizer == 'char':
            self.tokenizer = CharTokenizer(cache_dir=self.logdir)
        else:
            self.tokenizer = HuggingFaceTokenizer(cache_dir='BPE-2048',
                                                  vocab_size=FLAGS.bpe_size)
        self.vocab_size = self.tokenizer.vocab_size
        print(FLAGS.enc_type)

        self.model = Transducer(
            vocab_embed_size=FLAGS.vocab_embed_size,
            vocab_size=self.vocab_size,
            input_size=input_size,
            enc_hidden_size=FLAGS.enc_hidden_size,
            enc_layers=FLAGS.enc_layers,
            enc_dropout=FLAGS.enc_dropout,
            enc_proj_size=FLAGS.enc_proj_size,
            dec_hidden_size=FLAGS.dec_hidden_size,
            dec_layers=FLAGS.dec_layers,
            dec_dropout=FLAGS.dec_dropout,
            dec_proj_size=FLAGS.dec_proj_size,
            joint_size=FLAGS.joint_size,
            module_type=FLAGS.enc_type,
            output_loss=False,
        )
        self.latest_alignment = None
        self.steps = 0
        self.epoch = 0
        self.best_wer = 1000
Beispiel #16
0
    def __init__(self, config):
        super(Transducer, self).__init__()
        # define encoder
        self.config = config
        self.encoder = build_encoder(config)
        # define decoder
        self.decoder = build_decoder(config)
        # define JointNet
        self.joint = JointNet(
            input_size=config.joint.input_size,
            inner_dim=config.joint.inner_size,
            vocab_size=config.vocab_size
            )

        if config.share_embedding:
            assert self.decoder.embedding.weight.size() == self.joint.project_layer.weight.size(), '%d != %d' % (self.decoder.embedding.weight.size(1),  self.joint.project_layer.weight.size(1))
            self.joint.project_layer.weight = self.decoder.embedding.weight

        self.crit = RNNTLoss(blank=28)
Beispiel #17
0
    def __init__(self, config):
        super(Transducer, self).__init__()
        #build cnn
        # self.conv1 = nn.Sequential(
        #     nn.Conv2d(in_channels=1,out_channels=1,kernel_size=5,stride=1,padding=(2,2)),
        #     nn.ReLU(),
        #     nn.MaxPool2d(kernel_size=2,stride=2)
        # )
        # self.conv2 = nn.Sequential(
        #     nn.Conv2d(in_channels=1, out_channels=1, kernel_size=5, stride=1, padding=(2, 2)),
        #     nn.ReLU(),
        #     nn.MaxPool2d(kernel_size=2,stride=2)
        # )
        self.config = config
        self.alpha = config.alpha
        # define encoder
        self.encoder = build_encoder(config.enc)
        self.fir_enc = buildFir_enc(config.fir_enc)
        # define decoder
        self.decoder = build_decoder(config.dec)
        self.max_target_length = config.max_target_length
        # define JointNet
        self.joint = JointNet(input_size=config.joint.input_size,
                              inner_dim=config.joint.inner_size,
                              vocab_size=config.vocab_size)

        if config.share_embedding:
            assert self.decoder.embedding.weight.size(
            ) == self.joint.project_layer.weight.size(), '%d != %d' % (
                self.decoder.embedding.weight.size(1),
                self.joint.project_layer.weight.size(1))
            self.joint.project_layer.weight = self.decoder.embedding.weight

        self.rnnt = RNNTLoss()
        self.crit = nn.CrossEntropyLoss()

        #if hiratical lstm or not
        self.fir_enc_or_not = config.fir_enc_or_not
Beispiel #18
0
    def __init__(self, config):
        super(Transducer, self).__init__()
        # define encoder
        self.config = config

        # self.encoder = BuildEncoder(config)
        self.encoder = build_encoder(config)
        self.project_layer = nn.Linear(320, 213)
        # define decoder
        self.decoder = build_decoder(config)
        # define JointNet
        self.joint = JointNet(input_size=config.joint.input_size,
                              inner_dim=config.joint.inner_size,
                              vocab_size=config.vocab_size)

        if config.share_embedding:
            assert self.decoder.embedding.weight.size(
            ) == self.joint.project_layer.weight.size(), '%d != %d' % (
                self.decoder.embedding.weight.size(1),
                self.joint.project_layer.weight.size(1))
            self.joint.project_layer.weight = self.decoder.embedding.weight
        #self.ctc_crit = CTCLoss()
        self.rnnt_crit = RNNTLoss()
Beispiel #19
0
    def __init__(self, blank_id):
        """Construct an TransLoss object."""
        super(TransLoss, self).__init__()

        self.trans_loss = RNNTLoss(blank=blank_id)
        self.blank_id = blank_id
Beispiel #20
0
import argparse
import numpy as np
import time
import torch
import torch.autograd as autograd
import torch.nn as nn

from warprnnt_pytorch import RNNTLoss
from transducer_np import RNNTLoss as rnntloss

parser = argparse.ArgumentParser(description='MXNet RNN Transducer Test.')
parser.add_argument('--np', default=False, action='store_true', help='numpy loss')
args = parser.parse_args()

fn = rnntloss() if args.np else RNNTLoss(size_average=False) 

gpu = 1
def wrap_and_call(acts, labels):
    acts = torch.FloatTensor(acts)
    if use_cuda:
        acts = acts.cuda(gpu)
    #acts = autograd.Variable(acts, requires_grad=True)
    acts.requires_grad = True

    lengths = [acts.shape[1]] * acts.shape[0]
    label_lengths = [len(l) for l in labels]
    labels = torch.IntTensor(labels)
    lengths = torch.IntTensor(lengths)
    label_lengths = torch.IntTensor(label_lengths)
    if use_cuda:
Beispiel #21
0
def main():
    os.environ["CUDA_VISIBLE_DEVICES"] = "0"
    parser = argparse.ArgumentParser()
    parser.add_argument('-config',
                        type=str,
                        default='config/joint_streaming.yaml')
    parser.add_argument('-log', type=str, default='train.log')
    parser.add_argument('-mode', type=str, default='retrain')
    opt = parser.parse_args()

    configfile = open(opt.config)
    config = AttrDict(yaml.load(configfile, Loader=yaml.FullLoader))

    exp_name = os.path.join('egs', config.data.name,
                            config.training.save_model)
    if not os.path.isdir(exp_name):
        os.makedirs(exp_name)
    logger = init_logger(os.path.join(exp_name, opt.log))

    shutil.copyfile(opt.config, os.path.join(exp_name, 'config.yaml'))
    logger.info('Save config info.')

    # create a visualizer
    if config.training.visualization:
        visualizer = SummaryWriter(exp_name)
        logger.info('Created a visualizer.')
    else:
        visualizer = None

    index2word, word2index = generate_dictionary(config.data.vocab)
    logger.info('Load Vocabulary!')

    # num_workers = config.training.num_gpu * config.data.batch_size
    # num_workers = config.data.batch_size
    train_dataset = AudioDataset(config.data, 'train', word2index)
    training_data = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=config.data.batch_size,
        # train_dataset, batch_size=config.data.batch_size * config.training.num_gpu,
        shuffle=config.data.shuffle,
        num_workers=12)
    logger.info('Load Train Set!')

    dev_dataset = AudioDataset(config.data, 'dev', word2index)
    validate_data = torch.utils.data.DataLoader(
        dev_dataset,
        batch_size=config.data.batch_size,
        # dev_dataset, batch_size=config.data.batch_size * config.training.num_gpu,
        shuffle=False,
        num_workers=12)
    logger.info('Load Dev Set!')

    if config.training.num_gpu > 0:
        torch.cuda.manual_seed(config.training.seed)
        torch.backends.cudnn.deterministic = True
    else:
        torch.manual_seed(config.training.seed)
    logger.info('Set random seed: %d' % config.training.seed)

    model = Transducer(config.model)

    if config.training.load_model:
        checkpoint = torch.load(config.training.load_model)
        model.encoder.load_state_dict(checkpoint['encoder'])
        model.decoder.load_state_dict(checkpoint['decoder'])
        model.joint.load_state_dict(checkpoint['joint'])
        logger.info('Loaded model from %s' % config.training.load_model)
    elif config.training.load_encoder or config.training.load_decoder:
        if config.training.load_encoder:
            checkpoint = torch.load(config.training.load_encoder)
            model.encoder.load_state_dict(checkpoint['encoder'])
            logger.info('Loaded encoder from %s' %
                        config.training.load_encoder)
        if config.training.load_decoder:
            checkpoint = torch.load(config.training.load_decoder)
            model.decoder.load_state_dict(checkpoint['decoder'])
            logger.info('Loaded decoder from %s' %
                        config.training.load_decoder)

    if config.training.num_gpu > 0:
        model = model.cuda()
        if config.training.num_gpu > 1:
            device_ids = list(range(config.training.num_gpu))
            model = torch.nn.DataParallel(model, device_ids=device_ids)
        logger.info('Loaded the model to %d GPUs' % config.training.num_gpu)

    n_params, enc, dec = count_parameters(model)
    logger.info('# the number of parameters in the whole model: %d' % n_params)
    logger.info('# the number of parameters in the Encoder: %d' % enc)
    logger.info('# the number of parameters in the Decoder: %d' % dec)
    logger.info('# the number of parameters in the JointNet: %d' %
                (n_params - dec - enc))

    optimizer = Optimizer(model.parameters(), config.optim)
    logger.info('Created a %s optimizer.' % config.optim.type)

    criterion = RNNTLoss()
    logger.info('Created a RNNT loss.')

    if opt.mode == 'continue':
        optimizer.load_state_dict(checkpoint['optimizer'])
        start_epoch = checkpoint['epoch']
        global_epoch = checkpoint['step']
        optimizer.global_step = global_epoch
        optimizer.current_epoch = start_epoch
        logger.info('Load Optimizer State!')
    else:
        start_epoch = 0

    for epoch in range(start_epoch, config.training.epochs):

        train(epoch, config, model, training_data, optimizer, criterion,
              logger, visualizer)

        save_name = os.path.join(
            exp_name, '%s.epoch%d.chkpt' % (config.training.save_model, epoch))
        save_model(model, optimizer, config, save_name)
        logger.info('Epoch %d model has been saved.' % epoch)

        if config.training.eval_or_not:
            _ = eval(epoch, config, model, validate_data, logger, visualizer,
                     index2word)

        if epoch >= config.optim.begin_to_adjust_lr:
            optimizer.decay_lr()
            # early stop
            if optimizer.lr < 1e-6:
                logger.info('The learning rate is too low to train.')
                break
            logger.info('Epoch %d update learning rate: %.6f' %
                        (epoch, optimizer.lr))

    logger.info('The training process is OVER!')
Beispiel #22
0
    def __init__(
        self,
        vocab_size: int,
        token_list: Union[Tuple[str, ...], List[str]],
        frontend: Optional[AbsFrontend],
        specaug: Optional[AbsSpecAug],
        normalize: Optional[AbsNormalize],
        preencoder: Optional[AbsPreEncoder],
        encoder: AbsEncoder,
        postencoder: Optional[AbsPostEncoder],
        decoder: AbsDecoder,
        ctc: CTC,
        joint_network: Optional[torch.nn.Module],
        ctc_weight: float = 0.5,
        interctc_weight: float = 0.0,
        ignore_id: int = -1,
        lsm_weight: float = 0.0,
        length_normalized_loss: bool = False,
        report_cer: bool = True,
        report_wer: bool = True,
        sym_space: str = "<space>",
        sym_blank: str = "<blank>",
        extract_feats_in_collect_stats: bool = True,
    ):
        assert check_argument_types()
        assert 0.0 <= ctc_weight <= 1.0, ctc_weight
        assert 0.0 <= interctc_weight < 1.0, interctc_weight

        super().__init__()
        # note that eos is the same as sos (equivalent ID)
        self.blank_id = 0
        self.sos = vocab_size - 1
        self.eos = vocab_size - 1
        self.vocab_size = vocab_size
        self.ignore_id = ignore_id
        self.ctc_weight = ctc_weight
        self.interctc_weight = interctc_weight
        self.token_list = token_list.copy()

        self.frontend = frontend
        self.specaug = specaug
        self.normalize = normalize
        self.preencoder = preencoder
        self.postencoder = postencoder
        self.encoder = encoder

        if not hasattr(self.encoder, "interctc_use_conditioning"):
            self.encoder.interctc_use_conditioning = False
        if self.encoder.interctc_use_conditioning:
            self.encoder.conditioning_layer = torch.nn.Linear(
                vocab_size, self.encoder.output_size()
            )

        self.use_transducer_decoder = joint_network is not None

        self.error_calculator = None

        if self.use_transducer_decoder:
            from warprnnt_pytorch import RNNTLoss

            self.decoder = decoder
            self.joint_network = joint_network

            self.criterion_transducer = RNNTLoss(
                blank=self.blank_id,
                fastemit_lambda=0.0,
            )

            if report_cer or report_wer:
                self.error_calculator_trans = ErrorCalculatorTransducer(
                    decoder,
                    joint_network,
                    token_list,
                    sym_space,
                    sym_blank,
                    report_cer=report_cer,
                    report_wer=report_wer,
                )
            else:
                self.error_calculator_trans = None

                if self.ctc_weight != 0:
                    self.error_calculator = ErrorCalculator(
                        token_list, sym_space, sym_blank, report_cer, report_wer
                    )
        else:
            # we set self.decoder = None in the CTC mode since
            # self.decoder parameters were never used and PyTorch complained
            # and threw an Exception in the multi-GPU experiment.
            # thanks Jeff Farris for pointing out the issue.
            if ctc_weight == 1.0:
                self.decoder = None
            else:
                self.decoder = decoder

            self.criterion_att = LabelSmoothingLoss(
                size=vocab_size,
                padding_idx=ignore_id,
                smoothing=lsm_weight,
                normalize_length=length_normalized_loss,
            )

            if report_cer or report_wer:
                self.error_calculator = ErrorCalculator(
                    token_list, sym_space, sym_blank, report_cer, report_wer
                )

        if ctc_weight == 0.0:
            self.ctc = None
        else:
            self.ctc = ctc

        self.extract_feats_in_collect_stats = extract_feats_in_collect_stats
Beispiel #23
0
    def __init__(
        self,
        encoder_dim: int,
        decoder_dim: int,
        joint_dim: int,
        output_dim: int,
        joint_activation_type: str = "tanh",
        transducer_loss_weight: float = 1.0,
        ctc_loss: bool = False,
        ctc_loss_weight: float = 0.5,
        ctc_loss_dropout_rate: float = 0.0,
        lm_loss: bool = False,
        lm_loss_weight: float = 0.5,
        lm_loss_smoothing_rate: float = 0.0,
        aux_transducer_loss: bool = False,
        aux_transducer_loss_weight: float = 0.2,
        aux_transducer_loss_mlp_dim: int = 320,
        aux_trans_loss_mlp_dropout_rate: float = 0.0,
        symm_kl_div_loss: bool = False,
        symm_kl_div_loss_weight: float = 0.2,
        fastemit_lambda: float = 0.0,
        blank_id: int = 0,
        ignore_id: int = -1,
        training: bool = False,
    ):
        """Initialize module for Transducer tasks.

        Args:
            encoder_dim: Encoder outputs dimension.
            decoder_dim: Decoder outputs dimension.
            joint_dim: Joint space dimension.
            output_dim: Output dimension.
            joint_activation_type: Type of activation for joint network.
            transducer_loss_weight: Weight for main transducer loss.
            ctc_loss: Compute CTC loss.
            ctc_loss_weight: Weight of CTC loss.
            ctc_loss_dropout_rate: Dropout rate for CTC loss inputs.
            lm_loss: Compute LM loss.
            lm_loss_weight: Weight of LM loss.
            lm_loss_smoothing_rate: Smoothing rate for LM loss' label smoothing.
            aux_transducer_loss: Compute auxiliary transducer loss.
            aux_transducer_loss_weight: Weight of auxiliary transducer loss.
            aux_transducer_loss_mlp_dim: Hidden dimension for aux. transducer MLP.
            aux_trans_loss_mlp_dropout_rate: Dropout rate for aux. transducer MLP.
            symm_kl_div_loss: Compute KL divergence loss.
            symm_kl_div_loss_weight: Weight of KL divergence loss.
            fastemit_lambda: Regularization parameter for FastEmit.
            blank_id: Blank symbol ID.
            ignore_id: Padding symbol ID.
            training: Whether the model was initializated in training or inference mode.

        """
        super().__init__()

        if not training:
            ctc_loss, lm_loss, aux_transducer_loss, symm_kl_div_loss = (
                False,
                False,
                False,
                False,
            )

        self.joint_network = JointNetwork(output_dim, encoder_dim, decoder_dim,
                                          joint_dim, joint_activation_type)

        if training:
            from warprnnt_pytorch import RNNTLoss

            self.transducer_loss = RNNTLoss(
                blank=blank_id,
                reduction="sum",
                fastemit_lambda=fastemit_lambda,
            )

        if ctc_loss:
            self.ctc_lin = torch.nn.Linear(encoder_dim, output_dim)

            self.ctc_loss = torch.nn.CTCLoss(
                blank=blank_id,
                reduction="none",
                zero_infinity=True,
            )

        if aux_transducer_loss:
            self.mlp = torch.nn.Sequential(
                torch.nn.Linear(encoder_dim, aux_transducer_loss_mlp_dim),
                torch.nn.LayerNorm(aux_transducer_loss_mlp_dim),
                torch.nn.Dropout(p=aux_trans_loss_mlp_dropout_rate),
                torch.nn.ReLU(),
                torch.nn.Linear(aux_transducer_loss_mlp_dim, joint_dim),
            )

            if symm_kl_div_loss:
                self.kl_div = torch.nn.KLDivLoss(reduction="sum")

        if lm_loss:
            self.lm_lin = torch.nn.Linear(decoder_dim, output_dim)

            self.label_smoothing_loss = LabelSmoothingLoss(
                output_dim,
                ignore_id,
                lm_loss_smoothing_rate,
                normalize_length=False)

        self.output_dim = output_dim

        self.transducer_loss_weight = transducer_loss_weight

        self.use_ctc_loss = ctc_loss
        self.ctc_loss_weight = ctc_loss_weight
        self.ctc_dropout_rate = ctc_loss_dropout_rate

        self.use_lm_loss = lm_loss
        self.lm_loss_weight = lm_loss_weight

        self.use_aux_transducer_loss = aux_transducer_loss
        self.aux_transducer_loss_weight = aux_transducer_loss_weight

        self.use_symm_kl_div_loss = symm_kl_div_loss
        self.symm_kl_div_loss_weight = symm_kl_div_loss_weight

        self.blank_id = blank_id
        self.ignore_id = ignore_id

        self.target = None
Beispiel #24
0
import argparse
import numpy as np
import time
import torch
import torch.autograd as autograd
import torch.nn as nn

from warprnnt_pytorch import RNNTLoss
from transducer_np import RNNTLoss as rnntloss

parser = argparse.ArgumentParser(description='MXNet RNN Transducer Test.')
parser.add_argument('--np', default=False, action='store_true', help='numpy loss')
args = parser.parse_args()

fn = rnntloss() if args.np else RNNTLoss(reduction='sum')

gpu = 1
def wrap_and_call(acts, labels):
    acts = torch.FloatTensor(acts)
    if use_cuda:
        acts = acts.cuda(gpu)
    #acts = autograd.Variable(acts, requires_grad=True)
    acts.requires_grad = True

    lengths = [acts.shape[1]] * acts.shape[0]
    label_lengths = [len(l) for l in labels]
    labels = torch.IntTensor(labels)
    lengths = torch.IntTensor(lengths)
    label_lengths = torch.IntTensor(label_lengths)
    if use_cuda:
parser = argparse.ArgumentParser(description='MXNet RNN Transducer Test.')
parser.add_argument('B', type=int, default=1, help='batch size')
parser.add_argument('T', type=int, default=300, help='time step')
parser.add_argument('U', type=int, default=100, help='prediction step')
parser.add_argument('V', type=int, default=60, help='vocab size')
parser.add_argument('--np',
                    default=False,
                    action='store_true',
                    help='use numpy loss')
parser.add_argument('--add',
                    default=False,
                    action='store_true',
                    help='add_network')
args = parser.parse_args()

fn = rnntloss() if args.np else RNNTLoss()


def get_gpu_memory_map():
    result = subprocess.check_output([
        'nvidia-smi', '--query-gpu=memory.used',
        '--format=csv,nounits,noheader'
    ],
                                     encoding='utf-8')

    gpu_memory = [int(x) for x in result.strip().split('\n')]
    gpu_memory_map = dict(zip(range(len(gpu_memory)), gpu_memory))
    return gpu_memory_map


def wrap_and_call():
Beispiel #26
0
import numpy as np
import sys
import os
import matplotlib.pyplot as plt
import ctcdecode
import time

import torch.utils
import torch.utils.checkpoint

#awni_transducer_path = '/home/lugosch/code/transducer'
#sys.path.insert(0, awni_transducer_path)
#from transducer import Transducer

from warprnnt_pytorch import RNNTLoss
rnnt_loss = RNNTLoss()


class TransducerModel(torch.nn.Module):
    def __init__(self, config):
        super(TransducerModel, self).__init__()
        self.encoder = Encoder(config)
        self.decoder = AutoregressiveDecoder(config)
        self.joiner = Joiner(config)
        self.blank_index = self.joiner.blank_index
        self.num_outputs = self.joiner.num_outputs
        #self.transducer_loss = Transducer(blank_label=self.blank_index)
        self.ctc_decoder = ctcdecode.CTCBeamDecoder(
            ["a" for _ in range(self.num_outputs)],
            blank_id=self.blank_index,
            beam_width=config.beam_width)
Beispiel #27
0
def main():

    yaml_name = "/home/jhjeong/jiho_deep/rnn-t/label,csv/RNN-T_mobile_2.yaml"

    with open("./train.txt", "w") as f:
        f.write(yaml_name)
        f.write('\n')
        f.write('\n')
        f.write("학습 시작")
        f.write('\n')

    configfile = open(yaml_name)
    config = AttrDict(yaml.load(configfile, Loader=yaml.FullLoader))

    summary = SummaryWriter()

    windows = {
        'hamming': scipy.signal.hamming,
        'hann': scipy.signal.hann,
        'blackman': scipy.signal.blackman,
        'bartlett': scipy.signal.bartlett
    }

    SAMPLE_RATE = config.audio_data.sampling_rate
    WINDOW_SIZE = config.audio_data.window_size
    WINDOW_STRIDE = config.audio_data.window_stride
    WINDOW = config.audio_data.window

    audio_conf = dict(sample_rate=SAMPLE_RATE,
                      window_size=WINDOW_SIZE,
                      window_stride=WINDOW_STRIDE,
                      window=WINDOW)

    train_manifest_filepath = config.data.train_path
    val_manifest_filepath = config.data.val_path

    random.seed(config.data.seed)
    torch.manual_seed(config.data.seed)
    torch.cuda.manual_seed_all(config.data.seed)

    cuda = torch.cuda.is_available()
    device = torch.device('cuda' if cuda else 'cpu')

    #model
    #Prediction Network
    enc = BaseEncoder(input_size=config.model.enc.input_size,
                      hidden_size=config.model.enc.hidden_size,
                      output_size=config.model.enc.output_size,
                      n_layers=config.model.enc.n_layers,
                      dropout=config.model.dropout,
                      bidirectional=config.model.enc.bidirectional)

    #Transcription Network
    dec = BaseDecoder(embedding_size=config.model.dec.embedding_size,
                      hidden_size=config.model.dec.hidden_size,
                      vocab_size=config.model.vocab_size,
                      output_size=config.model.dec.output_size,
                      n_layers=config.model.dec.n_layers,
                      dropout=config.model.dropout)

    model = Transducer(enc, dec, config.model.joint.input_size,
                       config.model.joint.inner_dim, config.model.vocab_size)

    # 여기 모델 불러오는거
    #model.load_state_dict(torch.load("/home/jhjeong/jiho_deep/rnn-t/model_save/model2_save_epoch_19.pth"))

    model = nn.DataParallel(model).to(device)

    if config.optim.type == "AdamW":
        optimizer = optim.AdamW(model.module.parameters(),
                                lr=config.optim.lr,
                                weight_decay=config.optim.weight_decay)

    elif config.optim.type == "Adam":
        optimizer = optim.Adam(model.module.parameters(),
                               lr=config.optim.lr,
                               weight_decay=config.optim.weight_decay)
    else:
        pass

    scheduler = optim.lr_scheduler.MultiStepLR(
        optimizer,
        milestones=config.optim.milestones,
        gamma=config.optim.decay_rate)

    criterion = RNNTLoss().to(device)
    """
        acts: Tensor of [batch x seqLength x (labelLength + 1) x outputDim] containing output from network
        (+1 means first blank label prediction)
        labels: 2 dimensional Tensor containing all the targets of the batch with zero padded
        act_lens: Tensor of size (batch) containing size of each output sequence from the network
        label_lens: Tensor of (batch) containing label length of each example
    """

    #train dataset
    train_dataset = SpectrogramDataset(audio_conf,
                                       config.data.train_path,
                                       feature_type=config.audio_data.type,
                                       normalize=True,
                                       spec_augment=True)

    train_loader = AudioDataLoader(dataset=train_dataset,
                                   shuffle=True,
                                   num_workers=config.data.num_workers,
                                   batch_size=config.data.batch_size,
                                   drop_last=True)

    #val dataset
    val_dataset = SpectrogramDataset(audio_conf,
                                     config.data.val_path,
                                     feature_type=config.audio_data.type,
                                     normalize=True,
                                     spec_augment=False)

    val_loader = AudioDataLoader(dataset=val_dataset,
                                 shuffle=True,
                                 num_workers=config.data.num_workers,
                                 batch_size=config.data.batch_size,
                                 drop_last=True)

    print(model)
    print("시작합니다.")

    pre_val_loss = 100000
    for epoch in range(config.training.begin_epoch, config.training.end_epoch):
        print('{} 학습 시작'.format(datetime.datetime.now()))
        train_time = time.time()
        train_loss = train(model, train_loader, optimizer, criterion, device)
        train_total_time = time.time() - train_time
        print('{} Epoch {} (Training) Loss {:.4f}, time: {:.4f}'.format(
            datetime.datetime.now(), epoch + 1, train_loss, train_total_time))

        print('{} 평가 시작'.format(datetime.datetime.now()))
        eval_time = time.time()
        val_loss = eval(model, val_loader, criterion, device)
        eval_total_time = time.time() - eval_time
        print('{} Epoch {} (val) Loss {:.4f}, time: {:.4f}'.format(
            datetime.datetime.now(), epoch + 1, val_loss, eval_total_time))

        scheduler.step()

        with open("./train.txt", "a") as ff:
            ff.write('Epoch %d (Training) Loss %0.4f time %0.4f' %
                     (epoch + 1, train_loss, train_total_time))
            ff.write('\n')
            ff.write('Epoch %d (val) Loss %0.4f time %0.4f ' %
                     (epoch + 1, val_loss, eval_total_time))
            ff.write('\n')
            ff.write('\n')

        if pre_val_loss > val_loss:
            print("best model을 저장하였습니다.")
            torch.save(model.module.state_dict(),
                       "./model_save/model_save.pth")
            pre_val_loss = val_loss
def main(learning_rate=5e-4,
         batch_size=32,
         epochs=10,
         train_url="train-clean-100",
         test_url="test-clean",
         experiment=Experiment(api_key='dummy_key', disabled=True)):
    hparams = {
        # Encoder Parameters
        "n_cnn_layers": 3,
        "n_rnn_layers": 5,
        "rnn_dim": 512,
        "n_class": 29,
        "n_feats": 128,
        "stride": 2,
        "dropout": 0.2,
        "learning_rate": learning_rate,
        "batch_size": batch_size,
        # Decoder Parameters
        "hidden_size": 256,
        "vocab_size": 29,
        "output_size": 29,
        "n_layers": 1,
        # Joint Parameters,
        "input_size": 58,
        "inner_dim": 256,
        "epochs": epochs,
    }
    experiment.log_parameters(hparams)

    use_cuda = torch.cuda.is_available()
    torch.manual_seed(7)
    device = torch.device("cuda" if use_cuda else "cpu")

    if not os.path.isdir("./data"):
        os.makedirs("./data")

    train_dataset = torchaudio.datasets.LIBRISPEECH("./data",
                                                    url=train_url,
                                                    download=True)
    test_dataset = torchaudio.datasets.LIBRISPEECH("./data",
                                                   url=test_url,
                                                   download=True)

    kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}

    train_loader = data.DataLoader(
        dataset=train_dataset,
        batch_size=hparams['batch_size'],
        shuffle=True,
        collate_fn=lambda x: data_processing(x, 'train'),
        **kwargs)
    test_loader = data.DataLoader(
        dataset=test_dataset,
        batch_size=hparams['batch_size'],
        shuffle=False,
        collate_fn=lambda x: data_processing(x, 'valid'),
        **kwargs)

    model = Transducer(hparams['n_cnn_layers'], hparams['n_rnn_layers'],
                       hparams['rnn_dim'], hparams['n_class'],
                       hparams['n_feats'], hparams['hidden_size'],
                       hparams['vocab_size'], hparams['output_size'],
                       hparams['input_size'], hparams['inner_dim'],
                       hparams['n_layers']).to(device)

    # Model Parameter Initialization
    def init_weights(m):
        if type(m) == nn.Linear or type(m) == nn.Conv2d:
            torch.nn.init.xavier_uniform_(m.weight)
        else:
            if type(m) == nn.Embedding:
                torch.nn.init.uniform_(m.weight)
            else:
                if type(m) == nn.GRU or type(m) == nn.LSTM:
                    torch.nn.init.orthogonal_(m.weight_ih_l0)
                    torch.nn.init.orthogonal_(m.weight_hh_l0)

    model.apply(init_weights)  # Initialize before start of training

    # model.load_state_dict(torch.load("rnnt.params"))  # If the training is interrupted ,uncomment this line and comment the initialization to continue training

    print(model)
    print('Num Model Parameters',
          sum([param.nelement() for param in model.parameters()]))

    optimizer = optim.AdamW(model.parameters(), hparams['learning_rate'])
    criterion = RNNTLoss()
    scheduler = optim.lr_scheduler.OneCycleLR(optimizer,
                                              max_lr=hparams['learning_rate'],
                                              steps_per_epoch=int(
                                                  len(train_loader)),
                                              epochs=hparams['epochs'],
                                              anneal_strategy='linear')

    iter_meter = IterMeter()
    for epoch in range(1, epochs + 1):
        train(model, device, train_loader, criterion, optimizer, scheduler,
              epoch, iter_meter, experiment)
        torch.save(
            model.state_dict(),
            'rnnt.params')  # Save trained model after one-epoch training
        test(model, device, test_loader, criterion, epoch, iter_meter,
             experiment)
Beispiel #29
0
    def _calc_transducer_loss(
        self,
        encoder_out: torch.Tensor,
        joint_out: torch.Tensor,
        target: torch.Tensor,
        t_len: torch.Tensor,
        u_len: torch.Tensor,
    ) -> Tuple[torch.Tensor, Optional[float], Optional[float]]:
        """Compute Transducer loss.

        Args:
            encoder_out: Encoder output sequences. (B, T, D_enc)
            joint_out: Joint Network output sequences (B, T, U, D_joint)
            target: Target label ID sequences. (B, L)
            t_len: Encoder output sequences lengths. (B,)
            u_len: Target label ID sequences lengths. (B,)

        Return:
            loss_transducer: Transducer loss value.
            cer_transducer: Character error rate for Transducer.
            wer_transducer: Word Error Rate for Transducer.

        """
        if self.criterion_transducer is None:
            try:
                from warprnnt_pytorch import RNNTLoss

                self.criterion_transducer = RNNTLoss(
                    reduction="mean",
                    fastemit_lambda=self.fastemit_lambda,
                )
            except ImportError:
                logging.error("warp-rnnt was not installed."
                              "Please consult the installation documentation.")
                exit(1)

        loss_transducer = self.criterion_transducer(
            joint_out,
            target,
            t_len,
            u_len,
        )

        if not self.training and (self.report_cer or self.report_wer):
            if self.error_calculator is None:
                from espnet2.asr_transducer.error_calculator import ErrorCalculator

                self.error_calculator = ErrorCalculator(
                    self.decoder,
                    self.joint_network,
                    self.token_list,
                    self.sym_space,
                    self.sym_blank,
                    report_cer=self.report_cer,
                    report_wer=self.report_wer,
                )

            cer_transducer, wer_transducer = self.error_calculator(
                encoder_out, target)

            return loss_transducer, cer_transducer, wer_transducer

        return loss_transducer, None, None