Exemple #1
0
def _ravel_list(*leaves):
    leaves_metadata = tree_map(
        lambda l: pytree_metadata(jnp.ravel(l), jnp.shape(l), jnp.size(l),
                                  canonicalize_dtype(lax.dtype(l))), leaves)
    leaves_idx = jnp.cumsum(
        jnp.array((0, ) + tuple(d.size for d in leaves_metadata)))

    def unravel_list(arr):
        return [
            jnp.reshape(lax.dynamic_slice_in_dim(arr, leaves_idx[i], m.size),
                        m.shape).astype(m.dtype)
            for i, m in enumerate(leaves_metadata)
        ]

    flat = jnp.concatenate([m.flat for m in leaves_metadata
                            ]) if leaves_metadata else jnp.array([])
    return flat, unravel_list
Exemple #2
0
def dmrg_solve(A, L, R, mpo, n_krylov, tol, maxiter):
    """
    The local ground state step of single-site DMRG.
    """
    mpo_map = jax_dmrg.map.SingleMPOHeffMap(mpo, L, R)
    A_vec = jnp.ravel(A)

    #  E, eV, err = minimum_eigenpair_jit(mpo_map.matvec, mpo_map.data,
    #                                     n_krylov, maxiter, tol,
    #                                     A_vec)
    E, eV, err = minimum_eigenpair(mpo_map,
                                   n_krylov,
                                   maxiter=maxiter,
                                   tol=tol,
                                   v0=A_vec)
    #print(err)
    newA = eV.reshape(A.shape)
    return (E, newA, err)
Exemple #3
0
    def update(self, params, grad):
        """
        Description: Updates parameters based on correct value, loss and learning rate.
        Args:
            params (list/numpy.ndarray): Parameters of controller pred controller
            x (float): input to controller
            y (float): true label
            loss (function): loss function. defaults to input value.
        Returns:
            Updated parameters in same shape as input
        """
        # Make everything a list for generality
        is_list = True
        if(type(params) is not list):
            params = [params]
            grad = [grad]
            is_list = False

        # used to compute inverse matrix with respect to each parameter vector
        flat_grad = [np.ravel(dw) for dw in grad]

        # initialize A
        if self.A is None:
            self.A = [np.eye(dw.shape[0]) * self.eps for dw in flat_grad]
            self.Ainv = [np.eye(dw.shape[0]) * (1 / self.eps) for dw in flat_grad]

        # compute max norm and normalize accordingly
        eta = self.lr
        if(self.max_norm):                     
            self.max_norm = np.maximum(self.max_norm, np.linalg.norm([self.general_norm(dw) for dw in flat_grad]))
            eta = eta * self.max_norm
            
        # partial_update automatically reshapes flat_grad into correct params shape
        new_values = [self.partial_update(A, Ainv, g, w) for (A, Ainv, g, w) in zip(self.A, self.Ainv, flat_grad, params)]
        self.A, self.Ainv, new_grad = list(map(list, zip(*new_values)))

        new_params = [w - eta * dw for (w, dw) in zip(params, new_grad)]

        if self.project:
            self.min_radius = np.maximum(self.min_radius, self.general_norm(y))
            norm = 5. * self.min_radius
            new_params = [self.norm_project(p, A, norm) for (p, A) in zip(new_params, self.A)]

        return new_params if is_list else new_params[0]
Exemple #4
0
    def ravel(self, tensor):
        """
        Return a flattened view of the tensor, not a copy.

        Example:

            >>> import pyhf
            >>> pyhf.set_backend("jax")
            >>> tensor = pyhf.tensorlib.astensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
            >>> pyhf.tensorlib.ravel(tensor)
            DeviceArray([1., 2., 3., 4., 5., 6.], dtype=float64)

        Args:
            tensor (Tensor): Tensor object

        Returns:
            `jaxlib.xla_extension.DeviceArray`: A flattened array.
        """
        return jnp.ravel(tensor)
Exemple #5
0
def tree_ravel(pytree):
    r"""

    Flatten and concatenate all leaves into a single flat ndarray.

    Parameters
    ----------
    pytree : a pytree with ndarray leaves

        A typical example is a pytree of model parameters (weights) or gradients with respect to
        such model params.

    Returns
    -------
    arr : ndarray with ndim=1

        A single flat array.

    """
    return jnp.concatenate([jnp.ravel(leaf) for leaf in jax.tree_leaves(pytree)])
        def kmer_kernel_fn(theta, kmers1, kmers2):
            mapping = device_put(self.mapping)
            pos_kernel = partial(distance_kernel, theta)

            rev_complement = np.reshape(mapping[np.ravel(kmers2)],
                                        kmers2.shape)[:, ::-1]

            same_kmer = np.all(kmers1[:, None, :] == kmers2[None, :, :],
                               axis=2)
            same_rev_comp = np.all(
                kmers1[:, None, :] == rev_complement[None, :, :], axis=2)
            offsets1 = np.arange(kmers1.shape[0])
            offsets2 = np.arange(kmers2.shape[0])
            weight = vmap(lambda i: vmap(lambda j: pos_kernel(i, j))
                          (offsets2))(offsets1)
            weight_rev_comp = vmap(lambda i: vmap(lambda j: pos_kernel(i, j))
                                   (offsets2[::-1]))(offsets1)

            return (np.sum(same_kmer * weight) +
                    np.sum(same_rev_comp * weight_rev_comp))
