def plot(model, dataset, eval_batch_size=None, save_path=None,
         space_index=None):
    """
    Args:
        model: the model to evaluate
        dataset: An instance of a `Dataset` class
        eval_batch_size (int, optional): the batch size when evaluating the model
        save_path (string): path to save figures of CTC posteriors
        space_index (int, optional):
    """
    # Set batch size in the evaluation
    if eval_batch_size is not None:
        dataset.batch_size = eval_batch_size

    # Clean directory
    if isdir(save_path):
        shutil.rmtree(save_path)
        mkdir(save_path)

    vocab_file_path = '../metrics/vocab_files/' + \
        dataset.label_type + '_' + dataset.data_size + '.txt'
    if dataset.label_type == 'character':
        map_fn = Idx2char(vocab_file_path)
    elif dataset.label_type == 'character_capital_divide':
        map_fn = Idx2char(vocab_file_path, capital_divide=True)
    else:
        map_fn = Idx2word(vocab_file_path)

    for batch, is_new_epoch in dataset:

        # Get CTC probs
        probs = model.posteriors(batch['xs'], batch['x_lens'], temperature=1)
        # NOTE: probs: '[B, T, num_classes]'

        # Decode
        best_hyps _ = model.decode(batch['xs'], batch['x_lens'], beam_width=1)

        # Visualize
        for b in range(len(batch['xs'])):

            # Convert from list of index to string
            str_pred = map_fn(best_hyps[b])

            speaker, book = batch['input_names'][b].split('-')[:2]
            plot_ctc_probs(
                probs[b, :batch['x_lens'][b], :],
                frame_num=batch['x_lens'][b],
                num_stack=dataset.num_stack,
                space_index=space_index,
                str_pred=str_pred,
                save_path=mkdir_join(save_path, speaker, book, batch['input_names'][b] + '.png'))

        if is_new_epoch:
            break
def plot(model, dataset, eval_batch_size, save_path=None):
    """
    Args:
        model: the model to evaluate
        dataset: An instance of a `Dataset` class
        eval_batch_size (int): the batch size when evaluating the model
        save_path (string): path to save figures of CTC posteriors
    """
    # Set batch size in the evaluation
    if eval_batch_size is not None:
        dataset.batch_size = eval_batch_size

    # Clean directory
    if isdir(save_path):
        shutil.rmtree(save_path)
        mkdir(save_path)

    idx2word = Idx2word(dataset.vocab_file_path)
    idx2char = Idx2char(
        dataset.vocab_file_path,
        capital_divide=dataset.label_type_sub == 'character_capital_divide')

    for batch, is_new_epoch in dataset:

        # Get CTC probs
        probs = model.posteriors(batch['xs'], batch['x_lens'], temperature=1)
        probs_sub = model.posteriors(batch['xs'],
                                     batch['x_lens'],
                                     is_sub_task=True,
                                     temperature=1)
        # NOTE: probs: '[B, T, num_classes]'
        # NOTE: probs_sub: '[B, T, num_classes_sub]'

        # Decode
        best_hyps = model.decode(batch['xs'], batch['x_lens'], beam_width=1)
        best_hyps_sub = model.decode(batch['xs'],
                                     batch['x_lens'],
                                     beam_width=1,
                                     is_sub_task=True)

        # Visualize
        for b in range(len(batch['xs'])):

            # Convert from list of index to string
            str_hyp = idx2word(best_hyps[b])
            str_hyp_sub = idx2char(best_hyps_sub[b])

            speaker = batch['input_names'][b].split('_')[0]
            plot_hierarchical_ctc_probs(probs[b, :batch['x_lens'][b], :],
                                        probs_sub[b, :batch['x_lens'][b], :],
                                        frame_num=batch['x_lens'][b],
                                        num_stack=dataset.num_stack,
                                        str_hyp=str_hyp,
                                        str_hyp_sub=str_hyp_sub,
                                        save_path=mkdir_join(
                                            save_path, speaker,
                                            batch['input_names'][b] + '.png'))

        if is_new_epoch:
            break
Beispiel #3
0
def plot(model, dataset, beam_width,
         eval_batch_size=None, save_path=None):
    """Visualize attention weights of attetnion-based model.
    Args:
        model: model to evaluate
        dataset: An instance of a `Dataset` class
        beam_width: (int): the size of beam
        eval_batch_size (int, optional): the batch size when evaluating the model
        save_path (string, optional): path to save attention weights plotting
    """
    # Clean directory
    if save_path is not None and isdir(save_path):
        shutil.rmtree(save_path)
        mkdir(save_path)

    if 'char' in dataset.label_type:
        map_fn = Idx2char(dataset.vocab_file_path,
                          capital_divide=dataset.label_type == 'character_capital_divide',
                          return_list=True)
        max_decode_len = MAX_DECODE_LEN_CHAR
    else:
        map_fn = Idx2word(dataset.vocab_file_path, return_list=True)
        max_decode_len = MAX_DECODE_LEN_WORD

    for batch, is_new_epoch in dataset:

        # Decode
        best_hyps, aw, perm_idx = model.attention_weights(
            batch['xs'], batch['x_lens'],
            beam_width=beam_width,
            max_decode_len=max_decode_len)

        ys = batch['ys'][perm_idx]
        y_lens = batch['y_lens'][perm_idx]

        for b in range(len(batch['xs'])):
            ##############################
            # Reference
            ##############################
            if dataset.is_test:
                str_ref = ys[b][0]
                # NOTE: transcript is seperated by space('_')
            else:
                # Convert from list of index to string
                str_ref = map_fn(ys[b][:y_lens[b]])

            token_list = map_fn(best_hyps[b])

            speaker = '_'.join(batch['input_names'][b].split('_')[:2])
            plot_attention_weights(
                aw[b, :len(token_list), :batch['x_lens'][b]],
                label_list=token_list,
                spectrogram=batch['xs'][b, :, :dataset.input_freq],
                str_ref=str_ref,
                save_path=mkdir_join(save_path, speaker,
                                     batch['input_names'][b] + '.png'),
                figsize=(20, 8))

        if is_new_epoch:
            break
    def check_loading(self, label_type, data_type='dev',
                      shuffle=False, sort_utt=False, sort_stop_epoch=None,
                      frame_stacking=False, splice=1):

        print('========================================')
        print('  label_type: %s' % label_type)
        print('  data_type: %s' % data_type)
        print('  shuffle: %s' % str(sort_utt))
        print('  sort_utt: %s' % str(sort_utt))
        print('  sort_stop_epoch: %s' % str(sort_stop_epoch))
        print('  frame_stacking: %s' % str(frame_stacking))
        print('  splice: %d' % splice)
        print('========================================')

        num_stack = 3 if frame_stacking else 1
        num_skip = 3 if frame_stacking else 1
        dataset = Dataset(
            data_type=data_type, label_type=label_type,
            batch_size=64, eos_index=1, max_epoch=2, splice=splice,
            num_stack=num_stack, num_skip=num_skip,
            shuffle=shuffle,
            sort_utt=sort_utt, sort_stop_epoch=sort_stop_epoch,
            progressbar=True)

        print('=> Loading mini-batch...')
        if label_type in ['character', 'character_capital_divide']:
            map_fn_ctc = Idx2char(
                map_file_path='../../metrics/mapping_files/ctc/' + label_type + '.txt')
            map_fn_att = Idx2char(
                map_file_path='../../metrics/mapping_files/attention/' + label_type + '.txt')
        else:
            map_fn_ctc = Idx2phone(
                map_file_path='../../metrics/mapping_files/ctc/' + label_type + '.txt')
            map_fn_att = Idx2phone(
                map_file_path='../../metrics/mapping_files/attention/' + label_type + '.txt')

        for data, is_new_epoch in dataset:
            inputs, att_labels, ctc_labels, inputs_seq_len, att_labels_seq_len, input_names = data

            att_str_true = map_fn_att(att_labels[0][0: att_labels_seq_len[0]])
            ctc_str_true = map_fn_ctc(ctc_labels[0])
            att_str_true = re.sub(r'_', ' ', att_str_true)
            ctc_str_true = re.sub(r'_', ' ', ctc_str_true)
            print('----- %s ----- (epoch: %.3f)' %
                  (input_names[0], dataset.epoch_detail))
            print(att_str_true)
            print(ctc_str_true)
Beispiel #5
0
    def check(self, ss_type, data_type='dev',
              shuffle=False, sort_utt=False, sort_stop_epoch=None,
              frame_stacking=False, splice=1):

        print('========================================')
        print('  label_type: %s' % 'kana')
        print('  ss_type: %s' % ss_type)
        print('  data_type: %s' % data_type)
        print('  shuffle: %s' % str(shuffle))
        print('  sort_utt: %s' % str(sort_utt))
        print('  sort_stop_epoch: %s' % str(sort_stop_epoch))
        print('  frame_stacking: %s' % str(frame_stacking))
        print('  splice: %d' % splice)
        print('========================================')

        map_file_path = '../../metrics/mapping_files/kana_' + ss_type + '.txt'

        num_stack = 3 if frame_stacking else 1
        num_skip = 3 if frame_stacking else 1
        dataset = Dataset(
            data_type=data_type, label_type='kana', ss_type=ss_type,
            batch_size=64, map_file_path=map_file_path,
            max_epoch=2, splice=splice,
            num_stack=num_stack, num_skip=num_skip,
            shuffle=shuffle,
            sort_utt=sort_utt, sort_stop_epoch=sort_stop_epoch,
            progressbar=True)

        print('=> Loading mini-batch...')
        map_fn = Idx2char(map_file_path=map_file_path)

        for data, is_new_epoch in dataset:
            inputs, labels, inputs_seq_len, labels_seq_len, input_names = data

            if data_type == 'train':
                for i_batch, l_batch in zip(inputs[0], labels[0]):
                    if len(np.where(l_batch == dataset.padded_value)[0]) > 0:
                        if i_batch.shape[0] < np.where(l_batch == dataset.padded_value)[0][0]:
                            raise ValueError(
                                'input length must be longer than label length.')
                    else:
                        if i_batch.shape[0] < len(l_batch):
                            raise ValueError(
                                'input length must be longer than label length.')

            if data_type != 'test':
                str_true = map_fn(labels[0][0][:labels_seq_len[0][0]])
            else:
                str_true = labels[0][0][0]

            print('----- %s ----- (epoch: %.3f)' %
                  (input_names[0][0], dataset.epoch_detail))
            print(inputs[0][0].shape)
            print(str_true)

            if dataset.epoch_detail >= 0.2:
                break
    def check(self,
              label_type_main,
              data_type='dev',
              shuffle=False,
              sort_utt=False,
              sort_stop_epoch=None,
              frame_stacking=False,
              splice=1):

        print('========================================')
        print('  label_type_main: %s' % label_type_main)
        print('  data_type: %s' % data_type)
        print('  shuffle: %s' % str(shuffle))
        print('  sort_utt: %s' % str(sort_utt))
        print('  sort_stop_epoch: %s' % str(sort_stop_epoch))
        print('  frame_stacking: %s' % str(frame_stacking))
        print('  splice: %d' % splice)
        print('========================================')

        num_stack = 3 if frame_stacking else 1
        num_skip = 3 if frame_stacking else 1
        dataset = Dataset(data_type=data_type,
                          label_type_main=label_type_main,
                          label_type_sub='phone61',
                          batch_size=64,
                          max_epoch=2,
                          splice=splice,
                          num_stack=num_stack,
                          num_skip=num_skip,
                          shuffle=shuffle,
                          sort_utt=sort_utt,
                          sort_stop_epoch=sort_stop_epoch,
                          progressbar=True)

        print('=> Loading mini-batch...')
        idx2char = Idx2char(map_file_path='../../metrics/mapping_files/' +
                            label_type_main + '.txt')
        idx2phone = Idx2phone(
            map_file_path='../../metrics/mapping_files/phone61.txt')

        for data, is_new_epoch in dataset:
            inputs, labels_char, labels_phone, inputs_seq_len, input_names = data

            if data_type != 'test':
                str_true_char = idx2char(labels_char[0][0])
                str_true_phone = idx2phone(labels_phone[0][0])
            else:
                str_true_char = labels_char[0][0][0]
                str_true_phone = labels_phone[0][0][0]

            print('----- %s ----- (epoch: %.3f)' %
                  (input_names[0][0], dataset.epoch_detail))
            print(str_true_char)
            print(str_true_phone)
