Ejemplo n.º 1
0
def test_awac_impl(observation_shape, action_size, actor_learning_rate,
                   critic_learning_rate, actor_optim_factory,
                   critic_optim_factory, encoder_factory, gamma, tau, lam,
                   n_action_samples, max_weight, n_critics, bootstrap,
                   share_encoder, q_func_type, scaler, augmentation,
                   n_augmentations):
    impl = AWACImpl(observation_shape,
                    action_size,
                    actor_learning_rate,
                    critic_learning_rate,
                    actor_optim_factory,
                    critic_optim_factory,
                    encoder_factory,
                    encoder_factory,
                    gamma,
                    tau,
                    lam,
                    n_action_samples,
                    max_weight,
                    n_critics,
                    bootstrap,
                    share_encoder,
                    q_func_type=q_func_type,
                    use_gpu=False,
                    scaler=scaler,
                    augmentation=augmentation,
                    n_augmentations=n_augmentations)
    torch_impl_tester(impl,
                      discrete=False,
                      deterministic_best_action=q_func_type != 'iqn')
Ejemplo n.º 2
0
def test_awac_impl(
    observation_shape,
    action_size,
    actor_learning_rate,
    critic_learning_rate,
    actor_optim_factory,
    critic_optim_factory,
    encoder_factory,
    q_func_factory,
    gamma,
    tau,
    lam,
    n_action_samples,
    max_weight,
    n_critics,
    target_reduction_type,
    scaler,
    action_scaler,
    augmentation,
):
    impl = AWACImpl(
        observation_shape=observation_shape,
        action_size=action_size,
        actor_learning_rate=actor_learning_rate,
        critic_learning_rate=critic_learning_rate,
        actor_optim_factory=actor_optim_factory,
        critic_optim_factory=critic_optim_factory,
        actor_encoder_factory=encoder_factory,
        critic_encoder_factory=encoder_factory,
        q_func_factory=create_q_func_factory(q_func_factory),
        gamma=gamma,
        tau=tau,
        lam=lam,
        n_action_samples=n_action_samples,
        max_weight=max_weight,
        n_critics=n_critics,
        target_reduction_type=target_reduction_type,
        use_gpu=None,
        scaler=scaler,
        action_scaler=action_scaler,
        augmentation=augmentation,
    )
    torch_impl_tester(
        impl, discrete=False, deterministic_best_action=q_func_factory != "iqn"
    )