Пример #1
0
def normalize_dim(arr: DataType, dim_or_dims, keep_id=False):
    """
    Normalizes the intensity so that all values along arr.sum(dims other than those in ``dim_or_dims``)
    have the same value. The function normalizes so that the average value of cells in
    the output is 1.
    :param dim_name:
    :return:
    """

    dims = dim_or_dims
    if isinstance(dim_or_dims, str):
        dims = [dims]

    summed_arr = arr.fillna(arr.mean()).sum(
        [d for d in arr.dims if d not in dims])
    normalized_arr = arr / (summed_arr / np.product(summed_arr.shape))

    to_return = xr.DataArray(normalized_arr.values,
                             arr.coords,
                             arr.dims,
                             attrs=arr.attrs)

    if not keep_id and 'id' in to_return.attrs:
        del to_return.attrs['id']

    provenance(to_return, arr, {
        'what': 'Normalize axis or axes',
        'by': 'normalize_dim',
        'dims': dims,
    })

    return to_return
Пример #2
0
def soft_normalize_dim(arr: xr.DataArray,
                       dim_or_dims,
                       keep_id=False,
                       amp_limit=100):
    dims = dim_or_dims
    if isinstance(dim_or_dims, str):
        dims = [dims]

    summed_arr = arr.fillna(arr.mean()).sum(
        [d for d in arr.dims if d not in dims])
    normalized_arr = arr / (summed_arr / np.product(summed_arr.shape))

    to_return = xr.DataArray(normalized_arr.values,
                             arr.coords,
                             arr.dims,
                             attrs=arr.attrs)

    if not keep_id and 'id' in to_return.attrs:
        del to_return.attrs['id']

    provenance(to_return, arr, {
        'what': 'Normalize axis or axes',
        'by': 'normalize_dim',
        'dims': dims,
    })

    return to_return
Пример #3
0
def dn_along_axis(arr: xr.DataArray, axis=None, smooth_fn=None, order=2):
    """
    Like curvature, performs a second derivative. You can pass a function to use for smoothing through
    the parameter smooth_fn, otherwise no smoothing will be performed.

    You can specify the axis to take the derivative along with the axis param, which expects a string.
    If no axis is provided the axis will be chosen from among the available ones according to the preference
    for axes here, the first available being taken:

    ['eV', 'kp', 'kx', 'kz', 'ky', 'phi', 'beta', 'theta]
    :param arr:
    :param axis:
    :param smooth_fn:
    :param order: Specifies how many derivatives to take
    :return:
    """
    axis_order = ['eV', 'kp', 'kx', 'kz', 'ky', 'phi', 'beta', 'theta']
    if axis is None:
        axes = [a for a in axis_order if a in arr.dims]
        if axes:
            axis = axes[0]
        else:
            # have to do something
            axis = arr.dims[0]
            warnings.warn(
                'Choosing axis: {} for the second derivative, no preferred axis found.'
                .format(axis))

    if smooth_fn is None:
        smooth_fn = lambda x: x

    d_axis = float(arr.coords[axis][1] - arr.coords[axis][0])
    axis_idx = arr.dims.index(axis)

    values = arr.values
    for _ in range(order):
        as_arr = xr.DataArray(values, arr.coords, arr.dims)
        smoothed = smooth_fn(as_arr)
        values = np.gradient(smoothed.values, d_axis, axis=axis_idx)

    dn_arr = xr.DataArray(values, arr.coords, arr.dims, attrs=arr.attrs)

    if 'id' in dn_arr.attrs:
        del dn_arr.attrs['id']
        provenance(
            dn_arr, arr, {
                'what': '{}th derivative'.format(order),
                'by': 'dn_along_axis',
                'axis': axis,
                'order': order,
            })

    return dn_arr
