def _maybe_assert_valid_sample(self, x, loc): """Checks the validity of a sample.""" if not self.validate_args: return [] return [ assert_util.assert_greater_equal( x, loc, message='x is not in the support of the distribution') ]
def _prob(self, x): concentration = tf.convert_to_tensor(self.concentration) scale = tf.convert_to_tensor(self.scale) with tf.control_dependencies([ assert_util.assert_greater_equal( x, scale, message='`x` is not in the support of the distribution.') ] if self.validate_args else []): def prob_on_support(z): return concentration * (scale**concentration) / (z**(concentration + 1)) return self._extend_support(x, scale, prob_on_support, alt=0.)
def _assert_compatible_shape(self, index, sample_shape, samples): requested_shape, _ = self._expand_sample_shape_to_vector( tf.convert_to_tensor(sample_shape, dtype=tf.int32), name='requested_shape') actual_shape = prefer_static.shape(samples) actual_rank = prefer_static.rank_from_shape(actual_shape) requested_rank = prefer_static.rank_from_shape(requested_shape) # We test for two properties we expect of yielded distributions: # (1) The rank of the tensor of generated samples must be at least # as large as the rank requested. # (2) The requested shape must be a prefix of the shape of the # generated tensor of samples. # We attempt to perform test (1) statically first. # We don't need to do this explicitly for test (2) because # `assert_equal` evaluates statically if it can. static_actual_rank = tf.get_static_value(actual_rank) static_requested_rank = tf.get_static_value(requested_rank) assertion_message = ('Samples yielded by distribution #{} are not ' 'consistent with `sample_shape` passed to ' '`JointDistributionCoroutine` ' 'distribution.'.format(index)) # TODO Remove this static check (b/138738650) if (static_actual_rank is not None and static_requested_rank is not None): # We're able to statically check the rank if static_actual_rank < static_requested_rank: raise ValueError(assertion_message) else: control_dependencies = [] else: # We're not able to statically check the rank control_dependencies = [ assert_util.assert_greater_equal(actual_rank, requested_rank, message=assertion_message) ] with tf.control_dependencies(control_dependencies): trimmed_actual_shape = actual_shape[:requested_rank] control_dependencies = [ assert_util.assert_equal(requested_shape, trimmed_actual_shape, message=assertion_message) ] return control_dependencies
def _log_prob(self, x): concentration = tf.convert_to_tensor(self.concentration) scale = tf.convert_to_tensor(self.scale) with tf.control_dependencies([ assert_util.assert_greater_equal( x, scale, message='`x` is not in the support of the distribution.') ] if self.validate_args else []): def log_prob_on_support(z): return (tf.math.log(concentration) + concentration * tf.math.log(scale) - (concentration + 1.) * tf.math.log(z)) return self._extend_support( x, scale, log_prob_on_support, alt=-np.inf)
def _prob(self, x): low = tf.convert_to_tensor(self.low) high = tf.convert_to_tensor(self.high) peak = tf.convert_to_tensor(self.peak) if self.validate_args: with tf.control_dependencies([ assert_util.assert_greater_equal(x, low), assert_util.assert_less_equal(x, high) ]): x = tf.identity(x) interval_length = high - low # This is the pdf function when a low <= high <= x. This looks like # a triangle, so we have to treat each line segment separately. result_inside_interval = tf.where( (x >= low) & (x <= peak), # Line segment from (low, 0) to (peak, 2 / (high - low)). 2. * (x - low) / (interval_length * (peak - low)), # Line segment from (peak, 2 / (high - low)) to (high, 0). 2. * (high - x) / (interval_length * (high - peak))) return tf.where((x < low) | (x > high), tf.zeros_like(x), result_inside_interval)
def _parameter_control_dependencies(self, is_init): assertions = [] logits = self._logits probs = self._probs param, name = (probs, 'probs') if logits is None else (logits, 'logits') # In init, we can always build shape and dtype checks because # we assume shape doesn't change for Variable backed args. if is_init: if not dtype_util.is_floating(param.dtype): raise TypeError( 'Argument `{}` must having floating type.'.format(name)) msg = 'Argument `{}` must have rank at least 1.'.format(name) shape_static = tensorshape_util.dims(param.shape) if shape_static is not None: if len(shape_static) < 1: raise ValueError(msg) elif self.validate_args: param = tf.convert_to_tensor(param) assertions.append( assert_util.assert_rank_at_least(param, 1, message=msg)) with tf.control_dependencies(assertions): param = tf.identity(param) msg1 = 'Argument `{}` must have final dimension >= 1.'.format(name) msg2 = 'Argument `{}` must have final dimension <= {}.'.format( name, tf.int32.max) event_size = shape_static[-1] if shape_static is not None else None if event_size is not None: if event_size < 1: raise ValueError(msg1) if event_size > tf.int32.max: raise ValueError(msg2) elif self.validate_args: param = tf.convert_to_tensor(param) assertions.append( assert_util.assert_greater_equal(tf.shape(param)[-1], 1, message=msg1)) # NOTE: For now, we leave out a runtime assertion that # `tf.shape(param)[-1] <= tf.int32.max`. An earlier `tf.shape` call # will fail before we get to this point. if not self.validate_args: assert not assertions # Should never happen. return [] if probs is not None: probs = param # reuse tensor conversion from above if is_init != tensor_util.is_ref(probs): probs = tf.convert_to_tensor(probs) one = tf.ones([], dtype=probs.dtype) assertions.extend([ assert_util.assert_non_negative(probs), assert_util.assert_less_equal(probs, one), assert_util.assert_near( tf.reduce_sum(probs, axis=-1), one, message='Argument `probs` must sum to 1.'), ]) return assertions
def _validate_sample_arg(self, x): """Helper which validates sample arg, e.g., input to `log_prob`.""" with tf.name_scope('validate_sample_arg'): x_ndims = (tf.rank(x) if tensorshape_util.rank(x.shape) is None else tensorshape_util.rank(x.shape)) event_ndims = (tf.size(self.event_shape_tensor()) if tensorshape_util.rank(self.event_shape) is None else tensorshape_util.rank(self.event_shape)) batch_ndims = (tf.size(self._batch_shape_unexpanded) if tensorshape_util.rank(self.batch_shape) is None else tensorshape_util.rank(self.batch_shape)) expected_batch_event_ndims = batch_ndims + event_ndims if (isinstance(x_ndims, int) and isinstance(expected_batch_event_ndims, int)): if x_ndims < expected_batch_event_ndims: raise NotImplementedError( 'Broadcasting is not supported; too few batch and event dims ' '(expected at least {}, saw {}).'.format( expected_batch_event_ndims, x_ndims)) ndims_assertion = [] elif self.validate_args: ndims_assertion = [ assert_util.assert_greater_equal( x_ndims, expected_batch_event_ndims, message=('Broadcasting is not supported; too few ' 'batch and event dims.'), name='assert_batch_and_event_ndims_large_enough'), ] if (tensorshape_util.is_fully_defined(self.batch_shape) and tensorshape_util.is_fully_defined(self.event_shape)): expected_batch_event_shape = np.int32( tensorshape_util.concatenate(self.batch_shape, self.event_shape)) else: expected_batch_event_shape = tf.concat([ self.batch_shape_tensor(), self.event_shape_tensor(), ], axis=0) sample_ndims = x_ndims - expected_batch_event_ndims if isinstance(sample_ndims, int): sample_ndims = max(sample_ndims, 0) if (isinstance(sample_ndims, int) and tensorshape_util.is_fully_defined(x.shape[sample_ndims:])): actual_batch_event_shape = np.int32(x.shape[sample_ndims:]) else: sample_ndims = tf.maximum(sample_ndims, 0) actual_batch_event_shape = tf.shape(x)[sample_ndims:] if (isinstance(expected_batch_event_shape, np.ndarray) and isinstance(actual_batch_event_shape, np.ndarray)): if any(expected_batch_event_shape != actual_batch_event_shape): raise NotImplementedError( 'Broadcasting is not supported; ' 'unexpected batch and event shape ' '(expected {}, saw {}).'.format( expected_batch_event_shape, actual_batch_event_shape)) # We need to set the final runtime-assertions to `ndims_assertion` since # its possible this assertion was created. We could add a condition to # only do so if `self.validate_args == True`, however this is redundant # as `ndims_assertion` already encodes this information. runtime_assertions = ndims_assertion elif self.validate_args: # We need to make the `ndims_assertion` a control dep because otherwise # TF itself might raise an exception owing to this assertion being # ill-defined, ie, one cannot even compare different rank Tensors. with tf.control_dependencies(ndims_assertion): shape_assertion = assert_util.assert_equal( expected_batch_event_shape, actual_batch_event_shape, message=('Broadcasting is not supported; ' 'unexpected batch and event shape.'), name='assert_batch_and_event_shape_same') runtime_assertions = [shape_assertion] else: runtime_assertions = [] return runtime_assertions
def _sample_n(self, n, seed=None): seed = SeedStream(seed, salt='vom_mises_fisher') # The sampling strategy relies on the fact that vMF variates are symmetric # about the mean direction. Accordingly, if we have a sampling strategy for # the away-from-mean angle, then we can uniformly sample the remaining # dimensions on the S^{dim-2} sphere for , and rotate these samples from a # (1, 0, 0, ..., 0)-mode distribution into the target orientation. # # This is easy to imagine on the 1-sphere (S^1; in 2-D space): sample a # von-Mises distributed `x` value in [-1, 1], then uniformly select what # amounts to a "up" or "down" additional degree of freedom after unit # normalizing, followed by a final rotation to the desired mean direction # from a basis of (1, 0). # # On S^2 (in 3-D), selecting a vMF `x` identifies a circle in `yz` on the # unit sphere over which the distribution is uniform, in particular the # circle where x = \hat{x} intersects the unit sphere. We pick a point on # that circle, then rotate to the desired mean direction from a basis of # (1, 0, 0). event_dim = ( tf.compat.dimension_value(self.event_shape[0]) or self._event_shape_tensor()[0]) sample_batch_shape = tf.concat([[n], self._batch_shape_tensor()], axis=0) dim = tf.cast(event_dim - 1, self.dtype) if event_dim == 3: samples_dim0 = self._sample_3d(n, seed=seed) else: # Wood'94 provides a rejection algorithm to sample the x coordinate. # Wood'94 definition of b: # b = (-2 * kappa + tf.sqrt(4 * kappa**2 + dim**2)) / dim # https://stats.stackexchange.com/questions/156729 suggests: b = dim / (2 * self.concentration + tf.sqrt(4 * self.concentration**2 + dim**2)) # TODO(bjp): Integrate any useful numerical tricks from hyperspherical VAE # https://github.com/nicola-decao/s-vae-tf/ x = (1 - b) / (1 + b) c = self.concentration * x + dim * tf.math.log1p(-x**2) beta = beta_lib.Beta(dim / 2, dim / 2) def cond_fn(w, should_continue): del w return tf.reduce_any(should_continue) def body_fn(w, should_continue): z = beta.sample(sample_shape=sample_batch_shape, seed=seed()) # set_shape needed here because of b/139013403 z.set_shape(w.shape) w = tf.where(should_continue, (1 - (1 + b) * z) / (1 - (1 - b) * z), w) w = tf.debugging.check_numerics(w, 'w') unif = tf.random.uniform( sample_batch_shape, seed=seed(), dtype=self.dtype) # set_shape needed here because of b/139013403 unif.set_shape(w.shape) should_continue = tf.logical_and( should_continue, self.concentration * w + dim * tf.math.log1p(-x * w) - c < tf.math.log(unif)) return w, should_continue w = tf.zeros(sample_batch_shape, dtype=self.dtype) should_continue = tf.ones(sample_batch_shape, dtype=tf.bool) samples_dim0 = tf.while_loop( cond=cond_fn, body=body_fn, loop_vars=(w, should_continue))[0] samples_dim0 = samples_dim0[..., tf.newaxis] if not self._allow_nan_stats: # Verify samples are w/in -1, 1, with useful error output tensors (top # value rather than all values). with tf.control_dependencies([ assert_util.assert_less_equal( samples_dim0, dtype_util.as_numpy_dtype(self.dtype)(1.01), data=[tf.math.top_k(tf.reshape(samples_dim0, [-1]))[0]]), assert_util.assert_greater_equal( samples_dim0, dtype_util.as_numpy_dtype(self.dtype)(-1.01), data=[-tf.math.top_k(tf.reshape(-samples_dim0, [-1]))[0]]) ]): samples_dim0 = tf.identity(samples_dim0) samples_otherdims_shape = tf.concat([sample_batch_shape, [event_dim - 1]], axis=0) unit_otherdims = tf.math.l2_normalize( tf.random.normal( samples_otherdims_shape, seed=seed(), dtype=self.dtype), axis=-1) samples = tf.concat([ samples_dim0, # we must avoid sqrt(1 - (>1)**2) tf.sqrt(tf.maximum(1 - samples_dim0**2, 0.)) * unit_otherdims ], axis=-1) samples = tf.math.l2_normalize(samples, axis=-1) if not self._allow_nan_stats: samples = tf.debugging.check_numerics(samples, 'samples') # Runtime assert that samples are unit length. if not self._allow_nan_stats: worst, idx = tf.math.top_k( tf.reshape(tf.abs(1 - tf.linalg.norm(samples, axis=-1)), [-1])) with tf.control_dependencies([ assert_util.assert_near( dtype_util.as_numpy_dtype(self.dtype)(0), worst, data=[ worst, idx, tf.gather(tf.reshape(samples, [-1, event_dim]), idx) ], atol=1e-4, summarize=100) ]): samples = tf.identity(samples) # The samples generated are symmetric around a mode at (1, 0, 0, ...., 0). # Now, we move the mode to `self.mean_direction` using a rotation matrix. if not self._allow_nan_stats: # Assert that the basis vector rotates to the mean direction, as expected. basis = tf.cast(tf.concat([[1.], tf.zeros([event_dim - 1])], axis=0), self.dtype) with tf.control_dependencies([ assert_util.assert_less( tf.linalg.norm( self._rotate(basis) - self.mean_direction, axis=-1), dtype_util.as_numpy_dtype(self.dtype)(1e-5)) ]): return self._rotate(samples) return self._rotate(samples)
def __init__(self, initial_distribution, transition_distribution, observation_distribution, num_steps, validate_args=False, allow_nan_stats=True, name="HiddenMarkovModel"): """Initialize hidden Markov model. Args: initial_distribution: A `Categorical`-like instance. Determines probability of first hidden state in Markov chain. The number of categories must match the number of categories of `transition_distribution` as well as both the rightmost batch dimension of `transition_distribution` and the rightmost batch dimension of `observation_distribution`. transition_distribution: A `Categorical`-like instance. The rightmost batch dimension indexes the probability distribution of each hidden state conditioned on the previous hidden state. observation_distribution: A `tfp.distributions.Distribution`-like instance. The rightmost batch dimension indexes the distribution of each observation conditioned on the corresponding hidden state. num_steps: The number of steps taken in Markov chain. A python `int`. validate_args: Python `bool`, default `False`. When `True` distribution parameters are checked for validity despite possibly degrading runtime performance. When `False` invalid inputs may silently render incorrect outputs. Default value: `False`. allow_nan_stats: Python `bool`, default `True`. When `True`, statistics (e.g., mean, mode, variance) use the value "`NaN`" to indicate the result is undefined. When `False`, an exception is raised if one or more of the statistic's batch members are undefined. Default value: `True`. name: Python `str` name prefixed to Ops created by this class. Default value: "HiddenMarkovModel". Raises: ValueError: if `num_steps` is not at least 1. ValueError: if `initial_distribution` does not have scalar `event_shape`. ValueError: if `transition_distribution` does not have scalar `event_shape.` ValueError: if `transition_distribution` and `observation_distribution` are fully defined but don't have matching rightmost dimension. """ parameters = dict(locals()) # pylint: disable=protected-access with tf.name_scope(name) as name: self._runtime_assertions = [] # pylint: enable=protected-access num_steps = tf.convert_to_tensor(value=num_steps, name="num_steps") if validate_args: self._runtime_assertions += [ assert_util.assert_equal( tf.rank(num_steps), 0, message="`num_steps` must be a scalar") ] self._runtime_assertions += [ assert_util.assert_greater_equal( num_steps, 1, message="`num_steps` must be at least 1.") ] self._initial_distribution = initial_distribution self._observation_distribution = observation_distribution self._transition_distribution = transition_distribution if (initial_distribution.event_shape is not None and tensorshape_util.rank( initial_distribution.event_shape) != 0): raise ValueError( "`initial_distribution` must have scalar `event_dim`s") elif validate_args: self._runtime_assertions += [ assert_util.assert_equal( tf.shape(initial_distribution.event_shape_tensor())[0], 0, message="`initial_distribution` must have scalar" "`event_dim`s") ] if (transition_distribution.event_shape is not None and tensorshape_util.rank( transition_distribution.event_shape) != 0): raise ValueError( "`transition_distribution` must have scalar `event_dim`s") elif validate_args: self._runtime_assertions += [ assert_util.assert_equal( tf.shape( transition_distribution.event_shape_tensor())[0], 0, message="`transition_distribution` must have scalar" "`event_dim`s") ] if (transition_distribution.batch_shape is not None and tensorshape_util.rank( transition_distribution.batch_shape) == 0): raise ValueError( "`transition_distribution` can't have scalar batches") elif validate_args: self._runtime_assertions += [ assert_util.assert_greater( tf.size(transition_distribution.batch_shape_tensor()), 0, message="`transition_distribution` can't have scalar " "batches") ] if (observation_distribution.batch_shape is not None and tensorshape_util.rank( observation_distribution.batch_shape) == 0): raise ValueError( "`observation_distribution` can't have scalar batches") elif validate_args: self._runtime_assertions += [ assert_util.assert_greater( tf.size(observation_distribution.batch_shape_tensor()), 0, message="`observation_distribution` can't have scalar " "batches") ] # Infer number of hidden states and check consistency # between transitions and observations with tf.control_dependencies(self._runtime_assertions): self._num_states = ( (transition_distribution.batch_shape and transition_distribution.batch_shape[-1]) or transition_distribution.batch_shape_tensor()[-1]) observation_states = ( (observation_distribution.batch_shape and observation_distribution.batch_shape[-1]) or observation_distribution.batch_shape_tensor()[-1]) if (tf.is_tensor(self._num_states) or tf.is_tensor(observation_states)): if validate_args: self._runtime_assertions += [ assert_util.assert_equal( self._num_states, observation_states, message="`transition_distribution` and " "`observation_distribution` must agree on " "last dimension of batch size") ] elif self._num_states != observation_states: raise ValueError("`transition_distribution` and " "`observation_distribution` must agree on " "last dimension of batch size") self._log_init = _extract_log_probs(self._num_states, initial_distribution) self._log_trans = _extract_log_probs(self._num_states, transition_distribution) self._num_steps = num_steps self._num_states = tf.shape(self._log_init)[-1] self._underlying_event_rank = tf.size( self._observation_distribution.event_shape_tensor()) num_steps_ = tf.get_static_value(num_steps) if num_steps_ is not None: self.static_event_shape = tf.TensorShape([ num_steps_ ]).concatenate(self._observation_distribution.event_shape) else: self.static_event_shape = None with tf.control_dependencies(self._runtime_assertions): self.static_batch_shape = tf.broadcast_static_shape( self._initial_distribution.batch_shape, tf.broadcast_static_shape( self._transition_distribution.batch_shape[:-1], self._observation_distribution.batch_shape[:-1])) # pylint: disable=protected-access super(HiddenMarkovModel, self).__init__( dtype=self._observation_distribution.dtype, reparameterization_type=reparameterization.NOT_REPARAMETERIZED, validate_args=validate_args, allow_nan_stats=allow_nan_stats, parameters=parameters, name=name) # pylint: enable=protected-access self._parameters = parameters