def do_eval_wer(model,
                dataset,
                beam_width,
                max_decode_len,
                eval_batch_size=None,
                progressbar=False):
    """Evaluate trained model by Word Error Rate.
    Args:
        model: the model to evaluate
        dataset: An instance of a `Dataset' class
        beam_width: (int): the size of beam
        max_decode_len (int): the length of output sequences
            to stop prediction when EOS token have not been emitted.
            This is used for seq2seq models.
        eval_batch_size (int, optional): the batch size when evaluating the model
        progressbar (bool, optional): if True, visualize the progressbar
    Returns:
        wer (float): Word error rate
        df_wer (pd.DataFrame): dataframe of substitution, insertion, and deletion
    """
    # Reset data counter
    dataset.reset()

    idx2word = Idx2word(vocab_file_path=dataset.vocab_file_path)

    wer = 0
    sub, ins, dele, = 0, 0, 0
    num_words = 0
    if progressbar:
        pbar = tqdm(total=len(dataset))  # TODO: fix this
    while True:
        batch, is_new_epoch = dataset.next(batch_size=eval_batch_size)

        # Decode
        best_hyps, perm_idx = model.decode(batch['xs'],
                                           batch['x_lens'],
                                           beam_width=beam_width,
                                           max_decode_len=max_decode_len)
        ys = batch['ys'][perm_idx]
        y_lens = batch['y_lens'][perm_idx]

        for b in range(len(batch['xs'])):

            ##############################
            # Reference
            ##############################
            if dataset.is_test:
                str_ref = ys[b][0]
                # NOTE: transcript is seperated by space('_')
            else:
                # Convert from list of index to string
                str_ref = idx2word(ys[b][:y_lens[b]])

            ##############################
            # Hypothesis
            ##############################
            str_hyp = idx2word(best_hyps[b])
            if 'attention' in model.model_type:
                str_hyp = str_hyp.split('>')[0]
                # NOTE: Trancate by the first <EOS>

                # Remove the last space
                if len(str_hyp) > 0 and str_hyp[-1] == '_':
                    str_hyp = str_hyp[:-1]

            ##############################
            # Post-proccessing
            ##############################
            # Remove garbage labels
            str_ref = re.sub(r'[\'>]+', '', str_ref)
            str_hyp = re.sub(r'[\'>]+', '', str_hyp)
            # TODO: WER計算するときに消していい?

            # Compute WER
            wer_b, sub_b, ins_b, del_b = compute_wer(ref=str_ref.split('_'),
                                                     hyp=str_hyp.split('_'),
                                                     normalize=True)
            wer += wer_b
            sub += sub_b
            ins += ins_b
            dele += del_b
            num_words += len(str_ref.split('_'))

            if progressbar:
                pbar.update(1)

        if is_new_epoch:
            break

    if progressbar:
        pbar.close()

    wer /= num_words
    sub /= num_words
    ins /= num_words
    dele /= num_words

    df_wer = pd.DataFrame(
        {
            'SUB': [sub * 100],
            'INS': [ins * 100],
            'DEL': [dele * 100]
        },
        columns=['SUB', 'INS', 'DEL'],
        index=['WER'])

    return wer, df_wer
def do_eval_cer(save_paths,
                dataset,
                data_type,
                label_type,
                num_classes,
                beam_width,
                temperature_infer,
                is_test=False,
                progressbar=False):
    """Evaluate trained model by Character Error Rate.
    Args:
        save_paths (list):
        dataset: An instance of a `Dataset` class
        data_type (string):
        label_type (string): character
        num_classes (int):
        beam_width (int): the size of beam
        temperature (int): temperature in the inference stage
        is_test (bool, optional): set to True when evaluating by the test set
        progressbar (bool, optional): if True, visualize the progressbar
    Return:
        cer_mean (float): An average of CER
        wer_mean (float): An average of WER
    """
    if label_type == 'character':
        idx2char = Idx2char(
            map_file_path='../metrics/mapping_files/character.txt')
        char2idx = Char2idx(
            map_file_path='../metrics/mapping_files/character.txt')
    else:
        raise TypeError

    # Define decoder
    decoder = BeamSearchDecoder(space_index=char2idx(str_char='_')[0],
                                blank_index=num_classes - 1)

    ##################################################
    # Compute mean probabilities
    ##################################################
    if progressbar:
        pbar = tqdm(total=len(dataset))
    cer_mean, wer_mean = 0, 0
    for data, is_new_epoch in dataset:

        # Create feed dictionary for next mini batch
        inputs, labels_true, inputs_seq_len, input_names = data

        batch_size = inputs[0].shape[0]
        for i_batch in range(batch_size):
            probs_ensemble = None
            for i_model in range(len(save_paths)):

                # Load posteriors
                speaker = input_names[0][i_batch].split('-')[0]
                prob_save_path = join(save_paths[i_model],
                                      'temp' + str(temperature_infer),
                                      data_type, 'probs_utt', speaker,
                                      input_names[0][i_batch] + '.npy')
                probs_model_i = np.load(prob_save_path)
                # NOTE: probs_model_i: `[T, num_classes]`

                # Sum over probs
                if probs_ensemble is None:
                    probs_ensemble = probs_model_i
                else:
                    probs_ensemble += probs_model_i

            # Compute mean posteriors
            probs_ensemble /= len(save_paths)

            # Decode per utterance
            labels_pred, scores = decoder(
                probs=probs_ensemble[np.newaxis, :, :],
                seq_len=inputs_seq_len[0][i_batch:i_batch + 1],
                beam_width=beam_width)

            # Convert from list of index to string
            if is_test:
                str_true = labels_true[0][i_batch][0]
                # NOTE: transcript is seperated by space('_')
            else:
                str_true = idx2char(labels_true[0][i_batch],
                                    padded_value=dataset.padded_value)
            str_pred = idx2char(labels_pred[0])

            # Remove consecutive spaces
            str_pred = re.sub(r'[_]+', '_', str_pred)

            # Remove garbage labels
            str_true = re.sub(r'[\']+', '', str_true)
            str_pred = re.sub(r'[\']+', '', str_pred)

            # Compute WER
            wer_mean += compute_wer(ref=str_pred.split('_'),
                                    hyp=str_true.split('_'),
                                    normalize=True)
            # substitute, insert, delete = wer_align(
            #     ref=str_true.split('_'),
            #     hyp=str_pred.split('_'))
            # print('SUB: %d' % substitute)
            # print('INS: %d' % insert)
            # print('DEL: %d' % delete)

            # Remove spaces
            str_true = re.sub(r'[_]+', '', str_true)
            str_pred = re.sub(r'[_]+', '', str_pred)

            # Compute CER
            cer_mean += compute_cer(str_pred=str_pred,
                                    str_true=str_true,
                                    normalize=True)

            if progressbar:
                pbar.update(1)

        if is_new_epoch:
            break

    cer_mean /= (len(dataset))
    wer_mean /= (len(dataset))
    # TODO: Fix this

    return cer_mean, wer_mean
示例#3
0
def eval_phone(model,
               dataset,
               map_file_path,
               eval_batch_size,
               beam_width,
               max_decode_len,
               min_decode_len=0,
               length_penalty=0,
               coverage_penalty=0,
               progressbar=False):
    """Evaluate trained model by Phone Error Rate.
    Args:
        model: the model to evaluate
        dataset: An instance of a `Dataset' class
        map_file_path (string): path to phones.60-48-39.map
        eval_batch_size (int): the batch size when evaluating the model
        beam_width: (int): the size of beam
        max_decode_len (int): the length of output sequences
            to stop prediction. This is used for seq2seq models.
        min_decode_len (int): the minimum sequence length to emit
        length_penalty (float): length penalty in beam search decoding
        coverage_penalty (float): coverage penalty in beam search decoding
        progressbar (bool): if True, visualize the progressbar
    Returns:
        per (float): Phone error rate
        df_phone (pd.DataFrame): dataframe of substitution, insertion, and deletion
    """
    # Reset data counter
    dataset.reset()

    map2phone39 = Map2phone39(label_type=dataset.label_type,
                              map_file_path=map_file_path)

    per = 0
    sub, ins, dele = 0, 0, 0
    num_phones = 0
    if progressbar:
        pbar = tqdm(total=len(dataset))  # TODO: fix this
    while True:
        batch, is_new_epoch = dataset.next(batch_size=eval_batch_size)

        # Decode
        best_hyps, _, perm_idx = model.decode(
            batch['xs'],
            batch['x_lens'],
            beam_width=beam_width,
            max_decode_len=max_decode_len,
            min_decode_len=min_decode_len,
            length_penalty=length_penalty,
            coverage_penalty=coverage_penalty)

        ys = batch['ys'][perm_idx]
        y_lens = batch['y_lens'][perm_idx]

        for b in range(len(batch['xs'])):
            ##############################
            # Reference
            ##############################
            if dataset.is_test:
                phone_ref_list = ys[b][0].split(' ')
                # NOTE: transcript is seperated by space(' ')
            else:
                # Convert from index to phone (-> list of phone strings)
                phone_ref_list = dataset.idx2phone(
                    ys[b][:y_lens[b]]).split(' ')

            ##############################
            # Hypothesis
            ##############################
            # Convert from index to phone (-> list of phone strings)
            str_hyp = dataset.idx2phone(best_hyps[b])
            str_hyp = re.sub(r'(.*) >(.*)', r'\1', str_hyp)
            # NOTE: Trancate by the first <EOS>

            phone_hyp_list = str_hyp.split(' ')

            # Mapping to 39 phones (-> list of phone strings)
            if dataset.label_type != 'phone39':
                phone_ref_list = map2phone39(phone_ref_list)
                phone_hyp_list = map2phone39(phone_hyp_list)

            # Compute PER
            try:
                per_b, sub_b, ins_b, del_b = compute_wer(ref=phone_ref_list,
                                                         hyp=phone_hyp_list,
                                                         normalize=False)
                per += per_b
                sub += sub_b
                ins += ins_b
                dele += del_b
                num_phones += len(phone_ref_list)
            except:
                pass

            if progressbar:
                pbar.update(1)

        if is_new_epoch:
            break

    if progressbar:
        pbar.close()

    # Reset data counters
    dataset.reset()

    per /= num_phones
    sub /= num_phones
    ins /= num_phones
    dele /= num_phones

    df_phone = pd.DataFrame(
        {
            'SUB': [sub * 100],
            'INS': [ins * 100],
            'DEL': [dele * 100]
        },
        columns=['SUB', 'INS', 'DEL'],
        index=['PER'])

    return per, df_phone
def decode(model, dataset, beam_width, eval_batch_size=None, save_path=None):
    """Visualize label outputs.
    Args:
        model: the model to evaluate
        dataset: An instance of a `Dataset` class
        beam_width: (int): the size of beam
        eval_batch_size (int, optional): the batch size when evaluating the model
        save_path (string): path to save decoding results
    """
    # Set batch size in the evaluation
    if eval_batch_size is not None:
        dataset.batch_size = eval_batch_size

    if 'char' in dataset.label_type:
        map_fn = Idx2char(
            dataset.vocab_file_path,
            capital_divide=dataset.label_type == 'character_capital_divide')
        max_decode_len = MAX_DECODE_LEN_CHAR
    else:
        map_fn = Idx2word(dataset.vocab_file_path)
        max_decode_len = MAX_DECODE_LEN_WORD

    # Read GLM file
    glm = GLM(
        glm_path=
        '/n/sd8/inaguma/corpus/swbd/data/eval2000/LDC2002T43/reference/en20000405_hub5.glm'
    )

    if save_path is not None:
        sys.stdout = open(join(model.model_dir, 'decode.txt'), 'w')

    for batch, is_new_epoch in dataset:

        # Decode
        if model.model_type == 'nested_attention':
            best_hyps, _, best_hyps_sub, _, perm_idx = model.decode(
                batch['xs'],
                batch['x_lens'],
                beam_width=beam_width,
                max_decode_len=max_decode_len,
                max_decode_len_sub=max_decode_len)
        else:
            best_hyps, _, perm_idx = model.decode(
                batch['xs'],
                batch['x_lens'],
                beam_width=beam_width,
                max_decode_len=max_decode_len)

        if model.model_type == 'attention' and model.ctc_loss_weight > 0:
            best_hyps_ctc, perm_idx = model.decode_ctc(batch['xs'],
                                                       batch['x_lens'],
                                                       beam_width=beam_width)

        ys = batch['ys'][perm_idx]
        y_lens = batch['y_lens'][perm_idx]

        for b in range(len(batch['xs'])):
            ##############################
            # Reference
            ##############################
            if dataset.is_test:
                str_ref_original = ys[b][0]
                # NOTE: transcript is seperated by space('_')
            else:
                # Convert from list of index to string
                str_ref_original = map_fn(ys[b][:y_lens[b]])

            ##############################
            # Hypothesis
            ##############################
            # Convert from list of index to string
            str_hyp = map_fn(best_hyps[b])

            ##############################
            # Post-proccessing
            ##############################
            str_ref = fix_trans(str_ref_original, glm)
            str_hyp = fix_trans(str_hyp, glm)

            if len(str_ref) == 0:
                continue

            print('----- wav: %s -----' % batch['input_names'][b])
            print('Ref: %s' % str_ref.replace('_', ' '))
            print('Hyp: %s' % str_hyp.replace('_', ' '))
            if model.model_type == 'attention' and model.ctc_loss_weight > 0:
                str_hyp_ctc = map_fn(best_hyps_ctc[b])
                print('Hyp (CTC): %s' % str_hyp_ctc)

            try:
                # Compute WER
                wer, _, _, _ = compute_wer(
                    ref=str_ref.split('_'),
                    hyp=str_hyp.replace(r'_>.*', '').replace(r'>.*',
                                                             '').split('_'),
                    normalize=True)
                print('WER: %.3f %%' % (wer * 100))
                if model.ctc_loss_weight > 0:
                    wer_ctc, _, _, _ = compute_wer(ref=str_ref.split('_'),
                                                   hyp=str_hyp_ctc.split('_'),
                                                   normalize=True)
                    print('WER (CTC): %.3f %%' % (wer_ctc * 100))
            except:
                print('--- skipped ---')

        if is_new_epoch:
            break
def do_eval_cer(save_paths, dataset, data_type, label_type, num_classes,
                beam_width, temperature_infer,
                is_test=False, progressbar=False):
    """Evaluate trained model by Character Error Rate.
    Args:
        save_paths (list):
        dataset: An instance of a `Dataset` class
        data_type (string):
        label_type (string): character
        num_classes (int):
        beam_width (int): the size of beam
        temperature (int): temperature in the inference stage
        is_test (bool, optional): set to True when evaluating by the test set
        progressbar (bool, optional): if True, visualize the progressbar
    Return:
        cer_mean (float): An average of CER
        wer_mean (float): An average of WER
    """
    if label_type == 'character':
        idx2char = Idx2char(
            map_file_path='../metrics/mapping_files/character.txt')
        char2idx = Char2idx(
            map_file_path='../metrics/mapping_files/character.txt')
    else:
        raise TypeError

    # Define decoder
    decoder = BeamSearchDecoder(space_index=char2idx(str_char='_')[0],
                                blank_index=num_classes - 1)

    ##################################################
    # Compute mean probabilities
    ##################################################
    if progressbar:
        pbar = tqdm(total=len(dataset))
    cer_mean, wer_mean = 0, 0
    for data, is_new_epoch in dataset:

        # Create feed dictionary for next mini batch
        inputs, labels_true, inputs_seq_len, input_names = data

        batch_size = inputs[0].shape[0]
        for i_batch in range(batch_size):
            probs_ensemble = None
            for i_model in range(len(save_paths)):

                # Load posteriors
                speaker = input_names[0][i_batch].split('-')[0]
                prob_save_path = join(
                    save_paths[i_model], 'temp' + str(temperature_infer),
                    data_type, 'probs_utt', speaker,
                    input_names[0][i_batch] + '.npy')
                probs_model_i = np.load(prob_save_path)
                # NOTE: probs_model_i: `[T, num_classes]`

                # Sum over probs
                if probs_ensemble is None:
                    probs_ensemble = probs_model_i
                else:
                    probs_ensemble += probs_model_i

            # Compute mean posteriors
            probs_ensemble /= len(save_paths)

            # Decode per utterance
            labels_pred, scores = decoder(
                probs=probs_ensemble[np.newaxis, :, :],
                seq_len=inputs_seq_len[0][i_batch: i_batch + 1],
                beam_width=beam_width)

            # Convert from list of index to string
            if is_test:
                str_true = labels_true[0][i_batch][0]
                # NOTE: transcript is seperated by space('_')
            else:
                str_true = idx2char(labels_true[0][i_batch],
                                    padded_value=dataset.padded_value)
            str_pred = idx2char(labels_pred[0])

            # Remove consecutive spaces
            str_pred = re.sub(r'[_]+', '_', str_pred)

            # Remove garbage labels
            str_true = re.sub(r'[\']+', '', str_true)
            str_pred = re.sub(r'[\']+', '', str_pred)

            # Compute WER
            wer_mean += compute_wer(ref=str_pred.split('_'),
                                    hyp=str_true.split('_'),
                                    normalize=True)
            # substitute, insert, delete = wer_align(
            #     ref=str_true.split('_'),
            #     hyp=str_pred.split('_'))
            # print('SUB: %d' % substitute)
            # print('INS: %d' % insert)
            # print('DEL: %d' % delete)

            # Remove spaces
            str_true = re.sub(r'[_]+', '', str_true)
            str_pred = re.sub(r'[_]+', '', str_pred)

            # Compute CER
            cer_mean += compute_cer(str_pred=str_pred,
                                    str_true=str_true,
                                    normalize=True)

            if progressbar:
                pbar.update(1)

        if is_new_epoch:
            break

    cer_mean /= (len(dataset))
    wer_mean /= (len(dataset))
    # TODO: Fix this

    return cer_mean, wer_mean
