def run(args): env = Environment('BreakoutNoFrameskip-v4', args, atari_wrapper=True, test=True) from agent_dqn import Agent_DQN agent = Agent_DQN(env, args) test(agent, env, total_episodes=5)
def run(args): if args.train_dqn: env_name = args.env_name or 'BreakoutNoFrameskip-v4' env = Environment(env_name, args, atari_wrapper=True) from agent_dqn import Agent_DQN agent = Agent_DQN(env, args) agent.train() if args.test_dqn: env = Environment('BreakoutNoFrameskip-v4', args, atari_wrapper=True, test=True) from agent_dqn import Agent_DQN agent = Agent_DQN(env, args) test(agent, env, total_episodes=100)
def run(args): if args.train_dqn: env_name = args.env_name or 'SEVN-Test-AllObs-Shaped-v1' env = gym.make(env_name) from agent_dqn import Agent_DQN agent = Agent_DQN(env, args) print("Device: ", device) print("n_heads: ", args.n_heads) agent.train() if args.test_dqn: current_time = '2020-12-22-00:03:03' env_name = args.env_name or 'SEVN-Test-AllObs-Shaped-v1' env = gym.make(env_name) from agent_dqn import Agent_DQN agent = Agent_DQN(env, args) test(agent, env, 10, current_time)
def run(args): if args.train_dqn: env_name = args.env_name or 'BreakoutNoFrameskip-v4' env = Environment(env_name, args, atari_wrapper=True) torch.set_default_tensor_type('torch.cuda.FloatTensor') from agent_dqn import Agent_DQN agent = Agent_DQN(env, args) agent.train() if args.test_dqn: env = Environment('BreakoutNoFrameskip-v4', args, atari_wrapper=True, test=True) from agent_dqn import Agent_DQN agent = Agent_DQN(env, args) test(agent, env, total_episodes=100)
try: from argument import add_arguments parser = add_arguments(parser) except: pass args = parser.parse_args() return args args = parse() env = Environment('BreakoutNoFrameskip-v4', "", atari_wrapper=True, test=False) n = env.action_space state = env.reset() device = torch.device("cpu") input = torch.tensor(state, device=device) agent = Agent_DQN(env, args) dqn = DQN() torch.save(dqn.state_dict(), "checkpoint.pth") state_dict = torch.load("checkpoint.pth") agent.train() # Experience = namedtuple( # 'Experience', # ('state','action','next_state','reward') # ) # e = Experience(state,action,next_state,reward) # print(e) # agent.train() # print("shape:",input.shape) #
import argparse #from test import test from environment import Environment from agent_dqn import Agent_DQN from main import parse from dqn_model import DQN import torch if __name__ == '__main__': args = parse() env_name = 'BreakoutNoFrameskip-v4' env = Environment(env_name, args, atari_wrapper=True) agent = Agent_DQN(env, args) print (agent.buffer) for i in range(0, 8): agent.push(i) print(agent.buffer) dqn = DQN() print(dqn) input = torch.randn(1, 4, 84, 84) out = dqn(input) print(out.size())