def detrend(data, axis=-1, type='linear', bp=0, overwrite_data=None): if overwrite_data is not None: raise NotImplementedError("overwrite_data argument not implemented.") if type not in ['constant', 'linear']: raise ValueError("Trend type must be 'linear' or 'constant'.") data, = _promote_dtypes_inexact(jnp.asarray(data)) if type == 'constant': return data - data.mean(axis, keepdims=True) else: N = data.shape[axis] # bp is static, so we use np operations to avoid pushing to device. bp = np.sort(np.unique(np.r_[0, bp, N])) if bp[0] < 0 or bp[-1] > N: raise ValueError( "Breakpoints must be non-negative and less than length of data along given axis." ) data = jnp.moveaxis(data, axis, 0) shape = data.shape data = data.reshape(N, -1) for m in range(len(bp) - 1): Npts = bp[m + 1] - bp[m] A = jnp.vstack([ jnp.ones(Npts, dtype=data.dtype), jnp.arange(1, Npts + 1, dtype=data.dtype) / Npts ]).T sl = slice(bp[m], bp[m + 1]) coef, *_ = linalg.lstsq(A, data[sl]) data = data.at[sl].add( -jnp.matmul(A, coef, precision=lax.Precision.HIGHEST)) return jnp.moveaxis(data.reshape(shape), 0, axis)
def polyfit(x, y, deg, rcond=None, full=False, w=None, cov=False): _check_arraylike("polyfit", x, y) deg = core.concrete_or_error(int, deg, "deg must be int") order = deg + 1 # check arguments if deg < 0: raise ValueError("expected deg >= 0") if x.ndim != 1: raise TypeError("expected 1D vector for x") if x.size == 0: raise TypeError("expected non-empty vector for x") if y.ndim < 1 or y.ndim > 2: raise TypeError("expected 1D or 2D array for y") if x.shape[0] != y.shape[0]: raise TypeError("expected x and y to have same length") # set rcond if rcond is None: rcond = len(x) * finfo(x.dtype).eps rcond = core.concrete_or_error(float, rcond, "rcond must be float") # set up least squares equation for powers of x lhs = vander(x, order) rhs = y # apply weighting if w is not None: _check_arraylike("polyfit", w) w, = _promote_dtypes_inexact(w) if w.ndim != 1: raise TypeError("expected a 1-d array for weights") if w.shape[0] != y.shape[0]: raise TypeError("expected w and y to have the same length") lhs *= w[:, np.newaxis] if rhs.ndim == 2: rhs *= w[:, np.newaxis] else: rhs *= w # scale lhs to improve condition number and solve scale = sqrt((lhs * lhs).sum(axis=0)) lhs /= scale[np.newaxis, :] c, resids, rank, s = linalg.lstsq(lhs, rhs, rcond) c = (c.T / scale).T # broadcast scale coefficients if full: return c, resids, rank, s, rcond elif cov: Vbase = linalg.inv(dot(lhs.T, lhs)) Vbase /= outer(scale, scale) if cov == "unscaled": fac = 1 else: if len(x) <= order: raise ValueError("the number of data points must exceed order " "to scale the covariance matrix") fac = resids / (len(x) - order) fac = fac[0] #making np.array() of shape (1,) to int if y.ndim == 1: return c, Vbase * fac else: return c, Vbase[:, :, np.newaxis] * fac else: return c