示例#1
0
def apply_perspective(input: torch.Tensor,
                      params: Dict[str, torch.Tensor]) -> torch.Tensor:
    r"""Perform perspective transform of the given torch.Tensor or batch of tensors.

    Args:
        input (torch.Tensor): Tensor to be transformed with shape (H, W), (C, H, W), (B, C, H, W).
        params (Dict[str, torch.Tensor]):
            - params['batch_prob']: A boolean tensor thatindicating whether if to transform an image in a batch.
            - params['start_points']: Tensor containing [top-left, top-right, bottom-right,
              bottom-left] of the orignal image with shape Bx4x2.
            - params['end_points']: Tensor containing [top-left, top-right, bottom-right,
              bottom-left] of the transformed image with shape Bx4x2.
            - params['interpolation']: Integer tensor. NEAREST = 0, BILINEAR = 1.
            - params['align_corners']: Boolean tensor.

    Returns:
        torch.Tensor: Perspectively transformed tensor.
    """

    input = _transform_input(input)
    _validate_input_dtype(
        input, accepted_dtypes=[torch.float16, torch.float32, torch.float64])

    # arrange input data
    x_data: torch.Tensor = input.view(-1, *input.shape[-3:])

    _, _, height, width = x_data.shape

    # compute the homography between the input points
    transform: torch.Tensor = compute_perspective_transformation(input, params)

    out_data: torch.Tensor = x_data.clone()

    # process valid samples
    mask: torch.Tensor = params['batch_prob'].to(input.device)

    # TODO: look for a workaround for this hack. In CUDA it fails when no elements found.
    # TODO: this if statement is super weird and sum here is not the propeer way to check
    # it's valid. In addition, 'interpolation' shouldn't be a reason to get into the branch.

    if bool(mask.sum() > 0) and ('interpolation' in params):
        # apply the computed transform
        height, width = x_data.shape[-2:]
        resample_name: str = Resample(
            params['interpolation'].item()).name.lower()
        align_corners: bool = cast(bool, params['align_corners'].item())

        out_data[mask] = warp_perspective(x_data[mask],
                                          transform[mask], (height, width),
                                          flags=resample_name,
                                          align_corners=align_corners)

    return out_data.view_as(input)
示例#2
0
def apply_perspective(input: torch.Tensor,
                      params: Dict[str, torch.Tensor],
                      return_transform: bool = False) -> UnionType:
    r"""Perform perspective transform of the given torch.Tensor or batch of tensors.

    Args:
        input (torch.Tensor): Tensor to be transformed with shape (H, W), (C, H, W), (*, C, H, W).
        start_points (torch.Tensor): Tensor containing [top-left, top-right, bottom-right,
        bottom-left] of the orignal image with shape Bx4x2.
        end_points (torch.Tensor): Tensor containing [top-left, top-right, bottom-right,
        bottom-left] of the transformed image with shape Bx4x2.
        return_transform (bool): if ``True`` return the matrix describing the transformation
        applied to each. Default: False.

    Returns:
        torch.Tensor: Perspectively transformed tensor.
    """

    input = _transform_input(input)
    _validate_input_dtype(
        input, accepted_dtypes=[torch.float16, torch.float32, torch.float64])

    # arrange input data
    x_data: torch.Tensor = input.view(-1, *input.shape[-3:])

    _, _, height, width = x_data.shape

    # compute the homography between the input points
    transform: torch.Tensor = get_perspective_transform(
        params['start_points'], params['end_points']).type_as(input)

    out_data: torch.Tensor = x_data.clone()

    # process valid samples
    mask = params['batch_prob'].to(input.device)

    # TODO: look for a workaround for this hack. In CUDA it fails when no elements found.

    if bool(mask.sum() > 0):
        # apply the computed transform
        height, width = x_data.shape[-2:]
        resample_name = Resample(params['interpolation'].item()).name.lower()
        out_data[mask] = warp_perspective(x_data[mask],
                                          transform[mask], (height, width),
                                          flags=resample_name)

    if return_transform:
        return out_data.view_as(input), transform

    return out_data.view_as(input)
