def test_split_batch(batch_data): s = TimeStepBatch( env_spec=batch_data['env_spec'], observations=batch_data['observations'], actions=batch_data['actions'], rewards=batch_data['rewards'], next_observations=batch_data['next_observations'], step_types=batch_data['step_types'], env_infos=batch_data['env_infos'], agent_infos=batch_data['agent_infos'], ) batches = s.split() assert len(batches) == 2 # original batch_data is a batch of 2 for i, batch in enumerate(batches): assert batch.env_spec == batch_data['env_spec'] assert np.array_equal(batch.observations, [batch_data['observations'][i]]) assert np.array_equal(batch.next_observations, [batch_data['next_observations'][i]]) assert np.array_equal(batch.actions, [batch_data['actions'][i]]) assert np.array_equal(batch.rewards, [batch_data['rewards'][i]]) assert np.array_equal(batch.step_types, [batch_data['step_types'][i]]) for key in batch.env_infos: assert key in batch_data['env_infos'] assert np.array_equal(batch.env_infos[key], [batch_data['env_infos'][key][i]]) for key in batch.agent_infos: assert key in batch_data['agent_infos'] assert (np.array_equal(batch.agent_infos[key], [batch_data['agent_infos'][key][i]]))
def test_to_time_step_list_batch(batch_data): s = TimeStepBatch( env_spec=batch_data['env_spec'], observations=batch_data['observations'], actions=batch_data['actions'], rewards=batch_data['rewards'], next_observations=batch_data['next_observations'], step_types=batch_data['step_types'], env_infos=batch_data['env_infos'], agent_infos=batch_data['agent_infos'], ) batches = s.to_time_step_list() assert len(batches) == 2 # original batch_data is a batch of 2 for i, batch in enumerate(batches): assert np.array_equal(batch['observations'], [batch_data['observations'][i]]) assert np.array_equal(batch['next_observations'], [batch_data['next_observations'][i]]) assert np.array_equal(batch['actions'], [batch_data['actions'][i]]) assert np.array_equal(batch['rewards'], [batch_data['rewards'][i]]) assert np.array_equal(batch['step_types'], [batch_data['step_types'][i]]) for key in batch['env_infos']: assert key in batch_data['env_infos'] assert np.array_equal(batch['env_infos'][key], [batch_data['env_infos'][key][i]]) for key in batch['agent_infos']: assert key in batch_data['agent_infos'] assert np.array_equal(batch['agent_infos'][key], [batch_data['agent_infos'][key][i]])
def test_observations_env_spec_mismatch_batch(batch_data): with pytest.raises(ValueError, match='Each observation has shape'): batch_data['observations'] = batch_data['observations'][:, :, :, :1] s = TimeStepBatch(**batch_data) del s obs_space = akro.Box(low=1, high=10, shape=(4, 5, 2), dtype=np.float32) act_space = gym.spaces.MultiDiscrete([2, 5]) env_spec = EnvSpec(obs_space, act_space) batch_data['env_spec'] = env_spec with pytest.raises(ValueError, match='Each observation has shape'): batch_data['observations'] = batch_data['observations'][:, :, :, :1] s = TimeStepBatch(**batch_data) del s
def test_next_observations_batch_mismatch_batch(batch_data): with pytest.raises(ValueError, match='batch dimension of ' 'next_observations'): batch_data['next_observations'] = batch_data['next_observations'][:-1] s = TimeStepBatch(**batch_data) del s
def expert_source(env, goal, max_episode_length, n_eps): expert = OptimalPolicy(env.spec, goal=goal) workers = WorkerFactory(seed=100, max_episode_length=max_episode_length) expert_sampler = LocalSampler.from_worker_factory(workers, expert, env) for _ in range(n_eps): eps_batch = expert_sampler.obtain_samples(0, max_episode_length, None) yield TimeStepBatch.from_episode_batch(eps_batch)
def sample_timesteps(self, batch_size): """Sample a batch of timesteps from the buffer. Args: batch_size (int): Number of timesteps to sample. Returns: TimeStepBatch: The batch of timesteps. """ samples = self.sample_transitions(batch_size) step_types = np.array([ StepType.TERMINAL if terminal else StepType.MID for terminal in samples['terminals'].reshape(-1) ], dtype=StepType) return TimeStepBatch(env_spec=self._env_spec, episode_infos={}, observations=samples['observations'], actions=samples['actions'], rewards=samples['rewards'].flatten(), next_observations=samples['next_observations'], step_types=step_types, env_infos={}, agent_infos={})
def test_agent_infos_batch_mismatch_batch(batch_data): with pytest.raises(ValueError, match="Entry 'hidden' in agent_infos has batch size 1"): batch_data['agent_infos']['hidden'] = batch_data['agent_infos'][ 'hidden'][:-1] s = TimeStepBatch(**batch_data) del s
def test_from_time_step_list_batch(batch_data): batches = [batch_data, batch_data] s = TimeStepBatch.from_time_step_list(batch_data['env_spec'], batches) new_obs = np.concatenate( [batch_data['observations'], batch_data['observations']]) new_next_obs = np.concatenate( [batch_data['next_observations'], batch_data['next_observations']]) new_actions = np.concatenate( [batch_data['actions'], batch_data['actions']]) new_rewards = np.concatenate( [batch_data['rewards'], batch_data['rewards']]) new_step_types = np.concatenate( [batch_data['step_types'], batch_data['step_types']]) new_env_infos = { k: np.concatenate([b['env_infos'][k] for b in batches]) for k in batches[0]['env_infos'].keys() } new_agent_infos = { k: np.concatenate([b['agent_infos'][k] for b in batches]) for k in batches[0]['agent_infos'].keys() } assert s.env_spec == batch_data['env_spec'] assert np.array_equal(s.observations, new_obs) assert np.array_equal(s.next_observations, new_next_obs) assert np.array_equal(s.actions, new_actions) assert np.array_equal(s.rewards, new_rewards) assert np.array_equal(s.step_types, new_step_types) for key in new_env_infos: assert key in s.env_infos assert np.array_equal(new_env_infos[key], s.env_infos[key]) for key in new_agent_infos: assert key in s.agent_infos assert np.array_equal(new_agent_infos[key], s.agent_infos[key])
def test_act_box_env_spec_mismatch_batch(batch_data): with pytest.raises(ValueError, match='Each action has'): batch_data['env_spec'] = EnvSpec( batch_data['env_spec'].observation_space, akro.Box(low=1, high=np.inf, shape=(4, 3, 2), dtype=np.float32)) s = TimeStepBatch(**batch_data) del s
def test_time_step_batch_from_episode_batch(eps_data): eps = EpisodeBatch(**eps_data) timestep_batch = TimeStepBatch.from_episode_batch(eps) assert (timestep_batch.observations == eps.observations).all() assert (timestep_batch.next_observations[:eps.lengths[0] - 1] == eps.observations[1:eps.lengths[0]]).all() assert (timestep_batch.next_observations[eps.lengths[0]] == eps.last_observations[0]).all()
def test_act_box_env_spec_mismatch_batch(batch_data): with pytest.raises(ValueError, match='actions should have'): batch_data['env_spec'].action_space = akro.Box(low=1, high=np.inf, shape=(4, 3, 2), dtype=np.float32) s = TimeStepBatch(**batch_data) del s
def test_agent_infos_batch_mismatch_batch(batch_data): with pytest.raises( ValueError, match='entry in agent_infos must have a batch dimension'): batch_data['agent_infos']['hidden'] = batch_data['agent_infos'][ 'hidden'][:-1] s = TimeStepBatch(**batch_data) del s
def test_time_step_batch_from_trajectory_batch(traj_data): traj = TrajectoryBatch(**traj_data) timestep_batch = TimeStepBatch.from_trajectory_batch(traj) assert (timestep_batch.observations == traj.observations).all() assert (timestep_batch.next_observations[:traj.lengths[0] - 1] == traj.observations[1:traj.lengths[0]]).all() assert (timestep_batch.next_observations[traj.lengths[0]] == traj.last_observations[0]).all()
def test_new_ts_batch(batch_data): s = TimeStepBatch(**batch_data) assert s.env_spec is batch_data['env_spec'] assert s.observations is batch_data['observations'] assert s.next_observations is batch_data['next_observations'] assert s.actions is batch_data['actions'] assert s.rewards is batch_data['rewards'] assert s.env_infos is batch_data['env_infos'] assert s.agent_infos is batch_data['agent_infos'] assert s.step_types is batch_data['step_types']
def test_next_observations_env_spec_mismatch_batch(batch_data): with pytest.raises(ValueError, match='next_observations must conform'): batch_data['next_observations'] = batch_data[ 'next_observations'][:, :, :, :1] s = TimeStepBatch(**batch_data) del s obs_space = akro.Box(low=1, high=10, shape=(4, 3, 2), dtype=np.float32) act_space = gym.spaces.MultiDiscrete([2, 5]) env_spec = EnvSpec(obs_space, act_space) batch_data['env_spec'] = env_spec with pytest.raises( ValueError, match='next_observations should have the same dimensionality'): batch_data['next_observations'] = batch_data[ 'next_observations'][:, :, :, :1] s = TimeStepBatch(**batch_data) del s
def test_terminals(batch_data): s = TimeStepBatch( env_spec=batch_data['env_spec'], observations=batch_data['observations'], actions=batch_data['actions'], rewards=batch_data['rewards'], next_observations=batch_data['next_observations'], step_types=batch_data['step_types'], env_infos=batch_data['env_infos'], agent_infos=batch_data['agent_infos'], ) assert s.terminals.shape == s.rewards.shape
def test_concatenate_batch(batch_data): single_batch = TimeStepBatch(**batch_data) batches = [single_batch, single_batch] s = TimeStepBatch.concatenate(*batches) new_obs = np.concatenate( [batch_data['observations'], batch_data['observations']]) new_next_obs = np.concatenate( [batch_data['next_observations'], batch_data['next_observations']]) new_actions = np.concatenate( [batch_data['actions'], batch_data['actions']]) new_rewards = np.concatenate( [batch_data['rewards'], batch_data['rewards']]) new_step_types = np.concatenate( [batch_data['step_types'], batch_data['step_types']]) new_env_infos = { k: np.concatenate([b.env_infos[k] for b in batches]) for k in batches[0].env_infos.keys() } new_agent_infos = { k: np.concatenate([b.agent_infos[k] for b in batches]) for k in batches[0].agent_infos.keys() } assert s.env_spec == batch_data['env_spec'] assert np.array_equal(s.observations, new_obs) assert np.array_equal(s.next_observations, new_next_obs) assert np.array_equal(s.actions, new_actions) assert np.array_equal(s.rewards, new_rewards) assert np.array_equal(s.step_types, new_step_types) for key in new_env_infos: assert key in s.env_infos assert np.array_equal(new_env_infos[key], s.env_infos[key]) for key in new_agent_infos: assert key in s.agent_infos assert np.array_equal(new_agent_infos[key], s.agent_infos[key])
def _obtain_samples(self, trainer, epoch): """Obtain samples from self._source. Args: trainer (Trainer): Experiment trainer, which may be used to obtain samples. epoch (int): The current epoch. Returns: TimeStepBatch: Batch of samples. """ if isinstance(self._source, Policy): batch = trainer.obtain_episodes(epoch) log_performance(epoch, batch, 1.0, prefix='Expert') return batch else: batches = [] while (sum(len(batch.actions) for batch in batches) < self._batch_size): batches.append(next(self._source)) return TimeStepBatch.concatenate(*batches)
def _obtain_samples(self, runner, epoch): """Obtain samples from self._source. Args: runner (LocalRunner): LocalRunner to which may be used to obtain samples. epoch (int): The current epoch. Returns: TimeStepBatch: Batch of samples. """ if isinstance(self._source, Policy): batch = TrajectoryBatch.from_trajectory_list( self.env_spec, runner.obtain_samples(epoch)) log_performance(epoch, batch, 1.0, prefix='Expert') return batch else: batches = [] while (sum(len(batch.actions) for batch in batches) < self._batch_size): batches.append(next(self._source)) return TimeStepBatch.concatenate(*batches)
def test_agent_infos_not_ndarray_batch(batch_data): with pytest.raises(ValueError, match="Entry 'bar' in agent_infos"): batch_data['agent_infos']['bar'] = list() s = TimeStepBatch(**batch_data) del s
def test_from_empty_time_step_list_batch(batch_data): with pytest.raises(ValueError, match='at least one dict'): batches = [] s = TimeStepBatch.from_time_step_list(batch_data['env_spec'], batches) del s
def test_rewards_batch_mismatch_batch(batch_data): with pytest.raises(ValueError, match='batch dimension of rewards'): batch_data['rewards'] = batch_data['rewards'][:-1] s = TimeStepBatch(**batch_data) del s
def test_empty_terminals__batch(batch_data): with pytest.raises(ValueError, match='batch dimension of terminals'): batch_data['terminals'] = [] s = TimeStepBatch(**batch_data) del s
def test_concatenate_empty_batch(): with pytest.raises(ValueError, match='at least one'): batches = [] s = TimeStepBatch.concatenate(*batches) del s
def test_step_types_dtype_mismatch_batch(batch_data): with pytest.raises(ValueError, match='step_types must be a StepType'): batch_data['step_types'] = batch_data['step_types'].astype(np.float32) s = TimeStepBatch(**batch_data) del s
def test_step_types_batch_mismatch_batch(batch_data): with pytest.raises(ValueError, match='batch dimension of step_types'): batch_data['step_types'] = np.array([]) s = TimeStepBatch(**batch_data) del s
def test_terminals_dtype_mismatch_batch(batch_data): with pytest.raises(ValueError, match='terminals tensor must be dtype'): batch_data['terminals'] = batch_data['terminals'].astype(np.float32) s = TimeStepBatch(**batch_data) del s
def test_agent_infos_not_ndarray_batch(batch_data): with pytest.raises(ValueError, match='entry in agent_infos must be a numpy array'): batch_data['agent_infos']['bar'] = list() s = TimeStepBatch(**batch_data) del s
def test_act_env_spec_mismatch_batch(batch_data): with pytest.raises(ValueError, match='actions must conform'): batch_data['actions'] = batch_data['actions'][:, 0] s = TimeStepBatch(**batch_data) del s
def test_invalid_inferred_batch_size(batch_data): with pytest.raises(ValueError, match='batch dimension of rewards'): batch_data['rewards'] = [] s = TimeStepBatch(**batch_data) del s