def check(self,
              label_type,
              data_type='dev',
              shuffle=False,
              sort_utt=False,
              sort_stop_epoch=None,
              frame_stacking=False,
              splice=1):

        print('========================================')
        print('  label_type: %s' % label_type)
        print('  data_type: %s' % data_type)
        print('  shuffle: %s' % str(shuffle))
        print('  sort_utt: %s' % str(sort_utt))
        print('  sort_stop_epoch: %s' % str(sort_stop_epoch))
        print('  frame_stacking: %s' % str(frame_stacking))
        print('  splice: %d' % splice)
        print('========================================')

        num_stack = 3 if frame_stacking else 1
        num_skip = 3 if frame_stacking else 1
        dataset = Dataset(data_type=data_type,
                          label_type=label_type,
                          batch_size=64,
                          max_epoch=2,
                          splice=splice,
                          num_stack=num_stack,
                          num_skip=num_skip,
                          shuffle=shuffle,
                          sort_utt=sort_utt,
                          sort_stop_epoch=sort_stop_epoch,
                          progressbar=True)

        print('=> Loading mini-batch...')
        map_fn = Idx2phone(map_file_path='../../metrics/mapping_files/' +
                           label_type + '.txt')

        for data, is_new_epoch in dataset:
            inputs, labels, inputs_seq_len, input_names = data

            if data_type == 'train':
                for i_batch, l_batch in zip(inputs[0], labels[0]):
                    if len(np.where(l_batch == dataset.padded_value)[0]) > 0:
                        if i_batch.shape[0] < np.where(
                                l_batch == dataset.padded_value)[0][0]:
                            raise ValueError(
                                'input length must be longer than label length.'
                            )
                    else:
                        if i_batch.shape[0] < len(l_batch):
                            raise ValueError(
                                'input length must be longer than label length.'
                            )

            str_true = map_fn(labels[0][0])

            print('----- %s ----- (epoch: %.3f)' %
                  (input_names[0][0], dataset.epoch_detail))
            print(inputs[0][0].shape)
            print(str_true)
    def check_loading(self, label_type, data_type='dev',
                      shuffle=False, sort_utt=False, sort_stop_epoch=None,
                      frame_stacking=False, splice=1):

        print('========================================')
        print('  label_type: %s' % label_type)
        print('  data_type: %s' % data_type)
        print('  shuffle: %s' % str(sort_utt))
        print('  sort_utt: %s' % str(sort_utt))
        print('  sort_stop_epoch: %s' % str(sort_stop_epoch))
        print('  frame_stacking: %s' % str(frame_stacking))
        print('  splice: %d' % splice)
        print('========================================')

        num_stack = 3 if frame_stacking else 1
        num_skip = 3 if frame_stacking else 1
        dataset = Dataset(
            data_type=data_type, label_type=label_type,
            batch_size=64, eos_index=1, max_epoch=2, splice=splice,
            num_stack=num_stack, num_skip=num_skip,
            shuffle=shuffle,
            sort_utt=sort_utt, sort_stop_epoch=sort_stop_epoch,
            progressbar=True)

        print('=> Loading mini-batch...')
        if label_type in ['character', 'character_capital_divide']:
            map_fn_ctc = Idx2char(
                map_file_path='../../metrics/mapping_files/ctc/' + label_type + '.txt')
            map_fn_att = Idx2char(
                map_file_path='../../metrics/mapping_files/attention/' + label_type + '.txt')
        else:
            map_fn_ctc = Idx2phone(
                map_file_path='../../metrics/mapping_files/ctc/' + label_type + '.txt')
            map_fn_att = Idx2phone(
                map_file_path='../../metrics/mapping_files/attention/' + label_type + '.txt')

        for data, is_new_epoch in dataset:
            inputs, att_labels, ctc_labels, inputs_seq_len, att_labels_seq_len, input_names = data

            att_str_true = map_fn_att(att_labels[0][0: att_labels_seq_len[0]])
            ctc_str_true = map_fn_ctc(ctc_labels[0])
            att_str_true = re.sub(r'_', ' ', att_str_true)
            ctc_str_true = re.sub(r'_', ' ', ctc_str_true)
            print('----- %s ----- (epoch: %.3f)' %
                  (input_names[0], dataset.epoch_detail))
            print(att_str_true)
            print(ctc_str_true)
def plot_attention(model,
                   dataset,
                   eval_batch_size,
                   beam_width,
                   length_penalty,
                   save_path=None):
    """Visualize attention weights of the attetnion-based model.
    Args:
        model: model to evaluate
        dataset: An instance of a `Dataset` class
        eval_batch_size (int): the batch size when evaluating the model
        beam_width: (int): the size of beam
        length_penalty (float):
        save_path (string, optional): path to save attention weights plotting
    """
    # Clean directory
    if save_path is not None and isdir(save_path):
        shutil.rmtree(save_path)
        mkdir(save_path)

    idx2phone = Idx2phone(dataset.vocab_file_path)

    for batch, is_new_epoch in dataset:

        # Decode
        best_hyps, aw, perm_idx = model.decode(
            batch['xs'],
            batch['x_lens'],
            beam_width=beam_width,
            max_decode_len=MAX_DECODE_LEN_PHONE,
            length_penalty=length_penalty)

        ys = batch['ys'][perm_idx]
        y_lens = batch['y_lens'][perm_idx]

        for b in range(len(batch['xs'])):
            ##############################
            # Reference
            ##############################
            if dataset.is_test:
                str_ref = ys[b][0]
                # NOTE: transcript is seperated by space(' ')
            else:
                # Convert from list of index to string
                str_ref = idx2phone(ys[b][:y_lens[b]])

            token_list = idx2phone(best_hyps[b])

            plot_attention_weights(
                aw[b][:len(token_list), :batch['x_lens'][b]],
                label_list=token_list,
                spectrogram=batch['xs'][b, :, :40],
                str_ref=str_ref,
                save_path=join(save_path, batch['input_names'][b] + '.png'),
                figsize=(20, 8))

        if is_new_epoch:
            break
    def check(self,
              label_type_main,
              data_type='dev',
              shuffle=False,
              sort_utt=False,
              sort_stop_epoch=None,
              frame_stacking=False,
              splice=1):

        print('========================================')
        print('  label_type_main: %s' % label_type_main)
        print('  data_type: %s' % data_type)
        print('  shuffle: %s' % str(shuffle))
        print('  sort_utt: %s' % str(sort_utt))
        print('  sort_stop_epoch: %s' % str(sort_stop_epoch))
        print('  frame_stacking: %s' % str(frame_stacking))
        print('  splice: %d' % splice)
        print('========================================')

        num_stack = 3 if frame_stacking else 1
        num_skip = 3 if frame_stacking else 1
        dataset = Dataset(data_type=data_type,
                          label_type_main=label_type_main,
                          label_type_sub='phone61',
                          batch_size=64,
                          max_epoch=2,
                          splice=splice,
                          num_stack=num_stack,
                          num_skip=num_skip,
                          shuffle=shuffle,
                          sort_utt=sort_utt,
                          sort_stop_epoch=sort_stop_epoch,
                          progressbar=True)

        print('=> Loading mini-batch...')
        idx2char = Idx2char(map_file_path='../../metrics/mapping_files/' +
                            label_type_main + '.txt')
        idx2phone = Idx2phone(
            map_file_path='../../metrics/mapping_files/phone61.txt')

        for data, is_new_epoch in dataset:
            inputs, labels_char, labels_phone, inputs_seq_len, input_names = data

            if data_type != 'test':
                str_true_char = idx2char(labels_char[0][0])
                str_true_phone = idx2phone(labels_phone[0][0])
            else:
                str_true_char = labels_char[0][0][0]
                str_true_phone = labels_phone[0][0][0]

            print('----- %s ----- (epoch: %.3f)' %
                  (input_names[0][0], dataset.epoch_detail))
            print(str_true_char)
            print(str_true_phone)
