Beispiel #1
0
def main():
    print("Starting tests for wotan...")

    numpy.testing.assert_almost_equal(t14(R_s=1, M_s=1, P=365),
                                      0.6490025258902046)

    numpy.testing.assert_almost_equal(
        t14(R_s=1, M_s=1, P=365, small_planet=True), 0.5403690143737738)
    print("Transit duration correct.")

    numpy.random.seed(seed=0)  # reproducibility

    print("Slide clipper...")
    points = 1000
    time = numpy.linspace(0, 30, points)
    flux = 1 + numpy.sin(time) / points
    noise = numpy.random.normal(0, 0.0001, points)
    flux += noise

    for i in range(points):
        if i % 75 == 0:
            flux[i:i + 5] -= 0.0004  # Add some transits
            flux[i + 50:i + 52] += 0.0002  # and flares

    clipped = slide_clip(time,
                         flux,
                         window_length=0.5,
                         low=3,
                         high=2,
                         method='mad',
                         center='median')
    numpy.testing.assert_almost_equal(numpy.nansum(clipped), 948.9926368754939)
    """
    import matplotlib.pyplot as plt
    plt.scatter(time, flux, s=3, color='black')
    plt.scatter(time, clipped, s=3, color='orange')
    plt.show()
    """

    # TESS test
    print('Loading TESS data from archive.stsci.edu...')
    path = 'https://archive.stsci.edu/hlsps/tess-data-alerts/'
    filename = "hlsp_tess-data-alerts_tess_phot_00062483237-s01_tess_v1_lc.fits"
    #path = 'P:/P/Dok/tess_alarm/'
    #filename = "hlsp_tess-data-alerts_tess_phot_00062483237-s01_tess_v1_lc.fits"
    #filename = 'P:/P/Dok/tess_alarm/hlsp_tess-data-alerts_tess_phot_00077031414-s02_tess_v1_lc.fits'
    #filename = 'tess2018206045859-s0001-0000000201248411-111-s_llc.fits'
    time, flux = load_file(path + filename)

    window_length = 0.5

    print("Detrending 1 (biweight)...")
    flatten_lc, trend_lc = flatten(time,
                                   flux,
                                   window_length,
                                   edge_cutoff=1,
                                   break_tolerance=0.1,
                                   return_trend=True,
                                   cval=5.0)

    numpy.testing.assert_equal(len(trend_lc), 20076)
    numpy.testing.assert_almost_equal(numpy.nanmax(trend_lc),
                                      28754.985299070882)
    numpy.testing.assert_almost_equal(numpy.nanmin(trend_lc),
                                      28615.108124724477)
    numpy.testing.assert_almost_equal(trend_lc[500], 28671.686308143515)

    numpy.testing.assert_equal(len(flatten_lc), 20076)
    numpy.testing.assert_almost_equal(numpy.nanmax(flatten_lc),
                                      1.0034653549250616)
    numpy.testing.assert_almost_equal(numpy.nanmin(flatten_lc),
                                      0.996726610702177)
    numpy.testing.assert_almost_equal(flatten_lc[500], 1.000577429565131)

    print("Detrending 2 (andrewsinewave)...")
    flatten_lc, trend_lc = flatten(time,
                                   flux,
                                   window_length,
                                   method='andrewsinewave',
                                   return_trend=True)
    numpy.testing.assert_almost_equal(numpy.nansum(flatten_lc),
                                      18119.15471987987,
                                      decimal=2)

    print("Detrending 3 (welsch)...")
    flatten_lc, trend_lc = flatten(time,
                                   flux,
                                   window_length,
                                   method='welsch',
                                   return_trend=True)
    numpy.testing.assert_almost_equal(numpy.nansum(flatten_lc),
                                      18119.16764691235,
                                      decimal=2)

    print("Detrending 4 (hodges)...")
    flatten_lc, trend_lc = flatten(time[:1000],
                                   flux[:1000],
                                   window_length,
                                   method='hodges',
                                   return_trend=True)
    numpy.testing.assert_almost_equal(numpy.nansum(flatten_lc),
                                      994.0110525909206,
                                      decimal=2)

    print("Detrending 5 (median)...")
    flatten_lc, trend_lc = flatten(time,
                                   flux,
                                   window_length,
                                   method='median',
                                   return_trend=True)
    numpy.testing.assert_almost_equal(numpy.nansum(flatten_lc),
                                      18119.122065014355,
                                      decimal=2)

    print("Detrending 6 (mean)...")
    flatten_lc, trend_lc = flatten(time,
                                   flux,
                                   window_length,
                                   method='mean',
                                   return_trend=True)
    numpy.testing.assert_almost_equal(numpy.nansum(flatten_lc),
                                      18119.032473037714,
                                      decimal=2)

    print("Detrending 7 (trim_mean)...")
    flatten_lc, trend_lc = flatten(time,
                                   flux,
                                   window_length,
                                   method='trim_mean',
                                   return_trend=True)
    numpy.testing.assert_almost_equal(numpy.nansum(flatten_lc),
                                      18119.095164910334,
                                      decimal=2)

    print("Detrending 8 (supersmoother)...")
    flatten_lc, trend_lc = flatten(time,
                                   flux,
                                   window_length,
                                   method='supersmoother',
                                   return_trend=True)
    numpy.testing.assert_almost_equal(numpy.nansum(flatten_lc),
                                      18123.00632204841,
                                      decimal=2)

    print("Detrending 9 (hspline)...")
    flatten_lc, trend_lc = flatten(time,
                                   flux,
                                   window_length,
                                   method='hspline',
                                   return_trend=True)
    numpy.testing.assert_almost_equal(numpy.nansum(flatten_lc),
                                      18123.07625225313,
                                      decimal=2)

    print("Detrending 10 (cofiam)...")
    flatten_lc, trend_lc = flatten(time[:2000],
                                   flux[:2000],
                                   window_length,
                                   method='cofiam',
                                   return_trend=True)
    numpy.testing.assert_almost_equal(numpy.nansum(flatten_lc),
                                      1948.9999999987976,
                                      decimal=2)

    print("Detrending 11 (savgol)...")
    flatten_lc, trend_lc = flatten(time,
                                   flux,
                                   window_length=301,
                                   method='savgol',
                                   return_trend=True)
    numpy.testing.assert_almost_equal(numpy.nansum(flatten_lc),
                                      18123.003465539354,
                                      decimal=2)

    print("Detrending 12 (medfilt)...")
    flatten_lc, trend_lc = flatten(time,
                                   flux,
                                   window_length=301,
                                   method='medfilt',
                                   return_trend=True)
    numpy.testing.assert_almost_equal(numpy.nansum(flatten_lc),
                                      18123.22609806557,
                                      decimal=2)

    print("Detrending 12 (gp squared_exp)...")
    flatten_lc, trend_lc1 = flatten(time[:2000],
                                    flux[:2000],
                                    method='gp',
                                    kernel='squared_exp',
                                    kernel_size=10,
                                    return_trend=True)
    numpy.testing.assert_almost_equal(numpy.nansum(flatten_lc),
                                      1948.99958552324,
                                      decimal=2)

    print("Detrending 13 (gp squared_exp robust)...")
    flatten_lc, trend_lc1 = flatten(time[:2000],
                                    flux[:2000],
                                    method='gp',
                                    kernel='squared_exp',
                                    kernel_size=10,
                                    robust=True,
                                    return_trend=True)
    numpy.testing.assert_almost_equal(numpy.nansum(flatten_lc),
                                      1948.8820772313468,
                                      decimal=2)

    print("Detrending 14 (gp matern)...")
    flatten_lc, trend_lc2 = flatten(time[:2000],
                                    flux[:2000],
                                    method='gp',
                                    kernel='matern',
                                    kernel_size=10,
                                    return_trend=True)
    numpy.testing.assert_almost_equal(numpy.nansum(flatten_lc),
                                      1949.0001583058202,
                                      decimal=2)

    print("Detrending 15 (gp periodic)...")
    flatten_lc, trend_lc2 = flatten(time[:2000],
                                    flux[:2000],
                                    method='gp',
                                    kernel='periodic',
                                    kernel_size=1,
                                    kernel_period=10,
                                    return_trend=True)
    numpy.testing.assert_almost_equal(numpy.nansum(flatten_lc),
                                      1948.9999708985608,
                                      decimal=2)

    time_synth = numpy.linspace(0, 30, 200)
    flux_synth = numpy.sin(time_synth) + numpy.random.normal(0, 0.1, 200)
    flux_synth = 1 + flux_synth / 100
    time_synth *= 1.5
    print("Detrending 16 (gp periodic_auto)...")
    flatten_lc, trend_lc2 = flatten(time_synth,
                                    flux_synth,
                                    method='gp',
                                    kernel='periodic_auto',
                                    kernel_size=1,
                                    return_trend=True)
    numpy.testing.assert_almost_equal(numpy.nansum(flatten_lc), 200, decimal=1)

    print("Detrending 17 (rspline)...")
    flatten_lc, trend_lc2 = flatten(time,
                                    flux,
                                    method='rspline',
                                    window_length=1,
                                    return_trend=True)
    numpy.testing.assert_almost_equal(numpy.nansum(flatten_lc),
                                      18121.812790732245,
                                      decimal=2)

    print("Detrending 18 (huber)...")
    flatten_lc, trend_lc = flatten(time[:1000],
                                   flux[:1000],
                                   method='huber',
                                   window_length=0.5,
                                   edge_cutoff=0,
                                   break_tolerance=0.4,
                                   return_trend=True)
    numpy.testing.assert_almost_equal(numpy.nansum(flatten_lc),
                                      994.01102,
                                      decimal=2)

    print("Detrending 19 (winsorize)...")
    flatten_lc, trend_lc2 = flatten(time,
                                    flux,
                                    method='winsorize',
                                    window_length=0.5,
                                    edge_cutoff=0,
                                    break_tolerance=0.4,
                                    proportiontocut=0.1,
                                    return_trend=True)
    numpy.testing.assert_almost_equal(numpy.nansum(flatten_lc),
                                      18119.064587196448,
                                      decimal=2)

    print("Detrending 20 (pspline)...")
    flatten_lc, trend_lc = flatten(time,
                                   flux,
                                   method='pspline',
                                   return_trend=True)
    numpy.testing.assert_almost_equal(numpy.nansum(flatten_lc),
                                      18121.832133916843,
                                      decimal=2)

    print("Detrending 21 (hampelfilt)...")
    flatten_lc, trend_lc5 = flatten(time,
                                    flux,
                                    method='hampelfilt',
                                    window_length=0.5,
                                    cval=3,
                                    return_trend=True)
    numpy.testing.assert_almost_equal(numpy.nansum(flatten_lc),
                                      18119.158072498867,
                                      decimal=2)

    print("Detrending 22 (lowess)...")
    flatten_lc, trend_lc1 = flatten(time,
                                    flux,
                                    method='lowess',
                                    window_length=1,
                                    return_trend=True)
    numpy.testing.assert_almost_equal(numpy.nansum(flatten_lc),
                                      18123.08085676265,
                                      decimal=2)

    print("Detrending 23 (huber_psi)...")
    flatten_lc, trend_lc1 = flatten(time,
                                    flux,
                                    method='huber_psi',
                                    window_length=0.5,
                                    return_trend=True)
    numpy.testing.assert_almost_equal(numpy.nansum(flatten_lc),
                                      18119.122065014355,
                                      decimal=2)

    print("Detrending 24 (tau)...")
    flatten_lc, trend_lc2 = flatten(time,
                                    flux,
                                    method='tau',
                                    window_length=0.5,
                                    return_trend=True)
    numpy.testing.assert_almost_equal(numpy.nansum(flatten_lc),
                                      18119.02772621119,
                                      decimal=2)

    import numpy as np
    points = 1000
    time = np.linspace(0, 30, points)
    flux = 1 + np.sin(time) / points
    noise = np.random.normal(0, 0.0001, points)
    flux += noise

    for i in range(points):
        if i % 75 == 0:
            flux[i:i + 5] -= 0.0004  # Add some transits
            flux[i + 50:i + 52] += 0.0002  # and flares

    print("Detrending 25a (hampel 17A)...")
    flatten_lc, trend_lc1 = flatten(time,
                                    flux,
                                    method='hampel',
                                    cval=(1.7, 3.4, 8.5),
                                    window_length=0.5,
                                    return_trend=True)

    print("Detrending 25b (hampel 25A)...")
    flatten_lc, trend_lc2 = flatten(time,
                                    flux,
                                    method='hampel',
                                    cval=(2.5, 4.5, 9.5),
                                    window_length=0.5,
                                    return_trend=True)
    numpy.testing.assert_almost_equal(numpy.nansum(flatten_lc),
                                      997.9994362858843,
                                      decimal=2)

    print("Detrending 26 (ramsay)...")
    flatten_lc, trend_lc3 = flatten(time,
                                    flux,
                                    method='ramsay',
                                    cval=0.3,
                                    window_length=0.5,
                                    return_trend=True)
    numpy.testing.assert_almost_equal(numpy.nansum(flatten_lc),
                                      997.9974021484584,
                                      decimal=2)
    """
    import matplotlib.pyplot as plt
    plt.scatter(time, flux, s=1, color='black')
    plt.plot(time[:len(trend_lc1)], trend_lc1, color='blue', linewidth=2)
    plt.plot(time[:len(trend_lc1)], trend_lc2, color='red', linewidth=2, linestyle='dashed')
    plt.show()
    plt.close()
    #plt.scatter(time, flatten_lc, s=1, color='black')
    #plt.show()
    """

    print('All tests completed.')