def eval_char(models,
              eval_batch_size,
              dataset,
              beam_width,
              max_decode_len,
              min_decode_len=0,
              length_penalty=0,
              coverage_penalty=0,
              progressbar=False):
    """Evaluate trained model by Character Error Rate.
    Args:
        models (list): the models to evaluate
        dataset: An instance of a `Dataset' class
        eval_batch_size (int): the batch size when evaluating the model
        beam_width: (int): the size of beam
        max_decode_len (int): the maximum sequence length to emit
        min_decode_len (int): the minimum sequence length to emit
        length_penalty (float): length penalty in beam search decoding
        coverage_penalty (float): coverage penalty in beam search decoding
        progressbar (bool): if True, visualize the progressbar
    Returns:
        wer (float): Word error rate
        cer (float): Character error rate
        df_word (pd.DataFrame): dataframe of substitution, insertion, and deletion
    """
    # Reset data counter
    dataset.reset()

    model = models[0]
    # TODO: fix this

    wer, cer = 0, 0
    sub_word, ins_word, del_word = 0, 0, 0
    sub_char, ins_char, del_char = 0, 0, 0
    num_words, num_chars = 0, 0
    if progressbar:
        pbar = tqdm(total=len(dataset))  # TODO: fix this
    while True:
        batch, is_new_epoch = dataset.next(batch_size=eval_batch_size)

        # Decode
        if model.model_type in ['ctc', 'attention']:
            best_hyps, _, perm_idx = model.decode(
                batch['xs'],
                batch['x_lens'],
                beam_width=beam_width,
                max_decode_len=max_decode_len,
                min_decode_len=min_decode_len,
                length_penalty=length_penalty,
                coverage_penalty=coverage_penalty,
                task_index=0)
            ys = batch['ys'][perm_idx]
            y_lens = batch['y_lens'][perm_idx]
        else:
            best_hyps, _, perm_idx = model.decode(
                batch['xs'],
                batch['x_lens'],
                beam_width=beam_width,
                max_decode_len=max_decode_len,
                min_decode_len=min_decode_len,
                length_penalty=length_penalty,
                coverage_penalty=coverage_penalty,
                task_index=1)
            ys = batch['ys_sub'][perm_idx]
            y_lens = batch['y_lens_sub'][perm_idx]

        for b in range(len(batch['xs'])):
            ##############################
            # Reference
            ##############################
            if dataset.is_test:
                str_ref = ys[b][0]
                # NOTE: transcript is seperated by space('_')
            else:
                # Convert from list of index to string
                str_ref = dataset.idx2char(ys[b][:y_lens[b]])

            ##############################
            # Hypothesis
            ##############################
            str_hyp = dataset.idx2char(best_hyps[b])
            str_hyp = re.sub(r'(.*)>(.*)', r'\1', str_hyp)
            # NOTE: Trancate by the first <EOS>

            ##############################
            # Post-proccessing
            ##############################
            # Remove garbage labels
            str_ref = re.sub(r'[@>]+', '', str_ref)
            str_hyp = re.sub(r'[@>]+', '', str_hyp)
            # NOTE: @ means noise

            # Remove consecutive spaces
            str_ref = re.sub(r'[_]+', '_', str_ref)
            str_hyp = re.sub(r'[_]+', '_', str_hyp)

            try:
                # Compute WER
                wer_b, sub_b, ins_b, del_b = compute_wer(
                    ref=str_ref.split('_'),
                    hyp=str_hyp.split('_'),
                    normalize=False)
                wer += wer_b
                sub_word += sub_b
                ins_word += ins_b
                del_word += del_b
                num_words += len(str_ref.split('_'))

                # Compute CER
                cer_b, sub_b, ins_b, del_b = compute_wer(
                    ref=list(str_ref.replace('_', '')),
                    hyp=list(str_hyp.replace('_', '')),
                    normalize=False)
                cer += cer_b
                sub_char += sub_b
                ins_char += ins_b
                del_char += del_b
                num_chars += len(str_ref.replace('_', ''))
            except:
                pass

            if progressbar:
                pbar.update(1)

        if is_new_epoch:
            break

    if progressbar:
        pbar.close()

    # Reset data counters
    dataset.reset()

    wer /= num_words
    sub_word /= num_words
    ins_word /= num_words
    del_word /= num_words
    cer /= num_chars
    sub_char /= num_chars
    ins_char /= num_chars
    del_char /= num_chars

    df_word = pd.DataFrame(
        {
            'SUB': [sub_word * 100, sub_char * 100],
            'INS': [ins_word * 100, ins_char * 100],
            'DEL': [del_word * 100, del_char * 100]
        },
        columns=['SUB', 'INS', 'DEL'],
        index=['WER', 'CER'])

    return wer, cer, df_word
def main():

    args = parser.parse_args()

    # Load a config file (.yml)
    params = load_config(join(args.model_path, 'config.yml'), is_eval=True)

    # Load dataset
    dataset = Dataset(
        data_save_path=args.data_save_path,
        backend=params['backend'],
        input_freq=params['input_freq'],
        use_delta=params['use_delta'],
        use_double_delta=params['use_double_delta'],
        data_type='eval1',
        # data_type='eval2',
        # data_type='eval3',
        data_size=params['data_size'],
        label_type=params['label_type'],
        label_type_sub=params['label_type_sub'],
        batch_size=args.eval_batch_size,
        splice=params['splice'],
        num_stack=params['num_stack'],
        num_skip=params['num_skip'],
        sort_utt=False,
        reverse=False,
        tool=params['tool'])

    params['num_classes'] = dataset.num_classes
    params['num_classes_sub'] = dataset.num_classes_sub

    # Load model
    model = load(model_type=params['model_type'],
                 params=params,
                 backend=params['backend'])

    # Restore the saved parameters
    model.load_checkpoint(save_path=args.model_path, epoch=args.epoch)

    # GPU setting
    model.set_cuda(deterministic=False, benchmark=True)

    # sys.stdout = open(join(model.model_dir, 'decode.txt'), 'w')

    ######################################################################

    word2char = Word2char(dataset.vocab_file_path, dataset.vocab_file_path_sub)

    for batch, is_new_epoch in dataset:
        # Decode
        if model.model_type == 'nested_attention':
            best_hyps, aw, best_hyps_sub, aw_sub, _, perm_idx = model.decode(
                batch['xs'],
                batch['x_lens'],
                beam_width=args.beam_width,
                beam_width_sub=args.beam_width_sub,
                max_decode_len=MAX_DECODE_LEN_WORD,
                max_decode_len_sub=MAX_DECODE_LEN_CHAR,
                length_penalty=args.length_penalty,
                coverage_penalty=args.coverage_penalty)
        else:
            best_hyps, aw, perm_idx = model.decode(
                batch['xs'],
                batch['x_lens'],
                beam_width=args.beam_width,
                max_decode_len=MAX_DECODE_LEN_WORD,
                min_decode_len=MIN_DECODE_LEN_WORD,
                length_penalty=args.length_penalty,
                coverage_penalty=args.coverage_penalty)
            best_hyps_sub, aw_sub, _ = model.decode(
                batch['xs'],
                batch['x_lens'],
                beam_width=args.beam_width_sub,
                max_decode_len=MAX_DECODE_LEN_CHAR,
                min_decode_len=MIN_DECODE_LEN_CHAR,
                length_penalty=args.length_penalty,
                coverage_penalty=args.coverage_penalty,
                task_index=1)

        if model.model_type == 'hierarchical_attention' and args.joint_decoding is not None:
            best_hyps_joint, aw_joint, best_hyps_sub_joint, aw_sub_joint, _ = model.decode(
                batch['xs'],
                batch['x_lens'],
                beam_width=args.beam_width,
                max_decode_len=MAX_DECODE_LEN_WORD,
                min_decode_len=MIN_DECODE_LEN_WORD,
                length_penalty=args.length_penalty,
                coverage_penalty=args.coverage_penalty,
                joint_decoding=args.joint_decoding,
                space_index=dataset.char2idx('_')[0],
                oov_index=dataset.word2idx('OOV')[0],
                word2char=word2char,
                idx2word=dataset.idx2word,
                idx2char=dataset.idx2char,
                score_sub_weight=args.score_sub_weight)

        ys = batch['ys'][perm_idx]
        y_lens = batch['y_lens'][perm_idx]
        ys_sub = batch['ys_sub'][perm_idx]
        y_lens_sub = batch['y_lens_sub'][perm_idx]

        for b in range(len(batch['xs'])):
            ##############################
            # Reference
            ##############################
            if dataset.is_test:
                str_ref = ys[b][0]
                str_ref_sub = ys_sub[b][0]
                # NOTE: transcript is seperated by space('_')
            else:
                # Convert from list of index to string
                str_ref = dataset.idx2word(ys[b][:y_lens[b]])
                str_ref_sub = dataset.idx2char(ys_sub[b][:y_lens_sub[b]])

            ##############################
            # Hypothesis
            ##############################
            # Convert from list of index to string
            str_hyp = dataset.idx2word(best_hyps[b])
            str_hyp_sub = dataset.idx2char(best_hyps_sub[b])
            if model.model_type == 'hierarchical_attention' and args.joint_decoding is not None:
                str_hyp_joint = dataset.idx2word(best_hyps_joint[b])
                str_hyp_sub_joint = dataset.idx2char(best_hyps_sub_joint[b])

            ##############################
            # Resolving UNK
            ##############################
            if 'OOV' in str_hyp and args.resolving_unk:
                str_hyp_no_unk = resolve_unk(str_hyp, best_hyps_sub[b], aw[b],
                                             aw_sub[b], dataset.idx2char)
            if model.model_type == 'hierarchical_attention' and args.joint_decoding is not None:
                if 'OOV' in str_hyp_joint and args.resolving_unk:
                    str_hyp_no_unk_joint = resolve_unk(str_hyp_joint,
                                                       best_hyps_sub_joint[b],
                                                       aw_joint[b],
                                                       aw_sub_joint[b],
                                                       dataset.idx2char)

            print('----- wav: %s -----' % batch['input_names'][b])
            print('Ref         : %s' % str_ref.replace('_', ' '))
            print('Hyp (main)  : %s' % str_hyp.replace('_', ' '))
            print('Hyp (sub)   : %s' % str_hyp_sub.replace('_', ' '))
            if 'OOV' in str_hyp and args.resolving_unk:
                print('Hyp (no UNK): %s' % str_hyp_no_unk.replace('_', ' '))
            if model.model_type == 'hierarchical_attention' and args.joint_decoding is not None:
                print('===== joint decoding =====')
                print('Hyp (main)  : %s' % str_hyp_joint.replace('_', ' '))
                print('Hyp (sub)   : %s' % str_hyp_sub_joint.replace('_', ' '))
                if 'OOV' in str_hyp_joint and args.resolving_unk:
                    print('Hyp (no UNK): %s' %
                          str_hyp_no_unk_joint.replace('_', ' '))

            try:
                wer, _, _, _ = compute_wer(ref=str_ref.split('_'),
                                           hyp=re.sub(r'(.*)_>(.*)', r'\1',
                                                      str_hyp).split('_'),
                                           normalize=True)
                print('WER (main)  : %.3f %%' % (wer * 100))
                if dataset.label_type_sub == 'character_wb':
                    wer_sub, _, _, _ = compute_wer(ref=str_ref_sub.split('_'),
                                                   hyp=re.sub(
                                                       r'(.*)>(.*)', r'\1',
                                                       str_hyp_sub).split('_'),
                                                   normalize=True)
                    print('WER (sub)   : %.3f %%' % (wer_sub * 100))
                else:
                    cer, _, _, _ = compute_wer(
                        ref=list(str_ref_sub.replace('_', '')),
                        hyp=list(
                            re.sub(r'(.*)>(.*)', r'\1',
                                   str_hyp_sub).replace('_', '')),
                        normalize=True)
                    print('CER (sub)   : %.3f %%' % (cer * 100))
                if 'OOV' in str_hyp and args.resolving_unk:
                    wer_no_unk, _, _, _ = compute_wer(
                        ref=str_ref.split('_'),
                        hyp=re.sub(r'(.*)_>(.*)', r'\1',
                                   str_hyp_no_unk.replace('*', '')).split('_'),
                        normalize=True)
                    print('WER (no UNK): %.3f %%' % (wer_no_unk * 100))

                if model.model_type == 'hierarchical_attention' and args.joint_decoding is not None:
                    print('===== joint decoding =====')
                    wer_joint, _, _, _ = compute_wer(
                        ref=str_ref.split('_'),
                        hyp=re.sub(r'(.*)_>(.*)', r'\1',
                                   str_hyp_joint).split('_'),
                        normalize=True)
                    print('WER (main)  : %.3f %%' % (wer_joint * 100))
                    if 'OOV' in str_hyp_joint and args.resolving_unk:
                        wer_no_unk_joint, _, _, _ = compute_wer(
                            ref=str_ref.split('_'),
                            hyp=re.sub(r'(.*)_>(.*)', r'\1',
                                       str_hyp_no_unk_joint.replace(
                                           '*', '')).split('_'),
                            normalize=True)
                        print('WER (no UNK): %.3f %%' %
                              (wer_no_unk_joint * 100))

            except:
                print('--- skipped ---')
            print('\n')

        if is_new_epoch:
            break
def do_eval_per(session, decode_op, per_op, model, dataset, label_type,
                eval_batch_size=None, progressbar=False,
                is_multitask=False, is_jointctcatt=False):
    """Evaluate trained model by Phone Error Rate.
    Args:
        session: session of training model
        decode_op: operation for decoding
        per_op: operation for computing phone error rate
        model: the model to evaluate
        dataset: An instance of a `Dataset' class
        label_type (string): phone39 or phone48 or phone61
        eval_batch_size (int, optional): the batch size when evaluating the model
        progressbar (bool, optional): if True, visualize the progressbar
        is_multitask (bool, optional): if True, evaluate the multitask model
        is_jointctcatt (bool, optional): if True, evaluate the joint
            CTC-Attention model
    Returns:
        per_mean (float): An average of PER
    """
    # Reset data counter
    dataset.reset()

    train_label_type = label_type
    eval_label_type = dataset.label_type_sub if is_multitask else dataset.label_type

    # phone2idx_39_map_file_path = '../metrics/mapping_files/attention/phone39.txt'
    idx2phone_train = Idx2phone(
        map_file_path='../metrics/mapping_files/attention/' + train_label_type + '.txt')
    idx2phone_eval = Idx2phone(
        map_file_path='../metrics/mapping_files/attention/' + eval_label_type + '.txt')
    map2phone39_train = Map2phone39(
        label_type=train_label_type,
        map_file_path='../metrics/mapping_files/phone2phone.txt')
    map2phone39_eval = Map2phone39(
        label_type=eval_label_type,
        map_file_path='../metrics/mapping_files/phone2phone.txt')

    per_mean = 0
    if progressbar:
        pbar = tqdm(total=len(dataset))
    for data, is_new_epoch in dataset:

        # Create feed dictionary for next mini-batch
        if is_multitask:
            inputs, _, labels_true, inputs_seq_len, _, _ = data
        elif is_jointctcatt:
            inputs, labels_true, _, inputs_seq_len, _, _ = data
        else:
            inputs, labels_true, inputs_seq_len, _, _ = data

        feed_dict = {
            model.inputs_pl_list[0]: inputs,
            model.inputs_seq_len_pl_list[0]: inputs_seq_len,
            model.keep_prob_input_pl_list[0]: 1.0,
            model.keep_prob_hidden_pl_list[0]: 1.0,
            model.keep_prob_output_pl_list[0]: 1.0
        }

        batch_size_each = len(inputs_seq_len)

        # Evaluate by 39 phones
        labels_pred = session.run(decode_op, feed_dict=feed_dict)

        for i_batch in range(batch_size_each):
            ###############
            # Hypothesis
            ###############
            # Convert from num to phone (-> list of phone strings)
            phone_pred_list = idx2phone_train(labels_pred[i_batch]).split(' ')

            # Mapping to 39 phones (-> list of phone strings)
            phone_pred_list = map2phone39_train(phone_pred_list)

            ###############
            # Reference
            ###############
            # Convert from num to phone (-> list of phone strings)
            phone_true_list = idx2phone_eval(labels_true[i_batch]).split(' ')

            # Mapping to 39 phones (-> list of phone strings)
            phone_true_list = map2phone39_eval(phone_true_list)

            # Compute PER
            per_mean += compute_wer(str_pred=' '.join(phone_pred_list),
                                    str_true=' '.join(phone_true_list),
                                    normalize=True,
                                    space_mark=' ')

            if progressbar:
                pbar.update(1)

        if is_new_epoch:
            break

    per_mean /= len(dataset)

    return per_mean
