def _prob(self, x): if self.validate_args: with tf.control_dependencies([ assert_util.assert_greater_equal(x, self.low), assert_util.assert_less_equal(x, self.high) ]): x = tf.identity(x) broadcast_x_to_high = _broadcast_to(x, [self.high]) left_of_peak = tf.logical_and( broadcast_x_to_high > self.low, broadcast_x_to_high <= self.peak) interval_length = self.high - self.low # This is the pdf function when a low <= high <= x. This looks like # a triangle, so we have to treat each line segment separately. result_inside_interval = tf.where( left_of_peak, # Line segment from (self.low, 0) to (self.peak, 2 / (self.high - # self.low). 2. * (x - self.low) / (interval_length * (self.peak - self.low)), # Line segment from (self.peak, 2 / (self.high - self.low)) to # (self.high, 0). 2. * (self.high - x) / (interval_length * (self.high - self.peak))) broadcast_x_to_peak = _broadcast_to(x, [self.peak]) outside_interval = tf.logical_or( broadcast_x_to_peak < self.low, broadcast_x_to_peak > self.high) broadcast_shape = tf.broadcast_dynamic_shape( tf.shape(input=x), self.batch_shape_tensor()) return tf.where( outside_interval, tf.zeros(broadcast_shape, dtype=self.dtype), result_inside_interval)
def _variance(self): tailweight = tf.convert_to_tensor(self.tailweight) scale = tf.convert_to_tensor(self.scale) # For tail < 0.5, the variance is finite. See Eq (18) in # https://www.hindawi.com/journals/tswj/2015/909231/ var = ( tf.cast(tf.pow(1. - 2. * tailweight, -3. / 2.), dtype=self.dtype) * tf.math.square(scale)) # We need to put the tf.where inside the outer tf.where to ensure we never # hit a NaN in the gradient. result_where_defined = tf.where( tailweight < 0.5, var, tf.convert_to_tensor(np.inf, dtype=self.dtype)) if self.allow_nan_stats: return tf.where(tailweight < 1.0, result_where_defined, tf.convert_to_tensor(np.nan, self.dtype)) else: return distribution_util.with_dependencies([ assert_util.assert_greater_equal( tf.ones([], dtype=self.dtype), tailweight, message= "variance not defined for components of tailweight >= 1"), ], result_where_defined)
def assert_mvn_target_conservation(event_size, batch_size, **kwargs): strm = tfp_test_util.test_seed_stream() initialization = tfd.MultivariateNormalFullCovariance( loc=tf.zeros(event_size), covariance_matrix=tf.eye(event_size)).sample( batch_size, seed=strm()) samples, _ = run_nuts_chain( event_size, batch_size, num_steps=1, initial_state=initialization, **kwargs) answer = samples[0][-1] check_cdf_agrees = ( st.assert_multivariate_true_cdf_equal_on_projections_two_sample( answer, initialization, num_projections=100, false_fail_rate=1e-6)) check_sample_shape = assert_util.assert_equal( tf.shape(answer)[0], batch_size) movement = tf.linalg.norm(answer - initialization, axis=-1) # This movement distance (0.3) was copied from the univariate case. check_movement = assert_util.assert_greater_equal( tf.reduce_mean(movement), 0.3) check_enough_power = assert_util.assert_less( st.min_discrepancy_of_true_cdfs_detectable_by_dkwm_two_sample( batch_size, batch_size, false_fail_rate=1e-8, false_pass_rate=1e-6), 0.055) return ( check_cdf_agrees, check_sample_shape, check_movement, check_enough_power, )
def _maybe_assert_float_matrix(logu, validate_args): """Assertion check for the scores matrix to be float type.""" logu = tf.convert_to_tensor(logu, dtype_hint=tf.float32, name='logu') if not dtype_util.is_floating(logu.dtype): raise TypeError('Input argument must be `float` type.') assertions = [] # Check scores is a matrix. msg = 'Input argument must be a (batch of) matrix.' rank = tensorshape_util.rank(logu.shape) if rank is not None: if rank < 2: raise ValueError(msg) elif validate_args: assertions.append(assert_util.assert_rank_at_least(logu, 2, msg)) # Check scores has the shape [..., N, M], M >= N msg = 'Input argument must be a (batch of) matrix of the shape [N, M], M > N.' if (rank is not None and tensorshape_util.is_fully_defined(logu.shape[-2:])): if logu.shape[-2] > logu.shape[-1]: raise ValueError(msg) elif validate_args: n, m = tf.unstack(logu.shape[-2:]) assertions.append(assert_util.assert_greater_equal(m, n, message=msg)) return assertions
def _parameter_control_dependencies(self, is_init): assertions = [] logits = self._logits probs = self._probs param, name = (probs, 'probs') if logits is None else (logits, 'logits') # In init, we can always build shape and dtype checks because # we assume shape doesn't change for Variable backed args. if is_init: if not dtype_util.is_floating(param.dtype): raise TypeError('Argument `{}` must having floating type.'.format(name)) msg = 'Argument `{}` must have rank at least 1.'.format(name) shape_static = tensorshape_util.dims(param.shape) if shape_static is not None: if len(shape_static) < 1: raise ValueError(msg) elif self.validate_args: param = tf.convert_to_tensor(param) assertions.append( assert_util.assert_rank_at_least(param, 1, message=msg)) with tf.control_dependencies(assertions): param = tf.identity(param) msg1 = 'Argument `{}` must have final dimension >= 1.'.format(name) msg2 = 'Argument `{}` must have final dimension <= {}.'.format( name, dtype_util.max(tf.int32)) event_size = shape_static[-1] if shape_static is not None else None if event_size is not None: if event_size < 1: raise ValueError(msg1) if event_size > dtype_util.max(tf.int32): raise ValueError(msg2) elif self.validate_args: param = tf.convert_to_tensor(param) assertions.append(assert_util.assert_greater_equal( tf.shape(param)[-1], 1, message=msg1)) # NOTE: For now, we leave out a runtime assertion that # `tf.shape(param)[-1] <= tf.int32.max`. An earlier `tf.shape` call # will fail before we get to this point. if not self.validate_args: assert not assertions # Should never happen. return [] if probs is not None: probs = param # reuse tensor conversion from above if is_init != tensor_util.is_ref(probs): probs = tf.convert_to_tensor(probs) one = tf.ones([], dtype=probs.dtype) assertions.extend([ assert_util.assert_non_negative(probs), assert_util.assert_less_equal(probs, one), assert_util.assert_near( tf.reduce_sum(probs, axis=-1), one, message='Argument `probs` must sum to 1.'), ]) return assertions
def _check_valid_event_ndims(self, min_event_ndims, event_ndims): """Check whether event_ndims is atleast min_event_ndims.""" event_ndims = tf.convert_to_tensor(event_ndims, name='event_ndims') event_ndims_ = tf.get_static_value(event_ndims) assertions = [] if not dtype_util.is_integer(event_ndims.dtype): raise ValueError('Expected integer dtype, got dtype {}'.format( event_ndims.dtype)) if event_ndims_ is not None: if tensorshape_util.rank(event_ndims.shape) != 0: raise ValueError('Expected scalar event_ndims, got shape {}'.format( event_ndims.shape)) if min_event_ndims > event_ndims_: raise ValueError('event_ndims ({}) must be larger than ' 'min_event_ndims ({})'.format(event_ndims_, min_event_ndims)) elif self.validate_args: assertions += [ assert_util.assert_greater_equal(event_ndims, min_event_ndims) ] if tensorshape_util.is_fully_defined(event_ndims.shape): if tensorshape_util.rank(event_ndims.shape) != 0: raise ValueError('Expected scalar shape, got ndims {}'.format( tensorshape_util.rank(event_ndims.shape))) elif self.validate_args: assertions += [ assert_util.assert_rank(event_ndims, 0, message='Expected scalar.') ] return assertions
def _maybe_assert_valid_sample(self, x, loc): """Checks the validity of a sample.""" if not self.validate_args: return [] return [ assert_util.assert_greater_equal( x, loc, message='x is not in the support of the distribution') ]
def _check_arg_and_apply_f(*args, **kwargs): dist = args[0] x = args[1] with tf.control_dependencies([ assert_util.assert_greater_equal( x, dist.loc, message="x is not in the support of the distribution") ] if dist.validate_args else []): return f(*args, **kwargs)
def _parameter_control_dependencies(self, is_init): if not self.validate_args: return [] assertions = [] if is_init != tensor_util.is_ref(self._num_steps): assertions.append(assert_util.assert_greater_equal( self._num_steps, 1, message='Argument `num_steps` must be at least 1.')) return assertions
def _sample_control_dependencies(self, x): assertions = [] if not self.validate_args: return assertions assertions.append(assert_util.assert_greater_equal( x, self.low, message='Sample must be greater than or equal to `low`.')) assertions.append(assert_util.assert_less_equal( x, self.high, message='Sample must be less than or equal to `high`.')) return assertions
def _sample_control_dependencies(self, x): """Checks the validity of a sample.""" assertions = [] if not self.validate_args: return assertions loc = tf.convert_to_tensor(self.loc) assertions.append( assert_util.assert_greater_equal( x, loc, message='Sample must be greater than or equal to `loc`.')) return assertions
def _maybe_assert_valid_x(self, x): if not self.validate_args: return [] return [ assert_util.assert_greater_equal( x, self.loc, message= 'Forward transformation input must be greater than `loc`.') ]
def _parameter_control_dependencies(self, is_init): assertions = [] scores = self._scores param, name = (scores, 'scores') # In init, we can always build shape and dtype checks because # we assume shape doesn't change for Variable backed args. if is_init: if not dtype_util.is_floating(param.dtype): raise TypeError( 'Argument `{}` must having floating type.'.format(name)) msg = 'Argument `{}` must have rank at least 1.'.format(name) shape_static = tensorshape_util.dims(param.shape) if shape_static is not None: if len(shape_static) < 1: raise ValueError(msg) elif self.validate_args: param = tf.convert_to_tensor(param) assertions.append( assert_util.assert_rank_at_least(param, 1, message=msg)) with tf.control_dependencies(assertions): param = tf.identity(param) msg1 = 'Argument `{}` must have final dimension >= 1.'.format(name) msg2 = 'Argument `{}` must have final dimension <= {}.'.format( name, tf.int32.max) event_size = shape_static[-1] if shape_static is not None else None if event_size is not None: if event_size < 1: raise ValueError(msg1) if event_size > tf.int32.max: raise ValueError(msg2) elif self.validate_args: param = tf.convert_to_tensor(param) assertions.append( assert_util.assert_greater_equal(tf.shape(param)[-1], 1, message=msg1)) # NOTE: For now, we leave out a runtime assertion that # `tf.shape(param)[-1] <= tf.int32.max`. An earlier `tf.shape` call # will fail before we get to this point. if not self.validate_args: assert not assertions # Should never happen. return [] if is_init != tensor_util.is_ref(scores): scores = tf.convert_to_tensor(scores) assertions.extend([ assert_util.assert_positive(scores), ]) return assertions
def _parameter_control_dependencies(self, is_init): if not self.validate_args: return [] assertions = [] if is_init != tensor_util.is_ref(self._tailweight): assertions.append( assert_util.assert_greater_equal( self._tailweight, tf.zeros([], dtype=self.dtype), message="Argument `tailweight` must be non-negative.")) return assertions
def _prob(self, x): concentration = tf.convert_to_tensor(self.concentration) scale = tf.convert_to_tensor(self.scale) with tf.control_dependencies([ assert_util.assert_greater_equal( x, scale, message='`x` is not in the support of the distribution.') ] if self.validate_args else []): def prob_on_support(z): return concentration * (scale**concentration) / (z**(concentration + 1)) return self._extend_support(x, scale, prob_on_support, alt=0.)
def _prob(self, x): with tf.control_dependencies([ assert_util.assert_greater_equal( x, self.scale, message="x is not in the support of the distribution.") ] if self.validate_args else []): def prob_on_support(z): return (self.concentration * (self.scale ** self.concentration) / (z ** (self.concentration + 1))) return self._extend_support(x, prob_on_support, alt=0.)
def _call_quantile(self, value, name, **kwargs): with self._name_and_control_scope(name): dtype = tf.float32 if tf.nest.is_nested(self.dtype) else self.dtype value = tf.convert_to_tensor(value, name='value', dtype_hint=dtype) if self.validate_args: value = distribution_util.with_dependencies([ assert_util.assert_less_equal(value, tf.cast(1, value.dtype), message='`value` must be <= 1'), assert_util.assert_greater_equal(value, tf.cast(0, value.dtype), message='`value` must be >= 0') ], value) return self._quantile(value, **kwargs)
def _assertions(self, t): if not self.validate_args: return [] return [ assert_util.assert_greater_equal( t, dtype_util.as_numpy_dtype(t.dtype)(-1), message="Inverse transformation input must be >= -1."), assert_util.assert_less_equal( t, dtype_util.as_numpy_dtype(t.dtype)(1), message="Inverse transformation input must be <= 1.") ]
def _log_prob(self, x): with tf.control_dependencies([ assert_util.assert_greater_equal( x, self.scale, message="x is not in the support of the distribution.") ] if self.validate_args else []): def log_prob_on_support(z): return (tf.math.log(self.concentration) + self.concentration * tf.math.log(self.scale) - (self.concentration + 1.) * tf.math.log(z)) return self._extend_support(x, log_prob_on_support, alt=-np.inf)
def _assert_compatible_shape(self, index, sample_shape, samples): requested_shape, _ = self._expand_sample_shape_to_vector( tf.convert_to_tensor(sample_shape, dtype=tf.int32), name='requested_shape') actual_shape = prefer_static.shape(samples) actual_rank = prefer_static.rank_from_shape(actual_shape) requested_rank = prefer_static.rank_from_shape(requested_shape) # We test for two properties we expect of yielded distributions: # (1) The rank of the tensor of generated samples must be at least # as large as the rank requested. # (2) The requested shape must be a prefix of the shape of the # generated tensor of samples. # We attempt to perform test (1) statically first. # We don't need to do this explicitly for test (2) because # `assert_equal` evaluates statically if it can. static_actual_rank = tf.get_static_value(actual_rank) static_requested_rank = tf.get_static_value(requested_rank) assertion_message = ('Samples yielded by distribution #{} are not ' 'consistent with `sample_shape` passed to ' '`JointDistributionCoroutine` ' 'distribution.'.format(index)) # TODO Remove this static check (b/138738650) if (static_actual_rank is not None and static_requested_rank is not None): # We're able to statically check the rank if static_actual_rank < static_requested_rank: raise ValueError(assertion_message) else: control_dependencies = [] else: # We're not able to statically check the rank control_dependencies = [ assert_util.assert_greater_equal(actual_rank, requested_rank, message=assertion_message) ] with tf.control_dependencies(control_dependencies): trimmed_actual_shape = actual_shape[:requested_rank] control_dependencies = [ assert_util.assert_equal(requested_shape, trimmed_actual_shape, message=assertion_message) ] return control_dependencies
def _log_prob(self, x): concentration = tf.convert_to_tensor(self.concentration) scale = tf.convert_to_tensor(self.scale) with tf.control_dependencies([ assert_util.assert_greater_equal( x, scale, message='`x` is not in the support of the distribution.') ] if self.validate_args else []): def log_prob_on_support(z): # This can also be written as log(c) + c * log(s) - (c + 1) * log(z). # However, when c >> 1 and s and z are of the same magnitude, this can # lead to loss of precision (log(c) vs. log(c) - log(z)). return (tf.math.log(concentration / z) + concentration * tf.math.log(scale / z)) return self._extend_support( x, scale, log_prob_on_support, alt=-np.inf)
def _sample_control_dependencies(self, x): assertions = [] if not self.validate_args: return assertions loc = tf.convert_to_tensor(self.loc) scale = tf.convert_to_tensor(self.scale) concentration = tf.convert_to_tensor(self.concentration) assertions.append(assert_util.assert_greater_equal( x, loc, message='Sample must be greater than or equal to `loc`.')) assertions.append(assert_util.assert_equal( tf.logical_or(tf.greater_equal(concentration, 0), tf.less_equal(x, loc - scale / concentration)), True, message=('If `concentration < 0`, sample must be less than or ' 'equal to `loc - scale / concentration`.'), summarize=100)) return assertions
def assert_univariate_target_conservation(test, target_d, step_size): # Sample count limited partly by memory reliably available on Forge. The test # remains reasonable even if the nuts recursion limit is severely curtailed # (e.g., 3 or 4 levels), so use that to recover some memory footprint and bump # the sample count. num_samples = int(5e4) num_steps = 1 strm = test_util.test_seed_stream() # We wrap the initial values in `tf.identity` to avoid broken gradients # resulting from a bijector cache hit, since bijectors of the same # type/parameterization now share a cache. # TODO(b/72831017): Fix broken gradients caused by bijector caching. initialization = tf.identity(target_d.sample([num_samples], seed=strm())) @tf.function(autograph=False) def run_chain(): nuts = tfp.experimental.mcmc.PreconditionedNoUTurnSampler( target_d.log_prob, step_size=step_size, max_tree_depth=3, unrolled_leapfrog_steps=2) result = tfp.mcmc.sample_chain(num_results=num_steps, num_burnin_steps=0, current_state=initialization, trace_fn=None, kernel=nuts, seed=strm()) return result result = run_chain() test.assertAllEqual([num_steps, num_samples], result.shape) answer = result[0] check_cdf_agrees = st.assert_true_cdf_equal_by_dkwm(answer, target_d.cdf, false_fail_rate=1e-6) check_enough_power = assert_util.assert_less( st.min_discrepancy_of_true_cdfs_detectable_by_dkwm( num_samples, false_fail_rate=1e-6, false_pass_rate=1e-6), 0.025) movement = tf.abs(answer - initialization) test.assertAllEqual([num_samples], movement.shape) # This movement distance (1 * step_size) was selected by reducing until 100 # runs with independent seeds all passed. check_movement = assert_util.assert_greater_equal(tf.reduce_mean(movement), 1 * step_size) return (check_cdf_agrees, check_enough_power, check_movement)
def _check_at_least_two_chains(accept_prob, reduce_chain_axis_names, validate_args, message): """Checks that the number of chains is at least 2.""" # Number of total chains is local batch size * distributed axis size local_axis_size = ps.size(accept_prob) distributed_axis_size = int( ps.reduce_prod([ distribute_lib.get_axis_size(a) for a in reduce_chain_axis_names ])) num_chains = local_axis_size * distributed_axis_size num_chains_ = tf.get_static_value(num_chains) if num_chains_ is not None: if num_chains_ < 2: raise ValueError('{} Got: {}'.format(message, num_chains_)) elif validate_args: with tf.control_dependencies( [assert_util.assert_greater_equal(num_chains, 2, message)]): accept_prob = tf.identity(accept_prob) return accept_prob
def assert_univariate_target_conservation(test, target_d, step_size): # Sample count limited partly by memory reliably available on Forge. The test # remains reasonable even if the nuts recursion limit is severely curtailed # (e.g., 3 or 4 levels), so use that to recover some memory footprint and bump # the sample count. num_samples = int(5e4) num_steps = 1 strm = tfp.util.SeedStream(salt='univariate_nuts_test', seed=1) initialization = target_d.sample([num_samples], seed=strm()) @tf.function(autograph=False) def run_chain(): nuts = tfp.mcmc.NoUTurnSampler( target_d.log_prob, step_size=step_size, max_tree_depth=3, unrolled_leapfrog_steps=2, seed=strm()) result, _ = tfp.mcmc.sample_chain( num_results=num_steps, num_burnin_steps=0, current_state=initialization, kernel=nuts) return result result = run_chain() test.assertAllEqual([num_steps, num_samples], result.shape) answer = result[0] check_cdf_agrees = st.assert_true_cdf_equal_by_dkwm( answer, target_d.cdf, false_fail_rate=1e-6) check_enough_power = assert_util.assert_less( st.min_discrepancy_of_true_cdfs_detectable_by_dkwm( num_samples, false_fail_rate=1e-6, false_pass_rate=1e-6), 0.025) movement = tf.abs(answer - initialization) test.assertAllEqual([num_samples], movement.shape) # This movement distance (1 * step_size) was selected by reducing until 100 # runs with independent seeds all passed. check_movement = assert_util.assert_greater_equal( tf.reduce_mean(movement), 1 * step_size) return (check_cdf_agrees, check_enough_power, check_movement)
def _prob(self, x): if self.validate_args: with tf.control_dependencies([ assert_util.assert_greater_equal(x, self.low), assert_util.assert_less_equal(x, self.high) ]): x = tf.identity(x) interval_length = self.high - self.low # This is the pdf function when a low <= high <= x. This looks like # a triangle, so we have to treat each line segment separately. result_inside_interval = tf.where( (x >= self.low) & (x <= self.peak), # Line segment from (self.low, 0) to (self.peak, 2 / (self.high - # self.low). 2. * (x - self.low) / (interval_length * (self.peak - self.low)), # Line segment from (self.peak, 2 / (self.high - self.low)) to # (self.high, 0). 2. * (self.high - x) / (interval_length * (self.high - self.peak))) return tf.where((x < self.low) | (x > self.high), tf.zeros_like(x), result_inside_interval)
def _prob(self, x): low = tf.convert_to_tensor(self.low) high = tf.convert_to_tensor(self.high) peak = tf.convert_to_tensor(self.peak) if self.validate_args: with tf.control_dependencies([ assert_util.assert_greater_equal(x, low), assert_util.assert_less_equal(x, high) ]): x = tf.identity(x) interval_length = high - low # This is the pdf function when a low <= high <= x. This looks like # a triangle, so we have to treat each line segment separately. result_inside_interval = tf.where( (x >= low) & (x <= peak), # Line segment from (low, 0) to (peak, 2 / (high - low)). 2. * (x - low) / (interval_length * (peak - low)), # Line segment from (peak, 2 / (high - low)) to (high, 0). 2. * (high - x) / (interval_length * (high - peak))) return tf.where((x < low) | (x > high), tf.zeros_like(x), result_inside_interval)
def __init__(self, initial_distribution, transition_distribution, observation_distribution, num_steps, validate_args=False, allow_nan_stats=True, name="HiddenMarkovModel"): """Initialize hidden Markov model. Args: initial_distribution: A `Categorical`-like instance. Determines probability of first hidden state in Markov chain. The number of categories must match the number of categories of `transition_distribution` as well as both the rightmost batch dimension of `transition_distribution` and the rightmost batch dimension of `observation_distribution`. transition_distribution: A `Categorical`-like instance. The rightmost batch dimension indexes the probability distribution of each hidden state conditioned on the previous hidden state. observation_distribution: A `tfp.distributions.Distribution`-like instance. The rightmost batch dimension indexes the distribution of each observation conditioned on the corresponding hidden state. num_steps: The number of steps taken in Markov chain. A python `int`. validate_args: Python `bool`, default `False`. When `True` distribution parameters are checked for validity despite possibly degrading runtime performance. When `False` invalid inputs may silently render incorrect outputs. Default value: `False`. allow_nan_stats: Python `bool`, default `True`. When `True`, statistics (e.g., mean, mode, variance) use the value "`NaN`" to indicate the result is undefined. When `False`, an exception is raised if one or more of the statistic's batch members are undefined. Default value: `True`. name: Python `str` name prefixed to Ops created by this class. Default value: "HiddenMarkovModel". Raises: ValueError: if `num_steps` is not at least 1. ValueError: if `initial_distribution` does not have scalar `event_shape`. ValueError: if `transition_distribution` does not have scalar `event_shape.` ValueError: if `transition_distribution` and `observation_distribution` are fully defined but don't have matching rightmost dimension. """ parameters = dict(locals()) # pylint: disable=protected-access with tf.name_scope(name) as name: self._runtime_assertions = [] # pylint: enable=protected-access num_steps = tf.convert_to_tensor(value=num_steps, name="num_steps") if validate_args: self._runtime_assertions += [ assert_util.assert_equal( tf.rank(num_steps), 0, message="`num_steps` must be a scalar") ] self._runtime_assertions += [ assert_util.assert_greater_equal( num_steps, 1, message="`num_steps` must be at least 1.") ] self._initial_distribution = initial_distribution self._observation_distribution = observation_distribution self._transition_distribution = transition_distribution if (initial_distribution.event_shape is not None and tensorshape_util.rank( initial_distribution.event_shape) != 0): raise ValueError( "`initial_distribution` must have scalar `event_dim`s") elif validate_args: self._runtime_assertions += [ assert_util.assert_equal( tf.shape(initial_distribution.event_shape_tensor())[0], 0, message="`initial_distribution` must have scalar" "`event_dim`s") ] if (transition_distribution.event_shape is not None and tensorshape_util.rank( transition_distribution.event_shape) != 0): raise ValueError( "`transition_distribution` must have scalar `event_dim`s") elif validate_args: self._runtime_assertions += [ assert_util.assert_equal( tf.shape( transition_distribution.event_shape_tensor())[0], 0, message="`transition_distribution` must have scalar" "`event_dim`s") ] if (tensorshape_util.dims(transition_distribution.batch_shape) is not None and tensorshape_util.rank( transition_distribution.batch_shape) == 0): raise ValueError( "`transition_distribution` can't have scalar batches") elif validate_args: self._runtime_assertions += [ assert_util.assert_greater( tf.size(transition_distribution.batch_shape_tensor()), 0, message="`transition_distribution` can't have scalar " "batches") ] if (tensorshape_util.dims(observation_distribution.batch_shape) is not None and tensorshape_util.rank( observation_distribution.batch_shape) == 0): raise ValueError( "`observation_distribution` can't have scalar batches") elif validate_args: self._runtime_assertions += [ assert_util.assert_greater( tf.size(observation_distribution.batch_shape_tensor()), 0, message="`observation_distribution` can't have scalar " "batches") ] # Infer number of hidden states and check consistency # between transitions and observations with tf.control_dependencies(self._runtime_assertions): self._num_states = ( (tensorshape_util.dims(transition_distribution.batch_shape) is not None and tensorshape_util.as_list( transition_distribution.batch_shape)[-1]) or transition_distribution.batch_shape_tensor()[-1]) observation_states = ( (tensorshape_util.dims( observation_distribution.batch_shape) is not None and tensorshape_util.as_list( observation_distribution.batch_shape)[-1]) or observation_distribution.batch_shape_tensor()[-1]) if (tf.is_tensor(self._num_states) or tf.is_tensor(observation_states)): if validate_args: self._runtime_assertions += [ assert_util.assert_equal( self._num_states, observation_states, message="`transition_distribution` and " "`observation_distribution` must agree on " "last dimension of batch size") ] elif self._num_states != observation_states: raise ValueError("`transition_distribution` and " "`observation_distribution` must agree on " "last dimension of batch size") self._log_init = _extract_log_probs(self._num_states, initial_distribution) self._log_trans = _extract_log_probs(self._num_states, transition_distribution) self._num_steps = num_steps self._num_states = tf.shape(self._log_init)[-1] self._underlying_event_rank = tf.size( self._observation_distribution.event_shape_tensor()) num_steps_ = tf.get_static_value(num_steps) if num_steps_ is not None: self.static_event_shape = tf.TensorShape([ num_steps_ ]).concatenate(self._observation_distribution.event_shape) else: self.static_event_shape = None with tf.control_dependencies(self._runtime_assertions): self.static_batch_shape = tf.broadcast_static_shape( self._initial_distribution.batch_shape, tf.broadcast_static_shape( self._transition_distribution.batch_shape[:-1], self._observation_distribution.batch_shape[:-1])) # pylint: disable=protected-access super(HiddenMarkovModel, self).__init__( dtype=self._observation_distribution.dtype, reparameterization_type=reparameterization.NOT_REPARAMETERIZED, validate_args=validate_args, allow_nan_stats=allow_nan_stats, parameters=parameters, name=name) # pylint: enable=protected-access self._parameters = parameters
def _sample_n(self, n, seed=None): dim0_seed, otherdims_seed = samplers.split_seed(seed, salt='von_mises_fisher') # The sampling strategy relies on the fact that vMF variates are symmetric # about the mean direction. Accordingly, if we have a sampling strategy for # the away-from-mean angle, then we can uniformly sample the remaining # dimensions on the S^{dim-2} sphere for , and rotate these samples from a # (1, 0, 0, ..., 0)-mode distribution into the target orientation. # # This is easy to imagine on the 1-sphere (S^1; in 2-D space): sample a # von-Mises distributed `x` value in [-1, 1], then uniformly select what # amounts to a "up" or "down" additional degree of freedom after unit # normalizing, followed by a final rotation to the desired mean direction # from a basis of (1, 0). # # On S^2 (in 3-D), selecting a vMF `x` identifies a circle in `yz` on the # unit sphere over which the distribution is uniform, in particular the # circle where x = \hat{x} intersects the unit sphere. We pick a point on # that circle, then rotate to the desired mean direction from a basis of # (1, 0, 0). mean_direction = tf.convert_to_tensor(self.mean_direction) concentration = tf.convert_to_tensor(self.concentration) event_dim = ( tf.compat.dimension_value(self.event_shape[0]) or self._event_shape_tensor(mean_direction=mean_direction)[0]) sample_batch_shape = ps.concat([[n], self._batch_shape_tensor( mean_direction=mean_direction, concentration=concentration)], axis=0) dim = tf.cast(event_dim - 1, self.dtype) if event_dim == 3: samples_dim0 = self._sample_3d(n, mean_direction=mean_direction, concentration=concentration, seed=dim0_seed) else: # Wood'94 provides a rejection algorithm to sample the x coordinate. # Wood'94 definition of b: # b = (-2 * kappa + tf.sqrt(4 * kappa**2 + dim**2)) / dim # https://stats.stackexchange.com/questions/156729 suggests: b = dim / (2 * concentration + tf.sqrt(4 * concentration**2 + dim**2)) # TODO(bjp): Integrate any useful numerical tricks from hyperspherical VAE # https://github.com/nicola-decao/s-vae-tf/ x = (1 - b) / (1 + b) c = concentration * x + dim * tf.math.log1p(-x**2) beta = beta_lib.Beta(dim / 2, dim / 2) def cond_fn(w, should_continue, seed): del w, seed return tf.reduce_any(should_continue) def body_fn(w, should_continue, seed): """While loop body for sampling the angle `w`.""" beta_seed, unif_seed, next_seed = samplers.split_seed(seed, n=3) z = beta.sample(sample_shape=sample_batch_shape, seed=beta_seed) # set_shape needed here because of b/139013403 tensorshape_util.set_shape(z, w.shape) w = tf.where(should_continue, (1. - (1. + b) * z) / (1. - (1. - b) * z), w) if not self.allow_nan_stats: w = tf.debugging.check_numerics(w, 'w') unif = samplers.uniform( sample_batch_shape, seed=unif_seed, dtype=self.dtype) # set_shape needed here because of b/139013403 tensorshape_util.set_shape(unif, w.shape) should_continue = should_continue & ( concentration * w + dim * tf.math.log1p(-x * w) - c < # Use log1p(-unif) to prevent log(0) and ensure that log(1) is # possible. tf.math.log1p(-unif)) return w, should_continue, next_seed w = tf.zeros(sample_batch_shape, dtype=self.dtype) should_continue = tf.ones(sample_batch_shape, dtype=tf.bool) samples_dim0, _, _ = tf.while_loop( cond=cond_fn, body=body_fn, loop_vars=(w, should_continue, dim0_seed)) samples_dim0 = samples_dim0[..., tf.newaxis] if not self._allow_nan_stats: # Verify samples are w/in -1, 1, with useful error output tensors (top # value rather than all values). with tf.control_dependencies([ assert_util.assert_less_equal( samples_dim0, dtype_util.as_numpy_dtype(self.dtype)(1.01)), assert_util.assert_greater_equal( samples_dim0, dtype_util.as_numpy_dtype(self.dtype)(-1.01)), ]): samples_dim0 = tf.identity(samples_dim0) samples_otherdims_shape = ps.concat([sample_batch_shape, [event_dim - 1]], axis=0) unit_otherdims = tf.math.l2_normalize( samplers.normal( samples_otherdims_shape, seed=otherdims_seed, dtype=self.dtype), axis=-1) samples = tf.concat([ samples_dim0, # we must avoid sqrt(1 - (>1)**2) tf.sqrt(tf.maximum(1 - samples_dim0**2, 0.)) * unit_otherdims ], axis=-1) samples = tf.math.l2_normalize(samples, axis=-1) if not self.allow_nan_stats: samples = tf.debugging.check_numerics(samples, 'samples') # Runtime assert that samples are unit length. if not self.allow_nan_stats: worst, _ = tf.math.top_k( tf.reshape(tf.abs(1 - tf.linalg.norm(samples, axis=-1)), [-1])) with tf.control_dependencies([ assert_util.assert_near( dtype_util.as_numpy_dtype(self.dtype)(0), worst, atol=1e-4, summarize=100) ]): samples = tf.identity(samples) # The samples generated are symmetric around a mode at (1, 0, 0, ...., 0). # Now, we move the mode to `self.mean_direction` using a rotation matrix. if not self.allow_nan_stats: # Assert that the basis vector rotates to the mean direction, as expected. basis = tf.cast(tf.concat([[1.], tf.zeros([event_dim - 1])], axis=0), self.dtype) with tf.control_dependencies([ assert_util.assert_less( tf.linalg.norm( self._rotate(basis, mean_direction=mean_direction) - mean_direction, axis=-1), dtype_util.as_numpy_dtype(self.dtype)(1e-5)) ]): return self._rotate(samples, mean_direction=mean_direction) return self._rotate(samples, mean_direction=mean_direction)
def _parameter_control_dependencies(self, is_init): assertions = [] # Check num_steps is a scalar that's at least 1. if is_init != tensor_util.is_ref(self.num_steps): num_steps = tf.convert_to_tensor(self.num_steps) num_steps_ = tf.get_static_value(num_steps) if num_steps_ is not None: if np.ndim(num_steps_) != 0: raise ValueError( '`num_steps` must be a scalar but it has rank {}'.format( np.ndim(num_steps_))) if num_steps_ < 1: raise ValueError('`num_steps` must be at least 1.') elif self.validate_args: message = '`num_steps` must be a scalar' assertions.append( assert_util.assert_rank_at_most(self.num_steps, 0, message=message)) assertions.append( assert_util.assert_greater_equal( num_steps, 1, message='`num_steps` must be at least 1.')) # Check that the initial distribution has scalar events over the # integers. if is_init and not dtype_util.is_integer(self.initial_distribution.dtype): raise ValueError( '`initial_distribution.dtype` ({}) is not over integers'.format( dtype_util.name(self.initial_distribution.dtype))) if tensorshape_util.rank(self.initial_distribution.event_shape) is not None: if tensorshape_util.rank(self.initial_distribution.event_shape) != 0: raise ValueError('`initial_distribution` must have scalar `event_dim`s') elif self.validate_args: assertions += [ assert_util.assert_equal( ps.size(self.initial_distribution.event_shape_tensor()), 0, message='`initial_distribution` must have scalar `event_dim`s'), ] # Check that the transition distribution is over the integers. if (is_init and not dtype_util.is_integer(self.transition_distribution.dtype)): raise ValueError( '`transition_distribution.dtype` ({}) is not over integers'.format( dtype_util.name(self.transition_distribution.dtype))) # Check observations have non-scalar batches. # The graph version of this assertion is incorporated as # a control dependency of the transition/observation # compatibility test. if tensorshape_util.rank(self.observation_distribution.batch_shape) == 0: raise ValueError( "`observation_distribution` can't have scalar batches") # Check transitions have non-scalar batches. # The graph version of this assertion is incorporated as # a control dependency of the transition/observation # compatibility test. if tensorshape_util.rank(self.transition_distribution.batch_shape) == 0: raise ValueError( "`transition_distribution` can't have scalar batches") # Check compatibility of transition distribution and observation # distribution. tdbs = self.transition_distribution.batch_shape odbs = self.observation_distribution.batch_shape if (tensorshape_util.dims(tdbs) is not None and tf.compat.dimension_value(odbs[-1]) is not None): if (tf.compat.dimension_value(tdbs[-1]) != tf.compat.dimension_value(odbs[-1])): raise ValueError( '`transition_distribution` and `observation_distribution` ' 'must agree on last dimension of batch size') elif self.validate_args: tdbs = self.transition_distribution.batch_shape_tensor() odbs = self.observation_distribution.batch_shape_tensor() transition_precondition = assert_util.assert_greater( ps.size(tdbs), 0, message=('`transition_distribution` can\'t have scalar ' 'batches')) observation_precondition = assert_util.assert_greater( ps.size(odbs), 0, message=('`observation_distribution` can\'t have scalar ' 'batches')) with tf.control_dependencies([ transition_precondition, observation_precondition]): assertions += [ assert_util.assert_equal( tdbs[-1], odbs[-1], message=('`transition_distribution` and ' '`observation_distribution` ' 'must agree on last dimension of batch size'))] return assertions