Beispiel #1
0
def make_kernel_bias_posterior_mvn_diag(kernel_shape,
                                        bias_shape,
                                        dtype=tf.float32,
                                        kernel_initializer=None,
                                        bias_initializer=None):
    """Create learnable posterior for Variational layers with kernel and bias."""
    if kernel_initializer is None:
        kernel_initializer = tf.initializers.glorot_normal()
    if bias_initializer is None:
        bias_initializer = tf.initializers.glorot_normal()
    make_loc = lambda shape, init, name: tf.Variable(  # pylint: disable=g-long-lambda
        init(shape, dtype=dtype),
        name=name + '_loc')
    make_scale = lambda shape, name: TransformedVariable(  # pylint: disable=g-long-lambda
        tf.ones(shape, dtype=dtype),
        Chain([Shift(1e-5), Softplus()]),
        name=name + '_scale')
    return JointDistributionSequential([
        Independent(Normal(loc=make_loc(kernel_shape, kernel_initializer,
                                        'posterior_kernel'),
                           scale=make_scale(kernel_shape, 'posterior_kernel')),
                    reinterpreted_batch_ndims=prefer_static.size(kernel_shape),
                    name='posterior_kernel'),
        Independent(Normal(loc=make_loc(bias_shape, bias_initializer,
                                        'posterior_bias'),
                           scale=make_scale(bias_shape, 'posterior_bias')),
                    reinterpreted_batch_ndims=prefer_static.size(bias_shape),
                    name='posterior_bias'),
    ])
Beispiel #2
0
def reduce_logmeanexp(input_tensor, axis=None, keepdims=False, name=None):
    """Computes `log(mean(exp(input_tensor)))`.

  Reduces `input_tensor` along the dimensions given in `axis`.  Unless
  `keepdims` is true, the rank of the tensor is reduced by 1 for each entry in
  `axis`. If `keepdims` is true, the reduced dimensions are retained with length
  1.

  If `axis` has no entries, all dimensions are reduced, and a tensor with a
  single element is returned.

  This function is more numerically stable than `log(reduce_mean(exp(input)))`.
  It avoids overflows caused by taking the exp of large inputs and underflows
  caused by taking the log of small inputs.

  Args:
    input_tensor: The tensor to reduce. Should have numeric type.
    axis: The dimensions to reduce. If `None` (the default), reduces all
      dimensions. Must be in the range `[-rank(input_tensor),
      rank(input_tensor))`.
    keepdims:  Boolean.  Whether to keep the axis as singleton dimensions.
      Default value: `False` (i.e., squeeze the reduced dimensions).
    name: Python `str` name prefixed to Ops created by this function.
      Default value: `None` (i.e., `'reduce_logmeanexp'`).

  Returns:
    log_mean_exp: The reduced tensor.
  """
    with tf.name_scope(name or 'reduce_logmeanexp'):
        lse = tf.reduce_logsumexp(input_tensor, axis=axis, keepdims=keepdims)
        n = prefer_static.size(input_tensor) // prefer_static.size(lse)
        log_n = tf.math.log(tf.cast(n, lse.dtype))
        return lse - log_n
Beispiel #3
0
def make_kernel_bias_posterior_mvn_diag(
        kernel_shape,
        bias_shape,
        kernel_initializer=None,
        bias_initializer=None,
        kernel_batch_ndims=0,  # pylint: disable=unused-argument
        bias_batch_ndims=0,  # pylint: disable=unused-argument
        dtype=tf.float32,
        kernel_name='posterior_kernel',
        bias_name='posterior_bias'):
    """Create learnable posterior for Variational layers with kernel and bias.

  Args:
    kernel_shape: ...
    bias_shape: ...
    kernel_initializer: ...
      Default value: `None` (i.e., `tf.initializers.glorot_uniform()`).
    bias_initializer: ...
      Default value: `None` (i.e., `tf.zeros`).
    kernel_batch_ndims: ...
      Default value: `0`.
    bias_batch_ndims: ...
      Default value: `0`.
    dtype: ...
      Default value: `tf.float32`.
    kernel_name: ...
      Default value: `"posterior_kernel"`.
    bias_name: ...
      Default value: `"posterior_bias"`.

  Returns:
    kernel_and_bias_distribution: ...
  """
    if kernel_initializer is None:
        kernel_initializer = nn_init_lib.glorot_uniform()
    if bias_initializer is None:
        bias_initializer = tf.zeros
    make_loc = lambda init_fn, shape, batch_ndims, name: tf.Variable(  # pylint: disable=g-long-lambda
        _try_call_init_fn(init_fn, shape, dtype, batch_ndims),
        name=name + '_loc')
    # Setting the initial scale to a relatively small value causes the `loc` to
    # quickly move toward a lower loss value.
    make_scale = lambda shape, name: TransformedVariable(  # pylint: disable=g-long-lambda
        tf.fill(shape, value=tf.constant(1e-3, dtype=dtype)),
        Chain([Shift(1e-5), Softplus()]),
        name=name + '_scale')
    return JointDistributionSequential([
        Independent(Normal(loc=make_loc(kernel_initializer, kernel_shape,
                                        kernel_batch_ndims, kernel_name),
                           scale=make_scale(kernel_shape, kernel_name)),
                    reinterpreted_batch_ndims=prefer_static.size(kernel_shape),
                    name=kernel_name),
        Independent(Normal(loc=make_loc(bias_initializer, bias_shape,
                                        kernel_batch_ndims, bias_name),
                           scale=make_scale(bias_shape, bias_name)),
                    reinterpreted_batch_ndims=prefer_static.size(bias_shape),
                    name=bias_name),
    ])
