def test_sanity_a2ctrainer_cartpole(self): """Test-runs a2c on cartpole.""" task = rl_task.RLTask('CartPole-v0', initial_trajectories=0, max_steps=2) body = lambda mode: tl.Serial(tl.Dense(64), tl.Relu()) policy_model = functools.partial(models.Policy, body=body) value_model = functools.partial(models.Value, body=body) lr = lambda: lr_schedules.multifactor( # pylint: disable=g-long-lambda constant=1e-4, warmup_steps=100, factors='constant * linear_warmup') trainer = actor_critic.A2C(task, n_shared_layers=1, value_model=value_model, value_optimizer=opt.Adam, value_lr_schedule=lr, value_batch_size=2, value_train_steps_per_epoch=2, policy_model=policy_model, policy_optimizer=opt.Adam, policy_lr_schedule=lr, policy_batch_size=2, policy_train_steps_per_epoch=2, n_trajectories_per_epoch=2) trainer.run(2) self.assertEqual(2, trainer.current_epoch)
def test_policytrainer_cartpole(self): """Trains a policy on cartpole.""" task = rl_task.RLTask('CartPole-v0', initial_trajectories=1, max_steps=200) model = functools.partial( models.Policy, body=lambda mode: tl.Serial( # pylint: disable=g-long-lambda tl.Dense(64), tl.Relu(), tl.Dense(64), tl.Relu() ), ) lr = lambda: lr_schedules.multifactor( # pylint: disable=g-long-lambda constant=1e-2, warmup_steps=100, factors='constant * linear_warmup') max_avg_returns = -math.inf for _ in range(5): trainer = training.PolicyGradient( task, policy_model=model, policy_optimizer=opt.Adam, policy_lr_schedule=lr, policy_batch_size=128, policy_train_steps_per_epoch=1, n_trajectories_per_epoch=2) # Assert that we get to 200 at some point and then exit so the test is as # fast as possible. for ep in range(200): trainer.run(1) self.assertEqual(trainer.current_epoch, ep + 1) if trainer.avg_returns[-1] == 200.0: return max_avg_returns = max(max_avg_returns, trainer.avg_returns[-1]) self.fail( 'The expected score of 200 has not been reached. ' 'Maximum at end was {}.'.format(max_avg_returns) )
def test_sanity_awrtrainer_transformer_cartpole(self): """Test-runs AWR on cartpole with Transformer.""" task = rl_task.RLTask('CartPole-v0', initial_trajectories=2, max_steps=2) body = lambda mode: models.TransformerDecoder( # pylint: disable=g-long-lambda d_model=2, d_ff=2, n_layers=1, n_heads=1, mode=mode) policy_model = functools.partial(models.Policy, body=body) value_model = functools.partial(models.Value, body=body) lr = lambda: lr_schedules.multifactor( # pylint: disable=g-long-lambda constant=1e-2, warmup_steps=100, factors='constant * linear_warmup') trainer = actor_critic.AWR(task, n_shared_layers=0, max_slice_length=2, added_policy_slice_length=1, value_model=value_model, value_optimizer=opt.Adam, value_lr_schedule=lr, value_batch_size=2, value_train_steps_per_epoch=2, policy_model=policy_model, policy_optimizer=opt.Adam, policy_lr_schedule=lr, policy_batch_size=2, policy_train_steps_per_epoch=2, n_trajectories_per_epoch=1, n_eval_episodes=1) trainer.run(2) self.assertEqual(2, trainer.current_epoch)
def test_sampling_awrtrainer_mountain_acr(self): """Test-runs Sampling AWR on MountainCarContinuous.""" task = rl_task.RLTask('MountainCarContinuous-v0', initial_trajectories=0, max_steps=2) body = lambda mode: tl.Serial(tl.Dense(2), tl.Relu()) policy_model = functools.partial(models.Policy, body=body) value_model = functools.partial(models.Value, body=body) lr = lambda: lr_schedules.multifactor( # pylint: disable=g-long-lambda constant=1e-2, warmup_steps=100, factors='constant * linear_warmup') trainer = actor_critic.SamplingAWR( task, n_shared_layers=0, added_policy_slice_length=1, value_model=value_model, value_optimizer=opt.Adam, value_lr_schedule=lr, value_batch_size=2, value_train_steps_per_epoch=2, policy_model=policy_model, policy_optimizer=opt.Adam, policy_lr_schedule=lr, policy_batch_size=2, policy_train_steps_per_epoch=2, n_trajectories_per_epoch=2, advantage_estimator=advantages.monte_carlo, advantage_normalization=False, q_value_n_samples=3, ) trainer.run(1) self.assertEqual(1, trainer.current_epoch)
def test_no_int32_or_uint32_returned(self): """Tests that Trainer._jit_update_fn doesn't return int32 or uint32. TF pins int32/uint32 tensors to CPU, which will cause XLA-forced-compiled computation to copy int32/uint32 outputs to CPU. This test makes sure that won't happen. """ if xla_bridge.device_count() > 1: self.skipTest("tf-numpy backend doesn't support multi-devices yet.") with fastmath.use_backend(fastmath.Backend.TFNP), \ self.tmp_dir() as output_dir: n_classes = 1001 model_fn = functools.partial(models.Resnet50, n_output_classes=n_classes) inputs = _test_inputs(n_classes, input_shape=(224, 224, 3)) trainer = trainer_lib.Trainer( model=model_fn, loss_fn=tl.CrossEntropyLoss(), optimizer=trax_opt.SM3, lr_schedule=lr.multifactor(), inputs=inputs, ) trainer.reset(output_dir) trainer.train_epoch(1, 0) # Those are the things returned by Trainer._jit_update_fn arrays = (trainer._opt_state.weights, trainer._opt_state.slots, trainer._model_state, trainer._rngs) arrays = tf.nest.flatten(arrays) for x in arrays: if isinstance(x, jnp.ndarray) and (x.dtype == jnp.int32 or x.dtype == jnp.uint32): raise ValueError('Found an array of int32 or uint32: %s' % x)
def test_sanity_ppo_cartpole(self): """Run PPO and check whether it correctly runs for 2 epochs.s.""" task = rl_task.RLTask('CartPole-v1', initial_trajectories=0, max_steps=200) lr = lambda: lr_schedules.multifactor( # pylint: disable=g-long-lambda constant=1e-3, warmup_steps=100, factors='constant * linear_warmup') body = lambda mode: tl.Serial(tl.Dense(64), tl.Relu()) policy_model = functools.partial(models.Policy, body=body) value_model = functools.partial(models.Value, body=body) trainer = actor_critic.PPO(task, n_shared_layers=1, value_model=value_model, value_optimizer=opt.Adam, value_lr_schedule=lr, value_batch_size=128, value_train_steps_per_epoch=10, policy_model=policy_model, policy_optimizer=opt.Adam, policy_lr_schedule=lr, policy_batch_size=128, policy_train_steps_per_epoch=10, n_trajectories_per_epoch=10) trainer.run(2) self.assertEqual(2, trainer.current_epoch)
def test_reset_twice(self, backend): if xla_bridge.device_count() > 1 and backend == fastmath.Backend.TFNP: self.skipTest( "tf-numpy backend doesn't support multi-devices yet.") with fastmath.use_backend(backend): n_classes = 4 model_fn = functools.partial(models.MLP, d_hidden=16, n_output_classes=n_classes) inputs = _test_inputs(n_classes) trainer = trainer_lib.Trainer( model=model_fn, loss_fn=tl.CrossEntropyLoss(), optimizer=trax_opt.SM3, lr_schedule=lr.multifactor(), inputs=inputs, ) output_dir1 = self.create_tempdir(name='output_dir1').full_path trainer.reset(output_dir1) trainer.evaluate(1) output_dir2 = self.create_tempdir(name='output_dir2').full_path trainer.reset(output_dir2) trainer.evaluate(1)
def test_awrtrainer_cartpole(self): """Test-runs AWR on cartpole.""" task = rl_task.RLTask('CartPole-v0', initial_trajectories=1000, max_steps=200) body = lambda mode: tl.Serial(tl.Dense(64), tl.Relu()) policy_model = functools.partial(models.Policy, body=body) value_model = functools.partial(models.Value, body=body) lr = lambda: lr_schedules.multifactor( # pylint: disable=g-long-lambda constant=1e-2, warmup_steps=100, factors='constant * linear_warmup') trainer = actor_critic.AWRTrainer( task, n_shared_layers=0, added_policy_slice_length=1, value_model=value_model, value_optimizer=opt.Adam, value_lr_schedule=lr, value_batch_size=32, value_train_steps_per_epoch=200, policy_model=policy_model, policy_optimizer=opt.Adam, policy_lr_schedule=lr, policy_batch_size=32, policy_train_steps_per_epoch=200, n_trajectories_per_epoch=10, advantage_estimator=advantages.monte_carlo, advantage_normalization=False, ) trainer.run(1) self.assertEqual(1, trainer.current_epoch) self.assertGreater(trainer.avg_returns[-1], 35.0)
def test_policy_gradient_cartpole(self): """Trains a policy on cartpole.""" task = rl_task.RLTask('CartPole-v0', max_steps=200) lr = lambda: lr_schedules.multifactor(constant=1e-2, factors='constant') max_avg_returns = -math.inf for _ in range(2): agent = training.PolicyGradient( task, model_fn=self._model_fn, optimizer=opt.Adam, lr_schedule=lr, batch_size=128, n_trajectories_per_epoch=2, ) # Assert that we get to 200 at some point and then exit so the test is as # fast as possible. for ep in range(200): agent.run(1) self.assertEqual(agent.current_epoch, ep + 1) if agent.avg_returns[-1] == 200.0: return max_avg_returns = max(max_avg_returns, agent.avg_returns[-1]) self.fail('The expected score of 200 has not been reached. ' 'Maximum at end was {}.'.format(max_avg_returns))
def test_jointppotrainer_cartpole(self): """Test-runs joint PPO on CartPole.""" task = rl_task.RLTask('CartPole-v0', initial_trajectories=0, max_steps=2) joint_model = functools.partial( models.PolicyAndValue, body=lambda mode: tl.Serial(tl.Dense(2), tl.Relu()), ) lr = lambda: lr_schedules.multifactor( # pylint: disable=g-long-lambda constant=1e-2, warmup_steps=100, factors='constant * linear_warmup') trainer = actor_critic_joint.PPOJointTrainer( task, joint_model=joint_model, optimizer=opt.Adam, lr_schedule=lr, batch_size=4, train_steps_per_epoch=2, n_trajectories_per_epoch=5) trainer.run(2) self.assertEqual(2, trainer.current_epoch)
def test_awrtrainer_cartpole_shared(self): """Test-runs AWR on cartpole with shared layers.""" # This test is flaky, and this is the simplest way to retry in OSS. task = rl_task.RLTask('CartPole-v0', initial_trajectories=1000, max_steps=200) body = lambda mode: tl.Serial(tl.Dense(64), tl.Relu()) policy_model = functools.partial(models.Policy, body=body) value_model = functools.partial(models.Value, body=body) # pylint: disable=g-long-lambda lr = (lambda: lr_schedules.multifactor(constant=1e-2, warmup_steps=100, factors= 'constant * linear_warmup')) # pylint: enable=g-long-lambda max_avg_returns = -math.inf for _ in range(5): trainer = actor_critic.AWRTrainer( task, n_shared_layers=1, added_policy_slice_length=1, value_model=value_model, value_optimizer=opt.Adam, value_lr_schedule=lr, value_batch_size=32, value_train_steps_per_epoch=200, policy_model=policy_model, policy_optimizer=opt.Adam, policy_lr_schedule=lr, policy_batch_size=32, policy_train_steps_per_epoch=200, n_trajectories_per_epoch=10, advantage_estimator=advantages.monte_carlo, advantage_normalization=False, ) trainer.run(1) self.assertEqual(1, trainer.current_epoch) max_avg_returns = (max_avg_returns if max_avg_returns > trainer.avg_returns[-1] else trainer.avg_returns[-1]) if trainer.avg_returns[-1] > 35.0: return self.fail( f'We did not reach a score > 35.0, max was {max_avg_returns}.')
def test_reset_twice(self, backend): with fastmath.use_backend(backend): n_classes = 4 model_fn = functools.partial(models.MLP, layer_widths=(16, 16, n_classes)) inputs = _test_inputs(n_classes) trainer = trainer_lib.Trainer( model=model_fn, loss_fn=tl.WeightedCategoryCrossEntropy(), optimizer=trax_opt.SM3, lr_schedule=lr.multifactor(), inputs=inputs, ) output_dir1 = self.create_tempdir(name='output_dir1').full_path trainer.reset(output_dir1) trainer.evaluate(1) output_dir2 = self.create_tempdir(name='output_dir2').full_path trainer.reset(output_dir2) trainer.evaluate(1)
def test_dqntrainer_cartpole(self): """Test-runs joint PPO on CartPole.""" task = rl_task.RLTask('CartPole-v0', initial_trajectories=0, max_steps=2) value_body = lambda mode: tl.Serial(tl.Dense(64), tl.Relu()) lr = lambda: lr_schedules.multifactor( # pylint: disable=g-long-lambda constant=1e-2, warmup_steps=100, factors='constant * linear_warmup') trainer = training.DQN(task, value_body=value_body, value_optimizer=opt.Adam, value_lr_schedule=lr, value_batch_size=4, value_train_steps_per_epoch=2, n_trajectories_per_epoch=5) trainer.run(2) self.assertEqual(2, trainer.current_epoch)
def test_reset_twice(self, backend_name): if xla_bridge.device_count() > 1 and backend_name == 'tf': self.skipTest( "tf-numpy backend doesn't support multi-devices yet.") with fastmath.use_backend(backend_name), self.tmp_dir() as output_dir1, \ self.tmp_dir() as output_dir2: n_classes = 4 model_fn = functools.partial(models.MLP, d_hidden=16, n_output_classes=n_classes) inputs = _test_inputs(n_classes) trainer = trainer_lib.Trainer( model=model_fn, loss_fn=layers.CrossEntropyLoss(), optimizer=trax_opt.SM3, lr_schedule=lr.multifactor(), inputs=inputs, ) trainer.reset(output_dir1) trainer.evaluate(1) trainer.reset(output_dir2) trainer.evaluate(1)
def train(output_dir, model=gin.REQUIRED, loss_fn=tl.CrossEntropyLoss(), inputs=trax_inputs.batcher, optimizer=trax_opt.Adafactor, lr_schedule=lr.multifactor(), trainer_class=Trainer, steps=1000, checkpoints_at=None, eval_steps=10, eval_frequency=100, random_seed=None, save_graphs=True, metrics=None, checkpoint_highest=None, checkpoint_lowest=None, custom_train_fn=None): """Train the model on the inputs. Args: output_dir: Directory where to put the logs and checkpoints. model: The model to train as a callable returning 2 callables, an init_fn and apply_fn. loss_fn: callable with signature: weights, trax.inputs.Inputs, model, state, rng -> loss. inputs: callable returning trax.inputs.Inputs. optimizer: The optimizer (see optimizers/base.py for signature). lr_schedule: A learning rate schedule as a function that takes history and returns a function from step to learning rate (a float). trainer_class: The trainer class to use. steps: int, total number of training steps. checkpoints_at: list of integers. Save a checkpoint for each training step in the list. eval_steps: int, num of steps per evaluation. If None or 0, eval disabled. eval_frequency: int, how often to run evaluation (every eval_frequency steps). If None or 0, eval disabled. random_seed: the random seed to use; time/os dependent if None (default). save_graphs: bool, if True, save computation graph to file. metrics: optionally override the default metrics dictionary. checkpoint_highest: save the checkpoint highest at this metric. checkpoint_lowest: save the checkpoint lowest at this metric. custom_train_fn: custom train function to call, entirely bypassing this one Returns: trax.TrainerState """ if custom_train_fn is not None: return custom_train_fn(output_dir, model=model) n_devices = num_devices() trainer = trainer_class(model, loss_fn, optimizer, lr_schedule, inputs, output_dir, random_seed=random_seed, n_devices=n_devices, checkpoints_at=checkpoints_at, metrics=metrics, checkpoint_lowest=checkpoint_lowest, checkpoint_highest=checkpoint_highest) epoch_steps = [steps] # Only training if eval_frequency is 0 or None if eval_frequency and eval_steps > 0: epoch_steps = itertools.chain([1, # first epoch only 1 step eval_frequency - 1], itertools.repeat(eval_frequency)) trainer.log_step('Starting training using %d devices' % trainer.n_devices) trainer.print_n_weights() try: for epoch_steps in epochs(steps, trainer.step, epoch_steps): trainer.train_epoch(epoch_steps, eval_steps) # Bookkeeping we do at the first step if trainer.step == 1: # Save computation graph (single-device only for now) if (save_graphs and fastmath.backend_name() == 'jax'): trainer.save_computation_graphs() # Save Gin config trainer.save_gin() trainer.log_step('Training done') except Exception as e: raise e finally: trainer.close() return trainer.state