Exemple #1
0
  def reverse_and_grad(self, output, ct, weights=(), state=(), new_state=(),
                       **kwargs):
    rng = kwargs.pop('rng', None)
    rngs = (None,) * self._n_layers
    if rng is not None:
      rngs = random.split(rng, self._n_layers)

    def call_compute_residual(x, weights):
      res = self.compute_residual(x, weights=weights, state=state[0],
                                  rng=rngs[0], **kwargs)
      return res

    assert len(ct) == self.n_preserve + 1
    ct = (ct[0], ct[0]) + ct[1:]

    stack_with_residual, vjpfun = jax.vjp(
        call_compute_residual, output, weights[0])
    reconstructed_x = self.subtract_top(
        stack_with_residual, weights=weights[-1], state=state[-1], rng=rngs[-1],
        **kwargs)

    x_ct, residual_weights_ct = vjpfun(ct)
    assert not jax.tree_util.tree_leaves(weights[-1])
    add_top_weights_ct = weights[-1]
    return reconstructed_x, (x_ct, [residual_weights_ct, add_top_weights_ct])
Exemple #2
0
def run_policy(
    policy_and_value_net_apply,
    observations,
    lengths,
    weights,
    state,
    rng,
    action_space,
):
    """Runs the policy network."""
    # TODO(pkozakowski): Pass the actual actions here, to enable autoregressive
    # action sampling.
    (B, T_plus_1) = observations.shape[:2]  # pylint: disable=invalid-name
    dummy_actions = onp.zeros((B, T_plus_1 - 1) + action_space.shape,
                              dtype=action_space.dtype)
    policy_input = (observations, dummy_actions)
    (rng, subrng) = trax_random.split(rng)
    (log_probs, value_preds) = policy_and_value_net_apply(policy_input,
                                                          weights=weights,
                                                          state=state,
                                                          rng=subrng)
    # We need the log_probs of those actions that correspond to the last actual
    # time-step.
    index = lengths - 1  # Since we want to index using lengths.
    log_probs = log_probs[np.arange(B), index]
    value_preds = value_preds[np.arange(B), index]
    return (log_probs, value_preds, state, rng)
Exemple #3
0
 def single_update(i, opt_state, batch, state, rng):
     weights, slots, opt_params = opt_state
     rng, subrng = jax_random.split(rng[0])
     grad_fn = math.grad(model_and_loss_call, has_aux=True)
     grads, state = grad_fn(weights, batch, state, rng)
     return optimizer.tree_update(i, grads, weights, slots,
                                  opt_params), state, [subrng]
Exemple #4
0
    def evaluation_round(self, inputs_stream, weights, state, rng):
        """Evaluate.

    Args:
      inputs_stream: iterable of inputs to evaluate on.
      weights: weights for each f in eval_fns.
      state: state for each f in eval_fns.
      rng: random number generator.

    Returns:
      metrics: dict from metric name to metric value averaged over the number of
        inputs.
      state: end state for `predict_fn`.
    """
        metrics = collections.defaultdict(float)
        count = 0
        for inp in inputs_stream:
            count += 1
            rng, subrng = jax_random.split(rng)
            metric_values, _ = self._jit_eval(inp, weights, state, subrng)
            try:
                metric_values = list(metric_values)
            except TypeError:
                metric_values = [float(metric_values)]
            for m, v in zip(self._metrics, metric_values):
                metrics[m] += v
        return {m: v / count for (m, v) in six.iteritems(metrics)}, state
Exemple #5
0
 def _consume_act(self, actions, predict_fn, rng):
     act_repr = self._action_serializer.serialize(actions)
     for (i,
          subrng) in enumerate(jax_random.split(rng,
                                                self._act_repr_length)):
         # Run the network to update the inference buffers, but ignore the result.
         predict_fn(self._last_symbols, rng=subrng)
         self._last_symbols = act_repr[:, i:(i + 1)]
Exemple #6
0
 def _predict_obs(self, predict_fn, rng):
   obs_repr = np.zeros(
       (self._steps.shape[0], self._obs_repr_length), dtype=np.int32,
   )
   for (i, subrng) in enumerate(jax_random.split(rng, self._obs_repr_length)):
     log_probs = predict_fn(self._last_symbols, rng=subrng)
     self._last_symbols = utils.gumbel_sample(log_probs)
     obs_repr[:, i] = self._last_symbols[:, 0]
   return self._obs_serializer.deserialize(obs_repr)
Exemple #7
0
    def _reset(self, indices):
        """Resets environments at the given indices.

    Args:
      indices: list of indices of underlying envs to call reset on.

    Returns:
      np.ndarray of batched observations from the reset envs.
    """
        history = next(self._history_stream)
        (subrng, self._rng) = jax_random.split(self._rng)
        return self._reset_model(self._predict_fn, indices, history, subrng)
