def make_neur_distances(gridsizedeg, gridperdeg, hyper_col, PERIODIC=False):
    '''
    Makes a matrix of distances between neurons in the network
    gridsizedeg = size of grid in degress
    gridperdeg = number of grid points per degree
    hyper_col = hypercolumn length of the network
    Lx = length of the x direction in degrees
    Ly = length of the y direction in degrees
    
    outputs:
    X = matrix of distances between neurons in the x direction
    Y = matrix of distances between neurons in the y direction
    deltaD = matrix of distances between neurons used to make W
    '''
    gridsize = 1 + round(gridperdeg * gridsizedeg)

    Lx = gridsizedeg
    Ly = gridsizedeg

    dx = Lx / (gridsize - 1)
    dy = Ly / (gridsize - 1)

    [X, Y] = np.meshgrid(np.arange(0, Lx + dx, dx), np.arange(0, Ly + dy, dy))

    [indX, indY] = np.meshgrid(np.arange(gridsize), np.arange(gridsize))
    XDist = np.abs(np.ravel(indX, 'F') - np.ravel(indX, 'F')[:, None])
    YDist = np.abs(np.ravel(indY, 'F') - np.ravel(indY, 'F')[:, None])

    if PERIODIC:
        XDist = np.where(XDist > gridsize / 2, gridsize - XDist, XDist)
        YDist = np.where(YDist > gridsize / 2, gridsize - YDist, YDist)
    deltaD = np.sqrt(XDist**2 + YDist**2) * dx

    return X, Y, deltaD
Пример #2
0
def _make_xy_arrays(shape, On):
    x = jnp.arange(shape[1], dtype=jnp.float64)
    y = jnp.arange(shape[0], dtype=jnp.float64)
    cx, cy = jnp.meshgrid(x, y)
    x_super = jnp.linspace(0.5 / On - 0.5, shape[1] - 0.5 - 0.5 / On,
                          shape[1] * On)
    y_super = jnp.linspace(0.5 / On - 0.5, shape[0] - 0.5 - 0.5 / On,
                          shape[0] * On)
    cx_super, cy_super = jnp.meshgrid(x_super, y_super)
    return (cx, cy), (cx_super, cy_super)
Пример #3
0
def steady_probes(filepath, dt=1, dx=0.01, real_size=12, scaled_size=256):
    '''
    Calculates the contact electrogram at a 4x4 grid, that is distributed along the field.
    - points: list of tuples, each tuple contains (x,y) of location (in cm) of unipolar
    - dt: timestep between states set to 1 ms
    - dx: grid discretization set to 0.01 cm
    - real_size: in cm
    - scaled_size: in pixels - matrix dimensions

    Returns the 4x4 field, of 16 pixels, where each pixel is the value of the contact electrogram
    at each of the positions.
    '''
    p = [2.4, 4.8, 7.2, 9.6]
    points = np.meshgrid(p, p, indexing='ij')
    filename = filepath[:-5] + '_ecg.hdf5'
    states_file = h5py.File(filepath, 'r')
    states = states_file['states']
    conductivity_field = states_file["params/D"][:]
    shape = states.shape
    phi = np.zeros((shape[0], 4, 4))
    x = np.linspace(0, real_size, states.shape[-1])
    y = np.linspace(0, real_size, states.shape[-1])

    xv, yv = np.meshgrid(x, y, sparse=False, indexing='ij')

    egm_x = points[0].reshape((1, 16))
    egm_y = points[1].reshape((1, 16))

    for i in range(states.shape[0]):
        u = states[i, 2]
        u_x = gradient(u, 0) / dx
        u_y = gradient(u, 1) / dx
        val = -np.sum((u_y[:, :, np.newaxis] *
                       (egm_y - yv[:, :, np.newaxis]) + u_x[:, :, np.newaxis] *
                       (egm_x - xv[:, :, np.newaxis])) * 0.0001 /
                      (np.pi * 4 * 2.36 * 1 * np.sqrt(
                          (egm_x - xv[:, :, np.newaxis])**2 +
                          (egm_y - yv[:, :, np.newaxis])**2 + 10**(-2))**3),
                      axis=(0, 1)).reshape((4, 4))
        phi = jax.ops.index_update(phi, i, val)

    hdf5 = h5py.File(filename, "w")
    ecg_dset = hdf5.create_dataset('electrogram',
                                   shape=phi.shape,
                                   dtype="float32")
    ecg_dset[:] = phi
    conductivity = hdf5.create_dataset('conductivity',
                                       shape=states.shape[-2:],
                                       dtype='float32')
    conductivity[:] = conductivity_field
    return True
