def _parameter_control_dependencies(self, is_init): assertions = [] # In init, we can always build shape and dtype checks because # we assume shape doesn't change for Variable backed args. if is_init and self._is_vector: msg = "Argument `loc` must be at least rank 1." if tensorshape_util.rank(self.loc.shape) is not None: if tensorshape_util.rank(self.loc.shape) < 1: raise ValueError(msg) elif self.validate_args: assertions.append( assert_util.assert_rank_at_least(self.loc, 1, message=msg)) if not self.validate_args: assert not assertions # Should never happen return [] if is_init != tensor_util.is_mutable(self.atol): assertions.append( assert_util.assert_non_negative( self.atol, message="Argument 'atol' must be non-negative")) if is_init != tensor_util.is_mutable(self.rtol): assertions.append( assert_util.assert_non_negative( self.rtol, message="Argument 'rtol' must be non-negative")) return assertions
def _parameter_control_dependencies(self, is_init): if not self.validate_args: return [] assertions = [] if is_init != tensor_util.is_mutable(self.concentration): assertions.append( assert_util.assert_positive( self.concentration, message="Argument `concentration` must be positive.")) if is_init != tensor_util.is_mutable(self.rate): assertions.append( assert_util.assert_positive( self.rate, message="Argument `rate` must be positive.")) return assertions
def maybe_assert_categorical_param_correctness(is_init, validate_args, probs, logits): """Return assertions for `Categorical`-type distributions.""" assertions = [] # In init, we can always build shape and dtype checks because # we assume shape doesn't change for Variable backed args. if is_init: x, name = (probs, 'probs') if logits is None else (logits, 'logits') if not dtype_util.is_floating(x.dtype): raise TypeError( 'Argument `{}` must having floating type.'.format(name)) msg = 'Argument `{}` must have rank at least 1.'.format(name) ndims = tensorshape_util.rank(x.shape) if ndims is not None: if ndims < 1: raise ValueError(msg) elif validate_args: x = tf.convert_to_tensor(x) probs = x if logits is None else None # Retain tensor conversion. logits = x if probs is None else None assertions.append( assert_util.assert_rank_at_least(x, 1, message=msg)) if not validate_args: assert not assertions # Should never happen. return [] if logits is not None: if is_init != tensor_util.is_mutable(logits): logits = tf.convert_to_tensor(logits) assertions.extend( distribution_util.assert_categorical_event_shape(logits)) if probs is not None: if is_init != tensor_util.is_mutable(probs): probs = tf.convert_to_tensor(probs) assertions.extend([ assert_util.assert_non_negative(probs), assert_util.assert_near( tf.reduce_sum(probs, axis=-1), np.array(1, dtype=dtype_util.as_numpy_dtype(probs.dtype)), message='Argument `probs` must sum to 1.') ]) assertions.extend( distribution_util.assert_categorical_event_shape(probs)) return assertions
def maybe_assert_bernoulli_param_correctness( is_init, validate_args, probs, logits): """Return assertions for `Categorical`-type distributions.""" if is_init: x, name = (probs, 'probs') if logits is None else (logits, 'logits') if not dtype_util.is_floating(x.dtype): raise TypeError( 'Argument `{}` must having floating type.'.format(name)) if not validate_args: return [] assertions = [] if probs is not None: if is_init != tensor_util.is_mutable(probs): probs = tf.convert_to_tensor(probs) one = tf.constant(1., probs.dtype) assertions += [ assert_util.assert_non_negative( probs, message='probs has components less than 0.'), assert_util.assert_less_equal( probs, one, message='probs has components greater than 1.') ] return assertions
def _parameter_control_dependencies(self, is_init): if not self.validate_args: return [] assertions = [] if is_init != tensor_util.is_mutable(self._df): assertions.append(assert_util.assert_positive( self._df, message='Argument `df` must be positive.')) return assertions
def _parameter_control_dependencies(self, is_init): assertions = categorical_lib.maybe_assert_categorical_param_correctness( is_init, self.validate_args, self._probs, self._logits) if not self.validate_args: return assertions if is_init != tensor_util.is_mutable(self.total_count): assertions.extend(distribution_util.assert_nonnegative_integer_form( self.total_count)) return assertions
def test_various_types(self): self.assertFalse(tensor_util.is_mutable(0.)) self.assertFalse(tensor_util.is_mutable(FakeModule(0.))) self.assertFalse(tensor_util.is_mutable([tf.Variable(0.)])) # Note! self.assertFalse(tensor_util.is_mutable(np.array(0., np.float32))) self.assertFalse(tensor_util.is_mutable(tf.constant(0.))) self.assertTrue(tensor_util.is_mutable(FakeModule(tf.Variable(0.)))) self.assertTrue(tensor_util.is_mutable(tf.Variable(0.)))
def _parameter_control_dependencies(self, is_init): if not self.validate_args: return [] assertions = [] for concentration in [self.concentration0, self.concentration1]: if is_init != tensor_util.is_mutable(concentration): assertions.append(assert_util.assert_positive( concentration, message="Concentration parameter must be positive.")) return assertions
def _parameter_control_dependencies(self, is_init): if not self.validate_args: return [] assertions = [] if is_init != tensor_util.is_mutable(self.df): assertions.append( assert_util.assert_greater( self.df, dtype_util.as_numpy_dtype(self.df.dtype)(2.), message='`df` must be greater than 2.')) return assertions
def _parameter_control_dependencies(self, is_init): if is_init: dtype_util.assert_same_float_dtype([self.loc, self.scale]) if not self.validate_args: return [] assertions = [] if is_init != tensor_util.is_mutable(self._scale): assertions.append( assert_util.assert_positive( self._scale, message='Argument `scale` must be positive.')) return assertions
def _parameter_control_dependencies(self, is_init): if not self.validate_args: return [] assertions = [] if is_init != tensor_util.is_mutable(self.scale): assertions.append( assert_util.assert_none_equal( self.scale, tf.zeros([], dtype=self._scale.dtype), message="Argument `scale` must be non-zero.")) return assertions
def _parameter_control_dependencies(self, is_init): if is_init and not dtype_util.is_integer(self.axis.dtype): raise TypeError('Argument `axis` is not an `int` type.') if not self.validate_args: return [] assertions = [] if is_init != tensor_util.is_mutable(self.axis): assertions.append(assert_util.assert_negative( self.axis, message='Argument `axis` must be negative.')) return assertions
def assert_no_none_grad(bijector, method, wrt_vars, grads): for var, grad in zip(wrt_vars, grads): if 'log_det_jacobian' in method: if tensor_util.is_mutable(var): # We check tensor_util.is_mutable to accounts for xs/ys being in vars. var_name = var.name.rstrip('_0123456789:') else: var_name = '[arg]' to_check = bijector.bijector if is_invert(bijector) else bijector if var_name in NO_LDJ_GRADS_EXPECTED.get( type(to_check).__name__, ()): continue if grad is None: raise AssertionError( 'Missing `{}` -> {} grad for bijector {}'.format( method, var, bijector))
def _parameter_control_dependencies(self, is_init): """Checks the validity of the concentration parameter.""" assertions = [] # In init, we can always build shape and dtype checks because # we assume shape doesn't change for Variable backed args. if is_init: if not dtype_util.is_floating(self.concentration.dtype): raise TypeError('Argument `concentration` must be float type.') msg = 'Argument `concentration` must have rank at least 1.' ndims = tensorshape_util.rank(self.concentration.shape) if ndims is not None: if ndims < 1: raise ValueError(msg) elif self.validate_args: assertions.append( assert_util.assert_rank_at_least(self.concentration, 1, message=msg)) msg = 'Argument `concentration` must have `event_size` at least 2.' event_size = tf.compat.dimension_value( self.concentration.shape[-1]) if event_size is not None: if event_size < 2: raise ValueError(msg) elif self.validate_args: assertions.append( assert_util.assert_less(1, tf.shape(self.concentration)[-1], message=msg)) if not self.validate_args: assert not assertions # Should never happen. return [] if is_init != tensor_util.is_mutable(self.concentration): assertions.append( assert_util.assert_positive( self.concentration, message='Argument `concentration` must be positive.')) return assertions
def testDistribution(self, dist_name, data): if tf.executing_eagerly() != (FLAGS.tf_mode == 'eager'): return tf.compat.v1.set_random_seed( data.draw( hpnp.arrays(dtype=np.int64, shape=[]).filter(lambda x: x != 0))) dist, batch_shape = data.draw( distributions(dist_name=dist_name, enable_vars=True)) del batch_shape logging.info( 'distribution: %s; parameters used: %s', dist, [k for k, v in six.iteritems(dist.parameters) if v is not None]) self.evaluate([var.initializer for var in dist.variables]) for k, v in six.iteritems(dist.parameters): if not tensor_util.is_mutable(v): continue try: self.assertIs(getattr(dist, k), v) except AssertionError as e: raise AssertionError( 'No attr found for parameter {} of distribution {}: \n{}'. format(k, dist_name, e)) stat = data.draw( hps.one_of( map(hps.just, ['mean', 'mode', 'variance', 'covariance', 'entropy' ]))) try: VAR_USAGES.clear() getattr(dist, stat)() var_nusages = { var: len(usages) for var, usages in VAR_USAGES.items() } max_permissible = 2 # TODO(jvdillon): Reduce this to 1. if any( len(usages) > max_permissible for usages in VAR_USAGES.values()): for var, usages in six.iteritems(VAR_USAGES): if len(usages) > max_permissible: print( 'While executing statistic `{}` of `{}`, detected {} ' 'Tensor conversions for `{}`:'.format( stat, dist, len(usages), var)) for i, usage in enumerate(usages): print('Conversion {} of {}:\n{}'.format( i + 1, len(usages), ''.join(usage))) raise AssertionError( 'Excessive tensor conversions detected for {} {}: {}'. format(dist_name, stat, var_nusages)) except NotImplementedError: pass if dist.reparameterization_type == tfd.FULLY_REPARAMETERIZED: with tf.GradientTape() as tape: samp = dist.sample() grads = tape.gradient(samp, dist.variables) for grad, var in zip(grads, dist.variables): if grad is None: raise AssertionError( 'Missing sample -> {} grad for distribution {}'.format( var, dist_name)) if dist_name not in NO_LOG_PROB_PARAM_GRADS: # Turn off validations, since log_prob can choke on dist's own samples. dist = dist.copy(validate_args=False) with tf.GradientTape() as tape: lp = dist.log_prob(tf.stop_gradient(dist.sample())) grads = tape.gradient(lp, dist.variables) for grad, var in zip(grads, dist.variables): if grad is None: raise AssertionError( 'Missing log_prob -> {} grad for distribution {}'. format(var, dist_name))
def testDistribution(self, dist_name, data): if tf.executing_eagerly() != (FLAGS.tf_mode == 'eager'): return tf1.set_random_seed( data.draw( hpnp.arrays(dtype=np.int64, shape=[]).filter(lambda x: x != 0))) dist, batch_shape = data.draw( distributions(dist_name=dist_name, enable_vars=True)) batch_shape2 = data.draw(tfp_hps.broadcast_compatible_shape(batch_shape)) dist2, _ = data.draw( distributions( dist_name=dist_name, batch_shape=batch_shape2, event_dim=get_event_dim(dist), enable_vars=True)) del batch_shape logging.info( 'distribution: %s; parameters used: %s', dist, [k for k, v in six.iteritems(dist.parameters) if v is not None]) self.evaluate([var.initializer for var in dist.variables]) for k, v in six.iteritems(dist.parameters): if not tensor_util.is_mutable(v): continue try: self.assertIs(getattr(dist, k), v) except AssertionError as e: raise AssertionError( 'No attr found for parameter {} of distribution {}: \n{}'.format( k, dist_name, e)) for stat in data.draw( hps.sets( hps.one_of( map(hps.just, [ 'covariance', 'entropy', 'mean', 'mode', 'stddev', 'variance' ])), min_size=3, max_size=3)): logging.info('%s.%s', dist_name, stat) try: with tfp_hps.assert_no_excessive_var_usage( 'statistic `{}` of `{}`'.format(stat, dist)): getattr(dist, stat)() except NotImplementedError: pass with tf.GradientTape() as tape: with tfp_hps.assert_no_excessive_var_usage( 'method `sample` of `{}`'.format(dist)): sample = dist.sample() if dist.reparameterization_type == tfd.FULLY_REPARAMETERIZED: grads = tape.gradient(sample, dist.variables) for grad, var in zip(grads, dist.variables): var_name = var.name.rstrip('_0123456789:') if var_name in NO_SAMPLE_PARAM_GRADS.get(dist_name, ()): continue if grad is None: raise AssertionError( 'Missing sample -> {} grad for distribution {}'.format( var_name, dist_name)) # Turn off validations, since log_prob can choke on dist's own samples. # Also, to relax conversion counts for KL (might do >2 w/ validate_args). dist = dist.copy(validate_args=False) dist2 = dist2.copy(validate_args=False) try: for d1, d2 in (dist, dist2), (dist2, dist): with tf.GradientTape() as tape: with tfp_hps.assert_no_excessive_var_usage( '`kl_divergence` of (`{}` (vars {}), `{}` (vars {}))'.format( d1, d1.variables, d2, d2.variables), max_permissible=1): # No validation => 1 convert per var. kl = d1.kl_divergence(d2) wrt_vars = list(d1.variables) + list(d2.variables) grads = tape.gradient(kl, wrt_vars) for grad, var in zip(grads, wrt_vars): if grad is None and dist_name not in NO_KL_PARAM_GRADS: raise AssertionError('Missing KL({} || {}) -> {} grad:\n' '{} vars: {}\n{} vars: {}'.format( d1, d2, var, d1, d1.variables, d2, d2.variables)) except NotImplementedError: pass if dist_name not in NO_LOG_PROB_PARAM_GRADS: with tf.GradientTape() as tape: lp = dist.log_prob(tf.stop_gradient(sample)) grads = tape.gradient(lp, dist.variables) for grad, var in zip(grads, dist.variables): if grad is None: raise AssertionError( 'Missing log_prob -> {} grad for distribution {}'.format( var, dist_name)) for evaluative in data.draw( hps.sets( hps.one_of( map(hps.just, [ 'log_prob', 'prob', 'log_cdf', 'cdf', 'log_survival_function', 'survival_function' ])), min_size=3, max_size=3)): logging.info('%s.%s', dist_name, evaluative) try: with tfp_hps.assert_no_excessive_var_usage( 'evaluative `{}` of `{}`'.format(evaluative, dist), max_permissible=1): # No validation => 1 convert getattr(dist, evaluative)(sample) except NotImplementedError: pass
def testDistribution(self, dist_name, data): if tf.executing_eagerly() != (FLAGS.tf_mode == 'eager'): return tf.compat.v1.set_random_seed( data.draw( hpnp.arrays(dtype=np.int64, shape=[]).filter(lambda x: x != 0))) dist, batch_shape = data.draw( distributions(dist_name=dist_name, enable_vars=True)) del batch_shape logging.info( 'distribution: %s; parameters used: %s', dist, [k for k, v in six.iteritems(dist.parameters) if v is not None]) self.evaluate([var.initializer for var in dist.variables]) for k, v in six.iteritems(dist.parameters): if not tensor_util.is_mutable(v): continue try: self.assertIs(getattr(dist, k), v) except AssertionError as e: raise AssertionError( 'No attr found for parameter {} of distribution {}: \n{}'. format(k, dist_name, e)) for stat in data.draw( hps.permutations([ 'covariance', 'entropy', 'mean', 'mode', 'stddev', 'variance' ]))[:3]: logging.info('%s.%s', dist_name, stat) try: VAR_USAGES.clear() getattr(dist, stat)() assert_no_excessive_var_usage('statistic `{}` of `{}`'.format( stat, dist)) except NotImplementedError: pass VAR_USAGES.clear() with tf.GradientTape() as tape: sample = dist.sample() assert_no_excessive_var_usage('method `sample` of `{}`'.format(dist)) if dist.reparameterization_type == tfd.FULLY_REPARAMETERIZED: grads = tape.gradient(sample, dist.variables) for grad, var in zip(grads, dist.variables): if grad is None: raise AssertionError( 'Missing sample -> {} grad for distribution {}'.format( var, dist_name)) # Turn off validations, since log_prob can choke on dist's own samples. dist = dist.copy(validate_args=False) if dist_name not in NO_LOG_PROB_PARAM_GRADS: with tf.GradientTape() as tape: lp = dist.log_prob(tf.stop_gradient(sample)) grads = tape.gradient(lp, dist.variables) for grad, var in zip(grads, dist.variables): if grad is None: raise AssertionError( 'Missing log_prob -> {} grad for distribution {}'. format(var, dist_name)) for evaluative in data.draw( hps.permutations([ 'log_prob', 'prob', 'log_cdf', 'cdf', 'log_survival_function', 'survival_function' ]))[:3]: logging.info('%s.%s', dist_name, evaluative) try: VAR_USAGES.clear() getattr(dist, evaluative)(sample) assert_no_excessive_var_usage( 'evaluative `{}` of `{}`'.format(evaluative, dist), max_permissible=1) # No validation => 1 convert. except NotImplementedError: pass