Exemplo n.º 1
0
    def can_condition(self, val: jnp.ndarray):
        in_range = jnp.logical_and(
            jnp.all(jnp.greater_equal(val, jnp.array(0.))),
            jnp.all(jnp.greater_equal(self.n, val)))
        correct_size = val.shape == self.p.shape

        return jnp.logical_and(
            jnp.logical_and(utils.is_integer(val), correct_size), in_range)
Exemplo n.º 2
0
def stimulate(t, X, stimuli):
    stimulated = np.zeros_like(X)
    for stimulus in stimuli:
        active = np.greater_equal(t, stimulus["start"])
        active &= (np.mod(stimulus["start"] - t + 1, stimulus["period"]) < stimulus["duration"])
        stimulated = np.where(stimulus["field"] * (active), stimulus["field"], stimulated)
    return np.where(stimulated != 0, stimulated, X)
Exemplo n.º 3
0
 def topk_mask_internal(value):
     assert value.ndim == 1
     indices = jnp.argsort(value)
     k = jnp.round(density_fraction * jnp.size(value)).astype(jnp.int32)
     mask = jnp.greater_equal(np.arange(value.size), value.size - k)
     mask = jnp.zeros_like(mask).at[indices].set(mask)
     return mask.astype(np.int32)
Exemplo n.º 4
0
        def np_fn(input_np, v_current, gamma, tau_m, Vth, dt):
            spike = jnp.greater_equal(input_np + v_current,
                                      Vth).astype('float32')
            v_current = input_np + v_current - spike

            return spike, jnp.multiply(jnp.exp(
                -1 / tau_m), v_current), gamma + spike.astype('int32')
Exemplo n.º 5
0
 def _sample(self, rng, loc):
     rng, key = jax.random.split(rng)
     _, loc, counter = jax.lax.while_loop(self.w_cond, self.__sample,
                                          (rng, self.mvn(key, loc), 0))
     return jax.lax.cond(np.greater_equal(counter, self.max_counter),
                         lambda _: np.nan * np.ones(
                             (self.n_params, )), lambda _: loc, None)
Exemplo n.º 6
0
 def multiplier(self, step: float):
     """Returns step decay learning rate multiplier."""
     if isinstance(self.step_size, (tuple, list)):
         exponent = jn.sum(jn.greater_equal(step, jn.array(self.step_size)))
     else:
         exponent = step // self.step_size
     return self.gamma**exponent
Exemplo n.º 7
0
def _ibp_integer_pow(x: PrimitiveInput, y: int) -> IntervalBound:
    """Propagation of IBP bounds through integer_pow.

  Args:
    x: Argument be raised to a power, element-wise
    y: fixed integer exponent

  Returns:
    out_bounds: integer_pow output or its bounds.
  """
    if y < 0:
        raise NotImplementedError
    l_pow = lax.integer_pow(x.lower, y)
    u_pow = lax.integer_pow(x.upper, y)

    if y % 2 == 0:
        # Even powers
        contains_zero = jnp.logical_and(jnp.less_equal(x.lower, 0),
                                        jnp.greater_equal(x.upper, 0))
        lower = jnp.where(contains_zero, jnp.zeros_like(x.lower),
                          jnp.minimum(l_pow, u_pow))
        upper = jnp.maximum(l_pow, u_pow)
        return IntervalBound(lower, upper)
    else:
        # Odd powers
        return IntervalBound(l_pow, u_pow)
Exemplo n.º 8
0
def step(state, t, params, D, stimuli, dt, dx):
    v, w, u = state

    # apply stimulus
    u = stimulate(t, u, stimuli)

    # apply boundary conditions
    v = neumann(v)
    w = neumann(w)
    u = neumann(u)

    # gate variables
    p = np.greater_equal(u, params["V_c"])
    q = np.greater_equal(u, params["V_v"])
    tau_v_minus = (1 - q) * params["tau_v1_minus"] + q * params["tau_v2_minus"]

    d_v = ((1 - p) * (1 - v) / tau_v_minus) - ((p * v) / params["tau_v_plus"])
    d_w = ((1 - p) * (1 - w) / params["tau_w_minus"]) - ((p * w) / params["tau_w_plus"])

    # currents
    J_fi = - v * p * (u - params["V_c"]) * (1 - u) / params["tau_d"]
    J_so = (u * (1 - p) / params["tau_0"]) + (p / params["tau_r"])
    J_si = - (w * (1 + np.tanh(params["k"] * (u - params["V_csi"])))) / (2 * params["tau_si"])

    I_ion = -(J_fi + J_so + J_si) / params["Cm"]

    # voltage01
