def _multi_gamma_sequence(self, a, p, name="multi_gamma_sequence"): """Creates sequence used in multivariate (di)gamma; shape = shape(a)+[p].""" with self._name_and_control_scope(name): # Linspace only takes scalars, so we'll add in the offset afterwards. seq = tf.linspace(tf.constant(0., dtype=self.dtype), 0.5 - 0.5 * p, tf.cast(p, tf.int32)) return seq + tf.expand_dims(a, [-1])
def _variance(self): # Because df is a scalar, we need to expand dimensions to match # scale_operator. We use ellipses notation (...) to select all dimensions # and add two dimensions to the end. df = self.df[..., tf.newaxis, tf.newaxis] x = tf.sqrt(df) * self._square_scale_operator() d = tf.expand_dims(tf.linalg.diag_part(x), -1) v = tf.square(x) + tf.matmul(d, d, adjoint_b=True) return v
def _expand_to_event_rank(self, x): """Expand the rank of x up to static_event_rank times for broadcasting. The static event rank was checked to not be None at construction time. Args: x: A tensor to expand. Returns: The expanded tensor. """ expanded_x = x for _ in range(tensorshape_util.rank(self.event_shape)): expanded_x = tf.expand_dims(expanded_x, -1) return expanded_x
def _broadcast_event_and_samples(event, samples, event_ndims): """Broadcasts the event or samples.""" # This is the shape of self.samples, without the samples axis, i.e. the shape # of the result of a call to dist.sample(). This way we can broadcast it with # event to get a properly-sized event, then add the singleton dim back at # -event_ndims - 1. samples_shape = tf.concat([ tf.shape(samples)[:-event_ndims - 1], tf.shape(samples)[tf.rank(samples) - event_ndims:] ], axis=0) event = event * tf.ones(samples_shape, dtype=event.dtype) event = tf.expand_dims(event, axis=-event_ndims - 1) samples = samples * tf.ones_like(event, dtype=samples.dtype) return event, samples
def _stddev(self): samples = tf.convert_to_tensor(self._samples) axis = self._samples_axis r = samples - tf.expand_dims(self._mean(samples), axis=axis) var = tf.reduce_mean(tf.square(r), axis=axis) return tf.sqrt(var)
def _reparameterize_sample(self, x): """Adds reparameterization (pathwise) gradients to samples of the mixture. Implicit reparameterization gradients are dx/dphi = -(d transform(x, phi) / dx)^-1 * d transform(x, phi) / dphi, where transform(x, phi) is distributional transform that removes all parameters from samples x. We implement them by replacing x with -stop_gradient(d transform(x, phi) / dx)^-1 * transform(x, phi)] for the backward pass (gradient computation). The derivative of this quantity w.r.t. phi is then the implicit reparameterization gradient. Note that this replaces the gradients w.r.t. both the mixture distribution parameters and components distributions parameters. Limitations: 1. Fundamental: components must be fully reparameterized. 2. Distributional transform is currently only implemented for factorized components. 3. Distributional transform currently only works for known rank of the batch tensor. Arguments: x: Sample of mixture distribution Returns: Tensor with same value as x, but with reparameterization gradients """ # Remove the existing gradients of x wrt parameters of the components. x = tf.stop_gradient(x) x_2d_shape = [-1, self._event_size] # [S*prod(B), prod(E)] # Perform distributional transform of x in [S, B, E] shape, # but have Jacobian of size [S*prod(B), prod(E), prod(E)]. def reshaped_distributional_transform(x_2d): return tf.reshape( self._distributional_transform(tf.reshape(x_2d, tf.shape(x))), x_2d_shape) # transform_2d: [S*prod(B), prod(E)] # jacobian: [S*prod(B), prod(E), prod(E)] x_2d = tf.reshape(x, x_2d_shape) with tf.GradientTape() as tape: tape.watch(x_2d) transform_2d = reshaped_distributional_transform(x_2d) jacobian = tape.batch_jacobian(transform_2d, x_2d) # We only provide the first derivative; the second derivative computed by # autodiff would be incorrect, so we raise an error if it is requested. transform_2d = _prevent_2nd_derivative(transform_2d) # Compute [- stop_gradient(jacobian)^-1 * transform] by solving a linear # system. The Jacobian is lower triangular because the distributional # transform for i-th event dimension does not depend on the next # dimensions. surrogate_x_2d = -tf.linalg.triangular_solve( tf.stop_gradient(jacobian), tf.expand_dims(transform_2d, axis=-1), lower=True) # [S*prod(B), prod(E), 1] surrogate_x = tf.reshape(surrogate_x_2d, tf.shape(x)) # Replace gradients of x with gradients of surrogate_x, but keep the value. return x + (surrogate_x - tf.stop_gradient(surrogate_x))
def _sample_n(self, n, seed=None): sample_and_batch_shape = tf.concat([[n], self.batch_shape_tensor()], 0) flat_batch_and_sample_shape = tf.stack( [tf.reduce_prod(self.batch_shape_tensor()), n]) # In order to be reparameterizable we sample on the truncated_normal of # unit variance and mean and scale (but with the standardized # truncation bounds). @tf.custom_gradient def _std_samples_with_gradients(lower, upper): """Standard truncated Normal with gradient support for low, high.""" # Note: Unlike the convention in tf_probability, # parameterized_truncated_normal returns a tensor with the final dimension # being the sample dimension. std_samples = random_ops.parameterized_truncated_normal( shape=flat_batch_and_sample_shape, means=0.0, stddevs=1.0, minvals=lower, maxvals=upper, dtype=self.dtype, seed=seed) def grad(dy): """Computes a derivative for the min and max parameters. This function implements the derivative wrt the truncation bounds, which get blocked by the sampler. We use a custom expression for numerical stability instead of automatic differentiation on CDF for implicit gradients. Args: dy: output gradients Returns: The standard normal samples and the gradients wrt the upper bound and lower bound. """ # std_samples has an extra dimension (the sample dimension), expand # lower and upper so they broadcast along this dimension. # See note above regarding parameterized_truncated_normal, the sample # dimension is the final dimension. lower_broadcast = lower[..., tf.newaxis] upper_broadcast = upper[..., tf.newaxis] cdf_samples = ((special_math.ndtr(std_samples) - special_math.ndtr(lower_broadcast)) / (special_math.ndtr(upper_broadcast) - special_math.ndtr(lower_broadcast))) # tiny, eps are tolerance parameters to ensure we stay away from giving # a zero arg to the log CDF expression. tiny = np.finfo(dtype_util.as_numpy_dtype(self.dtype)).tiny eps = np.finfo(dtype_util.as_numpy_dtype(self.dtype)).eps cdf_samples = tf.clip_by_value(cdf_samples, tiny, 1 - eps) du = tf.exp(0.5 * (std_samples**2 - upper_broadcast**2) + tf.math.log(cdf_samples)) dl = tf.exp(0.5 * (std_samples**2 - lower_broadcast**2) + tf.math.log1p(-cdf_samples)) # Reduce the gradient across the samples grad_u = tf.reduce_sum(dy * du, axis=-1) grad_l = tf.reduce_sum(dy * dl, axis=-1) return [grad_l, grad_u] return std_samples, grad std_samples = _std_samples_with_gradients( tf.reshape(self._standardized_low, [-1]), tf.reshape(self._standardized_high, [-1])) # The returned shape is [flat_batch x n] std_samples = tf.transpose(a=std_samples, perm=[1, 0]) std_samples = tf.reshape(std_samples, sample_and_batch_shape) samples = (std_samples * tf.expand_dims(self._scale, axis=0) + tf.expand_dims(self._loc, axis=0)) return samples
def _observation_log_probs(self, observations, mask): """Compute and shape tensor of log probs associated with observations..""" # Let E be the underlying event shape # M the number of steps in the HMM # N the number of states of the HMM # # Then the incoming observations have shape # # observations : batch_o [M] E # # and the mask (if present) has shape # # mask : batch_m [M] # # Let this HMM distribution have batch shape batch_d # We need to broadcast all three of these batch shapes together # into the shape batch. # # We need to move the step dimension to the first dimension to make # them suitable for folding or scanning over. # # When we call `log_prob` for our observations we need to # do this for each state the observation could correspond to. # We do this by expanding the dimensions by 1 so we end up with: # # observations : [M] batch [1] [E] # # After calling `log_prob` we get # # observation_log_probs : [M] batch [N] # # We wish to use `mask` to select from this so we also # reshape and broadcast it up to shape # # mask : [M] batch [N] observation_tensor_shape = tf.shape(observations) observation_batch_shape = observation_tensor_shape[:-1 - self. _underlying_event_rank] observation_event_shape = observation_tensor_shape[ -1 - self._underlying_event_rank:] if mask is not None: mask_tensor_shape = tf.shape(mask) mask_batch_shape = mask_tensor_shape[:-1] batch_shape = tf.broadcast_dynamic_shape(observation_batch_shape, self.batch_shape_tensor()) if mask is not None: batch_shape = tf.broadcast_dynamic_shape(batch_shape, mask_batch_shape) observations = tf.broadcast_to( observations, tf.concat([batch_shape, observation_event_shape], axis=0)) observation_rank = tf.rank(observations) underlying_event_rank = self._underlying_event_rank observations = distribution_util.move_dimension( observations, observation_rank - underlying_event_rank - 1, 0) observations = tf.expand_dims(observations, observation_rank - underlying_event_rank) observation_log_probs = self._observation_distribution.log_prob( observations) if mask is not None: mask = tf.broadcast_to( mask, tf.concat([batch_shape, [self._num_steps]], axis=0)) mask = distribution_util.move_dimension(mask, -1, 0) observation_log_probs = tf.where( mask[..., tf.newaxis], tf.zeros_like(observation_log_probs), observation_log_probs) return observation_log_probs