def _argminmax(is_min: bool, operand: TfVal, axes: Sequence[int], index_dtype: DType, _in_avals: Sequence[core.ShapedArray], _out_aval: core.ShapedArray): # The following is known to diverge from JAX behavior for NaN. axis, = axes output_type = tf.int32 if dtypes.iinfo(index_dtype).bits > 32: output_type = tf.int64 # TODO(phawkins): handle axes larger than 2^31. fn = tf.math.argmin if is_min else tf.math.argmax result = fn(operand, axis=axis, output_type=output_type) return tf.cast(result, jax2tf._to_tf_dtype(index_dtype))
def _reduction_init_val(a, init_val): # This function uses np.* functions because lax pattern matches against the # specific concrete values of the reduction inputs. a_dtype = dtypes.canonicalize_dtype(dtypes.dtype(a)) if a_dtype == 'bool': return np.array(init_val > 0, dtype=a_dtype) try: return np.array(init_val, dtype=a_dtype) except OverflowError: assert dtypes.issubdtype(a_dtype, np.integer) sign, info = np.sign(init_val), dtypes.iinfo(a_dtype) return np.array(info.min if sign < 0 else info.max, dtype=a_dtype)