Exemple #7
0
def _ravel_list(lst):
    if not lst: return jnp.array([], jnp.float32), lambda _: []
    from_dtypes = [dtypes.dtype(l) for l in lst]
    to_dtype = dtypes.result_type(*from_dtypes)
    sizes, shapes = unzip2((jnp.size(x), jnp.shape(x)) for x in lst)
    indices = np.cumsum(sizes)

    def unravel(arr):
        chunks = jnp.split(arr, indices[:-1])
        with warnings.catch_warnings():
            warnings.simplefilter(
                "ignore")  # ignore complex-to-real cast warning
            return [
                lax.convert_element_type(chunk.reshape(shape), dtype)
                for chunk, shape, dtype in zip(chunks, shapes, from_dtypes)
            ]

    ravel = lambda e: jnp.ravel(lax.convert_element_type(e, to_dtype))
    raveled = jnp.concatenate([ravel(e) for e in lst])
    return raveled, unravel
Exemple #8
0
def gather_error_check(error, enabled_errors, operand, start_indices, *,
                       dimension_numbers, slice_sizes, unique_indices,
                       indices_are_sorted, mode, fill_value):
    out = lax.gather_p.bind(operand,
                            start_indices,
                            dimension_numbers=dimension_numbers,
                            slice_sizes=slice_sizes,
                            unique_indices=unique_indices,
                            indices_are_sorted=indices_are_sorted,
                            mode=mode,
                            fill_value=fill_value)

    if ErrorCategory.OOB not in enabled_errors:
        return out, error

    # compare to OOB masking logic in lax._gather_translation_rule
    dnums = dimension_numbers
    operand_dims = np.array(operand.shape)
    num_batch_dims = len(start_indices.shape) - 1

    upper_bound = operand_dims[np.array(dnums.start_index_map)]
    upper_bound -= np.array(slice_sizes)[np.array(dnums.start_index_map)]
    upper_bound = jnp.expand_dims(upper_bound,
                                  axis=tuple(range(num_batch_dims)))
    in_bounds = (start_indices >= 0) & (start_indices <= upper_bound.astype(
        start_indices.dtype))

    # Get first OOB index, axis and axis size so it can be added to the error msg.
    flat_idx = jnp.argmin(in_bounds)
    multi_idx = jnp.unravel_index(flat_idx, start_indices.shape)
    oob_axis = jnp.array(dnums.start_index_map)[multi_idx[-1]]
    oob_axis_size = jnp.array(operand.shape)[oob_axis]
    oob_index = jnp.ravel(start_indices)[flat_idx]
    payload = jnp.array([oob_index, oob_axis, oob_axis_size], dtype=jnp.int32)

    msg = (f'out-of-bounds indexing at {summary()} for array of '
           f'shape {operand.shape}: '
           'index {payload0} is out of bounds for axis {payload1} '
           'with size {payload2}.')

    return out, assert_func(error, jnp.all(in_bounds), msg, payload)
Exemple #9
0
def sol_recursive(f, z, t):
  """
  Recursively compute higher order derivatives of dynamics of ODE.
  """
  z_shape = z.shape
  z_t = jnp.concatenate((jnp.ravel(z), jnp.array([t])))

  def g(z_t):
    """
    Closure to expand z.
    """
    z, t = jnp.reshape(z_t[:-1], z_shape), z_t[-1]
    dz = jnp.ravel(f(z, t))
    dt = jnp.array([1.])
    dz_t = jnp.concatenate((dz, dt))
    return dz_t

  (y0, [y1h]) = jet(g, (z_t, ), ((jnp.ones_like(z_t), ), ))
  (y0, [y1, y2h]) = jet(g, (z_t, ), ((y0, y1h,), ))

  return (jnp.reshape(y0[:-1], z_shape), [jnp.reshape(y1[:-1], z_shape)])
Exemple #10
0
        def evaluate_ode_for_extended_state(extended_state, ivp=ivp, dt=dt):
            r"""Evaluate the ODE for an extended state (x(t), t).

            More precisely, compute the derivative of the stacked state (x(t), t) according to the ODE.
            This function implements a rewriting of non-autonomous as autonomous ODEs.
            This means that

            .. math:: \dot x(t) = f(t, x(t))

            becomes

            .. math:: \dot z(t) = \dot (x(t), t) = (f(x(t), t), 1).

            Only considering autonomous ODEs makes the jet-implementation
            (and automatic differentiation in general) easier.
            """
            x, t = jnp.reshape(extended_state[:-1],
                               ivp.y0.shape), extended_state[-1]
            dx = ivp.f(t, x)
            dx_ravelled = jnp.ravel(dx)
            stacked_ode_eval = jnp.concatenate((dx_ravelled, dt))
            return stacked_ode_eval