#     u_x, u_y = np.gradient(u)
    u_x, u_y = gradient(u, 0), gradient(u, 1)
    u_x /= dx
    u_y /= dx
#     u_xx = np.gradient(u_x, axis=0)
#     u_yy = np.gradient(u_y, axis=1)
    u_xx = gradient(u_x, 0)
    u_yy = gradient(u_y, 1)
    u_xx /= dx
    u_yy /= dx
#     D_x, D_y = np.gradient(D)
#     D_x /= dx
#     D_y /= dx
    d_u = D * (u_xx + u_yy) + I_ion
    
    # euler update
    v += d_v * dt
    w += d_w * dt
    u += d_u * dt
    return np.asarray((v, w, u))
Exemplo n.º 9
0
def get_top_k_weights(
    top_k_fraction: float,
    restarting_weights: Array,
    scaled_advantages: Array,
    axis_name: Optional[str] = None,
    use_stop_gradient: bool = True,
):
  """Get the weights for the top top_k_fraction of advantages.

  Args:
    top_k_fraction: The fraction of weights to use.
    restarting_weights: Restarting weights, shape E*, 0 means that this step is
      the start of a new episode and we ignore losses at this step because the
      agent cannot influence these.
    scaled_advantages: The advantages for each example (shape E*), scaled by
      temperature.
    axis_name: Optional axis name for `pmap`. If `None`, computations are
      performed locally on each device.
    use_stop_gradient: bool indicating whether or not to apply stop gradient.

  Returns:
    Weights for the top top_k_fraction of advantages
  """
  chex.assert_equal_shape([scaled_advantages, restarting_weights])
  chex.assert_type([scaled_advantages, restarting_weights], float)

  if not 0.0 < top_k_fraction <= 1.0:
    raise ValueError(
        f"`top_k_fraction` must be in (0, 1], got {top_k_fraction}")
  logging.info("[vmpo_e_step] top_k_fraction: %f", top_k_fraction)

  if top_k_fraction < 1.0:
    # Don't include the restarting samples in the determination of top-k.
    valid_scaled_advantages = scaled_advantages - (
        1.0 - restarting_weights) * _INFINITY
    # Determine the minimum top-k value across all devices,
    if axis_name:
      all_valid_scaled_advantages = jax.lax.all_gather(
          valid_scaled_advantages, axis_name=axis_name)
    else:
      all_valid_scaled_advantages = valid_scaled_advantages
    top_k = int(top_k_fraction * jnp.size(all_valid_scaled_advantages))
    if top_k == 0:
      raise ValueError(
          "top_k_fraction too low to get any valid scaled advantages.")
    # TODO(b/160450251): Use jnp.partition(all_valid_scaled_advantages, top_k)
    #   when this is implemented in jax.
    top_k_min = jnp.sort(jnp.reshape(all_valid_scaled_advantages, [-1]))[-top_k]
    # Fold the top-k into the restarting weights.
    top_k_weights = jnp.greater_equal(valid_scaled_advantages,
                                      top_k_min).astype(jnp.float32)
    top_k_weights = jax.lax.select(
        use_stop_gradient, jax.lax.stop_gradient(top_k_weights), top_k_weights)
    top_k_restarting_weights = restarting_weights * top_k_weights
  else:
    top_k_restarting_weights = restarting_weights

  return top_k_restarting_weights
