示例#1
0
  def testQdwhWithOnRankDeficientInput(self, m, n, log_cond):
    """Tests qdwh with rank-deficient input."""
    a = jnp.triu(jnp.ones((m, n))).astype(_QDWH_TEST_DTYPE)

    # Generates a rank-deficient input.
    u, s, v = jnp.linalg.svd(a, full_matrices=False)
    cond = 10**log_cond
    s = jnp.linspace(cond, 1, min(m, n))
    s = jnp.expand_dims(s.at[-1].set(0), range(u.ndim - 1))
    a = (u * s) @ v

    is_hermitian = _check_symmetry(a)
    max_iterations = 15
    actual_u, actual_h, _, _ = qdwh.qdwh(a, is_hermitian=is_hermitian,
                                         max_iterations=max_iterations)
    _, expected_h = osp_linalg.polar(a)

    # Sets the test tolerance.
    rtol = 1E4 * _QDWH_TEST_EPS

    # For rank-deficient matrix, `u` is not unique.
    with self.subTest('Test h.'):
      relative_diff_h = _compute_relative_diff(actual_h, expected_h)
      np.testing.assert_almost_equal(relative_diff_h, 1E-6, decimal=5)

    with self.subTest('Test u.dot(h).'):
      a_round_trip = _dot(actual_u, actual_h)
      relative_diff_a = _compute_relative_diff(a_round_trip, a)
      np.testing.assert_almost_equal(relative_diff_a, 1E-6, decimal=5)

    with self.subTest('Test orthogonality.'):
      actual_results = _dot(actual_u.T.conj(), actual_u)
      expected_results = np.eye(n)
      self.assertAllClose(
          actual_results, expected_results, rtol=rtol, atol=1E-6)
示例#2
0
  def testQdwhWithUpperTriangularInputAllOnes(self, m, n, log_cond):
    """Tests qdwh with upper triangular input of all ones."""
    a = jnp.triu(jnp.ones((m, n))).astype(_QDWH_TEST_DTYPE)
    u, s, v = jnp.linalg.svd(a, full_matrices=False)
    cond = 10**log_cond
    s = jnp.expand_dims(jnp.linspace(cond, 1, min(m, n)), range(u.ndim - 1))
    a = (u * s) @ v
    is_hermitian = _check_symmetry(a)
    max_iterations = 10

    actual_u, actual_h, _, _ = qdwh.qdwh(a, is_hermitian=is_hermitian,
                                         max_iterations=max_iterations)
    expected_u, expected_h = osp_linalg.polar(a)

    # Sets the test tolerance.
    rtol = 1E6 * _QDWH_TEST_EPS

    with self.subTest('Test u.'):
      relative_diff_u = _compute_relative_diff(actual_u, expected_u)
      np.testing.assert_almost_equal(relative_diff_u, 1E-6, decimal=5)

    with self.subTest('Test h.'):
      relative_diff_h = _compute_relative_diff(actual_h, expected_h)
      np.testing.assert_almost_equal(relative_diff_h, 1E-6, decimal=5)

    with self.subTest('Test u.dot(h).'):
      a_round_trip = _dot(actual_u, actual_h)
      relative_diff_a = _compute_relative_diff(a_round_trip, a)
      np.testing.assert_almost_equal(relative_diff_a, 1E-6, decimal=5)

    with self.subTest('Test orthogonality.'):
      actual_results = _dot(actual_u.T, actual_u)
      expected_results = np.eye(n)
      self.assertAllClose(
          actual_results, expected_results, rtol=rtol, atol=1E-5)
示例#3
0
    def init_fun(rng, input_dim, **kwargs):
        W = orthogonal()(rng, (input_dim, input_dim))
        P, L, U = scipy.linalg.lu(W)
        S = np.diag(U)
        U = np.triu(U, 1)
        identity = np.eye(input_dim)

        def direct_fun(params, inputs, **kwargs):
            L, U, S = params
            L = np.tril(L, -1) + identity
            U = np.triu(U, 1)
            W = P @ L @ (U + np.diag(S))

            outputs = inputs @ W
            log_det_jacobian = np.full(inputs.shape[:1],
                                       np.log(np.abs(S)).sum())
            return outputs, log_det_jacobian

        def inverse_fun(params, inputs, **kwargs):
            L, U, S = params
            L = np.tril(L, -1) + identity
            U = np.triu(U, 1)
            W = P @ L @ (U + np.diag(S))

            outputs = inputs @ linalg.inv(W)
            log_det_jacobian = np.full(inputs.shape[:1],
                                       -np.log(np.abs(S)).sum())
            return outputs, log_det_jacobian

        return (L, U, S), direct_fun, inverse_fun
示例#4
0
def crossover(parent_1, parent_2, offspring_size):
    all_offspring = []
    for o in range(offspring_size):
        lower_1 = np.tril(parent_1)
        upper_2 = np.triu(parent_2)
        offspring = lower_1 + upper_2
        all_offspring.append(offspring)
