Example #1
0
    def get(self, image, sample, objective, **kwargs):

        # Grab properties from the objective to pass to the sample
        new_kwargs = objective.properties.current_value_dict()
        new_kwargs.update(kwargs)
        kwargs = new_kwargs

        list_of_scatterers = sample.resolve(**kwargs)
        if not isinstance(list_of_scatterers, list):
            list_of_scatterers = [list_of_scatterers]

        sample_volume, limits = _create_volume(list_of_scatterers, **kwargs)
        sample_volume = Image(sample_volume)

        for scatterer in list_of_scatterers:
            sample_volume.merge_properties_from(scatterer)

        imaged_sample = objective.resolve(sample_volume,
                                          limits=limits,
                                          **kwargs)

        # Merge with input
        if not image:
            return imaged_sample

        if not isinstance(image, list):
            image = [image]
        for i in range(len(image)):
            image[i].merge_properties_from(imaged_sample)
        return image
Example #2
0
    def get(self, image, snr, background, **kwargs):
        image[image < 0] = 0

        peak = np.abs(np.max(image) - background)

        rescale = snr**2 / peak**2
        noisy_image = Image(np.random.poisson(image * rescale) / rescale)
        noisy_image.properties = image.properties
        return noisy_image


## IMGAUG IMGCORRUPTLIKE
# Currently unavailable until there's a better way to implement constricted datatypes (only uint8)
# Please see https://github.com/aleju/imgaug/blob/master/imgaug/augmenters/imgcorruptlike.py
# for source implementation

# import imgaug.augmenters as iaa
# import deeptrack as dt
# import inspect

# def init_method(self, **kwargs):
#     dt.ImgAug.__init__(self, **kwargs)

# augs = inspect.getmembers(iaa.blur, lambda x: inspect.isclass(x))

# for augname, aug in augs:

#     print(augname, aug.__module__)

#     globals()[augname] = type(augname, (aug, dt.ImgAug), {
#         "augmenter": aug,
#         "__init__": init_method})
Example #3
0
    def _process_and_get(self,
                         *args,
                         update_properties=None,
                         index=0,
                         **kwargs):
        # Loads a result from storage
        image_list_of_lists = self.preloaded_results[index]
        if not isinstance(image_list_of_lists, list):
            image_list_of_lists = [image_list_of_lists]

        new_list_of_lists = []
        # Calls get
        for image_list in image_list_of_lists:
            if isinstance(image_list, list):
                new_list_of_lists.append([[
                    self.get(Image(image),
                             **kwargs).merge_properties_from(image)
                    for image in image_list
                ]])
            else:
                new_list_of_lists.append(
                    self.get(Image(image_list),
                             **kwargs).merge_properties_from(image_list))

        if update_properties:
            for image_list in new_list_of_lists:
                if not isinstance(new_list_of_lists, list):
                    image_list = [image_list]
                for image in image_list:
                    image.properties = [
                        dict(prop) for prop in image.properties
                    ]
                    update_properties(image, **kwargs)

        return new_list_of_lists
Example #4
0
    def _process_and_get(self, *args, update_properties=None, **kwargs):

        # Loads a result from storage
        if self.feature and (
            not hasattr(self, "cache")
            or kwargs["update_tally"] - self.last_update >= kwargs["updates_per_reload"]
        ):
            if isinstance(self.feature, list):
                self.cache = [feature.resolve() for feature in self.feature]
            else:
                self.cache = self.feature.resolve()
            self.last_update = kwargs["update_tally"]

        if not self.feature:
            image_list_of_lists = args[0]
        else:
            image_list_of_lists = self.cache

        if not isinstance(image_list_of_lists, list):
            image_list_of_lists = [image_list_of_lists]

        new_list_of_lists = []
        # Calls get

        np.random.seed(kwargs["hash_key"][0])

        for image_list in image_list_of_lists:
            if isinstance(self.feature, list):
                # If multiple features, ensure consistent rng
                np.random.seed(kwargs["hash_key"][0])

            if isinstance(image_list, list):
                new_list_of_lists.append(
                    [
                        [
                            Image(
                                self.get(Image(image), **kwargs)
                            ).merge_properties_from(image)
                            for image in image_list
                        ]
                    ]
                )
            else:
                new_list_of_lists.append(
                    Image(self.get(Image(image_list), **kwargs)).merge_properties_from(
                        image_list
                    )
                )

        if update_properties:
            if not isinstance(new_list_of_lists, list):
                new_list_of_lists = [new_list_of_lists]
            for image_list in new_list_of_lists:
                if not isinstance(image_list, list):
                    image_list = [image_list]
                for image in image_list:
                    image.properties = [dict(prop) for prop in image.properties]
                    update_properties(image, **kwargs)

        return new_list_of_lists
Example #5
0
 def get(self, image, snr=None, **kwargs):
     image[image < 0] = 0
     peak = np.max(image)
     rescale = snr**2 / peak
     noisy_image = Image(np.random.poisson(image * rescale) / rescale)
     noisy_image.properties = image.properties
     return noisy_image
Example #6
0
    def _continuous_get_training_data(self):
        index = 0
        while True:
            # Stop generator
            if self.exit_signal:
                break

            new_image = self._get(self.feature, self.feature_kwargs)

            if self.label_function:
                new_label = Image(self.label_function(new_image))

            if self.batch_function:
                new_image = Image(self.batch_function(new_image))

            if new_image.ndim < self.ndim:
                new_image = [new_image]
                new_label = [new_label]

            for new_image_i, new_label_i in zip(new_image, new_label):
                if len(self.data) >= self.max_data_size:
                    self.data[index % self.max_data_size] = (new_image_i, new_label_i)
                else:
                    self.data.append((new_image_i, new_label_i))

                index += 1
Example #7
0
    def get(self, image, snr, background, **kwargs):
        image[image < 0] = 0

        peak = np.abs(np.max(image) - background)

        rescale = snr**2 / peak**2
        noisy_image = Image(np.random.poisson(image * rescale) / rescale)
        noisy_image.properties = image.properties
        return noisy_image
