Exemple #1
0
 def forward(
         self,
         batch_shape: torch.Size,
         same_on_batch: bool = False) -> Dict[str, Tensor]:  # type:ignore
     batch_size = batch_shape[0]
     _common_param_check(batch_size, same_on_batch)
     _device, _dtype = _extract_device_dtype(
         [self.brightness, self.contrast, self.hue, self.saturation])
     brightness_factor = _adapted_rsampling(
         (batch_size, ), self.brightness_sampler, same_on_batch)
     contrast_factor = _adapted_rsampling(
         (batch_size, ), self.contrast_sampler, same_on_batch)
     hue_factor = _adapted_rsampling((batch_size, ), self.hue_sampler,
                                     same_on_batch)
     saturation_factor = _adapted_rsampling(
         (batch_size, ), self.saturation_sampler, same_on_batch)
     return dict(
         brightness_factor=brightness_factor.to(device=_device,
                                                dtype=_dtype),
         contrast_factor=contrast_factor.to(device=_device, dtype=_dtype),
         hue_factor=hue_factor.to(device=_device, dtype=_dtype),
         saturation_factor=saturation_factor.to(device=_device,
                                                dtype=_dtype),
         order=self.randperm(4).to(device=_device, dtype=_dtype).long(),
     )
Exemple #2
0
    def forward(self, batch_shape: torch.Size, same_on_batch: bool = False) -> Dict[str, torch.Tensor]:  # type:ignore
        batch_size = batch_shape[0]
        height = batch_shape[-2]
        width = batch_shape[-1]

        _device, _dtype = _extract_device_dtype([self.distortion_scale])
        _common_param_check(batch_size, same_on_batch)
        if not (type(height) is int and height > 0 and type(width) is int and width > 0):
            raise AssertionError(f"'height' and 'width' must be integers. Got {height}, {width}.")

        start_points: torch.Tensor = torch.tensor(
            [[[0.0, 0], [width - 1, 0], [width - 1, height - 1], [0, height - 1]]], device=_device, dtype=_dtype
        ).expand(batch_size, -1, -1)

        # generate random offset not larger than half of the image
        fx = self._distortion_scale * width / 2
        fy = self._distortion_scale * height / 2

        factor = torch.stack([fx, fy], dim=0).view(-1, 1, 2).to(device=_device, dtype=_dtype)

        # TODO: This line somehow breaks the gradcheck
        rand_val: torch.Tensor = _adapted_rsampling(start_points.shape, self.rand_val_sampler, same_on_batch).to(
            device=_device, dtype=_dtype
        )
        if self.sampling_method == "basic":
            pts_norm = torch.tensor([[[1, 1], [-1, 1], [-1, -1], [1, -1]]], device=_device, dtype=_dtype)
            offset = factor * rand_val * pts_norm
        elif self.sampling_method == "area_preserving":
            offset = 2 * factor * (rand_val - 0.5)

        end_points = start_points + offset

        return dict(start_points=start_points, end_points=end_points)
Exemple #3
0
    def forward(self, batch_shape: torch.Size, same_on_batch: bool =
                False) -> Dict[str, torch.Tensor]:
        batch_size = batch_shape[0]
        _common_param_check(batch_size, same_on_batch)
        pl_idx = _adapted_rsampling((batch_size,),
                                    self.pl_idx_dist,
                                    same_on_batch)

        return dict(idx=pl_idx.long())
Exemple #4
0
    def forward(self, batch_shape: torch.Size, same_on_batch: bool = False) -> Dict[str, torch.Tensor]:  # type:ignore
        batch_size = batch_shape[0]
        _common_param_check(batch_size, same_on_batch)
        # self.ksize_factor.expand((batch_size, -1))
        _device, _dtype = _extract_device_dtype([self.angle, self.direction])
        yaw_factor = _adapted_rsampling((batch_size,), self.yaw_sampler, same_on_batch)
        pitch_factor = _adapted_rsampling((batch_size,), self.pitch_sampler, same_on_batch)
        roll_factor = _adapted_rsampling((batch_size,), self.roll_sampler, same_on_batch)
        angle_factor = torch.stack([yaw_factor, pitch_factor, roll_factor], dim=1)

        direction_factor = _adapted_rsampling((batch_size,), self.direction_sampler, same_on_batch)
        ksize_factor = _adapted_rsampling((batch_size,), self.ksize_sampler, same_on_batch).int() * 2 + 1

        return dict(
            ksize_factor=ksize_factor.to(device=_device, dtype=torch.int32),
            angle_factor=angle_factor.to(device=_device, dtype=_dtype),
            direction_factor=direction_factor.to(device=_device, dtype=_dtype),
        )
