def do_eval_cer(save_paths,
                dataset,
                data_type,
                label_type,
                num_classes,
                beam_width,
                temperature_infer,
                is_test=False,
                progressbar=False):
    """Evaluate trained model by Character Error Rate.
    Args:
        save_paths (list):
        dataset: An instance of a `Dataset` class
        data_type (string):
        label_type (string): character
        num_classes (int):
        beam_width (int): the size of beam
        temperature (int): temperature in the inference stage
        is_test (bool, optional): set to True when evaluating by the test set
        progressbar (bool, optional): if True, visualize the progressbar
    Return:
        cer_mean (float): An average of CER
        wer_mean (float): An average of WER
    """
    if label_type == 'character':
        idx2char = Idx2char(
            map_file_path='../metrics/mapping_files/character.txt')
        char2idx = Char2idx(
            map_file_path='../metrics/mapping_files/character.txt')
    else:
        raise TypeError

    # Define decoder
    decoder = BeamSearchDecoder(space_index=char2idx(str_char='_')[0],
                                blank_index=num_classes - 1)

    ##################################################
    # Compute mean probabilities
    ##################################################
    if progressbar:
        pbar = tqdm(total=len(dataset))
    cer_mean, wer_mean = 0, 0
    for data, is_new_epoch in dataset:

        # Create feed dictionary for next mini batch
        inputs, labels_true, inputs_seq_len, input_names = data

        batch_size = inputs[0].shape[0]
        for i_batch in range(batch_size):
            probs_ensemble = None
            for i_model in range(len(save_paths)):

                # Load posteriors
                speaker = input_names[0][i_batch].split('-')[0]
                prob_save_path = join(save_paths[i_model],
                                      'temp' + str(temperature_infer),
                                      data_type, 'probs_utt', speaker,
                                      input_names[0][i_batch] + '.npy')
                probs_model_i = np.load(prob_save_path)
                # NOTE: probs_model_i: `[T, num_classes]`

                # Sum over probs
                if probs_ensemble is None:
                    probs_ensemble = probs_model_i
                else:
                    probs_ensemble += probs_model_i

            # Compute mean posteriors
            probs_ensemble /= len(save_paths)

            # Decode per utterance
            labels_pred, scores = decoder(
                probs=probs_ensemble[np.newaxis, :, :],
                seq_len=inputs_seq_len[0][i_batch:i_batch + 1],
                beam_width=beam_width)

            # Convert from list of index to string
            if is_test:
                str_true = labels_true[0][i_batch][0]
                # NOTE: transcript is seperated by space('_')
            else:
                str_true = idx2char(labels_true[0][i_batch],
                                    padded_value=dataset.padded_value)
            str_pred = idx2char(labels_pred[0])

            # Remove consecutive spaces
            str_pred = re.sub(r'[_]+', '_', str_pred)

            # Remove garbage labels
            str_true = re.sub(r'[\']+', '', str_true)
            str_pred = re.sub(r'[\']+', '', str_pred)

            # Compute WER
            wer_mean += compute_wer(ref=str_pred.split('_'),
                                    hyp=str_true.split('_'),
                                    normalize=True)
            # substitute, insert, delete = wer_align(
            #     ref=str_true.split('_'),
            #     hyp=str_pred.split('_'))
            # print('SUB: %d' % substitute)
            # print('INS: %d' % insert)
            # print('DEL: %d' % delete)

            # Remove spaces
            str_true = re.sub(r'[_]+', '', str_true)
            str_pred = re.sub(r'[_]+', '', str_pred)

            # Compute CER
            cer_mean += compute_cer(str_pred=str_pred,
                                    str_true=str_true,
                                    normalize=True)

            if progressbar:
                pbar.update(1)

        if is_new_epoch:
            break

    cer_mean /= (len(dataset))
    wer_mean /= (len(dataset))
    # TODO: Fix this

    return cer_mean, wer_mean
Ejemplo n.º 2
0
def do_eval_cer(session,
                decode_op,
                model,
                dataset,
                label_type,
                eval_batch_size=None,
                progressbar=False,
                is_multitask=False):
    """Evaluate trained model by Character Error Rate.
    Args:
        session: session of training model
        decode_op: operation for decoding
        model: the model to evaluate
        dataset: An instance of a `Dataset` class
        label_type (string): character or character_capital_divide
        eval_batch_size (int, optional): the batch size when evaluating the model
        progressbar (bool, optional): if True, visualize the progressbar
        is_multitask (bool, optional): if True, evaluate the multitask model
    Return:
        cer_mean (float): An average of CER
        wer_mean (float): An average of WER
    """
    # Reset data counter
    dataset.reset()

    if label_type == 'character':
        idx2char = Idx2char(
            map_file_path='../metrics/mapping_files/ctc/character.txt')
    elif label_type == 'character_capital_divide':
        idx2char = Idx2char(
            map_file_path=
            '../metrics/mapping_files/ctc/character_capital_divide.txt',
            capital_divide=True,
            space_mark='_')

    cer_mean, wer_mean = 0, 0
    if progressbar:
        pbar = tqdm(total=len(dataset))
    for data, is_new_epoch in dataset:

        # Create feed dictionary for next mini batch
        if is_multitask:
            inputs, labels_true, _, inputs_seq_len, _ = data
        else:
            inputs, labels_true, inputs_seq_len, _ = data

        feed_dict = {
            model.inputs_pl_list[0]: inputs,
            model.inputs_seq_len_pl_list[0]: inputs_seq_len,
            model.keep_prob_input_pl_list[0]: 1.0,
            model.keep_prob_hidden_pl_list[0]: 1.0,
            model.keep_prob_output_pl_list[0]: 1.0
        }

        batch_size_each = len(inputs)

        labels_pred_st = session.run(decode_op, feed_dict=feed_dict)
        labels_pred = sparsetensor2list(labels_pred_st, batch_size_each)
        for i_batch in range(batch_size_each):

            # Convert from list of index to string
            str_true = idx2char(labels_true[i_batch])
            str_pred = idx2char(labels_pred[i_batch])

            # Remove consecutive spaces
            str_pred = re.sub(r'[_]+', '_', str_pred)

            # Remove garbage labels
            str_true = re.sub(r'[\'\":;!?,.-]+', "", str_true)
            str_pred = re.sub(r'[\'\":;!?,.-]+', "", str_pred)

            # Compute WER
            wer_mean += compute_wer(hyp=str_pred.split('_'),
                                    ref=str_true.split('_'),
                                    normalize=True)
            # substitute, insert, delete = wer_align(
            #     ref=str_pred.split('_'),
            #     hyp=str_true.split('_'))
            # print(substitute)
            # print(insert)
            # print(delete)

            # Remove spaces
            str_pred = re.sub(r'[_]+', "", str_pred)
            str_true = re.sub(r'[_]+', "", str_true)

            # Compute CER
            cer_mean += compute_cer(str_pred=str_pred,
                                    str_true=str_true,
                                    normalize=True)

            if progressbar:
                pbar.update(1)

        if is_new_epoch:
            break

    cer_mean /= len(dataset)
    wer_mean /= len(dataset)

    return cer_mean, wer_mean
