def _make_static_axis_non_negative_list(axis, ndims): """Convert possibly negatively indexed axis to non-negative list of ints. Args: axis: Integer Tensor. ndims: Number of dimensions into which axis indexes. Returns: A list of non-negative Python integers. Raises: ValueError: If `axis` is not statically defined. """ axis = distribution_util.make_non_negative_axis(axis, ndims) axis_const = tf.contrib.util.constant_value(axis) if axis_const is None: raise ValueError( 'Expected argument `axis` to be statically available. Found: %s' % axis) # Make at least 1-D. axis = axis_const + np.zeros([1], dtype=axis_const.dtype) return list(int(dim) for dim in axis)
def _interp_regular_1d_grid_impl(x, x_ref_min, x_ref_max, y_ref, axis=-1, batch_y_ref=False, fill_value='constant_extension', fill_value_below=None, fill_value_above=None, grid_regularizing_transform=None, name=None): """1-D interpolation that works with/without batching.""" # To understand the implemention differences between the batch/no-batch # versions of this function, you should probably understand the difference # between tf.gather and tf.batch_gather. In particular, we do *not* make the # no-batch version a special case of the batch version, because that would # an inefficient use of batch_gather with unnecessarily broadcast args. with tf.name_scope(name, values=[ x, x_ref_min, x_ref_max, y_ref, axis, fill_value, fill_value_below, fill_value_above ]): # Arg checking. allowed_fv_st = ('constant_extension', 'extrapolate') for fv in (fill_value, fill_value_below, fill_value_above): if isinstance(fv, str) and fv not in allowed_fv_st: raise ValueError( 'A fill value ({}) was not an allowed string ({})'.format( fv, allowed_fv_st)) # Separate value fills for below/above incurs extra cost, so keep track of # whether this is needed. need_separate_fills = ( fill_value_above is not None or fill_value_below is not None or fill_value == 'extrapolate' # always requries separate below/above ) if need_separate_fills and fill_value_above is None: fill_value_above = fill_value if need_separate_fills and fill_value_below is None: fill_value_below = fill_value dtype = dtype_util.common_dtype([x, x_ref_min, x_ref_max, y_ref], preferred_dtype=tf.float32) x = tf.convert_to_tensor(value=x, name='x', dtype=dtype) x_ref_min = tf.convert_to_tensor(value=x_ref_min, name='x_ref_min', dtype=dtype) x_ref_max = tf.convert_to_tensor(value=x_ref_max, name='x_ref_max', dtype=dtype) if not batch_y_ref: _assert_ndims_statically(x_ref_min, expect_ndims=0) _assert_ndims_statically(x_ref_max, expect_ndims=0) y_ref = tf.convert_to_tensor(value=y_ref, name='y_ref', dtype=dtype) if batch_y_ref: # If we're batching, # x.shape ~ [A1,...,AN, D], x_ref_min/max.shape ~ [A1,...,AN] # So to add together we'll append a singleton. # If not batching, x_ref_min/max are scalar, so this isn't an issue, # moreover, if not batching, x can be scalar, and expanding x_ref_min/max # would cause a bad expansion of x when added to x (confused yet?). x_ref_min = x_ref_min[..., tf.newaxis] x_ref_max = x_ref_max[..., tf.newaxis] axis = tf.convert_to_tensor(value=axis, name='axis', dtype=tf.int32) axis = distribution_util.make_non_negative_axis(axis, tf.rank(y_ref)) _assert_ndims_statically(axis, expect_ndims=0) ny = tf.cast(tf.shape(input=y_ref)[axis], dtype) # Map [x_ref_min, x_ref_max] to [0, ny - 1]. # This is the (fractional) index of x. if grid_regularizing_transform is None: g = lambda x: x else: g = grid_regularizing_transform fractional_idx = ((g(x) - g(x_ref_min)) / (g(x_ref_max) - g(x_ref_min))) x_idx_unclipped = fractional_idx * (ny - 1) # Wherever x is NaN, x_idx_unclipped will be NaN as well. # Keep track of the nan indices here (so we can impute NaN later). # Also eliminate any NaN indices, since there is not NaN in 32bit. nan_idx = tf.math.is_nan(x_idx_unclipped) x_idx_unclipped = tf.where(nan_idx, tf.zeros_like(x_idx_unclipped), x_idx_unclipped) x_idx = tf.clip_by_value(x_idx_unclipped, tf.zeros((), dtype=dtype), ny - 1) # Get the index above and below x_idx. # Naively we could set idx_below = floor(x_idx), idx_above = ceil(x_idx), # however, this results in idx_below == idx_above whenever x is on a grid. # This in turn results in y_ref_below == y_ref_above, and then the gradient # at this point is zero. So here we "jitter" one of idx_below, idx_above, # so that they are at different values. This jittering does not affect the # interpolated value, but does make the gradient nonzero (unless of course # the y_ref values are the same). idx_below = tf.floor(x_idx) idx_above = tf.minimum(idx_below + 1, ny - 1) idx_below = tf.maximum(idx_above - 1, 0) # These are the values of y_ref corresponding to above/below indices. idx_below_int32 = tf.cast(idx_below, dtype=tf.int32) idx_above_int32 = tf.cast(idx_above, dtype=tf.int32) if batch_y_ref: # If y_ref.shape ~ [A1,...,AN, C, B1,...,BN], # and x.shape, x_ref_min/max.shape ~ [A1,...,AN, D] # Then y_ref_below.shape ~ [A1,...,AN, D, B1,...,BN] y_ref_below = _batch_gather_with_broadcast(y_ref, idx_below_int32, axis) y_ref_above = _batch_gather_with_broadcast(y_ref, idx_above_int32, axis) else: # Here, y_ref_below.shape = # y_ref.shape[:axis] + x.shape + y_ref.shape[axis + 1:] y_ref_below = tf.gather(y_ref, idx_below_int32, axis=axis) y_ref_above = tf.gather(y_ref, idx_above_int32, axis=axis) # Use t to get a convex combination of the below/above values. t = x_idx - idx_below # x, and tensors shaped like x, need to be added to, and selected with # (using tf.where) the output y. This requires appending singletons. # Make functions appropriate for batch/no-batch. if batch_y_ref: # In the non-batch case, the output shape is going to be # y_ref.shape[:axis] + x.shape + y_ref.shape[axis+1:] expand_x_fn = _make_expand_x_fn_for_batch_interpolation( y_ref, axis) else: # In the batch case, the output shape is going to be # Broadcast(y_ref.shape[:axis], x.shape[:-1]) + # x.shape[-1:] + y_ref.shape[axis+1:] expand_x_fn = _make_expand_x_fn_for_non_batch_interpolation( y_ref, axis) t = expand_x_fn(t) nan_idx = expand_x_fn(nan_idx, broadcast=True) x_idx_unclipped = expand_x_fn(x_idx_unclipped, broadcast=True) y = t * y_ref_above + (1 - t) * y_ref_below # Now begins a long excursion to fill values outside [x_min, x_max]. # Re-insert NaN wherever x was NaN. y = tf.where(nan_idx, tf.fill(tf.shape(input=y), tf.constant(np.nan, y.dtype)), y) if not need_separate_fills: if fill_value == 'constant_extension': pass # Already handled by clipping x_idx_unclipped. else: y = tf.where( (x_idx_unclipped < 0) | (x_idx_unclipped > ny - 1), fill_value + tf.zeros_like(y), y) else: # Fill values below x_ref_min <==> x_idx_unclipped < 0. if fill_value_below == 'constant_extension': pass # Already handled by the clipping that created x_idx_unclipped. elif fill_value_below == 'extrapolate': if batch_y_ref: # For every batch member, gather the first two elements of y across # `axis`. y_0 = tf.gather(y_ref, [0], axis=axis) y_1 = tf.gather(y_ref, [1], axis=axis) else: # If not batching, we want to gather the first two elements, just like # above. However, these results need to be replicated for every # member of x. An easy way to do that is to gather using # indices = zeros/ones(x.shape). y_0 = tf.gather(y_ref, tf.zeros(tf.shape(input=x), dtype=tf.int32), axis=axis) y_1 = tf.gather(y_ref, tf.ones(tf.shape(input=x), dtype=tf.int32), axis=axis) x_delta = (x_ref_max - x_ref_min) / (ny - 1) x_factor = expand_x_fn((x - x_ref_min) / x_delta, broadcast=True) y = tf.where(x_idx_unclipped < 0, y_0 + x_factor * (y_1 - y_0), y) else: y = tf.where(x_idx_unclipped < 0, fill_value_below + tf.zeros_like(y), y) # Fill values above x_ref_min <==> x_idx_unclipped > ny - 1. if fill_value_above == 'constant_extension': pass # Already handled by the clipping that created x_idx_unclipped. elif fill_value_above == 'extrapolate': ny_int32 = tf.shape(input=y_ref)[axis] if batch_y_ref: y_n1 = tf.gather(y_ref, [tf.shape(input=y_ref)[axis] - 1], axis=axis) y_n2 = tf.gather(y_ref, [tf.shape(input=y_ref)[axis] - 2], axis=axis) else: y_n1 = tf.gather(y_ref, tf.fill(tf.shape(input=x), ny_int32 - 1), axis=axis) y_n2 = tf.gather(y_ref, tf.fill(tf.shape(input=x), ny_int32 - 2), axis=axis) x_delta = (x_ref_max - x_ref_min) / (ny - 1) x_factor = expand_x_fn((x - x_ref_max) / x_delta, broadcast=True) y = tf.where(x_idx_unclipped > ny - 1, y_n1 + x_factor * (y_n1 - y_n2), y) else: y = tf.where(x_idx_unclipped > ny - 1, fill_value_above + tf.zeros_like(y), y) return y
def batch_interp_regular_nd_grid(x, x_ref_min, x_ref_max, y_ref, axis, fill_value='constant_extension', name=None): """Multi-linear interpolation on a regular (constant spacing) grid. Given [a batch of] reference values, this function computes a multi-linear interpolant and evaluates it on [a batch of] of new `x` values. The interpolant is built from reference values indexed by `nd` dimensions of `y_ref`, starting at `axis`. For example, take the case of a `2-D` scalar valued function and no leading batch dimensions. In this case, `y_ref.shape = [C1, C2]` and `y_ref[i, j]` is the reference value corresponding to grid point ``` [x_ref_min[0] + i * (x_ref_max[0] - x_ref_min[0]) / (C1 - 1), x_ref_min[1] + j * (x_ref_max[1] - x_ref_min[1]) / (C2 - 1)] ``` In the general case, dimensions to the left of `axis` in `y_ref` are broadcast with leading dimensions in `x`, `x_ref_min`, `x_ref_max`. Args: x: Numeric `Tensor` The x-coordinates of the interpolated output values for each batch. Shape `[..., D, nd]`, designating [a batch of] `D` coordinates in `nd` space. `D` must be `>= 1` and is not a batch dim. x_ref_min: `Tensor` of same `dtype` as `x`. The minimum values of the (implicitly defined) reference `x_ref`. Shape `[..., nd]`. x_ref_max: `Tensor` of same `dtype` as `x`. The maximum values of the (implicitly defined) reference `x_ref`. Shape `[..., nd]`. y_ref: `Tensor` of same `dtype` as `x`. The reference output values. Shape `[..., C1, ..., Cnd, B1,...,BM]`, designating [a batch of] reference values indexed by `nd` dimensions, of a shape `[B1,...,BM]` valued function (for `M >= 0`). axis: Scalar integer `Tensor`. Dimensions `[axis, axis + nd)` of `y_ref` index the interpolation table. E.g. `3-D` interpolation of a scalar valued function requires `axis=-3` and a `3-D` matrix valued function requires `axis=-5`. fill_value: Determines what values output should take for `x` values that are below `x_ref_min` or above `x_ref_max`. Scalar `Tensor` or "constant_extension" ==> Extend as constant function. Default value: `"constant_extension"` name: A name to prepend to created ops. Default value: `"batch_interp_regular_nd_grid"`. Returns: y_interp: Interpolation between members of `y_ref`, at points `x`. `Tensor` of same `dtype` as `x`, and shape `[..., D, B1, ..., BM].` Raises: ValueError: If `rank(x) < 2` is determined statically. ValueError: If `axis` is not a scalar is determined statically. ValueError: If `axis + nd > rank(y_ref)` is determined statically. #### Examples Interpolate a function of one variable. ```python y_ref = tf.exp(tf.linspace(start=0., stop=10., 20)) tfp.math.batch_interp_regular_nd_grid( # x.shape = [3, 1], x_ref_min/max.shape = [1]. Trailing `1` for `1-D`. x=[[6.0], [0.5], [3.3]], x_ref_min=[0.], x_ref_max=[1.], y_ref=y_ref) ==> approx [exp(6.0), exp(0.5), exp(3.3)] ``` Interpolate a scalar function of two variables. ```python x_ref_min = [0., 2 * np.pi] x_ref_max = [0., 2 * np.pi] # Build y_ref. x0s, x1s = tf.meshgrid( tf.linspace(x_ref_min[0], x_ref_max[0], num=100), tf.linspace(x_ref_min[1], x_ref_max[1], num=100), indexing='ij') def func(x0, x1): return tf.sin(x0) * tf.cos(x1) y_ref = func(x0s, x1s) x = np.pi * tf.random_uniform(shape=(10, 2)) tfp.math.batch_interp_regular_nd_grid(x, x_ref_min, x_ref_max, y_ref, axis=-2) ==> tf.sin(x[:, 0]) * tf.cos(x[:, 1]) ``` """ with tf.compat.v1.name_scope( name, default_name='interp_regular_nd_grid', values=[x, x_ref_min, x_ref_max, y_ref, fill_value]): dtype = dtype_util.common_dtype([x, x_ref_min, x_ref_max, y_ref], preferred_dtype=tf.float32) # Arg checking. if isinstance(fill_value, str): if fill_value != 'constant_extension': raise ValueError( 'A fill value ({}) was not an allowed string ({})'.format( fill_value, 'constant_extension')) else: fill_value = tf.convert_to_tensor(value=fill_value, name='fill_value', dtype=dtype) _assert_ndims_statically(fill_value, expect_ndims=0) # x.shape = [..., nd]. x = tf.convert_to_tensor(value=x, name='x', dtype=dtype) _assert_ndims_statically(x, expect_ndims_at_least=2) # y_ref.shape = [..., C1,...,Cnd, B1,...,BM] y_ref = tf.convert_to_tensor(value=y_ref, name='y_ref', dtype=dtype) # x_ref_min.shape = [nd] x_ref_min = tf.convert_to_tensor(value=x_ref_min, name='x_ref_min', dtype=dtype) x_ref_max = tf.convert_to_tensor(value=x_ref_max, name='x_ref_max', dtype=dtype) _assert_ndims_statically(x_ref_min, expect_ndims_at_least=1, expect_static=True) _assert_ndims_statically(x_ref_max, expect_ndims_at_least=1, expect_static=True) # nd is the number of dimensions indexing the interpolation table, it's the # "nd" in the function name. nd = tf.compat.dimension_value(x_ref_min.shape[-1]) if nd is None: raise ValueError('`x_ref_min.shape[-1]` must be known statically.') x_ref_max.shape[-1:].assert_is_compatible_with(x_ref_min.shape[-1:]) # Convert axis and check it statically. axis = tf.convert_to_tensor(value=axis, dtype=tf.int32, name='axis') axis = distribution_util.make_non_negative_axis(axis, tf.rank(y_ref)) axis.shape.assert_has_rank(0) axis_ = tf.get_static_value(axis) y_ref_rank_ = tf.get_static_value(tf.rank(y_ref)) if axis_ is not None and y_ref_rank_ is not None: if axis_ + nd > y_ref_rank_: raise ValueError( 'Since dims `[axis, axis + nd)` index the interpolation table, we ' 'must have `axis + nd <= rank(y_ref)`. Found: ' '`axis`: {}, rank(y_ref): {}, and inferred `nd` from trailing ' 'dimensions of `x_ref_min` to be {}.'.format( axis_, y_ref_rank_, nd)) x_batch_shape = tf.shape(input=x)[:-2] x_ref_min_batch_shape = tf.shape(input=x_ref_min)[:-1] x_ref_max_batch_shape = tf.shape(input=x_ref_max)[:-1] y_ref_batch_shape = tf.shape(input=y_ref)[:axis] # Do a brute-force broadcast of batch dims (add zeros). batch_shape = y_ref_batch_shape for tensor in [ x_batch_shape, x_ref_min_batch_shape, x_ref_max_batch_shape ]: batch_shape = tf.broadcast_dynamic_shape(batch_shape, tensor) def _batch_of_zeros_with_rightmost_singletons(n_singletons): """Return Tensor of zeros with some singletons on the rightmost dims.""" ones = tf.ones(shape=[n_singletons], dtype=tf.int32) return tf.zeros(shape=tf.concat([batch_shape, ones], axis=0), dtype=dtype) x += _batch_of_zeros_with_rightmost_singletons(n_singletons=2) x_ref_min += _batch_of_zeros_with_rightmost_singletons(n_singletons=1) x_ref_max += _batch_of_zeros_with_rightmost_singletons(n_singletons=1) y_ref += _batch_of_zeros_with_rightmost_singletons( n_singletons=tf.rank(y_ref) - axis) return _batch_interp_with_gather_nd( x=x, x_ref_min=x_ref_min, x_ref_max=x_ref_max, y_ref=y_ref, nd=nd, fill_value=fill_value, batch_dims=tf.get_static_value(tf.rank(x)) - 2)
def interp_regular_1d_grid(x, x_ref_min, x_ref_max, y_ref, axis=-1, fill_value='constant_extension', fill_value_below=None, fill_value_above=None, grid_regularizing_transform=None, name=None): """Linear `1-D` interpolation on a regular (constant spacing) grid. Given reference values, this function computes a piecewise linear interpolant and evaluates it on a new set of `x` values. The interpolant is built from `M` reference values indexed by one dimension of `y_ref` (specified by the `axis` kwarg). If `y_ref` is a vector, then each value `y_ref[i]` is considered to be equal to `f(x_ref[i])`, for `M` (implicitly defined) reference values between `x_ref_min` and `x_ref_max`: ```none x_ref[i] = x_ref_min + i * (x_ref_max - x_ref_min) / (M - 1), i = 0, ..., M - 1. ``` If `rank(y_ref) > 1`, then `y_ref` contains `M` reference values of a `rank(y_ref) - 1` rank tensor valued function of one variable. `x_ref` is a `Tensor` of values of that variable (any shape allowed). Args: x: Numeric `Tensor` The x-coordinates of the interpolated output values. x_ref_min: `Tensor` of same `dtype` as `x`. The minimum value of the (implicitly defined) reference `x_ref`. x_ref_max: `Tensor` of same `dtype` as `x`. The maximum value of the (implicitly defined) reference `x_ref`. y_ref: `N-D` `Tensor` (`N > 0`) of same `dtype` as `x`. The reference output values. axis: Scalar `Tensor` designating the dimension of `y_ref` that indexes values of the interpolation variable. Default value: `-1`, the rightmost axis. fill_value: Determines what values output should take for `x` values that are below `x_ref_min` or above `x_ref_max`. `Tensor` or one of the strings "constant_extension" ==> Extend as constant function. "extrapolate" ==> Extrapolate in a linear fashion. Default value: `"constant_extension"` fill_value_below: Optional override of `fill_value` for `x < x_ref_min`. fill_value_above: Optional override of `fill_value` for `x > x_ref_max`. grid_regularizing_transform: Optional transformation `g` which regularizes the implied spacing of the x reference points. In other words, if provided, we assume `g(x_ref_i)` is a regular grid between `g(x_ref_min)` and `g(x_ref_max)`. name: A name to prepend to created ops. Default value: `"interp_regular_1d_grid"`. Returns: y_interp: Interpolation between members of `y_ref`, at points `x`. `Tensor` of same `dtype` as `x`, and shape `y.shape[:axis] + x.shape + y.shape[axis + 1:]` Raises: ValueError: If `fill_value` is not an allowed string. ValueError: If `axis` is not a scalar. #### Examples Interpolate a function of one variable: ```python y_ref = tf.exp(tf.linspace(start=0., stop=10., 20)) interp_regular_1d_grid( x=[6.0, 0.5, 3.3], x_ref_min=0., x_ref_max=1., y_ref=y_ref) ==> approx [exp(6.0), exp(0.5), exp(3.3)] ``` Interpolate a matrix-valued function of one variable: ```python mat_0 = [[1., 0.], [0., 1.]] mat_1 = [[0., -1], [1, 0]] y_ref = [mat_0, mat_1] # Get three output matrices at once. tfp.math.interp_regular_1d_grid( x=[0., 0.5, 1.], x_ref_min=0., x_ref_max=1., y_ref=y_ref, axis=0) ==> [mat_0, 0.5 * mat_0 + 0.5 * mat_1, mat_1] ``` Interpolate a function of one variable on a log-spaced grid: ```python x_ref = tf.exp(tf.linspace(tf.log(1.), tf.log(100000.), num_pts)) y_ref = tf.log(x_ref + x_ref**2) interp_regular_1d_grid(x=[1.1, 2.2], x_ref_min=1., x_ref_max=100000., y_ref, grid_regularizing_transform=tf.log) ==> [tf.log(1.1 + 1.1**2), tf.log(2.2 + 2.2**2)] ``` """ with tf.name_scope( name, 'interp_regular_1d_grid', values=[ x, x_ref_min, x_ref_max, y_ref, axis, fill_value, fill_value_below, fill_value_above ]): # Arg checking. allowed_fv_st = ('constant_extension', 'extrapolate') for fv in (fill_value, fill_value_below, fill_value_above): if isinstance(fv, str) and fv not in allowed_fv_st: raise ValueError( 'A fill value ({}) was not an allowed string ({})'.format( fv, allowed_fv_st)) # Separate value fills for below/above incurs extra cost, so keep track of # whether this is needed. need_separate_fills = ( fill_value_above is not None or fill_value_below is not None or fill_value == 'extrapolate' # always requries separate below/above ) if need_separate_fills and fill_value_above is None: fill_value_above = fill_value if need_separate_fills and fill_value_below is None: fill_value_below = fill_value axis = tf.convert_to_tensor(axis, name='axis', dtype=tf.int32) _assert_ndims_statically(axis, expect_ndims=0) axis = distribution_util.make_non_negative_axis(axis, tf.rank(y_ref)) dtype = dtype_util.common_dtype([x, x_ref_min, x_ref_max, y_ref], preferred_dtype=tf.float32) x = tf.convert_to_tensor(x, name='x', dtype=dtype) x_ref_min = tf.convert_to_tensor(x_ref_min, name='x_ref_min', dtype=dtype) x_ref_max = tf.convert_to_tensor(x_ref_max, name='x_ref_max', dtype=dtype) y_ref = tf.convert_to_tensor(y_ref, name='y_ref', dtype=dtype) ny = tf.cast(tf.shape(y_ref)[axis], dtype) # Map [x_ref_min, x_ref_max] to [0, ny - 1]. # This is the (fractional) index of x. if grid_regularizing_transform is None: g = lambda x: x else: g = grid_regularizing_transform fractional_idx = ((g(x) - g(x_ref_min)) / (g(x_ref_max) - g(x_ref_min))) x_idx_unclipped = fractional_idx * (ny - 1) # Wherever x is NaN, x_idx_unclipped will be NaN as well. # Keep track of the nan indices here (so we can impute NaN later). # Also eliminate any NaN indices, since there is not NaN in 32bit. nan_idx = tf.is_nan(x_idx_unclipped) x_idx_unclipped = tf.where(nan_idx, tf.zeros_like(x_idx_unclipped), x_idx_unclipped) x_idx = tf.clip_by_value(x_idx_unclipped, tf.zeros((), dtype=dtype), ny - 1) # Get the index above and below x_idx. # Naively we could set idx_below = floor(x_idx), idx_above = ceil(x_idx), # however, this results in idx_below == idx_above whenever x is on a grid. # This in turn results in y_ref_below == y_ref_above, and then the gradient # at this point is zero. So here we "jitter" one of idx_below, idx_above, # so that they are at different values. This jittering does not affect the # interpolated value, but does make the gradient nonzero (unless of course # the y_ref values are the same). idx_below = tf.floor(x_idx) idx_above = tf.minimum(idx_below + 1, ny - 1) idx_below = tf.maximum(idx_above - 1, 0) # These are the values of y_ref corresponding to above/below indices. idx_below_int32 = tf.to_int32(idx_below) idx_above_int32 = tf.to_int32(idx_above) y_ref_below = tf.gather(y_ref, idx_below_int32, axis=axis) y_ref_above = tf.gather(y_ref, idx_above_int32, axis=axis) # out_shape = y_ref.shape[:axis] + x.shape + y_ref.shape[axis + 1:] out_shape = tf.shape(y_ref_below) # Return a convex combination. t = x_idx - idx_below t = _expand_ends(t, out_shape, axis) y = t * y_ref_above + (1 - t) * y_ref_below # Now begins a long excursion to fill values outside [x_min, x_max]. # Re-insert NaN wherever x was NaN. y = tf.where( _expand_ends(nan_idx, out_shape, axis, broadcast=True), tf.fill(tf.shape(y), tf.constant(np.nan, y.dtype)), y) x_idx_unclipped = _expand_ends( x_idx_unclipped, out_shape, axis, broadcast=True) if not need_separate_fills: if fill_value == 'constant_extension': pass # Already handled by clipping x_idx_unclipped. else: y = tf.where((x_idx_unclipped < 0) | (x_idx_unclipped > ny - 1), fill_value + tf.zeros_like(y), y) else: # Fill values below x_ref_min <==> x_idx_unclipped < 0. if fill_value_below == 'constant_extension': pass # Already handled by the clipping that created x_idx_unclipped. elif fill_value_below == 'extrapolate': y_0 = tf.gather(y_ref, tf.zeros(tf.shape(x), dtype=tf.int32), axis=axis) y_1 = tf.gather(y_ref, tf.ones(tf.shape(x), dtype=tf.int32), axis=axis) x_delta = (x_ref_max - x_ref_min) / (ny - 1) x_factor = (x - x_ref_min) / x_delta x_factor = _expand_ends(x_factor, out_shape, axis, broadcast=True) y = tf.where(x_idx_unclipped < 0, y_0 + x_factor * (y_1 - y_0), y) else: y = tf.where(x_idx_unclipped < 0, fill_value_below + tf.zeros_like(y), y) # Fill values above x_ref_min <==> x_idx_unclipped > ny - 1. if fill_value_above == 'constant_extension': pass # Already handled by the clipping that created x_idx_unclipped. elif fill_value_above == 'extrapolate': ny_int32 = tf.shape(y_ref)[axis] y_n1 = tf.gather(y_ref, tf.fill(tf.shape(x), ny_int32 - 1), axis=axis) y_n2 = tf.gather(y_ref, tf.fill(tf.shape(x), ny_int32 - 2), axis=axis) x_delta = (x_ref_max - x_ref_min) / (ny - 1) x_factor = (x - x_ref_max) / x_delta x_factor = _expand_ends(x_factor, out_shape, axis, broadcast=True) y = tf.where(x_idx_unclipped > ny - 1, y_n1 + x_factor * (y_n1 - y_n2), y) else: y = tf.where(x_idx_unclipped > ny - 1, fill_value_above + tf.zeros_like(y), y) return y