Exemplo n.º 10
0
def step(state, t, params, diffusivity, stimuli, dt, dx):
    # neumann boundary conditions
    v = jnp.pad(state.v, 1, mode="edge")
    w = jnp.pad(state.w, 1, mode="edge")
    u = jnp.pad(state.u, 1, mode="edge")
    diffusivity = jnp.pad(diffusivity, 1, mode="edge")

    # reaction term
    p = jnp.greater_equal(u, params.V_c)
    q = jnp.greater_equal(u, params.V_v)
    tau_v_minus = (1 - q) * params.tau_v1_minus + q * params.tau_v2_minus

    j_fi = -v * p * (u - params.V_c) * (1 - u) / params.tau_d
    j_so = (u * (1 - p) / params.tau_0) + (p / params.tau_r)
    j_si = -(w * (1 + jnp.tanh(params.k *
                               (u - params.V_csi)))) / (2 * params.tau_si)
    j_ion = -(j_fi + j_so + j_si) / params.Cm

    # apply stimulus by introducing fictitious current
    stimuli = [
        s._replace(field=jnp.pad(s.field, 1, mode="edge")) for s in stimuli
    ]
    j_ion = stimulate(t, j_ion, stimuli)

    # diffusion term
    u_x = gradient(u, 0) / dx
    u_y = gradient(u, 1) / dx
    u_xx = gradient(u_x, 0) / dx
    u_yy = gradient(u_y, 1) / dx
    D_x = gradient(diffusivity, 0) / dx
    D_y = gradient(diffusivity, 1) / dx
    del_u = diffusivity * (u_xx + u_yy) + (D_x * u_x) + (D_y * u_y)

    d_v = ((1 - p) * (1 - v) / tau_v_minus) - ((p * v) / params.tau_v_plus)
    d_w = ((1 - p) *
           (1 - w) / params.tau_w_minus) - ((p * w) / params.tau_w_plus)
    d_u = del_u + j_ion

    # euler update and unpadding
    v = state.v + d_v[1:-1, 1:-1] * dt
    w = state.w + d_w[1:-1, 1:-1] * dt
    u = state.u + d_u[1:-1, 1:-1] * dt
    del_u = del_u[1:-1, 1:-1]
    j_ion = j_ion[1:-1, 1:-1]
    return State(v, w, u, del_u, j_ion)
Exemplo n.º 11
0
        def np_fn(input_np, v_current, gamma, tau_m, Vth, dt):
            v_current = ((input_np - v_current) / tau_m) * dt
            spike = np.greater_equal(
                v_current + np.multiply(
                    np.divide(np.subtract(input_np, v_current), tau_m), dt),
                Vth).astype('float32')

            gamma += np.where(spike >= Vth, 1, 0)
            return spike, v_current, gamma
Exemplo n.º 12
0
    def forward(self, x, v_current):
        dV_tau = jnp.multiply(jnp.subtract(x, v_current), self.dt)
        dV = jnp.divide(dV_tau, self.tau_m)
        v_current = index_add(v_current, index[:], dV)
        spike_list = jnp.greater_equal(v_current, self.Vth).astype('int32')
        v_current = jnp.where(v_current >= self.Vth, 0,
                              v_current * jnp.exp(-1 / self.tau_m))

        return spike_list, v_current
Exemplo n.º 13
0
def current_stimulate(t, X, stimuli):
    stimulated = np.zeros_like(X)
    for stimulus in stimuli:
        # active = np.greater_equal(t, stimulus["start"])
        # active &= (np.mod(stimulus["start"] - t + 1, stimulus["period"]) < stimulus["duration"])
        active = np.greater_equal(t ,stimulus["start"])
        # active &= np.greater_equal(stimulus["start"] + stimulus["duration"],t)
        active &= (np.mod(t - stimulus["start"], stimulus["period"]) < stimulus["duration"]) # this works for cyclics
        stimulated = np.where(stimulus["field"] * (active), stimulus["field"], stimulated)
    return np.where(stimulated != 0, stimulated, X)
Exemplo n.º 14
0
def _create_over_capacity_ratio_summary(mask, position_in_expert, capacity,
                                        name):
    _ = name  # TODO(lepikhin): consider inlined summary
    masked_position_in_expert = mask * position_in_expert
    ge_capacity = jnp.greater_equal(masked_position_in_expert, capacity)
    over_capacity = jnp.sum(ge_capacity).astype(jnp.float32)
    denom = jnp.sum(mask).astype(jnp.float32)
    over_capacity_ratio = over_capacity / jnp.maximum(
        jnp.array(1.0, dtype=jnp.float32), denom)
    return over_capacity_ratio
