def _reduce_zero_label(self, labelmap): if not self.configer.get('data', 'reduce_zero_label'): return labelmap labelmap = np.array(labelmap) labelmap[labelmap == 0] = 255 labelmap = labelmap - 1 labelmap[labelmap == 254] = 255 if self.configer.get('data', 'image_tool') == 'pil': labelmap = ImageHelper.to_img(labelmap.astype(np.uint8)) return labelmap
def _encode_label(self, labelmap): labelmap = np.array(labelmap) shape = labelmap.shape encoded_labelmap = np.ones(shape=(shape[0], shape[1]), dtype=np.float32) * 255 for i in range(len(self.configer.get('data', 'label_list'))): class_id = self.configer.get('data', 'label_list')[i] encoded_labelmap[labelmap == class_id] = i if self.configer.get('data', 'image_tool') == 'pil': encoded_labelmap = ImageHelper.to_img(encoded_labelmap.astype(np.uint8)) return encoded_labelmap
def __getitem__(self, index): img = ImageHelper.read_image( self.img_list[index], tool=self.configer.get('data', 'image_tool'), mode=self.configer.get('data', 'input_mode')) if os.path.exists(self.mask_list[index]): maskmap = ImageHelper.read_image(self.mask_list[index], tool=self.configer.get( 'data', 'image_tool'), mode='P') else: maskmap = np.ones((img.size[1], img.size[0]), dtype=np.uint8) if self.configer.get('data', 'image_tool') == 'pil': maskmap = ImageHelper.to_img(maskmap) kpts, bboxes = self.__read_json_file(self.json_list[index]) if self.aug_transform is not None and len(bboxes) > 0: img, maskmap, kpts, bboxes = self.aug_transform(img, maskmap=maskmap, kpts=kpts, bboxes=bboxes) elif self.aug_transform is not None: img, maskmap, kpts = self.aug_transform(img, maskmap=maskmap, kpts=kpts) width, height = ImageHelper.get_size(maskmap) maskmap = ImageHelper.resize( maskmap, (width // self.configer.get('network', 'stride'), height // self.configer.get('network', 'stride')), interpolation='nearest') maskmap = torch.from_numpy(np.array(maskmap, dtype=np.float32)) maskmap = maskmap.unsqueeze(0) heatmap = self.heatmap_generator(kpts, [width, height], maskmap) vecmap = self.paf_generator(kpts, [width, height], maskmap) if self.img_transform is not None: img = self.img_transform(img) meta = dict(kpts=kpts, ) return dict( img=DataContainer(img, stack=True), heatmap=DataContainer(heatmap, stack=True), maskmap=DataContainer(maskmap, stack=True), vecmap=DataContainer(vecmap, stack=True), meta=DataContainer(meta, stack=False, cpu_only=True), )