Exemplo n.º 1
0
 def test_computes(self):
     rng_key = jax_random.get_prng(0)
     hidden_size = (4, 4)
     output_size = 6
     model = atari_cnn.FrameStackMLP(hidden_sizes=hidden_size,
                                     output_size=output_size)
     B, T, OBS = 2, 2, 3  # pylint: disable=invalid-name
     rng_key, key = jax_random.split(rng_key)
     params, state = model.initialize((1, 1, OBS), onp.float32, key)
     x = onp.arange(B * (T + 1) * OBS).reshape(B, T + 1, OBS)
     rng_key, key = jax_random.split(rng_key)
     y, _ = model(x, params, state=state, rng=key)
     self.assertEqual((B, T + 1, output_size), y.shape)
Exemplo n.º 2
0
 def test_computes(self):
     rng_key = jax_random.get_prng(0)
     hidden_size = (4, 4)
     output_size = 6
     model = atari_cnn.AtariCnn(hidden_sizes=hidden_size,
                                output_size=output_size)
     B, T, OBS = 2, 2, (28, 28, 3)  # pylint: disable=invalid-name
     rng_key, key = jax_random.split(rng_key)
     params = model.initialize((1, 1) + OBS, onp.float32, key)
     x = onp.arange(B * (T + 1) * functools.reduce(op.mul, OBS)).reshape(
         B, T + 1, *OBS)
     rng_key, key = jax_random.split(rng_key)
     y = model(x, params, rng=key)
     self.assertEqual((B, T + 1, output_size), y.shape)
Exemplo n.º 3
0
def evaluation_round(inputs_stream, metric_names, eval_fn, params, state, rng):
    """Evaluate.

  Args:
    inputs_stream: iterable of inputs to evaluate on.
    metric_names: list of strings, the order in which eval_fn returns metrics.
    eval_fn: metric function, which takes inputs and predictions (and
      params, state, rng) and returns a tuple of scalar metric values.
    params: params for each f in eval_fns.
    state: state for each f in eval_fns.
    rng: random number generator.

  Returns:
    metrics: dict from metric name to metric value averaged over the number of
      inputs.
    state: end state for `predict_fn`.
  """
    metrics = collections.defaultdict(float)
    count = 0
    for inp in inputs_stream:
        count += 1
        rng, subrng = jax_random.split(rng)
        metric_values = eval_fn(inp, params=params, state=state, rng=subrng)
        try:
            metric_values = list(metric_values)
        except TypeError:
            metric_values = [float(metric_values)]
        for m, v in zip(metric_names, metric_values):
            metrics[m] += v
    return {m: v / count for (m, v) in six.iteritems(metrics)}, state
Exemplo n.º 4
0
def check_shape_agreement(test_case, init_fun, apply_fun, input_shape):
    rng_key1, rng_key2 = random.split(random.get_prng(0))
    result_shape, params = init_fun(rng_key1, input_shape)
    inputs = random_inputs(onp.random.RandomState(0), input_shape)
    result = apply_fun(params, inputs, rng=rng_key2)
    test_case.assertEqual(result.shape, result_shape)
    return result_shape
Exemplo n.º 5
0
 def single_update(i, opt_state, batch, state, rng):
     params, slots, opt_params = opt_state
     rng, subrng = jax_random.split(rng[0])
     grad_fn = backend.grad(loss_fn, has_aux=True)
     grads, state = grad_fn(params, batch, predict_fn, state, rng)
     return optimizer.tree_update(i, grads, params, slots,
                                  opt_params), state, [subrng]
Exemplo n.º 6
0
    def evaluate(self, eval_steps):
        """Evaluate the model and log metrics."""
        _, rng = jax_random.split(self._rngs[0])
        # TODO(lukaszkaiser): both model state and parameters by default include
        # the loss layer. Currently, we access the pure-model parameters by just
        # indexing, [0] here. But we should make it more explicit in a better API.
        params = (self._opt_state[0][0], self._metrics_params)
        state = (self._model_state[0], self._metrics_state)
        step_log(self._step, "Evaluation")
        train_eval_slice = itertools.islice(self._train_eval_stream,
                                            eval_steps)
        train_metrics, _ = evaluation_round(train_eval_slice, self._metrics,
                                            self._jit_eval, params, state, rng)
        if self._train_sw:
            log_metrics(train_metrics,
                        self._train_sw,
                        "train",
                        self._step,
                        history=self._history)
        eval_slice = itertools.islice(self._eval_stream, eval_steps)
        eval_metrics, _ = evaluation_round(eval_slice, self._metrics,
                                           self._jit_eval, params, state, rng)
        if self._eval_sw:
            log_metrics(eval_metrics,
                        self._eval_sw,
                        "eval",
                        self._step,
                        history=self._history)
        step_log(self._step, "Finished evaluation")

        # Save the optimizer params in the history
        for (name, value) in self.nontrainable_params.items():
            self._history.append("train", "training/{}".format(name),
                                 self._step, value)
