示例#1
0
    def trainsplit(self, ntrain=1000, tt=4000):
        x_inp_real = np.real(self.denMO)[:, self.rnzl[0], self.rnzl[1]]
        x_inp_imag = np.imag(self.denMO)[:, self.inzl[0], self.inzl[1]]
        self.x_inp = np.hstack([x_inp_real, x_inp_imag])

        self.offset = 2
        self.tt = tt
        self.ntrain = ntrain
        self.x_inp = self.x_inp[self.offset:(self.tt + self.offset), :]

        self.dt = 0.08268
        self.tint_whole = np.arange(self.x_inp.shape[0]) * self.dt

        # training set
        self.x_inp_train = self.x_inp[:ntrain, :]
        self.tint = self.tint_whole[:ntrain]

        # validation set
        self.x_inp_valid = self.x_inp[ntrain:, :]
        self.tint_valid = self.tint_whole[ntrain:]

        # adding field commutator terms
        hpcommute_real = np.real(self.eftraincommuteMOflat)
        hpcommute_imag = np.imag(self.eftraincommuteMOflat)
        self.hpcommute_train = np.hstack([hpcommute_real, hpcommute_imag])
        self.hpcommute_train_loss = self.hpcommute_train[1:(self.ntrain -
                                                            1), :]

        # show that we got here
        return True
示例#2
0
def pinv(model: SpectralSobolev1Fit):
    ns = model.exponents
    A = vander_builder(model.grid, ns)(model.mesh)
    B = vandergrad_builder(model.grid, ns)(model.mesh)
    I = np.ones((np.size(A, 0), 1))
    O = np.zeros((np.size(B, 0), 1))
    #
    if model.is_periodic:
        U = np.hstack((I, np.real(A), np.imag(A)))
        V = np.hstack((O, np.imag(B), np.real(B)))
    else:
        U = np.hstack((I, A))
        V = np.hstack((O, B))
    #
    return np.linalg.pinv(np.vstack((U, V)))
示例#3
0
def pinv(model: SpectralGradientFit):
    A = vandergrad_builder(model.grid, model.exponents)(model.mesh)
    #
    if model.is_periodic:
        A = np.hstack((np.imag(A), np.real(A)))
    #
    return np.linalg.pinv(A)
示例#4
0
def newton(fn, jac_fn, U):
    maxit=20
    tol = 1e-8
    count = 0
    res = 100
    fail = 0

    Uold = U
    maxit=5
#    
#    @jax.jit
#    def body_fun(U,Uold):
#        J =  jac_fn(U, Uold)
#        y = fn(U,Uold)
#        res = norm(y/norm(U,np.inf),np.inf)
#        delta = solve(J,y)
#        U = U - delta
#        return U, res
#   
    print("here")
    start =timeit.default_timer()     
    J =  jac_fn(U, Uold)
    print("computed jacobian")
    y = fn(U,Uold)
    res0 = norm(y/norm(U,np.inf),np.inf)
    delta = solve(J,y)
    U = U - delta
    count = count + 1
    end = timeit.default_timer()
    print("time elapsed in first loop", end-start)
    print(count, res0)
    while(count < maxit and  res > tol):
#        U, res, delta = body_fun(U,Uold)\
        start1 =timeit.default_timer() 
        J =  jac_fn(U, Uold)
        y = fn(U,Uold)
        res = norm(y/norm(U,np.inf),np.inf)
        delta = solve(J,y)
        U = U - delta
        count = count + 1
        end1 =timeit.default_timer() 
        print("time per loop", end1-start1)
        print(count, res)
    
        
    if fail ==0 and np.any(np.isnan(delta)):
        fail = 1
        print("nan solution")
        
    if fail == 0 and max(abs(np.imag(delta))) > 0:
            fail = 1
            print("solution complex")
    
    if fail == 0 and res > tol:
        fail = 1;
        print('Newton fail: no convergence')
    else:
        fail == 0 
        
    return U, fail
