Beispiel #1
0
def cholesky_jvp_rule(primals, tangents):
    x, = primals
    sigma_dot, = tangents
    L = np.tril(cholesky_p.bind(x))

    # Forward-mode rule from https://arxiv.org/pdf/1602.07527.pdf
    def phi(X):
        l = np.tril(X)
        return l / (np._constant_like(X, 1) +
                    np.eye(X.shape[-1], dtype=X.dtype))

    tmp = triangular_solve(L,
                           sigma_dot,
                           left_side=False,
                           transpose_a=True,
                           conjugate_a=True,
                           lower=True)
    L_dot = lax.batch_matmul(L,
                             phi(
                                 triangular_solve(L,
                                                  tmp,
                                                  left_side=True,
                                                  transpose_a=False,
                                                  lower=True)),
                             precision=lax.Precision.HIGHEST)
    return L, L_dot
Beispiel #2
0
def cholesky_jvp_rule(primals, tangents):
  x, = primals
  sigma_dot, = tangents
  L = cholesky_p.bind(x)

  # Forward-mode rule from https://arxiv.org/pdf/1602.07527.pdf
  phi = lambda X: np.tril(X) / (1 + np.eye(X.shape[-1], dtype=X.dtype))
  tmp = triangular_solve(L, sigma_dot,
                         left_side=False, transpose_a=True, lower=True)
  L_dot = lax.batch_matmul(L, phi(triangular_solve(
      L, tmp, left_side=True, transpose_a=False, lower=True)))
  return L, L_dot
Beispiel #3
0
def mpnn(params, A, F, nonlin=identity):
    """
    message passing neural network layer

    performs one round of message passing according to the adjacency-like
    matrices present in As, and then does one "dense" matrix multiplication
    on top of the message passing.

    :param params: A dictionary of parameters.
    :param A: A 3D-tensor of adjacency matrices.
        1st dimension is the sample/batch dimension;
        2nd and 3rd dimension must be equal.
    :param F: A 3D-tensor of feature matrices.
        1st dimension is the sample/batch dimension;
        2nd dimension is the node dimension;
        3rd dimension is the feature dimension.
    :returns: F, a 3D-tensor of transformed features.
        1st dimension is the sample/batch dimension;
        2nd dimension is the node dimension;
        3rd dimension is the feature dimension.
    """
    F = batch_matmul(A, F)  # shape will be n_samps x n_nodes x n_feats
    F = np.dot(F, params["w"]) + params["b"]
    return nonlin(F)