Example #1
0
    def __init__(self,
                 model,
                 task,
                 eval_task=None,
                 output_dir=None,
                 checkpoint_at=None):
        """Configures a training `Loop`, including a random initialization.

    Args:
      model: Trax layer, representing the core model to be trained. Loss
          functions and eval functions (a.k.a. metrics) are considered to be
          outside the core model, taking core model output and data labels as
          their two inputs.
      task: TrainTask instance, which defines the training data, loss function,
          and optimizer to be used in this training loop.
      eval_task: EvalTask instance or None. If None, don't do any evals.
      output_dir: Path telling where to save outputs (evals and checkpoints).
          Can be None if both `eval_task` and `checkpoint_at` are None.
      checkpoint_at: Function (integer --> boolean) telling, for step n, whether
          that step should have its checkpoint saved. If None, don't save any
          checkpoints.
    """
        self._task = task
        self._model_in_training = tl.Serial(model, task.loss_layer)
        self._eval_task = eval_task
        self._output_dir = output_dir
        self._checkpoint_at = checkpoint_at or _never
        self._step = None

        batch_signature = shapes.signature(task.sample_batch)
        # Initialize the model and the optimizer; discard the return values
        # (model weights/state, optimizer slots/params), since they're available
        # from the model and optimizer objects.
        _, _ = self._model_in_training.init(batch_signature)
        _, _ = task.optimizer.tree_init(self._model_in_training.weights)

        self._gradients_and_state_fn = (
            math.jit(
                math.grad(
                    self._model_in_training.pure_fn,
                    argnums=1,  # arg1 of pure_fn: weights
                    has_aux=True)))  # return (gradients, state)

        if eval_task is not None:
            model_with_metrics = _model_with_metrics(model, eval_task)
            self._eval_weights = model_with_metrics.weights[
                1]  # just the eval part
            self._eval_state = model_with_metrics.state[
                1]  # just the eval part
            self._metrics_fn = math.jit(model_with_metrics.pure_fn)
Example #2
0
def _jit_compute_loss_fn(predict_fn, loss_fn, n_devices, jit=True):
    """Returns a (JIT-compiled) function that computes the loss for one step."""
    if n_devices == 1:  # TODO(lukaszkaiser): remove branch when not needed.

        def single_compute_loss(opt_state, batch, state, rng):
            rng, subrng = jax_random.split(rng[0])
            loss_val, state = loss_fn(opt_state[0], batch, predict_fn, state,
                                      rng)
            return loss_val, state, [subrng]

        return math.jit(single_compute_loss) if jit else single_compute_loss

    # Else, for n_devices > 1:
    @functools.partial(math.pmap, axis_name='batch')
    def mapped_compute_loss(opt_state, batch, state, rng):
        """This is a multi-device version of the update function above."""
        # We assume all tensors have the first dimension = n_devices.
        rng, subrng = jax_random.split(rng)
        loss_val, state = loss_fn(opt_state[0], batch, predict_fn, state, rng)
        return loss_val, state, subrng

    def compute_loss(opt_state, batch, state, rng):
        return mapped_compute_loss(opt_state,
                                   _reshape_by_device(batch, n_devices), state,
                                   rng)

    return compute_loss
Example #3
0
    def __init__(self,
                 model,
                 batch_size,
                 observation_space,
                 action_space,
                 reward_range,
                 discrete_rewards,
                 history_stream,
                 output_dir,
                 model_predict_kwargs=None):
        """Initializes the env.

    Args:
      model: Trax model.
      batch_size: (int) Number of simulated environments run in parallel.
      observation_space: (gym.Space) Observation space.
      action_space: (gym.Space) Action space.
      reward_range: (tuple) Pair (min_reward, max_reward).
      discrete_rewards: (bool) Whether to discretize the rewards.
      history_stream: Iterator yielding batches of initial input data for the
        model. The format is implementation-specific.
      output_dir: (str) Output dir.
      model_predict_kwargs: (dict) Additional model keyword arguments for
        inference. Useful when different config is needed for training and
        inference, e.g. train with memory efficient attention and predict with
        the regular one.
    """
        self._model = model
        if model_predict_kwargs is None:
            model_predict_kwargs = {}
        model_predict = self._model(mode='predict', **model_predict_kwargs)

        # NOTE: can set non-default PRNG key: model_predict._set_rng_recursive(...)
        def predict_with_state(*args, **kwargs):
            output = model_predict(*args, **kwargs)
            return (output, model_predict.state)

        self._model_predict = math.jit(predict_with_state)
        self._model_initialize = model_predict.init

        self._observation_space = observation_space
        self._action_space = action_space
        self._reward_range = reward_range
        self._output_dir = output_dir

        self._predict_fn = None
        self._rng = None
        self._model_state = None
        self._history_stream = None

        # Call the super's ctor. It will use some of the member fields, so we call
        # it in the end.
        super(SimulatedEnvProblem, self).__init__(
            batch_size=batch_size,
            discrete_rewards=discrete_rewards,
            history_stream=history_stream,
        )

        self.seed()
