Example #1
0
 def testResultTypeWeakFlag(self):
     float_ = dtypes.canonicalize_dtype(dtypes.float_)
     x_weak = jnp.array(1.)
     x_strong = x_weak.astype(float_)
     self.assertEqual(dtypes.result_type(x_weak), float_)
     self.assertEqual(
         dtypes.result_type(x_weak, return_weak_type_flag=True),
         (float_, True))
     self.assertEqual(dtypes.result_type(x_strong), float_)
     self.assertEqual(
         dtypes.result_type(x_strong, return_weak_type_flag=True),
         (float_, False))
Example #2
0
def _ndarray_constant_handler(c, val, canonicalize_types=True):
    """Constant handler for ndarray literals, handling zero-size strides.

  This function essentially calls _numpy_array_constant(val) except it has
  special handling of arrays with any strides of size zero: for those, it
  generates appropriate calls to NumpyArrayConstant, Broadcast, and Transpose
  to avoid staging in large literals that might arise from np.zeros or np.ones
  or the output of lax.broadcast (which uses np.broadcast_to which in turn
  uses size-zero strides).

  Args:
    c: an XlaBuilder
    val: an ndarray.

  Returns:
    An XLA ComputationDataHandle / XlaOp representing the constant ndarray
    staged into the XLA Computation.
  """
    # TODO(mattjj): revise this to use xops.BroadcastInDim rather than Transpose
    if dtypes.result_type(val) == dtypes.float0:
        return _numpy_array_constant(c, np.zeros(val.shape, dtype=np.bool_))
    elif np.any(np.equal(0, val.strides)) and val.size > 0:
        zero_stride_axes, = np.where(np.equal(0, val.strides))
        other_axes, = np.where(np.not_equal(0, val.strides))
        collapsed_val = val[tuple(0 if ax in zero_stride_axes else slice(None)
                                  for ax in range(val.ndim))]
        xla_val = xops.Broadcast(
            _numpy_array_constant(c, collapsed_val, canonicalize_types)[0],
            np.take(val.shape, zero_stride_axes))
        permutation = np.argsort(tuple(zero_stride_axes) + tuple(other_axes))
        return [xops.Transpose(xla_val, permutation)]
    else:
        return _numpy_array_constant(c, val, canonicalize_types)
Example #3
0
def normalize_to_xla_dtypes(val):
  """Normalize dtypes in a value."""
  if hasattr(val, '__array__') or np.isscalar(val):
    return np.asarray(val, dtype=dtypes.canonicalize_dtype(dtypes.result_type(val)))
  elif isinstance(val, (tuple, list)):
    return tuple(normalize_to_xla_dtypes(x) for x in val)
  raise TypeError('Can\'t convert to XLA: {}'.format(val))
Example #4
0
def _ravel_list(lst):
  if not lst: return jnp.array([], jnp.float32), lambda _: []
  from_dtypes = [dtypes.dtype(l) for l in lst]
  to_dtype = dtypes.result_type(*from_dtypes)
  sizes, shapes = unzip2((jnp.size(x), jnp.shape(x)) for x in lst)
  indices = np.cumsum(sizes)

  if all(dt == to_dtype for dt in from_dtypes):
    # Skip any dtype conversion, resulting in a dtype-polymorphic `unravel`.
    # See https://github.com/google/jax/issues/7809.
    del from_dtypes, to_dtype
    def unravel(arr):
      chunks = jnp.split(arr, indices[:-1])
      return [chunk.reshape(shape) for chunk, shape in zip(chunks, shapes)]
    raveled = jnp.concatenate([jnp.ravel(e) for e in lst])
    return raveled, unravel

  # When there is more than one distinct input dtype, we perform type
  # conversions and produce a dtype-specific unravel function.
  def unravel(arr):
    arr_dtype = dtypes.dtype(arr)
    if arr_dtype != to_dtype:
      raise TypeError(f"unravel function given array of dtype {arr_dtype}, "
                      f"but expected dtype {to_dtype}")
    chunks = jnp.split(arr, indices[:-1])
    with warnings.catch_warnings():
      warnings.simplefilter("ignore")  # ignore complex-to-real cast warning
      return [lax.convert_element_type(chunk.reshape(shape), dtype)
              for chunk, shape, dtype in zip(chunks, shapes, from_dtypes)]

  ravel = lambda e: jnp.ravel(lax.convert_element_type(e, to_dtype))
  raveled = jnp.concatenate([ravel(e) for e in lst])
  return raveled, unravel
