Example #1
0
def train(
    q,
    k,
    scale,
    proj,
    true_attn,
    L_dL,
    proj_fn,
    alpha,
    num_iters,
    key,
    sample=True,
    post_renorm=False,
):
    losses = onp.zeros((num_iters, ))
    grads = onp.zeros((num_iters, ))
    for i in range(num_iters):
        if sample:
            key, key_sample = jax.random.split(key)
        else:
            key_sample = key
        projection_matrix = proj_fn(key_sample)
        kl_val, (dq, dk) = L_dL(q, k, projection_matrix, true_attn)
        q -= alpha * dq
        k -= alpha * dk
        losses[i] = kl_val
        grads[i] += norm(dq)**2
        grads[i] += norm(dk)**2

        if post_renorm:
            q = renorm(q)
            k = renorm(k)

    return losses, grads, q, k, scale, projection_matrix
Example #2
0
def newton(fn, jac_fn, U):
    maxit=20
    tol = 1e-8
    count = 0
    res = 100
    fail = 0

    Uold = U
    maxit=5
#    
#    @jax.jit
#    def body_fun(U,Uold):
#        J =  jac_fn(U, Uold)
#        y = fn(U,Uold)
#        res = norm(y/norm(U,np.inf),np.inf)
#        delta = solve(J,y)
#        U = U - delta
#        return U, res
#   
    print("here")
    start =timeit.default_timer()     
    J =  jac_fn(U, Uold)
    print("computed jacobian")
    y = fn(U,Uold)
    res0 = norm(y/norm(U,np.inf),np.inf)
    delta = solve(J,y)
    U = U - delta
    count = count + 1
    end = timeit.default_timer()
    print("time elapsed in first loop", end-start)
    print(count, res0)
    while(count < maxit and  res > tol):
#        U, res, delta = body_fun(U,Uold)\
        start1 =timeit.default_timer() 
        J =  jac_fn(U, Uold)
        y = fn(U,Uold)
        res = norm(y/norm(U,np.inf),np.inf)
        delta = solve(J,y)
        U = U - delta
        count = count + 1
        end1 =timeit.default_timer() 
        print("time per loop", end1-start1)
        print(count, res)
    
        
    if fail ==0 and np.any(np.isnan(delta)):
        fail = 1
        print("nan solution")
        
    if fail == 0 and max(abs(np.imag(delta))) > 0:
            fail = 1
            print("solution complex")
    
    if fail == 0 and res > tol:
        fail = 1;
        print('Newton fail: no convergence')
    else:
        fail == 0 
        
    return U, fail
Example #3
0
def train_proj(
    q,
    k,
    scale,
    proj,
    true_attn,
    L_dL,
    proj_fn_unused,
    alpha,
    num_iters,
    key,
    sample,
    post_renorm=False,
):
    losses = onp.zeros((num_iters, ))
    grads = onp.zeros((num_iters, ))
    for i in range(num_iters):
        kl_val, (dq, dk, dscale, dproj) = L_dL(q, k, scale, proj, true_attn)
        """
        # dbg
        ra, _ = relu_rff_attn0(q, k, proj)
        print(f"kl {kl_val}, attnmin {ra.min()}")
        if ra.min() < 0:
            import pdb; pdb.set_trace()
        if jnp.isinf(kl_val) or jnp.isnan(kl_val):
            import pdb; pdb.set_trace()
        #/dbg
        """

        q -= alpha * dq
        k -= alpha * dk
        if dscale is not None:
            scale -= alpha * dscale
        if dproj is not None:
            proj -= alpha * dproj

        losses[i] = kl_val

        grads[i] += norm(dq)**2
        grads[i] += norm(dk)**2
        if dscale is not None:
            grads[i] += norm(dscale)**2
        if dproj is not None:
            grads[i] += norm(dproj)**2

        #import pdb; pdb.set_trace()
        if post_renorm:
            q = renorm(q)
            k = renorm(k)

        #import pdb; pdb.set_trace()
        #print(f"grad {grads[i]}")
    #import pdb; pdb.set_trace()
    return losses, grads, q, k, scale, proj
