def setup(self, config): # Call super's `setup` to create rollout workers. super().setup(config) # Create local replay buffer. self.local_replay_buffer = MultiAgentReplayBuffer(num_shards=1, learning_starts=1000, capacity=50000, replay_batch_size=64)
def setup(self, config: PartialTrainerConfigDict): super().setup(config) # `training_iteration` implementation: Setup buffer in `setup`, not # in `execution_plan` (deprecated). if self.config["_disable_execution_plan_api"] is True: self.local_replay_buffer = MultiAgentReplayBuffer( learning_starts=self.config["learning_starts"], capacity=self.config["replay_buffer_size"], replay_batch_size=self.config["train_batch_size"], replay_sequence_length=1, )
def execution_plan(workers: WorkerSet, config: TrainerConfigDict, **kwargs) -> LocalIterator[dict]: assert len(kwargs) == 0, ( "Marwill execution_plan does NOT take any additional parameters") rollouts = ParallelRollouts(workers, mode="bulk_sync") replay_buffer = MultiAgentReplayBuffer( learning_starts=config["learning_starts"], capacity=config["replay_buffer_size"], replay_batch_size=config["train_batch_size"], replay_sequence_length=1, ) store_op = rollouts \ .for_each(StoreToReplayBuffer(local_buffer=replay_buffer)) replay_op = Replay(local_buffer=replay_buffer) \ .combine( ConcatBatches( min_batch_size=config["train_batch_size"], count_steps_by=config["multiagent"]["count_steps_by"], )) \ .for_each(TrainOneStep(workers)) train_op = Concurrently([store_op, replay_op], mode="round_robin", output_indexes=[1]) return StandardMetricsReporting(train_op, workers, config)
def custom_training_workflow(workers: WorkerSet, config: dict): local_replay_buffer = MultiAgentReplayBuffer(num_shards=1, learning_starts=1000, capacity=50000, replay_batch_size=64) def add_ppo_metrics(batch): print("PPO policy learning on samples from", batch.policy_batches.keys(), "env steps", batch.env_steps(), "agent steps", batch.env_steps()) metrics = _get_shared_metrics() metrics.counters["agent_steps_trained_PPO"] += batch.env_steps() return batch def add_dqn_metrics(batch): print("DQN policy learning on samples from", batch.policy_batches.keys(), "env steps", batch.env_steps(), "agent steps", batch.env_steps()) metrics = _get_shared_metrics() metrics.counters["agent_steps_trained_DQN"] += batch.env_steps() return batch # Generate common experiences. rollouts = ParallelRollouts(workers, mode="bulk_sync") r1, r2 = rollouts.duplicate(n=2) # DQN sub-flow. dqn_store_op = r1.for_each(SelectExperiences(["dqn_policy"])) \ .for_each( StoreToReplayBuffer(local_buffer=local_replay_buffer)) dqn_replay_op = Replay(local_buffer=local_replay_buffer) \ .for_each(add_dqn_metrics) \ .for_each(TrainOneStep(workers, policies=["dqn_policy"])) \ .for_each(UpdateTargetNetwork( workers, target_update_freq=500, policies=["dqn_policy"])) dqn_train_op = Concurrently([dqn_store_op, dqn_replay_op], mode="round_robin", output_indexes=[1]) # PPO sub-flow. ppo_train_op = r2.for_each(SelectExperiences(["ppo_policy"])) \ .combine(ConcatBatches( min_batch_size=200, count_steps_by="env_steps")) \ .for_each(add_ppo_metrics) \ .for_each(StandardizeFields(["advantages"])) \ .for_each(TrainOneStep( workers, policies=["ppo_policy"], num_sgd_iter=10, sgd_minibatch_size=128)) # Combined training flow train_op = Concurrently([ppo_train_op, dqn_train_op], mode="async", output_indexes=[1]) return StandardMetricsReporting(train_op, workers, config)
def test_store_to_replay_local(ray_start_regular_shared): buf = MultiAgentReplayBuffer(num_shards=1, learning_starts=200, capacity=1000, replay_batch_size=100, prioritized_replay_alpha=0.6, prioritized_replay_beta=0.4, prioritized_replay_eps=0.0001) assert buf.replay() is None workers = make_workers(0) a = ParallelRollouts(workers, mode="bulk_sync") b = a.for_each(StoreToReplayBuffer(local_buffer=buf)) next(b) assert buf.replay() is None # learning hasn't started yet next(b) assert buf.replay().count == 100 replay_op = Replay(local_buffer=buf) assert next(replay_op).count == 100
def execution_plan(workers: WorkerSet, config: TrainerConfigDict, **kwargs) -> LocalIterator[dict]: """Execution plan of the MARWIL/BC algorithm. Defines the distributed dataflow. Args: workers (WorkerSet): The WorkerSet for training the Polic(y/ies) of the Trainer. config (TrainerConfigDict): The trainer's configuration dict. Returns: LocalIterator[dict]: A local iterator over training metrics. """ assert len(kwargs) == 0, ( "Marwill execution_plan does NOT take any additional parameters") rollouts = ParallelRollouts(workers, mode="bulk_sync") replay_buffer = MultiAgentReplayBuffer( learning_starts=config["learning_starts"], capacity=config["replay_buffer_size"], replay_batch_size=config["train_batch_size"], replay_sequence_length=1, ) store_op = rollouts \ .for_each(StoreToReplayBuffer(local_buffer=replay_buffer)) replay_op = Replay(local_buffer=replay_buffer) \ .combine( ConcatBatches( min_batch_size=config["train_batch_size"], count_steps_by=config["multiagent"]["count_steps_by"], )) \ .for_each(TrainOneStep(workers)) train_op = Concurrently( [store_op, replay_op], mode="round_robin", output_indexes=[1]) return StandardMetricsReporting(train_op, workers, config)
def test_ddpg_loss_function(self): """Tests DDPG loss function results across all frameworks.""" config = ddpg.DEFAULT_CONFIG.copy() # Run locally. config["seed"] = 42 config["num_workers"] = 0 config["learning_starts"] = 0 config["twin_q"] = True config["use_huber"] = True config["huber_threshold"] = 1.0 config["gamma"] = 0.99 # Make this small (seems to introduce errors). config["l2_reg"] = 1e-10 config["prioritized_replay"] = False # Use very simple nets. config["actor_hiddens"] = [10] config["critic_hiddens"] = [10] # Make sure, timing differences do not affect trainer.train(). config["min_time_s_per_reporting"] = 0 config["timesteps_per_iteration"] = 100 map_ = { # Normal net. "default_policy/actor_hidden_0/kernel": "policy_model.action_0." "_model.0.weight", "default_policy/actor_hidden_0/bias": "policy_model.action_0." "_model.0.bias", "default_policy/actor_out/kernel": "policy_model.action_out." "_model.0.weight", "default_policy/actor_out/bias": "policy_model.action_out._model.0.bias", "default_policy/sequential/q_hidden_0/kernel": "q_model.q_hidden_0" "._model.0.weight", "default_policy/sequential/q_hidden_0/bias": "q_model.q_hidden_0." "_model.0.bias", "default_policy/sequential/q_out/kernel": "q_model.q_out._model." "0.weight", "default_policy/sequential/q_out/bias": "q_model.q_out._model.0.bias", # -- twin. "default_policy/sequential_1/twin_q_hidden_0/kernel": "twin_" "q_model.twin_q_hidden_0._model.0.weight", "default_policy/sequential_1/twin_q_hidden_0/bias": "twin_" "q_model.twin_q_hidden_0._model.0.bias", "default_policy/sequential_1/twin_q_out/kernel": "twin_" "q_model.twin_q_out._model.0.weight", "default_policy/sequential_1/twin_q_out/bias": "twin_" "q_model.twin_q_out._model.0.bias", # Target net. "default_policy/actor_hidden_0_1/kernel": "policy_model.action_0." "_model.0.weight", "default_policy/actor_hidden_0_1/bias": "policy_model.action_0." "_model.0.bias", "default_policy/actor_out_1/kernel": "policy_model.action_out." "_model.0.weight", "default_policy/actor_out_1/bias": "policy_model.action_out._model" ".0.bias", "default_policy/sequential_2/q_hidden_0/kernel": "q_model." "q_hidden_0._model.0.weight", "default_policy/sequential_2/q_hidden_0/bias": "q_model." "q_hidden_0._model.0.bias", "default_policy/sequential_2/q_out/kernel": "q_model." "q_out._model.0.weight", "default_policy/sequential_2/q_out/bias": "q_model.q_out._model.0.bias", # -- twin. "default_policy/sequential_3/twin_q_hidden_0/kernel": "twin_" "q_model.twin_q_hidden_0._model.0.weight", "default_policy/sequential_3/twin_q_hidden_0/bias": "twin_" "q_model.twin_q_hidden_0._model.0.bias", "default_policy/sequential_3/twin_q_out/kernel": "twin_" "q_model.twin_q_out._model.0.weight", "default_policy/sequential_3/twin_q_out/bias": "twin_" "q_model.twin_q_out._model.0.bias", } env = SimpleEnv batch_size = 100 obs_size = (batch_size, 1) actions = np.random.random(size=(batch_size, 1)) # Batch of size=n. input_ = self._get_batch_helper(obs_size, actions, batch_size) # Simply compare loss values AND grads of all frameworks with each # other. prev_fw_loss = weights_dict = None expect_c, expect_a, expect_t = None, None, None # History of tf-updated NN-weights over n training steps. tf_updated_weights = [] # History of input batches used. tf_inputs = [] for fw, sess in framework_iterator(config, frameworks=("tf", "torch"), session=True): # Generate Trainer and get its default Policy object. trainer = ddpg.DDPGTrainer(config=config, env=env) policy = trainer.get_policy() p_sess = None if sess: p_sess = policy.get_session() # Set all weights (of all nets) to fixed values. if weights_dict is None: assert fw == "tf" # Start with the tf vars-dict. weights_dict = policy.get_weights() else: assert fw == "torch" # Then transfer that to torch Model. model_dict = self._translate_weights_to_torch( weights_dict, map_) policy.model.load_state_dict(model_dict) policy.target_model.load_state_dict(model_dict) if fw == "torch": # Actually convert to torch tensors. input_ = policy._lazy_tensor_dict(input_) input_ = {k: input_[k] for k in input_.keys()} # Only run the expectation once, should be the same anyways # for all frameworks. if expect_c is None: expect_c, expect_a, expect_t = self._ddpg_loss_helper( input_, weights_dict, sorted(weights_dict.keys()), fw, gamma=config["gamma"], huber_threshold=config["huber_threshold"], l2_reg=config["l2_reg"], sess=sess, ) # Get actual outs and compare to expectation AND previous # framework. c=critic, a=actor, e=entropy, t=td-error. if fw == "tf": c, a, t, tf_c_grads, tf_a_grads = p_sess.run( [ policy.critic_loss, policy.actor_loss, policy.td_error, policy._critic_optimizer.compute_gradients( policy.critic_loss, policy.model.q_variables()), policy._actor_optimizer.compute_gradients( policy.actor_loss, policy.model.policy_variables()), ], feed_dict=policy._get_loss_inputs_dict(input_, shuffle=False), ) # Check pure loss values. check(c, expect_c) check(a, expect_a) check(t, expect_t) tf_c_grads = [g for g, v in tf_c_grads] tf_a_grads = [g for g, v in tf_a_grads] elif fw == "torch": loss_torch(policy, policy.model, None, input_) c, a, t = ( policy.get_tower_stats("critic_loss")[0], policy.get_tower_stats("actor_loss")[0], policy.get_tower_stats("td_error")[0], ) # Check pure loss values. check(c, expect_c) check(a, expect_a) check(t, expect_t) # Test actor gradients. policy._actor_optimizer.zero_grad() assert all(v.grad is None for v in policy.model.q_variables()) assert all(v.grad is None for v in policy.model.policy_variables()) a.backward() # `actor_loss` depends on Q-net vars # (but not twin-Q-net vars!). assert not any(v.grad is None for v in policy.model.q_variables()[:4]) assert all(v.grad is None for v in policy.model.q_variables()[4:]) assert not all( torch.mean(v.grad) == 0 for v in policy.model.policy_variables()) assert not all( torch.min(v.grad) == 0 for v in policy.model.policy_variables()) # Compare with tf ones. torch_a_grads = [ v.grad for v in policy.model.policy_variables() ] for tf_g, torch_g in zip(tf_a_grads, torch_a_grads): if tf_g.shape != torch_g.shape: check(tf_g, np.transpose(torch_g.cpu())) else: check(tf_g, torch_g) # Test critic gradients. policy._critic_optimizer.zero_grad() assert all(v.grad is None or torch.mean(v.grad) == 0.0 for v in policy.model.q_variables()) assert all(v.grad is None or torch.min(v.grad) == 0.0 for v in policy.model.q_variables()) c.backward() assert not all( torch.mean(v.grad) == 0 for v in policy.model.q_variables()) assert not all( torch.min(v.grad) == 0 for v in policy.model.q_variables()) # Compare with tf ones. torch_c_grads = [v.grad for v in policy.model.q_variables()] for tf_g, torch_g in zip(tf_c_grads, torch_c_grads): if tf_g.shape != torch_g.shape: check(tf_g, np.transpose(torch_g.cpu())) else: check(tf_g, torch_g) # Compare (unchanged(!) actor grads) with tf ones. torch_a_grads = [ v.grad for v in policy.model.policy_variables() ] for tf_g, torch_g in zip(tf_a_grads, torch_a_grads): if tf_g.shape != torch_g.shape: check(tf_g, np.transpose(torch_g.cpu())) else: check(tf_g, torch_g) # Store this framework's losses in prev_fw_loss to compare with # next framework's outputs. if prev_fw_loss is not None: check(c, prev_fw_loss[0]) check(a, prev_fw_loss[1]) check(t, prev_fw_loss[2]) prev_fw_loss = (c, a, t) # Update weights from our batch (n times). for update_iteration in range(6): print("train iteration {}".format(update_iteration)) if fw == "tf": in_ = self._get_batch_helper(obs_size, actions, batch_size) tf_inputs.append(in_) # Set a fake-batch to use # (instead of sampling from replay buffer). buf = MultiAgentReplayBuffer.get_instance_for_testing() buf._fake_batch = in_ trainer.train() updated_weights = policy.get_weights() # Net must have changed. if tf_updated_weights: check( updated_weights[ "default_policy/actor_hidden_0/kernel"], tf_updated_weights[-1] ["default_policy/actor_hidden_0/kernel"], false=True, ) tf_updated_weights.append(updated_weights) # Compare with updated tf-weights. Must all be the same. else: tf_weights = tf_updated_weights[update_iteration] in_ = tf_inputs[update_iteration] # Set a fake-batch to use # (instead of sampling from replay buffer). buf = MultiAgentReplayBuffer.get_instance_for_testing() buf._fake_batch = in_ trainer.train() # Compare updated model and target weights. for tf_key in tf_weights.keys(): tf_var = tf_weights[tf_key] # Model. if re.search( "actor_out_1|actor_hidden_0_1|sequential_[23]", tf_key): torch_var = policy.target_model.state_dict()[ map_[tf_key]] # Target model. else: torch_var = policy.model.state_dict()[map_[tf_key]] if tf_var.shape != torch_var.shape: check(tf_var, np.transpose(torch_var.cpu()), atol=0.1) else: check(tf_var, torch_var, atol=0.1) trainer.stop()
def test_sac_loss_function(self): """Tests SAC loss function results across all frameworks.""" config = sac.DEFAULT_CONFIG.copy() # Run locally. config["num_workers"] = 0 config["learning_starts"] = 0 config["twin_q"] = False config["gamma"] = 0.99 # Switch on deterministic loss so we can compare the loss values. config["_deterministic_loss"] = True # Use very simple nets. config["Q_model"]["fcnet_hiddens"] = [10] config["policy_model"]["fcnet_hiddens"] = [10] # Make sure, timing differences do not affect trainer.train(). config["min_time_s_per_reporting"] = 0 # Test SAC with Simplex action space. config["env_config"] = {"simplex_actions": True} map_ = { # Action net. "default_policy/fc_1/kernel": "action_model._hidden_layers.0." "_model.0.weight", "default_policy/fc_1/bias": "action_model._hidden_layers.0." "_model.0.bias", "default_policy/fc_out/kernel": "action_model." "_logits._model.0.weight", "default_policy/fc_out/bias": "action_model._logits._model.0.bias", "default_policy/value_out/kernel": "action_model." "_value_branch._model.0.weight", "default_policy/value_out/bias": "action_model." "_value_branch._model.0.bias", # Q-net. "default_policy/fc_1_1/kernel": "q_net." "_hidden_layers.0._model.0.weight", "default_policy/fc_1_1/bias": "q_net." "_hidden_layers.0._model.0.bias", "default_policy/fc_out_1/kernel": "q_net._logits._model.0.weight", "default_policy/fc_out_1/bias": "q_net._logits._model.0.bias", "default_policy/value_out_1/kernel": "q_net." "_value_branch._model.0.weight", "default_policy/value_out_1/bias": "q_net." "_value_branch._model.0.bias", "default_policy/log_alpha": "log_alpha", # Target action-net. "default_policy/fc_1_2/kernel": "action_model." "_hidden_layers.0._model.0.weight", "default_policy/fc_1_2/bias": "action_model." "_hidden_layers.0._model.0.bias", "default_policy/fc_out_2/kernel": "action_model." "_logits._model.0.weight", "default_policy/fc_out_2/bias": "action_model." "_logits._model.0.bias", "default_policy/value_out_2/kernel": "action_model." "_value_branch._model.0.weight", "default_policy/value_out_2/bias": "action_model." "_value_branch._model.0.bias", # Target Q-net "default_policy/fc_1_3/kernel": "q_net." "_hidden_layers.0._model.0.weight", "default_policy/fc_1_3/bias": "q_net." "_hidden_layers.0._model.0.bias", "default_policy/fc_out_3/kernel": "q_net." "_logits._model.0.weight", "default_policy/fc_out_3/bias": "q_net." "_logits._model.0.bias", "default_policy/value_out_3/kernel": "q_net." "_value_branch._model.0.weight", "default_policy/value_out_3/bias": "q_net." "_value_branch._model.0.bias", "default_policy/log_alpha_1": "log_alpha", } env = SimpleEnv batch_size = 100 obs_size = (batch_size, 1) actions = np.random.random(size=(batch_size, 2)) # Batch of size=n. input_ = self._get_batch_helper(obs_size, actions, batch_size) # Simply compare loss values AND grads of all frameworks with each # other. prev_fw_loss = weights_dict = None expect_c, expect_a, expect_e, expect_t = None, None, None, None # History of tf-updated NN-weights over n training steps. tf_updated_weights = [] # History of input batches used. tf_inputs = [] for fw, sess in framework_iterator(config, frameworks=("tf", "torch"), session=True): # Generate Trainer and get its default Policy object. trainer = sac.SACTrainer(config=config, env=env) policy = trainer.get_policy() p_sess = None if sess: p_sess = policy.get_session() # Set all weights (of all nets) to fixed values. if weights_dict is None: # Start with the tf vars-dict. assert fw in ["tf2", "tf", "tfe"] weights_dict = policy.get_weights() if fw == "tfe": log_alpha = weights_dict[10] weights_dict = self._translate_tfe_weights( weights_dict, map_) else: assert fw == "torch" # Then transfer that to torch Model. model_dict = self._translate_weights_to_torch( weights_dict, map_) # Have to add this here (not a parameter in tf, but must be # one in torch, so it gets properly copied to the GPU(s)). model_dict["target_entropy"] = policy.model.target_entropy policy.model.load_state_dict(model_dict) policy.target_model.load_state_dict(model_dict) if fw == "tf": log_alpha = weights_dict["default_policy/log_alpha"] elif fw == "torch": # Actually convert to torch tensors (by accessing everything). input_ = policy._lazy_tensor_dict(input_) input_ = {k: input_[k] for k in input_.keys()} log_alpha = policy.model.log_alpha.detach().cpu().numpy()[0] # Only run the expectation once, should be the same anyways # for all frameworks. if expect_c is None: expect_c, expect_a, expect_e, expect_t = self._sac_loss_helper( input_, weights_dict, sorted(weights_dict.keys()), log_alpha, fw, gamma=config["gamma"], sess=sess, ) # Get actual outs and compare to expectation AND previous # framework. c=critic, a=actor, e=entropy, t=td-error. if fw == "tf": c, a, e, t, tf_c_grads, tf_a_grads, tf_e_grads = p_sess.run( [ policy.critic_loss, policy.actor_loss, policy.alpha_loss, policy.td_error, policy.optimizer().compute_gradients( policy.critic_loss[0], [ v for v in policy.model.q_variables() if "value_" not in v.name ], ), policy.optimizer().compute_gradients( policy.actor_loss, [ v for v in policy.model.policy_variables() if "value_" not in v.name ], ), policy.optimizer().compute_gradients( policy.alpha_loss, policy.model.log_alpha), ], feed_dict=policy._get_loss_inputs_dict(input_, shuffle=False), ) tf_c_grads = [g for g, v in tf_c_grads] tf_a_grads = [g for g, v in tf_a_grads] tf_e_grads = [g for g, v in tf_e_grads] elif fw == "tfe": with tf.GradientTape() as tape: tf_loss(policy, policy.model, None, input_) c, a, e, t = ( policy.critic_loss, policy.actor_loss, policy.alpha_loss, policy.td_error, ) vars = tape.watched_variables() tf_c_grads = tape.gradient(c[0], vars[6:10]) tf_a_grads = tape.gradient(a, vars[2:6]) tf_e_grads = tape.gradient(e, vars[10]) elif fw == "torch": loss_torch(policy, policy.model, None, input_) c, a, e, t = ( policy.get_tower_stats("critic_loss")[0], policy.get_tower_stats("actor_loss")[0], policy.get_tower_stats("alpha_loss")[0], policy.get_tower_stats("td_error")[0], ) # Test actor gradients. policy.actor_optim.zero_grad() assert all(v.grad is None for v in policy.model.q_variables()) assert all(v.grad is None for v in policy.model.policy_variables()) assert policy.model.log_alpha.grad is None a.backward() # `actor_loss` depends on Q-net vars (but these grads must # be ignored and overridden in critic_loss.backward!). assert not all( torch.mean(v.grad) == 0 for v in policy.model.policy_variables()) assert not all( torch.min(v.grad) == 0 for v in policy.model.policy_variables()) assert policy.model.log_alpha.grad is None # Compare with tf ones. torch_a_grads = [ v.grad for v in policy.model.policy_variables() if v.grad is not None ] check(tf_a_grads[2], np.transpose(torch_a_grads[0].detach().cpu())) # Test critic gradients. policy.critic_optims[0].zero_grad() assert all( torch.mean(v.grad) == 0.0 for v in policy.model.q_variables() if v.grad is not None) assert all( torch.min(v.grad) == 0.0 for v in policy.model.q_variables() if v.grad is not None) assert policy.model.log_alpha.grad is None c[0].backward() assert not all( torch.mean(v.grad) == 0 for v in policy.model.q_variables() if v.grad is not None) assert not all( torch.min(v.grad) == 0 for v in policy.model.q_variables() if v.grad is not None) assert policy.model.log_alpha.grad is None # Compare with tf ones. torch_c_grads = [v.grad for v in policy.model.q_variables()] check(tf_c_grads[0], np.transpose(torch_c_grads[2].detach().cpu())) # Compare (unchanged(!) actor grads) with tf ones. torch_a_grads = [ v.grad for v in policy.model.policy_variables() ] check(tf_a_grads[2], np.transpose(torch_a_grads[0].detach().cpu())) # Test alpha gradient. policy.alpha_optim.zero_grad() assert policy.model.log_alpha.grad is None e.backward() assert policy.model.log_alpha.grad is not None check(policy.model.log_alpha.grad, tf_e_grads) check(c, expect_c) check(a, expect_a) check(e, expect_e) check(t, expect_t) # Store this framework's losses in prev_fw_loss to compare with # next framework's outputs. if prev_fw_loss is not None: check(c, prev_fw_loss[0]) check(a, prev_fw_loss[1]) check(e, prev_fw_loss[2]) check(t, prev_fw_loss[3]) prev_fw_loss = (c, a, e, t) # Update weights from our batch (n times). for update_iteration in range(5): print("train iteration {}".format(update_iteration)) if fw == "tf": in_ = self._get_batch_helper(obs_size, actions, batch_size) tf_inputs.append(in_) # Set a fake-batch to use # (instead of sampling from replay buffer). buf = MultiAgentReplayBuffer.get_instance_for_testing() buf._fake_batch = in_ trainer.train() updated_weights = policy.get_weights() # Net must have changed. if tf_updated_weights: check( updated_weights["default_policy/fc_1/kernel"], tf_updated_weights[-1] ["default_policy/fc_1/kernel"], false=True, ) tf_updated_weights.append(updated_weights) # Compare with updated tf-weights. Must all be the same. else: tf_weights = tf_updated_weights[update_iteration] in_ = tf_inputs[update_iteration] # Set a fake-batch to use # (instead of sampling from replay buffer). buf = MultiAgentReplayBuffer.get_instance_for_testing() buf._fake_batch = in_ trainer.train() # Compare updated model. for tf_key in sorted(tf_weights.keys()): if re.search("_[23]|alpha", tf_key): continue tf_var = tf_weights[tf_key] torch_var = policy.model.state_dict()[map_[tf_key]] if tf_var.shape != torch_var.shape: check( tf_var, np.transpose(torch_var.detach().cpu()), atol=0.003, ) else: check(tf_var, torch_var, atol=0.003) # And alpha. check(policy.model.log_alpha, tf_weights["default_policy/log_alpha"]) # Compare target nets. for tf_key in sorted(tf_weights.keys()): if not re.search("_[23]", tf_key): continue tf_var = tf_weights[tf_key] torch_var = policy.target_model.state_dict()[ map_[tf_key]] if tf_var.shape != torch_var.shape: check( tf_var, np.transpose(torch_var.detach().cpu()), atol=0.003, ) else: check(tf_var, torch_var, atol=0.003) trainer.stop()
class MyTrainer(Trainer): @classmethod @override(Trainer) def get_default_config(cls) -> TrainerConfigDict: # Run this Trainer with new `training_iteration` API and set some PPO-specific # parameters. return with_common_config({ "_disable_execution_plan_api": True, "num_sgd_iter": 10, "sgd_minibatch_size": 128, }) @override(Trainer) def setup(self, config): # Call super's `setup` to create rollout workers. super().setup(config) # Create local replay buffer. self.local_replay_buffer = MultiAgentReplayBuffer(num_shards=1, learning_starts=1000, capacity=50000, replay_batch_size=64) @override(Trainer) def training_iteration(self) -> ResultDict: # Generate common experiences, collect batch for PPO, store every (DQN) batch # into replay buffer. ppo_batches = [] num_env_steps = 0 # PPO batch size fixed at 200. while num_env_steps < 200: ma_batches = synchronous_parallel_sample(self.workers) # Loop through (parallely collected) ma-batches. for ma_batch in ma_batches: # Update sampled counters. self._counters[NUM_ENV_STEPS_SAMPLED] += ma_batch.count self._counters[ NUM_AGENT_STEPS_SAMPLED] += ma_batch.agent_steps() ppo_batch = ma_batch.policy_batches.pop("ppo_policy") # Add collected batches (only for DQN policy) to replay buffer. self.local_replay_buffer.add_batch(ma_batch) ppo_batches.append(ppo_batch) num_env_steps += ppo_batch.count # DQN sub-flow. dqn_train_results = {} dqn_train_batch = self.local_replay_buffer.replay() if dqn_train_batch is not None: dqn_train_results = train_one_step(self, dqn_train_batch, ["dqn_policy"]) self._counters[ "agent_steps_trained_DQN"] += dqn_train_batch.agent_steps() print( "DQN policy learning on samples from", "agent steps trained", dqn_train_batch.agent_steps(), ) # Update DQN's target net every 500 train steps. if (self._counters["agent_steps_trained_DQN"] - self._counters[LAST_TARGET_UPDATE_TS] >= 500): self.workers.local_worker().get_policy( "dqn_policy").update_target() self._counters[NUM_TARGET_UPDATES] += 1 self._counters[LAST_TARGET_UPDATE_TS] = self._counters[ "agent_steps_trained_DQN"] # PPO sub-flow. ppo_train_batch = SampleBatch.concat_samples(ppo_batches) self._counters[ "agent_steps_trained_PPO"] += ppo_train_batch.agent_steps() # Standardize advantages. ppo_train_batch[Postprocessing.ADVANTAGES] = standardized( ppo_train_batch[Postprocessing.ADVANTAGES]) print( "PPO policy learning on samples from", "agent steps trained", ppo_train_batch.agent_steps(), ) ppo_train_batch = MultiAgentBatch({"ppo_policy": ppo_train_batch}, ppo_train_batch.count) ppo_train_results = train_one_step(self, ppo_train_batch, ["ppo_policy"]) # Combine results for PPO and DQN into one results dict. results = dict(ppo_train_results, **dqn_train_results) return results
class MARWILTrainer(Trainer): @classmethod @override(Trainer) def get_default_config(cls) -> TrainerConfigDict: return DEFAULT_CONFIG @override(Trainer) def validate_config(self, config: TrainerConfigDict) -> None: # Call super's validation method. super().validate_config(config) if config["num_gpus"] > 1: raise ValueError("`num_gpus` > 1 not yet supported for MARWIL!") if config["postprocess_inputs"] is False and config["beta"] > 0.0: raise ValueError( "`postprocess_inputs` must be True for MARWIL (to " "calculate accum., discounted returns)!") @override(Trainer) def get_default_policy_class(self, config: TrainerConfigDict) -> Type[Policy]: if config["framework"] == "torch": from ray.rllib.agents.marwil.marwil_torch_policy import MARWILTorchPolicy return MARWILTorchPolicy else: return MARWILTFPolicy @override(Trainer) def setup(self, config: PartialTrainerConfigDict): super().setup(config) # `training_iteration` implementation: Setup buffer in `setup`, not # in `execution_plan` (deprecated). if self.config["_disable_execution_plan_api"] is True: self.local_replay_buffer = MultiAgentReplayBuffer( learning_starts=self.config["learning_starts"], capacity=self.config["replay_buffer_size"], replay_batch_size=self.config["train_batch_size"], replay_sequence_length=1, ) @override(Trainer) def training_iteration(self) -> ResultDict: # Collect SampleBatches from sample workers. batch = synchronous_parallel_sample(worker_set=self.workers) batch = batch.as_multi_agent() self._counters[NUM_AGENT_STEPS_SAMPLED] += batch.agent_steps() self._counters[NUM_ENV_STEPS_SAMPLED] += batch.env_steps() # Add batch to replay buffer. self.local_replay_buffer.add_batch(batch) # Pull batch from replay buffer and train on it. train_batch = self.local_replay_buffer.replay() # Train. if self.config["simple_optimizer"]: train_results = train_one_step(self, train_batch) else: train_results = multi_gpu_train_one_step(self, train_batch) self._counters[NUM_AGENT_STEPS_TRAINED] += batch.agent_steps() self._counters[NUM_ENV_STEPS_TRAINED] += batch.env_steps() global_vars = { "timestep": self._counters[NUM_AGENT_STEPS_SAMPLED], } # Update weights - after learning on the local worker - on all remote # workers. if self.workers.remote_workers(): with self._timers[WORKER_UPDATE_TIMER]: self.workers.sync_weights(global_vars=global_vars) # Update global vars on local worker as well. self.workers.local_worker().set_global_vars(global_vars) return train_results @staticmethod @override(Trainer) def execution_plan(workers: WorkerSet, config: TrainerConfigDict, **kwargs) -> LocalIterator[dict]: assert ( len(kwargs) == 0 ), "Marwill execution_plan does NOT take any additional parameters" rollouts = ParallelRollouts(workers, mode="bulk_sync") replay_buffer = MultiAgentReplayBuffer( learning_starts=config["learning_starts"], capacity=config["replay_buffer_size"], replay_batch_size=config["train_batch_size"], replay_sequence_length=1, ) store_op = rollouts.for_each( StoreToReplayBuffer(local_buffer=replay_buffer)) replay_op = (Replay(local_buffer=replay_buffer).combine( ConcatBatches( min_batch_size=config["train_batch_size"], count_steps_by=config["multiagent"]["count_steps_by"], )).for_each(TrainOneStep(workers))) train_op = Concurrently([store_op, replay_op], mode="round_robin", output_indexes=[1]) return StandardMetricsReporting(train_op, workers, config)