Пример #4
0
def gaussian_filter_arr(arr: xr.DataArray, sigma=None, n=1, default_size=1):
    """
    Functionally wraps scipy.ndimage.filters.gaussian_filter with the advantage that the sigma
    is coordinate aware.

    :param arr:
    :param sigma: Kernel sigma, specified in terms of axis units. An axis that is not specified
                  will have a kernel width of `default_size` in index units.
    :param n: Repeats n times.
    :param default_size: Changes the default kernel width for axes not specified in `sigma`. Changing this
                         parameter and leaving `sigma` as None allows you to smooth with an even-width
                         kernel in index-coordinates.
    :return: xr.DataArray: smoothed data.
    """
    if sigma is None:
        sigma = {}

    sigma = {
        k: int(v / (arr.coords[k][1] - arr.coords[k][0]))
        for k, v in sigma.items()
    }
    for dim in arr.dims:
        if dim not in sigma:
            sigma[dim] = default_size

    sigma = tuple(sigma[k] for k in arr.dims)

    values = arr.values
    for _ in range(n):
        values = ndimage.filters.gaussian_filter(values, sigma)

    filtered_arr = xr.DataArray(values,
                                arr.coords,
                                arr.dims,
                                attrs=copy.deepcopy(arr.attrs))

    if 'id' in filtered_arr.attrs:
        del filtered_arr.attrs['id']

        provenance(
            filtered_arr, arr, {
                'what': 'Gaussian filtered data',
                'by': 'gaussian_filter_arr',
                'sigma': sigma,
            })

    return filtered_arr
def apply_quadratic_fermi_edge_correction(
        arr: xr.DataArray,
        correction: lf.model.ModelResult = None,
        offset=None):
    assert isinstance(arr, xr.DataArray)
    if correction is None:
        correction = build_quadratic_fermi_edge_correction(arr)

    if 'corrections' not in arr.attrs:
        arr.attrs['corrections'] = {}

    arr.attrs['corrections']['FE_Corr'] = correction.best_values

    delta_E = arr.coords['eV'].values[1] - arr.coords['eV'].values[0]
    dims = list(arr.dims)
    energy_axis = dims.index('eV')
    phi_axis = dims.index('phi')

    shift_amount_E = correction.eval(x=arr.coords['phi'].values)

    if offset is not None:
        shift_amount_E = shift_amount_E - offset

    shift_amount = -shift_amount_E / delta_E

    corrected_arr = xr.DataArray(shift_by(arr.values,
                                          shift_amount,
                                          axis=energy_axis,
                                          by_axis=phi_axis,
                                          order=1),
                                 arr.coords,
                                 arr.dims,
                                 attrs=arr.attrs)

    if 'id' in corrected_arr.attrs:
        del corrected_arr.attrs['id']

    provenance(
        corrected_arr, arr, {
            'what': 'Shifted Fermi edge to align at 0',
            'by': 'apply_quadratic_fermi_edge_correction',
            'correction': correction.best_values,
        })

    return corrected_arr
def apply_photon_energy_fermi_edge_correction(arr: xr.DataArray,
                                              correction=None,
                                              **kwargs):
    if correction is None:
        correction = build_photon_energy_fermi_edge_correction(arr, **kwargs)

    correction_values = correction.T.map(lambda x: x.params['center'].value)
    if 'corrections' not in arr.attrs:
        arr.attrs['corrections'] = {}

    arr.attrs['corrections']['hv_correction'] = list(correction_values.values)

    shift_amount = -correction_values / arr.T.stride(
        generic_dim_names=False)['eV']
    energy_axis = arr.dims.index('eV')
    hv_axis = arr.dims.index('hv')

    corrected_arr = xr.DataArray(shift_by(arr.values,
                                          shift_amount,
                                          axis=energy_axis,
                                          by_axis=hv_axis,
                                          order=1),
                                 arr.coords,
                                 arr.dims,
                                 attrs=arr.attrs)

    if 'id' in corrected_arr.attrs:
        del corrected_arr.attrs['id']

    provenance(
        corrected_arr, arr, {
            'what': 'Shifted Fermi edge to align at 0 along hv axis',
            'by': 'apply_photon_energy_fermi_edge_correction',
            'correction': list(correction_values.values),
        })

    return corrected_arr
