Example #1
0
    def train_step(self, batch):
        """Run one training step and update self._opt_state."""
        # Calculate the current optimizer parameters.
        # TODO(pkozakowski): Optimizer parameters get polluted with model state,
        # which doesn't break anything but is weird. Filter it out.
        opt_param_updates = self._for_n_devices(
            math.nested_map(np.array, self.nontrainable_params))
        opt_state = self._opt_state
        opt_state.opt_params.update(opt_param_updates)

        # Run the update.
        weights, slots, opt_params = opt_state
        (weights,
         slots), stat, self._model_state, self._rngs = self._jit_update_fn(
             (weights, slots), self._step, opt_params, batch,
             self._model_state, self._rngs)
        self._model_state = self._map_to_state_dicts(self._state_dicts_update)
        self._opt_state = opt_state._replace(weights=weights, slots=slots)
        if self._should_log_now():
            for name, value in stat.items():
                scalar_value = np.mean(
                    value)  # On  multiple devices, take the mean.
                self._train_sw.scalar('training/' + name,
                                      scalar_value,
                                      step=self._step)
        self._step += 1
Example #2
0
 def model_state(self):
     # Currently we need to pick [0] as we ignore loss state (empty).
     state = self._model_state[0]
     if self.n_devices > 1:
         unreplicate = lambda x: x[0]
         state = math.nested_map(unreplicate, state)
     return state
Example #3
0
    def save_state(self, keep, prefix='model'):
        """Save trainer state given a possibly replicated opt_state."""
        opt_state = self._opt_state
        if self.n_devices > 1:
            first_replica = lambda x: x[0]
            opt_state = OptState(*math.nested_map(first_replica, opt_state))
        # This line, while optional, allows JAX to transfer arrays from the device
        # to the host in parallel, which is particularly important for cloud TPU.
        if math.backend_name() == 'jax':
            opt_state = jax.device_get(opt_state)
        step, history, model_state = self._step, self._history, self._model_state
        output_dir = self._output_dir

        weights_file = os.path.join(output_dir, prefix + '.pkl.gz')

        # This dict will be stored as the model.
        trainer_state_dict = make_trainer_state_dict(step, opt_state, history,
                                                     model_state,
                                                     self._input_signature)
        self._save_state_dict(trainer_state_dict, weights_file)

        if keep:
            weights_file = os.path.join(output_dir,
                                        '{}_{}.pkl.gz'.format(prefix, step))
            self._save_state_dict(trainer_state_dict, weights_file)
Example #4
0
 def model_weights(self):
     # Currently we need to pick [0] as we ignore loss weights (empty).
     weights = self._opt_state.weights[0]
     if self.n_devices > 1:
         unreplicate = lambda x: x[0]
         weights = math.nested_map(unreplicate, weights)
     return weights
Example #5
0
    def _forward_abstract(self, input_signature):
        """Computes shapes and dtypes this layer would produce in a forward pass.

    Args:
      input_signature: A ShapeDtype instance (if this layer takes one input)
          or a list/tuple of ShapeDtype instances; signatures of inputs.

    Returns:
      A tuple of (output, state).

      The output part of the tuple is a ShapeDtype instance representing the
      shape and type of the output (if this layer has one output) or a tuple
      of ShapeDtype instances (if this layer has more than one output).
    """
        try:
            # Beware: using an actual RNG (as opposed to this ShapeDtype stub) would
            # cause a large number of dropout masks to be computed and permanently
            # stored in global memory.
            rng = ShapeDtype((2, ), onp.uint32)

            def call_on_input(x, weights, state, rng):
                return self.forward_with_state(x,
                                               weights=weights,
                                               state=state,
                                               rng=rng)

            weight_signature = nested_map(signature, self.weights)
            s = math.abstract_eval(call_on_input)(input_signature,
                                                  weight_signature, self.state,
                                                  rng)
            return s
        except Exception:
            name, trace = self.__class__.__name__, _short_traceback(skip=3)
            raise LayerError(name, '_forward_abstract', self._caller,
                             input_signature, trace)
