コード例 #1
0
ファイル: word.py プロジェクト: nipengmath/neural_sp
def eval_word(models, dataset, decode_params, epoch, progressbar=False):
    """Evaluate the word-level model by WER.

    Args:
        models (list): the models to evaluate
        dataset: An instance of a `Dataset' class
        decode_params (dict):
        epoch (int):
        progressbar (bool): if True, visualize the progressbar
    Returns:
        wer (float): Word error rate
        num_sub (int): the number of substitution errors
        num_ins (int): the number of insertion errors
        num_del (int): the number of deletion errors
        decode_dir (str):

    """
    # Reset data counter
    dataset.reset()

    model = models[0]
    # TODO(hirofumi): ensemble decoding

    decode_dir = 'decode_' + dataset.set + '_ep' + str(epoch) + '_beam' + str(
        decode_params['beam_width'])
    decode_dir += '_lp' + str(decode_params['length_penalty'])
    decode_dir += '_cp' + str(decode_params['coverage_penalty'])
    decode_dir += '_' + str(decode_params['min_len_ratio']) + '_' + str(
        decode_params['max_len_ratio'])
    decode_dir += '_rnnlm' + str(decode_params['rnnlm_weight'])

    ref_trn_save_path = mkdir_join(model.save_path, decode_dir, 'ref.trn')
    hyp_trn_save_path = mkdir_join(model.save_path, decode_dir, 'hyp.trn')

    wer = 0
    num_sub, num_ins, num_del, = 0, 0, 0
    num_words = 0
    num_oov_total = 0
    if progressbar:
        pbar = tqdm(total=len(dataset))  # TODO(hirofumi): fix this

    with open(hyp_trn_save_path, 'w') as f_hyp, open(ref_trn_save_path,
                                                     'w') as f_ref:
        while True:
            batch, is_new_epoch = dataset.next(decode_params['batch_size'])
            best_hyps, aw, perm_idx = model.decode(batch['xs'],
                                                   decode_params,
                                                   exclude_eos=True)
            ys = [batch['ys'][i] for i in perm_idx]

            for b in range(len(batch['xs'])):
                # Reference
                if dataset.is_test:
                    text_ref = ys[b]
                else:
                    text_ref = dataset.idx2word(ys[b])

                # Hypothesis
                text_hyp = dataset.idx2word(best_hyps[b])
                num_oov_total += text_hyp.count('<unk>')

                # Resolving UNK
                if decode_params['resolving_unk'] and '<unk>' in text_hyp:
                    best_hyps_sub, aw_sub, _ = model.decode(batch['xs'][b:b +
                                                                        1],
                                                            batch['xs'],
                                                            decode_params,
                                                            exclude_eos=True)
                    # task_index=1

                    text_hyp = resolve_unk(
                        text_hyp,
                        best_hyps_sub[0],
                        aw[b],
                        aw_sub[0],
                        dataset.idx2char,
                        diff_time_resolution=2**sum(model.subsample_list) //
                        2**sum(
                            model.
                            subsample_list[:model.encoder_num_layers_sub - 1]))
                    text_hyp = text_hyp.replace('*', '')

                # Write to trn
                speaker = '_'.join(batch['utt_ids'][b].replace(
                    '-', '_').split('_')[:-2])
                start = batch['utt_ids'][b].replace('-', '_').split('_')[-2]
                end = batch['utt_ids'][b].replace('-', '_').split('_')[-1]
                f_ref.write(text_ref + ' (' + speaker + '-' + start + '-' +
                            end + ')\n')
                f_hyp.write(text_hyp + ' (' + speaker + '-' + start + '-' +
                            end + ')\n')

                # Compute WER
                wer_b, sub_b, ins_b, del_b = compute_wer(
                    ref=text_ref.split(' '),
                    hyp=text_hyp.split(' '),
                    normalize=False)
                wer += wer_b
                num_sub += sub_b
                num_ins += ins_b
                num_del += del_b
                num_words += len(text_ref.split(' '))

                if progressbar:
                    pbar.update(1)

            if is_new_epoch:
                break

    if progressbar:
        pbar.close()

    # Reset data counters
    dataset.reset()

    wer /= num_words
    num_sub /= num_words
    num_ins /= num_words
    num_del /= num_words

    return wer, num_sub, num_ins, num_del, os.path.join(
        model.save_path, decode_dir)
コード例 #2
0
def eval_char(models, dataloader, recog_params, epoch,
              recog_dir=None, streaming=False, progressbar=False, task_idx=0):
    """Evaluate the character-level model by WER & CER.

    Args:
        models (list): models to evaluate
        dataloader (torch.utils.data.DataLoader): evaluation dataloader
        recog_params (dict):
        epoch (int):
        recog_dir (str):
        streaming (bool): streaming decoding for the session-level evaluation
        progressbar (bool): visualize the progressbar
        task_idx (int): the index of the target task in interest
            0: main task
            1: sub task
            2: sub sub task
    Returns:
        wer (float): Word error rate
        cer (float): Character error rate

    """
    if recog_dir is None:
        recog_dir = 'decode_' + dataloader.set + '_ep' + str(epoch) + '_beam' + str(recog_params['recog_beam_width'])
        recog_dir += '_lp' + str(recog_params['recog_length_penalty'])
        recog_dir += '_cp' + str(recog_params['recog_coverage_penalty'])
        recog_dir += '_' + str(recog_params['recog_min_len_ratio']) + '_' + str(recog_params['recog_max_len_ratio'])
        recog_dir += '_lm' + str(recog_params['recog_lm_weight'])

        ref_trn_path = mkdir_join(models[0].save_path, recog_dir, 'ref.trn')
        hyp_trn_path = mkdir_join(models[0].save_path, recog_dir, 'hyp.trn')
    else:
        ref_trn_path = mkdir_join(recog_dir, 'ref.trn')
        hyp_trn_path = mkdir_join(recog_dir, 'hyp.trn')

    wer, cer = 0, 0
    n_sub_w, n_ins_w, n_del_w = 0, 0, 0
    n_sub_c, n_ins_c, n_del_c = 0, 0, 0
    n_word, n_char = 0, 0
    n_streamable, quantity_rate, n_utt = 0, 0, 0
    last_success_frame_ratio = 0

    # Reset data counter
    dataloader.reset(recog_params['recog_batch_size'])

    if progressbar:
        pbar = tqdm(total=len(dataloader))

    if task_idx == 0:
        task = 'ys'
    elif task_idx == 1:
        task = 'ys_sub1'
    elif task_idx == 2:
        task = 'ys_sub2'
    elif task_idx == 3:
        task = 'ys_sub3'

    with codecs.open(hyp_trn_path, 'w', encoding='utf-8') as f_hyp, \
            codecs.open(ref_trn_path, 'w', encoding='utf-8') as f_ref:
        while True:
            batch, is_new_epoch = dataloader.next(recog_params['recog_batch_size'])
            if streaming or recog_params['recog_chunk_sync']:
                best_hyps_id, _ = models[0].decode_streaming(
                    batch['xs'], recog_params, dataloader.idx2token[0],
                    exclude_eos=True)
            else:
                best_hyps_id, _ = models[0].decode(
                    batch['xs'], recog_params,
                    idx2token=dataloader.idx2token[task_idx] if progressbar else None,
                    exclude_eos=True,
                    refs_id=batch['ys'] if task_idx == 0 else batch['ys_sub' + str(task_idx)],
                    utt_ids=batch['utt_ids'],
                    speakers=batch['sessions' if dataloader.corpus == 'swbd' else 'speakers'],
                    task=task,
                    ensemble_models=models[1:] if len(models) > 1 else [])

            for b in range(len(batch['xs'])):
                ref = batch['text'][b]
                hyp = dataloader.idx2token[task_idx](best_hyps_id[b])

                # Truncate the first and last spaces for the char_space unit
                if len(hyp) > 0 and hyp[0] == ' ':
                    hyp = hyp[1:]
                if len(hyp) > 0 and hyp[-1] == ' ':
                    hyp = hyp[:-1]

                # Write to trn
                speaker = str(batch['speakers'][b]).replace('-', '_')
                if streaming:
                    utt_id = str(batch['utt_ids'][b]) + '_0000000_0000001'
                else:
                    utt_id = str(batch['utt_ids'][b])
                f_ref.write(ref + ' (' + speaker + '-' + utt_id + ')\n')
                f_hyp.write(hyp + ' (' + speaker + '-' + utt_id + ')\n')
                logger.debug('utt-id: %s' % utt_id)
                logger.debug('Ref: %s' % ref)
                logger.debug('Hyp: %s' % hyp)
                logger.debug('-' * 150)

                if not streaming:
                    if ('char' in dataloader.unit and 'nowb' not in dataloader.unit) or (task_idx > 0 and dataloader.unit_sub1 == 'char'):
                        # Compute WER
                        wer_b, sub_b, ins_b, del_b = compute_wer(ref=ref.split(' '),
                                                                 hyp=hyp.split(' '),
                                                                 normalize=False)
                        wer += wer_b
                        n_sub_w += sub_b
                        n_ins_w += ins_b
                        n_del_w += del_b
                        n_word += len(ref.split(' '))
                        # NOTE: sentence error rate for Chinese

                    # Compute CER
                    if dataloader.corpus == 'csj':
                        ref = ref.replace(' ', '')
                        hyp = hyp.replace(' ', '')
                    cer_b, sub_b, ins_b, del_b = compute_wer(ref=list(ref),
                                                             hyp=list(hyp),
                                                             normalize=False)
                    cer += cer_b
                    n_sub_c += sub_b
                    n_ins_c += ins_b
                    n_del_c += del_b
                    n_char += len(ref)
                    if models[0].streamable():
                        n_streamable += 1
                    else:
                        last_success_frame_ratio += models[0].last_success_frame_ratio()
                    quantity_rate += models[0].quantity_rate()
                    n_utt += 1

                if progressbar:
                    pbar.update(1)

            if is_new_epoch:
                break

    if progressbar:
        pbar.close()

    # Reset data counters
    dataloader.reset()

    if not streaming:
        if ('char' in dataloader.unit and 'nowb' not in dataloader.unit) or (task_idx > 0 and dataloader.unit_sub1 == 'char'):
            wer /= n_word
            n_sub_w /= n_word
            n_ins_w /= n_word
            n_del_w /= n_word
        else:
            wer = n_sub_w = n_ins_w = n_del_w = 0

        cer /= n_char
        n_sub_c /= n_char
        n_ins_c /= n_char
        n_del_c /= n_char

        if n_utt - n_streamable > 0:
            last_success_frame_ratio /= (n_utt - n_streamable)
        n_streamable /= n_utt
        quantity_rate /= n_utt

    logger.debug('WER (%s): %.2f %%' % (dataloader.set, wer))
    logger.debug('SUB: %.2f / INS: %.2f / DEL: %.2f' % (n_sub_w, n_ins_w, n_del_w))
    logger.debug('CER (%s): %.2f %%' % (dataloader.set, cer))
    logger.debug('SUB: %.2f / INS: %.2f / DEL: %.2f' % (n_sub_c, n_ins_c, n_del_c))

    logger.info('Streamability (%s): %.2f %%' % (dataloader.set, n_streamable * 100))
    logger.info('Quantity rate (%s): %.2f %%' % (dataloader.set, quantity_rate * 100))
    logger.info('Last success frame ratio (%s): %.2f %%' % (dataloader.set, last_success_frame_ratio))

    return wer, cer
