def test(rank, args, shared_model): torch.manual_seed(args.seed + rank) env = create_atari_env(args.env_name) env.seed(args.seed + rank) model = ActorCritic(env.observation_space.shape[0], env.action_space, args.num_skips) model.eval() state = env.reset() state = np.concatenate([state] * 4, axis=0) state = torch.from_numpy(state) reward_sum = 0 done = True action_stat = [0] * (model.n_real_acts + model.n_aux_acts) start_time = time.time() episode_length = 0 for ep_counter in itertools.count(1): # Sync with the shared model if done: model.load_state_dict(shared_model.state_dict()) if not os.path.exists('model-a3c-aux'): os.makedirs('model-a3c-aux') torch.save(shared_model.state_dict(), 'model-a3c-aux/model-{}.pth'.format(args.model_name)) print('saved model') value, logit = model(Variable(state.unsqueeze(0), volatile=True)) prob = F.softmax(logit) action = prob.max(1)[1].data.numpy() action_np = action[0, 0] action_stat[action_np] += 1 if action_np < model.n_real_acts: state_new, reward, done, info = env.step(action_np) dead = is_dead(info) if args.testing: print('episode', episode_length, 'normal action', action_np, 'lives', info['ale.lives']) env.render() state = np.append(state.numpy()[1:, :, :], state_new, axis=0) done = done or episode_length >= args.max_episode_length reward_sum += reward episode_length += 1 else: state = state.numpy() for _ in range(model.get_skip(action_np)): state_new, rew, done, info = env.step( 0) # instead of random perform NOOP=0 dead = is_dead(info) if args.testing: print('episode', episode_length, 'random action', action_np, 'lives', info['ale.lives']) env.render() state = np.append(state[1:, :, :], state_new, axis=0) done = done or episode_length >= args.max_episode_length reward_sum += rew episode_length += 1 if done or dead: break if done: print("Time {}, episode reward {}, episode length {}".format( time.strftime("%Hh %Mm %Ss", time.gmtime(time.time() - start_time)), reward_sum, episode_length)) print("actions stats real {}, aux {}".format( action_stat[:model.n_real_acts], action_stat[model.n_real_acts:])) reward_sum = 0 episode_length = 0 state = env.reset() env.seed(args.seed + rank + (args.num_processes + 1) * ep_counter) state = np.concatenate([state] * 4, axis=0) action_stat = [0] * (model.n_real_acts + model.n_aux_acts) if not args.testing: time.sleep(60) state = torch.from_numpy(state)
def train(rank, args, shared_model, optimizer=None): torch.manual_seed(args.seed + rank) env = create_atari_env(args.env_name) env.seed(args.seed + rank) model = ActorCritic(env.observation_space.shape[0], env.action_space, args.num_skips) if optimizer is None: optimizer = optim.Adam(shared_model.parameters(), lr=args.lr) model.train() state = env.reset() state = np.concatenate([state] * 4, axis=0) state = torch.from_numpy(state) done = True episode_length = 0 for ep_counter in itertools.count(1): # Sync with the shared model model.load_state_dict(shared_model.state_dict()) values = [] log_probs = [] rewards = [] entropies = [] for step in range(args.num_steps): value, logit = model(Variable(state.unsqueeze(0))) prob = F.softmax(logit) log_prob = F.log_softmax(logit) entropy = -(log_prob * prob).sum(1) entropies.append(entropy) action = prob.multinomial().data log_prob = log_prob.gather(1, Variable(action)) action_np = action.numpy()[0][0] if action_np < model.n_real_acts: state_new, reward, done, info = env.step(action_np) dead = is_dead(info) state = np.append(state.numpy()[1:, :, :], state_new, axis=0) done = done or episode_length >= args.max_episode_length reward = max(min(reward, 1), -1) episode_length += 1 else: state = state.numpy() reward = 0. for _ in range(model.get_skip(action_np)): state_new, rew, done, info = env.step( 0) # instead of random perform NOOP=0 dead = is_dead(info) state = np.append(state[1:, :, :], state_new, axis=0) done = done or episode_length >= args.max_episode_length rew = max(min(rew, 1), -1) reward += rew episode_length += 1 if done or dead: break if done: episode_length = 0 state = env.reset() env.seed(args.seed + rank + (args.num_processes + 1) * ep_counter) state = np.concatenate([state] * 4, axis=0) elif dead: state = np.concatenate([state_new] * 4, axis=0) state = torch.from_numpy(state) values.append(value) log_probs.append(log_prob) rewards.append(reward) if done or dead: break R = torch.zeros(1, 1) if not done and not dead: value, _ = model(Variable(state.unsqueeze(0))) R = value.data values.append(Variable(R)) policy_loss = 0 value_loss = 0 R = Variable(R) gae = torch.zeros(1, 1) for i in reversed(range(len(rewards))): R = args.gamma * R + rewards[i] advantage = R - values[i] value_loss = value_loss + 0.5 * advantage.pow(2) policy_loss = policy_loss - \ log_probs[i] * Variable(advantage.data) - 0.01 * entropies[i] optimizer.zero_grad() (policy_loss + 0.5 * value_loss).backward() torch.nn.utils.clip_grad_norm(model.parameters(), 40.) ensure_shared_grads(model, shared_model) optimizer.step()