Ejemplo n.º 1
0
 def cross_entropy_loss(logits, labels):
     start_loss = optax.softmax_cross_entropy(
         logits[0], onehot(labels[0], num_classes=num_labels))
     end_loss = optax.softmax_cross_entropy(
         logits[1], onehot(labels[1], num_classes=num_labels))
     xentropy = (start_loss + end_loss) / 2.0
     return jnp.mean(xentropy)
    def test_small_byt5_integration_test(self):
        """
        For comparision run:
        >>> import t5  # pip install t5==0.9.1

        >>> path_to_byt5_small_checkpoint = '<fill_in>'
        >>> t5_model = t5.models.MtfModel(model_dir=path_to_tf_checkpoint, batch_size=1, tpu=None)
        >>> vocab = t5.data.ByteVocabulary()
        >>> score = t5_model.score(inputs=["Hello there"], targets=["Hi I am"], vocabulary=vocab)
        """

        model = FlaxT5ForConditionalGeneration.from_pretrained(
            "google/byt5-small")
        tokenizer = ByT5Tokenizer.from_pretrained("google/byt5-small")

        input_ids = tokenizer("Hello there", return_tensors="np").input_ids
        labels = tokenizer("Hi I am", return_tensors="np").input_ids

        decoder_input_ids = shift_tokens_right(
            labels, model.config.pad_token_id,
            model.config.decoder_start_token_id)

        logits = model(input_ids, decoder_input_ids=decoder_input_ids).logits
        loss = optax.softmax_cross_entropy(logits,
                                           onehot(labels,
                                                  logits.shape[-1])).mean()

        mtf_score = -(labels.shape[-1] * loss.item())

        EXPECTED_SCORE = -60.7397
        self.assertTrue(abs(mtf_score - EXPECTED_SCORE) < 1e-4)
Ejemplo n.º 3
0
    def eval_step(params, batch):
        labels = batch.pop("labels")

        outputs = model(**batch,
                        output_attentions=True,
                        params=params,
                        train=False)
        logits = outputs["logits"]

        # compute loss
        loss = optax.softmax_cross_entropy(logits,
                                           onehot(labels, logits.shape[-1]))

        # compute accuracy
        accuracy = jnp.equal(jnp.argmax(logits, axis=-1), labels)

        # compute head specialization
        specialization = compute_specialization_metric(
            jnp.swapaxes(jnp.stack(outputs["encoder_attentions"]), 0, 1))

        # summarize metrics
        metrics = {
            "loss": loss.mean(),
            "accuracy": accuracy.mean(),
            "specialization": specialization
        }
        metrics = jax.lax.pmean(metrics, axis_name="batch")

        return metrics
Ejemplo n.º 4
0
def compute_contrastive_loss(
    quantized_features, transformer_features, negative_indices, mask_time_indices, logits_temp, num_negatives
):
    batch_size, sequence_length, hidden_size = quantized_features.shape

    # take negative vectors from sampled indices
    quantized_negatives = quantized_features.reshape(-1, hidden_size)[negative_indices.reshape(-1)]
    quantized_negatives = quantized_negatives.reshape(
        batch_size, sequence_length, num_negatives, hidden_size
    ).transpose(2, 0, 1, 3)

    target_features = jnp.concatenate([quantized_features[None, :], quantized_negatives], axis=0)
    loss_logits = optax.cosine_similarity(transformer_features, target_features)
    loss_logits = loss_logits / logits_temp

    neg_is_pos = (quantized_features == quantized_negatives).all(-1)
    neg_is_pos = jnp.concatenate([jnp.full((1,) + loss_logits.shape[1:], False), neg_is_pos], axis=0)

    # make sure incorrectly sampled vectors don't contribute to loss
    loss_logits = jnp.where(neg_is_pos, -1e9, loss_logits)

    predictions = loss_logits.transpose(2, 1, 0).reshape(-1, loss_logits.shape[0])
    targets = ((1 - mask_time_indices) * -100).transpose(1, 0).flatten()

    target_mask = jnp.where(targets >= 0, 1.0, 0.0)
    contrastive_loss = optax.softmax_cross_entropy(predictions, onehot(targets, predictions.shape[-1])) * target_mask

    contrastive_loss = contrastive_loss.sum()

    return contrastive_loss