Beispiel #5
0
    def check(self,
              label_type,
              data_type='dev',
              shuffle=False,
              sort_utt=False,
              sort_stop_epoch=None,
              frame_stacking=False,
              splice=1):

        print('========================================')
        print('  label_type: %s' % label_type)
        print('  data_type: %s' % data_type)
        print('  shuffle: %s' % str(shuffle))
        print('  sort_utt: %s' % str(sort_utt))
        print('  sort_stop_epoch: %s' % str(sort_stop_epoch))
        print('  frame_stacking: %s' % str(frame_stacking))
        print('  splice: %d' % splice)
        print('========================================')

        map_file_path = '../../metrics/mapping_files/' + label_type + '.txt'

        num_stack = 3 if frame_stacking else 1
        num_skip = 3 if frame_stacking else 1
        dataset = Dataset(data_type=data_type,
                          label_type=label_type,
                          batch_size=64,
                          map_file_path=map_file_path,
                          max_epoch=1,
                          splice=splice,
                          num_stack=num_stack,
                          num_skip=num_skip,
                          shuffle=shuffle,
                          sort_utt=sort_utt,
                          sort_stop_epoch=sort_stop_epoch,
                          progressbar=True)

        print('=> Loading mini-batch...')
        map_fn = Idx2phone(map_file_path=map_file_path)

        for data, is_new_epoch in dataset:
            inputs, labels, inputs_seq_len, labels_seq_len, input_names = data

            str_true = map_fn(labels[0][0][0:labels_seq_len[0][0]])

            print('----- %s ----- (epoch: %.3f)' %
                  (input_names[0][0], dataset.epoch_detail))
            print(inputs[0][0].shape)
            print(str_true)
    def check(self, encoder_type, lstm_impl, time_major=False):

        print('==================================================')
        print('  encoder_type: %s' % str(encoder_type))
        print('  lstm_impl: %s' % str(lstm_impl))
        print('  time_major: %s' % str(time_major))
        print('==================================================')

        tf.reset_default_graph()
        with tf.Graph().as_default():
            # Load batch data
            batch_size = 2
            inputs, labels_char, labels_phone, inputs_seq_len = generate_data(
                label_type='multitask', model='ctc', batch_size=batch_size)

            # Define model graph
            num_classes_main = 27
            num_classes_sub = 61
            model = MultitaskCTC(
                encoder_type=encoder_type,
                input_size=inputs[0].shape[1],
                num_units=256,
                num_layers_main=2,
                num_layers_sub=1,
                num_classes_main=num_classes_main,
                num_classes_sub=num_classes_sub,
                main_task_weight=0.8,
                lstm_impl=lstm_impl,
                parameter_init=0.1,
                clip_grad_norm=5.0,
                clip_activation=50,
                num_proj=256,
                weight_decay=1e-8,
                # bottleneck_dim=50,
                bottleneck_dim=None,
                time_major=time_major)

            # Define placeholders
            model.create_placeholders()
            learning_rate_pl = tf.placeholder(tf.float32, name='learning_rate')

            # Add to the graph each operation
            loss_op, logits_main, logits_sub = model.compute_loss(
                model.inputs_pl_list[0], model.labels_pl_list[0],
                model.labels_sub_pl_list[0], model.inputs_seq_len_pl_list[0],
                model.keep_prob_pl_list[0])
            train_op = model.train(loss_op,
                                   optimizer='adam',
                                   learning_rate=learning_rate_pl)
            decode_op_main, decode_op_sub = model.decoder(
                logits_main,
                logits_sub,
                model.inputs_seq_len_pl_list[0],
                beam_width=20)
            ler_op_main, ler_op_sub = model.compute_ler(
                decode_op_main, decode_op_sub, model.labels_pl_list[0],
                model.labels_sub_pl_list[0])

            # Define learning rate controller
            learning_rate = 1e-3
            lr_controller = Controller(learning_rate_init=learning_rate,
                                       decay_start_epoch=20,
                                       decay_rate=0.9,
                                       decay_patient_epoch=5,
                                       lower_better=True)

            # Add the variable initializer operation
            init_op = tf.global_variables_initializer()

            # Count total parameters
            parameters_dict, total_parameters = count_total_parameters(
                tf.trainable_variables())
            for parameter_name in sorted(parameters_dict.keys()):
                print("%s %d" %
                      (parameter_name, parameters_dict[parameter_name]))
            print("Total %d variables, %s M parameters" %
                  (len(parameters_dict.keys()), "{:,}".format(
                      total_parameters / 1000000)))

            # Make feed dict
            feed_dict = {
                model.inputs_pl_list[0]:
                inputs,
                model.labels_pl_list[0]:
                list2sparsetensor(labels_char, padded_value=-1),
                model.labels_sub_pl_list[0]:
                list2sparsetensor(labels_phone, padded_value=-1),
                model.inputs_seq_len_pl_list[0]:
                inputs_seq_len,
                model.keep_prob_pl_list[0]:
                0.9,
                learning_rate_pl:
                learning_rate
            }

            idx2phone = Idx2phone(map_file_path='./phone61.txt')

            with tf.Session() as sess:
                # Initialize parameters
                sess.run(init_op)

                # Wrapper for tfdbg
                # sess = tf_debug.LocalCLIDebugWrapperSession(sess)

                # Train model
                max_steps = 1000
                start_time_step = time.time()
                for step in range(max_steps):

                    # Compute loss
                    _, loss_train = sess.run([train_op, loss_op],
                                             feed_dict=feed_dict)

                    # Gradient check
                    # grads = sess.run(model.clipped_grads,
                    #                  feed_dict=feed_dict)
                    # for grad in grads:
                    #     print(np.max(grad))

                    if (step + 1) % 10 == 0:
                        # Change to evaluation mode
                        feed_dict[model.keep_prob_pl_list[0]] = 1.0

                        # Compute accuracy
                        ler_train_char, ler_train_phone = sess.run(
                            [ler_op_main, ler_op_sub], feed_dict=feed_dict)

                        duration_step = time.time() - start_time_step
                        print(
                            'Step %d: loss = %.3f / cer = %.3f / per = %.3f (%.3f sec) / lr = %.5f'
                            % (step + 1, loss_train, ler_train_char,
                               ler_train_phone, duration_step, learning_rate))
                        start_time_step = time.time()

                        # Visualize
                        labels_pred_char_st, labels_pred_phone_st = sess.run(
                            [decode_op_main, decode_op_sub],
                            feed_dict=feed_dict)
                        labels_pred_char = sparsetensor2list(
                            labels_pred_char_st, batch_size=batch_size)
                        labels_pred_phone = sparsetensor2list(
                            labels_pred_phone_st, batch_size=batch_size)

                        print('Character')
                        try:
                            print('  Ref: %s' % idx2alpha(labels_char[0]))
                            print('  Hyp: %s' % idx2alpha(labels_pred_char[0]))
                        except IndexError:
                            print('Character')
                            print('  Ref: %s' % idx2alpha(labels_char[0]))
                            print('  Hyp: %s' % '')

                        print('Phone')
                        try:
                            print('  Ref: %s' % idx2phone(labels_phone[0]))
                            print('  Hyp: %s' %
                                  idx2phone(labels_pred_phone[0]))
                        except IndexError:
                            print('  Ref: %s' % idx2phone(labels_phone[0]))
                            print('  Hyp: %s' % '')
                            # NOTE: This is for no prediction
                        print('-' * 30)

                        if ler_train_char < 0.1:
                            print('Modle is Converged.')
                            break

                        # Update learning rate
                        learning_rate = lr_controller.decay_lr(
                            learning_rate=learning_rate,
                            epoch=step,
                            value=ler_train_char)
                        feed_dict[learning_rate_pl] = learning_rate