Example #6
0
    def _forward_abstract(self, input_signature):
        """Computes shapes and dtypes this layer would produce in a forward pass.

    Args:
      input_signature: `ShapeDtype` instance (if this layer takes one input)
          or list/tuple of `ShapeDtype` instances.

    Returns:
      Tuple of (output, state).

      The output part of the tuple is a `ShapeDtype` instance representing the
      shape and type of the output (if this layer has one output) or a tuple
      of `ShapeDtype` instances (if this layer has more than one output).
    """
        try:
            # Note: By using rng_signature in place of an rng, we avoid computing and
            # permanently storing in global memory a large number of dropout masks.
            # TODO(jonni): Check if using an rng still carries this cost.
            dummy_rng = math.random.get_prng(0)
            rng_signature = ShapeDtype(dummy_rng.shape, dummy_rng.dtype)
            weight_signature = nested_map(signature, self.weights)
            forward_infer_shapes = math.abstract_eval(self.pure_fn)
            return forward_infer_shapes(input_signature, weight_signature,
                                        self.state, rng_signature)
        except Exception:
            # Skipping 13 lines which are all JAX abstract'ifying wrappers.
            name, trace = self._name, _short_traceback(skip=13)
            raise LayerError(name, '_forward_abstract', self._caller,
                             input_signature, trace) from None
Example #7
0
    def _forward_abstract(self, input_signature):
        """Computes shapes and dtypes this layer would produce in a forward pass.

    Args:
      input_signature: ShapeDtype instance (if this layer takes one input)
          or list/tuple of ShapeDtype instances.

    Returns:
      Tuple of (output, state).

      The output part of the tuple is a ShapeDtype instance representing the
      shape and type of the output (if this layer has one output) or a tuple
      of ShapeDtype instances (if this layer has more than one output).
    """
        try:
            # Note: By using rng_signature in place of an rng, we avoid computing and
            # permanently storing in global memory a large number of dropout masks.
            # TODO(jonni): Check if using an rng still carries this cost.
            rng_signature = ShapeDtype((2, ), np.uint32)
            weight_signature = nested_map(signature, self.weights)
            forward_infer_shapes = math.abstract_eval(self.forward_with_state)
            return forward_infer_shapes(input_signature, weight_signature,
                                        self.state, rng_signature)
        except Exception as e:
            name, trace = self._name, _short_traceback(skip=3)
            raise LayerError(name, '_forward_abstract', self._caller,
                             input_signature, trace) from e
Example #8
0
def _sizes(x):
  """Get a structure of sizes for a structure of nested arrays."""
  def size(x):
    try:
      return x.size
    except Exception:  # pylint: disable=broad-except
      return 0
  return math.nested_map(size, x)
Example #9
0
    def _test_train_eval_predict(self, backend_name):
        if xla_bridge.device_count() > 1 and backend_name == 'tf':
            self.skipTest("tf-numpy backend does't support multi-devices yet.")
        with math.use_backend(backend_name), self.tmp_dir() as output_dir:
            # Prepare model and inputs
            n_classes = 4
            steps = 2
            eval_steps = 2

            # Adds Dropout and BatchNorm to test state handling.
            def model_fn(mode='train'):
                return layers.Serial(
                    layers.Dropout(mode=mode, rate=0.1),
                    layers.BatchNorm(mode=mode),
                    models.MLP(d_hidden=16,
                               n_output_classes=n_classes,
                               mode=mode))

            inputs = test_inputs(n_classes)

            # Train and evaluate
            state = trainer_lib.train(output_dir,
                                      model=model_fn,
                                      inputs=inputs,
                                      steps=steps,
                                      eval_steps=eval_steps)

            # Assert total train steps
            self.assertEqual(steps, state.step)

            # Assert 2 evaluations ran
            train_acc = state.history.get('train', 'metrics/accuracy')
            eval_acc = state.history.get('eval', 'metrics/accuracy')
            self.assertEqual(len(train_acc), len(eval_acc))
            self.assertLen(eval_acc, 2)

            # Predict with final weights
            inputs = inputs.train_stream(1)
            model = model_fn()
            weights = state.opt_state.weights[0]
            state = state.model_state[0]
            if xla_bridge.device_count() > 1:
                unreplicate = lambda x: x[0]
                weights = math.nested_map(unreplicate, weights)
                state = math.nested_map(unreplicate, state)
            model(next(inputs)[0], weights=weights, state=state)
