コード例 #1
0
def plot(model, dataset, beam_width,
         eval_batch_size=None, save_path=None):
    """Visualize attention weights of attetnion-based model.
    Args:
        model: model to evaluate
        dataset: An instance of a `Dataset` class
        beam_width: (int): the size of beam
        eval_batch_size (int, optional): the batch size when evaluating the model
        save_path (string, optional): path to save attention weights plotting
    """
    # Clean directory
    if save_path is not None and isdir(save_path):
        shutil.rmtree(save_path)
        mkdir(save_path)

    if 'char' in dataset.label_type:
        map_fn = Idx2char(dataset.vocab_file_path,
                          capital_divide=dataset.label_type == 'character_capital_divide',
                          return_list=True)
        max_decode_len = MAX_DECODE_LEN_CHAR
    else:
        map_fn = Idx2word(dataset.vocab_file_path, return_list=True)
        max_decode_len = MAX_DECODE_LEN_WORD

    for batch, is_new_epoch in dataset:

        # Decode
        best_hyps, aw, perm_idx = model.attention_weights(
            batch['xs'], batch['x_lens'],
            beam_width=beam_width,
            max_decode_len=max_decode_len)

        ys = batch['ys'][perm_idx]
        y_lens = batch['y_lens'][perm_idx]

        for b in range(len(batch['xs'])):
            ##############################
            # Reference
            ##############################
            if dataset.is_test:
                str_ref = ys[b][0]
                # NOTE: transcript is seperated by space('_')
            else:
                # Convert from list of index to string
                str_ref = map_fn(ys[b][:y_lens[b]])

            token_list = map_fn(best_hyps[b])

            speaker = '_'.join(batch['input_names'][b].split('_')[:2])
            plot_attention_weights(
                aw[b, :len(token_list), :batch['x_lens'][b]],
                label_list=token_list,
                spectrogram=batch['xs'][b, :, :dataset.input_freq],
                str_ref=str_ref,
                save_path=mkdir_join(save_path, speaker,
                                     batch['input_names'][b] + '.png'),
                figsize=(20, 8))

        if is_new_epoch:
            break
def plot(model, dataset, eval_batch_size, save_path=None):
    """
    Args:
        model: the model to evaluate
        dataset: An instance of a `Dataset` class
        eval_batch_size (int): the batch size when evaluating the model
        save_path (string): path to save figures of CTC posteriors
    """
    # Set batch size in the evaluation
    if eval_batch_size is not None:
        dataset.batch_size = eval_batch_size

    # Clean directory
    if isdir(save_path):
        shutil.rmtree(save_path)
        mkdir(save_path)

    idx2word = Idx2word(dataset.vocab_file_path)
    idx2char = Idx2char(
        dataset.vocab_file_path,
        capital_divide=dataset.label_type_sub == 'character_capital_divide')

    for batch, is_new_epoch in dataset:

        # Get CTC probs
        probs = model.posteriors(batch['xs'], batch['x_lens'], temperature=1)
        probs_sub = model.posteriors(batch['xs'],
                                     batch['x_lens'],
                                     is_sub_task=True,
                                     temperature=1)
        # NOTE: probs: '[B, T, num_classes]'
        # NOTE: probs_sub: '[B, T, num_classes_sub]'

        # Decode
        best_hyps = model.decode(batch['xs'], batch['x_lens'], beam_width=1)
        best_hyps_sub = model.decode(batch['xs'],
                                     batch['x_lens'],
                                     beam_width=1,
                                     is_sub_task=True)

        # Visualize
        for b in range(len(batch['xs'])):

            # Convert from list of index to string
            str_hyp = idx2word(best_hyps[b])
            str_hyp_sub = idx2char(best_hyps_sub[b])

            speaker = batch['input_names'][b].split('_')[0]
            plot_hierarchical_ctc_probs(probs[b, :batch['x_lens'][b], :],
                                        probs_sub[b, :batch['x_lens'][b], :],
                                        frame_num=batch['x_lens'][b],
                                        num_stack=dataset.num_stack,
                                        str_hyp=str_hyp,
                                        str_hyp_sub=str_hyp_sub,
                                        save_path=mkdir_join(
                                            save_path, speaker,
                                            batch['input_names'][b] + '.png'))

        if is_new_epoch:
            break
コード例 #3
0
def plot(model, dataset, eval_batch_size=None, save_path=None,
         space_index=None):
    """
    Args:
        model: the model to evaluate
        dataset: An instance of a `Dataset` class
        eval_batch_size (int, optional): the batch size when evaluating the model
        save_path (string): path to save figures of CTC posteriors
        space_index (int, optional):
    """
    # Set batch size in the evaluation
    if eval_batch_size is not None:
        dataset.batch_size = eval_batch_size

    # Clean directory
    if isdir(save_path):
        shutil.rmtree(save_path)
        mkdir(save_path)

    vocab_file_path = '../metrics/vocab_files/' + \
        dataset.label_type + '_' + dataset.data_size + '.txt'
    if dataset.label_type == 'character':
        map_fn = Idx2char(vocab_file_path)
    elif dataset.label_type == 'character_capital_divide':
        map_fn = Idx2char(vocab_file_path, capital_divide=True)
    else:
        map_fn = Idx2word(vocab_file_path)

    for batch, is_new_epoch in dataset:

        # Get CTC probs
        probs = model.posteriors(batch['xs'], batch['x_lens'], temperature=1)
        # NOTE: probs: '[B, T, num_classes]'

        # Decode
        best_hyps _ = model.decode(batch['xs'], batch['x_lens'], beam_width=1)

        # Visualize
        for b in range(len(batch['xs'])):

            # Convert from list of index to string
            str_pred = map_fn(best_hyps[b])

            speaker, book = batch['input_names'][b].split('-')[:2]
            plot_ctc_probs(
                probs[b, :batch['x_lens'][b], :],
                frame_num=batch['x_lens'][b],
                num_stack=dataset.num_stack,
                space_index=space_index,
                str_pred=str_pred,
                save_path=mkdir_join(save_path, speaker, book, batch['input_names'][b] + '.png'))

        if is_new_epoch:
            break
コード例 #4
0
def plot(model, dataset, eval_batch_size, beam_width, beam_width_sub,
         length_penalty, save_path=None):
    """Visualize attention weights of Attetnion-based model.
    Args:
        model: model to evaluate
        dataset: An instance of a `Dataset` class
        eval_batch_size (int): the batch size when evaluating the model
        beam_width: (int): the size of beam in the main task
        beam_width_sub: (int): the size of beam in the sub task
        length_penalty (float):
        save_path (string, optional): path to save attention weights plotting
    """
    # Clean directory
    if save_path is not None and isdir(save_path):
        shutil.rmtree(save_path)
        mkdir(save_path)

    map_fn_main = Idx2word(dataset.vocab_file_path, return_list=True)
    map_fn_sub = Idx2char(dataset.vocab_file_path_sub, return_list=True)

    for batch, is_new_epoch in dataset:

        # Decode
        best_hyps, aw, perm_idx = model.decode(
            batch['xs'], batch['x_lens'],
            beam_width=beam_width,
            max_decode_len=MAX_DECODE_LEN_WORD)
        best_hyps_sub, aw_sub, _ = model.decode(
            batch['xs'], batch['x_lens'],
            beam_width=beam_width_sub,
            max_decode_len=MAX_DECODE_LEN_CHAR,
            task_index=1)

        for b in range(len(batch['xs'])):

            word_list = map_fn_main(best_hyps[b])
            char_list = map_fn_sub(best_hyps_sub[b])

            speaker = batch['input_names'][b].split('_')[0]

            plot_hierarchical_attention_weights(
                aw[b][:len(word_list), :batch['x_lens'][b]],
                aw_sub[b][:len(char_list), :batch['x_lens'][b]],
                label_list=word_list,
                label_list_sub=char_list,
                spectrogram=batch['xs'][b, :, :dataset.input_freq],
                save_path=mkdir_join(save_path, speaker,
                                     batch['input_names'][b] + '.png'),
                figsize=(40, 8)
            )

        if is_new_epoch:
            break
    def __init__(self, data_save_path,
                 backend, input_freq, use_delta, use_double_delta,
                 data_type, data_size, label_type, label_type_sub,
                 batch_size, max_epoch=None, splice=1,
                 num_stack=1, num_skip=1,
                 min_frame_num=40,
                 shuffle=False, sort_utt=False, reverse=False,
                 sort_stop_epoch=None, num_gpus=1, tool='htk',
                 num_enque=None, dynamic_batching=False):
        """A class for loading dataset.
        Args:
            data_save_path (string): path to saved data
            backend (string): pytorch or chainer
            input_freq (int): the number of dimensions of acoustics
            use_delta (bool): if True, use the delta feature
            use_double_delta (bool): if True, use the acceleration feature
            data_type (string): train or dev_clean or dev_other or test_clean
                or test_other
            data_size (string): 100 or 460 or 960
            label_type (string): word
            label_type_sub (string): characater or characater_capital_divide
            batch_size (int): the size of mini-batch
            max_epoch (int): the max epoch. None means infinite loop.
            splice (int): frames to splice. Default is 1 frame.
            num_stack (int): the number of frames to stack
            num_skip (int): the number of frames to skip
            shuffle (bool): if True, shuffle utterances. This is
                disabled when sort_utt is True.
            sort_utt (bool): if True, sort all utterances in the
                ascending order
            reverse (bool): if True, sort utteraces in the
                descending order
            sort_stop_epoch (int): After sort_stop_epoch, training
                will revert back to a random order
            num_gpus (int): the number of GPUs
            tool (string): htk or librosa or python_speech_features
            num_enque (int): the number of elements to enqueue
            dynamic_batching (bool): if True, batch size will be
                chainged dynamically in training
        """
        self.backend = backend
        self.input_freq = input_freq
        self.use_delta = use_delta
        self.use_double_delta = use_double_delta
        self.data_type = data_type
        self.data_size = data_size
        self.label_type = label_type
        self.label_type_sub = label_type_sub
        self.batch_size = batch_size * num_gpus
        self.max_epoch = max_epoch
        self.splice = splice
        self.num_stack = num_stack
        self.num_skip = num_skip
        self.shuffle = shuffle
        self.sort_utt = sort_utt
        self.sort_stop_epoch = sort_stop_epoch
        self.num_gpus = num_gpus
        self.tool = tool
        self.num_enque = num_enque
        self.dynamic_batching = dynamic_batching
        self.is_test = True if 'test' in data_type else False

        self.vocab_file_path = join(
            data_save_path, 'vocab', data_size, label_type + '.txt')
        self.idx2word = Idx2word(self.vocab_file_path)
        self.word2idx = Word2idx(self.vocab_file_path)
        self.vocab_file_path_sub = join(
            data_save_path, 'vocab', data_size, label_type_sub + '.txt')
        self.idx2char = Idx2char(
            self.vocab_file_path_sub,
            capital_divide=label_type_sub == 'character_capital_divide')
        self.char2idx = Char2idx(
            self.vocab_file_path_sub,
            capital_divide=label_type_sub == 'character_capital_divide')

        super(Dataset, self).__init__(vocab_file_path=self.vocab_file_path,
                                      vocab_file_path_sub=self.vocab_file_path_sub)

        # Load dataset file
        dataset_path = join(
            data_save_path, 'dataset', tool, data_size, data_type, label_type + '.csv')
        dataset_path_sub = join(
            data_save_path, 'dataset', tool, data_size, data_type, label_type_sub + '.csv')
        df = pd.read_csv(dataset_path)
        df = df.loc[:, ['frame_num', 'input_path', 'transcript']]
        df_sub = pd.read_csv(dataset_path_sub)
        df_sub = df_sub.loc[:, ['frame_num', 'input_path', 'transcript']]

        # Remove inappropriate utteraces
        if not self.is_test:
            logger.info('Original utterance num: %d' % len(df))
            df = df[df.apply(
                lambda x: min_frame_num <= x['frame_num'], axis=1)]
            logger.info('Restricted utterance num: %d' % len(df))

        # Sort paths to input & label
        if sort_utt:
            df = df.sort_values(by='frame_num', ascending=not reverse)
            df_sub = df_sub.sort_values(by='frame_num', ascending=not reverse)
        else:
            df = df.sort_values(by='input_path', ascending=True)
            df_sub = df_sub.sort_values(by='input_path', ascending=True)

        assert len(df) == len(df_sub)

        self.df = df
        self.df_sub = df_sub
        self.rest = set(list(df.index))
