Exemplo n.º 1
0
    def compute_batch_outputs(params, images):
        logits, out = model.apply({'params': flax.core.freeze(params)},
                                  images,
                                  train=False)
        pre_logits = out['pre_logits']
        if config and config.model_type == 'batchensemble':
            ens_size = config.model.transformer.ens_size
            loss_name = config.get('loss', 'sigmoid_xent')
            tiled_logits = logits
            if loss_name == 'sigmoid_xent':
                ens_logits = batchensemble_utils.log_average_sigmoid_probs(
                    jnp.asarray(jnp.split(tiled_logits, ens_size)))
                pre_logits = batchensemble_utils.log_average_sigmoid_probs(
                    jnp.asarray(jnp.split(out['pre_logits'], ens_size)))
            else:  # softmax
                ens_logits = batchensemble_utils.log_average_softmax_probs(
                    jnp.asarray(jnp.split(tiled_logits, ens_size)))
                pre_logits = batchensemble_utils.log_average_softmax_probs(
                    jnp.asarray(jnp.split(out['pre_logits'], ens_size)))
            logits = ens_logits

        if use_pre_logits:
            output = pre_logits
        else:
            output = logits

        # TODO(joost,andreas): For multi host this requires:
        # output = jax.lax.all_gather(output, axis_name='batch')
        return output
    def evaluation_fn(params, images, labels):
        tiled_logits, out = model.apply({'params': flax.core.freeze(params)},
                                        images,
                                        train=False)

        loss_name = config.get('loss', 'sigmoid_xent')
        # TODO(dusenberrymw,zmariet): Clean up and generalize this.
        if loss_name == 'sigmoid_xent':
            ens_logits = batchensemble_utils.log_average_sigmoid_probs(
                jnp.asarray(jnp.split(tiled_logits, ens_size)))
            pre_logits = batchensemble_utils.log_average_sigmoid_probs(
                jnp.asarray(jnp.split(out['pre_logits'], ens_size)))
        else:  # softmax
            ens_logits = batchensemble_utils.log_average_softmax_probs(
                jnp.asarray(jnp.split(tiled_logits, ens_size)))
            pre_logits = batchensemble_utils.log_average_softmax_probs(
                jnp.asarray(jnp.split(out['pre_logits'], ens_size)))

        losses = getattr(train_utils,
                         loss_name)(logits=ens_logits,
                                    labels=labels[:, :config.num_classes],
                                    reduction=False)
        loss = jax.lax.psum(losses, axis_name='batch')

        top1_idx = jnp.argmax(ens_logits, axis=1)
        top1_correct = jnp.take_along_axis(labels, top1_idx[:, None],
                                           axis=1)[:, 0]
        ncorrect = jax.lax.psum(top1_correct, axis_name='batch')
        n = batch_size_eval

        metric_args = jax.lax.all_gather([ens_logits, labels, pre_logits],
                                         axis_name='batch')
        return ncorrect, loss, n, metric_args
Exemplo n.º 3
0
    def cifar_10h_evaluation_fn(params, images, labels, mask):
        tiled_logits, out = model.apply({'params': flax.core.freeze(params)},
                                        images,
                                        train=False)
        loss_name = config.get('loss', 'softmax_xent')
        if loss_name == 'sigmoid_xent':
            ens_logits = batchensemble_utils.log_average_sigmoid_probs(
                jnp.asarray(jnp.split(tiled_logits, ens_size)))
            pre_logits = batchensemble_utils.log_average_sigmoid_probs(
                jnp.asarray(jnp.split(out['pre_logits'], ens_size)))
        else:  # softmax
            ens_logits = batchensemble_utils.log_average_softmax_probs(
                jnp.asarray(jnp.split(tiled_logits, ens_size)))
            pre_logits = batchensemble_utils.log_average_softmax_probs(
                jnp.asarray(jnp.split(out['pre_logits'], ens_size)))

        label_indices = config.get('label_indices')
        if label_indices:
            ens_logits = ens_logits[:, label_indices]

        losses = getattr(train_utils,
                         config.get('loss', 'softmax_xent'))(logits=ens_logits,
                                                             labels=labels,
                                                             reduction=False)
        loss = jax.lax.psum(losses, axis_name='batch')

        top1_idx = jnp.argmax(ens_logits, axis=1)
        # Extracts the label at the highest logit index for each image.
        one_hot_labels = jnp.eye(10)[jnp.argmax(labels, axis=1)]

        top1_correct = jnp.take_along_axis(one_hot_labels,
                                           top1_idx[:, None],
                                           axis=1)[:, 0]
        ncorrect = jax.lax.psum(top1_correct, axis_name='batch')
        n = jax.lax.psum(one_hot_labels, axis_name='batch')

        metric_args = jax.lax.all_gather(
            [ens_logits, labels, pre_logits, mask], axis_name='batch')
        return ncorrect, loss, n, metric_args