Example #10
0
def _shapes(x):
  """Gets a structure of shapes for a structure of nested arrays."""
  def shape(x):
    try:
      return tuple([int(i) for i in x.shape])
    except Exception:  # pylint: disable=broad-except
      return ()
  return tuple(nested_map(shape, x))
Example #11
0
def _combine_devices(x_tuple):
  """Combines multi-device tensors into a single batch."""
  def f(x):
    if len(x.shape) < 2:
      return x  # No extra batch dimension: use devices as batch, so return.
    batch_size = x.shape[0] * x.shape[1]
    return math.numpy.reshape(x, [batch_size] + list(x.shape[2:]))
  return math.nested_map(f, x_tuple)
Example #12
0
 def predict(x, weights, state, rng):
   """Predict function jited and parallelized as requested."""
   res, state = _combine_devices(model_predict(
       reshape_by_device(x, n_devices),
       weights,
       state,
       np.stack(math.random.split(rng, n_devices))))
   return math.nested_map(lambda y: np.mean(y, axis=0), res), state
Example #13
0
 def _default_timestep_to_np(self, ts):
     """Default way to convert timestep to numpy."""
     return math.nested_map(np.array, (
         ts.observation,
         ts.action,
         ts.dist_inputs,
         ts.reward,
         ts.discounted_return,
     ))
Example #14
0
 def predict(x, weights, state, rng):
     """Predict function JIT-compileds and parallelized as requested."""
     res, state = _combine_devices(
         model_predict(reshape_by_device(x, n_devices), weights, state,
                       jnp.stack(math.random.split(rng, n_devices))))
     if do_mean:
         return math.nested_map(lambda y: jnp.mean(y, axis=0), res), state
     else:
         return res, state
Example #15
0
def for_n_devices(x, n_devices):
  """Replicates/broadcasts `x` for `n_devices`."""
  def f(x):
    if n_devices > 1 and math.backend_name() == 'jax':
      return _multi_device_put(x)
    elif n_devices > 1:
      return jnp.broadcast_to(x, (n_devices,) + x.shape)
    else:
      return x
  return math.nested_map(f, x)
Example #16
0
 def print_n_weights(self):
     """Prints the total count of trainable weights."""
     opt_state = self._opt_state
     sizes = _sizes(opt_state.weights)
     if self.n_devices > 1:
         unreplicate = lambda x: x[0]
         single_weights = math.nested_map(unreplicate, opt_state.weights)
         sizes = _sizes(single_weights)
     total_size = _nested_reduce(sum, sizes)
     self.log_step('Total number of trainable weights: %d' % total_size)
Example #17
0
def dummy_inputs(rng, input_sig):
    def f(sig):
        shape = sig.shape
        if shape and shape[0] is None:
            shape = (2, ) + tuple(shape[1:])
        if onp.issubdtype(sig.dtype, onp.integer):
            minval = None
        else:
            minval = 0
        return rng.uniform(shape=shape, dtype=sig.dtype, minval=minval)

    return math_lib.nested_map(f, input_sig)
Example #18
0
def reshape_by_device(x, n_devices):
  """Reshapes possibly nested `x` into a shape `(n_devices, ...)`."""
  def f(x):
    x_shape = list(x.shape)
    batch_size = x_shape[0]
    batch_size_per_device = batch_size // n_devices
    if batch_size_per_device * n_devices != batch_size:
      raise ValueError(f'Number of devices ({n_devices}) does not evenly '
                       f'divide batch size ({batch_size}).')
    new_shape_prefix = [n_devices, batch_size_per_device]
    return math.numpy.reshape(x, new_shape_prefix + x_shape[1:])
  return math.nested_map(f, x)
Example #19
0
def reshape_by_device(x, n_devices):
  """Reshapes possibly nested x into a shape (n_devices, ...)."""
  def f(x):
    x_shape = list(x.shape)
    batch_size = x_shape[0]
    batch_size_per_device = batch_size // n_devices
    if batch_size_per_device * n_devices != batch_size:
      raise ValueError(
          'We require that n_devices[%d] divides batch_size[%d] evenly.' %
          (n_devices, batch_size))
    new_shape_prefix = [n_devices, batch_size_per_device]
    return math.numpy.reshape(x, new_shape_prefix + x_shape[1:])
  return math.nested_map(f, x)