def do_eval_cer(save_paths, dataset, data_type, label_type, num_classes,
                beam_width, temperature_infer,
                is_test=False, progressbar=False):
    """Evaluate trained model by Character Error Rate.
    Args:
        save_paths (list):
        dataset: An instance of a `Dataset` class
        data_type (string):
        label_type (string): character
        num_classes (int):
        beam_width (int): the size of beam
        temperature (int): temperature in the inference stage
        is_test (bool, optional): set to True when evaluating by the test set
        progressbar (bool, optional): if True, visualize the progressbar
    Return:
        cer_mean (float): An average of CER
        wer_mean (float): An average of WER
    """
    if label_type == 'character':
        idx2char = Idx2char(
            map_file_path='../metrics/mapping_files/character.txt')
        char2idx = Char2idx(
            map_file_path='../metrics/mapping_files/character.txt')
    else:
        raise TypeError

    # Define decoder
    decoder = BeamSearchDecoder(space_index=char2idx(str_char='_')[0],
                                blank_index=num_classes - 1)

    ##################################################
    # Compute mean probabilities
    ##################################################
    if progressbar:
        pbar = tqdm(total=len(dataset))
    cer_mean, wer_mean = 0, 0
    for data, is_new_epoch in dataset:

        # Create feed dictionary for next mini batch
        inputs, labels_true, inputs_seq_len, input_names = data

        batch_size = inputs[0].shape[0]
        for i_batch in range(batch_size):
            probs_ensemble = None
            for i_model in range(len(save_paths)):

                # Load posteriors
                speaker = input_names[0][i_batch].split('-')[0]
                prob_save_path = join(
                    save_paths[i_model], 'temp' + str(temperature_infer),
                    data_type, 'probs_utt', speaker,
                    input_names[0][i_batch] + '.npy')
                probs_model_i = np.load(prob_save_path)
                # NOTE: probs_model_i: `[T, num_classes]`

                # Sum over probs
                if probs_ensemble is None:
                    probs_ensemble = probs_model_i
                else:
                    probs_ensemble += probs_model_i

            # Compute mean posteriors
            probs_ensemble /= len(save_paths)

            # Decode per utterance
            labels_pred, scores = decoder(
                probs=probs_ensemble[np.newaxis, :, :],
                seq_len=inputs_seq_len[0][i_batch: i_batch + 1],
                beam_width=beam_width)

            # Convert from list of index to string
            if is_test:
                str_true = labels_true[0][i_batch][0]
                # NOTE: transcript is seperated by space('_')
            else:
                str_true = idx2char(labels_true[0][i_batch],
                                    padded_value=dataset.padded_value)
            str_pred = idx2char(labels_pred[0])

            # Remove consecutive spaces
            str_pred = re.sub(r'[_]+', '_', str_pred)

            # Remove garbage labels
            str_true = re.sub(r'[\']+', '', str_true)
            str_pred = re.sub(r'[\']+', '', str_pred)

            # Compute WER
            wer_mean += compute_wer(ref=str_pred.split('_'),
                                    hyp=str_true.split('_'),
                                    normalize=True)
            # substitute, insert, delete = wer_align(
            #     ref=str_true.split('_'),
            #     hyp=str_pred.split('_'))
            # print('SUB: %d' % substitute)
            # print('INS: %d' % insert)
            # print('DEL: %d' % delete)

            # Remove spaces
            str_true = re.sub(r'[_]+', '', str_true)
            str_pred = re.sub(r'[_]+', '', str_pred)

            # Compute CER
            cer_mean += compute_cer(str_pred=str_pred,
                                    str_true=str_true,
                                    normalize=True)

            if progressbar:
                pbar.update(1)

        if is_new_epoch:
            break

    cer_mean /= (len(dataset))
    wer_mean /= (len(dataset))
    # TODO: Fix this

    return cer_mean, wer_mean
Ejemplo n.º 4
0
def do_eval_cer(session,
                decode_ops,
                model,
                dataset,
                label_type,
                train_data_size,
                is_test=False,
                eval_batch_size=None,
                progressbar=False):
    """Evaluate trained model by Character Error Rate.
    Args:
        session: session of training model
        decode_ops (list): operations for decoding
        model: the model to evaluate
        dataset: An instance of a `Dataset` class
        label_type (string): kanji or kanji or kanji_divide or kana_divide
        train_data_size (string): train_subset or train_fullset
        is_test (bool, optional): set to True when evaluating by the test set
        eval_batch_size (int, optional): the batch size when evaluating the model
        progressbar (bool, optional): if True, visualize the progressbar
    Return:
        cer_mean (float): An average of CER
    """
    assert isinstance(decode_ops, list), "decode_ops must be a list."

    batch_size_original = dataset.batch_size

    # Reset data counter
    dataset.reset()

    # Set batch size in the evaluation
    if eval_batch_size is not None:
        dataset.batch_size = eval_batch_size

    if 'kanji' in label_type:
        map_file_path = '../metrics/mapping_files/' + \
            label_type + '_' + train_data_size + '.txt'
    elif 'kana' in label_type:
        map_file_path = '../metrics/mapping_files/' + label_type + '.txt'
    else:
        raise TypeError

    idx2char = Idx2char(map_file_path=map_file_path)

    cer_mean = 0
    if progressbar:
        pbar = tqdm(total=len(dataset))
    for data, is_new_epoch in dataset:

        # Create feed dictionary for next mini batch
        inputs, labels_true, inputs_seq_len, labels_seq_len, _ = data

        feed_dict = {}
        for i_device in range(len(decode_ops)):
            feed_dict[model.inputs_pl_list[i_device]] = inputs[i_device]
            feed_dict[model.inputs_seq_len_pl_list[i_device]] = inputs_seq_len[
                i_device]
            feed_dict[model.keep_prob_encoder_pl_list[i_device]] = 1.0
            feed_dict[model.keep_prob_decoder_pl_list[i_device]] = 1.0
            feed_dict[model.keep_prob_embedding_pl_list[i_device]] = 1.0

        labels_pred_list = session.run(decode_ops, feed_dict=feed_dict)
        for i_device in range(len(labels_pred_list)):
            for i_batch in range(len(inputs[i_device])):

                # Convert from list of index to string
                if is_test:
                    str_true = labels_true[i_device][i_batch][0]
                    # NOTE: transcript is seperated by space('_')
                else:
                    str_true = idx2char(labels_true[i_device][i_batch]
                                        [1:labels_seq_len[i_device][i_batch] -
                                         1])
                str_pred = idx2char(
                    labels_pred_list[i_device][i_batch]).split('>')[0]
                # NOTE: Trancate by <EOS>

                # Remove garbage labels
                str_true = re.sub(r'[_NZー・<>]+', '', str_true)
                str_pred = re.sub(r'[_NZー・<>]+', '', str_pred)

                # Compute CER
                cer_mean += compute_cer(str_pred=str_pred,
                                        str_true=str_true,
                                        normalize=True)

                if progressbar:
                    pbar.update(1)

        if is_new_epoch:
            break

    cer_mean /= len(dataset)

    # Register original batch size
    if eval_batch_size is not None:
        dataset.batch_size = batch_size_original

    return cer_mean
