Exemplo n.º 1
def value_loss_given_predictions(value_prediction,
  """Computes the value loss given the prediction of the value function.

    value_prediction: np.ndarray of shape (B, T+1, 1)
    rewards: np.ndarray of shape (B, T) of rewards.
    reward_mask: np.ndarray of shape (B, T), the mask over rewards.
    gamma: float, discount factor.
    epsilon: float, clip-fraction, used if value_value_prediction_old isn't None
    value_prediction_old: np.ndarray of shape (B, T+1, 1) of value predictions
      using the old parameters. If provided, we incorporate this in the loss as
      well. This is from the OpenAI baselines implementation.

    The average L2 value loss, averaged over instances where reward_mask is 1.

  B, T = rewards.shape  # pylint: disable=invalid-name
  assert (B, T) == reward_mask.shape
  assert (B, T + 1, 1) == value_prediction.shape

  value_prediction = np.squeeze(value_prediction, axis=2)  # (B, T+1)
  value_prediction = value_prediction[:, :-1] * reward_mask  # (B, T)
  r2g = rewards_to_go(rewards, reward_mask, gamma=gamma)  # (B, T)
  loss = (value_prediction - r2g)**2

  # From the baselines implementation.
  if value_prediction_old is not None:
    value_prediction_old = np.squeeze(value_prediction_old, axis=2)  # (B, T+1)
    value_prediction_old = value_prediction_old[:, :-1] * reward_mask  # (B, T)

    v_clipped = value_prediction_old + np.clip(
        value_prediction - value_prediction_old, -epsilon, epsilon)
    v_clipped_loss = (v_clipped - r2g)**2
    loss = np.maximum(v_clipped_loss, loss)

  # Take an average on only the points where mask != 0.
  return np.sum(loss) / np.sum(reward_mask)
Exemplo n.º 2
  def update_fn(updates, state, params=None):
    del params
    grads_flat, grads_treedef = jax.tree_flatten(updates)
    bsize = grads_flat[0].shape[0]

    if any(g.ndim == 0 or bsize != g.shape[0] for g in grads_flat):
      raise ValueError(
          'Unlike other transforms, `differentially_private_aggregate` expects'
          ' `updates` to have a batch dimension in the 0th axis. That is, this'
          ' function expects per-example gradients as input.')

    new_key, *rngs = jax.random.split(state.rng_key, len(grads_flat)+1)
    global_grad_norms = jax.vmap(utils.global_norm)(grads_flat)
    divisors = jnp.maximum(global_grad_norms / l2_norm_clip, 1.0)
    clipped = [(jnp.moveaxis(g, 0, -1) / divisors).sum(-1) for g in grads_flat]
    noised = [(g + noise_std * jax.random.normal(r, g.shape, g.dtype)) / bsize
              for g, r in zip(clipped, rngs)]
    return (jax.tree_unflatten(grads_treedef, noised),
Exemplo n.º 3
Arquivo: ode.py Projeto: xf05888/jax
def optimal_step_size(last_step,
    """Compute optimal Runge-Kutta stepsize."""
    mean_error_ratio = np.max(mean_error_ratio)
    dfactor = np.where(mean_error_ratio < 1, 1.0, dfactor)

    err_ratio = np.sqrt(mean_error_ratio)
    factor = np.maximum(
        1.0 / ifactor,
        np.minimum(err_ratio**(1.0 / order) / safety, 1.0 / dfactor))
    return np.where(
        mean_error_ratio == 0,
        last_step * ifactor,
        last_step / factor,
Exemplo n.º 4
    def assert_close(expected, actual):
        self.assertEqual(expected.shape, actual.shape)
        relative_error = (np.linalg.norm(actual - expected) /
                          np.maximum(np.linalg.norm(expected), 1e-12))

        absolute_error = np.mean(np.abs(actual - expected))

        if (np.isnan(relative_error) or relative_error > rtol
                or absolute_error > atol):
            _log(relative_error, absolute_error, expected, actual, False)
                self.failureException('Relative ERROR: ',
                                      'EXPECTED:' + ' ' * 50, expected,
                                      'ACTUAL:' + ' ' * 50, actual,
                                      ' ' * 50, 'Absolute ERROR: ',
            _log(relative_error, absolute_error, expected, actual, True)
Exemplo n.º 5
 def optimize(state,
     """Optimizes with warmup and gradient clipping (disabled if negative)."""
     lr = state.lr
     if warmup > 0:
         lr = lr * jnp.minimum(state.step / warmup, 1.0)
     if grad_clip >= 0:
         # Compute global gradient norm
         grad_norm = jnp.sqrt(
             sum([jnp.sum(jnp.square(x)) for x in jax.tree_leaves(grad)]))
         # Clip gradient
         clipped_grad = jax.tree_map(
             lambda x: x * grad_clip / jnp.maximum(grad_norm, grad_clip),
     else:  # disabling gradient clipping if grad_clip < 0
         clipped_grad = grad
     return state.optimizer.apply_gradient(clipped_grad, learning_rate=lr)
Exemplo n.º 6
    def test_segment_max_negatives(self, indices_are_sorted, unique_indices):
        neg_inf = jnp.iinfo(jnp.int32).min
        if unique_indices:
            data = -1 - jnp.arange(6)  # [-1, -2, -3, -4, -5, -6]
            if indices_are_sorted:
                segment_ids = jnp.array([0, 1, 2, 3, 4, 5])
                expected_out = jnp.array([-1, -2, -3, -4, -5, -6])
                num_segments = 6
                segment_ids = jnp.array([1, 0, 2, 4, 3, -5])
                expected_out = jnp.array([-2, -1, -3, -5, -4])
                num_segments = 5
            data = -1 - jnp.arange(9)  # [-1, -2, -3, -4, -5, -6, -7, -8, -9]
            if indices_are_sorted:
                segment_ids = jnp.array([0, 0, 0, 1, 1, 1, 2, 3, 4])
                expected_out = jnp.array([-1, -4, -7, -8, -9, neg_inf])
                segment_ids = jnp.array([0, 1, 2, 0, 4, 0, 1, 1, -6])
                expected_out = jnp.array([-1, -2, -3, neg_inf, -5, neg_inf])
            num_segments = 6

        with self.subTest('nojit'):
            result = utils.segment_max(data, segment_ids, num_segments,
                                       indices_are_sorted, unique_indices)
            self.assertAllClose(result, expected_out, check_dtypes=True)
            result = utils.segment_max(data,
            num_unique_segments = jnp.maximum(
                jnp.max(segment_ids) + 1, jnp.max(-segment_ids))
        with self.subTest('jit'):
            result = jax.jit(utils.segment_max,
                             static_argnums=(2, 3, 4))(data, segment_ids,
            self.assertAllClose(result, expected_out, check_dtypes=True)
Exemplo n.º 7
def norm_projection(delta, norm_type, eps=1.):
  """Projects to a norm-ball centered at 0.

    delta: An array of size dim x num containing vectors to be projected.
    norm_type: A string denoting the type of the norm-ball.
    eps: A float denoting the radius of the norm-ball.

    An array of size dim x num, the projection of delta to the norm-ball.
  shape = delta.shape
  if len(delta.shape) == 1:
    delta = delta.reshape(-1, 1)
  if norm_type == 'linf':
    delta = jnp.clip(delta, -eps, eps)
  elif norm_type == 'l2':
    # Euclidean projection: divide all elements by a constant factor
    avoid_zero_div = 1e-12
    norm2 = jnp.sum(delta**2, axis=0, keepdims=True)
    norm = jnp.sqrt(jnp.maximum(avoid_zero_div, norm2))
    # only decrease the norm, never increase
    delta = delta * jnp.clip(eps / norm, a_min=None, a_max=1)
  elif norm_type == 'l1':
    delta = l1_unit_projection(delta / eps) * eps
  elif norm_type == 'dftinf':
    # transform to DFT, project using known projections, then transform back
    # dft = np.matrix(scipy.linalg.dft(delta.shape[0]) / np.sqrt(delta.shape[0]))
    dft = np.matrix(scipy.linalg.dft(delta.shape[0], scale='sqrtn'))
    dftxdelta = dft @ delta
    # dftxdelta = np.matrix(scipy.fft.fft(delta, axis=0, norm='ortho'))
    # L2 projection of each coordinate to the L2-ball in the complex plane
    dftz = dftxdelta.reshape(1, -1)
    dftz = jnp.concatenate((jnp.real(dftz), jnp.imag(dftz)), axis=0)
    dftz = norm_projection(dftz, 'l2', eps)
    dftz = (dftz[0, :] + 1j * dftz[1, :]).reshape(delta.shape)
    # project back from DFT
    delta = dft.getH() @ dftz
    # delta = np.matrix(scipy.fft.ifft(dftz, axis=0, norm='ortho'))
    # Projected vector can have an imaginary part
    delta = jnp.real(delta)
  return delta.reshape(shape)
Exemplo n.º 8
def _internal_bi_tempered_logistic_loss(activations, labels, t1, t2):
  """Computes the Bi-Tempered logistic loss.

    activations: A multi-dimensional array with last dimension `num_classes`.
    labels: batch_size
    t1: Temperature 1 (< 1.0 for boundedness).
    t2: Temperature 2 (> 1.0 for tail heaviness).

    A loss array for robust loss.
  normalization_constants = compute_normalization(activations, t2, num_iters=5)
  if t2 == 1.0:
    if t1 == 1.0:
      return normalization_constants + jnp.sum(
                       jnp.log(labels + 1e-10) - activations), -1)
      shifted_activations = jnp.exp(activations - normalization_constants)
      one_minus_t1 = (1.0 - t1)
      one_minus_t2 = 1.0
    one_minus_t1 = (1.0 - t1)
    one_minus_t2 = (1.0 - t2)
    shifted_activations = jnp.maximum(
        1.0 + one_minus_t2 * (activations - normalization_constants), 0.0)

  if t1 == 1.0:
    return jnp.sum(
            jnp.log(labels + 1e-10) -
            jnp.log(jnp.power(shifted_activations, 1.0 / one_minus_t2)),
            labels), -1)
    beta = 1.0 + one_minus_t1
    logt_probs = (jnp.power(shifted_activations, one_minus_t1 / one_minus_t2) -
                  1.0) / one_minus_t1
    return jnp.sum(
        jnp.multiply(log_t(labels, t1) - logt_probs, labels) - 1.0 / beta *
        (jnp.power(labels, beta) -
         jnp.power(shifted_activations, beta / one_minus_t2)), -1)
Exemplo n.º 9
def sample_bounds(key: jnp.ndarray,
                  shape: Tuple[int, ...],
                  minval: float = -2.,
                  maxval: float = 2.) -> Tuple[jnp.ndarray, jnp.ndarray]:
  """Sample some bounds of the required shape.

    key: Random number generator.
    shape: Shape of the bounds to generate.
    minval: Optional, smallest value that the bounds could take.
    maxval: Optional, largest value that the bounds could take.
    lb, ub: Lower and upper bound tensors of the desired shape.
  key_0, key_1 = jax.random.split(key)
  bound_1 = jax.random.uniform(key_0, shape, minval=minval, maxval=maxval)
  bound_2 = jax.random.uniform(key_1, shape, minval=minval, maxval=maxval)
  lb = jnp.minimum(bound_1, bound_2)
  ub = jnp.maximum(bound_1, bound_2)
  return lb, ub
Exemplo n.º 10
def _clip_by_l2_norm(x: Array, max_norm: float) -> Array:
    """Clip gradients to maximum l2 norm `max_norm`."""
    # Compute the sum of squares and find out where things are zero.
    sum_sq = jnp.sum(jnp.vdot(x, x))
    nonzero = sum_sq > 0

    # Compute the norm wherever sum_sq > 0 and leave it <= 0 otherwise. This makes
    # use of the the "double where" trick; see
    # https://jax.readthedocs.io/en/latest/faq.html#gradients-contain-nan-where-using-where
    # for more info. In short this is necessary because although norm ends up
    # computed correctly where nonzero is true if we ignored this we'd end up with
    # nans on the off-branches which would leak through when computed gradients in
    # the backward pass.
    sum_sq_ones = jnp.where(nonzero, sum_sq, jnp.ones_like(sum_sq))
    norm = jnp.where(nonzero, jnp.sqrt(sum_sq_ones), sum_sq)

    # Normalize by max_norm. Whenever norm < max_norm we're left with x (this
    # happens trivially for indices where nonzero is false). Otherwise we're left
    # with the desired x * max_norm / norm.
    return (x * max_norm) / jnp.maximum(norm, max_norm)
Exemplo n.º 11
        def _dynamics(state, action):
            self.nsamples += 1
            position = state[0]
            velocity = state[1]

            force = jnp.minimum(jnp.maximum(action, self.min_action), self.max_action)

            velocity += force * self.power - 0.0025 * jnp.cos(3 * position)
            velocity = jnp.clip(velocity, -self.max_speed, self.max_speed)

            position += velocity
            position = jnp.clip(position, self.min_position, self.max_position)
            reset_velocity = (position == self.min_position) & (velocity < 0)
            # print('state.shape = ' + str(state.shape))
            # print('position.shape = ' + str(position.shape))
            # print('velocity.shape = ' + str(velocity.shape))
            # print('reset_velocity.shape = ' + str(reset_velocity.shape))
            velocity = jax.lax.cond(reset_velocity[0], velocity, lambda x: jnp.zeros((1,)), velocity, lambda x: x)
            # print('velocity.shape AFTER = ' + str(velocity.shape))
            return jnp.reshape(jnp.array([position, velocity]), (2,))
Exemplo n.º 12
def convex_fn_relaxation(primitive: bound_propagation.Primitive, inp: Bound,
                         **params) -> Tuple[TensorFunction, TensorFunction]:
    """Relaxation of an element-wise convex primitive.

    primitive: Convex primitive to relax.
    inp: Bounds on the input.
    **params: Params of the quadratic operation, mainly the jaxpr defining it.
    lb_fun, ub_fun
    prim_fun = functools.partial(primitive.bind, **params)
    x_lb, x_ub = inp.lower, inp.upper
    y_lb, y_ub = prim_fun(x_lb), prim_fun(x_ub)

    chord_slope_safe_denom = jnp.maximum(x_ub - x_lb, 1e-12)
    chord_slope = (y_ub - y_lb) / chord_slope_safe_denom
    chord_intercept = y_lb - chord_slope * x_lb
    chord_fun = lambda x: chord_slope * x + chord_intercept
    return prim_fun, chord_fun
Exemplo n.º 13
    def cost(self, p, extra=None, precomputed=None):
        Negetive Log Likelihood.
        y = self.y if extra is None else extra['y']
        r = self.forward_pass(p, extra) if precomputed is None else precomputed
        r = np.maximum(r, 1e-20)  # remove zero to avoid nan in log.
        dt = self.dt

        term0 = -np.log(r / dt) @ y  # spike term from poisson log-likelihood
        term1 = np.sum(r)  # non-spike term

        neglogli = term0 + term1

        if self.beta and extra is None:
            l1 = np.linalg.norm(p['w'], 1)
            l2 = np.linalg.norm(p['w'], 2)
            neglogli += self.beta * ((1 - self.alpha) * l2 + self.alpha * l1)

        return neglogli
Exemplo n.º 14
def l2_normalize(arr, axis, epsilon=1e-12):
    L2 normalize along a particular axis.

    Doc taken from tf.nn.l2_normalize:

    output = x / (
    sq_arr = np.power(arr, 2)
    square_sum = np.sum(sq_arr, axis=axis, keepdims=True)
    max_weights = np.maximum(square_sum, epsilon)
    return np.divide(arr, np.sqrt(max_weights))
Exemplo n.º 15
def log(q, eps=1e-8):
    """Computes the quaternion logarithm.


      q: the quaternion in (x,y,z,w) format.
      eps: an epsilon value for numerical stability.

      The logarithm of q.
    mag = linalg.norm(q, axis=-1, keepdims=True)
    v = im(q)
    s = re(q)
    w = jnp.log(mag)
    denom = jnp.maximum(linalg.norm(v, axis=-1, keepdims=True), eps * jnp.ones_like(v))
    xyz = v / denom * safe_acos(s / eps)
    return jnp.concatenate((xyz, w), axis=-1)
Exemplo n.º 16
def solve_vfi(money, EV_list, umult, kf, km, sigma, beta, i, wn, wt, sgrid,

    #ts = 0.01
    # dim is (na,ns,nexo,ntheta)

    EV_stretch_list = [(wt[:,None,None]*x[i,:,:] + \
                        for x in EV_list]

    consumption = money[:, None, :] - sgrid[None, :, None]
    consumption_negative = (consumption <= 0)
    uc = (np.maximum(consumption, 1e-8))**(1 - sigma) / (1 - sigma)
    utility = umult[None,None,None,:]*(uc[:,:,:,None]) - \

    EVs, EVFs, EVMs = EV_stretch_list

    mega_matrix = utility + beta * EVs[None, :, :, :]

    ind_s = mega_matrix.argmax(axis=1)
    V = np.take_along_axis(mega_matrix,ind_s[:,None,:,:],1)\
                                                .squeeze(axis=1) + psi

    s = sgrid[ind_s]
    c = money[:, :, None] - s

    V_check = umult[None,None,:]*(c**(1-sigma)/(1-sigma)) + \
                            psi + beta*np.take_along_axis(EVs,ind_s,0)

    VF = ((kf[None,None,:]*c)**(1-sigma)/(1-sigma)) + \
                            psi + beta*np.take_along_axis(EVFs,ind_s,0)
    VM = ((km[None,None,:]*c)**(1-sigma)/(1-sigma)) + \
                            psi + beta*np.take_along_axis(EVMs,ind_s,0)

    assert np.allclose(V_check, V, atol=1e-5)

    return V, VF, VM, s
Exemplo n.º 17
def mean_and_var(
    x: Optional[np.ndarray],
    axis: Optional[Axes] = None,
    dtype: Optional[np.dtype] = None,
    out: Optional[None] = None,
    ddof: int = 0,
    keepdims: bool = False,
    mask: Optional[np.ndarray] = None,
    get_var: bool = False
) -> Tuple[Optional[np.ndarray], Optional[np.ndarray]]:
    """`np.mean` and `np.var` taking the `mask` information into account."""
    var = None
    if x is None:
        return x, var

    if mask is None:
        mean = np.mean(x, axis, dtype, out, keepdims)
        if get_var:
            var = np.var(x, axis, dtype, out, ddof, keepdims)

        axis = tuple(utils.canonicalize_axis(axis, x))
        size = utils.size_at(x, axis)
        mask = np.broadcast_to(mask, x.shape)
        mask_size = np.count_nonzero(mask, axis)
        for i in axis:
            mask_size = np.expand_dims(mask_size, i)
        size -= mask_size
        size = np.maximum(size, 1)

        mean = np.sum(x, axis=axis, keepdims=True) / size
        if not keepdims:
            mean = np.squeeze(mean, axis)

        if get_var:
            var = np.sum(
                (x - mean)**2, axis=axis, keepdims=True) / (size - ddof)
            if not keepdims:
                var = np.squeeze(var, axis)

    return mean, var
Exemplo n.º 18
    def __init__(self,
                 max_range=(-100.0, 100.0)):
        self._precision = precision

        # Some gym envs (e.g. CartPole) have unreasonably high bounds for
        # observations. We clip so we can represent them.
        bounded_space = copy.copy(space)
        (min_low, max_high) = max_range
        bounded_space.low = np.maximum(space.low, min_low)
        bounded_space.high = np.minimum(space.high, max_high)
        if (not np.allclose(bounded_space.low, space.low)
                or not np.allclose(bounded_space.high, space.high)):
                'Space limits %s, %s out of bounds %s. Clipping to %s, %s.',
                str(space.low), str(space.high), str(max_range),
                str(bounded_space.low), str(bounded_space.high))

        super().__init__(bounded_space, vocab_size)
Exemplo n.º 19
def lerp_weight(x, xs):
    """Linear interpolation weight from a sample at x to xs.

  Returns the linear interpolation weight of a "query point" at coordinate `x`
  with respect to a "sample" at coordinate `xs`.

  The integer coordinates `x` are at pixel centers.
  The floating point coordinates `xs` are at pixel edges.
  (OpenGL convention).

    x: "Query" point position.
    xs: "Sample" position.

    - 1 when x = xs.
    - 0 when |x - xs| > 1.
    dx = x - xs
    abs_dx = abs(dx)
    return jnp.maximum(1.0 - abs_dx, 0.0)
Exemplo n.º 20
def initial_step_size(fun, t0, y0, order, rtol, atol, f0):
  # Algorithm from:
  # E. Hairer, S. P. Norsett G. Wanner,
  # Solving Ordinary Differential Equations I: Nonstiff Problems, Sec. II.4.
  y0, f0 = _promote_dtypes_inexact(y0, f0)
  dtype = y0.dtype

  scale = atol + jnp.abs(y0) * rtol
  d0 = jnp.linalg.norm(y0 / scale.astype(dtype))
  d1 = jnp.linalg.norm(f0 / scale.astype(dtype))

  h0 = jnp.where((d0 < 1e-5) | (d1 < 1e-5), 1e-6, 0.01 * d0 / d1)
  y1 = y0 + h0.astype(dtype) * f0
  f1 = fun(y1, t0 + h0)
  d2 = jnp.linalg.norm((f1 - f0) / scale.astype(dtype)) / h0

  h1 = jnp.where((d1 <= 1e-15) & (d2 <= 1e-15),
                jnp.maximum(1e-6, h0 * 1e-3),
                (0.01 / jnp.max(d1 + d2)) ** (1. / (order + 1.)))

  return jnp.minimum(100. * h0, h1)
Exemplo n.º 21
    def log_factor_k(cluster_id, log_maha_k, num_k, logdetC_k):
        Computes f_k such that,
            u_k @ f0_k n_k C_k @ u_k <= 1
            f_k^d V(n_k C_k) = max(V(S_k), V(f0_k n_k C_k))
            log_f_k = (log max(V(S)*n_k/n_S, V(f0_k n_k C_k)) - log V(n_k C_k))/D
            log_f_k = (max(log(V(S)*n_k/n_S), logV(n_k C_k)) - log V(n_k C_k))/D
        # K
        log_f_expand_k = -jnp.max(jnp.where(cluster_id == a_k[:, None],
                                            log_maha_k, -jnp.inf),
        log_VE_expand_k = log_ellipsoid_volume(logdetC_k, num_k,
        log_VE_k = log_ellipsoid_volume(logdetC_k, num_k, 0.)

        log_scale_k = (jnp.maximum(log_VS + jnp.log(num_k) - jnp.log(num_S),
                                   log_VE_expand_k) - log_VE_k) / D
        # K
        return log_scale_k
Exemplo n.º 22
Arquivo: mip.py Projeto: wx-b/mipnerf
def lift_gaussian(d, t_mean, t_var, r_var, diag):
    """Lift a Gaussian defined along a ray to 3D coordinates."""
    mean = d[..., None, :] * t_mean[..., None]

    d_mag_sq = jnp.maximum(1e-10, jnp.sum(d**2, axis=-1, keepdims=True))

    if diag:
        d_outer_diag = d**2
        null_outer_diag = 1 - d_outer_diag / d_mag_sq
        t_cov_diag = t_var[..., None] * d_outer_diag[..., None, :]
        xy_cov_diag = r_var[..., None] * null_outer_diag[..., None, :]
        cov_diag = t_cov_diag + xy_cov_diag
        return mean, cov_diag
        d_outer = d[..., :, None] * d[..., None, :]
        eye = jnp.eye(d.shape[-1])
        null_outer = eye - d[..., :, None] * (d / d_mag_sq)[..., None, :]
        t_cov = t_var[..., None, None] * d_outer[..., None, :, :]
        xy_cov = r_var[..., None, None] * null_outer[..., None, :, :]
        cov = t_cov + xy_cov
        return mean, cov
Exemplo n.º 23
def predict_cnn(params, inputs, include_preactivations=False):
    """Forward pass for a CNN given parameters.

    params: Parameters for the CNN. See make_cnn_params for syntax.
    inputs: Inputs to CNN.
    include_preactivations: bool. If True, also return pre-activations after
      each matmul layer.
    act: Output from forward pass through CNN.
    (Optional) layer_acts: Post-relu activation at each layer
    act = inputs
    layer_preacts = []
    for counter, layer_params in enumerate(params):
        act = fwd(act, layer_params)
        if counter < len(params) - 1:
            # no relu on final layer
            act = jnp.maximum(act, 0)
    return act if not include_preactivations else (act, layer_preacts)
Exemplo n.º 24
def dists_to_samples(rays, t):
    """Convert mipnerf frustums to gaussians."""
    t_mids = .5 * (t[Ellipsis, 1:] + t[Ellipsis, :-1])
    mean = rays[0][Ellipsis,
                   None, :] + rays[1][Ellipsis, None, :] * t_mids[Ellipsis,

    d = rays[1]
    d_mag_sq = np.maximum(1e-10, np.sum(d**2, axis=-1, keepdims=True))
    t_half = .5 * (t[Ellipsis, 1:] - t[Ellipsis, :-1])
    t_var = t_half**2 / 3.
    r_var = (rays[2] * t_mids)**2 / 12.

    d_outer = d[Ellipsis, :, None] * d[Ellipsis, None, :]
    eye = np.eye(d.shape[-1])
    null_outer = eye - d[Ellipsis, :, None] * (d / d_mag_sq)[Ellipsis, None, :]
    t_cov = t_var[Ellipsis, None, None] * d_outer[Ellipsis, None, :, :]
    xy_cov = r_var[Ellipsis, None, None] * null_outer[Ellipsis, None, :, :]
    cov = t_cov + xy_cov

    return mean, cov
Exemplo n.º 25
    def unconstrained_proposal(self, rng_key, x, grad_, hess_):
        ndim = np.ndim(x)
        if ndim == 0:
            inv_hess = 1 / hess_
            dist_type = dist.Normal
            inv_hess = np.linalg.inv(hess_)
            dist_type = dist.MultivariateNormal

        loc = x - np.dot(inv_hess, grad_)
        sigma = -inv_hess

        # Reconstruct sigma if not positive definite
        if not ndim == 0 and not np.all(np.linalg.eigvals(sigma) > 0):
            lam, vec = np.linalg.eigh(sigma)
            sigma = vec @ np.diag(np.maximum(
                lam, UNCONSTRAINED_RECONSTRUCTION)) @ vec.T

        dist_ = dist_type(loc, sigma + MU_CORRECTION)

        return dist_.sample(rng_key).reshape(x.shape), dist_
Exemplo n.º 26
 def from_params(
         traceable=True):  # FIXME: traceable; why sometimes no Scale?
     if not scale:
         scale = Scale(0.0, 1.0)
     floor = fixed_params.get("floor", -np.inf)
     ceiling = fixed_params.get("ceiling", np.inf)
     # Allow logistic center to exceed the range by 20%
     loc_min = np.maximum(scale.low, floor) - 0.2 * scale.width
     loc_max = np.minimum(scale.high, ceiling) + 0.2 * scale.width
     loc_range = loc_max - loc_min
     structured_params = opt_params.reshape((-1, 3))
     locs = loc_min + scipy.special.expit(structured_params[:,
                                                            0]) * loc_range
     # Allow logistic scales between 0.01 and 0.5
     # Don't allow tiny scales outside of the visible range
     s_min = 0.01 + 0.1 * np.where(
         (locs < scale.low),
         scale.low - locs,
         np.where(locs > scale.high, locs - scale.high, 0.0),
     s_max = 0.5
     s_range = s_max - s_min
     ss = s_min + scipy.special.expit(structured_params[:, 1]) * s_range
     # Allow probs > 0.01
     probs = list(0.01 + nn.softmax(structured_params[:, 2]) *
                  (1 - 0.01 * structured_params[:, 2].size))
     # Bundle up components
     component_logistics = [
         Logistic(l, s, scale, normalized=True) for (l, s) in zip(locs, ss)
     components = [
         Truncate(base_dist=cl, floor=floor, ceiling=ceiling)
         for cl in component_logistics
     mixture = cls(components=components, probs=probs)
     return mixture
Exemplo n.º 27
    def light(self,
        Basic light model using a only diffuse lighting
        rayhit = origin + direction * intersection
        normal = ((rayhit - self.center) * (1. / self.radius))
        direction_to_light = (light_position - rayhit).norm()
        direction_to_eye = (eye_position - rayhit).norm()
        nudged = rayhit + normal * 0.001  # To avoid shadow acne

        # Create shadow mask
        light_distances = [
            o.intersect(nudged, direction_to_light, far=far)
            for o in scene_objects
        light_nearest = reduce(jnp.minimum, light_distances)
        light_mask = light_distances[scene_objects.index(
            self)] == light_nearest

        # Ambient light
        color = Vec3(0.05, 0.05, 0.05)

        # Lambert shading (diffuse)
        light_hit = jnp.maximum(normal.dot(direction_to_light), 0)
        color += self.diffusecolor(rayhit) * light_hit * light_mask

        # Phong light
        phong = normal.dot((direction_to_light + direction_to_eye).norm())
        color += Vec3(1., 1., 1.) * jnp.power(jnp.clip(phong, 0, 1),
                                              50) * light_mask

        return color
Exemplo n.º 28
def segment_max(data,
    """Computes the max within segments of an array.

  Similar to TensorFlow's segment_max:

    data: an array with the values to be maxed over.
    segment_ids: an array with integer dtype that indicates the segments of
      `data` (along its leading axis) to be maxed over. Values can be repeated
      and need not be sorted. Values outside of the range [0, num_segments) are
      wrapped into that range by applying jnp.mod.
    num_segments: optional, an int with positive value indicating the number of
      segments. The default is ``jnp.maximum(jnp.max(segment_ids) + 1,
      jnp.max(-segment_ids))`` but since `num_segments` determines the size of
      the output, a static value must be provided to use ``segment_max`` in a
      ``jit``-compiled function.
    indices_are_sorted: whether ``segment_ids`` is known to be sorted
    unique_indices: whether ``segment_ids`` is known to be free of duplicates

    An array with shape ``(num_segments,) + data.shape[1:]`` representing
    the segment maxs.
    if num_segments is None:
        num_segments = jnp.maximum(
            jnp.max(segment_ids) + 1, jnp.max(-segment_ids))
    num_segments = int(num_segments)

    min_value = dtype_min_value(data.dtype)
    out = jnp.full((num_segments, ) + data.shape[1:],
    segment_ids = jnp.mod(segment_ids, num_segments)
    return jax.ops.index_max(out, segment_ids, data, indices_are_sorted,
Exemplo n.º 29
    def compute_loss(self, predictions: NestedMap,
                     input_batch: NestedMap) -> Tuple[Metrics, Dict[str, Any]]:
        """Computes the loss and other metrics for the given predictions.

      predictions: The output of `compute_predictions`.
      input_batch: A `.NestedMap` object containing input tensors to this tower.

      - A dict or NestedMap containing str keys and (metric, weight) pairs as
        values, where one of the entries is expected to corresponds to the loss.
      - A dict containing arbitrary tensors describing something about each
        training example, where the first dimension of each tensor is the batch
        labels = input_batch.labels
        num_tokens = jnp.sum(1.0 - input_batch.paddings.astype(jnp.float32))
        num_seqs = jnp.sum(
            jnp.amax(input_batch.segment_ids.astype(jnp.float32), axis=1))
        weights = predictions.augmented_pos.astype(jnp.float32)
        predicted_labels = predictions.per_example_argmax.astype(labels.dtype)
        num_preds = predictions.total_weight.astype(jnp.float32)
        mean_acc = jnp.sum(
            (labels == predicted_labels) * weights) / jnp.maximum(
                num_preds, 1)
        metric_weight = jnp.array(num_preds, predictions.avg_xent.dtype)
        metrics = py_utils.NestedMap(
            total_loss=(predictions.total_loss, metric_weight),
            avg_xent=(predictions.avg_xent, metric_weight),
            aux_loss=(predictions.aux_loss, metric_weight),
            log_pplx=(predictions.avg_xent, metric_weight),
                                       jnp.array(num_preds, mean_acc.dtype)),
            num_predictions=(num_preds, jnp.array(1.0, num_preds.dtype)),
            num_tokens=(num_tokens, jnp.array(1.0, num_tokens.dtype)),
            num_seqs=(num_seqs, jnp.array(1.0, num_seqs.dtype)),

        per_example_output = py_utils.NestedMap()
        return metrics, per_example_output
Exemplo n.º 30
    def test_clipping_norm(self, l2_norm_clip):
        dp_agg = privacy.differentially_private_aggregate(
            l2_norm_clip=l2_norm_clip, noise_multiplier=0., seed=42)
        state = dp_agg.init(self.params)
        update_fn = self.variant(dp_agg.update)

        # Shape of the three arrays below is (self.batch_size, )
        norms = [
            jnp.linalg.norm(g.reshape(self.batch_size, -1), axis=1)
            for g in jax.tree_leaves(self.per_eg_grads)
        global_norms = jnp.linalg.norm(jnp.stack(norms), axis=0)
        divisors = jnp.maximum(global_norms / l2_norm_clip, 1.)
        # Since the values of all the parameters are the same within each example,
        # we can easily compute what the values should be:
        expected_val = jnp.mean(jnp.arange(self.batch_size) / divisors)
        expected_tree = jax.tree_map(
            lambda p: jnp.broadcast_to(expected_val, p.shape), self.params)

        for _ in range(3):
            updates, state = update_fn(self.per_eg_grads, state, self.params)
            chex.assert_tree_all_close(updates, expected_tree)