def test_no_int32_or_uint32_returned(self):
        """Tests that Trainer._jit_update_fn doesn't return int32 or uint32.

    TF pins int32/uint32 tensors to CPU, which will cause XLA-forced-compiled
    computation to copy int32/uint32 outputs to CPU. This test makes sure that
    won't happen.
    """
        if xla_bridge.device_count() > 1:
            self.skipTest(
                "tf-numpy backend doesn't support multi-devices yet.")
        with backend.use_backend('tf'), self.tmp_dir() as output_dir:
            n_classes = 1001
            model_fn = functools.partial(models.Resnet50,
                                         n_output_classes=n_classes)
            inputs = lambda _: test_inputs(n_classes,
                                           input_shape=(224, 224, 3))
            trainer = trainer_lib.Trainer(
                model=model_fn,
                loss_fn=layers.CrossEntropyLossScalar,
                optimizer=trax_opt.SM3,
                lr_schedule=lr.MultifactorSchedule,
                inputs=inputs,
            )
            trainer.reset(output_dir)
            trainer.train_epoch(1, 0)
            # Those are the things returned by Trainer._jit_update_fn
            arrays = (trainer._opt_state.weights, trainer._opt_state.slots,
                      trainer._model_state, trainer._rngs)
            arrays = tf.nest.flatten(arrays)
            for x in arrays:
                if isinstance(x, np.ndarray) and (x.dtype == np.int32
                                                  or x.dtype == np.uint32):
                    raise ValueError('Found an array of int32 or uint32: %s' %
                                     x)
Beispiel #2
0
  def test_reset_twice(self, backend_name):
    if xla_bridge.device_count() > 1 and backend_name == "tf":
      self.skipTest("tf-numpy backend doesn't support multi-devices yet.")
    with backend.use_backend(backend_name), self.tmp_dir() as output_dir1, \
          self.tmp_dir() as output_dir2:
      n_classes = 4
      model_fn = functools.partial(
          models.MLP, d_hidden=16, n_output_classes=n_classes)
      inputs = lambda _: test_inputs(n_classes)

      trainer = trainer_lib.Trainer(
          model=model_fn,
          loss_fn=layers.CrossEntropyLossScalar,
          optimizer=trax_opt.SM3,
          lr_schedule=lr.MultifactorSchedule,
          inputs=inputs,
      )

      trainer.reset(output_dir1)
      trainer.evaluate(1)
      trainer.reset(output_dir2)
      trainer.evaluate(1)
Beispiel #3
0
    def test_training_loop_simulated(self):
        n_actions = 5
        history_shape = (3, 2, 3)
        action_shape = (3, )
        obs_shape = (3, 3)
        reward_shape = (3, 1)

        def model(mode):
            del mode
            return layers.Serial(
                layers.Parallel(
                    layers.Flatten(),  # Observation stack.
                    layers.Embedding(d_feature=1,
                                     vocab_size=n_actions),  # Action.
                ),
                layers.Concatenate(),
                layers.Dense(n_units=1),
                layers.Dup(),
                layers.Parallel(
                    layers.Dense(n_units=obs_shape[1]),  # New observation.
                    None,  # Reward.
                ))

        def inputs(n_devices):
            del n_devices
            stream = itertools.repeat(
                (np.zeros(history_shape), np.zeros(action_shape,
                                                   dtype=np.int32),
                 np.zeros(obs_shape), np.zeros(reward_shape)))
            return trax_inputs.Inputs(
                train_stream=lambda: stream,
                train_eval_stream=lambda: stream,
                eval_stream=lambda: stream,
                input_shape=(history_shape[1:], action_shape[1:]),
                input_dtype=(np.float32, np.int32),
                target_shape=(obs_shape[1:], reward_shape[1:]),
                target_dtype=(np.float32, np.float32),
            )

        def loss(mask_id=None, has_weights=False):
            """Cross-entropy loss as scalar compatible with Trax masking."""
            return layers.Serial(
                # Swap from (pred-obs, pred-reward, target-obs, target-reward)
                # to (pred-obs, target-obs, pred-reward, target-reward).
                layers.Parallel([], layers.Swap()),
                # Cross-entropy loss for obs, L2 loss on reward.
                layers.Parallel(
                    layers.CrossEntropyLossScalar(mask_id, has_weights),
                    layers.L2LossScalar(mask_id, has_weights)),
                # Add both losses.
                layers.Add(),
                # Zero out in this test.
                layers.MulConstant(constant=0.0))

        with self.tmp_dir() as output_dir:
            # Run fake training just to save the parameters.
            trainer = trainer_lib.Trainer(
                model=model,
                loss_fn=loss,
                inputs=inputs,
                optimizer=trax_opt.SM3,
                lr_schedule=lr.MultifactorSchedule,
                output_dir=output_dir,
            )
            trainer.train_epoch(epoch_steps=1, eval_steps=1)

            # Repeat the history over and over again.
            stream = itertools.repeat(np.zeros(history_shape))
            env_fn = functools.partial(
                simulated_env_problem.RawSimulatedEnvProblem,
                model=model,
                history_length=history_shape[1],
                trajectory_length=3,
                batch_size=history_shape[0],
                observation_space=gym.spaces.Box(low=-np.inf,
                                                 high=np.inf,
                                                 shape=(obs_shape[1], )),
                action_space=gym.spaces.Discrete(n=n_actions),
                reward_range=(-1, 1),
                discrete_rewards=False,
                history_stream=stream,
                output_dir=output_dir,
            )

            trainer = self._make_trainer(
                train_env=env_fn(),
                eval_env=env_fn(),
                output_dir=output_dir,
            )
            trainer.training_loop(n_epochs=2)