def resize_keep_aspect(image, target_res, pad=False):
    """Resizes image to the target_res while keeping aspect ratio by cropping

    image: an 3d array with dims [channel, height, width]
    target_res: [height, width]
    pad: if True, will pad zeros instead of cropping to preserve aspect ratio
    """
    im_res = image.shape[-2:]

    # finds the resolution needed for either dimension to have the target aspect
    # ratio, when the other is kept constant. If the image doesn't have the
    # target ratio, then one of these two will be larger, and the other smaller,
    # than the current image dimensions
    resized_res = (int(np.ceil(im_res[1] * target_res[0] / target_res[1])),
                   int(np.ceil(im_res[0] * target_res[1] / target_res[0])))

    # only pads smaller or crops larger dims, meaning that the resulting image
    # size will be the target aspect ratio after a single pad/crop to the
    # resized_res dimensions
    if pad:
        image = utils.pad_image(image, resized_res, pytorch=False)
    else:
        image = utils.crop_image(image, resized_res, pytorch=False)

    # switch to numpy channel dim convention, resize, switch back
    image = np.transpose(image, axes=(1, 2, 0))
    image = resize(image, target_res, mode='reflect')
    return np.transpose(image, axes=(2, 0, 1))
예제 #2
0
def find_enemies(full_gray_np, game_phase=1) -> Tuple[List[Rectangle], Any]:
    # Assert
    assert_rectangle_shape(
        full_gray_np, screen_template,
        f'Shape of the image should be {screen_template.shape()}')
    assert_gray_img(full_gray_np)

    roi_gray_np = utils.crop_image(full_gray_np, enemy_segment_template)

    mask = None
    if game_phase == 2:
        roi_gray_np = cv2.bitwise_not(roi_gray_np)

    _, mask = cv2.threshold(roi_gray_np, 150, 240, cv2.THRESH_BINARY_INV)

    kernel = np.ones((3, 10), np.uint8)
    dilation = cv2.dilate(mask, kernel)

    _, contours, _ = cv2.findContours(dilation, cv2.RETR_TREE,
                                      cv2.CHAIN_APPROX_NONE)

    to_return = list()
    for cnt in contours:

        x, y, w, h = cv2.boundingRect(cnt)
        if h < 8:
            continue
        rectangle = Rectangle(x, y, x + w, y + h)
        to_return.append(rectangle)

    return to_return, enemy_segment_template.shape()
예제 #3
0
def compute_zernike_basis(num_polynomials, field_res, dtype=torch.float32, wo_piston=False):
    """Computes a set of Zernike basis function with resolution field_res

    num_polynomials: number of Zernike polynomials in this basis
    field_res: [height, width] in px, any list-like object
    dtype: torch dtype for computation at different precision
    """

    # size the zernike basis to avoid circular masking
    zernike_diam = int(np.ceil(np.sqrt(field_res[0]**2 + field_res[1]**2)))

    # create zernike functions

    if not wo_piston:
        zernike = zernikeArray(num_polynomials, zernike_diam)
    else:  # 200427 - exclude pistorn term
        idxs = range(2, 2 + num_polynomials)
        zernike = zernikeArray(idxs, zernike_diam)

    zernike = utils.crop_image(zernike, field_res, pytorch=False)

    # convert to tensor and create phase
    zernike = torch.tensor(zernike, dtype=dtype, requires_grad=False)

    return zernike
def pad_crop_to_res(image, target_res):
    """Pads with 0 and crops as needed to force image to be target_res

    image: an array with dims [..., channel, height, width]
    target_res: [height, width]
    """
    return utils.crop_image(utils.pad_image(image, target_res, pytorch=False),
                            target_res,
                            pytorch=False)
예제 #5
0
def get_game_status(full_gray_np) -> str:
    # Assert
    assert_rectangle_shape(full_gray_np, screen_template, f'Shape of the image should be {screen_template.shape()}')
    assert_gray_img(full_gray_np)

    game_over_img_gray = utils.crop_image(full_gray_np, game_over_template)
    if _find_rectangle(game_over_img_gray, (1200, 1250)):
        return 'game_over'
    else:
        return 'playing'
