Ejemplo n.º 1
0
def apply_affine(input: torch.Tensor,
                 params: Dict[str, torch.Tensor],
                 return_transform: bool = False) -> UnionType:
    if not torch.is_tensor(input):
        raise TypeError(f"Input type is not a torch.Tensor. Got {type(input)}")
    r"""Random affine transformation of the image keeping center invariant
        Args:
            input (torch.Tensor): Tensor to be transformed with shape (H, W), (C, H, W), (*, C, H, W).
            degrees (float or tuple): Range of degrees to select from.
                If degrees is a number instead of sequence like (min, max), the range of degrees
                will be (-degrees, +degrees). Set to 0 to deactivate rotations.
            translate (tuple, optional): tuple of maximum absolute fraction for horizontal
                and vertical translations. For example translate=(a, b), then horizontal shift
                is randomly sampled in the range -img_width * a < dx < img_width * a and vertical shift is
                randomly sampled in the range -img_height * b < dy < img_height * b. Will not translate by default.
            scale (tuple, optional): scaling factor interval, e.g (a, b), then scale is
                randomly sampled from the range a <= scale <= b. Will keep original scale by default.
            shear (sequence or float, optional): Range of degrees to select from.
                If shear is a number, a shear parallel to the x axis in the range (-shear, +shear)
                will be applied. Else if shear is a tuple or list of 2 values a shear parallel to the x axis in the
                range (shear[0], shear[1]) will be applied. Else if shear is a tuple or list of 4 values,
                a x-axis shear in (shear[0], shear[1]) and y-axis shear in (shear[2], shear[3]) will be applied.
                Will not apply shear by default
            resample (int): Can be retrieved from Resample. 0 is NEAREST, 1 is BILINEAR.
            return_transform (bool): if ``True`` return the matrix describing the transformation
                applied to each. Default: False.
            mode (str): interpolation mode to calculate output values
                'bilinear' | 'nearest'. Default: 'bilinear'.
            padding_mode (str): padding mode for outside grid values
                'zeros' | 'border' | 'reflection'. Default: 'zeros'.
    """

    input = _transform_input(input)
    _validate_input_dtype(
        input, accepted_dtypes=[torch.float16, torch.float32, torch.float64])

    # arrange input data
    x_data: torch.Tensor = input.view(-1, *input.shape[-3:])

    height, width = x_data.shape[-2:]

    # concatenate transforms
    transform = get_affine_matrix2d(params['translations'], params['center'],
                                    params['scale'], params['angle'],
                                    deg2rad(params['sx']),
                                    deg2rad(params['sy'])).type_as(input)

    resample_name = Resample(params['resample'].item()).name.lower()

    out_data: torch.Tensor = warp_affine(x_data, transform[:, :2, :],
                                         (height, width), resample_name)

    if return_transform:
        return out_data.view_as(input), transform

    return out_data.view_as(input)
Ejemplo n.º 2
0
    def forward(self, inp, target, save_dir=None, debug=True):
        bs, _, h, w = inp.shape
        masks = make_center_mask(self.image_dim, self.patch_size, bs)
        if debug: assert (masks.max() <= 1) and (masks.min() >= 0)
        patches = self.patches[target]

        if self.apply_transforms:
            masks, tx_fn = self.aff_transformer(masks)
            patches = warp_affine(patches, tx_fn[:, :2, :], dsize=(h, w))

        inp = (masks * patches) + (1 - masks) * inp
        return inp