def do_eval_cer(session, decode_op, model, dataset, label_type, ss_type,
                is_test=False, eval_batch_size=None, progressbar=False):
    """Evaluate trained model by Character Error Rate.
    Args:
        session: session of training model
        decode_op: operation for decoding
        model: the model to evaluate
        dataset: An instance of a `Dataset` class
        label_type (string): kana
        ss_type (string): remove or insert_left or insert_both or insert_right
        is_test (bool, optional): set to True when evaluating by the test set
        eval_batch_size (int, optional): the batch size when evaluating the model
        progressbar (bool, optional): if True, visualize the progressbar
    Return:
        cer_mean (float): An average of CER
    """
    batch_size_original = dataset.batch_size

    # Reset data counter
    dataset.reset()

    # Set batch size in the evaluation
    if eval_batch_size is not None:
        dataset.batch_size = eval_batch_size

    idx2char = Idx2char(
        map_file_path='../metrics/mapping_files/' + label_type + '_' + ss_type + '.txt')

    cer_mean = 0
    skip_data_num = 0
    if progressbar:
        pbar = tqdm(total=len(dataset))
    for data, is_new_epoch in dataset:

        # Create feed dictionary for next mini batch
        inputs, labels_true, inputs_seq_len, _ = data

        feed_dict = {
            model.inputs_pl_list[0]: inputs[0],
            model.inputs_seq_len_pl_list[0]: inputs_seq_len[0],
            model.keep_prob_pl_list[0]: 1.0
        }

        batch_size = inputs[0].shape[0]

        labels_pred_st = session.run(decode_op, feed_dict=feed_dict)

        try:
            labels_pred = sparsetensor2list(labels_pred_st, batch_size)

            for i_batch in range(batch_size):

                # Convert from list of index to string
                if is_test:
                    str_true = labels_true[0][i_batch][0]
                    # NOTE: transcript is seperated by space('_')
                else:
                    # Convert from list of index to string
                    str_true = idx2char(labels_true[0][i_batch],
                                        padded_value=dataset.padded_value)
                str_pred = idx2char(labels_pred[i_batch])

                # Remove garbage labels
                str_true = re.sub(r'[_NZLFBDlfbdー]+', '', str_true)
                str_pred = re.sub(r'[_NZLFBDlfbdー]+', '', str_pred)

                # Compute CER
                cer_mean += compute_cer(str_pred=str_pred,
                                        str_true=str_true,
                                        normalize=True)

                if progressbar:
                    pbar.update(1)
        except:
            print('skipped')
            skip_data_num += batch_size
            # TODO: Conduct decoding again with batch size 1

            if progressbar:
                pbar.update(batch_size)

        if is_new_epoch:
            break

    cer_mean /= (len(dataset) - skip_data_num)

    # Register original batch size
    if eval_batch_size is not None:
        dataset.batch_size = batch_size_original

    return cer_mean
    def check(self, decoder_type):

        print('==================================================')
        print('  decoder_type: %s' % decoder_type)
        print('==================================================')

        tf.reset_default_graph()
        with tf.Graph().as_default():
            # Load batch data
            batch_size = 2
            num_stack = 2
            inputs, labels, inputs_seq_len = generate_data(
                label_type='character',
                model='ctc',
                batch_size=batch_size,
                num_stack=num_stack,
                splice=1)
            max_time = inputs.shape[1]

            # Define model graph
            model = CTC(encoder_type='blstm',
                        input_size=inputs[0].shape[-1],
                        splice=1,
                        num_stack=num_stack,
                        num_units=256,
                        num_layers=2,
                        num_classes=27,
                        lstm_impl='LSTMBlockCell',
                        parameter_init=0.1,
                        clip_grad_norm=5.0,
                        clip_activation=50,
                        num_proj=256,
                        weight_decay=1e-6)

            # Define placeholders
            model.create_placeholders()

            # Add to the graph each operation
            _, logits = model.compute_loss(
                model.inputs_pl_list[0],
                model.labels_pl_list[0],
                model.inputs_seq_len_pl_list[0],
                model.keep_prob_pl_list[0])
            beam_width = 20 if 'beam_search' in decoder_type else 1
            decode_op = model.decoder(logits,
                                      model.inputs_seq_len_pl_list[0],
                                      beam_width=beam_width)
            ler_op = model.compute_ler(decode_op, model.labels_pl_list[0])
            posteriors_op = model.posteriors(logits, blank_prior=1)

            if decoder_type == 'np_greedy':
                decoder = GreedyDecoder(blank_index=model.num_classes)
            elif decoder_type == 'np_beam_search':
                decoder = BeamSearchDecoder(space_index=26,
                                            blank_index=model.num_classes - 1)

            # Make feed dict
            feed_dict = {
                model.inputs_pl_list[0]: inputs,
                model.labels_pl_list[0]: list2sparsetensor(labels,
                                                           padded_value=-1),
                model.inputs_seq_len_pl_list[0]: inputs_seq_len,
                model.keep_prob_pl_list[0]: 1.0
            }

            # Create a saver for writing training checkpoints
            saver = tf.train.Saver()

            with tf.Session() as sess:
                ckpt = tf.train.get_checkpoint_state('./')

                # If check point exists
                if ckpt:
                    model_path = ckpt.model_checkpoint_path
                    saver.restore(sess, model_path)
                    print("Model restored: " + model_path)
                else:
                    raise ValueError('There are not any checkpoints.')

                if decoder_type in ['tf_greedy', 'tf_beam_search']:
                    # Decode
                    labels_pred_st = sess.run(decode_op, feed_dict=feed_dict)
                    labels_pred = sparsetensor2list(
                        labels_pred_st, batch_size=batch_size)

                    # Compute accuracy
                    cer = sess.run(ler_op, feed_dict=feed_dict)
                else:
                    # Compute CTC posteriors
                    probs = sess.run(posteriors_op, feed_dict=feed_dict)
                    probs = probs.reshape(-1, max_time, model.num_classes)

                    if decoder_type == 'np_greedy':
                        # Decode
                        labels_pred = decoder(probs=probs,
                                              seq_len=inputs_seq_len)

                    elif decoder_type == 'np_beam_search':
                        # Decode
                        labels_pred, scores = decoder(probs=probs,
                                                      seq_len=inputs_seq_len,
                                                      beam_width=beam_width)

                    # Compute accuracy
                    cer = compute_cer(str_pred=idx2alpha(labels_pred[0]),
                                      str_true=idx2alpha(labels[0]),
                                      normalize=True)

                # Visualize
                print('CER: %.3f %%' % (cer * 100))
                print('Ref: %s' % idx2alpha(labels[0]))
                print('Hyp: %s' % idx2alpha(labels_pred[0]))