Ejemplo n.º 5
0
 def loss_fn(self, label_batch: spec.Tensor,
             logits_batch: spec.Tensor) -> spec.Tensor:  # differentiable
     """Cross Entropy Loss"""
     one_hot_labels = jax.nn.one_hot(label_batch, num_classes=1000)
     xentropy = optax.softmax_cross_entropy(logits=logits_batch,
                                            labels=one_hot_labels)
     return xentropy
Ejemplo n.º 6
0
def compute_loss(logits, labels):
    """Computes the mean softmax cross-entropy loss."""
    assert labels.shape == logits.shape, f'Got incompatible shapes: logits as {logits.shape}, labels as {labels.shape}'

    loss = optax.softmax_cross_entropy(logits=logits, labels=labels)
    loss = jnp.mean(loss)
    return loss
Ejemplo n.º 7
0
 def loss_fn(_logits):  # pylint: disable=invalid-name
     if config.use_cpc:
         return (optax.softmax_cross_entropy(logits=_logits,
                                             labels=I) +
                 0.01 * jax.nn.logsumexp(_logits, axis=1)**2)
     else:
         return optax.sigmoid_binary_cross_entropy(
             logits=_logits, labels=I)
Ejemplo n.º 8
0
def compute_metrics(logits, labels):
    loss = jnp.mean(optax.softmax_cross_entropy(logits, onehot(labels)))
    accuracy = jnp.mean(jnp.argmax(logits, -1) == labels)
    metrics = {
        'loss': loss,
        'accuracy': accuracy,
    }
    return metrics
def eval_step(train_state, batch):
    images = rearrange(batch['images'], 'H W C N -> N H W C')
    variables = {'params': train_state.params}
    logits = train_state.apply_fn(variables, images, is_training=False)
    logits = logits.astype(jnp.float32)
    y = one_hot(batch['labels'])
    loss = jnp.mean(optax.softmax_cross_entropy(logits, y))
    loss = loss / jax.device_count()
    return jax.lax.psum(loss, axis_name='batch')
Ejemplo n.º 10
0
        def loss_fn(params):
            labels = batch.pop("labels")

            logits = state.apply_fn(**batch, params=params, dropout_rng=dropout_rng, train=True)[0]

            # compute loss
            loss = optax.softmax_cross_entropy(logits, onehot(labels, logits.shape[-1])).mean()

            return loss
Ejemplo n.º 11
0
def evaluate_predictions(logits, labels, mask):
    """Evaluates the model on the given dataset."""
    loss = optax.softmax_cross_entropy(logits, labels)
    loss = jnp.where(mask, loss, 0.)
    loss = jnp.sum(loss) / jnp.sum(mask)

    logits_match_labels = (jnp.argmax(logits, -1) == jnp.argmax(labels, -1))
    logits_match_labels = jnp.where(mask, logits_match_labels, 0.)
    accuracy = jnp.sum(logits_match_labels) / jnp.sum(mask)
    return loss, accuracy
Ejemplo n.º 12
0
def softmax_cross_entropy(logits, labels):
  """Computes softmax cross entropy given logits and one-hot class labels.

  Args:
    logits: Logit output values.
    labels: Ground truth one-hot-encoded labels.

  Returns:
    Loss value with the same shape as `labels`;
  """
  return jnp.asarray(optax.softmax_cross_entropy(logits, labels))