Exemplo n.º 7
0
def evaluate(inputs_stream, predict_fn, metric_fns, state, rng):
    """Evaluate.

  Args:
    inputs_stream: iterable of inputs to evaluate on.
    predict_fn: function from inputs to predictions. params should already be
      partially applied.
    metric_fns: dict from metric name to metric function, which takes inputs
      and predictions and returns a scalar metric value.
    state: start state for `predict_fn`.
    rng: random number generator.

  Returns:
    metrics: dict from metric name to metric value averaged over the number of
      inputs.
    state: end state for `predict_fn`.
  """
    metrics = collections.defaultdict(float)
    count = 0
    for inp in inputs_stream:
        count += 1
        rng, subrng = jax_random.split(rng)
        preds, state = predict_fn(inp[0], state=state, rng=subrng)
        for m, f in six.iteritems(metric_fns):
            metrics[m] += f(inp, preds)
    return {m: v / count for (m, v) in six.iteritems(metrics)}, state
Exemplo n.º 8
0
def evaluate(inputs_stream, predict_fn, metric_fns, state, rng, has_weights):
    """Evaluate.

  Args:
    inputs_stream: iterable of inputs to evaluate on.
    predict_fn: function from inputs to predictions. params should already be
      partially applied.
    metric_fns: dict from metric name to metric function, which takes inputs
      and predictions and returns a scalar metric value.
    state: start state for `predict_fn`.
    rng: random number generator.
    has_weights: bool, whether weights are included in the inputs.

  Returns:
    metrics: dict from metric name to metric value averaged over the number of
      inputs.
    state: end state for `predict_fn`.
  """
    metrics = collections.defaultdict(float)
    count = 0
    for inp in inputs_stream:
        count += 1
        rng, subrng = jax_random.split(rng)
        model_inp, get_preds = _stack_inputs_targets_and_get_predictions(inp)
        # Call model, preds will be the returned stack, usually (pred, targets).
        preds, state = predict_fn(model_inp, state=state, rng=subrng)
        pred = get_preds(preds)
        for m, f in six.iteritems(metric_fns):
            metrics[m] += f(inp, pred, has_weights=has_weights)
    return {m: v / count for (m, v) in six.iteritems(metrics)}, state
Exemplo n.º 9
0
 def single_update(i, opt_state, batch, rng):
     rng, subrng = jax_random.split(rng[0])
     params, opt_slots = opt_state
     return optimizer.tree_update(
         i,
         backend.grad(loss_fn)(params, batch, predict_fn, rng), params,
         opt_slots), [subrng]
Exemplo n.º 10
0
    def _step(self, actions):
        """Takes a step in all environments.

    Args:
      actions: (np.ndarray) with first dimension equal to the batch size.

    Returns:
      a tuple of batched raw observations, raw rewards, dones and infos.
    """
        # Predict the next observation.
        (subrng, self._rng) = jax_random.split(self._rng)
        (observation, reward) = self._model_predict((self._history, actions),
                                                    params=self._model_params,
                                                    rng=subrng)

        # Roll the history one timestep back and append the new observation.
        self._history = np.roll(self._history, shift=-1, axis=1)
        self._history[:, -1, ...] = observation

        # Increment the step counters and determine which envs are done.
        self._steps += 1
        done = self._steps == self._trajectory_length

        # Call copy() to get the data as numpy arrays.
        observation = observation.copy()
        # Reshape the rewards to get rid of the extra dimension.
        reward = np.squeeze(reward.copy(), axis=1)
        return (observation, reward, done, {})