示例#5
0
        def custom_assert(tst, result_jax, result_tf, *, args, tol, err_msg):
            operand, = args
            lu, pivots, perm = result_tf
            batch_dims = operand.shape[:-2]
            m, n = operand.shape[-2], operand.shape[-1]

            def _make_permutation_matrix(perm):
                result = []
                for idx in itertools.product(*map(range, operand.shape[:-1])):
                    result += [0 if c != perm[idx] else 1 for c in range(m)]
                result = np.reshape(np.array(result, dtype=dtype),
                                    [*batch_dims, m, m])
                return result

            k = min(m, n)
            l = jnp.tril(lu, -1)[..., :, :k] + jnp.eye(m, k, dtype=dtype)
            u = jnp.triu(lu)[..., :k, :]
            p_mat = _make_permutation_matrix(perm)

            tst.assertArraysEqual(
                lax.linalg.lu_pivots_to_permutation(pivots, m), perm)
            tst.assertAllClose(jnp.matmul(p_mat, operand),
                               jnp.matmul(l, u),
                               atol=tol,
                               rtol=tol,
                               err_msg=err_msg)
示例#6
0
 def build_attention_mask(self):
     # lazily create causal attention mask, with full attention between the vision tokens
     # pytorch uses additive attention mask; fill with -inf
     mask = jnp.zeros((self.context_length, self.context_length))
     mask -= 10e10
     mask = jnp.triu(mask, 1)  # zero out the lower diagonal
     return mask
示例#7
0
def _band_part(input, num_lower, num_upper, name=None):  # pylint: disable=redefined-builtin
    del name
    result = input
    if num_lower > -1:
        result = np.triu(result, -num_lower)
    if num_upper > -1:
        result = np.tril(result, num_upper)
    return result
示例#8
0
        def inverse_fun(params, inputs, **kwargs):
            L, U, S = params
            L = np.tril(L, -1) + identity
            U = np.triu(U, 1)
            W = P @ L @ (U + np.diag(S))

            outputs = inputs @ linalg.inv(W)
            log_det_jacobian = np.full(inputs.shape[:1],
                                       -np.log(np.abs(S)).sum())
            return outputs, log_det_jacobian
示例#9
0
    def __call__(self, inputs: Array) -> Array:
        """
        Applies a masked linear transformation to the inputs.

        Args:
          inputs: input data with dimensions (batch, length, features).

        Returns:
          The transformed data.
        """
        if inputs.ndim == 2:
            is_single_input = True
            inputs = jnp.expand_dims(inputs, axis=0)
        else:
            is_single_input = False

        batch, size, in_features = inputs.shape
        inputs = inputs.reshape((batch, size * in_features))

        if self.use_bias:
            bias = self.param(
                "bias", self.bias_init, (size, self.features), self.param_dtype
            )
        else:
            bias = None

        mask = jnp.ones((size, size), dtype=self.param_dtype)
        mask = jnp.triu(mask, self.exclusive)
        mask = jnp.kron(
            mask, jnp.ones((in_features, self.features), dtype=self.param_dtype)
        )

        kernel = self.param(
            "kernel",
            wrap_kernel_init(self.kernel_init, mask),
            (size * in_features, size * self.features),
            self.param_dtype,
        )

        inputs, mask, kernel, bias = promote_dtype(
            inputs, mask, kernel, bias, dtype=None
        )

        y = lax.dot(inputs, mask * kernel, precision=self.precision)

        y = y.reshape((batch, size, self.features))

        if is_single_input:
            y = y.squeeze(axis=0)

        if self.use_bias:
            y = y + bias

        return y
示例#10
0
文件: qr.py 项目: ithanlevin/dfact
def factored_to_QR(h, beta):
    """
    Computes dense matrices Q and R from the factored QR representation
    [h, tau] as computed by qr with mode == "factored".
    """
    m, n = h.shape
    R = jnp.triu(h)
    Q = jnp.eye(m, dtype=h.dtype)
    for j in range(n - 1, -1, -1):
        v = jnp.concatenate((jnp.array([1.]), h[j + 1:, j]))
        Q = index_update(Q, index[j:, j:],
                         house_leftmult(Q[j:, j:], v, beta[j]))
    out = [Q, R]
    return out
示例#11
0
    def testSvdWithOnRankDeficientInput(self, m, r, log_cond):
        """Tests SVD with rank-deficient input."""
        with jax.default_matmul_precision('float32'):
            a = jnp.triu(jnp.ones((m, m))).astype(_SVD_TEST_DTYPE)

            # Generates a rank-deficient input.
            u, s, v = jnp.linalg.svd(a, full_matrices=False)
            cond = 10**log_cond
            s = jnp.linspace(cond, 1, m)
            s = s.at[r:m].set(jnp.zeros((m - r, )))
            a = (u * s) @ v

            with jax.default_matmul_precision('float32'):
                u, s, v = svd.svd(a, full_matrices=False, hermitian=False)
            diff = np.linalg.norm(a - (u * s) @ v)

            np.testing.assert_almost_equal(diff, 1E-4, decimal=2)
示例#12
0
    def _transform_to_covariance_matrix(self, sq_mat):
        '''
        Takes the upper triangular matrix of the given matrix and then multiplies it by its transpose
        https://ericmjl.github.io/notes/stats-ml/estimating-a-multivariate-gaussians-parameters-by-gradient-descent/

        Parameters
        ----------
        sq_mat : array
            Square matrix

        Returns
        -------
        * array
        '''
        U = jnp.triu(sq_mat)
        U_T = jnp.transpose(U)
        return jnp.dot(U_T, U)
