def test_simple_q_compilation(self):
        """Test whether a SimpleQTrainer can be built on all frameworks."""
        config = dqn.SIMPLE_Q_DEFAULT_CONFIG.copy()
        config["num_workers"] = 0  # Run locally.

        for _ in framework_iterator(config):
            trainer = dqn.SimpleQTrainer(config=config, env="CartPole-v0")
            num_iterations = 2
            for i in range(num_iterations):
                results = trainer.train()
                print(results)
Exemplo n.º 2
0
    def test_simple_q_compilation(self):
        """Test whether a SimpleQTrainer can be built on all frameworks."""
        config = dqn.SIMPLE_Q_DEFAULT_CONFIG.copy()
        config["num_workers"] = 0  # Run locally.
        num_iterations = 2

        for _ in framework_iterator(config):
            trainer = dqn.SimpleQTrainer(config=config, env="CartPole-v0")
            rw = trainer.workers.local_worker()
            for i in range(num_iterations):
                sb = rw.sample()
                assert sb.count == config["rollout_fragment_length"]
                results = trainer.train()
                print(results)

            check_compute_single_action(trainer)
Exemplo n.º 3
0
    def test_simple_q_fake_multi_gpu_learning(self):
        """Test whether SimpleQTrainer learns CartPole w/ fake GPUs."""
        config = copy.deepcopy(dqn.SIMPLE_Q_DEFAULT_CONFIG)

        # Fake GPU setup.
        config["num_gpus"] = 2
        config["_fake_gpus"] = True

        for _ in framework_iterator(config, frameworks=("tf", "torch")):
            trainer = dqn.SimpleQTrainer(config=config, env="CartPole-v0")
            num_iterations = 200
            learnt = False
            for i in range(num_iterations):
                results = trainer.train()
                print("reward={}".format(results["episode_reward_mean"]))
                if results["episode_reward_mean"] > 75.0:
                    learnt = True
                    break
            assert learnt, "SimpleQ multi-GPU (with fake-GPUs) did not " \
                           "learn CartPole!"
            trainer.stop()
Exemplo n.º 4
0
def test_rllib_integration(ray_start_regular_shared):
    with ray_start_client_server():
        import ray.rllib.agents.dqn as dqn
        # Confirming the behavior of this context manager.
        # (Client mode hook not yet enabled.)
        assert not client_mode_should_convert()
        # Need to enable this for client APIs to be used.
        with enable_client_mode():
            # Confirming mode hook is enabled.
            assert client_mode_should_convert()

            config = dqn.SIMPLE_Q_DEFAULT_CONFIG.copy()
            # Run locally.
            config["num_workers"] = 0
            # Test with compression.
            config["compress_observations"] = True
            num_iterations = 2
            trainer = dqn.SimpleQTrainer(config=config, env="CartPole-v1")
            rw = trainer.workers.local_worker()
            for i in range(num_iterations):
                sb = rw.sample()
                assert sb.count == config["rollout_fragment_length"]
                trainer.train()
Exemplo n.º 5
0
    def test_simple_q_loss_function(self):
        """Tests the Simple-Q loss function results on all frameworks."""
        config = dqn.simple_q.SimpleQConfig().rollouts(num_rollout_workers=0)
        # Use very simple net (layer0=10 nodes, q-layer=2 nodes (2 actions)).
        config.training(model={
            "fcnet_hiddens": [10],
            "fcnet_activation": "linear",
        })

        for fw in framework_iterator(config):
            # Generate Trainer and get its default Policy object.
            trainer = dqn.SimpleQTrainer(config=config, env="CartPole-v0")
            policy = trainer.get_policy()
            # Batch of size=2.
            input_ = SampleBatch({
                SampleBatch.CUR_OBS:
                np.random.random(size=(2, 4)),
                SampleBatch.ACTIONS:
                np.array([0, 1]),
                SampleBatch.REWARDS:
                np.array([0.4, -1.23]),
                SampleBatch.DONES:
                np.array([False, False]),
                SampleBatch.NEXT_OBS:
                np.random.random(size=(2, 4)),
                SampleBatch.EPS_ID:
                np.array([1234, 1234]),
                SampleBatch.AGENT_INDEX:
                np.array([0, 0]),
                SampleBatch.ACTION_LOGP:
                np.array([-0.1, -0.1]),
                SampleBatch.ACTION_DIST_INPUTS:
                np.array([[0.1, 0.2], [-0.1, -0.2]]),
                SampleBatch.ACTION_PROB:
                np.array([0.1, 0.2]),
                "q_values":
                np.array([[0.1, 0.2], [0.2, 0.1]]),
            })
            # Get model vars for computing expected model outs (q-vals).
            # 0=layer-kernel; 1=layer-bias; 2=q-val-kernel; 3=q-val-bias
            vars = policy.get_weights()
            if isinstance(vars, dict):
                vars = list(vars.values())

            vars_t = policy.target_model.variables()
            if fw == "tf":
                vars_t = policy.get_session().run(vars_t)

            # Q(s,a) outputs.
            q_t = np.sum(
                one_hot(input_[SampleBatch.ACTIONS], 2) * fc(
                    fc(
                        input_[SampleBatch.CUR_OBS],
                        vars[0 if fw != "torch" else 2],
                        vars[1 if fw != "torch" else 3],
                        framework=fw,
                    ),
                    vars[2 if fw != "torch" else 0],
                    vars[3 if fw != "torch" else 1],
                    framework=fw,
                ),
                1,
            )
            # max[a'](Qtarget(s',a')) outputs.
            q_target_tp1 = np.max(
                fc(
                    fc(
                        input_[SampleBatch.NEXT_OBS],
                        vars_t[0 if fw != "torch" else 2],
                        vars_t[1 if fw != "torch" else 3],
                        framework=fw,
                    ),
                    vars_t[2 if fw != "torch" else 0],
                    vars_t[3 if fw != "torch" else 1],
                    framework=fw,
                ),
                1,
            )
            # TD-errors (Bellman equation).
            td_error = q_t - config.gamma * input_[
                SampleBatch.REWARDS] + q_target_tp1
            # Huber/Square loss on TD-error.
            expected_loss = huber_loss(td_error).mean()

            if fw == "torch":
                input_ = policy._lazy_tensor_dict(input_)
            # Get actual out and compare.
            if fw == "tf":
                out = policy.get_session().run(
                    policy._loss,
                    feed_dict=policy._get_loss_inputs_dict(input_,
                                                           shuffle=False),
                )
            else:
                out = (loss_torch if fw == "torch" else loss_tf)(policy,
                                                                 policy.model,
                                                                 None, input_)
            check(out, expected_loss, decimals=1)