Exemplo n.º 11
0
 def mapped_update(i, opt_state, batch, rng):
     """This is a multi-device version of the update function above."""
     # We assume all tensors have the first dimension = n_devices.
     rng, subrng = jax_random.split(rng)
     params, opt_slots = opt_state
     grads = backend.grad(loss_fn)(params, batch, predict_fn, rng)
     grads = jax.tree_util.tree_map(lambda g: lax.psum(g, "batch"), grads)
     return optimizer.tree_update(i, grads, params, opt_slots), subrng
Exemplo n.º 12
0
def check_shape_agreement(test_case, layer, input_shape):
    rng_key1, rng_key2 = random.split(random.get_prng(0))
    result_shape = layer.output_shape(input_shape)
    params = layer.initialize(input_shape, rng_key1)
    inputs = random_inputs(onp.random.RandomState(0), input_shape)
    result = layer(inputs, params, rng=rng_key2)
    test_case.assertEqual(result.shape, result_shape)
    return result_shape
Exemplo n.º 13
0
 def single_update(i, opt_state, batch, rng):
     rng, subrng = jax_random.split(rng[0])
     _, opt_update = optimizer(lr_fun)
     params = trax_opt.get_params(opt_state)
     return opt_update(
         i,
         backend.grad(loss_fun)(params, batch, predict_fun, rng),
         opt_state), [subrng]
Exemplo n.º 14
0
 def _consume_act(self, actions, predict_fn, rng):
     act_repr = self._action_serializer.serialize(actions)
     for (i,
          subrng) in enumerate(jax_random.split(rng,
                                                self._act_repr_length)):
         # Run the network to update the inference buffers, but ignore the result.
         predict_fn(self._last_symbols, rng=subrng)
         self._last_symbols = act_repr[:, i:(i + 1)]
 def _predict_obs(self, predict_fn, rng):
   obs_repr = np.zeros(
       (self._steps.shape[0], self._obs_repr_length), dtype=np.int32,
   )
   for (i, subrng) in enumerate(jax_random.split(rng, self._obs_repr_length)):
     log_probs = predict_fn(self._last_symbols, rng=subrng)
     self._last_symbols = utils.gumbel_sample(log_probs)
     obs_repr[:, i] = self._last_symbols[:, 0]
   return self._obs_serializer.deserialize(obs_repr)
Exemplo n.º 16
0
 def mapped_update(i, opt_state, batch, state, rng):
     """This is a multi-device version of the update function above."""
     # We assume all tensors have the first dimension = n_devices.
     params, slots, opt_params = opt_state
     rng, subrng = jax_random.split(rng)
     grad_fn = backend.grad(model_and_loss_call, has_aux=True)
     grads, state = grad_fn(params, batch, state, rng)
     grads = jax.tree_util.tree_map(lambda g: lax.psum(g, "batch"), grads)
     return optimizer.tree_update(i, grads, params, slots,
                                  opt_params), state, subrng
Exemplo n.º 17
0
  def train_epoch(self, epoch_steps, eval_steps):
    """Train for one epoch."""
    # Log separator
    print()

    # Timer
    start_time = time.time()

    for _ in range(epoch_steps):
      # Train
      next_train_batch = next(self._train_stream)
      if self._n_devices > 1:  # TODO(lukaszkaiser): use everywhere if possible.
        next_train_batch = reshape_by_device(next_train_batch, self._n_devices)
      self._opt_state, self._rngs = self._jit_update_fn(
          self._step, self._opt_state, next_train_batch, self._rngs)
      self._step += 1

      if self._step in self._save_steps:
        _save_replicated(self._opt_state, self._step, self._history,
                         self._n_devices, self._output_dir, True)

      # LR log
      if self._step == 1 or self._step % 10 == 0:
        self._train_sw.scalar("training/learning rate",
                              self._lr_fn(self._step), step=self._step)

    # Timer
    epoch_time = time.time() - start_time
    step_log(self._step, "Ran %d train steps in %0.2f secs" %
             (epoch_steps, epoch_time))
    if epoch_steps > 1:
      self._train_sw.scalar("training/steps per second",
                            epoch_steps / epoch_time, step=self._step)

    # Evaluate in parallel
    _, rng = jax_random.split(self._rngs[0])
    evaluate_train_and_eval(
        step=self._step,
        inputs=self._inputs,
        predict_fn=functools.partial(self._jit_model_predict_eval,
                                     params=self._opt_state[0]),
        eval_steps=eval_steps,
        rng=rng,
        train_sw=self._train_sw,
        eval_sw=self._eval_sw,
        history=self._history)

    # Save state
    _save_replicated(self._opt_state, self._step, self._history,
                     self._n_devices, self._output_dir, False)

    # Flush summary writers
    self._train_sw.flush()
    self._eval_sw.flush()
