예제 #1
0
    def _zscale(self, image, vmin, vmax, stretch, contrast=0.25):
        """Normalization object using Zscale algorithm
           See :mod:`astropy.visualization.ZScaleInterval`
        
        :param image: the image object
        :type image: :mod:`astropy.io.fits` HDU or CCDData
        :param contrast: The scaling factor (between 0 and 1) for determining the minimum and maximum value. Larger values increase the difference between the minimum and maximum values used for display. Defaults to 0.25.  
        :type contrast: float
        :returns: :mod:`astropy.visualization.normalization` object
        """
        # clip=False required or NaNs get max color value, see https://github.com/astropy/astropy/issues/8165
        if stretch == 'linear':
            s = LinearStretch()
        elif stretch == 'sqrt':
            s = SqrtStretch()
        elif stretch == 'power':
            s = PowerStretch(2)
        elif stretch == 'log':
            s = LogStretch(1000)
        elif stretch == 'asinh':
            s = AsinhStretch(0.1)
        else:
            raise ValueError(f'Unknown stretch: {stretch}.')

        norm = ImageNormalize(data=image,
                              vmin=vmin,
                              vmax=vmax,
                              interval=ZScaleInterval(contrast=contrast),
                              stretch=s,
                              clip=False)
        return norm
예제 #2
0
def test_invalid_power_log_a(a):
    match = 'a must be > 0'
    with pytest.raises(ValueError, match=match):
        PowerStretch(a=a)
    with pytest.raises(ValueError, match=match):
        LogStretch(a=a)
    with pytest.raises(ValueError, match=match):
        InvertedLogStretch(a=a)
예제 #3
0
def simple_norm(data,
                stretch='linear',
                power=1.0,
                asinh_a=0.1,
                log_a=1000,
                min_cut=None,
                max_cut=None,
                min_percent=None,
                max_percent=None,
                percent=None,
                clip=True):

    if percent is not None:
        interval = PercentileInterval(percent)
    elif min_percent is not None or max_percent is not None:
        interval = AsymmetricPercentileInterval(min_percent or 0., max_percent
                                                or 100.)
    elif min_cut is not None or max_cut is not None:
        interval = ManualInterval(min_cut, max_cut)
    else:
        interval = MinMaxInterval()

    if stretch == 'linear':
        stretch = LinearStretch()
    elif stretch == 'sqrt':
        stretch = SqrtStretch()
    elif stretch == 'power':
        stretch = PowerStretch(power)
    elif stretch == 'log':
        stretch = LogStretch(log_a)
    elif stretch == 'asinh':
        stretch = AsinhStretch(asinh_a)
    else:
        raise ValueError('Unknown stretch: {0}.'.format(stretch))

    vmin, vmax = interval.get_limits(data)

    return ImageNormalize(vmin=vmin, vmax=vmax, stretch=stretch, clip=clip)
예제 #4
0
from astropy.utils.exceptions import AstropyDeprecationWarning
from astropy.visualization.mpl_normalize import ImageNormalize, simple_norm, imshow_norm
from astropy.visualization.interval import ManualInterval, PercentileInterval
from astropy.visualization.stretch import LogStretch, PowerStretch, SqrtStretch
from astropy.utils.compat.optional_deps import HAS_MATPLOTLIB, HAS_PLT  # noqa

if HAS_MATPLOTLIB:
    import matplotlib
    MATPLOTLIB_LT_32 = Version(matplotlib.__version__) < Version('3.2')

DATA = np.linspace(0., 15., 6)
DATA2 = np.arange(3)
DATA2SCL = 0.5 * DATA2
DATA3 = np.linspace(-3., 3., 7)
STRETCHES = (SqrtStretch(), PowerStretch(0.5), LogStretch())
INVALID = (None, -np.inf, -1)


@pytest.mark.skipif('HAS_MATPLOTLIB')
def test_normalize_error_message():
    with pytest.raises(ImportError) as exc:
        ImageNormalize()
    assert (exc.value.args[0] == "matplotlib is required in order to use "
            "this class.")


@pytest.mark.skipif('not HAS_MATPLOTLIB')
class TestNormalize:
    def test_invalid_interval(self):
        with pytest.raises(TypeError):
예제 #5
0
    LinearStretch, SqrtStretch, PowerStretch, PowerDistStretch,
    InvertedPowerDistStretch, SquaredStretch, LogStretch, InvertedLogStretch,
    AsinhStretch, SinhStretch, HistEqStretch, InvertedHistEqStretch,
    ContrastBiasStretch)

DATA = np.array([0.00, 0.25, 0.50, 0.75, 1.00])

RESULTS = {}
RESULTS[LinearStretch()] = np.array([0.00, 0.25, 0.50, 0.75, 1.00])
RESULTS[LinearStretch(intercept=0.5) + LinearStretch(slope=0.5)] = \
    np.array([0.5, 0.625, 0.75, 0.875, 1.])
RESULTS[SqrtStretch()] = np.array([0., 0.5, 0.70710678, 0.8660254, 1.])
RESULTS[SquaredStretch()] = np.array([0., 0.0625, 0.25, 0.5625, 1.])
RESULTS[PowerStretch(0.5)] = np.array([0., 0.5, 0.70710678, 0.8660254, 1.])
RESULTS[PowerDistStretch()] = np.array([0., 0.004628, 0.030653, 0.177005, 1.])
RESULTS[LogStretch()] = np.array([0., 0.799776, 0.899816, 0.958408, 1.])
RESULTS[AsinhStretch()] = np.array([0., 0.549402, 0.77127, 0.904691, 1.])
RESULTS[SinhStretch()] = np.array([0., 0.082085, 0.212548, 0.46828, 1.])
RESULTS[ContrastBiasStretch(contrast=2.,
                            bias=0.4)] = np.array([-0.3, 0.2, 0.7, 1.2, 1.7])
RESULTS[HistEqStretch(DATA)] = DATA
RESULTS[HistEqStretch(DATA[::-1])] = DATA
RESULTS[HistEqStretch(DATA**0.5)] = np.array([0., 0.125, 0.25, 0.5674767, 1.])


class TestStretch:
    @pytest.mark.parametrize('stretch', RESULTS.keys())
    def test_no_clip(self, stretch):
        np.testing.assert_allclose(stretch(DATA, clip=False),
                                   RESULTS[stretch],
                                   atol=1.e-6)