def train(output_dir, model=gin.REQUIRED, loss_fn=tl.WeightedCategoryCrossEntropy(), inputs=trax_inputs.batcher, optimizer=trax_opt.Adafactor, lr_schedule_fn=lr.multifactor, trainer_class=Trainer, steps=1000, checkpoints_at=None, permanent_checkpoints_at=None, eval_steps=10, eval_frequency=100, permanent_checkpoint_frequency=None, random_seed=None, save_graphs=True, metrics=None, checkpoint_highest=None, checkpoint_lowest=None, use_loop=True, loss_chunk_size=0, use_memory_efficient_trainer=False): """Train the model on the inputs. Args: output_dir: Directory where to put the logs and checkpoints. model: The model to train as a callable returning 2 callables, an init_fn and apply_fn. loss_fn: callable with signature: weights, trax.inputs.Inputs, model, state, rng -> loss. inputs: callable returning trax.inputs.Inputs. optimizer: The optimizer (see optimizers/base.py for signature). lr_schedule_fn: A learning rate schedule function, that when called returns a function from step to learning rate (a float). trainer_class: The trainer class to use. steps: int, total number of training steps. checkpoints_at: list of integers. Save a checkpoint for each training step in the list. permanent_checkpoints_at: list of integers. Save a permanent checkpoint for each training step in the list. eval_steps: int, num of steps per evaluation. If None or 0, eval disabled. eval_frequency: int, how often to run evaluation (every eval_frequency steps). If None or 0, eval disabled. permanent_checkpoint_frequency: int, how often to save permanent checkpoints (every permanent_checkpoint_frequency steps). random_seed: the random seed to use; time/os dependent if None (default). save_graphs: bool, if True, save computation graph to file. metrics: optionally override the default metrics dictionary. checkpoint_highest: save the checkpoint highest at this metric. checkpoint_lowest: save the checkpoint lowest at this metric. use_loop: whether to use training.Loop instead of Trainer. loss_chunk_size: int, if > 0 chunk loss into these sizes to save memory. use_memory_efficient_trainer: whether to use memory-efficient trainer. Returns: trax.TrainerState or training.Loop if use_loop is True """ if (permanent_checkpoint_frequency is not None and permanent_checkpoints_at is not None): raise ValueError('Only one of ["permanent_checkpoint_frequency", ' '"permanent_checkpoints_at"] should be set.') if use_loop: n_devices = num_devices() or fastmath.device_count() # Prepare the training task. # Inputs is either an Inputs instance or a function that returns it. if callable( inputs): # If we pass a function, e.g., through gin, call it. inputs = inputs() opt = optimizer if use_memory_efficient_trainer else optimizer() train_task = training.TrainTask( inputs.train_stream(n_devices), loss_layer=loss_fn, optimizer=opt, lr_schedule=lr_schedule_fn(), n_steps_per_checkpoint=eval_frequency, n_steps_per_permanent_checkpoint=permanent_checkpoint_frequency) # Prepare the evaluation. metrics_dict = metrics if metrics is not None else _DEFAULT_METRICS names, metrics = zip(*metrics_dict.items()) eval_task = training.EvalTask(inputs.eval_stream(n_devices), metrics, metric_names=names, n_eval_batches=eval_steps) # Prepare the training loop. checkpoint_at = None if checkpoints_at is not None: checkpoint_at = lambda step: step in checkpoints_at permanent_checkpoint_at = None if permanent_checkpoints_at is not None: permanent_checkpoint_at = ( lambda step: step in permanent_checkpoints_at) loop = training.Loop( model(mode='train'), [train_task], eval_model=model(mode='eval'), eval_tasks=[eval_task], output_dir=output_dir, checkpoint_at=checkpoint_at, permanent_checkpoint_at=permanent_checkpoint_at, n_devices=n_devices, loss_chunk_size=loss_chunk_size, use_memory_efficient_trainer=use_memory_efficient_trainer, random_seed=random_seed) steps_to_go = steps - loop.step if steps_to_go <= 0: log('Stop training, already reached the total training steps %d' % steps) return loop # Train and return the loop. loop.run(steps_to_go) return loop n_devices = num_devices() trainer = trainer_class(model, loss_fn, optimizer, lr_schedule_fn(), inputs, output_dir, random_seed=random_seed, n_devices=n_devices, checkpoints_at=checkpoints_at, metrics=metrics, checkpoint_lowest=checkpoint_lowest, checkpoint_highest=checkpoint_highest) epoch_steps = [steps] # Only training if eval_frequency is 0 or None if eval_frequency and eval_steps > 0: epoch_steps = itertools.chain( [ 1, # first epoch only 1 step eval_frequency - 1 ], itertools.repeat(eval_frequency)) trainer.log_step('Starting training using %d devices' % trainer.n_devices) trainer.print_n_weights() try: for epoch_steps in epochs(steps, trainer.step, epoch_steps): trainer.train_epoch(epoch_steps, eval_steps) # Bookkeeping we do at the first step if trainer.step == 1: # Save computation graph (single-device only for now) if (save_graphs and fastmath.is_backend(fastmath.Backend.JAX)): trainer.save_computation_graphs() # Save Gin config trainer.save_gin() trainer.log_step('Training done') except Exception as e: raise e finally: trainer.close() return trainer.state
def __init__( self, task, value_body=None, value_optimizer=None, value_lr_schedule=lr.multifactor, value_batch_size=64, value_train_steps_per_epoch=500, value_evals_per_epoch=1, value_eval_steps=1, exploration_rate=functools.partial( lr.multifactor, factors='constant * decay_every', constant=1., # pylint: disable=redefined-outer-name decay_factor=0.99, steps_per_decay=1, minimum=0.1), n_eval_episodes=0, only_eval=False, n_replay_epochs=1, max_slice_length=1, sync_freq=1000, scale_value_targets=True, output_dir=None, **kwargs): """Configures the value trainer. Args: task: RLTask instance, which defines the environment to train on. value_body: Trax layer, representing the body of the value model. functions and eval functions (a.k.a. metrics) are considered to be outside the core model, taking core model output and data labels as their two inputs. value_optimizer: the optimizer to use to train the policy model. value_lr_schedule: learning rate schedule to use to train the policy. value_batch_size: batch size used to train the policy model. value_train_steps_per_epoch: how long to train policy in each RL epoch. value_evals_per_epoch: number of policy trainer evaluations per RL epoch - only affects metric reporting. value_eval_steps: number of policy trainer steps per evaluation - only affects metric reporting. exploration_rate: exploration rate schedule - used in the policy method. n_eval_episodes: number of episodes to play with policy at temperature 0 in each epoch -- used for evaluation only only_eval: If set to True, then trajectories are collected only for for evaluation purposes, but they are not recorded. n_replay_epochs: Number of last epochs to take into the replay buffer; only makes sense for off-policy algorithms. max_slice_length: the maximum length of trajectory slices to use; it is the second dimenions of the value network output: (batch, max_slice_length, number of actions) Higher max_slice_length implies that the network has to predict more values into the future. sync_freq: frequency when to synchronize the target network with the trained network. This is necessary for training the network on bootstrapped targets, e.g. using n-step returns. scale_value_targets: If `True`, scale value function targets by `1 / (1 - gamma)`. We are trying to fix the problem with very large returns in some games in a way which does not introduce an additional hyperparameters. output_dir: Path telling where to save outputs (evals and checkpoints). **kwargs: arguments for the superclass RLTrainer. """ super(ValueAgent, self).__init__(task, n_eval_episodes=n_eval_episodes, output_dir=output_dir, **kwargs) self._value_batch_size = value_batch_size self._value_train_steps_per_epoch = value_train_steps_per_epoch self._value_evals_per_epoch = value_evals_per_epoch self._value_eval_steps = value_eval_steps self._only_eval = only_eval self._max_slice_length = max_slice_length self._policy_dist = distributions.create_distribution( task.action_space) self._n_replay_epochs = n_replay_epochs self._exploration_rate = exploration_rate() self._sync_at = (lambda step: step % sync_freq == 0) if scale_value_targets: self._value_network_scale = 1 / (1 - self._task.gamma) else: self._value_network_scale = 1 value_model = functools.partial(models.Quality, body=value_body, n_actions=self.task.action_space.n) self._value_eval_model = value_model(mode='eval') self._value_eval_model.init(self._value_model_signature) self._value_eval_jit = tl.jit_forward(self._value_eval_model.pure_fn, fastmath.device_count(), do_mean=False) # Inputs to the value model are produced by self._values_batches_stream. self._inputs = data.inputs.Inputs( train_stream=lambda _: self.value_batches_stream()) # This is the value Trainer that will be used to train the value model. # * inputs to the trainer come from self.value_batches_stream # * outputs, targets and weights are passed to self.value_loss self._value_trainer = supervised.Trainer( model=value_model, optimizer=value_optimizer, lr_schedule=value_lr_schedule(), loss_fn=self.value_loss, inputs=self._inputs, output_dir=output_dir, metrics={ 'value_loss': self.value_loss, 'value_mean': self.value_mean, 'returns_mean': self.returns_mean }) value_batch = next(self.value_batches_stream()) self._eval_model = tl.Accelerate(value_model(mode='collect'), n_devices=1) self._eval_model.init(shapes.signature(value_batch)) if self._task._initial_trajectories == 0: self._task.remove_epoch(0) self._collect_trajectories()
def _get_conditionally_synced_rng(self): if self._sync and fastmath.device_count() > 1: return fastmath.psum(self.rng, 'batch') else: return self.rng
def __init__(self, task, value_model=None, value_optimizer=None, value_lr_schedule=lr.multifactor, value_batch_size=64, value_train_steps_per_epoch=500, value_evals_per_epoch=1, value_eval_steps=1, n_shared_layers=0, added_policy_slice_length=0, n_replay_epochs=1, scale_value_targets=False, q_value=False, q_value_aggregate='logsumexp', q_value_temperature=1.0, q_value_n_samples=1, q_value_normalization=False, **kwargs): # Arguments of PolicyAgent come here. """Configures the actor-critic trainer. Args: task: `RLTask` instance to use. value_model: Model to use for the value function. value_optimizer: Optimizer to train the value model. value_lr_schedule: lr schedule for value model training. value_batch_size: Batch size for value model training. value_train_steps_per_epoch: Number of steps are we using to train the value model in each epoch. value_evals_per_epoch: Number of value trainer evaluations per RL epoch. Every evaluation, we also synchronize the weights of the target network. value_eval_steps: Number of value trainer steps per evaluation; only affects metric reporting. n_shared_layers: Number of layers to share between value and policy models. added_policy_slice_length: How much longer should slices of trajectories be for policy than for value training; this is useful for TD calculations and only affect the length of elements produced for policy batches; value batches have maximum length set by `max_slice_length` in `**kwargs`. n_replay_epochs: Number of last epochs to take into the replay buffer; only makes sense for off-policy algorithms. scale_value_targets: If `True`, scale value function targets by `1 / (1 - gamma)`. q_value: If `True`, use Q-values as baselines. q_value_aggregate: How to aggregate Q-values. Options: 'mean', 'max', 'softmax', 'logsumexp'. q_value_temperature: Temperature parameter for the 'softmax' and 'logsumexp' aggregation methods. q_value_n_samples: Number of samples to average over when calculating baselines based on Q-values. q_value_normalization: How to normalize Q-values before aggregation. Allowed values: 'std', 'abs', `None`. If `None`, don't normalize. **kwargs: Arguments for `PolicyAgent` superclass. """ self._n_shared_layers = n_shared_layers self._value_batch_size = value_batch_size self._value_train_steps_per_epoch = value_train_steps_per_epoch self._value_evals_per_epoch = value_evals_per_epoch self._value_eval_steps = value_eval_steps # The 2 below will be initalized in super.__init__ anyway, but are needed # to construct value batches which are needed before PolicyAgent init # since policy input creation calls the value model -- hence this code. self._task = task self._max_slice_length = kwargs.get('max_slice_length', 1) self._added_policy_slice_length = added_policy_slice_length self._n_replay_epochs = n_replay_epochs task.set_n_replay_epochs(n_replay_epochs) if scale_value_targets: self._value_network_scale = 1 / (1 - self._task.gamma) else: self._value_network_scale = 1 self._q_value = q_value self._q_value_aggregate = q_value_aggregate self._q_value_temperature = q_value_temperature self._q_value_n_samples = q_value_n_samples self._q_value_normalization = q_value_normalization is_discrete = isinstance(self._task.action_space, gym.spaces.Discrete) self._is_discrete = is_discrete self._vocab_size = None self._sample_all_discrete_actions = False if q_value and is_discrete: self._vocab_size = self.task.action_space.n # TODO(lukaszkaiser): the code below is specific to AWR, move it. # If n_samples = n_actions, we'll take them all in actor and reweight. if self._q_value_n_samples == self._vocab_size: # TODO(lukaszkaiser): set this explicitly once it's in AWR Trainer. self._sample_all_discrete_actions = True if q_value: value_model = functools.partial(value_model, inject_actions=True, is_discrete=is_discrete, vocab_size=self._vocab_size) self._value_eval_model = value_model(mode='eval') self._value_eval_model.init(self._value_model_signature) self._value_eval_jit = tl.jit_forward(self._value_eval_model.pure_fn, fastmath.device_count(), do_mean=False) # Initialize policy training. super().__init__(task, **kwargs) # Initialize training of the value function. value_output_dir = kwargs.get('output_dir', None) if value_output_dir is not None: value_output_dir = os.path.join(value_output_dir, 'value') # If needed, create value_output_dir and missing parent directories. if not tf.io.gfile.isdir(value_output_dir): tf.io.gfile.makedirs(value_output_dir) self._value_inputs = data.inputs.Inputs( train_stream=lambda _: self.value_batches_stream()) self._value_trainer = supervised.Trainer( model=value_model, optimizer=value_optimizer, lr_schedule=value_lr_schedule(), loss_fn=tl.L2Loss(), inputs=self._value_inputs, output_dir=value_output_dir, metrics={ 'value_loss': tl.L2Loss(), 'value_mean': self.value_mean })
def __init__(self, blocks, loss_layer, optimizer_fn, n_devices=None, memoize_jit=True): """Creates a ReversibleSerialTrainer and the needed optimizers. This trainer performs updates equivalent to using the default Trainer on:: tl.Serial(blocks + [loss_layer]). It is more memory-efficient though since weights are stored on CPU and only sent to accelerator layer-by-layer. Blocks are pairs consisting of a list of standard (arbitrary) layers and a list of reversible layers which help save memory thanks to being reversible. Args: blocks: A list of pairs of lists of standard and reversible layers. loss_layer: The final layer of the model; it can have trainable weights but should end with a loss: it is required to produce a scalar output. optimizer_fn: A function to create the optimizer, e.g., `optimizers.Adam`. n_devices: An optional integer, number of accelerator devices to use; by default, all available accelerators will be used. memoize_jit: Whether to memoize JITed functions; this significantly speeds up XLA compilation of larger models, but it uses `repr(layer)` as keys to memoize so it could fail if two layers with different functionality had the same string representaion. We have not encountered such case yet so this is turned on by default, but consider turning it off or reviewing your model if you use custom layers and encounter a problem. """ self._blocks = [(tl.Serial(std), rev) for (std, rev) in blocks] self._loss_layer = loss_layer self._optimizer_fn = optimizer_fn self._n_devices = n_devices or fastmath.device_count() self._n_layers = 1 + sum([len(revs) + 1 for (_, revs) in self._blocks]) self._n_steps_per_log = 100 # Log layers and stats every 100 steps. self._jit_memory = {} if memoize_jit else None self._jit_per_device_rngs = fastmath.jit(self._per_device_rngs, backend='cpu') # Create accelerated versions of layers as pmaped/jited pure_fn. self._accelerated_layer_fns = fastmath.nested_map( lambda layer: self._pjit(layer.pure_fn, f'fwd {repr(layer)}'), self._blocks) # Create per-layer optimizers and replicate opt_params. def _make_optimizer(layer): opt = optimizer_fn() opt.tree_init(layer.weights) return opt self._optimizers = fastmath.nested_map(_make_optimizer, self._blocks) self._replicated_opt_params = fastmath.nested_map( lambda opt: self._replicate_cpu(opt.opt_params), self._optimizers) self._loss_opt = _make_optimizer(loss_layer) self._replicated_loss_opt_params = self._replicate_cpu( self._loss_opt.opt_params) # Forward + backward + optimizer-update functions for all layers. # We call them in short FBO for "Forward + Backward + Optimizer update". # Reversible layers define a reverse_and_fbo function that also reverses. self._fbos = [] for i, (std_layer, rev_layers) in enumerate(self._blocks): (std_opt, rev_opts) = self._optimizers[i] std_fbo = _fbo_with_layer_and_opt(std_layer, std_opt, self._n_devices) rev_and_fbos = [] for layer, opt in zip(rev_layers, rev_opts): rev_and_fbo = _reverse_and_fbo_with_layer_and_opt( layer, opt, self._n_devices) # The donated args are (outputs, weights, grads) and we can donate # them because weights and grads are immediately replaced and in # case of reversible layers, the outputs are never used again. rev_and_fbos.append( self._pjit(rev_and_fbo, f'rev+bwd {repr(layer)}', donate_argnums=(0, 1, 2))) # In standard layers, the inputs cannot be donated as they may be used # as outputs for the reversible block below, but weights and grads can. jit_std_fbo = self._pjit(std_fbo, f'bwd {repr(std_layer)}', donate_argnums=(1, 2)) self._fbos.append((jit_std_fbo, rev_and_fbos)) loss_fbo = _fbo_with_layer_and_opt(self._loss_layer, self._loss_opt, self._n_devices, 'loss') self._loss_fbo = self._pjit(loss_fbo, donate_argnums=(1, 2))
def __init__(self, blocks, loss_layer, optimizer_fn, n_devices=None): """Creates a ReversibleSerialTrainer and the needed optimizers. This trainer performs updates equivalent to using the default Trainer on:: tl.Serial(blocks + [loss_layer]). It is more memory-efficient though since weights are stored on CPU and only sent to accelerator layer-by-layer. Blocks are pairs consisting of a list of standard (arbitrary) layers and a list of reversible layers which help save memory thanks to being reversible. Args: blocks: A list of pairs of lists of standard and reversible layers. loss_layer: The final layer of the model; it can have trainable weights but should end with a loss: it is required to produce a scalar output. optimizer_fn: A function to create the optimizer, e.g., `optimizers.Adam`. n_devices: An optional integer, number of accelerator devices to use; by default, all available accelerators will be used. """ self._blocks = [(tl.Serial(std), rev) for (std, rev) in blocks] self._loss_layer = loss_layer self._optimizer_fn = optimizer_fn self._n_devices = n_devices or fastmath.device_count() self._n_layers = 1 + sum([len(revs) + 1 for (_, revs) in self._blocks]) self._n_steps_per_log = 100 # Log layers and stats every 100 steps. # Create accelerated versions of layers as pmaped/jited pure_fn. self._accelerated_layer_fns = fastmath.nested_map( lambda layer: self._pjit(layer.pure_fn), self._blocks) # Create per-layer optimizers and replicate opt_params. def _make_optimizer(layer): opt = optimizer_fn() opt.tree_init(layer.weights) return opt self._optimizers = fastmath.nested_map(_make_optimizer, self._blocks) self._replicated_opt_params = fastmath.nested_map( lambda opt: self._replicate(opt.opt_params), self._optimizers) self._loss_opt = _make_optimizer(loss_layer) self._replicated_loss_opt_params = self._replicate( self._loss_opt.opt_params) # Forward + backward + optimizer-update functions for all layers. # We call them in short FBO for "Forward + Backward + Optimizer update". # Reversible layers define a reverse_and_fbo function that also reverses. self._fbos = [] for i, (std_layer, rev_layers) in enumerate(self._blocks): (std_opt, rev_opts) = self._optimizers[i] std_fbo = _fbo_with_layer_and_opt(std_layer, std_opt, self._n_devices) rev_and_fbos = [] for layer, opt in zip(rev_layers, rev_opts): rev_and_fbos.append( self._pjit(_reverse_and_fbo_with_layer_and_opt( layer, opt, self._n_devices), donate_argnums=(1, ))) self._fbos.append((self._pjit(std_fbo, donate_argnums=(1, )), rev_and_fbos)) loss_fbo = _fbo_with_layer_and_opt(self._loss_layer, self._loss_opt, self._n_devices, 'loss') self._loss_fbo = self._pjit(loss_fbo, donate_argnums=(1, ))
def __init__(self, layer, n_devices=None): super().__init__(n_in=layer.n_in, n_out=layer.n_out) self._sublayers = [layer] self._n_devices = n_devices or fastmath.device_count() self._jit_pure_fn = jit_forward( layer.pure_fn, self._n_devices, do_mean=False)
def forward(self, xs): self._validate_forward_inputs(xs) (step, layers_state) = self.state # Get N+1 rngs, N for running layers and one extra. rngs = _split_rngs(self.rng, self._n_layers + 1) rng0, rngs = rngs[0], rngs[1:] if not self.sublayers: # No-op: leave args unchanged. self.state = (step + 1, layers_state) return xs # Prepare the stack and do some safety checks as in the parent class. stack = xs new_state = [] n_layers = self._n_layers weights = self.weights if n_layers != 1 and len(weights) != n_layers: raise ValueError( 'number of weights ({}) not equal to number of layers ' '({})'.format(len(weights), n_layers)) if n_layers != 1 and len(layers_state) != n_layers: raise ValueError( 'length of state ({}) not equal to number of layers ' '({})'.format(len(layers_state), n_layers)) # TODO(chowdhery): try different strategies, also try running not all # layers backwards by using fastmath.stop_gradient where needed. # Calculate how many layers to run forward. if self._mode == 'train': # warmup goes from 1.0 at start to 0.0 at skipping_warmup_steps and after w_steps = float(self._skipping_warmup_steps) f_step = jnp.array(step, dtype=jnp.float32) warmup = jnp.maximum(0.0, (w_steps - f_step) / w_steps) # low is the minimum number of layers to *not* skip, from n_layers to 0 low = warmup * float(n_layers) # high should be so that (high - n_layers) / high = 1.0 - skip_fraction # because (high - n_layers) / high is the probability we're not skipping # (after warmup); so high - n_layers = high - high * skip_fraction high = float(n_layers) / self._skip_fraction # We want the same rng0 on all cores. if fastmath.device_count() > 1: rng0 = fastmath.psum(rng0, 'batch') n_forward_layers = random.uniform(rng0, (), jnp.float32, low, high) else: n_forward_layers = float(n_layers) # Run layers skipping after a certain number. cur_layer_idx = 0.0 for layer, p, s, rng in zip(self.sublayers, weights, layers_state, rngs): inputs = _inputs_from_stack(layer, stack) def CondF(t): o, s = layer.pure_fn(t[0], t[1], t[2], t[3]) # pylint: disable=cell-var-from-loop return o, t[1], s, t[3] outputs, _, s, _ = fastmath.cond( fastmath.lt(cur_layer_idx, n_forward_layers), CondF, lambda x: x, (inputs, p, s, rng)) stack = _outputs_onto_stack(layer, outputs, stack) new_state.append(s) cur_layer_idx += 1.0 self.state = (step + 1, new_state) return stack