def _inverse(self, x): x -= self.loc x, sample_shape = self.shaper.make_batch_of_event_sample_matrices(x) x = linalg_ops.batch_matrix_triangular_solve(self.scale, x) x = self.shaper.undo_make_batch_of_event_sample_matrices( x, sample_shape) return x
def _x_whitened_if_should_flip(self, x): # Tensor to use if x.shape = [M1,...,Mm] + chol.shape[:-1], # which is common if x was sampled. x_flipped = self._flip_front_dims_to_back(x) # batch version of: L^{-1} x x_whitened_expanded = linalg_ops.batch_matrix_triangular_solve( self._chol, x_flipped) return self._unfip_back_dims_to_front(x_whitened_expanded, array_ops.shape(x), x.get_shape())
def _x_whitened_if_should_flip(self, x): # Tensor to use if x.shape = [M1,...,Mm] + chol.shape[:-1], # which is common if x was sampled. x_flipped = self._flip_front_dims_to_back(x) # batch version of: L^{-1} x x_whitened_expanded = linalg_ops.batch_matrix_triangular_solve( self._chol, x_flipped) return self._unfip_back_dims_to_front( x_whitened_expanded, array_ops.shape(x), x.get_shape())
def _x_whitened_if_no_flip(self, x): """x_whitened in the event of no flip.""" # Tensors to use if x and chol have same shape, or a shape that must be # broadcast to match. chol_bcast, x_bcast = self._get_chol_and_x_compatible_shape(x) # batch version of: L^{-1} x # Note that here x_bcast has trailing dims of (k, 1), for "1" system of k # linear equations. This is the form used by the solver. x_whitened_expanded = linalg_ops.batch_matrix_triangular_solve( chol_bcast, x_bcast) x_whitened = array_ops.squeeze(x_whitened_expanded, squeeze_dims=[-1]) return x_whitened
def _BatchMatrixTriangularSolveGrad(op, grad): """Gradient for BatchMatrixTriangularSolve.""" a = op.inputs[0] adjoint_a = op.get_attr("adjoint") lower_a = op.get_attr("lower") c = op.outputs[0] grad_b = linalg_ops.batch_matrix_triangular_solve(a, grad, lower=lower_a, adjoint=not adjoint_a) if adjoint_a: grad_a = -math_ops.batch_matmul(c, grad_b, adj_y=True) else: grad_a = -math_ops.batch_matmul(grad_b, c, adj_y=True) if lower_a: grad_a = array_ops.batch_matrix_band_part(grad_a, -1, 0) else: grad_a = array_ops.batch_matrix_band_part(grad_a, 0, -1) return (grad_a, grad_b)
def log_pdf(self, x, name=None): """Log pdf of observations `x` given these Multivariate Normals. Args: x: tensor of dtype `dtype`, must be broadcastable with `mu`. name: The name to give this op. Returns: log_pdf: tensor of dtype `dtype`, the log-PDFs of `x`. """ with ops.op_scope([self._mu, self._sigma_chol, x], name, "MultivariateNormalLogPdf"): x = ops.convert_to_tensor(x) contrib_tensor_util.assert_same_float_dtype((self._mu, x)) x_centered = x - self.mu x_rank = array_ops.rank(x_centered) sigma_rank = array_ops.rank(self._sigma_chol) x_rank_vec = array_ops.pack([x_rank]) sigma_rank_vec = array_ops.pack([sigma_rank]) x_shape = array_ops.shape(x_centered) # sigma_chol is shaped [D, E, F, ..., k, k] # x_centered shape is one of: # [D, E, F, ..., k], or [F, ..., k], or # [A, B, C, D, E, F, ..., k] # and we need to convert x_centered to shape: # [D, E, F, ..., k, A*B*C] (or 1 if A, B, C don't exist) # then transpose and reshape x_whitened back to one of the shapes: # [D, E, F, ..., k], or [1, 1, F, ..., k], or # [A, B, C, D, E, F, ..., k] # This helper handles the case where rank(x_centered) < rank(sigma) def _broadcast_x_not_higher_rank_than_sigma(): return array_ops.reshape( x_centered, array_ops.concat( # Reshape to ones(deficient x rank) + x_shape + [1] 0, (array_ops.ones(array_ops.pack( [sigma_rank - x_rank - 1]), dtype=x_rank.dtype), x_shape, [1]))) # These helpers handle the case where rank(x_centered) >= rank(sigma) def _broadcast_x_higher_rank_than_sigma(): x_shape_left = array_ops.slice(x_shape, [0], sigma_rank_vec - 1) x_shape_right = array_ops.slice(x_shape, sigma_rank_vec - 1, x_rank_vec - 1) x_shape_perm = array_ops.concat( 0, (math_ops.range(sigma_rank - 1, x_rank), math_ops.range(0, sigma_rank - 1))) return array_ops.reshape( # Convert to [D, E, F, ..., k, B, C] array_ops.transpose(x_centered, perm=x_shape_perm), # Reshape to [D, E, F, ..., k, B*C] array_ops.concat( 0, (x_shape_right, array_ops.pack( [math_ops.reduce_prod(x_shape_left, 0)])))) def _unbroadcast_x_higher_rank_than_sigma(): x_shape_left = array_ops.slice(x_shape, [0], sigma_rank_vec - 1) x_shape_right = array_ops.slice(x_shape, sigma_rank_vec - 1, x_rank_vec - 1) x_shape_perm = array_ops.concat( 0, (math_ops.range(sigma_rank - 1, x_rank), math_ops.range(0, sigma_rank - 1))) return array_ops.transpose( # [D, E, F, ..., k, B, C] => [B, C, D, E, F, ..., k] array_ops.reshape( # convert to [D, E, F, ..., k, B, C] x_whitened_broadcast, array_ops.concat(0, (x_shape_right, x_shape_left))), perm=x_shape_perm) # Step 1: reshape x_centered x_centered_broadcast = control_flow_ops.cond( # x_centered == [D, E, F, ..., k] => [D, E, F, ..., k, 1] # or == [F, ..., k] => [1, 1, F, ..., k, 1] x_rank <= sigma_rank - 1, _broadcast_x_not_higher_rank_than_sigma, # x_centered == [B, C, D, E, F, ..., k] => [D, E, F, ..., k, B*C] _broadcast_x_higher_rank_than_sigma) x_whitened_broadcast = linalg_ops.batch_matrix_triangular_solve( self._sigma_chol, x_centered_broadcast) # Reshape x_whitened_broadcast back to x_whitened x_whitened = control_flow_ops.cond( x_rank <= sigma_rank - 1, lambda: array_ops.reshape(x_whitened_broadcast, x_shape), _unbroadcast_x_higher_rank_than_sigma) x_whitened = array_ops.expand_dims(x_whitened, -1) # Reshape x_whitened to contain row vectors # Returns a batchwise scalar x_whitened_norm = math_ops.batch_matmul(x_whitened, x_whitened, adj_x=True) x_whitened_norm = control_flow_ops.cond( x_rank <= sigma_rank - 1, lambda: array_ops.squeeze(x_whitened_norm, [-2, -1]), lambda: array_ops.squeeze(x_whitened_norm, [-1])) log_two_pi = constant_op.constant(math.log(2 * math.pi), dtype=self.dtype) k = math_ops.cast(self._k, self.dtype) log_pdf_value = (-math_ops.log(self._sigma_det) - k * log_two_pi - x_whitened_norm) / 2 final_shaped_value = control_flow_ops.cond( x_rank <= sigma_rank - 1, lambda: log_pdf_value, lambda: array_ops.squeeze(log_pdf_value, [-1])) output_static_shape = x_centered.get_shape()[:-1] final_shaped_value.set_shape(output_static_shape) return final_shaped_value
def log_pdf(self, x, name=None): """Log pdf of observations `x` given these Multivariate Normals. Args: x: tensor of dtype `dtype`, must be broadcastable with `mu`. name: The name to give this op. Returns: log_pdf: tensor of dtype `dtype`, the log-PDFs of `x`. """ with ops.op_scope( [self._mu, self._sigma_chol, x], name, "MultivariateNormalLogPdf"): x = ops.convert_to_tensor(x) contrib_tensor_util.assert_same_float_dtype((self._mu, x)) x_centered = x - self.mu x_rank = array_ops.rank(x_centered) sigma_rank = array_ops.rank(self._sigma_chol) x_rank_vec = array_ops.pack([x_rank]) sigma_rank_vec = array_ops.pack([sigma_rank]) x_shape = array_ops.shape(x_centered) # sigma_chol is shaped [D, E, F, ..., k, k] # x_centered shape is one of: # [D, E, F, ..., k], or [F, ..., k], or # [A, B, C, D, E, F, ..., k] # and we need to convert x_centered to shape: # [D, E, F, ..., k, A*B*C] (or 1 if A, B, C don't exist) # then transpose and reshape x_whitened back to one of the shapes: # [D, E, F, ..., k], or [1, 1, F, ..., k], or # [A, B, C, D, E, F, ..., k] # This helper handles the case where rank(x_centered) < rank(sigma) def _broadcast_x_not_higher_rank_than_sigma(): return array_ops.reshape( x_centered, array_ops.concat( # Reshape to ones(deficient x rank) + x_shape + [1] 0, (array_ops.ones(array_ops.pack([sigma_rank - x_rank - 1]), dtype=x_rank.dtype), x_shape, [1]))) # These helpers handle the case where rank(x_centered) >= rank(sigma) def _broadcast_x_higher_rank_than_sigma(): x_shape_left = array_ops.slice( x_shape, [0], sigma_rank_vec - 1) x_shape_right = array_ops.slice( x_shape, sigma_rank_vec - 1, x_rank_vec - 1) x_shape_perm = array_ops.concat( 0, (math_ops.range(sigma_rank - 1, x_rank), math_ops.range(0, sigma_rank - 1))) return array_ops.reshape( # Convert to [D, E, F, ..., k, B, C] array_ops.transpose( x_centered, perm=x_shape_perm), # Reshape to [D, E, F, ..., k, B*C] array_ops.concat( 0, (x_shape_right, array_ops.pack([ math_ops.reduce_prod(x_shape_left, 0)])))) def _unbroadcast_x_higher_rank_than_sigma(): x_shape_left = array_ops.slice( x_shape, [0], sigma_rank_vec - 1) x_shape_right = array_ops.slice( x_shape, sigma_rank_vec - 1, x_rank_vec - 1) x_shape_perm = array_ops.concat( 0, (math_ops.range(sigma_rank - 1, x_rank), math_ops.range(0, sigma_rank - 1))) return array_ops.transpose( # [D, E, F, ..., k, B, C] => [B, C, D, E, F, ..., k] array_ops.reshape( # convert to [D, E, F, ..., k, B, C] x_whitened_broadcast, array_ops.concat(0, (x_shape_right, x_shape_left))), perm=x_shape_perm) # Step 1: reshape x_centered x_centered_broadcast = control_flow_ops.cond( # x_centered == [D, E, F, ..., k] => [D, E, F, ..., k, 1] # or == [F, ..., k] => [1, 1, F, ..., k, 1] x_rank <= sigma_rank - 1, _broadcast_x_not_higher_rank_than_sigma, # x_centered == [B, C, D, E, F, ..., k] => [D, E, F, ..., k, B*C] _broadcast_x_higher_rank_than_sigma) x_whitened_broadcast = linalg_ops.batch_matrix_triangular_solve( self._sigma_chol, x_centered_broadcast) # Reshape x_whitened_broadcast back to x_whitened x_whitened = control_flow_ops.cond( x_rank <= sigma_rank - 1, lambda: array_ops.reshape(x_whitened_broadcast, x_shape), _unbroadcast_x_higher_rank_than_sigma) x_whitened = array_ops.expand_dims(x_whitened, -1) # Reshape x_whitened to contain row vectors # Returns a batchwise scalar x_whitened_norm = math_ops.batch_matmul( x_whitened, x_whitened, adj_x=True) x_whitened_norm = control_flow_ops.cond( x_rank <= sigma_rank - 1, lambda: array_ops.squeeze(x_whitened_norm, [-2, -1]), lambda: array_ops.squeeze(x_whitened_norm, [-1])) log_two_pi = constant_op.constant(math.log(2 * math.pi), dtype=self.dtype) k = math_ops.cast(self._k, self.dtype) log_pdf_value = ( -math_ops.log(self._sigma_det) -k * log_two_pi - x_whitened_norm) / 2 final_shaped_value = control_flow_ops.cond( x_rank <= sigma_rank - 1, lambda: log_pdf_value, lambda: array_ops.squeeze(log_pdf_value, [-1])) output_static_shape = x_centered.get_shape()[:-1] final_shaped_value.set_shape(output_static_shape) return final_shaped_value
def _batch_sqrt_solve(self, rhs): return linalg_ops.batch_matrix_triangular_solve(self._chol, rhs, lower=True)
def _inverse(self, x): x -= self.loc x, sample_shape = self.shaper.make_batch_of_event_sample_matrices(x) x = linalg_ops.batch_matrix_triangular_solve(self.scale, x) x = self.shaper.undo_make_batch_of_event_sample_matrices(x, sample_shape) return x