def plot(model,
         dataset,
         beam_width,
         beam_width_sub,
         eval_batch_size=None,
         a2c_oracle=False,
         save_path=None):
    """Visualize attention weights of Attetnion-based model.
    Args:
        model: model to evaluate
        dataset: An instance of a `Dataset` class
        beam_width: (int): the size of beam i nteh main task
        beam_width_sub: (int): the size of beam in the sub task
        eval_batch_size (int, optional): the batch size when evaluating the model
        a2c_oracle (bool, optional):
        save_path (string, optional): path to save attention weights plotting
    """
    # Clean directory
    if save_path is not None and isdir(save_path):
        shutil.rmtree(save_path)
        mkdir(save_path)

    idx2word = Idx2word(dataset.vocab_file_path, return_list=True)
    idx2char = Idx2char(dataset.vocab_file_path_sub, return_list=True)

    for batch, is_new_epoch in dataset:
        batch_size = len(batch['xs'])

        if a2c_oracle:
            if dataset.is_test:
                max_label_num = 0
                for b in range(batch_size):
                    if max_label_num < len(list(batch['ys_sub'][b][0])):
                        max_label_num = len(list(batch['ys_sub'][b][0]))

                ys_sub = np.zeros((batch_size, max_label_num), dtype=np.int32)
                ys_sub -= 1  # pad with -1
                y_lens_sub = np.zeros((batch_size, ), dtype=np.int32)
                for b in range(batch_size):
                    indices = char2idx(batch['ys_sub'][b][0])
                    ys_sub[b, :len(indices)] = indices
                    y_lens_sub[b] = len(indices)
                    # NOTE: transcript is seperated by space('_')
            else:
                ys_sub = batch['ys_sub']
                y_lens_sub = batch['y_lens_sub']
        else:
            ys_sub = None
            y_lens_sub = None

        best_hyps, best_hyps_sub, aw, aw_sub, aw_dec = model.attention_weights(
            batch['xs'],
            batch['x_lens'],
            beam_width=beam_width,
            beam_width_sub=beam_width_sub,
            max_decode_len=MAX_DECODE_LEN_WORD,
            max_decode_len_sub=MAX_DECODE_LEN_CHAR,
            teacher_forcing=a2c_oracle,
            ys_sub=ys_sub,
            y_lens_sub=y_lens_sub)

        for b in range(len(batch['xs'])):
            word_list = idx2word(best_hyps[b])
            if 'word' in dataset.label_type_sub:
                char_list = idx2word(best_hyps_sub[b])
            else:
                char_list = idx2char(best_hyps_sub[b])

            # word to acoustic & character to acoustic
            plot_hierarchical_attention_weights(
                aw[b][:len(word_list), :batch['x_lens'][b]],
                aw_sub[b][:len(char_list), :batch['x_lens'][b]],
                label_list=word_list,
                label_list_sub=char_list,
                spectrogram=batch['xs'][b, :, :dataset.input_freq],
                save_path=mkdir_join(save_path,
                                     batch['input_names'][b] + '.png'),
                figsize=(40, 8))

            # word to characater
            plot_word2char_attention_weights(
                aw_dec[b][:len(word_list), :len(char_list)],
                label_list=word_list,
                label_list_sub=char_list,
                save_path=mkdir_join(
                    save_path, batch['input_names'][b] + '_word2char.png'),
                figsize=(40, 8))

            # with open(join(save_path, speaker, batch['input_names'][b] + '.txt'), 'w') as f:
            #     f.write(batch['ys'][b][0])

        if is_new_epoch:
            break
コード例 #7
0
def decode(model, dataset, beam_width, max_decode_len,
           eval_batch_size=None, save_path=None):
    """Visualize label outputs.
    Args:
        model: the model to evaluate
        dataset: An instance of a `Dataset` class
        beam_width: (int): the size of beam
        max_decode_len (int): the length of output sequences
            to stop prediction when EOS token have not been emitted.
            This is used for seq2seq models.
        eval_batch_size (int, optional): the batch size when evaluating the model
        save_path (string): path to save decoding results
    """
    # Set batch size in the evaluation
    if eval_batch_size is not None:
        dataset.batch_size = eval_batch_size

    vocab_file_path = '../metrics/vocab_files/' + \
        dataset.label_type + '_' + dataset.data_size + '.txt'
    if dataset.label_type == 'character':
        map_fn = Idx2char(vocab_file_path)
    elif dataset.label_type == 'character_capital_divide':
        map_fn = Idx2char(vocab_file_path, capital_divide=True)
    else:
        map_fn = Idx2word(vocab_file_path)

    if save_path is not None:
        sys.stdout = open(join(model.model_dir, 'decode.txt'), 'w')

    for batch, is_new_epoch in dataset:

        # Decode
        best_hyps, perm_idx = model.decode(batch['xs'], batch['x_lens'],
                                           beam_width=beam_width,
                                           max_decode_len=max_decode_len)

        if model.model_type == 'attention' and model.ctc_loss_weight > 0:
            best_hyps_ctc, perm_idx = model.decode_ctc(
                batch['xs'], batch['x_lens'], beam_width=beam_width)

        ys = batch['ys'][perm_idx]
        y_lens = batch['y_lens'][perm_idx]

        for b in range(len(batch['xs'])):

            ##############################
            # Reference
            ##############################
            if dataset.is_test:
                str_ref = ys[b][0]
                # NOTE: transcript is seperated by space('_')
            else:
                # Convert from list of index to string
                str_ref = map_fn(ys[b][:y_lens[b]])

            ##############################
            # Hypothesis
            ##############################
            # Convert from list of index to string
            str_hyp = map_fn(best_hyps[b])

            if model.model_type == 'attention':
                str_hyp = str_hyp.split('>')[0]
                # NOTE: Trancate by the first <EOS>

                # Remove the last space
                if len(str_hyp) > 0 and str_hyp[-1] == '_':
                    str_hyp = str_hyp[:-1]

            ##############################
            # Post-proccessing
            ##############################
            # Remove garbage labels
            str_ref = re.sub(r'[\'>]+', '', str_ref)
            str_hyp = re.sub(r'[\'>]+', '', str_hyp)

            print('----- wav: %s -----' % batch['input_names'][b])
            print('Ref: %s' % str_ref.replace('_', ' '))
            print('Hyp: %s' % str_hyp.replace('_', ' '))
            if model.model_type == 'attention' and model.ctc_loss_weight > 0:
                str_hyp_ctc = map_fn(best_hyps_ctc[b])
                print('Hyp (CTC): %s' % str_hyp_ctc)

            # Compute CER
            if 'word' in dataset.label_type:
                wer, _, _, _ = compute_wer(ref=str_ref.split('_'),
                                           hyp=str_hyp.split('_'),
                                           normalize=True)
                print('WER: %f %%' % (wer * 100))
                if model.ctc_loss_weight > 0:
                    wer_ctc, _, _, _ = compute_wer(ref=str_ref.split('_'),
                                                   hyp=str_hyp_ctc.split('_'),
                                                   normalize=True)
                    print('WER (CTC): %f %%' % (wer_ctc * 100))
            else:
                cer, _, _, _ = compute_wer(ref=str_ref.split('_'),
                                           hyp=str_hyp.split('_'),
                                           normalize=True)
                print('CER: %f %%' % (cer * 100))
                if model.model_type == 'attention' and model.ctc_loss_weight > 0:
                    cer_ctc, _, _, _ = compute_wer(
                        ref=list(str_ref.replace('_', '')),
                        hyp=list(str_hyp.replace('_', '')),
                        normalize=True)
                    print('CER (CTC): %f %%' % (cer_ctc * 100))

        if is_new_epoch:
            break
    def check_loading(self, label_type, data_type='dev_clean',
                      shuffle=False,  sort_utt=False, sort_stop_epoch=None,
                      frame_stacking=False, splice=1, num_gpu=1):

        print('========================================')
        print('  label_type: %s' % label_type)
        print('  data_type: %s' % data_type)
        print('  shuffle: %s' % str(shuffle))
        print('  sort_utt: %s' % str(sort_utt))
        print('  sort_stop_epoch: %s' % str(sort_stop_epoch))
        print('  frame_stacking: %s' % str(frame_stacking))
        print('  splice: %d' % splice)
        print('  num_gpu: %d' % num_gpu)
        print('========================================')

        num_stack = 3 if frame_stacking else 1
        num_skip = 3 if frame_stacking else 1
        dataset = Dataset(
            data_type=data_type, train_data_size='train_clean100',
            label_type=label_type,
            batch_size=64, max_epoch=1, splice=splice,
            num_stack=num_stack, num_skip=num_skip,
            shuffle=shuffle, sort_utt=sort_utt, sort_stop_epoch=sort_stop_epoch,
            progressbar=True, num_gpu=num_gpu)

        print('=> Loading mini-batch...')
        if label_type == 'character':
            map_file_path = '../../metrics/mapping_files/ctc/character.txt'
        elif label_type == 'character_capital_divide':
            map_file_path = '../../metrics/mapping_files/ctc/character_capital.txt'
        elif label_type == 'word':
            map_file_path = '../../metrics/mapping_files/ctc/word_' + \
                dataset.train_data_size + '.txt'

        idx2char = Idx2char(map_file_path)
        idx2word = Idx2word(map_file_path)

        for data, is_new_epoch in dataset:
            inputs, labels, inputs_seq_len, input_names = data

            if not self.length_check:
                for i, l in zip(inputs[0], labels[0]):
                    if len(i) < len(l):
                        raise ValueError(
                            'input length must be longer than label length.')
                self.length_check = True

            if num_gpu > 1:
                for inputs_gpu in inputs:
                    print(inputs_gpu.shape)

            if label_type == 'word':
                if 'test' not in data_type:
                    str_true = ' '.join(idx2word(labels[0][0]))
                else:
                    word_list = np.delete(labels[0][0], np.where(
                        labels[0][0] == None), axis=0)
                    str_true = ' '.join(word_list)
            else:
                str_true = idx2char(labels[0][0])
            str_true = re.sub(r'_', ' ', str_true)
            print('----- %s (epoch: %.3f) -----' %
                  (input_names[0][0], dataset.epoch_detail))
            print(inputs[0].shape)
            print(str_true)

            if dataset.epoch_detail >= 0.05:
                break
