コード例 #1
0
def test_discrete_cql_impl(
    observation_shape,
    action_size,
    learning_rate,
    optim_factory,
    encoder_factory,
    q_func_factory,
    gamma,
    n_critics,
    target_reduction_type,
    alpha,
    scaler,
    reward_scaler,
):
    impl = DiscreteCQLImpl(
        observation_shape=observation_shape,
        action_size=action_size,
        learning_rate=learning_rate,
        optim_factory=optim_factory,
        encoder_factory=encoder_factory,
        q_func_factory=create_q_func_factory(q_func_factory),
        gamma=gamma,
        n_critics=n_critics,
        target_reduction_type=target_reduction_type,
        alpha=alpha,
        use_gpu=None,
        scaler=scaler,
        reward_scaler=reward_scaler,
    )
    torch_impl_tester(
        impl, discrete=True, deterministic_best_action=q_func_factory != "iqn"
    )
コード例 #2
0
def test_discrete_fqe_impl(
    observation_shape,
    action_size,
    learning_rate,
    optim_factory,
    encoder_factory,
    q_func_factory,
    gamma,
    n_critics,
    scaler,
):
    fqe = DiscreteFQEImpl(
        observation_shape=observation_shape,
        action_size=action_size,
        learning_rate=learning_rate,
        optim_factory=optim_factory,
        encoder_factory=encoder_factory,
        q_func_factory=create_q_func_factory(q_func_factory),
        gamma=gamma,
        n_critics=n_critics,
        use_gpu=None,
        scaler=scaler,
        action_scaler=None,
    )

    torch_impl_tester(fqe, True)
コード例 #3
0
ファイル: test_fqe_impl.py プロジェクト: navidmdn/d3rlpy
def test_discrete_fqe_impl(
    observation_shape,
    action_size,
    learning_rate,
    optim_factory,
    encoder_factory,
    q_func_factory,
    gamma,
    n_critics,
    bootstrap,
    share_encoder,
    scaler,
):
    fqe = DiscreteFQEImpl(
        observation_shape,
        action_size,
        learning_rate,
        optim_factory,
        encoder_factory,
        create_q_func_factory(q_func_factory),
        gamma,
        n_critics,
        bootstrap,
        share_encoder,
        use_gpu=False,
        scaler=scaler,
        action_scaler=None,
    )

    torch_impl_tester(fqe, True)
コード例 #4
0
def test_double_dqn_impl(
    observation_shape,
    action_size,
    learning_rate,
    optim_factory,
    encoder_factory,
    q_func_factory,
    gamma,
    n_critics,
    target_reduction_type,
    scaler,
    augmentation,
):
    impl = DoubleDQNImpl(
        observation_shape=observation_shape,
        action_size=action_size,
        learning_rate=learning_rate,
        optim_factory=optim_factory,
        encoder_factory=encoder_factory,
        q_func_factory=create_q_func_factory(q_func_factory),
        gamma=gamma,
        n_critics=n_critics,
        target_reduction_type=target_reduction_type,
        use_gpu=None,
        scaler=scaler,
        augmentation=augmentation,
    )
    torch_impl_tester(impl,
                      discrete=True,
                      deterministic_best_action=q_func_factory != "iqn")
コード例 #5
0
ファイル: test_bcq_impl.py プロジェクト: wx-b/d3rlpy
def test_discrete_bcq_impl(
    observation_shape,
    action_size,
    learning_rate,
    optim_factory,
    encoder_factory,
    q_func_factory,
    gamma,
    n_critics,
    target_reduction_type,
    action_flexibility,
    beta,
    scaler,
):
    impl = DiscreteBCQImpl(
        observation_shape=observation_shape,
        action_size=action_size,
        learning_rate=learning_rate,
        optim_factory=optim_factory,
        encoder_factory=encoder_factory,
        q_func_factory=create_q_func_factory(q_func_factory),
        gamma=gamma,
        n_critics=n_critics,
        target_reduction_type=target_reduction_type,
        action_flexibility=action_flexibility,
        beta=beta,
        use_gpu=None,
        scaler=scaler,
    )
    torch_impl_tester(impl,
                      discrete=True,
                      deterministic_best_action=q_func_factory != "iqn")