Exemple #5
0
    def forward(
            self,
            batch_shape: torch.Size,
            same_on_batch: bool = False
    ) -> Dict[str, torch.Tensor]:  # type:ignore
        batch_size = batch_shape[0]
        _common_param_check(batch_size, same_on_batch)
        _device, _dtype = _extract_device_dtype([self.degrees])

        return dict(
            yaw=_adapted_rsampling((batch_size, ), self.yaw_sampler,
                                   same_on_batch).to(device=_device,
                                                     dtype=_dtype),
            pitch=_adapted_rsampling((batch_size, ), self.pitch_sampler,
                                     same_on_batch).to(device=_device,
                                                       dtype=_dtype),
            roll=_adapted_rsampling((batch_size, ), self.roll_sampler,
                                    same_on_batch).to(device=_device,
                                                      dtype=_dtype),
        )
Exemple #6
0
    def forward(
            self,
            batch_shape: torch.Size,
            same_on_batch: bool = False
    ) -> Dict[str, torch.Tensor]:  # type:ignore
        batch_size = batch_shape[0]
        depth = batch_shape[-3]
        height = batch_shape[-2]
        width = batch_shape[-1]

        _common_param_check(batch_size, same_on_batch)
        _device, _dtype = _extract_device_dtype([self.distortion_scale])

        start_points: torch.Tensor = torch.tensor(
            [[
                [0.0, 0, 0],
                [width - 1, 0, 0],
                [width - 1, height - 1, 0],
                [0, height - 1, 0],
                [0.0, 0, depth - 1],
                [width - 1, 0, depth - 1],
                [width - 1, height - 1, depth - 1],
                [0, height - 1, depth - 1],
            ]],
            device=_device,
            dtype=_dtype,
        ).expand(batch_size, -1, -1)

        # generate random offset not larger than half of the image
        fx = self._distortion_scale * width / 2
        fy = self._distortion_scale * height / 2
        fz = self._distortion_scale * depth / 2

        factor = torch.stack([fx, fy, fz], dim=0).view(-1, 1,
                                                       3).to(device=_device,
                                                             dtype=_dtype)

        rand_val: torch.Tensor = _adapted_rsampling(
            start_points.shape, self.rand_sampler,
            same_on_batch).to(device=_device, dtype=_dtype)

        pts_norm = torch.tensor(
            [[[1, 1, 1], [-1, 1, 1], [-1, -1, 1], [1, -1, 1], [1, 1, -1],
              [-1, 1, -1], [-1, -1, -1], [1, -1, -1]]],
            device=_device,
            dtype=_dtype,
        )
        end_points = start_points + factor * rand_val * pts_norm

        return dict(start_points=start_points, end_points=end_points)
Exemple #7
0
    def forward(
            self,
            batch_shape: torch.Size,
            same_on_batch: bool = False
    ) -> Dict[str, torch.Tensor]:  # type:ignore
        batch_size = batch_shape[0]
        _common_param_check(batch_size, same_on_batch)
        _device, _dtype = _extract_device_dtype(
            [t for t, _, _, _ in self.samplers])

        return {
            name: _adapted_rsampling((batch_size, ), dist,
                                     same_on_batch).to(device=_device,
                                                       dtype=_dtype)
            for name, dist in self.sampler_dict.items()
        }
Exemple #8
0
    def forward(self,
                batch_shape: torch.Size,
                same_on_batch: bool = False) -> Dict[str, torch.Tensor]:
        batch_size = batch_shape[0]

        _common_param_check(batch_size, same_on_batch)
        _device, _dtype = _extract_device_dtype([self.lambda_val])

        with torch.no_grad():
            batch_probs: torch.Tensor = _adapted_sampling(
                (batch_size, ), self.prob_sampler, same_on_batch)
        mixup_pairs: torch.Tensor = torch.randperm(batch_size,
                                                   device=_device,
                                                   dtype=_dtype).long()
        mixup_lambdas: torch.Tensor = _adapted_rsampling(
            (batch_size, ), self.lambda_sampler, same_on_batch)
        mixup_lambdas = mixup_lambdas * batch_probs

        return dict(
            mixup_pairs=mixup_pairs.to(device=_device, dtype=torch.long),
            mixup_lambdas=mixup_lambdas.to(device=_device, dtype=_dtype),
        )
