def main():
    logging.basicConfig(
        level=logging.INFO, format="%(asctime)s : %(levelname)s : %(module)s : %(message)s", datefmt="%d-%m-%Y %H:%M:%S"
    )

    matplotlib_setup()

    images_masks = load_pickle(IMAGES_MASKS_FILENAME)
    logging.info('Masks: %s', len(images_masks))

    images = sorted(images_masks.keys())
    logging.info('Images: %s', len(images))

    id2class = {i + 1: c for i, c in enumerate(CLASSES_NAMES)}
    nb_classes = len(CLASSES_NAMES)
    logging.info('Classes: %s', nb_classes)

    # calc intersections
    permutations = list(itertools.permutations(range(1, nb_classes + 1), 2))
    intersections = []
    for k, (i, j) in enumerate(permutations):
        inter = get_intersections(images, images_masks, i, j, mode='mask1')
        intersections.append(inter)

        logging.info('Finished %s/%s [%.2f%%, classes %s - %s]', k + 1, len(permutations), 100 * (k + 1) / len(permutations), i, j)


    intersection_data = np.zeros((nb_classes, nb_classes), dtype=np.float32)
    for k, (i, j) in enumerate(permutations):
        intersection_data[i-1, j-1] = intersections[k]

    plot_intersections(intersection_data, id2class)
示例#2
0
def main(model_names, output_name):
    logging.basicConfig(
        level=logging.INFO, format="%(asctime)s : %(levelname)s : %(module)s : %(message)s", datefmt="%d-%m-%Y %H:%M:%S"
    )

    matplotlib_setup()

    logging.info('Combining masks:')
    for mn in model_names:
        logging.info(' - %s', mn)

    logging.info('Output masks: %s', output_name)

    images_all, images_train, images_test = get_train_test_images_ids()
    logging.info('Train: %s, test: %s, all: %s', len(images_train), len(images_test), len(images_all))

    target_images = IMAGES_TEST_REAL
    logging.info('Target images: %s', len(target_images))
    for img_number, img_id in enumerate(target_images):
        # if img_id != '6060_2_3':
        #     continue

        img_masks = [load_prediction_mask(IMAGES_PREDICTION_MASK_DIR, img_id, model_name) for model_name in model_names]
        img_masks = np.array(img_masks)

        img_masks_combined = np.sum(img_masks, axis=0)

        save_prediction_mask(IMAGES_PREDICTION_MASK_DIR, img_masks_combined, img_id, output_name)

        logging.info('Combined: %s/%s [%.2f]',
                     img_number + 1, len(target_images), 100 * (img_number + 1) / len(target_images))
def main():
    logging.basicConfig(
        level=logging.INFO,
        format="%(asctime)s : %(levelname)s : %(module)s : %(message)s",
        datefmt="%d-%m-%Y %H:%M:%S")

    matplotlib_setup()

    # load images
    images_data = load_pickle(IMAGES_NORMALIZED_FILENAME)
    images_masks = load_pickle(IMAGES_MASKS_FILENAME)
    logging.info('Images: %s, masks: %s', len(images_data), len(images_masks))

    images_all, images_train, images_test = get_train_test_images_ids()
    logging.info('Train: %s, test: %s, all: %s', len(images_train),
                 len(images_test), len(images_all))

    nb_classes = len(images_masks[images_train[0]])
    classes = np.arange(1, nb_classes + 1)

    images_masks_stacked = stack_masks(images_train, images_masks, classes)
    logging.info('Masks stacked: %s', len(images_masks_stacked))

    # # plot_image(images_data[test_img_id] * channels_std + channels_mean)
    # for img_id in images_train:
    #     if images_masks[img_id][9].sum() > 0:
    #         logging.info('Image: %s, Vehicle: %s - %s, Crops: %s, Trees: %s',
    #                      img_id,
    #                      images_masks[img_id][9].sum(), images_masks[img_id][10].sum(),
    #                      images_masks[img_id][6].sum(), images_masks[img_id][5].sum(),
    #                      )

    # test on trees/crops
    test_img_id = '6120_2_2'

    # convert ot one hot and back
    masks_softmax = convert_masks_to_softmax(
        np.expand_dims(images_masks_stacked[test_img_id], 0))[0]
    masks_softmax_flattened = np.reshape(masks_softmax, (-1))
    mask_probs = np.zeros((masks_softmax_flattened.shape[0], 11))
    mask_probs[np.arange(masks_softmax_flattened.shape[0]),
               masks_softmax_flattened] = 1
    mask_probs = np.reshape(
        mask_probs, (masks_softmax.shape[0], masks_softmax.shape[1], 11))
    masks_converted = convert_softmax_to_masks(mask_probs)

    classes_to_plot = [
        [5, 'Trees'],
        [6, 'Crops'],
        # [3, 'Roads'],
        # [4, 'Track'],
    ]
    for class_type, class_title in classes_to_plot:
        titles = [class_title + ' - ' + t for t in ['True', 'Reconstructed']]
        plot_two_masks(images_masks[test_img_id][class_type][:750, :750],
                       masks_converted[:750, :750, class_type - 1],
                       titles=titles)