コード例 #9
0
def eval_word(models,
              dataset,
              beam_width,
              max_decode_len,
              beam_width_sub=1,
              max_decode_len_sub=300,
              eval_batch_size=None,
              length_penalty=0,
              progressbar=False,
              temperature=1,
              resolving_unk=False,
              a2c_oracle=False):
    """Evaluate trained model by Word Error Rate.
    Args:
        models (list): the models to evaluate
        dataset: An instance of a `Dataset' class
        max_decode_len (int): the length of output sequences
            to stop prediction. This is used for seq2seq models.
        beam_width_sub (int, optional): the size of beam in ths sub task
            This is used for the nested attention
        max_decode_len_sub (int, optional): the length of output sequences
            to stop prediction. This is used for the nested attention
        eval_batch_size (int, optional): the batch size when evaluating the model
        progressbar (bool, optional): if True, visualize the progressbar
        temperature (int, optional):
        resolving_unk (bool, optional):
        a2c_oracle (bool, optional):
    Returns:
        wer (float): Word error rate
        df_wer (pd.DataFrame): dataframe of substitution, insertion, and deletion
    """
    # Reset data counter
    dataset.reset()

    idx2word = Idx2word(dataset.vocab_file_path)
    if models[0].model_type == 'nested_attention':
        char2idx = Char2idx(dataset.vocab_file_path_sub)
    if models[0] in ['ctc', 'attention'] and resolving_unk:
        idx2char = Idx2char(dataset.vocab_file_path_sub,
                            capital_divide=dataset.label_type_sub ==
                            'character_capital_divide')

    wer = 0
    sub, ins, dele, = 0, 0, 0
    num_words = 0
    if progressbar:
        pbar = tqdm(total=len(dataset))  # TODO: fix this
    while True:
        batch, is_new_epoch = dataset.next(batch_size=eval_batch_size)

        batch_size = len(batch['xs'])

        # Decode
        if len(models) > 1:
            assert models[0].model_type in ['ctc']
            for i, model in enumerate(models):
                probs, x_lens, perm_idx = model.posteriors(
                    batch['xs'], batch['x_lens'])
                if i == 0:
                    probs_ensenmble = probs
                else:
                    probs_ensenmble += probs
            probs_ensenmble /= len(models)

            best_hyps = models[0].decode_from_probs(probs_ensenmble,
                                                    x_lens,
                                                    beam_width=1)
        else:
            model = models[0]
            # TODO: fix this

            if model.model_type == 'nested_attention':
                if a2c_oracle:
                    if dataset.is_test:
                        max_label_num = 0
                        for b in range(batch_size):
                            if max_label_num < len(list(
                                    batch['ys_sub'][b][0])):
                                max_label_num = len(list(
                                    batch['ys_sub'][b][0]))

                        ys_sub = np.zeros((batch_size, max_label_num),
                                          dtype=np.int32)
                        ys_sub -= 1  # pad with -1
                        y_lens_sub = np.zeros((batch_size, ), dtype=np.int32)
                        for b in range(batch_size):
                            indices = char2idx(batch['ys_sub'][b][0])
                            ys_sub[b, :len(indices)] = indices
                            y_lens_sub[b] = len(indices)
                            # NOTE: transcript is seperated by space('_')
                else:
                    ys_sub = batch['ys_sub']
                    y_lens_sub = batch['y_lens_sub']

                best_hyps, aw, best_hyps_sub, aw_sub, perm_idx = model.decode(
                    batch['xs'],
                    batch['x_lens'],
                    beam_width=beam_width,
                    beam_width_sub=beam_width_sub,
                    max_decode_len=max_decode_len,
                    max_decode_len_sub=max_label_num
                    if a2c_oracle else max_decode_len_sub,
                    length_penalty=length_penalty,
                    teacher_forcing=a2c_oracle,
                    ys_sub=ys_sub,
                    y_lens_sub=y_lens_sub)
            else:
                best_hyps, aw, perm_idx = model.decode(
                    batch['xs'],
                    batch['x_lens'],
                    beam_width=beam_width,
                    max_decode_len=max_decode_len,
                    length_penalty=length_penalty)
                if resolving_unk:
                    best_hyps_sub, aw_sub, _ = model.decode(
                        batch['xs'],
                        batch['x_lens'],
                        beam_width=beam_width,
                        max_decode_len=max_decode_len_sub,
                        length_penalty=length_penalty,
                        task_index=1)

        ys = batch['ys'][perm_idx]
        y_lens = batch['y_lens'][perm_idx]

        for b in range(batch_size):
            ##############################
            # Reference
            ##############################
            if dataset.is_test:
                str_ref = ys[b][0]
                # NOTE: transcript is seperated by space('_')
            else:
                # Convert from list of index to string
                str_ref = idx2word(ys[b][:y_lens[b]])

            ##############################
            # Hypothesis
            ##############################
            str_hyp = idx2word(best_hyps[b])
            if dataset.label_type == 'word':
                str_hyp = re.sub(r'(.*)_>(.*)', r'\1', str_hyp)
            else:
                str_hyp = re.sub(r'(.*)>(.*)', r'\1', str_hyp)
            # NOTE: Trancate by the first <EOS>

            ##############################
            # Resolving UNK
            ##############################
            if resolving_unk and 'OOV' in str_hyp:
                str_hyp = resolve_unk(str_hyp, best_hyps_sub[b], aw[b],
                                      aw_sub[b], idx2char)
                str_hyp = str_hyp.replace('*', '')

            ##############################
            # Post-proccessing
            ##############################
            # Remove garbage labels
            str_ref = re.sub(r'[@>]+', '', str_ref)
            str_hyp = re.sub(r'[@>]+', '', str_hyp)
            # NOTE: @ means noise

            # Remove consecutive spaces
            str_ref = re.sub(r'[_]+', '_', str_ref)
            str_hyp = re.sub(r'[_]+', '_', str_hyp)

            # Compute WER
            try:
                wer_b, sub_b, ins_b, del_b = compute_wer(
                    ref=str_ref.split('_'),
                    hyp=str_hyp.split('_'),
                    normalize=False)
                wer += wer_b
                sub += sub_b
                ins += ins_b
                dele += del_b
                num_words += len(str_ref.split('_'))
            except:
                pass

            if progressbar:
                pbar.update(1)

        if is_new_epoch:
            break

    if progressbar:
        pbar.close()

    # Reset data counters
    dataset.reset()

    wer /= num_words
    sub /= num_words
    ins /= num_words
    dele /= num_words

    df_wer = pd.DataFrame(
        {
            'SUB': [sub * 100],
            'INS': [ins * 100],
            'DEL': [dele * 100]
        },
        columns=['SUB', 'INS', 'DEL'],
        index=['WER'])

    return wer, df_wer
