Exemplo n.º 1
0
def complex_center_crop(data_list, shape, offset=1, contiguous=False):
    """
    Apply a center crop to the input data, or to a list of complex images

    Parameters
    ----------
    data_list : List[torch.Tensor] or torch.Tensor
        The complex input tensor to be center cropped. It should have at least 3 dimensions
         and the cropping is applied along dimensions didx and didx+1 and the last dimensions should have a size of 2.
    shape : Tuple[int, int]
        The output shape. The shape should be smaller than the corresponding dimensions of data.
        If one value is None, this is filled in by the image shape.
    offset : int
        Starting dimension for cropping.
    contiguous : bool
        Return as a contiguous array. Useful for fast reshaping or viewing.

    Returns
    -------
    torch.Tensor or list[torch.Tensor]: The center cropped input_image
    """
    data_list = ensure_list(data_list)
    assert_same_shape(data_list)

    image_shape = list(data_list[0].shape)
    ndim = data_list[0].ndim
    bbox = [0] * ndim + image_shape

    # Allow for False in crop directions
    shape = [
        _ if _ else image_shape[idx + offset] for idx, _ in enumerate(shape)
    ]
    for idx in range(len(shape)):
        bbox[idx + offset] = (image_shape[idx + offset] - shape[idx]) // 2
        bbox[len(image_shape) + idx + offset] = shape[idx]

    if not all([_ >= 0 for _ in bbox[:ndim]]):
        raise ValueError(
            f"Bounding box requested has negative values, "
            f"this is likely to data size being smaller than the crop size. Got {bbox} with image_shape {image_shape} "
            f"and requested shape {shape}.")

    output = [crop_to_bbox(data, bbox) for data in data_list]

    if contiguous:
        output = [_.contiguous() for _ in output]

    if len(output) == 1:  # Only one element:
        output = output[0]
    return output
Exemplo n.º 2
0
def complex_center_crop_previous(data_list, shape, didx=-3, contiguous=False):
    """
    Apply a center crop to the input data, or to a list of complex images


    Parameters_o
    ----------
    data_list : List[torch.Tensor] or torch.Tensor
        The complex input tensor to be center cropped. It should have at least 3 dimensions
         and the cropping is applied along dimensions didx and didx+1 and the last dimensions should have a size of 2.
    shape : Tuple[int, int]
        The output shape. The shape should be smaller than the corresponding dimensions of data.
    didx : int
        Starting dimension for cropping.
    contiguous : bool
        Return as a contiguous array. Useful for fast reshaping or viewing.

    Returns
    -------
    torch.Tensor or list[torch.Tensor]: The center cropped input_image

    # TODO(jt): We can use crop_to_bbox here.
    """
    data_list = ensure_list(data_list)
    for data in data_list:
        assert didx in [
            -3, -2
        ], "Cropping needs to be done in the spatial dimensions."
        assert 0 < shape[0] <= data.shape[didx]
        assert 0 < shape[1] <= data.shape[didx + 1]

    w_from = (data_list[0].shape[didx] - shape[0]) // 2
    h_from = (data_list[0].shape[didx + 1] - shape[1]) // 2
    w_to = w_from + shape[0]
    h_to = h_from + shape[1]
    if didx == -3:
        output = [data[..., w_from:w_to, h_from:h_to, :] for data in data_list]
    else:
        output = [data[..., w_from:w_to, h_from:h_to] for data in data_list]

    if contiguous:
        output = [_.contiguous() for _ in output]

    if len(output) == 1:  # Only one element:
        output = output[0]
    return output