示例#4
0
def main():
    logging.basicConfig(
        level=logging.INFO,
        format="%(asctime)s : %(levelname)s : %(module)s : %(message)s",
        datefmt="%d-%m-%Y %H:%M:%S")

    matplotlib_setup()

    mean_sharpened, std_sharpened = load_pickle(IMAGES_MEANS_STDS_FILENAME)
    logging.info('Mean: %s, Std: %s', mean_sharpened.shape,
                 std_sharpened.shape)

    images_all, images_train, images_test = get_train_test_images_ids()

    logging.info('Train: %s, test: %s, all: %s', len(images_train),
                 len(images_test), len(images_all))

    target_images = images_all
    logging.info('Target images: %s', len(target_images))

    Parallel(n_jobs=8, verbose=5)(delayed(load_and_normalize_image)(
        img_id, mean_sharpened, std_sharpened) for img_id in target_images)
def main():
    logging.basicConfig(
        level=logging.INFO,
        format="%(asctime)s : %(levelname)s : %(module)s : %(message)s",
        datefmt="%d-%m-%Y %H:%M:%S")

    matplotlib_setup()

    images_data = load_pickle(IMAGES_NORMALIZED_FILENAME)
    images_masks = load_pickle(IMAGES_MASKS_FILENAME)
    logging.info('Images: %s, masks: %s', len(images_data), len(images_masks))

    images_metadata, channels_mean, channels_std = load_pickle(
        IMAGES_METADATA_FILENAME)
    logging.info('Images metadata: %s, mean: %s, std: %s',
                 len(images_metadata), channels_mean.shape, channels_std.shape)

    patch_size = (
        256,
        256,
    )
    nb_channels = 3
    nb_classes = 10

    nb_patches = 1000000
    mask_threshold = 0.15

    images = np.array(list(images_data.keys()))
    classes = np.arange(1, nb_classes + 1)

    images_masks_stacked = stack_masks(images, images_masks, classes)
    logging.info('Masks stacked: %s', len(images_masks_stacked))

    train_patches_coordinates = []
    while len(train_patches_coordinates) < nb_patches:
        try:
            img_id = np.random.choice(images)
            img_mask_data = images_masks_stacked[img_id]

            # # sample 8*32 masks from the same image
            # masks_batch = Parallel(n_jobs=4, verbose=10)(delayed(sample_patch)(
            #     img_id, img_mask_data, patch_size, mask_threshold, 32, 99) for i in range(64))
            #
            # for masks in masks_batch:
            #     train_patches_coordinates.extend(masks)
            masks = sample_patch(img_id,
                                 img_mask_data,
                                 patch_size,
                                 threshold=mask_threshold,
                                 nb_masks=32)
            train_patches_coordinates.extend(masks)

            nb_sampled = len(train_patches_coordinates)
            if nb_sampled % 50 == 0:
                logging.info('Sampled %s/%s [%.2f]', nb_sampled, nb_patches,
                             100 * nb_sampled / nb_patches)

        except KeyboardInterrupt:
            break

    shuffle(train_patches_coordinates)
    logging.info('Sampled patches: %s', len(train_patches_coordinates))

    save_pickle(TRAIN_PATCHES_COORDINATES_FILENAME, train_patches_coordinates)
    logging.info('Saved: %s',
                 os.path.basename(TRAIN_PATCHES_COORDINATES_FILENAME))
