コード例 #1
0
def decode_test(session,
                decode_op,
                network,
                dataset,
                label_type,
                save_path=None):
    """Visualize label outputs.
    Args:
        session: session of training model
        decode_op: operation for decoding
        network: network to evaluate
        dataset: An instance of a `Dataset` class
        label_type: stirng, phone39 or phone48 or phone61 or character
        save_path: path to save decoding results
    """
    # Batch size is expected to be 1
    iteration = dataset.data_num

    # Make data generator
    mini_batch = dataset.next_batch(batch_size=1)

    if label_type == 'character':
        map_file_path = '../metrics/mapping_files/attention/char2num.txt'
    else:
        map_file_path = '../metrics/mapping_files/attention/phone2num_' + \
            label_type[5:7] + '.txt'

    # if save_path is not None:
    #     sys.stdout = open(join(network.model_dir, 'decode.txt'), 'w')

    for step in range(iteration):
        # Create feed dictionary for next mini batch
        inputs, labels_true, inputs_seq_len, _, input_names = mini_batch.__next__(
        )

        feed_dict = {
            network.inputs: inputs,
            network.labels: labels_true,
            network.inputs_seq_len: inputs_seq_len,
            network.keep_prob_input: 1.0,
            network.keep_prob_hidden: 1.0
        }

        # Visualize
        labels_pred = session.run(decode_op, feed_dict=feed_dict)

        if label_type == 'character':
            print('----- wav: %s -----' % input_names[0])
            print('True: %s' % num2char(labels_true[0][1:-1], map_file_path))
            print('Pred: %s' %
                  num2char(labels_pred[0], map_file_path).replace('>', ''))

        else:
            print('----- wav: %s -----' % input_names[0])
            print('True: %s' % num2phone(labels_true[0][1:-1], map_file_path))

            print('Pred: %s' %
                  num2phone(labels_pred[0], map_file_path).replace('>', ''))
コード例 #2
0
    def check_loading(self, num_gpu, is_sorted):
        print('----- num_gpu: ' + str(num_gpu) + ', is_sorted: ' +
              str(is_sorted) + ' -----')

        batch_size = 64
        dataset = Dataset(data_type='train',
                          label_type_main='character',
                          label_type_sub='phone61',
                          batch_size=batch_size,
                          num_stack=3,
                          num_skip=3,
                          is_sorted=is_sorted,
                          is_progressbar=True,
                          num_gpu=num_gpu)

        tf.reset_default_graph()
        with tf.Session().as_default() as sess:
            print('=> Loading mini-batch...')
            map_file_path_char = '../metrics/mapping_files/ctc/char2num.txt'
            map_file_path_phone = '../metrics/mapping_files/ctc/phone2num_61.txt'

            mini_batch = dataset.next_batch(session=sess)

            iter_per_epoch = int(dataset.data_num / (batch_size * num_gpu)) + 1
            for i in range(iter_per_epoch + 1):
                return_tuple = mini_batch.__next__()
                inputs = return_tuple[0]
                labels_char_st = return_tuple[1]
                labels_phone_st = return_tuple[2]

                if num_gpu > 1:
                    for inputs_gpu in inputs:
                        print(inputs_gpu.shape)
                    labels_char_st = labels_char_st[0]
                    labels_phone_st = labels_phone_st[0]

                labels_char = sparsetensor2list(labels_char_st,
                                                batch_size=len(inputs))
                labels_phone = sparsetensor2list(labels_phone_st,
                                                 batch_size=len(inputs))

                if num_gpu == 1:
                    for inputs_i, labels_i in zip(inputs, labels_char):
                        if len(inputs_i) < len(labels_i):
                            print(len(inputs_i))
                            print(len(labels_i))
                            raise ValueError
                    for inputs_i, labels_i in zip(inputs, labels_phone):
                        if len(inputs_i) < len(labels_i):
                            print(len(inputs_i))
                            print(len(labels_i))
                            raise ValueError

                str_true_char = num2char(labels_char[0], map_file_path_char)
                str_true_char = re.sub(r'_', ' ', str_true_char)
                str_true_phone = num2phone(labels_phone[0],
                                           map_file_path_phone)