示例#13
0
    def testQdwhUnconvergedAfterMaxNumberIterations(self, m, n, log_cond):
        """Tests unconvergence after maximum number of iterations."""
        a = jnp.triu(jnp.ones((m, n)))
        u, s, v = jnp.linalg.svd(a, full_matrices=False)
        cond = 10**log_cond
        s = jnp.linspace(cond, 1, min(m, n))
        a = (u * s) @ v
        is_symmetric = _check_symmetry(a)
        max_iterations = 2

        _, _, actual_num_iterations, is_converged = qdwh.qdwh(
            a, is_symmetric, max_iterations)

        with self.subTest('Number of iterations.'):
            self.assertEqual(max_iterations, actual_num_iterations)

        with self.subTest('Converged.'):
            self.assertFalse(is_converged)
示例#14
0
def fill_triangular(x, upper=False):
    m = x.shape[-1]
    if len(x.shape) != 1:
        raise ValueError("Only handles 1D to 2D transformation, because tril/u")
    m = np.int32(m)
    n = np.sqrt(0.25 + 2. * m) - 0.5
    if n != np.floor(n):
        raise ValueError('Input right-most shape ({}) does not '
                         'correspond to a triangular matrix.'.format(m))
    n = np.int32(n)
    final_shape = list(x.shape[:-1]) + [n, n]
    if upper:
        x_list = [x, np.flip(x[..., n:], -1)]

    else:
        x_list = [x[..., n:], np.flip(x, -1)]
    x = np.reshape(np.concatenate(x_list, axis=-1), final_shape)
    if upper:
        x = np.triu(x)
    else:
        x = np.tril(x)
    return x
示例#15
0
    def __call__(self, x, pos_emb, mask):
        dim_in, h = x.shape[-1], self.heads
        scale = dim_in**-0.5

        norm = nn.LayerNorm()
        to_qkv = nn.Dense(features=self.dim_head * h * 3, use_bias=False)
        to_out = nn.Dense(features=dim_in)

        x = norm(x)
        qkv = np.split(to_qkv(x), 3, axis=-1)
        q, k, v = map(lambda t: rearrange(t, "i (h d) -> i h d", h=h), qkv)

        q = index_update(q, index[1:], apply_rotary_pos_emb(q[1:], pos_emb))
        k = index_update(k, index[1:], apply_rotary_pos_emb(k[1:], pos_emb))

        sim = einsum("i h d, j h d -> i j h", q, k) * scale

        mask = np.pad(mask, (1, 0), constant_values=True)
        mask = rearrange(mask, "j -> () j ()")

        if self.causal:
            i, j = sim.shape[:2]
            tri_mask = np.ones((i - 1, j - 1), dtype=bool)
            tri_mask = np.pad(tri_mask, ((1, 0), (1, 0)),
                              constant_values=False)
            causal_mask = np.triu(tri_mask, j - i + 1)
            causal_mask = rearrange(causal_mask, "i j -> i j ()")
            mask = ~causal_mask * mask

        sim = np.where(mask, sim, LARGE_NEG_VALUE)

        attn = nn.softmax(sim, axis=-2)

        out = einsum("i j h, j h d -> i h d", attn, v)

        out = rearrange(out, "i h d -> i (h d)")
        return to_out(out)
示例#16
0
 def setup(self):
     """Initialize P, L, U, s"""
     # W = PL(U + s)
     # Based on https://github.com/openai/glow/blob/master/model.py#L485
     dim = self.input_dim
     # Sample random rotation matrix
     q, _ = np.linalg.qr(jax.random.normal(self.rng, (dim, dim)),
                         mode="complete")
     p, l, u = jax.scipy.linalg.lu(q)
     # Fixed Permutation (non-trainable)
     self.P = p
     self.P_inv = jax.scipy.linalg.inv(p)
     # Init value from LU decomposition
     L_init = l
     U_init = np.triu(u, k=1)
     s = np.diag(u)
     self.sign_s = np.sign(s)
     S_log_init = np.log(np.abs(s))
     self.l_mask = np.tril(np.ones((dim, dim)), k=-1)
     self.u_mask = np.transpose(self.l_mask)
     # Define trainable variables
     self.L = self.param("L", lambda k, sh: L_init, (dim, dim))
     self.U = self.param("U", lambda k, sh: U_init, (dim, dim))
     self.log_s = self.param("log_s", lambda k, sh: S_log_init, (dim, ))