Example #4
0
 def body_fun(val):
     U, count, res, fail = val
     J = jac_fn(U,Uold);
     y = fn(U,Uold)
     delta = solve(J,y)
     U = U - delta
     res = norm(y/norm(U,np.inf),np.inf)
     count = count + 1
     print(count, res)
     val = U, count, res, fail
   
     return val
Example #5
0
def damped_newton(fn, jac_fn, U):
    maxit=10
    tol = 1e-8
    count = 0
    res = 100
    fail = 0
#    U = jax.ops.index_update(U, jax.ops.index[jp02:etap02], U[jp02:etap02]*2**(-16))
    Uold = U
    J =  jac_fn(U, Uold)
    y = fn(U,Uold)  
    delta = solve(J,y)
    U = U - delta;
    res0 = norm(y/norm(U,np.inf),np.inf)
    print(count, res0)
    while(count < maxit and res > tol):
        J =  jac_fn(U, Uold)
        y = fn(U,Uold)        
        res = norm(y/norm(U,np.inf),np.inf)
#        res=norm(y, np.inf)
        print(count, res)
        delta = solve(J,y)
        
#        alpha = 1.0
#        while (norm( fn(U - alpha*delta,Uold )) > (1-alpha*0.5)*norm(y)):
##            print("norm1",norm( fn(U - alpha*delta,Uold )))
##            print("norm2", (1-alpha*0.5)*norm(y) )
#            alpha = alpha/2;
##            print("alpha",alpha)
#            if (alpha < 1e-8):
#                break;
#                
#        U = U - alpha*delta
        U = U - delta;
        count = count + 1
    
        
    if fail ==0 and np.any(np.isnan(delta)):
        fail = 1
        print("nan solution")
        
    if fail == 0 and max(abs(np.imag(delta))) > 0:
            fail = 1
            print("solution complex")
    
    if fail == 0 and res > tol:
        fail = 1;
        print('Newton fail: no convergence')
    else:
        fail == 0 
        
    return U, fail
Example #6
0
    def body(state):
        J = jac_fn(state.U, Uold)
        y = fn(state.U, Uold)
        delta = solve(J, y)
        #        delta = spsolve(csr_matrix(np.asarray(J)),y)
        U = state.U - delta
        res = norm(y / norm(U, np.inf), np.inf)
        converged1 = res < tol
        state._replace(count=state.count + 1,
                       converged=converged1,
                       fail=np.any(np.isnan(delta)),
                       U=U)

        #        print(state.count, state.res)
        return state
Example #7
0
def exp(q, eps=1e-8):
    """Computes the quaternion exponential.

  References:
    https://en.wikipedia.org/wiki/Quaternion#Exponential,_logarithm,_and_power_functions

  Args:
    q: the quaternion in (x,y,z,w) format or (x,y,z) if is_pure is True.
    eps: an epsilon value for numerical stability.

  Returns:
    The exponential of q.
  """
    is_pure = q.shape[-1] == 3
    if is_pure:
        s = jnp.zeros_like(q[..., -1:])
        v = q
    else:
        v = im(q)
        s = re(q)

    norm_v = linalg.norm(v, axis=-1, keepdims=True)
    exp_s = jnp.exp(s)
    w = jnp.cos(norm_v)
    xyz = jnp.sin(norm_v) * v / jnp.maximum(norm_v,
                                            eps * jnp.ones_like(norm_v))
    return exp_s * jnp.concatenate((xyz, w), axis=-1)
Example #8
0
def angle(p1, p2, p3):
    """
    Returns the angle defined by three points in space
    (around the one in the middle).
    """
    q = p1 - p2
    r = p3 - p2
    return np.arctan2(linalg.norm(np.cross(q, r)), np.dot(q, r))
