def assert_rank_at_most(x, rank, data=None, summarize=None, message=None, name=None): """Assert `x` has rank equal to `rank` or smaller. Example of adding a dependency to an operation: ```python with tf.control_dependencies([tf.assert_rank_at_most(x, 2)]): output = tf.reduce_sum(x) ``` Args: x: Numeric `Tensor`. rank: Scalar `Tensor`. data: The tensors to print out if the condition is False. Defaults to error message and first few entries of `x`. summarize: Print this many entries of each tensor. message: A string to prefix to the default message. name: A name for this operation (optional). Defaults to "assert_rank_at_most". Returns: Op raising `InvalidArgumentError` unless `x` has specified rank or lower. If static checks determine `x` has correct rank, a `no_op` is returned. Raises: ValueError: If static checks determine `x` has wrong rank. """ with tf.name_scope(name or 'assert_rank_at_most'): return tf1.assert_less_equal( tf.rank(x), rank, data=data, summarize=summarize, message=message)
def lu_reconstruct_assertions(lower_upper, perm, validate_args): """Returns list of assertions related to `lu_reconstruct` assumptions.""" assertions = [] message = 'Input `lower_upper` must have at least 2 dimensions.' if tensorshape_util.rank(lower_upper.shape) is not None: if tensorshape_util.rank(lower_upper.shape) < 2: raise ValueError(message) elif validate_args: assertions.append( assert_util.assert_rank_at_least(lower_upper, rank=2, message=message)) message = '`rank(lower_upper)` must equal `rank(perm) + 1`' if (tensorshape_util.rank(lower_upper.shape) is not None and tensorshape_util.rank(perm.shape) is not None): if (tensorshape_util.rank(lower_upper.shape) != tensorshape_util.rank(perm.shape) + 1): raise ValueError(message) elif validate_args: assertions.append( assert_util.assert_rank(lower_upper, rank=tf.rank(perm) + 1, message=message)) message = '`lower_upper` must be square.' if tensorshape_util.is_fully_defined(lower_upper.shape[:-2]): if lower_upper.shape[-2] != lower_upper.shape[-1]: raise ValueError(message) elif validate_args: m, n = tf.split(tf.shape(lower_upper)[-2:], num_or_size_splits=2) assertions.append(assert_util.assert_equal(m, n, message=message)) return assertions
def _make_columnar(self, x): """Ensures non-scalar input has at least one column. Example: If `x = [1, 2, 3]` then the output is `[[1], [2], [3]]`. If `x = [[1, 2, 3], [4, 5, 6]]` then the output is unchanged. If `x = 1` then the output is unchanged. Args: x: `Tensor`. Returns: columnar_x: `Tensor` with at least two dimensions. """ if tensorshape_util.rank(x.shape) is not None: if tensorshape_util.rank(x.shape) == 1: x = x[tf.newaxis, :] return x shape = tf.shape(x) maybe_expanded_shape = tf.concat([ shape[:-1], distribution_util.pick_vector( tf.equal(tf.rank(x), 1), [1], np.array([], dtype=np.int32)), shape[-1:], ], 0) return tf.reshape(x, maybe_expanded_shape)
def _validate_block_sizes(block_sizes, bijectors, validate_args): """Helper to validate block sizes.""" block_sizes_shape = block_sizes.shape if tensorshape_util.is_fully_defined(block_sizes_shape): if (tensorshape_util.rank(block_sizes_shape) != 1 or (tensorshape_util.num_elements(block_sizes_shape) != len(bijectors))): raise ValueError( '`block_sizes` must be `None`, or a vector of the same length as ' '`bijectors`. Got a `Tensor` with shape {} and `bijectors` of ' 'length {}'.format(block_sizes_shape, len(bijectors))) return block_sizes elif validate_args: message = ( '`block_sizes` must be `None`, or a vector of the same length ' 'as `bijectors`.') with tf.control_dependencies([ assert_util.assert_equal(tf.size(block_sizes), len(bijectors), message=message), assert_util.assert_equal(tf.rank(block_sizes), 1) ]): return tf.identity(block_sizes) else: return block_sizes
def _slice_single_param(param, param_event_ndims, slices, dist_batch_shape): """Slices a single parameter of a distribution. Args: param: A `Tensor`, the original parameter to slice. param_event_ndims: `int` event parameterization rank for this parameter. slices: A `tuple` of normalized slices. dist_batch_shape: The distribution's batch shape `Tensor`. Returns: new_param: A `Tensor`, batch-sliced according to slices. """ # Extend param shape with ones on the left to match dist_batch_shape. param_shape = tf.shape(input=param) insert_ones = tf.ones( [tf.size(input=dist_batch_shape) + param_event_ndims - tf.rank(param)], dtype=param_shape.dtype) new_param_shape = tf.concat([insert_ones, param_shape], axis=0) full_batch_param = tf.reshape(param, new_param_shape) param_slices = [] # We separately track the batch axis from the parameter axis because we want # them to align for positive indexing, and be offset by param_event_ndims for # negative indexing. param_dim_idx = 0 batch_dim_idx = 0 for slc in slices: if slc is tf.newaxis: param_slices.append(slc) continue if slc is Ellipsis: if batch_dim_idx < 0: raise ValueError('Found multiple `...` in slices {}'.format(slices)) param_slices.append(slc) # Switch over to negative indexing for the broadcast check. num_remaining_non_newaxis_slices = sum( [s is not tf.newaxis for s in slices[slices.index(Ellipsis) + 1:]]) batch_dim_idx = -num_remaining_non_newaxis_slices param_dim_idx = batch_dim_idx - param_event_ndims continue # Find the batch dimension sizes for both parameter and distribution. param_dim_size = new_param_shape[param_dim_idx] batch_dim_size = dist_batch_shape[batch_dim_idx] is_broadcast = batch_dim_size > param_dim_size # Slices are denoted by start:stop:step. if isinstance(slc, slice): start, stop, step = slc.start, slc.stop, slc.step if start is not None: start = tf.where(is_broadcast, 0, start) if stop is not None: stop = tf.where(is_broadcast, 1, stop) if step is not None: step = tf.where(is_broadcast, 1, step) param_slices.append(slice(start, stop, step)) else: # int, or int Tensor, e.g. d[d.batch_shape_tensor()[0] // 2] param_slices.append(tf.where(is_broadcast, 0, slc)) param_dim_idx += 1 batch_dim_idx += 1 param_slices.extend([ALL_SLICE] * param_event_ndims) return full_batch_param.__getitem__(param_slices)
def _pad_sample_dims(self, x): with tf.name_scope("pad_sample_dims"): ndims = tensorshape_util.rank(x.shape) if tensorshape_util.rank( x.shape) is not None else tf.rank(x) shape = tf.shape(x) d = ndims - self._event_ndims x = tf.reshape(x, shape=tf.concat([shape[:d], [1], shape[d:]], axis=0)) return x
def softmax(x, axis, name=None): """Equivalent to tf.math.softmax but works around b/70297725.""" with tf.name_scope(name or "softmax"): x = tf.convert_to_tensor(x, name="x") ndims = (tensorshape_util.rank(x.shape) if tensorshape_util.rank( x.shape) is not None else tf.rank(x, name="ndims")) axis = tf.convert_to_tensor(axis, dtype=tf.int32, name="axis") axis_ = tf.get_static_value(axis) if axis_ is not None: axis = np.int(ndims + axis_ if axis_ < 0 else axis_) else: axis = tf.where(axis < 0, ndims + axis, axis) return tf.math.softmax(x, axis=axis)
def _log_prob(self, x): # By convention, we always put the grid points right-most. y = tf.stack([aff.inverse(x) for aff in self.interpolated_affine], axis=-1) log_prob = tf.reduce_sum(self.distribution.log_prob(y), axis=-2) # Because the affine transformation has a constant Jacobian, it is the case # that `affine.fldj(x) = -affine.ildj(x)`. This is not true in general. fldj = tf.stack([ aff.forward_log_det_jacobian( x, event_ndims=tf.rank(self.event_shape_tensor())) for aff in self.interpolated_affine ], axis=-1) return tf.reduce_logsumexp(self.mixture_distribution.logits - fldj + log_prob, axis=-1)
def _broadcast_event_and_samples(event, samples, event_ndims): """Broadcasts the event or samples.""" # This is the shape of self.samples, without the samples axis, i.e. the shape # of the result of a call to dist.sample(). This way we can broadcast it with # event to get a properly-sized event, then add the singleton dim back at # -event_ndims - 1. samples_shape = tf.concat([ tf.shape(samples)[:-event_ndims - 1], tf.shape(samples)[tf.rank(samples) - event_ndims:] ], axis=0) event = event * tf.ones(samples_shape, dtype=event.dtype) event = tf.expand_dims(event, axis=-event_ndims - 1) samples = samples * tf.ones_like(event, dtype=samples.dtype) return event, samples
def _prob(self, x): if self.validate_args: is_vector_check = assert_util.assert_rank_at_least(x, 1) right_vec_space_check = assert_util.assert_equal( self.event_shape_tensor(), tf.gather(tf.shape(x), tf.rank(x) - 1), message= "Argument 'x' not defined in the same space R^k as this distribution" ) with tf.control_dependencies([is_vector_check]): with tf.control_dependencies([right_vec_space_check]): x = tf.identity(x) loc = tf.convert_to_tensor(self.loc) return tf.cast(tf.reduce_all(tf.abs(x - loc) <= self._slack(loc), axis=-1), dtype=self.dtype)
def _maybe_check_valid_shape(shape, validate_args): """Check that a shape Tensor is int-type and otherwise sane.""" if not dtype_util.is_integer(shape.dtype): raise TypeError('{} dtype ({}) should be `int`-like.'.format( shape, dtype_util.name(shape.dtype))) assertions = [] message = '`{}` rank should be <= 1.' if tensorshape_util.rank(shape.shape) is not None: if tensorshape_util.rank(shape.shape) > 1: raise ValueError(message.format(shape)) elif validate_args: assertions.append( assert_util.assert_less(tf.rank(shape), 2, message=message.format(shape))) shape_ = tf.get_static_value(shape) message = '`{}` elements must have at most one `-1`.' if shape_ is not None: if sum(shape_ == -1) > 1: raise ValueError(message.format(shape)) elif validate_args: assertions.append( assert_util.assert_less(tf.reduce_sum( tf.cast(tf.equal(shape, -1), tf.int32)), 2, message=message.format(shape))) message = '`{}` elements must be either positive integers or `-1`.' if shape_ is not None: if np.any(shape_ < -1): raise ValueError(message.format(shape)) elif validate_args: assertions.append( assert_util.assert_greater(shape, -2, message=message.format(shape))) return assertions
def _sample_shape(self, x): """Computes graph and static `sample_shape`.""" x_ndims = (tf.rank(x) if tensorshape_util.rank(x.shape) is None else tensorshape_util.rank(x.shape)) event_ndims = (tf.size(self.event_shape_tensor()) if tensorshape_util.rank(self.event_shape) is None else tensorshape_util.rank(self.event_shape)) batch_ndims = (tf.size(self._batch_shape_unexpanded) if tensorshape_util.rank(self.batch_shape) is None else tensorshape_util.rank(self.batch_shape)) sample_ndims = x_ndims - batch_ndims - event_ndims if isinstance(sample_ndims, int): static_sample_shape = x.shape[:sample_ndims] else: static_sample_shape = tf.TensorShape(None) if tensorshape_util.is_fully_defined(static_sample_shape): sample_shape = np.int32(static_sample_shape) else: sample_shape = tf.shape(x)[:sample_ndims] return sample_shape, static_sample_shape
def _forward_log_det_jacobian(self, x, **kwargs): x = tf.convert_to_tensor(x, name="x") fldj = tf.cast(0., dtype=dtype_util.base_dtype(x.dtype)) if not self.bijectors: return fldj event_ndims = self._maybe_get_static_event_ndims( self.forward_min_event_ndims) if _use_static_shape(x, event_ndims): event_shape = x.shape[tensorshape_util.rank(x.shape) - event_ndims:] else: event_shape = tf.shape(x)[tf.rank(x) - event_ndims:] # TODO(b/129973548): Document and simplify. for b in reversed(self.bijectors): fldj = fldj + b.forward_log_det_jacobian( x, event_ndims=event_ndims, **kwargs.get(b.name, {})) if _use_static_shape(x, event_ndims): event_shape = b.forward_event_shape(event_shape) event_ndims = self._maybe_get_static_event_ndims( tensorshape_util.rank(event_shape)) else: event_shape = b.forward_event_shape_tensor(event_shape) event_shape_ = distribution_util.maybe_get_static_value( event_shape) event_ndims = tf.size(event_shape) event_ndims_ = self._maybe_get_static_event_ndims(event_ndims) if event_ndims_ is not None and event_shape_ is not None: event_ndims = event_ndims_ event_shape = event_shape_ x = b.forward(x, **kwargs.get(b.name, {})) return fldj
def _inverse_log_det_jacobian(self, y, **kwargs): y = tf.convert_to_tensor(y, name="y") ildj = tf.cast(0., dtype=dtype_util.base_dtype(y.dtype)) if not self.bijectors: return ildj event_ndims = self._maybe_get_static_event_ndims( self.inverse_min_event_ndims) if _use_static_shape(y, event_ndims): event_shape = y.shape[tensorshape_util.rank(y.shape) - event_ndims:] else: event_shape = tf.shape(y)[tf.rank(y) - event_ndims:] # TODO(b/129973548): Document and simplify. for b in self.bijectors: ildj = ildj + b.inverse_log_det_jacobian( y, event_ndims=event_ndims, **kwargs.get(b.name, {})) if _use_static_shape(y, event_ndims): event_shape = b.inverse_event_shape(event_shape) event_ndims = self._maybe_get_static_event_ndims( tensorshape_util.rank(event_shape)) else: event_shape = b.inverse_event_shape_tensor(event_shape) event_shape_ = distribution_util.maybe_get_static_value( event_shape) event_ndims = tf.size(event_shape) event_ndims_ = self._maybe_get_static_event_ndims(event_ndims) if event_ndims_ is not None and event_shape_ is not None: event_ndims = event_ndims_ event_shape = event_shape_ y = b.inverse(y, **kwargs.get(b.name, {})) return ildj
return np.ones(s_, dtype_util.as_numpy_dtype(dtype or input.dtype)) return tf.ones(s, dtype or s.dtype, name) ones_like = _copy_docstring(tf.ones_like, _ones_like) range = _prefer_static( # pylint: disable=redefined-builtin tf.range, lambda start, limit=None, delta=1, dtype=None, name='range': np.arange( # pylint: disable=g-long-lambda start, limit, delta).astype( _numpy_dtype(dtype or np.array(tf.get_static_value(start)).dtype))) rank = _copy_docstring( tf.rank, lambda input, name=None: ( # pylint: disable=redefined-builtin,g-long-lambda tf.rank(input) if tensorshape_util.rank(input.shape) is None else np.int32( tensorshape_util.rank(input.shape)))) reduce_all = _prefer_static( tf.reduce_all, lambda input_tensor, axis=None, keepdims=False, name=None: np.all( # pylint: disable=g-long-lambda input_tensor, axis, keepdims=keepdims)) reduce_any = _prefer_static( tf.reduce_any, lambda input_tensor, axis=None, keepdims=False, name=None: np.any( # pylint: disable=g-long-lambda input_tensor, axis,
def _replicate(n, tensor): """Replicate the input tensor n times along a new (major) dimension.""" # TODO(axch) Does this already exist somewhere? Should it get contributed? multiples = tf.concat([[n], tf.ones([tf.rank(tensor)], dtype=n.dtype)], axis=0) return tf.tile(tensor[tf.newaxis], multiples)
def _forward_log_det_jacobian(self, x): # Let Y be a symmetric, positive definite matrix and write: # Y = X X.T # where X is lower-triangular. # # Observe that, # dY[i,j]/dX[a,b] # = d/dX[a,b] { X[i,:] X[j,:] } # = sum_{d=1}^p { I[i=a] I[d=b] X[j,d] + I[j=a] I[d=b] X[i,d] } # # To compute the Jacobian dX/dY we must represent X,Y as vectors. Since Y is # symmetric and X is lower-triangular, we need vectors of dimension: # d = p (p + 1) / 2 # where X, Y are p x p matrices, p > 0. We use a row-major mapping, i.e., # k = { i (i + 1) / 2 + j i>=j # { undef i<j # and assume zero-based indexes. When k is undef, the element is dropped. # Example: # j k # 0 1 2 3 / # 0 [ 0 . . . ] # i 1 [ 1 2 . . ] # 2 [ 3 4 5 . ] # 3 [ 6 7 8 9 ] # Write vec[.] to indicate transforming a matrix to vector via k(i,j). (With # slight abuse: k(i,j)=undef means the element is dropped.) # # We now show d vec[Y] / d vec[X] is lower triangular. Assuming both are # defined, observe that k(i,j) < k(a,b) iff (1) i<a or (2) i=a and j<b. # In both cases dvec[Y]/dvec[X]@[k(i,j),k(a,b)] = 0 since: # (1) j<=i<a thus i,j!=a. # (2) i=a>j thus i,j!=a. # # Since the Jacobian is lower-triangular, we need only compute the product # of diagonal elements: # d vec[Y] / d vec[X] @[k(i,j), k(i,j)] # = X[j,j] + I[i=j] X[i,j] # = 2 X[j,j]. # Since there is a 2 X[j,j] term for every lower-triangular element of X we # conclude: # |Jac(d vec[Y]/d vec[X])| = 2^p prod_{j=0}^{p-1} X[j,j]^{p-j}. diag = tf.linalg.diag_part(x) # We now ensure diag is columnar. Eg, if `diag = [1, 2, 3]` then the output # is `[[1], [2], [3]]` and if `diag = [[1, 2, 3], [4, 5, 6]]` then the # output is unchanged. diag = self._make_columnar(diag) with tf.control_dependencies(self._assertions(x)): # Create a vector equal to: [p, p-1, ..., 2, 1]. if tf.compat.dimension_value(x.shape[-1]) is None: p_int = tf.shape(x)[-1] p_float = tf.cast(p_int, dtype=x.dtype) else: p_int = tf.compat.dimension_value(x.shape[-1]) p_float = dtype_util.as_numpy_dtype(x.dtype)(p_int) exponents = tf.linspace(p_float, 1., p_int) sum_weighted_log_diag = tf.squeeze( tf.matmul(tf.math.log(diag), exponents[..., tf.newaxis]), axis=-1) fldj = p_float * np.log(2.) + sum_weighted_log_diag # We finally need to undo adding an extra column in non-scalar cases # where there is a single matrix as input. if tensorshape_util.rank(x.shape) is not None: if tensorshape_util.rank(x.shape) == 2: fldj = tf.squeeze(fldj, axis=-1) return fldj shape = tf.shape(fldj) maybe_squeeze_shape = tf.concat([ shape[:-1], distribution_util.pick_vector( tf.equal(tf.rank(x), 2), np.array([], dtype=np.int32), shape[-1:])], 0) return tf.reshape(fldj, maybe_squeeze_shape)
def _validate_sample_arg(self, x): """Helper which validates sample arg, e.g., input to `log_prob`.""" with tf.name_scope('validate_sample_arg'): x_ndims = (tf.rank(x) if tensorshape_util.rank(x.shape) is None else tensorshape_util.rank(x.shape)) event_ndims = (tf.size(self.event_shape_tensor()) if tensorshape_util.rank(self.event_shape) is None else tensorshape_util.rank(self.event_shape)) batch_ndims = (tf.size(self._batch_shape_unexpanded) if tensorshape_util.rank(self.batch_shape) is None else tensorshape_util.rank(self.batch_shape)) expected_batch_event_ndims = batch_ndims + event_ndims if (isinstance(x_ndims, int) and isinstance(expected_batch_event_ndims, int)): if x_ndims < expected_batch_event_ndims: raise NotImplementedError( 'Broadcasting is not supported; too few batch and event dims ' '(expected at least {}, saw {}).'.format( expected_batch_event_ndims, x_ndims)) ndims_assertion = [] elif self.validate_args: ndims_assertion = [ assert_util.assert_greater_equal( x_ndims, expected_batch_event_ndims, message=('Broadcasting is not supported; too few ' 'batch and event dims.'), name='assert_batch_and_event_ndims_large_enough'), ] if (tensorshape_util.is_fully_defined(self.batch_shape) and tensorshape_util.is_fully_defined(self.event_shape)): expected_batch_event_shape = np.int32( tensorshape_util.concatenate(self.batch_shape, self.event_shape)) else: expected_batch_event_shape = tf.concat([ self.batch_shape_tensor(), self.event_shape_tensor(), ], axis=0) sample_ndims = x_ndims - expected_batch_event_ndims if isinstance(sample_ndims, int): sample_ndims = max(sample_ndims, 0) if (isinstance(sample_ndims, int) and tensorshape_util.is_fully_defined(x.shape[sample_ndims:])): actual_batch_event_shape = np.int32(x.shape[sample_ndims:]) else: sample_ndims = tf.maximum(sample_ndims, 0) actual_batch_event_shape = tf.shape(x)[sample_ndims:] if (isinstance(expected_batch_event_shape, np.ndarray) and isinstance(actual_batch_event_shape, np.ndarray)): if any(expected_batch_event_shape != actual_batch_event_shape): raise NotImplementedError( 'Broadcasting is not supported; ' 'unexpected batch and event shape ' '(expected {}, saw {}).'.format( expected_batch_event_shape, actual_batch_event_shape)) # We need to set the final runtime-assertions to `ndims_assertion` since # its possible this assertion was created. We could add a condition to # only do so if `self.validate_args == True`, however this is redundant # as `ndims_assertion` already encodes this information. runtime_assertions = ndims_assertion elif self.validate_args: # We need to make the `ndims_assertion` a control dep because otherwise # TF itself might raise an exception owing to this assertion being # ill-defined, ie, one cannot even compare different rank Tensors. with tf.control_dependencies(ndims_assertion): shape_assertion = assert_util.assert_equal( expected_batch_event_shape, actual_batch_event_shape, message=('Broadcasting is not supported; ' 'unexpected batch and event shape.'), name='assert_batch_and_event_shape_same') runtime_assertions = [shape_assertion] else: runtime_assertions = [] return runtime_assertions
def __init__(self, initial_distribution, transition_distribution, observation_distribution, num_steps, validate_args=False, allow_nan_stats=True, name="HiddenMarkovModel"): """Initialize hidden Markov model. Args: initial_distribution: A `Categorical`-like instance. Determines probability of first hidden state in Markov chain. The number of categories must match the number of categories of `transition_distribution` as well as both the rightmost batch dimension of `transition_distribution` and the rightmost batch dimension of `observation_distribution`. transition_distribution: A `Categorical`-like instance. The rightmost batch dimension indexes the probability distribution of each hidden state conditioned on the previous hidden state. observation_distribution: A `tfp.distributions.Distribution`-like instance. The rightmost batch dimension indexes the distribution of each observation conditioned on the corresponding hidden state. num_steps: The number of steps taken in Markov chain. A python `int`. validate_args: Python `bool`, default `False`. When `True` distribution parameters are checked for validity despite possibly degrading runtime performance. When `False` invalid inputs may silently render incorrect outputs. Default value: `False`. allow_nan_stats: Python `bool`, default `True`. When `True`, statistics (e.g., mean, mode, variance) use the value "`NaN`" to indicate the result is undefined. When `False`, an exception is raised if one or more of the statistic's batch members are undefined. Default value: `True`. name: Python `str` name prefixed to Ops created by this class. Default value: "HiddenMarkovModel". Raises: ValueError: if `num_steps` is not at least 1. ValueError: if `initial_distribution` does not have scalar `event_shape`. ValueError: if `transition_distribution` does not have scalar `event_shape.` ValueError: if `transition_distribution` and `observation_distribution` are fully defined but don't have matching rightmost dimension. """ parameters = dict(locals()) # pylint: disable=protected-access with tf.name_scope(name) as name: self._runtime_assertions = [] # pylint: enable=protected-access num_steps = tf.convert_to_tensor(value=num_steps, name="num_steps") if validate_args: self._runtime_assertions += [ assert_util.assert_equal( tf.rank(num_steps), 0, message="`num_steps` must be a scalar") ] self._runtime_assertions += [ assert_util.assert_greater_equal( num_steps, 1, message="`num_steps` must be at least 1.") ] self._initial_distribution = initial_distribution self._observation_distribution = observation_distribution self._transition_distribution = transition_distribution if (initial_distribution.event_shape is not None and tensorshape_util.rank( initial_distribution.event_shape) != 0): raise ValueError( "`initial_distribution` must have scalar `event_dim`s") elif validate_args: self._runtime_assertions += [ assert_util.assert_equal( tf.shape(initial_distribution.event_shape_tensor())[0], 0, message="`initial_distribution` must have scalar" "`event_dim`s") ] if (transition_distribution.event_shape is not None and tensorshape_util.rank( transition_distribution.event_shape) != 0): raise ValueError( "`transition_distribution` must have scalar `event_dim`s") elif validate_args: self._runtime_assertions += [ assert_util.assert_equal( tf.shape( transition_distribution.event_shape_tensor())[0], 0, message="`transition_distribution` must have scalar" "`event_dim`s") ] if (transition_distribution.batch_shape is not None and tensorshape_util.rank( transition_distribution.batch_shape) == 0): raise ValueError( "`transition_distribution` can't have scalar batches") elif validate_args: self._runtime_assertions += [ assert_util.assert_greater( tf.size(transition_distribution.batch_shape_tensor()), 0, message="`transition_distribution` can't have scalar " "batches") ] if (observation_distribution.batch_shape is not None and tensorshape_util.rank( observation_distribution.batch_shape) == 0): raise ValueError( "`observation_distribution` can't have scalar batches") elif validate_args: self._runtime_assertions += [ assert_util.assert_greater( tf.size(observation_distribution.batch_shape_tensor()), 0, message="`observation_distribution` can't have scalar " "batches") ] # Infer number of hidden states and check consistency # between transitions and observations with tf.control_dependencies(self._runtime_assertions): self._num_states = ( (transition_distribution.batch_shape and transition_distribution.batch_shape[-1]) or transition_distribution.batch_shape_tensor()[-1]) observation_states = ( (observation_distribution.batch_shape and observation_distribution.batch_shape[-1]) or observation_distribution.batch_shape_tensor()[-1]) if (tf.is_tensor(self._num_states) or tf.is_tensor(observation_states)): if validate_args: self._runtime_assertions += [ assert_util.assert_equal( self._num_states, observation_states, message="`transition_distribution` and " "`observation_distribution` must agree on " "last dimension of batch size") ] elif self._num_states != observation_states: raise ValueError("`transition_distribution` and " "`observation_distribution` must agree on " "last dimension of batch size") self._log_init = _extract_log_probs(self._num_states, initial_distribution) self._log_trans = _extract_log_probs(self._num_states, transition_distribution) self._num_steps = num_steps self._num_states = tf.shape(self._log_init)[-1] self._underlying_event_rank = tf.size( self._observation_distribution.event_shape_tensor()) num_steps_ = tf.get_static_value(num_steps) if num_steps_ is not None: self.static_event_shape = tf.TensorShape([ num_steps_ ]).concatenate(self._observation_distribution.event_shape) else: self.static_event_shape = None with tf.control_dependencies(self._runtime_assertions): self.static_batch_shape = tf.broadcast_static_shape( self._initial_distribution.batch_shape, tf.broadcast_static_shape( self._transition_distribution.batch_shape[:-1], self._observation_distribution.batch_shape[:-1])) # pylint: disable=protected-access super(HiddenMarkovModel, self).__init__( dtype=self._observation_distribution.dtype, reparameterization_type=reparameterization.NOT_REPARAMETERIZED, validate_args=validate_args, allow_nan_stats=allow_nan_stats, parameters=parameters, name=name) # pylint: enable=protected-access self._parameters = parameters
def _observation_log_probs(self, observations, mask): """Compute and shape tensor of log probs associated with observations..""" # Let E be the underlying event shape # M the number of steps in the HMM # N the number of states of the HMM # # Then the incoming observations have shape # # observations : batch_o [M] E # # and the mask (if present) has shape # # mask : batch_m [M] # # Let this HMM distribution have batch shape batch_d # We need to broadcast all three of these batch shapes together # into the shape batch. # # We need to move the step dimension to the first dimension to make # them suitable for folding or scanning over. # # When we call `log_prob` for our observations we need to # do this for each state the observation could correspond to. # We do this by expanding the dimensions by 1 so we end up with: # # observations : [M] batch [1] [E] # # After calling `log_prob` we get # # observation_log_probs : [M] batch [N] # # We wish to use `mask` to select from this so we also # reshape and broadcast it up to shape # # mask : [M] batch [N] observation_tensor_shape = tf.shape(observations) observation_batch_shape = observation_tensor_shape[:-1 - self. _underlying_event_rank] observation_event_shape = observation_tensor_shape[ -1 - self._underlying_event_rank:] if mask is not None: mask_tensor_shape = tf.shape(mask) mask_batch_shape = mask_tensor_shape[:-1] batch_shape = tf.broadcast_dynamic_shape(observation_batch_shape, self.batch_shape_tensor()) if mask is not None: batch_shape = tf.broadcast_dynamic_shape(batch_shape, mask_batch_shape) observations = tf.broadcast_to( observations, tf.concat([batch_shape, observation_event_shape], axis=0)) observation_rank = tf.rank(observations) underlying_event_rank = self._underlying_event_rank observations = distribution_util.move_dimension( observations, observation_rank - underlying_event_rank - 1, 0) observations = tf.expand_dims(observations, observation_rank - underlying_event_rank) observation_log_probs = self._observation_distribution.log_prob( observations) if mask is not None: mask = tf.broadcast_to( mask, tf.concat([batch_shape, [self._num_steps]], axis=0)) mask = distribution_util.move_dimension(mask, -1, 0) observation_log_probs = tf.where( mask[..., tf.newaxis], tf.zeros_like(observation_log_probs), observation_log_probs) return observation_log_probs
def _transpose(self, x, perm): perm = self._make_perm(tf.rank(x), perm) return tf.transpose(a=x, perm=perm)
def _log_prob(self, x): if self.input_output_cholesky: x_sqrt = x else: # Complexity: O(nbk**3) x_sqrt = tf.linalg.cholesky(x) batch_shape = self.batch_shape_tensor() event_shape = self.event_shape_tensor() x_ndims = tf.rank(x_sqrt) num_singleton_axes_to_prepend = ( tf.maximum(tf.size(batch_shape) + 2, x_ndims) - x_ndims) x_with_prepended_singletons_shape = tf.concat([ tf.ones([num_singleton_axes_to_prepend], dtype=tf.int32), tf.shape(x_sqrt) ], 0) x_sqrt = tf.reshape(x_sqrt, x_with_prepended_singletons_shape) ndims = tf.rank(x_sqrt) # sample_ndims = ndims - batch_ndims - event_ndims sample_ndims = ndims - tf.size(batch_shape) - 2 sample_shape = tf.shape(x_sqrt)[:sample_ndims] # We need to be able to pre-multiply each matrix by its corresponding # batch scale matrix. Since a Distribution Tensor supports multiple # samples per batch, this means we need to reshape the input matrix `x` # so that the first b dimensions are batch dimensions and the last two # are of shape [dimension, dimensions*number_of_samples]. Doing these # gymnastics allows us to do a batch_solve. # # After we're done with sqrt_solve (the batch operation) we need to undo # this reshaping so what we're left with is a Tensor partitionable by # sample, batch, event dimensions. # Complexity: O(nbk**2) since transpose must access every element. scale_sqrt_inv_x_sqrt = x_sqrt perm = tf.concat( [tf.range(sample_ndims, ndims), tf.range(0, sample_ndims)], 0) scale_sqrt_inv_x_sqrt = tf.transpose(a=scale_sqrt_inv_x_sqrt, perm=perm) last_dim_size = ( tf.cast(self.dimension, dtype=tf.int32) * tf.reduce_prod(x_with_prepended_singletons_shape[:sample_ndims])) shape = tf.concat([ x_with_prepended_singletons_shape[sample_ndims:-2], [tf.cast(self.dimension, dtype=tf.int32), last_dim_size] ], axis=0) scale_sqrt_inv_x_sqrt = tf.reshape(scale_sqrt_inv_x_sqrt, shape) # Complexity: O(nbM*k) where M is the complexity of the operator solving a # vector system. For LinearOperatorLowerTriangular, each solve is O(k**2) so # this step has complexity O(nbk^3). scale_sqrt_inv_x_sqrt = self.scale_operator.solve( scale_sqrt_inv_x_sqrt) # Undo make batch-op ready. # Complexity: O(nbk**2) shape = tf.concat( [tf.shape(scale_sqrt_inv_x_sqrt)[:-2], event_shape, sample_shape], axis=0) scale_sqrt_inv_x_sqrt = tf.reshape(scale_sqrt_inv_x_sqrt, shape) perm = tf.concat([ tf.range(ndims - sample_ndims, ndims), tf.range(0, ndims - sample_ndims) ], 0) scale_sqrt_inv_x_sqrt = tf.transpose(a=scale_sqrt_inv_x_sqrt, perm=perm) # Write V = SS', X = LL'. Then: # tr[inv(V) X] = tr[inv(S)' inv(S) L L'] # = tr[inv(S) L L' inv(S)'] # = tr[(inv(S) L) (inv(S) L)'] # = sum_{ik} (inv(S) L)_{ik}**2 # The second equality follows from the cyclic permutation property. # Complexity: O(nbk**2) trace_scale_inv_x = tf.reduce_sum(tf.square(scale_sqrt_inv_x_sqrt), axis=[-2, -1]) # Complexity: O(nbk) half_log_det_x = tf.reduce_sum(tf.math.log( tf.linalg.diag_part(x_sqrt)), axis=[-1]) # Complexity: O(nbk**2) log_prob = ((self.df - self.dimension - 1.) * half_log_det_x - 0.5 * trace_scale_inv_x - self.log_normalization()) # Set shape hints. # Try to merge what we know from the input x with what we know from the # parameters of this distribution. if tensorshape_util.rank( x.shape) is not None and tensorshape_util.rank( self.batch_shape) is not None: tensorshape_util.set_shape( log_prob, tf.broadcast_static_shape(x.shape[:-2], self.batch_shape)) return log_prob