Exemple #8
0
 def reverse(self, output, weights=(), state=(), new_state=(), rng=None):
   reconstructed_x = output
   rngs = (None,) * self._n_layers
   if rng is not None:
     rngs = random.split(rng, self._n_layers)
   # Note that self.sublayers aligns exactly with self.reverse_layers in
   # terms of parameter and rng usage, so no re-ordering is required.
   for layer, p, s, ns, rng in zip(
       self.reverse_layers, weights, state, new_state, rngs):
     reconstructed_x = layer(reconstructed_x, weights=p,
                             state=s, new_state=ns, rng=rng)
   return reconstructed_x
Exemple #9
0
  def forward_and_backward(self, inputs, ct, state, new_state, rng=None):
    # Simultaneous forward pass and backprop through the attention mechanism.
    qkv = inputs[:3]
    passthrough = inputs[3:]
    out_ct = ct[0]
    passthrough_ct = ct[1:]
    if rng is not None:
      # Adjust RNG to match the forward pass.
      rng = random.split(rng, self._n_layers)[0]

    out, qkv_ct = self.attention.forward_and_backward(
        qkv, out_ct, state[0], new_state[0], rng)
    return (out,) + passthrough, qkv_ct + passthrough_ct
Exemple #10
0
    def _step(self, actions):
        """Takes a step in all environments.

    Args:
      actions: (np.ndarray) with first dimension equal to the batch size.

    Returns:
      a tuple of batched raw observations, raw rewards, dones and infos.
    """
        # Predict the next observation.
        (subrng, self._rng) = jax_random.split(self._rng)
        (observation, reward, done) = self._step_model(self._predict_fn,
                                                       actions, subrng)
        return (observation, reward, done, {})
Exemple #11
0
 def mapped_update(i, opt_state, batch, state, rng):
     """This is a multi-device version of the update function above."""
     # We assume all tensors have the first dimension = n_devices.
     weights, slots, opt_params = opt_state
     rng, subrng = jax_random.split(rng)
     grad_fn = math.grad(model_and_loss_call, has_aux=True)
     grads, state = grad_fn(weights, batch, state, rng)
     # We do a psum(1.0) here instead of `n_devices` since `n_devices` is just
     # the number of devices on this host machine, however psum goes over all
     # devices of all hosts (ex: a TPU pod) and we need to be averaging over all
     # of them.
     grads = jax.tree_util.tree_map(
         lambda g: math.psum(g, 'batch') / math.psum(1.0, 'batch'), grads)
     return optimizer.tree_update(i, grads, weights, slots,
                                  opt_params), state, subrng
Exemple #12
0
def run_policy_all_timesteps(
    policy_and_value_net_apply,
    observations,
    weights,
    state,
    rng,
    action_space,
):
    """Runs the policy network."""
    # TODO(pkozakowski): Pass the actual actions here, to enable autoregressive
    # action sampling.
    (B, T_plus_1) = observations.shape[:2]  # pylint: disable=invalid-name
    dummy_actions = onp.zeros((B, T_plus_1 - 1) + action_space.shape,
                              dtype=action_space.dtype)
    policy_input = (observations, dummy_actions)
    (rng, subrng) = trax_random.split(rng)
    (log_probs, value_preds) = policy_and_value_net_apply(policy_input,
                                                          weights=weights,
                                                          state=state,
                                                          rng=subrng)
    return log_probs, value_preds, state, rng
Exemple #13
0
  def evaluate(self, n_eval_steps):
    """Evaluate the model and log metrics."""
    _, rng = jax_random.split(self._rngs[0])
    # TODO(lukaszkaiser): both model state and parameters by default include
    # the loss layer. Currently, we access the pure-model parameters by just
    # indexing, [0] here. But we should make it more explicit in a better API.
    weights = (self._opt_state[0][0], self._metrics_weights)
    state = (self._model_state[0], self._metrics_state)
    self.log_step('Evaluation')
    train_eval_slice = itertools.islice(self._train_eval_stream, n_eval_steps)
    train_metrics, _ = self.evaluation_round(train_eval_slice, weights, state,
                                             rng)
    self.log_metrics(train_metrics, self._train_sw, 'train')
    eval_slice = itertools.islice(self._eval_stream, n_eval_steps)
    eval_metrics, _ = self.evaluation_round(eval_slice, weights, state, rng)
    self.log_metrics(eval_metrics, self._eval_sw, 'eval')
    self.log_step('Finished evaluation')

    # Save the optimizer weights in the history
    for (name, value) in self.nontrainable_params.items():
      self._history.append('train', 'training/{}'.format(name), self._step,
                           value)
