Esempio n. 1
0
def _make_static_axis_non_negative_list(axis, ndims):
    """Convert possibly negatively indexed axis to non-negative list of ints.

  Args:
    axis:  Integer Tensor.
    ndims:  Number of dimensions into which axis indexes.

  Returns:
    A list of non-negative Python integers.

  Raises:
    ValueError: If `axis` is not statically defined.
  """
    axis = distribution_util.make_non_negative_axis(axis, ndims)

    axis_const = tf.contrib.util.constant_value(axis)
    if axis_const is None:
        raise ValueError(
            'Expected argument `axis` to be statically available.  Found: %s' %
            axis)

    # Make at least 1-D.
    axis = axis_const + np.zeros([1], dtype=axis_const.dtype)

    return list(int(dim) for dim in axis)
Esempio n. 2
0
def _interp_regular_1d_grid_impl(x,
                                 x_ref_min,
                                 x_ref_max,
                                 y_ref,
                                 axis=-1,
                                 batch_y_ref=False,
                                 fill_value='constant_extension',
                                 fill_value_below=None,
                                 fill_value_above=None,
                                 grid_regularizing_transform=None,
                                 name=None):
    """1-D interpolation that works with/without batching."""
    # To understand the implemention differences between the batch/no-batch
    # versions of this function, you should probably understand the difference
    # between tf.gather and tf.batch_gather.  In particular, we do *not* make the
    # no-batch version a special case of the batch version, because that would
    # an inefficient use of batch_gather with unnecessarily broadcast args.
    with tf.name_scope(name,
                       values=[
                           x, x_ref_min, x_ref_max, y_ref, axis, fill_value,
                           fill_value_below, fill_value_above
                       ]):

        # Arg checking.
        allowed_fv_st = ('constant_extension', 'extrapolate')
        for fv in (fill_value, fill_value_below, fill_value_above):
            if isinstance(fv, str) and fv not in allowed_fv_st:
                raise ValueError(
                    'A fill value ({}) was not an allowed string ({})'.format(
                        fv, allowed_fv_st))

        # Separate value fills for below/above incurs extra cost, so keep track of
        # whether this is needed.
        need_separate_fills = (
            fill_value_above is not None or fill_value_below is not None or
            fill_value == 'extrapolate'  # always requries separate below/above
        )
        if need_separate_fills and fill_value_above is None:
            fill_value_above = fill_value
        if need_separate_fills and fill_value_below is None:
            fill_value_below = fill_value

        dtype = dtype_util.common_dtype([x, x_ref_min, x_ref_max, y_ref],
                                        preferred_dtype=tf.float32)
        x = tf.convert_to_tensor(value=x, name='x', dtype=dtype)

        x_ref_min = tf.convert_to_tensor(value=x_ref_min,
                                         name='x_ref_min',
                                         dtype=dtype)
        x_ref_max = tf.convert_to_tensor(value=x_ref_max,
                                         name='x_ref_max',
                                         dtype=dtype)
        if not batch_y_ref:
            _assert_ndims_statically(x_ref_min, expect_ndims=0)
            _assert_ndims_statically(x_ref_max, expect_ndims=0)

        y_ref = tf.convert_to_tensor(value=y_ref, name='y_ref', dtype=dtype)

        if batch_y_ref:
            # If we're batching,
            #   x.shape ~ [A1,...,AN, D],  x_ref_min/max.shape ~ [A1,...,AN]
            # So to add together we'll append a singleton.
            # If not batching, x_ref_min/max are scalar, so this isn't an issue,
            # moreover, if not batching, x can be scalar, and expanding x_ref_min/max
            # would cause a bad expansion of x when added to x (confused yet?).
            x_ref_min = x_ref_min[..., tf.newaxis]
            x_ref_max = x_ref_max[..., tf.newaxis]

        axis = tf.convert_to_tensor(value=axis, name='axis', dtype=tf.int32)
        axis = distribution_util.make_non_negative_axis(axis, tf.rank(y_ref))
        _assert_ndims_statically(axis, expect_ndims=0)

        ny = tf.cast(tf.shape(input=y_ref)[axis], dtype)

        # Map [x_ref_min, x_ref_max] to [0, ny - 1].
        # This is the (fractional) index of x.
        if grid_regularizing_transform is None:
            g = lambda x: x
        else:
            g = grid_regularizing_transform
        fractional_idx = ((g(x) - g(x_ref_min)) /
                          (g(x_ref_max) - g(x_ref_min)))
        x_idx_unclipped = fractional_idx * (ny - 1)

        # Wherever x is NaN, x_idx_unclipped will be NaN as well.
        # Keep track of the nan indices here (so we can impute NaN later).
        # Also eliminate any NaN indices, since there is not NaN in 32bit.
        nan_idx = tf.math.is_nan(x_idx_unclipped)
        x_idx_unclipped = tf.where(nan_idx, tf.zeros_like(x_idx_unclipped),
                                   x_idx_unclipped)

        x_idx = tf.clip_by_value(x_idx_unclipped, tf.zeros((), dtype=dtype),
                                 ny - 1)

        # Get the index above and below x_idx.
        # Naively we could set idx_below = floor(x_idx), idx_above = ceil(x_idx),
        # however, this results in idx_below == idx_above whenever x is on a grid.
        # This in turn results in y_ref_below == y_ref_above, and then the gradient
        # at this point is zero.  So here we "jitter" one of idx_below, idx_above,
        # so that they are at different values.  This jittering does not affect the
        # interpolated value, but does make the gradient nonzero (unless of course
        # the y_ref values are the same).
        idx_below = tf.floor(x_idx)
        idx_above = tf.minimum(idx_below + 1, ny - 1)
        idx_below = tf.maximum(idx_above - 1, 0)

        # These are the values of y_ref corresponding to above/below indices.
        idx_below_int32 = tf.cast(idx_below, dtype=tf.int32)
        idx_above_int32 = tf.cast(idx_above, dtype=tf.int32)
        if batch_y_ref:
            # If y_ref.shape ~ [A1,...,AN, C, B1,...,BN],
            # and x.shape, x_ref_min/max.shape ~ [A1,...,AN, D]
            # Then y_ref_below.shape ~ [A1,...,AN, D, B1,...,BN]
            y_ref_below = _batch_gather_with_broadcast(y_ref, idx_below_int32,
                                                       axis)
            y_ref_above = _batch_gather_with_broadcast(y_ref, idx_above_int32,
                                                       axis)
        else:
            # Here, y_ref_below.shape =
            #   y_ref.shape[:axis] + x.shape + y_ref.shape[axis + 1:]
            y_ref_below = tf.gather(y_ref, idx_below_int32, axis=axis)
            y_ref_above = tf.gather(y_ref, idx_above_int32, axis=axis)

        # Use t to get a convex combination of the below/above values.
        t = x_idx - idx_below

        # x, and tensors shaped like x, need to be added to, and selected with
        # (using tf.where) the output y.  This requires appending singletons.
        # Make functions appropriate for batch/no-batch.
        if batch_y_ref:
            # In the non-batch case, the output shape is going to be
            #   y_ref.shape[:axis] + x.shape + y_ref.shape[axis+1:]
            expand_x_fn = _make_expand_x_fn_for_batch_interpolation(
                y_ref, axis)
        else:
            # In the batch case, the output shape is going to be
            #   Broadcast(y_ref.shape[:axis], x.shape[:-1]) +
            #   x.shape[-1:] +  y_ref.shape[axis+1:]
            expand_x_fn = _make_expand_x_fn_for_non_batch_interpolation(
                y_ref, axis)

        t = expand_x_fn(t)
        nan_idx = expand_x_fn(nan_idx, broadcast=True)
        x_idx_unclipped = expand_x_fn(x_idx_unclipped, broadcast=True)

        y = t * y_ref_above + (1 - t) * y_ref_below

        # Now begins a long excursion to fill values outside [x_min, x_max].

        # Re-insert NaN wherever x was NaN.
        y = tf.where(nan_idx,
                     tf.fill(tf.shape(input=y), tf.constant(np.nan, y.dtype)),
                     y)

        if not need_separate_fills:
            if fill_value == 'constant_extension':
                pass  # Already handled by clipping x_idx_unclipped.
            else:
                y = tf.where(
                    (x_idx_unclipped < 0) | (x_idx_unclipped > ny - 1),
                    fill_value + tf.zeros_like(y), y)
        else:
            # Fill values below x_ref_min <==> x_idx_unclipped < 0.
            if fill_value_below == 'constant_extension':
                pass  # Already handled by the clipping that created x_idx_unclipped.
            elif fill_value_below == 'extrapolate':
                if batch_y_ref:
                    # For every batch member, gather the first two elements of y across
                    # `axis`.
                    y_0 = tf.gather(y_ref, [0], axis=axis)
                    y_1 = tf.gather(y_ref, [1], axis=axis)
                else:
                    # If not batching, we want to gather the first two elements, just like
                    # above.  However, these results need to be replicated for every
                    # member of x.  An easy way to do that is to gather using
                    # indices = zeros/ones(x.shape).
                    y_0 = tf.gather(y_ref,
                                    tf.zeros(tf.shape(input=x),
                                             dtype=tf.int32),
                                    axis=axis)
                    y_1 = tf.gather(y_ref,
                                    tf.ones(tf.shape(input=x), dtype=tf.int32),
                                    axis=axis)
                x_delta = (x_ref_max - x_ref_min) / (ny - 1)
                x_factor = expand_x_fn((x - x_ref_min) / x_delta,
                                       broadcast=True)
                y = tf.where(x_idx_unclipped < 0, y_0 + x_factor * (y_1 - y_0),
                             y)
            else:
                y = tf.where(x_idx_unclipped < 0,
                             fill_value_below + tf.zeros_like(y), y)
            # Fill values above x_ref_min <==> x_idx_unclipped > ny - 1.
            if fill_value_above == 'constant_extension':
                pass  # Already handled by the clipping that created x_idx_unclipped.
            elif fill_value_above == 'extrapolate':
                ny_int32 = tf.shape(input=y_ref)[axis]
                if batch_y_ref:
                    y_n1 = tf.gather(y_ref, [tf.shape(input=y_ref)[axis] - 1],
                                     axis=axis)
                    y_n2 = tf.gather(y_ref, [tf.shape(input=y_ref)[axis] - 2],
                                     axis=axis)
                else:
                    y_n1 = tf.gather(y_ref,
                                     tf.fill(tf.shape(input=x), ny_int32 - 1),
                                     axis=axis)
                    y_n2 = tf.gather(y_ref,
                                     tf.fill(tf.shape(input=x), ny_int32 - 2),
                                     axis=axis)
                x_delta = (x_ref_max - x_ref_min) / (ny - 1)
                x_factor = expand_x_fn((x - x_ref_max) / x_delta,
                                       broadcast=True)
                y = tf.where(x_idx_unclipped > ny - 1,
                             y_n1 + x_factor * (y_n1 - y_n2), y)
            else:
                y = tf.where(x_idx_unclipped > ny - 1,
                             fill_value_above + tf.zeros_like(y), y)

        return y