コード例 #10
0
    def check(self,
              label_type,
              data_type='dev_clean',
              shuffle=False,
              sort_utt=False,
              sort_stop_epoch=None,
              frame_stacking=False,
              splice=1,
              num_gpu=1):

        print('========================================')
        print('  label_type: %s' % label_type)
        print('  data_type: %s' % data_type)
        print('  shuffle: %s' % str(shuffle))
        print('  sort_utt: %s' % str(sort_utt))
        print('  sort_stop_epoch: %s' % str(sort_stop_epoch))
        print('  frame_stacking: %s' % str(frame_stacking))
        print('  splice: %d' % splice)
        print('  num_gpu: %d' % num_gpu)
        print('========================================')

        num_stack = 3 if frame_stacking else 1
        num_skip = 3 if frame_stacking else 1
        dataset = Dataset(data_type=data_type,
                          train_data_size='train100h',
                          label_type=label_type,
                          batch_size=64,
                          max_epoch=2,
                          splice=splice,
                          num_stack=num_stack,
                          num_skip=num_skip,
                          shuffle=shuffle,
                          sort_utt=sort_utt,
                          sort_stop_epoch=sort_stop_epoch,
                          progressbar=True,
                          num_gpu=num_gpu)

        print('=> Loading mini-batch...')
        if label_type == 'character':
            map_file_path = '../../metrics/mapping_files/character.txt'
        else:
            map_file_path = '../../metrics/mapping_files/' + label_type + '_' + \
                dataset.train_data_size + '.txt'

        idx2char = Idx2char(map_file_path)
        idx2word = Idx2word(map_file_path)

        for data, is_new_epoch in dataset:
            inputs, labels, inputs_seq_len, input_names = data

            if data_type == 'train':
                for i, l in zip(inputs[0], labels[0]):
                    if len(i) < len(l):
                        raise ValueError(
                            'input length must be longer than label length.')

            if num_gpu > 1:
                for inputs_gpu in inputs:
                    print(inputs_gpu.shape)

            if 'test' in data_type:
                str_true = labels[0][0][0]
            else:
                if 'word' in label_type:
                    str_true = '_'.join(idx2word(labels[0][0]))
                else:
                    str_true = idx2char(labels[0][0])

            print('----- %s (epoch: %.3f) -----' %
                  (input_names[0][0], dataset.epoch_detail))
            print(inputs[0].shape)
            print(str_true)

            if dataset.epoch_detail >= 0.1:
                break
コード例 #11
0
    def check(self,
              label_type,
              data_type='dev',
              data_size='all',
              backend='pytorch',
              shuffle=False,
              sort_utt=True,
              sort_stop_epoch=None,
              frame_stacking=False,
              splice=1,
              num_gpus=1):

        print('========================================')
        print('  backend: %s' % backend)
        print('  label_type: %s' % label_type)
        print('  data_type: %s' % data_type)
        print('  data_size: %s' % data_size)
        print('  shuffle: %s' % str(shuffle))
        print('  sort_utt: %s' % str(sort_utt))
        print('  sort_stop_epoch: %s' % str(sort_stop_epoch))
        print('  frame_stacking: %s' % str(frame_stacking))
        print('  splice: %d' % splice)
        print('  num_gpus: %d' % num_gpus)
        print('========================================')

        num_stack = 3 if frame_stacking else 1
        num_skip = 3 if frame_stacking else 1
        dataset = Dataset(data_save_path='/n/sd8/inaguma/corpus/csj/kaldi',
                          backend=backend,
                          input_freq=80,
                          use_delta=True,
                          use_double_delta=True,
                          data_type=data_type,
                          data_size=data_size,
                          label_type=label_type,
                          batch_size=64,
                          max_epoch=1,
                          splice=splice,
                          num_stack=num_stack,
                          num_skip=num_skip,
                          min_frame_num=40,
                          shuffle=shuffle,
                          sort_utt=sort_utt,
                          reverse=False,
                          sort_stop_epoch=sort_stop_epoch,
                          num_gpus=num_gpus,
                          tool='htk',
                          num_enque=None)

        print('=> Loading mini-batch...')
        if 'word' in label_type:
            map_fn = Idx2word(dataset.vocab_file_path)
        else:
            map_fn = Idx2char(dataset.vocab_file_path)

        for batch, is_new_epoch in dataset:
            if data_type == 'train' and backend == 'pytorch':
                for i in range(len(batch['xs'])):
                    if batch['xs'].shape[1] < batch['ys'].shape[1]:
                        raise ValueError(
                            'input length must be longer than label length.')

            if dataset.is_test:
                str_true = batch['ys'][0][0]
            else:
                str_true = map_fn(batch['ys'][0][:batch['y_lens'][0]])

            print('----- %s (epoch: %.3f, batch: %d) -----' %
                  (batch['input_names'][0], dataset.epoch_detail,
                   len(batch['xs'])))
            print(str_true)
            print('x_lens: %d' % (batch['x_lens'][0] * num_stack))
            if not dataset.is_test:
                print('y_lens: %d' % batch['y_lens'][0])

            if dataset.epoch_detail >= 0.1:
                break
コード例 #12
0
def plot(model,
         dataset,
         beam_width,
         beam_width_sub,
         eval_batch_size=None,
         save_path=None):
    """Visualize attention weights of Attetnion-based model.
    Args:
        model: model to evaluate
        dataset: An instance of a `Dataset` class
        beam_width: (int): the size of beam in the main task
        beam_width_sub: (int): the size of beam in the sub task
        eval_batch_size (int, optional): the batch size when evaluating the model
        save_path (string, optional): path to save attention weights plotting
    """
    # Clean directory
    if save_path is not None and isdir(save_path):
        shutil.rmtree(save_path)
        mkdir(save_path)

    idx2word = Idx2word(dataset.vocab_file_path, return_list=True)
    idx2char = Idx2char(dataset.vocab_file_path_sub, return_list=True)

    for batch, is_new_epoch in dataset:

        best_hyps, best_hyps_sub, aw, aw_sub, aw_dec = model.attention_weights(
            batch['xs'],
            batch['x_lens'],
            beam_width=beam_width,
            beam_width_sub=beam_width_sub,
            max_decode_len=MAX_DECODE_LEN_WORD,
            max_decode_len_sub=MAX_DECODE_LEN_CHAR)

        for b in range(len(batch['xs'])):

            word_list = idx2word(best_hyps[b])
            char_list = idx2char(best_hyps_sub[b])

            # if word_list.count('OOV') < 1:
            #     continue

            speaker = '_'.join(batch['input_names'][b].split('_')[:2])

            # word to acoustic & character to acoustic
            plot_hierarchical_attention_weights(
                aw[b][:len(word_list), :batch['x_lens'][b]],
                aw_sub[b][:len(char_list), :batch['x_lens'][b]],
                label_list=word_list,
                label_list_sub=char_list,
                spectrogram=batch['xs'][b, :, :dataset.input_freq],
                save_path=mkdir_join(save_path, speaker,
                                     batch['input_names'][b] + '.png'),
                figsize=(50, 10))

            # word to characater attention
            plot_word2char_attention_weights(
                aw_dec[b][:len(word_list), :len(char_list)],
                label_list=word_list,
                label_list_sub=char_list,
                save_path=mkdir_join(
                    save_path, speaker,
                    batch['input_names'][b] + '_word2char.png'),
                figsize=(50, 10))

            with open(
                    join(save_path, speaker, batch['input_names'][b] + '.txt'),
                    'w') as f:
                f.write(batch['ys'][b][0])

        if is_new_epoch:
            break
コード例 #13
0
def decode_test_multitask(session,
                          decode_op_main,
                          decode_op_sub,
                          model,
                          dataset,
                          train_data_size,
                          label_type_main,
                          label_type_sub,
                          is_test=False,
                          save_path=None):
    """Visualize label outputs of Multi-task CTC model.
    Args:
        session: session of training model
        decode_op_main: operation for decoding in the main task
        decode_op_sub: operation for decoding in the sub task
        model: the model to evaluate
        dataset: An instance of a `Dataset` class
        label_type_main (string): word
        label_type_sub (string): character or character_capital_divide
        train_data_size (string, optional): train_clean100 or train_clean360 or
            train_other500 or train_all
        save_path (string, optional): path to save decoding results
    """
    idx2word = Idx2word(map_file_path='../metrics/mapping_files/ctc/word_' +
                        train_data_size + '.txt')
    idx2char = Idx2char(map_file_path='../metrics/mapping_files/ctc/' +
                        label_type_sub + '.txt')

    if save_path is not None:
        sys.stdout = open(join(model.model_dir, 'decode.txt'), 'w')

    while True:

        # Create feed dictionary for next mini batch
        data, is_new_epoch = dataset.next(batch_size=1)
        inputs, labels_true_word, labels_true_char, inputs_seq_len, input_names = data
        # NOTE: Batch size is expected to be 1

        feed_dict = {
            model.inputs_pl_list[0]: inputs[0],
            model.inputs_seq_len_pl_list[0]: inputs_seq_len[0],
            model.keep_prob_input_pl_list[0]: 1.0,
            model.keep_prob_hidden_pl_list[0]: 1.0,
            model.keep_prob_output_pl_list[0]: 1.0
        }

        # Visualize
        labels_pred_st_word, labels_pred_st_char = session.run(
            [decode_op_main, decode_op_sub], feed_dict=feed_dict)
        try:
            labels_pred_word = sparsetensor2list(labels_pred_st_word,
                                                 batch_size=1)
        except IndexError:
            # no output
            labels_pred_word = ['']

        try:
            labels_pred_char = sparsetensor2list(labels_pred_st_char,
                                                 batch_size=1)
        except IndexError:
            # no output
            labels_pred_char = ['']

        print('----- wav: %s -----' % input_names[0][0])
        if dataset.is_test:
            str_true_word = labels_true_word[0][0][0]
        else:
            str_true_word = ' '.join(idx2word(labels_true_word[0][0]))
        str_pred_word = ' '.join(idx2word(labels_pred_word[0]))
        print('Ref (word): %s' % str_true_word)
        print('Hyp (word): %s' % str_pred_word)

        str_true_char = idx2char(labels_true_char[0][0])
        str_pred_char = idx2char(labels_pred_char[0]).replace('_', ' ')
        print('Ref (char): %s' % str_true_char)
        print('Hyp (char): %s' % str_pred_char)

        if is_new_epoch:
            break
