Example #1
0
def test_split_batch(batch_data):
    s = TimeStepBatch(
        env_spec=batch_data['env_spec'],
        observations=batch_data['observations'],
        actions=batch_data['actions'],
        rewards=batch_data['rewards'],
        next_observations=batch_data['next_observations'],
        step_types=batch_data['step_types'],
        env_infos=batch_data['env_infos'],
        agent_infos=batch_data['agent_infos'],
    )
    batches = s.split()

    assert len(batches) == 2  # original batch_data is a batch of 2
    for i, batch in enumerate(batches):
        assert batch.env_spec == batch_data['env_spec']
        assert np.array_equal(batch.observations,
                              [batch_data['observations'][i]])
        assert np.array_equal(batch.next_observations,
                              [batch_data['next_observations'][i]])
        assert np.array_equal(batch.actions, [batch_data['actions'][i]])
        assert np.array_equal(batch.rewards, [batch_data['rewards'][i]])
        assert np.array_equal(batch.step_types, [batch_data['step_types'][i]])
        for key in batch.env_infos:
            assert key in batch_data['env_infos']
            assert np.array_equal(batch.env_infos[key],
                                  [batch_data['env_infos'][key][i]])
        for key in batch.agent_infos:
            assert key in batch_data['agent_infos']
            assert (np.array_equal(batch.agent_infos[key],
                                   [batch_data['agent_infos'][key][i]]))
Example #2
0
def test_to_time_step_list_batch(batch_data):
    s = TimeStepBatch(
        env_spec=batch_data['env_spec'],
        observations=batch_data['observations'],
        actions=batch_data['actions'],
        rewards=batch_data['rewards'],
        next_observations=batch_data['next_observations'],
        step_types=batch_data['step_types'],
        env_infos=batch_data['env_infos'],
        agent_infos=batch_data['agent_infos'],
    )
    batches = s.to_time_step_list()

    assert len(batches) == 2  # original batch_data is a batch of 2
    for i, batch in enumerate(batches):
        assert np.array_equal(batch['observations'],
                              [batch_data['observations'][i]])
        assert np.array_equal(batch['next_observations'],
                              [batch_data['next_observations'][i]])
        assert np.array_equal(batch['actions'], [batch_data['actions'][i]])
        assert np.array_equal(batch['rewards'], [batch_data['rewards'][i]])
        assert np.array_equal(batch['step_types'],
                              [batch_data['step_types'][i]])
        for key in batch['env_infos']:
            assert key in batch_data['env_infos']
            assert np.array_equal(batch['env_infos'][key],
                                  [batch_data['env_infos'][key][i]])
        for key in batch['agent_infos']:
            assert key in batch_data['agent_infos']
            assert np.array_equal(batch['agent_infos'][key],
                                  [batch_data['agent_infos'][key][i]])
Example #3
0
def test_observations_env_spec_mismatch_batch(batch_data):
    with pytest.raises(ValueError, match='Each observation has shape'):
        batch_data['observations'] = batch_data['observations'][:, :, :, :1]
        s = TimeStepBatch(**batch_data)
        del s

    obs_space = akro.Box(low=1, high=10, shape=(4, 5, 2), dtype=np.float32)
    act_space = gym.spaces.MultiDiscrete([2, 5])
    env_spec = EnvSpec(obs_space, act_space)
    batch_data['env_spec'] = env_spec

    with pytest.raises(ValueError, match='Each observation has shape'):
        batch_data['observations'] = batch_data['observations'][:, :, :, :1]
        s = TimeStepBatch(**batch_data)
        del s
Example #4
0
def test_next_observations_batch_mismatch_batch(batch_data):
    with pytest.raises(ValueError,
                       match='batch dimension of '
                       'next_observations'):
        batch_data['next_observations'] = batch_data['next_observations'][:-1]
        s = TimeStepBatch(**batch_data)
        del s
Example #5
0
def expert_source(env, goal, max_episode_length, n_eps):
    expert = OptimalPolicy(env.spec, goal=goal)
    workers = WorkerFactory(seed=100, max_episode_length=max_episode_length)
    expert_sampler = LocalSampler.from_worker_factory(workers, expert, env)
    for _ in range(n_eps):
        eps_batch = expert_sampler.obtain_samples(0, max_episode_length, None)
        yield TimeStepBatch.from_episode_batch(eps_batch)
