Exemplo n.º 1
0
def test_compute_max_with_n_actions(
    observation_shape,
    action_size,
    encoder_factory,
    q_func_factory,
    n_ensembles,
    batch_size,
    n_quantiles,
    n_actions,
    lam,
):
    q_func = create_continuous_q_function(
        observation_shape,
        action_size,
        encoder_factory,
        q_func_factory,
        n_ensembles=n_ensembles,
    )
    x = torch.rand(batch_size, *observation_shape)
    actions = torch.rand(batch_size, n_actions, action_size)

    y = compute_max_with_n_actions(x, actions, q_func, lam)

    if isinstance(q_func_factory, MeanQFunctionFactory):
        assert y.shape == (batch_size, 1)
    else:
        assert y.shape == (batch_size, q_func_factory.n_quantiles)
Exemplo n.º 2
0
def test_create_continuous_q_function(
    observation_shape,
    action_size,
    batch_size,
    n_ensembles,
    encoder_factory,
    q_func_factory,
    share_encoder,
):
    q_func = create_continuous_q_function(
        observation_shape,
        action_size,
        encoder_factory,
        q_func_factory,
        n_ensembles,
        share_encoder=share_encoder,
    )

    assert isinstance(q_func, EnsembleContinuousQFunction)

    # check share_encoder
    encoder = q_func.q_funcs[0].encoder
    for q_func in q_func.q_funcs[1:]:
        if share_encoder:
            assert encoder is q_func.encoder
        else:
            assert encoder is not q_func.encoder

    x = torch.rand((batch_size, ) + observation_shape)
    action = torch.rand(batch_size, action_size)
    y = q_func(x, action)
    assert y.shape == (batch_size, 1)