Example #8
0
    def get(self, images, axis, features, **kwargs):
        if features is not None:
            images = [feature.resolve() for feature in features]
        result = Image(np.mean(images, axis=axis))

        for image in images:
            result.merge_properties_from(image)

        return result
Example #9
0
    def get(self, image, sample, objective, **kwargs):

        # Grab properties from the objective to pass to the sample
        new_kwargs = objective.properties.current_value_dict(**kwargs)
        new_kwargs.update(kwargs)
        kwargs = new_kwargs

        list_of_scatterers = sample.resolve(**kwargs)
        if not isinstance(list_of_scatterers, list):
            list_of_scatterers = [list_of_scatterers]

        volume_samples = [
            scatterer
            for scatterer in list_of_scatterers
            if not scatterer.get_property("is_field", default=False)
        ]
        field_samples = [
            scatterer
            for scatterer in list_of_scatterers
            if scatterer.get_property("is_field", default=False)
        ]

        sample_volume, limits = _create_volume(volume_samples, **kwargs)
        sample_volume = Image(sample_volume)

        for scatterer in volume_samples + field_samples:
            sample_volume.merge_properties_from(scatterer)

        imaged_sample = objective.resolve(
            sample_volume, limits=limits, fields=field_samples, **kwargs
        )

        upscale = kwargs["upscale"]
        shape = imaged_sample.shape
        if upscale > 1:
            mean_imaged_sample = np.reshape(
                imaged_sample,
                (shape[0] // upscale, upscale, shape[1] // upscale, upscale, shape[2]),
            ).mean(axis=(3, 1))

            imaged_sample = Image(mean_imaged_sample).merge_properties_from(
                imaged_sample
            )

        # Merge with input
        if not image:
            return imaged_sample

        if not isinstance(image, list):
            image = [image]
        for i in range(len(image)):
            image[i].merge_properties_from(imaged_sample)
        return image
Example #10
0
    def _process_and_get(self, image_list, **feature_input):

        if self.__distributed__:
            return [
                Image(self.get(image, **feature_input)) for image in image_list
            ]
        else:
            new_list = self.get(image_list, **feature_input)

            if not isinstance(new_list, list):
                new_list = [Image(new_list)]

            return new_list
Example #11
0
 def _process_and_get(self, image_list, **feature_input) -> List[Image]:
     # Controls how the get function is called
     
     if self.__distributed__:
         # Call get on each image in list, and merge properties from corresponding image
         return [Image(self.get(image, **feature_input)).merge_properties_from(image) for image in image_list]
     else:
         # Call get on entire list.
         new_list = self.get(image_list, **feature_input)
     
         if not isinstance(new_list, list):
             new_list = [Image(new_list)]
         
         return new_list
Example #12
0
 def test_Background(self):
     noise = noises.Background(offset=0.5)
     input_image = Image(np.zeros((256, 256)))
     output_image = noise.resolve(input_image)
     self.assertIsInstance(output_image, Image)
     self.assertEqual(output_image.shape, (256, 256))
     self.assertTrue(np.all(np.array(output_image) == 0.5))
Example #13
0
    def get(self, image, features=None, axis=None):

        image_list = [feature.resolve(image) for feature in features]

        merged_image = Image(np.concatenate(image_list, axis=axis))
        
        image = Image(image)
        num_properties = len(image.properties)
        
        merged_properties = image.properties

        for im in image_list:
            merged_properties += im.properties[num_properties:]

        merged_image.properties = merged_properties

        return merged_image
Example #14
0
    def _pupil(self,
               shape,
               NA,
               wavelength,
               refractive_index_medium,
               voxel_size,
               upscale,
               pupil,
               aberration=None,
               include_aberration=True,
               defocus=0,
               **kwargs):
        # Calculates the pupil at each z-position in defocus.
        shape = np.array(shape)

        # Pupil radius
        R = NA / wavelength * np.array(voxel_size)[:2]

        x_radius = R[0] * shape[0]
        y_radius = R[1] * shape[1]

        x = (np.linspace(-(shape[0] / 2), shape[0] / 2 - 1,
                         shape[0])) / x_radius + 1e-8
        y = (np.linspace(-(shape[1] / 2), shape[1] / 2 - 1,
                         shape[1])) / y_radius + 1e-8

        W, H = np.meshgrid(y, x)
        RHO = W**2 + H**2
        RHO[RHO > 1] = 1
        pupil_function = ((RHO < 1) * 1.0).astype(np.complex)
        # Defocus
        z_shift = (2 * np.pi * refractive_index_medium / wavelength *
                   voxel_size[2] *
                   np.sqrt(1 - (NA / refractive_index_medium)**2 * RHO))

        # Downsample the upsampled pupil

        pupil_function[np.isnan(pupil_function)] = 0
        pupil_function[np.isinf(pupil_function)] = 0
        pupil_function_is_nonzero = pupil_function != 0

        if include_aberration:
            pupil = pupil or aberration
            if isinstance(pupil, Feature):
                pupil_function = pupil.resolve(pupil_function, **kwargs)
            elif isinstance(pupil, np.ndarray):
                pupil_function *= pupil

        pupil_functions = []
        for z in defocus:
            pupil_at_z = Image(pupil_function)
            pupil_at_z[pupil_function_is_nonzero] *= np.exp(
                1j * z_shift[pupil_function_is_nonzero] * z)
            pupil_functions.append(pupil_at_z)

        return pupil_functions
Example #15
0
    def _format_input(self, image_list, **kwargs) -> List[Image]:
        # Ensures the input is a list of Image.

        if image_list is None:
            return []

        if not isinstance(image_list, list):
            image_list = [image_list]

        return [Image(image) for image in image_list]
Example #16
0
    def _process_and_get(self,
                         *args,
                         voxel_size,
                         upsample,
                         upsample_axes=None,
                         crop_empty=True,
                         **kwargs):
        # Post processes the created object to handle upsampling,
        # as well as cropping empty slices.

        # Calculates upsampled voxel_size
        if upsample_axes is None:
            upsample_axes = range(3)

        voxel_size = np.array(voxel_size)
        for axis in upsample_axes:
            voxel_size[axis] /= upsample

        # calls parent _process_and_get
        new_image = super()._process_and_get(*args,
                                             voxel_size=voxel_size,
                                             upsample=upsample,
                                             **kwargs)
        new_image = new_image[0]

        # Downsamples the image along the axes it was upsampled
        if upsample != 1 and upsample_axes:

            # Pad image to ensure it is divisible by upsample
            increase = np.array(new_image.shape)
            for axis in upsample_axes:
                increase[axis] = upsample - (new_image.shape[axis] % upsample)
            pad_width = [(0, inc) for inc in increase]
            new_image = np.pad(new_image, pad_width, mode='constant')

            # Finds reshape size for downsampling
            new_shape = []
            for axis in range(new_image.ndim):
                if axis in upsample_axes:
                    new_shape += [new_image.shape[axis] // upsample, upsample]
                else:
                    new_shape += [new_image.shape[axis]]

            # Downsamples
            new_image = np.reshape(new_image, new_shape).mean(
                axis=tuple(np.array(upsample_axes, dtype=np.int32) * 2 + 1))

        # Crops empty slices
        if crop_empty:
            new_image = new_image[~np.all(new_image == 0, axis=(1, 2))]
            new_image = new_image[:, ~np.all(new_image == 0, axis=(0, 2))]
            new_image = new_image[:, :, ~np.all(new_image == 0, axis=(0, 1))]

        return [Image(new_image)]
    def get(self,
            image,
            angle=None,
            axes=(1, 0),
            reshape=None,
            order=None,
            mode=None,
            cval=None,
            prefilter=None,
            **kwargs):

        new_image = Image(
            ndimage.rotate(image,
                           angle,
                           axes=axes,
                           reshape=reshape,
                           order=order,
                           mode=mode,
                           cval=cval,
                           prefilter=prefilter))
        new_image.properties = image.properties

        return new_image
Example #18
0
    def get(self, image, sample=None, objective=None, pupil=None, **kwargs):

        new_kwargs = objective.properties.current_value_dict()
        new_kwargs.update(kwargs)
        kwargs = new_kwargs

        list_of_scatterers = sample.resolve(**kwargs)
        if not isinstance(list_of_scatterers, list):
            list_of_scatterers = [list_of_scatterers]

        sample_volume, limits = create_volume(list_of_scatterers, **kwargs)

        sample_volume = Image(sample_volume)

        for scatterer in list_of_scatterers:
            sample_volume.properties += scatterer.properties

        imaged_sample = objective.resolve(sample_volume,
                                          pupil=pupil,
                                          limits=limits)

        # Merge with input
        if not image:
            return imaged_sample

        if not isinstance(image, list):
            image = [image]

        for i in range(len(image)):
            image[i] += imaged_sample
            for prop in imaged_sample.properties:
                if not any([
                        prop["hash_key"] == prop2["hash_key"]
                        for prop2 in image[i].properties
                ]):
                    image[i].properties.append(prop)

        return image
Example #19
0
def _create_volume(list_of_scatterers,
                   pad=(0, 0, 0, 0),
                   upscaled_output_region=(None, None, None, None),
                   refractive_index_medium=1.33,
                   upscale=1,
                   **kwargs):
    # Converts a list of scatterers into a volume.

    if not isinstance(list_of_scatterers, list):
        list_of_scatterers = [list_of_scatterers]

    volume = np.zeros((1, 1, 1), dtype=np.complex)
    limits = None
    OR = np.zeros((4, ))
    OR[0] = np.inf if upscaled_output_region[0] is None else int(
        upscaled_output_region[0] - pad[0])
    OR[1] = -np.inf if upscaled_output_region[1] is None else int(
        upscaled_output_region[1] - pad[1])
    OR[2] = np.inf if upscaled_output_region[2] is None else int(
        upscaled_output_region[2] + pad[2])
    OR[3] = -np.inf if upscaled_output_region[3] is None else int(
        upscaled_output_region[3] + pad[3])

    for scatterer in list_of_scatterers:

        position = _get_position(scatterer, mode="corner", return_z=True)

        if scatterer.get_property("intensity", None) is not None:
            scatterer_value = scatterer.get_property("intensity")
        elif scatterer.get_property("refractive_index", None) is not None:
            scatterer_value = scatterer.get_property(
                "refractive_index") - refractive_index_medium
        else:
            scatterer_value = scatterer.get_property("value")

        scatterer = scatterer * scatterer_value

        if limits is None:
            limits = np.zeros((3, 2), dtype=np.int32)
            limits[:, 0] = np.floor(position).astype(np.int32)
            limits[:, 1] = np.floor(position).astype(np.int32) + 1

        if (position[0] + scatterer.shape[0] < OR[0] or position[0] > OR[2]
                or position[1] + scatterer.shape[1] < OR[1]
                or position[1] > OR[3]):
            continue

        padded_scatterer = Image(
            np.pad(scatterer, [(2, 2), (2, 2), (2, 2)],
                   'constant',
                   constant_values=0))
        padded_scatterer.properties = scatterer.properties
        scatterer = padded_scatterer
        position = _get_position(scatterer, mode="corner", return_z=True)
        shape = np.array(scatterer.shape)

        if position is None:
            RuntimeWarning(
                "Optical device received an image without a position property. It will be ignored."
            )
            continue

        splined_scatterer = np.zeros_like(scatterer)

        x_off = position[0] - np.floor(position[0])
        y_off = position[1] - np.floor(position[1])

        kernel = np.array([[0, 0, 0],
                           [0, (1 - x_off) * (1 - y_off), (1 - x_off) * y_off],
                           [0, x_off * (1 - y_off), x_off * y_off]])

        for z in range(scatterer.shape[2]):
            splined_scatterer[:, :, z] = convolve(scatterer[:, :, z],
                                                  kernel,
                                                  mode="constant")

        scatterer = splined_scatterer
        position = np.floor(position)
        new_limits = np.zeros(limits.shape, dtype=np.int32)
        for i in range(3):
            new_limits[i, :] = (
                np.min([limits[i, 0], position[i]]),
                np.max([limits[i, 1], position[i] + shape[i]]),
            )

        if not (np.array(new_limits) == np.array(limits)).all():
            new_volume = np.zeros(np.diff(new_limits,
                                          axis=1)[:, 0].astype(np.int32),
                                  dtype=np.complex)
            old_region = (limits - new_limits).astype(np.int32)
            limits = limits.astype(np.int32)
            new_volume[old_region[0, 0]:old_region[0, 0] + limits[0, 1] -
                       limits[0, 0], old_region[1, 0]:old_region[1, 0] +
                       limits[1, 1] - limits[1, 0],
                       old_region[2, 0]:old_region[2, 0] + limits[2, 1] -
                       limits[2, 0]] = volume
            volume = new_volume
            limits = new_limits

        within_volume_position = position - limits[:, 0]

        # NOTE: Maybe shouldn't be additive.
        volume[int(within_volume_position[0]):int(within_volume_position[0] +
                                                  shape[0]),
               int(within_volume_position[1]):int(within_volume_position[1] +
                                                  shape[1]),
               int(within_volume_position[2]):int(within_volume_position[2] +
                                                  shape[2])] += scatterer
    return volume, limits
Example #20
0
    def get(self, illuminated_volume, limits, fields, **kwargs):
        ''' Convolves the image with a pupil function
        '''
        # Pad volume
        padded_volume, limits = self._pad_volume(illuminated_volume,
                                                 limits=limits,
                                                 **kwargs)

        # Extract indexes of the output region
        pad = kwargs.get("padding", (0, 0, 0, 0))
        output_region = np.array(
            kwargs.get("upscaled_output_region", (None, None, None, None)))
        output_region[0] = None if output_region[0] is None else int(
            output_region[0] - limits[0, 0] - pad[0])
        output_region[1] = None if output_region[1] is None else int(
            output_region[1] - limits[1, 0] - pad[1])
        output_region[2] = None if output_region[2] is None else int(
            output_region[2] - limits[0, 0] + pad[2])
        output_region[3] = None if output_region[3] is None else int(
            output_region[3] - limits[1, 0] + pad[3])

        padded_volume = padded_volume[output_region[0]:output_region[2],
                                      output_region[1]:output_region[3], :]
        z_limits = limits[2, :]

        output_image = Image(np.zeros((*padded_volume.shape[0:2], 1)))

        index_iterator = range(padded_volume.shape[2])
        z_iterator = np.linspace(z_limits[0],
                                 z_limits[1],
                                 num=padded_volume.shape[2],
                                 endpoint=False)

        zero_plane = np.all(padded_volume == 0, axis=(0, 1), keepdims=False)
        # z_values = z_iterator[~zero_plane]

        volume = pad_image_to_fft(padded_volume, axes=(0, 1))

        voxel_size = kwargs['voxel_size']

        pupils = (self._pupil(
            volume.shape[:2], defocus=[1], include_aberration=False, **kwargs)
                  + self._pupil(volume.shape[:2],
                                defocus=[-z_limits[1]],
                                include_aberration=True,
                                **kwargs))

        pupil_step = np.fft.fftshift(pupils[0])

        if "illumination" in kwargs:
            light_in = np.ones(volume.shape[:2], dtype=np.complex)
            light_in = kwargs["illumination"].resolve(light_in, **kwargs)
            light_in = np.fft.fft2(light_in)
        else:
            light_in = np.zeros(volume.shape[:2], dtype=np.complex)
            light_in[0, 0] = light_in.size

        K = 2 * np.pi / kwargs["wavelength"]

        field_z = [_get_position(field, return_z=True)[-1] for field in fields]
        field_offsets = [
            field.get_property("offset_z", default=0) for field in fields
        ]

        z = z_limits[1]
        for i, z in zip(index_iterator, z_iterator):
            light_in = light_in * pupil_step

            to_remove = []
            for idx, fz in enumerate(field_z):
                if fz < z:
                    propagation_matrix = self._pupil(
                        fields[idx].shape,
                        defocus=[z - fz - field_offsets[idx] / voxel_size[-1]],
                        include_aberration=False,
                        **kwargs)[0]
                    propagation_matrix = propagation_matrix * np.exp(
                        1j * voxel_size[-1] * 2 * np.pi / kwargs["wavelength"]
                        * kwargs["refractive_index_medium"] * (z - fz))
                    light_in += np.fft.fft2(
                        fields[idx][:, :,
                                    0]) * np.fft.fftshift(propagation_matrix)
                    to_remove.append(idx)

            for idx in reversed(to_remove):
                fields.pop(idx)
                field_z.pop(idx)
                field_offsets.pop(idx)

            if zero_plane[i]:
                continue

            ri_slice = volume[:, :, i]
            light = np.fft.ifft2(light_in)
            light_out = light * np.exp(1j * ri_slice * voxel_size[-1] * K)
            light_in = np.fft.fft2(light_out)

        # Add remaining fields
        for idx, fz in enumerate(field_z):
            prop_dist = z - fz - field_offsets[idx] / voxel_size[-1]
            propagation_matrix = self._pupil(fields[idx].shape,
                                             defocus=[prop_dist],
                                             include_aberration=False,
                                             **kwargs)[0]
            propagation_matrix = propagation_matrix * np.exp(
                -1j * voxel_size[-1] * 2 * np.pi / kwargs["wavelength"] *
                kwargs["refractive_index_medium"] * prop_dist)
            light_in += np.fft.fft2(
                fields[idx][:, :, 0]) * np.fft.fftshift(propagation_matrix)

        light_in_focus = light_in * np.fft.fftshift(pupils[-1])

        output_image = np.fft.ifft2(
            light_in_focus)[:padded_volume.shape[0], :padded_volume.shape[1]]
        output_image = np.expand_dims(output_image, axis=-1)
        output_image = Image(output_image[pad[0]:-pad[2], pad[1]:-pad[3]])

        if not kwargs.get("return_field", False):
            output_image = np.square(np.abs(output_image))

        output_image.properties = illuminated_volume.properties

        return output_image
Example #21
0
    def get(self, illuminated_volume, limits, **kwargs):
        ''' Convolves the image with a pupil function
        '''
        # Pad volume
        padded_volume, limits = self._pad_volume(illuminated_volume,
                                                 limits=limits,
                                                 **kwargs)

        # Extract indexes of the output region
        pad = kwargs.get("padding", (0, 0, 0, 0))
        output_region = np.array(
            kwargs.get("upscaled_output_region", (None, None, None, None)))
        output_region[0] = None if output_region[0] is None else int(
            output_region[0] - limits[0, 0] - pad[0])
        output_region[1] = None if output_region[1] is None else int(
            output_region[1] - limits[1, 0] - pad[1])
        output_region[2] = None if output_region[2] is None else int(
            output_region[2] - limits[0, 0] + pad[2])
        output_region[3] = None if output_region[3] is None else int(
            output_region[3] - limits[1, 0] + pad[3])

        padded_volume = padded_volume[output_region[0]:output_region[2],
                                      output_region[1]:output_region[3], :]
        z_limits = limits[2, :]

        output_image = Image(np.zeros((*padded_volume.shape[0:2], 1)))

        index_iterator = range(padded_volume.shape[2])

        # Get planes in volume where not all values are 0.
        z_iterator = np.linspace(z_limits[0],
                                 z_limits[1],
                                 num=padded_volume.shape[2],
                                 endpoint=False)
        zero_plane = np.all(padded_volume == 0, axis=(0, 1), keepdims=False)
        z_values = z_iterator[~zero_plane]

        # Further pad image to speed up fft
        volume = pad_image_to_fft(padded_volume, axes=(0, 1))

        pupils = self._pupil(volume.shape[:2], defocus=z_values, **kwargs)
        pupil_iterator = iter(pupils)

        # Loop through voluma and convole sample with pupil function
        for i, z in zip(index_iterator, z_iterator):

            if zero_plane[i]:
                continue

            image = volume[:, :, i]
            pupil = Image(next(pupil_iterator))

            psf = np.square(np.abs(np.fft.ifft2(np.fft.fftshift(pupil))))
            optical_transfer_function = np.fft.fft2(psf)

            fourier_field = np.fft.fft2(image)
            convolved_fourier_field = fourier_field * optical_transfer_function

            field = Image(np.fft.ifft2(convolved_fourier_field))

            # Discard remaining imaginary part (should be 0 up to rounding error)
            field = np.real(field)

            output_image[:, :, 0] += field[:padded_volume.
                                           shape[0], :padded_volume.shape[1]]

        output_image = output_image[pad[0]:-pad[2], pad[1]:-pad[3]]
        try:
            output_image.properties = illuminated_volume.properties + pupil.properties
        except UnboundLocalError:
            output_image.properties = illuminated_volume.properties

        return output_image
Example #22
0
    def plot(
        self,
        input_image: Image or List[Image] = None,
        resolve_kwargs: dict = None,
        interval: float = None,
        **kwargs
    ):
        """Visualizes the output of the feature.

        Resolves the feature and visualizes the result. If the output is an Image,
        show it using `pyplot.imshow`. If the output is a list, create an `Animation`.
        For notebooks, the animation is played inline using `to_jshtml()`. For scripts,
        the animation is played using the matplotlib backend.

        Any parameters in kwargs will be passed to `pyplot.imshow`.

        Parameters
        ----------
        input_image : Image or List[Image], optional
            Passed as argument to `resolve` call
        resolve_kwargs : dict, optional
            Passed as kwarg arguments to `resolve` call
        interval : float
            The time between frames in animation in ms. Default 33.
        kwargs
            keyword arguments passed to the method pyplot.imshow()
        """

        import matplotlib.pyplot as plt
        import matplotlib.animation as animation
        from IPython.display import HTML, display

        if input_image is not None:
            input_image = [Image(input_image)]

        output_image = self.resolve(input_image, **(resolve_kwargs or {}))

        # If a list, assume video
        if isinstance(output_image, Image):
            # Single image
            plt.imshow(output_image[:, :, 0], **kwargs)
            plt.show()

        else:
            # Assume video
            fig = plt.figure()
            images = []
            plt.axis("off")
            for image in output_image:
                images.append([plt.imshow(image[:, :, 0], **kwargs)])

            interval = (
                interval or output_image[0].get_property("interval") or (1 / 30 * 1000)
            )

            anim = animation.ArtistAnimation(
                fig, images, interval=interval, blit=True, repeat_delay=0
            )

            try:
                get_ipython  # Throws NameError if not in Notebook
                display(HTML(anim.to_jshtml()))
                return anim

            except NameError as e:
                # Not in an notebook
                plt.show()

            except RuntimeError as e:
                # In notebook, but animation failed
                import ipywidgets as widgets

                Warning(
                    "Javascript animation failed. This is a non-performant fallback."
                )

                def plotter(frame=0):
                    plt.imshow(output_image[frame][:, :, 0], **kwargs)
                    plt.show()

                return widgets.interact(
                    plotter,
                    frame=widgets.IntSlider(
                        value=0, min=0, max=len(images) - 1, step=1
                    ),
                )
Example #23
0
def _create_volume(list_of_scatterers,
                   pad=(0, 0, 0, 0),
                   output_region=(None, None, None, None),
                   refractive_index_medium=1.33,
                   **kwargs):
    # Converts a list of scatterers into a volume.

    if not isinstance(list_of_scatterers, list):
        list_of_scatterers = [list_of_scatterers]
    volume = np.zeros((1, 1, 1), dtype=np.complex)

    # x, y, z limits of the volume
    limits = np.array([(0, 1), (0, 1), (0, 1)])

    OR = np.zeros((4, ))
    for scatterer in list_of_scatterers:

        position = _get_position(scatterer, mode="corner", return_z=True)

        if scatterer.get_property("intensity", None) is not None:
            scatterer_value = scatterer.get_property("intensity")
        elif scatterer.get_property("refractive_index", None) is not None:
            scatterer_value = scatterer.get_property(
                "refractive_index") - refractive_index_medium
        else:
            scatterer_value = scatterer.get_property("value")

        scatterer = scatterer * scatterer_value

        if limits is None:
            limits = np.zeros((3, 2))
            limits[:, 0] = np.round(position).astype(np.int32)
            limits[:, 1] = np.round(position).astype(np.int32) + 1

        OR[0] = np.inf if output_region[0] is None else int(output_region[0] -
                                                            limits[0, 0] -
                                                            pad[0])
        OR[1] = -np.inf if output_region[1] is None else int(output_region[1] -
                                                             limits[1, 0] -
                                                             pad[1])
        OR[2] = np.inf if output_region[2] is None else int(output_region[2] -
                                                            limits[0, 0] +
                                                            pad[2])
        OR[3] = -np.inf if output_region[3] is None else int(output_region[3] -
                                                             limits[1, 0] +
                                                             pad[3])

        if (position[0] + scatterer.shape[0] < OR[0] or position[0] > OR[2]
                or position[1] + scatterer.shape[1] < OR[1]
                or position[1] > OR[3]):
            continue

        padded_scatterer = Image(
            np.pad(scatterer, [(2, 2), (2, 2), (0, 0)],
                   'constant',
                   constant_values=0))
        padded_scatterer.properties = scatterer.properties
        scatterer = padded_scatterer

        position = _get_position(scatterer, mode="corner", return_z=True)
        shape = np.array(scatterer.shape)

        if position is None:
            RuntimeWarning(
                "Optical device received a feature without a position property. It will be ignored."
            )
            continue

        x_pos = position[0] + np.arange(scatterer.shape[0])
        y_pos = position[1] + np.arange(scatterer.shape[1])

        target_x_pos = np.round(x_pos)
        target_y_pos = np.round(y_pos)

        splined_scatterer = np.zeros_like(scatterer)
        for z in range(scatterer.shape[2]):

            scatterer_spline = RectBivariateSpline(x_pos, y_pos,
                                                   np.real(scatterer[:, :, z]))
            splined_scatterer[1:-1, 1:-1,
                              z] = scatterer_spline(target_x_pos[1:-1],
                                                    target_y_pos[1:-1])

            if scatterer.dtype == np.complex:
                scatterer_spline = RectBivariateSpline(
                    x_pos, y_pos, np.imag(scatterer[:, :, z]))
                splined_scatterer[1:-1, 1:-1, z] += 1j * \
                    scatterer_spline(target_x_pos[1:-1], target_y_pos[1:-1])

        scatterer = splined_scatterer
        position = np.round(position)
        new_limits = np.zeros(limits.shape, dtype=np.int32)
        for i in range(3):
            new_limits[i, :] = (
                np.min([limits[i, 0], position[i]]),
                np.max([limits[i, 1], position[i] + shape[i]]),
            )

        if not (np.array(new_limits) == np.array(limits)).all():
            new_volume = np.zeros(np.diff(new_limits,
                                          axis=1)[:, 0].astype(np.int32),
                                  dtype=np.complex)
            old_region = (limits - new_limits).astype(np.int32)
            limits = limits.astype(np.int32)
            new_volume[old_region[0, 0]:old_region[0, 0] + limits[0, 1] -
                       limits[0, 0], old_region[1, 0]:old_region[1, 0] +
                       limits[1, 1] - limits[1, 0],
                       old_region[2, 0]:old_region[2, 0] + limits[2, 1] -
                       limits[2, 0]] = volume
            volume = new_volume
            limits = new_limits

        within_volume_position = position - limits[:, 0]

        # NOTE: Maybe shouldn't be additive.
        volume[int(within_volume_position[0]):int(within_volume_position[0] +
                                                  shape[0]),
               int(within_volume_position[1]):int(within_volume_position[1] +
                                                  shape[1]),
               int(within_volume_position[2]):int(within_volume_position[2] +
                                                  shape[2])] += scatterer

    return volume, limits
Example #24
0
    def get(self, illuminated_volume, limits, **kwargs):
        ''' Convolves the image with a pupil function
        '''

        # Pad volume
        padded_volume, limits = self._pad_volume(illuminated_volume,
                                                 limits=limits,
                                                 **kwargs)

        # Extract indexes of the output region
        pad = kwargs.get("padding", (0, 0, 0, 0))
        output_region = np.array(
            kwargs.get("output_region", (None, None, None, None)))
        output_region[0] = None if output_region[0] is None else int(
            output_region[0] - limits[0, 0] - pad[0])
        output_region[1] = None if output_region[1] is None else int(
            output_region[1] - limits[1, 0] - pad[1])
        output_region[2] = None if output_region[2] is None else int(
            output_region[2] - limits[0, 0] + pad[2])
        output_region[3] = None if output_region[3] is None else int(
            output_region[3] - limits[1, 0] + pad[3])

        padded_volume = padded_volume[output_region[0]:output_region[2],
                                      output_region[1]:output_region[3], :]
        z_limits = limits[2, :]

        output_image = Image(np.zeros((*padded_volume.shape[0:2], 1)))

        index_iterator = range(padded_volume.shape[2])
        z_iterator = np.linspace(z_limits[0],
                                 z_limits[1],
                                 num=padded_volume.shape[2],
                                 endpoint=False)

        zero_plane = np.all(padded_volume == 0, axis=(0, 1), keepdims=False)
        # z_values = z_iterator[~zero_plane]

        volume = pad_image_to_fft(padded_volume, axes=(0, 1))

        voxel_size = kwargs['voxel_size']

        pupils = (self._pupil(
            volume.shape[:2], defocus=[1], include_aberration=False, **kwargs)
                  + self._pupil(volume.shape[:2],
                                defocus=[-z_limits[1]],
                                include_aberration=True,
                                **kwargs))

        pupil_step = np.fft.fftshift(pupils[0])

        if "illumination" in kwargs:
            light_in = np.ones(volume.shape[:2])
            light_in = kwargs["illumination"].resolve(light_in, **kwargs)
            light_in = np.fft.fft2(light_in)
        else:
            light_in = np.zeros(volume.shape[:2])
            light_in[0, 0] = light_in.size

        K = 2 * np.pi / kwargs["wavelength"]

        for i, z in zip(index_iterator, z_iterator):

            light_in = light_in * pupil_step

            if zero_plane[i]:
                continue

            ri_slice = volume[:, :, i]

            light = np.fft.ifft2(light_in)

            light_out = light * np.exp(1j * ri_slice * voxel_size[-1] * K)

            light_in = np.fft.fft2(light_out)

        light_in_focus = light_in * np.fft.fftshift(pupils[-1])

        output_image = np.fft.ifft2(
            light_in_focus)[:padded_volume.shape[0], :padded_volume.shape[1]]
        output_image = np.expand_dims(output_image, axis=-1)
        output_image = Image(output_image[pad[0]:-pad[2], pad[1]:-pad[3]])

        if not kwargs.get("return_field", False):
            output_image = np.square(np.abs(output_image))

        output_image.properties = illuminated_volume.properties

        return output_image
Example #25
0
 def test_Gaussian(self):
     noise = noises.Gaussian(mu=0.1, sigma=0.05)
     input_image = Image(np.zeros((256, 256)))
     output_image = noise.resolve(input_image)
     self.assertIsInstance(output_image, Image)
     self.assertEqual(output_image.shape, (256, 256))
Example #26
0
    def get(self, image, scale, translate, rotate, shear, **kwargs):

        assert (
            image.ndim == 2 or image.ndim == 3
        ), "Affine only supports 2-dimensional or 3-dimension inputs."

        dx, dy = translate
        fx, fy = scale

        cr = np.cos(rotate)
        sr = np.sin(rotate)

        k = np.tan(shear)

        scale_map = np.array([[1 / fx, 0], [0, 1 / fy]])
        rotation_map = np.array([[cr, sr], [-sr, cr]])
        shear_map = np.array([[1, 0], [-k, 1]])

        mapping = scale_map @ rotation_map @ shear_map

        shape = image.shape
        center = np.array(shape[:2]) / 2

        d = center - np.dot(mapping, center) - np.array([dy, dx])

        # Clean up kwargs
        kwargs.pop("input", False)
        kwargs.pop("matrix", False)
        kwargs.pop("offset", False)
        kwargs.pop("output", False)

        # Call affine_transform
        if image.ndim == 2:
            new_image = utils.safe_call(
                ndimage.affine_transform,
                input=image,
                matrix=mapping,
                offset=d,
                **kwargs
            )

            new_image = Image(new_image)
            new_image.merge_properties_from(image)
            image = new_image

        elif image.ndim == 3:
            for z in range(shape[-1]):
                image[:, :, z] = utils.safe_call(
                    ndimage.affine_transform,
                    input=image[:, :, z],
                    matrix=mapping,
                    offset=d,
                    **kwargs
                )

        # Map positions
        inverse_mapping = np.linalg.inv(mapping)
        for prop in image.properties:
            if "position" in prop:
                position = np.array(prop["position"])
                prop["position"] = np.array(
                    (
                        *(
                            (
                                inverse_mapping
                                @ (position[:2] - center + np.array([dy, dx]))
                                + center
                            )
                        ),
                        *position[3:],
                    )
                )

        return image
Example #27
0
    def pupil(self,
              shape,
              NA=None,
              wavelength=None,
              refractive_index_medium=None,
              voxel_size=None,
              defocus=None,
              upscale=None,
              pupil=None,
              aberration=None,
              include_aberration=True,
              **kwargs):
        ''' Calculates pupil function

        Parameters
        ----------      
        shape
            The shape of the pupil function
        kwargs
            The current values of the properties of the optical device
        '''
        shape = np.array(shape)

        upscaled_shape = shape * upscale
        # Pupil radius
        R = NA / wavelength * np.array(voxel_size)[:2]

        x_radius = R[0] * upscaled_shape[0]
        y_radius = R[1] * upscaled_shape[1]

        x = (np.linspace(-(upscaled_shape[0] / 2), upscaled_shape[0] / 2 - 1,
                         upscaled_shape[0])) / x_radius + 1e-8
        y = (np.linspace(-(upscaled_shape[1] / 2), upscaled_shape[1] / 2 - 1,
                         upscaled_shape[1])) / y_radius + 1e-8

        W, H = np.meshgrid(y, x)
        RHO = W**2 + H**2
        RHO[RHO > 1] = 1
        pupil_function = ((RHO < 1) * 1.0).astype(np.complex)

        # Defocus
        z_shift = (2 * np.pi * refractive_index_medium / wavelength *
                   voxel_size[2] *
                   np.sqrt(1 - (NA / refractive_index_medium * RHO)**2))

        # Downsample the upsampled pupil
        if upscale > 1:
            pupil_function = np.reshape(
                pupil_function,
                (shape[0], upscale, shape[1], upscale)).mean(axis=(3, 1))
            z_shift = np.reshape(
                z_shift,
                (shape[0], upscale, shape[1], upscale)).mean(axis=(3, 1))

        pupil_function[np.isnan(pupil_function)] = 0
        pupil_function[np.isinf(pupil_function)] = 0
        pupil_function_is_nonzero = pupil_function != 0

        if include_aberration:
            pupil = pupil or aberration
            if isinstance(pupil, Feature):
                pupil_function = pupil.resolve(pupil_function, **kwargs)
            elif isinstance(pupil, np.ndarray):
                pupil_function *= pupil

        pupil_functions = []
        for z in defocus:
            pupil_at_z = Image(pupil_function)
            pupil_at_z[pupil_function_is_nonzero] *= np.exp(
                1j * z_shift[pupil_function_is_nonzero] * z)
            pupil_functions.append(pupil_at_z)

        return pupil_functions
Example #28
0
    def _process_and_get(self, images, **kwargs):
        if isinstance(images, list) and len(images) != 1:
            list_of_labels = super()._process_and_get(images, **kwargs)
        else:
            if isinstance(images, list):
                images = images[0]
            list_of_labels = []
            for prop in images.properties:

                if "position" in prop:

                    inp = Image(np.array(images))
                    inp.append(prop)
                    out = Image(self.get(inp, **kwargs))
                    out.merge_properties_from(inp)
                    list_of_labels.append(out)

        output_region = kwargs["output_region"]
        output = np.zeros(
            (output_region[2], output_region[3], kwargs["number_of_masks"])
        )

        for label in list_of_labels:
            positions = _get_position(label)
            for position in positions:
                p0 = np.round(position - output_region[0:2])

                if np.any(p0 > output.shape[0:2]) or np.any(p0 + label.shape[0:2] < 0):
                    continue

                crop_x = int(-np.min([p0[0], 0]))
                crop_y = int(-np.min([p0[1], 0]))
                crop_x_end = int(
                    label.shape[0]
                    - np.max([p0[0] + label.shape[0] - output.shape[0], 0])
                )
                crop_y_end = int(
                    label.shape[1]
                    - np.max([p0[1] + label.shape[1] - output.shape[1], 0])
                )

                labelarg = label[crop_x:crop_x_end, crop_y:crop_y_end, :]

                p0[0] = np.max([p0[0], 0])
                p0[1] = np.max([p0[1], 0])

                p0 = p0.astype(np.int)

                output_slice = output[
                    p0[0] : p0[0] + labelarg.shape[0], p0[1] : p0[1] + labelarg.shape[1]
                ]

                for label_index in range(kwargs["number_of_masks"]):

                    if isinstance(kwargs["merge_method"], list):
                        merge = kwargs["merge_method"][label_index]
                    else:
                        merge = kwargs["merge_method"]

                    if merge == "add":
                        output[
                            p0[0] : p0[0] + labelarg.shape[0],
                            p0[1] : p0[1] + labelarg.shape[1],
                            label_index,
                        ] += labelarg[..., label_index]

                    elif merge == "overwrite":
                        output_slice[
                            labelarg[..., label_index] != 0, label_index
                        ] = labelarg[labelarg[..., label_index] != 0, label_index]
                        output[
                            p0[0] : p0[0] + labelarg.shape[0],
                            p0[1] : p0[1] + labelarg.shape[1],
                            label_index,
                        ] = output_slice[..., label_index]

                    elif merge == "or":
                        output[
                            p0[0] : p0[0] + labelarg.shape[0],
                            p0[1] : p0[1] + labelarg.shape[1],
                            label_index,
                        ] = (output_slice[..., label_index] != 0) | (
                            labelarg[..., label_index] != 0
                        )

                    elif merge == "mul":
                        output[
                            p0[0] : p0[0] + labelarg.shape[0],
                            p0[1] : p0[1] + labelarg.shape[1],
                            label_index,
                        ] *= labelarg[..., label_index]

                    else:
                        # No match, assume function
                        output[
                            p0[0] : p0[0] + labelarg.shape[0],
                            p0[1] : p0[1] + labelarg.shape[1],
                            label_index,
                        ] = merge(
                            output_slice[..., label_index], labelarg[..., label_index]
                        )
        output = Image(output)
        for label in list_of_labels:
            output.merge_properties_from(label)
        return output
Example #29
0
 def test_Poisson(self):
     noise = noises.Poisson(snr=20)
     input_image = Image(np.ones((256, 256)) * 0.1)
     output_image = noise.resolve(input_image)
     self.assertIsInstance(output_image, Image)
     self.assertEqual(output_image.shape, (256, 256))
Example #30
0
    def plot(self, input_image=None, interval=None, **kwargs):
        ''' Resolves the image and shows the result

        Parameters
        ----------
        shape
            shape of the image to be drawn
        input_image
            
        kwargs
            keyword arguments passed to the method plt.imshow()
        '''

        import matplotlib.pyplot as plt
        import matplotlib.animation as animation
        from IPython.display import HTML, display

        if input_image is not None:
            input_image = [Image(input_image)]

        output_image = self.resolve(input_image)

        # If a list, assume video
        if isinstance(output_image, Image):
            # Single image
            plt.imshow(output_image[:, :, 0], **kwargs)
            plt.show()

        else:
            # Assume video
            fig = plt.figure()
            images = []
            for image in output_image:
                images.append([plt.imshow(image[:, :, 0], **kwargs)])

            interval = (interval or get_property(output_image[0], "interval")
                        or (1 / 30 * 1000))

            anim = animation.ArtistAnimation(fig,
                                             images,
                                             interval=interval,
                                             blit=True,
                                             repeat_delay=0)

            try:
                get_ipython  # Throws NameError if not in Notebook
                display(HTML(anim.to_jshtml()))

            except NameError as e:
                # Not in an notebook
                plt.show()

            except RuntimeError as e:
                # In notebook, but animation failed
                import ipywidgets as widgets
                Warning(
                    "Javascript animation failed. This is a non-performant fallback."
                )

                def plotter(frame=0):
                    plt.imshow(output_image[frame][:, :, 0], **kwargs)
                    plt.show()

                return widgets.interact(plotter,
                                        frame=widgets.IntSlider(
                                            value=0,
                                            min=0,
                                            max=len(images) - 1,
                                            step=1))