Exemplo n.º 3
0
def complex_random_crop(data_list, shape, contiguous=False):
    """
    Apply a random crop to the input data tensor or a list of complex.

    Parameters
    ----------
    data_list : Union[List[torch.Tensor], torch.Tensor]
        The complex input tensor to be center cropped. It should have at least 3 dimensions and the cropping is applied
        along dimensions -3 and -2 and the last dimensions should have a size of 2.
    shape : Tuple[int, int]
        The output shape. The shape should be smaller than the corresponding dimensions of data.
    contiguous : bool
            Return as a contiguous array. Useful for fast reshaping or viewing.

    Returns
    -------
    torch.Tensor: The center cropped input tensor or list of tensors

    """
    data_list = ensure_list(data_list)

    # TODO: Check if all have same shape
    for data in data_list:
        assert 0 < shape[0] <= data.shape[-3]
        assert 0 < shape[1] <= data.shape[-2]

    w_from = np.random.randint(0, data_list[0].shape[-3] - shape[0] + 1)
    h_from = np.random.randint(0, data_list[0].shape[-2] - shape[1] + 1)

    w_to = w_from + shape[0]
    h_to = h_from + shape[1]

    output = [data[..., w_from:w_to, h_from:h_to, :] for data in data_list]

    if contiguous:
        output = [_.contiguous() for _ in output]

    if len(output) == 1:
        return output[0]

    return output
Exemplo n.º 4
0
def complex_random_crop(
    data_list,
    crop_shape,
    offset: int = 1,
    contiguous: bool = False,
    sampler: str = "uniform",
    sigma: bool = None,
):
    """
    Apply a random crop to the input data tensor or a list of complex.

    Parameters
    ----------
    data_list : Union[List[torch.Tensor], torch.Tensor]
        The complex input tensor to be center cropped. It should have at least 3 dimensions and the cropping is applied
        along dimensions -3 and -2 and the last dimensions should have a size of 2.
    crop_shape : Tuple[int, ...]
        The output shape. The shape should be smaller than the corresponding dimensions of data.
    offset : int
        Starting dimension for cropping.
    contiguous : bool
        Return as a contiguous array. Useful for fast reshaping or viewing.
    sampler : str
        Select the random indices from either a `uniform` or `gaussian` distribution (around the center)
    sigma : float or list of float
        Standard variance of the gaussian when sampler is `gaussian`. If not set will take 1/3th of image shape

    Returns
    -------
    torch.Tensor: The center cropped input tensor or list of tensors

    """
    if sampler == "uniform" and sigma is not None:
        raise ValueError(
            f"sampler `uniform` is incompatible with sigma {sigma}, has to be None."
        )

    data_list = ensure_list(data_list)
    assert_same_shape(data_list)

    image_shape = list(data_list[0].shape)

    ndim = data_list[0].ndim
    bbox = [0] * ndim + image_shape

    crop_shape = [
        _ if _ else image_shape[idx + offset]
        for idx, _ in enumerate(crop_shape)
    ]
    crop_shape = np.asarray(crop_shape)

    limits = []
    for idx in range(len(crop_shape)):
        limits.append(image_shape[offset + idx] - crop_shape[idx])
    limits = np.asarray(limits)

    if not all([_ >= 0 for _ in limits]):
        raise ValueError(
            f"Bounding box limits have negative values, "
            f"this is likely to data size being smaller than the crop size. Got {limits}"
        )

    if sampler == "uniform":
        lower_point = np.random.randint(0, limits + 1).tolist()
    elif sampler == "gaussian":
        data_shape = np.asarray(image_shape[offset:offset + len(crop_shape)])
        if not sigma:
            sigma = data_shape / 6  # w, h
        if len(sigma) != 1 and len(sigma) != len(crop_shape):
            raise ValueError(
                f"Either one sigma has to be set or same as the length of the bounding box. Got {sigma}."
            )
        lower_point = (np.random.normal(
            loc=data_shape / 2, scale=sigma, size=len(data_shape)) -
                       crop_shape / 2).astype(int)
        lower_point = np.clip(lower_point, 0, limits)

    else:
        raise ValueError(
            f"Sampler is either `uniform` or `gaussian`. Got {sampler}.")

    for idx in range(len(crop_shape)):
        bbox[offset + idx] = lower_point[idx]
        bbox[offset + ndim + idx] = crop_shape[idx]

    output = [crop_to_bbox(data, bbox) for data in data_list]

    if contiguous:
        output = [_.contiguous() for _ in output]

    if len(output) == 1:
        return output[0]

    return output