Example #6
0
    def sample_timesteps(self, batch_size):
        """Sample a batch of timesteps from the buffer.

        Args:
            batch_size (int): Number of timesteps to sample.

        Returns:
            TimeStepBatch: The batch of timesteps.

        """
        samples = self.sample_transitions(batch_size)
        step_types = np.array([
            StepType.TERMINAL if terminal else StepType.MID
            for terminal in samples['terminals'].reshape(-1)
        ],
                              dtype=StepType)
        return TimeStepBatch(env_spec=self._env_spec,
                             episode_infos={},
                             observations=samples['observations'],
                             actions=samples['actions'],
                             rewards=samples['rewards'].flatten(),
                             next_observations=samples['next_observations'],
                             step_types=step_types,
                             env_infos={},
                             agent_infos={})
Example #7
0
def test_agent_infos_batch_mismatch_batch(batch_data):
    with pytest.raises(ValueError,
                       match="Entry 'hidden' in agent_infos has batch size 1"):
        batch_data['agent_infos']['hidden'] = batch_data['agent_infos'][
            'hidden'][:-1]
        s = TimeStepBatch(**batch_data)
        del s
Example #8
0
def test_from_time_step_list_batch(batch_data):
    batches = [batch_data, batch_data]
    s = TimeStepBatch.from_time_step_list(batch_data['env_spec'], batches)

    new_obs = np.concatenate(
        [batch_data['observations'], batch_data['observations']])
    new_next_obs = np.concatenate(
        [batch_data['next_observations'], batch_data['next_observations']])
    new_actions = np.concatenate(
        [batch_data['actions'], batch_data['actions']])
    new_rewards = np.concatenate(
        [batch_data['rewards'], batch_data['rewards']])
    new_step_types = np.concatenate(
        [batch_data['step_types'], batch_data['step_types']])
    new_env_infos = {
        k: np.concatenate([b['env_infos'][k] for b in batches])
        for k in batches[0]['env_infos'].keys()
    }
    new_agent_infos = {
        k: np.concatenate([b['agent_infos'][k] for b in batches])
        for k in batches[0]['agent_infos'].keys()
    }

    assert s.env_spec == batch_data['env_spec']
    assert np.array_equal(s.observations, new_obs)
    assert np.array_equal(s.next_observations, new_next_obs)
    assert np.array_equal(s.actions, new_actions)
    assert np.array_equal(s.rewards, new_rewards)
    assert np.array_equal(s.step_types, new_step_types)
    for key in new_env_infos:
        assert key in s.env_infos
        assert np.array_equal(new_env_infos[key], s.env_infos[key])
    for key in new_agent_infos:
        assert key in s.agent_infos
        assert np.array_equal(new_agent_infos[key], s.agent_infos[key])
Example #9
0
def test_act_box_env_spec_mismatch_batch(batch_data):
    with pytest.raises(ValueError, match='Each action has'):
        batch_data['env_spec'] = EnvSpec(
            batch_data['env_spec'].observation_space,
            akro.Box(low=1, high=np.inf, shape=(4, 3, 2), dtype=np.float32))
        s = TimeStepBatch(**batch_data)
        del s
Example #10
0
def test_time_step_batch_from_episode_batch(eps_data):
    eps = EpisodeBatch(**eps_data)
    timestep_batch = TimeStepBatch.from_episode_batch(eps)
    assert (timestep_batch.observations == eps.observations).all()
    assert (timestep_batch.next_observations[:eps.lengths[0] - 1] ==
            eps.observations[1:eps.lengths[0]]).all()
    assert (timestep_batch.next_observations[eps.lengths[0]] ==
            eps.last_observations[0]).all()
Example #11
0
def test_act_box_env_spec_mismatch_batch(batch_data):
    with pytest.raises(ValueError, match='actions should have'):
        batch_data['env_spec'].action_space = akro.Box(low=1,
                                                       high=np.inf,
                                                       shape=(4, 3, 2),
                                                       dtype=np.float32)
        s = TimeStepBatch(**batch_data)
        del s