Exemple #9
0
    def forward(self,
                batch_shape: torch.Size,
                same_on_batch: bool = False) -> Dict[str, torch.Tensor]:
        batch_size = batch_shape[0]
        depth = batch_shape[-3]
        height = batch_shape[-2]
        width = batch_shape[-1]

        if not (type(depth) is int and depth > 0 and type(height) is int
                and height > 0 and type(width) is int and width > 0):
            raise AssertionError(
                f"'depth', 'height' and 'width' must be integers. Got {depth}, {height}, {width}."
            )

        _device, _dtype = _extract_device_dtype(
            [self.degrees, self.translate, self.scale, self.shears])

        # degrees = degrees.to(device=device, dtype=dtype)
        yaw = _adapted_rsampling((batch_size, ), self.yaw_sampler,
                                 same_on_batch)
        pitch = _adapted_rsampling((batch_size, ), self.pitch_sampler,
                                   same_on_batch)
        roll = _adapted_rsampling((batch_size, ), self.roll_sampler,
                                  same_on_batch)
        angles = torch.stack([yaw, pitch, roll], dim=1)

        # compute tensor ranges
        if self._scale is not None:
            scale = torch.stack(
                [
                    _adapted_rsampling(
                        (batch_size, ), self.scale_1_sampler, same_on_batch),
                    _adapted_rsampling(
                        (batch_size, ), self.scale_2_sampler, same_on_batch),
                    _adapted_rsampling(
                        (batch_size, ), self.scale_3_sampler, same_on_batch),
                ],
                dim=1,
            )
        else:
            scale = torch.ones(batch_size, device=_device,
                               dtype=_dtype).reshape(batch_size,
                                                     1).repeat(1, 3)

        if self._translate is not None:
            max_dx: torch.Tensor = self._translate[0] * width
            max_dy: torch.Tensor = self._translate[1] * height
            max_dz: torch.Tensor = self._translate[2] * depth
            # translations should be in x,y,z
            translations = torch.stack(
                [
                    (_adapted_rsampling(
                        (batch_size, ), self.uniform_sampler, same_on_batch) -
                     0.5) * max_dx * 2,
                    (_adapted_rsampling(
                        (batch_size, ), self.uniform_sampler, same_on_batch) -
                     0.5) * max_dy * 2,
                    (_adapted_rsampling(
                        (batch_size, ), self.uniform_sampler, same_on_batch) -
                     0.5) * max_dz * 2,
                ],
                dim=1,
            )
        else:
            translations = torch.zeros((batch_size, 3),
                                       device=_device,
                                       dtype=_dtype)

        # center should be in x,y,z
        center: torch.Tensor = torch.tensor(
            [width, height, depth], device=_device, dtype=_dtype).view(
                1, 3) / 2.0 - 0.5
        center = center.expand(batch_size, -1)

        if self.shears is not None:
            sxy = _adapted_rsampling((batch_size, ), self.sxy_sampler,
                                     same_on_batch)
            sxz = _adapted_rsampling((batch_size, ), self.sxz_sampler,
                                     same_on_batch)
            syx = _adapted_rsampling((batch_size, ), self.syx_sampler,
                                     same_on_batch)
            syz = _adapted_rsampling((batch_size, ), self.syz_sampler,
                                     same_on_batch)
            szx = _adapted_rsampling((batch_size, ), self.szx_sampler,
                                     same_on_batch)
            szy = _adapted_rsampling((batch_size, ), self.szy_sampler,
                                     same_on_batch)
        else:
            sxy = sxz = syx = syz = szx = szy = torch.tensor([0] * batch_size,
                                                             device=_device,
                                                             dtype=_dtype)

        return dict(
            translations=torch.as_tensor(translations,
                                         device=_device,
                                         dtype=_dtype),
            center=torch.as_tensor(center, device=_device, dtype=_dtype),
            scale=torch.as_tensor(scale, device=_device, dtype=_dtype),
            angles=torch.as_tensor(angles, device=_device, dtype=_dtype),
            sxy=torch.as_tensor(sxy, device=_device, dtype=_dtype),
            sxz=torch.as_tensor(sxz, device=_device, dtype=_dtype),
            syx=torch.as_tensor(syx, device=_device, dtype=_dtype),
            syz=torch.as_tensor(syz, device=_device, dtype=_dtype),
            szx=torch.as_tensor(szx, device=_device, dtype=_dtype),
            szy=torch.as_tensor(szy, device=_device, dtype=_dtype),
        )
