def loop_body(inputs):
     rng, parameters, summaries, distances, n_accepted, iteration = \
         inputs
     rng, key = jax.random.split(rng)
     parameter_samples = self.prior.sample(n_simulations, seed=key)
     rng, key = jax.random.split(rng)
     summary_samples = self.compressor(
         self.simulator(key, parameter_samples))
     distance_samples = jax.vmap(
         lambda target, F: self.distance_measure(
             summary_samples, target, F))(self.target_summaries, self.F)
     indices = jax.lax.dynamic_slice(
         np.arange(n_simulations * max_iterations),
         [n_simulations * iteration], [n_simulations])
     parameters = jax.ops.index_update(parameters,
                                       jax.ops.index[indices],
                                       parameter_samples)
     summaries = jax.ops.index_update(summaries, jax.ops.index[indices],
                                      summary_samples)
     distances = jax.ops.index_update(distances, jax.ops.index[:,
                                                               indices],
                                      distance_samples)
     n_accepted = np.int32(np.less(distances, ϵ).sum(1))
     return rng, parameters, summaries, distances, n_accepted, \
         iteration + np.int32(1)
示例#2
0
 def testEnumPromotion(self):
   class AnEnum(enum.IntEnum):
     A = 42
     B = 101
   np.testing.assert_equal(np.array(42), np.array(AnEnum.A))
   np.testing.assert_equal(jnp.array(42), jnp.array(AnEnum.A))
   np.testing.assert_equal(np.int32(101), np.int32(AnEnum.B))
   np.testing.assert_equal(jnp.int32(101), jnp.int32(AnEnum.B))
示例#3
0
 def init(θ, data):
     y = data[1]
     e = errors(θ, x, y)
     i = np.int32(0)  # Iterations counter
     k = np.int32(1)  # Convergence counter
     C = obj.cost(e)
     R = obj.regularizer(θ)
     G = np.float32(β * C + α * R)  # Objective function
     return LMBTrainingState(θ, e, G, C, R, (α, β), μi, τ, i, k)
示例#4
0
 def testEnumPromotion(self):
   class AnEnum(enum.IntEnum):
     A = 42
     B = 101
   onp.testing.assert_equal(onp.array(42), onp.array(AnEnum.A))
   with core.skipping_checks():
     # Passing AnEnum.A to np.array fails the type check in bind
     onp.testing.assert_equal(np.array(42), np.array(AnEnum.A))
   onp.testing.assert_equal(onp.int32(101), onp.int32(AnEnum.B))
   onp.testing.assert_equal(np.int32(101), np.int32(AnEnum.B))
def rescale_jax(x: ep.JAXTensor, target_shape: List[int]) -> ep.JAXTensor:
    # img must be in channel_last format

    # modified according to https://github.com/google/jax/issues/862
    import jax.numpy as np

    img = x.raw

    resize_rates = (target_shape[1] / x.shape[1], target_shape[2] / x.shape[2])

    def interpolate_bilinear(  # type: ignore
            im: np.ndarray, rows: np.ndarray, cols: np.ndarray) -> np.ndarray:
        # based on http://stackoverflow.com/a/12729229
        col_lo = np.floor(cols).astype(int)
        col_hi = col_lo + 1
        row_lo = np.floor(rows).astype(int)
        row_hi = row_lo + 1

        def cclip(cols: np.ndarray) -> np.ndarray:  # type: ignore
            return np.clip(cols, 0, ncols - 1)

        def rclip(rows: np.ndarray) -> np.ndarray:  # type: ignore
            return np.clip(rows, 0, nrows - 1)

        nrows, ncols = im.shape[-3:-1]

        Ia = im[..., rclip(row_lo), cclip(col_lo), :]
        Ib = im[..., rclip(row_hi), cclip(col_lo), :]
        Ic = im[..., rclip(row_lo), cclip(col_hi), :]
        Id = im[..., rclip(row_hi), cclip(col_hi), :]

        wa = np.expand_dims((col_hi - cols) * (row_hi - rows), -1)
        wb = np.expand_dims((col_hi - cols) * (rows - row_lo), -1)
        wc = np.expand_dims((cols - col_lo) * (row_hi - rows), -1)
        wd = np.expand_dims((cols - col_lo) * (rows - row_lo), -1)

        return wa * Ia + wb * Ib + wc * Ic + wd * Id

    nrows, ncols = img.shape[-3:-1]
    deltas = (0.5 / resize_rates[0], 0.5 / resize_rates[1])

    rows = np.linspace(deltas[0], nrows - deltas[0],
                       np.int32(resize_rates[0] * nrows))
    cols = np.linspace(deltas[1], ncols - deltas[1],
                       np.int32(resize_rates[1] * ncols))
    rows_grid, cols_grid = np.meshgrid(rows - 0.5, cols - 0.5, indexing="ij")

    img_resize_vec = interpolate_bilinear(img, rows_grid.flatten(),
                                          cols_grid.flatten())
    img_resize = img_resize_vec.reshape(img.shape[:-3] +
                                        (len(rows), len(cols)) +
                                        img.shape[-1:])

    return ep.JAXTensor(img_resize)
