예제 #1
0
def value_loss_given_predictions(value_prediction,
                                 rewards,
                                 reward_mask,
                                 gamma=0.99,
                                 epsilon=0.2,
                                 value_prediction_old=None):
  """Computes the value loss given the prediction of the value function.

  Args:
    value_prediction: np.ndarray of shape (B, T+1, 1)
    rewards: np.ndarray of shape (B, T) of rewards.
    reward_mask: np.ndarray of shape (B, T), the mask over rewards.
    gamma: float, discount factor.
    epsilon: float, clip-fraction, used if value_value_prediction_old isn't None
    value_prediction_old: np.ndarray of shape (B, T+1, 1) of value predictions
      using the old parameters. If provided, we incorporate this in the loss as
      well. This is from the OpenAI baselines implementation.

  Returns:
    The average L2 value loss, averaged over instances where reward_mask is 1.
  """

  B, T = rewards.shape  # pylint: disable=invalid-name
  assert (B, T) == reward_mask.shape
  assert (B, T + 1, 1) == value_prediction.shape

  value_prediction = np.squeeze(value_prediction, axis=2)  # (B, T+1)
  value_prediction = value_prediction[:, :-1] * reward_mask  # (B, T)
  r2g = rewards_to_go(rewards, reward_mask, gamma=gamma)  # (B, T)
  loss = (value_prediction - r2g)**2

  # From the baselines implementation.
  if value_prediction_old is not None:
    value_prediction_old = np.squeeze(value_prediction_old, axis=2)  # (B, T+1)
    value_prediction_old = value_prediction_old[:, :-1] * reward_mask  # (B, T)

    v_clipped = value_prediction_old + np.clip(
        value_prediction - value_prediction_old, -epsilon, epsilon)
    v_clipped_loss = (v_clipped - r2g)**2
    loss = np.maximum(v_clipped_loss, loss)

  # Take an average on only the points where mask != 0.
  return np.sum(loss) / np.sum(reward_mask)
예제 #2
0
파일: privacy.py 프로젝트: ksachdeva/optax
  def update_fn(updates, state, params=None):
    del params
    grads_flat, grads_treedef = jax.tree_flatten(updates)
    bsize = grads_flat[0].shape[0]

    if any(g.ndim == 0 or bsize != g.shape[0] for g in grads_flat):
      raise ValueError(
          'Unlike other transforms, `differentially_private_aggregate` expects'
          ' `updates` to have a batch dimension in the 0th axis. That is, this'
          ' function expects per-example gradients as input.')

    new_key, *rngs = jax.random.split(state.rng_key, len(grads_flat)+1)
    global_grad_norms = jax.vmap(utils.global_norm)(grads_flat)
    divisors = jnp.maximum(global_grad_norms / l2_norm_clip, 1.0)
    clipped = [(jnp.moveaxis(g, 0, -1) / divisors).sum(-1) for g in grads_flat]
    noised = [(g + noise_std * jax.random.normal(r, g.shape, g.dtype)) / bsize
              for g, r in zip(clipped, rngs)]
    return (jax.tree_unflatten(grads_treedef, noised),
            DifferentiallyPrivateAggregateState(rng_key=new_key))
예제 #3
0
파일: ode.py 프로젝트: xf05888/jax
def optimal_step_size(last_step,
                      mean_error_ratio,
                      safety=0.9,
                      ifactor=10.0,
                      dfactor=0.2,
                      order=5.0):
    """Compute optimal Runge-Kutta stepsize."""
    mean_error_ratio = np.max(mean_error_ratio)
    dfactor = np.where(mean_error_ratio < 1, 1.0, dfactor)

    err_ratio = np.sqrt(mean_error_ratio)
    factor = np.maximum(
        1.0 / ifactor,
        np.minimum(err_ratio**(1.0 / order) / safety, 1.0 / dfactor))
    return np.where(
        mean_error_ratio == 0,
        last_step * ifactor,
        last_step / factor,
    )
예제 #4
0
    def assert_close(expected, actual):
        self.assertEqual(expected.shape, actual.shape)
        relative_error = (np.linalg.norm(actual - expected) /
                          np.maximum(np.linalg.norm(expected), 1e-12))

        absolute_error = np.mean(np.abs(actual - expected))

        if (np.isnan(relative_error) or relative_error > rtol
                or absolute_error > atol):
            _log(relative_error, absolute_error, expected, actual, False)
            self.fail(
                self.failureException('Relative ERROR: ',
                                      float(relative_error),
                                      'EXPECTED:' + ' ' * 50, expected,
                                      'ACTUAL:' + ' ' * 50, actual,
                                      ' ' * 50, 'Absolute ERROR: ',
                                      float(absolute_error)))
        else:
            _log(relative_error, absolute_error, expected, actual, True)