Пример #7
0
def fft_filter(data: xr.DataArray, stops):
    """
    Applies a brick wall filter at region in ``stops`` in the Fourier transform of data. Use with care.
    :param data:
    :param stops:
    :return:
    """

    # This won't tolerate inverse inverse filtering ;)
    kdata = xrft.dft(data)

    # this will produce a fair amount of ringing, Scipy isn't very clear about how to butterworth filter in Nd
    for stop in stops:
        kstop = {'freq_' + k if 'freq_' not in k else k: v for k, v in stop.items()} # be nice
        kdata.loc[kstop] = 0

    kkdata = xrft.idft(kdata)
    kkdata.values = np.real(kkdata.values)
    kkdata.values = kkdata.values - np.min(kkdata.values) + np.mean(data.values)

    filtered_arr = xr.DataArray(
        kkdata,
        data.coords,
        data.dims,
        attrs=data.attrs.copy()
    )

    if 'id' in filtered_arr:
        del filtered_arr.attrs['id']

        provenance(filtered_arr, data, {
            'what': 'Apply a filter in frequency space by brick walling coordinate regions.',
            'by': 'fft_filter',
            'stops': stops,
        })

    return filtered_arr
def apply_direct_fermi_edge_correction(arr: xr.DataArray,
                                       correction=None,
                                       *args,
                                       **kwargs):
    if correction is None:
        correction = build_direct_fermi_edge_correction(arr, *args, **kwargs)

    shift_amount = -correction / arr.T.stride(generic_dim_names=False)['eV']  # pylint: disable=invalid-unary-operand-type
    energy_axis = list(arr.dims).index('eV')

    correction_axis = list(arr.dims).index(correction.dims[0])

    corrected_arr = xr.DataArray(shift_by(arr.values,
                                          shift_amount,
                                          axis=energy_axis,
                                          by_axis=correction_axis,
                                          order=1),
                                 arr.coords,
                                 arr.dims,
                                 attrs=arr.attrs)

    if 'id' in corrected_arr.attrs:
        del corrected_arr.attrs['id']

    provenance(
        corrected_arr, arr, {
            'what':
            'Shifted Fermi edge to align at 0 along hv axis',
            'by':
            'apply_photon_energy_fermi_edge_correction',
            'correction':
            list(correction.values if isinstance(correction, xr.DataArray
                                                 ) else correction),
        })

    return corrected_arr
Пример #9
0
def boxcar_filter_arr(arr: xr.DataArray,
                      size=None,
                      n=1,
                      default_size=1,
                      skip_nan=True):
    """
    Functionally wraps scipy.ndimage.filters.gaussian_filter with the advantage that the sigma
    is coordinate aware.

    :param arr:
    :param size: Kernel size, specified in terms of axis units. An axis that is not specified
                 will have a kernel width of `default_size` in index units.
    :param n: Repeats n times.
    :param default_size: Changes the default kernel width for axes not specified in `sigma`. Changing this
                         parameter and leaving `sigma` as None allows you to smooth with an even-width
                         kernel in index-coordinates.
    :param skip_nan: By default, masks parts of the data which are NaN to prevent poor filter results.
    :return: xr.DataArray: smoothed data.
    """

    if size is None:
        size = {}

    size = {
        k: int(v / (arr.coords[k][1] - arr.coords[k][0]))
        for k, v in size.items()
    }
    for dim in arr.dims:
        if dim not in size:
            size[dim] = default_size

    size = tuple(size[k] for k in arr.dims)

    if skip_nan:
        nan_mask = np.copy(arr.values) * 0 + 1
        nan_mask[arr.values != arr.values] = 0
        filtered_mask = ndimage.filters.uniform_filter(nan_mask, size)

        values = np.copy(arr.values)
        values[values != values] = 0

        for _ in range(n):
            values = ndimage.filters.uniform_filter(values,
                                                    size) / filtered_mask
            values[nan_mask == 0] = 0
    else:
        for i in range(n):
            values = ndimage.filters.uniform_filter(values, size)

    filtered_arr = xr.DataArray(values,
                                arr.coords,
                                arr.dims,
                                attrs=copy.deepcopy(arr.attrs))

    if 'id' in arr.attrs:
        del filtered_arr.attrs['id']

        provenance(
            filtered_arr, arr, {
                'what': 'Boxcar filtered data',
                'by': 'boxcar_filter_arr',
                'size': size,
                'skip_nan': skip_nan,
            })

    return filtered_arr