示例#5
0
def loss(params):
    contrasts = np.array([0, 25, 50, 100])
    spect, fs, f0, r_fp, CONVG, _ = ssn_PS(params, contrasts)
    
    if CONVG:
    
        if np.max(np.abs(np.imag(spect))) > 0.01:
            print("Spectrum is dangerously imaginary")
            
        cons = len(contrasts)
        lower_bound_rates = -5 * np.ones([2, cons-1])
        upper_bound_rates = np.vstack((70*np.ones(cons-1), 100*np.ones(cons-1)))
        kink_control = 1 # how quickly log(1 + exp(x)) goes to ~x, where x = target_rates - found_rates    

        prefact_rates = 1
        prefact_params = 10

        fs_loss_inds = np.arange(0 , len(fs))
        fs_loss_inds = np.array([freq for freq in fs_loss_inds if fs[freq] >20])
        spect_loss = losses.loss_spect_nonzero_contrasts(fs[fs_loss_inds], spect[fs_loss_inds,:])

#         spect_loss = losses.loss_spect_contrasts(fs, np.real(spect))
        rates_loss = prefact_rates * losses.loss_rates_contrasts(r_fp[:,1:], lower_bound_rates, upper_bound_rates, kink_control) #fourth arg is slope which is set to 1 normally
        #param_loss = prefact_params * losses.loss_params(params)
        return spect_loss + rates_loss# + param_loss
    else:
        return np.inf
示例#6
0
def newton_while_lax(fn, jac_fn, U, maxit, tol):
    
    
    count = 0
    res = 100
    fail = 0
    
    val = (U, count, res, fail)

    Uold = U
    J =  jac_fn(U, Uold)
    y = fn(U,Uold)  
    delta = solve(J,y)
    U = U - delta;
#    res0 = norm(y/norm(U,np.inf),np.inf)
    
   
    def cond_fun(val):
        U, count, res, _ = val
        res = norm(y/norm(U,np.inf),np.inf)
        print("res:",res)
        return np.logical_and(res > tol, count < maxit)
#    
   
    def body_fun(val):
        U, count, res, fail = val
        J = jac_fn(U,Uold);
        y = fn(U,Uold)
        delta = solve(J,y)
        U = U - delta
        res = norm(y/norm(U,np.inf),np.inf)
        count = count + 1
        print(count, res)
        val = U, count, res, fail
      
        return val
    
    val =lax.while_loop(cond_fun, body_fun, val )
    U, count, res, _ = val
   
        
    if fail ==0 and np.any(np.isnan(delta)):
        fail = 1
        print("nan solution")
        
    if fail == 0 and max(abs(np.imag(delta))) > 0:
            fail = 1
            print("solution complex")
    
    if fail == 0 and res > tol:
        fail = 1;
        print('Newton fail: no convergence')
    else:
        fail == 0 
        
    return U, fail
示例#7
0
文件: random_test.py 项目: zizai/jax
  def testNormalComplex(self, dtype):
    key = random.PRNGKey(0)
    rand = lambda key: random.normal(key, (10000,), dtype)
    crand = api.jit(rand)

    uncompiled_samples = rand(key)
    compiled_samples = crand(key)

    for samples in [uncompiled_samples, compiled_samples]:
      self._CheckKolmogorovSmirnovCDF(jnp.real(samples), scipy.stats.norm(scale=1/np.sqrt(2)).cdf)
      self._CheckKolmogorovSmirnovCDF(jnp.imag(samples), scipy.stats.norm(scale=1/np.sqrt(2)).cdf)
示例#8
0
def damped_newton(fn, jac_fn, U):
    maxit=10
    tol = 1e-8
    count = 0
    res = 100
    fail = 0
#    U = jax.ops.index_update(U, jax.ops.index[jp02:etap02], U[jp02:etap02]*2**(-16))
    Uold = U
    J =  jac_fn(U, Uold)
    y = fn(U,Uold)  
    delta = solve(J,y)
    U = U - delta;
    res0 = norm(y/norm(U,np.inf),np.inf)
    print(count, res0)
    while(count < maxit and res > tol):
        J =  jac_fn(U, Uold)
        y = fn(U,Uold)        
        res = norm(y/norm(U,np.inf),np.inf)
#        res=norm(y, np.inf)
        print(count, res)
        delta = solve(J,y)
        