Exemple #10
0
    def forward(self,
                batch_shape: torch.Size,
                same_on_batch: bool = False) -> Dict[str, torch.Tensor]:
        batch_size = batch_shape[0]
        height = batch_shape[-2]
        width = batch_shape[-1]

        if not (type(height) is int and height > 0 and type(width) is int
                and width > 0):
            raise AssertionError(
                f"'height' and 'width' must be integers. Got {height}, {width}."
            )
        _device, _dtype = _extract_device_dtype([self.beta, self.cut_size])
        _common_param_check(batch_size, same_on_batch)

        if batch_size == 0:
            return dict(
                mix_pairs=torch.zeros([0, 3], device=_device,
                                      dtype=torch.long),
                crop_src=torch.zeros([0, 4, 2],
                                     device=_device,
                                     dtype=torch.long),
            )

        with torch.no_grad():
            batch_probs: torch.Tensor = _adapted_sampling(
                (batch_size * self.num_mix, ), self.prob_sampler,
                same_on_batch)
        mix_pairs: torch.Tensor = torch.rand(self.num_mix,
                                             batch_size,
                                             device=_device,
                                             dtype=_dtype).argsort(dim=1)
        cutmix_betas: torch.Tensor = _adapted_rsampling(
            (batch_size * self.num_mix, ), self.beta_sampler, same_on_batch)

        # Note: torch.clamp does not accept tensor, cutmix_betas.clamp(cut_size[0], cut_size[1]) throws:
        # Argument 1 to "clamp" of "_TensorBase" has incompatible type "Tensor"; expected "float"
        cutmix_betas = torch.min(torch.max(cutmix_betas, self._cut_size[0]),
                                 self._cut_size[1])
        cutmix_rate = torch.sqrt(1.0 - cutmix_betas) * batch_probs

        cut_height = (cutmix_rate * height).floor().to(device=_device,
                                                       dtype=_dtype)
        cut_width = (cutmix_rate * width).floor().to(device=_device,
                                                     dtype=_dtype)
        _gen_shape = (1, )

        if same_on_batch:
            _gen_shape = (cut_height.size(0), )
            cut_height = cut_height[0]
            cut_width = cut_width[0]

        # Reserve at least 1 pixel for cropping.
        x_start: torch.Tensor = _adapted_rsampling(
            _gen_shape, self.rand_sampler,
            same_on_batch) * (width - cut_width - 1)
        y_start: torch.Tensor = _adapted_rsampling(
            _gen_shape, self.rand_sampler,
            same_on_batch) * (height - cut_height - 1)
        x_start = x_start.floor().to(device=_device, dtype=_dtype)
        y_start = y_start.floor().to(device=_device, dtype=_dtype)

        crop_src = bbox_generator(x_start.squeeze(), y_start.squeeze(),
                                  cut_width, cut_height)

        # (B * num_mix, 4, 2) => (num_mix, batch_size, 4, 2)
        crop_src = crop_src.view(self.num_mix, batch_size, 4, 2)

        return dict(
            mix_pairs=mix_pairs.to(device=_device, dtype=torch.long),
            crop_src=crop_src.floor().to(device=_device, dtype=_dtype),
        )
Exemple #11
0
 def forward(self, batch_shape: torch.Size, same_on_batch: bool = False) -> Dict[str, torch.Tensor]:  # type:ignore
     batch_size = batch_shape[0]
     _common_param_check(batch_size, same_on_batch)
     _device, _ = _extract_device_dtype([self.bits if isinstance(self.bits, torch.Tensor) else None])
     bits_factor = _adapted_rsampling((batch_size,), self.bit_sampler, same_on_batch)
     return dict(bits_factor=bits_factor.to(device=_device, dtype=torch.int32))
