예제 #1
0
파일: main.py 프로젝트: AJ1897/DQN-Breakout
def run(args):
    if args.train_dqn:
        env_name = args.env_name or 'BreakoutNoFrameskip-v4'
        env = Environment(env_name, args, atari_wrapper=True)
        from agent_dqn import Agent_DQN
        agent = Agent_DQN(env, args)
        agent.train()

    if args.test_dqn:
        env = Environment('BreakoutNoFrameskip-v4', args, atari_wrapper=True, test=True)
        from agent_dqn import Agent_DQN
        agent = Agent_DQN(env, args)
        test(agent, env, total_episodes=100)
예제 #2
0
def run(args):
    if args.train_dqn:
        env_name = args.env_name or 'SEVN-Test-AllObs-Shaped-v1'
        env = gym.make(env_name)
        from agent_dqn import Agent_DQN
        agent = Agent_DQN(env, args)
        print("Device: ", device)
        print("n_heads: ", args.n_heads)
        agent.train()

    if args.test_dqn:
        current_time = '2020-12-22-00:03:03'
        env_name = args.env_name or 'SEVN-Test-AllObs-Shaped-v1'
        env = gym.make(env_name)
        from agent_dqn import Agent_DQN
        agent = Agent_DQN(env, args)
        test(agent, env, 10, current_time)
예제 #3
0
def run(args):
    env = Environment('BreakoutNoFrameskip-v4',
                      args,
                      atari_wrapper=True,
                      test=True)
    from agent_dqn import Agent_DQN
    agent = Agent_DQN(env, args)
    test(agent, env, total_episodes=5)
예제 #4
0
def run(args):
    if args.train_dqn:
        env_name = args.env_name or 'BreakoutNoFrameskip-v4'
        env = Environment(env_name, args, atari_wrapper=True)
        torch.set_default_tensor_type('torch.cuda.FloatTensor')
        from agent_dqn import Agent_DQN
        agent = Agent_DQN(env, args)
        agent.train()

    if args.test_dqn:
        env = Environment('BreakoutNoFrameskip-v4',
                          args,
                          atari_wrapper=True,
                          test=True)
        from agent_dqn import Agent_DQN
        agent = Agent_DQN(env, args)
        test(agent, env, total_episodes=100)
예제 #5
0
파일: Playground.py 프로젝트: bucky1995/P3
    try:
        from argument import add_arguments
        parser = add_arguments(parser)
    except:
        pass
    args = parser.parse_args()
    return args


args = parse()
env = Environment('BreakoutNoFrameskip-v4', "", atari_wrapper=True, test=False)
n = env.action_space
state = env.reset()
device = torch.device("cpu")
input = torch.tensor(state, device=device)
agent = Agent_DQN(env, args)
dqn = DQN()
torch.save(dqn.state_dict(), "checkpoint.pth")
state_dict = torch.load("checkpoint.pth")

agent.train()
# Experience = namedtuple(
#             'Experience',
#             ('state','action','next_state','reward')
#         )
# e = Experience(state,action,next_state,reward)
# print(e)
# agent.train()

# print("shape:",input.shape)
#