def __init__(self, loss_layer, optimizer, n_devices=None): self._loss_layer = loss_layer self._optimizer = optimizer self._n_devices = n_devices or fastmath.device_count() # optimizer slots and opt_params may need to be replicated self._slots, self._opt_params = tl.for_n_devices( (self._optimizer.slots, self._optimizer.opt_params), self._n_devices) # accelerated version of loss layer to replicate weights and state self._accelerated_loss_layer = tl.Accelerate( loss_layer, n_devices=n_devices) # Signature: # (batch, weights, state, rng) -> ((loss, state), gradients) self._forward_and_backward_fn = ( fastmath.value_and_grad( loss_layer.pure_fn, argnums=1, # arg1 of pure_fn: weights has_aux=True)) # return (loss, state), gradients # Signature: # (weights, slots), step, opt_params, batch, state, rng -> # (weights, slots), state, stats self._accelerated_update_fn = ( _accelerate_update_fn( self._forward_and_backward_fn, self._optimizer, n_devices=self._n_devices, accelerate=True, ) )
def __init__(self, model_with_loss, optimizer, n_devices=None, adasum=False): self._model_with_loss = model_with_loss self._optimizer = optimizer self._n_devices = n_devices or fastmath.local_device_count() self._adasum = adasum # optimizer slots and opt_params may need to be replicated self._slots, self._opt_params = tl.on_cpu( tl.for_n_devices( (self._optimizer.slots, self._optimizer.opt_params), self._n_devices)) # accelerated version of model+loss to replicate weights and state self._accelerated_model_with_loss = tl.Accelerate(model_with_loss, n_devices=n_devices) # Signature: # (batch, weights, state, rng) -> ((loss, state), gradients) self._forward_and_backward_fn = ( fastmath.value_and_grad( model_with_loss.pure_fn, argnums=1, # arg1 of pure_fn: weights has_aux=True)) # return (loss, state), gradients # Signature: # (weights, slots), step, opt_params, batch, state, rng -> # (weights, slots), state, stats self._accelerated_update_fn = (_accelerate_update_fn( self._forward_and_backward_fn, self._optimizer, n_devices=self._n_devices, accelerate=True, adasum=self._adasum))
def __init__(self, model, task, eval_model=None, eval_task=None, output_dir=None, checkpoint_at=None, eval_at=None): """Configures a training `Loop`, including a random initialization. Args: model: Trax layer, representing the core model to be trained. Loss functions and eval functions (a.k.a. metrics) are considered to be outside the core model, taking core model output and data labels as their two inputs. task: TrainTask instance, which defines the training data, loss function, and optimizer to be used in this training loop. eval_model: Optional Trax layer, representing model used for evaluation, e.g., with dropout turned off. If None, the training model (model) will be used. eval_task: EvalTask instance or None. If None, don't do any evals. output_dir: Path telling where to save outputs (evals and checkpoints). Can be None if both `eval_task` and `checkpoint_at` are None. checkpoint_at: Function (integer --> boolean) telling, for step n, whether that step should have its checkpoint saved. If None, the default is periodic checkpointing at `task.n_steps_per_checkpoint`. eval_at: Function (integer --> boolean) that says, for training step n, whether that step should run evals. If None, run when checkpointing. """ self._task = task self._model = model self._eval_model = eval_model or model default_at = (_at_step_1_and_every_nth_step( self._task.n_steps_per_checkpoint)) if output_dir is not None: self._output_dir = os.path.expanduser(output_dir) tf.io.gfile.makedirs(self._output_dir) else: self._output_dir = None # Prepare training components. self._step = 0 self._checkpoint_at = checkpoint_at or default_at self._model_in_training = tl.Serial(self._model, self._task.loss_layer) self._batch_signature = shapes.signature(self._task.sample_batch) self._eval_model.init(self._batch_signature) self._model_in_training.init(self._batch_signature) self._task.optimizer.tree_init(self._model_in_training.weights) self._forward_and_backward_fn = ( fastmath.jit( fastmath.value_and_grad( self._model_in_training.pure_fn, argnums=1, # arg1 of pure_fn: weights has_aux=True))) # return (loss, state), gradients # Prepare eval components. if eval_task is None: self._eval_at = _never else: self._eval_task = eval_task self._eval_at = eval_at or default_at metric_name_lengths = [ len(name) for name in self._eval_task.metric_names ] self._rjust_len = max([len(self._task.loss_layer.name)] + metric_name_lengths) model_with_metrics = (_model_with_metrics(self._eval_model, self._eval_task)) self._eval_weights = model_with_metrics.weights[ 1] # just the eval part self._eval_state = model_with_metrics.state[ 1] # just the eval part self._metrics_fn = fastmath.jit(model_with_metrics.pure_fn) if self._output_dir is None: _log( 'Will not write evaluation metrics, because output_dir is None.' )
def __init__(self, model, tasks, eval_model=None, eval_tasks=None, output_dir=None, checkpoint_at=None, eval_at=None, n_devices=None, random_seed=None): """Configures a training `Loop`, including a random initialization. Args: model: Trax layer, representing the core model to be trained. Loss functions and eval functions (a.k.a. metrics) are considered to be outside the core model, taking core model output and data labels as their two inputs. tasks: List of TrainTask instances, which define the training data, loss function, and optimizer to be used in respective tasks in this training loop. eval_model: Optional Trax layer, representing model used for evaluation, e.g., with dropout turned off. If None, the training model (model) will be used. eval_tasks: List of EvalTask instances or None. If None, don't do any evals. output_dir: Path telling where to save outputs (evals and checkpoints). Can be None if both `eval_task` and `checkpoint_at` are None. checkpoint_at: Function (integer --> boolean) telling, for step n, whether that step should have its checkpoint saved. If None, the default is periodic checkpointing at `task.n_steps_per_checkpoint`. eval_at: Function (integer --> boolean) that says, for training step n, whether that step should run evals. If None, run when checkpointing. n_devices: integer or None, the number of devices for this computation. random_seed: the random seed to use; time/os dependent if None (default). """ self._is_chief, self._n_hosts, self._n_devices, self._rng = ( init_host_and_devices(n_devices, random_seed)) # Handle single task case without lists too. if not isinstance(tasks, (list, tuple)): tasks = [tasks] assert len(tasks) == 1, 'Multitask training not supported yet.' task = tasks[0] if eval_tasks is None: eval_task = None else: assert len(eval_tasks) == 1, 'Multitask training not supported yet.' eval_task = eval_tasks[0] self._task = task self._model = model self._eval_model = eval_model or model default_at = ( _at_step_1_and_every_nth_step(self._task.n_steps_per_checkpoint)) if output_dir is not None: self._output_dir = os.path.expanduser(output_dir) tf.io.gfile.makedirs(self._output_dir) else: self._output_dir = None # Prepare training components. self._step = 0 self._checkpoint_at = checkpoint_at or default_at self._batch_signature = shapes.signature(self._task.sample_batch) self._model_in_training = tl.Serial(self._model, self._task.loss_layer) # Initialize using the given random seed. # NOTE: If `random_seed` is `None` then `self._rng` will be different on # different hosts, leading to different weights on the different hosts. self._model_in_training.rng = self.new_rng() self._model_in_training.init(self._batch_signature) self._eval_model.rng = self.new_rng() self._eval_model.init(self._batch_signature) # To handle the above case (i.e. random_seed = None), we psum the weights # and state and average them. # NOTE: This adds time (how much?) so we prefer not to do it if it is # unnecessary, i.e. random_seed was set. if random_seed is None and self._n_hosts > 1: logging.info('Syncing weights/state across %d hosts.', self._n_hosts) self._sync_weights_and_state_across_hosts() self._task.optimizer.tree_init(self._model_in_training.weights) # Signature: # (batch, weights, state, rng) -> ((loss, state), gradients) self._forward_and_backward_fn = ( fastmath.value_and_grad( self._model_in_training.pure_fn, argnums=1, # arg1 of pure_fn: weights has_aux=True)) # return (loss, state), gradients # Signature: # (weights, slots), step, opt_params, batch, state, rng -> # (weights, slots), state, stats self._accelerated_update_fn = ( _accelerate_update_fn( self._forward_and_backward_fn, self._task.optimizer, n_devices=self.n_devices, accelerate=True, ) ) # Restore from checkpoint if there is one. self.load_checkpoint() # Prepare eval components. if eval_task is None: self._eval_at = _never else: self._eval_task = eval_task self._eval_at = eval_at or default_at metric_name_lengths = [len(name) for name in self._eval_task.metric_names] self._rjust_len = max( [len(self._task.loss_layer.name)] + metric_name_lengths) model_with_metrics = ( _model_with_metrics(self._eval_model, self._eval_task)) # Keep self._eval_{weights/state} replicated. self._eval_weights = self._for_n_devices( model_with_metrics.weights[1]) # just the eval part self._eval_state = self._for_n_devices( model_with_metrics.state[1]) # just the eval part self._metrics_fn = _accelerate_model_with_metrics( model_with_metrics, self.n_devices) if self._output_dir is None: _log('Will not write evaluation metrics, because output_dir is None.')
def __init__(self, model, task, eval_model=None, eval_task=None, output_dir=None, checkpoint_at=None, eval_at=None): """Configures a training `Loop`, including a random initialization. Args: model: Trax layer, representing the core model to be trained. Loss functions and eval functions (a.k.a. metrics) are considered to be outside the core model, taking core model output and data labels as their two inputs. task: TrainTask instance, which defines the training data, loss function, and optimizer to be used in this training loop. eval_model: Optional Trax layer, representing model used for evaluation, e.g., with dropout turned off. If None, the training model (model) will be used. eval_task: EvalTask instance or None. If None, don't do any evals. output_dir: Path telling where to save outputs (evals and checkpoints). Can be None if both `eval_task` and `checkpoint_at` are None. checkpoint_at: Function (integer --> boolean) telling, for step n, whether that step should have its checkpoint saved. If None, the default is periodic checkpointing at `task.n_steps_per_checkpoint`. eval_at: Function (integer --> boolean) that says, for training step n, whether that step should run evals. If None, run when checkpointing. """ self._task = task self._model = model self._model_in_training = tl.Serial(model, task.loss_layer) self._eval_model = model if eval_model is None else eval_model self._eval_task = eval_task self._rjust_len = max([0] + [len(name) for name in eval_task.metric_names]) self._output_dir = os.path.expanduser( output_dir) if output_dir else None if output_dir is not None: tf.io.gfile.makedirs(output_dir) default_fn = _at_step_1_and_periodically_at( task.n_steps_per_checkpoint) self._checkpoint_at = checkpoint_at or default_fn self._eval_at = eval_at or default_fn if eval_task is None: self._eval_at = _never self._step = 0 batch_signature = shapes.signature(task.sample_batch) self._batch_signature = batch_signature # Initialize the model and the optimizer; discard the return values # (model weights/state, optimizer slots/params), since they're available # from the model and optimizer objects. _, _ = self._model_in_training.init(batch_signature) _, _ = task.optimizer.tree_init(self._model_in_training.weights) self._gradients_and_state_fn = ( fastmath.jit( fastmath.value_and_grad( self._model_in_training.pure_fn, argnums=1, # arg1 of pure_fn: weights has_aux=True))) # return (loss, state), gradients if eval_task is not None: model_with_metrics = _model_with_metrics(self._eval_model, eval_task) self._eval_weights = model_with_metrics.weights[ 1] # just the eval part self._eval_state = model_with_metrics.state[ 1] # just the eval part self._metrics_fn = fastmath.jit(model_with_metrics.pure_fn)