def cache_train_vehicle():

    image_set = extra_functions.get_class_image(classes=[9, 10])

    num_train = len(image_set)

    print('num_train_images =', num_train)

    train_shapes = shapes[shapes['image_id'].isin(image_set)]

    image_rows = train_shapes['height'].min()
    image_cols = train_shapes['width'].min()

    num_channels = 22

    num_mask_channels = 2

    f = h5py.File(os.path.join(data_path, 'train_vehicle.h5'), 'w')

    imgs = f.create_dataset('train',
                            (num_train, num_channels, image_rows, image_cols),
                            dtype=np.float32,
                            compression='gzip',
                            compression_opts=9)
    imgs_mask = f.create_dataset(
        'train_mask', (num_train, num_mask_channels, image_rows, image_cols),
        dtype=np.uint8,
        compression='gzip',
        compression_opts=9)

    ids = []

    i = 0
    for image_id in image_set:
        print(image_id)
        image = extra_functions.read_image_22(image_id)
        height, width, _ = image.shape

        imgs[i] = np.transpose(
            cv2.resize(image, (image_cols, image_rows),
                       interpolation=cv2.INTER_CUBIC), (2, 0, 1))
        imgs_mask[i] = np.transpose(
            cv2.resize(np.transpose(
                extra_functions.generate_mask(
                    image_id,
                    height,
                    width,
                    start=0,
                    num_mask_channels=num_mask_channels,
                    train=train_wkt), (1, 2, 0)), (image_cols, image_rows),
                       interpolation=cv2.INTER_CUBIC), (2, 0, 1))

        ids += [image_id]
        i += 1

    # fix from there: https://github.com/h5py/h5py/issues/441
    f['train_ids'] = np.array(ids).astype('|S9')

    f.close()
    x_scaler, y_scaler = extra_functions.get_scalers(H, W, x_max, y_min)

    mask_channel = first_class
    result += [(image_id, mask_channel + 1,
                mask2poly(new_mask[:, :, 0], threashold1, x_scaler, y_scaler))]
    mask_channel = first_class + 1
    result += [(image_id, mask_channel + 1,
                mask2poly(new_mask[:, :, 1], threashold2, x_scaler, y_scaler))]
    return result


for image_id in tqdm(test_ids):
    #for image_id in test_ids:
    print(image_id)
    image = extra_functions.read_image_22(image_id)

    H = image.shape[0]
    W = image.shape[1]

    x_max, y_min = extra_functions._get_xmax_ymin(image_id)

    model = read_model('b_s.h5')
    result = predict_poly(model, 0.1, 0.1, result, 0)
    model = read_model('r_t.h5')
    result = predict_poly(model, 0.1, 0.1, result, 2)
    model = read_model('t_c.h5')
    result = predict_poly(model, 0.9, 0.9, result, 4)
    model = read_model('vehicle.h5')
    result = predict_poly(model, 0.9, 0.05, result, 8)