示例#6
0
 def test_pushes_and_pops(self):
     stack = Stack.create(7, jnp.zeros((), jnp.int32))
     stack = stack.push(jnp.int32(7))
     self.assertFalse(stack.empty())
     stack = stack.push(jnp.int32(8))
     self.assertFalse(stack.empty())
     x, stack = stack.pop()
     self.assertFalse(stack.empty())
     self.assertEqual(8, x)
     stack = stack.push(jnp.int32(9))
     x, stack = stack.pop()
     self.assertFalse(stack.empty())
     self.assertEqual(9, x)
     x, stack = stack.pop()
     self.assertTrue(stack.empty())
     self.assertEqual(7, x)
示例#7
0
  def test_check_jaxpr_eqn_mismatch(self):
    def f(x):
      return jnp.sin(x) + jnp.cos(x)

    def new_jaxpr():
      return make_jaxpr(f)(jnp.float32(1.)).jaxpr

    # jaxpr is:
    #
    # { lambda  ; a.
    #   let b = sin a
    #       c = cos a
    #       d = add b c
    #   in (d,) }
    #
    # NB: eqns[0].outvars[0] and eqns[2].invars[0] are both 'b'

    jaxpr = new_jaxpr()
    # int, not float!
    jaxpr.eqns[0].outvars[0].aval = make_shaped_array(jnp.int32(2))
    self.assertRaisesRegex(
        core.JaxprTypeError,
        r"Variable 'b' inconsistently typed as f32\[\], "
        r"bound as i32\[\]\n\nin equation:\n\nb:i32\[\] = sin a",
        lambda: core.check_jaxpr(jaxpr))

    jaxpr = new_jaxpr()
    jaxpr.eqns[0].outvars[0].aval = make_shaped_array(
      np.ones((2, 3), dtype=jnp.float32))
    self.assertRaisesRegex(
        core.JaxprTypeError,
        r"Variable 'b' inconsistently typed as f32\[\], "
        r"bound as f32\[2,3\]\n\nin equation:\n\nb:f32\[2,3\] = sin a",
        lambda: core.check_jaxpr(jaxpr))
示例#8
0
def multinomial_mode(
    distribution_or_probs: Union[tfd.Distribution, jnp.DeviceArray]
) -> jnp.DeviceArray:
    """Calculates the (one-hot) mode of a multinomial distribution.

  Args:
    distribution_or_probs:
      `tfp.distributions.Distribution` | List[tensors].
      If the former, it is assumed that it has a `probs` property, and
      represents a distribution over categories. If the latter, these are
      taken to be the probabilities of categories directly.
      In either case, it is assumed that `probs` will be shape
      (batch_size, dim).

  Returns:
    `DeviceArray`, float32, (batch_size, dim).
    The mode of the distribution - this will be in one-hot form, but contain
    multiple non-zero entries in the event that more than one probability is
    joint-highest.
  """
    if isinstance(distribution_or_probs, tfd.Distribution):
        probs = distribution_or_probs.probs_parameter()
    else:
        probs = distribution_or_probs
    max_prob = jnp.max(probs, axis=1, keepdims=True)
    mode = jnp.int32(jnp.equal(probs, max_prob))
    return jnp.float32(mode / jnp.sum(mode, axis=1, keepdims=True))