Example #5
0
 def val_to_typecode(val):
     dtype = dtypes.result_type(val)
     weak_type = dtypes.is_weakly_typed(val)
     typecode = dtype_to_typecode[dtype]
     if weak_type:
         typecode = typecode[:-1] + '*'
     return typecode
Example #6
0
File: mlir.py Project: rsepassi/jax
def _ndarray_constant_handler(val: np.ndarray,
                              canonicalize_types) -> Sequence[ir.Value]:
    """Constant handler for ndarray literals, handling zero-size strides.

  In most cases this function calls _numpy_array_constant(val) except it has
  special handling of arrays with any strides of size zero: for those, it
  generates appropriate calls to NumpyArrayConstant, Broadcast, and Transpose
  to avoid staging in large literals that might arise from np.zeros or np.ones
  or the output of lax.broadcast (which uses np.broadcast_to which in turn
  uses size-zero strides).

  Args:
    val: an ndarray.

  Returns:
    An XLA ComputationDataHandle / XlaOp representing the constant ndarray
    staged into the XLA Computation.
  """
    if dtypes.result_type(val) == dtypes.float0:
        return _numpy_array_constant(np.zeros(val.shape, dtype=np.bool_),
                                     canonicalize_types=False)
    elif np.any(np.equal(0, val.strides)) and val.size > 0:
        zero_stride_axes, = np.where(np.equal(0, val.strides))
        other_axes, = np.where(np.not_equal(0, val.strides))
        collapsed_val = val[tuple(0 if ax in zero_stride_axes else slice(None)
                                  for ax in range(val.ndim))]
        out = mhlo.BroadcastInDimOp(
            aval_to_ir_type(xla.abstractify(val)),
            _numpy_array_constant(collapsed_val, canonicalize_types)[0],
            dense_int_elements(other_axes)).result
        return (out, )
    else:
        return _numpy_array_constant(val, canonicalize_types)
Example #7
0
def div(lhs, rhs):
  if dtypes.issubdtype(dtypes.result_type(lhs), np.integer):
    quotient = np.floor_divide(lhs, rhs)
    select = np.logical_and(np.sign(lhs) != np.sign(rhs),
                             np.remainder(lhs, rhs) != 0)
    return np.where(select, quotient + 1, quotient)
  else:
    return np.divide(lhs, rhs)
Example #8
0
 def testUnaryPromotion(self, dtype, weak_type):
     # Regression test for https://github.com/google/jax/issues/6051
     x = lax_internal._convert_element_type(0, dtype, weak_type=weak_type)
     if weak_type:
         expected = dtypes.canonicalize_dtype(
             dtypes._default_types['f' if x.dtype ==
                                   'bfloat16' else x.dtype.kind])
     else:
         expected = x.dtype
     self.assertEqual(dtypes.result_type(x), expected)
Example #9
0
def _make_reducer(py_binop, init_val):
    """Make a reducer function given a Python binop and an initial value."""
    # It's tempting to use np.ufunc.reduce (even with a ufunc generated by
    # np.frompyfunc(py_binop)), but this may not agree with custom init_val.
    # We make an attempt to uncover an underlying numpy ufunc (which might be
    # wrapped by autograd or lax) and check its identity against init_val.
    monoid_record = _monoids.get(getattr(py_binop, '__name__'))
    if monoid_record:
        reducer, monoid_identity = monoid_record
        if init_val == monoid_identity(dtypes.result_type(init_val)):
            return reducer
    return _reducer_from_pyfunc(py_binop, init_val)
