def modf(x, out=None): _check_arraylike("modf", x) if out is not None: raise NotImplementedError( "The 'out' argument to jnp.modf is not supported.") whole = _where(lax.ge(x, lax_internal._zero(x)), floor(x), ceil(x)) return x - whole, whole
def _where(condition, x=None, y=None): if x is None or y is None: raise ValueError("Either both or neither of the x and y arguments should " "be provided to jax.numpy.where, got {} and {}." .format(x, y)) if not np.issubdtype(_dtype(condition), np.bool_): condition = lax.ne(condition, lax_internal._zero(condition)) x, y = _promote_dtypes(x, y) condition, x, y = _broadcast_arrays(condition, x, y) try: is_always_empty = core.is_empty_shape(np.shape(x)) except: is_always_empty = False # can fail with dynamic shapes return lax.select(condition, x, y) if not is_always_empty else x
def _reduce_window_sum_transpose_rule(cotangent, operand, *, window_dimensions, window_strides, padding, base_dilation, window_dilation): assert ad.is_undefined_primal(operand) input_shape = operand.aval.shape pads = convolution._conv_general_vjp_lhs_padding( input_shape, window_dimensions, window_strides, cotangent.shape, padding, base_dilation, window_dilation) ones = [1] * len(input_shape) padding_config = [(lo, hi, stride - 1) for (lo, hi), stride in zip(pads, window_strides)] pad_cotangent = lax.pad(cotangent, lax._zero(cotangent), padding_config) result = _reduce_window_sum(pad_cotangent, window_dimensions, base_dilation, [(0, 0)] * len(input_shape), base_dilation=ones, window_dilation=window_dilation) assert result.shape == input_shape, (result.shape, input_shape) return [result]
def _select_and_scatter_add_impl(source, operand, *, select_prim, window_dimensions, window_strides, padding, expand_padding): dtype = source.dtype select = lambda x, y: select_prim.bind(x, y) scatter = lax.bitwise_or if dtype == np.bool_ else lax.add if expand_padding: operand_shape = operand.shape original_padding = padding identity = (lax._get_max_identity if select_prim is lax.ge_p else lax._get_min_identity) pads = [(lo, hi, 0) for (lo, hi) in padding] operand = lax.pad(operand, identity(dtype), pads) padding = [(0, 0) for _ in padding] out = _select_and_scatter( operand, select, window_dimensions, window_strides, padding, source, lax._zero(operand), scatter) if expand_padding: start_indices = [lo for (lo, hi) in original_padding] stop_indices = [lo + d for ((lo, hi), d) in zip(original_padding, operand_shape)] out = slicing.slice(out, start_indices, stop_indices) return out