コード例 #1
0
def run(args):
    # not using atari_wrapper
    if args.train_pg:
        env_name = args.env_name or 'Pong-v0'
        env = Environment(env_name, args)
        from agent_dir.agent_pg import Agent_PG
        agent = Agent_PG(env, args)
        agent.train()

        # using atari_wrapper
    if args.train_dqn:
        env_name = args.env_name or 'BreakoutNoFrameskip-v4'
        env = Environment(env_name, args, atari_wrapper=True)
        from agent_dir.agent_dqn import Agent_DQN
        agent = Agent_DQN(env, args)
        agent.train()

    if args.test_pg:
        env = Environment('Pong-v0', args, test=True)
        from agent_dir.agent_pg import Agent_PG
        agent = Agent_PG(env, args)
        test(agent, env)

    if args.test_dqn:
        env = Environment('BreakoutNoFrameskip-v4',
                          args,
                          atari_wrapper=True,
                          test=True)
        from agent_dir.agent_dqn import Agent_DQN
        agent = Agent_DQN(env, args)
        test(agent, env, total_episodes=100)
コード例 #2
0
ファイル: main_pg.py プロジェクト: exe1023/GA-final
def run(args):
    #env = gym.make('CartPole-v0')
    #env = gym.make('LunarLander-v2')
    env = trap.MKTrap(m=20, k=5)
    #env = gym.make('Pong-v0')
    solve = (195, 100
             )  # we solve cartpole when getting reward of 195 over 100 episode
    agent = Agent_PG(args, env, solve)
    agent.train()
コード例 #3
0
ファイル: test40.py プロジェクト: sinogaorui/DQ
def run(args):
    if args.test_pg:
        env = Environment('Pong-v0', args, test=True)
        from agent_dir.agent_pg import Agent_PG
        agent = Agent_PG(env, args)
        test(agent, env)

    if args.test_dqn:
        env = Environment('BreakoutNoFrameskip-v4', args, atari_wrapper=True, test=True)
        from agent_dir.agent_dqn import Agent_DQN
        agent = Agent_DQN(env, args)
        test(agent, env, total_episodes=5)#100
コード例 #4
0
def run(args):
    if args.test_pg:
        env = gym.make('Pong-v0')
        from agent_dir.agent_pg import Agent_PG
        agent = Agent_PG(env, args)
        test(agent, env)

    if args.test_dqn:
        env = make_wrap_atari(args.env_id, clip_rewards=False)
        from agent_dir.agent_dqn import Agent_DQN
        agent = Agent_DQN(args, env)
        test(agent, env, total_episodes=100)
コード例 #5
0
def run(args):
    if args.train:
        env_name = args.env_name or 'Pong-v0'
        env = Environment(env_name, args)
        from agent_dir.agent_pg import Agent_PG
        agent = Agent_PG(env, args)
        agent.train()

    if args.test:
        env = Environment('Pong-v0', args, test=True)
        from agent_dir.agent_pg import Agent_PG
        agent = Agent_PG(env, args)
        test(agent, env)
コード例 #6
0
ファイル: main.py プロジェクト: ced211/INFO8003-1
def run(args):
    if args.train_pg:
        env_name = args.env_name or 'Pong-v0'
        env = Environment(env_name, args)
        from agent_dir.agent_pg import Agent_PG
        agent = Agent_PG(env, args)
        agent.train()

    if args.test_pg:
        env = Environment('Pong-v0', args, test=True)
        from agent_dir.agent_pg import Agent_PG
        agent = Agent_PG(env, args)
        test(agent, env)

    # Experiment on Cartpole only, test unsupported
    if args.train_ac:
        env_name = args.env_name or 'CartPole-v0'
        env = Environment(env_name, args)
        from agent_dir.agent_actorcritic import Agent_ActorCritic
        agent = Agent_ActorCritic(env, args)
        agent.train()
    if args.train_pgc:
        env_name = args.env_name or 'CartPole-v0'
        env = Environment(env_name, args)
        from agent_dir.agent_pg_cart import Agent_PGC
        agent = Agent_PGC(env, args)
        agent.train()
コード例 #7
0
def run(args):
    if args.test:
        env = Environment('Pong-v0', args, test=True)
        from agent_dir.agent_pg import Agent_PG
        agent = Agent_PG(env, args)
        test(agent, env)
