def eval_word(models,
              dataset,
              beam_width,
              max_decode_len,
              beam_width_sub=1,
              max_decode_len_sub=300,
              eval_batch_size=None,
              length_penalty=0,
              progressbar=False,
              temperature=1,
              resolving_unk=False,
              a2c_oracle=False):
    """Evaluate trained model by Word Error Rate.
    Args:
        models (list): the models to evaluate
        dataset: An instance of a `Dataset' class
        max_decode_len (int): the length of output sequences
            to stop prediction. This is used for seq2seq models.
        beam_width_sub (int, optional): the size of beam in ths sub task
            This is used for the nested attention
        max_decode_len_sub (int, optional): the length of output sequences
            to stop prediction. This is used for the nested attention
        eval_batch_size (int, optional): the batch size when evaluating the model
        progressbar (bool, optional): if True, visualize the progressbar
        temperature (int, optional):
        resolving_unk (bool, optional):
        a2c_oracle (bool, optional):
    Returns:
        wer (float): Word error rate
        df_wer (pd.DataFrame): dataframe of substitution, insertion, and deletion
    """
    # Reset data counter
    dataset.reset()

    idx2word = Idx2word(dataset.vocab_file_path)
    if models[0].model_type == 'nested_attention':
        char2idx = Char2idx(dataset.vocab_file_path_sub)
    if models[0] in ['ctc', 'attention'] and resolving_unk:
        idx2char = Idx2char(dataset.vocab_file_path_sub,
                            capital_divide=dataset.label_type_sub ==
                            'character_capital_divide')

    wer = 0
    sub, ins, dele, = 0, 0, 0
    num_words = 0
    if progressbar:
        pbar = tqdm(total=len(dataset))  # TODO: fix this
    while True:
        batch, is_new_epoch = dataset.next(batch_size=eval_batch_size)

        batch_size = len(batch['xs'])

        # Decode
        if len(models) > 1:
            assert models[0].model_type in ['ctc']
            for i, model in enumerate(models):
                probs, x_lens, perm_idx = model.posteriors(
                    batch['xs'], batch['x_lens'])
                if i == 0:
                    probs_ensenmble = probs
                else:
                    probs_ensenmble += probs
            probs_ensenmble /= len(models)

            best_hyps = models[0].decode_from_probs(probs_ensenmble,
                                                    x_lens,
                                                    beam_width=1)
        else:
            model = models[0]
            # TODO: fix this

            if model.model_type == 'nested_attention':
                if a2c_oracle:
                    if dataset.is_test:
                        max_label_num = 0
                        for b in range(batch_size):
                            if max_label_num < len(list(
                                    batch['ys_sub'][b][0])):
                                max_label_num = len(list(
                                    batch['ys_sub'][b][0]))

                        ys_sub = np.zeros((batch_size, max_label_num),
                                          dtype=np.int32)
                        ys_sub -= 1  # pad with -1
                        y_lens_sub = np.zeros((batch_size, ), dtype=np.int32)
                        for b in range(batch_size):
                            indices = char2idx(batch['ys_sub'][b][0])
                            ys_sub[b, :len(indices)] = indices
                            y_lens_sub[b] = len(indices)
                            # NOTE: transcript is seperated by space('_')
                else:
                    ys_sub = batch['ys_sub']
                    y_lens_sub = batch['y_lens_sub']

                best_hyps, aw, best_hyps_sub, aw_sub, perm_idx = model.decode(
                    batch['xs'],
                    batch['x_lens'],
                    beam_width=beam_width,
                    beam_width_sub=beam_width_sub,
                    max_decode_len=max_decode_len,
                    max_decode_len_sub=max_label_num
                    if a2c_oracle else max_decode_len_sub,
                    length_penalty=length_penalty,
                    teacher_forcing=a2c_oracle,
                    ys_sub=ys_sub,
                    y_lens_sub=y_lens_sub)
            else:
                best_hyps, aw, perm_idx = model.decode(
                    batch['xs'],
                    batch['x_lens'],
                    beam_width=beam_width,
                    max_decode_len=max_decode_len,
                    length_penalty=length_penalty)
                if resolving_unk:
                    best_hyps_sub, aw_sub, _ = model.decode(
                        batch['xs'],
                        batch['x_lens'],
                        beam_width=beam_width,
                        max_decode_len=max_decode_len_sub,
                        length_penalty=length_penalty,
                        task_index=1)

        ys = batch['ys'][perm_idx]
        y_lens = batch['y_lens'][perm_idx]

        for b in range(batch_size):
            ##############################
            # Reference
            ##############################
            if dataset.is_test:
                str_ref = ys[b][0]
                # NOTE: transcript is seperated by space('_')
            else:
                # Convert from list of index to string
                str_ref = idx2word(ys[b][:y_lens[b]])

            ##############################
            # Hypothesis
            ##############################
            str_hyp = idx2word(best_hyps[b])
            if dataset.label_type == 'word':
                str_hyp = re.sub(r'(.*)_>(.*)', r'\1', str_hyp)
            else:
                str_hyp = re.sub(r'(.*)>(.*)', r'\1', str_hyp)
            # NOTE: Trancate by the first <EOS>

            ##############################
            # Resolving UNK
            ##############################
            if resolving_unk and 'OOV' in str_hyp:
                str_hyp = resolve_unk(str_hyp, best_hyps_sub[b], aw[b],
                                      aw_sub[b], idx2char)
                str_hyp = str_hyp.replace('*', '')

            ##############################
            # Post-proccessing
            ##############################
            # Remove garbage labels
            str_ref = re.sub(r'[@>]+', '', str_ref)
            str_hyp = re.sub(r'[@>]+', '', str_hyp)
            # NOTE: @ means noise

            # Remove consecutive spaces
            str_ref = re.sub(r'[_]+', '_', str_ref)
            str_hyp = re.sub(r'[_]+', '_', str_hyp)

            # Compute WER
            try:
                wer_b, sub_b, ins_b, del_b = compute_wer(
                    ref=str_ref.split('_'),
                    hyp=str_hyp.split('_'),
                    normalize=False)
                wer += wer_b
                sub += sub_b
                ins += ins_b
                dele += del_b
                num_words += len(str_ref.split('_'))
            except:
                pass

            if progressbar:
                pbar.update(1)

        if is_new_epoch:
            break

    if progressbar:
        pbar.close()

    # Reset data counters
    dataset.reset()

    wer /= num_words
    sub /= num_words
    ins /= num_words
    dele /= num_words

    df_wer = pd.DataFrame(
        {
            'SUB': [sub * 100],
            'INS': [ins * 100],
            'DEL': [dele * 100]
        },
        columns=['SUB', 'INS', 'DEL'],
        index=['WER'])

    return wer, df_wer
