コード例 #1
0
ファイル: test_bcq_impl.py プロジェクト: wx-b/d3rlpy
def test_discrete_bcq_impl(
    observation_shape,
    action_size,
    learning_rate,
    optim_factory,
    encoder_factory,
    q_func_factory,
    gamma,
    n_critics,
    target_reduction_type,
    action_flexibility,
    beta,
    scaler,
):
    impl = DiscreteBCQImpl(
        observation_shape=observation_shape,
        action_size=action_size,
        learning_rate=learning_rate,
        optim_factory=optim_factory,
        encoder_factory=encoder_factory,
        q_func_factory=create_q_func_factory(q_func_factory),
        gamma=gamma,
        n_critics=n_critics,
        target_reduction_type=target_reduction_type,
        action_flexibility=action_flexibility,
        beta=beta,
        use_gpu=None,
        scaler=scaler,
    )
    torch_impl_tester(impl,
                      discrete=True,
                      deterministic_best_action=q_func_factory != "iqn")
コード例 #2
0
ファイル: test_bcq_impl.py プロジェクト: navidmdn/d3rlpy
def test_discrete_bcq_impl(
    observation_shape,
    action_size,
    learning_rate,
    optim_factory,
    encoder_factory,
    q_func_factory,
    gamma,
    n_critics,
    bootstrap,
    share_encoder,
    target_reduction_type,
    action_flexibility,
    beta,
    scaler,
    augmentation,
):
    impl = DiscreteBCQImpl(
        observation_shape,
        action_size,
        learning_rate,
        optim_factory,
        encoder_factory,
        create_q_func_factory(q_func_factory),
        gamma,
        n_critics,
        bootstrap,
        share_encoder,
        target_reduction_type,
        action_flexibility,
        beta,
        use_gpu=False,
        scaler=scaler,
        augmentation=augmentation,
    )
    torch_impl_tester(impl,
                      discrete=True,
                      deterministic_best_action=q_func_factory != "iqn")
コード例 #3
0
ファイル: test_bcq_impl.py プロジェクト: mchetouani/d3rlpy
def test_discrete_bcq_impl(observation_shape, action_size, learning_rate,
                           optim_factory, encoder_factory, gamma, n_critics,
                           bootstrap, share_encoder, action_flexibility, beta,
                           q_func_type, scaler, augmentation, n_augmentations):
    impl = DiscreteBCQImpl(observation_shape,
                           action_size,
                           learning_rate,
                           optim_factory,
                           encoder_factory,
                           gamma,
                           n_critics,
                           bootstrap,
                           share_encoder,
                           action_flexibility,
                           beta,
                           q_func_type,
                           use_gpu=False,
                           scaler=scaler,
                           augmentation=augmentation,
                           n_augmentations=n_augmentations)
    torch_impl_tester(impl,
                      discrete=True,
                      deterministic_best_action=q_func_type != 'iqn')