Beispiel #4
0
 def _reshape_part(part, event_shape):
   part = tf.cast(part, self.dtype)
   static_rank = tf.get_static_value(ps.rank_from_shape(event_shape))
   if static_rank == 1:
     return part
   new_shape = ps.concat([
       ps.shape(part)[:ps.size(ps.shape(part)) - ps.size(event_shape)], [-1]
   ],
                         axis=-1)
   return tf.reshape(part, ps.cast(new_shape, tf.int32))
def _cumulative_broadcast_dynamic(event_shape):
  broadcast_shapes = [
      ps.slice(s, begin=[0], size=[ps.size(s)-1]) for s in event_shape]
  cumulative_shapes = [broadcast_shapes[0]]
  for shape in broadcast_shapes[1:]:
    out_shape = ps.broadcast_shape(shape, cumulative_shapes[-1])
    cumulative_shapes.append(out_shape)
  return [
      ps.concat([b, ps.slice(s, begin=[ps.size(s)-1], size=[1])], axis=0)
      for b, s in zip(cumulative_shapes, event_shape)]
Beispiel #6
0
 def _reshape_part(part, event_shape):
     part = tf.cast(part, self.dtype)
     new_shape = ps.concat(
         [
             ps.shape(part)[:ps.size(ps.shape(part)) -
                            ps.size(event_shape)], [-1]
         ],
         axis=-1,
     )
     return tf.reshape(part, ps.cast(new_shape, tf.int32))
Beispiel #7
0
def reduce_logmeanexp(input_tensor,
                      axis=None,
                      keepdims=False,
                      experimental_named_axis=None,
                      experimental_allow_all_gather=False,
                      name=None):
    """Computes `log(mean(exp(input_tensor)))`.

  Reduces `input_tensor` along the dimensions given in `axis`.  Unless
  `keepdims` is true, the rank of the tensor is reduced by 1 for each entry in
  `axis`. If `keepdims` is true, the reduced dimensions are retained with length
  1.

  If `axis` has no entries, all dimensions are reduced, and a tensor with a
  single element is returned.

  This function is more numerically stable than `log(reduce_mean(exp(input)))`.
  It avoids overflows caused by taking the exp of large inputs and underflows
  caused by taking the log of small inputs.

  Args:
    input_tensor: The tensor to reduce. Should have numeric type.
    axis: The dimensions to reduce. If `None` (the default), reduces all
      dimensions. Must be in the range `[-rank(input_tensor),
      rank(input_tensor))`.
    keepdims:  Boolean.  Whether to keep the axis as singleton dimensions.
      Default value: `False` (i.e., squeeze the reduced dimensions).
    experimental_named_axis: A `str or list of `str` axis names to additionally
      reduce over. Providing `None` will not reduce over any axes.
    experimental_allow_all_gather: Allow using an `all_gather`-based fallback
      under TensorFlow when computing the distributed maximum. This fallback is
      only efficient when `axis` reduces away most of the dimensions of
      `input_tensor`.
    name: Python `str` name prefixed to Ops created by this function.
      Default value: `None` (i.e., `'reduce_logmeanexp'`).

  Returns:
    log_mean_exp: The reduced tensor.
  """
    with tf.name_scope(name or 'reduce_logmeanexp'):
        named_axes = distribute_lib.canonicalize_named_axis(
            experimental_named_axis)
        lse = distribute_lib.reduce_logsumexp(
            input_tensor,
            axis=axis,
            keepdims=keepdims,
            named_axis=named_axes,
            allow_all_gather=experimental_allow_all_gather)
        n = ps.size(input_tensor) // ps.size(lse)
        for named_axis in named_axes:
            n = n * distribute_lib.get_axis_size(named_axis)
        log_n = tf.math.log(tf.cast(n, lse.dtype))
        return lse - log_n
Beispiel #8
0
def make_kernel_bias_posterior_mvn_diag(kernel_shape,
                                        bias_shape,
                                        dtype=tf.float32,
                                        kernel_initializer=None,
                                        bias_initializer=None,
                                        kernel_name='posterior_kernel',
                                        bias_name='posterior_bias'):
    """Create learnable posterior for Variational layers with kernel and bias.

  Args:
    kernel_shape: ...
    bias_shape: ...
    dtype: ...
      Default value: `tf.float32`.
    kernel_initializer: ...
      Default value: `None` (i.e., `tf.initializers.glorot_uniform()`).
    bias_initializer: ...
      Default value: `None` (i.e., `tf.zeros`).
    kernel_name: ...
      Default value: `"posterior_kernel"`.
    bias_name: ...
      Default value: `"posterior_bias"`.

  Returns:
    kernel_and_bias_distribution: ...
  """
    if kernel_initializer is None:
        kernel_initializer = tf.initializers.glorot_uniform()
    if bias_initializer is None:
        bias_initializer = tf.zeros
    make_loc = lambda shape, init, name: tf.Variable(  # pylint: disable=g-long-lambda
        init(shape, dtype=dtype),
        name=name + '_loc')
    make_scale = lambda shape, name: TransformedVariable(  # pylint: disable=g-long-lambda
        tf.ones(shape, dtype=dtype),
        Chain([Shift(1e-5), Softplus()]),
        name=name + '_scale')
    return JointDistributionSequential([
        Independent(Normal(loc=make_loc(kernel_shape, kernel_initializer,
                                        kernel_name),
                           scale=make_scale(kernel_shape, kernel_name)),
                    reinterpreted_batch_ndims=prefer_static.size(kernel_shape),
                    name=kernel_name),
        Independent(Normal(loc=make_loc(bias_shape, bias_initializer,
                                        bias_name),
                           scale=make_scale(bias_shape, bias_name)),
                    reinterpreted_batch_ndims=prefer_static.size(bias_shape),
                    name=bias_name),
    ])
