Example #1
0
def eval_word(models, dataset, decode_params, epoch, progressbar=False):
    """Evaluate the word-level model by WER.

    Args:
        models (list): the models to evaluate
        dataset: An instance of a `Dataset' class
        decode_params (dict):
        epoch (int):
        progressbar (bool): if True, visualize the progressbar
    Returns:
        wer (float): Word error rate
        num_sub (int): the number of substitution errors
        num_ins (int): the number of insertion errors
        num_del (int): the number of deletion errors
        decode_dir (str):

    """
    # Reset data counter
    dataset.reset()

    model = models[0]
    # TODO(hirofumi): ensemble decoding

    decode_dir = 'decode_' + dataset.set + '_ep' + str(epoch) + '_beam' + str(
        decode_params['beam_width'])
    decode_dir += '_lp' + str(decode_params['length_penalty'])
    decode_dir += '_cp' + str(decode_params['coverage_penalty'])
    decode_dir += '_' + str(decode_params['min_len_ratio']) + '_' + str(
        decode_params['max_len_ratio'])
    decode_dir += '_rnnlm' + str(decode_params['rnnlm_weight'])

    ref_trn_save_path = mkdir_join(model.save_path, decode_dir, 'ref.trn')
    hyp_trn_save_path = mkdir_join(model.save_path, decode_dir, 'hyp.trn')

    wer = 0
    num_sub, num_ins, num_del, = 0, 0, 0
    num_words = 0
    num_oov_total = 0
    if progressbar:
        pbar = tqdm(total=len(dataset))  # TODO(hirofumi): fix this

    with open(hyp_trn_save_path, 'w') as f_hyp, open(ref_trn_save_path,
                                                     'w') as f_ref:
        while True:
            batch, is_new_epoch = dataset.next(decode_params['batch_size'])
            best_hyps, aw, perm_idx = model.decode(batch['xs'],
                                                   decode_params,
                                                   exclude_eos=True)
            ys = [batch['ys'][i] for i in perm_idx]

            for b in range(len(batch['xs'])):
                # Reference
                if dataset.is_test:
                    text_ref = ys[b]
                else:
                    text_ref = dataset.idx2word(ys[b])

                # Hypothesis
                text_hyp = dataset.idx2word(best_hyps[b])
                num_oov_total += text_hyp.count('<unk>')

                # Resolving UNK
                if decode_params['resolving_unk'] and '<unk>' in text_hyp:
                    best_hyps_sub, aw_sub, _ = model.decode(batch['xs'][b:b +
                                                                        1],
                                                            batch['xs'],
                                                            decode_params,
                                                            exclude_eos=True)
                    # task_index=1

                    text_hyp = resolve_unk(
                        text_hyp,
                        best_hyps_sub[0],
                        aw[b],
                        aw_sub[0],
                        dataset.idx2char,
                        diff_time_resolution=2**sum(model.subsample_list) //
                        2**sum(
                            model.
                            subsample_list[:model.encoder_num_layers_sub - 1]))
                    text_hyp = text_hyp.replace('*', '')

                # Write to trn
                speaker = '_'.join(batch['utt_ids'][b].replace(
                    '-', '_').split('_')[:-2])
                start = batch['utt_ids'][b].replace('-', '_').split('_')[-2]
                end = batch['utt_ids'][b].replace('-', '_').split('_')[-1]
                f_ref.write(text_ref + ' (' + speaker + '-' + start + '-' +
                            end + ')\n')
                f_hyp.write(text_hyp + ' (' + speaker + '-' + start + '-' +
                            end + ')\n')

                # Compute WER
                wer_b, sub_b, ins_b, del_b = compute_wer(
                    ref=text_ref.split(' '),
                    hyp=text_hyp.split(' '),
                    normalize=False)
                wer += wer_b
                num_sub += sub_b
                num_ins += ins_b
                num_del += del_b
                num_words += len(text_ref.split(' '))

                if progressbar:
                    pbar.update(1)

            if is_new_epoch:
                break

    if progressbar:
        pbar.close()

    # Reset data counters
    dataset.reset()

    wer /= num_words
    num_sub /= num_words
    num_ins /= num_words
    num_del /= num_words

    return wer, num_sub, num_ins, num_del, os.path.join(
        model.save_path, decode_dir)