예제 #6
0
def get_phase(full_gray_np) -> int:
    assert_rectangle_shape(full_gray_np, screen_template, f'Shape of the image should be {screen_template.shape()}')
    assert_gray_img(full_gray_np)
    roi_gray_np = utils.crop_image(full_gray_np, phase_recognition_template)

    flat_pixels = roi_gray_np.ravel()
    all_pixels_average = sum(flat_pixels) / len(flat_pixels)
    if all_pixels_average > 150:
        # Phase 1
        return 1
    else:
        return 2
    def add_zeroth_order(self, idx=0):
        """
        Plot output of model with zero-phase input.

        :param idx: Global step value to record
        """

        zero_phase = torch.zeros((1, 1, *self.slm_res)).to(self.model.dev)
        recon_field = self.model(zero_phase)
        recon_amp = recon_field.abs()
        recon_amp = utils.crop_image(
            recon_amp, self.slm_res,
            stacked_complex=False).cpu().detach().squeeze().unsqueeze(0)
        self.add_image(f'parameters/zero_input_1080p',
                       (recon_amp - recon_amp.min()) /
                       (recon_amp.max() - recon_amp.min()), idx)
        self.add_figure_cmap(f'parameters/zero_input_figure',
                             recon_amp.squeeze(), idx, self.cmap_rgb)
예제 #8
0
def find_dino(full_gray_np) -> List[Rectangle]:
    # Assert
    assert_rectangle_shape(
        full_gray_np, screen_template,
        f'Shape of the image should be {screen_template.shape()}')
    assert_gray_img(full_gray_np)

    roi_gray_np = utils.crop_image(full_gray_np, dino_segment_template)
    _, mask = cv2.threshold(roi_gray_np, 200, 240, cv2.THRESH_BINARY_INV)

    kernel = np.ones((10, 10), np.uint8)
    dilation = cv2.dilate(mask, kernel)

    _, contours, _ = cv2.findContours(dilation, cv2.RETR_TREE,
                                      cv2.CHAIN_APPROX_NONE)

    to_return = list()

    if len(contours) != 1:
        return []
    else:
        x, y, w, h = cv2.boundingRect(contours[0])
        rectangle = Rectangle(x, y, x + w, y + h)
        return rectangle.relativize_from(dino_segment_template)
