Exemple #1
0
    def testPswapaxes(self):
        device_count = xla_bridge.device_count()
        shape = (device_count, 3, device_count, 5)
        x = onp.arange(prod(shape)).reshape(shape)

        ans = pmap(lambda x: lax.pswapaxes(x, 'i', 1), axis_name='i')(x)
        expected = onp.swapaxes(x, 0, 2)
        self.assertAllClose(ans, expected, check_dtypes=False)
Exemple #2
0
  def testPswapaxes(self):
    device_count = xla_bridge.device_count()
    # TODO: AllToAll not yet implemented on XLA:CPU
    if jtu.device_under_test() == "cpu":
      device_count = 1
    shape = (device_count, 3, device_count, 5)
    x = onp.arange(prod(shape)).reshape(shape)

    ans = pmap(lambda x: lax.pswapaxes(x, 'i', 1), axis_name='i')(x)
    expected = onp.swapaxes(x, 0, 2)
    self.assertAllClose(ans, expected, check_dtypes=False)
Exemple #3
0
def moco_train_step(optimizer_query, state_query, model_key, batch,
                    moco_dictionary, n_devices, moco_temperature,
                    learning_rate_fn, l2_reg, moco_momentum):
    """MoCo training step part 2.

  Given the keys generated in part 1, part 2
  uses the query network to predict embeddings for the same samples as in
  part 1.
  The MoCo loss encourages the query network to predict an
  embedding that is more similar to the corresponding key network
  embedding than to any of the embeddings in the MoCo dictionary
  (the paper uses the term dictionary).

  Args:
    optimizer_query: query network optimizer/model
    state_query: query network state / batch stats
    model_key: key network
    batch: data batch
    moco_dictionary: dictionary of embeddings from key network
    n_devices: number of devices in use
    moco_temperature: softmax temperature for computing MoCo loss
    learning_rate_fn: function fn(step) -> lr that defines learning rate
      schedule
    l2_reg: L2 regularization coefficient
    moco_momentum: MoCo key network momentum parameter

  Returns:
    (new_optimizer_query, new_state_query, metrics, model_key, emb_key_all)
      new_optimizer_query: query network optimizer and model after step
      new_state_query: query network state / batch stats after step
      metrics: MoCo training metrics
      model_key: key network model (used to update query network)
      emb_key_all: key network embeddings concatenated across devices
  """
    def loss_fn(model_query):
        """loss function used for training."""

        emb_key = batch['emb_key']
        x_query = batch['query_image']

        # Get predicted embeddings from query network
        with flax.nn.stateful(state_query) as new_state_query:
            emb_query, _ = model_query(x_query, train=True)
        emb_query = normalize_embeddings(emb_query)
        # emb_query.shape = (n_samples, emb_size)

        # Compute per-sample MoCo loss
        moco_loss_per_sample = moco_loss(emb_query, emb_key, moco_dictionary,
                                         moco_temperature)
        loss = moco_loss_per_sample.mean()

        # Apply L2 regularization
        if l2_reg > 0:
            weight_penalty_params = jax.tree_leaves(model_query.params)
            weight_l2 = sum(
                [jnp.sum(x**2) for x in weight_penalty_params if x.ndim > 1])
            weight_penalty = l2_reg * 0.5 * weight_l2
            loss = loss + weight_penalty

        return loss, (new_state_query, moco_loss_per_sample, emb_key)

    step = optimizer_query.state.step
    lr = learning_rate_fn(step)
    new_optimizer_query, _, (new_state_query, moco_loss_per_sample,
                             emb_key) = \
        optimizer_query.optimize(loss_fn, learning_rate=lr)

    # Update key network - exponential moving average of query network
    model_key_params = jax.tree_multimap(
        lambda p_k, p_q: p_k * moco_momentum + p_q * (1.0 - moco_momentum),
        model_key.params, new_optimizer_query.target.params)
    model_key = model_key.replace(params=model_key_params)

    # Compute metrics
    metrics = compute_train_moco_metrics(moco_loss_per_sample)
    metrics['learning_rate'] = lr

    # In this step we use `lax.pswapaxes` to concatenate the embeddings
    # generated by the key network *across multiple hosts*
    emb_rep = [n_devices] + [1] * emb_key.ndim
    emb_key = emb_key[None, Ellipsis]
    emb_key = jnp.tile(emb_key, emb_rep)
    emb_key_all = lax.pswapaxes(emb_key, 'batch', 0)

    # Return the concatenated key embeddings
    return new_optimizer_query, new_state_query, metrics, model_key, emb_key_all