Example #4
0
def _jit_update_fn(predict_fn, loss_fn, optimizer, n_devices, jit=True):
    """Returns a (JIT-compiled) function that computes updates for one step."""
    model_and_loss = tl.Serial(predict_fn, loss_fn)

    # Gradients are always wrt. the first argument, so putting weights first.
    def model_and_loss_call(weights, batch, state, rng):
        res = model_and_loss(batch, weights=weights, state=state, rng=rng)
        return res, model_and_loss.state

    if n_devices == 1:  # TODO(lukaszkaiser): remove branch when not needed.

        def single_update(weights_and_slots, i, opt_params, batch, state, rng):
            weights, slots = weights_and_slots
            rng, subrng = jax_random.split(rng[0])
            grad_fn = math.grad(model_and_loss_call, has_aux=True)
            grads, state = grad_fn(weights, batch, state, rng)
            new_weights, new_slots, stats = optimizer.tree_update(
                i, grads, weights, slots, opt_params)
            return (new_weights, new_slots), stats, state, [subrng]

        if jit:
            return math.jit(single_update, donate_argnums=(0, ))
        else:
            return single_update

    # Else, for n_devices > 1:
    @functools.partial(math.pmap, axis_name='batch', donate_argnums=(0, ))
    def mapped_update(weights_and_slots, i, opt_params, batch, state, rng):
        """This is a multi-device version of the update function above."""
        # We assume all tensors have the first dimension = n_devices.
        weights, slots = weights_and_slots
        rng, subrng = jax_random.split(rng)
        grad_fn = math.grad(model_and_loss_call, has_aux=True)
        grads, state = grad_fn(weights, batch, state, rng)
        # We do a psum(1.0) here instead of `n_devices` since `n_devices` is just
        # the number of devices on this host machine, however psum goes over all
        # devices of all hosts (ex: a TPU pod) and we need to be averaging over all
        # of them.
        grads = jax.tree_util.tree_map(
            lambda g: math.psum(g, 'batch') / math.psum(
                np.array(1.0), 'batch'), grads)
        new_weights, new_slots, stats = optimizer.tree_update(
            i, grads, weights, slots, opt_params)
        return (new_weights, new_slots), stats, state, subrng

    def update(weights_and_slots, i, opt_params, batch, state, rng):
        return mapped_update(weights_and_slots, np.repeat(i, n_devices),
                             opt_params, batch, state, rng)

    return update
Example #5
0
def _accelerate(f, n_devices):
    """JIT-compiled version of `f` running on `n_devices`."""
    if n_devices == 1:
        return math.jit(f)

    return math.pmap(f, axis_name='batch')
Example #6
0
def _accelerate(f, n_devices):
    """JITed version of f running on n_devices."""
    if n_devices == 1:
        return math.jit(f)

    return math.pmap(f, axis_name='batch')
