Exemplo n.º 1
0
def moco_key_step(model_key, state_key, batch):
    """MoCo train step part 1; predict embeddings given key network.

  We separate our MoCo training step into two parts.
  This first part uses the key network to predict embeddings.
  The samples that are used have to be shuffled to prevent the network
  from cheating using differing batch stats between devices.
  (see https://arxiv.org/abs/1911.05722, sec 3.3)

  Args:
    model_key: key network
    state_key: batch stats and state for key network
    batch: batch of samples

  Returns:
    embeddings for samples in `batch`
  """
    # Average batch stats across devices/hosts
    state_key = common_utils.pmean(state_key)

    # emb_key.shape = (n_samples, emb_size)
    x_key = batch['x_key']
    with flax.nn.stateful(state_key) as new_state_key:
        emb_key, _ = model_key(x_key, train=True)
    emb_key = jax.lax.stop_gradient(emb_key)
    emb_key = normalize_embeddings(emb_key)
    return emb_key, new_state_key
Exemplo n.º 2
0
def compute_metrics(logits, labels):
    loss = cross_entropy_loss(logits, labels)
    error_rate = jnp.mean(jnp.argmax(logits, -1) != labels)
    metrics = {
        'loss': loss,
        'error_rate': error_rate,
    }
    metrics = common_utils.pmean(metrics)
    return metrics
Exemplo n.º 3
0
def eval_step(model_moco, state_moco, feat_clf_model, batch):
    """Linear classifier evaluation step."""
    # Average batch stats across devices/hosts
    state_moco = common_utils.pmean(state_moco)
    # Use MoCo network to predict features
    with flax.nn.stateful(state_moco, mutable=False):
        _, features = model_moco(batch['image'], train=False)
    # Use linear model to predict class logits
    feat_logits = feat_clf_model(features)
    feat_metrics = compute_metrics(feat_logits, batch['label'])
    return feat_metrics
Exemplo n.º 4
0
def classifier_train_step(clf_feat_optimizer, model_moco, state_moco,
                          batch, learning_rate_fn, l2_reg):
  """Linear classifier training step."""
  # Average batch stats across devices/hosts
  state_moco = common_utils.pmean(state_moco)

  # Get data from batch
  sup_x = batch['image']

  # Predict features (ignore embeddings)
  with flax.nn.stateful(state_moco, mutable=False):
    _, features = model_moco(sup_x, train=False)
  features = jax.lax.stop_gradient(features)

  def features_loss_fn(model_clf):
    """loss function used for training."""
    logits = model_clf(features)
    loss = cross_entropy_loss(logits, batch['label'])

    if l2_reg > 0:
      weight_penalty_params = jax.tree_leaves(model_clf.params)
      weight_l2 = sum([jnp.sum(x ** 2)
                       for x in weight_penalty_params
                       if x.ndim > 1])
      weight_penalty = l2_reg * 0.5 * weight_l2
      loss = loss + weight_penalty
    return loss, (logits,)

  # Feature classifier
  feat_step = clf_feat_optimizer.state.step
  feat_lr = learning_rate_fn(feat_step)
  new_clf_feat_optimizer, _, (feat_logits,) = clf_feat_optimizer.optimize(
      features_loss_fn, learning_rate=feat_lr)

  feat_metrics = compute_metrics(feat_logits, batch['label'])
  feat_metrics['learning_rate'] = feat_lr

  return new_clf_feat_optimizer, feat_metrics
Exemplo n.º 5
0
def compute_train_moco_metrics(moco_loss_per_sample):
    metrics = {
        'moco_loss': moco_loss_per_sample.mean(),
    }
    metrics = common_utils.pmean(metrics)
    return metrics