Ejemplo n.º 1
0
    def __init__(self, global_shape, patch_res, map_count, modes, gauss_mu_range, gauss_sig_scaling):
        self.global_shape = global_shape
        self.patch_res = patch_res
        self.map_count = map_count
        self.modes = modes
        self.gauss_mu_range = gauss_mu_range
        self.gauss_sig_scaling = gauss_sig_scaling

        self.current_patch_index = -1
        self.patch_boundingboxes = image_utils.compute_patch_boundingboxes(self.global_shape, stride=self.patch_res, patch_res=self.patch_res)
        self.disp_maps = None
        self.create_new_disp_maps()
def sample_patches(params):
    raw_dirpath, tile_info, ds_fac, size, count, seed = params

    im_filepath = read.get_image_filepath(raw_dirpath, tile_info["city"],
                                          tile_info["number"])
    im_size = image_utils.get_image_size(im_filepath)
    polygon_list = read.load_polygons(raw_dirpath, read.POLYGON_DIRNAME,
                                      tile_info["city"], tile_info["number"])

    # Rescale data
    corrected_factor = ds_fac * REFERENCE_PIXEL_SIZE / tile_info["pixelsize"]
    scale_factor = 1 / corrected_factor
    im_size = (int(np.round(im_size[0] * scale_factor)),
               int(np.round(im_size[1] * scale_factor)))
    ds_polygon_list = polygon_utils.rescale_polygon(polygon_list,
                                                    1 / corrected_factor)

    bbox_list = image_utils.compute_patch_boundingboxes(im_size, size, size)

    random.seed(seed)
    random.shuffle(bbox_list)

    # Sample <count> patches in tile, making sure there is at least a polygon inside
    sampled_bbox_list = []
    for bbox in bbox_list:
        bbox_polygon_list = polygon_utils.filter_polygons_in_bounding_box(
            ds_polygon_list, bbox)
        if 1 <= len(bbox_polygon_list):
            sampled_bbox_list.append(bbox)
        if count <= len(sampled_bbox_list):
            break

    tile_info["bbox_list"] = sampled_bbox_list
    tile_info["scale_factor"] = scale_factor

    return tile_info
