def process_one_eqn(eqn: DimEquation) -> bool: # Try to rewrite the equation as "var * factor_var = dim_expr" (a linear # uni-variate equation. Return False if this rewrite fails. # Otherwise, add the variable to shapeenv and return True. # The invariant is: var * factor_var + rest_eqn_poly = dim_expr var, factor_var = None, None dim_expr = eqn.dim_expr for mon, factor in eqn.poly.monomials(): # Perhaps we can already evaluate this monomial (all vars solved) try: mon_value = mon.evaluate(shapeenv) dim_expr = dim_expr - _multiply(mon_value, np.int32(factor)) continue except KeyError: # There are some indeterminate variables. We handle only the case of # linear remaining indeterminates. v = mon.to_var() if v is not None and var is None: var = v factor_var = factor continue return False if var is not None: if factor_var == 1: var_value, var_remainder = dim_expr, np.int32(0) else: var_value = lax.div(dim_expr, np.int32(factor_var)) # type: ignore var_remainder = lax.rem(dim_expr, np.int32(factor_var)) # type: ignore # Check that the division is even. Works only in eager mode. var_remainder_int = _is_known_constant(var_remainder) if var_remainder_int is not None and var_remainder_int != 0: # TODO(necula): check even in graph mode, by embedding the checks in # the graph. msg = (f"Dimension variable {var} must have integer value >= 1. " # type: ignore f"Found value {int(_is_known_constant(dim_expr)) / factor_var} when solving " f"{eqn.poly} == {eqn.dim_expr}.{_shapeenv_to_str()}") raise ValueError(msg) var_value_int = _is_known_constant(var_value) if var_value_int is not None and var_value_int <= 0: msg = (f"{var_value_int} Dimension variable {var} must have integer value >= 1. " f"Found value {int(var_value_int)} when solving " f"{eqn.poly} == {eqn.dim_expr}.{_shapeenv_to_str()}") raise ValueError(msg) shapeenv[var] = var_value return True else: # All variables are resolved for this equation dim_expr_int = _is_known_constant(dim_expr) if dim_expr_int is not None and dim_expr_int != 0: err_msg = ( "Found inconsistency when solving " f"{eqn.poly} == {eqn.dim_expr}.{_shapeenv_to_str()}") raise ValueError(err_msg) return True
def _atan2_taylor(primals_in, series_in): x, y = primals_in primal_out = lax.atan2(x, y) x, series = jet(lax.div, primals_in, series_in) c0, cs = jet(lambda x: lax.div(1, 1 + lax.square(x)), (x, ), (series, )) c = [c0] + cs u = [x] + series v = [primal_out] + [None] * len(series) for k in range(1, len(v)): v[k] = fact(k-1) * sum(_scale(k, j) * c[k-j] * u[j] for j in range(1, k + 1)) primal_out, *series_out = v return primal_out, series_out
def _reduce_chooser_taylor_rule(g): return lax.div( lax._reduce_sum(lax.mul(g, location_indicators), axes), counts)
def def_comp(prim, comp): """ Define the jet rule for a primitive in terms of a composition of simpler primitives. """ jet_rules[prim] = partial(jet, comp) def_comp(lax.expm1_p, lambda x: lax.exp(x) - 1) def_comp(lax.log1p_p, lambda x: lax.log(1 + x)) def_comp(lax.sqrt_p, lambda x: x**0.5) def_comp(lax.rsqrt_p, lambda x: x**-0.5) def_comp(lax.asinh_p, lambda x: lax.log(x + lax.sqrt(lax.square(x) + 1))) def_comp(lax.acosh_p, lambda x: lax.log(x + lax.sqrt(lax.square(x) - 1))) def_comp(lax.atanh_p, lambda x: 0.5 * lax.log(lax.div(1 + x, 1 - x))) def_comp(lax.erfc_p, lambda x: 1 - lax.erf(x)) def_comp(lax.rem_p, lambda x, y: x - y * lax.floor(x / y)) def_comp(lax.clamp_p, lambda a, x, b: lax.min(lax.max(a, x), b)) def _erf_inv_rule(primals_in, series_in): x, = primals_in series, = series_in u = [x] + series primal_out = lax.erf_inv(x) v = [primal_out] + [None] * len(series) # derivative on co-domain for caching purposes deriv_const = np.sqrt(np.pi) / 2.