def decomposition_along(data: DataType,
                        axes: List[str],
                        decomposition_cls,
                        correlation=False,
                        **kwargs):
    """
    Performs a change of basis of your data according to `sklearn` decomposition classes. This allows
    for robust and simple PCA, ICA, factor analysis, and other decompositions of your data even when it
    is very high dimensional.

    Generally speaking, PCA and similar techniques work when data is 2D, i.e. a sequence of 1D observations.
    We can make the same techniques work by unravelling a ND dataset into 1D (i.e. np.ndarray.ravel()) and
    unravelling a KD set of observations into a 1D set of observations. This is basically grouping axes. As
    an example, if you had a 4D dataset which consisted of 2D-scanning valence band ARPES, then the dimensions
    on our dataset would be "[x,y,eV,phi]". We can group these into [spatial=(x, y), spectral=(eV, phi)] and
    perform PCA or another analysis of the spectral features over different spatial observations.

    If our data was called `f`, this can be accomplished with:

    ```
    transformed, decomp = decomposition_analysis(f.stack(spectral=['eV', 'phi']), ['x', 'y'], PCA)
    transformed.dims # -> [X, Y, components]
    ```

    The results of `decomposition_along` can be explored with `arpes.widgets.pca_explorer`, regardless of
    the decomposition class.

    :param data: Input data, can be N-dimensional but should only include one "spectral" axis.
    :param axes: Several axes to be treated as a single axis labeling the list of observations.
    :param decomposition_cls: A sklearn.decomposition class (such as PCA or ICA) to be used
                              to perform the decomposition.
    :param correlation: Controls whether StandardScaler() is used as the first stage of the data ingestion
                        pipeline for sklearn.
    :param kwargs:
    :return:
    """
    from sklearn.pipeline import make_pipeline
    from sklearn.preprocessing import StandardScaler

    if len(axes) > 1:
        flattened_data = normalize_to_spectrum(data).stack(fit_axis=axes)
        stacked = True
    else:
        flattened_data = normalize_to_spectrum(data).S.transpose_to_back(
            axes[0])
        stacked = False

    if len(flattened_data.dims) != 2:
        raise ValueError(
            'Inappropriate number of dimensions after flattening: [{}]'.format(
                flattened_data.dims))

    if correlation:
        pipeline = make_pipeline(StandardScaler(), decomposition_cls(**kwargs))
    else:
        pipeline = make_pipeline(decomposition_cls(**kwargs))

    pipeline.fit(flattened_data.values.T)

    decomp = pipeline.steps[-1][1]

    transform = decomp.transform(flattened_data.values.T)

    into = flattened_data.copy(deep=True)
    into_first = into.dims[0]
    into = into.isel(**dict([[into_first, slice(0, transform.shape[1])]]))
    into = into.rename(dict([[into_first, 'components']]))

    into.values = transform.T

    if stacked:
        into = into.unstack('fit_axis')

    provenance(
        into, data, {
            'what': 'sklearn decomposition',
            'by': 'decomposition_along',
            'axes': axes,
            'correlation': False,
            'decomposition_cls': decomposition_cls.__name__,
        })

    return into, decomp
