def _obtain_samples(self, trainer, epoch): """Obtain samples from self._source. Args: trainer (Trainer): Experiment trainer, which may be used to obtain samples. epoch (int): The current epoch. Returns: TimeStepBatch: Batch of samples. """ if isinstance(self._source, Policy): batch = trainer.obtain_episodes(epoch) log_performance(epoch, batch, 1.0, prefix='Expert') return batch else: batches = [] while (sum(len(batch.actions) for batch in batches) < self._batch_size): batches.append(next(self._source)) return TimeStepBatch.concatenate(*batches)
def _obtain_samples(self, runner, epoch): """Obtain samples from self._source. Args: runner (LocalRunner): LocalRunner to which may be used to obtain samples. epoch (int): The current epoch. Returns: TimeStepBatch: Batch of samples. """ if isinstance(self._source, Policy): batch = TrajectoryBatch.from_trajectory_list( self.env_spec, runner.obtain_samples(epoch)) log_performance(epoch, batch, 1.0, prefix='Expert') return batch else: batches = [] while (sum(len(batch.actions) for batch in batches) < self._batch_size): batches.append(next(self._source)) return TimeStepBatch.concatenate(*batches)
def test_concatenate_batch(batch_data): single_batch = TimeStepBatch(**batch_data) batches = [single_batch, single_batch] s = TimeStepBatch.concatenate(*batches) new_obs = np.concatenate( [batch_data['observations'], batch_data['observations']]) new_next_obs = np.concatenate( [batch_data['next_observations'], batch_data['next_observations']]) new_actions = np.concatenate( [batch_data['actions'], batch_data['actions']]) new_rewards = np.concatenate( [batch_data['rewards'], batch_data['rewards']]) new_step_types = np.concatenate( [batch_data['step_types'], batch_data['step_types']]) new_env_infos = { k: np.concatenate([b.env_infos[k] for b in batches]) for k in batches[0].env_infos.keys() } new_agent_infos = { k: np.concatenate([b.agent_infos[k] for b in batches]) for k in batches[0].agent_infos.keys() } assert s.env_spec == batch_data['env_spec'] assert np.array_equal(s.observations, new_obs) assert np.array_equal(s.next_observations, new_next_obs) assert np.array_equal(s.actions, new_actions) assert np.array_equal(s.rewards, new_rewards) assert np.array_equal(s.step_types, new_step_types) for key in new_env_infos: assert key in s.env_infos assert np.array_equal(new_env_infos[key], s.env_infos[key]) for key in new_agent_infos: assert key in s.agent_infos assert np.array_equal(new_agent_infos[key], s.agent_infos[key])
def test_concatenate_empty_batch(): with pytest.raises(ValueError, match='at least one'): batches = [] s = TimeStepBatch.concatenate(*batches) del s