Beispiel #7
0
def plot(model, dataset, eval_batch_size, beam_width, beam_width_sub,
         length_penalty, save_path=None):
    """Visualize attention weights of Attetnion-based model.
    Args:
        model: model to evaluate
        dataset: An instance of a `Dataset` class
        eval_batch_size (int): the batch size when evaluating the model
        beam_width: (int): the size of beam in the main task
        beam_width_sub: (int): the size of beam in the sub task
        length_penalty (float):
        save_path (string, optional): path to save attention weights plotting
    """
    # Clean directory
    if save_path is not None and isdir(save_path):
        shutil.rmtree(save_path)
        mkdir(save_path)

    map_fn_main = Idx2word(dataset.vocab_file_path, return_list=True)
    map_fn_sub = Idx2char(dataset.vocab_file_path_sub, return_list=True)

    for batch, is_new_epoch in dataset:

        # Decode
        best_hyps, aw, perm_idx = model.decode(
            batch['xs'], batch['x_lens'],
            beam_width=beam_width,
            max_decode_len=MAX_DECODE_LEN_WORD)
        best_hyps_sub, aw_sub, _ = model.decode(
            batch['xs'], batch['x_lens'],
            beam_width=beam_width_sub,
            max_decode_len=MAX_DECODE_LEN_CHAR,
            task_index=1)

        for b in range(len(batch['xs'])):

            word_list = map_fn_main(best_hyps[b])
            char_list = map_fn_sub(best_hyps_sub[b])

            speaker = batch['input_names'][b].split('_')[0]

            plot_hierarchical_attention_weights(
                aw[b][:len(word_list), :batch['x_lens'][b]],
                aw_sub[b][:len(char_list), :batch['x_lens'][b]],
                label_list=word_list,
                label_list_sub=char_list,
                spectrogram=batch['xs'][b, :, :dataset.input_freq],
                save_path=mkdir_join(save_path, speaker,
                                     batch['input_names'][b] + '.png'),
                figsize=(40, 8)
            )

        if is_new_epoch:
            break
    def check(self, label_type, data_type='dev',
              shuffle=False, sort_utt=False, sort_stop_epoch=None,
              frame_stacking=False, splice=1):

        print('========================================')
        print('  label_type: %s' % label_type)
        print('  data_type: %s' % data_type)
        print('  shuffle: %s' % str(shuffle))
        print('  sort_utt: %s' % str(sort_utt))
        print('  sort_stop_epoch: %s' % str(sort_stop_epoch))
        print('  frame_stacking: %s' % str(frame_stacking))
        print('  splice: %d' % splice)
        print('========================================')

        map_file_path = '../../metrics/mapping_files/' + label_type + '.txt'

        num_stack = 3 if frame_stacking else 1
        num_skip = 3 if frame_stacking else 1
        dataset = Dataset(
            data_type=data_type, label_type=label_type,
            batch_size=64,  map_file_path=map_file_path,
            max_epoch=1, splice=splice,
            num_stack=num_stack, num_skip=num_skip,
            shuffle=shuffle,
            sort_utt=sort_utt, sort_stop_epoch=sort_stop_epoch,
            progressbar=True)

        print('=> Loading mini-batch...')
        if label_type in ['character', 'character_capital_divide']:
            map_fn = Idx2char(map_file_path=map_file_path)
        else:
            map_fn = Idx2phone(map_file_path=map_file_path)

        for data, is_new_epoch in dataset:
            inputs, labels, inputs_seq_len, labels_seq_len, input_names = data

            if data_type != 'test':
                str_true = map_fn(labels[0][0][:labels_seq_len[0][0]])
            else:
                str_true = labels[0][0][0]

            print('----- %s ----- (epoch: %.3f)' %
                  (input_names[0][0], dataset.epoch_detail))
            print(inputs[0][0].shape)
            print(str_true)
def do_eval_cer(save_paths,
                dataset,
                data_type,
                label_type,
                num_classes,
                beam_width,
                temperature_infer,
                is_test=False,
                progressbar=False):
    """Evaluate trained model by Character Error Rate.
    Args:
        save_paths (list):
        dataset: An instance of a `Dataset` class
        data_type (string):
        label_type (string): character
        num_classes (int):
        beam_width (int): the size of beam
        temperature (int): temperature in the inference stage
        is_test (bool, optional): set to True when evaluating by the test set
        progressbar (bool, optional): if True, visualize the progressbar
    Return:
        cer_mean (float): An average of CER
        wer_mean (float): An average of WER
    """
    if label_type == 'character':
        idx2char = Idx2char(
            map_file_path='../metrics/mapping_files/character.txt')
        char2idx = Char2idx(
            map_file_path='../metrics/mapping_files/character.txt')
    else:
        raise TypeError

    # Define decoder
    decoder = BeamSearchDecoder(space_index=char2idx(str_char='_')[0],
                                blank_index=num_classes - 1)

    ##################################################
    # Compute mean probabilities
    ##################################################
    if progressbar:
        pbar = tqdm(total=len(dataset))
    cer_mean, wer_mean = 0, 0
    for data, is_new_epoch in dataset:

        # Create feed dictionary for next mini batch
        inputs, labels_true, inputs_seq_len, input_names = data

        batch_size = inputs[0].shape[0]
        for i_batch in range(batch_size):
            probs_ensemble = None
            for i_model in range(len(save_paths)):

                # Load posteriors
                speaker = input_names[0][i_batch].split('-')[0]
                prob_save_path = join(save_paths[i_model],
                                      'temp' + str(temperature_infer),
                                      data_type, 'probs_utt', speaker,
                                      input_names[0][i_batch] + '.npy')
                probs_model_i = np.load(prob_save_path)
                # NOTE: probs_model_i: `[T, num_classes]`

                # Sum over probs
                if probs_ensemble is None:
                    probs_ensemble = probs_model_i
                else:
                    probs_ensemble += probs_model_i

            # Compute mean posteriors
            probs_ensemble /= len(save_paths)

            # Decode per utterance
            labels_pred, scores = decoder(
                probs=probs_ensemble[np.newaxis, :, :],
                seq_len=inputs_seq_len[0][i_batch:i_batch + 1],
                beam_width=beam_width)

            # Convert from list of index to string
            if is_test:
                str_true = labels_true[0][i_batch][0]
                # NOTE: transcript is seperated by space('_')
            else:
                str_true = idx2char(labels_true[0][i_batch],
                                    padded_value=dataset.padded_value)
            str_pred = idx2char(labels_pred[0])

            # Remove consecutive spaces
            str_pred = re.sub(r'[_]+', '_', str_pred)

            # Remove garbage labels
            str_true = re.sub(r'[\']+', '', str_true)
            str_pred = re.sub(r'[\']+', '', str_pred)

            # Compute WER
            wer_mean += compute_wer(ref=str_pred.split('_'),
                                    hyp=str_true.split('_'),
                                    normalize=True)
            # substitute, insert, delete = wer_align(
            #     ref=str_true.split('_'),
            #     hyp=str_pred.split('_'))
            # print('SUB: %d' % substitute)
            # print('INS: %d' % insert)
            # print('DEL: %d' % delete)

            # Remove spaces
            str_true = re.sub(r'[_]+', '', str_true)
            str_pred = re.sub(r'[_]+', '', str_pred)

            # Compute CER
            cer_mean += compute_cer(str_pred=str_pred,
                                    str_true=str_true,
                                    normalize=True)

            if progressbar:
                pbar.update(1)

        if is_new_epoch:
            break

    cer_mean /= (len(dataset))
    wer_mean /= (len(dataset))
    # TODO: Fix this

    return cer_mean, wer_mean
Beispiel #10
0
def do_eval_cer(session,
                decode_ops,
                model,
                dataset,
                label_type,
                train_data_size,
                is_test=False,
                eval_batch_size=None,
                progressbar=False):
    """Evaluate trained model by Character Error Rate.
    Args:
        session: session of training model
        decode_ops (list): operations for decoding
        model: the model to evaluate
        dataset: An instance of a `Dataset` class
        label_type (string): kanji or kanji or kanji_divide or kana_divide
        train_data_size (string): train_subset or train_fullset
        is_test (bool, optional): set to True when evaluating by the test set
        eval_batch_size (int, optional): the batch size when evaluating the model
        progressbar (bool, optional): if True, visualize the progressbar
    Return:
        cer_mean (float): An average of CER
    """
    assert isinstance(decode_ops, list), "decode_ops must be a list."

    batch_size_original = dataset.batch_size

    # Reset data counter
    dataset.reset()

    # Set batch size in the evaluation
    if eval_batch_size is not None:
        dataset.batch_size = eval_batch_size

    if 'kanji' in label_type:
        map_file_path = '../metrics/mapping_files/' + \
            label_type + '_' + train_data_size + '.txt'
    elif 'kana' in label_type:
        map_file_path = '../metrics/mapping_files/' + label_type + '.txt'
    else:
        raise TypeError

    idx2char = Idx2char(map_file_path=map_file_path)

    cer_mean = 0
    if progressbar:
        pbar = tqdm(total=len(dataset))
    for data, is_new_epoch in dataset:

        # Create feed dictionary for next mini batch
        inputs, labels_true, inputs_seq_len, labels_seq_len, _ = data

        feed_dict = {}
        for i_device in range(len(decode_ops)):
            feed_dict[model.inputs_pl_list[i_device]] = inputs[i_device]
            feed_dict[model.inputs_seq_len_pl_list[i_device]] = inputs_seq_len[
                i_device]
            feed_dict[model.keep_prob_encoder_pl_list[i_device]] = 1.0
            feed_dict[model.keep_prob_decoder_pl_list[i_device]] = 1.0
            feed_dict[model.keep_prob_embedding_pl_list[i_device]] = 1.0

        labels_pred_list = session.run(decode_ops, feed_dict=feed_dict)
        for i_device in range(len(labels_pred_list)):
            for i_batch in range(len(inputs[i_device])):

                # Convert from list of index to string
                if is_test:
                    str_true = labels_true[i_device][i_batch][0]
                    # NOTE: transcript is seperated by space('_')
                else:
                    str_true = idx2char(labels_true[i_device][i_batch]
                                        [1:labels_seq_len[i_device][i_batch] -
                                         1])
                str_pred = idx2char(
                    labels_pred_list[i_device][i_batch]).split('>')[0]
                # NOTE: Trancate by <EOS>

                # Remove garbage labels
                str_true = re.sub(r'[_NZー・<>]+', '', str_true)
                str_pred = re.sub(r'[_NZー・<>]+', '', str_pred)

                # Compute CER
                cer_mean += compute_cer(str_pred=str_pred,
                                        str_true=str_true,
                                        normalize=True)

                if progressbar:
                    pbar.update(1)

        if is_new_epoch:
            break

    cer_mean /= len(dataset)

    # Register original batch size
    if eval_batch_size is not None:
        dataset.batch_size = batch_size_original

    return cer_mean
