示例#1
0
    def __init__(
            self,
            trajectory_batch_stream,
            optimizer,
            lr_schedule,
            policy_distribution,
            advantage_estimator,
            value_fn,
            weight_fn=(lambda x: x),
            advantage_normalization=True,
            advantage_normalization_epsilon=1e-5,
            head_selector=(),
    ):
        """Initializes PolicyTrainTask.

    Args:
      trajectory_batch_stream: Generator of trax.rl.task.TimeStepBatch.
      optimizer: Optimizer for network training.
      lr_schedule: Learning rate schedule for network training.
      policy_distribution: Distribution over actions.
      advantage_estimator: Function
        (rewards, returns, values, dones) -> advantages, created by one of the
        functions from trax.rl.advantages.
      value_fn: Function TimeStepBatch -> array (batch_size, seq_len)
        calculating the baseline for advantage calculation. Can be used to
        implement actor-critic algorithms, by substituting a call to the value
        network as value_fn.
      weight_fn: Function float -> float to apply to advantages. Examples:
        - A2C: weight_fn = id
        - AWR: weight_fn = exp
        - behavioral cloning: weight_fn(_) = 1
      advantage_normalization: Whether to normalize advantages.
      advantage_normalization_epsilon: Epsilon to use then normalizing
        advantages.
      head_selector: Layer to apply to the network output to select the value
        head. Only needed in multitask training. By default, use a no-op layer,
        signified by an empty sequence of layers, ().
    """
        self.trajectory_batch_stream = trajectory_batch_stream
        self._value_fn = value_fn
        self._advantage_estimator = advantage_estimator
        self._weight_fn = weight_fn
        self._advantage_normalization = advantage_normalization
        self._advantage_normalization_epsilon = advantage_normalization_epsilon
        self.policy_distribution = policy_distribution

        labeled_data = map(self.policy_batch, trajectory_batch_stream)
        sample_batch = self.policy_batch(next(trajectory_batch_stream),
                                         shape_only=True)
        loss_layer = distributions.LogLoss(distribution=policy_distribution)
        loss_layer = tl.Serial(head_selector, loss_layer)
        super().__init__(
            labeled_data,
            loss_layer,
            optimizer,
            sample_batch=sample_batch,
            lr_schedule=lr_schedule,
            loss_name='policy_loss',
        )
示例#2
0
    def __init__(
        self,
        trajectory_batch_stream,
        optimizer,
        lr_schedule,
        policy_distribution,
        advantage_estimator,
        value_fn,
        weight_fn=(lambda x: x),
        advantage_normalization=True,
        advantage_normalization_epsilon=1e-5,
    ):
        """Initializes PolicyTrainTask.

    Args:
      trajectory_batch_stream: Generator of trax.rl.task.TrajectoryNp.
      optimizer: Optimizer for network training.
      lr_schedule: Learning rate schedule for network training.
      policy_distribution: Distribution over actions.
      advantage_estimator: Function
        (rewards, returns, values, dones) -> advantages, created by one of the
        functions from trax.rl.advantages.
      value_fn: Function TrajectoryNp -> array (batch_size, seq_len) calculating
        the baseline for advantage calculation. Can be used to implement
        actor-critic algorithms, by substituting a call to the value network
        as value_fn.
      weight_fn: Function float -> float to apply to advantages. Examples:
        - A2C: weight_fn = id
        - AWR: weight_fn = exp
        - behavioral cloning: weight_fn(_) = 1
      advantage_normalization: Whether to normalize advantages.
      advantage_normalization_epsilon: Epsilon to use then normalizing
        advantages.
    """
        self._value_fn = value_fn
        self._advantage_estimator = advantage_estimator
        self._weight_fn = weight_fn
        self._advantage_normalization = advantage_normalization
        self._advantage_normalization_epsilon = advantage_normalization_epsilon
        self.policy_distribution = policy_distribution

        labeled_data = map(self.policy_batch, trajectory_batch_stream)
        loss_layer = distributions.LogLoss(distribution=policy_distribution)
        super().__init__(
            labeled_data,
            loss_layer,
            optimizer,
            lr_schedule=lr_schedule,
        )
示例#3
0
 def policy_loss(self):
     """Policy loss."""
     return distributions.LogLoss(distribution=self._policy_dist)