Exemple #1
0
    def run_re3(self, rl_algorithm):
        """Tests RE3 for PPO and SAC.

        Both the on-policy and off-policy setups are validated.
        """
        if rl_algorithm == "PPO":
            config = ppo.PPOConfig().to_dict()
            algo_cls = ppo.PPO
            beta_schedule = "constant"
        elif rl_algorithm == "SAC":
            config = sac.SACConfig().to_dict()
            algo_cls = sac.SAC
            beta_schedule = "linear_decay"

        class RE3Callbacks(RE3UpdateCallbacks, config["callbacks"]):
            pass

        config["env"] = "Pendulum-v1"
        config["callbacks"] = RE3Callbacks
        config["exploration_config"] = {
            "type": "RE3",
            "embeds_dim": 128,
            "beta_schedule": beta_schedule,
            "sub_exploration": {
                "type": "StochasticSampling",
            },
        }

        num_iterations = 30
        algo = algo_cls(config=config)
        learnt = False
        for i in range(num_iterations):
            result = algo.train()
            print(result)
            if result["episode_reward_max"] > -900.0:
                print("Reached goal after {} iters!".format(i))
                learnt = True
                break
        algo.stop()
        self.assertTrue(learnt)
Exemple #2
0
def _import_sac():
    import ray.rllib.algorithms.sac as sac

    return sac.SAC, sac.SACConfig().to_dict()
Exemple #3
0
    def test_sac_compilation(self):
        """Tests whether SAC can be built with all frameworks."""
        config = (sac.SACConfig().training(
            n_step=3,
            twin_q=True,
            replay_buffer_config={
                "learning_starts": 0,
                "capacity": 40000
            },
            store_buffer_in_checkpoints=True,
            train_batch_size=10,
        ).rollouts(num_rollout_workers=0, rollout_fragment_length=10))
        num_iterations = 1

        ModelCatalog.register_custom_model("batch_norm", KerasBatchNormModel)
        ModelCatalog.register_custom_model("batch_norm_torch",
                                           TorchBatchNormModel)

        image_space = Box(-1.0, 1.0, shape=(84, 84, 3))
        simple_space = Box(-1.0, 1.0, shape=(3, ))

        tune.register_env(
            "random_dict_env",
            lambda _: RandomEnv({
                "observation_space":
                Dict({
                    "a": simple_space,
                    "b": Discrete(2),
                    "c": image_space,
                }),
                "action_space":
                Box(-1.0, 1.0, shape=(1, )),
            }),
        )
        tune.register_env(
            "random_tuple_env",
            lambda _: RandomEnv({
                "observation_space":
                Tuple([simple_space, Discrete(2), image_space]),
                "action_space":
                Box(-1.0, 1.0, shape=(1, )),
            }),
        )

        for fw in framework_iterator(config, with_eager_tracing=True):
            # Test for different env types (discrete w/ and w/o image, + cont).
            for env in [
                    "random_dict_env",
                    "random_tuple_env",
                    # "MsPacmanNoFrameskip-v4",
                    "CartPole-v0",
            ]:
                print("Env={}".format(env))
                # Test making the Q-model a custom one for CartPole, otherwise,
                # use the default model.
                config.q_model_config["custom_model"] = (
                    "batch_norm{}".format("_torch" if fw == "torch" else "")
                    if env == "CartPole-v0" else None)
                trainer = config.build(env=env)
                for i in range(num_iterations):
                    results = trainer.train()
                    check_train_results(results)
                    print(results)
                check_compute_single_action(trainer)

                # Test, whether the replay buffer is saved along with
                # a checkpoint (no point in doing it for all frameworks since
                # this is framework agnostic).
                if fw == "tf" and env == "CartPole-v0":
                    checkpoint = trainer.save()
                    new_trainer = sac.SAC(config, env=env)
                    new_trainer.restore(checkpoint)
                    # Get some data from the buffer and compare.
                    data = trainer.local_replay_buffer.replay_buffers[
                        "default_policy"]._storage[:42 + 42]
                    new_data = new_trainer.local_replay_buffer.replay_buffers[
                        "default_policy"]._storage[:42 + 42]
                    check(data, new_data)
                    new_trainer.stop()

                trainer.stop()