def setup_sklearn():
    import sklearn.datasets
    from sklearn.model_selection import train_test_split

    iris = sklearn.datasets.load_iris()
    X = iris["data"]
    y = (iris["target"] == 2).astype(np.int)  # 1 if Iris-Virginica, else 0'
    N, D = X.shape  # 150, 4

    X_train, X_test, y_train, y_test = train_test_split(X,
                                                        y,
                                                        test_size=0.33,
                                                        random_state=42)

    from sklearn.linear_model import LogisticRegression

    # We set C to a large number to turn off regularization.
    # We don't fit the bias term to simplify the comparison below.
    log_reg = LogisticRegression(solver="lbfgs", C=1e5, fit_intercept=False)
    log_reg.fit(X_train, y_train)
    w_mle_sklearn = jnp.ravel(log_reg.coef_)
    set_seed(0)
    w = w_mle_sklearn
    return w, X_test, y_test
Exemple #12
0
def jacres_j(acc, pe, sep, ne, zcc):
    Mp = peq.M
    Mn = neq.M
    Ms = sepq.M
    Ma = accq.M
    Mz = zccq.M
    Np = peq.N
    Nn = neq.N

    arg_jp = [
        pe.jvec[0:Mp], pe.uvec[1:Mp + 1], pe.Tvec[1:Mp + 1], pe.etavec[0:Mp],
        pe.cmat[Np, :], pe.cmat[Np + 1, :], pe.cmax * np.ones([Mp, 1])
    ]
    res_jp = vmap(peq.ionic_flux)(*arg_jp)
    A_jp = vmap((grad(peq.ionic_flux, range(0, len(arg_jp) - 1))))(*arg_jp)

    J_jp = build_diag(Mp, A_jp[0], "square")
    J_jup = build_diag(Mp, A_jp[1], "wide")
    J_jTp = build_diag(Mp, A_jp[2], "wide")
    J_jetap = build_diag(Mp, A_jp[3], "square")

    col_cp = []
    data_cp = []
    #    row_cp = np.repeat( np.arange(0,Mp), 2)
    row_cp = np.repeat(np.arange(0, Mp), 2)
    for i in range(0, Mp):
        #        col_cp.append([Np + (Np+2)*(i), Np+1 + (Np+2)*(i) ])
        #        data_cp.append([A_jp[4][i], A_jp[5][i]])
        col_cp.append([Np + (Np + 2) * i, Np + 1 + (Np + 2) * (i)])
        data_cp.append([A_jp[4][i], A_jp[5][i]])
    data_cp = np.ravel(np.array(data_cp))
    col_cp = np.ravel(np.array(col_cp))
    J_cp = coo_matrix((data_cp, (row_cp, col_cp)),
                      shape=(Mp, Mp * (Np + 2) + Mn * (Nn + 2)))
    """ Negative Electrode"""

    arg_jn = [
        ne.jvec[0:Mn], ne.uvec[1:Mn + 1], ne.Tvec[1:Mn + 1], ne.etavec[0:Mn],
        ne.cmat[Nn, :], ne.cmat[Nn + 1, :], ne.cmax * np.ones([Mn, 1])
    ]
    res_jn = vmap(neq.ionic_flux)(*arg_jn)
    A_jn = vmap((grad(neq.ionic_flux, range(0, len(arg_jn) - 1))))(*arg_jn)

    J_jn = build_diag(Mn, A_jn[0], "square")
    J_jun = build_diag(Mn, A_jn[1], "wide")
    J_jTn = build_diag(Mn, A_jn[2], "wide")
    J_jetan = build_diag(Mn, A_jn[3], "square")

    col_cn = []
    data_cn = []
    offset = (Np + 2) * Mp
    row_cn = np.repeat(np.arange(0, Mn), 2)
    for i in range(0, Mn):
        col_cn.append(
            [Nn + (Nn + 2) * i + offset, Nn + 1 + (Nn + 2) * (i) + offset])
        data_cn.append([A_jn[4][i], A_jn[5][i]])
    data_cn = np.ravel(np.array(data_cn))
    col_cn = np.ravel(np.array(col_cn))
    J_cn = coo_matrix((data_cn, (row_cn, col_cn)),
                      shape=(Mn, Mp * (Np + 2) + Mn * (Nn + 2)))
    """" total """
    J_ju = hstack([
        vstack([J_jup, empty_rec(Mn, Mp + 2)]),
        empty_rec(Mp + Mn, Ms + 2),
        vstack([empty_rec(Mp, Mn + 2), J_jun])
    ])

    J_jT = hstack([
        empty_rec(Mp + Mn, Ma + 2),
        vstack([J_jTp, empty_rec(Mn, Mp + 2)]),
        empty_rec(Mp + Mn, Ms + 2),
        vstack([empty_rec(Mp, Mn + 2), J_jTn]),
        empty_rec(Mp + Mn, Mz + 2)
    ])

    J_jj = block_diag((J_jp, J_jn))

    J_jeta = block_diag((J_jetap, J_jetan))

    J_jc = vstack([J_cp, J_cn])

    res_j = np.hstack((res_jp, res_jn))
    J_j = hstack([
        J_jc, J_ju, J_jT,
        empty_rec(Mp + Mn, Mp + 2 + Ms + 2 + Mn + 2),
        empty_rec(Mp + Mn, Mp + 2 + Mn + 2), J_jj, J_jeta
    ])

    return res_j, J_j