Exemplo n.º 18
0
  def _predict_obs(self, predict_fn, rng):
    for (i, subrng) in enumerate(jax_random.split(rng, self._obs_repr_length)):
      symbol_index = self._steps * self._step_repr_length + i
      log_probs, self._model_state = predict_fn(self._history,
                                                state=self._model_state,
                                                rng=subrng)
      log_probs = log_probs[:, symbol_index, :]
      self._history[:, symbol_index] = utils.gumbel_sample(log_probs)

    obs_repr = self._history[self._obs_repr_indices]
    return self._obs_serializer.deserialize(obs_repr)
Exemplo n.º 19
0
 def initialize_environments(self, batch_size=1, **kwargs):
     """Initializes the environments."""
     self._steps = np.zeros(batch_size, dtype=np.int32)
     self._last_observations = np.full(
         (batch_size, ) + self._observation_space.shape, np.nan)
     self._last_symbols = np.zeros((batch_size, 1), dtype=np.int32)
     super(SerializedSequenceSimulatedEnvProblem,
           self).initialize_environments(batch_size=batch_size, **kwargs)
     (subrng, self._rng) = jax_random.split(self._rng)
     (_, self._init_model_state) = self._model_initialize(
         input_shapes=(batch_size, 1), input_dtype=np.int32, rng=subrng)
Exemplo n.º 20
0
 def predict(x, params=(), rng=None):
     """Predict function jited and parallelized as requested."""
     # On one device, jit and run.
     pred = mapped_predict(reshape_by_device(x, n_devices), params,
                           jax_random.split(rng, n_devices))
     # Need to reduce the [device, per-device-batch, ...] tensors back to
     # a [batch, ...] tensor. The tensors may be nested.
     if not isinstance(pred, (list, tuple)):  # Not nested.
         batch_size = pred.shape[0] * pred.shape[1]
         return np.reshape(pred, [batch_size] + list(pred.shape[2:]))
     batch_size = pred[0].shape[0] * pred[0].shape[1]
     return [np.reshape(p, [batch_size] + list(p.shape[2:])) for p in pred]
Exemplo n.º 21
0
  def _reset(self, indices):
    """Resets environments at the given indices.

    Args:
      indices: list of indices of underlying envs to call reset on.

    Returns:
      np.ndarray of batched observations from the reset envs.
    """
    history = next(self._history_stream)
    (subrng, self._rng) = jax_random.split(self._rng)
    return self._reset_model(self._predict_fn, indices, history, subrng)
Exemplo n.º 22
0
 def evaluate(self, eval_steps):
     _, rng = jax_random.split(self._rngs[0])
     evaluate_train_and_eval(step=self._step,
                             inputs=self._inputs,
                             predict_fn=functools.partial(
                                 self._jit_model_predict_eval,
                                 params=self._opt_state[0]),
                             eval_steps=eval_steps,
                             rng=rng,
                             train_sw=self._train_sw,
                             eval_sw=self._eval_sw,
                             history=self._history)
Exemplo n.º 23
0
    def predict(x, params=(), state=(), rng=None):
        """Predict function jited and parallelized as requested."""
        pred, state = mapped_predict(reshape_by_device(x, n_devices), params,
                                     state, jax_random.split(rng, n_devices))

        # Need to reduce the [device, per-device-batch, ...] tensors back to
        # a [batch, ...] tensor. The tensors may be nested.
        def combine(x):
            batch_size = x.shape[0] * x.shape[1]
            return np.reshape(x, [batch_size] + list(x.shape[2:]))

        return layers.nested_map(pred, combine), state
Exemplo n.º 24
0
  def _predict_obs(self, predict_fn, rng):
    def gumbel_sample(log_probs):
      u = np.random.uniform(low=1e-6, high=1.0 - 1e-6, size=log_probs.shape)
      g = -np.log(-np.log(u))
      return np.argmax(log_probs + g, axis=-1)

    for (i, subrng) in enumerate(jax_random.split(rng, self._obs_repr_length)):
      symbol_index = self._steps * self._step_repr_length + i
      log_probs = predict_fn(self._history, rng=subrng)[:, symbol_index, :]
      self._history[:, symbol_index] = gumbel_sample(log_probs)

    obs_repr = self._history[self._obs_repr_indices]
    return self._obs_serializer.deserialize(obs_repr)
