def _parameter_control_dependencies(self, is_init): if not self.validate_args: return [] assertions = [] if self._num_steps is not None: if is_init != tensor_util.is_ref(self._num_steps): assertions.append(assert_util.assert_rank( self._num_steps, 0, message='Argument `num_steps` must be a scalar')) assertions.append(assert_util.assert_positive( self._num_steps, message='Argument `num_steps` must be positive')) return assertions
def _parameter_control_dependencies(self, is_init): if not self.validate_args: return [] assertions = [] if is_init != tensor_util.is_ref(self.total_count): total_count = tf.convert_to_tensor(self.total_count) msg1 = 'Argument `total_count` must be non-negative.' msg2 = 'Argument `total_count` cannot contain fractional components.' assertions += [ assert_util.assert_non_negative(total_count, message=msg1), distribution_util.assert_integer_form(total_count, message=msg2), ] for concentration in [self.concentration1, self.concentration0]: if is_init != tensor_util.is_ref(concentration): assertions.append( assert_util.assert_positive( concentration, message='Concentration parameter must be positive.')) return assertions
def _parameter_control_dependencies(self, is_init): if not self.validate_args: return [] assertions = [] low = None high = None if is_init != tensor_util.is_ref(self.low): low = tf.convert_to_tensor(self.low) high = tf.convert_to_tensor(self.high) assertions.append( assert_util.assert_less( low, high, message='uniform not defined when `low` >= `high`.')) if is_init != tensor_util.is_ref(self.high): low = tf.convert_to_tensor(self.low) if low is None else low high = tf.convert_to_tensor(self.high) if high is None else high assertions.append( assert_util.assert_less( low, high, message='uniform not defined when `low` >= `high`.')) return assertions
def _parameter_control_dependencies(self, is_init): if not self.validate_args: return [] assertions = [] if is_init != tensor_util.is_ref(self._temperature): msg1 = 'Argument `temperature` must be positive.' temperature = tf.convert_to_tensor(self._temperature) assertions.append(assert_util.assert_positive(temperature, message=msg1)) if self._probs is not None: if is_init != tensor_util.is_ref(self._probs): probs = tf.convert_to_tensor(self._probs) one = tf.constant(1., probs.dtype) assertions.extend([ assert_util.assert_non_negative( probs, message='Argument `probs` has components less than 0.'), assert_util.assert_less_equal( probs, one, message='Argument `probs` has components greater than 1.') ]) return assertions
def testVariableParametersArePreserved(self, process_name, data): # Check that the process passes Variables through to the accessor # properties (without converting them to Tensor or anything like that). process = data.draw( stochastic_processes(process_name, enable_vars=True)) self.evaluate([var.initializer for var in process.variables]) for k, v in six.iteritems(process.parameters): if not tensor_util.is_ref(v): continue self.assertIs( getattr(process, k), v, 'Parameter equivalance assertion failed for parameter `{}`'. format(k))
def _parameter_control_dependencies(self, is_init): if not self.validate_args: return [] assertions = [] for param_name, param in dict( concentration=self.concentration, mixing_concentration=self.mixing_concentration, mixing_rate=self.mixing_rate).items(): if is_init != tensor_util.is_ref(param): assertions.append(assert_util.assert_positive( param, message='Argument `{}` must be positive.'.format(param_name))) return assertions
def _parameter_control_dependencies(self, is_init): if not self.validate_args: return [] assertions = [] for c in [ self.concentration0_numerator, self.concentration1_numerator, self.concentration0_denominator, self.concentration1_denominator]: if is_init != tensor_util.is_ref(c): assertions.append(assert_util.assert_positive( c, message='`concentration` must be positive.')) return assertions
def _parameter_control_dependencies(self, is_init): assertions = [] if is_init: try: self._batch_shape() except ValueError: raise ValueError( 'Arguments `loc`, `scale`, and `rate` must have compatible shapes; ' 'loc.shape={}, scale.shape={}, rate.shape={}.'.format( self.loc.shape, self.scale.shape, self.rate.shape)) # We don't bother checking the shapes in the dynamic case because # all member functions access both arguments anyway. if is_init != tensor_util.is_ref(self.scale): assertions.append(assert_util.assert_positive( self.scale, message='Argument `scale` must be positive.')) if is_init != tensor_util.is_ref(self.rate): assertions.append(assert_util.assert_positive( self.rate, message='Argument `rate` must be positive.')) return assertions
def _parameter_control_dependencies(self, is_init): if not self.validate_args: return [] assertions = [] if is_init != tensor_util.is_ref(self.mass): assertions.append( assert_util.assert_positive( self.mass, message='Argument `mass` must be positive.')) if is_init != tensor_util.is_ref(self.width): assertions.append( assert_util.assert_positive( self.width, message='Argument `width` must be positive.')) if is_init != tensor_util.is_ref(self.smin): assertions.append( assert_util.assert_non_negative( self.smin, message='Argument `smin` must be positive or zero.')) if is_init != tensor_util.is_ref(self.smax): assertions.append( assert_util.assert_greater( self.smax, self.smin, message='Argument `smax` must be larger than `smin`.')) return assertions
def _parameter_control_dependencies(self, is_init): if not self.validate_args: return [] assertions = [] if is_init != tensor_util.is_ref(self._batch_shape_parameter): assertions.append( assert_util.assert_rank( self._batch_shape_parameter, 1, message='Batch shape must be a vector.')) assertions.append( assert_util.assert_non_negative( self._batch_shape_parameter, message='Shape elements must be >-1.')) return assertions
def _parameter_control_dependencies(self, is_init): assertions = super(Wishart, self)._parameter_control_dependencies(is_init) if not self.validate_args: assert not assertions return [] if self._scale_full is None: if is_init != tensor_util.is_ref(self._scale_tril): shape = prefer_static.shape(self._scale_tril) assertions.extend( [assert_util.assert_positive( tf.linalg.diag_part(self._scale_tril), message='`scale_tril` must be positive definite.'), assert_util.assert_equal( shape[-1], shape[-2], message='`scale_tril` must be square.')] ) else: if is_init != tensor_util.is_ref(self._scale_full): assertions.append(distribution_util.assert_symmetric(self._scale_full)) return assertions
def _parameter_control_dependencies(self, is_init): if not self.validate_args: return [] assertions = [] tailweight_is_ref = tensor_util.is_ref(self.tailweight) tailweight = tf.convert_to_tensor(self.tailweight) if (is_init != tailweight_is_ref and is_init != tensor_util.is_ref(self.skewness)): assertions.append( assert_util.assert_less( tf.math.abs(self.skewness), tailweight, message='Expect `tailweight > |skewness|`')) if is_init != tensor_util.is_ref(self.scale): assertions.append( assert_util.assert_positive( self.scale, message='Argument `scale` must be positive.')) if is_init != tailweight_is_ref: assertions.append( assert_util.assert_positive( tailweight, message='Argument `tailweight` must be positive.')) return assertions
def assert_no_none_grad(bijector, method, wrt_vars, grads): for var, grad in zip(wrt_vars, grads): if 'log_det_jacobian' in method: if tensor_util.is_ref(var): # We check tensor_util.is_ref to accounts for xs/ys being in vars. var_name = var.name.rstrip('_0123456789:') else: var_name = '[arg]' to_check = bijector.bijector if is_invert(bijector) else bijector if var_name in NO_LDJ_GRADS_EXPECTED.get( type(to_check).__name__, ()): continue if grad is None: raise AssertionError( 'Missing `{}` -> {} grad for bijector {}'.format( method, var, bijector))
def _parameter_control_dependencies(self, is_init): if not self.validate_args: return [] assertions = [] if self._probs is not None: if is_init != tensor_util.is_ref(self._probs): probs = tf.convert_to_tensor(self._probs) assertions.append( assert_util.assert_positive( probs, message='Argument `probs` must be positive.')) assertions.append( assert_util.assert_less_equal( probs, dtype_util.as_numpy_dtype(self.dtype)(1.), message= 'Argument `probs` must be less than or equal to 1.')) return assertions
def _parameter_control_dependencies(self, is_init): assertions = [] sample_shape = None # Memoize concretization. # Check valid shape. ndims_ = tensorshape_util.rank(self.sample_shape.shape) if is_init != (ndims_ is None): msg = 'Argument `sample_shape` must be either a scalar or a vector.' if ndims_ is not None: if ndims_ > 1: raise ValueError(msg) elif self.validate_args: if sample_shape is None: sample_shape = tf.convert_to_tensor(self.sample_shape) assertions.append( assert_util.assert_less(tf.rank(sample_shape), 2, message=msg)) # Check valid dtype. if is_init: # No xor check because `dtype` cannot change. dtype_ = self.sample_shape.dtype if dtype_ is None: if sample_shape is None: sample_shape = tf.convert_to_tensor(self.sample_shape) dtype_ = sample_shape.dtype if dtype_util.base_dtype(dtype_) not in {tf.int32, tf.int64}: raise TypeError( 'Argument `sample_shape` must be integer type; ' 'saw {}.'.format(dtype_util.name(dtype_))) # Check valid "value". if is_init != tensor_util.is_ref(self.sample_shape): sample_shape_ = tf.get_static_value(self.sample_shape) msg = 'Argument `sample_shape` must have non-negative values.' if sample_shape_ is not None: if np.any(np.array(sample_shape_) < 0): raise ValueError('{} Saw: {}'.format(msg, sample_shape_)) elif self.validate_args: if sample_shape is None: sample_shape = tf.convert_to_tensor(self.sample_shape) assertions.append( assert_util.assert_greater(sample_shape, -1, message=msg)) return assertions
def _parameter_control_dependencies(self, is_init): """Checks the validity of the concentration parameter.""" assertions = [] # In init, we can always build shape and dtype checks because # we assume shape doesn't change for Variable backed args. if is_init: if not dtype_util.is_floating(self.concentration.dtype): raise TypeError('Argument `concentration` must be float type.') msg = 'Argument `concentration` must have rank at least 1.' ndims = tensorshape_util.rank(self.concentration.shape) if ndims is not None: if ndims < 1: raise ValueError(msg) elif self.validate_args: assertions.append( assert_util.assert_rank_at_least(self.concentration, 1, message=msg)) msg = 'Argument `concentration` must have `event_size` at least 2.' event_size = tf.compat.dimension_value( self.concentration.shape[-1]) if event_size is not None: if event_size < 2: raise ValueError(msg) elif self.validate_args: assertions.append( assert_util.assert_less(1, tf.shape(self.concentration)[-1], message=msg)) if not self.validate_args: assert not assertions # Should never happen. return [] if is_init != tensor_util.is_ref(self.concentration): assertions.append( assert_util.assert_positive( self.concentration, message='Argument `concentration` must be positive.')) return assertions
def _parameter_control_dependencies(self, is_init): assertions = super(WishartTriL, self)._parameter_control_dependencies(is_init) if not self.validate_args: assert not assertions return [] if is_init != tensor_util.is_ref(self._scale_tril): shape = ps.shape(self._scale_tril) assertions.extend([ assert_util.assert_positive( tf.linalg.diag_part(self._scale_tril), message='`scale_tril` must be positive definite.'), assert_util.assert_equal( shape[-1], shape[-2], message='`scale_tril` must be square.') ]) return assertions
def assert_no_none_grad(bijector, method, wrt_vars, grads): for var, grad in zip(wrt_vars, grads): expect_grad = var.dtype not in (tf.int32, tf.int64) if 'log_det_jacobian' in method: if tensor_util.is_ref(var): # We check tensor_util.is_ref to account for xs/ys being in vars. var_name = var.name.rstrip('_0123456789:').split('/')[-1] else: var_name = '[arg]' to_check = bijector.bijector if is_invert(bijector) else bijector to_check_method = INVERT_LDJ[method] if is_invert(bijector) else method if var_name == '[arg]' and bijector.is_constant_jacobian: expect_grad = False exempt_var_method = NO_LDJ_GRADS_EXPECTED.get(type(to_check).__name__, {}) if to_check_method in exempt_var_method.get(var_name, ()): expect_grad = False if expect_grad != (grad is not None): raise AssertionError('{} `{}` -> {} grad for bijector {}'.format( 'Missing' if expect_grad else 'Unexpected', method, var, bijector))
def _parameter_control_dependencies(self, is_init): if not self.validate_args: return [] assertions = [] ok_to_check = lambda x: ( # pylint:disable=g-long-lambda x is not None) and (is_init != tensor_util.is_ref(x)) bias_variance = self.bias_variance slope_variance = self.slope_variance if ok_to_check(self.exponent): exponent = tf.convert_to_tensor(self.exponent) assertions.append( assert_util.assert_positive( exponent, message='`exponent` must be positive.')) from tensorflow_probability.python.internal import distribution_util # pylint: disable=g-import-not-at-top assertions.append( distribution_util.assert_integer_form( exponent, message='`exponent` must be an integer.')) if ok_to_check(self.bias_variance): bias_variance = tf.convert_to_tensor(self.bias_variance) assertions.append( assert_util.assert_non_negative( bias_variance, message='`bias_variance` must be non-negative.')) if ok_to_check(self.slope_variance): slope_variance = tf.convert_to_tensor(self.slope_variance) assertions.append( assert_util.assert_non_negative( slope_variance, message='`slope_variance` must be non-negative.')) if (ok_to_check(self.bias_variance) and ok_to_check(self.slope_variance)): assertions.append( assert_util.assert_positive( tf.math.abs(slope_variance) + tf.math.abs(bias_variance), message=('`slope_variance` and `bias_variance` ' 'can not both be zero.'))) return assertions
def base_kernels(draw, kernel_name=None, batch_shape=None, event_dim=None, feature_dim=None, feature_ndims=None, enable_vars=False): if kernel_name is None: kernel_name = draw(hps.sampled_from(sorted(INSTANTIABLE_BASE_KERNELS))) if batch_shape is None: batch_shape = draw(tfp_hps.shapes()) if event_dim is None: event_dim = draw(hps.integers(min_value=2, max_value=6)) if feature_dim is None: feature_dim = draw(hps.integers(min_value=2, max_value=6)) if feature_ndims is None: feature_ndims = draw(hps.integers(min_value=2, max_value=6)) kernel_params = draw( broadcasting_params(kernel_name, batch_shape, event_dim=event_dim, enable_vars=enable_vars)) kernel_variable_names = [ k for k in kernel_params if tensor_util.is_ref(kernel_params[k]) ] hp.note('Forming kernel {} with constrained parameters {}'.format( kernel_name, kernel_params)) ctor = getattr(tfpk, kernel_name) result_kernel = ctor(validate_args=True, feature_ndims=feature_ndims, **kernel_params) if batch_shape != result_kernel.batch_shape: msg = ('Kernel strategy generated a bad batch shape ' 'for {}, should have been {}.').format(result_kernel, batch_shape) raise AssertionError(msg) return result_kernel, kernel_variable_names
def _parameter_control_dependencies(self, is_init): if not self.validate_args: # Avoid computing intermediates needed to construct the assertions. return [] assertions = [] if is_init != tensor_util.is_ref(self._batch_shape_unexpanded): implicit_dim_mask = prefer_static.equal(self._batch_shape_unexpanded, -1) assertions.append(assert_util.assert_rank( self._batch_shape_unexpanded, 1, message='New shape must be a vector.')) assertions.append(assert_util.assert_less_equal( tf.math.count_nonzero(implicit_dim_mask, dtype=tf.int32), 1, message='At most one dimension can be unknown.')) assertions.append(assert_util.assert_non_negative( self._batch_shape_unexpanded + 1, message='Shape elements must be >=-1.')) # Check that the old and new shapes are the same size. expanded_new_shape, original_size = self._calculate_new_shape() new_size = prefer_static.reduce_prod(expanded_new_shape) assertions.append(assert_util.assert_equal( new_size, tf.cast(original_size, new_size.dtype), message='Shape sizes do not match.')) return assertions
def _parameter_control_dependencies(self, is_init): assertions = [] # In init, we can always build shape and dtype checks because # we assume shape doesn't change for Variable backed args. if is_init: if not dtype_util.is_floating(self.cutpoints.dtype): raise TypeError( 'Argument `cutpoints` must having floating type.') if not dtype_util.is_floating(self.loc.dtype): raise TypeError('Argument `loc` must having floating type.') cutpoint_dims = tensorshape_util.rank(self.cutpoints.shape) msg = 'Argument `cutpoints` must have rank at least 1.' if cutpoint_dims is not None: if cutpoint_dims < 1: raise ValueError(msg) elif self.validate_args: cutpoints = tf.convert_to_tensor(self.cutpoints) assertions.append( assert_util.assert_rank_at_least(cutpoints, 1, message=msg)) if not self.validate_args: return [] if is_init != tensor_util.is_ref(self.cutpoints): cutpoints = tf.convert_to_tensor(self.cutpoints) assertions.append( distribution_util.assert_nondecreasing( cutpoints, message='Argument `cutpoints` must be non-decreasing.')) return assertions
def _parameter_control_dependencies(self, is_init): assertions = [] if is_init != tensor_util.is_ref(self.permutation): if not dtype_util.is_integer(self.permutation.dtype): raise TypeError('permutation.dtype ({}) should be `int`-like.'.format( dtype_util.name(self.permutation.dtype))) p = tf.get_static_value(self.permutation) if p is not None: if set(p) != set(np.arange(p.size)): raise ValueError('Permutation over `d` must contain exactly one of ' 'each of `{0, 1, ..., d}`.') if self.validate_args: p = tf.sort(self.permutation, axis=-1) assertions.append( assert_util.assert_equal( p, tf.range(tf.shape(p)[-1]), message=('Permutation over `d` must contain exactly one of ' 'each of `{0, 1, ..., d}`.'))) return assertions
def _parameter_control_dependencies(self, is_init): if not self.validate_args: return [] low = tf.convert_to_tensor(self.low) high = tf.convert_to_tensor(self.high) peak = tf.convert_to_tensor(self.peak) assertions = [] if (is_init != tensor_util.is_ref(self.low) and is_init != tensor_util.is_ref(self.high)): assertions.append(assert_util.assert_less( low, high, message='triangular not defined when low >= high.')) if (is_init != tensor_util.is_ref(self.low) and is_init != tensor_util.is_ref(self.peak)): assertions.append( assert_util.assert_less_equal( low, peak, message='triangular not defined when low > peak.')) if (is_init != tensor_util.is_ref(self.high) and is_init != tensor_util.is_ref(self.peak)): assertions.append( assert_util.assert_less_equal( peak, high, message='triangular not defined when peak > high.')) return assertions
def _parameter_control_dependencies(self, is_init): if not self.validate_args: return [] assertions = [] low = None high = None if is_init != tensor_util.is_ref(self.low): low = tf.convert_to_tensor(self.low) assertions.append( assert_util.assert_finite(low, message='`low` is not finite')) if is_init != tensor_util.is_ref(self.high): high = tf.convert_to_tensor(self.high) assertions.append( assert_util.assert_finite(high, message='`high` is not finite')) if is_init != tensor_util.is_ref(self.loc): assertions.append( assert_util.assert_finite(self.loc, message='`loc` is not finite')) if is_init != tensor_util.is_ref(self.scale): scale = tf.convert_to_tensor(self.scale) assertions.extend([ assert_util.assert_positive( scale, message='`scale` must be positive'), assert_util.assert_finite(scale, message='`scale` is not finite'), ]) if (is_init != tensor_util.is_ref(self.low) or is_init != tensor_util.is_ref(self.high)): low = tf.convert_to_tensor(self.low) if low is None else low high = tf.convert_to_tensor(self.high) if high is None else high assertions.append( assert_util.assert_greater( high, low, message='TruncatedCauchy not defined when `low >= high`.')) return assertions
def testDistribution(self, dist_name, data): seed = test_util.test_seed() # Explicitly draw event_dim here to avoid relying on _params_event_ndims # later, so this test can support distributions that do not implement the # slicing protocol. event_dim = data.draw(hps.integers(min_value=2, max_value=6)) dist = data.draw(dhps.distributions( dist_name=dist_name, event_dim=event_dim, enable_vars=True)) batch_shape = dist.batch_shape batch_shape2 = data.draw(tfp_hps.broadcast_compatible_shape(batch_shape)) dist2 = data.draw( dhps.distributions( dist_name=dist_name, batch_shape=batch_shape2, event_dim=event_dim, enable_vars=True)) self.evaluate([var.initializer for var in dist.variables]) # Check that the distribution passes Variables through to the accessor # properties (without converting them to Tensor or anything like that). for k, v in six.iteritems(dist.parameters): if not tensor_util.is_ref(v): continue self.assertIs(getattr(dist, k), v) # Check that standard statistics do not read distribution parameters more # than twice (once in the stat itself and up to once in any validation # assertions). max_permissible = 2 + extra_tensor_conversions_allowed(dist) for stat in sorted(data.draw( hps.sets( hps.one_of( map(hps.just, [ 'covariance', 'entropy', 'mean', 'mode', 'stddev', 'variance' ])), min_size=3, max_size=3))): hp.note('Testing excessive var usage in {}.{}'.format(dist_name, stat)) try: with tfp_hps.assert_no_excessive_var_usage( 'statistic `{}` of `{}`'.format(stat, dist), max_permissible=max_permissible): getattr(dist, stat)() except NotImplementedError: pass # Check that `sample` doesn't read distribution parameters more than twice, # and that it produces non-None gradients (if the distribution is fully # reparameterized). with tf.GradientTape() as tape: # TDs do bijector assertions twice (once by distribution.sample, and once # by bijector.forward). max_permissible = 2 + extra_tensor_conversions_allowed(dist) with tfp_hps.assert_no_excessive_var_usage( 'method `sample` of `{}`'.format(dist), max_permissible=max_permissible): sample = dist.sample(seed=seed) if dist.reparameterization_type == tfd.FULLY_REPARAMETERIZED: grads = tape.gradient(sample, dist.variables) for grad, var in zip(grads, dist.variables): var_name = var.name.rstrip('_0123456789:') if var_name in NO_SAMPLE_PARAM_GRADS.get(dist_name, ()): continue if grad is None: raise AssertionError( 'Missing sample -> {} grad for distribution {}'.format( var_name, dist_name)) # Turn off validations, since TODO(b/129271256) log_prob can choke on dist's # own samples. Also, to relax conversion counts for KL (might do >2 w/ # validate_args). dist = dist.copy(validate_args=False) dist2 = dist2.copy(validate_args=False) # Test that KL divergence reads distribution parameters at most once, and # that is produces non-None gradients. try: for d1, d2 in (dist, dist2), (dist2, dist): with tf.GradientTape() as tape: with tfp_hps.assert_no_excessive_var_usage( '`kl_divergence` of (`{}` (vars {}), `{}` (vars {}))'.format( d1, d1.variables, d2, d2.variables), max_permissible=1): # No validation => 1 convert per var. kl = d1.kl_divergence(d2) wrt_vars = list(d1.variables) + list(d2.variables) grads = tape.gradient(kl, wrt_vars) for grad, var in zip(grads, wrt_vars): if grad is None and dist_name not in NO_KL_PARAM_GRADS: raise AssertionError('Missing KL({} || {}) -> {} grad:\n' '{} vars: {}\n{} vars: {}'.format( d1, d2, var, d1, d1.variables, d2, d2.variables)) except NotImplementedError: pass # Test that log_prob produces non-None gradients, except for distributions # on the NO_LOG_PROB_PARAM_GRADS blacklist. if dist_name not in NO_LOG_PROB_PARAM_GRADS: with tf.GradientTape() as tape: lp = dist.log_prob(tf.stop_gradient(sample)) grads = tape.gradient(lp, dist.variables) for grad, var in zip(grads, dist.variables): if grad is None: raise AssertionError( 'Missing log_prob -> {} grad for distribution {}'.format( var, dist_name)) # Test that all forms of probability evaluation avoid reading distribution # parameters more than once. for evaluative in sorted(data.draw( hps.sets( hps.one_of( map(hps.just, [ 'log_prob', 'prob', 'log_cdf', 'cdf', 'log_survival_function', 'survival_function' ])), min_size=3, max_size=3))): hp.note('Testing excessive var usage in {}.{}'.format( dist_name, evaluative)) try: # No validation => 1 convert. But for TD we allow 2: # dist.log_prob(bijector.inverse(samp)) + bijector.ildj(samp) max_permissible = 2 + extra_tensor_conversions_allowed(dist) with tfp_hps.assert_no_excessive_var_usage( 'evaluative `{}` of `{}`'.format(evaluative, dist), max_permissible=max_permissible): getattr(dist, evaluative)(sample) except NotImplementedError: pass
def _maybe_validate_shape_override(self, override_shape, base_is_scalar_fn, static_base_shape, is_init): """Helper which ensures override batch/event_shape are valid.""" assertions = [] concretized_shape = None # Check valid dtype if is_init: # No xor check because `dtype` cannot change. dtype_ = override_shape.dtype if dtype_ is None: if concretized_shape is None: concretized_shape = tf.convert_to_tensor(override_shape) dtype_ = concretized_shape.dtype if dtype_util.base_dtype(dtype_) not in {tf.int32, tf.int64}: raise TypeError('Shape override must be integer type; ' 'saw {}.'.format(dtype_util.name(dtype_))) # Check non-negative elements if is_init != tensor_util.is_ref(override_shape): override_shape_ = tf.get_static_value(override_shape) msg = 'Shape override must have non-negative elements.' if override_shape_ is not None: if np.any(np.array(override_shape_) < 0): raise ValueError('{} Saw: {}'.format(msg, override_shape_)) elif self.validate_args: if concretized_shape is None: concretized_shape = tf.convert_to_tensor(override_shape) assertions.append( assert_util.assert_non_negative(concretized_shape, message=msg)) # Check valid shape override_ndims_ = tensorshape_util.rank(override_shape.shape) if is_init != (override_ndims_ is None): msg = 'Shape override must be a vector.' if override_ndims_ is not None: if override_ndims_ != 1: raise ValueError(msg) elif self.validate_args: if concretized_shape is None: concretized_shape = tf.convert_to_tensor(override_shape) override_rank = tf.rank(concretized_shape) assertions.append( assert_util.assert_equal(override_rank, 1, message=msg)) static_base_rank = tensorshape_util.rank(static_base_shape) # Determine if the override shape is `[]` (static_override_dims == [0]), # in which case the base distribution may be nonscalar. static_override_dims = tensorshape_util.dims(override_shape.shape) if is_init != (static_base_rank is None or static_override_dims is None): msg = 'Base distribution is not scalar.' if static_base_rank is not None and static_override_dims is not None: if static_base_rank != 0 and static_override_dims != [0]: raise ValueError(msg) elif self.validate_args: if concretized_shape is None: concretized_shape = tf.convert_to_tensor(override_shape) override_is_empty = tf.logical_not( self._has_nonzero_rank(concretized_shape)) assertions.append( assert_util.assert_equal(tf.logical_or( base_is_scalar_fn(), override_is_empty), True, message=msg)) return assertions
def _parameter_control_dependencies(self, is_init): assertions = [] logits = self._logits probs = self._probs param, name = (probs, 'probs') if logits is None else (logits, 'logits') # In init, we can always build shape and dtype checks because # we assume shape doesn't change for Variable backed args. if is_init: if not dtype_util.is_floating(param.dtype): raise TypeError( 'Argument `{}` must having floating type.'.format(name)) msg = 'Argument `{}` must have rank at least 1.'.format(name) shape_static = tensorshape_util.dims(param.shape) if shape_static is not None: if len(shape_static) < 1: raise ValueError(msg) elif self.validate_args: param = tf.convert_to_tensor(param) assertions.append( assert_util.assert_rank_at_least(param, 1, message=msg)) with tf.control_dependencies(assertions): param = tf.identity(param) msg1 = 'Argument `{}` must have final dimension >= 1.'.format(name) msg2 = 'Argument `{}` must have final dimension <= {}.'.format( name, dtype_util.max(tf.int32)) event_size = shape_static[-1] if shape_static is not None else None if event_size is not None: if event_size < 1: raise ValueError(msg1) if event_size > dtype_util.max(tf.int32): raise ValueError(msg2) elif self.validate_args: param = tf.convert_to_tensor(param) assertions.append( assert_util.assert_greater_equal(tf.shape(param)[-1], 1, message=msg1)) # NOTE: For now, we leave out a runtime assertion that # `tf.shape(param)[-1] <= tf.int32.max`. An earlier `tf.shape` call # will fail before we get to this point. if not self.validate_args: assert not assertions # Should never happen. return [] if probs is not None: probs = param # reuse tensor conversion from above if is_init != tensor_util.is_ref(probs): probs = tf.convert_to_tensor(probs) one = tf.ones([], dtype=probs.dtype) assertions.extend([ assert_util.assert_non_negative(probs), assert_util.assert_less_equal(probs, one), assert_util.assert_near( tf.reduce_sum(probs, axis=-1), one, message='Argument `probs` must sum to 1.'), ]) return assertions
def _parameter_control_dependencies(self, is_init): if not self.validate_args: return [] if is_init != any(tensor_util.is_ref(v) for v in self.scale.variables): return [self.scale.assert_non_singular()] return []
def as_composite(obj): """Returns a `CompositeTensor` equivalent to the given object. Note that the returned object will have any `Variable`, `tfp.util.DeferredTensor`, or `tfp.util.TransformedVariable` references it closes over converted to tensors at the time this function is called. The type of the returned object will be a subclass of both `CompositeTensor` and `type(obj)`. For this reason, one should be careful about using `as_composite()`, especially for `tf.Module` objects. For example, when the composite tensor is created even as part of a `tf.Module`, it "fixes" the values of the `DeferredTensor` and `tf.Variable` objects it uses: ```python class M(tf.Module): def __init__(self): self._v = tf.Variable(1.) self._d = tfp.distributions.Normal( tfp.util.DeferredTensor(self._v, lambda v: v + 1), 10) self._dct = tfp.experimental.as_composite(self._d) @tf.function def mean(self): return self._dct.mean() m = M() m.mean() >>> <tf.Tensor: numpy=2.0> m._v.assign(2.) # Doesn't update the CompositeTensor distribution. m.mean() >>> <tf.Tensor: numpy=2.0> ``` If, however, the creation of the composite is deferred to a method call, then the Variable and DeferredTensor will be properly captured and respected by the Module and its `SavedModel` (if it is serialized). ```python class M(tf.Module): def __init__(self): self._v = tf.Variable(1.) self._d = tfp.distributions.Normal( tfp.util.DeferredTensor(self._v, lambda v: v + 1), 10) @tf.function def d(self): return tfp.experimental.as_composite(self._d) m = M() m.d().mean() >>> <tf.Tensor: numpy=2.0> m._v.assign(2.) m.d().mean() >>> <tf.Tensor: numpy=3.0> ``` Note: This method is best-effort and based on a heuristic for what the tensor parameters are and what the non-tensor parameters are. Things might be broken, especially for meta-distributions like `TransformedDistribution` or `Independent`. (We try to raise NotImplementedError in such cases.) If you'd benefit from better coverage, please file an issue on github or send an email to `[email protected]`. Args: obj: A `tfp.distributions.Distribution`. Returns: obj: A `tfp.distributions.Distribution` that extends `CompositeTensor`. """ if isinstance(obj, CompositeTensor): return obj cls = _make_convertible(type(obj)) kwargs = dict(obj.parameters) def mk_err_msg(suffix=''): return ( 'Unable to make a CompositeTensor for "{}" of type `{}`. Email ' '`[email protected]` or file an issue on github if you ' 'would benefit from this working. {}'.format( obj, type(obj), suffix)) try: composite_tensor_params = obj._composite_tensor_params # pylint: disable=protected-access except (AttributeError, NotImplementedError): composite_tensor_params = () for k in composite_tensor_params: # Use dtype inference from ctor. if k in kwargs and kwargs[k] is not None: v = getattr(obj, k, kwargs[k]) try: kwargs[k] = tf.convert_to_tensor(v, name=k) except (ValueError, TypeError) as e: kwargs[k] = v for k, v in kwargs.items(): def composite_helper(v): # If we have a parameters attribute, then we may be able to convert to # a composite tensor by guessing which of the parameters are tensors. In # essence, we duck-type based on this attribute. if hasattr(v, 'parameters'): return as_composite(v) return v kwargs[k] = tf.nest.map_structure(composite_helper, v) # Unfortunately, tensor_util.is_ref(v) returns true for a # tf.linalg.LinearOperator even though that is not ideal behavior. if tensor_util.is_ref(v) and not isinstance(v, tf.linalg.LinearOperator): try: kwargs[k] = tf.convert_to_tensor(v, name=k) except TypeError as e: raise NotImplementedError( mk_err_msg( '(Unable to convert dependent entry \'{}\' of object ' '\'{}\': {})'.format(k, obj, str(e)))) result = cls(**kwargs) struct_coder = nested_structure_coder.StructureCoder() try: struct_coder.encode_structure(result._type_spec) # pylint: disable=protected-access except nested_structure_coder.NotEncodableError as e: raise NotImplementedError( mk_err_msg('(Unable to serialize: {})'.format(str(e)))) return result