Beispiel #11
0
def eval_char(models,
              dataset,
              beam_width,
              max_decode_len,
              eval_batch_size=None,
              length_penalty=0,
              progressbar=False,
              temperature=1):
    """Evaluate trained model by Character Error Rate.
    Args:
        models (list): the models to evaluate
        dataset: An instance of a `Dataset' class
        beam_width: (int): the size of beam
        max_decode_len (int): the length of output sequences
            to stop prediction when EOS token have not been emitted.
            This is used for seq2seq models.
        length_penalty (float, optional):
        eval_batch_size (int, optional): the batch size when evaluating the model
        progressbar (bool, optional): if True, visualize the progressbar
    Returns:
        wer (float): Word error rate
        cer (float): Character error rate
        df_wer_cer (pd.DataFrame): dataframe of substitution, insertion, and deletion
    """
    # Reset data counter
    dataset.reset()

    if models[0].model_type in ['ctc', 'attention']:
        idx2char = Idx2char(
            vocab_file_path=dataset.vocab_file_path,
            capital_divide=(dataset.label_type == 'character_capital_divide'))
    else:
        idx2char = Idx2char(
            vocab_file_path=dataset.vocab_file_path_sub,
            capital_divide=(
                dataset.label_type_sub == 'character_capital_divide'))

    # Read GLM file
    glm = GLM(
        glm_path=
        '/n/sd8/inaguma/corpus/swbd/data/eval2000/LDC2002T43/reference/en20000405_hub5.glm'
    )

    wer, cer = 0, 0
    sub_word, ins_word, del_word = 0, 0, 0
    sub_char, ins_char, del_char = 0, 0, 0
    num_words, num_chars = 0, 0
    if progressbar:
        pbar = tqdm(total=len(dataset))  # TODO: fix this
    while True:
        batch, is_new_epoch = dataset.next(batch_size=eval_batch_size)

        # TODO: add CTC ensemble

        # Decode
        model = models[0]
        # TODO: fix this

        if model.model_type in ['ctc', 'attention']:
            best_hyps, _, perm_idx = model.decode(
                batch['xs'],
                batch['x_lens'],
                beam_width=beam_width,
                max_decode_len=max_decode_len,
                length_penalty=length_penalty)
            ys = batch['ys'][perm_idx]
            y_lens = batch['y_lens'][perm_idx]
        else:
            best_hyps, _, perm_idx = model.decode(
                batch['xs'],
                batch['x_lens'],
                beam_width=beam_width,
                max_decode_len=max_decode_len,
                length_penalty=length_penalty,
                task_index=1)
            ys = batch['ys_sub'][perm_idx]
            y_lens = batch['y_lens_sub'][perm_idx]

        for b in range(len(batch['xs'])):

            ##############################
            # Reference
            ##############################
            if dataset.is_test:
                str_ref = ys[b][0]
                # NOTE: transcript is seperated by space('_')
            else:
                # Convert from list of index to string
                str_ref = idx2char(ys[b][:y_lens[b]])

            ##############################
            # Hypothesis
            ##############################
            str_hyp = idx2char(best_hyps[b])
            if 'attention' in model.model_type:
                str_hyp = str_hyp.split('>')[0]
                # NOTE: Trancate by the first <EOS>

                # Remove the last space
                if len(str_hyp) > 0 and str_hyp[-1] == '_':
                    str_hyp = str_hyp[:-1]

            ##############################
            # Post-proccessing
            ##############################
            str_ref = fix_trans(str_ref, glm)
            str_hyp = fix_trans(str_hyp, glm)

            if len(str_ref) == 0:
                if progressbar:
                    pbar.update(1)
                continue

            try:
                # Compute WER
                wer_b, sub_b, ins_b, del_b = compute_wer(
                    ref=str_ref.split('_'),
                    hyp=str_hyp.split('_'),
                    normalize=False)
                wer += wer_b
                sub_word += sub_b
                ins_word += ins_b
                del_word += del_b
                num_words += len(str_ref.split('_'))

                # Compute CER
                cer_b, sub_b, ins_b, del_b = compute_wer(
                    ref=list(str_ref.replace('_', '')),
                    hyp=list(str_hyp.replace('_', '')),
                    normalize=False)
                cer += cer_b
                sub_char += sub_b
                ins_char += ins_b
                del_char += del_b
                num_chars += len(str_ref.replace('_', ''))
            except:
                pass

            if progressbar:
                pbar.update(1)

        if is_new_epoch:
            break

    if progressbar:
        pbar.close()

    # Reset data counters
    dataset.reset()

    wer /= num_words
    sub_word /= num_words
    ins_word /= num_words
    del_word /= num_words
    cer /= num_chars
    sub_char /= num_chars
    ins_char /= num_chars
    del_char /= num_chars

    df_wer_cer = pd.DataFrame(
        {
            'SUB': [sub_word * 100, sub_char * 100],
            'INS': [ins_word * 100, ins_char * 100],
            'DEL': [del_word * 100, del_char * 100]
        },
        columns=['SUB', 'INS', 'DEL'],
        index=['WER', 'CER'])

    return wer, cer, df_wer_cer
Beispiel #12
0
    def check(self,
              label_type,
              data_type='dev',
              shuffle=False,
              sort_utt=False,
              sort_stop_epoch=None,
              frame_stacking=False,
              splice=1,
              num_gpu=1):

        print('========================================')
        print('  label_type: %s' % label_type)
        print('  data_type: %s' % data_type)
        print('  shuffle: %s' % str(shuffle))
        print('  sort_utt: %s' % str(sort_utt))
        print('  sort_stop_epoch: %s' % str(sort_stop_epoch))
        print('  frame_stacking: %s' % str(frame_stacking))
        print('  splice: %d' % splice)
        print('  num_gpu: %d' % num_gpu)
        print('========================================')

        if 'kana' in label_type:
            map_file_path = '../../metrics/mapping_files/' + label_type + '.txt'
        elif 'kanji' in label_type:
            map_file_path = '../../metrics/mapping_files/' + \
                label_type + '_train_subset.txt'

        num_stack = 3 if frame_stacking else 1
        num_skip = 3 if frame_stacking else 1
        dataset = Dataset(data_type=data_type,
                          train_data_size='train_subset',
                          label_type=label_type,
                          map_file_path=map_file_path,
                          batch_size=64,
                          max_epoch=2,
                          splice=splice,
                          num_stack=num_stack,
                          num_skip=num_skip,
                          shuffle=shuffle,
                          sort_utt=sort_utt,
                          sort_stop_epoch=sort_stop_epoch,
                          progressbar=True,
                          num_gpu=num_gpu)

        print('=> Loading mini-batch...')

        idx2char = Idx2char(map_file_path)
        # idx2word = Idx2word(map_file_path)

        for data, is_new_epoch in dataset:
            inputs, labels, inputs_seq_len, labels_seq_len, input_names = data

            if data_type == 'train':
                for i, l in zip(inputs[0], labels[0]):
                    if len(i) < len(l):
                        raise ValueError(
                            'input length must be longer than label length.')

            if num_gpu > 1:
                for inputs_gpu in inputs:
                    print(inputs_gpu.shape)

            if 'eval' in data_type:
                str_true = labels[0][0][0]
            else:
                # if 'word' in label_type:
                #     str_true = '_'.join(idx2word(labels[0][0]))
                # else:
                str_true = idx2char(labels[0][0][0:labels_seq_len[0][0]])

            print('----- %s (epoch: %.3f) -----' %
                  (input_names[0][0], dataset.epoch_detail))
            print(inputs[0].shape)
            print(str_true)

            if dataset.epoch_detail >= 0.1:
                break
Beispiel #13
0
def decode(model,
           dataset,
           beam_width,
           max_decode_len,
           max_decode_len_sub,
           eval_batch_size=None,
           save_path=None):
    """Visualize label outputs.
    Args:
        model: the model to evaluate
        dataset: An instance of a `Dataset` class
        beam_width: (int): the size of beam
        max_decode_len (int): the length of output sequences
            to stop prediction when EOS token have not been emitted.
            This is used for seq2seq models.
        max_decode_len_sub (int)
        eval_batch_size (int, optional): the batch size when evaluating the model
        save_path (string): path to save decoding results
    """
    # Set batch size in the evaluation
    if eval_batch_size is not None:
        dataset.batch_size = eval_batch_size

    idx2word = Idx2word(vocab_file_path=dataset.vocab_file_path)
    if dataset.label_type_sub == 'character':
        idx2char = Idx2char(vocab_file_path=dataset.vocab_file_path_sub)
    elif dataset.label_type_sub == 'character_capital_divide':
        idx2char = Idx2char(vocab_file_path=dataset.vocab_file_path_sub,
                            capital_divide=True)

    if save_path is not None:
        sys.stdout = open(join(model.model_dir, 'decode.txt'), 'w')

    for batch, is_new_epoch in dataset:

        # Decode
        if model.model_type == 'charseq_attention':
            best_hyps, best_hyps_sub, perm_idx = model.decode(
                batch['xs'],
                batch['x_lens'],
                beam_width=beam_width,
                max_decode_len=max_decode_len,
                max_decode_len_sub=100)
        else:
            best_hyps, perm_idx = model.decode(batch['xs'],
                                               batch['x_lens'],
                                               beam_width=beam_width,
                                               max_decode_len=max_decode_len)
            best_hyps_sub, perm_idx = model.decode(
                batch['xs'],
                batch['x_lens'],
                beam_width=beam_width,
                max_decode_len=max_decode_len_sub,
                is_sub_task=True)

        ys = batch['ys'][perm_idx]
        y_lens = batch['y_lens'][perm_idx]
        ys_sub = batch['ys_sub'][perm_idx]
        y_lens_sub = batch['y_lens_sub'][perm_idx]

        for b in range(len(batch['xs'])):

            ##############################
            # Reference
            ##############################
            if dataset.is_test:
                str_ref = ys[b][0]
                str_ref_sub = ys_sub[b][0]
                # NOTE: transcript is seperated by space('_')
            else:
                # Convert from list of index to string
                str_ref = idx2word(ys[b][:y_lens[b]])
                str_ref_sub = idx2word(ys_sub[b][:y_lens_sub[b]])

            ##############################
            # Hypothesis
            ##############################
            # Convert from list of index to string
            str_hyp = idx2word(best_hyps[b])
            str_hyp_sub = idx2char(best_hyps_sub[b])

            if model.model_type != 'hierarchical_ctc':
                str_hyp = str_hyp.split('>')[0]
                str_hyp_sub = str_hyp_sub.split('>')[0]
                # NOTE: Trancate by the first <EOS>

                # Remove the last space
                if len(str_hyp) > 0 and str_hyp[-1] == '_':
                    str_hyp = str_hyp[:-1]
                if len(str_hyp_sub) > 0 and str_hyp_sub[-1] == '_':
                    str_hyp_sub = str_hyp_sub[:-1]

            ##############################
            # Post-proccessing
            ##############################
            # Remove garbage labels
            str_ref = re.sub(r'[\'>]+', '', str_ref)
            str_hyp = re.sub(r'[\'>]+', '', str_hyp)

            print('----- wav: %s -----' % batch['input_names'][b])
            print('Ref: %s' % str_ref.replace('_', ' '))
            print('Hyp (main): %s' % str_hyp.replace('_', ' '))
            # print('Ref (sub): %s' % str_ref_sub.replace('_', ' '))
            print('Hyp (sub): %s' % str_hyp_sub.replace('_', ' '))

            wer, _, _, _ = compute_wer(ref=str_ref.split('_'),
                                       hyp=str_hyp.split('_'),
                                       normalize=True)
            print('WER: %f %%' % (wer * 100))
            cer, _, _, _ = compute_wer(ref=list(str_ref_sub.replace('_', '')),
                                       hyp=list(str_hyp_sub.replace('_', '')),
                                       normalize=True)
            print('CER: %f %%' % (cer * 100))

        if is_new_epoch:
            break
