def test_sac_loss_function(self): """Tests SAC loss function results across all frameworks.""" config = sac.DEFAULT_CONFIG.copy() # Run locally. config["num_workers"] = 0 config["learning_starts"] = 0 config["twin_q"] = False config["gamma"] = 0.99 # Switch on deterministic loss so we can compare the loss values. config["_deterministic_loss"] = True # Use very simple nets. config["Q_model"]["fcnet_hiddens"] = [10] config["policy_model"]["fcnet_hiddens"] = [10] # Make sure, timing differences do not affect trainer.train(). config["min_iter_time_s"] = 0 map_ = { # Normal net. "default_policy/sequential/action_1/kernel": "action_model." "action_0._model.0.weight", "default_policy/sequential/action_1/bias": "action_model." "action_0._model.0.bias", "default_policy/sequential/action_out/kernel": "action_model." "action_out._model.0.weight", "default_policy/sequential/action_out/bias": "action_model." "action_out._model.0.bias", "default_policy/sequential_1/q_hidden_0/kernel": "q_net." "q_hidden_0._model.0.weight", "default_policy/sequential_1/q_hidden_0/bias": "q_net." "q_hidden_0._model.0.bias", "default_policy/sequential_1/q_out/kernel": "q_net." "q_out._model.0.weight", "default_policy/sequential_1/q_out/bias": "q_net." "q_out._model.0.bias", "default_policy/value_out/kernel": "_value_branch." "_model.0.weight", "default_policy/value_out/bias": "_value_branch." "_model.0.bias", # Target net. "default_policy/sequential_2/action_1/kernel": "action_model." "action_0._model.0.weight", "default_policy/sequential_2/action_1/bias": "action_model." "action_0._model.0.bias", "default_policy/sequential_2/action_out/kernel": "action_model." "action_out._model.0.weight", "default_policy/sequential_2/action_out/bias": "action_model." "action_out._model.0.bias", "default_policy/sequential_3/q_hidden_0/kernel": "q_net." "q_hidden_0._model.0.weight", "default_policy/sequential_3/q_hidden_0/bias": "q_net." "q_hidden_0._model.0.bias", "default_policy/sequential_3/q_out/kernel": "q_net." "q_out._model.0.weight", "default_policy/sequential_3/q_out/bias": "q_net." "q_out._model.0.bias", "default_policy/value_out_1/kernel": "_value_branch." "_model.0.weight", "default_policy/value_out_1/bias": "_value_branch." "_model.0.bias", } env = SimpleEnv batch_size = 100 if env is SimpleEnv: obs_size = (batch_size, 1) actions = np.random.random(size=(batch_size, 1)) elif env == "CartPole-v0": obs_size = (batch_size, 4) actions = np.random.randint(0, 2, size=(batch_size, )) else: obs_size = (batch_size, 3) actions = np.random.random(size=(batch_size, 1)) # Batch of size=n. input_ = self._get_batch_helper(obs_size, actions, batch_size) # Simply compare loss values AND grads of all frameworks with each # other. prev_fw_loss = weights_dict = None expect_c, expect_a, expect_e, expect_t = None, None, None, None # History of tf-updated NN-weights over n training steps. tf_updated_weights = [] # History of input batches used. tf_inputs = [] for fw, sess in framework_iterator(config, frameworks=("tf", "torch"), session=True): # Generate Trainer and get its default Policy object. trainer = sac.SACTrainer(config=config, env=env) policy = trainer.get_policy() p_sess = None if sess: p_sess = policy.get_session() # Set all weights (of all nets) to fixed values. if weights_dict is None: assert fw == "tf" # Start with the tf vars-dict. weights_dict = policy.get_weights() else: assert fw == "torch" # Then transfer that to torch Model. model_dict = self._translate_weights_to_torch( weights_dict, map_) policy.model.load_state_dict(model_dict) policy.target_model.load_state_dict(model_dict) if fw == "tf": log_alpha = weights_dict["default_policy/log_alpha"] elif fw == "torch": # Actually convert to torch tensors. input_ = policy._lazy_tensor_dict(input_) input_ = {k: input_[k] for k in input_.keys()} log_alpha = policy.model.log_alpha.detach().numpy()[0] # Only run the expectation once, should be the same anyways # for all frameworks. if expect_c is None: expect_c, expect_a, expect_e, expect_t = \ self._sac_loss_helper(input_, weights_dict, sorted(weights_dict.keys()), log_alpha, fw, gamma=config["gamma"], sess=sess) # Get actual outs and compare to expectation AND previous # framework. c=critic, a=actor, e=entropy, t=td-error. if fw == "tf": c, a, e, t, tf_c_grads, tf_a_grads, tf_e_grads = \ p_sess.run([ policy.critic_loss, policy.actor_loss, policy.alpha_loss, policy.td_error, policy.optimizer().compute_gradients( policy.critic_loss[0], policy.model.q_variables()), policy.optimizer().compute_gradients( policy.actor_loss, policy.model.policy_variables()), policy.optimizer().compute_gradients( policy.alpha_loss, policy.model.log_alpha)], feed_dict=policy._get_loss_inputs_dict( input_, shuffle=False)) tf_c_grads = [g for g, v in tf_c_grads] tf_a_grads = [g for g, v in tf_a_grads] tf_e_grads = [g for g, v in tf_e_grads] elif fw == "torch": loss_torch(policy, policy.model, None, input_) c, a, e, t = policy.critic_loss, policy.actor_loss, \ policy.alpha_loss, policy.td_error # Test actor gradients. policy.actor_optim.zero_grad() assert all(v.grad is None for v in policy.model.q_variables()) assert all(v.grad is None for v in policy.model.policy_variables()) assert policy.model.log_alpha.grad is None a.backward() # `actor_loss` depends on Q-net vars (but these grads must # be ignored and overridden in critic_loss.backward!). assert not any(v.grad is None for v in policy.model.q_variables()) assert not all( torch.mean(v.grad) == 0 for v in policy.model.policy_variables()) assert not all( torch.min(v.grad) == 0 for v in policy.model.policy_variables()) assert policy.model.log_alpha.grad is None # Compare with tf ones. torch_a_grads = [ v.grad for v in policy.model.policy_variables() ] for tf_g, torch_g in zip(tf_a_grads, torch_a_grads): if tf_g.shape != torch_g.shape: check(tf_g, np.transpose(torch_g)) else: check(tf_g, torch_g) # Test critic gradients. policy.critic_optims[0].zero_grad() assert all( torch.mean(v.grad) == 0.0 for v in policy.model.q_variables()) assert all( torch.min(v.grad) == 0.0 for v in policy.model.q_variables()) assert policy.model.log_alpha.grad is None c[0].backward() assert not all( torch.mean(v.grad) == 0 for v in policy.model.q_variables()) assert not all( torch.min(v.grad) == 0 for v in policy.model.q_variables()) assert policy.model.log_alpha.grad is None # Compare with tf ones. torch_c_grads = [v.grad for v in policy.model.q_variables()] for tf_g, torch_g in zip(tf_c_grads, torch_c_grads): if tf_g.shape != torch_g.shape: check(tf_g, np.transpose(torch_g)) else: check(tf_g, torch_g) # Compare (unchanged(!) actor grads) with tf ones. torch_a_grads = [ v.grad for v in policy.model.policy_variables() ] for tf_g, torch_g in zip(tf_a_grads, torch_a_grads): if tf_g.shape != torch_g.shape: check(tf_g, np.transpose(torch_g)) else: check(tf_g, torch_g) # Test alpha gradient. policy.alpha_optim.zero_grad() assert policy.model.log_alpha.grad is None e.backward() assert policy.model.log_alpha.grad is not None check(policy.model.log_alpha.grad, tf_e_grads) check(c, expect_c) check(a, expect_a) check(e, expect_e) check(t, expect_t) # Store this framework's losses in prev_fw_loss to compare with # next framework's outputs. if prev_fw_loss is not None: check(c, prev_fw_loss[0]) check(a, prev_fw_loss[1]) check(e, prev_fw_loss[2]) check(t, prev_fw_loss[3]) prev_fw_loss = (c, a, e, t) # Update weights from our batch (n times). for update_iteration in range(10): print("train iteration {}".format(update_iteration)) if fw == "tf": in_ = self._get_batch_helper(obs_size, actions, batch_size) tf_inputs.append(in_) # Set a fake-batch to use # (instead of sampling from replay buffer). trainer.optimizer._fake_batch = in_ trainer.train() updated_weights = policy.get_weights() # Net must have changed. if tf_updated_weights: check(updated_weights[ "default_policy/sequential/action_1/kernel"], tf_updated_weights[-1] ["default_policy/sequential/action_1/kernel"], false=True) tf_updated_weights.append(updated_weights) # Compare with updated tf-weights. Must all be the same. else: tf_weights = tf_updated_weights[update_iteration] in_ = tf_inputs[update_iteration] # Set a fake-batch to use # (instead of sampling from replay buffer). trainer.optimizer._fake_batch = in_ trainer.train() # Compare updated model. for tf_key in sorted(tf_weights.keys())[2:10]: tf_var = tf_weights[tf_key] torch_var = policy.model.state_dict()[map_[tf_key]] if tf_var.shape != torch_var.shape: check(tf_var, np.transpose(torch_var), rtol=0.05) else: check(tf_var, torch_var, rtol=0.05) # And alpha. check(policy.model.log_alpha, tf_weights["default_policy/log_alpha"]) # Compare target nets. for tf_key in sorted(tf_weights.keys())[10:18]: tf_var = tf_weights[tf_key] torch_var = policy.target_model.state_dict()[ map_[tf_key]] if tf_var.shape != torch_var.shape: check(tf_var, np.transpose(torch_var), rtol=0.05) else: check(tf_var, torch_var, rtol=0.05)
def test_sac_loss_function(self): """Tests SAC loss function results across all frameworks.""" config = sac.DEFAULT_CONFIG.copy() # Run locally. config["num_workers"] = 0 config["learning_starts"] = 0 config["twin_q"] = False config["gamma"] = 0.99 # Switch on deterministic loss so we can compare the loss values. config["_deterministic_loss"] = True # Use very simple nets. config["Q_model"]["fcnet_hiddens"] = [10] config["policy_model"]["fcnet_hiddens"] = [10] # Make sure, timing differences do not affect trainer.train(). config["min_iter_time_s"] = 0 # Test SAC with Simplex action space. config["env_config"] = {"simplex_actions": True} map_ = { # Action net. "default_policy/fc_1/kernel": "action_model._hidden_layers.0." "_model.0.weight", "default_policy/fc_1/bias": "action_model._hidden_layers.0." "_model.0.bias", "default_policy/fc_out/kernel": "action_model." "_logits._model.0.weight", "default_policy/fc_out/bias": "action_model._logits._model.0.bias", "default_policy/value_out/kernel": "action_model." "_value_branch._model.0.weight", "default_policy/value_out/bias": "action_model." "_value_branch._model.0.bias", # Q-net. "default_policy/fc_1_1/kernel": "q_net." "_hidden_layers.0._model.0.weight", "default_policy/fc_1_1/bias": "q_net." "_hidden_layers.0._model.0.bias", "default_policy/fc_out_1/kernel": "q_net._logits._model.0.weight", "default_policy/fc_out_1/bias": "q_net._logits._model.0.bias", "default_policy/value_out_1/kernel": "q_net." "_value_branch._model.0.weight", "default_policy/value_out_1/bias": "q_net." "_value_branch._model.0.bias", "default_policy/log_alpha": "log_alpha", # Target action-net. "default_policy/fc_1_2/kernel": "action_model." "_hidden_layers.0._model.0.weight", "default_policy/fc_1_2/bias": "action_model." "_hidden_layers.0._model.0.bias", "default_policy/fc_out_2/kernel": "action_model." "_logits._model.0.weight", "default_policy/fc_out_2/bias": "action_model." "_logits._model.0.bias", "default_policy/value_out_2/kernel": "action_model." "_value_branch._model.0.weight", "default_policy/value_out_2/bias": "action_model." "_value_branch._model.0.bias", # Target Q-net "default_policy/fc_1_3/kernel": "q_net." "_hidden_layers.0._model.0.weight", "default_policy/fc_1_3/bias": "q_net." "_hidden_layers.0._model.0.bias", "default_policy/fc_out_3/kernel": "q_net." "_logits._model.0.weight", "default_policy/fc_out_3/bias": "q_net." "_logits._model.0.bias", "default_policy/value_out_3/kernel": "q_net." "_value_branch._model.0.weight", "default_policy/value_out_3/bias": "q_net." "_value_branch._model.0.bias", "default_policy/log_alpha_1": "log_alpha", } env = SimpleEnv batch_size = 100 obs_size = (batch_size, 1) actions = np.random.random(size=(batch_size, 2)) # Batch of size=n. input_ = self._get_batch_helper(obs_size, actions, batch_size) # Simply compare loss values AND grads of all frameworks with each # other. prev_fw_loss = weights_dict = None expect_c, expect_a, expect_e, expect_t = None, None, None, None # History of tf-updated NN-weights over n training steps. tf_updated_weights = [] # History of input batches used. tf_inputs = [] for fw, sess in framework_iterator(config, frameworks=("tf", "torch"), session=True): # Generate Trainer and get its default Policy object. trainer = sac.SACTrainer(config=config, env=env) policy = trainer.get_policy() p_sess = None if sess: p_sess = policy.get_session() # Set all weights (of all nets) to fixed values. if weights_dict is None: # Start with the tf vars-dict. assert fw in ["tf2", "tf", "tfe"] weights_dict = policy.get_weights() if fw == "tfe": log_alpha = weights_dict[10] weights_dict = self._translate_tfe_weights( weights_dict, map_) else: assert fw == "torch" # Then transfer that to torch Model. model_dict = self._translate_weights_to_torch( weights_dict, map_) # Have to add this here (not a parameter in tf, but must be # one in torch, so it gets properly copied to the GPU(s)). model_dict["target_entropy"] = policy.model.target_entropy policy.model.load_state_dict(model_dict) policy.target_model.load_state_dict(model_dict) if fw == "tf": log_alpha = weights_dict["default_policy/log_alpha"] elif fw == "torch": # Actually convert to torch tensors (by accessing everything). input_ = policy._lazy_tensor_dict(input_) input_ = {k: input_[k] for k in input_.keys()} log_alpha = policy.model.log_alpha.detach().cpu().numpy()[0] # Only run the expectation once, should be the same anyways # for all frameworks. if expect_c is None: expect_c, expect_a, expect_e, expect_t = \ self._sac_loss_helper(input_, weights_dict, sorted(weights_dict.keys()), log_alpha, fw, gamma=config["gamma"], sess=sess) # Get actual outs and compare to expectation AND previous # framework. c=critic, a=actor, e=entropy, t=td-error. if fw == "tf": c, a, e, t, tf_c_grads, tf_a_grads, tf_e_grads = \ p_sess.run([ policy.critic_loss, policy.actor_loss, policy.alpha_loss, policy.td_error, policy.optimizer().compute_gradients( policy.critic_loss[0], [v for v in policy.model.q_variables() if "value_" not in v.name]), policy.optimizer().compute_gradients( policy.actor_loss, [v for v in policy.model.policy_variables() if "value_" not in v.name]), policy.optimizer().compute_gradients( policy.alpha_loss, policy.model.log_alpha)], feed_dict=policy._get_loss_inputs_dict( input_, shuffle=False)) tf_c_grads = [g for g, v in tf_c_grads] tf_a_grads = [g for g, v in tf_a_grads] tf_e_grads = [g for g, v in tf_e_grads] elif fw == "tfe": with tf.GradientTape() as tape: tf_loss(policy, policy.model, None, input_) c, a, e, t = policy.critic_loss, policy.actor_loss, \ policy.alpha_loss, policy.td_error vars = tape.watched_variables() tf_c_grads = tape.gradient(c[0], vars[6:10]) tf_a_grads = tape.gradient(a, vars[2:6]) tf_e_grads = tape.gradient(e, vars[10]) elif fw == "torch": loss_torch(policy, policy.model, None, input_) c, a, e, t = policy.critic_loss, policy.actor_loss, \ policy.alpha_loss, policy.model.td_error # Test actor gradients. policy.actor_optim.zero_grad() assert all(v.grad is None for v in policy.model.q_variables()) assert all(v.grad is None for v in policy.model.policy_variables()) assert policy.model.log_alpha.grad is None a.backward() # `actor_loss` depends on Q-net vars (but these grads must # be ignored and overridden in critic_loss.backward!). assert not all( torch.mean(v.grad) == 0 for v in policy.model.policy_variables()) assert not all( torch.min(v.grad) == 0 for v in policy.model.policy_variables()) assert policy.model.log_alpha.grad is None # Compare with tf ones. torch_a_grads = [ v.grad for v in policy.model.policy_variables() if v.grad is not None ] check(tf_a_grads[2], np.transpose(torch_a_grads[0].detach().cpu())) # Test critic gradients. policy.critic_optims[0].zero_grad() assert all( torch.mean(v.grad) == 0.0 for v in policy.model.q_variables() if v.grad is not None) assert all( torch.min(v.grad) == 0.0 for v in policy.model.q_variables() if v.grad is not None) assert policy.model.log_alpha.grad is None c[0].backward() assert not all( torch.mean(v.grad) == 0 for v in policy.model.q_variables() if v.grad is not None) assert not all( torch.min(v.grad) == 0 for v in policy.model.q_variables() if v.grad is not None) assert policy.model.log_alpha.grad is None # Compare with tf ones. torch_c_grads = [v.grad for v in policy.model.q_variables()] check(tf_c_grads[0], np.transpose(torch_c_grads[2].detach().cpu())) # Compare (unchanged(!) actor grads) with tf ones. torch_a_grads = [ v.grad for v in policy.model.policy_variables() ] check(tf_a_grads[2], np.transpose(torch_a_grads[0].detach().cpu())) # Test alpha gradient. policy.alpha_optim.zero_grad() assert policy.model.log_alpha.grad is None e.backward() assert policy.model.log_alpha.grad is not None check(policy.model.log_alpha.grad, tf_e_grads) check(c, expect_c) check(a, expect_a) check(e, expect_e) check(t, expect_t) # Store this framework's losses in prev_fw_loss to compare with # next framework's outputs. if prev_fw_loss is not None: check(c, prev_fw_loss[0]) check(a, prev_fw_loss[1]) check(e, prev_fw_loss[2]) check(t, prev_fw_loss[3]) prev_fw_loss = (c, a, e, t) # Update weights from our batch (n times). for update_iteration in range(5): print("train iteration {}".format(update_iteration)) if fw == "tf": in_ = self._get_batch_helper(obs_size, actions, batch_size) tf_inputs.append(in_) # Set a fake-batch to use # (instead of sampling from replay buffer). buf = LocalReplayBuffer.get_instance_for_testing() buf._fake_batch = in_ trainer.train() updated_weights = policy.get_weights() # Net must have changed. if tf_updated_weights: check(updated_weights["default_policy/fc_1/kernel"], tf_updated_weights[-1] ["default_policy/fc_1/kernel"], false=True) tf_updated_weights.append(updated_weights) # Compare with updated tf-weights. Must all be the same. else: tf_weights = tf_updated_weights[update_iteration] in_ = tf_inputs[update_iteration] # Set a fake-batch to use # (instead of sampling from replay buffer). buf = LocalReplayBuffer.get_instance_for_testing() buf._fake_batch = in_ trainer.train() # Compare updated model. for tf_key in sorted(tf_weights.keys()): if re.search("_[23]|alpha", tf_key): continue tf_var = tf_weights[tf_key] torch_var = policy.model.state_dict()[map_[tf_key]] if tf_var.shape != torch_var.shape: check(tf_var, np.transpose(torch_var.detach().cpu()), atol=0.003) else: check(tf_var, torch_var, atol=0.003) # And alpha. check(policy.model.log_alpha, tf_weights["default_policy/log_alpha"]) # Compare target nets. for tf_key in sorted(tf_weights.keys()): if not re.search("_[23]", tf_key): continue tf_var = tf_weights[tf_key] torch_var = policy.target_model.state_dict()[ map_[tf_key]] if tf_var.shape != torch_var.shape: check(tf_var, np.transpose(torch_var.detach().cpu()), atol=0.003) else: check(tf_var, torch_var, atol=0.003) trainer.stop()