def test(rank, args, shared_model, counter): log = Log('a3c_baselines_testing') env = gym.make(args.env_name) env.seed(args.seed + rank) torch.manual_seed(args.seed + rank) model = Policy(2, action_map) model.eval() state = env.reset() reward_sum = 0 done = True start_time = time.time() # a quick hack to prevent the agent from stucking # actions = deque(maxlen=100) episode_length = 0 while True: episode_length += 1 # Sync with the shared model env.render() if done: model.load_state_dict(shared_model.state_dict()) cx = torch.zeros(1, 64) hx = torch.zeros(1, 64) action, hx, cx = model(state, hx, cx) state, reward, done, _ = env.step(action) reward_sum += reward if done: log_string = "Time {}, num steps {}, FPS {:.0f}, episode reward {}, episode length {}".format( time.strftime("%Hh %Mm %Ss", time.gmtime(time.time() - start_time)), counter.value, counter.value / (time.time() - start_time), reward_sum, episode_length) # print(log_string) log.log(log_string) reward_sum = 0 episode_length = 0 # actions.clear() state = env.reset() time.sleep(5)
def train(rank, args, shared_model, optimizer, counter, lock): env = gym.make(args.env_name) env.seed(args.seed + rank) torch.manual_seed(args.seed + rank) model = Policy(2, action_map) model.train() state = env.reset() # state = tensor_state(state) done = True episode_length = 0 while True: # Sync with the shared model model.load_state_dict(shared_model.state_dict()) if done: cx = torch.zeros(1, 64) hx = torch.zeros(1, 64) else: cx = cx.data hx = hx.data values = [] log_probs = [] rewards = [] entropies = [] for step in range(args.num_steps): episode_length += 1 action, hx, cx = model(state, hx, cx) entropies.append(model.entropy) state, reward, done, _ = env.step(action) reward = max(min(reward, 1), -1) with lock: counter.value += 1 if done: episode_length = 0 state = env.reset() values.append(model.v) log_probs.append(model.log_prob) rewards.append(reward) if done: break R = torch.zeros(1, 1) if not done: model(state, hx, cx) R = model.v.data values.append(R) policy_loss = 0 value_loss = 0 gae = torch.zeros(1, 1) for i in reversed(range(len(rewards))): R = args.gamma * R + rewards[i] advantage = R - values[i] value_loss = value_loss + 0.5 * advantage.pow(2) # Generalized Advantage Estimataion delta_t = rewards[i] + args.gamma * values[i + 1].data - values[i].data gae = gae * args.gamma * args.tau + delta_t policy_loss = policy_loss - log_probs[i] * gae - args.entropy_coef * entropies[i] loss = policy_loss + args.value_loss_coef * value_loss optimizer.zero_grad() loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm) ensure_shared_grads(model, shared_model) optimizer.step()