Example #2
0
def eval_word(models,
              dataloader,
              recog_params,
              epoch,
              recog_dir=None,
              streaming=False,
              progressbar=False,
              edit_distance=True,
              fine_grained=False,
              oracle=False,
              teacher_force=False):
    """Evaluate a word-level model by WER.

    Args:
        models (List): models to evaluate
        dataloader (torch.utils.data.DataLoader): evaluation dataloader
        recog_params (omegaconf.dictconfig.DictConfig): decoding hyperparameters
        epoch (int): current epoch
        recog_dir (str): directory path to save hypotheses
        streaming (bool): streaming decoding for session-level evaluation
        progressbar (bool): visualize progressbar
        edit_distance (bool): calculate edit-distance (can be skipped for RTF calculation)
        fine_grained (bool): calculate fine-grained WER distributions based on input lengths
        oracle (bool): calculate oracle WER
        teacher_force (bool): conduct decoding in teacher-forcing mode
    Returns:
        wer (float): Word error rate
        cer (float): Character error rate
        n_oov_total (int): total number of OOV

    """
    if recog_dir is None:
        recog_dir = 'decode_' + dataloader.set + '_ep' + \
            str(epoch) + '_beam' + str(recog_params.get('recog_beam_width'))
        recog_dir += '_lp' + str(recog_params.get('recog_length_penalty'))
        recog_dir += '_cp' + str(recog_params.get('recog_coverage_penalty'))
        recog_dir += '_' + str(recog_params.get('recog_min_len_ratio')) + '_' + \
            str(recog_params.get('recog_max_len_ratio'))
        recog_dir += '_lm' + str(recog_params.get('recog_lm_weight'))

        ref_trn_path = mkdir_join(models[0].save_path, recog_dir, 'ref.trn')
        hyp_trn_path = mkdir_join(models[0].save_path, recog_dir, 'hyp.trn')
    else:
        ref_trn_path = mkdir_join(recog_dir, 'ref.trn')
        hyp_trn_path = mkdir_join(recog_dir, 'hyp.trn')

    wer, cer = 0, 0
    n_sub_w, n_ins_w, n_del_w = 0, 0, 0
    n_sub_c, n_ins_c, n_del_c = 0, 0, 0
    n_word, n_char = 0, 0
    wer_dist = {}  # calculate WER distribution based on input lengths
    n_oov_total = 0

    wer_oracle = 0
    n_oracle_hit = 0
    n_utt = 0

    # Reset data counter
    dataloader.reset(recog_params.get('recog_batch_size'))

    if progressbar:
        pbar = tqdm(total=len(dataloader))

    with codecs.open(hyp_trn_path, 'w', encoding='utf-8') as f_hyp, \
            codecs.open(ref_trn_path, 'w', encoding='utf-8') as f_ref:
        for batch in dataloader:
            speakers = batch['sessions' if dataloader.corpus ==
                             'swbd' else 'speakers']
            if streaming or recog_params.get('recog_block_sync'):
                nbest_hyps_id = models[0].decode_streaming(
                    batch['xs'],
                    recog_params,
                    dataloader.idx2token[0],
                    exclude_eos=True,
                    speaker=speakers[0])[0]
            else:
                nbest_hyps_id, aws = models[0].decode(
                    batch['xs'],
                    recog_params,
                    idx2token=dataloader.idx2token[0],
                    exclude_eos=True,
                    refs_id=batch['ys'],
                    utt_ids=batch['utt_ids'],
                    speakers=speakers,
                    ensemble_models=models[1:] if len(models) > 1 else [])

            for b in range(len(batch['xs'])):
                ref = batch['text'][b]
                nbest_hyps = [
                    dataloader.idx2token[0](hyp_id)
                    for hyp_id in nbest_hyps_id[b]
                ]
                n_oov_total += nbest_hyps[0].count('<unk>')

                # Resolving UNK
                if recog_params.get(
                        'recog_resolving_unk') and '<unk>' in nbest_hyps[0]:
                    recog_params_char = copy.deepcopy(recog_params)
                    recog_params_char['recog_lm_weight'] = 0
                    recog_params_char['recog_beam_width'] = 1
                    best_hyps_id_char, aw_char = models[0].decode(
                        batch['xs'][b:b + 1],
                        recog_params_char,
                        idx2token=dataloader.idx2token[1],
                        exclude_eos=True,
                        refs_id=batch['ys_sub1'],
                        utt_ids=batch['utt_ids'],
                        speakers=speakers,
                        task='ys_sub1')
                    # TODO(hirofumi): support ys_sub2

                    assert not streaming

                    nbest_hyps[0] = resolve_unk(
                        nbest_hyps[0],
                        best_hyps_id_char[0],
                        aws[b],
                        aw_char[0],
                        dataloader.idx2token[1],
                        subsample_factor_word=np.prod(models[0].subsample),
                        subsample_factor_char=np.prod(
                            models[0].subsample[:models[0].enc_n_layers_sub1 -
                                                1]))
                    logger.debug('Hyp (after OOV resolution): %s' %
                                 nbest_hyps[0])
                    nbest_hyps[0] = nbest_hyps[0].replace('*', '')

                    # Compute CER
                    ref_char = ref
                    hyp_char = nbest_hyps[0]
                    if dataloader.corpus == 'csj':
                        ref_char = ref_char.replace(' ', '')
                        hyp_char = hyp_char.replace(' ', '')
                    err_b, sub_b, ins_b, del_b = compute_wer(
                        ref=list(ref_char), hyp=list(hyp_char))
                    cer += err_b
                    n_sub_c += sub_b
                    n_ins_c += ins_b
                    n_del_c += del_b
                    n_char += len(ref_char)

                # Write to trn
                speaker = str(batch['speakers'][b]).replace('-', '_')
                if streaming:
                    utt_id = str(batch['utt_ids'][b]) + '_0000000_0000001'
                else:
                    utt_id = str(batch['utt_ids'][b])
                f_ref.write(ref + ' (' + speaker + '-' + utt_id + ')\n')
                f_hyp.write(nbest_hyps[0] + ' (' + speaker + '-' + utt_id +
                            ')\n')
                logger.debug('utt-id (%d/%d): %s' %
                             (n_utt + 1, len(dataloader), utt_id))
                logger.debug('Ref: %s' % ref)
                logger.debug('Hyp: %s' % nbest_hyps[0])
                logger.debug('-' * 150)

                if edit_distance and not streaming:
                    # Compute WER
                    err_b, sub_b, ins_b, del_b = compute_wer(
                        ref=ref.split(' '), hyp=nbest_hyps[0].split(' '))
                    wer += err_b
                    n_sub_w += sub_b
                    n_ins_w += ins_b
                    n_del_w += del_b
                    n_word += len(ref.split(' '))

                    # Compute oracle WER
                    if oracle and len(nbest_hyps) > 1:
                        wers_b = [err_b] + [
                            compute_wer(ref=ref.split(' '),
                                        hyp=hyp_n.split(' '))[0]
                            for hyp_n in nbest_hyps[1:]
                        ]
                        oracle_idx = np.argmin(np.array(wers_b))
                        if oracle_idx == 0:
                            n_oracle_hit += len(batch['utt_ids'])
                        wer_oracle += wers_b[oracle_idx]
                        # NOTE: OOV resolution is not considered

                    if fine_grained:
                        xlen_bin = (batch['xlens'][b] // 200 + 1) * 200
                        if xlen_bin in wer_dist.keys():
                            wer_dist[xlen_bin] += [err_b / 100]
                        else:
                            wer_dist[xlen_bin] = [err_b / 100]

            n_utt += len(batch['utt_ids'])
            if progressbar:
                pbar.update(len(batch['utt_ids']))

    if progressbar:
        pbar.close()

    # Reset data counters
    dataloader.reset(is_new_epoch=True)

    if edit_distance and not streaming:
        wer /= n_word
        n_sub_w /= n_word
        n_ins_w /= n_word
        n_del_w /= n_word

        if n_char > 0:
            cer /= n_char
            n_sub_c /= n_char
            n_ins_c /= n_char
            n_del_c /= n_char

        if recog_params.get('recog_beam_width') > 1:
            logger.info('WER (%s): %.2f %%' % (dataloader.set, wer))
            logger.info('SUB: %.2f / INS: %.2f / DEL: %.2f' %
                        (n_sub_w, n_ins_w, n_del_w))
            logger.info('CER (%s): %.2f %%' % (dataloader.set, cer))
            logger.info('SUB: %.2f / INS: %.2f / DEL: %.2f' %
                        (n_sub_c, n_ins_c, n_del_c))
            logger.info('OOV (total): %d' % (n_oov_total))

        if oracle:
            wer_oracle /= n_word
            oracle_hit_rate = n_oracle_hit * 100 / n_utt
            logger.info('Oracle WER (%s): %.2f %%' %
                        (dataloader.set, wer_oracle))
            logger.info('Oracle hit rate (%s): %.2f %%' %
                        (dataloader.set, oracle_hit_rate))

        if fine_grained:
            for len_bin, wers in sorted(wer_dist.items(), key=lambda x: x[0]):
                logger.info('  WER (%s): %.2f %% (%d)' %
                            (dataloader.set, sum(wers) / len(wers), len_bin))

    return wer, cer, n_oov_total
Example #3
0
def eval_word(models,
              dataset,
              recog_params,
              epoch,
              recog_dir=None,
              streaming=False,
              progressbar=False):
    """Evaluate the word-level model by WER.

    Args:
        models (list): models to evaluate
        dataset (Dataset): evaluation dataset
        recog_params (dict):
        epoch (int):
        recog_dir (str):
        streaming (bool): streaming decoding for the session-level evaluation
        progressbar (bool): visualize the progressbar
    Returns:
        wer (float): Word error rate
        cer (float): Character error rate
        n_oov_total (int): totol number of OOV

    """
    # Reset data counter
    dataset.reset(recog_params['recog_batch_size'])

    if recog_dir is None:
        recog_dir = 'decode_' + dataset.set + '_ep' + str(
            epoch) + '_beam' + str(recog_params['recog_beam_width'])
        recog_dir += '_lp' + str(recog_params['recog_length_penalty'])
        recog_dir += '_cp' + str(recog_params['recog_coverage_penalty'])
        recog_dir += '_' + str(
            recog_params['recog_min_len_ratio']) + '_' + str(
                recog_params['recog_max_len_ratio'])
        recog_dir += '_lm' + str(recog_params['recog_lm_weight'])

        ref_trn_save_path = mkdir_join(models[0].save_path, recog_dir,
                                       'ref.trn')
        hyp_trn_save_path = mkdir_join(models[0].save_path, recog_dir,
                                       'hyp.trn')
    else:
        ref_trn_save_path = mkdir_join(recog_dir, 'ref.trn')
        hyp_trn_save_path = mkdir_join(recog_dir, 'hyp.trn')

    wer, cer = 0, 0
    n_sub_w, n_ins_w, n_del_w = 0, 0, 0
    n_sub_c, n_ins_c, n_del_c = 0, 0, 0
    n_word, n_char = 0, 0
    n_oov_total = 0
    if progressbar:
        pbar = tqdm(total=len(dataset))

    with open(hyp_trn_save_path, 'w') as f_hyp, open(ref_trn_save_path,
                                                     'w') as f_ref:
        while True:
            batch, is_new_epoch = dataset.next(
                recog_params['recog_batch_size'])
            if streaming or recog_params['recog_chunk_sync']:
                best_hyps_id, _ = models[0].decode_streaming(
                    batch['xs'],
                    recog_params,
                    dataset.idx2token[0],
                    exclude_eos=True)
            else:
                best_hyps_id, aws = models[0].decode(
                    batch['xs'],
                    recog_params,
                    idx2token=dataset.idx2token[0] if progressbar else None,
                    exclude_eos=True,
                    refs_id=batch['ys'],
                    utt_ids=batch['utt_ids'],
                    speakers=batch['sessions' if dataset.corpus ==
                                   'swbd' else 'speakers'],
                    ensemble_models=models[1:] if len(models) > 1 else [])

            for b in range(len(batch['xs'])):
                ref = batch['text'][b]
                hyp = dataset.idx2token[0](best_hyps_id[b])

                n_oov_total += hyp.count('<unk>')

                # Resolving UNK
                if recog_params['recog_resolving_unk'] and '<unk>' in hyp:
                    recog_params_char = copy.deepcopy(recog_params)
                    recog_params_char['recog_lm_weight'] = 0
                    recog_params_char['recog_beam_width'] = 1
                    best_hyps_id_char, aw_char = models[0].decode(
                        batch['xs'][b:b + 1],
                        recog_params_char,
                        idx2token=dataset.idx2token[1]
                        if progressbar else None,
                        exclude_eos=True,
                        refs_id=batch['ys_sub1'],
                        utt_ids=batch['utt_ids'],
                        speakers=batch['sessions']
                        if dataset.corpus == 'swbd' else batch['speakers'],
                        task='ys_sub1')
                    # TODO(hirofumi): support ys_sub2 and ys_sub3

                    assert not streaming

                    hyp = resolve_unk(
                        hyp,
                        best_hyps_id_char[0],
                        aws[b],
                        aw_char[0],
                        dataset.idx2token[1],
                        subsample_factor_word=np.prod(models[0].subsample),
                        subsample_factor_char=np.prod(
                            models[0].subsample[:models[0].enc_n_layers_sub1 -
                                                1]))
                    logger.debug('Hyp (after OOV resolution): %s' % hyp)
                    hyp = hyp.replace('*', '')

                    # Compute CER
                    ref_char = ref
                    hyp_char = hyp
                    if dataset.corpus == 'csj':
                        ref_char = ref.replace(' ', '')
                        hyp_char = hyp.replace(' ', '')
                    cer_b, sub_b, ins_b, del_b = compute_wer(
                        ref=list(ref_char),
                        hyp=list(hyp_char),
                        normalize=False)
                    cer += cer_b
                    n_sub_c += sub_b
                    n_ins_c += ins_b
                    n_del_c += del_b
                    n_char += len(ref_char)

                # Write to trn
                speaker = str(batch['speakers'][b]).replace('-', '_')
                if streaming:
                    utt_id = str(batch['utt_ids'][b]) + '_0000000_0000001'
                else:
                    utt_id = str(batch['utt_ids'][b])
                f_ref.write(ref + ' (' + speaker + '-' + utt_id + ')\n')
                f_hyp.write(hyp + ' (' + speaker + '-' + utt_id + ')\n')
                logger.debug('utt-id: %s' % utt_id)
                logger.debug('Ref: %s' % ref)
                logger.debug('Hyp: %s' % hyp)
                logger.debug('-' * 150)

                if not streaming:
                    # Compute WER
                    wer_b, sub_b, ins_b, del_b = compute_wer(
                        ref=ref.split(' '),
                        hyp=hyp.split(' '),
                        normalize=False)
                    wer += wer_b
                    n_sub_w += sub_b
                    n_ins_w += ins_b
                    n_del_w += del_b
                    n_word += len(ref.split(' '))

                if progressbar:
                    pbar.update(1)

            if is_new_epoch:
                break

    if progressbar:
        pbar.close()

    # Reset data counters
    dataset.reset()

    if not streaming:
        wer /= n_word
        n_sub_w /= n_word
        n_ins_w /= n_word
        n_del_w /= n_word

        if n_char > 0:
            cer /= n_char
            n_sub_c /= n_char
            n_ins_c /= n_char
            n_del_c /= n_char

    logger.debug('WER (%s): %.2f %%' % (dataset.set, wer))
    logger.debug('SUB: %.2f / INS: %.2f / DEL: %.2f' %
                 (n_sub_w, n_ins_w, n_del_w))
    logger.debug('CER (%s): %.2f %%' % (dataset.set, cer))
    logger.debug('SUB: %.2f / INS: %.2f / DEL: %.2f' %
                 (n_sub_c, n_ins_c, n_del_c))
    logger.debug('OOV (total): %d' % (n_oov_total))

    return wer, cer, n_oov_total
Example #4
0
def eval_word(models,
              dataset,
              decode_params,
              epoch,
              decode_dir=None,
              progressbar=False):
    """Evaluate the word-level model by WER.

    Args:
        models (list): the models to evaluate
        dataset: An instance of a `Dataset' class
        decode_params (dict):
        epoch (int):
        decode_dir (str):
        progressbar (bool): if True, visualize the progressbar
    Returns:
        wer (float): Word error rate
        nsub (int): the number of substitution errors
        nins (int): the number of insertion errors
        ndel (int): the number of deletion errors
        noov_total (int):

    """
    # Reset data counter
    dataset.reset()

    model = models[0]
    # TODO(hirofumi): ensemble decoding

    if decode_dir is None:
        decode_dir = 'decode_' + dataset.set + '_ep' + str(
            epoch) + '_beam' + str(decode_params['beam_width'])
        decode_dir += '_lp' + str(decode_params['length_penalty'])
        decode_dir += '_cp' + str(decode_params['coverage_penalty'])
        decode_dir += '_' + str(decode_params['min_len_ratio']) + '_' + str(
            decode_params['max_len_ratio'])
        decode_dir += '_rnnlm' + str(decode_params['rnnlm_weight'])

        ref_trn_save_path = mkdir_join(model.save_path, decode_dir, 'ref.trn')
        hyp_trn_save_path = mkdir_join(model.save_path, decode_dir, 'hyp.trn')
    else:
        ref_trn_save_path = mkdir_join(decode_dir, 'ref.trn')
        hyp_trn_save_path = mkdir_join(decode_dir, 'hyp.trn')

    wer = 0
    nsub, nins, ndel = 0, 0, 0
    nword = 0
    noov_total = 0
    if progressbar:
        pbar = tqdm(total=len(dataset))  # TODO(hirofumi): fix this

    with open(hyp_trn_save_path, 'w') as f_hyp, open(ref_trn_save_path,
                                                     'w') as f_ref:
        while True:
            batch, is_new_epoch = dataset.next(decode_params['batch_size'])
            best_hyps, aws, perm_ids = model.decode(batch['xs'],
                                                    decode_params,
                                                    exclude_eos=True)
            ys = [batch['text'][i] for i in perm_ids]

            for b in six.moves.range(len(batch['xs'])):
                ref = ys[b]
                hyp = dataset.id2word(best_hyps[b])

                noov_total += hyp.count('<unk>')

                # Resolving UNK
                if decode_params['resolving_unk'] and '<unk>' in hyp:
                    best_hyps_sub, aw_sub, _ = model.decode(batch['xs'][b:b +
                                                                        1],
                                                            decode_params,
                                                            exclude_eos=True)
                    # task_index=1

                    hyp = resolve_unk(
                        hyp,
                        best_hyps_sub[0],
                        aws[b],
                        aw_sub[0],
                        dataset.id2char,
                        diff_time_resolution=2**sum(model.subsample) //
                        2**sum(model.subsample[:model.enc_nlayers_sub - 1]))
                    hyp = hyp.replace('*', '')

                # Write to trn
                speaker = '_'.join(batch['utt_ids'][b].replace(
                    '-', '_').split('_')[:-2])
                start = batch['utt_ids'][b].replace('-', '_').split('_')[-2]
                end = batch['utt_ids'][b].replace('-', '_').split('_')[-1]
                f_ref.write(ref + ' (' + speaker + '-' + start + '-' + end +
                            ')\n')
                f_hyp.write(hyp + ' (' + speaker + '-' + start + '-' + end +
                            ')\n')
                logger.info('utt-id: %s' % batch['utt_ids'][b])
                # logger.info('Ref: %s' % ref.lower())
                logger.info('Ref: %s' % ref)
                logger.info('Hyp: %s' % hyp)
                logger.info('-' * 50)

                # Compute WER
                wer_b, sub_b, ins_b, del_b = compute_wer(ref=ref.split(' '),
                                                         hyp=hyp.split(' '),
                                                         normalize=False)
                wer += wer_b
                nsub += sub_b
                nins += ins_b
                ndel += del_b
                nword += len(ref.split(' '))
                # logger.info('WER: %d%%' % (wer_b / len(ref.split(' '))))

                if progressbar:
                    pbar.update(1)

            if is_new_epoch:
                break

    if progressbar:
        pbar.close()

    # Reset data counters
    dataset.reset()

    wer /= nword
    nsub /= nword
    nins /= nword
    ndel /= nword

    return wer, nsub, nins, ndel, noov_total