Exemple #13
0
def ravel(x, order="C"):
  if isinstance(x, JaxArray): x = x.value
  return JaxArray(jnp.ravel(x, order=order))
def makeInputs(OMap,
               r_cent,
               contrasts,
               X,
               Y,
               gridsizedeg=4,
               gridperdeg=5,
               AngWidth=32):
    '''
    makes the input arrays for the various stimulus conditions
    all radii at the highest contrast - to test Surround Suppression
    all contrasts at the highest radius - to test contrast effect
    highest contrast and radius with a Gabor filter - to test Ray-Maunsell Effect
    
    OMap = orientation preference across the cortex
    r_cent = array of stimulus radii
    contrasts = array of stimulus contrasts
    X,Y = matrices of distances in X and Y in degrees
    various parameters of the network
    
    Outputs
    StimConds = array of dim Ne x stimCondition (the name is short for Stimulus Conditions)
    stimCondition = [max radii * varying contrasts, max contrast * vary radii, Gabor]
    
    '''
    rads = np.hstack(
        (np.max(r_cent) * np.ones(len(contrasts) - 1), r_cent
         ))  # cause I don't want to double up the Contrast = 100 condition
    Contrasts = np.hstack(
        (contrasts, np.ones(len(r_cent)) * np.max(contrasts))
    )  # need to add one for Gabor condition, but I would subtract one to not double up the C= 100 R= max condition

    gridsize = OMap.shape
    dx = gridsizedeg / gridsize[0]  # dx is degrees between neurons

    Mid1 = int(np.floor(gridsize[0] / 2))
    Mid2 = int(np.floor(gridsize[1] / 2))

    # Python does linear indexing weird, just going to use the found midpts
    # trgt = onp.ravel_multi_index((Mid1, Mid2), (Len[0], Len[1]))

    Orientation = OMap[Mid1, Mid2]

    dOri = np.abs(OMap - Orientation)
    dOri = np.where(dOri > 90, 180 - dOri, dOri)
    In0 = np.ravel(np.exp(-dOri**2 / (2 * AngWidth**2)))

    RFdecay = 0.8 / 2  #biologic decay is 0.8 mm, magfactor =2 mm/deg
    RFdecay = RFdecay / 10  #trying to find good SI, this parameter has a large impact on the suppression curve
    RFdecay = 0.04
    #RFdecay = dx
    #GaborSigma = 0.3*np.max(r_cent)
    GaborSigma = 0.5

    x0 = X[Mid1, Mid2]
    y0 = Y[Mid1, Mid2]

    x_space = X - x0
    y_space = Y - y0

    # find the distances across the cortex
    r_space = np.ravel(np.sqrt(x_space**2 + y_space**2))

    #find the spatial input for a constant grating
    InSr = (1 - (1 / (1 + np.exp(-(r_space - rads[:, None]) / RFdecay))))
    #find the spatial input for a Gabor
    InGabor = np.exp(-r_space**2 / 2 / GaborSigma**2)
    #include the contrasts with it
    if len(contrasts) > 1:
        StimConds = Contrasts[:, None] * np.vstack((InSr, InGabor))
    else:
        StimConds = Contrasts[:, None] * InSr
    StimConds = StimConds * In0
    #include the relative drive between E and I cells  -- nixing this cause gE and gI are parametrs
    #InSpace = np.hstack( (StimConds, gI*StimConds)).T #.T makes it neurons by stimcond

    #array to reference to find max contrasts, etc
    stimulus_condition = np.vstack((Contrasts, np.hstack(
        (rads, np.max(rads)))))

    return StimConds.T, stimulus_condition, InSr
