def __init__( self, trajectory_batch_stream, optimizer, lr_schedule, policy_distribution, advantage_estimator, value_fn, weight_fn=(lambda x: x), advantage_normalization=True, advantage_normalization_epsilon=1e-5, head_selector=(), ): """Initializes PolicyTrainTask. Args: trajectory_batch_stream: Generator of trax.rl.task.TimeStepBatch. optimizer: Optimizer for network training. lr_schedule: Learning rate schedule for network training. policy_distribution: Distribution over actions. advantage_estimator: Function (rewards, returns, values, dones) -> advantages, created by one of the functions from trax.rl.advantages. value_fn: Function TimeStepBatch -> array (batch_size, seq_len) calculating the baseline for advantage calculation. Can be used to implement actor-critic algorithms, by substituting a call to the value network as value_fn. weight_fn: Function float -> float to apply to advantages. Examples: - A2C: weight_fn = id - AWR: weight_fn = exp - behavioral cloning: weight_fn(_) = 1 advantage_normalization: Whether to normalize advantages. advantage_normalization_epsilon: Epsilon to use then normalizing advantages. head_selector: Layer to apply to the network output to select the value head. Only needed in multitask training. By default, use a no-op layer, signified by an empty sequence of layers, (). """ self.trajectory_batch_stream = trajectory_batch_stream self._value_fn = value_fn self._advantage_estimator = advantage_estimator self._weight_fn = weight_fn self._advantage_normalization = advantage_normalization self._advantage_normalization_epsilon = advantage_normalization_epsilon self.policy_distribution = policy_distribution labeled_data = map(self.policy_batch, trajectory_batch_stream) sample_batch = self.policy_batch(next(trajectory_batch_stream), shape_only=True) loss_layer = distributions.LogLoss(distribution=policy_distribution) loss_layer = tl.Serial(head_selector, loss_layer) super().__init__( labeled_data, loss_layer, optimizer, sample_batch=sample_batch, lr_schedule=lr_schedule, loss_name='policy_loss', )
def __init__( self, trajectory_batch_stream, optimizer, lr_schedule, policy_distribution, advantage_estimator, value_fn, weight_fn=(lambda x: x), advantage_normalization=True, advantage_normalization_epsilon=1e-5, ): """Initializes PolicyTrainTask. Args: trajectory_batch_stream: Generator of trax.rl.task.TrajectoryNp. optimizer: Optimizer for network training. lr_schedule: Learning rate schedule for network training. policy_distribution: Distribution over actions. advantage_estimator: Function (rewards, returns, values, dones) -> advantages, created by one of the functions from trax.rl.advantages. value_fn: Function TrajectoryNp -> array (batch_size, seq_len) calculating the baseline for advantage calculation. Can be used to implement actor-critic algorithms, by substituting a call to the value network as value_fn. weight_fn: Function float -> float to apply to advantages. Examples: - A2C: weight_fn = id - AWR: weight_fn = exp - behavioral cloning: weight_fn(_) = 1 advantage_normalization: Whether to normalize advantages. advantage_normalization_epsilon: Epsilon to use then normalizing advantages. """ self._value_fn = value_fn self._advantage_estimator = advantage_estimator self._weight_fn = weight_fn self._advantage_normalization = advantage_normalization self._advantage_normalization_epsilon = advantage_normalization_epsilon self.policy_distribution = policy_distribution labeled_data = map(self.policy_batch, trajectory_batch_stream) loss_layer = distributions.LogLoss(distribution=policy_distribution) super().__init__( labeled_data, loss_layer, optimizer, lr_schedule=lr_schedule, )
def policy_loss(self): """Policy loss.""" return distributions.LogLoss(distribution=self._policy_dist)