#        alpha = 1.0
#        while (norm( fn(U - alpha*delta,Uold )) > (1-alpha*0.5)*norm(y)):
##            print("norm1",norm( fn(U - alpha*delta,Uold )))
##            print("norm2", (1-alpha*0.5)*norm(y) )
#            alpha = alpha/2;
##            print("alpha",alpha)
#            if (alpha < 1e-8):
#                break;
#                
#        U = U - alpha*delta
        U = U - delta;
        count = count + 1
    
        
    if fail ==0 and np.any(np.isnan(delta)):
        fail = 1
        print("nan solution")
        
    if fail == 0 and max(abs(np.imag(delta))) > 0:
            fail = 1
            print("solution complex")
    
    if fail == 0 and res > tol:
        fail = 1;
        print('Newton fail: no convergence')
    else:
        fail == 0 
        
    return U, fail
示例#9
0
    def __call__(self, x):
        if x.ndim < 3:
            x = jnp.expand_dims(x, -2)  # add a feature dimension

        x_flip = self.dense_symm(-1 * x)
        x = self.dense_symm(x)

        for layer in range(self.layers - 1):
            x = self.activation(x)
            x_flip = self.activation(x_flip)

            x_new = (
                self.equivariant_layers[layer](x)
                + self.equivariant_layers_flip[layer](x_flip)
            ) / np.sqrt(2)
            x_flip = (
                self.equivariant_layers[layer](x_flip)
                + self.equivariant_layers_flip[layer](x)
            ) / np.sqrt(2)
            x = jnp.array(x_new, copy=True)

        x = jnp.concatenate((x, x_flip), -1)

        x = self.output_activation(x)

        if self.parity == 1:
            par_chars = jnp.expand_dims(
                jnp.concatenate(
                    (jnp.array(self.characters), jnp.array(self.characters)), 0
                ),
                (0, 1),
            )
        else:
            par_chars = jnp.expand_dims(
                jnp.concatenate(
                    (jnp.array(self.characters), -1 * jnp.array(self.characters)), 0
                ),
                (0, 1),
            )

        if self.complex_output:
            x = logsumexp_cplx(x, axis=(-2, -1), b=par_chars)
        else:
            x = logsumexp(x, axis=(-2, -1), b=par_chars)

        if self.equal_amplitudes:
            return 1j * jnp.imag(x)
        else:
            return x
 def compute_expectations(self, mean, cov):
     """
     returns a list of expected values of the following expressions:
     x, x^2, cos(x), sin(x)
     """
     # characteristic function at [1, ..., 1]:
     t = np.ones(self.d)
     char = np.exp(np.vdot(t, (1j * mean - np.dot(cov, t) / 2)))
     expectations = [
         mean,
         np.diagonal(cov) + mean**2,
         np.real(char),
         np.imag(char)
     ]
     return expectations
示例#11
0
    def fun_on_leaf(_z):
        if np.isnan(_z):
            if not ignore_nan:
                raise ValueError('NaN encountered')
            return np.real(_z)

        _z_re = np.real(_z)

        if not ignore_im_part:
            if not np.allclose(
                    _z_re, _z_re + np.imag(_z), rtol=rtol, atol=atol):
                raise ValueError(
                    'Significant imaginary part encountered where it was not expected'
                )

        return _z_re
示例#12
0
    def __call__(self, x):
        x = self.dense_symm(x)

        for layer in range(self.layers - 1):
            x = self.activation(x)
            x = self.equivariant_layers[layer](x)

        x = self.output_activation(x)

        x = logsumexp(x,
                      axis=(-2, -1),
                      b=jnp.expand_dims(jnp.asarray(self.characters), (0, 1)))

        if self.equal_amplitudes:
            return 1j * jnp.imag(x)
        else:
            return x
示例#13
0
def newton(fn, jac_fn, U):
    maxit = 20
    tol = 1e-8
    count = 0
    res = 100
    fail = 0
    Uold = U

    start = timeit.default_timer()
    J = jac_fn(U, Uold)
    y = fn(U, Uold)
    res0 = norm(y / norm(U, np.inf), np.inf)
    delta = solve(J, y)
    U = U - delta
    count = count + 1
    end = timeit.default_timer()
    print("time elapsed in first loop", end - start)
    print(count, res0)
    while (count < maxit and res > tol):
        start1 = timeit.default_timer()
        J = jac_fn(U, Uold)
        y = fn(U, Uold)
        res = norm(y / norm(U, np.inf), np.inf)
        delta = solve(J, y)
        U = U - delta
        count = count + 1
        end1 = timeit.default_timer()
        print(count, res)
        print("time per loop", end1 - start1)

    if fail == 0 and np.any(np.isnan(delta)):
        fail = 1
        print("nan solution")

    if fail == 0 and max(abs(np.imag(delta))) > 0:
        fail = 1
        print("solution complex")

    if fail == 0 and res > tol:
        fail = 1
        print('Newton fail: no convergence')
    else:
        fail == 0

    return U, fail
