Ejemplo n.º 1
0
 def concatenate(
     self, supervised_inputs: chex.ArrayTree, extra_inputs: chex.ArrayTree
 ) -> Tuple[chex.Array, chex.Array, chex.Array]:
     """Concatenate inputs."""
     num_classes = self.config.num_classes
     supervised_images = supervised_inputs['image']
     supervised_labels = supervised_inputs['label']
     if extra_inputs is None:
         images = supervised_images
         labels = supervised_labels
         target_probs = hk.one_hot(labels, num_classes)
     else:
         extra_images = extra_inputs['image']
         images = jnp.concatenate([supervised_images, extra_images], axis=0)
         extra_labels = extra_inputs['label']
         labels = jnp.concatenate([supervised_labels, extra_labels], axis=0)
         supervised_one_hot_labels = hk.one_hot(supervised_labels,
                                                num_classes)
         extra_one_hot_labels = hk.one_hot(extra_labels, num_classes)
         if self.config.training.extra_label_smoothing > 0:
             pos = 1. - self.config.training.extra_label_smoothing
             neg = self.config.training.extra_label_smoothing / num_classes
             extra_one_hot_labels = pos * extra_one_hot_labels + neg
         target_probs = jnp.concatenate(
             [supervised_one_hot_labels, extra_one_hot_labels], axis=0)
     return images, labels, target_probs
Ejemplo n.º 2
0
    def _eval_batch(
        self,
        params: hk.Params,
        state: hk.State,
        batch: dataset.Batch,
    ) -> Mapping[Text, jnp.ndarray]:
        """Evaluates a batch.

    Args:
      params: Parameters of the model to evaluate. Typically Byol's online
        parameters.
      state: State of the model to evaluate. Typically Byol's online state.
      batch: Batch of data to evaluate (must contain keys images and labels).

    Returns:
      Unreduced evaluation loss and top1 accuracy on the batch.
    """
        if self._should_transpose_images():
            batch = dataset.transpose_images(batch)

        outputs, _ = self.forward.apply(params,
                                        state,
                                        batch,
                                        is_training=False)
        logits = outputs['logits']
        labels = hk.one_hot(batch['labels'], self._num_classes)
        loss = helpers.softmax_cross_entropy(logits, labels, reduction=None)
        top1_correct = helpers.topk_accuracy(logits, batch['labels'], topk=1)
        top5_correct = helpers.topk_accuracy(logits, batch['labels'], topk=5)
        # NOTE: Returned values will be summed and finally divided by num_samples.
        return {
            'eval_loss': loss,
            'top1_accuracy': top1_correct,
            'top5_accuracy': top5_correct,
        }
Ejemplo n.º 3
0
    def _loss_fn(
        self,
        backbone_params: hk.Params,
        classif_params: hk.Params,
        backbone_state: hk.State,
        inputs: dataset.Batch,
    ) -> Tuple[jnp.ndarray, Tuple[jnp.ndarray, hk.State]]:
        """Compute the classification loss function.

    Args:
      backbone_params: parameters of the encoder network.
      classif_params: parameters of the linear classifier.
      backbone_state: internal state of encoder network.
      inputs: inputs, containing `images` and `labels`.

    Returns:
      The classification loss and various logs.
    """
        embeddings, backbone_state = self.forward_backbone.apply(
            backbone_params,
            backbone_state,
            inputs,
            is_training=not self._freeze_backbone)

        logits = self.forward_classif.apply(classif_params, embeddings)
        labels = hk.one_hot(inputs['labels'], self._num_classes)
        loss = helpers.softmax_cross_entropy(logits, labels, reduction='mean')
        scaled_loss = loss / jax.device_count()

        return scaled_loss, (loss, backbone_state)
