Ejemplo n.º 1
0
 def _validate_args(self, nlamD, npix, offset):
     if self.centering == SYMMETRIC:
         if not np.isscalar(nlamD) or not np.isscalar(npix):
             raise RuntimeError(
                 'The selected centering mode, {}, does not support '
                 'rectangular arrays.'.format(self.centering)
             )
     if self.centering == FFTSTYLE or self.centering == SYMMETRIC:
         if offset is not None:
             raise RuntimeError(
                 'The selected centering mode, {}, does not support '
                 'position offsets.'.format(self.centering)
             )
Ejemplo n.º 2
0
def BatchNorm(axis=(0, 1, 2),
              epsilon=1e-5,
              center=True,
              scale=True,
              beta_init=zeros,
              gamma_init=ones):
    """Layer construction function for a batch normalization layer."""
    _beta_init = lambda shape: beta_init(shape) if center else ()
    _gamma_init = lambda shape: gamma_init(shape) if scale else ()
    axis = (axis, ) if np.isscalar(axis) else axis

    def init_fun(input_shape):
        shape = (1 if i in axis else d for i, d in enumerate(input_shape))
        shape = tuple(itertools.dropwhile(lambda x: x == 1, shape))
        beta, gamma = _beta_init(shape), _gamma_init(shape)
        return input_shape, (beta, gamma)

    def apply_fun(params, x, rng=None):
        beta, gamma = params
        mean, var = np.mean(x, axis, keepdims=True), fastvar(x,
                                                             axis,
                                                             keepdims=True)
        z = (x - mean) / np.sqrt(var + epsilon)
        if center and scale: return gamma * z + beta
        if center: return z + beta
        if scale: return gamma * z
        return z

    return init_fun, apply_fun
Ejemplo n.º 3
0
    def dot(self, x):
        """Matrix-matrix or matrix-vector multiplication.

        Parameters
        ----------
        x : array_like
            1-d or 2-d array, representing a vector or matrix.

        Returns
        -------
        Ax : array
            1-d or 2-d array (depending on the shape of x) that represents
            the result of applying this linear operator on x.

        """
        if isinstance(x, LinearOperator):
            return _ProductLinearOperator(self, x)
        elif np.isscalar(x):
            return _ScaledLinearOperator(self, x)
        else:
            x = np.asarray(x)

            if x.ndim == 1 or x.ndim == 2 and x.shape[1] == 1:
                return self.matvec(x)
            elif x.ndim == 2:
                return self.matmat(x)
            else:
                raise ValueError('expected 1-d or 2-d array or matrix, got %r'
                                 % x)
Ejemplo n.º 4
0
 def __init__(self,
              loc=0.,
              covariance_matrix=None,
              precision_matrix=None,
              scale_tril=None,
              validate_args=None):
     if np.isscalar(loc):
         loc = np.expand_dims(loc, axis=-1)
     # temporary append a new axis to loc
     loc = loc[..., np.newaxis]
     if covariance_matrix is not None:
         loc, self.covariance_matrix = promote_shapes(
             loc, covariance_matrix)
         self.scale_tril = np.linalg.cholesky(self.covariance_matrix)
     elif precision_matrix is not None:
         loc, self.precision_matrix = promote_shapes(loc, precision_matrix)
         self.scale_tril = cholesky_of_inverse(self.precision_matrix)
     elif scale_tril is not None:
         loc, self.scale_tril = promote_shapes(loc, scale_tril)
     else:
         raise ValueError(
             'One of `covariance_matrix`, `precision_matrix`, `scale_tril`'
             ' must be specified.')
     batch_shape = lax.broadcast_shapes(
         np.shape(loc)[:-2],
         np.shape(self.scale_tril)[:-2])
     event_shape = np.shape(self.scale_tril)[-1:]
     self.loc = np.broadcast_to(np.squeeze(loc, axis=-1),
                                batch_shape + event_shape)
     super(MultivariateNormal, self).__init__(batch_shape=batch_shape,
                                              event_shape=event_shape,
                                              validate_args=validate_args)