Пример #4
0
def ks93(g1, g2):
    """Direct inversion of weak-lensing shear to convergence.
    This function is an implementation of the Kaiser & Squires (1993) mass
    mapping algorithm. Due to the mass sheet degeneracy, the convergence is
    recovered only up to an overall additive constant. It is chosen here to
    produce output maps of mean zero. The inversion is performed in Fourier
    space for speed.
    Parameters
    ----------
    g1, g2 : array_like
        2D input arrays corresponding to the first and second (i.e., real and
        imaginary) components of shear, binned spatially to a regular grid.
    Returns
    -------
    kE, kB : tuple of numpy arrays
        E-mode and B-mode maps of convergence.
    Raises
    ------
    AssertionError
        For input arrays of different sizes.
    See Also
    --------
    bin2d
        For binning a galaxy shear catalog.
    Examples
    --------
    >>> # (g1, g2) should in practice be measurements from a real galaxy survey
    >>> g1, g2 = 0.1 * np.random.randn(2, 32, 32) + 0.1 * np.ones((2, 32, 32))
    >>> kE, kB = ks93(g1, g2)
    >>> kE.shape
    (32, 32)
    >>> kE.mean()
    1.0842021724855044e-18
    """
    # Check consistency of input maps
    assert g1.shape == g2.shape

    # Compute Fourier space grids
    (nx, ny) = g1.shape
    k1, k2 = np.meshgrid(np.fft.fftfreq(ny), np.fft.fftfreq(nx))

    # Compute Fourier transforms of g1 and g2
    g1hat = np.fft.fft2(g1)
    g2hat = np.fft.fft2(g2)

    # Apply Fourier space inversion operator
    p1 = k1 * k1 - k2 * k2
    p2 = 2 * k1 * k2
    k2 = k1 * k1 + k2 * k2
    #k2[0, 0] = 1  # avoid division by 0
    k2 = jax.ops.index_update(k2, jax.ops.index[0, 0],
                              1.)  # avoid division by 0
    kEhat = (p1 * g1hat + p2 * g2hat) / k2
    kBhat = -(p2 * g1hat - p1 * g2hat) / k2

    # Transform back to pixel space
    kE = np.fft.ifft2(kEhat).real
    kB = np.fft.ifft2(kBhat).real

    return kE, kB
Пример #5
0
def get_ray_bundle(height, width, focal_length, tfrom_cam2world):
    ii, jj = jnp.meshgrid(
        jnp.arange(
            width,
            dtype=jnp.float32,
        ),
        jnp.arange(
            height,
            dtype=jnp.float32,
        ),
        indexing="xy",
    )

    directions = jnp.stack(
        [
            (ii - width * 0.5) / focal_length,
            -(jj - height * 0.5) / focal_length,
            -jnp.ones_like(ii),
        ],
        axis=-1,
    )

    ray_directions = jnp.sum(directions[..., None, :] *
                             tfrom_cam2world[:3, :3],
                             axis=-1)
    ray_origins = jnp.broadcast_to(tfrom_cam2world[:3, -1],
                                   ray_directions.shape)
    return ray_origins, ray_directions
Пример #6
0
def show3d(state, rcount=200, ccount=200, zlim=None, figsize=None):
    # setup figure
    fig = plt.figure(figsize=figsize)
    ax = fig.add_subplot(projection="3d")
    # make surface plot
    r = list(range(0, len(state)))
    x, y = np.meshgrid(r, r)
    plot = ax.plot_surface(x,
                           y,
                           state,
                           rcount=rcount,
                           ccount=ccount,
                           cmap="magma")
    # add colorbar
    cbar = fig.colorbar(plot)
    cbar.ax.set_title("mV")
    if zlim is not None:
        ax.set_zlim3d(zlim[0], zlim[1])
    # format axes
    ax.w_xaxis.set_pane_color((1.0, 1.0, 1.0, 1.0))
    ax.w_yaxis.set_pane_color((1.0, 1.0, 1.0, 1.0))
    ax.w_zaxis.set_pane_color((1.0, 1.0, 1.0, 1.0))
    ax.xaxis.set_major_formatter(
        FuncFormatter(lambda y, _: '{:.1f}'.format(y / 100)))
    ax.set_xlabel("x [cm]")
    ax.yaxis.set_major_formatter(
        FuncFormatter(lambda y, _: '{:.1f}'.format(y / 100)))
    ax.set_ylabel("y [cm]")
    ax.set_zlabel("Voltage [mV]")
    # crop image
    fig.tight_layout()
    return fig, ax