Esempio n. 3
0
def batch_interp_regular_nd_grid(x,
                                 x_ref_min,
                                 x_ref_max,
                                 y_ref,
                                 axis,
                                 fill_value='constant_extension',
                                 name=None):
    """Multi-linear interpolation on a regular (constant spacing) grid.

  Given [a batch of] reference values, this function computes a multi-linear
  interpolant and evaluates it on [a batch of] of new `x` values.

  The interpolant is built from reference values indexed by `nd` dimensions
  of `y_ref`, starting at `axis`.

  For example, take the case of a `2-D` scalar valued function and no leading
  batch dimensions.  In this case, `y_ref.shape = [C1, C2]` and `y_ref[i, j]`
  is the reference value corresponding to grid point

  ```
  [x_ref_min[0] + i * (x_ref_max[0] - x_ref_min[0]) / (C1 - 1),
   x_ref_min[1] + j * (x_ref_max[1] - x_ref_min[1]) / (C2 - 1)]
  ```

  In the general case, dimensions to the left of `axis` in `y_ref` are broadcast
  with leading dimensions in `x`, `x_ref_min`, `x_ref_max`.

  Args:
    x: Numeric `Tensor` The x-coordinates of the interpolated output values for
      each batch.  Shape `[..., D, nd]`, designating [a batch of] `D`
      coordinates in `nd` space.  `D` must be `>= 1` and is not a batch dim.
    x_ref_min:  `Tensor` of same `dtype` as `x`.  The minimum values of the
      (implicitly defined) reference `x_ref`.  Shape `[..., nd]`.
    x_ref_max:  `Tensor` of same `dtype` as `x`.  The maximum values of the
      (implicitly defined) reference `x_ref`.  Shape `[..., nd]`.
    y_ref:  `Tensor` of same `dtype` as `x`.  The reference output values. Shape
      `[..., C1, ..., Cnd, B1,...,BM]`, designating [a batch of] reference
      values indexed by `nd` dimensions, of a shape `[B1,...,BM]` valued
      function (for `M >= 0`).
    axis:  Scalar integer `Tensor`.  Dimensions `[axis, axis + nd)` of `y_ref`
      index the interpolation table.  E.g. `3-D` interpolation of a scalar
      valued function requires `axis=-3` and a `3-D` matrix valued function
      requires `axis=-5`.
    fill_value:  Determines what values output should take for `x` values that
      are below `x_ref_min` or above `x_ref_max`. Scalar `Tensor` or
      "constant_extension" ==> Extend as constant function.
      Default value: `"constant_extension"`
    name:  A name to prepend to created ops.
      Default value: `"batch_interp_regular_nd_grid"`.

  Returns:
    y_interp:  Interpolation between members of `y_ref`, at points `x`.
      `Tensor` of same `dtype` as `x`, and shape `[..., D, B1, ..., BM].`

  Raises:
    ValueError:  If `rank(x) < 2` is determined statically.
    ValueError:  If `axis` is not a scalar is determined statically.
    ValueError:  If `axis + nd > rank(y_ref)` is determined statically.

  #### Examples

  Interpolate a function of one variable.

  ```python
  y_ref = tf.exp(tf.linspace(start=0., stop=10., 20))

  tfp.math.batch_interp_regular_nd_grid(
      # x.shape = [3, 1], x_ref_min/max.shape = [1].  Trailing `1` for `1-D`.
      x=[[6.0], [0.5], [3.3]], x_ref_min=[0.], x_ref_max=[1.], y_ref=y_ref)
  ==> approx [exp(6.0), exp(0.5), exp(3.3)]
  ```

  Interpolate a scalar function of two variables.

  ```python
  x_ref_min = [0., 2 * np.pi]
  x_ref_max = [0., 2 * np.pi]

  # Build y_ref.
  x0s, x1s = tf.meshgrid(
      tf.linspace(x_ref_min[0], x_ref_max[0], num=100),
      tf.linspace(x_ref_min[1], x_ref_max[1], num=100),
      indexing='ij')

  def func(x0, x1):
    return tf.sin(x0) * tf.cos(x1)

  y_ref = func(x0s, x1s)

  x = np.pi * tf.random_uniform(shape=(10, 2))

  tfp.math.batch_interp_regular_nd_grid(x, x_ref_min, x_ref_max, y_ref, axis=-2)
  ==> tf.sin(x[:, 0]) * tf.cos(x[:, 1])
  ```

  """
    with tf.compat.v1.name_scope(
            name,
            default_name='interp_regular_nd_grid',
            values=[x, x_ref_min, x_ref_max, y_ref, fill_value]):
        dtype = dtype_util.common_dtype([x, x_ref_min, x_ref_max, y_ref],
                                        preferred_dtype=tf.float32)

        # Arg checking.
        if isinstance(fill_value, str):
            if fill_value != 'constant_extension':
                raise ValueError(
                    'A fill value ({}) was not an allowed string ({})'.format(
                        fill_value, 'constant_extension'))
        else:
            fill_value = tf.convert_to_tensor(value=fill_value,
                                              name='fill_value',
                                              dtype=dtype)
            _assert_ndims_statically(fill_value, expect_ndims=0)

        # x.shape = [..., nd].
        x = tf.convert_to_tensor(value=x, name='x', dtype=dtype)
        _assert_ndims_statically(x, expect_ndims_at_least=2)

        # y_ref.shape = [..., C1,...,Cnd, B1,...,BM]
        y_ref = tf.convert_to_tensor(value=y_ref, name='y_ref', dtype=dtype)

        # x_ref_min.shape = [nd]
        x_ref_min = tf.convert_to_tensor(value=x_ref_min,
                                         name='x_ref_min',
                                         dtype=dtype)
        x_ref_max = tf.convert_to_tensor(value=x_ref_max,
                                         name='x_ref_max',
                                         dtype=dtype)
        _assert_ndims_statically(x_ref_min,
                                 expect_ndims_at_least=1,
                                 expect_static=True)
        _assert_ndims_statically(x_ref_max,
                                 expect_ndims_at_least=1,
                                 expect_static=True)

        # nd is the number of dimensions indexing the interpolation table, it's the
        # "nd" in the function name.
        nd = tf.compat.dimension_value(x_ref_min.shape[-1])
        if nd is None:
            raise ValueError('`x_ref_min.shape[-1]` must be known statically.')
        x_ref_max.shape[-1:].assert_is_compatible_with(x_ref_min.shape[-1:])

        # Convert axis and check it statically.
        axis = tf.convert_to_tensor(value=axis, dtype=tf.int32, name='axis')
        axis = distribution_util.make_non_negative_axis(axis, tf.rank(y_ref))
        axis.shape.assert_has_rank(0)
        axis_ = tf.get_static_value(axis)
        y_ref_rank_ = tf.get_static_value(tf.rank(y_ref))
        if axis_ is not None and y_ref_rank_ is not None:
            if axis_ + nd > y_ref_rank_:
                raise ValueError(
                    'Since dims `[axis, axis + nd)` index the interpolation table, we '
                    'must have `axis + nd <= rank(y_ref)`.  Found: '
                    '`axis`: {},  rank(y_ref): {}, and inferred `nd` from trailing '
                    'dimensions of `x_ref_min` to be {}.'.format(
                        axis_, y_ref_rank_, nd))

        x_batch_shape = tf.shape(input=x)[:-2]
        x_ref_min_batch_shape = tf.shape(input=x_ref_min)[:-1]
        x_ref_max_batch_shape = tf.shape(input=x_ref_max)[:-1]
        y_ref_batch_shape = tf.shape(input=y_ref)[:axis]

        # Do a brute-force broadcast of batch dims (add zeros).
        batch_shape = y_ref_batch_shape
        for tensor in [
                x_batch_shape, x_ref_min_batch_shape, x_ref_max_batch_shape
        ]:
            batch_shape = tf.broadcast_dynamic_shape(batch_shape, tensor)

        def _batch_of_zeros_with_rightmost_singletons(n_singletons):
            """Return Tensor of zeros with some singletons on the rightmost dims."""
            ones = tf.ones(shape=[n_singletons], dtype=tf.int32)
            return tf.zeros(shape=tf.concat([batch_shape, ones], axis=0),
                            dtype=dtype)

        x += _batch_of_zeros_with_rightmost_singletons(n_singletons=2)
        x_ref_min += _batch_of_zeros_with_rightmost_singletons(n_singletons=1)
        x_ref_max += _batch_of_zeros_with_rightmost_singletons(n_singletons=1)
        y_ref += _batch_of_zeros_with_rightmost_singletons(
            n_singletons=tf.rank(y_ref) - axis)

        return _batch_interp_with_gather_nd(
            x=x,
            x_ref_min=x_ref_min,
            x_ref_max=x_ref_max,
            y_ref=y_ref,
            nd=nd,
            fill_value=fill_value,
            batch_dims=tf.get_static_value(tf.rank(x)) - 2)
