コード例 #1
0
ファイル: value_tasks_test.py プロジェクト: rizwandel/trax
  def test_integration_with_policy_tasks(self):
    # Integration test for policy + value training and eval.
    optimizer = opt.Adam()
    lr_schedule = lr_schedules.constant(1e-3)
    advantage_estimator = advantages.td_k(gamma=self._task.gamma, margin=1)
    policy_dist = distributions.create_distribution(self._task.action_space)
    body = lambda mode: tl.Dense(64)
    train_model = models.PolicyAndValue(policy_dist, body=body)
    eval_model = models.PolicyAndValue(policy_dist, body=body)

    head_selector = tl.Select([1])
    value_train_task = value_tasks.ValueTrainTask(
        self._trajectory_batch_stream,
        optimizer,
        lr_schedule,
        advantage_estimator,
        model=train_model,
        target_model=eval_model,
        head_selector=head_selector,
    )
    value_eval_task = value_tasks.ValueEvalTask(
        value_train_task, head_selector=head_selector
    )

    # Drop the value head - just tl.Select([0]) would pass it, and it would
    # override the targets.
    head_selector = tl.Select([0], n_in=2)
    policy_train_task = policy_tasks.PolicyTrainTask(
        self._trajectory_batch_stream,
        optimizer,
        lr_schedule,
        policy_dist,
        advantage_estimator,
        # Plug a trained critic as our value estimate.
        value_fn=value_train_task.value,
        head_selector=head_selector,
    )
    policy_eval_task = policy_tasks.PolicyEvalTask(
        policy_train_task, head_selector=head_selector
    )

    loop = training.Loop(
        model=train_model,
        eval_model=eval_model,
        tasks=[policy_train_task, value_train_task],
        eval_tasks=[policy_eval_task, value_eval_task],
        eval_at=(lambda _: True),
        # Switch the task every step.
        which_task=(lambda step: step % 2),
    )
    # Run for a couple of steps to make sure there are a few task switches.
    loop.run(n_steps=10)
コード例 #2
0
ファイル: training.py プロジェクト: yliu45/trax
  def __init__(
      self, task, model_fn,
      optimizer=adam.Adam,
      lr_schedule=lr.multifactor,
      batch_size=64,
      network_eval_at=None,
      n_eval_batches=1,
      max_slice_length=1,
      **kwargs
  ):
    """Initializes PolicyGradient.

    Args:
      task: Instance of trax.rl.task.RLTask.
      model_fn: Function (policy_distribution, mode) -> policy_model.
      optimizer: Optimizer for network training.
      lr_schedule: Learning rate schedule for network training.
      batch_size: Batch size for network training.
      network_eval_at: Function step -> bool indicating the training steps, when
        network evaluation should be performed.
      n_eval_batches: Number of batches to run during network evaluation.
      max_slice_length: The length of trajectory slices to run the network on.
      **kwargs: Keyword arguments passed to the superclass.
    """
    super().__init__(task, **kwargs)

    self._max_slice_length = max_slice_length
    trajectory_batch_stream = task.trajectory_batch_stream(
        batch_size,
        epochs=[-1],
        max_slice_length=self._max_slice_length,
        sample_trajectories_uniformly=True,
    )
    self._policy_dist = distributions.create_distribution(task.action_space)
    train_task = policy_tasks.PolicyTrainTask(
        trajectory_batch_stream,
        optimizer(),
        lr_schedule(),
        self._policy_dist,
        # Policy gradient uses the MC estimator. No need for margin - the MC
        # estimator only uses empirical returns.
        advantage_estimator=advantages.monte_carlo(task.gamma, margin=0),
        value_fn=self._value_fn,
    )
    eval_task = policy_tasks.PolicyEvalTask(train_task, n_eval_batches)
    model_fn = functools.partial(
        model_fn,
        policy_distribution=self._policy_dist,
    )

    if self._output_dir is not None:
      policy_output_dir = os.path.join(self._output_dir, 'policy')
    else:
      policy_output_dir = None
    # Checkpoint every epoch. We do one step per epoch, so that's every step.
    checkpoint_at = lambda _: True
    self._loop = supervised.training.Loop(
        model=model_fn(mode='train'),
        tasks=[train_task],
        eval_model=model_fn(mode='eval'),
        eval_tasks=[eval_task],
        output_dir=policy_output_dir,
        eval_at=network_eval_at,
        checkpoint_at=checkpoint_at,
    )
    self._collect_model = model_fn(mode='collect')
    self._collect_model.init(shapes.signature(train_task.sample_batch))

    # Validate the restored checkpoints. The number of network training steps
    # (self.loop.step) should be equal to the number of epochs (self._epoch),
    # because we do exactly one gradient step per epoch.
    # TODO(pkozakowski): Move this to the base class once all Agents use Loop.
    if self.loop.step != self._epoch:
      raise ValueError(
          'The number of Loop steps must equal the number of Agent epochs, '
          'got {} and {}.'.format(self.loop.step, self._epoch)
      )