Exemple #4
0
    def test_sac_loss_function(self):
        """Tests SAC loss function results across all frameworks."""
        config = (sac.SACConfig().training(
            twin_q=False,
            gamma=0.99,
            _deterministic_loss=True,
            q_model_config={
                "fcnet_hiddens": [10]
            },
            policy_model_config={
                "fcnet_hiddens": [10]
            },
            replay_buffer_config={
                "learning_starts": 0
            },
        ).rollouts(num_rollout_workers=0).reporting(
            min_time_s_per_iteration=0, ).environment(env_config={
                "simplex_actions": True
            }, ).debugging(seed=42))

        map_ = {
            # Action net.
            "default_policy/fc_1/kernel": "action_model._hidden_layers.0."
            "_model.0.weight",
            "default_policy/fc_1/bias": "action_model._hidden_layers.0."
            "_model.0.bias",
            "default_policy/fc_out/kernel":
            "action_model._logits._model.0.weight",
            "default_policy/fc_out/bias": "action_model._logits._model.0.bias",
            "default_policy/value_out/kernel": "action_model."
            "_value_branch._model.0.weight",
            "default_policy/value_out/bias": "action_model."
            "_value_branch._model.0.bias",
            # Q-net.
            "default_policy/fc_1_1/kernel":
            "q_net._hidden_layers.0._model.0.weight",
            "default_policy/fc_1_1/bias":
            "q_net._hidden_layers.0._model.0.bias",
            "default_policy/fc_out_1/kernel": "q_net._logits._model.0.weight",
            "default_policy/fc_out_1/bias": "q_net._logits._model.0.bias",
            "default_policy/value_out_1/kernel": "q_net."
            "_value_branch._model.0.weight",
            "default_policy/value_out_1/bias":
            "q_net._value_branch._model.0.bias",
            "default_policy/log_alpha": "log_alpha",
            # Target action-net.
            "default_policy/fc_1_2/kernel": "action_model."
            "_hidden_layers.0._model.0.weight",
            "default_policy/fc_1_2/bias": "action_model."
            "_hidden_layers.0._model.0.bias",
            "default_policy/fc_out_2/kernel":
            "action_model._logits._model.0.weight",
            "default_policy/fc_out_2/bias":
            "action_model._logits._model.0.bias",
            "default_policy/value_out_2/kernel": "action_model."
            "_value_branch._model.0.weight",
            "default_policy/value_out_2/bias": "action_model."
            "_value_branch._model.0.bias",
            # Target Q-net
            "default_policy/fc_1_3/kernel":
            "q_net._hidden_layers.0._model.0.weight",
            "default_policy/fc_1_3/bias":
            "q_net._hidden_layers.0._model.0.bias",
            "default_policy/fc_out_3/kernel": "q_net._logits._model.0.weight",
            "default_policy/fc_out_3/bias": "q_net._logits._model.0.bias",
            "default_policy/value_out_3/kernel": "q_net."
            "_value_branch._model.0.weight",
            "default_policy/value_out_3/bias":
            "q_net._value_branch._model.0.bias",
            "default_policy/log_alpha_1": "log_alpha",
        }

        env = SimpleEnv
        batch_size = 64
        obs_size = (batch_size, 1)
        actions = np.random.random(size=(batch_size, 2))

        # Batch of size=n.
        input_ = self._get_batch_helper(obs_size, actions, batch_size)

        # Simply compare loss values AND grads of all frameworks with each
        # other.
        prev_fw_loss = weights_dict = None
        expect_c, expect_a, expect_e, expect_t = None, None, None, None
        # History of tf-updated NN-weights over n training steps.
        tf_updated_weights = []
        # History of input batches used.
        tf_inputs = []
        for fw, sess in framework_iterator(config,
                                           frameworks=("tf", "torch"),
                                           session=True):
            # Generate Trainer and get its default Policy object.
            trainer = config.build(env=env)
            policy = trainer.get_policy()
            p_sess = None
            if sess:
                p_sess = policy.get_session()

            # Set all weights (of all nets) to fixed values.
            if weights_dict is None:
                # Start with the tf vars-dict.
                assert fw in ["tf2", "tf", "tfe"]
                weights_dict = policy.get_weights()
                if fw == "tfe":
                    log_alpha = weights_dict[10]
                    weights_dict = self._translate_tfe_weights(
                        weights_dict, map_)
            else:
                assert fw == "torch"  # Then transfer that to torch Model.
                model_dict = self._translate_weights_to_torch(
                    weights_dict, map_)
                # Have to add this here (not a parameter in tf, but must be
                # one in torch, so it gets properly copied to the GPU(s)).
                model_dict["target_entropy"] = policy.model.target_entropy
                policy.model.load_state_dict(model_dict)
                policy.target_model.load_state_dict(model_dict)

            if fw == "tf":
                log_alpha = weights_dict["default_policy/log_alpha"]
            elif fw == "torch":
                # Actually convert to torch tensors (by accessing everything).
                input_ = policy._lazy_tensor_dict(input_)
                input_ = {k: input_[k] for k in input_.keys()}
                log_alpha = policy.model.log_alpha.detach().cpu().numpy()[0]

            # Only run the expectation once, should be the same anyways
            # for all frameworks.
            if expect_c is None:
                expect_c, expect_a, expect_e, expect_t = self._sac_loss_helper(
                    input_,
                    weights_dict,
                    sorted(weights_dict.keys()),
                    log_alpha,
                    fw,
                    gamma=config.gamma,
                    sess=sess,
                )

            # Get actual outs and compare to expectation AND previous
            # framework. c=critic, a=actor, e=entropy, t=td-error.
            if fw == "tf":
                c, a, e, t, tf_c_grads, tf_a_grads, tf_e_grads = p_sess.run(
                    [
                        policy.critic_loss,
                        policy.actor_loss,
                        policy.alpha_loss,
                        policy.td_error,
                        policy.optimizer().compute_gradients(
                            policy.critic_loss[0],
                            [
                                v for v in policy.model.q_variables()
                                if "value_" not in v.name
                            ],
                        ),
                        policy.optimizer().compute_gradients(
                            policy.actor_loss,
                            [
                                v for v in policy.model.policy_variables()
                                if "value_" not in v.name
                            ],
                        ),
                        policy.optimizer().compute_gradients(
                            policy.alpha_loss, policy.model.log_alpha),
                    ],
                    feed_dict=policy._get_loss_inputs_dict(input_,
                                                           shuffle=False),
                )
                tf_c_grads = [g for g, v in tf_c_grads]
                tf_a_grads = [g for g, v in tf_a_grads]
                tf_e_grads = [g for g, v in tf_e_grads]

            elif fw == "tfe":
                with tf.GradientTape() as tape:
                    tf_loss(policy, policy.model, None, input_)
                c, a, e, t = (
                    policy.critic_loss,
                    policy.actor_loss,
                    policy.alpha_loss,
                    policy.td_error,
                )
                vars = tape.watched_variables()
                tf_c_grads = tape.gradient(c[0], vars[6:10])
                tf_a_grads = tape.gradient(a, vars[2:6])
                tf_e_grads = tape.gradient(e, vars[10])

            elif fw == "torch":
                loss_torch(policy, policy.model, None, input_)
                c, a, e, t = (
                    policy.get_tower_stats("critic_loss")[0],
                    policy.get_tower_stats("actor_loss")[0],
                    policy.get_tower_stats("alpha_loss")[0],
                    policy.get_tower_stats("td_error")[0],
                )

                # Test actor gradients.
                policy.actor_optim.zero_grad()
                assert all(v.grad is None for v in policy.model.q_variables())
                assert all(v.grad is None
                           for v in policy.model.policy_variables())
                assert policy.model.log_alpha.grad is None
                a.backward()
                # `actor_loss` depends on Q-net vars (but these grads must
                # be ignored and overridden in critic_loss.backward!).
                assert not all(
                    torch.mean(v.grad) == 0
                    for v in policy.model.policy_variables())
                assert not all(
                    torch.min(v.grad) == 0
                    for v in policy.model.policy_variables())
                assert policy.model.log_alpha.grad is None
                # Compare with tf ones.
                torch_a_grads = [
                    v.grad for v in policy.model.policy_variables()
                    if v.grad is not None
                ]
                check(tf_a_grads[2],
                      np.transpose(torch_a_grads[0].detach().cpu()))

                # Test critic gradients.
                policy.critic_optims[0].zero_grad()
                assert all(
                    torch.mean(v.grad) == 0.0
                    for v in policy.model.q_variables() if v.grad is not None)
                assert all(
                    torch.min(v.grad) == 0.0
                    for v in policy.model.q_variables() if v.grad is not None)
                assert policy.model.log_alpha.grad is None
                c[0].backward()
                assert not all(
                    torch.mean(v.grad) == 0
                    for v in policy.model.q_variables() if v.grad is not None)
                assert not all(
                    torch.min(v.grad) == 0
                    for v in policy.model.q_variables() if v.grad is not None)
                assert policy.model.log_alpha.grad is None
                # Compare with tf ones.
                torch_c_grads = [v.grad for v in policy.model.q_variables()]
                check(tf_c_grads[0],
                      np.transpose(torch_c_grads[2].detach().cpu()))
                # Compare (unchanged(!) actor grads) with tf ones.
                torch_a_grads = [
                    v.grad for v in policy.model.policy_variables()
                ]
                check(tf_a_grads[2],
                      np.transpose(torch_a_grads[0].detach().cpu()))

                # Test alpha gradient.
                policy.alpha_optim.zero_grad()
                assert policy.model.log_alpha.grad is None
                e.backward()
                assert policy.model.log_alpha.grad is not None
                check(policy.model.log_alpha.grad, tf_e_grads)

            check(c, expect_c)
            check(a, expect_a)
            check(e, expect_e)
            check(t, expect_t)

            # Store this framework's losses in prev_fw_loss to compare with
            # next framework's outputs.
            if prev_fw_loss is not None:
                check(c, prev_fw_loss[0])
                check(a, prev_fw_loss[1])
                check(e, prev_fw_loss[2])
                check(t, prev_fw_loss[3])

            prev_fw_loss = (c, a, e, t)

            # Update weights from our batch (n times).
            for update_iteration in range(5):
                print("train iteration {}".format(update_iteration))
                if fw == "tf":
                    in_ = self._get_batch_helper(obs_size, actions, batch_size)
                    tf_inputs.append(in_)
                    # Set a fake-batch to use
                    # (instead of sampling from replay buffer).
                    buf = trainer.local_replay_buffer
                    patch_buffer_with_fake_sampling_method(buf, in_)
                    trainer.train()
                    updated_weights = policy.get_weights()
                    # Net must have changed.
                    if tf_updated_weights:
                        check(
                            updated_weights["default_policy/fc_1/kernel"],
                            tf_updated_weights[-1]
                            ["default_policy/fc_1/kernel"],
                            false=True,
                        )
                    tf_updated_weights.append(updated_weights)

                # Compare with updated tf-weights. Must all be the same.
                else:
                    tf_weights = tf_updated_weights[update_iteration]
                    in_ = tf_inputs[update_iteration]
                    # Set a fake-batch to use
                    # (instead of sampling from replay buffer).
                    buf = trainer.local_replay_buffer
                    patch_buffer_with_fake_sampling_method(buf, in_)
                    trainer.train()
                    # Compare updated model.
                    for tf_key in sorted(tf_weights.keys()):
                        if re.search("_[23]|alpha", tf_key):
                            continue
                        tf_var = tf_weights[tf_key]
                        torch_var = policy.model.state_dict()[map_[tf_key]]
                        if tf_var.shape != torch_var.shape:
                            check(
                                tf_var,
                                np.transpose(torch_var.detach().cpu()),
                                atol=0.003,
                            )
                        else:
                            check(tf_var, torch_var, atol=0.003)
                    # And alpha.
                    check(policy.model.log_alpha,
                          tf_weights["default_policy/log_alpha"])
                    # Compare target nets.
                    for tf_key in sorted(tf_weights.keys()):
                        if not re.search("_[23]", tf_key):
                            continue
                        tf_var = tf_weights[tf_key]
                        torch_var = policy.target_model.state_dict()[
                            map_[tf_key]]
                        if tf_var.shape != torch_var.shape:
                            check(
                                tf_var,
                                np.transpose(torch_var.detach().cpu()),
                                atol=0.003,
                            )
                        else:
                            check(tf_var, torch_var, atol=0.003)
            trainer.stop()