Пример #7
0
def evaluate_density(log_p_fn, bounds, num_points):
    xs = jnp.linspace(bounds[0][0], bounds[0][1], num=num_points)
    ys = jnp.linspace(bounds[1][0], bounds[1][1], num=num_points)
    X, Y = jnp.meshgrid(xs, ys)
    xs = jnp.stack([X, Y], axis=-1)
    log_ps = log_p_fn(xs)
    return X, Y, jnp.exp(log_ps)
Пример #8
0
    def _get_1d_latent_grid(self, paths_x):

        num_points_grid = 20

        max_x = np.amax(paths_x)
        min_x = np.amin(paths_x)
        x_array = np.linspace(min_x, max_x, 20)[:, None]
        xx, tt = np.meshgrid(np.linspace(min_x, max_x, 20),
                             np.linspace(0, 1, paths_x.shape[1]))
        txpairs = np.transpose(np.vstack([tt.reshape(-1), xx.reshape(-1)]))

        def scan_fn(carry, paths):
            drift, diffusion, index = carry
            time = index * self.config["delta_t"] - 0.5

            gp_matrices, temp_drift_function, temp_diffusion_function = self.model.build(
                self.model.model_vars())
            temp_drift = temp_drift_function(x_array, time)
            temp_diffusion = np.linalg.det(
                temp_diffusion_function(x_array, time))

            drift = ops.index_add(drift, ops.index[index], temp_drift)
            diffusion = ops.index_add(diffusion, ops.index[index],
                                      temp_diffusion)
            index += 1

            return (drift, diffusion, index), np.array([0.])

        drift_grid = np.zeros((paths_x.shape[1], num_points_grid, 1))
        diffusion_grid = np.zeros((paths_x.shape[1], num_points_grid))
        (drift_grid, diffusion_grid,
         index), _ = lax.scan(scan_fn, (drift_grid, diffusion_grid, 0),
                              np.transpose(paths_x, (1, 0, 2)))
        return txpairs, drift_grid.reshape(-1), diffusion_grid.reshape(-1)
Пример #9
0
def extract_weighted_patches(x, weights, kernel, stride, padding):
    """Weighted average of patches using jax.lax.scan."""
    logging.info("recompiling for kernel=%s and stride=%s and padding=%s",
                 kernel, stride, padding)
    x = jnp.pad(x, ((0, 0), (padding[0], padding[0] + kernel[0]),
                    (padding[1], padding[1] + kernel[1]), (0, 0)))
    batch_size, _, _, channels = x.shape
    _, k, weights_h, weights_w = weights.shape

    def accumulate_patches(acc, index_i_j):
        i, j = index_i_j
        patch = jax.lax.dynamic_slice(
            x, (0, i * stride[0], j * stride[1], 0),
            (batch_size, kernel[0], kernel[1], channels))
        weight = weights[:, :, i, j]

        weighted_patch = jnp.einsum("bk, bijc -> bkijc", weight, patch)
        acc += weighted_patch
        return acc, None

    indices = jnp.stack(jnp.meshgrid(jnp.arange(weights_h),
                                     jnp.arange(weights_w),
                                     indexing="ij"),
                        axis=-1)
    indices = indices.reshape((-1, 2))

    init_patches = jnp.zeros((batch_size, k, kernel[0], kernel[1], channels))
    patches, _ = jax.lax.scan(accumulate_patches, init_patches, indices)

    return patches
Пример #10
0
def gen_julia():
    im_width, im_height = 5000, 5000
    max_it = 300
    max_z = 100
    min_x, max_x = -1, 1
    min_y, max_y = -1, 1
    c = complex(-0.5, 0.61)

    ix_array, iy_array = np.meshgrid(range(im_width), range(im_height))

    real_part = ((max_x - min_x) / im_width) * ix_array + min_x
    im_part = ((max_y - min_y) / im_height) * iy_array + min_y
    z = real_part + (1j) * im_part

    not_done = np.ones(
        (im_width, im_height)
    )  # Array which records which pixels have not finished being computed
    it = np.zeros((im_width, im_height))

    while np.any(not_done):
        not_done = np.logical_and(np.abs(z) < max_z, it < max_it)
        new_z_vals = z[not_done]**2 + c
        jax.ops.index_update(z, not_done, new_z_vals)
        jax.ops.index_add(it, not_done, 1)

    julia_set = it / max_it
    return julia_set