def decode(session,
           decode_op,
           model,
           dataset,
           label_type,
           is_test=True,
           save_path=None):
    """Visualize label outputs of CTC model.
    Args:
        session: session of training model
        decode_op: operation for decoding
        model: the model to evaluate
        dataset: An instance of a `Dataset` class
        label_type (string): phone39 or phone48 or phone61 or character or
            character_capital_divide
        is_test (bool, optional):
        save_path (string, optional): path to save decoding results
    """
    if label_type == 'character':
        map_fn = Idx2char(
            map_file_path='../metrics/mapping_files/character.txt')
    elif label_type == 'character_capital_divide':
        map_fn = Idx2char(
            map_file_path=
            '../metrics/mapping_files/character_capital_divide.txt',
            capital_divide=True)
    else:
        map_fn = Idx2phone(map_file_path='../metrics/mapping_files/' +
                           label_type + '.txt')

    if save_path is not None:
        sys.stdout = open(join(model.model_dir, 'decode.txt'), 'w')

    for data, is_new_epoch in dataset:

        # Create feed dictionary for next mini batch
        inputs, labels_true, inputs_seq_len, input_names = data

        feed_dict = {
            model.inputs_pl_list[0]: inputs[0],
            model.inputs_seq_len_pl_list[0]: inputs_seq_len[0],
            model.keep_prob_pl_list[0]: 1.0
        }

        batch_size = inputs[0].shape[0]
        labels_pred_st = session.run(decode_op, feed_dict=feed_dict)
        try:
            labels_pred = sparsetensor2list(labels_pred_st,
                                            batch_size=batch_size)
        except IndexError:
            # no output
            labels_pred = ['']

        for i_batch in range(batch_size):
            print('----- wav: %s -----' % input_names[0][i_batch])
            if 'char' in label_type:
                if is_test:
                    str_true = labels_true[0][i_batch][0]
                else:
                    str_true = map_fn(labels_true[0][i_batch])
                str_pred = map_fn(labels_pred[i_batch])
            else:
                if is_test:
                    str_true = labels_true[0][i_batch][0]
                else:
                    str_true = map_fn(labels_true[0][i_batch])
                str_pred = map_fn(labels_pred[i_batch])

            print('Ref: %s' % str_true)
            print('Hyp: %s' % str_pred)

        if is_new_epoch:
            break
Beispiel #15
0
def do_eval_cer(session,
                decode_op,
                model,
                dataset,
                label_type,
                eval_batch_size=None,
                progressbar=False,
                is_multitask=False):
    """Evaluate trained model by Character Error Rate.
    Args:
        session: session of training model
        decode_op: operation for decoding
        model: the model to evaluate
        dataset: An instance of a `Dataset` class
        label_type (string): character or character_capital_divide
        eval_batch_size (int, optional): the batch size when evaluating the model
        progressbar (bool, optional): if True, visualize the progressbar
        is_multitask (bool, optional): if True, evaluate the multitask model
    Return:
        cer_mean (float): An average of CER
        wer_mean (float): An average of WER
    """
    # Reset data counter
    dataset.reset()

    if label_type == 'character':
        idx2char = Idx2char(
            map_file_path='../metrics/mapping_files/ctc/character.txt')
    elif label_type == 'character_capital_divide':
        idx2char = Idx2char(
            map_file_path=
            '../metrics/mapping_files/ctc/character_capital_divide.txt',
            capital_divide=True,
            space_mark='_')

    cer_mean, wer_mean = 0, 0
    if progressbar:
        pbar = tqdm(total=len(dataset))
    for data, is_new_epoch in dataset:

        # Create feed dictionary for next mini batch
        if is_multitask:
            inputs, labels_true, _, inputs_seq_len, _ = data
        else:
            inputs, labels_true, inputs_seq_len, _ = data

        feed_dict = {
            model.inputs_pl_list[0]: inputs,
            model.inputs_seq_len_pl_list[0]: inputs_seq_len,
            model.keep_prob_input_pl_list[0]: 1.0,
            model.keep_prob_hidden_pl_list[0]: 1.0,
            model.keep_prob_output_pl_list[0]: 1.0
        }

        batch_size_each = len(inputs)

        labels_pred_st = session.run(decode_op, feed_dict=feed_dict)
        labels_pred = sparsetensor2list(labels_pred_st, batch_size_each)
        for i_batch in range(batch_size_each):

            # Convert from list of index to string
            str_true = idx2char(labels_true[i_batch])
            str_pred = idx2char(labels_pred[i_batch])

            # Remove consecutive spaces
            str_pred = re.sub(r'[_]+', '_', str_pred)

            # Remove garbage labels
            str_true = re.sub(r'[\'\":;!?,.-]+', "", str_true)
            str_pred = re.sub(r'[\'\":;!?,.-]+', "", str_pred)

            # Compute WER
            wer_mean += compute_wer(hyp=str_pred.split('_'),
                                    ref=str_true.split('_'),
                                    normalize=True)
            # substitute, insert, delete = wer_align(
            #     ref=str_pred.split('_'),
            #     hyp=str_true.split('_'))
            # print(substitute)
            # print(insert)
            # print(delete)

            # Remove spaces
            str_pred = re.sub(r'[_]+', "", str_pred)
            str_true = re.sub(r'[_]+', "", str_true)

            # Compute CER
            cer_mean += compute_cer(str_pred=str_pred,
                                    str_true=str_true,
                                    normalize=True)

            if progressbar:
                pbar.update(1)

        if is_new_epoch:
            break

    cer_mean /= len(dataset)
    wer_mean /= len(dataset)

    return cer_mean, wer_mean
def decode(session,
           decode_op,
           model,
           dataset,
           label_type,
           train_data_size,
           is_test=True,
           save_path=None):
    """Visualize label outputs of CTC model.
    Args:
        session: session of training model
        decode_op: operation for decoding
        model: the model to evaluate
        dataset: An instance of a `Dataset` class
        label_type (string): kanji or kanji or kanji_divide or kana_divide
        train_data_size (string): train_subset or train_fullset
        is_test (bool, optional): set to True when evaluating by the test set
        save_path (string, optional): path to save decoding results
    """
    if 'kanji' in label_type:
        map_file_path = '../metrics/mapping_files/' + \
            label_type + '_' + train_data_size + '.txt'
    elif 'kana' in label_type:
        map_file_path = '../metrics/mapping_files/' + label_type + '.txt'
    else:
        raise TypeError

    idx2char = Idx2char(map_file_path=map_file_path)

    if save_path is not None:
        sys.stdout = open(join(model.model_dir, 'decode.txt'), 'w')

    for data, is_new_epoch in dataset:

        # Create feed dictionary for next mini batch
        inputs, labels_true, inputs_seq_len, input_names = data

        feed_dict = {
            model.inputs_pl_list[0]: inputs[0],
            model.inputs_seq_len_pl_list[0]: inputs_seq_len[0],
            model.keep_prob_pl_list[0]: 1.0
        }

        # Decode
        batch_size = inputs[0].shape[0]
        labels_pred_st = session.run(decode_op, feed_dict=feed_dict)
        no_output_flag = False
        try:
            labels_pred = sparsetensor2list(labels_pred_st,
                                            batch_size=batch_size)
        except IndexError:
            # no output
            no_output_flag = True

        # Visualize
        for i_batch in range(batch_size):

            print('----- wav: %s -----' % input_names[0][i_batch])
            if is_test:
                str_true = labels_true[0][i_batch][0]
            else:
                str_true = idx2char(labels_true[0][i_batch])
            if no_output_flag:
                str_pred = ''
            else:
                str_pred = idx2char(labels_pred[i_batch])

            print('Ref: %s' % str_true)
            print('Hyp: %s' % str_pred)

        if is_new_epoch:
            break
    def __init__(self, data_save_path,
                 backend, input_freq, use_delta, use_double_delta,
                 data_type, data_size, label_type, label_type_sub,
                 batch_size, max_epoch=None, splice=1,
                 num_stack=1, num_skip=1,
                 min_frame_num=40,
                 shuffle=False, sort_utt=False, reverse=False,
                 sort_stop_epoch=None, num_gpus=1, tool='htk',
                 num_enque=None, dynamic_batching=False):
        """A class for loading dataset.
        Args:
            data_save_path (string): path to saved data
            backend (string): pytorch or chainer
            input_freq (int): the number of dimensions of acoustics
            use_delta (bool): if True, use the delta feature
            use_double_delta (bool): if True, use the acceleration feature
            data_type (string): train or dev_clean or dev_other or test_clean
                or test_other
            data_size (string): 100 or 460 or 960
            label_type (string): word
            label_type_sub (string): characater or characater_capital_divide
            batch_size (int): the size of mini-batch
            max_epoch (int): the max epoch. None means infinite loop.
            splice (int): frames to splice. Default is 1 frame.
            num_stack (int): the number of frames to stack
            num_skip (int): the number of frames to skip
            shuffle (bool): if True, shuffle utterances. This is
                disabled when sort_utt is True.
            sort_utt (bool): if True, sort all utterances in the
                ascending order
            reverse (bool): if True, sort utteraces in the
                descending order
            sort_stop_epoch (int): After sort_stop_epoch, training
                will revert back to a random order
            num_gpus (int): the number of GPUs
            tool (string): htk or librosa or python_speech_features
            num_enque (int): the number of elements to enqueue
            dynamic_batching (bool): if True, batch size will be
                chainged dynamically in training
        """
        self.backend = backend
        self.input_freq = input_freq
        self.use_delta = use_delta
        self.use_double_delta = use_double_delta
        self.data_type = data_type
        self.data_size = data_size
        self.label_type = label_type
        self.label_type_sub = label_type_sub
        self.batch_size = batch_size * num_gpus
        self.max_epoch = max_epoch
        self.splice = splice
        self.num_stack = num_stack
        self.num_skip = num_skip
        self.shuffle = shuffle
        self.sort_utt = sort_utt
        self.sort_stop_epoch = sort_stop_epoch
        self.num_gpus = num_gpus
        self.tool = tool
        self.num_enque = num_enque
        self.dynamic_batching = dynamic_batching
        self.is_test = True if 'test' in data_type else False

        self.vocab_file_path = join(
            data_save_path, 'vocab', data_size, label_type + '.txt')
        self.idx2word = Idx2word(self.vocab_file_path)
        self.word2idx = Word2idx(self.vocab_file_path)
        self.vocab_file_path_sub = join(
            data_save_path, 'vocab', data_size, label_type_sub + '.txt')
        self.idx2char = Idx2char(
            self.vocab_file_path_sub,
            capital_divide=label_type_sub == 'character_capital_divide')
        self.char2idx = Char2idx(
            self.vocab_file_path_sub,
            capital_divide=label_type_sub == 'character_capital_divide')

        super(Dataset, self).__init__(vocab_file_path=self.vocab_file_path,
                                      vocab_file_path_sub=self.vocab_file_path_sub)

        # Load dataset file
        dataset_path = join(
            data_save_path, 'dataset', tool, data_size, data_type, label_type + '.csv')
        dataset_path_sub = join(
            data_save_path, 'dataset', tool, data_size, data_type, label_type_sub + '.csv')
        df = pd.read_csv(dataset_path)
        df = df.loc[:, ['frame_num', 'input_path', 'transcript']]
        df_sub = pd.read_csv(dataset_path_sub)
        df_sub = df_sub.loc[:, ['frame_num', 'input_path', 'transcript']]

        # Remove inappropriate utteraces
        if not self.is_test:
            logger.info('Original utterance num: %d' % len(df))
            df = df[df.apply(
                lambda x: min_frame_num <= x['frame_num'], axis=1)]
            logger.info('Restricted utterance num: %d' % len(df))

        # Sort paths to input & label
        if sort_utt:
            df = df.sort_values(by='frame_num', ascending=not reverse)
            df_sub = df_sub.sort_values(by='frame_num', ascending=not reverse)
        else:
            df = df.sort_values(by='input_path', ascending=True)
            df_sub = df_sub.sort_values(by='input_path', ascending=True)

        assert len(df) == len(df_sub)

        self.df = df
        self.df_sub = df_sub
        self.rest = set(list(df.index))