Exemplo n.º 4
0
    def test_log_average_probs(self, ensemble_size):
        batch_size, num_classes = 16, 3
        logits_shape = (ensemble_size, batch_size, num_classes)
        np.random.seed(42)
        ensemble_logits = jnp.asarray(np.random.normal(size=logits_shape))

        actual_logits = batchensemble_utils.log_average_softmax_probs(
            ensemble_logits)
        self.assertAllEqual(actual_logits.shape, (batch_size, num_classes))

        expected_probs = jnp.mean(jax.nn.softmax(ensemble_logits), axis=0)
        self.assertAllClose(jax.nn.softmax(actual_logits), expected_probs)

        actual_logits = batchensemble_utils.log_average_sigmoid_probs(
            ensemble_logits)
        self.assertAllEqual(actual_logits.shape, (batch_size, num_classes))

        expected_probs = jnp.mean(jax.nn.sigmoid(ensemble_logits), axis=0)
        self.assertAllClose(jax.nn.sigmoid(actual_logits), expected_probs)
Exemplo n.º 5
0
  def evaluation_fn(params, images, labels, mask):
    # Ignore the entries with all zero labels for evaluation.
    mask *= labels.max(axis=1)
    tiled_images = jnp.tile(images, (1, 1, 1, config.model.ensemble_size,))
    tiled_logits, out = model.apply({'params': flax.core.freeze(params)},
                                    tiled_images,
                                    train=False)
    loss_name = config.get('loss', 'sigmoid_xent')
    # TODO(dusenberrymw,zmariet): Clean up and generalize this.
    if loss_name == 'sigmoid_xent':
      logits = batchensemble_utils.log_average_sigmoid_probs(
          jnp.asarray(jnp.split(tiled_logits, config.model.ensemble_size)))
    else:  # softmax
      logits = batchensemble_utils.log_average_softmax_probs(
          jnp.asarray(jnp.split(tiled_logits, config.model.ensemble_size)))

    label_indices = config.get('label_indices')
    logging.info('!!! mask %s, label_indices %s', mask, label_indices)
    if label_indices:
      logits = logits[:, label_indices]

    # Note that logits and labels are usually of the shape [batch,num_classes].
    # But for OOD data, when num_classes_ood > num_classes_ind, we need to
    # adjust labels to labels[:, :config.num_classes] to match the shape of
    # logits. That is just to avoid shape mismatch. The output losses does not
    # have any meaning for OOD data, because OOD not belong to any IND class.
    losses = getattr(train_utils, config.get('loss', 'sigmoid_xent'))(
        logits=logits,
        labels=labels[:, :(len(label_indices) if label_indices
                           else config.num_classes)], reduction=False)
    loss = jax.lax.psum(losses * mask, axis_name='batch')

    top1_idx = jnp.argmax(logits, axis=1)
    # Extracts the label at the highest logit index for each image.
    top1_correct = jnp.take_along_axis(labels, top1_idx[:, None], axis=1)[:, 0]
    ncorrect = jax.lax.psum(top1_correct * mask, axis_name='batch')
    n = jax.lax.psum(mask, axis_name='batch')

    metric_args = jax.lax.all_gather([logits, labels, out['pre_logits'], mask],
                                     axis_name='batch')
    return ncorrect, loss, n, metric_args