Beispiel #1
0
def _ibp_integer_pow(x: PrimitiveInput, y: int) -> IntervalBound:
    """Propagation of IBP bounds through integer_pow.

  Args:
    x: Argument be raised to a power, element-wise
    y: fixed integer exponent

  Returns:
    out_bounds: integer_pow output or its bounds.
  """
    if y < 0:
        raise NotImplementedError
    l_pow = lax.integer_pow(x.lower, y)
    u_pow = lax.integer_pow(x.upper, y)

    if y % 2 == 0:
        # Even powers
        contains_zero = jnp.logical_and(jnp.less_equal(x.lower, 0),
                                        jnp.greater_equal(x.upper, 0))
        lower = jnp.where(contains_zero, jnp.zeros_like(x.lower),
                          jnp.minimum(l_pow, u_pow))
        upper = jnp.maximum(l_pow, u_pow)
        return IntervalBound(lower, upper)
    else:
        # Odd powers
        return IntervalBound(l_pow, u_pow)
Beispiel #2
0
 def log_cdf(self, value: Array) -> Array:
   """See `Distribution.log_cdf`."""
   norm_value = self._standardize(value)
   lower_value = norm_value - math.log(2.)
   exp_neg_norm_value = jnp.exp(-jnp.abs(norm_value))
   upper_value = jnp.log1p(-0.5 * exp_neg_norm_value)
   return jnp.where(jnp.less_equal(norm_value, 0.), lower_value, upper_value)
Beispiel #3
0
    def _outgoing_edges_in_bounds_pixel(self, node_id, relation_ids):
        """Returns the outgoing edges from `node_id`.

    Pixels outside of the image are mapped to node_id -1.
    This "out of bounds pixel" node points to the start node.
    The considered initial node must be inside the image bounds.

    Args:
      node_id: ID of start node
      relation_ids: IDs of the relations to neighbors.

    Returns:
      A list of neighbor ids, in the order given by relation_ids.
    """
        node_coordinates = self._pixel_id_to_coordinates(node_id)
        neighbor_coordinates = jnp.repeat(
            jnp.array(node_coordinates)[jnp.newaxis, :],
            len(self.RELATION_OFFSETS), 0)
        offsets = jnp.array(self.RELATION_OFFSETS)
        neighbor_coordinates = neighbor_coordinates + offsets
        # make sure coords are within bounds, or -1
        neighbor_coordinates = jnp.where(
            jnp.less_equal(neighbor_coordinates,
                           jnp.array(self.image.shape) - 1),
            neighbor_coordinates, -1)
        # make sure coordinates are >= 0, or -1
        neighbor_coordinates = jnp.where(
            jnp.all(neighbor_coordinates >= 0, axis=-1, keepdims=True),
            neighbor_coordinates, -1)

        neighbor_ids = jax.vmap(
            self._pixel_coordinates_to_id)(neighbor_coordinates)
        return neighbor_ids
Beispiel #4
0
    def __call__(self, value, update_stats=True):
        """Updates the EMA and returns the new value.

    Args:
      value: The array-like object for which you would like to perform an
        exponential decay on.
      update_stats: A Boolean, whether to update the internal state
        of this object to reflect the input value. When `update_stats` is False
        the internal stats will remain unchanged.

    Returns:
      The exponentially weighted average of the input value.

    """
        value = jnp.asarray(value)  # Ensure value has a dtype.
        prev_counter = base.get_state(
            "counter",
            shape=(),
            dtype=jnp.int32,
            init=initializers.Constant(-self._warmup_length))
        prev_hidden = base.get_state("hidden",
                                     shape=value.shape,
                                     dtype=value.dtype,
                                     init=jnp.zeros)

        decay = jnp.asarray(self._decay).astype(value.dtype)
        counter = prev_counter + 1
        decay = self._cond(jnp.less_equal(counter, 0), 0.0, decay, value.dtype)
        hidden = prev_hidden * decay + value * (1 - decay)

        if self._zero_debias:
            average = hidden / (1. - jnp.power(decay, counter))
        else:
            average = hidden

        if update_stats:
            base.set_state("counter", counter)
            base.set_state("hidden", hidden)
            base.set_state("average", average)
        return average