Example #20
0
    def _for_n_devices(self, x):
        """Replicates/broadcasts `x` for n devices if `self.n_devicess > 1`."""
        n = self.n_devices

        def f(x):
            if n > 1 and math.backend_name() == 'jax':
                return _multi_device_put(x)
            elif n > 1:
                return np.broadcast_to(x, (n, ) + x.shape)
            else:
                return x

        return math.nested_map(f, x)
Example #21
0
 def build(self, input_shape):
     with math_lib.use_backend("tf"):
         # Using `is` instead of `==` following Trax's practice
         if self._trax_layer.weights is base.EMPTY_WEIGHTS:
             sanitized_input_shape = math_lib.nested_map(
                 functools.partial(_replace_none_batch,
                                   batch_size=self._batch_size),
                 input_shape)
             weights, state = self._trax_layer.init(
                 tensor_shapes_to_shape_dtypes(sanitized_input_shape,
                                               self.dtype),
                 rng=self._initializer_rng)
         else:
             weights = self._trax_layer.weights
             state = self._trax_layer.state
         # Note: `weights` may contain `EMPTY_WEIGHTS`
         self._weights = math_lib.nested_map(
             functools.partial(tf.Variable, trainable=True), weights)
         self._state = math_lib.nested_map(
             functools.partial(tf.Variable, trainable=False), state)
         self._rng = tf.Variable(self._forward_rng_init, trainable=False)
     super(TraxKerasLayer, self).build(input_shape)
Example #22
0
 def _default_timestep_to_np(self, ts):
     """Default way to convert timestep to numpy."""
     return math.nested_map(
         np.array,
         TimeStepNp(
             observation=ts.observation,
             action=ts.action,
             dist_inputs=ts.dist_inputs,
             reward=ts.reward,
             done=ts.done,
             return_=ts.discounted_return,
             mask=ts.mask,
         ))
Example #23
0
 def policy(self, trajectory):
     """Chooses an action to play after a trajectory."""
     model = self._policy_collect_model
     model.weights = self._policy_trainer.model_weights
     tr_slice = trajectory[-self._max_slice_length:]
     trajectory_np = tr_slice.to_np(timestep_to_np=self.task.timestep_to_np)
     # Add batch dimension to trajectory_np and run the model.
     pred = model(trajectory_np.observations[None, ...], n_accelerators=1)
     # Pick element 0 from the batch (the only one), last (current) timestep.
     pred = pred[0, -1, :]
     sample = self._policy_dist.sample(pred)
     result = (sample, pred)
     if math.backend_name() == 'jax':
         result = math.nested_map(lambda x: x.copy(), result)
     return result
Example #24
0
    def train_step(self, batch):
        """Run one training step and update self._opt_state."""
        # Calculate the current optimizer parameters.
        # TODO(pkozakowski): Optimizer parameters get polluted with model state,
        # which doesn't break anything but is weird. Filter it out.
        opt_param_updates = self._for_n_devices(
            math.nested_map(np.array, self.nontrainable_params))
        opt_state = self._opt_state
        opt_state.opt_params.update(opt_param_updates)

        # Run the update.
        (weights, slots), self._model_state, self._rngs = self._jit_update_fn(
            self._step, opt_state, batch, self._model_state, self._rngs)
        self._model_state = self._map_to_state_dicts(self._state_dicts_update)
        self._opt_state = opt_state._replace(weights=weights, slots=slots)
        self._step += 1
