def _batch_shape(self): scalar_shape = tf.TensorShape([]) return tf.broadcast_static_shape( scalar_shape if self.amplitude is None else self.amplitude.shape, scalar_shape if self.length_scale is None else self.length_scale.shape)
def _batch_shape(self): return tf.broadcast_static_shape(self.loc.shape, self.concentration.shape)
def _batch_shape(self): return tf.broadcast_static_shape( (self._probs if self._logits is None else self._logits).shape[:-1], self.total_count.shape)
def _mvnormal_quasi(sample_shape, mean, random_type, seed, covariance_matrix=None, scale_matrix=None, validate_args=False, dtype=None, **kwargs): """Returns normal draws using low-discrepancy sequences.""" if scale_matrix is None and covariance_matrix is None: scale_matrix = tf.linalg.eye(tf.shape(mean)[-1], dtype=mean.dtype) elif scale_matrix is None and covariance_matrix is not None: covariance_matrix = tf.convert_to_tensor(covariance_matrix, dtype=dtype, name='covariance_matrix') scale_matrix = tf.linalg.cholesky(covariance_matrix) else: scale_matrix = tf.convert_to_tensor(scale_matrix, dtype=dtype, name='scale_matrix') scale_shape = scale_matrix.shape dim = scale_shape[-1] if mean is None: mean = tf.zeros([dim], dtype=scale_matrix.dtype) # Batch shape of the output batch_shape = tf.broadcast_static_shape(mean.shape, scale_shape[:-1]) # Reverse elements of the batch shape batch_shape_reverse = tf.TensorShape(reversed(batch_shape)) # Transposed shape of the output output_shape_t = tf.concat([batch_shape_reverse, sample_shape], -1) # Number of quasi random samples num_samples = tf.reduce_prod(output_shape_t) // dim # Number of initial low discrepancy sequence numbers to skip if 'skip' in kwargs: skip = kwargs['skip'] else: skip = 0 if random_type == RandomType.SOBOL: # Shape [num_samples, dim] of the Sobol samples low_discrepancy_seq = sobol.sample(dim=dim, num_results=num_samples, skip=skip, dtype=mean.dtype) else: # HALTON or HALTON_RANDOMIZED random_dtype if 'randomization_params' in kwargs: randomization_params = kwargs['randomization_params'] else: randomization_params = None randomized = random_type == RandomType.HALTON_RANDOMIZED # Shape [num_samples, dim] of the Sobol samples low_discrepancy_seq, _ = halton.sample( dim=dim, sequence_indices=tf.range(skip, skip + num_samples), randomized=randomized, randomization_params=randomization_params, seed=seed, validate_args=validate_args, dtype=mean.dtype) # Transpose to the shape [dim, num_samples] low_discrepancy_seq = tf.transpose(low_discrepancy_seq) size_sample = tf.size(sample_shape) size_batch = tf.size(batch_shape) # Permutation for `output_shape_t` to the output shape permutation = tf.concat([ tf.range(size_batch, size_batch + size_sample), tf.range(size_batch - 1, -1, -1) ], -1) # Reshape Sobol samples to the correct output shape low_discrepancy_seq = tf.transpose( tf.reshape(low_discrepancy_seq, output_shape_t), permutation) # Apply inverse Normal CDF to Sobol samples to obtain the corresponding # Normal samples samples = tf.math.erfinv((low_discrepancy_seq - 0.5) * 2) * _SQRT_2 return mean + tf.linalg.matvec(scale_matrix, samples)
def _parameter_control_dependencies(self, is_init): """Validate parameters.""" bw, bh, kd = None, None, None try: shape = tf.broadcast_static_shape(self.bin_widths.shape, self.bin_heights.shape) except ValueError as e: raise ValueError( '`bin_widths`, `bin_heights` must broadcast: {}'.format( str(e))) bin_sizes_shape = shape try: shape = tf.broadcast_static_shape(shape[:-1], self.knot_slopes.shape[:-1]) except ValueError as e: raise ValueError( '`bin_widths`, `bin_heights`, and `knot_slopes` must broadcast on ' 'batch axes: {}'.format(str(e))) assertions = [] if (tensorshape_util.is_fully_defined(bin_sizes_shape[-1:]) and tensorshape_util.is_fully_defined( self.knot_slopes.shape[-1:])): if tensorshape_util.rank(self.knot_slopes.shape) > 0: num_interior_knots = tensorshape_util.dims( bin_sizes_shape)[-1] - 1 if tensorshape_util.dims(self.knot_slopes.shape)[-1] not in ( 1, num_interior_knots): raise ValueError( 'Innermost axis of non-scalar `knot_slopes` must broadcast with ' '{}; got {}.'.format(num_interior_knots, self.knot_slopes.shape)) elif self.validate_args: if is_init != any( tensor_util.is_ref(t) for t in (self.bin_widths, self.bin_heights, self.knot_slopes)): bw = tf.convert_to_tensor( self.bin_widths) if bw is None else bw bh = tf.convert_to_tensor( self.bin_heights) if bh is None else bh kd = _ensure_at_least_1d( self.knot_slopes) if kd is None else kd shape = tf.broadcast_dynamic_shape( tf.shape((bw + bh)[..., :-1]), tf.shape(kd)) assertions.append( assert_util.assert_greater( tf.shape(shape)[0], tf.zeros([], dtype=shape.dtype), message= '`(bin_widths + bin_heights)[..., :-1]` must broadcast ' 'with `knot_slopes` to at least 1-D.')) if not self.validate_args: assert not assertions return assertions if (is_init != tensor_util.is_ref(self.bin_widths) or is_init != tensor_util.is_ref(self.bin_heights)): bw = tf.convert_to_tensor(self.bin_widths) if bw is None else bw bh = tf.convert_to_tensor(self.bin_heights) if bh is None else bh assertions += [ assert_util.assert_near( tf.reduce_sum(bw, axis=-1), tf.reduce_sum(bh, axis=-1), message='`sum(bin_widths, axis=-1)` must equal ' '`sum(bin_heights, axis=-1)`.'), ] if is_init != tensor_util.is_ref(self.bin_widths): bw = tf.convert_to_tensor(self.bin_widths) if bw is None else bw assertions += [ assert_util.assert_positive( bw, message='`bin_widths` must be positive.'), ] if is_init != tensor_util.is_ref(self.bin_heights): bh = tf.convert_to_tensor(self.bin_heights) if bh is None else bh assertions += [ assert_util.assert_positive( bh, message='`bin_heights` must be positive.'), ] if is_init != tensor_util.is_ref(self.knot_slopes): kd = _ensure_at_least_1d(self.knot_slopes) if kd is None else kd assertions += [ assert_util.assert_positive( kd, message='`knot_slopes` must be positive.'), ] return assertions
def _batch_shape(self): return tf.broadcast_static_shape(self.loc.shape, self.scale.shape)
def _batch_shape(self): if self.to_shape is None: return tf.broadcast_static_shape( self.distribution.batch_shape, tf.TensorShape(tf.get_static_value(self.with_shape))) return tf.TensorShape(tf.get_static_value(self.to_shape))
def _batch_shape(self): return tf.broadcast_static_shape(self.loc.shape, self.cutpoints.shape[:-1])
def _kl_brute_force(a, b, name=None): """Batched KL divergence `KL(a || b)` for multivariate Normals. With `X`, `Y` both multivariate Normals in `R^k` with means `mu_a`, `mu_b` and covariance `C_a`, `C_b` respectively, ``` KL(a || b) = 0.5 * ( L - k + T + Q ), L := Log[Det(C_b)] - Log[Det(C_a)] T := trace(C_b^{-1} C_a), Q := (mu_b - mu_a)^T C_b^{-1} (mu_b - mu_a), ``` This `Op` computes the trace by solving `C_b^{-1} C_a`. Although efficient methods for solving systems with `C_b` may be available, a dense version of (the square root of) `C_a` is used, so performance is `O(B s k**2)` where `B` is the batch size, and `s` is the cost of solving `C_b x = y` for vectors `x` and `y`. Args: a: Instance of `MultivariateNormalLinearOperator`. b: Instance of `MultivariateNormalLinearOperator`. name: (optional) name to use for created ops. Default "kl_mvn". Returns: Batchwise `KL(a || b)`. """ def squared_frobenius_norm(x): """Helper to make KL calculation slightly more readable.""" # http://mathworld.wolfram.com/FrobeniusNorm.html # The gradient of KL[p,q] is not defined when p==q. The culprit is # tf.norm, i.e., we cannot use the commented out code. # return tf.square(tf.norm(x, ord="fro", axis=[-2, -1])) return tf.reduce_sum(tf.square(x), axis=[-2, -1]) # TODO(b/35041439): See also b/35040945. Remove this function once LinOp # supports something like: # A.inverse().solve(B).norm(order='fro', axis=[-1, -2]) def is_diagonal(x): """Helper to identify if `LinearOperator` has only a diagonal component.""" return (isinstance(x, tf.linalg.LinearOperatorIdentity) or isinstance(x, tf.linalg.LinearOperatorScaledIdentity) or isinstance(x, tf.linalg.LinearOperatorDiag)) with tf.name_scope(name or 'kl_mvn'): # Calculation is based on: # http://stats.stackexchange.com/questions/60680/kl-divergence-between-two-multivariate-gaussians # and, # https://en.wikipedia.org/wiki/Matrix_norm#Frobenius_norm # i.e., # If Ca = AA', Cb = BB', then # tr[inv(Cb) Ca] = tr[inv(B)' inv(B) A A'] # = tr[inv(B) A A' inv(B)'] # = tr[(inv(B) A) (inv(B) A)'] # = sum_{ij} (inv(B) A)_{ij}**2 # = ||inv(B) A||_F**2 # where ||.||_F is the Frobenius norm and the second equality follows from # the cyclic permutation property. if is_diagonal(a.scale) and is_diagonal(b.scale): # Using `stddev` because it handles expansion of Identity cases. b_inv_a = (a.stddev() / b.stddev())[..., tf.newaxis] else: b_inv_a = b.scale.solve(a.scale.to_dense()) kl_div = (b.scale.log_abs_determinant() - a.scale.log_abs_determinant() + 0.5 * (-tf.cast(a.scale.domain_dimension_tensor(), a.dtype) + squared_frobenius_norm(b_inv_a) + squared_frobenius_norm( b.scale.solve((b.mean() - a.mean())[..., tf.newaxis])))) tensorshape_util.set_shape( kl_div, tf.broadcast_static_shape(a.batch_shape, b.batch_shape)) return kl_div
def _batch_shape(self): x = self._probs if self._logits is None else self._logits return tf.broadcast_static_shape(self.total_count.shape, x.shape)
def _batch_shape(self): if self._is_fixed_inputs_empty(): return self._base_kernel.batch_shape return tf.broadcast_static_shape( self._base_kernel.batch_shape, self._fixed_inputs.shape[:-(self._base_kernel.feature_ndims + 1)])
def _batch_shape(self): return tf.broadcast_static_shape( self.distribution.batch_shape, self.mixture_distribution.logits.shape)[:-1]
def _parameter_control_dependencies(self, is_init): assertions = [] if is_init: axis_ = tf.get_static_value(self._axis) if axis_ is not None and axis_ < 0: raise ValueError('Axis should be positive, %d was given' % axis_) if axis_ is None: assertions.append(tf.assert_greater_equal(axis_, 0)) all_event_shapes = [d.event_shape for d in self._distributions] if all( tensorshape_util.is_fully_defined(event_shape) for event_shape in all_event_shapes): if all_event_shapes[1:] != all_event_shapes[:-1]: raise ValueError( 'Distributions must have the same `event_shape`;' 'found: {}' % all_event_shapes) all_batch_shapes = [d.batch_shape for d in self._distributions] if all( tensorshape_util.is_fully_defined(batch_shape) for batch_shape in all_batch_shapes): batch_shape = all_batch_shapes[0].as_list() batch_shape[self._axis] = 1 for b in all_batch_shapes[1:]: b = b.as_list() if len(batch_shape) != len(b): raise ValueError( 'Incompatible batch shape % s with %s' % (batch_shape, b)) b[self._axis] = 1 tf.broadcast_static_shape( tensorshape_util.constant_value_as_shape(batch_shape), tensorshape_util.constant_value_as_shape(b)) if not self.validate_args: return [] if self.validate_args: # Validate that event shapes all match. all_event_shapes = [d.event_shape for d in self._distributions] if not all( tensorshape_util.is_fully_defined(event_shape) for event_shape in all_event_shapes): all_event_shape_tensors = [ d.event_shape_tensor() for d in self._distributions ] def _get_shapes(static_shape, dynamic_shape): if tensorshape_util.is_fully_defined(static_shape): return static_shape else: return dynamic_shape event_shapes = tf.nest.map_structure(_get_shapes, all_event_shapes, all_event_shape_tensors) event_shapes = tf.nest.flatten(event_shapes) assertions.extend( assert_util.assert_equal( e1, e2, message='Distributions should have same event shapes.') for e1, e2 in zip(event_shapes[1:], event_shapes[:-1])) # Validate that batch shapes are broadcastable and concatenable along # the specified axis. if not all( tensorshape_util.is_fully_defined(d.batch_shape) for d in self._distributions): for i, d in enumerate(self._distributions[:-1]): assertions.append( tf.assert_equal( tf.size(d.batch_shape_tensor()), tf.size( self._distributions[i + 1].batch_shape_tensor()))) batch_shape_tensors = [ ps.tensor_scatter_nd_update(d.batch_shape_tensor(), updates=1, indices=[self._axis]) for d in self._distributions ] assertions.append( functools.reduce(tf.broadcast_dynamic_shape, batch_shape_tensors[1:], batch_shape_tensors[:-1])) return assertions