예제 #1
0
파일: test_sac.py 프로젝트: yongjun823/ray
    def test_sac_loss_function(self):
        """Tests SAC loss function results across all frameworks."""
        config = sac.DEFAULT_CONFIG.copy()
        # Run locally.
        config["num_workers"] = 0
        config["learning_starts"] = 0
        config["twin_q"] = False
        config["gamma"] = 0.99
        # Switch on deterministic loss so we can compare the loss values.
        config["_deterministic_loss"] = True
        # Use very simple nets.
        config["Q_model"]["fcnet_hiddens"] = [10]
        config["policy_model"]["fcnet_hiddens"] = [10]
        # Make sure, timing differences do not affect trainer.train().
        config["min_iter_time_s"] = 0

        map_ = {
            # Normal net.
            "default_policy/sequential/action_1/kernel": "action_model."
            "action_0._model.0.weight",
            "default_policy/sequential/action_1/bias": "action_model."
            "action_0._model.0.bias",
            "default_policy/sequential/action_out/kernel": "action_model."
            "action_out._model.0.weight",
            "default_policy/sequential/action_out/bias": "action_model."
            "action_out._model.0.bias",
            "default_policy/sequential_1/q_hidden_0/kernel": "q_net."
            "q_hidden_0._model.0.weight",
            "default_policy/sequential_1/q_hidden_0/bias": "q_net."
            "q_hidden_0._model.0.bias",
            "default_policy/sequential_1/q_out/kernel": "q_net."
            "q_out._model.0.weight",
            "default_policy/sequential_1/q_out/bias": "q_net."
            "q_out._model.0.bias",
            "default_policy/value_out/kernel": "_value_branch."
            "_model.0.weight",
            "default_policy/value_out/bias": "_value_branch."
            "_model.0.bias",
            # Target net.
            "default_policy/sequential_2/action_1/kernel": "action_model."
            "action_0._model.0.weight",
            "default_policy/sequential_2/action_1/bias": "action_model."
            "action_0._model.0.bias",
            "default_policy/sequential_2/action_out/kernel": "action_model."
            "action_out._model.0.weight",
            "default_policy/sequential_2/action_out/bias": "action_model."
            "action_out._model.0.bias",
            "default_policy/sequential_3/q_hidden_0/kernel": "q_net."
            "q_hidden_0._model.0.weight",
            "default_policy/sequential_3/q_hidden_0/bias": "q_net."
            "q_hidden_0._model.0.bias",
            "default_policy/sequential_3/q_out/kernel": "q_net."
            "q_out._model.0.weight",
            "default_policy/sequential_3/q_out/bias": "q_net."
            "q_out._model.0.bias",
            "default_policy/value_out_1/kernel": "_value_branch."
            "_model.0.weight",
            "default_policy/value_out_1/bias": "_value_branch."
            "_model.0.bias",
        }

        env = SimpleEnv
        batch_size = 100
        if env is SimpleEnv:
            obs_size = (batch_size, 1)
            actions = np.random.random(size=(batch_size, 1))
        elif env == "CartPole-v0":
            obs_size = (batch_size, 4)
            actions = np.random.randint(0, 2, size=(batch_size, ))
        else:
            obs_size = (batch_size, 3)
            actions = np.random.random(size=(batch_size, 1))

        # Batch of size=n.
        input_ = self._get_batch_helper(obs_size, actions, batch_size)

        # Simply compare loss values AND grads of all frameworks with each
        # other.
        prev_fw_loss = weights_dict = None
        expect_c, expect_a, expect_e, expect_t = None, None, None, None
        # History of tf-updated NN-weights over n training steps.
        tf_updated_weights = []
        # History of input batches used.
        tf_inputs = []
        for fw, sess in framework_iterator(config,
                                           frameworks=("tf", "torch"),
                                           session=True):
            # Generate Trainer and get its default Policy object.
            trainer = sac.SACTrainer(config=config, env=env)
            policy = trainer.get_policy()
            p_sess = None
            if sess:
                p_sess = policy.get_session()

            # Set all weights (of all nets) to fixed values.
            if weights_dict is None:
                assert fw == "tf"  # Start with the tf vars-dict.
                weights_dict = policy.get_weights()
            else:
                assert fw == "torch"  # Then transfer that to torch Model.
                model_dict = self._translate_weights_to_torch(
                    weights_dict, map_)
                policy.model.load_state_dict(model_dict)
                policy.target_model.load_state_dict(model_dict)

            if fw == "tf":
                log_alpha = weights_dict["default_policy/log_alpha"]
            elif fw == "torch":
                # Actually convert to torch tensors.
                input_ = policy._lazy_tensor_dict(input_)
                input_ = {k: input_[k] for k in input_.keys()}
                log_alpha = policy.model.log_alpha.detach().numpy()[0]

            # Only run the expectation once, should be the same anyways
            # for all frameworks.
            if expect_c is None:
                expect_c, expect_a, expect_e, expect_t = \
                    self._sac_loss_helper(input_, weights_dict,
                                          sorted(weights_dict.keys()),
                                          log_alpha, fw,
                                          gamma=config["gamma"], sess=sess)

            # Get actual outs and compare to expectation AND previous
            # framework. c=critic, a=actor, e=entropy, t=td-error.
            if fw == "tf":
                c, a, e, t, tf_c_grads, tf_a_grads, tf_e_grads = \
                    p_sess.run([
                        policy.critic_loss,
                        policy.actor_loss,
                        policy.alpha_loss,
                        policy.td_error,
                        policy.optimizer().compute_gradients(
                            policy.critic_loss[0],
                            policy.model.q_variables()),
                        policy.optimizer().compute_gradients(
                            policy.actor_loss,
                            policy.model.policy_variables()),
                        policy.optimizer().compute_gradients(
                            policy.alpha_loss, policy.model.log_alpha)],
                        feed_dict=policy._get_loss_inputs_dict(
                            input_, shuffle=False))
                tf_c_grads = [g for g, v in tf_c_grads]
                tf_a_grads = [g for g, v in tf_a_grads]
                tf_e_grads = [g for g, v in tf_e_grads]

            elif fw == "torch":
                loss_torch(policy, policy.model, None, input_)
                c, a, e, t = policy.critic_loss, policy.actor_loss, \
                    policy.alpha_loss, policy.td_error

                # Test actor gradients.
                policy.actor_optim.zero_grad()
                assert all(v.grad is None for v in policy.model.q_variables())
                assert all(v.grad is None
                           for v in policy.model.policy_variables())
                assert policy.model.log_alpha.grad is None
                a.backward()
                # `actor_loss` depends on Q-net vars (but these grads must
                # be ignored and overridden in critic_loss.backward!).
                assert not any(v.grad is None
                               for v in policy.model.q_variables())
                assert not all(
                    torch.mean(v.grad) == 0
                    for v in policy.model.policy_variables())
                assert not all(
                    torch.min(v.grad) == 0
                    for v in policy.model.policy_variables())
                assert policy.model.log_alpha.grad is None
                # Compare with tf ones.
                torch_a_grads = [
                    v.grad for v in policy.model.policy_variables()
                ]
                for tf_g, torch_g in zip(tf_a_grads, torch_a_grads):
                    if tf_g.shape != torch_g.shape:
                        check(tf_g, np.transpose(torch_g))
                    else:
                        check(tf_g, torch_g)

                # Test critic gradients.
                policy.critic_optims[0].zero_grad()
                assert all(
                    torch.mean(v.grad) == 0.0
                    for v in policy.model.q_variables())
                assert all(
                    torch.min(v.grad) == 0.0
                    for v in policy.model.q_variables())
                assert policy.model.log_alpha.grad is None
                c[0].backward()
                assert not all(
                    torch.mean(v.grad) == 0
                    for v in policy.model.q_variables())
                assert not all(
                    torch.min(v.grad) == 0 for v in policy.model.q_variables())
                assert policy.model.log_alpha.grad is None
                # Compare with tf ones.
                torch_c_grads = [v.grad for v in policy.model.q_variables()]
                for tf_g, torch_g in zip(tf_c_grads, torch_c_grads):
                    if tf_g.shape != torch_g.shape:
                        check(tf_g, np.transpose(torch_g))
                    else:
                        check(tf_g, torch_g)
                # Compare (unchanged(!) actor grads) with tf ones.
                torch_a_grads = [
                    v.grad for v in policy.model.policy_variables()
                ]
                for tf_g, torch_g in zip(tf_a_grads, torch_a_grads):
                    if tf_g.shape != torch_g.shape:
                        check(tf_g, np.transpose(torch_g))
                    else:
                        check(tf_g, torch_g)

                # Test alpha gradient.
                policy.alpha_optim.zero_grad()
                assert policy.model.log_alpha.grad is None
                e.backward()
                assert policy.model.log_alpha.grad is not None
                check(policy.model.log_alpha.grad, tf_e_grads)

            check(c, expect_c)
            check(a, expect_a)
            check(e, expect_e)
            check(t, expect_t)

            # Store this framework's losses in prev_fw_loss to compare with
            # next framework's outputs.
            if prev_fw_loss is not None:
                check(c, prev_fw_loss[0])
                check(a, prev_fw_loss[1])
                check(e, prev_fw_loss[2])
                check(t, prev_fw_loss[3])

            prev_fw_loss = (c, a, e, t)

            # Update weights from our batch (n times).
            for update_iteration in range(10):
                print("train iteration {}".format(update_iteration))
                if fw == "tf":
                    in_ = self._get_batch_helper(obs_size, actions, batch_size)
                    tf_inputs.append(in_)
                    # Set a fake-batch to use
                    # (instead of sampling from replay buffer).
                    trainer.optimizer._fake_batch = in_
                    trainer.train()
                    updated_weights = policy.get_weights()
                    # Net must have changed.
                    if tf_updated_weights:
                        check(updated_weights[
                            "default_policy/sequential/action_1/kernel"],
                              tf_updated_weights[-1]
                              ["default_policy/sequential/action_1/kernel"],
                              false=True)
                    tf_updated_weights.append(updated_weights)

                # Compare with updated tf-weights. Must all be the same.
                else:
                    tf_weights = tf_updated_weights[update_iteration]
                    in_ = tf_inputs[update_iteration]
                    # Set a fake-batch to use
                    # (instead of sampling from replay buffer).
                    trainer.optimizer._fake_batch = in_
                    trainer.train()
                    # Compare updated model.
                    for tf_key in sorted(tf_weights.keys())[2:10]:
                        tf_var = tf_weights[tf_key]
                        torch_var = policy.model.state_dict()[map_[tf_key]]
                        if tf_var.shape != torch_var.shape:
                            check(tf_var, np.transpose(torch_var), rtol=0.05)
                        else:
                            check(tf_var, torch_var, rtol=0.05)
                    # And alpha.
                    check(policy.model.log_alpha,
                          tf_weights["default_policy/log_alpha"])
                    # Compare target nets.
                    for tf_key in sorted(tf_weights.keys())[10:18]:
                        tf_var = tf_weights[tf_key]
                        torch_var = policy.target_model.state_dict()[
                            map_[tf_key]]
                        if tf_var.shape != torch_var.shape:
                            check(tf_var, np.transpose(torch_var), rtol=0.05)
                        else:
                            check(tf_var, torch_var, rtol=0.05)