Пример #11
0
def generate_sample_grid(theta_mean, theta_std, n):
    """
    Create a meshgrid of n ** n_dim samples,
    tiling [theta_mean[i] - 5 * theta_std[i], theta_mean[i] + 5 * theta_std]
    into n portions.
    Also returns the volume element.

    Parameters
    ----------
    theta_mean, theta_std : ndarray (n_dim)

    Returns
    -------
    theta_samples : ndarray (nobj, n_dim)

    vol_element: scalar
        Volume element

    """
    n_components = theta_mean.size
    xs = [
        np.linspace(
            theta_mean[i] - 5 * theta_std[i],
            theta_mean[i] + 5 * theta_std[i],
            n,
        )
        for i in range(n_components)
    ]
    mxs = np.meshgrid(*xs)
    orshape = mxs[0].shape
    mxsf = np.vstack([i.ravel() for i in mxs]).T
    dxs = np.vstack([np.diff(xs[i])[i] for i in range(n_components)])
    vol_element = np.prod(dxs)
    theta_samples = np.vstack(mxsf)
    return theta_samples, vol_element
Пример #12
0
def cartesian_product(*arrays):
    '''
    JAX-friendly version of cartesian product. 
    '''
    tmp = jnp.asarray(jnp.meshgrid(*arrays,
                                   indexing='ij')).reshape(len(arrays), -1).T
    return tmp
 def get_pixel_grid(self, scene: Scene) -> Any:
   """Constructs a spatial grid of points to shoot rays through."""
   aspect_ratio = scene.width / scene.height
   x = jnp.linspace(-1., 1., scene.width)
   y = jnp.linspace(1 / aspect_ratio, -1. / aspect_ratio, scene.height)
   X, Y = jnp.meshgrid(x, y)
   return X, Y
Пример #14
0
def sampler(img_target, pose, intrinsics, rng, options):
    """
    Given a single image, samples rays
    """
    pose_target = pose[:3, :4]

    ray_origins, ray_directions = get_ray_bundle(
        intrinsics.height, intrinsics.width, intrinsics.focal_length, pose_target
    )

    coords = jnp.stack(
        jnp.meshgrid(
            jnp.arange(intrinsics.height), jnp.arange(intrinsics.width), indexing="xy"
        ),
        axis=-1,
    ).reshape((-1, 2))

    select_inds = jax.random.choice(
        rng, coords.shape[0], shape=(options.num_random_rays,), replace=False
    )
    select_inds = coords[select_inds]

    ray_origins = ray_origins[select_inds[:, 0], select_inds[:, 1], :]
    ray_directions = ray_directions[select_inds[:, 0], select_inds[:, 1], :]

    target_s = img_target[select_inds[:, 0], select_inds[:, 1], :]

    return ray_origins, ray_directions, target_s
Пример #15
0
    def _make_gabor(params: jnp.ndarray,
                    rf_dim: Tuple[int, int]) -> jnp.DeviceArray:
        σ, θ, λ, γ, φ = [
            u[:, jnp.newaxis, jnp.newaxis]
            for u in (params[:, 0], params[:, 1], params[:, 2], params[:, 3],
                      params[:, 4])
        ]
        pos_x, pos_y = [
            u[:, jnp.newaxis, jnp.newaxis]
            for u in (params[:, 5], params[:, 6])
        ]

        n = params.shape[0]

        x, y = jnp.meshgrid(jnp.arange(-rf_dim[0], rf_dim[0]),
                            jnp.arange(-rf_dim[1], rf_dim[1]))
        x = jnp.repeat(x[jnp.newaxis, :, :], n, axis=0)
        y = jnp.repeat(y[jnp.newaxis, :, :], n, axis=0)

        xp = (pos_x - x) * cos(θ) - (pos_y - y) * sin(θ)
        yp = (pos_x - x) * sin(θ) + (pos_y - y) * cos(θ)

        output = exp(-(xp**2 + (γ * yp)**2) /
                     (2 * σ**2)) * exp(1j * (2 * π * xp / λ + φ))

        return zscore_img(output.real)
