def _log_average_probs_maybe_check_args(sample_axis, event_axis, validate_args): """Assertions for `log_average_probs`.""" assertions = [] msg = 'Arguments `sample_axis` and `event_axis` must be distinct.' sample_setdiff = ps.setdiff1d(sample_axis, event_axis) if ps.is_numpy(sample_setdiff): if not np.array_equal(sample_setdiff, tf.get_static_value(sample_axis)): raise ValueError(msg) elif validate_args: assertions.append( _assert_array_equal(sample_setdiff, sample_axis, message=msg, name='sample_setdiff_rank_check')) event_setdiff = ps.setdiff1d(event_axis, sample_axis) if ps.is_numpy(event_setdiff): if not np.array_equal(event_setdiff, tf.get_static_value(event_axis)): raise ValueError(msg) elif validate_args: assertions.append( _assert_array_equal(event_setdiff, event_axis, message=msg, name='event_setdiff_rank_check')) return assertions
def test_dynamic(self): if tf.executing_eagerly(): return x = tf1.placeholder_with_default(np.arange(5), shape=None) self.assertAllEqual([0, 3, 4], self.evaluate(ps.setdiff1d(x, [1, 2]))) x = tf1.placeholder_with_default(np.array([], np.int32), shape=None) self.assertAllEqual([], self.evaluate(ps.setdiff1d(x, [1, 2]))) self.assertAllEqual([1, 2], self.evaluate(ps.setdiff1d([1, 2], x)))
def test_static(self): self.assertAllEqual( [0, 3, 4], prefer_static.setdiff1d(np.arange(5), [1, 2])) self.assertAllEqual( [], prefer_static.setdiff1d([], [1, 2])) self.assertAllEqual( [1, 2], prefer_static.setdiff1d([1, 2], []))
def _squeeze(x, axis): """A version of squeeze that works with dynamic axis.""" x = tf.convert_to_tensor(x, name='x') if axis is None: return tf.squeeze(x, axis=None) axis = ps.convert_to_shape_tensor(axis, name='axis', dtype=tf.int32) axis = _make_list_or_1d_tensor(axis) # Ensure at least 1d. keep_axis = ps.setdiff1d(ps.range(0, ps.rank(x)), axis) return tf.reshape(x, ps.gather(ps.shape(x), keep_axis))
def _squeeze(x, axis): """A version of squeeze that works with dynamic axis.""" x = tf.convert_to_tensor(x, name='x') if axis is None: return tf.squeeze(x, axis=None) axis = ps.convert_to_shape_tensor(axis, name='axis', dtype=tf.int32) axis = axis + ps.zeros([1], dtype=axis.dtype) # Make axis at least 1d. keep_axis = ps.setdiff1d(ps.range(0, ps.rank(x)), axis) return tf.reshape(x, ps.gather(ps.shape(x), keep_axis))
def _log_average_probs_process_args(logits, validate_args, sample_axis, event_axis): """Processes args for `log_average_probs`.""" rank = ps.rank(logits) if sample_axis is None or validate_args: event_axis = ps.reshape(ps.non_negative_axis(event_axis, rank), shape=[-1]) if sample_axis is None: sample_axis = ps.setdiff1d(ps.range(rank), event_axis) elif validate_args: sample_axis = ps.reshape(ps.non_negative_axis(sample_axis, rank), shape=[-1]) return sample_axis, event_axis
def covariance(x, y=None, sample_axis=0, event_axis=-1, keepdims=False, name=None): """Sample covariance between observations indexed by `event_axis`. Given `N` samples of scalar random variables `X` and `Y`, covariance may be estimated as ```none Cov[X, Y] := N^{-1} sum_{n=1}^N (X_n - Xbar) Conj{(Y_n - Ybar)} Xbar := N^{-1} sum_{n=1}^N X_n Ybar := N^{-1} sum_{n=1}^N Y_n ``` For vector-variate random variables `X = (X1, ..., Xd)`, `Y = (Y1, ..., Yd)`, one is often interested in the covariance matrix, `C_{ij} := Cov[Xi, Yj]`. ```python x = tf.random.normal(shape=(100, 2, 3)) y = tf.random.normal(shape=(100, 2, 3)) # cov[i, j] is the sample covariance between x[:, i, j] and y[:, i, j]. cov = tfp.stats.covariance(x, y, sample_axis=0, event_axis=None) # cov_matrix[i, m, n] is the sample covariance of x[:, i, m] and y[:, i, n] cov_matrix = tfp.stats.covariance(x, y, sample_axis=0, event_axis=-1) ``` Notice we divide by `N`, which does not create `NaN` when `N = 1`, but is slightly biased. Args: x: A numeric `Tensor` holding samples. y: Optional `Tensor` with same `dtype` and `shape` as `x`. Default value: `None` (`y` is effectively set to `x`). sample_axis: Scalar or vector `Tensor` designating axis holding samples, or `None` (meaning all axis hold samples). Default value: `0` (leftmost dimension). event_axis: Scalar or vector `Tensor`, or `None` (scalar events). Axis indexing random events, whose covariance we are interested in. If a vector, entries must form a contiguous block of dims. `sample_axis` and `event_axis` should not intersect. Default value: `-1` (rightmost axis holds events). keepdims: Boolean. Whether to keep the sample axis as singletons. name: Python `str` name prefixed to Ops created by this function. Default value: `None` (i.e., `'covariance'`). Returns: cov: A `Tensor` of same `dtype` as the `x`, and rank equal to `rank(x) - len(sample_axis) + 2 * len(event_axis)`. Raises: AssertionError: If `x` and `y` are found to have different shape. ValueError: If `sample_axis` and `event_axis` are found to overlap. ValueError: If `event_axis` is found to not be contiguous. """ with tf.name_scope(name or 'covariance'): x = tf.convert_to_tensor(x, name='x') # Covariance *only* uses the centered versions of x (and y). x = x - tf.reduce_mean(x, axis=sample_axis, keepdims=True) if y is None: y = x else: y = tf.convert_to_tensor(y, name='y', dtype=x.dtype) # If x and y have different shape, sample_axis and event_axis will likely # be wrong for one of them! tensorshape_util.assert_is_compatible_with(x.shape, y.shape) y = y - tf.reduce_mean(y, axis=sample_axis, keepdims=True) if event_axis is None: return tf.reduce_mean(x * tf.math.conj(y), axis=sample_axis, keepdims=keepdims) if sample_axis is None: raise ValueError( 'sample_axis was None, which means all axis hold events, and this ' 'overlaps with event_axis ({})'.format(event_axis)) event_axis = _make_positive_axis(event_axis, ps.rank(x)) sample_axis = _make_positive_axis(sample_axis, ps.rank(x)) # If we get lucky and axis is statically defined, we can do some checks. if _is_list_like(event_axis) and _is_list_like(sample_axis): event_axis = tuple(map(int, event_axis)) sample_axis = tuple(map(int, sample_axis)) if set(event_axis).intersection(sample_axis): raise ValueError( 'sample_axis ({}) and event_axis ({}) overlapped'.format( sample_axis, event_axis)) if (np.diff(np.array(sorted(event_axis))) > 1).any(): raise ValueError( 'event_axis must be contiguous. Found: {}'.format( event_axis)) batch_axis = list( sorted( set(range(tensorshape_util.rank( x.shape))).difference(sample_axis + event_axis))) else: batch_axis = ps.setdiff1d(ps.range(0, ps.rank(x)), ps.concat((sample_axis, event_axis), 0)) event_axis = ps.cast(event_axis, dtype=tf.int32) sample_axis = ps.cast(sample_axis, dtype=tf.int32) batch_axis = ps.cast(batch_axis, dtype=tf.int32) # Permute x/y until shape = B + E + S perm_for_xy = ps.concat((batch_axis, event_axis, sample_axis), 0) x_permed = tf.transpose(a=x, perm=perm_for_xy) y_permed = tf.transpose(a=y, perm=perm_for_xy) batch_ndims = ps.size(batch_axis) batch_shape = ps.shape(x_permed)[:batch_ndims] event_ndims = ps.size(event_axis) event_shape = ps.shape(x_permed)[batch_ndims:batch_ndims + event_ndims] sample_shape = ps.shape(x_permed)[batch_ndims + event_ndims:] sample_ndims = ps.size(sample_shape) n_samples = ps.reduce_prod(sample_shape) n_events = ps.reduce_prod(event_shape) # Flatten sample_axis into one long dim. x_permed_flat = tf.reshape( x_permed, ps.concat((batch_shape, event_shape, [n_samples]), 0)) y_permed_flat = tf.reshape( y_permed, ps.concat((batch_shape, event_shape, [n_samples]), 0)) # Do the same for event_axis. x_permed_flat = tf.reshape( x_permed, ps.concat((batch_shape, [n_events], [n_samples]), 0)) y_permed_flat = tf.reshape( y_permed, ps.concat((batch_shape, [n_events], [n_samples]), 0)) # After matmul, cov.shape = batch_shape + [n_events, n_events] cov = tf.matmul(x_permed_flat, y_permed_flat, adjoint_b=True) / ps.cast(n_samples, x.dtype) # Insert some singletons to make # cov.shape = batch_shape + event_shape**2 + [1,...,1] # This is just like x_permed.shape, except the sample_axis is all 1's, and # the [n_events] became event_shape**2. cov = tf.reshape( cov, ps.concat( ( batch_shape, # event_shape**2 used here because it is the same length as # event_shape, and has the same number of elements as one # batch of covariance. event_shape**2, ps.ones([sample_ndims], tf.int32)), 0)) # Permuting by the argsort inverts the permutation, making # cov.shape have ones in the position where there were samples, and # [n_events * n_events] in the event position. cov = tf.transpose(a=cov, perm=ps.invert_permutation(perm_for_xy)) # Now expand event_shape**2 into event_shape + event_shape. # We here use (for the first time) the fact that we require event_axis to be # contiguous. e_start = event_axis[0] e_len = 1 + event_axis[-1] - event_axis[0] cov = tf.reshape( cov, ps.concat((ps.shape(cov)[:e_start], event_shape, event_shape, ps.shape(cov)[e_start + e_len:]), 0)) # tf.squeeze requires python ints for axis, not Tensor. This is enough to # require our axis args to be constants. if not keepdims: squeeze_axis = ps.where(sample_axis < e_start, sample_axis, sample_axis + e_len) cov = _squeeze(cov, axis=squeeze_axis) return cov
def _log_loosum_exp_impl(logx, axis, keepdims, compute_mean): """Implementation for `*loosum*` functions.""" with tf.name_scope('log_loosum_exp_impl'): logx = tf.convert_to_tensor(logx, name='logx') dtype = dtype_util.as_numpy_dtype(logx.dtype) if axis is not None: x = np.array(axis) axis = (tf.convert_to_tensor( axis, name='axis', dtype_hint=tf.int32) if x.dtype is np.object else x.astype(np.int32)) log_sum_x = tf.reduce_logsumexp(logx, axis=axis, keepdims=True) # Later we'll want to compute the mean from a sum so we calculate the number # of reduced elements, n. n = prefer_static.size(logx) // prefer_static.size(log_sum_x) n = prefer_static.cast(n, dtype) # log_loosum_x[i] = # = logsumexp(logx[j] : j != i) # = log( exp(logsumexp(logx)) - exp(logx[i]) ) # = log( exp(logsumexp(logx - logx[i])) exp(logx[i]) - exp(logx[i])) # = logx[i] + log(exp(logsumexp(logx - logx[i])) - 1) # = logx[i] + log(exp(logsumexp(logx) - logx[i]) - 1) # = logx[i] + softplus_inverse(logsumexp(logx) - logx[i]) d = log_sum_x - logx # We use `d != 0` rather than `d > 0.` because `d < 0.` should never happen; # if it does we want to complain loudly (which `softplus_inverse` will). d_ok = tf.not_equal(d, 0.) safe_d = tf.where(d_ok, d, 1.) d_ok_result = logx + softplus_inverse(safe_d) neg_inf = tf.constant(-np.inf, dtype=dtype) # When not(d_ok) and is_positive_and_largest then we manually compute the # log_loosum_x. (We can efficiently do this for any one point but not all, # hence we still need the above calculation.) This is good because when # this condition is met, we cannot use the above calculation; its -inf. # We now compute the log-leave-out-max-sum, replicate it to every # point and make sure to select it only when we need to. max_logx = tf.reduce_max(logx, axis=axis, keepdims=True) is_positive_and_largest = (logx > 0.) & tf.equal(logx, max_logx) log_lomsum_x = tf.reduce_logsumexp(tf.where(is_positive_and_largest, neg_inf, logx), axis=axis, keepdims=True) d_not_ok_result = tf.where(is_positive_and_largest, log_lomsum_x, neg_inf) log_loosum_x = tf.where(d_ok, d_ok_result, d_not_ok_result) # We now squeeze log_sum_x so as if we used `keepdims=False`. # TODO(b/136176077): These mental gymnastics could all be replaced with # `tf.squeeze(log_sum_x, axis)` if tf.squeeze supported Tensor valued `axis` # arguments. if not keepdims: if axis is None: keepdims = np.array([], dtype=np.int32) else: rank = prefer_static.rank(logx) keepdims = prefer_static.setdiff1d( prefer_static.range(rank), prefer_static.non_negative_axis(axis, rank)) squeeze_shape = tf.gather(prefer_static.shape(logx), indices=keepdims) log_sum_x = tf.reshape(log_sum_x, shape=squeeze_shape) if prefer_static.is_numpy(keepdims): tensorshape_util.set_shape(log_sum_x, np.array(logx.shape)[keepdims]) # Set static shapes just in case we lost them. tensorshape_util.set_shape(n, []) tensorshape_util.set_shape(log_loosum_x, logx.shape) if not compute_mean: return log_loosum_x, log_sum_x, n log_nm1 = prefer_static.log(max(1., n - 1.)) log_n = prefer_static.log(n) return log_loosum_x - log_nm1, log_sum_x - log_n, n
def _parameter_control_dependencies(self, is_init): assertions = [] axis = None paddings = None if is_init != tensor_util.is_ref(self.axis): # First we check the shape of the axis argument. msg = 'Argument `axis` must be scalar or vector.' if tensorshape_util.rank(self.axis.shape) is not None: if tensorshape_util.rank(self.axis.shape) > 1: raise ValueError(msg) elif self.validate_args: if axis is None: axis = tf.convert_to_tensor(self.axis) assertions.append(assert_util.assert_rank_at_most( axis, 1, message=msg)) # Next we check the values of the axis argument. axis_ = tf.get_static_value(self.axis) msg = 'Argument `axis` must be negative.' if axis_ is not None: if np.any(axis_ > -1): raise ValueError(msg) elif self.validate_args: if axis is None: axis = tf.convert_to_tensor(self.axis) assertions.append(assert_util.assert_less(axis, 0, message=msg)) msg = 'Argument `axis` elements must be unique.' if axis_ is not None: if len(np.array(axis_).reshape(-1)) != len(np.unique(axis_)): raise ValueError(msg) elif self.validate_args: if axis is None: axis = tf.convert_to_tensor(self.axis) assertions.append(assert_util.assert_equal( prefer_static.size0(axis), prefer_static.size0(prefer_static.setdiff1d(axis)), message=msg)) if is_init != tensor_util.is_ref(self.paddings): # First we check the shape of the paddings argument. msg = 'Argument `paddings` must be a vector of pairs.' if tensorshape_util.is_fully_defined(self.paddings.shape): shape = np.int32(self.paddings.shape) if len(shape) != 2 or shape[0] < 1 or shape[1] != 2: raise ValueError(msg) elif self.validate_args: if paddings is None: paddings = tf.convert_to_tensor(self.paddings) with tf.control_dependencies([ assert_util.assert_equal(tf.rank(paddings), 2, message=msg)]): shape = tf.shape(paddings) assertions.extend([ assert_util.assert_greater(shape[0], 0, message=msg), assert_util.assert_equal(shape[1], 2, message=msg), ]) # Next we check the values of the paddings argument. paddings_ = tf.get_static_value(self.paddings) msg = 'Argument `paddings` must be non-negative.' if paddings_ is not None: if np.any(paddings_ < 0): raise ValueError(msg) elif self.validate_args: if paddings is None: paddings = tf.convert_to_tensor(self.paddings) assertions.append(assert_util.assert_greater( paddings, -1, message=msg)) if is_init != (tensor_util.is_ref(self.axis) and tensor_util.is_ref(self.paddings)): axis_ = tf.get_static_value(self.axis) if axis_ is None and axis is None: axis = tf.convert_to_tensor(self.axis) len_axis = prefer_static.size0(prefer_static.reshape( axis if axis_ is None else axis_, shape=-1)) paddings_ = tf.get_static_value(self.paddings) if paddings_ is None and paddings is None: paddings = tf.convert_to_tensor(self.paddings) len_paddings = prefer_static.size0( paddings if paddings_ is None else paddings_) msg = ('Arguments `axis` and `paddings` must have the same number ' 'of elements.') if (prefer_static.is_numpy(len_axis) and prefer_static.is_numpy(len_paddings)): if len_axis != len_paddings: raise ValueError(msg + ' Saw: {}, {}.'.format( self.axis, self.paddings)) elif self.validate_args: assertions.append(assert_util.assert_equal( len_axis, len_paddings, message=msg)) return assertions