Ejemplo n.º 3
0
    def forward(self, inp, target, save_dir=None):
        inp = inp.detach()
        # Shape variables
        bs, _, h, w = inp.shape
        gpu_num = inp.device.index
        max_bs = self.batch_size // self.num_gpus + 1
        ind = ch.zeros_like(target) if self.textures.shape[0] == 1 else target
        start, end = (gpu_num * max_bs, gpu_num * max_bs + bs)
        # Get texture to render
        tex_to_render = self.textures[ind,...].permute(0, 2, 3, 1)
        tex_to_render = tex_to_render.reshape(bs, -1).detach().cpu()
        self.texture_sh_mem[start:end] = tex_to_render
        [self.in_q.put(i) for i in range(start, end)]
        if self.debug: a = time.time()
        while not ch.all(self.dones_sh_mem[start:end]): pass
        if self.debug: print(f"Time spent waiting: {time.time() - a:.2f}s")
        self.dones_sh_mem[start:end] = False

        # Backwards pass / diff rendering
        uv_data = self.uv_map_sh_mem[start:end].to(inp.device)
        tex_to_render = self.textures[ind,...]
        renders, tx_mat = self.diff_render_with_tex(tex_to_render, uv_data, inp)
        renders = ch.clamp(renders, 0, 1)

        # Real rendering
        if self.render_forward:
            real_render = self.render_sh_mem[start:end].to(inp.device).detach()
            real_render = real_render.permute(0, 3, 1, 2).contiguous()
            real_render = warp_affine(real_render, tx_mat[:,:2,:], dsize=(h, w))
            real_render_alpha = real_render[:,3,...][:,None,...]
            real_render = real_render[:,:3] * real_render_alpha + inp * (1 - real_render_alpha)
            real_render = ch.clamp(real_render, 0, 1)
        else:
            real_render = renders

        if (save_dir is not None) and (inp.device.index == 0):
            vis_tools.show_image_row([real_render[:10].cpu()])
            plt.savefig(str(save_dir / "real_render_batch.png"))
            plt.close()
            if self.corruptions is not None:
                vis_tools.show_image_row([self.corruptions(real_render)[:10].cpu()])
                plt.savefig(str(save_dir / "real_render_batch_wc.png"))
                plt.close()
            vis_tools.show_image_row([renders[:10].cpu()])
            plt.savefig(str(save_dir / "diff_render_batch.png"))
            plt.close()
            self.output_vis(save_dir)

        res = renders - renders.detach() + real_render
        if self.corruptions is not None: 
            return self.corruptions(res)
        return res
Ejemplo n.º 4
0
def apply_affine(input: torch.Tensor, params: Dict[str, torch.Tensor],
                 flags: Dict[str, torch.Tensor]) -> torch.Tensor:
    r"""Random affine transformation of the image keeping center invariant.

    Args:
        input (torch.Tensor): Tensor to be transformed with shape (H, W), (C, H, W), (B, C, H, W).
        params (Dict[str, torch.Tensor]):
            - params['angle']: Degrees of rotation.
            - params['translations']: Horizontal and vertical translations.
            - params['center']: Rotation center.
            - params['scale']: Scaling params.
            - params['sx']: Shear param toward x-axis.
            - params['sy']: Shear param toward y-axis.
        flags (Dict[str, torch.Tensor]):
            - params['resample']: Integer tensor. NEAREST = 0, BILINEAR = 1.
            - params['padding_mode']: Integer tensor, see SamplePadding enum.
            - params['align_corners']: Boolean tensor.

    Returns:
        torch.Tensor: The transfromed input
    """
    if not torch.is_tensor(input):
        raise TypeError(f"Input type is not a torch.Tensor. Got {type(input)}")

    input = _transform_input(input)
    _validate_input_dtype(
        input, accepted_dtypes=[torch.float16, torch.float32, torch.float64])

    # arrange input data
    x_data: torch.Tensor = input.view(-1, *input.shape[-3:])

    height, width = x_data.shape[-2:]

    # concatenate transforms
    transform: torch.Tensor = compute_affine_transformation(input, params)

    resample_name: str = Resample(flags['resample'].item()).name.lower()
    padding_mode: str = SamplePadding(
        flags['padding_mode'].item()).name.lower()
    align_corners: bool = cast(bool, flags['align_corners'].item())

    out_data: torch.Tensor = warp_affine(x_data,
                                         transform[:, :2, :], (height, width),
                                         resample_name,
                                         align_corners=align_corners,
                                         padding_mode=padding_mode)
    return out_data.view_as(input)