def _reduction(a, name, np_fun, op, init_val, has_identity=True, preproc=None, bool_op=None, upcast_f16_for_computation=False, axis=None, dtype=None, out=None, keepdims=False, initial=None, where_=None, parallel_reduce=None): bool_op = bool_op or op # Note: we must accept out=None as an argument, because numpy reductions delegate to # object methods. For example `np.sum(x)` will call `x.sum()` if the `sum()` method # exists, passing along all its arguments. if out is not None: raise NotImplementedError( f"The 'out' argument to jnp.{name} is not supported.") _check_arraylike(name, a) lax_internal._check_user_dtype_supported(dtype, name) axis = core.concrete_or_error(None, axis, f"axis argument to jnp.{name}().") if initial is None and not has_identity: if not _all(core.greater_equal_dim(d, 1) for d in np.shape(a)): raise ValueError( f"zero-size array to reduction operation {name} which has no identity" ) if where_ is not None: raise ValueError( f"reduction operation {name} does not have an identity, so to use a " f"where mask one has to specify 'initial'") a = a if isinstance(a, ndarray) else _asarray(a) a = preproc(a) if preproc else a pos_dims, dims = _reduction_dims(a, axis) result_dtype = dtypes.canonicalize_dtype( dtype or dtypes.dtype(np_fun(np.ones((), dtype=dtypes.dtype(a))))) if upcast_f16_for_computation and dtypes.issubdtype( result_dtype, np.inexact): computation_dtype = dtypes.promote_types(result_dtype, np.float32) else: computation_dtype = result_dtype a = lax.convert_element_type(a, computation_dtype) op = op if computation_dtype != np.bool_ else bool_op # NB: in XLA, init_val must be an identity for the op, so the user-specified # initial value must be applied afterward. init_val = _reduction_init_val(a, init_val) if where_ is not None: a = _where(where_, a, init_val) if pos_dims is not dims: if parallel_reduce is None: raise NotImplementedError( f"Named reductions not implemented for jnp.{name}()") result = parallel_reduce(a, dims) else: result = lax.reduce(a, init_val, op, dims) if initial is not None: result = op(lax.convert_element_type(initial, a.dtype), result) if keepdims: result = lax.expand_dims(result, pos_dims) return lax.convert_element_type(result, dtype or result_dtype)
def absolute(x): _check_arraylike('absolute', x) dt = dtypes.dtype(x) return x if dt == np.bool_ or dtypes.issubdtype( dt, np.unsignedinteger) else lax.abs(x)
def heaviside(x1, x2): _check_arraylike("heaviside", x1, x2) x1, x2 = _promote_dtypes_inexact(x1, x2) zero = _lax_const(x1, 0) return _where(lax.lt(x1, zero), zero, _where(lax.gt(x1, zero), _lax_const(x1, 1), x2))
def reciprocal(x): _check_arraylike("reciprocal", x) x, = _promote_dtypes_inexact(x) return lax.integer_pow(x, -1)
def modf(x, out=None): _check_arraylike("modf", x) if out is not None: raise NotImplementedError("The 'out' argument to jnp.modf is not supported.") whole = _where(lax.ge(x, lax_internal._zero(x)), floor(x), ceil(x)) return x - whole, whole
def isnan(x): _check_arraylike("isnan", x) return lax.ne(x, x)
def imag(val): _check_arraylike("imag", val) return lax.imag(val) if np.iscomplexobj(val) else lax.full_like(val, 0)
def real(val): _check_arraylike("real", val) return lax.real(val) if np.iscomplexobj(val) else val
def conjugate(x): _check_arraylike("conjugate", x) return lax.conj(x) if np.iscomplexobj(x) else x
def square(x): _check_arraylike("square", x) return lax.integer_pow(x, 2)
def fmod(x1, x2): _check_arraylike("fmod", x1, x2) if dtypes.issubdtype(dtypes.result_type(x1, x2), np.integer): x2 = _where(x2 == 0, lax_internal._ones(x2), x2) return lax.rem(*_promote_args("fmod", x1, x2))
def polysub(a1, a2): _check_arraylike("polysub", a1, a2) a1, a2 = _promote_dtypes(a1, a2) return polyadd(a1, -a2)
def polyfit(x, y, deg, rcond=None, full=False, w=None, cov=False): _check_arraylike("polyfit", x, y) deg = core.concrete_or_error(int, deg, "deg must be int") order = deg + 1 # check arguments if deg < 0: raise ValueError("expected deg >= 0") if x.ndim != 1: raise TypeError("expected 1D vector for x") if x.size == 0: raise TypeError("expected non-empty vector for x") if y.ndim < 1 or y.ndim > 2: raise TypeError("expected 1D or 2D array for y") if x.shape[0] != y.shape[0]: raise TypeError("expected x and y to have same length") # set rcond if rcond is None: rcond = len(x) * finfo(x.dtype).eps rcond = core.concrete_or_error(float, rcond, "rcond must be float") # set up least squares equation for powers of x lhs = vander(x, order) rhs = y # apply weighting if w is not None: _check_arraylike("polyfit", w) w, = _promote_dtypes_inexact(w) if w.ndim != 1: raise TypeError("expected a 1-d array for weights") if w.shape[0] != y.shape[0]: raise TypeError("expected w and y to have the same length") lhs *= w[:, np.newaxis] if rhs.ndim == 2: rhs *= w[:, np.newaxis] else: rhs *= w # scale lhs to improve condition number and solve scale = sqrt((lhs*lhs).sum(axis=0)) lhs /= scale[np.newaxis,:] c, resids, rank, s = linalg.lstsq(lhs, rhs, rcond) c = (c.T/scale).T # broadcast scale coefficients if full: return c, resids, rank, s, rcond elif cov: Vbase = linalg.inv(dot(lhs.T, lhs)) Vbase /= outer(scale, scale) if cov == "unscaled": fac = 1 else: if len(x) <= order: raise ValueError("the number of data points must exceed order " "to scale the covariance matrix") fac = resids / (len(x) - order) fac = fac[0] #making np.array() of shape (1,) to int if y.ndim == 1: return c, Vbase * fac else: return c, Vbase[:, :, np.newaxis] * fac else: return c