Esempio n. 1
0
def test_awr_impl(
    observation_shape,
    action_size,
    actor_learning_rate,
    critic_learning_rate,
    actor_optim_factory,
    critic_optim_factory,
    encoder_factory,
    scaler,
    action_scaler,
    augmentation,
):
    impl = AWRImpl(
        observation_shape=observation_shape,
        action_size=action_size,
        actor_learning_rate=actor_learning_rate,
        critic_learning_rate=critic_learning_rate,
        actor_optim_factory=actor_optim_factory,
        critic_optim_factory=critic_optim_factory,
        actor_encoder_factory=encoder_factory,
        critic_encoder_factory=encoder_factory,
        use_gpu=None,
        scaler=scaler,
        action_scaler=action_scaler,
        augmentation=augmentation,
    )
    torch_impl_tester(impl, discrete=False, test_with_std=False)
Esempio n. 2
0
def test_ddpg_impl(observation_shape, action_size, actor_learning_rate,
                   critic_learning_rate, actor_optim_factory,
                   critic_optim_factory, encoder_factory, gamma, tau,
                   n_critics, bootstrap, share_encoder, reguralizing_rate,
                   q_func_type, scaler, augmentation, n_augmentations):
    impl = DDPGImpl(observation_shape,
                    action_size,
                    actor_learning_rate,
                    critic_learning_rate,
                    actor_optim_factory,
                    critic_optim_factory,
                    encoder_factory,
                    encoder_factory,
                    gamma,
                    tau,
                    n_critics,
                    bootstrap,
                    share_encoder,
                    reguralizing_rate,
                    q_func_type=q_func_type,
                    use_gpu=False,
                    scaler=scaler,
                    augmentation=augmentation,
                    n_augmentations=n_augmentations)
    torch_impl_tester(impl,
                      discrete=False,
                      deterministic_best_action=q_func_type != 'iqn')
Esempio n. 3
0
def test_discrete_cql_impl(
    observation_shape,
    action_size,
    learning_rate,
    optim_factory,
    encoder_factory,
    q_func_factory,
    gamma,
    n_critics,
    target_reduction_type,
    alpha,
    scaler,
    reward_scaler,
):
    impl = DiscreteCQLImpl(
        observation_shape=observation_shape,
        action_size=action_size,
        learning_rate=learning_rate,
        optim_factory=optim_factory,
        encoder_factory=encoder_factory,
        q_func_factory=create_q_func_factory(q_func_factory),
        gamma=gamma,
        n_critics=n_critics,
        target_reduction_type=target_reduction_type,
        alpha=alpha,
        use_gpu=None,
        scaler=scaler,
        reward_scaler=reward_scaler,
    )
    torch_impl_tester(
        impl, discrete=True, deterministic_best_action=q_func_factory != "iqn"
    )
Esempio n. 4
0
def test_awac_impl(observation_shape, action_size, actor_learning_rate,
                   critic_learning_rate, actor_optim_factory,
                   critic_optim_factory, encoder_factory, gamma, tau, lam,
                   n_action_samples, max_weight, n_critics, bootstrap,
                   share_encoder, q_func_type, scaler, augmentation,
                   n_augmentations):
    impl = AWACImpl(observation_shape,
                    action_size,
                    actor_learning_rate,
                    critic_learning_rate,
                    actor_optim_factory,
                    critic_optim_factory,
                    encoder_factory,
                    encoder_factory,
                    gamma,
                    tau,
                    lam,
                    n_action_samples,
                    max_weight,
                    n_critics,
                    bootstrap,
                    share_encoder,
                    q_func_type=q_func_type,
                    use_gpu=False,
                    scaler=scaler,
                    augmentation=augmentation,
                    n_augmentations=n_augmentations)
    torch_impl_tester(impl,
                      discrete=False,
                      deterministic_best_action=q_func_type != 'iqn')