示例#6
0
def main(kind):
    logging.basicConfig(
        level=logging.INFO,
        format="%(asctime)s : %(levelname)s : %(module)s : %(message)s",
        datefmt="%d-%m-%Y %H:%M:%S")

    matplotlib_setup()

    images_data = load_pickle(IMAGES_NORMALIZED_SHARPENED_FILENAME)
    logging.info('Images: %s', len(images_data))

    images_masks = load_pickle(IMAGES_MASKS_FILENAME)
    logging.info('Masks: %s', len(images_masks))

    images_metadata = load_pickle(IMAGES_METADATA_FILENAME)
    logging.info('Metadata: %s', len(images_metadata))

    images_metadata_polygons = load_pickle(IMAGES_METADATA_POLYGONS_FILENAME)
    logging.info('Polygons metadata: %s', len(images_metadata_polygons))

    mean_sharpened, std_sharpened = load_pickle(IMAGES_MEANS_STDS_FILENAME)
    logging.info('Mean: %s, Std: %s', mean_sharpened.shape,
                 std_sharpened.shape)

    images_all, images_train, images_test = get_train_test_images_ids()
    logging.info('Train: %s, test: %s, all: %s', len(images_train),
                 len(images_test), len(images_all))

    if kind == 'test':
        target_images = images_test
    elif kind == 'train':
        target_images = images_train
    else:
        raise ValueError('Unknown kind: {}'.format(kind))

    nb_target_images = len(target_images)
    logging.info('Target images: %s - %s', kind, nb_target_images)

    nb_classes = len(images_masks[images_train[0]])
    classes = np.arange(1, nb_classes + 1)

    images_masks_stacked = None
    if kind == 'train':
        images_masks_stacked = stack_masks(target_images, images_masks,
                                           classes)
        logging.info('Masks stacked: %s', len(images_masks_stacked))

    jaccards = []
    jaccards_simplified = []
    model_name = 'softmax_pansharpen_tiramisu_small_patch'
    for img_idx, img_id in enumerate(target_images):
        if img_id != '6040_4_4':  # 6010_1_2 6040_4_4 6060_2_3
            continue

        mask_filename = os.path.join(IMAGES_PREDICTION_MASK_DIR,
                                     '{0}_{1}.npy'.format(img_id, model_name))
        if not os.path.isfile(mask_filename):
            logging.warning('Cannot find masks for image: %s', img_id)
            continue

        img_data = None
        if kind == 'train':
            img_data = images_data[img_id] * std_sharpened + mean_sharpened
        if kind == 'test':
            img_filename = os.path.join(IMAGES_NORMALIZED_DATA_DIR,
                                        img_id + '.npy')
            img_data = np.load(img_filename)

        img_metadata = images_metadata[img_id]
        img_mask_pred = np.load(mask_filename)

        if kind == 'train':
            img_poly_true = images_metadata_polygons[img_id]
            img_mask_true = images_masks_stacked[img_id]
        else:
            img_poly_true = None
            img_mask_true = None

        # plot_image(img_data[:,:,:3])

        img_mask_pred_simplified = simplify_mask(img_mask_pred, kernel_size=5)

        # if kind == 'train':
        #     for i, class_name in enumerate(CLASSES_NAMES):
        #         if img_mask_true[:,:,i].sum() > 0:
        #             plot_two_masks(img_mask_true[:,:,i], img_mask_pred[:,:,i],
        #                 titles=['Ground Truth - {}'.format(class_name), 'Prediction - {}'.format(class_name)])
        #             plot_two_masks(img_mask_pred[:,:,i], img_mask_pred_simplified[:,:,i],
        #                 titles=['Ground Truth - {}'.format(class_name), 'Prediction Simplified - {}'.format(class_name)])

        # img_poly_pred = create_image_polygons(img_mask_pred, img_metadata, scale=False)
        # plot_polygons(img_data[:,:,:3], img_metadata, img_poly_pred, img_poly_true, title=img_id, show=False)

        if kind == 'train':
            # convert predicted polygons to mask
            jaccard = jaccard_coef(img_mask_pred, img_mask_true)
            jaccards.append(jaccard)
            jaccard_simplified = jaccard_coef(img_mask_pred_simplified,
                                              img_mask_true)
            jaccards_simplified.append(jaccard_simplified)
            logging.info('Image: %s, jaccard: %s, jaccard simplified: %s',
                         img_id, jaccard, jaccard_simplified)

    if kind == 'train':
        logging.info('Mean jaccard: %s, Mean jaccard simplified: %s',
                     np.mean(jaccards), np.mean(jaccards_simplified))

    import matplotlib.pyplot as plt
    plt.show()