コード例 #6
0
def test_discrete_cql_impl(
    observation_shape,
    action_size,
    learning_rate,
    optim_factory,
    encoder_factory,
    q_func_factory,
    gamma,
    n_critics,
    bootstrap,
    share_encoder,
    scaler,
    augmentation,
):
    impl = DiscreteCQLImpl(
        observation_shape,
        action_size,
        learning_rate,
        optim_factory,
        encoder_factory,
        create_q_func_factory(q_func_factory),
        gamma,
        n_critics,
        bootstrap,
        share_encoder,
        use_gpu=False,
        scaler=scaler,
        augmentation=augmentation,
    )
    torch_impl_tester(impl,
                      discrete=True,
                      deterministic_best_action=q_func_factory != "iqn")
コード例 #7
0
def test_cql_impl(
    observation_shape,
    action_size,
    actor_learning_rate,
    critic_learning_rate,
    temp_learning_rate,
    alpha_learning_rate,
    actor_optim_factory,
    critic_optim_factory,
    temp_optim_factory,
    alpha_optim_factory,
    encoder_factory,
    q_func_factory,
    gamma,
    tau,
    n_critics,
    target_reduction_type,
    initial_temperature,
    initial_alpha,
    alpha_threshold,
    conservative_weight,
    n_action_samples,
    soft_q_backup,
    scaler,
    action_scaler,
    reward_scaler,
):
    impl = CQLImpl(
        observation_shape=observation_shape,
        action_size=action_size,
        actor_learning_rate=actor_learning_rate,
        critic_learning_rate=critic_learning_rate,
        temp_learning_rate=temp_learning_rate,
        alpha_learning_rate=alpha_learning_rate,
        actor_optim_factory=actor_optim_factory,
        critic_optim_factory=critic_optim_factory,
        temp_optim_factory=temp_optim_factory,
        alpha_optim_factory=alpha_optim_factory,
        actor_encoder_factory=encoder_factory,
        critic_encoder_factory=encoder_factory,
        q_func_factory=create_q_func_factory(q_func_factory),
        gamma=gamma,
        tau=tau,
        n_critics=n_critics,
        target_reduction_type=target_reduction_type,
        initial_temperature=initial_temperature,
        initial_alpha=initial_alpha,
        alpha_threshold=alpha_threshold,
        conservative_weight=conservative_weight,
        n_action_samples=n_action_samples,
        soft_q_backup=soft_q_backup,
        use_gpu=None,
        scaler=scaler,
        action_scaler=action_scaler,
        reward_scaler=reward_scaler,
    )
    torch_impl_tester(
        impl, discrete=False, deterministic_best_action=q_func_factory != "iqn"
    )
コード例 #8
0
def test_cql_impl(
    observation_shape,
    action_size,
    actor_learning_rate,
    critic_learning_rate,
    temp_learning_rate,
    alpha_learning_rate,
    actor_optim_factory,
    critic_optim_factory,
    temp_optim_factory,
    alpha_optim_factory,
    encoder_factory,
    q_func_factory,
    gamma,
    tau,
    n_critics,
    bootstrap,
    share_encoder,
    initial_temperature,
    initial_alpha,
    alpha_threshold,
    n_action_samples,
    soft_q_backup,
    scaler,
    action_scaler,
    augmentation,
):
    impl = CQLImpl(
        observation_shape,
        action_size,
        actor_learning_rate,
        critic_learning_rate,
        temp_learning_rate,
        alpha_learning_rate,
        actor_optim_factory,
        critic_optim_factory,
        temp_optim_factory,
        alpha_optim_factory,
        encoder_factory,
        encoder_factory,
        create_q_func_factory(q_func_factory),
        gamma,
        tau,
        n_critics,
        bootstrap,
        share_encoder,
        initial_temperature,
        initial_alpha,
        alpha_threshold,
        n_action_samples,
        soft_q_backup,
        use_gpu=False,
        scaler=scaler,
        action_scaler=action_scaler,
        augmentation=augmentation,
    )
    torch_impl_tester(impl,
                      discrete=False,
                      deterministic_best_action=q_func_factory != "iqn")