Beispiel #5
0
def scatter_in_bounds(operand, indices, updates, dnums):
    # Ref: see clamping code used in scatter_translation_rule
    slice_sizes = []
    pos = 0
    for i in range(len(operand.shape)):
        if i in dnums.inserted_window_dims:
            slice_sizes.append(1)
        else:
            slice_sizes.append(updates.shape[dnums.update_window_dims[pos]])
            pos += 1

    upper_bound = np.array([
        operand.shape[i] - slice_sizes[i]
        for i in dnums.scatter_dims_to_operand_dims
    ], np.int64)
    upper_bound = np.minimum(upper_bound, np.iinfo(indices.dtype).max)
    upper_bound = lax.broadcast_in_dim(upper_bound, indices.shape,
                                       (len(indices.shape) - 1, ))

    lower_in_bounds = jnp.all(jnp.greater_equal(indices, 0))
    upper_in_bounds = jnp.all(jnp.less_equal(indices, upper_bound))
    return jnp.logical_and(lower_in_bounds, upper_in_bounds)
Beispiel #6
0
 def sigmoid_lower_convex(x):
     return jnp.where(jnp.less_equal(x, low_tang_point), sigmoid(x),
                      low_lin_slope * x + low_lin_offset)
Beispiel #7
0
def less_equal(x1, x2):
  if isinstance(x1, JaxArray): x1 = x1.value
  if isinstance(x2, JaxArray): x2 = x2.value
  return JaxArray(jnp.less_equal(x1, x2))
Beispiel #8
0
 def cond_func(args):
     xi, _ = args
     return jnp.less_equal(xi, total_count)
Beispiel #9
0
def le(a: Numeric, b: Numeric):
    return jnp.less_equal(a, b)
Beispiel #10
0
def sequence_mask(lengths: jnp.ndarray, maxlen):
  batch_size = lengths.shape[0]
  a = jnp.ones([batch_size, maxlen])
  b = jnp.cumsum(a, axis=-1)
  c = jnp.less_equal(b, lengths[:, jnp.newaxis]).astype(lengths.dtype)
  return c
Beispiel #11
0
 def _less_equal(a, b):
     return jnp.less_equal(a, b)
Beispiel #12
0
 def _inverse(self, y):
     # We perform clipping in the _inverse function, as is done in TF-Agents.
     y = jnp.where(jnp.less_equal(jnp.abs(y), 1.),
                   tf.clip(y, -0.99999997, 0.99999997), y)
     return jnp.arctanh(y)
Beispiel #13
0
is_strictly_increasing = utils.copy_docstring(
    tf.math.is_strictly_increasing,
    lambda x, name=None: np.all(x[1:] > x[:-1]))

l2_normalize = utils.copy_docstring(
    tf.math.l2_normalize,
    lambda x, axis=None, epsilon=1e-12, name=None: (  # pylint: disable=g-long-lambda
        np.linalg.norm(x, ord=2, axis=axis, keepdims=True)))

lbeta = utils.copy_docstring(tf.math.lbeta, _lbeta)

less = utils.copy_docstring(tf.math.less,
                            lambda x, y, name=None: np.less(x, y))

less_equal = utils.copy_docstring(tf.math.less_equal,
                                  lambda x, y, name=None: np.less_equal(x, y))

lgamma = utils.copy_docstring(tf.math.lgamma,
                              lambda x, name=None: scipy_special.gammaln(x))

log = utils.copy_docstring(tf.math.log, lambda x, name=None: np.log(x))

log1p = utils.copy_docstring(tf.math.log1p, lambda x, name=None: np.log1p(x))

log_sigmoid = utils.copy_docstring(tf.math.log_sigmoid,
                                   lambda x, name=None: -np.log1p(np.exp(-x)))

log_softmax = utils.copy_docstring(
    tf.math.log_softmax,
    lambda logits, axis=None, name=None: (
        np.subtract(  # pylint: disable=g-long-lambda
Beispiel #14
0
 def np_fn(input_np, v_current, gamma, tau_m, Vth, dt):
     dV = jnp.multiply(
         jnp.divide(jnp.subtract(input_np, v_current), tau_m), dt)
     return jnp.greater_equal(v_current + dV, Vth).astype('float32'), \
            jnp.exp(-1 / tau_m) * jnp.less_equal(v_current + dV, Vth).astype('float32') * jnp.where((v_current + dV)<0,0,v_current + dV), \
            gamma + jnp.greater_equal(v_current + dV, Vth).astype('int32')