def cache_train_16():

    print('num_train_images =', train_wkt['ImageId'].nunique())
    train_shapes = shapes[shapes['image_id'].isin(
        train_wkt['ImageId'].unique())]
    np.save('../data/train_ids.npy', train_shapes['image_id'])
    min_train_height = train_shapes['height'].min()
    min_train_width = train_shapes['width'].min()

    ids = []
    i = 0

    for image_id in tqdm(sorted(train_wkt['ImageId'].unique())):
        image = extra_functions.read_image_16(image_id)
        _, height, width = image.shape

        img = image[:, :min_train_height, :min_train_width]
        img_mask = extra_functions.generate_mask(
            image_id,
            height,
            width,
            num_mask_channels=num_mask_channels,
            train=train_wkt)[:, :min_train_height, :min_train_width]

        np.save('../data/data_files/{}_img.npy'.format(image_id), img)
        np.save('../data/data_files/{}_mask.npy'.format(image_id), img_mask)
        ids += [image_id]
        i += 1
def cache_train_vehicle():

    image_set = extra_functions.get_class_image(classes=[9, 10])

    num_train = len(image_set)

    print('num_train_images =', num_train)

    train_shapes = shapes[shapes['image_id'].isin(image_set)]

    image_rows = train_shapes['height'].min()
    image_cols = train_shapes['width'].min()

    num_channels = 22

    num_mask_channels = 2

    f = h5py.File(os.path.join(data_path, 'train_vehicle.h5'), 'w')

    imgs = f.create_dataset('train',
                            (num_train, num_channels, image_rows, image_cols),
                            dtype=np.float32,
                            compression='gzip',
                            compression_opts=9)
    imgs_mask = f.create_dataset(
        'train_mask', (num_train, num_mask_channels, image_rows, image_cols),
        dtype=np.uint8,
        compression='gzip',
        compression_opts=9)

    ids = []

    i = 0
    for image_id in image_set:
        print(image_id)
        image = extra_functions.read_image_22(image_id)
        height, width, _ = image.shape

        imgs[i] = np.transpose(
            cv2.resize(image, (image_cols, image_rows),
                       interpolation=cv2.INTER_CUBIC), (2, 0, 1))
        imgs_mask[i] = np.transpose(
            cv2.resize(np.transpose(
                extra_functions.generate_mask(
                    image_id,
                    height,
                    width,
                    start=0,
                    num_mask_channels=num_mask_channels,
                    train=train_wkt), (1, 2, 0)), (image_cols, image_rows),
                       interpolation=cv2.INTER_CUBIC), (2, 0, 1))

        ids += [image_id]
        i += 1

    # fix from there: https://github.com/h5py/h5py/issues/441
    f['train_ids'] = np.array(ids).astype('|S9')

    f.close()
def cache_train_16():
    print('num_train_images =', train_wkt['ImageId'].nunique())

    train_shapes = shapes[shapes['image_id'].isin(
        train_wkt['ImageId'].unique())]

    min_train_height = train_shapes['height'].min()
    min_train_width = train_shapes['width'].min()

    num_train = train_shapes.shape[0]

    image_rows = min_train_height
    image_cols = min_train_width

    num_channels = 16

    num_mask_channels = 10

    f = h5py.File(os.path.join(data_path, 'train_16.h5'),
                  'w',
                  compression='blosc:lz4',
                  compression_opts=9)

    imgs = f.create_dataset('train',
                            (num_train, num_channels, image_rows, image_cols),
                            dtype=np.float16)
    imgs_mask = f.create_dataset(
        'train_mask', (num_train, num_mask_channels, image_rows, image_cols),
        dtype=np.uint8)

    ids = []

    i = 0
    for image_id in tqdm(sorted(train_wkt['ImageId'].unique())):
        image = extra_functions.read_image_16(image_id)
        _, height, width = image.shape

        imgs[i] = image[:, :min_train_height, :min_train_width]
        imgs_mask[i] = extra_functions.generate_mask(
            image_id,
            height,
            width,
            num_mask_channels=num_mask_channels,
            train=train_wkt)[:, :min_train_height, :min_train_width]

        ids += [image_id]
        i += 1

    # fix from there: https://github.com/h5py/h5py/issues/441
    f['train_ids'] = np.array(ids).astype('|S9')

    f.close()
