def testPswapaxes(self): device_count = xla_bridge.device_count() shape = (device_count, 3, device_count, 5) x = onp.arange(prod(shape)).reshape(shape) ans = pmap(lambda x: lax.pswapaxes(x, 'i', 1), axis_name='i')(x) expected = onp.swapaxes(x, 0, 2) self.assertAllClose(ans, expected, check_dtypes=False)
def testPswapaxes(self): device_count = xla_bridge.device_count() # TODO: AllToAll not yet implemented on XLA:CPU if jtu.device_under_test() == "cpu": device_count = 1 shape = (device_count, 3, device_count, 5) x = onp.arange(prod(shape)).reshape(shape) ans = pmap(lambda x: lax.pswapaxes(x, 'i', 1), axis_name='i')(x) expected = onp.swapaxes(x, 0, 2) self.assertAllClose(ans, expected, check_dtypes=False)
def moco_train_step(optimizer_query, state_query, model_key, batch, moco_dictionary, n_devices, moco_temperature, learning_rate_fn, l2_reg, moco_momentum): """MoCo training step part 2. Given the keys generated in part 1, part 2 uses the query network to predict embeddings for the same samples as in part 1. The MoCo loss encourages the query network to predict an embedding that is more similar to the corresponding key network embedding than to any of the embeddings in the MoCo dictionary (the paper uses the term dictionary). Args: optimizer_query: query network optimizer/model state_query: query network state / batch stats model_key: key network batch: data batch moco_dictionary: dictionary of embeddings from key network n_devices: number of devices in use moco_temperature: softmax temperature for computing MoCo loss learning_rate_fn: function fn(step) -> lr that defines learning rate schedule l2_reg: L2 regularization coefficient moco_momentum: MoCo key network momentum parameter Returns: (new_optimizer_query, new_state_query, metrics, model_key, emb_key_all) new_optimizer_query: query network optimizer and model after step new_state_query: query network state / batch stats after step metrics: MoCo training metrics model_key: key network model (used to update query network) emb_key_all: key network embeddings concatenated across devices """ def loss_fn(model_query): """loss function used for training.""" emb_key = batch['emb_key'] x_query = batch['query_image'] # Get predicted embeddings from query network with flax.nn.stateful(state_query) as new_state_query: emb_query, _ = model_query(x_query, train=True) emb_query = normalize_embeddings(emb_query) # emb_query.shape = (n_samples, emb_size) # Compute per-sample MoCo loss moco_loss_per_sample = moco_loss(emb_query, emb_key, moco_dictionary, moco_temperature) loss = moco_loss_per_sample.mean() # Apply L2 regularization if l2_reg > 0: weight_penalty_params = jax.tree_leaves(model_query.params) weight_l2 = sum( [jnp.sum(x**2) for x in weight_penalty_params if x.ndim > 1]) weight_penalty = l2_reg * 0.5 * weight_l2 loss = loss + weight_penalty return loss, (new_state_query, moco_loss_per_sample, emb_key) step = optimizer_query.state.step lr = learning_rate_fn(step) new_optimizer_query, _, (new_state_query, moco_loss_per_sample, emb_key) = \ optimizer_query.optimize(loss_fn, learning_rate=lr) # Update key network - exponential moving average of query network model_key_params = jax.tree_multimap( lambda p_k, p_q: p_k * moco_momentum + p_q * (1.0 - moco_momentum), model_key.params, new_optimizer_query.target.params) model_key = model_key.replace(params=model_key_params) # Compute metrics metrics = compute_train_moco_metrics(moco_loss_per_sample) metrics['learning_rate'] = lr # In this step we use `lax.pswapaxes` to concatenate the embeddings # generated by the key network *across multiple hosts* emb_rep = [n_devices] + [1] * emb_key.ndim emb_key = emb_key[None, Ellipsis] emb_key = jnp.tile(emb_key, emb_rep) emb_key_all = lax.pswapaxes(emb_key, 'batch', 0) # Return the concatenated key embeddings return new_optimizer_query, new_state_query, metrics, model_key, emb_key_all