예제 #1
0
 def _build_critic(self):
     self.q_func = create_continuous_q_function(
         self.observation_shape,
         self.action_size,
         self.critic_encoder_factory,
         n_ensembles=self.n_critics,
         q_func_type=self.q_func_type,
         bootstrap=self.bootstrap,
         share_encoder=self.share_encoder)
예제 #2
0
파일: ddpg_impl.py 프로젝트: kintatta/d3rl
 def _build_critic(self):
     self.q_func = create_continuous_q_function(
         self.observation_shape,
         self.action_size,
         n_ensembles=self.n_critics,
         use_batch_norm=self.use_batch_norm,
         q_func_type=self.q_func_type,
         bootstrap=self.bootstrap,
         share_encoder=self.share_encoder,
         encoder_params=self.encoder_params)
예제 #3
0
def test_compute_max_with_n_actions(observation_shape, action_size,
                                    n_ensembles, batch_size, n_quantiles,
                                    n_actions, lam, q_func_type):
    q_func = create_continuous_q_function(observation_shape,
                                          action_size,
                                          n_ensembles,
                                          n_quantiles,
                                          q_func_type=q_func_type)
    x = torch.rand(batch_size, *observation_shape)
    actions = torch.rand(batch_size, n_actions, action_size)

    y = compute_max_with_n_actions(x, actions, q_func, lam)

    if q_func_type == 'mean':
        assert y.shape == (batch_size, 1)
    else:
        assert y.shape == (batch_size, n_quantiles)
예제 #4
0
def test_create_continuous_q_function(observation_shape, action_size,
                                      batch_size, n_ensembles, n_quantiles,
                                      embed_size, use_batch_norm,
                                      share_encoder, q_func_type):

    q_func = create_continuous_q_function(observation_shape,
                                          action_size,
                                          n_ensembles,
                                          n_quantiles,
                                          embed_size,
                                          use_batch_norm,
                                          q_func_type,
                                          share_encoder=share_encoder)

    assert isinstance(q_func, EnsembleContinuousQFunction)
    for f in q_func.q_funcs:
        if q_func_type == 'mean':
            assert isinstance(f, ContinuousQFunction)
        elif q_func_type == 'qr':
            assert isinstance(f, ContinuousQRQFunction)
        elif q_func_type == 'iqn':
            assert isinstance(f, ContinuousIQNQFunction)
        elif q_func_type == 'fqf':
            assert isinstance(f, ContinuousFQFQFunction)
        assert f.encoder.use_batch_norm == use_batch_norm

    # check share_encoder
    encoder = q_func.q_funcs[0].encoder
    for q_func in q_func.q_funcs[1:]:
        if share_encoder:
            assert encoder is q_func.encoder
        else:
            assert encoder is not q_func.encoder

    x = torch.rand((batch_size, ) + observation_shape)
    action = torch.rand(batch_size, action_size)
    y = q_func(x, action)
    assert y.shape == (batch_size, 1)