示例#9
0
    def check(self,
              encoder_type,
              decoder_type,
              bidirectional=False,
              attention_type='location',
              subsample=False,
              projection=False,
              ctc_loss_weight_sub=0,
              conv=False,
              batch_norm=False,
              residual=False,
              dense_residual=False,
              num_heads=1,
              backward_sub=False):

        print('==================================================')
        print('  encoder_type: %s' % encoder_type)
        print('  bidirectional: %s' % str(bidirectional))
        print('  projection: %s' % str(projection))
        print('  decoder_type: %s' % decoder_type)
        print('  attention_type: %s' % attention_type)
        print('  subsample: %s' % str(subsample))
        print('  ctc_loss_weight_sub: %s' % str(ctc_loss_weight_sub))
        print('  conv: %s' % str(conv))
        print('  batch_norm: %s' % str(batch_norm))
        print('  residual: %s' % str(residual))
        print('  dense_residual: %s' % str(dense_residual))
        print('  backward_sub: %s' % str(backward_sub))
        print('  num_heads: %s' % str(num_heads))
        print('==================================================')

        if conv or encoder_type == 'cnn':
            # pattern 1
            # conv_channels = [32, 32]
            # conv_kernel_sizes = [[41, 11], [21, 11]]
            # conv_strides = [[2, 2], [2, 1]]
            # poolings = [[], []]

            # pattern 2 (VGG like)
            conv_channels = [64, 64]
            conv_kernel_sizes = [[3, 3], [3, 3]]
            conv_strides = [[1, 1], [1, 1]]
            poolings = [[2, 2], [2, 2]]
        else:
            conv_channels = []
            conv_kernel_sizes = []
            conv_strides = []
            poolings = []

        # Load batch data
        splice = 1
        num_stack = 1 if subsample or conv or encoder_type == 'cnn' else 2
        xs, ys, ys_sub, x_lens, y_lens, y_lens_sub = generate_data(
            label_type='word_char',
            batch_size=2,
            num_stack=num_stack,
            splice=splice,
            backend='chainer')

        num_classes = 11
        num_classes_sub = 27

        # Load model
        model = HierarchicalAttentionSeq2seq(
            input_size=xs[0].shape[-1] // splice // num_stack,  # 120
            encoder_type=encoder_type,
            encoder_bidirectional=bidirectional,
            encoder_num_units=320,
            encoder_num_proj=320 if projection else 0,
            encoder_num_layers=2,
            encoder_num_layers_sub=1,
            attention_type=attention_type,
            attention_dim=128,
            decoder_type=decoder_type,
            decoder_num_units=320,
            decoder_num_layers=1,
            decoder_num_units_sub=320,
            decoder_num_layers_sub=1,
            embedding_dim=64,
            embedding_dim_sub=32,
            dropout_input=0.1,
            dropout_encoder=0.1,
            dropout_decoder=0.1,
            dropout_embedding=0.1,
            main_loss_weight=0.8,
            sub_loss_weight=0.2 if ctc_loss_weight_sub == 0 else 0,
            num_classes=num_classes,
            num_classes_sub=num_classes_sub,
            parameter_init_distribution='uniform',
            parameter_init=0.1,
            recurrent_weight_orthogonal=False,
            init_forget_gate_bias_with_one=True,
            subsample_list=[] if not subsample else [True, False],
            subsample_type='drop' if not subsample else subsample,
            bridge_layer=True,
            init_dec_state='first',
            sharpening_factor=1,
            logits_temperature=1,
            sigmoid_smoothing=False,
            ctc_loss_weight_sub=ctc_loss_weight_sub,
            attention_conv_num_channels=10,
            attention_conv_width=201,
            input_channel=3,
            num_stack=num_stack,
            splice=splice,
            conv_channels=conv_channels,
            conv_kernel_sizes=conv_kernel_sizes,
            conv_strides=conv_strides,
            poolings=poolings,
            activation='relu',
            batch_norm=batch_norm,
            scheduled_sampling_prob=0.1,
            scheduled_sampling_max_step=200,
            label_smoothing_prob=0.1,
            weight_noise_std=0,
            encoder_residual=residual,
            encoder_dense_residual=dense_residual,
            decoder_residual=residual,
            decoder_dense_residual=dense_residual,
            decoding_order='attend_generate_update',
            bottleneck_dim=256,
            bottleneck_dim_sub=256,
            backward_sub=backward_sub,
            num_heads=num_heads,
            num_heads_sub=num_heads)

        # Count total parameters
        for name in sorted(list(model.num_params_dict.keys())):
            num_params = model.num_params_dict[name]
            print("%s %d" % (name, num_params))
        print("Total %.3f M parameters" % (model.total_parameters / 1000000))

        # Define optimizer
        learning_rate = 1e-3
        model.set_optimizer('adam',
                            learning_rate_init=learning_rate,
                            weight_decay=1e-6,
                            lr_schedule=False,
                            factor=0.1,
                            patience_epoch=5)

        # Define learning rate controller
        lr_controller = Controller(learning_rate_init=learning_rate,
                                   backend='chainer',
                                   decay_start_epoch=20,
                                   decay_rate=0.9,
                                   decay_patient_epoch=10,
                                   lower_better=True)

        # GPU setting
        model.set_cuda(deterministic=False, benchmark=True)

        # Train model
        max_step = 300
        start_time_step = time.time()
        for step in range(max_step):

            # Step for parameter update
            loss, loss_main, loss_sub = model(xs, ys, x_lens, y_lens, ys_sub,
                                              y_lens_sub)
            model.optimizer.target.cleargrads()
            model.cleargrads()
            loss.backward()
            loss.unchain_backward()
            model.optimizer.update()

            if (step + 1) % 10 == 0:
                # Compute loss
                loss, loss_main, loss_sub = model(xs,
                                                  ys,
                                                  x_lens,
                                                  y_lens,
                                                  ys_sub,
                                                  y_lens_sub,
                                                  is_eval=True)

                # Decode
                best_hyps, _, _ = model.decode(
                    xs,
                    x_lens,
                    beam_width=1,
                    # beam_width=2,
                    max_decode_len=30)
                best_hyps_sub, _, _ = model.decode(
                    xs,
                    x_lens,
                    beam_width=1,
                    # beam_width=2,
                    max_decode_len=60,
                    task_index=1)

                str_hyp = idx2word(best_hyps[0][:-1]).split('>')[0]
                str_ref = idx2word(ys[0])
                str_hyp_sub = idx2char(best_hyps_sub[0][:-1]).split('>')[0]
                str_ref_sub = idx2char(ys_sub[0])

                # Compute accuracy
                try:
                    wer, _, _, _ = compute_wer(ref=str_ref.split('_'),
                                               hyp=str_hyp.split('_'),
                                               normalize=True)
                    cer, _, _, _ = compute_wer(
                        ref=list(str_ref_sub.replace('_', '')),
                        hyp=list(str_hyp_sub.replace('_', '')),
                        normalize=True)
                except:
                    wer = 1
                    cer = 1

                duration_step = time.time() - start_time_step
                print(
                    'Step %d: loss=%.3f(%.3f/%.3f) / wer=%.3f / cer=%.3f / lr=%.5f (%.3f sec)'
                    % (step + 1, loss, loss_main, loss_sub, wer, cer,
                       learning_rate, duration_step))
                start_time_step = time.time()

                # Visualize
                print('Ref: %s' % str_ref)
                print('Hyp (word): %s' % str_hyp)
                print('Hyp (char): %s' % str_hyp_sub)

                if cer < 0.1:
                    print('Modle is Converged.')
                    break

                # Update learning rate
                model.optimizer, learning_rate = lr_controller.decay_lr(
                    optimizer=model.optimizer,
                    learning_rate=learning_rate,
                    epoch=step,
                    value=wer)
def eval_char(models, dataset, beam_width, max_decode_len,
              eval_batch_size=None, length_penalty=0,
              progressbar=False, temperature=1):
    """Evaluate trained model by Character Error Rate.
    Args:
        models (list): the models to evaluate
        dataset: An instance of a `Dataset' class
        beam_width: (int): the size of beam
        max_decode_len (int): the length of output sequences
            to stop prediction when EOS token have not been emitted.
            This is used for seq2seq models.
        eval_batch_size (int, optional): the batch size when evaluating the model
        progressbar (bool, optional): if True, visualize the progressbar
        length_penalty (float, optional):
        temperature (int, optional):
    Returns:
        wer (float): Word error rate
        cer (float): Character error rate
        df_wer_cer (pd.DataFrame): dataframe of substitution, insertion, and deletion
    """
    # Reset data counter
    dataset.reset()

    if models[0].model_type in ['ctc', 'attention']:
        idx2char = Idx2char(vocab_file_path=dataset.vocab_file_path)
    else:
        idx2char = Idx2char(vocab_file_path=dataset.vocab_file_path_sub)

    wer, cer = 0, 0
    sub_word, ins_word, del_word = 0, 0, 0
    sub_char, ins_char, del_char = 0, 0, 0
    num_words, num_chars = 0, 0
    if progressbar:
        pbar = tqdm(total=len(dataset))  # TODO: fix this
    while True:
        batch, is_new_epoch = dataset.next(batch_size=eval_batch_size)

        # Decode
        if len(models) > 1:
            assert models[0].model_type in ['ctc']
            for i, model in enumerate(models):
                probs, x_lens, perm_idx = model.posteriors(
                    batch['xs'], batch['x_lens'], temperature=temperature)
                if i == 0:
                    probs_ensenmble = probs
                else:
                    probs_ensenmble += probs
            probs_ensenmble /= len(models)

            best_hyps = models[0].decode_from_probs(
                probs_ensenmble, x_lens, beam_width=beam_width)
        else:
            model = models[0]
            # TODO: fix this

            best_hyps, _, perm_idx = model.decode(
                batch['xs'], batch['x_lens'],
                beam_width=beam_width,
                max_decode_len=max_decode_len,
                length_penalty=length_penalty,
                task_index=0 if model.model_type in ['ctc', 'attention'] else 1)

        ys = batch['ys'][perm_idx]
        y_lens = batch['y_lens'][perm_idx]

        for b in range(len(batch['xs'])):
            ##############################
            # Reference
            ##############################
            if dataset.is_test:
                str_ref = ys[b][0]
                # NOTE: transcript is seperated by space('_')
            else:
                # Convert from list of index to string
                str_ref = idx2char(ys[b][:y_lens[b]])

            ##############################
            # Hypothesis
            ##############################
            str_hyp = idx2char(best_hyps[b])
            str_hyp = re.sub(r'(.*)>(.*)', r'\1', str_hyp)
            # NOTE: Trancate by the first <EOS>

            ##############################
            # Post-proccessing
            ##############################
            # Remove garbage labels
            str_ref = re.sub(r'[@>]+', '', str_ref)
            str_hyp = re.sub(r'[@>]+', '', str_hyp)
            # NOTE: @ means noise

            # Remove consecutive spaces
            str_ref = re.sub(r'[_]+', '_', str_ref)
            str_hyp = re.sub(r'[_]+', '_', str_hyp)

            try:
                # Compute WER
                wer_b, sub_b, ins_b, del_b = compute_wer(
                    ref=str_ref.split('_'),
                    hyp=str_hyp.split('_'),
                    normalize=False)
                wer += wer_b
                sub_word += sub_b
                ins_word += ins_b
                del_word += del_b
                num_words += len(str_ref.split('_'))

                # Compute CER
                cer_b, sub_b, ins_b, del_b = compute_wer(
                    ref=list(str_ref.replace('_', '')),
                    hyp=list(str_hyp.replace('_', '')),
                    normalize=False)
                cer += cer_b
                sub_char += sub_b
                ins_char += ins_b
                del_char += del_b
                num_chars += len(str_ref.replace('_', ''))
            except:
                pass

            if progressbar:
                pbar.update(1)

        if is_new_epoch:
            break

    if progressbar:
        pbar.close()

    # Reset data counters
    dataset.reset()

    wer /= num_words
    sub_word /= num_words
    ins_word /= num_words
    del_word /= num_words
    cer /= num_chars
    sub_char /= num_chars
    ins_char /= num_chars
    del_char /= num_chars

    df_wer_cer = pd.DataFrame(
        {'SUB': [sub_word * 100, sub_char * 100],
         'INS': [ins_word * 100, ins_char * 100],
         'DEL': [del_word * 100, del_char * 100]},
        columns=['SUB', 'INS', 'DEL'], index=['WER', 'CER'])

    return wer, cer, df_wer_cer
示例#11
0
    def check(self, usage_dec_sub='all', att_reg_weight=1,
              main_loss_weight=0.5, ctc_loss_weight_sub=0,
              dec_attend_temperature=1,
              dec_sigmoid_smoothing=False,
              backward_sub=False, num_heads=1, second_pass=False,
              relax_context_vec_dec=False):

        print('==================================================')
        print('  usage_dec_sub: %s' % usage_dec_sub)
        print('  att_reg_weight: %s' % str(att_reg_weight))
        print('  main_loss_weight: %s' % str(main_loss_weight))
        print('  ctc_loss_weight_sub: %s' % str(ctc_loss_weight_sub))
        print('  dec_attend_temperature: %s' % str(dec_attend_temperature))
        print('  dec_sigmoid_smoothing: %s' % str(dec_sigmoid_smoothing))
        print('  backward_sub: %s' % str(backward_sub))
        print('  num_heads: %s' % str(num_heads))
        print('  second_pass: %s' % str(second_pass))
        print('  relax_context_vec_dec: %s' % str(relax_context_vec_dec))
        print('==================================================')

        # Load batch data
        splice = 1
        num_stack = 1
        xs, ys, ys_sub, x_lens, y_lens, y_lens_sub = generate_data(
            label_type='word_char',
            batch_size=2,
            num_stack=num_stack,
            splice=splice)

        # Load model
        model = NestedAttentionSeq2seq(
            input_size=xs.shape[-1] // splice // num_stack,  # 120
            encoder_type='lstm',
            encoder_bidirectional=True,
            encoder_num_units=256,
            encoder_num_proj=0,
            encoder_num_layers=2,
            encoder_num_layers_sub=2,
            attention_type='location',
            attention_dim=128,
            decoder_type='lstm',
            decoder_num_units=256,
            decoder_num_layers=1,
            decoder_num_units_sub=256,
            decoder_num_layers_sub=1,
            embedding_dim=64,
            embedding_dim_sub=32,
            dropout_input=0.1,
            dropout_encoder=0.1,
            dropout_decoder=0.1,
            dropout_embedding=0.1,
            main_loss_weight=0.8,
            sub_loss_weight=0.2 if ctc_loss_weight_sub == 0 else 0,
            num_classes=11,
            num_classes_sub=27 if not second_pass else 11,
            parameter_init_distribution='uniform',
            parameter_init=0.1,
            recurrent_weight_orthogonal=False,
            init_forget_gate_bias_with_one=True,
            subsample_list=[True, False],
            subsample_type='drop',
            init_dec_state='first',
            sharpening_factor=1,
            logits_temperature=1,
            sigmoid_smoothing=False,
            ctc_loss_weight_sub=ctc_loss_weight_sub,
            attention_conv_num_channels=10,
            attention_conv_width=201,
            num_stack=num_stack,
            splice=1,
            conv_channels=[],
            conv_kernel_sizes=[],
            conv_strides=[],
            poolings=[],
            batch_norm=False,
            scheduled_sampling_prob=0.1,
            scheduled_sampling_max_step=200,
            label_smoothing_prob=0.1,
            weight_noise_std=0,
            encoder_residual=False,
            encoder_dense_residual=False,
            decoder_residual=False,
            decoder_dense_residual=False,
            decoding_order='attend_generate_update',
            # decoding_order='attend_update_generate',
            # decoding_order='conditional',
            bottleneck_dim=256,
            bottleneck_dim_sub=256,
            backward_sub=backward_sub,
            num_heads=num_heads,
            num_heads_sub=num_heads,
            num_heads_dec=num_heads,
            usage_dec_sub=usage_dec_sub,
            att_reg_weight=att_reg_weight,
            dec_attend_temperature=dec_attend_temperature,
            dec_sigmoid_smoothing=dec_attend_temperature,
            relax_context_vec_dec=relax_context_vec_dec,
            dec_attention_type='location')

        # Count total parameters
        for name in sorted(list(model.num_params_dict.keys())):
            num_params = model.num_params_dict[name]
            print("%s %d" % (name, num_params))
        print("Total %.3f M parameters" % (model.total_parameters / 1000000))

        # Define optimizer
        learning_rate = 1e-3
        model.set_optimizer('adam',
                            learning_rate_init=learning_rate,
                            weight_decay=1e-6,
                            lr_schedule=False,
                            factor=0.1,
                            patience_epoch=5)

        # Define learning rate controller
        lr_controller = Controller(learning_rate_init=learning_rate,
                                   backend='pytorch',
                                   decay_start_epoch=20,
                                   decay_rate=0.9,
                                   decay_patient_epoch=10,
                                   lower_better=True)

        # GPU setting
        model.set_cuda(deterministic=False, benchmark=True)

        # Train model
        max_step = 300
        start_time_step = time.time()
        for step in range(max_step):

            # Step for parameter update
            model.optimizer.zero_grad()
            if second_pass:
                loss = model(xs, ys, x_lens, y_lens)
            else:
                loss, loss_main, loss_sub = model(
                    xs, ys, x_lens, y_lens, ys_sub, y_lens_sub)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 5)
            model.optimizer.step()

            if (step + 1) % 10 == 0:
                # Compute loss
                if second_pass:
                    loss = model(xs, ys, x_lens, y_lens, is_eval=True)
                else:
                    loss, loss_main, loss_sub = model(
                        xs, ys, x_lens, y_lens, ys_sub, y_lens_sub, is_eval=True)

                best_hyps, _, best_hyps_sub, _, perm_idx = model.decode(
                    xs, x_lens, beam_width=1,
                    max_decode_len=30,
                    max_decode_len_sub=60)

                str_hyp = idx2word(best_hyps[0][:-1])
                str_ref = idx2word(ys[0])
                if second_pass:
                    str_hyp_sub = idx2word(best_hyps_sub[0][:-1])
                    str_ref_sub = idx2word(ys[0])
                else:
                    str_hyp_sub = idx2char(best_hyps_sub[0][:-1])
                    str_ref_sub = idx2char(ys_sub[0])

                # Compute accuracy
                try:
                    wer, _, _, _ = compute_wer(ref=str_ref.split('_'),
                                               hyp=str_hyp.split('_'),
                                               normalize=True)
                    if second_pass:
                        cer, _, _, _ = compute_wer(ref=str_ref.split('_'),
                                                   hyp=str_hyp_sub.split('_'),
                                                   normalize=True)
                    else:
                        cer, _, _, _ = compute_wer(
                            ref=list(str_ref_sub.replace('_', '')),
                            hyp=list(str_hyp_sub.replace('_', '')),
                            normalize=True)
                except:
                    wer = 1
                    cer = 1

                duration_step = time.time() - start_time_step
                if second_pass:
                    print('Step %d: loss=%.3f / wer=%.3f / cer=%.3f / lr=%.5f (%.3f sec)' %
                          (step + 1, loss, wer, cer, learning_rate, duration_step))
                else:
                    print('Step %d: loss=%.3f(%.3f/%.3f) / wer=%.3f / cer=%.3f / lr=%.5f (%.3f sec)' %
                          (step + 1, loss, loss_main, loss_sub,
                           wer, cer, learning_rate, duration_step))

                start_time_step = time.time()

                # Visualize
                print('Ref: %s' % str_ref)
                print('Hyp (word): %s' % str_hyp)
                print('Hyp (char): %s' % str_hyp_sub)

                if cer < 0.1:
                    print('Modle is Converged.')
                    break

                # Update learning rate
                model.optimizer, learning_rate = lr_controller.decay_lr(
                    optimizer=model.optimizer,
                    learning_rate=learning_rate,
                    epoch=step,
                    value=wer)
示例#12
0
def decode(model,
           dataset,
           eval_batch_size,
           beam_width,
           length_penalty,
           save_path=None):
    """Visualize label outputs.
    Args:
        model: the model to evaluate
        dataset: An instance of a `Dataset` class
        eval_batch_size (int): the batch size when evaluating the model
        beam_width: (int): the size of beam
        length_penalty (float):
        save_path (string): path to save decoding results
    """
    idx2phone = Idx2phone(dataset.vocab_file_path)

    if save_path is not None:
        sys.stdout = open(join(model.model_dir, 'decode.txt'), 'w')

    for batch, is_new_epoch in dataset:

        # Decode
        best_hyps, _, perm_idx = model.decode(
            batch['xs'],
            batch['x_lens'],
            beam_width=beam_width,
            max_decode_len=MAX_DECODE_LEN_PHONE,
            length_penalty=length_penalty)
        if model.model_type == 'attention' and model.ctc_loss_weight > 0:
            best_hyps_ctc, perm_idx = model.decode_ctc(batch['xs'],
                                                       batch['x_lens'],
                                                       beam_width=beam_width)

        ys = batch['ys'][perm_idx]
        y_lens = batch['y_lens'][perm_idx]

        for b in range(len(batch['xs'])):
            ##############################
            # Reference
            ##############################
            if dataset.is_test:
                str_ref = ys[b][0]
                # NOTE: transcript is seperated by space(' ')
            else:
                # Convert from list of index to string
                str_ref = idx2phone(ys[b][:y_lens[b]])

            ##############################
            # Hypothesis
            ##############################
            # Convert from list of index to string
            str_hyp = idx2phone(best_hyps[b])

            print('----- wav: %s -----' % batch['input_names'][b])
            print('Ref      : %s' % str_ref)
            print('Hyp      : %s' % str_hyp)
            if model.model_type == 'attention' and model.ctc_loss_weight > 0:
                str_hyp_ctc = idx2phone(best_hyps_ctc[b])
                print('Hyp (CTC): %s' % str_hyp_ctc)

            # Compute PER
            per, _, _, _ = compute_wer(ref=str_ref.split(' '),
                                       hyp=re.sub(r'(.*) >(.*)', r'\1',
                                                  str_hyp).split(' '),
                                       normalize=True)
            print('PER: %.3f %%' % (per * 100))
            if model.model_type == 'attention' and model.ctc_loss_weight > 0:
                per_ctc, _, _, _ = compute_wer(ref=str_ref.split(' '),
                                               hyp=str_hyp_ctc.split(' '),
                                               normalize=True)
                print('PER (CTC): %.3f %%' % (per_ctc * 100))

        if is_new_epoch:
            break
