def _accelerate(f, n_devices): """Returns an accelerated version of ``f`` running on ``n_devices``.""" if n_devices == 0: # no accelerators - run on CPU return fastmath.jit(f, device=jax.devices('cpu')[0]) if n_devices == 1: return fastmath.jit(f) return fastmath.pmap(f, axis_name='batch')
def __init__(self, model, task, eval_model=None, eval_task=None, output_dir=None, checkpoint_at=None, eval_at=None): """Configures a training `Loop`, including a random initialization. Args: model: Trax layer, representing the core model to be trained. Loss functions and eval functions (a.k.a. metrics) are considered to be outside the core model, taking core model output and data labels as their two inputs. task: TrainTask instance, which defines the training data, loss function, and optimizer to be used in this training loop. eval_model: Optional Trax layer, representing model used for evaluation, e.g., with dropout turned off. If None, the training model (model) will be used. eval_task: EvalTask instance or None. If None, don't do any evals. output_dir: Path telling where to save outputs (evals and checkpoints). Can be None if both `eval_task` and `checkpoint_at` are None. checkpoint_at: Function (integer --> boolean) telling, for step n, whether that step should have its checkpoint saved. If None, the default is periodic checkpointing at `task.n_steps_per_checkpoint`. eval_at: Function (integer --> boolean) that says, for training step n, whether that step should run evals. If None, run when checkpointing. """ self._task = task self._model = model self._model_in_training = tl.Serial(model, task.loss_layer) self._eval_model = model if eval_model is None else eval_model self._eval_task = eval_task self._output_dir = os.path.expanduser(output_dir) if output_dir else None default_fn = _at_step_1_and_periodically_at(task.n_steps_per_checkpoint) self._checkpoint_at = checkpoint_at or default_fn self._eval_at = eval_at or default_fn if eval_task is None: self._eval_at = _never self._step = 0 batch_signature = shapes.signature(task.sample_batch) self._batch_signature = batch_signature # Initialize the model and the optimizer; discard the return values # (model weights/state, optimizer slots/params), since they're available # from the model and optimizer objects. _, _ = self._model_in_training.init(batch_signature) _, _ = task.optimizer.tree_init(self._model_in_training.weights) self._gradients_and_state_fn = ( fastmath.jit(fastmath.grad(self._model_in_training.pure_fn, argnums=1, # arg1 of pure_fn: weights has_aux=True))) # return (gradients, state) if eval_task is not None: model_with_metrics = _model_with_metrics(self._eval_model, eval_task) self._eval_weights = model_with_metrics.weights[1] # just the eval part self._eval_state = model_with_metrics.state[1] # just the eval part self._metrics_fn = fastmath.jit(model_with_metrics.pure_fn)
def _jit_compute_loss_fn(predict_fn, loss_fn, n_devices, jit=True): """Returns a (JIT-compiled) function that computes the loss for one step.""" if n_devices == 1: # TODO(lukaszkaiser): remove branch when not needed. def single_compute_loss(opt_state, batch, state, rng): rng, subrng = jax_random.split(rng[0]) loss_val, state = loss_fn(opt_state[0], batch, predict_fn, state, rng) return loss_val, state, [subrng] return fastmath.jit( single_compute_loss) if jit else single_compute_loss # Else, for n_devices > 1: @functools.partial(fastmath.pmap, axis_name='batch') def mapped_compute_loss(opt_state, batch, state, rng): """This is a multi-device version of the update function above.""" # We assume all tensors have the first dimension = n_devices. rng, subrng = jax_random.split(rng) loss_val, state = loss_fn(opt_state[0], batch, predict_fn, state, rng) return loss_val, state, subrng def compute_loss(opt_state, batch, state, rng): return mapped_compute_loss(opt_state, _reshape_by_device(batch, n_devices), state, rng) return compute_loss
def _pjit(self, f, donate_argnums=()): """JIT f if 1 device is available and pmap if more are available.""" if self._n_devices == 1: return fastmath.jit(f, donate_argnums=donate_argnums) else: return fastmath.pmap(f, axis_name='batch', donate_argnums=donate_argnums)
def _accelerate_update_fn(forward_and_backward_fn, optimizer, n_devices, accelerate=True, adasum=False): """Accelerates the given forward_and_backward_fn function.""" if n_devices == 1: def single_device_update_fn(weights_and_slots, step, opt_params, batch, state, rng): step = jnp.array(step, dtype=jnp.int32) # Needed in TFNP backend. weights, slots = weights_and_slots (loss, state), gradients = forward_and_backward_fn( batch, weights, state, rng) weights, slots, stats = optimizer.tree_update(step, gradients, weights, slots, opt_params, store_slots=False) stats['loss'] = loss return (weights, slots), state, stats if accelerate: # TODO(afrozm): Find out the status of buffer donation on GPUs, then do # donate_argnums=(0,). single_device_update_fn = fastmath.jit(single_device_update_fn) return single_device_update_fn # More than one device (core), i.e. all of TPU configurations etc. assert n_devices > 1, f'{n_devices} should be greater than 1.' @functools.partial(fastmath.pmap, axis_name='batch', donate_argnums=(0, )) def _multi_device_update_fn(weights_and_slots, step, opt_params, batch, state, rng): # All tensors should have the first dimension = n_devices. weights, slots = weights_and_slots (loss, state), gradients = (forward_and_backward_fn(batch, weights, state, rng)) gradients = _average_multidevice_gradients(gradients, adasum=adasum) weights, slots, stats = optimizer.tree_update(step, gradients, weights, slots, opt_params, store_slots=False) stats['loss'] = loss return (weights, slots), state, stats def multi_device_update_fn(weights_and_slots, step, opt_params, batch, state, rng): # Need to replicate step to n_devices leading dimension. return _multi_device_update_fn(weights_and_slots, jnp.repeat(step, n_devices), opt_params, batch, state, rng) return multi_device_update_fn
def _accelerate_update_fn(forward_and_backward_fn, optimizer, n_devices, accelerate=True): """Accelerate the given forward_and_backward_fn function.""" if n_devices == 1: def single_device_update_fn(weights_and_slots, step, opt_params, batch, state, rng): step = jnp.array(step, dtype=jnp.int32) # Needed in TFNP backend. weights, slots = weights_and_slots (loss, state), gradients = forward_and_backward_fn( batch, weights, state, rng) weights, slots, stats = optimizer.tree_update( step, gradients, weights, slots, opt_params) stats['loss'] = loss return (weights, slots), state, stats if accelerate: # TODO(afrozm): Find out the status of buffer donation on GPUs, then do # donate_argnums=(0,). single_device_update_fn = fastmath.jit(single_device_update_fn) return single_device_update_fn # More than one device (core), i.e. all of TPU configurations etc. assert n_devices > 1, f'{n_devices} should be greater than 1.' @functools.partial(fastmath.pmap, axis_name='batch', donate_argnums=(0, )) def _multi_device_update_fn(weights_and_slots, step, opt_params, batch, state, rng): # We assume all tensors have the first dimension = n_devices. weights, slots = weights_and_slots (loss, state), gradients = forward_and_backward_fn(batch, weights, state, rng) # gradients now need to be summed over all the devices across different host # machines, n_devices is only the number of devices on *this* host machine. gradients = fastmath.psum(gradients, 'batch') n_devices_total = fastmath.psum(jnp.array(1.0), 'batch') # Average across hosts. gradients = jax.tree_util.tree_map(lambda g: g / n_devices_total, gradients) weights, slots, stats = optimizer.tree_update(step, gradients, weights, slots, opt_params) stats['loss'] = loss return (weights, slots), state, stats def multi_device_update_fn(weights_and_slots, step, opt_params, batch, state, rng): # Need to replicate step to n_devices leading dimension. return _multi_device_update_fn(weights_and_slots, jnp.repeat(step, n_devices), opt_params, batch, state, rng) return multi_device_update_fn
def _jit_update_fn(predict_fn, loss_fn, optimizer, n_devices, jit=True): """Returns a (JIT-compiled) function that computes updates for one step.""" model_and_loss = tl.Serial(predict_fn, loss_fn) # Gradients are always wrt. the first argument, so putting weights first. def model_and_loss_call(weights, batch, state, rng): res = model_and_loss(batch, weights=weights, state=state, rng=rng) return res, model_and_loss.state if n_devices == 1: # TODO(lukaszkaiser): remove branch when not needed. def single_update(weights_and_slots, i, opt_params, batch, state, rng): weights, slots = weights_and_slots rng, subrng = jax_random.split(rng[0]) grad_fn = fastmath.grad(model_and_loss_call, has_aux=True) grads, state = grad_fn(weights, batch, state, rng) new_weights, new_slots, stats = optimizer.tree_update( i, grads, weights, slots, opt_params) return (new_weights, new_slots), stats, state, [subrng] if jit: # TODO(lukaszkaiser): donate_argnums=(0,) when XLA supports it on GPU return fastmath.jit(single_update) else: return single_update # Else, for n_devices > 1: @functools.partial(fastmath.pmap, axis_name='batch') # donate_argnums=(0,)) def mapped_update(weights_and_slots, i, opt_params, batch, state, rng): """This is a multi-device version of the update function above.""" # We assume all tensors have the first dimension = n_devices. weights, slots = weights_and_slots rng, subrng = jax_random.split(rng) grad_fn = fastmath.grad(model_and_loss_call, has_aux=True) grads, state = grad_fn(weights, batch, state, rng) # We do a psum(1.0) here instead of `n_devices` since `n_devices` is just # the number of devices on this host machine, however psum goes over all # devices of all hosts (ex: a TPU pod) and we need to be averaging over all # of them. # # Collect all gradients. grads = fastmath.psum(grads, 'batch') n_devices_total = fastmath.psum(np.array(1.0), 'batch') # Average across hosts. grads = jax.tree_util.tree_map(lambda g: g / n_devices_total, grads) new_weights, new_slots, stats = optimizer.tree_update( i, grads, weights, slots, opt_params) return (new_weights, new_slots), stats, state, subrng def update(weights_and_slots, i, opt_params, batch, state, rng): return mapped_update(weights_and_slots, np.repeat(i, n_devices), opt_params, batch, state, rng) return update
def _pjit(self, f, memory_key=None, donate_argnums=()): """JIT f if 1 device is available and pmap if more are available.""" should_memoize = self._jit_memory is not None and memory_key is not None if (should_memoize and memory_key in self._jit_memory): logging.info('Found JITed function in memory for: %s', memory_key) return self._jit_memory[memory_key] if self._n_devices == 1: res = fastmath.jit(f, donate_argnums=donate_argnums) else: res = fastmath.pmap(f, axis_name='batch', donate_argnums=donate_argnums) if should_memoize: self._jit_memory[memory_key] = res return res
def _pjit(self, f): """JIT f if 1 device is available and pmap if more are available.""" if self._n_devices == 1: return fastmath.jit(f) else: return fastmath.pmap(f, axis_name='batch')
def __init__(self, model, loss_fn, optimizer, lr_schedule, inputs, output_dir=None, random_seed=None, n_devices=None, checkpoints_at=None, should_save_checkpoints=True, should_write_summaries=True, metrics=None, checkpoint_highest=None, checkpoint_lowest=None): self._is_chief, _, self._n_devices, rng = ( training.init_host_and_devices(n_devices, random_seed)) self._should_save_checkpoints = should_save_checkpoints and self._is_chief self._checkpoints_at = checkpoints_at if checkpoints_at is not None else [] self._should_write_summaries = should_write_summaries if not output_dir: self._should_save_checkpoints = False self._should_write_summaries = False self._checkpoint_highest = checkpoint_highest self._checkpoint_lowest = checkpoint_lowest self._metrics_dict = metrics if metrics is not None else _DEFAULT_METRICS # Inputs is either an Inputs instance or a function that returns it. self._inputs = inputs if callable( inputs): # If we pass a function, e.g., through gin, call it. self._inputs = inputs() # Initialize the learning rate to a dummy value. It will be set in reset(). opt = optimizer(learning_rate=0.0) # Setup the model. model_train = model(mode='train') model_predict_eval = model(mode='eval') self._model_with_loss = tl.Serial(model_train, loss_fn) # Setup state. rng, init_rng = jax_random.split(rng) self._rngs = np.stack(jax_random.split(rng, self._n_devices)) shapes, dtypes = self._inputs.example_shape_dtype input_signature = tuple( ShapeDtype(s, d) for (s, d) in zip(shapes, dtypes)) def new_opt_state_and_model_state(rng): """Returns optimizer and model states suitable for training a model.""" weights, state = self._model_with_loss.init(input_signature, rng=rng) (slots, opt_params) = opt.tree_init(weights) return (OptState(weights, slots, opt_params), state) if fastmath.is_backend(fastmath.Backend.JAX): # JIT parameter initialization to avoid memory fragmentation new_opt_state_and_model_state = ( fastmath.jit(new_opt_state_and_model_state)) self._new_opt_state_and_model_state = ( lambda: new_opt_state_and_model_state(init_rng)) # Arrange and initialize metrics layers. self._metrics = list(sorted(self._metrics_dict.keys())) metrics_layers = [self._metrics_dict[m] for m in self._metrics] metrics_in_parallel = tl.Branch(*metrics_layers) metrics_in_parallel.rng = init_rng example_signature = tuple( ShapeDtype(s, d) for (s, d) in zip(*self._inputs.example_shape_dtype)) model_predict_eval.init(example_signature) self._input_signature = example_signature output_signature = model_predict_eval.output_signature( example_signature) m_weights, m_state = metrics_in_parallel.init(output_signature) self._metrics_weights = self._for_n_devices(m_weights) self._metrics_state = self._for_n_devices(m_state) # Jit model_predict and update so they're fast. self._jit_eval = _jit_predict_fn(model_predict_eval, metrics_in_parallel, self._n_devices) self._jit_update_fn = _jit_update_fn(model_train, loss_fn, opt, self._n_devices) self._model_train = model_train self._model_predict_eval = model_predict_eval self._loss_fn = loss_fn self._lr_schedule = lr_schedule # Those fields will be set in reset(). self._output_dir = None self._train_sw = None self._eval_sw = None self._history = None self._opt_state = None self._step = None self._model_state = None self.reset(output_dir)
def __init__(self, blocks, loss_layer, optimizer_fn, n_devices=None, memoize_jit=True, free_accelerators_on_step=False, adasum=False): """Creates a ReversibleSerialTrainer and the needed optimizers. This trainer performs updates equivalent to using the default Trainer on:: tl.Serial(blocks + [loss_layer]). It is more memory-efficient though since weights are stored on CPU and only sent to accelerator layer-by-layer. Blocks are pairs consisting of a list of standard (arbitrary) layers and a list of reversible layers which help save memory thanks to being reversible. Args: blocks: A list of pairs of lists of standard and reversible layers. loss_layer: The final layer of the model; it can have trainable weights but should end with a loss: it is required to produce a scalar output. optimizer_fn: A function to create the optimizer, e.g., `optimizers.Adam`. n_devices: An optional integer, number of accelerator devices to use; by default, all available accelerators will be used. memoize_jit: Whether to memoize JITed functions; this significantly speeds up XLA compilation of larger models, but it uses `repr(layer)` as keys to memoize so it could fail if two layers with different functionality had the same string representaion. We have not encountered such case yet so this is turned on by default, but consider turning it off or reviewing your model if you use custom layers and encounter a problem. free_accelerators_on_step: If true, frees memory on accelerators when starting a step. All layers and arguments must be on host for that, otherwise it can lead to failures. Can prevent memory fragmentation. adasum: if True, use adaptive summation to gather multi-device gradients. """ self._blocks = [(tl.Serial(std), rev) for (std, rev) in blocks] self._loss_layer = loss_layer self._optimizer_fn = optimizer_fn self._n_devices = n_devices or fastmath.local_device_count() self._adasum = adasum self._n_layers = 1 + sum([len(revs) + 1 for (_, revs) in self._blocks]) self._n_steps_per_log = 100 # Log layers and stats every 100 steps. self._n_async_layers = 1 # How many layers to run asynchronously. self._jit_memory = {} if memoize_jit else None self._do_free = free_accelerators_on_step self._jit_per_device_rngs = fastmath.jit(self._per_device_rngs, backend='cpu') # Create accelerated versions of layers as pmaped/jited pure_fn. self._accelerated_layer_fns = fastmath.nested_map( lambda layer: self._pjit(layer.pure_fn, f'fwd {repr(layer)}'), self._blocks) # Create per-layer optimizers and replicate opt_params. def _make_optimizer(layer): opt = optimizer_fn() opt.tree_init(layer.weights) opt.slots = tl.on_cpu(opt.slots) return opt self._optimizers = fastmath.nested_map(_make_optimizer, self._blocks) self._replicated_opt_params = fastmath.nested_map( lambda opt: self._replicate_cpu(opt.opt_params), self._optimizers) self._loss_opt = _make_optimizer(loss_layer) self._replicated_loss_opt_params = self._replicate_cpu( self._loss_opt.opt_params) # Forward + backward + optimizer-update functions for all layers. # We call them in short FBO for "Forward + Backward + Optimizer update". # Reversible layers define a reverse_and_fbo function that also reverses. self._fbos = [] for i, (std_layer, rev_layers) in enumerate(self._blocks): (std_opt, rev_opts) = self._optimizers[i] std_fbo = _fbo_with_layer_and_opt(std_layer, std_opt, self._n_devices, adasum=self._adasum) rev_and_fbos = [] for layer, opt in zip(rev_layers, rev_opts): rev_and_fbo = _reverse_and_fbo_with_layer_and_opt( layer, opt, self._n_devices, self._adasum) # The donated args are (outputs, weights, grads) and we can donate # them because weights and grads are immediately replaced and in # case of reversible layers, the outputs are never used again. rev_and_fbos.append( self._pjit(rev_and_fbo, f'rev+bwd {repr(layer)}', donate_argnums=(0, 1, 2))) # In standard layers, the inputs cannot be donated as they may be used # as outputs for the reversible block below, but weights and grads can. jit_std_fbo = self._pjit(std_fbo, f'bwd {repr(std_layer)}', donate_argnums=(1, 2)) self._fbos.append((jit_std_fbo, rev_and_fbos)) loss_fbo = _fbo_with_layer_and_opt(self._loss_layer, self._loss_opt, self._n_devices, 'loss', self._adasum) self._loss_fbo = self._pjit(loss_fbo, donate_argnums=(1, 2))
def __init__(self, model, task, eval_model=None, eval_task=None, output_dir=None, checkpoint_at=None, eval_at=None): """Configures a training `Loop`, including a random initialization. Args: model: Trax layer, representing the core model to be trained. Loss functions and eval functions (a.k.a. metrics) are considered to be outside the core model, taking core model output and data labels as their two inputs. task: TrainTask instance, which defines the training data, loss function, and optimizer to be used in this training loop. eval_model: Optional Trax layer, representing model used for evaluation, e.g., with dropout turned off. If None, the training model (model) will be used. eval_task: EvalTask instance or None. If None, don't do any evals. output_dir: Path telling where to save outputs (evals and checkpoints). Can be None if both `eval_task` and `checkpoint_at` are None. checkpoint_at: Function (integer --> boolean) telling, for step n, whether that step should have its checkpoint saved. If None, the default is periodic checkpointing at `task.n_steps_per_checkpoint`. eval_at: Function (integer --> boolean) that says, for training step n, whether that step should run evals. If None, run when checkpointing. """ self._task = task self._model = model self._eval_model = eval_model or model default_at = (_at_step_1_and_every_nth_step( self._task.n_steps_per_checkpoint)) if output_dir is not None: self._output_dir = os.path.expanduser(output_dir) tf.io.gfile.makedirs(self._output_dir) else: self._output_dir = None # Prepare training components. self._step = 0 self._checkpoint_at = checkpoint_at or default_at self._model_in_training = tl.Serial(self._model, self._task.loss_layer) self._batch_signature = shapes.signature(self._task.sample_batch) self._eval_model.init(self._batch_signature) self._model_in_training.init(self._batch_signature) self._task.optimizer.tree_init(self._model_in_training.weights) self._forward_and_backward_fn = ( fastmath.jit( fastmath.value_and_grad( self._model_in_training.pure_fn, argnums=1, # arg1 of pure_fn: weights has_aux=True))) # return (loss, state), gradients # Prepare eval components. if eval_task is None: self._eval_at = _never else: self._eval_task = eval_task self._eval_at = eval_at or default_at metric_name_lengths = [ len(name) for name in self._eval_task.metric_names ] self._rjust_len = max([len(self._task.loss_layer.name)] + metric_name_lengths) model_with_metrics = (_model_with_metrics(self._eval_model, self._eval_task)) self._eval_weights = model_with_metrics.weights[ 1] # just the eval part self._eval_state = model_with_metrics.state[ 1] # just the eval part self._metrics_fn = fastmath.jit(model_with_metrics.pure_fn) if self._output_dir is None: _log( 'Will not write evaluation metrics, because output_dir is None.' )
def one_step(self, batch, rng, step=0, learning_rate=None): """Updates layers weights/state and optimizers slots by running one step. Args: batch: Batch of data to use for optimization. rng: Random number generator to use for running this step. step: Which step of the training are we running. learning_rate: Learning rate to use instead of the default one. Returns: Tuple (loss, stats) with new values from one step of training, where stats are all optimizer statistics. """ # Update the learning rate if needed. if learning_rate is not None: self._replicated_loss_opt_params['learning_rate'] = tl.for_n_devices( learning_rate, self._n_devices) for (std_op, rev_ops) in self._replicated_opt_params: std_op['learning_rate'] = tl.for_n_devices( learning_rate, self._n_devices) for op in rev_ops: op['learning_rate'] = tl.for_n_devices( learning_rate, self._n_devices) # Batch needs to be split across the local devices -- the difference # between _for_n_devices and _reshape_by_device is that the latter splits # the batch dim to batch // n_devices, vs _for_n_devices # broadcasts/replicates to n_devices dimension. step_int = step if self._n_devices > 1: batch = tl.reshape_by_device(batch, self._n_devices) step = np.repeat(step, self._n_devices) # Create separate rng for each device and layer. if self._n_devices == 1: rngs = fastmath.random.split(rng, self._n_layers) else: # Splitting by device first to be identical with default trainer. def per_device_rngs(rng): # A function to JIT to not fragment memory. per_device_rng = fastmath.random.split(rng, self._n_devices) per_device_rngs = [ fastmath.random.split(r, self._n_layers) for r in per_device_rng] rngs = [jnp.stack([r[i] for r in per_device_rngs]) for i in range(self._n_layers)] return rngs # JIT the function and run it on CPU to avoid memory fragmentation. rngs = fastmath.jit(per_device_rngs, backend='cpu')(tl.on_cpu(rng)) # Group rngs by layer blocks. rng_blocks, rng_i = [], 0 for _, rev_layers in self._blocks: l = len(rev_layers) rng_blocks.append((rngs[rng_i], rngs[rng_i + 1: rng_i + l + 1])) rng_i += l + 1 # Run the layers forward upto the loss layer. process = psutil.Process(os.getpid()) logging.info('running step %d', step_int) if step_int % self._n_steps_per_log == 1: logging.info('run fwd: cpu memory use (MB): %.2f', process.memory_info().rss / float(1024 * 1024)) stack = batch block_inputs_states = [] for i, (std_layer, rev_layers) in enumerate(self._blocks): acc_std_layer_fn, acc_rev_layer_fns = self._accelerated_layer_fns[i] std_rng, rev_rngs = rng_blocks[i] # Run the standard layer. stack, std_inputs, std_state = self._run_forward_standard( stack, std_layer, acc_std_layer_fn, std_rng, step_int) # Run the reversible layers and collect old and new states. stack, rev_old_states, rev_new_states = self._run_forward_reversible( stack, rev_layers, acc_rev_layer_fns, rev_rngs, step_int) block_inputs_states.append( ((std_inputs, std_state), (rev_old_states, rev_new_states))) # Run the loss layer forward and backward with optimizer update. if step_int % self._n_steps_per_log == 1: logging.info('run loss: cpu memory use (MB): %.2f', process.memory_info().rss / float(1024 * 1024)) loss_state = self._replicate(self._loss_layer.state) loss_inputs = cb.inputs_from_stack(stack, self._loss_layer.n_in) loss_stats, grad_stack = self._run_backward_standard( None, step, self._loss_layer, loss_inputs, loss_state, self._loss_fbo, rngs[-1], self._loss_opt, self._replicated_loss_opt_params) stats = [loss_stats] # Run the layers backward and run optimizer updates. if step_int % self._n_steps_per_log == 1: logging.info('run bwd: cpu memory use (MB): %.2f', process.memory_info().rss / float(1024 * 1024)) for i in range(len(self._blocks) - 1, -1, -1): std_layer, rev_layers = self._blocks[i] (std_inputs, std_state), (rev_old_states, rev_new_states) = block_inputs_states[i] std_fbo, rev_fbos = self._fbos[i] std_opt, rev_opts = self._optimizers[i] std_rng, rev_rngs = rng_blocks[i] repl_std_opt_params, repl_rev_opts_params = self._replicated_opt_params[i] # Run reversible layers backward with optimizer update. stack, grad_stack, new_stats = self._run_backward_reversible( stack, grad_stack, step, rev_layers, rev_fbos, rev_old_states, rev_new_states, rev_rngs, rev_opts, repl_rev_opts_params) stats.extend(new_stats) # Run the standard layer forward-and-backward pass and optimizer update. std_layer_stats, grad_stack = self._run_backward_standard( grad_stack, step, std_layer, std_inputs, std_state, std_fbo, std_rng, std_opt, repl_std_opt_params) stack = cb.outputs_onto_stack( # Put layer inputs on the stack. std_inputs, stack, std_layer.n_out) stats.append(std_layer_stats) # Join stats from different optimizers into one. joint_stats = {} for i, stat in enumerate(reversed(stats)): for k, v in stat.items(): joint_stats[f'layer{i}/' + k] = v return stats[0]['loss'], joint_stats
def _accelerate(f, n_devices): """JIT-compiled version of `f` running on `n_devices`.""" if n_devices == 1: return fastmath.jit(f) return fastmath.pmap(f, axis_name='batch')