def make_kernel_bias_posterior_mvn_diag(kernel_shape, bias_shape, dtype=tf.float32, kernel_initializer=None, bias_initializer=None): """Create learnable posterior for Variational layers with kernel and bias.""" if kernel_initializer is None: kernel_initializer = tf.initializers.glorot_normal() if bias_initializer is None: bias_initializer = tf.initializers.glorot_normal() make_loc = lambda shape, init, name: tf.Variable( # pylint: disable=g-long-lambda init(shape, dtype=dtype), name=name + '_loc') make_scale = lambda shape, name: TransformedVariable( # pylint: disable=g-long-lambda tf.ones(shape, dtype=dtype), Chain([Shift(1e-5), Softplus()]), name=name + '_scale') return JointDistributionSequential([ Independent(Normal(loc=make_loc(kernel_shape, kernel_initializer, 'posterior_kernel'), scale=make_scale(kernel_shape, 'posterior_kernel')), reinterpreted_batch_ndims=prefer_static.size(kernel_shape), name='posterior_kernel'), Independent(Normal(loc=make_loc(bias_shape, bias_initializer, 'posterior_bias'), scale=make_scale(bias_shape, 'posterior_bias')), reinterpreted_batch_ndims=prefer_static.size(bias_shape), name='posterior_bias'), ])
def reduce_logmeanexp(input_tensor, axis=None, keepdims=False, name=None): """Computes `log(mean(exp(input_tensor)))`. Reduces `input_tensor` along the dimensions given in `axis`. Unless `keepdims` is true, the rank of the tensor is reduced by 1 for each entry in `axis`. If `keepdims` is true, the reduced dimensions are retained with length 1. If `axis` has no entries, all dimensions are reduced, and a tensor with a single element is returned. This function is more numerically stable than `log(reduce_mean(exp(input)))`. It avoids overflows caused by taking the exp of large inputs and underflows caused by taking the log of small inputs. Args: input_tensor: The tensor to reduce. Should have numeric type. axis: The dimensions to reduce. If `None` (the default), reduces all dimensions. Must be in the range `[-rank(input_tensor), rank(input_tensor))`. keepdims: Boolean. Whether to keep the axis as singleton dimensions. Default value: `False` (i.e., squeeze the reduced dimensions). name: Python `str` name prefixed to Ops created by this function. Default value: `None` (i.e., `'reduce_logmeanexp'`). Returns: log_mean_exp: The reduced tensor. """ with tf.name_scope(name or 'reduce_logmeanexp'): lse = tf.reduce_logsumexp(input_tensor, axis=axis, keepdims=keepdims) n = prefer_static.size(input_tensor) // prefer_static.size(lse) log_n = tf.math.log(tf.cast(n, lse.dtype)) return lse - log_n
def make_kernel_bias_posterior_mvn_diag( kernel_shape, bias_shape, kernel_initializer=None, bias_initializer=None, kernel_batch_ndims=0, # pylint: disable=unused-argument bias_batch_ndims=0, # pylint: disable=unused-argument dtype=tf.float32, kernel_name='posterior_kernel', bias_name='posterior_bias'): """Create learnable posterior for Variational layers with kernel and bias. Args: kernel_shape: ... bias_shape: ... kernel_initializer: ... Default value: `None` (i.e., `tf.initializers.glorot_uniform()`). bias_initializer: ... Default value: `None` (i.e., `tf.zeros`). kernel_batch_ndims: ... Default value: `0`. bias_batch_ndims: ... Default value: `0`. dtype: ... Default value: `tf.float32`. kernel_name: ... Default value: `"posterior_kernel"`. bias_name: ... Default value: `"posterior_bias"`. Returns: kernel_and_bias_distribution: ... """ if kernel_initializer is None: kernel_initializer = nn_init_lib.glorot_uniform() if bias_initializer is None: bias_initializer = tf.zeros make_loc = lambda init_fn, shape, batch_ndims, name: tf.Variable( # pylint: disable=g-long-lambda _try_call_init_fn(init_fn, shape, dtype, batch_ndims), name=name + '_loc') # Setting the initial scale to a relatively small value causes the `loc` to # quickly move toward a lower loss value. make_scale = lambda shape, name: TransformedVariable( # pylint: disable=g-long-lambda tf.fill(shape, value=tf.constant(1e-3, dtype=dtype)), Chain([Shift(1e-5), Softplus()]), name=name + '_scale') return JointDistributionSequential([ Independent(Normal(loc=make_loc(kernel_initializer, kernel_shape, kernel_batch_ndims, kernel_name), scale=make_scale(kernel_shape, kernel_name)), reinterpreted_batch_ndims=prefer_static.size(kernel_shape), name=kernel_name), Independent(Normal(loc=make_loc(bias_initializer, bias_shape, kernel_batch_ndims, bias_name), scale=make_scale(bias_shape, bias_name)), reinterpreted_batch_ndims=prefer_static.size(bias_shape), name=bias_name), ])
def _reshape_part(part, event_shape): part = tf.cast(part, self.dtype) static_rank = tf.get_static_value(ps.rank_from_shape(event_shape)) if static_rank == 1: return part new_shape = ps.concat([ ps.shape(part)[:ps.size(ps.shape(part)) - ps.size(event_shape)], [-1] ], axis=-1) return tf.reshape(part, ps.cast(new_shape, tf.int32))
def _cumulative_broadcast_dynamic(event_shape): broadcast_shapes = [ ps.slice(s, begin=[0], size=[ps.size(s)-1]) for s in event_shape] cumulative_shapes = [broadcast_shapes[0]] for shape in broadcast_shapes[1:]: out_shape = ps.broadcast_shape(shape, cumulative_shapes[-1]) cumulative_shapes.append(out_shape) return [ ps.concat([b, ps.slice(s, begin=[ps.size(s)-1], size=[1])], axis=0) for b, s in zip(cumulative_shapes, event_shape)]
def _reshape_part(part, event_shape): part = tf.cast(part, self.dtype) new_shape = ps.concat( [ ps.shape(part)[:ps.size(ps.shape(part)) - ps.size(event_shape)], [-1] ], axis=-1, ) return tf.reshape(part, ps.cast(new_shape, tf.int32))
def reduce_logmeanexp(input_tensor, axis=None, keepdims=False, experimental_named_axis=None, experimental_allow_all_gather=False, name=None): """Computes `log(mean(exp(input_tensor)))`. Reduces `input_tensor` along the dimensions given in `axis`. Unless `keepdims` is true, the rank of the tensor is reduced by 1 for each entry in `axis`. If `keepdims` is true, the reduced dimensions are retained with length 1. If `axis` has no entries, all dimensions are reduced, and a tensor with a single element is returned. This function is more numerically stable than `log(reduce_mean(exp(input)))`. It avoids overflows caused by taking the exp of large inputs and underflows caused by taking the log of small inputs. Args: input_tensor: The tensor to reduce. Should have numeric type. axis: The dimensions to reduce. If `None` (the default), reduces all dimensions. Must be in the range `[-rank(input_tensor), rank(input_tensor))`. keepdims: Boolean. Whether to keep the axis as singleton dimensions. Default value: `False` (i.e., squeeze the reduced dimensions). experimental_named_axis: A `str or list of `str` axis names to additionally reduce over. Providing `None` will not reduce over any axes. experimental_allow_all_gather: Allow using an `all_gather`-based fallback under TensorFlow when computing the distributed maximum. This fallback is only efficient when `axis` reduces away most of the dimensions of `input_tensor`. name: Python `str` name prefixed to Ops created by this function. Default value: `None` (i.e., `'reduce_logmeanexp'`). Returns: log_mean_exp: The reduced tensor. """ with tf.name_scope(name or 'reduce_logmeanexp'): named_axes = distribute_lib.canonicalize_named_axis( experimental_named_axis) lse = distribute_lib.reduce_logsumexp( input_tensor, axis=axis, keepdims=keepdims, named_axis=named_axes, allow_all_gather=experimental_allow_all_gather) n = ps.size(input_tensor) // ps.size(lse) for named_axis in named_axes: n = n * distribute_lib.get_axis_size(named_axis) log_n = tf.math.log(tf.cast(n, lse.dtype)) return lse - log_n
def make_kernel_bias_posterior_mvn_diag(kernel_shape, bias_shape, dtype=tf.float32, kernel_initializer=None, bias_initializer=None, kernel_name='posterior_kernel', bias_name='posterior_bias'): """Create learnable posterior for Variational layers with kernel and bias. Args: kernel_shape: ... bias_shape: ... dtype: ... Default value: `tf.float32`. kernel_initializer: ... Default value: `None` (i.e., `tf.initializers.glorot_uniform()`). bias_initializer: ... Default value: `None` (i.e., `tf.zeros`). kernel_name: ... Default value: `"posterior_kernel"`. bias_name: ... Default value: `"posterior_bias"`. Returns: kernel_and_bias_distribution: ... """ if kernel_initializer is None: kernel_initializer = tf.initializers.glorot_uniform() if bias_initializer is None: bias_initializer = tf.zeros make_loc = lambda shape, init, name: tf.Variable( # pylint: disable=g-long-lambda init(shape, dtype=dtype), name=name + '_loc') make_scale = lambda shape, name: TransformedVariable( # pylint: disable=g-long-lambda tf.ones(shape, dtype=dtype), Chain([Shift(1e-5), Softplus()]), name=name + '_scale') return JointDistributionSequential([ Independent(Normal(loc=make_loc(kernel_shape, kernel_initializer, kernel_name), scale=make_scale(kernel_shape, kernel_name)), reinterpreted_batch_ndims=prefer_static.size(kernel_shape), name=kernel_name), Independent(Normal(loc=make_loc(bias_shape, bias_initializer, bias_name), scale=make_scale(bias_shape, bias_name)), reinterpreted_batch_ndims=prefer_static.size(bias_shape), name=bias_name), ])
def adjacent_swaps(num_replica, batch_shape=(), seed=None): """Make random shuffle using only one time swaps.""" with tf.name_scope(name or 'adjacent_swaps'): seed = SeedStream(seed, salt='random_adjacent_shuffle') # u selects parity. E.g., # u==True ==> [0, 2, 1, 4, 3] like swaps # u==False ==> [1, 0, 3, 2, 4] like swaps # If there are only 2 replicas, then the "True" swaps are null # swaps...which would contradict the user provided `prob_swap`. # So special case num_replica==2, forcing u==False in this case. u_shape = prefer_static.concat( (tf.ones(1, dtype=tf.int32), tf.cast(batch_shape, tf.int32)), axis=0) u = tf.random.uniform(u_shape, seed=seed()) < 0.5 u = tf.where(num_replica > 2, u, False) x = mcmc_util.left_justified_expand_dims_to( tf.range(num_replica, dtype=tf.int64), rank=prefer_static.size(u_shape)) y = tf.where(tf.equal(x % 2, tf.cast(u, dtype=tf.int64)), x + 1, x - 1) y = tf.clip_by_value(y, 0, num_replica - 1) # TODO(b/142689785): Consider using tf.cond and returning an empty list # then in REMC consider using a tf.cond for short-circuiting. return tf.where( tf.random.uniform(batch_shape, seed=seed()) < prob_swap, y, x)
def _forward_event_shape_tensor(self, input_shape, is_inverse=False): ndims = ps.size(input_shape) indices = ps.reshape(ps.add(self.axis, ndims), shape=[-1, 1]) extra_sizes = ps.reduce_sum(self.paddings, axis=-1) update_fn = (ps.tensor_scatter_nd_sub if is_inverse else ps.tensor_scatter_nd_add) return update_fn(ps.identity(input_shape), indices, extra_sizes)
def test_dynamic(self): if tf.executing_eagerly(): return x = tf1.placeholder_with_default( tf.random.normal([3, 4, 5], seed=tfp_test_util.test_seed()), shape=None) self.assertAllEqual( 3 * 4 * 5, self.evaluate(prefer_static.size(x)))
def prepare_tuple_argument(arg, n, arg_name, validate_args=False): """Helper which processes `Tensor`s to tuples in standard form.""" # Short-circuiting incoming lists and tuples here avoids both # Tensor packing / unpacking and numpy 1.20.+ pickiness about # np.array(tuple of Tensor). if isinstance(arg, (tuple, list)): if len(arg) == n: return tuple(arg) if len(arg) == 1: return (arg[0], ) * n arg_size = ps.size(arg) arg_size_ = tf.get_static_value(arg_size) assertions = [] if arg_size_ is not None: if arg_size_ not in (1, n): raise ValueError( 'The size of `{}` must be equal to `1` or to the rank ' 'of the convolution (={}). Saw size = {}'.format( arg_name, n, arg_size_)) elif validate_args: assertions.append( assert_util.assert_equal( ps.logical_or(arg_size == 1, arg_size == n), True, message= ('The size of `{}` must be equal to `1` or to the rank of the ' 'convolution (={})'.format(arg_name, n)))) with tf.control_dependencies(assertions): arg = ps.broadcast_to(arg, shape=[n]) arg = ps.unstack(arg, num=n) return arg
def bootstrap_results(self, init_state): """Creates initial `previous_kernel_results` using a supplied `state`.""" with tf.name_scope(self.name + '.bootstrap_results'): if not tf.nest.is_nested(init_state): init_state = [init_state] # Padding the step_size so it is compatable with the states step_size = self.step_size if len(step_size) == 1: step_size = step_size * len(init_state) self._step_size = step_size if len(step_size) != len(init_state): raise ValueError('Expected either one step size or {} (size of ' '`init_state`), but found {}'.format( len(init_state), len(step_size))) dummy_momentum = [tf.ones_like(state) for state in init_state] [ _, _, current_target_log_prob, current_grads_log_prob, ] = leapfrog_impl.process_args(self.target_log_prob_fn, dummy_momentum, init_state) batch_size = prefer_static.size(current_target_log_prob) return NUTSKernelResults( target_log_prob=current_target_log_prob, grads_target_log_prob=current_grads_log_prob, leapfrogs_computed=tf.zeros([], dtype=tf.int32, name='leapfrogs_computed'), is_accepted=tf.zeros([batch_size], dtype=tf.bool, name='is_accepted'), reach_max_depth=tf.zeros([batch_size], dtype=tf.bool, name='is_accepted'), )
def expand_dims(x, axis, name=None): """Like `tf.expand_dims` but accepts a vector of axes to expand.""" with tf.name_scope(name or 'expand_dims'): x = tf.convert_to_tensor(x, name='x') axis = tf.convert_to_tensor(axis, dtype_hint=tf.int32, name='axis') nx = prefer_static.rank(x) na = prefer_static.size(axis) is_neg_axis = axis < 0 k = prefer_static.reduce_sum( prefer_static.cast(is_neg_axis, axis.dtype)) axis = prefer_static.where(is_neg_axis, axis + nx, axis) axis = prefer_static.sort(axis) axis_neg, axis_pos = prefer_static.split(axis, [k, -1]) idx = prefer_static.argsort(prefer_static.concat([ axis_pos, prefer_static.range(nx), axis_neg, ], axis=0), stable=True) shape = prefer_static.pad(prefer_static.shape(x), paddings=[[na - k, k]], constant_values=1) shape = prefer_static.gather(shape, idx) return tf.reshape(x, shape)
def _prepare_args_with_initial_vertex(objective_function, initial_vertex, step_sizes, objective_at_initial_vertex, batch_evaluate_objective): """Constructs a standard axes aligned simplex.""" dim = ps.size(initial_vertex) # tf.eye complains about np.array(.., np.int32) num_rows, only welcomes numpy # scalars. TODO(b/162529062): Remove the following line. dim = dim if tf.is_tensor(dim) else int(dim) num_vertices = dim + 1 unit_vectors_along_axes = tf.reshape( tf.eye(dim, dim, dtype=dtype_util.base_dtype(initial_vertex.dtype)), ps.concat([[dim], ps.shape(initial_vertex)], axis=0)) # If step_sizes does not broadcast to initial_vertex, the multiplication # in the second term will fail. simplex_face = initial_vertex + step_sizes * unit_vectors_along_axes simplex = tf.concat([tf.expand_dims(initial_vertex, axis=0), simplex_face], axis=0) # Evaluate the objective function at the simplex vertices. if objective_at_initial_vertex is None: objective_at_simplex, num_evaluations = _evaluate_objective_multiple( objective_function, simplex, batch_evaluate_objective) else: objective_at_simplex_face, num_evaluations = _evaluate_objective_multiple( objective_function, simplex_face, batch_evaluate_objective) objective_at_simplex = tf.concat([ tf.expand_dims(objective_at_initial_vertex, axis=0), objective_at_simplex_face ], axis=0) return (dim, num_vertices, simplex, objective_at_simplex, num_evaluations)
def _prepare_window_args(x, low_indices=None, high_indices=None, axis=0): """Common argument defaulting logic for windowed statistics.""" if high_indices is None: high_indices = tf.range(ps.shape(x)[axis]) + 1 else: high_indices = tf.convert_to_tensor(high_indices) if low_indices is None: low_indices = high_indices // 2 else: low_indices = tf.convert_to_tensor(low_indices) # Broadcast indices together. high_indices = high_indices + tf.zeros_like(low_indices) low_indices = low_indices + tf.zeros_like(high_indices) # TODO(axch): Support batch low and high indices. That would # complicate this shape munging (though tf.gather should work # fine). # We want to place `low_counts` and `high_counts` at the `axis` # position, so we reshape them to shape `[1, 1, ..., 1, N, 1, ..., # 1]`, where the `N` is at `axis`. The `counts_shp`, below, # is this shape. size = ps.size(high_indices) counts_shp = ps.one_hot(axis, depth=ps.rank(x), on_value=size, off_value=1) low_counts = tf.reshape(tf.cast(low_indices, dtype=x.dtype), shape=counts_shp) high_counts = tf.reshape(tf.cast(high_indices, dtype=x.dtype), shape=counts_shp) return low_indices, high_indices, low_counts, high_counts
def _axis_size(x, axis=None): """Get number of elements of `x` in `axis`, as type `x.dtype`.""" if axis is None: return prefer_static.cast(prefer_static.size(x), x.dtype) return prefer_static.cast( prefer_static.reduce_prod( prefer_static.gather(prefer_static.shape(x), axis)), x.dtype)
def expand_dims_(x): """Implementation of `expand_dims`.""" with tf.name_scope(name or 'expand_dims'): x = tf.convert_to_tensor(x, name='x') new_axis = tf.convert_to_tensor(axis, dtype_hint=tf.int32, name='axis') nx = prefer_static.rank(x) na = prefer_static.size(new_axis) is_neg_axis = new_axis < 0 k = prefer_static.reduce_sum( prefer_static.cast(is_neg_axis, new_axis.dtype)) new_axis = prefer_static.where(is_neg_axis, new_axis + nx, new_axis) new_axis = prefer_static.sort(new_axis) axis_neg, axis_pos = prefer_static.split(new_axis, [k, -1]) idx = prefer_static.argsort(prefer_static.concat([ axis_pos, prefer_static.range(nx), axis_neg, ], axis=0), stable=True) shape = prefer_static.pad(prefer_static.shape(x), paddings=[[na - k, k]], constant_values=1) shape = prefer_static.gather(shape, idx) return tf.reshape(x, shape)
def _get_permutations(num_results, dims, seed=None): """Uniform iid sample from the space of permutations. Draws a sample of size `num_results` from the group of permutations of degrees specified by the `dims` tensor. These are packed together into one tensor such that each row is one sample from each of the dimensions in `dims`. For example, if dims = [2,3] and num_results = 2, the result is a tensor of shape [2, 2 + 3] and the first row of the result might look like: [1, 0, 2, 0, 1]. The first two elements are a permutation over 2 elements while the next three are a permutation over 3 elements. Args: num_results: A positive scalar `Tensor` of integral type. The number of draws from the discrete uniform distribution over the permutation groups. dims: A 1D `Tensor` of the same dtype as `num_results`. The degree of the permutation groups from which to sample. seed: PRNG seed; see `tfp.random.sanitize_seed` for details. Returns: permutations: A `Tensor` of shape `[num_results, sum(dims)]` and the same dtype as `dims`. """ seeds = samplers.split_seed(seed, n=ps.size(dims)) def generate_one(dim, seed): return tf.argsort(samplers.uniform([num_results, dim], seed=seed), axis=-1) return tf.concat([generate_one(dim, seed) for dim, seed in zip(tf.unstack(dims), tf.unstack(seeds))], axis=-1)
def adjacent_swaps(num_replica, batch_shape=(), step_count=None, seed=None): """Make random shuffle using only one time swaps.""" del step_count # Unused for this function. with tf.name_scope(name or 'adjacent_swaps'): parity_seed, proposal_seed = samplers.split_seed(seed) # u selects parity. E.g., # u==False ==> [1, 0, 3, 2, 4] even parity swaps # u==True ==> [0, 2, 1, 4, 3] odd parity swaps # If there are only 2 replicas, then the "True" swaps are null # swaps...which would contradict the user provided `prob_swap`. # So special case num_replica==2, forcing u==False in this case. u_shape = ps.concat( (ps.ones(1, dtype=tf.int32), ps.cast(batch_shape, tf.int32)), axis=0) u = samplers.uniform(u_shape, seed=parity_seed) < 0.5 u = tf.where(num_replica > 2, u, False) x = bu.left_justified_expand_dims_to(ps.range(num_replica, dtype=tf.int64), rank=ps.size(u_shape)) y = tf.where(tf.equal(x % 2, tf.cast(u, dtype=tf.int64)), x + 1, x - 1) y = tf.clip_by_value(y, 0, num_replica - 1) # TODO(b/142689785): Consider using tf.cond and returning an empty list # then in REMC consider using a tf.cond for short-circuiting. return tf.where( samplers.uniform(batch_shape, seed=proposal_seed) < prob_swap, y, x)
def _slice_single_param(param, param_event_ndims, slices, dist_batch_shape): """Slices a single parameter of a distribution. Args: param: A `Tensor`, the original parameter to slice. param_event_ndims: `int` event parameterization rank for this parameter. slices: A `tuple` of normalized slices. dist_batch_shape: The distribution's batch shape `Tensor`. Returns: new_param: A `Tensor`, batch-sliced according to slices. """ # Extend param shape with ones on the left to match dist_batch_shape. param_shape = ps.shape(param) insert_ones = ps.ones( [ps.size(dist_batch_shape) + param_event_ndims - ps.rank(param)], dtype=param_shape.dtype) new_param_shape = ps.concat([insert_ones, param_shape], axis=0) full_batch_param = tf.reshape(param, new_param_shape) param_slices = [] # We separately track the batch axis from the parameter axis because we want # them to align for positive indexing, and be offset by param_event_ndims for # negative indexing. param_dim_idx = 0 batch_dim_idx = 0 for slc in slices: if slc is tf.newaxis: param_slices.append(slc) continue if slc is Ellipsis: if batch_dim_idx < 0: raise ValueError('Found multiple `...` in slices {}'.format(slices)) param_slices.append(slc) # Switch over to negative indexing for the broadcast check. num_remaining_non_newaxis_slices = sum( [s is not tf.newaxis for s in slices[slices.index(Ellipsis) + 1:]]) batch_dim_idx = -num_remaining_non_newaxis_slices param_dim_idx = batch_dim_idx - param_event_ndims continue # Find the batch dimension sizes for both parameter and distribution. param_dim_size = new_param_shape[param_dim_idx] batch_dim_size = dist_batch_shape[batch_dim_idx] is_broadcast = batch_dim_size > param_dim_size # Slices are denoted by start:stop:step. if isinstance(slc, slice): start, stop, step = slc.start, slc.stop, slc.step if start is not None: start = ps.where(is_broadcast, 0, start) if stop is not None: stop = ps.where(is_broadcast, 1, stop) if step is not None: step = ps.where(is_broadcast, 1, step) param_slices.append(slice(start, stop, step)) else: # int, or int Tensor, e.g. d[d.batch_shape_tensor()[0] // 2] param_slices.append(ps.where(is_broadcast, 0, slc)) param_dim_idx += 1 batch_dim_idx += 1 param_slices.extend([ALL_SLICE] * param_event_ndims) return full_batch_param.__getitem__(tuple(param_slices))
def expand_right_dims(x, broadcast=False): """Expand x so it can bcast w/ tensors of output shape.""" expanded_shape_left = ps.broadcast_shape( ps.shape(x)[:-1], ps.ones([ps.size(y_ref_shape_left)], dtype=tf.int32)) expanded_shape = ps.concat( (expanded_shape_left, ps.shape(x)[-1:], ps.ones([ps.size(y_ref_shape_right)], dtype=tf.int32)), axis=0) x_expanded = tf.reshape(x, expanded_shape) if broadcast: broadcast_shape_left = ps.broadcast_shape( ps.shape(x)[:-1], y_ref_shape_left) broadcast_shape = ps.concat( (broadcast_shape_left, ps.shape(x)[-1:], y_ref_shape_right), axis=0) x_expanded = _broadcast_with(x_expanded, broadcast_shape) return x_expanded
def __init__( self, input_size, output_size, # Weights init_kernel_fn=None, # tfp.experimental.nn.initializers.glorot_uniform() init_bias_fn=None, # tf.initializers.zeros() make_kernel_bias_fn=nn_util_lib.make_kernel_bias, dtype=tf.float32, batch_shape=(), # Misc activation_fn=None, name=None): """Constructs layer. Args: input_size: ... output_size: ... init_kernel_fn: ... Default value: `None` (i.e., `tfp.experimental.nn.initializers.glorot_uniform()`). init_bias_fn: ... Default value: `None` (i.e., `tf.initializers.zeros()`). make_kernel_bias_fn: ... Default value: `tfp.experimental.nn.util.make_kernel_bias`. dtype: ... Default value: `tf.float32`. batch_shape: ... Default value: `()`. activation_fn: ... Default value: `None`. name: ... Default value: `None` (i.e., `'Affine'`). """ batch_shape = tf.constant( [], dtype=tf.int32) if batch_shape is None else prefer_static.cast( prefer_static.reshape(batch_shape, shape=[-1]), tf.int32) batch_ndims = prefer_static.size(batch_shape) kernel_shape = prefer_static.concat([ batch_shape, [input_size, output_size]], axis=0) bias_shape = prefer_static.concat([batch_shape, [output_size]], axis=0) apply_kernel_fn = lambda x, k: tf.matmul( x[..., tf.newaxis, :], k)[..., 0, :] # pylint-disable=long-lambda kernel, bias = make_kernel_bias_fn( kernel_shape, bias_shape, init_kernel_fn, init_bias_fn, batch_ndims, batch_ndims, dtype) self._make_kernel_bias_fn = make_kernel_bias_fn # For tracking. super(Affine, self).__init__( kernel=kernel, bias=bias, apply_kernel_fn=apply_kernel_fn, activation_fn=activation_fn, dtype=dtype, name=name)
def _dense_to_sparse(self, student_ids, question_ids, dense_correct): test_y_idx = np.stack([student_ids, question_ids], axis=-1) # Need to tile the indices across the batch, for gather_nd. batch_shape = ps.shape(dense_correct)[:-2] broadcast_shape = ps.concat([ps.ones_like(batch_shape), test_y_idx.shape], axis=-1) test_y_idx = tf.reshape(test_y_idx, broadcast_shape) test_y_idx = tf.tile(test_y_idx, ps.concat([batch_shape, [1, 1]], axis=-1)) return tf.gather_nd( dense_correct, test_y_idx, batch_dims=ps.size(batch_shape))
def _compute_calibration_bin_statistics(num_bins, logits=None, labels_true=None, labels_predicted=None): """Compute binning statistics required for calibration measures. Args: num_bins: int, number of probability bins, e.g. 10. logits: Tensor, (n,nlabels), with logits for n instances and nlabels. labels_true: Tensor, (n,), with tf.int32 or tf.int64 elements containing ground truth class labels in the range [0,nlabels]. labels_predicted: Tensor, (n,), with tf.int32 or tf.int64 elements containing decisions of the predictive system. If `None`, we will use the argmax decision using the `logits`. Returns: bz: Tensor, shape (2,num_bins), tf.int32, counts of incorrect (row 0) and correct (row 1) predictions in each of the `num_bins` probability bins. pmean_observed: Tensor, shape (num_bins,), tf.float32, the mean predictive probabilities in each probability bin. """ if labels_predicted is None: # If no labels are provided, we take the label with the maximum probability # decision. This corresponds to the optimal expected minimum loss decision # under 0/1 loss. pred_y = tf.argmax(logits, axis=1, output_type=labels_true.dtype) else: pred_y = labels_predicted correct = tf.cast(tf.equal(pred_y, labels_true), tf.int32) # Collect predicted probabilities of decisions pred = tf.nn.softmax(logits, axis=1) prob_y = tf.gather(pred, pred_y[:, tf.newaxis], batch_dims=1) # p(pred_y | x) prob_y = tf.reshape(prob_y, (ps.size(prob_y), )) # Compute b/z histogram statistics: # bz[0,bin] contains counts of incorrect predictions in the probability bin. # bz[1,bin] contains counts of correct predictions in the probability bin. bins = tf.histogram_fixed_width_bins(prob_y, [0.0, 1.0], nbins=num_bins) event_bin_counts = tf.math.bincount(correct * num_bins + bins, minlength=2 * num_bins, maxlength=2 * num_bins) event_bin_counts = tf.reshape(event_bin_counts, (2, num_bins)) # Compute mean predicted probability value in each of the `num_bins` bins pmean_observed = tf.math.unsorted_segment_sum(prob_y, bins, num_bins) tiny = np.finfo(dtype_util.as_numpy_dtype(logits.dtype)).tiny pmean_observed = pmean_observed / ( tf.cast(tf.reduce_sum(event_bin_counts, axis=0), logits.dtype) + tiny) return event_bin_counts, pmean_observed
def _update_principal_component_ema( self, reduce_axes, state, step, principal_component_ema_points, ema_principal_component, ): # This is a batched version of Oja's algorithm. For the learning rate step, # we use Welford's algorithm where the number of points is clamped to a # function that grows slower than N. event_axes = tf.nest.map_structure( lambda x: ps.range(ps.size(reduce_axes), ps.rank(x)) - ps.rank(x), state) if self.experimental_shard_axis_names is None: shard_axis_names = tf.nest.map_structure(lambda _: None, state) else: shard_axis_names = self.experimental_shard_axis_names def _center_part(x): return x - distribute_lib.reduce_mean( x, reduce_axes, self.experimental_reduce_chain_axis_names) state_dot_p = _dot_product(tf.nest.map_structure(_center_part, state), ema_principal_component, event_axes, shard_axis_names) def _weighted_sum_part(x): return distribute_lib.reduce_sum( bu.left_justified_expand_dims_like(state_dot_p, x) * x, reduce_axes, self.experimental_reduce_chain_axis_names) new_principal_component = _normalize( tf.nest.map_structure(_weighted_sum_part, state), event_axes, shard_axis_names) def _ema_part(old_x, new_x): weight = 1. / ( tf.cast(principal_component_ema_points, old_x.dtype) + 1.) return old_x + (new_x - old_x) * weight new_principal_component_ema_points = tf.minimum( principal_component_ema_points + 1, tf.maximum(1, step // self.principal_component_ema_factor)) new_ema_principal_component = _normalize( tf.nest.map_structure(_ema_part, ema_principal_component, new_principal_component), event_axes, shard_axis_names) return tf.nest.map_structure( lambda x, y: tf.where(step < self.num_adaptation_steps, x, y), (new_principal_component_ema_points, new_ema_principal_component), (principal_component_ema_points, ema_principal_component), )
def _get_reinterpreted_batch_ndims(self, distribution_batch_shape_tensor=None): if self._static_reinterpreted_batch_ndims is not None: return self._static_reinterpreted_batch_ndims if self._reinterpreted_batch_ndims is not None: return tf.convert_to_tensor(self._reinterpreted_batch_ndims) if distribution_batch_shape_tensor is None: distribution_batch_shape_tensor = self.distribution.batch_shape_tensor() return ps.cast( ps.maximum(0, ps.size(distribution_batch_shape_tensor) - 1), np.int32)
def even_odd_swaps(num_replica, batch_shape=(), step_count=None, seed=None): """Make deterministic even_odd one time swaps.""" if step_count is None: raise ValueError('`step_count` must be supplied. Found `None`.') del seed # Unused for this function. with tf.name_scope(name or 'even_odd_swaps'): # Period is 1 / frequency, and we want period = Inf if frequency = 0. # safe_swap_period is the correct swap period in case swap_frequency > 0. # If swap_frequency == 0, safe_swap_period is set to 1 (to avoid integer # div by zero below). We will hard-set this case to "null swap." swap_freq = tf.convert_to_tensor(swap_frequency, name='swap_frequency') safe_swap_period = tf.cast( tf.where(swap_freq > 0, tf.math.ceil(tf.math.reciprocal_no_nan(swap_freq)), 1), # Although period = 1 / frequency may have roundoff error, and result # in a period different than what the user intended, the # user will end up with a single integer period, and thus well defined # deterministic swaps. tf.int32, ) # u selects parity. E.g., # u==False ==> [1, 0, 3, 2, 4] even parity swaps # u==True ==> [0, 2, 1, 4, 3] odd parity swaps # If there are 2 replicas, then the "True" swaps are null # swaps...which would contradict the user provided `swap_frequency`. # So special case num_replica==2, forcing u==False in this case. u_shape = ps.concat( (ps.ones(1, dtype=tf.int32), ps.cast(batch_shape, tf.int32)), axis=0) u = tf.fill(u_shape, tf.cast((step_count // safe_swap_period) % 2, tf.bool)) u = tf.where(num_replica > 2, u, False) x = bu.left_justified_expand_dims_to(tf.range(num_replica, dtype=tf.int64), rank=ps.size(u_shape)) y = tf.where(tf.equal(x % 2, tf.cast(u, dtype=tf.int64)), x + 1, x - 1) y = tf.clip_by_value(y, 0, num_replica - 1) # TODO(b/142689785): Consider using tf.cond and returning an empty list # then in REMC consider using a tf.cond for short-circuiting. return tf.where( (tf.cast(step_count % safe_swap_period, tf.bool) | tf.math.equal(swap_freq, 0)), x, # Don't swap y, # Swap )
def _canonicalize_steps_to_trace(step_indices_to_trace, num_timesteps): """Canonicalizes `3` -> `[3]`, `[-2, -1]` -> `[N - 2, N - 1]`, etc.""" step_indices_to_trace = tf.convert_to_tensor( step_indices_to_trace, dtype_hint=tf.int32) # Warning: breaks gradients. traced_steps_have_rank_zero = ps.equal( ps.rank_from_shape(ps.shape(step_indices_to_trace)), 0) # Canonicalize negative step indices as positive. step_indices_to_trace = ps.where(step_indices_to_trace < 0, num_timesteps + step_indices_to_trace, step_indices_to_trace) # Canonicalize scalars as length-one vectors. return (ps.reshape(step_indices_to_trace, [ps.size(step_indices_to_trace)]), traced_steps_have_rank_zero)
def _build_sub_tree(self, direction, log_slice_sample, nsteps, initial_state, continue_tree, trace_arrays, name=None): with tf.name_scope('build_sub_tree'): batch_size = prefer_static.size(log_slice_sample) initial_state_candidate = TreeDoublingStateCandidate( state=initial_state.state, target=initial_state.target, target_grad_parts=initial_state.target_grad_parts, # We never want to select the inital state weight=tf.zeros(batch_size, dtype=TREE_COUNT_DTYPE)) [ leapfrogs_computed, final_state, candidate_tree_state, final_continue_tree, trace_arrays, ] = tf.while_loop( cond=lambda iter_, state, state_c, continue_tree, trace_arrays: ( # pylint: disable=g-long-lambda (iter_ < nsteps) & tf.reduce_any(continue_tree)), body=lambda iter_, state, state_c, continue_tree, trace_arrays: ( # pylint: disable=g-long-lambda self._loop_build_sub_tree( direction, log_slice_sample, iter_, state, state_c, continue_tree, trace_arrays)), loop_vars=( tf.zeros([], dtype=tf.int32, name='iter'), initial_state, initial_state_candidate, continue_tree, trace_arrays, ), parallel_iterations=TF_WHILE_PARALLEL_ITERATIONS, ) return ( candidate_tree_state, final_state, final_continue_tree, leapfrogs_computed, )