def main(model_name, classes_to_skip, patch_size, nb_iterations, batch_size,
         debug, regularization, model_load_step, prediction_images):
    # set-up matplotlib
    matplotlib_setup()

    logging.info('Model name: %s', model_name)

    classes_names = [
        c.strip().lower().replace(' ', '_').replace('.', '')
        for c in CLASSES_NAMES
    ]
    nb_classes = len(classes_names)
    classes = np.arange(1, nb_classes + 1)
    logging.info('Classes: %s', nb_classes)

    images_all, images_train, images_test = get_train_test_images_ids()
    logging.info('Train: %s, test: %s, all: %s', len(images_train),
                 len(images_test), len(images_all))

    # load images data
    images_data = load_pickle(IMAGES_NORMALIZED_SHARPENED_FILENAME)
    logging.info('Images: %s', len(images_data))

    # load masks
    images_masks = load_pickle(IMAGES_MASKS_FILENAME)
    logging.info('Masks: %s', len(images_masks))

    # load images metadata
    images_metadata = load_pickle(IMAGES_METADATA_FILENAME)
    logging.info('Metadata: %s', len(images_metadata))

    mean_sharpened, std_sharpened = load_pickle(IMAGES_MEANS_STDS_FILENAME)
    logging.info('Mean: %s, Std: %s', mean_sharpened.shape,
                 std_sharpened.shape)

    images = sorted(list(images_data.keys()))
    nb_images = len(images)
    logging.info('Train images: %s', nb_images)

    images_masks_stacked = stack_masks(images, images_masks, classes)
    logging.info('Masks stacked: %s', len(images_masks_stacked))

    nb_channels = images_data[images[0]].shape[2]
    logging.info('Channels: %s', nb_channels)

    # skip vehicles and misc manmade structures
    needed_classes = [
        c for c in range(nb_classes) if c + 1 not in classes_to_skip
    ]
    needed_classes_names = [
        c for i, c in enumerate(classes_names) if i + 1 not in classes_to_skip
    ]
    nb_needed_classes = len(needed_classes)
    logging.info('Skipping classes: %s, needed classes: %s', classes_to_skip,
                 nb_needed_classes)

    # create model

    sess_config = tf.ConfigProto(inter_op_parallelism_threads=4,
                                 intra_op_parallelism_threads=4)
    sess_config.gpu_options.allow_growth = True
    sess = tf.Session(config=sess_config)

    model_params = {
        'nb_classes': nb_needed_classes,
        'regularization': regularization,
    }
    model = create_model(model_params)  # SimpleModel(**model_params)
    model.set_session(sess)
    if not debug:
        model.set_tensorboard_dir(os.path.join(TENSORBOARD_DIR, model_name))

    # TODO: not a fixed size
    model.add_input('X', [patch_size[0], patch_size[1], nb_channels])
    model.add_input('Y_softmax', [
        patch_size[0],
        patch_size[1],
    ],
                    dtype=tf.uint8)

    model.build_model()

    # train model
    if model_load_step == -1:

        iteration_number = 0
        jaccard_val_mean = 0
        jaccard_train_mean = 0
        for iteration_number in range(1, nb_iterations + 1):
            try:
                data_dict_train = sample_data_dict(images,
                                                   images_masks_stacked,
                                                   images_data, 'train',
                                                   needed_classes)
                model.train_model(data_dict_train,
                                  nb_epoch=1,
                                  batch_size=batch_size)

                # validate the model
                if iteration_number % 5 == 0:
                    jaccard_val = evaluate_model_jaccard(model,
                                                         images,
                                                         images_masks_stacked,
                                                         images_data,
                                                         needed_classes,
                                                         kind='val')
                    jaccard_train = evaluate_model_jaccard(
                        model,
                        images,
                        images_masks_stacked,
                        images_data,
                        needed_classes,
                        kind='train')
                    jaccard_val_mean = np.mean(jaccard_val)
                    jaccard_train_mean = np.mean(jaccard_train)

                    logging.info(
                        'Iteration %s, jaccard val: %.5f, jaccard train: %.5f',
                        iteration_number, jaccard_val_mean, jaccard_train_mean)

                    for i, cls in enumerate(needed_classes_names):
                        model.write_scalar_summary(
                            'jaccard_val/{}'.format(cls), jaccard_val[i])
                        model.write_scalar_summary(
                            'jaccard_train/{}'.format(cls), jaccard_train[i])

                    model.write_scalar_summary('jaccard_mean/val',
                                               jaccard_val_mean)
                    model.write_scalar_summary('jaccard_mean/train',
                                               jaccard_train_mean)

                # save the model
                if iteration_number % 500 == 0:
                    model_filename = os.path.join(MODELS_DIR, model_name)
                    saved_filename = model.save_model(model_filename)
                    logging.info('Model saved: %s', saved_filename)

            except KeyboardInterrupt:
                break

        model_filename = os.path.join(MODELS_DIR, model_name)
        saved_filename = model.save_model(model_filename)
        logging.info('Model saved: %s', saved_filename)

        result = {
            'iteration': iteration_number,
            'jaccard_val': jaccard_val_mean,
            'jaccard_train': jaccard_train_mean,
        }
        return result

    # predict
    else:
        model_to_restore = '{}-{}'.format(model_name, model_load_step)
        model_filename = os.path.join(MODELS_DIR, model_to_restore)
        model.restore_model(model_filename)
        logging.info('Model restored: %s', os.path.basename(model_filename))

        logging.info('Target images: %s', prediction_images)

        if prediction_images == 'train':
            target_images = images_train
        elif prediction_images == 'test':
            target_images = images_test
        elif prediction_images == 'test_real':
            target_images = IMAGES_TEST_REAL
        else:
            raise ValueError(
                'Prediction images `{}` unknown'.format(prediction_images))

        for img_number, img_id in enumerate(target_images):
            # if img_id != '6060_2_3':
            #     continue

            img_filename = os.path.join(IMAGES_NORMALIZED_DATA_DIR,
                                        img_id + '.pkl')
            img_normalized = load_pickle(img_filename)

            patches, patches_coord = split_image_to_patches([
                img_normalized,
            ], [
                patch_size,
            ],
                                                            overlap=0.8)
            logging.info('Patches: %s', len(patches[0]))

            X = patches[0]

            data_dict = {'X': X}
            classes_probs_patches = model.predict(data_dict,
                                                  batch_size=batch_size)

            classes_probs = join_mask_patches(
                classes_probs_patches,
                patches_coord[0],
                images_metadata[img_id]['height_rgb'],
                images_metadata[img_id]['width_rgb'],
                softmax=True,
                normalization=False)

            masks_without_excluded = convert_softmax_to_masks(classes_probs)

            # join masks and put zeros insted of excluded classes
            zeros_filler = np.zeros_like(masks_without_excluded[:, :, 0])
            masks_all = []
            j = 0
            for i in range(nb_classes):
                if i + 1 not in classes_to_skip:
                    masks_all.append(masks_without_excluded[:, :, j])
                    j += 1
                else:
                    masks_all.append(zeros_filler)

            masks = np.stack(masks_all, axis=-1)

            save_prediction_mask(IMAGES_PREDICTION_MASK_DIR, masks, img_id,
                                 model_name)

            logging.info('Predicted: %s/%s [%.2f]', img_number + 1,
                         len(target_images),
                         100 * (img_number + 1) / len(target_images))

        result = {}
        return result
