def cross_entropy_loss(logits, labels): start_loss = optax.softmax_cross_entropy( logits[0], onehot(labels[0], num_classes=num_labels)) end_loss = optax.softmax_cross_entropy( logits[1], onehot(labels[1], num_classes=num_labels)) xentropy = (start_loss + end_loss) / 2.0 return jnp.mean(xentropy)
def compute_metrics( masked_lm_logits: jnp.ndarray, next_sentence_logits: jnp.ndarray, masked_lm_labels: jnp.ndarray, masked_lm_weights: jnp.ndarray, next_sentence_labels: jnp.ndarray, ): """Computes the pre-training loss and its components.""" masked_lm_logits = nn.log_softmax(masked_lm_logits) masked_lm_labels = onehot(masked_lm_labels.reshape((-1, )), masked_lm_logits.shape[-1]) masked_lm_weights = masked_lm_weights.reshape((-1, )) masked_lm_loss = -jnp.sum( jnp.sum(masked_lm_logits * masked_lm_labels, axis=-1) * masked_lm_weights) / jnp.sum(masked_lm_weights) next_sentence_logits = nn.log_softmax(next_sentence_logits) next_sentence_labels = next_sentence_labels.reshape((-1, )) next_sentence_loss = -jnp.mean( jnp.sum( onehot(next_sentence_labels, next_sentence_logits.shape[-1]) * next_sentence_logits, axis=-1, )) return { "loss": masked_lm_loss + next_sentence_loss, "masked_lm_loss": masked_lm_loss, "next_sentence_loss": next_sentence_loss, }
def apply(self, input_ids, input_mask, type_ids, labels=None, *, config, n_classes, deterministic=False): """Applies BERT for sequence classification.""" unused_sequence_output, pooled_output = BertModel( input_ids, input_mask, type_ids, config=config, deterministic=deterministic, name="bert") # TODO(kitaev): I think I'm missing dropout here logits = layers.OutputProjection(pooled_output, n_out=n_classes, kernel_init=kernel_initializer, name="classification") if labels is None: return logits elif logits.shape[-1] == 1: # Regression task loss = jnp.mean((logits[Ellipsis, 0] - labels)**2) return {"loss": loss} else: # Classification task logits = nn.log_softmax(logits) loss = -jnp.mean( jnp.sum(onehot(labels, logits.shape[-1]) * logits, axis=-1)) return {"loss": loss}
def sample(inputs, optimizer): next_inputs = inputs output = [] batch_size = 1 carry1 = nn.LSTMCell.initialize_carry(jax.random.PRNGKey(0), (batch_size, ), 512) carry2 = nn.LSTMCell.initialize_carry(jax.random.PRNGKey(0), (batch_size, ), 512) carry3 = nn.LSTMCell.initialize_carry(jax.random.PRNGKey(0), (batch_size, ), 512) carry = [carry1, carry2, carry3] def inference(model, carry): carry, rnn_output = model(inputs=next_inputs, train=False, carry_pred=carry) return carry, rnn_output for i in range(200): carry, rnn_output = inference(optimizer.target, carry) output.append(jnp.argmax(rnn_output, axis=-1)) # Select the argmax as the next input. next_inputs = jnp.expand_dims(common_utils.onehot( jnp.argmax(rnn_output), params['vocab_length']), axis=0) return output
def __call__( self, input_ids: jnp.ndarray, input_mask: jnp.ndarray, type_ids: jnp.ndarray, labels: jnp.ndarray = None, *, deterministic: bool = False, ): """Applies BERT for sequence classification.""" bert = BertModel(config=self.config, name="bert") _, pooled_output = bert(input_ids, input_mask, type_ids, deterministic=deterministic) pooled_output = nn.Dropout(rate=self.config.hidden_dropout_prob, deterministic=deterministic)(pooled_output) logits = layers.OutputProjection( n_out=self.n_classes, kernel_init=get_kernel_init(self.config), name="classification", )(pooled_output) if labels is None: return logits elif logits.shape[-1] == 1: # Regression task loss = jnp.mean((logits[..., 0] - labels)**2) return {"loss": loss} else: # Classification task logits = nn.log_softmax(logits) loss = -jnp.mean( jnp.sum(onehot(labels, logits.shape[-1]) * logits, axis=-1)) return {"loss": loss}
def loss_fn(params): """loss function used for training.""" logits = models.Transformer(config).apply( {"params": params}, inputs, targets, inputs_positions=inputs_positions, targets_positions=targets_positions, inputs_segmentation=inputs_segmentation, targets_segmentation=targets_segmentation, rngs={"dropout": dropout_rng}) vocab_size = logits.shape[-1] confidence = 1.0 - label_smoothing low_confidence = (1.0 - confidence) / (vocab_size - 1) normalizing_constant = -( confidence * jnp.log(confidence) + (vocab_size - 1) * low_confidence * jnp.log(low_confidence + 1e-20)) soft_targets = common_utils.onehot( targets, vocab_size, on_value=confidence, off_value=low_confidence) loss = -jnp.sum(soft_targets * nn.log_softmax(logits), axis=-1) loss = loss - normalizing_constant loss = loss * weights normalizing_factor = weights.sum() mean_loss = loss.sum() / normalizing_factor return mean_loss, logits
def cross_entropy(logits, targets, weights=None, label_smoothing=0.0): """Compute cross entropy and entropy for log probs and targets. Args: logits: [batch, length, num_classes] float array. targets: categorical targets [batch, length] int array. weights: None or array of shape [batch, length] label_smoothing: label smoothing constant, used to determine the on and off values. Returns: Tuple of scalar loss and batch normalizing factor. """ if logits.ndim != targets.ndim + 1: raise ValueError( "Incorrect shapes. Got shape %s logits and %s targets" % (str(logits.shape), str(targets.shape)) ) vocab_size = logits.shape[-1] confidence = 1.0 - label_smoothing low_confidence = (1.0 - confidence) / (vocab_size - 1) normalizing_constant = -( confidence * jnp.log(confidence) + (vocab_size - 1) * low_confidence * jnp.log(low_confidence + 1e-20) ) soft_targets = common_utils.onehot(targets, vocab_size, on_value=confidence, off_value=low_confidence) loss = -jnp.sum(soft_targets * log_softmax(logits), axis=-1) loss = loss - normalizing_constant if weights is not None: loss = loss * weights normalizing_factor = weights.sum() else: normalizing_factor = np.prod(targets.shape) return loss.sum(), normalizing_factor
def compute_weighted_cross_entropy(logits, targets, weights=None): """Compute weighted cross entropy and entropy for log probs and targets. Args: logits: `[batch, length, num_classes]` float array. targets: categorical targets `[batch, length]` int array. weights: None or array of shape [batch, length, 1] Returns: Tuple of scalar loss and batch normalizing factor. """ if logits.ndim != targets.ndim + 1: raise ValueError( 'Incorrect shapes. Got shape %s logits and %s targets' % (str(logits.shape), str(targets.shape))) if logits.shape[1] != targets.shape[1]: # Truncate logits. logits = logits[:, :targets.shape[1]] onehot_targets = common_utils.onehot(targets, logits.shape[-1]) loss = -jnp.sum(onehot_targets * nn.log_softmax(logits), axis=-1) normalizing_factor = jnp.prod(jnp.asarray(targets.shape)) if weights is not None: loss = loss * weights normalizing_factor = weights.sum() return loss.sum(), normalizing_factor
def test_small_byt5_integration_test(self): """ For comparision run: >>> import t5 # pip install t5==0.9.1 >>> path_to_byt5_small_checkpoint = '<fill_in>' >>> t5_model = t5.models.MtfModel(model_dir=path_to_tf_checkpoint, batch_size=1, tpu=None) >>> vocab = t5.data.ByteVocabulary() >>> score = t5_model.score(inputs=["Hello there"], targets=["Hi I am"], vocabulary=vocab) """ model = FlaxT5ForConditionalGeneration.from_pretrained( "google/byt5-small") tokenizer = ByT5Tokenizer.from_pretrained("google/byt5-small") input_ids = tokenizer("Hello there", return_tensors="np").input_ids labels = tokenizer("Hi I am", return_tensors="np").input_ids decoder_input_ids = shift_tokens_right( labels, model.config.pad_token_id, model.config.decoder_start_token_id) logits = model(input_ids, decoder_input_ids=decoder_input_ids).logits loss = optax.softmax_cross_entropy(logits, onehot(labels, logits.shape[-1])).mean() mtf_score = -(labels.shape[-1] * loss.item()) EXPECTED_SCORE = -60.7397 self.assertTrue(abs(mtf_score - EXPECTED_SCORE) < 1e-4)
def compute_contrastive_loss( quantized_features, transformer_features, negative_indices, mask_time_indices, logits_temp, num_negatives ): batch_size, sequence_length, hidden_size = quantized_features.shape # take negative vectors from sampled indices quantized_negatives = quantized_features.reshape(-1, hidden_size)[negative_indices.reshape(-1)] quantized_negatives = quantized_negatives.reshape( batch_size, sequence_length, num_negatives, hidden_size ).transpose(2, 0, 1, 3) target_features = jnp.concatenate([quantized_features[None, :], quantized_negatives], axis=0) loss_logits = optax.cosine_similarity(transformer_features, target_features) loss_logits = loss_logits / logits_temp neg_is_pos = (quantized_features == quantized_negatives).all(-1) neg_is_pos = jnp.concatenate([jnp.full((1,) + loss_logits.shape[1:], False), neg_is_pos], axis=0) # make sure incorrectly sampled vectors don't contribute to loss loss_logits = jnp.where(neg_is_pos, -1e9, loss_logits) predictions = loss_logits.transpose(2, 1, 0).reshape(-1, loss_logits.shape[0]) targets = ((1 - mask_time_indices) * -100).transpose(1, 0).flatten() target_mask = jnp.where(targets >= 0, 1.0, 0.0) contrastive_loss = optax.softmax_cross_entropy(predictions, onehot(targets, predictions.shape[-1])) * target_mask contrastive_loss = contrastive_loss.sum() return contrastive_loss
def eval_step(params, batch): labels = batch.pop("labels") outputs = model(**batch, output_attentions=True, params=params, train=False) logits = outputs["logits"] # compute loss loss = optax.softmax_cross_entropy(logits, onehot(labels, logits.shape[-1])) # compute accuracy accuracy = jnp.equal(jnp.argmax(logits, axis=-1), labels) # compute head specialization specialization = compute_specialization_metric( jnp.swapaxes(jnp.stack(outputs["encoder_attentions"]), 0, 1)) # summarize metrics metrics = { "loss": loss.mean(), "accuracy": accuracy.mean(), "specialization": specialization } metrics = jax.lax.pmean(metrics, axis_name="batch") return metrics
def __call__(self, input_ids, type_ids, labels=None, deterministic=False): """Applies model for sequence classification. Args: input_ids: <int>[BATCH_SIZE, MAX_SEQ_LENGTH] tokenized inputs. type_ids: <int>[BATCH_SIZE, MAX_SEQ_LENGTH] Ids partitioning input into different types. labels: True labels associated with inputs. Generally only required for training. Shape depends on task type: * Classification: <int>[BATCH_SIZE], * Regression: <float>[BATCH_SIZE]. deterministic: Whether to apply dropout to input. Returns: * If labels supplied (training mode): Model loss and metrics. * If no labels supplied (prediction / evaluation mode): Logits with shape <float>[BATCH_SIZE, n_classes]. """ encoder_output = EncoderModel(self.config, name="encoder")( input_ids, type_ids, deterministic=deterministic) # All other classification and regression tasks use the pooled output. output = encoder_output.pooled_output # TODO(jamesleethorp): For WiC, the original SuperGLUE paper # (https://arxiv.org/abs/1905.00537) concatenates the "CLS" and "word" # output representations. We only use the pooled output. logits = layers.OutputProjection(n_out=self.n_classes, kernel_init=default_kernel_init, name="classification")(output) if labels is None: # Code path used during evaluation or prediction; metrics can be computed # from logits by the caller. return logits # Code path used during training. if (self.config.dataset_name == "glue/stsb" or # Regression task self.config.dataset_name == "super_glue/copa" or # "Regression" task self.config.dataset_name == "super_glue/record"): # "Regression" task # Logits have shape: [BATCH_SIZE, 1]. per_example_loss = jnp.sum((logits[Ellipsis, 0] - labels)**2, axis=-1) batch_loss = jnp.mean(per_example_loss) return ClassificationStats(batch_loss=batch_loss, num_labels=labels.size) else: # Classification task # Logits have shape: [BATCH_SIZE, self.n_classes]. logits = nn.log_softmax(logits, axis=-1) per_example_loss = -jnp.sum( onehot(labels, logits.shape[-1]) * logits, axis=-1) batch_loss = jnp.mean(per_example_loss) correct_predictions = jnp.sum(logits.argmax(-1) == labels) return ClassificationStats(batch_loss=batch_loss, num_labels=labels.size, correct_predictions=correct_predictions)
def sampling_loop_body_fn(state): """Sampling loop state update.""" i, sequences, cache, cur_token, ended, rng, tokens_to_logits_state = state # Split RNG for sampling. rng1, rng2 = random.split(rng) # Call fast-decoder model on current tokens to get raw next-position logits. logits, new_cache, new_tokens_to_logits_state = tokens_to_logits( cur_token, cache, internal_state=tokens_to_logits_state) logits = logits / temperature # Mask out the BOS token. if masked_tokens is not None: mask = common_utils.onehot(jnp.array(masked_tokens), num_classes=logits.shape[-1], on_value=LARGE_NEGATIVE) mask = jnp.sum(mask, axis=0)[None, :] # Combine multiple masks together logits = logits + mask # Apply the repetition penalty. if repetition_penalty != 1: logits = apply_repetition_penalty( sequences, logits, i, repetition_penalty=repetition_penalty, repetition_window=repetition_window, repetition_penalty_normalize=repetition_penalty_normalize) # Mask out everything but the top-k entries. if top_k is not None: # Compute top_k_index and top_k_threshold with shapes (batch_size, 1). top_k_index = jnp.argsort(logits, axis=-1)[:, ::-1][:, top_k - 1:top_k] top_k_threshold = jnp.take_along_axis(logits, top_k_index, axis=-1) logits = jnp.where(logits < top_k_threshold, jnp.full_like(logits, LARGE_NEGATIVE), logits) # Sample next token from logits. sample = multinomial(rng1, logits) next_token = sample.astype(jnp.int32) # Only use sampled tokens if we have past the out_of_prompt_marker. out_of_prompt = (sequences[:, i + 1] == out_of_prompt_marker) next_token = (next_token * out_of_prompt + sequences[:, i + 1] * ~out_of_prompt) # If end-marker reached for batch item, only emit padding tokens. next_token = next_token[:, None] next_token_or_endpad = jnp.where(ended, jnp.full_like(next_token, pad_token), next_token) ended |= (next_token_or_endpad == end_marker) # Add current sampled tokens to recorded sequences. new_sequences = lax.dynamic_update_slice(sequences, next_token_or_endpad, (0, i + 1)) return (i + 1, new_sequences, new_cache, next_token_or_endpad, ended, rng2, new_tokens_to_logits_state)
def loss_fn(params): labels = batch.pop("labels") logits = state.apply_fn(**batch, params=params, dropout_rng=dropout_rng, train=True)[0] # compute loss loss = optax.softmax_cross_entropy(logits, onehot(labels, logits.shape[-1])).mean() return loss
def __call__( self, input_ids, input_mask, type_ids, labels = None, deterministic = False ): """Applies model for sequence classification. Args: input_ids: <int>[BATCH_SIZE, MAX_SEQ_LENGTH] tokenized inputs. input_mask: <bool>[BATCH_SIZE, MAX_SEQ_LENGTH] mask separating actual inputs from padding. Only used by BERT. type_ids: <int>[BATCH_SIZE, MAX_SEQ_LENGTH] Ids partitioning input into different types. labels: True labels associated with inputs. Generally only required for training. Shape depends on task type: * Classification: <int>[BATCH_SIZE] * Regression: <float>[BATCH_SIZE] deterministic: Whether or not to apply dropout to input. Returns: * If labels supplied (training mode): Model loss and metrics. * If no labels supplied (prediction / evaluation mode): Logits of shape <float>[BATCH_SIZE, n_classes]. """ _, pooled_output = EncoderModel( self.config, name="encoder")( input_ids, input_mask, type_ids, deterministic=deterministic) logits = layers.OutputProjection( n_out=self.n_classes, kernel_init=default_kernel_init, name="classification")( pooled_output) if labels is None: # Code path used during evaluation or prediction; metrics can be computed # from logits by the caller. return logits # Code path used during training. if self.config.dataset_name == "glue/stsb": # Regression task loss = jnp.mean((logits[Ellipsis, 0] - labels)**2) return {"loss": loss, "num_labels": labels.size} else: # Classification task logits = nn.log_softmax(logits) loss = -jnp.mean( jnp.sum(onehot(labels, logits.shape[-1]) * logits, axis=-1)) correct_predictions = jnp.sum(logits.argmax(-1) == labels) return { "loss": loss, "correct_predictions": correct_predictions, "num_labels": labels.size }
def cross_entropy_loss(log_softmax_logits, labels): """Returns the cross-entropy classification loss. Args: log_softmax_logits: The log of the softmax of the logits for the mini-batch, e.g. as output by jax.nn.log_softmax(logits). labels: The labels for the mini-batch. """ num_classes = log_softmax_logits.shape[-1] one_hot_labels = common_utils.onehot(labels, num_classes) return -jnp.sum(one_hot_labels * log_softmax_logits) / labels.size
def loss(params): x_in = shift_right(x) logits = model.apply(params, x_in, rngs={ "permute": permute_key, "dropout": dropout_key }) log_prob = nn.log_softmax(logits) x_onehot = onehot(x, num_classes=10) nll = -jnp.sum(x_onehot * log_prob, axis=-1) return jnp.mean(nll)
def sample(self, masked_inputs, rng): """Fill in MASK positions in inputs.""" mask_positions = masked_inputs == self.domain.vocab.mask logits = self.score(masked_inputs) # Mask out MASK token. mask = common_utils.onehot(jnp.array([self.domain.vocab.mask]), num_classes=logits.shape[-1], on_value=sampling.LARGE_NEGATIVE) logits = logits + mask samples = jax.random.categorical(rng, logits=logits) infilled = onp.where(mask_positions, samples, masked_inputs) return infilled
def cross_entropy_loss(logprobs, label, num_classes): """Computes the cross entropy loss for one datapoint. Args: logprobs: log probabilities predicted by the model label: true class label num_classes: number of classes in the task Returns: loss: value of the loss. """ one_hot_labels = common_utils.onehot(label, num_classes=num_classes) return -jnp.sum(one_hot_labels * logprobs)
def loss_fn(params): labels = batch.pop("labels") logits = state.apply_fn(**batch, params=params, dropout_rng=dropout_rng, train=True)[0] # compute loss, ignore padded input tokens label_mask = jnp.where(labels > 0, 1.0, 0.0) loss = optax.softmax_cross_entropy(logits, onehot(labels, logits.shape[-1])) * label_mask # take average loss = loss.sum() / label_mask.sum() return loss
def metrics_fn(self, logits, batch): """Calculates metrics for the classification task. Args: logits: float array; Output of the model->[batch, length, num_classes]. batch: dict; Batch of data that has 'label' and optionally 'weights'. Returns: a dict of metrics. """ target_is_onehot = logits.shape == batch['label'].shape if target_is_onehot: one_hot_targets = batch['label'] else: one_hot_targets = common_utils.onehot(batch['label'], logits.shape[-1]) if self.dataset.meta_data['num_classes'] == 1: # If this is a binary classification task, make sure the shape of labels # is (bs, 1) and is the same as the shape of logits. one_hot_targets = jnp.reshape(one_hot_targets, logits.shape) if self.task_params.get('class_indices'): possible_labels_indices = self.task_params.get('class_indices') one_hot_targets = one_hot_targets[:, possible_labels_indices] logits = logits[:, possible_labels_indices] weights = batch.get('weights') # weights might not be defined metrics_dic = {} for key in self._METRICS: metric_val, metric_normalizer = self._METRICS[key](logits, one_hot_targets, weights) metrics_dic[key] = (jax.lax.psum(metric_val, 'batch'), jax.lax.psum(metric_normalizer, 'batch')) # Store dataset related factors. for key in batch: if 'factor' in key: factors = batch[key] if weights is not None: val = jnp.sum(metrics.apply_weights(factors, weights)) norm = jnp.sum(weights) else: val = jnp.sum(factors) norm = len(factors) metrics_dic[key] = (jax.lax.psum(val, 'batch'), jax.lax.psum(norm, 'batch')) return metrics_dic
def loss_fn(logits, labels, z_loss=0): shift_logits = logits[..., :-1, :] shift_labels = labels[..., 1:] shift_labels = onehot(shift_labels, shift_logits.shape[-1]) shift_logits = shift_logits - jax.lax.stop_gradient(shift_logits.max(axis=-1, keepdims=True)) log_z = jnp.log(jnp.sum(jnp.exp(shift_logits), axis=-1, keepdims=True)) log_softmax = shift_logits - log_z loss = -jnp.sum(shift_labels * log_softmax, axis=-1) loss += (1e-4 * jnp.square(log_z.squeeze(-1))) * z_loss return loss.mean()
def _compute_metrics(self, masked_lm_logits, next_sentence_logits, masked_lm_labels, masked_lm_weights, next_sentence_labels, **unused_kwargs): """Computes the pre-training loss and its components.""" masked_lm_logits = nn.log_softmax(masked_lm_logits) masked_lm_labels = onehot(masked_lm_labels.reshape((-1, )), masked_lm_logits.shape[-1]) masked_lm_weights = masked_lm_weights.reshape((-1, )) masked_lm_loss = -jnp.sum( jnp.sum(masked_lm_logits * masked_lm_labels, axis=-1) * masked_lm_weights) / jnp.sum(masked_lm_weights) next_sentence_logits = nn.log_softmax(next_sentence_logits) next_sentence_labels = next_sentence_labels.reshape((-1, )) next_sentence_loss = -jnp.mean( jnp.sum( onehot(next_sentence_labels, next_sentence_logits.shape[-1]) * next_sentence_logits, axis=-1)) return { 'loss': masked_lm_loss + next_sentence_loss, 'masked_lm_loss': masked_lm_loss, 'next_sentence_loss': next_sentence_loss, }
def get_masked_lm_output(logits, label_ids, label_weights): """Calculate masked_lm loss for pretrain task.""" vocab_size = logits.shape[-1] label_ids = jnp.reshape(label_ids, (-1)) label_weights = jnp.reshape(label_weights, (-1)) one_hot_labels = common_utils.onehot( label_ids, vocab_size, on_value=1.0, off_value=0.0) log_probs = nn.log_softmax(logits) per_example_loss = -jnp.sum(log_probs * one_hot_labels, axis=-1) numerator = jnp.sum(label_weights * per_example_loss) denominator = jnp.sum(label_weights) + 1e-5 loss = numerator / denominator return loss, per_example_loss, log_probs
def compute_weighted_cross_entropy(logits, labels): """Compute weighted cross entropy and entropy for log probs and labels. Args: logits: [batch, length, num_classes] float array. labels: categorical targets [batch, length] int array. Returns: Tuple of scalars of loss and per example loss. """ log_probs = nn.log_softmax(logits) labels = jnp.reshape(labels, [-1]) one_hot_labels = common_utils.onehot(labels, num_classes=2) per_example_loss = -jnp.sum(one_hot_labels * log_probs, axis=-1) loss = jnp.mean(per_example_loss) return (loss, per_example_loss)
def eval_step(params, batch): labels = batch.pop("labels") logits = model(**batch, params=params, train=False)[0] # compute loss loss = optax.softmax_cross_entropy(logits, onehot(labels, logits.shape[-1])) # compute accuracy accuracy = jnp.equal(jnp.argmax(logits, axis=-1), labels) # summarize metrics metrics = {"loss": loss.mean(), "accuracy": accuracy.mean()} metrics = jax.lax.pmean(metrics, axis_name="batch") return metrics
def cross_entropy(logits, targets, weights = None, label_smoothing = 0.0): vocab_size = logits.shape[-1] confidence = 1.0 - label_smoothing low_confidence = (1.0 - confidence) / (vocab_size - 1) normalizing_constant = -( confidence * jnp.log(confidence) + (vocab_size - 1) * low_confidence * jnp.log(low_confidence + 1e-20) ).astype(logits.dtype) soft_targets = common_utils.onehot(targets,vocab_size) loss = -jnp.sum(soft_targets*nn.log_softmax(logits),axis=-1,dtype=logits.dtype) loss -= normalizing_constant if weights is not None: loss *= weights normalizing_factor = weights.sum() else: normalizing_factor = np.prod(targets.shape,dtype=logits.dtype) return loss.sum(),normalizing_factor
def loss_fn(params): labels = batch.pop("labels") outputs = state.apply_fn(**batch, output_attentions=True, params=params, dropout_rng=dropout_rng, train=True) logits = outputs["logits"] # compute loss loss = optax.softmax_cross_entropy( logits, onehot(labels, logits.shape[-1])).mean() return loss, jnp.swapaxes(jnp.stack(outputs["encoder_attentions"]), 0, 1)
def eval_step(params, batch): labels = batch.pop("labels") logits = model(**batch, params=params, train=False)[0] # compute loss, ignore padded input tokens label_mask = jnp.where(labels > 0, 1.0, 0.0) loss = optax.softmax_cross_entropy(logits, onehot(labels, logits.shape[-1])) * label_mask # compute accuracy accuracy = jnp.equal(jnp.argmax(logits, axis=-1), labels) * label_mask # summarize metrics metrics = {"loss": loss.sum(), "accuracy": accuracy.sum(), "normalizer": label_mask.sum()} metrics = jax.lax.psum(metrics, axis_name="batch") return metrics
def compute_weighted_cross_entropy(logits, targets, weights=None, label_smoothing=0.0, z_loss=0.0): """Compute weighted cross entropy and entropy for log probs and targets. Args: logits: [batch, length, num_classes] float array. targets: categorical one-hot targets [batch, length, category] int array. weights: None or array of shape [batch, length]. label_smoothing: label smoothing constant, used to determine the on and off values. z_loss: coefficient for auxilliary z-loss loss term. Returns: Tuple of scalar loss and batch normalizing factor. """ targets = targets.reshape((-1)) if logits.ndim != targets.ndim + 1: raise ValueError( 'Incorrect shapes. Got shape %s logits and %s targets' % (str(logits.shape), str(targets.shape))) vocab_size = logits.shape[-1] confidence = 1.0 - label_smoothing low_confidence = (1.0 - confidence) / (vocab_size - 1) normalizing_constant = -( confidence * jnp.log(confidence) + (vocab_size - 1) * low_confidence * jnp.log(low_confidence + 1e-20)) soft_targets = common_utils.onehot(targets, vocab_size, on_value=confidence, off_value=low_confidence) loss = cross_entropy_with_logits(logits, soft_targets, z_loss=z_loss) loss = loss - normalizing_constant normalizing_factor = np.prod(targets.shape) if weights is not None: weights = weights.reshape((-1)) loss = loss * weights normalizing_factor = jnp.sum(weights) # HACK T5's "loss_denominator" correction for batchsize 2048 * 114 targetlen.. # normalizing_factor = 233472.0 return jnp.sum(loss), normalizing_factor