def _add(self, op1, op2, operator_name, hints): if _type(op1) in _EFFICIENT_ADD_TO_TENSOR: op_add_to_tensor, op_other = op1, op2 else: op_add_to_tensor, op_other = op2, op1 return linear_operator_full_matrix.LinearOperatorFullMatrix( matrix=op_add_to_tensor.add_to_tensor(op_other.to_dense()), is_non_singular=hints.is_non_singular, is_self_adjoint=hints.is_self_adjoint, is_positive_definite=hints.is_positive_definite, name=operator_name)
def _mean_of_covariance_given_quadrature_component(self, diag_only): p = self.mixture_distribution.probs # To compute E[Cov(Z|V)], we'll add matrices within three categories: # scaled-identity, diagonal, and full. Then we'll combine these at the end. scale_identity_multiplier = None diag = None full = None for k, aff in enumerate(self.interpolated_affine): s = aff.scale # Just in case aff.scale has side-effects, we'll call once. if (s is None or isinstance( s, linop_identity_lib.LinearOperatorIdentity)): scale_identity_multiplier = add(scale_identity_multiplier, p[..., k, array_ops.newaxis]) elif isinstance(s, linop_identity_lib.LinearOperatorScaledIdentity): scale_identity_multiplier = add( scale_identity_multiplier, (p[..., k, array_ops.newaxis] * math_ops.square(s.multiplier))) elif isinstance(s, linop_diag_lib.LinearOperatorDiag): diag = add(diag, (p[..., k, array_ops.newaxis] * math_ops.square(s.diag_part()))) else: x = (p[..., k, array_ops.newaxis, array_ops.newaxis] * s.matmul(s.to_dense(), adjoint_arg=True)) if diag_only: x = array_ops.matrix_diag_part(x) full = add(full, x) # We must now account for the fact that the base distribution might have a # non-unity variance. Recall that, since X ~ iid Law(X_0), # `Cov(SX+m) = S Cov(X) S.T = S S.T Diag(Var(X_0))`. # We can scale by `Var(X)` (vs `Cov(X)`) since X corresponds to `d` iid # samples from a scalar-event distribution. v = self.distribution.variance() if scale_identity_multiplier is not None: scale_identity_multiplier *= v if diag is not None: diag *= v[..., array_ops.newaxis] if full is not None: full *= v[..., array_ops.newaxis] if diag_only: # Apparently we don't need the full matrix, just the diagonal. r = add(diag, full) if r is None and scale_identity_multiplier is not None: ones = array_ops.ones(self.event_shape_tensor(), dtype=self.dtype) return scale_identity_multiplier[..., array_ops.newaxis] * ones return add(r, scale_identity_multiplier) # `None` indicates we don't know if the result is positive-definite. is_positive_definite = (True if all( aff.scale.is_positive_definite for aff in self.endpoint_affine) else None) to_add = [] if diag is not None: to_add.append( linop_diag_lib.LinearOperatorDiag( diag=diag, is_positive_definite=is_positive_definite)) if full is not None: to_add.append( linop_full_lib.LinearOperatorFullMatrix( matrix=full, is_positive_definite=is_positive_definite)) if scale_identity_multiplier is not None: to_add.append( linop_identity_lib.LinearOperatorScaledIdentity( num_rows=self.event_shape_tensor()[0], multiplier=scale_identity_multiplier, is_positive_definite=is_positive_definite)) return (linop_add_lib.add_operators(to_add)[0].to_dense() if to_add else None)
def _inverse_block_lower_triangular(block_lower_triangular_operator): """Inverse of LinearOperatorBlockLowerTriangular. We recursively apply the identity: ```none |A 0|' = | A' 0| |B C| |-C'BA' C'| ``` where `A` is n-by-n, `B` is m-by-n, `C` is m-by-m, and `'` denotes inverse. This identity can be verified through multiplication: ```none |A 0|| A' 0| |B C||-C'BA' C'| = | AA' 0| |BA'-CC'BA' CC'| = |I 0| |0 I| ``` Args: block_lower_triangular_operator: Instance of `LinearOperatorBlockLowerTriangular`. Returns: block_lower_triangular_operator_inverse: Instance of `LinearOperatorBlockLowerTriangular`, the inverse of `block_lower_triangular_operator`. """ if len(block_lower_triangular_operator.operators) == 1: return ( linear_operator_block_lower_triangular. LinearOperatorBlockLowerTriangular( [[block_lower_triangular_operator.operators[0][0].inverse()]], is_non_singular=block_lower_triangular_operator. is_non_singular, is_self_adjoint=block_lower_triangular_operator. is_self_adjoint, is_positive_definite=( block_lower_triangular_operator.is_positive_definite), is_square=True)) blockwise_dim = len(block_lower_triangular_operator.operators) # Calculate the inverse of the `LinearOperatorBlockLowerTriangular` # representing all but the last row of `block_lower_triangular_operator` with # a recursive call (the matrix `A'` in the docstring definition). upper_left_inverse = ( linear_operator_block_lower_triangular. LinearOperatorBlockLowerTriangular( block_lower_triangular_operator.operators[:-1]).inverse()) bottom_row = block_lower_triangular_operator.operators[-1] bottom_right_inverse = bottom_row[-1].inverse() # Find the bottom row of the inverse (equal to `[-C'BA', C']` in the docstring # definition, where `C` is the bottom-right operator of # `block_lower_triangular_operator` and `B` is the set of operators in the # bottom row excluding `C`). To find `-C'BA'`, we first iterate over the # column partitions of `A'`. inverse_bottom_row = [] for i in range(blockwise_dim - 1): # Find the `i`-th block of `BA'`. blocks = [] for j in range(i, blockwise_dim - 1): result = bottom_row[j].matmul(upper_left_inverse.operators[j][i]) if not any( isinstance(result, op_type) for op_type in linear_operator_addition.SUPPORTED_OPERATORS): result = linear_operator_full_matrix.LinearOperatorFullMatrix( result.to_dense()) blocks.append(result) summed_blocks = linear_operator_addition.add_operators(blocks) assert len(summed_blocks) == 1 block = summed_blocks[0] # Find the `i`-th block of `-C'BA'`. block = bottom_right_inverse.matmul(block) block = linear_operator_identity.LinearOperatorScaledIdentity( num_rows=bottom_right_inverse.domain_dimension_tensor(), multiplier=math_ops.cast(-1, dtype=block.dtype)).matmul(block) inverse_bottom_row.append(block) # `C'` is the last block of the inverted linear operator. inverse_bottom_row.append(bottom_right_inverse) return ( linear_operator_block_lower_triangular. LinearOperatorBlockLowerTriangular( upper_left_inverse.operators + [inverse_bottom_row], is_non_singular=block_lower_triangular_operator.is_non_singular, is_self_adjoint=block_lower_triangular_operator.is_self_adjoint, is_positive_definite=( block_lower_triangular_operator.is_positive_definite), is_square=True))