Esempio n. 1
0
 def test_too_many_color_channels(self):
     col1 = convert_ndarray_to_xarray(ARRAY_4D,
                                      extra_dims={ILLUM: [0, 1, 2]})
     col2 = convert_ndarray_to_xarray(ARRAY_4D,
                                      extra_dims={ILLUM: [3, 4, 5]})
     xr6cols = clean_concat([col1, col2], dim=ILLUM)
     assert_raises(BadImage, display_image, xr6cols)
Esempio n. 2
0
    def test_preserves_data_order(self):
        data1 = make_data(seed=1)
        data2 = make_data(seed=2)
        data = [data1, data2]

        concatenated = clean_concat(data, 'point')
        self.assertTrue(np.all(concatenated.values[..., 0] == data[0].values))
        self.assertTrue(np.all(concatenated.values[..., 1] == data[1].values))
Esempio n. 3
0
    def test_preserves_metadata_keys(self):
        data1 = update_metadata(make_data(seed=1), **METADATA_VALUES)
        data2 = update_metadata(make_data(seed=2), **METADATA_VALUES)
        data = [data1, data2]

        concatenated = clean_concat(data, 'point')
        for key in METADATA_VALUES.keys():
            self.assertIn(key, concatenated.attrs)
            self.assertTrue(hasattr(concatenated, key))
Esempio n. 4
0
 def test_save_multiple_images_writes_image_files(self):
     # For now, we just test that it writes the image files, not that
     # the files are correct:
     savenames = [
         self._make_unused_filename_in_tempdir('png', i)
         for i, _ in enumerate(self.holograms)
     ]
     assert all([not os.path.exists(nm) for nm in savenames])
     save_plot(savenames, clean_concat(self.holograms, dim='z'))
     self.assertTrue(all([os.path.exists(nm) for nm in savenames]))
Esempio n. 5
0
    def test_preserves_metadata_values(self):
        data1 = update_metadata(make_data(seed=1), **METADATA_VALUES)
        data2 = update_metadata(make_data(seed=2), **METADATA_VALUES)
        data = [data1, data2]

        concatenated = clean_concat(data, 'point')
        for key, value in METADATA_VALUES.items():
            if key != 'illum_polarization':
                self.assertEqual(getattr(concatenated, key), value)
        polarization_ok = np.all(concatenated.illum_polarization[:2] ==
                                 METADATA_VALUES['illum_polarization'])
        self.assertTrue(polarization_ok)
Esempio n. 6
0
 def _calculate_multiple_color_scattered_field(self, scatterer, schema):
     field = []
     for illum in schema.illum_wavelen.illumination.values:
         this_schema = update_metadata(
             schema,
             illum_wavelen=ensure_array(
                 schema.illum_wavelen.sel(illumination=illum).values)[0],
             illum_polarization=ensure_array(
                 schema.illum_polarization.sel(illumination=illum).values))
         this_field = self._calculate_single_color_scattered_field(
             scatterer.select({illumination: illum}), this_schema)
         field.append(this_field)
     field = clean_concat(field, dim=schema.illum_wavelen.illumination)
     return field
Esempio n. 7
0
def display_image(im,
                  scaling='auto',
                  vert_axis='x',
                  horiz_axis='y',
                  depth_axis='z',
                  colour_axis='illumination'):
    im = im.copy()
    if isinstance(im, xr.DataArray):
        if hasattr(im, 'z') and len(im['z']) == 1 and depth_axis is not 'z':
            im = im[{'z': 0}]
        if depth_axis == 'z' and 'z' not in im.dims:
            im = im.expand_dims('z')
        if im.ndim > 3 + (colour_axis in im.dims):
            raise BadImage("Too many dims on DataArray to output properly.")
        attrs = im.attrs
    else:
        attrs = {}
        im = ensure_array(im)
        if im.ndim > 3:
            raise BadImage("Too many dims on ndarray to output properly.")
        elif im.ndim == 2:
            im = np.array([im])
        elif im.ndim < 2:
            raise BadImage("Too few dims on ndarray to output properly.")
        axes = [0, 1, 2]
        for axis in [vert_axis, horiz_axis, depth_axis]:
            if isinstance(axis, int):
                try:
                    axes.remove(axis)
                except KeyError:
                    raise ValueError("Cannot interpret axis specifications.")
        if len(axes) > 0:
            if not isinstance(depth_axis, int):
                depth_axis = axes[np.argmin([im.shape[i] for i in axes])]
                axes.remove(depth_axis)
            if not isinstance(vert_axis, int):
                vert_axis = axes[0]
                axes.pop(0)
            if not isinstance(horiz_axis, int):
                horiz_axis = axes[0]
        im = im.transpose([depth_axis, vert_axis, horiz_axis])
        depth_axis = 'z'
        vert_axis = 'x'
        horiz_axis = 'y'
        im = data_grid(im, spacing=1, z=range(len(im)))
    if np.iscomplex(im).any():
        warn("Image contains complex values. Taking image magnitude.")
        im = np.abs(im)
    if scaling is 'auto':
        scaling = (ensure_scalar(im.min()), ensure_scalar(im.max()))
    if scaling is not None:
        im = np.maximum(im, scaling[0])
        im = np.minimum(im, scaling[1])
        im = (im - scaling[0]) / (scaling[1] - scaling[0])
    im.attrs = attrs
    im.attrs['_image_scaling'] = scaling

    if colour_axis in im.dims:
        cols = [
            col[0].capitalize() if isinstance(col, str) else ' '
            for col in im[colour_axis].values
        ]
        RGB_names = np.all([letter in 'RGB' for letter in cols])
        if len(im[colour_axis]) == 1:
            im = im.squeeze(dim=colour_axis)
        elif len(im[colour_axis]) > 3:
            raise BadImage('Cannot output more than 3 colour channels')
        elif RGB_names:
            channels = {
                col: im[{
                    colour_axis: i
                }]
                for i, col in enumerate(cols)
            }
            if len(channels) == 2:
                dummy = im[{colour_axis: 0}].copy()
                dummy[:] = im.min()
                for i, col in enumerate('RGB'):
                    if col not in cols:
                        dummy[colour_axis] = col
                        channels[col] = dummy
                        channels['R'].attrs['_dummy_channel'] = i
                        break
            channels = [channels[col] for col in 'RGB']
            im = clean_concat(channels, colour_axis)
        elif len(im[colour_axis]) == 2:
            dummy = xr.full_like(im[{colour_axis: 0}], fill_value=im.min())
            dummy = dummy.expand_dims({colour_axis: [np.NaN]})
            im.attrs['_dummy_channel'] = -1
            im = clean_concat([im, dummy], colour_axis)
    dim_order = [depth_axis, vert_axis, horiz_axis, colour_axis][:im.ndim]
    return im.transpose(*dim_order)
Esempio n. 8
0
    def test_concatenates_data(self):
        data1 = make_data(seed=1)
        data2 = make_data(seed=2)

        concatenated = clean_concat([data1, data2], 'point')
        self.assertEqual(concatenated.shape, data1.shape + (2, ))