Ejemplo n.º 1
0
def predict_on_images(input_dir, graph, sess, output_dir, tmp_dir, score_threshold, categories, num_imgs, inference_times,
                      delete_input, output_polygons, mask_threshold, mask_nth, output_minrect, view_margin, fully_connected,
                      fit_bbox_to_polygon, output_width_height, bbox_as_fallback):
    """
    Method performing predictions on all images ony by one or combined as specified by the int value of num_imgs.

    :param input_dir: the directory with the images
    :type input_dir: str
    :param graph: the graph to use
    :type graph: tf.Graph()
    :param sess: the tensorflow session
    :type sess: tf.Session
    :param output_dir: the output directory to move the images to and store the predictions
    :type output_dir: str
    :param tmp_dir: the temporary directory to store the predictions until finished
    :type tmp_dir: str
    :param score_threshold: the minimum score predictions have to have
    :type score_threshold: float
    :param categories: the label map
    :param num_imgs: the number of images to combine into one before presenting to graph
    :type num_imgs: int
    :param inference_times: whether to output a CSV file with the inference times
    :type inference_times: bool
    :param delete_input: whether to delete the input images rather than moving them to the output directory
    :type delete_input: bool
    :param output_polygons: whether the model predicts masks and polygons should be stored in the CSV files
    :type output_polygons: bool
    :param mask_threshold: the threshold to use for determining the contour of a mask
    :type mask_threshold: float
    :param mask_nth: to speed up polygon computation, use only every nth row and column from mask
    :type mask_nth: int
    :param output_minrect: when predicting polygons, whether to output the minimal rectangles around the objects as well
    :type output_minrect: bool
    :param view_margin: the margin in pixels to use around the masks
    :type view_margin: int
    :param fully_connected: whether regions of 'high' or 'low' values should be fully-connected at isthmuses
    :type fully_connected: str
    :param fit_bbox_to_polygon: whether to fit the bounding box to the polygon
    :type fit_bbox_to_polygon: bool
    :param output_width_height: whether to output x/y/w/h instead of x0/y0/x1/y1
    :type output_width_height: bool
    :param bbox_as_fallback: if ratio between polygon-bbox and bbox is smaller than this value, use bbox as fallback polygon, ignored if < 0
    :type bbox_as_fallback: float
    """

    # Iterate through all files present in "test_images_directory"
    total_time = 0
    if inference_times:
        times = list()
        times.append("Image(s)_file_name(s),Total_time(ms),Number_of_images,Time_per_image(ms)\n")

    # counter for keeping track of images that cannot be processed
    incomplete_counter = dict()

    while True:
        start_time = datetime.now()
        im_list = []
        # Loop to pick up images equal to num_imgs or the remaining images if less
        for image_path in os.listdir(input_dir):
            # Load images only
            ext_lower = os.path.splitext(image_path)[1]
            if ext_lower in SUPPORTED_EXTS:
                full_path = os.path.join(input_dir, image_path)
                if auto.is_image_complete(full_path):
                    im_list.append(full_path)
                else:
                    if not full_path in incomplete_counter:
                        incomplete_counter[full_path] = 1
                    else:
                        incomplete_counter[full_path] = incomplete_counter[full_path] + 1

            # remove images that cannot be processed
            remove_from_blacklist = []
            for k in incomplete_counter:
                if incomplete_counter[k] == MAX_INCOMPLETE:
                    print("%s - %s" % (str(datetime.now()), os.path.basename(k)))
                    remove_from_blacklist.append(k)
                    try:
                        if delete_input:
                            print("  flagged as incomplete {} times, deleting\n".format(MAX_INCOMPLETE))
                            os.remove(k)
                        else:
                            print("  flagged as incomplete {} times, skipping\n".format(MAX_INCOMPLETE))
                            os.rename(k, os.path.join(output_dir, os.path.basename(k)))
                    except:
                        print(traceback.format_exc())

            for k in remove_from_blacklist:
                del incomplete_counter[k]

            if len(im_list) == num_imgs:
                break

        if len(im_list) == 0:
            time.sleep(1)
            break
        else:
            print("%s - %s" % (str(datetime.now()), ", ".join(os.path.basename(x) for x in im_list)))

        try:
            # Combining picked up images
            i = len(im_list)
            combined = []
            comb_img = None
            if i > 1:
                while i != 0:
                    if comb_img is None:
                        img2 = Image.open(im_list[i-1])
                        img1 = Image.open(im_list[i-2])
                        i -= 1
                        combined.append(os.path.join(output_dir, "combined.png"))
                    else:
                        img2 = comb_img
                        img1 = Image.open(im_list[i-1])
                    i -= 1
                    # Remove alpha channel if present
                    img1 = remove_alpha_channel(img1)
                    img2 = remove_alpha_channel(img2)
                    w1, h1 = img1.size
                    w2, h2 = img2.size
                    comb_img = np.zeros((h1+h2, max(w1, w2), 3), np.uint8)
                    comb_img[:h1, :w1, :3] = img1
                    comb_img[h1:h1+h2, :w2, :3] = img2
                    comb_img = Image.fromarray(comb_img)

            if comb_img is None:
                im_name = im_list[0]
                image = Image.open(im_name)
                image = remove_alpha_channel(image)
            else:
                im_name = combined[0]
                image = remove_alpha_channel(comb_img)

            image_np = image_to_numpyarray(image)
            output_dict = inference_for_image(image_np, graph, sess)

            # Loading results
            boxes = output_dict['detection_boxes']
            scores = output_dict['detection_scores']
            classes = output_dict['detection_classes']

            # Code for splitting rois to multiple csv's, one csv per image before combining
            max_height = 0
            prev_min = 0
            for i in range(len(im_list)):
                img = Image.open(im_list[i])
                img_height = img.height
                min_height = prev_min
                max_height += img_height
                prev_min = max_height
                roi_path = "{}/{}-rois.csv".format(output_dir, os.path.splitext(os.path.basename(im_list[i]))[0])
                if tmp_dir is not None:
                    roi_path_tmp = "{}/{}-rois.tmp".format(tmp_dir, os.path.splitext(os.path.basename(im_list[i]))[0])
                else:
                    roi_path_tmp = "{}/{}-rois.tmp".format(output_dir, os.path.splitext(os.path.basename(im_list[i]))[0])

                roiobjs = []
                for index in range(output_dict['num_detections']):
                    score = scores[index]

                    # Ignore this roi if the score is less than the provided threshold
                    if score < score_threshold:
                        continue

                    y0n, x0n, y1n, x1n = boxes[index]

                    # Translate roi coordinates into combined image coordinates
                    x0 = x0n * image.width
                    y0 = y0n * image.height
                    x1 = x1n * image.width
                    y1 = y1n * image.height

                    if y0 > max_height or y1 > max_height:
                        continue
                    elif y0 < min_height or y1 < min_height:
                        continue

                    label = classes[index]
                    label_str = categories[label - 1]['name']

                    px = None
                    py = None
                    pxn = None
                    pyn = None
                    bw = None
                    bh = None
                    if output_polygons:
                        px = []
                        py = []
                        pxn = []
                        pyn = []
                        bw = ""
                        bh = ""
                        if 'detection_masks'in output_dict:
                            poly = mask_to_polygon(output_dict['detection_masks'][index], mask_threshold=mask_threshold,
                                                   mask_nth=mask_nth, view=(x0, y0, x1, y1), view_margin=view_margin,
                                                   fully_connected=fully_connected)
                            if len(poly) > 0:
                                px, py = polygon_to_lists(poly[0], swap_x_y=True, normalize=False)
                                pxn, pyn = polygon_to_lists(poly[0], swap_x_y=True, normalize=True, img_width=image.width, img_height=image.height)
                                if output_minrect:
                                    bw, bh = polygon_to_minrect(poly[0])
                                if bbox_as_fallback >= 0:
                                    if len(px) >= 3:
                                        p_x0n, p_y0n, p_x1n, p_y1n = polygon_to_bbox(lists_to_polygon(pxn, pyn))
                                        p_area = (p_x1n - p_x0n) * (p_y1n - p_y0n)
                                        b_area = (x1n - x0n) * (y1n - y0n)
                                        if (b_area > 0) and (p_area / b_area < bbox_as_fallback):
                                            px = [float(i) for i in [x0, x1, x1, x0]]
                                            py = [float(i) for i in [y0, y0, y1, y1]]
                                            pxn = [float(i) for i in [x0n, x1n, x1n, x0n]]
                                            pyn = [float(i) for i in [y0n, y0n, y1n, y1n]]
                                    else:
                                        px = [float(i) for i in [x0, x1, x1, x0]]
                                        py = [float(i) for i in [y0, y0, y1, y1]]
                                        pxn = [float(i) for i in [x0n, x1n, x1n, x0n]]
                                        pyn = [float(i) for i in [y0n, y0n, y1n, y1n]]
                                    if output_minrect:
                                        bw = x1 - x0 + 1
                                        bh = y1 - y0 + 1
                                if fit_bbox_to_polygon:
                                    if len(px) >= 3:
                                        x0, y0, x1, y1 = polygon_to_bbox(lists_to_polygon(px, py))
                                        x0n, y0n, x1n, y1n = polygon_to_bbox(lists_to_polygon(pxn, pyn))

                    roiobj = ROIObject(x0, y0, x1, y1, x0n, y0n, x1n, y1n, label, label_str, score=score,
                                       poly_x=px, poly_y=py, poly_xn=pxn, poly_yn=pyn,
                                       minrect_w=bw, minrect_h=bh)
                    roiobjs.append(roiobj)

                info = ImageInfo(os.path.basename(im_list[i]))
                roiext = (info, roiobjs)
                options = ["--output", str(tmp_dir if tmp_dir is not None else output_dir), "--no-images"]
                if output_width_height:
                    options.append("--size-mode")
                roiwriter = ROIWriter(options)
                roiwriter.save([roiext])
                if tmp_dir is not None:
                    os.rename(roi_path_tmp, roi_path)
        except:
            print("Failed processing images: {}".format(",".join(im_list)))
            print(traceback.format_exc())

        # Move finished images to output_path or delete it
        for i in range(len(im_list)):
            if delete_input:
                os.remove(im_list[i])
            else:
                os.rename(im_list[i], os.path.join(output_dir, os.path.basename(im_list[i])))

        end_time = datetime.now()
        inference_time = end_time - start_time
        inference_time = int(inference_time.total_seconds() * 1000)
        time_per_image = int(inference_time / len(im_list))
        if inference_times:
            l = ""
            for i in range(len(im_list)):
                l += ("{}|".format(os.path.basename(im_list[i])))
            l += ",{},{},{}\n".format(inference_time, len(im_list), time_per_image)
            times.append(l)
        print("  Inference + I/O time: {} ms\n".format(inference_time))
        total_time += inference_time

        if inference_times:
            with open(os.path.join(output_dir, "inference_time.csv"), "w") as time_file:
                for l in times:
                    time_file.write(l)
            with open(os.path.join(output_dir, "total_time.txt"), "w") as total_time_file:
                total_time_file.write("Total inference and I/O time: {} ms\n".format(total_time))
