def _ravel_list(*leaves): leaves_metadata = tree_map( lambda l: pytree_metadata(jnp.ravel(l), jnp.shape(l), jnp.size(l), canonicalize_dtype(lax.dtype(l))), leaves) leaves_idx = jnp.cumsum( jnp.array((0, ) + tuple(d.size for d in leaves_metadata))) def unravel_list(arr): return [ jnp.reshape(lax.dynamic_slice_in_dim(arr, leaves_idx[i], m.size), m.shape).astype(m.dtype) for i, m in enumerate(leaves_metadata) ] flat = jnp.concatenate([m.flat for m in leaves_metadata ]) if leaves_metadata else jnp.array([]) return flat, unravel_list
def dmrg_solve(A, L, R, mpo, n_krylov, tol, maxiter): """ The local ground state step of single-site DMRG. """ mpo_map = jax_dmrg.map.SingleMPOHeffMap(mpo, L, R) A_vec = jnp.ravel(A) # E, eV, err = minimum_eigenpair_jit(mpo_map.matvec, mpo_map.data, # n_krylov, maxiter, tol, # A_vec) E, eV, err = minimum_eigenpair(mpo_map, n_krylov, maxiter=maxiter, tol=tol, v0=A_vec) #print(err) newA = eV.reshape(A.shape) return (E, newA, err)
def update(self, params, grad): """ Description: Updates parameters based on correct value, loss and learning rate. Args: params (list/numpy.ndarray): Parameters of controller pred controller x (float): input to controller y (float): true label loss (function): loss function. defaults to input value. Returns: Updated parameters in same shape as input """ # Make everything a list for generality is_list = True if(type(params) is not list): params = [params] grad = [grad] is_list = False # used to compute inverse matrix with respect to each parameter vector flat_grad = [np.ravel(dw) for dw in grad] # initialize A if self.A is None: self.A = [np.eye(dw.shape[0]) * self.eps for dw in flat_grad] self.Ainv = [np.eye(dw.shape[0]) * (1 / self.eps) for dw in flat_grad] # compute max norm and normalize accordingly eta = self.lr if(self.max_norm): self.max_norm = np.maximum(self.max_norm, np.linalg.norm([self.general_norm(dw) for dw in flat_grad])) eta = eta * self.max_norm # partial_update automatically reshapes flat_grad into correct params shape new_values = [self.partial_update(A, Ainv, g, w) for (A, Ainv, g, w) in zip(self.A, self.Ainv, flat_grad, params)] self.A, self.Ainv, new_grad = list(map(list, zip(*new_values))) new_params = [w - eta * dw for (w, dw) in zip(params, new_grad)] if self.project: self.min_radius = np.maximum(self.min_radius, self.general_norm(y)) norm = 5. * self.min_radius new_params = [self.norm_project(p, A, norm) for (p, A) in zip(new_params, self.A)] return new_params if is_list else new_params[0]
def ravel(self, tensor): """ Return a flattened view of the tensor, not a copy. Example: >>> import pyhf >>> pyhf.set_backend("jax") >>> tensor = pyhf.tensorlib.astensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) >>> pyhf.tensorlib.ravel(tensor) DeviceArray([1., 2., 3., 4., 5., 6.], dtype=float64) Args: tensor (Tensor): Tensor object Returns: `jaxlib.xla_extension.DeviceArray`: A flattened array. """ return jnp.ravel(tensor)
def tree_ravel(pytree): r""" Flatten and concatenate all leaves into a single flat ndarray. Parameters ---------- pytree : a pytree with ndarray leaves A typical example is a pytree of model parameters (weights) or gradients with respect to such model params. Returns ------- arr : ndarray with ndim=1 A single flat array. """ return jnp.concatenate([jnp.ravel(leaf) for leaf in jax.tree_leaves(pytree)])
def kmer_kernel_fn(theta, kmers1, kmers2): mapping = device_put(self.mapping) pos_kernel = partial(distance_kernel, theta) rev_complement = np.reshape(mapping[np.ravel(kmers2)], kmers2.shape)[:, ::-1] same_kmer = np.all(kmers1[:, None, :] == kmers2[None, :, :], axis=2) same_rev_comp = np.all( kmers1[:, None, :] == rev_complement[None, :, :], axis=2) offsets1 = np.arange(kmers1.shape[0]) offsets2 = np.arange(kmers2.shape[0]) weight = vmap(lambda i: vmap(lambda j: pos_kernel(i, j)) (offsets2))(offsets1) weight_rev_comp = vmap(lambda i: vmap(lambda j: pos_kernel(i, j)) (offsets2[::-1]))(offsets1) return (np.sum(same_kmer * weight) + np.sum(same_rev_comp * weight_rev_comp))
def _ravel_list(lst): if not lst: return jnp.array([], jnp.float32), lambda _: [] from_dtypes = [dtypes.dtype(l) for l in lst] to_dtype = dtypes.result_type(*from_dtypes) sizes, shapes = unzip2((jnp.size(x), jnp.shape(x)) for x in lst) indices = np.cumsum(sizes) def unravel(arr): chunks = jnp.split(arr, indices[:-1]) with warnings.catch_warnings(): warnings.simplefilter( "ignore") # ignore complex-to-real cast warning return [ lax.convert_element_type(chunk.reshape(shape), dtype) for chunk, shape, dtype in zip(chunks, shapes, from_dtypes) ] ravel = lambda e: jnp.ravel(lax.convert_element_type(e, to_dtype)) raveled = jnp.concatenate([ravel(e) for e in lst]) return raveled, unravel
def gather_error_check(error, enabled_errors, operand, start_indices, *, dimension_numbers, slice_sizes, unique_indices, indices_are_sorted, mode, fill_value): out = lax.gather_p.bind(operand, start_indices, dimension_numbers=dimension_numbers, slice_sizes=slice_sizes, unique_indices=unique_indices, indices_are_sorted=indices_are_sorted, mode=mode, fill_value=fill_value) if ErrorCategory.OOB not in enabled_errors: return out, error # compare to OOB masking logic in lax._gather_translation_rule dnums = dimension_numbers operand_dims = np.array(operand.shape) num_batch_dims = len(start_indices.shape) - 1 upper_bound = operand_dims[np.array(dnums.start_index_map)] upper_bound -= np.array(slice_sizes)[np.array(dnums.start_index_map)] upper_bound = jnp.expand_dims(upper_bound, axis=tuple(range(num_batch_dims))) in_bounds = (start_indices >= 0) & (start_indices <= upper_bound.astype( start_indices.dtype)) # Get first OOB index, axis and axis size so it can be added to the error msg. flat_idx = jnp.argmin(in_bounds) multi_idx = jnp.unravel_index(flat_idx, start_indices.shape) oob_axis = jnp.array(dnums.start_index_map)[multi_idx[-1]] oob_axis_size = jnp.array(operand.shape)[oob_axis] oob_index = jnp.ravel(start_indices)[flat_idx] payload = jnp.array([oob_index, oob_axis, oob_axis_size], dtype=jnp.int32) msg = (f'out-of-bounds indexing at {summary()} for array of ' f'shape {operand.shape}: ' 'index {payload0} is out of bounds for axis {payload1} ' 'with size {payload2}.') return out, assert_func(error, jnp.all(in_bounds), msg, payload)
def sol_recursive(f, z, t): """ Recursively compute higher order derivatives of dynamics of ODE. """ z_shape = z.shape z_t = jnp.concatenate((jnp.ravel(z), jnp.array([t]))) def g(z_t): """ Closure to expand z. """ z, t = jnp.reshape(z_t[:-1], z_shape), z_t[-1] dz = jnp.ravel(f(z, t)) dt = jnp.array([1.]) dz_t = jnp.concatenate((dz, dt)) return dz_t (y0, [y1h]) = jet(g, (z_t, ), ((jnp.ones_like(z_t), ), )) (y0, [y1, y2h]) = jet(g, (z_t, ), ((y0, y1h,), )) return (jnp.reshape(y0[:-1], z_shape), [jnp.reshape(y1[:-1], z_shape)])
def evaluate_ode_for_extended_state(extended_state, ivp=ivp, dt=dt): r"""Evaluate the ODE for an extended state (x(t), t). More precisely, compute the derivative of the stacked state (x(t), t) according to the ODE. This function implements a rewriting of non-autonomous as autonomous ODEs. This means that .. math:: \dot x(t) = f(t, x(t)) becomes .. math:: \dot z(t) = \dot (x(t), t) = (f(x(t), t), 1). Only considering autonomous ODEs makes the jet-implementation (and automatic differentiation in general) easier. """ x, t = jnp.reshape(extended_state[:-1], ivp.y0.shape), extended_state[-1] dx = ivp.f(t, x) dx_ravelled = jnp.ravel(dx) stacked_ode_eval = jnp.concatenate((dx_ravelled, dt)) return stacked_ode_eval
def setup_sklearn(): import sklearn.datasets from sklearn.model_selection import train_test_split iris = sklearn.datasets.load_iris() X = iris["data"] y = (iris["target"] == 2).astype(np.int) # 1 if Iris-Virginica, else 0' N, D = X.shape # 150, 4 X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.33, random_state=42) from sklearn.linear_model import LogisticRegression # We set C to a large number to turn off regularization. # We don't fit the bias term to simplify the comparison below. log_reg = LogisticRegression(solver="lbfgs", C=1e5, fit_intercept=False) log_reg.fit(X_train, y_train) w_mle_sklearn = jnp.ravel(log_reg.coef_) set_seed(0) w = w_mle_sklearn return w, X_test, y_test
def jacres_j(acc, pe, sep, ne, zcc): Mp = peq.M Mn = neq.M Ms = sepq.M Ma = accq.M Mz = zccq.M Np = peq.N Nn = neq.N arg_jp = [ pe.jvec[0:Mp], pe.uvec[1:Mp + 1], pe.Tvec[1:Mp + 1], pe.etavec[0:Mp], pe.cmat[Np, :], pe.cmat[Np + 1, :], pe.cmax * np.ones([Mp, 1]) ] res_jp = vmap(peq.ionic_flux)(*arg_jp) A_jp = vmap((grad(peq.ionic_flux, range(0, len(arg_jp) - 1))))(*arg_jp) J_jp = build_diag(Mp, A_jp[0], "square") J_jup = build_diag(Mp, A_jp[1], "wide") J_jTp = build_diag(Mp, A_jp[2], "wide") J_jetap = build_diag(Mp, A_jp[3], "square") col_cp = [] data_cp = [] # row_cp = np.repeat( np.arange(0,Mp), 2) row_cp = np.repeat(np.arange(0, Mp), 2) for i in range(0, Mp): # col_cp.append([Np + (Np+2)*(i), Np+1 + (Np+2)*(i) ]) # data_cp.append([A_jp[4][i], A_jp[5][i]]) col_cp.append([Np + (Np + 2) * i, Np + 1 + (Np + 2) * (i)]) data_cp.append([A_jp[4][i], A_jp[5][i]]) data_cp = np.ravel(np.array(data_cp)) col_cp = np.ravel(np.array(col_cp)) J_cp = coo_matrix((data_cp, (row_cp, col_cp)), shape=(Mp, Mp * (Np + 2) + Mn * (Nn + 2))) """ Negative Electrode""" arg_jn = [ ne.jvec[0:Mn], ne.uvec[1:Mn + 1], ne.Tvec[1:Mn + 1], ne.etavec[0:Mn], ne.cmat[Nn, :], ne.cmat[Nn + 1, :], ne.cmax * np.ones([Mn, 1]) ] res_jn = vmap(neq.ionic_flux)(*arg_jn) A_jn = vmap((grad(neq.ionic_flux, range(0, len(arg_jn) - 1))))(*arg_jn) J_jn = build_diag(Mn, A_jn[0], "square") J_jun = build_diag(Mn, A_jn[1], "wide") J_jTn = build_diag(Mn, A_jn[2], "wide") J_jetan = build_diag(Mn, A_jn[3], "square") col_cn = [] data_cn = [] offset = (Np + 2) * Mp row_cn = np.repeat(np.arange(0, Mn), 2) for i in range(0, Mn): col_cn.append( [Nn + (Nn + 2) * i + offset, Nn + 1 + (Nn + 2) * (i) + offset]) data_cn.append([A_jn[4][i], A_jn[5][i]]) data_cn = np.ravel(np.array(data_cn)) col_cn = np.ravel(np.array(col_cn)) J_cn = coo_matrix((data_cn, (row_cn, col_cn)), shape=(Mn, Mp * (Np + 2) + Mn * (Nn + 2))) """" total """ J_ju = hstack([ vstack([J_jup, empty_rec(Mn, Mp + 2)]), empty_rec(Mp + Mn, Ms + 2), vstack([empty_rec(Mp, Mn + 2), J_jun]) ]) J_jT = hstack([ empty_rec(Mp + Mn, Ma + 2), vstack([J_jTp, empty_rec(Mn, Mp + 2)]), empty_rec(Mp + Mn, Ms + 2), vstack([empty_rec(Mp, Mn + 2), J_jTn]), empty_rec(Mp + Mn, Mz + 2) ]) J_jj = block_diag((J_jp, J_jn)) J_jeta = block_diag((J_jetap, J_jetan)) J_jc = vstack([J_cp, J_cn]) res_j = np.hstack((res_jp, res_jn)) J_j = hstack([ J_jc, J_ju, J_jT, empty_rec(Mp + Mn, Mp + 2 + Ms + 2 + Mn + 2), empty_rec(Mp + Mn, Mp + 2 + Mn + 2), J_jj, J_jeta ]) return res_j, J_j
def ravel(x, order="C"): if isinstance(x, JaxArray): x = x.value return JaxArray(jnp.ravel(x, order=order))
def makeInputs(OMap, r_cent, contrasts, X, Y, gridsizedeg=4, gridperdeg=5, AngWidth=32): ''' makes the input arrays for the various stimulus conditions all radii at the highest contrast - to test Surround Suppression all contrasts at the highest radius - to test contrast effect highest contrast and radius with a Gabor filter - to test Ray-Maunsell Effect OMap = orientation preference across the cortex r_cent = array of stimulus radii contrasts = array of stimulus contrasts X,Y = matrices of distances in X and Y in degrees various parameters of the network Outputs StimConds = array of dim Ne x stimCondition (the name is short for Stimulus Conditions) stimCondition = [max radii * varying contrasts, max contrast * vary radii, Gabor] ''' rads = np.hstack( (np.max(r_cent) * np.ones(len(contrasts) - 1), r_cent )) # cause I don't want to double up the Contrast = 100 condition Contrasts = np.hstack( (contrasts, np.ones(len(r_cent)) * np.max(contrasts)) ) # need to add one for Gabor condition, but I would subtract one to not double up the C= 100 R= max condition gridsize = OMap.shape dx = gridsizedeg / gridsize[0] # dx is degrees between neurons Mid1 = int(np.floor(gridsize[0] / 2)) Mid2 = int(np.floor(gridsize[1] / 2)) # Python does linear indexing weird, just going to use the found midpts # trgt = onp.ravel_multi_index((Mid1, Mid2), (Len[0], Len[1])) Orientation = OMap[Mid1, Mid2] dOri = np.abs(OMap - Orientation) dOri = np.where(dOri > 90, 180 - dOri, dOri) In0 = np.ravel(np.exp(-dOri**2 / (2 * AngWidth**2))) RFdecay = 0.8 / 2 #biologic decay is 0.8 mm, magfactor =2 mm/deg RFdecay = RFdecay / 10 #trying to find good SI, this parameter has a large impact on the suppression curve RFdecay = 0.04 #RFdecay = dx #GaborSigma = 0.3*np.max(r_cent) GaborSigma = 0.5 x0 = X[Mid1, Mid2] y0 = Y[Mid1, Mid2] x_space = X - x0 y_space = Y - y0 # find the distances across the cortex r_space = np.ravel(np.sqrt(x_space**2 + y_space**2)) #find the spatial input for a constant grating InSr = (1 - (1 / (1 + np.exp(-(r_space - rads[:, None]) / RFdecay)))) #find the spatial input for a Gabor InGabor = np.exp(-r_space**2 / 2 / GaborSigma**2) #include the contrasts with it if len(contrasts) > 1: StimConds = Contrasts[:, None] * np.vstack((InSr, InGabor)) else: StimConds = Contrasts[:, None] * InSr StimConds = StimConds * In0 #include the relative drive between E and I cells -- nixing this cause gE and gI are parametrs #InSpace = np.hstack( (StimConds, gI*StimConds)).T #.T makes it neurons by stimcond #array to reference to find max contrasts, etc stimulus_condition = np.vstack((Contrasts, np.hstack( (rads, np.max(rads))))) return StimConds.T, stimulus_condition, InSr
def ssn_FP(pos_params, OLDSTYLE): ''' Fcn that finds the fixed point and PS of the given SSN network. returns spect, frequencies used to find that spect, peak frequencies (f0), fixed point rates (r_fp), and CONVG == 5 outputs inputs pos_params are params that once sigmoided are always positive. OLDSTYLE = old ranges for gE/gI in sigmoid params fcn ''' params = sigmoid_params(pos_params, MULTI=True, OLDSTYLE=OLDSTYLE) #unpack parameters Jee = params[0] * np.pi * psi Jei = params[1] * np.pi * psi Jie = params[2] * np.pi * psi Jii = params[3] * np.pi * psi if len(params) < 8: i2e = params[4] Plocal = params[5] sigR = params[6] gE = 1 gI = 1 * i2e NMDAratio = 0.1 elif len(params) == 9: i2e = 1 gE = params[4] gI = params[5] NMDAratio = params[6] Plocal = params[7] PlocalIE = Plocal sigR = params[8] sigEE = 0.35 * np.sqrt(sigR) sigIE = 0.35 / np.sqrt(sigR) elif len(params) == 10: i2e = 1 gE = params[4] gI = params[5] NMDAratio = params[6] Plocal = params[7] sigEE = params[8] sigIE = params[9] PlocalIE = Plocal else: i2e = 1 gE = params[4] gI = params[5] NMDAratio = params[6] Plocal = params[7] PlocalIE = params[8] sigEE = params[9] sigIE = params[10] sigEE = sigEE / magnFactor # sigEE now in degress sigIE = sigIE / magnFactor # sigIE now in degrees W = make_conn.make_full_W(Plocal, Jee, Jei, Jie, Jii, sigEE, sigIE, deltaD, OMap, sigXI=0.09, PlocalIE=PlocalIE) ssn = SSN_classes._SSN_AMPAGABA(tau_s, NMDAratio, n, k, Ne, Ni, tau_vec, W) ssn.topos_vec = np.ravel(OMap) r_init = np.zeros([ssn.N, len(Contrasts)]) #Inp_vec is Ne+Ni x stimulus conditions (typically 8) to find FP and such things inp_vec = np.vstack((gE * Inp, gI * Inp)) r_fp, CONVG = ssn.fixed_point_r(inp_vec, r_init=r_init, Tmax=Tmax, dt=dt, xtol=xtol) #calculate power spectrum - find PS for each stimulus condition and concatenate them together for cc in range(cons): if cc == 0: spect, fs, _ = SSN_power_spec.linear_power_spect( ssn, r_fp[:, cc], noise_pars, freq_range, fnums, LFPrange=[LFPtarget[0]]) elif cc == 1: spect_2, _, _ = SSN_power_spec.linear_power_spect( ssn, r_fp[:, cc], noise_pars, freq_range, fnums, LFPrange=[LFPtarget[0]]) spect = np.concatenate((spect[:, None], spect_2[:, None]), axis=1) else: spect_2, _, _ = SSN_power_spec.linear_power_spect( ssn, r_fp[:, cc], noise_pars, freq_range, fnums, LFPrange=[LFPtarget[0]]) spect = np.concatenate((spect, spect_2[:, None]), axis=1) # My one-shot way of finding the PS. The above version works better # if cons == 1: # spect, fs, f0, _ = SSN_power_spec.linear_power_spect(ssn, r_fp, noise_pars, freq_range, fnums, cons, LFPrange=[LFPtarget[0]]) # if np.max(np.abs(np.imag(spect))) > 0.01: # print("Spectrum is dangerously imaginary") # else: # spect, fs, f0, _ = SSN_power_spec.linear_PS_sameTime(ssn, r_fp[:, con_inds], noise_pars, freq_range, fnums, cons, LFPrange=[LFPtarget[0]]) #find PS at the outer neurons outer_spect = make_outer_spect(ssn, r_fp[:, gabor_inds], probes) spect = np.concatenate((spect, outer_spect), axis=1) if np.max(np.abs(np.imag(spect))) > 0.01: print("Spectrum is dangerously imaginary") spect = np.real(spect) # I'm keeping f0 in the outputs just for congruency between the topo SSN and 2-D SSN code f0 = 0 #print(spect.shape) return spect, fs, f0, r_fp, CONVG
gen_inds = np.arange(len(Contrasts)) rad_inds = np.arange( len(contrasts) - 1, len(r_cent) + len(contrasts) - 1) #np.where(stimCon[0, :] == np.max(Contrasts), gen_inds, 0) con_inds = np.hstack( (np.arange(0, len(contrasts) - 1), len(r_cent) + len(contrasts) - 2)) #np.where(stimCon[1, :] == np.max(stimCon[1,:]), gen_inds, 0) gabor_inds = -1 trgt = np.floor(Ne / 2) con_inds = np.hstack((con_inds, gabor_inds)) cons = len(con_inds) ssn_Ampa.topos_vec = np.ravel(OMap) if ssn_Ampa.N > 2: LFPtarget = trgt + np.array( [ii * gridsize for ii in range(int(np.floor(gridsize / 2)))]) else: LFPtarget = None spect, fs, f0, _ = SSN_power_spec.linear_PS_sameTime( ssn_Ampa, r_fp[:, con_inds], SSN_power_spec.NoisePars(), freq_range, fnums, cons, LFPrange=[LFPtarget[0]])
def _get_leaf_diagnostics(leaf, key_prefix): # update this to add more grads diagnostics return { f'{key_prefix}max': jnp.max(jnp.abs(leaf)), f'{key_prefix}norm': jnp.linalg.norm(jnp.ravel(leaf)), }
def func(x: jnp.DeviceArray) -> jnp.DeviceArray: return jnp.ravel(x)
def _ravel_list(*lst): return jnp.concatenate([jnp.ravel(elt) for elt in lst]) if lst else jnp.array([])
def reshape(x, y): return jnp.ravel(jnp.array([x, y]), order="F").reshape(10, 2)
for i in range(1,len(nums)): epsilon = eps[i] n_iter = n_iters[i] theta_i = np.tile(np.linspace(0, 2 * np.pi, NS+1)[:-1], (NC,1)) _, params_new = get_all_coil_data("../../../tests/w7x/scanold2/w7x_l{}.hdf5".format(nums[i])) fc_new, _ = params_new theta_i = find_minimum_theta_all_coils(fc_new, r_fil, theta_i) print("Size is {}".format(sizes[i])) print("The original delta r 1 is:") print(np.mean(np.linalg.norm(r_fil - filament_real_space(fc_new, np.tile(np.linspace(0, 2 * np.pi, NS+1)[:-1], (NC,1))), axis=-1))) print("The new delta r 1 is (with minimization):") print(np.mean(np.linalg.norm(r_fil - filament_real_space(fc_new, theta_i), axis=-1))) mean_delta_rs.append(np.mean(np.linalg.norm(r_fil - filament_real_space(fc_new, theta_i), axis=-1))) print("The max distance is") print(np.max(np.linalg.norm(r_fil - filament_real_space(fc_new, theta_i), axis=-1))) max_delta_rs.append(np.max(np.linalg.norm(r_fil - filament_real_space(fc_new, theta_i), axis=-1))) difference = np.linalg.norm(r_fil - filament_real_space(fc_new, np.tile(np.linspace(0, 2 * np.pi, NS+1)[:-1], (NC,1))), axis=-1) - np.linalg.norm(r_fil - filament_real_space(fc_new, theta_i), axis=-1) new_diff = difference[difference > 0] larger = np.ravel(difference).shape[0] - new_diff.shape[0] if larger > 0: print(larger) np.save("w7x_sizes.npy", np.asarray(sizes)) np.save("w7x_mean_delta_rs.npy", np.asarray(mean_delta_rs)) np.save("w7x_max_delta_rs.npy", np.asarray(max_delta_rs))
def ravel_list(*lst): return np.concatenate([np.ravel(elt) for elt in lst]) if lst else np.array([])
def test_logmarglike_lineargaussianmodel_onetransfer_basics(): theta_truth = jax.random.normal(key, (n_components,)) M_T = design_matrix_polynomials(n_components, n_pix_y) # (n_components, n_pix_y) y_truth = np.matmul(theta_truth, M_T) # (nobj, n_pix_y) y, yinvvar, logyinvvar = make_masked_noisy_data(y_truth) # (nobj, n_pix_y) assert_equal_shape([y_truth, y, yinvvar, logyinvvar]) logfml, theta_map, theta_cov = logmarglike_lineargaussianmodel_onetransfer( M_T, y, yinvvar ) # check result is finite and shapes are correct assert_shape(theta_map, (n_components,)) assert_shape(theta_cov, (n_components, n_components)) assert np.isfinite(logfml) assert np.all(np.isfinite(theta_map)) assert np.all(np.isfinite(theta_cov)) # check that result isn't too far off the truth, in chi2 sense dt = theta_map - theta_truth chi2 = 0.5 * np.ravel(np.matmul(dt.T, np.linalg.solve(theta_cov, dt))) assert chi2 < 100 # check that normalised posterior distribution factorises into product of gaussians def log_posterior(theta): y_mod = np.matmul(theta, M_T) # (n_samples, n_pix_y) return batch_gaussian_loglikelihood(y_mod - y, yinvvar) def log_posterior2(theta): dt = theta - theta_map s, logdet = np.linalg.slogdet(theta_cov * 2 * np.pi) chi2 = np.dot(dt.T, np.linalg.solve(theta_cov, dt)) return logfml - 0.5 * (s * logdet + chi2) logpostv = log_posterior(theta_truth) logpostv2 = log_posterior2(theta_truth) assert abs(logpostv2 / logpostv - 1) < 0.01 # now trying jit version of function logfml2, theta_map2, theta_cov2 = logmarglike_lineargaussianmodel_onetransfer_jit( M_T, y, yinvvar, logyinvvar ) # check that outputs match original implementation assert_fml_thetamap_thetacov( logfml, theta_map, theta_cov, logfml2, theta_map2, theta_cov2, relative_accuracy ) # now running simple optimiser to check that result is indeed optimum def loss_fn(theta): return -log_posterior(theta) params = [1 * theta_map] learning_rate = 1e-5 for n in range(10): grads = grad(loss_fn)(*params) params = [param - learning_rate * grad for param, grad in zip(params, grads)] # print(n, loss_fn(*params), params[0] - theta_map) assert np.allclose(theta_map, params[0], rtol=1e-6) # Testing analytic covariance is correct and equals inverse of hessian theta_cov2 = np.linalg.inv(np.reshape(hessian(loss_fn)(theta_map), theta_cov.shape)) assert np.allclose(theta_cov, theta_cov2, rtol=1e-6) # create vectorised loss loss_fn_vmap = jit(vmap(loss_fn)) # now computes the evidence numerically n = 15 theta_std = np.diag(theta_cov) ** 0.5 theta_samples, vol_element = generate_sample_grid(theta_map, theta_std, n) loglikelihoods = -loss_fn_vmap(theta_samples) logfml_numerical = logsumexp(np.log(vol_element) + loglikelihoods) # print("logfml, logfml_numerical", logfml, logfml_numerical) assert abs(logfml_numerical / logfml - 1) < 0.01 # Compare with case including gaussian prior with large variance mu = theta_map * 0 muinvvar = 1 / (1e4 * np.diag(theta_cov) ** 0.5) logfml2, theta_map2, theta_cov2 = logmarglike_lineargaussianmodel_twotransfers( M_T, y, yinvvar, mu, muinvvar ) assert_fml_thetamap_thetacov( logfml, theta_map, theta_cov, logfml2, theta_map2, theta_cov2, 0.2 )
def initialize_odefilter_with_taylormode(f, y0, t0, prior, initrv): """Initialize an ODE filter with Taylor-mode automatic differentiation. This requires JAX. For an explanation of what happens ``under the hood``, see [1]_. References ---------- .. [1] Krämer, N. and Hennig, P., Stable implementation of probabilistic ODE solvers, *arXiv:2012.10106*, 2020. The implementation is inspired by the implementation in https://github.com/jacobjinkelly/easy-neural-ode/blob/master/latent_ode.py Parameters ---------- f ODE vector field. y0 Initial value. t0 Initial time point. prior Prior distribution used for the ODE solver. For instance an integrated Brownian motion prior (``IBM``). initrv Initial random variable. Returns ------- Normal Estimated initial random variable. Compatible with the specified prior. Examples -------- >>> import sys, pytest >>> if not sys.platform.startswith('linux'): ... pytest.skip() >>> from dataclasses import astuple >>> from probnum.randvars import Normal >>> from probnum.problems.zoo.diffeq import threebody_jax, vanderpol_jax >>> from probnum.statespace import IBM Compute the initial values of the restricted three-body problem as follows >>> f, t0, tmax, y0, df, *_ = astuple(threebody_jax()) >>> print(y0) [ 0.994 0. 0. -2.00158511] >>> prior = IBM(ordint=3, spatialdim=4) >>> initrv = Normal(mean=np.zeros(prior.dimension), cov=np.eye(prior.dimension)) >>> improved_initrv = initialize_odefilter_with_taylormode(f, y0, t0, prior, initrv) >>> print(prior.proj2coord(0) @ improved_initrv.mean) [ 0.994 0. 0. -2.00158511] >>> print(improved_initrv.mean) [ 9.94000000e-01 0.00000000e+00 -3.15543023e+02 0.00000000e+00 0.00000000e+00 -2.00158511e+00 0.00000000e+00 9.99720945e+04 0.00000000e+00 -3.15543023e+02 0.00000000e+00 6.39028111e+07 -2.00158511e+00 0.00000000e+00 9.99720945e+04 0.00000000e+00] Compute the initial values of the van-der-Pol oscillator as follows >>> f, t0, tmax, y0, df, *_ = astuple(vanderpol_jax()) >>> print(y0) [2. 0.] >>> prior = IBM(ordint=3, spatialdim=2) >>> initrv = Normal(mean=np.zeros(prior.dimension), cov=np.eye(prior.dimension)) >>> improved_initrv = initialize_odefilter_with_taylormode(f, y0, t0, prior, initrv) >>> print(prior.proj2coord(0) @ improved_initrv.mean) [2. 0.] >>> print(improved_initrv.mean) [ 2. 0. -2. 60. 0. -2. 60. -1798.] >>> print(improved_initrv.std) [0. 0. 0. 0. 0. 0. 0. 0.] """ try: import jax.numpy as jnp from jax.config import config from jax.experimental.jet import jet config.update("jax_enable_x64", True) except ImportError as err: raise ImportError( "Cannot perform Taylor-mode initialisation without optional " "dependencies jax and jaxlib. Try installing them via `pip install jax jaxlib`." ) from err order = prior.ordint def total_derivative(z_t): """Total derivative.""" z, t = jnp.reshape(z_t[:-1], z_shape), z_t[-1] dz = jnp.ravel(f(t, z)) dt = jnp.array([1.0]) dz_t = jnp.concatenate((dz, dt)) return dz_t z_shape = y0.shape z_t = jnp.concatenate((jnp.ravel(y0), jnp.array([t0]))) derivs = [] derivs.extend(y0) if order == 0: all_derivs = statespace.Integrator._convert_derivwise_to_coordwise( np.asarray(jnp.array(derivs)), ordint=0, spatialdim=len(y0)) return randvars.Normal( np.asarray(all_derivs), cov=np.asarray(jnp.diag(jnp.zeros(len(derivs)))), cov_cholesky=np.asarray(jnp.diag(jnp.zeros(len(derivs)))), ) (dy0, [*yns]) = jet(total_derivative, (z_t, ), ((jnp.ones_like(z_t), ), )) derivs.extend(dy0[:-1]) if order == 1: all_derivs = statespace.Integrator._convert_derivwise_to_coordwise( np.asarray(jnp.array(derivs)), ordint=1, spatialdim=len(y0)) return randvars.Normal( np.asarray(all_derivs), cov=np.asarray(jnp.diag(jnp.zeros(len(derivs)))), cov_cholesky=np.asarray(jnp.diag(jnp.zeros(len(derivs)))), ) for _ in range(1, order): (dy0, [*yns]) = jet(total_derivative, (z_t, ), ((dy0, *yns), )) derivs.extend(yns[-2][:-1]) all_derivs = statespace.Integrator._convert_derivwise_to_coordwise( jnp.array(derivs), ordint=order, spatialdim=len(y0)) return randvars.Normal( np.asarray(all_derivs), cov=np.asarray(jnp.diag(jnp.zeros(len(derivs)))), cov_cholesky=np.asarray(jnp.diag(jnp.zeros(len(derivs)))), )
def rand_argmax(a): return np.random.choice(jnp.nonzero(jnp.ravel(a == jnp.max(a)))[0])
X = iris["data"] y = (iris["target"] == 2).astype(np.int) # 1 if Iris-Virginica, else 0' N, D = X.shape # 150, 4 X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.33, random_state=42) from sklearn.linear_model import LogisticRegression # We set C to a large number to turn off regularization. # We don't fit the bias term to simplify the comparison below. log_reg = LogisticRegression(solver="lbfgs", C=1e5, fit_intercept=False) log_reg.fit(X_train, y_train) w_mle_sklearn = jnp.ravel(log_reg.coef_) set_seed(0) w = w_mle_sklearn ## Compute gradient of loss "by hand" using numpy def BCE_with_logits(logits, targets): N = logits.shape[0] logits = logits.reshape(N, 1) logits_plus = jnp.hstack([np.zeros((N, 1)), logits]) # e^0=1 logits_minus = jnp.hstack([np.zeros((N, 1)), -logits]) logp1 = -logsumexp(logits_minus, axis=1) logp0 = -logsumexp(logits_plus, axis=1) logprobs = logp1 * targets + logp0 * (1 - targets)
def gaussian_kde_cdf(x_eval, samples, factor): low = jnp.ravel((-jnp.inf - samples) / factor) hi = jnp.ravel((x_eval - samples) / factor) return jax.scipy.special.ndtr(hi - low).mean(axis=0)
def _equality_constraints(variables): return np.ravel(equality_constraints(*unravel(variables)))
def _coo_fromdense_impl(mat, *, nnz, index_dtype): mat = jnp.asarray(mat) m, n = mat.shape mat_flat = jnp.ravel(mat) ind = jnp.nonzero(mat_flat, size=nnz)[0].astype(index_dtype) return mat_flat[ind], ind // n, ind % n
def add_fxy_coord(X, Y, coefficient): X_fxy_agg.append(jnp.ravel(X)) Y_fxy_agg.append(jnp.ravel(Y)) fxy_coefficient.append(coefficient * jnp.ones_like(Y_fxy_agg[-1]))