def __getitem__(self, index):
        img = ImageHelper.read_image(
            self.item_list[index][0],
            tool=self.configer.get('data', 'image_tool'),
            mode=self.configer.get('data', 'input_mode'))

        ori_img_size = ImageHelper.get_size(img)
        if self.aug_transform is not None:
            img = self.aug_transform(img)

        border_hw = ImageHelper.get_size(img)[::-1]
        if self.img_transform is not None:
            img = self.img_transform(img)

        meta = dict(ori_img_size=ori_img_size,
                    border_hw=border_hw,
                    img_path=self.item_list[index][0],
                    filename=self.item_list[index][1])
        return dict(img=DataContainer(img,
                                      stack=True,
                                      return_dc=True,
                                      samples_per_gpu=True),
                    meta=DataContainer(meta,
                                       stack=False,
                                       cpu_only=True,
                                       return_dc=True,
                                       samples_per_gpu=True))
예제 #2
0
    def todc(data_list,
             samples_per_gpu=True,
             stack=False,
             cpu_only=False,
             device_ids=None,
             concat=False):
        if not samples_per_gpu:
            if not stack:
                return DataContainer(data_list,
                                     stack=stack,
                                     samples_per_gpu=samples_per_gpu,
                                     cpu_only=cpu_only)
            else:
                return DataContainer(torch.stack(data_list, 0),
                                     stack=stack,
                                     samples_per_gpu=samples_per_gpu,
                                     cpu_only=cpu_only)

        device_ids = list(range(
            torch.cuda.device_count())) if device_ids is None else device_ids
        samples = (len(data_list) - 1 + len(device_ids)) // len(device_ids)
        stacked = []
        for i in range(0, len(data_list), samples):
            if not stack and not concat:
                stacked.append(data_list[i:i + samples])
            elif stack:
                stacked.append(torch.stack(data_list[i:i + samples], 0))
            else:
                stacked.append(torch.cat(data_list[i:i + samples], 0))

        return DataContainer(stacked,
                             stack=stack,
                             samples_per_gpu=samples_per_gpu,
                             cpu_only=cpu_only)
예제 #3
0
    def __getitem__(self, index):
        imgA = ImageHelper.read_image(
            self.imgA_list[index],
            tool=self.configer.get('data', 'image_tool'),
            mode=self.configer.get('data', 'input_mode'))

        indexB = random.randint(0,
                                len(self.imgB_list) - 1) % len(self.imgB_list)
        imgB = ImageHelper.read_image(
            self.imgB_list[indexB],
            tool=self.configer.get('data', 'image_tool'),
            mode=self.configer.get('data', 'input_mode'))

        if self.aug_transform is not None:
            imgA = self.aug_transform(imgA)
            imgB = self.aug_transform(imgB)

        if self.img_transform is not None:
            imgA = self.img_transform(imgA)
            imgB = self.img_transform(imgB)

        return dict(imgA=DataContainer(imgA, stack=True),
                    imgB=DataContainer(imgB, stack=True),
                    labelA=DataContainer(self.labelA_list[index], stack=True),
                    labelB=DataContainer(self.labelB_list[indexB], stack=True))
예제 #4
0
def stack(batch, data_key=None, device_ids=None):
    if isinstance(batch[0][data_key], DataContainer):
        if batch[0][data_key].stack:
            assert isinstance(batch[0][data_key].data, torch.Tensor) or \
                   isinstance(batch[0][data_key].data, int_classes) or \
                   isinstance(batch[0][data_key].data, float) or \
                   isinstance(batch[0][data_key].data, string_classes) or \
                   isinstance(batch[0][data_key].data, collections.Mapping) or \
                   isinstance(batch[0][data_key].data, collections.Sequence)
            stacked = []
            if batch[0][data_key].samples_per_gpu and len(device_ids) > 1:
                samples_per_gpu = (len(batch) - 1 +
                                   len(device_ids)) // len(device_ids)
                for i in range(0, len(batch), samples_per_gpu):
                    stacked.append(
                        default_collate([
                            sample[data_key].data
                            for sample in batch[i:i + samples_per_gpu]
                        ]))
            else:
                stacked = default_collate(
                    [sample[data_key].data for sample in batch])

            if batch[0][data_key].return_dc and len(device_ids) > 1:
                return DataContainer(
                    stacked,
                    stack=batch[0][data_key].stack,
                    samples_per_gpu=batch[0][data_key].samples_per_gpu,
                    cpu_only=batch[0][data_key].cpu_only)
            else:
                return stacked
        else:
            stacked = []
            if batch[0][data_key].samples_per_gpu and len(device_ids) > 1:
                samples_per_gpu = (len(batch) - 1 +
                                   len(device_ids)) // len(device_ids)
                for i in range(0, len(batch), samples_per_gpu):
                    stacked.append([
                        sample[data_key].data
                        for sample in batch[i:i + samples_per_gpu]
                    ])
            else:
                stacked = [sample[data_key].data for sample in batch]

            if batch[0][data_key].return_dc and len(device_ids) > 1:
                return DataContainer(
                    stacked,
                    stack=batch[0][data_key].stack,
                    samples_per_gpu=batch[0][data_key].samples_per_gpu,
                    cpu_only=batch[0][data_key].cpu_only)
            else:
                return stacked
    else:
        return default_collate([sample[data_key] for sample in batch])
