Beispiel #1
0
def record_test(directory, video_path, n):
    make_Dirs(video_path + n + '/')
    env = gym_make(ENV_ID)
    env = wrappers.Monitor(env,
                           video_path + n + '/',
                           video_callable=lambda episode_id: True,
                           force=True)
    env.seed(int(n) * 7)
    np.random.seed(int(n) * 7)
    torch.manual_seed(int(n) * 7)

    agent = MujocoFfAgent()
    agent.initialize(env.spaces)

    netword_state_dict = None
    try:
        network_state_dict = torch.load(directory + 'agent_model.pth')
    except (FileNotFoundError):
        print("No data found for the PPO agent (No existing model).")
        network_state_dict = None
        return

    if network_state_dict != None:
        agent.load_state_dict(network_state_dict)
    else:
        return

    agent.to_device(0)

    frame_idx = 0
    print("Start Test Episode for {}".format(n))
    done = False
    ### Interaction
    step = 0
    state = env.reset()
    prev_action = env.action_space.sample()
    prev_reward = 0.
    while not done:  # or step < MAX_STEPS:
        env.render()
        state = torch.FloatTensor(state)
        prev_action = torch.FloatTensor(prev_action)
        prev_reward = torch.FloatTensor([prev_reward])
        #agent.eval_mode(step) # determinitic distribution. The std is ignored.
        action = agent.step(state, prev_action, prev_reward).action
        action = action.detach().cpu().numpy()
        next_state, reward, done, _ = env.step(action)

        state = next_state
        prev_action = action
        prev_reward = reward
        frame_idx += 1
        step += 1

        if done:
            break
    env.close()
        formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument('--path',
                        help='path to params.pkl',
                        default='/home/alex/parkour-learning/data/params.pkl')
    parser.add_argument(
        '--env',
        default='HumanoidPrimitivePretraining-v0',
        choices=['HumanoidPrimitivePretraining-v0', 'TrackEnv-v0'])
    parser.add_argument('--algo', default='ppo', choices=['sac', 'ppo'])
    args = parser.parse_args()

    snapshot = torch.load(args.path, map_location=torch.device('cpu'))
    agent_state_dict = snapshot['agent_state_dict']
    env = GymEnvWrapper(gym.make(args.env, render=True))
    if args.algo == 'ppo':
        if args.env == 'TrackEnv-v0':
            agent = MujocoFfAgent(ModelCls=PpoMcpVisionModel)
        else:
            agent = MujocoFfAgent(ModelCls=PPOMcpModel)
    else:
        if args.env == 'TrackEnv-v0':
            agent = SacAgent(ModelCls=PiVisionModel,
                             QModelCls=QofMuVisionModel)
        else:
            agent = SacAgent(ModelCls=PiMCPModel, QModelCls=QofMCPModel)

    agent.initialize(env_spaces=env.spaces)
    agent.load_state_dict(agent_state_dict)
    agent.eval_mode(0)
    simulate_policy(env, agent)