Example #9
0
def newton(fn, jac_fn, U):
    maxit = 20
    tol = 1e-8
    count = 0
    res = 100
    fail = 0
    Uold = U

    start = timeit.default_timer()
    J = jac_fn(U, Uold)
    y = fn(U, Uold)
    res0 = norm(y / norm(U, np.inf), np.inf)
    delta = solve(J, y)
    U = U - delta
    count = count + 1
    end = timeit.default_timer()
    print("time elapsed in first loop", end - start)
    print(count, res0)
    while (count < maxit and res > tol):
        start1 = timeit.default_timer()
        J = jac_fn(U, Uold)
        y = fn(U, Uold)
        res = norm(y / norm(U, np.inf), np.inf)
        delta = solve(J, y)
        U = U - delta
        count = count + 1
        end1 = timeit.default_timer()
        print(count, res)
        print("time per loop", end1 - start1)

    if fail == 0 and np.any(np.isnan(delta)):
        fail = 1
        print("nan solution")

    if fail == 0 and max(abs(np.imag(delta))) > 0:
        fail = 1
        print("solution complex")

    if fail == 0 and res > tol:
        fail = 1
        print('Newton fail: no convergence')
    else:
        fail == 0

    return U, fail
Example #10
0
def dihedral_angle(p1, p2, p3, p4):
    """
    Returns the dihedral angle defined by four points in space
    (around the line defined by the two central points).
    """
    q = p3 - p2
    r = np.cross(p2 - p1, q)
    s = np.cross(q, p4 - p3)
    return np.arctan2(np.dot(np.cross(r, s), q), np.dot(r, s) * linalg.norm(q))
Example #11
0
def blackman_kernel(dims, M):
    n = M - 2
    apply = jax.vmap(lambda ns: blackman(M, norm(np.float64(ns)) / 2))
    inds = np.stack(
        np.meshgrid(*(np.arange(1 - n, n, 2) for _ in range(dims))),
        axis = -1
    )
    kernel = apply(inds.reshape(-1, dims))
    return (kernel / kernel.sum()).reshape(*(n for _ in range(dims)))
Example #12
0
def cond(x, p=None):
  _assertNoEmpty2d(x)
  if p in (None, 2):
    s = la.svd(x, compute_uv=False)
    return s[..., 0] / s[..., -1]
  elif p == -2:
    s = la.svd(x, compute_uv=False)
    r = s[..., -1] / s[..., 0]
  else:
    _assertRankAtLeast2(x)
    _assertNdSquareness(x)
    invx = la.inv(x)
    r = la.norm(x, ord=p, axis=(-2, -1)) * la.norm(invx, ord=p, axis=(-2, -1))

  # Convert nans to infs unless the original array had nan entries
  orig_nan_check = np.full_like(r, ~np.isnan(r).any())
  nan_mask = np.logical_and(np.isnan(r), ~np.isnan(x).any(axis=(-2, -1)))
  r = np.where(orig_nan_check, np.where(nan_mask, np.inf, r), r)
  return r
Example #13
0
def log(q, eps=1e-8):
    """Computes the quaternion logarithm.

    References:
      https://en.wikipedia.org/wiki/Quaternion#Exponential,_logarithm,_and_power_functions

    Args:
      q: the quaternion in (x,y,z,w) format.
      eps: an epsilon value for numerical stability.

    Returns:
      The logarithm of q.
    """
    mag = linalg.norm(q, axis=-1, keepdims=True)
    v = im(q)
    s = re(q)
    w = jnp.log(mag)
    denom = jnp.maximum(linalg.norm(v, axis=-1, keepdims=True), eps * jnp.ones_like(v))
    xyz = v / denom * safe_acos(s / eps)
    return jnp.concatenate((xyz, w), axis=-1)
