Beispiel #1
0
  def test_no_int32_or_uint32_returned(self):
    """Tests that Trainer._jit_update_fn doesn't return int32 or uint32.

    TF pins int32/uint32 tensors to CPU, which will cause XLA-forced-compiled
    computation to copy int32/uint32 outputs to CPU. This test makes sure that
    won't happen.
    """
    if xla_bridge.device_count() > 1:
      self.skipTest("tf-numpy backend doesn't support multi-devices yet.")
    with fastmath.use_backend(fastmath.Backend.TFNP), \
          self.tmp_dir() as output_dir:
      n_classes = 1001
      model_fn = functools.partial(models.Resnet50,
                                   n_output_classes=n_classes)
      inputs = _test_inputs(n_classes, input_shape=(224, 224, 3))
      trainer = trainer_lib.Trainer(
          model=model_fn,
          loss_fn=tl.CrossEntropyLoss(),
          optimizer=trax_opt.SM3,
          lr_schedule=lr.multifactor(),
          inputs=inputs,
      )
      trainer.reset(output_dir)
      trainer.train_epoch(1, 0)
      # Those are the things returned by Trainer._jit_update_fn
      arrays = (trainer._opt_state.weights, trainer._opt_state.slots,
                trainer._model_state, trainer._rngs)
      arrays = tf.nest.flatten(arrays)
      for x in arrays:
        if isinstance(x, jnp.ndarray) and (x.dtype == jnp.int32 or
                                           x.dtype == jnp.uint32):
          raise ValueError('Found an array of int32 or uint32: %s' % x)
Beispiel #2
0
    def test_reset_twice(self, backend):
        if xla_bridge.device_count() > 1 and backend == fastmath.Backend.TFNP:
            self.skipTest(
                "tf-numpy backend doesn't support multi-devices yet.")
        with fastmath.use_backend(backend):
            n_classes = 4
            model_fn = functools.partial(models.MLP,
                                         d_hidden=16,
                                         n_output_classes=n_classes)
            inputs = _test_inputs(n_classes)

            trainer = trainer_lib.Trainer(
                model=model_fn,
                loss_fn=tl.CrossEntropyLoss(),
                optimizer=trax_opt.SM3,
                lr_schedule=lr.multifactor(),
                inputs=inputs,
            )

            output_dir1 = self.create_tempdir(name='output_dir1').full_path
            trainer.reset(output_dir1)
            trainer.evaluate(1)
            output_dir2 = self.create_tempdir(name='output_dir2').full_path
            trainer.reset(output_dir2)
            trainer.evaluate(1)
Beispiel #3
0
    def test_reset_twice(self, backend):
        with fastmath.use_backend(backend):
            n_classes = 4
            model_fn = functools.partial(models.MLP,
                                         layer_widths=(16, 16, n_classes))
            inputs = _test_inputs(n_classes)

            trainer = trainer_lib.Trainer(
                model=model_fn,
                loss_fn=tl.WeightedCategoryCrossEntropy(),
                optimizer=trax_opt.SM3,
                lr_schedule=lr.multifactor(),
                inputs=inputs,
            )

            output_dir1 = self.create_tempdir(name='output_dir1').full_path
            trainer.reset(output_dir1)
            trainer.evaluate(1)
            output_dir2 = self.create_tempdir(name='output_dir2').full_path
            trainer.reset(output_dir2)
            trainer.evaluate(1)
Beispiel #4
0
  def test_reset_twice(self, backend_name):
    if xla_bridge.device_count() > 1 and backend_name == 'tf':
      self.skipTest("tf-numpy backend doesn't support multi-devices yet.")
    with math.use_backend(backend_name), self.tmp_dir() as output_dir1, \
          self.tmp_dir() as output_dir2:
      n_classes = 4
      model_fn = functools.partial(
          models.MLP, d_hidden=16, n_output_classes=n_classes)
      inputs = test_inputs(n_classes)

      trainer = trainer_lib.Trainer(
          model=model_fn,
          loss_fn=layers.CrossEntropyLoss(),
          optimizer=trax_opt.SM3,
          lr_schedule=lr.MultifactorSchedule,
          inputs=inputs,
      )

      trainer.reset(output_dir1)
      trainer.evaluate(1)
      trainer.reset(output_dir2)
      trainer.evaluate(1)
Beispiel #5
0
    def test_training_loop_simulated(self):
        n_actions = 5
        history_shape = (3, 2, 3)
        action_shape = (3, )
        obs_shape = (3, 3)
        reward_shape = (3, 1)

        def model(mode):
            del mode
            return layers.Serial(
                layers.Parallel(
                    layers.Flatten(),  # Observation stack.
                    layers.Embedding(d_feature=1,
                                     vocab_size=n_actions),  # Action.
                ),
                layers.Concatenate(),
                layers.Dense(n_units=1),
                layers.Dup(),
                layers.Parallel(
                    layers.Dense(n_units=obs_shape[1]),  # New observation.
                    None,  # Reward.
                ))

        stream = itertools.repeat(
            (np.zeros(history_shape), np.zeros(action_shape, dtype=np.int32),
             np.zeros(obs_shape), np.zeros(reward_shape)))
        inp = trax_inputs.Inputs(lambda _: stream)
        inp._input_shape = (history_shape[1:], action_shape[1:])
        inp._input_dtype = (np.float32, np.int32)
        inp._target_shape = (obs_shape[1:], reward_shape[1:])
        inp._target_dtype = (np.float32, np.float32)
        inputs = inp

        def loss(id_to_mask=None, has_weights=False):
            """Cross-entropy loss as scalar compatible with Trax masking."""
            return layers.Serial(
                # Swap from (pred-obs, pred-reward, target-obs, target-reward)
                # to (pred-obs, target-obs, pred-reward, target-reward).
                layers.Parallel([], layers.Swap()),
                # Cross-entropy loss for obs, L2 loss on reward.
                layers.Parallel(
                    layers.CrossEntropyLoss(id_to_mask, has_weights),
                    layers.L2Loss(id_to_mask, has_weights)),
                # Add both losses.
                layers.Add(),
                # Zero out in this test.
                layers.Fn(lambda x: x * 0.0),
            )

        with self.tmp_dir() as output_dir:
            # Run fake training just to save the parameters.
            trainer = trainer_lib.Trainer(
                model=model,
                loss_fn=loss,
                inputs=inputs,
                optimizer=trax_opt.SM3,
                lr_schedule=lr.MultifactorSchedule,
                output_dir=output_dir,
            )
            trainer.train_epoch(n_steps=1, n_eval_steps=1)

            # Repeat the history over and over again.
            stream = itertools.repeat(np.zeros(history_shape))
            env_fn = functools.partial(
                simulated_env_problem.RawSimulatedEnvProblem,
                model=model,
                history_length=history_shape[1],
                trajectory_length=3,
                batch_size=history_shape[0],
                observation_space=gym.spaces.Box(low=-np.inf,
                                                 high=np.inf,
                                                 shape=(obs_shape[1], )),
                action_space=gym.spaces.Discrete(n=n_actions),
                reward_range=(-1, 1),
                discrete_rewards=False,
                history_stream=stream,
                output_dir=output_dir,
            )

            trainer = self._make_trainer(
                train_env=env_fn(),
                eval_env=env_fn(),
                output_dir=output_dir,
            )
            trainer.training_loop(n_epochs=2)