Ejemplo n.º 13
0
        def loss_fn(params):
            labels = batch.pop("labels")

            logits = state.apply_fn(**batch, params=params, dropout_rng=dropout_rng, train=True)[0]

            # compute loss, ignore padded input tokens
            label_mask = jnp.where(labels > 0, 1.0, 0.0)
            loss = optax.softmax_cross_entropy(logits, onehot(labels, logits.shape[-1])) * label_mask

            # take average
            loss = loss.sum() / label_mask.sum()

            return loss
 def loss_fn(params):
     images = rearrange(batch['images'], 'H W C N -> N H W C')
     images = images.astype(jnp.bfloat16)
     logits = train_state.apply_fn(params, images, is_training=True)
     y = one_hot(batch['labels'])
     if 'mix_labels' in batch:
         y1 = one_hot(batch['mix_labels'])
         y = batch['ratio'][:,
                            None] * y + (1. - batch['ratio'][:, None]) * y1
     y = optax.smooth_labels(y, label_smoothing)
     logits = logits.astype(jnp.float32)
     loss = jnp.mean(optax.softmax_cross_entropy(logits, y))
     scaled_loss = loss / jax.device_count()
     return scaled_loss, logits
Ejemplo n.º 15
0
def loss_fn(forward, params, state, batch, l2=True):
    """Computes a regularized loss for the given batch."""
    logits, state = forward.apply(
        params, state, None, batch, is_training=True)
    labels = jax.nn.one_hot(batch[1], CLASS_NUM)
    logits = logits.reshape(len(labels), CLASS_NUM)  # match labels shape
    loss = optax.softmax_cross_entropy(logits=logits, labels=labels).mean()
    acc = (labels.argmax(1) == logits.argmax(1)).mean()

    if l2:
        l2_params = [p for ((mod_name, _), p) in tree.flatten_with_path(params)
                     if 'batchnorm' not in mod_name]
        loss = loss + 5e-4 * l2_loss(l2_params)
    return loss, (loss, state, acc)
Ejemplo n.º 16
0
    def eval_step(params, batch):
        labels = batch.pop("labels")

        logits = model(**batch, params=params, train=False)[0]

        # compute loss
        loss = optax.softmax_cross_entropy(logits, onehot(labels, logits.shape[-1]))

        # compute accuracy
        accuracy = jnp.equal(jnp.argmax(logits, axis=-1), labels)

        # summarize metrics
        metrics = {"loss": loss.mean(), "accuracy": accuracy.mean()}
        metrics = jax.lax.pmean(metrics, axis_name="batch")

        return metrics
Ejemplo n.º 17
0
        def loss_fn(params):
            labels = batch.pop("labels")

            outputs = state.apply_fn(**batch,
                                     output_attentions=True,
                                     params=params,
                                     dropout_rng=dropout_rng,
                                     train=True)
            logits = outputs["logits"]

            # compute loss
            loss = optax.softmax_cross_entropy(
                logits, onehot(labels, logits.shape[-1])).mean()

            return loss, jnp.swapaxes(jnp.stack(outputs["encoder_attentions"]),
                                      0, 1)
Ejemplo n.º 18
0
    def eval_step(params, batch):
        labels = batch.pop("labels")

        logits = model(**batch, params=params, train=False)[0]

        # compute loss, ignore padded input tokens
        label_mask = jnp.where(labels > 0, 1.0, 0.0)
        loss = optax.softmax_cross_entropy(logits, onehot(labels, logits.shape[-1])) * label_mask

        # compute accuracy
        accuracy = jnp.equal(jnp.argmax(logits, axis=-1), labels) * label_mask

        # summarize metrics
        metrics = {"loss": loss.sum(), "accuracy": accuracy.sum(), "normalizer": label_mask.sum()}
        metrics = jax.lax.psum(metrics, axis_name="batch")

        return metrics
Ejemplo n.º 19
0
def loss_fn(
    params: hk.Params,
    state: hk.State,
    loss_scale: jmp.LossScale,
    batch: dataset.Batch,
) -> Tuple[jnp.ndarray, Tuple[jnp.ndarray, hk.State]]:
    """Computes a regularized loss for the given batch."""
    logits, state = forward.apply(params, state, None, batch, is_training=True)
    labels = jax.nn.one_hot(batch['labels'], 1000)
    if FLAGS.train_smoothing:
        labels = optax.smooth_labels(labels, FLAGS.train_smoothing)
    loss = optax.softmax_cross_entropy(logits=logits, labels=labels).mean()
    l2_params = [
        p for ((mod_name, _), p) in tree.flatten_with_path(params)
        if 'batchnorm' not in mod_name
    ]
    loss = loss + FLAGS.train_weight_decay * l2_loss(l2_params)
    return loss_scale.scale(loss), (loss, state)