def process_image(msg_cont):
    """
    Processes the message container, loading the image from the message and forwarding the predictions.

    :param msg_cont: the message container to process
    :type msg_cont: MessageContainer
    """
    config = msg_cont.params.config

    try:
        start_time = datetime.now()
        image = Image.open(io.BytesIO(msg_cont.message['data']))
        image = remove_alpha_channel(image)
        image_array = image_to_numpyarray(image)
        detection = inference_detector(model, image_array)

        assert isinstance(config.class_names, (tuple, list))
        if isinstance(detection, tuple):
            bbox_result, segm_result = detection
            if isinstance(segm_result, tuple):
                segm_result = segm_result[0]  # ms rcnn
        else:
            bbox_result, segm_result = detection, None
        bboxes = np.vstack(bbox_result)
        labels = [
            np.full(bbox.shape[0], i, dtype=np.int32)
            for i, bbox in enumerate(bbox_result)
        ]
        labels = np.concatenate(labels)
        masks = None
        if segm_result is not None:
            segms = mmcv.concat_list(segm_result)
            if isinstance(segms[0], torch.Tensor):
                masks = torch.stack(segms, dim=0).detach().cpu().numpy()
            else:
                masks = np.stack(segms, axis=0)

        objs = []
        for index in range(len(bboxes)):
            x0, y0, x1, y1, score = bboxes[index]
            label = labels[index]
            label_str = config.class_names[label]

            # Ignore this roi if the score is less than the provided threshold
            if score < config.score_threshold:
                continue

            # Translate roi coordinates into original image coordinates (before combining)
            x0n = x0 / image.width
            y0n = y0 / image.height
            x1n = x1 / image.width
            y1n = y1 / image.height

            px = None
            py = None

            if segm_result is not None:
                px = []
                py = []
                mask = masks[index].astype(bool)
                poly = mask_to_polygon(mask,
                                       config.mask_threshold,
                                       mask_nth=config.mask_nth,
                                       view=(x0, y0, x1, y1),
                                       view_margin=config.view_margin,
                                       fully_connected=config.fully_connected)
                if len(poly) > 0:
                    px, py = polygon_to_lists(poly[0],
                                              swap_x_y=True,
                                              normalize=False)
                    pxn, pyn = polygon_to_lists(poly[0],
                                                swap_x_y=True,
                                                normalize=True,
                                                img_width=image.width,
                                                img_height=image.height)
                    if config.bbox_as_fallback >= 0:
                        if len(px) >= 3:
                            p_x0n, p_y0n, p_x1n, p_y1n = polygon_to_bbox(
                                lists_to_polygon(pxn, pyn))
                            p_area = (p_x1n - p_x0n) * (p_y1n - p_y0n)
                            b_area = (x1n - x0n) * (y1n - y0n)
                            if (b_area > 0) and (p_area / b_area <
                                                 config.bbox_as_fallback):
                                px = [float(i) for i in [x0, x1, x1, x0]]
                                py = [float(i) for i in [y0, y0, y1, y1]]
                        else:
                            px = [float(i) for i in [x0, x1, x1, x0]]
                            py = [float(i) for i in [y0, y0, y1, y1]]
                    if config.fit_bbox_to_polygon:
                        if len(px) >= 3:
                            x0, y0, x1, y1 = polygon_to_bbox(
                                lists_to_polygon(px, py))

            bbox = BBox(left=int(x0),
                        top=int(y0),
                        right=int(x1),
                        bottom=int(y1))
            p = []
            if px is None:
                px = [x0, x1, x1, x0]
                py = [y0, y0, y1, y1]
            for i in range(len(px)):
                p.append([int(px[i]), int(py[i])])
            poly = Polygon(points=p)
            pred = ObjectPrediction(label=label_str,
                                    score=float(score),
                                    bbox=bbox,
                                    polygon=poly)
            objs.append(pred)

        preds = ObjectPredictions(id=str(start_time),
                                  timestamp=str(start_time),
                                  objects=objs)
        msg_cont.params.redis.publish(msg_cont.params.channel_out,
                                      preds.to_json_string())
        if config.verbose:
            log("process_images - predictions string published: %s" %
                msg_cont.params.channel_out)
            end_time = datetime.now()
            processing_time = end_time - start_time
            processing_time = int(processing_time.total_seconds() * 1000)
            log("process_images - finished processing image: %d ms" %
                processing_time)

    except KeyboardInterrupt:
        msg_cont.params.stopped = True
    except:
        log("process_images - failed to process: %s" % traceback.format_exc())