示例#9
0
 def update(data, state):
     y = data[1]
     H, Je = differentiate(state.θ, state.e, y)
     # Inner Levenberg-Maquardt update
     lm_state = LMState(state.θ, state.e, state.G, state.C, state.R, state.μ)
     lm_cond = partial(_lm_cond, state.G)
     lm_update = partial(_lm_update, state.θ, H, Je, y, state.Λ)
     θ, e, G, C, R, μ = while_loop(
         lm_cond,
         lm_update,
         lm_state
     )
     μ = np.where(μ < μmax, μ / μs, μ)
     μ = np.where(μmin < μ, μ, μmin)
     # Bayesian hyperparameter learning
     bl_state = (G, state.Λ, μ, state.τ)
     bl_update = partial(_bl_update, H, C, R)
     bl_restart = partial(_bl_restart, state.G)
     G, Λ, μ, τ = cond(
         G > state.G,
         bl_state,
         bl_restart,
         bl_state,
         bl_update
     )
     k = np.where(G >= state.G, state.k + 1, np.int32(1))
     return LMBTrainingState(θ, e, G, C, R, Λ, μ, τ, state.i + 1, k)
示例#10
0
def hsv_to_rgb(hsv_image):
  # Adapted from the numpy implementation here: https://gist.github.com/PolarNick239/691387158ff1c41ad73c#file-rgb_to_hsv_np-py
  input_shape = hsv_image.shape
  hsv_image = hsv_image.reshape(-1, 3)
  h, s, v = hsv_image[:, 0], hsv_image[:, 1], hsv_image[:, 2]

  i = jnp.int32(h * 6.0)
  f = (h * 6.0) - i
  p = v * (1.0 - s)
  q = v * (1.0 - s * f)
  t = v * (1.0 - s * (1.0 - f))
  i = i % 6

  rgb_image = jnp.zeros_like(hsv_image)
  v, t, p, q = v.reshape(-1, 1), t.reshape(-1, 1), p.reshape(-1, 1), q.reshape(-1, 1)

  i = jnp.tile(i.reshape(-1,1), (1,3))
  rgb_image = jnp.where(i==0, jnp.hstack([v, t, p]), rgb_image)
  rgb_image = jnp.where(i==1, jnp.hstack([q, v, p]), rgb_image)
  rgb_image = jnp.where(i==2, jnp.hstack([p, v, t]), rgb_image)
  rgb_image = jnp.where(i==3, jnp.hstack([p, q, v]), rgb_image)
  rgb_image = jnp.where(i==4, jnp.hstack([t, p, v]), rgb_image)
  rgb_image = jnp.where(i==5, jnp.hstack([v, p, q]), rgb_image)

  s = jnp.tile(s.reshape(-1,1), (1,3))
  rgb_image = jnp.where(s==0, jnp.hstack([v, v, v]), rgb_image)

  return rgb_image.reshape(input_shape)
示例#11
0
def interpolate1d(x, values, tangents):
  r"""Perform cubic hermite spline interpolation on a 1D spline.

  The x coordinates of the spline knots are at [0 : len(values)-1].
  Queries outside of the range of the spline are computed using linear
  extrapolation. See https://en.wikipedia.org/wiki/Cubic_Hermite_spline
  for details, where "x" corresponds to `x`, "p" corresponds to `values`, and
  "m" corresponds to `tangents`.

  Args:
    x: A tensor containing the set of values to be used for interpolation into
      the spline.
    values: A vector containing the value of each knot of the spline being
      interpolated into. Must be the same length as `tangents`.
    tangents: A vector containing the tangent (derivative) of each knot of the
      spline being interpolated into. Must be the same length as `values` and
      the same type as `x`.

  Returns:
    The result of interpolating along the spline defined by `values`, and
    `tangents`, using `x` as the query values. Will be the same shape as `x`.
  """
  assert len(values.shape) == 1
  assert len(tangents.shape) == 1
  assert values.shape[0] == tangents.shape[0]

  # Find the indices of the knots below and above each x.
  x_lo = jnp.int32(jnp.floor(jnp.clip(x, 0., values.shape[0] - 2)))
  x_hi = x_lo + 1

  # Compute the relative distance between each `x` and the knot below it.
  t = x - x_lo

  # Compute the cubic hermite expansion of `t`.
  t_sq = t**2
  t_cu = t * t_sq
  h01 = -2 * t_cu + 3 * t_sq
  h00 = 1 - h01
  h11 = t_cu - t_sq
  h10 = h11 - t_sq + t

  # Linearly extrapolate above and below the extents of the spline for all
  # values.
  value_before = tangents[0] * t + values[0]
  value_after = tangents[-1] * (t - 1) + values[-1]

  # Cubically interpolate between the knots below and above each query point.
  neighbor_values_lo = jnp.take(values, x_lo)
  neighbor_values_hi = jnp.take(values, x_hi)
  neighbor_tangents_lo = jnp.take(tangents, x_lo)
  neighbor_tangents_hi = jnp.take(tangents, x_hi)

  value_mid = (
      neighbor_values_lo * h00 + neighbor_values_hi * h01 +
      neighbor_tangents_lo * h10 + neighbor_tangents_hi * h11)

  # Return the interpolated or extrapolated values for each query point,
  # depending on whether or not the query lies within the span of the spline.
  return jnp.where(t < 0., value_before,
                   jnp.where(t > 1., value_after, value_mid))
