def __init__(self, model, task, eval_task=None, output_dir=None, checkpoint_at=None): """Configures a training `Loop`, including a random initialization. Args: model: Trax layer, representing the core model to be trained. Loss functions and eval functions (a.k.a. metrics) are considered to be outside the core model, taking core model output and data labels as their two inputs. task: TrainTask instance, which defines the training data, loss function, and optimizer to be used in this training loop. eval_task: EvalTask instance or None. If None, don't do any evals. output_dir: Path telling where to save outputs (evals and checkpoints). Can be None if both `eval_task` and `checkpoint_at` are None. checkpoint_at: Function (integer --> boolean) telling, for step n, whether that step should have its checkpoint saved. If None, don't save any checkpoints. """ self._task = task self._model_in_training = tl.Serial(model, task.loss_layer) self._eval_task = eval_task self._output_dir = output_dir self._checkpoint_at = checkpoint_at or _never self._step = None batch_signature = shapes.signature(task.sample_batch) # Initialize the model and the optimizer; discard the return values # (model weights/state, optimizer slots/params), since they're available # from the model and optimizer objects. _, _ = self._model_in_training.init(batch_signature) _, _ = task.optimizer.tree_init(self._model_in_training.weights) self._gradients_and_state_fn = ( math.jit( math.grad( self._model_in_training.pure_fn, argnums=1, # arg1 of pure_fn: weights has_aux=True))) # return (gradients, state) if eval_task is not None: model_with_metrics = _model_with_metrics(model, eval_task) self._eval_weights = model_with_metrics.weights[ 1] # just the eval part self._eval_state = model_with_metrics.state[ 1] # just the eval part self._metrics_fn = math.jit(model_with_metrics.pure_fn)
def _jit_compute_loss_fn(predict_fn, loss_fn, n_devices, jit=True): """Returns a (JIT-compiled) function that computes the loss for one step.""" if n_devices == 1: # TODO(lukaszkaiser): remove branch when not needed. def single_compute_loss(opt_state, batch, state, rng): rng, subrng = jax_random.split(rng[0]) loss_val, state = loss_fn(opt_state[0], batch, predict_fn, state, rng) return loss_val, state, [subrng] return math.jit(single_compute_loss) if jit else single_compute_loss # Else, for n_devices > 1: @functools.partial(math.pmap, axis_name='batch') def mapped_compute_loss(opt_state, batch, state, rng): """This is a multi-device version of the update function above.""" # We assume all tensors have the first dimension = n_devices. rng, subrng = jax_random.split(rng) loss_val, state = loss_fn(opt_state[0], batch, predict_fn, state, rng) return loss_val, state, subrng def compute_loss(opt_state, batch, state, rng): return mapped_compute_loss(opt_state, _reshape_by_device(batch, n_devices), state, rng) return compute_loss
def __init__(self, model, batch_size, observation_space, action_space, reward_range, discrete_rewards, history_stream, output_dir, model_predict_kwargs=None): """Initializes the env. Args: model: Trax model. batch_size: (int) Number of simulated environments run in parallel. observation_space: (gym.Space) Observation space. action_space: (gym.Space) Action space. reward_range: (tuple) Pair (min_reward, max_reward). discrete_rewards: (bool) Whether to discretize the rewards. history_stream: Iterator yielding batches of initial input data for the model. The format is implementation-specific. output_dir: (str) Output dir. model_predict_kwargs: (dict) Additional model keyword arguments for inference. Useful when different config is needed for training and inference, e.g. train with memory efficient attention and predict with the regular one. """ self._model = model if model_predict_kwargs is None: model_predict_kwargs = {} model_predict = self._model(mode='predict', **model_predict_kwargs) # NOTE: can set non-default PRNG key: model_predict._set_rng_recursive(...) def predict_with_state(*args, **kwargs): output = model_predict(*args, **kwargs) return (output, model_predict.state) self._model_predict = math.jit(predict_with_state) self._model_initialize = model_predict.init self._observation_space = observation_space self._action_space = action_space self._reward_range = reward_range self._output_dir = output_dir self._predict_fn = None self._rng = None self._model_state = None self._history_stream = None # Call the super's ctor. It will use some of the member fields, so we call # it in the end. super(SimulatedEnvProblem, self).__init__( batch_size=batch_size, discrete_rewards=discrete_rewards, history_stream=history_stream, ) self.seed()
def _jit_update_fn(predict_fn, loss_fn, optimizer, n_devices, jit=True): """Returns a (JIT-compiled) function that computes updates for one step.""" model_and_loss = tl.Serial(predict_fn, loss_fn) # Gradients are always wrt. the first argument, so putting weights first. def model_and_loss_call(weights, batch, state, rng): res = model_and_loss(batch, weights=weights, state=state, rng=rng) return res, model_and_loss.state if n_devices == 1: # TODO(lukaszkaiser): remove branch when not needed. def single_update(weights_and_slots, i, opt_params, batch, state, rng): weights, slots = weights_and_slots rng, subrng = jax_random.split(rng[0]) grad_fn = math.grad(model_and_loss_call, has_aux=True) grads, state = grad_fn(weights, batch, state, rng) new_weights, new_slots, stats = optimizer.tree_update( i, grads, weights, slots, opt_params) return (new_weights, new_slots), stats, state, [subrng] if jit: return math.jit(single_update, donate_argnums=(0, )) else: return single_update # Else, for n_devices > 1: @functools.partial(math.pmap, axis_name='batch', donate_argnums=(0, )) def mapped_update(weights_and_slots, i, opt_params, batch, state, rng): """This is a multi-device version of the update function above.""" # We assume all tensors have the first dimension = n_devices. weights, slots = weights_and_slots rng, subrng = jax_random.split(rng) grad_fn = math.grad(model_and_loss_call, has_aux=True) grads, state = grad_fn(weights, batch, state, rng) # We do a psum(1.0) here instead of `n_devices` since `n_devices` is just # the number of devices on this host machine, however psum goes over all # devices of all hosts (ex: a TPU pod) and we need to be averaging over all # of them. grads = jax.tree_util.tree_map( lambda g: math.psum(g, 'batch') / math.psum( np.array(1.0), 'batch'), grads) new_weights, new_slots, stats = optimizer.tree_update( i, grads, weights, slots, opt_params) return (new_weights, new_slots), stats, state, subrng def update(weights_and_slots, i, opt_params, batch, state, rng): return mapped_update(weights_and_slots, np.repeat(i, n_devices), opt_params, batch, state, rng) return update
def _accelerate(f, n_devices): """JIT-compiled version of `f` running on `n_devices`.""" if n_devices == 1: return math.jit(f) return math.pmap(f, axis_name='batch')
def _accelerate(f, n_devices): """JITed version of f running on n_devices.""" if n_devices == 1: return math.jit(f) return math.pmap(f, axis_name='batch')
def __init__(self, model, loss_fn, optimizer, lr_schedule, inputs, output_dir=None, random_seed=None, n_devices=None, checkpoints_at=None, should_save_checkpoints=True, should_write_summaries=True, has_weights=False, nontrainable_param_map=None, id_to_mask=None, metrics=None, checkpoint_highest=None, checkpoint_lowest=None): self._is_chief, self._n_devices, rng = (self._init_host_and_devices( n_devices, random_seed)) self._should_save_checkpoints = should_save_checkpoints and self._is_chief self._checkpoints_at = checkpoints_at or [] self._should_write_summaries = should_write_summaries if not output_dir: self._should_save_checkpoints = False self._should_write_summaries = False self._checkpoint_highest = checkpoint_highest self._checkpoint_lowest = checkpoint_lowest self._has_weights = has_weights self._id_to_mask = id_to_mask self._metrics_dict = metrics if metrics is not None else _DEFAULT_METRICS loss_fn = loss_fn(has_weights=has_weights, id_to_mask=id_to_mask) # Inputs is either an Inputs instance or a function that returns it. self._inputs = inputs if callable( inputs): # If we pass a function, e.g., through gin, call it. self._inputs = inputs() # Initialize the learning rate to a dummy value. It will be set in reset(). opt = optimizer(learning_rate=0.0) # Setup the model. model_train = model(mode='train') model_predict_eval = model(mode='eval') # Setup state. rng, init_rng = jax_random.split(rng) self._rngs = np.stack(jax_random.split(rng, self._n_devices)) # If the inputs are a tuple/list, add [None] (batch) to each element. if self._inputs.input_shape and isinstance(self._inputs.input_shape[0], (list, tuple)): model_input_shape = tuple( tuple([None] + list(shape)) for shape in self._inputs.input_shape) else: # Otherwise just add [None] to the input shape. model_input_shape = tuple([None] + list(self._inputs.input_shape)) # Same for targets. if self._inputs.target_shape and isinstance( self._inputs.target_shape[0], (list, tuple)): model_target_shape = tuple( tuple([None] + list(shape)) for shape in self._inputs.target_shape) else: model_target_shape = tuple([None] + list(self._inputs.target_shape)) # Change all None to 1 in input and target shape. model_input_shape = math.nested_map(lambda x: x or 1, model_input_shape) model_target_shape = math.nested_map(lambda x: x or 1, model_target_shape) def new_opt_state_and_model_state(shape_dtype, rng): """Returns optimizer and model states suitable for training a model.""" # Combine inputs and targets on the stack. shapes, dtypes = shape_dtype input_signature = tuple( ShapeDtype(s, d) for (s, d) in zip(shapes, dtypes)) # We need to create a new model instance and not reuse `model_train` here, # because `m.initialize` puts cached parameter values in `m` and hence the # next call of `m.initialize` will give wrong results. m = tl.Serial(model(mode='train'), loss_fn) m._set_rng_recursive(rng) # pylint: disable=protected-access weights, state = m.init(input_signature) (slots, opt_params) = opt.tree_init(weights) return (OptState(weights, slots, opt_params), state) if _is_jit_init(): # JIT parameter initialization to avoid memory fragmentation new_opt_state_and_model_state = math.jit( new_opt_state_and_model_state, static_argnums=(0, )) self._new_opt_state_and_model_state = ( lambda: new_opt_state_and_model_state( # pylint: disable=g-long-lambda self._inputs.example_shape_dtype, init_rng)) # Arrange and initialize metrics layers. self._metrics = list(sorted(self._metrics_dict.keys())) metrics_layers = [ self._metrics_dict[m](has_weights=self._has_weights, id_to_mask=self._id_to_mask) for m in self._metrics ] metrics_in_parallel = tl.Branch(*metrics_layers) metrics_in_parallel._set_rng_recursive(init_rng) # pylint: disable=protected-access example_signature = tuple( ShapeDtype(s, d) for (s, d) in zip(*self._inputs.example_shape_dtype)) model_predict_eval.init(example_signature) output_signature = model_predict_eval.output_signature( example_signature) m_weights, m_state = metrics_in_parallel.init(output_signature) self._metrics_weights = self._for_n_devices(m_weights) self._metrics_state = self._for_n_devices(m_state) # Jit model_predict and update so they're fast. self._jit_eval = _jit_predict_fn(model_predict_eval, metrics_in_parallel, self._n_devices) self._jit_update_fn = _jit_update_fn(model_train, loss_fn, opt, self._n_devices) self._model_train = model_train self._model_predict_eval = model_predict_eval self._loss_fn = loss_fn # TODO(pkozakowski): "Learning rate schedules" are currently able to control # control all optimizer parameters and model state, so let's rename them # accordingly. self._lr_schedule = lr_schedule if nontrainable_param_map is None: nontrainable_param_map = {} self._nontrainable_param_map = nontrainable_param_map # Those fields will be set in reset(). self._output_dir = None self._train_sw = None self._eval_sw = None self._history = None self._lr_fn = None self._opt_state = None self._step = None self._model_state = None self.reset(output_dir)
def __init__(self, model, loss_fn, optimizer, lr_schedule, inputs, output_dir=None, random_seed=None, n_devices=None, checkpoints_at=None, should_save_checkpoints=True, should_write_summaries=True, metrics=None, checkpoint_highest=None, checkpoint_lowest=None): self._is_chief, self._n_devices, rng = (self._init_host_and_devices( n_devices, random_seed)) self._should_save_checkpoints = should_save_checkpoints and self._is_chief self._checkpoints_at = checkpoints_at or [] self._should_write_summaries = should_write_summaries if not output_dir: self._should_save_checkpoints = False self._should_write_summaries = False self._checkpoint_highest = checkpoint_highest self._checkpoint_lowest = checkpoint_lowest if metrics is not None: self._metrics_dict = metrics else: self._metrics_dict = _DEFAULT_METRICS self._metrics_dict['loss'] = loss_fn self._metrics_dict = metrics if metrics is not None else _DEFAULT_METRICS # Inputs is either an Inputs instance or a function that returns it. self._inputs = inputs if callable( inputs): # If we pass a function, e.g., through gin, call it. self._inputs = inputs() # Initialize the learning rate to a dummy value. It will be set in reset(). opt = optimizer(learning_rate=0.0) # Setup the model. model_train = model(mode='train') model_predict_eval = model(mode='eval') self._model_with_loss = tl.Serial(model_train, loss_fn) # Setup state. rng, init_rng = jax_random.split(rng) self._rngs = np.stack(jax_random.split(rng, self._n_devices)) shapes, dtypes = self._inputs.example_shape_dtype input_signature = tuple( ShapeDtype(s, d) for (s, d) in zip(shapes, dtypes)) def new_opt_state_and_model_state(rng): """Returns optimizer and model states suitable for training a model.""" weights, state = self._model_with_loss.init(input_signature, rng=rng) (slots, opt_params) = opt.tree_init(weights) return (OptState(weights, slots, opt_params), state) if math.backend_name() == 'jax': # JIT parameter initialization to avoid memory fragmentation new_opt_state_and_model_state = math.jit( new_opt_state_and_model_state) self._new_opt_state_and_model_state = ( lambda: new_opt_state_and_model_state(init_rng)) # Arrange and initialize metrics layers. self._metrics = list(sorted(self._metrics_dict.keys())) metrics_layers = [self._metrics_dict[m] for m in self._metrics] metrics_in_parallel = tl.Branch(*metrics_layers) metrics_in_parallel.rng = init_rng example_signature = tuple( ShapeDtype(s, d) for (s, d) in zip(*self._inputs.example_shape_dtype)) model_predict_eval.init(example_signature) self._input_signature = example_signature output_signature = model_predict_eval.output_signature( example_signature) m_weights, m_state = metrics_in_parallel.init(output_signature) self._metrics_weights = self._for_n_devices(m_weights) self._metrics_state = self._for_n_devices(m_state) # Jit model_predict and update so they're fast. self._jit_eval = _jit_predict_fn(model_predict_eval, metrics_in_parallel, self._n_devices) self._jit_update_fn = _jit_update_fn(model_train, loss_fn, opt, self._n_devices) self._model_train = model_train self._model_predict_eval = model_predict_eval self._loss_fn = loss_fn # TODO(pkozakowski): "Learning rate schedules" are currently able to control # control all optimizer parameters and model state, so let's rename them # accordingly. self._lr_schedule = lr_schedule # Those fields will be set in reset(). self._output_dir = None self._train_sw = None self._eval_sw = None self._history = None self._lr_fn = None self._opt_state = None self._step = None self._model_state = None self.reset(output_dir)