コード例 #1
0
def test_on_folder(model, folder_path, save_path='test_outputs/'):

    if not os.path.exists(save_path):
        os.mkdir(save_path)

    files = glob.glob(folder_path + '/*.jpg')

    for i in range(len(files)):
        image_id = i
        image = skimage.io.imread(files[image_id])
        results = model.detect(image[np.newaxis], verbose=0)
        results = results[0]
        class_names = ['slum'] * (len(results['class_ids']) + 1)
        mask = results['masks']

        file_to_save = save_path + '/pred_' + str(image_id) + '.jpg'

        visualize.save_instances(image,
                                 results['rois'],
                                 results['masks'],
                                 results['class_ids'],
                                 class_names,
                                 file_to_save,
                                 results['scores'],
                                 ax=None,
                                 show_bbox=False,
                                 show_mask=True,
                                 title="Predictions " + str(image_id))

        #Uncomment to visualize using matpltolib.
        """
コード例 #2
0
 def do_detech_buildings(self):
     images_path, ouput_path = self.images_path, self.ouput_path
     i = 0
     res = []
     tif_tans = TIF_TRANS()
     for image_name in os.listdir(images_path):
         i += 1
         img_path = os.path.join(images_path, image_name)
         # image = Image.open(imgpath).convert('RGB')
         image = skimage.io.imread(img_path)
         if len(image.shape) == 2:
             image = image[:, :, np.newaxis]
             image = np.concatenate((image, image, image), axis=2)
         w, h, _ = image.shape  # w = 400,h = 400
         results = model.detect([image], verbose=1)
         # Visualize results
         r = results[0]
         visualize.save_instances(image,
                                  r['rois'],
                                  r['masks'],
                                  r['class_ids'],
                                  class_names,
                                  r['scores'],
                                  save_name=image_name,
                                  save_path=ouput_path)
コード例 #3
0
    def do_detech_roads(self):
        images_path, ouput_path = self.images_path, self.ouput_path
        i = 0
        res = []
        tif_tans = TIF_TRANS()
        for image_name in os.listdir(images_path):
            i += 1
            img_path = os.path.join(images_path, image_name)
            # image = Image.open(imgpath).convert('RGB')
            image = skimage.io.imread(img_path)
            if len(image.shape) == 2:
                image = image[:, :, np.newaxis]
                image = np.concatenate((image, image, image), axis=2)
            w, h, _ = image.shape  # w = 400,h = 400
            results = model.detect([image], verbose=1)
            # Visualize results
            r = results[0]
            visualize.save_instances(image,
                                     r['rois'],
                                     r['masks'],
                                     r['class_ids'],
                                     class_names,
                                     r['scores'],
                                     save_name=image_name,
                                     save_path=ouput_path)
            dis, offset_xy = self.center_point(r['rois'], w, h)
            m = re.match(r'(\d+)_(\d+)_(\d+)_(\d+).jpg', image_name)
            row_point, col_point = int(m.group(3)), int(m.group(4))
            x_before, y_before = tif_tans.imagexy2geo(col_point, row_point)
            x_after, y_after = tif_tans.imagexy2geo(col_point + offset_xy[1],
                                                    row_point + offset_xy[0])
            temp = [
                offset_xy[0], offset_xy[1], dis, x_before, y_before, x_after,
                y_after, img_path
            ]
            res.append(temp)
            temp_str = [str(val) for val in temp]
            logging.info('_'.join(temp_str))

        res_data_frame = pd.DataFrame(res,
                                      columns=[
                                          'offset_x', 'offset_y', 'dis',
                                          'x_before', 'y_before', 'x_after',
                                          'y_after', 'img_path'
                                      ])
        all_patch_res = res_data_frame[res_data_frame['dis'] != 0]
        all_patch_res.to_csv(self.all_patch_res_path)
        cluster_res = self.culster(all_patch_res[['offset_x', 'offset_y']])
        cluster_res.to_csv(self.culster_csv)
        filter_patch_res = all_patch_res[cluster_res['jllable'] == 0]
        filter_patch_res.to_csv(self.filter_patch_res_path)
コード例 #4
0
 def do_detech_roads(self, image, image_name, res):
     ouput_path = self.save_path
     img_path = os.path.join(ouput_path, image_name)
     w, h, _ = image.shape  # w = 400,h = 400
     results = self.model.detect([image], verbose=1)
     # Visualize results
     r = results[0]
     visualize.save_instances(image, r['rois'], r['masks'], r['class_ids'],
                              self.class_names, r['scores'], save_name=image_name, save_path=ouput_path)
     dis, offset_xy, is_real = self.center_point(r['rois'], w, h)
     m = re.match(r'(\d+)_(\d+)_(\d+).jpg', image_name)
     row_point, col_point = int(m.group(2)), int(m.group(3))
     x_before, y_before = self.imagexy2geo(col_point, row_point)
     x_after, y_after = self.imagexy2geo(col_point + offset_xy[1], row_point + offset_xy[0])
     temp = [offset_xy[0], offset_xy[1], dis, x_before, y_before, x_after, y_after, is_real, img_path]
     res.append(temp)
     temp_str = [str(val) for val in temp]
     logging.info('_'.join(temp_str))
コード例 #5
0
def test_on_file(model, img):
    image = skimage.io.imread(img)

    results = model.detect(image[np.newaxis], verbose=0)

    results = results[0]
    class_names = ['slum'] * (len(results['class_ids']) + 1)
    mask = results['masks']

    visualize.save_instances(
        image,
        results['rois'],
        results['masks'],
        results['class_ids'],
        class_names,
        '/media/shaheem/Data/Projects/PycharmProjects/djangoProject/templates/assets/results/pred.jpg',
        results['scores'],
        ax=None,
        show_bbox=False,
        show_mask=True,
        title="Predictions")
コード例 #6
0
def main(args):

    # Get the parameters from args
    pb_filepath = args.model_path
    InIMAGE_DIR = args.test_images_path
    OutIMAGE_DIR = args.test_results_path

    # Load the .pb file
    tf.reset_default_graph()
    with tf.gfile.GFile(pb_filepath, 'rb') as f:
        graph_def = tf.GraphDef()
        graph_def.ParseFromString(f.read())
        _ = tf.import_graph_def(graph_def, name='')
    print('Graph loaded.')

    image_names = os.listdir(InIMAGE_DIR)  # List of images
    images = [os.path.join(InIMAGE_DIR, image)
              for image in image_names]  # List of images with absolute paths

    sess = tf.InteractiveSession()

    class_names = ["BG", "building"]

    for i, image in enumerate(images):

        print("Processing:", image_names[i])
        start = time.time()

        output_path = os.path.join(OutIMAGE_DIR, image_names[i][:-4] + ".png")

        # Load image
        image = imageio.imread(image)
        # remove alpha channel
        image = image[:, :, :3]

        results = detect(sess, [image], inference_config)

        r = results[0]

        end = time.time()

        print("Processing time = {} seconds \n".format(end - start))

        fig = visualize.save_instances(image, r['rois'], r['masks'],
                                       r['class_ids'], class_names,
                                       r['scores'])

        plt.imsave(output_path, fig)

    print('Done')
def save_detect_list(model, image_dir=None):
    image_path = os.listdir(image_dir)
    print(image_path[0])
    try:
        shutil.rmtree('draw')
    except:
        pass
    try:
        os.mkdir('draw')
    except:
        pass
    for image_name in image_path:
        image = skimage.io.imread(image_dir+image_name)
        # Detect objects
        predictions = model.detect([image], verbose=1)[0] 
        scores = predictions['scores']
        
        if len(scores)>0:
            index = np.argmax(scores)
            box = predictions['rois'][index]
            score = scores[index]
        visualize.save_instances(image, predictions['rois'], predictions['masks'], predictions['class_ids'], 
                            class_names,image_name, predictions['scores'])
コード例 #8
0
def test_on_img(model, file, id_):
    image = skimage.io.imread(file)
    if image.shape[-1] == 4:
        image = image[..., :3]
    results = model.detect(image[np.newaxis], verbose=0)
    results = results[0]
    class_names = ['slum'] * (len(results['class_ids']) + 1)
    mask = results['masks']

    file_to_save = 'static/images/pred_' + str(id_) + '.jpg'

    _, img = visualize.save_instances(image,
                                      results['rois'],
                                      results['masks'],
                                      results['class_ids'],
                                      class_names,
                                      file_to_save,
                                      results['scores'],
                                      ax=get_ax(0),
                                      show_bbox=False,
                                      show_mask=True,
                                      title="Predictions " + str(id_))
    return img, file_to_save
コード例 #9
0
    # image_id = random.choice(dataset.image_ids)
    image_id = 1473
    image, image_meta, gt_class_id, gt_bbox = \
        modellib.load_image_gt(dataset, config, image_id)
    info = dataset.image_info[image_id]
    print("image ID: {}.{} ({}) {}".format(info["source"], info["id"],
                                           image_id,
                                           dataset.image_reference(image_id)))
    # Run object detection
    results = model.detect([image], verbose=1)

    # Display results
    ax = get_ax(1)
    r = results[0]
    save_path = "{}_new_weights".format(image_id)
    image = dataset.load_image(image_id)
    visualize.save_instances(image,
                             r['rois'],
                             gt_bbox,
                             r['class_ids'],
                             gt_class_id,
                             dataset.class_names,
                             r['scores'],
                             ax=ax,
                             title="Predictions_{}".format(info["id"]),
                             path=save_path,
                             show_mask=False)
    print("gt_class_id", gt_class_id)
    print("gt_bbox", gt_bbox)
コード例 #10
0
    'keyboard', 'cell phone', 'microwave', 'oven', 'toaster', 'sink',
    'refrigerator', 'book', 'clock', 'vase', 'scissors', 'teddy bear',
    'hair drier', 'toothbrush'
]

## Run Object Detection
filenames = glob.glob(os.path.join(InIMAGE_DIR, "*.jpg"))
for counter, fl in enumerate(filenames):
    print("counter = {:5d}".format(counter))
    image_name = fl.split('/')[-1]
    output_path = os.path.join(OutIMAGE_DIR, image_name)
    if os.path.isfile(output_path):
        continue

    # image = skimage.io.imread(os.path.join(IMAGE_DIR, random.choice(file_names)))
    image = skimage.io.imread(fl)
    # remove alpha channel
    image = image[:, :, :3]

    results = model.detect([image], verbose=1)

    r = results[0]
    # visualize.display_instances(image, r['rois'], r['masks'], r['class_ids'], class_names, r['scores'])
    fig = visualize.save_instances(image, r['rois'], r['masks'],
                                   r['class_ids'], class_names, r['scores'])
    # fig.savefig(output_path, bbox_inches='tight', pad_inches=0)

    # print(fig.shape)
    plt.imsave(output_path, fig)
    # fig.imsave(output_path)
コード例 #11
0
ファイル: Predict.py プロジェクト: roeiherz/RelationMaskRCNN
    uuids = ["c927d51b-92852659"]
    # uuids = ["b1d0a191-06deb55d"]
    ids = get_ids_from_uuids(dataset, uuids)
    # ids = [random.choice(dataset.image_ids)]
    # ids = [1536]

    for image_id in ids:
        image, _, gt_class_id, gt_bbox = modellib.load_image_gt(dataset, config, image_id)
        info = dataset.image_info[image_id]
        print("image ID: {}.{} ({}) {}".format(info["source"], info["id"], image_id,
                                               dataset.image_reference(image_id)))
        # Run object detection
        results = model.detect([image], verbose=1, gpi_type=config.GPI_TYPE)

        # Display results
        ax = get_ax(1)
        r = results[0]
        # image = dataset.load_image(image_id)
        visualize.save_instances(image, r['rois'], gt_bbox, r['class_ids'], gt_class_id, dataset.class_names,
                                 r['scores'],
                                 ax=ax, title="Predictions_{}".format(info["id"]),
                                 path="{}/{}_{}.jpg".format(args.save_path, args.model.split('/')[-2], info["id"]),
                                 show_mask=False)
        if r['relation_attention'] is not None:
            visualize.draw_attention(r['rois'], r['relation_attention'], image, info["id"])

        print("gt_class_id", gt_class_id)
        print("gt_bbox", gt_bbox)

    print("End Graph Detector Prediction")
コード例 #12
0
class_names = ['BG', 'apple']

# ## Run Object Detection

# In[ ]:

# Load a random image from the images folder
#file_names = next(os.walk(IMAGE_DIR))[2]
#image = skimage.io.imread(os.path.join(IMAGE_DIR, random.choice(file_names)))
for root, dirs, files in os.walk(IMAGE_DIR, topdown=False):
    for name in files:
        inputFilePath = os.path.join(root, name)
        outputFilePath = os.path.join(ROOT_DIR + "/images/AusTestOutput", name)
        image = skimage.io.imread(inputFilePath)
        # Run detection
        results = model.detect([image], verbose=1)

        # Visualize results
        r = results[0]
        #visualize.display_instances(image, r['rois'], r['masks'], r['class_ids'],
        #                            class_names, r['scores'])
        visualize.save_instances(image,
                                 r['rois'],
                                 r['masks'],
                                 r['class_ids'],
                                 class_names,
                                 r['scores'],
                                 savepath=outputFilePath)

# In[ ]:
コード例 #13
0
ファイル: BDD100K.py プロジェクト: roeiherz/RelationMaskRCNN
def _get_detections_annotations(dataset,
                                model,
                                save_path=None,
                                config=None,
                                batch_size=1):
    """
    Get the detections from the model using the generator.
    The result is a list of lists such that the size is:
        all_detections[num_images][num_classes] = detections[num_detections, 4 + num_classes]

    # Arguments
        generator       : The generator used to run images through the model.
        model           : The model to run on the images.
        save_path       : The path to save the images with visualized detections to.
    # Returns
        A list of lists containing the detections for each image in the data set.
    """
    all_detections = [[None for i in range(dataset.num_labels())]
                      for j in range(dataset.size())]
    all_annotations = [[None for i in range(dataset.num_labels())]
                       for j in range(dataset.size())]
    image_ids = dataset.image_ids

    size = len(image_ids)
    t_prediction = 0
    t_start = time.time()

    # Decide batches per epoch
    if size % batch_size == 0:
        num_of_batches_per_epoch = size / batch_size
    else:
        num_of_batches_per_epoch = size / batch_size + 1

    for batch in range(num_of_batches_per_epoch):
        # Define number of samples per batch
        if batch_size * (batch + 1) >= size:
            nof_samples_per_batch = size - batch_size * batch
        else:
            nof_samples_per_batch = batch_size

        image_lst = []
        gt_class_id_lst = []
        gt_bbox_lst = []
        for current_index in range(nof_samples_per_batch):
            # Get index from files
            ind = batch * batch_size + current_index
            image_id = image_ids[ind]
            # Get data
            image, _, gt_class_id, gt_bbox = modellib.load_image_gt(
                dataset, config, image_id)
            # Append
            image_lst.append(image)
            gt_class_id_lst.append(gt_class_id)
            gt_bbox_lst.append(gt_bbox)

        # Run detection
        t = time.time()
        r_lst = model.detect(image_lst, verbose=0, gpi_type=config.GPI_TYPE)
        t_prediction += (time.time() - t)

        for current_index in range(nof_samples_per_batch):
            # Get index from files
            i = batch * batch_size + current_index
            # Get data
            image_id = image_ids[i]
            gt_class_id = gt_class_id_lst[current_index]
            gt_bbox = gt_bbox_lst[current_index]
            r = r_lst[current_index]

            image_boxes = r["rois"]
            image_labels = r["class_ids"]
            image_scores = r["scores"]
            id = dataset.image_info[image_id]['id']

            if save_path is not None:
                image = dataset.load_image(image_id)
                visualize.save_instances(image,
                                         r['rois'],
                                         gt_bbox,
                                         r['class_ids'],
                                         gt_class_id,
                                         dataset.class_names,
                                         r['scores'],
                                         ax=None,
                                         show_mask=False,
                                         path=os.path.join(
                                             save_path, "{}.jpg".format(id)),
                                         title="Predictions_{}".format(id))

            # select detections - [[num_boxes, y1, x1, y2, x2, score, class_id]]
            image_detections = np.concatenate([
                image_boxes,
                np.expand_dims(image_scores, axis=1),
                np.expand_dims(image_labels, axis=1)
            ],
                                              axis=1)

            # load the annotations - [[num_boxes, y1, x1, y2, x2, class_id]]
            annotations = np.concatenate(
                [gt_bbox, np.expand_dims(gt_class_id, axis=1)], axis=1)

            # copy detections to all_detections
            for label in range(dataset.num_labels()):
                all_detections[i][label] = image_detections[
                    image_detections[:, -1] == label, :-1]

            # copy detections to all_annotations
            for label in range(dataset.num_labels()):
                all_annotations[i][label] = annotations[annotations[:, 4] ==
                                                        label, :4].copy()

            # print('{}/{}'.format(i + 1, dataset.size()))

        print('Batch {}/{}'.format(batch + 1, num_of_batches_per_epoch))

    print("Prediction time: {}. Average {}/image".format(
        t_prediction, t_prediction / len(image_ids)))
    print("Total time: ", time.time() - t_start)
    return all_detections, all_annotations
コード例 #14
0
def main():

    import argparse

    # Parse command line arguments
    parser = argparse.ArgumentParser(
        description='Train Mask R-CNN on satellite images.')
    parser.add_argument("--size",
                        metavar="size",
                        default="small",
                        help="'large' or 'small' images")

    args = parser.parse_args()
    print("Image Type: ", args.size)



    # Root directory of the project
    ROOT_DIR = os.path.abspath("../")

    # Import Mask RCNN
    sys.path.append(ROOT_DIR)  # To find local version of the library
    from mrcnn import utils
    import mrcnn.model as modellib
    from mrcnn import visualize
    # from mrcnn.visualize import save_image # added by JX

    # Import COCO config
    sys.path.append(os.path.join(ROOT_DIR, "samples/coco/"))  # To find local version
    import sate

    #matplotlib inline

    # Directory to save logs and trained model
    MODEL_DIR = os.path.join(ROOT_DIR, "logs")

    # Local path to trained weights file
    COCO_MODEL_PATH = os.path.join(ROOT_DIR, "mask_rcnn_sate_b2_0071.h5")
    #COCO_MODEL_PATH = os.path.join(ROOT_DIR, "mask_rcnn_sate_0055.h5")
    #COCO_MODEL_PATH = os.path.join(ROOT_DIR, "crowdai_pretrained_weights.h5")
    
    # # Download COCO trained weights from Releases if needed
    # if not os.path.exists(COCO_MODEL_PATH):
    #     utils.download_trained_weights(COCO_MODEL_PATH)

    # Directory of images to run detection on
    # IMAGE_DIR = os.path.join(ROOT_DIR, "images")
    InIMAGE_DIR = "/home/ashwin/Desktop/Test/input"
    OutIMAGE_DIR = "/home/ashwin/Desktop/Test/output"
    if not os.path.isdir(OutIMAGE_DIR):
        os.makedirs(OutIMAGE_DIR)

    if args.size == 'large':
        Image.MAX_IMAGE_PIXELS = 1600 * 1600 * 10 * 10
        img_name = os.listdir(InIMAGE_DIR)[0]
        img = os.path.join(InIMAGE_DIR, img_name)
        num_tiles = 64
        print("Slicing image into {} slices".format(num_tiles))
        tiles = image_slicer.slice(img, num_tiles)

    ## Configurations

    class InferenceConfig(sate.CocoConfig):
        # Set batch size to 1 since we'll be running inference on
        # one image at a time. Batch size = GPU_COUNT * IMAGES_PER_GPU
        GPU_COUNT = 1
        IMAGES_PER_GPU = 1

    config = InferenceConfig()
    config.display()

    ## Create Model and Load Trained Weights

    # Create model object in inference mode.
    model = modellib.MaskRCNN(mode="inference", model_dir=MODEL_DIR, config=config)

    model_json = model.KM.to_json()
    with open("model.json", "w") as json_file:
        json_file.write(model_json)
    
    print("Loading weights from:",COCO_MODEL_PATH)

    # Load weights trained on MS-COCO
    model.load_weights(COCO_MODEL_PATH, by_name=True)#, exclude=["mrcnn_bbox_fc", "mrcnn_class_logits", "mrcnn_mask"])



    ## Class Names
    # COCO Class names
    # Index of the class in the list is its ID. For example, to get ID of
    # the teddy bear class, use: class_names.index('teddy bear')
    # class_names = ['Plane', 'Ships', 'Running_Track', 'Helicopter', 'Vehicles', 'Storage_Tanks', 'Tennis_Court',
    #                'Basketball_Court', 'Bridge', 'Roundabout', 'Soccer_Field', 'Swimming_Pool', 'baseball_diamond',
    #                'Buildings', 'Road', 'Tree', 'People', 'Hangar', 'Parking_Lot', 'Airport', 'Motorcycles', 'Flag',
    #                'Sports_Stadium', 'Rail_(for_train)', 'Satellite_Dish', 'Port', 'Telephone_Pole',
    #                'Intersection/Crossroads', 'Shipping_Container_Lot', 'Pier', 'Crane', 'Train', 'Tanks', 'Comms_Tower',
    #                'Cricket_Pitch', 'Submarine', 'Radar', 'Horse_Track', 'Hovercraft', 'Missiles', 'Artillery',
    #                'Racing_Track', 'Vehicle_Sheds', 'Fire_Station', 'Power_Station', 'Refinery', 'Mosques', 'Helipads',
    #                'Shipping_Containers', 'Runway', 'Prison', 'Market/Bazaar', 'Police_Station', 'Quarry', 'School',
    #                'Graveyard', 'Well', 'Rifle_Range', 'Farm', 'Train_Station', 'Crossing_Point', 'Telephone_Line',
    #                'Vehicle_Control_Point', 'Warehouse', 'Body_Of_water', 'Hospital', 'Playground', 'Solar_Panel']

    class_names = ['BG','building']

    # model_filename = "MRCNN_spacemodel.pkl"

    # pickle.dump(model,open(model_filename, 'wb'))

    if args.size == 'large':
        for tile in tiles:
            image = skimage.io.imread(tile.filename)
            # remove alpha channel
            image = image[:, :, :3]
    
            results = model.detect([image], verbose=1)
            r = results[0]
        
            fig = visualize.save_instances(image, r['rois'], r['masks'], r['class_ids'], class_names, r['scores'])
    
            plt.imsave(tile.filename,fig)
            tile.image = Image.open(tile.filename)

        image = join(tiles)
        image.save(os.path.join(OutIMAGE_DIR,img_name))


    if args.size == 'small':        
        gt_flag = 0

        # TODO: If any tif files are present, convert them to jpg 

        ## Run Object Detection
        filenames = glob.glob(os.path.join(InIMAGE_DIR, "*.jpg"))
        with open('./coco/annotations_batch1/train.json') as f: # Not needed for inference
            gj = json.load(f)
        for counter, fl in enumerate(filenames):
            print("counter = {:5d}".format(counter))
            image_name = fl.split('/')[-1]
            output_path = os.path.join(OutIMAGE_DIR, image_name)
            if os.path.isfile(output_path):
                continue

            # image = skimage.io.imread(os.path.join(IMAGE_DIR, random.choice(file_names)))
            image = skimage.io.imread(fl)
            # remove alpha channel
            image = image[:, :, :3]

            results = model.detect([image], verbose=1)


            r = results[0]
            # visualize.display_instances(image, r['rois'], r['masks'], r['class_ids'], class_names, r['scores'])

        #####################################################################################################
            if gt_flag == 1:
                bbox_actual = []


                print("Image ID:",image_name[:-4])
                for ind in gj['annotations']:
                    if ind['image_id'] == image_name[:-4]:
                        bbox_actual.append(ind['bbox'])
            
        #print("Number of proposals =",len(r['rois']))
        #print("Number of ground truth labels =",len(bbox_actual))

        #print("Proposal1= ",r['rois'][0])
        #print("ActualBBox1= ",bbox_actual[0])

        # Ideally it should be a square matrix with high IOU values in the diagonal
        #iou_mat = bbox_iou(r['rois'],bbox_actual,bb_format="XYWH")
        #print(iou_mat)

        #plt.figure()

        #plt.imshow(iou_mat, cmap="hot",interpolation="nearest")
        #plt.show()
        
        #####################################################################################################


            fig = visualize.save_instances(image, r['rois'], r['masks'], r['class_ids'], class_names, r['scores'])
        # fig.savefig(output_path, bbox_inches='tight', pad_inches=0)

            print("Number of detected buildings =",len(r['rois']))
        
            if gt_flag == 1:
                print("Number of ground truth buildings",len(bbox_actual))



        # print(fig.shape)
            plt.imsave(output_path, fig)
コード例 #15
0
def infer(args):
    from mrcnn import model as mrcnn_lib
    infer_path = args.path
    output = args.output
    do_pictures = not args.no_pictures
    do_contours = not args.no_contours
    model_path = args.model
    should_save_masks = not args.no_masks
    gt_adapter = args.gt
    task_id = args.task
    compare_to_gt = gt_adapter is not None

    # Retrieve images
    images = list(generate_images(infer_path))

    # Retrieve gt masks
    if compare_to_gt:
        adapter_class = locate(gt_adapter + '.' + gt_adapter)
        gt_path = infer_path
        gt_annotation_map = generate_annotation_map(adapter_class, gt_path,
                                                    images, task_id)

    # Retrieve model path
    model_path = prompt_model(model_path)

    # Load model
    inference_config = get_inference_config(LeafSegmentorConfig)
    if not os.path.exists(output):
        os.makedirs(output, exist_ok=True)
    model = mrcnn_lib.MaskRCNN(mode="inference",
                               config=inference_config,
                               model_dir=output)
    model.load_weights(model_path, by_name=True)
    model.set_log_dir()

    output_dir = model.log_dir
    os.makedirs(output_dir, exist_ok=True)

    # Infer
    inference_dict = {}
    IoU_dict = {}
    for image_path in tqdm(images):
        inference_dict[image_path] = []
        image_name = os.path.basename(image_path)
        image = np.array(Image.open(image_path))
        r = model.detect([image])[0]
        if should_save_masks:
            save_masks(r['masks'], output_dir, image_name)

        if do_pictures:
            output_file_path = os.path.join(output_dir, image_name)
            _, ax = plt.subplots(1, figsize=(16, 16))
            visualize.save_instances(image,
                                     r['rois'],
                                     r['masks'],
                                     r['class_ids'], ['BG', 'leave'],
                                     r['scores'],
                                     save_to=output_file_path,
                                     ax=ax)

        if do_contours:
            inference_dict[image_path], txt_contours = get_contours(r)

            for i, leaf_contour in enumerate(txt_contours):
                for j, polygon_contour in enumerate(leaf_contour):
                    contour_file_name = os.path.join(output_dir, os.path.splitext(image_name)[0]) + \
                                        "_" + str(i).zfill(3) + "_" + str(j) + ".txt"
                    np.savetxt(contour_file_name,
                               polygon_contour,
                               fmt='%.1f',
                               delimiter=' , ')

        if compare_to_gt:
            gt_masks = get_all_masks(gt_annotation_map[image_path], image_path)
            gt_image_name = ".".join(image_name.split(".")[:-1]) + "_GT.png"
            save_masks(gt_masks, output_dir, gt_image_name)
            IoU_dict[image_path] = calculate_iou(image_name, r['masks'],
                                                 gt_masks)

    if do_contours:
        with open(os.path.join(output_dir, CONTOUR_FILE_NAME), 'w') as f:
            f.write(json.dumps(inference_dict, indent=2))

    if compare_to_gt:
        total_score = sum(IoU_dict.values()) / len(IoU_dict)
        print("average IoU scores: " + str(total_score))