def do_eval_wer(session, decode_ops, model, dataset, train_data_size,
                is_test=False, eval_batch_size=None, progressbar=False,
                is_multitask=False):
    """Evaluate trained model by Word Error Rate.
    Args:
        session: session of training model
        decode_ops: list of operations for decoding
        model: the model to evaluate
        dataset: An instance of `Dataset` class
        train_data_size (string): train100h or train460h or train960h
        is_test (bool, optional): set to True when evaluating by the test set
        eval_batch_size (int, optional): the batch size when evaluating the model
        progressbar (bool, optional): if True, visualize progressbar
        is_multitask (bool, optional): if True, evaluate the multitask model
    Return:
        wer_mean (bool): An average of WER
    """
    assert isinstance(decode_ops, list), "decode_ops must be a list."

    batch_size_original = dataset.batch_size

    # Reset data counter
    dataset.reset()

    # Set batch size in the evaluation
    if eval_batch_size is not None:
        dataset.batch_size = eval_batch_size

    idx2word = Idx2word(
        map_file_path='../metrics/mapping_files/word_' + train_data_size + '.txt')

    wer_mean = 0
    skip_data_num = 0
    if progressbar:
        pbar = tqdm(total=len(dataset))
    for data, is_new_epoch in dataset:

        # Create feed dictionary for next mini batch
        if is_multitask:
            inputs, labels_true, _, inputs_seq_len, _ = data
        else:
            inputs, labels_true, inputs_seq_len, _ = data

        feed_dict = {}
        for i_device in range(len(decode_ops)):
            feed_dict[model.inputs_pl_list[i_device]] = inputs[i_device]
            feed_dict[model.inputs_seq_len_pl_list[i_device]
                      ] = inputs_seq_len[i_device]
            feed_dict[model.keep_prob_pl_list[i_device]] = 1.0

        labels_pred_st_list = session.run(decode_ops, feed_dict=feed_dict)
        for i_device, labels_pred_st in enumerate(labels_pred_st_list):
            batch_size_device = len(inputs[i_device])
            try:
                labels_pred = sparsetensor2list(labels_pred_st,
                                                batch_size_device)

                for i_batch in range(batch_size_device):

                    if is_test:
                        str_true = labels_true[i_device][i_batch][0]
                        # NOTE: transcript is seperated by space('_')
                    else:
                        str_true = '_'.join(
                            idx2word(labels_true[i_device][i_batch]))
                    str_pred = '_'.join(idx2word(labels_pred[i_batch]))

                    # if len(str_true.split('_')) == 0:
                    #     print(str_true)
                    #     print(str_pred)

                    # Compute WER
                    wer_mean += compute_wer(ref=str_true.split('_'),
                                            hyp=str_pred.split('_'),
                                            normalize=True)
                    # substitute, insert, delete = wer_align(
                    #     ref=str_true.split(' '),
                    #     hyp=str_pred.split(' '))
                    # print('SUB: %d' % substitute)
                    # print('INS: %d' % insert)
                    # print('DEL: %d' % delete)

                    if progressbar:
                        pbar.update(1)

            except IndexError:
                print('skipped')
                skip_data_num += batch_size_device
                # TODO: Conduct decoding again with batch size 1

                if progressbar:
                    pbar.update(batch_size_device)

        if is_new_epoch:
            break

    wer_mean /= (len(dataset) - skip_data_num)

    # Register original batch size
    if eval_batch_size is not None:
        dataset.batch_size = batch_size_original

    return wer_mean
def do_eval_cer(session, decode_ops, model, dataset, label_type,
                is_test=False, eval_batch_size=None, progressbar=False,
                is_multitask=False):
    """Evaluate trained model by Character Error Rate.
    Args:
        session: session of training model
        decode_ops: list of operations for decoding
        model: the model to evaluate
        dataset: An instance of a `Dataset` class
        label_type (string): character or character_capital_divide
        is_test (bool, optional): set to True when evaluating by the test set
        eval_batch_size (int, optional): the batch size when evaluating the model
        progressbar (bool, optional): if True, visualize the progressbar
        is_multitask (bool, optional): if True, evaluate the multitask model
    Return:
        cer_mean (float): An average of CER
        wer_mean (float): An average of WER
    """
    assert isinstance(decode_ops, list), "decode_ops must be a list."

    batch_size_original = dataset.batch_size

    # Reset data counter
    dataset.reset()

    # Set batch size in the evaluation
    if eval_batch_size is not None:
        dataset.batch_size = eval_batch_size

    if label_type == 'character':
        idx2char = Idx2char(
            map_file_path='../metrics/mapping_files/character.txt')
    elif label_type == 'character_capital_divide':
        idx2char = Idx2char(
            map_file_path='../metrics/mapping_files/character_capital_divide.txt',
            capital_divide=True,
            space_mark='_')
    else:
        raise TypeError

    cer_mean, wer_mean = 0, 0
    skip_data_num = 0
    if progressbar:
        pbar = tqdm(total=len(dataset))
    for data, is_new_epoch in dataset:

        # Create feed dictionary for next mini batch
        if is_multitask:
            inputs, _, labels_true, inputs_seq_len, _ = data
        else:
            inputs, labels_true, inputs_seq_len, _ = data

        feed_dict = {}
        for i_device in range(len(decode_ops)):
            feed_dict[model.inputs_pl_list[i_device]] = inputs[i_device]
            feed_dict[model.inputs_seq_len_pl_list[i_device]
                      ] = inputs_seq_len[i_device]
            feed_dict[model.keep_prob_pl_list[i_device]] = 1.0

        labels_pred_st_list = session.run(decode_ops, feed_dict=feed_dict)
        for i_device, labels_pred_st in enumerate(labels_pred_st_list):
            batch_size_device = len(inputs[i_device])
            try:
                labels_pred = sparsetensor2list(labels_pred_st,
                                                batch_size_device)
                for i_batch in range(batch_size_device):

                    # Convert from list of index to string
                    if is_test:
                        str_true = labels_true[i_device][i_batch][0]
                        # NOTE: transcript is seperated by space('_')
                    else:
                        str_true = idx2char(labels_true[i_device][i_batch],
                                            padded_value=dataset.padded_value)
                    str_pred = idx2char(labels_pred[i_batch])

                    # Remove consecutive spaces
                    str_pred = re.sub(r'[_]+', '_', str_pred)

                    # Remove garbage labels
                    str_true = re.sub(r'[\']+', '', str_true)
                    str_pred = re.sub(r'[\']+', '', str_pred)

                    # Compute WER
                    wer_mean += compute_wer(ref=str_true.split('_'),
                                            hyp=str_pred.split('_'),
                                            normalize=True)
                    # substitute, insert, delete = wer_align(
                    #     ref=str_pred.split('_'),
                    #     hyp=str_true.split('_'))
                    # print('SUB: %d' % substitute)
                    # print('INS: %d' % insert)
                    # print('DEL: %d' % delete)

                    # Remove spaces
                    str_true = re.sub(r'[_]+', '', str_true)
                    str_pred = re.sub(r'[_]+', '', str_pred)

                    # Compute CER
                    cer_mean += compute_cer(str_pred=str_pred,
                                            str_true=str_true,
                                            normalize=True)

                    if progressbar:
                        pbar.update(1)

            except IndexError:
                print('skipped')
                skip_data_num += batch_size_device
                # TODO: Conduct decoding again with batch size 1

                if progressbar:
                    pbar.update(batch_size_device)

        if is_new_epoch:
            break

    cer_mean /= (len(dataset) - skip_data_num)
    wer_mean /= (len(dataset) - skip_data_num)
    # TODO: Fix this

    # Register original batch size
    if eval_batch_size is not None:
        dataset.batch_size = batch_size_original

    return cer_mean, wer_mean
def do_eval_cer2(session, posteriors_ops, beam_width, model, dataset,
                 label_type, is_test=False, eval_batch_size=None,
                 progressbar=False, is_multitask=False):
    """Evaluate trained model by Character Error Rate.
    Args:
        session: session of training model
        posteriors_ops: list of operations for computing posteriors
        beam_width (int):
        model: the model to evaluate
        dataset: An instance of a `Dataset` class
        label_type (string): character or character_capital_divide
        is_test (bool, optional): set to True when evaluating by the test set
        eval_batch_size (int, optional): the batch size when evaluating the model
        progressbar (bool, optional): if True, visualize the progressbar
        is_multitask (bool, optional): if True, evaluate the multitask model
    Return:
        cer_mean (float): An average of CER
        wer_mean (float): An average of WER
    """
    assert isinstance(posteriors_ops, list), "posteriors_ops must be a list."

    batch_size_original = dataset.batch_size

    # Reset data counter
    dataset.reset()

    # Set batch size in the evaluation
    if eval_batch_size is not None:
        dataset.batch_size = eval_batch_size

    if label_type == 'character':
        idx2char = Idx2char(
            map_file_path='../metrics/mapping_files/character.txt')
        char2idx = Char2idx(
            map_file_path='../metrics/mapping_files/character.txt')
    elif label_type == 'character_capital_divide':
        raise NotImplementedError
    else:
        raise TypeError

    # Define decoder
    decoder = BeamSearchDecoder(space_index=char2idx('_')[0],
                                blank_index=model.num_classes - 1)

    cer_mean, wer_mean = 0, 0
    if progressbar:
        pbar = tqdm(total=len(dataset))
    for data, is_new_epoch in dataset:

        # Create feed dictionary for next mini batch
        if is_multitask:
            inputs, _, labels_true, inputs_seq_len, _ = data
        else:
            inputs, labels_true, inputs_seq_len, _ = data

        feed_dict = {}
        for i_device in range(len(posteriors_ops)):
            feed_dict[model.inputs_pl_list[i_device]] = inputs[i_device]
            feed_dict[model.inputs_seq_len_pl_list[i_device]
                      ] = inputs_seq_len[i_device]
            feed_dict[model.keep_prob_pl_list[i_device]] = 1.0

        posteriors_list = session.run(posteriors_ops, feed_dict=feed_dict)
        for i_device, labels_pred_st in enumerate(posteriors_list):
            batch_size_device, max_time = inputs[i_device].shape[:2]

            posteriors = posteriors_list[i_device].reshape(
                batch_size_device, max_time, model.num_classes)

            for i_batch in range(batch_size_device):

                # Decode per utterance
                labels_pred, scores = decoder(
                    probs=posteriors[i_batch:i_batch + 1],
                    seq_len=inputs_seq_len[i_device][i_batch: i_batch + 1],
                    beam_width=beam_width)

                # Convert from list of index to string
                if is_test:
                    str_true = labels_true[i_device][i_batch][0]
                    # NOTE: transcript is seperated by space('_')
                else:
                    str_true = idx2char(labels_true[i_device][i_batch],
                                        padded_value=dataset.padded_value)
                str_pred = idx2char(labels_pred[0])

                # Remove consecutive spaces
                str_pred = re.sub(r'[_]+', '_', str_pred)

                # Remove garbage labels
                str_true = re.sub(r'[\']+', '', str_true)
                str_pred = re.sub(r'[\']+', '', str_pred)

                # Compute WER
                wer_mean += compute_wer(ref=str_true.split('_'),
                                        hyp=str_pred.split('_'),
                                        normalize=True)
                # substitute, insert, delete = wer_align(
                #     ref=str_pred.split('_'),
                #     hyp=str_true.split('_'))
                # print('SUB: %d' % substitute)
                # print('INS: %d' % insert)
                # print('DEL: %d' % delete)

                # Remove spaces
                str_true = re.sub(r'[_]+', '', str_true)
                str_pred = re.sub(r'[_]+', '', str_pred)

                # Compute CER
                cer_mean += compute_cer(str_pred=str_pred,
                                        str_true=str_true,
                                        normalize=True)

                if progressbar:
                    pbar.update(1)

        if is_new_epoch:
            break

    cer_mean /= (len(dataset))
    wer_mean /= (len(dataset))

    # Register original batch size
    if eval_batch_size is not None:
        dataset.batch_size = batch_size_original

    return cer_mean, wer_mean
示例#16
0
def do_eval_cer(session, decode_op, model, dataset, label_type,
                is_test=False, eval_batch_size=None, progressbar=False,
                is_multitask=False, is_jointctcatt=False):
    """Evaluate trained model by Character Error Rate.
    Args:
        session: session of training model
        decode_op: operation for decoding
        model: the model to evaluate
        dataset: An instance of a `Dataset` class
        label_type (string): character or character_capital_divide
        is_test (bool, optional): set to True when evaluating by the test set
        eval_batch_size (int, optional): batch size when evaluating the model
        progressbar (bool, optional): if True, visualize the progressbar
        is_multitask (bool, optional): if True, evaluate the multitask model
        is_jointctcatt (bool, optional): if True, evaluate the joint
            CTC-Attention model
    Return:
        cer_mean (float): An average of CER
        wer_mean (float): An average of WER
    """
    batch_size_original = dataset.batch_size

    # Reset data counter
    dataset.reset()

    # Set batch size in the evaluation
    if eval_batch_size is not None:
        dataset.batch_size = eval_batch_size

    idx2char = Idx2char(
        map_file_path='../metrics/mapping_files/' + label_type + '.txt')

    cer_mean, wer_mean = 0, 0
    if progressbar:
        pbar = tqdm(total=len(dataset))
    for data, is_new_epoch in dataset:

        # Create feed dictionary for next mini-batch
        if is_multitask:
            inputs, labels_true, _, inputs_seq_len, labels_seq_len, _ = data
        elif is_jointctcatt:
            inputs, labels_true, _, inputs_seq_len, labels_seq_len, _ = data
        else:
            inputs, labels_true, inputs_seq_len, labels_seq_len, _ = data

        feed_dict = {
            model.inputs_pl_list[0]: inputs[0],
            model.inputs_seq_len_pl_list[0]: inputs_seq_len[0],
            model.keep_prob_encoder_pl_list[0]: 1.0,
            model.keep_prob_decoder_pl_list[0]: 1.0,
            model.keep_prob_embedding_pl_list[0]: 1.0
        }

        batch_size = inputs[0].shape[0]

        labels_pred = session.run(decode_op, feed_dict=feed_dict)
        for i_batch in range(batch_size):

            # Convert from list of index to string
            if is_test:
                str_true = labels_true[0][i_batch][0]
                # NOTE: transcript is seperated by space('_')
            else:
                # Convert from list of index to string
                str_true = idx2char(
                    labels_true[0][i_batch][1:labels_seq_len[0][i_batch] - 1],
                    padded_value=dataset.padded_value)
            str_pred = idx2char(labels_pred[i_batch]).split('>')[0]
            # NOTE: Trancate by <EOS>

            # Remove consecutive spaces
            str_pred = re.sub(r'[_]+', '_', str_pred)

            # Remove garbage labels
            str_true = re.sub(r'[<>\'\":;!?,.-]+', '', str_true)
            str_pred = re.sub(r'[<>\'\":;!?,.-]+', '', str_pred)

            # Compute WER
            wer_mean += compute_wer(hyp=str_pred.split('_'),
                                    ref=str_true.split('_'),
                                    normalize=True)
            # substitute, insert, delete = wer_align(
            #     ref=str_pred.split('_'),
            #     hyp=str_true.split('_'))
            # print('SUB: %d' % substitute)
            # print('INS: %d' % insert)
            # print('DEL: %d' % delete)

            # Remove spaces
            str_pred = re.sub(r'[_]+', '', str_pred)
            str_true = re.sub(r'[_]+', '', str_true)

            # Compute CER
            cer_mean += compute_cer(str_pred=str_pred,
                                    str_true=str_true,
                                    normalize=True)

            if progressbar:
                pbar.update(1)

        if is_new_epoch:
            break

    cer_mean /= len(dataset)
    wer_mean /= len(dataset)

    # Register original batch size
    if eval_batch_size is not None:
        dataset.batch_size = batch_size_original

    return cer_mean, wer_mean
示例#17
0
def decode(model,
           dataset,
           beam_width,
           max_decode_len,
           max_decode_len_sub,
           eval_batch_size=None,
           save_path=None):
    """Visualize label outputs.
    Args:
        model: the model to evaluate
        dataset: An instance of a `Dataset` class
        beam_width: (int): the size of beam
        max_decode_len (int): the length of output sequences
            to stop prediction when EOS token have not been emitted.
            This is used for seq2seq models.
        max_decode_len_sub (int)
        eval_batch_size (int, optional): the batch size when evaluating the model
        save_path (string): path to save decoding results
    """
    # Set batch size in the evaluation
    if eval_batch_size is not None:
        dataset.batch_size = eval_batch_size

    idx2word = Idx2word(vocab_file_path=dataset.vocab_file_path)
    if dataset.label_type_sub == 'character':
        idx2char = Idx2char(vocab_file_path=dataset.vocab_file_path_sub)
    elif dataset.label_type_sub == 'character_capital_divide':
        idx2char = Idx2char(vocab_file_path=dataset.vocab_file_path_sub,
                            capital_divide=True)

    if save_path is not None:
        sys.stdout = open(join(model.model_dir, 'decode.txt'), 'w')

    for batch, is_new_epoch in dataset:

        # Decode
        if model.model_type == 'charseq_attention':
            best_hyps, best_hyps_sub, perm_idx = model.decode(
                batch['xs'],
                batch['x_lens'],
                beam_width=beam_width,
                max_decode_len=max_decode_len,
                max_decode_len_sub=100)
        else:
            best_hyps, perm_idx = model.decode(batch['xs'],
                                               batch['x_lens'],
                                               beam_width=beam_width,
                                               max_decode_len=max_decode_len)
            best_hyps_sub, perm_idx = model.decode(
                batch['xs'],
                batch['x_lens'],
                beam_width=beam_width,
                max_decode_len=max_decode_len_sub,
                is_sub_task=True)

        ys = batch['ys'][perm_idx]
        y_lens = batch['y_lens'][perm_idx]
        ys_sub = batch['ys_sub'][perm_idx]
        y_lens_sub = batch['y_lens_sub'][perm_idx]

        for b in range(len(batch['xs'])):

            ##############################
            # Reference
            ##############################
            if dataset.is_test:
                str_ref = ys[b][0]
                str_ref_sub = ys_sub[b][0]
                # NOTE: transcript is seperated by space('_')
            else:
                # Convert from list of index to string
                str_ref = idx2word(ys[b][:y_lens[b]])
                str_ref_sub = idx2word(ys_sub[b][:y_lens_sub[b]])

            ##############################
            # Hypothesis
            ##############################
            # Convert from list of index to string
            str_hyp = idx2word(best_hyps[b])
            str_hyp_sub = idx2char(best_hyps_sub[b])

            if model.model_type != 'hierarchical_ctc':
                str_hyp = str_hyp.split('>')[0]
                str_hyp_sub = str_hyp_sub.split('>')[0]
                # NOTE: Trancate by the first <EOS>

                # Remove the last space
                if len(str_hyp) > 0 and str_hyp[-1] == '_':
                    str_hyp = str_hyp[:-1]
                if len(str_hyp_sub) > 0 and str_hyp_sub[-1] == '_':
                    str_hyp_sub = str_hyp_sub[:-1]

            ##############################
            # Post-proccessing
            ##############################
            # Remove garbage labels
            str_ref = re.sub(r'[\'>]+', '', str_ref)
            str_hyp = re.sub(r'[\'>]+', '', str_hyp)

            print('----- wav: %s -----' % batch['input_names'][b])
            print('Ref: %s' % str_ref.replace('_', ' '))
            print('Hyp (main): %s' % str_hyp.replace('_', ' '))
            # print('Ref (sub): %s' % str_ref_sub.replace('_', ' '))
            print('Hyp (sub): %s' % str_hyp_sub.replace('_', ' '))

            wer, _, _, _ = compute_wer(ref=str_ref.split('_'),
                                       hyp=str_hyp.split('_'),
                                       normalize=True)
            print('WER: %f %%' % (wer * 100))
            cer, _, _, _ = compute_wer(ref=list(str_ref_sub.replace('_', '')),
                                       hyp=list(str_hyp_sub.replace('_', '')),
                                       normalize=True)
            print('CER: %f %%' % (cer * 100))

        if is_new_epoch:
            break