Пример #11
0
def curvature(arr: xr.DataArray, directions=None, alpha=1, beta=None):
    """
    Defined via
        C(x,y) = ([C_0 + (df/dx)^2]d^2f/dy^2 - 2 * df/dx df/dy d^2f/dxdy + [C_0 + (df/dy)^2]d^2f/dx^2) /
                 (C_0 (df/dx)^2 + (df/dy)^2)^(3/2)

    of in the case of inequivalent dimensions x and y

        C(x,y) = ([1 + C_x(df/dx)^2]C_y * d^2f/dy^2 -
                  2 * C_x * C_y * df/dx df/dy d^2f/dxdy +
                  [1 + C_y * (df/dy)^2] * C_x * d^2f/dx^2) /
                 (1 + C_x (df/dx)^2 + C_y (df/dy)^2)^(3/2)

        where
        C_x = C_y * (xi / eta)^2
        and where (xi / eta) = dx / dy

        The value of C_y can reasonably be taken to have the value |df/dx|_max^2 + |df/dy|_max^2
        C_y = (dy / dx) * (|df/dx|_max^2 + |df/dy|_max^2) * \alpha

        for some dimensionless parameter alpha
    :param arr:
    :param alpha: regulation parameter, chosen semi-universally, but with no particular justification
    :return:
    """
    if beta is not None:
        alpha = np.power(10., beta)

    if directions is None:
        directions = arr.dims[:2]

    axis_indices = tuple(arr.dims.index(d) for d in directions)
    dx, dy = tuple(
        float(arr.coords[d][1] - arr.coords[d][0]) for d in directions)
    dfx, dfy = np.gradient(arr.values, dx, dy, axis=axis_indices)
    np.nan_to_num(dfx, copy=False)
    np.nan_to_num(dfy, copy=False)

    mdfdx, mdfdy = np.max(np.abs(dfx)), np.max(np.abs(dfy))

    cy = (dy / dx) * (mdfdx**2 + mdfdy**2) * alpha
    cx = (dx / dy) * (mdfdx**2 + mdfdy**2) * alpha

    dfx_2, dfy_2 = np.power(dfx, 2), np.power(dfy, 2)
    d2fy = np.gradient(dfy, dy, axis=axis_indices[1])
    d2fx = np.gradient(dfx, dx, axis=axis_indices[0])
    d2fxy = np.gradient(dfx, dy, axis=axis_indices[1])

    denom = np.power((1 + cx * dfx_2 + cy * dfy_2), 1.5)
    numerator = (1 + cx * dfx_2) * cy * d2fy - 2 * cx * cy * dfx * dfy * d2fxy + \
                (1 + cy * dfy_2) * cx * d2fx

    curv = xr.DataArray(numerator / denom,
                        arr.coords,
                        arr.dims,
                        attrs=arr.attrs)

    if 'id' in curv.attrs:
        del curv.attrs['id']
        provenance(
            curv, arr, {
                'what': 'Curvature',
                'by': 'curvature',
                'directions': directions,
                'alpha': alpha,
            })
    return curv