Ejemplo n.º 3
0
Archivo: model.py Proyecto: dtekeshe/ml
    def inference(self, image_array, ori_gt_array, checkpoints_dir):
        """
        Runs inference on image_array and ori_gt_array with model checkpoint in checkpoints_dir

        :param image_array:
        :param ori_gt_array:
        :param checkpoints_dir:
        :return:
        """
        spatial_shape = image_array.shape[:2]
        # Format inputs
        image_array = image_array[:, :, :3]  # Remove alpha channel if any
        image_array = (image_array / 255) * (self.image_dynamic_range[1] - self.image_dynamic_range[0]) + \
                      self.image_dynamic_range[0]

        ori_gt_array = ori_gt_array / 255

        padding = (self.input_res - self.output_res) // 2

        # Init displacement field and segmentation image
        complete_pred_field_map = np.zeros(
            (spatial_shape[0] - 2 * padding, spatial_shape[1] - 2 * padding, self.disp_output_channels))
        complete_segmentation_image = np.zeros(
            (spatial_shape[0] - 2 * padding, spatial_shape[1] - 2 * padding, self.seg_output_channels))

        # visualization.init_figures(["example"])

        # Iterate over every patch and predict displacement field for this patch
        patch_boundingboxes = image_utils.compute_patch_boundingboxes(spatial_shape,
                                                                      stride=self.output_res,
                                                                      patch_res=self.input_res)
        batch_boundingboxes_list = list(python_utils.split_list_into_chunks(patch_boundingboxes, self.batch_size, pad=True))

        # Saver
        saver = tf.train.Saver(save_relative_paths=True)
        with tf.Session() as sess:
            # Restore checkpoint
            restore_checkpoint_success = self.restore_checkpoint(sess, saver, checkpoints_dir)
            if not restore_checkpoint_success:
                sys.exit('No checkpoint found in {}'.format(checkpoints_dir))

            coord = tf.train.Coordinator()
            threads = tf.train.start_queue_runners(coord=coord)

            # Loop over every batch
            for batch_index, batch_boundingboxes in enumerate(batch_boundingboxes_list):
                if batch_index % 10 == 0:
                    print("Processing batch {}/{}"
                          .format(batch_index + 1, len(batch_boundingboxes_list)))
                # Form batch
                batch_image_list = []
                batch_ori_gt_list = []
                for boundingbox in batch_boundingboxes:
                    patch_image = image_array[boundingbox[0]:boundingbox[2],
                                  boundingbox[1]:boundingbox[3], :]
                    patch_ori_gt = ori_gt_array[boundingbox[0]:boundingbox[2],
                                   boundingbox[1]:boundingbox[3], :]
                    batch_image_list.append(patch_image)
                    batch_ori_gt_list.append(patch_ori_gt)
                batch_image = np.stack(batch_image_list, axis=0)
                batch_ori_gt = np.stack(batch_ori_gt_list, axis=0)

                if self.add_seg_output:
                    batch_pred_disp_field_map, batch_pred_seg = sess.run([self.level_0_disp_pred, self.level_0_seg_pred], feed_dict={
                        self.input_image: batch_image,
                        self.input_disp_polygon_map: batch_ori_gt,
                        self.keep_prob: 1.0
                    })
                else:
                    batch_pred_disp_field_map = sess.run(
                        self.level_0_disp_pred, feed_dict={
                            self.input_image: batch_image,
                            self.input_disp_polygon_map: batch_ori_gt,
                            self.keep_prob: 1.0
                        })
                    batch_pred_seg = np.zeros((batch_pred_disp_field_map.shape[0], batch_pred_disp_field_map.shape[1], batch_pred_disp_field_map.shape[2], self.seg_output_channels))

                # Fill complete outputs
                for batch_index, boundingbox in enumerate(batch_boundingboxes):
                    patch_pred_disp_field_map = batch_pred_disp_field_map[batch_index]
                    patch_pred_seg = batch_pred_seg[batch_index]
                    # print("--- patch_pred_seg: ---")
                    # print(patch_pred_seg[:, :, 0])
                    # print("---")
                    # print(patch_pred_seg[:, :, 1])
                    # print("---")
                    # print(patch_pred_seg[:, :, 2])
                    # print("---")
                    # print(patch_pred_seg[:, :, 3])
                    # print("---")

                    # # visualization.init_figures(["example", "example 2"])
                    # visualization.init_figures(["example"])
                    # patch_image = image_array[boundingbox[0]:boundingbox[2],
                    #               boundingbox[1]:boundingbox[3], :]
                    # patch_image = (patch_image - self.image_dynamic_range[0]) / (
                    #         self.image_dynamic_range[1] - self.image_dynamic_range[0])
                    # visualization.plot_seg("example", patch_image, patch_pred_seg)

                    padded_boundingbox = image_utils.padded_boundingbox(boundingbox, padding)
                    translated_padded_boundingbox = [x - padding for x in padded_boundingbox]
                    complete_pred_field_map[
                    translated_padded_boundingbox[0]:translated_padded_boundingbox[2],
                    translated_padded_boundingbox[1]:translated_padded_boundingbox[3], :] = patch_pred_disp_field_map
                    complete_segmentation_image[
                    translated_padded_boundingbox[0]:translated_padded_boundingbox[2],
                    translated_padded_boundingbox[1]:translated_padded_boundingbox[3],
                    :] = patch_pred_seg

                    # visualization.plot_seg("example 2", patch_image, complete_segmentation_image[
                    # translated_padded_boundingbox[0]:translated_padded_boundingbox[2],
                    # translated_padded_boundingbox[1]:translated_padded_boundingbox[3],
                    # :])

            # visualization.plot_example("example",
            #                            patch_image[0],
            #                            patch_ori_gt[0],
            #                            patch_pred_disp_field_map[0],
            #                            patch_ori_gt[0])

            coord.request_stop()
            coord.join(threads)

        # De-normalize field map
        complete_pred_field_map = complete_pred_field_map / self.disp_map_dynamic_range_fac  # Within [-1, 1]
        complete_pred_field_map = complete_pred_field_map * self.disp_max_abs_value  # Within [-config.DISP_MAX_ABS_VALUE, config.DISP_MAX_ABS_VALUE]

        # # De-normalize segmentation image
        # complete_segmentation_image = complete_segmentation_image * 255
        # complete_segmentation_image = complete_segmentation_image.astype(dtype=np.uint8)

        return complete_pred_field_map, complete_segmentation_image