コード例 #14
0
def decode_test(session,
                decode_op,
                model,
                dataset,
                label_type,
                train_data_size,
                save_path=None):
    """Visualize label outputs of CTC model.
    Args:
        session: session of training model
        decode_op: operation for decoding
        model: the model to evaluate
        dataset: An instance of a `Dataset` class
        label_type (string):  character or character_capital_divide or word
        train_data_size (string, optional): train_clean100 or train_clean360 or
            train_other500 or train_all
        save_path (string, optional): path to save decoding results
    """
    if label_type == 'character':
        idx2char = Idx2char(
            map_file_path='../metrics/mapping_files/ctc/character.txt')
    elif label_type == 'character_capital_divide':
        idx2char = Idx2char(
            map_file_path=
            '../metrics/mapping_files/ctc/character_capital_divide.txt',
            capital_divide=True)
    elif label_type == 'word':
        idx2word = Idx2word(
            map_file_path='../metrics/mapping_files/ctc/word_' +
            train_data_size + '.txt')

    if save_path is not None:
        sys.stdout = open(join(model.model_dir, 'decode.txt'), 'w')

    while True:

        # Create feed dictionary for next mini batch
        data, is_new_epoch = dataset.next(batch_size=1)
        inputs, labels_true, inputs_seq_len, input_names = data
        # NOTE: Batch size is expected to be 1

        feed_dict = {
            model.inputs_pl_list[0]: inputs[0],
            model.inputs_seq_len_pl_list[0]: inputs_seq_len[0],
            model.keep_prob_input_pl_list[0]: 1.0,
            model.keep_prob_hidden_pl_list[0]: 1.0,
            model.keep_prob_output_pl_list[0]: 1.0
        }

        # Visualize
        labels_pred_st = session.run(decode_op, feed_dict=feed_dict)
        try:
            labels_pred = sparsetensor2list(labels_pred_st, batch_size=1)
        except IndexError:
            # no output
            labels_pred = ['']
        finally:
            print('----- wav: %s -----' % input_names[0][0])
            if label_type == 'character':
                str_true = idx2char(labels_true[0][0]).replace('_', ' ')
                str_pred = idx2char(labels_pred[0]).replace('_', ' ')
            elif label_type == 'character_capital_divide':
                str_true = idx2char(labels_true[0][0])
                str_pred = idx2char(labels_pred[0])
            else:
                if dataset.is_test:
                    str_true = labels_true[0][0][0]
                else:
                    str_true = ' '.join(idx2word(labels_true[0][0]))
                str_pred = ' '.join(idx2word(labels_pred[0]))

            print('Ref: %s' % str_true)
            print('Hyp: %s' % str_pred)
            # wer_align(ref=str_true.split(), hyp=str_pred.split())

        if is_new_epoch:
            break
コード例 #15
0
def decode(model, dataset, beam_width, beam_width_sub,
           eval_batch_size=None, save_path=None, resolving_unk=False):
    """Visualize label outputs.
    Args:
        model: the model to evaluate
        dataset: An instance of a `Dataset` class
        beam_width: (int): the size of beam in the main task
        beam_width: (int): the size of beam in the sub task
        eval_batch_size (int, optional): the batch size when evaluating the model
        save_path (string): path to save decoding results
        resolving_unk (bool, optional):
    """
    # Set batch size in the evaluation
    if eval_batch_size is not None:
        dataset.batch_size = eval_batch_size

    idx2word = Idx2word(vocab_file_path=dataset.vocab_file_path)
    idx2char = Idx2char(vocab_file_path=dataset.vocab_file_path_sub,
                        capital_divide=dataset.label_type_sub == 'character_capital_divide')

    # Read GLM file
    glm = GLM(
        glm_path='/n/sd8/inaguma/corpus/swbd/data/eval2000/LDC2002T43/reference/en20000405_hub5.glm')

    if save_path is not None:
        sys.stdout = open(join(model.model_dir, 'decode.txt'), 'w')

    for batch, is_new_epoch in dataset:

        # Decode
        if model.model_type == 'nested_attention':
            best_hyps, aw, best_hyps_sub, aw_sub, perm_idx = model.decode(
                batch['xs'], batch['x_lens'],
                beam_width=beam_width,
                beam_width_sub=beam_width_sub,
                max_decode_len=MAX_DECODE_LEN_WORD,
                max_decode_len_sub=MAX_DECODE_LEN_CHAR)
        else:
            best_hyps, aw, perm_idx = model.decode(
                batch['xs'], batch['x_lens'],
                beam_width=beam_width,
                max_decode_len=MAX_DECODE_LEN_WORD)
            best_hyps_sub, aw_sub, _ = model.decode(
                batch['xs'], batch['x_lens'],
                beam_width=beam_width_sub,
                max_decode_len=MAX_DECODE_LEN_CHAR,
                task_index=1)

        ys = batch['ys'][perm_idx]
        y_lens = batch['y_lens'][perm_idx]
        ys_sub = batch['ys_sub'][perm_idx]
        y_lens_sub = batch['y_lens_sub'][perm_idx]

        for b in range(len(batch['xs'])):

            ##############################
            # Reference
            ##############################
            if dataset.is_test:
                str_ref_original = ys[b][0]
                str_ref_sub = ys_sub[b][0]
                # NOTE: transcript is seperated by space('_')
            else:
                # Convert from list of index to string
                str_ref_original = idx2word(ys[b][: y_lens[b]])
                str_ref_sub = idx2word(ys_sub[b][:y_lens_sub[b]])

            ##############################
            # Hypothesis
            ##############################
            # Convert from list of index to string
            str_hyp = idx2word(best_hyps[b])
            str_hyp_sub = idx2char(best_hyps_sub[b])

            ##############################
            # Resolving UNK
            ##############################
            if 'OOV' in str_hyp and resolving_unk:
                str_hyp_no_unk = resolve_unk(
                    str_hyp, best_hyps_sub[b], aw[b], aw_sub[b], idx2char)

            # if 'OOV' not in str_hyp:
            #     continue

            ##############################
            # Post-proccessing
            ##############################
            str_ref = fix_trans(str_ref_original, glm)
            str_ref_sub = fix_trans(str_ref_sub, glm)
            str_hyp = fix_trans(str_hyp, glm)
            str_hyp_sub = fix_trans(str_hyp_sub, glm)
            str_hyp_no_unk = fix_trans(str_hyp_no_unk, glm)

            if len(str_ref) == 0:
                continue

            print('----- wav: %s -----' % batch['input_names'][b])
            print('Ref         : %s' % str_ref.replace('_', ' '))
            print('Hyp (main)  : %s' % str_hyp.replace('_', ' '))
            print('Hyp (sub)   : %s' % str_hyp_sub.replace('_', ' '))
            if 'OOV' in str_hyp and resolving_unk:
                print('Hyp (no UNK): %s' % str_hyp_no_unk.replace('_', ' '))

            try:
                # Compute WER
                wer, _, _, _ = compute_wer(
                    ref=str_ref.split('_'),
                    hyp=str_hyp.replace(r'_>.*', '').split('_'),
                    normalize=True)
                print('WER (main)  : %.3f %%' % (wer * 100))
                wer_sub, _, _, _ = compute_wer(
                    ref=str_ref_sub.split('_'),
                    hyp=str_hyp_sub.replace(r'>.*', '').split('_'),
                    normalize=True)
                print('WER (sub)   : %.3f %%' % (wer_sub * 100))
                if 'OOV' in str_hyp and resolving_unk:
                    wer_no_unk, _, _, _ = compute_wer(
                        ref=str_ref.split('_'),
                        hyp=str_hyp_no_unk.replace(
                            '*', '').replace(r'_>.*', '').split('_'),
                        normalize=True)
                    print('WER (no UNK): %.3f %%' % (wer_no_unk * 100))
            except:
                print('--- skipped ---')

        if is_new_epoch:
            break
コード例 #16
0
def decode(model, dataset, beam_width, eval_batch_size=None, save_path=None):
    """Visualize label outputs.
    Args:
        model: the model to evaluate
        dataset: An instance of a `Dataset` class
        beam_width: (int): the size of beam
        eval_batch_size (int, optional): the batch size when evaluating the model
        save_path (string): path to save decoding results
    """
    # Set batch size in the evaluation
    if eval_batch_size is not None:
        dataset.batch_size = eval_batch_size

    if 'char' in dataset.label_type:
        map_fn = Idx2char(
            dataset.vocab_file_path,
            capital_divide=dataset.label_type == 'character_capital_divide')
        max_decode_len = MAX_DECODE_LEN_CHAR
    else:
        map_fn = Idx2word(dataset.vocab_file_path)
        max_decode_len = MAX_DECODE_LEN_WORD

    # Read GLM file
    glm = GLM(
        glm_path=
        '/n/sd8/inaguma/corpus/swbd/data/eval2000/LDC2002T43/reference/en20000405_hub5.glm'
    )

    if save_path is not None:
        sys.stdout = open(join(model.model_dir, 'decode.txt'), 'w')

    for batch, is_new_epoch in dataset:

        # Decode
        if model.model_type == 'nested_attention':
            best_hyps, _, best_hyps_sub, _, perm_idx = model.decode(
                batch['xs'],
                batch['x_lens'],
                beam_width=beam_width,
                max_decode_len=max_decode_len,
                max_decode_len_sub=max_decode_len)
        else:
            best_hyps, _, perm_idx = model.decode(
                batch['xs'],
                batch['x_lens'],
                beam_width=beam_width,
                max_decode_len=max_decode_len)

        if model.model_type == 'attention' and model.ctc_loss_weight > 0:
            best_hyps_ctc, perm_idx = model.decode_ctc(batch['xs'],
                                                       batch['x_lens'],
                                                       beam_width=beam_width)

        ys = batch['ys'][perm_idx]
        y_lens = batch['y_lens'][perm_idx]

        for b in range(len(batch['xs'])):
            ##############################
            # Reference
            ##############################
            if dataset.is_test:
                str_ref_original = ys[b][0]
                # NOTE: transcript is seperated by space('_')
            else:
                # Convert from list of index to string
                str_ref_original = map_fn(ys[b][:y_lens[b]])

            ##############################
            # Hypothesis
            ##############################
            # Convert from list of index to string
            str_hyp = map_fn(best_hyps[b])

            ##############################
            # Post-proccessing
            ##############################
            str_ref = fix_trans(str_ref_original, glm)
            str_hyp = fix_trans(str_hyp, glm)

            if len(str_ref) == 0:
                continue

            print('----- wav: %s -----' % batch['input_names'][b])
            print('Ref: %s' % str_ref.replace('_', ' '))
            print('Hyp: %s' % str_hyp.replace('_', ' '))
            if model.model_type == 'attention' and model.ctc_loss_weight > 0:
                str_hyp_ctc = map_fn(best_hyps_ctc[b])
                print('Hyp (CTC): %s' % str_hyp_ctc)

            try:
                # Compute WER
                wer, _, _, _ = compute_wer(
                    ref=str_ref.split('_'),
                    hyp=str_hyp.replace(r'_>.*', '').replace(r'>.*',
                                                             '').split('_'),
                    normalize=True)
                print('WER: %.3f %%' % (wer * 100))
                if model.ctc_loss_weight > 0:
                    wer_ctc, _, _, _ = compute_wer(ref=str_ref.split('_'),
                                                   hyp=str_hyp_ctc.split('_'),
                                                   normalize=True)
                    print('WER (CTC): %.3f %%' % (wer_ctc * 100))
            except:
                print('--- skipped ---')

        if is_new_epoch:
            break