def main():

    args = parser.parse_args()

    # Load a config file (.yml)
    params = load_config(join(args.model_path, 'config.yml'), is_eval=True)

    # Load dataset
    dataset = Dataset(
        data_save_path=args.data_save_path,
        backend=params['backend'],
        input_freq=params['input_freq'],
        use_delta=params['use_delta'],
        use_double_delta=params['use_double_delta'],
        data_type='eval1',
        # data_type='eval2',
        # data_type='eval3',
        data_size=params['data_size'],
        label_type=params['label_type'],
        label_type_sub=params['label_type_sub'],
        batch_size=args.eval_batch_size,
        splice=params['splice'],
        num_stack=params['num_stack'],
        num_skip=params['num_skip'],
        sort_utt=False,
        reverse=False,
        tool=params['tool'])

    params['num_classes'] = dataset.num_classes
    params['num_classes_sub'] = dataset.num_classes_sub

    # Load model
    model = load(model_type=params['model_type'],
                 params=params,
                 backend=params['backend'])

    # Restore the saved parameters
    model.load_checkpoint(save_path=args.model_path, epoch=args.epoch)

    # GPU setting
    model.set_cuda(deterministic=False, benchmark=True)

    # sys.stdout = open(join(model.model_dir, 'decode.txt'), 'w')

    ######################################################################

    word2char = Word2char(dataset.vocab_file_path, dataset.vocab_file_path_sub)

    for batch, is_new_epoch in dataset:
        # Decode
        if model.model_type == 'nested_attention':
            best_hyps, aw, best_hyps_sub, aw_sub, _, perm_idx = model.decode(
                batch['xs'],
                batch['x_lens'],
                beam_width=args.beam_width,
                beam_width_sub=args.beam_width_sub,
                max_decode_len=MAX_DECODE_LEN_WORD,
                max_decode_len_sub=MAX_DECODE_LEN_CHAR,
                length_penalty=args.length_penalty,
                coverage_penalty=args.coverage_penalty)
        else:
            best_hyps, aw, perm_idx = model.decode(
                batch['xs'],
                batch['x_lens'],
                beam_width=args.beam_width,
                max_decode_len=MAX_DECODE_LEN_WORD,
                min_decode_len=MIN_DECODE_LEN_WORD,
                length_penalty=args.length_penalty,
                coverage_penalty=args.coverage_penalty)
            best_hyps_sub, aw_sub, _ = model.decode(
                batch['xs'],
                batch['x_lens'],
                beam_width=args.beam_width_sub,
                max_decode_len=MAX_DECODE_LEN_CHAR,
                min_decode_len=MIN_DECODE_LEN_CHAR,
                length_penalty=args.length_penalty,
                coverage_penalty=args.coverage_penalty,
                task_index=1)

        if model.model_type == 'hierarchical_attention' and args.joint_decoding is not None:
            best_hyps_joint, aw_joint, best_hyps_sub_joint, aw_sub_joint, _ = model.decode(
                batch['xs'],
                batch['x_lens'],
                beam_width=args.beam_width,
                max_decode_len=MAX_DECODE_LEN_WORD,
                min_decode_len=MIN_DECODE_LEN_WORD,
                length_penalty=args.length_penalty,
                coverage_penalty=args.coverage_penalty,
                joint_decoding=args.joint_decoding,
                space_index=dataset.char2idx('_')[0],
                oov_index=dataset.word2idx('OOV')[0],
                word2char=word2char,
                idx2word=dataset.idx2word,
                idx2char=dataset.idx2char,
                score_sub_weight=args.score_sub_weight)

        ys = batch['ys'][perm_idx]
        y_lens = batch['y_lens'][perm_idx]
        ys_sub = batch['ys_sub'][perm_idx]
        y_lens_sub = batch['y_lens_sub'][perm_idx]

        for b in range(len(batch['xs'])):
            ##############################
            # Reference
            ##############################
            if dataset.is_test:
                str_ref = ys[b][0]
                str_ref_sub = ys_sub[b][0]
                # NOTE: transcript is seperated by space('_')
            else:
                # Convert from list of index to string
                str_ref = dataset.idx2word(ys[b][:y_lens[b]])
                str_ref_sub = dataset.idx2char(ys_sub[b][:y_lens_sub[b]])

            ##############################
            # Hypothesis
            ##############################
            # Convert from list of index to string
            str_hyp = dataset.idx2word(best_hyps[b])
            str_hyp_sub = dataset.idx2char(best_hyps_sub[b])
            if model.model_type == 'hierarchical_attention' and args.joint_decoding is not None:
                str_hyp_joint = dataset.idx2word(best_hyps_joint[b])
                str_hyp_sub_joint = dataset.idx2char(best_hyps_sub_joint[b])

            ##############################
            # Resolving UNK
            ##############################
            if 'OOV' in str_hyp and args.resolving_unk:
                str_hyp_no_unk = resolve_unk(str_hyp, best_hyps_sub[b], aw[b],
                                             aw_sub[b], dataset.idx2char)
            if model.model_type == 'hierarchical_attention' and args.joint_decoding is not None:
                if 'OOV' in str_hyp_joint and args.resolving_unk:
                    str_hyp_no_unk_joint = resolve_unk(str_hyp_joint,
                                                       best_hyps_sub_joint[b],
                                                       aw_joint[b],
                                                       aw_sub_joint[b],
                                                       dataset.idx2char)

            print('----- wav: %s -----' % batch['input_names'][b])
            print('Ref         : %s' % str_ref.replace('_', ' '))
            print('Hyp (main)  : %s' % str_hyp.replace('_', ' '))
            print('Hyp (sub)   : %s' % str_hyp_sub.replace('_', ' '))
            if 'OOV' in str_hyp and args.resolving_unk:
                print('Hyp (no UNK): %s' % str_hyp_no_unk.replace('_', ' '))
            if model.model_type == 'hierarchical_attention' and args.joint_decoding is not None:
                print('===== joint decoding =====')
                print('Hyp (main)  : %s' % str_hyp_joint.replace('_', ' '))
                print('Hyp (sub)   : %s' % str_hyp_sub_joint.replace('_', ' '))
                if 'OOV' in str_hyp_joint and args.resolving_unk:
                    print('Hyp (no UNK): %s' %
                          str_hyp_no_unk_joint.replace('_', ' '))

            try:
                wer, _, _, _ = compute_wer(ref=str_ref.split('_'),
                                           hyp=re.sub(r'(.*)_>(.*)', r'\1',
                                                      str_hyp).split('_'),
                                           normalize=True)
                print('WER (main)  : %.3f %%' % (wer * 100))
                if dataset.label_type_sub == 'character_wb':
                    wer_sub, _, _, _ = compute_wer(ref=str_ref_sub.split('_'),
                                                   hyp=re.sub(
                                                       r'(.*)>(.*)', r'\1',
                                                       str_hyp_sub).split('_'),
                                                   normalize=True)
                    print('WER (sub)   : %.3f %%' % (wer_sub * 100))
                else:
                    cer, _, _, _ = compute_wer(
                        ref=list(str_ref_sub.replace('_', '')),
                        hyp=list(
                            re.sub(r'(.*)>(.*)', r'\1',
                                   str_hyp_sub).replace('_', '')),
                        normalize=True)
                    print('CER (sub)   : %.3f %%' % (cer * 100))
                if 'OOV' in str_hyp and args.resolving_unk:
                    wer_no_unk, _, _, _ = compute_wer(
                        ref=str_ref.split('_'),
                        hyp=re.sub(r'(.*)_>(.*)', r'\1',
                                   str_hyp_no_unk.replace('*', '')).split('_'),
                        normalize=True)
                    print('WER (no UNK): %.3f %%' % (wer_no_unk * 100))

                if model.model_type == 'hierarchical_attention' and args.joint_decoding is not None:
                    print('===== joint decoding =====')
                    wer_joint, _, _, _ = compute_wer(
                        ref=str_ref.split('_'),
                        hyp=re.sub(r'(.*)_>(.*)', r'\1',
                                   str_hyp_joint).split('_'),
                        normalize=True)
                    print('WER (main)  : %.3f %%' % (wer_joint * 100))
                    if 'OOV' in str_hyp_joint and args.resolving_unk:
                        wer_no_unk_joint, _, _, _ = compute_wer(
                            ref=str_ref.split('_'),
                            hyp=re.sub(r'(.*)_>(.*)', r'\1',
                                       str_hyp_no_unk_joint.replace(
                                           '*', '')).split('_'),
                            normalize=True)
                        print('WER (no UNK): %.3f %%' %
                              (wer_no_unk_joint * 100))

            except:
                print('--- skipped ---')
            print('\n')

        if is_new_epoch:
            break
