Ejemplo n.º 1
0
def extract_max_elts(x):
    H, W, C = x.shape
    assert H % 2 == 0 and W % 2 == 0

    # Squeeze so that the grid elements are aligned on the last axis
    x_squeeze = util.pixel_squeeze(x)

    # Sort each grid
    x_sq_argsorted = x_squeeze.argsort(axis=-1)

    # Find the max index of each grid
    max_idx = x_sq_argsorted[..., -1:]

    # Get all of the elements that aren't the max.
    non_max_idx = x_sq_argsorted[..., :-1]

    # Sort the non-max indices so that we can pass the decoder consistent information.
    non_max_idx = non_max_idx.sort(axis=-1)

    # Take the max elements
    max_elts = jnp.take_along_axis(x_squeeze, max_idx,
                                   axis=-1).squeeze(axis=-1)
    assert max_elts.shape == (H // 2, H // 2, C)

    # Take the remaining elements
    non_max_elts = jnp.take_along_axis(x_squeeze, non_max_idx, axis=-1)

    # Subtract elts from the max so that we are left with positive elements
    non_max_elts = max_elts[..., None] - non_max_elts
    non_max_elts = non_max_elts.reshape((H // 2, W // 2, 3 * C))

    return max_elts, non_max_elts, max_idx.squeeze(axis=-1), non_max_idx
Ejemplo n.º 2
0
def sparse_categorical_crossentropy(
    y_true: jnp.ndarray,
    y_pred: jnp.ndarray,
    from_logits: bool = False,
    check_bounds: bool = True,
) -> jnp.ndarray:

    n_classes = y_pred.shape[-1]

    if from_logits:
        y_pred = jax.nn.log_softmax(y_pred)
        loss = -jnp.take_along_axis(y_pred, y_true[..., None], axis=-1)[..., 0]
    else:
        # select output value
        y_pred = jnp.take_along_axis(y_pred, y_true[..., None], axis=-1)[..., 0]

        # calculate log
        y_pred = jnp.maximum(y_pred, types.EPSILON)
        y_pred = jnp.log(y_pred)
        loss = -y_pred

    if check_bounds:
        # set NaN where y_true is negative or larger/equal to the number of y_pred channels
        loss = jnp.where(y_true < 0, jnp.nan, loss)
        loss = jnp.where(y_true >= n_classes, jnp.nan, loss)

    return loss
Ejemplo n.º 3
0
def solve_singles_vfi(EV_in, money, sigma, beta, i, wn, wt, sgrid):
    EV = wt[:, None] * EV_in[i, :] + wn[:, None] * EV_in[i + 1, :]

    consumption = money[:, None, :] - sgrid[None, :, None]

    consumption_negative = (consumption <= 0)

    utility = (np.maximum(consumption,1e-8))**(1-sigma)/(1-sigma) - \
                                                    1e9*consumption_negative

    mega_matrix = utility + beta * EV
    #print(mega_matrix.shape)

    ind_s = mega_matrix.argmax(axis=1)
    V = np.take_along_axis(mega_matrix, ind_s[:, None, :], 1).squeeze(axis=1)

    s = sgrid[ind_s]
    c = money - s

    V_check = (c**(1 - sigma) /
               (1 - sigma)) + beta * np.take_along_axis(EV, ind_s, 0)

    assert np.allclose(V_check, V, atol=1e-5)

    return V, s
Ejemplo n.º 4
0
def interp_manygrids(grids, xs, axis=0, return_wnext=True, trim=False):
    # this routine interpolates xs on many grids, defined along
    # the axis in an array grids. (so for axis=0 grids are
    #grids[:,i,j,k] for all i, j, k)

    assert np.all(np.diff(grids, axis=axis) > 0)
    '''
    if trim: xs = np.clip(xs[:,None,None],
                          grids.min(axis=axis,keepdims=True),
                          grids.max(axis=axis,keepdims=True))
    '''

    # this requires everything to be sorted
    mat = grids[..., None] < xs[(None, ) * grids.ndim + (slice(None), )]
    ng = grids.shape[axis]
    j = np.clip(np.sum(mat, axis=axis)[None, ...] - 1, 0, ng - 2)
    j = np.swapaxes(j, -1, axis).squeeze(axis=-1)
    grid_j = np.take_along_axis(grids, j, axis=axis)
    grid_jp = np.take_along_axis(grids, j + 1, axis=axis)

    xs_r = xs.reshape((1, ) * (axis - 1) + (xs.size, ) + (1, ) *
                      (grids.ndim - 1 - axis))

    wnext = (xs_r - grid_j) / (grid_jp - grid_j)
    return j, (wnext if return_wnext else 1 - wnext)
Ejemplo n.º 5
0
    def _compute_loss_and_stats(params, model_out, use_elastic_loss=False):
        rgb_loss = ((model_out['rgb'] - batch['rgb'][..., :3])**2).mean()
        stats = {
            'loss/rgb': rgb_loss,
        }
        loss = rgb_loss
        if use_elastic_loss:
            elastic_fn = functools.partial(compute_elastic_loss,
                                           loss_type=elastic_loss_type)
            v_elastic_fn = jax.jit(vmap(vmap(jax.jit(elastic_fn))))
            weights = lax.stop_gradient(model_out['weights'])
            jacobian = model_out['warp_jacobian']
            # Pick the median point Jacobian.
            if elastic_reduce_method == 'median':
                depth_indices = model_utils.compute_depth_index(weights)
                jacobian = jnp.take_along_axis(
                    # Unsqueeze axes: sample axis, Jacobian row, Jacobian col.
                    jacobian,
                    depth_indices[..., None, None, None],
                    axis=-3)
            # Compute loss using Jacobian.
            elastic_loss, elastic_residual = v_elastic_fn(jacobian)
            # Multiply weight if weighting by density.
            if elastic_reduce_method == 'weight':
                elastic_loss = weights * elastic_loss
            elastic_loss = elastic_loss.sum(axis=-1).mean()
            stats['loss/elastic'] = elastic_loss
            stats['residual/elastic'] = jnp.mean(elastic_residual)
            loss += scalar_params.elastic_loss_weight * elastic_loss

        if use_warp_reg_loss:
            weights = lax.stop_gradient(model_out['weights'])
            depth_indices = model_utils.compute_depth_index(weights)
            warp_mag = ((model_out['points'] -
                         model_out['warped_points'])**2).sum(axis=-1)
            warp_reg_residual = jnp.take_along_axis(warp_mag,
                                                    depth_indices[..., None],
                                                    axis=-1)
            warp_reg_loss = utils.general_loss_with_squared_residual(
                warp_reg_residual,
                alpha=scalar_params.warp_reg_loss_alpha,
                scale=scalar_params.warp_reg_loss_scale).mean()
            stats['loss/warp_reg'] = warp_reg_loss
            stats['residual/warp_reg'] = jnp.mean(jnp.sqrt(warp_reg_residual))
            loss += scalar_params.warp_reg_loss_weight * warp_reg_loss

        if 'warp_jacobian' in model_out:
            jacobian = model_out['warp_jacobian']
            jacobian_det = jnp.linalg.det(jacobian)
            jacobian_div = utils.jacobian_to_div(jacobian)
            jacobian_curl = utils.jacobian_to_curl(jacobian)
            stats['metric/jacobian_det'] = jnp.mean(jacobian_det)
            stats['metric/jacobian_div'] = jnp.mean(jacobian_div)
            stats['metric/jacobian_curl'] = jnp.mean(
                jnp.linalg.norm(jacobian_curl, axis=-1))

        stats['loss/total'] = loss
        stats['metric/psnr'] = utils.compute_psnr(rgb_loss)
        return loss, stats
 def sample(self, key, sample_shape=()):
     ps = Dirichlet(self.weights).sample(key, sample_shape=sample_shape)
     zs = np.expand_dims(Categorical(ps).sample(key), axis=-1)
     locs = np.broadcast_to(self.locs, sample_shape + self.batch_shape + self.event_shape + self.mixture_shape)
     scales = np.broadcast_to(self.scales, sample_shape + self.batch_shape + self.event_shape + self.mixture_shape)
     mlocs = np.squeeze(np.take_along_axis(locs, zs, axis=-1), axis=-1)
     mscales = np.squeeze(np.take_along_axis(scales, zs, axis=-1), axis=-1)
     return Normal(mlocs, mscales).sample(key)
Ejemplo n.º 7
0
def conditional_params_to_sample(rng, conditional_params):
    means, inv_scales, logit_probs = conditional_params
    _, h, w, c = means.shape
    rng_mix, rng_logistic = random.split(rng)
    mix_idx = np.broadcast_to(
        _gumbel_max(rng_mix, logit_probs)[..., np.newaxis],
        (h, w, c))[np.newaxis]
    means = np.take_along_axis(means, mix_idx, 0)[0]
    inv_scales = np.take_along_axis(inv_scales, mix_idx, 0)[0]
    return (
        means +
        random.logistic(rng_logistic, means.shape, means.dtype) / inv_scales)
Ejemplo n.º 8
0
def solve_singles_egm(EV, EMU, li, agrid, sigma, beta, R, i, wn, wt, last):

    if not last:

        c_prescribed = (beta * R * EMU)**(-1 / sigma)
        m_implied = c_prescribed + agrid[:, None]
        a_implied = (1 / R) * (m_implied - li)

        a_i_min = a_implied[0, ...]
        a_i_max = a_implied[-1, ...]

        j, wn = interp_manygrids(a_implied, agrid, axis=0, trim=True)
        s_egm = agrid[j] * (1 - wn) + agrid[j + 1] * wn



        EV_egm = np.take_along_axis(EV,j,axis=0)*(1-wn) + \
                        np.take_along_axis(EV,j+1,axis=0)*(wn)

        agrid_r = agrid.reshape((agrid.size, ) + a_i_min.ndim * (1, ))
        i_above = (agrid_r >= a_i_max)
        i_below = (agrid_r <= a_i_min)
        i_egm = (~i_above) & (~i_below)

        s_below = 0.0
        s_above = agrid[-1]

        EV_below = EV[:1, ...]
        EV_above = EV[-1:, ...]

        s = s_egm * i_egm + s_above * i_above  # + 0.0*i_below

        EV_int = EV_egm * i_egm + EV_above * i_above + EV_below * i_below

        c = R * agrid_r + li - s

        assert np.all(s >= s_below)
        assert np.all(s <= s_above)
        assert np.all(c >= 0)

    else:

        c = R * agrid[:, None] + li
        s = 0.0 * c
        EV_int = EV  # it is 0 anyways

    MUc = (c**(-sigma))

    # this EV should be interpolated as well

    V = (c**(1 - sigma) / (1 - sigma)) + beta * EV_int

    return V, s, MUc
Ejemplo n.º 9
0
 def topk(self: TensorType,
          k: int,
          sorted: bool = True) -> Tuple[TensorType, TensorType]:
     # argpartition not yet implemented
     # wrapping indexing not supported in take()
     n = self.raw.shape[-1]
     idx = np.take(np.argsort(self.raw), np.arange(n - k, n), axis=-1)
     val = np.take_along_axis(self.raw, idx, axis=-1)
     if sorted:
         perm = np.flip(np.argsort(val, axis=-1), axis=-1)
         idx = np.take_along_axis(idx, perm, axis=-1)
         val = np.take_along_axis(self.raw, idx, axis=-1)
     return type(self)(val), type(self)(idx)
Ejemplo n.º 10
0
 def logpmf(self, x, p):
     batch_shape = lax.broadcast_shapes(x.shape, p.shape[:-1])
     # append a dimension to x
     # TODO: consider to convert x.dtype to int
     x = np.expand_dims(x, axis=-1)
     x = np.broadcast_to(x, batch_shape + (1, ))
     p = np.broadcast_to(p, batch_shape + p.shape[-1:])
     if self.is_logits:
         # normalize log prob
         p = p - logsumexp(p, axis=-1, keepdims=True)
         # gather and remove the trailing dimension
         return np.take_along_axis(p, x, axis=-1)[..., 0]
     else:
         return np.take_along_axis(np.log(p), x, axis=-1)[..., 0]
Ejemplo n.º 11
0
 def log_prob(self, value):
     batch_shape = lax.broadcast_shapes(jnp.shape(value), self.batch_shape)
     value = jnp.expand_dims(value, axis=-1)
     value = jnp.broadcast_to(value, batch_shape + (1,))
     logits = _to_logits_multinom(self.probs)
     log_pmf = jnp.broadcast_to(logits, batch_shape + jnp.shape(logits)[-1:])
     return jnp.take_along_axis(log_pmf, value, axis=-1)[..., 0]
    def evaluation_fn(params, images, labels):
        tiled_logits, out = model.apply({'params': flax.core.freeze(params)},
                                        images,
                                        train=False)

        loss_name = config.get('loss', 'sigmoid_xent')
        # TODO(dusenberrymw,zmariet): Clean up and generalize this.
        if loss_name == 'sigmoid_xent':
            ens_logits = batchensemble_utils.log_average_sigmoid_probs(
                jnp.asarray(jnp.split(tiled_logits, ens_size)))
            pre_logits = batchensemble_utils.log_average_sigmoid_probs(
                jnp.asarray(jnp.split(out['pre_logits'], ens_size)))
        else:  # softmax
            ens_logits = batchensemble_utils.log_average_softmax_probs(
                jnp.asarray(jnp.split(tiled_logits, ens_size)))
            pre_logits = batchensemble_utils.log_average_softmax_probs(
                jnp.asarray(jnp.split(out['pre_logits'], ens_size)))

        losses = getattr(train_utils,
                         loss_name)(logits=ens_logits,
                                    labels=labels[:, :config.num_classes],
                                    reduction=False)
        loss = jax.lax.psum(losses, axis_name='batch')

        top1_idx = jnp.argmax(ens_logits, axis=1)
        top1_correct = jnp.take_along_axis(labels, top1_idx[:, None],
                                           axis=1)[:, 0]
        ncorrect = jax.lax.psum(top1_correct, axis_name='batch')
        n = batch_size_eval

        metric_args = jax.lax.all_gather([ens_logits, labels, pre_logits],
                                         axis_name='batch')
        return ncorrect, loss, n, metric_args
Ejemplo n.º 13
0
def select(sequences, indices):
  """Given an array of shape (number_of_sequences, sequence_length, element_dimension),
  and a 1D array specifying which indices of each sequence to select, return
  a (number_of_sequences, element_dimension)-shaped array with the selected elements.

  Args:
    sequences: array with shape (number_of_sequences, sequence_length, element_dimension)
    indices: 1D array with length number_of_sequence

  Returns:
    selected_elements: array with shape (number_of_sequences, element_dimension)
  """

  assert len(indices) == sequences.shape[0]

  # shape indices properly
  indices_shaped = indices[:, jnp.newaxis, jnp.newaxis]

  # select element
  selected_elements = jnp.take_along_axis(sequences, indices_shaped, axis=1)

  # remove sequence dimension
  selected_elements = jnp.squeeze(selected_elements, axis=1)

  return selected_elements
Ejemplo n.º 14
0
def index_Q_at_action(Q_values, actions):
    # Q_values [bsz, n_actions]
    # Actions [bsz,]
    idx = jnp.expand_dims(actions, -1)
    # pred_Q_values [bsz,]
    pred_Q_values = jnp.take_along_axis(Q_values, idx, -1).squeeze()
    return pred_Q_values
Ejemplo n.º 15
0
Archivo: jax.py Proyecto: yibit/eagerpy
 def take_along_axis(self: TensorType, index: TensorType,
                     axis: int) -> TensorType:
     if axis % self.ndim != self.ndim - 1:
         raise NotImplementedError(
             "take_along_axis is currently only supported for the last axis"
         )
     return type(self)(np.take_along_axis(self.raw, index.raw, axis=axis))
Ejemplo n.º 16
0
    def cifar_10h_evaluation_fn(params, images, labels, mask):
        loss_as_str = config.get('loss', 'softmax_xent')
        ens_logits, ens_prelogits = ensemble_pred_fn(params, images,
                                                     loss_as_str)
        label_indices = config.get('label_indices')
        if label_indices:
            ens_logits = ens_logits[:, label_indices]

        losses = getattr(train_utils, loss_as_str)(logits=ens_logits,
                                                   labels=labels,
                                                   reduction=False)
        loss = jax.lax.psum(losses, axis_name='batch')

        top1_idx = jnp.argmax(ens_logits, axis=1)
        # Extracts the label at the highest logit index for each image.
        one_hot_labels = jnp.eye(10)[jnp.argmax(labels, axis=1)]

        top1_correct = jnp.take_along_axis(one_hot_labels,
                                           top1_idx[:, None],
                                           axis=1)[:, 0]
        ncorrect = jax.lax.psum(top1_correct, axis_name='batch')
        n = jax.lax.psum(one_hot_labels, axis_name='batch')

        metric_args = jax.lax.all_gather(
            [ens_logits, labels, ens_prelogits, mask], axis_name='batch')
        return ncorrect, loss, n, metric_args
Ejemplo n.º 17
0
 def Bellman_loss(self, Qnet_params, batch, actions):
     inputs, targets = batch
     preds = self.predict(Qnet_params, inputs)
     preds_select = jnp.take_along_axis(preds,
                                        jnp.expand_dims(actions, axis=1),
                                        axis=1)
     return jnp.mean(huber_loss(preds_select - targets))
Ejemplo n.º 18
0
def _take_along_axis(array, indices,
                     axis):
  """Takes values from the input array by matching 1D index and data slices.

  This function serves the same purpose as jax.numpy.take_along_axis, except
  that it uses one-hot matrix multiplications under the hood on TPUs:
  (1) On TPUs, we use one-hot matrix multiplications to select elements from the
      array.
  (2) Otherwise, we fall back to jax.numpy.take_along_axis.

  Notes:
    - To simplify matters in case (1), we only support slices along the second
      or last dimensions.
    - We may wish to revisit (1) for very large arrays.

  Args:
    array: Source array.
    indices: Indices to take along each 1D slice of array.
    axis: Axis along which to take 1D slices.

  Returns:
    The indexed result.
  """
  if array.ndim != indices.ndim:
    raise ValueError(
        "indices and array must have the same number of dimensions; "
        f"{indices.ndim} vs. {array.ndim}.")

  if (axis != -1 and axis != array.ndim - 1 and  # Not last dimension
      axis != 1 and axis != -array.ndim + 1):  # Not second dimension
    raise ValueError(
        "Only slices along the second or last dimension are supported; "
        f"array.ndim = {array.ndim}, while axis = {axis}.")

  if _favor_one_hot_slices():
    one_hot_length = array.shape[axis]
    one_hot_indices = jax.nn.one_hot(indices, one_hot_length, axis=axis)

    if axis == -1 or array.ndim == 1:
      # Take i elements from last dimension (s).
      # We must use HIGHEST precision to accurately reproduce indexing
      # operations with matrix multiplications.
      result = jnp.einsum(
          "...s,...is->...i",
          array,
          one_hot_indices,
          precision=jax.lax.Precision.HIGHEST)
    else:
      # Take i elements from second dimension (s). We assume here that we always
      # want to slice along the second dimension.
      # We must use HIGHEST precision to accurately reproduce indexing
      # operations with matrix multiplications.
      result = jnp.einsum(
          "ns...,nis...->ni...",
          array,
          one_hot_indices,
          precision=jax.lax.Precision.HIGHEST)
    return jax.lax.convert_element_type(result, array.dtype)
  else:
    return jnp.take_along_axis(array, indices, axis=axis)
Ejemplo n.º 19
0
    def cifar_10h_evaluation_fn(params, states, images, labels, mask):
        variable_dict = {'params': flax.core.freeze(params), **states}
        logits, out = model.apply(variable_dict,
                                  images,
                                  train=False,
                                  rngs={
                                      'dropout':
                                      rng_dropout,
                                      'diag_noise_samples':
                                      diag_noise_rng,
                                      'standard_norm_noise_samples':
                                      standard_noise_rng
                                  })

        losses = getattr(train_utils,
                         config.get('loss', 'softmax_xent'))(logits=logits,
                                                             labels=labels,
                                                             reduction=False)
        loss = jax.lax.psum(losses, axis_name='batch')

        top1_idx = jnp.argmax(logits, axis=1)
        # Extracts the label at the highest logit index for each image.
        one_hot_labels = jnp.eye(10)[jnp.argmax(labels, axis=1)]

        top1_correct = jnp.take_along_axis(one_hot_labels,
                                           top1_idx[:, None],
                                           axis=1)[:, 0]
        ncorrect = jax.lax.psum(top1_correct, axis_name='batch')
        n = jax.lax.psum(one_hot_labels, axis_name='batch')

        metric_args = jax.lax.all_gather(
            [logits, labels, out['pre_logits'], mask], axis_name='batch')
        return ncorrect, loss, n, metric_args
Ejemplo n.º 20
0
def _extract_signal_patches(signal, window_length, hop=1, data_format="NCW"):
    if hasattr(window_length, "shape"):
        assert window_length.shape == ()
    else:
        assert not hasattr(window_length, "__len__")

    if data_format == "NCW":
        if signal.ndim == 2:
            signal_3d = signal[:, None, :]
        elif signal.ndim == 1:
            signal_3d = signal[None, None, :]
        else:
            signal_3d = signal

        N = (signal_3d.shape[2] - window_length) // hop + 1
        indices = jnp.arange(window_length) + jnp.expand_dims(jnp.arange(N) * hop, 1)
        indices = jnp.reshape(indices, [1, 1, N * window_length])
        patches = jnp.take_along_axis(signal_3d, indices, 2)
        output = jnp.reshape(patches, signal_3d.shape[:2] + (N, window_length))
        if signal.ndim == 1:
            return output[0, 0]
        elif signal.ndim == 2:
            return output[:, 0, :]
        else:
            return output
    else:
        error
Ejemplo n.º 21
0
    def sample_with_intermediates(self, key, sample_shape=()):
        """
        Same as ``sample`` except that the sampled mixture components are also returned.

        :param jax.random.PRNGKey key: the rng_key key to be used for the distribution.
        :param tuple sample_shape: the sample shape for the distribution.
        :return: Tuple (samples, indices)
        :rtype: tuple
        """
        assert is_prng_key(key)
        key_comp, key_ind = jax.random.split(key)
        # Samples from component distribution will have shape:
        #  (*sample_shape, *batch_shape, mixture_size, *event_shape)
        samples = self.component_distribution.expand(
            sample_shape + self.batch_shape +
            (self.mixture_size, )).sample(key_comp)
        # Sample selection indices from the categorical (shape will be sample_shape)
        indices = self.mixing_distribution.expand(
            sample_shape + self.batch_shape).sample(key_ind)
        n_expand = self.event_dim + 1
        indices_expanded = indices.reshape(indices.shape + (1, ) * n_expand)
        # Select samples according to indices samples from categorical
        samples_selected = jnp.take_along_axis(samples,
                                               indices=indices_expanded,
                                               axis=self.mixture_dim)
        # Final sample shape (*sample_shape, *batch_shape, *event_shape)
        return jnp.squeeze(samples_selected, axis=self.mixture_dim), indices
Ejemplo n.º 22
0
def cross_entropy(logits, targets, axis=-1):
    logprobs = nn.log_softmax(logits, axis=axis)
    nll = np.take_along_axis(logprobs,
                             np.expand_dims(targets, axis=axis),
                             axis=axis)
    ce = -np.mean(nll)
    return ce
Ejemplo n.º 23
0
def l1_unit_projection(x):
  """Euclidean projection to L1 unit ball i.e. argmin_{|v|_1<= 1} |x-v|_2.

  Args:
    x: An array of size dim x num.

  Returns:
    An array of size dim x num, the projection to the unit L1 ball.
  """
  # https://dl.acm.org/citation.cfm?id=1390191
  xshape = x.shape
  if len(x.shape) == 1:
    x = x.reshape(-1, 1)
  eshape = x.shape
  v = jnp.abs(x.reshape((-1, eshape[-1])))
  u = jnp.sort(v, axis=0)
  u = u[::-1, :]  # descending
  arange = (1 + jnp.arange(eshape[0])).reshape((-1, 1))
  usum = (jnp.cumsum(u, axis=0) - 1) / arange
  rho = jnp.max(((u - usum) > 0) * arange - 1, axis=0, keepdims=True)
  thx = jnp.take_along_axis(usum, rho, axis=0)
  w = (v - thx).clip(a_min=0)
  w = jnp.where(jnp.linalg.norm(v, ord=1, axis=0, keepdims=True) > 1, w, v)
  x = w.reshape(eshape) * jnp.sign(x)
  return x.reshape(xshape)
Ejemplo n.º 24
0
 def sampling_loop_body_fn(state):
   """Sampling loop state update."""
   i, sequences, cache, cur_token, ended, rng = state
   # Split RNG for sampling.
   rng1, rng2 = random.split(rng)
   # Call fast-decoder model on current tokens to get next-position logits.
   logits, new_cache = tokens_to_logits(cur_token, cache)
   # Sample next token from logits.
   # TODO(levskaya): add top-p "nucleus" sampling option.
   if topk:
     # Get top-k logits and their indices, sample within these top-k tokens.
     topk_logits, topk_idxs = lax.top_k(logits, topk)
     topk_token = jnp.expand_dims(
         random.categorical(rng1, topk_logits / temperature).astype(jnp.int32),
         axis=-1)
     # Return the original indices corresponding to the sampled top-k tokens.
     next_token = jnp.squeeze(
         jnp.take_along_axis(topk_idxs, topk_token, axis=-1), axis=-1)
   else:
     next_token = random.categorical(rng1,
                                     logits / temperature).astype(jnp.int32)
   # Only use sampled tokens if we're past provided prefix tokens.
   out_of_prompt = (sequences[:, i + 1] == 0)
   next_token = (
       next_token * out_of_prompt + sequences[:, i + 1] * ~out_of_prompt)
   # If end-marker reached for batch item, only emit padding tokens.
   next_token_or_endpad = next_token * ~ended
   ended |= (next_token_or_endpad == end_marker)
   # Add current sampled tokens to recorded sequences.
   new_sequences = lax.dynamic_update_slice(sequences, next_token_or_endpad,
                                            (0, i + 1))
   return (i + 1, new_sequences, new_cache, next_token_or_endpad, ended, rng2)
Ejemplo n.º 25
0
    def cifar_10h_evaluation_fn(params, images, labels, mask):
        logits, out = model.apply({'params': flax.core.freeze(params)},
                                  images,
                                  train=False)
        label_indices = config.get('label_indices')
        if label_indices:
            logits = logits[:, label_indices]

        losses = getattr(train_utils,
                         config.get('loss', 'softmax_xent'))(logits=logits,
                                                             labels=labels,
                                                             reduction=False)
        loss = jax.lax.psum(losses, axis_name='batch')

        top1_idx = jnp.argmax(logits, axis=1)
        # Extracts the label at the highest logit index for each image.
        one_hot_labels = jnp.eye(10)[jnp.argmax(labels, axis=1)]

        top1_correct = jnp.take_along_axis(one_hot_labels,
                                           top1_idx[:, None],
                                           axis=1)[:, 0]
        ncorrect = jax.lax.psum(top1_correct, axis_name='batch')
        n = jax.lax.psum(one_hot_labels, axis_name='batch')

        metric_args = jax.lax.all_gather(
            [logits, labels, out['pre_logits'], mask], axis_name='batch')
        return ncorrect, loss, n, metric_args
Ejemplo n.º 26
0
    def evaluation_fn(params, images, labels, mask):
        # Ignore the entries with all zero labels for evaluation.
        mask *= labels.max(axis=1)
        logits, out = model.apply({'params': flax.core.freeze(params)},
                                  images,
                                  train=False)
        label_indices = config.get('label_indices')
        logging.info('!!! mask %s, label_indices %s', mask, label_indices)
        if label_indices:
            logits = logits[:, label_indices]

        # Note that logits and labels are usually of the shape [batch,num_classes].
        # But for OOD data, when num_classes_ood > num_classes_ind, we need to
        # adjust labels to labels[:, :config.num_classes] to match the shape of
        # logits. That is just to avoid shape mismatch. The output losses does not
        # have any meaning for OOD data, because OOD not belong to any IND class.
        losses = getattr(train_utils, config.get('loss', 'sigmoid_xent'))(
            logits=logits,
            labels=labels[:, :(
                len(label_indices) if label_indices else config.num_classes)],
            reduction=False)
        loss = jax.lax.psum(losses * mask, axis_name='batch')

        top1_idx = jnp.argmax(logits, axis=1)
        # Extracts the label at the highest logit index for each image.
        top1_correct = jnp.take_along_axis(labels, top1_idx[:, None],
                                           axis=1)[:, 0]
        ncorrect = jax.lax.psum(top1_correct * mask, axis_name='batch')
        n = jax.lax.psum(mask, axis_name='batch')

        metric_args = jax.lax.all_gather(
            [logits, labels, out['pre_logits'], mask], axis_name='batch')
        return ncorrect, loss, n, metric_args
Ejemplo n.º 27
0
 def log_prob(self, value):
     batch_shape = lax.broadcast_shapes(jnp.shape(value), self.batch_shape)
     value = jnp.expand_dims(value, -1)
     value = jnp.broadcast_to(value, batch_shape + (1,))
     log_pmf = self.logits - logsumexp(self.logits, axis=-1, keepdims=True)
     log_pmf = jnp.broadcast_to(log_pmf, batch_shape + jnp.shape(log_pmf)[-1:])
     return jnp.take_along_axis(log_pmf, value, -1)[..., 0]
Ejemplo n.º 28
0
def sparse_categorical_crossentropy(y_true: jnp.ndarray,
                                    y_pred: jnp.ndarray,
                                    from_logits: bool = False) -> jnp.ndarray:

    if from_logits:
        y_pred = jax.nn.log_softmax(y_pred)
        return -jnp.take_along_axis(y_pred, y_true[..., None], axis=-1)[..., 0]

    else:
        # select output value
        y_pred = jnp.take_along_axis(y_pred, y_true[..., None], axis=-1)[...,
                                                                         0]

        # calculate log
        y_pred = jnp.maximum(y_pred, utils.EPSILON)
        y_pred = jnp.log(y_pred)
        return -y_pred
Ejemplo n.º 29
0
 def log_prob(self, value):
     if self._validate_args:
         self._validate_sample(value)
     value = np.expand_dims(value, -1)
     log_pmf = self.logits - logsumexp(self.logits, axis=-1, keepdims=True)
     value, log_pmf = promote_shapes(value, log_pmf)
     value = value[..., :1]
     return np.take_along_axis(log_pmf, value, -1)[..., 0]
Ejemplo n.º 30
0
    def sampling_loop_body_fn(state):
        """Sampling loop state update."""
        i, sequences, cache, cur_token, ended, rng, tokens_to_logits_state = state

        # Split RNG for sampling.
        rng1, rng2 = random.split(rng)

        # Call fast-decoder model on current tokens to get raw next-position logits.
        logits, new_cache, new_tokens_to_logits_state = tokens_to_logits(
            cur_token, cache, internal_state=tokens_to_logits_state)
        logits = logits / temperature

        # Mask out the BOS token.
        if masked_tokens is not None:
            mask = common_utils.onehot(jnp.array(masked_tokens),
                                       num_classes=logits.shape[-1],
                                       on_value=LARGE_NEGATIVE)
            mask = jnp.sum(mask,
                           axis=0)[None, :]  # Combine multiple masks together
            logits = logits + mask

        # Apply the repetition penalty.
        if repetition_penalty != 1:
            logits = apply_repetition_penalty(
                sequences,
                logits,
                i,
                repetition_penalty=repetition_penalty,
                repetition_window=repetition_window,
                repetition_penalty_normalize=repetition_penalty_normalize)

        # Mask out everything but the top-k entries.
        if top_k is not None:
            # Compute top_k_index and top_k_threshold with shapes (batch_size, 1).
            top_k_index = jnp.argsort(logits,
                                      axis=-1)[:, ::-1][:, top_k - 1:top_k]
            top_k_threshold = jnp.take_along_axis(logits, top_k_index, axis=-1)
            logits = jnp.where(logits < top_k_threshold,
                               jnp.full_like(logits, LARGE_NEGATIVE), logits)
        # Sample next token from logits.
        sample = multinomial(rng1, logits)
        next_token = sample.astype(jnp.int32)
        # Only use sampled tokens if we have past the out_of_prompt_marker.
        out_of_prompt = (sequences[:, i + 1] == out_of_prompt_marker)
        next_token = (next_token * out_of_prompt +
                      sequences[:, i + 1] * ~out_of_prompt)
        # If end-marker reached for batch item, only emit padding tokens.
        next_token = next_token[:, None]
        next_token_or_endpad = jnp.where(ended,
                                         jnp.full_like(next_token, pad_token),
                                         next_token)
        ended |= (next_token_or_endpad == end_marker)
        # Add current sampled tokens to recorded sequences.
        new_sequences = lax.dynamic_update_slice(sequences,
                                                 next_token_or_endpad,
                                                 (0, i + 1))
        return (i + 1, new_sequences, new_cache, next_token_or_endpad, ended,
                rng2, new_tokens_to_logits_state)