def test_train_save_restore_sharded(self): """Saves and restores a sharded checkpoint to check for equivalence.""" if fastmath.local_device_count() < 2: return # multi-accelerator only base.N_WEIGHTS_SHARDS = fastmath.local_device_count() train_data = data.Serial(lambda _: _very_simple_data(2, 2), data.CountAndSkip('simple_data')) task = training.TrainTask(train_data(), tl.L2Loss(), optimizers.Adam(.0001)) eval_task = training.EvalTask( _very_simple_data(2, 2), # deliberately re-using training data [tl.L2Loss()], metric_names=['SGD.L2Loss']) tmp_dir = self.create_tempdir().full_path def _make_model_and_session(): m = tl.Serial(tl.Dense(2)) ts = training.Loop(m, [task], eval_tasks=[eval_task], eval_at=lambda step_n: step_n % 2 == 0, output_dir=tmp_dir) return m, ts _, training_session = _make_model_and_session() self.assertEqual(0, training_session.step) training_session.run(n_steps=1) training_session.save_checkpoint('model') _, training_session2 = _make_model_and_session() training_session2.run(n_steps=1) base.N_WEIGHTS_SHARDS = 1
def test_autoregressive_sample_transformerlm_tfnp(self): with fastmath.use_backend(fastmath.Backend.TFNP): model = models.TransformerLM(10, d_model=32, d_ff=64, n_layers=1, n_heads=2, mode='predict') model.init(shapes.ShapeDtype((1, 1), dtype=np.int32)) s1 = decoding.autoregressive_sample(model, batch_size=1, eos_id=-1, max_length=10) self.assertEqual(s1.shape[0], 1) self.assertEqual(s1.shape[1], 10) batch_per_device = 2 // fastmath.local_device_count() model.init(shapes.ShapeDtype((batch_per_device, 1), dtype=np.int32)) s2 = decoding.autoregressive_sample(model, batch_size=2, max_length=10) self.assertEqual(s2.shape[0], 2) self.assertLess(s2.shape[1], 11) model.init(shapes.ShapeDtype((1, 1), dtype=np.int32)) prefix = np.array([[1, 2, 3]]) s3 = decoding.autoregressive_sample(model, prefix, eos_id=-1, max_length=10, batch_size=1) self.assertEqual(s3.shape[0], 1) self.assertEqual(s3.shape[1], 10)
def __init__(self, layer, n_devices=None): super().__init__(n_in=layer.n_in, n_out=layer.n_out) self._sublayers = [layer] self._n_devices = n_devices or fastmath.local_device_count() self._jit_pure_fn = jit_forward(layer.pure_fn, self._n_devices, do_mean=False)
def _run_value_model(self, obs, use_eval_model=True): """Runs value model.""" n_devices = fastmath.local_device_count() if use_eval_model: weights = tl.for_n_devices(self._value_eval_model.weights, n_devices) state = tl.for_n_devices(self._value_eval_model.state, n_devices) rng = self._value_eval_model.rng else: # TODO(henrykm): this strangely looking solution address the problem that # value_batches_stream calls _run_value_model _once_ before # the trainer is initialized. try: weights = tl.for_n_devices(self._value_trainer.model_weights, n_devices) state = tl.for_n_devices(self._value_trainer.model_state, n_devices) rng = self._value_trainer._rng # pylint: disable=protected-access except AttributeError: weights = tl.for_n_devices(self._value_eval_model.weights, n_devices) state = tl.for_n_devices(self._value_eval_model.state, n_devices) rng = self._value_eval_model.rng # TODO(henrykm): the line below fails on TPU with the error # ValueError: Number of devices (8) does not evenly divide batch size (1). obs_batch = obs.shape[0] if n_devices > obs_batch: obs = jnp.repeat(obs, int(n_devices / obs_batch), axis=0) values, _ = self._value_eval_jit(obs, weights, state, rng) values = values[:obs_batch] values *= self._value_network_scale return values
def test_run_sharded_terraformer(self): """Runs Terraformer with sharded weights (only on 2+-device systems).""" if fastmath.local_device_count() == 1: return base.N_WEIGHTS_SHARDS = fastmath.local_device_count() inputs_batch = np.arange(8).reshape((2, 4)) + 1 targets_batch = 2 * inputs_batch labeled_batch = (inputs_batch, targets_batch, np.ones_like(targets_batch)) int_sig = shapes.ShapeDtype((2, 4), dtype=np.int32) input_sig = (int_sig, int_sig, int_sig) # We want to test rng propagation too, so adding some dropout layers. model = terraformer.ConfigurableTerraformer( 20, d_model=8, d_ff=32, n_heads=1, dropout=0.0, n_encoder_layers=2, n_decoder_layers=2, ff_sparsity=(4, 8, 0.0, 1.0), encoder_attention_type=tl.Attention, encoder_decoder_attention_type=tl.CausalAttention, pos_type=None, reversible_encoder=True) loss = tl.Serial(tl.LogSoftmax(), tl.CrossEntropyLoss()) model_with_loss = tl.Serial(model, loss) rng_init = fastmath.random.get_prng(12) model_with_loss.init(input_sig, rng=rng_init) # Make a step with the trainer. optimizer = optimizers.Adafactor(0.01) split_w = fastmath.nested_map( lambda x: x[0], tl.shard(model_with_loss.weights, base.N_WEIGHTS_SHARDS)) optimizer.tree_init(split_w) trainer = optimizers.Trainer(model_with_loss, optimizer) rng_step1 = fastmath.random.get_prng(7) trainer.one_step(labeled_batch, rng_step1) # Reset shards back to default. base.N_WEIGHTS_SHARDS = 1
def shard(tensors, n_shards=None): """Shard tensors across n_shards.""" n_shards = N_WEIGHTS_SHARDS if n_shards is None else n_shards indices = _axis_index(np.zeros(fastmath.local_device_count())) def _shard_fn(x): axis = _axis_to_shard_heuristic(x.shape) if int(x.shape[axis]) % n_shards != 0: raise ValueError( f'Cannot split x with shape {x.shape} into {n_shards}.') split_x = jnp.split(x, n_shards, axis=axis) split_x = [split_x[i % n_shards] for i in indices] return np.stack(split_x, axis=0) return fastmath.nested_map(_shard_fn, tensors)
def __init__(self, model_with_loss, optimizer, n_devices=None, adasum=False): self._model_with_loss = model_with_loss self._optimizer = optimizer self._n_devices = n_devices or fastmath.local_device_count() self._adasum = adasum # optimizer slots and opt_params may need to be replicated self._slots, self._opt_params = tl.on_cpu( tl.for_n_devices( (self._optimizer.slots, self._optimizer.opt_params), self._n_devices)) # accelerated version of model+loss to replicate weights and state self._accelerated_model_with_loss = tl.Accelerate(model_with_loss, n_devices=n_devices) # Signature: # (batch, weights, state, rng) -> ((loss, state), gradients) self._forward_and_backward_fn = ( fastmath.value_and_grad( model_with_loss.pure_fn, argnums=1, # arg1 of pure_fn: weights has_aux=True)) # return (loss, state), gradients # Signature: # (weights, slots), step, opt_params, batch, state, rng -> # (weights, slots), state, stats self._accelerated_update_fn = (_accelerate_update_fn( self._forward_and_backward_fn, self._optimizer, n_devices=self._n_devices, accelerate=True, adasum=self._adasum))
def __init__(self, blocks, loss_layer, optimizer_fn, n_devices=None, memoize_jit=True, free_accelerators_on_step=False, adasum=False): """Creates a ReversibleSerialTrainer and the needed optimizers. This trainer performs updates equivalent to using the default Trainer on:: tl.Serial(blocks + [loss_layer]). It is more memory-efficient though since weights are stored on CPU and only sent to accelerator layer-by-layer. Blocks are pairs consisting of a list of standard (arbitrary) layers and a list of reversible layers which help save memory thanks to being reversible. Args: blocks: A list of pairs of lists of standard and reversible layers. loss_layer: The final layer of the model; it can have trainable weights but should end with a loss: it is required to produce a scalar output. optimizer_fn: A function to create the optimizer, e.g., `optimizers.Adam`. n_devices: An optional integer, number of accelerator devices to use; by default, all available accelerators will be used. memoize_jit: Whether to memoize JITed functions; this significantly speeds up XLA compilation of larger models, but it uses `repr(layer)` as keys to memoize so it could fail if two layers with different functionality had the same string representaion. We have not encountered such case yet so this is turned on by default, but consider turning it off or reviewing your model if you use custom layers and encounter a problem. free_accelerators_on_step: If true, frees memory on accelerators when starting a step. All layers and arguments must be on host for that, otherwise it can lead to failures. Can prevent memory fragmentation. adasum: if True, use adaptive summation to gather multi-device gradients. """ self._blocks = [(tl.Serial(std), rev) for (std, rev) in blocks] self._loss_layer = loss_layer self._optimizer_fn = optimizer_fn self._n_devices = n_devices or fastmath.local_device_count() self._adasum = adasum self._n_layers = 1 + sum([len(revs) + 1 for (_, revs) in self._blocks]) self._n_steps_per_log = 100 # Log layers and stats every 100 steps. self._n_async_layers = 1 # How many layers to run asynchronously. self._jit_memory = {} if memoize_jit else None self._do_free = free_accelerators_on_step self._jit_per_device_rngs = fastmath.jit(self._per_device_rngs, backend='cpu') # Create accelerated versions of layers as pmaped/jited pure_fn. self._accelerated_layer_fns = fastmath.nested_map( lambda layer: self._pjit(layer.pure_fn, f'fwd {repr(layer)}'), self._blocks) # Create per-layer optimizers and replicate opt_params. def _make_optimizer(layer): opt = optimizer_fn() opt.tree_init(layer.weights) opt.slots = tl.on_cpu(opt.slots) return opt self._optimizers = fastmath.nested_map(_make_optimizer, self._blocks) self._replicated_opt_params = fastmath.nested_map( lambda opt: self._replicate_cpu(opt.opt_params), self._optimizers) self._loss_opt = _make_optimizer(loss_layer) self._replicated_loss_opt_params = self._replicate_cpu( self._loss_opt.opt_params) # Forward + backward + optimizer-update functions for all layers. # We call them in short FBO for "Forward + Backward + Optimizer update". # Reversible layers define a reverse_and_fbo function that also reverses. self._fbos = [] for i, (std_layer, rev_layers) in enumerate(self._blocks): (std_opt, rev_opts) = self._optimizers[i] std_fbo = _fbo_with_layer_and_opt(std_layer, std_opt, self._n_devices, adasum=self._adasum) rev_and_fbos = [] for layer, opt in zip(rev_layers, rev_opts): rev_and_fbo = _reverse_and_fbo_with_layer_and_opt( layer, opt, self._n_devices, self._adasum) # The donated args are (outputs, weights, grads) and we can donate # them because weights and grads are immediately replaced and in # case of reversible layers, the outputs are never used again. rev_and_fbos.append( self._pjit(rev_and_fbo, f'rev+bwd {repr(layer)}', donate_argnums=(0, 1, 2))) # In standard layers, the inputs cannot be donated as they may be used # as outputs for the reversible block below, but weights and grads can. jit_std_fbo = self._pjit(std_fbo, f'bwd {repr(std_layer)}', donate_argnums=(1, 2)) self._fbos.append((jit_std_fbo, rev_and_fbos)) loss_fbo = _fbo_with_layer_and_opt(self._loss_layer, self._loss_opt, self._n_devices, 'loss', self._adasum) self._loss_fbo = self._pjit(loss_fbo, donate_argnums=(1, 2))
def train(output_dir, model=gin.REQUIRED, loss_fn=tl.WeightedCategoryCrossEntropy(), inputs=trax_inputs.batcher, optimizer=trax_opt.Adafactor, lr_schedule_fn=lr.multifactor, trainer_class=Trainer, steps=1000, checkpoints_at=None, permanent_checkpoints_at=None, eval_steps=10, eval_frequency=100, permanent_checkpoint_frequency=None, random_seed=None, save_graphs=True, metrics=None, checkpoint_highest=None, checkpoint_lowest=None, use_loop=True, loss_chunk_size=0, use_memory_efficient_trainer=False, adasum=False, init_checkpoint=None, callbacks=None, additional_train_tasks=None, additional_eval_tasks=None, additional_eval_streams=None): """Train the model on the inputs. Args: output_dir: Directory where to put the logs and checkpoints. model: The model to train as a callable returning 2 callables, an init_fn and apply_fn. loss_fn: callable with signature: weights, trax.inputs.Inputs, model, state, rng -> loss. inputs: callable returning trax.inputs.Inputs. optimizer: The optimizer (see optimizers/base.py for signature). lr_schedule_fn: A learning rate schedule function, that when called returns a function from step to learning rate (a float). trainer_class: The trainer class to use. steps: int, total number of training steps. checkpoints_at: list of integers. Save a checkpoint for each training step in the list. permanent_checkpoints_at: list of integers. Save a permanent checkpoint for each training step in the list. eval_steps: int, num of steps per evaluation. If None or 0, eval disabled. eval_frequency: int, how often to run evaluation (every eval_frequency steps). If None or 0, eval disabled. permanent_checkpoint_frequency: int, how often to save permanent checkpoints (every permanent_checkpoint_frequency steps). random_seed: the random seed to use; time/os dependent if None (default). save_graphs: bool, if True, save computation graph to file. metrics: optionally override the default metrics dictionary. checkpoint_highest: save the checkpoint highest at this metric. checkpoint_lowest: save the checkpoint lowest at this metric. use_loop: whether to use training.Loop instead of Trainer. loss_chunk_size: int, if > 0 chunk loss into these sizes to save memory. use_memory_efficient_trainer: whether to use memory-efficient trainer. adasum: if True, use adaptive summation for multi-device gradients. init_checkpoint: a checkpoint for fine tuning. callbacks: a list of callbacks to call during training. additional_train_tasks: additional tasks which should be performed during training. additional_eval_tasks: additional tasks which should be performed during evaluation. additional_eval_streams: List[NamedStream], additional data streams that should be used during evaluation. Can be provided independently of additional_eval_tasks. Returns: trax.TrainerState or training.Loop if use_loop is True """ if (permanent_checkpoint_frequency is not None and permanent_checkpoints_at is not None): raise ValueError('Only one of ["permanent_checkpoint_frequency", ' '"permanent_checkpoints_at"] should be set.') if use_loop: n_devices = num_devices() or fastmath.local_device_count() # Prepare the training task. # Inputs is either an Inputs instance or a function that returns it. if callable(inputs): # If we pass a function, e.g., through gin, call it. inputs = inputs() opt = optimizer if use_memory_efficient_trainer else optimizer() train_task = training.TrainTask( inputs.train_stream(n_devices), loss_layer=loss_fn, optimizer=opt, lr_schedule=lr_schedule_fn(), n_steps_per_checkpoint=eval_frequency, n_steps_per_permanent_checkpoint=permanent_checkpoint_frequency) if additional_train_tasks is None: additional_train_tasks = [] # Prepare the evaluation. metrics_dict = metrics if metrics is not None else _DEFAULT_METRICS names, metrics = zip(*metrics_dict.items()) eval_task = training.EvalTask(inputs.eval_stream(n_devices), metrics, metric_names=names, n_eval_batches=eval_steps) if additional_eval_tasks is None: additional_eval_tasks = [] additional_eval_tasks_from_streams = [] if additional_eval_streams is not None: for stream in additional_eval_streams: additional_eval_tasks_from_streams.append( training.EvalTask(stream.stream, metrics, metric_names=names, n_eval_batches=eval_steps, export_prefix=stream.name)) # Prepare the training loop. checkpoint_at = None if checkpoints_at is not None: checkpoint_at = lambda step: step in checkpoints_at permanent_checkpoint_at = None if permanent_checkpoints_at is not None: permanent_checkpoint_at = (lambda step: step in permanent_checkpoints_at) # Setup the model. model_train = model(mode='train') model_predict_eval = model(mode='eval') if init_checkpoint: model_train.init_from_file(init_checkpoint, weights_only=True) model_predict_eval.init_from_file(init_checkpoint, weights_only=True) loop = training.Loop( model_train, [train_task] + additional_train_tasks, eval_model=model_predict_eval, eval_tasks=[eval_task] + additional_eval_tasks + additional_eval_tasks_from_streams, output_dir=output_dir, checkpoint_at=checkpoint_at, permanent_checkpoint_at=permanent_checkpoint_at, n_devices=n_devices, loss_chunk_size=loss_chunk_size, use_memory_efficient_trainer=use_memory_efficient_trainer, adasum=adasum, random_seed=random_seed, callbacks=callbacks, ) steps_to_go = steps - loop.step if steps_to_go <= 0: log('Stop training, already reached the total training steps %d' % steps) return loop # Train and return the loop. loop.run(steps_to_go) return loop n_devices = num_devices() trainer = trainer_class(model, loss_fn, optimizer, lr_schedule_fn(), inputs, output_dir, random_seed=random_seed, n_devices=n_devices, checkpoints_at=checkpoints_at, metrics=metrics, checkpoint_lowest=checkpoint_lowest, checkpoint_highest=checkpoint_highest, init_checkpoint=init_checkpoint) epoch_steps = [steps] # Only training if eval_frequency is 0 or None if eval_frequency and eval_steps > 0: epoch_steps = itertools.chain([1, # first epoch only 1 step eval_frequency - 1], itertools.repeat(eval_frequency)) trainer.log_step('Starting training using %d devices' % trainer.n_devices) trainer.print_n_weights() try: for epoch_steps in epochs(steps, trainer.step, epoch_steps): trainer.train_epoch(epoch_steps, eval_steps) # Bookkeeping we do at the first step if trainer.step == 1: # Save computation graph (single-device only for now) if (save_graphs and fastmath.is_backend(fastmath.Backend.JAX)): trainer.save_computation_graphs() # Save Gin config trainer.save_gin() trainer.log_step('Training done') except Exception as e: raise e finally: trainer.close() return trainer.state
def __init__( self, task, value_body=None, value_optimizer=None, value_lr_schedule=lr.multifactor, value_batch_size=64, value_train_steps_per_epoch=500, value_evals_per_epoch=1, value_eval_steps=1, exploration_rate=functools.partial( lr.multifactor, factors='constant * decay_every', constant=1., # pylint: disable=redefined-outer-name decay_factor=0.99, steps_per_decay=1, minimum=0.1), n_eval_episodes=0, only_eval=False, n_replay_epochs=1, max_slice_length=1, sync_freq=1000, scale_value_targets=True, output_dir=None, **kwargs): """Configures the value trainer. Args: task: RLTask instance, which defines the environment to train on. value_body: Trax layer, representing the body of the value model. functions and eval functions (a.k.a. metrics) are considered to be outside the core model, taking core model output and data labels as their two inputs. value_optimizer: the optimizer to use to train the policy model. value_lr_schedule: learning rate schedule to use to train the policy. value_batch_size: batch size used to train the policy model. value_train_steps_per_epoch: how long to train policy in each RL epoch. value_evals_per_epoch: number of policy trainer evaluations per RL epoch - only affects metric reporting. value_eval_steps: number of policy trainer steps per evaluation - only affects metric reporting. exploration_rate: exploration rate schedule - used in the policy method. n_eval_episodes: number of episodes to play with policy at temperature 0 in each epoch -- used for evaluation only only_eval: If set to True, then trajectories are collected only for for evaluation purposes, but they are not recorded. n_replay_epochs: Number of last epochs to take into the replay buffer; only makes sense for off-policy algorithms. max_slice_length: the maximum length of trajectory slices to use; it is the second dimenions of the value network output: (batch, max_slice_length, number of actions) Higher max_slice_length implies that the network has to predict more values into the future. sync_freq: frequency when to synchronize the target network with the trained network. This is necessary for training the network on bootstrapped targets, e.g. using n-step returns. scale_value_targets: If `True`, scale value function targets by `1 / (1 - gamma)`. We are trying to fix the problem with very large returns in some games in a way which does not introduce an additional hyperparameters. output_dir: Path telling where to save outputs (evals and checkpoints). **kwargs: arguments for the superclass RLTrainer. """ super(ValueAgent, self).__init__(task, n_eval_episodes=n_eval_episodes, output_dir=output_dir, **kwargs) self._value_batch_size = value_batch_size self._value_train_steps_per_epoch = value_train_steps_per_epoch self._value_evals_per_epoch = value_evals_per_epoch self._value_eval_steps = value_eval_steps self._only_eval = only_eval self._max_slice_length = max_slice_length self._policy_dist = distributions.create_distribution( task.action_space) self._n_replay_epochs = n_replay_epochs self._exploration_rate = exploration_rate() self._sync_at = (lambda step: step % sync_freq == 0) if scale_value_targets: self._value_network_scale = 1 / (1 - self._task.gamma) else: self._value_network_scale = 1 value_model = functools.partial(models.Quality, body=value_body, n_actions=self.task.action_space.n) self._value_eval_model = value_model(mode='eval') self._value_eval_model.init(self._value_model_signature) self._value_eval_jit = tl.jit_forward(self._value_eval_model.pure_fn, fastmath.local_device_count(), do_mean=False) # Inputs to the value model are produced by self._values_batches_stream. self._inputs = data.inputs.Inputs( train_stream=lambda _: self.value_batches_stream()) # This is the value Trainer that will be used to train the value model. # * inputs to the trainer come from self.value_batches_stream # * outputs, targets and weights are passed to self.value_loss self._value_trainer = supervised.Trainer( model=value_model, optimizer=value_optimizer, lr_schedule=value_lr_schedule(), loss_fn=self.value_loss, inputs=self._inputs, output_dir=output_dir, metrics={ 'value_loss': self.value_loss, 'value_mean': self.value_mean, 'returns_mean': self.returns_mean }) value_batch = next(self.value_batches_stream()) self._eval_model = tl.Accelerate(value_model(mode='collect'), n_devices=1) self._eval_model.init(shapes.signature(value_batch)) if self._task._initial_trajectories == 0: self._task.remove_epoch(0) self._collect_trajectories()