コード例 #1
0
ファイル: test_imitators.py プロジェクト: mchetouani/d3rlpy
def test_create_deterministic_regressor(observation_shape, action_size,
                                        batch_size, encoder_factory):
    imitator = create_deterministic_regressor(observation_shape, action_size,
                                              encoder_factory)

    assert isinstance(imitator, DeterministicRegressor)

    x = torch.rand((batch_size, ) + observation_shape)
    y = imitator(x)
    assert y.shape == (batch_size, action_size)
コード例 #2
0
 def _build_network(self):
     self.imitator = create_deterministic_regressor(
         self.observation_shape,
         self.action_size,
         encoder_params=self.encoder_params)