예제 #5
0
 def optimize(state,
              grad,
              warmup=config.optim.warmup,
              grad_clip=config.optim.grad_clip):
     """Optimizes with warmup and gradient clipping (disabled if negative)."""
     lr = state.lr
     if warmup > 0:
         lr = lr * jnp.minimum(state.step / warmup, 1.0)
     if grad_clip >= 0:
         # Compute global gradient norm
         grad_norm = jnp.sqrt(
             sum([jnp.sum(jnp.square(x)) for x in jax.tree_leaves(grad)]))
         # Clip gradient
         clipped_grad = jax.tree_map(
             lambda x: x * grad_clip / jnp.maximum(grad_norm, grad_clip),
             grad)
     else:  # disabling gradient clipping if grad_clip < 0
         clipped_grad = grad
     return state.optimizer.apply_gradient(clipped_grad, learning_rate=lr)
예제 #6
0
파일: utils_test.py 프로젝트: BwRy/jraph
    def test_segment_max_negatives(self, indices_are_sorted, unique_indices):
        neg_inf = jnp.iinfo(jnp.int32).min
        if unique_indices:
            data = -1 - jnp.arange(6)  # [-1, -2, -3, -4, -5, -6]
            if indices_are_sorted:
                segment_ids = jnp.array([0, 1, 2, 3, 4, 5])
                expected_out = jnp.array([-1, -2, -3, -4, -5, -6])
                num_segments = 6
            else:
                segment_ids = jnp.array([1, 0, 2, 4, 3, -5])
                expected_out = jnp.array([-2, -1, -3, -5, -4])
                num_segments = 5
        else:
            data = -1 - jnp.arange(9)  # [-1, -2, -3, -4, -5, -6, -7, -8, -9]
            if indices_are_sorted:
                segment_ids = jnp.array([0, 0, 0, 1, 1, 1, 2, 3, 4])
                expected_out = jnp.array([-1, -4, -7, -8, -9, neg_inf])
            else:
                segment_ids = jnp.array([0, 1, 2, 0, 4, 0, 1, 1, -6])
                expected_out = jnp.array([-1, -2, -3, neg_inf, -5, neg_inf])
            num_segments = 6

        with self.subTest('nojit'):
            result = utils.segment_max(data, segment_ids, num_segments,
                                       indices_are_sorted, unique_indices)
            self.assertAllClose(result, expected_out, check_dtypes=True)
            result = utils.segment_max(data,
                                       segment_ids,
                                       indices_are_sorted=indices_are_sorted,
                                       unique_indices=unique_indices)
            num_unique_segments = jnp.maximum(
                jnp.max(segment_ids) + 1, jnp.max(-segment_ids))
            self.assertAllClose(result,
                                expected_out[:num_unique_segments],
                                check_dtypes=True)
        with self.subTest('jit'):
            result = jax.jit(utils.segment_max,
                             static_argnums=(2, 3, 4))(data, segment_ids,
                                                       num_segments,
                                                       indices_are_sorted,
                                                       unique_indices)
            self.assertAllClose(result, expected_out, check_dtypes=True)
예제 #7
0
def norm_projection(delta, norm_type, eps=1.):
  """Projects to a norm-ball centered at 0.

  Args:
    delta: An array of size dim x num containing vectors to be projected.
    norm_type: A string denoting the type of the norm-ball.
    eps: A float denoting the radius of the norm-ball.

  Returns:
    An array of size dim x num, the projection of delta to the norm-ball.
  """
  shape = delta.shape
  if len(delta.shape) == 1:
    delta = delta.reshape(-1, 1)
  if norm_type == 'linf':
    delta = jnp.clip(delta, -eps, eps)
  elif norm_type == 'l2':
    # Euclidean projection: divide all elements by a constant factor
    avoid_zero_div = 1e-12
    norm2 = jnp.sum(delta**2, axis=0, keepdims=True)
    norm = jnp.sqrt(jnp.maximum(avoid_zero_div, norm2))
    # only decrease the norm, never increase
    delta = delta * jnp.clip(eps / norm, a_min=None, a_max=1)
  elif norm_type == 'l1':
    delta = l1_unit_projection(delta / eps) * eps
  elif norm_type == 'dftinf':
    # transform to DFT, project using known projections, then transform back
    # dft = np.matrix(scipy.linalg.dft(delta.shape[0]) / np.sqrt(delta.shape[0]))
    dft = np.matrix(scipy.linalg.dft(delta.shape[0], scale='sqrtn'))
    dftxdelta = dft @ delta
    # dftxdelta = np.matrix(scipy.fft.fft(delta, axis=0, norm='ortho'))
    # L2 projection of each coordinate to the L2-ball in the complex plane
    dftz = dftxdelta.reshape(1, -1)
    dftz = jnp.concatenate((jnp.real(dftz), jnp.imag(dftz)), axis=0)
    dftz = norm_projection(dftz, 'l2', eps)
    dftz = (dftz[0, :] + 1j * dftz[1, :]).reshape(delta.shape)
    # project back from DFT
    delta = dft.getH() @ dftz
    # delta = np.matrix(scipy.fft.ifft(dftz, axis=0, norm='ortho'))
    # Projected vector can have an imaginary part
    delta = jnp.real(delta)
  return delta.reshape(shape)