Ejemplo n.º 5
0
def minibatch(batch_or_batchsize, num_obs_total=None):
    """Returns a context within which all samples are treated as being a
    minibatch of a larger data set.

    In essence, this marks the (log)likelihood of the sampled examples to be
    scaled to the total loss value over the whole data set.

    :param batch_or_batchsize: An integer indicating the batch size or an array
        indicating the shape of the batch where the length of the first axis
        is interpreted as batch size.
    :param num_obs_total: The total number of examples/observations in the
        full data set. Optional, defaults to the given batch size.
    """
    if is_int_scalar(batch_or_batchsize):
        if not jnp.isscalar(batch_or_batchsize):
            raise TypeError(
                "if a scalar is given for batch_or_batchsize, it "
                "can't be traced through jit. consider using static_argnums "
                "for the jit invocation.")
        batch_size = batch_or_batchsize
    elif is_array(batch_or_batchsize):
        batch_size = example_count(batch_or_batchsize)
    else:
        raise TypeError("batch_or_batchsize must be an array or an integer")
    if num_obs_total is None:
        num_obs_total = batch_size
    return scale(scale=num_obs_total / batch_size)
Ejemplo n.º 6
0
 def initialize(cls,
                key,
                in_spec,
                axis=(0, 1),
                momentum=0.99,
                epsilon=1e-5,
                center=True,
                scale=True,
                beta_init=stax.zeros,
                gamma_init=stax.ones):
     in_shape = in_spec.shape
     axis = (axis, ) if np.isscalar(axis) else axis
     decay = 1.0 - momentum
     shape = tuple(d for i, d in enumerate(in_shape) if i not in axis)
     moving_shape = tuple(1 if i in axis else d
                          for i, d in enumerate(in_shape))
     k1, k2, k3, k4 = random.split(key, 4)
     beta = base.create_parameter(k1, shape,
                                  init=beta_init) if center else ()
     gamma = base.create_parameter(k2, shape,
                                   init=gamma_init) if scale else ()
     moving_mean = base.create_parameter(k3, moving_shape, init=stax.zeros)
     moving_var = base.create_parameter(k4, moving_shape, init=stax.ones)
     params = BatchNormParams(beta, gamma)
     info = BatchNormInfo(axis, epsilon, center, scale, decay, in_shape)
     state = BatchNormState(moving_mean, moving_var)
     return base.LayerParams(params, info, state)
Ejemplo n.º 7
0
def BatchNorm(axis=(0, 1, 2),
              epsilon=1e-5,
              center=True,
              scale=True,
              beta_init=zeros,
              gamma_init=ones):
    """Layer construction function for a batch normalization layer."""
    _beta_init = lambda rng, shape: beta_init(rng, shape) if center else ()
    _gamma_init = lambda rng, shape: gamma_init(rng, shape) if scale else ()
    axis = (axis, ) if jnp.isscalar(axis) else axis

    def init_fun(rng, input_shape):
        shape = tuple(d for i, d in enumerate(input_shape) if i not in axis)
        k1, k2 = random.split(rng)
        beta, gamma = _beta_init(k1, shape), _gamma_init(k2, shape)
        return input_shape, (beta, gamma)

    def apply_fun(params, x, **kwargs):
        beta, gamma = params
        # TODO(phawkins): jnp.expand_dims should accept an axis tuple.
        # (https://github.com/numpy/numpy/issues/12290)
        ed = tuple(None if i in axis else slice(None)
                   for i in range(jnp.ndim(x)))
        z = standardize(x, axis, epsilon=epsilon)
        if center and scale: return gamma[ed] * z + beta[ed]
        if center: return z + beta[ed]
        if scale: return gamma[ed] * z
        return z

    return init_fun, apply_fun
