예제 #1
0
def _log_average_probs_process_args(logits, validate_args, sample_axis,
                                    event_axis):
    """Processes args for `log_average_probs`."""
    rank = ps.rank(logits)
    if sample_axis is None or validate_args:
        event_axis = ps.reshape(ps.non_negative_axis(event_axis, rank),
                                shape=[-1])
    if sample_axis is None:
        sample_axis = ps.setdiff1d(ps.range(rank), event_axis)
    elif validate_args:
        sample_axis = ps.reshape(ps.non_negative_axis(sample_axis, rank),
                                 shape=[-1])
    return sample_axis, event_axis
예제 #2
0
def _make_static_axis_non_negative_list(axis, ndims):
    """Convert possibly negatively indexed axis to non-negative list of ints.

  Args:
    axis:  Integer Tensor.
    ndims:  Number of dimensions into which axis indexes.

  Returns:
    A list of non-negative Python integers.

  Raises:
    ValueError: If `axis` is not statically defined.
  """
    axis = prefer_static.non_negative_axis(axis, ndims)

    axis_const = tf.get_static_value(axis)
    if axis_const is None:
        raise ValueError(
            'Expected argument `axis` to be statically available.  Found: %s' %
            axis)

    # Make at least 1-D.
    axis = axis_const + np.zeros([1], dtype=axis_const.dtype)

    return list(int(dim) for dim in axis)