コード例 #8
0
def run(args):
    if args.train_pg:
        env_name = args.env_name or 'Pong-v0'
        env = Environment(env_name, args)
        from agent_dir.agent_pg import Agent_PG
        agent = Agent_PG(env, args)
        agent.train()

    if args.train_pg_improved:
        env_name = args.env_name or 'Pong-v0'
        env = Environment(env_name, args)
        from agent_dir.agent_pg_improved import Agent_PG
        agent = Agent_PG(env, args)
        agent.train()

    if args.train_dqn:
        env_name = args.env_name or 'BreakoutNoFrameskip-v4'
        env = Environment(env_name, args, atari_wrapper=True)
        from agent_dir.agent_dqn import Agent_DQN
        agent = Agent_DQN(env, args)
        agent.train()

    if args.train_dqn_improved:
        env_name = args.env_name or 'BreakoutNoFrameskip-v4'
        env = Environment(env_name, args, atari_wrapper=True)
        from agent_dir.agent_dqn_improved import Agent_DQN
        agent = Agent_DQN(env, args)
        agent.train()

    if args.train_pong_ac:
        env_name = args.env_name or 'Pong-v0'
        env = Environment(env_name, args)
        from agent_dir.agent_pong_ac import Agent_AC
        agent = Agent_AC(env, args)
        agent.train()

    if args.train_break_ac:
        env_name = args.env_name or 'BreakoutNoFrameskip-v4'
        env = Environment(env_name, args, atari_wrapper=True)
        from agent_dir.agent_break_ac import Agent_AC
        agent = Agent_AC(env, args)
        agent.train()

    if args.train_pong_ac_improved:
        env_name = args.env_name or 'Pong-v0'
        env = Environment(env_name, args)
        from agent_dir.agent_pong_ac_improved import Agent_AC
        agent = Agent_AC(env, args)
        agent.train()

    if args.train_break_ac_improved:
        env_name = args.env_name or 'BreakoutNoFrameskip-v4'
        env = Environment(env_name, args, atari_wrapper=True)
        from agent_dir.agent_break_ac_improved import Agent_AC
        agent = Agent_AC(env, args)
        agent.train()

    if args.test_pg:
        env = Environment('Pong-v0', args, test=True)
        from agent_dir.agent_pg import Agent_PG
        agent = Agent_PG(env, args)
        test(agent, env)

    if args.test_pg_improved:
        env = Environment('Pong-v0', args, test=True)
        from agent_dir.agent_pg_improved import Agent_PG
        agent = Agent_PG(env, args)
        test(agent, env)

    if args.test_dqn:
        env = Environment('BreakoutNoFrameskip-v4',
                          args,
                          atari_wrapper=True,
                          test=True)
        from agent_dir.agent_dqn import Agent_DQN
        agent = Agent_DQN(env, args)
        test(agent, env, total_episodes=100)

    if args.test_dqn_improved:
        env = Environment('BreakoutNoFrameskip-v4',
                          args,
                          atari_wrapper=True,
                          test=True)
        from agent_dir.agent_dqn_improved import Agent_DQN
        agent = Agent_DQN(env, args)
        test(agent, env, total_episodes=100)

    if args.train_ddqn:
        env_name = args.env_name or 'BreakoutNoFrameskip-v4'
        env = Environment(env_name, args, atari_wrapper=True)
        from agent_dir.agent_ddqn import Agent_DQN
        agent = Agent_DQN(env, args)
        agent.train()

    if args.test_ddqn:
        env = Environment('BreakoutNoFrameskip-v4',
                          args,
                          atari_wrapper=True,
                          test=True)
        from agent_dir.agent_ddqn import Agent_DQN
        agent = Agent_DQN(env, args)
        test(agent, env, total_episodes=100)

    if args.train_dddqn:
        env_name = args.env_name or 'BreakoutNoFrameskip-v4'
        env = Environment(env_name, args, atari_wrapper=True)
        from agent_dir.agent_dddqn import Agent_DQN
        agent = Agent_DQN(env, args)
        agent.train()

    if args.test_dddqn:
        env = Environment('BreakoutNoFrameskip-v4',
                          args,
                          atari_wrapper=True,
                          test=True)
        from agent_dir.agent_dddqn import Agent_DQN
        agent = Agent_DQN(env, args)
        test(agent, env, total_episodes=100)