Пример #1
0
    def apply(self,
              input_ids,
              input_mask,
              type_ids,
              labels=None,
              *,
              config,
              n_classes,
              deterministic=False):
        """Applies BERT for sequence classification."""
        unused_sequence_output, pooled_output = BertModel(
            input_ids,
            input_mask,
            type_ids,
            config=config,
            deterministic=deterministic,
            name="bert")
        # TODO(kitaev): I think I'm missing dropout here
        logits = layers.OutputProjection(pooled_output,
                                         n_out=n_classes,
                                         kernel_init=kernel_initializer,
                                         name="classification")

        if labels is None:
            return logits
        elif logits.shape[-1] == 1:
            # Regression task
            loss = jnp.mean((logits[Ellipsis, 0] - labels)**2)
            return {"loss": loss}
        else:
            # Classification task
            logits = nn.log_softmax(logits)
            loss = -jnp.mean(
                jnp.sum(onehot(labels, logits.shape[-1]) * logits, axis=-1))
            return {"loss": loss}
Пример #2
0
def weighted_unnormalized_cross_entropy(logits, one_hot_targets, weights=None):
    """Compute weighted cross entropy and entropy for log probs and targets.

  This computes sum_(x,y) ce(x, y) for a single, potentially padded minibatch.
  If the minibatch is padded (that is it contains null examples) it is assumed
  that weights is a binary mask where 0 indicates that the example is null.

  Args:
   logits: [batch, length, num_classes] float array.
   one_hot_targets: one hot vector of shape [batch, ..., num_classes].
   weights: None or array of shape [batch x ...] (rank of one_hot_targets -1).

  Returns:
    Cross entropy loss computed per example, shape [batch, ...].
  """
    if logits.ndim != one_hot_targets.ndim:
        raise ValueError(
            'Incorrect shapes. Got shape %s logits and %s one_hot_targets' %
            (str(logits.shape), str(one_hot_targets.shape)))

    loss = -jnp.sum(one_hot_targets * nn.log_softmax(logits), axis=-1)
    if weights is not None:
        if weights.ndim != one_hot_targets.ndim - 1:
            raise ValueError(
                'Incorrect shapes. Got shape %s weights and %s one_hot_targets'
                % (str(weights.shape), str(one_hot_targets.shape)))
        loss = loss * weights

    return loss
Пример #3
0
def apply_activation(intermediate_output, intermediate_activation):
    """Applies selected activation function to intermediate output."""
    if intermediate_activation is None:
        return intermediate_output

    if intermediate_activation == 'gelu':
        intermediate_output = nn.gelu(intermediate_output)
    elif intermediate_activation == 'relu':
        intermediate_output = nn.relu(intermediate_output)
    elif intermediate_activation == 'sigmoid':
        intermediate_output = nn.sigmoid(intermediate_output)
    elif intermediate_activation == 'softmax':
        intermediate_output = nn.softmax(intermediate_output)
    elif intermediate_activation == 'celu':
        intermediate_output = nn.celu(intermediate_output)
    elif intermediate_activation == 'elu':
        intermediate_output = nn.elu(intermediate_output)
    elif intermediate_activation == 'log_sigmoid':
        intermediate_output = nn.log_sigmoid(intermediate_output)
    elif intermediate_activation == 'log_softmax':
        intermediate_output = nn.log_softmax(intermediate_output)
    elif intermediate_activation == 'soft_sign':
        intermediate_output = nn.soft_sign(intermediate_output)
    elif intermediate_activation == 'softplus':
        intermediate_output = nn.softplus(intermediate_output)
    elif intermediate_activation == 'swish':
        intermediate_output = nn.swish(intermediate_output)
    elif intermediate_activation == 'tanh':
        intermediate_output = jnp.tanh(intermediate_output)
    else:
        raise NotImplementedError(
            '%s activation function is not yet supported.' %
            intermediate_activation)

    return intermediate_output
Пример #4
0
 def rnn_cell(carry, x):
     newCarry, logits = jax.vmap(eval_cell)(carry[0], carry[1])
     sampleOut = jax.random.categorical(x, logits)
     sample = jax.nn.one_hot(sampleOut, inputDim)
     logProb = jnp.sum(nn.log_softmax(logits) * sample, axis=1)
     return (newCarry, sample), (jnp.nan_to_num(logProb,
                                                nan=-35), sampleOut)