Exemplo n.º 15
0
def stimulate(t, X, stimuli):
    stimulated = jnp.zeros_like(X)
    for stimulus in stimuli:
        # check if stimulus is in the past
        active = jnp.greater_equal(t, stimulus.protocol.start)
        # check if stimulus is active at the current time
        active &= (jnp.mod(stimulus.protocol.start - t + 1,
                           stimulus.protocol.period) <
                   stimulus.protocol.duration)
        # build the stimulus field
        stimulated = jnp.where(stimulus.field * (active), stimulus.field,
                               stimulated)
    # set the field to the stimulus
    return jnp.where(stimulated != 0, stimulated, X)
Exemplo n.º 16
0
def top_k_accuracy(top_k: int,
                   logits: JTensor,
                   label_ids: Optional[JTensor] = None,
                   label_probs: Optional[JTensor] = None,
                   weights: Optional[JTensor] = None) -> JTensor:
    """Computes the top-k accuracy given the logits and labels.

  Args:
    top_k: An int scalar, specifying the value of top-k.
    logits: A [..., C] float tensor corresponding to the logits.
    label_ids: A [...] int vector corresponding to the class labels. One of
      label_ids and label_probs should be presented.
    label_probs: A [..., C] float vector corresponding to the class
      probabilites. Must be presented if label_ids is None.
    weights: A [...] float vector corresponding to the weight to assign to each
      example.

  Returns:
    The top-k accuracy represented as a `JTensor`.

  Raises:
    ValueError if neither `label_ids` nor `label_probs` are provided.
  """
    if label_ids is None and label_probs is None:
        raise ValueError("One of label_ids and label_probs should be given.")
    if label_ids is None:
        label_ids = jnp.argmax(label_probs, axis=-1)

    values, _ = jax.lax.top_k(logits, k=top_k)
    threshold = jnp.min(values, axis=-1)

    # Reshape logits to [-1, C].
    logits_reshaped = jnp.reshape(logits, [-1, logits.shape[-1]])

    # Reshape label_ids to [-1, 1].
    label_ids_reshaped = jnp.reshape(label_ids, [-1, 1])
    logits_slice = jnp.take_along_axis(logits_reshaped,
                                       label_ids_reshaped,
                                       axis=-1)[..., 0]

    # Reshape logits_slice back to original shape to be compatible with weights.
    logits_slice = jnp.reshape(logits_slice, label_ids.shape)
    correct = jnp.greater_equal(logits_slice, threshold)
    correct_sum = jnp.sum(correct * weights)
    all_sum = jnp.maximum(1.0, jnp.sum(weights))
    return correct_sum / all_sum
Exemplo n.º 17
0
def scatter_in_bounds(operand, indices, updates, dnums):
    # Ref: see clamping code used in scatter_translation_rule
    slice_sizes = []
    pos = 0
    for i in range(len(operand.shape)):
        if i in dnums.inserted_window_dims:
            slice_sizes.append(1)
        else:
            slice_sizes.append(updates.shape[dnums.update_window_dims[pos]])
            pos += 1

    upper_bound = np.array([
        operand.shape[i] - slice_sizes[i]
        for i in dnums.scatter_dims_to_operand_dims
    ], np.int64)
    upper_bound = np.minimum(upper_bound, np.iinfo(indices.dtype).max)
    upper_bound = lax.broadcast_in_dim(upper_bound, indices.shape,
                                       (len(indices.shape) - 1, ))

    lower_in_bounds = jnp.all(jnp.greater_equal(indices, 0))
    upper_in_bounds = jnp.all(jnp.less_equal(indices, upper_bound))
    return jnp.logical_and(lower_in_bounds, upper_in_bounds)
