def eval_model(model: ActorCritic, env: ShmemVecEnv, history_size: int, emb_size: int, device: str, num_ep=100): model.eval() obs_emb = torch.zeros(env.num_envs, history_size, 1, 84, 84, device=device, dtype=torch.uint8) obs = env.reset().to(device=device) obs_emb[:, -1] = obs[:, -1:] ep_reward = [] while True: with torch.no_grad(): a = model(obs, obs_emb)[0].sample().unsqueeze(1) obs, rewards, terms, infos = env.step(a) obs = obs.to(device=device) obs_emb[:, :-1].copy_(obs_emb[:, 1:]) obs_emb *= (~terms)[..., None, None, None].to(device=device, dtype=torch.uint8) obs_emb[:, -1] = obs[:, -1:] for info in infos: if 'episode' in info.keys(): ep_reward.append(info['episode']['r']) if len(ep_reward) == num_ep: return torch.tensor(ep_reward)
def main(): # cartpole test if (cartpole_test): envs_fun = [lambda: gym.make('CartPole-v0')] envs_fun = np.tile(envs_fun, 3) envs = ShmemVecEnv(envs_fun) dummy_env = envs_fun[0]() else: INPUT_FILE = '../data/05f2a901.json' with open(INPUT_FILE, 'r') as f: puzzle = json.load(f) envs_fun = [ lambda: gym.make('arc-v0', input=task['input'], output=task['output'], need_ui=need_ui) for task in puzzle['train'] ] #pdb.set_trace() envs_fun = envs_fun[0:1] envs = ShmemVecEnv(envs_fun) dummy_env = envs_fun[0]() env_num = len(envs_fun) torch.manual_seed(500) num_inputs = dummy_env.observation_space.shape[0] num_actions = dummy_env.action_space.n print('state size:', num_inputs) print('action size:', num_actions) online_net = QNet(num_inputs, num_actions, cartpole_test, evalution_mode) target_net = QNet(num_inputs, num_actions, cartpole_test, evalution_mode) if (evalution_mode): online_net = torch.load('../result/arc0.model') target_net = torch.load('../result/arc0.model') update_target_model(online_net, target_net) optimizer = optim.Adam(online_net.parameters(), lr=lr) writer = SummaryWriter('logs') online_net.to(device) target_net.to(device) online_net.train() target_net.train() memory = Memory(replay_memory_capacity) score = 0 epsilon = 1.0 steps = 0 loss = 0 states = envs.reset() try: while True: if (need_ui): envs.render() steps += 1 global initial_exploration if (initial_exploration > 0): initial_exploration -= 1 actions = [] for state in states: state = torch.Tensor(state).to(device) state = state.unsqueeze(0) action = get_action(state, target_net, 0 if evalution_mode else epsilon, dummy_env) if (evalution_mode): print(action) actions.append(action) next_states, rewards, dones, info = envs.step(actions) #print(rewards) masks = np.zeros(envs.num_envs) for i in range(envs.num_envs): masks[i] = 0 if dones[i] else 1 for i in range(envs.num_envs): #print(rewards[i]) action_one_hot = np.zeros(dummy_env.action_space.n) action_one_hot[actions[i]] = 1 memory.push(states[i], next_states[i], action_one_hot, rewards[i], masks[i]) #score += reward states = next_states if not evalution_mode and steps > initial_exploration: epsilon -= 0.00003 epsilon = max(epsilon, 0.1) batch = memory.sample(batch_size) loss = QNet.train_model(online_net, target_net, optimizer, batch, device) if steps % update_target == 0: update_target_model(online_net, target_net) if (steps > 1028): states = envs.reset() steps = 0 print( 'new epsisode ------------------------------------------') except KeyboardInterrupt: print('save model') torch.save(target_net, '../result/arc.model') sys.exit(0)