示例#14
0
    def _apply(
        self,
        iter: jnp.ndarray,
        grad: jnp.ndarray,
        state: Tuple[jnp.ndarray],
        param: jnp.ndarray,
        precond: Union[None, jnp.ndarray],
        use_precond=False
    ) -> Tuple[jnp.ndarray, Tuple[jnp.ndarray]]:
        if use_precond:
            rgrad = self.manifold.egrad_to_rgrad(param, grad.conj(), precond)
        else:
            rgrad = self.manifold.egrad_to_rgrad(param, grad.conj())
        momentum = self.beta1 * state[0] + (1 - self.beta1) * rgrad
        if use_precond:
            v = self.beta2 * state[1] + (1 - self.beta2) * self.manifold.inner(
                param, rgrad, rgrad, precond
            )
        else:
            v = self.beta2 * state[1] + (1 - self.beta2) * self.manifold.inner(
                param, rgrad, rgrad
            )
        if self.ams:
            v_hat = jax.lax.complex(jnp.maximum(jnp.real(v), jnp.real(state[2])), jnp.imag(v))

        # Bias correction
        lr_corr = (
            self.learning_rate
            * jnp.sqrt(1 - self.beta2 ** (iter + 1))
            / (1 - self.beta1 ** (iter + 1))
        )

        if self.ams:
            search_dir = -lr_corr * momentum / (jnp.sqrt(v_hat) + self.eps)
            param, momentum = self.manifold.retraction_transport(
                param, momentum, search_dir
            )
            return param, (momentum, v, v_hat)
        else:
            search_dir = -lr_corr * momentum / (jnp.sqrt(v) + self.eps)
            param, momentum = self.manifold.retraction_transport(
                param, momentum, search_dir
            )
            return param, (momentum, v)
示例#15
0
def sd_dir(grad, eps_iter):
    dft = np.matrix(scipy.linalg.dft(grad.shape[0], scale='sqrtn'))
    dftxgrad = dft @ grad
    dftz = dftxgrad.reshape(1, -1)
    dftz = jnp.concatenate((jnp.real(dftz), jnp.imag(dftz)), axis=0)

    # projection does not scale bigger, here we want to scale it
    def l2_normalize(delta, eps):
        avoid_zero_div = 1e-15
        norm2 = jnp.sum(delta**2, axis=0, keepdims=True)
        norm = jnp.sqrt(jnp.maximum(avoid_zero_div, norm2))
        # only decrease the norm, never increase
        delta = delta * eps / norm
        return delta

    dftz = l2_normalize(dftz, eps_iter)
    dftz = (dftz[0, :] + 1j * dftz[1, :]).reshape(grad.shape)
    adv_step = dft.getH() @ dftz
    return adv_step
示例#16
0
def wofzs2(x, y):
    """Asymptotic representation of wofz (Faddeeva) function 1 for |z|**2 > 112 (for e = 10e-6)

    See Zaghloul (2018) arxiv:1806.01656

    Args:
        x: 
        y:

    Returns:
         jnp.array, jnp.array: H=real(wofz(x+iy)),L=imag(wofz(x+iy))

    """

    z = x + y * (1j)
    a = 1.0 / (2.0 * z * z)
    q = (1j) / (z * jnp.sqrt(jnp.pi)) * (1.0 + a * (1.0 + a *
                                                    (3.0 + a * 15.0)))
    return jnp.real(q), jnp.imag(q)