Пример #5
0
 def apply(self,
           x,
           num_classes,
           num_filters=64,
           num_layers=50,
           train=True,
           dtype=jnp.float32):
     if num_layers not in _block_size_options:
         raise ValueError('Please provide a valid number of layers')
     block_sizes = _block_size_options[num_layers]
     x = nn.Conv(x,
                 num_filters, (7, 7), (2, 2),
                 padding=[(3, 3), (3, 3)],
                 bias=False,
                 dtype=dtype,
                 name='init_conv')
     x = nn.BatchNorm(x,
                      use_running_average=not train,
                      momentum=0.9,
                      epsilon=1e-5,
                      dtype=dtype,
                      name='init_bn')
     x = nn.max_pool(x, (3, 3), strides=(2, 2), padding='SAME')
     for i, block_size in enumerate(block_sizes):
         for j in range(block_size):
             strides = (2, 2) if i > 0 and j == 0 else (1, 1)
             x = ResidualBlock(x,
                               num_filters * 2**i,
                               strides=strides,
                               train=train,
                               dtype=dtype)
     x = jnp.mean(x, axis=(1, 2))
     x = nn.Dense(x, num_classes)
     x = nn.log_softmax(x)
     return x
Пример #6
0
  def apply(self, x, num_classes, *,
            stage_sizes,
            block_cls,
            num_filters=64,
            dtype=jnp.float32,
            act=nn.relu,
            train=True):
    conv = nn.Conv.partial(bias=False, dtype=dtype)
    norm = nn.BatchNorm.partial(
        use_running_average=not train,
        momentum=0.9, epsilon=1e-5,
        dtype=dtype)

    x = conv(x, num_filters, (7, 7), (2, 2),
             padding=[(3, 3), (3, 3)],
             name='conv_init')
    x = norm(x, name='bn_init')
    x = nn.relu(x)
    x = nn.max_pool(x, (3, 3), strides=(2, 2), padding='SAME')
    for i, block_size in enumerate(stage_sizes):
      for j in range(block_size):
        strides = (2, 2) if i > 0 and j == 0 else (1, 1)
        x = block_cls(x, num_filters * 2 ** i,
                      strides=strides,
                      conv=conv,
                      norm=norm,
                      act=act)
    x = jnp.mean(x, axis=(1, 2))
    x = nn.Dense(x, num_classes, dtype=dtype)
    x = jnp.asarray(x, dtype)
    x = nn.log_softmax(x)
    return x
Пример #7
0
 def apply(self, x, hidden_layers, hidden_dim, n_classes):
     x = jnp.reshape(x, (x.shape[0], -1))
     for layer in range(hidden_layers):
         x = nn.Dense(x, hidden_dim, name=f'fc{layer}')
         x = nn.relu(x)
     x = nn.Dense(x, n_classes, name=f'fc{hidden_layers}')
     preds = nn.log_softmax(x)
     return preds
Пример #8
0
 def apply(self,
           x,
           num_classes,
           num_filters=64,
           num_layers=50,
           train=True,
           axis_name=None,
           axis_index_groups=None,
           dtype=jnp.float32,
           conv0_space_to_depth=False):
     if num_layers not in _block_size_options:
         raise ValueError('Please provide a valid number of layers')
     block_sizes = _block_size_options[num_layers]
     if conv0_space_to_depth:
         conv = SpaceToDepthConv.partial(block_size=(2, 2),
                                         padding=[(2, 1), (2, 1)])
     else:
         conv = nn.Conv.partial(padding=[(3, 3), (3, 3)])
     x = conv(x,
              num_filters,
              kernel_size=(7, 7),
              strides=(2, 2),
              bias=False,
              dtype=dtype,
              name='conv0')
     x = nn.BatchNorm(x,
                      use_running_average=not train,
                      momentum=0.9,
                      epsilon=1e-5,
                      name='init_bn',
                      axis_name=axis_name,
                      axis_index_groups=axis_index_groups,
                      dtype=dtype)
     x = nn.relu(x)  # MLPerf-required
     x = nn.max_pool(x, (3, 3), strides=(2, 2), padding='SAME')
     for i, block_size in enumerate(block_sizes):
         for j in range(block_size):
             strides = (2, 2) if i > 0 and j == 0 else (1, 1)
             x = ResidualBlock(x,
                               num_filters * 2**i,
                               strides=strides,
                               train=train,
                               axis_name=axis_name,
                               axis_index_groups=axis_index_groups,
                               dtype=dtype)
     x = jnp.mean(x, axis=(1, 2))
     x = nn.Dense(x,
                  num_classes,
                  kernel_init=nn.initializers.normal(),
                  dtype=dtype)
     x = nn.log_softmax(x)
     return x