def do_eval_cer2(session,
                 posteriors_ops,
                 beam_width,
                 model,
                 dataset,
                 label_type,
                 is_test=False,
                 eval_batch_size=None,
                 progressbar=False,
                 is_multitask=False):
    """Evaluate trained model by Character Error Rate.
    Args:
        session: session of training model
        posteriors_ops: list of operations for computing posteriors
        beam_width (int):
        model: the model to evaluate
        dataset: An instance of a `Dataset` class
        label_type (string): character or character_capital_divide
        is_test (bool, optional): set to True when evaluating by the test set
        eval_batch_size (int, optional): the batch size when evaluating the model
        progressbar (bool, optional): if True, visualize the progressbar
        is_multitask (bool, optional): if True, evaluate the multitask model
    Return:
        cer_mean (float): An average of CER
        wer_mean (float): An average of WER
    """
    assert isinstance(posteriors_ops, list), "posteriors_ops must be a list."

    batch_size_original = dataset.batch_size

    # Reset data counter
    dataset.reset()

    # Set batch size in the evaluation
    if eval_batch_size is not None:
        dataset.batch_size = eval_batch_size

    if label_type == 'character':
        idx2char = Idx2char(
            map_file_path='../metrics/mapping_files/character.txt')
        char2idx = Char2idx(
            map_file_path='../metrics/mapping_files/character.txt')
    elif label_type == 'character_capital_divide':
        raise NotImplementedError
    else:
        raise TypeError

    # Define decoder
    decoder = BeamSearchDecoder(space_index=char2idx('_')[0],
                                blank_index=model.num_classes - 1)

    cer_mean, wer_mean = 0, 0
    if progressbar:
        pbar = tqdm(total=len(dataset))
    for data, is_new_epoch in dataset:

        # Create feed dictionary for next mini batch
        if is_multitask:
            inputs, _, labels_true, inputs_seq_len, _ = data
        else:
            inputs, labels_true, inputs_seq_len, _ = data

        feed_dict = {}
        for i_device in range(len(posteriors_ops)):
            feed_dict[model.inputs_pl_list[i_device]] = inputs[i_device]
            feed_dict[model.inputs_seq_len_pl_list[i_device]] = inputs_seq_len[
                i_device]
            feed_dict[model.keep_prob_pl_list[i_device]] = 1.0

        posteriors_list = session.run(posteriors_ops, feed_dict=feed_dict)
        for i_device, labels_pred_st in enumerate(posteriors_list):
            batch_size_device, max_time = inputs[i_device].shape[:2]

            posteriors = posteriors_list[i_device].reshape(
                batch_size_device, max_time, model.num_classes)

            for i_batch in range(batch_size_device):

                # Decode per utterance
                labels_pred, scores = decoder(
                    probs=posteriors[i_batch:i_batch + 1],
                    seq_len=inputs_seq_len[i_device][i_batch:i_batch + 1],
                    beam_width=beam_width)

                # Convert from list of index to string
                if is_test:
                    str_true = labels_true[i_device][i_batch][0]
                    # NOTE: transcript is seperated by space('_')
                else:
                    str_true = idx2char(labels_true[i_device][i_batch],
                                        padded_value=dataset.padded_value)
                str_pred = idx2char(labels_pred[0])

                # Remove consecutive spaces
                str_pred = re.sub(r'[_]+', '_', str_pred)

                # Remove garbage labels
                str_true = re.sub(r'[\']+', '', str_true)
                str_pred = re.sub(r'[\']+', '', str_pred)

                # Compute WER
                wer_mean += compute_wer(ref=str_true.split('_'),
                                        hyp=str_pred.split('_'),
                                        normalize=True)
                # substitute, insert, delete = wer_align(
                #     ref=str_pred.split('_'),
                #     hyp=str_true.split('_'))
                # print('SUB: %d' % substitute)
                # print('INS: %d' % insert)
                # print('DEL: %d' % delete)

                # Remove spaces
                str_true = re.sub(r'[_]+', '', str_true)
                str_pred = re.sub(r'[_]+', '', str_pred)

                # Compute CER
                cer_mean += compute_cer(str_pred=str_pred,
                                        str_true=str_true,
                                        normalize=True)

                if progressbar:
                    pbar.update(1)

        if is_new_epoch:
            break

    cer_mean /= (len(dataset))
    wer_mean /= (len(dataset))

    # Register original batch size
    if eval_batch_size is not None:
        dataset.batch_size = batch_size_original

    return cer_mean, wer_mean
示例#19
0
def main():

    args = parser.parse_args()

    # Load a config file (.yml)
    params = load_config(join(args.model_path, 'config.yml'), is_eval=True)

    # Load dataset
    dataset = Dataset(
        data_save_path=args.data_save_path,
        backend=params['backend'],
        input_freq=params['input_freq'],
        use_delta=params['use_delta'],
        use_double_delta=params['use_double_delta'],
        data_type='test',
        label_type=params['label_type'],
        batch_size=args.eval_batch_size, splice=params['splice'],
        num_stack=params['num_stack'], num_skip=params['num_skip'],
        sort_utt=True, reverse=True, tool=params['tool'])

    params['num_classes'] = dataset.num_classes

    # Load model
    model = load(model_type=params['model_type'],
                 params=params,
                 backend=params['backend'])

    # Restore the saved parameters
    model.load_checkpoint(save_path=args.model_path, epoch=args.epoch)

    # GPU setting
    model.set_cuda(deterministic=False, benchmark=True)

    # sys.stdout = open(join(model.model_dir, 'decode.txt'), 'w')

    ######################################################################

    for batch, is_new_epoch in dataset:
        # Decode
        best_hyps, _, perm_idx = model.decode(
            batch['xs'], batch['x_lens'],
            beam_width=args.beam_width,
            max_decode_len=MAX_DECODE_LEN_PHONE,
            min_decode_len=MIN_DECODE_LEN_PHONE,
            length_penalty=args.length_penalty,
            coverage_penalty=args.coverage_penalty)

        if model.model_type == 'attention' and model.ctc_loss_weight > 0:
            best_hyps_ctc, perm_idx = model.decode_ctc(
                batch['xs'], batch['x_lens'],
                beam_width=args.beam_width)

        ys = batch['ys'][perm_idx]
        y_lens = batch['y_lens'][perm_idx]

        for b in range(len(batch['xs'])):
            ##############################
            # Reference
            ##############################
            if dataset.is_test:
                str_ref = ys[b][0]
                # NOTE: transcript is seperated by space(' ')
            else:
                # Convert from list of index to string
                str_ref = dataset.idx2phone(ys[b][: y_lens[b]])

            ##############################
            # Hypothesis
            ##############################
            # Convert from list of index to string
            str_hyp = dataset.idx2phone(best_hyps[b])

            print('----- wav: %s -----' % batch['input_names'][b])
            print('Ref      : %s' % str_ref)
            print('Hyp      : %s' % str_hyp)
            if model.model_type == 'attention' and model.ctc_loss_weight > 0:
                str_hyp_ctc = dataset.idx2phone(best_hyps_ctc[b])
                print('Hyp (CTC): %s' % str_hyp_ctc)

            # Compute PER
            per, _, _, _ = compute_wer(ref=str_ref.split(' '),
                                       hyp=re.sub(r'(.*) >(.*)', r'\1',
                                                  str_hyp).split(' '),
                                       normalize=True)
            print('PER: %.3f %%' % (per * 100))
            if model.model_type == 'attention' and model.ctc_loss_weight > 0:
                per_ctc, _, _, _ = compute_wer(ref=str_ref.split(' '),
                                               hyp=str_hyp_ctc.split(' '),
                                               normalize=True)
                print('PER (CTC): %.3f %%' % (per_ctc * 100))

        if is_new_epoch:
            break
def do_eval_cer(session,
                decode_ops,
                model,
                dataset,
                label_type,
                is_test=False,
                eval_batch_size=None,
                progressbar=False,
                is_multitask=False):
    """Evaluate trained model by Character Error Rate.
    Args:
        session: session of training model
        decode_ops: list of operations for decoding
        model: the model to evaluate
        dataset: An instance of a `Dataset` class
        label_type (string): character or character_capital_divide
        is_test (bool, optional): set to True when evaluating by the test set
        eval_batch_size (int, optional): the batch size when evaluating the model
        progressbar (bool, optional): if True, visualize the progressbar
        is_multitask (bool, optional): if True, evaluate the multitask model
    Return:
        cer_mean (float): An average of CER
        wer_mean (float): An average of WER
    """
    assert isinstance(decode_ops, list), "decode_ops must be a list."

    batch_size_original = dataset.batch_size

    # Reset data counter
    dataset.reset()

    # Set batch size in the evaluation
    if eval_batch_size is not None:
        dataset.batch_size = eval_batch_size

    if label_type == 'character':
        idx2char = Idx2char(
            map_file_path='../metrics/mapping_files/character.txt')
    elif label_type == 'character_capital_divide':
        idx2char = Idx2char(
            map_file_path=
            '../metrics/mapping_files/character_capital_divide.txt',
            capital_divide=True,
            space_mark='_')
    else:
        raise TypeError

    cer_mean, wer_mean = 0, 0
    skip_data_num = 0
    if progressbar:
        pbar = tqdm(total=len(dataset))
    for data, is_new_epoch in dataset:

        # Create feed dictionary for next mini batch
        if is_multitask:
            inputs, _, labels_true, inputs_seq_len, _ = data
        else:
            inputs, labels_true, inputs_seq_len, _ = data

        feed_dict = {}
        for i_device in range(len(decode_ops)):
            feed_dict[model.inputs_pl_list[i_device]] = inputs[i_device]
            feed_dict[model.inputs_seq_len_pl_list[i_device]] = inputs_seq_len[
                i_device]
            feed_dict[model.keep_prob_pl_list[i_device]] = 1.0

        labels_pred_st_list = session.run(decode_ops, feed_dict=feed_dict)
        for i_device, labels_pred_st in enumerate(labels_pred_st_list):
            batch_size_device = len(inputs[i_device])
            try:
                labels_pred = sparsetensor2list(labels_pred_st,
                                                batch_size_device)
                for i_batch in range(batch_size_device):

                    # Convert from list of index to string
                    if is_test:
                        str_true = labels_true[i_device][i_batch][0]
                        # NOTE: transcript is seperated by space('_')
                    else:
                        str_true = idx2char(labels_true[i_device][i_batch],
                                            padded_value=dataset.padded_value)
                    str_pred = idx2char(labels_pred[i_batch])

                    # Remove consecutive spaces
                    str_pred = re.sub(r'[_]+', '_', str_pred)

                    # Remove garbage labels
                    str_true = re.sub(r'[\']+', '', str_true)
                    str_pred = re.sub(r'[\']+', '', str_pred)

                    # Compute WER
                    wer_mean += compute_wer(ref=str_true.split('_'),
                                            hyp=str_pred.split('_'),
                                            normalize=True)
                    # substitute, insert, delete = wer_align(
                    #     ref=str_pred.split('_'),
                    #     hyp=str_true.split('_'))
                    # print('SUB: %d' % substitute)
                    # print('INS: %d' % insert)
                    # print('DEL: %d' % delete)

                    # Remove spaces
                    str_true = re.sub(r'[_]+', '', str_true)
                    str_pred = re.sub(r'[_]+', '', str_pred)

                    # Compute CER
                    cer_mean += compute_cer(str_pred=str_pred,
                                            str_true=str_true,
                                            normalize=True)

                    if progressbar:
                        pbar.update(1)

            except IndexError:
                print('skipped')
                skip_data_num += batch_size_device
                # TODO: Conduct decoding again with batch size 1

                if progressbar:
                    pbar.update(batch_size_device)

        if is_new_epoch:
            break

    cer_mean /= (len(dataset) - skip_data_num)
    wer_mean /= (len(dataset) - skip_data_num)
    # TODO: Fix this

    # Register original batch size
    if eval_batch_size is not None:
        dataset.batch_size = batch_size_original

    return cer_mean, wer_mean
def decode(model, dataset, beam_width, beam_width_sub,
           eval_batch_size=None, save_path=None, resolving_unk=False):
    """Visualize label outputs.
    Args:
        model: the model to evaluate
        dataset: An instance of a `Dataset` class
        beam_width: (int): the size of beam in the main task
        beam_width: (int): the size of beam in the sub task
        eval_batch_size (int, optional): the batch size when evaluating the model
        save_path (string): path to save decoding results
        resolving_unk (bool, optional):
    """
    # Set batch size in the evaluation
    if eval_batch_size is not None:
        dataset.batch_size = eval_batch_size

    idx2word = Idx2word(vocab_file_path=dataset.vocab_file_path)
    idx2char = Idx2char(vocab_file_path=dataset.vocab_file_path_sub,
                        capital_divide=dataset.label_type_sub == 'character_capital_divide')

    # Read GLM file
    glm = GLM(
        glm_path='/n/sd8/inaguma/corpus/swbd/data/eval2000/LDC2002T43/reference/en20000405_hub5.glm')

    if save_path is not None:
        sys.stdout = open(join(model.model_dir, 'decode.txt'), 'w')

    for batch, is_new_epoch in dataset:

        # Decode
        if model.model_type == 'nested_attention':
            best_hyps, aw, best_hyps_sub, aw_sub, perm_idx = model.decode(
                batch['xs'], batch['x_lens'],
                beam_width=beam_width,
                beam_width_sub=beam_width_sub,
                max_decode_len=MAX_DECODE_LEN_WORD,
                max_decode_len_sub=MAX_DECODE_LEN_CHAR)
        else:
            best_hyps, aw, perm_idx = model.decode(
                batch['xs'], batch['x_lens'],
                beam_width=beam_width,
                max_decode_len=MAX_DECODE_LEN_WORD)
            best_hyps_sub, aw_sub, _ = model.decode(
                batch['xs'], batch['x_lens'],
                beam_width=beam_width_sub,
                max_decode_len=MAX_DECODE_LEN_CHAR,
                task_index=1)

        ys = batch['ys'][perm_idx]
        y_lens = batch['y_lens'][perm_idx]
        ys_sub = batch['ys_sub'][perm_idx]
        y_lens_sub = batch['y_lens_sub'][perm_idx]

        for b in range(len(batch['xs'])):

            ##############################
            # Reference
            ##############################
            if dataset.is_test:
                str_ref_original = ys[b][0]
                str_ref_sub = ys_sub[b][0]
                # NOTE: transcript is seperated by space('_')
            else:
                # Convert from list of index to string
                str_ref_original = idx2word(ys[b][: y_lens[b]])
                str_ref_sub = idx2word(ys_sub[b][:y_lens_sub[b]])

            ##############################
            # Hypothesis
            ##############################
            # Convert from list of index to string
            str_hyp = idx2word(best_hyps[b])
            str_hyp_sub = idx2char(best_hyps_sub[b])

            ##############################
            # Resolving UNK
            ##############################
            if 'OOV' in str_hyp and resolving_unk:
                str_hyp_no_unk = resolve_unk(
                    str_hyp, best_hyps_sub[b], aw[b], aw_sub[b], idx2char)

            # if 'OOV' not in str_hyp:
            #     continue

            ##############################
            # Post-proccessing
            ##############################
            str_ref = fix_trans(str_ref_original, glm)
            str_ref_sub = fix_trans(str_ref_sub, glm)
            str_hyp = fix_trans(str_hyp, glm)
            str_hyp_sub = fix_trans(str_hyp_sub, glm)
            str_hyp_no_unk = fix_trans(str_hyp_no_unk, glm)

            if len(str_ref) == 0:
                continue

            print('----- wav: %s -----' % batch['input_names'][b])
            print('Ref         : %s' % str_ref.replace('_', ' '))
            print('Hyp (main)  : %s' % str_hyp.replace('_', ' '))
            print('Hyp (sub)   : %s' % str_hyp_sub.replace('_', ' '))
            if 'OOV' in str_hyp and resolving_unk:
                print('Hyp (no UNK): %s' % str_hyp_no_unk.replace('_', ' '))

            try:
                # Compute WER
                wer, _, _, _ = compute_wer(
                    ref=str_ref.split('_'),
                    hyp=str_hyp.replace(r'_>.*', '').split('_'),
                    normalize=True)
                print('WER (main)  : %.3f %%' % (wer * 100))
                wer_sub, _, _, _ = compute_wer(
                    ref=str_ref_sub.split('_'),
                    hyp=str_hyp_sub.replace(r'>.*', '').split('_'),
                    normalize=True)
                print('WER (sub)   : %.3f %%' % (wer_sub * 100))
                if 'OOV' in str_hyp and resolving_unk:
                    wer_no_unk, _, _, _ = compute_wer(
                        ref=str_ref.split('_'),
                        hyp=str_hyp_no_unk.replace(
                            '*', '').replace(r'_>.*', '').split('_'),
                        normalize=True)
                    print('WER (no UNK): %.3f %%' % (wer_no_unk * 100))
            except:
                print('--- skipped ---')

        if is_new_epoch:
            break
