Exemplo n.º 1
0
    def __init__(self, env_name, atari_wrapper=False, test=False):
        if atari_wrapper:
            clip_rewards = not test
            self.env = make_wrap_atari(env_name, clip_rewards)
        else:
            self.env = gym.make(env_name)

        self.action_space = self.env.action_space
        self.observation_space = self.env.observation_space
Exemplo n.º 2
0
def run(args):
    if args.train:
        from agent_dir.agent_dqn import Agent_DQN
        env = make_wrap_atari(args.env_id, clip_rewards=True)
        #env = gym.make('LunarLander-v2')
        agent = Agent_DQN(args, env)
        if args.n_steps > 1:
            agent.nsteps_train()
        else:
            agent.train()
Exemplo n.º 3
0
    def __init__(self, env_name, args, atari_wrapper=False, test=False):
        if atari_wrapper:
            clip_rewards = not test
            self.env = make_wrap_atari(env_name, clip_rewards)
        else:
            self.env = gym.make(env_name)

        self.action_space = self.env.action_space
        self.observation_space = self.env.observation_space

        self.env = gym.wrappers.Monitor(self.env, './render', force=True)
Exemplo n.º 4
0
def run(args):
    if args.test_pg:
        env = gym.make('Pong-v0')
        from agent_dir.agent_pg import Agent_PG
        agent = Agent_PG(env, args)
        test(agent, env)

    if args.test_dqn:
        env = make_wrap_atari(args.env_id, clip_rewards=False)
        from agent_dir.agent_dqn import Agent_DQN
        agent = Agent_DQN(args, env)
        test(agent, env, total_episodes=100)
Exemplo n.º 5
0
    def __init__(self, env_name, args, atari_wrapper=False, test=False):
        if atari_wrapper:
            clip_rewards = not test  # if not test, clip reward, else not clip reward
            self.env = make_wrap_atari(env_name, clip_rewards)
        else:
            self.env = gym.make(env_name)

        self.action_space = self.env.action_space
        self.observation_space = self.env.observation_space

        self.do_render = args.do_render

        if args.video_dir:
            self.env = gym.wrappers.Monitor(self.env,
                                            args.video_dir,
                                            force=True)
Exemplo n.º 6
0
    def __init__(self, env_name, args, atari_wrapper=False, test=False):
        if env_name.find('Mario') != -1:
            from mario_env import create_mario_env
            self.env = create_mario_env(env_name)
        elif atari_wrapper:
            clip_rewards = not test
            self.env = make_wrap_atari(env_name, clip_rewards)
        else:
            self.env = gym.make(env_name)

        self.action_space = self.env.action_space
        self.observation_space = self.env.observation_space

        self.do_render = args.do_render

        if args.video_dir:
            self.env = gym.wrappers.Monitor(self.env, args.video_dir, force=True)
Exemplo n.º 7
0
    def __init__(self,
                 env_name,
                 args,
                 atari_wrapper=False,
                 test=False,
                 frame_stack_and_origin=False):
        if atari_wrapper:
            clip_rewards = not test
            self.frame_stack_and_origin = frame_stack_and_origin
            self.env = make_wrap_atari(env_name, clip_rewards,
                                       frame_stack_and_origin)
        else:
            self.env = gym.make(env_name)

        self.action_space = self.env.action_space
        self.observation_space = self.env.observation_space

        self.do_render = args.do_render

        if args.video_dir:
            self.env = gym.wrappers.Monitor(self.env,
                                            args.video_dir,
                                            force=True)
                print(" ")
                print("--------------")
                print(" ")
                agent.flag = False




### parsing the input from the command line
#######################################################################################################
parser = argparse.ArgumentParser(description="DQN Breakout")
parser.add_argument('--train_dqn', action='store_true', help='whether train DQN')
parser.add_argument('--test_dqn', action='store_true', help='whether test DQN')
parser.add_argument('--render', action='store_true', help='whether render environment or not')
parser.add_argument('--episodes', type=int, default = 50000, help='Number of episodes to run')
args = parser.parse_args()
#######################################################################################################

              

### I know globals are bad, but I think I can get away with a few!
env = make_wrap_atari('BreakoutNoFrameskip-v4')
batch_size = 32
model_check = 10
EPSILON_START = 1.0
video_width = 84
video_height = 84
stack_images = 4
agent = DQN_Agent(env.observation_space.shape,env.action_space.n)
agent.run_experient(args)