Esempio n. 1
0
    def sample(self, batch_size, buckets=False):
        """Samples a batch of datapoints.

        Args:
            batch_size (int): Number of datapoints to sample.
            buckets (bool): Indicates if buckets indices should be returned.

        Returns:
            Datapoint object with sampled datapoints stacked along the 0 axis.

        Raises:
            ValueError: If the buffer is empty.
        """
        if self._hierarchy_depth > 1:
            samples = [self._sample_one() for _ in range(batch_size)]
        else:
            p = np.ones((len(self._buffer_hierarchy), ), dtype=np.float32)
            p = np.atleast_1d(p) / p.sum()
            samples = []
            distribution = multinomial(batch_size, p=p)
            rvs = distribution.rvs(1).squeeze(axis=0)
            for bucket, n_samples in zip(self._buffer_hierarchy, rvs):
                if self._hierarchy_depth > 0:
                    buffer = self._buffer_hierarchy[bucket]
                else:
                    buffer = self._buffer_hierarchy
                samples_b = buffer.sample(n_samples)
                if buckets:
                    samples_b = samples_b + (np.full(n_samples, bucket), )
                samples.append(samples_b)

        return data.nested_concatenate(samples)
Esempio n. 2
0
    def sample(self, batch_size):
        """Samples a batch of datapoints.

        Args:
            batch_size (int): Number of datapoints to sample.

        Returns:
            Datapoint object with sampled datapoints stacked along the 0 axis.

        Raises:
            ValueError: If the buffer is empty.
        """
        return data.nested_concatenate(
            [self._sample_one() for _ in range(batch_size)]
        )
Esempio n. 3
0
    def compute_metrics(episodes):
        agent_info_batch = data.nested_concatenate(
            [episode.transition_batch.agent_info for episode in episodes])
        metrics = {}

        metrics.update(
            metric_utils.compute_scalar_statistics(agent_info_batch['value'],
                                                   prefix='value',
                                                   with_min_and_max=True))
        metrics.update(
            metric_utils.compute_scalar_statistics(agent_info_batch['logits'],
                                                   prefix='logits',
                                                   with_min_and_max=True))

        return metrics
Esempio n. 4
0
    def batched_request(self):
        if self._batched_request is not None:
            return self._batched_request

        def assert_not_scalar(x):
            assert np.array(x).shape, (
                'All arrays in a PredictRequest must be at least rank 1.'
            )

        data.nested_map(assert_not_scalar, self._requests)

        self._batched_request = data.nested_concatenate(self._requests)

        self._batch_sizes = [
            data.choose_leaf(request).shape[0]
            for request in self._requests
        ]
        return self._batched_request
Esempio n. 5
0
    def compute_metrics(episodes):
        # Calculate simulation policy entropy.
        agent_info_batch = data.nested_concatenate(
            [episode.transition_batch.agent_info for episode in episodes])
        metrics = {}

        if 'sim_pi_entropy' in agent_info_batch:
            metrics.update(
                metric_logging.compute_scalar_statistics(
                    agent_info_batch['sim_pi_entropy'],
                    prefix='simulation_entropy',
                    with_min_and_max=True))

        if 'sim_pi_value' in agent_info_batch:
            metrics.update(
                metric_logging.compute_scalar_statistics(
                    agent_info_batch['sim_pi_value'],
                    prefix='network_value',
                    with_min_and_max=True))

        if 'sim_pi_logits' in agent_info_batch:
            metrics.update(
                metric_logging.compute_scalar_statistics(
                    agent_info_batch['sim_pi_logits'],
                    prefix='network_logits',
                    with_min_and_max=True))

        if 'value' in agent_info_batch:
            metrics.update(
                metric_logging.compute_scalar_statistics(
                    agent_info_batch['value'],
                    prefix='simulation_value',
                    with_min_and_max=True))

        if 'qualities' in agent_info_batch:
            metrics.update(
                metric_logging.compute_scalar_statistics(
                    agent_info_batch['qualities'],
                    prefix='simulation_qualities',
                    with_min_and_max=True))

        return metrics