def decode(session,
           decode_op_main,
           decode_op_sub,
           model,
           dataset,
           label_type_main,
           label_type_sub,
           is_test=True,
           save_path=None):
    """Visualize label outputs of Multi-task CTC model.
    Args:
        session: session of training model
        decode_op_main: operation for decoding in the main task
        decode_op_sub: operation for decoding in the sub task
        model: the model to evaluate
        dataset: An instance of a `Dataset` class
        label_type_main (string): character or character_capital_divide
        label_type_sub (string): phone39 or phone48 or phone61
        is_test (bool, optional):
        save_path (string, optional): path to save decoding results
    """
    idx2char = Idx2char(map_file_path='../metrics/mapping_files/' +
                        label_type_main + '.txt')
    idx2phone = Idx2phone(map_file_path='../metrics/mapping_files/' +
                          label_type_sub + '.txt')

    if save_path is not None:
        sys.stdout = open(join(model.model_dir, 'decode.txt'), 'w')

    for data, is_new_epoch in dataset:

        # Create feed dictionary for next mini batch
        inputs, labels_true_char, labels_true_phone, inputs_seq_len, input_names = data

        feed_dict = {
            model.inputs_pl_list[0]: inputs[0],
            model.inputs_seq_len_pl_list[0]: inputs_seq_len[0],
            model.keep_prob_pl_list[0]: 1.0
        }

        batch_size = inputs[0].shape[0]
        labels_pred_char_st, labels_pred_phone_st = session.run(
            [decode_op_main, decode_op_sub], feed_dict=feed_dict)
        try:
            labels_pred_char = sparsetensor2list(labels_pred_char_st,
                                                 batch_size=batch_size)
        except:
            # no output
            labels_pred_char = ['']
        try:
            labels_pred_phone = sparsetensor2list(labels_pred_char_st,
                                                  batch_size=batch_size)
        except:
            # no output
            labels_pred_phone = ['']

        for i_batch in range(batch_size):
            print('----- wav: %s -----' % input_names[0][i_batch])

            if is_test:
                str_true_char = labels_true_char[0][i_batch][0].replace(
                    '_', ' ')
                str_true_phone = labels_true_phone[0][i_batch][0]
            else:
                str_true_char = idx2char(labels_true_char[0][i_batch])
                str_true_phone = idx2phone(labels_true_phone[0][i_batch])

            str_pred_char = idx2char(labels_pred_char[i_batch])
            str_pred_phone = idx2phone(labels_pred_phone[i_batch])

            print('Ref (char): %s' % str_true_char)
            print('Hyp (char): %s' % str_pred_char)
            print('Ref (phone): %s' % str_true_phone)
            print('Hyp (phone): %s' % str_pred_phone)

        if is_new_epoch:
            break
Beispiel #19
0
def decode(session,
           decode_op,
           model,
           dataset,
           label_type,
           is_test=False,
           save_path=None):
    """Visualize label outputs of Attention-based model.
    Args:
        session: session of training model
        decode_op: operation for decoding
        model: the model to evaluate
        dataset: An instance of a `Dataset` class
        label_type (string): phone39 or phone48 or phone61 or character or
            character_capital_divide
        is_test (bool, optional):
        save_path (string): path to save decoding results
    """
    if label_type == 'character':
        map_fn = Idx2char(
            map_file_path='../metrics/mapping_files/character.txt')
    elif label_type == 'character_capital_divide':
        map_fn = Idx2char(
            map_file_path=
            '../metrics/mapping_files/character_capital_divide.txt',
            capital_divide=True)
    else:
        map_fn = Idx2phone(map_file_path='../metrics/mapping_files/' +
                           label_type + '.txt')

    if save_path is not None:
        sys.stdout = open(join(model.model_dir, 'decode.txt'), 'w')

    for data, is_new_epoch in dataset:

        # Create feed dictionary for next mini batch
        inputs, labels_true, inputs_seq_len, labels_seq_len, input_names = data

        feed_dict = {
            model.inputs_pl_list[0]: inputs[0],
            model.inputs_seq_len_pl_list[0]: inputs_seq_len[0],
            model.keep_prob_encoder_pl_list[0]: 1.0,
            model.keep_prob_decoder_pl_list[0]: 1.0,
            model.keep_prob_embedding_pl_list[0]: 1.0
        }

        batch_size = inputs[0].shape[0]
        labels_pred = session.run(decode_op, feed_dict=feed_dict)
        for i_batch in range(batch_size):
            print('----- wav: %s -----' % input_names[0][i_batch])
            if is_test:
                str_true = labels_true[0][i_batch][0]
            else:
                str_true = map_fn(
                    labels_true[0][i_batch][1:labels_seq_len[0][i_batch] - 1])
                # NOTE: Exclude <SOS> and <EOS>
            str_pred = map_fn(labels_pred[i_batch]).split('>')[0]
            # NOTE: Trancate by <EOS>

            if 'phone' in label_type:
                # Remove the last space
                if str_pred[-1] == ' ':
                    str_pred = str_pred[:-1]

            print('Ref: %s' % str_true)
            print('Hyp: %s' % str_pred)

        if is_new_epoch:
            break
def plot(model,
         dataset,
         beam_width,
         beam_width_sub,
         eval_batch_size=None,
         a2c_oracle=False,
         save_path=None):
    """Visualize attention weights of Attetnion-based model.
    Args:
        model: model to evaluate
        dataset: An instance of a `Dataset` class
        beam_width: (int): the size of beam i nteh main task
        beam_width_sub: (int): the size of beam in the sub task
        eval_batch_size (int, optional): the batch size when evaluating the model
        a2c_oracle (bool, optional):
        save_path (string, optional): path to save attention weights plotting
    """
    # Clean directory
    if save_path is not None and isdir(save_path):
        shutil.rmtree(save_path)
        mkdir(save_path)

    idx2word = Idx2word(dataset.vocab_file_path, return_list=True)
    idx2char = Idx2char(dataset.vocab_file_path_sub, return_list=True)

    for batch, is_new_epoch in dataset:
        batch_size = len(batch['xs'])

        if a2c_oracle:
            if dataset.is_test:
                max_label_num = 0
                for b in range(batch_size):
                    if max_label_num < len(list(batch['ys_sub'][b][0])):
                        max_label_num = len(list(batch['ys_sub'][b][0]))

                ys_sub = np.zeros((batch_size, max_label_num), dtype=np.int32)
                ys_sub -= 1  # pad with -1
                y_lens_sub = np.zeros((batch_size, ), dtype=np.int32)
                for b in range(batch_size):
                    indices = char2idx(batch['ys_sub'][b][0])
                    ys_sub[b, :len(indices)] = indices
                    y_lens_sub[b] = len(indices)
                    # NOTE: transcript is seperated by space('_')
            else:
                ys_sub = batch['ys_sub']
                y_lens_sub = batch['y_lens_sub']
        else:
            ys_sub = None
            y_lens_sub = None

        best_hyps, best_hyps_sub, aw, aw_sub, aw_dec = model.attention_weights(
            batch['xs'],
            batch['x_lens'],
            beam_width=beam_width,
            beam_width_sub=beam_width_sub,
            max_decode_len=MAX_DECODE_LEN_WORD,
            max_decode_len_sub=MAX_DECODE_LEN_CHAR,
            teacher_forcing=a2c_oracle,
            ys_sub=ys_sub,
            y_lens_sub=y_lens_sub)

        for b in range(len(batch['xs'])):
            word_list = idx2word(best_hyps[b])
            if 'word' in dataset.label_type_sub:
                char_list = idx2word(best_hyps_sub[b])
            else:
                char_list = idx2char(best_hyps_sub[b])

            # word to acoustic & character to acoustic
            plot_hierarchical_attention_weights(
                aw[b][:len(word_list), :batch['x_lens'][b]],
                aw_sub[b][:len(char_list), :batch['x_lens'][b]],
                label_list=word_list,
                label_list_sub=char_list,
                spectrogram=batch['xs'][b, :, :dataset.input_freq],
                save_path=mkdir_join(save_path,
                                     batch['input_names'][b] + '.png'),
                figsize=(40, 8))

            # word to characater
            plot_word2char_attention_weights(
                aw_dec[b][:len(word_list), :len(char_list)],
                label_list=word_list,
                label_list_sub=char_list,
                save_path=mkdir_join(
                    save_path, batch['input_names'][b] + '_word2char.png'),
                figsize=(40, 8))

            # with open(join(save_path, speaker, batch['input_names'][b] + '.txt'), 'w') as f:
            #     f.write(batch['ys'][b][0])

        if is_new_epoch:
            break
Beispiel #21
0
    def check(self,
              label_type,
              label_type_sub,
              data_type='dev_clean',
              data_size='100h',
              backend='pytorch',
              shuffle=False,
              sort_utt=True,
              sort_stop_epoch=None,
              frame_stacking=False,
              splice=1,
              num_gpus=1):

        print('========================================')
        print('  backend: %s' % backend)
        print('  label_type: %s' % label_type)
        print('  label_type_sub: %s' % label_type_sub)
        print('  data_type: %s' % data_type)
        print('  data_size: %s' % data_size)
        print('  shuffle: %s' % str(shuffle))
        print('  sort_utt: %s' % str(sort_utt))
        print('  sort_stop_epoch: %s' % str(sort_stop_epoch))
        print('  frame_stacking: %s' % str(frame_stacking))
        print('  splice: %d' % splice)
        print('  num_gpus: %d' % num_gpus)
        print('========================================')

        vocab_file_path = '../../metrics/vocab_files/' + \
            label_type + '_' + data_size + '.txt'
        vocab_file_path_sub = '../../metrics/vocab_files/' + \
            label_type_sub + '_' + data_size + '.txt'

        num_stack = 3 if frame_stacking else 1
        num_skip = 3 if frame_stacking else 1
        dataset = Dataset(backend=backend,
                          input_channel=40,
                          use_delta=True,
                          use_double_delta=True,
                          data_type=data_type,
                          data_size=data_size,
                          label_type=label_type,
                          label_type_sub=label_type_sub,
                          batch_size=64,
                          vocab_file_path=vocab_file_path,
                          vocab_file_path_sub=vocab_file_path_sub,
                          max_epoch=1,
                          splice=splice,
                          num_stack=num_stack,
                          num_skip=num_skip,
                          shuffle=shuffle,
                          sort_utt=sort_utt,
                          reverse=True,
                          sort_stop_epoch=sort_stop_epoch,
                          num_gpus=num_gpus,
                          save_format='numpy',
                          num_enque=None)

        print('=> Loading mini-batch...')
        idx2word = Idx2word(vocab_file_path, space_mark=' ')
        idx2char = Idx2char(vocab_file_path_sub)

        for batch, is_new_epoch in dataset:
            if data_type == 'train' and backend == 'pytorch':
                for i in range(len(batch['xs'])):
                    if batch['xs'].shape[1] < batch['ys'].shape[1]:
                        raise ValueError(
                            'input length must be longer than label length.')

            if dataset.is_test:
                str_ref = batch['ys'][0][0]
                str_ref_sub = batch['ys_sub'][0][0]
            else:
                str_ref = idx2word(batch['ys'][0][:batch['y_lens'][0]])
                str_ref_sub = idx2char(
                    batch['ys_sub'][0][:batch['y_lens_sub'][0]])

            print('----- %s (epoch: %.3f, batch: %d) -----' %
                  (batch['input_names'][0], dataset.epoch_detail,
                   len(batch['xs'])))
            print('=' * 20)
            print(str_ref)
            print('-' * 10)
            print(str_ref_sub)
            print('x_lens: %d' % (batch['x_lens'][0] * num_stack))
            if not dataset.is_test:
                print('y_lens (word): %d' % batch['y_lens'][0])
                print('y_lens_sub (char): %d' % batch['y_lens_sub'][0])

            if dataset.epoch_detail >= 0.01:
                break