示例#12
0
def apply_scatter_op(
    scatter_agg_op,
    n: int,
    values: jnp.ndarray,
    targets: jnp.ndarray,
    active: jnp.ndarray = None,
) -> jnp.ndarray:
    """
    Apply given scatter aggregate operation on `values` with their target indices `targets`

    `scatter_agg_op` is one of `jax.lax.scatter_*`. `n` is the result size, target indices outside the range are dropped.
    If `active` is given, only `active[i]==True` positions are taken into account.
    """
    if np.issubdtype(values.dtype, np.bool_) and scatter_agg_op in (
            jax.lax.scatter_add,
            jax.lax.scatter_mul,
    ):
        values = jnp.int32(values)
    neutral_value = _op_neutral(scatter_agg_op, values.dtype)
    # Array of neutral values
    z = jnp.full((n, ) + values.shape[1:], neutral_value, dtype=values.dtype)
    if active is not None:
        targets = jnp.where(active, targets, n + 1)
    targets = jnp.expand_dims(targets, 1)
    dims = jax.lax.ScatterDimensionNumbers(tuple(range(1, len(values.shape))),
                                           (0, ), (0, ))
    return scatter_agg_op(z, targets, values, dims, mode="drop")
示例#13
0
def epi_demo(edge_beta, gamma, infect, nodes, steps):
    k = 3
    np = NetworkProcess([
        epidemics.SIRUpdateOp(),
        # operations.CountNodeStatesOp(states=3, key="compartment"),
        # operations.CountNodeTransitionsOp(states=3, key="compartment"),
    ])
    params = {"edge_infection_rate": edge_beta, "recovery_rate": gamma}

    log.info(
        f"Network: Barabasi-Albert. n={nodes}, k={k}, cca {nodes*k*2:.2e} directed edges"
    )
    with utils.logged_time("  Creating graph", logger=log):
        g = nx.random_graphs.barabasi_albert_graph(nodes, k)
    with utils.logged_time("  Creating state", logger=log):
        net = Network.from_graph(g)
        state = np.new_state(net, props=params, seed=42)
        rng = jax.random.PRNGKey(43)
        comp = jnp.int32(
            jax.random.bernoulli(rng, infect / nodes, shape=[nodes]))
        state.node["compartment"] = comp
    with utils.logged_time("  Running model", logger=log):
        t0 = time.time()
        state2 = np.run(state, steps=steps)
        state2.block_on_all()
        t1 = time.time()

    log.info(np.trace_log())
    sps = steps / (t1 - t0)
    log.info(f"{steps} steps took {t1-t0:.2g} s,  {sps:.3g} steps/s,  " +
             f"{sps*state.m:.3g} edge_ops/s,  {sps * state.n:.3g} node_ops/s")
