def _underdetermined(op, grad): """Gradients for the underdetermined case of MatrixSolveLs. This is the backprop for the solution to the normal equations of the second kind: X = F(A, B) = A * (A*A^T + lambda*I)^{-1} * B that (for lambda=0) solve the least squares problem min ||X||_F subject to A*X = B. """ a = op.inputs[0] b = op.inputs[1] l2_regularizer = op.inputs[2] a_shape = array_ops.shape(a) batch_shape = a_shape[:-2] m = a_shape[-2] identity = linalg_ops.eye(m, batch_shape=batch_shape, dtype=a.dtype) gramian = math_ops.batch_matmul( a, a, adj_y=True) + l2_regularizer * identity chol = linalg_ops.cholesky(gramian) grad_b = linalg_ops.cholesky_solve(chol, math_ops.batch_matmul(a, grad)) # Temporary z = (A * A^T + lambda * I)^{-1} * B. z = linalg_ops.cholesky_solve(chol, b) bz = -math_ops.batch_matmul(grad_b, z, adj_y=True) bz_sym = bz + array_ops.matrix_transpose(bz) grad_a = math_ops.batch_matmul(bz_sym, a) + math_ops.batch_matmul(z, grad) return (grad_a, grad_b, None)
def _BatchMatrixInverseGrad(op, grad): """Gradient for BatchMatrixInverse.""" ainv = op.outputs[0] return -math_ops.batch_matmul( ainv, math_ops.batch_matmul(grad, ainv, adj_y=True), adj_x=True)
def matmul(self, x, name='matmul'): """Left (batch) matrix multiplication of `x` by this operator.""" chol = self._chol with ops.name_scope(self.name): with ops.op_scope(self.inputs, name): a_times_x = math_ops.batch_matmul(chol, x, adj_x=True) return math_ops.batch_matmul(chol, a_times_x)
def _batch_sqrt_solve(self, rhs): # Recall the square root of this operator is M + VDV^T. # The Woodbury formula gives: # (M + VDV^T)^{-1} # = M^{-1} - M^{-1} V (D^{-1} + V^T M^{-1} V)^{-1} V^T M^{-1} # = M^{-1} - M^{-1} V C^{-1} V^T M^{-1} # where C is the capacitance matrix. m = self._operator v = self._v cchol = self._chol_capacitance(batch_mode=True) # The operators will use batch/singleton mode automatically. We don't # override. # M^{-1} rhs minv_rhs = m.solve(rhs) # V^T M^{-1} rhs vt_minv_rhs = math_ops.batch_matmul(v, minv_rhs, adj_x=True) # C^{-1} V^T M^{-1} rhs cinv_vt_minv_rhs = linalg_ops.batch_cholesky_solve(cchol, vt_minv_rhs) # V C^{-1} V^T M^{-1} rhs v_cinv_vt_minv_rhs = math_ops.batch_matmul(v, cinv_vt_minv_rhs) # M^{-1} V C^{-1} V^T M^{-1} rhs minv_v_cinv_vt_minv_rhs = m.solve(v_cinv_vt_minv_rhs) # M^{-1} - M^{-1} V C^{-1} V^T M^{-1} return minv_rhs - minv_v_cinv_vt_minv_rhs
def _overdetermined(op, grad): """Gradients for the overdetermined case of MatrixSolveLs. This is the backprop for the solution to the normal equations of the first kind: X = F(A, B) = (A^T * A + lambda * I)^{-1} * A^T * B which solve the least squares problem min ||A * X - B||_F^2 + lambda ||X||_F^2. """ a = op.inputs[0] b = op.inputs[1] l2_regularizer = math_ops.cast(op.inputs[2], a.dtype.base_dtype) x = op.outputs[0] a_shape = array_ops.shape(a) batch_shape = a_shape[:-2] n = a_shape[-1] identity = linalg_ops.eye(n, batch_shape=batch_shape, dtype=a.dtype) gramian = math_ops.batch_matmul( a, a, adj_x=True) + l2_regularizer * identity chol = linalg_ops.cholesky(gramian) # Temporary z = (A^T * A + lambda * I)^{-1} * grad. z = linalg_ops.cholesky_solve(chol, grad) xzt = math_ops.batch_matmul(x, z, adj_y=True) zx_sym = xzt + array_ops.matrix_transpose(xzt) grad_a = -math_ops.batch_matmul(a, zx_sym) + math_ops.batch_matmul( b, z, adj_y=True) grad_b = math_ops.batch_matmul(a, z) return (grad_a, grad_b, None)
def _BatchMatrixInverseGrad(op, grad): """Gradient for BatchMatrixInverse.""" ainv = op.outputs[0] return -math_ops.batch_matmul(ainv, math_ops.batch_matmul(grad, ainv, adj_y=True), adj_x=True)
def _batch_matmul(self, x, transpose_x=False): # tf.batch_matmul is defined x * y, so "y" is on the right, not "x". chol = array_ops.batch_matrix_band_part(self._chol, -1, 0) chol_times_x = math_ops.batch_matmul(chol, x, adj_x=True, adj_y=transpose_x) return math_ops.batch_matmul(chol, chol_times_x)
def _BatchMatrixSolveGrad(op, grad): """Gradient for BatchMatrixSolve.""" a = op.inputs[0] adjoint_a = op.get_attr("adjoint") c = op.outputs[0] grad_b = linalg_ops.batch_matrix_solve(a, grad, adjoint=not adjoint_a) if adjoint_a: grad_a = -math_ops.batch_matmul(c, grad_b, adj_y=True) else: grad_a = -math_ops.batch_matmul(grad_b, c, adj_y=True) return (grad_a, grad_b)
def _BatchMatrixSolveGrad(op, grad): """Gradient for BatchMatrixSolve.""" a = op.inputs[0] c = op.outputs[0] # TODO(rmlarsen): Replace the following two lines with # a single call to batch_matrix_solve after adding # in an option to solve for A^T X = Y. ainv = linalg_ops.batch_matrix_inverse(a) grad_b = math_ops.batch_matmul(ainv, grad, adj_x=True) grad_a = -math_ops.batch_matmul(grad_b, c, adj_y=True) return (grad_a, grad_b)
def _MatrixSolveGrad(op, grad): """Gradient for MatrixSolve.""" a = op.inputs[0] adjoint_a = op.get_attr("adjoint") c = op.outputs[0] grad_b = linalg_ops.matrix_solve(a, grad, adjoint=not adjoint_a) if adjoint_a: grad_a = -math_ops.batch_matmul(c, grad_b, adj_y=True) else: grad_a = -math_ops.batch_matmul(grad_b, c, adj_y=True) return (grad_a, grad_b)
def _test1(op, grad_e, grad_v): """Gradient for SelfAdjointEigV2 derived with Joan with no adjustment for subspace""" e = op.outputs[0] v = op.outputs[1] #dim = v.get_shape() with ops.control_dependencies([grad_e.op, grad_v.op]): if grad_v is not None: E = array_ops.diag(e) v_proj = array_ops.slice(v, [0,0], [20,2]) grad_grassman = grad_v - math_ops.batch_matmul(math_ops.batch_matmul(v_proj, array_ops.transpose(v_proj)), grad_v) grad_a = math_ops.batch_matmul(grad_grassman, math_ops.batch_matmul(E, array_ops.transpose(grad_v)))+math_ops.batch_matmul(grad_v, math_ops.batch_matmul(E, array_ops.transpose(grad_grassman))) return grad_a
def _batch_sqrt_matmul(self, x, transpose_x=False): v = self._v m = self._operator d = self._diag_operator # The operators call the appropriate matmul/batch_matmul automatically. We # cannot override. # batch_matmul is defined as: x * y, so adj_x and adj_y are the ways to # transpose the left and right. mx = m.matmul(x, transpose_x=transpose_x) vt_x = math_ops.batch_matmul(v, x, adj_x=True, adj_y=transpose_x) d_vt_x = d.matmul(vt_x) v_d_vt_x = math_ops.batch_matmul(v, d_vt_x) return mx + v_d_vt_x
def _variance(self): p = self.p * array_ops.expand_dims(array_ops.ones_like(self.n), -1) outer_prod = math_ops.batch_matmul( array_ops.expand_dims(self._mean_val, -1), array_ops.expand_dims(p, -2)) return array_ops.batch_matrix_set_diag( -outer_prod, self._mean_val - self._mean_val * p)
def __call__(self, inputs, state, scope=None): state, fw = state with vs.variable_scope(scope or type(self).__name__) as scope: """Wh(t) + Cx(t)""" linear = self.fw_calc([state, inputs], self._hidden_units, False) """h_0(t+1) = f(Wh(t) + Cx(t))""" if not self._norm_re: h = self._activation(self._norm(linear, scope="Norm0")) else: h = self._activation(self._norm(linear)) h = self._vec2mat(h) linear = self._vec2mat(linear) for i in range(self._S): """ h_{s+1}(t+1) = f([Wh(t) + Cx(t)] + A(t) h_s(t+1)), S times. From Eqn (2). """ if not self._norm_re: h = self._activation( self._norm(linear + tf.matmul(fw, h), scope="Norm%d" % (i + 1))) else: h = self._activation( self._norm(linear + math_ops.batch_matmul(fw, h))) """ Compute A(t+1) according to Eqn (4) """ state = self._vec2mat(state) new_fw = self._lambda * fw + self._eta * tf.matmul( state, state, adjoint_b=True) h = self._mat2vec(h) return h, (h, new_fw)
def variance(self, name="variance"): """Variance of the Wishart distribution. This function should not be confused with the covariance of the Wishart. The covariance matrix would have shape `q x q` where, `q = dimension * (dimension+1) / 2` and having elements corresponding to some mapping from a lower-triangular matrix to a vector-space. This function returns the diagonal of the Covariance matrix but shaped as a `dimension x dimension` matrix. Args: name: The name of this op. Returns: variance: `Tensor` of dtype `self.dtype`. """ with ops.name_scope(self.name): with ops.name_scope(name, values=list(self.inputs.values())): x = math_ops.sqrt(self.df) * self.scale_operator_pd.to_dense() d = array_ops.expand_dims(array_ops.batch_matrix_diag_part(x), -1) v = math_ops.square(x) + math_ops.batch_matmul(d, d, adj_y=True) if self.cholesky_input_output_matrices: return linalg_ops.batch_cholesky(v) else: return v
def variance(self, name='variance'): """Variance of the Wishart distribution. This function should not be confused with the covariance of the Wishart. The covariance matrix would have shape `q x q` where, `q = dimension * (dimension+1) / 2` and having elements corresponding to some mapping from a lower-triangular matrix to a vector-space. This function returns the diagonal of the Covariance matrix but shaped as a `dimension x dimension` matrix. Args: name: The name of this op. Returns: variance: `Tensor` of dtype `self.dtype`. """ with ops.name_scope(self.name): with ops.name_scope(name, values=list(self.inputs.values())): x = math_ops.sqrt(self.df) * self.scale_operator_pd.to_dense() d = array_ops.expand_dims(array_ops.batch_matrix_diag_part(x), -1) v = math_ops.square(x) + math_ops.batch_matmul(d, d, adj_y=True) if self.cholesky_input_output_matrices: return linalg_ops.batch_cholesky(v) else: return v
def _BatchMatrixSolveGrad(op, grad): """Gradient for BatchMatrixSolve.""" a = op.inputs[0] c = op.outputs[0] grad_b = linalg_ops.batch_matrix_solve(a, grad, adjoint=True) grad_a = -math_ops.batch_matmul(grad_b, c, adj_y=True) return (grad_a, grad_b)
def _MatrixTriangularSolveGrad(op, grad): """Gradient for MatrixTriangularSolve.""" a = op.inputs[0] adjoint_a = op.get_attr("adjoint") lower_a = op.get_attr("lower") c = op.outputs[0] grad_b = linalg_ops.matrix_triangular_solve(a, grad, lower=lower_a, adjoint=not adjoint_a) if adjoint_a: grad_a = -math_ops.batch_matmul(c, grad_b, adj_y=True) else: grad_a = -math_ops.batch_matmul(grad_b, c, adj_y=True) if lower_a: grad_a = array_ops.matrix_band_part(grad_a, -1, 0) else: grad_a = array_ops.matrix_band_part(grad_a, 0, -1) return (grad_a, grad_b)
def _forward(self, x): x, sample_shape = self.shaper.make_batch_of_event_sample_matrices(x) x = math_ops.batch_matmul(self.scale, x) x = self.shaper.undo_make_batch_of_event_sample_matrices( x, sample_shape) x += self.loc return x
def _variance(self): x = math_ops.sqrt(self.df) * self.scale_operator_pd.to_dense() d = array_ops.expand_dims(array_ops.matrix_diag_part(x), -1) v = math_ops.square(x) + math_ops.batch_matmul(d, d, adj_y=True) if self.cholesky_input_output_matrices: return linalg_ops.cholesky(v) return v
def _variance(self): scale = self.alpha_sum * math_ops.sqrt(1.0 + self.alpha_sum) alpha = self.alpha / scale outer_prod = -math_ops.batch_matmul( array_ops.expand_dims(alpha, dim=-1), array_ops.expand_dims(alpha, dim=-2) # column ) # row return array_ops.batch_matrix_set_diag(outer_prod, alpha * (self.alpha_sum / scale - alpha))
def _MatrixTriangularSolveGrad(op, grad): """Gradient for MatrixTriangularSolve.""" a = op.inputs[0] adjoint_a = op.get_attr("adjoint") lower_a = op.get_attr("lower") c = op.outputs[0] grad_b = linalg_ops.matrix_triangular_solve( a, grad, lower=lower_a, adjoint=not adjoint_a) if adjoint_a: grad_a = -math_ops.batch_matmul(c, grad_b, adj_y=True) else: grad_a = -math_ops.batch_matmul(grad_b, c, adj_y=True) if lower_a: grad_a = array_ops.matrix_band_part(grad_a, -1, 0) else: grad_a = array_ops.matrix_band_part(grad_a, 0, -1) return (grad_a, grad_b)
def _variance(self): scale = self.alpha_sum * math_ops.sqrt(1. + self.alpha_sum) alpha = self.alpha / scale outer_prod = -math_ops.batch_matmul( array_ops.expand_dims(alpha, dim=-1), # column array_ops.expand_dims(alpha, dim=-2)) # row return array_ops.batch_matrix_set_diag( outer_prod, alpha * (self.alpha_sum / scale - alpha))
def _sample_n(self, n, seed): batch_shape = self.batch_shape() event_shape = self.event_shape() batch_ndims = array_ops.shape(batch_shape)[0] ndims = batch_ndims + 3 # sample_ndims=1, event_ndims=2 shape = array_ops.concat(0, ((n,), batch_shape, event_shape)) # Complexity: O(nbk^2) x = random_ops.random_normal(shape=shape, mean=0., stddev=1., dtype=self.dtype, seed=seed) # Complexity: O(nbk) # This parametrization is equivalent to Chi2, i.e., # ChiSquared(k) == Gamma(alpha=k/2, beta=1/2) g = random_ops.random_gamma(shape=(n,), alpha=self._multi_gamma_sequence( 0.5 * self.df, self.dimension), beta=0.5, dtype=self.dtype, seed=distribution_util.gen_new_seed( seed, "wishart")) # Complexity: O(nbk^2) x = array_ops.matrix_band_part(x, -1, 0) # Tri-lower. # Complexity: O(nbk) x = array_ops.matrix_set_diag(x, math_ops.sqrt(g)) # Make batch-op ready. # Complexity: O(nbk^2) perm = array_ops.concat(0, (math_ops.range(1, ndims), (0,))) x = array_ops.transpose(x, perm) shape = array_ops.concat(0, (batch_shape, (event_shape[0], -1))) x = array_ops.reshape(x, shape) # Complexity: O(nbM) where M is the complexity of the operator solving a # vector system. E.g., for OperatorPDDiag, each matmul is O(k^2), so # this complexity is O(nbk^2). For OperatorPDCholesky, each matmul is # O(k^3) so this step has complexity O(nbk^3). x = self.scale_operator_pd.sqrt_matmul(x) # Undo make batch-op ready. # Complexity: O(nbk^2) shape = array_ops.concat(0, (batch_shape, event_shape, (n,))) x = array_ops.reshape(x, shape) perm = array_ops.concat(0, ((ndims-1,), math_ops.range(0, ndims-1))) x = array_ops.transpose(x, perm) if not self.cholesky_input_output_matrices: # Complexity: O(nbk^3) x = math_ops.batch_matmul(x, x, adj_y=True) return x
def _sample_n(self, n, seed): batch_shape = self.batch_shape() event_shape = self.event_shape() batch_ndims = array_ops.shape(batch_shape)[0] ndims = batch_ndims + 3 # sample_ndims=1, event_ndims=2 shape = array_ops.concat(0, ((n, ), batch_shape, event_shape)) # Complexity: O(nbk^2) x = random_ops.random_normal(shape=shape, mean=0., stddev=1., dtype=self.dtype, seed=seed) # Complexity: O(nbk) # This parametrization is equivalent to Chi2, i.e., # ChiSquared(k) == Gamma(alpha=k/2, beta=1/2) g = random_ops.random_gamma( shape=(n, ), alpha=self._multi_gamma_sequence(0.5 * self.df, self.dimension), beta=0.5, dtype=self.dtype, seed=distribution_util.gen_new_seed(seed, "wishart")) # Complexity: O(nbk^2) x = array_ops.matrix_band_part(x, -1, 0) # Tri-lower. # Complexity: O(nbk) x = array_ops.matrix_set_diag(x, math_ops.sqrt(g)) # Make batch-op ready. # Complexity: O(nbk^2) perm = array_ops.concat(0, (math_ops.range(1, ndims), (0, ))) x = array_ops.transpose(x, perm) shape = array_ops.concat(0, (batch_shape, (event_shape[0], -1))) x = array_ops.reshape(x, shape) # Complexity: O(nbM) where M is the complexity of the operator solving a # vector system. E.g., for OperatorPDDiag, each matmul is O(k^2), so # this complexity is O(nbk^2). For OperatorPDCholesky, each matmul is # O(k^3) so this step has complexity O(nbk^3). x = self.scale_operator_pd.sqrt_matmul(x) # Undo make batch-op ready. # Complexity: O(nbk^2) shape = array_ops.concat(0, (batch_shape, event_shape, (n, ))) x = array_ops.reshape(x, shape) perm = array_ops.concat(0, ((ndims - 1, ), math_ops.range(0, ndims - 1))) x = array_ops.transpose(x, perm) if not self.cholesky_input_output_matrices: # Complexity: O(nbk^3) x = math_ops.batch_matmul(x, x, adj_y=True) return x
def _sqrt_to_dense(self): v = self._v d = self._diag_operator m = self._operator d_vt = d.matmul(v, transpose_x=True) # Batch op won't be efficient for singletons. Currently we don't break # to_dense into batch/singleton methods. v_d_vt = math_ops.batch_matmul(v, d_vt) m_plus_v_d_vt = m.to_dense() + v_d_vt return m_plus_v_d_vt
def variance(self, name="variance"): """Variance of the distribution.""" with ops.name_scope(self.name): with ops.op_scope([self._n, self._p, self._mean], name): p = array_ops.expand_dims( self._p * array_ops.expand_dims( array_ops.ones_like(self._n), -1), -1) variance = -math_ops.batch_matmul( array_ops.expand_dims(self._mean, -1), p, adj_y=True) variance += array_ops.batch_matrix_diag(self._mean) return variance
def variance(self, name="variance"): """Variance of the distribution.""" with ops.name_scope(self.name): with ops.name_scope(name, values=[self._n, self._p, self._mean]): p = array_ops.expand_dims( self._p * array_ops.expand_dims( array_ops.ones_like(self._n), -1), -1) variance = -math_ops.batch_matmul( array_ops.expand_dims(self._mean, -1), p, adj_y=True) variance += array_ops.batch_matrix_diag(self._mean) return variance
def _variance(self): alpha_sum = array_ops.expand_dims(self.alpha_sum, -1) normalized_alpha = self.alpha / alpha_sum variance = -math_ops.batch_matmul( array_ops.expand_dims(normalized_alpha, -1), array_ops.expand_dims(normalized_alpha, -2)) variance = array_ops.matrix_set_diag(variance, normalized_alpha * (1. - normalized_alpha)) shared_factor = (self.n * (alpha_sum + self.n) / (alpha_sum + 1) * array_ops.ones_like(self.alpha)) variance *= array_ops.expand_dims(shared_factor, -1) return variance
def _variance(self): alpha_sum = array_ops.expand_dims(self.alpha_sum, -1) normalized_alpha = self.alpha / alpha_sum variance = -math_ops.batch_matmul( array_ops.expand_dims(normalized_alpha, -1), array_ops.expand_dims(normalized_alpha, -2)) variance = array_ops.batch_matrix_set_diag( variance, normalized_alpha * (1. - normalized_alpha)) shared_factor = (self.n * (alpha_sum + self.n) / (alpha_sum + 1) * array_ops.ones_like(self.alpha)) variance *= array_ops.expand_dims(shared_factor, -1) return variance
def _underdetermined(op, grad): """Gradients for the underdetermined case of MatrixSolveLs. This is the backprop for the solution to the normal equations of the second kind: X = F(A, B) = A * (A*A^T + lambda*I)^{-1} * B that (for lambda=0) solve the least squares problem min ||X||_F subject to A*X = B. """ a = op.inputs[0] b = op.inputs[1] l2_regularizer = math_ops.cast(op.inputs[2], a.dtype.base_dtype) a_shape = array_ops.shape(a) batch_shape = a_shape[:-2] m = a_shape[-2] identity = linalg_ops.eye(m, batch_shape=batch_shape, dtype=a.dtype) gramian = math_ops.batch_matmul( a, a, adj_y=True) + l2_regularizer * identity chol = linalg_ops.cholesky(gramian) grad_b = linalg_ops.cholesky_solve(chol, math_ops.batch_matmul(a, grad)) # Temporary tmp = (A * A^T + lambda * I)^{-1} * B. tmp = linalg_ops.cholesky_solve(chol, b) a1 = math_ops.batch_matmul(tmp, a, adj_x=True) a1 = -math_ops.batch_matmul(grad_b, a1) a2 = grad - math_ops.batch_matmul(a, grad_b, adj_x=True) a2 = math_ops.batch_matmul(tmp, a2, adj_y=True) grad_a = a1 + a2 return (grad_a, grad_b, None)
def sample(self, n, seed=None, name=None): """Sample `n` observations from the Multivariate Normal Distributions. Args: n: `Scalar`, type int32, the number of observations to sample. seed: Python integer, the random seed. name: The name to give this op. Returns: samples: `[n, ...]`, a `Tensor` of `n` samples for each of the distributions determined by broadcasting the hyperparameters. """ with ops.op_scope([self._mu, self._sigma_chol, n], name, "MultivariateNormalSample"): # TODO(ebrevdo): Is there a better way to get broadcast_shape? broadcast_shape = self.mu.get_shape() n = ops.convert_to_tensor(n) sigma_shape_left = array_ops.slice( array_ops.shape(self._sigma_chol), [0], array_ops.pack([array_ops.rank(self._sigma_chol) - 2])) k_n = array_ops.pack([self._k, n]) shape = array_ops.concat(0, [sigma_shape_left, k_n]) white_samples = random_ops.random_normal(shape=shape, mean=0, stddev=1, dtype=self._mu.dtype, seed=seed) correlated_samples = math_ops.batch_matmul(self._sigma_chol, white_samples) # Move the last dimension to the front perm = array_ops.concat( 0, (array_ops.pack([array_ops.rank(correlated_samples) - 1]), math_ops.range(0, array_ops.rank(correlated_samples) - 1))) # TODO(ebrevdo): Once we get a proper tensor contraction op, # perform the inner product using that instead of batch_matmul # and this slow transpose can go away! correlated_samples = array_ops.transpose(correlated_samples, perm) samples = correlated_samples + self.mu # Provide some hints to shape inference n_val = tensor_util.constant_value(n) final_shape = tensor_shape.vector(n_val).concatenate( broadcast_shape) samples.set_shape(final_shape) return samples
def _SelfAdjointEigV2Grad(op, grad_e, grad_v): """Gradient for SelfAdjointEigV2.""" e = op.outputs[0] v = op.outputs[1] # a = op.inputs[0], which satisfies # a[...,:,:] * v[...,:,i] = e[...,i] * v[...,i] with ops.control_dependencies([grad_e.op, grad_v.op]): if grad_v is not None: # Construct the matrix f(i,j) = (i != j ? 1 / (e_i - e_j) : 0). # Notice that because of the term involving f, the gradient becomes # infinite (or NaN in practice) when eigenvalues are not unique. # Mathematically this should not be surprising, since for (k-fold) # degenerate eigenvalues, the corresponding eigenvectors are only defined # up to arbitrary rotation in a (k-dimensional) subspace. f = array_ops.matrix_set_diag( math_ops.inv( array_ops.expand_dims(e, -2) - array_ops.expand_dims(e, -1)), array_ops.zeros_like(e)) grad_a = math_ops.batch_matmul( v, math_ops.batch_matmul( array_ops.matrix_diag(grad_e) + f * math_ops.batch_matmul(v, grad_v, adj_x=True), v, adj_y=True)) else: grad_a = math_ops.batch_matmul( v, math_ops.batch_matmul(array_ops.matrix_diag(grad_e), v, adj_y=True)) # The forward op only depends on the lower triangular part of a, so here we # symmetrize and take the lower triangle grad_a = array_ops.matrix_band_part( grad_a + array_ops.matrix_transpose(grad_a), -1, 0) grad_a = array_ops.matrix_set_diag( grad_a, 0.5 * array_ops.matrix_diag_part(grad_a)) return grad_a
def variance(self, name="variance"): """Variance of the distribution.""" with ops.name_scope(self.name): with ops.name_scope(name, values=[self._alpha, self._alpha_0]): alpha = array_ops.expand_dims(self._alpha, -1) alpha_0 = array_ops.expand_dims(self._alpha_0, -1) expanded_alpha_0 = array_ops.expand_dims(alpha_0, -1) variance = -math_ops.batch_matmul(alpha, alpha, adj_y=True) / ( expanded_alpha_0 ** 2 * (expanded_alpha_0 + 1)) diagonal = self._alpha / (alpha_0 * (alpha_0 + 1)) variance += array_ops.batch_matrix_diag(diagonal) return variance
def variance(self, name="variance"): """Variance of the distribution.""" with ops.name_scope(self.name): with ops.op_scope([self._alpha, self._alpha_0], name): alpha = array_ops.expand_dims(self._alpha, -1) alpha_0 = array_ops.expand_dims(self._alpha_0, -1) expanded_alpha_0 = array_ops.expand_dims(alpha_0, -1) variance = -math_ops.batch_matmul(alpha, alpha, adj_y=True) / ( expanded_alpha_0 ** 2 * (expanded_alpha_0 + 1)) diagonal = self._alpha / (alpha_0 * (alpha_0 + 1)) variance += array_ops.batch_matrix_diag(diagonal) return variance
def _SelfAdjointEigV2Grad(op, grad_e, grad_v): """Gradient for SelfAdjointEigV2.""" e = op.outputs[0] v = op.outputs[1] # a = op.inputs[0], which satisfies # a[...,:,:] * v[...,:,i] = e[...,i] * v[...,i] with ops.control_dependencies([grad_e.op, grad_v.op]): if grad_v is not None: # Construct the matrix f(i,j) = (i != j ? 1 / (e_i - e_j) : 0). # Notice that because of the term involving f, the gradient becomes # infinite (or NaN in practice) when eigenvalues are not unique. # Mathematically this should not be surprising, since for (k-fold) # degenerate eigenvalues, the corresponding eigenvectors are only defined # up to arbitrary rotation in a (k-dimensional) subspace. f = array_ops.matrix_set_diag( math_ops.inv( array_ops.expand_dims(e, -2) - array_ops.expand_dims(e, -1)), array_ops.zeros_like(e)) grad_a = math_ops.batch_matmul( v, math_ops.batch_matmul( array_ops.matrix_diag(grad_e) + f * math_ops.batch_matmul( v, grad_v, adj_x=True), v, adj_y=True)) else: grad_a = math_ops.batch_matmul( v, math_ops.batch_matmul( array_ops.matrix_diag(grad_e), v, adj_y=True)) # The forward op only depends on the lower triangular part of a, so here we # symmetrize and take the lower triangle grad_a = array_ops.matrix_band_part( grad_a + array_ops.matrix_transpose(grad_a), -1, 0) grad_a = array_ops.matrix_set_diag(grad_a, 0.5 * array_ops.matrix_diag_part(grad_a)) return grad_a
def sqrt_matmul(self, x, name='sqrt_matmul'): """Left (batch) matmul `x` by a sqrt of this matrix: `Sx` where `A = S S^T. Args: x: `Tensor` with shape broadcastable to `[N1,...,Nb, k]` and same `dtype` as self. name: A name scope to use for ops added by this method. Returns: Shape `[N1,...,Nb, k]` `Tensor` holding the product `S x`. """ with ops.name_scope(self.name): with ops.op_scope([x] + self.inputs, name): chol_lower = array_ops.batch_matrix_band_part(self._chol, -1, 0) return math_ops.batch_matmul(chol_lower, x)
def __call__(self, inputs, state, scope=None): state, fast_weights = state with vs.variable_scope(scope or type(self).__name__) as scope: """Compute Wh(t) + Cx(t)""" linear = self._fwlinear([state, inputs], self._num_units, False) """Compute h_0(t+1) = f(Wh(t) + Cx(t))""" if not self._reuse_norm: h = self._activation(self._norm(linear, scope="Norm0")) else: h = self._activation(self._norm(linear)) h = self._vector2matrix(h) linear = self._vector2matrix(linear) for i in range(self._S): """ Compute h_{s+1}(t+1) = f([Wh(t) + Cx(t)] + A(t) h_s(t+1)), S times. See Eqn (2) in the paper. """ if not self._reuse_norm: h = self._activation( self._norm(linear + math_ops.batch_matmul(fast_weights, h), scope="Norm%d" % (i + 1))) else: h = self._activation( self._norm(linear + math_ops.batch_matmul(fast_weights, h))) """ Compute A(t+1) according to Eqn (4) """ state = self._vector2matrix(state) new_fast_weights = self._lambda * fast_weights + self._eta * math_ops.batch_matmul( state, state, adj_y=True) h = self._matrix2vector(h) return h, (h, new_fast_weights)
def sample(self, n, seed=None, name=None): """Sample `n` observations from the Multivariate Normal Distributions. Args: n: `Scalar`, type int32, the number of observations to sample. seed: Python integer, the random seed. name: The name to give this op. Returns: samples: `[n, ...]`, a `Tensor` of `n` samples for each of the distributions determined by broadcasting the hyperparameters. """ with ops.op_scope( [self._mu, self._sigma_chol, n], name, "MultivariateNormalSample"): # TODO(ebrevdo): Is there a better way to get broadcast_shape? broadcast_shape = self.mu.get_shape() n = ops.convert_to_tensor(n) sigma_shape_left = array_ops.slice( array_ops.shape(self._sigma_chol), [0], array_ops.pack([array_ops.rank(self._sigma_chol) - 2])) k_n = array_ops.pack([self._k, n]) shape = array_ops.concat(0, [sigma_shape_left, k_n]) white_samples = random_ops.random_normal( shape=shape, mean=0, stddev=1, dtype=self._mu.dtype, seed=seed) correlated_samples = math_ops.batch_matmul( self._sigma_chol, white_samples) # Move the last dimension to the front perm = array_ops.concat( 0, (array_ops.pack([array_ops.rank(correlated_samples) - 1]), math_ops.range(0, array_ops.rank(correlated_samples) - 1))) # TODO(ebrevdo): Once we get a proper tensor contraction op, # perform the inner product using that instead of batch_matmul # and this slow transpose can go away! correlated_samples = array_ops.transpose(correlated_samples, perm) samples = correlated_samples + self.mu # Provide some hints to shape inference n_val = tensor_util.constant_value(n) final_shape = tensor_shape.vector(n_val).concatenate(broadcast_shape) samples.set_shape(final_shape) return samples
def sqrt_matmul(self, x, name='sqrt_matmul'): """Left (batch) matmul `x` by a sqrt of this matrix: `Sx` where `A = S S^T. Args: x: `Tensor` with shape broadcastable to `[N1,...,Nb, k]` and same `dtype` as self. name: A name scope to use for ops added by this method. Returns: Shape `[N1,...,Nb, k]` `Tensor` holding the product `S x`. """ with ops.name_scope(self.name): with ops.op_scope([x] + self.inputs, name): chol_lower = array_ops.batch_matrix_band_part( self._chol, -1, 0) return math_ops.batch_matmul(chol_lower, x)
def _chol_capacitance(self, batch_mode): """Cholesky factorization of the capacitance term.""" # Cholesky factor for (D^{-1} + V^T M^{-1} V), which is sometimes # known as the "capacitance" matrix. # self._operator will use batch if need be. Automatically. We cannot force # that here. # M^{-1} V minv_v = self._operator.solve(self._v) # V^T M^{-1} V if batch_mode: vt_minv_v = math_ops.batch_matmul(self._v, minv_v, adj_x=True) else: vt_minv_v = math_ops.matmul(self._v, minv_v, transpose_a=True) # D^{-1} + V^T M^{-1} V capacitance = self._diag_inv_operator.add_to_tensor(vt_minv_v) # Cholesky[D^{-1} + V^T M^{-1} V] return linalg_ops.cholesky(capacitance)
def variance(self, name="mean"): """Class variances for every batch member. The variance for each batch member is defined as the following: ``` Var(X_j) = n * alpha_j / alpha_0 * (1 - alpha_j / alpha_0) * (n + alpha_0) / (1 + alpha_0) ``` where `alpha_0 = sum_j alpha_j`. The covariance between elements in a batch is defined as: ``` Cov(X_i, X_j) = -n * alpha_i * alpha_j / alpha_0 ** 2 * (n + alpha_0) / (1 + alpha_0) ``` Args: name: The name for this op. Returns: A `Tensor` representing the variances for each batch member. """ alpha = self._alpha alpha_sum = self._alpha_sum n = self._n with ops.name_scope(self.name): with ops.name_scope(name, values=[alpha, alpha_sum, n]): expanded_alpha_sum = array_ops.expand_dims(alpha_sum, -1) shared_factor = n * (expanded_alpha_sum + n) / ( expanded_alpha_sum + 1) * array_ops.ones_like(alpha) mean_no_n = alpha / expanded_alpha_sum expanded_mean_no_n = array_ops.expand_dims(mean_no_n, -1) variance = -math_ops.batch_matmul( expanded_mean_no_n, expanded_mean_no_n, adj_y=True) variance += array_ops.batch_matrix_diag(mean_no_n) variance *= array_ops.expand_dims(shared_factor, -1) return variance
def variance(self, name='mean'): """Class variances for every batch member. The variance for each batch member is defined as the following: ``` Var(X_j) = n * alpha_j / alpha_0 * (1 - alpha_j / alpha_0) * (n + alpha_0) / (1 + alpha_0) ``` where `alpha_0 = sum_j alpha_j`. The covariance between elements in a batch is defined as: ``` Cov(X_i, X_j) = -n * alpha_i * alpha_j / alpha_0 ** 2 * (n + alpha_0) / (1 + alpha_0) ``` Args: name: The name for this op. Returns: A `Tensor` representing the variances for each batch member. """ alpha = self._alpha alpha_sum = self._alpha_sum n = self._n with ops.name_scope(self.name): with ops.op_scope([alpha, alpha_sum, n], name): expanded_alpha_sum = array_ops.expand_dims(alpha_sum, -1) shared_factor = n * (expanded_alpha_sum + n) / ( expanded_alpha_sum + 1) * array_ops.ones_like(alpha) mean_no_n = alpha / expanded_alpha_sum expanded_mean_no_n = array_ops.expand_dims(mean_no_n, -1) variance = -math_ops.batch_matmul( expanded_mean_no_n, expanded_mean_no_n, adj_y=True) variance += array_ops.batch_matrix_diag(mean_no_n) variance *= array_ops.expand_dims(shared_factor, -1) return variance
def _BatchMatMul(op, grad): """Returns the gradient of x and y given the gradient of x * y.""" x = op.inputs[0] y = op.inputs[1] adj_x = op.get_attr("adj_x") adj_y = op.get_attr("adj_y") if not adj_x: if not adj_y: grad_x = math_ops.batch_matmul(grad, y, False, True) grad_y = math_ops.batch_matmul(x, grad, True, False) else: grad_x = math_ops.batch_matmul(grad, y, False, False) grad_y = math_ops.batch_matmul(grad, x, True, False) else: if not adj_y: grad_x = math_ops.batch_matmul(y, grad, False, True) grad_y = math_ops.batch_matmul(x, grad, False, False) else: grad_x = math_ops.batch_matmul(y, grad, True, True) grad_y = math_ops.batch_matmul(grad, x, True, True) return grad_x, grad_y
def log_pdf(self, x, name=None): """Log pdf of observations `x` given these Multivariate Normals. Args: x: tensor of dtype `dtype`, must be broadcastable with `mu`. name: The name to give this op. Returns: log_pdf: tensor of dtype `dtype`, the log-PDFs of `x`. """ with ops.op_scope( [self._mu, self._sigma_chol, x], name, "MultivariateNormalLogPdf"): x = ops.convert_to_tensor(x) contrib_tensor_util.assert_same_float_dtype((self._mu, x)) x_centered = x - self.mu x_rank = array_ops.rank(x_centered) sigma_rank = array_ops.rank(self._sigma_chol) x_rank_vec = array_ops.pack([x_rank]) sigma_rank_vec = array_ops.pack([sigma_rank]) x_shape = array_ops.shape(x_centered) # sigma_chol is shaped [D, E, F, ..., k, k] # x_centered shape is one of: # [D, E, F, ..., k], or [F, ..., k], or # [A, B, C, D, E, F, ..., k] # and we need to convert x_centered to shape: # [D, E, F, ..., k, A*B*C] (or 1 if A, B, C don't exist) # then transpose and reshape x_whitened back to one of the shapes: # [D, E, F, ..., k], or [1, 1, F, ..., k], or # [A, B, C, D, E, F, ..., k] # This helper handles the case where rank(x_centered) < rank(sigma) def _broadcast_x_not_higher_rank_than_sigma(): return array_ops.reshape( x_centered, array_ops.concat( # Reshape to ones(deficient x rank) + x_shape + [1] 0, (array_ops.ones(array_ops.pack([sigma_rank - x_rank - 1]), dtype=x_rank.dtype), x_shape, [1]))) # These helpers handle the case where rank(x_centered) >= rank(sigma) def _broadcast_x_higher_rank_than_sigma(): x_shape_left = array_ops.slice( x_shape, [0], sigma_rank_vec - 1) x_shape_right = array_ops.slice( x_shape, sigma_rank_vec - 1, x_rank_vec - 1) x_shape_perm = array_ops.concat( 0, (math_ops.range(sigma_rank - 1, x_rank), math_ops.range(0, sigma_rank - 1))) return array_ops.reshape( # Convert to [D, E, F, ..., k, B, C] array_ops.transpose( x_centered, perm=x_shape_perm), # Reshape to [D, E, F, ..., k, B*C] array_ops.concat( 0, (x_shape_right, array_ops.pack([ math_ops.reduce_prod(x_shape_left, 0)])))) def _unbroadcast_x_higher_rank_than_sigma(): x_shape_left = array_ops.slice( x_shape, [0], sigma_rank_vec - 1) x_shape_right = array_ops.slice( x_shape, sigma_rank_vec - 1, x_rank_vec - 1) x_shape_perm = array_ops.concat( 0, (math_ops.range(sigma_rank - 1, x_rank), math_ops.range(0, sigma_rank - 1))) return array_ops.transpose( # [D, E, F, ..., k, B, C] => [B, C, D, E, F, ..., k] array_ops.reshape( # convert to [D, E, F, ..., k, B, C] x_whitened_broadcast, array_ops.concat(0, (x_shape_right, x_shape_left))), perm=x_shape_perm) # Step 1: reshape x_centered x_centered_broadcast = control_flow_ops.cond( # x_centered == [D, E, F, ..., k] => [D, E, F, ..., k, 1] # or == [F, ..., k] => [1, 1, F, ..., k, 1] x_rank <= sigma_rank - 1, _broadcast_x_not_higher_rank_than_sigma, # x_centered == [B, C, D, E, F, ..., k] => [D, E, F, ..., k, B*C] _broadcast_x_higher_rank_than_sigma) x_whitened_broadcast = linalg_ops.batch_matrix_triangular_solve( self._sigma_chol, x_centered_broadcast) # Reshape x_whitened_broadcast back to x_whitened x_whitened = control_flow_ops.cond( x_rank <= sigma_rank - 1, lambda: array_ops.reshape(x_whitened_broadcast, x_shape), _unbroadcast_x_higher_rank_than_sigma) x_whitened = array_ops.expand_dims(x_whitened, -1) # Reshape x_whitened to contain row vectors # Returns a batchwise scalar x_whitened_norm = math_ops.batch_matmul( x_whitened, x_whitened, adj_x=True) x_whitened_norm = control_flow_ops.cond( x_rank <= sigma_rank - 1, lambda: array_ops.squeeze(x_whitened_norm, [-2, -1]), lambda: array_ops.squeeze(x_whitened_norm, [-1])) log_two_pi = constant_op.constant(math.log(2 * math.pi), dtype=self.dtype) k = math_ops.cast(self._k, self.dtype) log_pdf_value = ( -math_ops.log(self._sigma_det) -k * log_two_pi - x_whitened_norm) / 2 final_shaped_value = control_flow_ops.cond( x_rank <= sigma_rank - 1, lambda: log_pdf_value, lambda: array_ops.squeeze(log_pdf_value, [-1])) output_static_shape = x_centered.get_shape()[:-1] final_shaped_value.set_shape(output_static_shape) return final_shaped_value
def _to_dense(self): sqrt = self.sqrt_to_dense() return math_ops.batch_matmul(sqrt, sqrt, adj_y=True)
def __init__(self, mu, sigma=None, sigma_chol=None, name=None): """Multivariate Normal distributions on `R^k`. User must provide means `mu`, which are tensors of rank `N+1` (`N >= 0`) with the last dimension having length `k`. User must provide exactly one of `sigma` (the covariance matrices) or `sigma_chol` (the cholesky decompositions of the covariance matrices). `sigma` or `sigma_chol` must be of rank `N+2`. The last two dimensions must both have length `k`. The first `N` dimensions correspond to batch indices. If `sigma_chol` is not provided, the batch cholesky factorization of `sigma` is calculated for you. The shapes of `mu` and `sigma` must match for the first `N` dimensions. Regardless of which parameter is provided, the covariance matrices must all be **positive definite** (an error is raised if one of them is not). Args: mu: (N+1)-D. `float` or `double` tensor, the means of the distributions. sigma: (N+2)-D. (optional) `float` or `double` tensor, the covariances of the distribution(s). The first `N+1` dimensions must match those of `mu`. Must be batch-positive-definite. sigma_chol: (N+2)-D. (optional) `float` or `double` tensor, a lower-triangular factorization of `sigma` (`sigma = sigma_chol . sigma_chol^*`). The first `N+1` dimensions must match those of `mu`. The tensor itself need not be batch lower triangular: we ignore the upper triangular part. However, the batch diagonals must be positive (i.e., sigma_chol must be batch-positive-definite). name: The name to give Ops created by the initializer. Raises: ValueError: if neither sigma nor sigma_chol is provided. TypeError: if mu and sigma (resp. sigma_chol) are different dtypes. """ if (sigma is None) == (sigma_chol is None): raise ValueError("Exactly one of sigma and sigma_chol must be provided") with ops.op_scope([mu, sigma, sigma_chol], name, "MultivariateNormal"): sigma_or_half = sigma_chol if sigma is None else sigma mu = ops.convert_to_tensor(mu) sigma_or_half = ops.convert_to_tensor(sigma_or_half) contrib_tensor_util.assert_same_float_dtype((mu, sigma_or_half)) with ops.control_dependencies([ _assert_compatible_shapes(mu, sigma_or_half)]): mu = array_ops.identity(mu, name="mu") # Store the dimensionality of the MVNs self._k = array_ops.gather(array_ops.shape(mu), array_ops.rank(mu) - 1) if sigma_chol is not None: # Ensure we only keep the lower triangular part. sigma_chol = array_ops.batch_matrix_band_part( sigma_chol, num_lower=-1, num_upper=0) sigma_det = _determinant_from_sigma_chol(sigma_chol) with ops.control_dependencies([ _assert_batch_positive_definite(sigma_chol)]): self._sigma = math_ops.batch_matmul( sigma_chol, sigma_chol, adj_y=True, name="sigma") self._sigma_chol = array_ops.identity(sigma_chol, "sigma_chol") self._sigma_det = array_ops.identity(sigma_det, "sigma_det") self._mu = array_ops.identity(mu, "mu") else: # sigma is not None sigma_chol = linalg_ops.batch_cholesky(sigma) sigma_det = _determinant_from_sigma_chol(sigma_chol) # batch_cholesky checks for PSD; so we can just use it here. with ops.control_dependencies([sigma_chol]): self._sigma = array_ops.identity(sigma, "sigma") self._sigma_chol = array_ops.identity(sigma_chol, "sigma_chol") self._sigma_det = array_ops.identity(sigma_det, "sigma_det") self._mu = array_ops.identity(mu, "mu")
def sample_n(self, n, seed=None, name="sample"): # pylint: disable=line-too-long """Generate `n` samples. Complexity: O(nbk^3) The sampling procedure is based on the [Bartlett decomposition]( https://en.wikipedia.org/wiki/Wishart_distribution#Bartlett_decomposition) and [using a Gamma distribution to generate Chi2 random variates]( https://en.wikipedia.org/wiki/Chi-squared_distribution#Gamma.2C_exponential.2C_and_related_distributions). Args: n: `Scalar` `Tensor` of type `int32` or `int64`, the number of observations to sample. seed: Python integer; random number generator seed. name: The name of this op. Returns: samples: a `Tensor` of shape `(n,) + self.batch_shape + self.event_shape` with values of type `self.dtype`. """ with ops.name_scope(self.name): with ops.name_scope(name, values=[n] + list(self.inputs.values())): n = ops.convert_to_tensor(n, name="n") if n.dtype != dtypes.int32: raise TypeError("n.dtype=%s which is not int32" % n.dtype) batch_shape = self.batch_shape() event_shape = self.event_shape() batch_ndims = array_ops.shape(batch_shape)[0] ndims = batch_ndims + 3 # sample_ndims=1, event_ndims=2 shape = array_ops.concat(0, ((n,), batch_shape, event_shape)) # Complexity: O(nbk^2) x = random_ops.random_normal(shape=shape, mean=0.0, stddev=1.0, dtype=self.dtype, seed=seed) # Complexity: O(nbk) # This parametrization is equivalent to Chi2, i.e., # ChiSquared(k) == Gamma(alpha=k/2, beta=1/2) g = random_ops.random_gamma( shape=(n,), alpha=self._multi_gamma_sequence(0.5 * self.df, self.dimension), beta=0.5, dtype=self.dtype, seed=seed, ) # Complexity: O(nbk^2) x = array_ops.batch_matrix_band_part(x, -1, 0) # Tri-lower. # Complexity: O(nbk) x = array_ops.batch_matrix_set_diag(x, math_ops.sqrt(g)) # Make batch-op ready. # Complexity: O(nbk^2) perm = array_ops.concat(0, (math_ops.range(1, ndims), (0,))) x = array_ops.transpose(x, perm) shape = array_ops.concat(0, (batch_shape, (event_shape[0], -1))) x = array_ops.reshape(x, shape) # Complexity: O(nbM) where M is the complexity of the operator solving a # vector system. E.g., for OperatorPDDiag, each matmul is O(k^2), so # this complexity is O(nbk^2). For OperatorPDCholesky, each matmul is # O(k^3) so this step has complexity O(nbk^3). x = self.scale_operator_pd.sqrt_matmul(x) # Undo make batch-op ready. # Complexity: O(nbk^2) shape = array_ops.concat(0, (batch_shape, event_shape, (n,))) x = array_ops.reshape(x, shape) perm = array_ops.concat(0, ((ndims - 1,), math_ops.range(0, ndims - 1))) x = array_ops.transpose(x, perm) if not self.cholesky_input_output_matrices: # Complexity: O(nbk^3) x = math_ops.batch_matmul(x, x, adj_y=True) # Set shape hints. if self.scale_operator_pd.get_shape().ndims is not None: x.set_shape( tensor_shape.TensorShape( [tensor_util.constant_value(n)] + self.scale_operator_pd.get_shape().as_list() ) ) elif x.get_shape().ndims is not None: x.get_shape()[0].merge_with(tensor_shape.TensorDimension(tensor_util.constant_value(n))) return x