def eval_word(models,
              dataset,
              beam_width,
              max_decode_len,
              beam_width_sub=1,
              max_decode_len_sub=300,
              eval_batch_size=None,
              length_penalty=0,
              progressbar=False,
              temperature=1,
              resolving_unk=False,
              a2c_oracle=False):
    """Evaluate trained model by Word Error Rate.
    Args:
        models (list): the models to evaluate
        dataset: An instance of a `Dataset' class
        max_decode_len (int): the length of output sequences
            to stop prediction. This is used for seq2seq models.
        beam_width_sub (int, optional): the size of beam in ths sub task
            This is used for the nested attention
        max_decode_len_sub (int, optional): the length of output sequences
            to stop prediction. This is used for the nested attention
        eval_batch_size (int, optional): the batch size when evaluating the model
        progressbar (bool, optional): if True, visualize the progressbar
        temperature (int, optional):
        resolving_unk (bool, optional):
        a2c_oracle (bool, optional):
    Returns:
        wer (float): Word error rate
        df_wer (pd.DataFrame): dataframe of substitution, insertion, and deletion
    """
    # Reset data counter
    dataset.reset()

    idx2word = Idx2word(dataset.vocab_file_path)
    if models[0].model_type == 'nested_attention':
        char2idx = Char2idx(dataset.vocab_file_path_sub)
    if models[0] in ['ctc', 'attention'] and resolving_unk:
        idx2char = Idx2char(dataset.vocab_file_path_sub,
                            capital_divide=dataset.label_type_sub ==
                            'character_capital_divide')

    wer = 0
    sub, ins, dele, = 0, 0, 0
    num_words = 0
    if progressbar:
        pbar = tqdm(total=len(dataset))  # TODO: fix this
    while True:
        batch, is_new_epoch = dataset.next(batch_size=eval_batch_size)

        batch_size = len(batch['xs'])

        # Decode
        if len(models) > 1:
            assert models[0].model_type in ['ctc']
            for i, model in enumerate(models):
                probs, x_lens, perm_idx = model.posteriors(
                    batch['xs'], batch['x_lens'])
                if i == 0:
                    probs_ensenmble = probs
                else:
                    probs_ensenmble += probs
            probs_ensenmble /= len(models)

            best_hyps = models[0].decode_from_probs(probs_ensenmble,
                                                    x_lens,
                                                    beam_width=1)
        else:
            model = models[0]
            # TODO: fix this

            if model.model_type == 'nested_attention':
                if a2c_oracle:
                    if dataset.is_test:
                        max_label_num = 0
                        for b in range(batch_size):
                            if max_label_num < len(list(
                                    batch['ys_sub'][b][0])):
                                max_label_num = len(list(
                                    batch['ys_sub'][b][0]))

                        ys_sub = np.zeros((batch_size, max_label_num),
                                          dtype=np.int32)
                        ys_sub -= 1  # pad with -1
                        y_lens_sub = np.zeros((batch_size, ), dtype=np.int32)
                        for b in range(batch_size):
                            indices = char2idx(batch['ys_sub'][b][0])
                            ys_sub[b, :len(indices)] = indices
                            y_lens_sub[b] = len(indices)
                            # NOTE: transcript is seperated by space('_')
                else:
                    ys_sub = batch['ys_sub']
                    y_lens_sub = batch['y_lens_sub']

                best_hyps, aw, best_hyps_sub, aw_sub, perm_idx = model.decode(
                    batch['xs'],
                    batch['x_lens'],
                    beam_width=beam_width,
                    beam_width_sub=beam_width_sub,
                    max_decode_len=max_decode_len,
                    max_decode_len_sub=max_label_num
                    if a2c_oracle else max_decode_len_sub,
                    length_penalty=length_penalty,
                    teacher_forcing=a2c_oracle,
                    ys_sub=ys_sub,
                    y_lens_sub=y_lens_sub)
            else:
                best_hyps, aw, perm_idx = model.decode(
                    batch['xs'],
                    batch['x_lens'],
                    beam_width=beam_width,
                    max_decode_len=max_decode_len,
                    length_penalty=length_penalty)
                if resolving_unk:
                    best_hyps_sub, aw_sub, _ = model.decode(
                        batch['xs'],
                        batch['x_lens'],
                        beam_width=beam_width,
                        max_decode_len=max_decode_len_sub,
                        length_penalty=length_penalty,
                        task_index=1)

        ys = batch['ys'][perm_idx]
        y_lens = batch['y_lens'][perm_idx]

        for b in range(batch_size):
            ##############################
            # Reference
            ##############################
            if dataset.is_test:
                str_ref = ys[b][0]
                # NOTE: transcript is seperated by space('_')
            else:
                # Convert from list of index to string
                str_ref = idx2word(ys[b][:y_lens[b]])

            ##############################
            # Hypothesis
            ##############################
            str_hyp = idx2word(best_hyps[b])
            if dataset.label_type == 'word':
                str_hyp = re.sub(r'(.*)_>(.*)', r'\1', str_hyp)
            else:
                str_hyp = re.sub(r'(.*)>(.*)', r'\1', str_hyp)
            # NOTE: Trancate by the first <EOS>

            ##############################
            # Resolving UNK
            ##############################
            if resolving_unk and 'OOV' in str_hyp:
                str_hyp = resolve_unk(str_hyp, best_hyps_sub[b], aw[b],
                                      aw_sub[b], idx2char)
                str_hyp = str_hyp.replace('*', '')

            ##############################
            # Post-proccessing
            ##############################
            # Remove garbage labels
            str_ref = re.sub(r'[@>]+', '', str_ref)
            str_hyp = re.sub(r'[@>]+', '', str_hyp)
            # NOTE: @ means noise

            # Remove consecutive spaces
            str_ref = re.sub(r'[_]+', '_', str_ref)
            str_hyp = re.sub(r'[_]+', '_', str_hyp)

            # Compute WER
            try:
                wer_b, sub_b, ins_b, del_b = compute_wer(
                    ref=str_ref.split('_'),
                    hyp=str_hyp.split('_'),
                    normalize=False)
                wer += wer_b
                sub += sub_b
                ins += ins_b
                dele += del_b
                num_words += len(str_ref.split('_'))
            except:
                pass

            if progressbar:
                pbar.update(1)

        if is_new_epoch:
            break

    if progressbar:
        pbar.close()

    # Reset data counters
    dataset.reset()

    wer /= num_words
    sub /= num_words
    ins /= num_words
    dele /= num_words

    df_wer = pd.DataFrame(
        {
            'SUB': [sub * 100],
            'INS': [ins * 100],
            'DEL': [dele * 100]
        },
        columns=['SUB', 'INS', 'DEL'],
        index=['WER'])

    return wer, df_wer
    def check(self, label_type, label_type_sub,
              data_type='dev', data_size='300h', backend='pytorch',
              shuffle=False, sort_utt=True, sort_stop_epoch=None,
              frame_stacking=False, splice=1, num_gpus=1):

        print('========================================')
        print('  backend: %s' % backend)
        print('  label_type: %s' % label_type)
        print('  label_type_sub: %s' % label_type_sub)
        print('  data_type: %s' % data_type)
        print('  data_size: %s' % data_size)
        print('  shuffle: %s' % str(shuffle))
        print('  sort_utt: %s' % str(sort_utt))
        print('  sort_stop_epoch: %s' % str(sort_stop_epoch))
        print('  frame_stacking: %s' % str(frame_stacking))
        print('  splice: %d' % splice)
        print('  num_gpus: %d' % num_gpus)
        print('========================================')

        num_stack = 3 if frame_stacking else 1
        num_skip = 3 if frame_stacking else 1
        dataset = Dataset(
            # data_save_path='/n/sd8/inaguma/corpus/swbd/kaldi/' + data_size,
            data_save_path='/n/sd8/inaguma/corpus/swbd/kaldi',
            backend=backend,
            input_freq=40, use_delta=True, use_double_delta=True,
            data_type=data_type, data_size=data_size,
            label_type=label_type, label_type_sub=label_type_sub,
            batch_size=64, max_epoch=1, splice=splice,
            num_stack=num_stack, num_skip=num_skip,
            shuffle=shuffle,
            sort_utt=sort_utt, reverse=True, sort_stop_epoch=sort_stop_epoch,
            num_gpus=num_gpus, tool='htk',
            num_enque=None)

        print('=> Loading mini-batch...')
        idx2word = Idx2word(dataset.vocab_file_path)
        idx2char = Idx2char(dataset.vocab_file_path_sub)

        for batch, is_new_epoch in dataset:
            if data_type == 'train' and backend == 'pytorch':
                for i in range(len(batch['xs'])):
                    if batch['xs'].shape[1] < batch['ys'].shape[1]:
                        raise ValueError(
                            'input length must be longer than label length.')

            if dataset.is_test:
                str_ref = batch['ys'][0][0]
                str_ref = str_ref.lower()
                str_ref = str_ref.replace('(', '').replace(')', '')

                str_ref_sub = batch['ys_sub'][0][0]
                str_ref_sub = str_ref_sub.lower()
                str_ref_sub = str_ref_sub.replace('(', '').replace(')', '')
            else:
                str_ref = idx2word(batch['ys'][0][:batch['y_lens'][0]])
                str_ref_sub = idx2char(
                    batch['ys_sub'][0][:batch['y_lens_sub'][0]])

            print('----- %s (epoch: %.3f, batch: %d) -----' %
                  (batch['input_names'][0], dataset.epoch_detail, len(batch['xs'])))
            print('=' * 20)
            print(str_ref)
            print('-' * 10)
            print(str_ref_sub)
            print('x_lens: %d' % (batch['x_lens'][0] * num_stack))
            if not dataset.is_test:
                print('y_lens (word): %d' % batch['y_lens'][0])
                print('y_lens_sub (char): %d' % batch['y_lens_sub'][0])

            if dataset.epoch_detail >= 1:
                break
