def _run_backward_standard(self, grad_stack, step, layer, inp, state, fbo_fn, rng, optimizer, replicated_opt_params): """Run reversible layers backwards.""" step_int = int(step) if self._n_devices < 2 else int(step[0]) if step_int % self._n_steps_per_log == 1: logging.info('running backward standard layer %s', str(layer)) if grad_stack is not None: grads = cb.inputs_from_stack(grad_stack, layer.n_out) else: grads = None slots = self._replicate(optimizer.slots) weights = self._replicate(layer.weights) # Ensure all arguments are on accelerator. state = tl.on_accelerator(state) replicated_opt_params = tl.on_accelerator(replicated_opt_params) rng = tl.on_accelerator(rng) grads = tl.on_accelerator(grads) inp = tl.on_accelerator(inp) new_weights, new_state, new_slots, new_grads, stats = fbo_fn( inp, weights, grads, state, slots, replicated_opt_params, rng, step) layer.weights = self._lazy_unreplicate(new_weights) layer.state = self._unreplicate(new_state) optimizer.slots = self._unreplicate(new_slots) if grad_stack is not None: grad_stack = cb.outputs_onto_stack(new_grads, grad_stack, layer.n_out) else: grad_stack = new_grads return stats, grad_stack
def _run_backward_one_reversible(self, layer, stack, grad_stack, step, rng, optimizer, opt_params, reverse_and_fbo, old_state, new_state): """Run one reversible layer backwards.""" # We are running backwards and reversing, so we get *outputs* from stack. outputs = cb.inputs_from_stack(stack, layer.n_out) grads = cb.inputs_from_stack(grad_stack, layer.n_out) slots = self._replicate(optimizer.slots) weights = self._replicate(layer.weights) # cpu -> accelerator # Ensure all arguments are on accelerator. outputs = tl.on_accelerator(outputs) grads = tl.on_accelerator(grads) old_state = tl.on_accelerator(old_state) new_state = tl.on_accelerator(new_state) opt_params = tl.on_accelerator(opt_params) rng = tl.on_accelerator(rng) new_weights, new_slots, inputs, grads, layer_stats = reverse_and_fbo( outputs, weights, grads, old_state, new_state, slots, opt_params, rng, step) layer.weights = self._lazy_unreplicate( new_weights) # accelerator -> cpu layer.state = self._unreplicate(new_state) optimizer.slots = self._unreplicate(new_slots) stack = cb.outputs_onto_stack(inputs, stack, layer.n_out) grad_stack = cb.outputs_onto_stack(grads, grad_stack, layer.n_out) return stack, grad_stack, layer_stats
def _replicate(self, x): if self._n_devices > 1: return tl.for_n_devices(x, self._n_devices) return tl.on_accelerator(x)