コード例 #1
0
def test_jax_nonbonded_block():
    """Assert that nonbonded_block and nonbonded_on_specific_pairs agree"""
    system, positions, box, _ = builders.build_water_system(3.0)
    bps, masses = openmm_deserializer.deserialize_system(system, cutoff=1.2)
    nb = bps[-1]
    params = nb.params

    conf = positions.value_in_unit(unit.nanometer)

    N = conf.shape[0]
    beta = nb.get_beta()
    cutoff = nb.get_cutoff()

    split = 70

    def u_a(x, box, params):
        xi = x[:split]
        xj = x[split:]
        pi = params[:split]
        pj = params[split:]
        return nonbonded_block(xi, xj, box, pi, pj, beta, cutoff)

    i_s, j_s = np.indices((split, N - split))
    indices_left = i_s.flatten()
    indices_right = j_s.flatten() + split

    def u_b(x, box, params):
        vdw, es = nonbonded_v3_on_specific_pairs(x, params, box, indices_left,
                                                 indices_right, beta, cutoff)

        return np.sum(vdw + es)

    onp.testing.assert_almost_equal(u_a(conf, box, params),
                                    u_b(conf, box, params))
コード例 #2
0
def filled_circle_aa(shape,
                     xcenter,
                     ycenter,
                     radius,
                     xarray=None,
                     yarray=None,
                     fillvalue=1,
                     clip=True,
                     cliprange=(0, 1)):
    """Draw a filled circle with subpixel antialiasing into an array.

    Parameters
    -------------
    shape : 2d ndarray
        shape of array to return
    xcenter, ycenter : floats
        (X, Y) coordinates for the center of the circle (in the coordinate
        system specified by the xarray and yarray parameters, if those are given)
    radius : float
        Radius of the circle
    xarray, yarray : 2d ndarrays
        X and Y coordinates corresponding to the center of each pixel
        in the main array. If not present, integer pixel indices are assumed.
        WARNING - code currently is buggy with pixel scales != 1
    fillvalue : float
        Value to add into the array, for pixels that are entirely within the radius.
        This is *added* to each pixel at the specified coordinates. Default is 1
    clip : bool
        Clip the output array values to between the values given by the cliprange parameter.
    cliprange : array_like
        if clip is True, give values to use in the clip function.
    """

    array = np.zeros(shape)

    if xarray is None or yarray is None:
        yarray, xarray = np.indices(shape)

    r = np.sqrt((xarray - xcenter)**2 + (yarray - ycenter)**2)
    array = index_update(array, r < radius, fillvalue)

    pixscale = np.abs(xarray[0, 1] - xarray[0, 0])
    area_per_pix = pixscale**2

    if np.abs(pixscale - 1.0) > 0.01:
        import warnings
        warnings.warn(
            'filled_circle_aa may not yield exact results for grey pixels when pixel scale <1'
        )
    border = np.where(np.abs(r - radius) < pixscale)

    weights = pixwt(xcenter, ycenter, radius, xarray[border], yarray[border])

    array = index_update(array, border, weights * fillvalue / area_per_pix)

    if clip:
        assert len(cliprange) == 2
        return np.asarray(array).clip(*cliprange)
    else:
        return array
コード例 #3
0
ファイル: jax_utils.py プロジェクト: fehomi/timemachine
def get_group_group_indices(n: int, m: int) -> Tuple[Array, Array]:
    """all indices i, j such that i < n, j < m"""
    n_interactions = n * m

    _inds_i, _inds_j = np.indices((n, m))
    inds_i, inds_j = _inds_i.flatten(), _inds_j.flatten()

    assert len(inds_i) == n_interactions

    return inds_i, inds_j
コード例 #4
0
ファイル: spectral.py プロジェクト: CosmoStat/jax-lensing
def radial_profile(data):
    """
  Compute the radial profile of 2d image
  :param data: 2d image
  :return: radial profile
  """
    center = data.shape[0] / 2
    y, x = jnp.indices((data.shape))
    r = jnp.sqrt((x - center)**2 + (y - center)**2)
    r = r.astype('int32')

    tbin = jnp.bincount(r.ravel(), data.ravel())
    nr = jnp.bincount(r.ravel())
    radialprofile = tbin / nr
    return radialprofile