def do_eval_cer2(session,
                 posteriors_ops,
                 beam_width,
                 model,
                 dataset,
                 label_type,
                 is_test=False,
                 eval_batch_size=None,
                 progressbar=False,
                 is_multitask=False):
    """Evaluate trained model by Character Error Rate.
    Args:
        session: session of training model
        posteriors_ops: list of operations for computing posteriors
        beam_width (int):
        model: the model to evaluate
        dataset: An instance of a `Dataset` class
        label_type (string): character or character_capital_divide
        is_test (bool, optional): set to True when evaluating by the test set
        eval_batch_size (int, optional): the batch size when evaluating the model
        progressbar (bool, optional): if True, visualize the progressbar
        is_multitask (bool, optional): if True, evaluate the multitask model
    Return:
        cer_mean (float): An average of CER
        wer_mean (float): An average of WER
    """
    assert isinstance(posteriors_ops, list), "posteriors_ops must be a list."

    batch_size_original = dataset.batch_size

    # Reset data counter
    dataset.reset()

    # Set batch size in the evaluation
    if eval_batch_size is not None:
        dataset.batch_size = eval_batch_size

    if label_type == 'character':
        idx2char = Idx2char(
            map_file_path='../metrics/mapping_files/character.txt')
        char2idx = Char2idx(
            map_file_path='../metrics/mapping_files/character.txt')
    elif label_type == 'character_capital_divide':
        raise NotImplementedError
    else:
        raise TypeError

    # Define decoder
    decoder = BeamSearchDecoder(space_index=char2idx('_')[0],
                                blank_index=model.num_classes - 1)

    cer_mean, wer_mean = 0, 0
    if progressbar:
        pbar = tqdm(total=len(dataset))
    for data, is_new_epoch in dataset:

        # Create feed dictionary for next mini batch
        if is_multitask:
            inputs, _, labels_true, inputs_seq_len, _ = data
        else:
            inputs, labels_true, inputs_seq_len, _ = data

        feed_dict = {}
        for i_device in range(len(posteriors_ops)):
            feed_dict[model.inputs_pl_list[i_device]] = inputs[i_device]
            feed_dict[model.inputs_seq_len_pl_list[i_device]] = inputs_seq_len[
                i_device]
            feed_dict[model.keep_prob_pl_list[i_device]] = 1.0

        posteriors_list = session.run(posteriors_ops, feed_dict=feed_dict)
        for i_device, labels_pred_st in enumerate(posteriors_list):
            batch_size_device, max_time = inputs[i_device].shape[:2]

            posteriors = posteriors_list[i_device].reshape(
                batch_size_device, max_time, model.num_classes)

            for i_batch in range(batch_size_device):

                # Decode per utterance
                labels_pred, scores = decoder(
                    probs=posteriors[i_batch:i_batch + 1],
                    seq_len=inputs_seq_len[i_device][i_batch:i_batch + 1],
                    beam_width=beam_width)

                # Convert from list of index to string
                if is_test:
                    str_true = labels_true[i_device][i_batch][0]
                    # NOTE: transcript is seperated by space('_')
                else:
                    str_true = idx2char(labels_true[i_device][i_batch],
                                        padded_value=dataset.padded_value)
                str_pred = idx2char(labels_pred[0])

                # Remove consecutive spaces
                str_pred = re.sub(r'[_]+', '_', str_pred)

                # Remove garbage labels
                str_true = re.sub(r'[\']+', '', str_true)
                str_pred = re.sub(r'[\']+', '', str_pred)

                # Compute WER
                wer_mean += compute_wer(ref=str_true.split('_'),
                                        hyp=str_pred.split('_'),
                                        normalize=True)
                # substitute, insert, delete = wer_align(
                #     ref=str_pred.split('_'),
                #     hyp=str_true.split('_'))
                # print('SUB: %d' % substitute)
                # print('INS: %d' % insert)
                # print('DEL: %d' % delete)

                # Remove spaces
                str_true = re.sub(r'[_]+', '', str_true)
                str_pred = re.sub(r'[_]+', '', str_pred)

                # Compute CER
                cer_mean += compute_cer(str_pred=str_pred,
                                        str_true=str_true,
                                        normalize=True)

                if progressbar:
                    pbar.update(1)

        if is_new_epoch:
            break

    cer_mean /= (len(dataset))
    wer_mean /= (len(dataset))

    # Register original batch size
    if eval_batch_size is not None:
        dataset.batch_size = batch_size_original

    return cer_mean, wer_mean
def do_eval_cer(session,
                decode_ops,
                model,
                dataset,
                label_type,
                is_test=False,
                eval_batch_size=None,
                progressbar=False,
                is_multitask=False):
    """Evaluate trained model by Character Error Rate.
    Args:
        session: session of training model
        decode_ops: list of operations for decoding
        model: the model to evaluate
        dataset: An instance of a `Dataset` class
        label_type (string): character or character_capital_divide
        is_test (bool, optional): set to True when evaluating by the test set
        eval_batch_size (int, optional): the batch size when evaluating the model
        progressbar (bool, optional): if True, visualize the progressbar
        is_multitask (bool, optional): if True, evaluate the multitask model
    Return:
        cer_mean (float): An average of CER
        wer_mean (float): An average of WER
    """
    assert isinstance(decode_ops, list), "decode_ops must be a list."

    batch_size_original = dataset.batch_size

    # Reset data counter
    dataset.reset()

    # Set batch size in the evaluation
    if eval_batch_size is not None:
        dataset.batch_size = eval_batch_size

    if label_type == 'character':
        idx2char = Idx2char(
            map_file_path='../metrics/mapping_files/character.txt')
    elif label_type == 'character_capital_divide':
        idx2char = Idx2char(
            map_file_path=
            '../metrics/mapping_files/character_capital_divide.txt',
            capital_divide=True,
            space_mark='_')
    else:
        raise TypeError

    cer_mean, wer_mean = 0, 0
    skip_data_num = 0
    if progressbar:
        pbar = tqdm(total=len(dataset))
    for data, is_new_epoch in dataset:

        # Create feed dictionary for next mini batch
        if is_multitask:
            inputs, _, labels_true, inputs_seq_len, _ = data
        else:
            inputs, labels_true, inputs_seq_len, _ = data

        feed_dict = {}
        for i_device in range(len(decode_ops)):
            feed_dict[model.inputs_pl_list[i_device]] = inputs[i_device]
            feed_dict[model.inputs_seq_len_pl_list[i_device]] = inputs_seq_len[
                i_device]
            feed_dict[model.keep_prob_pl_list[i_device]] = 1.0

        labels_pred_st_list = session.run(decode_ops, feed_dict=feed_dict)
        for i_device, labels_pred_st in enumerate(labels_pred_st_list):
            batch_size_device = len(inputs[i_device])
            try:
                labels_pred = sparsetensor2list(labels_pred_st,
                                                batch_size_device)
                for i_batch in range(batch_size_device):

                    # Convert from list of index to string
                    if is_test:
                        str_true = labels_true[i_device][i_batch][0]
                        # NOTE: transcript is seperated by space('_')
                    else:
                        str_true = idx2char(labels_true[i_device][i_batch],
                                            padded_value=dataset.padded_value)
                    str_pred = idx2char(labels_pred[i_batch])

                    # Remove consecutive spaces
                    str_pred = re.sub(r'[_]+', '_', str_pred)

                    # Remove garbage labels
                    str_true = re.sub(r'[\']+', '', str_true)
                    str_pred = re.sub(r'[\']+', '', str_pred)

                    # Compute WER
                    wer_mean += compute_wer(ref=str_true.split('_'),
                                            hyp=str_pred.split('_'),
                                            normalize=True)
                    # substitute, insert, delete = wer_align(
                    #     ref=str_pred.split('_'),
                    #     hyp=str_true.split('_'))
                    # print('SUB: %d' % substitute)
                    # print('INS: %d' % insert)
                    # print('DEL: %d' % delete)

                    # Remove spaces
                    str_true = re.sub(r'[_]+', '', str_true)
                    str_pred = re.sub(r'[_]+', '', str_pred)

                    # Compute CER
                    cer_mean += compute_cer(str_pred=str_pred,
                                            str_true=str_true,
                                            normalize=True)

                    if progressbar:
                        pbar.update(1)

            except IndexError:
                print('skipped')
                skip_data_num += batch_size_device
                # TODO: Conduct decoding again with batch size 1

                if progressbar:
                    pbar.update(batch_size_device)

        if is_new_epoch:
            break

    cer_mean /= (len(dataset) - skip_data_num)
    wer_mean /= (len(dataset) - skip_data_num)
    # TODO: Fix this

    # Register original batch size
    if eval_batch_size is not None:
        dataset.batch_size = batch_size_original

    return cer_mean, wer_mean