def decode(model, dataset, beam_width, beam_width_sub,
           eval_batch_size=None, save_path=None, resolving_unk=False):
    """Visualize label outputs.
    Args:
        model: the model to evaluate
        dataset: An instance of a `Dataset` class
        beam_width: (int): the size of beam in the main task
        beam_width: (int): the size of beam in the sub task
        eval_batch_size (int, optional): the batch size when evaluating the model
        save_path (string): path to save decoding results
        resolving_unk (bool, optional):
    """
    # Set batch size in the evaluation
    if eval_batch_size is not None:
        dataset.batch_size = eval_batch_size

    idx2word = Idx2word(vocab_file_path=dataset.vocab_file_path)
    idx2char = Idx2char(vocab_file_path=dataset.vocab_file_path_sub,
                        capital_divide=dataset.label_type_sub == 'character_capital_divide')

    # Read GLM file
    glm = GLM(
        glm_path='/n/sd8/inaguma/corpus/swbd/data/eval2000/LDC2002T43/reference/en20000405_hub5.glm')

    if save_path is not None:
        sys.stdout = open(join(model.model_dir, 'decode.txt'), 'w')

    for batch, is_new_epoch in dataset:

        # Decode
        if model.model_type == 'nested_attention':
            best_hyps, aw, best_hyps_sub, aw_sub, perm_idx = model.decode(
                batch['xs'], batch['x_lens'],
                beam_width=beam_width,
                beam_width_sub=beam_width_sub,
                max_decode_len=MAX_DECODE_LEN_WORD,
                max_decode_len_sub=MAX_DECODE_LEN_CHAR)
        else:
            best_hyps, aw, perm_idx = model.decode(
                batch['xs'], batch['x_lens'],
                beam_width=beam_width,
                max_decode_len=MAX_DECODE_LEN_WORD)
            best_hyps_sub, aw_sub, _ = model.decode(
                batch['xs'], batch['x_lens'],
                beam_width=beam_width_sub,
                max_decode_len=MAX_DECODE_LEN_CHAR,
                task_index=1)

        ys = batch['ys'][perm_idx]
        y_lens = batch['y_lens'][perm_idx]
        ys_sub = batch['ys_sub'][perm_idx]
        y_lens_sub = batch['y_lens_sub'][perm_idx]

        for b in range(len(batch['xs'])):

            ##############################
            # Reference
            ##############################
            if dataset.is_test:
                str_ref_original = ys[b][0]
                str_ref_sub = ys_sub[b][0]
                # NOTE: transcript is seperated by space('_')
            else:
                # Convert from list of index to string
                str_ref_original = idx2word(ys[b][: y_lens[b]])
                str_ref_sub = idx2word(ys_sub[b][:y_lens_sub[b]])

            ##############################
            # Hypothesis
            ##############################
            # Convert from list of index to string
            str_hyp = idx2word(best_hyps[b])
            str_hyp_sub = idx2char(best_hyps_sub[b])

            ##############################
            # Resolving UNK
            ##############################
            if 'OOV' in str_hyp and resolving_unk:
                str_hyp_no_unk = resolve_unk(
                    str_hyp, best_hyps_sub[b], aw[b], aw_sub[b], idx2char)

            # if 'OOV' not in str_hyp:
            #     continue

            ##############################
            # Post-proccessing
            ##############################
            str_ref = fix_trans(str_ref_original, glm)
            str_ref_sub = fix_trans(str_ref_sub, glm)
            str_hyp = fix_trans(str_hyp, glm)
            str_hyp_sub = fix_trans(str_hyp_sub, glm)
            str_hyp_no_unk = fix_trans(str_hyp_no_unk, glm)

            if len(str_ref) == 0:
                continue

            print('----- wav: %s -----' % batch['input_names'][b])
            print('Ref         : %s' % str_ref.replace('_', ' '))
            print('Hyp (main)  : %s' % str_hyp.replace('_', ' '))
            print('Hyp (sub)   : %s' % str_hyp_sub.replace('_', ' '))
            if 'OOV' in str_hyp and resolving_unk:
                print('Hyp (no UNK): %s' % str_hyp_no_unk.replace('_', ' '))

            try:
                # Compute WER
                wer, _, _, _ = compute_wer(
                    ref=str_ref.split('_'),
                    hyp=str_hyp.replace(r'_>.*', '').split('_'),
                    normalize=True)
                print('WER (main)  : %.3f %%' % (wer * 100))
                wer_sub, _, _, _ = compute_wer(
                    ref=str_ref_sub.split('_'),
                    hyp=str_hyp_sub.replace(r'>.*', '').split('_'),
                    normalize=True)
                print('WER (sub)   : %.3f %%' % (wer_sub * 100))
                if 'OOV' in str_hyp and resolving_unk:
                    wer_no_unk, _, _, _ = compute_wer(
                        ref=str_ref.split('_'),
                        hyp=str_hyp_no_unk.replace(
                            '*', '').replace(r'_>.*', '').split('_'),
                        normalize=True)
                    print('WER (no UNK): %.3f %%' % (wer_no_unk * 100))
            except:
                print('--- skipped ---')

        if is_new_epoch:
            break
