def replay(self, policy_id: PolicyID = DEFAULT_POLICY_ID) -> \ Optional[SampleBatchType]: buffer = self.replay_buffers[policy_id] # Return None, if: # - Buffer empty or # - `replay_ratio` < 1.0 (new samples required in returned batch) # and no new samples to mix with replayed ones. if len(buffer) == 0 or (len(self.last_added_batches[policy_id]) == 0 and self.replay_ratio < 1.0): return None # Mix buffer's last added batches with older replayed batches. with self.replay_timer: output_batches = self.last_added_batches[policy_id] self.last_added_batches[policy_id] = [] # No replay desired -> Return here. if self.replay_ratio == 0.0: return SampleBatch.concat_samples(output_batches) # Only replay desired -> Return a (replayed) sample from the # buffer. elif self.replay_ratio == 1.0: return buffer.replay() # Replay ratio = old / [old + new] # Replay proportion: old / new num_new = len(output_batches) replay_proportion = self.replay_proportion while random.random() < num_new * replay_proportion: replay_proportion -= 1 output_batches.append(buffer.replay()) return SampleBatch.concat_samples(output_batches)
def test_concat_max_seq_len(self): """Tests, SampleBatches.concat_samples() max_seq_len.""" s1 = SampleBatch({ "a": np.array([1, 2, 3]), "b": { "c": np.array([4, 5, 6]) }, SampleBatch.SEQ_LENS: [1, 2] }) s2 = SampleBatch({ "a": np.array([2, 3, 4]), "b": { "c": np.array([5, 6, 7]) }, SampleBatch.SEQ_LENS: [3] }) s3 = SampleBatch({ "a": np.array([2, 3, 4]), "b": { "c": np.array([5, 6, 7]) }, }) concatd = SampleBatch.concat_samples([s1, s2]) check(concatd.max_seq_len, s2.max_seq_len) with self.assertRaises(ValueError): SampleBatch.concat_samples([s1, s2, s3])
def step(self): with self.update_weights_timer: if self.workers.remote_workers(): weights = ray.put(self.workers.local_worker().get_weights()) for e in self.workers.remote_workers(): e.set_weights.remote(weights) with self.sample_timer: samples = [] while sum(s.count for s in samples) < self.train_batch_size: if self.workers.remote_workers(): samples.extend( ray_get_and_free([ e.sample.remote() for e in self.workers.remote_workers() ])) else: samples.append(self.workers.local_worker().sample()) samples = SampleBatch.concat_samples(samples) self.sample_timer.push_units_processed(samples.count) # Unfortunate to have to hack it like this, but not sure how else to do it. # Setting the phase to zeros results in policy optimization, and to ones results in aux optimization. # These have to be added prior to the policy sgd. samples["phase"] = np.zeros(samples.count) with self.grad_timer: fetches = do_minibatch_sgd(samples, self.policies, self.workers.local_worker(), self.num_sgd_iter, self.sgd_minibatch_size, self.standardize_fields) self.grad_timer.push_units_processed(samples.count) if len(fetches) == 1 and DEFAULT_POLICY_ID in fetches: self.learner_stats = fetches[DEFAULT_POLICY_ID] else: self.learner_stats = fetches self.num_steps_sampled += samples.count self.num_steps_trained += samples.count if self.num_steps_sampled > self.aux_loss_start_after_num_steps: # Add samples to the memory to be provided to the aux loss. self._remove_unnecessary_data(samples) self.memory.append(samples) # Optionally run the aux optimization. if len(self.memory) >= self.aux_loss_every_k: samples = SampleBatch.concat_samples(self.memory) self._add_policy_logits(samples) # Ones indicate aux phase. samples["phase"] = np.ones_like(samples["phase"]) do_minibatch_sgd(samples, self.policies, self.workers.local_worker(), self.aux_loss_num_sgd_iter, self.sgd_minibatch_size, []) self.memory = [] return self.learner_stats
def step(self): with self.update_weights_timer: if self.workers.remote_workers(): weights = ray.put(self.workers.local_worker().get_weights()) for e in self.workers.remote_workers(): e.set_weights.remote(weights) with self.sample_timer: samples = [] while sum(s.count for s in samples) < self.train_batch_size: if self.workers.remote_workers(): samples.extend( ray_get_and_free([ e.sample.remote() for e in self.workers.remote_workers() ])) else: samples.append(self.workers.local_worker().sample()) samples = SampleBatch.concat_samples(samples) self.sample_timer.push_units_processed(samples.count) # Handle everything as if multiagent if isinstance(samples, SampleBatch): samples = MultiAgentBatch({DEFAULT_POLICY_ID: samples}, samples.count) fetches = {} with self.grad_timer: for policy_id, policy in self.policies.items(): if policy_id not in samples.policy_batches: continue batch = samples.policy_batches[policy_id] for field in self.standardize_fields: value = batch[field] standardized = (value - value.mean()) / max( 1e-4, value.std()) batch[field] = standardized for i in range(self.num_sgd_iter): iter_extra_fetches = defaultdict(list) for minibatch in self._minibatches(batch): batch_fetches = ( self.workers.local_worker().learn_on_batch( MultiAgentBatch({policy_id: minibatch}, minibatch.count)))[policy_id] for k, v in batch_fetches[LEARNER_STATS_KEY].items(): iter_extra_fetches[k].append(v) logger.debug("{} {}".format(i, _averaged(iter_extra_fetches))) fetches[policy_id] = _averaged(iter_extra_fetches) self.grad_timer.push_units_processed(samples.count) if len(fetches) == 1 and DEFAULT_POLICY_ID in fetches: self.learner_stats = fetches[DEFAULT_POLICY_ID] else: self.learner_stats = fetches self.num_steps_sampled += samples.count self.num_steps_trained += samples.count return self.learner_stats
def estimate( self, batch: SampleBatchType, ) -> OffPolicyEstimate: self.check_can_estimate_for(batch) estimates = [] # Split data into train and test batches for train_episodes, test_episodes in train_test_split( batch, self.train_test_split_val, self.k, ): # Train Q-function if train_episodes: # Reinitialize model self.model.reset() train_batch = SampleBatch.concat_samples(train_episodes) losses = self.train(train_batch) self.losses.append(losses) # Calculate doubly robust OPE estimates for episode in test_episodes: rewards, old_prob = episode["rewards"], episode["action_prob"] new_prob = np.exp(self.action_log_likelihood(episode)) v_old = 0.0 v_new = 0.0 q_values = self.model.estimate_q(episode[SampleBatch.OBS], episode[SampleBatch.ACTIONS]) q_values = convert_to_numpy(q_values) all_actions = np.zeros( [episode.count, self.policy.action_space.n]) all_actions[:] = np.arange(self.policy.action_space.n) # Two transposes required for torch.distributions to work tmp_episode = episode.copy() tmp_episode[SampleBatch.ACTIONS] = all_actions.T action_probs = np.exp( self.action_log_likelihood(tmp_episode)).T v_values = self.model.estimate_v(episode[SampleBatch.OBS], action_probs) v_values = convert_to_numpy(v_values) for t in reversed(range(episode.count)): v_old = rewards[t] + self.gamma * v_old v_new = v_values[t] + (new_prob[t] / old_prob[t]) * ( rewards[t] + self.gamma * v_new - q_values[t]) v_new = v_new.item() estimates.append( OffPolicyEstimate( self.name, { "v_old": v_old, "v_new": v_new, "v_gain": v_new / max(1e-8, v_old), }, )) return estimates
def inner_adaptation_steps(itr): buf = [] split = [] metrics = {} for samples in itr: # Processing Samples (Standardize Advantages) split_lst = [] for sample in samples: sample["advantages"] = standardized(sample["advantages"]) split_lst.append(sample.count) buf.extend(samples) split.append(split_lst) adapt_iter = len(split) - 1 metrics = post_process_metrics(adapt_iter, workers, metrics) if len(split) > inner_steps: out = SampleBatch.concat_samples(buf) out["split"] = np.array(split) buf = [] split = [] # Reporting Adaptation Rew Diff ep_rew_pre = metrics["episode_reward_mean"] ep_rew_post = metrics["episode_reward_mean_adapt_" + str(inner_steps)] metrics["adaptation_delta"] = ep_rew_post - ep_rew_pre yield out, metrics metrics = {} else: inner_adaptation(workers, samples)
def collect_samples(agents, sample_batch_size, num_envs_per_worker, train_batch_size): """Collects at least train_batch_size samples, never discarding any.""" num_timesteps_so_far = 0 trajectories = [] agent_dict = {} for agent in agents: fut_sample = agent.sample.remote() agent_dict[fut_sample] = agent while agent_dict: [fut_sample], _ = ray.wait(list(agent_dict)) agent = agent_dict.pop(fut_sample) next_sample = ray_get_and_free(fut_sample) num_timesteps_so_far += next_sample.count trajectories.append(next_sample) # Only launch more tasks if we don't already have enough pending pending = len(agent_dict) * sample_batch_size * num_envs_per_worker if num_timesteps_so_far + pending < train_batch_size: fut_sample2 = agent.sample.remote() agent_dict[fut_sample2] = agent return SampleBatch.concat_samples(trajectories)
def __call__(self, batch: SampleBatchType) -> List[SampleBatchType]: _check_sample_batch_type(batch) self.buffer.append(batch) if self.count_steps_by == "env_steps": self.count += batch.count else: assert isinstance(batch, MultiAgentBatch), \ "`count_steps_by=agent_steps` only allowed in multi-agent " \ "environments!" self.count += batch.agent_steps() if self.count >= self.min_batch_size: if self.count > self.min_batch_size * 2: logger.info("Collected more training samples than expected " "(actual={}, expected={}). ".format( self.count, self.min_batch_size) + "This may be because you have many workers or " "long episodes in 'complete_episodes' batch mode.") out = SampleBatch.concat_samples(self.buffer) timer = _get_shared_metrics().timers[SAMPLE_TIMER] timer.push(time.perf_counter() - self.batch_start_time) timer.push_units_processed(self.count) self.batch_start_time = None self.buffer = [] self.count = 0 return [out] return []
def test_sequence_size(self): # Seq-len=1. buffer = PrioritizedReplayBuffer( capacity=100, alpha=0.1, storage_unit="fragments" ) for _ in range(200): buffer.add(self._generate_data()) assert len(buffer._storage) == 100, len(buffer._storage) assert buffer.stats()["added_count"] == 200, buffer.stats() # Test get_state/set_state. state = buffer.get_state() new_memory = PrioritizedReplayBuffer(capacity=100, alpha=0.1) new_memory.set_state(state) assert len(new_memory._storage) == 100, len(new_memory._storage) assert new_memory.stats()["added_count"] == 200, new_memory.stats() # Seq-len=5. buffer = PrioritizedReplayBuffer( capacity=100, alpha=0.1, storage_unit="fragments" ) for _ in range(40): buffer.add( SampleBatch.concat_samples([self._generate_data() for _ in range(5)]) ) assert len(buffer._storage) == 20, len(buffer._storage) assert buffer.stats()["added_count"] == 200, buffer.stats() # Test get_state/set_state. state = buffer.get_state() new_memory = PrioritizedReplayBuffer(capacity=100, alpha=0.1) new_memory.set_state(state) assert len(new_memory._storage) == 20, len(new_memory._storage) assert new_memory.stats()["added_count"] == 200, new_memory.stats()
def postprocess_with_HER(policy, sample_batch, _other_agent_batches=None, _episode=None): """ postprocess the sampled batch, inject modified trajectory with modified goal condition """ # Hindsight Experience Replay trajectory augmentation if type(sample_batch) is SampleBatch: # init list of new trajectories augmented_trajs = [sample_batch] # init HER sampling strategy her_sampler = SamplingStrategy(policy, sample_batch) # sample n new trajectories using sampling strategy for i in range(policy.config['num_her_traj']): augmented_trajs.append(her_sampler.sample_trajectory()) # concatenate sampled trajectories sample_batch = SampleBatch.concat_samples(augmented_trajs) # RLlib Original DQN postprocess_fn Implementation sample_batch = postprocess_nstep_and_prio(policy, sample_batch, _other_agent_batches, _episode) return sample_batch
def step(self): with self.update_weights_timer: if self.remote_evaluators: weights = ray.put(self.local_evaluator.get_weights()) for e in self.remote_evaluators: e.set_weights.remote(weights) with self.sample_timer: if self.remote_evaluators: batch = SampleBatch.concat_samples( ray_get_and_free( [e.sample.remote() for e in self.remote_evaluators])) else: batch = self.local_evaluator.sample() # Handle everything as if multiagent if isinstance(batch, SampleBatch): batch = MultiAgentBatch({DEFAULT_POLICY_ID: batch}, batch.count) for policy_id, s in batch.policy_batches.items(): for row in s.rows(): self.replay_buffers[policy_id].add( pack_if_needed(row["obs"]), row["actions"], row["rewards"], pack_if_needed(row["new_obs"]), row["dones"], weight=None) if self.num_steps_sampled >= self.replay_starts: self._optimize() self.num_steps_sampled += batch.count
def sample_min_n_steps_from_buffer( replay_buffer: ReplayBuffer, min_steps: int, count_by_agent_steps: bool) -> Optional[SampleBatchType]: """Samples a minimum of n timesteps from a given replay buffer. This utility method is primarily used by the QMIX algorithm and helps with sampling a given number of time steps which has stored samples in units of sequences or complete episodes. Samples n batches from replay buffer until the total number of timesteps reaches `train_batch_size`. Args: replay_buffer: The replay buffer to sample from num_timesteps: The number of timesteps to sample count_by_agent_steps: Whether to count agent steps or env steps Returns: A concatenated SampleBatch or MultiAgentBatch with samples from the buffer. """ train_batch_size = 0 train_batches = [] while train_batch_size < min_steps: batch = replay_buffer.sample(num_items=1) batch_len = batch.agent_steps( ) if count_by_agent_steps else batch.env_steps() if batch_len == 0: # Replay has not started, so we can't accumulate timesteps here return batch train_batches.append(batch) train_batch_size += batch_len # All batch types are the same type, hence we can use any concat_samples() train_batch = SampleBatch.concat_samples(train_batches) return train_batch
def test_sequence_size(self): # Seq-len=1. memory = PrioritizedReplayBuffer(capacity=100, alpha=0.1) for _ in range(200): memory.add(self._generate_data(), weight=None) assert len(memory._storage) == 100, len(memory._storage) assert memory.stats()["added_count"] == 200, memory.stats() # Test get_state/set_state. state = memory.get_state() new_memory = PrioritizedReplayBuffer(capacity=100, alpha=0.1) new_memory.set_state(state) assert len(new_memory._storage) == 100, len(new_memory._storage) assert new_memory.stats()["added_count"] == 200, new_memory.stats() # Seq-len=5. memory = PrioritizedReplayBuffer(capacity=100, alpha=0.1) for _ in range(40): memory.add( SampleBatch.concat_samples( [self._generate_data() for _ in range(5)]), weight=None, ) assert len(memory._storage) == 20, len(memory._storage) assert memory.stats()["added_count"] == 200, memory.stats() # Test get_state/set_state. state = memory.get_state() new_memory = PrioritizedReplayBuffer(capacity=100, alpha=0.1) new_memory.set_state(state) assert len(new_memory._storage) == 20, len(new_memory._storage) assert new_memory.stats()["added_count"] == 200, new_memory.stats()
def _add_sample_batch_to_buffer(self, buffer, batch_size, num_batches=5, **kwargs): self.eps_id = 0 def _generate_data(): self.eps_id += 1 return SampleBatch({ SampleBatch.T: [0, 1], SampleBatch.ACTIONS: 2 * [np.random.choice([0, 1])], SampleBatch.REWARDS: 2 * [np.random.rand()], SampleBatch.OBS: 2 * [np.random.random((4, ))], SampleBatch.NEXT_OBS: 2 * [np.random.random((4, ))], SampleBatch.DONES: 2 * [np.random.choice([False, True])], SampleBatch.EPS_ID: 2 * [self.eps_id], SampleBatch.AGENT_INDEX: 2 * [0], "batch_id": 2 * [self.batch_id], }) for i in range(num_batches): data = [_generate_data() for _ in range(batch_size)] self.batch_id += 1 batch = SampleBatch.concat_samples(data) buffer.add(batch, **kwargs)
def postprocess_with_HER(policy, sample_batch, _other_agent_batches=None, _episode=None): """ postprocess the sampled batch, inject modified trajectory with modified goal condition """ import numpy as np # Hindsight Experience Replay trajectory augmentation if (type(sample_batch) is SampleBatch) and ( policy.config['use_HER']) and (sample_batch['obs'].shape[0] > 0): # init list of new trajectories augmented_trajs = [sample_batch] # init HER sampling strategy her_sampler = SamplingStrategy(policy, sample_batch) # sample n new trajectories using sampling strategy for i in range(policy.config['num_HER_traj']): augmented_trajs.append(her_sampler.sample_trajectory()) # concatenate sampled trajectories sample_batch = SampleBatch.concat_samples(augmented_trajs) # Original postprocess_fn Implementation sample_batch = postprocess_fn(policy, sample_batch, _other_agent_batches, _episode) # code.interact(local=locals()) return sample_batch
def step(self): with self.update_weights_timer: if self.remote_evaluators: weights = ray.put(self.local_evaluator.get_weights()) for e in self.remote_evaluators: e.set_weights.remote(weights) with self.sample_timer: samples = [] while sum(s.count for s in samples) < self.train_batch_size: if self.remote_evaluators: samples.extend( ray_get_and_free([ e.sample.remote() for e in self.remote_evaluators ])) else: samples.append(self.local_evaluator.sample()) samples = SampleBatch.concat_samples(samples) self.sample_timer.push_units_processed(samples.count) with self.grad_timer: for i in range(self.num_sgd_iter): fetches = self.local_evaluator.learn_on_batch(samples) self.learner_stats = get_learner_stats(fetches) if self.num_sgd_iter > 1: logger.debug("{} {}".format(i, fetches)) self.grad_timer.push_units_processed(samples.count) self.num_steps_sampled += samples.count self.num_steps_trained += samples.count return self.learner_stats
def inner_adaptation_steps(itr): buf = [] split = [] metrics = {} for samples in itr: print("Collecting Samples, Inner Adaptation {}".format( len(split))) # Processing Samples (Standardize Advantages) samples, split_lst = post_process_samples(samples, config) buf.extend(samples) split.append(split_lst) adapt_iter = len(split) - 1 prefix = "DynaTrajInner_" + str(adapt_iter) metrics = post_process_metrics(prefix, workers, metrics) if len(split) > num_inner_steps: out = SampleBatch.concat_samples(buf) out["split"] = np.array(split) buf = [] split = [] yield out, metrics metrics = {} else: inner_adaptation(workers, samples)
def collect_samples_straggler_mitigation(agents, train_batch_size): """Collects at least train_batch_size samples. This is the legacy behavior as of 0.6, and launches extra sample tasks to potentially improve performance but can result in many wasted samples. """ num_timesteps_so_far = 0 trajectories = [] agent_dict = {} for agent in agents: fut_sample = agent.sample.remote() agent_dict[fut_sample] = agent while num_timesteps_so_far < train_batch_size: # TODO(pcm): Make wait support arbitrary iterators and remove the # conversion to list here. [fut_sample], _ = ray.wait(list(agent_dict)) agent = agent_dict.pop(fut_sample) # Start task with next trajectory and record it in the dictionary. fut_sample2 = agent.sample.remote() agent_dict[fut_sample2] = agent next_sample = ray_get_and_free(fut_sample) num_timesteps_so_far += next_sample.count trajectories.append(next_sample) logger.info("Discarding {} sample tasks".format(len(agent_dict))) return SampleBatch.concat_samples(trajectories)
def ParallelRollouts(workers: WorkerSet, mode="bulk_sync") -> LocalIterator[SampleBatch]: """Operator to collect experiences in parallel from rollout workers. If there are no remote workers, experiences will be collected serially from the local worker instance instead. Arguments: workers (WorkerSet): set of rollout workers to use. mode (str): One of {'async', 'bulk_sync'}. - In 'async' mode, batches are returned as soon as they are computed by rollout workers with no order guarantees. - In 'bulk_sync' mode, we collect one batch from each worker and concatenate them together into a large batch to return. Returns: A local iterator over experiences collected in parallel. Examples: >>> rollouts = ParallelRollouts(workers, mode="async") >>> batch = next(rollouts) >>> print(batch.count) 50 # config.sample_batch_size >>> rollouts = ParallelRollouts(workers, mode="bulk_sync") >>> batch = next(rollouts) >>> print(batch.count) 200 # config.sample_batch_size * config.num_workers Updates the STEPS_SAMPLED_COUNTER counter in the local iterator context. """ def report_timesteps(batch): metrics = LocalIterator.get_metrics() metrics.counters[STEPS_SAMPLED_COUNTER] += batch.count return batch if not workers.remote_workers(): # Handle the serial sampling case. def sampler(_): while True: yield workers.local_worker().sample() return (LocalIterator(sampler, MetricsContext()).for_each(report_timesteps)) # Create a parallel iterator over generated experiences. rollouts = from_actors(workers.remote_workers()) if mode == "bulk_sync": return rollouts \ .batch_across_shards() \ .for_each(lambda batches: SampleBatch.concat_samples(batches)) \ .for_each(report_timesteps) elif mode == "async": return rollouts.gather_async().for_each(report_timesteps) else: raise ValueError( "mode must be one of 'bulk_sync', 'async', got '{}'".format(mode))
def _sgd_step(self): samples = [random.choice(self.replay_buffer)] while sum(s.count for s in samples) < self.train_batch_size: samples.append(random.choice(self.replay_buffer)) samples = SampleBatch.concat_samples(samples) info_dict = self.workers.local_worker().learn_on_batch(samples) for policy_id, info in info_dict.items(): self.learner_stats[policy_id] = get_learner_stats(info) self.num_steps_trained += samples.count return info_dict
def test_concat(self): b1 = SampleBatch({"a": np.array([1, 2, 3]), "b": np.array([4, 5, 6])}) b2 = SampleBatch({"a": np.array([1]), "b": np.array([4])}) b3 = SampleBatch({"a": np.array([1]), "b": np.array([5])}) b12 = b1.concat(b2) self.assertEqual(b12["a"].tolist(), [1, 2, 3, 1]) self.assertEqual(b12["b"].tolist(), [4, 5, 6, 4]) b = SampleBatch.concat_samples([b1, b2, b3]) self.assertEqual(b["a"].tolist(), [1, 2, 3, 1, 1]) self.assertEqual(b["b"].tolist(), [4, 5, 6, 4, 5])
def mix_batches(_policy_id): """Mixes old with new samples. Tries to mix according to self.replay_ratio on average. If not enough new samples are available, mixes in less old samples to retain self.replay_ratio on average. """ def round_up_or_down(value, ratio): """Returns an integer averaging to value*ratio.""" product = value * ratio ceil_prob = product % 1 if random.uniform(0, 1) < ceil_prob: return int(np.ceil(product)) else: return int(np.floor(product)) max_num_new = round_up_or_down(num_items, 1 - self.replay_ratio) # if num_samples * self.replay_ratio is not round, # we need one more sample with a probability of # (num_items*self.replay_ratio) % 1 _buffer = self.replay_buffers[_policy_id] output_batches = self.last_added_batches[_policy_id][:max_num_new] self.last_added_batches[_policy_id] = self.last_added_batches[_policy_id][ max_num_new: ] # No replay desired if self.replay_ratio == 0.0: return SampleBatch.concat_samples(output_batches) # Only replay desired elif self.replay_ratio == 1.0: return _buffer.sample(num_items, **kwargs) num_new = len(output_batches) if np.isclose(num_new, num_items * (1 - self.replay_ratio)): # The optimal case, we can mix in a round number of old # samples on average num_old = num_items - max_num_new else: # We never want to return more elements than num_items num_old = min( num_items - max_num_new, round_up_or_down( num_new, self.replay_ratio / (1 - self.replay_ratio) ), ) output_batches.append(_buffer.sample(num_old, **kwargs)) # Depending on the implementation of underlying buffers, samples # might be SampleBatches output_batches = [batch.as_multi_agent() for batch in output_batches] return MultiAgentBatch.concat_samples(output_batches)
def replay( self, policy_id: PolicyID = DEFAULT_POLICY_ID ) -> Optional[SampleBatchType]: if self.replay_mode == ReplayMode.LOCKSTEP and policy_id != _ALL_POLICIES: raise ValueError( "Trying to sample from single policy's buffer in lockstep " "mode. In lockstep mode, all policies' experiences are " "sampled from a single replay buffer which is accessed " "with the policy id `{}`".format(_ALL_POLICIES) ) buffer = self.replay_buffers[policy_id] # Return None, if: # - Buffer empty or # - `replay_ratio` < 1.0 (new samples required in returned batch) # and no new samples to mix with replayed ones. if len(buffer) == 0 or ( len(self.last_added_batches[policy_id]) == 0 and self.replay_ratio < 1.0 ): return None # Mix buffer's last added batches with older replayed batches. with self.replay_timer: output_batches = self.last_added_batches[policy_id] self.last_added_batches[policy_id] = [] # No replay desired -> Return here. if self.replay_ratio == 0.0: return SampleBatch.concat_samples(output_batches) # Only replay desired -> Return a (replayed) sample from the # buffer. elif self.replay_ratio == 1.0: return buffer.replay() # Replay ratio = old / [old + new] # Replay proportion: old / new num_new = len(output_batches) replay_proportion = self.replay_proportion while random.random() < num_new * replay_proportion: replay_proportion -= 1 output_batches.append(buffer.replay()) return SampleBatch.concat_samples(output_batches)
def _optimize(self): samples = [random.choice(self.replay_buffer)] while sum(s.count for s in samples) < self.train_batch_size: samples.append(random.choice(self.replay_buffer)) samples = SampleBatch.concat_samples(samples) with self.grad_timer: info_dict = self.local_evaluator.learn_on_batch(samples) for policy_id, info in info_dict.items(): self.learner_stats[policy_id] = get_learner_stats(info) self.grad_timer.push_units_processed(samples.count) self.num_steps_trained += samples.count return info_dict
def __call__(self, batch: SampleBatch) -> List[SampleBatch]: if not isinstance(batch, SampleBatch): raise ValueError("Expected type SampleBatch, got {}: {}".format( type(batch), batch)) self.buffer.append(batch) self.count += batch.count if self.count >= self.min_batch_size: out = SampleBatch.concat_samples(self.buffer) self.buffer = [] self.count = 0 return [out] return []
def mix_batches(_policy_id): _buffer = self.replay_buffers[policy_id] output_batches = self.last_added_batches[_policy_id] self.last_added_batches[_policy_id] = [] # No replay desired if self.replay_ratio == 0.0: return SampleBatch.concat_samples(output_batches) # Only replay desired elif self.replay_ratio == 1.0: return _buffer.sample(num_items, beta=self.prioritized_replay_beta) # Replay ratio = old / [old + new] # Replay proportion: old / new num_new = len(output_batches) replay_proportion = self.replay_proportion while random.random() < num_new * replay_proportion: replay_proportion -= 1 output_batches.append(_buffer.sample(num_items)) return SampleBatch.concat_samples(output_batches)
def training_step(self) -> ResultDict: """TODO: Returns: The results dict from executing the training iteration. """ # Sample n MultiAgentBatches from n workers. new_sample_batches = synchronous_parallel_sample( worker_set=self.workers, concat=False) for batch in new_sample_batches: # Update sampling step counters. self._counters[NUM_ENV_STEPS_SAMPLED] += batch.env_steps() self._counters[NUM_AGENT_STEPS_SAMPLED] += batch.agent_steps() # Store new samples in the replay buffer # Use deprecated add_batch() to support old replay buffers for now if self.local_replay_buffer is not None: self.local_replay_buffer.add(batch) if self.local_replay_buffer is not None: train_batch = self.local_replay_buffer.sample( self.config["train_batch_size"]) else: train_batch = SampleBatch.concat_samples(new_sample_batches) # Learn on the training batch. # Use simple optimizer (only for multi-agent or tf-eager; all other # cases should use the multi-GPU optimizer, even if only using 1 GPU) train_results = {} if train_batch is not None: if self.config.get("simple_optimizer") is True: train_results = train_one_step(self, train_batch) else: train_results = multi_gpu_train_one_step(self, train_batch) # TODO: Move training steps counter update outside of `train_one_step()` method. # # Update train step counters. # self._counters[NUM_ENV_STEPS_TRAINED] += train_batch.env_steps() # self._counters[NUM_AGENT_STEPS_TRAINED] += train_batch.agent_steps() # Update weights and global_vars - after learning on the local worker - on all # remote workers. global_vars = { "timestep": self._counters[NUM_ENV_STEPS_SAMPLED], } with self._timers[SYNCH_WORKER_WEIGHTS_TIMER]: self.workers.sync_weights(global_vars=global_vars) # Return all collected metrics for the iteration. return train_results
def get_cross_policy_object(multi_agent_batch, self_optimizer): """Add contents into cross_policy_object, which passed to each policy.""" config = self_optimizer.workers._remote_config if not config["use_joint_dataset"]: joint_obs = SampleBatch.concat_samples( list(multi_agent_batch.policy_batches.values()))[ SampleBatch.CUR_OBS] else: sample_size = config.get("joint_dataset_sample_batch_size") assert sample_size is not None, "You should specify the value of: " \ "joint_dataset_sample_batch_size " \ "in config!" samples = [multi_agent_batch] count_dict = { k: v.count for k, v in multi_agent_batch.policy_batches.items() } for k in self_optimizer.workers.local_worker().policy_map.keys(): if k not in count_dict: count_dict[k] = 0 while any([v < sample_size for v in count_dict.values()]): tmp_batch = self_optimizer.workers.local_worker().sample() samples.append(tmp_batch) for k, v in tmp_batch.policy_batches.items(): assert k in count_dict, count_dict count_dict[k] += v.count multi_agent_batch = MultiAgentBatch.concat_samples(samples) joint_obs = [] pid_list = [] for pid, batch in multi_agent_batch.policy_batches.items(): batch.shuffle() assert batch.count >= sample_size, batch joint_obs.append(batch.slice(0, sample_size)['obs']) pid_list.append(pid) joint_obs = np.concatenate(joint_obs) def _replay(policy, pid): act, _, infos = policy.compute_actions(joint_obs) return pid, act, infos # ATTENTION!!! Here is MYSELF replaying JOINT OBSERVATION ret = { pid: act for pid, act, infos in self_optimizer.workers.local_worker().foreach_policy(_replay) } return {JOINT_OBS: joint_obs, PEER_ACTION: ret}
def __call__(self, batch: SampleBatchType) -> List[SampleBatchType]: _check_sample_batch_type(batch) self.buffer.append(batch) self.count += batch.count if self.count >= self.min_batch_size: out = SampleBatch.concat_samples(self.buffer) timer = LocalIterator.get_metrics().timers[SAMPLE_TIMER] timer.push(time.perf_counter() - self.batch_start_time) timer.push_units_processed(self.count) self.batch_start_time = None self.buffer = [] self.count = 0 return [out] return []
def training_workflow(config, reporter): # Setup policy and policy evaluation actors env = gym.make("CartPole-v0") policy = CustomPolicy(env.observation_space, env.action_space, {}) workers = [ RolloutWorker.as_remote().remote( env_creator=lambda c: gym.make("CartPole-v0"), policy=CustomPolicy) for _ in range(config["num_workers"]) ] for _ in range(config["num_iters"]): # Broadcast weights to the policy evaluation workers weights = ray.put({DEFAULT_POLICY_ID: policy.get_weights()}) for w in workers: w.set_weights.remote(weights) # Gather a batch of samples T1 = SampleBatch.concat_samples( ray.get([w.sample.remote() for w in workers])) # Update the remote policy replicas and gather another batch of samples new_value = policy.w * 2.0 for w in workers: w.for_policy.remote(lambda p: p.update_some_value(new_value)) # Gather another batch of samples T2 = SampleBatch.concat_samples( ray.get([w.sample.remote() for w in workers])) # Improve the policy using the T1 batch policy.learn_on_batch(T1) # Do some arbitrary updates based on the T2 batch policy.update_some_value(sum(T2["rewards"])) reporter(**collect_metrics(remote_workers=workers))