def propagation_ASM(u_in,
                    feature_size,
                    wavelength,
                    z,
                    linear_conv=True,
                    padtype='zero',
                    return_H=False,
                    precomped_H=None,
                    return_H_exp=False,
                    precomped_H_exp=None,
                    dtype=torch.float32):
    """Propagates the input field using the angular spectrum method

    Inputs
    ------
    u_in: complex field of size (num_images, 1, height, width, 2)
        where the last two channels are real and imaginary values
    feature_size: (height, width) of individual holographic features in m
    wavelength: wavelength in m
    z: propagation distance
    linear_conv: if True, pad the input to obtain a linear convolution
    padtype: 'zero' to pad with zeros, 'median' to pad with median of u_in's
        amplitude
    return_H[_exp]: used for precomputing H or H_exp, ends the computation early
        and returns the desired variable
    precomped_H[_exp]: the precomputed value for H or H_exp
    dtype: torch dtype for computation at different precision

    Output
    ------
    tensor of size (num_images, 1, height, width, 2)
    """

    if linear_conv:
        # preprocess with padding for linear conv.
        input_resolution = u_in.size()[-3:-1]
        conv_size = [i * 2 for i in input_resolution]
        if padtype == 'zero':
            padval = 0
        elif padtype == 'median':
            padval = torch.median(torch.pow((u_in**2).sum(-1), 0.5))
        u_in = utils.pad_image(u_in, conv_size, padval=padval)

    if precomped_H is None and precomped_H_exp is None:
        # resolution of input field, should be: (num_images, num_channels, height, width, 2)
        field_resolution = u_in.size()

        # number of pixels
        num_y, num_x = field_resolution[2], field_resolution[3]

        # sampling inteval size
        dy, dx = feature_size

        # size of the field
        y, x = (dy * float(num_y), dx * float(num_x))

        # frequency coordinates sampling
        fy = np.linspace(-1 / (2 * dy) + 0.5 / (2 * y),
                         1 / (2 * dy) - 0.5 / (2 * y), num_y)
        fx = np.linspace(-1 / (2 * dx) + 0.5 / (2 * x),
                         1 / (2 * dx) - 0.5 / (2 * x), num_x)

        # momentum/reciprocal space
        FX, FY = np.meshgrid(fx, fy)

        # transfer function in numpy (omit distance)
        HH = 2 * math.pi * np.sqrt(1 / wavelength**2 - (FX**2 + FY**2))

        # create tensor & upload to device (GPU)
        H_exp = torch.tensor(HH, dtype=dtype).to(u_in.device)

        ###
        # here one may iterate over multiple distances, once H_exp is uploaded on GPU

        # reshape tensor and multiply
        H_exp = torch.reshape(H_exp, (1, 1, *H_exp.size()))

    # handle loading the precomputed H_exp value, or saving it for later runs
    elif precomped_H_exp is not None:
        H_exp = precomped_H_exp

    if precomped_H is None:
        # multiply by distance
        H_exp = torch.mul(H_exp, z)

        # band-limited ASM - Matsushima et al. (2009)
        fy_max = 1 / np.sqrt((2 * z * (1 / y))**2 + 1) / wavelength
        fx_max = 1 / np.sqrt((2 * z * (1 / x))**2 + 1) / wavelength
        H_filter = torch.tensor(
            ((np.abs(FX) < fx_max) & (np.abs(FY) < fy_max)).astype(np.uint8),
            dtype=dtype)

        # get real/img components
        H_real, H_imag = utils.polar_to_rect(H_filter.to(u_in.device), H_exp)

        H = torch.stack((H_real, H_imag), 4)
        H = utils.ifftshift(H)
    else:
        H = precomped_H

    # return for use later as precomputed inputs
    if return_H_exp:
        return H_exp
    if return_H:
        return H

    # angular spectrum
    U1 = torch.fft(utils.ifftshift(u_in), 2, True)

    # convolution of the system
    U2 = utils.mul_complex(H, U1)

    # Fourier transform of the convolution to the observation plane
    u_out = utils.fftshift(torch.ifft(U2, 2, True))

    if linear_conv:
        return utils.crop_image(u_out, input_resolution)
    else:
        return u_out
예제 #10
0
for e in range(opt.num_epochs):

    print(f'   - Epoch {e+1} ...')
    # visualize all the modules in the model on tensorboard
    with torch.no_grad():
        writer.visualize_model(e)

    for i, target in enumerate(image_loader):
        target_amp, _, target_filenames = target

        # extract indices of images
        idxs = []
        for name in target_filenames:
            _, target_filename = os.path.split(name)
            idxs.append(target_filename.split('_')[-1])
        target_amp = utils.crop_image(target_amp, target_shape=roi_res, stacked_complex=False).to(device)

        # load phases
        slm_phases = []
        for k, idx in enumerate(idxs):
            # Load pre-computed phases
            # Instead, you can optimize phases from the scratch after a few number of iterations.
            if e > 0:
                phase_filename = os.path.join(phase_path, f'{chan_str}', f'{idx}.png')
            else:
                phase_filename = os.path.join(phase_path, f'{chan_str}', f'{idx}_{channel}', f'phasemaps_1000.png')
            slm_phase = skimage.io.imread(phase_filename) / np.iinfo(np.uint8).max

            # invert phase (our SLM setup)
            slm_phase = torch.tensor((1 - slm_phase) * 2 * np.pi - np.pi,
                                     dtype=dtype).reshape(1, 1, *slm_res).to(device)