Example #25
0
    def __call__(self, x, **kwargs):
        """Makes Layer instances callable; for use in tests or interactive settings.

    This convenience method helps library users play with, test, or otherwise
    probe the behavior of layers outside of a full training environment. It
    presents the layer as callable function from inputs to outputs, with the
    option of manually specifying weights and non-parameter state per individual
    call. For convenience, weights and non-parameter state are cached per layer
    instance, starting from default values of `EMPTY_WEIGHTS` and `EMPTY_STATE`,
    and acquiring non-empty values either by initialization or from values
    explicitly provided via the weights and state keyword arguments.

    Args:
      x: 0 or more input tensors, formatted the same as the inputs to
          Layer.forward.
      **kwargs: Additional keyword arguments if needed/desired for this layer.
          Three possible keyword arguments are especially relevant:
            - weights=... will override any cached weights values
            - state=... will override any cached state values
            - rng=... will supply a PRNG key for use by the layer

    Returns:
      0 or more output tensors, formatted the same as the outputs from
          Layer.forward.
    """
        weights = kwargs.pop('weights', self.weights)
        state = kwargs.pop('state', self.state)
        rng = kwargs.pop('rng', self._rng)
        rng = math.random.get_prng(0) if rng is None else rng
        forward = self._forward_internal
        # TODO(lukaszkaiser): the following arguments are experimental, decide which
        #   are really useful after a number of experiments and finalize the API.
        n_accelerators = kwargs.pop('n_accelerators', 0)
        replicate = kwargs.pop('replicate', True)
        if n_accelerators > 1 and replicate:
            weights = for_n_devices(weights, n_accelerators)
            state = for_n_devices(state, n_accelerators)
        if n_accelerators:
            forward = jit_forward(forward, n_accelerators)
        outputs, new_state = forward(x, weights, state, rng)
        if n_accelerators > 1 and replicate:  # Unreplicate state if needed.
            new_state = math.nested_map(new_state, lambda x: x[0])
        self.state = new_state
        self.weights = weights
        return outputs
Example #26
0
 def call(self, inputs):
     with math_lib.use_backend("tf"):
         inputs = math_lib.nested_map(
             functools.partial(_replace_none_batch,
                               batch_size=self._batch_size), inputs)
         weights, state, rng = read_values(
             [self._weights, self._state, self._rng])
         inputs, weights, state, rng = to_arrays(
             [inputs, weights, state, rng])
         outputs, new_state = self._trax_layer.pure_fn(inputs,
                                                       weights=weights,
                                                       state=state,
                                                       rng=rng)
         tf.nest.map_structure(lambda v, t: v.assign(t), self._state,
                               new_state)
         self._rng.assign(self._rng_updater(rng))
         outputs = to_tensors(outputs)
         return outputs
Example #27
0
    def save_state(self, keep):
        """Save trainer state given a possibly replicated opt_state."""
        opt_state = self._opt_state
        if self.n_devices > 1:
            first_replica = lambda x: x[0]
            opt_state = OptState(*math.nested_map(first_replica, opt_state))
        # This line, while optional, allows JAX to transfer arrays from the device
        # to the host in parallel, which is particularly important for cloud TPU.
        if math.backend_name() == 'jax':
            opt_state = jax.device_get(opt_state)
        step, history, model_state = self._step, self._history, self._model_state
        output_dir = self._output_dir

        pkl_module = utils.get_pickle_module()
        weights_file = os.path.join(output_dir, 'model.pkl')
        with tf.io.gfile.GFile(weights_file, 'wb') as f:
            pkl_module.dump((tuple(opt_state), step, history, model_state), f)
        if keep:
            weights_file = os.path.join(output_dir,
                                        'model_{}.pkl'.format(step))
            with tf.io.gfile.GFile(weights_file, 'wb') as f:
                pkl_module.dump((tuple(opt_state), step, history, model_state),
                                f)
        log('Model saved to %s' % weights_file, stdout=False)