コード例 #9
0
ファイル: test_q_functions.py プロジェクト: vmbbc/d3rlpy
def test_create_q_func_factory(name):
    factory = create_q_func_factory(name)
    if name == "mean":
        assert isinstance(factory, MeanQFunctionFactory)
    elif name == "qr":
        assert isinstance(factory, QRQFunctionFactory)
    elif name == "iqn":
        assert isinstance(factory, IQNQFunctionFactory)
    elif name == "fqf":
        assert isinstance(factory, FQFQFunctionFactory)
コード例 #10
0
ファイル: test_plas_impl.py プロジェクト: navidmdn/d3rlpy
def test_plas_with_perturbation_impl(
    observation_shape,
    action_size,
    actor_learning_rate,
    critic_learning_rate,
    imitator_learning_rate,
    actor_optim_factory,
    critic_optim_factory,
    imitator_optim_factory,
    encoder_factory,
    q_func_factory,
    gamma,
    tau,
    n_critics,
    bootstrap,
    share_encoder,
    target_reduction_type,
    lam,
    beta,
    action_flexibility,
    scaler,
    action_scaler,
    augmentation,
):
    impl = PLASWithPerturbationImpl(
        observation_shape,
        action_size,
        actor_learning_rate,
        critic_learning_rate,
        imitator_learning_rate,
        actor_optim_factory,
        critic_optim_factory,
        imitator_optim_factory,
        encoder_factory,
        encoder_factory,
        encoder_factory,
        create_q_func_factory(q_func_factory),
        gamma,
        tau,
        n_critics,
        bootstrap,
        share_encoder,
        target_reduction_type,
        lam,
        beta,
        action_flexibility,
        use_gpu=False,
        scaler=scaler,
        action_scaler=action_scaler,
        augmentation=augmentation,
    )
    torch_impl_tester(
        impl, discrete=False, deterministic_best_action=q_func_factory != "iqn"
    )
コード例 #11
0
def test_plas_with_perturbation_impl(
    observation_shape,
    action_size,
    actor_learning_rate,
    critic_learning_rate,
    imitator_learning_rate,
    actor_optim_factory,
    critic_optim_factory,
    imitator_optim_factory,
    encoder_factory,
    q_func_factory,
    gamma,
    tau,
    n_critics,
    target_reduction_type,
    lam,
    beta,
    action_flexibility,
    scaler,
    action_scaler,
    reward_scaler,
):
    impl = PLASWithPerturbationImpl(
        observation_shape=observation_shape,
        action_size=action_size,
        actor_learning_rate=actor_learning_rate,
        critic_learning_rate=critic_learning_rate,
        imitator_learning_rate=imitator_learning_rate,
        actor_optim_factory=actor_optim_factory,
        critic_optim_factory=critic_optim_factory,
        imitator_optim_factory=imitator_optim_factory,
        actor_encoder_factory=encoder_factory,
        critic_encoder_factory=encoder_factory,
        imitator_encoder_factory=encoder_factory,
        q_func_factory=create_q_func_factory(q_func_factory),
        gamma=gamma,
        tau=tau,
        n_critics=n_critics,
        target_reduction_type=target_reduction_type,
        lam=lam,
        beta=beta,
        action_flexibility=action_flexibility,
        use_gpu=None,
        scaler=scaler,
        action_scaler=action_scaler,
        reward_scaler=reward_scaler,
    )
    torch_impl_tester(impl,
                      discrete=False,
                      deterministic_best_action=q_func_factory != "iqn")
