def resize_keep_aspect(image, target_res, pad=False): """Resizes image to the target_res while keeping aspect ratio by cropping image: an 3d array with dims [channel, height, width] target_res: [height, width] pad: if True, will pad zeros instead of cropping to preserve aspect ratio """ im_res = image.shape[-2:] # finds the resolution needed for either dimension to have the target aspect # ratio, when the other is kept constant. If the image doesn't have the # target ratio, then one of these two will be larger, and the other smaller, # than the current image dimensions resized_res = (int(np.ceil(im_res[1] * target_res[0] / target_res[1])), int(np.ceil(im_res[0] * target_res[1] / target_res[0]))) # only pads smaller or crops larger dims, meaning that the resulting image # size will be the target aspect ratio after a single pad/crop to the # resized_res dimensions if pad: image = utils.pad_image(image, resized_res, pytorch=False) else: image = utils.crop_image(image, resized_res, pytorch=False) # switch to numpy channel dim convention, resize, switch back image = np.transpose(image, axes=(1, 2, 0)) image = resize(image, target_res, mode='reflect') return np.transpose(image, axes=(2, 0, 1))
def pad_crop_to_res(image, target_res): """Pads with 0 and crops as needed to force image to be target_res image: an array with dims [..., channel, height, width] target_res: [height, width] """ return utils.crop_image(utils.pad_image(image, target_res, pytorch=False), target_res, pytorch=False)
def propagation_ASM(u_in, feature_size, wavelength, z, linear_conv=True, padtype='zero', return_H=False, precomped_H=None, return_H_exp=False, precomped_H_exp=None, dtype=torch.float32): """Propagates the input field using the angular spectrum method Inputs ------ u_in: complex field of size (num_images, 1, height, width, 2) where the last two channels are real and imaginary values feature_size: (height, width) of individual holographic features in m wavelength: wavelength in m z: propagation distance linear_conv: if True, pad the input to obtain a linear convolution padtype: 'zero' to pad with zeros, 'median' to pad with median of u_in's amplitude return_H[_exp]: used for precomputing H or H_exp, ends the computation early and returns the desired variable precomped_H[_exp]: the precomputed value for H or H_exp dtype: torch dtype for computation at different precision Output ------ tensor of size (num_images, 1, height, width, 2) """ if linear_conv: # preprocess with padding for linear conv. input_resolution = u_in.size()[-3:-1] conv_size = [i * 2 for i in input_resolution] if padtype == 'zero': padval = 0 elif padtype == 'median': padval = torch.median(torch.pow((u_in**2).sum(-1), 0.5)) u_in = utils.pad_image(u_in, conv_size, padval=padval) if precomped_H is None and precomped_H_exp is None: # resolution of input field, should be: (num_images, num_channels, height, width, 2) field_resolution = u_in.size() # number of pixels num_y, num_x = field_resolution[2], field_resolution[3] # sampling inteval size dy, dx = feature_size # size of the field y, x = (dy * float(num_y), dx * float(num_x)) # frequency coordinates sampling fy = np.linspace(-1 / (2 * dy) + 0.5 / (2 * y), 1 / (2 * dy) - 0.5 / (2 * y), num_y) fx = np.linspace(-1 / (2 * dx) + 0.5 / (2 * x), 1 / (2 * dx) - 0.5 / (2 * x), num_x) # momentum/reciprocal space FX, FY = np.meshgrid(fx, fy) # transfer function in numpy (omit distance) HH = 2 * math.pi * np.sqrt(1 / wavelength**2 - (FX**2 + FY**2)) # create tensor & upload to device (GPU) H_exp = torch.tensor(HH, dtype=dtype).to(u_in.device) ### # here one may iterate over multiple distances, once H_exp is uploaded on GPU # reshape tensor and multiply H_exp = torch.reshape(H_exp, (1, 1, *H_exp.size())) # handle loading the precomputed H_exp value, or saving it for later runs elif precomped_H_exp is not None: H_exp = precomped_H_exp if precomped_H is None: # multiply by distance H_exp = torch.mul(H_exp, z) # band-limited ASM - Matsushima et al. (2009) fy_max = 1 / np.sqrt((2 * z * (1 / y))**2 + 1) / wavelength fx_max = 1 / np.sqrt((2 * z * (1 / x))**2 + 1) / wavelength H_filter = torch.tensor( ((np.abs(FX) < fx_max) & (np.abs(FY) < fy_max)).astype(np.uint8), dtype=dtype) # get real/img components H_real, H_imag = utils.polar_to_rect(H_filter.to(u_in.device), H_exp) H = torch.stack((H_real, H_imag), 4) H = utils.ifftshift(H) else: H = precomped_H # return for use later as precomputed inputs if return_H_exp: return H_exp if return_H: return H # angular spectrum U1 = torch.fft(utils.ifftshift(u_in), 2, True) # convolution of the system U2 = utils.mul_complex(H, U1) # Fourier transform of the convolution to the observation plane u_out = utils.fftshift(torch.ifft(U2, 2, True)) if linear_conv: return utils.crop_image(u_out, input_resolution) else: return u_out