Exemple #12
0
    def forward(
            self,
            batch_shape: torch.Size,
            same_on_batch: bool = False
    ) -> Dict[str, torch.Tensor]:  # type:ignore
        batch_size = batch_shape[0]
        height = batch_shape[-2]
        width = batch_shape[-1]
        if not (type(height) is int and height > 0 and type(width) is int
                and width > 0):
            raise AssertionError(
                f"'height' and 'width' must be integers. Got {height}, {width}."
            )

        _common_param_check(batch_size, same_on_batch)
        _device, _dtype = _extract_device_dtype([self.ratio, self.scale])
        images_area = height * width
        target_areas = (_adapted_rsampling(
            (batch_size, ), self.scale_sampler, same_on_batch).to(
                device=_device, dtype=_dtype) * images_area)

        if self.ratio[0] < 1.0 and self.ratio[1] > 1.0:
            aspect_ratios1 = _adapted_rsampling(
                (batch_size, ), self.ratio_sampler1, same_on_batch)
            aspect_ratios2 = _adapted_rsampling(
                (batch_size, ), self.ratio_sampler2, same_on_batch)
            if same_on_batch:
                rand_idxs = (torch.round(
                    _adapted_rsampling(
                        (1, ), self.index_sampler,
                        same_on_batch)).repeat(batch_size).bool())
            else:
                rand_idxs = torch.round(
                    _adapted_rsampling((batch_size, ), self.index_sampler,
                                       same_on_batch)).bool()
            aspect_ratios = torch.where(rand_idxs, aspect_ratios1,
                                        aspect_ratios2)
        else:
            aspect_ratios = _adapted_rsampling(
                (batch_size, ), self.ratio_sampler, same_on_batch)

        aspect_ratios = aspect_ratios.to(device=_device, dtype=_dtype)

        # based on target areas and aspect ratios, rectangle params are computed
        heights = torch.min(
            torch.max(torch.round((target_areas * aspect_ratios)**(1 / 2)),
                      torch.tensor(1.0, device=_device, dtype=_dtype)),
            torch.tensor(height, device=_device, dtype=_dtype),
        )

        widths = torch.min(
            torch.max(torch.round((target_areas / aspect_ratios)**(1 / 2)),
                      torch.tensor(1.0, device=_device, dtype=_dtype)),
            torch.tensor(width, device=_device, dtype=_dtype),
        )

        xs_ratio = _adapted_rsampling((batch_size, ), self.uniform_sampler,
                                      same_on_batch).to(device=_device,
                                                        dtype=_dtype)
        ys_ratio = _adapted_rsampling((batch_size, ), self.uniform_sampler,
                                      same_on_batch).to(device=_device,
                                                        dtype=_dtype)

        xs = xs_ratio * (width - widths + 1)
        ys = ys_ratio * (height - heights + 1)

        return dict(
            widths=widths.floor(),
            heights=heights.floor(),
            xs=xs.floor(),
            ys=ys.floor(),
            values=torch.tensor([self.value] * batch_size,
                                device=_device,
                                dtype=_dtype),
        )
