示例#1
0
def wrap_env(env_id, max_frames=5000, clip_rewards=True, episode_life=True):
    if env_id.startswith('MiniGrid'):
        env = mini_grid_wrapper(
            env_id, max_frames=max_frames, clip_rewards=True)
    elif env_id.startswith('GDY'):
        env = griddly_wrapper(
            env_id, max_frames=max_frames, clip_rewards=True)
    else:
        env = atari_wrappers.wrap_deepmind(atari_wrappers.make_atari(
            env_id, max_frames=max_frames), episode_life=True, clip_rewards=True)
    return env
示例#2
0
def mini_grid_wrapper(env_id, max_frames=0, clip_rewards=True):
    env = gym.make(env_id)
    env = ReseedWrapper(env, seeds=[0])
    env = RGBImgObsWrapper(env)
    env = ImgObsWrapper(env)
    if max_frames:
        env = pfrl.wrappers.ContinuingTimeLimit(
            env, max_episode_steps=max_frames)
    # env = atari_wrappers.MaxAndSkipEnv(env, skip=0)
    env = atari_wrappers.wrap_deepmind(
        env, episode_life=False, clip_rewards=clip_rewards)
    return env
示例#3
0
文件: train_a3c.py 项目: lin826/pfrl
 def make_env(process_idx, test):
     # Use different random seeds for train and test envs
     process_seed = process_seeds[process_idx]
     env_seed = 2**31 - 1 - process_seed if test else process_seed
     env = atari_wrappers.wrap_deepmind(
         atari_wrappers.make_atari(args.env, max_frames=args.max_frames),
         episode_life=not test,
         clip_rewards=not test,
     )
     env.seed(int(env_seed))
     if args.monitor:
         env = pfrl.wrappers.Monitor(
             env, args.outdir, mode="evaluation" if test else "training")
     if args.render:
         env = pfrl.wrappers.Render(env)
     return env
示例#4
0
 def make_env(test):
     # Use different random seeds for train and test envs
     env_seed = test_seed if test else train_seed
     env = atari_wrappers.wrap_deepmind(
         atari_wrappers.make_atari(args.env, max_frames=args.max_frames),
         episode_life=not test,
         clip_rewards=not test,
     )
     env.seed(int(env_seed))
     if test:
         # Randomize actions like epsilon-greedy in evaluation as well
         env = pfrl.wrappers.RandomizeAction(env, args.eval_epsilon)
     if args.monitor:
         env = pfrl.wrappers.Monitor(
             env, args.outdir, mode="evaluation" if test else "training")
     if args.render:
         env = pfrl.wrappers.Render(env)
     return env
示例#5
0
 def make_env(idx, test):
     # Use different random seeds for train and test envs
     process_seed = int(process_seeds[idx])
     env_seed = 2**32 - 1 - process_seed if test else process_seed
     env = atari_wrappers.wrap_deepmind(
         atari_wrappers.make_atari(args.env, max_frames=args.max_frames),
         episode_life=not test,
         clip_rewards=not test,
         frame_stack=False,
     )
     if test:
         # Randomize actions like epsilon-greedy in evaluation as well
         env = pfrl.wrappers.RandomizeAction(env, args.eval_epsilon)
     env.seed(env_seed)
     if args.monitor:
         env = pfrl.wrappers.Monitor(
             env, args.outdir, mode="evaluation" if test else "training")
     if args.render:
         env = pfrl.wrappers.Render(env)
     return env
示例#6
0
    wrap_packed_sequences_recursive,
)

logging.basicConfig(level=20)

ap = argparse.ArgumentParser(description="pfrl RNN DQN")
ap.add_argument("--model")
ap.add_argument("--demo", action="store_true")
ap.add_argument("--sleep", default=.02, type=float)
ap.add_argument("--env", default="PongNoFrameskip-v4")
ap.add_argument("--steps", default=1e7, type=int)
ns = ap.parse_args()

env = atari_wrappers.wrap_deepmind(atari_wrappers.make_atari(ns.env,
                                                             max_frames=10000),
                                   episode_life=True,
                                   clip_rewards=True,
                                   frame_stack=False)
test_env = atari_wrappers.wrap_deepmind(atari_wrappers.make_atari(
    ns.env, max_frames=10000),
                                        episode_life=False,
                                        clip_rewards=False,
                                        frame_stack=False)


class MyNetwork(nn.Module):
    def __init__(self):
        super().__init__()
        self.l1 = nn.Conv2d(1, 32, 8, stride=4)
        self.l2 = nn.ReLU()
        self.l3 = nn.Conv2d(32, 64, 4, stride=2)
示例#7
0
def griddly_wrapper(env_id, max_frames=5000, clip_rewards=True):
    env = gym.make(env_id)
    env.reset()
    env = UnswapChannel(env)
    env = atari_wrappers.wrap_deepmind(env, episode_life=False, clip_rewards=True)
    return env
示例#8
0
import gym
from stable_baselines3 import DQN
from pfrl.wrappers import atari_wrappers

ap = argparse.ArgumentParser(description="DQN")
ap.add_argument("--env", default="PongNoFrameskip-v4")
ap.add_argument("--frame_stacks", default=4)
ap.add_argument("--learning_starts", default=100000, type=int)
ap.add_argument("--total_timesteps", default=1000000, type=int)
ap.add_argument("--save_path")
ap.add_argument("--tensorboard_log")
ns = ap.parse_args()

env = atari_wrappers.wrap_deepmind(
    atari_wrappers.make_atari(ns.env, max_frames=10000),
    episode_life=True,
    clip_rewards=True,
)
model = DQN('CnnPolicy',
            env,
            verbose=1,
            buffer_size=10000,
            learning_rate=.0001,
            learning_starts=ns.learning_starts,
            target_update_interval=1000,
            tensorboard_log=ns.tensorboard_log)
model.learn(total_timesteps=ns.total_timesteps)
if ns.save_path is not None:
    model.save(ns.save_path)