Beispiel #9
0
    def adjacent_swaps(num_replica, batch_shape=(), seed=None):
        """Make random shuffle using only one time swaps."""
        with tf.name_scope(name or 'adjacent_swaps'):
            seed = SeedStream(seed, salt='random_adjacent_shuffle')
            # u selects parity.  E.g.,
            #  u==True ==>  [0, 2, 1, 4, 3] like swaps
            #  u==False ==> [1, 0, 3, 2, 4] like swaps
            # If there are only 2 replicas, then the "True" swaps are null
            # swaps...which would contradict the user provided `prob_swap`.
            # So special case num_replica==2, forcing u==False in this case.
            u_shape = prefer_static.concat(
                (tf.ones(1, dtype=tf.int32), tf.cast(batch_shape, tf.int32)),
                axis=0)
            u = tf.random.uniform(u_shape, seed=seed()) < 0.5
            u = tf.where(num_replica > 2, u, False)

            x = mcmc_util.left_justified_expand_dims_to(
                tf.range(num_replica, dtype=tf.int64),
                rank=prefer_static.size(u_shape))
            y = tf.where(tf.equal(x % 2, tf.cast(u, dtype=tf.int64)), x + 1,
                         x - 1)
            y = tf.clip_by_value(y, 0, num_replica - 1)
            # TODO(b/142689785): Consider using tf.cond and returning an empty list
            # then in REMC consider using a tf.cond for short-circuiting.
            return tf.where(
                tf.random.uniform(batch_shape, seed=seed()) < prob_swap, y, x)
Beispiel #10
0
 def _forward_event_shape_tensor(self, input_shape, is_inverse=False):
   ndims = ps.size(input_shape)
   indices = ps.reshape(ps.add(self.axis, ndims), shape=[-1, 1])
   extra_sizes = ps.reduce_sum(self.paddings, axis=-1)
   update_fn = (ps.tensor_scatter_nd_sub if is_inverse else
                ps.tensor_scatter_nd_add)
   return update_fn(ps.identity(input_shape), indices, extra_sizes)
Beispiel #11
0
 def test_dynamic(self):
   if tf.executing_eagerly(): return
   x = tf1.placeholder_with_default(
       tf.random.normal([3, 4, 5], seed=tfp_test_util.test_seed()), shape=None)
   self.assertAllEqual(
       3 * 4 * 5,
       self.evaluate(prefer_static.size(x)))
Beispiel #12
0
def prepare_tuple_argument(arg, n, arg_name, validate_args=False):
    """Helper which processes `Tensor`s to tuples in standard form."""
    # Short-circuiting incoming lists and tuples here avoids both
    # Tensor packing / unpacking and numpy 1.20.+ pickiness about
    # np.array(tuple of Tensor).
    if isinstance(arg, (tuple, list)):
        if len(arg) == n:
            return tuple(arg)
        if len(arg) == 1:
            return (arg[0], ) * n

    arg_size = ps.size(arg)
    arg_size_ = tf.get_static_value(arg_size)
    assertions = []
    if arg_size_ is not None:
        if arg_size_ not in (1, n):
            raise ValueError(
                'The size of `{}` must be equal to `1` or to the rank '
                'of the convolution (={}). Saw size = {}'.format(
                    arg_name, n, arg_size_))
    elif validate_args:
        assertions.append(
            assert_util.assert_equal(
                ps.logical_or(arg_size == 1, arg_size == n),
                True,
                message=
                ('The size of `{}` must be equal to `1` or to the rank of the '
                 'convolution (={})'.format(arg_name, n))))

    with tf.control_dependencies(assertions):
        arg = ps.broadcast_to(arg, shape=[n])
        arg = ps.unstack(arg, num=n)
        return arg
Beispiel #13
0
  def bootstrap_results(self, init_state):
    """Creates initial `previous_kernel_results` using a supplied `state`."""
    with tf.name_scope(self.name + '.bootstrap_results'):
      if not tf.nest.is_nested(init_state):
        init_state = [init_state]
      # Padding the step_size so it is compatable with the states
      step_size = self.step_size
      if len(step_size) == 1:
        step_size = step_size * len(init_state)
        self._step_size = step_size
      if len(step_size) != len(init_state):
        raise ValueError('Expected either one step size or {} (size of '
                         '`init_state`), but found {}'.format(
                             len(init_state), len(step_size)))
      dummy_momentum = [tf.ones_like(state) for state in init_state]
      [
          _,
          _,
          current_target_log_prob,
          current_grads_log_prob,
      ] = leapfrog_impl.process_args(self.target_log_prob_fn,
                                     dummy_momentum,
                                     init_state)
      batch_size = prefer_static.size(current_target_log_prob)

      return NUTSKernelResults(
          target_log_prob=current_target_log_prob,
          grads_target_log_prob=current_grads_log_prob,
          leapfrogs_computed=tf.zeros([], dtype=tf.int32,
                                      name='leapfrogs_computed'),
          is_accepted=tf.zeros([batch_size], dtype=tf.bool,
                               name='is_accepted'),
          reach_max_depth=tf.zeros([batch_size], dtype=tf.bool,
                                   name='is_accepted'),
          )