예제 #2
0
파일: test_sac.py 프로젝트: hngenc/ray
    def test_sac_loss_function(self):
        """Tests SAC loss function results across all frameworks."""
        config = sac.DEFAULT_CONFIG.copy()
        # Run locally.
        config["num_workers"] = 0
        config["learning_starts"] = 0
        config["twin_q"] = False
        config["gamma"] = 0.99
        # Switch on deterministic loss so we can compare the loss values.
        config["_deterministic_loss"] = True
        # Use very simple nets.
        config["Q_model"]["fcnet_hiddens"] = [10]
        config["policy_model"]["fcnet_hiddens"] = [10]
        # Make sure, timing differences do not affect trainer.train().
        config["min_iter_time_s"] = 0
        # Test SAC with Simplex action space.
        config["env_config"] = {"simplex_actions": True}

        map_ = {
            # Action net.
            "default_policy/fc_1/kernel": "action_model._hidden_layers.0."
            "_model.0.weight",
            "default_policy/fc_1/bias": "action_model._hidden_layers.0."
            "_model.0.bias",
            "default_policy/fc_out/kernel": "action_model."
            "_logits._model.0.weight",
            "default_policy/fc_out/bias": "action_model._logits._model.0.bias",
            "default_policy/value_out/kernel": "action_model."
            "_value_branch._model.0.weight",
            "default_policy/value_out/bias": "action_model."
            "_value_branch._model.0.bias",
            # Q-net.
            "default_policy/fc_1_1/kernel": "q_net."
            "_hidden_layers.0._model.0.weight",
            "default_policy/fc_1_1/bias": "q_net."
            "_hidden_layers.0._model.0.bias",
            "default_policy/fc_out_1/kernel": "q_net._logits._model.0.weight",
            "default_policy/fc_out_1/bias": "q_net._logits._model.0.bias",
            "default_policy/value_out_1/kernel": "q_net."
            "_value_branch._model.0.weight",
            "default_policy/value_out_1/bias": "q_net."
            "_value_branch._model.0.bias",
            "default_policy/log_alpha": "log_alpha",
            # Target action-net.
            "default_policy/fc_1_2/kernel": "action_model."
            "_hidden_layers.0._model.0.weight",
            "default_policy/fc_1_2/bias": "action_model."
            "_hidden_layers.0._model.0.bias",
            "default_policy/fc_out_2/kernel": "action_model."
            "_logits._model.0.weight",
            "default_policy/fc_out_2/bias": "action_model."
            "_logits._model.0.bias",
            "default_policy/value_out_2/kernel": "action_model."
            "_value_branch._model.0.weight",
            "default_policy/value_out_2/bias": "action_model."
            "_value_branch._model.0.bias",
            # Target Q-net
            "default_policy/fc_1_3/kernel": "q_net."
            "_hidden_layers.0._model.0.weight",
            "default_policy/fc_1_3/bias": "q_net."
            "_hidden_layers.0._model.0.bias",
            "default_policy/fc_out_3/kernel": "q_net."
            "_logits._model.0.weight",
            "default_policy/fc_out_3/bias": "q_net."
            "_logits._model.0.bias",
            "default_policy/value_out_3/kernel": "q_net."
            "_value_branch._model.0.weight",
            "default_policy/value_out_3/bias": "q_net."
            "_value_branch._model.0.bias",
            "default_policy/log_alpha_1": "log_alpha",
        }

        env = SimpleEnv
        batch_size = 100
        obs_size = (batch_size, 1)
        actions = np.random.random(size=(batch_size, 2))

        # Batch of size=n.
        input_ = self._get_batch_helper(obs_size, actions, batch_size)

        # Simply compare loss values AND grads of all frameworks with each
        # other.
        prev_fw_loss = weights_dict = None
        expect_c, expect_a, expect_e, expect_t = None, None, None, None
        # History of tf-updated NN-weights over n training steps.
        tf_updated_weights = []
        # History of input batches used.
        tf_inputs = []
        for fw, sess in framework_iterator(config,
                                           frameworks=("tf", "torch"),
                                           session=True):
            # Generate Trainer and get its default Policy object.
            trainer = sac.SACTrainer(config=config, env=env)
            policy = trainer.get_policy()
            p_sess = None
            if sess:
                p_sess = policy.get_session()

            # Set all weights (of all nets) to fixed values.
            if weights_dict is None:
                # Start with the tf vars-dict.
                assert fw in ["tf2", "tf", "tfe"]
                weights_dict = policy.get_weights()
                if fw == "tfe":
                    log_alpha = weights_dict[10]
                    weights_dict = self._translate_tfe_weights(
                        weights_dict, map_)
            else:
                assert fw == "torch"  # Then transfer that to torch Model.
                model_dict = self._translate_weights_to_torch(
                    weights_dict, map_)
                # Have to add this here (not a parameter in tf, but must be
                # one in torch, so it gets properly copied to the GPU(s)).
                model_dict["target_entropy"] = policy.model.target_entropy
                policy.model.load_state_dict(model_dict)
                policy.target_model.load_state_dict(model_dict)

            if fw == "tf":
                log_alpha = weights_dict["default_policy/log_alpha"]
            elif fw == "torch":
                # Actually convert to torch tensors (by accessing everything).
                input_ = policy._lazy_tensor_dict(input_)
                input_ = {k: input_[k] for k in input_.keys()}
                log_alpha = policy.model.log_alpha.detach().cpu().numpy()[0]

            # Only run the expectation once, should be the same anyways
            # for all frameworks.
            if expect_c is None:
                expect_c, expect_a, expect_e, expect_t = \
                    self._sac_loss_helper(input_, weights_dict,
                                          sorted(weights_dict.keys()),
                                          log_alpha, fw,
                                          gamma=config["gamma"], sess=sess)

            # Get actual outs and compare to expectation AND previous
            # framework. c=critic, a=actor, e=entropy, t=td-error.
            if fw == "tf":
                c, a, e, t, tf_c_grads, tf_a_grads, tf_e_grads = \
                    p_sess.run([
                        policy.critic_loss,
                        policy.actor_loss,
                        policy.alpha_loss,
                        policy.td_error,
                        policy.optimizer().compute_gradients(
                            policy.critic_loss[0],
                            [v for v in policy.model.q_variables() if
                             "value_" not in v.name]),
                        policy.optimizer().compute_gradients(
                            policy.actor_loss,
                            [v for v in policy.model.policy_variables() if
                             "value_" not in v.name]),
                        policy.optimizer().compute_gradients(
                            policy.alpha_loss, policy.model.log_alpha)],
                        feed_dict=policy._get_loss_inputs_dict(
                            input_, shuffle=False))
                tf_c_grads = [g for g, v in tf_c_grads]
                tf_a_grads = [g for g, v in tf_a_grads]
                tf_e_grads = [g for g, v in tf_e_grads]

            elif fw == "tfe":
                with tf.GradientTape() as tape:
                    tf_loss(policy, policy.model, None, input_)
                c, a, e, t = policy.critic_loss, policy.actor_loss, \
                    policy.alpha_loss, policy.td_error
                vars = tape.watched_variables()
                tf_c_grads = tape.gradient(c[0], vars[6:10])
                tf_a_grads = tape.gradient(a, vars[2:6])
                tf_e_grads = tape.gradient(e, vars[10])

            elif fw == "torch":
                loss_torch(policy, policy.model, None, input_)
                c, a, e, t = policy.critic_loss, policy.actor_loss, \
                    policy.alpha_loss, policy.model.td_error

                # Test actor gradients.
                policy.actor_optim.zero_grad()
                assert all(v.grad is None for v in policy.model.q_variables())
                assert all(v.grad is None
                           for v in policy.model.policy_variables())
                assert policy.model.log_alpha.grad is None
                a.backward()
                # `actor_loss` depends on Q-net vars (but these grads must
                # be ignored and overridden in critic_loss.backward!).
                assert not all(
                    torch.mean(v.grad) == 0
                    for v in policy.model.policy_variables())
                assert not all(
                    torch.min(v.grad) == 0
                    for v in policy.model.policy_variables())
                assert policy.model.log_alpha.grad is None
                # Compare with tf ones.
                torch_a_grads = [
                    v.grad for v in policy.model.policy_variables()
                    if v.grad is not None
                ]
                check(tf_a_grads[2],
                      np.transpose(torch_a_grads[0].detach().cpu()))

                # Test critic gradients.
                policy.critic_optims[0].zero_grad()
                assert all(
                    torch.mean(v.grad) == 0.0
                    for v in policy.model.q_variables() if v.grad is not None)
                assert all(
                    torch.min(v.grad) == 0.0
                    for v in policy.model.q_variables() if v.grad is not None)
                assert policy.model.log_alpha.grad is None
                c[0].backward()
                assert not all(
                    torch.mean(v.grad) == 0
                    for v in policy.model.q_variables() if v.grad is not None)
                assert not all(
                    torch.min(v.grad) == 0
                    for v in policy.model.q_variables() if v.grad is not None)
                assert policy.model.log_alpha.grad is None
                # Compare with tf ones.
                torch_c_grads = [v.grad for v in policy.model.q_variables()]
                check(tf_c_grads[0],
                      np.transpose(torch_c_grads[2].detach().cpu()))
                # Compare (unchanged(!) actor grads) with tf ones.
                torch_a_grads = [
                    v.grad for v in policy.model.policy_variables()
                ]
                check(tf_a_grads[2],
                      np.transpose(torch_a_grads[0].detach().cpu()))

                # Test alpha gradient.
                policy.alpha_optim.zero_grad()
                assert policy.model.log_alpha.grad is None
                e.backward()
                assert policy.model.log_alpha.grad is not None
                check(policy.model.log_alpha.grad, tf_e_grads)

            check(c, expect_c)
            check(a, expect_a)
            check(e, expect_e)
            check(t, expect_t)

            # Store this framework's losses in prev_fw_loss to compare with
            # next framework's outputs.
            if prev_fw_loss is not None:
                check(c, prev_fw_loss[0])
                check(a, prev_fw_loss[1])
                check(e, prev_fw_loss[2])
                check(t, prev_fw_loss[3])

            prev_fw_loss = (c, a, e, t)

            # Update weights from our batch (n times).
            for update_iteration in range(5):
                print("train iteration {}".format(update_iteration))
                if fw == "tf":
                    in_ = self._get_batch_helper(obs_size, actions, batch_size)
                    tf_inputs.append(in_)
                    # Set a fake-batch to use
                    # (instead of sampling from replay buffer).
                    buf = LocalReplayBuffer.get_instance_for_testing()
                    buf._fake_batch = in_
                    trainer.train()
                    updated_weights = policy.get_weights()
                    # Net must have changed.
                    if tf_updated_weights:
                        check(updated_weights["default_policy/fc_1/kernel"],
                              tf_updated_weights[-1]
                              ["default_policy/fc_1/kernel"],
                              false=True)
                    tf_updated_weights.append(updated_weights)

                # Compare with updated tf-weights. Must all be the same.
                else:
                    tf_weights = tf_updated_weights[update_iteration]
                    in_ = tf_inputs[update_iteration]
                    # Set a fake-batch to use
                    # (instead of sampling from replay buffer).
                    buf = LocalReplayBuffer.get_instance_for_testing()
                    buf._fake_batch = in_
                    trainer.train()
                    # Compare updated model.
                    for tf_key in sorted(tf_weights.keys()):
                        if re.search("_[23]|alpha", tf_key):
                            continue
                        tf_var = tf_weights[tf_key]
                        torch_var = policy.model.state_dict()[map_[tf_key]]
                        if tf_var.shape != torch_var.shape:
                            check(tf_var,
                                  np.transpose(torch_var.detach().cpu()),
                                  atol=0.003)
                        else:
                            check(tf_var, torch_var, atol=0.003)
                    # And alpha.
                    check(policy.model.log_alpha,
                          tf_weights["default_policy/log_alpha"])
                    # Compare target nets.
                    for tf_key in sorted(tf_weights.keys()):
                        if not re.search("_[23]", tf_key):
                            continue
                        tf_var = tf_weights[tf_key]
                        torch_var = policy.target_model.state_dict()[
                            map_[tf_key]]
                        if tf_var.shape != torch_var.shape:
                            check(tf_var,
                                  np.transpose(torch_var.detach().cpu()),
                                  atol=0.003)
                        else:
                            check(tf_var, torch_var, atol=0.003)
            trainer.stop()