Esempio n. 5
0
def test_discrete_awr_impl(
    observation_shape,
    action_size,
    actor_learning_rate,
    critic_learning_rate,
    actor_optim_factory,
    critic_optim_factory,
    encoder_factory,
    scaler,
    augmentation,
):
    impl = DiscreteAWRImpl(
        observation_shape,
        action_size,
        actor_learning_rate,
        critic_learning_rate,
        actor_optim_factory,
        critic_optim_factory,
        encoder_factory,
        encoder_factory,
        use_gpu=False,
        scaler=scaler,
        augmentation=augmentation,
    )
    torch_impl_tester(impl, discrete=True, test_with_std=False)
Esempio n. 6
0
def test_sac_impl(observation_shape, action_size, actor_learning_rate,
                  critic_learning_rate, temp_learning_rate, gamma, tau,
                  n_critics, bootstrap, share_encoder, initial_temperature,
                  eps, use_batch_norm, q_func_type, scaler, augmentation,
                  n_augmentations, encoder_params):
    impl = SACImpl(observation_shape,
                   action_size,
                   actor_learning_rate,
                   critic_learning_rate,
                   temp_learning_rate,
                   gamma,
                   tau,
                   n_critics,
                   bootstrap,
                   share_encoder,
                   initial_temperature,
                   eps,
                   use_batch_norm,
                   q_func_type,
                   use_gpu=False,
                   scaler=scaler,
                   augmentation=augmentation,
                   n_augmentations=n_augmentations,
                   encoder_params=encoder_params)
    torch_impl_tester(impl,
                      discrete=False,
                      deterministic_best_action=q_func_type != 'iqn')
Esempio n. 7
0
def test_double_dqn_impl(
    observation_shape,
    action_size,
    learning_rate,
    optim_factory,
    encoder_factory,
    q_func_factory,
    gamma,
    n_critics,
    target_reduction_type,
    scaler,
    augmentation,
):
    impl = DoubleDQNImpl(
        observation_shape=observation_shape,
        action_size=action_size,
        learning_rate=learning_rate,
        optim_factory=optim_factory,
        encoder_factory=encoder_factory,
        q_func_factory=create_q_func_factory(q_func_factory),
        gamma=gamma,
        n_critics=n_critics,
        target_reduction_type=target_reduction_type,
        use_gpu=None,
        scaler=scaler,
        augmentation=augmentation,
    )
    torch_impl_tester(impl,
                      discrete=True,
                      deterministic_best_action=q_func_factory != "iqn")
Esempio n. 8
0
def test_td3_impl(observation_shape, action_size, actor_learning_rate,
                  critic_learning_rate, gamma, tau, reguralizing_rate,
                  n_critics, bootstrap, share_encoder, target_smoothing_sigma,
                  target_smoothing_clip, eps, use_batch_norm, q_func_type,
                  scaler, augmentation, n_augmentations, encoder_params):
    impl = TD3Impl(observation_shape,
                   action_size,
                   actor_learning_rate,
                   critic_learning_rate,
                   gamma,
                   tau,
                   reguralizing_rate,
                   n_critics,
                   bootstrap,
                   share_encoder,
                   target_smoothing_sigma,
                   target_smoothing_clip,
                   eps,
                   use_batch_norm,
                   q_func_type=q_func_type,
                   use_gpu=False,
                   scaler=scaler,
                   augmentation=augmentation,
                   n_augmentations=n_augmentations,
                   encoder_params=encoder_params)
    torch_impl_tester(impl,
                      discrete=False,
                      deterministic_best_action=q_func_type != 'iqn')
Esempio n. 9
0
def test_discrete_bcq_impl(
    observation_shape,
    action_size,
    learning_rate,
    optim_factory,
    encoder_factory,
    q_func_factory,
    gamma,
    n_critics,
    target_reduction_type,
    action_flexibility,
    beta,
    scaler,
):
    impl = DiscreteBCQImpl(
        observation_shape=observation_shape,
        action_size=action_size,
        learning_rate=learning_rate,
        optim_factory=optim_factory,
        encoder_factory=encoder_factory,
        q_func_factory=create_q_func_factory(q_func_factory),
        gamma=gamma,
        n_critics=n_critics,
        target_reduction_type=target_reduction_type,
        action_flexibility=action_flexibility,
        beta=beta,
        use_gpu=None,
        scaler=scaler,
    )
    torch_impl_tester(impl,
                      discrete=True,
                      deterministic_best_action=q_func_factory != "iqn")