コード例 #5
0
def airy_2d(diameter=1.0,
            wavelength=1e-6,
            shape=(512, 512),
            pixelscale=0.010,
            obscuration=0.0,
            center=None):
    """ 2-dimensional Airy function PSF calculator

    Parameters
    ----------
    diameter: float
        aperture diameter in meters
    wavelength : float
        Wavelength in meters
    shape : tuple
        array shape
    pixelscale :
        arcseconds
    obscuration: float, optional
        Diameter of secondary obscuration
    center: tuple, optional
        Offset coordinates for center of output array
    """

    if center is None:
        center = (np.asarray(shape) - 1.) / 2
    y, x = np.indices(shape, dtype=float)
    y -= center[0]
    x -= center[1]
    y *= pixelscale
    x *= pixelscale
    r = np.sqrt(x**2 + y**2)

    radius = float(diameter) / 2.0

    k = 2 * np.pi / wavelength  # wavenumber
    v = k * radius * r * _ARCSECtoRAD
    e = obscuration

    # pedantically avoid divide by 0 by setting 0s to minimum nonzero number
    v[v == 0] = np.finfo(v.dtype).eps

    airy = 1. / (1 - e**2)**2 * (
        (2 * scipy.special.jn(1, v) - e * 2 * scipy.special.jn(1, e * v)) /
        v)**2
    # see e.g. Schroeder, Astronomical Optics, 2nd ed. page 248
    return airy
コード例 #6
0
def sinc2_2d(width=1.0,
             height=None,
             wavelength=1e-6,
             shape=(512, 512),
             pixelscale=0.010,
             center=None):
    """
    Create a 2D sinc function PSF, representing the PSF of a square or rectangular aperture

    Parameters
    -----------
    width : float
        Width in meters of the aperture.
    height : float, optional
        height in meters of the aperture. If not specified, the aperture is assumed
        to be a square so height=width
    wavelength : float
        wavelength in meters
    shape : tuple with 2 elements
        shape of array to create
    pixelscale : float
        pixel scale in arcseconds per pixel
    center : tuple with 2 elements, optional
        Center coordinates of the PSF. Defaults to center of array.

    """

    if height is None:
        height = width
    halfwidth = float(width) / 2
    halfheight = float(height) / 2

    if center is None:
        center = (np.asarray(shape) - 1.) / 2
    y, x = np.indices(shape, float)
    y -= center[0]
    x -= center[1]
    y *= pixelscale
    x *= pixelscale

    k = 2 * np.pi / wavelength  # wavenumber
    alpha = k * x * halfwidth * _ARCSECtoRAD
    beta = k * y * halfheight * _ARCSECtoRAD

    psf = (np.sinc(alpha))**2 * (np.sinc(beta))**2

    return psf