Пример #16
0
 def __init__(self,
              variance=1.0,
              lengthscale_periodic=1.0,
              period=1.0,
              lengthscale_matern=1.0,
              order=6):
     self.transformed_lengthscale_periodic = objax.TrainVar(
         np.array(softplus_inv(lengthscale_periodic)))
     self.transformed_variance = objax.TrainVar(
         np.array(softplus_inv(variance)))
     self.transformed_period = objax.TrainVar(np.array(
         softplus_inv(period)))
     self.transformed_lengthscale_matern = objax.TrainVar(
         np.array(softplus_inv(lengthscale_matern)))
     super().__init__()
     self.name = 'Quasi-periodic Matern-3/2'
     self.order = order
     self.igrid = np.meshgrid(np.arange(self.order + 1),
                              np.arange(self.order + 1))[1]
     factorial_mesh_K = np.array(
         [[1., 1., 1., 1., 1., 1., 1.], [1., 1., 1., 1., 1., 1., 1.],
          [2., 2., 2., 2., 2., 2., 2.], [6., 6., 6., 6., 6., 6., 6.],
          [24., 24., 24., 24., 24., 24., 24.],
          [120., 120., 120., 120., 120., 120., 120.],
          [720., 720., 720., 720., 720., 720., 720.]])
     b = np.array([[1., 0., 0., 0., 0., 0.,
                    0.], [0., 2., 0., 0., 0., 0., 0.],
                   [2., 0., 2., 0., 0., 0.,
                    0.], [0., 6., 0., 2., 0., 0., 0.],
                   [6., 0., 8., 0., 2., 0., 0.],
                   [0., 20., 0., 10., 0., 2., 0.],
                   [20., 0., 30., 0., 12., 0., 2.]])
     self.b_fmK_2igrid = b * (1. / factorial_mesh_K) * (2.**-self.igrid)
Пример #17
0
def make_2d_quadrature_grid(x_quad, w_quad):

    grid = jnp.meshgrid(x_quad, x_quad)
    weights = jnp.outer(w_quad, w_quad).reshape(-1)
    stacked = jnp.stack([grid[0].reshape(-1), grid[1].reshape(-1)], axis=1)

    return stacked.T, weights
Пример #18
0
def calc_round_egm(params,
                   points,
                   dt=1,
                   dx=0.01,
                   real_size=12,
                   scaled_size=256,
                   radius=3):
    '''
    Calculates the contact electrogram at the given points.
    - points: list of tuples, each tuple contains (x,y) of location (in cm) of unipolar
    - dt: timestep between states set to 1 ms
    - dx: grid discretization set to 0.01 cm
    - real_size: in cm
    - scaled_size: in pixels - matrix dimensions
    '''
    states_file = h5py.File(params['file'], 'r')
    states = states_file['states']
    shape = states.shape
    phi = np.zeros((len(points), shape[0]))
    x = np.linspace(0, real_size, states.shape[-1])
    y = np.linspace(0, real_size, states.shape[-1])

    xv, yv = np.meshgrid(x, y, sparse=False, indexing='ij')
    radi = radius * scaled_size / real_size
    X = int(radi)
    circle = np.array([[x, y] for x in range(-X, X + 1)
                       for y in range(-int((radi * radi - x * x)**0.5),
                                      int((radi * radi - x * x)**0.5) + 1)])

    for i in range(states.shape[0]):
        u = states[i, 2, :, :]
        u_x = gradient(u, 0) / dx
        u_y = gradient(u, 1) / dx
        for j in range(len(points)):
            egm_x = points[j][0]
            egm_y = points[j][1]

            circ_mask = circle + np.array([
                int(scaled_size * egm_x / real_size),
                int(scaled_size * egm_y / real_size)
            ])
            circ_mask = circ_mask[(circ_mask[:, 0] < scaled_size)
                                  & (circ_mask[:, 1] < scaled_size)]
            ii = circ_mask[:, 0]
            jj = circ_mask[:, 1]

            val = (u_y[ii, jj] * (egm_y - yv[ii, jj]) + u_x[ii, jj] *
                   (egm_x - xv[ii, jj])) * 0.0001 / (
                       np.pi * 4 * 2.36 * 1 *
                       np.sqrt((egm_x - xv[ii, jj])**2 +
                               (egm_y - yv[ii, jj])**2 + 10**(-2))**3)

            #             val = jax.ops.index_update(val, jax.ops.index[circ_mask[:,0], circ_mask[:,1]], 0)

            #             plt.imshow(val)
            val = -np.sum(val)
            phi = jax.ops.index_update(phi, jax.ops.index[j, i], val)
#         print(i)
    return phi
