def test_too_many_color_channels(self): col1 = convert_ndarray_to_xarray(ARRAY_4D, extra_dims={ILLUM: [0, 1, 2]}) col2 = convert_ndarray_to_xarray(ARRAY_4D, extra_dims={ILLUM: [3, 4, 5]}) xr6cols = clean_concat([col1, col2], dim=ILLUM) assert_raises(BadImage, display_image, xr6cols)
def test_preserves_data_order(self): data1 = make_data(seed=1) data2 = make_data(seed=2) data = [data1, data2] concatenated = clean_concat(data, 'point') self.assertTrue(np.all(concatenated.values[..., 0] == data[0].values)) self.assertTrue(np.all(concatenated.values[..., 1] == data[1].values))
def test_preserves_metadata_keys(self): data1 = update_metadata(make_data(seed=1), **METADATA_VALUES) data2 = update_metadata(make_data(seed=2), **METADATA_VALUES) data = [data1, data2] concatenated = clean_concat(data, 'point') for key in METADATA_VALUES.keys(): self.assertIn(key, concatenated.attrs) self.assertTrue(hasattr(concatenated, key))
def test_save_multiple_images_writes_image_files(self): # For now, we just test that it writes the image files, not that # the files are correct: savenames = [ self._make_unused_filename_in_tempdir('png', i) for i, _ in enumerate(self.holograms) ] assert all([not os.path.exists(nm) for nm in savenames]) save_plot(savenames, clean_concat(self.holograms, dim='z')) self.assertTrue(all([os.path.exists(nm) for nm in savenames]))
def test_preserves_metadata_values(self): data1 = update_metadata(make_data(seed=1), **METADATA_VALUES) data2 = update_metadata(make_data(seed=2), **METADATA_VALUES) data = [data1, data2] concatenated = clean_concat(data, 'point') for key, value in METADATA_VALUES.items(): if key != 'illum_polarization': self.assertEqual(getattr(concatenated, key), value) polarization_ok = np.all(concatenated.illum_polarization[:2] == METADATA_VALUES['illum_polarization']) self.assertTrue(polarization_ok)
def _calculate_multiple_color_scattered_field(self, scatterer, schema): field = [] for illum in schema.illum_wavelen.illumination.values: this_schema = update_metadata( schema, illum_wavelen=ensure_array( schema.illum_wavelen.sel(illumination=illum).values)[0], illum_polarization=ensure_array( schema.illum_polarization.sel(illumination=illum).values)) this_field = self._calculate_single_color_scattered_field( scatterer.select({illumination: illum}), this_schema) field.append(this_field) field = clean_concat(field, dim=schema.illum_wavelen.illumination) return field
def display_image(im, scaling='auto', vert_axis='x', horiz_axis='y', depth_axis='z', colour_axis='illumination'): im = im.copy() if isinstance(im, xr.DataArray): if hasattr(im, 'z') and len(im['z']) == 1 and depth_axis is not 'z': im = im[{'z': 0}] if depth_axis == 'z' and 'z' not in im.dims: im = im.expand_dims('z') if im.ndim > 3 + (colour_axis in im.dims): raise BadImage("Too many dims on DataArray to output properly.") attrs = im.attrs else: attrs = {} im = ensure_array(im) if im.ndim > 3: raise BadImage("Too many dims on ndarray to output properly.") elif im.ndim == 2: im = np.array([im]) elif im.ndim < 2: raise BadImage("Too few dims on ndarray to output properly.") axes = [0, 1, 2] for axis in [vert_axis, horiz_axis, depth_axis]: if isinstance(axis, int): try: axes.remove(axis) except KeyError: raise ValueError("Cannot interpret axis specifications.") if len(axes) > 0: if not isinstance(depth_axis, int): depth_axis = axes[np.argmin([im.shape[i] for i in axes])] axes.remove(depth_axis) if not isinstance(vert_axis, int): vert_axis = axes[0] axes.pop(0) if not isinstance(horiz_axis, int): horiz_axis = axes[0] im = im.transpose([depth_axis, vert_axis, horiz_axis]) depth_axis = 'z' vert_axis = 'x' horiz_axis = 'y' im = data_grid(im, spacing=1, z=range(len(im))) if np.iscomplex(im).any(): warn("Image contains complex values. Taking image magnitude.") im = np.abs(im) if scaling is 'auto': scaling = (ensure_scalar(im.min()), ensure_scalar(im.max())) if scaling is not None: im = np.maximum(im, scaling[0]) im = np.minimum(im, scaling[1]) im = (im - scaling[0]) / (scaling[1] - scaling[0]) im.attrs = attrs im.attrs['_image_scaling'] = scaling if colour_axis in im.dims: cols = [ col[0].capitalize() if isinstance(col, str) else ' ' for col in im[colour_axis].values ] RGB_names = np.all([letter in 'RGB' for letter in cols]) if len(im[colour_axis]) == 1: im = im.squeeze(dim=colour_axis) elif len(im[colour_axis]) > 3: raise BadImage('Cannot output more than 3 colour channels') elif RGB_names: channels = { col: im[{ colour_axis: i }] for i, col in enumerate(cols) } if len(channels) == 2: dummy = im[{colour_axis: 0}].copy() dummy[:] = im.min() for i, col in enumerate('RGB'): if col not in cols: dummy[colour_axis] = col channels[col] = dummy channels['R'].attrs['_dummy_channel'] = i break channels = [channels[col] for col in 'RGB'] im = clean_concat(channels, colour_axis) elif len(im[colour_axis]) == 2: dummy = xr.full_like(im[{colour_axis: 0}], fill_value=im.min()) dummy = dummy.expand_dims({colour_axis: [np.NaN]}) im.attrs['_dummy_channel'] = -1 im = clean_concat([im, dummy], colour_axis) dim_order = [depth_axis, vert_axis, horiz_axis, colour_axis][:im.ndim] return im.transpose(*dim_order)
def test_concatenates_data(self): data1 = make_data(seed=1) data2 = make_data(seed=2) concatenated = clean_concat([data1, data2], 'point') self.assertEqual(concatenated.shape, data1.shape + (2, ))