Example #10
0
def _scatter_impl(x, y, scatter_op, treedef, static_idx, dynamic_idx,
                  indices_are_sorted, unique_indices, mode, normalize_indices):
    dtype = lax.dtype(x)
    weak_type = dtypes.is_weakly_typed(x)

    if dtype != dtypes.result_type(x, y):
        # TODO(jakevdp): change this to an error after the deprecation period.
        warnings.warn(
            "scatter inputs have incompatible types: cannot safely cast "
            f"value from dtype={lax.dtype(y)} to dtype={lax.dtype(x)}. "
            "In future JAX releases this will result in an error.",
            FutureWarning)

    idx = jnp._merge_static_and_dynamic_indices(treedef, static_idx,
                                                dynamic_idx)
    indexer = jnp._index_to_gather(jnp.shape(x),
                                   idx,
                                   normalize_indices=normalize_indices)

    # Avoid calling scatter if the slice shape is empty, both as a fast path and
    # to handle cases like zeros(0)[array([], int32)].
    if core.is_empty_shape(indexer.slice_shape):
        return x

    x, y = jnp._promote_dtypes(x, y)

    # Broadcast `y` to the slice output shape.
    y = jnp.broadcast_to(y, tuple(indexer.slice_shape))
    # Collapse any `None`/`jnp.newaxis` dimensions.
    y = jnp.squeeze(y, axis=indexer.newaxis_dims)
    if indexer.reversed_y_dims:
        y = lax.rev(y, indexer.reversed_y_dims)

    # Transpose the gather dimensions into scatter dimensions (cf.
    # lax._gather_transpose_rule)
    dnums = lax.ScatterDimensionNumbers(
        update_window_dims=indexer.dnums.offset_dims,
        inserted_window_dims=indexer.dnums.collapsed_slice_dims,
        scatter_dims_to_operand_dims=indexer.dnums.start_index_map)
    out = scatter_op(x,
                     indexer.gather_indices,
                     y,
                     dnums,
                     indices_are_sorted=indexer.indices_are_sorted
                     or indices_are_sorted,
                     unique_indices=indexer.unique_indices or unique_indices,
                     mode=mode)
    return lax_internal._convert_element_type(out, dtype, weak_type)
Example #11
0
def _ravel_list(lst):
  if not lst: return jnp.array([], jnp.float32), lambda _: []
  from_dtypes = [dtypes.dtype(l) for l in lst]
  to_dtype = dtypes.result_type(*from_dtypes)
  sizes, shapes = unzip2((jnp.size(x), jnp.shape(x)) for x in lst)
  indices = np.cumsum(sizes)

  def unravel(arr):
    chunks = jnp.split(arr, indices[:-1])
    with warnings.catch_warnings():
      warnings.simplefilter("ignore")  # ignore complex-to-real cast warning
      return [lax.convert_element_type(chunk.reshape(shape), dtype)
              for chunk, shape, dtype in zip(chunks, shapes, from_dtypes)]

  ravel = lambda e: jnp.ravel(lax.convert_element_type(e, to_dtype))
  raveled = jnp.concatenate([ravel(e) for e in lst])
  return raveled, unravel
Example #12
0
def _promote_to_complex(arg):
  dtype = dtypes.result_type(arg, np.complex64)
  return lax.convert_element_type(arg, dtype)
