示例#1
0
    def get_sample(self, idx):
        """
        Returns sample without transformation for visualization.

        Sample consists of resized previous and current frame with target
        which is passed to the network. Bounding box values are normalized
        between 0 and 1 with respect to the target frame and then scaled by
        factor of 10.
        """
        opts_curr = {}
        curr_sample = {}
        curr_img = self.get_orig_sample(idx, 1)['image']
        currbb = self.get_orig_sample(idx, 1)['bb']
        prevbb = self.get_orig_sample(idx, 0)['bb']
        bbox_curr_shift = BoundingBox(prevbb[0], prevbb[1], prevbb[2],
                                      prevbb[3])
        (rand_search_region, rand_search_location, edge_spacing_x,
         edge_spacing_y) = cropPadImage(bbox_curr_shift, curr_img)
        bbox_curr_gt = BoundingBox(currbb[0], currbb[1], currbb[2], currbb[3])
        bbox_gt_recentered = BoundingBox(0, 0, 0, 0)
        bbox_gt_recentered = bbox_curr_gt.recenter(rand_search_location,
                                                   edge_spacing_x,
                                                   edge_spacing_y,
                                                   bbox_gt_recentered)

        # get larger context
        bbox_curr_shift.kContextFactor = 4
        (rand_search_region_x2, rand_search_location_x2, edge_spacing_x_x2,
         edge_spacing_y_x2) = cropPadImage(bbox_curr_shift, curr_img)

        curr_sample['image'] = rand_search_region
        curr_sample['image_x2'] = rand_search_region_x2
        curr_sample['bb'] = bbox_gt_recentered.get_bb_list()

        # additional options for visualization
        opts_curr['edge_spacing_x'] = edge_spacing_x
        opts_curr['edge_spacing_y'] = edge_spacing_y
        opts_curr['search_location'] = rand_search_location
        opts_curr['search_region'] = rand_search_region

        # build prev sample
        prev_sample = self.get_orig_sample(idx, 0)
        prev_sample_x2 = self.get_orig_sample(idx, 0)
        prev_sample, opts_prev = crop_sample(prev_sample)
        prev_sample_x2, opts_prev_x2 = crop_sample(prev_sample_x2, 4)
        prev_sample['image_x2'] = prev_sample_x2['image']

        # scale
        scale = Rescale((self.input_size, self.input_size))
        scaled_curr_obj = scale(curr_sample, opts_curr)
        scaled_prev_obj = scale(prev_sample, opts_prev)
        training_sample = {
            'previmg': scaled_prev_obj['image'],
            'currimg': scaled_curr_obj['image'],
            'previmg_x2': scaled_prev_obj['image_x2'],
            'currimg_x2': scaled_curr_obj['image_x2'],
            'currbb': scaled_curr_obj['bb']
        }

        return training_sample, opts_curr
示例#2
0
 def show_sample_no_wait(self, idx):
     """
     Helper function to display sample, which is passed to GOTURN.
     Shows previous frame and current frame with bounding box.
     """
     x, _ = self.get_sample(idx)
     prev_image = x['previmg']
     curr_image = x['currimg']
     bb = x['currbb']
     bbox = BoundingBox(bb[0], bb[1], bb[2], bb[3])
     bbox.unscale(curr_image)
     bb = bbox.get_bb_list()
     bb = [int(val) for val in bb]
     prev_image = cv2.cvtColor(prev_image, cv2.COLOR_RGB2BGR)
     curr_image = cv2.cvtColor(curr_image, cv2.COLOR_RGB2BGR)
     curr_image = cv2.rectangle(curr_image, (bb[0], bb[1]), (bb[2], bb[3]),
                                (0, 255, 0), 2)
     concat_image = np.hstack((prev_image, curr_image))
     cv2.imshow('imagenet dataset sample', concat_image)
     cv2.waitKey(150)
示例#3
0
    def _get_rect(self, sample):
        """
        Performs forward pass through the GOTURN network to regress
        bounding box coordinates in the original image dimensions.
        """
        x1, x2 = sample['previmg'], sample['currimg']
        x1 = x1.unsqueeze(0).to(self.device)
        x2 = x2.unsqueeze(0).to(self.device)
        y = self.net(x1, x2)
        bb = y.data.cpu().numpy().transpose((1, 0))
        bb = bb[:, 0]
        bbox = BoundingBox(bb[0], bb[1], bb[2], bb[3])

        # inplace conversion
        bbox.unscale(self.opts['search_region'])
        bbox.uncenter(self.curr_img, self.opts['search_location'],
                      self.opts['edge_spacing_x'], self.opts['edge_spacing_y'])
        return bbox.get_bb_list()
