Exemple #1
0
def test_cql_impl(observation_shape, action_size, actor_learning_rate,
                  critic_learning_rate, temp_learning_rate,
                  alpha_learning_rate, gamma, tau, n_critics, bootstrap,
                  share_encoder, initial_temperature, initial_alpha,
                  alpha_threshold, n_action_samples, eps, use_batch_norm,
                  q_func_type, scaler, augmentation, n_augmentations,
                  encoder_params):
    impl = CQLImpl(observation_shape,
                   action_size,
                   actor_learning_rate,
                   critic_learning_rate,
                   temp_learning_rate,
                   alpha_learning_rate,
                   gamma,
                   tau,
                   n_critics,
                   bootstrap,
                   share_encoder,
                   initial_temperature,
                   initial_alpha,
                   alpha_threshold,
                   n_action_samples,
                   eps,
                   use_batch_norm,
                   q_func_type,
                   use_gpu=False,
                   scaler=scaler,
                   augmentation=augmentation,
                   n_augmentations=n_augmentations,
                   encoder_params=encoder_params)
    torch_impl_tester(impl,
                      discrete=False,
                      deterministic_best_action=q_func_type != 'iqn')
Exemple #2
0
def test_cql_impl(
    observation_shape,
    action_size,
    actor_learning_rate,
    critic_learning_rate,
    temp_learning_rate,
    alpha_learning_rate,
    actor_optim_factory,
    critic_optim_factory,
    temp_optim_factory,
    alpha_optim_factory,
    encoder_factory,
    q_func_factory,
    gamma,
    tau,
    n_critics,
    bootstrap,
    share_encoder,
    initial_temperature,
    initial_alpha,
    alpha_threshold,
    n_action_samples,
    soft_q_backup,
    scaler,
    action_scaler,
    augmentation,
):
    impl = CQLImpl(
        observation_shape,
        action_size,
        actor_learning_rate,
        critic_learning_rate,
        temp_learning_rate,
        alpha_learning_rate,
        actor_optim_factory,
        critic_optim_factory,
        temp_optim_factory,
        alpha_optim_factory,
        encoder_factory,
        encoder_factory,
        create_q_func_factory(q_func_factory),
        gamma,
        tau,
        n_critics,
        bootstrap,
        share_encoder,
        initial_temperature,
        initial_alpha,
        alpha_threshold,
        n_action_samples,
        soft_q_backup,
        use_gpu=False,
        scaler=scaler,
        action_scaler=action_scaler,
        augmentation=augmentation,
    )
    torch_impl_tester(impl,
                      discrete=False,
                      deterministic_best_action=q_func_factory != "iqn")
Exemple #3
0
def test_cql_impl(
    observation_shape,
    action_size,
    actor_learning_rate,
    critic_learning_rate,
    temp_learning_rate,
    alpha_learning_rate,
    actor_optim_factory,
    critic_optim_factory,
    temp_optim_factory,
    alpha_optim_factory,
    encoder_factory,
    q_func_factory,
    gamma,
    tau,
    n_critics,
    target_reduction_type,
    initial_temperature,
    initial_alpha,
    alpha_threshold,
    conservative_weight,
    n_action_samples,
    soft_q_backup,
    scaler,
    action_scaler,
    reward_scaler,
):
    impl = CQLImpl(
        observation_shape=observation_shape,
        action_size=action_size,
        actor_learning_rate=actor_learning_rate,
        critic_learning_rate=critic_learning_rate,
        temp_learning_rate=temp_learning_rate,
        alpha_learning_rate=alpha_learning_rate,
        actor_optim_factory=actor_optim_factory,
        critic_optim_factory=critic_optim_factory,
        temp_optim_factory=temp_optim_factory,
        alpha_optim_factory=alpha_optim_factory,
        actor_encoder_factory=encoder_factory,
        critic_encoder_factory=encoder_factory,
        q_func_factory=create_q_func_factory(q_func_factory),
        gamma=gamma,
        tau=tau,
        n_critics=n_critics,
        target_reduction_type=target_reduction_type,
        initial_temperature=initial_temperature,
        initial_alpha=initial_alpha,
        alpha_threshold=alpha_threshold,
        conservative_weight=conservative_weight,
        n_action_samples=n_action_samples,
        soft_q_backup=soft_q_backup,
        use_gpu=None,
        scaler=scaler,
        action_scaler=action_scaler,
        reward_scaler=reward_scaler,
    )
    torch_impl_tester(
        impl, discrete=False, deterministic_best_action=q_func_factory != "iqn"
    )