Ejemplo n.º 8
0
def BatchNorm(axis=(0, 1, 2),
              epsilon=1e-5,
              center=True,
              scale=True,
              beta_init=zeros,
              gamma_init=ones):
    """Layer construction function for a batch normalization layer."""

    axis = (axis, ) if np.isscalar(axis) else axis

    @parametrized
    def batch_norm(x):
        ed = tuple(None if i in axis else slice(None)
                   for i in range(np.ndim(x)))
        mean, var = np.mean(x, axis, keepdims=True), fastvar(x,
                                                             axis,
                                                             keepdims=True)
        z = (x - mean) / np.sqrt(var + epsilon)
        shape = tuple(d for i, d in enumerate(x.shape) if i not in axis)

        scaled = z * parameter(shape, gamma_init, x,
                               'gamma')[ed] if scale else z
        return scaled + parameter(shape, beta_init, x,
                                  'beta')[ed] if center else scaled

    return batch_norm
Ejemplo n.º 9
0
def BatchNorm(axis=(0, 1, 2),
              epsilon=1e-5,
              center=True,
              scale=True,
              beta_init=zeros,
              gamma_init=ones):
    """Layer construction function for a batch normalization layer."""
    _beta_init = lambda rng, shape: beta_init(rng, shape) if center else ()
    _gamma_init = lambda rng, shape: gamma_init(rng, shape) if scale else ()
    axis = (axis, ) if np.isscalar(axis) else axis

    def init_fun(rng, input_shape):
        shape = tuple(d for i, d in enumerate(input_shape) if i not in axis)
        beta, gamma = _beta_init(rng, shape), _gamma_init(rng, shape)
        return input_shape, (beta, gamma)

    def apply_fun(params, x, **kwargs):
        beta, gamma = params
        # TODO(phawkins): np.expand_dims should accept an axis tuple.
        # (https://github.com/numpy/numpy/issues/12290)
        ed = tuple(None if i in axis else slice(None)
                   for i in range(np.ndim(x)))
        beta = beta[ed]
        gamma = gamma[ed]
        mean, var = np.mean(x, axis, keepdims=True), fastvar(x,
                                                             axis,
                                                             keepdims=True)
        z = (x - mean) / np.sqrt(var + epsilon)
        if center and scale: return gamma * z + beta
        if center: return z + beta
        if scale: return gamma * z
        return z

    return init_fun, apply_fun
Ejemplo n.º 10
0
    def init_fn(key, R, box, mass=f32(1.0), **kwargs):
        N, dim = R.shape

        _kT = kT if 'kT' not in kwargs else kwargs['kT']

        mass = quantity.canonicalize_mass(mass)
        V = jnp.sqrt(_kT / mass) * random.normal(key, R.shape, dtype=R.dtype)
        V = V - jnp.mean(V * mass, axis=0, keepdims=True) / mass
        KE = quantity.kinetic_energy(V, mass)

        # The box position is defined via pos = (1 / d) log V / V_0.
        zero = jnp.zeros((), dtype=R.dtype)
        one = jnp.ones((), dtype=R.dtype)
        box_position = zero
        box_velocity = zero
        box_mass = dim * (N + 1) * kT * barostat_kwargs['tau']**2 * one
        KE_box = quantity.kinetic_energy(box_velocity, box_mass)

        if jnp.isscalar(box) or box.ndim == 0:
            # TODO(schsam): This is necessary because of JAX issue #5849.
            box = jnp.eye(R.shape[-1]) * box

        return NPTNoseHooverState(R, V, force_fn(R, box=box, **kwargs), mass,
                                  box, box_position, box_velocity, box_mass,
                                  barostat.initialize(1, KE_box, _kT),
                                  thermostat.initialize(R.size, KE, _kT))  # pytype: disable=wrong-arg-count