Пример #9
0
 def apply(self, x):
     x = nn.Conv(x, features=32, kernel_size=(3, 3))
     x = nn.relu(x)
     x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
     x = nn.Conv(x, features=64, kernel_size=(3, 3))
     x = nn.relu(x)
     x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
     x = x.reshape((x.shape[0], -1))  # flatten
     x = nn.Dense(x, features=256)
     x = nn.relu(x)
     x = nn.Dense(x, features=10)
     x = nn.log_softmax(x)
     return x
Пример #10
0
  def apply(self, x, num_classes, parameters, num_filters=64,
            train=True, axis_name=None, num_layers='34'):
    block_sizes = [3, 4, 6]
    data_format = 'channels_last'
    if ('conv0_space_to_depth' in parameters and
        parameters['conv0_space_to_depth']):
      # conv0 uses space-to-depth transform for TPU performance.
      x = func_conv0_space_to_depth(inputs=x, data_format=data_format,
                                    dtype=parameters['dtype'])
    else:
      x = conv2d_fixed_padding(
          inputs=x,
          filters=num_filters,
          kernel_size=7,
          strides=2,
          data_format=data_format,
          name='init_conv')

    replica_groups = _make_replica_groups(parameters)
    x = nn.BatchNorm(
        x,
        use_running_average=not train,
        momentum=0.9,
        epsilon=1e-5,
        name='init_bn',
        axis_name=axis_name,
        dtype=parameters['dtype'],
        axis_index_groups=replica_groups)
    x = nn.relu(x)
    x = nn.max_pool(x, (3, 3), strides=(2, 2), padding='SAME')
    for i, block_size in enumerate(block_sizes):
      for j in range(block_size):
        strides = 2 if i == 1 and j == 0 else 1
        use_projection = False if i == 0 or j > 0 else True
        x = ResidualBlock(
            x,
            num_filters * 2**i,
            parameters,
            strides=strides,
            train=train,
            axis_name=axis_name,
            use_projection=use_projection,
            data_format=data_format)
    if num_layers == '34':
      x = jnp.mean(x, axis=(1, 2))
      x = nn.Dense(x, num_classes, kernel_init=nn.initializers.normal(),
                   dtype=jnp.float32)  # TODO(deveci): dtype=dtype
      x = nn.log_softmax(x)
    return x
Пример #11
0
    def _compute_metrics(self, masked_lm_logits, next_sentence_logits,
                         masked_lm_labels, masked_lm_weights,
                         next_sentence_labels, **unused_kwargs):
        """Computes the pre-training loss and its components."""
        masked_lm_logits = nn.log_softmax(masked_lm_logits)
        masked_lm_labels = onehot(masked_lm_labels.reshape((-1, )),
                                  masked_lm_logits.shape[-1])
        masked_lm_weights = masked_lm_weights.reshape((-1, ))
        masked_lm_loss = -jnp.sum(
            jnp.sum(masked_lm_logits * masked_lm_labels, axis=-1) *
            masked_lm_weights) / jnp.sum(masked_lm_weights)

        next_sentence_logits = nn.log_softmax(next_sentence_logits)
        next_sentence_labels = next_sentence_labels.reshape((-1, ))
        next_sentence_loss = -jnp.mean(
            jnp.sum(
                onehot(next_sentence_labels, next_sentence_logits.shape[-1]) *
                next_sentence_logits,
                axis=-1))
        return {
            'loss': masked_lm_loss + next_sentence_loss,
            'masked_lm_loss': masked_lm_loss,
            'next_sentence_loss': next_sentence_loss,
        }