def cache_test():
    train_wkt = pd.read_csv(os.path.join(data_path, 'train_wkt_v4.csv'))

    print('num_test_images =', train_wkt['ImageId'].nunique())

    train_shapes = shapes[shapes['image_id'].isin(
        train_wkt['ImageId'].unique())]

    num_train = train_shapes.shape[0]

    image_rows = 3328
    image_cols = 3328

    num_channels = 3
    num_mask_channels = 10

    f = h5py.File(os.path.join(data_path, 'all.h5'),
                  'w',
                  compression='blosc:lz4',
                  compression_opts=9)

    imgs = f.create_dataset(
        'image',
        (train_wkt['ImageId'].nunique(), num_channels, image_rows, image_cols),
        dtype=np.float16)
    imgs_mask = f.create_dataset('image_mask',
                                 (train_wkt['ImageId'].nunique(),
                                  num_mask_channels, image_rows, image_cols),
                                 dtype=np.float16)

    ids = []

    i = 0
    for image_id in tqdm(sorted(train_wkt['ImageId'].unique())):
        img_fpath = os.path.join(data_path, 'three_band', '{}.tif')
        image = tiff.imread(img_fpath.format(image_id)) / 2047.0
        _, height, width = image.shape
        imgs[i] = image[:, :3328, :3328]
        imgs_mask[i] = extra_functions.generate_mask(
            image_id,
            height,
            width,
            num_mask_channels=num_mask_channels,
            train=train_wkt)[:, :3328, :3328]

        ids += [image_id]
        i += 1

    # fix from there: https://github.com/h5py/h5py/issues/441
    f['train_ids'] = np.array(ids).astype('|S9')
    f.close()
def cache_train_16():
    print('num_train_images =', train_wkt['ImageId'].nunique())

    train_shapes = shapes[shapes['image_id'].isin(
        train_wkt['ImageId'].unique())]
    np.save('../data/train_ids.npy', train_shapes['image_id'])
    min_train_height = train_shapes['height'].min()
    min_train_width = train_shapes['width'].min()

    #    num_train = train_shapes.shape[0]
    #    image_rows = min_train_height
    #    image_cols = min_train_width
    #    num_channels = 16

    #f = h5py.File(os.path.join(data_path, 'train_16.h5'), 'w')
    #imgs = f.create_dataset('train', (num_train, num_channels, image_rows, image_cols), dtype=np.float16)
    #imgs_mask = f.create_dataset('train_mask', (num_train,image_rows, image_cols), dtype=np.uint8)

    ids = []
    i = 0

    for image_id in tqdm(sorted(train_wkt['ImageId'].unique())):
        image = extra_functions.read_image_16(image_id)
        _, height, width = image.shape

        img = image[:, :min_train_height, :min_train_width]
        img_mask = extra_functions.generate_mask(
            image_id,
            height,
            width,
            num_mask_channels=num_mask_channels,
            train=train_wkt)[:, :min_train_height, :min_train_width]

        np.save('../data/data_files/{}_img.npy'.format(image_id), img)
        np.save('../data/data_files/{}_mask.npy'.format(image_id), img_mask)
        ids += [image_id]
        i += 1