Ejemplo n.º 11
0
def initialize_dim_names(variable_names: List[Text], state: State):
    """
    Initialize the dimension names for saving data to disk using pandas.
    The dimension names will be used as column headers in the resulting pd.DataFrame.
    Useful if you plan to label and plot your data automatically.

    Args:
        variable_names: Names of the state variables in the ODE integration run.
        state: Sample state from which to infer the dimension names.

    Returns:
        A list of dimension names.
    """

    var_dims = []

    for k, v in zip(variable_names, state):
        dim = 1 if jnp.isscalar(v) else len(v)

        var_dims.append((k, dim))

    dim_names = []
    for i, (name, dim) in enumerate(var_dims):
        if dim == 1:
            dim_names += [name]
        else:
            dim_names += ["{0}_{1}".format(name, i) for i in range(1, dim + 1)]

    return dim_names
Ejemplo n.º 12
0
def convert_to_dict(state: State, model_metadata: Dict[Text, Any], dim_names: List[Text]):
    """
    Convert a state in a run result object to a Dict for use in a pd.DataFrame constructor.

    Args:
        state: ODE state obtained in the numerical integration run.
        model_metadata: Model metadata saved in the run.
        dim_names: Names of dimensions in the ODE.

    Returns:
        A dict containing the dimension names as keys and the corresponding scalar data as values.
    """

    output_dict = dict()

    variable_names = model_metadata["variable_names"]

    idx = 0
    for i, name in enumerate(variable_names):
        v = state[i]

        if jnp.isscalar(v):
            k = dim_names[idx]
            output_dict.update({k: v})
            idx += 1
        else:
            k = dim_names[idx:idx + len(v)]
            output_dict.update(dict(zip(k, v)))
            idx += len(v)

    return output_dict
Ejemplo n.º 13
0
 def __init__(self, A, alpha):
     if not isinstance(A, LinearOperator):
         raise ValueError('LinearOperator expected as A')
     if not np.isscalar(alpha):
         raise ValueError('scalar expected as alpha')
     dtype = _get_dtype([A], [type(alpha)])
     super(_ScaledLinearOperator, self).__init__(dtype, A.shape)
     self.args = (A, alpha)
Ejemplo n.º 14
0
def make_schedule(step_size):
    """
    transforms various input types to x_1 step-size schedule
    """
    if type(step_size) is callable:
        return step_size
    if np.isscalar(step_size):
        return make_constant_schedule(step_size)
    raise ValueError(f'Unsupported type for `step_size`: {type(step_size)}')
Ejemplo n.º 15
0
def volume(dimension: int, box: Box) -> float:
    if jnp.isscalar(box) or not box.ndim:
        return box**dimension
    elif box.ndim == 1:
        return jnp.prod(box)
    elif box.ndim == 2:
        return jnp.linalg.det(box)
    raise ValueError(('Box must be either: a scalar, a vector, or a matrix. '
                      f'Found {box}.'))
Ejemplo n.º 16
0
def inverse(box: Box) -> Box:
    """Compute the inverse of an affine transformation."""
    if jnp.isscalar(box) or box.size == 1:
        return 1 / box
    elif box.ndim == 1:
        return 1 / box
    elif box.ndim == 2:
        return jnp.linalg.inv(box)
    raise ValueError(('Box must be either: a scalar, a vector, or a matrix. '
                      f'Found {box}.'))
Ejemplo n.º 17
0
def errorfill(x, y, yerr, color="r", alpha_fill=0.3, ax=None):
    ax = ax if ax is not None else plt.gca()
    if color is None:
        color = ax._get_lines.color_cycle.next()
    if np.isscalar(yerr) or len(yerr) == len(y):
        ymin = y - yerr
        ymax = y + yerr
    elif len(yerr) == 2:
        ymin, ymax = yerr
    ax.plot(x, y, color=color)
    ax.fill_between(x, ymax, ymin, color=color, alpha=alpha_fill)
    def collect_metrics(self, batch, env_ids, logits, logs, lr, model_params):
        """Collect metrics."""

        metrics_dict = self.metrics_fn(logits, batch, env_ids, model_params)
        metrics_dict['learning_rate'] = lr
        if isinstance(logs, dict):
            for key in logs:
                if jnp.isscalar(logs[key]):
                    metrics_dict[key] = logs[key]
                else:
                    metrics_dict[f'mean_{key}'] = jnp.mean(logs[key])

        return metrics_dict