Esempio n. 6
0
    def act(self, observation):
        """Runs n_rollouts simulations and chooses the best action."""
        assert self._model is not None, ('Reset ShootingAgent first.')
        del observation

        # Lazy initialize batch stepper
        if self._batch_stepper is None:
            self._batch_stepper = self._batch_stepper_class(
                env_class=type(self._model),
                agent_class=self._agent_class,
                network_fn=self._network_fn,
                n_envs=self._n_envs,
                output_dir=None,
            )

        # BatchStepper for a given number of episodes (by default n_envs).
        episodes = []
        for _ in range(math.ceil(self._n_rollouts / self._n_envs)):
            episodes.extend(
                self._batch_stepper.run_episode_batch(
                    agent_params=self._params,
                    epoch=self._epoch,
                    init_state=self._model.clone_state(),
                    time_limit=self._rollout_time_limit,
                ))

        # Compute episode returns and put them in a map.
        returns_ = yield from self._estimate_fn(episodes, self._discount)
        act_to_rets_map = {key: [] for key in range(self._action_space.n)}
        for episode, return_ in zip(episodes, returns_):
            act_to_rets_map[episode.transition_batch.action[0]].append(return_)

        # Aggregate episodes into action scores.
        action_scores = np.empty(self._action_space.n)
        for action, returns in act_to_rets_map.items():
            action_scores[action] = (self._aggregate_fn(returns)
                                     if returns else np.nan)

        # Calculate the estimate of a state value.
        value = sum(returns_) / len(returns_)

        # Choose greedy action (ignore NaN scores).
        action = np.nanargmax(action_scores)
        onehot_action = np.zeros_like(action_scores)
        onehot_action[action] = 1

        # Pack statistics into agent info.
        agent_info = {
            'action_histogram': onehot_action,
            'value': value,
            'qualities': np.nan_to_num(action_scores),
        }

        # Calculate simulation policy entropy, average value and logits.
        agent_info_batch = data.nested_concatenate(
            [episode.transition_batch.agent_info for episode in episodes])
        if 'entropy' in agent_info_batch:
            agent_info['sim_pi_entropy'] = np.mean(agent_info_batch['entropy'])
        if 'value' in agent_info_batch:
            agent_info['sim_pi_value'] = np.mean(agent_info_batch['value'])
        if 'logits' in agent_info_batch:
            agent_info['sim_pi_logits'] = np.mean(agent_info_batch['logits'])

        self._run_agent_callbacks(episodes)
        return action, agent_info
Esempio n. 7
0
def test_batch_steppers_run_episode_batch(max_n_requests, batch_stepper_cls):
    n_envs = 8
    max_n_steps = 4
    n_total_steps = n_envs * max_n_steps
    n_total_requests = n_total_steps * max_n_requests

    # Generate some random data.
    def sample_seq(n):
        return [np.random.randint(1, 999) for _ in range(n)]

    def setup_seq(n):
        expected = sample_seq(n)
        to_return = copy.copy(expected)
        actual = []
        return (expected, to_return, actual)

    (expected_rew, rew_to_return, _) = setup_seq(n_total_steps)
    (expected_obs, obs_to_return, actual_obs) = setup_seq(n_total_steps)
    (expected_act, act_to_return, actual_act) = setup_seq(n_total_steps)
    (expected_req, req_to_return, actual_req) = setup_seq(n_total_requests)
    (expected_res, res_to_return, actual_res) = setup_seq(n_total_requests)

    # Connect all pipes together.
    stepper = batch_stepper_cls(
        env_class=functools.partial(
            _TestEnv,
            actions=actual_act,
            n_steps=max_n_steps,
            observations=obs_to_return,
            rewards=rew_to_return,
        ),
        agent_class=functools.partial(
            _TestAgent,
            observations=actual_obs,
            max_n_requests=max_n_requests,
            requests=req_to_return,
            responses=actual_res,
            actions=act_to_return,
        ),
        network_fn=functools.partial(
            _TestNetwork,
            inputs=actual_req,
            outputs=res_to_return,
        ),
        n_envs=n_envs,
        output_dir=None,
    )
    episodes = stepper.run_episode_batch(params=None)
    transition_batch = data.nested_concatenate(
        # pylint: disable=not-an-iterable
        [episode.transition_batch for episode in episodes])

    # Assert that all data got passed around correctly.
    assert len(actual_obs) >= n_envs
    np.testing.assert_array_equal(actual_obs, expected_obs[:len(actual_obs)])
    np.testing.assert_array_equal(actual_req, expected_req[:len(actual_req)])
    np.testing.assert_array_equal(actual_res, expected_res[:len(actual_req)])
    np.testing.assert_array_equal(actual_act, expected_act[:len(actual_obs)])

    # Assert that we collected the correct transitions (order is mixed up).
    assert set(transition_batch.observation.tolist()) == set(actual_obs)
    assert set(transition_batch.action.tolist()) == set(actual_act)
    assert set(transition_batch.reward.tolist()) == set(
        expected_rew[:len(actual_obs)])
    assert transition_batch.done.sum() == n_envs
Esempio n. 8
0
 def sample(self, batch_size):
     return data.nested_concatenate(
         [self._sample_one() for _ in range(batch_size)])
Esempio n. 9
0
    def compute_metrics(episodes):
        agent_info_batch = data.nested_concatenate(
            [episode.transition_batch.agent_info for episode in episodes])

        return metric_logging.compute_scalar_statistics(
            agent_info_batch['logits'], prefix='logits', with_min_and_max=True)