Esempio n. 1
0
    def update(self, detections: BoxList, time: float):
        self.last_update = time
        assert detections.has_field('index') and detections.mode == 'xyxy'
        for i, ind in enumerate(detections.get_field('index')):
            ind = int(ind)
            box, label, mask, score = detections.bbox[i], detections.get_field('labels')[i], \
                                      detections.get_field('mask')[i], detections.get_field('scores')[i]
            location = np.asarray([(box[0] + box[2]) / 2, box[-1]]).round().astype(np.int)  # assumed car position
            region_code = self._object_region(location)
            region = CODE_TO_REGION[region_code]  # position at the moment

            if ind in self.instances:
                self.instances[ind]['regions'].append(region)
                self.instances[ind]['labels'].append(int(label))
                self.instances[ind]['scores'].append(float(score))
                self.instances[ind]['locations'].append(location)
                self.instances[ind]['box'].append(box)
                self.instances[ind]['mask'].append(mask)
                self.instances[ind]['lost'] = self.last_update
            else:
                self.instances[ind] = {
                    "regions": [region],
                    "labels": [int(label)],
                    "scores": [float(score)],
                    "locations": [location],
                    "box": [box],
                    "mask": [mask],
                    "appeared": self.last_update,
                    "lost": self.last_update,
                }
    def __getitem__(self, idx):
        img, anno = super(COCODataset, self).__getitem__(idx)

        # filter crowd annotations
        # TODO might be better to add an extra field
        anno = [obj for obj in anno if obj["iscrowd"] == 0]

        boxes = [obj["bbox"] for obj in anno]
        boxes = torch.as_tensor(boxes).reshape(-1, 4)  # guard against no boxes
        target = BoxList(boxes, img.size, mode="xywh").convert("xyxy")

        classes = [obj["category_id"] for obj in anno]
        classes = [self.json_category_id_to_contiguous_id[c] for c in classes]
        classes = torch.tensor(classes)
        target.add_field("labels", classes)

        if anno and "segmentation" in anno[0]:
            masks = [obj["segmentation"] for obj in anno]
            masks = SegmentationMask(masks, img.size, mode='poly')
            target.add_field("masks", masks)

        if anno and "keypoints" in anno[0]:
            keypoints = [obj["keypoints"] for obj in anno]
            keypoints = PersonKeypoints(keypoints, img.size)
            target.add_field("keypoints", keypoints)

        target = target.clip_to_image(remove_empty=True)

        if self.panoptic_on:
            # add semantic masks to the boxlist for panoptic
            img_id = self.ids[idx]
            img_path = self.coco.loadImgs(img_id)[0]['file_name']

            seg_path = self.root.replace('coco', 'coco/annotations').replace(
                'train2017',
                'panoptic_train2017_semantic_trainid_stff').replace(
                    'val2017',
                    'panoptic_val2017_semantic_trainid_stff') + '/' + img_path
            seg_img = Image.open(seg_path.replace('jpg', 'png'))

            # seg_img.mode = 'L'
            seg_gt = torch.ByteTensor(
                torch.ByteStorage.from_buffer(seg_img.tobytes()))
            seg_gt = seg_gt.view(seg_img.size[1], seg_img.size[0], 1)
            seg_gt = seg_gt.transpose(0, 1).transpose(0,
                                                      2).contiguous().float()

            seg_gt = SegmentationMask(seg_gt, seg_img.size, "mask")
            target.add_field("seg_masks", seg_gt)

        if self._transforms is not None:
            img, target = self._transforms(img, target)

        if self.use_binary_mask and target.has_field("masks"):
            if self.use_polygon_det:
                # compute target maps
                masks = target.get_field("masks")
                w, h = target.size
                assert target.mode == "xyxy"
                targets_map = np.ones((h, w), dtype=np.uint8) * 255
                assert len(masks.instances) <= 255
                for target_id, polygons in enumerate(masks.instances):
                    targets_map = self.compute_target_maps(
                        targets_map, target_id, polygons)
                    target.add_field("targets_map", torch.Tensor(targets_map))

            # compute binary masks
            MASK_SIZE = self.binary_mask_size
            binary_masks = torch.zeros(len(target),
                                       MASK_SIZE[0] * MASK_SIZE[1])
            masks = target.get_field("masks")
            # assert len(target) == len(masks.instances)
            for i, polygons in enumerate(masks.instances):
                mask = self.polygons_to_mask(polygons)
                mask = mask.to(binary_masks.device)
                mask = F.interpolate(mask.unsqueeze(0).unsqueeze(0), MASK_SIZE)
                binary_masks[i, :] = mask.view(-1)
            target.add_field("binary_masks", binary_masks)

        return img, target, idx