示例#17
0
文件: qr.py 项目: ithanlevin/dfact
def house_qr(A, mode="reduced"):
    """
    Performs a QR decomposition of the m x n real or complex matrix A
    using the Householder algorithm.

    The string parameter 'mode' determines the representation of the output.
    In this way, one can retrieve various implicit representations of the
    factored matrices. This can be a significant optimization in the case
    of a highly rectangular A, which is the reason for this function's
    existence.

    Parameters
    ----------
    A : array_like, shape (M, N)
            Matrix to be factored.

        mode: {'reduced', 'complete', 'r', 'factored', 'WY'}, optional
            If K = min(M, N), then:
              - 'reduced': returns Q, R with dimensions (M, K), (K, N)
                (default)
              - 'complete': returns Q, R  with dimensions (M, M), (M, N)
              - 'r': returns r only with dimensions (K, N)
              - 'factored': returns H, beta with dimensions (N, M), (K), read
                 below for details.
              - 'WY' : returns W, Y with dimensions (M, K), read below for
                 details.

    With 'reduced', 'complete', or 'r', this function simply passes to
    jnp.linalg.qr, which depending on the currect status of Jax may lead to
    NotImplemented if A is complex.

    With 'factored' this function returns the same H, beta as generated by
    the Lapack function dgeqrf() (but in row-major form). Thus,
    H contains the upper triangular matrix R in its upper triangle, and
    the j'th Householder reflector forming Q in the j'th column of its
    lower triangle. beta[j] contains the normalization factor of the j'th
    reflector, called 'beta' in the function 'house' in this module.

    The matrix Q is then represented implicitly as
        Q = H(0) H(1) ... H(K), H(i) = I - tau[i] v dag(v)
    with v[:j] = 0; v[j]=1; v[j+1:]=A[j+1:, j].

    Application of Q (C -> dag{Q} C) can be made directly from this implicit
    representation using the function factored_multiply(C). When
    K << max(M, N), both the QR factorization and multiplication by Q
    using factored_multiply theoretically require far fewer operations than
    would an explicit representation of Q. However, these applications
    are mostly Level-2 BLAS operations.

    With 'WY' this function returns (M, K) matrices W and Y such that
        Q = I - W dag(Y).
    Y is lower-triangular matrix of Householder vectors, e.g. the lower
    triangle
    of the matrix H resulting from mode='factored'. W is then computed so
    that the above identity holds.

    Application of Q can be made directly from the WY representation
    using the function WY_multiply(C). The WY representation is
    a bit more expensive to compute than the factored one, though still less
    expensive than the full Q. Its advantage versus 'factored' is that
    WY_multiply calls depend mostly on Level-3 BLAS operations.


    Returns
    -------
    Q: ndarray of float or complex, optional
        The column-orthonormal orthogonal/unitary matrix Q.

    R: ndarray of float or complex, optional.
        The upper-triangular matrix.

    [H, beta]: list of ndarrays of float or complex, optional.
        The matrix H and scaling factors beta generating Q along with R in the
        'factored' representation.

    [W, Y, R] : list of ndarrays of float or complex, optional.
        The matrices W and Y generating Q along with R in the 'WY'
        representation.

    Raises
    ------
    LinAlgError
        If factoring fails.

    NotImplementedError
        In reduced, complete, or r mode with complex ijnp.t.
        In factored or WY mode in the case M < N.
    """
    if mode == "reduced" or mode == "complete" or mode == "r":
        return jnp.linalg.qr(A, mode=mode)
    else:
        m, n = A.shape
        if n > m:
            raise NotImplementedError("n > m QR not implemented in factored" +
                                      "or WY mode.")
        if mode == "factored":
            return __house_qr_factored(A)
        elif mode == "WY":
            hbetalist = __house_qr_factored(A)
            R = jnp.triu(hbetalist[0])
            WYlist = factored_to_WY(hbetalist)
            output = WYlist + [R]
            return output
        else:
            raise ValueError("Invalid mode: ", mode)
