コード例 #1
0
def test_ceppo(local_mode=False):
    _base(
        CEPPOTrainer,
        local_mode,
        extra_config={
            "mode": tune.grid_search(
                [
                    # DISABLE,
                    DISABLE_AND_EXPAND,
                    REPLAY_VALUES,
                    # NO_REPLAY_VALUES,
                    # DIVERSITY_ENCOURAGING,
                    # DIVERSITY_ENCOURAGING_NO_RV,
                    # DIVERSITY_ENCOURAGING_DISABLE,
                    # DIVERSITY_ENCOURAGING_DISABLE_AND_EXPAND, CURIOSITY,
                    # CURIOSITY_NO_RV,
                    # CURIOSITY_DISABLE,
                    # CURIOSITY_DISABLE_AND_EXPAND,
                    # CURIOSITY_KL,
                    # CURIOSITY_KL_NO_RV,
                    # CURIOSITY_KL_DISABLE,
                    # CURIOSITY_KL_DISABLE_AND_EXPAND
                ]
            ),
            "num_cpus_per_worker": 0.5,
            "num_workers": 1,

            # new config:
            "clip_action_prob_kl": 0.0
        },
        # config={"mode": DIVERSITY_ENCOURAGING},
        env_name="Pendulum-v0",
        t=10000
    )
コード例 #2
0
def test_dece(config={}, local_mode=False, t=2000, **kwargs):
    _base(
        trainer=DECETrainer,
        local_mode=local_mode,
        extra_config=config,
        env_name="Pendulum-v0",
        t=t,
        **kwargs
    )
コード例 #3
0
def test_vtrace_single_agent(local_mode=False):
    _base(
        trainer=DECETrainer,
        local_mode=local_mode,
        extra_config={
            REPLAY_VALUES: tune.grid_search([True, False]),
            'sample_batch_size': 50,
            'train_batch_size': 200,
            'num_sgd_iter': 10,
            'sgd_minibatch_size': 50
        },
        env_name=FourWayGridWorld,
        t=20000,
        num_agents=1
    )
コード例 #4
0
def mock_experiment(lm=False):
    _base(
        trainer=DECETrainer,
        local_mode=lm,
        extra_config={
            DELAY_UPDATE: tune.grid_search([True, False]),
            REPLAY_VALUES: tune.grid_search([True, False]),
            'sample_batch_size': 20,
            'sgd_minibatch_size': 100,
            'train_batch_size': 500,
        },
        env_name=FourWayGridWorld,
        t={'timesteps_total': 5000},
        num_agents=tune.grid_search([1, 5])
    )
コード例 #5
0
def test_vtrace(local_mode=False, hard=False):
    _base(
        trainer=DECETrainer,
        local_mode=local_mode,
        extra_config={
            REPLAY_VALUES: True,
            'sample_batch_size': 50 if hard else 8,
            'train_batch_size': 450 if hard else 96,
            'num_sgd_iter': 10 if hard else 2,
            "sgd_minibatch_size": 150 if hard else 3 * 8,
            'model': {
                'fcnet_hiddens': [16, 16]
            },
            'seed': 0
            # 'lr': 5e-3,
        },
        env_name=FourWayGridWorld,
        t=100000
    )
コード例 #6
0
def no_replay_values_batch_size_bug(lm=False):
    _base(
        trainer=DECETrainer,
        local_mode=lm,
        extra_config={
            REPLAY_VALUES: tune.grid_search([True, False]),
            CONSTRAIN_NOVELTY: tune.grid_search(['soft', 'hard', None]),
            'num_envs_per_worker': 4,
            'sample_batch_size': 20,
            'sgd_minibatch_size': 100,
            'train_batch_size': 1000,
            "num_cpus_per_worker": 1,
            "num_cpus_for_driver": 1,
            'num_workers': 2,
        },
        env_name=FourWayGridWorld,
        t=1000000,
        num_agents=tune.grid_search([5])
    )
コード例 #7
0
def regression_test(local_mode=False):
    _base(
        trainer=DECETrainer,
        local_mode=local_mode,
        extra_config={
            REPLAY_VALUES: tune.grid_search([True, False]),
            # "normalize_advantage": tune.grid_search([True, False]),
            # 'use_vtrace': tune.grid_search([True]),
            'sample_batch_size': 128,
            'train_batch_size': 512,
            'sgd_minibatch_size': 128,
            'num_sgd_iter': 10,
            USE_BISECTOR: False,
            'seed': tune.grid_search([432, 1920]),
            # 'lr': 5e-3,
        },
        # env_name="Pendulum-v0",
        # env_name="CartPole-v0",
        env_name=FourWayGridWorld,
        t={'time_total_s': 300},
        # t={'timesteps_total': 300000},
        num_agents=1
    )
コード例 #8
0
def test_single_agent(local_mode=False):
    _base(CEPPOTrainer, local_mode, dict(mode=DISABLE), num_agents=1)