Exemplo n.º 18
0
def general_loss_with_squared_residual(squared_x, alpha, scale):
    r"""The general loss that takes a squared residual.

  This fuses the sqrt operation done to compute many residuals while preserving
  the square in the loss formulation.

  This implements the rho(x, \alpha, c) function described in "A General and
  Adaptive Robust Loss Function", Jonathan T. Barron,
  https://arxiv.org/abs/1701.03077.

  Args:
    squared_x: The residual for which the loss is being computed. x can have
      any shape, and alpha and scale will be broadcasted to match x's shape if
      necessary.
    alpha: The shape parameter of the loss (\alpha in the paper), where more
      negative values produce a loss with more robust behavior (outliers "cost"
      less), and more positive values produce a loss with less robust behavior
      (outliers are penalized more heavily). Alpha can be any value in
      [-infinity, infinity], but the gradient of the loss with respect to alpha
      is 0 at -infinity, infinity, 0, and 2. Varying alpha allows for smooth
      interpolation between several discrete robust losses:
        alpha=-Infinity: Welsch/Leclerc Loss.
        alpha=-2: Geman-McClure loss.
        alpha=0: Cauchy/Lortentzian loss.
        alpha=1: Charbonnier/pseudo-Huber loss.
        alpha=2: L2 loss.
    scale: The scale parameter of the loss. When |x| < scale, the loss is an
      L2-like quadratic bowl, and when |x| > scale the loss function takes on a
      different shape according to alpha.

  Returns:
    The losses for each element of x, in the same shape as x.
  """
    eps = jnp.finfo(jnp.float32).eps

    # This will be used repeatedly.
    squared_scaled_x = squared_x / (scale**2)

    # The loss when alpha == 2.
    loss_two = 0.5 * squared_scaled_x
    # The loss when alpha == 0.
    loss_zero = log1p_safe(0.5 * squared_scaled_x)
    # The loss when alpha == -infinity.
    loss_neginf = -jnp.expm1(-0.5 * squared_scaled_x)
    # The loss when alpha == +infinity.
    loss_posinf = expm1_safe(0.5 * squared_scaled_x)

    # The loss when not in one of the above special cases.
    # Clamp |2-alpha| to be >= machine epsilon so that it's safe to divide by.
    beta_safe = jnp.maximum(eps, jnp.abs(alpha - 2.))
    # Clamp |alpha| to be >= machine epsilon so that it's safe to divide by.
    alpha_safe = jnp.where(jnp.greater_equal(alpha, 0.), jnp.ones_like(alpha),
                           -jnp.ones_like(alpha)) * jnp.maximum(
                               eps, jnp.abs(alpha))
    loss_otherwise = (beta_safe / alpha_safe) * (
        jnp.power(squared_scaled_x / beta_safe + 1., 0.5 * alpha) - 1.)

    # Select which of the cases of the loss to return.
    loss = jnp.where(
        alpha == -jnp.inf, loss_neginf,
        jnp.where(
            alpha == 0, loss_zero,
            jnp.where(alpha == 2, loss_two,
                      jnp.where(alpha == jnp.inf, loss_posinf,
                                loss_otherwise))))

    return scale * loss
Exemplo n.º 19
0
 def _greater_equal(a, b):
     return jnp.greater_equal(a, b)
Exemplo n.º 20
0
                            lambda x, name=None: scipy_special.erfc(x))

exp = utils.copy_docstring(tf.math.exp, lambda x, name=None: np.exp(x))

expm1 = utils.copy_docstring(tf.math.expm1, lambda x, name=None: np.expm1(x))

floor = utils.copy_docstring(tf.math.floor, lambda x, name=None: np.floor(x))

floordiv = utils.copy_docstring(tf.math.floordiv,
                                lambda x, y, name=None: np.floor_divide(x, y))

greater = utils.copy_docstring(tf.math.greater,
                               lambda x, y, name=None: np.greater(x, y))

greater_equal = utils.copy_docstring(
    tf.math.greater_equal, lambda x, y, name=None: np.greater_equal(x, y))

igamma = utils.copy_docstring(
    tf.math.igamma, lambda a, x, name=None: scipy_special.gammainc(a, x))

igammac = utils.copy_docstring(
    tf.math.igammac, lambda a, x, name=None: scipy_special.gammaincc(a, x))

imag = utils.copy_docstring(tf.math.imag,
                            lambda input, name=None: np.imag(input))

# in_top_k = utils.copy_docstring(
#     tf.math.in_top_k,
#     lambda targets, predictions, k, name=None: np.in_top_k)