Example #7
0
    def __init__(self,
                 model,
                 loss_fn,
                 optimizer,
                 lr_schedule,
                 inputs,
                 output_dir=None,
                 random_seed=None,
                 n_devices=None,
                 checkpoints_at=None,
                 should_save_checkpoints=True,
                 should_write_summaries=True,
                 has_weights=False,
                 nontrainable_param_map=None,
                 id_to_mask=None,
                 metrics=None,
                 checkpoint_highest=None,
                 checkpoint_lowest=None):

        self._is_chief, self._n_devices, rng = (self._init_host_and_devices(
            n_devices, random_seed))
        self._should_save_checkpoints = should_save_checkpoints and self._is_chief
        self._checkpoints_at = checkpoints_at or []
        self._should_write_summaries = should_write_summaries
        if not output_dir:
            self._should_save_checkpoints = False
            self._should_write_summaries = False
        self._checkpoint_highest = checkpoint_highest
        self._checkpoint_lowest = checkpoint_lowest
        self._has_weights = has_weights
        self._id_to_mask = id_to_mask
        self._metrics_dict = metrics if metrics is not None else _DEFAULT_METRICS
        loss_fn = loss_fn(has_weights=has_weights, id_to_mask=id_to_mask)
        # Inputs is either an Inputs instance or a function that returns it.
        self._inputs = inputs
        if callable(
                inputs):  # If we pass a function, e.g., through gin, call it.
            self._inputs = inputs()

        # Initialize the learning rate to a dummy value. It will be set in reset().
        opt = optimizer(learning_rate=0.0)

        # Setup the model.
        model_train = model(mode='train')
        model_predict_eval = model(mode='eval')

        # Setup state.
        rng, init_rng = jax_random.split(rng)
        self._rngs = np.stack(jax_random.split(rng, self._n_devices))
        # If the inputs are a tuple/list, add [None] (batch) to each element.
        if self._inputs.input_shape and isinstance(self._inputs.input_shape[0],
                                                   (list, tuple)):
            model_input_shape = tuple(
                tuple([None] + list(shape))
                for shape in self._inputs.input_shape)
        else:  # Otherwise just add [None] to the input shape.
            model_input_shape = tuple([None] + list(self._inputs.input_shape))
        # Same for targets.
        if self._inputs.target_shape and isinstance(
                self._inputs.target_shape[0], (list, tuple)):
            model_target_shape = tuple(
                tuple([None] + list(shape))
                for shape in self._inputs.target_shape)
        else:
            model_target_shape = tuple([None] +
                                       list(self._inputs.target_shape))
        # Change all None to 1 in input and target shape.
        model_input_shape = math.nested_map(lambda x: x or 1,
                                            model_input_shape)
        model_target_shape = math.nested_map(lambda x: x or 1,
                                             model_target_shape)

        def new_opt_state_and_model_state(shape_dtype, rng):
            """Returns optimizer and model states suitable for training a model."""
            # Combine inputs and targets on the stack.
            shapes, dtypes = shape_dtype
            input_signature = tuple(
                ShapeDtype(s, d) for (s, d) in zip(shapes, dtypes))
            # We need to create a new model instance and not reuse `model_train` here,
            # because `m.initialize` puts cached parameter values in `m` and hence the
            # next call of `m.initialize` will give wrong results.
            m = tl.Serial(model(mode='train'), loss_fn)
            m._set_rng_recursive(rng)  # pylint: disable=protected-access
            weights, state = m.init(input_signature)
            (slots, opt_params) = opt.tree_init(weights)
            return (OptState(weights, slots, opt_params), state)

        if _is_jit_init():
            # JIT parameter initialization to avoid memory fragmentation
            new_opt_state_and_model_state = math.jit(
                new_opt_state_and_model_state, static_argnums=(0, ))
        self._new_opt_state_and_model_state = (
            lambda: new_opt_state_and_model_state(  # pylint: disable=g-long-lambda
                self._inputs.example_shape_dtype, init_rng))

        # Arrange and initialize metrics layers.
        self._metrics = list(sorted(self._metrics_dict.keys()))
        metrics_layers = [
            self._metrics_dict[m](has_weights=self._has_weights,
                                  id_to_mask=self._id_to_mask)
            for m in self._metrics
        ]
        metrics_in_parallel = tl.Branch(*metrics_layers)
        metrics_in_parallel._set_rng_recursive(init_rng)  # pylint: disable=protected-access
        example_signature = tuple(
            ShapeDtype(s, d)
            for (s, d) in zip(*self._inputs.example_shape_dtype))
        model_predict_eval.init(example_signature)
        output_signature = model_predict_eval.output_signature(
            example_signature)
        m_weights, m_state = metrics_in_parallel.init(output_signature)
        self._metrics_weights = self._for_n_devices(m_weights)
        self._metrics_state = self._for_n_devices(m_state)

        # Jit model_predict and update so they're fast.
        self._jit_eval = _jit_predict_fn(model_predict_eval,
                                         metrics_in_parallel, self._n_devices)
        self._jit_update_fn = _jit_update_fn(model_train, loss_fn, opt,
                                             self._n_devices)

        self._model_train = model_train
        self._model_predict_eval = model_predict_eval
        self._loss_fn = loss_fn
        # TODO(pkozakowski): "Learning rate schedules" are currently able to control
        # control all optimizer parameters and model state, so let's rename them
        # accordingly.
        self._lr_schedule = lr_schedule

        if nontrainable_param_map is None:
            nontrainable_param_map = {}
        self._nontrainable_param_map = nontrainable_param_map

        # Those fields will be set in reset().
        self._output_dir = None
        self._train_sw = None
        self._eval_sw = None
        self._history = None
        self._lr_fn = None
        self._opt_state = None
        self._step = None
        self._model_state = None
        self.reset(output_dir)