示例#18
0
  def step(self, s, a):
    """Apply control, damping, boundary, and collision forces.

    Args:
      s: (p, v, misc), where p and v are [n_entities,2] jnp.float32,
         and misc is child defined
      a: [n_agents, dim_a] jnp.float32

    Returns:
      A state tuple (p, v, misc)
    """
    p, v, misc = s  # [n,2], [n,2], [a_shape]
    f = jnp.zeros_like(p)  # [n,2]
    n = p.shape[0]  # number of entities

    # Calculate control forces
    f_control = jnp.pad(a, ((0, n-a.shape[0]), (0, 0)),
                        mode="constant")  # [n, dim_a]
    f += f_control

    # Calculate damping forces
    f_damping = -1.0*self.damping*v  # [n,2]
    f = f + f_damping

    # Calculate boundary forces
    bounce = (((p+self.radius >= self.max_p) & (v >= 0.0)) |
              ((p-self.radius <= self.min_p) & (v <= 0.0)))  # [n,2]
    v_new = (-1.0*bounce + 1.0*~bounce)*v  # [n,2]
    f_boundary = self.mass*(v_new - v)/self.dt  # [n,2]
    f = f + f_boundary

    # Calculate shared quantities for later calculations
    # same: [n,n,1], True if i==j
    same = jnp.expand_dims(jnp.eye(n, dtype=jnp.bool_), axis=-1)
    # p2p: [n,n,2], p2p[i,j,:] is the vector from entity i to entity j
    p2p = p - jnp.expand_dims(p, axis=1)
    # dist: [n,n,1], p2p[i,j,0] is the distance between i and j
    dist = jnp.linalg.norm(p2p, axis=-1, keepdims=True)
    # overlap: [n,n,1], overlap[i,j,0] is the overlap between i and j
    overlap = ((jnp.expand_dims(self.radius, axis=1) +
                jnp.expand_dims(self.radius, axis=0)) -
               dist)
    if self.same_position_check:
      # ontop: [n,n,1], ontop[i,j,0] = True if i is at the exact location of j
      ontop = (dist == 0.0)
      # ontop_dir: [n,n,1], (1,0) above diagonal, (-1,0) below diagonal
      ontop_dir = jnp.stack([jnp.triu(jnp.ones((n, n)))*2-1, jnp.zeros((n, n))],
                            axis=-1)
      # contact_dir: [n,n,2], contact_dir[i,j,:] is the unit vector in the
      # direction of j from i
      contact_dir = (~ontop*p2p + (ontop*ontop_dir))/(~ontop*dist + ontop*1.0)
    else:
      # contact_dir: [n,n,2], contact_dir[i,j,:] is the unit vector in the
      # direction of j from i
      contact_dir = p2p/(dist+same)
    # collideable: [n,n,1], True if i and j are collideable
    collideable = (jnp.expand_dims(self.collideable, axis=1) &
                   jnp.expand_dims(self.collideable, axis=0))
    # overlap: [n,n,1], True if i,j overlap
    overlapping = overlap > 0

    # Calculate collision forces
    # Assume all entities collide with all entities, then mask out
    # non-collisions.
    #
    # For approaching, coliding entities, apply a forces
    # along the direction of collision that results in
    # relative velocities consistent with the coefficient of
    # restitution (c) and preservation of momentum in that
    # direction.
    # momentum: m_a*v_a + m_b*v_b = m_a*v'_a + m_b*v'_b
    # restitution: v'_b - v'_a = -c*(v_b-v_a)
    # solve for v'_a:
    #  v'_a = [m_a*v_a + m_b*v_b + m_b*c*(v_b-v_a)]/(m_a + m_b)
    #
    # v_contact_dir: [n,n] speed of i in dir of j
    v_contact_dir = jnp.sum(jnp.expand_dims(v, axis=-2)*contact_dir, axis=-1)
    # v_approach: [n,n] speed that i,j are approaching each other
    v_approach = jnp.transpose(v_contact_dir) + v_contact_dir
    # momentum: [n,n] joint momentum in direction of contact (i->j)
    momentum = self.mass*v_contact_dir - jnp.transpose(self.mass*v_contact_dir)
    # v_result: [n,n] speed of i in dir of j after collision
    v_result = ((momentum +
                 self.restitution*jnp.transpose(self.mass)*(-v_approach)) /
                (self.mass + jnp.transpose(self.mass)))
    # f_collision: [n,n] force on i in dir of j to realize acceleration
    f_collision = self.mass*(v_result - v_contact_dir)/self.dt
    # f_collision: [n,n,2] force on i to realize acceleration due to
    # collision with j
    f_collision = jnp.expand_dims(f_collision, axis=-1)*contact_dir
    # collision_mask: [n,n,1]
    collision_mask = (collideable & overlapping & ~same &
                      (jnp.expand_dims(v_approach, axis=-1) > 0))
    # f_collision: [n,2], sum of collision forces on i
    f_collision = jnp.sum(f_collision*collision_mask, axis=-2)
    f = f + f_collision

    # Calculate overlapping spring forces
    # This corrects for any overlap due to discrete steps.
    # f_overlap: [n,n,2], force in the negative contact dir due to overlap
    f_overlap = -1.0*contact_dir*overlap*self.overlap_spring_constant
    # overlapping_mask: [n,n,1], True if i,j are collideable, overlap,
    # and i != j
    overlapping_mask = collideable & overlapping & ~same
    # f_overlap: [n,2], sum of spring forces on i
    f_overlap = jnp.sum(f_overlap*overlapping_mask, axis=-2)
    f = f + f_overlap

    # apply forces
    v = v + (f/self.mass)*self.dt
    p = p + v*self.dt

    # update misc
    misc = self._update_misc((p, v, misc), a)  # pylint: disable=assignment-from-none

    return (p, v, misc)