Ejemplo n.º 19
0
    def _check_for_aliasing(self, wavelengths):
        """ Check for spatial frequency aliasing and warn if the
        user is requesting a FOV which is larger than supported based on
        the available pupil resolution in the optical system entrance pupil.
        If the requested FOV of the output PSF exceeds that which is Nyquist
        sampled in the entrance pupil, raise a warning to the user.

        The check implemented here is fairly simple, designed to catch the most
        common cases, and makes assumptions about the optical system which are
        not necessarily true in all cases, specifically that it starts with a
        pupil plane with fixed spatial resolution and ends with a detector
        plane. If either of those assumptions is violated, this check is skipped.

        See https://github.com/mperrin/morphine/issues/135 and
        https://github.com/mperrin/morphine/issues/180 for more background on the
        relevant Fourier optics.
        """
        # Note this must be called after self.optsys is defined in calc_psf()

        # compute spatial sampling in the entrance pupil
        if not hasattr(
                self.optsys.planes[0],
                'pixelscale') or self.optsys.planes[0].pixelscale is None:
            return  # analytic entrance pupil, no sampling limitations.
        if not isinstance(self.optsys.planes[-1], morphine_core.Detector):
            return  # optical system doesn't end on some fixed sampling detector, not sure how to check sampling limit

        # determine the spatial frequency which is Nyquist sampled by the input pupil.
        # convert this to units of cycles per meter and make it not a Quantity
        sf = (1. / (self.optsys.planes[0].pixelscale * 2))

        det_fov_arcsec = self.optsys.planes[-1].fov_arcsec
        if np.isscalar(
                det_fov_arcsec):  # FOV can be scalar (square) or rectangular
            det_fov_arcsec = (det_fov_arcsec, det_fov_arcsec)

        # determine the angular scale that corresponds to for the given wavelength
        for wl in wavelengths:
            critical_angle_arcsec = wl * sf * morphine_core._RADIANStoARCSEC
            if (critical_angle_arcsec < det_fov_arcsec[0] / 2) or (
                    critical_angle_arcsec < det_fov_arcsec[1] / 2):
                import warnings
                warnings.warn((
                    "For wavelength {:.3f} microns, a FOV of {:.3f} * {:.3f} arcsec exceeds the maximum "
                    +
                    " spatial frequency well sampled by the input pupil. Your computed PSF will suffer from "
                    +
                    "aliasing for angles beyond {:.3f} arcsec radius.").format(
                        wl * 1e6, det_fov_arcsec[0], det_fov_arcsec[1],
                        critical_angle_arcsec))
Ejemplo n.º 20
0
def _contains_query(vals, query):
    if isinstance(query, tuple):
        return map(partial(_contains_query, vals), query)

    if np.isnan(query):
        if np.any(np.isnan(vals)):
            raise FoundValue('NaN')
    elif np.isinf(query):
        if np.any(np.isinf(vals)):
            raise FoundValue('Found Inf')
    elif np.isscalar(query):
        if np.any(vals == query):
            raise FoundValue(str(query))
    else:
        raise ValueError('Malformed Query: {}'.format(query))
Ejemplo n.º 21
0
Archivo: util.py Proyecto: byzhang/d3p
def is_scalar(x):
    """Returns True if the input can be interpreted as a scalar.

    This fits actual scalars as well as arrays that contain only one element
    (regardless of their number of dimensions). I.e., a (jax.)numpy array
    with shape (1,1,1,1) would be considered a scalar.

    Works with jax.jit.
    """
    # note(lumip): a call to jax.jit(is_scalar)(s), where x is a scalar,
    #   results in an x that is a jax.numpy array without any dimensions but
    #   which has a shape attribute. therefore, jnp.isscalar(x) as well as
    #   is_array(x) are False -> we have to use has_shape(x) to detect this
    return jnp.isscalar(x) or (has_shape(x) and reduce(lambda x, a: x * a,
                                                       jnp.shape(x), 1) == 1)