Example #12
0
def test_agent_infos_batch_mismatch_batch(batch_data):
    with pytest.raises(
            ValueError,
            match='entry in agent_infos must have a batch dimension'):
        batch_data['agent_infos']['hidden'] = batch_data['agent_infos'][
            'hidden'][:-1]
        s = TimeStepBatch(**batch_data)
        del s
Example #13
0
def test_time_step_batch_from_trajectory_batch(traj_data):
    traj = TrajectoryBatch(**traj_data)
    timestep_batch = TimeStepBatch.from_trajectory_batch(traj)
    assert (timestep_batch.observations == traj.observations).all()
    assert (timestep_batch.next_observations[:traj.lengths[0] - 1] ==
            traj.observations[1:traj.lengths[0]]).all()
    assert (timestep_batch.next_observations[traj.lengths[0]] ==
            traj.last_observations[0]).all()
Example #14
0
def test_new_ts_batch(batch_data):
    s = TimeStepBatch(**batch_data)
    assert s.env_spec is batch_data['env_spec']
    assert s.observations is batch_data['observations']
    assert s.next_observations is batch_data['next_observations']
    assert s.actions is batch_data['actions']
    assert s.rewards is batch_data['rewards']
    assert s.env_infos is batch_data['env_infos']
    assert s.agent_infos is batch_data['agent_infos']
    assert s.step_types is batch_data['step_types']
Example #15
0
def test_next_observations_env_spec_mismatch_batch(batch_data):
    with pytest.raises(ValueError, match='next_observations must conform'):
        batch_data['next_observations'] = batch_data[
            'next_observations'][:, :, :, :1]
        s = TimeStepBatch(**batch_data)
        del s

    obs_space = akro.Box(low=1, high=10, shape=(4, 3, 2), dtype=np.float32)
    act_space = gym.spaces.MultiDiscrete([2, 5])
    env_spec = EnvSpec(obs_space, act_space)
    batch_data['env_spec'] = env_spec

    with pytest.raises(
            ValueError,
            match='next_observations should have the same dimensionality'):
        batch_data['next_observations'] = batch_data[
            'next_observations'][:, :, :, :1]
        s = TimeStepBatch(**batch_data)
        del s
Example #16
0
def test_terminals(batch_data):
    s = TimeStepBatch(
        env_spec=batch_data['env_spec'],
        observations=batch_data['observations'],
        actions=batch_data['actions'],
        rewards=batch_data['rewards'],
        next_observations=batch_data['next_observations'],
        step_types=batch_data['step_types'],
        env_infos=batch_data['env_infos'],
        agent_infos=batch_data['agent_infos'],
    )
    assert s.terminals.shape == s.rewards.shape
Example #17
0
def test_concatenate_batch(batch_data):
    single_batch = TimeStepBatch(**batch_data)
    batches = [single_batch, single_batch]
    s = TimeStepBatch.concatenate(*batches)

    new_obs = np.concatenate(
        [batch_data['observations'], batch_data['observations']])
    new_next_obs = np.concatenate(
        [batch_data['next_observations'], batch_data['next_observations']])
    new_actions = np.concatenate(
        [batch_data['actions'], batch_data['actions']])
    new_rewards = np.concatenate(
        [batch_data['rewards'], batch_data['rewards']])
    new_step_types = np.concatenate(
        [batch_data['step_types'], batch_data['step_types']])
    new_env_infos = {
        k: np.concatenate([b.env_infos[k] for b in batches])
        for k in batches[0].env_infos.keys()
    }
    new_agent_infos = {
        k: np.concatenate([b.agent_infos[k] for b in batches])
        for k in batches[0].agent_infos.keys()
    }

    assert s.env_spec == batch_data['env_spec']
    assert np.array_equal(s.observations, new_obs)
    assert np.array_equal(s.next_observations, new_next_obs)
    assert np.array_equal(s.actions, new_actions)
    assert np.array_equal(s.rewards, new_rewards)
    assert np.array_equal(s.step_types, new_step_types)
    for key in new_env_infos:
        assert key in s.env_infos
        assert np.array_equal(new_env_infos[key], s.env_infos[key])
    for key in new_agent_infos:
        assert key in s.agent_infos
        assert np.array_equal(new_agent_infos[key], s.agent_infos[key])