Exemplo n.º 6
0
    def test_simple_q_loss_function(self):
        """Tests the Simple-Q loss function results on all frameworks."""
        config = dqn.SIMPLE_Q_DEFAULT_CONFIG.copy()
        # Run locally.
        config["num_workers"] = 0
        # Use very simple net (layer0=10 nodes, q-layer=2 nodes (2 actions)).
        config["model"]["fcnet_hiddens"] = [10]
        config["model"]["fcnet_activation"] = "linear"

        for fw in framework_iterator(config):
            # Generate Trainer and get its default Policy object.
            trainer = dqn.SimpleQTrainer(config=config, env="CartPole-v0")
            policy = trainer.get_policy()
            # Batch of size=2.
            input_ = {
                SampleBatch.CUR_OBS: np.random.random(size=(2, 4)),
                SampleBatch.ACTIONS: np.array([0, 1]),
                SampleBatch.REWARDS: np.array([0.4, -1.23]),
                SampleBatch.DONES: np.array([False, False]),
                SampleBatch.NEXT_OBS: np.random.random(size=(2, 4))
            }
            # Get model vars for computing expected model outs (q-vals).
            # 0=layer-kernel; 1=layer-bias; 2=q-val-kernel; 3=q-val-bias
            vars = policy.get_weights()
            if isinstance(vars, dict):
                vars = list(vars.values())
            vars_t = policy.target_q_func_vars
            if fw == "tf":
                vars_t = policy.get_session().run(vars_t)

            # Q(s,a) outputs.
            q_t = np.sum(
                one_hot(input_[SampleBatch.ACTIONS], 2) *
                fc(fc(input_[SampleBatch.CUR_OBS],
                      vars[0 if fw != "torch" else 2],
                      vars[1 if fw != "torch" else 3],
                      framework=fw),
                   vars[2 if fw != "torch" else 0],
                   vars[3 if fw != "torch" else 1],
                   framework=fw), 1)
            # max[a'](Qtarget(s',a')) outputs.
            q_target_tp1 = np.max(
                fc(fc(input_[SampleBatch.NEXT_OBS],
                      vars_t[0 if fw != "torch" else 2],
                      vars_t[1 if fw != "torch" else 3],
                      framework=fw),
                   vars_t[2 if fw != "torch" else 0],
                   vars_t[3 if fw != "torch" else 1],
                   framework=fw), 1)
            # TD-errors (Bellman equation).
            td_error = q_t - config["gamma"] * input_[SampleBatch.REWARDS] + \
                q_target_tp1
            # Huber/Square loss on TD-error.
            expected_loss = huber_loss(td_error).mean()

            if fw == "torch":
                input_ = policy._lazy_tensor_dict(input_)
            # Get actual out and compare.
            if fw == "tf":
                out = policy.get_session().run(
                    policy._loss,
                    feed_dict=policy._get_loss_inputs_dict(input_,
                                                           shuffle=False))
            else:
                out = (loss_torch if fw == "torch" else loss_tf)(policy,
                                                                 policy.model,
                                                                 None, input_)
            check(out, expected_loss, decimals=1)