Beispiel #2
0
def main():
    print("Starting tests for wotan...")

    numpy.testing.assert_almost_equal(t14(R_s=1, M_s=1, P=365), 0.6490025258902046)

    numpy.testing.assert_almost_equal(
        t14(R_s=1, M_s=1, P=365, small_planet=True), 0.5403690143737738
    )
    print("Transit duration correct.")

    numpy.random.seed(seed=0)  # reproducibility

    print("Slide clipper...")
    points = 1000
    time = numpy.linspace(0, 30, points)
    flux = 1 + numpy.sin(time) / points
    noise = numpy.random.normal(0, 0.0001, points)
    flux += noise

    for i in range(points):
        if i % 75 == 0:
            flux[i : i + 5] -= 0.0004  # Add some transits
            flux[i + 50 : i + 52] += 0.0002  # and flares

    clipped = slide_clip(
        time, flux, window_length=0.5, low=3, high=2, method="mad", center="median"
    )
    numpy.testing.assert_almost_equal(numpy.nansum(clipped), 948.9926368754939)

    """
    import matplotlib.pyplot as plt
    plt.scatter(time, flux, s=3, color='black')
    plt.scatter(time, clipped, s=3, color='orange')
    plt.show()
    """

    # TESS test
    print("Loading TESS data from archive.stsci.edu...")
    path = "https://archive.stsci.edu/hlsps/tess-data-alerts/"
    # path = 'P:/P/Dok/tess_alarm/'
    filename = "hlsp_tess-data-alerts_tess_phot_00062483237-s01_tess_v1_lc.fits"
    time, flux = load_file(path + filename)

    window_length = 0.5
    print("Detrending 1 (biweight)...")
    flatten_lc, trend_lc = flatten(
        time,
        flux,
        window_length,
        edge_cutoff=1,
        break_tolerance=0.1,
        return_trend=True,
        cval=5.0,
    )

    numpy.testing.assert_equal(len(trend_lc), 20076)
    numpy.testing.assert_almost_equal(
        numpy.nanmax(trend_lc), 28755.03811866676, decimal=2
    )
    numpy.testing.assert_almost_equal(
        numpy.nanmin(trend_lc), 28615.110229935075, decimal=2
    )
    numpy.testing.assert_almost_equal(trend_lc[500], 28671.650565730513, decimal=2)

    numpy.testing.assert_equal(len(flatten_lc), 20076)
    numpy.testing.assert_almost_equal(
        numpy.nanmax(flatten_lc), 1.0034653549250616, decimal=2
    )
    numpy.testing.assert_almost_equal(
        numpy.nanmin(flatten_lc), 0.996726610702177, decimal=2
    )
    numpy.testing.assert_almost_equal(flatten_lc[500], 1.000577429565131, decimal=2)

    print("Detrending 2 (andrewsinewave)...")
    flatten_lc, trend_lc = flatten(
        time, flux, window_length, method="andrewsinewave", return_trend=True
    )
    numpy.testing.assert_almost_equal(
        numpy.nansum(flatten_lc), 18123.15456308313, decimal=2
    )

    print("Detrending 3 (welsch)...")
    flatten_lc, trend_lc = flatten(
        time, flux, window_length, method="welsch", return_trend=True
    )
    numpy.testing.assert_almost_equal(
        numpy.nansum(flatten_lc), 18123.16770590837, decimal=2
    )

    print("Detrending 4 (hodges)...")
    flatten_lc, trend_lc = flatten(
        time[:1000], flux[:1000], window_length, method="hodges", return_trend=True
    )
    numpy.testing.assert_almost_equal(
        numpy.nansum(flatten_lc), 996.0113241694287, decimal=2
    )

    print("Detrending 5 (median)...")
    flatten_lc, trend_lc = flatten(
        time, flux, window_length, method="median", return_trend=True
    )
    numpy.testing.assert_almost_equal(
        numpy.nansum(flatten_lc), 18123.12166552401, decimal=2
    )

    print("Detrending 6 (mean)...")
    flatten_lc, trend_lc = flatten(
        time, flux, window_length, method="mean", return_trend=True
    )
    numpy.testing.assert_almost_equal(
        numpy.nansum(flatten_lc), 18123.032058753546, decimal=2
    )

    print("Detrending 7 (trim_mean)...")
    flatten_lc, trend_lc = flatten(
        time, flux, window_length, method="trim_mean", return_trend=True
    )
    numpy.testing.assert_almost_equal(
        numpy.nansum(flatten_lc), 18123.094751124332, decimal=2
    )

    print("Detrending 8 (supersmoother)...")
    flatten_lc, trend_lc = flatten(
        time, flux, window_length, method="supersmoother", return_trend=True
    )
    numpy.testing.assert_almost_equal(
        numpy.nansum(flatten_lc), 18123.00632204841, decimal=2
    )

    print("Detrending 9 (hspline)...")
    flatten_lc, trend_lc = flatten(
        time, flux, window_length, method="hspline", return_trend=True
    )
    numpy.testing.assert_almost_equal(
        numpy.nansum(flatten_lc), 18123.082601463717, decimal=1
    )

    print("Detrending 10 (cofiam)...")
    flatten_lc, trend_lc = flatten(
        time[:2000], flux[:2000], window_length, method="cofiam", return_trend=True
    )
    numpy.testing.assert_almost_equal(
        numpy.nansum(flatten_lc), 1948.9999999987976, decimal=1
    )

    print("Detrending 11 (savgol)...")
    flatten_lc, trend_lc = flatten(
        time, flux, window_length=301, method="savgol", return_trend=True
    )
    numpy.testing.assert_almost_equal(
        numpy.nansum(flatten_lc), 18123.003465539354, decimal=1
    )

    print("Detrending 12 (medfilt)...")
    flatten_lc, trend_lc = flatten(
        time, flux, window_length=301, method="medfilt", return_trend=True
    )
    numpy.testing.assert_almost_equal(
        numpy.nansum(flatten_lc), 18123.22609806557, decimal=1
    )

    print("Detrending 12 (gp squared_exp)...")
    flatten_lc, trend_lc1 = flatten(
        time[:2000],
        flux[:2000],
        method="gp",
        kernel="squared_exp",
        kernel_size=10,
        return_trend=True,
    )
    numpy.testing.assert_almost_equal(
        numpy.nansum(flatten_lc), 1948.9672036416687, decimal=2
    )

    print("Detrending 13 (gp squared_exp robust)...")
    flatten_lc, trend_lc1 = flatten(
        time[:2000],
        flux[:2000],
        method="gp",
        kernel="squared_exp",
        kernel_size=10,
        robust=True,
        return_trend=True,
    )
    numpy.testing.assert_almost_equal(
        numpy.nansum(flatten_lc), 1948.8820772313468, decimal=2
    )

    print("Detrending 14 (gp matern)...")
    flatten_lc, trend_lc2 = flatten(
        time[:2000],
        flux[:2000],
        method="gp",
        kernel="matern",
        kernel_size=10,
        return_trend=True,
    )
    numpy.testing.assert_almost_equal(
        numpy.nansum(flatten_lc), 1948.9672464898367, decimal=2
    )

    print("Detrending 15 (gp periodic)...")
    flatten_lc, trend_lc2 = flatten(
        time[:2000],
        flux[:2000],
        method="gp",
        kernel="periodic",
        kernel_size=1,
        kernel_period=10,
        return_trend=True,
    )

    numpy.testing.assert_almost_equal(
        numpy.nansum(flatten_lc), 1948.9999708985608, decimal=2
    )

    time_synth = numpy.linspace(0, 30, 200)
    flux_synth = numpy.sin(time_synth) + numpy.random.normal(0, 0.1, 200)
    flux_synth = 1 + flux_synth / 100
    time_synth *= 1.5
    print("Detrending 16 (gp periodic_auto)...")
    flatten_lc, trend_lc2 = flatten(
        time_synth,
        flux_synth,
        method="gp",
        kernel="periodic_auto",
        kernel_size=1,
        return_trend=True,
    )
    numpy.testing.assert_almost_equal(numpy.nansum(flatten_lc), 200, decimal=1)

    print("Detrending 17 (rspline)...")
    flatten_lc, trend_lc2 = flatten(
        time, flux, method="rspline", window_length=1, return_trend=True
    )
    numpy.testing.assert_almost_equal(
        numpy.nansum(flatten_lc), 18121.812790732245, decimal=2
    )
    """
    print("Detrending 18 (huber)...")
    flatten_lc, trend_lc = flatten(
        time[:1000],
        flux[:1000],
        method='huber',
        window_length=0.5,
        edge_cutoff=0,
        break_tolerance=0.4,
        return_trend=True)
    numpy.testing.assert_almost_equal(numpy.nansum(flatten_lc), 996.0112964009066, decimal=2)
    """
    print("Detrending 19 (winsorize)...")
    flatten_lc, trend_lc2 = flatten(
        time,
        flux,
        method="winsorize",
        window_length=0.5,
        edge_cutoff=0,
        break_tolerance=0.4,
        proportiontocut=0.1,
        return_trend=True,
    )
    numpy.testing.assert_almost_equal(
        numpy.nansum(flatten_lc), 18123.064149766662, decimal=2
    )

    print("Detrending 20 (pspline)...")
    flatten_lc, trend_lc = flatten(time, flux, method="pspline", return_trend=True)
    # import matplotlib.pyplot as plt
    # plt.scatter(time, flux, s=3, color='black')
    # plt.plot(time, trend_lc)
    # plt.show()

    numpy.testing.assert_almost_equal(
        numpy.nansum(flatten_lc), 18122.740535799767, decimal=2
    )

    print("Detrending 21 (hampelfilt)...")
    flatten_lc, trend_lc5 = flatten(
        time, flux, method="hampelfilt", window_length=0.5, cval=3, return_trend=True
    )
    numpy.testing.assert_almost_equal(
        numpy.nansum(flatten_lc), 18123.157973016467, decimal=2
    )

    print("Detrending 22 (lowess)...")
    flatten_lc, trend_lc1 = flatten(
        time, flux, method="lowess", window_length=1, return_trend=True
    )
    numpy.testing.assert_almost_equal(
        numpy.nansum(flatten_lc), 18123.039744125545, decimal=2
    )

    print("Detrending 23 (huber_psi)...")
    flatten_lc, trend_lc1 = flatten(
        time, flux, method="huber_psi", window_length=0.5, return_trend=True
    )
    numpy.testing.assert_almost_equal(
        numpy.nansum(flatten_lc), 18123.110893063527, decimal=2
    )

    print("Detrending 24 (tau)...")
    flatten_lc, trend_lc2 = flatten(
        time, flux, method="tau", window_length=0.5, return_trend=True
    )
    numpy.testing.assert_almost_equal(
        numpy.nansum(flatten_lc), 18123.026005725977, decimal=2
    )

    print("Detrending 25 (cosine)...")
    flatten_lc, trend_lc2 = flatten(
        time, flux, method="cosine", window_length=0.5, return_trend=True
    )
    numpy.testing.assert_almost_equal(
        numpy.nansum(flatten_lc), 18122.999999974905, decimal=2
    )

    print("Detrending 25 (cosine robust)...")
    flatten_lc, trend_lc2 = flatten(
        time, flux, method="cosine", robust=True, window_length=0.5, return_trend=True
    )
    numpy.testing.assert_almost_equal(
        numpy.nansum(flatten_lc), 18122.227938535038, decimal=2
    )

    import numpy as np

    points = 1000
    time = numpy.linspace(0, 30, points)
    flux = 1 + numpy.sin(time) / points
    noise = numpy.random.normal(0, 0.0001, points)
    flux += noise

    for i in range(points):
        if i % 75 == 0:
            flux[i : i + 5] -= 0.0004  # Add some transits
            flux[i + 50 : i + 52] += 0.0002  # and flares

    print("Detrending 26 (hampel 17A)...")
    flatten_lc, trend_lc1 = flatten(
        time,
        flux,
        method="hampel",
        cval=(1.7, 3.4, 8.5),
        window_length=0.5,
        return_trend=True,
    )

    print("Detrending 27 (hampel 25A)...")
    flatten_lc, trend_lc2 = flatten(
        time,
        flux,
        method="hampel",
        cval=(2.5, 4.5, 9.5),
        window_length=0.5,
        return_trend=True,
    )
    numpy.testing.assert_almost_equal(
        numpy.nansum(flatten_lc), 999.9992212031945, decimal=2
    )

    print("Detrending 28 (ramsay)...")
    flatten_lc, trend_lc3 = flatten(
        time, flux, method="ramsay", cval=0.3, window_length=0.5, return_trend=True
    )
    numpy.testing.assert_almost_equal(
        numpy.nansum(flatten_lc), 999.9970566765148, decimal=2
    )

    print("Detrending 29 (ridge)...")
    flatten_lc, trend_lc1 = flatten(
        time, flux, window_length=0.5, method="ridge", return_trend=True, cval=1
    )
    numpy.testing.assert_almost_equal(
        numpy.nansum(flatten_lc), 999.9999958887022, decimal=1
    )

    print("Detrending 30 (lasso)...")
    flatten_lc, trend_lc2 = flatten(
        time, flux, window_length=0.5, method="lasso", return_trend=True, cval=1
    )
    numpy.testing.assert_almost_equal(
        numpy.nansum(flatten_lc), 999.9999894829843, decimal=1
    )

    print("Detrending 31 (elasticnet)...")
    flatten_lc, trend_lc3 = flatten(
        time, flux, window_length=0.5, method="elasticnet", return_trend=True, cval=1
    )
    numpy.testing.assert_almost_equal(
        numpy.nansum(flatten_lc), 999.9999945063338, decimal=1
    )

    # Test of transit mask
    print("Testing transit_mask")
    filename = "hlsp_tess-data-alerts_tess_phot_00207081058-s01_tess_v1_lc.fits"
    time, flux = load_file(path + filename)

    from wotan import transit_mask

    mask = transit_mask(time=time, period=14.77338, duration=0.21060, T0=1336.141095)
    numpy.testing.assert_almost_equal(numpy.sum(mask), 302, decimal=1)

    print("Detrending 32 (transit_mask cosine)")
    flatten_lc1, trend_lc1 = flatten(
        time,
        flux,
        method="cosine",
        window_length=0.4,
        return_trend=True,
        robust=True,
        mask=mask,
    )
    numpy.testing.assert_almost_equal(
        numpy.nansum(flatten_lc1), 18119.281265446625, decimal=1
    )

    print("Detrending 33 (transit_mask lowess)")
    flatten_lc2, trend_lc2 = flatten(
        time,
        flux,
        method="lowess",
        window_length=0.8,
        return_trend=True,
        robust=True,
        mask=mask,
    )
    # print(numpy.nansum(flatten_lc2))
    numpy.testing.assert_almost_equal(
        numpy.nansum(flatten_lc2), 18119.30865711536, decimal=1
    )

    print("Detrending 34 (transit_mask GP)")
    mask = transit_mask(time=time[:2000], period=100, duration=0.3, T0=1327.4)
    flatten_lc2, trend_lc2 = flatten(
        time[:2000],
        flux[:2000],
        method="gp",
        kernel="matern",
        kernel_size=0.8,
        return_trend=True,
        robust=True,
        mask=mask,
    )
    # print(numpy.nansum(flatten_lc2))
    numpy.testing.assert_almost_equal(
        numpy.nansum(flatten_lc2), 1948.9000170463796, decimal=1
    )

    print("Detrending 35 (pspline full features)")
    flatten_lc, trend_lc, nsplines = flatten(
        time,
        flux,
        method="pspline",
        max_splines=100,
        edge_cutoff=0.5,
        return_trend=True,
        return_nsplines=True,
        verbose=True,
    )

    # print('lightcurve was split into', len(nsplines), 'segments')
    # print('chosen number of splines', nsplines)
    """
    import matplotlib.pyplot as plt
    plt.scatter(time, flux, s=3, color='black')
    plt.plot(time, trend_lc)
    plt.show()

    plt.scatter(time, flatten_lc, s=3, color='black')
    plt.show()
    """
    numpy.testing.assert_almost_equal(
        numpy.nansum(flatten_lc), 16678.312693036027, decimal=1
    )

    print("Detrending 36 (pspline variable PSPLINES_STDEV_CUT)")
    flatten_lc, trend_lc, nsplines = flatten(
        time,
        flux,
        method="pspline",
        max_splines=100,
        edge_cutoff=0.5,
        stdev_cut=3,
        return_trend=True,
        return_nsplines=True,
        verbose=True,
    )
    numpy.testing.assert_almost_equal(
        numpy.nansum(flatten_lc), 16678.292210380347, decimal=2
    )

    """
    import matplotlib.pyplot as plt
    plt.scatter(time, flux, s=1, color='black')
    plt.plot(time, trend_lc, color='red', linewidth=2, linestyle='dashed')
    plt.show()
    plt.close()
    """

    print("All tests completed.")
