def reset(self, output_dir, init_checkpoint=None): """Reset the model parameters. Restores the parameters from the given output_dir if a checkpoint exists, otherwise randomly initializes them. Does not re-jit the model. Args: output_dir: Output directory. init_checkpoint: Initial checkpoint (default $output_dir/model.pkl.gz) """ self.close() self._output_dir = output_dir if output_dir is not None: tf.io.gfile.makedirs(output_dir) else: assert not self._should_save_checkpoints assert not self._should_write_summaries # Create summary writers and history. if self._should_write_summaries: self._train_sw = jaxboard.SummaryWriter(os.path.join(output_dir, 'train'), enable=self._is_chief) self._eval_sw = jaxboard.SummaryWriter(os.path.join(output_dir, 'eval'), enable=self._is_chief) # Reset the train and eval streams. self._train_stream = _repeat_stream(self._inputs.train_stream, self._n_devices) # TODO(lukaszkaiser): add an option to evaluate exactly on the full eval # set by adding a padding and stopping the stream when too large. self._eval_stream = _repeat_stream( self._inputs.eval_stream, self._n_devices) self._train_eval_stream = _repeat_stream( self._inputs.train_eval_stream, self._n_devices) # Restore the training state. if output_dir is not None: state = load_trainer_state(output_dir, self._model_with_loss, init_checkpoint) else: state = TrainerState(step=None, opt_state=None, history=trax_history.History(), model_state=None) self._step = state.step or 0 history = state.history self._history = history if state.opt_state: opt_state = state.opt_state model_state = state.model_state else: opt_state, model_state = self._new_opt_state_and_model_state() model_state = self._for_n_devices(model_state) self._opt_state = OptState(*self._for_n_devices(opt_state)) self._model_state = model_state if not state.opt_state and self._should_save_checkpoints: self.save_state(keep=False)
def load_trainer_state(output_dir, model, weights_file=None): """Returns a TrainerState instance loaded from the given `output_dir`.""" if weights_file is None: weights_file = os.path.join(output_dir, 'model.pkl.gz') if not tf.io.gfile.exists(weights_file): return TrainerState(step=None, opt_state=None, history=trax_history.History(), model_state=None) elif not tf.io.gfile.exists(weights_file): raise ValueError('File not found: %s' % weights_file) trainer_state_dict = unpickle_from_file(weights_file, gzip=True) trainer_state = trainer_state_from_dict(trainer_state_dict, model) log('Model loaded from %s at step %d' % (weights_file, trainer_state.step)) logging.debug('From loaded model : history = %s', trainer_state.history) return trainer_state
def __init__( self, model, tasks, eval_model=None, eval_tasks=None, output_dir=None, checkpoint_at=None, permanent_checkpoint_at=None, eval_at=None, which_task=None, n_devices=None, random_seed=None, loss_chunk_size=0, use_memory_efficient_trainer=False, callbacks=None, ): """Configures a training ``Loop``, including a random initialization. Args: model: Trax layer, representing the core model to be trained. Loss functions and eval functions (a.k.a. metrics) are considered to be outside the core model, taking core model output and data labels as their two inputs. tasks: List of :py:class:`TrainTask` instances, which define the training data, loss function, and optimizer to be used in respective tasks in this training loop. It can also be a single :py:class:`TrainTask` instance which is treated in the same way as a singleton list. eval_model: Optional Trax layer, representing model used for evaluation, e.g., with dropout turned off. If ``None``, the training model (model) will be used. eval_tasks: List of :py:class:`EvalTask` instances which define how to evaluate the model: which validation data to use and which metrics to report. Evaluation on each of the tasks and will run and be reported separately which allows to score a model on different subtasks. This argument can also be ``None``, in which case no evals will be run, or a single :py:class:`EvalTask`, which wil be treated in the same way as a singleton list. output_dir: Path telling where to save outputs (evals and checkpoints). Can be ``None`` if both ``eval_task`` and ``checkpoint_at`` are ``None``. checkpoint_at: Function (integer --> boolean) telling, for step n, whether that step should have its checkpoint saved. If ``None``, the default is periodic checkpointing at ``task.n_steps_per_checkpoint``. permanent_checkpoint_at: Function (integer --> boolean) telling, for step n, whether that step should have its checkpoint saved permanently. If ``None``, the default is periodic checkpointing at ``task.n_steps_per_permanent_checkpoint``. eval_at: Function (integer --> boolean) that says, for training step n, whether that step should run evals. If ``None``, run when checkpointing. which_task: Function (integer --> integer) indicating which task should be used at which training step. Can be set to ``None`` in single-task training. n_devices: integer or ``None``, the number of devices for this computation. random_seed: the random seed to use; time/os dependent if ``None`` (default). loss_chunk_size: int, if > 0 use chunks of this size to make loss computation more more memory-efficient. use_memory_efficient_trainer: whether to use a special memory-efficient trainer; if set to 2, the memory efficiency if very aggressive callbacks: List of subclasses of StepCallback to call on training steps. """ self._is_chief, self._n_hosts, self._n_devices, self._rng = ( init_host_and_devices(n_devices, random_seed)) if use_memory_efficient_trainer: self._rng = tl.on_cpu(self._rng) # Handle single task case without lists too. if not isinstance(tasks, (list, tuple)): tasks = [tasks] if not tasks: raise ValueError('Must provide at least one training task.') if eval_tasks is None: eval_tasks = [] eval_at = _never else: if not isinstance(eval_tasks, (list, tuple)): eval_tasks = [eval_tasks] self._tasks = tasks self._model = model self._eval_model = eval_model or model self._use_memory_efficient_trainer = use_memory_efficient_trainer self._loss_chunk_size = loss_chunk_size # TODO(lukaszkaiser): can we have different eval models and save memory? if use_memory_efficient_trainer: assert len(tasks) == 1, 'only single task supported for now' self._eval_model = model default_at = _at_step_1_and_every_nth_step(tasks[0].n_steps_per_checkpoint) permanent_default_at = _at_step_1_and_every_nth_step( tasks[0].n_steps_per_permanent_checkpoint) if output_dir is not None: self._output_dir = os.path.expanduser(output_dir) tf.io.gfile.makedirs(self._output_dir) inputs.load_data_counters(self._output_dir) else: self._output_dir = None # Prepare training components. self._step = 0 self._history = trax_history.History() self._checkpoint_at = checkpoint_at or default_at self._permanent_checkpoint_at = ( permanent_checkpoint_at or permanent_default_at) if which_task is None: # If which task is not passed, then we permute tasks one by one. # If len(tasks) = 1, then which_task is a constant function equal to 0. which_task = lambda n: n % len(tasks) self._which_task = which_task # Initialize using the given random seed. # NOTE: If random_seed is None then self._rng will be different on # different hosts, leading to different weights on the different hosts. self._batch_signature = shapes.signature(tasks[0].sample_batch) self._model.rng = self.new_rng() # In the memory-efficient case, we initialize in init_trainer. if not use_memory_efficient_trainer: if _is_uninitialized(self._model): self._model.init(self._batch_signature) self._eval_model.rng = self.new_rng() if _is_uninitialized(self._eval_model): self._eval_model.init(self._batch_signature) # To handle the above case (i.e. random_seed = None), we psum the weights # and state and average them. # NOTE: This adds time (how much?) so we prefer not to do it if it is # unnecessary, i.e. random_seed was set. # NOTE: Averaging the weights across devices can screw up the initial weight # statistics. # TODO(pkozakowski): Broadcast from one of the devices instead? # TODO(lukaszkaiser): make it work for the memory-efficient trainer too. if (random_seed is None and self._n_hosts > 1 and not use_memory_efficient_trainer): logging.info('Syncing weights/state across %d hosts.', self._n_hosts) self._sync_weights_and_state_across_hosts() # Create the optimizer for the training loss function. self._trainer_per_task = tuple(self._init_trainer(task) for task in tasks) self.load_checkpoint() # Prepare eval components. self._eval_at = eval_at or default_at self._eval_tasks = eval_tasks loss_names = [task.loss_name for task in self._tasks] metric_names = [ name # pylint: disable=g-complex-comprehension for eval_task in self._eval_tasks for name in eval_task.metric_names ] self._rjust_len = max(map(len, loss_names + metric_names)) self._evaluator_per_task = tuple( self._init_evaluator(eval_task) for eval_task in self._eval_tasks) if self._output_dir is None: _log('Will not write evaluation metrics, because output_dir is None.') def task_output_dir(task_index, task_list): if self._output_dir is not None: if len(task_list) < 2: output_dir = self._output_dir else: output_dir = os.path.join(self._output_dir, str(task_index)) tf.io.gfile.makedirs(output_dir) return output_dir else: return None self._output_dir_per_eval_task = [ task_output_dir(i, eval_tasks) for i in range(len(eval_tasks))] self._output_dir_per_train_task = [ task_output_dir(i, tasks) for i in range(len(tasks))] callbacks = callbacks or [] self._callbacks = [ callback_class(self) for callback_class in callbacks ]
def test_metrics_for_mode(self): history = trax_history.History() history.append('train', 'metric1', 1, 0.1) history.append('train', 'metric2', 2, 0.2) self.assertEqual(history.metrics_for_mode('train'), ['metric1', 'metric2'])
def test_modes(self): history = trax_history.History() history.append('train', 'metric1', 1, 0.1) history.append('test', 'metric2', 2, 0.2) self.assertEqual(history.modes, ['test', 'train'])
def test_serializer_and_deserializer(self): history = trax_history.History() history.append('train', 'metric1', 1, 0.1) json_object = history.to_dict() history2 = trax_history.History.from_dict(json_object) self.assertEqual(history2.get('train', 'metric1'), [(1, 0.1)])
def test_unknown_metric(self): history = trax_history.History() history.append('train', 'metric1', 1, 0.1) self.assertEqual(history.get('train', 'unknown_metric'), [])