Пример #12
0
def slice_along_path(arr: xr.DataArray,
                     interpolation_points=None,
                     axis_name=None,
                     resolution=None,
                     n_points=None,
                     shift_gamma=True,
                     extend_to_edge=False,
                     **kwargs):
    """
    TODO: There might be a little bug here where the last coordinate has a value of 0, causing the interpolation to loop
    back to the start point. For now I will just deal with this in client code where I see it until I understand if it is
    universal.

    Interpolates along a path through a volume. If the volume is higher dimensional than the desired path, the
    interpolation is broadcasted along the free dimensions. This allows one to specify a k-space path and receive
    the band structure along this path in k-space.

    Points can either by specified by coordinates, or by reference to symmetry points, should they exist in the source
    array. These symmetry points are translated to regular coordinates immediately, but are provided as a convenience.
    If not all points specify the same set of coordinates, an attempt will be made to unify the coordinates. As an example,
    if the specified path is (kx=0, ky=0, T=20) -> (kx=1, ky=1), the path will be made between (kx=0, ky=0, T=20) ->
    (kx=1, ky=1, T=20). On the other hand, the path (kx=0, ky=0, T=20) -> (kx=1, ky=1, T=40) -> (kx=0, ky=1) will result
    in an error because there is no way to break the ambiguity on the temperature for the last coordinate.

    A reasonable value will be chosen for the resolution, near the maximum resolution of any of the interpolated
    axes by default.

    This function transparently handles the entire path. An alternate approach would be to convert each segment
    separately and concatenate the interpolated axis with xarray.

    If the sentinel value 'G' for the Gamma point is included in the interpolation points, the coordinate axis of the
    interpolated coordinate will be shifted so that its value at the Gamma point is 0. You can opt out of this with the
    parameter 'shift_gamma'

    :param arr: Source data
    :param interpolation_points: Path vertices
    :param axis_name: Label for the interpolated axis. Under special circumstances a reasonable name will be chosen,
    such as when the interpolation dimensions are kx and ky: in this case the interpolated dimension will be labeled kp.
    In mixed or ambiguous situations the axis will be labeled by the default value 'inter'.
    :param resolution: Requested resolution along the interpolated axis.
    :param n_points: Requested number of points in the new array. Only one of resolution and n_points should be set.
    :param shift_gamma: Controls whether the interpolated axis is shifted to a value of 0 at Gamma.
    :param extend_to_edge: Controls whether or not to scale the vector S - G for symmetry point S so that you interpolate
    to the edge of the available data
    :param kwargs:
    :return: xr.DataArray containing the interpolated data.
    """

    if resolution is not None and n_points is not None:
        raise ValueError("Only set one of resoltion and n_points!")

    if interpolation_points is None:
        raise ValueError(
            'You must provide points specifying an interpolation path')

    def extract_symmetry_point(name):
        raw_point = arr.attrs['symmetry_points'][name]
        G = arr.attrs['symmetry_points']['G']

        if not extend_to_edge or name == 'G':
            return raw_point

        # scale the point so that it reaches the edge of the dataset
        S = np.array([raw_point[d] for d in arr.dims if d in raw_point])
        G = np.array([G[d] for d in arr.dims if d in raw_point])

        scale_factor = np.inf
        for i, d in enumerate([d for d in arr.dims if d in raw_point]):
            dS = (S - G)[i]
            coord = arr.coords[d]

            if np.abs(dS) < 0.001:
                continue

            if dS < 0:
                required_scale = (np.min(coord) - G[i]) / dS
                if required_scale < scale_factor:
                    scale_factor = float(required_scale)
            else:
                required_scale = (np.max(coord) - G[i]) / dS
                if required_scale < scale_factor:
                    scale_factor = float(required_scale)

        S = (S - G) * scale_factor + G
        return dict(zip([d for d in arr.dims if d in raw_point], S))

    parsed_interpolation_points = [
        x if isinstance(x, collections.Iterable) and not isinstance(x, str)
        else extract_symmetry_point(x) for x in interpolation_points
    ]

    free_coordinates = list(arr.dims)
    seen_coordinates = collections.defaultdict(set)
    for point in parsed_interpolation_points:
        for coord, value in point.items():
            seen_coordinates[coord].add(value)
            if coord in free_coordinates:
                free_coordinates.remove(coord)

    for point in parsed_interpolation_points:
        for coord, values in seen_coordinates.items():
            if coord not in point:
                if len(values) != 1:
                    raise ValueError(
                        'Ambiguous interpolation waypoint broadcast at dimension {}'
                        .format(coord))
                else:
                    point[coord] = list(values)[0]

    if axis_name is None:
        axis_name = {
            (
                'beta',
                'phi',
            ): 'angle',
            (
                'chi',
                'phi',
            ): 'angle',
            (
                'phi',
                'psi',
            ): 'angle',
            (
                'phi',
                'theta',
            ): 'angle',
            (
                'kx',
                'ky',
            ): 'kp',
            (
                'kx',
                'kz',
            ): 'k',
            (
                'ky',
                'kz',
            ): 'k',
            (
                'kx',
                'ky',
                'kz',
            ): 'k'
        }.get(tuple(sorted(seen_coordinates.keys())), 'inter')

        if axis_name == 'angle' or axis_name == 'inter':
            warnings.warn('Interpolating along axes with different dimensions '
                          'will not include Jacobian correction factor.')

    converted_coordinates = None
    converted_dims = free_coordinates + [axis_name]

    path_segments = list(
        zip(parsed_interpolation_points, parsed_interpolation_points[1:]))

    def element_distance(waypoint_a, waypoint_b):
        delta = np.array(
            [waypoint_a[k] - waypoint_b[k] for k in waypoint_a.keys()])
        return np.linalg.norm(delta)

    def required_sampling_density(waypoint_a, waypoint_b):
        ks = waypoint_a.keys()
        dist = element_distance(waypoint_a, waypoint_b)
        delta = np.array([waypoint_a[k] - waypoint_b[k] for k in ks])
        delta_idx = [
            abs(d / (arr.coords[k][1] - arr.coords[k][0]))
            for d, k in zip(delta, ks)
        ]
        return dist / np.max(delta_idx)

    # Approximate how many points we should use
    segment_lengths = [element_distance(*segment) for segment in path_segments]
    path_length = sum(segment_lengths)

    gamma_offset = 0  # offset the gamma point to a k coordinate of 0 if possible
    if 'G' in interpolation_points and shift_gamma:
        gamma_offset = sum(segment_lengths[0:interpolation_points.index('G')])

    if resolution is None:
        if n_points is None:
            resolution = np.min([
                required_sampling_density(*segment)
                for segment in path_segments
            ])
        else:
            path_length / n_points

    def converter_for_coordinate_name(name):
        def raw_interpolator(*coordinates):
            return coordinates[free_coordinates.index(name)]

        if name in free_coordinates:
            return raw_interpolator

        # Conversion involves the interpolated coordinates
        def interpolated_coordinate_to_raw(*coordinates):
            # Coordinate order is [*free_coordinates, interpolated]
            interpolated = coordinates[len(free_coordinates)] + gamma_offset

            # Start with empty array that we will mask writes onto
            # We need to go with a masking approach rather than a concatenation based one because the coordinates
            # come from np.meshgrid
            dest_coordinate = np.zeros(shape=interpolated.shape)

            start = 0
            for i, l in enumerate(segment_lengths):
                end = start + l
                normalized = (interpolated - start) / l
                seg_start, seg_end = path_segments[i]
                dim_start, dim_end = seg_start[name], seg_end[name]
                mask = np.logical_and(normalized >= 0, normalized < 1)
                dest_coordinate[mask] = \
                    dim_start * (1 - normalized[mask]) + dim_end * normalized[mask]
                start = end

            return dest_coordinate

        return interpolated_coordinate_to_raw

    converted_coordinates = {d: arr.coords[d].values for d in free_coordinates}

    if n_points is None:
        n_points = int(path_length / resolution)

    # Adjust this coordinate under special circumstances
    converted_coordinates[axis_name] = np.linspace(0, path_length,
                                                   n_points) - gamma_offset

    converted_ds = convert_coordinates(
        arr,
        converted_coordinates, {
            'dims':
            converted_dims,
            'transforms':
            dict(
                zip(arr.dims,
                    [converter_for_coordinate_name(d) for d in arr.dims]))
        },
        as_dataset=True)

    if axis_name in arr.dims and len(parsed_interpolation_points) == 2:
        if parsed_interpolation_points[1][
                axis_name] < parsed_interpolation_points[0][axis_name]:
            # swap the sign on this axis as a convenience to the caller
            converted_ds.coords[
                axis_name].data = -converted_ds.coords[axis_name].data

    if 'id' in converted_ds.attrs:
        del converted_ds.attrs['id']
        provenance(
            converted_ds, arr, {
                'what': 'Slice along path',
                'by': 'slice_along_path',
                'parsed_interpolation_points': parsed_interpolation_points,
                'interpolation_points': interpolation_points,
            })

    return converted_ds