예제 #11
0
        final_phase_num_in = 4
    blur = utils.make_kernel_gaussian(0.849, 3)

    # load camera model and set it into eval mode
    model_prop = ModelPropagate(distance=prop_dist,
                                feature_size=feature_size,
                                wavelength=wavelength,
                                blur=blur).to(device)
    model_prop.load_state_dict(torch.load(opt.model_path, map_location=device))

    # Here, we crop model parameters to match the Holonet resolution, which is slightly different from 1080p.
    # parameters for CITL model
    zernike_coeffs = model_prop.coeffs_fourier
    source_amplitude = model_prop.source_amp
    latent_codes = model_prop.latent_code
    latent_codes = utils.crop_image(latent_codes, target_shape=image_res, pytorch=True, stacked_complex=False)

    # content independent target field (See Section 5.1.1.)
    u_t_amp = utils.crop_image(model_prop.target_constant_amp, target_shape=image_res, stacked_complex=False)
    u_t_phase = utils.crop_image(model_prop.target_constant_phase, target_shape=image_res, stacked_complex=False)
    real, imag = utils.polar_to_rect(u_t_amp, u_t_phase)
    u_t = torch.complex(real, imag)

    # match the shape of forward model parameters (1072, 1920)

    # If you make it nn.Parameter, the Holonet will include these parameters in the torch.save files
    model_prop.latent_code = nn.Parameter(latent_codes.detach(), requires_grad=False)
    model_prop.target_constant_amp = nn.Parameter(u_t_amp.detach(), requires_grad=False)
    model_prop.target_constant_phase = nn.Parameter(u_t_phase.detach(), requires_grad=False)

    # But as these parameters are already in the CITL-calibrated models,
예제 #12
0
                           crop_to_homography=True,
                           shuffle=False,
                           vertical_flips=False,
                           horizontal_flips=False)

