def add(self, samples: SampleBatch): """Add a SampleBatch to storage. Optimized to avoid several queries for large sample batches. Args: samples: The sample batch """ if samples.count >= self._maxsize: samples = samples.slice(samples.count - self._maxsize, None) end_idx = 0 assign = [(slice(0, self._maxsize), samples)] else: start_idx = self._next_idx end_idx = (self._next_idx + samples.count) % self._maxsize if end_idx < start_idx: tailcount = self._maxsize - start_idx assign = [ (slice(start_idx, None), samples.slice(0, tailcount)), (slice(end_idx), samples.slice(tailcount, None)), ] else: assign = [(slice(start_idx, end_idx), samples)] for field in self.fields: for slc, smp in assign: self._storage[field.name][slc] = smp[field.name] self._next_idx = end_idx self._curr_size = min(self._curr_size + samples.count, self._maxsize)
def __call__(self, batch: MultiAgentBatch) -> List[SampleBatchType]: _check_sample_batch_type(batch) batch_count = batch.policy_batches[self.policy_id_to_count_for].count if self.drop_samples_for_other_agents: batch = MultiAgentBatch(policy_batches={ self.policy_id_to_count_for: batch.policy_batches[self.policy_id_to_count_for] }, env_steps=batch.policy_batches[ self.policy_id_to_count_for].count) self.buffer.append(batch) self.count += batch_count if self.count >= self.min_batch_size: if self.count > self.min_batch_size * 2: logger.info("Collected more training samples than expected " "(actual={}, expected={}). ".format( self.count, self.min_batch_size) + "This may be because you have many workers or " "long episodes in 'complete_episodes' batch mode.") out = SampleBatch.concat_samples(self.buffer) timer = _get_shared_metrics().timers[SAMPLE_TIMER] timer.push(time.perf_counter() - self.batch_start_time) timer.push_units_processed(self.count) self.batch_start_time = None self.buffer = [] self.count = 0 return [out] return []
def __call__(self, batch: SampleBatchType) -> List[SampleBatchType]: _check_sample_batch_type(batch) if self.done: # Warmup phase done, simply return batch return [batch] metrics = _get_shared_metrics() timesteps_total = metrics.counters[STEPS_SAMPLED_COUNTER] self.buffer.append(batch) self.count += batch.count assert self.count == timesteps_total if timesteps_total < self.learning_starts: # Return emtpy if still in warmup return [] # Warmup just done if self.count > self.learning_starts * 2: logger.info( # pylint:disable=logging-fstring-interpolation "Collected more training samples than expected " f"(actual={self.count}, expected={self.learning_starts}). " "This may be because you have many workers or " "long episodes in 'complete_episodes' batch mode.") out = SampleBatch.concat_samples(self.buffer) self.buffer = [] self.count = 0 self.done = True return [out]
def replay(self) -> SampleBatchType: """If this buffer was given a fake batch, return it, otherwise return a MultiAgentBatch with samples. """ if self._fake_batch: fake_batch = SampleBatch(self._fake_batch) return MultiAgentBatch({ DEFAULT_POLICY_ID: fake_batch }, fake_batch.count) if self.num_added < self.replay_starts: return None with self.replay_timer: # Lockstep mode: Sample from all policies at the same time an # equal amount of steps. if self.replay_mode == "lockstep": return self.replay_buffers[_ALL_POLICIES].sample( self.replay_batch_size, beta=self.prioritized_replay_beta) else: samples = {} for policy_id, replay_buffer in self.replay_buffers.items(): samples[policy_id] = replay_buffer.sample( self.replay_batch_size, beta=self.prioritized_replay_beta) return MultiAgentBatch(samples, self.replay_batch_size)
def sample_with_idxes(self, idxes: np.ndarray) -> SampleBatch: """Transition batch corresponding with the given indexes.""" batch = { k: self._storage[k][idxes] for k in (f.name for f in self.fields) } return SampleBatch(batch)
def replay(self, policy_id: Optional[PolicyID] = None) -> SampleBatchType: """If this buffer was given a fake batch, return it, otherwise return a MultiAgentBatch with samples. """ if self._fake_batch: if not isinstance(self._fake_batch, MultiAgentBatch): self._fake_batch = SampleBatch( self._fake_batch).as_multi_agent() return self._fake_batch if self.num_added < self.replay_starts: return None with self.replay_timer: # Lockstep mode: Sample from all policies at the same time an # equal amount of steps. if self.replay_mode == "lockstep": assert ( policy_id is None ), "`policy_id` specifier not allowed in `locksetp` mode!" return self.replay_buffers[_ALL_POLICIES].sample( self.replay_batch_size, beta=self.prioritized_replay_beta) elif policy_id is not None: return self.replay_buffers[policy_id].sample( self.replay_batch_size, beta=self.prioritized_replay_beta) else: samples = {} for policy_id, replay_buffer in self.replay_buffers.items(): samples[policy_id] = replay_buffer.sample( self.replay_batch_size, beta=self.prioritized_replay_beta) return MultiAgentBatch(samples, self.replay_batch_size)
def aggregate_into_larger_batch(): if (sum(b.count for b in self.batch_being_built) >= self.config["train_batch_size"]): batch_to_add = SampleBatch.concat_samples( self.batch_being_built) self.batches_to_place_on_learner.append(batch_to_add) self.batch_being_built = []
def improve_policy(self, num_improvements: int) -> Dict[str, float]: """Call the policy to perform policy improvement using the augmented replay. Args: num_improvements: Number of times to call `policy.learn_on_batch` Returns: A dictionary of training and exploration statistics """ policy = self.get_policy() batch_size = self.config["train_batch_size"] env_batch_size = int(batch_size * self.config["real_data_ratio"]) model_batch_size = batch_size - env_batch_size stats = {} for _ in range(num_improvements): samples = [] if env_batch_size: samples += [self.replay.sample(env_batch_size)] if model_batch_size: samples += [self.virtual_replay.sample(model_batch_size)] batch = SampleBatch.concat_samples(samples) stats = get_learner_stats(policy.learn_on_batch(batch)) self.tracker.num_steps_trained += batch.count stats.update(policy.get_exploration_info()) return stats
def generate_virtual_sample_batch(self, samples: SampleBatch) -> SampleBatch: """Rollout model with latest policy. Produces samples for populating the virtual buffer, hence no gradient information is retained. If a transition is terminal, the next transition, if any, is generated from the initial state passed through `samples`. Args: samples: the transitions to extract initial states from Returns: A batch of transitions sampled from the model """ virtual_samples = [] obs = init_obs = self.convert_to_tensor(samples[SampleBatch.CUR_OBS]) rollout_length = round(self.rollout_schedule(self.global_timestep)) for _ in range(rollout_length): model = self.rng.choice(self.elite_models) action, _ = self.module.actor.sample(obs) next_obs, _ = model.sample(model(obs, action)) reward = self.reward_fn(obs, action, next_obs) done = self.termination_fn(obs, action, next_obs) transition = { SampleBatch.CUR_OBS: obs, SampleBatch.ACTIONS: action, SampleBatch.NEXT_OBS: next_obs, SampleBatch.REWARDS: reward, SampleBatch.DONES: done, } virtual_samples += [ SampleBatch( {k: v.cpu().numpy() for k, v in transition.items()}) ] obs = torch.where(done.unsqueeze(-1), init_obs, next_obs) return SampleBatch.concat_samples(virtual_samples)
def fake_batch(obs_space, action_space, batch_size=1): """Create a fake SampleBatch compatible with Policy.learn_on_batch.""" samples = { SampleBatch.CUR_OBS: fake_space_samples(obs_space, batch_size), SampleBatch.ACTIONS: fake_space_samples(action_space, batch_size), SampleBatch.REWARDS: np.random.randn(batch_size), SampleBatch.NEXT_OBS: fake_space_samples(obs_space, batch_size), SampleBatch.DONES: np.random.randn(batch_size) > 0, } return SampleBatch(samples)
def transition_dataset(trajs: list[SampleBatch]) -> TensorDataset: """Convert a list of trajectories into a transition tensor dataset.""" transitions = SampleBatch.concat_samples(trajs) dataset = TensorDataset( torch.from_numpy(transitions[SampleBatch.CUR_OBS]), torch.from_numpy(transitions[SampleBatch.ACTIONS]), torch.from_numpy(transitions[SampleBatch.NEXT_OBS]), ) assert len(dataset) == transitions.count return dataset
def _train_dual_policies(self, samples: SampleBatch): learner_stats = {"learner_stats": {}} for policy_n, policy in enumerate(self.algorithms): if policy_n in self.DUAL_POLICIES: logger.debug(f"train policy {policy}") samples_copy = samples.copy() samples_copy = self._modify_batch_for_policy( policy_n, samples_copy) learner_stats_one_policy = policy.learn_on_batch(samples_copy) learner_stats["learner_stats"][ f"algo{policy_n}"] = learner_stats_one_policy return learner_stats
def group_batch_episodes(samples: SampleBatch) -> SampleBatch: """Return the sample batch with rows grouped by episode id. Moreover, rows are sorted by timestep. Warning: Modifies the sample batch in-place """ # Assume "t" is the timestep key in the sample batch sorted_timestep_idxs = np.argsort(samples["t"]) for key, val in samples.items(): samples[key] = val[sorted_timestep_idxs] # Stable sort is important so that we don't alter the order # of timesteps sorted_episode_idxs = np.argsort(samples[SampleBatch.EPS_ID], kind="stable") for key, val in samples.items(): samples[key] = val[sorted_episode_idxs] return samples
def __call__(self, batch: SampleBatchType) -> List[SampleBatchType]: _check_sample_batch_type(batch) self.buffer.append(batch) self.count += 1 if self.count >= self.num_episodes: out = SampleBatch.concat_samples(self.buffer) timer = _get_shared_metrics().timers[SAMPLE_TIMER] timer.push(time.perf_counter() - self.batch_start_time) timer.push_units_processed(self.count) self.batch_start_time = None self.buffer = [] self.count = 0 return [out] return []
def update_policy(self, times: int) -> StatDict: batch_size = self.config["batch_size"] env_batch_size = int(batch_size * self.config["real_data_ratio"]) model_batch_size = batch_size - env_batch_size for _ in range(times): samples = [] if env_batch_size: samples += [self.replay.sample(env_batch_size)] if model_batch_size: samples += [self.virtual_replay.sample(model_batch_size)] batch = SampleBatch.concat_samples(samples) batch = self.lazy_tensor_dict(batch) info = self.improve_policy(batch) return info
def _learn_on_policy(self, samples: SampleBatch) -> dict: """Update on-policy components.""" batch = self.lazy_tensor_dict(samples) episodes = [ self.lazy_tensor_dict(s) for s in samples.split_by_episode() ] with self.optimizers.optimize("on_policy"): loss, info = self.loss_actor(episodes) kl_div = self._avg_kl_divergence(batch) loss = loss + kl_div * self.curr_kl_coeff loss.backward() info.update(self.extra_grad_info(batch, on_policy=True)) info.update(self.update_kl_coeff(samples)) return info
def sample( self, num_items: int, policy_id: Optional[PolicyID] = None) -> Optional[SampleBatchType]: """Samples a batch of size `num_items` from a policy's buffer If this buffer was given a fake batch, return it, otherwise return a MultiAgentBatch with samples. If less than `num_items` records are in the policy's buffer, some samples in the results may be repeated to fulfil the batch size (`num_items`) request. Args: num_items: Number of items to sample from a policy's buffer. policy_id: ID of the policy that created the experiences we sample Returns: Concatenated batch of items. None if buffer is empty. """ if self._fake_batch: if not isinstance(self._fake_batch, MultiAgentBatch): self._fake_batch = SampleBatch( self._fake_batch).as_multi_agent() return self._fake_batch if self._num_added < self.replay_starts: return None with self.replay_timer: # Lockstep mode: Sample from all policies at the same time an # equal amount of steps. if self.replay_mode == "lockstep": assert ( policy_id is None ), "`policy_id` specifier not allowed in `locksetp` mode!" return self.replay_buffers[_ALL_POLICIES].sample( self.replay_batch_size, beta=self.prioritized_replay_beta) elif policy_id is not None: return self.replay_buffers[policy_id].sample( self.replay_batch_size, beta=self.prioritized_replay_beta) else: samples = {} for policy_id, replay_buffer in self.replay_buffers.items(): samples[policy_id] = replay_buffer.sample( self.replay_batch_size, beta=self.prioritized_replay_beta) return MultiAgentBatch(samples, self.replay_batch_size)
def test_getitem(numpy_replay: NumpyReplayBuffer, sample_batch: SampleBatch, idx): replay = numpy_replay batch = replay[idx] assert isinstance(batch, dict) assert all([ np.allclose(batch[k], sample_batch[k][idx]) for k in sample_batch.keys() ]) mean = np.mean(sample_batch[SampleBatch.CUR_OBS], axis=0) std = np.std(sample_batch[SampleBatch.CUR_OBS], axis=0) replay.update_obs_stats() batch = replay[idx] for key in SampleBatch.CUR_OBS, SampleBatch.NEXT_OBS: expected = (sample_batch[key][idx] - mean) / (std + 1e-7) assert np.allclose(batch[key], expected)
def test_getitem(filled_replay: NumpyReplayBuffer, sample_batch: SampleBatch, idx): replay = filled_replay batch = replay[idx] assert isinstance(batch, dict) assert all([ np.allclose(batch[k], sample_batch[k][idx]) for k in sample_batch.keys() ]) mean = np.mean(sample_batch[SampleBatch.CUR_OBS], axis=0) std = np.std(sample_batch[SampleBatch.CUR_OBS], axis=0) std[std < 1e-12] = 1.0 replay.compute_stats = True batch = replay[idx] for key in SampleBatch.CUR_OBS, SampleBatch.NEXT_OBS: expected = (sample_batch[key][idx] - mean) / std assert np.allclose(batch[key], expected)
def sample_with_idxes(self, idxes: np.ndarray) -> SampleBatch: """Sample a batch of experiences corresponding to the given indexes.""" self._num_sampled += len(idxes) data = self._encode_sample(idxes) return SampleBatch(dict(zip([f.name for f in self.fields], data)))
def sample(self, batch_size: int) -> SampleBatch: """Transition batch uniformly sampled with replacement.""" return SampleBatch(self[self.sample_idxes(batch_size)])
def all_samples(self) -> SampleBatch: """All stored transitions.""" return SampleBatch(self[:len(self)])
def sample_with_idxes(self, idxes: np.ndarray) -> SampleBatch: self._num_sampled += len(idxes) data = self._encode_sample(idxes) return SampleBatch(dict(zip([f.name for f in self.fields], data)))
def all_samples(self) -> SampleBatch: """All stored transitions.""" return SampleBatch({ k: self._storage[k][:len(self)] for k in (f.name for f in self.fields) })
def _initialize_loss(self): def fake_array(tensor, none_shape): shape = tensor.shape.as_list() non_none_shape = [s for s in shape if s is not None] none_shape = none_shape if isinstance(none_shape, list) else [none_shape] shape = none_shape + non_none_shape return np.zeros(shape, dtype=tensor.dtype.as_numpy_dtype) T = self.config["model"]["max_seq_len"] B = self.config["train_batch_size"] // T dummy_batch = { SampleBatch.CUR_OBS: fake_array(self._obs_input, B * T), SampleBatch.NEXT_OBS: fake_array(self._obs_input, B * T), SampleBatch.DONES: np.array([False] * B * T, dtype=np.bool), SampleBatch.ACTIONS: fake_array( ModelCatalog.get_action_placeholder(self.action_space), B * T ), SampleBatch.REWARDS: np.array([0] * B * T, dtype=np.float32), SampleBatch.INFOS: np.array([self.sample_info] * B * T), } if self._obs_include_prev_action_reward: dummy_batch.update( { SampleBatch.PREV_ACTIONS: fake_array(self._prev_action_input, B * T), SampleBatch.PREV_REWARDS: fake_array(self._prev_reward_input, B * T), } ) state_init = self.get_initial_state() state_batches = [] for i, h in enumerate(state_init): dummy_batch["state_in_{}".format(i)] = np.repeat( np.expand_dims(h, 0), B * T, 0 ) dummy_batch["state_out_{}".format(i)] = np.repeat( np.expand_dims(h, 0), B * T, 0 ) state_batches.append(np.repeat(np.expand_dims(h, 0), B * T, 0)) if state_init: dummy_batch["seq_lens"] = np.array([T] * B * T, dtype=np.int32) for k, v in self.extra_compute_action_fetches().items(): dummy_batch[k] = fake_array(v, B * T) # postprocessing might depend on variable init, so run it first here self._sess.run(tf.global_variables_initializer()) postprocessed_batch = self.postprocess_trajectory(SampleBatch(dummy_batch)) # model forward pass for the loss (needed after postprocess to # overwrite any tensor state from that call) self.model(self._input_dict, self._state_in, self._seq_lens) if self._obs_include_prev_action_reward: train_batch = UsageTrackingDict( { SampleBatch.PREV_ACTIONS: self._prev_action_input, SampleBatch.PREV_REWARDS: self._prev_reward_input, SampleBatch.CUR_OBS: self._obs_input, } ) loss_inputs = [ (SampleBatch.PREV_ACTIONS, self._prev_action_input), (SampleBatch.PREV_REWARDS, self._prev_reward_input), (SampleBatch.CUR_OBS, self._obs_input), ] else: train_batch = UsageTrackingDict({SampleBatch.CUR_OBS: self._obs_input}) loss_inputs = [ (SampleBatch.CUR_OBS, self._obs_input), ] for k, v in postprocessed_batch.items(): if k in train_batch: continue elif v.dtype == np.object: continue # can't handle arbitrary objects in TF elif k == "seq_lens" or k.startswith("state_in_"): continue shape = (None,) + v.shape[1:] dtype = np.float32 if v.dtype == np.float64 else v.dtype placeholder = tf.placeholder(dtype, shape=shape, name=k) train_batch[k] = placeholder for i, si in enumerate(self._state_in): train_batch["state_in_{}".format(i)] = si train_batch["seq_lens"] = self._seq_lens if log_once("loss_init"): logger.debug( "Initializing loss function with dummy input:\n\n{}\n".format( summarize(train_batch) ) ) self._loss_input_dict = train_batch loss = self._do_loss_init(train_batch) for k in sorted(train_batch.accessed_keys): if k != "seq_lens" and not k.startswith("state_in_"): loss_inputs.append((k, train_batch[k])) TFPolicy._initialize_loss(self, loss, loss_inputs) if self._grad_stats_fn: self._stats_fetches.update( self._grad_stats_fn(self, train_batch, self._grads) ) self._sess.run(tf.global_variables_initializer())