def step(self): with self.update_weights_timer: if self.workers.remote_workers(): weights = ray.put(self.workers.local_worker().get_weights()) for e in self.workers.remote_workers(): e.set_weights.remote(weights) with self.sample_timer: samples = [] while sum(s.count for s in samples) < self.train_batch_size: if self.workers.remote_workers(): samples.extend( ray_get_and_free([ e.sample.remote() for e in self.workers.remote_workers() ])) else: samples.append(self.workers.local_worker().sample()) samples = SampleBatch.concat_samples(samples) self.sample_timer.push_units_processed(samples.count) # Handle everything as if multiagent if isinstance(samples, SampleBatch): samples = MultiAgentBatch({DEFAULT_POLICY_ID: samples}, samples.count) fetches = {} with self.grad_timer: for policy_id, policy in self.policies.items(): if policy_id not in samples.policy_batches: continue batch = samples.policy_batches[policy_id] for field in self.standardize_fields: value = batch[field] standardized = (value - value.mean()) / max( 1e-4, value.std()) batch[field] = standardized for i in range(self.num_sgd_iter): iter_extra_fetches = defaultdict(list) for minibatch in self._minibatches(batch): batch_fetches = ( self.workers.local_worker().learn_on_batch( MultiAgentBatch({policy_id: minibatch}, minibatch.count)))[policy_id] for k, v in batch_fetches[LEARNER_STATS_KEY].items(): iter_extra_fetches[k].append(v) logger.debug("{} {}".format(i, _averaged(iter_extra_fetches))) fetches[policy_id] = _averaged(iter_extra_fetches) self.grad_timer.push_units_processed(samples.count) if len(fetches) == 1 and DEFAULT_POLICY_ID in fetches: self.learner_stats = fetches[DEFAULT_POLICY_ID] else: self.learner_stats = fetches self.num_steps_sampled += samples.count self.num_steps_trained += samples.count return self.learner_stats
def replay(self) -> SampleBatchType: """If this buffer was given a fake batch, return it, otherwise return a MultiAgentBatch with samples. """ if self._fake_batch: fake_batch = SampleBatch(self._fake_batch) return MultiAgentBatch({ DEFAULT_POLICY_ID: fake_batch }, fake_batch.count) if self.num_added < self.replay_starts: return None with self.replay_timer: # Lockstep mode: Sample from all policies at the same time an # equal amount of steps. if self.replay_mode == "lockstep": return self.replay_buffers[_ALL_POLICIES].sample( self.replay_batch_size, beta=self.prioritized_replay_beta) else: samples = {} for policy_id, replay_buffer in self.replay_buffers.items(): samples[policy_id] = replay_buffer.sample( self.replay_batch_size, beta=self.prioritized_replay_beta) return MultiAgentBatch(samples, self.replay_batch_size)
def replay(self): if self._fake_batch: fake_batch = SampleBatch(self._fake_batch) return MultiAgentBatch({DEFAULT_POLICY_ID: fake_batch}, fake_batch.count) if self.num_added < self.replay_starts: return None with self.replay_timer: samples = {} idxes = None for policy_id, replay_buffer in self.replay_buffers.items(): if self.multiagent_sync_replay: if idxes is None: idxes = replay_buffer.sample_idxes( self.replay_batch_size) else: idxes = replay_buffer.sample_idxes(self.replay_batch_size) (obses_t, actions, rewards, obses_tp1, dones, weights, batch_indexes) = replay_buffer.sample_with_idxes( idxes, beta=self.prioritized_replay_beta) samples[policy_id] = SampleBatch({ "obs": obses_t, "actions": actions, "rewards": rewards, "new_obs": obses_tp1, "dones": dones, "weights": weights, "batch_indexes": batch_indexes }) return MultiAgentBatch(samples, self.replay_batch_size)
def do_minibatch_sgd(samples, policies, local_worker, num_sgd_iter, sgd_minibatch_size, standardize_fields): """Execute minibatch SGD. Args: samples (SampleBatch): Batch of samples to optimize. policies (dict): Dictionary of policies to optimize. local_worker (RolloutWorker): Master rollout worker instance. num_sgd_iter (int): Number of epochs of optimization to take. sgd_minibatch_size (int): Size of minibatches to use for optimization. standardize_fields (list): List of sample field names that should be normalized prior to optimization. Returns: averaged info fetches over the last SGD epoch taken. """ if isinstance(samples, SampleBatch): samples = MultiAgentBatch({DEFAULT_POLICY_ID: samples}, samples.count) # Use LearnerInfoBuilder as a unified way to build the final # results dict from `learn_on_loaded_batch` call(s). # This makes sure results dicts always have the same structure # no matter the setup (multi-GPU, multi-agent, minibatch SGD, # tf vs torch). learner_info_builder = LearnerInfoBuilder(num_devices=1) for policy_id, policy in policies.items(): if policy_id not in samples.policy_batches: continue batch = samples.policy_batches[policy_id] for field in standardize_fields: batch[field] = standardized(batch[field]) # Check to make sure that the sgd_minibatch_size is not smaller # than max_seq_len otherwise this will cause indexing errors while # performing sgd when using a RNN or Attention model if policy.is_recurrent() and \ policy.config["model"]["max_seq_len"] > sgd_minibatch_size: raise ValueError("`sgd_minibatch_size` ({}) cannot be smaller than" "`max_seq_len` ({}).".format( sgd_minibatch_size, policy.config["model"]["max_seq_len"])) for i in range(num_sgd_iter): for minibatch in minibatches(batch, sgd_minibatch_size): results = (local_worker.learn_on_batch( MultiAgentBatch({policy_id: minibatch}, minibatch.count)))[policy_id] learner_info_builder.add_learn_on_batch_results( results, policy_id) learner_info = learner_info_builder.finalize() return learner_info
def add_batch(self, batch): # Make a copy so the replay buffer doesn't pin plasma memory. batch = batch.copy() # Handle everything as if multiagent if isinstance(batch, SampleBatch): batch = MultiAgentBatch({DEFAULT_POLICY_ID: batch}, batch.count) with self.add_batch_timer: if self.replay_mode == "lockstep": for s in batch.timeslices(self.replay_sequence_length): self.replay_buffers[_ALL_POLICIES].add(s) else: for policy_id, b in batch.policy_batches.items(): for s in b.timeslices(self.replay_sequence_length): self.replay_buffers[policy_id].add(s) self.num_added += batch.count
def _add_multi_agent_batch_to_buffer( buffer, num_policies, num_batches=5, seq_lens=False, **kwargs ): def _generate_data(policy_id): batch = SampleBatch( { SampleBatch.T: [0, 1], SampleBatch.ACTIONS: 2 * [np.random.choice([0, 1])], SampleBatch.REWARDS: 2 * [np.random.rand()], SampleBatch.OBS: 2 * [np.random.random((4,))], SampleBatch.NEXT_OBS: 2 * [np.random.random((4,))], SampleBatch.DONES: [False, True], SampleBatch.EPS_ID: 2 * [self.batch_id], SampleBatch.AGENT_INDEX: 2 * [0], SampleBatch.SEQ_LENS: [2], "batch_id": 2 * [self.batch_id], "policy_id": 2 * [policy_id], } ) if not seq_lens: del batch[SampleBatch.SEQ_LENS] self.batch_id += 1 return batch for i in range(num_batches): # genera a few policy batches policy_batches = { idx: _generate_data(idx) for idx, _ in enumerate(range(num_policies)) } batch = MultiAgentBatch(policy_batches, num_batches * 2) buffer.add(batch, **kwargs)
def _from_json(batch: str) -> SampleBatchType: if isinstance(batch, bytes): # smart_open S3 doesn't respect "r" batch = batch.decode("utf-8") data = json.loads(batch) if "type" in data: data_type = data.pop("type") else: raise ValueError("JSON record missing 'type' field") if data_type == "SampleBatch": for k, v in data.items(): data[k] = unpack_if_needed(v) return SampleBatch(data) elif data_type == "MultiAgentBatch": policy_batches = {} for policy_id, policy_batch in data["policy_batches"].items(): inner = {} for k, v in policy_batch.items(): inner[k] = unpack_if_needed(v) policy_batches[policy_id] = SampleBatch(inner) return MultiAgentBatch(policy_batches, data["count"]) else: raise ValueError( "Type field must be one of ['SampleBatch', 'MultiAgentBatch']", data_type)
def gen_replay(timeout): while True: samples = {} idxes = None for policy_id, replay_buffer in replay_buffers.items(): if synchronize_sampling: if idxes is None: idxes = replay_buffer.sample_idxes(train_batch_size) else: idxes = replay_buffer.sample_idxes(train_batch_size) if isinstance(replay_buffer, PrioritizedReplayBuffer): metrics = LocalIterator.get_metrics() num_steps_trained = metrics.counters[STEPS_TRAINED_COUNTER] (obses_t, actions, rewards, obses_tp1, dones, weights, batch_indexes) = replay_buffer.sample_with_idxes( idxes, beta=prioritized_replay_beta.value(num_steps_trained)) else: (obses_t, actions, rewards, obses_tp1, dones) = replay_buffer.sample_with_idxes(idxes) weights = np.ones_like(rewards) batch_indexes = -np.ones_like(rewards) samples[policy_id] = SampleBatch({ "obs": obses_t, "actions": actions, "rewards": rewards, "new_obs": obses_tp1, "dones": dones, "weights": weights, "batch_indexes": batch_indexes }) yield MultiAgentBatch(samples, train_batch_size)
def _collect_joint_dataset(trainer, worker, sample_size): joint_obs = [] if hasattr(trainer.optimizer, "replay_buffers"): # If we are using maddpg, it use ReplayOptimizer, which has this # attribute. for policy_id, replay_buffer in \ trainer.optimizer.replay_buffers.items(): obs = replay_buffer.sample(sample_size)[0] joint_obs.append(obs) else: # If we are using individual PPO, it has no replay buffer, # so it seems we have to rollout here to collect the observations # Force to collect enough data for us to use. tmp_batch = worker.sample() count_dict = {k: v.count for k, v in tmp_batch.policy_batches.items()} for k in worker.policy_map.keys(): if k not in count_dict: count_dict[k] = 0 samples = [tmp_batch] while any(c < sample_size for c in count_dict.values()): tmp_batch = worker.sample() for k, v in tmp_batch.policy_batches.items(): assert k in count_dict, count_dict count_dict[k] += v.count samples.append(tmp_batch) multi_agent_batch = MultiAgentBatch.concat_samples(samples) for pid, batch in multi_agent_batch.policy_batches.items(): batch.shuffle() assert batch.count >= sample_size, (batch, batch.count, [ b.count for b in batch.policy_batches.values() ]) joint_obs.append(batch.slice(0, sample_size)['obs']) joint_obs = np.concatenate(joint_obs) return joint_obs
def build_and_reset( self, episode: Optional[MultiAgentEpisode] = None) -> MultiAgentBatch: """Returns the accumulated sample batches for each policy. Any unprocessed rows will be first postprocessed with a policy postprocessor. The internal state of this builder will be reset. Args: episode (Optional[MultiAgentEpisode]): The Episode object that holds this MultiAgentBatchBuilder object or None. Returns: MultiAgentBatch: Returns the accumulated sample batches for each policy. """ self.postprocess_batch_so_far(episode) policy_batches = {} for policy_id, builder in self.policy_builders.items(): if builder.count > 0: policy_batches[policy_id] = builder.build_and_reset() old_count = self.count self.count = 0 return MultiAgentBatch.wrap_as_needed(policy_batches, old_count)
def before_learn_on_batch(multi_agent_batch: MultiAgentBatch, policies, train_batch_size): samples = {} # Modify keys. for pid, p in policies.items(): i = p.agent_idx keys = multi_agent_batch.policy_batches[pid].data.keys() keys = ["_".join([k, str(i)]) for k in keys] samples.update( dict(zip(keys, multi_agent_batch.policy_batches[pid].data.values()))) # Make ops and feed_dict to get "new_obs" from target action sampler. new_obs_ph_n = [p.new_obs_ph for p in policies.values()] new_obs_n = list() for k, v in samples.items(): if "new_obs" in k: new_obs_n.append(v) # target_act_sampler_n = [p.target_act_sampler for p in policies.values()] feed_dict = dict(zip(new_obs_ph_n, new_obs_n)) new_act_n = [ p.sess.run(p.target_act_sampler, feed_dict) for p in policies.values() ] samples.update( {"new_actions_%d" % i: new_act for i, new_act in enumerate(new_act_n)}) # Share samples among agents. policy_batches = {pid: SampleBatch(samples) for pid in policies.keys()} return MultiAgentBatch(policy_batches, train_batch_size)
def on_sample_end(self, *, worker: "RolloutWorker", samples: SampleBatch, **kwargs): super().on_sample_end(worker=worker, samples=samples, **kwargs) assert isinstance(samples, MultiAgentBatch) for policy_samples in samples.policy_batches.values(): if "action_prob" in policy_samples.data: del policy_samples.data["action_prob"] if "action_logp" in policy_samples.data: del policy_samples.data["action_logp"] for average_policy_id, br_policy_id in [("average_policy_0", "best_response_0"), ("average_policy_1", "best_response_1")]: for policy_id, policy_samples in samples.policy_batches.items(): if policy_id == br_policy_id: store_to_avg_policy_buffer(MultiAgentBatch(policy_batches={ average_policy_id: policy_samples }, env_steps=policy_samples.count)) if average_policy_id in samples.policy_batches: if br_policy_id in samples.policy_batches: all_policies_samples = samples.policy_batches[br_policy_id].concat( other=samples.policy_batches[average_policy_id]) else: all_policies_samples = samples.policy_batches[average_policy_id] del samples.policy_batches[average_policy_id] samples.policy_batches[br_policy_id] = all_policies_samples
def from_json_data(json_data: Any, worker: Optional["RolloutWorker"]): # Try to infer the SampleBatchType (SampleBatch or MultiAgentBatch). if "type" in json_data: data_type = json_data.pop("type") else: raise ValueError("JSON record missing 'type' field") if data_type == "SampleBatch": if worker is not None and len(worker.policy_map) != 1: raise ValueError( "Found single-agent SampleBatch in input file, but our " "PolicyMap contains more than 1 policy!") for k, v in json_data.items(): json_data[k] = unpack_if_needed(v) if worker is not None: policy = next(iter(worker.policy_map.values())) json_data = _adjust_obs_actions_for_policy(json_data, policy) return SampleBatch(json_data) elif data_type == "MultiAgentBatch": policy_batches = {} for policy_id, policy_batch in json_data["policy_batches"].items(): inner = {} for k, v in policy_batch.items(): inner[k] = unpack_if_needed(v) if worker is not None: policy = worker.policy_map[policy_id] inner = _adjust_obs_actions_for_policy(inner, policy) policy_batches[policy_id] = SampleBatch(inner) return MultiAgentBatch(policy_batches, json_data["count"]) else: raise ValueError( "Type field must be one of ['SampleBatch', 'MultiAgentBatch']", data_type)
def add_batch(self, batch: SampleBatchType) -> None: # Make a copy so the replay buffer doesn't pin plasma memory. batch = batch.copy() # Handle everything as if multiagent if isinstance(batch, SampleBatch): batch = MultiAgentBatch({DEFAULT_POLICY_ID: batch}, batch.count) with self.add_batch_timer: # Lockstep mode: Store under _ALL_POLICIES key (we will always # only sample from all policies at the same time). if self.replay_mode == "lockstep": # Note that prioritization is not supported in this mode. for s in batch.timeslices(self.replay_sequence_length): self.replay_buffers[_ALL_POLICIES].add(s, weight=None) else: for policy_id, sample_batch in batch.policy_batches.items(): if self.replay_sequence_length == 1: timeslices = sample_batch.timeslices(1) else: timeslices = timeslice_along_seq_lens_with_overlap( sample_batch=sample_batch, zero_pad_max_seq_len=self.replay_sequence_length, pre_overlap=self.replay_burn_in, zero_init_states=self.replay_zero_init_states, ) for time_slice in timeslices: # If SampleBatch has prio-replay weights, average # over these to use as a weight for the entire # sequence. if "weights" in time_slice: weight = np.mean(time_slice["weights"]) else: weight = None self.replay_buffers[policy_id].add(time_slice, weight=weight) self.num_added += batch.count
def _add_multi_agent_batch_to_buffer( self, buffer, num_policies, num_batches=5, **kwargs ): def _generate_data(policy_id): batch = SampleBatch( { SampleBatch.T: [0], SampleBatch.ACTIONS: [np.random.choice([0, 1])], SampleBatch.REWARDS: [np.random.rand()], SampleBatch.OBS: [np.random.random((4,))], SampleBatch.NEXT_OBS: [np.random.random((4,))], SampleBatch.DONES: [np.random.choice([False, True])], SampleBatch.EPS_ID: [self.batch_id], SampleBatch.AGENT_INDEX: [self.batch_id], "batch_id": [self.batch_id], "policy_id": [policy_id], } ) return batch for i in range(num_batches): # genera a few policy batches policy_batches = {idx: _generate_data(idx) for idx in range(num_policies)} self.batch_id += 1 batch = MultiAgentBatch(policy_batches, 1) buffer.add(batch, **kwargs)
def __call__(self, batch: MultiAgentBatch) -> List[SampleBatchType]: _check_sample_batch_type(batch) batch_count = batch.policy_batches[self.policy_id_to_count_for].count if self.drop_samples_for_other_agents: batch = MultiAgentBatch(policy_batches={ self.policy_id_to_count_for: batch.policy_batches[self.policy_id_to_count_for] }, env_steps=batch.policy_batches[ self.policy_id_to_count_for].count) self.buffer.append(batch) self.count += batch_count if self.count >= self.min_batch_size: if self.count > self.min_batch_size * 2: logger.info("Collected more training samples than expected " "(actual={}, expected={}). ".format( self.count, self.min_batch_size) + "This may be because you have many workers or " "long episodes in 'complete_episodes' batch mode.") out = SampleBatch.concat_samples(self.buffer) timer = _get_shared_metrics().timers[SAMPLE_TIMER] timer.push(time.perf_counter() - self.batch_start_time) timer.push_units_processed(self.count) self.batch_start_time = None self.buffer = [] self.count = 0 return [out] return []
def _optimize(self): if self._fake_batch: fake_batch = SampleBatch(self._fake_batch) samples = MultiAgentBatch({ DEFAULT_POLICY_ID: fake_batch }, fake_batch.count) else: samples = self._replay() with self.grad_timer: if self.before_learn_on_batch: samples = self.before_learn_on_batch( samples, self.workers.local_worker().policy_map, self.train_batch_size) info_dict = self.workers.local_worker().learn_on_batch(samples) for policy_id, info in info_dict.items(): self.learner_stats[policy_id] = get_learner_stats(info) replay_buffer = self.replay_buffers[policy_id] if isinstance(replay_buffer, PrioritizedReplayBuffer): # TODO(sven): This is currently structured differently for # torch/tf. Clean up these results/info dicts across # policies (note: fixing this in torch_policy.py will # break e.g. DDPPO!). td_error = info.get("td_error", info["learner_stats"].get("td_error")) new_priorities = ( np.abs(td_error) + self.prioritized_replay_eps) replay_buffer.update_priorities( samples.policy_batches[policy_id]["batch_indexes"], new_priorities) self.grad_timer.push_units_processed(samples.count) self.num_steps_trained += samples.count
def replay(self, policy_id: Optional[PolicyID] = None) -> SampleBatchType: """If this buffer was given a fake batch, return it, otherwise return a MultiAgentBatch with samples. """ if self._fake_batch: if not isinstance(self._fake_batch, MultiAgentBatch): self._fake_batch = SampleBatch( self._fake_batch).as_multi_agent() return self._fake_batch if self.num_added < self.replay_starts: return None with self.replay_timer: # Lockstep mode: Sample from all policies at the same time an # equal amount of steps. if self.replay_mode == "lockstep": assert ( policy_id is None ), "`policy_id` specifier not allowed in `locksetp` mode!" return self.replay_buffers[_ALL_POLICIES].sample( self.replay_batch_size, beta=self.prioritized_replay_beta) elif policy_id is not None: return self.replay_buffers[policy_id].sample( self.replay_batch_size, beta=self.prioritized_replay_beta) else: samples = {} for policy_id, replay_buffer in self.replay_buffers.items(): samples[policy_id] = replay_buffer.sample( self.replay_batch_size, beta=self.prioritized_replay_beta) return MultiAgentBatch(samples, self.replay_batch_size)
def _replay(self): samples = {} idxes = None with self.replay_timer: for policy_id, replay_buffer in self.replay_buffers.items(): if self.synchronize_sampling: if idxes is None: idxes = replay_buffer.sample_idxes( self.train_batch_size) else: idxes = replay_buffer.sample_idxes(self.train_batch_size) if isinstance(replay_buffer, PrioritizedReplayBuffer): (obses_t, actions, rewards, obses_tp1, dones, weights, batch_indexes) = replay_buffer.sample_with_idxes( idxes, beta=self.prioritized_replay_beta.value( self.num_steps_trained)) else: (obses_t, actions, rewards, obses_tp1, dones) = replay_buffer.sample_with_idxes(idxes) weights = np.ones_like(rewards) batch_indexes = -np.ones_like(rewards) samples[policy_id] = SampleBatch({ "obs": obses_t, "actions": actions, "rewards": rewards, "new_obs": obses_tp1, "dones": dones, "weights": weights, "batch_indexes": batch_indexes }) return MultiAgentBatch(samples, self.train_batch_size)
def before_learn_on_batch(multi_agent_batch, policies, train_batch_size): samples = {} # Modify keys. for pid, p in policies.items(): i = p.config["agent_id"] keys = multi_agent_batch.policy_batches[pid].keys() keys = ["_".join([k, str(i)]) for k in keys] samples.update( dict(zip(keys, multi_agent_batch.policy_batches[pid].values()))) # Make ops and feed_dict to get "new_obs" from target action sampler. new_obs_ph_n = [p.new_obs_ph for p in policies.values()] new_obs_n = list() for k, v in samples.items(): if "new_obs" in k: new_obs_n.append(v) for i, p in enumerate(policies.values()): feed_dict = {new_obs_ph_n[i]: new_obs_n[i]} new_act = p.get_session().run(p.target_act_sampler, feed_dict) samples.update({"new_actions_%d" % i: new_act}) # Share samples among agents. policy_batches = {pid: SampleBatch(samples) for pid in policies.keys()} return MultiAgentBatch(policy_batches, train_batch_size)
def step(self): with self.update_weights_timer: if self.remote_evaluators: weights = ray.put(self.local_evaluator.get_weights()) for e in self.remote_evaluators: e.set_weights.remote(weights) with self.sample_timer: if self.remote_evaluators: batch = SampleBatch.concat_samples( ray_get_and_free( [e.sample.remote() for e in self.remote_evaluators])) else: batch = self.local_evaluator.sample() # Handle everything as if multiagent if isinstance(batch, SampleBatch): batch = MultiAgentBatch({DEFAULT_POLICY_ID: batch}, batch.count) for policy_id, s in batch.policy_batches.items(): for row in s.rows(): self.replay_buffers[policy_id].add( pack_if_needed(row["obs"]), row["actions"], row["rewards"], pack_if_needed(row["new_obs"]), row["dones"], weight=None) if self.num_steps_sampled >= self.replay_starts: self._optimize() self.num_steps_sampled += batch.count
def sample(self, num_items: int, policy_id: Optional[PolicyID] = None, **kwargs) -> Optional[SampleBatchType]: """Samples a MultiAgentBatch of `num_items` per one policy's buffer. If less than `num_items` records are in the policy's buffer, some samples in the results may be repeated to fulfil the batch size `num_items` request. Returns an empty batch if there are no items in the buffer. Args: num_items: Number of items to sample from a policy's buffer. policy_id: ID of the policy that created the experiences we sample. If none is given, sample from all policies. Returns: Concatenated MultiAgentBatch of items. **kwargs: Forward compatibility kwargs. """ # Merge kwargs, overwriting standard call arguments kwargs = merge_dicts_with_warning(self.underlying_buffer_call_args, kwargs) if self._num_added < self.replay_starts: return MultiAgentBatch({}, 0) with self.replay_timer: # Lockstep mode: Sample from all policies at the same time an # equal amount of steps. if self.replay_mode == ReplayMode.LOCKSTEP: assert ( policy_id is None ), "`policy_id` specifier not allowed in `lockstep` mode!" # In lockstep mode we sample MultiAgentBatches return self.replay_buffers[_ALL_POLICIES].sample( num_items, **kwargs) elif policy_id is not None: sample = self.replay_buffers[policy_id].sample( num_items, **kwargs) return MultiAgentBatch({policy_id: sample}, sample.count) else: samples = {} for policy_id, replay_buffer in self.replay_buffers.items(): samples[policy_id] = replay_buffer.sample( num_items, **kwargs) return MultiAgentBatch(samples, sum(s.count for s in samples.values()))
def do_minibatch_sgd(samples, policies, local_worker, num_sgd_iter, sgd_minibatch_size, standardize_fields): """Execute minibatch SGD. Args: samples (SampleBatch): Batch of samples to optimize. policies (dict): Dictionary of policies to optimize. local_worker (RolloutWorker): Master rollout worker instance. num_sgd_iter (int): Number of epochs of optimization to take. sgd_minibatch_size (int): Size of minibatches to use for optimization. standardize_fields (list): List of sample field names that should be normalized prior to optimization. Returns: averaged info fetches over the last SGD epoch taken. """ if isinstance(samples, SampleBatch): samples = MultiAgentBatch({DEFAULT_POLICY_ID: samples}, samples.count) fetches = defaultdict(dict) for policy_id in policies.keys(): if policy_id not in samples.policy_batches: continue batch = samples.policy_batches[policy_id] for field in standardize_fields: batch[field] = standardized(batch[field]) learner_stats = defaultdict(list) model_stats = defaultdict(list) custom_callbacks_stats = defaultdict(list) for i in range(num_sgd_iter): for minibatch in minibatches(batch, sgd_minibatch_size): batch_fetches = (local_worker.learn_on_batch( MultiAgentBatch({policy_id: minibatch}, minibatch.count)))[policy_id] for k, v in batch_fetches.get(LEARNER_STATS_KEY, {}).items(): learner_stats[k].append(v) for k, v in batch_fetches.get("model", {}).items(): model_stats[k].append(v) for k, v in batch_fetches.get("custom_metrics", {}).items(): custom_callbacks_stats[k].append(v) fetches[policy_id][LEARNER_STATS_KEY] = averaged(learner_stats) fetches[policy_id]["model"] = averaged(model_stats) fetches[policy_id]["custom_metrics"] = averaged(custom_callbacks_stats) return fetches
def __call__(self, samples: SampleBatchType) -> SampleBatchType: _check_sample_batch_type(samples) if isinstance(samples, MultiAgentBatch): if self.local_worker: samples = MultiAgentBatch({ pid: batch for pid, batch in samples.policy_batches.items() if self.local_worker.is_policy_to_train(pid, batch) }, samples.count) else: samples = MultiAgentBatch({ k: v for k, v in samples.policy_batches.items() if k in self.policy_ids }, samples.count) return samples
def add_batch(self, batch): # Handle everything as if multiagent if isinstance(batch, SampleBatch): batch = MultiAgentBatch({DEFAULT_POLICY_ID: batch}, batch.count) self.buffer.append(batch) self.cur_size += batch.count self.num_added += batch.count while self.cur_size > self.buffer_size: self.cur_size -= self.buffer.pop(0).count
def do_minibatch_sgd(samples, policies, local_worker, num_sgd_iter, sgd_minibatch_size, standardize_fields): """Execute minibatch SGD. Args: samples (SampleBatch): Batch of samples to optimize. policies (dict): Dictionary of policies to optimize. local_worker (RolloutWorker): Master rollout worker instance. num_sgd_iter (int): Number of epochs of optimization to take. sgd_minibatch_size (int): Size of minibatches to use for optimization. standardize_fields (list): List of sample field names that should be normalized prior to optimization. Returns: averaged info fetches over the last SGD epoch taken. """ if isinstance(samples, SampleBatch): samples = MultiAgentBatch({DEFAULT_POLICY_ID: samples}, samples.count) # Use LearnerInfoBuilder as a unified way to build the final # results dict from `learn_on_loaded_batch` call(s). # This makes sure results dicts always have the same structure # no matter the setup (multi-GPU, multi-agent, minibatch SGD, # tf vs torch). learner_info_builder = LearnerInfoBuilder(num_devices=1) for policy_id in policies.keys(): if policy_id not in samples.policy_batches: continue batch = samples.policy_batches[policy_id] for field in standardize_fields: batch[field] = standardized(batch[field]) for i in range(num_sgd_iter): for minibatch in minibatches(batch, sgd_minibatch_size): results = (local_worker.learn_on_batch( MultiAgentBatch({policy_id: minibatch}, minibatch.count)))[policy_id] learner_info_builder.add_learn_on_batch_results( results, policy_id) learner_info = learner_info_builder.finalize() return learner_info
def mix_batches(_policy_id): """Mixes old with new samples. Tries to mix according to self.replay_ratio on average. If not enough new samples are available, mixes in less old samples to retain self.replay_ratio on average. """ def round_up_or_down(value, ratio): """Returns an integer averaging to value*ratio.""" product = value * ratio ceil_prob = product % 1 if random.uniform(0, 1) < ceil_prob: return int(np.ceil(product)) else: return int(np.floor(product)) max_num_new = round_up_or_down(num_items, 1 - self.replay_ratio) # if num_samples * self.replay_ratio is not round, # we need one more sample with a probability of # (num_items*self.replay_ratio) % 1 _buffer = self.replay_buffers[_policy_id] output_batches = self.last_added_batches[_policy_id][:max_num_new] self.last_added_batches[_policy_id] = self.last_added_batches[_policy_id][ max_num_new: ] # No replay desired if self.replay_ratio == 0.0: return SampleBatch.concat_samples(output_batches) # Only replay desired elif self.replay_ratio == 1.0: return _buffer.sample(num_items, **kwargs) num_new = len(output_batches) if np.isclose(num_new, num_items * (1 - self.replay_ratio)): # The optimal case, we can mix in a round number of old # samples on average num_old = num_items - max_num_new else: # We never want to return more elements than num_items num_old = min( num_items - max_num_new, round_up_or_down( num_new, self.replay_ratio / (1 - self.replay_ratio) ), ) output_batches.append(_buffer.sample(num_old, **kwargs)) # Depending on the implementation of underlying buffers, samples # might be SampleBatches output_batches = [batch.as_multi_agent() for batch in output_batches] return MultiAgentBatch.concat_samples(output_batches)
def add_batch(self, batch): # Make a copy so the replay buffer doesn't pin plasma memory. batch = batch.copy() # Handle everything as if multiagent if isinstance(batch, SampleBatch): batch = MultiAgentBatch({DEFAULT_POLICY_ID: batch}, batch.count) with self.add_batch_timer: if self.replay_mode == "lockstep": # Note that prioritization is not supported in this mode. for s in batch.timeslices(self.replay_sequence_length): self.replay_buffers[_ALL_POLICIES].add(s, weight=None) else: for policy_id, b in batch.policy_batches.items(): for s in b.timeslices(self.replay_sequence_length): if "weights" in s: weight = np.mean(s["weights"]) else: weight = None self.replay_buffers[policy_id].add(s, weight=weight) self.num_added += batch.count
def add_batch(self, batch): # Handle everything as if multiagent if isinstance(batch, SampleBatch): batch = MultiAgentBatch({DEFAULT_POLICY_ID: batch}, batch.count) with self.add_batch_timer: for policy_id, s in batch.policy_batches.items(): for row in s.rows(): self.replay_buffers[policy_id].add( row["obs"], row["actions"], row["rewards"], row["new_obs"], row["dones"], row["weights"]) self.num_added += batch.count
def replay(self): if self._fake_batch: fake_batch = SampleBatch(self._fake_batch) return MultiAgentBatch({DEFAULT_POLICY_ID: fake_batch}, fake_batch.count) if self.num_added < self.replay_starts: return None with self.replay_timer: if self.replay_mode == "lockstep": return self.replay_buffers[_ALL_POLICIES].sample( self.replay_batch_size, beta=self.prioritized_replay_beta) else: samples = {} for policy_id, replay_buffer in self.replay_buffers.items(): samples[policy_id] = replay_buffer.sample( self.replay_batch_size, beta=self.prioritized_replay_beta) return MultiAgentBatch(samples, self.replay_batch_size)