Beispiel #14
0
def expand_dims(x, axis, name=None):
    """Like `tf.expand_dims` but accepts a vector of axes to expand."""
    with tf.name_scope(name or 'expand_dims'):
        x = tf.convert_to_tensor(x, name='x')
        axis = tf.convert_to_tensor(axis, dtype_hint=tf.int32, name='axis')
        nx = prefer_static.rank(x)
        na = prefer_static.size(axis)
        is_neg_axis = axis < 0
        k = prefer_static.reduce_sum(
            prefer_static.cast(is_neg_axis, axis.dtype))
        axis = prefer_static.where(is_neg_axis, axis + nx, axis)
        axis = prefer_static.sort(axis)
        axis_neg, axis_pos = prefer_static.split(axis, [k, -1])
        idx = prefer_static.argsort(prefer_static.concat([
            axis_pos,
            prefer_static.range(nx),
            axis_neg,
        ],
                                                         axis=0),
                                    stable=True)
        shape = prefer_static.pad(prefer_static.shape(x),
                                  paddings=[[na - k, k]],
                                  constant_values=1)
        shape = prefer_static.gather(shape, idx)
        return tf.reshape(x, shape)
Beispiel #15
0
def _prepare_args_with_initial_vertex(objective_function, initial_vertex,
                                      step_sizes, objective_at_initial_vertex,
                                      batch_evaluate_objective):
    """Constructs a standard axes aligned simplex."""
    dim = ps.size(initial_vertex)
    # tf.eye complains about np.array(.., np.int32) num_rows, only welcomes numpy
    # scalars. TODO(b/162529062): Remove the following line.
    dim = dim if tf.is_tensor(dim) else int(dim)
    num_vertices = dim + 1
    unit_vectors_along_axes = tf.reshape(
        tf.eye(dim, dim, dtype=dtype_util.base_dtype(initial_vertex.dtype)),
        ps.concat([[dim], ps.shape(initial_vertex)], axis=0))

    # If step_sizes does not broadcast to initial_vertex, the multiplication
    # in the second term will fail.
    simplex_face = initial_vertex + step_sizes * unit_vectors_along_axes
    simplex = tf.concat([tf.expand_dims(initial_vertex, axis=0), simplex_face],
                        axis=0)
    # Evaluate the objective function at the simplex vertices.
    if objective_at_initial_vertex is None:
        objective_at_simplex, num_evaluations = _evaluate_objective_multiple(
            objective_function, simplex, batch_evaluate_objective)
    else:
        objective_at_simplex_face, num_evaluations = _evaluate_objective_multiple(
            objective_function, simplex_face, batch_evaluate_objective)
        objective_at_simplex = tf.concat([
            tf.expand_dims(objective_at_initial_vertex, axis=0),
            objective_at_simplex_face
        ],
                                         axis=0)

    return (dim, num_vertices, simplex, objective_at_simplex, num_evaluations)
Beispiel #16
0
def _prepare_window_args(x, low_indices=None, high_indices=None, axis=0):
    """Common argument defaulting logic for windowed statistics."""
    if high_indices is None:
        high_indices = tf.range(ps.shape(x)[axis]) + 1
    else:
        high_indices = tf.convert_to_tensor(high_indices)
    if low_indices is None:
        low_indices = high_indices // 2
    else:
        low_indices = tf.convert_to_tensor(low_indices)
    # Broadcast indices together.
    high_indices = high_indices + tf.zeros_like(low_indices)
    low_indices = low_indices + tf.zeros_like(high_indices)

    # TODO(axch): Support batch low and high indices.  That would
    # complicate this shape munging (though tf.gather should work
    # fine).

    # We want to place `low_counts` and `high_counts` at the `axis`
    # position, so we reshape them to shape `[1, 1, ..., 1, N, 1, ...,
    # 1]`, where the `N` is at `axis`.  The `counts_shp`, below,
    # is this shape.
    size = ps.size(high_indices)
    counts_shp = ps.one_hot(axis, depth=ps.rank(x), on_value=size, off_value=1)

    low_counts = tf.reshape(tf.cast(low_indices, dtype=x.dtype),
                            shape=counts_shp)
    high_counts = tf.reshape(tf.cast(high_indices, dtype=x.dtype),
                             shape=counts_shp)
    return low_indices, high_indices, low_counts, high_counts
Beispiel #17
0
def _axis_size(x, axis=None):
    """Get number of elements of `x` in `axis`, as type `x.dtype`."""
    if axis is None:
        return prefer_static.cast(prefer_static.size(x), x.dtype)
    return prefer_static.cast(
        prefer_static.reduce_prod(
            prefer_static.gather(prefer_static.shape(x), axis)), x.dtype)