Example #13
0
def switch(index,
           branches: Sequence[Callable],
           *operands,
           operand=_no_operand_sentinel):
    """Apply exactly one of ``branches`` given by ``index``.

  If ``index`` is out of bounds, it is clamped to within bounds.

  Has the semantics of the following Python::

    def switch(index, branches, *operands):
      index = clamp(0, index, len(branches) - 1)
      return branches[index](*operands)

  Args:
    index: Integer scalar type, indicating which branch function to apply.
    branches: Sequence of functions (A -> B) to be applied based on ``index``.
    operands: Operands (A) input to whichever branch is applied.

  Returns:
    Value (B) of ``branch(*operands)`` for the branch that was selected based
    on ``index``.
  """
    if not all(callable(branch) for branch in branches):
        raise TypeError(
            "lax.switch: branches argument should be a sequence of callables.")
    if operand is not _no_operand_sentinel:
        if operands:
            raise TypeError(
                "if 'operand' keyword is passed then no positional "
                f"operands can be passed, got operand={operand} "
                f"and positional operands {operands}")
        operands = (operand, )
    del operand

    if len(np.shape(index)) != 0:
        raise TypeError(f"Branch index must be scalar, "
                        f"got {index} of shape {np.shape(index)}.")

    try:
        index_dtype = dtypes.result_type(index)
    except TypeError as err:
        msg = f"Index type must be an integer, got {index}."
        raise TypeError(msg) from err

    if index_dtype.kind not in 'iu':
        raise TypeError(
            f"Index type must be an integer, got {index} as {index_dtype}")

    branches = tuple(branches)

    if len(branches) == 0:
        raise ValueError("Empty branch sequence")
    elif len(branches) == 1:
        return branches[0](*operands)

    index = lax.convert_element_type(index, np.int32)
    lo = np.array(0, np.int32)
    hi = np.array(len(branches) - 1, np.int32)
    index = lax.clamp(lo, index, hi)

    if (config.jax_disable_jit
            and isinstance(core.get_aval(index), ConcreteArray)):
        return branches[int(index)](*operands)

    ops, ops_tree = tree_flatten(operands)
    ops_avals = tuple(_map(_abstractify, ops))

    jaxprs, consts, out_trees = _initial_style_jaxprs_with_common_consts(
        branches, ops_tree, ops_avals, primitive_name='switch')
    for i, (out_tree, jaxpr) in enumerate(zip(out_trees[1:], jaxprs[1:])):
        _check_tree_and_avals(f"branch 0 and {i + 1} outputs", out_trees[0],
                              jaxprs[0].out_avals, out_tree, jaxpr.out_avals)
    joined_effects = core.join_effects(*(jaxpr.effects for jaxpr in jaxprs))
    disallowed_effects = joined_effects - allowed_effects
    if disallowed_effects:
        raise NotImplementedError(
            f'Effects not supported in `switch`: {disallowed_effects}')

    linear = (False, ) * (len(consts) + len(ops))
    out = cond_p.bind(index,
                      *consts,
                      *ops,
                      branches=tuple(jaxprs),
                      linear=linear)
    return tree_unflatten(out_trees[0], out)
Example #14
0
def zeros_like_array(x):
    dtype = dtypes.canonicalize_dtype(dtypes.result_type(x))
    aval = ShapedArray(np.shape(x), dtype)
    return ad_util.zeros_like_aval(aval)
Example #15
0
 def testResultTypeNone(self):
     # This matches the behavior of np.result_type(None) => np.float_
     self.assertEqual(dtypes.result_type(None),
                      dtypes.canonicalize_dtype(dtypes.float_))
Example #16
0
def make_shaped_array(x):
  dtype = dtypes.canonicalize_dtype(dtypes.result_type(x))
  return ShapedArray(np.shape(x), dtype)
Example #17
0
 def testUnaryPromotion(self, dtype, weak_type):
   # Regression test for https://github.com/google/jax/issues/6051
   x = lax._convert_element_type(0, dtype, weak_type=weak_type)
   y = jnp.array(0, dtype=dtypes.result_type(x))
   assert x.dtype == y.dtype
Example #18
0
def _canonicalize_ndarray_dtype(x):
    return np.asarray(x, dtypes.canonicalize_dtype(dtypes.result_type(x)))
Example #19
0
def zeros_like_array(x):
  dtype = dtypes.canonicalize_dtype(dtypes.result_type(x))
  return zeros_like_shaped_array(ShapedArray(np.shape(x), dtype))