예제 #8
0
def _internal_bi_tempered_logistic_loss(activations, labels, t1, t2):
  """Computes the Bi-Tempered logistic loss.

  Args:
    activations: A multi-dimensional array with last dimension `num_classes`.
    labels: batch_size
    t1: Temperature 1 (< 1.0 for boundedness).
    t2: Temperature 2 (> 1.0 for tail heaviness).

  Returns:
    A loss array for robust loss.
  """
  normalization_constants = compute_normalization(activations, t2, num_iters=5)
  if t2 == 1.0:
    if t1 == 1.0:
      return normalization_constants + jnp.sum(
          jnp.multiply(labels,
                       jnp.log(labels + 1e-10) - activations), -1)
    else:
      shifted_activations = jnp.exp(activations - normalization_constants)
      one_minus_t1 = (1.0 - t1)
      one_minus_t2 = 1.0
  else:
    one_minus_t1 = (1.0 - t1)
    one_minus_t2 = (1.0 - t2)
    shifted_activations = jnp.maximum(
        1.0 + one_minus_t2 * (activations - normalization_constants), 0.0)

  if t1 == 1.0:
    return jnp.sum(
        jnp.multiply(
            jnp.log(labels + 1e-10) -
            jnp.log(jnp.power(shifted_activations, 1.0 / one_minus_t2)),
            labels), -1)
  else:
    beta = 1.0 + one_minus_t1
    logt_probs = (jnp.power(shifted_activations, one_minus_t1 / one_minus_t2) -
                  1.0) / one_minus_t1
    return jnp.sum(
        jnp.multiply(log_t(labels, t1) - logt_probs, labels) - 1.0 / beta *
        (jnp.power(labels, beta) -
         jnp.power(shifted_activations, beta / one_minus_t2)), -1)
예제 #9
0
def sample_bounds(key: jnp.ndarray,
                  shape: Tuple[int, ...],
                  minval: float = -2.,
                  maxval: float = 2.) -> Tuple[jnp.ndarray, jnp.ndarray]:
  """Sample some bounds of the required shape.

  Args:
    key: Random number generator.
    shape: Shape of the bounds to generate.
    minval: Optional, smallest value that the bounds could take.
    maxval: Optional, largest value that the bounds could take.
  Returns:
    lb, ub: Lower and upper bound tensors of the desired shape.
  """
  key_0, key_1 = jax.random.split(key)
  bound_1 = jax.random.uniform(key_0, shape, minval=minval, maxval=maxval)
  bound_2 = jax.random.uniform(key_1, shape, minval=minval, maxval=maxval)
  lb = jnp.minimum(bound_1, bound_2)
  ub = jnp.maximum(bound_1, bound_2)
  return lb, ub
예제 #10
0
def _clip_by_l2_norm(x: Array, max_norm: float) -> Array:
    """Clip gradients to maximum l2 norm `max_norm`."""
    # Compute the sum of squares and find out where things are zero.
    sum_sq = jnp.sum(jnp.vdot(x, x))
    nonzero = sum_sq > 0

    # Compute the norm wherever sum_sq > 0 and leave it <= 0 otherwise. This makes
    # use of the the "double where" trick; see
    # https://jax.readthedocs.io/en/latest/faq.html#gradients-contain-nan-where-using-where
    # for more info. In short this is necessary because although norm ends up
    # computed correctly where nonzero is true if we ignored this we'd end up with
    # nans on the off-branches which would leak through when computed gradients in
    # the backward pass.
    sum_sq_ones = jnp.where(nonzero, sum_sq, jnp.ones_like(sum_sq))
    norm = jnp.where(nonzero, jnp.sqrt(sum_sq_ones), sum_sq)

    # Normalize by max_norm. Whenever norm < max_norm we're left with x (this
    # happens trivially for indices where nonzero is false). Otherwise we're left
    # with the desired x * max_norm / norm.
    return (x * max_norm) / jnp.maximum(norm, max_norm)