def decode(session, decode_op_main, decode_op_sub, model,
           dataset, train_data_size, label_type_main,
           label_type_sub, is_test=True, save_path=None):
    """Visualize label outputs of Multi-task CTC model.
    Args:
        session: session of training model
        decode_op_main: operation for decoding in the main task
        decode_op_sub: operation for decoding in the sub task
        model: the model to evaluate
        dataset: An instance of a `Dataset` class
        label_type_main (string): word
        label_type_sub (string): character or character_capital_divide
        train_data_size (string, optional): train100h or train460h or
            train960h
        is_test (bool, optional): set to True when evaluating by the test set
        save_path (string, optional): path to save decoding results
    """
    idx2word = Idx2word(
        map_file_path='../metrics/mapping_files/word_' + train_data_size + '.txt')
    idx2char = Idx2char(
        map_file_path='../metrics/mapping_files/' + label_type_sub + '.txt')

    if save_path is not None:
        sys.stdout = open(join(model.model_dir, 'decode.txt'), 'w')

    for data, is_new_epoch in dataset:

        # Create feed dictionary for next mini batch
        inputs, labels_true_word, labels_true_char, inputs_seq_len, input_names = data
        feed_dict = {
            model.inputs_pl_list[0]: inputs[0],
            model.inputs_seq_len_pl_list[0]: inputs_seq_len[0],
            model.keep_prob_hidden_pl_list[0]: 1.0
        }

        # Decode
        batch_size = inputs[0].shape[0]
        labels_pred_st_word, labels_pred_st_char = session.run(
            [decode_op_main, decode_op_sub], feed_dict=feed_dict)
        try:
            labels_pred_word = sparsetensor2list(
                labels_pred_st_word, batch_size=batch_size)
        except IndexError:
            # no output
            labels_pred_word = ['']
        try:
            labels_pred_char = sparsetensor2list(
                labels_pred_st_char, batch_size=batch_size)
        except IndexError:
            # no output
            labels_pred_char = ['']

        # Visualize
        for i_batch in range(batch_size):
            print('----- wav: %s -----' % input_names[0][i_batch])
            if is_test:
                str_true_word = labels_true_word[0][i_batch][0]
                str_true_char = labels_true_char[0][i_batch][0]
            else:
                str_true_word = '_'.join(
                    idx2word(labels_true_word[0][i_batch]))
                str_true_char = idx2char(labels_true_char[0][i_batch])

            str_pred_word = '_'.join(idx2word(labels_pred_word[0]))
            str_pred_char = idx2char(labels_pred_char[0])

            print('Ref (word): %s' % str_true_word)
            print('Ref (char): %s' % str_true_char)
            print('Hyp (word): %s' % str_pred_word)
            print('Hyp (char): %s' % str_pred_char)

        if is_new_epoch:
            break
コード例 #18
0
def decode(session,
           decode_op,
           model,
           dataset,
           label_type,
           train_data_size,
           is_test=True,
           save_path=None):
    """Visualize label outputs of CTC model.
    Args:
        session: session of training model
        decode_op: operation for decoding
        model: the model to evaluate
        dataset: An instance of a `Dataset` class
        label_type (string):  character or character_capital_divide or word
        train_data_size (string, optional): train100h or train460h or
            train960h
        is_test (bool, optional): set to True when evaluating by the test set
        save_path (string, optional): path to save decoding results
    """
    if label_type == 'character':
        map_fn = Idx2char(
            map_file_path='../metrics/mapping_files/character.txt')
    elif label_type == 'character_capital_divide':
        map_fn = Idx2char(
            map_file_path=
            '../metrics/mapping_files/character_capital_divide.txt',
            capital_divide=True)
    elif label_type == 'word':
        map_fn = Idx2word(map_file_path='../metrics/mapping_files/word_' +
                          train_data_size + '.txt')

    if save_path is not None:
        sys.stdout = open(join(model.model_dir, 'decode.txt'), 'w')

    for data, is_new_epoch in dataset:

        # Create feed dictionary for next mini batch
        inputs, labels_true, inputs_seq_len, input_names = data

        feed_dict = {
            model.inputs_pl_list[0]: inputs[0],
            model.inputs_seq_len_pl_list[0]: inputs_seq_len[0],
            model.keep_prob_pl_list[0]: 1.0
        }

        # Decode
        batch_size = inputs[0].shape[0]
        labels_pred_st = session.run(decode_op, feed_dict=feed_dict)
        try:
            labels_pred = sparsetensor2list(labels_pred_st,
                                            batch_size=batch_size)
        except IndexError:
            # no output
            labels_pred = ['']

        # Visualize
        for i_batch in range(batch_size):
            print(labels_true[0][i_batch][0])

            print('----- wav: %s -----' % input_names[0][i_batch])
            if 'char' in label_type:
                if is_test:
                    str_true = labels_true[0][i_batch][0]
                else:
                    str_true = map_fn(labels_true[0][i_batch])
                str_pred = map_fn(labels_pred[i_batch])
            else:
                if is_test:
                    str_true = labels_true[0][i_batch][0]
                else:
                    str_true = '_'.join(map_fn(labels_true[0][i_batch]))
                str_pred = '_'.join(map_fn(labels_pred[i_batch]))

            print('Ref: %s' % str_true)
            print('Hyp: %s' % str_pred)
            # wer_align(ref=str_true.split(), hyp=str_pred.split())

        if is_new_epoch:
            break
コード例 #19
0
def do_eval_wer(session,
                decode_ops,
                model,
                dataset,
                train_data_size,
                is_test=False,
                eval_batch_size=None,
                progressbar=False,
                is_multitask=False):
    """Evaluate trained model by Word Error Rate.
    Args:
        session: session of training model
        decode_ops: list of operations for decoding
        model: the model to evaluate
        dataset: An instance of `Dataset` class
        train_data_size (string): train100h or train460h or train960h
        is_test (bool, optional): set to True when evaluating by the test set
        eval_batch_size (int, optional): the batch size when evaluating the model
        progressbar (bool, optional): if True, visualize progressbar
        is_multitask (bool, optional): if True, evaluate the multitask model
    Return:
        wer_mean (bool): An average of WER
    """
    assert isinstance(decode_ops, list), "decode_ops must be a list."

    batch_size_original = dataset.batch_size

    # Reset data counter
    dataset.reset()

    # Set batch size in the evaluation
    if eval_batch_size is not None:
        dataset.batch_size = eval_batch_size

    idx2word = Idx2word(map_file_path='../metrics/mapping_files/word_' +
                        train_data_size + '.txt')

    wer_mean = 0
    skip_data_num = 0
    if progressbar:
        pbar = tqdm(total=len(dataset))
    for data, is_new_epoch in dataset:

        # Create feed dictionary for next mini batch
        if is_multitask:
            inputs, labels_true, _, inputs_seq_len, _ = data
        else:
            inputs, labels_true, inputs_seq_len, _ = data

        feed_dict = {}
        for i_device in range(len(decode_ops)):
            feed_dict[model.inputs_pl_list[i_device]] = inputs[i_device]
            feed_dict[model.inputs_seq_len_pl_list[i_device]] = inputs_seq_len[
                i_device]
            feed_dict[model.keep_prob_pl_list[i_device]] = 1.0

        labels_pred_st_list = session.run(decode_ops, feed_dict=feed_dict)
        for i_device, labels_pred_st in enumerate(labels_pred_st_list):
            batch_size_device = len(inputs[i_device])
            try:
                labels_pred = sparsetensor2list(labels_pred_st,
                                                batch_size_device)

                for i_batch in range(batch_size_device):

                    if is_test:
                        str_true = labels_true[i_device][i_batch][0]
                        # NOTE: transcript is seperated by space('_')
                    else:
                        str_true = '_'.join(
                            idx2word(labels_true[i_device][i_batch]))
                    str_pred = '_'.join(idx2word(labels_pred[i_batch]))

                    # if len(str_true.split('_')) == 0:
                    #     print(str_true)
                    #     print(str_pred)

                    # Compute WER
                    wer_mean += compute_wer(ref=str_true.split('_'),
                                            hyp=str_pred.split('_'),
                                            normalize=True)
                    # substitute, insert, delete = wer_align(
                    #     ref=str_true.split(' '),
                    #     hyp=str_pred.split(' '))
                    # print('SUB: %d' % substitute)
                    # print('INS: %d' % insert)
                    # print('DEL: %d' % delete)

                    if progressbar:
                        pbar.update(1)

            except IndexError:
                print('skipped')
                skip_data_num += batch_size_device
                # TODO: Conduct decoding again with batch size 1

                if progressbar:
                    pbar.update(batch_size_device)

        if is_new_epoch:
            break

    wer_mean /= (len(dataset) - skip_data_num)

    # Register original batch size
    if eval_batch_size is not None:
        dataset.batch_size = batch_size_original

    return wer_mean
    def check(self, label_type, label_type_sub,
              data_type='dev', data_size='300h', backend='pytorch',
              shuffle=False, sort_utt=True, sort_stop_epoch=None,
              frame_stacking=False, splice=1, num_gpus=1):

        print('========================================')
        print('  backend: %s' % backend)
        print('  label_type: %s' % label_type)
        print('  label_type_sub: %s' % label_type_sub)
        print('  data_type: %s' % data_type)
        print('  data_size: %s' % data_size)
        print('  shuffle: %s' % str(shuffle))
        print('  sort_utt: %s' % str(sort_utt))
        print('  sort_stop_epoch: %s' % str(sort_stop_epoch))
        print('  frame_stacking: %s' % str(frame_stacking))
        print('  splice: %d' % splice)
        print('  num_gpus: %d' % num_gpus)
        print('========================================')

        num_stack = 3 if frame_stacking else 1
        num_skip = 3 if frame_stacking else 1
        dataset = Dataset(
            # data_save_path='/n/sd8/inaguma/corpus/swbd/kaldi/' + data_size,
            data_save_path='/n/sd8/inaguma/corpus/swbd/kaldi',
            backend=backend,
            input_freq=40, use_delta=True, use_double_delta=True,
            data_type=data_type, data_size=data_size,
            label_type=label_type, label_type_sub=label_type_sub,
            batch_size=64, max_epoch=1, splice=splice,
            num_stack=num_stack, num_skip=num_skip,
            shuffle=shuffle,
            sort_utt=sort_utt, reverse=True, sort_stop_epoch=sort_stop_epoch,
            num_gpus=num_gpus, tool='htk',
            num_enque=None)

        print('=> Loading mini-batch...')
        idx2word = Idx2word(dataset.vocab_file_path)
        idx2char = Idx2char(dataset.vocab_file_path_sub)

        for batch, is_new_epoch in dataset:
            if data_type == 'train' and backend == 'pytorch':
                for i in range(len(batch['xs'])):
                    if batch['xs'].shape[1] < batch['ys'].shape[1]:
                        raise ValueError(
                            'input length must be longer than label length.')

            if dataset.is_test:
                str_ref = batch['ys'][0][0]
                str_ref = str_ref.lower()
                str_ref = str_ref.replace('(', '').replace(')', '')

                str_ref_sub = batch['ys_sub'][0][0]
                str_ref_sub = str_ref_sub.lower()
                str_ref_sub = str_ref_sub.replace('(', '').replace(')', '')
            else:
                str_ref = idx2word(batch['ys'][0][:batch['y_lens'][0]])
                str_ref_sub = idx2char(
                    batch['ys_sub'][0][:batch['y_lens_sub'][0]])

            print('----- %s (epoch: %.3f, batch: %d) -----' %
                  (batch['input_names'][0], dataset.epoch_detail, len(batch['xs'])))
            print('=' * 20)
            print(str_ref)
            print('-' * 10)
            print(str_ref_sub)
            print('x_lens: %d' % (batch['x_lens'][0] * num_stack))
            if not dataset.is_test:
                print('y_lens (word): %d' % batch['y_lens'][0])
                print('y_lens_sub (char): %d' % batch['y_lens_sub'][0])

            if dataset.epoch_detail >= 1:
                break