Esempio n. 4
0
def interp_regular_1d_grid(x,
                           x_ref_min,
                           x_ref_max,
                           y_ref,
                           axis=-1,
                           fill_value='constant_extension',
                           fill_value_below=None,
                           fill_value_above=None,
                           grid_regularizing_transform=None,
                           name=None):
  """Linear `1-D` interpolation on a regular (constant spacing) grid.

  Given reference values, this function computes a piecewise linear interpolant
  and evaluates it on a new set of `x` values.

  The interpolant is built from `M` reference values indexed by one dimension
  of `y_ref` (specified by the `axis` kwarg).

  If `y_ref` is a vector, then each value `y_ref[i]` is considered to be equal
  to `f(x_ref[i])`, for `M` (implicitly defined) reference values between
  `x_ref_min` and `x_ref_max`:

  ```none
  x_ref[i] = x_ref_min + i * (x_ref_max - x_ref_min) / (M - 1),
  i = 0, ..., M - 1.
  ```

  If `rank(y_ref) > 1`, then `y_ref` contains `M` reference values of a
  `rank(y_ref) - 1` rank tensor valued function of one variable.
  `x_ref` is a `Tensor` of values of that variable (any shape allowed).

  Args:
    x: Numeric `Tensor` The x-coordinates of the interpolated output values.
    x_ref_min:  `Tensor` of same `dtype` as `x`.  The minimum value of the
      (implicitly defined) reference `x_ref`.
    x_ref_max:  `Tensor` of same `dtype` as `x`.  The maximum value of the
      (implicitly defined) reference `x_ref`.
    y_ref:  `N-D` `Tensor` (`N > 0`) of same `dtype` as `x`.
      The reference output values.
    axis:  Scalar `Tensor` designating the dimension of `y_ref` that indexes
      values of the interpolation variable.
      Default value: `-1`, the rightmost axis.
    fill_value:  Determines what values output should take for `x` values that
      are below `x_ref_min` or above `x_ref_max`.
      `Tensor` or one of the strings
        "constant_extension" ==> Extend as constant function.
        "extrapolate" ==> Extrapolate in a linear fashion.
      Default value: `"constant_extension"`
    fill_value_below:  Optional override of `fill_value` for `x < x_ref_min`.
    fill_value_above:  Optional override of `fill_value` for `x > x_ref_max`.
    grid_regularizing_transform:  Optional transformation `g` which regularizes
      the implied spacing of the x reference points.  In other words, if
      provided, we assume `g(x_ref_i)` is a regular grid between `g(x_ref_min)`
      and `g(x_ref_max)`.
    name:  A name to prepend to created ops.
      Default value: `"interp_regular_1d_grid"`.

  Returns:
    y_interp:  Interpolation between members of `y_ref`, at points `x`.
      `Tensor` of same `dtype` as `x`, and shape
      `y.shape[:axis] + x.shape + y.shape[axis + 1:]`

  Raises:
    ValueError:  If `fill_value` is not an allowed string.
    ValueError:  If `axis` is not a scalar.

  #### Examples

  Interpolate a function of one variable:

  ```python
  y_ref = tf.exp(tf.linspace(start=0., stop=10., 20))

  interp_regular_1d_grid(
      x=[6.0, 0.5, 3.3], x_ref_min=0., x_ref_max=1., y_ref=y_ref)
  ==> approx [exp(6.0), exp(0.5), exp(3.3)]
  ```

  Interpolate a matrix-valued function of one variable:

  ```python
  mat_0 = [[1., 0.], [0., 1.]]
  mat_1 = [[0., -1], [1, 0]]
  y_ref = [mat_0, mat_1]

  # Get three output matrices at once.
  tfp.math.interp_regular_1d_grid(
      x=[0., 0.5, 1.], x_ref_min=0., x_ref_max=1., y_ref=y_ref, axis=0)
  ==> [mat_0, 0.5 * mat_0 + 0.5 * mat_1, mat_1]
  ```

  Interpolate a function of one variable on a log-spaced grid:

  ```python
  x_ref = tf.exp(tf.linspace(tf.log(1.), tf.log(100000.), num_pts))
  y_ref = tf.log(x_ref + x_ref**2)

  interp_regular_1d_grid(x=[1.1, 2.2], x_ref_min=1., x_ref_max=100000., y_ref,
      grid_regularizing_transform=tf.log)
  ==> [tf.log(1.1 + 1.1**2), tf.log(2.2 + 2.2**2)]
  ```

  """

  with tf.name_scope(
      name,
      'interp_regular_1d_grid',
      values=[
          x, x_ref_min, x_ref_max, y_ref, axis, fill_value, fill_value_below,
          fill_value_above
      ]):

    # Arg checking.
    allowed_fv_st = ('constant_extension', 'extrapolate')
    for fv in (fill_value, fill_value_below, fill_value_above):
      if isinstance(fv, str) and fv not in allowed_fv_st:
        raise ValueError(
            'A fill value ({}) was not an allowed string ({})'.format(
                fv, allowed_fv_st))

    # Separate value fills for below/above incurs extra cost, so keep track of
    # whether this is needed.
    need_separate_fills = (
        fill_value_above is not None or fill_value_below is not None or
        fill_value == 'extrapolate'  # always requries separate below/above
    )
    if need_separate_fills and fill_value_above is None:
      fill_value_above = fill_value
    if need_separate_fills and fill_value_below is None:
      fill_value_below = fill_value

    axis = tf.convert_to_tensor(axis, name='axis', dtype=tf.int32)
    _assert_ndims_statically(axis, expect_ndims=0)
    axis = distribution_util.make_non_negative_axis(axis, tf.rank(y_ref))

    dtype = dtype_util.common_dtype([x, x_ref_min, x_ref_max, y_ref],
                                    preferred_dtype=tf.float32)
    x = tf.convert_to_tensor(x, name='x', dtype=dtype)
    x_ref_min = tf.convert_to_tensor(x_ref_min, name='x_ref_min', dtype=dtype)
    x_ref_max = tf.convert_to_tensor(x_ref_max, name='x_ref_max', dtype=dtype)
    y_ref = tf.convert_to_tensor(y_ref, name='y_ref', dtype=dtype)

    ny = tf.cast(tf.shape(y_ref)[axis], dtype)

    # Map [x_ref_min, x_ref_max] to [0, ny - 1].
    # This is the (fractional) index of x.
    if grid_regularizing_transform is None:
      g = lambda x: x
    else:
      g = grid_regularizing_transform
    fractional_idx = ((g(x) - g(x_ref_min)) / (g(x_ref_max) - g(x_ref_min)))
    x_idx_unclipped = fractional_idx * (ny - 1)

    # Wherever x is NaN, x_idx_unclipped will be NaN as well.
    # Keep track of the nan indices here (so we can impute NaN later).
    # Also eliminate any NaN indices, since there is not NaN in 32bit.
    nan_idx = tf.is_nan(x_idx_unclipped)
    x_idx_unclipped = tf.where(nan_idx, tf.zeros_like(x_idx_unclipped),
                               x_idx_unclipped)

    x_idx = tf.clip_by_value(x_idx_unclipped, tf.zeros((), dtype=dtype), ny - 1)

    # Get the index above and below x_idx.
    # Naively we could set idx_below = floor(x_idx), idx_above = ceil(x_idx),
    # however, this results in idx_below == idx_above whenever x is on a grid.
    # This in turn results in y_ref_below == y_ref_above, and then the gradient
    # at this point is zero.  So here we "jitter" one of idx_below, idx_above,
    # so that they are at different values.  This jittering does not affect the
    # interpolated value, but does make the gradient nonzero (unless of course
    # the y_ref values are the same).
    idx_below = tf.floor(x_idx)
    idx_above = tf.minimum(idx_below + 1, ny - 1)
    idx_below = tf.maximum(idx_above - 1, 0)

    # These are the values of y_ref corresponding to above/below indices.
    idx_below_int32 = tf.to_int32(idx_below)
    idx_above_int32 = tf.to_int32(idx_above)
    y_ref_below = tf.gather(y_ref, idx_below_int32, axis=axis)
    y_ref_above = tf.gather(y_ref, idx_above_int32, axis=axis)

    # out_shape = y_ref.shape[:axis] + x.shape + y_ref.shape[axis + 1:]
    out_shape = tf.shape(y_ref_below)

    # Return a convex combination.
    t = x_idx - idx_below

    t = _expand_ends(t, out_shape, axis)

    y = t * y_ref_above + (1 - t) * y_ref_below

    # Now begins a long excursion to fill values outside [x_min, x_max].

    # Re-insert NaN wherever x was NaN.
    y = tf.where(
        _expand_ends(nan_idx, out_shape, axis, broadcast=True),
        tf.fill(tf.shape(y), tf.constant(np.nan, y.dtype)), y)

    x_idx_unclipped = _expand_ends(
        x_idx_unclipped, out_shape, axis, broadcast=True)

    if not need_separate_fills:
      if fill_value == 'constant_extension':
        pass  # Already handled by clipping x_idx_unclipped.
      else:
        y = tf.where((x_idx_unclipped < 0) | (x_idx_unclipped > ny - 1),
                     fill_value + tf.zeros_like(y), y)
    else:
      # Fill values below x_ref_min <==> x_idx_unclipped < 0.
      if fill_value_below == 'constant_extension':
        pass  # Already handled by the clipping that created x_idx_unclipped.
      elif fill_value_below == 'extrapolate':
        y_0 = tf.gather(y_ref, tf.zeros(tf.shape(x), dtype=tf.int32), axis=axis)
        y_1 = tf.gather(y_ref, tf.ones(tf.shape(x), dtype=tf.int32), axis=axis)
        x_delta = (x_ref_max - x_ref_min) / (ny - 1)
        x_factor = (x - x_ref_min) / x_delta
        x_factor = _expand_ends(x_factor, out_shape, axis, broadcast=True)
        y = tf.where(x_idx_unclipped < 0, y_0 + x_factor * (y_1 - y_0), y)
      else:
        y = tf.where(x_idx_unclipped < 0, fill_value_below + tf.zeros_like(y),
                     y)
      # Fill values above x_ref_min <==> x_idx_unclipped > ny - 1.
      if fill_value_above == 'constant_extension':
        pass  # Already handled by the clipping that created x_idx_unclipped.
      elif fill_value_above == 'extrapolate':
        ny_int32 = tf.shape(y_ref)[axis]
        y_n1 = tf.gather(y_ref, tf.fill(tf.shape(x), ny_int32 - 1), axis=axis)
        y_n2 = tf.gather(y_ref, tf.fill(tf.shape(x), ny_int32 - 2), axis=axis)
        x_delta = (x_ref_max - x_ref_min) / (ny - 1)
        x_factor = (x - x_ref_max) / x_delta
        x_factor = _expand_ends(x_factor, out_shape, axis, broadcast=True)
        y = tf.where(x_idx_unclipped > ny - 1,
                     y_n1 + x_factor * (y_n1 - y_n2), y)
      else:
        y = tf.where(x_idx_unclipped > ny - 1,
                     fill_value_above + tf.zeros_like(y), y)

    return y