Beispiel #18
0
 def expand_dims_(x):
     """Implementation of `expand_dims`."""
     with tf.name_scope(name or 'expand_dims'):
         x = tf.convert_to_tensor(x, name='x')
         new_axis = tf.convert_to_tensor(axis,
                                         dtype_hint=tf.int32,
                                         name='axis')
         nx = prefer_static.rank(x)
         na = prefer_static.size(new_axis)
         is_neg_axis = new_axis < 0
         k = prefer_static.reduce_sum(
             prefer_static.cast(is_neg_axis, new_axis.dtype))
         new_axis = prefer_static.where(is_neg_axis, new_axis + nx,
                                        new_axis)
         new_axis = prefer_static.sort(new_axis)
         axis_neg, axis_pos = prefer_static.split(new_axis, [k, -1])
         idx = prefer_static.argsort(prefer_static.concat([
             axis_pos,
             prefer_static.range(nx),
             axis_neg,
         ],
                                                          axis=0),
                                     stable=True)
         shape = prefer_static.pad(prefer_static.shape(x),
                                   paddings=[[na - k, k]],
                                   constant_values=1)
         shape = prefer_static.gather(shape, idx)
         return tf.reshape(x, shape)
def _get_permutations(num_results, dims, seed=None):
  """Uniform iid sample from the space of permutations.

  Draws a sample of size `num_results` from the group of permutations of degrees
  specified by the `dims` tensor. These are packed together into one tensor
  such that each row is one sample from each of the dimensions in `dims`. For
  example, if dims = [2,3] and num_results = 2, the result is a tensor of shape
  [2, 2 + 3] and the first row of the result might look like:
  [1, 0, 2, 0, 1]. The first two elements are a permutation over 2 elements
  while the next three are a permutation over 3 elements.

  Args:
    num_results: A positive scalar `Tensor` of integral type. The number of
      draws from the discrete uniform distribution over the permutation groups.
    dims: A 1D `Tensor` of the same dtype as `num_results`. The degree of the
      permutation groups from which to sample.
    seed: PRNG seed; see `tfp.random.sanitize_seed` for details.

  Returns:
    permutations: A `Tensor` of shape `[num_results, sum(dims)]` and the same
    dtype as `dims`.
  """
  seeds = samplers.split_seed(seed, n=ps.size(dims))

  def generate_one(dim, seed):
    return tf.argsort(samplers.uniform([num_results, dim], seed=seed), axis=-1)

  return tf.concat([generate_one(dim, seed)
                    for dim, seed in zip(tf.unstack(dims), tf.unstack(seeds))],
                   axis=-1)
Beispiel #20
0
    def adjacent_swaps(num_replica,
                       batch_shape=(),
                       step_count=None,
                       seed=None):
        """Make random shuffle using only one time swaps."""
        del step_count  # Unused for this function.
        with tf.name_scope(name or 'adjacent_swaps'):
            parity_seed, proposal_seed = samplers.split_seed(seed)
            # u selects parity.  E.g.,
            #  u==False ==> [1, 0, 3, 2, 4] even parity swaps
            #  u==True ==>  [0, 2, 1, 4, 3] odd parity swaps
            # If there are only 2 replicas, then the "True" swaps are null
            # swaps...which would contradict the user provided `prob_swap`.
            # So special case num_replica==2, forcing u==False in this case.
            u_shape = ps.concat(
                (ps.ones(1, dtype=tf.int32), ps.cast(batch_shape, tf.int32)),
                axis=0)
            u = samplers.uniform(u_shape, seed=parity_seed) < 0.5
            u = tf.where(num_replica > 2, u, False)

            x = bu.left_justified_expand_dims_to(ps.range(num_replica,
                                                          dtype=tf.int64),
                                                 rank=ps.size(u_shape))
            y = tf.where(tf.equal(x % 2, tf.cast(u, dtype=tf.int64)), x + 1,
                         x - 1)
            y = tf.clip_by_value(y, 0, num_replica - 1)
            # TODO(b/142689785): Consider using tf.cond and returning an empty list
            # then in REMC consider using a tf.cond for short-circuiting.
            return tf.where(
                samplers.uniform(batch_shape, seed=proposal_seed) < prob_swap,
                y, x)