def explore_flux_lightcurves(
    data, ticid, outdir=None, period=None, epoch=None, pipeline=None,
    detrend=False, window_length=None, do_phasefold=0, badtimewindows=None,
    get_lc=False, require_quality_zero=1, forceylim=None, normstitch=True,
    slideclipdict={'window_length':1, 'high':3, 'low':8},
    mask_orbit_edges=False
):
    """
    Given a list of SPOC 2 minute data FITS tables, stitch them across sectors
    and make diagnostic plots.

    Args:

        data (list): from `get_tess_data`, contents [hdulistA[1].data,
            hdulistB[1].data], etc..

        ticid (str): TIC ID.

        pipeline (str): one of ['cdips', 'spoc', 'eleanor', 'cdipspre',
        'kepler', 'qlp'].  This is used to access the flux, provenance, etc.

        outdir (str): diagnostic plots are written here. If None, goes to
        cdips_followup results directory.

    Optional kwargs:

        period, epoch (float): optional

        detrend (bool, or string): 'biweight' or 'pspline' accepted. Default
        parameters assumed for each.

        badtimewindows (list): to manually mask out, [(1656, 1658), (1662,
            1663)], for instance.

        get_lc (bool): if True, returns time and flux arrays.

        require_quality_zero (bool): if True, sets QUALITY==0, throwing out
        lots of data.

        normstitch (bool): normalize flux across sectors s.t. the relative
        amplitude remains fixed.

        slideclipdict (dict): e.g., {'window_length':1, 'high':3, 'low':8} for
        1 day sliding window, exclude +3MAD from median above and -8MAD from
        median below.
    """

    assert isinstance(data, list), 'Expected list of FITStables.'

    if pipeline not in ['spoc', 'kepler', 'qlp', 'cdips']:
        raise NotImplementedError

    if isinstance(epoch, float):
        if epoch < 2450000:
            raise ValueError(f'Expected epoch in BJDTDB. Got epoch={epoch:.6f}.')

    ykey = LCKEYDICT[pipeline]['flux']
    xkey = LCKEYDICT[pipeline]['time']
    qualkey = LCKEYDICT[pipeline]['quality']
    prov = LCKEYDICT[pipeline]['prov']
    inst = LCKEYDICT[pipeline]['inst']

    if outdir is None:
        outdir = os.path.join(RESULTSDIR, 'quicklooklc', f'TIC{ticid}')

    times, fluxs= [], []
    for ix, d in enumerate(data):

        savpath = os.path.join(
            outdir, f'TIC{ticid}_{prov}_{inst}_lightcurve_{str(ix).zfill(2)}.png'
        )
        if detrend:
            savpath = os.path.join(
                outdir, f'TIC{ticid}_{prov}_{inst}_lightcurve_{detrend}_{str(ix).zfill(2)}.png'
            )

        plt.close('all')
        f,ax = plt.subplots(figsize=(16*2,4*1.5))

        if require_quality_zero:
            okkey = 0 if pipeline in 'spoc,kepler,qlp'.split(',') else 'G'
            sel = (d[qualkey] == okkey) & (d[ykey] > 0)
            print(42*'.')
            print('WRN!: omitting all non-zero quality flags. throws out good data!')
            print(42*'.')
        else:
            sel = (d[ykey] > 0)
        if badtimewindows is not None:
            for w in badtimewindows:
                sel &= ~(
                    (d[xkey] > w[0])
                    &
                    (d[xkey] < w[1])
                )

        # correct time column to BJD_TDB
        x_offset = LCKEYDICT[pipeline]['time_offset']
        x_obs = d[xkey][sel] + x_offset

        # get the median-normalized flux
        y_obs = d[ykey][sel]
        if pipeline == 'cdips':
            y_obs, _ = _given_mag_get_flux(y_obs, y_obs*1e-3)
        y_obs /= np.nanmedian(y_obs)

        if mask_orbit_edges:
            x_obs, y_obs, _ = moe.mask_orbit_start_and_end(
                x_obs, y_obs, raise_expectation_error=False, orbitgap=0.7,
                orbitpadding=12/(24),
                return_inds=True
            )

        # slide clip -- remove outliers with windowed stdevn removal
        if isinstance(slideclipdict, dict):
            y_obs = slide_clip(x_obs, y_obs, slideclipdict['window_length'],
                               low=slideclipdict['low'],
                               high=slideclipdict['high'], method='mad',
                               center='median')


        if detrend:
            ax.scatter(x_obs, y_obs, c='k', s=4, zorder=2)

            # # default pspline detrending
            if detrend=='pspline':
                y_obs, y_trend = dtr.detrend_flux(x_obs, y_obs)
                x_trend = deepcopy(x_obs)

            # in some cases, might prefer the biweight
            elif detrend == 'biweight':
                y_obs, y_trend = dtr.detrend_flux(x_obs, y_obs,
                                                  method='biweight', cval=5,
                                                  window_length=0.5,
                                                  break_tolerance=0.5)
                x_trend = deepcopy(x_obs)

            elif detrend == 'minimal':
                y_obs, y_trend = dtr.detrend_flux(x_obs, y_obs,
                                                  method='biweight', cval=2,
                                                  window_length=3.5,
                                                  break_tolerance=0.5)
                x_trend = deepcopy(x_obs)

            elif detrend == 'median':
                y_obs, y_trend = dtr.detrend_flux(x_obs, y_obs,
                                                  method='median',
                                                  window_length=0.6,
                                                  break_tolerance=0.5,
                                                  edge_cutoff=0.)
                x_trend = deepcopy(x_obs)

            elif detrend == 'best':
                from cdips.lcproc.find_planets import run_periodograms_and_detrend
                dtr_dict = {'method':'best', 'break_tolerance':0.5, 'window_length':0.5}
                lsp_options = {'period_min':0.1, 'period_max':20}

                # r = [source_id, ls_period, ls_fap, ls_amplitude, tls_period, tls_sde,
                #      tls_t0, tls_depth, tls_duration, tls_distinct_transit_count,
                #      tls_odd_even, dtr_method]

                r, search_time, search_flux, dtr_stages_dict = run_periodograms_and_detrend(
                    ticid, x_obs, y_obs, dtr_dict,
                    period_min=lsp_options['period_min'],
                    period_max=lsp_options['period_max'], dtr_method='best',
                    return_extras=True,
                    magisflux=True
                )
                y_trend, x_trend, dtr_method = (
                    dtr_stages_dict['trend_flux'],
                    dtr_stages_dict['trend_time'],
                    dtr_stages_dict['dtr_method_used']
                )
                x_obs, y_obs = deepcopy(search_time), deepcopy(search_flux)
                print(f'TIC{ticid} TLS results')
                print(f'dtr_method_used: {dtr_method}')
                print(r)

            else:
                raise NotImplementedError

        if detrend:
            ax.plot(x_trend, y_trend, c='r', lw=0.5, zorder=3)
        else:
            ax.scatter(x_obs, y_obs, c='k', s=4, zorder=2)

        times.append( x_obs )
        fluxs.append( y_obs )

        ax.set_xlabel('time [bjdtdb]')
        ax.set_ylabel(ykey)
        ylim = ax.get_ylim()

        ax.set_title(ix)

        if detrend:
            _ylim = _get_ylim(y_trend)
        else:
            _ylim = _get_ylim(y_obs)

        ax.set_ylim(_ylim)
        if isinstance(forceylim, list) or isinstance(forceylim, tuple):
            ax.set_ylim(forceylim)

        if not epoch is None:
            tra_times = epoch + np.arange(-1000,1000,1)*period

            xlim = ax.get_xlim()
            ylim = ax.get_ylim()

            ax.set_ylim((min(ylim), max(ylim)))
            ax.vlines(tra_times, min(ylim), max(ylim), color='orangered',
                      linestyle='--', zorder=-2, lw=0.5, alpha=0.3)
            ax.set_ylim((min(ylim), max(ylim)))
            ax.set_xlim(xlim)

        f.savefig(savpath, dpi=300, bbox_inches='tight')
        print('made {}'.format(savpath))

    if normstitch:
        times, fluxs, _ = lcu.stitch_light_curves(
            times, fluxs, fluxs, magsarefluxes=True, normstitch=True
        )
    else:
        times = np.hstack(np.array(times).flatten())
        fluxs = np.hstack(np.array(fluxs).flatten())

    # NOTE: this call is deprecated
    stimes, smags, _ = lcmath.sigclip_magseries(
        times, fluxs, np.ones_like(fluxs), sigclip=[20,20], iterative=True,
        magsarefluxes=True
    )

    savpath = os.path.join(
        outdir, f'TIC{ticid}_{prov}_{inst}_lightcurve_{str(ykey).zfill(2)}_allsector.png'
    )
    if detrend:
        savpath = os.path.join(
            outdir, f'TIC{ticid}_{prov}_{inst}_lightcurve_{detrend}_{str(ykey).zfill(2)}_allsector.png'
        )

    plt.close('all')
    f,ax = plt.subplots(figsize=(16,4))

    ax.scatter(stimes, smags, c='k', s=1)

    if not epoch is None:
        tra_times = epoch + np.arange(-1000,1000,1)*period

        xlim = ax.get_xlim()
        ylim = ax.get_ylim()

        ax.set_ylim((min(ylim), max(ylim)))
        ax.vlines(tra_times, min(ylim), max(ylim), color='orangered',
                  linestyle='--', zorder=-2, lw=0.5, alpha=0.3)
        ax.set_ylim((min(ylim), max(ylim)))
        ax.set_xlim(xlim)

    ax.set_xlabel('time [bjdtdb]')
    ax.set_ylabel('relative '+ykey)

    ax.set_title(ix)

    f.savefig(savpath, dpi=400, bbox_inches='tight')
    print('made {}'.format(savpath))

    csvpath = savpath.replace('.png','_sigclipped.csv')
    pd.DataFrame({
        'time': stimes, 'flux': smags,
    }).to_csv(csvpath, index=False)
    print(f'made {csvpath}')


    if do_phasefold:

        assert (
            isinstance(period, (float,int)) and isinstance(epoch, (float,int))
        )

        #
        # ax: primary transit
        #
        if inst == 'kepler':
            phasebin = 1e-3
        elif inst == 'tess':
            phasebin = 5e-3
        minbinelems = 2
        plotxlims = [(-0.5, 0.5), (-0.05,0.05)]
        xlimstrs = ['xwide','xnarrow']
        plotylim = [0.9, 1.08]#None #[0.9,1.1]
        do_vlines = False

        for plotxlim, xstr in zip(plotxlims, xlimstrs):

            plt.close('all')
            fig, ax = plt.subplots(figsize=(4,3))

            # use times and fluxs, instead of the sigma clipped thing.
            _make_phased_magseries_plot(ax, 0, times, fluxs,
                                        np.ones_like(fluxs)/1e4, period, epoch,
                                        True, True, phasebin, minbinelems,
                                        plotxlim, '', xliminsetmode=False,
                                        magsarefluxes=True, phasems=0.8,
                                        phasebinms=4.0, verbose=True)
            if isinstance(plotylim, (list, tuple)):
                ax.set_ylim(plotylim)
            else:
                plotylim = _get_ylim(fluxs)
                ax.set_ylim(plotylim)

            if do_vlines:
                ax.vlines(1/6, min(plotylim), max(plotylim), color='orangered',
                          linestyle='--', zorder=-2, lw=1, alpha=0.8)
                ax.set_ylim(plotylim)

            dstr = detrend if detrend else ''
            savpath = os.path.join(
                outdir, f'TIC{ticid}_{prov}_{inst}_lightcurve_{dstr}_{ykey}_{xstr}_allsector_phasefold.png'
            )

            fig.savefig(savpath, dpi=400, bbox_inches='tight')
            print(f'made {savpath}')

        csvpath = savpath.replace('png','csv')
        pd.DataFrame({
            'time': times, 'flux': fluxs
        }).to_csv(csvpath, index=False)
        print(f'made {csvpath}')

    if get_lc:
        return times, fluxs
