def bg_correct(raw, bg, df=None): """Correct for noisy images by dividing by a background. The calculation used is (raw-df)/(bg-df). Parameters ---------- raw : xarray.DataArray Image to be background divided. bg : xarray.DataArray background image recorded with the same optical setup. df : xarray.DataArray dark field image recorded without illumination. Returns ------- corrected_image : xarray.DataArray A copy of the background divided input image with None values of noise_sd updated to match bg. """ if df is None: df = raw.copy() df[:] = 0 if not (raw.shape == bg.shape == df.shape and list(get_spacing(raw)) == list(get_spacing(bg)) == list(get_spacing(df))): raise BadImage( "raw and background images must have the same shape and spacing") holo = (raw - df) / zero_filter(bg - df) holo = copy_metadata(raw, holo) if hasattr(holo, 'noise_sd') and hasattr( bg, 'noise_sd') and holo.noise_sd is None: holo = update_metadata(holo, noise_sd=bg.noise_sd) return holo
def zero_filter(image): ''' Search for and interpolate pixels equal to 0. This is to avoid NaN's when a hologram is divided by a BG with 0's. Parameters ---------- image : xarray.DataArray Image to process Returns ------- image : xarray.DataArray Image where pixels = 0 are instead given values equal to average of neighbors. dtype is the same as the input image ''' zero_pix = np.where(image == 0) output = image.copy() # check to see if adjacent pixels are 0, if more than 1 dead pixel if len(zero_pix[0]) > 1: delta_rows = zero_pix[0] - np.roll(zero_pix[0], 1) delta_cols = zero_pix[1] - np.roll(zero_pix[1], 1) if ((1 in delta_rows[np.where(delta_cols == 0)]) or (1 in delta_cols[np.where(delta_rows == 0)])): raise BadImage( 'Image has adjacent dead pixels, cannot remove dead pixels') for row, col in zip(zero_pix[0], zero_pix[1]): # in the bulk if ((row > 0) and (row < (image.shape[0] - 1)) and (col > 0) and (col < image.shape[1] - 1)): output[row, col] = np.sum(image[row - 1:row + 2, col - 1:col + 2]) / 8. else: # deal with edges by padding im_avg = image.sum() / (image.size - len(zero_pix[0])) padded_im = np.ones( (image.shape[0] + 2, image.shape[1] + 2)) * im_avg padded_im[1:-1, 1:-1] = image output[row, col] = np.sum(padded_im[row:row + 3, col:col + 3]) / 8. print('Pixel with value 0 reset to nearest neighbor average') return copy_metadata(image, output)
def load_image(inf, spacing=None, medium_index=None, illum_wavelen=None, illum_polarization=None, normals=None, noise_sd=None, channel=None, name=None): """ Load data or results Parameters ---------- inf : string File to load. spacing : float or (float, float) (optional) pixel size of images in each dimension - assumes square pixels if single value. set equal to 1 if not passed in and issues warning. medium_index : float (optional) refractive index of the medium illum_wavelen : float (optional) wavelength (in vacuum) of illuminating light illum_polarization : (float, float) (optional) (x, y) polarization vector of the illuminating light noise_sd : float (optional) noise level in the image, normalized to image intensity channel : int or tuple of ints (optional) number(s) of channel to load for a color image (in general 0=red, 1=green, 2=blue) name : str (optional) name to assign the xr.DataArray object resulting from load_image Returns ------- obj : xarray.DataArray representation of the image with associated metadata """ if normals is not None: raise ValueError(NORMALS_DEPRECATION_MESSAGE) if name is None: name = os.path.splitext(os.path.split(inf)[-1])[0] with open(inf, 'rb') as pi_raw: pi = pilimage.open(pi_raw) arr = np.asarray(pi).astype('d') try: if isinstance(yaml.safe_load(pi.tag[270][0]), dict): warnings.warn( "Metadata detected but ignored. Use hp.load to read it.") except (AttributeError, KeyError): pass extra_dims = None if channel is None: if arr.ndim > 2: raise BadImage( 'Not a greyscale image. You must specify which channel(s) to use' ) elif arr.ndim == 2: if not channel == 'all': warnings.warn("Not a color image (channel number ignored)") pass else: # color image with specified channel(s) if channel == 'all': channel = range(arr.shape[2]) channel = ensure_array(channel) if channel.max() >= arr.shape[2]: raise LoadError( filename, "The image doesn't have a channel number {0}".format( channel.max())) else: arr = arr[:, :, channel].squeeze() if len(channel) > 1: # multiple channels. increase output dimensionality if channel.max() <= 2: channel = [['red', 'green', 'blue'][c] for c in channel] extra_dims = {illumination: channel} if illum_wavelen is not None and not isinstance( illum_wavelen, dict) and len( ensure_array(illum_wavelen)) == len(channel): illum_wavelen = xr.DataArray(ensure_array(illum_wavelen), dims=illumination, coords=extra_dims) if not isinstance(illum_polarization, dict) and np.array( illum_polarization).ndim == 2: pol_index = xr.DataArray(channel, dims=illumination, name=illumination) illum_polarization = xr.concat( [to_vector(pol) for pol in illum_polarization], pol_index) image = data_grid(arr, spacing=spacing, medium_index=medium_index, illum_wavelen=illum_wavelen, illum_polarization=illum_polarization, noise_sd=noise_sd, name=name, extra_dims=extra_dims) return image
def display_image(im, scaling='auto', vert_axis='x', horiz_axis='y', depth_axis='z', colour_axis='illumination'): im = im.copy() if isinstance(im, xr.DataArray): if hasattr(im, 'z') and len(im['z']) == 1 and depth_axis is not 'z': im = im[{'z': 0}] if depth_axis == 'z' and 'z' not in im.dims: im = im.expand_dims('z') if im.ndim > 3 + (colour_axis in im.dims): raise BadImage("Too many dims on DataArray to output properly.") attrs = im.attrs else: attrs = {} im = ensure_array(im) if im.ndim > 3: raise BadImage("Too many dims on ndarray to output properly.") elif im.ndim == 2: im = np.array([im]) elif im.ndim < 2: raise BadImage("Too few dims on ndarray to output properly.") axes = [0, 1, 2] for axis in [vert_axis, horiz_axis, depth_axis]: if isinstance(axis, int): try: axes.remove(axis) except KeyError: raise ValueError("Cannot interpret axis specifications.") if len(axes) > 0: if not isinstance(depth_axis, int): depth_axis = axes[np.argmin([im.shape[i] for i in axes])] axes.remove(depth_axis) if not isinstance(vert_axis, int): vert_axis = axes[0] axes.pop(0) if not isinstance(horiz_axis, int): horiz_axis = axes[0] im = im.transpose([depth_axis, vert_axis, horiz_axis]) depth_axis = 'z' vert_axis = 'x' horiz_axis = 'y' im = data_grid(im, spacing=1, z=range(len(im))) if np.iscomplex(im).any(): warn("Image contains complex values. Taking image magnitude.") im = np.abs(im) if scaling is 'auto': scaling = (ensure_scalar(im.min()), ensure_scalar(im.max())) if scaling is not None: im = np.maximum(im, scaling[0]) im = np.minimum(im, scaling[1]) im = (im - scaling[0]) / (scaling[1] - scaling[0]) im.attrs = attrs im.attrs['_image_scaling'] = scaling if colour_axis in im.dims: cols = [ col[0].capitalize() if isinstance(col, str) else ' ' for col in im[colour_axis].values ] RGB_names = np.all([letter in 'RGB' for letter in cols]) if len(im[colour_axis]) == 1: im = im.squeeze(dim=colour_axis) elif len(im[colour_axis]) > 3: raise BadImage('Cannot output more than 3 colour channels') elif RGB_names: channels = { col: im[{ colour_axis: i }] for i, col in enumerate(cols) } if len(channels) == 2: dummy = im[{colour_axis: 0}].copy() dummy[:] = im.min() for i, col in enumerate('RGB'): if col not in cols: dummy[colour_axis] = col channels[col] = dummy channels['R'].attrs['_dummy_channel'] = i break channels = [channels[col] for col in 'RGB'] im = clean_concat(channels, colour_axis) elif len(im[colour_axis]) == 2: dummy = xr.full_like(im[{colour_axis: 0}], fill_value=im.min()) dummy = dummy.expand_dims({colour_axis: [np.NaN]}) im.attrs['_dummy_channel'] = -1 im = clean_concat([im, dummy], colour_axis) dim_order = [depth_axis, vert_axis, horiz_axis, colour_axis][:im.ndim] return im.transpose(*dim_order)