Beispiel #21
0
def _slice_single_param(param, param_event_ndims, slices, dist_batch_shape):
  """Slices a single parameter of a distribution.

  Args:
    param: A `Tensor`, the original parameter to slice.
    param_event_ndims: `int` event parameterization rank for this parameter.
    slices: A `tuple` of normalized slices.
    dist_batch_shape: The distribution's batch shape `Tensor`.

  Returns:
    new_param: A `Tensor`, batch-sliced according to slices.
  """
  # Extend param shape with ones on the left to match dist_batch_shape.
  param_shape = ps.shape(param)
  insert_ones = ps.ones(
      [ps.size(dist_batch_shape) + param_event_ndims - ps.rank(param)],
      dtype=param_shape.dtype)
  new_param_shape = ps.concat([insert_ones, param_shape], axis=0)
  full_batch_param = tf.reshape(param, new_param_shape)
  param_slices = []
  # We separately track the batch axis from the parameter axis because we want
  # them to align for positive indexing, and be offset by param_event_ndims for
  # negative indexing.
  param_dim_idx = 0
  batch_dim_idx = 0
  for slc in slices:
    if slc is tf.newaxis:
      param_slices.append(slc)
      continue
    if slc is Ellipsis:
      if batch_dim_idx < 0:
        raise ValueError('Found multiple `...` in slices {}'.format(slices))
      param_slices.append(slc)
      # Switch over to negative indexing for the broadcast check.
      num_remaining_non_newaxis_slices = sum(
          [s is not tf.newaxis for s in slices[slices.index(Ellipsis) + 1:]])
      batch_dim_idx = -num_remaining_non_newaxis_slices
      param_dim_idx = batch_dim_idx - param_event_ndims
      continue
    # Find the batch dimension sizes for both parameter and distribution.
    param_dim_size = new_param_shape[param_dim_idx]
    batch_dim_size = dist_batch_shape[batch_dim_idx]
    is_broadcast = batch_dim_size > param_dim_size
    # Slices are denoted by start:stop:step.
    if isinstance(slc, slice):
      start, stop, step = slc.start, slc.stop, slc.step
      if start is not None:
        start = ps.where(is_broadcast, 0, start)
      if stop is not None:
        stop = ps.where(is_broadcast, 1, stop)
      if step is not None:
        step = ps.where(is_broadcast, 1, step)
      param_slices.append(slice(start, stop, step))
    else:  # int, or int Tensor, e.g. d[d.batch_shape_tensor()[0] // 2]
      param_slices.append(ps.where(is_broadcast, 0, slc))
    param_dim_idx += 1
    batch_dim_idx += 1
  param_slices.extend([ALL_SLICE] * param_event_ndims)
  return full_batch_param.__getitem__(tuple(param_slices))
 def expand_right_dims(x, broadcast=False):
   """Expand x so it can bcast w/ tensors of output shape."""
   expanded_shape_left = ps.broadcast_shape(
       ps.shape(x)[:-1],
       ps.ones([ps.size(y_ref_shape_left)], dtype=tf.int32))
   expanded_shape = ps.concat(
       (expanded_shape_left, ps.shape(x)[-1:],
        ps.ones([ps.size(y_ref_shape_right)], dtype=tf.int32)),
       axis=0)
   x_expanded = tf.reshape(x, expanded_shape)
   if broadcast:
     broadcast_shape_left = ps.broadcast_shape(
         ps.shape(x)[:-1], y_ref_shape_left)
     broadcast_shape = ps.concat(
         (broadcast_shape_left, ps.shape(x)[-1:], y_ref_shape_right),
         axis=0)
     x_expanded = _broadcast_with(x_expanded, broadcast_shape)
   return x_expanded
Beispiel #23
0
  def __init__(
      self,
      input_size,
      output_size,
      # Weights
      init_kernel_fn=None,  # tfp.experimental.nn.initializers.glorot_uniform()
      init_bias_fn=None,    # tf.initializers.zeros()
      make_kernel_bias_fn=nn_util_lib.make_kernel_bias,
      dtype=tf.float32,
      batch_shape=(),
      # Misc
      activation_fn=None,
      name=None):
    """Constructs layer.

    Args:
      input_size: ...
      output_size: ...
      init_kernel_fn: ...
        Default value: `None` (i.e.,
        `tfp.experimental.nn.initializers.glorot_uniform()`).
      init_bias_fn: ...
        Default value: `None` (i.e., `tf.initializers.zeros()`).
      make_kernel_bias_fn: ...
        Default value: `tfp.experimental.nn.util.make_kernel_bias`.
      dtype: ...
        Default value: `tf.float32`.
      batch_shape: ...
        Default value: `()`.
      activation_fn: ...
        Default value: `None`.
      name: ...
        Default value: `None` (i.e., `'Affine'`).
    """
    batch_shape = tf.constant(
        [], dtype=tf.int32) if batch_shape is None else prefer_static.cast(
            prefer_static.reshape(batch_shape, shape=[-1]), tf.int32)
    batch_ndims = prefer_static.size(batch_shape)
    kernel_shape = prefer_static.concat([
        batch_shape, [input_size, output_size]], axis=0)
    bias_shape = prefer_static.concat([batch_shape, [output_size]], axis=0)
    apply_kernel_fn = lambda x, k: tf.matmul(
        x[..., tf.newaxis, :], k)[..., 0, :]  # pylint-disable=long-lambda
    kernel, bias = make_kernel_bias_fn(
        kernel_shape, bias_shape,
        init_kernel_fn, init_bias_fn,
        batch_ndims, batch_ndims,
        dtype)
    self._make_kernel_bias_fn = make_kernel_bias_fn  # For tracking.
    super(Affine, self).__init__(
        kernel=kernel,
        bias=bias,
        apply_kernel_fn=apply_kernel_fn,
        activation_fn=activation_fn,
        dtype=dtype,
        name=name)
Beispiel #24
0
 def _dense_to_sparse(self, student_ids, question_ids, dense_correct):
   test_y_idx = np.stack([student_ids, question_ids], axis=-1)
   # Need to tile the indices across the batch, for gather_nd.
   batch_shape = ps.shape(dense_correct)[:-2]
   broadcast_shape = ps.concat([ps.ones_like(batch_shape), test_y_idx.shape],
                               axis=-1)
   test_y_idx = tf.reshape(test_y_idx, broadcast_shape)
   test_y_idx = tf.tile(test_y_idx, ps.concat([batch_shape, [1, 1]], axis=-1))
   return tf.gather_nd(
       dense_correct, test_y_idx, batch_dims=ps.size(batch_shape))
