예제 #1
0
파일: env.py 프로젝트: my87432122/pso_td3
 def fn():
     env = gym.make(env_id)
     if 'SuperMarioBros' in env_id:
         env = ReshapeReward(env, monitor=None)
         env = SkipObs(env)
     env = SingleEnv(env)
     env.seed(seed + rank)
     env.action_space.seed(seed + rank)
     return env
예제 #2
0
    def eval_policy(self, env_name, seed, eval_episodes=5):
        '''Runs policy for X episodes and returns average reward
        A fixed seed is used for the eval environment
        Only off-policy algos need this'''

        eval_env = gym.make(env_name)
        eval_env = SingleEnv(eval_env)
        eval_env.seed(seed + 100)
        # Assert env is wrapped by TimeLimit and env has attr _max_episode_steps
        for _ in range(eval_episodes):
            obs, done, ep_rew, ep_len = eval_env.reset(), False, 0., 0
            while not done:
                # Take deterministic actions at test time (noise=0)
                obs, rew, done, _ = eval_env.step(self.select_action(obs))
                ep_rew += rew
                ep_len += 1
            self.logger.store(TestEpRet=ep_rew, TestEplen=ep_len)
예제 #3
0
    parser.add_argument('--policy', type=str, default='mp_ppo')
    parser.add_argument('--env', type=str, default='SuperMarioBros-8-1-v0')
    parser.add_argument('--policy_type', type=str, default='cnn')
    args = parser.parse_args()

    exp_name = args.policy + '_' + args.env

    output = f'./data/{exp_name}/{exp_name}_seed_0'

    env = gym.make(args.env)

    if 'SuperMario' in args.env:
        env = JoypadSpace(env, SIMPLE_MOVEMENT)
        env = ReshapeReward(env, monitor=None)
        env = SkipObs(env)
    env = SingleEnv(env)
    env = gym.wrappers.Monitor(env,
                               DEFAULT_VIDEO_DIR,
                               force=True,
                               video_callable=lambda episode_id: True)
    model_name = os.path.join(output, 'pytorch_save', 'model.pth')
    save_name = os.path.join(output, 'environment_vars.pkl')

    kwargs = {
        'env':
        env,
        "actor_critic":
        core.CNNActorCritic
        if args.policy_type == 'cnn' else core.MLPActorCritic,
        'ac_kwargs':
        dict(hidden_sizes=[64] * 2),
예제 #4
0
if __name__ == '__main__' and '__file__' in globals():
    parser = argparse.ArgumentParser()
    parser.add_argument('--policy_name', type=str, default='td3')
    parser.add_argument('--env_name', type=str, default='HalfCheetahPyBulletEnv-v0')
    parser.add_argument('--policy_type', type=str, default='mlp')
    parser.add_argument('--seed', type=int, default=0)
    args = parser.parse_args()


    exp_name = args.policy_name + '_' + args.env_name

    output=f'./data/F_{exp_name}/{exp_name}_seed_{args.seed}'

    env = gym.make(args.env_name)

    env = SingleEnv(env)
    env = gym.wrappers.Monitor(env, DEFAULT_VIDEO_DIR, force=True, video_callable=lambda episode_id: True)
    model_name = os.path.join(output, 'pytorch_save', 'model.pth')
    save_name = os.path.join(output, 'environment_vars.pkl')

    kwargs = {
        'env': env,
        "actor_critic": core.MLPDetActorCritic,
        'ac_kwargs': dict(hidden_sizes=[400,300]),
        'device' : 'cpu'
    }

    policy = TD3(**kwargs)
    policy.ac.load_state_dict(torch.load(model_name, map_location='cpu'))

    # This Command for slower(setpts > 1.0) of faster(setpts < 1.0) video