def compute_metrics( masked_lm_logits: jnp.ndarray, next_sentence_logits: jnp.ndarray, masked_lm_labels: jnp.ndarray, masked_lm_weights: jnp.ndarray, next_sentence_labels: jnp.ndarray, ): """Computes the pre-training loss and its components.""" masked_lm_logits = nn.log_softmax(masked_lm_logits) masked_lm_labels = onehot(masked_lm_labels.reshape((-1, )), masked_lm_logits.shape[-1]) masked_lm_weights = masked_lm_weights.reshape((-1, )) masked_lm_loss = -jnp.sum( jnp.sum(masked_lm_logits * masked_lm_labels, axis=-1) * masked_lm_weights) / jnp.sum(masked_lm_weights) next_sentence_logits = nn.log_softmax(next_sentence_logits) next_sentence_labels = next_sentence_labels.reshape((-1, )) next_sentence_loss = -jnp.mean( jnp.sum( onehot(next_sentence_labels, next_sentence_logits.shape[-1]) * next_sentence_logits, axis=-1, )) return { "loss": masked_lm_loss + next_sentence_loss, "masked_lm_loss": masked_lm_loss, "next_sentence_loss": next_sentence_loss, }
def __call__(self, x): """Define the convolutional network architecture. Architecture originates from "Human-level control through deep reinforcement learning.", Nature 518, no. 7540 (2015): 529-533. Note that this is different than the one from "Playing atari with deep reinforcement learning." arxiv.org/abs/1312.5602 (2013) Network is used to both estimate policy (logits) and expected state value; in other words, hidden layers' params are shared between policy and value networks, see e.g.: github.com/openai/baselines/blob/master/baselines/ppo1/cnn_policy.py """ dtype = jnp.float32 x = x.astype(dtype) / 255. x = nn.Conv(features=32, kernel_size=(8, 8), strides=(4, 4), name='conv1', dtype=dtype)(x) x = nn.relu(x) x = nn.Conv(features=64, kernel_size=(4, 4), strides=(2, 2), name='conv2', dtype=dtype)(x) x = nn.relu(x) x = nn.Conv(features=64, kernel_size=(3, 3), strides=(1, 1), name='conv3', dtype=dtype)(x) x = nn.relu(x) x = x.reshape((x.shape[0], -1)) # flatten x = nn.Dense(features=512, name='hidden', dtype=dtype)(x) x = nn.relu(x) logits = nn.Dense(features=self.num_outputs, name='logits', dtype=dtype)(x) policy_log_probabilities = nn.log_softmax(logits) value = nn.Dense(features=1, name='value', dtype=dtype)(x) return policy_log_probabilities, value
def compute_per_pos_loss(logits, targets, weights=None, label_smoothing=0.0): """Compute weighted cross entropy and entropy for log probs and targets. Args: logits: [batch, length, num_classes] float array. targets: categorical targets [batch, length] int array. weights: None or array of shape [batch, length]. label_smoothing: label smoothing constant, used to determine the on and off values. Returns: Tuple of scalar loss and batch normalizing factor. """ if logits.ndim != targets.ndim + 1: raise ValueError('Incorrect shapes. Got shape %s logits and %s targets' % (str(logits.shape), str(targets.shape))) vocab_size = logits.shape[-1] confidence = 1.0 - label_smoothing low_confidence = (1.0 - confidence) / (vocab_size - 1) normalizing_constant = -( confidence * jnp.log(confidence) + (vocab_size - 1) * low_confidence * jnp.log(low_confidence + 1e-20)) soft_targets = common_utils.onehot( targets, vocab_size, on_value=confidence, off_value=low_confidence) loss = -jnp.sum(soft_targets * nn.log_softmax(logits), axis=-1) loss = loss - normalizing_constant if weights is not None: loss = loss * weights return loss
def weighted_unnormalized_cross_entropy(logits, targets, weights=None): """Compute weighted cross entropy and entropy for log probs and targets. This computes sum_(x,y) ce(x, y) for a single, potentially padded minibatch. If the minibatch is padded (that is it contains null examples) it is assumed that weights is a binary mask where 0 indicates that the example is null. Args: logits: [batch, length, num_classes] float array. targets: one hot vector of shape [batch, ..., num_classes]. weights: None or array of shape [batch x ...] (rank of one_hot_targets -1). Returns: Cross entropy loss computed per example, shape [batch, ...]. """ if logits.ndim != targets.ndim: raise ValueError('Incorrect shapes. Got shape %s logits and %s targets' % (str(logits.shape), str(targets.shape))) loss = -jnp.sum(targets * nn.log_softmax(logits), axis=-1) if weights is not None: if weights.ndim != targets.ndim - 1: raise ValueError('Incorrect shapes. Got shape %s weights and %s targets' % (str(weights.shape), str(targets.shape))) loss = loss * weights return loss
def compute_weighted_cross_entropy(logits, targets, weights=None): """Compute weighted cross entropy and entropy for log probs and targets. Args: logits: `[batch, length, num_classes]` float array. targets: categorical targets `[batch, length]` int array. weights: None or array of shape [batch, length, 1] Returns: Tuple of scalar loss and batch normalizing factor. """ if logits.ndim != targets.ndim + 1: raise ValueError( 'Incorrect shapes. Got shape %s logits and %s targets' % (str(logits.shape), str(targets.shape))) if logits.shape[1] != targets.shape[1]: # Truncate logits. logits = logits[:, :targets.shape[1]] onehot_targets = common_utils.onehot(targets, logits.shape[-1]) loss = -jnp.sum(onehot_targets * nn.log_softmax(logits), axis=-1) normalizing_factor = jnp.prod(jnp.asarray(targets.shape)) if weights is not None: loss = loss * weights normalizing_factor = weights.sum() return loss.sum(), normalizing_factor
def cross_entropy(logits, targets, axis=-1): logprobs = nn.log_softmax(logits, axis=axis) nll = np.take_along_axis(logprobs, np.expand_dims(targets, axis=axis), axis=axis) ce = -np.mean(nll) return ce
def __call__( self, input_ids: jnp.ndarray, input_mask: jnp.ndarray, type_ids: jnp.ndarray, labels: jnp.ndarray = None, *, deterministic: bool = False, ): """Applies BERT for sequence classification.""" bert = BertModel(config=self.config, name="bert") _, pooled_output = bert(input_ids, input_mask, type_ids, deterministic=deterministic) pooled_output = nn.Dropout(rate=self.config.hidden_dropout_prob, deterministic=deterministic)(pooled_output) logits = layers.OutputProjection( n_out=self.n_classes, kernel_init=get_kernel_init(self.config), name="classification", )(pooled_output) if labels is None: return logits elif logits.shape[-1] == 1: # Regression task loss = jnp.mean((logits[..., 0] - labels)**2) return {"loss": loss} else: # Classification task logits = nn.log_softmax(logits) loss = -jnp.mean( jnp.sum(onehot(labels, logits.shape[-1]) * logits, axis=-1)) return {"loss": loss}
def __call__(self, x, train: bool = True): conv = partial(nn.Conv, use_bias=False, dtype=self.dtype) norm = partial(nn.BatchNorm, use_running_average=not train, momentum=0.9, epsilon=1e-5, dtype=self.dtype) x = conv(self.num_filters, (7, 7), (2, 2), padding=[(3, 3), (3, 3)], name='conv_init')(x) x = norm(name='bn_init')(x) x = nn.relu(x) x = nn.max_pool(x, (3, 3), strides=(2, 2), padding='SAME') for i, block_size in enumerate(self.stage_sizes): for j in range(block_size): strides = (2, 2) if i > 0 and j == 0 else (1, 1) x = self.block_cls(self.num_filters * 2**i, strides=strides, conv=conv, norm=norm, act=self.act)(x) x = jnp.mean(x, axis=(1, 2)) x = nn.Dense(self.num_classes, dtype=self.dtype)(x) x = jnp.asarray(x, self.dtype) x = nn.log_softmax(x) return x
def __call__(self, x): x = OGDense(first, name="first")(x) for n in rest: x = OGDense(n)(x) x = nn.Dense(10, name="last")(x) x = nn.log_softmax(x) return x
def __call__(self, x): x = nn.Conv(features=8, kernel_size=(3, 3), strides=(1, 1))(x) x = activation(x) x = jnp.reshape(x, (x.shape[0], -1)) x = nn.Dense(10)(x) x = nn.log_softmax(x) return x
def __call__(self, x): if self.nhidden > 0: x = nn.Dense(self.nhidden)(x) x = nn.relu(x) x = nn.Dense(self.nclasses)(x) x = nn.log_softmax(x) return x
def __call__(self, x): # Helper macro. R_ = lambda hidden_: ResidualUnit(hidden_features=hidden_, norm=self.norm, training=self.training, activation=nn.gelu) # First filter to make features. h = nn.Conv(features=self.hidden * self.alpha, use_bias=False, kernel_size=(3, 3), kernel_init=INITS[self.kernel_init])(x) h = NORMS[self.norm](use_running_average=not self.training)(h) h = nn.gelu(h) # 2 stages of continuous segments: h = ResidualStitch(hidden_features=self.hidden * self.alpha, output_features=self.hidden * self.alpha, strides=(1, 1), norm=self.norm, training=self.training, activation=nn.gelu)(h) h = StatefulContinuousBlock(R=R_(self.hidden * self.alpha), scheme=self.scheme, n_step=self.n_step, n_basis=self.n_basis, basis=self.basis, training=self.training)(h) # Pool and linearly classify: h = NORMS[self.norm](use_running_average=not self.training)(h) h = nn.gelu(h) h = nn.avg_pool(h, window_shape=(8, 8), strides=(8, 8)) h = h.reshape((h.shape[0], -1)) h = nn.Dense(features=self.n_classes)(h) return nn.log_softmax(h) # no softmax
def __call__(self, x): x = jnp.reshape(x, (-1, 28 * 28)) x = nn.Dense(1024)(x) x = activation(x) x = nn.Dense(10)(x) x = nn.log_softmax(x) return x
def loss_fn(params): """loss function used for training.""" logits = models.Transformer(config).apply( {"params": params}, inputs, targets, inputs_positions=inputs_positions, targets_positions=targets_positions, inputs_segmentation=inputs_segmentation, targets_segmentation=targets_segmentation, rngs={"dropout": dropout_rng}) vocab_size = logits.shape[-1] confidence = 1.0 - label_smoothing low_confidence = (1.0 - confidence) / (vocab_size - 1) normalizing_constant = -( confidence * jnp.log(confidence) + (vocab_size - 1) * low_confidence * jnp.log(low_confidence + 1e-20)) soft_targets = common_utils.onehot( targets, vocab_size, on_value=confidence, off_value=low_confidence) loss = -jnp.sum(soft_targets * nn.log_softmax(logits), axis=-1) loss = loss - normalizing_constant loss = loss * weights normalizing_factor = weights.sum() mean_loss = loss.sum() / normalizing_factor return mean_loss, logits
def __call__(self, input_ids, type_ids, labels=None, deterministic=False): """Applies model for sequence classification. Args: input_ids: <int>[BATCH_SIZE, MAX_SEQ_LENGTH] tokenized inputs. type_ids: <int>[BATCH_SIZE, MAX_SEQ_LENGTH] Ids partitioning input into different types. labels: True labels associated with inputs. Generally only required for training. Shape depends on task type: * Classification: <int>[BATCH_SIZE], * Regression: <float>[BATCH_SIZE]. deterministic: Whether to apply dropout to input. Returns: * If labels supplied (training mode): Model loss and metrics. * If no labels supplied (prediction / evaluation mode): Logits with shape <float>[BATCH_SIZE, n_classes]. """ encoder_output = EncoderModel(self.config, name="encoder")( input_ids, type_ids, deterministic=deterministic) # All other classification and regression tasks use the pooled output. output = encoder_output.pooled_output # TODO(jamesleethorp): For WiC, the original SuperGLUE paper # (https://arxiv.org/abs/1905.00537) concatenates the "CLS" and "word" # output representations. We only use the pooled output. logits = layers.OutputProjection(n_out=self.n_classes, kernel_init=default_kernel_init, name="classification")(output) if labels is None: # Code path used during evaluation or prediction; metrics can be computed # from logits by the caller. return logits # Code path used during training. if (self.config.dataset_name == "glue/stsb" or # Regression task self.config.dataset_name == "super_glue/copa" or # "Regression" task self.config.dataset_name == "super_glue/record"): # "Regression" task # Logits have shape: [BATCH_SIZE, 1]. per_example_loss = jnp.sum((logits[Ellipsis, 0] - labels)**2, axis=-1) batch_loss = jnp.mean(per_example_loss) return ClassificationStats(batch_loss=batch_loss, num_labels=labels.size) else: # Classification task # Logits have shape: [BATCH_SIZE, self.n_classes]. logits = nn.log_softmax(logits, axis=-1) per_example_loss = -jnp.sum( onehot(labels, logits.shape[-1]) * logits, axis=-1) batch_loss = jnp.mean(per_example_loss) correct_predictions = jnp.sum(logits.argmax(-1) == labels) return ClassificationStats(batch_loss=batch_loss, num_labels=labels.size, correct_predictions=correct_predictions)
def __call__( self, inputs, ): """Applies ResNet model. Number of residual blocks inferred from hparams.""" num_classes = self.num_classes hparams = self.hparams num_filters = self.num_filters dtype = self.dtype x = aqt_flax_layers.ConvAqt( features=num_filters, kernel_size=(7, 7), strides=(2, 2), padding=[(3, 3), (3, 3)], use_bias=False, dtype=dtype, name='init_conv', train=self.train, quant_context=self.quant_context, paxis_name='batch', hparams=hparams.conv_init, )(inputs) x = nn.BatchNorm(use_running_average=not self.train, momentum=0.9, epsilon=1e-5, dtype=dtype, name='init_bn')(x) x = nn.relu(x) x = nn.max_pool(x, (3, 3), strides=(2, 2), padding='SAME') filter_multiplier = hparams.filter_multiplier for i, block_hparams in enumerate(hparams.residual_blocks): proj = block_hparams.conv_proj # For projection layers (unless it is the first layer), strides = (2, 2) if i > 0 and proj is not None: filter_multiplier *= 2 strides = (2, 2) else: strides = (1, 1) x = ResidualBlock(filters=int(num_filters * filter_multiplier), hparams=block_hparams, quant_context=self.quant_context, strides=strides, train=self.train, dtype=dtype)(x) x = jnp.mean(x, axis=(1, 2)) x = aqt_flax_layers.DenseAqt( features=num_classes, dtype=dtype, train=self.train, quant_context=self.quant_context, paxis_name='batch', hparams=hparams.dense_layer, )(x, padding_mask=None) x = jnp.asarray(x, dtype) output = nn.log_softmax(x) return output
def __call__(self, x): x = nn.Dense(1024, bias_init=nn.initializers.normal(stddev=1.0))(x) x = nn.relu(x) x = nn.Dense(1024, bias_init=nn.initializers.normal(stddev=1.0))(x) x = nn.relu(x) x = nn.Dense(10)(x) x = nn.log_softmax(x) return x
def __call__( self, input_ids, input_mask, type_ids, labels = None, deterministic = False ): """Applies model for sequence classification. Args: input_ids: <int>[BATCH_SIZE, MAX_SEQ_LENGTH] tokenized inputs. input_mask: <bool>[BATCH_SIZE, MAX_SEQ_LENGTH] mask separating actual inputs from padding. Only used by BERT. type_ids: <int>[BATCH_SIZE, MAX_SEQ_LENGTH] Ids partitioning input into different types. labels: True labels associated with inputs. Generally only required for training. Shape depends on task type: * Classification: <int>[BATCH_SIZE] * Regression: <float>[BATCH_SIZE] deterministic: Whether or not to apply dropout to input. Returns: * If labels supplied (training mode): Model loss and metrics. * If no labels supplied (prediction / evaluation mode): Logits of shape <float>[BATCH_SIZE, n_classes]. """ _, pooled_output = EncoderModel( self.config, name="encoder")( input_ids, input_mask, type_ids, deterministic=deterministic) logits = layers.OutputProjection( n_out=self.n_classes, kernel_init=default_kernel_init, name="classification")( pooled_output) if labels is None: # Code path used during evaluation or prediction; metrics can be computed # from logits by the caller. return logits # Code path used during training. if self.config.dataset_name == "glue/stsb": # Regression task loss = jnp.mean((logits[Ellipsis, 0] - labels)**2) return {"loss": loss, "num_labels": labels.size} else: # Classification task logits = nn.log_softmax(logits) loss = -jnp.mean( jnp.sum(onehot(labels, logits.shape[-1]) * logits, axis=-1)) correct_predictions = jnp.sum(logits.argmax(-1) == labels) return { "loss": loss, "correct_predictions": correct_predictions, "num_labels": labels.size }
def __call__(self, x: spec.Tensor, train: bool): del train input_size = 28 * 28 num_hidden = 128 num_classes = 10 x = x.reshape((x.shape[0], input_size)) # Flatten. x = nn.Dense(features=num_hidden, use_bias=True)(x) x = nn.sigmoid(x) x = nn.Dense(features=num_classes, use_bias=True)(x) x = nn.log_softmax(x) return x
def loss(params): x_in = shift_right(x) logits = model.apply(params, x_in, rngs={ "permute": permute_key, "dropout": dropout_key }) log_prob = nn.log_softmax(logits) x_onehot = onehot(x, num_classes=10) nll = -jnp.sum(x_onehot * log_prob, axis=-1) return jnp.mean(nll)
def __call__(self, x): x = nn.Conv(features=32, kernel_size=(3, 3))(x) x = nn.relu(x) x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2)) x = nn.Conv(features=64, kernel_size=(3, 3))(x) x = nn.relu(x) x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2)) x = x.reshape((x.shape[0], -1)) # flatten x = nn.Dense(features=256)(x) x = nn.relu(x) x = nn.Dense(features=10)(x) x = nn.log_softmax(x) return x
def __call__(self, inputs, train: bool = True): """Passes the input through the network. Arguments: inputs: [batch_size, height, width, channels] train: bool Returns: output: [batch_size, config.num_classes] """ cfg = self.config conv = partial(nn.Conv, use_bias=False, dtype=cfg.dtype, precision=cfg.precision, kernel_init=cfg.kernel_init) norm = partial(nn.BatchNorm, use_running_average=not train, momentum=cfg.bn_momentum, epsilon=cfg.bn_epsilon, dtype=cfg.dtype) y = conv(cfg.initial_filters, kernel_size=(7, 7), strides=(2, 2), padding=[(3, 3), (3, 3)])(inputs) y = norm()(y) y = cfg.activation_fn(y) y = nn.max_pool(y, (3, 3), strides=(2, 2), padding='SAME') for i, block_size in enumerate(cfg.stage_sizes[:-1]): for j in range(block_size): strides = (2, 2) if i > 0 and j == 0 else (1, 1) y = BottleneckResNetBlock(filters=cfg.initial_filters * 2**i, strides=strides, config=cfg, conv=conv, norm=norm)(y) for j in range(cfg.stage_sizes[-1]): strides = (2, 2) if j == 0 and cfg.stride_one is False else (1, 1) y = BoTBlock(filters=cfg.initial_filters * 2**(i + 1), strides=strides, config=cfg, conv=conv, norm=norm)(y) y = jnp.mean(y, axis=(1, 2)) y = nn.Dense(cfg.num_classes, dtype=cfg.dtype, kernel_init=cfg.kernel_init, bias_init=cfg.bias_init)(y) y = jnp.asarray(y, dtype=cfg.dtype) y = nn.log_softmax(y) return y
def __call__(self, x_dict: _InputBatch) -> _LogitBatch: # Each feature is of shape f32[B, 1] x_tuple = tuple(x_dict[feature] for feature in _FEATURE_KEYS_XF) x_array = jnp.concatenate(x_tuple, axis=-1) # shape: f32[B, 4] assert x_array.ndim == 2 assert x_array.shape[1] == 4 x = x_array x = nn.Dense(features=8)(x) x = nn.relu(x) x = nn.Dense(features=8)(x) x = nn.relu(x) x = nn.Dense(features=3)(x) x = nn.log_softmax(x, axis=-1) return x
def __call__(self, x, with_classifier=True): x = nn.Conv(features=32, kernel_size=(3, 3))(x) x = nn.relu(x) x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2)) x = nn.Conv(features=64, kernel_size=(3, 3))(x) x = nn.relu(x) x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2)) # TODO: replace np.prod(x.shape[1:]) with -1 once we fix shape_polymorphism x = x.reshape((x.shape[0], np.prod(x.shape[1:]))) # flatten x = nn.Dense(features=256)(x) x = nn.relu(x) if not with_classifier: return x x = nn.Dense(features=10)(x) x = nn.log_softmax(x) return x
def cross_entropy(logits, targets, weights = None, label_smoothing = 0.0): vocab_size = logits.shape[-1] confidence = 1.0 - label_smoothing low_confidence = (1.0 - confidence) / (vocab_size - 1) normalizing_constant = -( confidence * jnp.log(confidence) + (vocab_size - 1) * low_confidence * jnp.log(low_confidence + 1e-20) ).astype(logits.dtype) soft_targets = common_utils.onehot(targets,vocab_size) loss = -jnp.sum(soft_targets*nn.log_softmax(logits),axis=-1,dtype=logits.dtype) loss -= normalizing_constant if weights is not None: loss *= weights normalizing_factor = weights.sum() else: normalizing_factor = np.prod(targets.shape,dtype=logits.dtype) return loss.sum(),normalizing_factor
def _compute_weighted_cross_entropy(logits, targets, weights=None): """Computes weighted cross entropy and entropy for log probs and targets. Args: logits: <float>[NUM_EXAMPLES, NUM_CLASSES] predicted logits. targets: <int>[NUM_EXAMPLES] true labels. weights: <float>[NUM_EXAMPLES] relative weights for labels. Returns: Loss and normalizing factor for input batch. """ onehot_targets = onehot(targets, logits.shape[-1]) loss = -jnp.sum(onehot_targets * nn.log_softmax(logits), axis=-1) normalizing_factor = onehot_targets.sum() if weights is not None: loss = loss * weights normalizing_factor = weights.sum() return loss.sum(), normalizing_factor
def __call__(self, x): alpha = self.alpha n_step = self.n_step h = nn.Conv(features=alpha, kernel_size=(3, 3), use_bias=False, kernel_init=INITS[self.kernel_init])(x) h = NORMS[self.norm](use_running_average=not self.training)(h) h = nn.relu(h) for i in range(n_step): h += ResidualUnit(hidden_features=alpha, norm=self.norm, kernel_init=self.kernel_init, training=self.training)(h) h = ResidualStitch(hidden_features=alpha, output_features=2 * alpha, strides=(2, 2), norm=self.norm, kernel_init=self.kernel_init, training=self.training)(h) for i in range(n_step): h += ResidualUnit(hidden_features=2 * alpha, norm=self.norm, kernel_init=self.kernel_init, training=self.training)(h) h = ResidualStitch(hidden_features=2 * alpha, output_features=4 * alpha, strides=(2, 2), norm=self.norm, kernel_init=self.kernel_init, training=self.training)(h) for i in range(n_step): h += ResidualUnit(hidden_features=4 * alpha, norm=self.norm, kernel_init=self.kernel_init, training=self.training)(h) h = NORMS[self.norm](use_running_average=not self.training)(h) # h = nn.pooling.avg_pool(h, (h.shape[-3], h.shape[-2])) # h = h.reshape(h.shape[0], -1) h = jnp.mean(h, axis=(1, 2)) h = nn.Dense(features=self.n_classes)(h) return nn.log_softmax(h) # no softmax
def __call__(self, x): """Define the model architecture. Network is used to both estimate policy (logits) and expected state value; in other words, hidden layers' params are shared between policy and value networks. Args: x: input of shape N, H, W(1) Returns: policy_log_probabilities: logits value: state value """ x = x.astype(self.dtype) x = nn.Conv( features=self.chan1, kernel_size=[3], strides=1, name='conv1', dtype=self.dtype)( x) x = nn.relu(x) x = nn.Conv( features=self.chan2, kernel_size=[3], strides=1, name='conv2', dtype=self.dtype)( x) x = nn.relu(x) x = x.reshape((x.shape[0], -1)) # flatten logits = nn.Dense( features=self.num_actions, name='logits', dtype=self.dtype)( x) policy_log_probabilities = nn.log_softmax(logits) value = nn.Dense(features=1, name='value', dtype=self.dtype)(x) return policy_log_probabilities, value
def compute_weighted_cross_entropy(self, logits, targets, weights, label_smoothing=0.0): """Compute weighted cross entropy and entropy for log probs and targets. Args: logits: [batch, length, num_classes] float array. targets: categorical targets [batch, length] int array. weights: array of shape [batch, length]. label_smoothing: label smoothing constant, used to determine the on and off values. Returns: Tuple of loss for every example and batch normalizing factor. """ if logits.ndim != targets.ndim + 1: raise ValueError( f'Incorrect shapes. Got shape {str(logits.shape)} logits ' f'and {str(targets.shape)} targets') confidence = 1.0 - label_smoothing low_confidence = (1.0 - confidence) / (self._vocab_size - 1) normalizing_constant = -(confidence * jnp.log(confidence) + ((self._vocab_size - 1) * low_confidence * jnp.log(low_confidence + 1e-20))) soft_targets = common_utils.onehot(targets, self._vocab_size, on_value=confidence, off_value=low_confidence) loss = -jnp.sum(soft_targets * nn.log_softmax(logits), axis=-1) loss = loss - normalizing_constant if weights is not None: loss = loss * weights return loss
def cross_entropy(logits, targets): """Compute weighted cross entropy and entropy for log probs and targets. Args: logits: [batch, length, num_classes] float array. targets: categorical targets [batch, length] int array. Returns: Tuple of scalar loss and batch normalizing factor. """ if logits.ndim != targets.ndim + 1: raise ValueError('Incorrect shapes. Got shape %s logits and %s targets' % (str(logits.shape), str(targets.shape))) vocab_size = logits.shape[-1] onehot_targets = common_utils.onehot(targets, vocab_size) loss = -jnp.sum(onehot_targets * nn.log_softmax(logits), axis=-1) d = np.prod(targets.shape[1:]) loss = util_fns.sum_except_batch(loss) / d / np.log(2) return loss