def weighted_binary_accuracy(logits, targets, weights=None): """Compute weighted accuracy over the given batch. This computes the accuracy over a single, potentially padded minibatch. If the minibatch is padded (that is it contains null examples) it is assumed that weights is a binary mask where 0 indicates that the example is null. We assume the trainer will aggregate and divide by number of samples. Args: logits: float array; Output of model in shape [batch, ..., 1]. targets: float array; Target labels of shape [batch, ..., 1]. weights: None or array of shape [batch, ...] (rank of one_hot_targets -1). Returns: The mean accuracy of the examples in the given batch as a scalar. """ if logits.ndim != targets.ndim: raise ValueError( 'Incorrect shapes. Got shape %s logits and %s one_hot_targets' % (str(logits.shape), str(targets.shape))) logits = nn.sigmoid(logits) preds = logits > 0.5 correct = jnp.equal(preds, targets) if weights is not None: correct = apply_weights(correct, weights) if weights is None: normalization = np.prod(targets.shape[:-1]) else: normalization = weights.sum() return jnp.sum(correct), normalization
def apply(self, x, reduction=16): num_channels = x.shape[-1] y = x.mean(axis=(1, 2)) y = nn.Dense(y, features=num_channels // reduction, bias=False) y = nn.relu(y) y = nn.Dense(y, features=num_channels, bias=False) y = nn.sigmoid(y) return x * y[:, None, None, :]
def sigmoid_hinge_loss(logits, targets): """Computes hinge loss given predictions and labels. Args: logits: float array; Output of model in shape `[ ..., num_classes]`. targets: int array; Labels with shape `[..., num_classes]`. Returns: Loss value. """ probs = nn.sigmoid(logits) loss = jnp.sum(jnp.maximum(0, 1. - jnp.multiply(probs, targets)), axis=-1) return loss
def multilabel_accuracy(logits, labels, stats): """Strict multilabel classification accuracy (all labels have to be correct). Args: logits: Tensor, shape [batch_size, num_classes] labels: Tensor, shape [batch_size, num_classes], values in {0, 1} stats: Dict of statistics output by the model. Returns: per_sample_success: Tensor of shape [batch_size] """ del stats error = jnp.abs(labels - jnp.round(nn.sigmoid(logits))) return 1. - jnp.max(error, axis=-1)
def apply(self, x, communication=Communication.NONE, train=True): """Forward pass.""" batch_size = x.shape[0] if communication is Communication.SQUEEZE_EXCITE_X: x = sample_patches.SqueezeExciteLayer(x) # end if squeeze excite x d1 = nn.relu( nn.Conv(x, 128, kernel_size=(3, 3), strides=(1, 1), bias=True, name="down1")) d2 = nn.relu( nn.Conv(d1, 128, kernel_size=(3, 3), strides=(2, 2), bias=True, name="down2")) d3 = nn.relu( nn.Conv(d2, 128, kernel_size=(3, 3), strides=(2, 2), bias=True, name="down3")) if communication is Communication.SQUEEZE_EXCITE_D: d1_flatten = einops.rearrange(d1, "b h w c -> b (h w) c") d2_flatten = einops.rearrange(d2, "b h w c -> b (h w) c") d3_flatten = einops.rearrange(d3, "b h w c -> b (h w) c") nd1 = d1_flatten.shape[1] nd2 = d2_flatten.shape[1] d_together = jnp.concatenate([d1_flatten, d2_flatten, d3_flatten], axis=1) num_channels = d_together.shape[-1] y = d_together.mean(axis=1) y = nn.Dense(y, features=num_channels // 4, bias=False) y = nn.relu(y) y = nn.Dense(y, features=num_channels, bias=False) y = nn.sigmoid(y) d_together = d_together * y[:, None, :] # split and reshape d1 = d_together[:, :nd1].reshape(d1.shape) d2 = d_together[:, nd1:nd1 + nd2].reshape(d2.shape) d3 = d_together[:, nd1 + nd2:].reshape(d3.shape) elif communication is Communication.TRANSFORMER: d1_flatten = einops.rearrange(d1, "b h w c -> b (h w) c") d2_flatten = einops.rearrange(d2, "b h w c -> b (h w) c") d3_flatten = einops.rearrange(d3, "b h w c -> b (h w) c") nd1 = d1_flatten.shape[1] nd2 = d2_flatten.shape[1] d_together = jnp.concatenate([d1_flatten, d2_flatten, d3_flatten], axis=1) positional_encodings = self.param( "scale_ratio_position_encodings", shape=(1, ) + d_together.shape[1:], initializer=jax.nn.initializers.normal(1. / d_together.shape[-1])) d_together = transformer.Transformer(d_together + positional_encodings, num_layers=2, num_heads=8, is_training=train) # split and reshape d1 = d_together[:, :nd1].reshape(d1.shape) d2 = d_together[:, nd1:nd1 + nd2].reshape(d2.shape) d3 = d_together[:, nd1 + nd2:].reshape(d3.shape) t1 = nn.Conv(d1, 6, kernel_size=(1, 1), strides=(1, 1), bias=True, name="tidy1") t2 = nn.Conv(d2, 6, kernel_size=(1, 1), strides=(1, 1), bias=True, name="tidy2") t3 = nn.Conv(d3, 9, kernel_size=(1, 1), strides=(1, 1), bias=True, name="tidy3") raw_scores = (jnp.split(t1, 6, axis=-1) + jnp.split(t2, 6, axis=-1) + jnp.split(t3, 9, axis=-1)) # The following is for normalization. t = jnp.concatenate((jnp.reshape( t1, [batch_size, -1]), jnp.reshape( t2, [batch_size, -1]), jnp.reshape(t3, [batch_size, -1])), axis=1) t_min = jnp.reshape(jnp.min(t, axis=-1), [batch_size, 1, 1, 1]) t_max = jnp.reshape(jnp.max(t, axis=-1), [batch_size, 1, 1, 1]) normalized_scores = zeroone(raw_scores, t_min, t_max) stats = { "scores": normalized_scores, "raw_scores": t, } # removes the split dimension. scores are now b x h' x w' shaped normalized_scores = [s.squeeze(-1) for s in normalized_scores] return normalized_scores, stats
def apply(self, x, layer=LAYER_EVONORM_B0, nonlinearity=True, num_groups=32, group_size=None, batch_stats=None, use_running_average=False, axis=-1, momentum=0.99, epsilon=1e-5, dtype=jnp.float32, bias=True, scale=True, bias_init=initializers.zeros, scale_init=initializers.ones, axis_name=None, axis_index_groups=None): """Normalizes the input using batch statistics. Args: x: the input to be normalized. layer: LAYER_EVONORM_B0 or LAYER_EVONORM_S0. nonlinearity: use the EvoNorm nonlinearity. num_groups: number of groups to use for group statistics. group_size: size of groups, see nn.GroupNorm. batch_stats: a `flax.nn.Collection` used to store an exponential moving average of the batch statistics (default: None). use_running_average: if true, the statistics stored in batch_stats will be used instead of computing the batch statistics on the input. axis: the feature or non-batch axis of the input. momentum: decay rate for the exponential moving average of the batch statistics. epsilon: a small float added to variance to avoid dividing by zero. dtype: the dtype of the computation (default: float32). bias: if True, bias (beta) is added. scale: if True, multiply by scale (gamma). When the next layer is linear (also e.g. nn.relu), this can be disabled since the scaling will be done by the next layer. bias_init: initializer for bias, by default, zero. scale_init: initializer for scale, by default, one. axis_name: the axis name used to combine batch statistics from multiple devices. See `jax.pmap` for a description of axis names (default: None). axis_index_groups: groups of axis indices within that named axis representing subsets of devices to reduce over (default: None). For example, `[[0, 1], [2, 3]]` would independently batch-normalize over the examples on the first two and last two devices. See `jax.lax.psum` for more details. Returns: Normalized inputs (the same shape as inputs). """ x = jnp.asarray(x, jnp.float32) axis = axis if isinstance(axis, tuple) else (axis, ) # pylint: disable=protected-access axis = nn.normalization._absolute_dims(x.ndim, axis) # pylint: enable=protected-access feature_shape = tuple(d if i in axis else 1 for i, d in enumerate(x.shape)) reduced_feature_shape = tuple(d for i, d in enumerate(x.shape) if i in axis) instance_reduction_axis = tuple(i for i in range(x.ndim) if i not in axis and i > 0) batch_reduction_axis = tuple(i for i in range(x.ndim) if i not in axis) if nonlinearity: v = self.param('v', reduced_feature_shape, jax.nn.initializers.ones).reshape(feature_shape) if layer == LAYER_EVONORM_S0: den, group_shape, input_shape = _GroupStd( x, num_groups=num_groups, group_size=group_size, epsilon=epsilon, dtype=dtype, ) x = x * nn.sigmoid(v * x) x = x.reshape(group_shape) x /= den x = x.reshape(input_shape) elif layer == LAYER_EVONORM_B0: if self.is_stateful() or batch_stats: ra_var = self.state('var', reduced_feature_shape, initializers.ones, collection=batch_stats) else: ra_var = None if use_running_average: if ra_var is None: raise ValueError( 'when use_running_averages is True ' 'either use a stateful context or provide batch_stats' ) var = ra_var.value else: mean = jnp.mean(x, axis=batch_reduction_axis, keepdims=False) mean2 = jnp.mean(lax.square(x), axis=batch_reduction_axis, keepdims=False) if axis_name is not None and not self.is_initializing(): concatenated_mean = jnp.concatenate([mean, mean2]) mean, mean2 = jnp.split( lax.pmean(concatenated_mean, axis_name=axis_name, axis_index_groups=axis_index_groups), 2) var = mean2 - lax.square(mean) if ra_var and not self.is_initializing(): ra_var.value = momentum * ra_var.value + ( 1 - momentum) * var left = lax.sqrt(var + epsilon) instance_std = jnp.sqrt( x.var(axis=instance_reduction_axis, keepdims=True) + epsilon) right = v * x + instance_std x = x / jnp.maximum(left, right) else: raise ValueError('Unknown EvoNorm layer: {}'.format(layer)) if scale: x *= self.param('scale', reduced_feature_shape, scale_init).reshape(feature_shape) if bias: x = x + self.param('bias', reduced_feature_shape, bias_init).reshape(feature_shape) return jnp.asarray(x, dtype)
def apply(self, x): tau = self.param('tau', x.shape[-1:], nn.initializers.ones) return x * nn.sigmoid(x * tau)
def apply(self, x): return nn.sigmoid(x) * nn.sigmoid(-x) * nn.tanh(x) * (1 / 0.15)