Example #8
0
    def __init__(self,
                 model,
                 loss_fn,
                 optimizer,
                 lr_schedule,
                 inputs,
                 output_dir=None,
                 random_seed=None,
                 n_devices=None,
                 checkpoints_at=None,
                 should_save_checkpoints=True,
                 should_write_summaries=True,
                 metrics=None,
                 checkpoint_highest=None,
                 checkpoint_lowest=None):

        self._is_chief, self._n_devices, rng = (self._init_host_and_devices(
            n_devices, random_seed))
        self._should_save_checkpoints = should_save_checkpoints and self._is_chief
        self._checkpoints_at = checkpoints_at or []
        self._should_write_summaries = should_write_summaries
        if not output_dir:
            self._should_save_checkpoints = False
            self._should_write_summaries = False
        self._checkpoint_highest = checkpoint_highest
        self._checkpoint_lowest = checkpoint_lowest
        if metrics is not None:
            self._metrics_dict = metrics
        else:
            self._metrics_dict = _DEFAULT_METRICS
            self._metrics_dict['loss'] = loss_fn
        self._metrics_dict = metrics if metrics is not None else _DEFAULT_METRICS
        # Inputs is either an Inputs instance or a function that returns it.
        self._inputs = inputs
        if callable(
                inputs):  # If we pass a function, e.g., through gin, call it.
            self._inputs = inputs()
        # Initialize the learning rate to a dummy value. It will be set in reset().
        opt = optimizer(learning_rate=0.0)

        # Setup the model.
        model_train = model(mode='train')
        model_predict_eval = model(mode='eval')
        self._model_with_loss = tl.Serial(model_train, loss_fn)

        # Setup state.
        rng, init_rng = jax_random.split(rng)
        self._rngs = np.stack(jax_random.split(rng, self._n_devices))
        shapes, dtypes = self._inputs.example_shape_dtype
        input_signature = tuple(
            ShapeDtype(s, d) for (s, d) in zip(shapes, dtypes))

        def new_opt_state_and_model_state(rng):
            """Returns optimizer and model states suitable for training a model."""
            weights, state = self._model_with_loss.init(input_signature,
                                                        rng=rng)
            (slots, opt_params) = opt.tree_init(weights)
            return (OptState(weights, slots, opt_params), state)

        if math.backend_name() == 'jax':
            # JIT parameter initialization to avoid memory fragmentation
            new_opt_state_and_model_state = math.jit(
                new_opt_state_and_model_state)
        self._new_opt_state_and_model_state = (
            lambda: new_opt_state_and_model_state(init_rng))

        # Arrange and initialize metrics layers.
        self._metrics = list(sorted(self._metrics_dict.keys()))
        metrics_layers = [self._metrics_dict[m] for m in self._metrics]
        metrics_in_parallel = tl.Branch(*metrics_layers)
        metrics_in_parallel.rng = init_rng
        example_signature = tuple(
            ShapeDtype(s, d)
            for (s, d) in zip(*self._inputs.example_shape_dtype))
        model_predict_eval.init(example_signature)
        self._input_signature = example_signature
        output_signature = model_predict_eval.output_signature(
            example_signature)
        m_weights, m_state = metrics_in_parallel.init(output_signature)
        self._metrics_weights = self._for_n_devices(m_weights)
        self._metrics_state = self._for_n_devices(m_state)

        # Jit model_predict and update so they're fast.
        self._jit_eval = _jit_predict_fn(model_predict_eval,
                                         metrics_in_parallel, self._n_devices)
        self._jit_update_fn = _jit_update_fn(model_train, loss_fn, opt,
                                             self._n_devices)

        self._model_train = model_train
        self._model_predict_eval = model_predict_eval
        self._loss_fn = loss_fn
        # TODO(pkozakowski): "Learning rate schedules" are currently able to control
        # control all optimizer parameters and model state, so let's rename them
        # accordingly.
        self._lr_schedule = lr_schedule

        # Those fields will be set in reset().
        self._output_dir = None
        self._train_sw = None
        self._eval_sw = None
        self._history = None
        self._lr_fn = None
        self._opt_state = None
        self._step = None
        self._model_state = None
        self.reset(output_dir)