예제 #11
0
        def _dynamics(state, action):
            self.nsamples += 1
            position = state[0]
            velocity = state[1]

            force = jnp.minimum(jnp.maximum(action, self.min_action), self.max_action)

            velocity += force * self.power - 0.0025 * jnp.cos(3 * position)
            velocity = jnp.clip(velocity, -self.max_speed, self.max_speed)

            position += velocity
            position = jnp.clip(position, self.min_position, self.max_position)
            reset_velocity = (position == self.min_position) & (velocity < 0)
            # print('state.shape = ' + str(state.shape))
            # print('position.shape = ' + str(position.shape))
            # print('velocity.shape = ' + str(velocity.shape))
            # print('reset_velocity.shape = ' + str(reset_velocity.shape))
            velocity = jax.lax.cond(reset_velocity[0], velocity, lambda x: jnp.zeros((1,)), velocity, lambda x: x)
            # print('velocity.shape AFTER = ' + str(velocity.shape))
            return jnp.reshape(jnp.array([position, velocity]), (2,))
예제 #12
0
def convex_fn_relaxation(primitive: bound_propagation.Primitive, inp: Bound,
                         **params) -> Tuple[TensorFunction, TensorFunction]:
    """Relaxation of an element-wise convex primitive.

  Args:
    primitive: Convex primitive to relax.
    inp: Bounds on the input.
    **params: Params of the quadratic operation, mainly the jaxpr defining it.
  Returns:
    lb_fun, ub_fun
  """
    prim_fun = functools.partial(primitive.bind, **params)
    x_lb, x_ub = inp.lower, inp.upper
    y_lb, y_ub = prim_fun(x_lb), prim_fun(x_ub)

    chord_slope_safe_denom = jnp.maximum(x_ub - x_lb, 1e-12)
    chord_slope = (y_ub - y_lb) / chord_slope_safe_denom
    chord_intercept = y_lb - chord_slope * x_lb
    chord_fun = lambda x: chord_slope * x + chord_intercept
    return prim_fun, chord_fun
예제 #13
0
파일: _lnp.py 프로젝트: lschmors/RFEst
    def cost(self, p, extra=None, precomputed=None):
        """
        Negetive Log Likelihood.
        """
        y = self.y if extra is None else extra['y']
        r = self.forward_pass(p, extra) if precomputed is None else precomputed
        r = np.maximum(r, 1e-20)  # remove zero to avoid nan in log.
        dt = self.dt

        term0 = -np.log(r / dt) @ y  # spike term from poisson log-likelihood
        term1 = np.sum(r)  # non-spike term

        neglogli = term0 + term1

        if self.beta and extra is None:
            l1 = np.linalg.norm(p['w'], 1)
            l2 = np.linalg.norm(p['w'], 2)
            neglogli += self.beta * ((1 - self.alpha) * l2 + self.alpha * l1)

        return neglogli
예제 #14
0
파일: utils.py 프로젝트: jejjohnson/fundl
def l2_normalize(arr, axis, epsilon=1e-12):
    """
    L2 normalize along a particular axis.

    Doc taken from tf.nn.l2_normalize:
    https://www.tensorflow.org/api_docs/python/tf/math/l2_normalize

    output = x / (
        sqrt(
            max(
                sum(x**2),
                epsilon
            )
        )
    )
    """
    sq_arr = np.power(arr, 2)
    square_sum = np.sum(sq_arr, axis=axis, keepdims=True)
    max_weights = np.maximum(square_sum, epsilon)
    return np.divide(arr, np.sqrt(max_weights))
예제 #15
0
def log(q, eps=1e-8):
    """Computes the quaternion logarithm.

    References:
      https://en.wikipedia.org/wiki/Quaternion#Exponential,_logarithm,_and_power_functions

    Args:
      q: the quaternion in (x,y,z,w) format.
      eps: an epsilon value for numerical stability.

    Returns:
      The logarithm of q.
    """
    mag = linalg.norm(q, axis=-1, keepdims=True)
    v = im(q)
    s = re(q)
    w = jnp.log(mag)
    denom = jnp.maximum(linalg.norm(v, axis=-1, keepdims=True), eps * jnp.ones_like(v))
    xyz = v / denom * safe_acos(s / eps)
    return jnp.concatenate((xyz, w), axis=-1)
