def training_iteration(self) -> ResultDict: # Collect SampleBatches from sample workers. batch = synchronous_parallel_sample(worker_set=self.workers) batch = batch.as_multi_agent() self._counters[NUM_AGENT_STEPS_SAMPLED] += batch.agent_steps() self._counters[NUM_ENV_STEPS_SAMPLED] += batch.env_steps() # Add batch to replay buffer. self.local_replay_buffer.add_batch(batch) # Pull batch from replay buffer and train on it. train_batch = self.local_replay_buffer.replay() # Train. if self.config["simple_optimizer"]: train_results = train_one_step(self, train_batch) else: train_results = multi_gpu_train_one_step(self, train_batch) self._counters[NUM_AGENT_STEPS_TRAINED] += batch.agent_steps() self._counters[NUM_ENV_STEPS_TRAINED] += batch.env_steps() global_vars = { "timestep": self._counters[NUM_AGENT_STEPS_SAMPLED], } # Update weights - after learning on the local worker - on all remote # workers. if self.workers.remote_workers(): with self._timers[WORKER_UPDATE_TIMER]: self.workers.sync_weights(global_vars=global_vars) # Update global vars on local worker as well. self.workers.local_worker().set_global_vars(global_vars) return train_results
def training_iteration(self) -> ResultDict: """Simple Q training iteration function. Simple Q consists of the following steps: - (1) Sample (MultiAgentBatch) from workers... - (2) Store new samples in replay buffer. - (3) Sample training batch (MultiAgentBatch) from replay buffer. - (4) Learn on training batch. - (5) Update target network every target_network_update_freq steps. - (6) Return all collected metrics for the iteration. Returns: The results dict from executing the training iteration. """ batch_size = self.config["train_batch_size"] local_worker = self.workers.local_worker() # (1) Sample (MultiAgentBatches) from workers new_sample_batches = synchronous_parallel_sample( worker_set=self.workers, concat=False ) for s in new_sample_batches: # Update counters self._counters[NUM_ENV_STEPS_SAMPLED] += len(s) self._counters[NUM_AGENT_STEPS_SAMPLED] += ( len(s) if isinstance(s, SampleBatch) else s.agent_steps() ) # (2) Store new samples in replay buffer self.local_replay_buffer.add(s) # (3) Sample training batch (MultiAgentBatch) from replay buffer. train_batch = self.local_replay_buffer.sample(batch_size) # (4) Learn on training batch. # Use simple optimizer (only for multi-agent or tf-eager; all other # cases should use the multi-GPU optimizer, even if only using 1 GPU) if self.config.get("simple_optimizer") is True: train_results = train_one_step(self, train_batch) else: train_results = multi_gpu_train_one_step(self, train_batch) # (5) Update target network every target_network_update_freq steps cur_ts = self._counters[NUM_ENV_STEPS_SAMPLED] last_update = self._counters[LAST_TARGET_UPDATE_TS] if cur_ts - last_update >= self.config["target_network_update_freq"]: to_update = local_worker.get_policies_to_train() local_worker.foreach_policy_to_train( lambda p, pid: pid in to_update and p.update_target() ) self._counters[NUM_TARGET_UPDATES] += 1 self._counters[LAST_TARGET_UPDATE_TS] = cur_ts # Update remote workers' weights after learning on local worker if self.workers.remote_workers(): with self._timers[SYNCH_WORKER_WEIGHTS_TIMER]: self.workers.sync_weights() # (6) Return all collected metrics for the iteration. return train_results
def training_step(self) -> ResultDict: with self._timers[SAMPLE_TIMER]: train_batch = synchronous_parallel_sample(worker_set=self.workers) train_batch = train_batch.as_multi_agent() self._counters[NUM_AGENT_STEPS_SAMPLED] += train_batch.agent_steps() self._counters[NUM_ENV_STEPS_SAMPLED] += train_batch.env_steps() # Postprocess batch before we learn on it. post_fn = self.config.get("before_learn_on_batch") or (lambda b, *a: b) train_batch = post_fn(train_batch, self.workers, self.config) # Learn on training batch. # Use simple optimizer (only for multi-agent or tf-eager; all other # cases should use the multi-GPU optimizer, even if only using 1 GPU) if self.config.get("simple_optimizer", False): train_results = train_one_step(self, train_batch) else: train_results = multi_gpu_train_one_step(self, train_batch) # update target every few gradient updates # Update target network every `target_network_update_freq` training steps. cur_ts = self._counters[NUM_AGENT_STEPS_TRAINED if self. _by_agent_steps else NUM_ENV_STEPS_TRAINED] last_update = self._counters[LAST_TARGET_UPDATE_TS] if cur_ts - last_update >= self.config["target_network_update_freq"]: with self._timers[TARGET_NET_UPDATE_TIMER]: to_update = self.workers.local_worker().get_policies_to_train() self.workers.local_worker().foreach_policy_to_train( lambda p, pid: pid in to_update and p.update_target()) self._counters[NUM_TARGET_UPDATES] += 1 self._counters[LAST_TARGET_UPDATE_TS] = cur_ts self._counters[NUM_GRADIENT_UPDATES] += 1 return train_results
def training_step(self) -> ResultDict: # Collect SampleBatches from sample workers. batch = synchronous_parallel_sample(worker_set=self.workers) batch = batch.as_multi_agent() self._counters[NUM_AGENT_STEPS_SAMPLED] += batch.agent_steps() self._counters[NUM_ENV_STEPS_SAMPLED] += batch.env_steps() # Add batch to replay buffer. self.local_replay_buffer.add(batch) # Sample training batch from replay buffer. train_batch = sample_min_n_steps_from_buffer( self.local_replay_buffer, self.config["train_batch_size"], count_by_agent_steps=self._by_agent_steps, ) # Old-style replay buffers return None if learning has not started if not train_batch: return {} # Postprocess batch before we learn on it. post_fn = self.config.get("before_learn_on_batch") or (lambda b, *a: b) train_batch = post_fn(train_batch, self.workers, self.config) # Learn on training batch. # Use simple optimizer (only for multi-agent or tf-eager; all other # cases should use the multi-GPU optimizer, even if only using 1 GPU) if self.config.get("simple_optimizer") is True: train_results = train_one_step(self, train_batch) else: train_results = multi_gpu_train_one_step(self, train_batch) # Update replay buffer priorities. update_priorities_in_replay_buffer( self.local_replay_buffer, self.config, train_batch, train_results, ) # Update target network every `target_network_update_freq` training steps. cur_ts = self._counters[NUM_AGENT_STEPS_TRAINED if self. _by_agent_steps else NUM_ENV_STEPS_TRAINED] last_update = self._counters[LAST_TARGET_UPDATE_TS] if cur_ts - last_update >= self.config["target_network_update_freq"]: with self._timers[TARGET_NET_UPDATE_TIMER]: to_update = self.workers.local_worker().get_policies_to_train() self.workers.local_worker().foreach_policy_to_train( lambda p, pid: pid in to_update and p.update_target()) self._counters[NUM_TARGET_UPDATES] += 1 self._counters[LAST_TARGET_UPDATE_TS] = cur_ts # Update remote workers's weights after learning on local worker if self.workers.remote_workers(): with self._timers[SYNCH_WORKER_WEIGHTS_TIMER]: self.workers.sync_weights() # Return all collected metrics for the iteration. return train_results
def training_step(self) -> ResultDict: """TODO: Returns: The results dict from executing the training iteration. """ # Sample n MultiAgentBatches from n workers. new_sample_batches = synchronous_parallel_sample( worker_set=self.workers, concat=False) for batch in new_sample_batches: # Update sampling step counters. self._counters[NUM_ENV_STEPS_SAMPLED] += batch.env_steps() self._counters[NUM_AGENT_STEPS_SAMPLED] += batch.agent_steps() # Store new samples in the replay buffer # Use deprecated add_batch() to support old replay buffers for now if self.local_replay_buffer is not None: self.local_replay_buffer.add(batch) if self.local_replay_buffer is not None: train_batch = self.local_replay_buffer.sample( self.config["train_batch_size"]) else: train_batch = SampleBatch.concat_samples(new_sample_batches) # Learn on the training batch. # Use simple optimizer (only for multi-agent or tf-eager; all other # cases should use the multi-GPU optimizer, even if only using 1 GPU) train_results = {} if train_batch is not None: if self.config.get("simple_optimizer") is True: train_results = train_one_step(self, train_batch) else: train_results = multi_gpu_train_one_step(self, train_batch) # TODO: Move training steps counter update outside of `train_one_step()` method. # # Update train step counters. # self._counters[NUM_ENV_STEPS_TRAINED] += train_batch.env_steps() # self._counters[NUM_AGENT_STEPS_TRAINED] += train_batch.agent_steps() # Update weights and global_vars - after learning on the local worker - on all # remote workers. global_vars = { "timestep": self._counters[NUM_ENV_STEPS_SAMPLED], } with self._timers[SYNCH_WORKER_WEIGHTS_TIMER]: self.workers.sync_weights(global_vars=global_vars) # Return all collected metrics for the iteration. return train_results
def training_step(self) -> ResultDict: total_transitions = len(self.local_replay_buffer) bsize = self.config["train_batch_size"] n_batches_per_epoch = total_transitions // bsize results = [] for batch_iter in range(n_batches_per_epoch): # Sample training batch from replay buffer. train_batch = self.local_replay_buffer.sample(bsize) # Postprocess batch before we learn on it. post_fn = self.config.get("before_learn_on_batch") or ( lambda b, *a: b) train_batch = post_fn(train_batch, self.workers, self.config) # Learn on training batch. # Use simple optimizer (only for multi-agent or tf-eager; all other # cases should use the multi-GPU optimizer, even if only using 1 GPU) if self.config.get("simple_optimizer", False): train_results = train_one_step(self, train_batch) else: train_results = multi_gpu_train_one_step(self, train_batch) # update target every few gradient updates cur_ts = self._counters[NUM_GRADIENT_UPDATES] last_update = self._counters[LAST_TARGET_UPDATE_TS] if cur_ts - last_update >= self.config[ "target_update_grad_intervals"]: with self._timers[TARGET_NET_UPDATE_TIMER]: to_update = self.workers.local_worker( ).get_policies_to_train() self.workers.local_worker().foreach_policy_to_train( lambda p, pid: pid in to_update and p.update_target()) self._counters[NUM_TARGET_UPDATES] += 1 self._counters[LAST_TARGET_UPDATE_TS] = cur_ts self._counters[NUM_GRADIENT_UPDATES] += 1 results.append(train_results) summary = tree.map_structure_with_path( lambda path, *v: float(np.mean(v)), *results) return summary
def training_step(self) -> ResultDict: # Collect SampleBatches from sample workers. with self._timers[SAMPLE_TIMER]: batch = synchronous_parallel_sample(worker_set=self.workers) batch = batch.as_multi_agent() self._counters[NUM_AGENT_STEPS_SAMPLED] += batch.agent_steps() self._counters[NUM_ENV_STEPS_SAMPLED] += batch.env_steps() # Add batch to replay buffer. self.local_replay_buffer.add(batch) # Pull batch from replay buffer and train on it. train_batch = sample_min_n_steps_from_buffer( self.local_replay_buffer, self.config["train_batch_size"], count_by_agent_steps=self._by_agent_steps, ) # Train. if self.config["simple_optimizer"]: train_results = train_one_step(self, train_batch) else: train_results = multi_gpu_train_one_step(self, train_batch) # TODO: Move training steps counter update outside of `train_one_step()` method. # # Update train step counters. # self._counters[NUM_ENV_STEPS_TRAINED] += train_batch.env_steps() # self._counters[NUM_AGENT_STEPS_TRAINED] += train_batch.agent_steps() global_vars = { "timestep": self._counters[NUM_AGENT_STEPS_SAMPLED], } # Update weights - after learning on the local worker - on all remote # workers. if self.workers.remote_workers(): with self._timers[SYNCH_WORKER_WEIGHTS_TIMER]: self.workers.sync_weights(global_vars=global_vars) # Update global vars on local worker as well. self.workers.local_worker().set_global_vars(global_vars) return train_results
def training_step(self) -> ResultDict: # Collect SampleBatches from sample workers until we have a full batch. if self._by_agent_steps: train_batch = synchronous_parallel_sample( worker_set=self.workers, max_agent_steps=self.config["train_batch_size"]) else: train_batch = synchronous_parallel_sample( worker_set=self.workers, max_env_steps=self.config["train_batch_size"]) train_batch = train_batch.as_multi_agent() self._counters[NUM_AGENT_STEPS_SAMPLED] += train_batch.agent_steps() self._counters[NUM_ENV_STEPS_SAMPLED] += train_batch.env_steps() # Standardize advantages train_batch = standardize_fields(train_batch, ["advantages"]) # Train if self.config["simple_optimizer"]: train_results = train_one_step(self, train_batch) else: train_results = multi_gpu_train_one_step(self, train_batch) global_vars = { "timestep": self._counters[NUM_AGENT_STEPS_SAMPLED], } # Update weights - after learning on the local worker - on all remote # workers. if self.workers.remote_workers(): with self._timers[WORKER_UPDATE_TIMER]: self.workers.sync_weights(global_vars=global_vars) # For each policy: update KL scale and warn about possible issues for policy_id, policy_info in train_results.items(): # Update KL loss with dynamic scaling # for each (possibly multiagent) policy we are training kl_divergence = policy_info[LEARNER_STATS_KEY].get("kl") self.get_policy(policy_id).update_kl(kl_divergence) # Warn about excessively high value function loss scaled_vf_loss = (self.config["vf_loss_coeff"] * policy_info[LEARNER_STATS_KEY]["vf_loss"]) policy_loss = policy_info[LEARNER_STATS_KEY]["policy_loss"] if (log_once("ppo_warned_lr_ratio") and self.config.get("model", {}).get("vf_share_layers") and scaled_vf_loss > 100): logger.warning( "The magnitude of your value function loss for policy: {} is " "extremely large ({}) compared to the policy loss ({}). This " "can prevent the policy from learning. Consider scaling down " "the VF loss by reducing vf_loss_coeff, or disabling " "vf_share_layers.".format(policy_id, scaled_vf_loss, policy_loss)) # Warn about bad clipping configs. train_batch.policy_batches[policy_id].set_get_interceptor(None) mean_reward = train_batch.policy_batches[policy_id][ "rewards"].mean() if (log_once("ppo_warned_vf_clip") and mean_reward > self.config["vf_clip_param"]): self.warned_vf_clip = True logger.warning( f"The mean reward returned from the environment is {mean_reward}" f" but the vf_clip_param is set to {self.config['vf_clip_param']}." f" Consider increasing it for policy: {policy_id} to improve" " value function convergence.") # Update global vars on local worker as well. self.workers.local_worker().set_global_vars(global_vars) return train_results
def training_iteration(self) -> ResultDict: """Simple Q training iteration function. Simple Q consists of the following steps: - Sample n MultiAgentBatches from n workers synchronously. - Store new samples in the replay buffer. - Sample one training MultiAgentBatch from the replay buffer. - Learn on the training batch. - Update the target network every `target_network_update_freq` steps. - Return all collected training metrics for the iteration. Returns: The results dict from executing the training iteration. """ batch_size = self.config["train_batch_size"] local_worker = self.workers.local_worker() # Sample n MultiAgentBatches from n workers. new_sample_batches = synchronous_parallel_sample( worker_set=self.workers, concat=False) for batch in new_sample_batches: # Update sampling step counters. self._counters[NUM_ENV_STEPS_SAMPLED] += batch.env_steps() self._counters[NUM_AGENT_STEPS_SAMPLED] += batch.agent_steps() # Store new samples in the replay buffer self.local_replay_buffer.add(batch) # Sample one training MultiAgentBatch from replay buffer. train_batch = self.local_replay_buffer.sample(batch_size) # Learn on the training batch. # Use simple optimizer (only for multi-agent or tf-eager; all other # cases should use the multi-GPU optimizer, even if only using 1 GPU) if self.config.get("simple_optimizer") is True: train_results = train_one_step(self, train_batch) else: train_results = multi_gpu_train_one_step(self, train_batch) # TODO: Move training steps counter update outside of `train_one_step()` method. # # Update train step counters. # self._counters[NUM_ENV_STEPS_TRAINED] += train_batch.env_steps() # self._counters[NUM_AGENT_STEPS_TRAINED] += train_batch.agent_steps() # Update target network every `target_network_update_freq` steps. cur_ts = self._counters[NUM_ENV_STEPS_SAMPLED] last_update = self._counters[LAST_TARGET_UPDATE_TS] if cur_ts - last_update >= self.config["target_network_update_freq"]: with self._timers[TARGET_NET_UPDATE_TIMER]: to_update = local_worker.get_policies_to_train() local_worker.foreach_policy_to_train( lambda p, pid: pid in to_update and p.update_target()) self._counters[NUM_TARGET_UPDATES] += 1 self._counters[LAST_TARGET_UPDATE_TS] = cur_ts # Update remote workers' weights after learning on local worker. if self.workers.remote_workers(): with self._timers[SYNCH_WORKER_WEIGHTS_TIMER]: self.workers.sync_weights() # Return all collected metrics for the iteration. return train_results
def training_iteration(self) -> ResultDict: # Generate common experiences, collect batch for PPO, store every (DQN) batch # into replay buffer. ppo_batches = [] num_env_steps = 0 # PPO batch size fixed at 200. while num_env_steps < 200: ma_batches = synchronous_parallel_sample(worker_set=self.workers, concat=False) # Loop through (parallely collected) ma-batches. for ma_batch in ma_batches: # Update sampled counters. self._counters[NUM_ENV_STEPS_SAMPLED] += ma_batch.count self._counters[ NUM_AGENT_STEPS_SAMPLED] += ma_batch.agent_steps() ppo_batch = ma_batch.policy_batches.pop("ppo_policy") # Add collected batches (only for DQN policy) to replay buffer. self.local_replay_buffer.add(ma_batch) ppo_batches.append(ppo_batch) num_env_steps += ppo_batch.count # DQN sub-flow. dqn_train_results = {} dqn_train_batch = self.local_replay_buffer.sample(num_items=64) if dqn_train_batch is not None: dqn_train_results = train_one_step(self, dqn_train_batch, ["dqn_policy"]) self._counters[ "agent_steps_trained_DQN"] += dqn_train_batch.agent_steps() print( "DQN policy learning on samples from", "agent steps trained", dqn_train_batch.agent_steps(), ) # Update DQN's target net every 500 train steps. if (self._counters["agent_steps_trained_DQN"] - self._counters[LAST_TARGET_UPDATE_TS] >= 500): self.workers.local_worker().get_policy( "dqn_policy").update_target() self._counters[NUM_TARGET_UPDATES] += 1 self._counters[LAST_TARGET_UPDATE_TS] = self._counters[ "agent_steps_trained_DQN"] # PPO sub-flow. ppo_train_batch = SampleBatch.concat_samples(ppo_batches) self._counters[ "agent_steps_trained_PPO"] += ppo_train_batch.agent_steps() # Standardize advantages. ppo_train_batch[Postprocessing.ADVANTAGES] = standardized( ppo_train_batch[Postprocessing.ADVANTAGES]) print( "PPO policy learning on samples from", "agent steps trained", ppo_train_batch.agent_steps(), ) ppo_train_batch = MultiAgentBatch({"ppo_policy": ppo_train_batch}, ppo_train_batch.count) ppo_train_results = train_one_step(self, ppo_train_batch, ["ppo_policy"]) # Combine results for PPO and DQN into one results dict. results = dict(ppo_train_results, **dqn_train_results) return results
def training_iteration(self) -> ResultDict: """DQN training iteration function. Each training iteration, we: - Sample (MultiAgentBatch) from workers. - Store new samples in replay buffer. - Sample training batch (MultiAgentBatch) from replay buffer. - Learn on training batch. - Update remote workers' new policy weights. - Update target network every `target_network_update_freq` sample steps. - Return all collected metrics for the iteration. Returns: The results dict from executing the training iteration. """ train_results = {} # We alternate between storing new samples and sampling and training store_weight, sample_and_train_weight = calculate_rr_weights(self.config) for _ in range(store_weight): # Sample (MultiAgentBatch) from workers. new_sample_batch = synchronous_parallel_sample( worker_set=self.workers, concat=True ) # Update counters self._counters[NUM_AGENT_STEPS_SAMPLED] += new_sample_batch.agent_steps() self._counters[NUM_ENV_STEPS_SAMPLED] += new_sample_batch.env_steps() # Store new samples in replay buffer. self.local_replay_buffer.add_batch(new_sample_batch) global_vars = { "timestep": self._counters[NUM_ENV_STEPS_SAMPLED], } for _ in range(sample_and_train_weight): # Sample training batch (MultiAgentBatch) from replay buffer. train_batch = sample_min_n_steps_from_buffer( self.local_replay_buffer, self.config["train_batch_size"], count_by_agent_steps=self._by_agent_steps, ) # Old-style replay buffers return None if learning has not started if train_batch is None or len(train_batch) == 0: self.workers.local_worker().set_global_vars(global_vars) break # Postprocess batch before we learn on it post_fn = self.config.get("before_learn_on_batch") or (lambda b, *a: b) train_batch = post_fn(train_batch, self.workers, self.config) # for policy_id, sample_batch in train_batch.policy_batches.items(): # print(len(sample_batch["obs"])) # print(sample_batch.count) # Learn on training batch. # Use simple optimizer (only for multi-agent or tf-eager; all other # cases should use the multi-GPU optimizer, even if only using 1 GPU) if self.config.get("simple_optimizer") is True: train_results = train_one_step(self, train_batch) else: train_results = multi_gpu_train_one_step(self, train_batch) # Update replay buffer priorities. update_priorities_in_replay_buffer( self.local_replay_buffer, self.config, train_batch, train_results, ) # Update target network every `target_network_update_freq` sample steps. cur_ts = self._counters[ NUM_AGENT_STEPS_SAMPLED if self._by_agent_steps else NUM_ENV_STEPS_SAMPLED ] last_update = self._counters[LAST_TARGET_UPDATE_TS] if cur_ts - last_update >= self.config["target_network_update_freq"]: to_update = self.workers.local_worker().get_policies_to_train() self.workers.local_worker().foreach_policy_to_train( lambda p, pid: pid in to_update and p.update_target() ) self._counters[NUM_TARGET_UPDATES] += 1 self._counters[LAST_TARGET_UPDATE_TS] = cur_ts # Update weights and global_vars - after learning on the local worker - # on all remote workers. with self._timers[SYNCH_WORKER_WEIGHTS_TIMER]: self.workers.sync_weights(global_vars=global_vars) # Return all collected metrics for the iteration. return train_results
def training_iteration(self) -> ResultDict: """DQN training iteration function. Each training iteration, we: - Sample (MultiAgentBatch) from workers. - Store new samples in replay buffer. - Sample training batch (MultiAgentBatch) from replay buffer. - Learn on training batch. - Update remote workers' new policy weights. - Update target network every target_network_update_freq steps. - Return all collected metrics for the iteration. Returns: The results dict from executing the training iteration. """ local_worker = self.workers.local_worker() train_results = {} # We alternate between storing new samples and sampling and training store_weight, sample_and_train_weight = calculate_rr_weights( self.config) for _ in range(store_weight): # (1) Sample (MultiAgentBatch) from workers new_sample_batch = synchronous_parallel_sample( worker_set=self.workers, concat=True) # Update counters self._counters[ NUM_AGENT_STEPS_SAMPLED] += new_sample_batch.agent_steps() self._counters[ NUM_ENV_STEPS_SAMPLED] += new_sample_batch.env_steps() # (2) Store new samples in replay buffer self.local_replay_buffer.add_batch(new_sample_batch) for _ in range(sample_and_train_weight): # (3) Sample training batch (MultiAgentBatch) from replay buffer. train_batch = self.local_replay_buffer.replay() # Old-style replay buffers return None if learning has not started if not train_batch: continue # Postprocess batch before we learn on it post_fn = self.config.get("before_learn_on_batch") or ( lambda b, *a: b) train_batch = post_fn(train_batch, self.workers, self.config) # (4) Learn on training batch. # Use simple optimizer (only for multi-agent or tf-eager; all other # cases should use the multi-GPU optimizer, even if only using 1 GPU) if self.config.get("simple_optimizer") is True: train_results = train_one_step(self, train_batch) else: train_results = multi_gpu_train_one_step(self, train_batch) # Update replay buffer priorities. update_priorities_in_replay_buffer( self.local_replay_buffer, self.config, train_batch, train_results, ) # (6) Update target network every target_network_update_freq steps cur_ts = self._counters[NUM_ENV_STEPS_SAMPLED] last_update = self._counters[LAST_TARGET_UPDATE_TS] if cur_ts - last_update >= self.config[ "target_network_update_freq"]: to_update = local_worker.get_policies_to_train() local_worker.foreach_policy_to_train( lambda p, pid: pid in to_update and p.update_target()) self._counters[NUM_TARGET_UPDATES] += 1 self._counters[LAST_TARGET_UPDATE_TS] = cur_ts # Update remote workers's weights after learning on local worker if self.workers.remote_workers(): with self._timers[SYNCH_WORKER_WEIGHTS_TIMER]: self.workers.sync_weights() # (7) Return all collected metrics for the iteration. return train_results
def training_step(self) -> ResultDict: """QMIX training iteration function. - Sample n MultiAgentBatches from n workers synchronously. - Store new samples in the replay buffer. - Sample one training MultiAgentBatch from the replay buffer. - Learn on the training batch. - Update the target network every `target_network_update_freq` sample steps. - Return all collected training metrics for the iteration. Returns: The results dict from executing the training iteration. """ # Sample n batches from n workers. new_sample_batches = synchronous_parallel_sample( worker_set=self.workers, concat=False) for batch in new_sample_batches: # Update counters. self._counters[NUM_ENV_STEPS_SAMPLED] += batch.env_steps() self._counters[NUM_AGENT_STEPS_SAMPLED] += batch.agent_steps() # Store new samples in the replay buffer. self.local_replay_buffer.add(batch) # Sample n batches from replay buffer until the total number of timesteps # reaches `train_batch_size`. train_batch = sample_min_n_steps_from_buffer( replay_buffer=self.local_replay_buffer, min_steps=self.config["train_batch_size"], count_by_agent_steps=self._by_agent_steps, ) if train_batch is None: return {} # Learn on the training batch. # Use simple optimizer (only for multi-agent or tf-eager; all other # cases should use the multi-GPU optimizer, even if only using 1 GPU) if self.config.get("simple_optimizer") is True: train_results = train_one_step(self, train_batch) else: train_results = multi_gpu_train_one_step(self, train_batch) # TODO: Move training steps counter update outside of `train_one_step()` method. # # Update train step counters. # self._counters[NUM_ENV_STEPS_TRAINED] += train_batch.env_steps() # self._counters[NUM_AGENT_STEPS_TRAINED] += train_batch.agent_steps() # Update target network every `target_network_update_freq` sample steps. cur_ts = self._counters[NUM_AGENT_STEPS_SAMPLED if self. _by_agent_steps else NUM_ENV_STEPS_SAMPLED] last_update = self._counters[LAST_TARGET_UPDATE_TS] if cur_ts - last_update >= self.config["target_network_update_freq"]: to_update = self.workers.local_worker().get_policies_to_train() self.workers.local_worker().foreach_policy_to_train( lambda p, pid: pid in to_update and p.update_target()) self._counters[NUM_TARGET_UPDATES] += 1 self._counters[LAST_TARGET_UPDATE_TS] = cur_ts update_priorities_in_replay_buffer(self.local_replay_buffer, self.config, train_batch, train_results) # Update weights and global_vars - after learning on the local worker - on all # remote workers. global_vars = { "timestep": self._counters[NUM_ENV_STEPS_SAMPLED], } # Update remote workers' weights and global vars after learning on local worker. with self._timers[SYNCH_WORKER_WEIGHTS_TIMER]: self.workers.sync_weights(global_vars=global_vars) # Return all collected metrics for the iteration. return train_results
def training_iteration(self) -> ResultDict: """DQN training iteration function. Each training iteration, we: - Sample (MultiAgentBatch) from workers. - Store new samples in replay buffer. - Sample training batch (MultiAgentBatch) from replay buffer. - Learn on training batch. - Update remote workers' new policy weights. - Update target network every target_network_update_freq steps. - Return all collected metrics for the iteration. Returns: The results dict from executing the training iteration. """ local_worker = self.workers.local_worker() train_results = {} # We alternate between storing new samples and sampling and training store_weight, sample_and_train_weight = calculate_rr_weights(self.config) for _ in range(store_weight): # (1) Sample (MultiAgentBatch) from workers new_sample_batch = synchronous_parallel_sample( worker_set=self.workers, concat=True ) # Update counters self._counters[NUM_AGENT_STEPS_SAMPLED] += new_sample_batch.agent_steps() self._counters[NUM_ENV_STEPS_SAMPLED] += new_sample_batch.env_steps() # (2) Store new samples in replay buffer self.local_replay_buffer.add_batch(new_sample_batch) for _ in range(sample_and_train_weight): # (3) Sample training batch (MultiAgentBatch) from replay buffer. train_batch = self.local_replay_buffer.replay() # Old-style replay buffers return None if learning has not started if not train_batch: continue # Postprocess batch before we learn on it post_fn = self.config.get("before_learn_on_batch") or (lambda b, *a: b) train_batch = post_fn(train_batch, self.workers, self.config) # (4) Learn on training batch. # Use simple optimizer (only for multi-agent or tf-eager; all other # cases should use the multi-GPU optimizer, even if only using 1 GPU) if self.config.get("simple_optimizer") is True: train_results = train_one_step(self, train_batch) else: train_results = multi_gpu_train_one_step(self, train_batch) # Update priorities if ( type(self.local_replay_buffer) is LegacyMultiAgentReplayBuffer and self.config["replay_buffer_config"].get( "prioritized_replay_alpha", 0.0 ) > 0.0 ) or isinstance( self.local_replay_buffer, MultiAgentPrioritizedReplayBuffer ): prio_dict = {} for policy_id, info in train_results.items(): # TODO(sven): This is currently structured differently for # torch/tf. Clean up these results/info dicts across # policies (note: fixing this in torch_policy.py will # break e.g. DDPPO!). td_error = info.get( "td_error", info[LEARNER_STATS_KEY].get("td_error") ) train_batch.policy_batches[policy_id].set_get_interceptor(None) batch_indices = train_batch.policy_batches[policy_id].get( "batch_indexes" ) # In case the buffer stores sequences, TD-error could # already be calculated per sequence chunk. if len(batch_indices) != len(td_error): T = self.local_replay_buffer.replay_sequence_length assert ( len(batch_indices) > len(td_error) and len(batch_indices) % T == 0 ) batch_indices = batch_indices.reshape([-1, T])[:, 0] assert len(batch_indices) == len(td_error) prio_dict[policy_id] = (batch_indices, td_error) self.local_replay_buffer.update_priorities(prio_dict) # (6) Update target network every target_network_update_freq steps cur_ts = self._counters[NUM_ENV_STEPS_SAMPLED] last_update = self._counters[LAST_TARGET_UPDATE_TS] if cur_ts - last_update >= self.config["target_network_update_freq"]: to_update = local_worker.get_policies_to_train() local_worker.foreach_policy_to_train( lambda p, pid: pid in to_update and p.update_target() ) self._counters[NUM_TARGET_UPDATES] += 1 self._counters[LAST_TARGET_UPDATE_TS] = cur_ts # Update remote workers's weights after learning on local worker if self.workers.remote_workers(): with self._timers[SYNCH_WORKER_WEIGHTS_TIMER]: self.workers.sync_weights() # (7) Return all collected metrics for the iteration. return train_results