def reverse_and_grad(self, output, grad, weights=(), state=(), new_state=(), rng=None): rngs = (None, ) * self._n_layers if rng is not None: rngs = fastmath.random.split(rng, self._n_layers) stack = output stack_grad = grad weights_grad = [] for layer, p, s, ns, rng in reversed( list(zip(self.sublayers, weights, state, new_state, rngs))): layer_val = cb.inputs_from_stack(stack, layer.n_out) layer_ct = cb.inputs_from_stack(stack_grad, layer.n_out) layer_val, layer_ct = layer.reverse_and_grad(layer_val, layer_ct, p, s, ns, rng=rng) layer_ct, p_ct = layer_ct weights_grad.insert(0, p_ct) stack = cb.outputs_onto_stack(layer_val, stack, layer.n_out) stack_grad = cb.outputs_onto_stack(layer_ct, stack_grad, layer.n_out) return stack, (stack_grad, tuple(weights_grad))
def _run_backward_one_reversible(self, layer, stack, grad_stack, step, rng, optimizer, opt_params, reverse_and_fbo, old_state, new_state): """Run one reversible layer backwards.""" # We are running backwards and reversing, so we get *outputs* from stack. outputs = cb.inputs_from_stack(stack, layer.n_out) grads = cb.inputs_from_stack(grad_stack, layer.n_out) slots = self._replicate(optimizer.slots) weights = self._replicate(layer.weights) # cpu -> accelerator # Ensure all arguments are on accelerator. outputs = tl.on_accelerator(outputs) grads = tl.on_accelerator(grads) old_state = tl.on_accelerator(old_state) new_state = tl.on_accelerator(new_state) opt_params = tl.on_accelerator(opt_params) rng = tl.on_accelerator(rng) new_weights, new_slots, inputs, grads, layer_stats = reverse_and_fbo( outputs, weights, grads, old_state, new_state, slots, opt_params, rng, step) layer.weights = self._lazy_unreplicate( new_weights) # accelerator -> cpu layer.state = self._unreplicate(new_state) optimizer.slots = self._unreplicate(new_slots) stack = cb.outputs_onto_stack(inputs, stack, layer.n_out) grad_stack = cb.outputs_onto_stack(grads, grad_stack, layer.n_out) return stack, grad_stack, layer_stats
def _run_backward_reversible(self, stack, grad_stack, step, rev_layers, rev_and_fbos, old_states, new_states, rngs, optimizers, replicated_opt_params): """Run reversible layers backwards.""" counter = 0 stats = [] for layer, reverse_and_fbo, old_state, new_state, rng in reversed(list(zip( rev_layers, rev_and_fbos, old_states, new_states, rngs))): counter -= 1 # We are running backwards and reversing, so we get *outputs* from stack. outputs = cb.inputs_from_stack(stack, layer.n_out) grads = cb.inputs_from_stack(grad_stack, layer.n_out) slots = self._replicate(optimizers[counter].slots) opt_params = replicated_opt_params[counter] weights = self._replicate(layer.weights) # cpu -> accelerator new_weights, new_slots, inputs, grads, layer_stats = reverse_and_fbo( outputs, weights, old_state, new_state, slots, opt_params, rng, step, grads) layer.weights = self._unreplicate(new_weights) # accelerator -> cpu layer.state = self._unreplicate(new_state) optimizers[counter].slots = self._unreplicate(new_slots) stats.append(layer_stats) stack = cb.outputs_onto_stack(inputs, stack, layer.n_out) grad_stack = cb.outputs_onto_stack(grads, grad_stack, layer.n_out) return stack, grad_stack, stats
def init_reversible_blocks(blocks, loss_layer, input_signature, rng): """Initialize reversible blocks and the loss layer and place weights on CPU. Args: blocks: List of reversible blocks (pairs of layer lists). loss_layer: The final loss layer to initialize. input_signature: The signature of the input to the blocks. rng: Random key used to initialize the layers. """ sig_stack = input_signature process = psutil.Process(os.getpid()) mem_use = process.memory_info().rss for (std_layers, rev_layers) in blocks: rngs = fastmath.random.split(rng, len(std_layers) + len(rev_layers) + 1) rng = rngs[0] for layer, layer_rng in zip(std_layers + rev_layers, rngs[1:]): sig = cb.inputs_from_stack(sig_stack, layer.n_in) layer.init(sig, rng=layer_rng) layer.weights = tl.on_cpu( layer.weights) # store weights in cpu memory layer.state = tl.on_cpu(layer.state) # store weights in cpu memory logging.info('init: layer %s\nadded cpu memory (MB): %.2f', str(layer), (process.memory_info().rss - mem_use) / float(1024 * 1024)) mem_use = process.memory_info().rss logging.info('init: cpu memory use (MB): %.2f', mem_use / float(1024 * 1024)) out_sig = layer.output_signature(sig) sig_stack = cb.outputs_onto_stack(out_sig, sig_stack, layer.n_in) loss_layer.init(cb.inputs_from_stack(sig_stack, loss_layer.n_in), rng=rng) loss_layer.weights = tl.on_cpu(loss_layer.weights) loss_layer.state = tl.on_cpu(loss_layer.state)
def _run_backward_standard(self, grad_stack, step, layer, inp, state, fbo_fn, rng, optimizer, replicated_opt_params): """Run reversible layers backwards.""" step_int = int(step) if self._n_devices < 2 else int(step[0]) if step_int % self._n_steps_per_log == 1: logging.info('running backward standard layer %s', str(layer)) if grad_stack is not None: grads = cb.inputs_from_stack(grad_stack, layer.n_out) else: grads = None slots = self._replicate(optimizer.slots) weights = self._replicate(layer.weights) # Ensure all arguments are on accelerator. state = tl.on_accelerator(state) replicated_opt_params = tl.on_accelerator(replicated_opt_params) rng = tl.on_accelerator(rng) grads = tl.on_accelerator(grads) inp = tl.on_accelerator(inp) new_weights, new_state, new_slots, new_grads, stats = fbo_fn( inp, weights, grads, state, slots, replicated_opt_params, rng, step) layer.weights = self._lazy_unreplicate(new_weights) layer.state = self._unreplicate(new_state) optimizer.slots = self._unreplicate(new_slots) if grad_stack is not None: grad_stack = cb.outputs_onto_stack(new_grads, grad_stack, layer.n_out) else: grad_stack = new_grads return stats, grad_stack
def _run_forward_standard(self, stack, layer, accelerated_fn, rng): """Run standard layer forward.""" layer_inputs = cb.inputs_from_stack(stack, layer.n_in) layer_weights = self._replicate(layer.weights) layer_state = self._replicate(layer.state) outputs, layer_new_state = accelerated_fn( layer_inputs, layer_weights, layer_state, rng) stack = cb.outputs_onto_stack(outputs, stack, layer.n_in) return stack, layer_inputs, layer_new_state
def _run_forward_standard(self, stack, layer, accelerated_fn, rng, step): """Run standard layer forward.""" if step % self._n_steps_per_log == 1: logging.info('running forward standard layer %s', str(layer)) layer_inputs = cb.inputs_from_stack(stack, layer.n_in) layer_weights = self._replicate(layer.weights) layer_state = self._replicate(layer.state) outputs, layer_new_state = accelerated_fn(layer_inputs, layer_weights, layer_state, rng) stack = cb.outputs_onto_stack(outputs, stack, layer.n_in) return stack, layer_inputs, layer_new_state
def reverse(self, output, weights=(), state=(), new_state=(), rng=None): rngs = (None,) * self._n_layers if rng is not None: rngs = fastmath.random.split(rng, self._n_layers) stack = output for layer, p, s, ns, rng in reversed(list(zip( self.sublayers, weights, state, new_state, rngs))): layer_val = cb.inputs_from_stack(stack, layer.n_out) layer_val = layer.reverse(layer_val, p, s, ns, rng=rng) stack = cb.outputs_onto_stack(layer_val, stack, layer.n_out) return stack
def _run_forward_reversible(self, stack, rev_layers, accelerated_fns, rngs): """Run reversible layers forward, collect states for backwards pass.""" old_states, new_states = [], [] for i, layer in enumerate(rev_layers): weights = self._replicate(layer.weights) # also copies cpu -> accelerator state = self._replicate(layer.state) old_states.append(state) inputs = cb.inputs_from_stack(stack, layer.n_in) outputs, new_state = accelerated_fns[i]( inputs, weights, state, rngs[i]) stack = cb.outputs_onto_stack(outputs, stack, layer.n_in) new_states.append(new_state) return stack, old_states, new_states
def forward(self, xs): rngs = _split_rngs(self.rng, len(self.sublayers)) accumulator, *context = xs stack = context = tuple(context) new_state = [] for layer, w, s, rng in zip(self.sublayers, self.weights, self.state, rngs): inputs = cb.inputs_from_stack(stack, layer.n_in) outputs, s = layer.pure_fn(inputs, w, s, rng) stack = cb.outputs_onto_stack(outputs, stack, layer.n_in) new_state.append(s) residual = stack[0] if isinstance(stack, (tuple, list)) else stack output = accumulator + residual stack = (output,) + context self.state = tuple(new_state) return stack
def init_weights_and_state(self, input_signature): stack = input_signature[1:] if len(stack) == 1: stack = stack[0] inputs = cb.inputs_from_stack(stack, self.compute_residual.n_in) weights, state = self.compute_residual.init(inputs) outputs, _ = self.compute_residual._forward_abstract(inputs) stack = cb.outputs_onto_stack(outputs, stack, self.compute_residual.n_in) if self.attention_layer is None: self.state = (state,) self.weights = (weights,) else: inputs = cb.inputs_from_stack(stack, self.attention_layer.n_in) attn_weights, attn_state = self.attention_layer.init(inputs) self.state = (state, attn_state) self.weights = (weights, attn_weights)
def _run_backward_standard(self, grad_stack, step, layer, inp, state, fbo_fn, rng, optimizer, replicated_opt_params): """Run reversible layers backwards.""" if grad_stack is not None: grads = cb.inputs_from_stack(grad_stack, layer.n_out) else: grads = None slots = self._replicate(optimizer.slots) weights = self._replicate(layer.weights) new_weights, new_state, new_slots, new_grads, stats = fbo_fn( inp, weights, state, slots, replicated_opt_params, rng, step, grads) layer.weights = self._unreplicate(new_weights) layer.state = self._unreplicate(new_state) optimizer.slots = self._unreplicate(new_slots) if grad_stack is not None: grad_stack = cb.outputs_onto_stack(new_grads, grad_stack, layer.n_out) else: grad_stack = new_grads return stats, grad_stack
def forward(self, xs): rngs = _split_rngs(self.rng, len(self.sublayers)) accumulator, *context = xs stack = context = tuple(context) new_state = [] for layer, w, s, rng in zip(self.sublayers, self.weights, self.state, rngs): inputs = cb.inputs_from_stack(stack, layer.n_in) if base.N_WEIGHTS_SHARDS > 1: # With sharded weights, make sure we don't keep them concatenated # in memory on each device by using remat. outputs, s = jax.remat(layer.pure_fn)(inputs, w, s, rng) else: outputs, s = layer.pure_fn(inputs, w, s, rng) stack = cb.outputs_onto_stack(outputs, stack, layer.n_in) new_state.append(s) residual = stack[0] if isinstance(stack, (tuple, list)) else stack output = accumulator + residual stack = (output,) + context self.state = tuple(new_state) return stack
def init_reversible_blocks(blocks, loss_layer, input_signature, rng): """Initialize reversible blocks and the loss layer and place weights on CPU. Args: blocks: List of reversible blocks (pairs of layer lists). loss_layer: The final loss layer to initialize. input_signature: The signature of the input to the blocks. rng: Random key used to initialize the layers. """ sig_stack = input_signature for (std_layers, rev_layers) in blocks: rngs = fastmath.random.split(rng, len(std_layers) + len(rev_layers) + 1) rng = rngs[0] for layer, layer_rng in zip(std_layers + rev_layers, rngs[1:]): sig = cb.inputs_from_stack(sig_stack, layer.n_in) layer.init(sig, rng=layer_rng) layer.weights = tl.on_cpu(layer.weights) # store weights in cpu memory out_sig = layer.output_signature(sig) sig_stack = cb.outputs_onto_stack(out_sig, sig_stack, layer.n_in) loss_layer.init(cb.inputs_from_stack(sig_stack, loss_layer.n_in), rng=rng) loss_layer.weights = tl.on_cpu(loss_layer.weights)
def reverse_and_grad(self, output, ct, weights=(), state=(), new_state=(), rng=None): rngs = _split_rngs(rng, len(self.sublayers)) accumulator_output, *context = output context = tuple(context) accumulator_output_ct, *context_ct = ct context_ct = tuple(context_ct) # Forward pass through self.compute_residual. Outputs that will not receive # a gradient signal from subsequent layers are moved to aux. def call_compute_residual(x, weights): res, _ = self.compute_residual.pure_fn( x, weights=weights, state=state[0], rng=rngs[0]) if not isinstance(res, (tuple, list)): return res, None else: n_differentiable = 1 if self.attention_layer is not None: n_differentiable = min(len(res), self.attention_layer.n_in) return res[:n_differentiable], res[n_differentiable:] stack = context inputs = cb.inputs_from_stack(stack, self.compute_residual.n_in) outputs, compute_residual_vjpfun, outputs_aux = fastmath.vjp( call_compute_residual, inputs, weights[0], has_aux=True) if outputs_aux is not None: n_differentiable_outputs = len(outputs) outputs = outputs + outputs_aux stack = cb.outputs_onto_stack(outputs, stack, self.compute_residual.n_in) stack_ct = accumulator_output_ct if self.attention_layer is None: residual = stack[0] if isinstance(stack, (tuple, list)) else stack else: inputs = cb.inputs_from_stack(stack, self.attention_layer.n_in) (residual, _, attn_inputs_ct, attn_weights_ct ) = self._forward_and_or_backward( inputs, weights[1], new_state[1], rngs[1], output_grad=accumulator_output_ct, compute_output=True, update_state=False) stack_ct = cb.outputs_onto_stack( attn_inputs_ct, stack_ct, self.attention_layer.n_out) compute_residual_ct = cb.inputs_from_stack( stack_ct, self.compute_residual.n_out) if outputs_aux is not None: if not isinstance(compute_residual_ct, (tuple, list)): compute_residual_ct = (compute_residual_ct,) compute_residual_ct = compute_residual_ct[:n_differentiable_outputs] assert len(compute_residual_ct) == n_differentiable_outputs (compute_residual_inputs_ct, compute_residual_weights_ct ) = compute_residual_vjpfun(compute_residual_ct) stack_ct = cb.outputs_onto_stack( compute_residual_inputs_ct, stack_ct, self.compute_residual.n_out) if not isinstance(stack_ct, (tuple, list)): stack_ct = (stack_ct,) stack_ct = (accumulator_output_ct,) + fastmath.nested_map_multiarg( lambda x, y: x+y, context_ct[:len(stack_ct)], stack_ct ) + context_ct[len(stack_ct):] reconstructed_x = accumulator_output - residual stack = (reconstructed_x,) + context if self.attention_layer is None: weights_ct = (compute_residual_weights_ct,) else: weights_ct = (compute_residual_weights_ct, attn_weights_ct) return stack, (stack_ct, weights_ct)
def reverse_and_grad(self, output, ct, weights=(), state=(), new_state=(), rng=None): rngs = _split_rngs(rng, len(self.sublayers)) accumulator_output, *context = output context = tuple(context) accumulator_output_ct, *context_ct = ct context_ct = tuple(context_ct) # Forward pass through self._compute_residual. Outputs that will not receive # a gradient signal from subsequent layers are moved to aux. def call_compute_residual(x, weights): state_to_pass = state[0] # old_state # _replace_second_time is currently used exclusively in _RememberInReverse # layer to combat numerical instability in Reformer2 when quantizing # the mask in SparseFF. def _replace_second_time(stt, nstt): if (isinstance(stt, tuple) and len(stt) == 2 and isinstance(stt[1], dict) and 'running_second_time' in stt[1]): return (nstt[0], {'running_second_time_yes': ()}) elif isinstance(stt, (tuple, list)): assert isinstance(nstt, (tuple, list)) and len(nstt) == len(stt) return type(stt)([ _replace_second_time(s, ns) for s, ns in zip(stt, nstt) ]) else: return stt state_to_pass = _replace_second_time(state_to_pass, new_state[0]) res, _ = self._compute_residual.pure_fn(x, weights=weights, state=state_to_pass, rng=rngs[0]) if not isinstance(res, (tuple, list)): return res, None else: n_differentiable = 1 if self._attention_layer is not None: n_differentiable = min(len(res), self._attention_layer.n_in) return res[:n_differentiable], res[n_differentiable:] stack = context inputs = cb.inputs_from_stack(stack, self._compute_residual.n_in) outputs, compute_residual_vjpfun, outputs_aux = fastmath.vjp( call_compute_residual, inputs, weights[0], has_aux=True) if outputs_aux is not None: n_differentiable_outputs = len(outputs) outputs = outputs + outputs_aux stack = cb.outputs_onto_stack(outputs, stack, self._compute_residual.n_in) stack_ct = accumulator_output_ct if self._attention_layer is None: residual = stack[0] if isinstance(stack, (tuple, list)) else stack else: inputs = cb.inputs_from_stack(stack, self._attention_layer.n_in) (residual, _, attn_inputs_ct, attn_weights_ct) = self._forward_and_or_backward( inputs, weights[1], new_state[1], rngs[1], output_grad=accumulator_output_ct, compute_output=True, update_state=False) stack_ct = cb.outputs_onto_stack(attn_inputs_ct, stack_ct, self._attention_layer.n_out) compute_residual_ct = cb.inputs_from_stack( stack_ct, self._compute_residual.n_out) if outputs_aux is not None: if not isinstance(compute_residual_ct, (tuple, list)): compute_residual_ct = (compute_residual_ct, ) compute_residual_ct = compute_residual_ct[: n_differentiable_outputs] assert len(compute_residual_ct) == n_differentiable_outputs (compute_residual_inputs_ct, compute_residual_weights_ct ) = compute_residual_vjpfun(compute_residual_ct) stack_ct = cb.outputs_onto_stack(compute_residual_inputs_ct, stack_ct, self._compute_residual.n_out) if not isinstance(stack_ct, (tuple, list)): stack_ct = (stack_ct, ) def _add(x, y): # `None` is for TFNP backend, which uses `None` as the gradient of # int/bool instead of an array of dtype `float0`. if x is None or x.dtype == jax.float0: return y if y is None or y.dtype == jax.float0: return x return x + y stack_ct = (accumulator_output_ct, ) + fastmath.nested_map_multiarg( _add, context_ct[:len(stack_ct)], stack_ct) + context_ct[len(stack_ct):] reconstructed_x = accumulator_output - residual stack = (reconstructed_x, ) + context if self._attention_layer is None: weights_ct = (compute_residual_weights_ct, ) else: weights_ct = (compute_residual_weights_ct, attn_weights_ct) return stack, (stack_ct, weights_ct)
def one_step(self, batch, rng, step=0, learning_rate=None): """Updates layers weights/state and optimizers slots by running one step. Args: batch: Batch of data to use for optimization. rng: Random number generator to use for running this step. step: Which step of the training are we running. learning_rate: Learning rate to use instead of the default one. Returns: Tuple (loss, stats) with new values from one step of training, where stats are all optimizer statistics. """ # Update the learning rate if needed. if learning_rate is not None: self._replicated_loss_opt_params[ 'learning_rate'] = self._replicate_cpu(learning_rate) for (std_op, rev_ops) in self._replicated_opt_params: std_op['learning_rate'] = self._replicate_cpu(learning_rate) for op in rev_ops: op['learning_rate'] = self._replicate_cpu(learning_rate) # Batch needs to be split across the local devices -- the difference # between _for_n_devices and _reshape_by_device is that the latter splits # the batch dim to batch // n_devices, vs _for_n_devices # broadcasts/replicates to n_devices dimension. step_int = step if self._n_devices > 1: batch = tl.reshape_by_device(batch, self._n_devices, pure_np=True) step = np.repeat(step, self._n_devices) # Create separate rng for each device and layer. if self._n_devices == 1: rngs = fastmath.random.split(rng, self._n_layers) else: # JIT the function and run it on CPU to avoid memory fragmentation. rngs = self._jit_per_device_rngs(tl.on_cpu(rng)) # Group rngs by layer blocks. rng_blocks, rng_i = [], 0 for _, rev_layers in self._blocks: l = len(rev_layers) rng_blocks.append((rngs[rng_i], rngs[rng_i + 1:rng_i + l + 1])) rng_i += l + 1 # Run the layers forward upto the loss layer. if self._do_free: self._free_accelerators() process = psutil.Process(os.getpid()) if isinstance(batch, (list, tuple)): batch_shapes = [x.shape for x in batch] else: batch_shapes = batch.shape logging.info('running step %d on shapes %s', step_int, str(batch_shapes)) if step_int % self._n_steps_per_log == 1: logging.info('run fwd: cpu memory use (MB): %.2f', process.memory_info().rss / float(1024 * 1024)) stack = batch block_inputs_states = [] for i, (std_layer, rev_layers) in enumerate(self._blocks): acc_std_layer_fn, acc_rev_layer_fns = self._accelerated_layer_fns[ i] std_rng, rev_rngs = rng_blocks[i] # Run the standard layer. stack, std_inputs, std_state = self._run_forward_standard( stack, std_layer, acc_std_layer_fn, std_rng, step_int) # Run the reversible layers and collect old and new states. stack, rev_old_states, rev_new_states = self._run_forward_reversible( stack, rev_layers, acc_rev_layer_fns, rev_rngs, step_int) block_inputs_states.append( tl.on_cpu(((std_inputs, std_state), (rev_old_states, rev_new_states)))) # Run the loss layer forward and backward with optimizer update. if step_int % self._n_steps_per_log == 1: logging.info('run loss: cpu memory use (MB): %.2f', process.memory_info().rss / float(1024 * 1024)) loss_state = self._replicate(self._loss_layer.state) loss_inputs = cb.inputs_from_stack(stack, self._loss_layer.n_in) loss_stats, grad_stack = self._run_backward_standard( None, step, self._loss_layer, loss_inputs, loss_state, self._loss_fbo, rngs[-1], self._loss_opt, self._replicated_loss_opt_params) self._collect_weights(self._loss_layer) stats = [tl.on_cpu(loss_stats)] # De-fragment memory. if self._do_free: stack, grad_stack = tl.on_cpu(stack), tl.on_cpu(grad_stack) self._free_accelerators() # Run the layers backward and run optimizer updates. if step_int % self._n_steps_per_log == 1: logging.info('run bwd: cpu memory use (MB): %.2f', process.memory_info().rss / float(1024 * 1024)) for i in range(len(self._blocks) - 1, -1, -1): std_layer, rev_layers = self._blocks[i] (std_inputs, std_state), (rev_old_states, rev_new_states) = block_inputs_states[i] std_fbo, rev_fbos = self._fbos[i] std_opt, rev_opts = self._optimizers[i] std_rng, rev_rngs = rng_blocks[i] repl_std_opt_params, repl_rev_opts_params = self._replicated_opt_params[ i] # Run reversible layers backward with optimizer update. stack, grad_stack, new_stats = self._run_backward_reversible( stack, grad_stack, step, rev_layers, rev_fbos, rev_old_states, rev_new_states, rev_rngs, rev_opts, repl_rev_opts_params) stats.extend(tl.on_cpu(new_stats)) # Run the standard layer forward-and-backward pass and optimizer update. std_layer_stats, grad_stack = self._run_backward_standard( grad_stack, step, std_layer, std_inputs, std_state, std_fbo, std_rng, std_opt, repl_std_opt_params) stack = cb.outputs_onto_stack( # Put layer inputs on the stack. std_inputs, stack, std_layer.n_out) stats.append(tl.on_cpu(std_layer_stats)) # Collect lazily unreplicated layer weights. for rev_layer_id in range(self._n_async_layers): self._collect_weights(rev_layers[rev_layer_id]) self._collect_weights(std_layer) # Join stats from different optimizers into one. joint_stats = {} for i, stat in enumerate(reversed(stats)): for k, v in stat.items(): joint_stats[f'layer{i}/' + k] = v return stats[0]['loss'], joint_stats
def one_step(self, batch, rng, step=0, learning_rate=None): """Updates layers weights/state and optimizers slots by running one step. Args: batch: Batch of data to use for optimization. rng: Random number generator to use for running this step. step: Which step of the training are we running. learning_rate: Learning rate to use instead of the default one. Returns: Tuple (loss, stats) with new values from one step of training, where stats are all optimizer statistics. """ # Update the learning rate if needed. if learning_rate is not None: self._replicated_loss_opt_params['learning_rate'] = tl.for_n_devices( learning_rate, self._n_devices) for (std_op, rev_ops) in self._replicated_opt_params: std_op['learning_rate'] = tl.for_n_devices( learning_rate, self._n_devices) for op in rev_ops: op['learning_rate'] = tl.for_n_devices( learning_rate, self._n_devices) # Batch needs to be split across the local devices -- the difference # between _for_n_devices and _reshape_by_device is that the latter splits # the batch dim to batch // n_devices, vs _for_n_devices # broadcasts/replicates to n_devices dimension. if self._n_devices > 1: batch = tl.reshape_by_device(batch, self._n_devices) step = jnp.repeat(step, self._n_devices) # Create separate rng for each device and layer. if self._n_devices == 1: rngs = fastmath.random.split(rng, self._n_layers) else: # Splitting by device first to be identical with default trainer. per_device_rng = fastmath.random.split(rng, self._n_devices) per_device_rngs = [ fastmath.random.split(r, self._n_layers) for r in per_device_rng] rngs = [jnp.stack([r[i] for r in per_device_rngs]) for i in range(self._n_layers)] # Group rngs by layer blocks. rng_blocks, rng_i = [], 0 for _, rev_layers in self._blocks: l = len(rev_layers) rng_blocks.append((rngs[rng_i], rngs[rng_i + 1: rng_i + l + 1])) rng_i += l + 1 # Run the layers forward upto the loss layer. stack = batch block_inputs_states = [] for i, (std_layer, rev_layers) in enumerate(self._blocks): acc_std_layer_fn, acc_rev_layer_fns = self._accelerated_layer_fns[i] std_rng, rev_rngs = rng_blocks[i] # Run the standard layer. stack, std_inputs, std_state = self._run_forward_standard( stack, std_layer, acc_std_layer_fn, std_rng) # Run the reversible layers and collect old and new states. stack, rev_old_states, rev_new_states = self._run_forward_reversible( stack, rev_layers, acc_rev_layer_fns, rev_rngs) block_inputs_states.append( ((std_inputs, std_state), (rev_old_states, rev_new_states))) # Run the loss layer forward and backward with optimizer update. loss_state = self._replicate(self._loss_layer.state) loss_inputs = cb.inputs_from_stack(stack, self._loss_layer.n_in) loss_stats, grad_stack = self._run_backward_standard( None, step, self._loss_layer, loss_inputs, loss_state, self._loss_fbo, rngs[-1], self._loss_opt, self._replicated_loss_opt_params) stats = [loss_stats] # Run the layers backward and run optimizer updates. for i in range(len(self._blocks) - 1, -1, -1): std_layer, rev_layers = self._blocks[i] (std_inputs, std_state), (rev_old_states, rev_new_states) = block_inputs_states[i] std_fbo, rev_fbos = self._fbos[i] std_opt, rev_opts = self._optimizers[i] std_rng, rev_rngs = rng_blocks[i] repl_std_opt_params, repl_rev_opts_params = self._replicated_opt_params[i] # Run reversible layers backward with optimizer update. stack, grad_stack, new_stats = self._run_backward_reversible( stack, grad_stack, step, rev_layers, rev_fbos, rev_old_states, rev_new_states, rev_rngs, rev_opts, repl_rev_opts_params) stats.extend(new_stats) # Run the standard layer forward-and-backward pass and optimizer update. std_layer_stats, grad_stack = self._run_backward_standard( grad_stack, step, std_layer, std_inputs, std_state, std_fbo, std_rng, std_opt, repl_std_opt_params) stack = cb.outputs_onto_stack( # Put layer inputs on the stack. std_inputs, stack, std_layer.n_out) stats.append(std_layer_stats) # Join stats from different optimizers into one. joint_stats = {} for i, stat in enumerate(reversed(stats)): for k, v in stat.items(): joint_stats[f'layer{i}/' + k] = v return stats[0]['loss'], joint_stats