def do_eval_wer(session,
                decode_ops,
                model,
                dataset,
                train_data_size,
                is_test=False,
                eval_batch_size=None,
                progressbar=False,
                is_multitask=False):
    """Evaluate trained model by Word Error Rate.
    Args:
        session: session of training model
        decode_ops: list of operations for decoding
        model: the model to evaluate
        dataset: An instance of `Dataset` class
        train_data_size (string): train100h or train460h or train960h
        is_test (bool, optional): set to True when evaluating by the test set
        eval_batch_size (int, optional): the batch size when evaluating the model
        progressbar (bool, optional): if True, visualize progressbar
        is_multitask (bool, optional): if True, evaluate the multitask model
    Return:
        wer_mean (bool): An average of WER
    """
    assert isinstance(decode_ops, list), "decode_ops must be a list."

    batch_size_original = dataset.batch_size

    # Reset data counter
    dataset.reset()

    # Set batch size in the evaluation
    if eval_batch_size is not None:
        dataset.batch_size = eval_batch_size

    idx2word = Idx2word(map_file_path='../metrics/mapping_files/word_' +
                        train_data_size + '.txt')

    wer_mean = 0
    skip_data_num = 0
    if progressbar:
        pbar = tqdm(total=len(dataset))
    for data, is_new_epoch in dataset:

        # Create feed dictionary for next mini batch
        if is_multitask:
            inputs, labels_true, _, inputs_seq_len, _ = data
        else:
            inputs, labels_true, inputs_seq_len, _ = data

        feed_dict = {}
        for i_device in range(len(decode_ops)):
            feed_dict[model.inputs_pl_list[i_device]] = inputs[i_device]
            feed_dict[model.inputs_seq_len_pl_list[i_device]] = inputs_seq_len[
                i_device]
            feed_dict[model.keep_prob_pl_list[i_device]] = 1.0

        labels_pred_st_list = session.run(decode_ops, feed_dict=feed_dict)
        for i_device, labels_pred_st in enumerate(labels_pred_st_list):
            batch_size_device = len(inputs[i_device])
            try:
                labels_pred = sparsetensor2list(labels_pred_st,
                                                batch_size_device)

                for i_batch in range(batch_size_device):

                    if is_test:
                        str_true = labels_true[i_device][i_batch][0]
                        # NOTE: transcript is seperated by space('_')
                    else:
                        str_true = '_'.join(
                            idx2word(labels_true[i_device][i_batch]))
                    str_pred = '_'.join(idx2word(labels_pred[i_batch]))

                    # if len(str_true.split('_')) == 0:
                    #     print(str_true)
                    #     print(str_pred)

                    # Compute WER
                    wer_mean += compute_wer(ref=str_true.split('_'),
                                            hyp=str_pred.split('_'),
                                            normalize=True)
                    # substitute, insert, delete = wer_align(
                    #     ref=str_true.split(' '),
                    #     hyp=str_pred.split(' '))
                    # print('SUB: %d' % substitute)
                    # print('INS: %d' % insert)
                    # print('DEL: %d' % delete)

                    if progressbar:
                        pbar.update(1)

            except IndexError:
                print('skipped')
                skip_data_num += batch_size_device
                # TODO: Conduct decoding again with batch size 1

                if progressbar:
                    pbar.update(batch_size_device)

        if is_new_epoch:
            break

    wer_mean /= (len(dataset) - skip_data_num)

    # Register original batch size
    if eval_batch_size is not None:
        dataset.batch_size = batch_size_original

    return wer_mean
    def check(self,
              encoder_type,
              decoder_type,
              bidirectional=False,
              attention_type='location',
              label_type='char',
              subsample=False,
              projection=False,
              init_dec_state='first',
              ctc_loss_weight=0,
              conv=False,
              batch_norm=False,
              residual=False,
              dense_residual=False,
              decoding_order='bahdanau_attention',
              backward_loss_weight=0,
              num_heads=1,
              beam_width=1):

        print('==================================================')
        print('  label_type: %s' % label_type)
        print('  encoder_type: %s' % encoder_type)
        print('  bidirectional: %s' % str(bidirectional))
        print('  projection: %s' % str(projection))
        print('  decoder_type: %s' % decoder_type)
        print('  init_dec_state: %s' % init_dec_state)
        print('  attention_type: %s' % attention_type)
        print('  subsample: %s' % str(subsample))
        print('  ctc_loss_weight: %s' % str(ctc_loss_weight))
        print('  conv: %s' % str(conv))
        print('  batch_norm: %s' % str(batch_norm))
        print('  residual: %s' % str(residual))
        print('  dense_residual: %s' % str(dense_residual))
        print('  decoding_order: %s' % decoding_order)
        print('  backward_loss_weight: %s' % str(backward_loss_weight))
        print('  num_heads: %s' % str(num_heads))
        print('  beam_width: %s' % str(beam_width))
        print('==================================================')

        if conv or encoder_type == 'cnn':
            # pattern 1
            # conv_channels = [32, 32]
            # conv_kernel_sizes = [[41, 11], [21, 11]]
            # conv_strides = [[2, 2], [2, 1]]
            # poolings = [[], []]

            # pattern 2 (VGG like)
            conv_channels = [64, 64]
            conv_kernel_sizes = [[3, 3], [3, 3]]
            conv_strides = [[1, 1], [1, 1]]
            poolings = [[2, 2], [2, 2]]
        else:
            conv_channels = []
            conv_kernel_sizes = []
            conv_strides = []
            poolings = []

        # Load batch data
        splice = 1
        num_stack = 1 if subsample or conv or encoder_type == 'cnn' else 3
        xs, ys, x_lens, y_lens = generate_data(label_type=label_type,
                                               batch_size=2,
                                               num_stack=num_stack,
                                               splice=splice)

        if label_type == 'char':
            num_classes = 27
            map_fn = idx2char
        elif label_type == 'word':
            num_classes = 11
            map_fn = idx2word

        # Load model
        model = AttentionSeq2seq(
            input_size=xs.shape[-1] // splice // num_stack,  # 120
            encoder_type=encoder_type,
            encoder_bidirectional=bidirectional,
            encoder_num_units=256,
            encoder_num_proj=256 if projection else 0,
            encoder_num_layers=1 if not subsample else 2,
            attention_type=attention_type,
            attention_dim=128,
            decoder_type=decoder_type,
            decoder_num_units=256,
            decoder_num_layers=1,
            embedding_dim=32,
            dropout_input=0.1,
            dropout_encoder=0.1,
            dropout_decoder=0.1,
            dropout_embedding=0.1,
            num_classes=num_classes,
            parameter_init_distribution='uniform',
            parameter_init=0.1,
            recurrent_weight_orthogonal=False,
            init_forget_gate_bias_with_one=True,
            subsample_list=[] if not subsample else [True, False],
            subsample_type='concat' if not subsample else subsample,
            bridge_layer=True,
            init_dec_state=init_dec_state,
            sharpening_factor=1,
            logits_temperature=1,
            sigmoid_smoothing=False,
            coverage_weight=0,
            ctc_loss_weight=ctc_loss_weight,
            attention_conv_num_channels=10,
            attention_conv_width=201,
            num_stack=num_stack,
            splice=splice,
            input_channel=3,
            conv_channels=conv_channels,
            conv_kernel_sizes=conv_kernel_sizes,
            conv_strides=conv_strides,
            poolings=poolings,
            activation='relu',
            batch_norm=batch_norm,
            scheduled_sampling_prob=0.1,
            scheduled_sampling_max_step=200,
            label_smoothing_prob=0.1,
            weight_noise_std=1e-9,
            encoder_residual=residual,
            encoder_dense_residual=dense_residual,
            decoder_residual=residual,
            decoder_dense_residual=dense_residual,
            decoding_order=decoding_order,
            bottleneck_dim=256,
            backward_loss_weight=backward_loss_weight,
            num_heads=num_heads)

        # Count total parameters
        for name in sorted(list(model.num_params_dict.keys())):
            num_params = model.num_params_dict[name]
            print("%s %d" % (name, num_params))
        print("Total %.3f M parameters" % (model.total_parameters / 1000000))

        # Define optimizer
        learning_rate = 1e-3
        model.set_optimizer('adam',
                            learning_rate_init=learning_rate,
                            weight_decay=1e-8,
                            lr_schedule=False,
                            factor=0.1,
                            patience_epoch=5)

        # Define learning rate controller
        lr_controller = Controller(learning_rate_init=learning_rate,
                                   backend='pytorch',
                                   decay_start_epoch=20,
                                   decay_rate=0.9,
                                   decay_patient_epoch=10,
                                   lower_better=True)

        # GPU setting
        model.set_cuda(deterministic=False, benchmark=True)

        # Train model
        max_step = 300
        start_time_step = time.time()
        for step in range(max_step):

            # Step for parameter update
            model.optimizer.zero_grad()
            loss = model(xs, ys, x_lens, y_lens)
            loss.backward()
            # torch.nn.utils.clip_grad_norm_(model.parameters(), 5)
            torch.nn.utils.clip_grad_norm(model.parameters(), 5)
            model.optimizer.step()

            # Inject Gaussian noise to all parameters
            # if loss.item() < 50:
            if loss.data[0] < 50:
                model.weight_noise_injection = True

            if (step + 1) % 10 == 0:
                # Compute loss
                loss = model(xs, ys, x_lens, y_lens, is_eval=True)

                # Decode
                best_hyps, _, perm_idx = model.decode(xs,
                                                      x_lens,
                                                      beam_width,
                                                      max_decode_len=60)

                str_ref = map_fn(ys[0])
                str_hyp = map_fn(best_hyps[0][:-1])

                # Compute accuracy
                try:
                    if label_type == 'char':
                        ler, _, _, _ = compute_wer(
                            ref=list(str_ref.replace('_', '')),
                            hyp=list(str_hyp.replace('_', '')),
                            normalize=True)
                    elif label_type == 'word':
                        ler, _, _, _ = compute_wer(ref=str_ref.split('_'),
                                                   hyp=str_hyp.split('_'),
                                                   normalize=True)
                except:
                    ler = 1

                duration_step = time.time() - start_time_step
                print('Step %d: loss=%.3f / ler=%.3f / lr=%.5f (%.3f sec)' %
                      (step + 1, loss, ler, learning_rate, duration_step))
                start_time_step = time.time()

                # Visualize
                print('Ref: %s' % str_ref)
                print('Hyp: %s' % str_hyp)

                # Decode by the CTC decoder
                if model.ctc_loss_weight >= 0.1:
                    best_hyps_ctc, perm_idx = model.decode_ctc(xs,
                                                               x_lens,
                                                               beam_width=1)
                    str_pred_ctc = map_fn(best_hyps_ctc[0])
                    print('Hyp (CTC): %s' % str_pred_ctc)

                if ler < 0.1:
                    print('Modle is Converged.')
                    break

                # Update learning rate
                model.optimizer, learning_rate = lr_controller.decay_lr(
                    optimizer=model.optimizer,
                    learning_rate=learning_rate,
                    epoch=step,
                    value=ler)
示例#24
0
def eval_word(models, dataset, eval_batch_size,
              beam_width, max_decode_len, min_decode_len=0,
              beam_width_sub=1, max_decode_len_sub=200, min_decode_len_sub=0,
              length_penalty=0, coverage_penalty=0,
              progressbar=False, resolving_unk=False, a2c_oracle=False,
              joint_decoding=None, score_sub_weight=0):
    """Evaluate trained model by Word Error Rate.
    Args:
        models (list): the models to evaluate
        dataset: An instance of a `Dataset' class
        eval_batch_size (int): the batch size when evaluating the model
        beam_width (int): the size of beam in ths main task
        max_decode_len (int): the maximum sequence length of tokens in the main task
        min_decode_len (int): the minimum sequence length of tokens in the main task
        beam_width_sub (int): the size of beam in ths sub task
            This is used for the nested attention
        max_decode_len_sub (int): the maximum sequence length of tokens in the sub task
        min_decode_len_sub (int): the minimum sequence length of tokens in the sub task
        length_penalty (float): length penalty in beam search decoding
        coverage_penalty (float): coverage penalty in beam search decoding
        progressbar (bool): if True, visualize the progressbar
        resolving_unk (bool):
        a2c_oracle (bool):
        joint_decoding (bool): onepass or resocring or None
        score_sub_weight (float):
    Returns:
        wer (float): Word error rate
        df_word (pd.DataFrame): dataframe of substitution, insertion, and deletion
    """
    # Reset data counter
    dataset.reset()

    model = models[0]
    # TODO: fix this

    if model.model_type == 'hierarchical_attention' and joint_decoding is not None:
        word2char = Word2char(dataset.vocab_file_path,
                              dataset.vocab_file_path_sub)

    wer = 0
    sub, ins, dele, = 0, 0, 0
    num_words = 0
    if progressbar:
        pbar = tqdm(total=len(dataset))  # TODO: fix this
    while True:
        batch, is_new_epoch = dataset.next(batch_size=eval_batch_size)

        batch_size = len(batch['xs'])

        # Decode
        if model.model_type == 'nested_attention':
            if a2c_oracle:
                if dataset.is_test:
                    max_label_num = 0
                    for b in range(batch_size):
                        if max_label_num < len(list(batch['ys_sub'][b][0])):
                            max_label_num = len(
                                list(batch['ys_sub'][b][0]))

                    ys_sub = np.zeros(
                        (batch_size, max_label_num), dtype=np.int32)
                    ys_sub -= 1  # pad with -1
                    y_lens_sub = np.zeros((batch_size,), dtype=np.int32)
                    for b in range(batch_size):
                        indices = dataset.char2idx(batch['ys_sub'][b][0])
                        ys_sub[b, :len(indices)] = indices
                        y_lens_sub[b] = len(indices)
                        # NOTE: transcript is seperated by space('_')
            else:
                ys_sub = batch['ys_sub']
                y_lens_sub = batch['y_lens_sub']

            best_hyps, aw, best_hyps_sub, aw_sub, _, perm_idx = model.decode(
                batch['xs'], batch['x_lens'],
                beam_width=beam_width,
                max_decode_len=max_decode_len,
                min_decode_len=min_decode_len,
                beam_width_sub=beam_width_sub,
                max_decode_len_sub=max_label_num if a2c_oracle else max_decode_len_sub,
                min_decode_len_sub=min_decode_len_sub,
                length_penalty=length_penalty,
                coverage_penalty=coverage_penalty,
                teacher_forcing=a2c_oracle,
                ys_sub=ys_sub,
                y_lens_sub=y_lens_sub)
        elif model.model_type == 'hierarchical_attention' and joint_decoding is not None:
            best_hyps, aw, best_hyps_sub, aw_sub, perm_idx = model.decode(
                batch['xs'], batch['x_lens'],
                beam_width=beam_width,
                max_decode_len=max_decode_len,
                min_decode_len=min_decode_len,
                length_penalty=length_penalty,
                coverage_penalty=coverage_penalty,
                joint_decoding=joint_decoding,
                space_index=dataset.char2idx('_')[0],
                oov_index=dataset.word2idx('OOV')[0],
                word2char=word2char,
                idx2word=dataset.idx2word,
                idx2char=dataset.idx2char,
                score_sub_weight=score_sub_weight)
        else:
            best_hyps, aw, perm_idx = model.decode(
                batch['xs'], batch['x_lens'],
                beam_width=beam_width,
                max_decode_len=max_decode_len,
                min_decode_len=min_decode_len,
                length_penalty=length_penalty,
                coverage_penalty=coverage_penalty)
            if resolving_unk:
                best_hyps_sub, aw_sub, _ = model.decode(
                    batch['xs'], batch['x_lens'],
                    beam_width=beam_width,
                    max_decode_len=max_decode_len_sub,
                    min_decode_len=min_decode_len_sub,
                    length_penalty=length_penalty,
                    coverage_penalty=coverage_penalty,
                    task_index=1)

        ys = batch['ys'][perm_idx]
        y_lens = batch['y_lens'][perm_idx]

        for b in range(batch_size):
            ##############################
            # Reference
            ##############################
            if dataset.is_test:
                str_ref = ys[b][0]
                # NOTE: transcript is seperated by space('_')
            else:
                # Convert from list of index to string
                str_ref = dataset.idx2word(ys[b][:y_lens[b]])

            ##############################
            # Hypothesis
            ##############################
            str_hyp = dataset.idx2word(best_hyps[b])
            if dataset.label_type == 'word':
                str_hyp = re.sub(r'(.*)_>(.*)', r'\1', str_hyp)
            else:
                str_hyp = re.sub(r'(.*)>(.*)', r'\1', str_hyp)
            # NOTE: Trancate by the first <EOS>

            ##############################
            # Resolving UNK
            ##############################
            if resolving_unk and 'OOV' in str_hyp:
                str_hyp = resolve_unk(
                    str_hyp, best_hyps_sub[b], aw[b], aw_sub[b], dataset.idx2char)
                str_hyp = str_hyp.replace('*', '')

            ##############################
            # Post-proccessing
            ##############################
            # Remove garbage labels
            str_ref = re.sub(r'[@>]+', '', str_ref)
            str_hyp = re.sub(r'[@>]+', '', str_hyp)
            # NOTE: @ means noise

            # Remove consecutive spaces
            str_ref = re.sub(r'[_]+', '_', str_ref)
            str_hyp = re.sub(r'[_]+', '_', str_hyp)

            # Compute WER
            try:
                wer_b, sub_b, ins_b, del_b = compute_wer(
                    ref=str_ref.split('_'),
                    hyp=str_hyp.split('_'),
                    normalize=False)
                wer += wer_b
                sub += sub_b
                ins += ins_b
                dele += del_b
                num_words += len(str_ref.split('_'))
            except:
                pass

            if progressbar:
                pbar.update(1)

        if is_new_epoch:
            break

    if progressbar:
        pbar.close()

    # Reset data counters
    dataset.reset()

    wer /= num_words
    sub /= num_words
    ins /= num_words
    dele /= num_words

    df_word = pd.DataFrame(
        {'SUB': [sub * 100], 'INS': [ins * 100], 'DEL': [dele * 100]},
        columns=['SUB', 'INS', 'DEL'], index=['WER'])

    return wer, df_word