Ejemplo n.º 4
0
    def compute_patch_gradients(self, ori_image, polygon_map_array,
                                checkpoints_dir):
        """
        Runs inference on image_array and ori_gt_array with model checkpoint in checkpoints_dir

        :param image_array:
        :param ori_gt_array:
        :param checkpoints_dir:
        :return:
        """
        spatial_shape = ori_image.shape[:2]
        # Format inputs
        image = ori_image[:, :, :3]  # Remove alpha channel if any
        image = (image / 255) * (self.image_dynamic_range[1] - self.image_dynamic_range[0]) + \
                self.image_dynamic_range[0]
        polygon_map_array = polygon_map_array / 255

        # Init patch_gradient_list
        patch_info_list = []

        # Iterate over every patch and compute all gradients for this patch
        patch_bbox_list = image_utils.compute_patch_boundingboxes(
            spatial_shape, stride=self.input_res, patch_res=self.input_res)
        y_x = self.level_0_disp_pred[:, :, :, 0]
        y_y = self.level_0_disp_pred[:, :, :, 1]
        xs = tf.trainable_variables()  # All trainable variables
        grad_x_ops = tf.gradients(y_x, xs, name='gradients')
        grad_y_ops = tf.gradients(y_y, xs, name='gradients')
        grad_x_op = [
            grad_x_op for grad_x_op in grad_x_ops if grad_x_op is not None
        ]
        grad_y_op = [
            grad_y_op for grad_y_op in grad_y_ops if grad_y_op is not None
        ]

        # Saver
        saver = tf.train.Saver(save_relative_paths=True)
        with tf.Session() as sess:
            # Restore checkpoint
            restore_checkpoint_success = self.restore_checkpoint(
                sess, saver, checkpoints_dir)
            if not restore_checkpoint_success:
                sys.exit('No checkpoint found in {}'.format(checkpoints_dir))

            # Loop over every patch
            for index, bbox in enumerate(
                    tqdm(patch_bbox_list, desc="Computing patch gradients")):
                patch_image = image[bbox[0]:bbox[2], bbox[1]:bbox[3], :]
                patch_polygon_map = polygon_map_array[bbox[0]:bbox[2],
                                                      bbox[1]:bbox[3], :]

                batch_image = np.expand_dims(patch_image, axis=0)
                batch_polygon_map = np.expand_dims(patch_polygon_map, axis=0)

                feed_dict = {
                    self.input_image: batch_image,
                    self.input_disp_polygon_map: batch_polygon_map,
                    self.keep_prob: 1.0
                }
                patch_grads_x, patch_grads_y = sess.run([grad_x_op, grad_y_op],
                                                        feed_dict=feed_dict)

                patch_ori_image = ori_image[bbox[0]:bbox[2],
                                            bbox[1]:bbox[3], :]
                patch_info = {
                    "bbox": bbox,
                    "image": patch_ori_image,
                    "grads": {
                        "x": patch_grads_x,
                        "y": patch_grads_y,
                    },
                }
                patch_info_list.append(patch_info)

        return patch_info_list
def process_sample_into_patches(patch_stride, patch_res, image, gt_polygon_map, disp_field_maps, disp_polygon_maps,
                                gt_polygons=None, disp_polygons_list=None):
    """
    Crops all inputs to patches generated with patch_stride and patch_res

    :param patch_stride:
    :param patch_res:
    :param image:
    :param gt_polygon_map:
    :param disp_field_maps:
    :param disp_polygon_maps:
    :param gt_polygons:
    :param disp_polygons_list:
    :return:
    """
    include_polygons = gt_polygons is not None and disp_polygons_list is not None
    patches = []
    patch_boundingboxes = image_utils.compute_patch_boundingboxes(image.shape[0:2],
                                                                  stride=patch_stride,
                                                                  patch_res=patch_res)
    # print(patch_boundingboxes)
    for patch_boundingbox in patch_boundingboxes:
        # Crop image
        patch_image = image[patch_boundingbox[0]:patch_boundingbox[2], patch_boundingbox[1]:patch_boundingbox[3], :]

        if include_polygons:
            patch_gt_polygons, \
            patch_disp_polygons_array = polygon_utils.prepare_polygons_for_tfrecord(gt_polygons, disp_polygons_list,
                                                                                    patch_boundingbox)
        else:
            patch_gt_polygons = patch_disp_polygons_array = None

        patch_gt_polygon_map = gt_polygon_map[patch_boundingbox[0]:patch_boundingbox[2],
                               patch_boundingbox[1]:patch_boundingbox[3], :]
        patch_disp_field_maps = disp_field_maps[:,
                                patch_boundingbox[0]:patch_boundingbox[2],
                                patch_boundingbox[1]:patch_boundingbox[3], :]
        patch_disp_polygon_maps_array = disp_polygon_maps[:,
                                        patch_boundingbox[0]:patch_boundingbox[2],
                                        patch_boundingbox[1]:patch_boundingbox[3], :]

        # Filter out patches based on presence of polygon and area ratio inside inner patch =
        patch_inner_res = 2 * patch_stride
        patch_padding = (patch_res - patch_inner_res) // 2
        inner_patch_gt_polygon_map_corners = patch_gt_polygon_map[patch_padding:-patch_padding,
                                             patch_padding:-patch_padding, 2]
        if np.sum(inner_patch_gt_polygon_map_corners) \
                and (not include_polygons or (include_polygons and patch_gt_polygons is not None)):
            assert patch_image.shape[0] == patch_image.shape[
                1], "image should be square otherwise tile_res cannot be defined"
            tile_res = patch_image.shape[0]
            disp_map_count = patch_disp_polygon_maps_array.shape[0]

            patches.append({
                "tile_res": tile_res,
                "disp_map_count": disp_map_count,
                "image": patch_image,
                "gt_polygons": patch_gt_polygons,
                "disp_polygons": patch_disp_polygons_array,
                "gt_polygon_map": patch_gt_polygon_map,
                "disp_field_maps": patch_disp_field_maps,
                "disp_polygon_maps": patch_disp_polygon_maps_array,
            })

    return patches