Beispiel #25
0
def _compute_calibration_bin_statistics(num_bins,
                                        logits=None,
                                        labels_true=None,
                                        labels_predicted=None):
    """Compute binning statistics required for calibration measures.

  Args:
    num_bins: int, number of probability bins, e.g. 10.
    logits: Tensor, (n,nlabels), with logits for n instances and nlabels.
    labels_true: Tensor, (n,), with tf.int32 or tf.int64 elements containing
      ground truth class labels in the range [0,nlabels].
    labels_predicted: Tensor, (n,), with tf.int32 or tf.int64 elements
      containing decisions of the predictive system.  If `None`, we will use
      the argmax decision using the `logits`.

  Returns:
    bz: Tensor, shape (2,num_bins), tf.int32, counts of incorrect (row 0) and
      correct (row 1) predictions in each of the `num_bins` probability bins.
    pmean_observed: Tensor, shape (num_bins,), tf.float32, the mean predictive
      probabilities in each probability bin.
  """

    if labels_predicted is None:
        # If no labels are provided, we take the label with the maximum probability
        # decision.  This corresponds to the optimal expected minimum loss decision
        # under 0/1 loss.
        pred_y = tf.argmax(logits, axis=1, output_type=labels_true.dtype)
    else:
        pred_y = labels_predicted

    correct = tf.cast(tf.equal(pred_y, labels_true), tf.int32)

    # Collect predicted probabilities of decisions
    pred = tf.nn.softmax(logits, axis=1)
    prob_y = tf.gather(pred, pred_y[:, tf.newaxis],
                       batch_dims=1)  # p(pred_y | x)
    prob_y = tf.reshape(prob_y, (ps.size(prob_y), ))

    # Compute b/z histogram statistics:
    # bz[0,bin] contains counts of incorrect predictions in the probability bin.
    # bz[1,bin] contains counts of correct predictions in the probability bin.
    bins = tf.histogram_fixed_width_bins(prob_y, [0.0, 1.0], nbins=num_bins)
    event_bin_counts = tf.math.bincount(correct * num_bins + bins,
                                        minlength=2 * num_bins,
                                        maxlength=2 * num_bins)
    event_bin_counts = tf.reshape(event_bin_counts, (2, num_bins))

    # Compute mean predicted probability value in each of the `num_bins` bins
    pmean_observed = tf.math.unsorted_segment_sum(prob_y, bins, num_bins)
    tiny = np.finfo(dtype_util.as_numpy_dtype(logits.dtype)).tiny
    pmean_observed = pmean_observed / (
        tf.cast(tf.reduce_sum(event_bin_counts, axis=0), logits.dtype) + tiny)

    return event_bin_counts, pmean_observed