Esempio n. 10
0
def test_discrete_cql_impl(
    observation_shape,
    action_size,
    learning_rate,
    optim_factory,
    encoder_factory,
    q_func_factory,
    gamma,
    n_critics,
    bootstrap,
    share_encoder,
    scaler,
    augmentation,
):
    impl = DiscreteCQLImpl(
        observation_shape,
        action_size,
        learning_rate,
        optim_factory,
        encoder_factory,
        create_q_func_factory(q_func_factory),
        gamma,
        n_critics,
        bootstrap,
        share_encoder,
        use_gpu=False,
        scaler=scaler,
        augmentation=augmentation,
    )
    torch_impl_tester(impl,
                      discrete=True,
                      deterministic_best_action=q_func_factory != "iqn")
Esempio n. 11
0
def test_discrete_sac_impl(observation_shape, action_size, actor_learning_rate,
                           critic_learning_rate, temp_learning_rate,
                           actor_optim_factory, critic_optim_factory,
                           temp_optim_factory, encoder_factory, gamma,
                           n_critics, bootstrap, share_encoder,
                           initial_temperature, q_func_type, scaler,
                           augmentation, n_augmentations):
    impl = DiscreteSACImpl(observation_shape,
                           action_size,
                           actor_learning_rate,
                           critic_learning_rate,
                           temp_learning_rate,
                           actor_optim_factory,
                           critic_optim_factory,
                           temp_optim_factory,
                           encoder_factory,
                           encoder_factory,
                           gamma,
                           n_critics,
                           bootstrap,
                           share_encoder,
                           initial_temperature,
                           q_func_type,
                           use_gpu=False,
                           scaler=scaler,
                           augmentation=augmentation,
                           n_augmentations=n_augmentations)
    torch_impl_tester(impl,
                      discrete=True,
                      deterministic_best_action=q_func_type != 'iqn')
Esempio n. 12
0
def test_discrete_awr_impl(
    observation_shape,
    action_size,
    actor_learning_rate,
    critic_learning_rate,
    actor_optim_factory,
    critic_optim_factory,
    encoder_factory,
    scaler,
    reward_scaler,
):
    impl = DiscreteAWRImpl(
        observation_shape=observation_shape,
        action_size=action_size,
        actor_learning_rate=actor_learning_rate,
        critic_learning_rate=critic_learning_rate,
        actor_optim_factory=actor_optim_factory,
        critic_optim_factory=critic_optim_factory,
        actor_encoder_factory=encoder_factory,
        critic_encoder_factory=encoder_factory,
        use_gpu=None,
        scaler=scaler,
        action_scaler=None,
        reward_scaler=reward_scaler,
    )
    torch_impl_tester(impl, discrete=True, test_with_std=False)