Exemple #13
0
    def forward(
            self,
            batch_shape: torch.Size,
            same_on_batch: bool = False
    ) -> Dict[str, torch.Tensor]:  # type:ignore
        batch_size, _, depth, height, width = batch_shape
        _common_param_check(batch_size, same_on_batch)
        _device, _dtype = _extract_device_dtype(
            [self.size if isinstance(self.size, torch.Tensor) else None])

        if not isinstance(self.size, torch.Tensor):
            size = torch.tensor(self.size, device=_device,
                                dtype=_dtype).repeat(batch_size, 1)
        else:
            size = self.size.to(device=_device, dtype=_dtype)
        if size.shape != torch.Size([batch_size, 3]):
            raise AssertionError(
                "If `size` is a tensor, it must be shaped as (B, 3). "
                f"Got {size.shape} while expecting {torch.Size([batch_size, 3])}."
            )
        if not (isinstance(depth, (int, )) and isinstance(height, (int, ))
                and isinstance(width, (int, )) and depth > 0 and height > 0
                and width > 0):
            raise AssertionError(
                f"`batch_shape` should not contain negative values. Got {(batch_shape)}."
            )

        x_diff = width - size[:, 2] + 1
        y_diff = height - size[:, 1] + 1
        z_diff = depth - size[:, 0] + 1

        if (x_diff < 0).any() or (y_diff < 0).any() or (z_diff < 0).any():
            raise ValueError(
                f"input_size {(depth, height, width)} cannot be smaller than crop size {str(size)} in any dimension."
            )

        if batch_size == 0:
            return dict(
                src=torch.zeros([0, 8, 3], device=_device, dtype=_dtype),
                dst=torch.zeros([0, 8, 3], device=_device, dtype=_dtype),
            )

        x_start = _adapted_rsampling((batch_size, ), self.rand_sampler,
                                     same_on_batch).to(device=_device,
                                                       dtype=_dtype)
        y_start = _adapted_rsampling((batch_size, ), self.rand_sampler,
                                     same_on_batch).to(device=_device,
                                                       dtype=_dtype)
        z_start = _adapted_rsampling((batch_size, ), self.rand_sampler,
                                     same_on_batch).to(device=_device,
                                                       dtype=_dtype)

        x_start = (x_start * x_diff).floor()
        y_start = (y_start * y_diff).floor()
        z_start = (z_start * z_diff).floor()

        crop_src = bbox_generator3d(x_start.view(-1), y_start.view(-1),
                                    z_start.view(-1), size[:, 2] - 1,
                                    size[:, 1] - 1, size[:, 0] - 1)

        if self.resize_to is None:
            crop_dst = bbox_generator3d(
                torch.tensor([0] * batch_size, device=_device, dtype=_dtype),
                torch.tensor([0] * batch_size, device=_device, dtype=_dtype),
                torch.tensor([0] * batch_size, device=_device, dtype=_dtype),
                size[:, 2] - 1,
                size[:, 1] - 1,
                size[:, 0] - 1,
            )
        else:
            if not (len(self.resize_to) == 3 and isinstance(
                    self.resize_to[0],
                (int, )) and isinstance(self.resize_to[1], (int, ))
                    and isinstance(self.resize_to[2],
                                   (int, )) and self.resize_to[0] > 0
                    and self.resize_to[1] > 0 and self.resize_to[2] > 0):
                raise AssertionError(
                    f"`resize_to` must be a tuple of 3 positive integers. Got {self.resize_to}."
                )
            crop_dst = torch.tensor(
                [[
                    [0, 0, 0],
                    [self.resize_to[-1] - 1, 0, 0],
                    [self.resize_to[-1] - 1, self.resize_to[-2] - 1, 0],
                    [0, self.resize_to[-2] - 1, 0],
                    [0, 0, self.resize_to[-3] - 1],
                    [self.resize_to[-1] - 1, 0, self.resize_to[-3] - 1],
                    [
                        self.resize_to[-1] - 1, self.resize_to[-2] - 1,
                        self.resize_to[-3] - 1
                    ],
                    [0, self.resize_to[-2] - 1, self.resize_to[-3] - 1],
                ]],
                device=_device,
                dtype=_dtype,
            ).repeat(batch_size, 1, 1)

        return dict(src=crop_src.to(device=_device),
                    dst=crop_dst.to(device=_device))