示例#19
0
 def mass_matrix_inv_mul(self, q: jnp.ndarray, v: jnp.ndarray,
                         **kwargs) -> jnp.ndarray:
     """Computes the product of the inverse mass matrix with a vector."""
     if self.kinetic_func_form in ("separable_net", "dep_net"):
         raise ValueError(
             "It is not possible to compute `M^-1 p` when using a "
             "network for the kinetic energy.")
     if self.kinetic_func_form in ("pure_quad", "embed_quad"):
         return v
     if self.kinetic_func_form == "matrix_diag_quad":
         if self.parametrize_mass_matrix:
             m_diag_log = hk.get_parameter(
                 "MassMatrixDiagLog",
                 shape=[self.system_dim],
                 init=hk.initializers.Constant(0.0))
             m_inv_diag = 1.0 / (jnp.exp(m_diag_log) + self.mass_eps)
         else:
             m_inv_diag_log = hk.get_parameter(
                 "InvMassMatrixDiagLog",
                 shape=[self.system_dim],
                 init=hk.initializers.Constant(0.0))
             m_inv_diag = jnp.exp(m_inv_diag_log) + self.mass_eps
         return m_inv_diag * v
     if self.kinetic_func_form == "matrix_quad":
         if self.parametrize_mass_matrix:
             m_triu = hk.get_parameter(
                 "MassMatrixU",
                 shape=[self.system_dim, self.system_dim],
                 init=hk.initializers.Identity())
             m_triu = jnp.triu(m_triu)
             m = jnp.matmul(m_triu.T, m_triu)
             m = m + self.mass_eps * jnp.eye(self.system_dim)
             solve = jnp.linalg.solve
             for _ in range(v.ndim + 1 - m.ndim):
                 solve = jax.vmap(solve, in_axes=(None, 0))
             return solve(m, v)
         else:
             m_inv_triu = hk.get_parameter(
                 "InvMassMatrixU",
                 shape=[self.system_dim, self.system_dim],
                 init=hk.initializers.Identity())
             m_inv_triu = jnp.triu(m_inv_triu)
             m_inv = jnp.matmul(m_inv_triu.T, m_inv_triu)
             m_inv = m_inv + self.mass_eps * jnp.eye(self.system_dim)
             return self.feature_matrix_vector(m_inv, v)
     if self.kinetic_func_form in ("matrix_dep_diag_quad",
                                   "matrix_dep_diag_embed_quad"):
         if self.parametrize_mass_matrix:
             m_diag_log = self.mass_matrix_net(q, **kwargs)
             m_inv_diag = 1.0 / (jnp.exp(m_diag_log) + self.mass_eps)
         else:
             m_inv_diag_log = self.mass_matrix_net(q, **kwargs)
             m_inv_diag = jnp.exp(m_inv_diag_log) + self.mass_eps
         return m_inv_diag * v
     if self.kinetic_func_form in ("matrix_dep_quad",
                                   "matrix_dep_embed_quad"):
         if self.parametrize_mass_matrix:
             m_triu = self.mass_matrix_net(q, **kwargs)
             m_triu = utils.triu_matrix_from_v(m_triu, self.system_dim)
             m = jnp.matmul(jnp.swapaxes(m_triu, -2, -1), m_triu)
             m = m + self.mass_eps * jnp.eye(self.system_dim)
             return jnp.linalg.solve(m, v)
         else:
             m_inv_triu = self.mass_matrix_net(q, **kwargs)
             m_inv_triu = utils.triu_matrix_from_v(m_inv_triu,
                                                   self.system_dim)
             m_inv = jnp.matmul(jnp.swapaxes(m_inv_triu, -2, -1),
                                m_inv_triu)
             m_inv = m_inv + self.mass_eps * jnp.eye(self.system_dim)
             return self.feature_matrix_vector(m_inv, v)
     raise NotImplementedError()
示例#20
0
def triu(a, k=0):
  if isinstance(a, JaxArray): a = a.value
  return JaxArray(jnp.triu(a, k))
#         errs_a.append(np.stack(erra).mean())
#         errs_std_a.append(np.stack(erra).std())
#         errs_b.append(np.stack(errb).mean())
#         errs_std_b.append(np.stack(errb).std())

# Compute generalization error, rapid but heuristic
errs_a = []
errs_b = []
for ii, a in enumerate(range(K)):
    Xa = np.array(manifolds[a])
    for b in range(a + 1, K):
        Xb = np.array(manifolds[b])
        erra = []
        errb = []

        key, _ = random.split(key)
        erra, errb = mshot_err_fast(key, Xa, Xb)

        errs_a.append(erra)
        errs_b.append(errb)

    print('Manifold {} of {}. Avg. acc: {}'.format(ii, K,
                                                   1 - errs_a[-1].mean()))

# Combine errs_a and errs_b into K x K matrix
errs_full = np.triu(squareform(errs_a)) + np.tril(squareform(errs_b))

# Save
np.save(save_path, errs_full)
print('Finished with acc. ' + str(1 - np.mean(errs_full)) + '. Saved.')
def compute_OBC_energy_vectorized(
    distance_matrix,
    radii,
    scales,
    charges,
    offset=0.009,
    screening=138.935484,
    surface_tension=28.3919551,
    solvent_dielectric=78.5,
    solute_dielectric=1.0,
):
    """Compute GBSA-OBC energy from a distance matrix"""
    N = len(radii)
    #print(type(distance_matrix))
    eye = np.eye(N, dtype=distance_matrix.dtype)
    #print(type(eye))
    r = distance_matrix + eye  # so I don't have divide-by-zero nonsense
    or1 = radii.reshape((N, 1)) - offset
    or2 = radii.reshape((1, N)) - offset
    sr2 = scales.reshape((1, N)) * or2

    L = np.maximum(or1, abs(r - sr2))
    U = r + sr2
    I = step(r + sr2 - or1) * 0.5 * (1 / L - 1 / U + 0.25 * (r - sr2**2 / r) *
                                     (1 / (U**2) - 1 /
                                      (L**2)) + 0.5 * np.log(L / U) / r)

    I -= np.diag(np.diag(I))
    I = np.sum(I, axis=1)

    # okay, next compute born radii
    offset_radius = radii - offset
    psi = I * offset_radius
    psi_coefficient = 0.8
    psi2_coefficient = 0
    psi3_coefficient = 2.909125

    psi_term = (psi_coefficient * psi) + (psi2_coefficient *
                                          psi**2) + (psi3_coefficient * psi**3)

    B = 1 / (1 / offset_radius - np.tanh(psi_term) / radii)

    # finally, compute the three energy terms
    E = 0.0

    # single particle
    E += np.sum(surface_tension * (radii + 0.14)**2 * (radii / B)**6)
    E += np.sum(-0.5 * screening *
                (1 / solute_dielectric - 1 / solvent_dielectric) * charges**2 /
                B)

    # particle pair
    f = np.sqrt(r**2 + np.outer(B, B) * np.exp(-r**2 / (4 * np.outer(B, B))))
    charge_products = np.outer(charges, charges)

    E += np.sum(
        np.triu(-screening * (1 / solute_dielectric - 1 / solvent_dielectric) *
                charge_products / f,
                k=1))

    return E
