def extract_max_elts(x): H, W, C = x.shape assert H % 2 == 0 and W % 2 == 0 # Squeeze so that the grid elements are aligned on the last axis x_squeeze = util.pixel_squeeze(x) # Sort each grid x_sq_argsorted = x_squeeze.argsort(axis=-1) # Find the max index of each grid max_idx = x_sq_argsorted[..., -1:] # Get all of the elements that aren't the max. non_max_idx = x_sq_argsorted[..., :-1] # Sort the non-max indices so that we can pass the decoder consistent information. non_max_idx = non_max_idx.sort(axis=-1) # Take the max elements max_elts = jnp.take_along_axis(x_squeeze, max_idx, axis=-1).squeeze(axis=-1) assert max_elts.shape == (H // 2, H // 2, C) # Take the remaining elements non_max_elts = jnp.take_along_axis(x_squeeze, non_max_idx, axis=-1) # Subtract elts from the max so that we are left with positive elements non_max_elts = max_elts[..., None] - non_max_elts non_max_elts = non_max_elts.reshape((H // 2, W // 2, 3 * C)) return max_elts, non_max_elts, max_idx.squeeze(axis=-1), non_max_idx
def sparse_categorical_crossentropy( y_true: jnp.ndarray, y_pred: jnp.ndarray, from_logits: bool = False, check_bounds: bool = True, ) -> jnp.ndarray: n_classes = y_pred.shape[-1] if from_logits: y_pred = jax.nn.log_softmax(y_pred) loss = -jnp.take_along_axis(y_pred, y_true[..., None], axis=-1)[..., 0] else: # select output value y_pred = jnp.take_along_axis(y_pred, y_true[..., None], axis=-1)[..., 0] # calculate log y_pred = jnp.maximum(y_pred, types.EPSILON) y_pred = jnp.log(y_pred) loss = -y_pred if check_bounds: # set NaN where y_true is negative or larger/equal to the number of y_pred channels loss = jnp.where(y_true < 0, jnp.nan, loss) loss = jnp.where(y_true >= n_classes, jnp.nan, loss) return loss
def solve_singles_vfi(EV_in, money, sigma, beta, i, wn, wt, sgrid): EV = wt[:, None] * EV_in[i, :] + wn[:, None] * EV_in[i + 1, :] consumption = money[:, None, :] - sgrid[None, :, None] consumption_negative = (consumption <= 0) utility = (np.maximum(consumption,1e-8))**(1-sigma)/(1-sigma) - \ 1e9*consumption_negative mega_matrix = utility + beta * EV #print(mega_matrix.shape) ind_s = mega_matrix.argmax(axis=1) V = np.take_along_axis(mega_matrix, ind_s[:, None, :], 1).squeeze(axis=1) s = sgrid[ind_s] c = money - s V_check = (c**(1 - sigma) / (1 - sigma)) + beta * np.take_along_axis(EV, ind_s, 0) assert np.allclose(V_check, V, atol=1e-5) return V, s
def interp_manygrids(grids, xs, axis=0, return_wnext=True, trim=False): # this routine interpolates xs on many grids, defined along # the axis in an array grids. (so for axis=0 grids are #grids[:,i,j,k] for all i, j, k) assert np.all(np.diff(grids, axis=axis) > 0) ''' if trim: xs = np.clip(xs[:,None,None], grids.min(axis=axis,keepdims=True), grids.max(axis=axis,keepdims=True)) ''' # this requires everything to be sorted mat = grids[..., None] < xs[(None, ) * grids.ndim + (slice(None), )] ng = grids.shape[axis] j = np.clip(np.sum(mat, axis=axis)[None, ...] - 1, 0, ng - 2) j = np.swapaxes(j, -1, axis).squeeze(axis=-1) grid_j = np.take_along_axis(grids, j, axis=axis) grid_jp = np.take_along_axis(grids, j + 1, axis=axis) xs_r = xs.reshape((1, ) * (axis - 1) + (xs.size, ) + (1, ) * (grids.ndim - 1 - axis)) wnext = (xs_r - grid_j) / (grid_jp - grid_j) return j, (wnext if return_wnext else 1 - wnext)
def _compute_loss_and_stats(params, model_out, use_elastic_loss=False): rgb_loss = ((model_out['rgb'] - batch['rgb'][..., :3])**2).mean() stats = { 'loss/rgb': rgb_loss, } loss = rgb_loss if use_elastic_loss: elastic_fn = functools.partial(compute_elastic_loss, loss_type=elastic_loss_type) v_elastic_fn = jax.jit(vmap(vmap(jax.jit(elastic_fn)))) weights = lax.stop_gradient(model_out['weights']) jacobian = model_out['warp_jacobian'] # Pick the median point Jacobian. if elastic_reduce_method == 'median': depth_indices = model_utils.compute_depth_index(weights) jacobian = jnp.take_along_axis( # Unsqueeze axes: sample axis, Jacobian row, Jacobian col. jacobian, depth_indices[..., None, None, None], axis=-3) # Compute loss using Jacobian. elastic_loss, elastic_residual = v_elastic_fn(jacobian) # Multiply weight if weighting by density. if elastic_reduce_method == 'weight': elastic_loss = weights * elastic_loss elastic_loss = elastic_loss.sum(axis=-1).mean() stats['loss/elastic'] = elastic_loss stats['residual/elastic'] = jnp.mean(elastic_residual) loss += scalar_params.elastic_loss_weight * elastic_loss if use_warp_reg_loss: weights = lax.stop_gradient(model_out['weights']) depth_indices = model_utils.compute_depth_index(weights) warp_mag = ((model_out['points'] - model_out['warped_points'])**2).sum(axis=-1) warp_reg_residual = jnp.take_along_axis(warp_mag, depth_indices[..., None], axis=-1) warp_reg_loss = utils.general_loss_with_squared_residual( warp_reg_residual, alpha=scalar_params.warp_reg_loss_alpha, scale=scalar_params.warp_reg_loss_scale).mean() stats['loss/warp_reg'] = warp_reg_loss stats['residual/warp_reg'] = jnp.mean(jnp.sqrt(warp_reg_residual)) loss += scalar_params.warp_reg_loss_weight * warp_reg_loss if 'warp_jacobian' in model_out: jacobian = model_out['warp_jacobian'] jacobian_det = jnp.linalg.det(jacobian) jacobian_div = utils.jacobian_to_div(jacobian) jacobian_curl = utils.jacobian_to_curl(jacobian) stats['metric/jacobian_det'] = jnp.mean(jacobian_det) stats['metric/jacobian_div'] = jnp.mean(jacobian_div) stats['metric/jacobian_curl'] = jnp.mean( jnp.linalg.norm(jacobian_curl, axis=-1)) stats['loss/total'] = loss stats['metric/psnr'] = utils.compute_psnr(rgb_loss) return loss, stats
def sample(self, key, sample_shape=()): ps = Dirichlet(self.weights).sample(key, sample_shape=sample_shape) zs = np.expand_dims(Categorical(ps).sample(key), axis=-1) locs = np.broadcast_to(self.locs, sample_shape + self.batch_shape + self.event_shape + self.mixture_shape) scales = np.broadcast_to(self.scales, sample_shape + self.batch_shape + self.event_shape + self.mixture_shape) mlocs = np.squeeze(np.take_along_axis(locs, zs, axis=-1), axis=-1) mscales = np.squeeze(np.take_along_axis(scales, zs, axis=-1), axis=-1) return Normal(mlocs, mscales).sample(key)
def conditional_params_to_sample(rng, conditional_params): means, inv_scales, logit_probs = conditional_params _, h, w, c = means.shape rng_mix, rng_logistic = random.split(rng) mix_idx = np.broadcast_to( _gumbel_max(rng_mix, logit_probs)[..., np.newaxis], (h, w, c))[np.newaxis] means = np.take_along_axis(means, mix_idx, 0)[0] inv_scales = np.take_along_axis(inv_scales, mix_idx, 0)[0] return ( means + random.logistic(rng_logistic, means.shape, means.dtype) / inv_scales)
def solve_singles_egm(EV, EMU, li, agrid, sigma, beta, R, i, wn, wt, last): if not last: c_prescribed = (beta * R * EMU)**(-1 / sigma) m_implied = c_prescribed + agrid[:, None] a_implied = (1 / R) * (m_implied - li) a_i_min = a_implied[0, ...] a_i_max = a_implied[-1, ...] j, wn = interp_manygrids(a_implied, agrid, axis=0, trim=True) s_egm = agrid[j] * (1 - wn) + agrid[j + 1] * wn EV_egm = np.take_along_axis(EV,j,axis=0)*(1-wn) + \ np.take_along_axis(EV,j+1,axis=0)*(wn) agrid_r = agrid.reshape((agrid.size, ) + a_i_min.ndim * (1, )) i_above = (agrid_r >= a_i_max) i_below = (agrid_r <= a_i_min) i_egm = (~i_above) & (~i_below) s_below = 0.0 s_above = agrid[-1] EV_below = EV[:1, ...] EV_above = EV[-1:, ...] s = s_egm * i_egm + s_above * i_above # + 0.0*i_below EV_int = EV_egm * i_egm + EV_above * i_above + EV_below * i_below c = R * agrid_r + li - s assert np.all(s >= s_below) assert np.all(s <= s_above) assert np.all(c >= 0) else: c = R * agrid[:, None] + li s = 0.0 * c EV_int = EV # it is 0 anyways MUc = (c**(-sigma)) # this EV should be interpolated as well V = (c**(1 - sigma) / (1 - sigma)) + beta * EV_int return V, s, MUc
def topk(self: TensorType, k: int, sorted: bool = True) -> Tuple[TensorType, TensorType]: # argpartition not yet implemented # wrapping indexing not supported in take() n = self.raw.shape[-1] idx = np.take(np.argsort(self.raw), np.arange(n - k, n), axis=-1) val = np.take_along_axis(self.raw, idx, axis=-1) if sorted: perm = np.flip(np.argsort(val, axis=-1), axis=-1) idx = np.take_along_axis(idx, perm, axis=-1) val = np.take_along_axis(self.raw, idx, axis=-1) return type(self)(val), type(self)(idx)
def logpmf(self, x, p): batch_shape = lax.broadcast_shapes(x.shape, p.shape[:-1]) # append a dimension to x # TODO: consider to convert x.dtype to int x = np.expand_dims(x, axis=-1) x = np.broadcast_to(x, batch_shape + (1, )) p = np.broadcast_to(p, batch_shape + p.shape[-1:]) if self.is_logits: # normalize log prob p = p - logsumexp(p, axis=-1, keepdims=True) # gather and remove the trailing dimension return np.take_along_axis(p, x, axis=-1)[..., 0] else: return np.take_along_axis(np.log(p), x, axis=-1)[..., 0]
def log_prob(self, value): batch_shape = lax.broadcast_shapes(jnp.shape(value), self.batch_shape) value = jnp.expand_dims(value, axis=-1) value = jnp.broadcast_to(value, batch_shape + (1,)) logits = _to_logits_multinom(self.probs) log_pmf = jnp.broadcast_to(logits, batch_shape + jnp.shape(logits)[-1:]) return jnp.take_along_axis(log_pmf, value, axis=-1)[..., 0]
def evaluation_fn(params, images, labels): tiled_logits, out = model.apply({'params': flax.core.freeze(params)}, images, train=False) loss_name = config.get('loss', 'sigmoid_xent') # TODO(dusenberrymw,zmariet): Clean up and generalize this. if loss_name == 'sigmoid_xent': ens_logits = batchensemble_utils.log_average_sigmoid_probs( jnp.asarray(jnp.split(tiled_logits, ens_size))) pre_logits = batchensemble_utils.log_average_sigmoid_probs( jnp.asarray(jnp.split(out['pre_logits'], ens_size))) else: # softmax ens_logits = batchensemble_utils.log_average_softmax_probs( jnp.asarray(jnp.split(tiled_logits, ens_size))) pre_logits = batchensemble_utils.log_average_softmax_probs( jnp.asarray(jnp.split(out['pre_logits'], ens_size))) losses = getattr(train_utils, loss_name)(logits=ens_logits, labels=labels[:, :config.num_classes], reduction=False) loss = jax.lax.psum(losses, axis_name='batch') top1_idx = jnp.argmax(ens_logits, axis=1) top1_correct = jnp.take_along_axis(labels, top1_idx[:, None], axis=1)[:, 0] ncorrect = jax.lax.psum(top1_correct, axis_name='batch') n = batch_size_eval metric_args = jax.lax.all_gather([ens_logits, labels, pre_logits], axis_name='batch') return ncorrect, loss, n, metric_args
def select(sequences, indices): """Given an array of shape (number_of_sequences, sequence_length, element_dimension), and a 1D array specifying which indices of each sequence to select, return a (number_of_sequences, element_dimension)-shaped array with the selected elements. Args: sequences: array with shape (number_of_sequences, sequence_length, element_dimension) indices: 1D array with length number_of_sequence Returns: selected_elements: array with shape (number_of_sequences, element_dimension) """ assert len(indices) == sequences.shape[0] # shape indices properly indices_shaped = indices[:, jnp.newaxis, jnp.newaxis] # select element selected_elements = jnp.take_along_axis(sequences, indices_shaped, axis=1) # remove sequence dimension selected_elements = jnp.squeeze(selected_elements, axis=1) return selected_elements
def index_Q_at_action(Q_values, actions): # Q_values [bsz, n_actions] # Actions [bsz,] idx = jnp.expand_dims(actions, -1) # pred_Q_values [bsz,] pred_Q_values = jnp.take_along_axis(Q_values, idx, -1).squeeze() return pred_Q_values
def take_along_axis(self: TensorType, index: TensorType, axis: int) -> TensorType: if axis % self.ndim != self.ndim - 1: raise NotImplementedError( "take_along_axis is currently only supported for the last axis" ) return type(self)(np.take_along_axis(self.raw, index.raw, axis=axis))
def cifar_10h_evaluation_fn(params, images, labels, mask): loss_as_str = config.get('loss', 'softmax_xent') ens_logits, ens_prelogits = ensemble_pred_fn(params, images, loss_as_str) label_indices = config.get('label_indices') if label_indices: ens_logits = ens_logits[:, label_indices] losses = getattr(train_utils, loss_as_str)(logits=ens_logits, labels=labels, reduction=False) loss = jax.lax.psum(losses, axis_name='batch') top1_idx = jnp.argmax(ens_logits, axis=1) # Extracts the label at the highest logit index for each image. one_hot_labels = jnp.eye(10)[jnp.argmax(labels, axis=1)] top1_correct = jnp.take_along_axis(one_hot_labels, top1_idx[:, None], axis=1)[:, 0] ncorrect = jax.lax.psum(top1_correct, axis_name='batch') n = jax.lax.psum(one_hot_labels, axis_name='batch') metric_args = jax.lax.all_gather( [ens_logits, labels, ens_prelogits, mask], axis_name='batch') return ncorrect, loss, n, metric_args
def Bellman_loss(self, Qnet_params, batch, actions): inputs, targets = batch preds = self.predict(Qnet_params, inputs) preds_select = jnp.take_along_axis(preds, jnp.expand_dims(actions, axis=1), axis=1) return jnp.mean(huber_loss(preds_select - targets))
def _take_along_axis(array, indices, axis): """Takes values from the input array by matching 1D index and data slices. This function serves the same purpose as jax.numpy.take_along_axis, except that it uses one-hot matrix multiplications under the hood on TPUs: (1) On TPUs, we use one-hot matrix multiplications to select elements from the array. (2) Otherwise, we fall back to jax.numpy.take_along_axis. Notes: - To simplify matters in case (1), we only support slices along the second or last dimensions. - We may wish to revisit (1) for very large arrays. Args: array: Source array. indices: Indices to take along each 1D slice of array. axis: Axis along which to take 1D slices. Returns: The indexed result. """ if array.ndim != indices.ndim: raise ValueError( "indices and array must have the same number of dimensions; " f"{indices.ndim} vs. {array.ndim}.") if (axis != -1 and axis != array.ndim - 1 and # Not last dimension axis != 1 and axis != -array.ndim + 1): # Not second dimension raise ValueError( "Only slices along the second or last dimension are supported; " f"array.ndim = {array.ndim}, while axis = {axis}.") if _favor_one_hot_slices(): one_hot_length = array.shape[axis] one_hot_indices = jax.nn.one_hot(indices, one_hot_length, axis=axis) if axis == -1 or array.ndim == 1: # Take i elements from last dimension (s). # We must use HIGHEST precision to accurately reproduce indexing # operations with matrix multiplications. result = jnp.einsum( "...s,...is->...i", array, one_hot_indices, precision=jax.lax.Precision.HIGHEST) else: # Take i elements from second dimension (s). We assume here that we always # want to slice along the second dimension. # We must use HIGHEST precision to accurately reproduce indexing # operations with matrix multiplications. result = jnp.einsum( "ns...,nis...->ni...", array, one_hot_indices, precision=jax.lax.Precision.HIGHEST) return jax.lax.convert_element_type(result, array.dtype) else: return jnp.take_along_axis(array, indices, axis=axis)
def cifar_10h_evaluation_fn(params, states, images, labels, mask): variable_dict = {'params': flax.core.freeze(params), **states} logits, out = model.apply(variable_dict, images, train=False, rngs={ 'dropout': rng_dropout, 'diag_noise_samples': diag_noise_rng, 'standard_norm_noise_samples': standard_noise_rng }) losses = getattr(train_utils, config.get('loss', 'softmax_xent'))(logits=logits, labels=labels, reduction=False) loss = jax.lax.psum(losses, axis_name='batch') top1_idx = jnp.argmax(logits, axis=1) # Extracts the label at the highest logit index for each image. one_hot_labels = jnp.eye(10)[jnp.argmax(labels, axis=1)] top1_correct = jnp.take_along_axis(one_hot_labels, top1_idx[:, None], axis=1)[:, 0] ncorrect = jax.lax.psum(top1_correct, axis_name='batch') n = jax.lax.psum(one_hot_labels, axis_name='batch') metric_args = jax.lax.all_gather( [logits, labels, out['pre_logits'], mask], axis_name='batch') return ncorrect, loss, n, metric_args
def _extract_signal_patches(signal, window_length, hop=1, data_format="NCW"): if hasattr(window_length, "shape"): assert window_length.shape == () else: assert not hasattr(window_length, "__len__") if data_format == "NCW": if signal.ndim == 2: signal_3d = signal[:, None, :] elif signal.ndim == 1: signal_3d = signal[None, None, :] else: signal_3d = signal N = (signal_3d.shape[2] - window_length) // hop + 1 indices = jnp.arange(window_length) + jnp.expand_dims(jnp.arange(N) * hop, 1) indices = jnp.reshape(indices, [1, 1, N * window_length]) patches = jnp.take_along_axis(signal_3d, indices, 2) output = jnp.reshape(patches, signal_3d.shape[:2] + (N, window_length)) if signal.ndim == 1: return output[0, 0] elif signal.ndim == 2: return output[:, 0, :] else: return output else: error
def sample_with_intermediates(self, key, sample_shape=()): """ Same as ``sample`` except that the sampled mixture components are also returned. :param jax.random.PRNGKey key: the rng_key key to be used for the distribution. :param tuple sample_shape: the sample shape for the distribution. :return: Tuple (samples, indices) :rtype: tuple """ assert is_prng_key(key) key_comp, key_ind = jax.random.split(key) # Samples from component distribution will have shape: # (*sample_shape, *batch_shape, mixture_size, *event_shape) samples = self.component_distribution.expand( sample_shape + self.batch_shape + (self.mixture_size, )).sample(key_comp) # Sample selection indices from the categorical (shape will be sample_shape) indices = self.mixing_distribution.expand( sample_shape + self.batch_shape).sample(key_ind) n_expand = self.event_dim + 1 indices_expanded = indices.reshape(indices.shape + (1, ) * n_expand) # Select samples according to indices samples from categorical samples_selected = jnp.take_along_axis(samples, indices=indices_expanded, axis=self.mixture_dim) # Final sample shape (*sample_shape, *batch_shape, *event_shape) return jnp.squeeze(samples_selected, axis=self.mixture_dim), indices
def cross_entropy(logits, targets, axis=-1): logprobs = nn.log_softmax(logits, axis=axis) nll = np.take_along_axis(logprobs, np.expand_dims(targets, axis=axis), axis=axis) ce = -np.mean(nll) return ce
def l1_unit_projection(x): """Euclidean projection to L1 unit ball i.e. argmin_{|v|_1<= 1} |x-v|_2. Args: x: An array of size dim x num. Returns: An array of size dim x num, the projection to the unit L1 ball. """ # https://dl.acm.org/citation.cfm?id=1390191 xshape = x.shape if len(x.shape) == 1: x = x.reshape(-1, 1) eshape = x.shape v = jnp.abs(x.reshape((-1, eshape[-1]))) u = jnp.sort(v, axis=0) u = u[::-1, :] # descending arange = (1 + jnp.arange(eshape[0])).reshape((-1, 1)) usum = (jnp.cumsum(u, axis=0) - 1) / arange rho = jnp.max(((u - usum) > 0) * arange - 1, axis=0, keepdims=True) thx = jnp.take_along_axis(usum, rho, axis=0) w = (v - thx).clip(a_min=0) w = jnp.where(jnp.linalg.norm(v, ord=1, axis=0, keepdims=True) > 1, w, v) x = w.reshape(eshape) * jnp.sign(x) return x.reshape(xshape)
def sampling_loop_body_fn(state): """Sampling loop state update.""" i, sequences, cache, cur_token, ended, rng = state # Split RNG for sampling. rng1, rng2 = random.split(rng) # Call fast-decoder model on current tokens to get next-position logits. logits, new_cache = tokens_to_logits(cur_token, cache) # Sample next token from logits. # TODO(levskaya): add top-p "nucleus" sampling option. if topk: # Get top-k logits and their indices, sample within these top-k tokens. topk_logits, topk_idxs = lax.top_k(logits, topk) topk_token = jnp.expand_dims( random.categorical(rng1, topk_logits / temperature).astype(jnp.int32), axis=-1) # Return the original indices corresponding to the sampled top-k tokens. next_token = jnp.squeeze( jnp.take_along_axis(topk_idxs, topk_token, axis=-1), axis=-1) else: next_token = random.categorical(rng1, logits / temperature).astype(jnp.int32) # Only use sampled tokens if we're past provided prefix tokens. out_of_prompt = (sequences[:, i + 1] == 0) next_token = ( next_token * out_of_prompt + sequences[:, i + 1] * ~out_of_prompt) # If end-marker reached for batch item, only emit padding tokens. next_token_or_endpad = next_token * ~ended ended |= (next_token_or_endpad == end_marker) # Add current sampled tokens to recorded sequences. new_sequences = lax.dynamic_update_slice(sequences, next_token_or_endpad, (0, i + 1)) return (i + 1, new_sequences, new_cache, next_token_or_endpad, ended, rng2)
def cifar_10h_evaluation_fn(params, images, labels, mask): logits, out = model.apply({'params': flax.core.freeze(params)}, images, train=False) label_indices = config.get('label_indices') if label_indices: logits = logits[:, label_indices] losses = getattr(train_utils, config.get('loss', 'softmax_xent'))(logits=logits, labels=labels, reduction=False) loss = jax.lax.psum(losses, axis_name='batch') top1_idx = jnp.argmax(logits, axis=1) # Extracts the label at the highest logit index for each image. one_hot_labels = jnp.eye(10)[jnp.argmax(labels, axis=1)] top1_correct = jnp.take_along_axis(one_hot_labels, top1_idx[:, None], axis=1)[:, 0] ncorrect = jax.lax.psum(top1_correct, axis_name='batch') n = jax.lax.psum(one_hot_labels, axis_name='batch') metric_args = jax.lax.all_gather( [logits, labels, out['pre_logits'], mask], axis_name='batch') return ncorrect, loss, n, metric_args
def evaluation_fn(params, images, labels, mask): # Ignore the entries with all zero labels for evaluation. mask *= labels.max(axis=1) logits, out = model.apply({'params': flax.core.freeze(params)}, images, train=False) label_indices = config.get('label_indices') logging.info('!!! mask %s, label_indices %s', mask, label_indices) if label_indices: logits = logits[:, label_indices] # Note that logits and labels are usually of the shape [batch,num_classes]. # But for OOD data, when num_classes_ood > num_classes_ind, we need to # adjust labels to labels[:, :config.num_classes] to match the shape of # logits. That is just to avoid shape mismatch. The output losses does not # have any meaning for OOD data, because OOD not belong to any IND class. losses = getattr(train_utils, config.get('loss', 'sigmoid_xent'))( logits=logits, labels=labels[:, :( len(label_indices) if label_indices else config.num_classes)], reduction=False) loss = jax.lax.psum(losses * mask, axis_name='batch') top1_idx = jnp.argmax(logits, axis=1) # Extracts the label at the highest logit index for each image. top1_correct = jnp.take_along_axis(labels, top1_idx[:, None], axis=1)[:, 0] ncorrect = jax.lax.psum(top1_correct * mask, axis_name='batch') n = jax.lax.psum(mask, axis_name='batch') metric_args = jax.lax.all_gather( [logits, labels, out['pre_logits'], mask], axis_name='batch') return ncorrect, loss, n, metric_args
def log_prob(self, value): batch_shape = lax.broadcast_shapes(jnp.shape(value), self.batch_shape) value = jnp.expand_dims(value, -1) value = jnp.broadcast_to(value, batch_shape + (1,)) log_pmf = self.logits - logsumexp(self.logits, axis=-1, keepdims=True) log_pmf = jnp.broadcast_to(log_pmf, batch_shape + jnp.shape(log_pmf)[-1:]) return jnp.take_along_axis(log_pmf, value, -1)[..., 0]
def sparse_categorical_crossentropy(y_true: jnp.ndarray, y_pred: jnp.ndarray, from_logits: bool = False) -> jnp.ndarray: if from_logits: y_pred = jax.nn.log_softmax(y_pred) return -jnp.take_along_axis(y_pred, y_true[..., None], axis=-1)[..., 0] else: # select output value y_pred = jnp.take_along_axis(y_pred, y_true[..., None], axis=-1)[..., 0] # calculate log y_pred = jnp.maximum(y_pred, utils.EPSILON) y_pred = jnp.log(y_pred) return -y_pred
def log_prob(self, value): if self._validate_args: self._validate_sample(value) value = np.expand_dims(value, -1) log_pmf = self.logits - logsumexp(self.logits, axis=-1, keepdims=True) value, log_pmf = promote_shapes(value, log_pmf) value = value[..., :1] return np.take_along_axis(log_pmf, value, -1)[..., 0]
def sampling_loop_body_fn(state): """Sampling loop state update.""" i, sequences, cache, cur_token, ended, rng, tokens_to_logits_state = state # Split RNG for sampling. rng1, rng2 = random.split(rng) # Call fast-decoder model on current tokens to get raw next-position logits. logits, new_cache, new_tokens_to_logits_state = tokens_to_logits( cur_token, cache, internal_state=tokens_to_logits_state) logits = logits / temperature # Mask out the BOS token. if masked_tokens is not None: mask = common_utils.onehot(jnp.array(masked_tokens), num_classes=logits.shape[-1], on_value=LARGE_NEGATIVE) mask = jnp.sum(mask, axis=0)[None, :] # Combine multiple masks together logits = logits + mask # Apply the repetition penalty. if repetition_penalty != 1: logits = apply_repetition_penalty( sequences, logits, i, repetition_penalty=repetition_penalty, repetition_window=repetition_window, repetition_penalty_normalize=repetition_penalty_normalize) # Mask out everything but the top-k entries. if top_k is not None: # Compute top_k_index and top_k_threshold with shapes (batch_size, 1). top_k_index = jnp.argsort(logits, axis=-1)[:, ::-1][:, top_k - 1:top_k] top_k_threshold = jnp.take_along_axis(logits, top_k_index, axis=-1) logits = jnp.where(logits < top_k_threshold, jnp.full_like(logits, LARGE_NEGATIVE), logits) # Sample next token from logits. sample = multinomial(rng1, logits) next_token = sample.astype(jnp.int32) # Only use sampled tokens if we have past the out_of_prompt_marker. out_of_prompt = (sequences[:, i + 1] == out_of_prompt_marker) next_token = (next_token * out_of_prompt + sequences[:, i + 1] * ~out_of_prompt) # If end-marker reached for batch item, only emit padding tokens. next_token = next_token[:, None] next_token_or_endpad = jnp.where(ended, jnp.full_like(next_token, pad_token), next_token) ended |= (next_token_or_endpad == end_marker) # Add current sampled tokens to recorded sequences. new_sequences = lax.dynamic_update_slice(sequences, next_token_or_endpad, (0, i + 1)) return (i + 1, new_sequences, new_cache, next_token_or_endpad, ended, rng2, new_tokens_to_logits_state)