コード例 #3
0
    def __init__(self,
                 task,
                 model_fn,
                 value_fn,
                 weight_fn,
                 n_replay_epochs,
                 n_train_steps_per_epoch,
                 optimizer=adam.Adam,
                 lr_schedule=lr.multifactor,
                 batch_size=64,
                 network_eval_at=None,
                 n_eval_batches=1,
                 max_slice_length=1,
                 **kwargs):
        """Initializes LoopPolicyAgent.

    Args:
      task: Instance of trax.rl.task.RLTask.
      model_fn: Function (policy_distribution, mode) -> policy_model.
      value_fn: Function TimeStepBatch -> array (batch_size, seq_len)
        calculating the baseline for advantage calculation.
      weight_fn: Function float -> float to apply to advantages when calculating
        policy loss.
      n_replay_epochs: Number of last epochs to take into the replay buffer;
        only makes sense for off-policy algorithms.
      n_train_steps_per_epoch: Number of steps to train the policy network for
        in each epoch.
      optimizer: Optimizer for network training.
      lr_schedule: Learning rate schedule for network training.
      batch_size: Batch size for network training.
      network_eval_at: Function step -> bool indicating the training steps, when
        network evaluation should be performed.
      n_eval_batches: Number of batches to run during network evaluation.
      max_slice_length: The length of trajectory slices to run the network on.
      **kwargs: Keyword arguments passed to the superclass.
    """
        self._n_train_steps_per_epoch = n_train_steps_per_epoch
        super().__init__(task, **kwargs)

        task.set_n_replay_epochs(n_replay_epochs)
        self._max_slice_length = max_slice_length
        trajectory_batch_stream = task.trajectory_batch_stream(
            batch_size,
            epochs=[-(ep + 1) for ep in range(n_replay_epochs)],
            max_slice_length=self._max_slice_length,
            sample_trajectories_uniformly=True,
        )
        self._policy_dist = distributions.create_distribution(
            task.action_space)
        train_task = policy_tasks.PolicyTrainTask(
            trajectory_batch_stream,
            optimizer(),
            lr_schedule(),
            self._policy_dist,
            # Without a value network it doesn't make a lot of sense to use
            # a better advantage estimator than MC.
            advantage_estimator=advantages.monte_carlo(task.gamma, margin=0),
            value_fn=value_fn,
            weight_fn=weight_fn,
        )
        eval_task = policy_tasks.PolicyEvalTask(train_task, n_eval_batches)
        model_fn = functools.partial(
            model_fn,
            policy_distribution=self._policy_dist,
        )

        if self._output_dir is not None:
            policy_output_dir = os.path.join(self._output_dir, 'policy')
        else:
            policy_output_dir = None
        # Checkpoint every epoch.
        checkpoint_at = lambda step: step % n_train_steps_per_epoch == 0
        self._loop = supervised.training.Loop(
            model=model_fn(mode='train'),
            tasks=[train_task],
            eval_model=model_fn(mode='eval'),
            eval_tasks=[eval_task],
            output_dir=policy_output_dir,
            eval_at=network_eval_at,
            checkpoint_at=checkpoint_at,
        )
        self._collect_model = model_fn(mode='collect')
        self._collect_model.init(shapes.signature(train_task.sample_batch))

        # Validate the restored checkpoints.
        # TODO(pkozakowski): Move this to the base class once all Agents use Loop.
        if self._loop.step != self._epoch * self._n_train_steps_per_epoch:
            raise ValueError(
                'The number of Loop steps must equal the number of Agent epochs '
                'times the number of steps per epoch, got {}, {} and {}.'.
                format(self._loop.step, self._epoch,
                       self._n_train_steps_per_epoch))