def test_local_pixel_derivatives(spatial_wcs_2d_small_angle): not_diag = np.logical_not(np.diag([1,1])) # At (or close to) the reference pixel this should equal the cdelt derivs = local_partial_pixel_derivatives(spatial_wcs_2d_small_angle, 3, 3) np.testing.assert_allclose(np.diag(derivs), spatial_wcs_2d_small_angle.wcs.cdelt) np.testing.assert_allclose(derivs[not_diag].flat, [0,0], atol=1e-10) # Far away from the reference pixel this should not equal the cdelt derivs = local_partial_pixel_derivatives(spatial_wcs_2d_small_angle, 3e4, 3e4) assert not np.allclose(np.diag(derivs), spatial_wcs_2d_small_angle.wcs.cdelt) # At (or close to) the reference pixel this should equal the cdelt derivs = local_partial_pixel_derivatives(spatial_wcs_2d_small_angle, 3, 3, normalize_by_world=True) np.testing.assert_allclose(np.diag(derivs), [1, 1]) np.testing.assert_allclose(derivs[not_diag].flat, [0,0], atol=1e-8)
def transform_coord_meta_from_wcs(wcs, frame_class, slices=None): if slices is not None: slices = tuple(slices) if wcs.pixel_n_dim > 2: if slices is None: raise ValueError("WCS has more than 2 pixel dimensions, so " "'slices' should be set") elif len(slices) != wcs.pixel_n_dim: raise ValueError("'slices' should have as many elements as WCS " "has pixel dimensions (should be {})".format( wcs.pixel_n_dim)) is_fits_wcs = isinstance(wcs, WCS) coord_meta = {} coord_meta['name'] = [] coord_meta['type'] = [] coord_meta['wrap'] = [] coord_meta['unit'] = [] coord_meta['visible'] = [] coord_meta['format_unit'] = [] for idx in range(wcs.world_n_dim): axis_type = wcs.world_axis_physical_types[idx] axis_unit = u.Unit(wcs.world_axis_units[idx]) coord_wrap = None format_unit = axis_unit coord_type = 'scalar' if axis_type is not None: axis_type_split = axis_type.split('.') if "pos.helioprojective.lon" in axis_type: coord_wrap = 180. format_unit = u.arcsec coord_type = "longitude" elif "pos.helioprojective.lat" in axis_type: format_unit = u.arcsec coord_type = "latitude" elif "pos.heliographic.stonyhurst.lon" in axis_type: coord_wrap = 180. format_unit = u.deg coord_type = "longitude" elif "pos.heliographic.stonyhurst.lat" in axis_type: format_unit = u.deg coord_type = "latitude" elif "pos.heliographic.carrington.lon" in axis_type: coord_wrap = 360. format_unit = u.deg coord_type = "longitude" elif "pos.heliographic.carrington.lat" in axis_type: format_unit = u.deg coord_type = "latitude" elif "pos" in axis_type_split: if "lon" in axis_type_split: coord_type = "longitude" elif "lat" in axis_type_split: coord_type = "latitude" elif "ra" in axis_type_split: coord_type = "longitude" format_unit = u.hourangle elif "dec" in axis_type_split: coord_type = "latitude" elif "alt" in axis_type_split: coord_type = "longitude" elif "az" in axis_type_split: coord_type = "latitude" elif "long" in axis_type_split: coord_type = "longitude" coord_meta['type'].append(coord_type) coord_meta['wrap'].append(coord_wrap) coord_meta['format_unit'].append(format_unit) coord_meta['unit'].append(axis_unit) # For FITS-WCS, for backward-compatibility, we need to make sure that we # provide aliases based on CTYPE for the name. if is_fits_wcs: name = [] if isinstance(wcs, WCS): name.append(wcs.wcs.ctype[idx].lower()) name.append(wcs.wcs.ctype[idx][:4].replace('-', '').lower()) elif isinstance(wcs, SlicedLowLevelWCS): name.append(wcs._wcs.wcs.ctype[wcs._world_keep[idx]].lower()) name.append( wcs._wcs.wcs.ctype[wcs._world_keep[idx]][:4].replace( '-', '').lower()) if name[0] == name[1]: name = name[0:1] if axis_type: name.insert(0, axis_type) name = tuple(name) if len(name) > 1 else name[0] else: name = axis_type or '' coord_meta['name'].append(name) coord_meta['default_axislabel_position'] = [''] * wcs.world_n_dim coord_meta['default_ticklabel_position'] = [''] * wcs.world_n_dim coord_meta['default_ticks_position'] = [''] * wcs.world_n_dim transform_wcs, invert_xy, world_map = apply_slices(wcs, slices) transform = WCSPixel2WorldTransform(transform_wcs, invert_xy=invert_xy) for i in range(len(coord_meta['type'])): coord_meta['visible'].append(i in world_map) inv_all_corr = [False] * wcs.world_n_dim m = transform_wcs.axis_correlation_matrix.copy() if invert_xy: inv_all_corr = np.all(m, axis=1) m = m[:, ::-1] if frame_class is RectangularFrame: for i, spine_name in enumerate('bltr'): pos = np.nonzero(m[:, i % 2])[0] # If all the axes we have are correlated with each other and we # have inverted the axes, then we need to reverse the index so we # put the 'y' on the left. if inv_all_corr[i % 2]: pos = pos[::-1] if len(pos) > 0: index = world_map[pos[0]] coord_meta['default_axislabel_position'][index] = spine_name coord_meta['default_ticklabel_position'][index] = spine_name coord_meta['default_ticks_position'][index] = spine_name m[pos[0], :] = 0 # In the special and common case where the frame is rectangular and # we are dealing with 2-d WCS (after slicing), we show all ticks on # all axes for backward-compatibility. if len(world_map) == 2: for index in world_map: coord_meta['default_ticks_position'][index] = 'bltr' elif frame_class is RectangularFrame1D: derivs = np.abs( local_partial_pixel_derivatives(transform_wcs, *[0] * transform_wcs.pixel_n_dim, normalize_by_world=False))[:, 0] for i, spine_name in enumerate('bt'): # Here we are iterating over the correlated axes in world axis order. # We want to sort the correlated axes by their partial derivatives, # so we put the most rapidly changing world axis on the bottom. pos = np.nonzero(m[:, 0])[0] order = np.argsort(derivs[pos])[::-1] # Sort largest to smallest pos = pos[order] if len(pos) > 0: index = world_map[pos[0]] coord_meta['default_axislabel_position'][index] = spine_name coord_meta['default_ticklabel_position'][index] = spine_name coord_meta['default_ticks_position'][index] = spine_name m[pos[0], :] = 0 # In the special and common case where the frame is rectangular and # we are dealing with 2-d WCS (after slicing), we show all ticks on # all axes for backward-compatibility. if len(world_map) == 1: for index in world_map: coord_meta['default_ticks_position'][index] = 'bt' elif frame_class is EllipticalFrame: if 'longitude' in coord_meta['type']: lon_idx = coord_meta['type'].index('longitude') coord_meta['default_axislabel_position'][lon_idx] = 'h' coord_meta['default_ticklabel_position'][lon_idx] = 'h' coord_meta['default_ticks_position'][lon_idx] = 'h' if 'latitude' in coord_meta['type']: lat_idx = coord_meta['type'].index('latitude') coord_meta['default_axislabel_position'][lat_idx] = 'c' coord_meta['default_ticklabel_position'][lat_idx] = 'c' coord_meta['default_ticks_position'][lat_idx] = 'c' else: for index in range(len(coord_meta['type'])): if index in world_map: coord_meta['default_axislabel_position'][ index] = frame_class.spine_names coord_meta['default_ticklabel_position'][ index] = frame_class.spine_names coord_meta['default_ticks_position'][ index] = frame_class.spine_names return transform, coord_meta