def _collect_joint_dataset(trainer, worker, sample_size): joint_obs = [] if hasattr(trainer.optimizer, "replay_buffers"): # If we are using maddpg, it use ReplayOptimizer, which has this # attribute. for policy_id, replay_buffer in \ trainer.optimizer.replay_buffers.items(): obs = replay_buffer.sample(sample_size)[0] joint_obs.append(obs) else: # If we are using individual PPO, it has no replay buffer, # so it seems we have to rollout here to collect the observations # Force to collect enough data for us to use. tmp_batch = worker.sample() count_dict = {k: v.count for k, v in tmp_batch.policy_batches.items()} for k in worker.policy_map.keys(): if k not in count_dict: count_dict[k] = 0 samples = [tmp_batch] while any(c < sample_size for c in count_dict.values()): tmp_batch = worker.sample() for k, v in tmp_batch.policy_batches.items(): assert k in count_dict, count_dict count_dict[k] += v.count samples.append(tmp_batch) multi_agent_batch = MultiAgentBatch.concat_samples(samples) for pid, batch in multi_agent_batch.policy_batches.items(): batch.shuffle() assert batch.count >= sample_size, (batch, batch.count, [ b.count for b in batch.policy_batches.values() ]) joint_obs.append(batch.slice(0, sample_size)['obs']) joint_obs = np.concatenate(joint_obs) return joint_obs
def mix_batches(_policy_id): """Mixes old with new samples. Tries to mix according to self.replay_ratio on average. If not enough new samples are available, mixes in less old samples to retain self.replay_ratio on average. """ def round_up_or_down(value, ratio): """Returns an integer averaging to value*ratio.""" product = value * ratio ceil_prob = product % 1 if random.uniform(0, 1) < ceil_prob: return int(np.ceil(product)) else: return int(np.floor(product)) max_num_new = round_up_or_down(num_items, 1 - self.replay_ratio) # if num_samples * self.replay_ratio is not round, # we need one more sample with a probability of # (num_items*self.replay_ratio) % 1 _buffer = self.replay_buffers[_policy_id] output_batches = self.last_added_batches[_policy_id][:max_num_new] self.last_added_batches[_policy_id] = self.last_added_batches[_policy_id][ max_num_new: ] # No replay desired if self.replay_ratio == 0.0: return SampleBatch.concat_samples(output_batches) # Only replay desired elif self.replay_ratio == 1.0: return _buffer.sample(num_items, **kwargs) num_new = len(output_batches) if np.isclose(num_new, num_items * (1 - self.replay_ratio)): # The optimal case, we can mix in a round number of old # samples on average num_old = num_items - max_num_new else: # We never want to return more elements than num_items num_old = min( num_items - max_num_new, round_up_or_down( num_new, self.replay_ratio / (1 - self.replay_ratio) ), ) output_batches.append(_buffer.sample(num_old, **kwargs)) # Depending on the implementation of underlying buffers, samples # might be SampleBatches output_batches = [batch.as_multi_agent() for batch in output_batches] return MultiAgentBatch.concat_samples(output_batches)
def next(self): """Return the next batch of experiences read. Returns: SampleBatch or MultiAgentBatch read. """ batches = [] for dp in self.data_processors: batches.append(ray_get_and_free(dp.next.remote())) batch = MultiAgentBatch.concat_samples(samples=batches) return batch
def get_cross_policy_object(multi_agent_batch, self_optimizer): """Add contents into cross_policy_object, which passed to each policy.""" config = self_optimizer.workers._remote_config if not config["use_joint_dataset"]: joint_obs = SampleBatch.concat_samples( list(multi_agent_batch.policy_batches.values()))[ SampleBatch.CUR_OBS] else: sample_size = config.get("joint_dataset_sample_batch_size") assert sample_size is not None, "You should specify the value of: " \ "joint_dataset_sample_batch_size " \ "in config!" samples = [multi_agent_batch] count_dict = { k: v.count for k, v in multi_agent_batch.policy_batches.items() } for k in self_optimizer.workers.local_worker().policy_map.keys(): if k not in count_dict: count_dict[k] = 0 while any([v < sample_size for v in count_dict.values()]): tmp_batch = self_optimizer.workers.local_worker().sample() samples.append(tmp_batch) for k, v in tmp_batch.policy_batches.items(): assert k in count_dict, count_dict count_dict[k] += v.count multi_agent_batch = MultiAgentBatch.concat_samples(samples) joint_obs = [] pid_list = [] for pid, batch in multi_agent_batch.policy_batches.items(): batch.shuffle() assert batch.count >= sample_size, batch joint_obs.append(batch.slice(0, sample_size)['obs']) pid_list.append(pid) joint_obs = np.concatenate(joint_obs) def _replay(policy, pid): act, _, infos = policy.compute_actions(joint_obs) return pid, act, infos # ATTENTION!!! Here is MYSELF replaying JOINT OBSERVATION ret = { pid: act for pid, act, infos in self_optimizer.workers.local_worker().foreach_policy(_replay) } return {JOINT_OBS: joint_obs, PEER_ACTION: ret}
def sample(self, num_items: int, policy_id: PolicyID = DEFAULT_POLICY_ID, **kwargs) -> Optional[SampleBatchType]: """Samples a batch of size `num_items` from a specified buffer. Concatenates old samples to new ones according to self.replay_ratio. If not enough new samples are available, mixes in less old samples to retain self.replay_ratio on average. Returns an empty batch if there are no items in the buffer. Args: num_items: Number of items to sample fromM this buffer. policy_id: ID of the policy that produced the experiences to be sampled. **kwargs: Forward compatibility kwargs. Returns: Concatenated MultiAgentBatch of items. """ # Merge kwargs, overwriting standard call arguments kwargs = merge_dicts_with_warning(self.underlying_buffer_call_args, kwargs) def mix_batches(_policy_id): """Mixes old with new samples. Tries to mix according to self.replay_ratio on average. If not enough new samples are available, mixes in less old samples to retain self.replay_ratio on average. """ def round_up_or_down(value, ratio): """Returns an integer averaging to value*ratio.""" product = value * ratio ceil_prob = product % 1 if random.uniform(0, 1) < ceil_prob: return int(np.ceil(product)) else: return int(np.floor(product)) max_num_new = round_up_or_down(num_items, 1 - self.replay_ratio) # if num_samples * self.replay_ratio is not round, # we need one more sample with a probability of # (num_items*self.replay_ratio) % 1 _buffer = self.replay_buffers[_policy_id] output_batches = self.last_added_batches[_policy_id][:max_num_new] self.last_added_batches[_policy_id] = self.last_added_batches[ _policy_id][max_num_new:] # No replay desired if self.replay_ratio == 0.0: return SampleBatch.concat_samples(output_batches) # Only replay desired elif self.replay_ratio == 1.0: return _buffer.sample(num_items, **kwargs) num_new = len(output_batches) if np.isclose(num_new, num_items * (1 - self.replay_ratio)): # The optimal case, we can mix in a round number of old # samples on average num_old = num_items - max_num_new else: # We never want to return more elements than num_items num_old = min( num_items - max_num_new, round_up_or_down( num_new, self.replay_ratio / (1 - self.replay_ratio)), ) output_batches.append(_buffer.sample(num_old, **kwargs)) # Depending on the implementation of underlying buffers, samples # might be SampleBatches output_batches = [ batch.as_multi_agent() for batch in output_batches ] return MultiAgentBatch.concat_samples(output_batches) def check_buffer_is_ready(_policy_id): if ((len(self.replay_buffers[policy_id]) == 0) and self.replay_ratio > 0.0) or ( len(self.last_added_batches[_policy_id]) == 0 and self.replay_ratio < 1.0): return False return True with self.replay_timer: samples = [] if self.replay_mode == ReplayMode.LOCKSTEP: assert ( policy_id is None ), "`policy_id` specifier not allowed in `lockstep` mode!" if check_buffer_is_ready(_ALL_POLICIES): samples.append(mix_batches(_ALL_POLICIES).as_multi_agent()) elif policy_id is not None: if check_buffer_is_ready(policy_id): samples.append(mix_batches(policy_id).as_multi_agent()) else: for policy_id, replay_buffer in self.replay_buffers.items(): if check_buffer_is_ready(policy_id): samples.append(mix_batches(policy_id).as_multi_agent()) return MultiAgentBatch.concat_samples(samples)