Ejemplo n.º 20
0
    def loss_fn(logits, labels, padding_mask, label_smoothing_factor=0.0):
        """
        The label smoothing implementation is adapted from Flax's official example:
        https://github.com/google/flax/blob/87a211135c6a377c8f29048a1cac3840e38b9da4/examples/wmt/train.py#L104
        """
        vocab_size = logits.shape[-1]
        confidence = 1.0 - label_smoothing_factor
        low_confidence = (1.0 - confidence) / (vocab_size - 1)
        normalizing_constant = -(
            confidence * jnp.log(confidence) + (vocab_size - 1) * low_confidence * jnp.log(low_confidence + 1e-20)
        )
        soft_labels = onehot(labels, vocab_size, on_value=confidence, off_value=low_confidence)

        loss = optax.softmax_cross_entropy(logits, soft_labels)
        loss = loss - normalizing_constant

        # ignore padded tokens from loss
        loss = loss * padding_mask
        loss = loss.sum() / padding_mask.sum()
        return loss
Ejemplo n.º 21
0
    def test_step(
        self: M,
        inputs,
        labels,
    ) -> eg.TestStepOutput[M]:
        model: M = self
        # flatten + scale
        inputs = jnp.reshape(inputs, (inputs.shape[0], -1)) / 255

        # forward
        logits, model = model.pred_step(inputs)

        # crossentropy loss
        target = jax.nn.one_hot(labels["target"], self.features_out)
        loss = optax.softmax_cross_entropy(logits, target).mean()

        # metrics
        logs = dict(
            acc=jnp.mean(jnp.argmax(logits, axis=-1) == labels["target"]),
            loss=loss,
        )

        return loss, logs, model
Ejemplo n.º 22
0
def loss_fn(params, batch):
    logits = predict(params, batch['image'])
    return optax.softmax_cross_entropy(logits, batch['label']).mean(), logits
Ejemplo n.º 23
0
def cross_entropy_loss(logits, labels):
    one_hot_labels = common_utils.onehot(labels, num_classes=NUM_CLASSES)
    xentropy = optax.softmax_cross_entropy(logits=logits,
                                           labels=one_hot_labels)
    return jnp.mean(xentropy)
Ejemplo n.º 24
0
 def cross_entropy_loss(logits, labels):
     xentropy = optax.softmax_cross_entropy(logits, onehot(labels, num_classes=num_labels))
     return jnp.mean(xentropy)
Ejemplo n.º 25
0
 def loss_fn(logits, labels):
     shift_logits = logits[..., :-1, :]
     shift_labels = labels[..., 1:]
     loss = optax.softmax_cross_entropy(
         shift_logits, onehot(shift_labels, shift_logits.shape[-1]))
     return loss.mean()
Ejemplo n.º 26
0
 def call(self, y_true, y_preds):
     loss = jnp.sum(optax.softmax_cross_entropy(y_preds,
                                                y_true)) / y_preds.shape[0]
     return loss
Ejemplo n.º 27
0
 def loss_fn(params):
     y_pred = model.apply({"params": params}, x)
     y_one_hot = jax.nn.one_hot(y, 10)
     loss = optax.softmax_cross_entropy(y_pred, y_one_hot).mean()
     return loss, y_pred
Ejemplo n.º 28
0
 def loss_fn(params):
     logits = CNN().apply({'params': params}, batch['image'])
     loss = jnp.mean(
         optax.softmax_cross_entropy(logits=logits,
                                     labels=onehot(batch['label'])))
     return loss, logits
 def loss_fn(logits, labels):
     loss = optax.softmax_cross_entropy(logits,
                                        onehot(labels, logits.shape[-1]))
     return loss.mean()
Ejemplo n.º 30
0
def loss_fn(logits: np.ndarray, labels: np.ndarray) -> np.ndarray:
    onehot_labels = onehot(labels, num_classes=logits.shape[-1])
    return optax.softmax_cross_entropy(logits, onehot_labels).mean()