Ejemplo n.º 1
0
    def forward(
        self, image: Tensor, target: Optional[Dict[str, Tensor]] = None
    ) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]:
        if isinstance(image, torch.Tensor):
            if image.ndimension() not in {2, 3}:
                raise ValueError("image should be 2/3 dimensional. Got {} dimensions.".format(image.ndimension()))
            elif image.ndimension() == 2:
                image = image.unsqueeze(0)

        r = torch.rand(7)

        if r[0] < self.p:
            image = self._brightness(image)

        contrast_before = r[1] < 0.5
        if contrast_before:
            if r[2] < self.p:
                image = self._contrast(image)

        if r[3] < self.p:
            image = self._saturation(image)

        if r[4] < self.p:
            image = self._hue(image)

        if not contrast_before:
            if r[5] < self.p:
                image = self._contrast(image)

        if r[6] < self.p:
            channels = F.get_image_num_channels(image)
            permutation = torch.randperm(channels)

            is_pil = F._is_pil_image(image)
            if is_pil:
                image = F.pil_to_tensor(image)
                image = F.convert_image_dtype(image)
            image = image[..., permutation, :, :]
            if is_pil:
                image = F.to_pil_image(image)

        return image, target
Ejemplo n.º 2
0
    def forward(
        self, image: Tensor, target: Optional[Dict[str, Tensor]] = None
    ) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]:
        if isinstance(image, torch.Tensor):
            if image.ndimension() not in {2, 3}:
                raise ValueError(f"image should be 2/3 dimensional. Got {image.ndimension()} dimensions.")
            elif image.ndimension() == 2:
                image = image.unsqueeze(0)

        if torch.rand(1) >= self.p:
            return image, target

        _, orig_h, orig_w = F.get_dimensions(image)

        r = self.side_range[0] + torch.rand(1) * (self.side_range[1] - self.side_range[0])
        canvas_width = int(orig_w * r)
        canvas_height = int(orig_h * r)

        r = torch.rand(2)
        left = int((canvas_width - orig_w) * r[0])
        top = int((canvas_height - orig_h) * r[1])
        right = canvas_width - (left + orig_w)
        bottom = canvas_height - (top + orig_h)

        if torch.jit.is_scripting():
            fill = 0
        else:
            fill = self._get_fill_value(F._is_pil_image(image))

        image = F.pad(image, [left, top, right, bottom], fill=fill)
        if isinstance(image, torch.Tensor):
            # PyTorch's pad supports only integers on fill. So we need to overwrite the colour
            v = torch.tensor(self.fill, device=image.device, dtype=image.dtype).view(-1, 1, 1)
            image[..., :top, :] = image[..., :, :left] = image[..., (top + orig_h) :, :] = image[
                ..., :, (left + orig_w) :
            ] = v

        if target is not None:
            target["boxes"][:, 0::2] += left
            target["boxes"][:, 1::2] += top

        return image, target
Ejemplo n.º 3
0
    def __call__(self, img):
        """Call function of PBATransformer.

        :param img: input image
        :type img: numpy or tensor
        :return: the image after transform
        :rtype: numpy or tensor
        """
        count = np.random.choice([0, 1, 2], p=[0.2, 0.3, 0.5])
        policys = self.policys
        np.random.shuffle(policys)
        whether_cutout = [0, 0]
        for policy in policys:
            if count == 0:
                break
            if len(policy) != 3:
                raise ValueError(
                    'set policy illegal, policy should be (op, prob, mag)!')
            op, prob, mag = policy
            if np.random.random() > prob:
                continue
            else:
                count -= 1
                if op == "Cutout":
                    if whether_cutout[0] == 0:
                        whether_cutout[0] = mag
                    else:
                        whether_cutout[1] = mag
                    continue
                operation = ClassFactory.get_cls(ClassType.TRANSFORM, op)
                current_operation = operation(mag)
                img = current_operation(img)

        from torchvision.transforms import functional as F
        if F._is_pil_image(img):
            img = F.to_tensor(img)

        img = Cutout(8)(img)
        for i in whether_cutout:
            if i:
                img = Cutout(i)(img)
        return img
