def _batch_shape_tensor(self, temperature=None, logits=None): param = logits if param is None: param = self._logits if self._logits is not None else self._probs if temperature is None: temperature = self.temperature return ps.broadcast_shape(ps.shape(temperature), ps.shape(param)[:-1])
def _log_prob(self, x): scores = tf.convert_to_tensor(self.scores) event_size = self._event_size(scores) x = tf.cast(x, self.dtype) # Broadcast scores or x if need be. if (not tensorshape_util.is_fully_defined(x.shape) or not tensorshape_util.is_fully_defined(scores.shape) or x.shape != scores.shape): broadcast_shape = ps.broadcast_shape(ps.shape(scores), ps.shape(x)) scores = tf.broadcast_to(scores, broadcast_shape) x = tf.broadcast_to(x, broadcast_shape) scores_shape = ps.shape(scores)[:-1] scores_2d = tf.reshape(scores, [-1, event_size]) x_2d = tf.reshape(x, [-1, event_size]) # Ensure that these are indices that we can use in a gather. if dtype_util.is_floating(x_2d.dtype): x_2d = tf.cast(x_2d, tf.int32) rearranged_scores = tf.gather(scores_2d, x_2d, batch_dims=1) normalization_terms = tf.cumsum(rearranged_scores, axis=-1, reverse=True) ret = tf.math.reduce_sum(tf.math.log(rearranged_scores / normalization_terms), axis=-1) # Reshape back to user-supplied batch and sample dims prior to 2D reshape. ret = tf.reshape(ret, scores_shape) return ret
def test_with_broadcast_batch_shape(self, bijector_fn, x_event_ndims=None): bijector = bijector_fn() if x_event_ndims is None: x_event_ndims = bijector.forward_min_event_ndims batch_shape = bijector.experimental_batch_shape( x_event_ndims=x_event_ndims) param_batch_shapes = batch_shape_lib.batch_shape_parts( bijector, bijector_x_event_ndims=x_event_ndims) new_batch_shape = [4, 2, 1, 1, 1] broadcast_bijector = bijector._broadcast_parameters_with_batch_shape( new_batch_shape, x_event_ndims) broadcast_batch_shape = broadcast_bijector.experimental_batch_shape_tensor( x_event_ndims=x_event_ndims) self.assertAllEqual(broadcast_batch_shape, ps.broadcast_shape(batch_shape, new_batch_shape)) # Check that all params have the expected batch shape. broadcast_param_batch_shapes = batch_shape_lib.batch_shape_parts( broadcast_bijector, bijector_x_event_ndims=x_event_ndims) def _maybe_broadcast_param_batch_shape(p, s): if isinstance(p, tfb.Invert) and not p.bijector._params_event_ndims(): return s # Can't broadcast a bijector that doesn't itself have params. return ps.broadcast_shape(s, new_batch_shape) expected_broadcast_param_batch_shapes = tf.nest.map_structure( _maybe_broadcast_param_batch_shape, {param: getattr(bijector, param) for param in param_batch_shapes}, param_batch_shapes) self.assertAllEqualNested(broadcast_param_batch_shapes, expected_broadcast_param_batch_shapes)
def _log_prob(self, x): log_nsphere_surface_area = ( np.log(2.) + (self.dimension / 2) * np.log(np.pi) - tf.math.lgamma(tf.cast(self.dimension / 2., x.dtype))) batch_shape = ps.broadcast_shape( ps.shape(x)[:-1], self.batch_shape) return tf.fill(batch_shape, -log_nsphere_surface_area)
def _batch_shape_tensor(self, logits_or_probs=None, total_count=None): if logits_or_probs is None: logits_or_probs = self._logits if self._probs is None else self._logits total_count = self._total_count if total_count is None else total_count return prefer_static.broadcast_shape( prefer_static.shape(logits_or_probs), prefer_static.shape(total_count))
def _reduce_ldj_ratio(unreduced_ldj_ratio, p, q, input_shape, min_event_ndims, event_ndims): """Reduces an LDJ ratio computed with event_ndims=min_event_ndims.""" # pylint: disable=protected-access have_parameter_batch_shape = (p._parameter_batch_shape is not None and q._parameter_batch_shape is not None) if have_parameter_batch_shape: parameter_batch_shape = ps.broadcast_shape(p._parameter_batch_shape, q._parameter_batch_shape) else: parameter_batch_shape = None reduce_shape, assertions = bijector_lib.ldj_reduction_shape( input_shape, event_ndims=event_ndims, min_event_ndims=min_event_ndims, parameter_batch_shape=parameter_batch_shape, allow_event_shape_broadcasting=not (p._parts_interact or q._parts_interact), validate_args=p.validate_args or q.validate_args) sum_fn = getattr(p, '_sum_fn', getattr(q, '_sum_fn', tf.reduce_sum)) with tf.control_dependencies(assertions): return bijector_lib.reduce_jacobian_det_over_shape( unreduced_ldj_ratio, reduce_shape=reduce_shape, sum_fn=sum_fn)
def _entropy(self): scale = tf.broadcast_to( self.scale, ps.broadcast_shape(ps.shape(self.scale), ps.shape(self.loc))) euler_gamma = tf.constant(np.euler_gamma, self.dtype) return 1. + tf.math.log(scale) + euler_gamma * (1. + self.concentration)
def _sample_n(self, n, seed=None): """Gamma sampler. Rather than use `tf.random.gamma` (which is as of February 2020 implemented in C++ for CPU only), we implement our own gamma sampler in Python, using `batched_las_vegas_algorithm` as a substrate. This has the advantage that our sampler is XLA compilable. If sampling becomes a bottleneck on CPU, one way to gain speed would be to consider switching back to the C++ sampler. Args: n: Number of samples to draw. seed: (optional) The random seed. Returns: n samples from the gamma distribution. """ n = tf.convert_to_tensor(n, name='shape', dtype=tf.int32) alpha = tf.convert_to_tensor(self.concentration, name='alpha') beta = tf.convert_to_tensor(self.rate, name='beta') broadcast_shape = prefer_static.broadcast_shape( prefer_static.shape(alpha), prefer_static.shape(beta)) result_shape = tf.concat([[n], broadcast_shape], axis=0) return random_gamma(result_shape, alpha, beta, seed=seed)
def _batch_gather_with_broadcast(params, indices, axis): """Like batch_gather, but broadcasts to the left of axis.""" # batch_gather assumes... # params.shape = [A1,...,AN, B1,...,BM] # indices.shape = [A1,...,AN, C] # which gives output of shape # [A1,...,AN, C, B1,...,BM] # Here we broadcast dims of each to the left of `axis` in params, and left of # the rightmost dim in indices, e.g. we can # have # params.shape = [A1,...,AN, B1,...,BM] # indices.shape = [a1,...,aN, C], # where ai broadcasts with Ai. # leading_bcast_shape is the broadcast of [A1,...,AN] and [a1,...,aN]. leading_bcast_shape = ps.broadcast_shape( ps.shape_slice(params, np.s_[:axis]), ps.shape_slice(indices, np.s_[:-1])) params = _broadcast_with( params, ps.concat((leading_bcast_shape, ps.shape_slice(params, np.s_[axis:])), axis=0)) indices = _broadcast_with( indices, ps.concat((leading_bcast_shape, ps.shape_slice(indices, np.s_[-1:])), axis=0)) return tf.gather(params, indices, batch_dims=tensorshape_util.rank(indices.shape) - 1)
def random_gamma_with_runtime(shape, concentration, rate=None, log_rate=None, seed=None, log_space=False): """Returns both a sample and the id of the implementation-selected runtime.""" # This method exists chiefly for testing purposes. dtype = dtype_util.common_dtype([concentration, rate, log_rate], tf.float32) concentration = tf.convert_to_tensor(concentration, dtype=dtype) shape = ps.convert_to_shape_tensor(shape, dtype_hint=tf.int32, name='shape') if rate is not None and log_rate is not None: raise ValueError( 'At most one of `rate` and `log_rate` may be specified.') if rate is not None: rate = tf.convert_to_tensor(rate, dtype=dtype) if log_rate is not None: log_rate = tf.convert_to_tensor(log_rate, dtype=dtype) total_shape = ps.concat([ shape, ps.broadcast_shape(ps.shape(concentration), _shape_or_scalar(rate, log_rate)) ], axis=0) seed = samplers.sanitize_seed(seed, salt='random_gamma') return _random_gamma_gradient(total_shape, concentration, rate, log_rate, seed, log_space)
def _parameter_control_dependencies(self, is_init): if not self.validate_args: return [] sample_shape = tf.concat( [self._batch_shape_tensor(), self._event_shape_tensor()], axis=0) low = None if self._low is None else tf.convert_to_tensor(self._low) high = None if self._high is None else tf.convert_to_tensor(self._high) assertions = [] if self._low is not None and is_init != tensor_util.is_ref(self._low): low_shape = ps.shape(low) broadcast_shape = ps.broadcast_shape(sample_shape, low_shape) assertions.extend([ distribution_util.assert_integer_form( low, message='`low` has non-integer components.'), assert_util.assert_equal( tf.reduce_prod(broadcast_shape), tf.reduce_prod(sample_shape), message=('Shape of `low` adds extra batch dimensions to ' 'sample shape.')) ]) if self._high is not None and is_init != tensor_util.is_ref( self._high): high_shape = ps.shape(high) broadcast_shape = ps.broadcast_shape(sample_shape, high_shape) assertions.extend([ distribution_util.assert_integer_form( high, message='`high` has non-integer components.'), assert_util.assert_equal( tf.reduce_prod(broadcast_shape), tf.reduce_prod(sample_shape), message=('Shape of `high` adds extra batch dimensions to ' 'sample shape.')) ]) if (self._low is not None and self._high is not None and (is_init != (tensor_util.is_ref(self._low) or tensor_util.is_ref(self._high)))): assertions.append( assert_util.assert_less( low, high, message='`low` must be strictly less than `high`.')) return assertions
def test_batching(self, input_batch_shape, kernel_batch_shape): input_shape = (12, 12, 2) filter_shape = (2, 2) channels_out = 3 strides = (1, 1) dilations = (1, 1) padding = 'SAME' x, k = _make_input_and_kernel(self.make_input, input_batch_shape=input_batch_shape, input_shape=input_shape, kernel_batch_shape=kernel_batch_shape, filter_shape=filter_shape, channels_out=channels_out, dtype=self.dtype) conv_fn = tfn.util.make_convolution_fn(filter_shape, rank=2, strides=strides, padding=padding, dilations=dilations, validate_args=True) y_batched = conv_fn(x, k) broadcast_batch_shape = ps.broadcast_shape(input_batch_shape, kernel_batch_shape) broadcasted_input = tf.broadcast_to( x, shape=ps.concat([broadcast_batch_shape, input_shape], axis=0)) broadcasted_kernel = tf.broadcast_to( k, shape=ps.concat([broadcast_batch_shape, ps.shape(k)[-2:]], axis=0)) flat_y = tf.reshape(y_batched, shape=ps.pad(ps.shape(y_batched)[-3:], paddings=[[1, 0]], constant_values=-1)) flat_x = tf.reshape(broadcasted_input, shape=ps.pad(input_shape, paddings=[[1, 0]], constant_values=-1)) flat_tf_kernel = tf.reshape(broadcasted_kernel, shape=ps.concat( [(-1, ), filter_shape, (input_shape[-1], channels_out)], axis=0)) y_expected = tf.vectorized_map( lambda args: tf.nn.conv2d( # pylint: disable=g-long-lambda args[0][tf.newaxis], args[1], strides=strides, padding=padding), elems=(flat_x, flat_tf_kernel)) [y_actual_, y_expected_] = self.evaluate([flat_y, tf.squeeze(y_expected, axis=1)]) self.assertAllClose(y_expected_, y_actual_, rtol=1e-5, atol=0)
def _broadcast_params(self): lower_upper = tf.convert_to_tensor(self.lower_upper) perm = tf.convert_to_tensor(self.permutation) shape = ps.broadcast_shape(ps.shape(lower_upper)[:-1], ps.shape(perm)) lower_upper = tf.broadcast_to(lower_upper, ps.concat([shape, shape[-1:]], 0)) perm = tf.broadcast_to(perm, shape) return lower_upper, perm
def _sample_n(self, n, seed=None): broadcast_shape = prefer_static.broadcast_shape( prefer_static.shape(self.concentration), prefer_static.shape(self.scale)) return 1. / gamma.random_gamma(sample_shape=tf.concat( [[n], broadcast_shape], axis=0), alpha=self.concentration, beta=self.scale, seed=seed)
def fn(self, *args, **kwargs): val = getattr(self.distribution, fn_name)(*args, **kwargs) single_val_shape = self.batch_shape_tensor() if n_event_shapes: single_val_shape = ps.concat( [single_val_shape] + [self.event_shape_tensor()] * n_event_shapes, axis=0) return tf.broadcast_to( val, ps.broadcast_shape(ps.shape(val), single_val_shape))
def _cdf(self, x): loc = tf.convert_to_tensor(self.loc) concentration = tf.convert_to_tensor(self.concentration) batch_shape = ps.broadcast_shape( self._batch_shape_tensor(loc=loc, concentration=concentration), ps.shape(x)) z = tf.broadcast_to(self._z(x, loc=loc), batch_shape) concentration = tf.broadcast_to(concentration, batch_shape) return von_mises_cdf(z, concentration)
def expand_right_dims(x, broadcast=False): """Expand x so it can bcast w/ tensors of output shape.""" expanded_shape_left = ps.broadcast_shape( ps.shape(x)[:-1], ps.ones([ps.size(y_ref_shape_left)], dtype=tf.int32)) expanded_shape = ps.concat( (expanded_shape_left, ps.shape(x)[-1:], ps.ones([ps.size(y_ref_shape_right)], dtype=tf.int32)), axis=0) x_expanded = tf.reshape(x, expanded_shape) if broadcast: broadcast_shape_left = ps.broadcast_shape( ps.shape(x)[:-1], y_ref_shape_left) broadcast_shape = ps.concat( (broadcast_shape_left, ps.shape(x)[-1:], y_ref_shape_right), axis=0) x_expanded = _broadcast_with(x_expanded, broadcast_shape) return x_expanded
def _cumulative_broadcast_dynamic(event_shape): broadcast_shapes = [ ps.slice(s, begin=[0], size=[ps.size(s)-1]) for s in event_shape] cumulative_shapes = [broadcast_shapes[0]] for shape in broadcast_shapes[1:]: out_shape = ps.broadcast_shape(shape, cumulative_shapes[-1]) cumulative_shapes.append(out_shape) return [ ps.concat([b, ps.slice(s, begin=[ps.size(s)-1], size=[1])], axis=0) for b, s in zip(cumulative_shapes, event_shape)]
def _broadcast_with(tensor, shape): """Like broadcast_to, but allows singletons in the destination shape.""" res = tf.broadcast_to(tensor, ps.broadcast_shape(ps.shape(tensor), shape)) # We need this done explicitly because ps.broadcast_shape cannot deal with # partially specified shapes. tensorshape_util.set_shape( res, tf.broadcast_static_shape(tensor.shape, tf.TensorShape(tf.get_static_value(shape)))) return res
def test_dynamic(self): if tf.executing_eagerly(): return shape = prefer_static.broadcast_shape( tf.convert_to_tensor([3, 2, 1]), tf.shape(tf1.placeholder_with_default(np.zeros((1, 5)), shape=(None, 5)))) self.assertIsNone(tf.get_static_value(shape)) self.assertAllEqual([3, 2, 5], self.evaluate(shape))
def _variance(self): if self._precision is None: precision = self._precision_factor.matmul(self._precision_factor, adjoint_arg=True) else: precision = self._precision variance = precision.inverse().diag_part() return tf.broadcast_to( variance, ps.broadcast_shape(ps.shape(variance), ps.shape(self.loc)))
def test_works_correctly(self, input_size, output_size, kernel_batch_shape, input_batch_shape): affine = tfn.Affine(input_size, output_size=output_size, batch_shape=kernel_batch_shape) x = tf.ones((input_batch_shape + (input_size, )), dtype=tf.float32) y = affine(x) self.assertAllEqual( y.shape, ps.broadcast_shape(kernel_batch_shape, input_batch_shape).concatenate(output_size))
def testAffineBatching(self, layer_batch, input_batch): dist = jdlayers.Affine(4, 3, dtype=self.dtype) layer = dist.sample(layer_batch, seed=test_util.test_seed(sampler_type='stateless')) # Validate that we can map the layer. layer = tf.nest.map_structure(lambda x: x + 0., layer) x = tf.ones(input_batch + [3], dtype=self.dtype) y = layer(x) self.assertAllEqual( list(ps.broadcast_shape(layer_batch, input_batch)) + [4], y.shape) self.assertEqual(self.dtype, y.dtype)
def _random_gamma_noncpu(shape, concentration, rate, seed=None): """Sample using XLA-friendly python-based rejection sampler.""" shape = tf.concat([ shape, prefer_static.broadcast_shape(tf.shape(concentration), tf.shape(rate)) ], axis=0) return random_gamma_rejection(sample_shape=shape, alpha=concentration, beta=rate, seed=seed)
def random_gamma(shape, concentration, rate, seed=None): shape = ps.convert_to_shape_tensor(shape, dtype_hint=tf.int32, name='shape') total_shape = ps.concat( [shape, ps.broadcast_shape(ps.shape(concentration), ps.shape(rate))], axis=0) seed = samplers.sanitize_seed(seed, salt='random_gamma') return _random_gamma_gradient(total_shape, concentration, rate, seed)
def _cdf(self, x): low = tf.convert_to_tensor(self.low) high = tf.convert_to_tensor(self.high) batch_shape = self.batch_shape if not tensorshape_util.is_fully_defined(batch_shape): batch_shape = self._batch_shape_tensor(low=low, high=high) broadcast_shape = ps.broadcast_shape(ps.shape(x), batch_shape) zeros = tf.zeros(broadcast_shape, dtype=self.dtype) ones = tf.ones(broadcast_shape, dtype=self.dtype) result_if_not_big = tf.where(x < low, zeros, (x - low) / self._range(low=low, high=high)) return tf.where(x >= high, ones, result_if_not_big)
def _finish_log_prob(self, lp, aux): (sample_ndims, extra_sample_ndims, batch_ndims) = aux # (1) Ensure lp is fully broadcast in the sample dims, i.e. ensure lp has # full sample shape in the sample axes, before we reduce. bcast_lp_shape = ps.broadcast_shape( ps.shape(lp), ps.concat([ps.ones([sample_ndims], tf.int32), ps.reshape(self.sample_shape, shape=[-1]), ps.ones([batch_ndims], tf.int32)], axis=0)) lp = tf.broadcast_to(lp, bcast_lp_shape) # (2) Make the final reduction. axis = ps.range(sample_ndims, sample_ndims + extra_sample_ndims) return self._sum_fn()(lp, axis=axis)
def _log_gamma_difference_jvp(primals, tangents): """Computes JVP for log-gamma-difference (supports JAX custom derivative).""" x, y = primals dx, dy = tangents # TODO(https://github.com/google/jax/issues/3768): eliminate broadcast_to? bc_shp = prefer_static.broadcast_shape(prefer_static.shape(dx), prefer_static.shape(dy)) dx = tf.broadcast_to(dx, bc_shp) dy = tf.broadcast_to(dy, bc_shp) # See note above in _log_gamma_difference_bwd. px = -tf.math.digamma(x + y) py = tf.math.digamma(y) + px return _log_gamma_difference_naive_gradient(x, y), px * dx + py * dy
def _lbeta_jvp(primals, tangents): """Computes JVP for log-beta (supports JAX custom derivative).""" x, y = primals dx, dy = tangents # TODO(https://github.com/google/jax/issues/3768): eliminate broadcast_to? bc_shp = prefer_static.broadcast_shape(prefer_static.shape(dx), prefer_static.shape(dy)) dx = tf.broadcast_to(dx, bc_shp) dy = tf.broadcast_to(dy, bc_shp) total_digamma = tf.math.digamma(x + y) px = tf.math.digamma(x) - total_digamma py = tf.math.digamma(y) - total_digamma return _lbeta_naive_gradient(x, y), px * dx + py * dy
def _inner_apply(x1, x2): order = ps.shape(self.amplitudes)[-1] def scan_fn(esp, i): s = self.kernel[..., i].apply( x1[..., i][..., tf.newaxis], x2[..., i][..., tf.newaxis], example_ndims=example_ndims) next_esp = esp[..., 1:] + s[..., tf.newaxis] * esp[..., :-1] # Add the zero-th polynomial. next_esp = tf.concat( [tf.ones_like(esp[..., 0][..., tf.newaxis]), next_esp], axis=-1) return next_esp batch_shape = ps.broadcast_shape( ps.shape(x1)[:-self.kernel.feature_ndims], ps.shape(x2)[:-self.kernel.feature_ndims]) batch_shape = ps.broadcast_shape( batch_shape, ps.concat([ self.batch_shape_tensor(), [1] * example_ndims], axis=0)) initializer = tf.concat( [tf.ones(ps.concat([batch_shape, [1]], axis=0), dtype=self.dtype), tf.zeros(ps.concat([batch_shape, [order]], axis=0), dtype=self.dtype)], axis=-1) esps = tf.scan( scan_fn, elems=ps.range(0, ps.shape(x1)[-1], dtype=tf.int32), parallel_iterations=32, initializer=initializer)[-1, ..., 1:] amplitudes = util.pad_shape_with_ones( self.amplitudes, ndims=example_ndims, start=-2) return tf.reduce_sum(esps * tf.math.square(amplitudes), axis=-1)