Пример #12
0
def softmax_cross_entropy(logits, onehot_labels):
    """Returns a cross-entropy loss tensor.

  Note that `onehot_labels` and `logits` must have the same shape,
  e.g. `[batch_size, num_classes]`. This does not perform a reduction on loss,
  e.g., loss is a `Tensor` of shape `[batch_size]`.

  Args:
    logits: Float-like logits outputs of the network.
    onehot_labels: Float-like one-hot-encoded labels.
  Returns:
    Loss `Tensor` of the same type as `logits`, which has shape `[batch_size]`.
  """
    log_softmax = -nn.log_softmax(logits)
    return jnp.sum(log_softmax * onehot_labels, axis=-1)
Пример #13
0
def compute_weighted_cross_entropy(logits,
                                   labels):
  """Compute weighted cross entropy and entropy for log probs and labels.

  Args:
   logits: [batch, length, num_classes] float array.
   labels: categorical targets [batch, length] int array.
  Returns:
    Tuple of scalars of loss and per example loss.
  """
  log_probs = nn.log_softmax(logits)
  labels = jnp.reshape(labels, [-1])
  one_hot_labels = common_utils.onehot(labels, num_classes=2)
  per_example_loss = -jnp.sum(one_hot_labels * log_probs, axis=-1)
  loss = jnp.mean(per_example_loss)
  return (loss, per_example_loss)
Пример #14
0
def get_masked_lm_output(logits, label_ids, label_weights):
  """Calculate masked_lm loss for pretrain task."""
  vocab_size = logits.shape[-1]

  label_ids = jnp.reshape(label_ids, (-1))
  label_weights = jnp.reshape(label_weights, (-1))
  one_hot_labels = common_utils.onehot(
      label_ids, vocab_size, on_value=1.0, off_value=0.0)

  log_probs = nn.log_softmax(logits)
  per_example_loss = -jnp.sum(log_probs * one_hot_labels, axis=-1)

  numerator = jnp.sum(label_weights * per_example_loss)
  denominator = jnp.sum(label_weights) + 1e-5
  loss = numerator / denominator
  return loss, per_example_loss, log_probs
Пример #15
0
 def apply(self,
           x,
           num_classes,
           train=True,
           batch_stats=None,
           axis_name=None,
           dtype=jnp.float32):
   x = nn.BatchNorm(x,
                    batch_stats=batch_stats,
                    use_running_average=not train,
                    momentum=0.9, epsilon=1e-5,
                    name='init_bn', axis_name=axis_name, dtype=dtype)
   x = jnp.mean(x, axis=(1, 2))
   x = nn.Dense(x, num_classes, kernel_init=nn.initializers.normal(),
                dtype=dtype)
   x = nn.log_softmax(x)
   return x
Пример #16
0
def compute_logprob(inputs, model, mask_token=None):
    """Returns an array of log probabilities for the input sequences."""

    assert inputs.ndim == 2

    targets = inputs
    weights = jnp.where(targets != model.pad_token, 1, 0)
    if mask_token is not None:
        weights *= jnp.where(targets != mask_token, 1, 0)
    logits = model.score(inputs)
    assert logits.ndim == 3

    onehot_targets = common_utils.onehot(targets, logits.shape[-1])
    log_lik = jnp.sum(onehot_targets * nn.log_softmax(logits), axis=-1)
    log_lik *= weights
    log_prob = jnp.sum(log_lik, axis=-1)

    return log_prob
Пример #17
0
def compute_weighted_cross_entropy(logits,
                                   targets,
                                   token_weights=None,
                                   example_weights=None):
    """Compute weighted cross entropy and entropy for log probs and targets.

  The loss is assumed to be sum_i example_weights[i] * logprob[i], where
  i indexes elements in the batch.

  logprob[i] is the log probability of sequence i, which is a weighted
  average of per-token log probabilities with weights according
  to token_weights. Typically token_weights is a mask for whether tokens are
  padding or not.

  Maximum likelihood training sets example_weights[i] = 1.
  Training with a REINFORCE-style objective may set example_weights[i]
  to any positive or negative number.

  Args:
   logits: [batch, length, num_classes] float array.
   targets: categorical targets [batch, length] int array.
   token_weights: None or array of shape [batch x length]
   example_weights: None or array of shape [batch_size]
  Returns:
    Tuple of scalar loss and batch normalizing factor.
  """
    if logits.ndim != targets.ndim + 1:
        raise ValueError(
            'Incorrect shapes. Got shape %s logits and %s targets' %
            (str(logits.shape), str(targets.shape)))
    onehot_targets = common_utils.onehot(targets, logits.shape[-1])
    loss = -jnp.sum(onehot_targets * nn.log_softmax(logits), axis=-1)
    normalizing_factor = onehot_targets.sum()
    if token_weights is not None:
        loss = loss * token_weights
        normalizing_factor = token_weights.sum()

    if example_weights is not None:
        loss = loss.sum(axis=1)
        loss *= example_weights

    return loss.sum(), normalizing_factor
