def __init__(self, actors: List[Union[NeuralNetworkModule, nn.Module]], actor_targets: List[Union[NeuralNetworkModule, nn.Module]], critics: List[Union[NeuralNetworkModule, nn.Module]], critic_targets: List[Union[NeuralNetworkModule, nn.Module]], critic_visible_actors: List[List[int]], optimizer: Callable, criterion: Callable, *_, sub_policy_num: int = 0, lr_scheduler: Callable = None, lr_scheduler_args: Tuple[Tuple, Tuple] = None, lr_scheduler_kwargs: Tuple[Dict, Dict] = None, batch_size: int = 100, update_rate: float = 0.001, actor_learning_rate: float = 0.0005, critic_learning_rate: float = 0.001, discount: float = 0.99, gradient_max: float = np.inf, replay_size: int = 500000, replay_device: Union[str, t.device] = "cpu", replay_buffer: Buffer = None, visualize: bool = False, visualize_dir: str = "", use_jit: bool = True, pool_type: str = "thread", pool_size: int = None): """ See Also: :class:`.DDPG` Note: In order to parallelize agent inference, a process pool is used internally. However, in order to minimize memory copy / CUDA memory copy, the location of all of your models must be either "cpu", or "cuda" (Using multiple CUDA devices is supported). Note: MADDPG framework **does not require** all of your actors are homogeneous. Each pair of your actors and critcs could be heterogeneous. Note: Suppose you have three pair of actors and critics, with index 0, 1, 2. If critic 0 can observe the action of actor 0 and 1, critic 1 can observe the action of actor 1 and 2, critic 2 can observe the action of actor 2 and 0, the ``critic_visible_actors`` should be:: [[0, 1], [1, 2], [2, 0]] Note: This implementation contains: - Ensemble Training This implementation does not contain: - Inferring other agents' policies - Mixed continuous/discrete action spaces Args: actors: Actor network modules. actor_targets: Target actor network modules. critics: Critic network modules. critic_targets: Target critic network modules. critic_visible_actors: Indexes of visible actors for each critic. optimizer: Optimizer used to optimize ``actors`` and ``critics``. criterion: Criterion used to evaluate the value loss. sub_policy_num: Times to replicate each actor. Equals to `ensemble_policy_num - 1` lr_scheduler: Learning rate scheduler of ``optimizer``. lr_scheduler_args: Arguments of the learning rate scheduler. lr_scheduler_kwargs: Keyword arguments of the learning rate scheduler. batch_size: Batch size used during training. update_rate: :math:`\\tau` used to update target networks. Target parameters are updated as: :math:`\\theta_t = \\theta * \\tau + \\theta_t * (1 - \\tau)` actor_learning_rate: Learning rate of the actor optimizer, not compatible with ``lr_scheduler``. critic_learning_rate: Learning rate of the critic optimizer, not compatible with ``lr_scheduler``. discount: :math:`\\gamma` used in the bellman function. replay_size: Replay buffer size for each actor. Not compatible with ``replay_buffer``. replay_device: Device where the replay buffer locates on, Not compatible with ``replay_buffer``. replay_buffer: Custom replay buffer. Will be replicated for actor. visualize: Whether visualize the network flow in the first pass. visualize_dir: Visualized graph save directory. use_jit: Whether use torch jit to perform the forward pass in parallel instead of using the internal pool. Provides significant speed and efficiency advantage, but requires actors and critics convertible to TorchScript. pool_type: Type of the internal execution pool, either "process" or "thread". pool_size: Size of the internal execution pool. """ assert pool_type in ("process", "thread") self.batch_size = batch_size self.update_rate = update_rate self.discount = discount self.has_visualized = False self.visualize = visualize self.visualize_dir = visualize_dir self.grad_max = gradient_max self.critic_visible_actors = critic_visible_actors # create ensembles of policies self.actors = [[actor] + [copy.deepcopy(actor) for _ in range(sub_policy_num)] for actor in actors] self.actor_targets = [ [actor_target] + [copy.deepcopy(actor_target) for _ in range(sub_policy_num)] for actor_target in actor_targets ] self.critics = critics self.critic_targets = critic_targets self.actor_optims = [[ optimizer(acc.parameters(), lr=actor_learning_rate) for acc in ac ] for ac in self.actors] self.critic_optims = [ optimizer(cr.parameters(), lr=critic_learning_rate) for cr in self.critics ] self.ensemble_size = sub_policy_num + 1 self.replay_buffers = [ SHMBuffer(replay_size, replay_device) if replay_buffer is None else copy.deepcopy(replay_buffer) for _ in range(len(actors)) ] # create the pool used to update() # check devices of all parameters, # determine the pool process starting method. device = self._check_parameters_device( itertools.chain(*self.actors, self.critics)) self.device = device self.pool_type = pool_type if pool_type == "process": self.pool = P2PPool(processes=pool_size, is_recursive=False, is_copy_tensor=False, share_method=device) elif pool_type == "thread": self.pool = ThreadPool(processes=pool_size) # Create wrapper for target actors and target critics. # So their parameters can be saved. self.all_actor_target = nn.Module() self.all_critic_target = nn.Module() for ac, idx in zip(self.actor_targets, range(len(actors))): for acc, idxx in zip(ac, range(self.ensemble_size)): acc.share_memory() self.all_actor_target.add_module( "actor_{}_{}".format(idx, idxx), acc) for cr, idx in zip(self.critic_targets, range(len(critics))): cr.share_memory() self.all_critic_target.add_module("critic_{}".format(idx), cr) # Make sure target and online networks have the same weight with t.no_grad(): self.pool.starmap( hard_update, zip(itertools.chain(*self.actors), itertools.chain(*self.actor_targets))) self.pool.starmap(hard_update, zip(self.critics, self.critic_targets)) if lr_scheduler is not None: if lr_scheduler_args is None: lr_scheduler_args = ((), ()) if lr_scheduler_kwargs is None: lr_scheduler_kwargs = ({}, {}) self.actor_lr_schs = [ lr_scheduler(ac_opt, *lr_scheduler_args[0], *lr_scheduler_kwargs[0]) for acc_opt in self.actor_optims for ac_opt in acc_opt ] self.critic_lr_schs = [ lr_scheduler(cr_opt, *lr_scheduler_args[1], *lr_scheduler_kwargs[1]) for cr_opt in self.critic_optims ] self.criterion = criterion # make preparations if use jit # jit modules will share the same parameter memory with original # modules, therefore it is safe to use them together. self.use_jit = use_jit self.jit_actors = [] self.jit_actor_targets = [] if use_jit: # only compile actors, since critics will not be # launched in parallel for ac in self.actors: jit_actors = [] jit_actor_targets = [] for acc in ac: # exclude "self" by truncating element 0 actor_arg_spec = inspect.getfullargspec(acc.forward) jit_actor = t.jit.script(acc) jit_actor.arg_spec = actor_arg_spec jit_actor.model_type = type(acc) jit_actors.append(jit_actor) jit_actor_target = t.jit.script(acc) jit_actor_target.arg_spec = actor_arg_spec jit_actor_target.model_type = type(acc) jit_actor_targets.append(jit_actor_target) self.jit_actors.append(jit_actors) self.jit_actor_targets.append(jit_actor_targets) super(MADDPG, self).__init__()
class MADDPG(TorchFramework): """ MADDPG is a centralized multi-agent training framework, it alleviates the unstable reward problem caused by the disturbance of other agents by gathering all agents observations and train a global critic. This global critic observes all actions and all states from all agents. """ # Since the number of sub-policies is automatically determined, # they are not considered here. _is_top = ["all_actor_target", "all_critic_target"] _is_restorable = ["all_actor_target", "all_critic_target"] def __init__(self, actors: List[Union[NeuralNetworkModule, nn.Module]], actor_targets: List[Union[NeuralNetworkModule, nn.Module]], critics: List[Union[NeuralNetworkModule, nn.Module]], critic_targets: List[Union[NeuralNetworkModule, nn.Module]], critic_visible_actors: List[List[int]], optimizer: Callable, criterion: Callable, *_, sub_policy_num: int = 0, lr_scheduler: Callable = None, lr_scheduler_args: Tuple[Tuple, Tuple] = None, lr_scheduler_kwargs: Tuple[Dict, Dict] = None, batch_size: int = 100, update_rate: float = 0.001, actor_learning_rate: float = 0.0005, critic_learning_rate: float = 0.001, discount: float = 0.99, gradient_max: float = np.inf, replay_size: int = 500000, replay_device: Union[str, t.device] = "cpu", replay_buffer: Buffer = None, visualize: bool = False, visualize_dir: str = "", use_jit: bool = True, pool_type: str = "thread", pool_size: int = None): """ See Also: :class:`.DDPG` Note: In order to parallelize agent inference, a process pool is used internally. However, in order to minimize memory copy / CUDA memory copy, the location of all of your models must be either "cpu", or "cuda" (Using multiple CUDA devices is supported). Note: MADDPG framework **does not require** all of your actors are homogeneous. Each pair of your actors and critcs could be heterogeneous. Note: Suppose you have three pair of actors and critics, with index 0, 1, 2. If critic 0 can observe the action of actor 0 and 1, critic 1 can observe the action of actor 1 and 2, critic 2 can observe the action of actor 2 and 0, the ``critic_visible_actors`` should be:: [[0, 1], [1, 2], [2, 0]] Note: This implementation contains: - Ensemble Training This implementation does not contain: - Inferring other agents' policies - Mixed continuous/discrete action spaces Args: actors: Actor network modules. actor_targets: Target actor network modules. critics: Critic network modules. critic_targets: Target critic network modules. critic_visible_actors: Indexes of visible actors for each critic. optimizer: Optimizer used to optimize ``actors`` and ``critics``. criterion: Criterion used to evaluate the value loss. sub_policy_num: Times to replicate each actor. Equals to `ensemble_policy_num - 1` lr_scheduler: Learning rate scheduler of ``optimizer``. lr_scheduler_args: Arguments of the learning rate scheduler. lr_scheduler_kwargs: Keyword arguments of the learning rate scheduler. batch_size: Batch size used during training. update_rate: :math:`\\tau` used to update target networks. Target parameters are updated as: :math:`\\theta_t = \\theta * \\tau + \\theta_t * (1 - \\tau)` actor_learning_rate: Learning rate of the actor optimizer, not compatible with ``lr_scheduler``. critic_learning_rate: Learning rate of the critic optimizer, not compatible with ``lr_scheduler``. discount: :math:`\\gamma` used in the bellman function. replay_size: Replay buffer size for each actor. Not compatible with ``replay_buffer``. replay_device: Device where the replay buffer locates on, Not compatible with ``replay_buffer``. replay_buffer: Custom replay buffer. Will be replicated for actor. visualize: Whether visualize the network flow in the first pass. visualize_dir: Visualized graph save directory. use_jit: Whether use torch jit to perform the forward pass in parallel instead of using the internal pool. Provides significant speed and efficiency advantage, but requires actors and critics convertible to TorchScript. pool_type: Type of the internal execution pool, either "process" or "thread". pool_size: Size of the internal execution pool. """ assert pool_type in ("process", "thread") self.batch_size = batch_size self.update_rate = update_rate self.discount = discount self.has_visualized = False self.visualize = visualize self.visualize_dir = visualize_dir self.grad_max = gradient_max self.critic_visible_actors = critic_visible_actors # create ensembles of policies self.actors = [[actor] + [copy.deepcopy(actor) for _ in range(sub_policy_num)] for actor in actors] self.actor_targets = [ [actor_target] + [copy.deepcopy(actor_target) for _ in range(sub_policy_num)] for actor_target in actor_targets ] self.critics = critics self.critic_targets = critic_targets self.actor_optims = [[ optimizer(acc.parameters(), lr=actor_learning_rate) for acc in ac ] for ac in self.actors] self.critic_optims = [ optimizer(cr.parameters(), lr=critic_learning_rate) for cr in self.critics ] self.ensemble_size = sub_policy_num + 1 self.replay_buffers = [ SHMBuffer(replay_size, replay_device) if replay_buffer is None else copy.deepcopy(replay_buffer) for _ in range(len(actors)) ] # create the pool used to update() # check devices of all parameters, # determine the pool process starting method. device = self._check_parameters_device( itertools.chain(*self.actors, self.critics)) self.device = device self.pool_type = pool_type if pool_type == "process": self.pool = P2PPool(processes=pool_size, is_recursive=False, is_copy_tensor=False, share_method=device) elif pool_type == "thread": self.pool = ThreadPool(processes=pool_size) # Create wrapper for target actors and target critics. # So their parameters can be saved. self.all_actor_target = nn.Module() self.all_critic_target = nn.Module() for ac, idx in zip(self.actor_targets, range(len(actors))): for acc, idxx in zip(ac, range(self.ensemble_size)): acc.share_memory() self.all_actor_target.add_module( "actor_{}_{}".format(idx, idxx), acc) for cr, idx in zip(self.critic_targets, range(len(critics))): cr.share_memory() self.all_critic_target.add_module("critic_{}".format(idx), cr) # Make sure target and online networks have the same weight with t.no_grad(): self.pool.starmap( hard_update, zip(itertools.chain(*self.actors), itertools.chain(*self.actor_targets))) self.pool.starmap(hard_update, zip(self.critics, self.critic_targets)) if lr_scheduler is not None: if lr_scheduler_args is None: lr_scheduler_args = ((), ()) if lr_scheduler_kwargs is None: lr_scheduler_kwargs = ({}, {}) self.actor_lr_schs = [ lr_scheduler(ac_opt, *lr_scheduler_args[0], *lr_scheduler_kwargs[0]) for acc_opt in self.actor_optims for ac_opt in acc_opt ] self.critic_lr_schs = [ lr_scheduler(cr_opt, *lr_scheduler_args[1], *lr_scheduler_kwargs[1]) for cr_opt in self.critic_optims ] self.criterion = criterion # make preparations if use jit # jit modules will share the same parameter memory with original # modules, therefore it is safe to use them together. self.use_jit = use_jit self.jit_actors = [] self.jit_actor_targets = [] if use_jit: # only compile actors, since critics will not be # launched in parallel for ac in self.actors: jit_actors = [] jit_actor_targets = [] for acc in ac: # exclude "self" by truncating element 0 actor_arg_spec = inspect.getfullargspec(acc.forward) jit_actor = t.jit.script(acc) jit_actor.arg_spec = actor_arg_spec jit_actor.model_type = type(acc) jit_actors.append(jit_actor) jit_actor_target = t.jit.script(acc) jit_actor_target.arg_spec = actor_arg_spec jit_actor_target.model_type = type(acc) jit_actor_targets.append(jit_actor_target) self.jit_actors.append(jit_actors) self.jit_actor_targets.append(jit_actor_targets) super(MADDPG, self).__init__() def act(self, states: List[Dict[str, Any]], use_target: bool = False, **__): """ Use all actor networks to produce actions for the current state. A random sub-policy from the policy ensemble of each actor will be chosen. Args: states: A list of current states of each actor. use_target: Whether use the target network. Returns: A list of anything returned by your actor. If your actor returns multiple values, they will be wrapped in a tuple. """ return [ safe_return(act) for act in self._act_api_general(states, use_target) ] def act_with_noise(self, states: List[Dict[str, Any]], noise_param: Any = (0.0, 1.0), ratio: float = 1.0, mode: str = "uniform", use_target: bool = False, **__): """ Use all actor networks to produce noisy actions for the current state. A random sub-policy from the policy ensemble of each actor will be chosen. See Also: :mod:`machin.frame.noise.action_space_noise` Args: states: A list of current states of each actor. noise_param: Noise params. ratio: Noise ratio. mode: Noise mode. Supported are: ``"uniform", "normal", "clipped_normal", "ou"`` use_target: Whether use the target network. Returns: A list of noisy actions of shape ``[batch_size, action_dim]``. """ actions = self._act_api_general(states, use_target) result = [] for action, *others in actions: if mode == "uniform": action = add_uniform_noise_to_action(action, noise_param, ratio) elif mode == "normal": action = add_normal_noise_to_action(action, noise_param, ratio) elif mode == "clipped_normal": action = add_clipped_normal_noise_to_action( action, noise_param, ratio) elif mode == "ou": action = add_ou_noise_to_action(action, noise_param, ratio) else: raise ValueError("Unknown noise type: " + str(mode)) if len(others) == 0: result.append(action) else: result.append((action, *others)) return result def act_discrete(self, states: List[Dict[str, Any]], use_target: bool = False): """ Use all actor networks to produce discrete actions for the current state. A random sub-policy from the policy ensemble of each actor will be chosen. Notes: actor network must output a probability tensor, of shape (batch_size, action_dims), and has a sum of 1 for each row in dimension 1. Args: states: A list of current states of each actor. use_target: Whether use the target network. Returns: A list of tuples containing: 1. Integer discrete actions of shape ``[batch_size, 1]``. 2. Action probability tensors of shape ``[batch_size, action_num]``. 3. Any other things returned by your actor. """ actions = self._act_api_general(states, use_target) result = [] for action, *others in actions: assert_output_is_probs(action) batch_size = action.shape[0] action_disc = t.argmax(action, dim=1).view(batch_size, 1) result.append((action_disc, action, *others)) return result def act_discrete_with_noise(self, states: List[Dict[str, Any]], use_target: bool = False): """ Use all actor networks to produce discrete actions for the current state. A random sub-policy from the policy ensemble of each actor will be chosen. Notes: actor network must output a probability tensor, of shape (batch_size, action_dims), and has a sum of 1 for each row in dimension 1. Args: states: A list of current states of each actor. use_target: Whether use the target network. Returns: A list of tuples containing: 1. Integer noisy discrete actions. 2. Action probability tensors of shape ``[batch_size, action_num]``. 3. Any other things returned by your actor. """ actions = self._act_api_general(states, use_target) result = [] for action, *others in actions: assert_output_is_probs(action) batch_size = action.shape[0] dist = Categorical(action) action_disc = dist.sample([batch_size, 1]).view(batch_size, 1) result.append((action_disc, action, *others)) return result def _act_api_general(self, states, use_target): if self.use_jit: if use_target: actors = [ choice(sub_actors) for sub_actors in self.jit_actor_targets ] else: actors = [choice(sub_actors) for sub_actors in self.jit_actors] future = [ self._jit_safe_call(ac, st) for ac, st in zip(actors, states) ] result = [t.jit._wait(fut) for fut in future] result = [ res if isinstance(res, tuple) else (res, ) for res in result ] else: if use_target: actors = [ choice(sub_actors) for sub_actors in self.actor_targets ] else: actors = [choice(sub_actors) for sub_actors in self.actors] result = self.pool.starmap(self._no_grad_safe_call, zip(actors, states)) result = [res for res in result] return result def _criticize(self, states: List[Dict[str, Any]], actions: List[Dict[str, Any]], index: int, use_target=False): """ Use critic network to evaluate current value. Args: states: Current states of all actors. actions: Current actions of all actors. use_target: Whether to use the target network. index: Index of the used critic. Returns: Q Value of shape ``[batch_size, 1]``. """ if use_target: return safe_call(self.critic_targets[index], self.state_concat_function(states), self.action_concat_function(actions)) else: return safe_call(self.critics[index], self.state_concat_function(states), self.action_concat_function(actions)) def store_transitions(self, transitions: List[Union[Transition, Dict]]): """ Add a list of transition samples, from all actors at the same time step, to the replay buffers. Args: transitions: List of transition objects. """ assert len(transitions) == len(self.replay_buffers) for buff, trans in zip(self.replay_buffers, transitions): buff.append(trans, required_attrs=("state", "action", "next_state", "reward", "terminal")) def store_episodes(self, episodes: List[List[Union[Transition, Dict]]]): """ Add a List of full episodes, from all actors, to the replay buffers. Each episode is a list of transition samples. """ assert len(episodes) == len(self.replay_buffers) all_length = [len(ep) for ep in episodes] assert len(set(all_length)) == 1, \ "All episodes must have the same length!" for buff, ep in zip(self.replay_buffers, episodes): for trans in ep: buff.append(trans, required_attrs=("state", "action", "next_state", "reward", "terminal")) def update(self, update_value=True, update_policy=True, update_target=True, concatenate_samples=True): """ Update network weights by sampling from replay buffer. Args: update_value: Whether to update the Q network. update_policy: Whether to update the actor network. update_target: Whether to update targets. concatenate_samples: Whether to concatenate the samples. Returns: mean value of estimated policy value, value loss """ # All buffers should have the same length now. # Create a sample method per update # this sample method will sample the same indexes # (different for each update() call) on all buffers. buffer_length = self.replay_buffers[0].size() if buffer_length == 0: return batch_size = min(buffer_length, self.batch_size) sample_indexes = [[ randint(0, buffer_length - 1) for _ in range(batch_size) ] for __ in range(self.ensemble_size)] sample_methods = [ self._create_sample_method(indexes) for indexes in sample_indexes ] # Now sample from buffer for each sub-policy in the ensemble. # To reduce memory usage, for each sub-policy "i" of each actor, # the same sample "i" will be used for training. # Tensors in the sampled batch will be moved to shared memory. # size: [ensemble size, num of actors] batches = [] next_actions_t = [] for e_idx in range(self.ensemble_size): ensemble_batch = [] for a_idx in range(len(self.actors)): batch_size_, batch = \ self.replay_buffers[a_idx].sample_batch( self.batch_size, concatenate_samples, sample_method=sample_methods[e_idx], sample_attrs=[ "state", "action", "reward", "next_state", "terminal", "*"] ) ensemble_batch.append(batch) assert batch_size_ == batch_size batches.append(ensemble_batch) next_actions_t.append([ self.action_transform_function(act) for act in self.act([batch[3] for batch in ensemble_batch], target=True) ]) if self.pool_type == "process": batches = self._move_to_shared_mem(batches) next_actions_t = self._move_to_shared_mem(next_actions_t) args = [] for e_idx in range(self.ensemble_size): for a_idx in range(len(self.actors)): args.append( (batch_size, batches, next_actions_t, a_idx, e_idx, self.actors, self.actor_targets, self.critics, self.critic_targets, self.critic_visible_actors, self.actor_optims, self.critic_optims, update_value, update_policy, update_target, self.action_transform_function, self.action_concat_function, self.state_concat_function, self.reward_function, self.criterion, self.discount, self.update_rate, self.grad_max, self.visualize and not self.has_visualized, self.visualize_dir)) all_loss = self.pool.starmap(self._update_sub_policy, args) mean_loss = t.tensor(all_loss).mean(dim=0) # returns action value and policy loss return -mean_loss[0].item(), mean_loss[1].item() def update_lr_scheduler(self): """ Update learning rate schedulers. """ if hasattr(self, "actor_lr_schs"): for actor_lr_sch in self.actor_lr_schs: actor_lr_sch.step() if hasattr(self, "critic_lr_schs"): for critic_lr_sch in self.critic_lr_schs: critic_lr_sch.step() def load(self, model_dir, network_map=None, version=-1): # DOC INHERITED super(MADDPG, self).load(model_dir, network_map, version) with t.no_grad(): self.pool.starmap( hard_update, zip(itertools.chain(*self.actors), itertools.chain(*self.actor_targets))) self.pool.starmap(hard_update, zip(self.critics, self.critic_targets)) @staticmethod def _no_grad_safe_call(model, *named_args): with t.no_grad(): result = safe_call(model, *named_args) return result @staticmethod def _jit_safe_call(model, *named_args): if (not hasattr(model, "input_device") or not hasattr(model, "output_device")): # try to automatically determine the input & output # device of the model model_type = type(model) device = determine_device(model) if len(device) > 1: raise RuntimeError( "Failed to automatically determine i/o device " "of your model: {}\n" "Detected multiple devices: {}\n" "You need to manually specify i/o device of " "your model.\n" "Wrap your model of type nn.Module with one " "of: \n" "1. static_module_wrapper " "from machin.model.nets.base \n" "1. dynamic_module_wrapper " "from machin.model.nets.base \n" "Or construct your own module & model with: \n" "NeuralNetworkModule from machin.model.nets.base".format( model_type, device)) else: # assume that i/o devices are the same as parameter device # print a warning default_logger.warning( "You have not specified the i/o device of" "your model {}, automatically determined and" " set to: {}\n" "The framework is not responsible for any " "un-matching device issues caused by this" "operation.".format(model_type, device[0])) model = static_module_wrapper(model, device[0], device[0]) input_device = model.input_device # set in __init__ args = model.arg_spec.args[1:] + model.arg_spec.kwonlyargs if model.arg_spec.defaults is not None: args_with_defaults = args[-len(model.arg_spec.defaults):] else: args_with_defaults = [] required_args = (set(args) - set(args_with_defaults) - set(model.arg_spec.kwonlydefaults.keys() if model. arg_spec.kwonlydefaults is not None else [])) model_type = model.model_type # t.jit._fork does not support keyword args # fill arguments in by their positions. args_list = [None for _ in args] args_filled = [False for _ in args] for na in named_args: for k, v in na.items(): if k in args: if k not in args: pass args_filled[args.index(k)] = True if t.is_tensor(v): args_list[args.index(k)] = v.to(input_device) else: args_list[args.index(k)] = v if not all(args_filled): not_filled = [ arg for filled, arg in zip(args_filled, args) if not filled ] required_not_filled = set(not_filled).intersection(required_args) if len(required_not_filled) > 0: raise RuntimeError("\n" "The signature of the forward function " "of Model {} is {}\n" "Missing required arguments: {}, " "check your storage functions.".format( model_type, required_args, required_not_filled)) return t.jit._fork(model, *args_list) @staticmethod def _update_sub_policy(batch_size, batches, next_actions_t, actor_index, policy_index, actors, actor_targets, critics, critic_targets, critic_visible_actors, actor_optims, critic_optims, update_value, update_policy, update_target, atf, acf, scf, rf, criterion, discount, update_rate, grad_max, visualize, visualize_dir): # atf: action transform function, used to transform the # raw output of a single actor to a arg dict like: # {"action": tensor}, where "action" is the keyword argument # name of the critic. # # acf: action concatenation function, used to concatenate # a list of action dicts into a single arg dict readable # by critic. # scf: state concatenation function, used to concatenate # a list of state dicts into a single arg dict readable # by critic. # rf: reward function # The innermost element of ``batches``: # (state, action, reward, next_state, terminal, *) # ``batches`` size: [ensemble_size, actor_num] # select the batch for this sub-policy in the ensemble ensemble_batch = batches[policy_index] ensemble_n_act_t = next_actions_t[policy_index] visible_actors = critic_visible_actors[actor_index] actors[actor_index][policy_index].train() critics[actor_index].train() with t.no_grad(): # only select visible actors all_next_actions_t = [ ensemble_n_act_t[a_idx] if a_idx != actor_index else atf( safe_call(actor_targets[actor_index][policy_index], ensemble_batch[a_idx][3])[0], ensemble_batch[a_idx][5]) for a_idx in visible_actors ] all_next_actions_t = acf(all_next_actions_t) all_actions = [ ensemble_batch[a_idx][1] for a_idx in visible_actors ] all_actions = acf(all_actions) all_next_states = [ ensemble_batch[a_idx][3] for a_idx in visible_actors ] all_next_states = scf(all_next_states) all_states = [ensemble_batch[a_idx][0] for a_idx in visible_actors] all_states = scf(all_states) # Update critic network first # Generate target value using target critic. with t.no_grad(): reward = ensemble_batch[actor_index][2] terminal = ensemble_batch[actor_index][4] next_value = safe_call(critic_targets[actor_index], all_next_states, all_next_actions_t)[0] next_value = next_value.view(batch_size, -1) y_i = rf(reward, discount, next_value, terminal, ensemble_batch[actor_index][5]) cur_value = safe_call(critics[actor_index], all_states, all_actions)[0] value_loss = criterion(cur_value, y_i.to(cur_value.device)) if visualize: # only invoked if not running by pool MADDPG._visualize(value_loss, "critic_{}".format(actor_index), visualize_dir) if update_value: critics[actor_index].zero_grad() value_loss.backward() nn.utils.clip_grad_norm_(critics[actor_index].parameters(), grad_max) critic_optims[actor_index].step() # Update actor network all_actions = [ensemble_batch[a_idx][1] for a_idx in visible_actors] # find the actor index in the view range of critic # Eg: there are 4 actors in total: a_0, a_1, a_2, a_3 # critic may have access to actor a_1 and a_2 # then: # visible_actors.index(a_1) = 0 # visible_actors.index(a_2) = 1 # visible_actors.index returns the (critic-)local position of actor # in the view range of its corresponding critic. all_actions[visible_actors.index(actor_index)] = atf( safe_call(actors[actor_index][policy_index], ensemble_batch[actor_index][3])[0], ensemble_batch[actor_index][5]) all_actions = acf(all_actions) act_value = safe_call(critics[actor_index], all_states, all_actions)[0] # "-" is applied because we want to maximize J_b(u), # but optimizer workers by minimizing the target act_policy_loss = -act_value.mean() if visualize: # only invoked if not running by pool MADDPG._visualize(act_policy_loss, "actor_{}_{}".format(actor_index, policy_index), visualize_dir) if update_policy: actors[actor_index][policy_index].zero_grad() act_policy_loss.backward() nn.utils.clip_grad_norm_( actors[actor_index][policy_index].parameters(), grad_max) actor_optims[actor_index][policy_index].step() # Update target networks if update_target: soft_update(actor_targets[actor_index][policy_index], actors[actor_index][policy_index], update_rate) soft_update(critic_targets[actor_index], critics[actor_index], update_rate) actors[actor_index][policy_index].eval() critics[actor_index].eval() return -act_policy_loss.item(), value_loss.item() @staticmethod def _visualize(final_tensor, name, directory): g = make_dot(final_tensor) g.render(filename=name, directory=directory, view=False, cleanup=False, quiet=True) @staticmethod def _move_to_shared_mem(obj): if t.is_tensor(obj): obj = obj.detach() obj.share_memory_() return obj elif isinstance(obj, list): for idx, sub_obj in enumerate(obj): obj[idx] = MADDPG._move_to_shared_mem(sub_obj) return obj elif isinstance(obj, tuple): obj = list(obj) for idx, sub_obj in enumerate(obj): obj[idx] = MADDPG._move_to_shared_mem(sub_obj) return tuple(obj) elif isinstance(obj, dict): for k, v in obj.items(): obj[k] = MADDPG._move_to_shared_mem(v) return obj @staticmethod def _check_parameters_device(models): devices = set() for model in models: for k, v in model.named_parameters(): devices.add(v.device.type) if len(devices) > 1: raise RuntimeError("All of your models should either" "locate on GPUs or on your CPU!") return list(devices)[0] @staticmethod def _create_sample_method(indexes): def sample_method(buffer, _len): nonlocal indexes batch = [buffer[i] for i in indexes if i < len(buffer)] return len(batch), batch return sample_method @staticmethod def action_transform_function(raw_output_action: Any, *_): return {"action": raw_output_action} @staticmethod def action_concat_function(actions: List[Dict], *_): # Assume an atom action is [batch_size, action_dim] # concatenate actions in the second dimension. # becomes [batch_size, actor_num * action_dim] keys = actions[0].keys() all_actions = {} for k in keys: all_actions[k] = t.cat([act[k].cpu() for act in actions], dim=1) return all_actions @staticmethod def state_concat_function(states: List[Dict], *_): # Assume an atom state is [batch_size, state_dim] # concatenate states in the second dimension. # becomes [batch_size, actor_num * state_dim] keys = states[0].keys() all_states = {} for k in keys: all_states[k] = t.cat([st[k].cpu() for st in states], dim=1) return all_states @staticmethod def reward_function(reward, discount, next_value, terminal, *_): next_value = next_value.to(reward.device) terminal = terminal.to(reward.device) return reward + discount * ~terminal * next_value
def test_size(self): pool = ThreadPool(processes=2) assert pool.size() == 2 pool.close() pool.join()
def test_reduce(self): with pytest.raises(RuntimeError, match="not reducible"): dill.dumps(ThreadPool(processes=2))
def test_reduce(self): with pytest.raises(NotImplementedError, match="cannot be passed"): dill.dumps(ThreadPool(processes=2))