def decode(session,
           decode_op_main,
           decode_op_sub,
           model,
           dataset,
           label_type_main,
           label_type_sub,
           is_test=True,
           save_path=None):
    """Visualize label outputs of Multi-task CTC model.
    Args:
        session: session of training model
        decode_op_main: operation for decoding in the main task
        decode_op_sub: operation for decoding in the sub task
        model: the model to evaluate
        dataset: An instance of a `Dataset` class
        label_type_main (string): character or character_capital_divide
        label_type_sub (string): phone39 or phone48 or phone61
        is_test (bool, optional):
        save_path (string, optional): path to save decoding results
    """
    idx2char = Idx2char(map_file_path='../metrics/mapping_files/' +
                        label_type_main + '.txt')
    idx2phone = Idx2phone(map_file_path='../metrics/mapping_files/' +
                          label_type_sub + '.txt')

    if save_path is not None:
        sys.stdout = open(join(model.model_dir, 'decode.txt'), 'w')

    for data, is_new_epoch in dataset:

        # Create feed dictionary for next mini batch
        inputs, labels_true_char, labels_true_phone, inputs_seq_len, input_names = data

        feed_dict = {
            model.inputs_pl_list[0]: inputs[0],
            model.inputs_seq_len_pl_list[0]: inputs_seq_len[0],
            model.keep_prob_pl_list[0]: 1.0
        }

        batch_size = inputs[0].shape[0]
        labels_pred_char_st, labels_pred_phone_st = session.run(
            [decode_op_main, decode_op_sub], feed_dict=feed_dict)
        try:
            labels_pred_char = sparsetensor2list(labels_pred_char_st,
                                                 batch_size=batch_size)
        except:
            # no output
            labels_pred_char = ['']
        try:
            labels_pred_phone = sparsetensor2list(labels_pred_char_st,
                                                  batch_size=batch_size)
        except:
            # no output
            labels_pred_phone = ['']

        for i_batch in range(batch_size):
            print('----- wav: %s -----' % input_names[0][i_batch])

            if is_test:
                str_true_char = labels_true_char[0][i_batch][0].replace(
                    '_', ' ')
                str_true_phone = labels_true_phone[0][i_batch][0]
            else:
                str_true_char = idx2char(labels_true_char[0][i_batch])
                str_true_phone = idx2phone(labels_true_phone[0][i_batch])

            str_pred_char = idx2char(labels_pred_char[i_batch])
            str_pred_phone = idx2phone(labels_pred_phone[i_batch])

            print('Ref (char): %s' % str_true_char)
            print('Hyp (char): %s' % str_pred_char)
            print('Ref (phone): %s' % str_true_phone)
            print('Hyp (phone): %s' % str_pred_phone)

        if is_new_epoch:
            break
Beispiel #8
0
    def check_training(self,
                       encoder_type,
                       label_type,
                       lstm_impl='LSTMBlockCell',
                       save_params=False):

        print('==================================================')
        print('  encoder_type: %s' % encoder_type)
        print('  label_type: %s' % label_type)
        print('  lstm_impl: %s' % lstm_impl)
        print('==================================================')

        tf.reset_default_graph()
        with tf.Graph().as_default():
            # Load batch data
            batch_size = 1
            splice = 11 if encoder_type in [
                'vgg_blstm', 'vgg_lstm', 'vgg_wang', 'resnet_wang', 'cnn_zhang'
            ] else 1
            inputs, labels_true_st, inputs_seq_len = generate_data(
                label_type=label_type,
                model='ctc',
                batch_size=batch_size,
                splice=splice)
            # NOTE: input_size must be even number when using CudnnLSTM

            # Define model graph
            num_classes = 26 if label_type == 'character' else 61
            model = CTC(
                encoder_type=encoder_type,
                input_size=inputs[0].shape[-1] // splice,
                splice=splice,
                num_units=256,
                num_layers=2,
                num_classes=num_classes,
                lstm_impl=lstm_impl,
                parameter_init=0.1,
                clip_grad=5.0,
                clip_activation=50,
                num_proj=256,
                # bottleneck_dim=50,
                bottleneck_dim=None,
                weight_decay=1e-8)

            # Define placeholders
            model.create_placeholders()
            learning_rate_pl = tf.placeholder(tf.float32, name='learning_rate')

            # Add to the graph each operation
            loss_op, logits = model.compute_loss(
                model.inputs_pl_list[0], model.labels_pl_list[0],
                model.inputs_seq_len_pl_list[0],
                model.keep_prob_input_pl_list[0],
                model.keep_prob_hidden_pl_list[0],
                model.keep_prob_output_pl_list[0])
            train_op = model.train(loss_op,
                                   optimizer='adam',
                                   learning_rate=learning_rate_pl)
            # NOTE: Adam does not run on CudnnLSTM
            decode_op = model.decoder(logits,
                                      model.inputs_seq_len_pl_list[0],
                                      beam_width=20)
            ler_op = model.compute_ler(decode_op, model.labels_pl_list[0])

            # Define learning rate controller
            learning_rate = 1e-3
            lr_controller = Controller(learning_rate_init=learning_rate,
                                       decay_start_epoch=10,
                                       decay_rate=0.98,
                                       decay_patient_epoch=5,
                                       lower_better=True)

            if save_params:
                # Create a saver for writing training checkpoints
                saver = tf.train.Saver(max_to_keep=None)

            # Add the variable initializer operation
            init_op = tf.global_variables_initializer()

            # Count total parameters
            if lstm_impl != 'CudnnLSTM':
                parameters_dict, total_parameters = count_total_parameters(
                    tf.trainable_variables())
                for parameter_name in sorted(parameters_dict.keys()):
                    print("%s %d" %
                          (parameter_name, parameters_dict[parameter_name]))
                print("Total %d variables, %s M parameters" %
                      (len(parameters_dict.keys()), "{:,}".format(
                          total_parameters / 1000000)))

            # Make feed dict
            feed_dict = {
                model.inputs_pl_list[0]: inputs,
                model.labels_pl_list[0]: labels_true_st,
                model.inputs_seq_len_pl_list[0]: inputs_seq_len,
                model.keep_prob_input_pl_list[0]: 1.0,
                model.keep_prob_hidden_pl_list[0]: 1.0,
                model.keep_prob_output_pl_list[0]: 1.0,
                learning_rate_pl: learning_rate
            }

            idx2phone = Idx2phone(map_file_path='./phone61_ctc.txt')

            with tf.Session() as sess:
                # Initialize parameters
                sess.run(init_op)

                # Wrapper for tfdbg
                # sess = tf_debug.LocalCLIDebugWrapperSession(sess)

                # Train model
                max_steps = 1000
                start_time_global = time.time()
                start_time_step = time.time()
                ler_train_pre = 1
                not_improved_count = 0
                for step in range(max_steps):

                    # Compute loss
                    _, loss_train = sess.run([train_op, loss_op],
                                             feed_dict=feed_dict)

                    # Gradient check
                    # grads = sess.run(model.clipped_grads,
                    #                  feed_dict=feed_dict)
                    # for grad in grads:
                    #     print(np.max(grad))

                    if (step + 1) % 10 == 0:
                        # Change to evaluation mode
                        feed_dict[model.keep_prob_input_pl_list[0]] = 1.0
                        feed_dict[model.keep_prob_hidden_pl_list[0]] = 1.0
                        feed_dict[model.keep_prob_output_pl_list[0]] = 1.0

                        # Compute accuracy
                        ler_train = sess.run(ler_op, feed_dict=feed_dict)

                        duration_step = time.time() - start_time_step
                        print(
                            'Step %d: loss = %.3f / ler = %.3f (%.3f sec) / lr = %.5f'
                            % (step + 1, loss_train, ler_train, duration_step,
                               learning_rate))
                        start_time_step = time.time()

                        # Decode
                        labels_pred_st = sess.run(decode_op,
                                                  feed_dict=feed_dict)
                        labels_true = sparsetensor2list(labels_true_st,
                                                        batch_size=batch_size)

                        # Visualize
                        try:
                            labels_pred = sparsetensor2list(
                                labels_pred_st, batch_size=batch_size)
                            if label_type == 'character':
                                print('Ref: %s' % idx2alpha(labels_true[0]))
                                print('Hyp: %s' % idx2alpha(labels_pred[0]))
                            else:
                                print('Ref: %s' % idx2phone(labels_true[0]))
                                print('Hyp: %s' % idx2phone(labels_pred[0]))

                        except IndexError:
                            if label_type == 'character':
                                print('Ref: %s' % idx2alpha(labels_true[0]))
                                print('Hyp: %s' % '')
                            else:
                                print('Ref: %s' % idx2phone(labels_true[0]))
                                print('Hyp: %s' % '')
                            # NOTE: This is for no prediction

                        if ler_train >= ler_train_pre:
                            not_improved_count += 1
                        else:
                            not_improved_count = 0
                        if ler_train < 0.05:
                            print('Modle is Converged.')
                            if save_params:
                                # Save model (check point)
                                checkpoint_file = './model.ckpt'
                                save_path = saver.save(sess,
                                                       checkpoint_file,
                                                       global_step=1)
                                print("Model saved in file: %s" % save_path)
                            break
                        ler_train_pre = ler_train

                        # Update learning rate
                        learning_rate = lr_controller.decay_lr(
                            learning_rate=learning_rate,
                            epoch=step,
                            value=ler_train)
                        feed_dict[learning_rate_pl] = learning_rate

                duration_global = time.time() - start_time_global
                print('Total time: %.3f sec' % (duration_global))