Ejemplo n.º 4
0
    def __call__(self, img, target):
        """
        Args:
            img (PIL Image): Image to be Perspectively transformed.

        Returns:
            PIL Image: Random perspectivley transformed image.
        """
        if not F._is_pil_image(img):
            raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
        if self.custom != None:
            target = self.custom(target)

        if random.random() < self.p:
            width, height = img.size
            startpoints, endpoints = self.get_params(width, height, self.distortion_scale)
            if isinstance(target, Image) and target.size == img.size:
                target = F.perspective(target, startpoints, endpoints, self.interpolation)
            return F.perspective(img, startpoints, endpoints, self.interpolation), target
        return img, target
Ejemplo n.º 5
0
def resize(img, label, size, interpolation=Image.BILINEAR):
    if not _is_pil_image(img):
        raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
    if not (isinstance(size, int) or (isinstance(size, Iterable) and len(size) == 2)):
        raise TypeError('Got inappropriate size arg: {}'.format(size))

    if isinstance(size, int):
        w, h = img.size
        if (w <= h and w == size) or (h <= w and h == size):
            return img, label
        if w < h:
            ow = size
            oh = int(size * h / w)
            return img.resize((ow, oh), interpolation), label.resize((ow, oh), interpolation)
        else:
            oh = size
            ow = int(size * w / h)
            return img.resize((ow, oh), interpolation), label.resize((ow, oh), interpolation)
    else:
        return img.resize(size[::-1], interpolation), label.resize(size[::-1], interpolation)
Ejemplo n.º 6
0
    def invert(self, img):
        r"""Invert the input PIL Image.
        Args:
            img (PIL Image): Image to be inverted.
        Returns:
            PIL Image: Inverted image.
        """
        if not F._is_pil_image(img):
            raise TypeError('img should be PIL Image. Got {}'.format(type(img)))

        if img.mode == 'RGBA':
            r, g, b, a = img.split()
            rgb = Image.merge('RGB', (r, g, b))
            inv = ImageOps.invert(rgb)
            r, g, b = inv.split()
            inv = Image.merge('RGBA', (r, g, b, a))
        elif img.mode == 'LA':
            l, a = img.split()
            l = ImageOps.invert(l)
            inv = Image.merge('LA', (l, a))
        else:
            inv = ImageOps.invert(img)
        return inv
Ejemplo n.º 7
0
    def __call__(self, image, target):
        """
        Args:
            img (PIL Image): Image to be distort transformed.

        Returns:
            PIL Image: distort transformed image.
        """
        if not F._is_pil_image(image):
            raise TypeError('img should be PIL Image. Got {}'.format(
                type(image)))

        image = np.array(image)

        if random.randrange(2):

            # brightness distortion
            if random.randrange(2):
                self.convert(image, beta=random.uniform(-32, 32))

            # contrast distortion
            if random.randrange(2):
                self.convert(image, alpha=random.uniform(0.5, 1.5))

            image = cv2.cvtColor(image, cv2.COLOR_RGB2HSV)

            # saturation distortion
            if random.randrange(2):
                self.convert(image[:, :, 1], alpha=random.uniform(0.5, 1.5))

            # hue distortion
            if random.randrange(2):
                tmp = image[:, :, 0].astype(int) + random.randint(-18, 18)
                tmp %= 180
                image[:, :, 0] = tmp

            image = cv2.cvtColor(image, cv2.COLOR_HSV2RGB)

        else:

            # brightness distortion
            if random.randrange(2):
                self.convert(image, beta=random.uniform(-32, 32))

            image = cv2.cvtColor(image, cv2.COLOR_RGB2HSV)

            # saturation distortion
            if random.randrange(2):
                self.convert(image[:, :, 1], alpha=random.uniform(0.5, 1.5))

            # hue distortion
            if random.randrange(2):
                tmp = image[:, :, 0].astype(int) + random.randint(-18, 18)
                tmp %= 180
                image[:, :, 0] = tmp

            image = cv2.cvtColor(image, cv2.COLOR_HSV2RGB)

            # contrast distortion
            if random.randrange(2):
                self.convert(image, alpha=random.uniform(0.5, 1.5))

        image = Image.fromarray(image)
        return image, target