Esempio n. 13
0
def test_cql_impl(
    observation_shape,
    action_size,
    actor_learning_rate,
    critic_learning_rate,
    temp_learning_rate,
    alpha_learning_rate,
    actor_optim_factory,
    critic_optim_factory,
    temp_optim_factory,
    alpha_optim_factory,
    encoder_factory,
    q_func_factory,
    gamma,
    tau,
    n_critics,
    target_reduction_type,
    initial_temperature,
    initial_alpha,
    alpha_threshold,
    conservative_weight,
    n_action_samples,
    soft_q_backup,
    scaler,
    action_scaler,
    reward_scaler,
):
    impl = CQLImpl(
        observation_shape=observation_shape,
        action_size=action_size,
        actor_learning_rate=actor_learning_rate,
        critic_learning_rate=critic_learning_rate,
        temp_learning_rate=temp_learning_rate,
        alpha_learning_rate=alpha_learning_rate,
        actor_optim_factory=actor_optim_factory,
        critic_optim_factory=critic_optim_factory,
        temp_optim_factory=temp_optim_factory,
        alpha_optim_factory=alpha_optim_factory,
        actor_encoder_factory=encoder_factory,
        critic_encoder_factory=encoder_factory,
        q_func_factory=create_q_func_factory(q_func_factory),
        gamma=gamma,
        tau=tau,
        n_critics=n_critics,
        target_reduction_type=target_reduction_type,
        initial_temperature=initial_temperature,
        initial_alpha=initial_alpha,
        alpha_threshold=alpha_threshold,
        conservative_weight=conservative_weight,
        n_action_samples=n_action_samples,
        soft_q_backup=soft_q_backup,
        use_gpu=None,
        scaler=scaler,
        action_scaler=action_scaler,
        reward_scaler=reward_scaler,
    )
    torch_impl_tester(
        impl, discrete=False, deterministic_best_action=q_func_factory != "iqn"
    )
Esempio n. 14
0
def test_cql_impl(
    observation_shape,
    action_size,
    actor_learning_rate,
    critic_learning_rate,
    temp_learning_rate,
    alpha_learning_rate,
    actor_optim_factory,
    critic_optim_factory,
    temp_optim_factory,
    alpha_optim_factory,
    encoder_factory,
    q_func_factory,
    gamma,
    tau,
    n_critics,
    bootstrap,
    share_encoder,
    initial_temperature,
    initial_alpha,
    alpha_threshold,
    n_action_samples,
    soft_q_backup,
    scaler,
    action_scaler,
    augmentation,
):
    impl = CQLImpl(
        observation_shape,
        action_size,
        actor_learning_rate,
        critic_learning_rate,
        temp_learning_rate,
        alpha_learning_rate,
        actor_optim_factory,
        critic_optim_factory,
        temp_optim_factory,
        alpha_optim_factory,
        encoder_factory,
        encoder_factory,
        create_q_func_factory(q_func_factory),
        gamma,
        tau,
        n_critics,
        bootstrap,
        share_encoder,
        initial_temperature,
        initial_alpha,
        alpha_threshold,
        n_action_samples,
        soft_q_backup,
        use_gpu=False,
        scaler=scaler,
        action_scaler=action_scaler,
        augmentation=augmentation,
    )
    torch_impl_tester(impl,
                      discrete=False,
                      deterministic_best_action=q_func_factory != "iqn")
Esempio n. 15
0
def test_bear_impl(observation_shape, action_size, actor_learning_rate,
                   critic_learning_rate, imitator_learning_rate,
                   temp_learning_rate, alpha_learning_rate,
                   actor_optim_factory, critic_optim_factory,
                   imitator_optim_factory, temp_optim_factory,
                   alpha_optim_factory, encoder_factory, gamma, tau, n_critics,
                   bootstrap, share_encoder, initial_temperature,
                   initial_alpha, alpha_threshold, lam, n_action_samples,
                   mmd_sigma, q_func_type, scaler, augmentation,
                   n_augmentations):
    impl = BEARImpl(observation_shape,
                    action_size,
                    actor_learning_rate,
                    critic_learning_rate,
                    imitator_learning_rate,
                    temp_learning_rate,
                    alpha_learning_rate,
                    actor_optim_factory,
                    critic_optim_factory,
                    imitator_optim_factory,
                    temp_optim_factory,
                    alpha_optim_factory,
                    encoder_factory,
                    encoder_factory,
                    encoder_factory,
                    gamma,
                    tau,
                    n_critics,
                    bootstrap,
                    share_encoder,
                    initial_temperature,
                    initial_alpha,
                    alpha_threshold,
                    lam,
                    n_action_samples,
                    mmd_sigma,
                    q_func_type,
                    use_gpu=False,
                    scaler=scaler,
                    augmentation=augmentation,
                    n_augmentations=n_augmentations)
    impl.build()

    x = torch.rand(32, *observation_shape)
    target = impl.compute_target(x)
    if q_func_type == 'mean':
        assert target.shape == (32, 1)
    else:
        n_quantiles = impl.q_func.q_funcs[0].n_quantiles
        assert target.shape == (32, n_quantiles)

    torch_impl_tester(impl,
                      discrete=False,
                      deterministic_best_action=q_func_type != 'iqn')