Ejemplo n.º 22
0
Archivo: kde.py Proyecto: romanngg/jax
    def __init__(self, dataset, bw_method=None, weights=None):
        _check_arraylike("gaussian_kde", dataset)
        dataset = jnp.atleast_2d(dataset)
        if jnp.issubdtype(lax.dtype(dataset), jnp.complexfloating):
            raise NotImplementedError(
                "gaussian_kde does not support complex data")
        if not dataset.size > 1:
            raise ValueError("`dataset` input should have multiple elements.")

        d, n = dataset.shape
        if weights is not None:
            _check_arraylike("gaussian_kde", weights)
            dataset, weights = _promote_dtypes_inexact(dataset, weights)
            weights = jnp.atleast_1d(weights)
            weights /= jnp.sum(weights)
            if weights.ndim != 1:
                raise ValueError("`weights` input should be one-dimensional.")
            if len(weights) != n:
                raise ValueError("`weights` input should be of length n")
        else:
            dataset, = _promote_dtypes_inexact(dataset)
            weights = jnp.full(n, 1.0 / n, dtype=dataset.dtype)

        self._setattr("dataset", dataset)
        self._setattr("weights", weights)
        neff = self._setattr("neff", 1 / jnp.sum(weights**2))

        bw_method = "scott" if bw_method is None else bw_method
        if bw_method == "scott":
            factor = jnp.power(neff, -1. / (d + 4))
        elif bw_method == "silverman":
            factor = jnp.power(neff * (d + 2) / 4.0, -1. / (d + 4))
        elif jnp.isscalar(bw_method) and not isinstance(bw_method, str):
            factor = bw_method
        elif callable(bw_method):
            factor = bw_method(self)
        else:
            raise ValueError(
                "`bw_method` should be 'scott', 'silverman', a scalar, or a callable."
            )

        data_covariance = jnp.atleast_2d(
            jnp.cov(dataset, rowvar=1, bias=False, aweights=weights))
        data_inv_cov = jnp.linalg.inv(data_covariance)
        covariance = data_covariance * factor**2
        inv_cov = data_inv_cov / factor**2
        self._setattr("covariance", covariance)
        self._setattr("inv_cov", inv_cov)
Ejemplo n.º 23
0
def sph_harm(m: jnp.ndarray,
             n: jnp.ndarray,
             theta: jnp.ndarray,
             phi: jnp.ndarray,
             n_max: Optional[int] = None) -> jnp.ndarray:
    r"""Computes the spherical harmonics.

  The JAX version has one extra argument `n_max`, the maximum value in `n`.

  The spherical harmonic of degree `n` and order `m` can be written as
  :math:`Y_n^m(\theta, \phi) = N_n^m * P_n^m(\cos \phi) * \exp(i m \theta)`,
  where :math:`N_n^m = \sqrt{\frac{\left(2n+1\right) \left(n-m\right)!}
  {4 \pi \left(n+m\right)!}}` is the normalization factor and :math:`\phi` and
  :math:\theta` are the colatitude and longitude, repectively. :math:`N_n^m` is
  chosen in the way that the spherical harmonics form a set of orthonormal basis
  functions of :math:`L^2(S^2)`.

  Args:
    m: The order of the harmonic; must have `|m| <= n`. Return values for
      `|m| > n` ara undefined.
    n: The degree of the harmonic; must have `n >= 0`. The standard notation for
      degree in descriptions of spherical harmonics is `l (lower case L)`. We
      use `n` here to be consistent with `scipy.special.sph_harm`. Return
      values for `n < 0` are undefined.
    theta: The azimuthal (longitudinal) coordinate; must be in [0, 2*pi].
    phi: The polar (colatitudinal) coordinate; must be in [0, pi].
    n_max: The maximum degree `max(n)`. If the supplied `n_max` is not the true
      maximum value of `n`, the results are clipped to `n_max`. For example,
      `sph_harm(m=jnp.array([2]), n=jnp.array([10]), theta, phi, n_max=6)`
      acutually returns
      `sph_harm(m=jnp.array([2]), n=jnp.array([6]), theta, phi, n_max=6)`
  Returns:
    A 1D array containing the spherical harmonics at (m, n, theta, phi).
  """

    if jnp.isscalar(phi):
        phi = jnp.array([phi])

    if n_max is None:
        n_max = jnp.max(n)
    n_max = core.concrete_or_error(
        int, n_max,
        'The `n_max` argument of `jnp.scipy.special.sph_harm` must '
        'be statically specified to use `sph_harm` within JAX transformations.'
    )

    return _sph_harm(m, n, theta, phi, n_max)