Пример #19
0
 def __init__(self,
              temporal_kernel,
              spatial_kernel,
              z=None,
              conditional=None,
              sparse=True,
              opt_z=False,
              spatial_dims=None):
     self.temporal_kernel = temporal_kernel
     self.spatial_kernel = spatial_kernel
     if conditional is None:
         if sparse:
             conditional = 'Full'
         else:
             conditional = 'DTC'
     if opt_z and (
             not sparse
     ):  # z should not be optimised if the model is not sparse
         warn(
             "spatial inducing inputs z will not be optimised because sparse=False"
         )
         opt_z = False
     self.sparse = sparse
     if z is None:  # initialise z
         # TODO: smart initialisation
         if spatial_dims == 1:
             z = np.linspace(-3., 3., num=15)
         elif spatial_dims == 2:
             z1 = np.linspace(-3., 3., num=5)
             zA, zB = np.meshgrid(
                 z1,
                 z1)  # Adding additional dimension to inducing points grid
             z = np.hstack((zA.reshape(-1, 1), zB.reshape(
                 -1, 1)))  # Flattening grid for use in kernel functions
         else:
             raise NotImplementedError(
                 'please provide an initialisation for inducing inputs z')
     if z.ndim < 2:
         z = z[:, np.newaxis]
     if spatial_dims is None:
         spatial_dims = z.ndim - 1
     assert spatial_dims == z.ndim - 1
     self.M = z.shape[0]
     if opt_z:
         self.z = objax.TrainVar(z)  # .reshape(-1, 1)
     else:
         self.z = objax.StateVar(z)
     if conditional in ['DTC', 'dtc']:
         self.conditional_covariance = self.deterministic_training_conditional
     elif conditional in ['FIC', 'FITC', 'fic', 'fitc']:
         self.conditional_covariance = self.fully_independent_conditional
     elif conditional in ['Full', 'full']:
         self.conditional_covariance = self.full_conditional
     else:
         raise NotImplementedError('conditional method not recognised')
     if (not sparse) and (conditional != 'DTC'):
         warn(
             "You chose a non-deterministic conditional, but \'DTC\' will be used because the model is not sparse"
         )
Пример #20
0
def get_sign2(f, *xyz, args=()):
  in_axes = tuple(range(len(xyz))) + tuple([None] * len(args))
  f = bm.jit(bm.vmap(f_without_jaxarray_return(f), in_axes=in_axes))
  xyz = tuple((v.value if isinstance(v, bm.JaxArray) else v) for v in xyz)
  XYZ = jnp.meshgrid(*xyz)
  XYZ = tuple(jnp.moveaxis(v, 1, 0).flatten() for v in XYZ)
  shape = (len(v) for v in xyz)
  return jnp.sign(f(*(XYZ + args))).reshape(shape)
Пример #21
0
def mesh_eval(func, x_limits, y_limits, params, num_ticks=101):
    # Evaluate func on a 2D grid defined by x_limits and y_limits.
    x = np.linspace(*x_limits, num=num_ticks)
    y = np.linspace(*y_limits, num=num_ticks)
    X, Y = np.meshgrid(x, y)
    xy_vec = np.stack([X.ravel(), Y.ravel()]).T
    zs = vmap(func, in_axes=(0, None))(xy_vec, params)
    return X, Y, zs.reshape(X.shape)
