def _add_attr_to_loaded_data(attr_name: str, data: sc.Variable, value: np.ndarray, unit: sc.Unit, dtype: Optional[Any] = None): try: data = data.attrs except AttributeError: pass try: if dtype is not None: if dtype == sc.DType.vector3: data[attr_name] = sc.vector(value=value, unit=unit) elif dtype == sc.DType.affine_transform3: data[attr_name] = affine_transform(value=value, unit=unit) elif dtype == sc.DType.linear_transform3: data[attr_name] = linear_transform(value=value, unit=unit) elif dtype == sc.DType.rotation3: if unit != sc.units.one: raise sc.UnitError( f'Rotations must be dimensionless, got unit {unit}') data[attr_name] = rotation(value=value) elif dtype == sc.DType.translation3: data[attr_name] = translation(value=value, unit=unit) else: data[attr_name] = sc.scalar(value=value, dtype=dtype, unit=unit) else: data[attr_name] = sc.scalar(value=value, unit=unit) except KeyError: pass
def _fit_workspace(ws, mantid_args): """ Performs a fit on the workspace. :param ws: The workspace on which the fit will be performed :returns: Dataset containing all of Fit's outputs """ with run_mantid_alg('Fit', InputWorkspace=ws, **mantid_args, CreateOutput=True) as fit: # This is assuming that all parameters are dimensionless. If this is # not the case we should use a dataset with a scalar variable per # parameter instead. Or better, a dict of scalar variables? parameters = convert_TableWorkspace_to_dataset(fit.OutputParameters) parameters = _table_to_data_array(parameters, key='Name', value='Value', stddev='Error') out = convert_Workspace2D_to_data_array(fit.OutputWorkspace) data = sc.Dataset() data['data'] = out['empty', 0] data['calculated'] = out['empty', 1] data['diff'] = out['empty', 2] parameters.coords['status'] = sc.scalar(fit.OutputStatus) parameters.coords['chi^2/d.o.f.'] = sc.scalar(fit.OutputChi2overDoF) parameters.coords['function'] = sc.scalar(str(fit.Function)) parameters.coords['cost_function'] = sc.scalar(fit.CostFunction) return parameters, data
def _exec_to_spherical(self, x, y, z): in_out = sc.Dataset() in_out['x'] = sc.scalar(x, unit=sc.units.m) in_out['y'] = sc.scalar(y, unit=sc.units.m) in_out['z'] = sc.scalar(z, unit=sc.units.m) point = sc.geometry.position(in_out['x'].data, in_out['y'].data, in_out['z'].data) scn.mantid._to_spherical(point, in_out) return in_out
def _convert_time_to_datetime64( raw_times: sc.Variable, group_path: str, start: str = None, scaling_factor: Union[float, np.float_] = None) -> sc.Variable: """ The nexus standard allows an arbitrary scaling factor to be inserted between the numbers in the `time` series and the unit of time reported in the nexus attribute. The times are also relative to a given log start time, which might be different for each log. If this log start time is not available, the start of the unix epoch (1970-01-01T00:00:00Z) is used instead. See https://manual.nexusformat.org/classes/base_classes/NXlog.html Args: raw_times: The raw time data from a nexus file. group_path: The path within the nexus file to the log being read. Used to generate warnings if loading the log fails. start: Optional, the start time of the log in an ISO8601 string. If not provided, defaults to the beginning of the unix epoch (1970-01-01T00:00:00Z). scaling_factor: Optional, the scaling factor between the provided time series data and the unit of the raw_times Variable. If not provided, defaults to 1 (a no-op scaling factor). """ try: raw_times_ns = sc.to_unit(raw_times, sc.units.ns, copy=False) except sc.UnitError: raise BadSource( f"The units of time in the entry at " f"'{group_path}/time{{units}}' must be convertible to seconds, " f"but this cannot be done for '{raw_times.unit}'. Skipping " f"loading group at '{group_path}'.") try: _start_ts = sc.scalar(value=np.datetime64(start or "1970-01-01T00:00:00Z"), unit=sc.units.ns, dtype=sc.DType.datetime64) except ValueError: raise BadSource( f"The date string '{start}' in the entry at " f"'{group_path}/time@start' failed to parse as an ISO8601 date. " f"Skipping loading group at '{group_path}'") _scale = sc.scalar( value=scaling_factor if scaling_factor is not None else 1, unit=sc.units.dimensionless) return _start_ts + (raw_times_ns * _scale).astype(sc.DType.int64, copy=False)
def make_beamline( sample_rotation: sc.Variable = None, beam_size: sc.Variable = 0.001 * sc.units.m, sample_size: sc.Variable = 0.01 * sc.units.m, detector_spatial_resolution: sc.Variable = 0.0025 * sc.units.m, gravity: sc.Variable = sc.vector(value=[0, -1, 0]) * g, chopper_frequency: sc.Variable = sc.scalar(20 / 3, unit='Hz'), chopper_phase: sc.Variable = sc.scalar(-8.0, unit='deg'), chopper_position: sc.Variable = sc.vector(value=[0, 0, -15.0], unit='m') ) -> dict: """ Amor beamline components. :param sample_rotation: Sample rotation (omega) angle. Default is `None`. :type sample_rotation: Variable. :param beam_size: Size of the beam perpendicular to the scattering surface. Default is `0.001 m`. :param sample_size: Size of the sample in direction of the beam. Default :code:`0.01 m`. :param detector_spatial_resolution: Spatial resolution of the detector. Default is `2.5 mm`. :param gravity: Vector representing the direction and magnitude of the Earth's gravitational field. Default is `[0, -g, 0]`. :param chopper_frequency: Rotational frequency of the chopper. Default is `6.6666... Hz`. :param chopper_phase: Phase offset between chopper pulse and ToF zero. Default is `-8. degrees of arc`. :param chopper_position: Position of the chopper. Default is `-15 m`. :returns: A dict. :rtype: dict """ beamline = { 'sample_rotation': sample_rotation, 'beam_size': beam_size, 'sample_size': sample_size, 'detector_spatial_resolution': detector_spatial_resolution, 'gravity': gravity } # TODO: in scn.load_nexus, the chopper parameters are stored as coordinates # of a DataArray, and the data value is a string containing the name of the # chopper. This does not allow storing e.g. chopper cutout angles. # We should change this to be a Dataset, which is what we do here. beamline["source_chopper"] = sc.scalar( make_chopper(frequency=chopper_frequency, phase=chopper_phase, position=chopper_position)) return beamline
def make_variables_from_run_logs(ws): for property_name in ws.run().keys(): units_string = ws.run()[property_name].units try: unit = additional_unit_mapping.get(units_string, sc.Unit(units_string)) except RuntimeError: # TODO catch UnitError once exposed from C++ # Parsing unit string failed unit = None values = deepcopy(ws.run()[property_name].value) if units_string and unit is None: warnings.warn(f"Workspace run log '{property_name}' " f"has unrecognised units: '{units_string}'") if unit is None: unit = sc.units.one try: times = deepcopy(ws.run()[property_name].times) is_time_series = True dimension_label = "time" except AttributeError: times = None is_time_series = False dimension_label = property_name if np.isscalar(values): property_data = sc.scalar(values, unit=unit) else: property_data = sc.Variable(values=values, unit=unit, dims=[dimension_label]) if is_time_series: # If property has timestamps, create a DataArray data_array = sc.DataArray(data=property_data, coords={ dimension_label: sc.Variable(dims=[dimension_label], values=times) }) yield property_name, sc.scalar(data_array) elif not np.isscalar(values): # If property is multi-valued, create a wrapper single # value variable. This prevents interference with # global dimensions for for output Dataset. yield property_name, sc.scalar(property_data) else: yield property_name, property_data
def test_find_chopper_keys(): da = sc.DataArray(data=sc.scalar('dummy'), coords={ 'chopper3': sc.scalar(0), 'abc': sc.scalar(0), 'chopper_1': sc.scalar(0), 'sample': sc.scalar(0), 'source': sc.scalar(0), 'Chopper_wfm': sc.scalar(0), 'chopper0': sc.scalar(0), 'chopper5': sc.scalar(0), 'monitor': sc.scalar(0) }) expected = ['chopper3', 'chopper_1', 'Chopper_wfm', 'chopper0', 'chopper5'] assert ch.find_chopper_keys(da) == expected
def test_scalar_without_dtype(): value = 'temp' var = sc.scalar(value) expected = sc.Variable(value) # Cannot directly compare variables with string dtype assert var.values == expected.values
def params(): dim = 'frame' return { 'frequency': sc.scalar(56.0, unit="Hz"), 'phase': sc.scalar(0.5, unit='rad'), 'position': sc.vector(value=[0., 0., 5.], unit='m'), 'cutout_angles_center': sc.linspace(dim=dim, start=0.25, stop=2.0 * np.pi, num=6, unit='rad'), 'cutout_angles_width': sc.linspace(dim=dim, start=0.1, stop=0.6, num=6, unit='rad'), 'kind': sc.scalar('wfm') }
def mean_from_adj_pixels(data): """ Applies a mean across 8 neighboring pixels (plus centre value) for data with 'x' and 'y' dimensions (at least). Result will calculate mean from slices across additional dimensions. For example if there is a tof dimension in addition to x, and y, for each set of neighbours the returned mean will take the mean tof value in the neighbour group. """ fill = np.finfo(data.values.dtype).min has_variances = data.variances is not None container = sc.empty(dims=['neighbor'] + data.dims, dtype=data.dtype, shape=[ 9, ] + data.shape, with_variances=has_variances, unit=data.unit) container['neighbor', 0] = data container['neighbor', 1] = _shift(data, "x", True, fill) container['neighbor', 2] = _shift(data, "x", False, fill) container['neighbor', 3] = _shift(data, "y", True, fill) container['neighbor', 4] = _shift(data, "y", False, fill) container['neighbor', 5:7] = _shift(container['neighbor', 1:3], "y", True, fill) container['neighbor', 7:9] = _shift(container['neighbor', 1:3], "y", False, fill) edges_mask = container <= sc.scalar(value=fill, unit=data.unit) da = sc.DataArray(data=container, masks={'edges': edges_mask}) return sc.mean(da, dim='neighbor').data
def test_extract_energy_initial(): from mantid.simpleapi import mtd mtd.clear() ds = scn.load(scn.data.get_path("CNCS_51936_event.nxs"), mantid_args={"SpectrumMax": 1}) assert sc.identical(ds.coords["incident_energy"], sc.scalar(value=3.0, unit=sc.Unit("meV")))
def load_monitor_data(monitor_groups: List[Group], nexus: LoadFromNexus) -> Dict: """ Load monitor data. Event-mode data takes precedence over histogram-mode data. """ monitor_data = {} for group in monitor_groups: try: nxmonitor = NXmonitor(group, nexus) # Standard loading requires binning monitor into pulses and adding # detector IDs. This is currently encapsulated in load_detector_data, # so we cannot readily use NXmonitor and bin aferwards without duplication. if nxmonitor._is_events: monitor = load_detector_data([group], [], nexus, True, True) warnings.warn( f"Event data present in NXmonitor group {group.name}. " f"Histogram-mode monitor data from this group will be " f"ignored.") else: monitor = nxmonitor[()] monitor_name = group.name.split("/")[-1] monitor_data[monitor_name] = sc.scalar(value=monitor) except KeyError: warnings.warn( f"No event-mode or histogram-mode monitor data found for " f"NXMonitor group {group.name}. Skipping this group.") return monitor_data
def _load_event_time_zero(group: Group, nexus: LoadFromNexus, index=...) -> sc.Variable: time_zero_group = "event_time_zero" event_time_zero = nexus.load_dataset(group, time_zero_group, dimensions=[_pulse_dimension], index=index) try: pulse_times = sc.to_unit(event_time_zero, sc.units.ns, copy=False) except sc.UnitError: raise BadSource(f"Could not load pulse times: units attribute " f"'{event_time_zero.unit}' in NXEvent at " f"{group.name}/{time_zero_group} is not convertible" f" to nanoseconds.") try: time_offset = nexus.get_string_attribute( nexus.get_dataset_from_group(group, time_zero_group), "offset") except MissingAttribute: time_offset = "1970-01-01T00:00:00Z" # Need to convert the values which were loaded as float64 into int64 to be able # to do datetime arithmetic. This needs to be done after conversion to ns to # avoid unnecessary loss of accuracy. pulse_times = pulse_times.astype(sc.DType.int64, copy=False) return pulse_times + sc.scalar(np.datetime64(time_offset), unit=sc.units.ns, dtype=sc.DType.datetime64)
def test_illumination_correction_no_spill(): beam_size = 1.0 * sc.units.m sample_size = 10.0 * sc.units.m theta = sc.array(values=[30.0], unit=sc.units.deg, dims=['event']) expected_result = sc.scalar(1.0) actual_result = corrections.illumination_correction(beam_size, sample_size, theta) assert sc.allclose(actual_result, expected_result)
def make_detector_info(ws, spectrum_dim): det_info = ws.detectorInfo() # det -> spec mapping nDet = det_info.size() spectrum = np.empty(shape=(nDet, ), dtype=np.int32) has_spectrum = np.full((nDet, ), False) spec_info = ws.spectrumInfo() for i, spec in enumerate(spec_info): spec_def = spec.spectrumDefinition for j in range(len(spec_def)): det, time = spec_def[j] if time != 0: raise RuntimeError( "Conversion of Mantid Workspace with scanning instrument " "not supported yet.") spectrum[det] = i has_spectrum[det] = True # Store only information about detectors with data (a spectrum). The rest # mostly just gets in the way and including it in the default converter # is probably not required. spectrum = sc.array(dims=['detector'], values=spectrum[has_spectrum]) detector = sc.array(dims=['detector'], values=det_info.detectorIDs()[has_spectrum]) # May want to include more information here, such as detector positions, # but for now this is not necessary. return sc.scalar( sc.Dataset(coords={ 'detector': detector, spectrum_dim: spectrum }))
def from_mantid(workspace, **kwargs): """Convert Mantid workspace to a scipp data array or dataset. :param workspace: Mantid workspace to convert. """ scipp_obj = None # This is either a Dataset or DataArray monitor_ws = None workspaces_to_delete = [] w_id = workspace.id() if (w_id == 'Workspace2D' or w_id == 'RebinnedOutput' or w_id == 'MaskWorkspace'): n_monitor = 0 spec_info = workspace.spectrumInfo() for i in range(len(spec_info)): if spec_info.hasDetectors(i) and spec_info.isMonitor(i): n_monitor += 1 # If there are *only* monitors we do not move them to an attribute if n_monitor > 0 and n_monitor < len(spec_info): import mantid.simpleapi as mantid workspace, monitor_ws = mantid.ExtractMonitors(workspace) workspaces_to_delete.append(workspace) workspaces_to_delete.append(monitor_ws) scipp_obj = convert_Workspace2D_to_data_array(workspace, **kwargs) elif w_id == 'EventWorkspace': scipp_obj = convert_EventWorkspace_to_data_array(workspace, **kwargs) elif w_id == 'TableWorkspace': scipp_obj = convert_TableWorkspace_to_dataset(workspace, **kwargs) elif w_id == 'MDHistoWorkspace': scipp_obj = convert_MDHistoWorkspace_to_data_array(workspace, **kwargs) elif w_id == 'WorkspaceGroup': scipp_obj = convert_WorkspaceGroup_to_dataarray_dict( workspace, **kwargs) if scipp_obj is None: raise RuntimeError('Unsupported workspace type {}'.format(w_id)) # TODO Is there ever a case where a Workspace2D has a separate monitor # workspace? This is not handled by ExtractMonitors above, I think. if monitor_ws is None: if hasattr(workspace, 'getMonitorWorkspace'): try: monitor_ws = workspace.getMonitorWorkspace() except RuntimeError: # Have to try/fail here. No inspect method on Mantid for this. pass if monitor_ws is not None: if monitor_ws.id() == 'MaskWorkspace' or monitor_ws.id( ) == 'Workspace2D': converter = convert_Workspace2D_to_data_array elif monitor_ws.id() == 'EventWorkspace': converter = convert_EventWorkspace_to_data_array monitors = convert_monitors_ws(monitor_ws, converter, **kwargs) for name, monitor in monitors: scipp_obj.attrs[name] = sc.scalar(monitor) for ws in workspaces_to_delete: mantid.DeleteWorkspace(ws) return scipp_obj
def _load_title(entry_group: Group, nexus: LoadFromNexus) -> Dict: try: return { "experiment_title": sc.scalar(value=nexus.load_scalar_string(entry_group, "title")) } except MissingDataset: return {}
def _extract_einitial(ws): if ws.run().hasProperty("Ei"): ei = ws.run().getProperty("Ei").value elif ws.run().hasProperty('EnergyRequest'): ei = ws.run().getProperty('EnergyRequest').value[-1] else: ei = 0 return sc.scalar(ei, unit=sc.Unit("meV"))
def _load_start_and_end_time(entry_group: Group, nexus: LoadFromNexus) -> Dict: times = {} for time in ["start_time", "end_time"]: try: times[time] = sc.scalar( value=nexus.load_scalar_string(entry_group, time)) except MissingDataset: pass return times
def make_chopper(frequency: sc.Variable, position: sc.Variable, phase: sc.Variable = None, cutout_angles_center: sc.Variable = None, cutout_angles_width: sc.Variable = None, cutout_angles_begin: sc.Variable = None, cutout_angles_end: sc.Variable = None, kind: str = None) -> sc.Dataset: """ Create a Dataset that holds chopper parameters. This ensures the Dataset is compatible with the other functions in the choppers module. Defining a chopper's cutout angles can either constructed from an array of cutout centers and an array of angular widths, or two arrays containing the begin (leading) and end (closing) angles of the cutout windows. :param frequency: The rotational frequency of the chopper. :param position: The position vector of the chopper. :param phase: The chopper phase. :param cutout_angles_center: The centers of the chopper cutout angles. :param cutout_angles_width: The angular widths of the chopper cutouts. :param cutout_angles_begin: The starting/opening angles of the chopper cutouts. :param cutout_angles_end: The ending/closing angles of the chopper cutouts. :param kind: The chopper kind. Any string can be supplied, but WFM choppers should be given 'wfm' as their kind. """ data = {"frequency": frequency, "position": position} if phase is not None: data["phase"] = phase if cutout_angles_center is not None: data["cutout_angles_center"] = cutout_angles_center if cutout_angles_width is not None: data["cutout_angles_width"] = cutout_angles_width if cutout_angles_begin is not None: data["cutout_angles_begin"] = cutout_angles_begin if cutout_angles_end is not None: data["cutout_angles_end"] = cutout_angles_end if kind is not None: data["kind"] = kind chopper = sc.Dataset(data=data) # Sanitize input parameters if (None not in [cutout_angles_begin, cutout_angles_end]) or (None not in [ cutout_angles_center, cutout_angles_width ]): widths = utils.cutout_angles_width(chopper) if (sc.min(widths) < sc.scalar(0.0, unit=widths.unit)).value: raise ValueError( "Negative window width found in chopper cutout angles.") if not sc.allsorted(utils.cutout_angles_begin(chopper), dim=widths.dim): raise ValueError("Chopper begin cutout angles are not monotonic.") if not sc.allsorted(utils.cutout_angles_end(chopper), dim=widths.dim): raise ValueError("Chopper end cutout angles are not monotonic.") return chopper
def get_metadata_array(self) -> Tuple[bool, sc.Variable]: """ Copy collected data from the buffer """ with self._buffer_mutex: return_array = self._data_array[ self._name, :self._buffer_filled_size].copy() new_data_exists = self._buffer_filled_size != 0 self._buffer_filled_size = 0 return new_data_exists, sc.scalar(return_array)
def _add_coord_to_loaded_data(attr_name: str, data: sc.Variable, value: np.ndarray, unit: sc.Unit, dtype: Optional[Any] = None): if isinstance(data, sc.DataArray): data = data.coords try: if dtype is not None: if dtype == sc.DType.vector3: data[attr_name] = sc.vector(value=value, unit=unit) else: data[attr_name] = sc.scalar(value, dtype=dtype, unit=unit) else: data[attr_name] = sc.scalar(value, unit=unit) except KeyError: pass
def test_EventWorkspace_with_pulse_times(): import mantid.simpleapi as sapi small_event_ws = sapi.CreateSampleWorkspace(WorkspaceType='Event', NumBanks=1, NumEvents=10) d = scn.mantid.convert_EventWorkspace_to_data_array(small_event_ws, load_pulse_times=True) assert d.data.values[0].coords['pulse_time'].dtype == sc.DType.datetime64 assert sc.identical( d.data.values[0].coords['pulse_time']['event', 0], sc.scalar(value=small_event_ws.getSpectrum(0).getPulseTimes() [0].to_datetime64()))
def test_scalar_with_dtype(): value = 1.0 variance = 5.0 unit = sc.units.m dtype = sc.dtype.float64 var = sc.scalar(value=value, variance=variance, unit=unit, dtype=dtype) expected = sc.Variable(value=value, variance=variance, unit=unit, dtype=dtype) comparison = var == expected assert comparison.values.all()
def _load_instrument_name(instrument_groups: List[Group], nexus: LoadFromNexus) -> Dict: try: if len(instrument_groups) > 1: warn(f"More than one {nx_instrument} found in file, " f"loading name from {instrument_groups[0].name} only") return { "instrument_name": sc.scalar( value=nexus.load_scalar_string(instrument_groups[0], "name")) } except MissingDataset: return {}
def _make_component_settings(*, data, center='sample_position', type='box', size_unit=sc.units.m, wireframe=False, component_size=(0.1, 0.1, 0.1)): comp_size = sc.vector(value=component_size, unit=size_unit) if isinstance( component_size, tuple) else sc.scalar(value=component_size, unit=size_unit) center = data.meta[center] if isinstance(center, str) else center sample_settings = {'center': center, 'size': comp_size, 'type': type} return sample_settings
def _get_transformation_magnitude_and_unit( group_name: str, transform: Union[h5py.Dataset, GroupObject], nexus: LoadFromNexus) -> sc.DataArray: """ Gets a scipp data array containing magnitudes and timestamps of a transformation. """ if nexus.is_group(transform): try: values = nexus.load_dataset(transform, "value", dimensions=["time"]) times = load_time_dataset(nexus=nexus, group=transform, dataset_name="time", dim="time", group_name=group_name) if len(values) == 0: raise TransformationError(f"Found empty NXlog as a " f"transformation for {group_name}") if len(values) != len(times): raise TransformationError( f"Mismatched time and value dataset lengths " f"for transformation at {group_name}") except MissingDataset: raise TransformationError( f"Encountered {nexus.get_name(transform)} in transformation " f"chain for {group_name} but it is a group without a value " "dataset; not a valid transformation") unit = nexus.get_unit(nexus.get_dataset_from_group(transform, "value")) if unit == sc.units.dimensionless: # See if the value unit is on the NXLog itself instead unit = nexus.get_unit(transform) if unit == sc.units.dimensionless: raise TransformationError( f"Missing units for transformation at " f"{nexus.get_name(transform)}") else: magnitude = nexus.load_dataset_as_numpy_array(transform).astype( float).item() unit = nexus.get_unit(transform) if unit == sc.units.dimensionless: raise TransformationError(f"Missing units for transformation at " f"{nexus.get_name(transform)}") return sc.scalar(value=magnitude, unit=unit, dtype=sc.DType.float64) return sc.DataArray(data=values, coords={"time": times})
def test_time_open_closed(params): dim = 'frame' chopper = ch.make_chopper( frequency=sc.scalar(0.5, unit=sc.units.one / sc.units.s), phase=sc.scalar(0., unit='rad'), position=params['position'], cutout_angles_begin=sc.array(dims=[dim], values=np.pi * np.array([0.0, 0.5, 1.0]), unit='rad'), cutout_angles_end=sc.array(dims=[dim], values=np.pi * np.array([0.5, 1.0, 1.5]), unit='rad'), kind=params['kind']) assert sc.allclose( ch.time_open(chopper), sc.to_unit(sc.array(dims=[dim], values=[0.0, 0.5, 1.0], unit='s'), 'us')) assert sc.allclose( ch.time_closed(chopper), sc.to_unit(sc.array(dims=[dim], values=[0.5, 1.0, 1.5], unit='s'), 'us')) chopper["phase"] = sc.scalar(2.0 * np.pi / 3.0, unit='rad') assert sc.allclose( ch.time_open(chopper), sc.to_unit( sc.array(dims=[dim], values=np.array([0.0, 0.5, 1.0]) + 2.0 / 3.0, unit='s'), 'us')) assert sc.allclose( ch.time_closed(chopper), sc.to_unit( sc.array(dims=[dim], values=np.array([0.5, 1.0, 1.5]) + 2.0 / 3.0, unit='s'), 'us'))
def test_groupby2d_simple_case_neutron_specific(): data = sc.array(dims=['wavelength', 'y', 'x'], values=np.arange(100.0).reshape(1, 10, 10)) wav = sc.scalar(value=1.0) x = sc.array(dims=['x'], values=np.arange(10)) y = sc.array(dims=['y'], values=np.arange(10)) source_position = sc.vector(value=[0, 0, -10]) ds = sc.Dataset(data={'a': data}, coords={ 'y': y, 'x': x, 'wavelength': wav, 'source_position': source_position }) grouped = groupby2D(ds, 5, 5) assert grouped['a'].shape == [1, 5, 5] grouped = groupby2D(ds, 1, 1) assert grouped['a'].shape == [1, 1, 1] assert 'source_position' in grouped['a'].meta
def _rotation_matrix_from_axis_and_angle(axis: np.ndarray, angles: sc.DataArray) -> sc.DataArray: """ From a provided Dataset containing N angles, produce N rotation matrices corresponding to a rotation of angle around the rotation axis given in axis. Args: axis: numpy array of length 3 specifying the rotation axis angles: a dataset containing the angles Returns: A dataset of rotation matrices. """ rotvec = sc.vector(value=axis) # We multiply by -1*angle to get a "passive transform" rotvecs = rotvec * sc.scalar(-1.0) * angles.astype(sc.DType.float64, copy=False) matrices = sc.spatial.rotations_from_rotvecs(dims=angles.dims, values=rotvecs.values, unit=sc.units.rad) return matrices