コード例 #1
0
  def evaluate(self, eval_steps):
    """Evaluate the model and log metrics."""
    _, rng = jax_random.split(self._rngs[0])
    # TODO(lukaszkaiser): both model state and parameters by default include
    # the loss layer. Currently, we access the pure-model parameters by just
    # indexing, [0] here. But we should make it more explicit in a better API.
    weights = (self._opt_state[0][0], self._metrics_weights)
    state = (self._model_state[0], self._metrics_state)
    step_log(self._step, 'Evaluation')
    train_eval_slice = itertools.islice(self._train_eval_stream, eval_steps)
    train_metrics, _ = evaluation_round(
        train_eval_slice, self._metrics, self._jit_eval, weights, state, rng)
    log_metrics(train_metrics, self._train_sw, 'train',
                self._step, history=self._history)
    eval_slice = itertools.islice(self._eval_stream, eval_steps)
    eval_metrics, _ = evaluation_round(
        eval_slice, self._metrics, self._jit_eval, weights, state, rng)
    log_metrics(eval_metrics, self._eval_sw, 'eval',
                self._step, history=self._history)
    step_log(self._step, 'Finished evaluation')

    # Save the optimizer weights in the history
    for (name, value) in self.nontrainable_params.items():
      self._history.append('train', 'training/{}'.format(name), self._step,
                           value)
コード例 #2
0
    def evaluation_round(self, inputs_stream, weights, state, rng):
        """Evaluate.

    Args:
      inputs_stream: iterable of inputs to evaluate on.
      weights: weights for each f in eval_fns.
      state: state for each f in eval_fns.
      rng: random number generator.

    Returns:
      metrics: dict from metric name to metric value averaged over the number of
        inputs.
      state: end state for `predict_fn`.
    """
        metrics = collections.defaultdict(float)
        count = 0
        for inp in inputs_stream:
            count += 1
            rng, subrng = jax_random.split(rng)
            metric_values, _ = self._jit_eval(inp, weights, state, subrng)
            try:
                metric_values = list(metric_values)
            except TypeError:
                metric_values = [float(metric_values)]
            for m, v in zip(self._metrics, metric_values):
                metrics[m] += v
        return {m: v / count for (m, v) in six.iteritems(metrics)}, state
コード例 #3
0
def evaluation_round(inputs_stream, metric_names, eval_fn, weights, state, rng):
  """Evaluate.

  Args:
    inputs_stream: iterable of inputs to evaluate on.
    metric_names: list of strings, the order in which eval_fn returns metrics.
    eval_fn: metric function, which takes inputs and predictions (and
      weights, state, rng) and returns a tuple of scalar metric values.
    weights: weights for each f in eval_fns.
    state: state for each f in eval_fns.
    rng: random number generator.

  Returns:
    metrics: dict from metric name to metric value averaged over the number of
      inputs.
    state: end state for `predict_fn`.
  """
  metrics = collections.defaultdict(float)
  count = 0
  for inp in inputs_stream:
    count += 1
    rng, subrng = jax_random.split(rng)
    metric_values, _ = eval_fn(inp, weights, state, subrng)
    try:
      metric_values = list(metric_values)
    except TypeError:
      metric_values = [float(metric_values)]
    for m, v in zip(metric_names, metric_values):
      metrics[m] += v
  return {m: v / count for (m, v) in six.iteritems(metrics)}, state
コード例 #4
0
 def single_update(i, opt_state, batch, state, rng):
     weights, slots, opt_params = opt_state
     rng, subrng = jax_random.split(rng[0])
     grad_fn = backend.grad(model_and_loss_call, has_aux=True)
     grads, state = grad_fn(weights, batch, state, rng)
     return optimizer.tree_update(i, grads, weights, slots,
                                  opt_params), state, [subrng]
 def _consume_act(self, actions, predict_fn, rng):
     act_repr = self._action_serializer.serialize(actions)
     for (i,
          subrng) in enumerate(jax_random.split(rng,
                                                self._act_repr_length)):
         # Run the network to update the inference buffers, but ignore the result.
         predict_fn(self._last_symbols, rng=subrng)
         self._last_symbols = act_repr[:, i:(i + 1)]