Пример #22
0
def calc_local_egm(params,
                   points,
                   dt=1,
                   dx=0.01,
                   real_size=12,
                   scaled_size=256,
                   radius=3):
    '''
    Calculates the contact electrogram at the given points.
    - points: list of tuples, each tuple contains (x,y) of location (in cm) of unipolar
    - dt: timestep between states set to 1 ms
    - dx: grid discretization set to 0.01 cm
    - real_size: in cm
    - scaled_size: in pixels - matrix dimensions
    '''
    states_file = h5py.File(params['file'], 'r')
    states = states_file['states']
    shape = states.shape
    phi = np.zeros((len(points), shape[0]))
    x = np.linspace(0, real_size, states.shape[-1])
    y = np.linspace(0, real_size, states.shape[-1])

    xv, yv = np.meshgrid(x, y, sparse=False, indexing='ij')

    for i in range(states.shape[0]):
        u = states[i, 2, :, :]
        u_x = gradient(u, 0) / dx
        u_y = gradient(u, 1) / dx
        for j in range(len(points)):
            egm_x = points[j][0]
            egm_y = points[j][1]
            egm_dx = u_x[int(scaled_size * egm_x / real_size),
                         int(scaled_size * egm_y / real_size)]
            egm_dy = u_y[int(scaled_size * egm_x / real_size),
                         int(scaled_size * egm_y / real_size)]
            #             val = np.sum((egm_dy*(yv - egm_y)+egm_dx*(xv - egm_x))*0.0001/(np.pi*4*2.36*1*np.sqrt((egm_x - xv)**2+(egm_y-yv)**2+10**(-6))**3))
            #             val = -np.sum((u_y*(egm_y-yv)+u_x*(egm_x - xv))*0.0001/(np.pi*4*2.36*1*np.sqrt((egm_x - xv)**2+(egm_y-yv)**2+10**(-6))**3))
            #             print(val)
            val = (u_y * (egm_y - yv) + u_x *
                   (egm_x - xv)) * 0.0001 / (np.pi * 4 * 2.36 * 1 * np.sqrt(
                       (egm_x - xv)**2 + (egm_y - yv)**2 + 10**(-2))**3)
            x1 = int(np.max((scaled_size * (egm_x - radius) / real_size, 0)))
            x2 = int(
                np.min(
                    (scaled_size * (egm_x + radius) / real_size, scaled_size)))
            y1 = int(np.max((scaled_size * (egm_y - radius) / real_size, 0)))
            y2 = int(
                np.min(
                    (scaled_size * (egm_y + radius) / real_size, scaled_size)))
            val = jax.ops.index_update(val, jax.ops.index[:x1, :], 0)
            val = jax.ops.index_update(val, jax.ops.index[x2:, :], 0)
            val = jax.ops.index_update(val, jax.ops.index[:, :y1], 0)
            val = jax.ops.index_update(val, jax.ops.index[:, y2:], 0)
            #             plt.imshow(val)

            val = -np.sum(val)
            phi = jax.ops.index_update(phi, jax.ops.index[j, i], val)
    return phi
Пример #23
0
 def generate_mesh(self):
     range_x = self.plot_max_x1 - self.plot_min_x1
     range_y = self.plot_max_x2 - self.plot_min_x2
     mesh_x, mesh_y = np.meshgrid(
         np.linspace(self.plot_min_x1, self.plot_max_x1,
                     range_x * self.mesh_resolution),
         np.linspace(self.plot_min_x2, self.plot_max_x2,
                     range_y * self.mesh_resolution))
     return mesh_x, mesh_y
Пример #24
0
def blackman_kernel(dims, M):
    n = M - 2
    apply = jax.vmap(lambda ns: blackman(M, norm(np.float64(ns)) / 2))
    inds = np.stack(
        np.meshgrid(*(np.arange(1 - n, n, 2) for _ in range(dims))),
        axis = -1
    )
    kernel = apply(inds.reshape(-1, dims))
    return (kernel / kernel.sum()).reshape(*(n for _ in range(dims)))
Пример #25
0
def get_all_indices(L):
    """
    get combinatorial indices of an LxL square (zero_indexed)
    """
    all_indices = jnp.transpose(
        jnp.array(
            jnp.meshgrid(jnp.arange(L, dtype=jnp.int32),
                         jnp.arange(L, dtype=jnp.int32)))).reshape(-1, 2)
    return all_indices
Пример #26
0
def plot_gradient_field(ax, func, xlimits, ylimits, numticks=30):
    x = np.linspace(*xlimits, num=numticks)
    y = np.linspace(*ylimits, num=numticks)
    X, Y = np.meshgrid(x, y)
    zs = vmap(func)(Y.ravel(), X.ravel())
    Z = zs.reshape(X.shape)
    ax.quiver(X, Y, np.ones(Z.shape), Z)
    ax.set_xlim(xlimits)
    ax.set_ylim(ylimits)
Пример #27
0
Файл: dct.py Проект: nmheim/esn
def dct_coefficients(N):
    alpha0 = jnp.ones((1, N))
    alphaj = jnp.ones((N - 1, N)) + 1
    a = jnp.sqrt(jnp.vstack([alpha0, alphaj]) / N)

    #a = jnp.sqrt(alpha0 / N)
    k, j = jnp.meshgrid(jnp.arange(N), jnp.arange(N))
    C = a * jnp.cos(jnp.pi * (2 * k + 1) * j / (2 * N))
    return C
