Ejemplo n.º 1
0
    def train_model(self):
        logging.info("Epoch %d: training model", self._simple_epoch)

        # Load data from all epochs.
        # TODO(pkozakowski): Handle the case when the data won't fit in the memory.
        (train_trajectories, eval_trajectories) = self._load_trajectories(
            self._trajectory_dump_root_dir)
        train_stream = lambda: self._data_stream(  # pylint: disable=g-long-lambda
            train_trajectories, self._model_train_batch_size)
        eval_stream = lambda: self._data_stream(  # pylint: disable=g-long-lambda
            eval_trajectories, self._model_train_batch_size)
        # Ignore n_devices for now.
        inputs = lambda _: trax_inputs.Inputs(  # pylint: disable=g-long-lambda
            train_stream=train_stream,
            train_eval_stream=train_stream,
            eval_stream=eval_stream,
            input_shape=self._sim_env.model_input_shape,
            input_dtype=self._sim_env.model_input_dtype,
        )

        self._model_train_step += self._n_model_train_steps
        trax.train(
            model=self._sim_env.model,
            inputs=inputs,
            train_steps=self._model_train_step,
            output_dir=self._model_dir,
            has_weights=True,
        )
Ejemplo n.º 2
0
def test_inputs(n_classes, with_weights=False, input_shape=(6, 6, 3)):
    """Make trax.inputs.Inputs."""
    batch_size = 2 * xla_bridge.device_count()

    def input_stream():
        key = backend.random.get_prng(0)
        while True:
            keys = backend.random.split(key, 4)
            key = keys[0]
            inputs = backend.random.uniform(keys[1],
                                            [batch_size] + list(input_shape))
            targets = backend.random.randint(keys[2], [batch_size],
                                             dtype=np.int32,
                                             minval=0,
                                             maxval=n_classes)
            weights = backend.random.uniform(keys[3], [batch_size])
            if with_weights:
                yield inputs, targets, weights
            else:
                yield inputs, targets

    return inputs_lib.Inputs(train_stream=input_stream,
                             train_eval_stream=input_stream,
                             eval_stream=input_stream,
                             input_shape=input_shape,
                             input_dtype=np.float32,
                             target_shape=(),
                             target_dtype=np.int32)
Ejemplo n.º 3
0
 def inputs(n_devices):
   del n_devices
   stream = itertools.repeat((
       (np.zeros(history_shape), np.zeros(action_shape, dtype=np.int32)),
       (np.zeros(obs_shape), np.zeros(reward_shape)),
   ))
   return trax_inputs.Inputs(
       train_stream=lambda: stream,
       train_eval_stream=lambda: stream,
       eval_stream=lambda: stream,
       input_shape=(history_shape[1:], action_shape[1:]),
       input_dtype=(np.float32, np.int32),
   )
Ejemplo n.º 4
0
def test_inputs(num_classes):
    """Make trax.inputs.Inputs."""
    batch_size = 2
    input_shape = (6, 6, 3)

    def input_stream():
        while True:
            yield (np.random.rand(*([batch_size] + list(input_shape))),
                   np.random.randint(num_classes, size=batch_size))

    return inputs.Inputs(train_stream=input_stream,
                         eval_stream=input_stream,
                         input_shape=input_shape)
Ejemplo n.º 5
0
def test_inputs(n_classes, with_weights=False):
    """Make trax.inputs.Inputs."""
    batch_size = 2
    input_shape = (6, 6, 3)

    def input_stream():
        while True:
            inputs = np.random.rand(*([batch_size] + list(input_shape)))
            targets = np.random.randint(n_classes, size=batch_size)
            weights = np.random.rand(batch_size)
            if with_weights:
                yield inputs, targets, weights
            else:
                yield inputs, targets

    return inputs_lib.Inputs(train_stream=input_stream,
                             train_eval_stream=input_stream,
                             eval_stream=input_stream,
                             input_shape=input_shape,
                             input_dtype=np.float32)
Ejemplo n.º 6
0
  def train_model(self):
    logging.info("Epoch %d: training model", self._epoch)

    train_stream = lambda: self._data_stream(  # pylint: disable=g-long-lambda
        self._train_trajectories, self._model_train_batch_size)
    eval_stream = lambda: self._data_stream(  # pylint: disable=g-long-lambda
        self._eval_trajectories, self._model_train_batch_size)
    # Ignore n_devices for now.
    inputs = lambda _: trax_inputs.Inputs(  # pylint: disable=g-long-lambda
        train_stream=train_stream,
        train_eval_stream=train_stream,
        eval_stream=eval_stream,
        input_shape=self._sim_env.model_input_shape,
        input_dtype=self._sim_env.model_input_dtype,
    )
    trax.train(
        model=self._sim_env.model,
        inputs=inputs,
        output_dir=self._model_dir,
        has_weights=True,
    )
Ejemplo n.º 7
0
    def train_model(self):
        logging.info("SimPLe epoch [% 6d]: training model.",
                     self._simple_epoch)

        (train_stream, eval_stream) = self._make_input_streams()
        # Ignore n_devices for now.
        inputs = lambda _: trax_inputs.Inputs(  # pylint: disable=g-long-lambda
            train_stream=(lambda: train_stream),
            train_eval_stream=(lambda: train_stream),
            eval_stream=(lambda: eval_stream),
            input_shape=self._sim_env.model_input_shape,
            input_dtype=self._sim_env.model_input_dtype,
        )

        self._model_train_step += self._n_model_train_steps
        trax.train(
            model=self._sim_env.model,
            inputs=inputs,
            train_steps=self._model_train_step,
            output_dir=self._model_dir,
            has_weights=True,
        )
Ejemplo n.º 8
0
    def train_model(self):
        """Train the model.

    Returns:
      whether the training was skipped due to a restart.
    """
        logging.info("SimPLe epoch [% 6d]: training model.",
                     self._simple_epoch)
        start_time = time.time()

        (train_stream, eval_stream) = self._make_input_streams()
        # Ignore n_devices for now.
        inputs = lambda _: trax_inputs.Inputs(  # pylint: disable=g-long-lambda
            train_stream=(lambda: train_stream),
            train_eval_stream=(lambda: train_stream),
            eval_stream=(lambda: eval_stream),
            input_shape=self._sim_env.model_input_shape,
            input_dtype=self._sim_env.model_input_dtype,
            # TODO(lukaszkaiser): correct those, they may differ from inputs.
            target_shape=self._sim_env.model_input_shape,
            target_dtype=self._sim_env.model_input_dtype)

        if self._simple_epoch == 0:
            train_steps = self._n_model_initial_train_steps
        else:
            train_steps = self._n_model_train_steps_per_epoch
        self._model_train_step += train_steps
        with gin.config_scope("world_model"):
            state = trax.train(
                model=self._sim_env.model,
                inputs=inputs,
                train_steps=self._model_train_step,
                output_dir=self._model_dir,
                has_weights=True,
            )

        logging.vlog(1, "Training model took %0.2f sec.",
                     time.time() - start_time)
        return state.step > self._model_train_step