예제 #5
0
    def __getitem__(self, index):
        img = ImageHelper.read_image(
            self.img_list[index],
            tool=self.configer.get('data', 'image_tool'),
            mode=self.configer.get('data', 'input_mode'))
        label = self.label_list[index]

        if self.aug_transform is not None:
            img = self.aug_transform(img)

        if self.img_transform is not None:
            img = self.img_transform(img)

        return dict(
            img=DataContainer(img, stack=True),
            label=DataContainer(label, stack=True),
        )
예제 #6
0
    def __getitem__(self, index):
        img = ImageHelper.read_image(self.img_list[index],
                                     tool=self.configer.get('data', 'image_tool'),
                                     mode=self.configer.get('data', 'input_mode'))
        labels, bboxes, polygons = self.__read_json_file(self.json_list[index])

        if self.aug_transform is not None:
            img, bboxes, labels, polygons = self.aug_transform(img, bboxes=bboxes,
                                                               labels=labels, polygons=polygons)

        if self.img_transform is not None:
            img = self.img_transform(img)

        return dict(
            img=DataContainer(img, stack=True),
            bboxes=DataContainer(bboxes, stack=False),
            labels=DataContainer(labels, stack=False),
            polygons=DataContainer(polygons, stack=False, cpu_only=True)
        )
    def __getitem__(self, index):
        img = ImageHelper.read_image(self.img_list[index],
                                     tool=self.configer.get('data', 'image_tool'),
                                     mode=self.configer.get('data', 'input_mode'))
        if os.path.exists(self.mask_list[index]):
            maskmap = ImageHelper.read_image(self.mask_list[index],
                                             tool=self.configer.get('data', 'image_tool'), mode='P')
        else:
            maskmap = np.ones((img.size[1], img.size[0]), dtype=np.uint8)
            if self.configer.get('data', 'image_tool') == 'pil':
                maskmap = ImageHelper.to_img(maskmap)

        kpts, bboxes = self.__read_json_file(self.json_list[index])

        if self.aug_transform is not None and len(bboxes) > 0:
            img, maskmap, kpts, bboxes = self.aug_transform(img, maskmap=maskmap, kpts=kpts, bboxes=bboxes)

        elif self.aug_transform is not None:
            img, maskmap, kpts = self.aug_transform(img, maskmap=maskmap, kpts=kpts)

        width, height = ImageHelper.get_size(maskmap)
        maskmap = ImageHelper.resize(maskmap,
                                     (width // self.configer.get('network', 'stride'),
                                      height // self.configer.get('network', 'stride')),
                                     interpolation='nearest')

        maskmap = torch.from_numpy(np.array(maskmap, dtype=np.float32))
        maskmap = maskmap.unsqueeze(0)
        heatmap = self.heatmap_generator(kpts, [width, height], maskmap)
        vecmap = self.paf_generator(kpts, [width, height], maskmap)
        if self.img_transform is not None:
            img = self.img_transform(img)

        meta = dict(
            kpts=kpts,
        )
        return dict(
            img=DataContainer(img, stack=True),
            heatmap=DataContainer(heatmap, stack=True),
            maskmap=DataContainer(maskmap, stack=True),
            vecmap=DataContainer(vecmap, stack=True),
            meta=DataContainer(meta, stack=False, cpu_only=True),
        )
예제 #8
0
    def __getitem__(self, index):
        img = ImageHelper.read_image(
            self.img_list[index],
            tool=self.configer.get('data', 'image_tool'),
            mode=self.configer.get('data', 'input_mode'))
        img_size = ImageHelper.get_size(img)
        bboxes, labels = self.__read_json_file(self.json_list[index])
        ori_bboxes, ori_labels = bboxes.copy(), labels.copy()

        if self.aug_transform is not None:
            img, bboxes, labels = self.aug_transform(img,
                                                     bboxes=bboxes,
                                                     labels=labels)

        labels = torch.from_numpy(labels).long()
        bboxes = torch.from_numpy(bboxes).float()

        meta = dict(ori_img_size=img_size,
                    border_wh=ImageHelper.get_size(img),
                    ori_bboxes=torch.from_numpy(ori_bboxes).float(),
                    ori_labels=torch.from_numpy(ori_labels).long())
        if self.img_transform is not None:
            img = self.img_transform(img)

        return dict(img=DataContainer(img,
                                      stack=True,
                                      return_dc=True,
                                      samples_per_gpu=True),
                    bboxes=DataContainer(bboxes,
                                         stack=False,
                                         return_dc=True,
                                         samples_per_gpu=True),
                    labels=DataContainer(labels,
                                         stack=False,
                                         return_dc=True,
                                         samples_per_gpu=True),
                    meta=DataContainer(meta,
                                       stack=False,
                                       cpu_only=True,
                                       return_dc=True,
                                       samples_per_gpu=True))
예제 #9
0
    def __getitem__(self, index):
        img = ImageHelper.read_image(
            self.img_list[index],
            tool=self.configer.get('data', 'image_tool'),
            mode=self.configer.get('data', 'input_mode'))

        kpts, bboxes = self.__read_json_file(self.json_list[index])

        if self.aug_transform is not None:
            img, kpts, bboxes = self.aug_transform(img,
                                                   kpts=kpts,
                                                   bboxes=bboxes)

        heatmap = self.heatmap_generator(kpts, ImageHelper.get_size(img))
        if self.img_transform is not None:
            img = self.img_transform(img)

        return dict(
            img=DataContainer(img, stack=True),
            heatmap=DataContainer(heatmap, stack=True),
        )
예제 #10
0
    def __getitem__(self, index):
        imgA = ImageHelper.read_image(
            self.imgA_list[index],
            tool=self.configer.get('data', 'image_tool'),
            mode=self.configer.get('data', 'input_mode'))

        imgB = ImageHelper.read_image(
            self.imgB_list[index],
            tool=self.configer.get('data', 'image_tool'),
            mode=self.configer.get('data', 'input_mode'))

        if self.aug_transform is not None:
            imgA, imgB = self.aug_transform([imgA, imgB])

        if self.img_transform is not None:
            imgA = self.img_transform(imgA)
            imgB = self.img_transform(imgB)

        return dict(
            imgA=DataContainer(imgA, stack=True),
            imgB=DataContainer(imgB, stack=True),
        )
    def __getitem__(self, index):
        img = ImageHelper.read_image(
            self.img_list[index],
            tool=self.configer.get('data', 'image_tool'),
            mode=self.configer.get('data', 'input_mode'))
        img_size = ImageHelper.get_size(img)
        labelmap = ImageHelper.read_image(self.label_list[index],
                                          tool=self.configer.get(
                                              'data', 'image_tool'),
                                          mode='P')
        if self.configer.get('data.label_list', default=None):
            labelmap = self._encode_label(labelmap)

        if self.configer.get('data.reduce_zero_label', default=None):
            labelmap = self._reduce_zero_label(labelmap)

        ori_target = ImageHelper.to_np(labelmap)

        if self.aug_transform is not None:
            img, labelmap = self.aug_transform(img, labelmap=labelmap)

        border_size = ImageHelper.get_size(img)

        if self.img_transform is not None:
            img = self.img_transform(img)

        if self.label_transform is not None:
            labelmap = self.label_transform(labelmap)

        meta = dict(ori_img_wh=img_size,
                    border_wh=border_size,
                    ori_target=ori_target)
        return dict(
            img=DataContainer(img, stack=True),
            labelmap=DataContainer(labelmap, stack=True),
            meta=DataContainer(meta, stack=False, cpu_only=True),
        )