# Loop over the dataset
for k, target in enumerate(image_loader):
    # get target image
    target_amp, target_res, target_filename = target
    target_path, target_filename = os.path.split(target_filename[0])
    target_idx = target_filename.split('_')[-1]
    target_amp = target_amp.to(device)
    print(target_idx)

    # if you want to separate folders by target_idx or whatever, you can do so here.
    phase_only_algorithm.init_scale = s0 * utils.crop_image(
        target_amp, roi_res, stacked_complex=False).mean()
    phase_only_algorithm.phase_path = os.path.join(root_path)

    # run algorithm (See algorithm_modules.py and algorithms.py)
    if opt.method in ['DPAC', 'HOLONET', 'UNET']:
        # direct methods
        _, final_phase = phase_only_algorithm(target_amp)
    else:
        # iterative methods, initial phase: random guess
        init_phase = (-0.5 + 1.0 * torch.rand(1, 1, *slm_res)).to(device)
        final_phase = phase_only_algorithm(target_amp, init_phase)

    print(final_phase.shape)

    # save the final result somewhere.
    phase_out_8bit = utils.phasemap_8bit(final_phase.cpu().detach(),
예제 #13
0
psnrs = {'amp': [], 'lin': [], 'srgb': []}
ssims = {'amp': [], 'lin': [], 'srgb': []}
idxs = []

# Loop over the dataset
for k, target in enumerate(image_loader):
    # get target image
    target_amp, target_res, target_filename = target
    target_path, target_filename = os.path.split(target_filename[0])
    target_idx = target_filename.split('_')[-1]
    target_amp = target_amp.to(device)

    print(f'    - running for img_{target_idx}...')

    # crop to ROI
    target_amp = utils.crop_image(target_amp, target_shape=roi_res, stacked_complex=False).to(device)

    recon_amp = []

    # for each channel, propagate wave from the SLM plane to the image plane and get the reconstructed image.
    for c in chs:
        # load and invert phase (our SLM setup)
        phase_filename = os.path.join(opt.root_path, chan_strs[c], f'{target_idx}.png')
        slm_phase = skimage.io.imread(phase_filename) / 255.
        slm_phase = torch.tensor((1 - slm_phase) * 2 * np.pi - np.pi, dtype=dtype).reshape(1, 1, *slm_res).to(device)

        # propagate field
        real, imag = utils.polar_to_rect(torch.ones_like(slm_phase), slm_phase)
        slm_field = torch.complex(real, imag)

        if opt.prop_model.upper() == 'MODEL':
예제 #14
0
def stochastic_gradient_descent(init_phase,
                                target_amp,
                                num_iters,
                                prop_dist,
                                wavelength,
                                feature_size,
                                roi_res=None,
                                phase_path=None,
                                prop_model='ASM',
                                propagator=None,
                                loss=nn.MSELoss(),
                                lr=0.01,
                                lr_s=0.003,
                                s0=1.0,
                                citl=False,
                                camera_prop=None,
                                writer=None,
                                dtype=torch.float32,
                                precomputed_H=None):
    """
    Given the initial guess, run the SGD algorithm to calculate the optimal phase pattern of spatial light modulator.

    Input
    ------
    :param init_phase: a tensor, in the shape of (1,1,H,W), initial guess for the phase.
    :param target_amp: a tensor, in the shape of (1,1,H,W), the amplitude of the target image.
    :param num_iters: the number of iterations to run the SGD.
    :param prop_dist: propagation distance in m.
    :param wavelength: wavelength in m.
    :param feature_size: the SLM pixel pitch, in meters, default 6.4e-6
    :param roi_res: a tuple of integer, region of interest, like (880, 1600)
    :param phase_path: a string, for saving intermediate phases
    :param prop_model: a string, that indicates the propagation model. ('ASM' or 'MODEL')
    :param propagator: predefined function or model instance for the propagation.
    :param loss: loss function, default L2
    :param lr: learning rate for optimization variables
    :param lr_s: learning rate for learnable scale
    :param s0: initial scale
    :param writer: Tensorboard writer instance
    :param dtype: default torch.float32
    :param precomputed_H: A Pytorch complex64 tensor, pre-computed kernel shape of (1,1,2H,2W) for fast computation.

    Output
    ------
    :return: a tensor, the optimized phase pattern at the SLM plane, in the shape of (1,1,H,W)
    """

    device = init_phase.device
    s = torch.tensor(s0, requires_grad=True, device=device)

    # phase at the slm plane
    slm_phase = init_phase.requires_grad_(True)

    # optimization variables and adam optimizer
    optvars = [{'params': slm_phase}]
    if lr_s > 0:
        optvars += [{'params': s, 'lr': lr_s}]
    optimizer = optim.Adam(optvars, lr=lr)

    # crop target roi
    target_amp = utils.crop_image(target_amp, roi_res, stacked_complex=False)

    # run the iterative algorithm
    for k in range(num_iters):
        optimizer.zero_grad()
        # forward propagation from the SLM plane to the target plane
        real, imag = utils.polar_to_rect(torch.ones_like(slm_phase), slm_phase)
        slm_field = torch.complex(real, imag)

        recon_field = utils.propagate_field(slm_field, propagator, prop_dist,
                                            wavelength, feature_size,
                                            prop_model, dtype, precomputed_H)

        # get amplitude
        recon_amp = recon_field.abs()

        # crop roi
        recon_amp = utils.crop_image(recon_amp,
                                     target_shape=roi_res,
                                     stacked_complex=False)

        # camera-in-the-loop technique
        if citl:
            captured_amp = camera_prop(slm_phase)

            # use the gradient of proxy, replacing the amplitudes
            # captured_amp is assumed that its size already matches that of recon_amp
            out_amp = recon_amp + (captured_amp - recon_amp).detach()
        else:
            out_amp = recon_amp

        # calculate loss and backprop
        lossValue = loss(s * out_amp, target_amp)
        lossValue.backward()
        optimizer.step()

        # write to tensorboard / write phase image
        # Note that it takes 0.~ s for writing it to tensorboard
        with torch.no_grad():
            if k % 50 == 0:
                print(k)
                utils.write_sgd_summary(slm_phase,
                                        out_amp,
                                        target_amp,
                                        k,
                                        writer=writer,
                                        path=phase_path,
                                        s=s,
                                        prefix='test')

    return slm_phase