def predict(model):

    input_files = fnmatch.filter(
        listdir(FLAGS.filePath), FLAGS.filePattern +
        FLAGS.predictionSetExtension + '.*.' + FLAGS.extension)
    if len(input_files) == 0:
        sys.exit("File not found: " + filename_pattern)

    for file in input_files:
        print('Number of predictions:        ', end='')
        predict_data_generator = DataGenerator(
            data_shape,
            num_class,
            FLAGS.filePath,
            file,
            FLAGS.testBatchSize,
            useOddSample=True,
        )
        prediction = model.predict_generator(
            predict_data_generator,
            predict_data_generator.num_steps(),
            max_queue_size=FLAGS.queueSize,
        )
        np.save(
            path.join(
                FLAGS.filePath,
                file.replace(FLAGS.predictionSetExtension,
                             FLAGS.predictionExtension).replace(
                                 '.' + FLAGS.extension, '')), prediction)
def test(model):
    print('Number of test samples:       ', end='')
    test_data_generator = DataGenerator(
        data_shape,
        num_class,
        FLAGS.filePath,
        FLAGS.filePattern + FLAGS.testSetExtension + '.*.' + FLAGS.extension,
        FLAGS.testBatchSize,
        useOddSample=True,
        percent=FLAGS.testPercent,
        shuffle=True,
    )

    score = model.evaluate_generator(
        test_data_generator,
        test_data_generator.num_steps(),
        max_queue_size=FLAGS.queueSize,
    )

    print('Test loss:    ', score[0])
    print('Test accuracy:', score[1])
def main():

    model_serial = define_model()

    # Save model
    model_yaml = model_serial.to_yaml()
    model_yaml_file = open(
        path.join(FLAGS.modelPath, FLAGS.sessionName + '.model.yaml'), "w")
    model_yaml_file.write(model_yaml)
    model_yaml_file.close()

    #keras.utils.plot_model(model_serial, to_file=path.join(FLAGS.modelPath, FLAGS.sessionName + '.model.png'), show_shapes= True)

    # load weights
    if FLAGS.startEpoch > 0:
        model_serial.load_weights(
            path.join(
                FLAGS.modelPath, FLAGS.sessionName +
                '.weights.{:02d}.hdf5'.format(FLAGS.startEpoch)))
        print('Loaded epoch', FLAGS.startEpoch)

    if (FLAGS.testOnly == True or FLAGS.predictOnly) and FLAGS.startEpoch == 0:
        model_serial.load_weights(
            path.join(FLAGS.modelPath, FLAGS.sessionName + '.weights.hdf5'))

    model = model_serial
    model.summary()

    model.compile(
        loss='categorical_crossentropy',
        optimizer=keras.optimizers.nadam(lr=FLAGS.learningRate),
        metrics=['accuracy'],
    )

    if FLAGS.randomSeed > 0:
        random.seed(FLAGS.randomSeed)

    if FLAGS.testOnly == True:
        test(model)
        sys.exit()

    if FLAGS.predictOnly == True:
        predict(model)
        sys.exit()

    trainingParameterText = 'BatchSize : {:02d}; learningRate : {:.8f}; L2 : {:.8f}; Dropout : {:.4f}'.format(
        FLAGS.batchSize, FLAGS.learningRate, FLAGS.L2, FLAGS.dropout)
    print(trainingParameterText)
    print('_________________________________________________________________')

    print('Number of training samples:   ', end='')
    training_data_generator = DataGenerator(
        data_shape,
        num_class,
        FLAGS.filePath,
        FLAGS.filePattern + FLAGS.trainingSetExtension + '.*.' +
        FLAGS.extension,
        FLAGS.batchSize,
        percent=FLAGS.trainingPercent,
        queue_size=FLAGS.queueSize,
        shuffle=True,
        #class_weight=class_weight,
    )
    print('Number of validation samples: ', end='')
    validation_data_generator = DataGenerator(
        data_shape,
        num_class,
        FLAGS.filePath,
        FLAGS.filePattern + FLAGS.validationSetExtension + '.*.' +
        FLAGS.extension,
        FLAGS.testBatchSize,
        percent=FLAGS.validationPercent,
        shuffle=True,
        #class_weight=class_weight,
    )
    print('Number of test samples:       ', end='')
    test_data_generator = DataGenerator(
        data_shape,
        num_class,
        FLAGS.filePath,
        FLAGS.filePattern + FLAGS.testSetExtension + '.*.' + FLAGS.extension,
        FLAGS.testBatchSize,
        percent=FLAGS.testPercent,
        shuffle=True,
    )

    print('_________________________________________________________________')

    callback_list = []

    callback_list.append(
        RecordEpoch(
            best_epoch_dict,
            weigh_filepath=path.join(
                FLAGS.modelPath,
                FLAGS.sessionName + '.weights.{epoch:02d}.hdf5'),
            csv_log_filepath=path.join(
                FLAGS.modelPath,
                FLAGS.sessionName + '.' + FLAGS.logExtension,
            ),
            csv_log_header=trainingParameterText,
            patience=FLAGS.earlyStoppingPatience,
            sleep_after_epoch=sleep_after_epoch,
        ))

    history = model.fit_generator(
        training_data_generator,
        training_data_generator.num_steps(),
        epochs=FLAGS.maxEpoch,
        initial_epoch=FLAGS.startEpoch,
        verbose=1,
        callbacks=callback_list,
        #class_weight=class_weight,
        validation_data=validation_data_generator,
        validation_steps=validation_data_generator.num_steps(),
        max_queue_size=FLAGS.queueSize,
    )

    training_data_generator.terminate()

    print('Best epoch:', best_epoch_dict['epoch_index'])
    print('Validation accuracy:', best_epoch_dict['val_acc'])

    model.load_weights(
        path.join(
            FLAGS.modelPath, FLAGS.sessionName +
            '.weights.{:02d}.hdf5'.format(best_epoch_dict['epoch_index'])))
    print('Loaded weights of this epoch')

    test(model)