Exemple #14
0
    def __init__(self,
                 model,
                 loss_fn,
                 optimizer,
                 lr_schedule,
                 inputs,
                 output_dir=None,
                 random_seed=None,
                 n_devices=None,
                 checkpoints_at=None,
                 should_save_checkpoints=True,
                 should_write_summaries=True,
                 has_weights=False,
                 nontrainable_param_map=None,
                 id_to_mask=None,
                 metrics=None,
                 checkpoint_highest=None,
                 checkpoint_lowest=None):

        self._is_chief, self._n_devices, rng = (self._init_host_and_devices(
            n_devices, random_seed))
        self._should_save_checkpoints = should_save_checkpoints and self._is_chief
        self._checkpoints_at = checkpoints_at or []
        self._should_write_summaries = should_write_summaries
        if not output_dir:
            self._should_save_checkpoints = False
            self._should_write_summaries = False
        self._checkpoint_highest = checkpoint_highest
        self._checkpoint_lowest = checkpoint_lowest
        self._has_weights = has_weights
        self._id_to_mask = id_to_mask
        self._metrics_dict = metrics if metrics is not None else _DEFAULT_METRICS
        loss_fn = loss_fn(has_weights=has_weights, id_to_mask=id_to_mask)
        # Inputs is either an Inputs instance or a function that returns it.
        self._inputs = inputs
        if callable(
                inputs):  # If we pass a function, e.g., through gin, call it.
            self._inputs = inputs()

        # Initialize the learning rate to a dummy value. It will be set in reset().
        opt = optimizer(learning_rate=0.0)

        # Setup the model.
        model_train = model(mode='train')
        model_predict_eval = model(mode='eval')

        # Setup state.
        rng, init_rng = jax_random.split(rng)
        self._rngs = np.stack(jax_random.split(rng, self._n_devices))
        # If the inputs are a tuple/list, add [None] (batch) to each element.
        if self._inputs.input_shape and isinstance(self._inputs.input_shape[0],
                                                   (list, tuple)):
            model_input_shape = tuple(
                tuple([None] + list(shape))
                for shape in self._inputs.input_shape)
        else:  # Otherwise just add [None] to the input shape.
            model_input_shape = tuple([None] + list(self._inputs.input_shape))
        # Same for targets.
        if self._inputs.target_shape and isinstance(
                self._inputs.target_shape[0], (list, tuple)):
            model_target_shape = tuple(
                tuple([None] + list(shape))
                for shape in self._inputs.target_shape)
        else:
            model_target_shape = tuple([None] +
                                       list(self._inputs.target_shape))
        # Change all None to 1 in input and target shape.
        model_input_shape = math.nested_map(lambda x: x or 1,
                                            model_input_shape)
        model_target_shape = math.nested_map(lambda x: x or 1,
                                             model_target_shape)

        def new_opt_state_and_model_state(shape_dtype, rng):
            """Returns optimizer and model states suitable for training a model."""
            # Combine inputs and targets on the stack.
            shapes, dtypes = shape_dtype
            input_signature = tuple(
                ShapeDtype(s, d) for (s, d) in zip(shapes, dtypes))
            # We need to create a new model instance and not reuse `model_train` here,
            # because `m.initialize` puts cached parameter values in `m` and hence the
            # next call of `m.initialize` will give wrong results.
            m = tl.Serial(model(mode='train'), loss_fn)
            m._set_rng_recursive(rng)  # pylint: disable=protected-access
            weights, state = m.init(input_signature)
            (slots, opt_params) = opt.tree_init(weights)
            return (OptState(weights, slots, opt_params), state)

        if _is_jit_init():
            # JIT parameter initialization to avoid memory fragmentation
            new_opt_state_and_model_state = math.jit(
                new_opt_state_and_model_state, static_argnums=(0, ))
        self._new_opt_state_and_model_state = (
            lambda: new_opt_state_and_model_state(  # pylint: disable=g-long-lambda
                self._inputs.example_shape_dtype, init_rng))

        # Arrange and initialize metrics layers.
        self._metrics = list(sorted(self._metrics_dict.keys()))
        metrics_layers = [
            self._metrics_dict[m](has_weights=self._has_weights,
                                  id_to_mask=self._id_to_mask)
            for m in self._metrics
        ]
        metrics_in_parallel = tl.Branch(*metrics_layers)
        metrics_in_parallel._set_rng_recursive(init_rng)  # pylint: disable=protected-access
        example_signature = tuple(
            ShapeDtype(s, d)
            for (s, d) in zip(*self._inputs.example_shape_dtype))
        model_predict_eval.init(example_signature)
        output_signature = model_predict_eval.output_signature(
            example_signature)
        m_weights, m_state = metrics_in_parallel.init(output_signature)
        self._metrics_weights = self._for_n_devices(m_weights)
        self._metrics_state = self._for_n_devices(m_state)

        # Jit model_predict and update so they're fast.
        self._jit_eval = _jit_predict_fn(model_predict_eval,
                                         metrics_in_parallel, self._n_devices)
        self._jit_update_fn = _jit_update_fn(model_train, loss_fn, opt,
                                             self._n_devices)

        self._model_train = model_train
        self._model_predict_eval = model_predict_eval
        self._loss_fn = loss_fn
        # TODO(pkozakowski): "Learning rate schedules" are currently able to control
        # control all optimizer parameters and model state, so let's rename them
        # accordingly.
        self._lr_schedule = lr_schedule

        if nontrainable_param_map is None:
            nontrainable_param_map = {}
        self._nontrainable_param_map = nontrainable_param_map

        # Those fields will be set in reset().
        self._output_dir = None
        self._train_sw = None
        self._eval_sw = None
        self._history = None
        self._lr_fn = None
        self._opt_state = None
        self._step = None
        self._model_state = None
        self.reset(output_dir)