def ssn_FP(pos_params, OLDSTYLE):
    ''' 
    Fcn that finds the fixed point and PS of the given SSN network. 
    returns spect, frequencies used to find that spect, peak frequencies (f0), fixed point rates (r_fp), and CONVG == 5 outputs
    
    inputs 
    pos_params are params that once sigmoided are always positive. 
    OLDSTYLE = old ranges for gE/gI in sigmoid params fcn
    '''

    params = sigmoid_params(pos_params, MULTI=True, OLDSTYLE=OLDSTYLE)

    #unpack parameters
    Jee = params[0] * np.pi * psi
    Jei = params[1] * np.pi * psi
    Jie = params[2] * np.pi * psi
    Jii = params[3] * np.pi * psi

    if len(params) < 8:
        i2e = params[4]
        Plocal = params[5]
        sigR = params[6]
        gE = 1
        gI = 1 * i2e
        NMDAratio = 0.1
    elif len(params) == 9:
        i2e = 1
        gE = params[4]
        gI = params[5]
        NMDAratio = params[6]
        Plocal = params[7]
        PlocalIE = Plocal
        sigR = params[8]
        sigEE = 0.35 * np.sqrt(sigR)
        sigIE = 0.35 / np.sqrt(sigR)
    elif len(params) == 10:
        i2e = 1
        gE = params[4]
        gI = params[5]
        NMDAratio = params[6]
        Plocal = params[7]
        sigEE = params[8]
        sigIE = params[9]
        PlocalIE = Plocal
    else:
        i2e = 1
        gE = params[4]
        gI = params[5]
        NMDAratio = params[6]
        Plocal = params[7]
        PlocalIE = params[8]
        sigEE = params[9]
        sigIE = params[10]

    sigEE = sigEE / magnFactor  # sigEE now in degress
    sigIE = sigIE / magnFactor  # sigIE now in degrees

    W = make_conn.make_full_W(Plocal,
                              Jee,
                              Jei,
                              Jie,
                              Jii,
                              sigEE,
                              sigIE,
                              deltaD,
                              OMap,
                              sigXI=0.09,
                              PlocalIE=PlocalIE)

    ssn = SSN_classes._SSN_AMPAGABA(tau_s, NMDAratio, n, k, Ne, Ni, tau_vec, W)
    ssn.topos_vec = np.ravel(OMap)

    r_init = np.zeros([ssn.N, len(Contrasts)])
    #Inp_vec is Ne+Ni x stimulus conditions (typically 8) to find FP and such things
    inp_vec = np.vstack((gE * Inp, gI * Inp))

    r_fp, CONVG = ssn.fixed_point_r(inp_vec,
                                    r_init=r_init,
                                    Tmax=Tmax,
                                    dt=dt,
                                    xtol=xtol)

    #calculate power spectrum - find PS for each stimulus condition and concatenate them together
    for cc in range(cons):
        if cc == 0:
            spect, fs, _ = SSN_power_spec.linear_power_spect(
                ssn,
                r_fp[:, cc],
                noise_pars,
                freq_range,
                fnums,
                LFPrange=[LFPtarget[0]])
        elif cc == 1:
            spect_2, _, _ = SSN_power_spec.linear_power_spect(
                ssn,
                r_fp[:, cc],
                noise_pars,
                freq_range,
                fnums,
                LFPrange=[LFPtarget[0]])
            spect = np.concatenate((spect[:, None], spect_2[:, None]), axis=1)
        else:
            spect_2, _, _ = SSN_power_spec.linear_power_spect(
                ssn,
                r_fp[:, cc],
                noise_pars,
                freq_range,
                fnums,
                LFPrange=[LFPtarget[0]])
            spect = np.concatenate((spect, spect_2[:, None]), axis=1)

    # My one-shot way of finding the PS. The above version works better


#     if cons == 1:
#         spect, fs, f0, _ = SSN_power_spec.linear_power_spect(ssn, r_fp, noise_pars, freq_range, fnums, cons, LFPrange=[LFPtarget[0]])

#         if np.max(np.abs(np.imag(spect))) > 0.01:
#             print("Spectrum is dangerously imaginary")

#     else:
#         spect, fs, f0, _ = SSN_power_spec.linear_PS_sameTime(ssn, r_fp[:, con_inds], noise_pars, freq_range, fnums, cons, LFPrange=[LFPtarget[0]])

#find PS at the outer neurons
    outer_spect = make_outer_spect(ssn, r_fp[:, gabor_inds], probes)
    spect = np.concatenate((spect, outer_spect), axis=1)

    if np.max(np.abs(np.imag(spect))) > 0.01:
        print("Spectrum is dangerously imaginary")

    spect = np.real(spect)

    # I'm keeping f0 in the outputs just for congruency between the topo SSN and 2-D SSN code
    f0 = 0
    #print(spect.shape)

    return spect, fs, f0, r_fp, CONVG
gen_inds = np.arange(len(Contrasts))
rad_inds = np.arange(
    len(contrasts) - 1,
    len(r_cent) + len(contrasts) -
    1)  #np.where(stimCon[0, :] == np.max(Contrasts), gen_inds, 0)
con_inds = np.hstack(
    (np.arange(0,
               len(contrasts) - 1), len(r_cent) + len(contrasts) -
     2))  #np.where(stimCon[1, :] == np.max(stimCon[1,:]), gen_inds, 0)
gabor_inds = -1

trgt = np.floor(Ne / 2)

con_inds = np.hstack((con_inds, gabor_inds))
cons = len(con_inds)
ssn_Ampa.topos_vec = np.ravel(OMap)

if ssn_Ampa.N > 2:
    LFPtarget = trgt + np.array(
        [ii * gridsize for ii in range(int(np.floor(gridsize / 2)))])
else:
    LFPtarget = None

spect, fs, f0, _ = SSN_power_spec.linear_PS_sameTime(
    ssn_Ampa,
    r_fp[:, con_inds],
    SSN_power_spec.NoisePars(),
    freq_range,
    fnums,
    cons,
    LFPrange=[LFPtarget[0]])
