def cholesky_jvp_rule(primals, tangents): x, = primals sigma_dot, = tangents L = np.tril(cholesky_p.bind(x)) # Forward-mode rule from https://arxiv.org/pdf/1602.07527.pdf def phi(X): l = np.tril(X) return l / (np._constant_like(X, 1) + np.eye(X.shape[-1], dtype=X.dtype)) tmp = triangular_solve(L, sigma_dot, left_side=False, transpose_a=True, conjugate_a=True, lower=True) L_dot = lax.batch_matmul(L, phi( triangular_solve(L, tmp, left_side=True, transpose_a=False, lower=True)), precision=lax.Precision.HIGHEST) return L, L_dot
def cholesky_jvp_rule(primals, tangents): x, = primals sigma_dot, = tangents L = cholesky_p.bind(x) # Forward-mode rule from https://arxiv.org/pdf/1602.07527.pdf phi = lambda X: np.tril(X) / (1 + np.eye(X.shape[-1], dtype=X.dtype)) tmp = triangular_solve(L, sigma_dot, left_side=False, transpose_a=True, lower=True) L_dot = lax.batch_matmul(L, phi(triangular_solve( L, tmp, left_side=True, transpose_a=False, lower=True))) return L, L_dot
def mpnn(params, A, F, nonlin=identity): """ message passing neural network layer performs one round of message passing according to the adjacency-like matrices present in As, and then does one "dense" matrix multiplication on top of the message passing. :param params: A dictionary of parameters. :param A: A 3D-tensor of adjacency matrices. 1st dimension is the sample/batch dimension; 2nd and 3rd dimension must be equal. :param F: A 3D-tensor of feature matrices. 1st dimension is the sample/batch dimension; 2nd dimension is the node dimension; 3rd dimension is the feature dimension. :returns: F, a 3D-tensor of transformed features. 1st dimension is the sample/batch dimension; 2nd dimension is the node dimension; 3rd dimension is the feature dimension. """ F = batch_matmul(A, F) # shape will be n_samps x n_nodes x n_feats F = np.dot(F, params["w"]) + params["b"] return nonlin(F)