def cache_train_16():
    train_wkt = pd.read_csv(os.path.join(data_path, 'train_wkt_v4.csv'))

    print('num_label_train_images =', len(train_id))
    print('num_unlabel_train_images =', len(unlabel_id))
    print('num_validation_images =', len(validation_id))

    min_train_height, min_train_width = 3328, 3328
    image_rows, image_cols = min_train_height, min_train_width

    num_train = len(train_id)
    num_channels = 3
    num_mask_channels = 10

    num_unlabeled_train, num_validation = len(unlabel_id), len(validation_id)

    f_labeled = h5py.File(os.path.join(data_path, 'train_label.h5'),
                          'w',
                          compression='blosc:lz4',
                          compression_opts=9)
    f_unlabeled = h5py.File(os.path.join(data_path, 'train_unlabel.h5'),
                            'w',
                            compression='blosc:lz4',
                            compression_opts=9)
    f_validation = h5py.File(os.path.join(data_path, 'validation.h5'),
                             'w',
                             compression='blosc:lz4',
                             compression_opts=9)

    imgs_unlabeled = f_unlabeled.create_dataset(
        'polygons',
        (num_unlabeled_train, num_channels, image_rows, image_cols),
        dtype=np.float16)
    imgs_unlabeled_mask = f_unlabeled.create_dataset(
        'train_mask',
        (num_unlabeled_train, num_mask_channels, image_rows, image_cols),
        dtype=np.uint8)

    imgs_labeled = f_labeled.create_dataset(
        'polygons', (num_train, num_channels, image_rows, image_cols),
        dtype=np.float16)
    imgs_labeled_mask = f_labeled.create_dataset(
        'train_mask', (num_train, num_mask_channels, image_rows, image_cols),
        dtype=np.uint8)

    imgs_validation = f_validation.create_dataset(
        'polygons', (num_validation, num_channels, image_rows, image_cols),
        dtype=np.float16)
    imgs_validation_mask = f_validation.create_dataset(
        'train_mask',
        (num_validation, num_mask_channels, image_rows, image_cols),
        dtype=np.uint8)

    ids, unlabel_ids, validation_ids = [], [], []
    tif_fname = os.path.join(data_path, 'three_band', '{}.tif')

    for i, image_id in enumerate(tqdm(train_id)):
        image = tiff.imread(tif_fname.format(image_id)) / 2047.0
        #image = extra_functions.read_image_16(image_id)
        _, height, width = image.shape
        # populate the following datasets: imgs_labeled, imgs_labeled_mask
        imgs_labeled[i] = image[:, :min_train_height, :min_train_width]
        imgs_labeled_mask[i] = extra_functions.generate_mask(
            image_id,
            height,
            width,
            num_mask_channels=num_mask_channels,
            train=train_wkt)[:, :min_train_height, :min_train_width]
        ids += [image_id]
    # fix from there: https://github.com/h5py/h5py/issues/441
    f_labeled['train_ids'] = np.array(ids).astype(
        '|S9')  # add the 'train_ids' field to f_labeled
    f_labeled.close()  # save the data to 'train_label.h5'

    for i, image_id in enumerate(tqdm(unlabel_id)):
        image = tiff.imread(tif_fname.format(image_id)) / 2047.0
        _, height, width = image.shape
        # populate the following datasets: imgs_unlabeled, imgs_unlabeled_mask
        imgs_unlabeled[i] = image[:, :min_train_height, :min_train_width]
        imgs_unlabeled_mask[i] = extra_functions.generate_mask(
            image_id,
            height,
            width,
            num_mask_channels=num_mask_channels,
            train=train_wkt)[:, :min_train_height, :min_train_width]

        unlabel_ids += [image_id]
    f_unlabeled['train_ids'] = np.array(unlabel_ids).astype(
        '|S9')  # add the 'train_ids' field to f_unlabeled
    f_unlabeled.close()  # save the data to 'train_label.h5'

    for i, image_id in enumerate(tqdm(validation_id)):
        image = tiff.imread(tif_fname.format(image_id)) / 2047.0
        _, height, width = image.shape
        # populate the following datasets: imgs_validation, imgs_validation_mask
        imgs_validation[i] = image[:, :min_train_height, :min_train_width]
        imgs_validation_mask[i] = extra_functions.generate_mask(
            image_id,
            height,
            width,
            num_mask_channels=num_mask_channels,
            train=train_wkt)[:, :min_train_height, :min_train_width]
        validation_ids += [image_id]

    f_validation['validation_ids'] = np.array(validation_ids).astype(
        '|S9')  # add the 'validation_ids' field
    f_validation.close()  # save all data
import pandas as pd
import numpy as np
import cv2
import extra_functions

data_path = os.getcwd()
num_channels = 22
num_mask_channels = 2
pred = pd.read_csv('temp_b_s.csv')
shapes = pd.read_csv(os.path.join(data_path, '3_shapes.csv'))
#test_id = pred['ImageId']
test_id = [
    '6050_4_4', '6060_0_1', '6060_1_4', '6100_0_2', '6100_2_4', '6110_2_3',
    '6120_1_4', '6120_3_3'
]

for image_id in test_id:
    print(image_id)
    mask = extra_functions.generate_mask(
        image_id,
        int(shapes[shapes['image_id'] == image_id]['height']),
        int(shapes[shapes['image_id'] == image_id]['width']),
        start=0,
        num_mask_channels=num_mask_channels,
        train=pred)
    mask = np.transpose(mask, (1, 2, 0))
    mask = extra_functions.stretch_n(mask)
    img = np.concatenate([mask, np.expand_dims(mask[:, :, 0], 2)], axis=2)
    img = 255 * img
    img = img.astype(np.uint8)
    cv2.imwrite('mask' + image_id + '.png', img)