Exemple #15
0
 def predict(x, weights, state, rng):
     """Predict function jited and parallelized as requested."""
     res, state = _combine_devices(
         model_predict(_reshape_by_device(x, n_devices), weights, state,
                       np.stack(jax_random.split(rng, n_devices))))
     return math.nested_map(lambda y: np.mean(y, axis=0), res), state
Exemple #16
0
    def __init__(self,
                 model,
                 loss_fn,
                 optimizer,
                 lr_schedule,
                 inputs,
                 output_dir=None,
                 random_seed=None,
                 n_devices=None,
                 checkpoints_at=None,
                 should_save_checkpoints=True,
                 should_write_summaries=True,
                 metrics=None,
                 checkpoint_highest=None,
                 checkpoint_lowest=None):

        self._is_chief, self._n_devices, rng = (self._init_host_and_devices(
            n_devices, random_seed))
        self._should_save_checkpoints = should_save_checkpoints and self._is_chief
        self._checkpoints_at = checkpoints_at or []
        self._should_write_summaries = should_write_summaries
        if not output_dir:
            self._should_save_checkpoints = False
            self._should_write_summaries = False
        self._checkpoint_highest = checkpoint_highest
        self._checkpoint_lowest = checkpoint_lowest
        if metrics is not None:
            self._metrics_dict = metrics
        else:
            self._metrics_dict = _DEFAULT_METRICS
            self._metrics_dict['loss'] = loss_fn
        self._metrics_dict = metrics if metrics is not None else _DEFAULT_METRICS
        # Inputs is either an Inputs instance or a function that returns it.
        self._inputs = inputs
        if callable(
                inputs):  # If we pass a function, e.g., through gin, call it.
            self._inputs = inputs()
        # Initialize the learning rate to a dummy value. It will be set in reset().
        opt = optimizer(learning_rate=0.0)

        # Setup the model.
        model_train = model(mode='train')
        model_predict_eval = model(mode='eval')
        self._model_with_loss = tl.Serial(model_train, loss_fn)

        # Setup state.
        rng, init_rng = jax_random.split(rng)
        self._rngs = np.stack(jax_random.split(rng, self._n_devices))
        shapes, dtypes = self._inputs.example_shape_dtype
        input_signature = tuple(
            ShapeDtype(s, d) for (s, d) in zip(shapes, dtypes))

        def new_opt_state_and_model_state(rng):
            """Returns optimizer and model states suitable for training a model."""
            weights, state = self._model_with_loss.init(input_signature,
                                                        rng=rng)
            (slots, opt_params) = opt.tree_init(weights)
            return (OptState(weights, slots, opt_params), state)

        if math.backend_name() == 'jax':
            # JIT parameter initialization to avoid memory fragmentation
            new_opt_state_and_model_state = math.jit(
                new_opt_state_and_model_state)
        self._new_opt_state_and_model_state = (
            lambda: new_opt_state_and_model_state(init_rng))

        # Arrange and initialize metrics layers.
        self._metrics = list(sorted(self._metrics_dict.keys()))
        metrics_layers = [self._metrics_dict[m] for m in self._metrics]
        metrics_in_parallel = tl.Branch(*metrics_layers)
        metrics_in_parallel.rng = init_rng
        example_signature = tuple(
            ShapeDtype(s, d)
            for (s, d) in zip(*self._inputs.example_shape_dtype))
        model_predict_eval.init(example_signature)
        self._input_signature = example_signature
        output_signature = model_predict_eval.output_signature(
            example_signature)
        m_weights, m_state = metrics_in_parallel.init(output_signature)
        self._metrics_weights = self._for_n_devices(m_weights)
        self._metrics_state = self._for_n_devices(m_state)

        # Jit model_predict and update so they're fast.
        self._jit_eval = _jit_predict_fn(model_predict_eval,
                                         metrics_in_parallel, self._n_devices)
        self._jit_update_fn = _jit_update_fn(model_train, loss_fn, opt,
                                             self._n_devices)

        self._model_train = model_train
        self._model_predict_eval = model_predict_eval
        self._loss_fn = loss_fn
        # TODO(pkozakowski): "Learning rate schedules" are currently able to control
        # control all optimizer parameters and model state, so let's rename them
        # accordingly.
        self._lr_schedule = lr_schedule

        # Those fields will be set in reset().
        self._output_dir = None
        self._train_sw = None
        self._eval_sw = None
        self._history = None
        self._lr_fn = None
        self._opt_state = None
        self._step = None
        self._model_state = None
        self.reset(output_dir)