Exemple #17
0
def _get_leaf_diagnostics(leaf, key_prefix):
    # update this to add more grads diagnostics
    return {
        f'{key_prefix}max': jnp.max(jnp.abs(leaf)),
        f'{key_prefix}norm': jnp.linalg.norm(jnp.ravel(leaf)),
    }
Exemple #18
0
 def func(x: jnp.DeviceArray) -> jnp.DeviceArray:
     return jnp.ravel(x)
Exemple #19
0
def _ravel_list(*lst):
    return jnp.concatenate([jnp.ravel(elt)
                            for elt in lst]) if lst else jnp.array([])
Exemple #20
0
 def reshape(x, y):
     return jnp.ravel(jnp.array([x, y]), order="F").reshape(10, 2)
Exemple #21
0
for i in range(1,len(nums)):
	epsilon = eps[i]
	n_iter = n_iters[i]
	theta_i = np.tile(np.linspace(0, 2 * np.pi, NS+1)[:-1], (NC,1))
	_, params_new = get_all_coil_data("../../../tests/w7x/scanold2/w7x_l{}.hdf5".format(nums[i]))
	fc_new, _ = params_new
	theta_i = find_minimum_theta_all_coils(fc_new, r_fil, theta_i)


	print("Size is {}".format(sizes[i]))
	print("The original delta r 1 is:")
	print(np.mean(np.linalg.norm(r_fil - filament_real_space(fc_new, np.tile(np.linspace(0, 2 * np.pi, NS+1)[:-1], (NC,1))), axis=-1)))
	print("The new delta r 1 is (with minimization):")
	print(np.mean(np.linalg.norm(r_fil - filament_real_space(fc_new, theta_i), axis=-1)))
	mean_delta_rs.append(np.mean(np.linalg.norm(r_fil - filament_real_space(fc_new, theta_i), axis=-1)))
	print("The max distance is")
	print(np.max(np.linalg.norm(r_fil - filament_real_space(fc_new, theta_i), axis=-1)))
	max_delta_rs.append(np.max(np.linalg.norm(r_fil - filament_real_space(fc_new, theta_i), axis=-1)))


	difference = np.linalg.norm(r_fil - filament_real_space(fc_new, np.tile(np.linspace(0, 2 * np.pi, NS+1)[:-1], (NC,1))), axis=-1) - np.linalg.norm(r_fil - filament_real_space(fc_new, theta_i), axis=-1)
	new_diff = difference[difference > 0]
	larger = np.ravel(difference).shape[0] - new_diff.shape[0]
	if larger > 0:
		print(larger)

np.save("w7x_sizes.npy", np.asarray(sizes))
np.save("w7x_mean_delta_rs.npy", np.asarray(mean_delta_rs))
np.save("w7x_max_delta_rs.npy", np.asarray(max_delta_rs))

Exemple #22
0
def ravel_list(*lst):
  return np.concatenate([np.ravel(elt) for elt in lst]) if lst else np.array([])
def test_logmarglike_lineargaussianmodel_onetransfer_basics():

    theta_truth = jax.random.normal(key, (n_components,))

    M_T = design_matrix_polynomials(n_components, n_pix_y)  # (n_components, n_pix_y)
    y_truth = np.matmul(theta_truth, M_T)  # (nobj, n_pix_y)
    y, yinvvar, logyinvvar = make_masked_noisy_data(y_truth)  # (nobj, n_pix_y)
    assert_equal_shape([y_truth, y, yinvvar, logyinvvar])

    logfml, theta_map, theta_cov = logmarglike_lineargaussianmodel_onetransfer(
        M_T, y, yinvvar
    )

    # check result is finite and shapes are correct
    assert_shape(theta_map, (n_components,))
    assert_shape(theta_cov, (n_components, n_components))
    assert np.isfinite(logfml)
    assert np.all(np.isfinite(theta_map))
    assert np.all(np.isfinite(theta_cov))

    # check that result isn't too far off the truth, in chi2 sense
    dt = theta_map - theta_truth
    chi2 = 0.5 * np.ravel(np.matmul(dt.T, np.linalg.solve(theta_cov, dt)))
    assert chi2 < 100

    # check that normalised posterior distribution factorises into product of gaussians
    def log_posterior(theta):
        y_mod = np.matmul(theta, M_T)  # (n_samples, n_pix_y)
        return batch_gaussian_loglikelihood(y_mod - y, yinvvar)

    def log_posterior2(theta):
        dt = theta - theta_map
        s, logdet = np.linalg.slogdet(theta_cov * 2 * np.pi)
        chi2 = np.dot(dt.T, np.linalg.solve(theta_cov, dt))
        return logfml - 0.5 * (s * logdet + chi2)

    logpostv = log_posterior(theta_truth)
    logpostv2 = log_posterior2(theta_truth)
    assert abs(logpostv2 / logpostv - 1) < 0.01

    # now trying jit version of function
    logfml2, theta_map2, theta_cov2 = logmarglike_lineargaussianmodel_onetransfer_jit(
        M_T, y, yinvvar, logyinvvar
    )
    # check that outputs match original implementation
    assert_fml_thetamap_thetacov(
        logfml, theta_map, theta_cov, logfml2, theta_map2, theta_cov2, relative_accuracy
    )

    # now running simple optimiser to check that result is indeed optimum
    def loss_fn(theta):
        return -log_posterior(theta)

    params = [1 * theta_map]
    learning_rate = 1e-5
    for n in range(10):
        grads = grad(loss_fn)(*params)
        params = [param - learning_rate * grad for param, grad in zip(params, grads)]
        # print(n, loss_fn(*params), params[0] - theta_map)
    assert np.allclose(theta_map, params[0], rtol=1e-6)

    # Testing analytic covariance is correct and equals inverse of hessian
    theta_cov2 = np.linalg.inv(np.reshape(hessian(loss_fn)(theta_map), theta_cov.shape))
    assert np.allclose(theta_cov, theta_cov2, rtol=1e-6)

    # create vectorised loss
    loss_fn_vmap = jit(vmap(loss_fn))

    # now computes the evidence numerically
    n = 15
    theta_std = np.diag(theta_cov) ** 0.5
    theta_samples, vol_element = generate_sample_grid(theta_map, theta_std, n)
    loglikelihoods = -loss_fn_vmap(theta_samples)
    logfml_numerical = logsumexp(np.log(vol_element) + loglikelihoods)
    # print("logfml, logfml_numerical", logfml, logfml_numerical)
    assert abs(logfml_numerical / logfml - 1) < 0.01

    # Compare with case including gaussian prior with large variance
    mu = theta_map * 0
    muinvvar = 1 / (1e4 * np.diag(theta_cov) ** 0.5)
    logfml2, theta_map2, theta_cov2 = logmarglike_lineargaussianmodel_twotransfers(
        M_T, y, yinvvar, mu, muinvvar
    )
    assert_fml_thetamap_thetacov(
        logfml, theta_map, theta_cov, logfml2, theta_map2, theta_cov2, 0.2
    )
