コード例 #1
0
def run(args):
    # not using atari_wrapper
    if args.train_pg:
        env_name = args.env_name or 'Pong-v0'
        env = Environment(env_name, args)
        from agent_dir.agent_pg import Agent_PG
        agent = Agent_PG(env, args)
        agent.train()

        # using atari_wrapper
    if args.train_dqn:
        env_name = args.env_name or 'BreakoutNoFrameskip-v4'
        env = Environment(env_name, args, atari_wrapper=True)
        from agent_dir.agent_dqn import Agent_DQN
        agent = Agent_DQN(env, args)
        agent.train()

    if args.test_pg:
        env = Environment('Pong-v0', args, test=True)
        from agent_dir.agent_pg import Agent_PG
        agent = Agent_PG(env, args)
        test(agent, env)

    if args.test_dqn:
        env = Environment('BreakoutNoFrameskip-v4',
                          args,
                          atari_wrapper=True,
                          test=True)
        from agent_dir.agent_dqn import Agent_DQN
        agent = Agent_DQN(env, args)
        test(agent, env, total_episodes=100)
コード例 #2
0
def run(args):
    if args.test_dqn:
        env = Environment('SeaquestNoFrameskip-v0',
                          args,
                          atari_wrapper=True,
                          test=True)
        from agent_dir.agent_dqn import Agent_DQN
        agent = Agent_DQN(env, args)
        play_game(agent, env, total_episodes=1)
コード例 #3
0
def run(args):
    if args.test_dqn:
        env = Environment('BreakoutNoFrameskip-v4',
                          args,
                          atari_wrapper=True,
                          test=True)
        from agent_dir.agent_dqn import Agent_DQN
        agent = Agent_DQN(env, args)
        test(agent, env, total_episodes=100)
コード例 #4
0
def run(args):
    if args.test_pg:
        env = gym.make('Pong-v0')
        from agent_dir.agent_pg import Agent_PG
        agent = Agent_PG(env, args)
        test(agent, env)

    if args.test_dqn:
        env = make_wrap_atari(args.env_id, clip_rewards=False)
        from agent_dir.agent_dqn import Agent_DQN
        agent = Agent_DQN(args, env)
        test(agent, env, total_episodes=100)
def run(args):
    # All frames are preprocessed with atari wrapper.
    if args.train_dqn:
        env_name = args.env_name or 'SeaquestNoFrameskip-v0'  #'BreakoutNoFrameskip-v4'
        env = Environment(env_name, args, atari_wrapper=True)
        from agent_dir.agent_dqn import Agent_DQN
        agent = Agent_DQN(env, args)
        agent.train()

    if args.test_dqn:
        env_name = args.env_name or 'SeaquestNoFrameskip-v0'
        env = Environment(env_name,
                          args,
                          atari_wrapper=True,
                          test=True,
                          frame_stack_and_origin=True)
        from agent_dir.agent_dqn import Agent_DQN
        agent = Agent_DQN(env, args)

        if (args.visualize):
            print("<< visualization >>\n")
            play_game(args, agent, env, total_episodes=1)
        else:
            print("<< test >>\n")
            test(agent, env)
コード例 #6
0
ファイル: main_dqn.py プロジェクト: exe1023/GA-final
def run(args):
    if args.train:
        from agent_dir.agent_dqn import Agent_DQN
        env = make_wrap_atari(args.env_id, clip_rewards=True)
        #env = gym.make('LunarLander-v2')
        agent = Agent_DQN(args, env)
        if args.n_steps > 1:
            agent.nsteps_train()
        else:
            agent.train()
コード例 #7
0
def run(args):
    # All frames are preprocessed with atari wrapper.
    if args.train:
        env_name = args.env_name or 'BreakoutNoFrameskip-v4'
        env = Environment(env_name, args, atari_wrapper=True)
        agent = Agent_DQN(env, args)
        agent.train()

    if args.test:
        env = Environment('BreakoutNoFrameskip-v4', args, atari_wrapper=True, test=True)
        agent = Agent_DQN(env, args)
        test(agent, env, total_episodes=arg.testEpisodes)
コード例 #8
0
def run(args):
    if args.train_dqn:
        env_name = args.env_name or 'SpaceInvadersNoFrameskip-v4'
        env = Environment(env_name, args, atari_wrapper=True)
        from agent_dir.agent_dqn import Agent_DQN
        agent = Agent_DQN(env, args)
        agent.train()

    if args.test_dqn:
        env = Environment('SpaceInvadersNoFrameskip-v4', args, atari_wrapper=True, test=True)
        from agent_dir.agent_dqn import Agent_DQN
        agent = Agent_DQN(env, args)
        test(agent, env, total_episodes=100)
