def testQdwhWithOnRankDeficientInput(self, m, n, log_cond): """Tests qdwh with rank-deficient input.""" a = jnp.triu(jnp.ones((m, n))).astype(_QDWH_TEST_DTYPE) # Generates a rank-deficient input. u, s, v = jnp.linalg.svd(a, full_matrices=False) cond = 10**log_cond s = jnp.linspace(cond, 1, min(m, n)) s = jnp.expand_dims(s.at[-1].set(0), range(u.ndim - 1)) a = (u * s) @ v is_hermitian = _check_symmetry(a) max_iterations = 15 actual_u, actual_h, _, _ = qdwh.qdwh(a, is_hermitian=is_hermitian, max_iterations=max_iterations) _, expected_h = osp_linalg.polar(a) # Sets the test tolerance. rtol = 1E4 * _QDWH_TEST_EPS # For rank-deficient matrix, `u` is not unique. with self.subTest('Test h.'): relative_diff_h = _compute_relative_diff(actual_h, expected_h) np.testing.assert_almost_equal(relative_diff_h, 1E-6, decimal=5) with self.subTest('Test u.dot(h).'): a_round_trip = _dot(actual_u, actual_h) relative_diff_a = _compute_relative_diff(a_round_trip, a) np.testing.assert_almost_equal(relative_diff_a, 1E-6, decimal=5) with self.subTest('Test orthogonality.'): actual_results = _dot(actual_u.T.conj(), actual_u) expected_results = np.eye(n) self.assertAllClose( actual_results, expected_results, rtol=rtol, atol=1E-6)
def testQdwhWithUpperTriangularInputAllOnes(self, m, n, log_cond): """Tests qdwh with upper triangular input of all ones.""" a = jnp.triu(jnp.ones((m, n))).astype(_QDWH_TEST_DTYPE) u, s, v = jnp.linalg.svd(a, full_matrices=False) cond = 10**log_cond s = jnp.expand_dims(jnp.linspace(cond, 1, min(m, n)), range(u.ndim - 1)) a = (u * s) @ v is_hermitian = _check_symmetry(a) max_iterations = 10 actual_u, actual_h, _, _ = qdwh.qdwh(a, is_hermitian=is_hermitian, max_iterations=max_iterations) expected_u, expected_h = osp_linalg.polar(a) # Sets the test tolerance. rtol = 1E6 * _QDWH_TEST_EPS with self.subTest('Test u.'): relative_diff_u = _compute_relative_diff(actual_u, expected_u) np.testing.assert_almost_equal(relative_diff_u, 1E-6, decimal=5) with self.subTest('Test h.'): relative_diff_h = _compute_relative_diff(actual_h, expected_h) np.testing.assert_almost_equal(relative_diff_h, 1E-6, decimal=5) with self.subTest('Test u.dot(h).'): a_round_trip = _dot(actual_u, actual_h) relative_diff_a = _compute_relative_diff(a_round_trip, a) np.testing.assert_almost_equal(relative_diff_a, 1E-6, decimal=5) with self.subTest('Test orthogonality.'): actual_results = _dot(actual_u.T, actual_u) expected_results = np.eye(n) self.assertAllClose( actual_results, expected_results, rtol=rtol, atol=1E-5)
def init_fun(rng, input_dim, **kwargs): W = orthogonal()(rng, (input_dim, input_dim)) P, L, U = scipy.linalg.lu(W) S = np.diag(U) U = np.triu(U, 1) identity = np.eye(input_dim) def direct_fun(params, inputs, **kwargs): L, U, S = params L = np.tril(L, -1) + identity U = np.triu(U, 1) W = P @ L @ (U + np.diag(S)) outputs = inputs @ W log_det_jacobian = np.full(inputs.shape[:1], np.log(np.abs(S)).sum()) return outputs, log_det_jacobian def inverse_fun(params, inputs, **kwargs): L, U, S = params L = np.tril(L, -1) + identity U = np.triu(U, 1) W = P @ L @ (U + np.diag(S)) outputs = inputs @ linalg.inv(W) log_det_jacobian = np.full(inputs.shape[:1], -np.log(np.abs(S)).sum()) return outputs, log_det_jacobian return (L, U, S), direct_fun, inverse_fun
def crossover(parent_1, parent_2, offspring_size): all_offspring = [] for o in range(offspring_size): lower_1 = np.tril(parent_1) upper_2 = np.triu(parent_2) offspring = lower_1 + upper_2 all_offspring.append(offspring)
def custom_assert(tst, result_jax, result_tf, *, args, tol, err_msg): operand, = args lu, pivots, perm = result_tf batch_dims = operand.shape[:-2] m, n = operand.shape[-2], operand.shape[-1] def _make_permutation_matrix(perm): result = [] for idx in itertools.product(*map(range, operand.shape[:-1])): result += [0 if c != perm[idx] else 1 for c in range(m)] result = np.reshape(np.array(result, dtype=dtype), [*batch_dims, m, m]) return result k = min(m, n) l = jnp.tril(lu, -1)[..., :, :k] + jnp.eye(m, k, dtype=dtype) u = jnp.triu(lu)[..., :k, :] p_mat = _make_permutation_matrix(perm) tst.assertArraysEqual( lax.linalg.lu_pivots_to_permutation(pivots, m), perm) tst.assertAllClose(jnp.matmul(p_mat, operand), jnp.matmul(l, u), atol=tol, rtol=tol, err_msg=err_msg)
def build_attention_mask(self): # lazily create causal attention mask, with full attention between the vision tokens # pytorch uses additive attention mask; fill with -inf mask = jnp.zeros((self.context_length, self.context_length)) mask -= 10e10 mask = jnp.triu(mask, 1) # zero out the lower diagonal return mask
def _band_part(input, num_lower, num_upper, name=None): # pylint: disable=redefined-builtin del name result = input if num_lower > -1: result = np.triu(result, -num_lower) if num_upper > -1: result = np.tril(result, num_upper) return result
def inverse_fun(params, inputs, **kwargs): L, U, S = params L = np.tril(L, -1) + identity U = np.triu(U, 1) W = P @ L @ (U + np.diag(S)) outputs = inputs @ linalg.inv(W) log_det_jacobian = np.full(inputs.shape[:1], -np.log(np.abs(S)).sum()) return outputs, log_det_jacobian
def __call__(self, inputs: Array) -> Array: """ Applies a masked linear transformation to the inputs. Args: inputs: input data with dimensions (batch, length, features). Returns: The transformed data. """ if inputs.ndim == 2: is_single_input = True inputs = jnp.expand_dims(inputs, axis=0) else: is_single_input = False batch, size, in_features = inputs.shape inputs = inputs.reshape((batch, size * in_features)) if self.use_bias: bias = self.param( "bias", self.bias_init, (size, self.features), self.param_dtype ) else: bias = None mask = jnp.ones((size, size), dtype=self.param_dtype) mask = jnp.triu(mask, self.exclusive) mask = jnp.kron( mask, jnp.ones((in_features, self.features), dtype=self.param_dtype) ) kernel = self.param( "kernel", wrap_kernel_init(self.kernel_init, mask), (size * in_features, size * self.features), self.param_dtype, ) inputs, mask, kernel, bias = promote_dtype( inputs, mask, kernel, bias, dtype=None ) y = lax.dot(inputs, mask * kernel, precision=self.precision) y = y.reshape((batch, size, self.features)) if is_single_input: y = y.squeeze(axis=0) if self.use_bias: y = y + bias return y
def factored_to_QR(h, beta): """ Computes dense matrices Q and R from the factored QR representation [h, tau] as computed by qr with mode == "factored". """ m, n = h.shape R = jnp.triu(h) Q = jnp.eye(m, dtype=h.dtype) for j in range(n - 1, -1, -1): v = jnp.concatenate((jnp.array([1.]), h[j + 1:, j])) Q = index_update(Q, index[j:, j:], house_leftmult(Q[j:, j:], v, beta[j])) out = [Q, R] return out
def testSvdWithOnRankDeficientInput(self, m, r, log_cond): """Tests SVD with rank-deficient input.""" with jax.default_matmul_precision('float32'): a = jnp.triu(jnp.ones((m, m))).astype(_SVD_TEST_DTYPE) # Generates a rank-deficient input. u, s, v = jnp.linalg.svd(a, full_matrices=False) cond = 10**log_cond s = jnp.linspace(cond, 1, m) s = s.at[r:m].set(jnp.zeros((m - r, ))) a = (u * s) @ v with jax.default_matmul_precision('float32'): u, s, v = svd.svd(a, full_matrices=False, hermitian=False) diff = np.linalg.norm(a - (u * s) @ v) np.testing.assert_almost_equal(diff, 1E-4, decimal=2)
def _transform_to_covariance_matrix(self, sq_mat): ''' Takes the upper triangular matrix of the given matrix and then multiplies it by its transpose https://ericmjl.github.io/notes/stats-ml/estimating-a-multivariate-gaussians-parameters-by-gradient-descent/ Parameters ---------- sq_mat : array Square matrix Returns ------- * array ''' U = jnp.triu(sq_mat) U_T = jnp.transpose(U) return jnp.dot(U_T, U)
def testQdwhUnconvergedAfterMaxNumberIterations(self, m, n, log_cond): """Tests unconvergence after maximum number of iterations.""" a = jnp.triu(jnp.ones((m, n))) u, s, v = jnp.linalg.svd(a, full_matrices=False) cond = 10**log_cond s = jnp.linspace(cond, 1, min(m, n)) a = (u * s) @ v is_symmetric = _check_symmetry(a) max_iterations = 2 _, _, actual_num_iterations, is_converged = qdwh.qdwh( a, is_symmetric, max_iterations) with self.subTest('Number of iterations.'): self.assertEqual(max_iterations, actual_num_iterations) with self.subTest('Converged.'): self.assertFalse(is_converged)
def fill_triangular(x, upper=False): m = x.shape[-1] if len(x.shape) != 1: raise ValueError("Only handles 1D to 2D transformation, because tril/u") m = np.int32(m) n = np.sqrt(0.25 + 2. * m) - 0.5 if n != np.floor(n): raise ValueError('Input right-most shape ({}) does not ' 'correspond to a triangular matrix.'.format(m)) n = np.int32(n) final_shape = list(x.shape[:-1]) + [n, n] if upper: x_list = [x, np.flip(x[..., n:], -1)] else: x_list = [x[..., n:], np.flip(x, -1)] x = np.reshape(np.concatenate(x_list, axis=-1), final_shape) if upper: x = np.triu(x) else: x = np.tril(x) return x
def __call__(self, x, pos_emb, mask): dim_in, h = x.shape[-1], self.heads scale = dim_in**-0.5 norm = nn.LayerNorm() to_qkv = nn.Dense(features=self.dim_head * h * 3, use_bias=False) to_out = nn.Dense(features=dim_in) x = norm(x) qkv = np.split(to_qkv(x), 3, axis=-1) q, k, v = map(lambda t: rearrange(t, "i (h d) -> i h d", h=h), qkv) q = index_update(q, index[1:], apply_rotary_pos_emb(q[1:], pos_emb)) k = index_update(k, index[1:], apply_rotary_pos_emb(k[1:], pos_emb)) sim = einsum("i h d, j h d -> i j h", q, k) * scale mask = np.pad(mask, (1, 0), constant_values=True) mask = rearrange(mask, "j -> () j ()") if self.causal: i, j = sim.shape[:2] tri_mask = np.ones((i - 1, j - 1), dtype=bool) tri_mask = np.pad(tri_mask, ((1, 0), (1, 0)), constant_values=False) causal_mask = np.triu(tri_mask, j - i + 1) causal_mask = rearrange(causal_mask, "i j -> i j ()") mask = ~causal_mask * mask sim = np.where(mask, sim, LARGE_NEG_VALUE) attn = nn.softmax(sim, axis=-2) out = einsum("i j h, j h d -> i h d", attn, v) out = rearrange(out, "i h d -> i (h d)") return to_out(out)
def setup(self): """Initialize P, L, U, s""" # W = PL(U + s) # Based on https://github.com/openai/glow/blob/master/model.py#L485 dim = self.input_dim # Sample random rotation matrix q, _ = np.linalg.qr(jax.random.normal(self.rng, (dim, dim)), mode="complete") p, l, u = jax.scipy.linalg.lu(q) # Fixed Permutation (non-trainable) self.P = p self.P_inv = jax.scipy.linalg.inv(p) # Init value from LU decomposition L_init = l U_init = np.triu(u, k=1) s = np.diag(u) self.sign_s = np.sign(s) S_log_init = np.log(np.abs(s)) self.l_mask = np.tril(np.ones((dim, dim)), k=-1) self.u_mask = np.transpose(self.l_mask) # Define trainable variables self.L = self.param("L", lambda k, sh: L_init, (dim, dim)) self.U = self.param("U", lambda k, sh: U_init, (dim, dim)) self.log_s = self.param("log_s", lambda k, sh: S_log_init, (dim, ))
def house_qr(A, mode="reduced"): """ Performs a QR decomposition of the m x n real or complex matrix A using the Householder algorithm. The string parameter 'mode' determines the representation of the output. In this way, one can retrieve various implicit representations of the factored matrices. This can be a significant optimization in the case of a highly rectangular A, which is the reason for this function's existence. Parameters ---------- A : array_like, shape (M, N) Matrix to be factored. mode: {'reduced', 'complete', 'r', 'factored', 'WY'}, optional If K = min(M, N), then: - 'reduced': returns Q, R with dimensions (M, K), (K, N) (default) - 'complete': returns Q, R with dimensions (M, M), (M, N) - 'r': returns r only with dimensions (K, N) - 'factored': returns H, beta with dimensions (N, M), (K), read below for details. - 'WY' : returns W, Y with dimensions (M, K), read below for details. With 'reduced', 'complete', or 'r', this function simply passes to jnp.linalg.qr, which depending on the currect status of Jax may lead to NotImplemented if A is complex. With 'factored' this function returns the same H, beta as generated by the Lapack function dgeqrf() (but in row-major form). Thus, H contains the upper triangular matrix R in its upper triangle, and the j'th Householder reflector forming Q in the j'th column of its lower triangle. beta[j] contains the normalization factor of the j'th reflector, called 'beta' in the function 'house' in this module. The matrix Q is then represented implicitly as Q = H(0) H(1) ... H(K), H(i) = I - tau[i] v dag(v) with v[:j] = 0; v[j]=1; v[j+1:]=A[j+1:, j]. Application of Q (C -> dag{Q} C) can be made directly from this implicit representation using the function factored_multiply(C). When K << max(M, N), both the QR factorization and multiplication by Q using factored_multiply theoretically require far fewer operations than would an explicit representation of Q. However, these applications are mostly Level-2 BLAS operations. With 'WY' this function returns (M, K) matrices W and Y such that Q = I - W dag(Y). Y is lower-triangular matrix of Householder vectors, e.g. the lower triangle of the matrix H resulting from mode='factored'. W is then computed so that the above identity holds. Application of Q can be made directly from the WY representation using the function WY_multiply(C). The WY representation is a bit more expensive to compute than the factored one, though still less expensive than the full Q. Its advantage versus 'factored' is that WY_multiply calls depend mostly on Level-3 BLAS operations. Returns ------- Q: ndarray of float or complex, optional The column-orthonormal orthogonal/unitary matrix Q. R: ndarray of float or complex, optional. The upper-triangular matrix. [H, beta]: list of ndarrays of float or complex, optional. The matrix H and scaling factors beta generating Q along with R in the 'factored' representation. [W, Y, R] : list of ndarrays of float or complex, optional. The matrices W and Y generating Q along with R in the 'WY' representation. Raises ------ LinAlgError If factoring fails. NotImplementedError In reduced, complete, or r mode with complex ijnp.t. In factored or WY mode in the case M < N. """ if mode == "reduced" or mode == "complete" or mode == "r": return jnp.linalg.qr(A, mode=mode) else: m, n = A.shape if n > m: raise NotImplementedError("n > m QR not implemented in factored" + "or WY mode.") if mode == "factored": return __house_qr_factored(A) elif mode == "WY": hbetalist = __house_qr_factored(A) R = jnp.triu(hbetalist[0]) WYlist = factored_to_WY(hbetalist) output = WYlist + [R] return output else: raise ValueError("Invalid mode: ", mode)
def step(self, s, a): """Apply control, damping, boundary, and collision forces. Args: s: (p, v, misc), where p and v are [n_entities,2] jnp.float32, and misc is child defined a: [n_agents, dim_a] jnp.float32 Returns: A state tuple (p, v, misc) """ p, v, misc = s # [n,2], [n,2], [a_shape] f = jnp.zeros_like(p) # [n,2] n = p.shape[0] # number of entities # Calculate control forces f_control = jnp.pad(a, ((0, n-a.shape[0]), (0, 0)), mode="constant") # [n, dim_a] f += f_control # Calculate damping forces f_damping = -1.0*self.damping*v # [n,2] f = f + f_damping # Calculate boundary forces bounce = (((p+self.radius >= self.max_p) & (v >= 0.0)) | ((p-self.radius <= self.min_p) & (v <= 0.0))) # [n,2] v_new = (-1.0*bounce + 1.0*~bounce)*v # [n,2] f_boundary = self.mass*(v_new - v)/self.dt # [n,2] f = f + f_boundary # Calculate shared quantities for later calculations # same: [n,n,1], True if i==j same = jnp.expand_dims(jnp.eye(n, dtype=jnp.bool_), axis=-1) # p2p: [n,n,2], p2p[i,j,:] is the vector from entity i to entity j p2p = p - jnp.expand_dims(p, axis=1) # dist: [n,n,1], p2p[i,j,0] is the distance between i and j dist = jnp.linalg.norm(p2p, axis=-1, keepdims=True) # overlap: [n,n,1], overlap[i,j,0] is the overlap between i and j overlap = ((jnp.expand_dims(self.radius, axis=1) + jnp.expand_dims(self.radius, axis=0)) - dist) if self.same_position_check: # ontop: [n,n,1], ontop[i,j,0] = True if i is at the exact location of j ontop = (dist == 0.0) # ontop_dir: [n,n,1], (1,0) above diagonal, (-1,0) below diagonal ontop_dir = jnp.stack([jnp.triu(jnp.ones((n, n)))*2-1, jnp.zeros((n, n))], axis=-1) # contact_dir: [n,n,2], contact_dir[i,j,:] is the unit vector in the # direction of j from i contact_dir = (~ontop*p2p + (ontop*ontop_dir))/(~ontop*dist + ontop*1.0) else: # contact_dir: [n,n,2], contact_dir[i,j,:] is the unit vector in the # direction of j from i contact_dir = p2p/(dist+same) # collideable: [n,n,1], True if i and j are collideable collideable = (jnp.expand_dims(self.collideable, axis=1) & jnp.expand_dims(self.collideable, axis=0)) # overlap: [n,n,1], True if i,j overlap overlapping = overlap > 0 # Calculate collision forces # Assume all entities collide with all entities, then mask out # non-collisions. # # For approaching, coliding entities, apply a forces # along the direction of collision that results in # relative velocities consistent with the coefficient of # restitution (c) and preservation of momentum in that # direction. # momentum: m_a*v_a + m_b*v_b = m_a*v'_a + m_b*v'_b # restitution: v'_b - v'_a = -c*(v_b-v_a) # solve for v'_a: # v'_a = [m_a*v_a + m_b*v_b + m_b*c*(v_b-v_a)]/(m_a + m_b) # # v_contact_dir: [n,n] speed of i in dir of j v_contact_dir = jnp.sum(jnp.expand_dims(v, axis=-2)*contact_dir, axis=-1) # v_approach: [n,n] speed that i,j are approaching each other v_approach = jnp.transpose(v_contact_dir) + v_contact_dir # momentum: [n,n] joint momentum in direction of contact (i->j) momentum = self.mass*v_contact_dir - jnp.transpose(self.mass*v_contact_dir) # v_result: [n,n] speed of i in dir of j after collision v_result = ((momentum + self.restitution*jnp.transpose(self.mass)*(-v_approach)) / (self.mass + jnp.transpose(self.mass))) # f_collision: [n,n] force on i in dir of j to realize acceleration f_collision = self.mass*(v_result - v_contact_dir)/self.dt # f_collision: [n,n,2] force on i to realize acceleration due to # collision with j f_collision = jnp.expand_dims(f_collision, axis=-1)*contact_dir # collision_mask: [n,n,1] collision_mask = (collideable & overlapping & ~same & (jnp.expand_dims(v_approach, axis=-1) > 0)) # f_collision: [n,2], sum of collision forces on i f_collision = jnp.sum(f_collision*collision_mask, axis=-2) f = f + f_collision # Calculate overlapping spring forces # This corrects for any overlap due to discrete steps. # f_overlap: [n,n,2], force in the negative contact dir due to overlap f_overlap = -1.0*contact_dir*overlap*self.overlap_spring_constant # overlapping_mask: [n,n,1], True if i,j are collideable, overlap, # and i != j overlapping_mask = collideable & overlapping & ~same # f_overlap: [n,2], sum of spring forces on i f_overlap = jnp.sum(f_overlap*overlapping_mask, axis=-2) f = f + f_overlap # apply forces v = v + (f/self.mass)*self.dt p = p + v*self.dt # update misc misc = self._update_misc((p, v, misc), a) # pylint: disable=assignment-from-none return (p, v, misc)
def mass_matrix_inv_mul(self, q: jnp.ndarray, v: jnp.ndarray, **kwargs) -> jnp.ndarray: """Computes the product of the inverse mass matrix with a vector.""" if self.kinetic_func_form in ("separable_net", "dep_net"): raise ValueError( "It is not possible to compute `M^-1 p` when using a " "network for the kinetic energy.") if self.kinetic_func_form in ("pure_quad", "embed_quad"): return v if self.kinetic_func_form == "matrix_diag_quad": if self.parametrize_mass_matrix: m_diag_log = hk.get_parameter( "MassMatrixDiagLog", shape=[self.system_dim], init=hk.initializers.Constant(0.0)) m_inv_diag = 1.0 / (jnp.exp(m_diag_log) + self.mass_eps) else: m_inv_diag_log = hk.get_parameter( "InvMassMatrixDiagLog", shape=[self.system_dim], init=hk.initializers.Constant(0.0)) m_inv_diag = jnp.exp(m_inv_diag_log) + self.mass_eps return m_inv_diag * v if self.kinetic_func_form == "matrix_quad": if self.parametrize_mass_matrix: m_triu = hk.get_parameter( "MassMatrixU", shape=[self.system_dim, self.system_dim], init=hk.initializers.Identity()) m_triu = jnp.triu(m_triu) m = jnp.matmul(m_triu.T, m_triu) m = m + self.mass_eps * jnp.eye(self.system_dim) solve = jnp.linalg.solve for _ in range(v.ndim + 1 - m.ndim): solve = jax.vmap(solve, in_axes=(None, 0)) return solve(m, v) else: m_inv_triu = hk.get_parameter( "InvMassMatrixU", shape=[self.system_dim, self.system_dim], init=hk.initializers.Identity()) m_inv_triu = jnp.triu(m_inv_triu) m_inv = jnp.matmul(m_inv_triu.T, m_inv_triu) m_inv = m_inv + self.mass_eps * jnp.eye(self.system_dim) return self.feature_matrix_vector(m_inv, v) if self.kinetic_func_form in ("matrix_dep_diag_quad", "matrix_dep_diag_embed_quad"): if self.parametrize_mass_matrix: m_diag_log = self.mass_matrix_net(q, **kwargs) m_inv_diag = 1.0 / (jnp.exp(m_diag_log) + self.mass_eps) else: m_inv_diag_log = self.mass_matrix_net(q, **kwargs) m_inv_diag = jnp.exp(m_inv_diag_log) + self.mass_eps return m_inv_diag * v if self.kinetic_func_form in ("matrix_dep_quad", "matrix_dep_embed_quad"): if self.parametrize_mass_matrix: m_triu = self.mass_matrix_net(q, **kwargs) m_triu = utils.triu_matrix_from_v(m_triu, self.system_dim) m = jnp.matmul(jnp.swapaxes(m_triu, -2, -1), m_triu) m = m + self.mass_eps * jnp.eye(self.system_dim) return jnp.linalg.solve(m, v) else: m_inv_triu = self.mass_matrix_net(q, **kwargs) m_inv_triu = utils.triu_matrix_from_v(m_inv_triu, self.system_dim) m_inv = jnp.matmul(jnp.swapaxes(m_inv_triu, -2, -1), m_inv_triu) m_inv = m_inv + self.mass_eps * jnp.eye(self.system_dim) return self.feature_matrix_vector(m_inv, v) raise NotImplementedError()
def triu(a, k=0): if isinstance(a, JaxArray): a = a.value return JaxArray(jnp.triu(a, k))
# errs_a.append(np.stack(erra).mean()) # errs_std_a.append(np.stack(erra).std()) # errs_b.append(np.stack(errb).mean()) # errs_std_b.append(np.stack(errb).std()) # Compute generalization error, rapid but heuristic errs_a = [] errs_b = [] for ii, a in enumerate(range(K)): Xa = np.array(manifolds[a]) for b in range(a + 1, K): Xb = np.array(manifolds[b]) erra = [] errb = [] key, _ = random.split(key) erra, errb = mshot_err_fast(key, Xa, Xb) errs_a.append(erra) errs_b.append(errb) print('Manifold {} of {}. Avg. acc: {}'.format(ii, K, 1 - errs_a[-1].mean())) # Combine errs_a and errs_b into K x K matrix errs_full = np.triu(squareform(errs_a)) + np.tril(squareform(errs_b)) # Save np.save(save_path, errs_full) print('Finished with acc. ' + str(1 - np.mean(errs_full)) + '. Saved.')
def compute_OBC_energy_vectorized( distance_matrix, radii, scales, charges, offset=0.009, screening=138.935484, surface_tension=28.3919551, solvent_dielectric=78.5, solute_dielectric=1.0, ): """Compute GBSA-OBC energy from a distance matrix""" N = len(radii) #print(type(distance_matrix)) eye = np.eye(N, dtype=distance_matrix.dtype) #print(type(eye)) r = distance_matrix + eye # so I don't have divide-by-zero nonsense or1 = radii.reshape((N, 1)) - offset or2 = radii.reshape((1, N)) - offset sr2 = scales.reshape((1, N)) * or2 L = np.maximum(or1, abs(r - sr2)) U = r + sr2 I = step(r + sr2 - or1) * 0.5 * (1 / L - 1 / U + 0.25 * (r - sr2**2 / r) * (1 / (U**2) - 1 / (L**2)) + 0.5 * np.log(L / U) / r) I -= np.diag(np.diag(I)) I = np.sum(I, axis=1) # okay, next compute born radii offset_radius = radii - offset psi = I * offset_radius psi_coefficient = 0.8 psi2_coefficient = 0 psi3_coefficient = 2.909125 psi_term = (psi_coefficient * psi) + (psi2_coefficient * psi**2) + (psi3_coefficient * psi**3) B = 1 / (1 / offset_radius - np.tanh(psi_term) / radii) # finally, compute the three energy terms E = 0.0 # single particle E += np.sum(surface_tension * (radii + 0.14)**2 * (radii / B)**6) E += np.sum(-0.5 * screening * (1 / solute_dielectric - 1 / solvent_dielectric) * charges**2 / B) # particle pair f = np.sqrt(r**2 + np.outer(B, B) * np.exp(-r**2 / (4 * np.outer(B, B)))) charge_products = np.outer(charges, charges) E += np.sum( np.triu(-screening * (1 / solute_dielectric - 1 / solvent_dielectric) * charge_products / f, k=1)) return E
def gbsa_obc( coords, # params, lamb, # box, charge_params, gb_params, # charge_idxs, # radii_idxs, # scale_idxs, alpha, beta, gamma, cutoff_radii, cutoff_force, lambda_plane_idxs, lambda_offset_idxs, dielectric_offset=0.009, surface_tension=28.3919551, solute_dielectric=1.0, solvent_dielectric=78.5, probe_radius=0.14): box = None assert cutoff_radii == cutoff_force coords_4d = convert_to_4d(coords, lamb, lambda_plane_idxs, lambda_offset_idxs, cutoff_radii) N = len(charge_params) radii = gb_params[:, 0] scales = gb_params[:, 1] ri = np.expand_dims(coords_4d, 0) rj = np.expand_dims(coords_4d, 1) dij = distance(ri, rj, box) eye = np.eye(N, dtype=dij.dtype) r = dij + eye # so I don't have divide-by-zero nonsense or1 = radii.reshape((N, 1)) - dielectric_offset or2 = radii.reshape((1, N)) - dielectric_offset sr2 = scales.reshape((1, N)) * or2 L = np.maximum(or1, abs(r - sr2)) U = r + sr2 I = 1 / L - 1 / U + 0.25 * (r - sr2**2 / r) * (1 / (U**2) - 1 / (L**2)) + 0.5 * np.log( L / U) / r # handle the interior case I = np.where(or1 < (sr2 - r), I + 2 * (1 / or1 - 1 / L), I) I = step(r + sr2 - or1) * 0.5 * I # note the extra 0.5 here I -= np.diag(np.diag(I)) # switch I only for now # inner = (np.pi*np.power(dij,8))/(2*cutoff_radii) # sw = np.power(np.cos(inner), 2) # I = I*sw I = np.where(dij > cutoff_radii, 0, I) I = np.sum(I, axis=1) # okay, next compute born radii offset_radius = radii - dielectric_offset psi = I * offset_radius psi_coefficient = alpha psi2_coefficient = beta psi3_coefficient = gamma psi_term = (psi_coefficient * psi) - (psi2_coefficient * psi**2) + (psi3_coefficient * psi**3) B = 1 / (1 / offset_radius - np.tanh(psi_term) / radii) E = 0.0 # single particle # ACE E += np.sum(surface_tension * (radii + probe_radius)**2 * (radii / B)**6) # on-diagonal charges = charge_params E += np.sum(-0.5 * (1 / solute_dielectric - 1 / solvent_dielectric) * charges**2 / B) # particle pair f = np.sqrt(r**2 + np.outer(B, B) * np.exp(-r**2 / (4 * np.outer(B, B)))) charge_products = np.outer(charges, charges) ixns = -(1 / solute_dielectric - 1 / solvent_dielectric) * charge_products / f # sw = np.power(np.cos((np.pi*dij)/(2*cutoff_radii)), 2) # ixns = ixns*sw ixns = np.where(dij > cutoff_force, 0, ixns) E += np.sum(np.triu(ixns, k=1)) return E
def _connection_weights(num_iterations, num_mixing_iterations): """Gets the connection weights.""" mask = jnp.triu(jnp.tril(jnp.ones((num_iterations, num_iterations))), k=-num_mixing_iterations + 1) return mask / jnp.sum(mask, axis=1, keepdims=True)
def update_site(self, inputs: Array, index: int) -> Array: """ Adds an input site into the cache, and applies the masked linear transformation to the cache. Args: inputs: an input site to be added into the cache with dimensions (batch, features). index: the index of the output site. The index of the input site should be `index - self.exclusive`. Returns: The output site with dimensions (batch, features). """ dtype = jnp.promote_types(inputs.dtype, self.dtype) inputs = jnp.asarray(inputs, dtype) is_single_input = False if inputs.ndim == 1: is_single_input = True inputs = jnp.expand_dims(inputs, axis=0) batch, in_features = inputs.shape size = self.size # Number of input sites depended by the output site at the index size_i = index + 1 # Initialize the cache with zeros, and the RNG key is None # `cache.dtype` must be the same as `inputs.dtype` (no promotion) _cache = self.variable("cache", "inputs", zeros, None, (batch, size, in_features), inputs.dtype) initializing = self.is_mutable_collection("params") if not initializing: # Add the input site into the cache # To write the cache, use `_cache.value` as the left value of the assignment _cache.value = lax.cond( index - self.exclusive >= 0, lambda _: _cache.value.at[:, index - self.exclusive, :].set( inputs), lambda _: _cache.value, None, ) cache = _cache.value cache = jnp.asarray(cache, dtype) cache_i = cache[:, :size_i, :] cache_i = cache_i.reshape((batch, size_i * in_features)) # The construction of `mask` will be optimized to a constant by JIT mask = jnp.ones((size, size), dtype=self.dtype) mask = jnp.triu(mask, self.exclusive) mask = jnp.kron( mask, jnp.ones((in_features, self.features), dtype=self.dtype)) kernel = self.param( "kernel", wrap_kernel_init(self.kernel_init, mask), (size * in_features, size * self.features), self.dtype, ) mask = jnp.asarray(mask, dtype) kernel = jnp.asarray(kernel, dtype) mask_i = mask.reshape((size, in_features, size, self.features)) mask_i = mask_i[:size_i, :, index, :] mask_i = mask_i.reshape((size_i * in_features, self.features)) kernel_i = kernel.reshape((size, in_features, size, self.features)) kernel_i = kernel_i[:size_i, :, index, :] kernel_i = kernel_i.reshape((size_i * in_features, self.features)) y_i = lax.dot(cache_i, mask_i * kernel_i, precision=self.precision) if self.use_bias: bias = self.param("bias", self.bias_init, (size, self.features), self.dtype) bias = jnp.asarray(bias, dtype) bias_i = bias[index, :] y_i = y_i + bias_i assert y_i.shape[1] == self.features if is_single_input: y_i = y_i.squeeze(axis=0) return y_i
def returns(self, r): # r: [n_steps] return jnp.dot(jnp.triu(jnp.ones((self.n_steps, self.n_steps))), r) # R: [n_steps]
def gbsa(conf, params, box, param_idxs, dielectric_offset=0.009, cutoff=2.0, alpha_obc=1.0, beta_obc=0.8, gamma_obc=4.85, solute_dielectric=1.0, solvent_dielectric=78.3, electric_constant=-69.467728, probe_radius=0.14, surface_area_energy=2.25936): """ Computes the GBSA energy with support for full OBC style parameters. For detailed notes on the values of the undocumented keyword args, please refer to the OpenMM theory manual: http://docs.openmm.org/latest/userguide/theory.html#gbsaobcforce Parameters ---------- conf: shape [num_atoms, 3] np.array atomic coordinates params: shape [num_params,] np.array unique parameters box: shape [3, 3] np.array periodic boundary vectors, if not None param_idxs: shape [num_atoms, 3] a list of 3-tuple parameter indices, where the 0th index indicate charges, 1st indicates radii and 2nd indicates scale_factors """ if box is not None: raise ValueError("Periodic GBSA is not supported.") num_atoms = conf.shape[0] if solute_dielectric != 0.0 and solvent_dielectric != 0.0: prefactor = 2.0 * electric_constant * (1.0 / solute_dielectric - 1.0 / solvent_dielectric) else: prefactor = 0.0 # (ytz): The rough sketch of the algorithm is as follows: # 1. Compute the adjusted GB radii # 2. Use the adjusted radiis to compute the shielded electrostatic potential # 3. Compute the non-polar contribution using the GB radii charges = params[param_idxs[:, 0]] atomic_radii = params[param_idxs[:, 1]] scaled_factors = params[param_idxs[:, 2]] br = born_radii(conf, atomic_radii, scaled_factors, dielectric_offset, alpha_obc, beta_obc, gamma_obc) r_i = np.expand_dims(conf, axis=0) r_j = np.expand_dims(conf, axis=1) q_i = np.expand_dims(charges, axis=0) q_j = np.expand_dims(charges, axis=1) q_ij = q_i * q_j br_i = np.expand_dims(br, axis=0) br_j = np.expand_dims(br, axis=1) r2 = np.sum(np.power(r_i - r_j, 2), axis=-1) alpha2_ij = br_i * br_j D_ij = r2 / (4.0 * alpha2_ij) expTerm = np.exp(-D_ij) denom2 = r2 + alpha2_ij * expTerm denom = np.sqrt(denom2) pq_ij = prefactor * q_ij Gpol = pq_ij / denom energy = Gpol pi4Asolv = 4 * np.pi * surface_area_energy nonpolar_nrg = non_polar_ace(br, atomic_radii, probe_radius, pi4Asolv) # compute using only the upper triangle return np.sum(np.triu(energy)) + np.sum( np.diagonal(energy) / 2.0) + nonpolar_nrg