예제 #1
0
    def test_r2d2_compilation(self):
        """Test whether R2D2 can be built on all frameworks."""
        config = (
            r2d2.R2D2Config().rollouts(num_rollout_workers=0).training(
                model={
                    # Wrap with an LSTM and use a very simple base-model.
                    "use_lstm": True,
                    "max_seq_len": 20,
                    "fcnet_hiddens": [32],
                    "lstm_cell_size": 64,
                },
                dueling=False,
                lr=5e-4,
                zero_init_states=True,
                replay_buffer_config={
                    "replay_burn_in": 20
                },
            ).exploration(exploration_config={"epsilon_timesteps": 100000}))

        num_iterations = 1

        # Test building an R2D2 agent in all frameworks.
        for _ in framework_iterator(config, with_eager_tracing=True):
            algo = config.build(env="CartPole-v0")
            for i in range(num_iterations):
                results = algo.train()
                check_train_results(results)
                check_batch_sizes(results)
                print(results)

            check_compute_single_action(algo, include_state=True)
예제 #2
0
파일: registry.py 프로젝트: parasj/ray
def _import_r2d2():
    import ray.rllib.algorithms.r2d2 as r2d2

    return r2d2.R2D2, r2d2.R2D2Config().to_dict()