def training_step(self) -> ResultDict: # Collect SampleBatches from sample workers. batch = synchronous_parallel_sample(worker_set=self.workers) batch = batch.as_multi_agent() self._counters[NUM_AGENT_STEPS_SAMPLED] += batch.agent_steps() self._counters[NUM_ENV_STEPS_SAMPLED] += batch.env_steps() # Add batch to replay buffer. self.local_replay_buffer.add(batch) # Sample training batch from replay buffer. train_batch = sample_min_n_steps_from_buffer( self.local_replay_buffer, self.config["train_batch_size"], count_by_agent_steps=self._by_agent_steps, ) # Old-style replay buffers return None if learning has not started if not train_batch: return {} # Postprocess batch before we learn on it. post_fn = self.config.get("before_learn_on_batch") or (lambda b, *a: b) train_batch = post_fn(train_batch, self.workers, self.config) # Learn on training batch. # Use simple optimizer (only for multi-agent or tf-eager; all other # cases should use the multi-GPU optimizer, even if only using 1 GPU) if self.config.get("simple_optimizer") is True: train_results = train_one_step(self, train_batch) else: train_results = multi_gpu_train_one_step(self, train_batch) # Update replay buffer priorities. update_priorities_in_replay_buffer( self.local_replay_buffer, self.config, train_batch, train_results, ) # Update target network every `target_network_update_freq` training steps. cur_ts = self._counters[NUM_AGENT_STEPS_TRAINED if self. _by_agent_steps else NUM_ENV_STEPS_TRAINED] last_update = self._counters[LAST_TARGET_UPDATE_TS] if cur_ts - last_update >= self.config["target_network_update_freq"]: with self._timers[TARGET_NET_UPDATE_TIMER]: to_update = self.workers.local_worker().get_policies_to_train() self.workers.local_worker().foreach_policy_to_train( lambda p, pid: pid in to_update and p.update_target()) self._counters[NUM_TARGET_UPDATES] += 1 self._counters[LAST_TARGET_UPDATE_TS] = cur_ts # Update remote workers's weights after learning on local worker if self.workers.remote_workers(): with self._timers[SYNCH_WORKER_WEIGHTS_TIMER]: self.workers.sync_weights() # Return all collected metrics for the iteration. return train_results
def training_step(self) -> ResultDict: # Collect SampleBatches from sample workers. with self._timers[SAMPLE_TIMER]: batch = synchronous_parallel_sample(worker_set=self.workers) batch = batch.as_multi_agent() self._counters[NUM_AGENT_STEPS_SAMPLED] += batch.agent_steps() self._counters[NUM_ENV_STEPS_SAMPLED] += batch.env_steps() # Add batch to replay buffer. self.local_replay_buffer.add(batch) # Pull batch from replay buffer and train on it. train_batch = sample_min_n_steps_from_buffer( self.local_replay_buffer, self.config["train_batch_size"], count_by_agent_steps=self._by_agent_steps, ) # Train. if self.config["simple_optimizer"]: train_results = train_one_step(self, train_batch) else: train_results = multi_gpu_train_one_step(self, train_batch) # TODO: Move training steps counter update outside of `train_one_step()` method. # # Update train step counters. # self._counters[NUM_ENV_STEPS_TRAINED] += train_batch.env_steps() # self._counters[NUM_AGENT_STEPS_TRAINED] += train_batch.agent_steps() global_vars = { "timestep": self._counters[NUM_AGENT_STEPS_SAMPLED], } # Update weights - after learning on the local worker - on all remote # workers. if self.workers.remote_workers(): with self._timers[SYNCH_WORKER_WEIGHTS_TIMER]: self.workers.sync_weights(global_vars=global_vars) # Update global vars on local worker as well. self.workers.local_worker().set_global_vars(global_vars) return train_results
def training_iteration(self) -> ResultDict: """DQN training iteration function. Each training iteration, we: - Sample (MultiAgentBatch) from workers. - Store new samples in replay buffer. - Sample training batch (MultiAgentBatch) from replay buffer. - Learn on training batch. - Update remote workers' new policy weights. - Update target network every `target_network_update_freq` sample steps. - Return all collected metrics for the iteration. Returns: The results dict from executing the training iteration. """ train_results = {} # We alternate between storing new samples and sampling and training store_weight, sample_and_train_weight = calculate_rr_weights(self.config) for _ in range(store_weight): # Sample (MultiAgentBatch) from workers. new_sample_batch = synchronous_parallel_sample( worker_set=self.workers, concat=True ) # Update counters self._counters[NUM_AGENT_STEPS_SAMPLED] += new_sample_batch.agent_steps() self._counters[NUM_ENV_STEPS_SAMPLED] += new_sample_batch.env_steps() # Store new samples in replay buffer. self.local_replay_buffer.add_batch(new_sample_batch) global_vars = { "timestep": self._counters[NUM_ENV_STEPS_SAMPLED], } for _ in range(sample_and_train_weight): # Sample training batch (MultiAgentBatch) from replay buffer. train_batch = sample_min_n_steps_from_buffer( self.local_replay_buffer, self.config["train_batch_size"], count_by_agent_steps=self._by_agent_steps, ) # Old-style replay buffers return None if learning has not started if train_batch is None or len(train_batch) == 0: self.workers.local_worker().set_global_vars(global_vars) break # Postprocess batch before we learn on it post_fn = self.config.get("before_learn_on_batch") or (lambda b, *a: b) train_batch = post_fn(train_batch, self.workers, self.config) # for policy_id, sample_batch in train_batch.policy_batches.items(): # print(len(sample_batch["obs"])) # print(sample_batch.count) # Learn on training batch. # Use simple optimizer (only for multi-agent or tf-eager; all other # cases should use the multi-GPU optimizer, even if only using 1 GPU) if self.config.get("simple_optimizer") is True: train_results = train_one_step(self, train_batch) else: train_results = multi_gpu_train_one_step(self, train_batch) # Update replay buffer priorities. update_priorities_in_replay_buffer( self.local_replay_buffer, self.config, train_batch, train_results, ) # Update target network every `target_network_update_freq` sample steps. cur_ts = self._counters[ NUM_AGENT_STEPS_SAMPLED if self._by_agent_steps else NUM_ENV_STEPS_SAMPLED ] last_update = self._counters[LAST_TARGET_UPDATE_TS] if cur_ts - last_update >= self.config["target_network_update_freq"]: to_update = self.workers.local_worker().get_policies_to_train() self.workers.local_worker().foreach_policy_to_train( lambda p, pid: pid in to_update and p.update_target() ) self._counters[NUM_TARGET_UPDATES] += 1 self._counters[LAST_TARGET_UPDATE_TS] = cur_ts # Update weights and global_vars - after learning on the local worker - # on all remote workers. with self._timers[SYNCH_WORKER_WEIGHTS_TIMER]: self.workers.sync_weights(global_vars=global_vars) # Return all collected metrics for the iteration. return train_results
def training_iteration(self) -> ResultDict: """QMIX training iteration function. - Sample n MultiAgentBatches from n workers synchronously. - Store new samples in the replay buffer. - Sample one training MultiAgentBatch from the replay buffer. - Learn on the training batch. - Update the target network every `target_network_update_freq` sample steps. - Return all collected training metrics for the iteration. Returns: The results dict from executing the training iteration. """ # Sample n batches from n workers. new_sample_batches = synchronous_parallel_sample( worker_set=self.workers, concat=False ) for batch in new_sample_batches: # Update counters. self._counters[NUM_ENV_STEPS_SAMPLED] += batch.env_steps() self._counters[NUM_AGENT_STEPS_SAMPLED] += batch.agent_steps() # Store new samples in the replay buffer. self.local_replay_buffer.add(batch) # Sample n batches from replay buffer until the total number of timesteps # reaches `train_batch_size`. train_batch = sample_min_n_steps_from_buffer( replay_buffer=self.local_replay_buffer, min_steps=self.config["train_batch_size"], count_by_agent_steps=self._by_agent_steps, ) if train_batch is None: return {} # Learn on the training batch. # Use simple optimizer (only for multi-agent or tf-eager; all other # cases should use the multi-GPU optimizer, even if only using 1 GPU) if self.config.get("simple_optimizer") is True: train_results = train_one_step(self, train_batch) else: train_results = multi_gpu_train_one_step(self, train_batch) # TODO: Move training steps counter update outside of `train_one_step()` method. # # Update train step counters. # self._counters[NUM_ENV_STEPS_TRAINED] += train_batch.env_steps() # self._counters[NUM_AGENT_STEPS_TRAINED] += train_batch.agent_steps() # Update target network every `target_network_update_freq` sample steps. cur_ts = self._counters[ NUM_AGENT_STEPS_SAMPLED if self._by_agent_steps else NUM_ENV_STEPS_SAMPLED ] last_update = self._counters[LAST_TARGET_UPDATE_TS] if cur_ts - last_update >= self.config["target_network_update_freq"]: to_update = self.workers.local_worker().get_policies_to_train() self.workers.local_worker().foreach_policy_to_train( lambda p, pid: pid in to_update and p.update_target() ) self._counters[NUM_TARGET_UPDATES] += 1 self._counters[LAST_TARGET_UPDATE_TS] = cur_ts # Update weights and global_vars - after learning on the local worker - on all # remote workers. global_vars = { "timestep": self._counters[NUM_ENV_STEPS_SAMPLED], } # Update remote workers' weights and global vars after learning on local worker. with self._timers[SYNCH_WORKER_WEIGHTS_TIMER]: self.workers.sync_weights(global_vars=global_vars) # Return all collected metrics for the iteration. return train_results