Esempio n. 16
0
def test_plas_with_perturbation_impl(
    observation_shape,
    action_size,
    actor_learning_rate,
    critic_learning_rate,
    imitator_learning_rate,
    actor_optim_factory,
    critic_optim_factory,
    imitator_optim_factory,
    encoder_factory,
    q_func_factory,
    gamma,
    tau,
    n_critics,
    bootstrap,
    share_encoder,
    target_reduction_type,
    lam,
    beta,
    action_flexibility,
    scaler,
    action_scaler,
    augmentation,
):
    impl = PLASWithPerturbationImpl(
        observation_shape,
        action_size,
        actor_learning_rate,
        critic_learning_rate,
        imitator_learning_rate,
        actor_optim_factory,
        critic_optim_factory,
        imitator_optim_factory,
        encoder_factory,
        encoder_factory,
        encoder_factory,
        create_q_func_factory(q_func_factory),
        gamma,
        tau,
        n_critics,
        bootstrap,
        share_encoder,
        target_reduction_type,
        lam,
        beta,
        action_flexibility,
        use_gpu=False,
        scaler=scaler,
        action_scaler=action_scaler,
        augmentation=augmentation,
    )
    torch_impl_tester(
        impl, discrete=False, deterministic_best_action=q_func_factory != "iqn"
    )
Esempio n. 17
0
def test_bc_impl(observation_shape, action_size, learning_rate, optim_factory,
                 encoder_factory, scaler, augmentation, n_augmentations):
    impl = BCImpl(observation_shape,
                  action_size,
                  learning_rate,
                  optim_factory,
                  encoder_factory,
                  use_gpu=False,
                  scaler=scaler,
                  augmentation=augmentation,
                  n_augmentations=n_augmentations)
    torch_impl_tester(impl, discrete=False, imitator=True)
Esempio n. 18
0
def test_plas_with_perturbation_impl(
    observation_shape,
    action_size,
    actor_learning_rate,
    critic_learning_rate,
    imitator_learning_rate,
    actor_optim_factory,
    critic_optim_factory,
    imitator_optim_factory,
    encoder_factory,
    q_func_factory,
    gamma,
    tau,
    n_critics,
    target_reduction_type,
    lam,
    beta,
    action_flexibility,
    scaler,
    action_scaler,
    reward_scaler,
):
    impl = PLASWithPerturbationImpl(
        observation_shape=observation_shape,
        action_size=action_size,
        actor_learning_rate=actor_learning_rate,
        critic_learning_rate=critic_learning_rate,
        imitator_learning_rate=imitator_learning_rate,
        actor_optim_factory=actor_optim_factory,
        critic_optim_factory=critic_optim_factory,
        imitator_optim_factory=imitator_optim_factory,
        actor_encoder_factory=encoder_factory,
        critic_encoder_factory=encoder_factory,
        imitator_encoder_factory=encoder_factory,
        q_func_factory=create_q_func_factory(q_func_factory),
        gamma=gamma,
        tau=tau,
        n_critics=n_critics,
        target_reduction_type=target_reduction_type,
        lam=lam,
        beta=beta,
        action_flexibility=action_flexibility,
        use_gpu=None,
        scaler=scaler,
        action_scaler=action_scaler,
        reward_scaler=reward_scaler,
    )
    torch_impl_tester(impl,
                      discrete=False,
                      deterministic_best_action=q_func_factory != "iqn")