コード例 #12
0
def test_awac_impl(
    observation_shape,
    action_size,
    actor_learning_rate,
    critic_learning_rate,
    actor_optim_factory,
    critic_optim_factory,
    encoder_factory,
    q_func_factory,
    gamma,
    tau,
    lam,
    n_action_samples,
    max_weight,
    n_critics,
    bootstrap,
    share_encoder,
    target_reduction_type,
    scaler,
    action_scaler,
    augmentation,
):
    impl = AWACImpl(
        observation_shape,
        action_size,
        actor_learning_rate,
        critic_learning_rate,
        actor_optim_factory,
        critic_optim_factory,
        encoder_factory,
        encoder_factory,
        create_q_func_factory(q_func_factory),
        gamma,
        tau,
        lam,
        n_action_samples,
        max_weight,
        n_critics,
        bootstrap,
        share_encoder,
        target_reduction_type,
        use_gpu=False,
        scaler=scaler,
        action_scaler=action_scaler,
        augmentation=augmentation,
    )
    torch_impl_tester(impl,
                      discrete=False,
                      deterministic_best_action=q_func_factory != "iqn")
コード例 #13
0
ファイル: test_td3_impl.py プロジェクト: navidmdn/d3rlpy
def test_td3_impl(
    observation_shape,
    action_size,
    actor_learning_rate,
    critic_learning_rate,
    actor_optim_factory,
    critic_optim_factory,
    encoder_factory,
    q_func_factory,
    gamma,
    tau,
    n_critics,
    bootstrap,
    share_encoder,
    target_reduction_type,
    target_smoothing_sigma,
    target_smoothing_clip,
    scaler,
    action_scaler,
    augmentation,
):
    impl = TD3Impl(
        observation_shape,
        action_size,
        actor_learning_rate,
        critic_learning_rate,
        actor_optim_factory,
        critic_optim_factory,
        encoder_factory,
        encoder_factory,
        create_q_func_factory(q_func_factory),
        gamma,
        tau,
        n_critics,
        bootstrap,
        share_encoder,
        target_reduction_type,
        target_smoothing_sigma,
        target_smoothing_clip,
        use_gpu=False,
        scaler=scaler,
        action_scaler=action_scaler,
        augmentation=augmentation,
    )
    torch_impl_tester(impl,
                      discrete=False,
                      deterministic_best_action=q_func_factory != "iqn")
コード例 #14
0
def test_sac_impl(
    observation_shape,
    action_size,
    actor_learning_rate,
    critic_learning_rate,
    temp_learning_rate,
    actor_optim_factory,
    critic_optim_factory,
    temp_optim_factory,
    encoder_factory,
    q_func_factory,
    gamma,
    tau,
    n_critics,
    bootstrap,
    share_encoder,
    initial_temperature,
    scaler,
    action_scaler,
    augmentation,
):
    impl = SACImpl(
        observation_shape,
        action_size,
        actor_learning_rate,
        critic_learning_rate,
        temp_learning_rate,
        actor_optim_factory,
        critic_optim_factory,
        temp_optim_factory,
        encoder_factory,
        encoder_factory,
        create_q_func_factory(q_func_factory),
        gamma,
        tau,
        n_critics,
        bootstrap,
        share_encoder,
        initial_temperature,
        use_gpu=False,
        scaler=scaler,
        action_scaler=action_scaler,
        augmentation=augmentation,
    )
    torch_impl_tester(
        impl, discrete=False, deterministic_best_action=q_func_factory != "iqn"
    )
コード例 #15
0
def test_crr_impl(
    observation_shape,
    action_size,
    actor_learning_rate,
    critic_learning_rate,
    actor_optim_factory,
    critic_optim_factory,
    encoder_factory,
    q_func_factory,
    gamma,
    beta,
    n_action_samples,
    advantage_type,
    weight_type,
    max_weight,
    n_critics,
    target_reduction_type,
    scaler,
    action_scaler,
    augmentation,
):
    impl = CRRImpl(
        observation_shape=observation_shape,
        action_size=action_size,
        actor_learning_rate=actor_learning_rate,
        critic_learning_rate=critic_learning_rate,
        actor_optim_factory=actor_optim_factory,
        critic_optim_factory=critic_optim_factory,
        actor_encoder_factory=encoder_factory,
        critic_encoder_factory=encoder_factory,
        q_func_factory=create_q_func_factory(q_func_factory),
        gamma=gamma,
        beta=beta,
        n_action_samples=n_action_samples,
        advantage_type=advantage_type,
        weight_type=weight_type,
        max_weight=max_weight,
        n_critics=n_critics,
        target_reduction_type=target_reduction_type,
        use_gpu=None,
        scaler=scaler,
        action_scaler=action_scaler,
        augmentation=augmentation,
    )
    torch_impl_tester(impl, discrete=False, deterministic_best_action=False)
