def test_no_int32_or_uint32_returned(self): """Tests that Trainer._jit_update_fn doesn't return int32 or uint32. TF pins int32/uint32 tensors to CPU, which will cause XLA-forced-compiled computation to copy int32/uint32 outputs to CPU. This test makes sure that won't happen. """ if xla_bridge.device_count() > 1: self.skipTest("tf-numpy backend doesn't support multi-devices yet.") with fastmath.use_backend(fastmath.Backend.TFNP), \ self.tmp_dir() as output_dir: n_classes = 1001 model_fn = functools.partial(models.Resnet50, n_output_classes=n_classes) inputs = _test_inputs(n_classes, input_shape=(224, 224, 3)) trainer = trainer_lib.Trainer( model=model_fn, loss_fn=tl.CrossEntropyLoss(), optimizer=trax_opt.SM3, lr_schedule=lr.multifactor(), inputs=inputs, ) trainer.reset(output_dir) trainer.train_epoch(1, 0) # Those are the things returned by Trainer._jit_update_fn arrays = (trainer._opt_state.weights, trainer._opt_state.slots, trainer._model_state, trainer._rngs) arrays = tf.nest.flatten(arrays) for x in arrays: if isinstance(x, jnp.ndarray) and (x.dtype == jnp.int32 or x.dtype == jnp.uint32): raise ValueError('Found an array of int32 or uint32: %s' % x)
def test_reset_twice(self, backend): if xla_bridge.device_count() > 1 and backend == fastmath.Backend.TFNP: self.skipTest( "tf-numpy backend doesn't support multi-devices yet.") with fastmath.use_backend(backend): n_classes = 4 model_fn = functools.partial(models.MLP, d_hidden=16, n_output_classes=n_classes) inputs = _test_inputs(n_classes) trainer = trainer_lib.Trainer( model=model_fn, loss_fn=tl.CrossEntropyLoss(), optimizer=trax_opt.SM3, lr_schedule=lr.multifactor(), inputs=inputs, ) output_dir1 = self.create_tempdir(name='output_dir1').full_path trainer.reset(output_dir1) trainer.evaluate(1) output_dir2 = self.create_tempdir(name='output_dir2').full_path trainer.reset(output_dir2) trainer.evaluate(1)
def test_reset_twice(self, backend): with fastmath.use_backend(backend): n_classes = 4 model_fn = functools.partial(models.MLP, layer_widths=(16, 16, n_classes)) inputs = _test_inputs(n_classes) trainer = trainer_lib.Trainer( model=model_fn, loss_fn=tl.WeightedCategoryCrossEntropy(), optimizer=trax_opt.SM3, lr_schedule=lr.multifactor(), inputs=inputs, ) output_dir1 = self.create_tempdir(name='output_dir1').full_path trainer.reset(output_dir1) trainer.evaluate(1) output_dir2 = self.create_tempdir(name='output_dir2').full_path trainer.reset(output_dir2) trainer.evaluate(1)
def test_reset_twice(self, backend_name): if xla_bridge.device_count() > 1 and backend_name == 'tf': self.skipTest("tf-numpy backend doesn't support multi-devices yet.") with math.use_backend(backend_name), self.tmp_dir() as output_dir1, \ self.tmp_dir() as output_dir2: n_classes = 4 model_fn = functools.partial( models.MLP, d_hidden=16, n_output_classes=n_classes) inputs = test_inputs(n_classes) trainer = trainer_lib.Trainer( model=model_fn, loss_fn=layers.CrossEntropyLoss(), optimizer=trax_opt.SM3, lr_schedule=lr.MultifactorSchedule, inputs=inputs, ) trainer.reset(output_dir1) trainer.evaluate(1) trainer.reset(output_dir2) trainer.evaluate(1)
def test_training_loop_simulated(self): n_actions = 5 history_shape = (3, 2, 3) action_shape = (3, ) obs_shape = (3, 3) reward_shape = (3, 1) def model(mode): del mode return layers.Serial( layers.Parallel( layers.Flatten(), # Observation stack. layers.Embedding(d_feature=1, vocab_size=n_actions), # Action. ), layers.Concatenate(), layers.Dense(n_units=1), layers.Dup(), layers.Parallel( layers.Dense(n_units=obs_shape[1]), # New observation. None, # Reward. )) stream = itertools.repeat( (np.zeros(history_shape), np.zeros(action_shape, dtype=np.int32), np.zeros(obs_shape), np.zeros(reward_shape))) inp = trax_inputs.Inputs(lambda _: stream) inp._input_shape = (history_shape[1:], action_shape[1:]) inp._input_dtype = (np.float32, np.int32) inp._target_shape = (obs_shape[1:], reward_shape[1:]) inp._target_dtype = (np.float32, np.float32) inputs = inp def loss(id_to_mask=None, has_weights=False): """Cross-entropy loss as scalar compatible with Trax masking.""" return layers.Serial( # Swap from (pred-obs, pred-reward, target-obs, target-reward) # to (pred-obs, target-obs, pred-reward, target-reward). layers.Parallel([], layers.Swap()), # Cross-entropy loss for obs, L2 loss on reward. layers.Parallel( layers.CrossEntropyLoss(id_to_mask, has_weights), layers.L2Loss(id_to_mask, has_weights)), # Add both losses. layers.Add(), # Zero out in this test. layers.Fn(lambda x: x * 0.0), ) with self.tmp_dir() as output_dir: # Run fake training just to save the parameters. trainer = trainer_lib.Trainer( model=model, loss_fn=loss, inputs=inputs, optimizer=trax_opt.SM3, lr_schedule=lr.MultifactorSchedule, output_dir=output_dir, ) trainer.train_epoch(n_steps=1, n_eval_steps=1) # Repeat the history over and over again. stream = itertools.repeat(np.zeros(history_shape)) env_fn = functools.partial( simulated_env_problem.RawSimulatedEnvProblem, model=model, history_length=history_shape[1], trajectory_length=3, batch_size=history_shape[0], observation_space=gym.spaces.Box(low=-np.inf, high=np.inf, shape=(obs_shape[1], )), action_space=gym.spaces.Discrete(n=n_actions), reward_range=(-1, 1), discrete_rewards=False, history_stream=stream, output_dir=output_dir, ) trainer = self._make_trainer( train_env=env_fn(), eval_env=env_fn(), output_dir=output_dir, ) trainer.training_loop(n_epochs=2)