Ejemplo n.º 4
0
def untargeted_margin(logits: chex.Array, labels: chex.Array) -> chex.Array:
    """Make the highest non-correct logits higher than the true class logits."""
    batch_size = logits.shape[0]
    num_classes = logits.shape[-1]
    label_logits = logits[jnp.arange(batch_size), labels]
    logit_mask = hk.one_hot(labels, num_classes).astype(logits.dtype)
    highest_logits = jnp.max(logits - 1e8 * logit_mask, axis=-1)
    return label_logits - highest_logits
Ejemplo n.º 5
0
def untargeted_cross_entropy(logits: chex.Array,
                             labels: chex.Array) -> chex.Array:
    """Maximize the cross-entropy of the true class (make it less likely)."""
    num_classes = logits.shape[-1]
    log_probs = jax.nn.log_softmax(logits)
    return jnp.sum(hk.one_hot(labels, num_classes).astype(logits.dtype) *
                   log_probs,
                   axis=-1)
Ejemplo n.º 6
0
 def loss_fun(params, batch):
     """Training loss to optimize."""
     outputs = model.apply(params, batch)
     labels = hk.one_hot(batch['label'], len(meta.classes))
     softmax_xent = -jnp.sum(labels * jax.nn.log_softmax(outputs['scores']))
     softmax_xent /= labels.shape[0]
     disperse = embedding_regularizer(outputs['embedded'], batch['label'],
                                      meta)
     return softmax_xent + disperse + outputs['penalties']
Ejemplo n.º 7
0
def make_network() -> hk.RNNCore:
    """Defines the network architecture."""
    model = hk.DeepRNN([
        lambda x: hk.one_hot(x, num_classes=dataset.NUM_CHARS),
        hk.LSTM(FLAGS.hidden_size),
        jax.nn.relu,
        hk.LSTM(FLAGS.hidden_size),
        hk.nets.MLP([FLAGS.hidden_size, dataset.NUM_CHARS]),
    ])
    return model
Ejemplo n.º 8
0
    def loss(params: hk.Params, batch: Batch) -> jnp.ndarray:
        """Compute the loss of the network, including L2."""
        logits = net.apply(params, batch)
        labels = hk.one_hot(batch["label"], 10)

        l2_loss = 0.5 * sum(
            jnp.sum(jnp.square(p)) for p in jax.tree_leaves(params))
        softmax_xent = -jnp.sum(labels * jax.nn.log_softmax(logits))
        softmax_xent /= labels.shape[0]

        return softmax_xent + 1e-4 * l2_loss
Ejemplo n.º 9
0
def lm_loss_fn(forward_fn, vocab_size, params, rng, data, is_training=True):
    """Compute the loss on data wrt params."""
    logits = forward_fn(params, rng, data, is_training)
    targets = hk.one_hot(data['target'], vocab_size)
    assert logits.shape == targets.shape

    mask = jnp.greater(data['obs'], 0)
    loss = -jnp.sum(targets * jax.nn.log_softmax(logits), axis=-1)
    loss = jnp.sum(loss * mask) / jnp.sum(mask)

    return loss
Ejemplo n.º 10
0
def sequence_loss(batch: dataset.Batch) -> jnp.ndarray:
    """Unrolls the network over a sequence of inputs & targets, gets loss."""
    # Note: this function is impure; we hk.transform() it below.
    core = make_network()
    sequence_length, batch_size = batch['input'].shape
    initial_state = core.initial_state(batch_size)
    logits, _ = hk.dynamic_unroll(core, batch['input'], initial_state)
    log_probs = jax.nn.log_softmax(logits)
    one_hot_labels = hk.one_hot(batch['target'], num_classes=logits.shape[-1])
    return -jnp.sum(one_hot_labels * log_probs) / (sequence_length *
                                                   batch_size)
