def inference(flags, inference_batcher, model, lock=threading.Lock()): with torch.no_grad(): for batch in inference_batcher: batched_env_outputs, agent_state = batch.get_inputs() obs, _, done, *_ = batched_env_outputs obs, done, agent_state = nest.map( lambda t: t.to(flags.actor_device, non_blocking=True), [obs, done, agent_state], ) with lock: outputs = model(obs, done, agent_state) outputs = nest.map(lambda t: t.cpu(), outputs) batch.set_outputs(outputs)
def inference(inference_batcher, lock=threading.Lock()): nonlocal step for batch in inference_batcher: batched_env_outputs, agent_state = batch.get_inputs() obs, _, done, *_ = batched_env_outputs B = done.shape[1] with lock: step += B actions = nest.map(lambda i: action_space.sample(), [i for i in range(B)]) action = torch.from_numpy(np.concatenate(actions)).view(1, B, -1) outputs = ((action,), ()) outputs = nest.map(lambda t: t, outputs) batch.set_outputs(outputs)
def forward(self, inputs, core_state): x = inputs["frame"] T, B, *_ = x.shape x = torch.flatten(x, 0, 1) # Merge time and batch. x = x.float() / 255.0 res_input = None for i, fconv in enumerate(self.feat_convs): x = fconv(x) res_input = x x = self.resnet1[i](x) x += res_input res_input = x x = self.resnet2[i](x) x += res_input x = F.relu(x) x = x.view(T * B, -1) x = F.relu(self.fc(x)) clipped_reward = torch.clamp(inputs["reward"], -1, 1).view(T * B, 1) core_input = torch.cat([x, clipped_reward], dim=-1) if self.use_lstm: core_input = core_input.view(T, B, -1) core_output_list = [] notdone = (~inputs["done"]).float() for input, nd in zip(core_input.unbind(), notdone.unbind()): # Reset core state to zero whenever an episode ended. # Make `done` broadcastable with (num_layers, B, hidden_size) # states: nd = nd.view(1, -1, 1) core_state = nest.map(nd.mul, core_state) output, core_state = self.core(input.unsqueeze(0), core_state) core_output_list.append(output) core_output = torch.flatten(torch.cat(core_output_list), 0, 1) else: core_output = core_input policy_logits = self.policy(core_output) baseline = self.baseline(core_output) if self.training: action = torch.multinomial(F.softmax(policy_logits, dim=1), num_samples=1) else: # Don't sample when testing. action = torch.argmax(policy_logits, dim=1) policy_logits = policy_logits.view(T, B, self.num_actions) baseline = baseline.view(T, B) action = action.view(T, B) return (action, policy_logits, baseline), core_state
def test_nest_map(self): t1 = torch.tensor(0) t2 = torch.tensor(1) d = {"hey": t2} n = nest.map(lambda t: t + 42, (t1, t2)) self.assertSequenceEqual(n, [t1 + 42, t2 + 42]) self.assertSequenceEqual(n, nest.flatten(n)) n1 = (d, n, t1) n2 = nest.map(lambda t: t * 2, n1) self.assertEqual(n2[0], {"hey": torch.tensor(2)}) self.assertEqual(n2[1], (torch.tensor(84), torch.tensor(86))) self.assertEqual(n2[2], torch.tensor(0)) t = torch.tensor(42) # Doesn't work with pybind11/functional.h, but does with py::function. self.assertEqual(nest.map(t.add, t2), torch.tensor(43))
def tca_reward_function(flags, obs, new_frame, D): frame = obs["canvas"][:-1] frame, new_frame = nest.map(lambda t: torch.flatten(t, 0, 1), (frame, new_frame)) with torch.no_grad(): reward = torch.zeros(flags.unroll_length + 1, flags.batch_size, device=flags.learner_device) reward[1:] += (D(new_frame) - D(frame)).view(-1, flags.batch_size) return reward
def forward(self, input, core_state): T, B, *_ = input["obs"].shape grid = self.grid.repeat(T * B, 1, 1, 1) notdone = (~input["done"]).float() action = torch.flatten(input["action"] * notdone.unsqueeze(dim=2), 0, 1) obs = torch.flatten(input["obs"].float(), 0, 1) noise = torch.flatten(input["noise"], 0, 1) spatial = self.conv5x5(torch.cat([obs, grid], dim=1)) noise_embedding = self.fc(noise).view(-1, 32, 1, 1) mlp = self.action_fc(action) embedding = self.relu(spatial + noise_embedding + mlp) h = self.conv(embedding) h = self.resblock(h) h = self.flatten_fc(h) h = self.relu(h) core_input = h.view(T, B, 256) core_output_list = [] for input, nd in zip(core_input.unbind(), notdone.unbind()): nd = nd.view(1, -1, 1) core_state = nest.map(nd.mul, core_state) output, core_state = self.lstm(input.unsqueeze(0), core_state) core_output_list.append(output) core_output = torch.flatten(torch.cat(core_output_list), 0, 1) action, policy_logits = self.policy(core_output, action) baseline = self.baseline(core_output) action = action.view(T, B, self._num_actions) baseline = baseline.view(T, B) policy_logits = nest.map(lambda t: t.view(T, B, -1), policy_logits) return (action, policy_logits, baseline), core_state
def forward(self, obs, done, core_state): T, B, C, H, W = obs["canvas"].shape grid = self._grid(T * B, H, W) notdone = (~done).float() obs["prev_action"] = obs["prev_action"] * notdone.unsqueeze(dim=2) obs = nest.map(lambda t: torch.flatten(t, 0, 1), obs) canvas, action_mask, action, noise = ( obs[k] for k in ["canvas", "action_mask", "prev_action", "noise_sample"]) features = self.obs(torch.cat([canvas, grid], dim=1)) condition = (self.noise(noise) + self.action(self.mask_mlp(action, action_mask))).view( -1, 32, 1, 1) embedding = self.base(self.relu(features + condition)).view(T, B, 256) core_output_list = [] for core_input, nd in zip(embedding.unbind(), notdone.unbind()): nd = nd.view(1, -1, 1) core_state = nest.map(nd.mul, core_state) output, core_state = self.lstm(core_input.unsqueeze(0), core_state) core_output_list.append(output) seed = torch.flatten(torch.cat(core_output_list), 0, 1) action, logits = self.policy(seed, action) baseline = self.baseline(seed) action = action.view(T, B, self._num_actions) baseline = baseline.view(T, B) logits = nest.map(lambda t: t.view(T, B, -1), logits) return (action, logits, baseline), core_state
def __init__(self, timestep, unroll_length, num_actors, num_overlapping_steps=0): self._full_length = num_overlapping_steps + unroll_length + 1 self._num_overlapping_steps = num_overlapping_steps N = num_actors L = self._full_length self._state = nest.map( lambda t: torch.zeros((N, L) + t.shape, dtype=t.dtype), timestep) self._index = torch.zeros([N], dtype=torch.int64)
def _complete_unrolls(self, actor_ids): """Obtain unrolls that have reached the desired length""" actor_indices = self._index[actor_ids] actor_ids = actor_ids[actor_indices == self._full_length] unrolls = nest.map(lambda s: s[actor_ids], self._state) # Reset state of completed actors to start from the end of the previous # ones (NB: since `unrolls` is a copy it is ok to do it in place). j = self._num_overlapping_steps + 1 for s in nest.flatten(self._state): s[actor_ids, :j] = s[actor_ids, -j:] self._index.scatter_(0, actor_ids, 1 + self._num_overlapping_steps) return actor_ids, unrolls
def compute_metrics( flags, learner_outputs, actor_outputs, env_outputs, last_actions, reward_stats=None, end_of_episode_bootstrap=False, ): """ Compute various metrics (including in particular the loss being optimized). :param learner_outputs: A dictionary holding the following tensors, where `T` is the unroll length and `N` the batch size (number of actors): * "action": (T + 1, N) * "baseline": (T + 1, N) * "policy_logits": (T + 1, N, num_actions) :param actor_outputs: Similar to `learner_outputs`. :param env_outputs: A triplet of observation (T + 1, N, obs_dim), reward (T + 1, N) and done (T + 1, N) tensors. :param last_actions: not used :param reward_stats: Must be provided when reward normalization is enabled. This dictionary holds statistics on the observed rewards. """ del last_actions # Only used in model. # Estimated value of the last state in the rollout (N). bootstrap_value = learner_outputs["baseline"][-1] # Move from obs[t] -> action[t] to action[t] -> obs[t]. # After this step all tensors have shape (T, N, ...) actor_outputs = nest.map(lambda t: t[:-1], actor_outputs) rewards, done = nest.map(lambda t: t[1:], env_outputs[1:]) learner_outputs = nest.map(lambda t: t[:-1], learner_outputs) if flags.reward_clipping == "abs_one": rewards = torch.clamp(rewards, -1, 1) elif flags.reward_clipping == "soft_asymmetric": squeezed = torch.tanh(rewards / 5.0) # Negative rewards are given less weight than positive rewards. rewards = torch.where(rewards < 0, 0.3 * squeezed, squeezed) * 5.0 elif flags.reward_clipping != "none": raise NotImplementedError(flags.reward_clipping) if flags.reward_normalization: train_job_id = env_outputs[0][1:][:, :, -1].long() normalize_rewards(flags, train_job_id, rewards, reward_stats) discounts = (~done).float() * flags.discounting actor_logits = actor_outputs["policy_logits"] # (T, N, num_actions) learner_logits = learner_outputs["policy_logits"] # (T, N, num_actions) actions = actor_outputs["action"] # (T, N) baseline = learner_outputs["baseline"] # (T, N) vtrace_returns = vtrace.from_logits( behavior_policy_logits=actor_logits, target_policy_logits=learner_logits, actions=actions, discounts=discounts, rewards=rewards, values=baseline, bootstrap_value=bootstrap_value, end_of_episode_bootstrap=end_of_episode_bootstrap, done=done, ) pg_loss = compute_policy_gradient_loss(learner_logits, actions, vtrace_returns.pg_advantages) baseline_loss = compute_baseline_loss( # NB: if `end_of_episode_bootstrap` is True, then `vs` at end of episode # is equal to the baseline, contributing to a zero cost. This is what we # want, as we cannot learn from that step without knowing the next state. vtrace_returns.vs - baseline) entropy_loss = compute_entropy_loss(learner_logits) # Total entropy if the learner was outputting a uniform distribution over actions # (used for normalization in the reported metrics). batch_total_size = learner_logits.shape[0] * learner_logits.shape[1] uniform_entropy = batch_total_size * np.log(flags.num_actions) return { "loss/total": pg_loss + flags.baseline_cost * baseline_loss + flags.entropy_cost * entropy_loss, "loss/pg": pg_loss, "loss/baseline": baseline_loss, "loss/normalized_neg_entropy": entropy_loss / uniform_entropy, "critic/baseline/mean": torch.mean(baseline), "critic/baseline/min": torch.min(baseline), "critic/baseline/max": torch.max(baseline), }
def learn( flags, learner_queue, model, actor_model, D, optimizer, scheduler, stats, plogger, lock=threading.Lock(), ): for tensors in learner_queue: new_obs = tensors[1] tensors = edit_tuple(tensors, 1, new_obs["canvas"]) tensors = nest.map( lambda t: t.to(flags.learner_device, non_blocking=True), tensors) batch, new_frame, initial_agent_state = tensors env_outputs, actor_outputs = batch obs, reward, done, step, _ = env_outputs lock.acquire() # Only one thread learning at a time. if flags.use_tca: discriminator_reward = tca_reward_function(flags, obs, new_frame, D) reward = env_outputs[1] env_outputs = edit_tuple(env_outputs, 1, reward + discriminator_reward) batch = edit_tuple(batch, 0, env_outputs) else: if done.any().item(): discriminator_reward = reward_function(flags, done, new_frame, D) reward = env_outputs[1] env_outputs = edit_tuple(env_outputs, 1, reward + discriminator_reward) batch = edit_tuple(batch, 0, env_outputs) optimizer.zero_grad() actor_outputs = AgentOutput._make(actor_outputs) learner_outputs, agent_state = model(obs, done, initial_agent_state) # Take final value function slice for bootstrapping. learner_outputs = AgentOutput._make(learner_outputs) bootstrap_value = learner_outputs.baseline[-1] # Move from obs[t] -> action[t] to action[t] -> obs[t]. batch = nest.map(lambda t: t[1:], batch) learner_outputs = nest.map(lambda t: t[:-1], learner_outputs) # Turn into namedtuples again. env_outputs, actor_outputs = batch env_outputs = EnvOutput._make(env_outputs) actor_outputs = AgentOutput._make(actor_outputs) learner_outputs = AgentOutput._make(learner_outputs) discounts = (~env_outputs.done).float() * flags.discounting vtrace_returns = vtrace.from_logits( behavior_policy_logits=actor_outputs.policy_logits, target_policy_logits=learner_outputs.policy_logits, actions=actor_outputs.action, discounts=discounts, rewards=env_outputs.reward, values=learner_outputs.baseline, bootstrap_value=bootstrap_value, ) vtrace_returns = vtrace.VTraceFromLogitsReturns._make(vtrace_returns) pg_loss = compute_policy_gradient_loss( learner_outputs.policy_logits, actor_outputs.action, vtrace_returns.pg_advantages, ) baseline_loss = flags.baseline_cost * compute_baseline_loss( vtrace_returns.vs - learner_outputs.baseline) entropy_loss = flags.entropy_cost * compute_entropy_loss( learner_outputs.policy_logits) total_loss = pg_loss + baseline_loss + entropy_loss total_loss.backward() nn.utils.clip_grad_norm_(model.parameters(), flags.grad_norm_clipping) optimizer.step() scheduler.step() actor_model.load_state_dict(model.state_dict()) episode_returns = env_outputs.episode_return[env_outputs.done] stats["step"] = stats.get("step", 0) + flags.unroll_length * flags.batch_size stats["episode_returns"] = tuple(episode_returns.cpu().numpy()) stats["mean_environment_return"] = episode_returns.mean().item() stats["mean_discriminator_return"] = discriminator_reward.mean().item() stats["mean_episode_return"] = (stats["mean_environment_return"] + stats["mean_discriminator_return"]) stats["total_loss"] = total_loss.item() stats["pg_loss"] = pg_loss.item() stats["baseline_loss"] = baseline_loss.item() stats["entropy_loss"] = entropy_loss.item() stats["learner_queue_size"] = learner_queue.size() if flags.condition and new_frame.size() != 0: stats["l2_loss"] = F.mse_loss( *new_frame.split(split_size=new_frame.shape[1] // 2, dim=1)).item() plogger.log(stats) lock.release()