예제 #1
0
파일: stats.py 프로젝트: LiZhenghua0311/lip
def stats(weight_path, dataset_path, img_c, img_w, img_h, frames_n,
          absolute_max_string_len, minibatch_size):
    lip_gen = BasicGenerator(
        dataset_path=dataset_path,
        minibatch_size=minibatch_size,
        img_c=img_c,
        img_w=img_w,
        img_h=img_h,
        frames_n=frames_n,
        absolute_max_string_len=absolute_max_string_len).build()

    lipnet = LipNet(img_c=img_c,
                    img_w=img_w,
                    img_h=img_h,
                    frames_n=frames_n,
                    absolute_max_string_len=absolute_max_string_len,
                    output_size=lip_gen.get_output_size())

    adam = Adam(lr=0.0001, beta_1=0.9, beta_2=0.999, epsilon=1e-08)

    lipnet.model.compile(loss={
        'ctc': lambda y_true, y_pred: y_pred
    },
                         optimizer=adam)
    lipnet.model.load_weights(weight_path)

    spell = Spell(path=PREDICT_DICTIONARY)
    decoder = Decoder(greedy=PREDICT_GREEDY,
                      beam_width=PREDICT_BEAM_WIDTH,
                      postprocessors=[labels_to_text, spell.sentence])

    statistics = Statistics(lipnet,
                            lip_gen.next_val(),
                            decoder,
                            256,
                            output_dir=None)

    lip_gen.on_train_begin()
    statistics.on_epoch_end(0)
예제 #2
0
def train(run_name, speaker, start_epoch, stop_epoch, img_c, img_w, img_h, frames_n, absolute_max_string_len, minibatch_size):
    DATASET_DIR = os.path.join(CURRENT_PATH, speaker, 'datasets')
    OUTPUT_DIR = os.path.join(CURRENT_PATH, speaker, 'results')
    LOG_DIR = os.path.join(CURRENT_PATH, speaker, 'logs')

    curriculum = Curriculum(curriculum_rules)
    lip_gen = BasicGenerator(dataset_path=DATASET_DIR,
                                minibatch_size=minibatch_size,
                                img_c=img_c, img_w=img_w, img_h=img_h, frames_n=frames_n,
                                absolute_max_string_len=absolute_max_string_len,
                                curriculum=curriculum, start_epoch=start_epoch).build()

    lipnet = LipNet(img_c=img_c, img_w=img_w, img_h=img_h, frames_n=frames_n,
                            absolute_max_string_len=absolute_max_string_len, output_size=lip_gen.get_output_size())
    lipnet.summary()

    adam = Adam(lr=0.0001, beta_1=0.9, beta_2=0.999, epsilon=1e-08)

    # the loss calc occurs elsewhere, so use a dummy lambda func for the loss
    lipnet.model.compile(loss={'ctc': lambda y_true, y_pred: y_pred}, optimizer=adam)

    # load weight if necessary
    if start_epoch > 0:
        weight_file = os.path.join(OUTPUT_DIR, os.path.join(run_name, 'weights%02d.h5' % (start_epoch - 1)))
        lipnet.model.load_weights(weight_file)

    if start_epoch < 1:
        weight_file = os.path.join(OUTPUT_DIR, os.path.join(CURRENT_PATH,speaker,'results', 'weightsa.h5'))
        lipnet.model.load_weights(weight_file)

    spell = Spell(path=PREDICT_DICTIONARY)
    decoder = Decoder(greedy=PREDICT_GREEDY, beam_width=PREDICT_BEAM_WIDTH,
                      postprocessors=[labels_to_text, spell.sentence])

    # define callbacks
    statistics  = Statistics(lipnet, lip_gen.next_val(), decoder, 256, output_dir=os.path.join(OUTPUT_DIR, run_name))
    visualize   = Visualize(os.path.join(OUTPUT_DIR, run_name), lipnet, lip_gen.next_val(), decoder, num_display_sentences=minibatch_size)
    tensorboard = TensorBoard(log_dir=os.path.join(LOG_DIR, run_name))
    csv_logger  = CSVLogger(os.path.join(LOG_DIR, "{}-{}.csv".format('training',run_name)), separator=',', append=True)
    checkpoint  = ModelCheckpoint(os.path.join(OUTPUT_DIR, run_name, "weights{epoch:02d}.h5"), monitor='val_loss', save_weights_only=True, mode='auto', period=1)

    lipnet.model.fit_generator(generator=lip_gen.next_train(),
                        steps_per_epoch=lip_gen.default_training_steps, epochs=stop_epoch,
                        validation_data=lip_gen.next_val(), validation_steps=lip_gen.default_validation_steps,
                        callbacks=[checkpoint, statistics, visualize, lip_gen, tensorboard, csv_logger],
                        initial_epoch=start_epoch,
                        verbose=1,
                        max_q_size=5,
                        workers=2,
                        pickle_safe=True)