コード例 #21
0
def decode(model,
           dataset,
           beam_width,
           max_decode_len,
           max_decode_len_sub,
           eval_batch_size=None,
           save_path=None):
    """Visualize label outputs.
    Args:
        model: the model to evaluate
        dataset: An instance of a `Dataset` class
        beam_width: (int): the size of beam
        max_decode_len (int): the length of output sequences
            to stop prediction when EOS token have not been emitted.
            This is used for seq2seq models.
        max_decode_len_sub (int)
        eval_batch_size (int, optional): the batch size when evaluating the model
        save_path (string): path to save decoding results
    """
    # Set batch size in the evaluation
    if eval_batch_size is not None:
        dataset.batch_size = eval_batch_size

    idx2word = Idx2word(vocab_file_path=dataset.vocab_file_path)
    if dataset.label_type_sub == 'character':
        idx2char = Idx2char(vocab_file_path=dataset.vocab_file_path_sub)
    elif dataset.label_type_sub == 'character_capital_divide':
        idx2char = Idx2char(vocab_file_path=dataset.vocab_file_path_sub,
                            capital_divide=True)

    if save_path is not None:
        sys.stdout = open(join(model.model_dir, 'decode.txt'), 'w')

    for batch, is_new_epoch in dataset:

        # Decode
        if model.model_type == 'charseq_attention':
            best_hyps, best_hyps_sub, perm_idx = model.decode(
                batch['xs'],
                batch['x_lens'],
                beam_width=beam_width,
                max_decode_len=max_decode_len,
                max_decode_len_sub=100)
        else:
            best_hyps, perm_idx = model.decode(batch['xs'],
                                               batch['x_lens'],
                                               beam_width=beam_width,
                                               max_decode_len=max_decode_len)
            best_hyps_sub, perm_idx = model.decode(
                batch['xs'],
                batch['x_lens'],
                beam_width=beam_width,
                max_decode_len=max_decode_len_sub,
                is_sub_task=True)

        ys = batch['ys'][perm_idx]
        y_lens = batch['y_lens'][perm_idx]
        ys_sub = batch['ys_sub'][perm_idx]
        y_lens_sub = batch['y_lens_sub'][perm_idx]

        for b in range(len(batch['xs'])):

            ##############################
            # Reference
            ##############################
            if dataset.is_test:
                str_ref = ys[b][0]
                str_ref_sub = ys_sub[b][0]
                # NOTE: transcript is seperated by space('_')
            else:
                # Convert from list of index to string
                str_ref = idx2word(ys[b][:y_lens[b]])
                str_ref_sub = idx2word(ys_sub[b][:y_lens_sub[b]])

            ##############################
            # Hypothesis
            ##############################
            # Convert from list of index to string
            str_hyp = idx2word(best_hyps[b])
            str_hyp_sub = idx2char(best_hyps_sub[b])

            if model.model_type != 'hierarchical_ctc':
                str_hyp = str_hyp.split('>')[0]
                str_hyp_sub = str_hyp_sub.split('>')[0]
                # NOTE: Trancate by the first <EOS>

                # Remove the last space
                if len(str_hyp) > 0 and str_hyp[-1] == '_':
                    str_hyp = str_hyp[:-1]
                if len(str_hyp_sub) > 0 and str_hyp_sub[-1] == '_':
                    str_hyp_sub = str_hyp_sub[:-1]

            ##############################
            # Post-proccessing
            ##############################
            # Remove garbage labels
            str_ref = re.sub(r'[\'>]+', '', str_ref)
            str_hyp = re.sub(r'[\'>]+', '', str_hyp)

            print('----- wav: %s -----' % batch['input_names'][b])
            print('Ref: %s' % str_ref.replace('_', ' '))
            print('Hyp (main): %s' % str_hyp.replace('_', ' '))
            # print('Ref (sub): %s' % str_ref_sub.replace('_', ' '))
            print('Hyp (sub): %s' % str_hyp_sub.replace('_', ' '))

            wer, _, _, _ = compute_wer(ref=str_ref.split('_'),
                                       hyp=str_hyp.split('_'),
                                       normalize=True)
            print('WER: %f %%' % (wer * 100))
            cer, _, _, _ = compute_wer(ref=list(str_ref_sub.replace('_', '')),
                                       hyp=list(str_hyp_sub.replace('_', '')),
                                       normalize=True)
            print('CER: %f %%' % (cer * 100))

        if is_new_epoch:
            break
コード例 #22
0
def do_eval_wer(model,
                dataset,
                beam_width,
                max_decode_len,
                eval_batch_size=None,
                progressbar=False):
    """Evaluate trained model by Word Error Rate.
    Args:
        model: the model to evaluate
        dataset: An instance of a `Dataset' class
        beam_width: (int): the size of beam
        max_decode_len (int): the length of output sequences
            to stop prediction when EOS token have not been emitted.
            This is used for seq2seq models.
        eval_batch_size (int, optional): the batch size when evaluating the model
        progressbar (bool, optional): if True, visualize the progressbar
    Returns:
        wer (float): Word error rate
        df_wer (pd.DataFrame): dataframe of substitution, insertion, and deletion
    """
    # Reset data counter
    dataset.reset()

    idx2word = Idx2word(vocab_file_path=dataset.vocab_file_path)

    wer = 0
    sub, ins, dele, = 0, 0, 0
    num_words = 0
    if progressbar:
        pbar = tqdm(total=len(dataset))  # TODO: fix this
    while True:
        batch, is_new_epoch = dataset.next(batch_size=eval_batch_size)

        # Decode
        best_hyps, perm_idx = model.decode(batch['xs'],
                                           batch['x_lens'],
                                           beam_width=beam_width,
                                           max_decode_len=max_decode_len)
        ys = batch['ys'][perm_idx]
        y_lens = batch['y_lens'][perm_idx]

        for b in range(len(batch['xs'])):

            ##############################
            # Reference
            ##############################
            if dataset.is_test:
                str_ref = ys[b][0]
                # NOTE: transcript is seperated by space('_')
            else:
                # Convert from list of index to string
                str_ref = idx2word(ys[b][:y_lens[b]])

            ##############################
            # Hypothesis
            ##############################
            str_hyp = idx2word(best_hyps[b])
            if 'attention' in model.model_type:
                str_hyp = str_hyp.split('>')[0]
                # NOTE: Trancate by the first <EOS>

                # Remove the last space
                if len(str_hyp) > 0 and str_hyp[-1] == '_':
                    str_hyp = str_hyp[:-1]

            ##############################
            # Post-proccessing
            ##############################
            # Remove garbage labels
            str_ref = re.sub(r'[\'>]+', '', str_ref)
            str_hyp = re.sub(r'[\'>]+', '', str_hyp)
            # TODO: WER計算するときに消していい?

            # Compute WER
            wer_b, sub_b, ins_b, del_b = compute_wer(ref=str_ref.split('_'),
                                                     hyp=str_hyp.split('_'),
                                                     normalize=True)
            wer += wer_b
            sub += sub_b
            ins += ins_b
            dele += del_b
            num_words += len(str_ref.split('_'))

            if progressbar:
                pbar.update(1)

        if is_new_epoch:
            break

    if progressbar:
        pbar.close()

    wer /= num_words
    sub /= num_words
    ins /= num_words
    dele /= num_words

    df_wer = pd.DataFrame(
        {
            'SUB': [sub * 100],
            'INS': [ins * 100],
            'DEL': [dele * 100]
        },
        columns=['SUB', 'INS', 'DEL'],
        index=['WER'])

    return wer, df_wer