示例#3
0
def apply_perspective(input: torch.Tensor,
                      params: Dict[str, torch.Tensor]) -> torch.Tensor:
    r"""Perform perspective transform of the given torch.Tensor or batch of tensors.

    Args:
        input (torch.Tensor): Tensor to be transformed with shape (H, W), (C, H, W), (*, C, H, W).
        start_points (torch.Tensor): Tensor containing [top-left, top-right, bottom-right,
        bottom-left] of the orignal image with shape Bx4x2.
        end_points (torch.Tensor): Tensor containing [top-left, top-right, bottom-right,
        bottom-left] of the transformed image with shape Bx4x2.
        return_transform (bool): if ``True`` return the matrix describing the transformation
        applied to each. Default: False.

    Returns:
        torch.Tensor: Perspectively transformed tensor.
    """

    input = _transform_input(input)
    _validate_input_dtype(
        input, accepted_dtypes=[torch.float16, torch.float32, torch.float64])

    # arrange input data
    x_data: torch.Tensor = input.view(-1, *input.shape[-3:])

    _, _, height, width = x_data.shape

    # compute the homography between the input points
    transform = compute_perspective_transformation(input, params)

    out_data: torch.Tensor = x_data.clone()

    # process valid samples
    mask = params['batch_prob'].to(input.device)

    # TODO: look for a workaround for this hack. In CUDA it fails when no elements found.
    # TODO: this if statement is super weird and sum here is not the propeer way to check
    # it's valid. In addition, 'interpolation' shouldn't be a reason to get into the branch.

    if bool(mask.sum() > 0) and ('interpolation' in params):
        # apply the computed transform
        height, width = x_data.shape[-2:]
        resample_name = Resample(params['interpolation'].item()).name.lower()
        out_data[mask] = warp_perspective(
            x_data[mask],
            transform[mask],  # type: ignore
            (height, width),
            flags=resample_name,
            align_corners=params['align_corners'])
    return out_data.view_as(input)
示例#4
0
def apply_perspective(input: torch.Tensor, params: Dict[str, torch.Tensor],
                      flags: Dict[str, torch.Tensor]) -> torch.Tensor:
    r"""Perform perspective transform of the given torch.Tensor or batch of tensors.

    Args:
        input (torch.Tensor): Tensor to be transformed with shape (H, W), (C, H, W), (B, C, H, W).
        params (Dict[str, torch.Tensor]):
            - params['start_points']: Tensor containing [top-left, top-right, bottom-right,
              bottom-left] of the orignal image with shape Bx4x2.
            - params['end_points']: Tensor containing [top-left, top-right, bottom-right,
              bottom-left] of the transformed image with shape Bx4x2.
        flags (Dict[str, torch.Tensor]):
            - params['interpolation']: Integer tensor. NEAREST = 0, BILINEAR = 1.
            - params['align_corners']: Boolean tensor.

    Returns:
        torch.Tensor: Perspectively transformed tensor.
    """
    input = _transform_input(input)
    _validate_input_dtype(
        input, accepted_dtypes=[torch.float16, torch.float32, torch.float64])

    _, _, height, width = input.shape

    # compute the homography between the input points
    transform: torch.Tensor = compute_perspective_transformation(input, params)

    out_data: torch.Tensor = input.clone()

    # apply the computed transform
    height, width = input.shape[-2:]
    resample_name: str = Resample(flags['interpolation'].item()).name.lower()
    align_corners: bool = cast(bool, flags['align_corners'].item())

    out_data = warp_perspective(input,
                                transform, (height, width),
                                flags=resample_name,
                                align_corners=align_corners)

    return out_data.view_as(input)