Example #20
0
def _cond(pred,
          true_fun: Callable,
          false_fun: Callable,
          *operands,
          operand=_no_operand_sentinel,
          linear=None):
    """Conditionally apply ``true_fun`` or ``false_fun``.

  ``cond()`` has equivalent semantics to this Python implementation::

    def cond(pred, true_fun, false_fun, *operands):
      if pred:
        return true_fun(*operands)
      else:
        return false_fun(*operands)

  ``pred`` must be a scalar type.

  Args:
    pred: Boolean scalar type, indicating which branch function to apply.
    true_fun: Function (A -> B), to be applied if ``pred`` is True.
    false_fun: Function (A -> B), to be applied if ``pred`` is False.
    operands: Operands (A) input to either branch depending on ``pred``. The
      type can be a scalar, array, or any pytree (nested Python tuple/list/dict)
      thereof.

  Returns:
    Value (B) of either ``true_fun(*operands)`` or ``false_fun(*operands)``,
    depending on the value of ``pred``. The type can be a scalar, array, or any
    pytree (nested Python tuple/list/dict) thereof.
  """
    if not (callable(true_fun) and callable(false_fun)):
        raise TypeError(
            "lax.cond: true_fun and false_fun arguments should be callable.")
    if operand is not _no_operand_sentinel:
        if operands:
            raise TypeError(
                "if 'operand' keyword is passed then no positional "
                f"operands can be passed, got operand={operand} "
                f"and positional operands {operands}")
        operands = (operand, )
    del operand

    if isinstance(pred, Sequence) or np.ndim(pred) != 0:
        raise TypeError(f"Pred must be a scalar, got {pred} of " +
                        (f"type {type(pred)}" if isinstance(pred, Sequence)
                         else f"shape {np.shape(pred)}."))

    try:
        pred_dtype = dtypes.result_type(pred)
    except TypeError as err:
        msg = ("Pred type must be either boolean or number, got {}.")
        raise TypeError(msg.format(pred)) from err

    if pred_dtype.kind != 'b':
        if pred_dtype.kind in 'iuf':
            pred = pred != 0
        else:
            msg = ("Pred type must be either boolean or number, got {}.")
            raise TypeError(msg.format(pred_dtype))

    if config.jax_disable_jit and isinstance(core.get_aval(pred),
                                             ConcreteArray):
        if pred:
            return true_fun(*operands)
        else:
            return false_fun(*operands)

    ops, ops_tree = tree_flatten(operands)
    if linear is None:
        linear_ops = [False] * len(ops)
    else:
        linear_ops, ops_tree2 = tree_flatten(linear)
        if ops_tree != ops_tree2:
            raise TypeError('linear tree and operand tree mismatch')
    ops_avals = tuple(_map(_abstractify, ops))

    jaxprs, consts, out_trees = _initial_style_jaxprs_with_common_consts(
        (true_fun, false_fun), ops_tree, ops_avals, 'cond')
    true_jaxpr, false_jaxpr = jaxprs
    out_tree, false_out_tree = out_trees

    _check_tree_and_avals("true_fun and false_fun output", out_tree,
                          true_jaxpr.out_avals, false_out_tree,
                          false_jaxpr.out_avals)
    joined_effects = core.join_effects(true_jaxpr.effects, false_jaxpr.effects)
    disallowed_effects = joined_effects - allowed_effects
    if disallowed_effects:
        raise NotImplementedError(
            f'Effects not supported in `cond`: {disallowed_effects}')

    index = lax.convert_element_type(pred, np.int32)

    linear = [False] * len(consts) + linear_ops
    out = cond_p.bind(index,
                      *consts,
                      *ops,
                      branches=(false_jaxpr, true_jaxpr),
                      linear=tuple(linear))
    return tree_unflatten(out_tree, out)
Example #21
0
def _promote_to_real(arg):
  dtype = dtypes.result_type(arg, np.float32)
  return lax.convert_element_type(arg, dtype)
Example #22
0
def _dtype(x):
  try:
    return dtypes.result_type(x)
  except ValueError:
    return dtypes.result_type(getattr(x, 'dtype'))
Example #23
0
def fmod(x1, x2):
    _check_arraylike("fmod", x1, x2)
    if dtypes.issubdtype(dtypes.result_type(x1, x2), np.integer):
        x2 = _where(x2 == 0, lax_internal._ones(x2), x2)
    return lax.rem(*_promote_args("fmod", x1, x2))