Example #28
0
    def __init__(self,
                 model,
                 loss_fn,
                 optimizer,
                 lr_schedule,
                 inputs,
                 output_dir=None,
                 random_seed=None,
                 n_devices=None,
                 checkpoints_at=None,
                 should_save_checkpoints=True,
                 should_write_summaries=True,
                 has_weights=False,
                 nontrainable_param_map=None,
                 id_to_mask=None,
                 metrics=None,
                 checkpoint_highest=None,
                 checkpoint_lowest=None):

        self._is_chief, self._n_devices, rng = (self._init_host_and_devices(
            n_devices, random_seed))
        self._should_save_checkpoints = should_save_checkpoints and self._is_chief
        self._checkpoints_at = checkpoints_at or []
        self._should_write_summaries = should_write_summaries
        if not output_dir:
            self._should_save_checkpoints = False
            self._should_write_summaries = False
        self._checkpoint_highest = checkpoint_highest
        self._checkpoint_lowest = checkpoint_lowest
        self._has_weights = has_weights
        self._id_to_mask = id_to_mask
        self._metrics_dict = metrics if metrics is not None else _DEFAULT_METRICS
        loss_fn = loss_fn(has_weights=has_weights, id_to_mask=id_to_mask)
        # Inputs is either an Inputs instance or a function that returns it.
        self._inputs = inputs
        if callable(
                inputs):  # If we pass a function, e.g., through gin, call it.
            self._inputs = inputs()

        # Initialize the learning rate to a dummy value. It will be set in reset().
        opt = optimizer(learning_rate=0.0)

        # Setup the model.
        model_train = model(mode='train')
        model_predict_eval = model(mode='eval')

        # Setup state.
        rng, init_rng = jax_random.split(rng)
        self._rngs = np.stack(jax_random.split(rng, self._n_devices))
        # If the inputs are a tuple/list, add [None] (batch) to each element.
        if self._inputs.input_shape and isinstance(self._inputs.input_shape[0],
                                                   (list, tuple)):
            model_input_shape = tuple(
                tuple([None] + list(shape))
                for shape in self._inputs.input_shape)
        else:  # Otherwise just add [None] to the input shape.
            model_input_shape = tuple([None] + list(self._inputs.input_shape))
        # Same for targets.
        if self._inputs.target_shape and isinstance(
                self._inputs.target_shape[0], (list, tuple)):
            model_target_shape = tuple(
                tuple([None] + list(shape))
                for shape in self._inputs.target_shape)
        else:
            model_target_shape = tuple([None] +
                                       list(self._inputs.target_shape))
        # Change all None to 1 in input and target shape.
        model_input_shape = math.nested_map(lambda x: x or 1,
                                            model_input_shape)
        model_target_shape = math.nested_map(lambda x: x or 1,
                                             model_target_shape)

        def new_opt_state_and_model_state(shape_dtype, rng):
            """Returns optimizer and model states suitable for training a model."""
            # Combine inputs and targets on the stack.
            shapes, dtypes = shape_dtype
            input_signature = tuple(
                ShapeDtype(s, d) for (s, d) in zip(shapes, dtypes))
            # We need to create a new model instance and not reuse `model_train` here,
            # because `m.initialize` puts cached parameter values in `m` and hence the
            # next call of `m.initialize` will give wrong results.
            m = tl.Serial(model(mode='train'), loss_fn)
            m._set_rng_recursive(rng)  # pylint: disable=protected-access
            weights, state = m.init(input_signature)
            (slots, opt_params) = opt.tree_init(weights)
            return (OptState(weights, slots, opt_params), state)

        if _is_jit_init():
            # JIT parameter initialization to avoid memory fragmentation
            new_opt_state_and_model_state = math.jit(
                new_opt_state_and_model_state, static_argnums=(0, ))
        self._new_opt_state_and_model_state = (
            lambda: new_opt_state_and_model_state(  # pylint: disable=g-long-lambda
                self._inputs.example_shape_dtype, init_rng))

        # Arrange and initialize metrics layers.
        self._metrics = list(sorted(self._metrics_dict.keys()))
        metrics_layers = [
            self._metrics_dict[m](has_weights=self._has_weights,
                                  id_to_mask=self._id_to_mask)
            for m in self._metrics
        ]
        metrics_in_parallel = tl.Branch(*metrics_layers)
        metrics_in_parallel._set_rng_recursive(init_rng)  # pylint: disable=protected-access
        example_signature = tuple(
            ShapeDtype(s, d)
            for (s, d) in zip(*self._inputs.example_shape_dtype))
        model_predict_eval.init(example_signature)
        output_signature = model_predict_eval.output_signature(
            example_signature)
        m_weights, m_state = metrics_in_parallel.init(output_signature)
        self._metrics_weights = self._for_n_devices(m_weights)
        self._metrics_state = self._for_n_devices(m_state)

        # Jit model_predict and update so they're fast.
        self._jit_eval = _jit_predict_fn(model_predict_eval,
                                         metrics_in_parallel, self._n_devices)
        self._jit_update_fn = _jit_update_fn(model_train, loss_fn, opt,
                                             self._n_devices)

        self._model_train = model_train
        self._model_predict_eval = model_predict_eval
        self._loss_fn = loss_fn
        # TODO(pkozakowski): "Learning rate schedules" are currently able to control
        # control all optimizer parameters and model state, so let's rename them
        # accordingly.
        self._lr_schedule = lr_schedule

        if nontrainable_param_map is None:
            nontrainable_param_map = {}
        self._nontrainable_param_map = nontrainable_param_map

        # Those fields will be set in reset().
        self._output_dir = None
        self._train_sw = None
        self._eval_sw = None
        self._history = None
        self._lr_fn = None
        self._opt_state = None
        self._step = None
        self._model_state = None
        self.reset(output_dir)
