def test_integration_with_policy_tasks(self): # Integration test for policy + value training and eval. optimizer = opt.Adam() lr_schedule = lr_schedules.constant(1e-3) advantage_estimator = advantages.td_k(gamma=self._task.gamma, margin=1) policy_dist = distributions.create_distribution(self._task.action_space) body = lambda mode: tl.Dense(64) train_model = models.PolicyAndValue(policy_dist, body=body) eval_model = models.PolicyAndValue(policy_dist, body=body) head_selector = tl.Select([1]) value_train_task = value_tasks.ValueTrainTask( self._trajectory_batch_stream, optimizer, lr_schedule, advantage_estimator, model=train_model, target_model=eval_model, head_selector=head_selector, ) value_eval_task = value_tasks.ValueEvalTask( value_train_task, head_selector=head_selector ) # Drop the value head - just tl.Select([0]) would pass it, and it would # override the targets. head_selector = tl.Select([0], n_in=2) policy_train_task = policy_tasks.PolicyTrainTask( self._trajectory_batch_stream, optimizer, lr_schedule, policy_dist, advantage_estimator, # Plug a trained critic as our value estimate. value_fn=value_train_task.value, head_selector=head_selector, ) policy_eval_task = policy_tasks.PolicyEvalTask( policy_train_task, head_selector=head_selector ) loop = training.Loop( model=train_model, eval_model=eval_model, tasks=[policy_train_task, value_train_task], eval_tasks=[policy_eval_task, value_eval_task], eval_at=(lambda _: True), # Switch the task every step. which_task=(lambda step: step % 2), ) # Run for a couple of steps to make sure there are a few task switches. loop.run(n_steps=10)
def __init__( self, task, model_fn, optimizer=adam.Adam, lr_schedule=lr.multifactor, batch_size=64, network_eval_at=None, n_eval_batches=1, max_slice_length=1, **kwargs ): """Initializes PolicyGradient. Args: task: Instance of trax.rl.task.RLTask. model_fn: Function (policy_distribution, mode) -> policy_model. optimizer: Optimizer for network training. lr_schedule: Learning rate schedule for network training. batch_size: Batch size for network training. network_eval_at: Function step -> bool indicating the training steps, when network evaluation should be performed. n_eval_batches: Number of batches to run during network evaluation. max_slice_length: The length of trajectory slices to run the network on. **kwargs: Keyword arguments passed to the superclass. """ super().__init__(task, **kwargs) self._max_slice_length = max_slice_length trajectory_batch_stream = task.trajectory_batch_stream( batch_size, epochs=[-1], max_slice_length=self._max_slice_length, sample_trajectories_uniformly=True, ) self._policy_dist = distributions.create_distribution(task.action_space) train_task = policy_tasks.PolicyTrainTask( trajectory_batch_stream, optimizer(), lr_schedule(), self._policy_dist, # Policy gradient uses the MC estimator. No need for margin - the MC # estimator only uses empirical returns. advantage_estimator=advantages.monte_carlo(task.gamma, margin=0), value_fn=self._value_fn, ) eval_task = policy_tasks.PolicyEvalTask(train_task, n_eval_batches) model_fn = functools.partial( model_fn, policy_distribution=self._policy_dist, ) if self._output_dir is not None: policy_output_dir = os.path.join(self._output_dir, 'policy') else: policy_output_dir = None # Checkpoint every epoch. We do one step per epoch, so that's every step. checkpoint_at = lambda _: True self._loop = supervised.training.Loop( model=model_fn(mode='train'), tasks=[train_task], eval_model=model_fn(mode='eval'), eval_tasks=[eval_task], output_dir=policy_output_dir, eval_at=network_eval_at, checkpoint_at=checkpoint_at, ) self._collect_model = model_fn(mode='collect') self._collect_model.init(shapes.signature(train_task.sample_batch)) # Validate the restored checkpoints. The number of network training steps # (self.loop.step) should be equal to the number of epochs (self._epoch), # because we do exactly one gradient step per epoch. # TODO(pkozakowski): Move this to the base class once all Agents use Loop. if self.loop.step != self._epoch: raise ValueError( 'The number of Loop steps must equal the number of Agent epochs, ' 'got {} and {}.'.format(self.loop.step, self._epoch) )
def __init__(self, task, model_fn, value_fn, weight_fn, n_replay_epochs, n_train_steps_per_epoch, optimizer=adam.Adam, lr_schedule=lr.multifactor, batch_size=64, network_eval_at=None, n_eval_batches=1, max_slice_length=1, **kwargs): """Initializes LoopPolicyAgent. Args: task: Instance of trax.rl.task.RLTask. model_fn: Function (policy_distribution, mode) -> policy_model. value_fn: Function TimeStepBatch -> array (batch_size, seq_len) calculating the baseline for advantage calculation. weight_fn: Function float -> float to apply to advantages when calculating policy loss. n_replay_epochs: Number of last epochs to take into the replay buffer; only makes sense for off-policy algorithms. n_train_steps_per_epoch: Number of steps to train the policy network for in each epoch. optimizer: Optimizer for network training. lr_schedule: Learning rate schedule for network training. batch_size: Batch size for network training. network_eval_at: Function step -> bool indicating the training steps, when network evaluation should be performed. n_eval_batches: Number of batches to run during network evaluation. max_slice_length: The length of trajectory slices to run the network on. **kwargs: Keyword arguments passed to the superclass. """ self._n_train_steps_per_epoch = n_train_steps_per_epoch super().__init__(task, **kwargs) task.set_n_replay_epochs(n_replay_epochs) self._max_slice_length = max_slice_length trajectory_batch_stream = task.trajectory_batch_stream( batch_size, epochs=[-(ep + 1) for ep in range(n_replay_epochs)], max_slice_length=self._max_slice_length, sample_trajectories_uniformly=True, ) self._policy_dist = distributions.create_distribution( task.action_space) train_task = policy_tasks.PolicyTrainTask( trajectory_batch_stream, optimizer(), lr_schedule(), self._policy_dist, # Without a value network it doesn't make a lot of sense to use # a better advantage estimator than MC. advantage_estimator=advantages.monte_carlo(task.gamma, margin=0), value_fn=value_fn, weight_fn=weight_fn, ) eval_task = policy_tasks.PolicyEvalTask(train_task, n_eval_batches) model_fn = functools.partial( model_fn, policy_distribution=self._policy_dist, ) if self._output_dir is not None: policy_output_dir = os.path.join(self._output_dir, 'policy') else: policy_output_dir = None # Checkpoint every epoch. checkpoint_at = lambda step: step % n_train_steps_per_epoch == 0 self._loop = supervised.training.Loop( model=model_fn(mode='train'), tasks=[train_task], eval_model=model_fn(mode='eval'), eval_tasks=[eval_task], output_dir=policy_output_dir, eval_at=network_eval_at, checkpoint_at=checkpoint_at, ) self._collect_model = model_fn(mode='collect') self._collect_model.init(shapes.signature(train_task.sample_batch)) # Validate the restored checkpoints. # TODO(pkozakowski): Move this to the base class once all Agents use Loop. if self._loop.step != self._epoch * self._n_train_steps_per_epoch: raise ValueError( 'The number of Loop steps must equal the number of Agent epochs ' 'times the number of steps per epoch, got {}, {} and {}.'. format(self._loop.step, self._epoch, self._n_train_steps_per_epoch))