def do_eval_cer(session,
                decode_ops,
                model,
                dataset,
                label_type,
                train_data_size,
                is_test=False,
                eval_batch_size=None,
                progressbar=False,
                is_multitask=False,
                is_main=False):
    """Evaluate trained model by Character Error Rate.
    Args:
        session: session of training model
        decode_op: operation for decoding
        model: the model to evaluate
        dataset: An instance of `Dataset` class
        label_type (string): kanji or kanji or kanji_divide or kana_divide
        train_data_size (string): train_subset or train_fullset
        is_test (bool, optional): set to True when evaluating by the test set
        eval_batch_size (int, optional): the batch size when evaluating the model
        progressbar (bool, optional): if True, visualize progressbar
        is_multitask (bool, optional): if True, evaluate the multitask model
        is_main (bool, optional): if True, evaluate the main task
    Return:
        cer_mean: An average of CER
    """
    # NOTE: add multitask version

    assert isinstance(decode_ops, list), "decode_ops must be a list."

    batch_size_original = dataset.batch_size

    # Reset data counter
    dataset.reset()

    # Set batch size in the evaluation
    if eval_batch_size is not None:
        dataset.batch_size = eval_batch_size

    if 'kanji' in label_type:
        map_file_path = '../metrics/mapping_files/' + \
            label_type + '_' + train_data_size + '.txt'
    elif 'kana' in label_type:
        map_file_path = '../metrics/mapping_files/' + label_type + '.txt'
    else:
        raise TypeError

    idx2char = Idx2char(map_file_path=map_file_path)

    cer_mean = 0
    skip_data_num = 0
    if progressbar:
        pbar = tqdm(total=len(dataset))
    for data, is_new_epoch in dataset:

        # Create feed dictionary for next mini batch
        inputs, labels_true, inputs_seq_len, _ = data

        feed_dict = {}
        for i_device in range(len(decode_ops)):
            feed_dict[model.inputs_pl_list[i_device]] = inputs[i_device]
            feed_dict[model.inputs_seq_len_pl_list[i_device]] = inputs_seq_len[
                i_device]
            feed_dict[model.keep_prob_pl_list[i_device]] = 1.0

        labels_pred_st_list = session.run(decode_ops, feed_dict=feed_dict)
        for i_device, labels_pred_st in enumerate(labels_pred_st_list):
            batch_size_device = len(inputs[i_device])
            try:
                labels_pred = sparsetensor2list(labels_pred_st,
                                                batch_size_device)
                for i_batch in range(batch_size_device):

                    # Convert from list of index to string
                    if is_test:
                        str_true = labels_true[i_device][i_batch][0]
                        # NOTE: transcript may be seperated by space('_')
                    else:
                        str_true = idx2char(labels_true[i_device][i_batch])
                    str_pred = idx2char(labels_pred[i_batch])

                    # Remove garbage labels
                    str_true = re.sub(r'[_NZー・]+', '', str_true)
                    str_pred = re.sub(r'[_NZー・]+', '', str_pred)

                    # Compute CER
                    cer_mean += compute_cer(str_pred=str_pred,
                                            str_true=str_true,
                                            normalize=True)

                    if progressbar:
                        pbar.update(1)
            except:
                print('skipped')
                skip_data_num += batch_size_device
                # TODO: Conduct decoding again with batch size 1

                if progressbar:
                    pbar.update(batch_size_device)

        if is_new_epoch:
            break

    cer_mean /= (len(dataset) - skip_data_num)

    # Register original batch size
    if eval_batch_size is not None:
        dataset.batch_size = batch_size_original

    return cer_mean