コード例 #16
0
def test_td3_plus_bc_impl(
    observation_shape,
    action_size,
    actor_learning_rate,
    critic_learning_rate,
    actor_optim_factory,
    critic_optim_factory,
    encoder_factory,
    q_func_factory,
    gamma,
    tau,
    n_critics,
    target_reduction_type,
    target_smoothing_sigma,
    target_smoothing_clip,
    alpha,
    scaler,
    action_scaler,
    reward_scaler,
):
    impl = TD3PlusBCImpl(
        observation_shape=observation_shape,
        action_size=action_size,
        actor_learning_rate=actor_learning_rate,
        critic_learning_rate=critic_learning_rate,
        actor_optim_factory=actor_optim_factory,
        critic_optim_factory=critic_optim_factory,
        actor_encoder_factory=encoder_factory,
        critic_encoder_factory=encoder_factory,
        q_func_factory=create_q_func_factory(q_func_factory),
        gamma=gamma,
        tau=tau,
        n_critics=n_critics,
        target_reduction_type=target_reduction_type,
        target_smoothing_sigma=target_smoothing_sigma,
        target_smoothing_clip=target_smoothing_clip,
        alpha=alpha,
        use_gpu=None,
        scaler=scaler,
        action_scaler=action_scaler,
        reward_scaler=reward_scaler,
    )
    torch_impl_tester(impl,
                      discrete=False,
                      deterministic_best_action=q_func_factory != "iqn")
コード例 #17
0
def test_awac_impl(
    observation_shape,
    action_size,
    actor_learning_rate,
    critic_learning_rate,
    actor_optim_factory,
    critic_optim_factory,
    encoder_factory,
    q_func_factory,
    gamma,
    tau,
    lam,
    n_action_samples,
    max_weight,
    n_critics,
    target_reduction_type,
    scaler,
    action_scaler,
    reward_scaler,
):
    impl = AWACImpl(
        observation_shape=observation_shape,
        action_size=action_size,
        actor_learning_rate=actor_learning_rate,
        critic_learning_rate=critic_learning_rate,
        actor_optim_factory=actor_optim_factory,
        critic_optim_factory=critic_optim_factory,
        actor_encoder_factory=encoder_factory,
        critic_encoder_factory=encoder_factory,
        q_func_factory=create_q_func_factory(q_func_factory),
        gamma=gamma,
        tau=tau,
        lam=lam,
        n_action_samples=n_action_samples,
        max_weight=max_weight,
        n_critics=n_critics,
        target_reduction_type=target_reduction_type,
        use_gpu=None,
        scaler=scaler,
        action_scaler=action_scaler,
        reward_scaler=reward_scaler,
    )
    torch_impl_tester(
        impl, discrete=False, deterministic_best_action=q_func_factory != "iqn"
    )