def do_eval_cer(session,
                decode_ops,
                model,
                dataset,
                label_type,
                is_test=False,
                eval_batch_size=None,
                progressbar=False,
                is_multitask=False):
    """Evaluate trained model by Character Error Rate.
    Args:
        session: session of training model
        decode_ops: list of operations for decoding
        model: the model to evaluate
        dataset: An instance of a `Dataset` class
        label_type (string): character or character_capital_divide
        is_test (bool, optional): set to True when evaluating by the test set
        eval_batch_size (int, optional): the batch size when evaluating the model
        progressbar (bool, optional): if True, visualize the progressbar
        is_multitask (bool, optional): if True, evaluate the multitask model
    Return:
        cer_mean (float): An average of CER
        wer_mean (float): An average of WER
    """
    assert isinstance(decode_ops, list), "decode_ops must be a list."

    batch_size_original = dataset.batch_size

    # Reset data counter
    dataset.reset()

    # Set batch size in the evaluation
    if eval_batch_size is not None:
        dataset.batch_size = eval_batch_size

    if label_type == 'character':
        idx2char = Idx2char(
            map_file_path='../metrics/mapping_files/character.txt')
    elif label_type == 'character_capital_divide':
        idx2char = Idx2char(
            map_file_path=
            '../metrics/mapping_files/character_capital_divide.txt',
            capital_divide=True,
            space_mark='_')
    else:
        raise TypeError

    cer_mean, wer_mean = 0, 0
    skip_data_num = 0
    if progressbar:
        pbar = tqdm(total=len(dataset))
    for data, is_new_epoch in dataset:

        # Create feed dictionary for next mini batch
        if is_multitask:
            inputs, _, labels_true, inputs_seq_len, _ = data
        else:
            inputs, labels_true, inputs_seq_len, _ = data

        feed_dict = {}
        for i_device in range(len(decode_ops)):
            feed_dict[model.inputs_pl_list[i_device]] = inputs[i_device]
            feed_dict[model.inputs_seq_len_pl_list[i_device]] = inputs_seq_len[
                i_device]
            feed_dict[model.keep_prob_pl_list[i_device]] = 1.0

        labels_pred_st_list = session.run(decode_ops, feed_dict=feed_dict)
        for i_device, labels_pred_st in enumerate(labels_pred_st_list):
            batch_size_device = len(inputs[i_device])
            try:
                labels_pred = sparsetensor2list(labels_pred_st,
                                                batch_size_device)
                for i_batch in range(batch_size_device):

                    # Convert from list of index to string
                    if is_test:
                        str_true = labels_true[i_device][i_batch][0]
                        # NOTE: transcript is seperated by space('_')
                    else:
                        str_true = idx2char(labels_true[i_device][i_batch],
                                            padded_value=dataset.padded_value)
                    str_pred = idx2char(labels_pred[i_batch])

                    # Remove consecutive spaces
                    str_pred = re.sub(r'[_]+', '_', str_pred)

                    # Remove garbage labels
                    str_true = re.sub(r'[\']+', '', str_true)
                    str_pred = re.sub(r'[\']+', '', str_pred)

                    # Compute WER
                    wer_mean += compute_wer(ref=str_true.split('_'),
                                            hyp=str_pred.split('_'),
                                            normalize=True)
                    # substitute, insert, delete = wer_align(
                    #     ref=str_pred.split('_'),
                    #     hyp=str_true.split('_'))
                    # print('SUB: %d' % substitute)
                    # print('INS: %d' % insert)
                    # print('DEL: %d' % delete)

                    # Remove spaces
                    str_true = re.sub(r'[_]+', '', str_true)
                    str_pred = re.sub(r'[_]+', '', str_pred)

                    # Compute CER
                    cer_mean += compute_cer(str_pred=str_pred,
                                            str_true=str_true,
                                            normalize=True)

                    if progressbar:
                        pbar.update(1)

            except IndexError:
                print('skipped')
                skip_data_num += batch_size_device
                # TODO: Conduct decoding again with batch size 1

                if progressbar:
                    pbar.update(batch_size_device)

        if is_new_epoch:
            break

    cer_mean /= (len(dataset) - skip_data_num)
    wer_mean /= (len(dataset) - skip_data_num)
    # TODO: Fix this

    # Register original batch size
    if eval_batch_size is not None:
        dataset.batch_size = batch_size_original

    return cer_mean, wer_mean
    def check_loading(self, label_type, data_type='dev_clean',
                      shuffle=False,  sort_utt=False, sort_stop_epoch=None,
                      frame_stacking=False, splice=1, num_gpu=1):

        print('========================================')
        print('  label_type: %s' % label_type)
        print('  data_type: %s' % data_type)
        print('  shuffle: %s' % str(shuffle))
        print('  sort_utt: %s' % str(sort_utt))
        print('  sort_stop_epoch: %s' % str(sort_stop_epoch))
        print('  frame_stacking: %s' % str(frame_stacking))
        print('  splice: %d' % splice)
        print('  num_gpu: %d' % num_gpu)
        print('========================================')

        num_stack = 3 if frame_stacking else 1
        num_skip = 3 if frame_stacking else 1
        dataset = Dataset(
            data_type=data_type, train_data_size='train_clean100',
            label_type=label_type,
            batch_size=64, max_epoch=1, splice=splice,
            num_stack=num_stack, num_skip=num_skip,
            shuffle=shuffle, sort_utt=sort_utt, sort_stop_epoch=sort_stop_epoch,
            progressbar=True, num_gpu=num_gpu)

        print('=> Loading mini-batch...')
        if label_type == 'character':
            map_file_path = '../../metrics/mapping_files/ctc/character.txt'
        elif label_type == 'character_capital_divide':
            map_file_path = '../../metrics/mapping_files/ctc/character_capital.txt'
        elif label_type == 'word':
            map_file_path = '../../metrics/mapping_files/ctc/word_' + \
                dataset.train_data_size + '.txt'

        idx2char = Idx2char(map_file_path)
        idx2word = Idx2word(map_file_path)

        for data, is_new_epoch in dataset:
            inputs, labels, inputs_seq_len, input_names = data

            if not self.length_check:
                for i, l in zip(inputs[0], labels[0]):
                    if len(i) < len(l):
                        raise ValueError(
                            'input length must be longer than label length.')
                self.length_check = True

            if num_gpu > 1:
                for inputs_gpu in inputs:
                    print(inputs_gpu.shape)

            if label_type == 'word':
                if 'test' not in data_type:
                    str_true = ' '.join(idx2word(labels[0][0]))
                else:
                    word_list = np.delete(labels[0][0], np.where(
                        labels[0][0] == None), axis=0)
                    str_true = ' '.join(word_list)
            else:
                str_true = idx2char(labels[0][0])
            str_true = re.sub(r'_', ' ', str_true)
            print('----- %s (epoch: %.3f) -----' %
                  (input_names[0][0], dataset.epoch_detail))
            print(inputs[0].shape)
            print(str_true)

            if dataset.epoch_detail >= 0.05:
                break
def decode(session,
           decode_op,
           model,
           dataset,
           label_type,
           ss_type,
           is_test=False,
           eval_batch_size=None,
           save_path=None):
    """Visualize label outputs of CTC model.
    Args:
        session: session of training model
        decode_op: operation for decoding
        model: the model to evaluate
        dataset: An instance of a `Dataset` class
        label_type (string): kana
        ss_type (string): remove or insert_left or insert_both or insert_right
        is_test (bool, optional): set to True when evaluating by the test set
        eval_batch_size (int, optional): the batch size when evaluating the model
        save_path (string, optional): path to save decoding results
    """
    batch_size_original = dataset.batch_size

    # Reset data counter
    dataset.reset()

    # Set batch size in the evaluation
    if eval_batch_size is not None:
        dataset.batch_size = eval_batch_size

    idx2char = Idx2char(map_file_path='../metrics/mapping_files/' +
                        label_type + '_' + ss_type + '.txt')

    if save_path is not None:
        sys.stdout = open(join(model.model_dir, 'decode.txt'), 'w')

    for data, is_new_epoch in dataset:

        # Create feed dictionary for next mini batch
        inputs, labels_true, inputs_seq_len, input_names = data

        feed_dict = {
            model.inputs_pl_list[0]: inputs[0],
            model.inputs_seq_len_pl_list[0]: inputs_seq_len[0],
            model.keep_prob_pl_list[0]: 1.0
        }

        batch_size = inputs[0].shape[0]
        labels_pred_st = session.run(decode_op, feed_dict=feed_dict)
        try:
            labels_pred = sparsetensor2list(labels_pred_st,
                                            batch_size=batch_size)
        except IndexError:
            # no output
            labels_pred = ['']

        for i_batch in range(batch_size):
            print('----- wav: %s -----' % input_names[0][i_batch])
            if 'char' in label_type:
                if is_test:
                    str_true = labels_true[0][i_batch][0]
                else:
                    str_true = idx2char(labels_true[0][i_batch])
                str_pred = idx2char(labels_pred[i_batch])
            else:
                if is_test:
                    str_true = labels_true[0][i_batch][0]
                else:
                    str_true = idx2char(labels_true[0][i_batch])
                str_pred = idx2char(labels_pred[i_batch])

            print('Ref: %s' % str_true)
            print('Hyp: %s' % str_pred)

        if is_new_epoch:
            break

    # Register original batch size
    if eval_batch_size is not None:
        dataset.batch_size = batch_size_original
def decode(model, dataset, beam_width, beam_width_sub,
           eval_batch_size=None, save_path=None, resolving_unk=False):
    """Visualize label outputs.
    Args:
        model: the model to evaluate
        dataset: An instance of a `Dataset` class
        beam_width: (int): the size of beam in the main task
        beam_width: (int): the size of beam in the sub task
        eval_batch_size (int, optional): the batch size when evaluating the model
        save_path (string): path to save decoding results
        resolving_unk (bool, optional):
    """
    # Set batch size in the evaluation
    if eval_batch_size is not None:
        dataset.batch_size = eval_batch_size

    idx2word = Idx2word(vocab_file_path=dataset.vocab_file_path)
    idx2char = Idx2char(vocab_file_path=dataset.vocab_file_path_sub,
                        capital_divide=dataset.label_type_sub == 'character_capital_divide')

    # Read GLM file
    glm = GLM(
        glm_path='/n/sd8/inaguma/corpus/swbd/data/eval2000/LDC2002T43/reference/en20000405_hub5.glm')

    if save_path is not None:
        sys.stdout = open(join(model.model_dir, 'decode.txt'), 'w')

    for batch, is_new_epoch in dataset:

        # Decode
        if model.model_type == 'nested_attention':
            best_hyps, aw, best_hyps_sub, aw_sub, perm_idx = model.decode(
                batch['xs'], batch['x_lens'],
                beam_width=beam_width,
                beam_width_sub=beam_width_sub,
                max_decode_len=MAX_DECODE_LEN_WORD,
                max_decode_len_sub=MAX_DECODE_LEN_CHAR)
        else:
            best_hyps, aw, perm_idx = model.decode(
                batch['xs'], batch['x_lens'],
                beam_width=beam_width,
                max_decode_len=MAX_DECODE_LEN_WORD)
            best_hyps_sub, aw_sub, _ = model.decode(
                batch['xs'], batch['x_lens'],
                beam_width=beam_width_sub,
                max_decode_len=MAX_DECODE_LEN_CHAR,
                task_index=1)

        ys = batch['ys'][perm_idx]
        y_lens = batch['y_lens'][perm_idx]
        ys_sub = batch['ys_sub'][perm_idx]
        y_lens_sub = batch['y_lens_sub'][perm_idx]

        for b in range(len(batch['xs'])):

            ##############################
            # Reference
            ##############################
            if dataset.is_test:
                str_ref_original = ys[b][0]
                str_ref_sub = ys_sub[b][0]
                # NOTE: transcript is seperated by space('_')
            else:
                # Convert from list of index to string
                str_ref_original = idx2word(ys[b][: y_lens[b]])
                str_ref_sub = idx2word(ys_sub[b][:y_lens_sub[b]])

            ##############################
            # Hypothesis
            ##############################
            # Convert from list of index to string
            str_hyp = idx2word(best_hyps[b])
            str_hyp_sub = idx2char(best_hyps_sub[b])

            ##############################
            # Resolving UNK
            ##############################
            if 'OOV' in str_hyp and resolving_unk:
                str_hyp_no_unk = resolve_unk(
                    str_hyp, best_hyps_sub[b], aw[b], aw_sub[b], idx2char)

            # if 'OOV' not in str_hyp:
            #     continue

            ##############################
            # Post-proccessing
            ##############################
            str_ref = fix_trans(str_ref_original, glm)
            str_ref_sub = fix_trans(str_ref_sub, glm)
            str_hyp = fix_trans(str_hyp, glm)
            str_hyp_sub = fix_trans(str_hyp_sub, glm)
            str_hyp_no_unk = fix_trans(str_hyp_no_unk, glm)

            if len(str_ref) == 0:
                continue

            print('----- wav: %s -----' % batch['input_names'][b])
            print('Ref         : %s' % str_ref.replace('_', ' '))
            print('Hyp (main)  : %s' % str_hyp.replace('_', ' '))
            print('Hyp (sub)   : %s' % str_hyp_sub.replace('_', ' '))
            if 'OOV' in str_hyp and resolving_unk:
                print('Hyp (no UNK): %s' % str_hyp_no_unk.replace('_', ' '))

            try:
                # Compute WER
                wer, _, _, _ = compute_wer(
                    ref=str_ref.split('_'),
                    hyp=str_hyp.replace(r'_>.*', '').split('_'),
                    normalize=True)
                print('WER (main)  : %.3f %%' % (wer * 100))
                wer_sub, _, _, _ = compute_wer(
                    ref=str_ref_sub.split('_'),
                    hyp=str_hyp_sub.replace(r'>.*', '').split('_'),
                    normalize=True)
                print('WER (sub)   : %.3f %%' % (wer_sub * 100))
                if 'OOV' in str_hyp and resolving_unk:
                    wer_no_unk, _, _, _ = compute_wer(
                        ref=str_ref.split('_'),
                        hyp=str_hyp_no_unk.replace(
                            '*', '').replace(r'_>.*', '').split('_'),
                        normalize=True)
                    print('WER (no UNK): %.3f %%' % (wer_no_unk * 100))
            except:
                print('--- skipped ---')

        if is_new_epoch:
            break
