def _kl_blockwise_blockwise(b0, b1, name=None): """Calculate the batched KL divergence KL(b0 || b1) with b0 and b1 Blockwise distributions. Args: b0: instance of a Blockwise distribution object. b1: instance of a Blockwise distribution object. name: (optional) Name to use for created operations. Default is "kl_blockwise_blockwise". Returns: kl_blockwise_blockwise: `Tensor`. The batchwise KL(b0 || b1). """ if len(b0.distributions) != len(b1.distributions): raise ValueError( 'Can only compute KL divergence between Blockwise distributions with ' 'the same number of component distributions.') # We also need to check that the event shapes match for each one. b0_event_sizes = [_event_size(d) for d in b0.distributions] b1_event_sizes = [_event_size(d) for d in b1.distributions] assertions = [] message = ('Can only compute KL divergence between Blockwise distributions ' 'with the same pairwise event shapes.') if (all(isinstance(event_size, int) for event_size in b0_event_sizes) and all(isinstance(event_size, int) for event_size in b1_event_sizes)): if b0_event_sizes != b1_event_sizes: raise ValueError(message) else: if b0.validate_args or b1.validate_args: assertions.extend( assert_util.assert_equal( # pylint: disable=g-complex-comprehension e1, e2, message=message) for e1, e2 in zip(b0_event_sizes, b1_event_sizes)) with tf.name_scope(name or 'kl_blockwise_blockwise'): with tf.control_dependencies(assertions): return sum([ kullback_leibler.kl_divergence(d1, d2) for d1, d2 in zip( b0.distributions, b1.distributions)])
def _kl_sample(a, b, name='kl_sample'): """Batched KL divergence `KL(a || b)` for Sample distributions. We can leverage the fact that: ``` KL(Sample(a) || Sample(b)) = sum(KL(a || b)) ``` where the sum is over the `sample_shape` dims. Args: a: Instance of `Sample` distribution. b: Instance of `Sample` distribution. name: (optional) name to use for created ops. Default value: `"kl_sample"`'. Returns: kldiv: Batchwise `KL(a || b)`. Raises: ValueError: If the `sample_shape` of `a` and `b` don't match. """ assertions = [] a_ss = tf.get_static_value(a.sample_shape) b_ss = tf.get_static_value(b.sample_shape) msg = '`a.sample_shape` must be identical to `b.sample_shape`.' if a_ss is not None and b_ss is not None: if not np.array_equal(a_ss, b_ss): raise ValueError(msg) elif a.validate_args or b.validate_args: assertions.append( assert_util.assert_equal(a.sample_shape, b.sample_shape, message=msg)) with tf.control_dependencies(assertions): kl = kullback_leibler.kl_divergence(a.distribution, b.distribution, name=name) n = ps.reduce_prod(a.sample_shape) return tf.cast(x=n, dtype=kl.dtype) * kl
def _assert_batch_shape_matches_weights(distribution, weights_shape, diststr): """Checks that all parts of a distribution have the expected batch shape.""" shapes = [weights_shape] + tf.nest.flatten( distribution.batch_shape_tensor()) static_shapes = [ tf.get_static_value(ps.convert_to_shape_tensor(s)) for s in shapes ] static_shapes_not_none = [s for s in static_shapes if s is not None] static_shapes_match = all([ np.all(a == b) # Also need to check for rank mismatch (below). for (a, b) in zip(static_shapes_not_none[1:], static_shapes_not_none[:-1]) ]) # Build a separate list of static ranks, since rank is often static even when # shape is not. ranks = [ps.rank_from_shape(s) for s in shapes] static_ranks = [int(r) for r in ranks if not tf.is_tensor(r)] static_ranks_match = all( [a == b for (a, b) in zip(static_ranks[1:], static_ranks[:-1])]) msg = ( "The {diststr} distribution's batch shape does not match the particle " "weights; a correct {diststr} distribution must return an independent " "log-density for each particle. You may be " "creating a joint distribution in which some parts do not depend on the " "previous particles, and/or you are creating an autobatched joint " "distribution without setting `batch_ndims`.".format(diststr=diststr)) if not (static_ranks_match and static_shapes_match): raise ValueError( msg + ' ' + 'Weights have shape {}, but the distribution has batch ' 'shape {}.'.format(weights_shape, distribution.batch_shape)) assertions = [] if distribution.validate_args and any([s is None for s in static_shapes]): assertions = [ assert_util.assert_equal(a, b, message=msg) for a, b in zip(shapes[1:], shapes[:-1]) ] return assertions
def _sample_control_dependencies(self, samples): inner_sample_dim = samples.shape[-1] shape_msg = ('Samples must have innermost dimension matching that of ' '`self.dimension`. Found {}, expected {}'.format( inner_sample_dim, self.dimension)) if inner_sample_dim is not None: if self.dimension != inner_sample_dim: raise ValueError(shape_msg) assertions = [] if not self.validate_args: return assertions assertions.append( assert_util.assert_near(tf.cast(1., dtype=self.dtype), tf.linalg.norm(samples, axis=-1), message='Samples must be unit length.')) assertions.append( assert_util.assert_equal(tf.shape(samples)[-1:], self.dimension, message=shape_msg)) return assertions
def _sample_control_dependencies(self, x): assertions = [] if not self.validate_args: return assertions loc = tf.convert_to_tensor(self.loc) scale = tf.convert_to_tensor(self.scale) concentration = tf.convert_to_tensor(self.concentration) assertions.append( assert_util.assert_greater_equal( x, loc, message='Sample must be greater than or equal to `loc`.')) assertions.append( assert_util.assert_equal( tf.logical_or(tf.greater_equal(concentration, 0), tf.less_equal(x, loc - scale / concentration)), True, message=('If `concentration < 0`, sample must be less than or ' 'equal to `loc - scale / concentration`.'), summarize=100)) return assertions
def _parameter_control_dependencies(self, is_init): if tensorshape_util.is_fully_defined(self.distribution.batch_shape): if self.to_shape is not None: static_to_shape = tf.get_static_value(self.to_shape) if static_to_shape is not None: bcast_shp = tf.broadcast_static_shape( tf.TensorShape(static_to_shape), self.distribution.batch_shape) if bcast_shp != static_to_shape: raise ValueError(f'Argument `to_shape` ({static_to_shape}) ' 'is incompatible with underlying distribution ' f'batch shape ({self.distribution.batch_shape}).') else: static_with_shape = tf.get_static_value(self.with_shape) if static_with_shape is not None: tf.broadcast_static_shape( # Ensure compatible. tf.TensorShape(static_with_shape), self.distribution.batch_shape) underlying = self.distribution._parameter_control_dependencies(is_init) # pylint: disable=protected-access if not self.validate_args: return underlying checks = [] if self.to_shape is not None: if tensor_util.is_ref(self.to_shape) != is_init: checks += [assert_util.assert_equal( self.to_shape, ps.broadcast_shape(self.distribution.batch_shape_tensor(), self.to_shape), message='Argument `to_shape` is incompatible with underlying ' 'distribution batch shape.')] else: if tensor_util.is_ref(self.with_shape) != is_init: checks += [tf.broadcast_dynamic_shape( self.distribution.batch_shape_tensor(), self.with_shape)] return tuple(checks) + tuple(underlying)
def _kl_power_uniform_spherical(a, b, name=None): """Calculate the batched KL divergence KL(a || b). Args: a: instance of a PowerSpherical distribution object. b: instance of a SphericalUniform distribution object. name: (optional) Name to use for created operations. default is "kl_power_uniform_spherical". Returns: Batchwise KL(a || b) Raises: ValueError: If the two distributions are over spheres of different dimensions. #### References [1] Nicola de Cao, Wilker Aziz. The Power Spherical distribution. https://arxiv.org/abs/2006.04437. """ with tf.name_scope(name or 'kl_power_uniform_spherical'): msg = ( 'Can not compute the KL divergence between a `PowerSpherical` and ' '`SphericalUniform` of different dimensions.') deps = [] if a.event_shape[-1] is not None: if a.event_shape[-1] != b.dimension: raise ValueError( (msg + 'Got {} vs. {}').format(a.event_shape[-1], b.dimension)) elif a.validate_args or b.validate_args: deps += [ assert_util.assert_equal(a.event_shape_tensor()[-1], b.dimension, message=msg) ] with tf.control_dependencies(deps): return b.entropy() - a.entropy()
def _inverse(self, y): ndims = ps.rank(y) indices = ps.reshape(ps.add(self.axis, ndims), shape=[-1, 1]) num_left, num_right = ps.unstack(self.paddings, num=2, axis=-1) x = tf.slice( y, begin=ps.tensor_scatter_nd_update( ps.zeros(ndims, dtype=tf.int32), indices, num_left), size=ps.tensor_scatter_nd_sub( ps.shape(y), indices, num_left + num_right)) if not self.validate_args: return x assertions = [ assert_util.assert_equal( self._forward(x), y, message=('Argument `y` to `inverse` was not padded with ' '`constant_values`.')), ] with tf.control_dependencies(assertions): return tf.identity(x)
def _sample_control_dependencies(self, samples): """Check samples for proper shape and whether samples are unit vectors.""" inner_sample_dim = samples.shape[-1] event_size = self.event_shape[-1] shape_msg = ('Samples must have innermost dimension matching that of ' '`self.mean_direction`.') if event_size is not None and inner_sample_dim is not None: if event_size != inner_sample_dim: raise ValueError(shape_msg) assertions = [] if not self.validate_args: return assertions assertions.append(assert_util.assert_near( 1., tf.linalg.norm(samples, axis=-1), message='Samples must be unit length.')) assertions.append(assert_util.assert_equal( tf.shape(samples)[-1:], self.event_shape_tensor(), message=shape_msg)) return assertions
def vector_size_to_square_matrix_size(d, validate_args, name=None): """Convert a vector size to a matrix size.""" if isinstance(d, (float, int, np.generic, np.ndarray)): n = (-1 + np.sqrt(1 + 8 * d)) / 2. if float(int(n)) != n: raise ValueError( 'Vector length {} is not a triangular number.'.format(d)) return int(n) else: with tf.name_scope(name or 'vector_size_to_square_matrix_size') as name: n = (-1. + tf.sqrt(1 + 8. * tf.cast(d, dtype=tf.float32))) / 2. if validate_args: with tf.control_dependencies([ assert_util.assert_equal( tf.cast(tf.cast(n, dtype=tf.int32), dtype=tf.float32), n, data=[d], message='Vector length is not a triangular number') ]): n = tf.identity(n) return tf.cast(n, d.dtype)
def _parameter_control_dependencies(self, is_init): if not self.validate_args: # Avoid computing intermediates needed to construct the assertions. return [] assertions = [] if is_init != tensor_util.is_ref(self._batch_shape_unexpanded): implicit_dim_mask = prefer_static.equal(self._batch_shape_unexpanded, -1) assertions.append(assert_util.assert_rank( self._batch_shape_unexpanded, 1, message='New shape must be a vector.')) assertions.append(assert_util.assert_less_equal( tf.math.count_nonzero(implicit_dim_mask, dtype=tf.int32), 1, message='At most one dimension can be unknown.')) assertions.append(assert_util.assert_non_negative( self._batch_shape_unexpanded + 1, message='Shape elements must be >=-1.')) # Check that the old and new shapes are the same size. expanded_new_shape, original_size = self._calculate_new_shape() new_size = prefer_static.reduce_prod(expanded_new_shape) assertions.append(assert_util.assert_equal( new_size, tf.cast(original_size, new_size.dtype), message='Shape sizes do not match.')) return assertions
def _maybe_warn_increased_dof(self, component_name, component_ldj, increased_dof): """Warns or raises when `increased_dof` is True.""" # Short-circuit when the component LDJ is statically zero. if (tf.get_static_value(tf.rank(component_ldj)) == 0 and tf.get_static_value(component_ldj) == 0): return # Short-circuit when increased_dof is statically False. increased_dof_ = tf.get_static_value(increased_dof) if increased_dof_ is False: # pylint: disable=g-bool-id-comparison return error_message = ( 'Nested component "{}" in composition "{}" operates on inputs ' 'with increased degrees of freedom. This may result in an ' 'incorrect log_det_jacobian.' ).format(component_name, self.name) # When validate_args is True, we raise on increased DoF. if self._validate_args: if increased_dof_: raise ValueError(error_message) return assert_util.assert_equal(False, increased_dof, error_message) if (not tf.executing_eagerly() and control_flow_util.GraphOrParentsInXlaContext(tf1.get_default_graph())): return # No StringFormat or Print ops in XLA. # Otherwise, we print a warning and continue. return ps.cond( pred=increased_dof, false_fn=tf.no_op, true_fn=lambda: tf.print( # pylint: disable=g-long-lambda 'WARNING: ' + error_message, output_stream=sys.stderr))
def _maybe_validate_rightmost_transposed_ndims( initial_rightmost_transposed_ndims, rightmost_transposed_ndims, validate_args, name=None): """Checks that `rightmost_transposed_ndims` is valid.""" with tf.name_scope(name or 'maybe_validate_rightmost_transposed_ndims'): assertions = [] if tensorshape_util.rank(rightmost_transposed_ndims.shape) is not None: if tensorshape_util.rank(rightmost_transposed_ndims.shape) != 0: raise ValueError('`rightmost_transposed_ndims` must be a scalar, ' 'saw rank: {}.'.format( tensorshape_util.rank( rightmost_transposed_ndims.shape))) elif validate_args: assertions += [ assert_util.assert_rank(rightmost_transposed_ndims, 0), assert_util.assert_equal( rightmost_transposed_ndims, initial_rightmost_transposed_ndims, message='`rightmost_transposed_ndims` must not change ' 'from the value set when the `Transpose` ' 'bijector was constructed.')] rightmost_transposed_ndims_ = tf.get_static_value( rightmost_transposed_ndims) msg = '`rightmost_transposed_ndims` must be non-negative.' if rightmost_transposed_ndims_ is not None: if rightmost_transposed_ndims_ < 0: raise ValueError(msg[:-1] + ', saw: {}.'.format( rightmost_transposed_ndims_)) elif validate_args: assertions += [ assert_util.assert_non_negative( rightmost_transposed_ndims, message=msg) ] return assertions
def _lu_solve_assertions(lower_upper, perm, rhs, validate_args): """Returns list of assertions related to `lu_solve` assumptions.""" assertions = _lu_reconstruct_assertions(lower_upper, perm, validate_args) message = 'Input `rhs` must have at least 2 dimensions.' if rhs.shape.ndims is not None: if rhs.shape.ndims < 2: raise ValueError(message) elif validate_args: assertions.append( assert_util.assert_rank_at_least(rhs, rank=2, message=message)) message = '`lower_upper.shape[-1]` must equal `rhs.shape[-1]`.' if (tf.compat.dimension_value(lower_upper.shape[-1]) is not None and tf.compat.dimension_value(rhs.shape[-2]) is not None): if lower_upper.shape[-1] != rhs.shape[-2]: raise ValueError(message) elif validate_args: assertions.append( assert_util.assert_equal(tf.shape(lower_upper)[-1], tf.shape(rhs)[-2], message=message)) return assertions
def _parameter_control_dependencies(self, is_init): assertions = [] if is_init != tensor_util.is_ref(self.permutation): if not dtype_util.is_integer(self.permutation.dtype): raise TypeError('permutation.dtype ({}) should be `int`-like.'.format( dtype_util.name(self.permutation.dtype))) p = tf.get_static_value(self.permutation) if p is not None: if set(p) != set(np.arange(p.size)): raise ValueError('Permutation over `d` must contain exactly one of ' 'each of `{0, 1, ..., d}`.') if self.validate_args: p = tf.sort(self.permutation, axis=-1) assertions.append( assert_util.assert_equal( p, tf.range(tf.shape(p)[-1]), message=('Permutation over `d` must contain exactly one of ' 'each of `{0, 1, ..., d}`.'))) return assertions
def _parameter_control_dependencies(self, is_init): assertions = super(Wishart, self)._parameter_control_dependencies(is_init) if not self.validate_args: assert not assertions return [] if self._scale_full is None: if is_init != tensor_util.is_ref(self._scale_tril): shape = prefer_static.shape(self._scale_tril) assertions.extend( [assert_util.assert_positive( tf.linalg.diag_part(self._scale_tril), message='`scale_tril` must be positive definite.'), assert_util.assert_equal( shape[-1], shape[-2], message='`scale_tril` must be square.')] ) else: if is_init != tensor_util.is_ref(self._scale_full): assertions.append(distribution_util.assert_symmetric(self._scale_full)) return assertions
def _maybe_assert_valid_x(self, x, loc=None, scale=None, concentration=None): if not self.validate_args: return [] loc = tf.convert_to_tensor(self.loc) if loc is None else loc scale = tf.convert_to_tensor(self.scale) if scale is None else scale concentration = (tf.convert_to_tensor(self.concentration) if concentration is None else concentration) # The support of this bijector depends on the sign of concentration. is_in_bounds = tf.where(concentration > 0., x >= loc - scale / concentration, x <= loc - scale / concentration) # For concentration 0, the domain is the whole line. is_in_bounds = is_in_bounds | tf.math.equal(concentration, 0.) return [ assert_util.assert_equal( is_in_bounds, True, message='Forward transformation input must be inside domain.') ]
def prepare_tuple_argument(arg, n, arg_name, validate_args=False): """Helper which processes `Tensor`s to tuples in standard form.""" arg_size = ps.size(arg) arg_size_ = tf.get_static_value(arg_size) assertions = [] if arg_size_ is not None: if arg_size_ not in (1, n): raise ValueError( 'The size of `{}` must be equal to `1` or to the rank ' 'of the convolution (={}). Saw size = {}'.format( arg_name, n, arg_size_)) elif validate_args: assertions.append( assert_util.assert_equal( ps.logical_or(arg_size == 1, arg_size == n), True, message= ('The size of `{}` must be equal to `1` or to the rank of the ' 'convolution (={})'.format(arg_name, n)))) with tf.control_dependencies(assertions): arg = ps.broadcast_to(arg, shape=[n]) arg = ps.unstack(arg, num=n) return arg
def maybe_check_quadrature_param(param, name, validate_args): """Helper which checks validity of `loc` and `scale` init args.""" with tf.name_scope("check_" + name): assertions = [] if tensorshape_util.rank(param.shape) is not None: if tensorshape_util.rank(param.shape) == 0: raise ValueError("Mixing params must be a (batch of) vector; " "{}.rank={} is not at least one.".format( name, tensorshape_util.rank(param.shape))) elif validate_args: assertions.append( assert_util.assert_rank_at_least( param, 1, message=("Mixing params must be a (batch of) vector; " "{}.rank is not at least one.".format(name)))) # TODO(jvdillon): Remove once we support k-mixtures. if tensorshape_util.with_rank_at_least(param.shape, 1)[-1] is not None: if tf.compat.dimension_value(param.shape[-1]) != 1: raise NotImplementedError("Currently only bimixtures are supported; " "{}.shape[-1]={} is not 1.".format( name, tf.compat.dimension_value( param.shape[-1]))) elif validate_args: assertions.append( assert_util.assert_equal( tf.shape(input=param)[-1], 1, message=("Currently only bimixtures are supported; " "{}.shape[-1] is not 1.".format(name)))) if assertions: return distribution_util.with_dependencies(assertions, param) return param
def _maybe_validate_shape_override(self, override_shape, base_is_scalar_fn, static_base_shape, is_init): """Helper which ensures override batch/event_shape are valid.""" assertions = [] concretized_shape = None # Check valid dtype if is_init: # No xor check because `dtype` cannot change. dtype_ = override_shape.dtype if dtype_ is None: if concretized_shape is None: concretized_shape = tf.convert_to_tensor(override_shape) dtype_ = concretized_shape.dtype if dtype_util.base_dtype(dtype_) not in {tf.int32, tf.int64}: raise TypeError('Shape override must be integer type; ' 'saw {}.'.format(dtype_util.name(dtype_))) # Check non-negative elements if is_init != tensor_util.is_ref(override_shape): override_shape_ = tf.get_static_value(override_shape) msg = 'Shape override must have non-negative elements.' if override_shape_ is not None: if np.any(np.array(override_shape_) < 0): raise ValueError('{} Saw: {}'.format(msg, override_shape_)) elif self.validate_args: if concretized_shape is None: concretized_shape = tf.convert_to_tensor(override_shape) assertions.append( assert_util.assert_non_negative(concretized_shape, message=msg)) # Check valid shape override_ndims_ = tensorshape_util.rank(override_shape.shape) if is_init != (override_ndims_ is None): msg = 'Shape override must be a vector.' if override_ndims_ is not None: if override_ndims_ != 1: raise ValueError(msg) elif self.validate_args: if concretized_shape is None: concretized_shape = tf.convert_to_tensor(override_shape) override_rank = tf.rank(concretized_shape) assertions.append( assert_util.assert_equal(override_rank, 1, message=msg)) static_base_rank = tensorshape_util.rank(static_base_shape) # Determine if the override shape is `[]` (static_override_dims == [0]), # in which case the base distribution may be nonscalar. static_override_dims = tensorshape_util.dims(override_shape.shape) if is_init != (static_base_rank is None or static_override_dims is None): msg = 'Base distribution is not scalar.' if static_base_rank is not None and static_override_dims is not None: if static_base_rank != 0 and static_override_dims != [0]: raise ValueError(msg) elif self.validate_args: if concretized_shape is None: concretized_shape = tf.convert_to_tensor(override_shape) override_is_empty = tf.logical_not( self._has_nonzero_rank(concretized_shape)) assertions.append( assert_util.assert_equal(tf.logical_or( base_is_scalar_fn(), override_is_empty), True, message=msg)) return assertions
def _distributional_transform(self, x): """Performs distributional transform of the mixture samples. Distributional transform removes the parameters from samples of a multivariate distribution by applying conditional CDFs: (F(x_1), F(x_2 | x1_), ..., F(x_d | x_1, ..., x_d-1)) (the indexing is over the "flattened" event dimensions). The result is a sample of product of Uniform[0, 1] distributions. We assume that the components are factorized, so the conditional CDFs become F(x_i | x_1, ..., x_i-1) = sum_k w_i^k F_k (x_i), where w_i^k is the posterior mixture weight: for i > 0 w_i^k = w_k prob_k(x_1, ..., x_i-1) / sum_k' w_k' prob_k'(x_1, ..., x_i-1) and w_0^k = w_k is the mixture probability of the k-th component. Arguments: x: Sample of mixture distribution Returns: Result of the distributional transform """ if x.shape.ndims is None: # tf.nn.softmax raises an error when applied to inputs of undefined rank. raise ValueError( "Distributional transform does not support inputs of " "undefined rank.") # Obtain factorized components distribution and assert that it's # a scalar distribution. if isinstance(self._components_distribution, independent.Independent): univariate_components = self._components_distribution.distribution else: univariate_components = self._components_distribution with tf.control_dependencies([ assert_util.assert_equal( univariate_components.is_scalar_event(), True, message="`univariate_components` must have scalar event") ]): x_padded = self._pad_sample_dims(x) # [S, B, 1, E] log_prob_x = univariate_components.log_prob( x_padded) # [S, B, k, E] cdf_x = univariate_components.cdf(x_padded) # [S, B, k, E] # log prob_k (x_1, ..., x_i-1) cumsum_log_prob_x = tf.reshape( tf.math.cumsum( # [S*prod(B)*k, prod(E)] tf.reshape(log_prob_x, [-1, self._event_size]), exclusive=True, axis=-1), tf.shape(input=log_prob_x)) # [S, B, k, E] logits_mix_prob = distribution_utils.pad_mixture_dimensions( self.mixture_distribution.logits, self, self.mixture_distribution, self._event_ndims) # [B, k, 1] # Logits of the posterior weights: log w_k + log prob_k (x_1, ..., x_i-1) log_posterior_weights_x = logits_mix_prob + cumsum_log_prob_x component_axis = x.shape.ndims - self._event_ndims posterior_weights_x = tf.nn.softmax(log_posterior_weights_x, axis=component_axis) return tf.reduce_sum(input_tensor=posterior_weights_x * cdf_x, axis=component_axis)
def __init__(self, mixture_distribution, components_distribution, reparameterize=False, validate_args=False, allow_nan_stats=True, name="MixtureSameFamily"): """Construct a `MixtureSameFamily` distribution. Args: mixture_distribution: `tfp.distributions.Categorical`-like instance. Manages the probability of selecting components. The number of categories must match the rightmost batch dimension of the `components_distribution`. Must have either scalar `batch_shape` or `batch_shape` matching `components_distribution.batch_shape[:-1]`. components_distribution: `tfp.distributions.Distribution`-like instance. Right-most batch dimension indexes components. reparameterize: Python `bool`, default `False`. Whether to reparameterize samples of the distribution using implicit reparameterization gradients [(Figurnov et al., 2018)][1]. The gradients for the mixture logits are equivalent to the ones described by [(Graves, 2016)][2]. The gradients for the components parameters are also computed using implicit reparameterization (as opposed to ancestral sampling), meaning that all components are updated every step. Only works when: (1) components_distribution is fully reparameterized; (2) components_distribution is either a scalar distribution or fully factorized (tfd.Independent applied to a scalar distribution); (3) batch shape has a known rank. Experimental, may be slow and produce infs/NaNs. validate_args: Python `bool`, default `False`. When `True` distribution parameters are checked for validity despite possibly degrading runtime performance. When `False` invalid inputs may silently render incorrect outputs. allow_nan_stats: Python `bool`, default `True`. When `True`, statistics (e.g., mean, mode, variance) use the value "`NaN`" to indicate the result is undefined. When `False`, an exception is raised if one or more of the statistic's batch members are undefined. name: Python `str` name prefixed to Ops created by this class. Raises: ValueError: `if not mixture_distribution.dtype.is_integer`. ValueError: if mixture_distribution does not have scalar `event_shape`. ValueError: if `mixture_distribution.batch_shape` and `components_distribution.batch_shape[:-1]` are both fully defined and the former is neither scalar nor equal to the latter. ValueError: if `mixture_distribution` categories does not equal `components_distribution` rightmost batch shape. #### References [1]: Michael Figurnov, Shakir Mohamed and Andriy Mnih. Implicit reparameterization gradients. In _Neural Information Processing Systems_, 2018. https://arxiv.org/abs/1805.08498 [2]: Alex Graves. Stochastic Backpropagation through Mixture Density Distributions. _arXiv_, 2016. https://arxiv.org/abs/1607.05690 """ parameters = dict(locals()) with tf.name_scope(name) as name: self._mixture_distribution = mixture_distribution self._components_distribution = components_distribution self._runtime_assertions = [] s = components_distribution.event_shape_tensor() self._event_ndims = tf.compat.dimension_value(s.shape[0]) if self._event_ndims is None: self._event_ndims = tf.size(input=s) self._event_size = tf.reduce_prod(input_tensor=s) if not mixture_distribution.dtype.is_integer: raise ValueError( "`mixture_distribution.dtype` ({}) is not over integers". format(mixture_distribution.dtype.name)) if (mixture_distribution.event_shape.ndims is not None and mixture_distribution.event_shape.ndims != 0): raise ValueError( "`mixture_distribution` must have scalar `event_dim`s") elif validate_args: self._runtime_assertions += [ assert_util.assert_equal( tf.size( input=mixture_distribution.event_shape_tensor()), 0, message= "`mixture_distribution` must have scalar `event_dim`s" ), ] mdbs = mixture_distribution.batch_shape cdbs = components_distribution.batch_shape.with_rank_at_least( 1)[:-1] if mdbs.is_fully_defined() and cdbs.is_fully_defined(): if mdbs.ndims != 0 and mdbs != cdbs: raise ValueError( "`mixture_distribution.batch_shape` (`{}`) is not " "compatible with `components_distribution.batch_shape` " "(`{}`)".format(mdbs.as_list(), cdbs.as_list())) elif validate_args: mdbs = mixture_distribution.batch_shape_tensor() cdbs = components_distribution.batch_shape_tensor()[:-1] self._runtime_assertions += [ assert_util.assert_equal( distribution_utils.pick_vector( mixture_distribution.is_scalar_batch(), cdbs, mdbs), cdbs, message= ("`mixture_distribution.batch_shape` is not " "compatible with `components_distribution.batch_shape`" )) ] km = tf.compat.dimension_value( mixture_distribution.logits.shape.with_rank_at_least(1)[-1]) kc = tf.compat.dimension_value( components_distribution.batch_shape.with_rank_at_least(1)[-1]) if km is not None and kc is not None and km != kc: raise ValueError( "`mixture_distribution components` ({}) does not " "equal `components_distribution.batch_shape[-1]` " "({})".format(km, kc)) elif validate_args: km = tf.shape(input=mixture_distribution.logits)[-1] kc = components_distribution.batch_shape_tensor()[-1] self._runtime_assertions += [ assert_util.assert_equal( km, kc, message=( "`mixture_distribution components` does not equal " "`components_distribution.batch_shape[-1:]`")), ] elif km is None: km = tf.shape(input=mixture_distribution.logits)[-1] self._num_components = km self._reparameterize = reparameterize if reparameterize: # Note: tfd.Independent passes through the reparameterization type hence # we do not need separate logic for Independent. if (self._components_distribution.reparameterization_type != reparameterization.FULLY_REPARAMETERIZED): raise ValueError("Cannot reparameterize a mixture of " "non-reparameterized components.") reparameterization_type = reparameterization.FULLY_REPARAMETERIZED else: reparameterization_type = reparameterization.NOT_REPARAMETERIZED super(MixtureSameFamily, self).__init__( dtype=self._components_distribution.dtype, reparameterization_type=reparameterization_type, validate_args=validate_args, allow_nan_stats=allow_nan_stats, parameters=parameters, graph_parents=( self._mixture_distribution._graph_parents # pylint: disable=protected-access + self._components_distribution._graph_parents), # pylint: disable=protected-access name=name)
def _parameter_control_dependencies(self, is_init): assertions = [] # Check num_steps is a scalar that's at least 1. if is_init != tensor_util.is_ref(self.num_steps): num_steps = tf.convert_to_tensor(self.num_steps) num_steps_ = tf.get_static_value(num_steps) if num_steps_ is not None: if np.ndim(num_steps_) != 0: raise ValueError( '`num_steps` must be a scalar but it has rank {}'.format( np.ndim(num_steps_))) if num_steps_ < 1: raise ValueError('`num_steps` must be at least 1.') elif self.validate_args: message = '`num_steps` must be a scalar' assertions.append( assert_util.assert_rank_at_most(self.num_steps, 0, message=message)) assertions.append( assert_util.assert_greater_equal( num_steps, 1, message='`num_steps` must be at least 1.')) # Check that the initial distribution has scalar events over the # integers. if is_init and not dtype_util.is_integer(self.initial_distribution.dtype): raise ValueError( '`initial_distribution.dtype` ({}) is not over integers'.format( dtype_util.name(self.initial_distribution.dtype))) if tensorshape_util.rank(self.initial_distribution.event_shape) is not None: if tensorshape_util.rank(self.initial_distribution.event_shape) != 0: raise ValueError('`initial_distribution` must have scalar `event_dim`s') elif self.validate_args: assertions += [ assert_util.assert_equal( ps.size(self.initial_distribution.event_shape_tensor()), 0, message='`initial_distribution` must have scalar `event_dim`s'), ] # Check that the transition distribution is over the integers. if (is_init and not dtype_util.is_integer(self.transition_distribution.dtype)): raise ValueError( '`transition_distribution.dtype` ({}) is not over integers'.format( dtype_util.name(self.transition_distribution.dtype))) # Check observations have non-scalar batches. # The graph version of this assertion is incorporated as # a control dependency of the transition/observation # compatibility test. if tensorshape_util.rank(self.observation_distribution.batch_shape) == 0: raise ValueError( "`observation_distribution` can't have scalar batches") # Check transitions have non-scalar batches. # The graph version of this assertion is incorporated as # a control dependency of the transition/observation # compatibility test. if tensorshape_util.rank(self.transition_distribution.batch_shape) == 0: raise ValueError( "`transition_distribution` can't have scalar batches") # Check compatibility of transition distribution and observation # distribution. tdbs = self.transition_distribution.batch_shape odbs = self.observation_distribution.batch_shape if (tensorshape_util.dims(tdbs) is not None and tf.compat.dimension_value(odbs[-1]) is not None): if (tf.compat.dimension_value(tdbs[-1]) != tf.compat.dimension_value(odbs[-1])): raise ValueError( '`transition_distribution` and `observation_distribution` ' 'must agree on last dimension of batch size') elif self.validate_args: tdbs = self.transition_distribution.batch_shape_tensor() odbs = self.observation_distribution.batch_shape_tensor() transition_precondition = assert_util.assert_greater( ps.size(tdbs), 0, message=('`transition_distribution` can\'t have scalar ' 'batches')) observation_precondition = assert_util.assert_greater( ps.size(odbs), 0, message=('`observation_distribution` can\'t have scalar ' 'batches')) with tf.control_dependencies([ transition_precondition, observation_precondition]): assertions += [ assert_util.assert_equal( tdbs[-1], odbs[-1], message=('`transition_distribution` and ' '`observation_distribution` ' 'must agree on last dimension of batch size'))] return assertions
def _kl_independent(a, b, name='kl_independent'): """Batched KL divergence `KL(a || b)` for Independent distributions. We can leverage the fact that ``` KL(Independent(a) || Independent(b)) = sum(KL(a || b)) ``` where the sum is over the `reinterpreted_batch_ndims`. Args: a: Instance of `Independent`. b: Instance of `Independent`. name: (optional) name to use for created ops. Default 'kl_independent'. Returns: Batchwise `KL(a || b)`. Raises: ValueError: If the event space for `a` and `b`, or their underlying distributions don't match. """ p = a.distribution q = b.distribution # The KL between any two (non)-batched distributions is a scalar. # Given that the KL between two factored distributions is the sum, i.e. # KL(p1(x)p2(y) || q1(x)q2(y)) = KL(p1 || q1) + KL(q1 || q2), we compute # KL(p || q) and do a `reduce_sum` on the reinterpreted batch dimensions. if (tensorshape_util.is_fully_defined(a.event_shape) and tensorshape_util.is_fully_defined(b.event_shape)): if a.event_shape == b.event_shape: if p.event_shape == q.event_shape: num_reduce_dims = (tensorshape_util.rank(a.event_shape) - tensorshape_util.rank(p.event_shape)) reduce_dims = [-i - 1 for i in range(0, num_reduce_dims)] return tf.reduce_sum(kullback_leibler.kl_divergence(p, q, name=name), axis=reduce_dims) else: raise NotImplementedError( 'KL between Independents with different ' 'event shapes not supported.') else: raise ValueError('Event shapes do not match.') else: p_event_shape_tensor = p.event_shape_tensor() q_event_shape_tensor = q.event_shape_tensor() # NOTE: We could optimize by passing the event_shape_tensor of p and q # to a.event_shape_tensor() and b.event_shape_tensor(). a_event_shape_tensor = a.event_shape_tensor() b_event_shape_tensor = b.event_shape_tensor() with tf.control_dependencies([ assert_util.assert_equal(a_event_shape_tensor, b_event_shape_tensor, message='Event shapes do not match.'), assert_util.assert_equal(p_event_shape_tensor, q_event_shape_tensor, message='Event shapes do not match.'), ]): num_reduce_dims = (prefer_static.rank_from_shape( a_event_shape_tensor, a.event_shape) - prefer_static.rank_from_shape( p_event_shape_tensor, p.event_shape)) reduce_dims = prefer_static.range(-num_reduce_dims, 0, 1) return tf.reduce_sum(kullback_leibler.kl_divergence(p, q, name=name), axis=reduce_dims)
def _parameter_control_dependencies(self, is_init): assertions = [] if is_init and not dtype_util.is_integer( self.mixture_distribution.dtype): raise ValueError( '`mixture_distribution.dtype` ({}) is not over integers'. format(dtype_util.name(self.mixture_distribution.dtype))) if tensorshape_util.rank( self.mixture_distribution.event_shape) is not None: if tensorshape_util.rank( self.mixture_distribution.event_shape) != 0: raise ValueError( '`mixture_distribution` must have scalar `event_dim`s') elif self.validate_args: assertions += [ assert_util.assert_equal( tf.size(self.mixture_distribution.event_shape_tensor()), 0, message= '`mixture_distribution` must have scalar `event_dim`s'), ] # pylint: disable=protected-access mixture_dist_param = (self.mixture_distribution._probs if self.mixture_distribution._logits is None else self.mixture_distribution._logits) km = tf.compat.dimension_value( tensorshape_util.with_rank_at_least(mixture_dist_param.shape, 1)[-1]) kc = tf.compat.dimension_value( tensorshape_util.with_rank_at_least( self.components_distribution.batch_shape, 1)[-1]) component_bst = None if km is not None and kc is not None: if km != kc: raise ValueError( '`mixture_distribution` components ({}) does not ' 'equal `components_distribution.batch_shape[-1]` ' '({})'.format(km, kc)) elif self.validate_args: if km is None: mixture_dist_param = tf.convert_to_tensor(mixture_dist_param) km = tf.shape(mixture_dist_param)[-1] if kc is None: component_bst = self.components_distribution.batch_shape_tensor( ) kc = component_bst[-1] assertions += [ assert_util.assert_equal( km, kc, message=( '`mixture_distribution` components does not equal ' '`components_distribution.batch_shape[-1]`')), ] mdbs = self.mixture_distribution.batch_shape cdbs = tensorshape_util.with_rank_at_least( self.components_distribution.batch_shape, 1)[:-1] if (tensorshape_util.is_fully_defined(mdbs) and tensorshape_util.is_fully_defined(cdbs)): if tensorshape_util.rank(mdbs) != 0 and mdbs != cdbs: raise ValueError( '`mixture_distribution.batch_shape` (`{}`) is not ' 'compatible with `components_distribution.batch_shape` ' '(`{}`)'.format(tensorshape_util.as_list(mdbs), tensorshape_util.as_list(cdbs))) elif self.validate_args: if not tensorshape_util.is_fully_defined(mdbs): mixture_dist_param = tf.convert_to_tensor(mixture_dist_param) mdbs = tf.shape(mixture_dist_param)[:-1] if not tensorshape_util.is_fully_defined(cdbs): if component_bst is None: component_bst = self.components_distribution.batch_shape_tensor( ) cdbs = component_bst[:-1] assertions += [ assert_util.assert_equal( distribution_utils.pick_vector( tf.equal(tf.shape(mdbs)[0], 0), cdbs, mdbs), cdbs, message=( '`mixture_distribution.batch_shape` is not ' 'compatible with `components_distribution.batch_shape`' )) ] return assertions
def _distributional_transform(self, x, event_shape): """Performs distributional transform of the mixture samples. Distributional transform removes the parameters from samples of a multivariate distribution by applying conditional CDFs: (F(x_1), F(x_2 | x1_), ..., F(x_d | x_1, ..., x_d-1)) (the indexing is over the 'flattened' event dimensions). The result is a sample of product of Uniform[0, 1] distributions. We assume that the components are factorized, so the conditional CDFs become F(x_i | x_1, ..., x_i-1) = sum_k w_i^k F_k (x_i), where w_i^k is the posterior mixture weight: for i > 0 w_i^k = w_k prob_k(x_1, ..., x_i-1) / sum_k' w_k' prob_k'(x_1, ..., x_i-1) and w_0^k = w_k is the mixture probability of the k-th component. Arguments: x: Sample of mixture distribution event_shape: The event shape of this distribution Returns: Result of the distributional transform """ if tensorshape_util.rank(x.shape) is None: # tf.math.softmax raises an error when applied to inputs of undefined # rank. raise ValueError( 'Distributional transform does not support inputs of ' 'undefined rank.') # Obtain factorized components distribution and assert that it's # a scalar distribution. if isinstance(self._components_distribution, independent.Independent): univariate_components = self._components_distribution.distribution else: univariate_components = self._components_distribution with tf.control_dependencies([ assert_util.assert_equal( univariate_components.is_scalar_event(), True, message='`univariate_components` must have scalar event') ]): event_ndims = ps.rank_from_shape(event_shape) x_padded = self._pad_sample_dims( x, event_ndims=event_ndims) # [S, B, 1, E] log_prob_x = univariate_components.log_prob( x_padded) # [S, B, k, E] cdf_x = univariate_components.cdf(x_padded) # [S, B, k, E] # log prob_k (x_1, ..., x_i-1) event_size = ps.cast(ps.reduce_prod(event_shape), dtype=tf.int32) cumsum_log_prob_x = tf.reshape( tf.math.cumsum( # [S*prod(B)*k, prod(E)] tf.reshape(log_prob_x, [-1, event_size]), exclusive=True, axis=-1), ps.shape(log_prob_x)) # [S, B, k, E] event_ndims = ps.rank_from_shape(event_shape) logits_mix_prob = self.mixture_distribution.logits_parameter() logits_mix_prob = tf.reshape( logits_mix_prob, # [k] or [B, k] ps.concat([ ps.shape(logits_mix_prob), ps.ones([event_ndims], dtype=tf.int32), ], axis=0)) # [k, [1]*e] or [B, k, [1]*e] # Logits of the posterior weights: log w_k + log prob_k (x_1, ..., x_i-1) log_posterior_weights_x = logits_mix_prob + cumsum_log_prob_x component_axis = tensorshape_util.rank(x.shape) - event_ndims posterior_weights_x = tf.math.softmax(log_posterior_weights_x, axis=component_axis) return tf.reduce_sum(posterior_weights_x * cdf_x, axis=component_axis)
def _sample_control_dependencies(self, x): """Helper which validates sample arg, e.g., input to `log_prob`.""" x_ndims = ( tf.rank(x) if tensorshape_util.rank(x.shape) is None else tensorshape_util.rank(x.shape)) event_ndims = ( tf.size(self.event_shape_tensor()) if tensorshape_util.rank(self.event_shape) is None else tensorshape_util.rank(self.event_shape)) batch_ndims = ( tf.size(self._batch_shape_unexpanded) if tensorshape_util.rank(self.batch_shape) is None else tensorshape_util.rank(self.batch_shape)) expected_batch_event_ndims = batch_ndims + event_ndims if (isinstance(x_ndims, int) and isinstance(expected_batch_event_ndims, int)): if x_ndims < expected_batch_event_ndims: raise NotImplementedError( 'Broadcasting is not supported; too few batch and event dims ' '(expected at least {}, saw {}).'.format( expected_batch_event_ndims, x_ndims)) ndims_assertion = [] elif self.validate_args: ndims_assertion = [ assert_util.assert_greater_equal( x_ndims, expected_batch_event_ndims, message=('Broadcasting is not supported; too few ' 'batch and event dims.'), name='assert_batch_and_event_ndims_large_enough'), ] if (tensorshape_util.is_fully_defined(self.batch_shape) and tensorshape_util.is_fully_defined(self.event_shape)): expected_batch_event_shape = np.int32( tensorshape_util.concatenate(self.batch_shape, self.event_shape)) else: expected_batch_event_shape = tf.concat( [ self.batch_shape_tensor(), self.event_shape_tensor(), ], axis=0) sample_ndims = x_ndims - expected_batch_event_ndims if isinstance(sample_ndims, int): sample_ndims = max(sample_ndims, 0) if (isinstance(sample_ndims, int) and tensorshape_util.is_fully_defined(x.shape[sample_ndims:])): actual_batch_event_shape = np.int32(x.shape[sample_ndims:]) else: sample_ndims = tf.maximum(sample_ndims, 0) actual_batch_event_shape = tf.shape(x)[sample_ndims:] assertions = [] if (isinstance(expected_batch_event_shape, np.ndarray) and isinstance(actual_batch_event_shape, np.ndarray)): if any(expected_batch_event_shape != actual_batch_event_shape): raise NotImplementedError('Broadcasting is not supported; ' 'unexpected batch and event shape ' '(expected {}, saw {}).'.format( expected_batch_event_shape, actual_batch_event_shape)) # We need to set the final runtime-assertions to `ndims_assertion` since # its possible this assertion was created. We could add a condition to # only do so if `self.validate_args == True`, however this is redundant # as `ndims_assertion` already encodes this information. assertions.extend(ndims_assertion) elif self.validate_args: # We need to make the `ndims_assertion` a control dep because otherwise # TF itself might raise an exception owing to this assertion being # ill-defined, ie, one cannot even compare different rank Tensors. with tf.control_dependencies(ndims_assertion): shape_assertion = assert_util.assert_equal( expected_batch_event_shape, actual_batch_event_shape, message=('Broadcasting is not supported; ' 'unexpected batch and event shape.'), name='assert_batch_and_event_shape_same') assertions.append(shape_assertion) return assertions
def __init__(self, initial_distribution, transition_distribution, observation_distribution, num_steps, validate_args=False, allow_nan_stats=True, name="HiddenMarkovModel"): """Initialize hidden Markov model. Args: initial_distribution: A `Categorical`-like instance. Determines probability of first hidden state in Markov chain. The number of categories must match the number of categories of `transition_distribution` as well as both the rightmost batch dimension of `transition_distribution` and the rightmost batch dimension of `observation_distribution`. transition_distribution: A `Categorical`-like instance. The rightmost batch dimension indexes the probability distribution of each hidden state conditioned on the previous hidden state. observation_distribution: A `tfp.distributions.Distribution`-like instance. The rightmost batch dimension indexes the distribution of each observation conditioned on the corresponding hidden state. num_steps: The number of steps taken in Markov chain. A python `int`. validate_args: Python `bool`, default `False`. When `True` distribution parameters are checked for validity despite possibly degrading runtime performance. When `False` invalid inputs may silently render incorrect outputs. Default value: `False`. allow_nan_stats: Python `bool`, default `True`. When `True`, statistics (e.g., mean, mode, variance) use the value "`NaN`" to indicate the result is undefined. When `False`, an exception is raised if one or more of the statistic's batch members are undefined. Default value: `True`. name: Python `str` name prefixed to Ops created by this class. Default value: "HiddenMarkovModel". Raises: ValueError: if `num_steps` is not at least 1. ValueError: if `initial_distribution` does not have scalar `event_shape`. ValueError: if `transition_distribution` does not have scalar `event_shape.` ValueError: if `transition_distribution` and `observation_distribution` are fully defined but don't have matching rightmost dimension. """ parameters = dict(locals()) # pylint: disable=protected-access with tf.name_scope(name) as name: self._runtime_assertions = [] # pylint: enable=protected-access num_steps = tf.convert_to_tensor(value=num_steps, name="num_steps") if validate_args: self._runtime_assertions += [ assert_util.assert_equal( tf.rank(num_steps), 0, message="`num_steps` must be a scalar") ] self._runtime_assertions += [ assert_util.assert_greater_equal( num_steps, 1, message="`num_steps` must be at least 1.") ] self._initial_distribution = initial_distribution self._observation_distribution = observation_distribution self._transition_distribution = transition_distribution if (initial_distribution.event_shape is not None and tensorshape_util.rank( initial_distribution.event_shape) != 0): raise ValueError( "`initial_distribution` must have scalar `event_dim`s") elif validate_args: self._runtime_assertions += [ assert_util.assert_equal( tf.shape(initial_distribution.event_shape_tensor())[0], 0, message="`initial_distribution` must have scalar" "`event_dim`s") ] if (transition_distribution.event_shape is not None and tensorshape_util.rank( transition_distribution.event_shape) != 0): raise ValueError( "`transition_distribution` must have scalar `event_dim`s") elif validate_args: self._runtime_assertions += [ assert_util.assert_equal( tf.shape( transition_distribution.event_shape_tensor())[0], 0, message="`transition_distribution` must have scalar" "`event_dim`s") ] if (tensorshape_util.dims(transition_distribution.batch_shape) is not None and tensorshape_util.rank( transition_distribution.batch_shape) == 0): raise ValueError( "`transition_distribution` can't have scalar batches") elif validate_args: self._runtime_assertions += [ assert_util.assert_greater( tf.size(transition_distribution.batch_shape_tensor()), 0, message="`transition_distribution` can't have scalar " "batches") ] if (tensorshape_util.dims(observation_distribution.batch_shape) is not None and tensorshape_util.rank( observation_distribution.batch_shape) == 0): raise ValueError( "`observation_distribution` can't have scalar batches") elif validate_args: self._runtime_assertions += [ assert_util.assert_greater( tf.size(observation_distribution.batch_shape_tensor()), 0, message="`observation_distribution` can't have scalar " "batches") ] # Infer number of hidden states and check consistency # between transitions and observations with tf.control_dependencies(self._runtime_assertions): self._num_states = ( (tensorshape_util.dims(transition_distribution.batch_shape) is not None and tensorshape_util.as_list( transition_distribution.batch_shape)[-1]) or transition_distribution.batch_shape_tensor()[-1]) observation_states = ( (tensorshape_util.dims( observation_distribution.batch_shape) is not None and tensorshape_util.as_list( observation_distribution.batch_shape)[-1]) or observation_distribution.batch_shape_tensor()[-1]) if (tf.is_tensor(self._num_states) or tf.is_tensor(observation_states)): if validate_args: self._runtime_assertions += [ assert_util.assert_equal( self._num_states, observation_states, message="`transition_distribution` and " "`observation_distribution` must agree on " "last dimension of batch size") ] elif self._num_states != observation_states: raise ValueError("`transition_distribution` and " "`observation_distribution` must agree on " "last dimension of batch size") self._log_init = _extract_log_probs(self._num_states, initial_distribution) self._log_trans = _extract_log_probs(self._num_states, transition_distribution) self._num_steps = num_steps self._num_states = tf.shape(self._log_init)[-1] self._underlying_event_rank = tf.size( self._observation_distribution.event_shape_tensor()) num_steps_ = tf.get_static_value(num_steps) if num_steps_ is not None: self.static_event_shape = tf.TensorShape([ num_steps_ ]).concatenate(self._observation_distribution.event_shape) else: self.static_event_shape = None with tf.control_dependencies(self._runtime_assertions): self.static_batch_shape = tf.broadcast_static_shape( self._initial_distribution.batch_shape, tf.broadcast_static_shape( self._transition_distribution.batch_shape[:-1], self._observation_distribution.batch_shape[:-1])) # pylint: disable=protected-access super(HiddenMarkovModel, self).__init__( dtype=self._observation_distribution.dtype, reparameterization_type=reparameterization.NOT_REPARAMETERIZED, validate_args=validate_args, allow_nan_stats=allow_nan_stats, parameters=parameters, name=name) # pylint: enable=protected-access self._parameters = parameters
def _replace_event_shape_in_shape_tensor( input_shape, event_shape_in, event_shape_out, validate_args): """Replaces the rightmost dims in a `Tensor` representing a shape. Args: input_shape: a rank-1 `Tensor` of integers event_shape_in: the event shape expected to be present in rightmost dims of `shape_in`. event_shape_out: the event shape with which to replace `event_shape_in` in the rightmost dims of `input_shape`. validate_args: Python `bool` indicating whether arguments should be checked for correctness. Returns: output_shape: A rank-1 integer `Tensor` with the same contents as `input_shape` except for the event dims, which are replaced with `event_shape_out`. """ output_tensorshape, is_validated = _replace_event_shape_in_tensorshape( tensorshape_util.constant_value_as_shape(input_shape), event_shape_in, event_shape_out) # TODO(b/124240153): Remove map(tf.identity, deps) once tf.function # correctly supports control_dependencies. validation_dependencies = ( map(tf.identity, (event_shape_in, event_shape_out)) if validate_args else ()) if (tensorshape_util.is_fully_defined(output_tensorshape) and (is_validated or not validate_args)): with tf.control_dependencies(validation_dependencies): output_shape = tf.convert_to_tensor( output_tensorshape, name='output_shape', dtype_hint=tf.int32) return output_shape, output_tensorshape with tf.control_dependencies(validation_dependencies): event_shape_in_ndims = ( tf.size(event_shape_in) if tensorshape_util.num_elements(event_shape_in.shape) is None else tensorshape_util.num_elements(event_shape_in.shape)) input_non_event_shape, input_event_shape = tf.split( input_shape, num_or_size_splits=[-1, event_shape_in_ndims]) additional_assertions = [] if is_validated: pass elif validate_args: # Check that `input_event_shape` and `event_shape_in` are compatible in the # sense that they have equal entries in any position that isn't a `-1` in # `event_shape_in`. Note that our validations at construction time ensure # there is at most one such entry in `event_shape_in`. mask = event_shape_in >= 0 explicit_input_event_shape = tf.boolean_mask(input_event_shape, mask=mask) explicit_event_shape_in = tf.boolean_mask(event_shape_in, mask=mask) additional_assertions.append( assert_util.assert_equal( explicit_input_event_shape, explicit_event_shape_in, message='Input `event_shape` does not match `event_shape_in`.')) # We don't explicitly additionally verify # `tf.size(input_shape) > tf.size(event_shape_in)` since `tf.split` # already makes this assertion. with tf.control_dependencies(additional_assertions): output_shape = tf.concat([input_non_event_shape, event_shape_out], axis=0, name='output_shape') return output_shape, output_tensorshape
def custom_gradient(fx, gx, x, fx_gx_manually_stopped=False, name=None): """Embeds a custom gradient into a `Tensor`. This function works by clever application of `stop_gradient`. I.e., observe that: ```none h(x) = stop_gradient(f(x)) + stop_gradient(g(x)) * (x - stop_gradient(x)) ``` is such that `h(x) == stop_gradient(f(x))` and `grad[h(x), x] == stop_gradient(g(x)).` In addition to scalar-domain/scalar-range functions, this function also supports tensor-domain/scalar-range functions. Partial Custom Gradient: Suppose `h(x) = htilde(x, y)`. Note that `dh/dx = stop(g(x))` but `dh/dy = None`. This is because a `Tensor` cannot have only a portion of its gradient stopped. To circumvent this issue, one must manually `stop_gradient` the relevant portions of `f`, `g`. For example see the unit-test, `test_works_correctly_fx_gx_manually_stopped`. Args: fx: `Tensor`. Output of function evaluated at `x`. gx: `Tensor` or list of `Tensor`s. Gradient of function at (each) `x`. x: `Tensor` or list of `Tensor`s. Args of evaluation for `f`. fx_gx_manually_stopped: Python `bool` indicating that `fx`, `gx` manually have `stop_gradient` applied. name: Python `str` name prefixed to Ops created by this function. Returns: fx: Floating-type `Tensor` equal to `f(x)` but which has gradient `stop_gradient(g(x))`. """ def maybe_stop(x): if fx_gx_manually_stopped: return x return tf.stop_gradient(x) with tf.name_scope(name or 'custom_gradient'): fx = tf.convert_to_tensor(fx, name='fx') # We don't want to bother eagerly computing `gx` since we may not even need # it. with tf.control_dependencies([fx]): if is_list_like(x): x = [identity(x_, name='x') for x_ in x] else: x = [identity(x, name='x')] if is_list_like(gx): gx = [identity(gx_, dtype=fx.dtype, name='gx') for gx_ in gx] else: gx = [identity(gx, dtype=fx.dtype, name='gx')] override_grad = [] for x_, gx_ in zip(x, gx): # Observe: tf.gradients(f(x), x)[i].shape == x[i].shape # thus we check that the user is supplying correct shapes. equal_shape = assert_util.assert_equal( tf.shape(x_), tf.shape(gx_), message='Each `x` must have the same shape as each `gx`.') with tf.control_dependencies([equal_shape]): # IEEE754 ensures `(x-x)==0.` and that `0.*x==0.` so we make sure to # write the code this way, rather than, e.g., # `sum_x * stop(gx) + stop(fx - sum_x * gx)`. # For more discussion regarding the relevant portions of the IEEE754 # standard, see the StackOverflow question, # "Is there a floating point value of x, for which x-x == 0 is false?" # http://stackoverflow.com/q/2686644 zeros_like_x_ = x_ - tf.stop_gradient(x_) override_grad.append(tf.reduce_sum(maybe_stop(gx_) * zeros_like_x_)) override_grad = sum(override_grad) override_grad /= tf.cast( tf.size(fx), dtype=dtype_util.base_dtype(fx.dtype)) # Proof of correctness: # # f(x) = x * stop[gx] + stop[fx - x * gx] # = stop[fx] # # g(x) = grad[fx] # = stop[gx] + grad[stop[fx - x * gx]] # = stop[gx] + 0 # # Notice that when x is zero it still works: # grad[x * stop(gx) + stop(fx - x * gx)] = 1 * stop[gx] + 0 = stop[gx] # # The proof is similar for the tensor-domain case, except that we # `reduce_sum` the `stop[gx] * (x - stop[x])` then rescale by # `tf.size(fx)` since this reduced version is broadcast to `fx`. return maybe_stop(fx) + override_grad