コード例 #3
0
def eval_word(models,
              dataset,
              recog_params,
              epoch,
              recog_dir=None,
              streaming=False,
              progressbar=False):
    """Evaluate the word-level model by WER.

    Args:
        models (list): models to evaluate
        dataset (Dataset): evaluation dataset
        recog_params (dict):
        epoch (int):
        recog_dir (str):
        streaming (bool): streaming decoding for the session-level evaluation
        progressbar (bool): visualize the progressbar
    Returns:
        wer (float): Word error rate
        cer (float): Character error rate
        n_oov_total (int): totol number of OOV

    """
    # Reset data counter
    dataset.reset(recog_params['recog_batch_size'])

    if recog_dir is None:
        recog_dir = 'decode_' + dataset.set + '_ep' + str(
            epoch) + '_beam' + str(recog_params['recog_beam_width'])
        recog_dir += '_lp' + str(recog_params['recog_length_penalty'])
        recog_dir += '_cp' + str(recog_params['recog_coverage_penalty'])
        recog_dir += '_' + str(
            recog_params['recog_min_len_ratio']) + '_' + str(
                recog_params['recog_max_len_ratio'])
        recog_dir += '_lm' + str(recog_params['recog_lm_weight'])

        ref_trn_save_path = mkdir_join(models[0].save_path, recog_dir,
                                       'ref.trn')
        hyp_trn_save_path = mkdir_join(models[0].save_path, recog_dir,
                                       'hyp.trn')
    else:
        ref_trn_save_path = mkdir_join(recog_dir, 'ref.trn')
        hyp_trn_save_path = mkdir_join(recog_dir, 'hyp.trn')

    wer, cer = 0, 0
    n_sub_w, n_ins_w, n_del_w = 0, 0, 0
    n_sub_c, n_ins_c, n_del_c = 0, 0, 0
    n_word, n_char = 0, 0
    n_oov_total = 0
    if progressbar:
        pbar = tqdm(total=len(dataset))

    with open(hyp_trn_save_path, 'w') as f_hyp, open(ref_trn_save_path,
                                                     'w') as f_ref:
        while True:
            batch, is_new_epoch = dataset.next(
                recog_params['recog_batch_size'])
            if streaming or recog_params['recog_chunk_sync']:
                best_hyps_id, _ = models[0].decode_streaming(
                    batch['xs'],
                    recog_params,
                    dataset.idx2token[0],
                    exclude_eos=True)
            else:
                best_hyps_id, aws = models[0].decode(
                    batch['xs'],
                    recog_params,
                    idx2token=dataset.idx2token[0] if progressbar else None,
                    exclude_eos=True,
                    refs_id=batch['ys'],
                    utt_ids=batch['utt_ids'],
                    speakers=batch['sessions' if dataset.corpus ==
                                   'swbd' else 'speakers'],
                    ensemble_models=models[1:] if len(models) > 1 else [])

            for b in range(len(batch['xs'])):
                ref = batch['text'][b]
                hyp = dataset.idx2token[0](best_hyps_id[b])

                n_oov_total += hyp.count('<unk>')

                # Resolving UNK
                if recog_params['recog_resolving_unk'] and '<unk>' in hyp:
                    recog_params_char = copy.deepcopy(recog_params)
                    recog_params_char['recog_lm_weight'] = 0
                    recog_params_char['recog_beam_width'] = 1
                    best_hyps_id_char, aw_char = models[0].decode(
                        batch['xs'][b:b + 1],
                        recog_params_char,
                        idx2token=dataset.idx2token[1]
                        if progressbar else None,
                        exclude_eos=True,
                        refs_id=batch['ys_sub1'],
                        utt_ids=batch['utt_ids'],
                        speakers=batch['sessions']
                        if dataset.corpus == 'swbd' else batch['speakers'],
                        task='ys_sub1')
                    # TODO(hirofumi): support ys_sub2 and ys_sub3

                    assert not streaming

                    hyp = resolve_unk(
                        hyp,
                        best_hyps_id_char[0],
                        aws[b],
                        aw_char[0],
                        dataset.idx2token[1],
                        subsample_factor_word=np.prod(models[0].subsample),
                        subsample_factor_char=np.prod(
                            models[0].subsample[:models[0].enc_n_layers_sub1 -
                                                1]))
                    logger.debug('Hyp (after OOV resolution): %s' % hyp)
                    hyp = hyp.replace('*', '')

                    # Compute CER
                    ref_char = ref
                    hyp_char = hyp
                    if dataset.corpus == 'csj':
                        ref_char = ref.replace(' ', '')
                        hyp_char = hyp.replace(' ', '')
                    cer_b, sub_b, ins_b, del_b = compute_wer(
                        ref=list(ref_char),
                        hyp=list(hyp_char),
                        normalize=False)
                    cer += cer_b
                    n_sub_c += sub_b
                    n_ins_c += ins_b
                    n_del_c += del_b
                    n_char += len(ref_char)

                # Write to trn
                speaker = str(batch['speakers'][b]).replace('-', '_')
                if streaming:
                    utt_id = str(batch['utt_ids'][b]) + '_0000000_0000001'
                else:
                    utt_id = str(batch['utt_ids'][b])
                f_ref.write(ref + ' (' + speaker + '-' + utt_id + ')\n')
                f_hyp.write(hyp + ' (' + speaker + '-' + utt_id + ')\n')
                logger.debug('utt-id: %s' % utt_id)
                logger.debug('Ref: %s' % ref)
                logger.debug('Hyp: %s' % hyp)
                logger.debug('-' * 150)

                if not streaming:
                    # Compute WER
                    wer_b, sub_b, ins_b, del_b = compute_wer(
                        ref=ref.split(' '),
                        hyp=hyp.split(' '),
                        normalize=False)
                    wer += wer_b
                    n_sub_w += sub_b
                    n_ins_w += ins_b
                    n_del_w += del_b
                    n_word += len(ref.split(' '))

                if progressbar:
                    pbar.update(1)

            if is_new_epoch:
                break

    if progressbar:
        pbar.close()

    # Reset data counters
    dataset.reset()

    if not streaming:
        wer /= n_word
        n_sub_w /= n_word
        n_ins_w /= n_word
        n_del_w /= n_word

        if n_char > 0:
            cer /= n_char
            n_sub_c /= n_char
            n_ins_c /= n_char
            n_del_c /= n_char

    logger.debug('WER (%s): %.2f %%' % (dataset.set, wer))
    logger.debug('SUB: %.2f / INS: %.2f / DEL: %.2f' %
                 (n_sub_w, n_ins_w, n_del_w))
    logger.debug('CER (%s): %.2f %%' % (dataset.set, cer))
    logger.debug('SUB: %.2f / INS: %.2f / DEL: %.2f' %
                 (n_sub_c, n_ins_c, n_del_c))
    logger.debug('OOV (total): %d' % (n_oov_total))

    return wer, cer, n_oov_total
コード例 #4
0
def eval_wordpiece(models,
                   dataset,
                   recog_params,
                   epoch,
                   recog_dir=None,
                   progressbar=False):
    """Evaluate the wordpiece-level model by WER.

    Args:
        models (list): models to evaluate
        dataset: An instance of a `Dataset' class
        recog_params (recog_dict):
        epoch (int):
        recog_dir (str):
        progressbar (bool): visualize the progressbar
    Returns:
        wer (float): Word error rate
        cer (float): Character error rate

    """
    # Reset data counter
    dataset.reset()

    if recog_dir is None:
        recog_dir = 'decode_' + dataset.set + '_ep' + str(
            epoch) + '_beam' + str(recog_params['recog_beam_width'])
        recog_dir += '_lp' + str(recog_params['recog_length_penalty'])
        recog_dir += '_cp' + str(recog_params['recog_coverage_penalty'])
        recog_dir += '_' + str(
            recog_params['recog_min_len_ratio']) + '_' + str(
                recog_params['recog_max_len_ratio'])
        recog_dir += '_lm' + str(recog_params['recog_lm_weight'])

        ref_trn_save_path = mkdir_join(models[0].save_path, recog_dir,
                                       'ref.trn')
        hyp_trn_save_path = mkdir_join(models[0].save_path, recog_dir,
                                       'hyp.trn')
    else:
        ref_trn_save_path = mkdir_join(recog_dir, 'ref.trn')
        hyp_trn_save_path = mkdir_join(recog_dir, 'hyp.trn')

    wer, cer = 0, 0
    n_sub_w, n_ins_w, n_del_w = 0, 0, 0
    n_sub_c, n_ins_c, n_del_c = 0, 0, 0
    n_word, n_char = 0, 0
    if progressbar:
        pbar = tqdm(total=len(dataset))

    with open(hyp_trn_save_path, 'w') as f_hyp, open(ref_trn_save_path,
                                                     'w') as f_ref:
        while True:
            batch, is_new_epoch = dataset.next(
                recog_params['recog_batch_size'])
            best_hyps_id, _, perm_id, _ = models[0].decode(
                batch['xs'],
                recog_params,
                dataset.idx2token[0],
                exclude_eos=True,
                refs_id=batch['ys'],
                utt_ids=batch['utt_ids'],
                speakers=batch['sessions']
                if dataset.corpus == 'swbd' else batch['speakers'],
                ensemble_models=models[1:] if len(models) > 1 else [])
            ys = [batch['text'][i] for i in perm_id]

            for b in range(len(batch['xs'])):
                ref = ys[b]
                hyp = dataset.idx2token[0](best_hyps_id[b])

                # Write to trn
                utt_id = str(batch['utt_ids'][b])
                speaker = str(batch['speakers'][b]).replace('-', '_')
                f_ref.write(ref + ' (' + speaker + '-' + utt_id + ')\n')
                f_hyp.write(hyp + ' (' + speaker + '-' + utt_id + ')\n')
                logger.info('utt-id: %s' % batch['utt_ids'][b])
                logger.info('Ref: %s' % ref)
                logger.info('Hyp: %s' % hyp)
                logger.info('-' * 150)

                # Compute WER
                wer_b, sub_b, ins_b, del_b = compute_wer(ref=ref.split(' '),
                                                         hyp=hyp.split(' '),
                                                         normalize=False)
                wer += wer_b
                n_sub_w += sub_b
                n_ins_w += ins_b
                n_del_w += del_b
                n_word += len(ref.split(' '))

                # Compute CER
                if dataset.corpus == 'csj':
                    ref = ref.replace(' ', '')
                    hyp = hyp.replace(' ', '')
                cer_b, sub_b, ins_b, del_b = compute_wer(ref=list(ref),
                                                         hyp=list(hyp),
                                                         normalize=False)
                cer += cer_b
                n_sub_c += sub_b
                n_ins_c += ins_b
                n_del_c += del_b
                n_char += len(ref)

                if progressbar:
                    pbar.update(1)

            if is_new_epoch:
                break

    if progressbar:
        pbar.close()

    # Reset data counters
    dataset.reset()

    wer /= n_word
    n_sub_w /= n_word
    n_ins_w /= n_word
    n_del_w /= n_word

    cer /= n_char
    n_sub_c /= n_char
    n_ins_c /= n_char
    n_del_c /= n_char

    logger.info('WER (%s): %.2f %%' % (dataset.set, wer))
    logger.info('SUB: %.2f / INS: %.2f / DEL: %.2f' %
                (n_sub_w, n_ins_w, n_del_w))
    logger.info('CER (%s): %.2f %%' % (dataset.set, cer))
    logger.info('SUB: %.2f / INS: %.2f / DEL: %.2f' %
                (n_sub_c, n_ins_c, n_del_c))

    return wer, cer