Ejemplo n.º 10
0
def do_eval_cer(session, decode_op, model, dataset, label_type,
                is_test=False, eval_batch_size=None, progressbar=False,
                is_multitask=False, is_jointctcatt=False):
    """Evaluate trained model by Character Error Rate.
    Args:
        session: session of training model
        decode_op: operation for decoding
        model: the model to evaluate
        dataset: An instance of a `Dataset` class
        label_type (string): character or character_capital_divide
        is_test (bool, optional): set to True when evaluating by the test set
        eval_batch_size (int, optional): batch size when evaluating the model
        progressbar (bool, optional): if True, visualize the progressbar
        is_multitask (bool, optional): if True, evaluate the multitask model
        is_jointctcatt (bool, optional): if True, evaluate the joint
            CTC-Attention model
    Return:
        cer_mean (float): An average of CER
        wer_mean (float): An average of WER
    """
    batch_size_original = dataset.batch_size

    # Reset data counter
    dataset.reset()

    # Set batch size in the evaluation
    if eval_batch_size is not None:
        dataset.batch_size = eval_batch_size

    idx2char = Idx2char(
        map_file_path='../metrics/mapping_files/' + label_type + '.txt')

    cer_mean, wer_mean = 0, 0
    if progressbar:
        pbar = tqdm(total=len(dataset))
    for data, is_new_epoch in dataset:

        # Create feed dictionary for next mini-batch
        if is_multitask:
            inputs, labels_true, _, inputs_seq_len, labels_seq_len, _ = data
        elif is_jointctcatt:
            inputs, labels_true, _, inputs_seq_len, labels_seq_len, _ = data
        else:
            inputs, labels_true, inputs_seq_len, labels_seq_len, _ = data

        feed_dict = {
            model.inputs_pl_list[0]: inputs[0],
            model.inputs_seq_len_pl_list[0]: inputs_seq_len[0],
            model.keep_prob_encoder_pl_list[0]: 1.0,
            model.keep_prob_decoder_pl_list[0]: 1.0,
            model.keep_prob_embedding_pl_list[0]: 1.0
        }

        batch_size = inputs[0].shape[0]

        labels_pred = session.run(decode_op, feed_dict=feed_dict)
        for i_batch in range(batch_size):

            # Convert from list of index to string
            if is_test:
                str_true = labels_true[0][i_batch][0]
                # NOTE: transcript is seperated by space('_')
            else:
                # Convert from list of index to string
                str_true = idx2char(
                    labels_true[0][i_batch][1:labels_seq_len[0][i_batch] - 1],
                    padded_value=dataset.padded_value)
            str_pred = idx2char(labels_pred[i_batch]).split('>')[0]
            # NOTE: Trancate by <EOS>

            # Remove consecutive spaces
            str_pred = re.sub(r'[_]+', '_', str_pred)

            # Remove garbage labels
            str_true = re.sub(r'[<>\'\":;!?,.-]+', '', str_true)
            str_pred = re.sub(r'[<>\'\":;!?,.-]+', '', str_pred)

            # Compute WER
            wer_mean += compute_wer(hyp=str_pred.split('_'),
                                    ref=str_true.split('_'),
                                    normalize=True)
            # substitute, insert, delete = wer_align(
            #     ref=str_pred.split('_'),
            #     hyp=str_true.split('_'))
            # print('SUB: %d' % substitute)
            # print('INS: %d' % insert)
            # print('DEL: %d' % delete)

            # Remove spaces
            str_pred = re.sub(r'[_]+', '', str_pred)
            str_true = re.sub(r'[_]+', '', str_true)

            # Compute CER
            cer_mean += compute_cer(str_pred=str_pred,
                                    str_true=str_true,
                                    normalize=True)

            if progressbar:
                pbar.update(1)

        if is_new_epoch:
            break

    cer_mean /= len(dataset)
    wer_mean /= len(dataset)

    # Register original batch size
    if eval_batch_size is not None:
        dataset.batch_size = batch_size_original

    return cer_mean, wer_mean
def do_eval_cer(session,
                decode_op,
                model,
                dataset,
                label_type,
                ss_type,
                is_test=False,
                eval_batch_size=None,
                progressbar=False):
    """Evaluate trained model by Character Error Rate.
    Args:
        session: session of training model
        decode_op: operation for decoding
        model: the model to evaluate
        dataset: An instance of a `Dataset` class
        label_type (string): kana
        ss_type (string): remove or insert_left or insert_both or insert_right
        is_test (bool, optional): set to True when evaluating by the test set
        eval_batch_size (int, optional): the batch size when evaluating the model
        progressbar (bool, optional): if True, visualize the progressbar
    Return:
        cer_mean (float): An average of CER
    """
    batch_size_original = dataset.batch_size

    # Reset data counter
    dataset.reset()

    # Set batch size in the evaluation
    if eval_batch_size is not None:
        dataset.batch_size = eval_batch_size

    idx2char = Idx2char(map_file_path='../metrics/mapping_files/' +
                        label_type + '_' + ss_type + '.txt')

    cer_mean = 0
    skip_data_num = 0
    if progressbar:
        pbar = tqdm(total=len(dataset))
    for data, is_new_epoch in dataset:

        # Create feed dictionary for next mini batch
        inputs, labels_true, inputs_seq_len, _ = data

        feed_dict = {
            model.inputs_pl_list[0]: inputs[0],
            model.inputs_seq_len_pl_list[0]: inputs_seq_len[0],
            model.keep_prob_pl_list[0]: 1.0
        }

        batch_size = inputs[0].shape[0]

        labels_pred_st = session.run(decode_op, feed_dict=feed_dict)

        try:
            labels_pred = sparsetensor2list(labels_pred_st, batch_size)

            for i_batch in range(batch_size):

                # Convert from list of index to string
                if is_test:
                    str_true = labels_true[0][i_batch][0]
                    # NOTE: transcript is seperated by space('_')
                else:
                    # Convert from list of index to string
                    str_true = idx2char(labels_true[0][i_batch],
                                        padded_value=dataset.padded_value)
                str_pred = idx2char(labels_pred[i_batch])

                # Remove garbage labels
                str_true = re.sub(r'[_NZLFBDlfbdー]+', '', str_true)
                str_pred = re.sub(r'[_NZLFBDlfbdー]+', '', str_pred)

                # Compute CER
                cer_mean += compute_cer(str_pred=str_pred,
                                        str_true=str_true,
                                        normalize=True)

                if progressbar:
                    pbar.update(1)
        except:
            print('skipped')
            skip_data_num += batch_size
            # TODO: Conduct decoding again with batch size 1

            if progressbar:
                pbar.update(batch_size)

        if is_new_epoch:
            break

    cer_mean /= (len(dataset) - skip_data_num)

    # Register original batch size
    if eval_batch_size is not None:
        dataset.batch_size = batch_size_original

    return cer_mean
    def check(self, decoder_type):

        print('==================================================')
        print('  decoder_type: %s' % decoder_type)
        print('==================================================')

        tf.reset_default_graph()
        with tf.Graph().as_default():
            # Load batch data
            batch_size = 2
            num_stack = 2
            inputs, labels, inputs_seq_len = generate_data(
                label_type='character',
                model='ctc',
                batch_size=batch_size,
                num_stack=num_stack,
                splice=1)
            max_time = inputs.shape[1]

            # Define model graph
            model = CTC(encoder_type='blstm',
                        input_size=inputs[0].shape[-1],
                        splice=1,
                        num_stack=num_stack,
                        num_units=256,
                        num_layers=2,
                        num_classes=27,
                        lstm_impl='LSTMBlockCell',
                        parameter_init=0.1,
                        clip_grad_norm=5.0,
                        clip_activation=50,
                        num_proj=256,
                        weight_decay=1e-6)

            # Define placeholders
            model.create_placeholders()

            # Add to the graph each operation
            _, logits = model.compute_loss(model.inputs_pl_list[0],
                                           model.labels_pl_list[0],
                                           model.inputs_seq_len_pl_list[0],
                                           model.keep_prob_pl_list[0])
            beam_width = 20 if 'beam_search' in decoder_type else 1
            decode_op = model.decoder(logits,
                                      model.inputs_seq_len_pl_list[0],
                                      beam_width=beam_width)
            ler_op = model.compute_ler(decode_op, model.labels_pl_list[0])
            posteriors_op = model.posteriors(logits, blank_prior=1)

            if decoder_type == 'np_greedy':
                decoder = GreedyDecoder(blank_index=model.num_classes)
            elif decoder_type == 'np_beam_search':
                decoder = BeamSearchDecoder(space_index=26,
                                            blank_index=model.num_classes - 1)

            # Make feed dict
            feed_dict = {
                model.inputs_pl_list[0]: inputs,
                model.labels_pl_list[0]: list2sparsetensor(labels,
                                                           padded_value=-1),
                model.inputs_seq_len_pl_list[0]: inputs_seq_len,
                model.keep_prob_pl_list[0]: 1.0
            }

            # Create a saver for writing training checkpoints
            saver = tf.train.Saver()

            with tf.Session() as sess:
                ckpt = tf.train.get_checkpoint_state('./')

                # If check point exists
                if ckpt:
                    model_path = ckpt.model_checkpoint_path
                    saver.restore(sess, model_path)
                    print("Model restored: " + model_path)
                else:
                    raise ValueError('There are not any checkpoints.')

                if decoder_type in ['tf_greedy', 'tf_beam_search']:
                    # Decode
                    labels_pred_st = sess.run(decode_op, feed_dict=feed_dict)
                    labels_pred = sparsetensor2list(labels_pred_st,
                                                    batch_size=batch_size)

                    # Compute accuracy
                    cer = sess.run(ler_op, feed_dict=feed_dict)
                else:
                    # Compute CTC posteriors
                    probs = sess.run(posteriors_op, feed_dict=feed_dict)
                    probs = probs.reshape(-1, max_time, model.num_classes)

                    if decoder_type == 'np_greedy':
                        # Decode
                        labels_pred = decoder(probs=probs,
                                              seq_len=inputs_seq_len)

                    elif decoder_type == 'np_beam_search':
                        # Decode
                        labels_pred, scores = decoder(probs=probs,
                                                      seq_len=inputs_seq_len,
                                                      beam_width=beam_width)

                    # Compute accuracy
                    cer = compute_cer(str_pred=idx2alpha(labels_pred[0]),
                                      str_true=idx2alpha(labels[0]),
                                      normalize=True)

                # Visualize
                print('CER: %.3f %%' % (cer * 100))
                print('Ref: %s' % idx2alpha(labels[0]))
                print('Hyp: %s' % idx2alpha(labels_pred[0]))
