Пример #1
0
def after_optimizer_iteration(trainer, fetches):
    """Update the policies pool in each policy."""
    update_kl(trainer, fetches)
    for pid in fetches.keys():
        fetches[pid]["novelty_target"] = \
            trainer.get_policy(pid)._novelty_target

    if trainer.config[DELAY_UPDATE] and (not trainer.config[I_AM_CLONE]):
        if trainer.workers.remote_workers():
            weights = ray.put(trainer.workers.local_worker().get_weights())
            for e in trainer.workers.remote_workers():
                e.set_weights.remote(weights)

            def _delay_update_for_worker(worker, worker_index):
                worker.foreach_policy(lambda p, _: p.update_clone_network())

            trainer.workers.foreach_worker_with_index(_delay_update_for_worker)

    if trainer.config[CONSTRAIN_NOVELTY] is not None:
        # print("***** enter update after optimizer iteration")

        def update(pi, pi_id):
            if pi_id in fetches:
                pi.update_alpha(fetches[pi_id]["novelty_reward_mean"])
            else:
                logger.debug(
                    "No data for {}, not updating alpha".format(pi_id))

        trainer.workers.foreach_worker(
            lambda w: w.foreach_trainable_policy(update))
Пример #2
0
def update_target_and_kl(trainer, fetches):
    # Update the KL coeff depending on how many steps LearnerThread has stepped
    # through
    learner_steps = trainer.optimizer.learner.num_steps
    if learner_steps >= trainer.target_update_frequency:

        # Update Target Network
        trainer.optimizer.learner.num_steps = 0
        trainer.workers.local_worker().foreach_trainable_policy(
            lambda p, _: p.update_target())

        # Also update KL Coeff
        if trainer.config["use_kl_loss"]:
            update_kl(trainer, trainer.optimizer.learner.stats)
Пример #3
0
def after_optimizer_iteration(trainer, fetches):
    """Update the policies pool in each policy."""
    update_kl(trainer, fetches)  # original PPO procedure

    # only update the policies pool if used DELAY_UPDATE, otherwise
    # the policies_pool in each policy is simply not used, so we don't
    # need to update it.
    if trainer.config[DELAY_UPDATE]:
        if trainer.workers.remote_workers():
            weights = ray.put(trainer.workers.local_worker().get_weights())
            for e in trainer.workers.remote_workers():
                e.set_weights.remote(weights)

            def _delay_update_for_worker(worker, worker_index):
                worker.foreach_policy(lambda p, _: p.update_target())

            trainer.workers.foreach_worker_with_index(_delay_update_for_worker)
Пример #4
0
def wrap_after_train_result(trainer, fetches):
    update_novelty(trainer, fetches)
    update_kl(trainer, fetches)
Пример #5
0
def after_optimizer_step(trainer, fetches):
    update_kl(trainer=trainer, fetches=fetches)
    perform_relevant_custom_evals(trainer=trainer, fetches=fetches)