def predict(kind, model_name, global_step):
    logging.basicConfig(
        level=logging.INFO,
        format="%(asctime)s : %(levelname)s : %(module)s : %(message)s",
        datefmt="%d-%m-%Y %H:%M:%S")

    matplotlib_setup()

    logging.info('Prediction mode')

    nb_channels_m = 8
    nb_channels_sharpened = 4
    nb_classes = 10

    # skip vehicles and misc manmade structures
    classes_to_skip = {2, 9, 10}
    logging.info('Skipping classes: %s', classes_to_skip)

    # skip M bands that were pansharpened
    m_bands_to_skip = {4, 2, 1, 6}
    needed_m_bands = [
        i for i in range(nb_channels_m) if i not in m_bands_to_skip
    ]
    logging.info('Skipping M bands: %s', m_bands_to_skip)

    patch_size = (
        64,
        64,
    )  # (224, 224,)
    patch_size_sharpened = (
        patch_size[0],
        patch_size[1],
    )
    patch_size_m = (
        patch_size_sharpened[0] // 4,
        patch_size_sharpened[1] // 4,
    )
    logging.info('Patch sizes: %s, %s, %s', patch_size, patch_size_sharpened,
                 patch_size_m)

    images_metadata = load_pickle(IMAGES_METADATA_FILENAME)
    logging.info('Metadata: %s', len(images_metadata))

    images_metadata_polygons = load_pickle(IMAGES_METADATA_POLYGONS_FILENAME)
    logging.info('Polygons metadata: %s', len(images_metadata_polygons))

    images_all, images_train, images_test = get_train_test_images_ids()

    if kind == 'test':
        target_images = images_test
    elif kind == 'train':
        target_images = images_train
    else:
        raise ValueError('Unknown kind: {}'.format(kind))

    nb_target_images = len(target_images)
    logging.info('Target images: %s - %s', kind, nb_target_images)

    batch_size = 25

    # create and load model
    sess_config = tf.ConfigProto(inter_op_parallelism_threads=4,
                                 intra_op_parallelism_threads=4)
    sess_config.gpu_options.allow_growth = True
    sess = tf.Session(config=sess_config)

    model_params = {
        'nb_classes': nb_classes - len(classes_to_skip),
    }
    model = CombinedModel(**model_params)
    model.set_session(sess)
    # model.set_tensorboard_dir(os.path.join(TENSORBOARD_DIR, 'simple_model'))

    # TODO: not a fixed size
    model.add_input('X_sharpened', [
        patch_size_sharpened[0], patch_size_sharpened[1], nb_channels_sharpened
    ])
    model.add_input('X_m', [
        patch_size_m[0], patch_size_m[1], nb_channels_m - len(m_bands_to_skip)
    ])
    model.add_input('Y', [
        patch_size[0],
        patch_size[1],
    ], dtype=tf.uint8)

    model.build_model()

    model_to_restore = '{}-{}'.format(model_name, global_step)
    model_filename = os.path.join(MODELS_DIR, model_to_restore)
    model.restore_model(model_filename)
    logging.info('Model restored: %s', os.path.basename(model_filename))

    for img_number, img_id in enumerate(target_images):
        img_filename = os.path.join(IMAGES_NORMALIZED_DATA_DIR,
                                    img_id + '.pkl')
        img_normalized_sharpened, img_normalized_m = load_pickle(img_filename)

        patches, patches_coord = split_image_to_patches(
            [img_normalized_sharpened, img_normalized_m],
            [patch_size_sharpened, patch_size_m],
            overlap=0.5)

        X_sharpened = np.array(patches[0])
        X_m = np.array(patches[1])
        X_m = X_m[:, :, :, needed_m_bands]

        data_dict = {
            'X_sharpened': X_sharpened,
            'X_m': X_m,
        }
        classes_probs_patches = model.predict(data_dict, batch_size=batch_size)

        classes_probs = join_mask_patches(
            classes_probs_patches,
            patches_coord[0],
            images_metadata[img_id]['height_rgb'],
            images_metadata[img_id]['width_rgb'],
            softmax=True,
            normalization=False)

        masks_without_excluded = convert_softmax_to_masks(classes_probs)

        # join masks and put zeros insted of excluded classes
        zeros_filler = np.zeros_like(masks_without_excluded[:, :, 0])
        masks_all = []
        j = 0
        for i in range(nb_classes):
            if i + 1 not in classes_to_skip:
                masks_all.append(masks_without_excluded[:, :, j])
                j += 1
            else:
                masks_all.append(zeros_filler)

        masks = np.stack(masks_all, axis=-1)

        mask_filename = os.path.join(IMAGES_PREDICTION_MASK_DIR,
                                     '{0}_{1}.npy'.format(img_id, model_name))
        np.save(mask_filename, masks)

        logging.info('Predicted: %s/%s [%.2f]', img_number + 1,
                     nb_target_images,
                     100 * (img_number + 1) / nb_target_images)