Esempio n. 19
0
def test_bc_impl(observation_shape, action_size, learning_rate, eps,
                 use_batch_norm, scaler, augmentation, n_augmentations,
                 encoder_params):
    impl = BCImpl(observation_shape,
                  action_size,
                  learning_rate,
                  eps,
                  use_batch_norm,
                  use_gpu=False,
                  scaler=scaler,
                  augmentation=augmentation,
                  n_augmentations=n_augmentations,
                  encoder_params=encoder_params)
    torch_impl_tester(impl, discrete=False, imitator=True)
Esempio n. 20
0
def test_sac_impl(
    observation_shape,
    action_size,
    actor_learning_rate,
    critic_learning_rate,
    temp_learning_rate,
    actor_optim_factory,
    critic_optim_factory,
    temp_optim_factory,
    encoder_factory,
    q_func_factory,
    gamma,
    tau,
    n_critics,
    bootstrap,
    share_encoder,
    target_reduction_type,
    initial_temperature,
    scaler,
    action_scaler,
    augmentation,
):
    impl = SACImpl(
        observation_shape,
        action_size,
        actor_learning_rate,
        critic_learning_rate,
        temp_learning_rate,
        actor_optim_factory,
        critic_optim_factory,
        temp_optim_factory,
        encoder_factory,
        encoder_factory,
        create_q_func_factory(q_func_factory),
        gamma,
        tau,
        n_critics,
        bootstrap,
        share_encoder,
        target_reduction_type,
        initial_temperature,
        use_gpu=False,
        scaler=scaler,
        action_scaler=action_scaler,
        augmentation=augmentation,
    )
    torch_impl_tester(impl,
                      discrete=False,
                      deterministic_best_action=q_func_factory != "iqn")
Esempio n. 21
0
def test_awr_impl(observation_shape, action_size, actor_learning_rate,
                  critic_learning_rate, momentum, use_batch_norm, scaler,
                  augmentation, n_augmentations, encoder_params):
    impl = AWRImpl(observation_shape,
                   action_size,
                   actor_learning_rate,
                   critic_learning_rate,
                   momentum,
                   use_batch_norm,
                   use_gpu=False,
                   scaler=scaler,
                   augmentation=augmentation,
                   n_augmentations=n_augmentations,
                   encoder_params=encoder_params)
    torch_impl_tester(impl, discrete=False, test_with_std=False)
Esempio n. 22
0
def test_plas_impl(
    observation_shape,
    action_size,
    actor_learning_rate,
    critic_learning_rate,
    imitator_learning_rate,
    actor_optim_factory,
    critic_optim_factory,
    imitator_optim_factory,
    encoder_factory,
    q_func_factory,
    gamma,
    tau,
    n_critics,
    bootstrap,
    share_encoder,
    lam,
    beta,
    scaler,
    augmentation,
):
    impl = PLASImpl(
        observation_shape,
        action_size,
        actor_learning_rate,
        critic_learning_rate,
        imitator_learning_rate,
        actor_optim_factory,
        critic_optim_factory,
        imitator_optim_factory,
        encoder_factory,
        encoder_factory,
        encoder_factory,
        create_q_func_factory(q_func_factory),
        gamma,
        tau,
        n_critics,
        bootstrap,
        share_encoder,
        lam,
        beta,
        use_gpu=False,
        scaler=scaler,
        augmentation=augmentation,
    )
    torch_impl_tester(
        impl, discrete=False, deterministic_best_action=q_func_factory != "iqn"
    )
Esempio n. 23
0
def test_crr_impl(
    observation_shape,
    action_size,
    actor_learning_rate,
    critic_learning_rate,
    actor_optim_factory,
    critic_optim_factory,
    encoder_factory,
    q_func_factory,
    gamma,
    beta,
    n_action_samples,
    advantage_type,
    weight_type,
    max_weight,
    n_critics,
    target_reduction_type,
    scaler,
    action_scaler,
    augmentation,
):
    impl = CRRImpl(
        observation_shape=observation_shape,
        action_size=action_size,
        actor_learning_rate=actor_learning_rate,
        critic_learning_rate=critic_learning_rate,
        actor_optim_factory=actor_optim_factory,
        critic_optim_factory=critic_optim_factory,
        actor_encoder_factory=encoder_factory,
        critic_encoder_factory=encoder_factory,
        q_func_factory=create_q_func_factory(q_func_factory),
        gamma=gamma,
        beta=beta,
        n_action_samples=n_action_samples,
        advantage_type=advantage_type,
        weight_type=weight_type,
        max_weight=max_weight,
        n_critics=n_critics,
        target_reduction_type=target_reduction_type,
        use_gpu=None,
        scaler=scaler,
        action_scaler=action_scaler,
        augmentation=augmentation,
    )
    torch_impl_tester(impl, discrete=False, deterministic_best_action=False)
