def _parameter_control_dependencies(self, is_init): if not self.validate_args: return [] low = tf.convert_to_tensor(self.low) high = tf.convert_to_tensor(self.high) peak = tf.convert_to_tensor(self.peak) assertions = [] if (is_init != tensor_util.is_ref(self.low) and is_init != tensor_util.is_ref(self.high)): assertions.append( assert_util.assert_less( low, high, message='triangular not defined when low >= high.')) if (is_init != tensor_util.is_ref(self.low) and is_init != tensor_util.is_ref(self.peak)): assertions.append( assert_util.assert_less_equal( low, peak, message='triangular not defined when low > peak.')) if (is_init != tensor_util.is_ref(self.high) and is_init != tensor_util.is_ref(self.peak)): assertions.append( assert_util.assert_less_equal( peak, high, message='triangular not defined when peak > high.')) return assertions
def maybe_assert_continuous_bernoulli_param_correctness( is_init, validate_args, probs, logits): """Return assertions for `Bernoulli`-type distributions.""" if is_init: x, name = (probs, 'probs') if logits is None else (logits, 'logits') if not dtype_util.is_floating(x.dtype): raise TypeError('Argument `{}` must having floating type.'.format(name)) if not validate_args: return [] assertions = [] if probs is not None: if is_init != tensor_util.is_ref(probs): probs = tf.convert_to_tensor(probs) one = tf.constant(1.0, probs.dtype) assertions += [ assert_util.assert_non_negative( probs, message='probs has components less than 0.'), assert_util.assert_less_equal( probs, one, message='probs has components greater than 1.')] return assertions
def _inverse(self, y): # To derive the inverse mapping note that: # y[i] = exp(x[i]) / normalization # and # y[end] = 1 / normalization. # Thus: # x[i] = log(exp(x[i])) - log(y[end]) - log(normalization) # = log(exp(x[i])/normalization) - log(y[end]) # = log(y[i]) - log(y[end]) # Do this first to make sure CSE catches that it'll happen again in # _inverse_log_det_jacobian. assertions = [] if self.validate_args: assertions.append(assert_util.assert_near( tf.reduce_sum(y, axis=-1), tf.ones([], y.dtype), 2. * np.finfo(dtype_util.as_numpy_dtype(y.dtype)).eps, message='Last dimension of `y` must sum to `1`.')) assertions.append(assert_util.assert_less_equal( y, tf.ones([], y.dtype), message='Elements of `y` must be less than or equal to `1`.')) assertions.append(assert_util.assert_non_negative( y, message='Elements of `y` must be non-negative.')) with tf.control_dependencies(assertions): x = tf.math.log(y) x, log_normalization = tf.split(x, num_or_size_splits=[-1, 1], axis=-1) return x - log_normalization
def _parameter_control_dependencies(self, is_init): if not self.validate_args: return [] mean_direction = tf.convert_to_tensor(self.mean_direction) concentration = tf.convert_to_tensor(self.concentration) assertions = [] if is_init != tensor_util.is_ref(self._mean_direction): assertions.append( assert_util.assert_greater( tf.shape(mean_direction)[-1], 1, message='`mean_direction` may not have scalar event shape')) assertions.append( assert_util.assert_less_equal( tf.shape(mean_direction)[-1], 5, message='von Mises-Fisher ndims > 5 is not currently supported')) assertions.append( assert_util.assert_near( 1., tf.linalg.norm(mean_direction, axis=-1), message='`mean_direction` must be unit-length')) if is_init != tensor_util.is_ref(self._concentration): assertions.append( assert_util.assert_non_negative( concentration, message='`concentration` must be non-negative')) return assertions
def calculate_reshape(original_shape, new_shape, validate=False, name=None): """Calculates the reshaped dimensions (replacing up to one -1 in reshape).""" batch_shape_static = tensorshape_util.constant_value_as_shape(new_shape) if tensorshape_util.is_fully_defined(batch_shape_static): return np.int32(batch_shape_static), batch_shape_static, [] with tf.name_scope(name or 'calculate_reshape'): original_size = tf.reduce_prod(original_shape) implicit_dim = tf.equal(new_shape, -1) size_implicit_dim = (original_size // tf.maximum(1, -tf.reduce_prod(new_shape))) expanded_new_shape = tf.where( # Assumes exactly one `-1`. implicit_dim, size_implicit_dim, new_shape) validations = [] if not validate else [ # pylint: disable=g-long-ternary assert_util.assert_rank( original_shape, 1, message='Original shape must be a vector.'), assert_util.assert_rank( new_shape, 1, message='New shape must be a vector.'), assert_util.assert_less_equal( tf.math.count_nonzero(implicit_dim, dtype=tf.int32), 1, message='At most one dimension can be unknown.'), assert_util.assert_positive( expanded_new_shape, message='Shape elements must be >=-1.'), assert_util.assert_equal(tf.reduce_prod(expanded_new_shape), original_size, message='Shape sizes do not match.'), ] return expanded_new_shape, batch_shape_static, validations
def maybe_assert_negative_binomial_param_correctness( is_init, validate_args, total_count, probs, logits): """Return assertions for `NegativeBinomial`-type distributions.""" if is_init: x, name = (probs, 'probs') if logits is None else (logits, 'logits') if not dtype_util.is_floating(x.dtype): raise TypeError( 'Argument `{}` must having floating type.'.format(name)) if not validate_args: return [] assertions = [] if is_init != tensor_util.is_ref(total_count): total_count = tf.convert_to_tensor(total_count) assertions.extend([ assert_util.assert_positive( total_count, message='`total_count` has components less than or equal to 0.'), distribution_util.assert_integer_form( total_count, message='`total_count` has fractional components.') ]) if probs is not None: if is_init != tensor_util.is_ref(probs): probs = tf.convert_to_tensor(probs) one = tf.constant(1., probs.dtype) assertions.extend([ assert_util.assert_non_negative( probs, message='`probs` has components less than 0.'), assert_util.assert_less_equal( probs, one, message='`probs` has components greater than 1.') ]) return assertions
def _parameter_control_dependencies(self, is_init): assertions = [] logits = self._logits probs = self._probs param, name = (probs, 'probs') if logits is None else (logits, 'logits') # In init, we can always build shape and dtype checks because # we assume shape doesn't change for Variable backed args. if is_init: if not dtype_util.is_floating(param.dtype): raise TypeError('Argument `{}` must having floating type.'.format(name)) msg = 'Argument `{}` must have rank at least 1.'.format(name) shape_static = tensorshape_util.dims(param.shape) if shape_static is not None: if len(shape_static) < 1: raise ValueError(msg) elif self.validate_args: param = tf.convert_to_tensor(param) assertions.append( assert_util.assert_rank_at_least(param, 1, message=msg)) with tf.control_dependencies(assertions): param = tf.identity(param) msg1 = 'Argument `{}` must have final dimension >= 1.'.format(name) msg2 = 'Argument `{}` must have final dimension <= {}.'.format( name, dtype_util.max(tf.int32)) event_size = shape_static[-1] if shape_static is not None else None if event_size is not None: if event_size < 1: raise ValueError(msg1) if event_size > dtype_util.max(tf.int32): raise ValueError(msg2) elif self.validate_args: param = tf.convert_to_tensor(param) assertions.append(assert_util.assert_greater_equal( tf.shape(param)[-1], 1, message=msg1)) # NOTE: For now, we leave out a runtime assertion that # `tf.shape(param)[-1] <= tf.int32.max`. An earlier `tf.shape` call # will fail before we get to this point. if not self.validate_args: assert not assertions # Should never happen. return [] if probs is not None: probs = param # reuse tensor conversion from above if is_init != tensor_util.is_ref(probs): probs = tf.convert_to_tensor(probs) one = tf.ones([], dtype=probs.dtype) assertions.extend([ assert_util.assert_non_negative(probs), assert_util.assert_less_equal(probs, one), assert_util.assert_near( tf.reduce_sum(probs, axis=-1), one, message='Argument `probs` must sum to 1.'), ]) return assertions
def _prob(self, x): if self.validate_args: with tf.control_dependencies([ assert_util.assert_greater_equal(x, self.low), assert_util.assert_less_equal(x, self.high) ]): x = tf.identity(x) broadcast_x_to_high = _broadcast_to(x, [self.high]) left_of_peak = tf.logical_and( broadcast_x_to_high > self.low, broadcast_x_to_high <= self.peak) interval_length = self.high - self.low # This is the pdf function when a low <= high <= x. This looks like # a triangle, so we have to treat each line segment separately. result_inside_interval = tf.where( left_of_peak, # Line segment from (self.low, 0) to (self.peak, 2 / (self.high - # self.low). 2. * (x - self.low) / (interval_length * (self.peak - self.low)), # Line segment from (self.peak, 2 / (self.high - self.low)) to # (self.high, 0). 2. * (self.high - x) / (interval_length * (self.high - self.peak))) broadcast_x_to_peak = _broadcast_to(x, [self.peak]) outside_interval = tf.logical_or( broadcast_x_to_peak < self.low, broadcast_x_to_peak > self.high) broadcast_shape = tf.broadcast_dynamic_shape( tf.shape(input=x), self.batch_shape_tensor()) return tf.where( outside_interval, tf.zeros(broadcast_shape, dtype=self.dtype), result_inside_interval)
def _parameter_control_dependencies(self, is_init): assertions = [] if is_init: if not dtype_util.is_floating(self._scale.dtype): raise TypeError( 'scale.dtype={} is not a floating-point type.'.format( self._scale.dtype)) if not self._scale.is_square: raise ValueError('scale must be square.') dtype_util.assert_same_float_dtype([self._df, self._scale]) df_val = tf.get_static_value(self._df) dim_val = tf.compat.dimension_value(self._scale.shape[-1]) msg = ('Degrees of freedom (`df = {}`) cannot be less than dimension of ' 'scale matrix (`scale.dimension = {}`).') if is_init and df_val is not None and dim_val is not None: df_val = np.asarray(df_val) dim_val = np.asarray(dim_val) if not dim_val.shape: dim_val = dim_val[np.newaxis, ...] if not df_val.shape: df_val = df_val[np.newaxis, ...] if np.any(df_val < dim_val): raise ValueError(msg.format(df_val, dim_val)) elif self.validate_args: if (is_init != tensor_util.is_ref(self._df) or is_init != tensor_util.is_ref(self._scale)): df = tf.convert_to_tensor(self._df) dimension = self._dimension() assertions.append(assert_util.assert_less_equal( dimension, df, message=(msg.format(df, dimension)))) return assertions
def _parameter_control_dependencies(self, is_init): if not self.validate_args: return [] assertions = [] if is_init != tensor_util.is_ref(self.total_count): total_count = tf.convert_to_tensor(self.total_count) msg1 = 'Argument `total_count` must be non-negative.' msg2 = 'Argument `total_count` cannot contain fractional components.' assertions += [ assert_util.assert_non_negative(total_count, message=msg1), distribution_util.assert_integer_form(total_count, message=msg2), ] if self._probs is not None: if is_init != tensor_util.is_ref(self._probs): probs = tf.convert_to_tensor(self._probs) one = tf.constant(1., probs.dtype) assertions += [ assert_util.assert_non_negative( probs, message='probs has components less than 0.'), assert_util.assert_less_equal( probs, one, message='probs has components greater than 1.') ] return assertions
def _parameter_control_dependencies(self, is_init): if not self.validate_args: # Avoid computing intermediates needed to construct the assertions. return [] assertions = [] if is_init != tensor_util.is_ref(self._batch_shape_unexpanded): implicit_dim_mask = ps.equal(self._batch_shape_unexpanded, -1) assertions.append( assert_util.assert_rank(self._batch_shape_unexpanded, 1, message='New shape must be a vector.')) assertions.append( assert_util.assert_less_equal( tf.math.count_nonzero(implicit_dim_mask, dtype=tf.int32), 1, message='At most one dimension can be unknown.')) assertions.append( assert_util.assert_non_negative( self._batch_shape_unexpanded + 1, message='Shape elements must be >=-1.')) # Check that the old and new shapes are the same size. expanded_new_shape, original_size = self._calculate_new_shape() new_size = ps.reduce_prod(expanded_new_shape) assertions.append( assert_util.assert_equal(new_size, tf.cast(original_size, new_size.dtype), message='Shape sizes do not match.')) return assertions
def _parameter_control_dependencies(self, is_init): if not self.validate_args: return [] assertions = [] if is_init != tensor_util.is_ref(self._temperature): msg1 = 'Argument `temperature` must be positive.' temperature = tf.convert_to_tensor(self._temperature) assertions.append( assert_util.assert_positive(temperature, message=msg1)) if self._probs is not None: if is_init != tensor_util.is_ref(self._probs): probs = tf.convert_to_tensor(self._probs) one = tf.constant(1., probs.dtype) assertions.extend([ assert_util.assert_non_negative( probs, message='Argument `probs` has components less than 0.' ), assert_util.assert_less_equal( probs, one, message= 'Argument `probs` has components greater than 1.') ]) return assertions
def _sample_control_dependencies(self, x): assertions = [] if tensorshape_util.is_fully_defined(x.shape[-2:]): if not (tensorshape_util.dims(x.shape)[-2] == tensorshape_util.dims(x.shape)[-1] == self.dimension): raise ValueError( 'Input dimension mismatch: expected [..., {}, {}], got {}'. format(self.dimension, self.dimension, tensorshape_util.dims(x.shape))) elif self.validate_args: msg = 'Input dimension mismatch: expected [..., {}, {}], got {}'.format( self.dimension, self.dimension, tf.shape(x)) assertions.append( assert_util.assert_equal(tf.shape(x)[-2], self.dimension, message=msg)) assertions.append( assert_util.assert_equal(tf.shape(x)[-1], self.dimension, message=msg)) if self.validate_args and not self.input_output_cholesky: assertions.append( assert_util.assert_less_equal( dtype_util.as_numpy_dtype(x.dtype)(-1), x, message='Correlations must be >= -1.', summarize=30)) assertions.append( assert_util.assert_less_equal( x, dtype_util.as_numpy_dtype(x.dtype)(1), message='Correlations must be <= 1.', summarize=30)) assertions.append( assert_util.assert_near( tf.linalg.diag_part(x), dtype_util.as_numpy_dtype(x.dtype)(1), message='Self-correlations must be = 1.', summarize=30)) assertions.append( assert_util.assert_near( x, tf.linalg.matrix_transpose(x), message='Correlation matrices must be symmetric.', summarize=30)) return assertions
def _maybe_assert_valid_sample(self, x, dtype): if not self.validate_args: return x one = tf.ones([], dtype=dtype) return distribution_util.with_dependencies([ assert_util.assert_non_negative(x), assert_util.assert_less_equal(x, one), assert_util.assert_near(one, tf.reduce_sum(x, axis=[-1])), ], x)
def _sample_control_dependencies(self, x): assertions = [] if not self.validate_args: return assertions assertions.append(assert_util.assert_greater_equal( x, self.low, message='Sample must be greater than or equal to `low`.')) assertions.append(assert_util.assert_less_equal( x, self.high, message='Sample must be less than or equal to `high`.')) return assertions
def _sample_control_dependencies(self, x): assertions = [] if not self.validate_args: return assertions assertions.extend(distribution_util.assert_nonnegative_integer_form(x)) assertions.append( assert_util.assert_less_equal(x, tf.ones([], dtype=x.dtype), message='Elements cannot exceed 1.')) return assertions
def _maybe_assert_valid_y(self, y): if not self.validate_args: return [] is_positive = assert_util.assert_non_negative( y, message='Inverse transformation input must be greater than 0.') less_than_one = assert_util.assert_less_equal( y, tf.constant(1., y.dtype), message='Inverse transformation input must be less than or equal to 1.') return [is_positive, less_than_one]
def _maybe_assert_valid_y(self, y): if not self.validate_args: return y is_positive = assert_util.assert_non_negative( y, message="Inverse transformation input must be greater than 0.") less_than_one = assert_util.assert_less_equal( y, tf.constant(1., y.dtype), message="Inverse transformation input must be less than or equal to 1.") return distribution_util.with_dependencies([is_positive, less_than_one], y)
def _sample_control_dependencies(self, x): assertions = [] if not self.validate_args: return assertions assertions.append(assert_util.assert_non_negative( x, message='Sample must be non-negative.')) assertions.append(assert_util.assert_less_equal( x, tf.ones([], x.dtype), message='Sample must be less than or equal to `1`.')) return assertions
def _maybe_assert_valid_sample(self, counts): """Check counts for proper shape, values, then return tensor version.""" if not self.validate_args: return counts counts = distribution_util.embed_check_nonnegative_integer_form(counts) msg = ('Sampled counts must be itemwise less than ' 'or equal to `total_count` parameter.') return distribution_util.with_dependencies([ assert_util.assert_less_equal(counts, self.total_count, message=msg), ], counts)
def _sample_control_dependencies(self, x): assertions = [] if not self.validate_args: return assertions assertions.extend(distribution_util.assert_nonnegative_integer_form(x)) assertions.append( assert_util.assert_less_equal( x, tf.cast(self._num_categories(), x.dtype), message=('StoppingRatioLogistic samples must be `>= 0` and `<= K` ' 'where `K` is the number of cutpoints.'))) return assertions
def _is_valid_correlation_matrix(self, x): if not self.validate_args or self.input_output_cholesky: return [] return [ assert_util.assert_less_equal( dtype_util.as_numpy_dtype(x.dtype)(-1), x, message='Correlations must be >= -1.'), assert_util.assert_less_equal( x, dtype_util.as_numpy_dtype(x.dtype)(1), message='Correlations must be <= 1.'), assert_util.assert_near(tf.linalg.diag_part(x), dtype_util.as_numpy_dtype(x.dtype)(1), message='Self-correlations must be = 1.'), assert_util.assert_near( x, tf.linalg.matrix_transpose(x), message='Correlation matrices must be symmetric') ]
def _maybe_assert_valid(self, x): if not self.validate_args: return x return distribution_util.with_dependencies([ assert_util.assert_non_negative( x, message='Sample must be non-negative.'), assert_util.assert_less_equal( x, tf.ones([], self.concentration0.dtype), message='Sample must be less than or equal to `1`.'), ], x)
def _maybe_assert_valid_sample(self, counts): """Check counts for proper shape, values, then return tensor version.""" if not self.validate_args: return counts counts = distribution_util.embed_check_nonnegative_integer_form(counts) return distribution_util.with_dependencies([ assert_util.assert_less_equal( counts, self.total_count, message='counts are not less than or equal to n.'), ], counts)
def _sample_control_dependencies(self, x): assertions = [] if not self.validate_args: return assertions assertions.extend(distribution_util.assert_nonnegative_integer_form(x)) assertions.append( assert_util.assert_less_equal( x, tf.cast(self._num_categories(), x.dtype), message=('Categorical samples must be between `0` and `n-1` ' 'where `n` is the number of categories.'))) return assertions
def _assertions(self, t): if not self.validate_args: return [] return [ assert_util.assert_non_negative( t, message="Inverse transformation input must be greater than 0."), assert_util.assert_less_equal( t, dtype_util.as_numpy_dtype(t.dtype)(1.), message="Inverse transformation input must be less than or equal " "to 1.")]
def _sample_control_dependencies(self, counts): """Check counts for proper values.""" assertions = [] if not self.validate_args: return assertions assertions.extend(distribution_util.assert_nonnegative_integer_form(counts)) assertions.append( assert_util.assert_less_equal( counts, self.total_count, message=('Sampled counts must be itemwise less than ' 'or equal to `total_count` parameter.'))) return assertions
def _call_quantile(self, value, name, **kwargs): with self._name_and_control_scope(name): dtype = tf.float32 if tf.nest.is_nested(self.dtype) else self.dtype value = tf.convert_to_tensor(value, name='value', dtype_hint=dtype) if self.validate_args: value = distribution_util.with_dependencies([ assert_util.assert_less_equal(value, tf.cast(1, value.dtype), message='`value` must be <= 1'), assert_util.assert_greater_equal(value, tf.cast(0, value.dtype), message='`value` must be >= 0') ], value) return self._quantile(value, **kwargs)
def _validate_correlationness(self, x): if not self.validate_args or self.input_output_cholesky: return x checks = [ assert_util.assert_less_equal( dtype_util.as_numpy_dtype(x.dtype)(-1), x, message='Correlations must be >= -1.'), assert_util.assert_less_equal( x, dtype_util.as_numpy_dtype(x.dtype)(1), message='Correlations must be <= 1.'), assert_util.assert_near(tf.linalg.diag_part(x), dtype_util.as_numpy_dtype(x.dtype)(1), message='Self-correlations must be = 1.'), assert_util.assert_near( x, tf.linalg.matrix_transpose(x), message='Correlation matrices must be symmetric') ] with tf.control_dependencies(checks): return tf.identity(x)
def _assertions(self, t): if not self.validate_args: return [] return [ assert_util.assert_greater_equal( t, dtype_util.as_numpy_dtype(t.dtype)(-1), message="Inverse transformation input must be >= -1."), assert_util.assert_less_equal( t, dtype_util.as_numpy_dtype(t.dtype)(1), message="Inverse transformation input must be <= 1.") ]