def store_eig_vec(evals_small, evecs_small, filename):
    idx_min = np.argmin(evals_small)
    print("GS energy: %f" % evals_small[idx_min])
    vec_r = np.real(evecs_small[:, idx_min])
    vec_i = np.imag(evecs_small[:, idx_min])
    vec_r = vec_r / np.linalg.norm(vec_r)
    vec_i = vec_i / np.linalg.norm(vec_i)
    if np.abs(vec_r.dot(vec_i)) - 1. < 1e-6:
        print("Eigen Vec can be casted as real")
        log_file = open(filename, 'wb')
        np.savetxt(log_file, vec_r, fmt='%.8e', delimiter=',')
        log_file.close()
    else:
        print(np.abs(vec_r.dot(vec_i)) - 1.)
        print("Complex Eigen Vec !!!")
        print("The real part <E> : %f " % vec_r.T.dot(H.dot(vec_r)))
        print("The imag part <E> : %f " % vec_i.T.dot(H.dot(vec_i)))

    return
示例#18
0
文件: special.py 项目: jbampton/jax
def _sph_harm(m: jnp.ndarray, n: jnp.ndarray, theta: jnp.ndarray,
              phi: jnp.ndarray, n_max: int) -> jnp.ndarray:
    """Computes the spherical harmonics."""

    cos_colatitude = jnp.cos(phi)

    legendre = _gen_associated_legendre(n_max, cos_colatitude, True)
    legendre_val = legendre.at[abs(m), n, jnp.arange(len(n))].get(mode="clip")

    angle = abs(m) * theta
    vandermonde = lax.complex(jnp.cos(angle), jnp.sin(angle))
    harmonics = lax.complex(legendre_val * jnp.real(vandermonde),
                            legendre_val * jnp.imag(vandermonde))

    # Negative order.
    harmonics = jnp.where(m < 0, (-1.0)**abs(m) * jnp.conjugate(harmonics),
                          harmonics)

    return harmonics
示例#19
0
def norm_projection(delta, norm_type, eps=1.):
  """Projects to a norm-ball centered at 0.

  Args:
    delta: An array of size dim x num containing vectors to be projected.
    norm_type: A string denoting the type of the norm-ball.
    eps: A float denoting the radius of the norm-ball.

  Returns:
    An array of size dim x num, the projection of delta to the norm-ball.
  """
  shape = delta.shape
  if len(delta.shape) == 1:
    delta = delta.reshape(-1, 1)
  if norm_type == 'linf':
    delta = jnp.clip(delta, -eps, eps)
  elif norm_type == 'l2':
    # Euclidean projection: divide all elements by a constant factor
    avoid_zero_div = 1e-12
    norm2 = jnp.sum(delta**2, axis=0, keepdims=True)
    norm = jnp.sqrt(jnp.maximum(avoid_zero_div, norm2))
    # only decrease the norm, never increase
    delta = delta * jnp.clip(eps / norm, a_min=None, a_max=1)
  elif norm_type == 'l1':
    delta = l1_unit_projection(delta / eps) * eps
  elif norm_type == 'dftinf':
    # transform to DFT, project using known projections, then transform back
    # dft = np.matrix(scipy.linalg.dft(delta.shape[0]) / np.sqrt(delta.shape[0]))
    dft = np.matrix(scipy.linalg.dft(delta.shape[0], scale='sqrtn'))
    dftxdelta = dft @ delta
    # dftxdelta = np.matrix(scipy.fft.fft(delta, axis=0, norm='ortho'))
    # L2 projection of each coordinate to the L2-ball in the complex plane
    dftz = dftxdelta.reshape(1, -1)
    dftz = jnp.concatenate((jnp.real(dftz), jnp.imag(dftz)), axis=0)
    dftz = norm_projection(dftz, 'l2', eps)
    dftz = (dftz[0, :] + 1j * dftz[1, :]).reshape(delta.shape)
    # project back from DFT
    delta = dft.getH() @ dftz
    # delta = np.matrix(scipy.fft.ifft(dftz, axis=0, norm='ortho'))
    # Projected vector can have an imaginary part
    delta = jnp.real(delta)
  return delta.reshape(shape)