Exemple #17
0
 def single_compute_loss(opt_state, batch, state, rng):
     rng, subrng = jax_random.split(rng[0])
     loss_val, state = loss_fn(opt_state[0], batch, predict_fn, state,
                               rng)
     return loss_val, state, [subrng]
Exemple #18
0
 def mapped_compute_loss(opt_state, batch, state, rng):
     """This is a multi-device version of the update function above."""
     # We assume all tensors have the first dimension = n_devices.
     rng, subrng = jax_random.split(rng)
     loss_val, state = loss_fn(opt_state[0], batch, predict_fn, state, rng)
     return loss_val, state, subrng
Exemple #19
0
  def reverse_and_grad(self, output, ct, weights=(), state=(), new_state=(),
                       **kwargs):
    rng = kwargs.pop('rng', None)
    rngs = (None,) * self._n_layers
    if rng is not None:
      rngs = random.split(rng, self._n_layers)

    # Forward pass through self.pre_attention, while preparing for
    # later backprop.
    def call_pre_attention(x, weights):
      res = self.pre_attention(x, weights=weights, state=state[0], rng=rngs[0],
                               **kwargs)
      return res
    stack, pre_attention_vjpfun = jax.vjp(call_pre_attention,
                                          output, weights[0])

    # Backprop through adding the residual
    assert len(ct) == 2
    ct = saved_ct = (ct[0], ct[0], ct[1])

    # Backprop through self.post_attention with respect to the inputs only
    def call_post_attention(x):
      res = self.post_attention(x, weights=weights[2], state=state[2],
                                rng=rngs[2], **kwargs)
      return res
    # Note: these are *not* the actual inputs to self.post_attention.
    # If self.post_attention is not linear, we will get incorrect gradients.
    dummy_inputs = (stack[-3], stack[-2], stack[-1])
    _, post_attention_vjpfun = jax.vjp(call_post_attention, dummy_inputs)
    (ct,) = post_attention_vjpfun(ct)

    # Simultaneous forward pass and backprop through the attention mechanism
    stack, ct = self.attention.forward_and_backward(
        stack, ct, rng=rngs[1], state=state[1], new_state=new_state[1],
        **kwargs)
    assert not jax.tree_util.tree_leaves(weights[1])
    attention_weights_ct = weights[1]  # This is valid when weights is empty.

    # Backprop through self.pre_attention
    x_ct, pre_attention_weights_ct = pre_attention_vjpfun(ct)

    # Forward pass for self.post_attention, and backprop with respect to the
    # parameters only
    def call_post_attention2(weights):
      res = self.post_attention(stack, weights=weights, state=state[2],
                                rng=rngs[2], **kwargs)
      return res
    stack, post_attention_vjpfun = jax.vjp(call_post_attention2, weights[2])
    (post_attention_weights_ct,) = post_attention_vjpfun(saved_ct)

    # Forward pass through subtracting the residual
    reconstructed_x = self.subtract_top(
        stack, weights=weights[-1], state=state[-1], rng=rngs[-1], **kwargs)

    assert not jax.tree_util.tree_leaves(weights[-1])
    add_top_weights_ct = weights[-1]
    weights_ct = [
        pre_attention_weights_ct,
        attention_weights_ct,
        post_attention_weights_ct,
        add_top_weights_ct,
    ]

    return reconstructed_x, (x_ct, weights_ct)