Ejemplo n.º 1
0
def bg_correct(raw, bg, df=None):
    """Correct for noisy images by dividing by a background. The calculation used is (raw-df)/(bg-df). 

    Parameters
    ----------
    raw : xarray.DataArray
        Image to be background divided.
    bg : xarray.DataArray
        background image recorded with the same optical setup.
    df : xarray.DataArray
        dark field image recorded without illumination.

    Returns
    -------
    corrected_image : xarray.DataArray
       A copy of the background divided input image with None values of noise_sd updated to match bg.

    """
    if df is None:
        df = raw.copy()
        df[:] = 0

    if not (raw.shape == bg.shape == df.shape and list(get_spacing(raw)) ==
            list(get_spacing(bg)) == list(get_spacing(df))):
        raise BadImage(
            "raw and background images must have the same shape and spacing")

    holo = (raw - df) / zero_filter(bg - df)
    holo = copy_metadata(raw, holo)

    if hasattr(holo, 'noise_sd') and hasattr(
            bg, 'noise_sd') and holo.noise_sd is None:
        holo = update_metadata(holo, noise_sd=bg.noise_sd)

    return holo
Ejemplo n.º 2
0
def zero_filter(image):
    '''
    Search for and interpolate pixels equal to 0.
    This is to avoid NaN's when a hologram is divided by a BG with 0's.

    Parameters
    ----------
    image : xarray.DataArray
       Image to process

    Returns
    -------
    image : xarray.DataArray
       Image where pixels = 0 are instead given values equal to average of
       neighbors.  dtype is the same as the input image
    '''
    zero_pix = np.where(image == 0)
    output = image.copy()

    # check to see if adjacent pixels are 0, if more than 1 dead pixel
    if len(zero_pix[0]) > 1:
        delta_rows = zero_pix[0] - np.roll(zero_pix[0], 1)
        delta_cols = zero_pix[1] - np.roll(zero_pix[1], 1)
        if ((1 in delta_rows[np.where(delta_cols == 0)])
                or (1 in delta_cols[np.where(delta_rows == 0)])):
            raise BadImage(
                'Image has adjacent dead pixels, cannot remove dead pixels')

    for row, col in zip(zero_pix[0], zero_pix[1]):
        # in the bulk
        if ((row > 0) and (row < (image.shape[0] - 1)) and (col > 0)
                and (col < image.shape[1] - 1)):
            output[row,
                   col] = np.sum(image[row - 1:row + 2, col - 1:col + 2]) / 8.
        else:  # deal with edges by padding
            im_avg = image.sum() / (image.size - len(zero_pix[0]))
            padded_im = np.ones(
                (image.shape[0] + 2, image.shape[1] + 2)) * im_avg
            padded_im[1:-1, 1:-1] = image
            output[row, col] = np.sum(padded_im[row:row + 3, col:col + 3]) / 8.
        print('Pixel with value 0 reset to nearest neighbor average')

    return copy_metadata(image, output)
Ejemplo n.º 3
0
def load_image(inf,
               spacing=None,
               medium_index=None,
               illum_wavelen=None,
               illum_polarization=None,
               normals=None,
               noise_sd=None,
               channel=None,
               name=None):
    """
    Load data or results

    Parameters
    ----------
    inf : string
        File to load.
    spacing : float or (float, float) (optional)
        pixel size of images in each dimension - assumes square pixels if single value.
        set equal to 1 if not passed in and issues warning.
    medium_index : float (optional)
        refractive index of the medium
    illum_wavelen : float (optional)
        wavelength (in vacuum) of illuminating light
    illum_polarization : (float, float) (optional)
        (x, y) polarization vector of the illuminating light
    noise_sd : float (optional)
        noise level in the image, normalized to image intensity
    channel : int or tuple of ints (optional)
        number(s) of channel to load for a color image (in general 0=red,
        1=green, 2=blue)
	name : str (optional)
        name to assign the xr.DataArray object resulting from load_image

    Returns
    -------
    obj : xarray.DataArray representation of the image with associated metadata

    """
    if normals is not None:
        raise ValueError(NORMALS_DEPRECATION_MESSAGE)
    if name is None:
        name = os.path.splitext(os.path.split(inf)[-1])[0]

    with open(inf, 'rb') as pi_raw:
        pi = pilimage.open(pi_raw)
        arr = np.asarray(pi).astype('d')
        try:
            if isinstance(yaml.safe_load(pi.tag[270][0]), dict):
                warnings.warn(
                    "Metadata detected but ignored. Use hp.load to read it.")
        except (AttributeError, KeyError):
            pass

    extra_dims = None
    if channel is None:
        if arr.ndim > 2:
            raise BadImage(
                'Not a greyscale image. You must specify which channel(s) to use'
            )
    elif arr.ndim == 2:
        if not channel == 'all':
            warnings.warn("Not a color image (channel number ignored)")
        pass
    else:
        # color image with specified channel(s)
        if channel == 'all':
            channel = range(arr.shape[2])
        channel = ensure_array(channel)
        if channel.max() >= arr.shape[2]:
            raise LoadError(
                filename, "The image doesn't have a channel number {0}".format(
                    channel.max()))
        else:
            arr = arr[:, :, channel].squeeze()

            if len(channel) > 1:
                # multiple channels. increase output dimensionality
                if channel.max() <= 2:
                    channel = [['red', 'green', 'blue'][c] for c in channel]
                extra_dims = {illumination: channel}
                if illum_wavelen is not None and not isinstance(
                        illum_wavelen, dict) and len(
                            ensure_array(illum_wavelen)) == len(channel):
                    illum_wavelen = xr.DataArray(ensure_array(illum_wavelen),
                                                 dims=illumination,
                                                 coords=extra_dims)
                if not isinstance(illum_polarization, dict) and np.array(
                        illum_polarization).ndim == 2:
                    pol_index = xr.DataArray(channel,
                                             dims=illumination,
                                             name=illumination)
                    illum_polarization = xr.concat(
                        [to_vector(pol) for pol in illum_polarization],
                        pol_index)

    image = data_grid(arr,
                      spacing=spacing,
                      medium_index=medium_index,
                      illum_wavelen=illum_wavelen,
                      illum_polarization=illum_polarization,
                      noise_sd=noise_sd,
                      name=name,
                      extra_dims=extra_dims)
    return image
