class QMixTorchPolicy(Policy): """QMix impl. Assumes homogeneous agents for now. You must use MultiAgentEnv.with_agent_groups() to group agents together for QMix. This creates the proper Tuple obs/action spaces and populates the '_group_rewards' info field. Action masking: to specify an action mask for individual agents, use a dict space with an action_mask key, e.g. {"obs": ob, "action_mask": mask}. The mask space must be `Box(0, 1, (n_actions,))`. """ def __init__(self, obs_space, action_space, config): _validate(obs_space, action_space) config = dict(ray.rllib.agents.qmix.qmix.DEFAULT_CONFIG, **config) self.framework = "torch" super().__init__(obs_space, action_space, config) self.n_agents = len(obs_space.original_space.spaces) self.n_actions = action_space.spaces[0].n self.h_size = config["model"]["lstm_cell_size"] self.has_env_global_state = False self.has_action_mask = False self.device = (torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")) agent_obs_space = obs_space.original_space.spaces[0] if isinstance(agent_obs_space, Dict): space_keys = set(agent_obs_space.spaces.keys()) if "obs" not in space_keys: raise ValueError( "Dict obs space must have subspace labeled `obs`") self.obs_size = _get_size(agent_obs_space.spaces["obs"]) if "action_mask" in space_keys: mask_shape = tuple(agent_obs_space.spaces["action_mask"].shape) if mask_shape != (self.n_actions, ): raise ValueError( "Action mask shape must be {}, got {}".format( (self.n_actions, ), mask_shape)) self.has_action_mask = True if ENV_STATE in space_keys: self.env_global_state_shape = _get_size( agent_obs_space.spaces[ENV_STATE]) self.has_env_global_state = True else: self.env_global_state_shape = (self.obs_size, self.n_agents) # The real agent obs space is nested inside the dict config["model"]["full_obs_space"] = agent_obs_space agent_obs_space = agent_obs_space.spaces["obs"] else: self.obs_size = _get_size(agent_obs_space) self.model = ModelCatalog.get_model_v2( agent_obs_space, action_space.spaces[0], self.n_actions, config["model"], framework="torch", name="model", default_model=RNNModel).to(self.device) self.target_model = ModelCatalog.get_model_v2( agent_obs_space, action_space.spaces[0], self.n_actions, config["model"], framework="torch", name="target_model", default_model=RNNModel).to(self.device) self.exploration = self._create_exploration() # Setup the mixer network. if config["mixer"] is None: self.mixer = None self.target_mixer = None elif config["mixer"] == "qmix": self.mixer = QMixer(self.n_agents, self.env_global_state_shape, config["mixing_embed_dim"]).to(self.device) self.target_mixer = QMixer( self.n_agents, self.env_global_state_shape, config["mixing_embed_dim"]).to(self.device) elif config["mixer"] == "vdn": self.mixer = VDNMixer().to(self.device) self.target_mixer = VDNMixer().to(self.device) else: raise ValueError("Unknown mixer type {}".format(config["mixer"])) self.cur_epsilon = 1.0 self.update_target() # initial sync # Setup optimizer self.params = list(self.model.parameters()) if self.mixer: self.params += list(self.mixer.parameters()) self.loss = QMixLoss(self.model, self.target_model, self.mixer, self.target_mixer, self.n_agents, self.n_actions, self.config["double_q"], self.config["gamma"]) self.optimiser = RMSprop( params=self.params, lr=config["lr"], alpha=config["optim_alpha"], eps=config["optim_eps"]) @override(Policy) def compute_actions(self, obs_batch, state_batches=None, prev_action_batch=None, prev_reward_batch=None, info_batch=None, episodes=None, explore=None, **kwargs): explore = explore if explore is not None else self.config["explore"] obs_batch, action_mask, _ = self._unpack_observation(obs_batch) # We need to ensure we do not use the env global state # to compute actions # Compute actions with torch.no_grad(): q_values, hiddens = _mac( self.model, torch.as_tensor( obs_batch, dtype=torch.float, device=self.device), [ torch.as_tensor( np.array(s), dtype=torch.float, device=self.device) for s in state_batches ]) avail = torch.as_tensor( action_mask, dtype=torch.float, device=self.device) masked_q_values = q_values.clone() masked_q_values[avail == 0.0] = -float("inf") # epsilon-greedy action selector random_numbers = torch.rand_like(q_values[:, :, 0]) pick_random = (random_numbers < (self.cur_epsilon if explore else 0.0)).long() random_actions = Categorical(avail).sample().long() actions = (pick_random * random_actions + (1 - pick_random) * masked_q_values.argmax(dim=2)) actions = actions.cpu().numpy() hiddens = [s.cpu().numpy() for s in hiddens] return TupleActions(list(actions.transpose([1, 0]))), hiddens, {} @override(Policy) def compute_log_likelihoods(self, actions, obs_batch, state_batches=None, prev_action_batch=None, prev_reward_batch=None): obs_batch, action_mask, _ = self._unpack_observation(obs_batch) return np.zeros(obs_batch.size()[0]) @override(Policy) def learn_on_batch(self, samples): obs_batch, action_mask, env_global_state = self._unpack_observation( samples[SampleBatch.CUR_OBS]) (next_obs_batch, next_action_mask, next_env_global_state) = self._unpack_observation( samples[SampleBatch.NEXT_OBS]) group_rewards = self._get_group_rewards(samples[SampleBatch.INFOS]) input_list = [ group_rewards, action_mask, next_action_mask, samples[SampleBatch.ACTIONS], samples[SampleBatch.DONES], obs_batch, next_obs_batch ] if self.has_env_global_state: input_list.extend([env_global_state, next_env_global_state]) output_list, _, seq_lens = \ chop_into_sequences( samples[SampleBatch.EPS_ID], samples[SampleBatch.UNROLL_ID], samples[SampleBatch.AGENT_INDEX], input_list, [], # RNN states not used here max_seq_len=self.config["model"]["max_seq_len"], dynamic_max=True) # These will be padded to shape [B * T, ...] if self.has_env_global_state: (rew, action_mask, next_action_mask, act, dones, obs, next_obs, env_global_state, next_env_global_state) = output_list else: (rew, action_mask, next_action_mask, act, dones, obs, next_obs) = output_list B, T = len(seq_lens), max(seq_lens) def to_batches(arr, dtype): new_shape = [B, T] + list(arr.shape[1:]) return torch.as_tensor( np.reshape(arr, new_shape), dtype=dtype, device=self.device) rewards = to_batches(rew, torch.float) actions = to_batches(act, torch.long) obs = to_batches(obs, torch.float).reshape( [B, T, self.n_agents, self.obs_size]) action_mask = to_batches(action_mask, torch.float) next_obs = to_batches(next_obs, torch.float).reshape( [B, T, self.n_agents, self.obs_size]) next_action_mask = to_batches(next_action_mask, torch.float) if self.has_env_global_state: env_global_state = to_batches(env_global_state, torch.float) next_env_global_state = to_batches(next_env_global_state, torch.float) # TODO(ekl) this treats group termination as individual termination terminated = to_batches(dones, torch.float).unsqueeze(2).expand( B, T, self.n_agents) # Create mask for where index is < unpadded sequence length filled = np.reshape( np.tile(np.arange(T, dtype=np.float32), B), [B, T]) < np.expand_dims(seq_lens, 1) mask = torch.as_tensor( filled, dtype=torch.float, device=self.device).unsqueeze(2).expand( B, T, self.n_agents) # Compute loss loss_out, mask, masked_td_error, chosen_action_qvals, targets = ( self.loss(rewards, actions, terminated, mask, obs, next_obs, action_mask, next_action_mask, env_global_state, next_env_global_state)) # Optimise self.optimiser.zero_grad() loss_out.backward() grad_norm = torch.nn.utils.clip_grad_norm_( self.params, self.config["grad_norm_clipping"]) self.optimiser.step() mask_elems = mask.sum().item() stats = { "loss": loss_out.item(), "grad_norm": grad_norm if isinstance(grad_norm, float) else grad_norm.item(), "td_error_abs": masked_td_error.abs().sum().item() / mask_elems, "q_taken_mean": (chosen_action_qvals * mask).sum().item() / mask_elems, "target_mean": (targets * mask).sum().item() / mask_elems, } return {LEARNER_STATS_KEY: stats} @override(Policy) def get_initial_state(self): # initial RNN state return [ s.expand([self.n_agents, -1]).cpu().numpy() for s in self.model.get_initial_state() ] @override(Policy) def get_weights(self): return { "model": self._cpu_dict(self.model.state_dict()), "target_model": self._cpu_dict(self.target_model.state_dict()), "mixer": self._cpu_dict(self.mixer.state_dict()) if self.mixer else None, "target_mixer": self._cpu_dict(self.target_mixer.state_dict()) if self.mixer else None, } @override(Policy) def set_weights(self, weights): self.model.load_state_dict(self._device_dict(weights["model"])) self.target_model.load_state_dict( self._device_dict(weights["target_model"])) if weights["mixer"] is not None: self.mixer.load_state_dict(self._device_dict(weights["mixer"])) self.target_mixer.load_state_dict( self._device_dict(weights["target_mixer"])) @override(Policy) def get_state(self): state = self.get_weights() state["cur_epsilon"] = self.cur_epsilon return state @override(Policy) def set_state(self, state): self.set_weights(state) self.set_epsilon(state["cur_epsilon"]) def update_target(self): self.target_model.load_state_dict(self.model.state_dict()) if self.mixer is not None: self.target_mixer.load_state_dict(self.mixer.state_dict()) logger.debug("Updated target networks") def set_epsilon(self, epsilon): self.cur_epsilon = epsilon def _get_group_rewards(self, info_batch): group_rewards = np.array([ info.get(GROUP_REWARDS, [0.0] * self.n_agents) for info in info_batch ]) return group_rewards def _device_dict(self, state_dict): return { k: torch.as_tensor(v, device=self.device) for k, v in state_dict.items() } @staticmethod def _cpu_dict(state_dict): return {k: v.cpu().detach().numpy() for k, v in state_dict.items()} def _unpack_observation(self, obs_batch): """Unpacks the observation, action mask, and state (if present) from agent grouping. Returns: obs (np.ndarray): obs tensor of shape [B, n_agents, obs_size] mask (np.ndarray): action mask, if any state (np.ndarray or None): state tensor of shape [B, state_size] or None if it is not in the batch """ unpacked = _unpack_obs( np.array(obs_batch, dtype=np.float32), self.observation_space.original_space, tensorlib=np) if self.has_action_mask: obs = np.concatenate( [o["obs"] for o in unpacked], axis=1).reshape([len(obs_batch), self.n_agents, self.obs_size]) action_mask = np.concatenate( [o["action_mask"] for o in unpacked], axis=1).reshape( [len(obs_batch), self.n_agents, self.n_actions]) else: if isinstance(unpacked[0], dict): unpacked_obs = [u["obs"] for u in unpacked] else: unpacked_obs = unpacked obs = np.concatenate( unpacked_obs, axis=1).reshape([len(obs_batch), self.n_agents, self.obs_size]) action_mask = np.ones( [len(obs_batch), self.n_agents, self.n_actions], dtype=np.float32) if self.has_env_global_state: state = unpacked[0][ENV_STATE] else: state = None return obs, action_mask, state
class QMixTorchPolicy(Policy): """QMix impl. Assumes homogeneous agents for now. You must use MultiAgentEnv.with_agent_groups() to group agents together for QMix. This creates the proper Tuple obs/action spaces and populates the '_group_rewards' info field. Action masking: to specify an action mask for individual agents, use a dict space with an action_mask key, e.g. {"obs": ob, "action_mask": mask}. The mask space must be `Box(0, 1, (n_actions,))`. """ def __init__(self, obs_space, action_space, config): _validate(obs_space, action_space) config = dict(ray.rllib.agents.qmix.qmix.DEFAULT_CONFIG, **config) self.config = config self.observation_space = obs_space self.action_space = action_space self.n_agents = len(obs_space.original_space.spaces) self.n_actions = action_space.spaces[0].n self.h_size = config["model"]["lstm_cell_size"] agent_obs_space = obs_space.original_space.spaces[0] if isinstance(agent_obs_space, Dict): space_keys = set(agent_obs_space.spaces.keys()) if space_keys != {"obs", "action_mask"}: raise ValueError( "Dict obs space for agent must have keyset " "['obs', 'action_mask'], got {}".format(space_keys)) mask_shape = tuple(agent_obs_space.spaces["action_mask"].shape) if mask_shape != (self.n_actions, ): raise ValueError("Action mask shape must be {}, got {}".format( (self.n_actions, ), mask_shape)) self.has_action_mask = True self.obs_size = _get_size(agent_obs_space.spaces["obs"]) # The real agent obs space is nested inside the dict agent_obs_space = agent_obs_space.spaces["obs"] else: self.has_action_mask = False self.obs_size = _get_size(agent_obs_space) self.model = ModelCatalog.get_torch_model(agent_obs_space, self.n_actions, config["model"], default_model_cls=RNNModel) self.target_model = ModelCatalog.get_torch_model( agent_obs_space, self.n_actions, config["model"], default_model_cls=RNNModel) # Setup the mixer network. # The global state is just the stacked agent observations for now. self.state_shape = [self.obs_size, self.n_agents] if config["mixer"] is None: self.mixer = None self.target_mixer = None elif config["mixer"] == "qmix": self.mixer = QMixer(self.n_agents, self.state_shape, config["mixing_embed_dim"]) self.target_mixer = QMixer(self.n_agents, self.state_shape, config["mixing_embed_dim"]) elif config["mixer"] == "vdn": self.mixer = VDNMixer() self.target_mixer = VDNMixer() else: raise ValueError("Unknown mixer type {}".format(config["mixer"])) self.cur_epsilon = 1.0 self.update_target() # initial sync # Setup optimizer self.params = list(self.model.parameters()) if self.mixer: self.params += list(self.mixer.parameters()) self.loss = QMixLoss(self.model, self.target_model, self.mixer, self.target_mixer, self.n_agents, self.n_actions, self.config["double_q"], self.config["gamma"]) self.optimiser = RMSprop(params=self.params, lr=config["lr"], alpha=config["optim_alpha"], eps=config["optim_eps"]) @override(Policy) def compute_actions(self, obs_batch, state_batches=None, prev_action_batch=None, prev_reward_batch=None, info_batch=None, episodes=None, **kwargs): obs_batch, action_mask = self._unpack_observation(obs_batch) # Compute actions with th.no_grad(): q_values, hiddens = _mac( self.model, th.from_numpy(obs_batch), [th.from_numpy(np.array(s)) for s in state_batches]) avail = th.from_numpy(action_mask).float() masked_q_values = q_values.clone() masked_q_values[avail == 0.0] = -float("inf") # epsilon-greedy action selector random_numbers = th.rand_like(q_values[:, :, 0]) pick_random = (random_numbers < self.cur_epsilon).long() random_actions = Categorical(avail).sample().long() actions = (pick_random * random_actions + (1 - pick_random) * masked_q_values.max(dim=2)[1]) actions = actions.numpy() hiddens = [s.numpy() for s in hiddens] return TupleActions(list(actions.transpose([1, 0]))), hiddens, {} @override(Policy) def learn_on_batch(self, samples): obs_batch, action_mask = self._unpack_observation( samples[SampleBatch.CUR_OBS]) next_obs_batch, next_action_mask = self._unpack_observation( samples[SampleBatch.NEXT_OBS]) group_rewards = self._get_group_rewards(samples[SampleBatch.INFOS]) # These will be padded to shape [B * T, ...] [rew, action_mask, next_action_mask, act, dones, obs, next_obs], \ initial_states, seq_lens = \ chop_into_sequences( samples[SampleBatch.EPS_ID], samples[SampleBatch.UNROLL_ID], samples[SampleBatch.AGENT_INDEX], [ group_rewards, action_mask, next_action_mask, samples[SampleBatch.ACTIONS], samples[SampleBatch.DONES], obs_batch, next_obs_batch ], [samples["state_in_{}".format(k)] for k in range(len(self.get_initial_state()))], max_seq_len=self.config["model"]["max_seq_len"], dynamic_max=True) B, T = len(seq_lens), max(seq_lens) def to_batches(arr): new_shape = [B, T] + list(arr.shape[1:]) return th.from_numpy(np.reshape(arr, new_shape)) rewards = to_batches(rew).float() actions = to_batches(act).long() obs = to_batches(obs).reshape([B, T, self.n_agents, self.obs_size]).float() action_mask = to_batches(action_mask) next_obs = to_batches(next_obs).reshape( [B, T, self.n_agents, self.obs_size]).float() next_action_mask = to_batches(next_action_mask) # TODO(ekl) this treats group termination as individual termination terminated = to_batches(dones.astype(np.float32)).unsqueeze(2).expand( B, T, self.n_agents) # Create mask for where index is < unpadded sequence length filled = (np.reshape(np.tile(np.arange(T), B), [B, T]) < np.expand_dims(seq_lens, 1)).astype(np.float32) mask = th.from_numpy(filled).unsqueeze(2).expand(B, T, self.n_agents) # Compute loss loss_out, mask, masked_td_error, chosen_action_qvals, targets = \ self.loss(rewards, actions, terminated, mask, obs, next_obs, action_mask, next_action_mask) # Optimise self.optimiser.zero_grad() loss_out.backward() grad_norm = th.nn.utils.clip_grad_norm_( self.params, self.config["grad_norm_clipping"]) self.optimiser.step() mask_elems = mask.sum().item() stats = { "loss": loss_out.item(), "grad_norm": grad_norm if isinstance(grad_norm, float) else grad_norm.item(), "td_error_abs": masked_td_error.abs().sum().item() / mask_elems, "q_taken_mean": (chosen_action_qvals * mask).sum().item() / mask_elems, "target_mean": (targets * mask).sum().item() / mask_elems, } return {LEARNER_STATS_KEY: stats} @override(Policy) def get_initial_state(self): return [ s.expand([self.n_agents, -1]).numpy() for s in self.model.state_init() ] @override(Policy) def get_weights(self): return {"model": self.model.state_dict()} @override(Policy) def set_weights(self, weights): self.model.load_state_dict(weights["model"]) @override(Policy) def get_state(self): return { "model": self.model.state_dict(), "target_model": self.target_model.state_dict(), "mixer": self.mixer.state_dict() if self.mixer else None, "target_mixer": self.target_mixer.state_dict() if self.mixer else None, "cur_epsilon": self.cur_epsilon, } @override(Policy) def set_state(self, state): self.model.load_state_dict(state["model"]) self.target_model.load_state_dict(state["target_model"]) if state["mixer"] is not None: self.mixer.load_state_dict(state["mixer"]) self.target_mixer.load_state_dict(state["target_mixer"]) self.set_epsilon(state["cur_epsilon"]) self.update_target() def update_target(self): self.target_model.load_state_dict(self.model.state_dict()) if self.mixer is not None: self.target_mixer.load_state_dict(self.mixer.state_dict()) logger.debug("Updated target networks") def set_epsilon(self, epsilon): self.cur_epsilon = epsilon def _get_group_rewards(self, info_batch): group_rewards = np.array([ info.get(GROUP_REWARDS, [0.0] * self.n_agents) for info in info_batch ]) return group_rewards def _unpack_observation(self, obs_batch): """Unpacks the action mask / tuple obs from agent grouping. Returns: obs (Tensor): flattened obs tensor of shape [B, n_agents, obs_size] mask (Tensor): action mask, if any """ unpacked = _unpack_obs(np.array(obs_batch), self.observation_space.original_space, tensorlib=np) if self.has_action_mask: obs = np.concatenate([o["obs"] for o in unpacked], axis=1).reshape( [len(obs_batch), self.n_agents, self.obs_size]) action_mask = np.concatenate([o["action_mask"] for o in unpacked], axis=1).reshape([ len(obs_batch), self.n_agents, self.n_actions ]) else: obs = np.concatenate(unpacked, axis=1).reshape( [len(obs_batch), self.n_agents, self.obs_size]) action_mask = np.ones( [len(obs_batch), self.n_agents, self.n_actions]) return obs, action_mask
class QMixPolicyGraph(PolicyGraph): """QMix impl. Assumes homogeneous agents for now. You must use MultiAgentEnv.with_agent_groups() to group agents together for QMix. This creates the proper Tuple obs/action spaces and populates the '_group_rewards' info field. Action masking: to specify an action mask for individual agents, use a dict space with an action_mask key, e.g. {"obs": ob, "action_mask": mask}. The mask space must be `Box(0, 1, (n_actions,))`. """ def __init__(self, obs_space, action_space, config): _validate(obs_space, action_space) config = dict(ray.rllib.agents.qmix.qmix.DEFAULT_CONFIG, **config) self.config = config self.observation_space = obs_space self.action_space = action_space self.n_agents = len(obs_space.original_space.spaces) self.n_actions = action_space.spaces[0].n self.h_size = config["model"]["lstm_cell_size"] agent_obs_space = obs_space.original_space.spaces[0] if isinstance(agent_obs_space, Dict): space_keys = set(agent_obs_space.spaces.keys()) if space_keys != {"obs", "action_mask"}: raise ValueError( "Dict obs space for agent must have keyset " "['obs', 'action_mask'], got {}".format(space_keys)) mask_shape = tuple(agent_obs_space.spaces["action_mask"].shape) if mask_shape != (self.n_actions, ): raise ValueError("Action mask shape must be {}, got {}".format( (self.n_actions, ), mask_shape)) self.has_action_mask = True self.obs_size = _get_size(agent_obs_space.spaces["obs"]) # The real agent obs space is nested inside the dict agent_obs_space = agent_obs_space.spaces["obs"] else: self.has_action_mask = False self.obs_size = _get_size(agent_obs_space) self.model = ModelCatalog.get_torch_model( agent_obs_space, self.n_actions, config["model"], default_model_cls=RNNModel) self.target_model = ModelCatalog.get_torch_model( agent_obs_space, self.n_actions, config["model"], default_model_cls=RNNModel) # Setup the mixer network. # The global state is just the stacked agent observations for now. self.state_shape = [self.obs_size, self.n_agents] if config["mixer"] is None: self.mixer = None self.target_mixer = None elif config["mixer"] == "qmix": self.mixer = QMixer(self.n_agents, self.state_shape, config["mixing_embed_dim"]) self.target_mixer = QMixer(self.n_agents, self.state_shape, config["mixing_embed_dim"]) elif config["mixer"] == "vdn": self.mixer = VDNMixer() self.target_mixer = VDNMixer() else: raise ValueError("Unknown mixer type {}".format(config["mixer"])) self.cur_epsilon = 1.0 self.update_target() # initial sync # Setup optimizer self.params = list(self.model.parameters()) self.loss = QMixLoss(self.model, self.target_model, self.mixer, self.target_mixer, self.n_agents, self.n_actions, self.config["double_q"], self.config["gamma"]) self.optimiser = RMSprop( params=self.params, lr=config["lr"], alpha=config["optim_alpha"], eps=config["optim_eps"]) @override(PolicyGraph) def compute_actions(self, obs_batch, state_batches=None, prev_action_batch=None, prev_reward_batch=None, info_batch=None, episodes=None, **kwargs): obs_batch, action_mask = self._unpack_observation(obs_batch) # Compute actions with th.no_grad(): q_values, hiddens = _mac( self.model, th.from_numpy(obs_batch), [th.from_numpy(np.array(s)) for s in state_batches]) avail = th.from_numpy(action_mask).float() masked_q_values = q_values.clone() masked_q_values[avail == 0.0] = -float("inf") # epsilon-greedy action selector random_numbers = th.rand_like(q_values[:, :, 0]) pick_random = (random_numbers < self.cur_epsilon).long() random_actions = Categorical(avail).sample().long() actions = (pick_random * random_actions + (1 - pick_random) * masked_q_values.max(dim=2)[1]) actions = actions.numpy() hiddens = [s.numpy() for s in hiddens] return TupleActions(list(actions.transpose([1, 0]))), hiddens, {} @override(PolicyGraph) def learn_on_batch(self, samples): obs_batch, action_mask = self._unpack_observation(samples["obs"]) group_rewards = self._get_group_rewards(samples["infos"]) # These will be padded to shape [B * T, ...] [rew, action_mask, act, dones, obs], initial_states, seq_lens = \ chop_into_sequences( samples["eps_id"], samples["agent_index"], [ group_rewards, action_mask, samples["actions"], samples["dones"], obs_batch ], [samples["state_in_{}".format(k)] for k in range(len(self.get_initial_state()))], max_seq_len=self.config["model"]["max_seq_len"], dynamic_max=True, _extra_padding=1) # TODO(ekl) adding 1 extra unit of padding here, since otherwise we # lose the terminating reward and the Q-values will be unanchored! B, T = len(seq_lens), max(seq_lens) + 1 def to_batches(arr): new_shape = [B, T] + list(arr.shape[1:]) return th.from_numpy(np.reshape(arr, new_shape)) rewards = to_batches(rew)[:, :-1].float() actions = to_batches(act)[:, :-1].long() obs = to_batches(obs).reshape([B, T, self.n_agents, self.obs_size]).float() action_mask = to_batches(action_mask) # TODO(ekl) this treats group termination as individual termination terminated = to_batches(dones.astype(np.float32)).unsqueeze(2).expand( B, T, self.n_agents)[:, :-1] filled = (np.reshape(np.tile(np.arange(T), B), [B, T]) < np.expand_dims(seq_lens, 1)).astype(np.float32) mask = th.from_numpy(filled).unsqueeze(2).expand(B, T, self.n_agents)[:, :-1] mask[:, 1:] = mask[:, 1:] * (1 - terminated[:, :-1]) # Compute loss loss_out, mask, masked_td_error, chosen_action_qvals, targets = \ self.loss(rewards, actions, terminated, mask, obs, action_mask) # Optimise self.optimiser.zero_grad() loss_out.backward() grad_norm = th.nn.utils.clip_grad_norm_( self.params, self.config["grad_norm_clipping"]) self.optimiser.step() mask_elems = mask.sum().item() stats = { "loss": loss_out.item(), "grad_norm": grad_norm if isinstance(grad_norm, float) else grad_norm.item(), "td_error_abs": masked_td_error.abs().sum().item() / mask_elems, "q_taken_mean": (chosen_action_qvals * mask).sum().item() / mask_elems, "target_mean": (targets * mask).sum().item() / mask_elems, } return {"stats": stats}, {} @override(PolicyGraph) def get_initial_state(self): return [ s.expand([self.n_agents, -1]).numpy() for s in self.model.state_init() ] @override(PolicyGraph) def get_weights(self): return {"model": self.model.state_dict()} @override(PolicyGraph) def set_weights(self, weights): self.model.load_state_dict(weights["model"]) @override(PolicyGraph) def get_state(self): return { "model": self.model.state_dict(), "target_model": self.target_model.state_dict(), "mixer": self.mixer.state_dict() if self.mixer else None, "target_mixer": self.target_mixer.state_dict() if self.mixer else None, "cur_epsilon": self.cur_epsilon, } @override(PolicyGraph) def set_state(self, state): self.model.load_state_dict(state["model"]) self.target_model.load_state_dict(state["target_model"]) if state["mixer"] is not None: self.mixer.load_state_dict(state["mixer"]) self.target_mixer.load_state_dict(state["target_mixer"]) self.set_epsilon(state["cur_epsilon"]) self.update_target() def update_target(self): self.target_model.load_state_dict(self.model.state_dict()) if self.mixer is not None: self.target_mixer.load_state_dict(self.mixer.state_dict()) logger.debug("Updated target networks") def set_epsilon(self, epsilon): self.cur_epsilon = epsilon def _get_group_rewards(self, info_batch): group_rewards = np.array([ info.get(GROUP_REWARDS, [0.0] * self.n_agents) for info in info_batch ]) return group_rewards def _unpack_observation(self, obs_batch): """Unpacks the action mask / tuple obs from agent grouping. Returns: obs (Tensor): flattened obs tensor of shape [B, n_agents, obs_size] mask (Tensor): action mask, if any """ unpacked = _unpack_obs( np.array(obs_batch), self.observation_space.original_space, tensorlib=np) if self.has_action_mask: obs = np.concatenate( [o["obs"] for o in unpacked], axis=1).reshape([len(obs_batch), self.n_agents, self.obs_size]) action_mask = np.concatenate( [o["action_mask"] for o in unpacked], axis=1).reshape( [len(obs_batch), self.n_agents, self.n_actions]) else: obs = np.concatenate( unpacked, axis=1).reshape([len(obs_batch), self.n_agents, self.obs_size]) action_mask = np.ones( [len(obs_batch), self.n_agents, self.n_actions]) return obs, action_mask