Beispiel #9
0
def decode(session,
           decode_op,
           model,
           dataset,
           label_type,
           is_test=False,
           save_path=None):
    """Visualize label outputs of Attention-based model.
    Args:
        session: session of training model
        decode_op: operation for decoding
        model: the model to evaluate
        dataset: An instance of a `Dataset` class
        label_type (string): phone39 or phone48 or phone61 or character or
            character_capital_divide
        is_test (bool, optional):
        save_path (string): path to save decoding results
    """
    if label_type == 'character':
        map_fn = Idx2char(
            map_file_path='../metrics/mapping_files/character.txt')
    elif label_type == 'character_capital_divide':
        map_fn = Idx2char(
            map_file_path=
            '../metrics/mapping_files/character_capital_divide.txt',
            capital_divide=True)
    else:
        map_fn = Idx2phone(map_file_path='../metrics/mapping_files/' +
                           label_type + '.txt')

    if save_path is not None:
        sys.stdout = open(join(model.model_dir, 'decode.txt'), 'w')

    for data, is_new_epoch in dataset:

        # Create feed dictionary for next mini batch
        inputs, labels_true, inputs_seq_len, labels_seq_len, input_names = data

        feed_dict = {
            model.inputs_pl_list[0]: inputs[0],
            model.inputs_seq_len_pl_list[0]: inputs_seq_len[0],
            model.keep_prob_encoder_pl_list[0]: 1.0,
            model.keep_prob_decoder_pl_list[0]: 1.0,
            model.keep_prob_embedding_pl_list[0]: 1.0
        }

        batch_size = inputs[0].shape[0]
        labels_pred = session.run(decode_op, feed_dict=feed_dict)
        for i_batch in range(batch_size):
            print('----- wav: %s -----' % input_names[0][i_batch])
            if is_test:
                str_true = labels_true[0][i_batch][0]
            else:
                str_true = map_fn(
                    labels_true[0][i_batch][1:labels_seq_len[0][i_batch] - 1])
                # NOTE: Exclude <SOS> and <EOS>
            str_pred = map_fn(labels_pred[i_batch]).split('>')[0]
            # NOTE: Trancate by <EOS>

            if 'phone' in label_type:
                # Remove the last space
                if str_pred[-1] == ' ':
                    str_pred = str_pred[:-1]

            print('Ref: %s' % str_true)
            print('Hyp: %s' % str_pred)

        if is_new_epoch:
            break
def decode(session,
           decode_op,
           model,
           dataset,
           label_type,
           is_test=True,
           save_path=None):
    """Visualize label outputs of CTC model.
    Args:
        session: session of training model
        decode_op: operation for decoding
        model: the model to evaluate
        dataset: An instance of a `Dataset` class
        label_type (string): phone39 or phone48 or phone61 or character or
            character_capital_divide
        is_test (bool, optional):
        save_path (string, optional): path to save decoding results
    """
    if label_type == 'character':
        map_fn = Idx2char(
            map_file_path='../metrics/mapping_files/character.txt')
    elif label_type == 'character_capital_divide':
        map_fn = Idx2char(
            map_file_path=
            '../metrics/mapping_files/character_capital_divide.txt',
            capital_divide=True)
    else:
        map_fn = Idx2phone(map_file_path='../metrics/mapping_files/' +
                           label_type + '.txt')

    if save_path is not None:
        sys.stdout = open(join(model.model_dir, 'decode.txt'), 'w')

    for data, is_new_epoch in dataset:

        # Create feed dictionary for next mini batch
        inputs, labels_true, inputs_seq_len, input_names = data

        feed_dict = {
            model.inputs_pl_list[0]: inputs[0],
            model.inputs_seq_len_pl_list[0]: inputs_seq_len[0],
            model.keep_prob_pl_list[0]: 1.0
        }

        batch_size = inputs[0].shape[0]
        labels_pred_st = session.run(decode_op, feed_dict=feed_dict)
        try:
            labels_pred = sparsetensor2list(labels_pred_st,
                                            batch_size=batch_size)
        except IndexError:
            # no output
            labels_pred = ['']

        for i_batch in range(batch_size):
            print('----- wav: %s -----' % input_names[0][i_batch])
            if 'char' in label_type:
                if is_test:
                    str_true = labels_true[0][i_batch][0]
                else:
                    str_true = map_fn(labels_true[0][i_batch])
                str_pred = map_fn(labels_pred[i_batch])
            else:
                if is_test:
                    str_true = labels_true[0][i_batch][0]
                else:
                    str_true = map_fn(labels_true[0][i_batch])
                str_pred = map_fn(labels_pred[i_batch])

            print('Ref: %s' % str_true)
            print('Hyp: %s' % str_pred)

        if is_new_epoch:
            break