示例#23
0
def gbsa_obc(
        coords,
        # params,
        lamb,
        # box,
        charge_params,
        gb_params,
        # charge_idxs,
        # radii_idxs,
        # scale_idxs,
        alpha,
        beta,
        gamma,
        cutoff_radii,
        cutoff_force,
        lambda_plane_idxs,
        lambda_offset_idxs,
        dielectric_offset=0.009,
        surface_tension=28.3919551,
        solute_dielectric=1.0,
        solvent_dielectric=78.5,
        probe_radius=0.14):

    box = None

    assert cutoff_radii == cutoff_force

    coords_4d = convert_to_4d(coords, lamb, lambda_plane_idxs,
                              lambda_offset_idxs, cutoff_radii)

    N = len(charge_params)

    radii = gb_params[:, 0]
    scales = gb_params[:, 1]

    ri = np.expand_dims(coords_4d, 0)
    rj = np.expand_dims(coords_4d, 1)

    dij = distance(ri, rj, box)

    eye = np.eye(N, dtype=dij.dtype)

    r = dij + eye  # so I don't have divide-by-zero nonsense
    or1 = radii.reshape((N, 1)) - dielectric_offset
    or2 = radii.reshape((1, N)) - dielectric_offset
    sr2 = scales.reshape((1, N)) * or2

    L = np.maximum(or1, abs(r - sr2))
    U = r + sr2

    I = 1 / L - 1 / U + 0.25 * (r - sr2**2 / r) * (1 / (U**2) - 1 /
                                                   (L**2)) + 0.5 * np.log(
                                                       L / U) / r
    # handle the interior case
    I = np.where(or1 < (sr2 - r), I + 2 * (1 / or1 - 1 / L), I)
    I = step(r + sr2 - or1) * 0.5 * I  # note the extra 0.5 here
    I -= np.diag(np.diag(I))

    # switch I only for now
    # inner = (np.pi*np.power(dij,8))/(2*cutoff_radii)
    # sw = np.power(np.cos(inner), 2)
    # I = I*sw

    I = np.where(dij > cutoff_radii, 0, I)
    I = np.sum(I, axis=1)

    # okay, next compute born radii
    offset_radius = radii - dielectric_offset

    psi = I * offset_radius

    psi_coefficient = alpha
    psi2_coefficient = beta
    psi3_coefficient = gamma

    psi_term = (psi_coefficient * psi) - (psi2_coefficient *
                                          psi**2) + (psi3_coefficient * psi**3)

    B = 1 / (1 / offset_radius - np.tanh(psi_term) / radii)

    E = 0.0
    # single particle
    # ACE
    E += np.sum(surface_tension * (radii + probe_radius)**2 * (radii / B)**6)

    # on-diagonal
    charges = charge_params

    E += np.sum(-0.5 * (1 / solute_dielectric - 1 / solvent_dielectric) *
                charges**2 / B)

    # particle pair
    f = np.sqrt(r**2 + np.outer(B, B) * np.exp(-r**2 / (4 * np.outer(B, B))))
    charge_products = np.outer(charges, charges)

    ixns = -(1 / solute_dielectric -
             1 / solvent_dielectric) * charge_products / f

    # sw = np.power(np.cos((np.pi*dij)/(2*cutoff_radii)), 2)
    # ixns = ixns*sw
    ixns = np.where(dij > cutoff_force, 0, ixns)

    E += np.sum(np.triu(ixns, k=1))

    return E
示例#24
0
def _connection_weights(num_iterations, num_mixing_iterations):
    """Gets the connection weights."""
    mask = jnp.triu(jnp.tril(jnp.ones((num_iterations, num_iterations))),
                    k=-num_mixing_iterations + 1)
    return mask / jnp.sum(mask, axis=1, keepdims=True)