Пример #18
0
def compute_weighted_cross_entropy(logits, targets, num_classes, weights=None):
    """Compute weighted cross entropy and entropy for log probs and targets.

  Args:
   logits: [batch, num_classes] float array.
   targets: categorical targets [batch, length] int array.
   num_classes: int, num classes of problem.
   weights: None or array of shape [batch x length]

  Returns:
    Tuple of scalar loss and batch normalizing factor.
  """
    onehot_targets = common_utils.onehot(targets, num_classes)
    loss = -jnp.sum(onehot_targets * nn.log_softmax(logits), axis=-1)
    normalizing_factor = onehot_targets.sum()
    if weights is not None:
        loss = loss * weights
        normalizing_factor = weights.sum()

    return loss.sum(), normalizing_factor
Пример #19
0
    def apply(self, node_x, edge_x, sources, targets):
        """Computes GNN forward pass.

    Args:
      node_x: node features with shape of `[num_nodes, num_features]`.
      edge_x: `None` or edge features with shape of `[num_edges, num_features]`.
      sources: Array of source node indices with shape of `[num_edges]`.
      targets: Array of target node indices with shape of `[num_edges]`.

    Returns:
      Output of shape `[num_nodes, num_features]`.
    """

        node_x = GraphConvBlock(node_x, edge_x, sources, targets, features=32)
        node_x = nn.relu(node_x)

        node_x = GraphConvBlock(node_x, edge_x, sources, targets, features=2)
        node_x = nn.log_softmax(node_x)

        return node_x
Пример #20
0
def compute_weighted_cross_entropy(logits,
                                   targets,
                                   weights=None,
                                   label_smoothing=0.0):
    """Compute weighted cross entropy and entropy for log probs and targets.

  Args:
   logits: [batch, length, num_classes] float array.
   targets: categorical targets [batch, length] int array.
   weights: None or array of shape [batch, length].
   label_smoothing: label smoothing constant, used to determine the on and
     off values.

  Returns:
    Tuple of scalar loss and batch normalizing factor.
  """
    if logits.ndim != targets.ndim + 1:
        raise ValueError(
            'Incorrect shapes. Got shape %s logits and %s targets' %
            (str(logits.shape), str(targets.shape)))
    vocab_size = logits.shape[-1]
    confidence = 1.0 - label_smoothing
    low_confidence = (1.0 - confidence) / (vocab_size - 1)
    normalizing_constant = -(
        confidence * jnp.log(confidence) +
        (vocab_size - 1) * low_confidence * jnp.log(low_confidence + 1e-20))
    soft_targets = common_utils.onehot(targets,
                                       vocab_size,
                                       on_value=confidence,
                                       off_value=low_confidence)

    loss = -jnp.sum(soft_targets * nn.log_softmax(logits), axis=-1)
    loss = loss - normalizing_constant

    normalizing_factor = jnp.prod(targets.shape)
    if weights is not None:
        loss = loss * weights
        normalizing_factor = weights.sum()

    return loss.sum(), normalizing_factor