コード例 #5
0
ファイル: phone.py プロジェクト: yuekaizhang/neural_sp
def eval_phone(models,
               dataloader,
               recog_params,
               epoch,
               recog_dir=None,
               streaming=False,
               progressbar=False,
               fine_grained=False,
               oracle=False,
               teacher_force=False):
    """Evaluate a phone-level model by PER.

    Args:
        models (List): models to evaluate
        dataloader (torch.utils.data.DataLoader): evaluation dataloader
        recog_params (dict):
        epoch (int):
        recog_dir (str):
        streaming (bool): streaming decoding for the session-level evaluation
        progressbar (bool): visualize the progressbar
        oracle (bool): calculate oracle PER
        fine_grained (bool): calculate fine-grained PER distributions based on input lengths
        teacher_force (bool): conduct decoding in teacher-forcing mode
    Returns:
        per (float): Phone error rate

    """
    if recog_dir is None:
        recog_dir = 'decode_' + dataloader.set + '_ep' + str(
            epoch) + '_beam' + str(recog_params['recog_beam_width'])
        recog_dir += '_lp' + str(recog_params['recog_length_penalty'])
        recog_dir += '_cp' + str(recog_params['recog_coverage_penalty'])
        recog_dir += '_' + str(
            recog_params['recog_min_len_ratio']) + '_' + str(
                recog_params['recog_max_len_ratio'])

        ref_trn_path = mkdir_join(models[0].save_path, recog_dir, 'ref.trn')
        hyp_trn_path = mkdir_join(models[0].save_path, recog_dir, 'hyp.trn')
    else:
        ref_trn_path = mkdir_join(recog_dir, 'ref.trn')
        hyp_trn_path = mkdir_join(recog_dir, 'hyp.trn')

    per = 0
    n_sub, n_ins, n_del = 0, 0, 0
    n_phone = 0
    per_dist = {}  # calculate PER distribution based on input lengths

    per_oracle = 0
    n_oracle_hit = 0
    n_utt = 0

    # Reset data counter
    dataloader.reset(recog_params['recog_batch_size'])

    if progressbar:
        pbar = tqdm(total=len(dataloader))

    with codecs.open(hyp_trn_path, 'w', encoding='utf-8') as f_hyp, \
            codecs.open(ref_trn_path, 'w', encoding='utf-8') as f_ref:
        while True:
            batch, is_new_epoch = dataloader.next(
                recog_params['recog_batch_size'])
            if streaming or recog_params['recog_block_sync']:
                nbest_hyps_id = models[0].decode_streaming(
                    batch['xs'],
                    recog_params,
                    dataloader.idx2token[0],
                    exclude_eos=True)[0]
            else:
                nbest_hyps_id = models[0].decode(
                    batch['xs'],
                    recog_params,
                    idx2token=dataloader.idx2token[0] if progressbar else None,
                    exclude_eos=True,
                    refs_id=batch['ys'],
                    utt_ids=batch['utt_ids'],
                    speakers=batch['sessions' if dataloader.corpus ==
                                   'swbd' else 'speakers'],
                    ensemble_models=models[1:] if len(models) > 1 else [])[0]

            for b in range(len(batch['xs'])):
                ref = batch['text'][b]
                nbest_hyps = [
                    dataloader.idx2token[0](hyp_id)
                    for hyp_id in nbest_hyps_id[b]
                ]

                # Write to trn
                speaker = str(batch['speakers'][b]).replace('-', '_')
                if streaming:
                    utt_id = str(batch['utt_ids'][b]) + '_0000000_0000001'
                else:
                    utt_id = str(batch['utt_ids'][b])
                f_ref.write(ref + ' (' + speaker + '-' + utt_id + ')\n')
                f_hyp.write(nbest_hyps[0] + ' (' + speaker + '-' + utt_id +
                            ')\n')
                logger.debug('utt-id: %s' % utt_id)
                logger.debug('Ref: %s' % ref)
                logger.debug('Hyp: %s' % nbest_hyps[0])
                logger.debug('-' * 150)

                if not streaming:
                    # Compute PER
                    err_b, sub_b, ins_b, del_b = compute_wer(
                        ref=ref.split(' '), hyp=nbest_hyps[0].split(' '))
                    per += err_b
                    n_sub += sub_b
                    n_ins += ins_b
                    n_del += del_b
                    n_phone += len(ref.split(' '))

                    # Compute oracle PER
                    if oracle and len(nbest_hyps) > 1:
                        pers_b = [err_b] + [
                            compute_wer(ref=ref.split(' '),
                                        hyp=hyp_n.split(' '))[0]
                            for hyp_n in nbest_hyps[1:]
                        ]
                        oracle_idx = np.argmin(np.array(pers_b))
                        if oracle_idx == 0:
                            n_oracle_hit += 1
                        per_oracle += pers_b[oracle_idx]

                n_utt += 1
                if progressbar:
                    pbar.update(1)

            if is_new_epoch:
                break

    if progressbar:
        pbar.close()

    # Reset data counters
    dataloader.reset()

    if not streaming:
        per /= n_phone
        n_sub /= n_phone
        n_ins /= n_phone
        n_del /= n_phone

        if recog_params['recog_beam_width'] > 1:
            logger.info('PER (%s): %.2f %%' % (dataloader.set, per))
            logger.info('SUB: %.2f / INS: %.2f / DEL: %.2f' %
                        (n_sub, n_ins, n_del))

        if oracle:
            per_oracle /= n_phone
            oracle_hit_rate = n_oracle_hit * 100 / n_utt
            logger.info('Oracle PER (%s): %.2f %%' %
                        (dataloader.set, per_oracle))
            logger.info('Oracle hit rate (%s): %.2f %%' %
                        (dataloader.set, oracle_hit_rate))

        if fine_grained:
            for len_bin, pers in sorted(per_dist.items(), key=lambda x: x[0]):
                logger.info('  PER (%s): %.2f %% (%d)' %
                            (dataloader.set, sum(pers) / len(pers), len_bin))

    return per