示例#25
0
def eval_char(models,
              dataset,
              beam_width,
              max_decode_len,
              eval_batch_size=None,
              length_penalty=0,
              progressbar=False,
              temperature=1):
    """Evaluate trained model by Character Error Rate.
    Args:
        models (list): the models to evaluate
        dataset: An instance of a `Dataset' class
        beam_width: (int): the size of beam
        max_decode_len (int): the length of output sequences
            to stop prediction when EOS token have not been emitted.
            This is used for seq2seq models.
        length_penalty (float, optional):
        eval_batch_size (int, optional): the batch size when evaluating the model
        progressbar (bool, optional): if True, visualize the progressbar
    Returns:
        wer (float): Word error rate
        cer (float): Character error rate
        df_wer_cer (pd.DataFrame): dataframe of substitution, insertion, and deletion
    """
    # Reset data counter
    dataset.reset()

    if models[0].model_type in ['ctc', 'attention']:
        idx2char = Idx2char(
            vocab_file_path=dataset.vocab_file_path,
            capital_divide=(dataset.label_type == 'character_capital_divide'))
    else:
        idx2char = Idx2char(
            vocab_file_path=dataset.vocab_file_path_sub,
            capital_divide=(
                dataset.label_type_sub == 'character_capital_divide'))

    # Read GLM file
    glm = GLM(
        glm_path=
        '/n/sd8/inaguma/corpus/swbd/data/eval2000/LDC2002T43/reference/en20000405_hub5.glm'
    )

    wer, cer = 0, 0
    sub_word, ins_word, del_word = 0, 0, 0
    sub_char, ins_char, del_char = 0, 0, 0
    num_words, num_chars = 0, 0
    if progressbar:
        pbar = tqdm(total=len(dataset))  # TODO: fix this
    while True:
        batch, is_new_epoch = dataset.next(batch_size=eval_batch_size)

        # TODO: add CTC ensemble

        # Decode
        model = models[0]
        # TODO: fix this

        if model.model_type in ['ctc', 'attention']:
            best_hyps, _, perm_idx = model.decode(
                batch['xs'],
                batch['x_lens'],
                beam_width=beam_width,
                max_decode_len=max_decode_len,
                length_penalty=length_penalty)
            ys = batch['ys'][perm_idx]
            y_lens = batch['y_lens'][perm_idx]
        else:
            best_hyps, _, perm_idx = model.decode(
                batch['xs'],
                batch['x_lens'],
                beam_width=beam_width,
                max_decode_len=max_decode_len,
                length_penalty=length_penalty,
                task_index=1)
            ys = batch['ys_sub'][perm_idx]
            y_lens = batch['y_lens_sub'][perm_idx]

        for b in range(len(batch['xs'])):

            ##############################
            # Reference
            ##############################
            if dataset.is_test:
                str_ref = ys[b][0]
                # NOTE: transcript is seperated by space('_')
            else:
                # Convert from list of index to string
                str_ref = idx2char(ys[b][:y_lens[b]])

            ##############################
            # Hypothesis
            ##############################
            str_hyp = idx2char(best_hyps[b])
            if 'attention' in model.model_type:
                str_hyp = str_hyp.split('>')[0]
                # NOTE: Trancate by the first <EOS>

                # Remove the last space
                if len(str_hyp) > 0 and str_hyp[-1] == '_':
                    str_hyp = str_hyp[:-1]

            ##############################
            # Post-proccessing
            ##############################
            str_ref = fix_trans(str_ref, glm)
            str_hyp = fix_trans(str_hyp, glm)

            if len(str_ref) == 0:
                if progressbar:
                    pbar.update(1)
                continue

            try:
                # Compute WER
                wer_b, sub_b, ins_b, del_b = compute_wer(
                    ref=str_ref.split('_'),
                    hyp=str_hyp.split('_'),
                    normalize=False)
                wer += wer_b
                sub_word += sub_b
                ins_word += ins_b
                del_word += del_b
                num_words += len(str_ref.split('_'))

                # Compute CER
                cer_b, sub_b, ins_b, del_b = compute_wer(
                    ref=list(str_ref.replace('_', '')),
                    hyp=list(str_hyp.replace('_', '')),
                    normalize=False)
                cer += cer_b
                sub_char += sub_b
                ins_char += ins_b
                del_char += del_b
                num_chars += len(str_ref.replace('_', ''))
            except:
                pass

            if progressbar:
                pbar.update(1)

        if is_new_epoch:
            break

    if progressbar:
        pbar.close()

    # Reset data counters
    dataset.reset()

    wer /= num_words
    sub_word /= num_words
    ins_word /= num_words
    del_word /= num_words
    cer /= num_chars
    sub_char /= num_chars
    ins_char /= num_chars
    del_char /= num_chars

    df_wer_cer = pd.DataFrame(
        {
            'SUB': [sub_word * 100, sub_char * 100],
            'INS': [ins_word * 100, ins_char * 100],
            'DEL': [del_word * 100, del_char * 100]
        },
        columns=['SUB', 'INS', 'DEL'],
        index=['WER', 'CER'])

    return wer, cer, df_wer_cer
    def check(self,
              encoder_type,
              bidirectional=False,
              subsample=False,
              projection=False,
              conv=False,
              batch_norm=False,
              activation='relu',
              encoder_residual=False,
              encoder_dense_residual=False,
              label_smoothing=False):

        print('==================================================')
        print('  encoder_type: %s' % encoder_type)
        print('  bidirectional: %s' % str(bidirectional))
        print('  projection: %s' % str(projection))
        print('  subsample: %s' % str(subsample))
        print('  conv: %s' % str(conv))
        print('  batch_norm: %s' % str(batch_norm))
        print('  encoder_residual: %s' % str(encoder_residual))
        print('  encoder_dense_residual: %s' % str(encoder_dense_residual))
        print('  label_smoothing: %s' % str(label_smoothing))
        print('==================================================')

        if conv or encoder_type == 'cnn':
            # pattern 1
            # conv_channels = [32, 32]
            # conv_kernel_sizes = [[41, 11], [21, 11]]
            # conv_strides = [[2, 2], [2, 1]]
            # poolings = [[], []]

            # pattern 2 (VGG like)
            conv_channels = [64, 64]
            conv_kernel_sizes = [[3, 3], [3, 3]]
            conv_strides = [[1, 1], [1, 1]]
            poolings = [[2, 2], [2, 2]]

            fc_list = [786, 786]
        else:
            conv_channels = []
            conv_kernel_sizes = []
            conv_strides = []
            poolings = []
            fc_list = []

        # Load batch data
        num_stack = 1 if subsample or conv or encoder_type == 'cnn' else 2
        splice = 1
        xs, ys, ys_sub, x_lens, y_lens, y_lens_sub = generate_data(
            label_type='word_char',
            batch_size=2,
            num_stack=num_stack,
            splice=splice)

        num_classes = 11
        num_classes_sub = 27

        # Load model
        model = HierarchicalCTC(
            input_size=xs.shape[-1] // splice // num_stack,  # 120
            encoder_type=encoder_type,
            encoder_bidirectional=bidirectional,
            encoder_num_units=256,
            encoder_num_proj=256 if projection else 0,
            encoder_num_layers=2,
            encoder_num_layers_sub=1,
            fc_list=fc_list,
            fc_list_sub=fc_list,
            dropout_input=0.1,
            dropout_encoder=0.1,
            main_loss_weight=0.8,
            sub_loss_weight=0.2,
            num_classes=num_classes,
            num_classes_sub=num_classes_sub,
            parameter_init_distribution='uniform',
            parameter_init=0.1,
            recurrent_weight_orthogonal=False,
            init_forget_gate_bias_with_one=True,
            subsample_list=[] if not subsample else [True, False],
            num_stack=num_stack,
            splice=splice,
            input_channel=3,
            conv_channels=conv_channels,
            conv_kernel_sizes=conv_kernel_sizes,
            conv_strides=conv_strides,
            poolings=poolings,
            batch_norm=batch_norm,
            label_smoothing_prob=0.1 if label_smoothing else 0,
            weight_noise_std=0,
            encoder_residual=encoder_residual,
            encoder_dense_residual=encoder_dense_residual)

        # Count total parameters
        for name in sorted(list(model.num_params_dict.keys())):
            num_params = model.num_params_dict[name]
            print("%s %d" % (name, num_params))
        print("Total %.3f M parameters" % (model.total_parameters / 1000000))

        # Define optimizer
        learning_rate = 1e-3
        model.set_optimizer('adam',
                            learning_rate_init=learning_rate,
                            weight_decay=1e-6,
                            lr_schedule=False,
                            factor=0.1,
                            patience_epoch=5)

        # Define learning rate controller
        lr_controller = Controller(learning_rate_init=learning_rate,
                                   backend='pytorch',
                                   decay_start_epoch=20,
                                   decay_rate=0.9,
                                   decay_patient_epoch=10,
                                   lower_better=True)

        # GPU setting
        model.set_cuda(deterministic=False, benchmark=True)

        # Train model
        max_step = 300
        start_time_step = time.time()
        for step in range(max_step):

            # Step for parameter update
            model.optimizer.zero_grad()
            loss, loss_main, loss_sub = model(xs, ys, x_lens, y_lens, ys_sub,
                                              y_lens_sub)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 5)
            model.optimizer.step()

            if (step + 1) % 10 == 0:
                # Compute loss
                loss, loss_main, loss_sub = model(xs,
                                                  ys,
                                                  x_lens,
                                                  y_lens,
                                                  ys_sub,
                                                  y_lens_sub,
                                                  is_eval=True)

                # Decode
                best_hyps, _, _ = model.decode(xs,
                                               x_lens,
                                               beam_width=2,
                                               task_index=0)
                best_hyps_sub, _, _ = model.decode(xs,
                                                   x_lens,
                                                   beam_width=2,
                                                   task_index=1)

                str_ref = idx2word(ys[0, :y_lens[0]])
                str_hyp = idx2word(best_hyps[0])
                str_ref_sub = idx2char(ys_sub[0, :y_lens_sub[0]])
                str_hyp_sub = idx2char(best_hyps_sub[0])

                # Compute accuracy
                try:
                    wer, _, _, _ = compute_wer(ref=str_ref.split('_'),
                                               hyp=str_hyp.split('_'),
                                               normalize=True)
                    cer, _, _, _ = compute_wer(
                        ref=list(str_ref_sub.replace('_', '')),
                        hyp=list(str_hyp_sub.replace('_', '')),
                        normalize=True)
                except:
                    wer = 1
                    cer = 1

                duration_step = time.time() - start_time_step
                print(
                    'Step %d: loss=%.3f(%.3f/%.3f) / wer=%.3f / cer=%.3f / lr=%.5f (%.3f sec)'
                    % (step + 1, loss, loss_main, loss_sub, wer, cer,
                       learning_rate, duration_step))
                start_time_step = time.time()

                # Visualize
                print('Ref: %s' % str_ref)
                print('Hyp (word): %s' % str_hyp)
                print('Hyp (char): %s' % str_hyp_sub)

                if cer < 0.1:
                    print('Modle is Converged.')
                    break

                # Update learning rate
                model.optimizer, learning_rate = lr_controller.decay_lr(
                    optimizer=model.optimizer,
                    learning_rate=learning_rate,
                    epoch=step,
                    value=wer)
def eval_word(models,
              dataset,
              beam_width,
              max_decode_len,
              beam_width_sub=1,
              max_decode_len_sub=300,
              eval_batch_size=None,
              length_penalty=0,
              progressbar=False,
              temperature=1,
              resolving_unk=False,
              a2c_oracle=False):
    """Evaluate trained model by Word Error Rate.
    Args:
        models (list): the models to evaluate
        dataset: An instance of a `Dataset' class
        max_decode_len (int): the length of output sequences
            to stop prediction. This is used for seq2seq models.
        beam_width_sub (int, optional): the size of beam in ths sub task
            This is used for the nested attention
        max_decode_len_sub (int, optional): the length of output sequences
            to stop prediction. This is used for the nested attention
        eval_batch_size (int, optional): the batch size when evaluating the model
        progressbar (bool, optional): if True, visualize the progressbar
        temperature (int, optional):
        resolving_unk (bool, optional):
        a2c_oracle (bool, optional):
    Returns:
        wer (float): Word error rate
        df_wer (pd.DataFrame): dataframe of substitution, insertion, and deletion
    """
    # Reset data counter
    dataset.reset()

    idx2word = Idx2word(dataset.vocab_file_path)
    if models[0].model_type == 'nested_attention':
        char2idx = Char2idx(dataset.vocab_file_path_sub)
    if models[0] in ['ctc', 'attention'] and resolving_unk:
        idx2char = Idx2char(dataset.vocab_file_path_sub,
                            capital_divide=dataset.label_type_sub ==
                            'character_capital_divide')

    wer = 0
    sub, ins, dele, = 0, 0, 0
    num_words = 0
    if progressbar:
        pbar = tqdm(total=len(dataset))  # TODO: fix this
    while True:
        batch, is_new_epoch = dataset.next(batch_size=eval_batch_size)

        batch_size = len(batch['xs'])

        # Decode
        if len(models) > 1:
            assert models[0].model_type in ['ctc']
            for i, model in enumerate(models):
                probs, x_lens, perm_idx = model.posteriors(
                    batch['xs'], batch['x_lens'])
                if i == 0:
                    probs_ensenmble = probs
                else:
                    probs_ensenmble += probs
            probs_ensenmble /= len(models)

            best_hyps = models[0].decode_from_probs(probs_ensenmble,
                                                    x_lens,
                                                    beam_width=1)
        else:
            model = models[0]
            # TODO: fix this

            if model.model_type == 'nested_attention':
                if a2c_oracle:
                    if dataset.is_test:
                        max_label_num = 0
                        for b in range(batch_size):
                            if max_label_num < len(list(
                                    batch['ys_sub'][b][0])):
                                max_label_num = len(list(
                                    batch['ys_sub'][b][0]))

                        ys_sub = np.zeros((batch_size, max_label_num),
                                          dtype=np.int32)
                        ys_sub -= 1  # pad with -1
                        y_lens_sub = np.zeros((batch_size, ), dtype=np.int32)
                        for b in range(batch_size):
                            indices = char2idx(batch['ys_sub'][b][0])
                            ys_sub[b, :len(indices)] = indices
                            y_lens_sub[b] = len(indices)
                            # NOTE: transcript is seperated by space('_')
                else:
                    ys_sub = batch['ys_sub']
                    y_lens_sub = batch['y_lens_sub']

                best_hyps, aw, best_hyps_sub, aw_sub, perm_idx = model.decode(
                    batch['xs'],
                    batch['x_lens'],
                    beam_width=beam_width,
                    beam_width_sub=beam_width_sub,
                    max_decode_len=max_decode_len,
                    max_decode_len_sub=max_label_num
                    if a2c_oracle else max_decode_len_sub,
                    length_penalty=length_penalty,
                    teacher_forcing=a2c_oracle,
                    ys_sub=ys_sub,
                    y_lens_sub=y_lens_sub)
            else:
                best_hyps, aw, perm_idx = model.decode(
                    batch['xs'],
                    batch['x_lens'],
                    beam_width=beam_width,
                    max_decode_len=max_decode_len,
                    length_penalty=length_penalty)
                if resolving_unk:
                    best_hyps_sub, aw_sub, _ = model.decode(
                        batch['xs'],
                        batch['x_lens'],
                        beam_width=beam_width,
                        max_decode_len=max_decode_len_sub,
                        length_penalty=length_penalty,
                        task_index=1)

        ys = batch['ys'][perm_idx]
        y_lens = batch['y_lens'][perm_idx]

        for b in range(batch_size):
            ##############################
            # Reference
            ##############################
            if dataset.is_test:
                str_ref = ys[b][0]
                # NOTE: transcript is seperated by space('_')
            else:
                # Convert from list of index to string
                str_ref = idx2word(ys[b][:y_lens[b]])

            ##############################
            # Hypothesis
            ##############################
            str_hyp = idx2word(best_hyps[b])
            if dataset.label_type == 'word':
                str_hyp = re.sub(r'(.*)_>(.*)', r'\1', str_hyp)
            else:
                str_hyp = re.sub(r'(.*)>(.*)', r'\1', str_hyp)
            # NOTE: Trancate by the first <EOS>

            ##############################
            # Resolving UNK
            ##############################
            if resolving_unk and 'OOV' in str_hyp:
                str_hyp = resolve_unk(str_hyp, best_hyps_sub[b], aw[b],
                                      aw_sub[b], idx2char)
                str_hyp = str_hyp.replace('*', '')

            ##############################
            # Post-proccessing
            ##############################
            # Remove garbage labels
            str_ref = re.sub(r'[@>]+', '', str_ref)
            str_hyp = re.sub(r'[@>]+', '', str_hyp)
            # NOTE: @ means noise

            # Remove consecutive spaces
            str_ref = re.sub(r'[_]+', '_', str_ref)
            str_hyp = re.sub(r'[_]+', '_', str_hyp)

            # Compute WER
            try:
                wer_b, sub_b, ins_b, del_b = compute_wer(
                    ref=str_ref.split('_'),
                    hyp=str_hyp.split('_'),
                    normalize=False)
                wer += wer_b
                sub += sub_b
                ins += ins_b
                dele += del_b
                num_words += len(str_ref.split('_'))
            except:
                pass

            if progressbar:
                pbar.update(1)

        if is_new_epoch:
            break

    if progressbar:
        pbar.close()

    # Reset data counters
    dataset.reset()

    wer /= num_words
    sub /= num_words
    ins /= num_words
    dele /= num_words

    df_wer = pd.DataFrame(
        {
            'SUB': [sub * 100],
            'INS': [ins * 100],
            'DEL': [dele * 100]
        },
        columns=['SUB', 'INS', 'DEL'],
        index=['WER'])

    return wer, df_wer
def do_eval_cer(models,
                model_type,
                dataset,
                label_type,
                beam_width,
                max_decode_len,
                eval_batch_size=None,
                temperature=1,
                progressbar=False):
    """Evaluate trained models by Character Error Rate.
    Args:
        models (list): the model to evaluate
        model_type (string): ctc or attention or hierarchical_ctc or
            hierarchical_attention
        dataset: An instance of a `Dataset' class
        label_type (string): character or character_capital_divide
        beam_width: (int): the size of beam
        max_decode_len (int): the length of output sequences
            to stop prediction when EOS token have not been emitted.
            This is used for seq2seq models.
        eval_batch_size (int, optional): the batch size when evaluating the model
        temperature (int, optional):
        progressbar (bool, optional): if True, visualize the progressbar
    Returns:
        wer (float): Word error rate
        cer (float): Character error rate
        df_wer_cer (pd.DataFrame): dataframe of substitution, insertion, and deletion
    """
    # Reset data counter
    dataset.reset()

    idx2char = Idx2char(
        vocab_file_path=dataset.vocab_file_path,
        capital_divide=(dataset.label_type == 'character_capital_divide'))

    cer, wer = 0, 0
    sub_char, ins_char, del_char = 0, 0, 0
    sub_word, ins_word, del_word = 0, 0, 0
    num_words, num_chars = 0, 0
    if progressbar:
        pbar = tqdm(total=len(dataset))  # TODO: fix this
    while True:
        batch, is_new_epoch = dataset.next(batch_size=eval_batch_size)

        # Decode the ensemble
        if model_type in ['attention', 'ctc']:
            for i, model in enumerate(models):
                probs_i, perm_idx = model.posteriors(batch['xs'],
                                                     batch['x_lens'],
                                                     temperature=temperature)
                if i == 0:
                    probs = probs_i
                else:
                    probs += probs_i
                # NOTE: probs: `[1 (B), T, num_classes]`
            probs /= len(models)

            best_hyps = model.decode_from_probs(probs,
                                                batch['x_lens'][perm_idx],
                                                beam_width=beam_width,
                                                max_decode_len=max_decode_len)
            ys = batch['ys'][perm_idx]
            y_lens = batch['y_lens'][perm_idx]

        elif model_type in ['hierarchical_attention', 'hierarchical_ctc']:
            raise NotImplementedError

        for b in range(len(batch['xs'])):

            ##############################
            # Reference
            ##############################
            if dataset.is_test:
                str_ref = ys[b][0]
                # NOTE: transcript is seperated by space('_')
            else:
                # Convert from list of index to string
                str_ref = idx2char(ys[b][:y_lens[b]])

            ##############################
            # Hypothesis
            ##############################
            str_hyp = idx2char(best_hyps[b])
            if 'attention' in model.model_type:
                str_hyp = str_hyp.split('>')[0]
                # NOTE: Trancate by the first <EOS>

                # Remove the last space
                if len(str_hyp) > 0 and str_hyp[-1] == '_':
                    str_hyp = str_hyp[:-1]

            # Remove consecutive spaces
            str_hyp = re.sub(r'[_]+', '_', str_hyp)

            ##############################
            # Post-proccessing
            ##############################
            # Remove garbage labels
            str_ref = re.sub(r'[\'>]+', '', str_ref)
            str_hyp = re.sub(r'[\'>]+', '', str_hyp)
            # TODO: WER計算するときに消していい?

            # Compute WER
            wer_b, sub_b, ins_b, del_b = compute_wer(ref=str_ref.split('_'),
                                                     hyp=str_hyp.split('_'),
                                                     normalize=False)
            wer += wer_b
            sub_word += sub_b
            ins_word += ins_b
            del_word += del_b
            num_words += len(str_ref.split('_'))

            # Compute CER
            cer_b, sub_b, ins_b, del_b = compute_wer(
                ref=list(str_ref.replace('_', '')),
                hyp=list(str_hyp.replace('_', '')),
                normalize=False)
            cer += cer_b
            sub_char += sub_b
            ins_char += ins_b
            del_char += del_b
            num_chars += len(str_ref.replace('_', ''))

            if progressbar:
                pbar.update(1)

        if is_new_epoch:
            break

    if progressbar:
        pbar.close()

    wer /= num_words
    cer /= num_chars
    sub_char /= num_chars
    ins_char /= num_chars
    del_char /= num_chars
    sub_word /= num_words
    ins_word /= num_words
    del_word /= num_words

    df_wer_cer = pd.DataFrame(
        {
            'SUB': [sub_char * 100, sub_word * 100],
            'INS': [ins_char * 100, ins_word * 100],
            'DEL': [del_char * 100, del_word * 100]
        },
        columns=['SUB', 'INS', 'DEL'],
        index=['CER', 'WER'])

    return cer, wer, df_wer_cer
