Example #1
0
def test_inputs(n_classes, with_weights=False, input_shape=(6, 6, 3)):
    """Make trainer_lib.inputs.Inputs."""
    batch_size = 2 * xla_bridge.device_count()

    def input_stream():
        key = backend.random.get_prng(0)
        while True:
            keys = backend.random.split(key, 4)
            key = keys[0]
            inputs = backend.random.uniform(keys[1],
                                            [batch_size] + list(input_shape))
            targets = backend.random.randint(keys[2], [batch_size],
                                             dtype=np.int32,
                                             minval=0,
                                             maxval=n_classes)
            weights = backend.random.uniform(keys[3], [batch_size])
            if with_weights:
                yield inputs, targets, weights
            else:
                yield inputs, targets

    return inputs_lib.Inputs(train_stream=input_stream,
                             train_eval_stream=input_stream,
                             eval_stream=input_stream,
                             input_shape=input_shape,
                             input_dtype=np.float32,
                             target_shape=(),
                             target_dtype=np.int32)
Example #2
0
 def inputs(n_devices):
     del n_devices
     stream = itertools.repeat(
         (np.zeros(history_shape), np.zeros(action_shape,
                                            dtype=np.int32),
          np.zeros(obs_shape), np.zeros(reward_shape)))
     return trax_inputs.Inputs(
         train_stream=lambda: stream,
         train_eval_stream=lambda: stream,
         eval_stream=lambda: stream,
         input_shape=(history_shape[1:], action_shape[1:]),
         input_dtype=(np.float32, np.int32),
         target_shape=(obs_shape[1:], reward_shape[1:]),
         target_dtype=(np.float32, np.float32),
     )
Example #3
0
    def train_model(self):
        """Train the model.

    Returns:
      whether the training was skipped due to a restart.
    """
        logging.info('SimPLe epoch [% 6d]: training model.',
                     self._simple_epoch)
        start_time = time.time()

        (train_stream, eval_stream) = self._make_input_streams()
        # Ignore n_devices for now.
        inputs = lambda _: trax_inputs.Inputs(  # pylint: disable=g-long-lambda
            train_stream=(lambda: train_stream),
            train_eval_stream=(lambda: train_stream),
            eval_stream=(lambda: eval_stream),
            input_shape=self._sim_env.model_input_shape,
            input_dtype=self._sim_env.model_input_dtype,
            # TODO(lukaszkaiser): correct those, they may differ from inputs.
            target_shape=self._sim_env.model_input_shape,
            target_dtype=self._sim_env.model_input_dtype)

        if self._simple_epoch == 0:
            train_steps = self._n_model_initial_train_steps
        else:
            train_steps = self._n_model_train_steps_per_epoch
        self._model_train_step += train_steps
        with gin.config_scope('world_model'):
            state = trainer_lib.train(
                model=self._sim_env.model,
                inputs=inputs,
                train_steps=self._model_train_step,
                output_dir=self._model_dir,
                has_weights=True,
            )

        logging.vlog(1, 'Training model took %0.2f sec.',
                     time.time() - start_time)
        return state.step > self._model_train_step