예제 #16
0
def solve_vfi(money, EV_list, umult, kf, km, sigma, beta, i, wn, wt, sgrid,
              psi):

    #ts = 0.01
    # dim is (na,ns,nexo,ntheta)



    EV_stretch_list = [(wt[:,None,None]*x[i,:,:] + \
                              wn[:,None,None]*x[i+1,:,:])
                        for x in EV_list]

    consumption = money[:, None, :] - sgrid[None, :, None]
    consumption_negative = (consumption <= 0)
    uc = (np.maximum(consumption, 1e-8))**(1 - sigma) / (1 - sigma)
    utility = umult[None,None,None,:]*(uc[:,:,:,None]) - \
                                1e9*consumption_negative[:,:,:,None]

    EVs, EVFs, EVMs = EV_stretch_list

    mega_matrix = utility + beta * EVs[None, :, :, :]
    #print(mega_matrix.shape)

    ind_s = mega_matrix.argmax(axis=1)
    V = np.take_along_axis(mega_matrix,ind_s[:,None,:,:],1)\
                                                .squeeze(axis=1) + psi

    s = sgrid[ind_s]
    c = money[:, :, None] - s

    V_check = umult[None,None,:]*(c**(1-sigma)/(1-sigma)) + \
                            psi + beta*np.take_along_axis(EVs,ind_s,0)

    VF = ((kf[None,None,:]*c)**(1-sigma)/(1-sigma)) + \
                            psi + beta*np.take_along_axis(EVFs,ind_s,0)
    VM = ((km[None,None,:]*c)**(1-sigma)/(1-sigma)) + \
                            psi + beta*np.take_along_axis(EVMs,ind_s,0)

    assert np.allclose(V_check, V, atol=1e-5)

    return V, VF, VM, s
예제 #17
0
def mean_and_var(
    x: Optional[np.ndarray],
    axis: Optional[Axes] = None,
    dtype: Optional[np.dtype] = None,
    out: Optional[None] = None,
    ddof: int = 0,
    keepdims: bool = False,
    mask: Optional[np.ndarray] = None,
    get_var: bool = False
) -> Tuple[Optional[np.ndarray], Optional[np.ndarray]]:
    """`np.mean` and `np.var` taking the `mask` information into account."""
    var = None
    if x is None:
        return x, var

    if mask is None:
        mean = np.mean(x, axis, dtype, out, keepdims)
        if get_var:
            var = np.var(x, axis, dtype, out, ddof, keepdims)

    else:
        axis = tuple(utils.canonicalize_axis(axis, x))
        size = utils.size_at(x, axis)
        mask = np.broadcast_to(mask, x.shape)
        mask_size = np.count_nonzero(mask, axis)
        for i in axis:
            mask_size = np.expand_dims(mask_size, i)
        size -= mask_size
        size = np.maximum(size, 1)

        mean = np.sum(x, axis=axis, keepdims=True) / size
        if not keepdims:
            mean = np.squeeze(mean, axis)

        if get_var:
            var = np.sum(
                (x - mean)**2, axis=axis, keepdims=True) / (size - ddof)
            if not keepdims:
                var = np.squeeze(var, axis)

    return mean, var
예제 #18
0
    def __init__(self,
                 space,
                 vocab_size,
                 precision=2,
                 max_range=(-100.0, 100.0)):
        self._precision = precision

        # Some gym envs (e.g. CartPole) have unreasonably high bounds for
        # observations. We clip so we can represent them.
        bounded_space = copy.copy(space)
        (min_low, max_high) = max_range
        bounded_space.low = np.maximum(space.low, min_low)
        bounded_space.high = np.minimum(space.high, max_high)
        if (not np.allclose(bounded_space.low, space.low)
                or not np.allclose(bounded_space.high, space.high)):
            logging.warning(
                'Space limits %s, %s out of bounds %s. Clipping to %s, %s.',
                str(space.low), str(space.high), str(max_range),
                str(bounded_space.low), str(bounded_space.high))

        super().__init__(bounded_space, vocab_size)
예제 #19
0
def lerp_weight(x, xs):
    """Linear interpolation weight from a sample at x to xs.

  Returns the linear interpolation weight of a "query point" at coordinate `x`
  with respect to a "sample" at coordinate `xs`.

  The integer coordinates `x` are at pixel centers.
  The floating point coordinates `xs` are at pixel edges.
  (OpenGL convention).

  Args:
    x: "Query" point position.
    xs: "Sample" position.

  Returns:
    - 1 when x = xs.
    - 0 when |x - xs| > 1.
  """
    dx = x - xs
    abs_dx = abs(dx)
    return jnp.maximum(1.0 - abs_dx, 0.0)
