def loss_fn(model): """Loss function used for training.""" with nn.stochastic(dropout_rng): logits = model(inputs, train=True) loss, weight_sum = train_utils.compute_weighted_cross_entropy( logits, targets, num_classes=10, weights=None) mean_loss = loss / weight_sum return mean_loss, logits
def compute_metrics(logits, labels, num_classes, weights): """Compute summary metrics.""" loss, weight_sum = train_utils.compute_weighted_cross_entropy( logits, labels, num_classes, weights=weights) acc, _ = train_utils.compute_weighted_accuracy(logits, labels, weights) metrics = { 'loss': loss, 'accuracy': acc, 'denominator': weight_sum, } metrics = jax.lax.psum(metrics, 'batch') return metrics