Ejemplo n.º 24
0
    def __init__(self, existing_optical_system, oversample=8, occulter_box=1.0,
                 fpm_index=1, **kwargs):
        from . import optics
        super(SemiAnalyticCoronagraph, self).__init__(**kwargs)

        self.name = "SemiAnalyticCoronagraph for " + existing_optical_system.name
        self.verbose = existing_optical_system.verbose
        self.source_offset_r = existing_optical_system.source_offset_r
        self.source_offset_theta = existing_optical_system.source_offset_theta
        self.planes = existing_optical_system.planes
        self.npix = existing_optical_system.npix
        self.pupil_diameter = existing_optical_system.pupil_diameter

        # SemiAnalyticCoronagraphs have some mandatory planes, so give them reasonable names:
        self.fpm_index = fpm_index
        self.occulter = self.planes[fpm_index]
        self.lyotplane = self.planes[fpm_index + 1]
        self.detector = self.planes[-1]

        # some tweaks for display
        self.occulter.wavefront_display_hint = 'intensity'
        self.lyotplane.wavefront_display_hint = 'intensity'

        self.mask_function = optics.InverseTransmission(self.occulter)

        pt = morphine_core.PlaneType
        for label, plane, typecode in zip(["Occulter (plane {})".format(fpm_index),
                                           "Lyot (plane {})".format(fpm_index + 1),
                                           "Detector (last plane)"],
                                          [self.occulter, self.lyotplane, self.detector],
                                          [pt.image, pt.pupil, pt.detector]):
            if not plane.planetype == typecode:
                raise ValueError("Plane {0} is not of the right type for a semianalytic \
                        coronagraph calculation: should be {1:s} but is {2:s}.".format(label,
                                                                                       typecode, plane.planetype))

        self.oversample = oversample

        if not np.isscalar(occulter_box):
            occulter_box = np.array(occulter_box)  # cast to numpy array so the multiplication by 2
                                                  # just below will work
        self.occulter_box = occulter_box

        self.occulter_highres = morphine_core.Detector(self.detector.pixelscale / self.oversample,
                                                    fov_arcsec=self.occulter_box * 2,
                                                    name='Oversampled Occulter Plane')
Ejemplo n.º 25
0
    def new_parameters(self, input_shape, input_dtype, rng):
        """Helper to initialize batch norm params."""
        del input_dtype, rng
        axis = self._axis
        axis = (axis, ) if np.isscalar(axis) else axis
        shape = tuple(d for i, d in enumerate(input_shape) if i not in axis)
        beta = np.zeros(shape, dtype='float32') if self._center else ()
        gamma = np.ones(shape, dtype='float32') if self._scale else ()

        def get_stats_axis(i, d):
            if i in axis:
                return 1
            else:
                return d

        stats_shape = tuple(
            get_stats_axis(i, d) for i, d in enumerate(input_shape))
        running_mean = np.zeros(stats_shape, dtype=np.float32)
        running_var = np.zeros(stats_shape, dtype=np.float32)
        num_batches = np.zeros((), dtype=np.int32)
        return (beta, gamma), (running_mean, running_var, num_batches)