Exemple #14
0
    def forward(
            self,
            batch_shape: torch.Size,
            same_on_batch: bool = False
    ) -> Dict[str, torch.Tensor]:  # type:ignore
        batch_size = batch_shape[0]
        _common_param_check(batch_size, same_on_batch)
        _device, _dtype = _extract_device_dtype(
            [self.size if isinstance(self.size, torch.Tensor) else None])

        if batch_size == 0:
            return dict(
                src=torch.zeros([0, 4, 2], device=_device, dtype=_dtype),
                dst=torch.zeros([0, 4, 2], device=_device, dtype=_dtype),
            )

        input_size = (batch_shape[-2], batch_shape[-1])
        if not isinstance(self.size, torch.Tensor):
            size = torch.tensor(self.size, device=_device,
                                dtype=_dtype).repeat(batch_size, 1)
        else:
            size = self.size.to(device=_device, dtype=_dtype)
        if size.shape != torch.Size([batch_size, 2]):
            raise AssertionError(
                "If `size` is a tensor, it must be shaped as (B, 2). "
                f"Got {size.shape} while expecting {torch.Size([batch_size, 2])}."
            )
        if not (input_size[0] > 0 and input_size[1] > 0 and (size > 0).all()):
            raise AssertionError(
                f"Got non-positive input size or size. {input_size}, {size}.")
        size = size.floor()

        x_diff = input_size[1] - size[:, 1] + 1
        y_diff = input_size[0] - size[:, 0] + 1

        # Start point will be 0 if diff < 0
        x_diff = x_diff.clamp(0)
        y_diff = y_diff.clamp(0)

        if same_on_batch:
            # If same_on_batch, select the first then repeat.
            x_start = (_adapted_rsampling(
                (batch_size, ), self.rand_sampler, same_on_batch).to(x_diff) *
                       x_diff[0]).floor()
            y_start = (_adapted_rsampling(
                (batch_size, ), self.rand_sampler, same_on_batch).to(y_diff) *
                       y_diff[0]).floor()
        else:
            x_start = (_adapted_rsampling(
                (batch_size, ), self.rand_sampler, same_on_batch).to(x_diff) *
                       x_diff).floor()
            y_start = (_adapted_rsampling(
                (batch_size, ), self.rand_sampler, same_on_batch).to(y_diff) *
                       y_diff).floor()
        crop_src = bbox_generator(
            x_start.view(-1).to(device=_device, dtype=_dtype),
            y_start.view(-1).to(device=_device, dtype=_dtype),
            torch.where(
                size[:, 1] == 0,
                torch.tensor(input_size[1], device=_device, dtype=_dtype),
                size[:, 1]),
            torch.where(
                size[:, 0] == 0,
                torch.tensor(input_size[0], device=_device, dtype=_dtype),
                size[:, 0]),
        )

        if self.resize_to is None:
            crop_dst = bbox_generator(
                torch.tensor([0] * batch_size, device=_device, dtype=_dtype),
                torch.tensor([0] * batch_size, device=_device, dtype=_dtype),
                size[:, 1],
                size[:, 0],
            )
            _output_size = size.to(dtype=torch.long)
        else:
            if not (len(self.resize_to) == 2 and isinstance(
                    self.resize_to[0],
                (int, )) and isinstance(self.resize_to[1], (int, ))
                    and self.resize_to[0] > 0 and self.resize_to[1] > 0):
                raise AssertionError(
                    f"`resize_to` must be a tuple of 2 positive integers. Got {self.resize_to}."
                )
            crop_dst = torch.tensor(
                [[
                    [0, 0],
                    [self.resize_to[1] - 1, 0],
                    [self.resize_to[1] - 1, self.resize_to[0] - 1],
                    [0, self.resize_to[0] - 1],
                ]],
                device=_device,
                dtype=_dtype,
            ).repeat(batch_size, 1, 1)
            _output_size = torch.tensor(self.resize_to,
                                        device=_device,
                                        dtype=torch.long).expand(
                                            batch_size, -1)

        _input_size = torch.tensor(input_size,
                                   device=_device,
                                   dtype=torch.long).expand(batch_size, -1)

        return dict(src=crop_src,
                    dst=crop_dst,
                    input_size=_input_size,
                    output_size=_output_size)
Exemple #15
0
    def forward(
            self,
            batch_shape: torch.Size,
            same_on_batch: bool = False
    ) -> Dict[str, torch.Tensor]:  # type:ignore
        batch_size = batch_shape[0]
        size = (batch_shape[-2], batch_shape[-1])
        _device, _dtype = _extract_device_dtype([self.scale, self.ratio])

        if batch_size == 0:
            return dict(
                src=torch.zeros([0, 4, 2], device=_device, dtype=_dtype),
                dst=torch.zeros([0, 4, 2], device=_device, dtype=_dtype),
                size=torch.zeros([0, 2], device=_device, dtype=_dtype),
            )

        rand = _adapted_rsampling((batch_size, 10), self.rand_sampler,
                                  same_on_batch).to(device=_device,
                                                    dtype=_dtype)
        area = (rand * (self.scale[1] - self.scale[0]) +
                self.scale[0]) * size[0] * size[1]
        log_ratio = _adapted_rsampling((batch_size, 10),
                                       self.log_ratio_sampler,
                                       same_on_batch).to(device=_device,
                                                         dtype=_dtype)
        aspect_ratio = torch.exp(log_ratio)

        w = torch.sqrt(area * aspect_ratio).round().floor()
        h = torch.sqrt(area / aspect_ratio).round().floor()
        # Element-wise w, h condition
        cond = ((0 < w) * (w < size[0]) * (0 < h) * (h < size[1])).int()

        # torch.argmax is not reproducible across devices: https://github.com/pytorch/pytorch/issues/17738
        # Here, we will select the first occurrence of the duplicated elements.
        cond_bool, argmax_dim1 = ((cond.cumsum(1) == 1) & cond.bool()).max(1)
        h_out = w[
            torch.arange(0, batch_size, device=_device, dtype=torch.long),
            argmax_dim1]
        w_out = h[
            torch.arange(0, batch_size, device=_device, dtype=torch.long),
            argmax_dim1]

        if not cond_bool.all():
            # Fallback to center crop
            in_ratio = float(size[0]) / float(size[1])
            _min = self.ratio.min() if isinstance(
                self.ratio, torch.Tensor) else min(self.ratio)
            if in_ratio < _min:  # type:ignore
                h_ct = torch.tensor(size[0], device=_device, dtype=_dtype)
                w_ct = torch.round(h_ct / _min)
            elif in_ratio > _min:  # type:ignore
                w_ct = torch.tensor(size[1], device=_device, dtype=_dtype)
                h_ct = torch.round(w_ct * _min)
            else:  # whole image
                h_ct = torch.tensor(size[0], device=_device, dtype=_dtype)
                w_ct = torch.tensor(size[1], device=_device, dtype=_dtype)
            h_ct = h_ct.floor()
            w_ct = w_ct.floor()

            h_out = h_out.where(cond_bool, h_ct)
            w_out = w_out.where(cond_bool, w_ct)

        # Update the crop size.
        self.size = torch.stack([h_out, w_out], dim=1)
        return super().forward(batch_shape, same_on_batch)