Esempio n. 24
0
def test_td3_impl(
    observation_shape,
    action_size,
    actor_learning_rate,
    critic_learning_rate,
    actor_optim_factory,
    critic_optim_factory,
    encoder_factory,
    q_func_factory,
    gamma,
    tau,
    n_critics,
    bootstrap,
    share_encoder,
    target_smoothing_sigma,
    target_smoothing_clip,
    scaler,
    action_scaler,
    augmentation,
):
    impl = TD3Impl(
        observation_shape,
        action_size,
        actor_learning_rate,
        critic_learning_rate,
        actor_optim_factory,
        critic_optim_factory,
        encoder_factory,
        encoder_factory,
        create_q_func_factory(q_func_factory),
        gamma,
        tau,
        n_critics,
        bootstrap,
        share_encoder,
        target_smoothing_sigma,
        target_smoothing_clip,
        use_gpu=False,
        scaler=scaler,
        action_scaler=action_scaler,
        augmentation=augmentation,
    )
    torch_impl_tester(impl,
                      discrete=False,
                      deterministic_best_action=q_func_factory != "iqn")
Esempio n. 25
0
def test_td3_plus_bc_impl(
    observation_shape,
    action_size,
    actor_learning_rate,
    critic_learning_rate,
    actor_optim_factory,
    critic_optim_factory,
    encoder_factory,
    q_func_factory,
    gamma,
    tau,
    n_critics,
    target_reduction_type,
    target_smoothing_sigma,
    target_smoothing_clip,
    alpha,
    scaler,
    action_scaler,
    reward_scaler,
):
    impl = TD3PlusBCImpl(
        observation_shape=observation_shape,
        action_size=action_size,
        actor_learning_rate=actor_learning_rate,
        critic_learning_rate=critic_learning_rate,
        actor_optim_factory=actor_optim_factory,
        critic_optim_factory=critic_optim_factory,
        actor_encoder_factory=encoder_factory,
        critic_encoder_factory=encoder_factory,
        q_func_factory=create_q_func_factory(q_func_factory),
        gamma=gamma,
        tau=tau,
        n_critics=n_critics,
        target_reduction_type=target_reduction_type,
        target_smoothing_sigma=target_smoothing_sigma,
        target_smoothing_clip=target_smoothing_clip,
        alpha=alpha,
        use_gpu=None,
        scaler=scaler,
        action_scaler=action_scaler,
        reward_scaler=reward_scaler,
    )
    torch_impl_tester(impl,
                      discrete=False,
                      deterministic_best_action=q_func_factory != "iqn")
Esempio n. 26
0
def test_sac_impl(
    observation_shape,
    action_size,
    actor_learning_rate,
    critic_learning_rate,
    temp_learning_rate,
    actor_optim_factory,
    critic_optim_factory,
    temp_optim_factory,
    encoder_factory,
    q_func_factory,
    gamma,
    tau,
    n_critics,
    target_reduction_type,
    initial_temperature,
    scaler,
    action_scaler,
    reward_scaler,
):
    impl = SACImpl(
        observation_shape=observation_shape,
        action_size=action_size,
        actor_learning_rate=actor_learning_rate,
        critic_learning_rate=critic_learning_rate,
        temp_learning_rate=temp_learning_rate,
        actor_optim_factory=actor_optim_factory,
        critic_optim_factory=critic_optim_factory,
        temp_optim_factory=temp_optim_factory,
        actor_encoder_factory=encoder_factory,
        critic_encoder_factory=encoder_factory,
        q_func_factory=create_q_func_factory(q_func_factory),
        gamma=gamma,
        tau=tau,
        n_critics=n_critics,
        target_reduction_type=target_reduction_type,
        initial_temperature=initial_temperature,
        use_gpu=None,
        scaler=scaler,
        action_scaler=action_scaler,
        reward_scaler=reward_scaler,
    )
    torch_impl_tester(
        impl, discrete=False, deterministic_best_action=q_func_factory != "iqn"
    )