示例#4
0
    def visu_network_input_pred(self,
                                tag,
                                epoch,
                                data,
                                images,
                                target,
                                cam,
                                max_images=10,
                                store=False,
                                jupyter=False,
                                method='def'):
        num = min(max_images, data.shape[0])
        fig = plt.figure(figsize=(10.5, num * 3.5))

        for i in range(num):
            # real render input
            n_render = f'batch{i}_render.png'
            n_real = f'batch{i}_real.png'
            real = np.transpose(
                data[i, :3, :, :].cpu().numpy().astype(np.uint8), (1, 2, 0))
            render = np.transpose(
                data[i, 3:, :, :].cpu().numpy().astype(np.uint8), (1, 2, 0))
            fig.add_subplot(num, 3, i * 3 + 1)
            plt.imshow(real)
            plt.tight_layout()
            fig.add_subplot(num, 3, i * 3 + 2)
            plt.imshow(render)
            plt.tight_layout()

            # prediction
            masked_idx = backproject_points(target[i],
                                            fx=cam[i, 2],
                                            fy=cam[i, 3],
                                            cx=cam[i, 0],
                                            cy=cam[i, 1])
            for j in range(masked_idx.shape[0]):
                try:
                    images[i,
                           int(masked_idx[j, 0]),
                           int(masked_idx[j, 1]), 0] = 0
                    images[i,
                           int(masked_idx[j, 0]),
                           int(masked_idx[j, 1]), 1] = 255
                    images[i,
                           int(masked_idx[j, 0]),
                           int(masked_idx[j, 1]), 2] = 0
                except:
                    pass
            min1 = torch.min(masked_idx[:, 0])
            max1 = torch.max(masked_idx[:, 0])
            max2 = torch.max(masked_idx[:, 1])
            min2 = torch.min(masked_idx[:, 1])
            bb = BoundingBox(p1=torch.stack([min1, min2]),
                             p2=torch.stack([max1, max2]))
            bb_img = bb.plot(images[i, :, :, :3].cpu().numpy().astype(
                np.uint8))
            fig.add_subplot(num, 3, i * 3 + 3)
            plt.imshow(bb_img)
            # fig.add_subplot(num, 2, i * 2 + 4)
            # real = images[i, :, :, :3].cpu().numpy().astype(np.uint8)
            # plt.imshow(real)
        if method != 'def':
            a = get_img_from_fig(fig).astype(np.uint8)
            plt.close()
            return a

        if store:
            #store_ar = (img_d* 255).round().astype(np.uint8)
            plt.savefig(
                f'{self.p_visu}/{str(epoch)}_{tag}_network_input_and_prediction.png',
                dpi=300)
            #save_image(img_d, tag=str(epoch) + tag, p_store=self.p_visu)
        if jupyter:
            plt.show()
        if self.writer is not None:
            # you can get a high-resolution image as numpy array!!
            plot_img_np = get_img_from_fig(fig)
            self.writer.add_image(tag,
                                  plot_img_np,
                                  global_step=epoch,
                                  dataformats='HWC')
        plt.close()
示例#5
0
    def plot_batch_projection(self,
                              tag,
                              epoch,
                              images,
                              target,
                              cam,
                              max_images=10,
                              store=False,
                              jupyter=False,
                              method='def'):

        num = min(max_images, target.shape[0])
        fig = plt.figure(figsize=(7, num * 3.5))
        for i in range(num):
            masked_idx = backproject_points(target[i],
                                            fx=cam[i, 2],
                                            fy=cam[i, 3],
                                            cx=cam[i, 0],
                                            cy=cam[i, 1])

            for j in range(masked_idx.shape[0]):
                try:
                    images[i,
                           int(masked_idx[j, 0]),
                           int(masked_idx[j, 1]), 0] = 0
                    images[i,
                           int(masked_idx[j, 0]),
                           int(masked_idx[j, 1]), 1] = 255
                    images[i,
                           int(masked_idx[j, 0]),
                           int(masked_idx[j, 1]), 2] = 0
                except:
                    pass

            min1 = torch.min(masked_idx[:, 0])
            max1 = torch.max(masked_idx[:, 0])
            max2 = torch.max(masked_idx[:, 1])
            min2 = torch.min(masked_idx[:, 1])

            bb = BoundingBox(p1=torch.stack([min1, min2]),
                             p2=torch.stack([max1, max2]))

            bb_img = bb.plot(images[i, :, :, :3].cpu().numpy().astype(
                np.uint8))
            fig.add_subplot(num, 2, i * 2 + 1)
            plt.imshow(bb_img)

            fig.add_subplot(num, 2, i * 2 + 2)
            real = images[i, :, :, :3].cpu().numpy().astype(np.uint8)
            plt.imshow(real)

        if method != 'def':
            a = get_img_from_fig(fig).astype(np.uint8)
            plt.close()
            return a

        if store:
            #store_ar = (img_d* 255).round().astype(np.uint8)
            plt.savefig(f'{self.p_visu}/{str(epoch)}_{tag}_project_batch.png',
                        dpi=300)
            #save_image(img_d, tag=str(epoch) + tag, p_store=self.p_visu)
        if jupyter:
            plt.show()
        if self.writer is not None:
            # you can get a high-resolution image as numpy array!!
            plot_img_np = get_img_from_fig(fig)
            self.writer.add_image(tag,
                                  plot_img_np,
                                  global_step=epoch,
                                  dataformats='HWC')