コード例 #6
0
ファイル: word.py プロジェクト: mbencherif/neural_sp
def eval_word(models,
              dataloader,
              recog_params,
              epoch,
              recog_dir=None,
              streaming=False,
              progressbar=False,
              edit_distance=True,
              fine_grained=False,
              oracle=False,
              teacher_force=False):
    """Evaluate a word-level model by WER.

    Args:
        models (List): models to evaluate
        dataloader (torch.utils.data.DataLoader): evaluation dataloader
        recog_params (omegaconf.dictconfig.DictConfig): decoding hyperparameters
        epoch (int): current epoch
        recog_dir (str): directory path to save hypotheses
        streaming (bool): streaming decoding for session-level evaluation
        progressbar (bool): visualize progressbar
        edit_distance (bool): calculate edit-distance (can be skipped for RTF calculation)
        fine_grained (bool): calculate fine-grained WER distributions based on input lengths
        oracle (bool): calculate oracle WER
        teacher_force (bool): conduct decoding in teacher-forcing mode
    Returns:
        wer (float): Word error rate
        cer (float): Character error rate
        n_oov_total (int): total number of OOV

    """
    if recog_dir is None:
        recog_dir = 'decode_' + dataloader.set + '_ep' + \
            str(epoch) + '_beam' + str(recog_params.get('recog_beam_width'))
        recog_dir += '_lp' + str(recog_params.get('recog_length_penalty'))
        recog_dir += '_cp' + str(recog_params.get('recog_coverage_penalty'))
        recog_dir += '_' + str(recog_params.get('recog_min_len_ratio')) + '_' + \
            str(recog_params.get('recog_max_len_ratio'))
        recog_dir += '_lm' + str(recog_params.get('recog_lm_weight'))

        ref_trn_path = mkdir_join(models[0].save_path, recog_dir, 'ref.trn')
        hyp_trn_path = mkdir_join(models[0].save_path, recog_dir, 'hyp.trn')
    else:
        ref_trn_path = mkdir_join(recog_dir, 'ref.trn')
        hyp_trn_path = mkdir_join(recog_dir, 'hyp.trn')

    wer, cer = 0, 0
    n_sub_w, n_ins_w, n_del_w = 0, 0, 0
    n_sub_c, n_ins_c, n_del_c = 0, 0, 0
    n_word, n_char = 0, 0
    wer_dist = {}  # calculate WER distribution based on input lengths
    n_oov_total = 0

    wer_oracle = 0
    n_oracle_hit = 0
    n_utt = 0

    # Reset data counter
    dataloader.reset(recog_params.get('recog_batch_size'))

    if progressbar:
        pbar = tqdm(total=len(dataloader))

    with codecs.open(hyp_trn_path, 'w', encoding='utf-8') as f_hyp, \
            codecs.open(ref_trn_path, 'w', encoding='utf-8') as f_ref:
        for batch in dataloader:
            speakers = batch['sessions' if dataloader.corpus ==
                             'swbd' else 'speakers']
            if streaming or recog_params.get('recog_block_sync'):
                nbest_hyps_id = models[0].decode_streaming(
                    batch['xs'],
                    recog_params,
                    dataloader.idx2token[0],
                    exclude_eos=True,
                    speaker=speakers[0])[0]
            else:
                nbest_hyps_id, aws = models[0].decode(
                    batch['xs'],
                    recog_params,
                    idx2token=dataloader.idx2token[0],
                    exclude_eos=True,
                    refs_id=batch['ys'],
                    utt_ids=batch['utt_ids'],
                    speakers=speakers,
                    ensemble_models=models[1:] if len(models) > 1 else [])

            for b in range(len(batch['xs'])):
                ref = batch['text'][b]
                nbest_hyps = [
                    dataloader.idx2token[0](hyp_id)
                    for hyp_id in nbest_hyps_id[b]
                ]
                n_oov_total += nbest_hyps[0].count('<unk>')

                # Resolving UNK
                if recog_params.get(
                        'recog_resolving_unk') and '<unk>' in nbest_hyps[0]:
                    recog_params_char = copy.deepcopy(recog_params)
                    recog_params_char['recog_lm_weight'] = 0
                    recog_params_char['recog_beam_width'] = 1
                    best_hyps_id_char, aw_char = models[0].decode(
                        batch['xs'][b:b + 1],
                        recog_params_char,
                        idx2token=dataloader.idx2token[1],
                        exclude_eos=True,
                        refs_id=batch['ys_sub1'],
                        utt_ids=batch['utt_ids'],
                        speakers=speakers,
                        task='ys_sub1')
                    # TODO(hirofumi): support ys_sub2

                    assert not streaming

                    nbest_hyps[0] = resolve_unk(
                        nbest_hyps[0],
                        best_hyps_id_char[0],
                        aws[b],
                        aw_char[0],
                        dataloader.idx2token[1],
                        subsample_factor_word=np.prod(models[0].subsample),
                        subsample_factor_char=np.prod(
                            models[0].subsample[:models[0].enc_n_layers_sub1 -
                                                1]))
                    logger.debug('Hyp (after OOV resolution): %s' %
                                 nbest_hyps[0])
                    nbest_hyps[0] = nbest_hyps[0].replace('*', '')

                    # Compute CER
                    ref_char = ref
                    hyp_char = nbest_hyps[0]
                    if dataloader.corpus == 'csj':
                        ref_char = ref_char.replace(' ', '')
                        hyp_char = hyp_char.replace(' ', '')
                    err_b, sub_b, ins_b, del_b = compute_wer(
                        ref=list(ref_char), hyp=list(hyp_char))
                    cer += err_b
                    n_sub_c += sub_b
                    n_ins_c += ins_b
                    n_del_c += del_b
                    n_char += len(ref_char)

                # Write to trn
                speaker = str(batch['speakers'][b]).replace('-', '_')
                if streaming:
                    utt_id = str(batch['utt_ids'][b]) + '_0000000_0000001'
                else:
                    utt_id = str(batch['utt_ids'][b])
                f_ref.write(ref + ' (' + speaker + '-' + utt_id + ')\n')
                f_hyp.write(nbest_hyps[0] + ' (' + speaker + '-' + utt_id +
                            ')\n')
                logger.debug('utt-id (%d/%d): %s' %
                             (n_utt + 1, len(dataloader), utt_id))
                logger.debug('Ref: %s' % ref)
                logger.debug('Hyp: %s' % nbest_hyps[0])
                logger.debug('-' * 150)

                if edit_distance and not streaming:
                    # Compute WER
                    err_b, sub_b, ins_b, del_b = compute_wer(
                        ref=ref.split(' '), hyp=nbest_hyps[0].split(' '))
                    wer += err_b
                    n_sub_w += sub_b
                    n_ins_w += ins_b
                    n_del_w += del_b
                    n_word += len(ref.split(' '))

                    # Compute oracle WER
                    if oracle and len(nbest_hyps) > 1:
                        wers_b = [err_b] + [
                            compute_wer(ref=ref.split(' '),
                                        hyp=hyp_n.split(' '))[0]
                            for hyp_n in nbest_hyps[1:]
                        ]
                        oracle_idx = np.argmin(np.array(wers_b))
                        if oracle_idx == 0:
                            n_oracle_hit += len(batch['utt_ids'])
                        wer_oracle += wers_b[oracle_idx]
                        # NOTE: OOV resolution is not considered

                    if fine_grained:
                        xlen_bin = (batch['xlens'][b] // 200 + 1) * 200
                        if xlen_bin in wer_dist.keys():
                            wer_dist[xlen_bin] += [err_b / 100]
                        else:
                            wer_dist[xlen_bin] = [err_b / 100]

            n_utt += len(batch['utt_ids'])
            if progressbar:
                pbar.update(len(batch['utt_ids']))

    if progressbar:
        pbar.close()

    # Reset data counters
    dataloader.reset(is_new_epoch=True)

    if edit_distance and not streaming:
        wer /= n_word
        n_sub_w /= n_word
        n_ins_w /= n_word
        n_del_w /= n_word

        if n_char > 0:
            cer /= n_char
            n_sub_c /= n_char
            n_ins_c /= n_char
            n_del_c /= n_char

        if recog_params.get('recog_beam_width') > 1:
            logger.info('WER (%s): %.2f %%' % (dataloader.set, wer))
            logger.info('SUB: %.2f / INS: %.2f / DEL: %.2f' %
                        (n_sub_w, n_ins_w, n_del_w))
            logger.info('CER (%s): %.2f %%' % (dataloader.set, cer))
            logger.info('SUB: %.2f / INS: %.2f / DEL: %.2f' %
                        (n_sub_c, n_ins_c, n_del_c))
            logger.info('OOV (total): %d' % (n_oov_total))

        if oracle:
            wer_oracle /= n_word
            oracle_hit_rate = n_oracle_hit * 100 / n_utt
            logger.info('Oracle WER (%s): %.2f %%' %
                        (dataloader.set, wer_oracle))
            logger.info('Oracle hit rate (%s): %.2f %%' %
                        (dataloader.set, oracle_hit_rate))

        if fine_grained:
            for len_bin, wers in sorted(wer_dist.items(), key=lambda x: x[0]):
                logger.info('  WER (%s): %.2f %% (%d)' %
                            (dataloader.set, sum(wers) / len(wers), len_bin))

    return wer, cer, n_oov_total
コード例 #7
0
def eval_phone(models, dataset, decode_params, epoch, progressbar=False):
    """Evaluate a phone-level model by PER.

    Args:
        models (list): the models to evaluate
        dataset: An instance of a `Dataset' class
        decode_params (dict):
        epoch (int):
        progressbar (bool): if True, visualize the progressbar
    Returns:
        per (float): Phone error rate
        num_sub (int): the number of substitution errors
        num_ins (int): the number of insertion errors
        num_del (int): the number of deletion errors
        decode_dir (str):

    """
    # Reset data counter
    dataset.reset()

    model = models[0]
    # TODO(hirofumi): ensemble decoding

    decode_dir = 'decode_' + dataset.set + '_ep' + str(epoch) + '_beam' + str(decode_params['beam_width'])
    decode_dir += '_lp' + str(decode_params['length_penalty'])
    decode_dir += '_cp' + str(decode_params['coverage_penalty'])
    decode_dir += '_' + str(decode_params['min_len_ratio']) + '_' + str(decode_params['max_len_ratio'])
    decode_dir += '_rnnlm' + str(decode_params['rnnlm_weight'])

    ref_trn_save_path = mkdir_join(model.save_path, decode_dir, 'ref.trn')
    hyp_trn_save_path = mkdir_join(model.save_path, decode_dir, 'hyp.trn')

    per = 0
    num_sub, num_ins, num_del = 0, 0, 0
    num_phones = 0
    if progressbar:
        pbar = tqdm(total=len(dataset))

    with open(hyp_trn_save_path, 'w') as f_hyp, open(ref_trn_save_path, 'w') as f_ref:
        while True:
            batch, is_new_epoch = dataset.next(decode_params['batch_size'])
            best_hyps, _, perm_idx = model.decode(batch['xs'], decode_params,
                                                  exclude_eos=True)
            ys = [batch['ys'][i] for i in perm_idx]

            for b in range(len(batch['xs'])):
                # Reference
                if dataset.is_test:
                    text_ref = ys[b]
                else:
                    text_ref = dataset.idx2phone(ys[b])

                # Hypothesis
                text_hyp = dataset.idx2phone(best_hyps[b])

                # Write to trn
                speaker = '_'.join(batch['utt_ids'][b].replace('-', '_').split('_')[:-2])
                start = batch['utt_ids'][b].replace('-', '_').split('_')[-2]
                end = batch['utt_ids'][b].replace('-', '_').split('_')[-1]
                f_ref.write(text_ref + ' (' + speaker + '-' + start + '-' + end + ')\n')
                f_hyp.write(text_hyp + ' (' + speaker + '-' + start + '-' + end + ')\n')

                # Compute PER
                per_b, sub_b, ins_b, del_b = compute_wer(ref=text_ref.split(' '),
                                                         hyp=text_hyp.split(' '),
                                                         normalize=False)
                per += per_b
                num_sub += sub_b
                num_ins += ins_b
                num_del += del_b
                num_phones += len(text_ref.split(' '))

                if progressbar:
                    pbar.update(1)

            if is_new_epoch:
                break

    if progressbar:
        pbar.close()

    # Reset data counters
    dataset.reset()

    per /= num_phones
    num_sub /= num_phones
    num_ins /= num_phones
    num_del /= num_phones

    return per, num_sub, num_ins, num_del, os.path.join(model.save_path, decode_dir)
コード例 #8
0
ファイル: character.py プロジェクト: sscorpio93/neural_sp
def eval_char(models,
              dataset,
              recog_params,
              epoch,
              recog_dir=None,
              progressbar=False,
              task_idx=0):
    """Evaluate the character-level model by WER & CER.

    Args:
        models (list): models to evaluate
        dataset (Dataset): evaluation dataset
        recog_params (dict):
        epoch (int):
        recog_dir (str):
        progressbar (bool): visualize the progressbar
        task_idx (int): the index of the target task in interest
            0: main task
            1: sub task
            2: sub sub task
    Returns:
        wer (float): Word error rate
        cer (float): Character error rate

    """
    # Reset data counter
    dataset.reset()

    if recog_dir is None:
        recog_dir = 'decode_' + dataset.set + '_ep' + str(
            epoch) + '_beam' + str(recog_params['recog_beam_width'])
        recog_dir += '_lp' + str(recog_params['recog_length_penalty'])
        recog_dir += '_cp' + str(recog_params['recog_coverage_penalty'])
        recog_dir += '_' + str(
            recog_params['recog_min_len_ratio']) + '_' + str(
                recog_params['recog_max_len_ratio'])
        recog_dir += '_lm' + str(recog_params['recog_lm_weight'])

        ref_trn_save_path = mkdir_join(models[0].save_path, recog_dir,
                                       'ref.trn')
        hyp_trn_save_path = mkdir_join(models[0].save_path, recog_dir,
                                       'hyp.trn')
    else:
        ref_trn_save_path = mkdir_join(recog_dir, 'ref.trn')
        hyp_trn_save_path = mkdir_join(recog_dir, 'hyp.trn')

    wer, cer = 0, 0
    n_sub_w, n_ins_w, n_del_w = 0, 0, 0
    n_sub_c, n_ins_c, n_del_c = 0, 0, 0
    n_word, n_char = 0, 0
    if progressbar:
        pbar = tqdm(total=len(dataset))

    if task_idx == 0:
        task = 'ys'
    elif task_idx == 1:
        task = 'ys_sub1'
    elif task_idx == 2:
        task = 'ys_sub2'
    elif task_idx == 3:
        task = 'ys_sub3'

    with open(hyp_trn_save_path, 'w') as f_hyp, open(ref_trn_save_path,
                                                     'w') as f_ref:
        while True:
            batch, is_new_epoch = dataset.next(
                recog_params['recog_batch_size'])
            best_hyps_id, _, _ = models[0].decode(
                batch['xs'],
                recog_params,
                dataset.idx2token[task_idx],
                exclude_eos=True,
                refs_id=batch['ys'] if task_idx == 0 else batch['ys_sub' +
                                                                str(task_idx)],
                utt_ids=batch['utt_ids'],
                speakers=batch['sessions']
                if dataset.corpus == 'swbd' else batch['speakers'],
                task=task,
                ensemble_models=models[1:] if len(models) > 1 else [])

            for b in range(len(batch['xs'])):
                ref = batch['text'][b]
                hyp = dataset.idx2token[task_idx](best_hyps_id[b])

                # Write to trn
                utt_id = str(batch['utt_ids'][b])
                speaker = str(batch['speakers'][b]).replace('-', '_')
                f_ref.write(ref + ' (' + speaker + '-' + utt_id + ')\n')
                f_hyp.write(hyp + ' (' + speaker + '-' + utt_id + ')\n')
                logger.info('utt-id: %s' % utt_id)
                logger.info('Ref: %s' % ref)
                logger.info('Hyp: %s' % hyp)
                logger.info('-' * 150)

                if ('char' in dataset.unit and 'nowb' not in dataset.unit) or (
                        task_idx > 0 and dataset.unit_sub1 == 'char'):
                    # Compute WER
                    wer_b, sub_b, ins_b, del_b = compute_wer(
                        ref=ref.split(' '),
                        hyp=hyp.split(' '),
                        normalize=False)
                    wer += wer_b
                    n_sub_w += sub_b
                    n_ins_w += ins_b
                    n_del_w += del_b
                    n_word += len(ref.split(' '))

                # Compute CER
                if dataset.corpus == 'csj':
                    ref = ref.replace(' ', '')
                    hyp = hyp.replace(' ', '')
                cer_b, sub_b, ins_b, del_b = compute_wer(ref=list(ref),
                                                         hyp=list(hyp),
                                                         normalize=False)
                cer += cer_b
                n_sub_c += sub_b
                n_ins_c += ins_b
                n_del_c += del_b
                n_char += len(ref)

                if progressbar:
                    pbar.update(1)

            if is_new_epoch:
                break

    if progressbar:
        pbar.close()

    # Reset data counters
    dataset.reset()

    if ('char' in dataset.unit and 'nowb' not in dataset.unit) or (
            task_idx > 0 and dataset.unit_sub1 == 'char'):
        wer /= n_word
        n_sub_w /= n_word
        n_ins_w /= n_word
        n_del_w /= n_word
    else:
        wer = n_sub_w = n_ins_w = n_del_w = 0

    cer /= n_char
    n_sub_c /= n_char
    n_ins_c /= n_char
    n_del_c /= n_char

    logger.info('WER (%s): %.2f %%' % (dataset.set, wer))
    logger.info('SUB: %.2f / INS: %.2f / DEL: %.2f' %
                (n_sub_w, n_ins_w, n_del_w))
    logger.info('CER (%s): %.2f %%' % (dataset.set, cer))
    logger.info('SUB: %.2f / INS: %.2f / DEL: %.2f' %
                (n_sub_c, n_ins_c, n_del_c))

    return wer, cer
コード例 #9
0
def eval_phone(models,
               dataset,
               recog_params,
               epoch,
               recog_dir=None,
               progressbar=False):
    """Evaluate a phone-level model by PER.

    Args:
        models (list): models to evaluate
        dataset (Dataset): evaluation dataset
        recog_params (dict):
        epoch (int):
        recog_dir (str):
        progressbar (bool): visualize the progressbar
    Returns:
        per (float): Phone error rate

    """
    # Reset data counter
    dataset.reset()

    if recog_dir is None:
        recog_dir = 'decode_' + dataset.set + '_ep' + str(
            epoch) + '_beam' + str(recog_params['recog_beam_width'])
        recog_dir += '_lp' + str(recog_params['recog_length_penalty'])
        recog_dir += '_cp' + str(recog_params['recog_coverage_penalty'])
        recog_dir += '_' + str(
            recog_params['recog_min_len_ratio']) + '_' + str(
                recog_params['recog_max_len_ratio'])

        ref_trn_save_path = mkdir_join(models[0].save_path, recog_dir,
                                       'ref.trn')
        hyp_trn_save_path = mkdir_join(models[0].save_path, recog_dir,
                                       'hyp.trn')
    else:
        ref_trn_save_path = mkdir_join(recog_dir, 'ref.trn')
        hyp_trn_save_path = mkdir_join(recog_dir, 'hyp.trn')

    per = 0
    n_sub, n_ins, n_del = 0, 0, 0
    n_phone = 0
    if progressbar:
        pbar = tqdm(total=len(dataset))

    with open(hyp_trn_save_path, 'w') as f_hyp, open(ref_trn_save_path,
                                                     'w') as f_ref:
        while True:
            batch, is_new_epoch = dataset.getitem(
                recog_params['recog_batch_size'])
            best_hyps_id, _, _ = models[0].decode(
                batch['xs'],
                recog_params,
                dataset.idx2token[0],
                exclude_eos=True,
                refs_id=batch['ys'],
                utt_ids=batch['utt_ids'],
                speakers=batch['sessions']
                if dataset.corpus == 'swbd' else batch['speakers'],
                ensemble_models=models[1:] if len(models) > 1 else [])

            for b in range(len(batch['xs'])):
                ref = batch['text'][b]
                hyp = dataset.idx2token[0](best_hyps_id[b])

                # Write to trn
                utt_id = str(batch['utt_ids'][b])
                speaker = str(batch['speakers'][b]).replace('-', '_')
                f_ref.write(ref + ' (' + speaker + '-' + utt_id + ')\n')
                f_hyp.write(hyp + ' (' + speaker + '-' + utt_id + ')\n')
                logger.info('utt-id: %s' % batch['utt_ids'][b])
                logger.info('Ref: %s' % ref)
                logger.info('Hyp: %s' % hyp)
                logger.info('-' * 150)

                # Compute PER
                per_b, sub_b, ins_b, del_b = compute_wer(ref=ref.split(' '),
                                                         hyp=hyp.split(' '),
                                                         normalize=False)
                per += per_b
                n_sub += sub_b
                n_ins += ins_b
                n_del += del_b
                n_phone += len(ref.split(' '))

                if progressbar:
                    pbar.update(1)

            if is_new_epoch:
                break

    if progressbar:
        pbar.close()

    # Reset data counters
    dataset.reset()

    per /= n_phone
    n_sub /= n_phone
    n_ins /= n_phone
    n_del /= n_phone

    logger.info('PER (%s): %.2f %%' % (dataset.set, per))
    logger.info('SUB: %.2f / INS: %.2f / DEL: %.2f' % (n_sub, n_ins, n_del))

    return per
コード例 #10
0
def eval_char(models,
              dataloader,
              params,
              epoch=-1,
              rank=0,
              save_dir=None,
              streaming=False,
              progressbar=False,
              task_idx=0,
              edit_distance=True,
              fine_grained=False,
              oracle=False,
              teacher_force=False):
    """Evaluate a character-level model by WER & CER.

    Args:
        models (List): models to evaluate
        dataloader (torch.utils.data.DataLoader): evaluation dataloader
        params (omegaconf.dictconfig.DictConfig): decoding hyperparameters
        epoch (int): current epoch
        rank (int): rank of current process group
        save_dir (str): directory path to save hypotheses
        streaming (bool): streaming decoding for session-level evaluation
        progressbar (bool): visualize progressbar
        edit_distance (bool): calculate edit-distance (can be skipped for RTF calculation)
        task_idx (int): index of target task in interest
            0: main task
            1: sub task
            2: sub sub task
        fine_grained (bool): calculate fine-grained WER distributions based on input lengths
        oracle (bool): calculate oracle WER
        teacher_force (bool): conduct decoding in teacher-forcing mode
    Returns:
        wer (float): Word error rate
        cer (float): Character error rate

    """
    if save_dir is None:
        save_dir = 'decode_' + dataloader.set + '_ep' + \
            str(epoch) + '_beam' + str(params.get('recog_beam_width'))
        save_dir += '_lp' + str(params.get('recog_length_penalty'))
        save_dir += '_cp' + str(params.get('recog_coverage_penalty'))
        save_dir += '_' + str(params.get('recog_min_len_ratio')) + '_' + \
            str(params.get('recog_max_len_ratio'))
        save_dir += '_lm' + str(params.get('recog_lm_weight'))

        ref_trn_path = mkdir_join(models[0].save_path,
                                  save_dir,
                                  'ref.trn',
                                  rank=rank)
        hyp_trn_path = mkdir_join(models[0].save_path,
                                  save_dir,
                                  'hyp.trn',
                                  rank=rank)
    else:
        ref_trn_path = mkdir_join(save_dir, 'ref.trn', rank=rank)
        hyp_trn_path = mkdir_join(save_dir, 'hyp.trn', rank=rank)

    wer, cer = 0, 0
    n_sub_w, n_ins_w, n_del_w = 0, 0, 0
    n_sub_c, n_ins_c, n_del_c = 0, 0, 0
    n_word, n_char = 0, 0
    cer_dist = {}  # calculate CER distribution based on input lengths

    cer_oracle = 0
    n_oracle_hit = 0

    n_streamable, quantity_rate, n_utt = 0, 0, 0
    last_success_frame_ratio = 0

    # Reset data counter
    dataloader.reset(params.get('recog_batch_size'), 'seq')

    if progressbar:
        pbar = tqdm(total=len(dataloader))

    if rank == 0:
        f_hyp = codecs.open(hyp_trn_path, 'w', encoding='utf-8')
        f_ref = codecs.open(ref_trn_path, 'w', encoding='utf-8')

    if task_idx == 0:
        task = 'ys'
    elif task_idx == 1:
        task = 'ys_sub1'
    elif task_idx == 2:
        task = 'ys_sub2'
    elif task_idx == 3:
        task = 'ys_sub3'

    for batch in dataloader:
        speakers = batch['sessions' if dataloader.corpus ==
                         'swbd' else 'speakers']
        if streaming or params.get('recog_block_sync'):
            nbest_hyps_id = models[0].decode_streaming(batch['xs'],
                                                       params,
                                                       dataloader.idx2token[0],
                                                       exclude_eos=True,
                                                       speaker=speakers[0])[0]
        else:
            nbest_hyps_id = models[0].decode(
                batch['xs'],
                params,
                idx2token=dataloader.idx2token[0],
                exclude_eos=True,
                refs_id=batch['ys'] if task_idx == 0 else batch['ys_sub' +
                                                                str(task_idx)],
                utt_ids=batch['utt_ids'],
                speakers=speakers,
                task=task,
                ensemble_models=models[1:] if len(models) > 1 else [],
                teacher_force=teacher_force)[0]

        for b in range(len(batch['xs'])):
            # assert len(batch['xs']) == 1, 'batch is 1'
            ref = batch['text'][b]
            nbest_hyps_tmp = [
                dataloader.idx2token[0](hyp_id) for hyp_id in nbest_hyps_id[b]
            ]
            # print(nbest_hyps_id)
            # print(nbest_hyps_tmp)
            # assert False, 'vv'
            # Truncate the first and last spaces for the char_space unit
            nbest_hyps = []
            for hyp in nbest_hyps_tmp:
                if len(hyp) > 0 and hyp[0] == ' ':
                    hyp = hyp[1:]
                if len(hyp) > 0 and hyp[-1] == ' ':
                    hyp = hyp[:-1]
                nbest_hyps.append(hyp)

            # Write to trn
            speaker = str(batch['speakers'][b]).replace('-', '_')
            if streaming:
                utt_id = str(batch['utt_ids'][b]) + '_0000000_0000001'
            else:
                utt_id = str(batch['utt_ids'][b])
            if rank == 0:
                f_ref.write(ref + ' (' + speaker + '-' + utt_id + ')\n')
                f_hyp.write(nbest_hyps[0] + ' (' + speaker + '-' + utt_id +
                            ')\n')
            logger.debug('utt-id (%d/%d): %s' %
                         (n_utt + 1, len(dataloader), utt_id))
            logger.debug('Ref: %s' % ref)
            logger.debug('Hyp: %s' % nbest_hyps[0])
            logger.debug('-' * 150)

            if edit_distance and not streaming:
                if ('char' in dataloader.unit and 'nowb' not in dataloader.unit
                    ) or (task_idx > 0 and dataloader.unit_sub1 == 'char'):
                    # Compute WER
                    err_b, sub_b, ins_b, del_b = compute_wer(
                        ref=ref.split(' '), hyp=nbest_hyps[0].split(' '))
                    wer += err_b
                    n_sub_w += sub_b
                    n_ins_w += ins_b
                    n_del_w += del_b
                    n_word += len(ref.split(' '))
                    # NOTE: sentence error rate for Chinese

                # Compute CER
                if dataloader.corpus == 'csj':
                    ref = ref.replace(' ', '')
                    nbest_hyps[0] = nbest_hyps[0].replace(' ', '')
                err_b, sub_b, ins_b, del_b = compute_wer(ref=list(ref),
                                                         hyp=list(
                                                             nbest_hyps[0]))
                cer += err_b
                n_sub_c += sub_b
                n_ins_c += ins_b
                n_del_c += del_b
                n_char += len(ref)

                # Compute oracle CER
                if oracle and len(nbest_hyps) > 1:
                    cers_b = [err_b] + [
                        compute_wer(ref=list(ref), hyp=list(hyp_n))[0]
                        for hyp_n in nbest_hyps[1:]
                    ]
                    oracle_idx = np.argmin(np.array(cers_b))
                    if oracle_idx == 0:
                        n_oracle_hit += len(batch['utt_ids'])
                    cer_oracle += cers_b[oracle_idx]

                if fine_grained:
                    xlen_bin = (batch['xlens'][b] // 200 + 1) * 200
                    if xlen_bin in cer_dist.keys():
                        cer_dist[xlen_bin] += [err_b / 100]
                    else:
                        cer_dist[xlen_bin] = [err_b / 100]

                if models[0].streamable():
                    n_streamable += len(batch['utt_ids'])
                else:
                    last_success_frame_ratio += models[
                        0].last_success_frame_ratio()
                quantity_rate += models[0].quantity_rate()

        n_utt += len(batch['utt_ids'])
        if progressbar:
            pbar.update(len(batch['utt_ids']))

    if rank == 0:
        f_hyp.close()
        f_ref.close()
    if progressbar:
        pbar.close()

    # Reset data counters
    dataloader.reset(is_new_epoch=True)

    if edit_distance and not streaming:
        if ('char' in dataloader.unit and 'nowb' not in dataloader.unit) or (
                task_idx > 0 and dataloader.unit_sub1 == 'char'):
            wer /= n_word
            n_sub_w /= n_word
            n_ins_w /= n_word
            n_del_w /= n_word
        else:
            wer = n_sub_w = n_ins_w = n_del_w = 0

        cer /= n_char
        n_sub_c /= n_char
        n_ins_c /= n_char
        n_del_c /= n_char

        if n_utt - n_streamable > 0:
            last_success_frame_ratio /= (n_utt - n_streamable)
        n_streamable /= n_utt
        quantity_rate /= n_utt

        if params.get('recog_beam_width') > 1:
            logger.info('WER (%s): %.2f %%' % (dataloader.set, wer))
            logger.info('SUB: %.2f / INS: %.2f / DEL: %.2f' %
                        (n_sub_w, n_ins_w, n_del_w))
            logger.info('CER (%s): %.2f %%' % (dataloader.set, cer))
            logger.info('SUB: %.2f / INS: %.2f / DEL: %.2f' %
                        (n_sub_c, n_ins_c, n_del_c))

        if oracle:
            cer_oracle /= n_char
            oracle_hit_rate = n_oracle_hit * 100 / n_utt
            logger.info('Oracle CER (%s): %.2f %%' %
                        (dataloader.set, cer_oracle))
            logger.info('Oracle hit rate (%s): %.2f %%' %
                        (dataloader.set, oracle_hit_rate))

        if fine_grained:
            for len_bin, cers in sorted(cer_dist.items(), key=lambda x: x[0]):
                logger.info('  CER (%s): %.2f %% (%d)' %
                            (dataloader.set, sum(cers) / len(cers), len_bin))

        logger.info('Streamability (%s): %.2f %%' %
                    (dataloader.set, n_streamable * 100))
        logger.info('Quantity rate (%s): %.2f %%' %
                    (dataloader.set, quantity_rate * 100))
        logger.info('Last success frame ratio (%s): %.2f %%' %
                    (dataloader.set, last_success_frame_ratio))

    return wer, cer
コード例 #11
0
def eval_wordpiece(models,
                   dataloader,
                   recog_params,
                   epoch,
                   recog_dir=None,
                   streaming=False,
                   progressbar=False,
                   fine_grained=False,
                   oracle=False,
                   teacher_force=False):
    """Evaluate a wordpiece-level model by WER.

    Args:
        models (List): models to evaluate
        dataloader (torch.utils.data.DataLoader): evaluation dataloader
        recog_params (dict):
        epoch (int):
        recog_dir (str):
        streaming (bool): streaming decoding for session-level evaluation
        progressbar (bool): visualize progressbar
        oracle (bool): calculate oracle WER
        fine_grained (bool): calculate fine-grained WER distributions based on input lengths
        teacher_force (bool): conduct decoding in teacher-forcing mode
    Returns:
        wer (float): Word error rate
        cer (float): Character error rate

    """
    if recog_dir is None:
        recog_dir = 'decode_' + dataloader.set + '_ep' + str(
            epoch) + '_beam' + str(recog_params['recog_beam_width'])
        recog_dir += '_lp' + str(recog_params['recog_length_penalty'])
        recog_dir += '_cp' + str(recog_params['recog_coverage_penalty'])
        recog_dir += '_' + str(
            recog_params['recog_min_len_ratio']) + '_' + str(
                recog_params['recog_max_len_ratio'])
        recog_dir += '_lm' + str(recog_params['recog_lm_weight'])

        ref_trn_path = mkdir_join(models[0].save_path, recog_dir, 'ref.trn')
        hyp_trn_path = mkdir_join(models[0].save_path, recog_dir, 'hyp.trn')
    else:
        ref_trn_path = mkdir_join(recog_dir, 'ref.trn')
        hyp_trn_path = mkdir_join(recog_dir, 'hyp.trn')

    wer, cer = 0, 0
    n_sub_w, n_ins_w, n_del_w = 0, 0, 0
    n_sub_c, n_ins_c, n_del_c = 0, 0, 0
    n_word, n_char = 0, 0
    wer_dist = {}  # calculate WER distribution based on input lengths

    wer_oracle = 0
    n_oracle_hit = 0

    n_streamable, quantity_rate, n_utt = 0, 0, 0
    last_success_frame_ratio = 0

    # Reset data counter
    dataloader.reset(recog_params['recog_batch_size'])

    if progressbar:
        pbar = tqdm(total=len(dataloader))

    with codecs.open(hyp_trn_path, 'w', encoding='utf-8') as f_hyp, \
            codecs.open(ref_trn_path, 'w', encoding='utf-8') as f_ref:
        while True:
            batch, is_new_epoch = dataloader.next(
                recog_params['recog_batch_size'])
            if streaming or recog_params['recog_block_sync']:
                nbest_hyps_id = models[0].decode_streaming(
                    batch['xs'],
                    recog_params,
                    dataloader.idx2token[0],
                    exclude_eos=True)[0]
            else:
                nbest_hyps_id = models[0].decode(
                    batch['xs'],
                    recog_params,
                    idx2token=dataloader.idx2token[0] if progressbar else None,
                    exclude_eos=True,
                    refs_id=batch['ys'],
                    utt_ids=batch['utt_ids'],
                    speakers=batch['sessions' if dataloader.corpus ==
                                   'swbd' else 'speakers'],
                    ensemble_models=models[1:] if len(models) > 1 else [])[0]

            for b in range(len(batch['xs'])):
                ref = batch['text'][b]
                if ref[0] == '<':
                    ref = ref.split('>')[1]
                nbest_hyps = [
                    dataloader.idx2token[0](hyp_id)
                    for hyp_id in nbest_hyps_id[b]
                ]

                # Write to trn
                speaker = str(batch['speakers'][b]).replace('-', '_')
                if streaming:
                    utt_id = str(batch['utt_ids'][b]) + '_0000000_0000001'
                else:
                    utt_id = str(batch['utt_ids'][b])
                f_ref.write(ref + ' (' + speaker + '-' + utt_id + ')\n')
                f_hyp.write(nbest_hyps[0] + ' (' + speaker + '-' + utt_id +
                            ')\n')
                logger.debug('utt-id: %s' % utt_id)
                logger.debug('Ref: %s' % ref)
                logger.debug('Hyp: %s' % nbest_hyps[0])
                logger.debug('-' * 150)

                if not streaming:
                    # Compute WER
                    err_b, sub_b, ins_b, del_b = compute_wer(
                        ref=ref.split(' '), hyp=nbest_hyps[0].split(' '))
                    wer += err_b
                    n_sub_w += sub_b
                    n_ins_w += ins_b
                    n_del_w += del_b
                    n_word += len(ref.split(' '))

                    # Compute oracle WER
                    if oracle and len(nbest_hyps) > 1:
                        wers_b = [err_b] + [
                            compute_wer(ref=ref.split(' '),
                                        hyp=hyp_n.split(' '))[0]
                            for hyp_n in nbest_hyps[1:]
                        ]
                        oracle_idx = np.argmin(np.array(wers_b))
                        if oracle_idx == 0:
                            n_oracle_hit += 1
                        wer_oracle += wers_b[oracle_idx]

                    if fine_grained:
                        xlen_bin = (batch['xlens'][b] // 200 + 1) * 200
                        if xlen_bin in wer_dist.keys():
                            wer_dist[xlen_bin] += [err_b / 100]
                        else:
                            wer_dist[xlen_bin] = [err_b / 100]

                    # Compute CER
                    if dataloader.corpus == 'csj':
                        ref = ref.replace(' ', '')
                        nbest_hyps[0] = nbest_hyps[0].replace(' ', '')
                    err_b, sub_b, ins_b, del_b = compute_wer(
                        ref=list(ref), hyp=list(nbest_hyps[0]))
                    cer += err_b
                    n_sub_c += sub_b
                    n_ins_c += ins_b
                    n_del_c += del_b
                    n_char += len(ref)

                    if models[0].streamable():
                        n_streamable += 1
                    else:
                        last_success_frame_ratio += models[
                            0].last_success_frame_ratio()
                    quantity_rate += models[0].quantity_rate()

                n_utt += 1
                if progressbar:
                    pbar.update(1)

            if is_new_epoch:
                break

    if progressbar:
        pbar.close()

    # Reset data counters
    dataloader.reset()

    if not streaming:
        wer /= n_word
        n_sub_w /= n_word
        n_ins_w /= n_word
        n_del_w /= n_word

        cer /= n_char
        n_sub_c /= n_char
        n_ins_c /= n_char
        n_del_c /= n_char

        if n_utt - n_streamable > 0:
            last_success_frame_ratio /= (n_utt - n_streamable)
        n_streamable /= n_utt
        quantity_rate /= n_utt

        if recog_params['recog_beam_width'] > 1:
            logger.info('WER (%s): %.2f %%' % (dataloader.set, wer))
            logger.info('SUB: %.2f / INS: %.2f / DEL: %.2f' %
                        (n_sub_w, n_ins_w, n_del_w))
            logger.info('CER (%s): %.2f %%' % (dataloader.set, cer))
            logger.info('SUB: %.2f / INS: %.2f / DEL: %.2f' %
                        (n_sub_c, n_ins_c, n_del_c))

        if oracle:
            wer_oracle /= n_word
            oracle_hit_rate = n_oracle_hit * 100 / n_utt
            logger.info('Oracle WER (%s): %.2f %%' %
                        (dataloader.set, wer_oracle))
            logger.info('Oracle hit rate (%s): %.2f %%' %
                        (dataloader.set, oracle_hit_rate))

        if fine_grained:
            for len_bin, wers in sorted(wer_dist.items(), key=lambda x: x[0]):
                logger.info('  WER (%s): %.2f %% (%d)' %
                            (dataloader.set, sum(wers) / len(wers), len_bin))

        logger.info('Streamability (%s): %.2f %%' %
                    (dataloader.set, n_streamable * 100))
        logger.info('Quantity rate (%s): %.2f %%' %
                    (dataloader.set, quantity_rate * 100))
        logger.info('Last success frame ratio (%s): %.2f %%' %
                    (dataloader.set, last_success_frame_ratio))

    return wer, cer
コード例 #12
0
ファイル: wordpiece.py プロジェクト: caochensi/neural_sp
def eval_wordpiece(models,
                   dataset,
                   decode_params,
                   epoch,
                   decode_dir=None,
                   progressbar=False):
    """Evaluate the wordpiece-level model by WER.

    Args:
        models (list): the models to evaluate
        dataset: An instance of a `Dataset' class
        decode_params (dict):
        epoch (int):
        decode_dir (str):
        progressbar (bool): if True, visualize the progressbar
    Returns:
        wer (float): Word error rate
        nsub (int): the number of substitution errors
        nins (int): the number of insertion errors
        ndel (int): the number of deletion errors

    """
    # Reset data counter
    dataset.reset()

    model = models[0]
    # TODO(hirofumi): ensemble decoding

    if decode_dir is None:
        decode_dir = 'decode_' + dataset.set + '_ep' + str(
            epoch) + '_beam' + str(decode_params['beam_width'])
        decode_dir += '_lp' + str(decode_params['length_penalty'])
        decode_dir += '_cp' + str(decode_params['coverage_penalty'])
        decode_dir += '_' + str(decode_params['min_len_ratio']) + '_' + str(
            decode_params['max_len_ratio'])
        decode_dir += '_rnnlm' + str(decode_params['rnnlm_weight'])

        ref_trn_save_path = mkdir_join(model.save_path, decode_dir, 'ref.trn')
        hyp_trn_save_path = mkdir_join(model.save_path, decode_dir, 'hyp.trn')
    else:
        ref_trn_save_path = mkdir_join(decode_dir, 'ref.trn')
        hyp_trn_save_path = mkdir_join(decode_dir, 'hyp.trn')

    wer = 0
    nsub, nins, ndel = 0, 0, 0
    nword = 0
    if progressbar:
        pbar = tqdm(total=len(dataset))

    with open(hyp_trn_save_path, 'w') as f_hyp, open(ref_trn_save_path,
                                                     'w') as f_ref:
        while True:
            batch, is_new_epoch = dataset.next(decode_params['batch_size'])
            best_hyps, _, perm_id = model.decode(batch['xs'],
                                                 decode_params,
                                                 exclude_eos=True,
                                                 id2token=dataset.id2wp,
                                                 refs=batch['ys'])
            ys = [batch['text'][i] for i in perm_id]

            for b in six.moves.range(len(batch['xs'])):
                ref = ys[b]
                hyp = dataset.id2wp(best_hyps[b])

                # Write to trn
                speaker = '_'.join(batch['utt_ids'][b].replace(
                    '-', '_').split('_')[:-2])
                start = batch['utt_ids'][b].replace('-', '_').split('_')[-2]
                end = batch['utt_ids'][b].replace('-', '_').split('_')[-1]
                f_ref.write(ref + ' (' + speaker + '-' + start + '-' + end +
                            ')\n')
                f_hyp.write(hyp + ' (' + speaker + '-' + start + '-' + end +
                            ')\n')
                logger.info('utt-id: %s' % batch['utt_ids'][b])
                # logger.info('Ref: %s' % ref.lower())
                logger.info('Ref: %s' % ref)
                logger.info('Hyp: %s' % hyp)
                logger.info('-' * 50)

                # Compute WER
                wer_b, sub_b, ins_b, del_b = compute_wer(ref=ref.split(' '),
                                                         hyp=hyp.split(' '),
                                                         normalize=False)
                wer += wer_b
                nsub += sub_b
                nins += ins_b
                ndel += del_b
                nword += len(ref.split(' '))
                # logger.info('WER: %d%%' % (float(wer_b) / len(ref.split(' '))))

                if progressbar:
                    pbar.update(1)

            if is_new_epoch:
                break

    if progressbar:
        pbar.close()

    # Reset data counters
    dataset.reset()

    wer /= nword
    nsub /= nword
    nins /= nword
    ndel /= nword

    return wer, nsub, nins, ndel
コード例 #13
0
def eval_char(models,
              dataset,
              decode_params,
              epoch,
              decode_dir=None,
              progressbar=False,
              task_id=0):
    """Evaluate the character-level model by WER & CER.

    Args:
        models (list): the models to evaluate
        dataset: An instance of a `Dataset' class
        decode_params (dict):
        epoch (int):
        decode_dir (str):
        progressbar (bool): if True, visualize the progressbar
        task_id (int): the index of the target task in interest
            0: main task
            1: sub task
            2: sub sub task
    Returns:
        wer (float): Word error rate
        nsub_w (int): the number of substitution errors for WER
        nins_w (int): the number of insertion errors for WER
        ndel_w (int): the number of deletion errors for WER
        cer (float): Character error rate
        nsub_w (int): the number of substitution errors for CER
        nins_c (int): the number of insertion errors for CER
        ndel_c (int): the number of deletion errors for CER

    """
    # Reset data counter
    dataset.reset()

    model = models[0]

    if decode_dir is None:
        decode_dir = 'decode_' + dataset.set + '_ep' + str(
            epoch) + '_beam' + str(decode_params['beam_width'])
        decode_dir += '_lp' + str(decode_params['length_penalty'])
        decode_dir += '_cp' + str(decode_params['coverage_penalty'])
        decode_dir += '_' + str(decode_params['min_len_ratio']) + '_' + str(
            decode_params['max_len_ratio'])
        decode_dir += '_rnnlm' + str(decode_params['rnnlm_weight'])

        ref_trn_save_path = mkdir_join(model.save_path, decode_dir, 'ref.trn')
        hyp_trn_save_path = mkdir_join(model.save_path, decode_dir, 'hyp.trn')
    else:
        ref_trn_save_path = mkdir_join(decode_dir, 'ref.trn')
        hyp_trn_save_path = mkdir_join(decode_dir, 'hyp.trn')

    wer, cer = 0, 0
    nsub_w, nins_w, ndel_w = 0, 0, 0
    nsub_c, nins_c, ndel_c = 0, 0, 0
    nword, nchar = 0, 0
    if progressbar:
        pbar = tqdm(total=len(dataset))

    if task_id == 0:
        task = 'ys'
    elif task_id == 1:
        task = 'ys_sub1'
    elif task_id == 2:
        task = 'ys_sub2'

    with open(hyp_trn_save_path, 'w') as f_hyp, open(ref_trn_save_path,
                                                     'w') as f_ref:
        while True:
            batch, is_new_epoch = dataset.next(decode_params['batch_size'])
            best_hyps, _, perm_ids = model.decode(batch['xs'],
                                                  decode_params,
                                                  exclude_eos=True,
                                                  task=task)
            ys = [batch['text'][i] for i in perm_ids]

            for b in six.moves.range(len(batch['xs'])):
                ref = ys[b]
                hyp = dataset.id2char(best_hyps[b])

                # Write to trn
                speaker = '_'.join(batch['utt_ids'][b].replace(
                    '-', '_').split('_')[:-2])
                start = batch['utt_ids'][b].replace('-', '_').split('_')[-2]
                end = batch['utt_ids'][b].replace('-', '_').split('_')[-1]
                f_ref.write(ref + ' (' + speaker + '-' + start + '-' + end +
                            ')\n')
                f_hyp.write(hyp + ' (' + speaker + '-' + start + '-' + end +
                            ')\n')
                logger.info('utt-id: %s' % batch['utt_ids'][b])
                # logger.info('Ref: %s' % ref.lower())
                logger.info('Ref: %s' % ref)
                logger.info('Hyp: %s' % hyp)
                logger.info('-' * 50)

                if ('char' in dataset.unit and 'nowb' not in dataset.unit) or (
                        task_id > 0 and dataset.unit_sub1 == 'char'):
                    # Compute WER
                    wer_b, sub_b, ins_b, del_b = compute_wer(
                        ref=ref.split(' '),
                        hyp=hyp.split(' '),
                        normalize=False)
                    wer += wer_b
                    nsub_w += sub_b
                    nins_w += ins_b
                    ndel_w += del_b
                    nword += len(ref.split(' '))
                    # logger.info('WER: %d%%' % (wer_b / len(ref.split(' '))))

                # Compute CER
                cer_b, sub_b, ins_b, del_b = compute_wer(ref=list(ref),
                                                         hyp=list(hyp),
                                                         normalize=False)
                cer += cer_b
                nsub_c += sub_b
                nins_c += ins_b
                ndel_c += del_b
                nchar += len(ref)
                # logger.info('CER: %d%%' % (cer_b / len(ref)))

                if progressbar:
                    pbar.update(1)

            if is_new_epoch:
                break

    if progressbar:
        pbar.close()

    # Reset data counters
    dataset.reset()

    if ('char' in dataset.unit and 'nowb' not in dataset.unit) or (
            task_id > 0 and dataset.unit_sub1 == 'char'):
        wer /= nword
        nsub_w /= nword
        nins_w /= nword
        ndel_w /= nword
    else:
        wer = nsub_w = nins_w = ndel_w = 0

    cer /= nchar
    nsub_c /= nchar
    nins_c /= nchar
    ndel_c /= nchar

    return (wer, nsub_w, nins_w, ndel_w), (cer, nsub_c, nins_c, ndel_c)
コード例 #14
0
ファイル: wordpiece.py プロジェクト: entn-at/neural_sp
def eval_wordpiece(models, dataset, recog_params, epoch,
                   recog_dir=None, streaming=False, progressbar=False,
                   fine_grained=False):
    """Evaluate the wordpiece-level model by WER.

    Args:
        models (list): models to evaluate
        dataset (Dataset): evaluation dataset
        recog_params (dict):
        epoch (int):
        recog_dir (str):
        streaming (bool): streaming decoding for the session-level evaluation
        progressbar (bool): visualize the progressbar
        fine_grained (bool): calculate fine-grained WER distributions based on input lengths
    Returns:
        wer (float): Word error rate
        cer (float): Character error rate

    """
    # Reset data counter
    dataset.reset(recog_params['recog_batch_size'])

    if recog_dir is None:
        recog_dir = 'decode_' + dataset.set + '_ep' + str(epoch) + '_beam' + str(recog_params['recog_beam_width'])
        recog_dir += '_lp' + str(recog_params['recog_length_penalty'])
        recog_dir += '_cp' + str(recog_params['recog_coverage_penalty'])
        recog_dir += '_' + str(recog_params['recog_min_len_ratio']) + '_' + str(recog_params['recog_max_len_ratio'])
        recog_dir += '_lm' + str(recog_params['recog_lm_weight'])

        ref_trn_save_path = mkdir_join(models[0].save_path, recog_dir, 'ref.trn')
        hyp_trn_save_path = mkdir_join(models[0].save_path, recog_dir, 'hyp.trn')
    else:
        ref_trn_save_path = mkdir_join(recog_dir, 'ref.trn')
        hyp_trn_save_path = mkdir_join(recog_dir, 'hyp.trn')

    wer, cer = 0, 0
    n_sub_w, n_ins_w, n_del_w = 0, 0, 0
    n_sub_c, n_ins_c, n_del_c = 0, 0, 0
    n_word, n_char = 0, 0
    n_streamable, quantity_rate, n_utt = 0, 0, 0
    last_success_frame_ratio = 0
    if progressbar:
        pbar = tqdm(total=len(dataset))

    # calculate WER distribution based on input lengths
    wer_dist = {}

    with open(hyp_trn_save_path, 'w') as f_hyp, open(ref_trn_save_path, 'w') as f_ref:
        while True:
            batch, is_new_epoch = dataset.next(recog_params['recog_batch_size'])
            if streaming or recog_params['recog_chunk_sync']:
                best_hyps_id, _ = models[0].decode_streaming(
                    batch['xs'], recog_params, dataset.idx2token[0],
                    exclude_eos=True)
            else:
                best_hyps_id, _ = models[0].decode(
                    batch['xs'], recog_params,
                    idx2token=dataset.idx2token[0] if progressbar else None,
                    exclude_eos=True,
                    refs_id=batch['ys'],
                    utt_ids=batch['utt_ids'],
                    speakers=batch['sessions' if dataset.corpus == 'swbd' else 'speakers'],
                    ensemble_models=models[1:] if len(models) > 1 else [])

            for b in range(len(batch['xs'])):
                ref = batch['text'][b]
                if ref[0] == '<':
                    ref = ref.split('>')[1]
                hyp = dataset.idx2token[0](best_hyps_id[b])

                # Write to trn
                speaker = str(batch['speakers'][b]).replace('-', '_')
                if streaming:
                    utt_id = str(batch['utt_ids'][b]) + '_0000000_0000001'
                else:
                    utt_id = str(batch['utt_ids'][b])
                f_ref.write(ref + ' (' + speaker + '-' + utt_id + ')\n')
                f_hyp.write(hyp + ' (' + speaker + '-' + utt_id + ')\n')
                logger.debug('utt-id: %s' % utt_id)
                logger.debug('Ref: %s' % ref)
                logger.debug('Hyp: %s' % hyp)
                logger.debug('-' * 150)

                if not streaming:
                    # Compute WER
                    wer_b, sub_b, ins_b, del_b = compute_wer(ref=ref.split(' '),
                                                             hyp=hyp.split(' '),
                                                             normalize=False)
                    wer += wer_b
                    n_sub_w += sub_b
                    n_ins_w += ins_b
                    n_del_w += del_b
                    n_word += len(ref.split(' '))

                    if fine_grained:
                        xlen_bin = (batch['xlens'][b] // 200 + 1) * 200
                        if xlen_bin in wer_dist.keys():
                            wer_dist[xlen_bin] += [wer_b / 100]
                        else:
                            wer_dist[xlen_bin] = [wer_b / 100]

                    # Compute CER
                    if dataset.corpus == 'csj':
                        ref = ref.replace(' ', '')
                        hyp = hyp.replace(' ', '')
                    cer_b, sub_b, ins_b, del_b = compute_wer(ref=list(ref),
                                                             hyp=list(hyp),
                                                             normalize=False)
                    cer += cer_b
                    n_sub_c += sub_b
                    n_ins_c += ins_b
                    n_del_c += del_b
                    n_char += len(ref)
                    if models[0].streamable():
                        n_streamable += 1
                    else:
                        last_success_frame_ratio += models[0].last_success_frame_ratio()
                    quantity_rate += models[0].quantity_rate()
                    n_utt += 1

                if progressbar:
                    pbar.update(1)

            if is_new_epoch:
                break

    if progressbar:
        pbar.close()

    # Reset data counters
    dataset.reset()

    if not streaming:
        wer /= n_word
        n_sub_w /= n_word
        n_ins_w /= n_word
        n_del_w /= n_word

        cer /= n_char
        n_sub_c /= n_char
        n_ins_c /= n_char
        n_del_c /= n_char

        if n_utt - n_streamable > 0:
            last_success_frame_ratio /= (n_utt - n_streamable)
        n_streamable /= n_utt
        quantity_rate /= n_utt

        if fine_grained:
            for len_bin, wers in sorted(wer_dist.items(), key=lambda x: x[0]):
                logger.info('  WER (%s): %.2f %% (%d)' % (dataset.set, sum(wers) / len(wers), len_bin))

    logger.debug('WER (%s): %.2f %%' % (dataset.set, wer))
    logger.debug('SUB: %.2f / INS: %.2f / DEL: %.2f' % (n_sub_w, n_ins_w, n_del_w))
    logger.debug('CER (%s): %.2f %%' % (dataset.set, cer))
    logger.debug('SUB: %.2f / INS: %.2f / DEL: %.2f' % (n_sub_c, n_ins_c, n_del_c))

    logger.info('Streamablility (%s): %.2f %%' % (dataset.set, n_streamable * 100))
    logger.info('Quantity rate (%s): %.2f %%' % (dataset.set, quantity_rate * 100))
    logger.info('Last success frame ratio (%s): %.2f %%' % (dataset.set, last_success_frame_ratio))

    return wer, cer
コード例 #15
0
ファイル: character.py プロジェクト: nipengmath/neural_sp
def eval_char(models, dataset, decode_params, epoch, progressbar=False):
    """Evaluate the character-level model by WER & CER.

    Args:
        models (list): the models to evaluate
        dataset: An instance of a `Dataset' class
        decode_params (dict):
        epoch (int):
        progressbar (bool): if True, visualize the progressbar
    Returns:
        wer (float): Word error rate
        num_sub (int): the number of substitution errors
        num_ins (int): the number of insertion errors
        num_del (int): the number of deletion errors
        cer (float): Character error rate
        num_sub (int): the number of substitution errors
        num_ins (int): the number of insertion errors
        num_del (int): the number of deletion errors
        decode_dir (str):

    """
    # Reset data counter
    dataset.reset()

    model = models[0]

    decode_dir = 'decode_' + dataset.set + '_ep' + str(epoch) + '_beam' + str(decode_params['beam_width'])
    decode_dir += '_lp' + str(decode_params['length_penalty'])
    decode_dir += '_cp' + str(decode_params['coverage_penalty'])
    decode_dir += '_' + str(decode_params['min_len_ratio']) + '_' + str(decode_params['max_len_ratio'])
    decode_dir += '_rnnlm' + str(decode_params['rnnlm_weight'])

    ref_trn_save_path = mkdir_join(model.save_path, decode_dir, 'ref.trn')
    hyp_trn_save_path = mkdir_join(model.save_path, decode_dir, 'hyp.trn')

    wer, cer = 0, 0
    num_sub_w, num_ins_w, num_del_w = 0, 0, 0
    num_sub_c, num_ins_c, num_del_c = 0, 0, 0
    num_words, num_chars = 0, 0
    if progressbar:
        pbar = tqdm(total=len(dataset))

    with open(hyp_trn_save_path, 'w') as f_hyp, open(ref_trn_save_path, 'w') as f_ref:
        while True:
            batch, is_new_epoch = dataset.next(decode_params['batch_size'])
            best_hyps, aw, perm_idx = model.decode(batch['xs'], decode_params,
                                                   exclude_eos=True)
            # task_index = 0
            ys = [batch['ys'][i] for i in perm_idx]

            for b in range(len(batch['xs'])):
                # Reference
                if dataset.is_test:
                    text_ref = ys[b]
                else:
                    text_ref = dataset.idx2char(ys[b])

                # Hypothesis
                text_hyp = dataset.idx2char(best_hyps[b])

                # Write to trn
                speaker = '_'.join(batch['utt_ids'][b].replace('-', '_').split('_')[:-2])
                start = batch['utt_ids'][b].replace('-', '_').split('_')[-2]
                end = batch['utt_ids'][b].replace('-', '_').split('_')[-1]
                f_ref.write(text_ref + ' (' + speaker + '-' + start + '-' + end + ')\n')
                f_hyp.write(text_hyp + ' (' + speaker + '-' + start + '-' + end + ')\n')

                if ('character' in dataset.label_type and 'nowb' not in dataset.label_type) or (task_index > 0 and dataset.label_type_sub == 'character'):
                    # Compute WER
                    wer_b, sub_b, ins_b, del_b = compute_wer(ref=text_ref.split(' '),
                                                             hyp=text_hyp.split(' '),
                                                             normalize=False)
                    wer += wer_b
                    num_sub_w += sub_b
                    num_ins_w += ins_b
                    num_del_w += del_b
                    num_words += len(text_ref.split(' '))

                # Compute CER
                cer_b, sub_b, ins_b, del_b = compute_wer(ref=list(text_ref.replace(' ', '')),
                                                         hyp=list(text_hyp.replace(' ', '')),
                                                         normalize=False)
                cer += cer_b
                num_sub_c += sub_b
                num_ins_c += ins_b
                num_del_c += del_b
                num_chars += len(text_ref)

                if progressbar:
                    pbar.update(1)

            if is_new_epoch:
                break

    if progressbar:
        pbar.close()

    # Reset data counters
    dataset.reset()

    if ('character' in dataset.label_type and 'nowb' not in dataset.label_type) or (task_index > 0 and dataset.label_type_sub == 'character'):
        wer /= num_words
        num_sub_w /= num_words
        num_ins_w /= num_words
        num_del_w /= num_words
    else:
        wer = num_sub_w = num_ins_w = num_del_w = 0

    cer /= num_chars
    num_sub_c /= num_chars
    num_ins_c /= num_chars
    num_del_c /= num_chars

    return (wer, num_sub_w, num_ins_w, num_del_w), (cer, num_sub_c, num_ins_c, num_del_c), os.path.join(model.save_path, decode_dir)
コード例 #16
0
def eval(epoch):
    recog_dir = args.out
    ref_trn_save_path = recog_dir + '/ref_epoch_' + str(epoch) + '.trn'
    hyp_trn_save_path = recog_dir + '/hyp_epoch_' + str(epoch) + '.trn'
    wer, cer = 0, 0
    n_sub_w, n_ins_w, n_del_w = 0, 0, 0
    n_sub_c, n_ins_c, n_del_c = 0, 0, 0
    n_word, n_char = 0, 0
    pbar = tqdm(total=len(devset))
    f_hyp = open(hyp_trn_save_path, 'w')
    f_ref = open(ref_trn_save_path, 'w')
    losses = []
    is_new_epoch = 0
    #    for xs, ys, xlen, ylen in devset:
    step = 0
    while True:
        batch, is_new_epoch = devset.next()
        #        if is_new_epoch:
        #            break
        xs, ys, xlens = batch['xs'], batch['ys'], batch['xlens']
        xs = [np2tensor(x).float() for x in batch['xs']]
        xlen = torch.IntTensor([len(x) for x in batch['xs']])
        xs = pad_list(xs, 0.0).cuda()
        _ys = [np2tensor(np.fromiter(y, dtype=np.int64), -1) for y in ys]
        ys_out_pad = pad_list(_ys, 0).long().cuda()
        ylen = np2tensor(np.fromiter([y.size(0) for y in _ys], dtype=np.int32))
        # xs = Variable(torch.FloatTens is:open or(xs), volatile=True).cuda()
        # ys = Variable(torch.LongTensor(ys), volatile=True).cuda()
        # xlen = Variable(torch.IntTensor(xlen)); ylen = Variable(torch.IntTensor(ylen))
        model.eval()
        #logging.info('================== Evaluation Mode =================')
        loss = model(xs, ys_out_pad, xlen, ylen)
        loss = float(loss.data) * len(xlen)
        losses.append(loss)
        step += 1  # //TODO vishay un-hardcode the batch size
        best_hyps_id, _ = model.greedy_decode(xs)

        for b in range(len(batch['xs'])):
            ref = batch['text'][b]
            hyp = devset.idx2token[0](best_hyps_id[b])
            hyp = removeDuplicates(hyp)
            # Write to trn
            utt_id = str(batch['utt_ids'][b])
            speaker = str(batch['speakers'][b]).replace('-', '_')
            if hyp is None:
                hyp = "none"
            f_ref.write(ref + ' (' + speaker + '-' + utt_id + ')\n')
            f_hyp.write(hyp + ' (' + speaker + '-' + utt_id + ')\n')
            logging.info('utt-id: %s' % utt_id)
            logging.info('Ref: %s' % ref)
            logging.info('Hyp: %s' % hyp)
            logging.info('-' * 150)

            if 'char' in devset.unit:  # //TODO this is only for char unit
                # Compute WER
                wer_b, sub_b, ins_b, del_b = compute_wer(ref=ref.split(' '),
                                                         hyp=hyp.split(' '),
                                                         normalize=False)
                wer += wer_b
                n_sub_w += sub_b
                n_ins_w += ins_b
                n_del_w += del_b
                n_word += len(ref.split(' '))

                # Compute CER
            cer_b, sub_b, ins_b, del_b = compute_wer(ref=list(ref),
                                                     hyp=list(hyp),
                                                     normalize=False)
            cer += cer_b
            n_sub_c += sub_b
            n_ins_c += ins_b
            n_del_c += del_b
            n_char += len(ref)

        pbar.update(len(batch['xs']))
        if is_new_epoch:
            break

    pbar.close()

    # Reset data counters
    devset.reset()

    if 'char' in devset.unit:
        wer /= n_word
        n_sub_w /= n_word
        n_ins_w /= n_word
        n_del_w /= n_word
    else:
        wer = n_sub_w = n_ins_w = n_del_w = 0

    cer /= n_char
    n_sub_c /= n_char
    n_ins_c /= n_char
    n_del_c /= n_char

    logging.info('WER (%s): %.2f %%' % (devset.set, wer))
    logging.info('SUB: %.2f / INS: %.2f / DEL: %.2f' %
                 (n_sub_w, n_ins_w, n_del_w))
    logging.info('CER (%s): %.2f %%' % (devset.set, cer))
    logging.info('SUB: %.2f / INS: %.2f / DEL: %.2f' %
                 (n_sub_c, n_ins_c, n_del_c))

    # print(step, '/12k  dev')
    return sum(losses) / len(devset), wer, cer
コード例 #17
0
ファイル: phone.py プロジェクト: ishine/neural_sp
def eval_phone(models,
               dataloader,
               params,
               epoch=-1,
               rank=0,
               save_dir=None,
               streaming=False,
               progressbar=False,
               edit_distance=True,
               fine_grained=False,
               oracle=False,
               teacher_force=False):
    """Evaluate a phone-level model by PER.

    Args:
        models (List): models to evaluate
        dataloader (torch.utils.data.DataLoader): evaluation dataloader
        params (omegaconf.dictconfig.DictConfig): decoding hyperparameters
        epoch (int): current epoch
        rank (int): rank of current process group
        save_dir (str): directory path to save hypotheses
        streaming (bool): streaming decoding for session-level evaluation
        progressbar (bool): visualize progressbar
        edit_distance (bool): calculate edit-distance (can be skipped for RTF calculation)
        fine_grained (bool): calculate fine-grained PER distributions based on input lengths
        oracle (bool): calculate oracle PER
        teacher_force (bool): conduct decoding in teacher-forcing mode
    Returns:
        per (float): Phone error rate

    """
    if save_dir is None:
        save_dir = 'decode_' + dataloader.set + '_ep' + \
            str(epoch) + '_beam' + str(params.get('recog_beam_width'))
        save_dir += '_lp' + str(params.get('recog_length_penalty'))
        save_dir += '_cp' + str(params.get('recog_coverage_penalty'))
        save_dir += '_' + str(params.get('recog_min_len_ratio')) + '_' + \
            str(params.get('recog_max_len_ratio'))

        ref_trn_path = mkdir_join(models[0].save_path,
                                  save_dir,
                                  'ref.trn',
                                  rank=rank)
        hyp_trn_path = mkdir_join(models[0].save_path,
                                  save_dir,
                                  'hyp.trn',
                                  rank=rank)
    else:
        ref_trn_path = mkdir_join(save_dir, 'ref.trn', rank=rank)
        hyp_trn_path = mkdir_join(save_dir, 'hyp.trn', rank=rank)

    per = 0
    n_sub, n_ins, n_del = 0, 0, 0
    n_phone = 0
    per_dist = {}  # calculate PER distribution based on input lengths

    per_oracle = 0
    n_oracle_hit = 0
    n_utt = 0

    # Reset data counter
    dataloader.reset(params.get('recog_batch_size'), 'seq')

    if progressbar:
        pbar = tqdm(total=len(dataloader))

    if rank == 0:
        f_hyp = codecs.open(hyp_trn_path, 'w', encoding='utf-8')
        f_ref = codecs.open(ref_trn_path, 'w', encoding='utf-8')

    for batch in dataloader:
        speakers = batch['sessions' if dataloader.corpus ==
                         'swbd' else 'speakers']
        if streaming or params.get('recog_block_sync'):
            nbest_hyps_id = models[0].decode_streaming(batch['xs'],
                                                       params,
                                                       dataloader.idx2token[0],
                                                       exclude_eos=True,
                                                       speaker=speakers[0])[0]
        else:
            nbest_hyps_id = models[0].decode(
                batch['xs'],
                params,
                idx2token=dataloader.idx2token[0],
                exclude_eos=True,
                refs_id=batch['ys'],
                utt_ids=batch['utt_ids'],
                speakers=speakers,
                ensemble_models=models[1:] if len(models) > 1 else [],
                teacher_force=teacher_force)[0]

        for b in range(len(batch['xs'])):
            ref = batch['text'][b]
            nbest_hyps = [
                dataloader.idx2token[0](hyp_id) for hyp_id in nbest_hyps_id[b]
            ]

            # Write to trn
            speaker = str(batch['speakers'][b]).replace('-', '_')
            if streaming:
                utt_id = str(batch['utt_ids'][b]) + '_0000000_0000001'
            else:
                utt_id = str(batch['utt_ids'][b])
            if rank == 0:
                f_ref.write(ref + ' (' + speaker + '-' + utt_id + ')\n')
                f_hyp.write(nbest_hyps[0] + ' (' + speaker + '-' + utt_id +
                            ')\n')
            logger.debug('utt-id (%d/%d): %s' %
                         (n_utt + 1, len(dataloader), utt_id))
            logger.debug('Ref: %s' % ref)
            logger.debug('Hyp: %s' % nbest_hyps[0])
            logger.debug('-' * 150)

            if edit_distance and not streaming:
                # Compute PER
                err_b, sub_b, ins_b, del_b = compute_wer(
                    ref=ref.split(' '), hyp=nbest_hyps[0].split(' '))
                per += err_b
                n_sub += sub_b
                n_ins += ins_b
                n_del += del_b
                n_phone += len(ref.split(' '))

                # Compute oracle PER
                if oracle and len(nbest_hyps) > 1:
                    pers_b = [err_b] + [
                        compute_wer(ref=ref.split(' '),
                                    hyp=hyp_n.split(' '))[0]
                        for hyp_n in nbest_hyps[1:]
                    ]
                    oracle_idx = np.argmin(np.array(pers_b))
                    if oracle_idx == 0:
                        n_oracle_hit += len(batch['utt_ids'])
                    per_oracle += pers_b[oracle_idx]

        n_utt += len(batch['utt_ids'])
        if progressbar:
            pbar.update(len(batch['utt_ids']))

    if rank == 0:
        f_hyp.close()
        f_ref.close()
    if progressbar:
        pbar.close()

    # Reset data counters
    dataloader.reset(is_new_epoch=True)

    if edit_distance and not streaming:
        per /= n_phone
        n_sub /= n_phone
        n_ins /= n_phone
        n_del /= n_phone

        if params.get('recog_beam_width') > 1:
            logger.info('PER (%s): %.2f %%' % (dataloader.set, per))
            logger.info('SUB: %.2f / INS: %.2f / DEL: %.2f' %
                        (n_sub, n_ins, n_del))

        if oracle:
            per_oracle /= n_phone
            oracle_hit_rate = n_oracle_hit * 100 / n_utt
            logger.info('Oracle PER (%s): %.2f %%' %
                        (dataloader.set, per_oracle))
            logger.info('Oracle hit rate (%s): %.2f %%' %
                        (dataloader.set, oracle_hit_rate))

        if fine_grained:
            for len_bin, pers in sorted(per_dist.items(), key=lambda x: x[0]):
                logger.info('  PER (%s): %.2f %% (%d)' %
                            (dataloader.set, sum(pers) / len(pers), len_bin))

    return per