Ejemplo n.º 4
0
def display_image(im,
                  scaling='auto',
                  vert_axis='x',
                  horiz_axis='y',
                  depth_axis='z',
                  colour_axis='illumination'):
    im = im.copy()
    if isinstance(im, xr.DataArray):
        if hasattr(im, 'z') and len(im['z']) == 1 and depth_axis is not 'z':
            im = im[{'z': 0}]
        if depth_axis == 'z' and 'z' not in im.dims:
            im = im.expand_dims('z')
        if im.ndim > 3 + (colour_axis in im.dims):
            raise BadImage("Too many dims on DataArray to output properly.")
        attrs = im.attrs
    else:
        attrs = {}
        im = ensure_array(im)
        if im.ndim > 3:
            raise BadImage("Too many dims on ndarray to output properly.")
        elif im.ndim == 2:
            im = np.array([im])
        elif im.ndim < 2:
            raise BadImage("Too few dims on ndarray to output properly.")
        axes = [0, 1, 2]
        for axis in [vert_axis, horiz_axis, depth_axis]:
            if isinstance(axis, int):
                try:
                    axes.remove(axis)
                except KeyError:
                    raise ValueError("Cannot interpret axis specifications.")
        if len(axes) > 0:
            if not isinstance(depth_axis, int):
                depth_axis = axes[np.argmin([im.shape[i] for i in axes])]
                axes.remove(depth_axis)
            if not isinstance(vert_axis, int):
                vert_axis = axes[0]
                axes.pop(0)
            if not isinstance(horiz_axis, int):
                horiz_axis = axes[0]
        im = im.transpose([depth_axis, vert_axis, horiz_axis])
        depth_axis = 'z'
        vert_axis = 'x'
        horiz_axis = 'y'
        im = data_grid(im, spacing=1, z=range(len(im)))
    if np.iscomplex(im).any():
        warn("Image contains complex values. Taking image magnitude.")
        im = np.abs(im)
    if scaling is 'auto':
        scaling = (ensure_scalar(im.min()), ensure_scalar(im.max()))
    if scaling is not None:
        im = np.maximum(im, scaling[0])
        im = np.minimum(im, scaling[1])
        im = (im - scaling[0]) / (scaling[1] - scaling[0])
    im.attrs = attrs
    im.attrs['_image_scaling'] = scaling

    if colour_axis in im.dims:
        cols = [
            col[0].capitalize() if isinstance(col, str) else ' '
            for col in im[colour_axis].values
        ]
        RGB_names = np.all([letter in 'RGB' for letter in cols])
        if len(im[colour_axis]) == 1:
            im = im.squeeze(dim=colour_axis)
        elif len(im[colour_axis]) > 3:
            raise BadImage('Cannot output more than 3 colour channels')
        elif RGB_names:
            channels = {
                col: im[{
                    colour_axis: i
                }]
                for i, col in enumerate(cols)
            }
            if len(channels) == 2:
                dummy = im[{colour_axis: 0}].copy()
                dummy[:] = im.min()
                for i, col in enumerate('RGB'):
                    if col not in cols:
                        dummy[colour_axis] = col
                        channels[col] = dummy
                        channels['R'].attrs['_dummy_channel'] = i
                        break
            channels = [channels[col] for col in 'RGB']
            im = clean_concat(channels, colour_axis)
        elif len(im[colour_axis]) == 2:
            dummy = xr.full_like(im[{colour_axis: 0}], fill_value=im.min())
            dummy = dummy.expand_dims({colour_axis: [np.NaN]})
            im.attrs['_dummy_channel'] = -1
            im = clean_concat([im, dummy], colour_axis)
    dim_order = [depth_axis, vert_axis, horiz_axis, colour_axis][:im.ndim]
    return im.transpose(*dim_order)