def do_eval_cer(models,
                model_type,
                dataset,
                label_type,
                beam_width,
                max_decode_len,
                eval_batch_size=None,
                temperature=1,
                progressbar=False):
    """Evaluate trained models by Character Error Rate.
    Args:
        models (list): the model to evaluate
        model_type (string): ctc or attention or hierarchical_ctc or
            hierarchical_attention
        dataset: An instance of a `Dataset' class
        label_type (string): character or character_capital_divide
        beam_width: (int): the size of beam
        max_decode_len (int): the length of output sequences
            to stop prediction when EOS token have not been emitted.
            This is used for seq2seq models.
        eval_batch_size (int, optional): the batch size when evaluating the model
        temperature (int, optional):
        progressbar (bool, optional): if True, visualize the progressbar
    Returns:
        wer (float): Word error rate
        cer (float): Character error rate
        df_wer_cer (pd.DataFrame): dataframe of substitution, insertion, and deletion
    """
    # Reset data counter
    dataset.reset()

    idx2char = Idx2char(
        vocab_file_path=dataset.vocab_file_path,
        capital_divide=(dataset.label_type == 'character_capital_divide'))

    cer, wer = 0, 0
    sub_char, ins_char, del_char = 0, 0, 0
    sub_word, ins_word, del_word = 0, 0, 0
    num_words, num_chars = 0, 0
    if progressbar:
        pbar = tqdm(total=len(dataset))  # TODO: fix this
    while True:
        batch, is_new_epoch = dataset.next(batch_size=eval_batch_size)

        # Decode the ensemble
        if model_type in ['attention', 'ctc']:
            for i, model in enumerate(models):
                probs_i, perm_idx = model.posteriors(batch['xs'],
                                                     batch['x_lens'],
                                                     temperature=temperature)
                if i == 0:
                    probs = probs_i
                else:
                    probs += probs_i
                # NOTE: probs: `[1 (B), T, num_classes]`
            probs /= len(models)

            best_hyps = model.decode_from_probs(probs,
                                                batch['x_lens'][perm_idx],
                                                beam_width=beam_width,
                                                max_decode_len=max_decode_len)
            ys = batch['ys'][perm_idx]
            y_lens = batch['y_lens'][perm_idx]

        elif model_type in ['hierarchical_attention', 'hierarchical_ctc']:
            raise NotImplementedError

        for b in range(len(batch['xs'])):

            ##############################
            # Reference
            ##############################
            if dataset.is_test:
                str_ref = ys[b][0]
                # NOTE: transcript is seperated by space('_')
            else:
                # Convert from list of index to string
                str_ref = idx2char(ys[b][:y_lens[b]])

            ##############################
            # Hypothesis
            ##############################
            str_hyp = idx2char(best_hyps[b])
            if 'attention' in model.model_type:
                str_hyp = str_hyp.split('>')[0]
                # NOTE: Trancate by the first <EOS>

                # Remove the last space
                if len(str_hyp) > 0 and str_hyp[-1] == '_':
                    str_hyp = str_hyp[:-1]

            # Remove consecutive spaces
            str_hyp = re.sub(r'[_]+', '_', str_hyp)

            ##############################
            # Post-proccessing
            ##############################
            # Remove garbage labels
            str_ref = re.sub(r'[\'>]+', '', str_ref)
            str_hyp = re.sub(r'[\'>]+', '', str_hyp)
            # TODO: WER計算するときに消していい?

            # Compute WER
            wer_b, sub_b, ins_b, del_b = compute_wer(ref=str_ref.split('_'),
                                                     hyp=str_hyp.split('_'),
                                                     normalize=False)
            wer += wer_b
            sub_word += sub_b
            ins_word += ins_b
            del_word += del_b
            num_words += len(str_ref.split('_'))

            # Compute CER
            cer_b, sub_b, ins_b, del_b = compute_wer(
                ref=list(str_ref.replace('_', '')),
                hyp=list(str_hyp.replace('_', '')),
                normalize=False)
            cer += cer_b
            sub_char += sub_b
            ins_char += ins_b
            del_char += del_b
            num_chars += len(str_ref.replace('_', ''))

            if progressbar:
                pbar.update(1)

        if is_new_epoch:
            break

    if progressbar:
        pbar.close()

    wer /= num_words
    cer /= num_chars
    sub_char /= num_chars
    ins_char /= num_chars
    del_char /= num_chars
    sub_word /= num_words
    ins_word /= num_words
    del_word /= num_words

    df_wer_cer = pd.DataFrame(
        {
            'SUB': [sub_char * 100, sub_word * 100],
            'INS': [ins_char * 100, ins_word * 100],
            'DEL': [del_char * 100, del_word * 100]
        },
        columns=['SUB', 'INS', 'DEL'],
        index=['CER', 'WER'])

    return cer, wer, df_wer_cer
def decode(model,
           dataset,
           eval_batch_size,
           beam_width,
           length_penalty,
           save_path=None):
    """Visualize label outputs.
    Args:
        model: the model to evaluate
        dataset: An instance of a `Dataset` class
        eval_batch_size (int): the batch size when evaluating the model
        beam_width: (int): the size of beam
        length_penalty (float):
        save_path (string): path to save decoding results
    """
    if 'word' in dataset.label_type:
        map_fn = Idx2word(dataset.vocab_file_path)
        max_decode_len = MAX_DECODE_LEN_WORD
    else:
        map_fn = Idx2char(dataset.vocab_file_path)
        max_decode_len = MAX_DECODE_LEN_CHAR

    if save_path is not None:
        sys.stdout = open(join(model.model_dir, 'decode.txt'), 'w')

    for batch, is_new_epoch in dataset:

        # Decode
        if model.model_type == 'nested_attention':
            best_hyps, _, best_hyps_sub, _, perm_idx = model.decode(
                batch['xs'],
                batch['x_lens'],
                beam_width=beam_width,
                max_decode_len=max_decode_len,
                max_decode_len_sub=max_decode_len)
        else:
            best_hyps, _, perm_idx = model.decode(
                batch['xs'],
                batch['x_lens'],
                beam_width=beam_width,
                max_decode_len=max_decode_len)

        if model.model_type == 'attention' and model.ctc_loss_weight > 0:
            best_hyps_ctc, perm_idx = model.decode_ctc(batch['xs'],
                                                       batch['x_lens'],
                                                       beam_width=beam_width)

        ys = batch['ys'][perm_idx]
        y_lens = batch['y_lens'][perm_idx]

        for b in range(len(batch['xs'])):
            ##############################
            # Reference
            ##############################
            if dataset.is_test:
                str_ref = ys[b][0]
                # NOTE: transcript is seperated by space('_')
            else:
                # Convert from list of index to string
                str_ref = map_fn(ys[b][:y_lens[b]])

            ##############################
            # Hypothesis
            ##############################
            # Convert from list of index to string
            str_hyp = map_fn(best_hyps[b])

            print('----- wav: %s -----' % batch['input_names'][b])
            print('Ref: %s' % str_ref.replace('_', ' '))
            print('Hyp: %s' % str_hyp.replace('_', ' '))
            if model.model_type == 'attention' and model.ctc_loss_weight > 0:
                str_hyp_ctc = map_fn(best_hyps_ctc[b])
                print('Hyp (CTC): %s' % str_hyp_ctc)

            try:
                if dataset.label_type == 'word' or dataset.label_type == 'kanji_wb':
                    wer, _, _, _ = compute_wer(ref=str_ref.split('_'),
                                               hyp=re.sub(
                                                   r'(.*)[_]*>(.*)', r'\1',
                                                   str_hyp).split('_'),
                                               normalize=True)
                    print('WER: %.3f %%' % (wer * 100))
                    if model.model_type == 'attention' and model.ctc_loss_weight > 0:
                        wer_ctc, _, _, _ = compute_wer(
                            ref=str_ref.split('_'),
                            hyp=str_hyp_ctc.split('_'),
                            normalize=True)
                        print('WER (CTC): %.3f %%' % (wer_ctc * 100))
                else:
                    cer, _, _, _ = compute_wer(
                        ref=list(str_ref.replace('_', '')),
                        hyp=list(
                            re.sub(r'(.*)>(.*)', r'\1',
                                   str_hyp).replace('_', '')),
                        normalize=True)
                    print('CER: %.3f %%' % (cer * 100))
                    if model.model_type == 'attention' and model.ctc_loss_weight > 0:
                        cer_ctc, _, _, _ = compute_wer(
                            ref=list(str_ref.replace('_', '')),
                            hyp=list(str_hyp_ctc.replace('_', '')),
                            normalize=True)
                        print('CER (CTC): %.3f %%' % (cer_ctc * 100))
            except:
                print('--- skipped ---')

        if is_new_epoch:
            break
def decode(model, dataset, beam_width, eval_batch_size=None, save_path=None):
    """Visualize label outputs.
    Args:
        model: the model to evaluate
        dataset: An instance of a `Dataset` class
        beam_width: (int): the size of beam
        eval_batch_size (int, optional): the batch size when evaluating the model
        save_path (string): path to save decoding results
    """
    # Set batch size in the evaluation
    if eval_batch_size is not None:
        dataset.batch_size = eval_batch_size

    if 'char' in dataset.label_type:
        map_fn = Idx2char(
            dataset.vocab_file_path,
            capital_divide=dataset.label_type == 'character_capital_divide')
        max_decode_len = MAX_DECODE_LEN_CHAR
    else:
        map_fn = Idx2word(dataset.vocab_file_path)
        max_decode_len = MAX_DECODE_LEN_WORD

    # Read GLM file
    glm = GLM(
        glm_path=
        '/n/sd8/inaguma/corpus/swbd/data/eval2000/LDC2002T43/reference/en20000405_hub5.glm'
    )

    if save_path is not None:
        sys.stdout = open(join(model.model_dir, 'decode.txt'), 'w')

    for batch, is_new_epoch in dataset:

        # Decode
        if model.model_type == 'nested_attention':
            best_hyps, _, best_hyps_sub, _, perm_idx = model.decode(
                batch['xs'],
                batch['x_lens'],
                beam_width=beam_width,
                max_decode_len=max_decode_len,
                max_decode_len_sub=max_decode_len)
        else:
            best_hyps, _, perm_idx = model.decode(
                batch['xs'],
                batch['x_lens'],
                beam_width=beam_width,
                max_decode_len=max_decode_len)

        if model.model_type == 'attention' and model.ctc_loss_weight > 0:
            best_hyps_ctc, perm_idx = model.decode_ctc(batch['xs'],
                                                       batch['x_lens'],
                                                       beam_width=beam_width)

        ys = batch['ys'][perm_idx]
        y_lens = batch['y_lens'][perm_idx]

        for b in range(len(batch['xs'])):
            ##############################
            # Reference
            ##############################
            if dataset.is_test:
                str_ref_original = ys[b][0]
                # NOTE: transcript is seperated by space('_')
            else:
                # Convert from list of index to string
                str_ref_original = map_fn(ys[b][:y_lens[b]])

            ##############################
            # Hypothesis
            ##############################
            # Convert from list of index to string
            str_hyp = map_fn(best_hyps[b])

            ##############################
            # Post-proccessing
            ##############################
            str_ref = fix_trans(str_ref_original, glm)
            str_hyp = fix_trans(str_hyp, glm)

            if len(str_ref) == 0:
                continue

            print('----- wav: %s -----' % batch['input_names'][b])
            print('Ref: %s' % str_ref.replace('_', ' '))
            print('Hyp: %s' % str_hyp.replace('_', ' '))
            if model.model_type == 'attention' and model.ctc_loss_weight > 0:
                str_hyp_ctc = map_fn(best_hyps_ctc[b])
                print('Hyp (CTC): %s' % str_hyp_ctc)

            try:
                # Compute WER
                wer, _, _, _ = compute_wer(
                    ref=str_ref.split('_'),
                    hyp=str_hyp.replace(r'_>.*', '').replace(r'>.*',
                                                             '').split('_'),
                    normalize=True)
                print('WER: %.3f %%' % (wer * 100))
                if model.ctc_loss_weight > 0:
                    wer_ctc, _, _, _ = compute_wer(ref=str_ref.split('_'),
                                                   hyp=str_hyp_ctc.split('_'),
                                                   normalize=True)
                    print('WER (CTC): %.3f %%' % (wer_ctc * 100))
            except:
                print('--- skipped ---')

        if is_new_epoch:
            break