示例#25
0
    def update_site(self, inputs: Array, index: int) -> Array:
        """
        Adds an input site into the cache, and applies the masked linear transformation to the cache.

        Args:
          inputs: an input site to be added into the cache with dimensions (batch, features).
          index: the index of the output site. The index of the input site should be `index - self.exclusive`.

        Returns:
          The output site with dimensions (batch, features).
        """
        dtype = jnp.promote_types(inputs.dtype, self.dtype)

        inputs = jnp.asarray(inputs, dtype)

        is_single_input = False
        if inputs.ndim == 1:
            is_single_input = True
            inputs = jnp.expand_dims(inputs, axis=0)

        batch, in_features = inputs.shape
        size = self.size

        # Number of input sites depended by the output site at the index
        size_i = index + 1

        # Initialize the cache with zeros, and the RNG key is None
        # `cache.dtype` must be the same as `inputs.dtype` (no promotion)
        _cache = self.variable("cache", "inputs", zeros, None,
                               (batch, size, in_features), inputs.dtype)

        initializing = self.is_mutable_collection("params")
        if not initializing:
            # Add the input site into the cache
            # To write the cache, use `_cache.value` as the left value of the assignment
            _cache.value = lax.cond(
                index - self.exclusive >= 0,
                lambda _: _cache.value.at[:, index - self.exclusive, :].set(
                    inputs),
                lambda _: _cache.value,
                None,
            )

        cache = _cache.value
        cache = jnp.asarray(cache, dtype)

        cache_i = cache[:, :size_i, :]
        cache_i = cache_i.reshape((batch, size_i * in_features))

        # The construction of `mask` will be optimized to a constant by JIT
        mask = jnp.ones((size, size), dtype=self.dtype)
        mask = jnp.triu(mask, self.exclusive)
        mask = jnp.kron(
            mask, jnp.ones((in_features, self.features), dtype=self.dtype))

        kernel = self.param(
            "kernel",
            wrap_kernel_init(self.kernel_init, mask),
            (size * in_features, size * self.features),
            self.dtype,
        )
        mask = jnp.asarray(mask, dtype)
        kernel = jnp.asarray(kernel, dtype)

        mask_i = mask.reshape((size, in_features, size, self.features))
        mask_i = mask_i[:size_i, :, index, :]
        mask_i = mask_i.reshape((size_i * in_features, self.features))

        kernel_i = kernel.reshape((size, in_features, size, self.features))
        kernel_i = kernel_i[:size_i, :, index, :]
        kernel_i = kernel_i.reshape((size_i * in_features, self.features))

        y_i = lax.dot(cache_i, mask_i * kernel_i, precision=self.precision)

        if self.use_bias:
            bias = self.param("bias", self.bias_init, (size, self.features),
                              self.dtype)
            bias = jnp.asarray(bias, dtype)

            bias_i = bias[index, :]

            y_i = y_i + bias_i

        assert y_i.shape[1] == self.features

        if is_single_input:
            y_i = y_i.squeeze(axis=0)

        return y_i
示例#26
0
 def returns(self, r):
     # r: [n_steps]
     return jnp.dot(jnp.triu(jnp.ones((self.n_steps, self.n_steps))),
                    r)  # R: [n_steps]
示例#27
0
def gbsa(conf,
         params,
         box,
         param_idxs,
         dielectric_offset=0.009,
         cutoff=2.0,
         alpha_obc=1.0,
         beta_obc=0.8,
         gamma_obc=4.85,
         solute_dielectric=1.0,
         solvent_dielectric=78.3,
         electric_constant=-69.467728,
         probe_radius=0.14,
         surface_area_energy=2.25936):
    """
    Computes the GBSA energy with support for full OBC style parameters.

    For detailed notes on the values of the undocumented keyword args, please
    refer to the OpenMM theory manual:

    http://docs.openmm.org/latest/userguide/theory.html#gbsaobcforce

    Parameters
    ----------
    conf: shape [num_atoms, 3] np.array
        atomic coordinates

    params: shape [num_params,] np.array
        unique parameters

    box: shape [3, 3] np.array
        periodic boundary vectors, if not None
    
    param_idxs: shape [num_atoms, 3]
        a list of 3-tuple parameter indices, where the
        0th index indicate charges, 1st indicates radii
        and 2nd indicates scale_factors

    """

    if box is not None:
        raise ValueError("Periodic GBSA is not supported.")

    num_atoms = conf.shape[0]

    if solute_dielectric != 0.0 and solvent_dielectric != 0.0:
        prefactor = 2.0 * electric_constant * (1.0 / solute_dielectric -
                                               1.0 / solvent_dielectric)
    else:
        prefactor = 0.0

    # (ytz): The rough sketch of the algorithm is as follows:
    # 1. Compute the adjusted GB radii
    # 2. Use the adjusted radiis to compute the shielded electrostatic potential
    # 3. Compute the non-polar contribution using the GB radii

    charges = params[param_idxs[:, 0]]
    atomic_radii = params[param_idxs[:, 1]]
    scaled_factors = params[param_idxs[:, 2]]

    br = born_radii(conf, atomic_radii, scaled_factors, dielectric_offset,
                    alpha_obc, beta_obc, gamma_obc)

    r_i = np.expand_dims(conf, axis=0)
    r_j = np.expand_dims(conf, axis=1)

    q_i = np.expand_dims(charges, axis=0)
    q_j = np.expand_dims(charges, axis=1)
    q_ij = q_i * q_j

    br_i = np.expand_dims(br, axis=0)
    br_j = np.expand_dims(br, axis=1)

    r2 = np.sum(np.power(r_i - r_j, 2), axis=-1)
    alpha2_ij = br_i * br_j
    D_ij = r2 / (4.0 * alpha2_ij)
    expTerm = np.exp(-D_ij)
    denom2 = r2 + alpha2_ij * expTerm
    denom = np.sqrt(denom2)
    pq_ij = prefactor * q_ij

    Gpol = pq_ij / denom
    energy = Gpol

    pi4Asolv = 4 * np.pi * surface_area_energy

    nonpolar_nrg = non_polar_ace(br, atomic_radii, probe_radius, pi4Asolv)

    # compute using only the upper triangle
    return np.sum(np.triu(energy)) + np.sum(
        np.diagonal(energy) / 2.0) + nonpolar_nrg