def make_neur_distances(gridsizedeg, gridperdeg, hyper_col, PERIODIC=False): ''' Makes a matrix of distances between neurons in the network gridsizedeg = size of grid in degress gridperdeg = number of grid points per degree hyper_col = hypercolumn length of the network Lx = length of the x direction in degrees Ly = length of the y direction in degrees outputs: X = matrix of distances between neurons in the x direction Y = matrix of distances between neurons in the y direction deltaD = matrix of distances between neurons used to make W ''' gridsize = 1 + round(gridperdeg * gridsizedeg) Lx = gridsizedeg Ly = gridsizedeg dx = Lx / (gridsize - 1) dy = Ly / (gridsize - 1) [X, Y] = np.meshgrid(np.arange(0, Lx + dx, dx), np.arange(0, Ly + dy, dy)) [indX, indY] = np.meshgrid(np.arange(gridsize), np.arange(gridsize)) XDist = np.abs(np.ravel(indX, 'F') - np.ravel(indX, 'F')[:, None]) YDist = np.abs(np.ravel(indY, 'F') - np.ravel(indY, 'F')[:, None]) if PERIODIC: XDist = np.where(XDist > gridsize / 2, gridsize - XDist, XDist) YDist = np.where(YDist > gridsize / 2, gridsize - YDist, YDist) deltaD = np.sqrt(XDist**2 + YDist**2) * dx return X, Y, deltaD
def _make_xy_arrays(shape, On): x = jnp.arange(shape[1], dtype=jnp.float64) y = jnp.arange(shape[0], dtype=jnp.float64) cx, cy = jnp.meshgrid(x, y) x_super = jnp.linspace(0.5 / On - 0.5, shape[1] - 0.5 - 0.5 / On, shape[1] * On) y_super = jnp.linspace(0.5 / On - 0.5, shape[0] - 0.5 - 0.5 / On, shape[0] * On) cx_super, cy_super = jnp.meshgrid(x_super, y_super) return (cx, cy), (cx_super, cy_super)
def steady_probes(filepath, dt=1, dx=0.01, real_size=12, scaled_size=256): ''' Calculates the contact electrogram at a 4x4 grid, that is distributed along the field. - points: list of tuples, each tuple contains (x,y) of location (in cm) of unipolar - dt: timestep between states set to 1 ms - dx: grid discretization set to 0.01 cm - real_size: in cm - scaled_size: in pixels - matrix dimensions Returns the 4x4 field, of 16 pixels, where each pixel is the value of the contact electrogram at each of the positions. ''' p = [2.4, 4.8, 7.2, 9.6] points = np.meshgrid(p, p, indexing='ij') filename = filepath[:-5] + '_ecg.hdf5' states_file = h5py.File(filepath, 'r') states = states_file['states'] conductivity_field = states_file["params/D"][:] shape = states.shape phi = np.zeros((shape[0], 4, 4)) x = np.linspace(0, real_size, states.shape[-1]) y = np.linspace(0, real_size, states.shape[-1]) xv, yv = np.meshgrid(x, y, sparse=False, indexing='ij') egm_x = points[0].reshape((1, 16)) egm_y = points[1].reshape((1, 16)) for i in range(states.shape[0]): u = states[i, 2] u_x = gradient(u, 0) / dx u_y = gradient(u, 1) / dx val = -np.sum((u_y[:, :, np.newaxis] * (egm_y - yv[:, :, np.newaxis]) + u_x[:, :, np.newaxis] * (egm_x - xv[:, :, np.newaxis])) * 0.0001 / (np.pi * 4 * 2.36 * 1 * np.sqrt( (egm_x - xv[:, :, np.newaxis])**2 + (egm_y - yv[:, :, np.newaxis])**2 + 10**(-2))**3), axis=(0, 1)).reshape((4, 4)) phi = jax.ops.index_update(phi, i, val) hdf5 = h5py.File(filename, "w") ecg_dset = hdf5.create_dataset('electrogram', shape=phi.shape, dtype="float32") ecg_dset[:] = phi conductivity = hdf5.create_dataset('conductivity', shape=states.shape[-2:], dtype='float32') conductivity[:] = conductivity_field return True
def ks93(g1, g2): """Direct inversion of weak-lensing shear to convergence. This function is an implementation of the Kaiser & Squires (1993) mass mapping algorithm. Due to the mass sheet degeneracy, the convergence is recovered only up to an overall additive constant. It is chosen here to produce output maps of mean zero. The inversion is performed in Fourier space for speed. Parameters ---------- g1, g2 : array_like 2D input arrays corresponding to the first and second (i.e., real and imaginary) components of shear, binned spatially to a regular grid. Returns ------- kE, kB : tuple of numpy arrays E-mode and B-mode maps of convergence. Raises ------ AssertionError For input arrays of different sizes. See Also -------- bin2d For binning a galaxy shear catalog. Examples -------- >>> # (g1, g2) should in practice be measurements from a real galaxy survey >>> g1, g2 = 0.1 * np.random.randn(2, 32, 32) + 0.1 * np.ones((2, 32, 32)) >>> kE, kB = ks93(g1, g2) >>> kE.shape (32, 32) >>> kE.mean() 1.0842021724855044e-18 """ # Check consistency of input maps assert g1.shape == g2.shape # Compute Fourier space grids (nx, ny) = g1.shape k1, k2 = np.meshgrid(np.fft.fftfreq(ny), np.fft.fftfreq(nx)) # Compute Fourier transforms of g1 and g2 g1hat = np.fft.fft2(g1) g2hat = np.fft.fft2(g2) # Apply Fourier space inversion operator p1 = k1 * k1 - k2 * k2 p2 = 2 * k1 * k2 k2 = k1 * k1 + k2 * k2 #k2[0, 0] = 1 # avoid division by 0 k2 = jax.ops.index_update(k2, jax.ops.index[0, 0], 1.) # avoid division by 0 kEhat = (p1 * g1hat + p2 * g2hat) / k2 kBhat = -(p2 * g1hat - p1 * g2hat) / k2 # Transform back to pixel space kE = np.fft.ifft2(kEhat).real kB = np.fft.ifft2(kBhat).real return kE, kB
def get_ray_bundle(height, width, focal_length, tfrom_cam2world): ii, jj = jnp.meshgrid( jnp.arange( width, dtype=jnp.float32, ), jnp.arange( height, dtype=jnp.float32, ), indexing="xy", ) directions = jnp.stack( [ (ii - width * 0.5) / focal_length, -(jj - height * 0.5) / focal_length, -jnp.ones_like(ii), ], axis=-1, ) ray_directions = jnp.sum(directions[..., None, :] * tfrom_cam2world[:3, :3], axis=-1) ray_origins = jnp.broadcast_to(tfrom_cam2world[:3, -1], ray_directions.shape) return ray_origins, ray_directions
def show3d(state, rcount=200, ccount=200, zlim=None, figsize=None): # setup figure fig = plt.figure(figsize=figsize) ax = fig.add_subplot(projection="3d") # make surface plot r = list(range(0, len(state))) x, y = np.meshgrid(r, r) plot = ax.plot_surface(x, y, state, rcount=rcount, ccount=ccount, cmap="magma") # add colorbar cbar = fig.colorbar(plot) cbar.ax.set_title("mV") if zlim is not None: ax.set_zlim3d(zlim[0], zlim[1]) # format axes ax.w_xaxis.set_pane_color((1.0, 1.0, 1.0, 1.0)) ax.w_yaxis.set_pane_color((1.0, 1.0, 1.0, 1.0)) ax.w_zaxis.set_pane_color((1.0, 1.0, 1.0, 1.0)) ax.xaxis.set_major_formatter( FuncFormatter(lambda y, _: '{:.1f}'.format(y / 100))) ax.set_xlabel("x [cm]") ax.yaxis.set_major_formatter( FuncFormatter(lambda y, _: '{:.1f}'.format(y / 100))) ax.set_ylabel("y [cm]") ax.set_zlabel("Voltage [mV]") # crop image fig.tight_layout() return fig, ax
def evaluate_density(log_p_fn, bounds, num_points): xs = jnp.linspace(bounds[0][0], bounds[0][1], num=num_points) ys = jnp.linspace(bounds[1][0], bounds[1][1], num=num_points) X, Y = jnp.meshgrid(xs, ys) xs = jnp.stack([X, Y], axis=-1) log_ps = log_p_fn(xs) return X, Y, jnp.exp(log_ps)
def _get_1d_latent_grid(self, paths_x): num_points_grid = 20 max_x = np.amax(paths_x) min_x = np.amin(paths_x) x_array = np.linspace(min_x, max_x, 20)[:, None] xx, tt = np.meshgrid(np.linspace(min_x, max_x, 20), np.linspace(0, 1, paths_x.shape[1])) txpairs = np.transpose(np.vstack([tt.reshape(-1), xx.reshape(-1)])) def scan_fn(carry, paths): drift, diffusion, index = carry time = index * self.config["delta_t"] - 0.5 gp_matrices, temp_drift_function, temp_diffusion_function = self.model.build( self.model.model_vars()) temp_drift = temp_drift_function(x_array, time) temp_diffusion = np.linalg.det( temp_diffusion_function(x_array, time)) drift = ops.index_add(drift, ops.index[index], temp_drift) diffusion = ops.index_add(diffusion, ops.index[index], temp_diffusion) index += 1 return (drift, diffusion, index), np.array([0.]) drift_grid = np.zeros((paths_x.shape[1], num_points_grid, 1)) diffusion_grid = np.zeros((paths_x.shape[1], num_points_grid)) (drift_grid, diffusion_grid, index), _ = lax.scan(scan_fn, (drift_grid, diffusion_grid, 0), np.transpose(paths_x, (1, 0, 2))) return txpairs, drift_grid.reshape(-1), diffusion_grid.reshape(-1)
def extract_weighted_patches(x, weights, kernel, stride, padding): """Weighted average of patches using jax.lax.scan.""" logging.info("recompiling for kernel=%s and stride=%s and padding=%s", kernel, stride, padding) x = jnp.pad(x, ((0, 0), (padding[0], padding[0] + kernel[0]), (padding[1], padding[1] + kernel[1]), (0, 0))) batch_size, _, _, channels = x.shape _, k, weights_h, weights_w = weights.shape def accumulate_patches(acc, index_i_j): i, j = index_i_j patch = jax.lax.dynamic_slice( x, (0, i * stride[0], j * stride[1], 0), (batch_size, kernel[0], kernel[1], channels)) weight = weights[:, :, i, j] weighted_patch = jnp.einsum("bk, bijc -> bkijc", weight, patch) acc += weighted_patch return acc, None indices = jnp.stack(jnp.meshgrid(jnp.arange(weights_h), jnp.arange(weights_w), indexing="ij"), axis=-1) indices = indices.reshape((-1, 2)) init_patches = jnp.zeros((batch_size, k, kernel[0], kernel[1], channels)) patches, _ = jax.lax.scan(accumulate_patches, init_patches, indices) return patches
def gen_julia(): im_width, im_height = 5000, 5000 max_it = 300 max_z = 100 min_x, max_x = -1, 1 min_y, max_y = -1, 1 c = complex(-0.5, 0.61) ix_array, iy_array = np.meshgrid(range(im_width), range(im_height)) real_part = ((max_x - min_x) / im_width) * ix_array + min_x im_part = ((max_y - min_y) / im_height) * iy_array + min_y z = real_part + (1j) * im_part not_done = np.ones( (im_width, im_height) ) # Array which records which pixels have not finished being computed it = np.zeros((im_width, im_height)) while np.any(not_done): not_done = np.logical_and(np.abs(z) < max_z, it < max_it) new_z_vals = z[not_done]**2 + c jax.ops.index_update(z, not_done, new_z_vals) jax.ops.index_add(it, not_done, 1) julia_set = it / max_it return julia_set
def generate_sample_grid(theta_mean, theta_std, n): """ Create a meshgrid of n ** n_dim samples, tiling [theta_mean[i] - 5 * theta_std[i], theta_mean[i] + 5 * theta_std] into n portions. Also returns the volume element. Parameters ---------- theta_mean, theta_std : ndarray (n_dim) Returns ------- theta_samples : ndarray (nobj, n_dim) vol_element: scalar Volume element """ n_components = theta_mean.size xs = [ np.linspace( theta_mean[i] - 5 * theta_std[i], theta_mean[i] + 5 * theta_std[i], n, ) for i in range(n_components) ] mxs = np.meshgrid(*xs) orshape = mxs[0].shape mxsf = np.vstack([i.ravel() for i in mxs]).T dxs = np.vstack([np.diff(xs[i])[i] for i in range(n_components)]) vol_element = np.prod(dxs) theta_samples = np.vstack(mxsf) return theta_samples, vol_element
def cartesian_product(*arrays): ''' JAX-friendly version of cartesian product. ''' tmp = jnp.asarray(jnp.meshgrid(*arrays, indexing='ij')).reshape(len(arrays), -1).T return tmp
def get_pixel_grid(self, scene: Scene) -> Any: """Constructs a spatial grid of points to shoot rays through.""" aspect_ratio = scene.width / scene.height x = jnp.linspace(-1., 1., scene.width) y = jnp.linspace(1 / aspect_ratio, -1. / aspect_ratio, scene.height) X, Y = jnp.meshgrid(x, y) return X, Y
def sampler(img_target, pose, intrinsics, rng, options): """ Given a single image, samples rays """ pose_target = pose[:3, :4] ray_origins, ray_directions = get_ray_bundle( intrinsics.height, intrinsics.width, intrinsics.focal_length, pose_target ) coords = jnp.stack( jnp.meshgrid( jnp.arange(intrinsics.height), jnp.arange(intrinsics.width), indexing="xy" ), axis=-1, ).reshape((-1, 2)) select_inds = jax.random.choice( rng, coords.shape[0], shape=(options.num_random_rays,), replace=False ) select_inds = coords[select_inds] ray_origins = ray_origins[select_inds[:, 0], select_inds[:, 1], :] ray_directions = ray_directions[select_inds[:, 0], select_inds[:, 1], :] target_s = img_target[select_inds[:, 0], select_inds[:, 1], :] return ray_origins, ray_directions, target_s
def _make_gabor(params: jnp.ndarray, rf_dim: Tuple[int, int]) -> jnp.DeviceArray: σ, θ, λ, γ, φ = [ u[:, jnp.newaxis, jnp.newaxis] for u in (params[:, 0], params[:, 1], params[:, 2], params[:, 3], params[:, 4]) ] pos_x, pos_y = [ u[:, jnp.newaxis, jnp.newaxis] for u in (params[:, 5], params[:, 6]) ] n = params.shape[0] x, y = jnp.meshgrid(jnp.arange(-rf_dim[0], rf_dim[0]), jnp.arange(-rf_dim[1], rf_dim[1])) x = jnp.repeat(x[jnp.newaxis, :, :], n, axis=0) y = jnp.repeat(y[jnp.newaxis, :, :], n, axis=0) xp = (pos_x - x) * cos(θ) - (pos_y - y) * sin(θ) yp = (pos_x - x) * sin(θ) + (pos_y - y) * cos(θ) output = exp(-(xp**2 + (γ * yp)**2) / (2 * σ**2)) * exp(1j * (2 * π * xp / λ + φ)) return zscore_img(output.real)
def __init__(self, variance=1.0, lengthscale_periodic=1.0, period=1.0, lengthscale_matern=1.0, order=6): self.transformed_lengthscale_periodic = objax.TrainVar( np.array(softplus_inv(lengthscale_periodic))) self.transformed_variance = objax.TrainVar( np.array(softplus_inv(variance))) self.transformed_period = objax.TrainVar(np.array( softplus_inv(period))) self.transformed_lengthscale_matern = objax.TrainVar( np.array(softplus_inv(lengthscale_matern))) super().__init__() self.name = 'Quasi-periodic Matern-3/2' self.order = order self.igrid = np.meshgrid(np.arange(self.order + 1), np.arange(self.order + 1))[1] factorial_mesh_K = np.array( [[1., 1., 1., 1., 1., 1., 1.], [1., 1., 1., 1., 1., 1., 1.], [2., 2., 2., 2., 2., 2., 2.], [6., 6., 6., 6., 6., 6., 6.], [24., 24., 24., 24., 24., 24., 24.], [120., 120., 120., 120., 120., 120., 120.], [720., 720., 720., 720., 720., 720., 720.]]) b = np.array([[1., 0., 0., 0., 0., 0., 0.], [0., 2., 0., 0., 0., 0., 0.], [2., 0., 2., 0., 0., 0., 0.], [0., 6., 0., 2., 0., 0., 0.], [6., 0., 8., 0., 2., 0., 0.], [0., 20., 0., 10., 0., 2., 0.], [20., 0., 30., 0., 12., 0., 2.]]) self.b_fmK_2igrid = b * (1. / factorial_mesh_K) * (2.**-self.igrid)
def make_2d_quadrature_grid(x_quad, w_quad): grid = jnp.meshgrid(x_quad, x_quad) weights = jnp.outer(w_quad, w_quad).reshape(-1) stacked = jnp.stack([grid[0].reshape(-1), grid[1].reshape(-1)], axis=1) return stacked.T, weights
def calc_round_egm(params, points, dt=1, dx=0.01, real_size=12, scaled_size=256, radius=3): ''' Calculates the contact electrogram at the given points. - points: list of tuples, each tuple contains (x,y) of location (in cm) of unipolar - dt: timestep between states set to 1 ms - dx: grid discretization set to 0.01 cm - real_size: in cm - scaled_size: in pixels - matrix dimensions ''' states_file = h5py.File(params['file'], 'r') states = states_file['states'] shape = states.shape phi = np.zeros((len(points), shape[0])) x = np.linspace(0, real_size, states.shape[-1]) y = np.linspace(0, real_size, states.shape[-1]) xv, yv = np.meshgrid(x, y, sparse=False, indexing='ij') radi = radius * scaled_size / real_size X = int(radi) circle = np.array([[x, y] for x in range(-X, X + 1) for y in range(-int((radi * radi - x * x)**0.5), int((radi * radi - x * x)**0.5) + 1)]) for i in range(states.shape[0]): u = states[i, 2, :, :] u_x = gradient(u, 0) / dx u_y = gradient(u, 1) / dx for j in range(len(points)): egm_x = points[j][0] egm_y = points[j][1] circ_mask = circle + np.array([ int(scaled_size * egm_x / real_size), int(scaled_size * egm_y / real_size) ]) circ_mask = circ_mask[(circ_mask[:, 0] < scaled_size) & (circ_mask[:, 1] < scaled_size)] ii = circ_mask[:, 0] jj = circ_mask[:, 1] val = (u_y[ii, jj] * (egm_y - yv[ii, jj]) + u_x[ii, jj] * (egm_x - xv[ii, jj])) * 0.0001 / ( np.pi * 4 * 2.36 * 1 * np.sqrt((egm_x - xv[ii, jj])**2 + (egm_y - yv[ii, jj])**2 + 10**(-2))**3) # val = jax.ops.index_update(val, jax.ops.index[circ_mask[:,0], circ_mask[:,1]], 0) # plt.imshow(val) val = -np.sum(val) phi = jax.ops.index_update(phi, jax.ops.index[j, i], val) # print(i) return phi
def __init__(self, temporal_kernel, spatial_kernel, z=None, conditional=None, sparse=True, opt_z=False, spatial_dims=None): self.temporal_kernel = temporal_kernel self.spatial_kernel = spatial_kernel if conditional is None: if sparse: conditional = 'Full' else: conditional = 'DTC' if opt_z and ( not sparse ): # z should not be optimised if the model is not sparse warn( "spatial inducing inputs z will not be optimised because sparse=False" ) opt_z = False self.sparse = sparse if z is None: # initialise z # TODO: smart initialisation if spatial_dims == 1: z = np.linspace(-3., 3., num=15) elif spatial_dims == 2: z1 = np.linspace(-3., 3., num=5) zA, zB = np.meshgrid( z1, z1) # Adding additional dimension to inducing points grid z = np.hstack((zA.reshape(-1, 1), zB.reshape( -1, 1))) # Flattening grid for use in kernel functions else: raise NotImplementedError( 'please provide an initialisation for inducing inputs z') if z.ndim < 2: z = z[:, np.newaxis] if spatial_dims is None: spatial_dims = z.ndim - 1 assert spatial_dims == z.ndim - 1 self.M = z.shape[0] if opt_z: self.z = objax.TrainVar(z) # .reshape(-1, 1) else: self.z = objax.StateVar(z) if conditional in ['DTC', 'dtc']: self.conditional_covariance = self.deterministic_training_conditional elif conditional in ['FIC', 'FITC', 'fic', 'fitc']: self.conditional_covariance = self.fully_independent_conditional elif conditional in ['Full', 'full']: self.conditional_covariance = self.full_conditional else: raise NotImplementedError('conditional method not recognised') if (not sparse) and (conditional != 'DTC'): warn( "You chose a non-deterministic conditional, but \'DTC\' will be used because the model is not sparse" )
def get_sign2(f, *xyz, args=()): in_axes = tuple(range(len(xyz))) + tuple([None] * len(args)) f = bm.jit(bm.vmap(f_without_jaxarray_return(f), in_axes=in_axes)) xyz = tuple((v.value if isinstance(v, bm.JaxArray) else v) for v in xyz) XYZ = jnp.meshgrid(*xyz) XYZ = tuple(jnp.moveaxis(v, 1, 0).flatten() for v in XYZ) shape = (len(v) for v in xyz) return jnp.sign(f(*(XYZ + args))).reshape(shape)
def mesh_eval(func, x_limits, y_limits, params, num_ticks=101): # Evaluate func on a 2D grid defined by x_limits and y_limits. x = np.linspace(*x_limits, num=num_ticks) y = np.linspace(*y_limits, num=num_ticks) X, Y = np.meshgrid(x, y) xy_vec = np.stack([X.ravel(), Y.ravel()]).T zs = vmap(func, in_axes=(0, None))(xy_vec, params) return X, Y, zs.reshape(X.shape)
def calc_local_egm(params, points, dt=1, dx=0.01, real_size=12, scaled_size=256, radius=3): ''' Calculates the contact electrogram at the given points. - points: list of tuples, each tuple contains (x,y) of location (in cm) of unipolar - dt: timestep between states set to 1 ms - dx: grid discretization set to 0.01 cm - real_size: in cm - scaled_size: in pixels - matrix dimensions ''' states_file = h5py.File(params['file'], 'r') states = states_file['states'] shape = states.shape phi = np.zeros((len(points), shape[0])) x = np.linspace(0, real_size, states.shape[-1]) y = np.linspace(0, real_size, states.shape[-1]) xv, yv = np.meshgrid(x, y, sparse=False, indexing='ij') for i in range(states.shape[0]): u = states[i, 2, :, :] u_x = gradient(u, 0) / dx u_y = gradient(u, 1) / dx for j in range(len(points)): egm_x = points[j][0] egm_y = points[j][1] egm_dx = u_x[int(scaled_size * egm_x / real_size), int(scaled_size * egm_y / real_size)] egm_dy = u_y[int(scaled_size * egm_x / real_size), int(scaled_size * egm_y / real_size)] # val = np.sum((egm_dy*(yv - egm_y)+egm_dx*(xv - egm_x))*0.0001/(np.pi*4*2.36*1*np.sqrt((egm_x - xv)**2+(egm_y-yv)**2+10**(-6))**3)) # val = -np.sum((u_y*(egm_y-yv)+u_x*(egm_x - xv))*0.0001/(np.pi*4*2.36*1*np.sqrt((egm_x - xv)**2+(egm_y-yv)**2+10**(-6))**3)) # print(val) val = (u_y * (egm_y - yv) + u_x * (egm_x - xv)) * 0.0001 / (np.pi * 4 * 2.36 * 1 * np.sqrt( (egm_x - xv)**2 + (egm_y - yv)**2 + 10**(-2))**3) x1 = int(np.max((scaled_size * (egm_x - radius) / real_size, 0))) x2 = int( np.min( (scaled_size * (egm_x + radius) / real_size, scaled_size))) y1 = int(np.max((scaled_size * (egm_y - radius) / real_size, 0))) y2 = int( np.min( (scaled_size * (egm_y + radius) / real_size, scaled_size))) val = jax.ops.index_update(val, jax.ops.index[:x1, :], 0) val = jax.ops.index_update(val, jax.ops.index[x2:, :], 0) val = jax.ops.index_update(val, jax.ops.index[:, :y1], 0) val = jax.ops.index_update(val, jax.ops.index[:, y2:], 0) # plt.imshow(val) val = -np.sum(val) phi = jax.ops.index_update(phi, jax.ops.index[j, i], val) return phi
def generate_mesh(self): range_x = self.plot_max_x1 - self.plot_min_x1 range_y = self.plot_max_x2 - self.plot_min_x2 mesh_x, mesh_y = np.meshgrid( np.linspace(self.plot_min_x1, self.plot_max_x1, range_x * self.mesh_resolution), np.linspace(self.plot_min_x2, self.plot_max_x2, range_y * self.mesh_resolution)) return mesh_x, mesh_y
def blackman_kernel(dims, M): n = M - 2 apply = jax.vmap(lambda ns: blackman(M, norm(np.float64(ns)) / 2)) inds = np.stack( np.meshgrid(*(np.arange(1 - n, n, 2) for _ in range(dims))), axis = -1 ) kernel = apply(inds.reshape(-1, dims)) return (kernel / kernel.sum()).reshape(*(n for _ in range(dims)))
def get_all_indices(L): """ get combinatorial indices of an LxL square (zero_indexed) """ all_indices = jnp.transpose( jnp.array( jnp.meshgrid(jnp.arange(L, dtype=jnp.int32), jnp.arange(L, dtype=jnp.int32)))).reshape(-1, 2) return all_indices
def plot_gradient_field(ax, func, xlimits, ylimits, numticks=30): x = np.linspace(*xlimits, num=numticks) y = np.linspace(*ylimits, num=numticks) X, Y = np.meshgrid(x, y) zs = vmap(func)(Y.ravel(), X.ravel()) Z = zs.reshape(X.shape) ax.quiver(X, Y, np.ones(Z.shape), Z) ax.set_xlim(xlimits) ax.set_ylim(ylimits)
def dct_coefficients(N): alpha0 = jnp.ones((1, N)) alphaj = jnp.ones((N - 1, N)) + 1 a = jnp.sqrt(jnp.vstack([alpha0, alphaj]) / N) #a = jnp.sqrt(alpha0 / N) k, j = jnp.meshgrid(jnp.arange(N), jnp.arange(N)) C = a * jnp.cos(jnp.pi * (2 * k + 1) * j / (2 * N)) return C
def make_pupil(scale,npix=128): x = np.linspace(-0.5,0.5,npix) xx, yy = np.meshgrid(x,x) rr = np.sqrt(xx**2 + yy**2) mask = rr > (scale/2.) pupil = onp.ones_like(rr) pupil[mask] = 0 return pupil
def precompute_log_prob_components_without_wind(kernel, X, dtec, dtec_uncert, bottom_array, width_array, lengthscale_array, sigma_array, chunksize=2): """ Precompute the log_prob for each parameter. Args: kernel: X: dtec: dtec_uncert: *arrays: Returns: """ arrays = jnp.meshgrid(bottom_array, width_array, lengthscale_array, indexing='ij') arrays = [a.ravel() for a in arrays] def compute_log_prob_components(bottom, width, lengthscale): # N, N K = kernel(X, X, bottom, width, lengthscale, 1., wind_velocity=None) def _compute_with_sigma(sigma): def _compute(dtec, dtec_uncert): return log_normal_with_outliers(dtec, 0., sigma**2 * K, dtec_uncert) # M return chunked_pmap(_compute, dtec, dtec_uncert, chunksize=1) # Ns,M return chunked_pmap(_compute_with_sigma, sigma_array, chunksize=1) Nb = bottom_array.shape[0] Nw = width_array.shape[0] Nl = lengthscale_array.shape[0] Ns = sigma_array.shape[0] # Nb*Nw*Nl,Ns,M log_prob = chunked_pmap(compute_log_prob_components, *arrays, chunksize=chunksize) # M, Nb,Nw,Nl,Ns log_prob = log_prob.reshape((Nb * Nw * Nl * Ns, dtec.shape[0])).transpose( (1, 0)).reshape((dtec.shape[0], Nb, Nw, Nl, Ns)) return log_prob
def gridworld_plot_sa(env, data, title, ax=None, frame=(0, 0, 0, 0), step=None, log_plot=False): """ This is going to generate a quiver plot to visualize the policy graphically. It is useful to see all the probabilities assigned to the four possible actions in each state """ if ax is None: ax = plt.gca() num_cols = env.ncol if hasattr(env, "ncol") else env.size num_rows = env.ncol if hasattr(env, "nrow") else env.size num_obs, num_actions = data.shape direction = [ np.array((-1, 0)), # left np.array((1, 0)), # right np.array((0, 1)), # up np.array((0, -1)), # down ] x, y = np.meshgrid(np.arange(env.size), np.arange(env.size)) x, y = x.flatten(), y.flatten() for base, a in zip(direction, range(num_actions)): quivers = np.einsum("d,m->md", base, data[:, a]) pos = data[:, a] > 0 ax.quiver(x[pos], y[pos], *quivers[pos].T, units='xy', scale=2.0, color='g') pos = data[:, a] < 0 ax.quiver(x[pos], y[pos], *-quivers[pos].T, units='xy', scale=2.0, color='r') x0, x1, y0, y1 = frame # set axis limits / ticks / etc... so we have a nice grid overlay ax.set_xlim((x0 - 0.5, num_cols - x1 - 0.5)) ax.set_ylim((y0 - 0.5, num_rows - y1 - 0.5)[::-1]) ax.set_xticks(np.arange(x0, num_cols - x1, 1)) ax.xaxis.set_tick_params(labelsize=5) ax.set_yticks(np.arange(y0, num_rows - y1, 1)) ax.yaxis.set_tick_params(labelsize=5) # minor ticks ax.set_xticks(np.arange(*ax.get_xlim(), 1), minor=True) ax.set_yticks(np.arange(*ax.get_ylim()[::-1], 1), minor=True) ax.grid(which='minor', color='gray', linestyle='-', linewidth=1) ax.set_aspect(1) tag = f"plots/{title}" ax.set_title(title, fontdict={'fontsize': 8, 'fontweight': 'medium'}) if log_plot: plt.savefig(tag.replace(".", "_")) assert step is not None config.tb.plot(tag, plt, step) plt.clf()