Example #14
0
    def _trust_region_body_f(
            params: _TrustRegionResults) -> _TrustRegionResults:
        # compute a new hessian vector product function given the current state
        hessvp = partial(_hvp, g_f, params.x_k)

        # we should add a interal success check for future subp approaches that might not be solvable
        # (e.g., non-PSD hessian)
        result = subp(params.f_k,
                      params.g_k,
                      params.g_k_mag,
                      hessvp,
                      params.trust_radius,
                      norm=norm)

        pred_f_kp1 = result.pred_f
        x_kp1 = params.x_k + result.step
        f_kp1, g_kp1 = vg_f(x_kp1)

        delta = params.f_k - f_kp1
        pred_delta = params.f_k - pred_f_kp1

        # update the trust radius according to the actual/predicted ratio
        # use `where` to avoid branching. this is a simple scalar check so not much computational overhead
        rho = delta / pred_delta
        tr = params.trust_radius
        cur_tradius = jnp.where(rho < 0.25, tr * 0.25, tr)
        cur_tradius = jnp.where((rho > 0.75) & result.hits_boundary,
                                jnp.minimum(2. * tr, max_trust_radius),
                                cur_tradius)

        # compute norm to check for convergence
        g_kp1_mag = jnpla.norm(g_kp1, ord=norm)

        # if the ratio is high enough then accept the proposed step
        # repeated check to skirt using cond/branching
        f_kp1 = jnp.where(rho > eta, f_kp1, params.f_k)
        x_kp1 = jnp.where(rho > eta, x_kp1, params.x_k)
        g_kp1 = jnp.where(rho > eta, g_kp1, params.g_k)
        g_kp1_mag = jnp.where(rho > eta, g_kp1_mag, params.g_k_mag)

        iter_params = _TrustRegionResults(converged=g_kp1_mag < gtol,
                                          good_approx=pred_delta > 0,
                                          k=params.k + 1,
                                          x_k=x_kp1,
                                          f_k=f_kp1,
                                          g_k=g_kp1,
                                          g_k_mag=g_kp1_mag,
                                          nfev=params.nfev + result.nfev + 1,
                                          ngev=params.ngev + result.ngev + 1,
                                          trust_radius=cur_tradius,
                                          status=params.status)

        return iter_params
Example #15
0
def newton_tol(fn, jac_fn, U,tol):
    maxit=20
    count = 0
    res = 100
    fail = 0
    Uold = U
    
    while(count < maxit and res > tol):
        J =  jac_fn(U, Uold)
#        J = jacrev(fn)(U,Uold)
#        Jsparse = csr_matrix(J)
        y = fn(U,Uold)
        res = max(abs(y/norm(y,2)))
        print(count, res)
        delta = solve(J,y)
#        delta = jitsolve(J,fn(U, Uold))
#        delta = spsolve(csr_matrix(J),fn(U,Uold))
        U = U - delta
        count = count + 1
    
        
    if fail ==0 and np.any(np.isnan(delta)):
        fail = 1
        print("nan solution")
        
    if fail == 0 and max(abs(np.imag(delta))) > 0:
            fail = 1
            print("solution complex")
    
    if fail == 0 and res > tol:
        fail = 1;
        print('Newton fail: no convergence')
    else:
        fail == 0 
        
    return U, fail
Example #16
0
    def body_f(iterp: _CGSteihaugState) -> _CGSteihaugState:

        # do an iteration
        Bd = hessvp(iterp.d)
        dBd = _dot(iterp.d, Bd)

        # after 1
        r_squared = _dot(iterp.r, iterp.r)
        alpha = r_squared / dBd
        z_next = iterp.z + alpha * iterp.d

        # after 2
        r_next = iterp.r + alpha * Bd
        r_next_squared = _dot(r_next, r_next)

        # include a junk switch to catch the case where none should be executed
        index = jnp.argmax(
            jnp.array([
                False, dBd <= 0,
                jnpla.norm(z_next, ord=norm) >= trust_radius,
                jnp.sqrt(r_next_squared) < tolerance
            ]))
        result = lax.switch(index, [noop, step1, step2, step3],
                            (iterp, z_next))

        # update the state for the next iteration
        beta_next = r_next_squared / r_squared
        d_next = -r_next + beta_next * iterp.d

        state = _CGSteihaugState(z=z_next,
                                 r=r_next,
                                 d=d_next,
                                 step=result.step,
                                 hits_boundary=result.hits_boundary,
                                 converged=result.converged)
        return state
