Exemplo n.º 1
0
  def process_one_eqn(eqn: DimEquation) -> bool:
    # Try to rewrite the equation as "var * factor_var = dim_expr" (a linear
    # uni-variate equation. Return False if this rewrite fails.
    # Otherwise, add the variable to shapeenv and return True.

    # The invariant is: var * factor_var + rest_eqn_poly = dim_expr
    var, factor_var = None, None
    dim_expr = eqn.dim_expr

    for mon, factor in eqn.poly.monomials():
      # Perhaps we can already evaluate this monomial (all vars solved)
      try:
        mon_value = mon.evaluate(shapeenv)
        dim_expr = dim_expr - _multiply(mon_value, np.int32(factor))
        continue
      except KeyError:
        # There are some indeterminate variables. We handle only the case of
        # linear remaining indeterminates.
        v = mon.to_var()
        if v is not None and var is None:
          var = v
          factor_var = factor
          continue
      return False

    if var is not None:
      if factor_var == 1:
        var_value, var_remainder = dim_expr, np.int32(0)
      else:
        var_value = lax.div(dim_expr, np.int32(factor_var))  # type: ignore
        var_remainder = lax.rem(dim_expr, np.int32(factor_var))  # type: ignore

      # Check that the division is even. Works only in eager mode.
      var_remainder_int = _is_known_constant(var_remainder)
      if var_remainder_int is not None and var_remainder_int != 0:
        # TODO(necula): check even in graph mode, by embedding the checks in
        # the graph.
        msg = (f"Dimension variable {var} must have integer value >= 1. "  # type: ignore
               f"Found value {int(_is_known_constant(dim_expr)) / factor_var} when solving "
               f"{eqn.poly} == {eqn.dim_expr}.{_shapeenv_to_str()}")
        raise ValueError(msg)
      var_value_int = _is_known_constant(var_value)
      if var_value_int is not None and var_value_int <= 0:
        msg = (f"{var_value_int} Dimension variable {var} must have integer value >= 1. "
               f"Found value {int(var_value_int)} when solving "
               f"{eqn.poly} == {eqn.dim_expr}.{_shapeenv_to_str()}")
        raise ValueError(msg)

      shapeenv[var] = var_value
      return True
    else:
      # All variables are resolved for this equation
      dim_expr_int = _is_known_constant(dim_expr)
      if dim_expr_int is not None and dim_expr_int != 0:
        err_msg = (
            "Found inconsistency when solving "
            f"{eqn.poly} == {eqn.dim_expr}.{_shapeenv_to_str()}")
        raise ValueError(err_msg)
      return True
Exemplo n.º 2
0
def _atan2_taylor(primals_in, series_in):
  x, y = primals_in
  primal_out = lax.atan2(x, y)

  x, series = jet(lax.div, primals_in, series_in)
  c0, cs = jet(lambda x: lax.div(1, 1 + lax.square(x)), (x, ), (series, ))
  c = [c0] + cs
  u = [x] + series
  v = [primal_out] + [None] * len(series)
  for k in range(1, len(v)):
    v[k] = fact(k-1) * sum(_scale(k, j) * c[k-j] * u[j] for j in range(1, k + 1))
  primal_out, *series_out = v
  return primal_out, series_out
Exemplo n.º 3
0
 def _reduce_chooser_taylor_rule(g):
     return lax.div(
         lax._reduce_sum(lax.mul(g, location_indicators), axes), counts)
Exemplo n.º 4
0

def def_comp(prim, comp):
    """
  Define the jet rule for a primitive in terms of a composition of simpler primitives.
  """
    jet_rules[prim] = partial(jet, comp)


def_comp(lax.expm1_p, lambda x: lax.exp(x) - 1)
def_comp(lax.log1p_p, lambda x: lax.log(1 + x))
def_comp(lax.sqrt_p, lambda x: x**0.5)
def_comp(lax.rsqrt_p, lambda x: x**-0.5)
def_comp(lax.asinh_p, lambda x: lax.log(x + lax.sqrt(lax.square(x) + 1)))
def_comp(lax.acosh_p, lambda x: lax.log(x + lax.sqrt(lax.square(x) - 1)))
def_comp(lax.atanh_p, lambda x: 0.5 * lax.log(lax.div(1 + x, 1 - x)))
def_comp(lax.erfc_p, lambda x: 1 - lax.erf(x))
def_comp(lax.rem_p, lambda x, y: x - y * lax.floor(x / y))
def_comp(lax.clamp_p, lambda a, x, b: lax.min(lax.max(a, x), b))


def _erf_inv_rule(primals_in, series_in):
    x, = primals_in
    series, = series_in

    u = [x] + series
    primal_out = lax.erf_inv(x)
    v = [primal_out] + [None] * len(series)

    # derivative on co-domain for caching purposes
    deriv_const = np.sqrt(np.pi) / 2.