def testLogProb(self, event_shape, event_dims, training): with self.test_session() as sess: training = tf.placeholder_with_default(training, (), "training") layer = normalization.BatchNormalization(axis=event_dims, epsilon=0.) batch_norm = tfb.BatchNormalization(batchnorm_layer=layer, training=training) base_dist = distributions.MultivariateNormalDiag( loc=np.zeros(np.prod(event_shape), dtype=np.float32)) # Reshape the events. if isinstance(event_shape, int): event_shape = [event_shape] base_dist = transformed_distribution_lib.TransformedDistribution( distribution=base_dist, bijector=tfb.Reshape(event_shape_out=event_shape)) dist = transformed_distribution_lib.TransformedDistribution( distribution=base_dist, bijector=batch_norm, validate_args=True) samples = dist.sample(int(1e5)) # No volume distortion since training=False, bijector is initialized # to the identity transformation. base_log_prob = base_dist.log_prob(samples) dist_log_prob = dist.log_prob(samples) tf.global_variables_initializer().run() base_log_prob_, dist_log_prob_ = sess.run( [base_log_prob, dist_log_prob]) self.assertAllClose(base_log_prob_, dist_log_prob_)
def testMaximumLikelihoodTraining(self): # Test Maximum Likelihood training with default bijector. with self.cached_session() as sess: base_dist = distributions.MultivariateNormalDiag(loc=[0., 0.]) batch_norm = BatchNormalization(training=True) dist = transformed_distribution_lib.TransformedDistribution( distribution=base_dist, bijector=batch_norm) target_dist = distributions.MultivariateNormalDiag(loc=[1., 2.]) target_samples = target_dist.sample(100) dist_samples = dist.sample(3000) loss = -math_ops.reduce_mean(dist.log_prob(target_samples)) with ops.control_dependencies(batch_norm.batchnorm.updates): train_op = adam.AdamOptimizer(1e-2).minimize(loss) moving_mean = array_ops.identity( batch_norm.batchnorm.moving_mean) moving_var = array_ops.identity( batch_norm.batchnorm.moving_variance) variables.global_variables_initializer().run() for _ in range(3000): sess.run(train_op) [dist_samples_, moving_mean_, moving_var_] = sess.run([dist_samples, moving_mean, moving_var]) self.assertAllClose([1., 2.], np.mean(dist_samples_, axis=0), atol=5e-2) self.assertAllClose([1., 2.], moving_mean_, atol=5e-2) self.assertAllClose([1., 1.], moving_var_, atol=5e-2)
def testCompareToBijector(self): """Demonstrates equivalence between TD, Bijector approach and AR dist.""" sample_shape = np.int32([4, 5]) batch_shape = np.int32([]) event_size = np.int32(2) with self.cached_session() as sess: batch_event_shape = np.concatenate([batch_shape, [event_size]], axis=0) sample0 = array_ops.zeros(batch_event_shape) affine = Affine(scale_tril=self._random_scale_tril(event_size)) ar = autoregressive_lib.Autoregressive(self._normal_fn(affine), sample0, validate_args=True) ar_flow = MaskedAutoregressiveFlow(is_constant_jacobian=True, shift_and_log_scale_fn=lambda x: [None, affine.forward(x)], validate_args=True) td = transformed_distribution_lib.TransformedDistribution( distribution=normal_lib.Normal(loc=0., scale=1.), bijector=ar_flow, event_shape=[event_size], batch_shape=batch_shape, validate_args=True) x_shape = np.concatenate([sample_shape, batch_shape, [event_size]], axis=0) x = 2. * self._rng.random_sample(x_shape).astype(np.float32) - 1. td_log_prob_, ar_log_prob_ = sess.run( [td.log_prob(x), ar.log_prob(x)]) self.assertAllClose(td_log_prob_, ar_log_prob_, atol=0., rtol=1e-6)
def testDocstringExample(self): with self.test_session(): exp_gamma_distribution = ( transformed_distribution_lib.TransformedDistribution( distribution=gamma_lib.Gamma(concentration=1., rate=2.), bijector=tfb.Invert(tfb.Exp()))) self.assertAllEqual([], tf.shape(exp_gamma_distribution.sample()).eval())
def testMutuallyConsistent(self): dims = 4 with self.cached_session() as sess: ma = MaskedAutoregressiveFlow(validate_args=True, **self._autoregressive_flow_kwargs) dist = transformed_distribution_lib.TransformedDistribution( distribution=normal_lib.Normal(loc=0., scale=1.), bijector=ma, event_shape=[dims], validate_args=True) self.run_test_sample_consistent_log_prob(sess_run_fn=sess.run, dist=dist, num_samples=int(1e5), radius=1., center=0., rtol=0.02)
def testMutuallyConsistent(self): dims = 4 with self.test_session() as sess: nvp = tfb.RealNVP( num_masked=3, validate_args=True, **self._real_nvp_kwargs) dist = transformed_distribution_lib.TransformedDistribution( distribution=tf.distributions.Normal(loc=0., scale=1.), bijector=nvp, event_shape=[dims], validate_args=True) self.run_test_sample_consistent_log_prob( sess_run_fn=sess.run, dist=dist, num_samples=int(1e5), radius=1., center=0., rtol=0.02)
def testLogProb(self): with self.test_session() as sess: layer = normalization.BatchNormalization(epsilon=0.) batch_norm = BatchNormalization(batchnorm_layer=layer, training=False) base_dist = distributions.MultivariateNormalDiag(loc=[0., 0.]) dist = transformed_distribution_lib.TransformedDistribution( distribution=base_dist, bijector=batch_norm, validate_args=True) samples = dist.sample(int(1e5)) # No volume distortion since training=False, bijector is initialized # to the identity transformation. base_log_prob = base_dist.log_prob(samples) dist_log_prob = dist.log_prob(samples) variables.global_variables_initializer().run() base_log_prob_, dist_log_prob_ = sess.run([base_log_prob, dist_log_prob]) self.assertAllClose(base_log_prob_, dist_log_prob_)
def testInvertMutuallyConsistent(self): # BatchNorm bijector is only mutually consistent when training=False. dims = 4 with self.cached_session() as sess: layer = normalization.BatchNormalization(epsilon=0.) batch_norm = Invert( BatchNormalization(batchnorm_layer=layer, training=False)) dist = transformed_distribution_lib.TransformedDistribution( distribution=normal_lib.Normal(loc=0., scale=1.), bijector=batch_norm, event_shape=[dims], validate_args=True) self.run_test_sample_consistent_log_prob(sess_run_fn=sess.run, dist=dist, num_samples=int(1e5), radius=2., center=0., rtol=0.02)
def quadrature_scheme_lognormal_quantiles(loc, scale, quadrature_size, validate_args=False, name=None): """Use LogNormal quantiles to form quadrature on positive-reals. Args: loc: `float`-like (batch of) scalar `Tensor`; the location parameter of the LogNormal prior. scale: `float`-like (batch of) scalar `Tensor`; the scale parameter of the LogNormal prior. quadrature_size: Python `int` scalar representing the number of quadrature points. validate_args: Python `bool`, default `False`. When `True` distribution parameters are checked for validity despite possibly degrading runtime performance. When `False` invalid inputs may silently render incorrect outputs. name: Python `str` name prefixed to Ops created by this class. Returns: grid: (Batch of) length-`quadrature_size` vectors representing the `log_rate` parameters of a `Poisson`. probs: (Batch of) length-`quadrature_size` vectors representing the weight associate with each `grid` value. """ with ops.name_scope(name, "quadrature_scheme_lognormal_quantiles", [loc, scale]): # Create a LogNormal distribution. dist = transformed_lib.TransformedDistribution( distribution=normal_lib.Normal(loc=loc, scale=scale), bijector=Exp(), validate_args=validate_args) batch_ndims = dist.batch_shape.ndims if batch_ndims is None: batch_ndims = array_ops.shape(dist.batch_shape_tensor())[0] def _compute_quantiles(): """Helper to build quantiles.""" # Omit {0, 1} since they might lead to Inf/NaN. zero = array_ops.zeros([], dtype=dist.dtype) edges = math_ops.linspace(zero, 1., quadrature_size + 3)[1:-1] # Expand edges so its broadcast across batch dims. edges = array_ops.reshape( edges, shape=array_ops.concat( [[-1], array_ops.ones([batch_ndims], dtype=dtypes.int32)], axis=0)) quantiles = dist.quantile(edges) # Cyclically permute left by one. perm = array_ops.concat([math_ops.range(1, 1 + batch_ndims), [0]], axis=0) quantiles = array_ops.transpose(quantiles, perm) return quantiles quantiles = _compute_quantiles() # Compute grid as quantile midpoints. grid = (quantiles[..., :-1] + quantiles[..., 1:]) / 2. # Set shape hints. grid.set_shape(dist.batch_shape.concatenate([quadrature_size])) # By construction probs is constant, i.e., `1 / quadrature_size`. This is # important, because non-constant probs leads to non-reparameterizable # samples. probs = array_ops.fill(dims=[quadrature_size], value=1. / math_ops.cast(quadrature_size, dist.dtype)) return grid, probs