def do_eval_cer(session, decode_ops, model, dataset, label_type,
                train_data_size,
                is_test=False, eval_batch_size=None, progressbar=False,
                is_multitask=False, is_main=False):
    """Evaluate trained model by Character Error Rate.
    Args:
        session: session of training model
        decode_ops (list): operations for decoding
        model: the model to evaluate
        dataset: An instance of `Dataset` class
        label_type (string): kanji or kanji or kanji_divide or kana_divide
        train_data_size (string): train_subset or train_fullset
        is_test (bool, optional): set to True when evaluating by the test set
        eval_batch_size (int, optional): the batch size when evaluating the model
        progressbar (bool, optional): if True, visualize progressbar
        is_multitask (bool, optional): if True, evaluate the multitask model
        is_main (bool, optional): if True, evaluate the main task
    Return:
        cer_mean (float): An average of CER
    """
    # NOTE: add multitask version

    assert isinstance(decode_ops, list), "decode_ops must be a list."

    batch_size_original = dataset.batch_size

    # Reset data counter
    dataset.reset()

    # Set batch size in the evaluation
    if eval_batch_size is not None:
        dataset.batch_size = eval_batch_size

    if 'kanji' in label_type:
        map_file_path = '../metrics/mapping_files/' + \
            label_type + '_' + train_data_size + '.txt'
    elif 'kana' in label_type:
        map_file_path = '../metrics/mapping_files/' + label_type + '.txt'
    else:
        raise TypeError

    idx2char = Idx2char(map_file_path=map_file_path)

    cer_mean = 0
    skip_data_num = 0
    if progressbar:
        pbar = tqdm(total=len(dataset))
    for data, is_new_epoch in dataset:

        # Create feed dictionary for next mini batch
        inputs, labels_true, inputs_seq_len, _ = data

        feed_dict = {}
        for i_device in range(len(decode_ops)):
            feed_dict[model.inputs_pl_list[i_device]] = inputs[i_device]
            feed_dict[model.inputs_seq_len_pl_list[i_device]
                      ] = inputs_seq_len[i_device]
            feed_dict[model.keep_prob_pl_list[i_device]] = 1.0

        labels_pred_st_list = session.run(decode_ops, feed_dict=feed_dict)
        for i_device in range(len(labels_pred_st_list)):
            batch_size_device = len(inputs[i_device])
            try:
                labels_pred = sparsetensor2list(labels_pred_st_list[i_device],
                                                batch_size_device)
                for i_batch in range(batch_size_device):

                    # Convert from list of index to string
                    if is_test:
                        str_true = labels_true[i_device][i_batch][0]
                        # NOTE: transcript may be seperated by space('_')
                    else:
                        str_true = idx2char(labels_true[i_device][i_batch])
                    str_pred = idx2char(labels_pred[i_batch])

                    # Remove garbage labels
                    str_true = re.sub(r'[_NZー・]+', '', str_true)
                    str_pred = re.sub(r'[_NZー・]+', '', str_pred)

                    # Compute CER
                    cer_mean += compute_cer(str_pred=str_pred,
                                            str_true=str_true,
                                            normalize=True)

                    if progressbar:
                        pbar.update(1)
            except:
                print('skipped')
                skip_data_num += batch_size_device
                # TODO: Conduct decoding again with batch size 1

                if progressbar:
                    pbar.update(batch_size_device)

        if is_new_epoch:
            break

    cer_mean /= (len(dataset) - skip_data_num)

    # Register original batch size
    if eval_batch_size is not None:
        dataset.batch_size = batch_size_original

    return cer_mean