Пример #21
0
    def apply(self, x, num_outputs):
        """Define the convolutional network architecture.

    Architecture originates from "Human-level control through deep reinforcement
    learning.", Nature 518, no. 7540 (2015): 529-533.
    Note that this is different than the one from  "Playing atari with deep
    reinforcement learning." arxiv.org/abs/1312.5602 (2013)
    """
        dtype = jnp.float32
        x = x.astype(dtype) / 255.
        x = nn.Conv(x,
                    features=32,
                    kernel_size=(8, 8),
                    strides=(4, 4),
                    name='conv1',
                    dtype=dtype)
        x = nn.relu(x)
        x = nn.Conv(x,
                    features=64,
                    kernel_size=(4, 4),
                    strides=(2, 2),
                    name='conv2',
                    dtype=dtype)
        x = nn.relu(x)
        x = nn.Conv(x,
                    features=64,
                    kernel_size=(3, 3),
                    strides=(1, 1),
                    name='conv3',
                    dtype=dtype)
        x = nn.relu(x)
        x = x.reshape((x.shape[0], -1))  # flatten
        x = nn.Dense(x, features=512, name='hidden', dtype=dtype)
        x = nn.relu(x)
        # Network used to both estimate policy (logits) and expected state value.
        # See github.com/openai/baselines/blob/master/baselines/ppo1/cnn_policy.py
        logits = nn.Dense(x, features=num_outputs, name='logits', dtype=dtype)
        policy_log_probabilities = nn.log_softmax(logits)
        value = nn.Dense(x, features=1, name='value', dtype=dtype)
        return policy_log_probabilities, value
Пример #22
0
def compute_weighted_cross_entropy(logits, targets, weights=None):
  """Compute weighted cross entropy and entropy for log probs and targets.

  Args:
   logits: [batch, length, num_classes] float array.
   targets: categorical targets [batch, length] int array.
   weights: None or array of shape [batch x length]

  Returns:
    Tuple of scalar loss and batch normalizing factor.
  """
  if logits.ndim != targets.ndim + 1:
    raise ValueError('Incorrect shapes. Got shape %s logits and %s targets' %
                     (str(logits.shape), str(targets.shape)))
  onehot_targets = common_utils.onehot(targets, logits.shape[-1])
  loss = -jnp.sum(onehot_targets * nn.log_softmax(logits), axis=-1)
  normalizing_factor = onehot_targets.sum()
  if weights is not None:
    loss = loss * weights
    normalizing_factor = weights.sum()

  return loss.sum(), normalizing_factor
Пример #23
0
    def apply(self,
              input_ids,
              input_mask,
              type_ids,
              labels=None,
              *,
              config,
              n_classes,
              deterministic=False):
        """Applies BERT for sequence classification."""
        unused_sequence_output, pooled_output = BertModel(
            input_ids,
            input_mask,
            type_ids,
            config=config,
            deterministic=deterministic,
            name='bert')
        pooled_output = nn.dropout(pooled_output,
                                   rate=config.hidden_dropout_prob,
                                   deterministic=deterministic)
        logits = layers.OutputProjection(pooled_output,
                                         n_out=n_classes,
                                         kernel_init=get_kernel_init(config),
                                         name='classification')

        if labels is None:
            return logits
        elif logits.shape[-1] == 1:
            # Regression task
            loss = jnp.mean((logits[..., 0] - labels)**2)
            return {'loss': loss}
        else:
            # Classification task
            logits = nn.log_softmax(logits)
            loss = -jnp.mean(
                jnp.sum(onehot(labels, logits.shape[-1]) * logits, axis=-1))
            return {'loss': loss}
Пример #24
0
Файл: train.py Проект: us/flax
def cross_entropy_loss(logits, labels, lengths):
  """Returns cross-entropy loss."""
  xe = jnp.sum(nn.log_softmax(logits) * labels, axis=-1)
  masked_xe = jnp.mean(mask_sequences(xe, lengths))
  return -masked_xe
Пример #25
0
def compute_cross_entropy(logits, targets):
    onehot_targets = common_utils.onehot(targets, logits.shape[-1])
    loss = -jnp.sum(onehot_targets * nn.log_softmax(logits), axis=-1)
    return loss
Пример #26
0
 def rnn_cell(carry, x):
     newCarry, out = rnnCell(carry[0], carry[1])
     logProb = nn.log_softmax(actFun(probDense(out)))
     logProb = jnp.sum(logProb * x, axis=-1)
     return (newCarry, x), (jnp.nan_to_num(logProb, nan=-35), out)
Пример #27
0
 def batch_ce_loss(logits, targets):
   one_hot_targets = np.eye(4)[targets]
   loss = -np.sum(one_hot_targets * nn.log_softmax(logits), axis=-1)
   return loss
Пример #28
0
def cross_entropy_loss(logits, labels):
    """Returns cross-entropy loss."""
    return -jnp.mean(jnp.sum(nn.log_softmax(logits) * labels[:, 1:], axis=-1))