def moco_key_step(model_key, state_key, batch): """MoCo train step part 1; predict embeddings given key network. We separate our MoCo training step into two parts. This first part uses the key network to predict embeddings. The samples that are used have to be shuffled to prevent the network from cheating using differing batch stats between devices. (see https://arxiv.org/abs/1911.05722, sec 3.3) Args: model_key: key network state_key: batch stats and state for key network batch: batch of samples Returns: embeddings for samples in `batch` """ # Average batch stats across devices/hosts state_key = common_utils.pmean(state_key) # emb_key.shape = (n_samples, emb_size) x_key = batch['x_key'] with flax.nn.stateful(state_key) as new_state_key: emb_key, _ = model_key(x_key, train=True) emb_key = jax.lax.stop_gradient(emb_key) emb_key = normalize_embeddings(emb_key) return emb_key, new_state_key
def compute_metrics(logits, labels): loss = cross_entropy_loss(logits, labels) error_rate = jnp.mean(jnp.argmax(logits, -1) != labels) metrics = { 'loss': loss, 'error_rate': error_rate, } metrics = common_utils.pmean(metrics) return metrics
def eval_step(model_moco, state_moco, feat_clf_model, batch): """Linear classifier evaluation step.""" # Average batch stats across devices/hosts state_moco = common_utils.pmean(state_moco) # Use MoCo network to predict features with flax.nn.stateful(state_moco, mutable=False): _, features = model_moco(batch['image'], train=False) # Use linear model to predict class logits feat_logits = feat_clf_model(features) feat_metrics = compute_metrics(feat_logits, batch['label']) return feat_metrics
def classifier_train_step(clf_feat_optimizer, model_moco, state_moco, batch, learning_rate_fn, l2_reg): """Linear classifier training step.""" # Average batch stats across devices/hosts state_moco = common_utils.pmean(state_moco) # Get data from batch sup_x = batch['image'] # Predict features (ignore embeddings) with flax.nn.stateful(state_moco, mutable=False): _, features = model_moco(sup_x, train=False) features = jax.lax.stop_gradient(features) def features_loss_fn(model_clf): """loss function used for training.""" logits = model_clf(features) loss = cross_entropy_loss(logits, batch['label']) if l2_reg > 0: weight_penalty_params = jax.tree_leaves(model_clf.params) weight_l2 = sum([jnp.sum(x ** 2) for x in weight_penalty_params if x.ndim > 1]) weight_penalty = l2_reg * 0.5 * weight_l2 loss = loss + weight_penalty return loss, (logits,) # Feature classifier feat_step = clf_feat_optimizer.state.step feat_lr = learning_rate_fn(feat_step) new_clf_feat_optimizer, _, (feat_logits,) = clf_feat_optimizer.optimize( features_loss_fn, learning_rate=feat_lr) feat_metrics = compute_metrics(feat_logits, batch['label']) feat_metrics['learning_rate'] = feat_lr return new_clf_feat_optimizer, feat_metrics
def compute_train_moco_metrics(moco_loss_per_sample): metrics = { 'moco_loss': moco_loss_per_sample.mean(), } metrics = common_utils.pmean(metrics) return metrics