# TODO(b/256095991): Add unit-test.
Exemplo n.º 21
0
def step(state, t, params, D, stimuli, dt, dx):
    # v, w, u, at, max_du = state
    v, w, u = state

    # apply stimulus
    u = np.where(params["current_stimulus"], u, stimulate(t, u, stimuli))

    # apply boundary conditions
    v = neumann(v)
    w = neumann(w)
    u = neumann(u)

    # gate variables
    p = np.greater_equal(u, params["V_c"])
    q = np.greater_equal(u, params["V_v"])
    tau_v_minus = (1 - q) * params["tau_v1_minus"] + q * params["tau_v2_minus"]

    d_v = ((1 - p) * (1 - v) / tau_v_minus) - ((p * v) / params["tau_v_plus"])
    d_w = ((1 - p) * (1 - w) / params["tau_w_minus"]) - ((p * w) / params["tau_w_plus"])

    # currents
    J_fi = - v * p * (u - params["V_c"]) * (1 - u) / params["tau_d"]
    J_so = (u * (1 - p) / params["tau_0"]) + (p / params["tau_r"])
    J_si = - (w * (1 + np.tanh(params["k"] * (u - params["V_csi"])))) / (2 * params["tau_si"])

    I_ion = -(J_fi + J_so + J_si) / params["Cm"]

    # voltage01
#     u_x, u_y = np.gradient(u)
    u_x, u_y = gradient(u, 0), gradient(u, 1)
    u_x /= dx
    u_y /= dx
#     u_xx = np.gradient(u_x, axis=0)
#     u_yy = np.gradient(u_y, axis=1)
    u_xx = gradient(u_x, 0)
    u_yy = gradient(u_y, 1)
    u_xx /= dx
    u_yy /= dx
#     D_x, D_y = np.gradient(D)
#     D_x /= dx
#     D_y /= dx

# Kostas ---------
    D_x, D_y = gradient(D,0), gradient(D,1)
    D_x /= dx
    D_y /= dx
    extra_term = D_x*u_x + D_y*u_y


    current_stimuli = np.zeros(u.shape)
    current_stimuli = np.where(params["current_stimulus"], stimulate(t, current_stimuli, stimuli), current_stimuli)

# Kostas ---------


    d_u = D * (u_xx + u_yy) + extra_term + I_ion + current_stimuli
    
    # checking du for activation time update
    # at = np.where(np.greater_equal(d_u,max_du), t, at)
    # max_du = np.where(np.greater_equal(d_u,max_du), d_u, max_du)

    # euler update
    v += d_v * dt
    w += d_w * dt
    u += d_u * dt

    
    return np.asarray((v, w, u))
Exemplo n.º 22
0
 def sigmoid_upper_concave(x):
     return jnp.where(jnp.greater_equal(x, up_tangent_point), sigmoid(x),
                      up_lin_slope * x + up_lin_offset)
Exemplo n.º 23
0
 def np_fn(input_np, v_current, gamma, tau_m, Vth, dt):
     dV = jnp.multiply(
         jnp.divide(jnp.subtract(input_np, v_current), tau_m), dt)
     return jnp.greater_equal(v_current + dV, Vth).astype('float32'), \
            jnp.exp(-1 / tau_m) * jnp.less_equal(v_current + dV, Vth).astype('float32') * jnp.where((v_current + dV)<0,0,v_current + dV), \
            gamma + jnp.greater_equal(v_current + dV, Vth).astype('int32')
Exemplo n.º 24
0
def ge(a: Numeric, b: Numeric):
    return jnp.greater_equal(a, b)
Exemplo n.º 25
0
def greater_equal(x1, x2):
  if isinstance(x1, JaxArray): x1 = x1.value
  if isinstance(x2, JaxArray): x2 = x2.value
  return JaxArray(jnp.greater_equal(x1, x2))
Exemplo n.º 26
0
 def can_condition(self, val: jnp.ndarray):
     in_range = jnp.logical_and(
         jnp.all(jnp.greater_equal(val, jnp.array(0.))),
         jnp.all(jnp.greater_equal(jnp.array(1.), val)))
     return in_range
Exemplo n.º 27
0
 def can_condition(self, val: jnp.ndarray):
     in_range = jnp.logical_and(
         jnp.all(jnp.greater_equal(val, jnp.array(0.))),
         jnp.all(jnp.greater_equal(self.n, val)))
     return jnp.logical_and(utils.is_integer(val), in_range)