コード例 #6
0
 def predict(x, weights, state, rng):
   """Predict function jited and parallelized as requested."""
   res, state = backend.combine_devices(model_predict(
       backend.reshape_by_device(x, n_devices),
       weights,
       state,
       np.stack(jax_random.split(rng, n_devices))))
   return layers.nested_map(lambda y: np.mean(y, axis=0), res), state
コード例 #7
0
 def mapped_update(i, opt_state, batch, state, rng):
   """This is a multi-device version of the update function above."""
   # We assume all tensors have the first dimension = n_devices.
   weights, slots, opt_params = opt_state
   rng, subrng = jax_random.split(rng)
   grad_fn = backend.grad(model_and_loss_call, has_aux=True)
   grads, state = grad_fn(weights, batch, state, rng)
   grads = jax.tree_util.tree_map(
       lambda g: backend.psum(g, 'batch'), grads)
   return optimizer.tree_update(
       i, grads, weights, slots, opt_params), state, subrng
コード例 #8
0
ファイル: simulated_env_problem.py プロジェクト: koz4k2/trax
 def initialize_environments(self, batch_size=1, **kwargs):
     """Initializes the environments."""
     self._steps = np.zeros(batch_size, dtype=np.int32)
     self._last_observations = np.full(
         (batch_size, ) + self._observation_space.shape, np.nan)
     self._last_symbols = np.zeros((batch_size, 1), dtype=np.int32)
     super(SerializedSequenceSimulatedEnvProblem,
           self).initialize_environments(batch_size=batch_size, **kwargs)
     (subrng, self._rng) = jax_random.split(self._rng)
     (_, self._init_model_state) = self._model_initialize(
         input_shapes=(batch_size, 1), input_dtype=np.int32, rng=subrng)
 def _predict_obs(self, predict_fn, rng):
     obs_repr = np.zeros(
         (self._steps.shape[0], self._obs_repr_length),
         dtype=np.int32,
     )
     for (i,
          subrng) in enumerate(jax_random.split(rng,
                                                self._obs_repr_length)):
         log_probs = predict_fn(self._last_symbols, rng=subrng)
         self._last_symbols = utils.gumbel_sample(log_probs)
         obs_repr[:, i] = self._last_symbols[:, 0]
     return self._obs_serializer.deserialize(obs_repr)
コード例 #10
0
 def test_computes(self):
     rng_key = jax_random.get_prng(0)
     hidden_size = (4, 4)
     output_size = 6
     model = atari_cnn.FrameStackMLP(hidden_sizes=hidden_size,
                                     output_size=output_size)
     B, T, OBS = 2, 2, 3  # pylint: disable=invalid-name
     rng_key, key = jax_random.split(rng_key)
     _, _ = model.initialize_once((1, 1, OBS), onp.float32, key)
     x = onp.arange(B * (T + 1) * OBS).reshape(B, T + 1, OBS)
     y = model(x)
     self.assertEqual((B, T + 1, output_size), y.shape)
    def _reset(self, indices):
        """Resets environments at the given indices.

    Args:
      indices: list of indices of underlying envs to call reset on.

    Returns:
      np.ndarray of batched observations from the reset envs.
    """
        history = next(self._history_stream)
        (subrng, self._rng) = jax_random.split(self._rng)
        return self._reset_model(self._predict_fn, indices, history, subrng)
コード例 #12
0
 def test_computes(self):
     rng_key = jax_random.get_prng(0)
     hidden_size = (4, 4)
     output_size = 6
     model = atari_cnn.AtariCnn(hidden_sizes=hidden_size,
                                output_size=output_size)
     B, T, OBS = 2, 2, (28, 28, 3)  # pylint: disable=invalid-name
     rng_key, key = jax_random.split(rng_key)
     _, _ = model.initialize_once((1, 1) + OBS, onp.float32, key)
     x = onp.arange(B * (T + 1) * functools.reduce(op.mul, OBS)).reshape(
         B, T + 1, *OBS)
     y = model(x)
     self.assertEqual((B, T + 1, output_size), y.shape)
    def _step(self, actions):
        """Takes a step in all environments.

    Args:
      actions: (np.ndarray) with first dimension equal to the batch size.

    Returns:
      a tuple of batched raw observations, raw rewards, dones and infos.
    """
        # Predict the next observation.
        (subrng, self._rng) = jax_random.split(self._rng)
        (observation, reward, done) = self._step_model(self._predict_fn,
                                                       actions, subrng)
        return (observation, reward, done, {})