예제 #3
0
def _effective_sample_size_single_state(states, filter_beyond_lag,
                                        filter_threshold,
                                        filter_beyond_positive_pairs,
                                        cross_chain_dims, validate_args):
    """ESS computation for one single Tensor argument."""

    with tf.name_scope('effective_sample_size_single_state'):

        states = tf.convert_to_tensor(states, name='states')
        dt = states.dtype

        # filter_beyond_lag == None ==> auto_corr is the full sequence.
        auto_cov = stats.auto_correlation(states,
                                          axis=0,
                                          max_lags=filter_beyond_lag,
                                          normalize=False)
        n = _axis_size(states, axis=0)

        if cross_chain_dims is not None:
            num_chains = _axis_size(states, cross_chain_dims)
            num_chains_ = tf.get_static_value(num_chains)

            assertions = []
            msg = (
                'When `cross_chain_dims` is not `None`, there must be > 1 chain '
                'in `states`.')
            if num_chains_ is not None:
                if num_chains_ < 2:
                    raise ValueError(msg)
            elif validate_args:
                assertions.append(
                    assert_util.assert_greater(num_chains, 1., message=msg))

            with tf.control_dependencies(assertions):
                # We're computing the R[k] from equation 10 of Vehtari et al.
                # (2019):
                #
                # R[k] := 1 - (W - 1/C * Sum_{c=1}^C s_c**2 R[k, c]) / (var^+),
                #
                # where:
                #   C := number of chains
                #   N := length of chains
                #   x_hat[c] := 1 / N Sum_{n=1}^N x[n, c], chain mean.
                #   x_hat := 1 / C Sum_{c=1}^C x_hat[c], overall mean.
                #   W := 1/C Sum_{c=1}^C s_c**2, within-chain variance.
                #   B := N / (C - 1) Sum_{c=1}^C (x_hat[c] - x_hat)**2, between chain
                #     variance.
                #   s_c**2 := 1 / (N - 1) Sum_{n=1}^N (x[n, c] - x_hat[c])**2, chain
                #       variance
                #   R[k, m] := auto_corr[k, m, ...], auto-correlation indexed by chain.
                #   var^+ := (N - 1) / N * W + B / N

                cross_chain_dims = prefer_static.non_negative_axis(
                    cross_chain_dims, prefer_static.rank(states))
                # B / N
                between_chain_variance_div_n = _reduce_variance(
                    tf.reduce_mean(states, axis=0),
                    biased=False,  # This makes the denominator be C - 1.
                    axis=cross_chain_dims - 1)
                # W * (N - 1) / N
                biased_within_chain_variance = tf.reduce_mean(
                    auto_cov[0], cross_chain_dims - 1)
                # var^+
                approx_variance = (biased_within_chain_variance +
                                   between_chain_variance_div_n)
                # 1/C * Sum_{c=1}^C s_c**2 R[k, c]
                mean_auto_cov = tf.reduce_mean(auto_cov, cross_chain_dims)
                auto_corr = 1. - (biased_within_chain_variance -
                                  mean_auto_cov) / approx_variance
        else:
            auto_corr = auto_cov / auto_cov[:1]
            num_chains = 1

        # With R[k] := auto_corr[k, ...],
        # ESS = N / {1 + 2 * Sum_{k=1}^N R[k] * (N - k) / N}
        #     = N / {-1 + 2 * Sum_{k=0}^N R[k] * (N - k) / N} (since R[0] = 1)
        #     approx N / {-1 + 2 * Sum_{k=0}^M R[k] * (N - k) / N}
        # where M is the filter_beyond_lag truncation point chosen above.

        # Get the factor (N - k) / N, and give it shape [M, 1,...,1], having total
        # ndims the same as auto_corr
        k = tf.range(0., _axis_size(auto_corr, axis=0))
        nk_factor = (n - k) / n
        if tensorshape_util.rank(auto_corr.shape) is not None:
            new_shape = [-1
                         ] + [1] * (tensorshape_util.rank(auto_corr.shape) - 1)
        else:
            new_shape = tf.concat(
                ([-1], tf.ones([tf.rank(auto_corr) - 1], dtype=tf.int32)),
                axis=0)
        nk_factor = tf.reshape(nk_factor, new_shape)
        weighted_auto_corr = nk_factor * auto_corr

        if filter_beyond_positive_pairs:

            def _sum_pairs(x):
                x_len = tf.shape(x)[0]
                # For odd sequences, we drop the final value.
                x = x[:x_len - x_len % 2]
                new_shape = tf.concat(
                    [[x_len // 2, 2], tf.shape(x)[1:]], axis=0)
                return tf.reduce_sum(tf.reshape(x, new_shape), 1)

            # Pairwise sums are all positive for auto-correlation spectra derived from
            # reversible MCMC chains.
            # E.g. imagine the pairwise sums are [0.2, 0.1, -0.1, -0.2]
            # Step 1: mask = [False, False, True, True]
            mask = _sum_pairs(auto_corr) < 0.
            # Step 2: mask = [0, 0, 1, 1]
            mask = tf.cast(mask, dt)
            # Step 3: mask = [0, 0, 1, 2]
            mask = tf.cumsum(mask, axis=0)
            # Step 4: mask = [1, 1, 0, 0]
            mask = tf.maximum(1. - mask, 0.)

            # N.B. this reduces the length of weighted_auto_corr by a factor of 2.
            # It still works fine in the formula below.
            weighted_auto_corr = _sum_pairs(weighted_auto_corr) * mask
        elif filter_threshold is not None:
            filter_threshold = tf.convert_to_tensor(filter_threshold,
                                                    dtype=dt,
                                                    name='filter_threshold')
            # Get a binary mask to zero out values of auto_corr below the threshold.
            #   mask[i, ...] = 1 if auto_corr[j, ...] > threshold for all j <= i,
            #   mask[i, ...] = 0, otherwise.
            # So, along dimension zero, the mask will look like [1, 1, ..., 0, 0,...]
            # Building step by step,
            #   Assume auto_corr = [1, 0.5, 0.0, 0.3], and filter_threshold = 0.2.
            # Step 1:  mask = [False, False, True, False]
            mask = auto_corr < filter_threshold
            # Step 2:  mask = [0, 0, 1, 0]
            mask = tf.cast(mask, dtype=dt)
            # Step 3:  mask = [0, 0, 1, 1]
            mask = tf.cumsum(mask, axis=0)
            # Step 4:  mask = [1, 1, 0, 0]
            mask = tf.maximum(1. - mask, 0.)
            weighted_auto_corr *= mask

        return num_chains * n / (-1 +
                                 2 * tf.reduce_sum(weighted_auto_corr, axis=0))
예제 #4
0
def _log_loosum_exp_impl(logx, axis, keepdims, compute_mean):
    """Implementation for `*loosum*` functions."""
    with tf.name_scope('log_loosum_exp_impl'):
        logx = tf.convert_to_tensor(logx, name='logx')
        dtype = dtype_util.as_numpy_dtype(logx.dtype)

        if axis is not None:
            x = np.array(axis)
            axis = (tf.convert_to_tensor(
                axis, name='axis', dtype_hint=tf.int32)
                    if x.dtype is np.object else x.astype(np.int32))

        log_sum_x = tf.reduce_logsumexp(logx, axis=axis, keepdims=True)

        # Later we'll want to compute the mean from a sum so we calculate the number
        # of reduced elements, n.
        n = prefer_static.size(logx) // prefer_static.size(log_sum_x)
        n = prefer_static.cast(n, dtype)

        # log_loosum_x[i] =
        # = logsumexp(logx[j] : j != i)
        # = log( exp(logsumexp(logx)) - exp(logx[i]) )
        # = log( exp(logsumexp(logx - logx[i])) exp(logx[i])  - exp(logx[i]))
        # = logx[i] + log(exp(logsumexp(logx - logx[i])) - 1)
        # = logx[i] + log(exp(logsumexp(logx) - logx[i]) - 1)
        # = logx[i] + softplus_inverse(logsumexp(logx) - logx[i])
        d = log_sum_x - logx
        # We use `d != 0` rather than `d > 0.` because `d < 0.` should never happen;
        # if it does we want to complain loudly (which `softplus_inverse` will).
        d_ok = tf.not_equal(d, 0.)
        safe_d = tf.where(d_ok, d, 1.)
        d_ok_result = logx + softplus_inverse(safe_d)

        neg_inf = tf.constant(-np.inf, dtype=dtype)

        # When not(d_ok) and is_positive_and_largest then we manually compute the
        # log_loosum_x. (We can efficiently do this for any one point but not all,
        # hence we still need the above calculation.) This is good because when
        # this condition is met, we cannot use the above calculation; its -inf.
        # We now compute the log-leave-out-max-sum, replicate it to every
        # point and make sure to select it only when we need to.
        max_logx = tf.reduce_max(logx, axis=axis, keepdims=True)
        is_positive_and_largest = (logx > 0.) & tf.equal(logx, max_logx)
        log_lomsum_x = tf.reduce_logsumexp(tf.where(is_positive_and_largest,
                                                    neg_inf, logx),
                                           axis=axis,
                                           keepdims=True)
        d_not_ok_result = tf.where(is_positive_and_largest, log_lomsum_x,
                                   neg_inf)

        log_loosum_x = tf.where(d_ok, d_ok_result, d_not_ok_result)

        # We now squeeze log_sum_x so as if we used `keepdims=False`.
        # TODO(b/136176077): These mental gymnastics could all be replaced with
        # `tf.squeeze(log_sum_x, axis)` if tf.squeeze supported Tensor valued `axis`
        # arguments.
        if not keepdims:
            if axis is None:
                keepdims = np.array([], dtype=np.int32)
            else:
                rank = prefer_static.rank(logx)
                keepdims = prefer_static.setdiff1d(
                    prefer_static.range(rank),
                    prefer_static.non_negative_axis(axis, rank))
            squeeze_shape = tf.gather(prefer_static.shape(logx),
                                      indices=keepdims)
            log_sum_x = tf.reshape(log_sum_x, shape=squeeze_shape)
            if prefer_static.is_numpy(keepdims):
                tensorshape_util.set_shape(log_sum_x,
                                           np.array(logx.shape)[keepdims])

        # Set static shapes just in case we lost them.
        tensorshape_util.set_shape(n, [])
        tensorshape_util.set_shape(log_loosum_x, logx.shape)

        if not compute_mean:
            return log_loosum_x, log_sum_x, n

        log_nm1 = prefer_static.log(max(1., n - 1.))
        log_n = prefer_static.log(n)
        return log_loosum_x - log_nm1, log_sum_x - log_n, n
예제 #5
0
 def test_dynamic_vector_index(self):
     axis = tf.Variable([0, -2])
     positive_axis = ps.non_negative_axis(axis=axis, rank=4)
     self.evaluate(axis.initializer)
     self.assertAllEqual([0, 2], self.evaluate(positive_axis))
예제 #6
0
 def test_static_vector_index(self):
     positive_axis = ps.non_negative_axis(axis=[0, -2], rank=4)
     self.assertAllEqual([0, 2], positive_axis)
예제 #7
0
 def test_static_scalar_negative_index(self):
     positive_axis = ps.non_negative_axis(axis=-1, rank=4)
     self.assertAllEqual(3, positive_axis)
예제 #8
0
def batch_interp_regular_nd_grid(x,
                                 x_ref_min,
                                 x_ref_max,
                                 y_ref,
                                 axis,
                                 fill_value='constant_extension',
                                 name=None):
  """Multi-linear interpolation on a regular (constant spacing) grid.

  Given [a batch of] reference values, this function computes a multi-linear
  interpolant and evaluates it on [a batch of] of new `x` values.

  The interpolant is built from reference values indexed by `nd` dimensions
  of `y_ref`, starting at `axis`.

  For example, take the case of a `2-D` scalar valued function and no leading
  batch dimensions.  In this case, `y_ref.shape = [C1, C2]` and `y_ref[i, j]`
  is the reference value corresponding to grid point

  ```
  [x_ref_min[0] + i * (x_ref_max[0] - x_ref_min[0]) / (C1 - 1),
   x_ref_min[1] + j * (x_ref_max[1] - x_ref_min[1]) / (C2 - 1)]
  ```

  In the general case, dimensions to the left of `axis` in `y_ref` are broadcast
  with leading dimensions in `x`, `x_ref_min`, `x_ref_max`.

  Args:
    x: Numeric `Tensor` The x-coordinates of the interpolated output values for
      each batch.  Shape `[..., D, nd]`, designating [a batch of] `D`
      coordinates in `nd` space.  `D` must be `>= 1` and is not a batch dim.
    x_ref_min:  `Tensor` of same `dtype` as `x`.  The minimum values of the
      (implicitly defined) reference `x_ref`.  Shape `[..., nd]`.
    x_ref_max:  `Tensor` of same `dtype` as `x`.  The maximum values of the
      (implicitly defined) reference `x_ref`.  Shape `[..., nd]`.
    y_ref:  `Tensor` of same `dtype` as `x`.  The reference output values. Shape
      `[..., C1, ..., Cnd, B1,...,BM]`, designating [a batch of] reference
      values indexed by `nd` dimensions, of a shape `[B1,...,BM]` valued
      function (for `M >= 0`).
    axis:  Scalar integer `Tensor`.  Dimensions `[axis, axis + nd)` of `y_ref`
      index the interpolation table.  E.g. `3-D` interpolation of a scalar
      valued function requires `axis=-3` and a `3-D` matrix valued function
      requires `axis=-5`.
    fill_value:  Determines what values output should take for `x` values that
      are below `x_ref_min` or above `x_ref_max`. Scalar `Tensor` or
      'constant_extension' ==> Extend as constant function.
      Default value: `'constant_extension'`
    name:  A name to prepend to created ops.
      Default value: `'batch_interp_regular_nd_grid'`.

  Returns:
    y_interp:  Interpolation between members of `y_ref`, at points `x`.
      `Tensor` of same `dtype` as `x`, and shape `[..., D, B1, ..., BM].`

  Raises:
    ValueError:  If `rank(x) < 2` is determined statically.
    ValueError:  If `axis` is not a scalar is determined statically.
    ValueError:  If `axis + nd > rank(y_ref)` is determined statically.

  #### Examples

  Interpolate a function of one variable.

  ```python
  y_ref = tf.exp(tf.linspace(start=0., stop=10., num=20))

  tfp.math.batch_interp_regular_nd_grid(
      # x.shape = [3, 1], x_ref_min/max.shape = [1].  Trailing `1` for `1-D`.
      x=[[6.0], [0.5], [3.3]], x_ref_min=[0.], x_ref_max=[10.], y_ref=y_ref,
      axis=0)
  ==> approx [exp(6.0), exp(0.5), exp(3.3)]
  ```

  Interpolate a scalar function of two variables.

  ```python
  x_ref_min = [0., 0.]
  x_ref_max = [2 * np.pi, 2 * np.pi]

  # Build y_ref.
  x0s, x1s = tf.meshgrid(
      tf.linspace(x_ref_min[0], x_ref_max[0], num=100),
      tf.linspace(x_ref_min[1], x_ref_max[1], num=100),
      indexing='ij')

  def func(x0, x1):
    return tf.sin(x0) * tf.cos(x1)

  y_ref = func(x0s, x1s)

  x = np.pi * tf.random.uniform(shape=(10, 2))

  tfp.math.batch_interp_regular_nd_grid(x, x_ref_min, x_ref_max, y_ref, axis=-2)
  ==> tf.sin(x[:, 0]) * tf.cos(x[:, 1])
  ```

  """
  with tf.name_scope(name or 'interp_regular_nd_grid'):
    dtype = dtype_util.common_dtype([x, x_ref_min, x_ref_max, y_ref],
                                    dtype_hint=tf.float32)

    # Arg checking.
    if isinstance(fill_value, str):
      if fill_value != 'constant_extension':
        raise ValueError(
            'A fill value ({}) was not an allowed string ({})'.format(
                fill_value, 'constant_extension'))
    else:
      fill_value = tf.convert_to_tensor(
          fill_value, name='fill_value', dtype=dtype)
      _assert_ndims_statically(fill_value, expect_ndims=0)

    # x.shape = [..., nd].
    x = tf.convert_to_tensor(x, name='x', dtype=dtype)
    _assert_ndims_statically(x, expect_ndims_at_least=2)

    # y_ref.shape = [..., C1,...,Cnd, B1,...,BM]
    y_ref = tf.convert_to_tensor(y_ref, name='y_ref', dtype=dtype)

    # x_ref_min.shape = [nd]
    x_ref_min = tf.convert_to_tensor(
        x_ref_min, name='x_ref_min', dtype=dtype)
    x_ref_max = tf.convert_to_tensor(
        x_ref_max, name='x_ref_max', dtype=dtype)
    _assert_ndims_statically(
        x_ref_min, expect_ndims_at_least=1, expect_static=True)
    _assert_ndims_statically(
        x_ref_max, expect_ndims_at_least=1, expect_static=True)

    # nd is the number of dimensions indexing the interpolation table, it's the
    # 'nd' in the function name.
    nd = tf.compat.dimension_value(x_ref_min.shape[-1])
    if nd is None:
      raise ValueError('`x_ref_min.shape[-1]` must be known statically.')
    tensorshape_util.assert_is_compatible_with(
        x_ref_max.shape[-1:], x_ref_min.shape[-1:])

    # Convert axis and check it statically.
    axis = tf.convert_to_tensor(axis, dtype=tf.int32, name='axis')
    axis = ps.non_negative_axis(axis, tf.rank(y_ref))
    tensorshape_util.assert_has_rank(axis.shape, 0)
    axis_ = tf.get_static_value(axis)
    y_ref_rank_ = tf.get_static_value(tf.rank(y_ref))
    if axis_ is not None and y_ref_rank_ is not None:
      if axis_ + nd > y_ref_rank_:
        raise ValueError(
            'Since dims `[axis, axis + nd)` index the interpolation table, we '
            'must have `axis + nd <= rank(y_ref)`.  Found: '
            '`axis`: {},  rank(y_ref): {}, and inferred `nd` from trailing '
            'dimensions of `x_ref_min` to be {}.'.format(
                axis_, y_ref_rank_, nd))

    x_batch_shape = ps.shape(x)[:-2]
    x_ref_min_batch_shape = ps.shape(x_ref_min)[:-1]
    x_ref_max_batch_shape = ps.shape(x_ref_max)[:-1]
    y_ref_batch_shape = ps.shape(y_ref)[:axis]

    # Do a brute-force broadcast of batch dims (add zeros).
    batch_shape = y_ref_batch_shape
    for tensor in [x_batch_shape, x_ref_min_batch_shape, x_ref_max_batch_shape]:
      batch_shape = ps.broadcast_shape(batch_shape, tensor)

    def _batch_shape_of_zeros_with_rightmost_singletons(n_singletons):
      """Return Tensor of zeros with some singletons on the rightmost dims."""
      ones = ps.ones(shape=[n_singletons], dtype=tf.int32)
      return ps.concat([batch_shape, ones], axis=0)

    x = _broadcast_with(
        x, _batch_shape_of_zeros_with_rightmost_singletons(n_singletons=2))
    x_ref_min = _broadcast_with(
        x_ref_min,
        _batch_shape_of_zeros_with_rightmost_singletons(n_singletons=1))
    x_ref_max = _broadcast_with(
        x_ref_max,
        _batch_shape_of_zeros_with_rightmost_singletons(n_singletons=1))
    y_ref = _broadcast_with(
        y_ref,
        _batch_shape_of_zeros_with_rightmost_singletons(
            n_singletons=tf.rank(y_ref) - axis))

    return _batch_interp_with_gather_nd(
        x=x,
        x_ref_min=x_ref_min,
        x_ref_max=x_ref_max,
        y_ref=y_ref,
        nd=nd,
        fill_value=fill_value,
        batch_dims=ps.rank(x) - 2)
예제 #9
0
def _interp_regular_1d_grid_impl(x,
                                 x_ref_min,
                                 x_ref_max,
                                 y_ref,
                                 axis=-1,
                                 batch_y_ref=False,
                                 fill_value='constant_extension',
                                 fill_value_below=None,
                                 fill_value_above=None,
                                 grid_regularizing_transform=None,
                                 name=None):
  """1-D interpolation that works with/without batching."""
  # Note: we do *not* make the no-batch version a special case of the batch
  # version, because that would an inefficient use of batch_gather with
  # unnecessarily broadcast args.
  with tf.name_scope(name or 'interp_regular_1d_grid_impl'):

    # Arg checking.
    allowed_fv_st = ('constant_extension', 'extrapolate')
    for fv in (fill_value, fill_value_below, fill_value_above):
      if isinstance(fv, str) and fv not in allowed_fv_st:
        raise ValueError(
            'A fill value ({}) was not an allowed string ({})'.format(
                fv, allowed_fv_st))

    # Separate value fills for below/above incurs extra cost, so keep track of
    # whether this is needed.
    need_separate_fills = (
        fill_value_above is not None or fill_value_below is not None or
        fill_value == 'extrapolate'  # always requries separate below/above
    )
    if need_separate_fills and fill_value_above is None:
      fill_value_above = fill_value
    if need_separate_fills and fill_value_below is None:
      fill_value_below = fill_value

    dtype = dtype_util.common_dtype([x, x_ref_min, x_ref_max, y_ref],
                                    dtype_hint=tf.float32)
    x = tf.convert_to_tensor(x, name='x', dtype=dtype)

    x_ref_min = tf.convert_to_tensor(
        x_ref_min, name='x_ref_min', dtype=dtype)
    x_ref_max = tf.convert_to_tensor(
        x_ref_max, name='x_ref_max', dtype=dtype)
    if not batch_y_ref:
      _assert_ndims_statically(x_ref_min, expect_ndims=0)
      _assert_ndims_statically(x_ref_max, expect_ndims=0)

    y_ref = tf.convert_to_tensor(y_ref, name='y_ref', dtype=dtype)

    if batch_y_ref:
      # If we're batching,
      #   x.shape ~ [A1,...,AN, D],  x_ref_min/max.shape ~ [A1,...,AN]
      # So to add together we'll append a singleton.
      # If not batching, x_ref_min/max are scalar, so this isn't an issue,
      # moreover, if not batching, x can be scalar, and expanding x_ref_min/max
      # would cause a bad expansion of x when added to x (confused yet?).
      x_ref_min = x_ref_min[..., tf.newaxis]
      x_ref_max = x_ref_max[..., tf.newaxis]

    axis = tf.convert_to_tensor(axis, name='axis', dtype=tf.int32)
    axis = ps.non_negative_axis(axis, tf.rank(y_ref))
    _assert_ndims_statically(axis, expect_ndims=0)

    ny = tf.cast(tf.shape(y_ref)[axis], dtype)

    # Map [x_ref_min, x_ref_max] to [0, ny - 1].
    # This is the (fractional) index of x.
    if grid_regularizing_transform is None:
      g = lambda x: x
    else:
      g = grid_regularizing_transform
    fractional_idx = ((g(x) - g(x_ref_min)) / (g(x_ref_max) - g(x_ref_min)))
    x_idx_unclipped = fractional_idx * (ny - 1)

    # Wherever x is NaN, x_idx_unclipped will be NaN as well.
    # Keep track of the nan indices here (so we can impute NaN later).
    # Also eliminate any NaN indices, since there is not NaN in 32bit.
    nan_idx = tf.math.is_nan(x_idx_unclipped)
    zero = tf.zeros((), dtype=dtype)
    x_idx_unclipped = tf.where(nan_idx, zero, x_idx_unclipped)
    x_idx = tf.clip_by_value(x_idx_unclipped, zero, ny - 1)

    # Get the index above and below x_idx.
    # Naively we could set idx_below = floor(x_idx), idx_above = ceil(x_idx),
    # however, this results in idx_below == idx_above whenever x is on a grid.
    # This in turn results in y_ref_below == y_ref_above, and then the gradient
    # at this point is zero.  So here we 'jitter' one of idx_below, idx_above,
    # so that they are at different values.  This jittering does not affect the
    # interpolated value, but does make the gradient nonzero (unless of course
    # the y_ref values are the same).
    idx_below = tf.floor(x_idx)
    idx_above = tf.minimum(idx_below + 1, ny - 1)
    idx_below = tf.maximum(idx_above - 1, 0)

    # These are the values of y_ref corresponding to above/below indices.
    idx_below_int32 = tf.cast(idx_below, dtype=tf.int32)
    idx_above_int32 = tf.cast(idx_above, dtype=tf.int32)
    if batch_y_ref:
      # If y_ref.shape ~ [A1,...,AN, C, B1,...,BN],
      # and x.shape, x_ref_min/max.shape ~ [A1,...,AN, D]
      # Then y_ref_below.shape ~ [A1,...,AN, D, B1,...,BN]
      y_ref_below = _batch_gather_with_broadcast(y_ref, idx_below_int32, axis)
      y_ref_above = _batch_gather_with_broadcast(y_ref, idx_above_int32, axis)
    else:
      # Here, y_ref_below.shape =
      #   y_ref.shape[:axis] + x.shape + y_ref.shape[axis + 1:]
      y_ref_below = tf.gather(y_ref, idx_below_int32, axis=axis)
      y_ref_above = tf.gather(y_ref, idx_above_int32, axis=axis)

    # Use t to get a convex combination of the below/above values.
    t = x_idx - idx_below

    # x, and tensors shaped like x, need to be added to, and selected with
    # (using tf.where) the output y.  This requires appending singletons.
    # Make functions appropriate for batch/no-batch.
    if batch_y_ref:
      # In the non-batch case, the output shape is going to be
      #   y_ref.shape[:axis] + x.shape + y_ref.shape[axis+1:]
      expand_x_fn = _make_expand_x_fn_for_batch_interpolation(y_ref, axis)
    else:
      # In the batch case, the output shape is going to be
      #   Broadcast(y_ref.shape[:axis], x.shape[:-1]) +
      #   x.shape[-1:] +  y_ref.shape[axis+1:]
      expand_x_fn = _make_expand_x_fn_for_non_batch_interpolation(y_ref, axis)

    t = expand_x_fn(t)
    nan_idx = expand_x_fn(nan_idx, broadcast=True)
    x_idx_unclipped = expand_x_fn(x_idx_unclipped, broadcast=True)

    y = t * y_ref_above + (1 - t) * y_ref_below

    # Now begins a long excursion to fill values outside [x_min, x_max].

    # Re-insert NaN wherever x was NaN.
    y = tf.where(nan_idx, tf.constant(np.nan, y.dtype), y)

    if not need_separate_fills:
      if fill_value == 'constant_extension':
        pass  # Already handled by clipping x_idx_unclipped.
      else:
        y = tf.where(
            (x_idx_unclipped < 0) | (x_idx_unclipped > ny - 1),
            fill_value, y)
    else:
      # Fill values below x_ref_min <==> x_idx_unclipped < 0.
      if fill_value_below == 'constant_extension':
        pass  # Already handled by the clipping that created x_idx_unclipped.
      elif fill_value_below == 'extrapolate':
        if batch_y_ref:
          # For every batch member, gather the first two elements of y across
          # `axis`.
          y_0 = tf.gather(y_ref, [0], axis=axis)
          y_1 = tf.gather(y_ref, [1], axis=axis)
        else:
          # If not batching, we want to gather the first two elements, just like
          # above.  However, these results need to be replicated for every
          # member of x.  An easy way to do that is to gather using
          # indices = zeros/ones(x.shape).
          y_0 = tf.gather(
              y_ref, tf.zeros(tf.shape(x), dtype=tf.int32), axis=axis)
          y_1 = tf.gather(
              y_ref, tf.ones(tf.shape(x), dtype=tf.int32), axis=axis)
        x_delta = (x_ref_max - x_ref_min) / (ny - 1)
        x_factor = expand_x_fn((x - x_ref_min) / x_delta, broadcast=True)
        y = tf.where(x_idx_unclipped < 0, y_0 + x_factor * (y_1 - y_0), y)
      else:
        y = tf.where(x_idx_unclipped < 0, fill_value_below, y)
      # Fill values above x_ref_min <==> x_idx_unclipped > ny - 1.
      if fill_value_above == 'constant_extension':
        pass  # Already handled by the clipping that created x_idx_unclipped.
      elif fill_value_above == 'extrapolate':
        ny_int32 = tf.shape(y_ref)[axis]
        if batch_y_ref:
          y_n1 = tf.gather(y_ref, [tf.shape(y_ref)[axis] - 1], axis=axis)
          y_n2 = tf.gather(y_ref, [tf.shape(y_ref)[axis] - 2], axis=axis)
        else:
          y_n1 = tf.gather(
              y_ref, tf.fill(tf.shape(x), ny_int32 - 1), axis=axis)
          y_n2 = tf.gather(
              y_ref, tf.fill(tf.shape(x), ny_int32 - 2), axis=axis)
        x_delta = (x_ref_max - x_ref_min) / (ny - 1)
        x_factor = expand_x_fn((x - x_ref_max) / x_delta, broadcast=True)
        y = tf.where(x_idx_unclipped > ny - 1,
                     y_n1 + x_factor * (y_n1 - y_n2), y)
      else:
        y = tf.where(x_idx_unclipped > ny - 1, fill_value_above, y)

    return y
예제 #10
0
 def test_static_scalar_positive_index(self):
     positive_axis = prefer_static.non_negative_axis(axis=2, rank=4)
     self.assertAllEqual(2, positive_axis)