Example #17
0
def _dihedral_angle(p1, p2, p3, p4):
    q = p3 - p2
    r = np.cross(p2 - p1, q)
    s = np.cross(q, p4 - p3)
    return np.arctan2(np.dot(np.cross(r, s), q), np.dot(r, s) * linalg.norm(q))
Example #18
0
def _angle(p1, p2, p3):
    q = p1 - p2
    r = p3 - p2
    return np.arctan2(linalg.norm(np.cross(q, r)), np.dot(q, r))
Example #19
0
def distance(r1, r2):
    return linalg.norm(r1 - r2)
Example #20
0
 def cond_fun(val):
     U, count, res, _ = val
     res = norm(y/norm(U,np.inf),np.inf)
     print("res:",res)
     return np.logical_and(res > tol, count < maxit)
Example #21
0
def cosine_similarity(a, b, norm_axis=1):
    normalized_a = a / LA.norm(a, ord=2, axis=norm_axis).reshape(-1, 1)
    normalized_b = b / LA.norm(b, ord=2, axis=norm_axis).reshape(-1, 1)

    return (normalized_a @ normalized_b.T)
Example #22
0
File: tt-svd.py Project: fasghq/TT
'''

A = []
for i in range(100):
  for j in range(100):
    for q in range(100):
      A.append(1 / (i + j + q + 3))
A = np.asarray(A)
A = A.reshape(100, 100, 100)

d = len(A.shape)
N = np.size(A)
n = A.shape

eps = 1e-12 # accuracy
delta = (eps/math.sqrt(d-1)) * la.norm(A) # cutting param

C = A # tmp tensor

G = [] # tt-cores
r = [] # tt-ranks
r.append(1)

for k in range(1, d):
  C = np.reshape(C, (r[k-1] * n[k-1], int(N / (r[k-1] * n[k-1]))))
  
  # calc low-rank approximation
  u, s, v = la.svd(C)
  sum = 0 
  nsize = np.size(s)
  rres = np.size(s)
Example #23
0
def norm(x, ord=None, axis=None, keepdims=False):
    if isinstance(x, JaxArray): x = x.value
    r = linalg.norm(x, ord=ord, axis=axis, keepdims=keepdims)
    return r if axis is None else JaxArray(r)
Example #24
0
def normalize(q):
    """Normalize a quaternion."""
    norm = linalg.norm(q, axis=-1, keepdims=True)
    return q / norm
Example #25
0
def minimize_trust_region(
    fun: Callable,
    x0: jnp.ndarray,
    maxiter: Optional[int] = None,
    norm=jnp.inf,
    gtol: float = 1e-5,
    max_trust_radius: Union[float, jnp.ndarray] = 1000.,
    initial_trust_radius: Union[float, jnp.ndarray] = 1.0,
    eta: Union[float, jnp.ndarray] = 0.15,
    method="trust-ncg",
) -> _TrustRegionResults:

    if not (0 <= eta < 0.25):
        raise Exception("invalid acceptance stringency")
    if gtol < 0.:
        raise Exception("gradient tolerance must be positive")
    if max_trust_radius <= 0:
        raise Exception("max trust radius must be positive")
    if initial_trust_radius <= 0:
        raise ValueError("initial trust radius must be positive")
    if initial_trust_radius >= max_trust_radius:
        raise ValueError(
            "initial trust radius must be less than the max trust radius")

    if method == "trust-ncg":
        subp = CGSteihaugSubproblem
    else:
        raise ValueError("Method {} not recognized".format(method))
    if maxiter is None:
        maxiter = jnp.size(x0) * 200

    vg_f = value_and_grad(fun)
    g_f = grad(fun)
    f_0, g_0 = vg_f(x0)

    init_params = _TrustRegionResults(converged=False,
                                      good_approx=jnp.isfinite(
                                          jnpla.norm(g_0, ord=norm)),
                                      k=1,
                                      x_k=x0,
                                      f_k=f_0,
                                      g_k=g_0,
                                      g_k_mag=jnpla.norm(g_0, ord=norm),
                                      nfev=1,
                                      ngev=1,
                                      trust_radius=initial_trust_radius,
                                      status=0)

    # function to generate the hessian vector product function
    def _hvp(g_f, primals, tangents):
        return jvp(g_f, (primals, ), (tangents, ))[1]

    # condition for the main trust region optimization loop
    def _trust_region_cond_f(params: _TrustRegionResults) -> bool:
        return (jnp.logical_not(params.converged)
                & (params.k < maxiter)
                & params.good_approx)

    # function to take a constrained gradient step or adjust trust region size for next iteration
    def _trust_region_body_f(
            params: _TrustRegionResults) -> _TrustRegionResults:
        # compute a new hessian vector product function given the current state
        hessvp = partial(_hvp, g_f, params.x_k)

        # we should add a interal success check for future subp approaches that might not be solvable
        # (e.g., non-PSD hessian)
        result = subp(params.f_k,
                      params.g_k,
                      params.g_k_mag,
                      hessvp,
                      params.trust_radius,
                      norm=norm)

        pred_f_kp1 = result.pred_f
        x_kp1 = params.x_k + result.step
        f_kp1, g_kp1 = vg_f(x_kp1)

        delta = params.f_k - f_kp1
        pred_delta = params.f_k - pred_f_kp1

        # update the trust radius according to the actual/predicted ratio
        # use `where` to avoid branching. this is a simple scalar check so not much computational overhead
        rho = delta / pred_delta
        tr = params.trust_radius
        cur_tradius = jnp.where(rho < 0.25, tr * 0.25, tr)
        cur_tradius = jnp.where((rho > 0.75) & result.hits_boundary,
                                jnp.minimum(2. * tr, max_trust_radius),
                                cur_tradius)

        # compute norm to check for convergence
        g_kp1_mag = jnpla.norm(g_kp1, ord=norm)

        # if the ratio is high enough then accept the proposed step
        # repeated check to skirt using cond/branching
        f_kp1 = jnp.where(rho > eta, f_kp1, params.f_k)
        x_kp1 = jnp.where(rho > eta, x_kp1, params.x_k)
        g_kp1 = jnp.where(rho > eta, g_kp1, params.g_k)
        g_kp1_mag = jnp.where(rho > eta, g_kp1_mag, params.g_k_mag)

        iter_params = _TrustRegionResults(converged=g_kp1_mag < gtol,
                                          good_approx=pred_delta > 0,
                                          k=params.k + 1,
                                          x_k=x_kp1,
                                          f_k=f_kp1,
                                          g_k=g_kp1,
                                          g_k_mag=g_kp1_mag,
                                          nfev=params.nfev + result.nfev + 1,
                                          ngev=params.ngev + result.ngev + 1,
                                          trust_radius=cur_tradius,
                                          status=params.status)

        return iter_params

    state = lax.while_loop(_trust_region_cond_f, _trust_region_body_f,
                           init_params)
    status = jnp.where(
        state.converged,
        0,  # converged
        jnp.where(
            state.k == maxiter,
            1,  # max iters reached
            jnp.where(
                state.good_approx,
                -1,  # undefined
                2,  # poor approx
            )))
    state = state._replace(status=status)

    return state
Example #26
0
def norm(q):
    return linalg.norm(q, axis=-1, keepdims=True)