예제 #20
0
def initial_step_size(fun, t0, y0, order, rtol, atol, f0):
  # Algorithm from:
  # E. Hairer, S. P. Norsett G. Wanner,
  # Solving Ordinary Differential Equations I: Nonstiff Problems, Sec. II.4.
  y0, f0 = _promote_dtypes_inexact(y0, f0)
  dtype = y0.dtype

  scale = atol + jnp.abs(y0) * rtol
  d0 = jnp.linalg.norm(y0 / scale.astype(dtype))
  d1 = jnp.linalg.norm(f0 / scale.astype(dtype))

  h0 = jnp.where((d0 < 1e-5) | (d1 < 1e-5), 1e-6, 0.01 * d0 / d1)
  y1 = y0 + h0.astype(dtype) * f0
  f1 = fun(y1, t0 + h0)
  d2 = jnp.linalg.norm((f1 - f0) / scale.astype(dtype)) / h0

  h1 = jnp.where((d1 <= 1e-15) & (d2 <= 1e-15),
                jnp.maximum(1e-6, h0 * 1e-3),
                (0.01 / jnp.max(d1 + d2)) ** (1. / (order + 1.)))

  return jnp.minimum(100. * h0, h1)
예제 #21
0
    def log_factor_k(cluster_id, log_maha_k, num_k, logdetC_k):
        """
        Computes f_k such that,
            u_k @ f0_k n_k C_k @ u_k <= 1
        and
            f_k^d V(n_k C_k) = max(V(S_k), V(f0_k n_k C_k))
            log_f_k = (log max(V(S)*n_k/n_S, V(f0_k n_k C_k)) - log V(n_k C_k))/D
            log_f_k = (max(log(V(S)*n_k/n_S), logV(n_k C_k)) - log V(n_k C_k))/D
        """
        # K
        log_f_expand_k = -jnp.max(jnp.where(cluster_id == a_k[:, None],
                                            log_maha_k, -jnp.inf),
                                  axis=-1)
        log_VE_expand_k = log_ellipsoid_volume(logdetC_k, num_k,
                                               log_f_expand_k)
        log_VE_k = log_ellipsoid_volume(logdetC_k, num_k, 0.)

        log_scale_k = (jnp.maximum(log_VS + jnp.log(num_k) - jnp.log(num_S),
                                   log_VE_expand_k) - log_VE_k) / D
        # K
        return log_scale_k
예제 #22
0
파일: mip.py 프로젝트: wx-b/mipnerf
def lift_gaussian(d, t_mean, t_var, r_var, diag):
    """Lift a Gaussian defined along a ray to 3D coordinates."""
    mean = d[..., None, :] * t_mean[..., None]

    d_mag_sq = jnp.maximum(1e-10, jnp.sum(d**2, axis=-1, keepdims=True))

    if diag:
        d_outer_diag = d**2
        null_outer_diag = 1 - d_outer_diag / d_mag_sq
        t_cov_diag = t_var[..., None] * d_outer_diag[..., None, :]
        xy_cov_diag = r_var[..., None] * null_outer_diag[..., None, :]
        cov_diag = t_cov_diag + xy_cov_diag
        return mean, cov_diag
    else:
        d_outer = d[..., :, None] * d[..., None, :]
        eye = jnp.eye(d.shape[-1])
        null_outer = eye - d[..., :, None] * (d / d_mag_sq)[..., None, :]
        t_cov = t_var[..., None, None] * d_outer[..., None, :, :]
        xy_cov = r_var[..., None, None] * null_outer[..., None, :, :]
        cov = t_cov + xy_cov
        return mean, cov
예제 #23
0
def predict_cnn(params, inputs, include_preactivations=False):
    """Forward pass for a CNN given parameters.

  Args:
    params: Parameters for the CNN. See make_cnn_params for syntax.
    inputs: Inputs to CNN.
    include_preactivations: bool. If True, also return pre-activations after
      each matmul layer.
  Returns:
    act: Output from forward pass through CNN.
    (Optional) layer_acts: Post-relu activation at each layer
  """
    act = inputs
    layer_preacts = []
    for counter, layer_params in enumerate(params):
        act = fwd(act, layer_params)
        layer_preacts.append(act)
        if counter < len(params) - 1:
            # no relu on final layer
            act = jnp.maximum(act, 0)
    return act if not include_preactivations else (act, layer_preacts)
예제 #24
0
def dists_to_samples(rays, t):
    """Convert mipnerf frustums to gaussians."""
    t_mids = .5 * (t[Ellipsis, 1:] + t[Ellipsis, :-1])
    mean = rays[0][Ellipsis,
                   None, :] + rays[1][Ellipsis, None, :] * t_mids[Ellipsis,
                                                                  None]

    d = rays[1]
    d_mag_sq = np.maximum(1e-10, np.sum(d**2, axis=-1, keepdims=True))
    t_half = .5 * (t[Ellipsis, 1:] - t[Ellipsis, :-1])
    t_var = t_half**2 / 3.
    r_var = (rays[2] * t_mids)**2 / 12.

    d_outer = d[Ellipsis, :, None] * d[Ellipsis, None, :]
    eye = np.eye(d.shape[-1])
    null_outer = eye - d[Ellipsis, :, None] * (d / d_mag_sq)[Ellipsis, None, :]
    t_cov = t_var[Ellipsis, None, None] * d_outer[Ellipsis, None, :, :]
    xy_cov = r_var[Ellipsis, None, None] * null_outer[Ellipsis, None, :, :]
    cov = t_cov + xy_cov

    return mean, cov