Ejemplo n.º 11
0
Archivo: train.py Proyecto: ibab/haiku
def loss_fn(
    params: hk.Params,
    state: hk.State,
    batch: dataset.Batch,
) -> Tuple[jnp.ndarray, hk.State]:
  """Computes a regularized loss for the given batch."""
  logits, state = forward.apply(params, state, None, batch, is_training=True)
  labels = hk.one_hot(batch['labels'], 1000)
  cat_loss = jnp.mean(softmax_cross_entropy(logits=logits, labels=labels))
  l2_params = [p for ((mod_name, _), p) in tree.flatten_with_path(params)
               if 'batchnorm' not in mod_name]
  reg_loss = FLAGS.train_weight_decay * l2_loss(l2_params)
  loss = cat_loss + reg_loss
  return loss, state
Ejemplo n.º 12
0
    def update(params, state, step, image, label):
        """One-vs-all classifier update fn."""

        # Learning rate schedules.
        learning_rate = jnp.minimum(
            FLAGS.max_lr, FLAGS.lr_constant / (1. + FLAGS.lr_decay * step))

        # Update weights and report log-loss.
        targets = hk.one_hot(jnp.asarray(label), num_classes)

        fn = jax.vmap(update_, in_axes=(0, 0, None, 0, None))
        out = fn(params, state, image, targets, learning_rate)
        (params, unused_predictions, log_loss), state = out
        return (jnp.mean(log_loss), params), state
Ejemplo n.º 13
0
    def update(params, state, step, image, label):
        """One-vs-all classifier update fn."""

        # Learning rate schedules.
        learning_rate = jnp.minimum(
            MAX_LR.value, LR_CONSTANT.value / (1. + LR_DECAY.value * step))

        # Update weights and report log-loss.
        targets = hk.one_hot(jnp.asarray(label), num_classes)

        fn = jax.vmap(update_, in_axes=(0, 0, None, 0, None))
        out = fn(params, state, image, targets, learning_rate)
        (params, unused_predictions, log_loss), state = out
        return (jnp.mean(log_loss), params), state
Ejemplo n.º 14
0
    def loss(
        params: hk.Params,
        inputs: np.ndarray,
        targets: np.ndarray,
    ) -> jnp.DeviceArray:
        """Compute the loss of the network, including L2."""
        assert targets.dtype == np.int32
        batch_size = inputs.shape[0]
        log_probs = net.apply(params, inputs)

        l2_loss = 0.5 * sum(
            jnp.sum(jnp.square(p)) for p in jax.tree_leaves(params))
        softmax_xent = -jnp.sum(hk.one_hot(targets, NUM_DIGITS) * log_probs)
        softmax_xent = softmax_xent / batch_size

        return softmax_xent + 1e-4 * l2_loss
Ejemplo n.º 15
0
def sequence_prediction_metrics(
        logits: jnp.ndarray,
        labels: jnp.ndarray,
        mask: Optional[jnp.ndarray] = None) -> Dict[str, float]:
    """Compute the metrics for sequence prediction.

  Args:
    logits: [B, T, V] array of logits.
    labels: [B, T] array of labels.
    mask: [B, T] array of binary masks, if provided.

  Returns:
    metrics: a dictionary of metrics.
  """
    vocab_size = logits.shape[-1]
    logps = jax.nn.log_softmax(logits)
    labels_one_hot = hk.one_hot(labels, vocab_size)
    class_logps = jnp.sum(logps * labels_one_hot, axis=-1)
    prediction_correct = jnp.argmax(logits, axis=-1) == labels
    if mask is not None:
        masked_logps = mask * class_logps
        total_count = jnp.sum(mask)
        tokens_correct = jnp.sum(prediction_correct * mask)
        seq_correct = jnp.all(jnp.logical_or(prediction_correct,
                                             jnp.logical_not(mask)),
                              axis=-1)
    else:
        masked_logps = class_logps
        total_count = np.prod(class_logps.shape)
        tokens_correct = jnp.sum(prediction_correct)
        seq_correct = jnp.all(prediction_correct, axis=-1)

    token_accuracy = tokens_correct.astype(jnp.float32) / total_count
    seq_accuracy = jnp.mean(seq_correct)
    log_probs = jnp.mean(jnp.sum(masked_logps, axis=-1))
    total_loss = -jnp.sum(masked_logps)
    loss = total_loss / total_count
    return dict(
        loss=loss,
        total_loss=total_loss,
        total_count=total_count,
        token_accuracy=token_accuracy,
        seq_accuracy=seq_accuracy,
        log_probs=log_probs,
    )