Exemple #24
0
def initialize_odefilter_with_taylormode(f, y0, t0, prior, initrv):
    """Initialize an ODE filter with Taylor-mode automatic differentiation.

    This requires JAX. For an explanation of what happens ``under the hood``, see [1]_.

    References
    ----------
    .. [1] Krämer, N. and Hennig, P., Stable implementation of probabilistic ODE solvers,
       *arXiv:2012.10106*, 2020.


    The implementation is inspired by the implementation in
    https://github.com/jacobjinkelly/easy-neural-ode/blob/master/latent_ode.py

    Parameters
    ----------
    f
        ODE vector field.
    y0
        Initial value.
    t0
        Initial time point.
    prior
        Prior distribution used for the ODE solver. For instance an integrated Brownian motion prior (``IBM``).
    initrv
        Initial random variable.

    Returns
    -------
    Normal
        Estimated initial random variable. Compatible with the specified prior.


    Examples
    --------

    >>> import sys, pytest
    >>> if not sys.platform.startswith('linux'):
    ...     pytest.skip()

    >>> from dataclasses import astuple
    >>> from probnum.randvars import Normal
    >>> from probnum.problems.zoo.diffeq import threebody_jax, vanderpol_jax
    >>> from probnum.statespace import IBM

    Compute the initial values of the restricted three-body problem as follows

    >>> f, t0, tmax, y0, df, *_ = astuple(threebody_jax())
    >>> print(y0)
    [ 0.994       0.          0.         -2.00158511]

    >>> prior = IBM(ordint=3, spatialdim=4)
    >>> initrv = Normal(mean=np.zeros(prior.dimension), cov=np.eye(prior.dimension))
    >>> improved_initrv = initialize_odefilter_with_taylormode(f, y0, t0, prior, initrv)
    >>> print(prior.proj2coord(0) @ improved_initrv.mean)
    [ 0.994       0.          0.         -2.00158511]
    >>> print(improved_initrv.mean)
    [ 9.94000000e-01  0.00000000e+00 -3.15543023e+02  0.00000000e+00
      0.00000000e+00 -2.00158511e+00  0.00000000e+00  9.99720945e+04
      0.00000000e+00 -3.15543023e+02  0.00000000e+00  6.39028111e+07
     -2.00158511e+00  0.00000000e+00  9.99720945e+04  0.00000000e+00]

    Compute the initial values of the van-der-Pol oscillator as follows

    >>> f, t0, tmax, y0, df, *_ = astuple(vanderpol_jax())
    >>> print(y0)
    [2. 0.]
    >>> prior = IBM(ordint=3, spatialdim=2)
    >>> initrv = Normal(mean=np.zeros(prior.dimension), cov=np.eye(prior.dimension))
    >>> improved_initrv = initialize_odefilter_with_taylormode(f, y0, t0, prior, initrv)
    >>> print(prior.proj2coord(0) @ improved_initrv.mean)
    [2. 0.]
    >>> print(improved_initrv.mean)
    [    2.     0.    -2.    60.     0.    -2.    60. -1798.]
    >>> print(improved_initrv.std)
    [0. 0. 0. 0. 0. 0. 0. 0.]
    """

    try:
        import jax.numpy as jnp
        from jax.config import config
        from jax.experimental.jet import jet

        config.update("jax_enable_x64", True)
    except ImportError as err:
        raise ImportError(
            "Cannot perform Taylor-mode initialisation without optional "
            "dependencies jax and jaxlib. Try installing them via `pip install jax jaxlib`."
        ) from err

    order = prior.ordint

    def total_derivative(z_t):
        """Total derivative."""
        z, t = jnp.reshape(z_t[:-1], z_shape), z_t[-1]
        dz = jnp.ravel(f(t, z))
        dt = jnp.array([1.0])
        dz_t = jnp.concatenate((dz, dt))
        return dz_t

    z_shape = y0.shape
    z_t = jnp.concatenate((jnp.ravel(y0), jnp.array([t0])))

    derivs = []

    derivs.extend(y0)
    if order == 0:
        all_derivs = statespace.Integrator._convert_derivwise_to_coordwise(
            np.asarray(jnp.array(derivs)), ordint=0, spatialdim=len(y0))

        return randvars.Normal(
            np.asarray(all_derivs),
            cov=np.asarray(jnp.diag(jnp.zeros(len(derivs)))),
            cov_cholesky=np.asarray(jnp.diag(jnp.zeros(len(derivs)))),
        )

    (dy0, [*yns]) = jet(total_derivative, (z_t, ), ((jnp.ones_like(z_t), ), ))
    derivs.extend(dy0[:-1])
    if order == 1:
        all_derivs = statespace.Integrator._convert_derivwise_to_coordwise(
            np.asarray(jnp.array(derivs)), ordint=1, spatialdim=len(y0))

        return randvars.Normal(
            np.asarray(all_derivs),
            cov=np.asarray(jnp.diag(jnp.zeros(len(derivs)))),
            cov_cholesky=np.asarray(jnp.diag(jnp.zeros(len(derivs)))),
        )

    for _ in range(1, order):
        (dy0, [*yns]) = jet(total_derivative, (z_t, ), ((dy0, *yns), ))
        derivs.extend(yns[-2][:-1])

    all_derivs = statespace.Integrator._convert_derivwise_to_coordwise(
        jnp.array(derivs), ordint=order, spatialdim=len(y0))

    return randvars.Normal(
        np.asarray(all_derivs),
        cov=np.asarray(jnp.diag(jnp.zeros(len(derivs)))),
        cov_cholesky=np.asarray(jnp.diag(jnp.zeros(len(derivs)))),
    )
