Ejemplo n.º 1
0
def weighted_binary_accuracy(logits, targets, weights=None):
    """Compute weighted accuracy over the given batch.

  This computes the accuracy over a single, potentially padded minibatch.
  If the minibatch is padded (that is it contains null examples) it is assumed
  that weights is a binary mask where 0 indicates that the example is null.
  We assume the trainer will aggregate and divide by number of samples.

  Args:
   logits: float array; Output of model in shape [batch, ..., 1].
   targets: float array; Target labels of shape [batch, ..., 1].
   weights: None or array of shape [batch, ...] (rank of one_hot_targets -1).

  Returns:
    The mean accuracy of the examples in the given batch as a scalar.
  """

    if logits.ndim != targets.ndim:
        raise ValueError(
            'Incorrect shapes. Got shape %s logits and %s one_hot_targets' %
            (str(logits.shape), str(targets.shape)))

    logits = nn.sigmoid(logits)
    preds = logits > 0.5
    correct = jnp.equal(preds, targets)

    if weights is not None:
        correct = apply_weights(correct, weights)

    if weights is None:
        normalization = np.prod(targets.shape[:-1])
    else:
        normalization = weights.sum()
    return jnp.sum(correct), normalization
Ejemplo n.º 2
0
 def apply(self, x, reduction=16):
     num_channels = x.shape[-1]
     y = x.mean(axis=(1, 2))
     y = nn.Dense(y, features=num_channels // reduction, bias=False)
     y = nn.relu(y)
     y = nn.Dense(y, features=num_channels, bias=False)
     y = nn.sigmoid(y)
     return x * y[:, None, None, :]
Ejemplo n.º 3
0
def sigmoid_hinge_loss(logits, targets):
    """Computes hinge loss given predictions and labels.

  Args:
    logits: float array; Output of model in shape `[ ..., num_classes]`.
    targets: int array; Labels with shape  `[..., num_classes]`.

  Returns:
    Loss value.
  """
    probs = nn.sigmoid(logits)
    loss = jnp.sum(jnp.maximum(0, 1. - jnp.multiply(probs, targets)), axis=-1)

    return loss
Ejemplo n.º 4
0
def multilabel_accuracy(logits, labels, stats):
  """Strict multilabel classification accuracy (all labels have to be correct).

  Args:
    logits: Tensor, shape [batch_size, num_classes]
    labels: Tensor, shape [batch_size, num_classes], values in {0, 1}
    stats: Dict of statistics output by the model.

  Returns:
    per_sample_success: Tensor of shape [batch_size]
  """
  del stats
  error = jnp.abs(labels - jnp.round(nn.sigmoid(logits)))
  return 1. - jnp.max(error, axis=-1)
Ejemplo n.º 5
0
    def apply(self, x, communication=Communication.NONE, train=True):
        """Forward pass."""
        batch_size = x.shape[0]

        if communication is Communication.SQUEEZE_EXCITE_X:
            x = sample_patches.SqueezeExciteLayer(x)
        # end if squeeze excite x

        d1 = nn.relu(
            nn.Conv(x,
                    128,
                    kernel_size=(3, 3),
                    strides=(1, 1),
                    bias=True,
                    name="down1"))
        d2 = nn.relu(
            nn.Conv(d1,
                    128,
                    kernel_size=(3, 3),
                    strides=(2, 2),
                    bias=True,
                    name="down2"))
        d3 = nn.relu(
            nn.Conv(d2,
                    128,
                    kernel_size=(3, 3),
                    strides=(2, 2),
                    bias=True,
                    name="down3"))

        if communication is Communication.SQUEEZE_EXCITE_D:
            d1_flatten = einops.rearrange(d1, "b h w c -> b (h w) c")
            d2_flatten = einops.rearrange(d2, "b h w c -> b (h w) c")
            d3_flatten = einops.rearrange(d3, "b h w c -> b (h w) c")

            nd1 = d1_flatten.shape[1]
            nd2 = d2_flatten.shape[1]

            d_together = jnp.concatenate([d1_flatten, d2_flatten, d3_flatten],
                                         axis=1)

            num_channels = d_together.shape[-1]
            y = d_together.mean(axis=1)
            y = nn.Dense(y, features=num_channels // 4, bias=False)
            y = nn.relu(y)
            y = nn.Dense(y, features=num_channels, bias=False)
            y = nn.sigmoid(y)

            d_together = d_together * y[:, None, :]

            # split and reshape
            d1 = d_together[:, :nd1].reshape(d1.shape)
            d2 = d_together[:, nd1:nd1 + nd2].reshape(d2.shape)
            d3 = d_together[:, nd1 + nd2:].reshape(d3.shape)

        elif communication is Communication.TRANSFORMER:
            d1_flatten = einops.rearrange(d1, "b h w c -> b (h w) c")
            d2_flatten = einops.rearrange(d2, "b h w c -> b (h w) c")
            d3_flatten = einops.rearrange(d3, "b h w c -> b (h w) c")

            nd1 = d1_flatten.shape[1]
            nd2 = d2_flatten.shape[1]

            d_together = jnp.concatenate([d1_flatten, d2_flatten, d3_flatten],
                                         axis=1)

            positional_encodings = self.param(
                "scale_ratio_position_encodings",
                shape=(1, ) + d_together.shape[1:],
                initializer=jax.nn.initializers.normal(1. /
                                                       d_together.shape[-1]))
            d_together = transformer.Transformer(d_together +
                                                 positional_encodings,
                                                 num_layers=2,
                                                 num_heads=8,
                                                 is_training=train)

            # split and reshape
            d1 = d_together[:, :nd1].reshape(d1.shape)
            d2 = d_together[:, nd1:nd1 + nd2].reshape(d2.shape)
            d3 = d_together[:, nd1 + nd2:].reshape(d3.shape)

        t1 = nn.Conv(d1,
                     6,
                     kernel_size=(1, 1),
                     strides=(1, 1),
                     bias=True,
                     name="tidy1")
        t2 = nn.Conv(d2,
                     6,
                     kernel_size=(1, 1),
                     strides=(1, 1),
                     bias=True,
                     name="tidy2")
        t3 = nn.Conv(d3,
                     9,
                     kernel_size=(1, 1),
                     strides=(1, 1),
                     bias=True,
                     name="tidy3")

        raw_scores = (jnp.split(t1, 6, axis=-1) + jnp.split(t2, 6, axis=-1) +
                      jnp.split(t3, 9, axis=-1))

        # The following is for normalization.
        t = jnp.concatenate((jnp.reshape(
            t1, [batch_size, -1]), jnp.reshape(
                t2, [batch_size, -1]), jnp.reshape(t3, [batch_size, -1])),
                            axis=1)
        t_min = jnp.reshape(jnp.min(t, axis=-1), [batch_size, 1, 1, 1])
        t_max = jnp.reshape(jnp.max(t, axis=-1), [batch_size, 1, 1, 1])
        normalized_scores = zeroone(raw_scores, t_min, t_max)

        stats = {
            "scores": normalized_scores,
            "raw_scores": t,
        }
        # removes the split dimension. scores are now b x h' x w' shaped
        normalized_scores = [s.squeeze(-1) for s in normalized_scores]

        return normalized_scores, stats
Ejemplo n.º 6
0
    def apply(self,
              x,
              layer=LAYER_EVONORM_B0,
              nonlinearity=True,
              num_groups=32,
              group_size=None,
              batch_stats=None,
              use_running_average=False,
              axis=-1,
              momentum=0.99,
              epsilon=1e-5,
              dtype=jnp.float32,
              bias=True,
              scale=True,
              bias_init=initializers.zeros,
              scale_init=initializers.ones,
              axis_name=None,
              axis_index_groups=None):
        """Normalizes the input using batch statistics.

    Args:
      x: the input to be normalized.
      layer: LAYER_EVONORM_B0 or LAYER_EVONORM_S0.
      nonlinearity: use the EvoNorm nonlinearity.
      num_groups: number of groups to use for group statistics.
      group_size: size of groups, see nn.GroupNorm.
      batch_stats: a `flax.nn.Collection` used to store an exponential moving
        average of the batch statistics (default: None).
      use_running_average: if true, the statistics stored in batch_stats will be
        used instead of computing the batch statistics on the input.
      axis: the feature or non-batch axis of the input.
      momentum: decay rate for the exponential moving average of the batch
        statistics.
      epsilon: a small float added to variance to avoid dividing by zero.
      dtype: the dtype of the computation (default: float32).
      bias:  if True, bias (beta) is added.
      scale: if True, multiply by scale (gamma). When the next layer is linear
        (also e.g. nn.relu), this can be disabled since the scaling will be done
        by the next layer.
      bias_init: initializer for bias, by default, zero.
      scale_init: initializer for scale, by default, one.
      axis_name: the axis name used to combine batch statistics from multiple
        devices. See `jax.pmap` for a description of axis names (default: None).
      axis_index_groups: groups of axis indices within that named axis
        representing subsets of devices to reduce over (default: None). For
          example, `[[0, 1], [2, 3]]` would independently batch-normalize over
          the examples on the first two and last two devices. See `jax.lax.psum`
          for more details.

    Returns:
      Normalized inputs (the same shape as inputs).
    """
        x = jnp.asarray(x, jnp.float32)

        axis = axis if isinstance(axis, tuple) else (axis, )
        # pylint: disable=protected-access
        axis = nn.normalization._absolute_dims(x.ndim, axis)
        # pylint: enable=protected-access
        feature_shape = tuple(d if i in axis else 1
                              for i, d in enumerate(x.shape))
        reduced_feature_shape = tuple(d for i, d in enumerate(x.shape)
                                      if i in axis)
        instance_reduction_axis = tuple(i for i in range(x.ndim)
                                        if i not in axis and i > 0)
        batch_reduction_axis = tuple(i for i in range(x.ndim) if i not in axis)

        if nonlinearity:
            v = self.param('v', reduced_feature_shape,
                           jax.nn.initializers.ones).reshape(feature_shape)
            if layer == LAYER_EVONORM_S0:
                den, group_shape, input_shape = _GroupStd(
                    x,
                    num_groups=num_groups,
                    group_size=group_size,
                    epsilon=epsilon,
                    dtype=dtype,
                )
                x = x * nn.sigmoid(v * x)
                x = x.reshape(group_shape)
                x /= den
                x = x.reshape(input_shape)
            elif layer == LAYER_EVONORM_B0:
                if self.is_stateful() or batch_stats:
                    ra_var = self.state('var',
                                        reduced_feature_shape,
                                        initializers.ones,
                                        collection=batch_stats)
                else:
                    ra_var = None

                if use_running_average:
                    if ra_var is None:
                        raise ValueError(
                            'when use_running_averages is True '
                            'either use a stateful context or provide batch_stats'
                        )
                    var = ra_var.value
                else:
                    mean = jnp.mean(x,
                                    axis=batch_reduction_axis,
                                    keepdims=False)
                    mean2 = jnp.mean(lax.square(x),
                                     axis=batch_reduction_axis,
                                     keepdims=False)
                    if axis_name is not None and not self.is_initializing():
                        concatenated_mean = jnp.concatenate([mean, mean2])
                        mean, mean2 = jnp.split(
                            lax.pmean(concatenated_mean,
                                      axis_name=axis_name,
                                      axis_index_groups=axis_index_groups), 2)
                    var = mean2 - lax.square(mean)

                    if ra_var and not self.is_initializing():
                        ra_var.value = momentum * ra_var.value + (
                            1 - momentum) * var

                left = lax.sqrt(var + epsilon)

                instance_std = jnp.sqrt(
                    x.var(axis=instance_reduction_axis, keepdims=True) +
                    epsilon)
                right = v * x + instance_std
                x = x / jnp.maximum(left, right)
            else:
                raise ValueError('Unknown EvoNorm layer: {}'.format(layer))

        if scale:
            x *= self.param('scale', reduced_feature_shape,
                            scale_init).reshape(feature_shape)
        if bias:
            x = x + self.param('bias', reduced_feature_shape,
                               bias_init).reshape(feature_shape)
        return jnp.asarray(x, dtype)
Ejemplo n.º 7
0
 def apply(self, x):
     tau = self.param('tau', x.shape[-1:], nn.initializers.ones)
     return x * nn.sigmoid(x * tau)
Ejemplo n.º 8
0
 def apply(self, x):
     return nn.sigmoid(x) * nn.sigmoid(-x) * nn.tanh(x) * (1 / 0.15)