def test_no_int32_or_uint32_returned(self): """Tests that Trainer._jit_update_fn doesn't return int32 or uint32. TF pins int32/uint32 tensors to CPU, which will cause XLA-forced-compiled computation to copy int32/uint32 outputs to CPU. This test makes sure that won't happen. """ if xla_bridge.device_count() > 1: self.skipTest( "tf-numpy backend doesn't support multi-devices yet.") with backend.use_backend('tf'), self.tmp_dir() as output_dir: n_classes = 1001 model_fn = functools.partial(models.Resnet50, n_output_classes=n_classes) inputs = lambda _: test_inputs(n_classes, input_shape=(224, 224, 3)) trainer = trainer_lib.Trainer( model=model_fn, loss_fn=layers.CrossEntropyLossScalar, optimizer=trax_opt.SM3, lr_schedule=lr.MultifactorSchedule, inputs=inputs, ) trainer.reset(output_dir) trainer.train_epoch(1, 0) # Those are the things returned by Trainer._jit_update_fn arrays = (trainer._opt_state.weights, trainer._opt_state.slots, trainer._model_state, trainer._rngs) arrays = tf.nest.flatten(arrays) for x in arrays: if isinstance(x, np.ndarray) and (x.dtype == np.int32 or x.dtype == np.uint32): raise ValueError('Found an array of int32 or uint32: %s' % x)
def test_reset_twice(self, backend_name): if xla_bridge.device_count() > 1 and backend_name == "tf": self.skipTest("tf-numpy backend doesn't support multi-devices yet.") with backend.use_backend(backend_name), self.tmp_dir() as output_dir1, \ self.tmp_dir() as output_dir2: n_classes = 4 model_fn = functools.partial( models.MLP, d_hidden=16, n_output_classes=n_classes) inputs = lambda _: test_inputs(n_classes) trainer = trainer_lib.Trainer( model=model_fn, loss_fn=layers.CrossEntropyLossScalar, optimizer=trax_opt.SM3, lr_schedule=lr.MultifactorSchedule, inputs=inputs, ) trainer.reset(output_dir1) trainer.evaluate(1) trainer.reset(output_dir2) trainer.evaluate(1)
def test_training_loop_simulated(self): n_actions = 5 history_shape = (3, 2, 3) action_shape = (3, ) obs_shape = (3, 3) reward_shape = (3, 1) def model(mode): del mode return layers.Serial( layers.Parallel( layers.Flatten(), # Observation stack. layers.Embedding(d_feature=1, vocab_size=n_actions), # Action. ), layers.Concatenate(), layers.Dense(n_units=1), layers.Dup(), layers.Parallel( layers.Dense(n_units=obs_shape[1]), # New observation. None, # Reward. )) def inputs(n_devices): del n_devices stream = itertools.repeat( (np.zeros(history_shape), np.zeros(action_shape, dtype=np.int32), np.zeros(obs_shape), np.zeros(reward_shape))) return trax_inputs.Inputs( train_stream=lambda: stream, train_eval_stream=lambda: stream, eval_stream=lambda: stream, input_shape=(history_shape[1:], action_shape[1:]), input_dtype=(np.float32, np.int32), target_shape=(obs_shape[1:], reward_shape[1:]), target_dtype=(np.float32, np.float32), ) def loss(mask_id=None, has_weights=False): """Cross-entropy loss as scalar compatible with Trax masking.""" return layers.Serial( # Swap from (pred-obs, pred-reward, target-obs, target-reward) # to (pred-obs, target-obs, pred-reward, target-reward). layers.Parallel([], layers.Swap()), # Cross-entropy loss for obs, L2 loss on reward. layers.Parallel( layers.CrossEntropyLossScalar(mask_id, has_weights), layers.L2LossScalar(mask_id, has_weights)), # Add both losses. layers.Add(), # Zero out in this test. layers.MulConstant(constant=0.0)) with self.tmp_dir() as output_dir: # Run fake training just to save the parameters. trainer = trainer_lib.Trainer( model=model, loss_fn=loss, inputs=inputs, optimizer=trax_opt.SM3, lr_schedule=lr.MultifactorSchedule, output_dir=output_dir, ) trainer.train_epoch(epoch_steps=1, eval_steps=1) # Repeat the history over and over again. stream = itertools.repeat(np.zeros(history_shape)) env_fn = functools.partial( simulated_env_problem.RawSimulatedEnvProblem, model=model, history_length=history_shape[1], trajectory_length=3, batch_size=history_shape[0], observation_space=gym.spaces.Box(low=-np.inf, high=np.inf, shape=(obs_shape[1], )), action_space=gym.spaces.Discrete(n=n_actions), reward_range=(-1, 1), discrete_rewards=False, history_stream=stream, output_dir=output_dir, ) trainer = self._make_trainer( train_env=env_fn(), eval_env=env_fn(), output_dir=output_dir, ) trainer.training_loop(n_epochs=2)