예제 #4
0
def eval_word(models, dataset, eval_batch_size,
              beam_width, max_decode_len, min_decode_len=0,
              beam_width_sub=1, max_decode_len_sub=200, min_decode_len_sub=0,
              length_penalty=0, coverage_penalty=0,
              progressbar=False, resolving_unk=False, a2c_oracle=False,
              joint_decoding=None, score_sub_weight=0):
    """Evaluate trained model by Word Error Rate.
    Args:
        models (list): the models to evaluate
        dataset: An instance of a `Dataset' class
        eval_batch_size (int): the batch size when evaluating the model
        beam_width (int): the size of beam in ths main task
        max_decode_len (int): the maximum sequence length of tokens in the main task
        min_decode_len (int): the minimum sequence length of tokens in the main task
        beam_width_sub (int): the size of beam in ths sub task
            This is used for the nested attention
        max_decode_len_sub (int): the maximum sequence length of tokens in the sub task
        min_decode_len_sub (int): the minimum sequence length of tokens in the sub task
        length_penalty (float): length penalty in beam search decoding
        coverage_penalty (float): coverage penalty in beam search decoding
        progressbar (bool): if True, visualize the progressbar
        resolving_unk (bool):
        a2c_oracle (bool):
        joint_decoding (bool): onepass or resocring or None
        score_sub_weight (float):
    Returns:
        wer (float): Word error rate
        df_word (pd.DataFrame): dataframe of substitution, insertion, and deletion
    """
    # Reset data counter
    dataset.reset()

    model = models[0]
    # TODO: fix this

    if model.model_type == 'hierarchical_attention' and joint_decoding is not None:
        word2char = Word2char(dataset.vocab_file_path,
                              dataset.vocab_file_path_sub)

    wer = 0
    sub, ins, dele, = 0, 0, 0
    num_words = 0
    if progressbar:
        pbar = tqdm(total=len(dataset))  # TODO: fix this
    while True:
        batch, is_new_epoch = dataset.next(batch_size=eval_batch_size)

        batch_size = len(batch['xs'])

        # Decode
        if model.model_type == 'nested_attention':
            if a2c_oracle:
                if dataset.is_test:
                    max_label_num = 0
                    for b in range(batch_size):
                        if max_label_num < len(list(batch['ys_sub'][b][0])):
                            max_label_num = len(
                                list(batch['ys_sub'][b][0]))

                    ys_sub = np.zeros(
                        (batch_size, max_label_num), dtype=np.int32)
                    ys_sub -= 1  # pad with -1
                    y_lens_sub = np.zeros((batch_size,), dtype=np.int32)
                    for b in range(batch_size):
                        indices = dataset.char2idx(batch['ys_sub'][b][0])
                        ys_sub[b, :len(indices)] = indices
                        y_lens_sub[b] = len(indices)
                        # NOTE: transcript is seperated by space('_')
            else:
                ys_sub = batch['ys_sub']
                y_lens_sub = batch['y_lens_sub']

            best_hyps, aw, best_hyps_sub, aw_sub, _, perm_idx = model.decode(
                batch['xs'], batch['x_lens'],
                beam_width=beam_width,
                max_decode_len=max_decode_len,
                min_decode_len=min_decode_len,
                beam_width_sub=beam_width_sub,
                max_decode_len_sub=max_label_num if a2c_oracle else max_decode_len_sub,
                min_decode_len_sub=min_decode_len_sub,
                length_penalty=length_penalty,
                coverage_penalty=coverage_penalty,
                teacher_forcing=a2c_oracle,
                ys_sub=ys_sub,
                y_lens_sub=y_lens_sub)
        elif model.model_type == 'hierarchical_attention' and joint_decoding is not None:
            best_hyps, aw, best_hyps_sub, aw_sub, perm_idx = model.decode(
                batch['xs'], batch['x_lens'],
                beam_width=beam_width,
                max_decode_len=max_decode_len,
                min_decode_len=min_decode_len,
                length_penalty=length_penalty,
                coverage_penalty=coverage_penalty,
                joint_decoding=joint_decoding,
                space_index=dataset.char2idx('_')[0],
                oov_index=dataset.word2idx('OOV')[0],
                word2char=word2char,
                idx2word=dataset.idx2word,
                idx2char=dataset.idx2char,
                score_sub_weight=score_sub_weight)
        else:
            best_hyps, aw, perm_idx = model.decode(
                batch['xs'], batch['x_lens'],
                beam_width=beam_width,
                max_decode_len=max_decode_len,
                min_decode_len=min_decode_len,
                length_penalty=length_penalty,
                coverage_penalty=coverage_penalty)
            if resolving_unk:
                best_hyps_sub, aw_sub, _ = model.decode(
                    batch['xs'], batch['x_lens'],
                    beam_width=beam_width,
                    max_decode_len=max_decode_len_sub,
                    min_decode_len=min_decode_len_sub,
                    length_penalty=length_penalty,
                    coverage_penalty=coverage_penalty,
                    task_index=1)

        ys = batch['ys'][perm_idx]
        y_lens = batch['y_lens'][perm_idx]

        for b in range(batch_size):
            ##############################
            # Reference
            ##############################
            if dataset.is_test:
                str_ref = ys[b][0]
                # NOTE: transcript is seperated by space('_')
            else:
                # Convert from list of index to string
                str_ref = dataset.idx2word(ys[b][:y_lens[b]])

            ##############################
            # Hypothesis
            ##############################
            str_hyp = dataset.idx2word(best_hyps[b])
            if dataset.label_type == 'word':
                str_hyp = re.sub(r'(.*)_>(.*)', r'\1', str_hyp)
            else:
                str_hyp = re.sub(r'(.*)>(.*)', r'\1', str_hyp)
            # NOTE: Trancate by the first <EOS>

            ##############################
            # Resolving UNK
            ##############################
            if resolving_unk and 'OOV' in str_hyp:
                str_hyp = resolve_unk(
                    str_hyp, best_hyps_sub[b], aw[b], aw_sub[b], dataset.idx2char)
                str_hyp = str_hyp.replace('*', '')

            ##############################
            # Post-proccessing
            ##############################
            # Remove garbage labels
            str_ref = re.sub(r'[@>]+', '', str_ref)
            str_hyp = re.sub(r'[@>]+', '', str_hyp)
            # NOTE: @ means noise

            # Remove consecutive spaces
            str_ref = re.sub(r'[_]+', '_', str_ref)
            str_hyp = re.sub(r'[_]+', '_', str_hyp)

            # Compute WER
            try:
                wer_b, sub_b, ins_b, del_b = compute_wer(
                    ref=str_ref.split('_'),
                    hyp=str_hyp.split('_'),
                    normalize=False)
                wer += wer_b
                sub += sub_b
                ins += ins_b
                dele += del_b
                num_words += len(str_ref.split('_'))
            except:
                pass

            if progressbar:
                pbar.update(1)

        if is_new_epoch:
            break

    if progressbar:
        pbar.close()

    # Reset data counters
    dataset.reset()

    wer /= num_words
    sub /= num_words
    ins /= num_words
    dele /= num_words

    df_word = pd.DataFrame(
        {'SUB': [sub * 100], 'INS': [ins * 100], 'DEL': [dele * 100]},
        columns=['SUB', 'INS', 'DEL'], index=['WER'])

    return wer, df_word