Ejemplo n.º 8
0
    def __call__(self, img):
        assert TVF._is_pil_image(img)

        return TVF.center_crop(img, min(img.size))
Ejemplo n.º 9
0
    def __call__(self, example):
        assert isinstance(example, dict)
        assert 'image' in example
        # assert 'coordinates' in example
        # assert isinstance(example['coordinates'], np.ndarray)
        # assert example['coordinates'].ndim == 3
        # assert example['coordinates'].shape[0] == 2
        # assert example['coordinates'].shape[1] == 4

        image = example['image']

        if not F._is_pil_image(image):
            raise TypeError('image should be PIL Image. Got {}'.format(
                type(image)))
        if not (isinstance(self._output_size, int) or
                (isinstance(self._output_size, Iterable) and
                 len(self._output_size) == 2)):
            raise TypeError('Got inappropriate self._output_size '
                            'arg: {}'.format(self._output_size))

        input_width, input_height = image.size
        if isinstance(self._output_size, int):
            if (input_width <= input_height and
                    input_width == self._output_size) or\
                    (input_height <= input_width and
                     input_height == self._output_size):
                if not (input_height % pconfig.SDR or
                        input_width % pconfig.SDR):
                    return example
            if input_width < input_height:
                output_width = self._output_size
                output_height =\
                    int(self._output_size * input_height / input_width)
                padded_height =\
                    math.ceil(output_height / pconfig.SDR) * pconfig.SDR
                image = image.resize((output_width, output_height),
                                     self._interpolation)
                image = F.pad(image,
                              padding=(0, 0, 0,
                                       padded_height - output_height),
                              fill=0,
                              padding_mode='constant')
            else:
                output_height = self._output_size
                output_width =\
                    int(self._output_size * input_width / input_height)
                padded_width =\
                    math.ceil(output_width / pconfig.SDR) * pconfig.SDR
                image = image.resize((output_width, output_height),
                                     self._interpolation)
                image = F.pad(image,
                              padding=(0, 0,
                                       padded_width - output_width,
                                       0),
                              fill=0,
                              padding_mode='constant')
        else:
            output_width, output_height = self._output_size[::-1]
            image = image.resize(self._output_size[::-1], self._interpolation)
        example['image'] = image
        if 'coordinates' in example:
            example['coordinates'][0, ...] =\
                example['coordinates'][0, ...] * (output_width / input_width)
            example['coordinates'][1, ...] =\
                example['coordinates'][1, ...] * (output_height / input_height)
        if ('region_map' in example and
            (example['region_map'].shape[0] != int(image.size[1] / 2) or
             example['region_map'].shape[1] != int(image.size[0] / 2))):
            example['region_map'] =\
                cv2.resize(example['region_map'],
                           (int(image.size[0] / 2), int(image.size[1] / 2)))
        if ('affinity_map' in example and
            (example['affinity_map'].shape[0] != int(image.size[1] / 2) or
             example['affinity_map'].shape[1] != int(image.size[0] / 2))):
            example['affinity_map'] =\
                cv2.resize(example['affinity_map'],
                           (int(image.size[0] / 2), int(image.size[1] / 2)))
        if ('confidence_map' in example and
            (example['confidence_map'].shape[0] != int(image.size[1] / 2) or
             example['confidence_map'].shape[1] != int(image.size[0] / 2))):
            example['confidence_map'] =\
                cv2.resize(example['confidence_map'],
                           (int(image.size[0] / 2), int(image.size[1] / 2)))
        return example
Ejemplo n.º 10
0
def transpose(image):
    if not F._is_pil_image(image):
        raise TypeError('image should be PIL Image. Got {}'.format(type(image)))

    return image.transpose(Image.TRANSPOSE)