Beispiel #26
0
    def _update_principal_component_ema(
        self,
        reduce_axes,
        state,
        step,
        principal_component_ema_points,
        ema_principal_component,
    ):
        # This is a batched version of Oja's algorithm. For the learning rate step,
        # we use Welford's algorithm where the number of points is clamped to a
        # function that grows slower than N.

        event_axes = tf.nest.map_structure(
            lambda x: ps.range(ps.size(reduce_axes), ps.rank(x)) - ps.rank(x),
            state)
        if self.experimental_shard_axis_names is None:
            shard_axis_names = tf.nest.map_structure(lambda _: None, state)
        else:
            shard_axis_names = self.experimental_shard_axis_names

        def _center_part(x):
            return x - distribute_lib.reduce_mean(
                x, reduce_axes, self.experimental_reduce_chain_axis_names)

        state_dot_p = _dot_product(tf.nest.map_structure(_center_part, state),
                                   ema_principal_component, event_axes,
                                   shard_axis_names)

        def _weighted_sum_part(x):
            return distribute_lib.reduce_sum(
                bu.left_justified_expand_dims_like(state_dot_p, x) * x,
                reduce_axes, self.experimental_reduce_chain_axis_names)

        new_principal_component = _normalize(
            tf.nest.map_structure(_weighted_sum_part, state), event_axes,
            shard_axis_names)

        def _ema_part(old_x, new_x):
            weight = 1. / (
                tf.cast(principal_component_ema_points, old_x.dtype) + 1.)
            return old_x + (new_x - old_x) * weight

        new_principal_component_ema_points = tf.minimum(
            principal_component_ema_points + 1,
            tf.maximum(1, step // self.principal_component_ema_factor))
        new_ema_principal_component = _normalize(
            tf.nest.map_structure(_ema_part, ema_principal_component,
                                  new_principal_component), event_axes,
            shard_axis_names)
        return tf.nest.map_structure(
            lambda x, y: tf.where(step < self.num_adaptation_steps, x, y),
            (new_principal_component_ema_points, new_ema_principal_component),
            (principal_component_ema_points, ema_principal_component),
        )
Beispiel #27
0
  def _get_reinterpreted_batch_ndims(self,
                                     distribution_batch_shape_tensor=None):
    if self._static_reinterpreted_batch_ndims is not None:
      return self._static_reinterpreted_batch_ndims
    if self._reinterpreted_batch_ndims is not None:
      return tf.convert_to_tensor(self._reinterpreted_batch_ndims)

    if distribution_batch_shape_tensor is None:
      distribution_batch_shape_tensor = self.distribution.batch_shape_tensor()
    return ps.cast(
        ps.maximum(0, ps.size(distribution_batch_shape_tensor) - 1),
        np.int32)
Beispiel #28
0
    def even_odd_swaps(num_replica,
                       batch_shape=(),
                       step_count=None,
                       seed=None):
        """Make deterministic even_odd one time swaps."""
        if step_count is None:
            raise ValueError('`step_count` must be supplied. Found `None`.')
        del seed  # Unused for this function.
        with tf.name_scope(name or 'even_odd_swaps'):
            # Period is 1 / frequency, and we want period = Inf if frequency = 0.
            # safe_swap_period is the correct swap period in case swap_frequency > 0.
            # If swap_frequency == 0, safe_swap_period is set to 1 (to avoid integer
            # div by zero below). We will hard-set this case to "null swap."
            swap_freq = tf.convert_to_tensor(swap_frequency,
                                             name='swap_frequency')
            safe_swap_period = tf.cast(
                tf.where(swap_freq > 0,
                         tf.math.ceil(tf.math.reciprocal_no_nan(swap_freq)),
                         1),
                # Although period = 1 / frequency may have roundoff error, and result
                # in a period different than what the user intended, the
                # user will end up with a single integer period, and thus well defined
                # deterministic swaps.
                tf.int32,
            )

            # u selects parity.  E.g.,
            #  u==False ==> [1, 0, 3, 2, 4] even parity swaps
            #  u==True ==>  [0, 2, 1, 4, 3] odd parity swaps
            # If there are 2 replicas, then the "True" swaps are null
            # swaps...which would contradict the user provided `swap_frequency`.
            # So special case num_replica==2, forcing u==False in this case.
            u_shape = ps.concat(
                (ps.ones(1, dtype=tf.int32), ps.cast(batch_shape, tf.int32)),
                axis=0)
            u = tf.fill(u_shape,
                        tf.cast((step_count // safe_swap_period) % 2, tf.bool))
            u = tf.where(num_replica > 2, u, False)

            x = bu.left_justified_expand_dims_to(tf.range(num_replica,
                                                          dtype=tf.int64),
                                                 rank=ps.size(u_shape))
            y = tf.where(tf.equal(x % 2, tf.cast(u, dtype=tf.int64)), x + 1,
                         x - 1)
            y = tf.clip_by_value(y, 0, num_replica - 1)
            # TODO(b/142689785): Consider using tf.cond and returning an empty list
            # then in REMC consider using a tf.cond for short-circuiting.
            return tf.where(
                (tf.cast(step_count % safe_swap_period, tf.bool)
                 | tf.math.equal(swap_freq, 0)),
                x,  # Don't swap
                y,  # Swap
            )
Beispiel #29
0
def _canonicalize_steps_to_trace(step_indices_to_trace, num_timesteps):
    """Canonicalizes `3` -> `[3]`, `[-2, -1]` -> `[N - 2, N - 1]`, etc."""
    step_indices_to_trace = tf.convert_to_tensor(
        step_indices_to_trace,
        dtype_hint=tf.int32)  # Warning: breaks gradients.
    traced_steps_have_rank_zero = ps.equal(
        ps.rank_from_shape(ps.shape(step_indices_to_trace)), 0)
    # Canonicalize negative step indices as positive.
    step_indices_to_trace = ps.where(step_indices_to_trace < 0,
                                     num_timesteps + step_indices_to_trace,
                                     step_indices_to_trace)
    # Canonicalize scalars as length-one vectors.
    return (ps.reshape(step_indices_to_trace,
                       [ps.size(step_indices_to_trace)]),
            traced_steps_have_rank_zero)
Beispiel #30
0
  def _build_sub_tree(self,
                      direction,
                      log_slice_sample,
                      nsteps,
                      initial_state,
                      continue_tree,
                      trace_arrays,
                      name=None):
    with tf.name_scope('build_sub_tree'):
      batch_size = prefer_static.size(log_slice_sample)
      initial_state_candidate = TreeDoublingStateCandidate(
          state=initial_state.state,
          target=initial_state.target,
          target_grad_parts=initial_state.target_grad_parts,
          # We never want to select the inital state
          weight=tf.zeros(batch_size, dtype=TREE_COUNT_DTYPE))
      [
          leapfrogs_computed,
          final_state,
          candidate_tree_state,
          final_continue_tree,
          trace_arrays,
      ] = tf.while_loop(
          cond=lambda iter_, state, state_c, continue_tree, trace_arrays: (  # pylint: disable=g-long-lambda
              (iter_ < nsteps) & tf.reduce_any(continue_tree)),
          body=lambda iter_, state, state_c, continue_tree, trace_arrays: (  # pylint: disable=g-long-lambda
              self._loop_build_sub_tree(
                  direction, log_slice_sample,
                  iter_, state, state_c, continue_tree, trace_arrays)),
          loop_vars=(
              tf.zeros([], dtype=tf.int32, name='iter'),
              initial_state,
              initial_state_candidate,
              continue_tree,
              trace_arrays,
          ),
          parallel_iterations=TF_WHILE_PARALLEL_ITERATIONS,
      )

    return (
        candidate_tree_state,
        final_state,
        final_continue_tree,
        leapfrogs_computed,
    )