def collect_samples(agents, sample_batch_size, num_envs_per_worker, train_batch_size): """Collects at least train_batch_size samples, never discarding any.""" num_timesteps_so_far = 0 trajectories = [] agent_dict = {} for agent in agents: fut_sample = agent.sample.remote() agent_dict[fut_sample] = agent while agent_dict: [fut_sample], _ = ray.wait(list(agent_dict)) agent = agent_dict.pop(fut_sample) next_sample = ray.get(fut_sample) assert next_sample.count >= sample_batch_size * num_envs_per_worker num_timesteps_so_far += next_sample.count trajectories.append(next_sample) # Only launch more tasks if we don't already have enough pending pending = len(agent_dict) * sample_batch_size * num_envs_per_worker if num_timesteps_so_far + pending < train_batch_size: fut_sample2 = agent.sample.remote() agent_dict[fut_sample2] = agent return SampleBatch.concat_samples(trajectories)
def collect_samples_straggler_mitigation(agents, train_batch_size): """Collects at least train_batch_size samples. This is the legacy behavior as of 0.6, and launches extra sample tasks to potentially improve performance but can result in many wasted samples. """ num_timesteps_so_far = 0 trajectories = [] agent_dict = {} for agent in agents: fut_sample = agent.sample.remote() agent_dict[fut_sample] = agent while num_timesteps_so_far < train_batch_size: # TODO(pcm): Make wait support arbitrary iterators and remove the # conversion to list here. [fut_sample], _ = ray.wait(list(agent_dict)) agent = agent_dict.pop(fut_sample) # Start task with next trajectory and record it in the dictionary. fut_sample2 = agent.sample.remote() agent_dict[fut_sample2] = agent next_sample = ray.get(fut_sample) num_timesteps_so_far += next_sample.count trajectories.append(next_sample) logger.info("Discarding {} sample tasks".format(len(agent_dict))) return SampleBatch.concat_samples(trajectories)
def step(self): with self.update_weights_timer: if self.remote_evaluators: weights = ray.put(self.local_evaluator.get_weights()) for e in self.remote_evaluators: e.set_weights.remote(weights) with self.sample_timer: if self.remote_evaluators: samples = SampleBatch.concat_samples( ray.get( [e.sample.remote() for e in self.remote_evaluators])) else: samples = self.local_evaluator.sample() with self.grad_timer: for i in range(self.num_sgd_iter): fetches = self.local_evaluator.compute_apply(samples) if self.num_sgd_iter > 1: print(i, fetches) self.grad_timer.push_units_processed(samples.count) self.num_steps_sampled += samples.count self.num_steps_trained += samples.count return fetches
def collect_samples(agents, timesteps_per_batch): num_timesteps_so_far = 0 trajectories = [] # This variable maps the object IDs of trajectories that are currently # computed to the agent that they are computed on; we start some initial # tasks here. agent_dict = {} for agent in agents: fut_sample = agent.sample.remote() agent_dict[fut_sample] = agent while num_timesteps_so_far < timesteps_per_batch: # TODO(pcm): Make wait support arbitrary iterators and remove the # conversion to list here. [fut_sample], _ = ray.wait(list(agent_dict)) agent = agent_dict.pop(fut_sample) # Start task with next trajectory and record it in the dictionary. fut_sample2 = agent.sample.remote() agent_dict[fut_sample2] = agent next_sample = ray.get(fut_sample) num_timesteps_so_far += next_sample.count trajectories.append(next_sample) return SampleBatch.concat_samples(trajectories)
def step(self): with self.update_weights_timer: if self.remote_evaluators: weights = ray.put(self.local_evaluator.get_weights()) for e in self.remote_evaluators: e.set_weights.remote(weights) with self.sample_timer: if self.remote_evaluators: batch = SampleBatch.concat_samples( ray.get( [e.sample.remote() for e in self.remote_evaluators])) else: batch = self.local_evaluator.sample() # Handle everything as if multiagent if isinstance(batch, SampleBatch): batch = MultiAgentBatch({DEFAULT_POLICY_ID: batch}, batch.count) for policy_id, s in batch.policy_batches.items(): for row in s.rows(): self.replay_buffers[policy_id].add( pack_if_needed(row["obs"]), row["actions"], row["rewards"], pack_if_needed(row["new_obs"]), row["dones"], weight=None) if self.num_steps_sampled >= self.replay_starts: self._optimize() self.num_steps_sampled += batch.count
def step(self): with self.update_weights_timer: if self.remote_evaluators: weights = ray.put(self.local_evaluator.get_weights()) for e in self.remote_evaluators: e.set_weights.remote(weights) with self.sample_timer: if self.remote_evaluators: batch = SampleBatch.concat_samples( ray.get( [e.sample.remote() for e in self.remote_evaluators])) else: batch = self.local_evaluator.sample() # Handle everything as if multiagent if isinstance(batch, SampleBatch): batch = MultiAgentBatch({ DEFAULT_POLICY_ID: batch }, batch.count) for policy_id, s in batch.policy_batches.items(): for row in s.rows(): self.replay_buffers[policy_id].add( pack_if_needed(row["obs"]), row["actions"], row["rewards"], pack_if_needed(row["new_obs"]), row["dones"], weight=None) if self.num_steps_sampled >= self.replay_starts: self._optimize() self.num_steps_sampled += batch.count
def update_episode_buffer(self, samples): if self.config["episode_mode"] == "episode_buffer": # assert True new_batches = list(separate_sample_batch(samples).values()) if self._episode_buffer: old_batches = \ list(separate_sample_batch(self._episode_buffer).values()) else: old_batches = [] all_batches = old_batches + new_batches all_batches = sorted(all_batches, key=lambda x: x["rewards"].sum(), reverse=True) buffer_size = self.config["buffer_size"] # assert buffer_size == 24 self._episode_buffer = SampleBatch.concat_samples( all_batches[:buffer_size]) elif self.config["episode_mode"] == "last_episodes": # assert False self._episode_buffer = samples elif self.config["episode_mode"] == "all_episodes": # assert False if self._episode_buffer: self._episode_buffer = self._episode_buffer.concat(samples) else: self._episode_buffer = samples else: raise NotImplementedError self._rnn_state_out = None
def step(self): with self.update_weights_timer: if self.remote_evaluators: weights = ray.put(self.local_evaluator.get_weights()) for e in self.remote_evaluators: e.set_weights.remote(weights) with self.sample_timer: samples = [] while sum(s.count for s in samples) < self.train_batch_size: if self.remote_evaluators: samples.extend( ray.get([ e.sample.remote() for e in self.remote_evaluators ])) else: samples.append(self.local_evaluator.sample()) samples = SampleBatch.concat_samples(samples) self.sample_timer.push_units_processed(samples.count) with self.grad_timer: for i in range(self.num_sgd_iter): fetches = self.local_evaluator.learn_on_batch(samples) if "stats" in fetches: self.learner_stats = fetches["stats"] if self.num_sgd_iter > 1: logger.debug("{} {}".format(i, fetches)) self.grad_timer.push_units_processed(samples.count) self.num_steps_sampled += samples.count self.num_steps_trained += samples.count return fetches
def step(self): with self.update_weights_timer: if self.remote_evaluators: weights = ray.put(self.local_evaluator.get_weights()) for e in self.remote_evaluators: e.set_weights.remote(weights) with self.sample_timer: samples = [] while sum(s.count for s in samples) < self.train_batch_size: if self.remote_evaluators: samples.extend( ray.get([ e.sample.remote() for e in self.remote_evaluators ])) else: samples.append(self.local_evaluator.sample()) samples = SampleBatch.concat_samples(samples) self.sample_timer.push_units_processed(samples.count) with self.grad_timer: for i in range(self.num_sgd_iter): fetches = self.local_evaluator.compute_apply(samples) if "stats" in fetches: self.learner_stats = fetches["stats"] if self.num_sgd_iter > 1: logger.debug("{} {}".format(i, fetches)) self.grad_timer.push_units_processed(samples.count) self.num_steps_sampled += samples.count self.num_steps_trained += samples.count return fetches
def postprocess_trajectory(samples, baseline, gamma, lambda_, use_gae): separated_samples = separate_sample_batch(samples) baseline.fit(separated_samples.values()) for eps_id, values in separated_samples.items(): values["vf_preds"] = baseline.predict(values) separated_samples[eps_id] = compute_advantages(values, 0.0, gamma, lambda_, use_gae) samples = SampleBatch.concat_samples(list(separated_samples.values())) return samples
def _optimize(self): samples = [random.choice(self.replay_buffer)] while sum(s.count for s in samples) < self.train_batch_size: samples.append(random.choice(self.replay_buffer)) samples = SampleBatch.concat_samples(samples) with self.grad_timer: info_dict = self.local_evaluator.compute_apply(samples) for policy_id, info in info_dict.items(): if "stats" in info: self.learner_stats[policy_id] = info["stats"] self.grad_timer.push_units_processed(samples.count) self.num_steps_trained += samples.count
def step(self): with self.update_weights_timer: if self.remote_evaluators: weights = ray.put(self.local_evaluator.get_weights()) for e in self.remote_evaluators: e.set_weights.remote(weights) with self.sample_timer: samples = [] while sum(s.count for s in samples) < self.train_batch_size: if self.remote_evaluators: samples.extend( ray.get([ e.sample.remote() for e in self.remote_evaluators ])) else: samples.append(self.local_evaluator.sample()) samples = SampleBatch.concat_samples(samples) self.sample_timer.push_units_processed(samples.count) # print("\n\nhkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkk samples", samples.keys()) # print("hkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkk samples obs", samples["obs"].shape) # print("hkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkk samples new_obs", samples["new_obs"].shape) # print("hkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkk samples actions", samples["actions"].shape) # import numpy # transition_state = numpy.stack([samples["obs"], samples["new_obs"]], axis=2) # print("hkkkkk ", transition_state.shape) # print("hkkkkk ", transition_state.reshape(transition_state.shape[0], -1).shape) # print("samples[obs]", samples["obs"]) # print("samples[actions]", samples["actions"]) with self.grad_timer: new_samples = self.env_model.process(samples) # print("new_samples[obs]", new_samples["obs"]) # print("new_samples[actions]", new_samples["actions"]) for i in range(self.num_sgd_iter): fetches = self.local_evaluator.compute_apply(new_samples) if "stats" in fetches: self.learner_stats = fetches["stats"] if self.num_sgd_iter > 1: print(i, fetches) # self.grad_timer.push_units_processed(new_samples.count) self.grad_timer.push_units_processed(len(samples["obs"])) # self.num_steps_sampled += new_samples.count # self.num_steps_trained += new_samples.count self.num_steps_sampled += len(samples["obs"]) self.num_steps_trained += len(samples["obs"]) return fetches
def _postprocess_if_needed(self, batch): if not self.ioctx.config.get("postprocess_inputs"): return batch if isinstance(batch, SampleBatch): out = [] for sub_batch in batch.split_by_episode(): out.append(self.ioctx.evaluator.policy_map[DEFAULT_POLICY_ID] .postprocess_trajectory(sub_batch)) return SampleBatch.concat_samples(out) else: # TODO(ekl) this is trickier since the alignments between agent # trajectories in the episode are not available any more. raise NotImplementedError( "Postprocessing of multi-agent data not implemented yet.")
def step(self): with self.update_weights_timer: if self.remote_evaluators: weights = ray.put(self.local_evaluator.get_weights()) for e in self.remote_evaluators: e.set_weights.remote(weights) with self.sample_timer: if self.remote_evaluators: samples = SampleBatch.concat_samples( ray.get( [e.sample.remote() for e in self.remote_evaluators])) else: samples = self.local_evaluator.sample() with self.grad_timer: grad, _ = self.local_evaluator.compute_gradients(samples) self.local_evaluator.apply_gradients(grad) self.grad_timer.push_units_processed(samples.count) self.num_steps_sampled += samples.count self.num_steps_trained += samples.count
def step(self): with self.update_weights_timer: if self.remote_evaluators: weights = ray.put(self.local_evaluator.get_weights()) for e in self.remote_evaluators: e.set_weights.remote(weights) with self.sample_timer: if self.remote_evaluators: if self.straggler_mitigation: samples = collect_samples_straggler_mitigation( self.remote_evaluators, self.train_batch_size) else: samples = collect_samples( self.remote_evaluators, self.sample_batch_size, self.num_envs_per_worker, self.train_batch_size) if samples.count > self.train_batch_size * 2: logger.info( "Collected more training samples than expected " "(actual={}, train_batch_size={}). ".format( samples.count, self.train_batch_size) + "This may be because you have many workers or " "long episodes in 'complete_episodes' batch mode.") else: samples = [] while sum(s.count for s in samples) < self.train_batch_size: samples.append(self.local_evaluator.sample()) samples = SampleBatch.concat_samples(samples) # Handle everything as if multiagent if isinstance(samples, SampleBatch): samples = MultiAgentBatch({ DEFAULT_POLICY_ID: samples }, samples.count) for policy_id, policy in self.policies.items(): if policy_id not in samples.policy_batches: continue batch = samples.policy_batches[policy_id] for field in self.standardize_fields: value = batch[field] standardized = (value - value.mean()) / max(1e-4, value.std()) batch[field] = standardized # Important: don't shuffle RNN sequence elements if not policy._state_inputs: batch.shuffle() num_loaded_tuples = {} with self.load_timer: for policy_id, batch in samples.policy_batches.items(): if policy_id not in self.policies: continue policy = self.policies[policy_id] tuples = policy._get_loss_inputs_dict(batch) data_keys = [ph for _, ph in policy._loss_inputs] if policy._state_inputs: state_keys = policy._state_inputs + [policy._seq_lens] else: state_keys = [] num_loaded_tuples[policy_id] = ( self.optimizers[policy_id].load_data( self.sess, [tuples[k] for k in data_keys], [tuples[k] for k in state_keys])) fetches = {} with self.grad_timer: for policy_id, tuples_per_device in num_loaded_tuples.items(): optimizer = self.optimizers[policy_id] num_batches = max( 1, int(tuples_per_device) // int(self.per_device_batch_size)) logger.debug("== sgd epochs for {} ==".format(policy_id)) for i in range(self.num_sgd_iter): iter_extra_fetches = defaultdict(list) permutation = np.random.permutation(num_batches) for batch_index in range(num_batches): batch_fetches = optimizer.optimize( self.sess, permutation[batch_index] * self.per_device_batch_size) for k, v in batch_fetches.items(): iter_extra_fetches[k].append(v) logger.debug("{} {}".format(i, _averaged(iter_extra_fetches))) fetches[policy_id] = _averaged(iter_extra_fetches) self.num_steps_sampled += samples.count self.num_steps_trained += tuples_per_device * len(self.devices) return fetches
def step(self): with self.update_weights_timer: if self.remote_evaluators: weights = ray.put(self.local_evaluator.get_weights()) for e in self.remote_evaluators: e.set_weights.remote(weights) with self.sample_timer: if self.remote_evaluators: if self.straggler_mitigation: samples = collect_samples_straggler_mitigation( self.remote_evaluators, self.train_batch_size) else: samples = collect_samples( self.remote_evaluators, self.sample_batch_size, self.num_envs_per_worker, self.train_batch_size) if samples.count > self.train_batch_size * 2: logger.info( "Collected more training samples than expected " "(actual={}, train_batch_size={}). ".format( samples.count, self.train_batch_size) + "This may be because you have many workers or " "long episodes in 'complete_episodes' batch mode.") else: samples = [] while sum(s.count for s in samples) < self.train_batch_size: samples.append(self.local_evaluator.sample()) samples = SampleBatch.concat_samples(samples) # Handle everything as if multiagent if isinstance(samples, SampleBatch): samples = MultiAgentBatch({ DEFAULT_POLICY_ID: samples }, samples.count) for policy_id, policy in self.policies.items(): if policy_id not in samples.policy_batches: continue batch = samples.policy_batches[policy_id] for field in self.standardize_fields: value = batch[field] standardized = (value - value.mean()) / max(1e-4, value.std()) batch[field] = standardized # Important: don't shuffle RNN sequence elements if not policy._state_inputs: batch.shuffle() num_loaded_tuples = {} with self.load_timer: for policy_id, batch in samples.policy_batches.items(): if policy_id not in self.policies: continue policy = self.policies[policy_id] tuples = policy._get_loss_inputs_dict(batch) data_keys = [ph for _, ph in policy._loss_inputs] if policy._state_inputs: state_keys = policy._state_inputs + [policy._seq_lens] else: state_keys = [] num_loaded_tuples[policy_id] = ( self.optimizers[policy_id].load_data( self.sess, [tuples[k] for k in data_keys], [tuples[k] for k in state_keys])) fetches = {} with self.grad_timer: for policy_id, tuples_per_device in num_loaded_tuples.items(): optimizer = self.optimizers[policy_id] num_batches = max( 1, int(tuples_per_device) // int(self.per_device_batch_size)) logger.debug("== sgd epochs for {} ==".format(policy_id)) for i in range(self.num_sgd_iter): iter_extra_fetches = defaultdict(list) permutation = np.random.permutation(num_batches) for batch_index in range(num_batches): batch_fetches = optimizer.optimize( self.sess, permutation[batch_index] * self.per_device_batch_size) for k, v in batch_fetches[LEARNER_STATS_KEY].items(): iter_extra_fetches[k].append(v) logger.debug("{} {}".format(i, _averaged(iter_extra_fetches))) fetches[policy_id] = _averaged(iter_extra_fetches) self.num_steps_sampled += samples.count self.num_steps_trained += tuples_per_device * len(self.devices) self.learner_stats = fetches return fetches