コード例 #1
0
ファイル: log.py プロジェクト: johSchm/magpie
def log_confusion_matrix(path, epoch, classes, ground_truth, predictions):
    """ Logs the confusion matrix.
    :param path:
    :param epoch:
    :param classes:
    :param ground_truth:
    :param predictions:
    """
    file_writer = tf.summary.create_file_writer(path_utils.join(path + 'cm'))
    con_mat = tf.math.confusion_matrix(labels=ground_truth,
                                       predictions=predictions).numpy()
    con_mat_norm = np.around(con_mat.astype('float') /
                             con_mat.sum(axis=1)[:, np.newaxis],
                             decimals=2)
    con_mat_df = pd.DataFrame(con_mat_norm, index=classes, columns=classes)
    figure = plt.figure(figsize=(8, 8))
    sns.heatmap(con_mat_df, annot=True, cmap=plt.cm.Blues)
    plt.tight_layout()
    plt.ylabel('True label')
    plt.xlabel('Predicted label')
    buf = io.BytesIO()
    plt.savefig(buf, format='png')
    plt.close(figure)
    buf.seek(0)
    image = tf.image.decode_png(buf.getvalue(), channels=4)
    image = tf.expand_dims(image, 0)
    with file_writer.as_default():
        tf.summary.image("Confusion Matrix", image, step=epoch)
    tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=path)
コード例 #2
0
def get_settings_file_path(master_path=None,
                           auto_search=True,
                           auto_iteration=0,
                           max_iterations=5):
    """ Returns the path to the settings file.
    :param master_path:
    :param auto_search:
    :param auto_iteration:
    :param max_iterations:
    :return: path
    """
    data = None
    if master_path is None:
        master_path = MASTER_PATH
    f = None
    try:
        f = open(master_path, "r")
        data = json.load(f)
    except FileNotFoundError:
        if auto_iteration >= max_iterations:
            auto_search = False
        if auto_search:
            new_path = path_utils.join('..', master_path)
            auto_iteration += 1
            get_settings_file_path(master_path=new_path,
                                   auto_iteration=auto_iteration)
        else:
            raise FileNotFoundError("File not found!")
    if f is not None:
        f.close()
    return data
コード例 #3
0
 def __init__(self, path=None, model_idx=0, settings_idx=0):
     """ Init method.
     :param path:
     :param settings_idx:
     :param model_idx: for multiple settings paths
     """
     self.path = path
     if path is None:
         path = utils.get_settings_file_path()
         if type(path) is dict:
             path = path.get('settings_' + str(settings_idx))[model_idx]
         self.path = path_utils.join('..', path)
コード例 #4
0
ファイル: log.py プロジェクト: johSchm/magpie
def get_checkpoint_cb(path):
    """ This will setup a models checkpoint callback.
    Therefore the models is stored after every epoch.
    :param path:
    :return: cb
    """
    if path is None:
        return
    checkpoint_path = path + str()
    if not os.path.exists(checkpoint_path):
        os.makedirs(checkpoint_path)
    file_path = path_utils.join(checkpoint_path, "ckpt-e{epoch:02d}.hdf5")
    checkpoint_callback = ModelCheckpoint(file_path,
                                          verbose=1,
                                          save_best_only=False,
                                          save_weights_only=False,
                                          save_frequency=1)
    return checkpoint_callback
コード例 #5
0
ファイル: root_model.py プロジェクト: johSchm/magpie
 def generator_batch(db, channel_first=False, use_paths=False, rnd_sample_frame=False):
     """ A generator for datasets.
     :param db:
     :param channel_first:
     :param use_paths:
     :param rnd_sample_frame:
     :return (yield) data
     """
     while True:
         for data in db:
             if type(data) is tuple:
                 if channel_first:
                     yield [np.moveaxis(item[0], item[0].ndim - 1, 1) for item in data][0], data[0][1]
                 elif use_paths:
                     imgs = []
                     # @todo @HACK dirty (hard coded shape)
                     for path in data[0][0]:
                         path = path.numpy().decode('utf-8')
                         paths = [path_utils.join(path, str(i) + '.jpg') for i in range(24)]
                         img = [np.array(Image.open(p).resize((224, 224))) / 255.0 for p in paths]
                         img = np.array(img).reshape([24, 224, 224, 3])
                         imgs.append(img)
                     imgs = tf.convert_to_tensor(imgs, dtype=tf.float32)
                     yield imgs, data[0][1]
                 else:
                     if len(list(data)) > 1:
                         imgs = [item[0] for item in data]
                     else:
                         imgs = [item[0] for item in data][0]
                     labels = data[0][1]
                     if rnd_sample_frame:
                         img = imgs[:, randint(0, 23)]
                         yield ([img, imgs], labels)
                     else:
                         yield imgs, labels
             else:
                 yield data
