Exemple #1
0
def my_train_fn(config, reporter):
    iterations = config.pop("train-iterations", 10)

    # Train for n iterations with high LR
    agent1 = PPO(env="CartPole-v0", config=config)
    for _ in range(iterations):
        result = agent1.train()
        result["phase"] = 1
        reporter(**result)
        phase1_time = result["timesteps_total"]
    state = agent1.save()
    agent1.stop()

    # Train for n iterations with low LR
    config["lr"] = 0.0001
    agent2 = PPO(env="CartPole-v0", config=config)
    agent2.restore(state)
    for _ in range(iterations):
        result = agent2.train()
        result["phase"] = 2
        result["timesteps_total"] += phase1_time  # keep time moving forward
        reporter(**result)
    agent2.stop()
Exemple #2
0
    # Create a new dummy Trainer to "fix" our checkpoint.
    new_trainer = PPO(config=config)
    # Get untrained weights for all policies.
    untrained_weights = new_trainer.get_weights()
    # Restore all policies from checkpoint.
    new_trainer.restore(best_checkpoint)
    # Set back all weights (except for 1st agent) to original
    # untrained weights.
    new_trainer.set_weights(
        {pid: w
         for pid, w in untrained_weights.items() if pid != "policy_0"})
    # Create the checkpoint from which tune can pick up the
    # experiment.
    new_checkpoint = new_trainer.save()
    new_trainer.stop()
    print(".. checkpoint to restore from (all policies reset, "
          f"except policy_0): {new_checkpoint}")

    print("Starting new tune.run")

    # Start our actual experiment.
    stop = {
        "episode_reward_mean": args.stop_reward,
        "timesteps_total": args.stop_timesteps,
        "training_iteration": args.stop_iters,
    }

    # Make sure, the non-1st policies are not updated anymore.
    config["multiagent"]["policies_to_train"] = [
        pid for pid in policy_ids if pid != "policy_0"
Exemple #3
0
    # Create a new dummy Algorithm to "fix" our checkpoint.
    new_algo = PPO(config=config)
    # Get untrained weights for all policies.
    untrained_weights = new_algo.get_weights()
    # Restore all policies from checkpoint.
    new_algo.restore(best_checkpoint)
    # Set back all weights (except for 1st agent) to original
    # untrained weights.
    new_algo.set_weights(
        {pid: w
         for pid, w in untrained_weights.items() if pid != "policy_0"})
    # Create the checkpoint from which tune can pick up the
    # experiment.
    new_checkpoint = new_algo.save()
    new_algo.stop()
    print(".. checkpoint to restore from (all policies reset, "
          f"except policy_0): {new_checkpoint}")

    print("Starting new tune.run")

    # Start our actual experiment.
    stop = {
        "episode_reward_mean": args.stop_reward,
        "timesteps_total": args.stop_timesteps,
        "training_iteration": args.stop_iters,
    }

    # Make sure, the non-1st policies are not updated anymore.
    config["multiagent"]["policies_to_train"] = [
        pid for pid in policy_ids if pid != "policy_0"