Example #18
0
    def _obtain_samples(self, trainer, epoch):
        """Obtain samples from self._source.

        Args:
            trainer (Trainer): Experiment trainer, which may be used to
                obtain samples.
            epoch (int): The current epoch.

        Returns:
            TimeStepBatch: Batch of samples.

        """
        if isinstance(self._source, Policy):
            batch = trainer.obtain_episodes(epoch)
            log_performance(epoch, batch, 1.0, prefix='Expert')
            return batch
        else:
            batches = []
            while (sum(len(batch.actions)
                       for batch in batches) < self._batch_size):
                batches.append(next(self._source))
            return TimeStepBatch.concatenate(*batches)
Example #19
0
    def _obtain_samples(self, runner, epoch):
        """Obtain samples from self._source.

        Args:
            runner (LocalRunner): LocalRunner to which may be used to obtain
                samples.
            epoch (int): The current epoch.

        Returns:
            TimeStepBatch: Batch of samples.

        """
        if isinstance(self._source, Policy):
            batch = TrajectoryBatch.from_trajectory_list(
                self.env_spec, runner.obtain_samples(epoch))
            log_performance(epoch, batch, 1.0, prefix='Expert')
            return batch
        else:
            batches = []
            while (sum(len(batch.actions)
                       for batch in batches) < self._batch_size):
                batches.append(next(self._source))
            return TimeStepBatch.concatenate(*batches)
Example #20
0
def test_agent_infos_not_ndarray_batch(batch_data):
    with pytest.raises(ValueError, match="Entry 'bar' in agent_infos"):
        batch_data['agent_infos']['bar'] = list()
        s = TimeStepBatch(**batch_data)
        del s
Example #21
0
def test_from_empty_time_step_list_batch(batch_data):
    with pytest.raises(ValueError, match='at least one dict'):
        batches = []
        s = TimeStepBatch.from_time_step_list(batch_data['env_spec'], batches)
        del s
Example #22
0
def test_rewards_batch_mismatch_batch(batch_data):
    with pytest.raises(ValueError, match='batch dimension of rewards'):
        batch_data['rewards'] = batch_data['rewards'][:-1]
        s = TimeStepBatch(**batch_data)
        del s
Example #23
0
def test_empty_terminals__batch(batch_data):
    with pytest.raises(ValueError, match='batch dimension of terminals'):
        batch_data['terminals'] = []
        s = TimeStepBatch(**batch_data)
        del s
Example #24
0
def test_concatenate_empty_batch():
    with pytest.raises(ValueError, match='at least one'):
        batches = []
        s = TimeStepBatch.concatenate(*batches)
        del s
Example #25
0
def test_step_types_dtype_mismatch_batch(batch_data):
    with pytest.raises(ValueError, match='step_types must be a StepType'):
        batch_data['step_types'] = batch_data['step_types'].astype(np.float32)
        s = TimeStepBatch(**batch_data)
        del s
Example #26
0
def test_step_types_batch_mismatch_batch(batch_data):
    with pytest.raises(ValueError, match='batch dimension of step_types'):
        batch_data['step_types'] = np.array([])
        s = TimeStepBatch(**batch_data)
        del s
Example #27
0
def test_terminals_dtype_mismatch_batch(batch_data):
    with pytest.raises(ValueError, match='terminals tensor must be dtype'):
        batch_data['terminals'] = batch_data['terminals'].astype(np.float32)
        s = TimeStepBatch(**batch_data)
        del s
Example #28
0
def test_agent_infos_not_ndarray_batch(batch_data):
    with pytest.raises(ValueError,
                       match='entry in agent_infos must be a numpy array'):
        batch_data['agent_infos']['bar'] = list()
        s = TimeStepBatch(**batch_data)
        del s
Example #29
0
def test_act_env_spec_mismatch_batch(batch_data):
    with pytest.raises(ValueError, match='actions must conform'):
        batch_data['actions'] = batch_data['actions'][:, 0]
        s = TimeStepBatch(**batch_data)
        del s
Example #30
0
def test_invalid_inferred_batch_size(batch_data):
    with pytest.raises(ValueError, match='batch dimension of rewards'):
        batch_data['rewards'] = []
        s = TimeStepBatch(**batch_data)
        del s