def step(self): with self.update_weights_timer: if self.remote_evaluators: weights = ray.put(self.local_evaluator.get_weights()) for e in self.remote_evaluators: e.set_weights.remote(weights) with self.sample_timer: if self.remote_evaluators: batches = ray.get( [e.sample.remote() for e in self.remote_evaluators]) else: batches = [self.local_evaluator.sample()] # Handle everything as if multiagent tmp = [] for batch in batches: if isinstance(batch, SampleBatch): batch = MultiAgentBatch({DEFAULT_POLICY_ID: batch}, batch.count) tmp.append(batch) batches = tmp for batch in batches: self.replay_buffer.append(batch) self.num_steps_sampled += batch.count self.buffer_size += batch.count while self.buffer_size > self.max_buffer_size: evicted = self.replay_buffer.pop(0) self.buffer_size -= evicted.count if self.num_steps_sampled >= self.replay_starts: self._optimize()
def step(self): with self.update_weights_timer: if self.remote_evaluators: weights = ray.put(self.local_evaluator.get_weights()) for e in self.remote_evaluators: e.set_weights.remote(weights) with self.sample_timer: if self.remote_evaluators: batch = SampleBatch.concat_samples( ray.get( [e.sample.remote() for e in self.remote_evaluators])) else: batch = self.local_evaluator.sample() # Handle everything as if multiagent if isinstance(batch, SampleBatch): batch = MultiAgentBatch({DEFAULT_POLICY_ID: batch}, batch.count) for policy_id, s in batch.policy_batches.items(): for row in s.rows(): self.replay_buffers[policy_id].add( pack_if_needed(row["obs"]), row["actions"], row["rewards"], pack_if_needed(row["new_obs"]), row["dones"], weight=None) if self.num_steps_sampled >= self.replay_starts: self._optimize() self.num_steps_sampled += batch.count
def _replay(self): samples = {} with self.replay_timer: for policy_id, replay_buffer in self.replay_buffers.items(): if isinstance(replay_buffer, PrioritizedReplayBuffer): (obses_t, actions, rewards, obses_tp1, dones, weights, batch_indexes) = replay_buffer.sample( self.train_batch_size, beta=self.prioritized_replay_beta.value( self.num_steps_trained)) else: (obses_t, actions, rewards, obses_tp1, dones) = replay_buffer.sample(self.train_batch_size) weights = np.ones_like(rewards) batch_indexes = -np.ones_like(rewards) samples[policy_id] = SampleBatch({ "obs": obses_t, "actions": actions, "rewards": rewards, "new_obs": obses_tp1, "dones": dones, "weights": weights, "batch_indexes": batch_indexes }) return MultiAgentBatch(samples, self.train_batch_size)
def _from_json(batch): if isinstance(batch, bytes): # smart_open S3 doesn't respect "r" batch = batch.decode("utf-8") data = json.loads(batch) if "type" in data: data_type = data.pop("type") else: raise ValueError("JSON record missing 'type' field") if data_type == "SampleBatch": for k, v in data.items(): data[k] = unpack_if_needed(v) return SampleBatch(data) elif data_type == "MultiAgentBatch": policy_batches = {} for policy_id, policy_batch in data["policy_batches"].items(): inner = {} for k, v in policy_batch.items(): inner[k] = unpack_if_needed(v) policy_batches[policy_id] = SampleBatch(inner) return MultiAgentBatch(policy_batches, data["count"]) else: raise ValueError( "Type field must be one of ['SampleBatch', 'MultiAgentBatch']", data_type)
def add_batch(self, batch): # Handle everything as if multiagent if isinstance(batch, SampleBatch): batch = MultiAgentBatch({DEFAULT_POLICY_ID: batch}, batch.count) self.buffer.append(batch) self.cur_size += batch.count self.num_added += batch.count while self.cur_size > self.buffer_size: self.cur_size -= self.buffer.pop(0).count
def add_batch(self, batch): # Handle everything as if multiagent if isinstance(batch, SampleBatch): batch = MultiAgentBatch({DEFAULT_POLICY_ID: batch}, batch.count) with self.add_batch_timer: for policy_id, s in batch.policy_batches.items(): for row in s.rows(): self.replay_buffers[policy_id].add( row["obs"], row["actions"], row["rewards"], row["new_obs"], row["dones"], row["weights"]) self.num_added += batch.count
def build_and_reset(self, episode): """Returns the accumulated sample batches for each policy. Any unprocessed rows will be first postprocessed with a policy postprocessor. The internal state of this builder will be reset. Arguments: episode: current MultiAgentEpisode object or None """ self.postprocess_batch_so_far(episode) policy_batches = {} for policy_id, builder in self.policy_builders.items(): if builder.count > 0: policy_batches[policy_id] = builder.build_and_reset() old_count = self.count self.count = 0 return MultiAgentBatch.wrap_as_needed(policy_batches, old_count)
def replay(self): if self.num_added < self.replay_starts: return None with self.replay_timer: samples = {} for policy_id, replay_buffer in self.replay_buffers.items(): (obses_t, actions, rewards, obses_tp1, dones, weights, batch_indexes) = replay_buffer.sample( self.train_batch_size, beta=self.prioritized_replay_beta) samples[policy_id] = SampleBatch({ "obs": obses_t, "actions": actions, "rewards": rewards, "new_obs": obses_tp1, "dones": dones, "weights": weights, "batch_indexes": batch_indexes }) return MultiAgentBatch(samples, self.train_batch_size)
def step(self): with self.update_weights_timer: if self.remote_evaluators: weights = ray.put(self.local_evaluator.get_weights()) for e in self.remote_evaluators: e.set_weights.remote(weights) with self.sample_timer: if self.remote_evaluators: batches = ray_get_and_free( [e.sample.remote() for e in self.remote_evaluators]) else: batches = [self.local_evaluator.sample()] # Handle everything as if multiagent tmp = [] for batch in batches: if isinstance(batch, SampleBatch): batch = MultiAgentBatch({DEFAULT_POLICY_ID: batch}, batch.count) tmp.append(batch) batches = tmp for batch in batches: if batch.count > self.max_buffer_size: raise ValueError( "The size of a single sample batch exceeds the replay " "buffer size ({} > {})".format(batch.count, self.max_buffer_size)) self.replay_buffer.append(batch) self.num_steps_sampled += batch.count self.buffer_size += batch.count while self.buffer_size > self.max_buffer_size: evicted = self.replay_buffer.pop(0) self.buffer_size -= evicted.count if self.num_steps_sampled >= self.replay_starts: return self._optimize() else: return {}
def step(self): with self.update_weights_timer: if self.remote_evaluators: weights = ray.put(self.local_evaluator.get_weights()) for e in self.remote_evaluators: e.set_weights.remote(weights) with self.sample_timer: if self.remote_evaluators: # TODO(rliaw): remove when refactoring from ray.rllib.agents.ppo.rollout import collect_samples samples = collect_samples(self.remote_evaluators, self.train_batch_size) else: samples = self.local_evaluator.sample() # Handle everything as if multiagent if isinstance(samples, SampleBatch): samples = MultiAgentBatch({DEFAULT_POLICY_ID: samples}, samples.count) for _, batch in samples.policy_batches.items(): for field in self.standardize_fields: value = batch[field] standardized = (value - value.mean()) / max(1e-4, value.std()) batch[field] = standardized for policy_id, policy in self.policies.items(): # Important: don't shuffle RNN sequence elements if (policy_id in samples.policy_batches and not policy._state_inputs): samples.policy_batches[policy_id].shuffle() num_loaded_tuples = {} with self.load_timer: for policy_id, batch in samples.policy_batches.items(): policy = self.policies[policy_id] tuples = policy._get_loss_inputs_dict(batch) data_keys = [ph for _, ph in policy._loss_inputs] if policy._state_inputs: state_keys = policy._state_inputs + [policy._seq_lens] else: state_keys = [] num_loaded_tuples[policy_id] = ( self.optimizers[policy_id].load_data( self.sess, [tuples[k] for k in data_keys], [tuples[k] for k in state_keys])) fetches = {} with self.grad_timer: for policy_id, tuples_per_device in num_loaded_tuples.items(): optimizer = self.optimizers[policy_id] num_batches = (int(tuples_per_device) // int(self.per_device_batch_size)) logger.debug("== sgd epochs for {} ==".format(policy_id)) for i in range(self.num_sgd_iter): iter_extra_fetches = defaultdict(list) permutation = np.random.permutation(num_batches) for batch_index in range(num_batches): batch_fetches = optimizer.optimize( self.sess, permutation[batch_index] * self.per_device_batch_size) for k, v in batch_fetches.items(): iter_extra_fetches[k].append(v) logger.debug("{} {}".format(i, _averaged(iter_extra_fetches))) fetches[policy_id] = _averaged(iter_extra_fetches) self.num_steps_sampled += samples.count self.num_steps_trained += samples.count return fetches
def step(self): with self.update_weights_timer: if self.remote_evaluators: weights = ray.put(self.local_evaluator.get_weights()) for e in self.remote_evaluators: e.set_weights.remote(weights) with self.sample_timer: if self.remote_evaluators: if self.straggler_mitigation: samples = collect_samples_straggler_mitigation( self.remote_evaluators, self.train_batch_size) else: samples = collect_samples( self.remote_evaluators, self.sample_batch_size, self.num_envs_per_worker, self.train_batch_size) if samples.count > self.train_batch_size * 2: logger.info( "Collected more training samples than expected " "(actual={}, train_batch_size={}). ".format( samples.count, self.train_batch_size) + "This may be because you have many workers or " "long episodes in 'complete_episodes' batch mode.") else: samples = [] while sum(s.count for s in samples) < self.train_batch_size: samples.append(self.local_evaluator.sample()) samples = SampleBatch.concat_samples(samples) # Handle everything as if multiagent if isinstance(samples, SampleBatch): samples = MultiAgentBatch({ DEFAULT_POLICY_ID: samples }, samples.count) for policy_id, policy in self.policies.items(): if policy_id not in samples.policy_batches: continue batch = samples.policy_batches[policy_id] for field in self.standardize_fields: value = batch[field] standardized = (value - value.mean()) / max(1e-4, value.std()) batch[field] = standardized # Important: don't shuffle RNN sequence elements if not policy._state_inputs: batch.shuffle() num_loaded_tuples = {} with self.load_timer: for policy_id, batch in samples.policy_batches.items(): if policy_id not in self.policies: continue policy = self.policies[policy_id] tuples = policy._get_loss_inputs_dict(batch) data_keys = [ph for _, ph in policy._loss_inputs] if policy._state_inputs: state_keys = policy._state_inputs + [policy._seq_lens] else: state_keys = [] num_loaded_tuples[policy_id] = ( self.optimizers[policy_id].load_data( self.sess, [tuples[k] for k in data_keys], [tuples[k] for k in state_keys])) fetches = {} with self.grad_timer: for policy_id, tuples_per_device in num_loaded_tuples.items(): optimizer = self.optimizers[policy_id] num_batches = max( 1, int(tuples_per_device) // int(self.per_device_batch_size)) logger.debug("== sgd epochs for {} ==".format(policy_id)) for i in range(self.num_sgd_iter): iter_extra_fetches = defaultdict(list) permutation = np.random.permutation(num_batches) for batch_index in range(num_batches): batch_fetches = optimizer.optimize( self.sess, permutation[batch_index] * self.per_device_batch_size) for k, v in batch_fetches[LEARNER_STATS_KEY].items(): iter_extra_fetches[k].append(v) logger.debug("{} {}".format(i, _averaged(iter_extra_fetches))) fetches[policy_id] = _averaged(iter_extra_fetches) self.num_steps_sampled += samples.count self.num_steps_trained += tuples_per_device * len(self.devices) self.learner_stats = fetches return fetches