Exemplo n.º 1
0
def test(env, actor_model):
    """
            Tests the model.

            Parameters:
                    env - the environment to test the policy on
                    actor_model - the actor model to load in

            Return:
                    None
    """
    print(f"Testing {actor_model}", flush=True)

    # If the actor model is not specified, then exit
    if actor_model == '':
        print(f"Didn't specify model file. Exiting.", flush=True)
        sys.exit(0)

    # Extract out dimensions of observation and action spaces
    obs_dim = env.observation_space.shape[0]
    act_dim = env.action_space.shape[0]

    # Build our policy the same way we build our actor model in PPO
    policy = FeedForwardNN(obs_dim, act_dim)

    # Load in the actor model saved by the PPO algorithm
    policy.load_state_dict(torch.load(actor_model))

    # Evaluate our policy with a separate module, eval_policy, to demonstrate
    # that once we are done training the model/policy with ppo.py, we no longer need
    # ppo.py since it only contains the training algorithm. The model/policy itself exists
    # independently as a binary file that can be loaded in with torch.
    eval_policy(policy=policy, env=env, render=True)
Exemplo n.º 2
0
def test(env, datapath, actor_model, mode):
    print(f"Testing {actor_model}", flush=True)

    if actor_model == '':
        print(f"Didn't specify model file. Exiting.", flush=True)
        sys.exit(0)

    obs_dim = env.observation_space.shape[0]
    act_dim = env.action_space.shape[0]
    entries = os.listdir(datapath)
    entries = [int(x) for x in entries]
    entries.sort()

    if mode == 'test':
        policy = FeedForwardNN(obs_dim, act_dim)
        actor_model = datapath + "/" + str(entries[-1]) + "/" + actor_model
        policy.load_state_dict(torch.load(actor_model))
        print("Iteration " + str(entries[-1]))
        eval_policy(policy=policy, env=env, render=True)

    if mode == 'progress':
        eval_progress(list_dir = entries, file_dir = datapath, actor_model = actor_model,\
          obs_dim = obs_dim, act_dim = act_dim, env = env, render = True)
        '''