def _batch_gather_with_broadcast(params, indices, axis): """Like batch_gather, but broadcasts to the left of axis.""" # batch_gather assumes... # params.shape = [A1,...,AN, B1,...,BM] # indices.shape = [A1,...,AN, C] # which gives output of shape # [A1,...,AN, C, B1,...,BM] # Here we broadcast dims of each to the left of `axis` in params, and left of # the rightmost dim in indices, e.g. we can # have # params.shape = [A1,...,AN, B1,...,BM] # indices.shape = [a1,...,aN, C], # where ai broadcasts with Ai. # leading_bcast_shape is the broadcast of [A1,...,AN] and [a1,...,aN]. leading_bcast_shape = ps.broadcast_shape( ps.shape_slice(params, np.s_[:axis]), ps.shape_slice(indices, np.s_[:-1])) params = _broadcast_with( params, ps.concat((leading_bcast_shape, ps.shape_slice(params, np.s_[axis:])), axis=0)) indices = _broadcast_with( indices, ps.concat((leading_bcast_shape, ps.shape_slice(indices, np.s_[-1:])), axis=0)) return tf.gather(params, indices, batch_dims=tensorshape_util.rank(indices.shape) - 1)
def expand_right_dims(x, broadcast=False): """Expand x so it can bcast w/ tensors of output shape.""" x_shape_left = ps.shape_slice(x, np.s_[:-1]) x_shape_right = ps.shape_slice(x, np.s_[-1:]) expanded_shape_left = ps.broadcast_shape( x_shape_left, ps.ones([ps.size(y_ref_shape_left)], dtype=tf.int32)) expanded_shape = ps.concat( (expanded_shape_left, x_shape_right, ps.ones([ps.size(y_ref_shape_right)], dtype=tf.int32)), axis=0) x_expanded = tf.reshape(x, expanded_shape) if broadcast: broadcast_shape_left = ps.broadcast_shape(x_shape_left, y_ref_shape_left) broadcast_shape = ps.concat( (broadcast_shape_left, x_shape_right, y_ref_shape_right), axis=0) x_expanded = _broadcast_with(x_expanded, broadcast_shape) return x_expanded
def _expand_x_fn(tensor): # Reshape tensor to tensor.shape + [1] * M. extended_shape = ps.concat( [ ps.shape(tensor), ps.ones_like( ps.convert_to_shape_tensor( ps.shape_slice(y_ref, np.s_[batch_dims + nd:]))) ], axis=0, ) return tf.reshape(tensor, extended_shape)
def test_shape_slice(self): shape = [3, 2, 1] slice_ = slice(1, 2) slice_tensor = slice(tf.constant(1), tf.constant(2)) # case: numpy input. self.assertEqual(ps.shape_slice(np.zeros(shape), slice_), shape[slice_]) self.assertEqual(ps.shape_slice(np.zeros(shape), slice_tensor), shape[slice_]) # case: static-shape Tensor input. self.assertEqual(ps.shape_slice(tf.zeros(shape), slice_), shape[slice_]) self.assertNotIsInstance(ps.shape_slice(tf.zeros(shape), slice_), tf.TensorShape) self.assertEqual(ps.shape_slice(tf.zeros(shape), slice_tensor), shape[slice_]) self.assertNotIsInstance(ps.shape_slice(tf.zeros(shape), slice_tensor), tf1.Dimension) if tf.executing_eagerly(): return # Case: input is Tensor with fully unknown shape. zeros_pl = tf1.placeholder_with_default(tf.zeros(shape), shape=None) slice_pl = slice(tf1.placeholder_with_default(1, shape=[]), tf1.placeholder_with_default(2, shape=[])) self.assertAllEqual(ps.shape_slice(zeros_pl, slice_), shape[slice_]) self.assertAllEqual(ps.shape_slice(zeros_pl, slice_tensor), shape[slice_]) self.assertAllEqual(ps.shape_slice(zeros_pl, slice_pl), shape[slice_]) # Case: input is Tensor with partially known shape. # The result should be static if slice_ is. zeros_partial_pl = tf1.placeholder_with_default(tf.zeros(shape), shape=tf.TensorShape( [None, 2, 1])) self.assertEqual(ps.shape_slice(zeros_partial_pl, slice_), shape[slice_]) self.assertEqual(ps.shape_slice(zeros_partial_pl, slice_tensor), shape[slice_]) self.assertAllEqual(ps.shape_slice(zeros_partial_pl, slice_pl), shape[slice_])
def _batch_interp_with_gather_nd(x, x_ref_min, x_ref_max, y_ref, nd, fill_value, batch_dims): """N-D interpolation that works with leading batch dims.""" dtype = x.dtype # In this function, # x.shape = [A1, ..., An, D, nd], where n = batch_dims # and # y_ref.shape = [A1, ..., An, C1, C2,..., Cnd, B1,...,BM] # y_ref[A1, ..., An, i1,...,ind] is a shape [B1,...,BM] Tensor with the value # at index [i1,...,ind] in the interpolation table. # and x_ref_max have shapes [A1, ..., An, nd]. # ny[k] is number of y reference points in interp dim k. ny = tf.cast(ps.shape_slice(y_ref, np.s_[batch_dims:batch_dims + nd]), dtype) # Map [x_ref_min, x_ref_max] to [0, ny - 1]. # This is the (fractional) index of x. # x_idx_unclipped[A1, ..., An, d, k] is the fractional index into dim k of # interpolation table for the dth x value. x_ref_min_expanded = tf.expand_dims(x_ref_min, axis=-2) x_ref_max_expanded = tf.expand_dims(x_ref_max, axis=-2) x_idx_unclipped = (ny - 1) * (x - x_ref_min_expanded) / ( x_ref_max_expanded - x_ref_min_expanded) # Wherever x is NaN, x_idx_unclipped will be NaN as well. # Keep track of the nan indices here (so we can impute NaN later). # Also eliminate any NaN indices, since there is not NaN in 32bit. nan_idx = tf.math.is_nan(x_idx_unclipped) x_idx_unclipped = tf.where(nan_idx, tf.cast(0., dtype=dtype), x_idx_unclipped) # x_idx.shape = [A1, ..., An, D, nd] x_idx = tf.clip_by_value(x_idx_unclipped, tf.zeros((), dtype=dtype), ny - 1) # Get the index above and below x_idx. # Naively we could set idx_below = floor(x_idx), idx_above = ceil(x_idx), # however, this results in idx_below == idx_above whenever x is on a grid. # This in turn results in y_ref_below == y_ref_above, and then the gradient # at this point is zero. So here we 'jitter' one of idx_below, idx_above, # so that they are at different values. This jittering does not affect the # interpolated value, but does make the gradient nonzero (unless of course # the y_ref values are the same). idx_below = tf.floor(x_idx) idx_above = tf.minimum(idx_below + 1, ny - 1) idx_below = tf.maximum(idx_above - 1, 0) # These are the values of y_ref corresponding to above/below indices. # idx_below_int32.shape = x.shape[:-1] + [nd] idx_below_int32 = tf.cast(idx_below, dtype=tf.int32) idx_above_int32 = tf.cast(idx_above, dtype=tf.int32) # idx_below_list is a length nd list of shape x.shape[:-1] int32 tensors. idx_below_list = tf.unstack(idx_below_int32, axis=-1) idx_above_list = tf.unstack(idx_above_int32, axis=-1) # Use t to get a convex combination of the below/above values. # t.shape = [A1, ..., An, D, nd] t = x_idx - idx_below # x, and tensors shaped like x, need to be added to, and selected with # (using tf.where) the output y. This requires appending singletons. def _expand_x_fn(tensor): # Reshape tensor to tensor.shape + [1] * M. extended_shape = ps.concat( [ ps.shape(tensor), ps.ones_like( ps.convert_to_shape_tensor( ps.shape_slice(y_ref, np.s_[batch_dims + nd:]))) ], axis=0, ) return tf.reshape(tensor, extended_shape) # Now, t.shape = [A1, ..., An, D, nd] + [1] * (rank(y_ref) - nd - batch_dims) t = _expand_x_fn(t) s = 1 - t # Re-insert NaN wherever x was NaN. nan_idx = _expand_x_fn(nan_idx) t = tf.where(nan_idx, tf.constant(np.nan, dtype), t) terms = [] # Our work above has located x's fractional index inside a cube of above/below # indices. The distance to the below indices is t, and to the above indices # is s. # Drawing lines from x to the cube walls, we get 2**nd smaller cubes. Each # term in the result is a product of a reference point, gathered from y_ref, # multiplied by a volume. The volume is that of the cube opposite to the # reference point. E.g. if the reference point is below x in every axis, the # volume is that of the cube with corner above x in every axis, s[0]*...*s[nd] # We could probably do this with one massive gather, but that would be very # unreadable and un-debuggable. It also would create a large Tensor. for zero_ones_list in _binary_count(nd): gather_from_y_ref_idx = [] opposite_volume_t_idx = [] opposite_volume_s_idx = [] for k, zero_or_one in enumerate(zero_ones_list): if zero_or_one == 0: # If the kth iterate has zero_or_one = 0, # Will gather from the 'below' reference point along axis k. gather_from_y_ref_idx.append(idx_below_list[k]) # Now append the index to gather for computing opposite_volume. # This could be done by initializing opposite_volume to 1, then here: # opposite_volume *= tf.gather(s, indices=k, axis=tf.rank(x) - 1) # but that puts a gather in the 'inner loop.' Better to append the # index and do one larger gather down below. opposite_volume_s_idx.append(k) else: gather_from_y_ref_idx.append(idx_above_list[k]) # Append an index to gather, having the same effect as # opposite_volume *= tf.gather(t, indices=k, axis=tf.rank(x) - 1) opposite_volume_t_idx.append(k) # Compute opposite_volume (volume of cube opposite the ref point): # Recall t.shape = s.shape = [D, nd] + [1, ..., 1] # Gather from t and s along the 'nd' axis, which is rank(x) - 1. ov_axis = ps.cast(ps.rank(x) - 1, tf.int32) opposite_volume = (tf.reduce_prod( tf.gather(t, indices=tf.cast(opposite_volume_t_idx, dtype=tf.int32), axis=ov_axis), axis=ov_axis) * tf.reduce_prod(tf.gather( s, indices=tf.cast(opposite_volume_s_idx, dtype=tf.int32), axis=ov_axis), axis=ov_axis)) # pyformat: disable y_ref_pt = tf.gather_nd(y_ref, tf.stack(gather_from_y_ref_idx, axis=-1), batch_dims=batch_dims) terms.append(y_ref_pt * opposite_volume) y = tf.math.add_n(terms) if tf.debugging.is_numeric_tensor(fill_value): # Recall x_idx_unclipped.shape = [D, nd], # so here we check if it was out of bounds in any of the nd dims. # Thus, oob_idx.shape = [D]. oob_idx = tf.reduce_any( (x_idx_unclipped < 0) | (x_idx_unclipped > ny - 1), axis=-1) # Now, y.shape = [D, B1,...,BM], so we'll have to broadcast oob_idx. oob_idx = _expand_x_fn(oob_idx) # Shape [D, 1,...,1] oob_idx = _broadcast_with(oob_idx, ps.shape(y)) y = tf.where(oob_idx, fill_value, y) return y
def batch_interp_regular_nd_grid(x, x_ref_min, x_ref_max, y_ref, axis, fill_value='constant_extension', name=None): """Multi-linear interpolation on a regular (constant spacing) grid. Given [a batch of] reference values, this function computes a multi-linear interpolant and evaluates it on [a batch of] of new `x` values. The interpolant is built from reference values indexed by `nd` dimensions of `y_ref`, starting at `axis`. For example, take the case of a `2-D` scalar valued function and no leading batch dimensions. In this case, `y_ref.shape = [C1, C2]` and `y_ref[i, j]` is the reference value corresponding to grid point ``` [x_ref_min[0] + i * (x_ref_max[0] - x_ref_min[0]) / (C1 - 1), x_ref_min[1] + j * (x_ref_max[1] - x_ref_min[1]) / (C2 - 1)] ``` In the general case, dimensions to the left of `axis` in `y_ref` are broadcast with leading dimensions in `x`, `x_ref_min`, `x_ref_max`. Args: x: Numeric `Tensor` The x-coordinates of the interpolated output values for each batch. Shape `[..., D, nd]`, designating [a batch of] `D` coordinates in `nd` space. `D` must be `>= 1` and is not a batch dim. x_ref_min: `Tensor` of same `dtype` as `x`. The minimum values of the (implicitly defined) reference `x_ref`. Shape `[..., nd]`. x_ref_max: `Tensor` of same `dtype` as `x`. The maximum values of the (implicitly defined) reference `x_ref`. Shape `[..., nd]`. y_ref: `Tensor` of same `dtype` as `x`. The reference output values. Shape `[..., C1, ..., Cnd, B1,...,BM]`, designating [a batch of] reference values indexed by `nd` dimensions, of a shape `[B1,...,BM]` valued function (for `M >= 0`). axis: Scalar integer `Tensor`. Dimensions `[axis, axis + nd)` of `y_ref` index the interpolation table. E.g. `3-D` interpolation of a scalar valued function requires `axis=-3` and a `3-D` matrix valued function requires `axis=-5`. fill_value: Determines what values output should take for `x` values that are below `x_ref_min` or above `x_ref_max`. Scalar `Tensor` or 'constant_extension' ==> Extend as constant function. Default value: `'constant_extension'` name: A name to prepend to created ops. Default value: `'batch_interp_regular_nd_grid'`. Returns: y_interp: Interpolation between members of `y_ref`, at points `x`. `Tensor` of same `dtype` as `x`, and shape `[..., D, B1, ..., BM].` Raises: ValueError: If `rank(x) < 2` is determined statically. ValueError: If `axis` is not a scalar is determined statically. ValueError: If `axis + nd > rank(y_ref)` is determined statically. #### Examples Interpolate a function of one variable. ```python y_ref = tf.exp(tf.linspace(start=0., stop=10., num=20)) tfp.math.batch_interp_regular_nd_grid( # x.shape = [3, 1], x_ref_min/max.shape = [1]. Trailing `1` for `1-D`. x=[[6.0], [0.5], [3.3]], x_ref_min=[0.], x_ref_max=[10.], y_ref=y_ref, axis=0) ==> approx [exp(6.0), exp(0.5), exp(3.3)] ``` Interpolate a scalar function of two variables. ```python x_ref_min = [0., 0.] x_ref_max = [2 * np.pi, 2 * np.pi] # Build y_ref. x0s, x1s = tf.meshgrid( tf.linspace(x_ref_min[0], x_ref_max[0], num=100), tf.linspace(x_ref_min[1], x_ref_max[1], num=100), indexing='ij') def func(x0, x1): return tf.sin(x0) * tf.cos(x1) y_ref = func(x0s, x1s) x = np.pi * tf.random.uniform(shape=(10, 2)) tfp.math.batch_interp_regular_nd_grid(x, x_ref_min, x_ref_max, y_ref, axis=-2) ==> tf.sin(x[:, 0]) * tf.cos(x[:, 1]) ``` """ with tf.name_scope(name or 'interp_regular_nd_grid'): dtype = dtype_util.common_dtype([x, x_ref_min, x_ref_max, y_ref], dtype_hint=tf.float32) # Arg checking. if isinstance(fill_value, str): if fill_value != 'constant_extension': raise ValueError( 'A fill value ({}) was not an allowed string ({})'.format( fill_value, 'constant_extension')) else: fill_value = tf.convert_to_tensor(fill_value, name='fill_value', dtype=dtype) _assert_ndims_statically(fill_value, expect_ndims=0) # x.shape = [..., nd]. x = tf.convert_to_tensor(x, name='x', dtype=dtype) _assert_ndims_statically(x, expect_ndims_at_least=2) # y_ref.shape = [..., C1,...,Cnd, B1,...,BM] y_ref = tf.convert_to_tensor(y_ref, name='y_ref', dtype=dtype) # x_ref_min.shape = [nd] x_ref_min = tf.convert_to_tensor(x_ref_min, name='x_ref_min', dtype=dtype) x_ref_max = tf.convert_to_tensor(x_ref_max, name='x_ref_max', dtype=dtype) _assert_ndims_statically(x_ref_min, expect_ndims_at_least=1, expect_static=True) _assert_ndims_statically(x_ref_max, expect_ndims_at_least=1, expect_static=True) # nd is the number of dimensions indexing the interpolation table, it's the # 'nd' in the function name. nd = tf.compat.dimension_value(x_ref_min.shape[-1]) if nd is None: raise ValueError('`x_ref_min.shape[-1]` must be known statically.') tensorshape_util.assert_is_compatible_with(x_ref_max.shape[-1:], x_ref_min.shape[-1:]) # Convert axis and check it statically. axis = ps.convert_to_shape_tensor(axis, dtype=tf.int32, name='axis') axis = ps.non_negative_axis(axis, ps.rank(y_ref)) tensorshape_util.assert_has_rank(axis.shape, 0) axis_ = tf.get_static_value(axis) y_ref_rank_ = tf.get_static_value(tf.rank(y_ref)) if axis_ is not None and y_ref_rank_ is not None: if axis_ + nd > y_ref_rank_: raise ValueError( 'Since dims `[axis, axis + nd)` index the interpolation table, we ' 'must have `axis + nd <= rank(y_ref)`. Found: ' '`axis`: {}, rank(y_ref): {}, and inferred `nd` from trailing ' 'dimensions of `x_ref_min` to be {}.'.format( axis_, y_ref_rank_, nd)) x_batch_shape = ps.shape_slice(x, np.s_[:-2]) x_ref_min_batch_shape = ps.shape_slice(x_ref_min, np.s_[:-1]) x_ref_max_batch_shape = ps.shape_slice(x_ref_max, np.s_[:-1]) y_ref_batch_shape = ps.shape_slice(y_ref, np.s_[:axis]) # Do a brute-force broadcast of batch dims (add zeros). batch_shape = y_ref_batch_shape for tensor in [ x_batch_shape, x_ref_min_batch_shape, x_ref_max_batch_shape ]: batch_shape = ps.broadcast_shape(batch_shape, tensor) def _batch_shape_of_zeros_with_rightmost_singletons(n_singletons): """Return Tensor of zeros with some singletons on the rightmost dims.""" ones = ps.ones(shape=[n_singletons], dtype=tf.int32) return ps.concat([batch_shape, ones], axis=0) x = _broadcast_with( x, _batch_shape_of_zeros_with_rightmost_singletons(n_singletons=2)) x_ref_min = _broadcast_with( x_ref_min, _batch_shape_of_zeros_with_rightmost_singletons(n_singletons=1)) x_ref_max = _broadcast_with( x_ref_max, _batch_shape_of_zeros_with_rightmost_singletons(n_singletons=1)) y_ref = _broadcast_with( y_ref, _batch_shape_of_zeros_with_rightmost_singletons( n_singletons=ps.rank(y_ref) - axis)) return _batch_interp_with_gather_nd(x=x, x_ref_min=x_ref_min, x_ref_max=x_ref_max, y_ref=y_ref, nd=nd, fill_value=fill_value, batch_dims=ps.rank(x) - 2)