def main(model_name):
    logging.basicConfig(
        level=logging.INFO,
        format="%(asctime)s : %(levelname)s : %(module)s : %(message)s",
        datefmt="%d-%m-%Y %H:%M:%S")

    matplotlib_setup()

    classes_names = [
        c.strip().lower().replace(' ', '_').replace('.', '')
        for c in CLASSES_NAMES
    ]
    nb_classes = len(classes_names)
    classes = np.arange(1, nb_classes + 1)
    logging.info('Classes: %s', nb_classes)

    # load images data
    images_data_m = load_pickle(IMAGES_NORMALIZED_M_FILENAME)
    images_data_sharpened = load_pickle(IMAGES_NORMALIZED_SHARPENED_FILENAME)
    logging.info('Images: %s, %s', len(images_data_m),
                 len(images_data_sharpened))

    # load masks
    images_masks = load_pickle(IMAGES_MASKS_FILENAME)
    logging.info('Masks: %s', len(images_masks))

    # load images metadata
    images_metadata = load_pickle(IMAGES_METADATA_FILENAME)
    logging.info('Metadata: %s', len(images_metadata))

    mean_m, std_m, mean_sharpened, std_sharpened = load_pickle(
        IMAGES_MEANS_STDS_FILENAME)
    logging.info('Mean & Std: %s - %s, %s - %s', mean_m.shape, std_m.shape,
                 mean_sharpened.shape, std_sharpened.shape)

    images = sorted(list(images_data_sharpened.keys()))
    nb_images = len(images)
    logging.info('Train images: %s', nb_images)

    images_masks_stacked = stack_masks(images, images_masks, classes)
    logging.info('Masks stacked: %s', len(images_masks_stacked))

    nb_channels_m = images_data_m[images[0]].shape[2]
    nb_channels_sharpened = images_data_sharpened[images[0]].shape[2]
    logging.info('Channels: %s, %s', nb_channels_m, nb_channels_sharpened)

    # skip vehicles and misc manmade structures
    classes_to_skip = {1, 3, 4, 5, 6, 7, 8}  # {2, 9, 10}
    needed_classes = [
        c for c in range(nb_classes) if c + 1 not in classes_to_skip
    ]
    needed_classes_names = [
        c for i, c in enumerate(classes_names) if i + 1 not in classes_to_skip
    ]
    logging.info('Skipping classes: %s', classes_to_skip)

    # skip M bands that were pansharpened
    m_bands_to_skip = {4, 2, 1, 6}
    needed_m_bands = [
        i for i in range(nb_channels_m) if i not in m_bands_to_skip
    ]
    logging.info('Skipping M bands: %s', m_bands_to_skip)

    patch_size = (
        64,
        64,
    )  # (224, 224,)
    patch_size_sharpened = (
        patch_size[0],
        patch_size[1],
    )
    patch_size_m = (
        patch_size_sharpened[0] // 4,
        patch_size_sharpened[1] // 4,
    )
    logging.info('Patch sizes: %s, %s, %s', patch_size, patch_size_sharpened,
                 patch_size_m)

    val_size = 256
    logging.info('Validation size: %s', val_size)

    sess_config = tf.ConfigProto(inter_op_parallelism_threads=4,
                                 intra_op_parallelism_threads=4)
    sess_config.gpu_options.allow_growth = True
    sess = tf.Session(config=sess_config)

    model_params = {
        'nb_classes': nb_classes - len(classes_to_skip),
    }
    model = CombinedModel(**model_params)
    model.set_session(sess)
    model.set_tensorboard_dir(os.path.join(TENSORBOARD_DIR, model_name))

    # TODO: not a fixed size
    model.add_input('X_sharpened', [
        patch_size_sharpened[0], patch_size_sharpened[1], nb_channels_sharpened
    ])
    model.add_input('X_m', [
        patch_size_m[0], patch_size_m[1], nb_channels_m - len(m_bands_to_skip)
    ])
    model.add_input('Y', [
        patch_size[0],
        patch_size[1],
    ], dtype=tf.uint8)

    model.build_model()

    # train model

    nb_iterations = 100000
    nb_samples_train = 1000  # 10 1000
    nb_samples_val = 512  # 10 512
    batch_size = 30  # 5 30

    for iteration_number in range(1, nb_iterations + 1):
        try:
            patches = sample_patches(
                images,
                [images_masks_stacked, images_data_sharpened, images_data_m],
                [patch_size, patch_size_sharpened, patch_size_m],
                nb_samples_train,
                kind='train',
                val_size=val_size)

            Y, X_sharpened, X_m = patches[0], patches[1], patches[2]
            Y_softmax = convert_masks_to_softmax(Y,
                                                 needed_classes=needed_classes)
            X_m = X_m[:, :, :, needed_m_bands]

            data_dict_train = {
                'X_sharpened': X_sharpened,
                'X_m': X_m,
                'Y': Y_softmax
            }
            model.train_model(data_dict_train,
                              nb_epoch=1,
                              batch_size=batch_size)

            # validate the model
            if iteration_number % 5 == 0:

                # calc jaccard val
                patches_val = sample_patches(images, [
                    images_masks_stacked, images_data_sharpened, images_data_m
                ], [patch_size, patch_size_sharpened, patch_size_m],
                                             nb_samples_val,
                                             kind='val',
                                             val_size=val_size)

                Y_val, X_sharpened_val, X_m_val = patches_val[0], patches_val[
                    1], patches_val[2]
                X_m_val = X_m_val[:, :, :, needed_m_bands]

                data_dict_val = {
                    'X_sharpened': X_sharpened_val,
                    'X_m': X_m_val,
                }
                Y_val_pred_probs = model.predict(data_dict_val,
                                                 batch_size=batch_size)
                Y_val_pred = np.stack([
                    convert_softmax_to_masks(Y_val_pred_probs[i])
                    for i in range(nb_samples_val)
                ],
                                      axis=0)

                Y_val = Y_val[:, :, :, needed_classes]
                jaccard_val = jaccard_coef(Y_val_pred, Y_val, mean=False)

                # calc jaccard train
                patches_train_val = sample_patches(images, [
                    images_masks_stacked, images_data_sharpened, images_data_m
                ], [patch_size, patch_size_sharpened, patch_size_m],
                                                   nb_samples_val,
                                                   kind='train',
                                                   val_size=val_size)

                Y_train_val, X_sharpened_train_val, X_m_train_val = \
                    patches_train_val[0], patches_train_val[1], patches_train_val[2]
                X_m_train_val = X_m_train_val[:, :, :, needed_m_bands]

                data_dict_val = {
                    'X_sharpened': X_sharpened_train_val,
                    'X_m': X_m_train_val,
                }
                Y_train_val_pred_probs = model.predict(data_dict_val,
                                                       batch_size=batch_size)
                Y_train_val_pred = np.stack([
                    convert_softmax_to_masks(Y_train_val_pred_probs[i])
                    for i in range(nb_samples_val)
                ],
                                            axis=0)

                Y_train_val = Y_train_val[:, :, :, needed_classes]
                jaccard_train_val = jaccard_coef(Y_train_val_pred,
                                                 Y_train_val,
                                                 mean=False)

                logging.info(
                    'Iteration %s, jaccard val: %.5f, jaccard train: %.5f',
                    iteration_number, np.mean(jaccard_val),
                    np.mean(jaccard_train_val))

                for i, cls in enumerate(needed_classes_names):
                    model.write_scalar_summary('jaccard_val/{}'.format(cls),
                                               jaccard_val[i])
                    model.write_scalar_summary('jaccard_train/{}'.format(cls),
                                               jaccard_train_val[i])

                model.write_scalar_summary('jaccard_mean/val',
                                           np.mean(jaccard_val))
                model.write_scalar_summary('jaccard_mean/train',
                                           np.mean(jaccard_train_val))

            # save the model
            if iteration_number % 15 == 0:
                model_filename = os.path.join(MODELS_DIR, model_name)
                saved_filename = model.save_model(model_filename)
                logging.info('Model saved: %s', saved_filename)

        except KeyboardInterrupt:
            break
