예제 #1
0
        config=config,
        stop={"training_iteration": args.pre_training_iters},
        verbose=1,
        checkpoint_freq=1,
        checkpoint_at_end=True,
    )
    print("Pre-training done.")

    best_checkpoint = results.get_best_checkpoint(results.trials[0],
                                                  mode="max")
    print(f".. best checkpoint was: {best_checkpoint}")

    # Create a new dummy Trainer to "fix" our checkpoint.
    new_trainer = PPO(config=config)
    # Get untrained weights for all policies.
    untrained_weights = new_trainer.get_weights()
    # Restore all policies from checkpoint.
    new_trainer.restore(best_checkpoint)
    # Set back all weights (except for 1st agent) to original
    # untrained weights.
    new_trainer.set_weights(
        {pid: w
         for pid, w in untrained_weights.items() if pid != "policy_0"})
    # Create the checkpoint from which tune can pick up the
    # experiment.
    new_checkpoint = new_trainer.save()
    new_trainer.stop()
    print(".. checkpoint to restore from (all policies reset, "
          f"except policy_0): {new_checkpoint}")

    print("Starting new tune.run")
예제 #2
0
        config=config,
        stop={"training_iteration": args.pre_training_iters},
        verbose=1,
        checkpoint_freq=1,
        checkpoint_at_end=True,
    )
    print("Pre-training done.")

    best_checkpoint = results.get_best_checkpoint(results.trials[0],
                                                  mode="max")
    print(f".. best checkpoint was: {best_checkpoint}")

    # Create a new dummy Algorithm to "fix" our checkpoint.
    new_algo = PPO(config=config)
    # Get untrained weights for all policies.
    untrained_weights = new_algo.get_weights()
    # Restore all policies from checkpoint.
    new_algo.restore(best_checkpoint)
    # Set back all weights (except for 1st agent) to original
    # untrained weights.
    new_algo.set_weights(
        {pid: w
         for pid, w in untrained_weights.items() if pid != "policy_0"})
    # Create the checkpoint from which tune can pick up the
    # experiment.
    new_checkpoint = new_algo.save()
    new_algo.stop()
    print(".. checkpoint to restore from (all policies reset, "
          f"except policy_0): {new_checkpoint}")

    print("Starting new tune.run")
예제 #3
0
    #     dqn_policy: X
    #     ppo_policy: Y
    for i in range(args.stop_iters):
        print("== Iteration", i, "==")

        # improve the DQN policy
        print("-- DQN --")
        result_dqn = dqn_trainer.train()
        print(pretty_print(result_dqn))

        # improve the PPO policy
        print("-- PPO --")
        result_ppo = ppo_trainer.train()
        print(pretty_print(result_ppo))

        # Test passed gracefully.
        if (args.as_test
                and result_dqn["episode_reward_mean"] > args.stop_reward
                and result_ppo["episode_reward_mean"] > args.stop_reward):
            print("test passed (both agents above requested reward)")
            quit(0)

        # swap weights to synchronize
        dqn_trainer.set_weights(ppo_trainer.get_weights(["ppo_policy"]))
        ppo_trainer.set_weights(dqn_trainer.get_weights(["dqn_policy"]))

    # Desired reward not reached.
    if args.as_test:
        raise ValueError("Desired reward ({}) not reached!".format(
            args.stop_reward))
예제 #4
0
    #     dqn_policy: X
    #     ppo_policy: Y
    for i in range(args.stop_iters):
        print("== Iteration", i, "==")

        # improve the DQN policy
        print("-- DQN --")
        result_dqn = dqn.train()
        print(pretty_print(result_dqn))

        # improve the PPO policy
        print("-- PPO --")
        result_ppo = ppo.train()
        print(pretty_print(result_ppo))

        # Test passed gracefully.
        if (args.as_test
                and result_dqn["episode_reward_mean"] > args.stop_reward
                and result_ppo["episode_reward_mean"] > args.stop_reward):
            print("test passed (both agents above requested reward)")
            quit(0)

        # swap weights to synchronize
        dqn.set_weights(ppo.get_weights(["ppo_policy"]))
        ppo.set_weights(dqn.get_weights(["dqn_policy"]))

    # Desired reward not reached.
    if args.as_test:
        raise ValueError("Desired reward ({}) not reached!".format(
            args.stop_reward))