コード例 #18
0
ファイル: test_sac_impl.py プロジェクト: jkbjh/d3rlpy
def test_sac_impl(
    observation_shape,
    action_size,
    actor_learning_rate,
    critic_learning_rate,
    temp_learning_rate,
    actor_optim_factory,
    critic_optim_factory,
    temp_optim_factory,
    encoder_factory,
    q_func_factory,
    gamma,
    tau,
    n_critics,
    target_reduction_type,
    initial_temperature,
    scaler,
    action_scaler,
    augmentation,
):
    impl = SACImpl(
        observation_shape=observation_shape,
        action_size=action_size,
        actor_learning_rate=actor_learning_rate,
        critic_learning_rate=critic_learning_rate,
        temp_learning_rate=temp_learning_rate,
        actor_optim_factory=actor_optim_factory,
        critic_optim_factory=critic_optim_factory,
        temp_optim_factory=temp_optim_factory,
        actor_encoder_factory=encoder_factory,
        critic_encoder_factory=encoder_factory,
        q_func_factory=create_q_func_factory(q_func_factory),
        gamma=gamma,
        tau=tau,
        n_critics=n_critics,
        target_reduction_type=target_reduction_type,
        initial_temperature=initial_temperature,
        use_gpu=None,
        scaler=scaler,
        action_scaler=action_scaler,
        augmentation=augmentation,
    )
    torch_impl_tester(impl,
                      discrete=False,
                      deterministic_best_action=q_func_factory != "iqn")
コード例 #19
0
def test_discrete_sac_impl(
    observation_shape,
    action_size,
    actor_learning_rate,
    critic_learning_rate,
    temp_learning_rate,
    actor_optim_factory,
    critic_optim_factory,
    temp_optim_factory,
    encoder_factory,
    q_func_factory,
    gamma,
    n_critics,
    initial_temperature,
    scaler,
    reward_scaler,
):
    impl = DiscreteSACImpl(
        observation_shape=observation_shape,
        action_size=action_size,
        actor_learning_rate=actor_learning_rate,
        critic_learning_rate=critic_learning_rate,
        temp_learning_rate=temp_learning_rate,
        actor_optim_factory=actor_optim_factory,
        critic_optim_factory=critic_optim_factory,
        temp_optim_factory=temp_optim_factory,
        actor_encoder_factory=encoder_factory,
        critic_encoder_factory=encoder_factory,
        q_func_factory=create_q_func_factory(q_func_factory),
        gamma=gamma,
        n_critics=n_critics,
        initial_temperature=initial_temperature,
        use_gpu=None,
        scaler=scaler,
        reward_scaler=reward_scaler,
    )
    torch_impl_tester(
        impl, discrete=True, deterministic_best_action=q_func_factory != "iqn"
    )
コード例 #20
0
ファイル: test_bcq_impl.py プロジェクト: navidmdn/d3rlpy
def test_discrete_bcq_impl(
    observation_shape,
    action_size,
    learning_rate,
    optim_factory,
    encoder_factory,
    q_func_factory,
    gamma,
    n_critics,
    bootstrap,
    share_encoder,
    target_reduction_type,
    action_flexibility,
    beta,
    scaler,
    augmentation,
):
    impl = DiscreteBCQImpl(
        observation_shape,
        action_size,
        learning_rate,
        optim_factory,
        encoder_factory,
        create_q_func_factory(q_func_factory),
        gamma,
        n_critics,
        bootstrap,
        share_encoder,
        target_reduction_type,
        action_flexibility,
        beta,
        use_gpu=False,
        scaler=scaler,
        augmentation=augmentation,
    )
    torch_impl_tester(impl,
                      discrete=True,
                      deterministic_best_action=q_func_factory != "iqn")
