def _log_average_probs_process_args(logits, validate_args, sample_axis, event_axis): """Processes args for `log_average_probs`.""" rank = ps.rank(logits) if sample_axis is None or validate_args: event_axis = ps.reshape(ps.non_negative_axis(event_axis, rank), shape=[-1]) if sample_axis is None: sample_axis = ps.setdiff1d(ps.range(rank), event_axis) elif validate_args: sample_axis = ps.reshape(ps.non_negative_axis(sample_axis, rank), shape=[-1]) return sample_axis, event_axis
def _make_static_axis_non_negative_list(axis, ndims): """Convert possibly negatively indexed axis to non-negative list of ints. Args: axis: Integer Tensor. ndims: Number of dimensions into which axis indexes. Returns: A list of non-negative Python integers. Raises: ValueError: If `axis` is not statically defined. """ axis = prefer_static.non_negative_axis(axis, ndims) axis_const = tf.get_static_value(axis) if axis_const is None: raise ValueError( 'Expected argument `axis` to be statically available. Found: %s' % axis) # Make at least 1-D. axis = axis_const + np.zeros([1], dtype=axis_const.dtype) return list(int(dim) for dim in axis)
def _effective_sample_size_single_state(states, filter_beyond_lag, filter_threshold, filter_beyond_positive_pairs, cross_chain_dims, validate_args): """ESS computation for one single Tensor argument.""" with tf.name_scope('effective_sample_size_single_state'): states = tf.convert_to_tensor(states, name='states') dt = states.dtype # filter_beyond_lag == None ==> auto_corr is the full sequence. auto_cov = stats.auto_correlation(states, axis=0, max_lags=filter_beyond_lag, normalize=False) n = _axis_size(states, axis=0) if cross_chain_dims is not None: num_chains = _axis_size(states, cross_chain_dims) num_chains_ = tf.get_static_value(num_chains) assertions = [] msg = ( 'When `cross_chain_dims` is not `None`, there must be > 1 chain ' 'in `states`.') if num_chains_ is not None: if num_chains_ < 2: raise ValueError(msg) elif validate_args: assertions.append( assert_util.assert_greater(num_chains, 1., message=msg)) with tf.control_dependencies(assertions): # We're computing the R[k] from equation 10 of Vehtari et al. # (2019): # # R[k] := 1 - (W - 1/C * Sum_{c=1}^C s_c**2 R[k, c]) / (var^+), # # where: # C := number of chains # N := length of chains # x_hat[c] := 1 / N Sum_{n=1}^N x[n, c], chain mean. # x_hat := 1 / C Sum_{c=1}^C x_hat[c], overall mean. # W := 1/C Sum_{c=1}^C s_c**2, within-chain variance. # B := N / (C - 1) Sum_{c=1}^C (x_hat[c] - x_hat)**2, between chain # variance. # s_c**2 := 1 / (N - 1) Sum_{n=1}^N (x[n, c] - x_hat[c])**2, chain # variance # R[k, m] := auto_corr[k, m, ...], auto-correlation indexed by chain. # var^+ := (N - 1) / N * W + B / N cross_chain_dims = prefer_static.non_negative_axis( cross_chain_dims, prefer_static.rank(states)) # B / N between_chain_variance_div_n = _reduce_variance( tf.reduce_mean(states, axis=0), biased=False, # This makes the denominator be C - 1. axis=cross_chain_dims - 1) # W * (N - 1) / N biased_within_chain_variance = tf.reduce_mean( auto_cov[0], cross_chain_dims - 1) # var^+ approx_variance = (biased_within_chain_variance + between_chain_variance_div_n) # 1/C * Sum_{c=1}^C s_c**2 R[k, c] mean_auto_cov = tf.reduce_mean(auto_cov, cross_chain_dims) auto_corr = 1. - (biased_within_chain_variance - mean_auto_cov) / approx_variance else: auto_corr = auto_cov / auto_cov[:1] num_chains = 1 # With R[k] := auto_corr[k, ...], # ESS = N / {1 + 2 * Sum_{k=1}^N R[k] * (N - k) / N} # = N / {-1 + 2 * Sum_{k=0}^N R[k] * (N - k) / N} (since R[0] = 1) # approx N / {-1 + 2 * Sum_{k=0}^M R[k] * (N - k) / N} # where M is the filter_beyond_lag truncation point chosen above. # Get the factor (N - k) / N, and give it shape [M, 1,...,1], having total # ndims the same as auto_corr k = tf.range(0., _axis_size(auto_corr, axis=0)) nk_factor = (n - k) / n if tensorshape_util.rank(auto_corr.shape) is not None: new_shape = [-1 ] + [1] * (tensorshape_util.rank(auto_corr.shape) - 1) else: new_shape = tf.concat( ([-1], tf.ones([tf.rank(auto_corr) - 1], dtype=tf.int32)), axis=0) nk_factor = tf.reshape(nk_factor, new_shape) weighted_auto_corr = nk_factor * auto_corr if filter_beyond_positive_pairs: def _sum_pairs(x): x_len = tf.shape(x)[0] # For odd sequences, we drop the final value. x = x[:x_len - x_len % 2] new_shape = tf.concat( [[x_len // 2, 2], tf.shape(x)[1:]], axis=0) return tf.reduce_sum(tf.reshape(x, new_shape), 1) # Pairwise sums are all positive for auto-correlation spectra derived from # reversible MCMC chains. # E.g. imagine the pairwise sums are [0.2, 0.1, -0.1, -0.2] # Step 1: mask = [False, False, True, True] mask = _sum_pairs(auto_corr) < 0. # Step 2: mask = [0, 0, 1, 1] mask = tf.cast(mask, dt) # Step 3: mask = [0, 0, 1, 2] mask = tf.cumsum(mask, axis=0) # Step 4: mask = [1, 1, 0, 0] mask = tf.maximum(1. - mask, 0.) # N.B. this reduces the length of weighted_auto_corr by a factor of 2. # It still works fine in the formula below. weighted_auto_corr = _sum_pairs(weighted_auto_corr) * mask elif filter_threshold is not None: filter_threshold = tf.convert_to_tensor(filter_threshold, dtype=dt, name='filter_threshold') # Get a binary mask to zero out values of auto_corr below the threshold. # mask[i, ...] = 1 if auto_corr[j, ...] > threshold for all j <= i, # mask[i, ...] = 0, otherwise. # So, along dimension zero, the mask will look like [1, 1, ..., 0, 0,...] # Building step by step, # Assume auto_corr = [1, 0.5, 0.0, 0.3], and filter_threshold = 0.2. # Step 1: mask = [False, False, True, False] mask = auto_corr < filter_threshold # Step 2: mask = [0, 0, 1, 0] mask = tf.cast(mask, dtype=dt) # Step 3: mask = [0, 0, 1, 1] mask = tf.cumsum(mask, axis=0) # Step 4: mask = [1, 1, 0, 0] mask = tf.maximum(1. - mask, 0.) weighted_auto_corr *= mask return num_chains * n / (-1 + 2 * tf.reduce_sum(weighted_auto_corr, axis=0))
def _log_loosum_exp_impl(logx, axis, keepdims, compute_mean): """Implementation for `*loosum*` functions.""" with tf.name_scope('log_loosum_exp_impl'): logx = tf.convert_to_tensor(logx, name='logx') dtype = dtype_util.as_numpy_dtype(logx.dtype) if axis is not None: x = np.array(axis) axis = (tf.convert_to_tensor( axis, name='axis', dtype_hint=tf.int32) if x.dtype is np.object else x.astype(np.int32)) log_sum_x = tf.reduce_logsumexp(logx, axis=axis, keepdims=True) # Later we'll want to compute the mean from a sum so we calculate the number # of reduced elements, n. n = prefer_static.size(logx) // prefer_static.size(log_sum_x) n = prefer_static.cast(n, dtype) # log_loosum_x[i] = # = logsumexp(logx[j] : j != i) # = log( exp(logsumexp(logx)) - exp(logx[i]) ) # = log( exp(logsumexp(logx - logx[i])) exp(logx[i]) - exp(logx[i])) # = logx[i] + log(exp(logsumexp(logx - logx[i])) - 1) # = logx[i] + log(exp(logsumexp(logx) - logx[i]) - 1) # = logx[i] + softplus_inverse(logsumexp(logx) - logx[i]) d = log_sum_x - logx # We use `d != 0` rather than `d > 0.` because `d < 0.` should never happen; # if it does we want to complain loudly (which `softplus_inverse` will). d_ok = tf.not_equal(d, 0.) safe_d = tf.where(d_ok, d, 1.) d_ok_result = logx + softplus_inverse(safe_d) neg_inf = tf.constant(-np.inf, dtype=dtype) # When not(d_ok) and is_positive_and_largest then we manually compute the # log_loosum_x. (We can efficiently do this for any one point but not all, # hence we still need the above calculation.) This is good because when # this condition is met, we cannot use the above calculation; its -inf. # We now compute the log-leave-out-max-sum, replicate it to every # point and make sure to select it only when we need to. max_logx = tf.reduce_max(logx, axis=axis, keepdims=True) is_positive_and_largest = (logx > 0.) & tf.equal(logx, max_logx) log_lomsum_x = tf.reduce_logsumexp(tf.where(is_positive_and_largest, neg_inf, logx), axis=axis, keepdims=True) d_not_ok_result = tf.where(is_positive_and_largest, log_lomsum_x, neg_inf) log_loosum_x = tf.where(d_ok, d_ok_result, d_not_ok_result) # We now squeeze log_sum_x so as if we used `keepdims=False`. # TODO(b/136176077): These mental gymnastics could all be replaced with # `tf.squeeze(log_sum_x, axis)` if tf.squeeze supported Tensor valued `axis` # arguments. if not keepdims: if axis is None: keepdims = np.array([], dtype=np.int32) else: rank = prefer_static.rank(logx) keepdims = prefer_static.setdiff1d( prefer_static.range(rank), prefer_static.non_negative_axis(axis, rank)) squeeze_shape = tf.gather(prefer_static.shape(logx), indices=keepdims) log_sum_x = tf.reshape(log_sum_x, shape=squeeze_shape) if prefer_static.is_numpy(keepdims): tensorshape_util.set_shape(log_sum_x, np.array(logx.shape)[keepdims]) # Set static shapes just in case we lost them. tensorshape_util.set_shape(n, []) tensorshape_util.set_shape(log_loosum_x, logx.shape) if not compute_mean: return log_loosum_x, log_sum_x, n log_nm1 = prefer_static.log(max(1., n - 1.)) log_n = prefer_static.log(n) return log_loosum_x - log_nm1, log_sum_x - log_n, n
def test_dynamic_vector_index(self): axis = tf.Variable([0, -2]) positive_axis = ps.non_negative_axis(axis=axis, rank=4) self.evaluate(axis.initializer) self.assertAllEqual([0, 2], self.evaluate(positive_axis))
def test_static_vector_index(self): positive_axis = ps.non_negative_axis(axis=[0, -2], rank=4) self.assertAllEqual([0, 2], positive_axis)
def test_static_scalar_negative_index(self): positive_axis = ps.non_negative_axis(axis=-1, rank=4) self.assertAllEqual(3, positive_axis)
def batch_interp_regular_nd_grid(x, x_ref_min, x_ref_max, y_ref, axis, fill_value='constant_extension', name=None): """Multi-linear interpolation on a regular (constant spacing) grid. Given [a batch of] reference values, this function computes a multi-linear interpolant and evaluates it on [a batch of] of new `x` values. The interpolant is built from reference values indexed by `nd` dimensions of `y_ref`, starting at `axis`. For example, take the case of a `2-D` scalar valued function and no leading batch dimensions. In this case, `y_ref.shape = [C1, C2]` and `y_ref[i, j]` is the reference value corresponding to grid point ``` [x_ref_min[0] + i * (x_ref_max[0] - x_ref_min[0]) / (C1 - 1), x_ref_min[1] + j * (x_ref_max[1] - x_ref_min[1]) / (C2 - 1)] ``` In the general case, dimensions to the left of `axis` in `y_ref` are broadcast with leading dimensions in `x`, `x_ref_min`, `x_ref_max`. Args: x: Numeric `Tensor` The x-coordinates of the interpolated output values for each batch. Shape `[..., D, nd]`, designating [a batch of] `D` coordinates in `nd` space. `D` must be `>= 1` and is not a batch dim. x_ref_min: `Tensor` of same `dtype` as `x`. The minimum values of the (implicitly defined) reference `x_ref`. Shape `[..., nd]`. x_ref_max: `Tensor` of same `dtype` as `x`. The maximum values of the (implicitly defined) reference `x_ref`. Shape `[..., nd]`. y_ref: `Tensor` of same `dtype` as `x`. The reference output values. Shape `[..., C1, ..., Cnd, B1,...,BM]`, designating [a batch of] reference values indexed by `nd` dimensions, of a shape `[B1,...,BM]` valued function (for `M >= 0`). axis: Scalar integer `Tensor`. Dimensions `[axis, axis + nd)` of `y_ref` index the interpolation table. E.g. `3-D` interpolation of a scalar valued function requires `axis=-3` and a `3-D` matrix valued function requires `axis=-5`. fill_value: Determines what values output should take for `x` values that are below `x_ref_min` or above `x_ref_max`. Scalar `Tensor` or 'constant_extension' ==> Extend as constant function. Default value: `'constant_extension'` name: A name to prepend to created ops. Default value: `'batch_interp_regular_nd_grid'`. Returns: y_interp: Interpolation between members of `y_ref`, at points `x`. `Tensor` of same `dtype` as `x`, and shape `[..., D, B1, ..., BM].` Raises: ValueError: If `rank(x) < 2` is determined statically. ValueError: If `axis` is not a scalar is determined statically. ValueError: If `axis + nd > rank(y_ref)` is determined statically. #### Examples Interpolate a function of one variable. ```python y_ref = tf.exp(tf.linspace(start=0., stop=10., num=20)) tfp.math.batch_interp_regular_nd_grid( # x.shape = [3, 1], x_ref_min/max.shape = [1]. Trailing `1` for `1-D`. x=[[6.0], [0.5], [3.3]], x_ref_min=[0.], x_ref_max=[10.], y_ref=y_ref, axis=0) ==> approx [exp(6.0), exp(0.5), exp(3.3)] ``` Interpolate a scalar function of two variables. ```python x_ref_min = [0., 0.] x_ref_max = [2 * np.pi, 2 * np.pi] # Build y_ref. x0s, x1s = tf.meshgrid( tf.linspace(x_ref_min[0], x_ref_max[0], num=100), tf.linspace(x_ref_min[1], x_ref_max[1], num=100), indexing='ij') def func(x0, x1): return tf.sin(x0) * tf.cos(x1) y_ref = func(x0s, x1s) x = np.pi * tf.random.uniform(shape=(10, 2)) tfp.math.batch_interp_regular_nd_grid(x, x_ref_min, x_ref_max, y_ref, axis=-2) ==> tf.sin(x[:, 0]) * tf.cos(x[:, 1]) ``` """ with tf.name_scope(name or 'interp_regular_nd_grid'): dtype = dtype_util.common_dtype([x, x_ref_min, x_ref_max, y_ref], dtype_hint=tf.float32) # Arg checking. if isinstance(fill_value, str): if fill_value != 'constant_extension': raise ValueError( 'A fill value ({}) was not an allowed string ({})'.format( fill_value, 'constant_extension')) else: fill_value = tf.convert_to_tensor( fill_value, name='fill_value', dtype=dtype) _assert_ndims_statically(fill_value, expect_ndims=0) # x.shape = [..., nd]. x = tf.convert_to_tensor(x, name='x', dtype=dtype) _assert_ndims_statically(x, expect_ndims_at_least=2) # y_ref.shape = [..., C1,...,Cnd, B1,...,BM] y_ref = tf.convert_to_tensor(y_ref, name='y_ref', dtype=dtype) # x_ref_min.shape = [nd] x_ref_min = tf.convert_to_tensor( x_ref_min, name='x_ref_min', dtype=dtype) x_ref_max = tf.convert_to_tensor( x_ref_max, name='x_ref_max', dtype=dtype) _assert_ndims_statically( x_ref_min, expect_ndims_at_least=1, expect_static=True) _assert_ndims_statically( x_ref_max, expect_ndims_at_least=1, expect_static=True) # nd is the number of dimensions indexing the interpolation table, it's the # 'nd' in the function name. nd = tf.compat.dimension_value(x_ref_min.shape[-1]) if nd is None: raise ValueError('`x_ref_min.shape[-1]` must be known statically.') tensorshape_util.assert_is_compatible_with( x_ref_max.shape[-1:], x_ref_min.shape[-1:]) # Convert axis and check it statically. axis = tf.convert_to_tensor(axis, dtype=tf.int32, name='axis') axis = ps.non_negative_axis(axis, tf.rank(y_ref)) tensorshape_util.assert_has_rank(axis.shape, 0) axis_ = tf.get_static_value(axis) y_ref_rank_ = tf.get_static_value(tf.rank(y_ref)) if axis_ is not None and y_ref_rank_ is not None: if axis_ + nd > y_ref_rank_: raise ValueError( 'Since dims `[axis, axis + nd)` index the interpolation table, we ' 'must have `axis + nd <= rank(y_ref)`. Found: ' '`axis`: {}, rank(y_ref): {}, and inferred `nd` from trailing ' 'dimensions of `x_ref_min` to be {}.'.format( axis_, y_ref_rank_, nd)) x_batch_shape = ps.shape(x)[:-2] x_ref_min_batch_shape = ps.shape(x_ref_min)[:-1] x_ref_max_batch_shape = ps.shape(x_ref_max)[:-1] y_ref_batch_shape = ps.shape(y_ref)[:axis] # Do a brute-force broadcast of batch dims (add zeros). batch_shape = y_ref_batch_shape for tensor in [x_batch_shape, x_ref_min_batch_shape, x_ref_max_batch_shape]: batch_shape = ps.broadcast_shape(batch_shape, tensor) def _batch_shape_of_zeros_with_rightmost_singletons(n_singletons): """Return Tensor of zeros with some singletons on the rightmost dims.""" ones = ps.ones(shape=[n_singletons], dtype=tf.int32) return ps.concat([batch_shape, ones], axis=0) x = _broadcast_with( x, _batch_shape_of_zeros_with_rightmost_singletons(n_singletons=2)) x_ref_min = _broadcast_with( x_ref_min, _batch_shape_of_zeros_with_rightmost_singletons(n_singletons=1)) x_ref_max = _broadcast_with( x_ref_max, _batch_shape_of_zeros_with_rightmost_singletons(n_singletons=1)) y_ref = _broadcast_with( y_ref, _batch_shape_of_zeros_with_rightmost_singletons( n_singletons=tf.rank(y_ref) - axis)) return _batch_interp_with_gather_nd( x=x, x_ref_min=x_ref_min, x_ref_max=x_ref_max, y_ref=y_ref, nd=nd, fill_value=fill_value, batch_dims=ps.rank(x) - 2)
def _interp_regular_1d_grid_impl(x, x_ref_min, x_ref_max, y_ref, axis=-1, batch_y_ref=False, fill_value='constant_extension', fill_value_below=None, fill_value_above=None, grid_regularizing_transform=None, name=None): """1-D interpolation that works with/without batching.""" # Note: we do *not* make the no-batch version a special case of the batch # version, because that would an inefficient use of batch_gather with # unnecessarily broadcast args. with tf.name_scope(name or 'interp_regular_1d_grid_impl'): # Arg checking. allowed_fv_st = ('constant_extension', 'extrapolate') for fv in (fill_value, fill_value_below, fill_value_above): if isinstance(fv, str) and fv not in allowed_fv_st: raise ValueError( 'A fill value ({}) was not an allowed string ({})'.format( fv, allowed_fv_st)) # Separate value fills for below/above incurs extra cost, so keep track of # whether this is needed. need_separate_fills = ( fill_value_above is not None or fill_value_below is not None or fill_value == 'extrapolate' # always requries separate below/above ) if need_separate_fills and fill_value_above is None: fill_value_above = fill_value if need_separate_fills and fill_value_below is None: fill_value_below = fill_value dtype = dtype_util.common_dtype([x, x_ref_min, x_ref_max, y_ref], dtype_hint=tf.float32) x = tf.convert_to_tensor(x, name='x', dtype=dtype) x_ref_min = tf.convert_to_tensor( x_ref_min, name='x_ref_min', dtype=dtype) x_ref_max = tf.convert_to_tensor( x_ref_max, name='x_ref_max', dtype=dtype) if not batch_y_ref: _assert_ndims_statically(x_ref_min, expect_ndims=0) _assert_ndims_statically(x_ref_max, expect_ndims=0) y_ref = tf.convert_to_tensor(y_ref, name='y_ref', dtype=dtype) if batch_y_ref: # If we're batching, # x.shape ~ [A1,...,AN, D], x_ref_min/max.shape ~ [A1,...,AN] # So to add together we'll append a singleton. # If not batching, x_ref_min/max are scalar, so this isn't an issue, # moreover, if not batching, x can be scalar, and expanding x_ref_min/max # would cause a bad expansion of x when added to x (confused yet?). x_ref_min = x_ref_min[..., tf.newaxis] x_ref_max = x_ref_max[..., tf.newaxis] axis = tf.convert_to_tensor(axis, name='axis', dtype=tf.int32) axis = ps.non_negative_axis(axis, tf.rank(y_ref)) _assert_ndims_statically(axis, expect_ndims=0) ny = tf.cast(tf.shape(y_ref)[axis], dtype) # Map [x_ref_min, x_ref_max] to [0, ny - 1]. # This is the (fractional) index of x. if grid_regularizing_transform is None: g = lambda x: x else: g = grid_regularizing_transform fractional_idx = ((g(x) - g(x_ref_min)) / (g(x_ref_max) - g(x_ref_min))) x_idx_unclipped = fractional_idx * (ny - 1) # Wherever x is NaN, x_idx_unclipped will be NaN as well. # Keep track of the nan indices here (so we can impute NaN later). # Also eliminate any NaN indices, since there is not NaN in 32bit. nan_idx = tf.math.is_nan(x_idx_unclipped) zero = tf.zeros((), dtype=dtype) x_idx_unclipped = tf.where(nan_idx, zero, x_idx_unclipped) x_idx = tf.clip_by_value(x_idx_unclipped, zero, ny - 1) # Get the index above and below x_idx. # Naively we could set idx_below = floor(x_idx), idx_above = ceil(x_idx), # however, this results in idx_below == idx_above whenever x is on a grid. # This in turn results in y_ref_below == y_ref_above, and then the gradient # at this point is zero. So here we 'jitter' one of idx_below, idx_above, # so that they are at different values. This jittering does not affect the # interpolated value, but does make the gradient nonzero (unless of course # the y_ref values are the same). idx_below = tf.floor(x_idx) idx_above = tf.minimum(idx_below + 1, ny - 1) idx_below = tf.maximum(idx_above - 1, 0) # These are the values of y_ref corresponding to above/below indices. idx_below_int32 = tf.cast(idx_below, dtype=tf.int32) idx_above_int32 = tf.cast(idx_above, dtype=tf.int32) if batch_y_ref: # If y_ref.shape ~ [A1,...,AN, C, B1,...,BN], # and x.shape, x_ref_min/max.shape ~ [A1,...,AN, D] # Then y_ref_below.shape ~ [A1,...,AN, D, B1,...,BN] y_ref_below = _batch_gather_with_broadcast(y_ref, idx_below_int32, axis) y_ref_above = _batch_gather_with_broadcast(y_ref, idx_above_int32, axis) else: # Here, y_ref_below.shape = # y_ref.shape[:axis] + x.shape + y_ref.shape[axis + 1:] y_ref_below = tf.gather(y_ref, idx_below_int32, axis=axis) y_ref_above = tf.gather(y_ref, idx_above_int32, axis=axis) # Use t to get a convex combination of the below/above values. t = x_idx - idx_below # x, and tensors shaped like x, need to be added to, and selected with # (using tf.where) the output y. This requires appending singletons. # Make functions appropriate for batch/no-batch. if batch_y_ref: # In the non-batch case, the output shape is going to be # y_ref.shape[:axis] + x.shape + y_ref.shape[axis+1:] expand_x_fn = _make_expand_x_fn_for_batch_interpolation(y_ref, axis) else: # In the batch case, the output shape is going to be # Broadcast(y_ref.shape[:axis], x.shape[:-1]) + # x.shape[-1:] + y_ref.shape[axis+1:] expand_x_fn = _make_expand_x_fn_for_non_batch_interpolation(y_ref, axis) t = expand_x_fn(t) nan_idx = expand_x_fn(nan_idx, broadcast=True) x_idx_unclipped = expand_x_fn(x_idx_unclipped, broadcast=True) y = t * y_ref_above + (1 - t) * y_ref_below # Now begins a long excursion to fill values outside [x_min, x_max]. # Re-insert NaN wherever x was NaN. y = tf.where(nan_idx, tf.constant(np.nan, y.dtype), y) if not need_separate_fills: if fill_value == 'constant_extension': pass # Already handled by clipping x_idx_unclipped. else: y = tf.where( (x_idx_unclipped < 0) | (x_idx_unclipped > ny - 1), fill_value, y) else: # Fill values below x_ref_min <==> x_idx_unclipped < 0. if fill_value_below == 'constant_extension': pass # Already handled by the clipping that created x_idx_unclipped. elif fill_value_below == 'extrapolate': if batch_y_ref: # For every batch member, gather the first two elements of y across # `axis`. y_0 = tf.gather(y_ref, [0], axis=axis) y_1 = tf.gather(y_ref, [1], axis=axis) else: # If not batching, we want to gather the first two elements, just like # above. However, these results need to be replicated for every # member of x. An easy way to do that is to gather using # indices = zeros/ones(x.shape). y_0 = tf.gather( y_ref, tf.zeros(tf.shape(x), dtype=tf.int32), axis=axis) y_1 = tf.gather( y_ref, tf.ones(tf.shape(x), dtype=tf.int32), axis=axis) x_delta = (x_ref_max - x_ref_min) / (ny - 1) x_factor = expand_x_fn((x - x_ref_min) / x_delta, broadcast=True) y = tf.where(x_idx_unclipped < 0, y_0 + x_factor * (y_1 - y_0), y) else: y = tf.where(x_idx_unclipped < 0, fill_value_below, y) # Fill values above x_ref_min <==> x_idx_unclipped > ny - 1. if fill_value_above == 'constant_extension': pass # Already handled by the clipping that created x_idx_unclipped. elif fill_value_above == 'extrapolate': ny_int32 = tf.shape(y_ref)[axis] if batch_y_ref: y_n1 = tf.gather(y_ref, [tf.shape(y_ref)[axis] - 1], axis=axis) y_n2 = tf.gather(y_ref, [tf.shape(y_ref)[axis] - 2], axis=axis) else: y_n1 = tf.gather( y_ref, tf.fill(tf.shape(x), ny_int32 - 1), axis=axis) y_n2 = tf.gather( y_ref, tf.fill(tf.shape(x), ny_int32 - 2), axis=axis) x_delta = (x_ref_max - x_ref_min) / (ny - 1) x_factor = expand_x_fn((x - x_ref_max) / x_delta, broadcast=True) y = tf.where(x_idx_unclipped > ny - 1, y_n1 + x_factor * (y_n1 - y_n2), y) else: y = tf.where(x_idx_unclipped > ny - 1, fill_value_above, y) return y
def test_static_scalar_positive_index(self): positive_axis = prefer_static.non_negative_axis(axis=2, rank=4) self.assertAllEqual(2, positive_axis)