Esempio n. 1
0
    def test_reset_twice(self, backend_name):
        if xla_bridge.device_count() > 1 and backend_name == "tf":
            self.skipTest(
                "tf-numpy backend doesn't support multi-devices yet.")
        with backend.use_backend(backend_name), self.tmp_dir() as output_dir1, \
              self.tmp_dir() as output_dir2:
            n_classes = 4
            model_fn = functools.partial(models.MLP,
                                         d_hidden=16,
                                         n_output_classes=n_classes)
            inputs = lambda _: test_inputs(n_classes)

            trainer = trax.Trainer(
                model=model_fn,
                loss_fn=trax.loss,
                optimizer=trax_opt.SM3,
                lr_schedule=lr.MultifactorSchedule,
                inputs=inputs,
            )

            trainer.reset(output_dir1)
            trainer.evaluate(1)
            trainer.reset(output_dir2)
            trainer.evaluate(1)
Esempio n. 2
0
  def test_training_loop_simulated(self):
    n_actions = 5
    history_shape = (3, 2, 3)
    action_shape = (3,)
    obs_shape = (3, 3)
    reward_shape = (3, 1)

    def model(mode):
      del mode
      return layers.Serial(
          layers.Parallel(
              layers.Flatten(),  # Observation stack.
              layers.Embedding(d_feature=1, vocab_size=n_actions),  # Action.
          ),
          layers.Concatenate(),
          layers.Dense(n_units=1),
          layers.Dup(),
          layers.Parallel(
              layers.Dense(n_units=obs_shape[1]),  # New observation.
              None,  # Reward.
          )
      )

    def inputs(n_devices):
      del n_devices
      stream = itertools.repeat((
          (np.zeros(history_shape), np.zeros(action_shape, dtype=np.int32)),
          (np.zeros(obs_shape), np.zeros(reward_shape)),
      ))
      return trax_inputs.Inputs(
          train_stream=lambda: stream,
          train_eval_stream=lambda: stream,
          eval_stream=lambda: stream,
          input_shape=(history_shape[1:], action_shape[1:]),
          input_dtype=(np.float32, np.int32),
      )

    def loss(*args, **kwargs):
      del args
      del kwargs
      return 0.0

    with self.tmp_dir() as output_dir:
      # Run fake training just to save the parameters.
      trainer = trax.Trainer(
          model=model,
          loss_fn=loss,
          inputs=inputs,
          optimizer=trax_opt.SM3,
          lr_schedule=lr.MultifactorSchedule,
          output_dir=output_dir,
      )
      trainer.train_epoch(epoch_steps=1, eval_steps=1)

      # Repeat the history over and over again.
      stream = itertools.repeat(np.zeros(history_shape))
      env_fn = functools.partial(
          simulated_env_problem.RawSimulatedEnvProblem,
          model=model,
          history_length=history_shape[1],
          trajectory_length=3,
          batch_size=history_shape[0],
          observation_space=gym.spaces.Box(
              low=-np.inf, high=np.inf, shape=(obs_shape[1],)),
          action_space=gym.spaces.Discrete(n=n_actions),
          reward_range=(-1, 1),
          discrete_rewards=False,
          history_stream=stream,
          output_dir=output_dir,
      )

      self._run_training_loop(
          env=env_fn(),
          eval_env=env_fn(),
          output_dir=output_dir,
      )
    def test_training_loop_simulated(self):
        n_actions = 5
        history_shape = (3, 2, 3)
        action_shape = (3, )
        obs_shape = (3, 3)
        reward_shape = (3, 1)

        def model(mode):
            del mode
            return layers.Serial(
                layers.Parallel(
                    layers.Flatten(),  # Observation stack.
                    layers.Embedding(d_feature=1,
                                     vocab_size=n_actions),  # Action.
                ),
                layers.Concatenate(),
                layers.Dense(n_units=1),
                layers.Dup(),
                layers.Parallel(
                    layers.Dense(n_units=obs_shape[1]),  # New observation.
                    None,  # Reward.
                ))

        def inputs(n_devices):
            del n_devices
            stream = itertools.repeat(
                (np.zeros(history_shape), np.zeros(action_shape,
                                                   dtype=np.int32),
                 np.zeros(obs_shape), np.zeros(reward_shape)))
            return trax_inputs.Inputs(
                train_stream=lambda: stream,
                train_eval_stream=lambda: stream,
                eval_stream=lambda: stream,
                input_shape=(history_shape[1:], action_shape[1:]),
                input_dtype=(np.float32, np.int32),
                target_shape=(obs_shape[1:], reward_shape[1:]),
                target_dtype=(np.float32, np.float32),
            )

        def loss(mask_id=None, has_weights=False):
            """Cross-entropy loss as scalar compatible with Trax masking."""
            return layers.Serial(
                # Swap from (pred-obs, pred-reward, target-obs, target-reward)
                # to (pred-obs, target-obs, pred-reward, target-reward).
                layers.Parallel([], layers.Swap()),
                # Cross-entropy loss for obs, L2 loss on reward.
                layers.Parallel(
                    layers.CrossEntropyLossScalar(mask_id, has_weights),
                    layers.L2LossScalar(mask_id, has_weights)),
                # Add both losses.
                layers.Add(),
                # Zero out in this test.
                layers.MulConstant(constant=0.0))

        with self.tmp_dir() as output_dir:
            # Run fake training just to save the parameters.
            trainer = trax.Trainer(
                model=model,
                loss_fn=loss,
                inputs=inputs,
                optimizer=trax_opt.SM3,
                lr_schedule=lr.MultifactorSchedule,
                output_dir=output_dir,
            )
            trainer.train_epoch(epoch_steps=1, eval_steps=1)

            # Repeat the history over and over again.
            stream = itertools.repeat(np.zeros(history_shape))
            env_fn = functools.partial(
                simulated_env_problem.RawSimulatedEnvProblem,
                model=model,
                history_length=history_shape[1],
                trajectory_length=3,
                batch_size=history_shape[0],
                observation_space=gym.spaces.Box(low=-np.inf,
                                                 high=np.inf,
                                                 shape=(obs_shape[1], )),
                action_space=gym.spaces.Discrete(n=n_actions),
                reward_range=(-1, 1),
                discrete_rewards=False,
                history_stream=stream,
                output_dir=output_dir,
            )

            trainer = self._make_trainer(
                train_env=env_fn(),
                eval_env=env_fn(),
                output_dir=output_dir,
            )
            trainer.training_loop(n_epochs=2)