コード例 #21
0
ファイル: test_bcq_impl.py プロジェクト: wx-b/d3rlpy
def test_bcq_impl(
    observation_shape,
    action_size,
    actor_learning_rate,
    critic_learning_rate,
    imitator_learning_rate,
    actor_optim_factory,
    critic_optim_factory,
    imitator_optim_factory,
    encoder_factory,
    q_func_factory,
    gamma,
    tau,
    n_critics,
    lam,
    n_action_samples,
    action_flexibility,
    latent_size,
    beta,
    scaler,
    action_scaler,
):
    impl = BCQImpl(
        observation_shape=observation_shape,
        action_size=action_size,
        actor_learning_rate=actor_learning_rate,
        critic_learning_rate=critic_learning_rate,
        imitator_learning_rate=imitator_learning_rate,
        actor_optim_factory=actor_optim_factory,
        critic_optim_factory=critic_optim_factory,
        imitator_optim_factory=imitator_optim_factory,
        actor_encoder_factory=encoder_factory,
        critic_encoder_factory=encoder_factory,
        imitator_encoder_factory=encoder_factory,
        q_func_factory=create_q_func_factory(q_func_factory),
        gamma=gamma,
        tau=tau,
        n_critics=n_critics,
        lam=lam,
        n_action_samples=n_action_samples,
        action_flexibility=action_flexibility,
        latent_size=latent_size,
        beta=beta,
        use_gpu=None,
        scaler=scaler,
        action_scaler=action_scaler,
    )
    impl.build()

    # test internal methods
    x = torch.rand(32, *observation_shape)

    repeated_x = impl._repeat_observation(x)
    assert repeated_x.shape == (32, n_action_samples) + observation_shape

    action = impl._sample_repeated_action(repeated_x)
    assert action.shape == (32, n_action_samples, action_size)

    value = impl._predict_value(repeated_x, action)
    assert value.shape == (n_critics, 32 * n_action_samples, 1)

    best_action = impl._predict_best_action(x)
    assert best_action.shape == (32, action_size)

    torch_impl_tester(impl, discrete=False, deterministic_best_action=False)
コード例 #22
0
ファイル: test_bear_impl.py プロジェクト: YangRui2015/d3rlpy
def test_bear_impl(
    observation_shape,
    action_size,
    actor_learning_rate,
    critic_learning_rate,
    imitator_learning_rate,
    temp_learning_rate,
    alpha_learning_rate,
    actor_optim_factory,
    critic_optim_factory,
    imitator_optim_factory,
    temp_optim_factory,
    alpha_optim_factory,
    encoder_factory,
    q_func_factory,
    gamma,
    tau,
    n_critics,
    initial_temperature,
    initial_alpha,
    alpha_threshold,
    lam,
    n_action_samples,
    mmd_kernel,
    mmd_sigma,
    scaler,
    action_scaler,
    augmentation,
):
    impl = BEARImpl(
        observation_shape=observation_shape,
        action_size=action_size,
        actor_learning_rate=actor_learning_rate,
        critic_learning_rate=critic_learning_rate,
        imitator_learning_rate=imitator_learning_rate,
        temp_learning_rate=temp_learning_rate,
        alpha_learning_rate=alpha_learning_rate,
        actor_optim_factory=actor_optim_factory,
        critic_optim_factory=critic_optim_factory,
        imitator_optim_factory=imitator_optim_factory,
        temp_optim_factory=temp_optim_factory,
        alpha_optim_factory=alpha_optim_factory,
        actor_encoder_factory=encoder_factory,
        critic_encoder_factory=encoder_factory,
        imitator_encoder_factory=encoder_factory,
        q_func_factory=create_q_func_factory(q_func_factory),
        gamma=gamma,
        tau=tau,
        n_critics=n_critics,
        initial_temperature=initial_temperature,
        initial_alpha=initial_alpha,
        alpha_threshold=alpha_threshold,
        lam=lam,
        n_action_samples=n_action_samples,
        mmd_kernel=mmd_kernel,
        mmd_sigma=mmd_sigma,
        use_gpu=None,
        scaler=scaler,
        action_scaler=action_scaler,
        augmentation=augmentation,
    )
    impl.build()

    x = torch.rand(32, *observation_shape)
    target = impl.compute_target(x)
    if q_func_factory == "mean":
        assert target.shape == (32, 1)
    else:
        n_quantiles = impl._q_func.q_funcs[0]._n_quantiles
        assert target.shape == (32, n_quantiles)

    torch_impl_tester(impl,
                      discrete=False,
                      deterministic_best_action=q_func_factory != "iqn")
