示例#1
0
def test_compute_max_with_n_actions(
    observation_shape,
    action_size,
    encoder_factory,
    q_func_factory,
    n_ensembles,
    batch_size,
    n_quantiles,
    n_actions,
    lam,
):
    q_func = create_continuous_q_function(
        observation_shape,
        action_size,
        encoder_factory,
        q_func_factory,
        n_ensembles=n_ensembles,
    )
    x = torch.rand(batch_size, *observation_shape)
    actions = torch.rand(batch_size, n_actions, action_size)

    y = compute_max_with_n_actions(x, actions, q_func, lam)

    if isinstance(q_func_factory, MeanQFunctionFactory):
        assert y.shape == (batch_size, 1)
    else:
        assert y.shape == (batch_size, q_func_factory.n_quantiles)
示例#2
0
    def compute_target(self, x):
        # TODO: this seems to be slow with image observation
        with torch.no_grad():
            repeated_x = self._repeat_observation(x)
            actions = self._sample_action(repeated_x, True)

            values = compute_max_with_n_actions(x, actions, self.targ_q_func,
                                                self.lam)

            return values
示例#3
0
    def compute_target(self, x):
        with torch.no_grad():
            # BCQ-like target computation
            actions, log_probs = self.policy.sample_n(x, self.n_action_samples,
                                                      True)
            values, indices = compute_max_with_n_actions(
                x, actions, self.targ_q_func, self.lam, True)

            # (batch, n, 1) -> (batch, 1)
            max_log_prob = log_probs[torch.arange(x.shape[0]), indices]

            return values - self.log_temp.exp() * max_log_prob
示例#4
0
def test_compute_max_with_n_actions(observation_shape, action_size,
                                    n_ensembles, batch_size, n_quantiles,
                                    n_actions, lam, q_func_type):
    q_func = create_continuous_q_function(observation_shape,
                                          action_size,
                                          n_ensembles,
                                          n_quantiles,
                                          q_func_type=q_func_type)
    x = torch.rand(batch_size, *observation_shape)
    actions = torch.rand(batch_size, n_actions, action_size)

    y = compute_max_with_n_actions(x, actions, q_func, lam)

    if q_func_type == 'mean':
        assert y.shape == (batch_size, 1)
    else:
        assert y.shape == (batch_size, n_quantiles)