コード例 #14
0
ファイル: trainer_lib.py プロジェクト: vkataev/trax
 def mapped_update(i, opt_state, batch, state, rng):
   """This is a multi-device version of the update function above."""
   # We assume all tensors have the first dimension = n_devices.
   weights, slots, opt_params = opt_state
   rng, subrng = jax_random.split(rng)
   grad_fn = backend.grad(model_and_loss_call, has_aux=True)
   grads, state = grad_fn(weights, batch, state, rng)
   # We do a psum(1.0) here instead of `n_devices` since `n_devices` is just
   # the number of devices on this host machine, however psum goes over all
   # devices of all hosts (ex: a TPU pod) and we need to be averaging over all
   # of them.
   grads = jax.tree_util.tree_map(
       lambda g: backend.psum(g, 'batch') / backend.psum(1.0, 'batch'), grads)
   return optimizer.tree_update(
       i, grads, weights, slots, opt_params), state, subrng
コード例 #15
0
 def single_compute_loss(opt_state, batch, state, rng):
     rng, subrng = jax_random.split(rng[0])
     loss_val, state = loss_fn(opt_state[0], batch, predict_fn, state,
                               rng)
     return loss_val, state, [subrng]
コード例 #16
0
 def mapped_compute_loss(opt_state, batch, state, rng):
     """This is a multi-device version of the update function above."""
     # We assume all tensors have the first dimension = n_devices.
     rng, subrng = jax_random.split(rng)
     loss_val, state = loss_fn(opt_state[0], batch, predict_fn, state, rng)
     return loss_val, state, subrng