コード例 #23
0
def test_bear_impl(
    observation_shape,
    action_size,
    actor_learning_rate,
    critic_learning_rate,
    imitator_learning_rate,
    temp_learning_rate,
    alpha_learning_rate,
    actor_optim_factory,
    critic_optim_factory,
    imitator_optim_factory,
    temp_optim_factory,
    alpha_optim_factory,
    encoder_factory,
    q_func_factory,
    gamma,
    tau,
    n_critics,
    initial_temperature,
    initial_alpha,
    alpha_threshold,
    lam,
    n_action_samples,
    n_target_samples,
    n_mmd_action_samples,
    mmd_kernel,
    mmd_sigma,
    vae_kl_weight,
    scaler,
    action_scaler,
):
    impl = BEARImpl(
        observation_shape=observation_shape,
        action_size=action_size,
        actor_learning_rate=actor_learning_rate,
        critic_learning_rate=critic_learning_rate,
        imitator_learning_rate=imitator_learning_rate,
        temp_learning_rate=temp_learning_rate,
        alpha_learning_rate=alpha_learning_rate,
        actor_optim_factory=actor_optim_factory,
        critic_optim_factory=critic_optim_factory,
        imitator_optim_factory=imitator_optim_factory,
        temp_optim_factory=temp_optim_factory,
        alpha_optim_factory=alpha_optim_factory,
        actor_encoder_factory=encoder_factory,
        critic_encoder_factory=encoder_factory,
        imitator_encoder_factory=encoder_factory,
        q_func_factory=create_q_func_factory(q_func_factory),
        gamma=gamma,
        tau=tau,
        n_critics=n_critics,
        initial_temperature=initial_temperature,
        initial_alpha=initial_alpha,
        alpha_threshold=alpha_threshold,
        lam=lam,
        n_action_samples=n_action_samples,
        n_target_samples=n_target_samples,
        n_mmd_action_samples=n_mmd_action_samples,
        mmd_kernel=mmd_kernel,
        mmd_sigma=mmd_sigma,
        vae_kl_weight=vae_kl_weight,
        use_gpu=None,
        scaler=scaler,
        action_scaler=action_scaler,
    )
    impl.build()

    torch_impl_tester(impl, discrete=False, deterministic_best_action=False)
コード例 #24
0
ファイル: test_bcq_impl.py プロジェクト: navidmdn/d3rlpy
def test_bcq_impl(
    observation_shape,
    action_size,
    actor_learning_rate,
    critic_learning_rate,
    imitator_learning_rate,
    actor_optim_factory,
    critic_optim_factory,
    imitator_optim_factory,
    encoder_factory,
    q_func_factory,
    gamma,
    tau,
    n_critics,
    bootstrap,
    share_encoder,
    lam,
    n_action_samples,
    action_flexibility,
    latent_size,
    beta,
    scaler,
    action_scaler,
    augmentation,
):
    impl = BCQImpl(
        observation_shape,
        action_size,
        actor_learning_rate,
        critic_learning_rate,
        imitator_learning_rate,
        actor_optim_factory,
        critic_optim_factory,
        imitator_optim_factory,
        encoder_factory,
        encoder_factory,
        encoder_factory,
        create_q_func_factory(q_func_factory),
        gamma,
        tau,
        n_critics,
        bootstrap,
        share_encoder,
        lam,
        n_action_samples,
        action_flexibility,
        latent_size,
        beta,
        use_gpu=False,
        scaler=scaler,
        action_scaler=action_scaler,
        augmentation=augmentation,
    )
    impl.build()

    # test internal methods
    x = torch.rand(32, *observation_shape)

    repeated_x = impl._repeat_observation(x)
    assert repeated_x.shape == (32, n_action_samples) + observation_shape

    action = impl._sample_repeated_action(repeated_x)
    assert action.shape == (32, n_action_samples, action_size)

    value = impl._predict_value(repeated_x, action)
    assert value.shape == (n_critics, 32 * n_action_samples, 1)

    target = impl.compute_target(x)
    if q_func_factory == "mean":
        assert target.shape == (32, 1)
    else:
        assert target.shape == (32, impl._q_func.q_funcs[0]._n_quantiles)

    best_action = impl._predict_best_action(x)
    assert best_action.shape == (32, action_size)

    torch_impl_tester(impl, discrete=False, deterministic_best_action=False)