示例#14
0
def visualize_coord_fix(coords, acc, percentile=99.):
    """Visualize the "cell" each coordinate lives in, and highlight its edges."""

    # Round towards zero.
    coords_fix = jnp.int32(jnp.fix(coords))

    # A very hacky plus-shaped edge detector.
    coords_fix_pad = jnp.pad(coords_fix, [(1, 1), (1, 1), (0, 0)], 'edge')
    mask = ((coords_fix == coords_fix_pad[2:, 1:-1, :]) &
            (coords_fix == coords_fix_pad[:-2, 1:-1, :])
            & (coords_fix == coords_fix_pad[1:-1, 2:, :])
            & (coords_fix == coords_fix_pad[1:-1, :-2, :]))

    # Scale according to `acc` and clip to lie in [-1, 1].
    max_val = jnp.maximum(
        1,
        math.weighted_percentile(jnp.max(jnp.abs(coords_fix), axis=2), acc,
                                 percentile))
    coords_fix_unit = jnp.clip(coords_fix / max_val, -1, 1)

    # The [-1, 1] center cube is gray, and every other integer boundary gets
    # colored with xyz \propto rgb - gray. Edge pixels are highlighted.
    return matte(
        jnp.where(mask, (coords_fix_unit + 1) / 2,
                  1 - jnp.abs(coords_fix_unit)), acc)
示例#15
0
 def _generate_partition(decays, decay_distribution, length):
     # Generates length-sized array split according to decay_distribution.
     decays = jnp.array(decays)
     decay_distribution = jnp.array(decay_distribution)
     multiples = jnp.int32(jnp.floor(decay_distribution * length))
     multiples = multiples.at[-1].set(multiples[-1] + length -
                                      jnp.sum(multiples))
     return jnp.repeat(decays, multiples)