def do_eval_cer(session, decode_ops, model, dataset, label_type,
                is_test=False, eval_batch_size=None, progressbar=False,
                is_multitask=False):
    """Evaluate trained model by Character Error Rate.
    Args:
        session: session of training model
        decode_ops: list of operations for decoding
        model: the model to evaluate
        dataset: An instance of a `Dataset` class
        label_type (string): character or character_capital_divide
        is_test (bool, optional): set to True when evaluating by the test set
        eval_batch_size (int, optional): the batch size when evaluating the model
        progressbar (bool, optional): if True, visualize the progressbar
        is_multitask (bool, optional): if True, evaluate the multitask model
    Return:
        cer_mean (float): An average of CER
        wer_mean (float): An average of WER
    """
    assert isinstance(decode_ops, list), "decode_ops must be a list."

    batch_size_original = dataset.batch_size

    # Reset data counter
    dataset.reset()

    # Set batch size in the evaluation
    if eval_batch_size is not None:
        dataset.batch_size = eval_batch_size

    if label_type == 'character':
        idx2char = Idx2char(
            map_file_path='../metrics/mapping_files/character.txt')
    elif label_type == 'character_capital_divide':
        idx2char = Idx2char(
            map_file_path='../metrics/mapping_files/character_capital_divide.txt',
            capital_divide=True,
            space_mark='_')
    else:
        raise TypeError

    cer_mean, wer_mean = 0, 0
    skip_data_num = 0
    if progressbar:
        pbar = tqdm(total=len(dataset))
    for data, is_new_epoch in dataset:

        # Create feed dictionary for next mini batch
        if is_multitask:
            inputs, _, labels_true, inputs_seq_len, _ = data
        else:
            inputs, labels_true, inputs_seq_len, _ = data

        feed_dict = {}
        for i_device in range(len(decode_ops)):
            feed_dict[model.inputs_pl_list[i_device]] = inputs[i_device]
            feed_dict[model.inputs_seq_len_pl_list[i_device]
                      ] = inputs_seq_len[i_device]
            feed_dict[model.keep_prob_pl_list[i_device]] = 1.0

        labels_pred_st_list = session.run(decode_ops, feed_dict=feed_dict)
        for i_device, labels_pred_st in enumerate(labels_pred_st_list):
            batch_size_device = len(inputs[i_device])
            try:
                labels_pred = sparsetensor2list(labels_pred_st,
                                                batch_size_device)
                for i_batch in range(batch_size_device):

                    # Convert from list of index to string
                    if is_test:
                        str_true = labels_true[i_device][i_batch][0]
                        # NOTE: transcript is seperated by space('_')
                    else:
                        str_true = idx2char(labels_true[i_device][i_batch],
                                            padded_value=dataset.padded_value)
                    str_pred = idx2char(labels_pred[i_batch])

                    # Remove consecutive spaces
                    str_pred = re.sub(r'[_]+', '_', str_pred)

                    # Remove garbage labels
                    str_true = re.sub(r'[\']+', '', str_true)
                    str_pred = re.sub(r'[\']+', '', str_pred)

                    # Compute WER
                    wer_mean += compute_wer(ref=str_true.split('_'),
                                            hyp=str_pred.split('_'),
                                            normalize=True)
                    # substitute, insert, delete = wer_align(
                    #     ref=str_pred.split('_'),
                    #     hyp=str_true.split('_'))
                    # print('SUB: %d' % substitute)
                    # print('INS: %d' % insert)
                    # print('DEL: %d' % delete)

                    # Remove spaces
                    str_true = re.sub(r'[_]+', '', str_true)
                    str_pred = re.sub(r'[_]+', '', str_pred)

                    # Compute CER
                    cer_mean += compute_cer(str_pred=str_pred,
                                            str_true=str_true,
                                            normalize=True)

                    if progressbar:
                        pbar.update(1)

            except IndexError:
                print('skipped')
                skip_data_num += batch_size_device
                # TODO: Conduct decoding again with batch size 1

                if progressbar:
                    pbar.update(batch_size_device)

        if is_new_epoch:
            break

    cer_mean /= (len(dataset) - skip_data_num)
    wer_mean /= (len(dataset) - skip_data_num)
    # TODO: Fix this

    # Register original batch size
    if eval_batch_size is not None:
        dataset.batch_size = batch_size_original

    return cer_mean, wer_mean
def do_eval_cer2(session, posteriors_ops, beam_width, model, dataset,
                 label_type, is_test=False, eval_batch_size=None,
                 progressbar=False, is_multitask=False):
    """Evaluate trained model by Character Error Rate.
    Args:
        session: session of training model
        posteriors_ops: list of operations for computing posteriors
        beam_width (int):
        model: the model to evaluate
        dataset: An instance of a `Dataset` class
        label_type (string): character or character_capital_divide
        is_test (bool, optional): set to True when evaluating by the test set
        eval_batch_size (int, optional): the batch size when evaluating the model
        progressbar (bool, optional): if True, visualize the progressbar
        is_multitask (bool, optional): if True, evaluate the multitask model
    Return:
        cer_mean (float): An average of CER
        wer_mean (float): An average of WER
    """
    assert isinstance(posteriors_ops, list), "posteriors_ops must be a list."

    batch_size_original = dataset.batch_size

    # Reset data counter
    dataset.reset()

    # Set batch size in the evaluation
    if eval_batch_size is not None:
        dataset.batch_size = eval_batch_size

    if label_type == 'character':
        idx2char = Idx2char(
            map_file_path='../metrics/mapping_files/character.txt')
        char2idx = Char2idx(
            map_file_path='../metrics/mapping_files/character.txt')
    elif label_type == 'character_capital_divide':
        raise NotImplementedError
    else:
        raise TypeError

    # Define decoder
    decoder = BeamSearchDecoder(space_index=char2idx('_')[0],
                                blank_index=model.num_classes - 1)

    cer_mean, wer_mean = 0, 0
    if progressbar:
        pbar = tqdm(total=len(dataset))
    for data, is_new_epoch in dataset:

        # Create feed dictionary for next mini batch
        if is_multitask:
            inputs, _, labels_true, inputs_seq_len, _ = data
        else:
            inputs, labels_true, inputs_seq_len, _ = data

        feed_dict = {}
        for i_device in range(len(posteriors_ops)):
            feed_dict[model.inputs_pl_list[i_device]] = inputs[i_device]
            feed_dict[model.inputs_seq_len_pl_list[i_device]
                      ] = inputs_seq_len[i_device]
            feed_dict[model.keep_prob_pl_list[i_device]] = 1.0

        posteriors_list = session.run(posteriors_ops, feed_dict=feed_dict)
        for i_device, labels_pred_st in enumerate(posteriors_list):
            batch_size_device, max_time = inputs[i_device].shape[:2]

            posteriors = posteriors_list[i_device].reshape(
                batch_size_device, max_time, model.num_classes)

            for i_batch in range(batch_size_device):

                # Decode per utterance
                labels_pred, scores = decoder(
                    probs=posteriors[i_batch:i_batch + 1],
                    seq_len=inputs_seq_len[i_device][i_batch: i_batch + 1],
                    beam_width=beam_width)

                # Convert from list of index to string
                if is_test:
                    str_true = labels_true[i_device][i_batch][0]
                    # NOTE: transcript is seperated by space('_')
                else:
                    str_true = idx2char(labels_true[i_device][i_batch],
                                        padded_value=dataset.padded_value)
                str_pred = idx2char(labels_pred[0])

                # Remove consecutive spaces
                str_pred = re.sub(r'[_]+', '_', str_pred)

                # Remove garbage labels
                str_true = re.sub(r'[\']+', '', str_true)
                str_pred = re.sub(r'[\']+', '', str_pred)

                # Compute WER
                wer_mean += compute_wer(ref=str_true.split('_'),
                                        hyp=str_pred.split('_'),
                                        normalize=True)
                # substitute, insert, delete = wer_align(
                #     ref=str_pred.split('_'),
                #     hyp=str_true.split('_'))
                # print('SUB: %d' % substitute)
                # print('INS: %d' % insert)
                # print('DEL: %d' % delete)

                # Remove spaces
                str_true = re.sub(r'[_]+', '', str_true)
                str_pred = re.sub(r'[_]+', '', str_pred)

                # Compute CER
                cer_mean += compute_cer(str_pred=str_pred,
                                        str_true=str_true,
                                        normalize=True)

                if progressbar:
                    pbar.update(1)

        if is_new_epoch:
            break

    cer_mean /= (len(dataset))
    wer_mean /= (len(dataset))

    # Register original batch size
    if eval_batch_size is not None:
        dataset.batch_size = batch_size_original

    return cer_mean, wer_mean