Beispiel #4
0
def tls_search(time,
               flux,
               flux_err,
               known_transits=None,
               tls_kwargs=None,
               wotan_kwargs=None,
               options=None):
    '''
    Summary:
    -------
    This runs TLS on these data with the given infos
    
    Inputs:
    -------
    time : array of flaot
        time stamps of observations
    flux : array of flaot
        normalized flux
    flux_err : array of flaot
        error of normalized flux
        
        
    Optional Inputs:
    ----------------
    tls_kwargs : None or dict, keywords:
        R_star : float
            radius of the star (e.g. median)
            default 1 R_sun (from TLS)
        R_star_min : float
            minimum radius of the star (e.g. 1st percentile)
            default 0.13 R_sun (from TLS)
        R_star_max : float
            maximum radius of the star (e.g. 99th percentile)
            default 3.5 R_sun (from TLS)
        M_star : float
            mass of the star (e.g. median)
            default 1. M_sun (from TLS)
        M_star_min : float
            minimum mass of the star (e.g. 1st percentile)
            default 0.1 M_sun (from TLS)
        M_star_max : float
            maximum mass of the star (e.g. 99th percentile)
            default 1. M_sun (from TLS)    
        u : list
            quadratic limb darkening parameters
            default [0.4804, 0.1867]
        ...
            
    SNR_threshold : float
        the SNR threshold at which to stop the TLS search
        
    known_transits : None or dict
        if dict and one transit is already known: 
            known_transits = {'period':[1.3], 'duration':[2.1], 'epoch':[245800.0]}
        if dict and multiple transits are already known: 
            known_transits = {'name':['b','c'], 'period':[1.3, 21.0], 'duration':[2.1, 4.1], 'epoch':[245800.0, 245801.0]}
        'period' is the period of the transit
        'duration' must be the total duration, i.e. from first ingress point to last egrees point, in days
        'epoch' is the epoch of the transit
        
    options : None or dict, keywords:
        show_plot : bool
            show a plot of each phase-folded transit candidate and TLS model in the terminal 
            default is False
        save_plot : bool
            save a plot of each phase-folded transit candidate and TLS model into outdir
            default is False
        outdir : string
            if None, use the current working directory
            default is ""
        
    Returns:
    -------
    List of all TLS results
    '''

    #::: seeed
    np.random.seed(42)

    #::: handle inputs
    if flux_err is None:
        ind = np.where(~np.isnan(time * flux))[0]
        time = time[ind]
        flux = flux[ind]
    else:
        ind = np.where(~np.isnan(time * flux * flux_err))[0]
        time = time[ind]
        flux = flux[ind]
        flux_err = flux_err[ind]

    time_input = 1. * time
    flux_input = 1. * flux  #for plotting

    if wotan_kwargs is None:
        detrend = False
    else:
        detrend = True

        if 'slide_clip' not in wotan_kwargs: wotan_kwargs['slide_clip'] = {}
        if 'window_length' not in wotan_kwargs['slide_clip']:
            wotan_kwargs['slide_clip']['window_length'] = 1.
        if 'low' not in wotan_kwargs['slide_clip']:
            wotan_kwargs['slide_clip']['low'] = 20.
        if 'high' not in wotan_kwargs['slide_clip']:
            wotan_kwargs['slide_clip']['high'] = 3.

        if 'flatten' not in wotan_kwargs: wotan_kwargs['flatten'] = {}
        if 'method' not in wotan_kwargs['flatten']:
            wotan_kwargs['flatten']['method'] = 'biweight'
        if 'window_length' not in wotan_kwargs['flatten']:
            wotan_kwargs['flatten']['window_length'] = 1.
        #the rest is filled automatically by Wotan

    if tls_kwargs is None: tls_kwargs = {}
    if 'show_progress_bar' not in tls_kwargs:
        tls_kwargs['show_progress_bar'] = False
    if 'SNR_threshold' not in tls_kwargs: tls_kwargs['SNR_threshold'] = 5.
    if 'SDE_threshold' not in tls_kwargs: tls_kwargs['SDE_threshold'] = 5.
    if 'FAP_threshold' not in tls_kwargs: tls_kwargs['FAP_threshold'] = 0.05
    tls_kwargs_original = {
        key: tls_kwargs[key]
        for key in tls_kwargs.keys()
        if key not in ['SNR_threshold', 'SDE_threshold', 'FAP_threshold']
    }  #for the original tls
    #the rest is filled automatically by TLS

    if options is None: options = {}
    if 'show_plot' not in options: options['show_plot'] = False
    if 'save_plot' not in options: options['save_plot'] = False
    if 'outdir' not in options: options['outdir'] = ''

    #::: init
    SNR = 1e12
    SDE = 1e12
    FAP = 0
    FOUND_SIGNAL = False
    results_all = []
    if len(options['outdir']) > 0 and not os.path.exists(options['outdir']):
        os.makedirs(options['outdir'])

    #::: logprint
    with open(os.path.join(options['outdir'], 'logfile.log'), 'w') as f:
        f.write('TLS search, UTC ' +
                datetime.utcnow().strftime('%Y-%m-%d %H:%M:%S') + '\n')
    logprint('\nWotan kwargs:', options=options)
    logpprint(wotan_kwargs, options=options)
    logprint('\nTLS kwargs:', options=options)
    logpprint(tls_kwargs, options=options)
    logprint('\nOptions:', options=options)
    logpprint(options, options=options)

    #::: apply a mask (if wished so)
    if known_transits is not None:
        for period, duration, T0 in zip(known_transits['period'],
                                        known_transits['duration'],
                                        known_transits['epoch']):
            time, flux, flux_err = mask(time, flux, flux_err, period, duration,
                                        T0)

    #::: global sigma clipping
    flux = sigma_clip(flux, sigma_upper=3, sigma_lower=20)

    #::: detrend (if wished so)
    if detrend:
        flux = slide_clip(time, flux, **wotan_kwargs['slide_clip'])
        flux, trend = flatten(time,
                              flux,
                              return_trend=True,
                              **wotan_kwargs['flatten'])

        if options['show_plot'] or options['save_plot']:
            fig, axes = plt.subplots(2, 1, figsize=(40, 8))
            axes[0].plot(time, flux_input, 'b.', rasterized=True)
            axes[0].plot(time, trend, 'r-', lw=2)
            axes[0].set(ylabel='Flux (input)', xticklabels=[])
            axes[1].plot(time, flux, 'b.', rasterized=True)
            axes[1].set(ylabel='Flux (detrended)', xlabel='Time (BJD)')
        if options['save_plot']:
            fig.savefig(os.path.join(
                options['outdir'],
                'flux_' + wotan_kwargs['flatten']['method'] + '.pdf'),
                        bbox_inches='tight')
            if options['show_plot']:
                plt.show(fig)
            else:
                plt.close(fig)

        X = np.column_stack((time, flux, flux_err, trend))
        np.savetxt(os.path.join(
            options['outdir'],
            'flux_' + wotan_kwargs['flatten']['method'] + '.csv'),
                   X,
                   delimiter=',',
                   header='time,flux_detrended,flux_err,trend')

        time_detrended = 1. * time
        flux_detrended = 1. * flux  #for plotting

    #::: search for the rest
    i = 0
    ind_trs = []
    while (SNR >= tls_kwargs['SNR_threshold']
           ) and (SDE >= tls_kwargs['SDE_threshold']) and (
               FAP <= tls_kwargs['FAP_threshold']) and (FOUND_SIGNAL == False):

        model = tls(time, flux, flux_err)
        results = model.power(**tls_kwargs_original)

        if (results.snr >= tls_kwargs['SNR_threshold']) and (
                results.SDE >= tls_kwargs['SDE_threshold']) and (
                    results.FAP <= tls_kwargs['FAP_threshold']):

            #::: calculcate the correct_duration, as TLS sometimes returns unreasonable durations
            ind_tr_phase = np.where(results['model_folded_model'] < 1.)[0]
            correct_duration = results['period'] * (
                results['model_folded_phase'][ind_tr_phase[-1]] -
                results['model_folded_phase'][ind_tr_phase[0]])

            #::: mark transit
            ind_tr, ind_out = index_transits(time_input, results['T0'],
                                             results['period'],
                                             correct_duration)
            ind_trs.append(ind_tr)

            #::: mask out detected transits and append results
            time1, flux1 = time, flux  #for plotting
            time, flux, flux_err = mask(time, flux, flux_err, results.period,
                                        np.max((1.5 * correct_duration)),
                                        results.T0)
            results_all.append(results)

            #::: write TLS stats to file
            with open(
                    os.path.join(options['outdir'],
                                 'tls_signal_' + str(i) + '.txt'),
                    'wt') as out:
                pprint(results, stream=out)

            #::: individual TLS plots
            if options['show_plot'] or options['save_plot']:
                fig = plt.figure(figsize=(20, 8), tight_layout=True)
                gs = fig.add_gridspec(2, 3)

                ax = fig.add_subplot(gs[0, :])
                ax.plot(time1, flux1, 'b.', rasterized=True)
                ax.plot(results['model_lightcurve_time'],
                        results['model_lightcurve_model'],
                        'r-',
                        lw=3)
                ax.set(xlabel='Time (BJD)', ylabel='Flux')

                ax = fig.add_subplot(gs[1, 0])
                ax.plot(results['folded_phase'],
                        results['folded_y'],
                        'b.',
                        rasterized=True)
                ax.plot(results['model_folded_phase'],
                        results['model_folded_model'],
                        'r-',
                        lw=3)
                ax.set(xlabel='Phase', ylabel='Flux')

                ax = fig.add_subplot(gs[1, 1])
                ax.plot(
                    (results['folded_phase'] - 0.5) * results['period'] * 24,
                    results['folded_y'],
                    'b.',
                    rasterized=True)
                ax.plot((results['model_folded_phase'] - 0.5) *
                        results['period'] * 24,
                        results['model_folded_model'],
                        'r-',
                        lw=3)
                ax.set(xlim=[
                    -1.5 * correct_duration * 24, +1.5 * correct_duration * 24
                ],
                       xlabel='Time (h)',
                       yticks=[])

                ax = fig.add_subplot(gs[1, 2])
                ax.text(.02,
                        0.95,
                        'P = ' +
                        np.format_float_positional(results['period'], 4) +
                        ' d',
                        ha='left',
                        va='center',
                        transform=ax.transAxes)
                ax.text(
                    .02,
                    0.85,
                    'Depth = ' +
                    np.format_float_positional(1e3 *
                                               (1. - results['depth']), 4) +
                    ' ppt',
                    ha='left',
                    va='center',
                    transform=ax.transAxes)
                ax.text(.02,
                        0.75,
                        'Duration = ' +
                        np.format_float_positional(24 * correct_duration, 4) +
                        ' h',
                        ha='left',
                        va='center',
                        transform=ax.transAxes)
                ax.text(.02,
                        0.65,
                        'T_0 = ' +
                        np.format_float_positional(results['T0'], 4) + ' d',
                        ha='left',
                        va='center',
                        transform=ax.transAxes)
                ax.text(.02,
                        0.55,
                        'SNR = ' +
                        np.format_float_positional(results['snr'], 4),
                        ha='left',
                        va='center',
                        transform=ax.transAxes)
                ax.text(.02,
                        0.45,
                        'SDE = ' +
                        np.format_float_positional(results['SDE'], 4),
                        ha='left',
                        va='center',
                        transform=ax.transAxes)
                ax.text(.02,
                        0.35,
                        'FAP = ' +
                        np.format_float_scientific(results['FAP'], 4),
                        ha='left',
                        va='center',
                        transform=ax.transAxes)
                ax.set_axis_off()
                if options['save_plot']:
                    fig.savefig(os.path.join(options['outdir'],
                                             'tls_signal_' + str(i) + '.pdf'),
                                bbox_inches='tight')
                if options['show_plot']:
                    plt.show(fig)
                else:
                    plt.close(fig)

        SNR = results.snr
        SDE = results.SDE
        FAP = results.FAP
        i += 1

    #::: full lightcurve plot
    if options['show_plot'] or options['save_plot']:

        if detrend:
            fig, axes = plt.subplots(2, 1, figsize=(40, 8), tight_layout=True)
            ax = axes[0]
            ax.plot(time_input,
                    flux_input,
                    'k.',
                    color='grey',
                    rasterized=True)
            ax.plot(time_input, trend, 'r-', lw=2)
            for number, ind_tr in enumerate(ind_trs):
                ax.plot(time_input[ind_tr],
                        flux_input[ind_tr],
                        marker='.',
                        linestyle='none',
                        label='signal ' + str(number))
            ax.set(ylabel='Flux (input)', xticklabels=[])
            ax.legend()

            ax = axes[1]
            ax.plot(time_detrended,
                    flux_detrended,
                    'k.',
                    color='grey',
                    rasterized=True)
            for number, ind_tr in enumerate(ind_trs):
                ax.plot(time_detrended[ind_tr],
                        flux_detrended[ind_tr],
                        marker='.',
                        linestyle='none',
                        label='signal ' + str(number))
            ax.set(ylabel='Flux (detrended)', xlabel='Time (BJD)')
            ax.legend()

        else:
            fig = plt.figure(figsize=(20, 4), tight_layout=True)
            fig, ax = plt.subplots(1, 1, figsize=(40, 4))
            ax.plot(time_input,
                    flux_input,
                    'k.',
                    color='grey',
                    rasterized=True)
            ax.set(ylabel='Flux (input)', xlabel='Time (BJD)')
            for number, ind_tr in enumerate(ind_trs):
                ax.plot(time_input[ind_tr],
                        flux_input[ind_tr],
                        marker='.',
                        linestyle='none',
                        label='signal ' + str(number))
            ax.legend()

        if options['save_plot']:
            fig.savefig(os.path.join(options['outdir'], 'tls_signal_all.pdf'),
                        bbox_inches='tight')
        if options['show_plot']:
            plt.show(fig)
        else:
            plt.close(fig)

    return results_all