Ejemplo n.º 26
0
def isscalar(num):
    r"""

    This helper uses a slightly looser definition of scalar compared to :func:`numpy.isscalar` (and
    :func:`jax.numpy.isscalar`) in that it also considers single-item arrays to be scalars as well.

    Parameters
    ----------
    num : number or ndarray

        Input array.

    Returns
    -------
    isscalar : bool

        Whether the input number is either a number or a single-item array.

    """
    return jnp.isscalar(num) or (
        isinstance(num, (jnp.ndarray, onp.ndarray)) and jnp.size(num) == 1)
Ejemplo n.º 27
0
    def __init__(self, existing_optical_system, oversample=4, occulter_box=1.0,
                 **kwargs):
        super(MatrixFTCoronagraph, self).__init__(**kwargs)

        if len(existing_optical_system.planes) < 4:
            raise ValueError("Input optical system must have at least 4 planes "
                             "to be convertible into a MatrixFTCoronagraph")
        self.name = "MatrixFTCoronagraph for " + existing_optical_system.name
        self.verbose = existing_optical_system.verbose
        self.source_offset_r = existing_optical_system.source_offset_r
        self.source_offset_theta = existing_optical_system.source_offset_theta
        self.planes = existing_optical_system.planes
        self.npix = existing_optical_system.npix
        self.pupil_diameter = existing_optical_system.pupil_diameter

        self.oversample = oversample

        # if hasattr(occulter_box, '__getitem__'):
        if not np.isscalar(occulter_box):
            occulter_box = np.array(occulter_box)  # cast to numpy array so the multiplication by 2
                                                   # just below will work
        self.occulter_box = occulter_box
Ejemplo n.º 28
0
def transform(box: Box, R: Array) -> Array:
    """Apply an affine transformation to positions.

  See `periodic_general` for a description of the semantics of `box`.

  Args:
    box: An affine transformation described in `periodic_general`.
    R: Array of positions. Should have  shape `(..., spatial_dimension)`.

  Returns:
    A transformed array positions of shape `(..., spatial_dimension)`.
  """
    if jnp.isscalar(box) or box.size == 1:
        return R * box
    elif box.ndim == 1:
        indices = _get_free_indices(R.ndim - 1) + 'i'
        return jnp.einsum(f'i,{indices}->{indices}', box, R)
    elif box.ndim == 2:
        free_indices = _get_free_indices(R.ndim - 1)
        left_indices = free_indices + 'j'
        right_indices = free_indices + 'i'
        return jnp.einsum(f'ij,{left_indices}->{right_indices}', box, R)
    raise ValueError(('Box must be either: a scalar, a vector, or a matrix. '
                      f'Found {box}.'))
Ejemplo n.º 29
0
def convex_interpolate(x, y, lmbdas):
    """Interpolate with convex combination.

  Args:
    x: float jnp array; `[bs, ...]`
    y: float jnp array; `[bs, ...]`
    lmbdas: float array; `[num_of_interpolations, bs]`

  Returns:
    z (interolated states) with shape `[bs x num_of_interpolations, ...]`
  """
    # TODO(samiraabnar): Make sure this method is not redefined else where and is
    # just reused.
    assert x.shape == y.shape, f'x.shape != y.shape, {x.shape} != {y.shape}'

    if (not jnp.isscalar(lmbdas)) and len(x.shape) > (len(lmbdas.shape) - 1):
        # If lambdas.shape
        lmbdas = jax.lax.broadcast_in_dim(lmbdas,
                                          shape=lmbdas.shape + x.shape[1:],
                                          broadcast_dimensions=(0, 1))

    z = x[None, Ellipsis] * (1 - lmbdas) + y[None, Ellipsis] * (lmbdas)
    z = z.reshape((-1, ) + z.shape[2:])
    return z
Ejemplo n.º 30
0
 def parametric_beizer_curve(t, bends):  # bends: (num_bends, nL)
     if not jnp.isscalar(t):
         t = t.reshape(-1, 1)  # t should be column vector
     c = coef * jnp.power(t, n) * jnp.power(1. - t, n_rev)
     return jnp.dot(c, bends)