示例#20
0
    def __call__(self, x):
        if x.ndim < 3:
            x = jnp.expand_dims(x, -2)  # add a feature dimension
        x = self.dense_symm(x)

        for layer in range(self.layers - 1):
            x = self.activation(x)
            x = self.equivariant_layers[layer](x)

        x = self.output_activation(x)

        if self.complex_output:
            x = logsumexp_cplx(x, axis=(-2, -1), b=jnp.asarray(self.characters))
        else:
            x = logsumexp(x, axis=(-2, -1), b=jnp.asarray(self.characters))

        if self.equal_amplitudes:
            return 1j * jnp.imag(x)
        else:
            return x
示例#21
0
    def trainsplit(self, ntrain=1000, tt=4000):
        x_inp_real = np.real(self.denMO)[:, self.rnzl[0], self.rnzl[1]]
        x_inp_imag = np.imag(self.denMO)[:, self.inzl[0], self.inzl[1]]
        self.x_inp = np.hstack([x_inp_real, x_inp_imag])

        self.offset = 2
        self.tt = tt
        self.ntrain = ntrain
        self.x_inp = self.x_inp[self.offset:(self.tt + self.offset), :]

        self.dt = 0.08268
        self.tint_whole = np.arange(self.x_inp.shape[0]) * self.dt

        # training set
        self.x_inp_train = self.x_inp[:ntrain, :]
        self.tint = self.tint_whole[:ntrain]

        # validation set
        self.x_inp_valid = self.x_inp[ntrain:, :]
        self.tint_valid = self.tint_whole[ntrain:]

        # show that we got here
        return True
示例#22
0
def newton_tol(fn, jac_fn, U,tol):
    maxit=20
    count = 0
    res = 100
    fail = 0
    Uold = U
    
    while(count < maxit and res > tol):
        J =  jac_fn(U, Uold)
#        J = jacrev(fn)(U,Uold)
#        Jsparse = csr_matrix(J)
        y = fn(U,Uold)
        res = max(abs(y/norm(y,2)))
        print(count, res)
        delta = solve(J,y)
#        delta = jitsolve(J,fn(U, Uold))
#        delta = spsolve(csr_matrix(J),fn(U,Uold))
        U = U - delta
        count = count + 1
    
        
    if fail ==0 and np.any(np.isnan(delta)):
        fail = 1
        print("nan solution")
        
    if fail == 0 and max(abs(np.imag(delta))) > 0:
            fail = 1
            print("solution complex")
    
    if fail == 0 and res > tol:
        fail = 1;
        print('Newton fail: no convergence')
    else:
        fail == 0 
        
    return U, fail
 def compute_expectations(self, means, covs, weights):
     """
     returns a list of expected values of the following expressions:
     x, x^2, cos(x), sin(x)
     """
     # characteristic function at [1, ..., 1]:
     t = np.ones(self.d)
     chars = np.array([
         np.exp(np.vdot(t, (1j * mean - np.dot(cov, t) / 2)))
         for mean, cov in zip(means, covs)
     ])  # shape (k,d)
     char = np.vdot(weights, chars)
     mean = np.einsum("i,id->d", weights, means)
     xsquares = [
         np.diagonal(cov) + mean**2 for mean, cov in zip(means, covs)
     ]
     expectations = [
         mean,
         np.einsum("i,id->d", weights, xsquares),
         np.real(char),
         np.imag(char)
     ]
     expectations = [np.squeeze(e) for e in expectations]
     return expectations
示例#24
0
文件: nlp.py 项目: szhang104/pycomm
 def comp2real(z):
     return np.concatenate((np.real(z), np.imag(z)))
示例#25
0
def test_expect_herm(oper, state):
    """Tests that the expectation value of a hermitian operator is real and that of 
       the non-hermitian operator is complex"""
    assert jnp.imag(expect(oper, state)) == 0.0
示例#26
0
 def f(z):
   x_re = jnp.concatenate([jnp.real(z), jnp.imag(z)])
   return f_re(x_re)
示例#27
0
def transf(x):
    # use jax.numpy instead of numpy in this function
    f = jnp.fft.rfft(x)
    return jnp.stack([jnp.real(f), jnp.imag(f)])
示例#28
0
       hamimags[(i,j)] = cnt
       cnt += 1