Beispiel #5
0
def clean_rotationsignal_tess_singlesector_light_curve(time,
                                                       mag,
                                                       magisflux=False,
                                                       dtr_dict=None,
                                                       lsp_dict=None,
                                                       maskorbitedge=True,
                                                       lsp_options={
                                                           'period_min': 0.1,
                                                           'period_max': 20
                                                       },
                                                       verbose=True):
    """
    The goal of this function is to remove a stellar rotation signal from a
    single TESS light curve (ideally one without severe insturmental
    systematics) while preserving transits.

    "Cleaning" by default is taken to mean the sequence of mask_orbit_edge ->
    slide_clip -> detrend -> slide_clip.  "Detrend" can mean any of the Wotan
    flatteners, Notch, or LOCOR. "slide_clip" means apply windowed
    sigma-clipping removal.

    Args:
        time, mag (np.ndarray): time and magnitude (or flux) vectors

        magisflux (bool): True if the "mag" vector is a flux already

        dtr_dict (optional dict): dictionary containing arguments passed to
        Wotan, Notch, or LOCOR. Relevant keys should include:

            'dtr_method' (str): one of: ['best', 'notch', 'locor', 'pspline',
            'biweight', 'none']

            'break_tolerance' (float): number of days past which a segment of
            light curve is considered a "new segment".

            'window_length' (float): length of sliding window in days

        lsp_dict (optional dict): dictionary containing Lomb Scargle
        periodogram information, which is used in the "best" method for
        choosing between LOCOR or Notch detrending.  If this is not passed,
        it'll be constructed here after the mask_orbit_edge -> slide_clip
        steps.

        lsp_options: contains keys period_min and period_max, used for the
        internal Lomb Scargle periodogram search.

        maskorbitedge (bool): whether to apply the initial "mask_orbit_edge"
        step. Probably would only want to be false if you had already done it
        elsewhere.

    Returns:
        search_time, search_flux, dtr_stages_dict (np.ndarrays and dict): light
            curve ready for TLS or BLS style periodograms; and a dictionary of
            the different processing stages (see comments for details of
            `dtr_stages_dict` contents).
    """

    dtr_method = _get_detrending_method(dtr_dict)

    #
    # convert mag to flux and median-normalize
    #
    if magisflux:
        flux = mag
    else:
        f_x0 = 1e4
        m_x0 = 10
        flux = f_x0 * 10**(-0.4 * (mag - m_x0))

    flux /= np.nanmedian(flux)

    #
    # ignore the times near the edges of orbits for TLS.
    #
    if maskorbitedge:
        _time, _flux = moe.mask_orbit_start_and_end(
            time, flux, raise_expectation_error=False, verbose=verbose)
    else:
        _time, _flux = time, flux

    #
    # sliding sigma clip asymmetric [20,3]*MAD, about median. use a 3-day
    # window, to give ~100 to 150 data points. mostly to avoid big flares.
    #
    clip_window = 3
    clipped_flux = slide_clip(_time,
                              _flux,
                              window_length=clip_window,
                              low=20,
                              high=3,
                              method='mad',
                              center='median')
    sel0 = ~np.isnan(clipped_flux)

    #
    # for "best" or LOCOR detrending, you need to know the stellar rotation
    # period.  so, if it hasn't already been run, run the LS periodogram here.
    # in `lsp_dict`, cache the LS peak period, amplitude, and FAP.
    #
    if (not isinstance(lsp_dict, dict)) and (dtr_method in ['locor', 'best']):

        period_min = lsp_options['period_min']
        period_max = lsp_options['period_max']

        ls = LombScargle(_time[sel0], clipped_flux[sel0],
                         clipped_flux[sel0] * 1e-3)
        freq, power = ls.autopower(minimum_frequency=1 / period_max,
                                   maximum_frequency=1 / period_min)
        ls_fap = ls.false_alarm_probability(power.max())
        best_freq = freq[np.argmax(power)]
        ls_period = 1 / best_freq
        theta = ls.model_parameters(best_freq)
        ls_amplitude = theta[1]

        lsp_dict = {}
        lsp_dict['ls_period'] = ls_period
        lsp_dict['ls_amplitude'] = np.abs(ls_amplitude)
        lsp_dict['ls_fap'] = ls_fap

    if not isinstance(dtr_dict, dict):
        dtr_dict = {}
        dtr_dict['method'] = dtr_method

    #
    # apply the detrending call based on the method given
    #

    dtr_method_used = dtr_method

    if dtr_method in ['pspline', 'biweight', 'none']:

        if 'break_tolerance' not in dtr_dict:
            dtr_dict['break_tolerance'] = None
        if 'window_length' not in dtr_dict:
            dtr_dict['window_length'] = None

        flat_flux, trend_flux = detrend_flux(
            _time[sel0],
            clipped_flux[sel0],
            break_tolerance=dtr_dict['break_tolerance'],
            method=dtr_dict['method'],
            cval=None,
            window_length=dtr_dict['window_length'],
            edge_cutoff=None)

    elif dtr_method == 'notch':

        flat_flux, trend_flux, notch = _run_notch(_time[sel0],
                                                  clipped_flux[sel0],
                                                  dtr_dict,
                                                  verbose=verbose)

    elif dtr_method == 'locor':

        flat_flux, trend_flux, notch = _run_locor(_time[sel0],
                                                  clipped_flux[sel0], dtr_dict,
                                                  lsp_dict)

    elif dtr_method == 'best':

        # for stars with Prot < 1 day, use LOCOR.  for stars with Prot > 1 day,
        # use Notch.  (or pspline?).
        PERIOD_CUTOFF = 1.0

        if lsp_dict['ls_period'] > PERIOD_CUTOFF:
            flat_flux, trend_flux, notch = _run_notch(_time[sel0],
                                                      clipped_flux[sel0],
                                                      dtr_dict,
                                                      verbose=verbose)
            dtr_method_used += '-notch'
        elif (lsp_dict['ls_period'] < PERIOD_CUTOFF
              and lsp_dict['ls_period'] > 0):
            flat_flux, trend_flux, notch = _run_locor(_time[sel0],
                                                      clipped_flux[sel0],
                                                      dtr_dict, lsp_dict)
            dtr_method_used += '-locor'
        else:
            raise NotImplementedError(f"Got LS period {lsp_dict['ls_period']}")

    #
    # re-apply sliding sigma clip asymmetric [20,3]*MAD, about median, after
    # detrending.
    #
    clip_window = 3
    clipped_flat_flux = slide_clip(_time[sel0],
                                   flat_flux,
                                   window_length=clip_window,
                                   low=20,
                                   high=3,
                                   method='mad',
                                   center='median')
    sel1 = ~np.isnan(clipped_flat_flux)

    search_flux = clipped_flat_flux[sel1]
    search_time = _time[sel0][sel1]

    dtr_stages_dict = {
        # non-nan indices from clipped_flux
        'sel0': sel0,
        # non-nan indices from clipped_flat_flux
        'sel1': sel1,
        # after initial window sigma_clip on flux, what is left?
        'clipped_flux': clipped_flux,
        # after detrending, what is left?
        'flat_flux': flat_flux,
        # after window sigma_clip on flat_flux, what is left?
        'clipped_flat_flux': clipped_flat_flux,
        # what does the detrending algorithm give as the "trend"?
        'trend_flux': trend_flux,
        'trend_time': _time[sel0],
        # what method was used? if "best", gives "best-notch" or "best-locor"
        'dtr_method_used': dtr_method_used,
        # times and fluxes used
        'search_time': search_time,
        'search_flux': search_flux
    }
    if isinstance(lsp_dict, dict):
        # in most cases, cache the LS period, amplitude, and FAP
        dtr_stages_dict['lsp_dict'] = lsp_dict

    return search_time, search_flux, dtr_stages_dict