コード例 #9
0
def run(args):
    if args.train_pg:
        env_name = args.env_name or 'Pong-v0'
        env = Environment(env_name, args)
        from agent_dir.agent_pg import Agent_PG
        agent = Agent_PG(env, args)
        agent.train()

    if args.train_pg_improved:
        env_name = args.env_name or 'Pong-v0'
        env = Environment(env_name, args)
        from agent_dir.agent_pg_improved import Agent_PG
        agent = Agent_PG(env, args)
        agent.train()

    if args.train_dqn:
        env_name = args.env_name or 'BreakoutNoFrameskip-v4'
        env = Environment(env_name, args, atari_wrapper=True)
        from agent_dir.agent_dqn import Agent_DQN
        agent = Agent_DQN(env, args)
        agent.train()

    if args.train_dqn_improved:
        env_name = args.env_name or 'BreakoutNoFrameskip-v4'
        env = Environment(env_name, args, atari_wrapper=True)
        from agent_dir.agent_dqn_improved import Agent_DQN
        agent = Agent_DQN(env, args)
        agent.train()

    if args.train_pong_ac:
        env_name = args.env_name or 'Pong-v0'
        env = Environment(env_name, args)
        from agent_dir.agent_pong_ac import Agent_AC
        agent = Agent_AC(env, args)
        agent.train()

    if args.train_break_ac:
        env_name = args.env_name or 'BreakoutNoFrameskip-v4'
        env = Environment(env_name, args, atari_wrapper=True)
        from agent_dir.agent_break_ac import Agent_AC
        agent = Agent_AC(env, args)
        agent.train()

    if args.train_pong_ac_improved:
        env_name = args.env_name or 'Pong-v0'
        env = Environment(env_name, args)
        from agent_dir.agent_pong_ac_improved import Agent_AC
        agent = Agent_AC(env, args)
        agent.train()

    if args.train_break_ac_improved:
        env_name = args.env_name or 'BreakoutNoFrameskip-v4'
        env = Environment(env_name, args, atari_wrapper=True)
        from agent_dir.agent_break_ac_improved import Agent_AC
        agent = Agent_AC(env, args)
        agent.train()

    if args.test_pg:
        env = Environment('Pong-v0', args, test=True)
        from agent_dir.agent_pg import Agent_PG
        agent = Agent_PG(env, args)
        test(agent, env)

    if args.test_pg_improved:
        env = Environment('Pong-v0', args, test=True)
        from agent_dir.agent_pg_improved import Agent_PG
        agent = Agent_PG(env, args)
        test(agent, env)

    if args.test_dqn:
        env = Environment('BreakoutNoFrameskip-v4',
                          args,
                          atari_wrapper=True,
                          test=True)
        from agent_dir.agent_dqn import Agent_DQN
        agent = Agent_DQN(env, args)
        test(agent, env, total_episodes=100)

    if args.test_dqn_improved:
        env = Environment('BreakoutNoFrameskip-v4',
                          args,
                          atari_wrapper=True,
                          test=True)
        from agent_dir.agent_dqn_improved import Agent_DQN
        agent = Agent_DQN(env, args)
        test(agent, env, total_episodes=100)

    if args.train_ddqn:
        env_name = args.env_name or 'BreakoutNoFrameskip-v4'
        env = Environment(env_name, args, atari_wrapper=True)
        from agent_dir.agent_ddqn import Agent_DQN
        agent = Agent_DQN(env, args)
        agent.train()

    if args.test_ddqn:
        env = Environment('BreakoutNoFrameskip-v4',
                          args,
                          atari_wrapper=True,
                          test=True)
        from agent_dir.agent_ddqn import Agent_DQN
        agent = Agent_DQN(env, args)
        test(agent, env, total_episodes=100)

    if args.train_dddqn:
        env_name = args.env_name or 'BreakoutNoFrameskip-v4'
        env = Environment(env_name, args, atari_wrapper=True)
        from agent_dir.agent_dddqn import Agent_DQN
        agent = Agent_DQN(env, args)
        agent.train()

    if args.test_dddqn:
        env = Environment('BreakoutNoFrameskip-v4',
                          args,
                          atari_wrapper=True,
                          test=True)
        from agent_dir.agent_dddqn import Agent_DQN
        agent = Agent_DQN(env, args)
        test(agent, env, total_episodes=100)