Beispiel #11
0
def do_eval_per(session, decode_op, per_op, model, dataset, label_type,
                is_test=False, eval_batch_size=None, progressbar=False,
                is_multitask=False, is_jointctcatt=False):
    """Evaluate trained model by Phone Error Rate.
    Args:
        session: session of training model
        decode_op: operation for decoding
        per_op: operation for computing phone error rate
        model: the model to evaluate
        dataset: An instance of a `Dataset' class
        label_type (string): phone39 or phone48 or phone61
        is_test (bool, optional): set to True when evaluating by the test set
        eval_batch_size (int, optional): the batch size when evaluating the model
        progressbar (bool, optional): if True, visualize the progressbar
        is_multitask (bool, optional): if True, evaluate the multitask model
        is_jointctcatt (bool, optional): if True, evaluate the joint
            CTC-Attention model
    Returns:
        per_mean (float): An average of PER
    """
    batch_size_original = dataset.batch_size

    # Reset data counter
    dataset.reset()

    # Set batch size in the evaluation
    if eval_batch_size is not None:
        dataset.batch_size = eval_batch_size

    train_label_type = label_type
    eval_label_type = dataset.label_type_sub if is_multitask else dataset.label_type

    idx2phone_train = Idx2phone(
        map_file_path='../metrics/mapping_files/' + train_label_type + '.txt')
    idx2phone_eval = Idx2phone(
        map_file_path='../metrics/mapping_files/' + eval_label_type + '.txt')
    map2phone39_train = Map2phone39(
        label_type=train_label_type,
        map_file_path='../metrics/mapping_files/phone2phone.txt')
    map2phone39_eval = Map2phone39(
        label_type=eval_label_type,
        map_file_path='../metrics/mapping_files/phone2phone.txt')

    per_mean = 0
    if progressbar:
        pbar = tqdm(total=len(dataset))
    for data, is_new_epoch in dataset:

        # Create feed dictionary for next mini-batch
        if is_multitask:
            inputs, _, labels_true, inputs_seq_len, labels_seq_len, _ = data
        elif is_jointctcatt:
            inputs, labels_true, _, inputs_seq_len, labels_seq_len, _ = data
        else:
            inputs, labels_true, inputs_seq_len, labels_seq_len, _ = data

        feed_dict = {
            model.inputs_pl_list[0]: inputs[0],
            model.inputs_seq_len_pl_list[0]: inputs_seq_len[0],
            model.keep_prob_encoder_pl_list[0]: 1.0,
            model.keep_prob_decoder_pl_list[0]: 1.0,
            model.keep_prob_embedding_pl_list[0]: 1.0
        }

        batch_size = inputs[0].shape[0]

        # Evaluate by 39 phones
        labels_pred = session.run(decode_op, feed_dict=feed_dict)

        for i_batch in range(batch_size):
            ###############
            # Hypothesis
            ###############
            # Convert from index to phone (-> list of phone strings)
            str_pred = idx2phone_train(labels_pred[i_batch]).split('>')[0]
            # NOTE: Trancate by <EOS>

            # Remove the last space
            if str_pred[-1] == ' ':
                str_pred = str_pred[:-1]

            phone_pred_list = str_pred.split(' ')

            ###############
            # Reference
            ###############
            if is_test:
                phone_true_list = labels_true[0][i_batch][0].split(' ')
            else:
                # Convert from index to phone (-> list of phone strings)
                phone_true_list = idx2phone_eval(
                    labels_true[0][i_batch][1:labels_seq_len[0][i_batch] - 1]).split(' ')
                # NOTE: Exclude <SOS> and <EOS>

            # Mapping to 39 phones (-> list of phone strings)
            phone_pred_list = map2phone39_train(phone_pred_list)
            phone_true_list = map2phone39_eval(phone_true_list)

            # Compute PER
            per_mean += compute_per(ref=phone_true_list,
                                    hyp=phone_pred_list,
                                    normalize=True)

            if progressbar:
                pbar.update(1)

        if is_new_epoch:
            break

    per_mean /= len(dataset)

    # Register original batch size
    if eval_batch_size is not None:
        dataset.batch_size = batch_size_original

    return per_mean
Beispiel #12
0
def plot(session,
         decode_op,
         attention_weights_op,
         model,
         dataset,
         label_type,
         is_test=False,
         save_path=None,
         show=False):
    """Visualize attention weights of Attetnion-based model.
    Args:
        session: session of training model
        decode_op: operation for decoding
        attention_weights_op: operation for computing attention weights
        model: model to evaluate
        dataset: An instance of a `Dataset` class
        label_type (string, optional): phone39 or phone48 or phone61 or character or
            character_capital_divide
        is_test (bool, optional):
        save_path (string, optional): path to save attention weights plotting
        show (bool, optional): if True, show each figure
    """
    # Clean directory
    if save_path is not None and isdir(save_path):
        shutil.rmtree(save_path)
        mkdir(save_path)

    if label_type == 'character':
        map_fn = Idx2char(
            map_file_path='../metrics/mapping_files/character.txt')
    elif label_type == 'character_capital_divide':
        map_fn = Idx2char(
            map_file_path=
            '../metrics/mapping_files/character_capital_divide.txt',
            capital_divide=True)
    else:
        map_fn = Idx2phone(map_file_path='../metrics/mapping_files/' +
                           label_type + '.txt')

    for data, is_new_epoch in dataset:

        # Create feed dictionary for next mini batch
        inputs, labels_true, inputs_seq_len, _, input_names = data

        feed_dict = {
            model.inputs_pl_list[0]: inputs[0],
            model.inputs_seq_len_pl_list[0]: inputs_seq_len[0],
            model.keep_prob_encoder_pl_list[0]: 1.0,
            model.keep_prob_decoder_pl_list[0]: 1.0,
            model.keep_prob_embedding_pl_list[0]: 1.0
        }

        # Visualize
        batch_size, max_frame_num = inputs.shape[:2]
        attention_weights, labels_pred = session.run(
            [attention_weights_op, decode_op], feed_dict=feed_dict)

        for i_batch in range(batch_size):

            # t_out, t_in = attention_weights[i_batch].shape

            # Check if the sum of attention weights equals to 1
            # print(np.sum(attention_weights[i_batch], axis=1))

            # Convert from index to label
            str_pred = map_fn(labels_pred[i_batch])
            if 'phone' in label_type:
                label_list = str_pred.split(' ')
            else:
                raise NotImplementedError

            plt.clf()
            plt.figure(figsize=(10, 4))
            sns.heatmap(attention_weights[i_batch],
                        cmap='Blues',
                        xticklabels=False,
                        yticklabels=label_list)

            plt.xlabel('Input frames', fontsize=12)
            plt.ylabel('Output labels (top to bottom)', fontsize=12)

            if show:
                plt.show()

            # Save as a png file
            if save_path is not None:
                plt.savefig(join(save_path, input_names[0] + '.png'), dvi=500)

        if is_new_epoch:
            break
Beispiel #13
0
def decode_test_multitask(session,
                          decode_op_main,
                          decode_op_sub,
                          model,
                          dataset,
                          label_type_main,
                          label_type_sub,
                          save_path=None):
    """Visualize label outputs of Multi-task CTC model.
    Args:
        session: session of training model
        decode_op_main: operation for decoding in the main task
        decode_op_sub: operation for decoding in the sub task
        model: the model to evaluate
        dataset: An instance of a `Dataset` class
        label_type_main (string): character or character_capital_divide
        label_type_sub (string): phone39 or phone48 or phone61
        save_path (string, optional): path to save decoding results
    """
    # TODO: fix

    if save_path is not None:
        sys.stdout = open(join(model.model_dir, 'decode.txt'), 'w')

    # Decode character
    print('===== ' + label_type_main + ' =====')
    idx2char = Idx2char(map_file_path='../metrics/mapping_files/ctc/' +
                        label_type_main + '.txt')
    while True:

        # Create feed dictionary for next mini batch
        data, is_new_epoch = dataset.next(batch_size=1)
        inputs, labels_true, _, inputs_seq_len, input_names = data
        # NOTE: Batch size is expected to be 1

        feed_dict = {
            model.inputs_pl_list[0]: inputs,
            model.inputs_seq_len_pl_list[0]: inputs_seq_len,
            model.keep_prob_input_pl_list[0]: 1.0,
            model.keep_prob_hidden_pl_list[0]: 1.0,
            model.keep_prob_output_pl_list[0]: 1.0
        }

        # Visualize
        labels_pred_st = session.run(decode_op_main, feed_dict=feed_dict)
        labels_pred = sparsetensor2list(labels_pred_st, batch_size=1)

        print('----- wav: %s -----' % input_names[0])
        print('Ref: %s' % idx2char(labels_true[0]))
        print('Hyp: %s' % idx2char(labels_pred[0]))

        if is_new_epoch:
            break

    # Decode phone
    print('\n===== ' + label_type_sub + ' =====')
    idx2phone = Idx2phone(map_file_path='../metrics/mapping_files/ctc/' +
                          label_type_sub + '.txt')
    while True:

        # Create feed dictionary for next mini batch
        data, is_new_epoch = dataset.next(batch_size=1)
        inputs, _, labels_true, inputs_seq_len, input_names = data

        feed_dict = {
            model.inputs_pl_list[0]: inputs,
            model.inputs_seq_len_pl_list[0]: inputs_seq_len,
            model.keep_prob_input_pl_list[0]: 1.0,
            model.keep_prob_hidden_pl_list[0]: 1.0,
            model.keep_prob_output_pl_list[0]: 1.0
        }

        # Visualize
        labels_pred_st = session.run(decode_op_sub, feed_dict=feed_dict)
        try:
            labels_pred = sparsetensor2list(labels_pred_st, batch_size=1)
        except IndexError:
            # no output
            labels_pred = ['']
        finally:
            print('----- wav: %s -----' % input_names[0])
            print('Ref: %s' % idx2phone(labels_true[0]))
            print('Hyp: %s' % idx2phone(labels_pred[0]))

        if is_new_epoch:
            break