예제 #25
0
    def unconstrained_proposal(self, rng_key, x, grad_, hess_):
        ndim = np.ndim(x)
        if ndim == 0:
            inv_hess = 1 / hess_
            dist_type = dist.Normal
        else:
            inv_hess = np.linalg.inv(hess_)
            dist_type = dist.MultivariateNormal

        loc = x - np.dot(inv_hess, grad_)
        sigma = -inv_hess

        # Reconstruct sigma if not positive definite
        if not ndim == 0 and not np.all(np.linalg.eigvals(sigma) > 0):
            lam, vec = np.linalg.eigh(sigma)
            sigma = vec @ np.diag(np.maximum(
                lam, UNCONSTRAINED_RECONSTRUCTION)) @ vec.T

        dist_ = dist_type(loc, sigma + MU_CORRECTION)

        return dist_.sample(rng_key).reshape(x.shape), dist_
예제 #26
0
 def from_params(
         cls,
         fixed_params,
         opt_params,
         scale=None,
         traceable=True):  # FIXME: traceable; why sometimes no Scale?
     if not scale:
         scale = Scale(0.0, 1.0)
     floor = fixed_params.get("floor", -np.inf)
     ceiling = fixed_params.get("ceiling", np.inf)
     # Allow logistic center to exceed the range by 20%
     loc_min = np.maximum(scale.low, floor) - 0.2 * scale.width
     loc_max = np.minimum(scale.high, ceiling) + 0.2 * scale.width
     loc_range = loc_max - loc_min
     structured_params = opt_params.reshape((-1, 3))
     locs = loc_min + scipy.special.expit(structured_params[:,
                                                            0]) * loc_range
     # Allow logistic scales between 0.01 and 0.5
     # Don't allow tiny scales outside of the visible range
     s_min = 0.01 + 0.1 * np.where(
         (locs < scale.low),
         scale.low - locs,
         np.where(locs > scale.high, locs - scale.high, 0.0),
     )
     s_max = 0.5
     s_range = s_max - s_min
     ss = s_min + scipy.special.expit(structured_params[:, 1]) * s_range
     # Allow probs > 0.01
     probs = list(0.01 + nn.softmax(structured_params[:, 2]) *
                  (1 - 0.01 * structured_params[:, 2].size))
     # Bundle up components
     component_logistics = [
         Logistic(l, s, scale, normalized=True) for (l, s) in zip(locs, ss)
     ]
     components = [
         Truncate(base_dist=cl, floor=floor, ceiling=ceiling)
         for cl in component_logistics
     ]
     mixture = cls(components=components, probs=probs)
     return mixture
예제 #27
0
    def light(self,
              origin,
              direction,
              intersection,
              light_position,
              eye_position,
              scene_objects,
              bounce=0,
              far=1.0e15):
        '''
        Basic light model using a only diffuse lighting
        '''
        rayhit = origin + direction * intersection
        normal = ((rayhit - self.center) * (1. / self.radius))
        direction_to_light = (light_position - rayhit).norm()
        direction_to_eye = (eye_position - rayhit).norm()
        nudged = rayhit + normal * 0.001  # To avoid shadow acne

        # Create shadow mask
        light_distances = [
            o.intersect(nudged, direction_to_light, far=far)
            for o in scene_objects
        ]
        light_nearest = reduce(jnp.minimum, light_distances)
        light_mask = light_distances[scene_objects.index(
            self)] == light_nearest

        # Ambient light
        color = Vec3(0.05, 0.05, 0.05)

        # Lambert shading (diffuse)
        light_hit = jnp.maximum(normal.dot(direction_to_light), 0)
        color += self.diffusecolor(rayhit) * light_hit * light_mask

        # Phong light
        phong = normal.dot((direction_to_light + direction_to_eye).norm())
        color += Vec3(1., 1., 1.) * jnp.power(jnp.clip(phong, 0, 1),
                                              50) * light_mask

        return color
