def batch_matrix_diag_transform(matrix, transform=None, name=None): """Transform diagonal of [batch-]matrix, leave rest of matrix unchanged. Create a trainable covariance defined by a Cholesky factor: ```python # Transform network layer into 2 x 2 array. matrix_values = tf.contrib.layers.fully_connected(activations, 4) matrix = tf.reshape(matrix_values, (batch_size, 2, 2)) # Make the diagonal positive. If the upper triangle was zero, this would be a # valid Cholesky factor. chol = batch_matrix_diag_transform(matrix, transform=tf.nn.softplus) # OperatorPDCholesky ignores the upper triangle. operator = OperatorPDCholesky(chol) ``` Example of heteroskedastic 2-D linear regression. ```python # Get a trainable Cholesky factor. matrix_values = tf.contrib.layers.fully_connected(activations, 4) matrix = tf.reshape(matrix_values, (batch_size, 2, 2)) chol = batch_matrix_diag_transform(matrix, transform=tf.nn.softplus) # Get a trainable mean. mu = tf.contrib.layers.fully_connected(activations, 2) # This is a fully trainable multivariate normal! dist = tf.contrib.distributions.MVNCholesky(mu, chol) # Standard log loss. Minimizing this will "train" mu and chol, and then dist # will be a distribution predicting labels as multivariate Gaussians. loss = -1 * tf.reduce_mean(dist.log_pdf(labels)) ``` Args: matrix: Rank `R` `Tensor`, `R >= 2`, where the last two dimensions are equal. transform: Element-wise function mapping `Tensors` to `Tensors`. To be applied to the diagonal of `matrix`. If `None`, `matrix` is returned unchanged. Defaults to `None`. name: A name to give created ops. Defaults to "batch_matrix_diag_transform". Returns: A `Tensor` with same shape and `dtype` as `matrix`. """ with ops.op_scope([matrix], name, 'batch_matrix_diag_transform'): matrix = ops.convert_to_tensor(matrix, name='matrix') if transform is None: return matrix # Replace the diag with transformed diag. diag = array_ops.batch_matrix_diag_part(matrix) transformed_diag = transform(diag) matrix += array_ops.batch_matrix_diag(transformed_diag - diag) return matrix
def variance(self, name="variance"): """Variance of the distribution.""" with ops.name_scope(self.name): with ops.op_scope([self._n, self._p, self._mean], name): p = array_ops.expand_dims( self._p * array_ops.expand_dims( array_ops.ones_like(self._n), -1), -1) variance = -math_ops.batch_matmul( array_ops.expand_dims(self._mean, -1), p, adj_y=True) variance += array_ops.batch_matrix_diag(self._mean) return variance
def variance(self, name="variance"): """Variance of the distribution.""" with ops.name_scope(self.name): with ops.name_scope(name, values=[self._n, self._p, self._mean]): p = array_ops.expand_dims( self._p * array_ops.expand_dims( array_ops.ones_like(self._n), -1), -1) variance = -math_ops.batch_matmul( array_ops.expand_dims(self._mean, -1), p, adj_y=True) variance += array_ops.batch_matrix_diag(self._mean) return variance
def variance(self, name="variance"): """Variance of the distribution.""" with ops.name_scope(self.name): with ops.op_scope([self._alpha, self._alpha_0], name): alpha = array_ops.expand_dims(self._alpha, -1) alpha_0 = array_ops.expand_dims(self._alpha_0, -1) expanded_alpha_0 = array_ops.expand_dims(alpha_0, -1) variance = -math_ops.batch_matmul(alpha, alpha, adj_y=True) / ( expanded_alpha_0 ** 2 * (expanded_alpha_0 + 1)) diagonal = self._alpha / (alpha_0 * (alpha_0 + 1)) variance += array_ops.batch_matrix_diag(diagonal) return variance
def variance(self, name="variance"): """Variance of the distribution.""" with ops.name_scope(self.name): with ops.name_scope(name, values=[self._alpha, self._alpha_0]): alpha = array_ops.expand_dims(self._alpha, -1) alpha_0 = array_ops.expand_dims(self._alpha_0, -1) expanded_alpha_0 = array_ops.expand_dims(alpha_0, -1) variance = -math_ops.batch_matmul(alpha, alpha, adj_y=True) / ( expanded_alpha_0 ** 2 * (expanded_alpha_0 + 1)) diagonal = self._alpha / (alpha_0 * (alpha_0 + 1)) variance += array_ops.batch_matrix_diag(diagonal) return variance
def variance(self, name="mean"): """Class variances for every batch member. The variance for each batch member is defined as the following: ``` Var(X_j) = n * alpha_j / alpha_0 * (1 - alpha_j / alpha_0) * (n + alpha_0) / (1 + alpha_0) ``` where `alpha_0 = sum_j alpha_j`. The covariance between elements in a batch is defined as: ``` Cov(X_i, X_j) = -n * alpha_i * alpha_j / alpha_0 ** 2 * (n + alpha_0) / (1 + alpha_0) ``` Args: name: The name for this op. Returns: A `Tensor` representing the variances for each batch member. """ alpha = self._alpha alpha_sum = self._alpha_sum n = self._n with ops.name_scope(self.name): with ops.name_scope(name, values=[alpha, alpha_sum, n]): expanded_alpha_sum = array_ops.expand_dims(alpha_sum, -1) shared_factor = n * (expanded_alpha_sum + n) / ( expanded_alpha_sum + 1) * array_ops.ones_like(alpha) mean_no_n = alpha / expanded_alpha_sum expanded_mean_no_n = array_ops.expand_dims(mean_no_n, -1) variance = -math_ops.batch_matmul( expanded_mean_no_n, expanded_mean_no_n, adj_y=True) variance += array_ops.batch_matrix_diag(mean_no_n) variance *= array_ops.expand_dims(shared_factor, -1) return variance
def variance(self, name='mean'): """Class variances for every batch member. The variance for each batch member is defined as the following: ``` Var(X_j) = n * alpha_j / alpha_0 * (1 - alpha_j / alpha_0) * (n + alpha_0) / (1 + alpha_0) ``` where `alpha_0 = sum_j alpha_j`. The covariance between elements in a batch is defined as: ``` Cov(X_i, X_j) = -n * alpha_i * alpha_j / alpha_0 ** 2 * (n + alpha_0) / (1 + alpha_0) ``` Args: name: The name for this op. Returns: A `Tensor` representing the variances for each batch member. """ alpha = self._alpha alpha_sum = self._alpha_sum n = self._n with ops.name_scope(self.name): with ops.op_scope([alpha, alpha_sum, n], name): expanded_alpha_sum = array_ops.expand_dims(alpha_sum, -1) shared_factor = n * (expanded_alpha_sum + n) / ( expanded_alpha_sum + 1) * array_ops.ones_like(alpha) mean_no_n = alpha / expanded_alpha_sum expanded_mean_no_n = array_ops.expand_dims(mean_no_n, -1) variance = -math_ops.batch_matmul( expanded_mean_no_n, expanded_mean_no_n, adj_y=True) variance += array_ops.batch_matrix_diag(mean_no_n) variance *= array_ops.expand_dims(shared_factor, -1) return variance
def _BatchMatrixDiagPartGrad(_, grad): return array_ops.batch_matrix_diag(grad)
def _sqrt_to_dense(self): return array_ops.batch_matrix_diag(self._diag)
def _to_dense(self): return array_ops.batch_matrix_diag(math_ops.square(self._diag))
def _to_dense(self): diag = array_ops.ones(self.vector_shape(), dtype=self.dtype) dense = array_ops.batch_matrix_diag(diag) dense.set_shape(self.get_shape()) return dense