Пример #28
0
def make_pupil(scale,npix=128):
    x = np.linspace(-0.5,0.5,npix)
    xx, yy = np.meshgrid(x,x)
    
    rr = np.sqrt(xx**2 + yy**2)
    mask = rr > (scale/2.)
    pupil = onp.ones_like(rr)
    pupil[mask] = 0
    return pupil
def precompute_log_prob_components_without_wind(kernel,
                                                X,
                                                dtec,
                                                dtec_uncert,
                                                bottom_array,
                                                width_array,
                                                lengthscale_array,
                                                sigma_array,
                                                chunksize=2):
    """
    Precompute the log_prob for each parameter.

    Args:
        kernel:
        X:
        dtec:
        dtec_uncert:
        *arrays:

    Returns:

    """

    arrays = jnp.meshgrid(bottom_array,
                          width_array,
                          lengthscale_array,
                          indexing='ij')
    arrays = [a.ravel() for a in arrays]

    def compute_log_prob_components(bottom, width, lengthscale):
        # N, N
        K = kernel(X, X, bottom, width, lengthscale, 1., wind_velocity=None)

        def _compute_with_sigma(sigma):
            def _compute(dtec, dtec_uncert):
                return log_normal_with_outliers(dtec, 0., sigma**2 * K,
                                                dtec_uncert)

            # M
            return chunked_pmap(_compute, dtec, dtec_uncert, chunksize=1)

        # Ns,M
        return chunked_pmap(_compute_with_sigma, sigma_array, chunksize=1)

    Nb = bottom_array.shape[0]
    Nw = width_array.shape[0]
    Nl = lengthscale_array.shape[0]
    Ns = sigma_array.shape[0]

    # Nb*Nw*Nl,Ns,M
    log_prob = chunked_pmap(compute_log_prob_components,
                            *arrays,
                            chunksize=chunksize)
    # M, Nb,Nw,Nl,Ns
    log_prob = log_prob.reshape((Nb * Nw * Nl * Ns, dtec.shape[0])).transpose(
        (1, 0)).reshape((dtec.shape[0], Nb, Nw, Nl, Ns))
    return log_prob
Пример #30
0
def gridworld_plot_sa(env, data, title, ax=None, frame=(0, 0, 0, 0), step=None, log_plot=False):
    """
    This is going to generate a quiver plot to visualize the policy graphically.
    It is useful to see all the probabilities assigned to the four possible actions
    in each state
    """
    if ax is None:
        ax = plt.gca()
    num_cols = env.ncol if hasattr(env, "ncol") else env.size
    num_rows = env.ncol if hasattr(env, "nrow") else env.size

    num_obs, num_actions = data.shape

    direction = [
        np.array((-1, 0)),  # left
        np.array((1, 0)),  # right
        np.array((0, 1)),  # up
        np.array((0, -1)),  # down
    ]

    x, y = np.meshgrid(np.arange(env.size), np.arange(env.size))
    x, y = x.flatten(), y.flatten()

    for base, a in zip(direction, range(num_actions)):
        quivers = np.einsum("d,m->md", base, data[:, a])

        pos = data[:, a] > 0
        ax.quiver(x[pos], y[pos], *quivers[pos].T, units='xy', scale=2.0, color='g')

        pos = data[:, a] < 0
        ax.quiver(x[pos], y[pos], *-quivers[pos].T, units='xy', scale=2.0, color='r')

    x0, x1, y0, y1 = frame
    # set axis limits / ticks / etc... so we have a nice grid overlay
    ax.set_xlim((x0 - 0.5, num_cols - x1 - 0.5))
    ax.set_ylim((y0 - 0.5, num_rows - y1 - 0.5)[::-1])

    ax.set_xticks(np.arange(x0, num_cols - x1, 1))
    ax.xaxis.set_tick_params(labelsize=5)
    ax.set_yticks(np.arange(y0, num_rows - y1, 1))
    ax.yaxis.set_tick_params(labelsize=5)

    # minor ticks
    ax.set_xticks(np.arange(*ax.get_xlim(), 1), minor=True)
    ax.set_yticks(np.arange(*ax.get_ylim()[::-1], 1), minor=True)

    ax.grid(which='minor', color='gray', linestyle='-', linewidth=1)
    ax.set_aspect(1)

    tag = f"plots/{title}"
    ax.set_title(title, fontdict={'fontsize': 8, 'fontweight': 'medium'})
    if log_plot:
        plt.savefig(tag.replace(".", "_"))
        assert step is not None
        config.tb.plot(tag, plt, step)
        plt.clf()