Exemplo n.º 1
0
def prediction(args,
               number,
               data_dir='../samples/cha/stimuli3/0',
               model_path="",
               mode='train',
               data_suffix='npy'):
    """ Get the ground truth and predicted sequence
        number: number of data in the batch wants to be visualized
        mode: train or validation or test, which will indicate the folder to be opened
        dir: the directory of the mfcc data
        model_path: the model checkpoint, HDF5 file
        data_suffix: the data suffix of those audio file inside the data_dir
    """
    # get data by index
    data, label, total_path_name = data_specification(mode, data_dir,
                                                      data_suffix)

    max_sequence_length = 0

    for i in data:
        max_sequence_length = max(max_sequence_length, i.shape[1])

    # Training Phase
    model = brnn_keras(args, max_sequence_length)
    #model.summary()
    model.model_1.load_weights(model_path)

    #model.model_2.load_weights(model_path)

    batch = create_batch(args, data, max_sequence_length, label)
    for i in range(number):
        inputs, outputs = next(batch)

        x = inputs['the_input']
        y = inputs['the_labels'][0]

        prediction = model.model_1.predict(x, steps=1)
        pred_ints = decode_ctc(args, prediction)
        print(np.shape(prediction))
        model_path = model_path.replace("../CHECKPOINT/char/save/", "")
        print('model ' + model_path)
        print('-' * 80)
        y = y.astype(int)
        if "test" not in model_path:
            print('GroundTruth speech:\n' + '\n' + int_sequence_to_text(y))
        else:
            print('GroundTruth speech:\n' + '\n' +
                  int_sequence_to_text_test(y))

        print('-' * 80)
        print(pred_ints)
        if "test" not in model_path:
            print('Predicted speech:\n' + '\n' +
                  ''.join(int_sequence_to_text(pred_ints)))
        else:
            print('Predicted speech:\n' + '\n' +
                  ''.join(int_sequence_to_text_test(pred_ints)))

        output_matrix(prediction[0])
Exemplo n.º 2
0
    def run_session(self, args):
        # get data

        ################################
        # get data and model handler
        ################################
        # training data

        data, label, _ = data_specification(args.mode, args.data_dir, 'npy')
        # 暂时设置为test 数据变化之后会进行修改
        dev_data, dev_label, _ = data_specification('test', args.data_dir,
                                                    'npy')

        max_sequence_length = 0

        batch_num = len(data) // args.batch_size
        dev_batch_num = len(dev_data) // args.batch_size

        for i in data:
            max_sequence_length = max(max_sequence_length, i.shape[1])

        for i in dev_data:
            max_sequence_length = max(max_sequence_length, i.shape[1])

        # Checkpointer
        checkpointer = ModelCheckpoint(filepath=savedir + '/' +
                                       args.model_name,
                                       verbose=1)

        # Training Phase
        model = Deepspeech2(args, max_sequence_length)

        history = model.model_2.fit_generator(
            generator=create_batch(args, data, max_sequence_length, label),
            validation_data=create_batch(args, dev_data, max_sequence_length,
                                         dev_label),
            steps_per_epoch=batch_num,
            validation_steps=dev_batch_num,
            epochs=args.epochs,
            verbose=1,
            callbacks=[checkpointer])
        # save model loss
        with open(savedir + '/' + args.pickle_name, 'wb') as f:
            pickle.dump(history.history, f)
Exemplo n.º 3
0
def get_data(mode='train',
             data_dir='./PreferredStimuliData',
             data_suffix='npy'):
    """Return data that is used for this visualization test
       reuse the function inside the data helper, only return data with
       batchsize with 1
    """
    data, label, total_path_name = data_specification(mode, data_dir,
                                                      data_suffix)

    max_sequence_length = 0

    for i in data:
        max_sequence_length = max(max_sequence_length, i.shape[1])

    batch = create_batch(args, data, max_sequence_length, label)
    inputs, outputs = next(batch)
    #x = inputs['the_input']
    #y = inputs['the_labels'][0]
    return inputs, outputs, max_sequence_length