def main():
    """
    Main function.
    Does the following step by step:
    * Load images (from which to extract cat faces) from SOURCE_DIR
    * Initialize model (as trained via train_cat_face_locator.py)
    * Prepares images for the model (i.e. shrinks them, squares them)
    * Lets model locate cat faces in the images
    * Projects face coordinates onto original images
    * Squares the face rectangles (as we want to get square images at the end)
    * Extracts faces from images with some pixels of padding around theM
    * Augments each face image several times
    * Removes the padding from each face image
    * Resizes each face image to OUT_SCALE (height, width)
    * Saves each face image (unaugmented + augmented images)
    """
    
    # --------------
    # load images
    # --------------
    images, paths = get_images([SOURCE_DIR])
    images = images
    paths = paths
    # we will use the image filenames when saving the images at the end
    images_filenames = [path[path.rfind("/")+1:] for path in paths]
    
    # --------------
    # create model
    # --------------
    #model = create_model_tiny(MODEL_IMAGE_HEIGHT, MODEL_IMAGE_WIDTH, Adam())
    model = create_model(MODEL_IMAGE_HEIGHT, MODEL_IMAGE_WIDTH, Adam())
    load_weights_seq(model, WEIGHTS_FILEPATH)

    # --------------
    # make all images square with required sizes
    # and roll color channel to dimension index 1 (required by theano)
    # --------------
    paddings = []
    images_padded = np.zeros((len(images), MODEL_IMAGE_HEIGHT, MODEL_IMAGE_WIDTH, 3))
    for idx, image in enumerate(images):
        if idx == 0:
            print(idx, image.shape, paths[idx])
        image_padded, (pad_top, pad_right, pad_bottom, pad_left) = square_image(image)
        images_padded[idx] = misc.imresize(image_padded, (MODEL_IMAGE_HEIGHT, MODEL_IMAGE_WIDTH))
        paddings.append((pad_top, pad_right, pad_bottom, pad_left))
    
    #misc.imshow(images_padded[0])
    
    # roll color channel
    images_padded = np.rollaxis(images_padded, 3, 1)

    # project to 0-1
    images_padded /= 255
    #print(images_padded[0])

    # --------------
    # predict positions of faces
    # --------------
    coordinates_predictions = predict_on_images(model, images_padded)
    
    print("[Predicted positions]", coordinates_predictions[0])
    """
    for idx, (tl_y, tl_x, br_y, br_x) in enumerate(coordinates_predictions):
        marked_image = visualize_rectangle(images_padded[idx]*255, tl_x, br_x, tl_y, br_y, \
                                           (255,), channel_is_first_axis=True)
        misc.imshow(marked_image)
    """
    
    # --------------
    # project coordinates from small padded images to full-sized original images (without padding)
    # --------------
    coordinates_orig = []
    for idx, (tl_y, tl_x, br_y, br_x) in enumerate(coordinates_predictions):
        pad_top, pad_right, pad_bottom, pad_left = paddings[idx]
        height_full = images[idx].shape[0] + pad_top + pad_bottom
        width_full = images[idx].shape[1] + pad_right + pad_left
        height_orig = images[idx].shape[0]
        width_orig = images[idx].shape[1]
        
        tl_y_perc = tl_y / MODEL_IMAGE_HEIGHT
        tl_x_perc = tl_x / MODEL_IMAGE_WIDTH
        br_y_perc = br_y / MODEL_IMAGE_HEIGHT
        br_x_perc = br_x / MODEL_IMAGE_WIDTH
        
        # coordinates on full sized squared image version
        tl_y_full = int(tl_y_perc * height_full)
        tl_x_full = int(tl_x_perc * width_full)
        br_y_full = int(br_y_perc * height_full)
        br_x_full = int(br_x_perc * width_full)
        
        # remove paddings to get coordinates on original images
        tl_y_orig = tl_y_full - pad_top
        tl_x_orig = tl_x_full - pad_left
        br_y_orig = br_y_full - pad_top
        br_x_orig = br_x_full - pad_left
        
        # fix broken coordinates
        # anything below 0
        # anything above image height (y) or width (x)
        # anything where top left >= bottom right
        tl_y_orig = min(max(tl_y_orig, 0), height_orig)
        tl_x_orig = min(max(tl_x_orig, 0), width_orig)
        br_y_orig = min(max(br_y_orig, 0), height_orig)
        br_x_orig = min(max(br_x_orig, 0), width_orig)
        
        if tl_y_orig >= br_y_orig:
            tl_y_orig = br_y_orig - 1
        if tl_x_orig >= br_x_orig:
            tl_x_orig = br_x_orig - 1
        
        coordinates_orig.append((tl_y_orig, tl_x_orig, br_y_orig, br_x_orig))
    
    """
    # project face coordinates to original image sizes
    coordinates_orig = []
    for idx, (tl_y, tl_x, br_y, br_x) in enumerate(coordinates_nopad):
        height_orig = images[idx].shape[0]
        width_orig = images[idx].shape[1]
        
        tl_y_perc = tl_y / MODEL_IMAGE_HEIGHT
        tl_x_perc = tl_x / MODEL_IMAGE_WIDTH
        br_y_perc = br_y / MODEL_IMAGE_HEIGHT
        br_x_perc = br_x / MODEL_IMAGE_WIDTH
        
        tl_y_orig = int(tl_y_perc * height_orig)
        tl_x_orig = int(tl_x_perc * width_orig)
        br_y_orig = int(br_y_perc * height_orig)
        br_x_orig = int(br_x_perc * width_orig)
        
        coordinates_orig.append((tl_y_orig, tl_x_orig, br_y_orig, br_x_orig))
    
    print("[Coordinates on original image]", coordinates_orig[0])
    
    # remove padding from predicted face coordinates
    # tl = top left, br = bottom right
    coordinates_nopad = []
    for idx, (tl_y, tl_x, br_y, br_x) in enumerate(coordinates_predictions):
        pad_top, pad_right, pad_bottom, pad_left = paddings[idx]
        tl_y_nopad = tl_y - pad_top
        tl_x_nopad = tl_x - pad_left
        br_y_nopad = br_y - pad_top
        br_x_nopad = br_x - pad_left
        tpl = (tl_y_nopad, tl_x_nopad, br_y_nopad, br_x_nopad)
        tpl_fixed = [max(coord, 0) for coord in tpl]
        if tpl_fixed[0] >= tpl_fixed[2]:
            tpl_fixed[2] += 1
        elif tpl_fixed[1] >= tpl_fixed[3]:
            tpl_fixed[3] += 1
        tpl_fixed = tuple(tpl_fixed)
        
        if tpl != tpl_fixed:
            print("[WARNING] Predicted coordinate below 0 after padding-removel. Bad prediction." \
                  " (In image %d, coordinates nopad: %s, coordinates pred: %s)" \
                  % (idx, tpl, coordinates_predictions[idx]))
        
        coordinates_nopad.append(tpl_fixed)
    """
    
    print("[Removed padding from predicted coordinates]", coordinates_orig[0])
    
    # --------------
    # square faces
    # --------------
    coordinates_orig_square = []
    for idx, (tl_y, tl_x, br_y, br_x) in enumerate(coordinates_orig):
        height = br_y - tl_y
        width = br_x - tl_x
        i = 0
        # we remove here instead of adding rows/cols, because that way we wont exceed the
        # image maximum sizes
        while height > width:
            if i % 2 == 0:
                tl_y += 1
            else:
                br_y -= 1
            height -= 1
            i += 1
        while width > height:
            if i % 2 == 0:
                tl_x += 1
            else:
                br_x -= 1
            width -= 1
            i += 1
        print("New height:", (br_y-tl_y), "New width:", (br_x-tl_x))
        coordinates_orig_square.append((tl_y, tl_x, br_y, br_x))
    
    print("[Squared face coordinates]", coordinates_orig_square[0])
    
    # --------------
    # pad faces
    # --------------
    # extract "padded" faces, where the padding is part of the original image
    # (N pixels around the face)
    # After doing that, we can augment the "padded" faces, then remove the padding and have less
    # augmentation damage (i.e. areas that would otherwise be black will now be filled with parts
    # of the original image)
    faces_padded = []
    for idx, (tl_y, tl_x, br_y, br_x) in enumerate(coordinates_orig_square):
        image = images[idx]
        # we pad the whole image by N pixels so that we can savely extract an area of N pixels
        # around the face
        image_padded = np.pad(image, ((AUGMENTATION_PADDING, AUGMENTATION_PADDING), \
                                      (AUGMENTATION_PADDING, AUGMENTATION_PADDING), \
                                      (0, 0)), mode=str("median"))
        face_padded = image_padded[tl_y:br_y+2*AUGMENTATION_PADDING, \
                                   tl_x:br_x+2*AUGMENTATION_PADDING, \
                                   ...]
        faces_padded.append(face_padded)
    
    print("[Extracted face with padding]")
    misc.imshow(faces_padded[0])
    
    # --------------
    # augment and save images
    # --------------
    for idx, face_padded in enumerate(faces_padded):
        # these should be the same values for all images
        image_height = face_padded.shape[0]
        image_width = face_padded.shape[1]
        print("[specs of padded face] height", image_height, "width", image_width)
        
        # augment the padded images
        ia = ImageAugmenter(image_width, image_height,
                            channel_is_first_axis=False,
                            hflip=True, vflip=False,
                            scale_to_percent=(0.90, 1.10), scale_axis_equally=True,
                            rotation_deg=45, shear_deg=0,
                            translation_x_px=8, translation_y_px=8)
        images_aug = np.zeros((AUGMENTATION_ITERATIONS, image_height, image_width, 3),
                              dtype=np.uint8)
        for i in range(AUGMENTATION_ITERATIONS):
            images_aug[i, ...] = face_padded
        print("images_aug.shape", images_aug.shape)
        images_aug = ia.augment_batch(images_aug)
        
        # randomly change brightness of whole images
        for idx_aug, image_aug in enumerate(images_aug):
            by_percent = random.uniform(0.90, 1.10)
            images_aug[idx_aug] = np.clip(image_aug * by_percent, 0.0, 1.0)
        print("images_aug.shape [0]:", images_aug.shape)
        
        # add gaussian noise
        # skipped, because that could be added easily in torch as a layer
        #images_aug = images_aug + np.random.normal(0.0, 0.05, images_aug.shape)
        
        # remove the padding
        images_aug = images_aug[:,
                                AUGMENTATION_PADDING:-AUGMENTATION_PADDING,
                                AUGMENTATION_PADDING:-AUGMENTATION_PADDING,
                                ...]
        print("images_aug.shape [1]:", images_aug.shape)
        
        # add the unaugmented image
        images_aug = np.vstack((images_aug, \
                                [face_padded[AUGMENTATION_PADDING:-AUGMENTATION_PADDING, \
                                             AUGMENTATION_PADDING:-AUGMENTATION_PADDING, \
                                             ...]]))
        
        print("images_aug.shape [2]:", images_aug.shape)
        
        # save images
        for i, image_aug in enumerate(images_aug):
            if image_aug.shape[0] * image_aug.shape[1] < MINIMUM_AREA:
                print("Ignoring image %d / %d because it is too small (area of %d vs min. %d)" \
                       % (idx, i, image_aug.shape[0] * image_aug.shape[1], MINIMUM_AREA))
            else:
                image_resized = misc.imresize(image_aug, (OUT_SCALE, OUT_SCALE))
                filename_aug = "%s_%d.jpg" % (images_filenames[idx].replace(".jpg", ""), i)
                #misc.imshow(image_resized)
                misc.imsave(os.path.join(TARGET_DIR, filename_aug), image_resized)
        print(image.shape)
        print(Y_train[i])
        tl_y, tl_x, br_y, br_x = center_scale_to_pixels(image,
                                                        Y_train[i][0], Y_train[i][1],
                                                        Y_train[i][2], Y_train[i][3])
        marked_image = visualize_rectangle(image*255,
                                           tl_x, br_x, tl_y, br_y, (255,),
                                           channel_is_first_axis=True)
        misc.imshow(np.squeeze(marked_image))
    """
    
    model = create_model(MODEL_IMAGE_HEIGHT, MODEL_IMAGE_WIDTH, Adam())
    
    if args.load:
        print("Loading weights...")
        load_weights_seq(model, args.load)

    model.fit(X_train, Y_train, batch_size=128, nb_epoch=EPOCHS, validation_split=0.0,
              validation_data=(X_val, Y_val), show_accuracy=False)

    print("Saving weights...")
    model.save_weights(SAVE_WEIGHTS_FILEPATH, overwrite=SAVE_AUTO_OVERWRITE)

    if SAVE_EXAMPLES:
        print("Saving examples (predictions)...")
        y_preds = predict_on_images(model, X_val)
        for img_idx, (tl_y, tl_x, br_y, br_x) in enumerate(y_preds):
            image = np.rollaxis(X_val[img_idx, ...], 0, 3)
            if GRAYSCALE:
                image_marked = visualize_rectangle(image*255,
                                                   tl_x, br_x, tl_y, br_y, (255,),