Ejemplo n.º 1
0
def _argminmax(is_min: bool, operand: TfVal, axes: Sequence[int],
               index_dtype: DType, _in_avals: Sequence[core.ShapedArray],
               _out_aval: core.ShapedArray):
    # The following is known to diverge from JAX behavior for NaN.
    axis, = axes
    output_type = tf.int32
    if dtypes.iinfo(index_dtype).bits > 32:
        output_type = tf.int64
    # TODO(phawkins): handle axes larger than 2^31.
    fn = tf.math.argmin if is_min else tf.math.argmax
    result = fn(operand, axis=axis, output_type=output_type)
    return tf.cast(result, jax2tf._to_tf_dtype(index_dtype))
Ejemplo n.º 2
0
def _reduction_init_val(a, init_val):
    # This function uses np.* functions because lax pattern matches against the
    # specific concrete values of the reduction inputs.
    a_dtype = dtypes.canonicalize_dtype(dtypes.dtype(a))
    if a_dtype == 'bool':
        return np.array(init_val > 0, dtype=a_dtype)
    try:
        return np.array(init_val, dtype=a_dtype)
    except OverflowError:
        assert dtypes.issubdtype(a_dtype, np.integer)
        sign, info = np.sign(init_val), dtypes.iinfo(a_dtype)
        return np.array(info.min if sign < 0 else info.max, dtype=a_dtype)