def value_loss_given_predictions(value_prediction, rewards, reward_mask, gamma=0.99, epsilon=0.2, value_prediction_old=None): """Computes the value loss given the prediction of the value function. Args: value_prediction: np.ndarray of shape (B, T+1, 1) rewards: np.ndarray of shape (B, T) of rewards. reward_mask: np.ndarray of shape (B, T), the mask over rewards. gamma: float, discount factor. epsilon: float, clip-fraction, used if value_value_prediction_old isn't None value_prediction_old: np.ndarray of shape (B, T+1, 1) of value predictions using the old parameters. If provided, we incorporate this in the loss as well. This is from the OpenAI baselines implementation. Returns: The average L2 value loss, averaged over instances where reward_mask is 1. """ B, T = rewards.shape # pylint: disable=invalid-name assert (B, T) == reward_mask.shape assert (B, T + 1, 1) == value_prediction.shape value_prediction = np.squeeze(value_prediction, axis=2) # (B, T+1) value_prediction = value_prediction[:, :-1] * reward_mask # (B, T) r2g = rewards_to_go(rewards, reward_mask, gamma=gamma) # (B, T) loss = (value_prediction - r2g)**2 # From the baselines implementation. if value_prediction_old is not None: value_prediction_old = np.squeeze(value_prediction_old, axis=2) # (B, T+1) value_prediction_old = value_prediction_old[:, :-1] * reward_mask # (B, T) v_clipped = value_prediction_old + np.clip( value_prediction - value_prediction_old, -epsilon, epsilon) v_clipped_loss = (v_clipped - r2g)**2 loss = np.maximum(v_clipped_loss, loss) # Take an average on only the points where mask != 0. return np.sum(loss) / np.sum(reward_mask)
def update_fn(updates, state, params=None): del params grads_flat, grads_treedef = jax.tree_flatten(updates) bsize = grads_flat[0].shape[0] if any(g.ndim == 0 or bsize != g.shape[0] for g in grads_flat): raise ValueError( 'Unlike other transforms, `differentially_private_aggregate` expects' ' `updates` to have a batch dimension in the 0th axis. That is, this' ' function expects per-example gradients as input.') new_key, *rngs = jax.random.split(state.rng_key, len(grads_flat)+1) global_grad_norms = jax.vmap(utils.global_norm)(grads_flat) divisors = jnp.maximum(global_grad_norms / l2_norm_clip, 1.0) clipped = [(jnp.moveaxis(g, 0, -1) / divisors).sum(-1) for g in grads_flat] noised = [(g + noise_std * jax.random.normal(r, g.shape, g.dtype)) / bsize for g, r in zip(clipped, rngs)] return (jax.tree_unflatten(grads_treedef, noised), DifferentiallyPrivateAggregateState(rng_key=new_key))
def optimal_step_size(last_step, mean_error_ratio, safety=0.9, ifactor=10.0, dfactor=0.2, order=5.0): """Compute optimal Runge-Kutta stepsize.""" mean_error_ratio = np.max(mean_error_ratio) dfactor = np.where(mean_error_ratio < 1, 1.0, dfactor) err_ratio = np.sqrt(mean_error_ratio) factor = np.maximum( 1.0 / ifactor, np.minimum(err_ratio**(1.0 / order) / safety, 1.0 / dfactor)) return np.where( mean_error_ratio == 0, last_step * ifactor, last_step / factor, )
def assert_close(expected, actual): self.assertEqual(expected.shape, actual.shape) relative_error = (np.linalg.norm(actual - expected) / np.maximum(np.linalg.norm(expected), 1e-12)) absolute_error = np.mean(np.abs(actual - expected)) if (np.isnan(relative_error) or relative_error > rtol or absolute_error > atol): _log(relative_error, absolute_error, expected, actual, False) self.fail( self.failureException('Relative ERROR: ', float(relative_error), 'EXPECTED:' + ' ' * 50, expected, 'ACTUAL:' + ' ' * 50, actual, ' ' * 50, 'Absolute ERROR: ', float(absolute_error))) else: _log(relative_error, absolute_error, expected, actual, True)
def optimize(state, grad, warmup=config.optim.warmup, grad_clip=config.optim.grad_clip): """Optimizes with warmup and gradient clipping (disabled if negative).""" lr = state.lr if warmup > 0: lr = lr * jnp.minimum(state.step / warmup, 1.0) if grad_clip >= 0: # Compute global gradient norm grad_norm = jnp.sqrt( sum([jnp.sum(jnp.square(x)) for x in jax.tree_leaves(grad)])) # Clip gradient clipped_grad = jax.tree_map( lambda x: x * grad_clip / jnp.maximum(grad_norm, grad_clip), grad) else: # disabling gradient clipping if grad_clip < 0 clipped_grad = grad return state.optimizer.apply_gradient(clipped_grad, learning_rate=lr)
def test_segment_max_negatives(self, indices_are_sorted, unique_indices): neg_inf = jnp.iinfo(jnp.int32).min if unique_indices: data = -1 - jnp.arange(6) # [-1, -2, -3, -4, -5, -6] if indices_are_sorted: segment_ids = jnp.array([0, 1, 2, 3, 4, 5]) expected_out = jnp.array([-1, -2, -3, -4, -5, -6]) num_segments = 6 else: segment_ids = jnp.array([1, 0, 2, 4, 3, -5]) expected_out = jnp.array([-2, -1, -3, -5, -4]) num_segments = 5 else: data = -1 - jnp.arange(9) # [-1, -2, -3, -4, -5, -6, -7, -8, -9] if indices_are_sorted: segment_ids = jnp.array([0, 0, 0, 1, 1, 1, 2, 3, 4]) expected_out = jnp.array([-1, -4, -7, -8, -9, neg_inf]) else: segment_ids = jnp.array([0, 1, 2, 0, 4, 0, 1, 1, -6]) expected_out = jnp.array([-1, -2, -3, neg_inf, -5, neg_inf]) num_segments = 6 with self.subTest('nojit'): result = utils.segment_max(data, segment_ids, num_segments, indices_are_sorted, unique_indices) self.assertAllClose(result, expected_out, check_dtypes=True) result = utils.segment_max(data, segment_ids, indices_are_sorted=indices_are_sorted, unique_indices=unique_indices) num_unique_segments = jnp.maximum( jnp.max(segment_ids) + 1, jnp.max(-segment_ids)) self.assertAllClose(result, expected_out[:num_unique_segments], check_dtypes=True) with self.subTest('jit'): result = jax.jit(utils.segment_max, static_argnums=(2, 3, 4))(data, segment_ids, num_segments, indices_are_sorted, unique_indices) self.assertAllClose(result, expected_out, check_dtypes=True)
def norm_projection(delta, norm_type, eps=1.): """Projects to a norm-ball centered at 0. Args: delta: An array of size dim x num containing vectors to be projected. norm_type: A string denoting the type of the norm-ball. eps: A float denoting the radius of the norm-ball. Returns: An array of size dim x num, the projection of delta to the norm-ball. """ shape = delta.shape if len(delta.shape) == 1: delta = delta.reshape(-1, 1) if norm_type == 'linf': delta = jnp.clip(delta, -eps, eps) elif norm_type == 'l2': # Euclidean projection: divide all elements by a constant factor avoid_zero_div = 1e-12 norm2 = jnp.sum(delta**2, axis=0, keepdims=True) norm = jnp.sqrt(jnp.maximum(avoid_zero_div, norm2)) # only decrease the norm, never increase delta = delta * jnp.clip(eps / norm, a_min=None, a_max=1) elif norm_type == 'l1': delta = l1_unit_projection(delta / eps) * eps elif norm_type == 'dftinf': # transform to DFT, project using known projections, then transform back # dft = np.matrix(scipy.linalg.dft(delta.shape[0]) / np.sqrt(delta.shape[0])) dft = np.matrix(scipy.linalg.dft(delta.shape[0], scale='sqrtn')) dftxdelta = dft @ delta # dftxdelta = np.matrix(scipy.fft.fft(delta, axis=0, norm='ortho')) # L2 projection of each coordinate to the L2-ball in the complex plane dftz = dftxdelta.reshape(1, -1) dftz = jnp.concatenate((jnp.real(dftz), jnp.imag(dftz)), axis=0) dftz = norm_projection(dftz, 'l2', eps) dftz = (dftz[0, :] + 1j * dftz[1, :]).reshape(delta.shape) # project back from DFT delta = dft.getH() @ dftz # delta = np.matrix(scipy.fft.ifft(dftz, axis=0, norm='ortho')) # Projected vector can have an imaginary part delta = jnp.real(delta) return delta.reshape(shape)
def _internal_bi_tempered_logistic_loss(activations, labels, t1, t2): """Computes the Bi-Tempered logistic loss. Args: activations: A multi-dimensional array with last dimension `num_classes`. labels: batch_size t1: Temperature 1 (< 1.0 for boundedness). t2: Temperature 2 (> 1.0 for tail heaviness). Returns: A loss array for robust loss. """ normalization_constants = compute_normalization(activations, t2, num_iters=5) if t2 == 1.0: if t1 == 1.0: return normalization_constants + jnp.sum( jnp.multiply(labels, jnp.log(labels + 1e-10) - activations), -1) else: shifted_activations = jnp.exp(activations - normalization_constants) one_minus_t1 = (1.0 - t1) one_minus_t2 = 1.0 else: one_minus_t1 = (1.0 - t1) one_minus_t2 = (1.0 - t2) shifted_activations = jnp.maximum( 1.0 + one_minus_t2 * (activations - normalization_constants), 0.0) if t1 == 1.0: return jnp.sum( jnp.multiply( jnp.log(labels + 1e-10) - jnp.log(jnp.power(shifted_activations, 1.0 / one_minus_t2)), labels), -1) else: beta = 1.0 + one_minus_t1 logt_probs = (jnp.power(shifted_activations, one_minus_t1 / one_minus_t2) - 1.0) / one_minus_t1 return jnp.sum( jnp.multiply(log_t(labels, t1) - logt_probs, labels) - 1.0 / beta * (jnp.power(labels, beta) - jnp.power(shifted_activations, beta / one_minus_t2)), -1)
def sample_bounds(key: jnp.ndarray, shape: Tuple[int, ...], minval: float = -2., maxval: float = 2.) -> Tuple[jnp.ndarray, jnp.ndarray]: """Sample some bounds of the required shape. Args: key: Random number generator. shape: Shape of the bounds to generate. minval: Optional, smallest value that the bounds could take. maxval: Optional, largest value that the bounds could take. Returns: lb, ub: Lower and upper bound tensors of the desired shape. """ key_0, key_1 = jax.random.split(key) bound_1 = jax.random.uniform(key_0, shape, minval=minval, maxval=maxval) bound_2 = jax.random.uniform(key_1, shape, minval=minval, maxval=maxval) lb = jnp.minimum(bound_1, bound_2) ub = jnp.maximum(bound_1, bound_2) return lb, ub
def _clip_by_l2_norm(x: Array, max_norm: float) -> Array: """Clip gradients to maximum l2 norm `max_norm`.""" # Compute the sum of squares and find out where things are zero. sum_sq = jnp.sum(jnp.vdot(x, x)) nonzero = sum_sq > 0 # Compute the norm wherever sum_sq > 0 and leave it <= 0 otherwise. This makes # use of the the "double where" trick; see # https://jax.readthedocs.io/en/latest/faq.html#gradients-contain-nan-where-using-where # for more info. In short this is necessary because although norm ends up # computed correctly where nonzero is true if we ignored this we'd end up with # nans on the off-branches which would leak through when computed gradients in # the backward pass. sum_sq_ones = jnp.where(nonzero, sum_sq, jnp.ones_like(sum_sq)) norm = jnp.where(nonzero, jnp.sqrt(sum_sq_ones), sum_sq) # Normalize by max_norm. Whenever norm < max_norm we're left with x (this # happens trivially for indices where nonzero is false). Otherwise we're left # with the desired x * max_norm / norm. return (x * max_norm) / jnp.maximum(norm, max_norm)
def _dynamics(state, action): self.nsamples += 1 position = state[0] velocity = state[1] force = jnp.minimum(jnp.maximum(action, self.min_action), self.max_action) velocity += force * self.power - 0.0025 * jnp.cos(3 * position) velocity = jnp.clip(velocity, -self.max_speed, self.max_speed) position += velocity position = jnp.clip(position, self.min_position, self.max_position) reset_velocity = (position == self.min_position) & (velocity < 0) # print('state.shape = ' + str(state.shape)) # print('position.shape = ' + str(position.shape)) # print('velocity.shape = ' + str(velocity.shape)) # print('reset_velocity.shape = ' + str(reset_velocity.shape)) velocity = jax.lax.cond(reset_velocity[0], velocity, lambda x: jnp.zeros((1,)), velocity, lambda x: x) # print('velocity.shape AFTER = ' + str(velocity.shape)) return jnp.reshape(jnp.array([position, velocity]), (2,))
def convex_fn_relaxation(primitive: bound_propagation.Primitive, inp: Bound, **params) -> Tuple[TensorFunction, TensorFunction]: """Relaxation of an element-wise convex primitive. Args: primitive: Convex primitive to relax. inp: Bounds on the input. **params: Params of the quadratic operation, mainly the jaxpr defining it. Returns: lb_fun, ub_fun """ prim_fun = functools.partial(primitive.bind, **params) x_lb, x_ub = inp.lower, inp.upper y_lb, y_ub = prim_fun(x_lb), prim_fun(x_ub) chord_slope_safe_denom = jnp.maximum(x_ub - x_lb, 1e-12) chord_slope = (y_ub - y_lb) / chord_slope_safe_denom chord_intercept = y_lb - chord_slope * x_lb chord_fun = lambda x: chord_slope * x + chord_intercept return prim_fun, chord_fun
def cost(self, p, extra=None, precomputed=None): """ Negetive Log Likelihood. """ y = self.y if extra is None else extra['y'] r = self.forward_pass(p, extra) if precomputed is None else precomputed r = np.maximum(r, 1e-20) # remove zero to avoid nan in log. dt = self.dt term0 = -np.log(r / dt) @ y # spike term from poisson log-likelihood term1 = np.sum(r) # non-spike term neglogli = term0 + term1 if self.beta and extra is None: l1 = np.linalg.norm(p['w'], 1) l2 = np.linalg.norm(p['w'], 2) neglogli += self.beta * ((1 - self.alpha) * l2 + self.alpha * l1) return neglogli
def l2_normalize(arr, axis, epsilon=1e-12): """ L2 normalize along a particular axis. Doc taken from tf.nn.l2_normalize: https://www.tensorflow.org/api_docs/python/tf/math/l2_normalize output = x / ( sqrt( max( sum(x**2), epsilon ) ) ) """ sq_arr = np.power(arr, 2) square_sum = np.sum(sq_arr, axis=axis, keepdims=True) max_weights = np.maximum(square_sum, epsilon) return np.divide(arr, np.sqrt(max_weights))
def log(q, eps=1e-8): """Computes the quaternion logarithm. References: https://en.wikipedia.org/wiki/Quaternion#Exponential,_logarithm,_and_power_functions Args: q: the quaternion in (x,y,z,w) format. eps: an epsilon value for numerical stability. Returns: The logarithm of q. """ mag = linalg.norm(q, axis=-1, keepdims=True) v = im(q) s = re(q) w = jnp.log(mag) denom = jnp.maximum(linalg.norm(v, axis=-1, keepdims=True), eps * jnp.ones_like(v)) xyz = v / denom * safe_acos(s / eps) return jnp.concatenate((xyz, w), axis=-1)
def solve_vfi(money, EV_list, umult, kf, km, sigma, beta, i, wn, wt, sgrid, psi): #ts = 0.01 # dim is (na,ns,nexo,ntheta) EV_stretch_list = [(wt[:,None,None]*x[i,:,:] + \ wn[:,None,None]*x[i+1,:,:]) for x in EV_list] consumption = money[:, None, :] - sgrid[None, :, None] consumption_negative = (consumption <= 0) uc = (np.maximum(consumption, 1e-8))**(1 - sigma) / (1 - sigma) utility = umult[None,None,None,:]*(uc[:,:,:,None]) - \ 1e9*consumption_negative[:,:,:,None] EVs, EVFs, EVMs = EV_stretch_list mega_matrix = utility + beta * EVs[None, :, :, :] #print(mega_matrix.shape) ind_s = mega_matrix.argmax(axis=1) V = np.take_along_axis(mega_matrix,ind_s[:,None,:,:],1)\ .squeeze(axis=1) + psi s = sgrid[ind_s] c = money[:, :, None] - s V_check = umult[None,None,:]*(c**(1-sigma)/(1-sigma)) + \ psi + beta*np.take_along_axis(EVs,ind_s,0) VF = ((kf[None,None,:]*c)**(1-sigma)/(1-sigma)) + \ psi + beta*np.take_along_axis(EVFs,ind_s,0) VM = ((km[None,None,:]*c)**(1-sigma)/(1-sigma)) + \ psi + beta*np.take_along_axis(EVMs,ind_s,0) assert np.allclose(V_check, V, atol=1e-5) return V, VF, VM, s
def mean_and_var( x: Optional[np.ndarray], axis: Optional[Axes] = None, dtype: Optional[np.dtype] = None, out: Optional[None] = None, ddof: int = 0, keepdims: bool = False, mask: Optional[np.ndarray] = None, get_var: bool = False ) -> Tuple[Optional[np.ndarray], Optional[np.ndarray]]: """`np.mean` and `np.var` taking the `mask` information into account.""" var = None if x is None: return x, var if mask is None: mean = np.mean(x, axis, dtype, out, keepdims) if get_var: var = np.var(x, axis, dtype, out, ddof, keepdims) else: axis = tuple(utils.canonicalize_axis(axis, x)) size = utils.size_at(x, axis) mask = np.broadcast_to(mask, x.shape) mask_size = np.count_nonzero(mask, axis) for i in axis: mask_size = np.expand_dims(mask_size, i) size -= mask_size size = np.maximum(size, 1) mean = np.sum(x, axis=axis, keepdims=True) / size if not keepdims: mean = np.squeeze(mean, axis) if get_var: var = np.sum( (x - mean)**2, axis=axis, keepdims=True) / (size - ddof) if not keepdims: var = np.squeeze(var, axis) return mean, var
def __init__(self, space, vocab_size, precision=2, max_range=(-100.0, 100.0)): self._precision = precision # Some gym envs (e.g. CartPole) have unreasonably high bounds for # observations. We clip so we can represent them. bounded_space = copy.copy(space) (min_low, max_high) = max_range bounded_space.low = np.maximum(space.low, min_low) bounded_space.high = np.minimum(space.high, max_high) if (not np.allclose(bounded_space.low, space.low) or not np.allclose(bounded_space.high, space.high)): logging.warning( 'Space limits %s, %s out of bounds %s. Clipping to %s, %s.', str(space.low), str(space.high), str(max_range), str(bounded_space.low), str(bounded_space.high)) super().__init__(bounded_space, vocab_size)
def lerp_weight(x, xs): """Linear interpolation weight from a sample at x to xs. Returns the linear interpolation weight of a "query point" at coordinate `x` with respect to a "sample" at coordinate `xs`. The integer coordinates `x` are at pixel centers. The floating point coordinates `xs` are at pixel edges. (OpenGL convention). Args: x: "Query" point position. xs: "Sample" position. Returns: - 1 when x = xs. - 0 when |x - xs| > 1. """ dx = x - xs abs_dx = abs(dx) return jnp.maximum(1.0 - abs_dx, 0.0)
def initial_step_size(fun, t0, y0, order, rtol, atol, f0): # Algorithm from: # E. Hairer, S. P. Norsett G. Wanner, # Solving Ordinary Differential Equations I: Nonstiff Problems, Sec. II.4. y0, f0 = _promote_dtypes_inexact(y0, f0) dtype = y0.dtype scale = atol + jnp.abs(y0) * rtol d0 = jnp.linalg.norm(y0 / scale.astype(dtype)) d1 = jnp.linalg.norm(f0 / scale.astype(dtype)) h0 = jnp.where((d0 < 1e-5) | (d1 < 1e-5), 1e-6, 0.01 * d0 / d1) y1 = y0 + h0.astype(dtype) * f0 f1 = fun(y1, t0 + h0) d2 = jnp.linalg.norm((f1 - f0) / scale.astype(dtype)) / h0 h1 = jnp.where((d1 <= 1e-15) & (d2 <= 1e-15), jnp.maximum(1e-6, h0 * 1e-3), (0.01 / jnp.max(d1 + d2)) ** (1. / (order + 1.))) return jnp.minimum(100. * h0, h1)
def log_factor_k(cluster_id, log_maha_k, num_k, logdetC_k): """ Computes f_k such that, u_k @ f0_k n_k C_k @ u_k <= 1 and f_k^d V(n_k C_k) = max(V(S_k), V(f0_k n_k C_k)) log_f_k = (log max(V(S)*n_k/n_S, V(f0_k n_k C_k)) - log V(n_k C_k))/D log_f_k = (max(log(V(S)*n_k/n_S), logV(n_k C_k)) - log V(n_k C_k))/D """ # K log_f_expand_k = -jnp.max(jnp.where(cluster_id == a_k[:, None], log_maha_k, -jnp.inf), axis=-1) log_VE_expand_k = log_ellipsoid_volume(logdetC_k, num_k, log_f_expand_k) log_VE_k = log_ellipsoid_volume(logdetC_k, num_k, 0.) log_scale_k = (jnp.maximum(log_VS + jnp.log(num_k) - jnp.log(num_S), log_VE_expand_k) - log_VE_k) / D # K return log_scale_k
def lift_gaussian(d, t_mean, t_var, r_var, diag): """Lift a Gaussian defined along a ray to 3D coordinates.""" mean = d[..., None, :] * t_mean[..., None] d_mag_sq = jnp.maximum(1e-10, jnp.sum(d**2, axis=-1, keepdims=True)) if diag: d_outer_diag = d**2 null_outer_diag = 1 - d_outer_diag / d_mag_sq t_cov_diag = t_var[..., None] * d_outer_diag[..., None, :] xy_cov_diag = r_var[..., None] * null_outer_diag[..., None, :] cov_diag = t_cov_diag + xy_cov_diag return mean, cov_diag else: d_outer = d[..., :, None] * d[..., None, :] eye = jnp.eye(d.shape[-1]) null_outer = eye - d[..., :, None] * (d / d_mag_sq)[..., None, :] t_cov = t_var[..., None, None] * d_outer[..., None, :, :] xy_cov = r_var[..., None, None] * null_outer[..., None, :, :] cov = t_cov + xy_cov return mean, cov
def predict_cnn(params, inputs, include_preactivations=False): """Forward pass for a CNN given parameters. Args: params: Parameters for the CNN. See make_cnn_params for syntax. inputs: Inputs to CNN. include_preactivations: bool. If True, also return pre-activations after each matmul layer. Returns: act: Output from forward pass through CNN. (Optional) layer_acts: Post-relu activation at each layer """ act = inputs layer_preacts = [] for counter, layer_params in enumerate(params): act = fwd(act, layer_params) layer_preacts.append(act) if counter < len(params) - 1: # no relu on final layer act = jnp.maximum(act, 0) return act if not include_preactivations else (act, layer_preacts)
def dists_to_samples(rays, t): """Convert mipnerf frustums to gaussians.""" t_mids = .5 * (t[Ellipsis, 1:] + t[Ellipsis, :-1]) mean = rays[0][Ellipsis, None, :] + rays[1][Ellipsis, None, :] * t_mids[Ellipsis, None] d = rays[1] d_mag_sq = np.maximum(1e-10, np.sum(d**2, axis=-1, keepdims=True)) t_half = .5 * (t[Ellipsis, 1:] - t[Ellipsis, :-1]) t_var = t_half**2 / 3. r_var = (rays[2] * t_mids)**2 / 12. d_outer = d[Ellipsis, :, None] * d[Ellipsis, None, :] eye = np.eye(d.shape[-1]) null_outer = eye - d[Ellipsis, :, None] * (d / d_mag_sq)[Ellipsis, None, :] t_cov = t_var[Ellipsis, None, None] * d_outer[Ellipsis, None, :, :] xy_cov = r_var[Ellipsis, None, None] * null_outer[Ellipsis, None, :, :] cov = t_cov + xy_cov return mean, cov
def unconstrained_proposal(self, rng_key, x, grad_, hess_): ndim = np.ndim(x) if ndim == 0: inv_hess = 1 / hess_ dist_type = dist.Normal else: inv_hess = np.linalg.inv(hess_) dist_type = dist.MultivariateNormal loc = x - np.dot(inv_hess, grad_) sigma = -inv_hess # Reconstruct sigma if not positive definite if not ndim == 0 and not np.all(np.linalg.eigvals(sigma) > 0): lam, vec = np.linalg.eigh(sigma) sigma = vec @ np.diag(np.maximum( lam, UNCONSTRAINED_RECONSTRUCTION)) @ vec.T dist_ = dist_type(loc, sigma + MU_CORRECTION) return dist_.sample(rng_key).reshape(x.shape), dist_
def from_params( cls, fixed_params, opt_params, scale=None, traceable=True): # FIXME: traceable; why sometimes no Scale? if not scale: scale = Scale(0.0, 1.0) floor = fixed_params.get("floor", -np.inf) ceiling = fixed_params.get("ceiling", np.inf) # Allow logistic center to exceed the range by 20% loc_min = np.maximum(scale.low, floor) - 0.2 * scale.width loc_max = np.minimum(scale.high, ceiling) + 0.2 * scale.width loc_range = loc_max - loc_min structured_params = opt_params.reshape((-1, 3)) locs = loc_min + scipy.special.expit(structured_params[:, 0]) * loc_range # Allow logistic scales between 0.01 and 0.5 # Don't allow tiny scales outside of the visible range s_min = 0.01 + 0.1 * np.where( (locs < scale.low), scale.low - locs, np.where(locs > scale.high, locs - scale.high, 0.0), ) s_max = 0.5 s_range = s_max - s_min ss = s_min + scipy.special.expit(structured_params[:, 1]) * s_range # Allow probs > 0.01 probs = list(0.01 + nn.softmax(structured_params[:, 2]) * (1 - 0.01 * structured_params[:, 2].size)) # Bundle up components component_logistics = [ Logistic(l, s, scale, normalized=True) for (l, s) in zip(locs, ss) ] components = [ Truncate(base_dist=cl, floor=floor, ceiling=ceiling) for cl in component_logistics ] mixture = cls(components=components, probs=probs) return mixture
def light(self, origin, direction, intersection, light_position, eye_position, scene_objects, bounce=0, far=1.0e15): ''' Basic light model using a only diffuse lighting ''' rayhit = origin + direction * intersection normal = ((rayhit - self.center) * (1. / self.radius)) direction_to_light = (light_position - rayhit).norm() direction_to_eye = (eye_position - rayhit).norm() nudged = rayhit + normal * 0.001 # To avoid shadow acne # Create shadow mask light_distances = [ o.intersect(nudged, direction_to_light, far=far) for o in scene_objects ] light_nearest = reduce(jnp.minimum, light_distances) light_mask = light_distances[scene_objects.index( self)] == light_nearest # Ambient light color = Vec3(0.05, 0.05, 0.05) # Lambert shading (diffuse) light_hit = jnp.maximum(normal.dot(direction_to_light), 0) color += self.diffusecolor(rayhit) * light_hit * light_mask # Phong light phong = normal.dot((direction_to_light + direction_to_eye).norm()) color += Vec3(1., 1., 1.) * jnp.power(jnp.clip(phong, 0, 1), 50) * light_mask return color
def segment_max(data, segment_ids, num_segments=None, indices_are_sorted=False, unique_indices=False): """Computes the max within segments of an array. Similar to TensorFlow's segment_max: https://www.tensorflow.org/api_docs/python/tf/math/segment_max Args: data: an array with the values to be maxed over. segment_ids: an array with integer dtype that indicates the segments of `data` (along its leading axis) to be maxed over. Values can be repeated and need not be sorted. Values outside of the range [0, num_segments) are wrapped into that range by applying jnp.mod. num_segments: optional, an int with positive value indicating the number of segments. The default is ``jnp.maximum(jnp.max(segment_ids) + 1, jnp.max(-segment_ids))`` but since `num_segments` determines the size of the output, a static value must be provided to use ``segment_max`` in a ``jit``-compiled function. indices_are_sorted: whether ``segment_ids`` is known to be sorted unique_indices: whether ``segment_ids`` is known to be free of duplicates Returns: An array with shape ``(num_segments,) + data.shape[1:]`` representing the segment maxs. """ if num_segments is None: num_segments = jnp.maximum( jnp.max(segment_ids) + 1, jnp.max(-segment_ids)) num_segments = int(num_segments) min_value = dtype_min_value(data.dtype) out = jnp.full((num_segments, ) + data.shape[1:], min_value, dtype=data.dtype) segment_ids = jnp.mod(segment_ids, num_segments) return jax.ops.index_max(out, segment_ids, data, indices_are_sorted, unique_indices)
def compute_loss(self, predictions: NestedMap, input_batch: NestedMap) -> Tuple[Metrics, Dict[str, Any]]: """Computes the loss and other metrics for the given predictions. Args: predictions: The output of `compute_predictions`. input_batch: A `.NestedMap` object containing input tensors to this tower. Returns: - A dict or NestedMap containing str keys and (metric, weight) pairs as values, where one of the entries is expected to corresponds to the loss. - A dict containing arbitrary tensors describing something about each training example, where the first dimension of each tensor is the batch index. """ labels = input_batch.labels num_tokens = jnp.sum(1.0 - input_batch.paddings.astype(jnp.float32)) num_seqs = jnp.sum( jnp.amax(input_batch.segment_ids.astype(jnp.float32), axis=1)) weights = predictions.augmented_pos.astype(jnp.float32) predicted_labels = predictions.per_example_argmax.astype(labels.dtype) num_preds = predictions.total_weight.astype(jnp.float32) mean_acc = jnp.sum( (labels == predicted_labels) * weights) / jnp.maximum( num_preds, 1) metric_weight = jnp.array(num_preds, predictions.avg_xent.dtype) metrics = py_utils.NestedMap( total_loss=(predictions.total_loss, metric_weight), avg_xent=(predictions.avg_xent, metric_weight), aux_loss=(predictions.aux_loss, metric_weight), log_pplx=(predictions.avg_xent, metric_weight), fraction_of_correct_preds=(mean_acc, jnp.array(num_preds, mean_acc.dtype)), num_predictions=(num_preds, jnp.array(1.0, num_preds.dtype)), num_tokens=(num_tokens, jnp.array(1.0, num_tokens.dtype)), num_seqs=(num_seqs, jnp.array(1.0, num_seqs.dtype)), ) per_example_output = py_utils.NestedMap() return metrics, per_example_output
def test_clipping_norm(self, l2_norm_clip): dp_agg = privacy.differentially_private_aggregate( l2_norm_clip=l2_norm_clip, noise_multiplier=0., seed=42) state = dp_agg.init(self.params) update_fn = self.variant(dp_agg.update) # Shape of the three arrays below is (self.batch_size, ) norms = [ jnp.linalg.norm(g.reshape(self.batch_size, -1), axis=1) for g in jax.tree_leaves(self.per_eg_grads) ] global_norms = jnp.linalg.norm(jnp.stack(norms), axis=0) divisors = jnp.maximum(global_norms / l2_norm_clip, 1.) # Since the values of all the parameters are the same within each example, # we can easily compute what the values should be: expected_val = jnp.mean(jnp.arange(self.batch_size) / divisors) expected_tree = jax.tree_map( lambda p: jnp.broadcast_to(expected_val, p.shape), self.params) for _ in range(3): updates, state = update_fn(self.per_eg_grads, state, self.params) chex.assert_tree_all_close(updates, expected_tree)