Exemplo n.º 1
0
  def test_integration_with_policy_tasks(self):
    # Integration test for policy + value training and eval.
    optimizer = opt.Adam()
    lr_schedule = lr_schedules.constant(1e-3)
    advantage_estimator = advantages.td_k(gamma=self._task.gamma, margin=1)
    policy_dist = distributions.create_distribution(self._task.action_space)
    body = lambda mode: tl.Dense(64)
    train_model = models.PolicyAndValue(policy_dist, body=body)
    eval_model = models.PolicyAndValue(policy_dist, body=body)

    head_selector = tl.Select([1])
    value_train_task = value_tasks.ValueTrainTask(
        self._trajectory_batch_stream,
        optimizer,
        lr_schedule,
        advantage_estimator,
        model=train_model,
        target_model=eval_model,
        head_selector=head_selector,
    )
    value_eval_task = value_tasks.ValueEvalTask(
        value_train_task, head_selector=head_selector
    )

    # Drop the value head - just tl.Select([0]) would pass it, and it would
    # override the targets.
    head_selector = tl.Select([0], n_in=2)
    policy_train_task = policy_tasks.PolicyTrainTask(
        self._trajectory_batch_stream,
        optimizer,
        lr_schedule,
        policy_dist,
        advantage_estimator,
        # Plug a trained critic as our value estimate.
        value_fn=value_train_task.value,
        head_selector=head_selector,
    )
    policy_eval_task = policy_tasks.PolicyEvalTask(
        policy_train_task, head_selector=head_selector
    )

    loop = training.Loop(
        model=train_model,
        eval_model=eval_model,
        tasks=[policy_train_task, value_train_task],
        eval_tasks=[policy_eval_task, value_eval_task],
        eval_at=(lambda _: True),
        # Switch the task every step.
        which_task=(lambda step: step % 2),
    )
    # Run for a couple of steps to make sure there are a few task switches.
    loop.run(n_steps=10)
Exemplo n.º 2
0
 def test_value_tasks_smoke(self):
     # Smoke test for train + eval.
     model = self._model_fn(mode='train')
     train_task = value_tasks.ValueTrainTask(
         self._trajectory_batch_stream,
         optimizer=opt.Adam(),
         lr_schedule=lr_schedules.constant(1e-3),
         advantage_estimator=advantages.td_k(gamma=self._task.gamma,
                                             margin=1),
         model=model,
     )
     eval_task = value_tasks.ValueEvalTask(train_task)
     loop = training.Loop(
         model=model,
         tasks=[train_task],
         eval_tasks=[eval_task],
         eval_at=(lambda _: True),
     )
     loop.run(n_steps=1)