示例#16
0
def fill_triangular_inverse(x, upper=False):
    n = x.shape[-1]
    n = np.int32(n)
    m = np.int32((n * (n + 1)) // 2)
    final_shape = list(x.shape[:-2]) + [m]
    if upper:
        initial_elements = x[..., 0, :]
        triangular_portion = x[..., 1:, :]
    else:
        initial_elements = np.flip(x[..., -1, :], axis=-2)
        triangular_portion = x[..., :-1, :]
    rotated_triangular_portion = np.flip(
        np.flip(triangular_portion, axis=-1), axis=-2)
    consolidated_matrix = triangular_portion + rotated_triangular_portion
    end_sequence = np.reshape(
        consolidated_matrix,
        list(x.shape[:-2]) + [n * (n - 1)])
    y = np.concatenate([initial_elements, end_sequence[..., :m - n]], axis=-1)
    return y
示例#17
0
    def test_primitive_compilation_cache(self):
        devices = self.get_devices()

        x = jax.device_put(jnp.int32(1), devices[1])

        with jtu.count_primitive_compiles() as count:
            y = lax.add(x, x)
            z = lax.add(y, y)

        self.assertEqual(count[0], 1)
        self.assert_committed_to_device(y, devices[1])
        self.assert_committed_to_device(z, devices[1])
示例#18
0
def fill_triangular(x, upper=False):
    m = x.shape[-1]
    if len(x.shape) != 1:
        raise ValueError("Only handles 1D to 2D transformation, because tril/u")
    m = np.int32(m)
    n = np.sqrt(0.25 + 2. * m) - 0.5
    if n != np.floor(n):
        raise ValueError('Input right-most shape ({}) does not '
                         'correspond to a triangular matrix.'.format(m))
    n = np.int32(n)
    final_shape = list(x.shape[:-1]) + [n, n]
    if upper:
        x_list = [x, np.flip(x[..., n:], -1)]

    else:
        x_list = [x[..., n:], np.flip(x, -1)]
    x = np.reshape(np.concatenate(x_list, axis=-1), final_shape)
    if upper:
        x = np.triu(x)
    else:
        x = np.tril(x)
    return x
示例#19
0
def _for_impl(*args, jaxpr, nsteps, reverse, which_linear):
    del which_linear
    discharged_jaxpr, consts = discharge_state(jaxpr, ())

    def cond(carry):
        i, _ = carry
        return i < nsteps

    def body(carry):
        i, state = carry
        i_ = nsteps - i - 1 if reverse else i
        next_state = core.eval_jaxpr(discharged_jaxpr, consts, i_, *state)
        return i + 1, next_state

    _, state = lax.while_loop(cond, body, (jnp.int32(0), list(args)))
    return state
示例#20
0
def bi_tempered_logistic_loss_fwd(activations,
                                  labels,
                                  t1,
                                  t2,
                                  label_smoothing=0.0,
                                  num_iters=5):
  """Forward pass function for bi-tempered logistic loss.

  Args:
    activations: A multi-dimensional array with last dimension `num_classes`.
    labels: An array with shape and dtype as activations.
    t1: Temperature 1 (< 1.0 for boundedness).
    t2: Temperature 2 (> 1.0 for tail heaviness, < 1.0 for finite support).
    label_smoothing: Label smoothing parameter between [0, 1).
    num_iters: Number of iterations to run the method.

  Returns:
    A loss array, residuals.
  """
  num_classes = jnp.int32(labels.shape[-1])
  labels = cond(
      label_smoothing > 0.0,
      lambda u:  # pylint: disable=g-long-lambda
      (1 - num_classes /
       (num_classes - 1) * label_smoothing) * u + label_smoothing /
      (num_classes - 1),
      lambda u: u,
      labels)
  probabilities = tempered_softmax(activations, t2, num_iters)

  def _tempred_cross_entropy_loss(unused_activations):
    loss_values = jnp.multiply(
        labels,
        log_t(labels + 1e-10, t1) -
        log_t(probabilities, t1)) - 1.0 / (2.0 - t1) * (
            jnp.power(labels, 2.0 - t1) - jnp.power(probabilities, 2.0 - t1))
    loss_values = jnp.sum(loss_values, -1)
    return loss_values

  loss_values = cond(
      jnp.logical_and(
          jnp.less(jnp.abs(t1 - 1.0), 1e-15),
          jnp.less(jnp.abs(t2 - 1.0), 1e-15)),
      functools.partial(_cross_entropy_loss, labels=labels),
      _tempred_cross_entropy_loss,
      activations)
  return loss_values, (labels, t1, t2, probabilities)
示例#21
0
    def rvs(self, nsamps: int = 1) -> np.array:
        assert np.all(self.prefactors >= 0.)

        #use residual resampling from SMC theory
        if nsamps is None:
            nsamps = len(pop)
        prop_w = log(self.normalized().prefactors)
        mult = exp(prop_w + log(nsamps))
        count = np.int32(np.floor(mult))
        resid = log(mult - count)
        resid = resid - logsumexp(resid)
        count = count + onp.random.multinomial(nsamps - count.sum(),
                                               exp(resid))

        rval = np.repeat(self.inspace_points, count, 0) + self.k.rvs(
            nsamps, self.inspace_points.shape[1])
        return rval
示例#22
0
def compute_normalization_binary_search(activations,
                                        t,
                                        num_iters = 10):
  """Returns the normalization value for each example (t < 1.0).

  Args:
    activations: A multi-dimensional array with last dimension `num_classes`.
    t: Temperature 2 (< 1.0 for finite support).
    num_iters: Number of iterations to run the method.
  Return: An array of same rank as activation with the last dimension being 1.
  """
  mu = jnp.max(activations, -1, keepdims=True)
  normalized_activations = activations - mu
  shape_activations = activations.shape
  effective_dim = jnp.float32(
      jnp.sum(
          jnp.int32(normalized_activations > -1.0 / (1.0 - t)),
          -1,
          keepdims=True))
  shape_partition = list(shape_activations[:-1]) + [1]
  lower = jnp.zeros(shape_partition)
  upper = -log_t(1.0 / effective_dim, t) * jnp.ones(shape_partition)

  def cond_fun(carry):
    _, _, iters = carry
    return iters < num_iters

  def body_fun(carry):
    lower, upper, iters = carry
    logt_partition = (upper + lower) / 2.0
    sum_probs = jnp.sum(
        exp_t(normalized_activations - logt_partition, t), -1, keepdims=True)
    update = jnp.float32(sum_probs < 1.0)
    lower = jnp.reshape(lower * update + (1.0 - update) * logt_partition,
                        shape_partition)
    upper = jnp.reshape(upper * (1.0 - update) + update * logt_partition,
                        shape_partition)
    return lower, upper, iters + 1

  lower = jnp.zeros(shape_partition)
  upper = -log_t(1.0 / effective_dim, t) * jnp.ones(shape_partition)
  lower, upper, _ = while_loop(cond_fun, body_fun, (lower, upper, 0))

  logt_partition = (upper + lower) / 2.0
  return logt_partition + mu
示例#23
0
    def sample(self, key: jnp.ndarray) -> jnp.ndarray:
        """Sample from the distribution.

        Args:
            key: JAX random key.
        """
        shape = self.p.shape
        keys = jax.random.split(key, self.n.size)
        n_sample = self.n.reshape((self.n.size))
        p_sample = self.p.reshape((-1, self.p.shape[-1]))
        samp = []
        for n, p, k in zip(n_sample, p_sample, keys):
            samples = jnp.where(
                jnp.isnan(n), jnp.nan,
                jax.random.categorical(k, jnp.log(p), shape=(jnp.int32(n), )))
            samp.append(jnp.sum(jax.nn.one_hot(samples, p.shape[-1]), 0))

        is_nan = jnp.isnan(self.p)
        return jnp.where(is_nan, jnp.full(shape, jnp.nan),
                         jnp.stack(samp).reshape(shape))
示例#24
0
    def sample(self, key: jnp.ndarray) -> jnp.ndarray:
        """Sample from the distribution.

        Args:
            key: JAX random key.
        """
        shape = self.n.shape
        keys = jax.random.split(key, self.n.size)
        n_sample = self.n.reshape((self.n.size))
        p_sample = self.p.reshape((self.p.size))
        samp = []
        for n, p, k in zip(n_sample, p_sample, keys):
            samples = jnp.where(
                jnp.isnan(n),
                jnp.nan,
                jax.random.bernoulli(k, p, shape=(jnp.int32(n), )))
            samp.append(jnp.sum(samples))

        is_nan = jnp.isnan(self.p)
        return jnp.where(is_nan, jnp.full(shape, jnp.nan),
                         jnp.stack(samp).reshape(shape))
示例#25
0
def compute_shift(lattice, cutoff):
    """ calculate neccessary repetitions in each direction with reciprocal lattice
    Args:
        lattice: takes lattice matrix as 2D-Array , e.g.: jnp.diag(jnp.ones(3))
        cutoff: cutoff distance as float
    Returns:
        shifts matrix as 2D matrix of shift vectors
    """
    n_repeat = jnp.int32(jnp.ceil(jnp.linalg.norm(jnp.linalg.inv(lattice), axis=0) * cutoff))
    # n_repeat = np.int32(np.asarray([1, 1, 1]))
    print("Repeat", n_repeat)
    relative_shifts = jnp.array([[el, el2, el3] for el in range(-n_repeat[0], n_repeat[0] + 1, 1)
                                 for el2 in range(-n_repeat[1], n_repeat[1] + 1, 1)
                                 for el3 in range(-n_repeat[2], n_repeat[2] + 1, 1)])
    relative_shifts2 = jnp.where(jnp.where(relative_shifts > 0, relative_shifts-1, relative_shifts) < 0, relative_shifts + 1,
                                jnp.where(relative_shifts > 0, relative_shifts-1, relative_shifts))
    shifts = jnp.matmul(jnp.expand_dims(lattice.T, axis=0).repeat(relative_shifts2.shape[0], axis=0),
                        jnp.expand_dims(relative_shifts2, -1)).squeeze()
    relative_shifts = relative_shifts[jnp.where(np.linalg.norm(shifts, axis=1) < cutoff)]
    shifts = jnp.matmul(jnp.expand_dims(lattice.T, axis=0).repeat(relative_shifts.shape[0], axis=0),
                            jnp.expand_dims(relative_shifts, -1)).squeeze()
    return shifts
 def test_attributes_create_weights_op_fp(
     self,
     weight_range,
     weight_shape,
     fp_quant,
 ):
     weights = jnp.array(
         fp32(onp.random.uniform(*weight_range, size=weight_shape)))
     axis = None if weight_shape[1] == 1 else 0
     weights_quant_op = QuantOps.create_weights_ops(
         w=weights,
         weight_params=QuantOps.WeightParams(prec=fp_quant,
                                             axis=axis,
                                             half_shift=False))
     max_weight = onp.max(abs(weights), axis=0)
     onp.testing.assert_array_equal(
         jnp.squeeze(weights_quant_op._scale),
         jnp.exp2(-jnp.floor(jnp.log2(max_weight))))
     self.assertEqual(weights_quant_op._symmetric, True)
     self.assertIs(weights_quant_op._prec, fp_quant)
     weights_scaled = (weights * weights_quant_op._scale).astype(
         weights.dtype)
     weights_quant_expected = fp_cast.downcast_sat_ftz(
         weights_scaled,
         fp_quant.fp_spec.exp_min,
         fp_quant.fp_spec.exp_max,
         fp_quant.fp_spec.sig_bits,
     )
     weights_quant_calculated = weights_quant_op.to_quantized(
         weights, dtype=SCALE_DTYPE)
     onp.testing.assert_array_equal(weights_quant_expected,
                                    weights_quant_calculated)
     # Test the lower (23 - fp_quant.fp_spec.sig_bits) bits of the calculated
     # quantized weights are zero.
     sig_mask = jnp.int32((1 << (23 - fp_quant.fp_spec.sig_bits)) - 1)
     onp.testing.assert_array_equal(
         weights_quant_calculated.view(jnp.int32) & sig_mask,
         jnp.zeros_like(weights))
示例#27
0
def prepare_filter_functions(metadata, background_avg):

    #Convolution kernel
    kernel_width = np.max(
        np.array([
            np.int32(
                np.floor(metadata["padded_frame_width"] /
                         metadata["output_frame_width"])), 1
        ]))
    kernel_box = np.ones((kernel_width, kernel_width))

    cleanXraw_vmap = jax.vmap(lambda x: cleanXraw(x - background_avg))
    cleanXraw_vmap_d1 = jax.vmap(lambda x: cleanXraw(x - background_avg[0]))
    cleanXraw_vmap_d2 = jax.vmap(lambda x: cleanXraw(x - background_avg[1]))

    combine_double_exposure_vmapf = jax.vmap(
        lambda x, y: combine_double_exposure(x, y, metadata[
            "double_exp_time_ratio"]))

    #single and double exposure functions
    f_cleanframes = jax.jit(lambda x: cleanXraw_vmap(x))
    f_cleanframes_d = jax.jit(lambda x, y: combine_double_exposure_vmapf(
        cleanXraw_vmap_d1(x), cleanXraw_vmap_d2(y)))

    def f(clean_frame):
        filtered_frame = filter_frame(clean_frame, kernel_box)
        centered_rescaled_frame = shift_rescale(
            filtered_frame, metadata["center_of_mass"],
            metadata["output_frame_width"], metadata["output_padded_ratio"])
        return centered_rescaled_frame

    process_batch_vmapf = jax.vmap(f)

    f_all = jax.jit(lambda x: process_batch_vmapf(f_cleanframes(x)))
    f_all_d = jax.jit(lambda x, y: process_batch_vmapf(f_cleanframes_d(x, y)))

    return f_all, f_all_d
示例#28
0
 def testScanTypeErrors(self):
   """Test typing error messages for scan."""
   a = np.arange(5)
   # Body output not a tuple
   with self.assertRaisesRegex(TypeError,
       re.escape("scan body output must be a pair, got ShapedArray(float32[]).")):
     lax.scan(lambda c, x: np.float32(0.), 0, a)
   with  self.assertRaisesRegex(TypeError,
       re.escape("scan carry output and input must have same type structure, "
                 "got PyTreeDef(tuple, [*,*,*]) and PyTreeDef(tuple, [*,PyTreeDef(tuple, [*,*])])")):
     lax.scan(lambda c, x: ((0, 0, 0), x), (1, (2, 3)), a)
   with self.assertRaisesRegex(TypeError,
       re.escape("scan carry output and input must have same type structure, got * and PyTreeDef(None, []).")):
     lax.scan(lambda c, x: (0, x), None, a)
   with self.assertRaisesWithLiteralMatch(
       TypeError,
       "scan carry output and input must have identical types, got\n"
       "ShapedArray(int32[])\n"
       "and\n"
       "ShapedArray(float32[])."):
     lax.scan(lambda c, x: (np.int32(0), x), np.float32(1.0), a)
   with self.assertRaisesRegex(TypeError,
       re.escape("scan carry output and input must have same type structure, got * and PyTreeDef(tuple, [*,*]).")):
     lax.scan(lambda c, x: (0, x), (1, 2), np.arange(5))
示例#29
0
 def testScalarCastInsideJitWorks(self):
     # jnp.int32(tracer) should work.
     self.assertEqual(jnp.int32(101),
                      jax.jit(lambda x: jnp.int32(x))(jnp.float32(101.4)))
示例#30
0
 def testForiLoopErrors(self):
   """Test typing error messages for while."""
   with self.assertRaisesRegex(
     TypeError, "arguments to fori_loop must have equal types"):
     lax.fori_loop(onp.int16(0), np.int32(10), (lambda i, c: c), np.float32(7))