コード例 #23
0
    def check(self,
              label_type,
              label_type_sub,
              data_type='dev_clean',
              data_size='100h',
              backend='pytorch',
              shuffle=False,
              sort_utt=True,
              sort_stop_epoch=None,
              frame_stacking=False,
              splice=1,
              num_gpus=1):

        print('========================================')
        print('  backend: %s' % backend)
        print('  label_type: %s' % label_type)
        print('  label_type_sub: %s' % label_type_sub)
        print('  data_type: %s' % data_type)
        print('  data_size: %s' % data_size)
        print('  shuffle: %s' % str(shuffle))
        print('  sort_utt: %s' % str(sort_utt))
        print('  sort_stop_epoch: %s' % str(sort_stop_epoch))
        print('  frame_stacking: %s' % str(frame_stacking))
        print('  splice: %d' % splice)
        print('  num_gpus: %d' % num_gpus)
        print('========================================')

        vocab_file_path = '../../metrics/vocab_files/' + \
            label_type + '_' + data_size + '.txt'
        vocab_file_path_sub = '../../metrics/vocab_files/' + \
            label_type_sub + '_' + data_size + '.txt'

        num_stack = 3 if frame_stacking else 1
        num_skip = 3 if frame_stacking else 1
        dataset = Dataset(backend=backend,
                          input_channel=40,
                          use_delta=True,
                          use_double_delta=True,
                          data_type=data_type,
                          data_size=data_size,
                          label_type=label_type,
                          label_type_sub=label_type_sub,
                          batch_size=64,
                          vocab_file_path=vocab_file_path,
                          vocab_file_path_sub=vocab_file_path_sub,
                          max_epoch=1,
                          splice=splice,
                          num_stack=num_stack,
                          num_skip=num_skip,
                          shuffle=shuffle,
                          sort_utt=sort_utt,
                          reverse=True,
                          sort_stop_epoch=sort_stop_epoch,
                          num_gpus=num_gpus,
                          save_format='numpy',
                          num_enque=None)

        print('=> Loading mini-batch...')
        idx2word = Idx2word(vocab_file_path, space_mark=' ')
        idx2char = Idx2char(vocab_file_path_sub)

        for batch, is_new_epoch in dataset:
            if data_type == 'train' and backend == 'pytorch':
                for i in range(len(batch['xs'])):
                    if batch['xs'].shape[1] < batch['ys'].shape[1]:
                        raise ValueError(
                            'input length must be longer than label length.')

            if dataset.is_test:
                str_ref = batch['ys'][0][0]
                str_ref_sub = batch['ys_sub'][0][0]
            else:
                str_ref = idx2word(batch['ys'][0][:batch['y_lens'][0]])
                str_ref_sub = idx2char(
                    batch['ys_sub'][0][:batch['y_lens_sub'][0]])

            print('----- %s (epoch: %.3f, batch: %d) -----' %
                  (batch['input_names'][0], dataset.epoch_detail,
                   len(batch['xs'])))
            print('=' * 20)
            print(str_ref)
            print('-' * 10)
            print(str_ref_sub)
            print('x_lens: %d' % (batch['x_lens'][0] * num_stack))
            if not dataset.is_test:
                print('y_lens (word): %d' % batch['y_lens'][0])
                print('y_lens_sub (char): %d' % batch['y_lens_sub'][0])

            if dataset.epoch_detail >= 0.01:
                break
コード例 #24
0
def decode(model,
           dataset,
           eval_batch_size,
           beam_width,
           length_penalty,
           save_path=None):
    """Visualize label outputs.
    Args:
        model: the model to evaluate
        dataset: An instance of a `Dataset` class
        eval_batch_size (int): the batch size when evaluating the model
        beam_width: (int): the size of beam
        length_penalty (float):
        save_path (string): path to save decoding results
    """
    if 'word' in dataset.label_type:
        map_fn = Idx2word(dataset.vocab_file_path)
        max_decode_len = MAX_DECODE_LEN_WORD
    else:
        map_fn = Idx2char(dataset.vocab_file_path)
        max_decode_len = MAX_DECODE_LEN_CHAR

    if save_path is not None:
        sys.stdout = open(join(model.model_dir, 'decode.txt'), 'w')

    for batch, is_new_epoch in dataset:

        # Decode
        if model.model_type == 'nested_attention':
            best_hyps, _, best_hyps_sub, _, perm_idx = model.decode(
                batch['xs'],
                batch['x_lens'],
                beam_width=beam_width,
                max_decode_len=max_decode_len,
                max_decode_len_sub=max_decode_len)
        else:
            best_hyps, _, perm_idx = model.decode(
                batch['xs'],
                batch['x_lens'],
                beam_width=beam_width,
                max_decode_len=max_decode_len)

        if model.model_type == 'attention' and model.ctc_loss_weight > 0:
            best_hyps_ctc, perm_idx = model.decode_ctc(batch['xs'],
                                                       batch['x_lens'],
                                                       beam_width=beam_width)

        ys = batch['ys'][perm_idx]
        y_lens = batch['y_lens'][perm_idx]

        for b in range(len(batch['xs'])):
            ##############################
            # Reference
            ##############################
            if dataset.is_test:
                str_ref = ys[b][0]
                # NOTE: transcript is seperated by space('_')
            else:
                # Convert from list of index to string
                str_ref = map_fn(ys[b][:y_lens[b]])

            ##############################
            # Hypothesis
            ##############################
            # Convert from list of index to string
            str_hyp = map_fn(best_hyps[b])

            print('----- wav: %s -----' % batch['input_names'][b])
            print('Ref: %s' % str_ref.replace('_', ' '))
            print('Hyp: %s' % str_hyp.replace('_', ' '))
            if model.model_type == 'attention' and model.ctc_loss_weight > 0:
                str_hyp_ctc = map_fn(best_hyps_ctc[b])
                print('Hyp (CTC): %s' % str_hyp_ctc)

            try:
                if dataset.label_type == 'word' or dataset.label_type == 'kanji_wb':
                    wer, _, _, _ = compute_wer(ref=str_ref.split('_'),
                                               hyp=re.sub(
                                                   r'(.*)[_]*>(.*)', r'\1',
                                                   str_hyp).split('_'),
                                               normalize=True)
                    print('WER: %.3f %%' % (wer * 100))
                    if model.model_type == 'attention' and model.ctc_loss_weight > 0:
                        wer_ctc, _, _, _ = compute_wer(
                            ref=str_ref.split('_'),
                            hyp=str_hyp_ctc.split('_'),
                            normalize=True)
                        print('WER (CTC): %.3f %%' % (wer_ctc * 100))
                else:
                    cer, _, _, _ = compute_wer(
                        ref=list(str_ref.replace('_', '')),
                        hyp=list(
                            re.sub(r'(.*)>(.*)', r'\1',
                                   str_hyp).replace('_', '')),
                        normalize=True)
                    print('CER: %.3f %%' % (cer * 100))
                    if model.model_type == 'attention' and model.ctc_loss_weight > 0:
                        cer_ctc, _, _, _ = compute_wer(
                            ref=list(str_ref.replace('_', '')),
                            hyp=list(str_hyp_ctc.replace('_', '')),
                            normalize=True)
                        print('CER (CTC): %.3f %%' % (cer_ctc * 100))
            except:
                print('--- skipped ---')

        if is_new_epoch:
            break
def plot_attention(model,
                   dataset,
                   max_decode_len,
                   eval_batch_size=None,
                   save_path=None):
    """Visualize attention weights of attetnion-based model.
    Args:
        model: model to evaluate
        dataset: An instance of a `Dataset` class
        eval_batch_size (int, optional): the batch size when evaluating the model
        max_decode_len (int): the length of output sequences
            to stop prediction when EOS token have not been emitted.
        save_path (string, optional): path to save attention weights plotting
    """
    # Set batch size in the evaluation
    if eval_batch_size is not None:
        dataset.batch_size = eval_batch_size

    # Clean directory
    if isdir(save_path):
        shutil.rmtree(save_path)
        mkdir(save_path)

    vocab_file_path = '../metrics/vocab_files/' + \
        dataset.label_type + '_' + dataset.data_size + '.txt'
    if 'char' in dataset.label_type:
        map_fn = Idx2char(vocab_file_path)
    else:
        map_fn = Idx2word(vocab_file_path)

    for batch, is_new_epoch in dataset:

        # Decode
        best_hyps, att_weights = model.attention_weights(
            batch['xs'], batch['x_lens'], max_decode_len=max_decode_len)
        # NOTE: attention_weights: `[B, T_out, T_in]`

        # Visualize
        for b in range(len(batch['xs'])):

            # Check if the sum of attention weights equals to 1
            # print(np.sum(att_weights[b], axis=1))

            str_pred = map_fn(best_hyps[b])
            eos = True if '>' in str_pred else False

            str_pred = str_pred.split('>')[0]
            # NOTE: Trancate by <EOS>

            # Remove the last space
            if len(str_pred) > 0 and str_pred[-1] == '_':
                str_pred = str_pred[:-1]

            if eos:
                str_pred += '_>'

            speaker = batch['input_names'][b].split('_')[0]
            plot_attention_weights(attention_weights=att_weights[
                b, :len(str_pred.split('_')), :batch['x_lens'][b]],
                                   label_list=str_pred.split('_'),
                                   save_path=mkdir_join(
                                       save_path, speaker,
                                       batch['input_names'][b] + '.png'),
                                   figsize=(20, 8))

        if is_new_epoch:
            break