Ejemplo n.º 16
0
    def __call__(self, inputs: observation_action_reward.OAR,
                 state: LSTMState) -> Tuple[Tuple[Logits, Value], LSTMState]:
        reward = jnp.tanh(inputs.reward)  # [B, 1]
        if not reward.shape:
            reward = jnp.expand_dims(reward, axis=0)

        action = hk.one_hot(inputs.action, self._num_actions)  # [B, A]

        expand = len(inputs.observation.shape) == 3
        if expand:
            inputs = inputs._replace(
                observation=jnp.expand_dims(inputs.observation, axis=0))
        embedding = self._torso(inputs.observation)
        if expand:
            embedding = jnp.squeeze(embedding, axis=0)
        embedding = jnp.concatenate([embedding, action, reward], axis=-1)

        embedding, new_state = self._core(embedding, state)
        logits, value = self._head(embedding)  # [B, A]

        return (logits, value), new_state
Ejemplo n.º 17
0
 def _eval_batch(
     self,
     backbone_params: hk.Params,
     classif_params: hk.Params,
     backbone_state: hk.State,
     inputs: dataset.Batch,
 ) -> LogsDict:
     """Evaluates a batch."""
     embeddings, backbone_state = self.forward_backbone.apply(
         backbone_params, backbone_state, inputs, is_training=False)
     logits = self.forward_classif.apply(classif_params, embeddings)
     labels = hk.one_hot(inputs['labels'], self._num_classes)
     loss = helpers.softmax_cross_entropy(logits, labels, reduction=None)
     top1_correct = helpers.topk_accuracy(logits, inputs['labels'], topk=1)
     top5_correct = helpers.topk_accuracy(logits, inputs['labels'], topk=5)
     # NOTE: Returned values will be summed and finally divided by num_samples.
     return {
         'eval_loss': loss,
         'top1_accuracy': top1_correct,
         'top5_accuracy': top5_correct
     }
Ejemplo n.º 18
0
 def logp(self, params, obs, act):
     logits = self._net_apply(params, obs)
     all_logps = nn.log_softmax(logits)
     return (hk.one_hot(act, self.act_dim) * all_logps).sum(-1)
Ejemplo n.º 19
0
def loss_fn_flat(param_vector, hidden, data, targets):
    params = params_unravel_pytree(param_vector)
    result, hidden, outputs = apply_jit(params, data, hidden, None, False)
    loss = cross_entropy(result.reshape(-1, ntokens),
                         hk.one_hot(targets, ntokens))
    return loss, hidden
Ejemplo n.º 20
0
def loss_fn(params, data, targets, hidden):
    # result, hidden, outputs = apply_jit(params, data, hidden)
    result, hidden, outputs = apply_jit(params, data, hidden, None, False)
    loss = cross_entropy(result.reshape(-1, ntokens),
                         hk.one_hot(targets, ntokens))
    return loss, hidden
Ejemplo n.º 21
0
def softmax_cross_entropy(logits, labels):
    """
  Cross-entropy loss applied to softmax.
  """
    one_hot = hk.one_hot(labels, logits.shape[-1])
    return -jnp.sum(jax.nn.log_softmax(logits) * one_hot, axis=-1)
Ejemplo n.º 22
0
 def loss(params, inputs, targets):
   log_probs = net.apply(params, inputs)
   return -jnp.mean(hk.one_hot(targets, NUM_CLASSES) * log_probs)