def apply(self, input_ids, input_mask, type_ids, labels=None, *, config, n_classes, deterministic=False): """Applies BERT for sequence classification.""" unused_sequence_output, pooled_output = BertModel( input_ids, input_mask, type_ids, config=config, deterministic=deterministic, name="bert") # TODO(kitaev): I think I'm missing dropout here logits = layers.OutputProjection(pooled_output, n_out=n_classes, kernel_init=kernel_initializer, name="classification") if labels is None: return logits elif logits.shape[-1] == 1: # Regression task loss = jnp.mean((logits[Ellipsis, 0] - labels)**2) return {"loss": loss} else: # Classification task logits = nn.log_softmax(logits) loss = -jnp.mean( jnp.sum(onehot(labels, logits.shape[-1]) * logits, axis=-1)) return {"loss": loss}
def weighted_unnormalized_cross_entropy(logits, one_hot_targets, weights=None): """Compute weighted cross entropy and entropy for log probs and targets. This computes sum_(x,y) ce(x, y) for a single, potentially padded minibatch. If the minibatch is padded (that is it contains null examples) it is assumed that weights is a binary mask where 0 indicates that the example is null. Args: logits: [batch, length, num_classes] float array. one_hot_targets: one hot vector of shape [batch, ..., num_classes]. weights: None or array of shape [batch x ...] (rank of one_hot_targets -1). Returns: Cross entropy loss computed per example, shape [batch, ...]. """ if logits.ndim != one_hot_targets.ndim: raise ValueError( 'Incorrect shapes. Got shape %s logits and %s one_hot_targets' % (str(logits.shape), str(one_hot_targets.shape))) loss = -jnp.sum(one_hot_targets * nn.log_softmax(logits), axis=-1) if weights is not None: if weights.ndim != one_hot_targets.ndim - 1: raise ValueError( 'Incorrect shapes. Got shape %s weights and %s one_hot_targets' % (str(weights.shape), str(one_hot_targets.shape))) loss = loss * weights return loss
def apply_activation(intermediate_output, intermediate_activation): """Applies selected activation function to intermediate output.""" if intermediate_activation is None: return intermediate_output if intermediate_activation == 'gelu': intermediate_output = nn.gelu(intermediate_output) elif intermediate_activation == 'relu': intermediate_output = nn.relu(intermediate_output) elif intermediate_activation == 'sigmoid': intermediate_output = nn.sigmoid(intermediate_output) elif intermediate_activation == 'softmax': intermediate_output = nn.softmax(intermediate_output) elif intermediate_activation == 'celu': intermediate_output = nn.celu(intermediate_output) elif intermediate_activation == 'elu': intermediate_output = nn.elu(intermediate_output) elif intermediate_activation == 'log_sigmoid': intermediate_output = nn.log_sigmoid(intermediate_output) elif intermediate_activation == 'log_softmax': intermediate_output = nn.log_softmax(intermediate_output) elif intermediate_activation == 'soft_sign': intermediate_output = nn.soft_sign(intermediate_output) elif intermediate_activation == 'softplus': intermediate_output = nn.softplus(intermediate_output) elif intermediate_activation == 'swish': intermediate_output = nn.swish(intermediate_output) elif intermediate_activation == 'tanh': intermediate_output = jnp.tanh(intermediate_output) else: raise NotImplementedError( '%s activation function is not yet supported.' % intermediate_activation) return intermediate_output
def rnn_cell(carry, x): newCarry, logits = jax.vmap(eval_cell)(carry[0], carry[1]) sampleOut = jax.random.categorical(x, logits) sample = jax.nn.one_hot(sampleOut, inputDim) logProb = jnp.sum(nn.log_softmax(logits) * sample, axis=1) return (newCarry, sample), (jnp.nan_to_num(logProb, nan=-35), sampleOut)
def apply(self, x, num_classes, num_filters=64, num_layers=50, train=True, dtype=jnp.float32): if num_layers not in _block_size_options: raise ValueError('Please provide a valid number of layers') block_sizes = _block_size_options[num_layers] x = nn.Conv(x, num_filters, (7, 7), (2, 2), padding=[(3, 3), (3, 3)], bias=False, dtype=dtype, name='init_conv') x = nn.BatchNorm(x, use_running_average=not train, momentum=0.9, epsilon=1e-5, dtype=dtype, name='init_bn') x = nn.max_pool(x, (3, 3), strides=(2, 2), padding='SAME') for i, block_size in enumerate(block_sizes): for j in range(block_size): strides = (2, 2) if i > 0 and j == 0 else (1, 1) x = ResidualBlock(x, num_filters * 2**i, strides=strides, train=train, dtype=dtype) x = jnp.mean(x, axis=(1, 2)) x = nn.Dense(x, num_classes) x = nn.log_softmax(x) return x
def apply(self, x, num_classes, *, stage_sizes, block_cls, num_filters=64, dtype=jnp.float32, act=nn.relu, train=True): conv = nn.Conv.partial(bias=False, dtype=dtype) norm = nn.BatchNorm.partial( use_running_average=not train, momentum=0.9, epsilon=1e-5, dtype=dtype) x = conv(x, num_filters, (7, 7), (2, 2), padding=[(3, 3), (3, 3)], name='conv_init') x = norm(x, name='bn_init') x = nn.relu(x) x = nn.max_pool(x, (3, 3), strides=(2, 2), padding='SAME') for i, block_size in enumerate(stage_sizes): for j in range(block_size): strides = (2, 2) if i > 0 and j == 0 else (1, 1) x = block_cls(x, num_filters * 2 ** i, strides=strides, conv=conv, norm=norm, act=act) x = jnp.mean(x, axis=(1, 2)) x = nn.Dense(x, num_classes, dtype=dtype) x = jnp.asarray(x, dtype) x = nn.log_softmax(x) return x
def apply(self, x, hidden_layers, hidden_dim, n_classes): x = jnp.reshape(x, (x.shape[0], -1)) for layer in range(hidden_layers): x = nn.Dense(x, hidden_dim, name=f'fc{layer}') x = nn.relu(x) x = nn.Dense(x, n_classes, name=f'fc{hidden_layers}') preds = nn.log_softmax(x) return preds
def apply(self, x, num_classes, num_filters=64, num_layers=50, train=True, axis_name=None, axis_index_groups=None, dtype=jnp.float32, conv0_space_to_depth=False): if num_layers not in _block_size_options: raise ValueError('Please provide a valid number of layers') block_sizes = _block_size_options[num_layers] if conv0_space_to_depth: conv = SpaceToDepthConv.partial(block_size=(2, 2), padding=[(2, 1), (2, 1)]) else: conv = nn.Conv.partial(padding=[(3, 3), (3, 3)]) x = conv(x, num_filters, kernel_size=(7, 7), strides=(2, 2), bias=False, dtype=dtype, name='conv0') x = nn.BatchNorm(x, use_running_average=not train, momentum=0.9, epsilon=1e-5, name='init_bn', axis_name=axis_name, axis_index_groups=axis_index_groups, dtype=dtype) x = nn.relu(x) # MLPerf-required x = nn.max_pool(x, (3, 3), strides=(2, 2), padding='SAME') for i, block_size in enumerate(block_sizes): for j in range(block_size): strides = (2, 2) if i > 0 and j == 0 else (1, 1) x = ResidualBlock(x, num_filters * 2**i, strides=strides, train=train, axis_name=axis_name, axis_index_groups=axis_index_groups, dtype=dtype) x = jnp.mean(x, axis=(1, 2)) x = nn.Dense(x, num_classes, kernel_init=nn.initializers.normal(), dtype=dtype) x = nn.log_softmax(x) return x
def apply(self, x): x = nn.Conv(x, features=32, kernel_size=(3, 3)) x = nn.relu(x) x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2)) x = nn.Conv(x, features=64, kernel_size=(3, 3)) x = nn.relu(x) x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2)) x = x.reshape((x.shape[0], -1)) # flatten x = nn.Dense(x, features=256) x = nn.relu(x) x = nn.Dense(x, features=10) x = nn.log_softmax(x) return x
def apply(self, x, num_classes, parameters, num_filters=64, train=True, axis_name=None, num_layers='34'): block_sizes = [3, 4, 6] data_format = 'channels_last' if ('conv0_space_to_depth' in parameters and parameters['conv0_space_to_depth']): # conv0 uses space-to-depth transform for TPU performance. x = func_conv0_space_to_depth(inputs=x, data_format=data_format, dtype=parameters['dtype']) else: x = conv2d_fixed_padding( inputs=x, filters=num_filters, kernel_size=7, strides=2, data_format=data_format, name='init_conv') replica_groups = _make_replica_groups(parameters) x = nn.BatchNorm( x, use_running_average=not train, momentum=0.9, epsilon=1e-5, name='init_bn', axis_name=axis_name, dtype=parameters['dtype'], axis_index_groups=replica_groups) x = nn.relu(x) x = nn.max_pool(x, (3, 3), strides=(2, 2), padding='SAME') for i, block_size in enumerate(block_sizes): for j in range(block_size): strides = 2 if i == 1 and j == 0 else 1 use_projection = False if i == 0 or j > 0 else True x = ResidualBlock( x, num_filters * 2**i, parameters, strides=strides, train=train, axis_name=axis_name, use_projection=use_projection, data_format=data_format) if num_layers == '34': x = jnp.mean(x, axis=(1, 2)) x = nn.Dense(x, num_classes, kernel_init=nn.initializers.normal(), dtype=jnp.float32) # TODO(deveci): dtype=dtype x = nn.log_softmax(x) return x
def _compute_metrics(self, masked_lm_logits, next_sentence_logits, masked_lm_labels, masked_lm_weights, next_sentence_labels, **unused_kwargs): """Computes the pre-training loss and its components.""" masked_lm_logits = nn.log_softmax(masked_lm_logits) masked_lm_labels = onehot(masked_lm_labels.reshape((-1, )), masked_lm_logits.shape[-1]) masked_lm_weights = masked_lm_weights.reshape((-1, )) masked_lm_loss = -jnp.sum( jnp.sum(masked_lm_logits * masked_lm_labels, axis=-1) * masked_lm_weights) / jnp.sum(masked_lm_weights) next_sentence_logits = nn.log_softmax(next_sentence_logits) next_sentence_labels = next_sentence_labels.reshape((-1, )) next_sentence_loss = -jnp.mean( jnp.sum( onehot(next_sentence_labels, next_sentence_logits.shape[-1]) * next_sentence_logits, axis=-1)) return { 'loss': masked_lm_loss + next_sentence_loss, 'masked_lm_loss': masked_lm_loss, 'next_sentence_loss': next_sentence_loss, }
def softmax_cross_entropy(logits, onehot_labels): """Returns a cross-entropy loss tensor. Note that `onehot_labels` and `logits` must have the same shape, e.g. `[batch_size, num_classes]`. This does not perform a reduction on loss, e.g., loss is a `Tensor` of shape `[batch_size]`. Args: logits: Float-like logits outputs of the network. onehot_labels: Float-like one-hot-encoded labels. Returns: Loss `Tensor` of the same type as `logits`, which has shape `[batch_size]`. """ log_softmax = -nn.log_softmax(logits) return jnp.sum(log_softmax * onehot_labels, axis=-1)
def compute_weighted_cross_entropy(logits, labels): """Compute weighted cross entropy and entropy for log probs and labels. Args: logits: [batch, length, num_classes] float array. labels: categorical targets [batch, length] int array. Returns: Tuple of scalars of loss and per example loss. """ log_probs = nn.log_softmax(logits) labels = jnp.reshape(labels, [-1]) one_hot_labels = common_utils.onehot(labels, num_classes=2) per_example_loss = -jnp.sum(one_hot_labels * log_probs, axis=-1) loss = jnp.mean(per_example_loss) return (loss, per_example_loss)
def get_masked_lm_output(logits, label_ids, label_weights): """Calculate masked_lm loss for pretrain task.""" vocab_size = logits.shape[-1] label_ids = jnp.reshape(label_ids, (-1)) label_weights = jnp.reshape(label_weights, (-1)) one_hot_labels = common_utils.onehot( label_ids, vocab_size, on_value=1.0, off_value=0.0) log_probs = nn.log_softmax(logits) per_example_loss = -jnp.sum(log_probs * one_hot_labels, axis=-1) numerator = jnp.sum(label_weights * per_example_loss) denominator = jnp.sum(label_weights) + 1e-5 loss = numerator / denominator return loss, per_example_loss, log_probs
def apply(self, x, num_classes, train=True, batch_stats=None, axis_name=None, dtype=jnp.float32): x = nn.BatchNorm(x, batch_stats=batch_stats, use_running_average=not train, momentum=0.9, epsilon=1e-5, name='init_bn', axis_name=axis_name, dtype=dtype) x = jnp.mean(x, axis=(1, 2)) x = nn.Dense(x, num_classes, kernel_init=nn.initializers.normal(), dtype=dtype) x = nn.log_softmax(x) return x
def compute_logprob(inputs, model, mask_token=None): """Returns an array of log probabilities for the input sequences.""" assert inputs.ndim == 2 targets = inputs weights = jnp.where(targets != model.pad_token, 1, 0) if mask_token is not None: weights *= jnp.where(targets != mask_token, 1, 0) logits = model.score(inputs) assert logits.ndim == 3 onehot_targets = common_utils.onehot(targets, logits.shape[-1]) log_lik = jnp.sum(onehot_targets * nn.log_softmax(logits), axis=-1) log_lik *= weights log_prob = jnp.sum(log_lik, axis=-1) return log_prob
def compute_weighted_cross_entropy(logits, targets, token_weights=None, example_weights=None): """Compute weighted cross entropy and entropy for log probs and targets. The loss is assumed to be sum_i example_weights[i] * logprob[i], where i indexes elements in the batch. logprob[i] is the log probability of sequence i, which is a weighted average of per-token log probabilities with weights according to token_weights. Typically token_weights is a mask for whether tokens are padding or not. Maximum likelihood training sets example_weights[i] = 1. Training with a REINFORCE-style objective may set example_weights[i] to any positive or negative number. Args: logits: [batch, length, num_classes] float array. targets: categorical targets [batch, length] int array. token_weights: None or array of shape [batch x length] example_weights: None or array of shape [batch_size] Returns: Tuple of scalar loss and batch normalizing factor. """ if logits.ndim != targets.ndim + 1: raise ValueError( 'Incorrect shapes. Got shape %s logits and %s targets' % (str(logits.shape), str(targets.shape))) onehot_targets = common_utils.onehot(targets, logits.shape[-1]) loss = -jnp.sum(onehot_targets * nn.log_softmax(logits), axis=-1) normalizing_factor = onehot_targets.sum() if token_weights is not None: loss = loss * token_weights normalizing_factor = token_weights.sum() if example_weights is not None: loss = loss.sum(axis=1) loss *= example_weights return loss.sum(), normalizing_factor
def compute_weighted_cross_entropy(logits, targets, num_classes, weights=None): """Compute weighted cross entropy and entropy for log probs and targets. Args: logits: [batch, num_classes] float array. targets: categorical targets [batch, length] int array. num_classes: int, num classes of problem. weights: None or array of shape [batch x length] Returns: Tuple of scalar loss and batch normalizing factor. """ onehot_targets = common_utils.onehot(targets, num_classes) loss = -jnp.sum(onehot_targets * nn.log_softmax(logits), axis=-1) normalizing_factor = onehot_targets.sum() if weights is not None: loss = loss * weights normalizing_factor = weights.sum() return loss.sum(), normalizing_factor
def apply(self, node_x, edge_x, sources, targets): """Computes GNN forward pass. Args: node_x: node features with shape of `[num_nodes, num_features]`. edge_x: `None` or edge features with shape of `[num_edges, num_features]`. sources: Array of source node indices with shape of `[num_edges]`. targets: Array of target node indices with shape of `[num_edges]`. Returns: Output of shape `[num_nodes, num_features]`. """ node_x = GraphConvBlock(node_x, edge_x, sources, targets, features=32) node_x = nn.relu(node_x) node_x = GraphConvBlock(node_x, edge_x, sources, targets, features=2) node_x = nn.log_softmax(node_x) return node_x
def compute_weighted_cross_entropy(logits, targets, weights=None, label_smoothing=0.0): """Compute weighted cross entropy and entropy for log probs and targets. Args: logits: [batch, length, num_classes] float array. targets: categorical targets [batch, length] int array. weights: None or array of shape [batch, length]. label_smoothing: label smoothing constant, used to determine the on and off values. Returns: Tuple of scalar loss and batch normalizing factor. """ if logits.ndim != targets.ndim + 1: raise ValueError( 'Incorrect shapes. Got shape %s logits and %s targets' % (str(logits.shape), str(targets.shape))) vocab_size = logits.shape[-1] confidence = 1.0 - label_smoothing low_confidence = (1.0 - confidence) / (vocab_size - 1) normalizing_constant = -( confidence * jnp.log(confidence) + (vocab_size - 1) * low_confidence * jnp.log(low_confidence + 1e-20)) soft_targets = common_utils.onehot(targets, vocab_size, on_value=confidence, off_value=low_confidence) loss = -jnp.sum(soft_targets * nn.log_softmax(logits), axis=-1) loss = loss - normalizing_constant normalizing_factor = jnp.prod(targets.shape) if weights is not None: loss = loss * weights normalizing_factor = weights.sum() return loss.sum(), normalizing_factor
def apply(self, x, num_outputs): """Define the convolutional network architecture. Architecture originates from "Human-level control through deep reinforcement learning.", Nature 518, no. 7540 (2015): 529-533. Note that this is different than the one from "Playing atari with deep reinforcement learning." arxiv.org/abs/1312.5602 (2013) """ dtype = jnp.float32 x = x.astype(dtype) / 255. x = nn.Conv(x, features=32, kernel_size=(8, 8), strides=(4, 4), name='conv1', dtype=dtype) x = nn.relu(x) x = nn.Conv(x, features=64, kernel_size=(4, 4), strides=(2, 2), name='conv2', dtype=dtype) x = nn.relu(x) x = nn.Conv(x, features=64, kernel_size=(3, 3), strides=(1, 1), name='conv3', dtype=dtype) x = nn.relu(x) x = x.reshape((x.shape[0], -1)) # flatten x = nn.Dense(x, features=512, name='hidden', dtype=dtype) x = nn.relu(x) # Network used to both estimate policy (logits) and expected state value. # See github.com/openai/baselines/blob/master/baselines/ppo1/cnn_policy.py logits = nn.Dense(x, features=num_outputs, name='logits', dtype=dtype) policy_log_probabilities = nn.log_softmax(logits) value = nn.Dense(x, features=1, name='value', dtype=dtype) return policy_log_probabilities, value
def compute_weighted_cross_entropy(logits, targets, weights=None): """Compute weighted cross entropy and entropy for log probs and targets. Args: logits: [batch, length, num_classes] float array. targets: categorical targets [batch, length] int array. weights: None or array of shape [batch x length] Returns: Tuple of scalar loss and batch normalizing factor. """ if logits.ndim != targets.ndim + 1: raise ValueError('Incorrect shapes. Got shape %s logits and %s targets' % (str(logits.shape), str(targets.shape))) onehot_targets = common_utils.onehot(targets, logits.shape[-1]) loss = -jnp.sum(onehot_targets * nn.log_softmax(logits), axis=-1) normalizing_factor = onehot_targets.sum() if weights is not None: loss = loss * weights normalizing_factor = weights.sum() return loss.sum(), normalizing_factor
def apply(self, input_ids, input_mask, type_ids, labels=None, *, config, n_classes, deterministic=False): """Applies BERT for sequence classification.""" unused_sequence_output, pooled_output = BertModel( input_ids, input_mask, type_ids, config=config, deterministic=deterministic, name='bert') pooled_output = nn.dropout(pooled_output, rate=config.hidden_dropout_prob, deterministic=deterministic) logits = layers.OutputProjection(pooled_output, n_out=n_classes, kernel_init=get_kernel_init(config), name='classification') if labels is None: return logits elif logits.shape[-1] == 1: # Regression task loss = jnp.mean((logits[..., 0] - labels)**2) return {'loss': loss} else: # Classification task logits = nn.log_softmax(logits) loss = -jnp.mean( jnp.sum(onehot(labels, logits.shape[-1]) * logits, axis=-1)) return {'loss': loss}
def cross_entropy_loss(logits, labels, lengths): """Returns cross-entropy loss.""" xe = jnp.sum(nn.log_softmax(logits) * labels, axis=-1) masked_xe = jnp.mean(mask_sequences(xe, lengths)) return -masked_xe
def compute_cross_entropy(logits, targets): onehot_targets = common_utils.onehot(targets, logits.shape[-1]) loss = -jnp.sum(onehot_targets * nn.log_softmax(logits), axis=-1) return loss
def rnn_cell(carry, x): newCarry, out = rnnCell(carry[0], carry[1]) logProb = nn.log_softmax(actFun(probDense(out))) logProb = jnp.sum(logProb * x, axis=-1) return (newCarry, x), (jnp.nan_to_num(logProb, nan=-35), out)
def batch_ce_loss(logits, targets): one_hot_targets = np.eye(4)[targets] loss = -np.sum(one_hot_targets * nn.log_softmax(logits), axis=-1) return loss
def cross_entropy_loss(logits, labels): """Returns cross-entropy loss.""" return -jnp.mean(jnp.sum(nn.log_softmax(logits) * labels[:, 1:], axis=-1))