Beispiel #6
0
def main():
    print("Starting tests for wotan...")

    numpy.testing.assert_almost_equal(
        t14(R_s=1, M_s=1, P=365),
        0.6490025258902046)

    numpy.testing.assert_almost_equal(
        t14(R_s=1, M_s=1, P=365, small_planet=True),
        0.5403690143737738)
    print("Transit duration correct.")

    numpy.random.seed(seed=0)  # reproducibility

    print("Slide clipper...")
    points = 1000
    time = numpy.linspace(0, 30, points)
    flux = 1 + numpy.sin(time)  / points
    noise = numpy.random.normal(0, 0.0001, points)
    flux += noise

    for i in range(points):  
        if i % 75 == 0:
            flux[i:i+5] -= 0.0004  # Add some transits
            flux[i+50:i+52] += 0.0002  # and flares

    clipped = slide_clip(
    time,
    flux,
    window_length=0.5,
    low=3,
    high=2,
    method='mad',
    center='median'
    )
    numpy.testing.assert_almost_equal(numpy.nansum(clipped), 948.9926368754939)

    """
    import matplotlib.pyplot as plt
    plt.scatter(time, flux, s=3, color='black')
    plt.scatter(time, clipped, s=3, color='orange')
    plt.show()
    """

    # TESS test
    print('Loading TESS data from archive.stsci.edu...')
    path = 'https://archive.stsci.edu/hlsps/tess-data-alerts/'
    #path = 'P:/P/Dok/tess_alarm/'
    filename = "hlsp_tess-data-alerts_tess_phot_00062483237-s01_tess_v1_lc.fits"
    time, flux = load_file(path + filename)

    window_length = 0.5
    
    print("Detrending 1 (biweight)...")
    flatten_lc, trend_lc = flatten(
        time,
        flux,
        window_length,
        edge_cutoff=1,
        break_tolerance=0.1,
        return_trend=True,
        cval=5.0)

    numpy.testing.assert_equal(len(trend_lc), 20076)
    numpy.testing.assert_almost_equal(numpy.nanmax(trend_lc), 28755.03811866676, decimal=2)
    numpy.testing.assert_almost_equal(numpy.nanmin(trend_lc), 28615.110229935075, decimal=2)
    numpy.testing.assert_almost_equal(trend_lc[500], 28671.650565730513, decimal=2)

    numpy.testing.assert_equal(len(flatten_lc), 20076)
    numpy.testing.assert_almost_equal(numpy.nanmax(flatten_lc), 1.0034653549250616, decimal=2)
    numpy.testing.assert_almost_equal(numpy.nanmin(flatten_lc), 0.996726610702177, decimal=2)
    numpy.testing.assert_almost_equal(flatten_lc[500], 1.000577429565131, decimal=2)

    print("Detrending 2 (andrewsinewave)...")
    flatten_lc, trend_lc = flatten(time, flux, window_length, method='andrewsinewave', return_trend=True)
    numpy.testing.assert_almost_equal(numpy.nansum(flatten_lc), 18119.15471987987, decimal=2)

    print("Detrending 3 (welsch)...")
    flatten_lc, trend_lc = flatten(time, flux, window_length, method='welsch', return_trend=True)
    numpy.testing.assert_almost_equal(numpy.nansum(flatten_lc), 18119.16764691235, decimal=2)

    print("Detrending 4 (hodges)...")
    flatten_lc, trend_lc = flatten(time[:1000], flux[:1000], window_length, method='hodges', return_trend=True)
    numpy.testing.assert_almost_equal(numpy.nansum(flatten_lc), 994.0110525909206, decimal=2)

    print("Detrending 5 (median)...")
    flatten_lc, trend_lc = flatten(time, flux, window_length, method='median', return_trend=True)
    numpy.testing.assert_almost_equal(numpy.nansum(flatten_lc), 18119.122065014355, decimal=2)

    print("Detrending 6 (mean)...")
    flatten_lc, trend_lc = flatten(time, flux, window_length, method='mean', return_trend=True)
    numpy.testing.assert_almost_equal(numpy.nansum(flatten_lc), 18119.032473037714, decimal=2)

    print("Detrending 7 (trim_mean)...")
    flatten_lc, trend_lc = flatten(time, flux, window_length, method='trim_mean', return_trend=True)
    numpy.testing.assert_almost_equal(numpy.nansum(flatten_lc), 18119.095164910334, decimal=2)

    print("Detrending 8 (supersmoother)...")
    flatten_lc, trend_lc = flatten(time, flux, window_length, method='supersmoother', return_trend=True)
    numpy.testing.assert_almost_equal(numpy.nansum(flatten_lc), 18123.00632204841, decimal=2)

    print("Detrending 9 (hspline)...")
    flatten_lc, trend_lc = flatten(time, flux, window_length, method='hspline', return_trend=True)
    numpy.testing.assert_almost_equal(numpy.nansum(flatten_lc), 18123.082601463717, decimal=1)

    print("Detrending 10 (cofiam)...")
    flatten_lc, trend_lc = flatten(time[:2000], flux[:2000], window_length, method='cofiam', return_trend=True)
    numpy.testing.assert_almost_equal(numpy.nansum(flatten_lc), 1948.9999999987976, decimal=1)

    print("Detrending 11 (savgol)...")
    flatten_lc, trend_lc = flatten(time, flux, window_length=301, method='savgol', return_trend=True)
    numpy.testing.assert_almost_equal(numpy.nansum(flatten_lc), 18123.003465539354, decimal=1)

    print("Detrending 12 (medfilt)...")
    flatten_lc, trend_lc = flatten(time, flux, window_length=301, method='medfilt', return_trend=True)
    numpy.testing.assert_almost_equal(numpy.nansum(flatten_lc), 18123.22609806557, decimal=1)

    print("Detrending 12 (gp squared_exp)...")
    flatten_lc, trend_lc1 = flatten(
        time[:2000],
        flux[:2000],
        method='gp',
        kernel='squared_exp',
        kernel_size=10,
        return_trend=True)
    numpy.testing.assert_almost_equal(numpy.nansum(flatten_lc), 1948.9672036416687, decimal=2)

    print("Detrending 13 (gp squared_exp robust)...")
    flatten_lc, trend_lc1 = flatten(
        time[:2000],
        flux[:2000],
        method='gp',
        kernel='squared_exp',
        kernel_size=10,
        robust=True,
        return_trend=True)
    numpy.testing.assert_almost_equal(numpy.nansum(flatten_lc), 1948.8820772313468, decimal=2)

    print("Detrending 14 (gp matern)...")
    flatten_lc, trend_lc2 = flatten(
        time[:2000],
        flux[:2000],
        method='gp',
        kernel='matern',
        kernel_size=10,
        return_trend=True)
    numpy.testing.assert_almost_equal(numpy.nansum(flatten_lc), 1948.9672464898367, decimal=2)

    print("Detrending 15 (gp periodic)...")
    flatten_lc, trend_lc2 = flatten(
        time[:2000],
        flux[:2000],
        method='gp',
        kernel='periodic',
        kernel_size=1,
        kernel_period=10,
        return_trend=True)

    numpy.testing.assert_almost_equal(numpy.nansum(flatten_lc), 1948.9999708985608, decimal=2)

    time_synth = numpy.linspace(0, 30, 200)
    flux_synth = numpy.sin(time_synth) + numpy.random.normal(0, 0.1, 200)
    flux_synth = 1 + flux_synth / 100
    time_synth *= 1.5
    print("Detrending 16 (gp periodic_auto)...")
    flatten_lc, trend_lc2 = flatten(
        time_synth,
        flux_synth,
        method='gp',
        kernel='periodic_auto',
        kernel_size=1,
        return_trend=True)
    numpy.testing.assert_almost_equal(numpy.nansum(flatten_lc), 200, decimal=1)
    
    print("Detrending 17 (rspline)...")
    flatten_lc, trend_lc2 = flatten(
        time,
        flux,
        method='rspline',
        window_length=1,
        return_trend=True)
    numpy.testing.assert_almost_equal(numpy.nansum(flatten_lc), 18121.812790732245, decimal=2)

    print("Detrending 18 (huber)...")
    flatten_lc, trend_lc = flatten(
        time[:1000],
        flux[:1000],
        method='huber',
        window_length=0.5,
        edge_cutoff=0,
        break_tolerance=0.4,
        return_trend=True)
    numpy.testing.assert_almost_equal(numpy.nansum(flatten_lc), 994.01102, decimal=2)
    
    print("Detrending 19 (winsorize)...")
    flatten_lc, trend_lc2 = flatten(
        time,
        flux,
        method='winsorize',
        window_length=0.5,
        edge_cutoff=0,
        break_tolerance=0.4,
        proportiontocut=0.1,
        return_trend=True)
    numpy.testing.assert_almost_equal(numpy.nansum(flatten_lc), 18119.064587196448, decimal=2)
    
    
    print("Detrending 20 (pspline)...")
    flatten_lc, trend_lc = flatten(
        time,
        flux,
        method='pspline',
        return_trend=True
        )
    numpy.testing.assert_almost_equal(numpy.nansum(flatten_lc), 18121.832133916843, decimal=2)
    
    print("Detrending 21 (hampelfilt)...")
    flatten_lc, trend_lc5 = flatten(
        time,
        flux,
        method='hampelfilt',
        window_length=0.5,
        cval=3,
        return_trend=True
        )
    numpy.testing.assert_almost_equal(numpy.nansum(flatten_lc), 18119.158072498867, decimal=2)
    
    print("Detrending 22 (lowess)...")
    flatten_lc, trend_lc1 = flatten(
        time,
        flux,
        method='lowess',
        window_length=1,
        return_trend=True
        )
    numpy.testing.assert_almost_equal(numpy.nansum(flatten_lc), 18123.039744125545, decimal=2)

    print("Detrending 23 (huber_psi)...")
    flatten_lc, trend_lc1 = flatten(
        time,
        flux,
        method='huber_psi',
        window_length=0.5,
        return_trend=True
        )
    numpy.testing.assert_almost_equal(numpy.nansum(flatten_lc), 18119.122065014355, decimal=2)

    print("Detrending 24 (tau)...")
    flatten_lc, trend_lc2 = flatten(
        time,
        flux,
        method='tau',
        window_length=0.5,
        return_trend=True
        )
    numpy.testing.assert_almost_equal(numpy.nansum(flatten_lc), 18119.02772621119, decimal=2)
    
    print("Detrending 25 (cosine)...")
    flatten_lc, trend_lc2 = flatten(
        time,
        flux,
        method='cosine',
        window_length=0.5,
        return_trend=True
        )
    numpy.testing.assert_almost_equal(numpy.nansum(flatten_lc), 18122.999999974905, decimal=2)

    print("Detrending 25 (cosine robust)...")
    flatten_lc, trend_lc2 = flatten(
        time,
        flux,
        method='cosine',
        robust=True,
        window_length=0.5,
        return_trend=True
        )
    numpy.testing.assert_almost_equal(numpy.nansum(flatten_lc), 18122.227938535038, decimal=2)


    import numpy as np
    points = 1000
    time = numpy.linspace(0, 30, points)
    flux = 1 + numpy.sin(time)  / points
    noise = numpy.random.normal(0, 0.0001, points)
    flux += noise

    for i in range(points):  
        if i % 75 == 0:
            flux[i:i+5] -= 0.0004  # Add some transits
            flux[i+50:i+52] += 0.0002  # and flares


    print("Detrending 26 (hampel 17A)...")
    flatten_lc, trend_lc1 = flatten(
        time,
        flux,
        method='hampel',
        cval=(1.7, 3.4, 8.5),
        window_length=0.5,
        return_trend=True
        )

    print("Detrending 27 (hampel 25A)...")
    flatten_lc, trend_lc2 = flatten(
        time,
        flux,
        method='hampel',
        cval=(2.5, 4.5, 9.5),
        window_length=0.5,
        return_trend=True
        )
    numpy.testing.assert_almost_equal(numpy.nansum(flatten_lc), 997.9994362858843, decimal=2)

    print("Detrending 28 (ramsay)...")
    flatten_lc, trend_lc3 = flatten(
        time,
        flux,
        method='ramsay',
        cval=0.3,
        window_length=0.5,
        return_trend=True
        )
    numpy.testing.assert_almost_equal(numpy.nansum(flatten_lc), 997.9974021484584, decimal=2)

    print("Detrending 29 (ridge)...")
    flatten_lc, trend_lc1 = flatten(
        time,
        flux,
        window_length=0.5,
        method='ridge',
        return_trend=True,
        cval=1)
    numpy.testing.assert_almost_equal(numpy.nansum(flatten_lc), 999.9999958887022, decimal=1)

    print("Detrending 30 (lasso)...")
    flatten_lc, trend_lc2 = flatten(
        time,
        flux,
        window_length=0.5,
        method='lasso',
        return_trend=True,
        cval=1)
    numpy.testing.assert_almost_equal(numpy.nansum(flatten_lc), 999.9999894829843, decimal=1)

    print("Detrending 31 (elasticnet)...")
    flatten_lc, trend_lc3 = flatten(
        time,
        flux,
        window_length=0.5,
        method='elasticnet',
        return_trend=True,
        cval=1)
    numpy.testing.assert_almost_equal(numpy.nansum(flatten_lc), 999.9999945063338, decimal=1)


    # Test of transit mask
    print('Testing transit_mask')
    filename = 'hlsp_tess-data-alerts_tess_phot_00207081058-s01_tess_v1_lc.fits'
    time, flux = load_file(path + filename)

    from wotan import transit_mask
    mask = transit_mask(
        time=time,
        period=14.77338,
        duration=0.21060,
        T0=1336.141095
        )
    numpy.testing.assert_almost_equal(numpy.sum(mask), 302, decimal=1)

    print('Detrending 32 (transit_mask cosine)')
    flatten_lc1, trend_lc1 = flatten(
        time,
        flux,
        method='cosine',
        window_length=0.4,
        return_trend=True,
        robust=True,
        mask=mask
        )
    numpy.testing.assert_almost_equal(numpy.nansum(flatten_lc1), 18119.281265446625, decimal=1)

    print('Detrending 33 (transit_mask lowess)')
    flatten_lc2, trend_lc2 = flatten(
        time,
        flux,
        method='lowess',
        window_length=0.8,
        return_trend=True,
        robust=True,
        mask=mask
        )
    #print(numpy.nansum(flatten_lc2))
    numpy.testing.assert_almost_equal(numpy.nansum(flatten_lc2), 18119.30865711536, decimal=1)

    print('Detrending 34 (transit_mask GP)')
    mask = transit_mask(
        time=time[:2000],
        period=100,
        duration=0.3,
        T0=1327.4
        )
    flatten_lc2, trend_lc2 = flatten(
        time[:2000],
        flux[:2000],
        method='gp',
        kernel='matern',
        kernel_size=0.8,
        return_trend=True,
        robust=True,
        mask=mask
        )
    #print(numpy.nansum(flatten_lc2))
    numpy.testing.assert_almost_equal(numpy.nansum(flatten_lc2), 1948.9000170463796, decimal=1)
    
    """
    import matplotlib.pyplot as plt
    plt.scatter(time[:2000], flux[:2000], s=1, color='black')
    plt.plot(time[:2000], trend_lc2, color='red', linewidth=2, linestyle='dashed')
    plt.show()
    plt.close()
    """

    print('All tests completed.')