"""
hamreals = nzreals.copy()
hamimags = nzimags.copy()

print('hamreals:')
print(hamreals)
print('hamimags:')
print(hamimags)
hamdof = cnt
print('hamdof: ', hamdof)

# set up training data
x_inp_real = np.real(denMO)[:, rnzl[0], rnzl[1]]
x_inp_imag = np.imag(denMO)[:, inzl[0], inzl[1]]
x_inp = np.hstack([x_inp_real, x_inp_imag])

offset = 2
tt = 4000
x_inp = x_inp[offset:(tt + offset), :]

dt = 0.08268
npts = x_inp.shape[0]
tint_whole = np.arange(npts) * dt

if mol == 'c2h4':
    ntrain = 2000
else:
    ntrain = 1000
示例#29
0
def flat_gradient(fun, arg):
    gr = grad(lambda x, y: jnp.real(x(y)))(fun,arg)
    gr = tree_flatten(jax.tree_util.tree_map(lambda x: x.ravel(), gr))[0]
    gi = grad(lambda x, y: jnp.imag(x(y)))(fun,arg)
    gi = tree_flatten(jax.tree_util.tree_map(lambda x: x.ravel(), gi))[0]
    return jnp.concatenate(gr) + 1.j * jnp.concatenate(gi)
示例#30
0
def loss(params):
    spect, fs, obs_f0, r_fp, CONVG = ssn_PS(params, contrasts)

    if CONVG:
        if np.max(np.abs(np.imag(spect))) > 0.01:
            print("Spectrum is dangerously imaginary")

        #half_width_rates = 20 # basin around acceptable rates
        #lower_bound_rates = 0 # needs to be > 0, valley will start -lower_bound, 5 is a nice value with kink_control = 5
        #upper_bound_rates = 80 # valley ends at upper_bound, keeps rates from blowing up

        prefact_rates = 1
        prefact_params = 10

        fs_loss_inds = np.arange(0, len(fs))
        fs_loss_inds = np.array([
            freq for freq in fs_loss_inds if fs[freq] > 20
        ])  #np.where(fs > 0, fs_loss_inds, )
        #     fs_loss = fs[np.where(fs > 20)]

        spect_loss = losses.loss_spect_contrasts(
            fs[fs_loss_inds], np.real(spect[fs_loss_inds, :]))
        #spect_loss = losses.loss_spect_nonzero_contrasts(fs[fs_loss_inds], spect[fs_loss_inds,:])
        #rates_loss = prefact_rates * losses.loss_rates_contrasts(r_fp[:,1:], lower_bound_rates, upper_bound_rates, kink_control) #fourth arg is slope which is set to 1 normally
        #rates_loss = prefact_rates * losses.loss_rates_contrasts(r_fp, lower_bound_rates, upper_bound_rates, kink_control) # recreate ground truth
        #param_loss = prefact_params * losses.loss_params(params)
        #     peak_freq_loss = losses.loss_peak_freq(fs, obs_f0)

        #if spect_loss/rates_loss < 1:
        #print('rates loss is greater than spect loss')
        #     print(spect_loss/rates_loss)

        return spect_loss  # + param_loss + rates_loss # + peak_freq_loss #
    else:
        return np.inf


# def sigmoid_params(pos_params):
#     J_max = 3
#     i2e_max = 2
#     gE_max = 2
#     gI_max = 1.5 #because I do not want gI_min = 0, so I will offset the sigmoid
#     gI_min = 0.5
#     NMDA_max = 1

#     Jee = J_max * logistic_sig(pos_params[0])
#     Jei = J_max * logistic_sig(pos_params[1])
#     Jie = J_max * logistic_sig(pos_params[2])
#     Jii = J_max * logistic_sig(pos_params[3])

#     if len(pos_params) < 6:
#         i2e = i2e_max * logistic_sig(pos_params[4])
#         gE = 1
#         gI = 1
#         NMDAratio = 0.4

#         params = np.array([Jee, Jei, Jie, Jii, gE, gI, NMDAratio])

#     else:
#         i2e = 1
#         gE = gE_max * logistic_sig(pos_params[4])
#         gI = gI_max * logistic_sig(pos_params[5]) + gI_min
#         NMDAratio = NMDA_max * logistic_sig(pos_params[6])

#         params = np.array([Jee, Jei, Jie, Jii, gE, gI, NMDAratio])

#     return params

# def logistic_sig(x):
#     return 1/(1 + np.exp(-x))