def _variance(self): scale = self.alpha_sum * math_ops.sqrt(1.0 + self.alpha_sum) alpha = self.alpha / scale outer_prod = -math_ops.batch_matmul( array_ops.expand_dims(alpha, dim=-1), array_ops.expand_dims(alpha, dim=-2) # column ) # row return array_ops.batch_matrix_set_diag(outer_prod, alpha * (self.alpha_sum / scale - alpha))
def _variance(self): p = self.p * array_ops.expand_dims(array_ops.ones_like(self.n), -1) outer_prod = math_ops.batch_matmul( array_ops.expand_dims(self._mean_val, -1), array_ops.expand_dims(p, -2)) return array_ops.batch_matrix_set_diag( -outer_prod, self._mean_val - self._mean_val * p)
def batch_matrix_diag_transform(matrix, transform=None, name=None): """Transform diagonal of [batch-]matrix, leave rest of matrix unchanged. Create a trainable covariance defined by a Cholesky factor: ```python # Transform network layer into 2 x 2 array. matrix_values = tf.contrib.layers.fully_connected(activations, 4) matrix = tf.reshape(matrix_values, (batch_size, 2, 2)) # Make the diagonal positive. If the upper triangle was zero, this would be a # valid Cholesky factor. chol = batch_matrix_diag_transform(matrix, transform=tf.nn.softplus) # OperatorPDCholesky ignores the upper triangle. operator = OperatorPDCholesky(chol) ``` Example of heteroskedastic 2-D linear regression. ```python # Get a trainable Cholesky factor. matrix_values = tf.contrib.layers.fully_connected(activations, 4) matrix = tf.reshape(matrix_values, (batch_size, 2, 2)) chol = batch_matrix_diag_transform(matrix, transform=tf.nn.softplus) # Get a trainable mean. mu = tf.contrib.layers.fully_connected(activations, 2) # This is a fully trainable multivariate normal! dist = tf.contrib.distributions.MVNCholesky(mu, chol) # Standard log loss. Minimizing this will "train" mu and chol, and then dist # will be a distribution predicting labels as multivariate Gaussians. loss = -1 * tf.reduce_mean(dist.log_pdf(labels)) ``` Args: matrix: Rank `R` `Tensor`, `R >= 2`, where the last two dimensions are equal. transform: Element-wise function mapping `Tensors` to `Tensors`. To be applied to the diagonal of `matrix`. If `None`, `matrix` is returned unchanged. Defaults to `None`. name: A name to give created ops. Defaults to "batch_matrix_diag_transform". Returns: A `Tensor` with same shape and `dtype` as `matrix`. """ with ops.name_scope(name, "batch_matrix_diag_transform", [matrix]): matrix = ops.convert_to_tensor(matrix, name="matrix") if transform is None: return matrix # Replace the diag with transformed diag. diag = array_ops.batch_matrix_diag_part(matrix) transformed_diag = transform(diag) transformed_mat = array_ops.batch_matrix_set_diag(matrix, transformed_diag) return transformed_mat
def _sample_n(self, n, seed): batch_shape = self.batch_shape() event_shape = self.event_shape() batch_ndims = array_ops.shape(batch_shape)[0] ndims = batch_ndims + 3 # sample_ndims=1, event_ndims=2 shape = array_ops.concat(0, ((n,), batch_shape, event_shape)) # Complexity: O(nbk^2) x = random_ops.random_normal(shape=shape, mean=0., stddev=1., dtype=self.dtype, seed=seed) # Complexity: O(nbk) # This parametrization is equivalent to Chi2, i.e., # ChiSquared(k) == Gamma(alpha=k/2, beta=1/2) g = random_ops.random_gamma(shape=(n,), alpha=self._multi_gamma_sequence( 0.5 * self.df, self.dimension), beta=0.5, dtype=self.dtype, seed=seed) # Complexity: O(nbk^2) x = array_ops.batch_matrix_band_part(x, -1, 0) # Tri-lower. # Complexity: O(nbk) x = array_ops.batch_matrix_set_diag(x, math_ops.sqrt(g)) # Make batch-op ready. # Complexity: O(nbk^2) perm = array_ops.concat(0, (math_ops.range(1, ndims), (0,))) x = array_ops.transpose(x, perm) shape = array_ops.concat(0, (batch_shape, (event_shape[0], -1))) x = array_ops.reshape(x, shape) # Complexity: O(nbM) where M is the complexity of the operator solving a # vector system. E.g., for OperatorPDDiag, each matmul is O(k^2), so # this complexity is O(nbk^2). For OperatorPDCholesky, each matmul is # O(k^3) so this step has complexity O(nbk^3). x = self.scale_operator_pd.sqrt_matmul(x) # Undo make batch-op ready. # Complexity: O(nbk^2) shape = array_ops.concat(0, (batch_shape, event_shape, (n,))) x = array_ops.reshape(x, shape) perm = array_ops.concat(0, ((ndims-1,), math_ops.range(0, ndims-1))) x = array_ops.transpose(x, perm) if not self.cholesky_input_output_matrices: # Complexity: O(nbk^3) x = math_ops.batch_matmul(x, x, adj_y=True) return x
def _variance(self): alpha_sum = array_ops.expand_dims(self.alpha_sum, -1) normalized_alpha = self.alpha / alpha_sum variance = -math_ops.batch_matmul( array_ops.expand_dims(normalized_alpha, -1), array_ops.expand_dims(normalized_alpha, -2)) variance = array_ops.batch_matrix_set_diag( variance, normalized_alpha * (1. - normalized_alpha)) shared_factor = (self.n * (alpha_sum + self.n) / (alpha_sum + 1) * array_ops.ones_like(self.alpha)) variance *= array_ops.expand_dims(shared_factor, -1) return variance
def _BatchMatrixSetDiagGrad(op, grad): diag_shape = op.inputs[1].get_shape() diag_shape = diag_shape.merge_with(op.inputs[0].get_shape()[:-1]) diag_shape = diag_shape.merge_with(grad.get_shape()[:-1]) if diag_shape.is_fully_defined(): diag_shape = diag_shape.as_list() else: diag_shape = array_ops.shape(grad) diag_shape = array_ops.slice(diag_shape, [0], [array_ops.rank(grad) - 1]) grad_input = array_ops.batch_matrix_set_diag( grad, array_ops.zeros(diag_shape, dtype=grad.dtype)) grad_diag = array_ops.batch_matrix_diag_part(grad) return (grad_input, grad_diag)
def sample_n(self, n, seed=None, name='sample'): # pylint: disable=line-too-long """Generate `n` samples. Complexity: O(nbk^3) The sampling procedure is based on the [Bartlett decomposition]( https://en.wikipedia.org/wiki/Wishart_distribution#Bartlett_decomposition) and [using a Gamma distribution to generate Chi2 random variates]( https://en.wikipedia.org/wiki/Chi-squared_distribution#Gamma.2C_exponential.2C_and_related_distributions). Args: n: `Scalar` `Tensor` of type `int32` or `int64`, the number of observations to sample. seed: Python integer; random number generator seed. name: The name of this op. Returns: samples: a `Tensor` of shape `(n,) + self.batch_shape + self.event_shape` with values of type `self.dtype`. """ with ops.name_scope(self.name): with ops.name_scope(name, values=[n] + list(self.inputs.values())): n = ops.convert_to_tensor(n, name='n') if n.dtype != dtypes.int32: raise TypeError('n.dtype=%s which is not int32' % n.dtype) batch_shape = self.batch_shape() event_shape = self.event_shape() batch_ndims = array_ops.shape(batch_shape)[0] ndims = batch_ndims + 3 # sample_ndims=1, event_ndims=2 shape = array_ops.concat(0, ((n,), batch_shape, event_shape)) # Complexity: O(nbk^2) x = random_ops.random_normal(shape=shape, mean=0., stddev=1., dtype=self.dtype, seed=seed) # Complexity: O(nbk) # This parametrization is equivalent to Chi2, i.e., # ChiSquared(k) == Gamma(alpha=k/2, beta=1/2) g = random_ops.random_gamma(shape=(n,), alpha=self._multi_gamma_sequence( 0.5 * self.df, self.dimension), beta=0.5, dtype=self.dtype, seed=seed) # Complexity: O(nbk^2) x = array_ops.batch_matrix_band_part(x, -1, 0) # Tri-lower. # Complexity: O(nbk) x = array_ops.batch_matrix_set_diag(x, math_ops.sqrt(g)) # Make batch-op ready. # Complexity: O(nbk^2) perm = array_ops.concat(0, (math_ops.range(1, ndims), (0,))) x = array_ops.transpose(x, perm) shape = array_ops.concat(0, (batch_shape, (event_shape[0], -1))) x = array_ops.reshape(x, shape) # Complexity: O(nbM) where M is the complexity of the operator solving a # vector system. E.g., for OperatorPDDiag, each matmul is O(k^2), so # this complexity is O(nbk^2). For OperatorPDCholesky, each matmul is # O(k^3) so this step has complexity O(nbk^3). x = self.scale_operator_pd.sqrt_matmul(x) # Undo make batch-op ready. # Complexity: O(nbk^2) shape = array_ops.concat(0, (batch_shape, event_shape, (n,))) x = array_ops.reshape(x, shape) perm = array_ops.concat(0, ((ndims-1,), math_ops.range(0, ndims-1))) x = array_ops.transpose(x, perm) if not self.cholesky_input_output_matrices: # Complexity: O(nbk^3) x = math_ops.batch_matmul(x, x, adj_y=True) # Set shape hints. if self.scale_operator_pd.get_shape().ndims is not None: x.set_shape(tensor_shape.TensorShape( [tensor_util.constant_value(n)] + self.scale_operator_pd.get_shape().as_list())) elif x.get_shape().ndims is not None: x.get_shape()[0].merge_with( tensor_shape.TensorDimension(tensor_util.constant_value(n))) return x
def _add_to_tensor(self, mat): mat_diag = array_ops.batch_matrix_diag_part(mat) new_diag = math_ops.square(self._diag) + mat_diag return array_ops.batch_matrix_set_diag(mat, new_diag)
def sample_n(self, n, seed=None, name="sample"): # pylint: disable=line-too-long """Generate `n` samples. Complexity: O(nbk^3) The sampling procedure is based on the [Bartlett decomposition]( https://en.wikipedia.org/wiki/Wishart_distribution#Bartlett_decomposition) and [using a Gamma distribution to generate Chi2 random variates]( https://en.wikipedia.org/wiki/Chi-squared_distribution#Gamma.2C_exponential.2C_and_related_distributions). Args: n: `Scalar` `Tensor` of type `int32` or `int64`, the number of observations to sample. seed: Python integer; random number generator seed. name: The name of this op. Returns: samples: a `Tensor` of shape `(n,) + self.batch_shape + self.event_shape` with values of type `self.dtype`. """ with ops.name_scope(self.name): with ops.name_scope(name, values=[n] + list(self.inputs.values())): n = ops.convert_to_tensor(n, name="n") if n.dtype != dtypes.int32: raise TypeError("n.dtype=%s which is not int32" % n.dtype) batch_shape = self.batch_shape() event_shape = self.event_shape() batch_ndims = array_ops.shape(batch_shape)[0] ndims = batch_ndims + 3 # sample_ndims=1, event_ndims=2 shape = array_ops.concat(0, ((n,), batch_shape, event_shape)) # Complexity: O(nbk^2) x = random_ops.random_normal(shape=shape, mean=0.0, stddev=1.0, dtype=self.dtype, seed=seed) # Complexity: O(nbk) # This parametrization is equivalent to Chi2, i.e., # ChiSquared(k) == Gamma(alpha=k/2, beta=1/2) g = random_ops.random_gamma( shape=(n,), alpha=self._multi_gamma_sequence(0.5 * self.df, self.dimension), beta=0.5, dtype=self.dtype, seed=seed, ) # Complexity: O(nbk^2) x = array_ops.batch_matrix_band_part(x, -1, 0) # Tri-lower. # Complexity: O(nbk) x = array_ops.batch_matrix_set_diag(x, math_ops.sqrt(g)) # Make batch-op ready. # Complexity: O(nbk^2) perm = array_ops.concat(0, (math_ops.range(1, ndims), (0,))) x = array_ops.transpose(x, perm) shape = array_ops.concat(0, (batch_shape, (event_shape[0], -1))) x = array_ops.reshape(x, shape) # Complexity: O(nbM) where M is the complexity of the operator solving a # vector system. E.g., for OperatorPDDiag, each matmul is O(k^2), so # this complexity is O(nbk^2). For OperatorPDCholesky, each matmul is # O(k^3) so this step has complexity O(nbk^3). x = self.scale_operator_pd.sqrt_matmul(x) # Undo make batch-op ready. # Complexity: O(nbk^2) shape = array_ops.concat(0, (batch_shape, event_shape, (n,))) x = array_ops.reshape(x, shape) perm = array_ops.concat(0, ((ndims - 1,), math_ops.range(0, ndims - 1))) x = array_ops.transpose(x, perm) if not self.cholesky_input_output_matrices: # Complexity: O(nbk^3) x = math_ops.batch_matmul(x, x, adj_y=True) # Set shape hints. if self.scale_operator_pd.get_shape().ndims is not None: x.set_shape( tensor_shape.TensorShape( [tensor_util.constant_value(n)] + self.scale_operator_pd.get_shape().as_list() ) ) elif x.get_shape().ndims is not None: x.get_shape()[0].merge_with(tensor_shape.TensorDimension(tensor_util.constant_value(n))) return x
def _add_to_tensor(self, mat): # Add to a tensor in O(k) time! mat_diag = array_ops.batch_matrix_diag_part(mat) new_diag = constant_op.constant(1, dtype=self.dtype) + mat_diag return array_ops.batch_matrix_set_diag(mat, new_diag)