def check_event_space_bijector_constrains(self, dist, data): event_space_bijector = dist.experimental_default_event_space_bijector() if event_space_bijector is None: return total_sample_shape = tensorshape_util.concatenate( # Draw a sample shape data.draw(tfp_hps.shapes()), # Draw a shape that broadcasts with `[batch_shape, inverse_event_shape]` # where `inverse_event_shape` is the event shape in the bijector's # domain. This is the shape of `y` in R**n, such that # x = event_space_bijector(y) has the event shape of the distribution. data.draw( tfp_hps.broadcasting_shapes(tensorshape_util.concatenate( dist.batch_shape, event_space_bijector.inverse_event_shape( dist.event_shape)), n=1))[0]) y = data.draw( tfp_hps.constrained_tensors(tfp_hps.identity_fn, total_sample_shape.as_list())) with tfp_hps.no_tf_rank_errors(): x = event_space_bijector(y) with tf.control_dependencies(dist._sample_control_dependencies(x)): self.evaluate(tf.identity(x))
def test_sample(self, data): batch_shape = data.draw(tfp_hps.shapes()) bcast_arg, dist_batch_shp = data.draw( tfp_hps.broadcasting_shapes(batch_shape, 2)) underlying = tfd.Normal(loc=tf.reshape( tf.range(float(np.prod(tensorshape_util.as_list(dist_batch_shp)))), dist_batch_shp), scale=0.01) if not self.is_static_shape: bcast_arg = tf.Variable(bcast_arg) self.evaluate(bcast_arg.initializer) dist = tfd.BatchBroadcast(underlying, bcast_arg) sample_shape = data.draw( hps.one_of(hps.integers(0, 13), tfp_hps.shapes())) sample_batch_event = tf.concat([ np.int32(sample_shape).reshape([-1]), batch_shape, dist.event_shape_tensor() ], axis=0) sample = dist.sample(sample_shape, seed=test_util.test_seed()) if self.is_static_shape: self.assertEqual(tf.TensorShape(self.evaluate(sample_batch_event)), sample.shape) self.assertAllEqual(sample_batch_event, tf.shape(sample)) # Since the `loc` of the underlying is simply 0...n-1 (reshaped), and the # scale is extremely small, then we can verify that these locations are # effectively broadcast out to the full batch shape when sampling. self.assertAllClose(tf.broadcast_to(dist.distribution.loc, sample_batch_event), sample, atol=.1)
def testCholeskyExtensionRandomized(self, data): jitter = lambda n: tf.linalg.eye(n, dtype=self.dtype) * 5e-5 target_bs = data.draw(hpnp.array_shapes()) prev_bs, new_bs = data.draw(tfp_hps.broadcasting_shapes(target_bs, 2)) ones = tf.TensorShape([1] * len(target_bs)) smallest_shared_shp = tuple( np.min([ tensorshape_util.as_list(tf.broadcast_static_shape(ones, shp)) for shp in [prev_bs, new_bs] ], axis=0)) z = data.draw(hps.integers(min_value=1, max_value=12)) n = data.draw(hps.integers(min_value=0, max_value=z - 1)) m = z - n np.random.seed( data.draw(hps.integers(min_value=0, max_value=2**32 - 1))) xs = np.random.uniform(size=smallest_shared_shp + (n, )) data.draw(hps.just(xs)) xs = (xs + np.zeros(tensorshape_util.as_list(prev_bs) + [n]))[..., np.newaxis] xs = xs.astype(self.dtype) xs = tf1.placeholder_with_default( xs, shape=xs.shape if self.use_static_shape else None) k = tfp.math.psd_kernels.MaternOneHalf() mat = k.matrix(xs, xs) + jitter(n) chol = tf.linalg.cholesky(mat) ys = np.random.uniform(size=smallest_shared_shp + (m, )) data.draw(hps.just(ys)) ys = (ys + np.zeros(tensorshape_util.as_list(new_bs) + [m]))[..., np.newaxis] ys = ys.astype(self.dtype) ys = tf1.placeholder_with_default( ys, shape=ys.shape if self.use_static_shape else None) xsys = tf.concat([ xs + tf.zeros(target_bs + (n, 1), dtype=self.dtype), ys + tf.zeros(target_bs + (m, 1), dtype=self.dtype) ], axis=-2) new_chol_expected = tf.linalg.cholesky( k.matrix(xsys, xsys) + jitter(z)) new_chol = tfp.math.cholesky_concat( chol, k.matrix(xsys, ys) + jitter(z)[:, n:]) self.assertAllClose(new_chol_expected, new_chol, rtol=1e-5, atol=2e-5)
def normal_params(draw): shape = draw(shapes()) arg_shapes = draw( tfp_hps.broadcasting_shapes(shape, 3).map(tensorshapes_to_tuples)) include_arg = draw(hps.lists(hps.booleans(), min_size=2, max_size=2)) dtype = draw(hps.sampled_from([np.float32, np.float64])) mean = (draw( single_arrays(shape=hps.just(arg_shapes[1]), dtype=dtype, elements=floats())) if include_arg[0] else 0) stddev = (draw( single_arrays(shape=hps.just(arg_shapes[2]), dtype=dtype, elements=positive_floats())) if include_arg[1] else 1) return (arg_shapes[0], mean, stddev, dtype)
def test_shapes(self, data): batch_shape = data.draw(tfp_hps.shapes()) bcast_arg, dist_batch_shp = data.draw( tfp_hps.broadcasting_shapes(batch_shape, 2)) underlying = data.draw(tfd_hps.distributions(batch_shape=dist_batch_shp)) if not self.is_static_shape: bcast_arg = tf.Variable(bcast_arg) self.evaluate(bcast_arg.initializer) dist = tfd.BatchBroadcast(underlying, bcast_arg) if self.is_static_shape: self.assertEqual(batch_shape, dist.batch_shape) self.assertEqual(underlying.event_shape, dist.event_shape) self.assertAllEqual(batch_shape, dist.batch_shape_tensor()) self.assertAllEqual(underlying.event_shape_tensor(), dist.event_shape_tensor())
def testCholeskyUpdateRandomized(self, data): target_bs = data.draw(hpnp.array_shapes()) chol_bs, u_bs, multiplier_bs = data.draw( tfp_hps.broadcasting_shapes(target_bs, 3)) l = data.draw(hps.integers(min_value=1, max_value=12)) rng_seed = data.draw(hps.integers(min_value=0, max_value=2**32 - 1)) rng = np.random.RandomState(seed=rng_seed) xs = push_apart( rng.uniform(size=tensorshape_util.concatenate(chol_bs, (l, 1))), axis=-2) hp.note(xs) xs = xs.astype(self.dtype) xs = tf1.placeholder_with_default( xs, shape=xs.shape if self.use_static_shape else None) k = tfp.math.psd_kernels.MaternOneHalf() jitter = lambda n: tf.linalg.eye(n, dtype=self.dtype) * 5e-5 mat = k.matrix(xs, xs) + jitter(l) chol = tf.linalg.cholesky(mat) u = rng.uniform(size=tensorshape_util.concatenate(u_bs, (l, ))) hp.note(u) u = u.astype(self.dtype) u = tf1.placeholder_with_default( u, shape=u.shape if self.use_static_shape else None) multiplier = rng.uniform(size=multiplier_bs) hp.note(multiplier) multiplier = multiplier.astype(self.dtype) multiplier = tf1.placeholder_with_default( multiplier, shape=multiplier.shape if self.use_static_shape else None) new_chol_expected = tf.linalg.cholesky( mat + multiplier[..., tf.newaxis, tf.newaxis] * tf.linalg.matmul(u[..., tf.newaxis], u[..., tf.newaxis, :])) new_chol = tfp.math.cholesky_update(chol, u, multiplier=multiplier) self.assertAllClose(new_chol_expected, new_chol, rtol=1e-5, atol=2e-5) self.assertAllEqual(tf.linalg.band_part(new_chol, -1, 0), new_chol)
def uniform_params(draw): shape = draw(shapes()) arg_shapes = draw( tfp_hps.broadcasting_shapes(shape, 3).map(tensorshapes_to_tuples)) include_arg = draw(hps.lists(hps.booleans(), min_size=2, max_size=2)) dtype = draw(hps.sampled_from([np.int32, np.int64, np.float32, np.float64])) elements = floats(), positive_floats() if dtype == np.int32 or dtype == np.int64: # TF RandomUniformInt only supports scalar min/max. arg_shapes = (arg_shapes[0], (), ()) elements = integers(), integers(min_value=1) minval = ( draw(single_arrays(shape=hps.just(arg_shapes[1]), dtype=dtype, elements=elements[0])) if include_arg[0] else 0) maxval = minval + ( draw(single_arrays(shape=hps.just(arg_shapes[2]), dtype=dtype, elements=elements[1])) if include_arg[1] else dtype(10)) return (arg_shapes[0], minval, maxval, dtype)
def test_default_bijector(self, data): batch_shape = data.draw(tfp_hps.shapes()) bcast_arg, dist_batch_shp = data.draw( tfp_hps.broadcasting_shapes(batch_shape, 2)) underlying = data.draw( tfd_hps.distributions(batch_shape=dist_batch_shp)) if not self.is_static_shape: bcast_arg = tf.Variable(bcast_arg) self.evaluate(bcast_arg.initializer) dist = tfd.BatchBroadcast(underlying, bcast_arg) bijector = dist.experimental_default_event_space_bijector() hp.assume(bijector is not None) shp = bijector.inverse_event_shape_tensor( tf.concat([dist.batch_shape_tensor(), dist.event_shape_tensor()], axis=0)) obs = bijector.forward( tf.random.normal(shp, seed=test_util.test_seed())) with tf.control_dependencies(dist._sample_control_dependencies(obs)): self.evaluate(tf.identity(obs))
def testDistribution(self, data): enable_vars = data.draw(hps.booleans()) # TODO(b/146572907): Fix `enable_vars` for metadistributions. broken_dists = EVENT_SPACE_BIJECTOR_IS_BROKEN if enable_vars: broken_dists.extend(dhps.INSTANTIABLE_META_DISTS) dist = data.draw( dhps.distributions( enable_vars=enable_vars, eligibility_filter=(lambda name: name not in broken_dists))) self.evaluate([var.initializer for var in dist.variables]) self.check_bad_loc_scale(dist) event_space_bijector = dist._experimental_default_event_space_bijector( ) if event_space_bijector is None: return total_sample_shape = tensorshape_util.concatenate( # Draw a sample shape data.draw(tfp_hps.shapes()), # Draw a shape that broadcasts with `[batch_shape, inverse_event_shape]` # where `inverse_event_shape` is the event shape in the bijector's # domain. This is the shape of `y` in R**n, such that # x = event_space_bijector(y) has the event shape of the distribution. data.draw( tfp_hps.broadcasting_shapes(tensorshape_util.concatenate( dist.batch_shape, event_space_bijector.inverse_event_shape( dist.event_shape)), n=1))[0]) y = data.draw( tfp_hps.constrained_tensors(tfp_hps.identity_fn, total_sample_shape.as_list())) x = event_space_bijector(y) with tf.control_dependencies(dist._sample_control_dependencies(x)): self.evaluate(tf.identity(x))
def test_log_prob(self, data): batch_shape = data.draw(tfp_hps.shapes()) bcast_arg, dist_batch_shp = data.draw( tfp_hps.broadcasting_shapes(batch_shape, 2)) underlying = tfd.Normal( loc=tf.reshape( tf.range(float(np.prod(tensorshape_util.as_list(dist_batch_shp)))), dist_batch_shp), scale=0.01) if not self.is_static_shape: bcast_arg = tf.Variable(bcast_arg) self.evaluate(bcast_arg.initializer) dist = tfd.BatchBroadcast(underlying, bcast_arg) sample_shape = data.draw(hps.one_of(hps.integers(0, 13), tfp_hps.shapes())) sample_batch_event = tf.concat([np.int32(sample_shape).reshape([-1]), batch_shape, dist.event_shape_tensor()], axis=0) obsv = tf.broadcast_to(dist.distribution.loc, sample_batch_event) self.assertAllTrue(dist.log_prob(obsv) > dist.log_prob(obsv + .5))