コード例 #1
0
ファイル: call_tf.py プロジェクト: frederikwilde/jax
    def make_tensorspec(a_jax):
      a_tf_dtype = jax2tf_internal._to_tf_dtype(a_jax.dtype)
      if any(not core.is_constant_dim(d) for d in a_jax.shape):
        msg = ("call_tf cannot be applied to shape-polymorphic arguments. "
               f"Found argument shape: {a_jax.shape}. "
               "See https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#limitations-of-call-tf for a discussion.")
        raise ValueError(msg)

      return tf.TensorSpec(a_jax.shape, a_tf_dtype)
コード例 #2
0
ファイル: impl_no_xla.py プロジェクト: matthewfeickert/jax
def _argminmax(is_min: bool, operand: TfVal, axes: Sequence[int],
               index_dtype: DType, _in_avals: Sequence[core.ShapedArray],
               _out_aval: core.ShapedArray):
    # The following is known to diverge from JAX behavior for NaN.
    axis, = axes
    output_type = tf.int32
    if dtypes.iinfo(index_dtype).bits > 32:
        output_type = tf.int64
    # TODO(phawkins): handle axes larger than 2^31.
    fn = tf.math.argmin if is_min else tf.math.argmax
    result = fn(operand, axis=axis, output_type=output_type)
    return tf.cast(result, jax2tf._to_tf_dtype(index_dtype))
コード例 #3
0
ファイル: impl_no_xla.py プロジェクト: matthewfeickert/jax
def _gather_with_batch_dims(args: GatherArgs):
    """Implements call to gather with non-empty batch dimensions.

  E.g., when doing `jax.vmap(lax.dynamic_slice).
  """
    op_shape = jax2tf._eval_shape(args.op_shape)
    start_indices = _clip(op_shape, args.start_indices, args.slice_sizes)
    result = tf.map_fn(
        lambda idxs: tf.slice(args.operand, begin=idxs, size=args.slice_sizes),
        start_indices,
        fn_output_signature=jax2tf._to_tf_dtype(args.operand.dtype))
    result = tf.squeeze(result, axis=1)
    return result