Beispiel #14
0
def decode_test(session,
                decode_op,
                model,
                dataset,
                label_type,
                save_path=None):
    """Visualize label outputs of CTC model.
    Args:
        session: session of training model
        decode_op: operation for decoding
        model: the model to evaluate
        dataset: An instance of a `Dataset` class
        label_type (string): phone39 or phone48 or phone61 or character or
            character_capital_divide
        save_path (string, optional): path to save decoding results
    """
    if label_type == 'character':
        map_fn = Idx2char(
            map_file_path='../metrics/mapping_files/ctc/character.txt')
    elif label_type == 'character_capital_divide':
        map_fn = Idx2char(
            map_file_path=
            '../metrics/mapping_files/ctc/character_capital_divide.txt',
            capital_divide=True)
    else:
        map_fn = Idx2phone(map_file_path='../metrics/mapping_files/ctc/' +
                           label_type + '.txt')

    if save_path is not None:
        sys.stdout = open(join(model.model_dir, 'decode.txt'), 'w')

    while True:

        # Create feed dictionary for next mini batch
        data, is_new_epoch = dataset.next(batch_size=1)
        inputs, labels_true, inputs_seq_len, input_names = data
        # NOTE: Batch size is expected to be 1

        feed_dict = {
            model.inputs_pl_list[0]: inputs,
            model.inputs_seq_len_pl_list[0]: inputs_seq_len,
            model.keep_prob_input_pl_list[0]: 1.0,
            model.keep_prob_hidden_pl_list[0]: 1.0,
            model.keep_prob_output_pl_list[0]: 1.0
        }

        # Visualize
        labels_pred_st = session.run(decode_op, feed_dict=feed_dict)
        try:
            labels_pred = sparsetensor2list(labels_pred_st, batch_size=1)
        except IndexError:
            # no output
            labels_pred = ['']
        finally:
            print('----- wav: %s -----' % input_names[0])
            if label_type == 'character':
                true_seq = map_fn(labels_true[0]).replace('_', ' ')
                pred_seq = map_fn(labels_pred[0]).replace('_', ' ')
            else:
                true_seq = map_fn(labels_true[0])
                pred_seq = map_fn(labels_pred[0])
            print('Ref: %s' % true_seq)
            print('Hyp: %s' % pred_seq)

        if is_new_epoch:
            break
Beispiel #15
0
def do_eval_per(session,
                decode_op,
                per_op,
                model,
                dataset,
                label_type,
                eval_batch_size=None,
                progressbar=False,
                is_multitask=False):
    """Evaluate trained model by Phone Error Rate.
    Args:
        session: session of training model
        decode_op: operation for decoding
        per_op: operation for computing phone error rate
        model: the model to evaluate
        dataset: An instance of a `Dataset' class
        label_type (string): phone39 or phone48 or phone61
        eval_batch_size (int, optional): the batch size when evaluating the model
        progressbar (bool, optional): if True, visualize the progressbar
        is_multitask (bool, optional): if True, evaluate the multitask model
    Returns:
        per_mean (float): An average of PER
    """
    # Reset data counter
    dataset.reset()

    train_label_type = label_type
    eval_label_type = dataset.label_type_sub if is_multitask else dataset.label_type

    # phone2idx_39_map_file_path = '../metrics/mapping_files/ctc/phone39.txt'
    idx2phone_train = Idx2phone(map_file_path='../metrics/mapping_files/ctc/' +
                                train_label_type + '.txt')
    idx2phone_eval = Idx2phone(map_file_path='../metrics/mapping_files/ctc/' +
                               eval_label_type + '.txt')
    map2phone39_train = Map2phone39(
        label_type=train_label_type,
        map_file_path='../metrics/mapping_files/phone2phone.txt')
    map2phone39_eval = Map2phone39(
        label_type=eval_label_type,
        map_file_path='../metrics/mapping_files/phone2phone.txt')

    per_mean = 0
    if progressbar:
        pbar = tqdm(total=len(dataset))
    for data, is_new_epoch in dataset:

        # Create feed dictionary for next mini batch
        if is_multitask:
            inputs, _, labels_true, inputs_seq_len, _ = data
        else:
            inputs, labels_true, inputs_seq_len, _ = data

        feed_dict = {
            model.inputs_pl_list[0]: inputs,
            model.inputs_seq_len_pl_list[0]: inputs_seq_len,
            model.keep_prob_input_pl_list[0]: 1.0,
            model.keep_prob_hidden_pl_list[0]: 1.0,
            model.keep_prob_output_pl_list[0]: 1.0
        }

        batch_size_each = len(inputs)

        # Evaluate by 39 phones
        labels_pred_st = session.run(decode_op, feed_dict=feed_dict)
        labels_pred = sparsetensor2list(labels_pred_st, batch_size_each)

        for i_batch in range(batch_size_each):
            ###############
            # Hypothesis
            ###############
            # Convert from index to phone (-> list of phone strings)
            phone_pred_list = idx2phone_train(labels_pred[i_batch]).split(' ')

            # Mapping to 39 phones (-> list of phone strings)
            phone_pred_list = map2phone39_train(phone_pred_list)

            ###############
            # Reference
            ###############
            # Convert from index to phone (-> list of phone strings)
            phone_true_list = idx2phone_eval(labels_true[i_batch]).split(' ')

            # Mapping to 39 phones (-> list of phone strings)
            phone_true_list = map2phone39_eval(phone_true_list)

            # Compute PER
            per_mean += compute_per(ref=phone_pred_list,
                                    hyp=phone_true_list,
                                    normalize=True)

            if progressbar:
                pbar.update(1)

        if is_new_epoch:
            break

    per_mean /= len(dataset)

    return per_mean
    def check(self,
              label_type,
              data_type='dev',
              backend='pytorch',
              shuffle=False,
              sort_utt=False,
              sort_stop_epoch=None,
              frame_stacking=False,
              splice=1):

        print('========================================')
        print('  backend: %s' % backend)
        print('  label_type: %s' % label_type)
        print('  data_type: %s' % data_type)
        print('  shuffle: %s' % str(shuffle))
        print('  sort_utt: %s' % str(sort_utt))
        print('  sort_stop_epoch: %s' % str(sort_stop_epoch))
        print('  frame_stacking: %s' % str(frame_stacking))
        print('  splice: %d' % splice)
        print('========================================')

        num_stack = 3 if frame_stacking else 1
        num_skip = 3 if frame_stacking else 1
        dataset = Dataset(data_save_path='/n/sd8/inaguma/corpus/timit/kaldi',
                          backend=backend,
                          input_freq=41,
                          use_delta=True,
                          use_double_delta=True,
                          data_type=data_type,
                          label_type=label_type,
                          batch_size=64,
                          max_epoch=2,
                          splice=splice,
                          num_stack=num_stack,
                          num_skip=num_skip,
                          shuffle=shuffle,
                          sort_utt=sort_utt,
                          sort_stop_epoch=sort_stop_epoch,
                          tool='htk',
                          num_enque=None)

        print('=> Loading mini-batch...')
        idx2phone = Idx2phone(dataset.vocab_file_path)

        for batch, is_new_epoch in dataset:
            if data_type == 'train' and backend == 'pytorch':
                for i in range(len(batch['xs'])):
                    if batch['xs'].shape[1] < batch['ys'].shape[1]:
                        raise ValueError(
                            'input length must be longer than label length.')

            if dataset.is_test:
                str_true = batch['ys'][0][0]
            else:
                str_true = idx2phone(batch['ys'][0][:batch['y_lens'][0]])

            print('----- %s (epoch: %.3f, batch: %d) -----' %
                  (batch['input_names'][0], dataset.epoch_detail,
                   len(batch['xs'])))
            print(str_true)
            print('x_lens: %d' % (batch['x_lens'][0] * num_stack))
            if not dataset.is_test:
                print('y_lens: %d' % batch['y_lens'][0])
    def __init__(self,
                 data_save_path,
                 backend,
                 input_freq,
                 use_delta,
                 use_double_delta,
                 data_type,
                 label_type,
                 batch_size,
                 max_epoch=None,
                 splice=1,
                 num_stack=1,
                 num_skip=1,
                 shuffle=False,
                 sort_utt=False,
                 reverse=False,
                 sort_stop_epoch=None,
                 tool='htk',
                 num_enque=None,
                 dynamic_batching=False):
        """A class for loading dataset.
        Args:
            data_save_path (string): path to saved data
            backend (string): pytorch or chainer
            input_freq (int): the number of dimensions of acoustics
            use_delta (bool): if True, use the delta feature
            use_double_delta (bool): if True, use the acceleration feature
            data_type (string): train or dev or test
            label_type (string): phone39 or phone48 or phone61
            batch_size (int): the size of mini-batch
            max_epoch (int): the max epoch. None means infinite loop.
            splice (int): frames to splice. Default is 1 frame.
            num_stack (int): the number of frames to stack
            num_skip (int): the number of frames to skip
            shuffle (bool): if True, shuffle utterances. This is
                disabled when sort_utt is True.
            sort_utt (bool): if True, sort all utterances in the
                ascending order
            reverse (bool): if True, sort utteraces in the
                descending order
            sort_stop_epoch (int): After sort_stop_epoch, training
                will revert back to a random order
            tool (string): htk or librosa or python_speech_features
            num_enque (int): the number of elements to enqueue
            dynamic_batching (bool): if True, batch size will be
                chainged dynamically in training
        """
        self.backend = backend
        self.input_freq = input_freq
        self.use_delta = use_delta
        self.use_double_delta = use_double_delta
        self.data_type = data_type
        self.label_type = label_type
        self.batch_size = batch_size
        self.max_epoch = max_epoch
        self.splice = splice
        self.num_stack = num_stack
        self.num_skip = num_skip
        self.shuffle = shuffle
        self.sort_utt = sort_utt
        self.sort_stop_epoch = sort_stop_epoch
        self.num_gpus = 1
        self.tool = tool
        self.num_enque = num_enque
        self.dynamic_batching = dynamic_batching
        self.is_test = True if data_type == 'test' else False

        self.vocab_file_path = join(data_save_path, 'vocab',
                                    label_type + '.txt')
        self.idx2phone = Idx2phone(self.vocab_file_path)

        super(Dataset, self).__init__(vocab_file_path=self.vocab_file_path)

        # Load dataset file
        dataset_path = join(data_save_path, 'dataset', tool, data_type,
                            label_type + '.csv')
        df = pd.read_csv(dataset_path)
        df = df.loc[:, ['frame_num', 'input_path', 'transcript']]

        # Sort paths to input & label
        if sort_utt:
            df = df.sort_values(by='frame_num', ascending=not reverse)
        else:
            df = df.sort_values(by='input_path', ascending=True)

        self.df = df
        self.rest = set(list(df.index))
