def complex_center_crop(data_list, shape, offset=1, contiguous=False): """ Apply a center crop to the input data, or to a list of complex images Parameters ---------- data_list : List[torch.Tensor] or torch.Tensor The complex input tensor to be center cropped. It should have at least 3 dimensions and the cropping is applied along dimensions didx and didx+1 and the last dimensions should have a size of 2. shape : Tuple[int, int] The output shape. The shape should be smaller than the corresponding dimensions of data. If one value is None, this is filled in by the image shape. offset : int Starting dimension for cropping. contiguous : bool Return as a contiguous array. Useful for fast reshaping or viewing. Returns ------- torch.Tensor or list[torch.Tensor]: The center cropped input_image """ data_list = ensure_list(data_list) assert_same_shape(data_list) image_shape = list(data_list[0].shape) ndim = data_list[0].ndim bbox = [0] * ndim + image_shape # Allow for False in crop directions shape = [ _ if _ else image_shape[idx + offset] for idx, _ in enumerate(shape) ] for idx in range(len(shape)): bbox[idx + offset] = (image_shape[idx + offset] - shape[idx]) // 2 bbox[len(image_shape) + idx + offset] = shape[idx] if not all([_ >= 0 for _ in bbox[:ndim]]): raise ValueError( f"Bounding box requested has negative values, " f"this is likely to data size being smaller than the crop size. Got {bbox} with image_shape {image_shape} " f"and requested shape {shape}.") output = [crop_to_bbox(data, bbox) for data in data_list] if contiguous: output = [_.contiguous() for _ in output] if len(output) == 1: # Only one element: output = output[0] return output
def complex_random_crop( data_list, crop_shape, offset: int = 1, contiguous: bool = False, sampler: str = "uniform", sigma: bool = None, ): """ Apply a random crop to the input data tensor or a list of complex. Parameters ---------- data_list : Union[List[torch.Tensor], torch.Tensor] The complex input tensor to be center cropped. It should have at least 3 dimensions and the cropping is applied along dimensions -3 and -2 and the last dimensions should have a size of 2. crop_shape : Tuple[int, ...] The output shape. The shape should be smaller than the corresponding dimensions of data. offset : int Starting dimension for cropping. contiguous : bool Return as a contiguous array. Useful for fast reshaping or viewing. sampler : str Select the random indices from either a `uniform` or `gaussian` distribution (around the center) sigma : float or list of float Standard variance of the gaussian when sampler is `gaussian`. If not set will take 1/3th of image shape Returns ------- torch.Tensor: The center cropped input tensor or list of tensors """ if sampler == "uniform" and sigma is not None: raise ValueError( f"sampler `uniform` is incompatible with sigma {sigma}, has to be None." ) data_list = ensure_list(data_list) assert_same_shape(data_list) image_shape = list(data_list[0].shape) ndim = data_list[0].ndim bbox = [0] * ndim + image_shape crop_shape = [ _ if _ else image_shape[idx + offset] for idx, _ in enumerate(crop_shape) ] crop_shape = np.asarray(crop_shape) limits = [] for idx in range(len(crop_shape)): limits.append(image_shape[offset + idx] - crop_shape[idx]) limits = np.asarray(limits) if not all([_ >= 0 for _ in limits]): raise ValueError( f"Bounding box limits have negative values, " f"this is likely to data size being smaller than the crop size. Got {limits}" ) if sampler == "uniform": lower_point = np.random.randint(0, limits + 1).tolist() elif sampler == "gaussian": data_shape = np.asarray(image_shape[offset:offset + len(crop_shape)]) if not sigma: sigma = data_shape / 6 # w, h if len(sigma) != 1 and len(sigma) != len(crop_shape): raise ValueError( f"Either one sigma has to be set or same as the length of the bounding box. Got {sigma}." ) lower_point = (np.random.normal( loc=data_shape / 2, scale=sigma, size=len(data_shape)) - crop_shape / 2).astype(int) lower_point = np.clip(lower_point, 0, limits) else: raise ValueError( f"Sampler is either `uniform` or `gaussian`. Got {sampler}.") for idx in range(len(crop_shape)): bbox[offset + idx] = lower_point[idx] bbox[offset + ndim + idx] = crop_shape[idx] output = [crop_to_bbox(data, bbox) for data in data_list] if contiguous: output = [_.contiguous() for _ in output] if len(output) == 1: return output[0] return output