Esempio n. 27
0
def test_awac_impl(
    observation_shape,
    action_size,
    actor_learning_rate,
    critic_learning_rate,
    actor_optim_factory,
    critic_optim_factory,
    encoder_factory,
    q_func_factory,
    gamma,
    tau,
    lam,
    n_action_samples,
    max_weight,
    n_critics,
    target_reduction_type,
    scaler,
    action_scaler,
    augmentation,
):
    impl = AWACImpl(
        observation_shape=observation_shape,
        action_size=action_size,
        actor_learning_rate=actor_learning_rate,
        critic_learning_rate=critic_learning_rate,
        actor_optim_factory=actor_optim_factory,
        critic_optim_factory=critic_optim_factory,
        actor_encoder_factory=encoder_factory,
        critic_encoder_factory=encoder_factory,
        q_func_factory=create_q_func_factory(q_func_factory),
        gamma=gamma,
        tau=tau,
        lam=lam,
        n_action_samples=n_action_samples,
        max_weight=max_weight,
        n_critics=n_critics,
        target_reduction_type=target_reduction_type,
        use_gpu=None,
        scaler=scaler,
        action_scaler=action_scaler,
        augmentation=augmentation,
    )
    torch_impl_tester(
        impl, discrete=False, deterministic_best_action=q_func_factory != "iqn"
    )
Esempio n. 28
0
def test_dqn_impl(observation_shape, action_size, learning_rate, optim_factory,
                  encoder_factory, gamma, n_critics, bootstrap, share_encoder,
                  q_func_type, scaler, augmentation, n_augmentations):
    impl = DQNImpl(observation_shape,
                   action_size,
                   learning_rate,
                   optim_factory,
                   encoder_factory,
                   gamma,
                   n_critics,
                   bootstrap,
                   share_encoder,
                   q_func_type,
                   use_gpu=False,
                   scaler=scaler,
                   augmentation=augmentation,
                   n_augmentations=n_augmentations)
    torch_impl_tester(impl,
                      discrete=True,
                      deterministic_best_action=q_func_type != 'iqn')
Esempio n. 29
0
def test_discrete_bc_impl(
    observation_shape,
    action_size,
    learning_rate,
    optim_factory,
    encoder_factory,
    beta,
    scaler,
):
    impl = DiscreteBCImpl(
        observation_shape=observation_shape,
        action_size=action_size,
        learning_rate=learning_rate,
        optim_factory=optim_factory,
        encoder_factory=encoder_factory,
        beta=beta,
        use_gpu=None,
        scaler=scaler,
    )
    torch_impl_tester(impl, discrete=True, imitator=True)
Esempio n. 30
0
def test_double_dqn_impl(observation_shape, action_size, learning_rate, gamma,
                         n_critics, bootstrap, share_encoder, eps,
                         use_batch_norm, q_func_type, scaler, augmentation,
                         n_augmentations, encoder_params):
    impl = DoubleDQNImpl(observation_shape,
                         action_size,
                         learning_rate,
                         gamma,
                         n_critics,
                         bootstrap,
                         share_encoder,
                         eps,
                         use_batch_norm,
                         q_func_type=q_func_type,
                         use_gpu=False,
                         scaler=scaler,
                         augmentation=augmentation,
                         n_augmentations=n_augmentations,
                         encoder_params=encoder_params)
    torch_impl_tester(impl,
                      discrete=True,
                      deterministic_best_action=q_func_type != 'iqn')