def eval_phone(model,
               dataset,
               map_file_path,
               eval_batch_size,
               beam_width,
               max_decode_len,
               length_penalty=0,
               progressbar=False):
    """Evaluate trained model by Phone Error Rate.
    Args:
        model: the model to evaluate
        dataset: An instance of a `Dataset' class
        map_file_path (string): path to phones.60-48-39.map
        eval_batch_size (int): the batch size when evaluating the model
        beam_width: (int): the size of beam
        max_decode_len (int): the length of output sequences
            to stop prediction. This is used for seq2seq models.
        length_penalty (float, optional):
        progressbar (bool, optional): if True, visualize the progressbar
    Returns:
        per (float): Phone error rate
        df_per (pd.DataFrame): dataframe of substitution, insertion, and deletion
    """
    # Reset data counter
    dataset.reset()

    idx2phone = Idx2phone(vocab_file_path=dataset.vocab_file_path)
    map2phone39 = Map2phone39(label_type=dataset.label_type,
                              map_file_path=map_file_path)

    per = 0
    sub, ins, dele = 0, 0, 0
    num_phones = 0
    if progressbar:
        pbar = tqdm(total=len(dataset))  # TODO: fix this
    while True:
        batch, is_new_epoch = dataset.next(batch_size=eval_batch_size)

        # Decode
        best_hyps, _, perm_idx = model.decode(batch['xs'],
                                              batch['x_lens'],
                                              beam_width=beam_width,
                                              max_decode_len=max_decode_len,
                                              length_penalty=length_penalty)
        ys = batch['ys'][perm_idx]
        y_lens = batch['y_lens'][perm_idx]

        for b in range(len(batch['xs'])):
            ##############################
            # Reference
            ##############################
            if dataset.is_test:
                phone_ref_list = ys[b][0].split(' ')
                # NOTE: transcript is seperated by space(' ')
            else:
                # Convert from index to phone (-> list of phone strings)
                phone_ref_list = idx2phone(ys[b][:y_lens[b]]).split(' ')

            ##############################
            # Hypothesis
            ##############################
            # Convert from index to phone (-> list of phone strings)
            str_hyp = idx2phone(best_hyps[b])
            str_hyp = re.sub(r'(.*) >(.*)', r'\1', str_hyp)
            # NOTE: Trancate by the first <EOS>

            phone_hyp_list = str_hyp.split(' ')

            # Mapping to 39 phones (-> list of phone strings)
            if dataset.label_type != 'phone39':
                phone_ref_list = map2phone39(phone_ref_list)
                phone_hyp_list = map2phone39(phone_hyp_list)

            # Compute PER
            try:
                per_b, sub_b, ins_b, del_b = compute_wer(ref=phone_ref_list,
                                                         hyp=phone_hyp_list,
                                                         normalize=False)
                per += per_b
                sub += sub_b
                ins += ins_b
                dele += del_b
                num_phones += len(phone_ref_list)
            except:
                pass

            if progressbar:
                pbar.update(1)

        if is_new_epoch:
            break

    if progressbar:
        pbar.close()

    # Reset data counters
    dataset.reset()

    per /= num_phones
    sub /= num_phones
    ins /= num_phones
    dele /= num_phones

    df_per = pd.DataFrame(
        {
            'SUB': [sub * 100],
            'INS': [ins * 100],
            'DEL': [dele * 100]
        },
        columns=['SUB', 'INS', 'DEL'],
        index=['PER'])

    return per, df_per
