def test_starmap(self): pool = Pool(processes=2) x = [(t.ones([10]) * i, t.ones([10]) * i) for i in range(5)] assert all(out == expect_out for out, expect_out in zip( pool.starmap(func2, x), [0, 20, 40, 60, 80])) pool.close() pool.join()
class MADDPG(TorchFramework): """ MADDPG is a centralized multi-agent training framework, it alleviates the unstable reward problem caused by the disturbance of other agents by gathering all agents observations and train a global critic. This global critic observes all actions and all states from all agents. """ # Since the number of sub-policies is automatically determined, # they are not considered here. _is_top = ["actor_target", "critic_target"] _is_restorable = ["actor_target", "critic_target"] def __init__(self, agent_num, actor: Union[NeuralNetworkModule, nn.Module], actor_target: Union[NeuralNetworkModule, nn.Module], critic: Union[NeuralNetworkModule, nn.Module], critic_target: Union[NeuralNetworkModule, nn.Module], optimizer: Callable, criterion: Callable, available_devices: Union[list, None] = None, sub_policy_num: int = 1, lr_scheduler: Callable = None, lr_scheduler_args: Tuple[Tuple, Tuple, Tuple] = (), lr_scheduler_kwargs: Tuple[Dict, Dict, Dict] = (), batch_size: int = 100, update_rate: float = 0.005, learning_rate: float = 0.001, discount: float = 0.99, replay_size: int = 500000, replay_device: Union[str, t.device] = "cpu", replay_buffer: Buffer = None, reward_func: Callable = None, action_concat_func: Callable = None, action_alter_func: Callable = None, state_split_func: Callable = None, visualize: bool = False, pool: Any = None): """ See Also: :class:`.DDPG` Hint: Please reference: - :meth:`MADDPG.action_concat_function` - :meth:`MADDPG.action_alter_function` - :meth:`MADDPG.state_split_function` for ``action_concat_func``, ``action_alter_func``, and ``state_split_func`` design, if your actor does not output a simple action of shape ``[batch_size, action_dim]`` which should be directly accepted by your critic. Args: actor: Actor network module. actor_target: Target actor network module. critic: Critic network module. critic_target: Target critic network module. optimizer: Optimizer used to optimize ``actor`` and ``critic``. criterion: Criterion used to evaluate the value loss. lr_scheduler: Learning rate scheduler of ``optimizer``. lr_scheduler_args: Arguments of the learning rate scheduler. lr_scheduler_kwargs: Keyword arguments of the learning rate scheduler. batch_size: Batch size used during training. update_rate: :math:`\\tau` used to update target networks. Target parameters are updated as: :math:`\\theta_t = \\theta * \\tau + \\theta_t * (1 - \\tau)` learning_rate: Learning rate of the optimizer, not compatible with ``lr_scheduler``. discount: :math:`\\gamma` used in the bellman function. replay_size: Replay buffer size. Not compatible with ``replay_buffer``. replay_device: Device where the replay buffer locates on, Not compatible with ``replay_buffer``. replay_buffer: Custom replay buffer. reward_func: Reward function used in training. action_concat_func: Action concatenation function. action_alter_func: Action alternation function. state_split_func: All states spli function. visualize: Whether visualize the network flow in the first pass. """ self.batch_size = batch_size self.update_rate = update_rate self.discount = discount self.visualize = visualize self.agent_num = agent_num self.actors = [copy.deepcopy(actor) for _ in range(sub_policy_num)] self.actor_targets = [copy.deepcopy(actor_target) for _ in range(sub_policy_num)] self.critics = [copy.deepcopy(critic) for _ in range(sub_policy_num)] self.critic_targets = [copy.deepcopy(critic_target) for _ in range(sub_policy_num)] self.actor_optims = [optimizer(ac.parameters(), lr=learning_rate) for ac in self.actors] self.critic_optims = [optimizer(cr.parameters(), lr=learning_rate) for cr in self.critics] self.sub_policy_num = sub_policy_num self.pool = Pool() if pool is None else pool self.replay_buffer = (Buffer(replay_size, replay_device) if replay_buffer is None else replay_buffer) # Distribute models to available devices. if available_devices is not None and available_devices: nets = self.actors + self.actor_targets + \ self.critics + self.critic_targets # Only actors and critics are related connections = {(i, i + self.sub_policy_num * 2): 1 for i in range(self.sub_policy_num)} # We do not need the assigner after construction. _assigner = ModelAssigner(nets, connections, available_devices) # For debugging: (act_assign, act_target_assign, critic_assign, critic_target_assign) = \ np.array_split(_assigner.assignment, 4) default_logger.log("Actors assigned to:") default_logger.log(act_assign) default_logger.log("Actors (target) assigned to:") default_logger.log(act_target_assign) default_logger.log("Critics assigned to:") default_logger.log(critic_assign) default_logger.log("Critics (target) assigned to:") default_logger.log(critic_target_assign) # Create wrapper for target actors and target critics. # So their parameters can be saved. self.actor_target = nn.Module() self.critic_target = nn.Module() for actor_t, idx in zip(self.actor_targets, range(self.sub_policy_num)): self.actor_target.add_module("actor_{}".format(idx), actor_t) for critic_t, idx in zip(self.critic_targets, range(self.sub_policy_num)): self.critic_target.add_module("critic_{}".format(idx), critic_t) # Make sure target and online networks have the same weight with t.no_grad(): self.pool.starmap(hard_update, zip(self.actors, self.actor_targets)) self.pool.starmap(hard_update, zip(self.critics, self.critic_targets)) if lr_scheduler is not None: self.actor_lr_schs = [lr_scheduler(ac_opt, *lr_scheduler_args[0], *lr_scheduler_kwargs[0]) for ac_opt in self.actor_optims] self.critic_lr_schs = [lr_scheduler(cr_opt, *lr_scheduler_args[1], *lr_scheduler_kwargs[1]) for cr_opt in self.critic_optims] self.criterion = criterion self.reward_func = (MADDPG.bellman_function if reward_func is None else reward_func) self.action_alter_func = (MADDPG.action_alter_function if action_alter_func is None else action_alter_func) self.action_concat_func = (MADDPG.action_concat_function if action_concat_func is None else action_concat_func) self.state_split_func = (MADDPG.state_split_function if state_split_func is None else state_split_func) super(MADDPG, self).__init__() def act(self, state: Dict[str, Any], use_target: bool = False, index: int = -1, **__): """ Use actor network to produce an action for the current state. Args: state: Current state. use_target: Whether use the target network. index: The sub-policy index to use. Returns: Action of shape ``[batch_size, action_dim]``. """ if index not in range(self.sub_policy_num): index = np.random.randint(0, self.sub_policy_num) if use_target: return safe_call(self.actor_targets[index], state) else: return safe_call(self.actors[index], state) def act_with_noise(self, state: Dict[str, Any], noise_param: Tuple = (0.0, 1.0), ratio: float = 1.0, mode: str = "uniform", use_target: bool = False, index: int = -1, **__): """ Use actor network to produce a noisy action for the current state. See Also: :mod:`machin.frame.noise.action_space_noise` Args: state: Current state. noise_param: Noise params. ratio: Noise ratio. mode: Noise mode. Supported are: ``"uniform", "normal", "clipped_normal", "ou"`` use_target: Whether use the target network. index: The sub-policy index to use. Returns: Noisy action of shape ``[batch_size, action_dim]``. """ if mode == "uniform": return add_uniform_noise_to_action( self.act(state, use_target, index), noise_param, ratio ) if mode == "normal": return add_normal_noise_to_action( self.act(state, use_target, index), noise_param, ratio ) if mode == "clipped_normal": return add_clipped_normal_noise_to_action( self.act(state, use_target, index), noise_param, ratio ) if mode == "ou": return add_ou_noise_to_action( self.act(state, use_target, index), noise_param, ratio ) raise RuntimeError("Unknown noise type: " + str(mode)) def act_discreet(self, state: Dict[str, Any], use_target: bool = False, index: int = -1): """ Use actor network to produce a discreet action for the current state. Notes: actor network must output a probability tensor, of shape (batch_size, action_dims), and has a sum of 1 for each row in dimension 1. Args: state: Current state. use_target: Whether to use the target network. index: The sub-policy index to use. Returns: Action of shape ``[batch_size, 1]``. """ if index not in range(self.sub_policy_num): index = np.random.randint(0, self.sub_policy_num) if use_target: result = safe_call(self.actor_targets[index], state) else: result = safe_call(self.actors[index], state) assert_output_is_probs(result) batch_size = result.shape[0] result = t.argmax(result, dim=1).view(batch_size, 1) return result def act_discreet_with_noise(self, state: Dict[str, Any], use_target: bool = False, index=-1): """ Use actor network to produce a noisy discreet action for the current state. Notes: actor network must output a probability tensor, of shape (batch_size, action_dims), and has a sum of 1 for each row in dimension 1. Args: state: Current state. use_target: Whether to use the target network. index: The sub-policy index to use. Returns: Noisy action of shape ``[batch_size, 1]``. """ if index not in range(self.sub_policy_num): index = np.random.randint(0, self.sub_policy_num) if use_target: result = safe_call(self.actor_targets[index], state) else: result = safe_call(self.actors[index], state) assert_output_is_probs(result) dist = Categorical(result) batch_size = result.shape[0] return dist.sample([batch_size, 1]) def criticize(self, all_states, all_actions, use_target=False, index=-1): """ Use critic network to evaluate current value. Args: all_states: Current states of all actors. all_actions: Current actions of all actors. use_target: Whether to use the target network. index: The sub-critic index to use. Returns: Value of shape ``[batch_size, 1]``. """ if index not in range(self.sub_policy_num): index = np.random.randint(0, self.sub_policy_num) if use_target: return safe_call(self.critic_targets[index], all_states, all_actions) else: return safe_call(self.critics[index], all_states, all_actions) def store_transition(self, transition: Union[MultiAgentTransition, Dict]): """ Add a transition sample to the replay buffer. """ self.replay_buffer.append(transition, required_attrs=( "state", "all_states", "all_actions", "all_next_states", "reward", "terminal", "index" )) def store_episode(self, episode: List[Union[MultiAgentTransition, Dict]]): """ Add a full episode of transition samples to the replay buffer. """ for trans in episode: self.replay_buffer.append(trans, required_attrs=( "state", "all_states", "all_actions", "all_next_states", "reward", "terminal", "index" )) def update(self, update_value=True, update_policy=True, update_target=True, concatenate_samples=True, average_target_parameter=False): """ Update network weights by sampling from replay buffer. Args: update_value: Whether to update the Q network. update_policy: Whether to update the actor network. update_target: Whether to update targets. concatenate_samples: Whether to concatenate the samples. average_target_parameter: Whether to average sub target networks, including actors and critics. Returns: mean value of estimated policy value, value loss """ batch_size, (state, all_states, all_actions, all_next_states, reward, terminal, agent_indexes, *others) = \ self.replay_buffer.sample_batch(self.batch_size, concatenate_samples, sample_attrs=[ "state", "all_states", "all_actions", "all_next_states", "reward", "terminal", "index", "*" ]) def update_inner(i): with t.no_grad(): # Produce all_next_actions for all_next_states, using target i, # so the target critic can evaluate the value of the next step. all_next_actions_t = \ self.action_concat_func( self.act( self.state_split_func( all_next_states, batch_size, self.agent_num, others ), True, i ), batch_size, self.agent_num, others ) # Update critic network first # Generate target value using target critic. with t.no_grad(): next_value = self.criticize(all_next_states, all_next_actions_t, True, i) next_value = next_value.view(batch_size, -1) y_i = self.reward_func(reward, self.discount, next_value, terminal, others) # action contain actions of all agents, same for state cur_value = self.criticize(all_states, all_actions, index=i) value_loss = self.criterion(cur_value, y_i.to(cur_value.device)) if update_value: self.critics[i].zero_grad() value_loss.backward() self.critic_optims[i].step() # Update actor network cur_all_actions = copy.deepcopy(all_actions) cur_all_actions = self.action_alter_func( self.act(state, index=i), cur_all_actions, agent_indexes, batch_size, self.agent_num, others ) act_value = self.criticize(state, cur_all_actions, index=i) # "-" is applied because we want to maximize J_b(u), # but optimizer workers by minimizing the target act_policy_loss = -act_value.mean() if update_policy: self.actors[i].zero_grad() act_policy_loss.backward() self.actor_optims[i].step() # Update target networks if update_target: soft_update(self.actor_targets[i], self.actors[i], self.update_rate) soft_update(self.critic_targets[i], self.critics[i], self.update_rate) return act_policy_loss.item(), value_loss.item() all_loss = self.pool.map(update_inner, range(self.sub_policy_num)) mean_loss = t.tensor(all_loss).mean(dim=0) if average_target_parameter: self.average_target_parameters() # returns action value and policy loss return -mean_loss[0].item(), mean_loss[1].item() def update_lr_scheduler(self): """ Update learning rate schedulers. """ if hasattr(self, "actor_lr_schs"): for actor_lr_sch in self.actor_lr_schs: actor_lr_sch.step() if hasattr(self, "critic_lr_schs"): for critic_lr_sch in self.critic_lr_schs: critic_lr_sch.step() def load(self, model_dir, network_map=None, version=-1): # DOC INHERITED super(MADDPG, self).load(model_dir, network_map, version) with t.no_grad(): self.pool.starmap(hard_update, zip(self.actors, self.actor_targets)) self.pool.starmap(hard_update, zip(self.critics, self.critic_targets)) def average_target_parameters(self): """ Average parameters of sub-policies and sub-critics. Averaging is performed on target networks. """ with t.no_grad(): actor_params = [net.parameters() for net in self.actor_targets] critic_params = [net.parameters() for net in self.critic_targets] self.pool.starmap( _average_parameters, itertools.chain(zip(*actor_params), zip(*critic_params)) ) @staticmethod def action_alter_function(raw_output_action, all_actions, indexes, batch_size, agent_num, *_): """ This function is used to alternate an action inside all actions, using output from the online actor network. Args: raw_output_action: Raw output of actor. all_actions: All actions of all agents. indexes: Agent index among all agents. batch_size: Sampled batch size. agent_num: Number of agents. Returns: Alternated all actions. """ all_actions["action"][indexes] = \ raw_output_action.view(batch_size, agent_num, -1) return all_actions @staticmethod def state_split_function(all_states: Dict[str, t.Tensor], batch_size: int, agent_num: int, *_): """ This function is used to split states from multiple agents into batched single states, usable by actor network. Args: all_states: All states of all agents. batch_size: Sampled batch size. agent_num: Number of agents. Returns: Splitted states. """ all_states["state"] = all_states["state"]\ .view(batch_size * agent_num, -1) return all_states @staticmethod def action_concat_function(raw_output_action, batch_size, agent_num, *_): """ This function is used to transform the actions produced by actor from the splitted states, to the final output of all actions of all agents. Args: raw_output_action: Raw output of actor. batch_size: Sampled batch size. agent_num: Number of agents. Returns: Concatenated actions. """ return {"action": raw_output_action.view(batch_size, agent_num)} @staticmethod def bellman_function(reward, discount, next_value, terminal, *_): next_value = next_value.to(reward.device) terminal = terminal.to(reward.device) return reward + discount * (1 - terminal) * next_value