def main(model_name):
    logging.basicConfig(
        level=logging.INFO, format="%(asctime)s : %(levelname)s : %(module)s : %(message)s", datefmt="%d-%m-%Y %H:%M:%S"
    )

    nb_classes = 10
    skip_classes = None
    double_pass = False
    use_close = False
    use_min_area = False

    matplotlib_setup()

    min_areas = {cls: 1.0 for cls in range(nb_classes)}
    if use_min_area:
        min_areas[5] = 5000 # crops


    logging.info('Skip classes: %s', skip_classes)
    logging.info('Mode: %s', 'double pass' if double_pass else 'single pass')

    # load images metadata
    images_metadata = load_pickle(IMAGES_METADATA_FILENAME)
    logging.info('Images metadata: %s', len(images_metadata))

    sample_submission = load_sample_submission(SAMPLE_SUBMISSION_FILENAME)
    submission_order = [(row['ImageId'], row['ClassType']) for i, row in sample_submission.iterrows()]

    target_images = sorted(set([r[0] for r in submission_order]))
    # target_images = target_images[:10]
    logging.info('Target images: %s', len(target_images))

    polygons = {}
    for i, img_id in enumerate(target_images):
        img_metadata = images_metadata[img_id]

        try:
            img_mask = load_prediction_mask(IMAGES_PREDICTION_MASK_DIR, img_id, model_name)
        except IOError:
            img_mask = None

        # do closing for roads and tracks
        if use_close:
            img_mask_closed_tmp = close_mask(img_mask, kernel_size=5)
            img_mask_closed = np.copy(img_mask)
            img_mask_closed[:, :, [2, 3]] = img_mask_closed_tmp[:, :, [2, 3]]

            img_mask = img_mask_closed

        if not double_pass:
            # img_mask_simplified = simplify_mask(img_mask, kernel_size=5)
            img_polygons = create_image_polygons(img_mask, img_metadata,
                                                 scale=True, skip_classes=skip_classes, min_areas=min_areas)
        else:
            # img_polygons = create_image_polygons(img_mask, img_metadata, scale=False, skip_classes=skip_classes)
            #
            # img_mask_reconstructed = []
            # for class_type in sorted(img_polygons.keys()):
            #     ploy_metadata = {'ploy_scaled': img_polygons[class_type].wkt}
            #     img_class_mask_reconstructed = create_mask_from_metadata(img_metadata, ploy_metadata)
            #     img_mask_reconstructed.append(img_class_mask_reconstructed)
            #
            # img_mask = np.stack(img_mask_reconstructed, axis=-1)
            # img_polygons = create_image_polygons(img_mask, img_metadata, scale=True, skip_classes=skip_classes)
            raise NotImplementedError('Double pass is not implemented yet')

        polygons[img_id] = img_polygons

        if (i + 1) % 10 == 0:
            logging.info('\n\nProcessed images: %s/%s [%.2f%%]\n\n',
                         i + 1, len(target_images), 100 * (i + 1) / len(target_images))

    submission_filename = os.path.join(SUBMISSION_DIR, 'submission_{}.csv'.format(model_name))
    save_submission(polygons, submission_order, submission_filename, skip_classes=skip_classes)
