コード例 #1
0
 def _run_backward_standard(self, grad_stack, step, layer, inp, state,
                            fbo_fn, rng, optimizer, replicated_opt_params):
     """Run reversible layers backwards."""
     step_int = int(step) if self._n_devices < 2 else int(step[0])
     if step_int % self._n_steps_per_log == 1:
         logging.info('running backward standard layer %s', str(layer))
     if grad_stack is not None:
         grads = cb.inputs_from_stack(grad_stack, layer.n_out)
     else:
         grads = None
     slots = self._replicate(optimizer.slots)
     weights = self._replicate(layer.weights)
     # Ensure all arguments are on accelerator.
     state = tl.on_accelerator(state)
     replicated_opt_params = tl.on_accelerator(replicated_opt_params)
     rng = tl.on_accelerator(rng)
     grads = tl.on_accelerator(grads)
     inp = tl.on_accelerator(inp)
     new_weights, new_state, new_slots, new_grads, stats = fbo_fn(
         inp, weights, grads, state, slots, replicated_opt_params, rng,
         step)
     layer.weights = self._lazy_unreplicate(new_weights)
     layer.state = self._unreplicate(new_state)
     optimizer.slots = self._unreplicate(new_slots)
     if grad_stack is not None:
         grad_stack = cb.outputs_onto_stack(new_grads, grad_stack,
                                            layer.n_out)
     else:
         grad_stack = new_grads
     return stats, grad_stack
コード例 #2
0
 def _run_backward_one_reversible(self, layer, stack, grad_stack, step, rng,
                                  optimizer, opt_params, reverse_and_fbo,
                                  old_state, new_state):
     """Run one reversible layer backwards."""
     # We are running backwards and reversing, so we get *outputs* from stack.
     outputs = cb.inputs_from_stack(stack, layer.n_out)
     grads = cb.inputs_from_stack(grad_stack, layer.n_out)
     slots = self._replicate(optimizer.slots)
     weights = self._replicate(layer.weights)  # cpu -> accelerator
     # Ensure all arguments are on accelerator.
     outputs = tl.on_accelerator(outputs)
     grads = tl.on_accelerator(grads)
     old_state = tl.on_accelerator(old_state)
     new_state = tl.on_accelerator(new_state)
     opt_params = tl.on_accelerator(opt_params)
     rng = tl.on_accelerator(rng)
     new_weights, new_slots, inputs, grads, layer_stats = reverse_and_fbo(
         outputs, weights, grads, old_state, new_state, slots, opt_params,
         rng, step)
     layer.weights = self._lazy_unreplicate(
         new_weights)  # accelerator -> cpu
     layer.state = self._unreplicate(new_state)
     optimizer.slots = self._unreplicate(new_slots)
     stack = cb.outputs_onto_stack(inputs, stack, layer.n_out)
     grad_stack = cb.outputs_onto_stack(grads, grad_stack, layer.n_out)
     return stack, grad_stack, layer_stats
コード例 #3
0
ファイル: trainer.py プロジェクト: wangdongya/trax
 def _replicate(self, x):
   if self._n_devices > 1:
     return tl.for_n_devices(x, self._n_devices)
   return tl.on_accelerator(x)