コード例 #6
0
ファイル: train.py プロジェクト: johSchm/magpie
    import os
    import learn.model.model_utils as model_utils
    import utils.path_utils as path_utils
    import utils.os_utils as os_utils
    import tensorflow as tf
    from keras import backend

    # force channels-last ordering
    tf.keras.backend.set_image_data_format('channels_last')
    backend.set_image_data_format('channels_last')
    print("Enforcing channel ordering: " + str(backend.image_data_format()))
    print("Enforcing channel ordering: " +
          str(tf.keras.backend.image_data_format()))

    # adapted relative paths
    hyper_param_path = path_utils.join(os.getcwd(), BASE_PATH, key,
                                       hyper_param_path)
    checkpoint_path = path_utils.join(os.getcwd(), BASE_PATH, key,
                                      checkpoint_path)
    class_path = path_utils.join(os.getcwd(), BASE_PATH, key, class_path)
    log_path = path_utils.join(os.getcwd(), BASE_PATH, key, log_path)
    model_path = path_utils.join(os.getcwd(), BASE_PATH, key, model_path,
                                 key + model_format)

    if os_utils.get_operation_system() == os_utils.OperatingSystems.WIN:
        checkpoint_path = checkpoint_path.replace('/', '\\')
        model_path = model_path.replace('/', '\\')
        log_path = log_path.replace('/', '\\')

    # setup model hyper parameter
    settings_manager = sManager.SettingsManager(path=hyper_param_path)
    loss = settings_manager.read("loss")
コード例 #7
0
ファイル: predict.py プロジェクト: johSchm/magpie
]
CHECKPOINT = "ckpt-e05.hdf5"

settings_manager = sManager.SettingsManager(model_idx=0, settings_idx=23)
image_size = settings_manager.read("image_size")
checkpoint_path = settings_manager.read("checkpoint_path")
model_path = settings_manager.read("model_path")
key = settings_manager.read("key")
shape = settings_manager.read("input_shape")
class_path = settings_manager.read("class_path")
model_format = settings_manager.read("model_format")
optical_flow = settings_manager.read("optical_flow")
feature_detection = settings_manager.read("feature_detection")
log_path = settings_manager.read("log_path")

checkpoint_path = path_utils.join(BASE_PATH, key, checkpoint_path, CHECKPOINT)
model_path = path_utils.join(BASE_PATH, key, model_path, key + model_format)
class_path = path_utils.join(BASE_PATH, key, class_path)
log_path = path_utils.join(BASE_PATH, key, log_path)

learner = learn.Learner()
learner.load(model_path)
learner.model.get_model().load_weights(checkpoint_path)

optical_flow_estimator = optflow.OpticalFlowEstimator(
    estimation_method=optical_flow, feature_detection=feature_detection)
pc = predict.PredictionController(
    learner.model.get_model(),
    class_path,
    log_path=log_path,
    optical_flow_estimator=optical_flow_estimator)
コード例 #8
0
partial_load = settings_manager.read("partial_load")
class_path = settings_manager.read("class_path")
test_split = settings_manager.read("test_split")
fpv = settings_manager.read("fpv")
augmentation_path = settings_manager.read("augmentation_path")
data_cache_path = settings_manager.read("data_cache_path")
from_cache = settings_manager.read("from_cache")
opt_flow = settings_manager.read("optical_flow")
segmentation = settings_manager.read("segmentation")
feature_detection = settings_manager.read("feature_detection")
pose_method = settings_manager.read("pose")
colormode = settings_manager.read("color_depth")
memory_data_storage = settings_manager.read("memory_data_storage")
key = settings_manager.read("key")

class_path = path_utils.join(BASE_PATH, key, class_path)
raw_data_path = path_utils.join(BASE_PATH, key, raw_data_path)
data_path = path_utils.join(BASE_PATH, key, data_path)
augmentation_path = path_utils.join(BASE_PATH, key, augmentation_path)
data_cache_path = path_utils.join(BASE_PATH, key, data_cache_path)

processor = pro.Preprocessor(colormode=colormode,
                             data_pattern=data_utils.DataPattern.TF_RECORD,
                             sample_type=sample.SampleType.VIDEO,
                             memory_data_storage=memory_data_storage)
sample_train, sample_test = processor.run(raw_data_path,
                                          class_path,
                                          data_cache_path,
                                          image_size,
                                          test_split,
                                          partial_load,