예제 #1
0
def test_local_pixel_derivatives(spatial_wcs_2d_small_angle):
    not_diag = np.logical_not(np.diag([1,1]))
    # At (or close to) the reference pixel this should equal the cdelt
    derivs = local_partial_pixel_derivatives(spatial_wcs_2d_small_angle, 3, 3)
    np.testing.assert_allclose(np.diag(derivs), spatial_wcs_2d_small_angle.wcs.cdelt)
    np.testing.assert_allclose(derivs[not_diag].flat, [0,0], atol=1e-10)

    # Far away from the reference pixel this should not equal the cdelt
    derivs = local_partial_pixel_derivatives(spatial_wcs_2d_small_angle, 3e4, 3e4)
    assert not np.allclose(np.diag(derivs), spatial_wcs_2d_small_angle.wcs.cdelt)

    # At (or close to) the reference pixel this should equal the cdelt
    derivs = local_partial_pixel_derivatives(spatial_wcs_2d_small_angle, 3, 3, normalize_by_world=True)
    np.testing.assert_allclose(np.diag(derivs), [1, 1])
    np.testing.assert_allclose(derivs[not_diag].flat, [0,0], atol=1e-8)
예제 #2
0
def transform_coord_meta_from_wcs(wcs, frame_class, slices=None):

    if slices is not None:
        slices = tuple(slices)

    if wcs.pixel_n_dim > 2:
        if slices is None:
            raise ValueError("WCS has more than 2 pixel dimensions, so "
                             "'slices' should be set")
        elif len(slices) != wcs.pixel_n_dim:
            raise ValueError("'slices' should have as many elements as WCS "
                             "has pixel dimensions (should be {})".format(
                                 wcs.pixel_n_dim))

    is_fits_wcs = isinstance(wcs, WCS)

    coord_meta = {}
    coord_meta['name'] = []
    coord_meta['type'] = []
    coord_meta['wrap'] = []
    coord_meta['unit'] = []
    coord_meta['visible'] = []
    coord_meta['format_unit'] = []

    for idx in range(wcs.world_n_dim):

        axis_type = wcs.world_axis_physical_types[idx]
        axis_unit = u.Unit(wcs.world_axis_units[idx])
        coord_wrap = None
        format_unit = axis_unit

        coord_type = 'scalar'

        if axis_type is not None:

            axis_type_split = axis_type.split('.')

            if "pos.helioprojective.lon" in axis_type:
                coord_wrap = 180.
                format_unit = u.arcsec
                coord_type = "longitude"
            elif "pos.helioprojective.lat" in axis_type:
                format_unit = u.arcsec
                coord_type = "latitude"
            elif "pos.heliographic.stonyhurst.lon" in axis_type:
                coord_wrap = 180.
                format_unit = u.deg
                coord_type = "longitude"
            elif "pos.heliographic.stonyhurst.lat" in axis_type:
                format_unit = u.deg
                coord_type = "latitude"
            elif "pos.heliographic.carrington.lon" in axis_type:
                coord_wrap = 360.
                format_unit = u.deg
                coord_type = "longitude"
            elif "pos.heliographic.carrington.lat" in axis_type:
                format_unit = u.deg
                coord_type = "latitude"
            elif "pos" in axis_type_split:
                if "lon" in axis_type_split:
                    coord_type = "longitude"
                elif "lat" in axis_type_split:
                    coord_type = "latitude"
                elif "ra" in axis_type_split:
                    coord_type = "longitude"
                    format_unit = u.hourangle
                elif "dec" in axis_type_split:
                    coord_type = "latitude"
                elif "alt" in axis_type_split:
                    coord_type = "longitude"
                elif "az" in axis_type_split:
                    coord_type = "latitude"
                elif "long" in axis_type_split:
                    coord_type = "longitude"

        coord_meta['type'].append(coord_type)
        coord_meta['wrap'].append(coord_wrap)
        coord_meta['format_unit'].append(format_unit)
        coord_meta['unit'].append(axis_unit)

        # For FITS-WCS, for backward-compatibility, we need to make sure that we
        # provide aliases based on CTYPE for the name.
        if is_fits_wcs:
            name = []
            if isinstance(wcs, WCS):
                name.append(wcs.wcs.ctype[idx].lower())
                name.append(wcs.wcs.ctype[idx][:4].replace('-', '').lower())
            elif isinstance(wcs, SlicedLowLevelWCS):
                name.append(wcs._wcs.wcs.ctype[wcs._world_keep[idx]].lower())
                name.append(
                    wcs._wcs.wcs.ctype[wcs._world_keep[idx]][:4].replace(
                        '-', '').lower())
            if name[0] == name[1]:
                name = name[0:1]
            if axis_type:
                name.insert(0, axis_type)
            name = tuple(name) if len(name) > 1 else name[0]
        else:
            name = axis_type or ''

        coord_meta['name'].append(name)

    coord_meta['default_axislabel_position'] = [''] * wcs.world_n_dim
    coord_meta['default_ticklabel_position'] = [''] * wcs.world_n_dim
    coord_meta['default_ticks_position'] = [''] * wcs.world_n_dim

    transform_wcs, invert_xy, world_map = apply_slices(wcs, slices)

    transform = WCSPixel2WorldTransform(transform_wcs, invert_xy=invert_xy)

    for i in range(len(coord_meta['type'])):
        coord_meta['visible'].append(i in world_map)

    inv_all_corr = [False] * wcs.world_n_dim
    m = transform_wcs.axis_correlation_matrix.copy()
    if invert_xy:
        inv_all_corr = np.all(m, axis=1)
        m = m[:, ::-1]

    if frame_class is RectangularFrame:

        for i, spine_name in enumerate('bltr'):
            pos = np.nonzero(m[:, i % 2])[0]
            # If all the axes we have are correlated with each other and we
            # have inverted the axes, then we need to reverse the index so we
            # put the 'y' on the left.
            if inv_all_corr[i % 2]:
                pos = pos[::-1]

            if len(pos) > 0:
                index = world_map[pos[0]]
                coord_meta['default_axislabel_position'][index] = spine_name
                coord_meta['default_ticklabel_position'][index] = spine_name
                coord_meta['default_ticks_position'][index] = spine_name
                m[pos[0], :] = 0

        # In the special and common case where the frame is rectangular and
        # we are dealing with 2-d WCS (after slicing), we show all ticks on
        # all axes for backward-compatibility.
        if len(world_map) == 2:
            for index in world_map:
                coord_meta['default_ticks_position'][index] = 'bltr'

    elif frame_class is RectangularFrame1D:
        derivs = np.abs(
            local_partial_pixel_derivatives(transform_wcs,
                                            *[0] * transform_wcs.pixel_n_dim,
                                            normalize_by_world=False))[:, 0]
        for i, spine_name in enumerate('bt'):
            # Here we are iterating over the correlated axes in world axis order.
            # We want to sort the correlated axes by their partial derivatives,
            # so we put the most rapidly changing world axis on the bottom.
            pos = np.nonzero(m[:, 0])[0]
            order = np.argsort(derivs[pos])[::-1]  # Sort largest to smallest
            pos = pos[order]
            if len(pos) > 0:
                index = world_map[pos[0]]
                coord_meta['default_axislabel_position'][index] = spine_name
                coord_meta['default_ticklabel_position'][index] = spine_name
                coord_meta['default_ticks_position'][index] = spine_name
                m[pos[0], :] = 0

        # In the special and common case where the frame is rectangular and
        # we are dealing with 2-d WCS (after slicing), we show all ticks on
        # all axes for backward-compatibility.
        if len(world_map) == 1:
            for index in world_map:
                coord_meta['default_ticks_position'][index] = 'bt'

    elif frame_class is EllipticalFrame:

        if 'longitude' in coord_meta['type']:
            lon_idx = coord_meta['type'].index('longitude')
            coord_meta['default_axislabel_position'][lon_idx] = 'h'
            coord_meta['default_ticklabel_position'][lon_idx] = 'h'
            coord_meta['default_ticks_position'][lon_idx] = 'h'

        if 'latitude' in coord_meta['type']:
            lat_idx = coord_meta['type'].index('latitude')
            coord_meta['default_axislabel_position'][lat_idx] = 'c'
            coord_meta['default_ticklabel_position'][lat_idx] = 'c'
            coord_meta['default_ticks_position'][lat_idx] = 'c'

    else:

        for index in range(len(coord_meta['type'])):
            if index in world_map:
                coord_meta['default_axislabel_position'][
                    index] = frame_class.spine_names
                coord_meta['default_ticklabel_position'][
                    index] = frame_class.spine_names
                coord_meta['default_ticks_position'][
                    index] = frame_class.spine_names

    return transform, coord_meta