Exemplo n.º 25
0
def check_shape_agreement(layer_instance, input_shape, integer_inputs=False):
    """Check if layer.output_shape agrees with the actual output shape."""
    rng1, rng2, rng3 = random.split(random.get_prng(0), 3)
    output_shape = layer_instance.output_shape(input_shape)
    output_shape = nested_map(output_shape, int)  # Make non-numpy.
    params = layer_instance.initialize(input_shape, rng1)
    inputs = _random_inputs(input_shape, rng2, integer_inputs=integer_inputs)
    result = layer_instance(inputs, params, rng=rng3)
    result_shape = shapes(result)
    msg = 'output shape %s != real result shape %s' % (output_shape,
                                                       result_shape)
    assert output_shape == result_shape, msg
    return output_shape
Exemplo n.º 26
0
  def _step(self, actions):
    """Takes a step in all environments.

    Args:
      actions: (np.ndarray) with first dimension equal to the batch size.

    Returns:
      a tuple of batched raw observations, raw rewards, dones and infos.
    """
    # Predict the next observation.
    (subrng, self._rng) = jax_random.split(self._rng)
    (observation, reward, done) = self._step_model(
        self._predict_fn, actions, subrng)
    return (observation, reward, done, {})
Exemplo n.º 27
0
 def evaluate(self, eval_steps):
     _, rng = jax_random.split(self._rngs[0])
     _, _, self._model_state = evaluate_train_and_eval(
         step=self._step,
         eval_stream=self._eval_stream,
         train_eval_stream=self._train_eval_stream,
         predict_fn=functools.partial(self._jit_model_predict_eval,
                                      params=self._opt_state[0]),
         eval_steps=eval_steps,
         state=self._model_state,
         rng=rng,
         train_sw=self._train_sw,
         eval_sw=self._eval_sw,
         history=self._history,
         has_weights=self._has_weights)
Exemplo n.º 28
0
    def predict(x, params=(), rng=None):
        """Predict function jited and parallelized as requested."""
        # On one device, jit and run.
        if num_devices == 1:
            return backend.jit(model_predict)(x, params, rng=rng)

        # Multi-devices, pmap and run.
        @functools.partial(backend.pmap, axis_name="batch")
        def mapped_predict(x, params, rng):
            return model_predict(x, params, rng=rng)

        pred = mapped_predict(reshape_by_device(x, num_devices), params,
                              jax_random.split(rng, num_devices))
        batch_size = x.shape[0]
        return np.reshape(pred, [batch_size] + list(pred.shape[2:]))
Exemplo n.º 29
0
    def predict(x, params=(), state=(), rng=None):
        """Predict function jited and parallelized as requested."""
        pred = mapped_predict(reshape_by_device(x, n_devices), params, state,
                              jax_random.split(rng, n_devices))

        # Need to reduce the [device, per-device-batch, ...] tensors back to
        # a [batch, ...] tensor. The tensors may be nested.
        def combine(x):
            if len(x.shape) > 1:
                batch_size = x.shape[0] * x.shape[1]
                return np.reshape(x, [batch_size] + list(x.shape[2:]))
            # TODO(lukaszkaiser): is returning averages for scalars the right choice?
            # If it is only scalar, return the average.
            return np.mean(x, axis=0)

        return layers.nested_map(pred, combine)
Exemplo n.º 30
0
    def evaluate(self, eval_steps):
        """Evaluate the model and log metrics."""
        _, rng = jax_random.split(self._rngs[0])
        _, _, self._model_state = evaluate_train_and_eval(
            step=self._step,
            eval_stream=self._eval_stream,
            train_eval_stream=self._train_eval_stream,
            predict_fn=functools.partial(self._jit_model_predict_eval,
                                         params=self._opt_state[0]),
            eval_steps=eval_steps,
            state=self._model_state,
            rng=rng,
            train_sw=self._train_sw,
            eval_sw=self._eval_sw,
            history=self._history,
            has_weights=self._has_weights)

        # Save the learning rate in the history
        self._history.append("train", "training/learning_rate", self._step,
                             self.learning_rate)