Beispiel #19
0
    def check(self, encoder_type, attention_type, label_type='character'):

        print('==================================================')
        print('  encoder_type: %s' % encoder_type)
        print('  attention_type: %s' % attention_type)
        print('  label_type: %s' % label_type)
        print('==================================================')

        tf.reset_default_graph()
        with tf.Graph().as_default():
            # Load batch data
            batch_size = 4
            inputs, labels, inputs_seq_len, labels_seq_len = generate_data(
                label_type=label_type,
                model='attention',
                batch_size=batch_size)

            # Define model graph
            num_classes = 27 if label_type == 'character' else 61
            model = AttentionSeq2Seq(input_size=inputs[0].shape[1],
                                     encoder_type=encoder_type,
                                     encoder_num_units=256,
                                     encoder_num_layers=2,
                                     encoder_num_proj=None,
                                     attention_type=attention_type,
                                     attention_dim=128,
                                     decoder_type='lstm',
                                     decoder_num_units=256,
                                     decoder_num_layers=1,
                                     embedding_dim=64,
                                     num_classes=num_classes,
                                     sos_index=num_classes,
                                     eos_index=num_classes + 1,
                                     max_decode_length=100,
                                     use_peephole=True,
                                     splice=1,
                                     parameter_init=0.1,
                                     clip_grad_norm=5.0,
                                     clip_activation_encoder=50,
                                     clip_activation_decoder=50,
                                     weight_decay=1e-8,
                                     time_major=True,
                                     sharpening_factor=1.0,
                                     logits_temperature=1.0)

            # Define placeholders
            model.create_placeholders()
            learning_rate_pl = tf.placeholder(tf.float32, name='learning_rate')

            # Add to the graph each operation
            loss_op, logits, decoder_outputs_train, decoder_outputs_infer = model.compute_loss(
                model.inputs_pl_list[0],
                model.labels_pl_list[0],
                model.inputs_seq_len_pl_list[0],
                model.labels_seq_len_pl_list[0],
                model.keep_prob_encoder_pl_list[0],
                model.keep_prob_decoder_pl_list[0],
                model.keep_prob_embedding_pl_list[0])
            train_op = model.train(loss_op,
                                   optimizer='adam',
                                   learning_rate=learning_rate_pl)
            decode_op_train, decode_op_infer = model.decode(
                decoder_outputs_train, decoder_outputs_infer)
            ler_op = model.compute_ler(model.labels_st_true_pl,
                                       model.labels_st_pred_pl)

            # Define learning rate controller
            learning_rate = 1e-3
            lr_controller = Controller(learning_rate_init=learning_rate,
                                       decay_start_epoch=20,
                                       decay_rate=0.9,
                                       decay_patient_epoch=10,
                                       lower_better=True)

            # Add the variable initializer operation
            init_op = tf.global_variables_initializer()

            # Count total parameters
            parameters_dict, total_parameters = count_total_parameters(
                tf.trainable_variables())
            for parameter_name in sorted(parameters_dict.keys()):
                print("%s %d" %
                      (parameter_name, parameters_dict[parameter_name]))
            print("Total %d variables, %s M parameters" %
                  (len(parameters_dict.keys()),
                   "{:,}".format(total_parameters / 1000000)))

            # Make feed dict
            feed_dict = {
                model.inputs_pl_list[0]: inputs,
                model.labels_pl_list[0]: labels,
                model.inputs_seq_len_pl_list[0]: inputs_seq_len,
                model.labels_seq_len_pl_list[0]: labels_seq_len,
                model.keep_prob_encoder_pl_list[0]: 0.8,
                model.keep_prob_decoder_pl_list[0]: 1.0,
                model.keep_prob_embedding_pl_list[0]: 1.0,
                learning_rate_pl: learning_rate
            }

            idx2phone = Idx2phone(map_file_path='./phone61.txt')

            with tf.Session() as sess:
                # Initialize parameters
                sess.run(init_op)

                # Wrapper for tfdbg
                # sess = tf_debug.LocalCLIDebugWrapperSession(sess)

                # Train model
                max_steps = 1000
                start_time_step = time.time()
                for step in range(max_steps):

                    # Compute loss
                    _, loss_train = sess.run(
                        [train_op, loss_op], feed_dict=feed_dict)

                    # Gradient check
                    # grads = sess.run(model.clipped_grads,
                    #                  feed_dict=feed_dict)
                    # for grad in grads:
                    #     print(np.max(grad))

                    if (step + 1) % 10 == 0:
                        # Change to evaluation mode
                        feed_dict[model.keep_prob_encoder_pl_list[0]] = 1.0
                        feed_dict[model.keep_prob_decoder_pl_list[0]] = 1.0
                        feed_dict[model.keep_prob_embedding_pl_list[0]] = 1.0

                        # Predict class ids
                        predicted_ids_train, predicted_ids_infer = sess.run(
                            [decode_op_train, decode_op_infer],
                            feed_dict=feed_dict)

                        # Compute accuracy
                        try:
                            feed_dict_ler = {
                                model.labels_st_true_pl: list2sparsetensor(
                                    labels, padded_value=model.eos_index),
                                model.labels_st_pred_pl: list2sparsetensor(
                                    predicted_ids_infer, padded_value=model.eos_index)
                            }
                            ler_train = sess.run(
                                ler_op, feed_dict=feed_dict_ler)
                        except IndexError:
                            ler_train = 1

                        duration_step = time.time() - start_time_step
                        print('Step %d: loss = %.3f / ler = %.3f (%.3f sec) / lr = %.5f' %
                              (step + 1, loss_train, ler_train, duration_step, learning_rate))
                        start_time_step = time.time()

                        # Visualize
                        if label_type == 'character':
                            print('True            : %s' %
                                  idx2alpha(labels[0]))
                            print('Pred (Training) : <%s' %
                                  idx2alpha(predicted_ids_train[0]))
                            print('Pred (Inference): <%s' %
                                  idx2alpha(predicted_ids_infer[0]))
                        else:
                            print('True            : %s' %
                                  idx2phone(labels[0]))
                            print('Pred (Training) : < %s' %
                                  idx2phone(predicted_ids_train[0]))
                            print('Pred (Inference): < %s' %
                                  idx2phone(predicted_ids_infer[0]))

                        if ler_train < 0.1:
                            print('Model is Converged.')
                            break

                        # Update learning rate
                        learning_rate = lr_controller.decay_lr(
                            learning_rate=learning_rate,
                            epoch=step,
                            value=ler_train)
                        feed_dict[learning_rate_pl] = learning_rate
Beispiel #20
0
def decode(model,
           dataset,
           eval_batch_size,
           beam_width,
           length_penalty,
           save_path=None):
    """Visualize label outputs.
    Args:
        model: the model to evaluate
        dataset: An instance of a `Dataset` class
        eval_batch_size (int): the batch size when evaluating the model
        beam_width: (int): the size of beam
        length_penalty (float):
        save_path (string): path to save decoding results
    """
    idx2phone = Idx2phone(dataset.vocab_file_path)

    if save_path is not None:
        sys.stdout = open(join(model.model_dir, 'decode.txt'), 'w')

    for batch, is_new_epoch in dataset:

        # Decode
        best_hyps, _, perm_idx = model.decode(
            batch['xs'],
            batch['x_lens'],
            beam_width=beam_width,
            max_decode_len=MAX_DECODE_LEN_PHONE,
            length_penalty=length_penalty)
        if model.model_type == 'attention' and model.ctc_loss_weight > 0:
            best_hyps_ctc, perm_idx = model.decode_ctc(batch['xs'],
                                                       batch['x_lens'],
                                                       beam_width=beam_width)

        ys = batch['ys'][perm_idx]
        y_lens = batch['y_lens'][perm_idx]

        for b in range(len(batch['xs'])):
            ##############################
            # Reference
            ##############################
            if dataset.is_test:
                str_ref = ys[b][0]
                # NOTE: transcript is seperated by space(' ')
            else:
                # Convert from list of index to string
                str_ref = idx2phone(ys[b][:y_lens[b]])

            ##############################
            # Hypothesis
            ##############################
            # Convert from list of index to string
            str_hyp = idx2phone(best_hyps[b])

            print('----- wav: %s -----' % batch['input_names'][b])
            print('Ref      : %s' % str_ref)
            print('Hyp      : %s' % str_hyp)
            if model.model_type == 'attention' and model.ctc_loss_weight > 0:
                str_hyp_ctc = idx2phone(best_hyps_ctc[b])
                print('Hyp (CTC): %s' % str_hyp_ctc)

            # Compute PER
            per, _, _, _ = compute_wer(ref=str_ref.split(' '),
                                       hyp=re.sub(r'(.*) >(.*)', r'\1',
                                                  str_hyp).split(' '),
                                       normalize=True)
            print('PER: %.3f %%' % (per * 100))
            if model.model_type == 'attention' and model.ctc_loss_weight > 0:
                per_ctc, _, _, _ = compute_wer(ref=str_ref.split(' '),
                                               hyp=str_hyp_ctc.split(' '),
                                               normalize=True)
                print('PER (CTC): %.3f %%' % (per_ctc * 100))

        if is_new_epoch:
            break