コード例 #3
0
def decode_test(session, decode_op, network, dataset, label_type,
                save_path=None):
    """Visualize label outputs of CTC model.
    Args:
        session: session of training model
        decode_op: operation for decoding
        network: network to evaluate
        dataset: An instance of a `Dataset` class
        label_type: string, kanji or kana or phone
        save_path: path to save decoding results
    """
    # Batch size is expected to be 1
    iteration = dataset.data_num

    # Make data generator
    mini_batch = dataset.next_batch(batch_size=1)

    if label_type == 'kanji':
        map_file_path = '../metrics/mapping_files/ctc/kanji2num.txt'
    elif label_type == 'kana':
        map_file_path = '../metrics/mapping_files/ctc/kana2num.txt'
    elif label_type == 'phone':
        map_file_path = '../metrics/mapping_files/ctc/phone2num.txt'

    if save_path is not None:
        sys.stdout = open(join(network.model_dir, 'decode.txt'), 'w')

    for step in range(iteration):
        # Create feed dictionary for next mini batch
        inputs, labels_true, inputs_seq_len, input_names = mini_batch.__next__()
        # NOTE: labels_true is expected to be a list of string when evaluation
        # using dataset where label_type is kanji or kana

        feed_dict = {
            network.inputs: inputs,
            network.inputs_seq_len: inputs_seq_len,
            network.keep_prob_input: 1.0,
            network.keep_prob_hidden: 1.0
        }

        # Visualize
        labels_pred_st = session.run(decode_op, feed_dict=feed_dict)
        labels_pred = sparsetensor2list(labels_pred_st, batch_size=1)

        if label_type in ['kanji', 'kana']:
            print('----- wav: %s -----' % input_names[0])
            print('True: %s' % labels_true[0])
            print('Pred: %s' % num2char(labels_pred[0], map_file_path))

        elif label_type == 'phone':
            print('----- wav: %s -----' % input_names[0])
            print('True: %s' % num2phone(labels_true[0], map_file_path))
            print('Pred: %s' % num2phone(labels_pred[0], map_file_path))
コード例 #4
0
def decode_test_multitask(session, decode_op_main, decode_op_second, network,
                          dataset, label_type_second, save_path=None):
    """Visualize label outputs of Multi-task CTC model.
    Args:
        session: session of training model
        decode_op_main: operation for decoding in the main task
        decode_op_second: operation for decoding in the second task
        network: network to evaluate
        dataset: An instance of a `Dataset` class
        label_type_second: string, phone39 or phone48 or phone61
        save_path: path to save decoding results
    """
    # Batch size is expected to be 1
    iteration = dataset.data_num

    # Make data generator
    mini_batch = dataset.next_batch(batch_size=1)

    if save_path is not None:
        sys.stdout = open(join(network.model_dir, 'decode.txt'), 'w')

    # Decode character
    print('===== character =====')
    map_file_path = '../metrics/mapping_files/ctc/char2num.txt'
    for step in range(iteration):
        # Create feed dictionary for next mini batch
        inputs, labels_true_st, _, inputs_seq_len, input_names = mini_batch.__next__()

        feed_dict = {
            network.inputs: inputs,
            network.inputs_seq_len: inputs_seq_len,
            network.keep_prob_input: 1.0,
            network.keep_prob_hidden: 1.0
        }

        # Visualize
        labels_pred_st = session.run(decode_op_main, feed_dict=feed_dict)
        labels_true = sparsetensor2list(labels_true_st, batch_size=1)
        labels_pred = sparsetensor2list(labels_pred_st, batch_size=1)

        print('----- wav: %s -----' % input_names[0])
        print('True: %s' % num2char(
            labels_true[0], map_file_path))
        print('Pred: %s' % num2char(
            labels_pred[0], map_file_path))

    # Decode phone
    print('\n===== phone =====')
    map_file_path = '../metrics/mapping_files/ctc/phone2num_' + \
        label_type_second[5:7] + '.txt'
    for step in range(iteration):
        # Create feed dictionary for next mini batch
        inputs, _, labels_true_st, inputs_seq_len, input_names = mini_batch.__next__()

        feed_dict = {
            network.inputs: inputs,
            network.inputs_seq_len: inputs_seq_len,
            network.keep_prob_input: 1.0,
            network.keep_prob_hidden: 1.0
        }

        # Visualize
        labels_pred_st = session.run(decode_op_second, feed_dict=feed_dict)
        labels_true = sparsetensor2list(labels_true_st, batch_size=1)
        labels_pred = sparsetensor2list(labels_pred_st, batch_size=1)

        print('----- wav: %s -----' % input_names[0])
        print('True: %s' % num2phone(
            labels_true[0], map_file_path))

        print('Pred: %s' % num2phone(
            labels_pred[0], map_file_path))