Exemple #16
0
    def forward(
        self,
        batch_shape: torch.Size,
        same_on_batch: bool = False
    ) -> Dict[str, torch.Tensor]:  # type: ignore
        batch_size = batch_shape[0]
        height = batch_shape[-2]
        width = batch_shape[-1]

        _device, _dtype = _extract_device_dtype(
            [self.degrees, self.translate, self.scale, self.shear])
        _common_param_check(batch_size, same_on_batch)
        if not (isinstance(width, (int, )) and isinstance(height, (int, ))
                and width > 0 and height > 0):
            raise AssertionError(
                f"`width` and `height` must be positive integers. Got {width}, {height}."
            )

        angle = _adapted_rsampling((batch_size, ), self.degree_sampler,
                                   same_on_batch).to(device=_device,
                                                     dtype=_dtype)

        # compute tensor ranges
        if self.scale_2_sampler is not None:
            _scale = _adapted_rsampling(
                (batch_size, ), self.scale_2_sampler,
                same_on_batch).unsqueeze(1).repeat(1, 2)
            if self.scale_4_sampler is not None:
                _scale[:, 1] = _adapted_rsampling(
                    (batch_size, ), self.scale_4_sampler, same_on_batch)
            _scale = _scale.to(device=_device, dtype=_dtype)
        else:
            _scale = torch.ones((batch_size, 2), device=_device, dtype=_dtype)

        if self.translate_x_sampler is not None and self.translate_y_sampler is not None:
            translations = torch.stack(
                [
                    _adapted_rsampling(
                        (batch_size, ), self.translate_x_sampler,
                        same_on_batch) * width,
                    _adapted_rsampling(
                        (batch_size, ), self.translate_y_sampler,
                        same_on_batch) * height,
                ],
                dim=-1,
            )
            translations = translations.to(device=_device, dtype=_dtype)
        else:
            translations = torch.zeros((batch_size, 2),
                                       device=_device,
                                       dtype=_dtype)

        center: torch.Tensor = torch.tensor(
            [width, height], device=_device, dtype=_dtype).view(1,
                                                                2) / 2.0 - 0.5
        center = center.expand(batch_size, -1)

        if self.shear_x_sampler is not None and self.shear_y_sampler is not None:
            sx = _adapted_rsampling((batch_size, ), self.shear_x_sampler,
                                    same_on_batch)
            sy = _adapted_rsampling((batch_size, ), self.shear_y_sampler,
                                    same_on_batch)
            sx = sx.to(device=_device, dtype=_dtype)
            sy = sy.to(device=_device, dtype=_dtype)
        else:
            sx = sy = torch.tensor([0] * batch_size,
                                   device=_device,
                                   dtype=_dtype)

        return dict(translations=translations,
                    center=center,
                    scale=_scale,
                    angle=angle,
                    sx=sx,
                    sy=sy)