コード例 #17
0
    def __init__(self,
                 model,
                 loss_fn,
                 optimizer,
                 lr_schedule,
                 inputs,
                 output_dir=None,
                 random_seed=None,
                 n_devices=None,
                 save_steps=None,
                 should_save_checkpoints=True,
                 should_write_summaries=True,
                 has_weights=False,
                 nontrainable_param_map=None,
                 mask_id=None,
                 metrics=None):
        if backend.get_name() == 'jax':
            self._host_id = jax.host_id()
            self._host_count = jax.host_count()
        else:
            self._host_id = 0
            self._host_count = 1
        self._is_chief = (self._host_id == 0)

        if save_steps is None:
            save_steps = []
        self._save_steps = save_steps
        self._should_save_checkpoints = should_save_checkpoints
        self._should_write_summaries = should_write_summaries
        self._has_weights = has_weights
        self._mask_id = mask_id
        self._metrics_dict = _METRICS if metrics is None else metrics
        loss_fn = loss_fn(has_weights=has_weights, mask_id=mask_id)
        device_count = backend.device_count()
        n_devices = n_devices or device_count
        # TODO(lukaszkaiser): remove this restriction when possible.
        if n_devices != device_count and backend.get_name() == 'jax':
            raise ValueError(
                'JAX cannot work yet with n_devices != all devices: '
                '%d != %d' % (n_devices, device_count))
        self._n_devices = n_devices

        # Simple differential seeding of RNG across hosts by host_id and time.
        if random_seed is None and self._host_count > 1:
            _, random_seed = divmod(
                int(time.time() * 1e6) + int(self._host_id * 1e6), 2**32)
        rng = get_random_number_generator_and_set_seed(random_seed)
        inputs = inputs(n_devices)
        self._inputs = inputs

        # Initialize the learning rate to a dummy value. It will be set in reset().
        opt = optimizer(learning_rate=0.0)

        # Setup the model.
        model_train = model(mode='train')
        model_predict_eval = model(mode='eval')

        # Setup state.
        rng, init_rng = jax_random.split(rng)
        self._rngs = np.stack(jax_random.split(rng, n_devices))
        first_shape = inputs.input_shape[0]
        # If the inputs are a tuple/list, add [None] (batch) to each element.
        if isinstance(first_shape, (list, tuple)):
            model_input_shape = tuple(
                tuple([None] + list(shape)) for shape in inputs.input_shape)
            model_target_shape = tuple(
                tuple([None] + list(shape)) for shape in inputs.target_shape)
        else:  # Otherwise just add [None] to the input shape.
            model_input_shape = tuple([None] + list(inputs.input_shape))
            model_target_shape = tuple([None] + list(inputs.target_shape))
        # Change all None to 1 in input and target shape.
        model_input_shape = backend.nested_map(lambda x: x or 1,
                                               model_input_shape)
        model_target_shape = backend.nested_map(lambda x: x or 1,
                                                model_target_shape)

        def new_opt_state_and_model_state(input_shape, input_dtype,
                                          target_shape, target_dtype, rng):
            """Returns optimizer and model states suitable for training a model."""
            # Combine inputs and targets on the stack.
            if not isinstance(input_dtype, (list, tuple)):
                input_dtype = [input_dtype]
                input_shape = [input_shape]
            if not isinstance(target_dtype, (list, tuple)):
                target_dtype = [target_dtype]
                target_shape = [target_shape]
            dtypes = list(input_dtype) + list(target_dtype)
            shapes = list(input_shape) + list(target_shape)
            if self._has_weights:
                shapes += list(target_shape)
                dtypes += [np.float32 for _ in target_dtype]
            input_signature = tuple(
                ShapeDtype(s, d) for (s, d) in zip(shapes, dtypes))
            # We need to create a new model instance and not reuse `model_train` here,
            # because `m.initialize` puts cached parameter values in `m` and hence the
            # next call of `m.initialize` will give wrong results.
            m = tl.Serial(model(mode='train'), loss_fn)
            m._set_rng_recursive(rng)  # pylint: disable=protected-access
            weights, state = m.init(input_signature)
            (slots, opt_params) = opt.tree_init(weights)
            return (OptState(weights, slots, opt_params), state)

        if _is_jit_init():
            # JIT parameter initialization to avoid memory fragmentation
            new_opt_state_and_model_state = backend.jit(
                new_opt_state_and_model_state, static_argnums=(0, 1, 2, 3))
        self._new_opt_state_and_model_state = (
            lambda: new_opt_state_and_model_state(  # pylint: disable=g-long-lambda
                model_input_shape, self._inputs.input_dtype,
                model_target_shape, self._inputs.target_dtype, init_rng))

        # Arrange and initialize metrics layers.
        self._metrics = list(sorted(self._metrics_dict.keys()))
        metrics_layers = [
            self._metrics_dict[m](has_weights=self._has_weights,
                                  mask_id=self._mask_id) for m in self._metrics
        ]
        metrics_in_parallel = tl.Branch(*metrics_layers)
        # TODO(lukaszkaiser): clean this up once layer API stabilizes.
        # For now, we need to initialize metric layers somehow, so here we go.
        # We assume that they do not have any parameters, so this is a dummy.
        dummy_shapes = ((1, 2), (1, ),
                        (1, )) if self._has_weights else ((1, 2), (1, ))
        dummy_signature = tuple(ShapeDtype(s) for s in dummy_shapes)
        metrics_in_parallel._set_rng_recursive(init_rng)  # pylint: disable=protected-access
        m_weights, m_state = metrics_in_parallel.init(dummy_signature)
        self._metrics_weights = self._for_n_devices(m_weights)
        self._metrics_state = self._for_n_devices(m_state)

        # Jit model_predict and update so they're fast.
        self._jit_eval = _jit_predict_fn(model_predict_eval,
                                         metrics_in_parallel, n_devices)
        self._jit_update_fn = _jit_update_fn(model_train, loss_fn, opt,
                                             n_devices)

        self._model_train = model_train
        self._model_predict_eval = model_predict_eval
        self._loss_fn = loss_fn
        # TODO(pkozakowski): "Learning rate schedules" are currently able to control
        # control all optimizer parameters and model state, so let's rename them
        # accordingly.
        self._lr_schedule = lr_schedule

        if nontrainable_param_map is None:
            nontrainable_param_map = {}
        self._nontrainable_param_map = nontrainable_param_map

        # Those fields will be set in reset().
        self._output_dir = None
        self._train_sw = None
        self._eval_sw = None
        self._history = None
        self._lr_fn = None
        self._opt_state = None
        self._step = None
        self._model_state = None

        if output_dir is not None:
            self.reset(output_dir)