Exemple #1
0
def my_policy_1155152886(num_envs=1):
    """We will use this function to load your agent and then testing.

    Make sure this function can run bug-free, when the working directory is
    "ierg5350-assignment/assignment5/"

    You can rewrite this function completely if you have custom agents, but you
    need to make sure the codes is bug-free and add necessary description on the notebook.

    Please rename this function!!! We will use program to automatically detect your agent, so a wrong function name
    will fail the evaluation. Run this file directly to make sure everything is fine.
    """
    # [TODO] rewrite this function
    # [TODO] CAUTION! PLEASE CHANGE THE NAME OF THIS FUNCTION!!! Otherwise our program can't find your agent!
    my_agent_log_dir = 'data/cCarRacing-v0_PPO_12-01_21-39'
    my_agent_suffix = 'final'

    # checkpoint_path = osp.join(my_agent_log_dir, "checkpoint-{}.pkl".format(my_agent_suffix))
    # if not osp.exists(checkpoint_path):
    #     raise ValueError("Can't find anything at {}!".format(checkpoint_path))
    # else:
    #     print("Found your checkpoint at {}!".format(checkpoint_path))

    return PolicyAPI("cCarRacing-v0",
                     num_envs=num_envs,
                     log_dir=my_agent_log_dir,
                     suffix=my_agent_suffix)
Exemple #2
0
def student_compute_action_function(num_envs=1):
    """We will use this function to load your agent and then testing.

    Make sure this function can run bug-free, when the working directory is
    "ierg6130-assignment/assignment4/"

    You can rewrite this function completely if you have custom agents, but you
    need to make sure the codes is bug-free and add necessary description on
    report_SID.md

    Run this file directly to make sure everything is fine.
    """
    # [TODO] rewrite this function
    my_agent_log_dir = ""
    my_agent_suffix = ""

    checkpoint_path = osp.join(my_agent_log_dir,
                               "checkpoint-{}.pkl".format(my_agent_suffix))
    if not osp.exists(checkpoint_path):
        print("Can't find anything at {}!".format(checkpoint_path))
    else:
        print("Found your checkpoint at {}!".format(checkpoint_path))

    return PolicyAPI(num_envs=num_envs,
                     log_dir=my_agent_log_dir,
                     suffix=my_agent_suffix)
Exemple #3
0
def my_policy_zhenghao(num_envs=1):
    return PolicyAPI("cCarRacing-v0",
                     num_envs=num_envs,
                     log_dir="data/alphacar",
                     suffix="zhenghao")
Exemple #4
0
def my_policy_1155156694(num_envs=1):
    return PolicyAPI("cCarRacing-v0",
                     num_envs=num_envs,
                     log_dir="data/alphacar",
                     suffix="alphacar")
Exemple #5
0
def train(args):
    # Verify algorithm and config
    algo = args.algo
    if algo == "PPO":
        config = ppo_config
    else:
        raise ValueError("args.algo must in [PPO]")
    config.num_envs = args.num_envs
    config.lr = args.lr
    config.entropy_loss_weight = args.entropy
    assert args.env_id in ["cPong-v0", "cCarRacing-v0"], args.env_id

    # Seed the environments and setup torch
    seed = args.seed
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
    torch.set_num_threads(1)

    # Create vectorized environments
    num_envs = args.num_envs
    env_id = args.env_id if not args.opponent else "cCarRacingDouble-v0"

    # Clean log directory
    log_dir = verify_log_dir(
        args.log_dir,
        "{}_{}_{}".format(env_id, algo,
                          datetime.datetime.now().strftime("%m-%d_%H-%M")))

    if args.opponent:
        assert args.num_eval_envs == 0

        from competitive_rl.car_racing import make_competitive_car_racing
        from load_agents import PolicyAPI

        restore_log_dir = os.path.dirname(args.restore)
        restore_suffix = os.path.basename(
            args.restore).split("checkpoint-")[1].split(".pkl")[0]
        opponent_policy = PolicyAPI("cCarRacing-v0",
                                    num_envs=1,
                                    log_dir=restore_log_dir,
                                    suffix=restore_suffix)
        envs = make_competitive_car_racing(opponent_policy=opponent_policy,
                                           num_envs=num_envs,
                                           asynchronous=not args.test)
    else:
        envs = make_envs(env_id=env_id,
                         seed=seed,
                         log_dir=log_dir,
                         num_envs=num_envs,
                         asynchronous=not args.test,
                         resized_dim=config.resized_dim,
                         action_repeat=args.action_repeat)

    if args.num_eval_envs > 0:
        eval_envs = make_envs(env_id=env_id,
                              seed=seed,
                              log_dir=log_dir,
                              num_envs=args.num_eval_envs,
                              asynchronous=not args.test,
                              resized_dim=config.resized_dim,
                              action_repeat=args.action_repeat)
    else:
        eval_envs = None

    # Setup trainer
    if algo == "PPO":
        trainer = PPOTrainer(envs, config)
    else:
        raise ValueError("Unknown algorithm {}".format(algo))

    if args.restore:
        restore_log_dir = os.path.dirname(args.restore)
        restore_suffix = os.path.basename(
            args.restore).split("checkpoint-")[1].split(".pkl")[0]
        success = trainer.load_w(restore_log_dir, restore_suffix)
        if not success:
            raise ValueError(
                "We can't restore your agent. The log_dir is {} and the suffix is {}"
                .format(restore_log_dir, restore_suffix))

    # Start training
    print("Start training!")
    obs = envs.reset()
    # frame_stack_tensor.update(obs)
    raw_obs = trainer.process_obs(obs)
    processed_obs = trainer.model.world_model(raw_obs)
    trainer.rollouts.before_update(obs, processed_obs)

    try:
        _train(trainer, envs, eval_envs, config, num_envs, algo, log_dir,
               False, False)
    except KeyboardInterrupt:
        print(
            "The training is stopped by user. The log directory is {}. Now we finish the training."
            .format(log_dir))

    trainer.save_w(log_dir, "final")
    envs.close()