def __init__(self, obs_space, action_space, config): _validate(obs_space, action_space) config = dict(ray.rllib.agents.qmix.qmix.DEFAULT_CONFIG, **config) self.config = config self.observation_space = obs_space self.action_space = action_space self.n_agents = len(obs_space.original_space.spaces) self.n_actions = action_space.spaces[0].n self.h_size = config["model"]["lstm_cell_size"] agent_obs_space = obs_space.original_space.spaces[0] if isinstance(agent_obs_space, Dict): space_keys = set(agent_obs_space.spaces.keys()) if space_keys != {"obs", "action_mask"}: raise ValueError( "Dict obs space for agent must have keyset " "['obs', 'action_mask'], got {}".format(space_keys)) mask_shape = tuple(agent_obs_space.spaces["action_mask"].shape) if mask_shape != (self.n_actions, ): raise ValueError("Action mask shape must be {}, got {}".format( (self.n_actions, ), mask_shape)) self.has_action_mask = True self.obs_size = _get_size(agent_obs_space.spaces["obs"]) else: self.has_action_mask = False self.obs_size = _get_size(agent_obs_space) self.model = RNNModel(self.obs_size, self.h_size, self.n_actions) self.target_model = RNNModel(self.obs_size, self.h_size, self.n_actions) # Setup the mixer network. # The global state is just the stacked agent observations for now. self.state_shape = [self.obs_size, self.n_agents] if config["mixer"] is None: self.mixer = None self.target_mixer = None elif config["mixer"] == "qmix": self.mixer = QMixer(self.n_agents, self.state_shape, config["mixing_embed_dim"]) self.target_mixer = QMixer(self.n_agents, self.state_shape, config["mixing_embed_dim"]) elif config["mixer"] == "vdn": self.mixer = VDNMixer() self.target_mixer = VDNMixer() else: raise ValueError("Unknown mixer type {}".format(config["mixer"])) self.cur_epsilon = 1.0 self.update_target() # initial sync # Setup optimizer self.params = list(self.model.parameters()) self.loss = QMixLoss(self.model, self.target_model, self.mixer, self.target_mixer, self.n_agents, self.n_actions, self.config["double_q"], self.config["gamma"]) self.optimiser = RMSprop(params=self.params, lr=config["lr"], alpha=config["optim_alpha"], eps=config["optim_eps"])
def __init__(self, obs_space, action_space, config): _validate(obs_space, action_space) config = dict(ray.rllib.agents.qmix.qmix.DEFAULT_CONFIG, **config) self.framework = "torch" super().__init__(obs_space, action_space, config) self.n_agents = len(obs_space.original_space.spaces) self.n_actions = action_space.spaces[0].n self.h_size = config["model"]["lstm_cell_size"] self.has_env_global_state = False self.has_action_mask = False self.device = (torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")) agent_obs_space = obs_space.original_space.spaces[0] if isinstance(agent_obs_space, Dict): space_keys = set(agent_obs_space.spaces.keys()) if "obs" not in space_keys: raise ValueError( "Dict obs space must have subspace labeled `obs`") self.obs_size = _get_size(agent_obs_space.spaces["obs"]) if "action_mask" in space_keys: mask_shape = tuple(agent_obs_space.spaces["action_mask"].shape) if mask_shape != (self.n_actions, ): raise ValueError( "Action mask shape must be {}, got {}".format( (self.n_actions, ), mask_shape)) self.has_action_mask = True if ENV_STATE in space_keys: self.env_global_state_shape = _get_size( agent_obs_space.spaces[ENV_STATE]) self.has_env_global_state = True else: self.env_global_state_shape = (self.obs_size, self.n_agents) # The real agent obs space is nested inside the dict config["model"]["full_obs_space"] = agent_obs_space agent_obs_space = agent_obs_space.spaces["obs"] else: self.obs_size = _get_size(agent_obs_space) self.model = ModelCatalog.get_model_v2( agent_obs_space, action_space.spaces[0], self.n_actions, config["model"], framework="torch", name="model", default_model=RNNModel).to(self.device) self.target_model = ModelCatalog.get_model_v2( agent_obs_space, action_space.spaces[0], self.n_actions, config["model"], framework="torch", name="target_model", default_model=RNNModel).to(self.device) self.exploration = self._create_exploration() # Setup the mixer network. if config["mixer"] is None: self.mixer = None self.target_mixer = None elif config["mixer"] == "qmix": self.mixer = QMixer(self.n_agents, self.env_global_state_shape, config["mixing_embed_dim"]).to(self.device) self.target_mixer = QMixer( self.n_agents, self.env_global_state_shape, config["mixing_embed_dim"]).to(self.device) elif config["mixer"] == "vdn": self.mixer = VDNMixer().to(self.device) self.target_mixer = VDNMixer().to(self.device) else: raise ValueError("Unknown mixer type {}".format(config["mixer"])) self.cur_epsilon = 1.0 self.update_target() # initial sync # Setup optimizer self.params = list(self.model.parameters()) if self.mixer: self.params += list(self.mixer.parameters()) self.loss = QMixLoss(self.model, self.target_model, self.mixer, self.target_mixer, self.n_agents, self.n_actions, self.config["double_q"], self.config["gamma"]) self.optimiser = RMSprop( params=self.params, lr=config["lr"], alpha=config["optim_alpha"], eps=config["optim_eps"])