예제 #28
0
def segment_max(data,
                segment_ids,
                num_segments=None,
                indices_are_sorted=False,
                unique_indices=False):
    """Computes the max within segments of an array.

  Similar to TensorFlow's segment_max:
  https://www.tensorflow.org/api_docs/python/tf/math/segment_max

  Args:
    data: an array with the values to be maxed over.
    segment_ids: an array with integer dtype that indicates the segments of
      `data` (along its leading axis) to be maxed over. Values can be repeated
      and need not be sorted. Values outside of the range [0, num_segments) are
      wrapped into that range by applying jnp.mod.
    num_segments: optional, an int with positive value indicating the number of
      segments. The default is ``jnp.maximum(jnp.max(segment_ids) + 1,
      jnp.max(-segment_ids))`` but since `num_segments` determines the size of
      the output, a static value must be provided to use ``segment_max`` in a
      ``jit``-compiled function.
    indices_are_sorted: whether ``segment_ids`` is known to be sorted
    unique_indices: whether ``segment_ids`` is known to be free of duplicates

  Returns:
    An array with shape ``(num_segments,) + data.shape[1:]`` representing
    the segment maxs.
  """
    if num_segments is None:
        num_segments = jnp.maximum(
            jnp.max(segment_ids) + 1, jnp.max(-segment_ids))
    num_segments = int(num_segments)

    min_value = dtype_min_value(data.dtype)
    out = jnp.full((num_segments, ) + data.shape[1:],
                   min_value,
                   dtype=data.dtype)
    segment_ids = jnp.mod(segment_ids, num_segments)
    return jax.ops.index_max(out, segment_ids, data, indices_are_sorted,
                             unique_indices)
예제 #29
0
    def compute_loss(self, predictions: NestedMap,
                     input_batch: NestedMap) -> Tuple[Metrics, Dict[str, Any]]:
        """Computes the loss and other metrics for the given predictions.

    Args:
      predictions: The output of `compute_predictions`.
      input_batch: A `.NestedMap` object containing input tensors to this tower.

    Returns:
      - A dict or NestedMap containing str keys and (metric, weight) pairs as
        values, where one of the entries is expected to corresponds to the loss.
      - A dict containing arbitrary tensors describing something about each
        training example, where the first dimension of each tensor is the batch
        index.
    """
        labels = input_batch.labels
        num_tokens = jnp.sum(1.0 - input_batch.paddings.astype(jnp.float32))
        num_seqs = jnp.sum(
            jnp.amax(input_batch.segment_ids.astype(jnp.float32), axis=1))
        weights = predictions.augmented_pos.astype(jnp.float32)
        predicted_labels = predictions.per_example_argmax.astype(labels.dtype)
        num_preds = predictions.total_weight.astype(jnp.float32)
        mean_acc = jnp.sum(
            (labels == predicted_labels) * weights) / jnp.maximum(
                num_preds, 1)
        metric_weight = jnp.array(num_preds, predictions.avg_xent.dtype)
        metrics = py_utils.NestedMap(
            total_loss=(predictions.total_loss, metric_weight),
            avg_xent=(predictions.avg_xent, metric_weight),
            aux_loss=(predictions.aux_loss, metric_weight),
            log_pplx=(predictions.avg_xent, metric_weight),
            fraction_of_correct_preds=(mean_acc,
                                       jnp.array(num_preds, mean_acc.dtype)),
            num_predictions=(num_preds, jnp.array(1.0, num_preds.dtype)),
            num_tokens=(num_tokens, jnp.array(1.0, num_tokens.dtype)),
            num_seqs=(num_seqs, jnp.array(1.0, num_seqs.dtype)),
        )

        per_example_output = py_utils.NestedMap()
        return metrics, per_example_output
예제 #30
0
    def test_clipping_norm(self, l2_norm_clip):
        dp_agg = privacy.differentially_private_aggregate(
            l2_norm_clip=l2_norm_clip, noise_multiplier=0., seed=42)
        state = dp_agg.init(self.params)
        update_fn = self.variant(dp_agg.update)

        # Shape of the three arrays below is (self.batch_size, )
        norms = [
            jnp.linalg.norm(g.reshape(self.batch_size, -1), axis=1)
            for g in jax.tree_leaves(self.per_eg_grads)
        ]
        global_norms = jnp.linalg.norm(jnp.stack(norms), axis=0)
        divisors = jnp.maximum(global_norms / l2_norm_clip, 1.)
        # Since the values of all the parameters are the same within each example,
        # we can easily compute what the values should be:
        expected_val = jnp.mean(jnp.arange(self.batch_size) / divisors)
        expected_tree = jax.tree_map(
            lambda p: jnp.broadcast_to(expected_val, p.shape), self.params)

        for _ in range(3):
            updates, state = update_fn(self.per_eg_grads, state, self.params)
            chex.assert_tree_all_close(updates, expected_tree)