Exemple #25
0
def rand_argmax(a):
    return np.random.choice(jnp.nonzero(jnp.ravel(a == jnp.max(a)))[0])
X = iris["data"]
y = (iris["target"] == 2).astype(np.int)  # 1 if Iris-Virginica, else 0'
N, D = X.shape  # 150, 4

X_train, X_test, y_train, y_test = train_test_split(X,
                                                    y,
                                                    test_size=0.33,
                                                    random_state=42)

from sklearn.linear_model import LogisticRegression

# We set C to a large number to turn off regularization.
# We don't fit the bias term to simplify the comparison below.
log_reg = LogisticRegression(solver="lbfgs", C=1e5, fit_intercept=False)
log_reg.fit(X_train, y_train)
w_mle_sklearn = jnp.ravel(log_reg.coef_)

set_seed(0)
w = w_mle_sklearn

## Compute gradient of loss "by hand" using numpy


def BCE_with_logits(logits, targets):
    N = logits.shape[0]
    logits = logits.reshape(N, 1)
    logits_plus = jnp.hstack([np.zeros((N, 1)), logits])  # e^0=1
    logits_minus = jnp.hstack([np.zeros((N, 1)), -logits])
    logp1 = -logsumexp(logits_minus, axis=1)
    logp0 = -logsumexp(logits_plus, axis=1)
    logprobs = logp1 * targets + logp0 * (1 - targets)
Exemple #27
0
def gaussian_kde_cdf(x_eval, samples, factor):

    low = jnp.ravel((-jnp.inf - samples) / factor)
    hi = jnp.ravel((x_eval - samples) / factor)

    return jax.scipy.special.ndtr(hi - low).mean(axis=0)
Exemple #28
0
 def _equality_constraints(variables):
     return np.ravel(equality_constraints(*unravel(variables)))
Exemple #29
0
def _coo_fromdense_impl(mat, *, nnz, index_dtype):
    mat = jnp.asarray(mat)
    m, n = mat.shape
    mat_flat = jnp.ravel(mat)
    ind = jnp.nonzero(mat_flat, size=nnz)[0].astype(index_dtype)
    return mat_flat[ind], ind // n, ind % n
Exemple #30
0
 def add_fxy_coord(X, Y, coefficient):
     X_fxy_agg.append(jnp.ravel(X))
     Y_fxy_agg.append(jnp.ravel(Y))
     fxy_coefficient.append(coefficient * jnp.ones_like(Y_fxy_agg[-1]))