Пример #13
0
def transform_dataarray_axis(f,
                             old_axis_name: str,
                             new_axis_name: str,
                             new_axis,
                             dataset: xr.DataArray,
                             prep_name,
                             transform_spectra=None,
                             remove_old=True):

    ds = dataset.copy()
    if transform_spectra is None:
        # transform *all* DataArrays in the dataset that have old_axis_name in their dimensions
        transform_spectra = {
            k: v
            for k, v in ds.data_vars.items() if old_axis_name in v.dims
        }

    ds.coords[new_axis_name] = new_axis

    new_dataarrays = []
    for name in transform_spectra.keys():
        dr = ds[name]

        old_axis = dr.dims.index(old_axis_name)
        shape = list(dr.sizes.values())
        shape[old_axis] = len(new_axis)
        new_dims = list(dr.dims)
        new_dims[old_axis] = new_axis_name

        g = functools.partial(f, axis=old_axis)
        output = geometric_transform(dr.values,
                                     g,
                                     output_shape=shape,
                                     output='f',
                                     order=1)

        new_coords = dict(dr.coords)
        new_coords.pop(old_axis_name)

        new_dataarray = xr.DataArray(output,
                                     coords=new_coords,
                                     dims=new_dims,
                                     attrs=dr.attrs.copy(),
                                     name=prep_name(dr.name))
        new_dataarrays.append(new_dataarray)
        if 'id' in new_dataarray.attrs:
            del new_dataarray.attrs['id']

        if remove_old:
            del ds[name]
        else:
            assert (prep_name(name) != name
                    and "You must make sure names don't collide")

    new_ds = xr.merge([ds, *new_dataarrays])

    new_ds.attrs.update(ds.attrs.copy())

    if 'id' in new_ds:
        del new_ds.attrs['id']

    provenance(
        new_ds, dataset, {
            'what': 'Transformed a Dataset coordinate axis',
            'by': 'transform_dataarray_axis',
            'old_axis': old_axis_name,
            'new_axis': new_axis_name,
            'transformed_vars': list(transform_spectra.keys()),
        })

    return new_ds