示例#5
0
文件: export.py 项目: wx-b/SOLD2
def homography_adaptation(input_images, model, grid_size, homography_cfg):
    """ The homography adaptation process.
    Arguments:
        input_images: The images to be evaluated.
        model: The pytorch model in evaluation mode.
        grid_size: Grid size of the junction decoder.
        homography_cfg: Homography adaptation configurations.
    """
    # Get the device of the current model
    device = next(model.parameters()).device

    # Define some constants and placeholder
    batch_size, _, H, W = input_images.shape
    num_iter = homography_cfg["num_iter"]
    junc_probs = torch.zeros([batch_size, num_iter, H, W], device=device)
    junc_counts = torch.zeros([batch_size, 1, H, W], device=device)
    heatmap_probs = torch.zeros([batch_size, num_iter, H, W], device=device)
    heatmap_counts = torch.zeros([batch_size, 1, H, W], device=device)
    margin = homography_cfg["valid_border_margin"]

    # Keep a config with no artifacts
    homography_cfg_no_artifacts = copy.copy(homography_cfg["homographies"])
    homography_cfg_no_artifacts["allow_artifacts"] = False

    for idx in range(num_iter):
        if idx <= num_iter // 5:
            # Ensure that 20% of the homographies have no artifact
            H_mat_lst = [sample_homography(
                [H,W], **homography_cfg_no_artifacts)[0][None]
                         for _ in range(batch_size)]
        else:
            H_mat_lst = [sample_homography(
                [H,W], **homography_cfg["homographies"])[0][None]
                         for _ in range(batch_size)]

        H_mats = np.concatenate(H_mat_lst, axis=0)
        H_tensor = torch.tensor(H_mats, dtype=torch.float, device=device)
        H_inv_tensor = torch.inverse(H_tensor)

        # Perform the homography warp
        images_warped = warp_perspective(input_images, H_tensor, (H, W),
                                         flags="bilinear")
        
        # Warp the mask
        masks_junc_warped = warp_perspective(
            torch.ones([batch_size, 1, H, W], device=device),
            H_tensor, (H, W), flags="nearest")
        masks_heatmap_warped = warp_perspective(
            torch.ones([batch_size, 1, H, W], device=device),
            H_tensor, (H, W), flags="nearest")

        # Run the network forward pass
        with torch.no_grad():
            outputs = model(images_warped)
        
        # Unwarp and mask the junction prediction
        junc_prob_warped = pixel_shuffle(softmax(
            outputs["junctions"], dim=1)[:, :-1, :, :], grid_size)
        junc_prob = warp_perspective(junc_prob_warped, H_inv_tensor,
                                     (H, W), flags="bilinear")

        # Create the out of boundary mask
        out_boundary_mask = warp_perspective(
            torch.ones([batch_size, 1, H, W], device=device),
            H_inv_tensor, (H, W), flags="nearest")
        out_boundary_mask = adjust_border(out_boundary_mask, device, margin)

        junc_prob = junc_prob * out_boundary_mask
        junc_count = warp_perspective(masks_junc_warped * out_boundary_mask,
                                      H_inv_tensor, (H, W), flags="nearest")

        # Unwarp the mask and heatmap prediction
        # Always fetch only one channel
        if outputs["heatmap"].shape[1] == 2:
            # Convert to single channel directly from here
            heatmap_prob_warped = softmax(outputs["heatmap"],
                                          dim=1)[:, 1:, :, :]
        else:
            heatmap_prob_warped = torch.sigmoid(outputs["heatmap"])
        
        heatmap_prob_warped = heatmap_prob_warped * masks_heatmap_warped
        heatmap_prob = warp_perspective(heatmap_prob_warped, H_inv_tensor,
                                        (H, W), flags="bilinear")
        heatmap_count = warp_perspective(masks_heatmap_warped, H_inv_tensor,
                                         (H, W), flags="nearest")

        # Record the results
        junc_probs[:, idx:idx+1, :, :] = junc_prob
        heatmap_probs[:, idx:idx+1, :, :] = heatmap_prob
        junc_counts += junc_count
        heatmap_counts += heatmap_count

    # Perform the accumulation operation
    if homography_cfg["min_counts"] > 0:
        min_counts = homography_cfg["min_counts"]
        junc_count_mask = (junc_counts < min_counts)
        heatmap_count_mask = (heatmap_counts < min_counts)
        junc_counts[junc_count_mask] = 0
        heatmap_counts[heatmap_count_mask] = 0
    else:
        junc_count_mask = np.zeros_like(junc_counts, dtype=bool)
        heatmap_count_mask = np.zeros_like(heatmap_counts, dtype=bool)
    
    # Compute the mean accumulation
    junc_probs_mean = torch.sum(junc_probs, dim=1, keepdim=True) / junc_counts
    junc_probs_mean[junc_count_mask] = 0.
    heatmap_probs_mean = (torch.sum(heatmap_probs, dim=1, keepdim=True)
                          / heatmap_counts)
    heatmap_probs_mean[heatmap_count_mask] = 0.

    # Compute the max accumulation
    junc_probs_max = torch.max(junc_probs, dim=1, keepdim=True)[0]
    junc_probs_max[junc_count_mask] = 0.
    heatmap_probs_max = torch.max(heatmap_probs, dim=1, keepdim=True)[0]
    heatmap_probs_max[heatmap_count_mask] = 0.

    return {"junc_probs_mean": junc_probs_mean,
            "junc_probs_max": junc_probs_max,
            "junc_counts": junc_counts,
            "heatmap_probs_mean": heatmap_probs_mean,
            "heatmap_probs_max": heatmap_probs_max,
            "heatmap_counts": heatmap_counts}