Example #29
0
 def unreplicate(self, unreplicate_state=False):
     """Unreplicate weights and optionally state. Experimental."""
     self.weights = math.nested_map(self.weights, lambda x: x[0])
     if unreplicate_state:
         self.state = math.nested_map(self.state, lambda x: x[0])
Example #30
0
    def trajectory_batch_stream(self,
                                batch_size,
                                epochs=None,
                                max_slice_length=None,
                                min_slice_length=None,
                                margin=0,
                                include_final_state=False,
                                sample_trajectories_uniformly=False):
        """Return a stream of trajectory batches from the specified epochs.

    This function returns a stream of tuples of numpy arrays (tensors).
    If tensors have different lengths, they will be padded by 0.

    Args:
      batch_size: the size of the batches to return
      epochs: a list of epochs to use; we use all epochs if None
      max_slice_length: maximum length of the slices of trajectories to return
      min_slice_length: minimum length of the slices of trajectories to return
      margin: number of extra steps after "done" that should be included in
        slices, so that networks see the terminal states in the training data
      include_final_state: whether to include slices with the final state of
        the trajectory which may have no action and reward
      sample_trajectories_uniformly: whether to sample trajectories uniformly,
       or proportionally to the number of slices in each trajectory (default)

    Yields:
      batches of trajectory slices sampled uniformly from all slices of length
      at least min_slice_length and up to max_slice_length in all specified
      epochs
    """
        def pad(tensor_list):
            # Replace Nones with valid tensors.
            not_none_tensors = [t for t in tensor_list if t is not None]
            assert not_none_tensors, 'All tensors to pad are None.'
            prototype = np.zeros_like(not_none_tensors[0])
            tensor_list = [
                t if t is not None else prototype for t in tensor_list
            ]

            max_len = max([t.shape[0] for t in tensor_list])
            min_len = min([t.shape[0] for t in tensor_list])
            if max_len == min_len:  # No padding needed.
                return np.array(tensor_list)

            pad_len = 2**int(np.ceil(np.log2(max_len)))
            return np.array([
                _zero_pad(t, (0, pad_len - t.shape[0]), axis=0)
                for t in tensor_list
            ])

        cur_batch = []
        for t in self.trajectory_stream(epochs,
                                        max_slice_length,
                                        include_final_state,
                                        sample_trajectories_uniformly,
                                        margin=margin):
            # TODO(pkozakowski): Instead sample the trajectories out of those with
            # the minimum length.
            if min_slice_length is not None and len(t) < min_slice_length:
                continue

            cur_batch.append(t)
            if len(cur_batch) == batch_size:
                # TODO(pkozakowski): Unpack based on name instead of position in the
                # tuple (how?).
                obs, act, dinp, rew, ret, done, mask = zip(
                    *[t.to_np(self._timestep_to_np) for t in cur_batch])
                # Where act, rew and ret will usually have the following shape:
                # [batch_size, trajectory_length-1], which we call [B, L-1].
                # Observations are more complex and will usuall be [B, L] + S where S
                # is the shape of the observation space (self.observation_space.shape).
                # We stop the recursion at level 1, so we pass lists of arrays into
                # pad().
                yield math.nested_map(pad,
                                      TrajectoryNp(
                                          observations=obs,
                                          actions=act,
                                          dist_inputs=dinp,
                                          rewards=rew,
                                          dones=done,
                                          returns=ret,
                                          mask=mask,
                                      ),
                                      level=1)
                cur_batch = []