def main():
    logging.basicConfig(
        level=logging.INFO,
        format="%(asctime)s : %(levelname)s : %(module)s : %(message)s",
        datefmt="%d-%m-%Y %H:%M:%S")

    matplotlib_setup()

    grid_sizes = load_grid_sizes(GRID_SIZES_FILENAME)
    polygons = load_polygons(POLYGONS_FILENAME)

    images_all, images_train, images_test = get_train_test_images_ids()
    logging.info('Train: %s, Test: %s, All: %s', len(images_train),
                 len(images_test), len(images_all))

    # create images metadata
    images_sizes_rgb = get_images_sizes(IMAGES_THREE_BAND_DIR,
                                        target_images=images_all)
    images_sizes_m = get_images_sizes(IMAGES_SIXTEEN_BAND_DIR,
                                      target_images=images_all,
                                      target_format='M')
    images_sizes_p = get_images_sizes(IMAGES_SIXTEEN_BAND_DIR,
                                      target_images=images_all,
                                      target_format='P')
    images_metadata, images_metadata_polygons = create_images_metadata(
        grid_sizes, polygons, images_sizes_rgb, images_sizes_m, images_sizes_p)
    logging.info('Metadata: %s, polygons metadata: %s', len(images_metadata),
                 len(images_metadata_polygons))

    # load train images
    images_data_m = load_images(IMAGES_SIXTEEN_BAND_DIR,
                                target_images=images_train,
                                target_format='M')
    images_data_p = load_images(IMAGES_SIXTEEN_BAND_DIR,
                                target_images=images_train,
                                target_format='P')

    # pansharpen to get (R,G,B,NIR) + (rest,) scaled images
    images_data_sharpened = pansharpen_images(images_data_m, images_data_p)
    logging.info('Images sharpened: %s', len(images_data_sharpened))

    # create masks using RGB sizes
    images_masks = create_classes_masks(images_metadata,
                                        images_metadata_polygons)
    logging.info('Masks created: %s', len(images_masks))

    # free the memory
    del images_data_m
    del images_data_p

    # normalize the data channel by channel
    nb_channels_sharpened = images_data_sharpened[images_train[0]].shape[2]
    channels_means_stds_sharpened = []
    for i in range(nb_channels_sharpened):
        ch_mean_std = calculate_channel_mean_std(images_data_sharpened, i)
        channels_means_stds_sharpened.append(ch_mean_std)

        logging.info('Channel normalized: %s', i)

    channels_means_stds_sharpened = np.array(channels_means_stds_sharpened)
    mean_sharpened = channels_means_stds_sharpened[:, 0]
    std_sharpened = channels_means_stds_sharpened[:, 1]

    images_data_sharpened_normalized = normalize_images(
        images_data_sharpened, mean_sharpened, std_sharpened)

    save_pickle(IMAGES_METADATA_FILENAME, images_metadata)
    save_pickle(IMAGES_METADATA_POLYGONS_FILENAME, images_metadata_polygons)
    save_pickle(IMAGES_MASKS_FILENAME, images_masks)

    save_pickle(IMAGES_NORMALIZED_SHARPENED_FILENAME,
                images_data_sharpened_normalized)

    save_pickle(IMAGES_MEANS_STDS_FILENAME, [mean_sharpened, std_sharpened])