Ejemplo n.º 1
0
 def mapped_update(i, opt_state, batch, state, rng):
   """This is a multi-device version of the update function above."""
   # We assume all tensors have the first dimension = n_devices.
   weights, slots, opt_params = opt_state
   rng, subrng = jax_random.split(rng)
   grad_fn = backend.grad(model_and_loss_call, has_aux=True)
   grads, state = grad_fn(weights, batch, state, rng)
   # We do a psum(1.0) here instead of `n_devices` since `n_devices` is just
   # the number of devices on this host machine, however psum goes over all
   # devices of all hosts (ex: a TPU pod) and we need to be averaging over all
   # of them.
   grads = jax.tree_util.tree_map(
       lambda g: backend.psum(g, 'batch') / backend.psum(1.0, 'batch'), grads)
   return optimizer.tree_update(
       i, grads, weights, slots, opt_params), state, subrng
 def mapped_update(i, opt_state, batch, state, rng):
   """This is a multi-device version of the update function above."""
   # We assume all tensors have the first dimension = n_devices.
   weights, slots, opt_params = opt_state
   rng, subrng = jax_random.split(rng)
   grad_fn = backend.grad(model_and_loss_call, has_aux=True)
   grads, state = grad_fn(weights, batch, state, rng)
   grads = jax.tree_util.tree_map(
       lambda g: backend.psum(g, 'batch'), grads)
   return optimizer.tree_update(
       i, grads, weights, slots, opt_params), state, subrng