Exemplo n.º 1
0
def kepler_spline(time, flux, bkspace=1.5, maxiter=5, outlier_cut=3, QUICK_SPLINE=False):
    """Computes a best-fit spline curve for a light curve segment.

    The spline is fit using an iterative process to remove outliers that may cause
    the spline to be "pulled" by discrepent points. In each iteration the spline
    is fit, and if there are any points where the absolute deviation from the
    median residual is at least 3*sigma (where sigma is a robust estimate of the
    standard deviation of the residuals), those points are removed and the spline
    is re-fit.

    Args:
        time: Numpy array; the time values of the light curve.
        flux: Numpy array; the flux (brightness) values of the light curve.
        bkspace: Spline break point spacing in time units.
        maxiter: Maximum number of attempts to fit the spline after removing badly
                fit points.
        outlier_cut: The maximum number of standard deviations from the median
                spline residual before a point is considered an outlier.

    Returns:
        spline: The values of the fitted spline corresponding to the input time
                values.
        mask: Boolean mask indicating the points used to fit the final spline.

    Raises:
        InsufficientPointsError: If there were insufficient points (after removing
                outliers) for spline fitting.
        SplineError: If the spline could not be fit, for example if the breakpoint
                spacing is too small.
    """
    if len(time) < 4:
        raise InsufficientPointsError(
                "Cannot fit a spline on less than 4 points. Got %d points." % len(time))

    # Rescale time into [0, 1].
    t_min = np.min(time)
    t_max = np.max(time)
    time = (time - t_min) / (t_max - t_min)
    bkspace /= (t_max - t_min)    # Rescale bucket spacing.

    if QUICK_SPLINE:
        # calculate knots of quick spline
        nknot = int((time[-1] - time[0]) / bkspace)

        if nknot == 0:
            nknot = 1

        knots = np.linspace(time[1], time[-2], nknot)


        try:
            k_min = np.min(knots)
            k_max = np.max(knots)
        except:
            from IPython import embed; embed()

        knots = (knots-k_min) / (k_max-k_min)


        knots = knots[1:-2]

    # Values of the best fitting spline evaluated at the time points.
    spline = None

    # Mask indicating the points used to fit the spline.
    mask = None

    for _ in range(maxiter):
        if spline is None:
            mask = np.ones_like(time, dtype=np.bool)    # Try to fit all points.
        else:
            # Choose points where the absolute deviation from the median residual is
            # less than outlier_cut*sigma, where sigma is a robust estimate of the
            # standard deviation of the residuals from the previous spline.
            residuals = flux - spline
            new_mask = robust_mean(residuals, cut=outlier_cut)[2]

            if np.all(new_mask == mask):
                break    # Spline converged.

            mask = new_mask

        if np.sum(mask) < 4:
            # Fewer than 4 points after removing outliers. We could plausibly return
            # the spline from the previous iteration because it was fit with at least
            # 4 points. However, since the outliers were such a significant fraction
            # of the curve, the spline from the previous iteration is probably junk,
            # and we consider this a fatal error.
            raise InsufficientPointsError(
                    "Cannot fit a spline on less than 4 points. After removing "
                    "outliers, got %d points." % np.sum(mask))

        try:
            with warnings.catch_warnings():
                # Suppress warning messages printed by pydlutils.bspline. Instead we
                # catch any exception and raise a more informative error.
                warnings.simplefilter("ignore")

                # Fit the spline on non-outlier points.
                if QUICK_SPLINE:
                    #try:
                    curve = LSQUnivariateSpline(time[mask], flux[mask], knots, k=3)
                    #except:
                    #    from IPython import embed; embed()
                else:
                    curve = bspline.iterfit(time[mask], flux[mask], bkspace=bkspace)[0]



            # Evaluate spline at the time points.
            if QUICK_SPLINE:
                spline = curve(time)
            else:
                spline = curve.value(time)[0]

            #spline = np.copy(flux)

        except (IndexError, TypeError, ValueError) as e:
            raise SplineError(
                    "Fitting spline failed with error: '%s'. This might be caused by the "
                    "breakpoint spacing being too small, and/or there being insufficient "
                    "points to fit the spline in one of the intervals." % e)

    return spline, mask