示例#29
0
def do_eval_cer(session,
                decode_op,
                model,
                dataset,
                label_type,
                eval_batch_size=None,
                progressbar=False,
                is_multitask=False):
    """Evaluate trained model by Character Error Rate.
    Args:
        session: session of training model
        decode_op: operation for decoding
        model: the model to evaluate
        dataset: An instance of a `Dataset` class
        label_type (string): character or character_capital_divide
        eval_batch_size (int, optional): the batch size when evaluating the model
        progressbar (bool, optional): if True, visualize the progressbar
        is_multitask (bool, optional): if True, evaluate the multitask model
    Return:
        cer_mean (float): An average of CER
        wer_mean (float): An average of WER
    """
    # Reset data counter
    dataset.reset()

    if label_type == 'character':
        idx2char = Idx2char(
            map_file_path='../metrics/mapping_files/ctc/character.txt')
    elif label_type == 'character_capital_divide':
        idx2char = Idx2char(
            map_file_path=
            '../metrics/mapping_files/ctc/character_capital_divide.txt',
            capital_divide=True,
            space_mark='_')

    cer_mean, wer_mean = 0, 0
    if progressbar:
        pbar = tqdm(total=len(dataset))
    for data, is_new_epoch in dataset:

        # Create feed dictionary for next mini batch
        if is_multitask:
            inputs, labels_true, _, inputs_seq_len, _ = data
        else:
            inputs, labels_true, inputs_seq_len, _ = data

        feed_dict = {
            model.inputs_pl_list[0]: inputs,
            model.inputs_seq_len_pl_list[0]: inputs_seq_len,
            model.keep_prob_input_pl_list[0]: 1.0,
            model.keep_prob_hidden_pl_list[0]: 1.0,
            model.keep_prob_output_pl_list[0]: 1.0
        }

        batch_size_each = len(inputs)

        labels_pred_st = session.run(decode_op, feed_dict=feed_dict)
        labels_pred = sparsetensor2list(labels_pred_st, batch_size_each)
        for i_batch in range(batch_size_each):

            # Convert from list of index to string
            str_true = idx2char(labels_true[i_batch])
            str_pred = idx2char(labels_pred[i_batch])

            # Remove consecutive spaces
            str_pred = re.sub(r'[_]+', '_', str_pred)

            # Remove garbage labels
            str_true = re.sub(r'[\'\":;!?,.-]+', "", str_true)
            str_pred = re.sub(r'[\'\":;!?,.-]+', "", str_pred)

            # Compute WER
            wer_mean += compute_wer(hyp=str_pred.split('_'),
                                    ref=str_true.split('_'),
                                    normalize=True)
            # substitute, insert, delete = wer_align(
            #     ref=str_pred.split('_'),
            #     hyp=str_true.split('_'))
            # print(substitute)
            # print(insert)
            # print(delete)

            # Remove spaces
            str_pred = re.sub(r'[_]+', "", str_pred)
            str_true = re.sub(r'[_]+', "", str_true)

            # Compute CER
            cer_mean += compute_cer(str_pred=str_pred,
                                    str_true=str_true,
                                    normalize=True)

            if progressbar:
                pbar.update(1)

        if is_new_epoch:
            break

    cer_mean /= len(dataset)
    wer_mean /= len(dataset)

    return cer_mean, wer_mean
    def check(self, encoder_type, bidirectional=False, label_type='char',
              subsample=False,  projection=False,
              conv=False, batch_norm=False, activation='relu',
              encoder_residual=False, encoder_dense_residual=False,
              label_smoothing=False):

        print('==================================================')
        print('  label_type: %s' % label_type)
        print('  encoder_type: %s' % encoder_type)
        print('  bidirectional: %s' % str(bidirectional))
        print('  projection: %s' % str(projection))
        print('  subsample: %s' % str(subsample))
        print('  conv: %s' % str(conv))
        print('  batch_norm: %s' % str(batch_norm))
        print('  activation: %s' % activation)
        print('  encoder_residual: %s' % str(encoder_residual))
        print('  encoder_dense_residual: %s' % str(encoder_dense_residual))
        print('  label_smoothing: %s' % str(label_smoothing))
        print('==================================================')

        if conv or encoder_type == 'cnn':
            # pattern 1
            # conv_channels = [32, 32]
            # conv_kernel_sizes = [[41, 11], [21, 11]]
            # conv_strides = [[2, 2], [2, 1]]
            # poolings = [[], []]

            # pattern 2 (VGG like)
            conv_channels = [64, 64]
            conv_kernel_sizes = [[3, 3], [3, 3]]
            conv_strides = [[1, 1], [1, 1]]
            poolings = [[2, 2], [2, 2]]

            fc_list = [786, 786]
        else:
            conv_channels = []
            conv_kernel_sizes = []
            conv_strides = []
            poolings = []
            fc_list = []

        # Load batch data
        splice = 1
        num_stack = 1 if subsample or conv or encoder_type == 'cnn' else 2
        xs, ys, x_lens, y_lens = generate_data(
            label_type=label_type,
            batch_size=2,
            num_stack=num_stack,
            splice=splice,
            backend='chainer')

        if label_type == 'char':
            num_classes = 27
            map_fn = idx2char
        elif label_type == 'word':
            num_classes = 11
            map_fn = idx2word

        # Load model
        model = CTC(
            input_size=xs[0].shape[-1] // splice // num_stack,  # 120
            encoder_type=encoder_type,
            encoder_bidirectional=bidirectional,
            encoder_num_units=256,
            encoder_num_proj=256 if projection else 0,
            encoder_num_layers=1 if not subsample else 2,
            fc_list=fc_list,
            dropout_input=0.1,
            dropout_encoder=0.1,
            num_classes=num_classes,
            parameter_init_distribution='uniform',
            parameter_init=0.1,
            recurrent_weight_orthogonal=False,
            init_forget_gate_bias_with_one=True,
            subsample_list=[] if not subsample else [True] * 2,
            num_stack=num_stack,
            splice=splice,
            input_channel=3,
            conv_channels=conv_channels,
            conv_kernel_sizes=conv_kernel_sizes,
            conv_strides=conv_strides,
            poolings=poolings,
            activation=activation,
            batch_norm=batch_norm,
            label_smoothing_prob=0.1 if label_smoothing else 0,
            weight_noise_std=0,
            encoder_residual=encoder_residual,
            encoder_dense_residual=encoder_dense_residual)

        # Count total parameters
        for name in sorted(list(model.num_params_dict.keys())):
            num_params = model.num_params_dict[name]
            print("%s %d" % (name, num_params))
        print("Total %.3f M parameters" % (model.total_parameters / 1000000))

        # Define optimizer
        learning_rate = 1e-3
        model.set_optimizer(
            'adam',
            # 'adadelta',
            learning_rate_init=learning_rate,
            weight_decay=1e-6,
            clip_grad_norm=5,
            lr_schedule=None,
            factor=None,
            patience_epoch=None)

        # Define learning rate controller
        lr_controller = Controller(learning_rate_init=learning_rate,
                                   backend='chainer',
                                   decay_start_epoch=20,
                                   decay_rate=0.9,
                                   decay_patient_epoch=10,
                                   lower_better=True)

        # GPU setting
        model.set_cuda(deterministic=False, benchmark=True)

        # Train model
        max_step = 300
        start_time_step = time.time()
        for step in range(max_step):

            # Step for parameter update
            loss = model(xs, ys, x_lens, y_lens)
            model.optimizer.target.cleargrads()
            model.cleargrads()
            loss.backward()
            loss.unchain_backward()
            model.optimizer.update()

            # Inject Gaussian noise to all parameters

            if (step + 1) % 10 == 0:
                # Compute loss
                loss = model(xs, ys, x_lens, y_lens, is_eval=True)

                # Decode
                best_hyps, _,  _ = model.decode(xs, x_lens, beam_width=1)
                # TODO: fix beam search

                str_ref = map_fn(ys[0, :y_lens[0]])
                str_hyp = map_fn(best_hyps[0])

                # Compute accuracy
                try:
                    if label_type == 'char':
                        ler, _, _, _ = compute_wer(
                            ref=list(str_ref.replace('_', '')),
                            hyp=list(str_hyp.replace('_', '')),
                            normalize=True)
                    elif label_type == 'word':
                        ler, _, _, _ = compute_wer(ref=str_ref.split('_'),
                                                   hyp=str_hyp.split('_'),
                                                   normalize=True)
                except:
                    ler = 1

                duration_step = time.time() - start_time_step
                print('Step %d: loss=%.3f / ler=%.3f / lr=%.5f (%.3f sec)' %
                      (step + 1, loss, ler, learning_rate, duration_step))
                start_time_step = time.time()

                # Visualize
                print('Ref: %s' % str_ref)
                print('Hyp: %s' % str_hyp)

                if ler < 0.05:
                    print('Modle is Converged.')
                    break

                # Update learning rate
                model.optimizer, learning_rate = lr_controller.decay_lr(
                    optimizer=model.optimizer,
                    learning_rate=learning_rate,
                    epoch=step,
                    value=ler)
def decode(model,
           dataset,
           eval_batch_size,
           beam_width,
           length_penalty,
           save_path=None):
    """Visualize label outputs.
    Args:
        model: the model to evaluate
        dataset: An instance of a `Dataset` class
        eval_batch_size (int): the batch size when evaluating the model
        beam_width: (int): the size of beam
        length_penalty (float):
        save_path (string): path to save decoding results
    """
    if 'word' in dataset.label_type:
        map_fn = Idx2word(dataset.vocab_file_path)
        max_decode_len = MAX_DECODE_LEN_WORD
    else:
        map_fn = Idx2char(dataset.vocab_file_path)
        max_decode_len = MAX_DECODE_LEN_CHAR

    if save_path is not None:
        sys.stdout = open(join(model.model_dir, 'decode.txt'), 'w')

    for batch, is_new_epoch in dataset:

        # Decode
        if model.model_type == 'nested_attention':
            best_hyps, _, best_hyps_sub, _, perm_idx = model.decode(
                batch['xs'],
                batch['x_lens'],
                beam_width=beam_width,
                max_decode_len=max_decode_len,
                max_decode_len_sub=max_decode_len)
        else:
            best_hyps, _, perm_idx = model.decode(
                batch['xs'],
                batch['x_lens'],
                beam_width=beam_width,
                max_decode_len=max_decode_len)

        if model.model_type == 'attention' and model.ctc_loss_weight > 0:
            best_hyps_ctc, perm_idx = model.decode_ctc(batch['xs'],
                                                       batch['x_lens'],
                                                       beam_width=beam_width)

        ys = batch['ys'][perm_idx]
        y_lens = batch['y_lens'][perm_idx]

        for b in range(len(batch['xs'])):
            ##############################
            # Reference
            ##############################
            if dataset.is_test:
                str_ref = ys[b][0]
                # NOTE: transcript is seperated by space('_')
            else:
                # Convert from list of index to string
                str_ref = map_fn(ys[b][:y_lens[b]])

            ##############################
            # Hypothesis
            ##############################
            # Convert from list of index to string
            str_hyp = map_fn(best_hyps[b])

            print('----- wav: %s -----' % batch['input_names'][b])
            print('Ref: %s' % str_ref.replace('_', ' '))
            print('Hyp: %s' % str_hyp.replace('_', ' '))
            if model.model_type == 'attention' and model.ctc_loss_weight > 0:
                str_hyp_ctc = map_fn(best_hyps_ctc[b])
                print('Hyp (CTC): %s' % str_hyp_ctc)

            try:
                if dataset.label_type == 'word' or dataset.label_type == 'kanji_wb':
                    wer, _, _, _ = compute_wer(ref=str_ref.split('_'),
                                               hyp=re.sub(
                                                   r'(.*)[_]*>(.*)', r'\1',
                                                   str_hyp).split('_'),
                                               normalize=True)
                    print('WER: %.3f %%' % (wer * 100))
                    if model.model_type == 'attention' and model.ctc_loss_weight > 0:
                        wer_ctc, _, _, _ = compute_wer(
                            ref=str_ref.split('_'),
                            hyp=str_hyp_ctc.split('_'),
                            normalize=True)
                        print('WER (CTC): %.3f %%' % (wer_ctc * 100))
                else:
                    cer, _, _, _ = compute_wer(
                        ref=list(str_ref.replace('_', '')),
                        hyp=list(
                            re.sub(r'(.*)>(.*)', r'\1',
                                   str_hyp).replace('_', '')),
                        normalize=True)
                    print('CER: %.3f %%' % (cer * 100))
                    if model.model_type == 'attention' and model.ctc_loss_weight > 0:
                        cer_ctc, _, _, _ = compute_wer(
                            ref=list(str_ref.replace('_', '')),
                            hyp=list(str_hyp_ctc.replace('_', '')),
                            normalize=True)
                        print('CER (CTC): %.3f %%' % (cer_ctc * 100))
            except:
                print('--- skipped ---')

        if is_new_epoch:
            break
示例#32
0
def decode(model, dataset, beam_width, max_decode_len,
           eval_batch_size=None, save_path=None):
    """Visualize label outputs.
    Args:
        model: the model to evaluate
        dataset: An instance of a `Dataset` class
        beam_width: (int): the size of beam
        max_decode_len (int): the length of output sequences
            to stop prediction when EOS token have not been emitted.
            This is used for seq2seq models.
        eval_batch_size (int, optional): the batch size when evaluating the model
        save_path (string): path to save decoding results
    """
    # Set batch size in the evaluation
    if eval_batch_size is not None:
        dataset.batch_size = eval_batch_size

    vocab_file_path = '../metrics/vocab_files/' + \
        dataset.label_type + '_' + dataset.data_size + '.txt'
    if dataset.label_type == 'character':
        map_fn = Idx2char(vocab_file_path)
    elif dataset.label_type == 'character_capital_divide':
        map_fn = Idx2char(vocab_file_path, capital_divide=True)
    else:
        map_fn = Idx2word(vocab_file_path)

    if save_path is not None:
        sys.stdout = open(join(model.model_dir, 'decode.txt'), 'w')

    for batch, is_new_epoch in dataset:

        # Decode
        best_hyps, perm_idx = model.decode(batch['xs'], batch['x_lens'],
                                           beam_width=beam_width,
                                           max_decode_len=max_decode_len)

        if model.model_type == 'attention' and model.ctc_loss_weight > 0:
            best_hyps_ctc, perm_idx = model.decode_ctc(
                batch['xs'], batch['x_lens'], beam_width=beam_width)

        ys = batch['ys'][perm_idx]
        y_lens = batch['y_lens'][perm_idx]

        for b in range(len(batch['xs'])):

            ##############################
            # Reference
            ##############################
            if dataset.is_test:
                str_ref = ys[b][0]
                # NOTE: transcript is seperated by space('_')
            else:
                # Convert from list of index to string
                str_ref = map_fn(ys[b][:y_lens[b]])

            ##############################
            # Hypothesis
            ##############################
            # Convert from list of index to string
            str_hyp = map_fn(best_hyps[b])

            if model.model_type == 'attention':
                str_hyp = str_hyp.split('>')[0]
                # NOTE: Trancate by the first <EOS>

                # Remove the last space
                if len(str_hyp) > 0 and str_hyp[-1] == '_':
                    str_hyp = str_hyp[:-1]

            ##############################
            # Post-proccessing
            ##############################
            # Remove garbage labels
            str_ref = re.sub(r'[\'>]+', '', str_ref)
            str_hyp = re.sub(r'[\'>]+', '', str_hyp)

            print('----- wav: %s -----' % batch['input_names'][b])
            print('Ref: %s' % str_ref.replace('_', ' '))
            print('Hyp: %s' % str_hyp.replace('_', ' '))
            if model.model_type == 'attention' and model.ctc_loss_weight > 0:
                str_hyp_ctc = map_fn(best_hyps_ctc[b])
                print('Hyp (CTC): %s' % str_hyp_ctc)

            # Compute CER
            if 'word' in dataset.label_type:
                wer, _, _, _ = compute_wer(ref=str_ref.split('_'),
                                           hyp=str_hyp.split('_'),
                                           normalize=True)
                print('WER: %f %%' % (wer * 100))
                if model.ctc_loss_weight > 0:
                    wer_ctc, _, _, _ = compute_wer(ref=str_ref.split('_'),
                                                   hyp=str_hyp_ctc.split('_'),
                                                   normalize=True)
                    print('WER (CTC): %f %%' % (wer_ctc * 100))
            else:
                cer, _, _, _ = compute_wer(ref=str_ref.split('_'),
                                           hyp=str_hyp.split('_'),
                                           normalize=True)
                print('CER: %f %%' % (cer * 100))
                if model.model_type == 'attention' and model.ctc_loss_weight > 0:
                    cer_ctc, _, _, _ = compute_wer(
                        ref=list(str_ref.replace('_', '')),
                        hyp=list(str_hyp.replace('_', '')),
                        normalize=True)
                    print('CER (CTC): %f %%' % (cer_ctc * 100))

        if is_new_epoch:
            break