def make_tensorspec(a_jax): a_tf_dtype = jax2tf_internal._to_tf_dtype(a_jax.dtype) if any(not core.is_constant_dim(d) for d in a_jax.shape): msg = ("call_tf cannot be applied to shape-polymorphic arguments. " f"Found argument shape: {a_jax.shape}. " "See https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#limitations-of-call-tf for a discussion.") raise ValueError(msg) return tf.TensorSpec(a_jax.shape, a_tf_dtype)
def _argminmax(is_min: bool, operand: TfVal, axes: Sequence[int], index_dtype: DType, _in_avals: Sequence[core.ShapedArray], _out_aval: core.ShapedArray): # The following is known to diverge from JAX behavior for NaN. axis, = axes output_type = tf.int32 if dtypes.iinfo(index_dtype).bits > 32: output_type = tf.int64 # TODO(phawkins): handle axes larger than 2^31. fn = tf.math.argmin if is_min else tf.math.argmax result = fn(operand, axis=axis, output_type=output_type) return tf.cast(result, jax2tf._to_tf_dtype(index_dtype))
def _gather_with_batch_dims(args: GatherArgs): """Implements call to gather with non-empty batch dimensions. E.g., when doing `jax.vmap(lax.dynamic_slice). """ op_shape = jax2tf._eval_shape(args.op_shape) start_indices = _clip(op_shape, args.start_indices, args.slice_sizes) result = tf.map_fn( lambda idxs: tf.slice(args.operand, begin=idxs, size=args.slice_sizes), start_indices, fn_output_signature=jax2tf._to_tf_dtype(args.operand.dtype)) result = tf.squeeze(result, axis=1) return result