コード例 #7
0
ファイル: fresnel.py プロジェクト: quesmax/morphine
    def __init__(self,
                 beam_radius,
                 units=u.m,
                 rayleigh_factor=2.0,
                 oversample=2,
                 **kwargs):
        """
        Wavefront for Fresnel diffraction calculation.

        This class inherits from and extends the Fraunhofer-domain
        morphine.Wavefront class.


        Parameters
        --------------------
        beam_radius : astropy.Quantity of type length
            Radius of the illuminated beam at the initial optical plane.
            I.e. this would be the pupil aperture radius in an entrance pupil.
        units : astropy.units.Unit
            Astropy units of input parameters
        rayleigh_factor:
            Threshold for considering a wave spherical.
        oversample : float
            Padding factor to apply to the wavefront array, multiplying on top of the beam radius.


        References
        -------------------
        - Lawrence, G. N. (1992), Optical Modeling, in Applied Optics and Optical Engineering., vol. XI,
            edited by R. R. Shannon and J. C. Wyant., Academic Press, New York.

        - https://en.wikipedia.org/wiki/Gaussian_beam

        - IDEX Optics and Photonics(n.d.), Gaussian Beam Optics,
            [online] Available from:
            https://marketplace.idexop.com/store/SupportDocuments/All_About_Gaussian_Beam_OpticsWEB.pdf

        - Krist, J. E. (2007), PROPER: an optical propagation library for IDL,
            vol. 6675, p. 66750P-66750P-9.
            [online] Available from: http://dx.doi.org/10.1117/12.731179

        - Andersen, T., and A. Enmark (2011), Integrated Modeling of Telescopes, Springer Science & Business Media.

        """
        super(FresnelWavefront, self).__init__(diam=beam_radius * 2.0,
                                               oversample=oversample,
                                               **kwargs)
        self.units = 'm'
        """`astropy.units.Unit` for measuring distance"""

        self.w_0 = beam_radius  # convert to base units.
        """Beam waist radius at initial plane"""
        self.z = 0
        """Current wavefront coordinate along the optical axis"""
        self.z_w0 = 0
        """Coordinate along the optical axis of the latest beam waist"""
        self.waists_w0 = [self.w_0]
        """List of beam waist radii, in series as encountered during the course of an optical propagation."""
        self.waists_z = [self.z_w0]
        """List of beam waist distances along the optical axis, in series as encountered
        during the course of an optical propagation."""
        self.spherical = False
        """Is this wavefront spherical or planar?"""
        self.k = np.pi * 2.0 / self.wavelength
        """ Wavenumber"""
        self.rayleigh_factor = rayleigh_factor
        """Threshold for considering a wave spherical, in units of Rayleigh distance"""

        self.focal_length = np.inf
        """Focal length of the current beam, or infinity if not a focused beam"""

        if self.oversample > 1 and not self.ispadded:  # add padding for oversampling, if necessary
            self.wavefront = utils.pad_to_oversample(self.wavefront,
                                                     self.oversample)
            self.ispadded = True
            logmsg = "Padded WF array for oversampling by {0:d}, to {1}.".format(
                self.oversample, self.wavefront.shape)
            # _log.debug(logmsg)

            self.history.append(logmsg)
        else:
            pass
            # _log.debug("Skipping oversampling, oversample < 1 or already padded ")

        if self.oversample < 2:
            pass
            # _log.warning("Oversampling > 2x suggested for reliable results in Fresnel propagation.")

        self._y, self._x = np.indices(self.shape, dtype=float)
        self._y = self._y - (self.wavefront.shape[0]) / 2.0
        self._x = self._x - (self.wavefront.shape[1]) / 2.0
        """saves x and y indices for future use"""

        # FIXME MP: this self.n attribute appears unnecessary?
        if self.shape[0] == self.shape[1]:
            self.n = self.shape[0]
        else:
            self.n = self.shape

        if self.planetype == PlaneType.image:
            raise ValueError(
                "Input wavefront needs to be a pupil plane in units of m/pix. Specify a diameter not a pixelscale."
            )
コード例 #8
0
def position_offsets(height, width):
    """Generates a (height, width, 2) tensor containing pixel indices."""
    position_offset = jnp.indices((height, width))
    position_offset = jnp.moveaxis(position_offset, 0, -1)
    return position_offset
コード例 #9
0
def indices(dimensions, dtype=None, sparse=False):
  dtype = jnp.int32 if dtype is None else dtype
  return JaxArray(jnp.indices(dimensions, dtype, sparse))
コード例 #10
0
ix = 128
iy = 128

# Initalising the key and the kernel
key = random.PRNGKey(12234)
kernel = jnp.zeros((3, 3, 1, 1), dtype=jnp.float32)
kernel += jnp.array([[0, 1, 0],
                     [1, 0,1],
                     [0,1,0]])[:, :, jnp.newaxis, jnp.newaxis]

dn = lax.conv_dimension_numbers((K, ix, iy, 1),     # only ndim matters, not shape
                                 kernel.shape,  # only ndim matters, not shape 
                                ('NHWC', 'HWIO', 'NHWC'))  # the important bit

# Creating the checkerboard
mask = jnp.indices((K, iy, ix, 1)).sum(axis=0) % 2

def checkerboard_pattern1(x):
  return mask[0, :, : , 0]

def checkerboard_pattern2(x):
  return mask[1, :, : , 0]

def make_checkerboard_pattern1():
  arr = vmap(checkerboard_pattern1, in_axes=0)(jnp.array(K*[1]))
  return jnp.expand_dims(arr, -1)

def make_checkerboard_pattern2():
  arr = vmap(checkerboard_pattern2, in_axes=0)(jnp.array(K*[1]))
  return jnp.expand_dims(arr, -1)
コード例 #11
0
ファイル: camera.py プロジェクト: dukebw/nerfies
 def get_pixel_centers(self):
     """Returns the pixel centers."""
     shape = self.image_shape
     return jnp.moveaxis(jnp.indices(shape, dtype=self.dtype)[::-1], 0,
                         -1) + 0.5