Ejemplo n.º 3
0
def process_image(fname, output_dir, poller):
    """
    Method for processing an image.

    :param fname: the image to process
    :type fname: str
    :param output_dir: the directory to write the image to
    :type output_dir: str
    :param poller: the Poller instance that called the method
    :type poller: Poller
    :return: the list of generated output files
    :rtype: list
    """
    result = []

    try:
        image = Image.open(fname)
        image = remove_alpha_channel(image)
        image_array = image_to_numpyarray(image)
        detection = inference_detector(model, image_array)

        assert isinstance(poller.params.class_names, (tuple, list))
        if isinstance(detection, tuple):
            bbox_result, segm_result = detection
        else:
            bbox_result, segm_result = detection, None
        bboxes = np.vstack(bbox_result)
        labels = [np.full(bbox.shape[0], i, dtype=np.int32) for i, bbox in enumerate(bbox_result)]
        labels = np.concatenate(labels)

        roi_path = "{}/{}-rois.csv".format(output_dir, os.path.splitext(os.path.basename(fname))[0])
        img_path = "{}/{}-mask.png".format(output_dir, os.path.splitext(os.path.basename(fname))[0])

        # rois
        roiobjs = []
        mask_comb = None
        for index in range(len(bboxes)):
            x0, y0, x1, y1, score = bboxes[index]
            label = labels[index]
            label_str = poller.params.class_names[label]

            # Ignore this roi if the score is less than the provided threshold
            if score < poller.params.score_threshold:
                continue

            # Translate roi coordinates into original image coordinates (before combining)
            x0n = x0 / image.width
            y0n = y0 / image.height
            x1n = x1 / image.width
            y1n = y1 / image.height

            px = None
            py = None
            pxn = None
            pyn = None
            bw = None
            bh = None

            if segm_result is not None:
                px = []
                py = []
                pxn = []
                pyn = []
                bw = ""
                bh = ""
                segms = mmcv.concat_list(segm_result)
                if isinstance(segms, tuple):
                    mask = segms[0][index]
                    score = segms[1][index]
                else:
                    mask = segms[index]
                mask = maskUtils.decode(mask).astype(np.int)
                poly = mask_to_polygon(mask, poller.params.mask_threshold, mask_nth=poller.params.mask_nth, view=(x0, y0, x1, y1),
                                       view_margin=poller.params.view_margin, fully_connected=poller.params.fully_connected)
                if len(poly) > 0:
                    px, py = polygon_to_lists(poly[0], swap_x_y=True, normalize=False)
                    pxn, pyn = polygon_to_lists(poly[0], swap_x_y=True, normalize=True, img_width=image.width, img_height=image.height)
                    if poller.params.output_minrect:
                        bw, bh = polygon_to_minrect(poly[0])
                    if poller.params.bbox_as_fallback >= 0:
                        if len(px) >= 3:
                            p_x0n, p_y0n, p_x1n, p_y1n = polygon_to_bbox(lists_to_polygon(pxn, pyn))
                            p_area = (p_x1n - p_x0n) * (p_y1n - p_y0n)
                            b_area = (x1n - x0n) * (y1n - y0n)
                            if (b_area > 0) and (p_area / b_area < poller.params.bbox_as_fallback):
                                px = [float(i) for i in [x0, x1, x1, x0]]
                                py = [float(i) for i in [y0, y0, y1, y1]]
                                pxn = [float(i) for i in [x0n, x1n, x1n, x0n]]
                                pyn = [float(i) for i in [y0n, y0n, y1n, y1n]]
                        else:
                            px = [float(i) for i in [x0, x1, x1, x0]]
                            py = [float(i) for i in [y0, y0, y1, y1]]
                            pxn = [float(i) for i in [x0n, x1n, x1n, x0n]]
                            pyn = [float(i) for i in [y0n, y0n, y1n, y1n]]
                        if poller.params.output_minrect:
                            bw = x1 - x0 + 1
                            bh = y1 - y0 + 1
                    if poller.params.fit_bbox_to_polygon:
                        if len(px) >= 3:
                            x0, y0, x1, y1 = polygon_to_bbox(lists_to_polygon(px, py))
                            x0n, y0n, x1n, y1n = polygon_to_bbox(lists_to_polygon(pxn, pyn))

                if poller.params.output_mask_image:
                    mask_img = mask.copy()
                    mask_img[mask_img < poller.params.mask_threshold] = 0
                    mask_img[mask_img >= poller.params.mask_threshold] = label+1  # first label is 0
                    if mask_comb is None:
                        mask_comb = mask_img
                    else:
                        tmp = np.where(mask_comb==0, mask_img, mask_comb)
                        mask_comb = tmp

            roiobj = ROIObject(x0, y0, x1, y1, x0n, y0n, x1n, y1n, label, label_str, score=score,
                               poly_x=px, poly_y=py, poly_xn=pxn, poly_yn=pyn,
                               minrect_w=bw, minrect_h=bh)
            roiobjs.append(roiobj)

        info = ImageInfo(os.path.basename(fname))
        roiext = (info, roiobjs)
        options = ["--output", output_dir, "--no-images"]
        if poller.params.output_width_height:
            options.append("--size-mode")
        roiwriter = ROIWriter(options)
        roiwriter.save([roiext])
        result.append(roi_path)

        if mask_comb is not None:
            im = Image.fromarray(np.uint8(mask_comb), 'P')
            im.save(img_path, "PNG")
            result.append(img_path)
    except KeyboardInterrupt:
        poller.keyboard_interrupt()
    except:
        poller.error("Failed to process image: %s\n%s" % (fname, traceback.format_exc()))
    return result