def do_eval_cer(session,
                decode_op,
                network,
                dataset,
                eval_batch_size=None,
                is_progressbar=False):
    """Evaluate trained model by Character Error Rate.
    Args:
        session: session of training model
        decode_op: operation for decoding
        network: network to evaluate
        dataset: An instance of a `Dataset` class
        eval_batch_size: int, batch size when evaluating the model
        is_progressbar: if True, visualize the progressbar
    Return:
        cer_mean: An average of CER
    """
    if eval_batch_size is not None:
        batch_size = eval_batch_size
    else:
        batch_size = dataset.batch_size

    # Make data generator
    mini_batch = dataset.next_batch(batch_size=batch_size)

    num_examples = dataset.data_num
    iteration = int(num_examples / batch_size)
    if (num_examples / batch_size) != int(num_examples / batch_size):
        iteration += 1
    cer_sum = 0

    map_file_path = '../metrics/mapping_files/attention/char2num.txt'
    for step in wrap_iterator(range(iteration), is_progressbar):
        # Create feed dictionary for next mini-batch
        inputs, att_labels_true, _, inputs_seq_len, _, _ = mini_batch.__next__(
        )

        feed_dict = {
            network.inputs: inputs,
            network.inputs_seq_len: inputs_seq_len,
            network.keep_prob_input: 1.0,
            network.keep_prob_hidden: 1.0
        }

        batch_size_each = len(inputs_seq_len)

        predicted_ids = session.run(decode_op, feed_dict=feed_dict)
        for i_batch in range(batch_size_each):

            # Convert from list to string
            str_true = num2char(att_labels_true[i_batch], map_file_path)
            str_pred = num2char(predicted_ids[i_batch], map_file_path)

            # Remove silence(_) labels
            str_true = re.sub(r'[_<>,.\'-?!]+', "", str_true)
            str_pred = re.sub(r'[_<>,.\'-?!]+', "", str_pred)

            # Compute edit distance
            cer_each = Levenshtein.distance(str_pred, str_true) / len(
                list(str_true))
            cer_sum += cer_each

    cer_mean = cer_sum / dataset.data_num

    return cer_mean
コード例 #6
0
def do_eval_cer(session,
                decode_op,
                network,
                dataset,
                label_type,
                is_test=None,
                eval_batch_size=None,
                is_progressbar=False,
                is_multitask=False,
                is_main=False):
    """Evaluate trained model by Character Error Rate.
    Args:
        session: session of training model
        decode_op: operation for decoding
        network: network to evaluate
        dataset: An instance of `Dataset` class
        label_type: string, kanji or kana or phone
        is_test: bool, set to True when evaluating by the test set
        eval_batch_size: int, the batch size when evaluating the model
        is_progressbar: if True, visualize progressbar
        is_multitask: if True, evaluate the multitask model
        is_main: if True, evaluate the main task
    Return:
        cer_mean: An average of CER
    """
    if eval_batch_size is None:
        batch_size = dataset.batch_size
    else:
        batch_size = eval_batch_size

    num_examples = dataset.data_num
    iteration = int(num_examples / batch_size)
    if (num_examples / batch_size) != int(num_examples / batch_size):
        iteration += 1
    cer_sum = 0

    # Make data generator
    mini_batch = dataset.next_batch(batch_size=batch_size)

    if label_type == 'kanji':
        map_file_path = '../metrics/mapping_files/ctc/kanji2num.txt'
    elif label_type == 'kana':
        map_file_path = '../metrics/mapping_files/ctc/kana2num.txt'
    elif label_type == 'phone':
        map_file_path == '../metrics/mapping_files/ctc/phone2num.txt'

    for step in wrap_iterator(range(iteration), is_progressbar):
        # Create feed dictionary for next mini batch
        if not is_multitask:
            inputs, labels_true, inputs_seq_len, _ = mini_batch.__next__()
        else:
            if is_main:
                inputs, labels_true, _, inputs_seq_len, _ = mini_batch.__next__(
                )
            else:
                inputs, _, labels_true, inputs_seq_len, _ = mini_batch.__next__(
                )

        feed_dict = {
            network.inputs: inputs,
            network.inputs_seq_len: inputs_seq_len,
            network.keep_prob_input: 1.0,
            network.keep_prob_hidden: 1.0
        }

        batch_size_each = len(inputs_seq_len)

        labels_pred_st = session.run(decode_op, feed_dict=feed_dict)
        labels_pred = sparsetensor2list(labels_pred_st, batch_size_each)

        for i_batch in range(batch_size_each):
            # Convert from list to string
            if label_type != 'phone' and is_test:
                str_true = ''.join(labels_true[i_batch])
                # NOTE: 漢字とかなの場合はテストデータのラベルはそのまま保存してある
            else:
                str_true = num2char(labels_true[i_batch], map_file_path)
            str_pred = num2char(labels_pred[i_batch], map_file_path)

            # Remove silence(_) & noise(NZ) labels
            str_true = re.sub(r'[_NZー]+', "", str_true)
            str_pred = re.sub(r'[_NZー]+', "", str_pred)

            # Compute edit distance
            cer_each = Levenshtein.distance(str_pred, str_true) / len(
                list(str_true))

            cer_sum += cer_each

    cer_mean = cer_sum / dataset.data_num

    return cer_mean