def __call__(self, batch: SampleBatchType) -> List[SampleBatchType]: _check_sample_batch_type(batch) self.buffer.append(batch) if self.count_steps_by == "env_steps": self.count += batch.count else: assert isinstance(batch, MultiAgentBatch), \ "`count_steps_by=agent_steps` only allowed in multi-agent " \ "environments!" self.count += batch.agent_steps() if self.count >= self.min_batch_size: if self.count > self.min_batch_size * 2: logger.info("Collected more training samples than expected " "(actual={}, expected={}). ".format( self.count, self.min_batch_size) + "This may be because you have many workers or " "long episodes in 'complete_episodes' batch mode.") out = SampleBatch.concat_samples(self.buffer) timer = _get_shared_metrics().timers[SAMPLE_TIMER] timer.push(time.perf_counter() - self.batch_start_time) timer.push_units_processed(self.count) self.batch_start_time = None self.buffer = [] self.count = 0 return [out] return []
def __call__(self, batch: SampleBatchType) -> (SampleBatchType, List[dict]): _check_sample_batch_type(batch) metrics = _get_shared_metrics() learn_timer = metrics.timers[LEARN_ON_BATCH_TIMER] lw = self.local_worker with learn_timer: # Subsample minibatches (size=`sgd_minibatch_size`) from the # train batch and loop through train batch `num_sgd_iter` times. if self.num_sgd_iter > 1 or self.sgd_minibatch_size > 0: learner_info = do_minibatch_sgd( batch, { pid: lw.get_policy(pid) for pid in self.policies or lw.get_policies_to_train(batch) }, lw, self.num_sgd_iter, self.sgd_minibatch_size, [], ) # Single update step using train batch. else: learner_info = lw.learn_on_batch(batch) metrics.info[LEARNER_INFO] = learner_info learn_timer.push_units_processed(batch.count) metrics.counters[STEPS_TRAINED_COUNTER] += batch.count metrics.counters[STEPS_TRAINED_THIS_ITER_COUNTER] = batch.count if isinstance(batch, MultiAgentBatch): metrics.counters[AGENT_STEPS_TRAINED_COUNTER] += batch.agent_steps( ) # Update weights - after learning on the local worker - on all remote # workers. if self.workers.remote_workers(): with metrics.timers[WORKER_UPDATE_TIMER]: weights = ray.put( lw.get_weights(self.policies or lw.get_policies_to_train(batch))) for e in self.workers.remote_workers(): e.set_weights.remote(weights, _get_global_vars()) # Also update global vars of the local worker. lw.set_global_vars(_get_global_vars()) return batch, learner_info
def __call__(self, batch: SampleBatchType) -> List[SampleBatchType]: if not batch: return [] _check_sample_batch_type(batch) if self.count_steps_by == "env_steps": size = batch.count else: assert isinstance(batch, MultiAgentBatch), ( "`count_steps_by=agent_steps` only allowed in multi-agent " "environments!" ) size = batch.agent_steps() # Incoming batch is an empty dummy batch -> Ignore. # Possibly produced automatically by a PolicyServer to unblock # an external env waiting for inputs from unresponsive/disconnected # client(s). if size == 0: return [] self.count += size self.buffer.append(batch) if self.count >= self.min_batch_size: if self.count > self.min_batch_size * 2: logger.info( "Collected more training samples than expected " "(actual={}, expected={}). ".format(self.count, self.min_batch_size) + "This may be because you have many workers or " "long episodes in 'complete_episodes' batch mode." ) out = SampleBatch.concat_samples(self.buffer) perf_counter = time.perf_counter() if self.using_iterators: timer = _get_shared_metrics().timers[SAMPLE_TIMER] timer.push(perf_counter - self.last_batch_time) timer.push_units_processed(self.count) self.last_batch_time = perf_counter self.buffer = [] self.count = 0 return [out] return []
def __call__(self, batch: SampleBatchType) -> (SampleBatchType, List[dict]): _check_sample_batch_type(batch) metrics = _get_shared_metrics() learn_timer = metrics.timers[LEARN_ON_BATCH_TIMER] with learn_timer: if self.num_sgd_iter > 1 or self.sgd_minibatch_size > 0: lw = self.workers.local_worker() info = do_minibatch_sgd( batch, { pid: lw.get_policy(pid) for pid in self.policies or self.local_worker.policies_to_train }, lw, self.num_sgd_iter, self.sgd_minibatch_size, []) # TODO(ekl) shouldn't be returning learner stats directly here # TODO(sven): Skips `custom_metrics` key from on_learn_on_batch # callback (shouldn't). metrics.info[LEARNER_INFO] = info else: info = self.workers.local_worker().learn_on_batch(batch) metrics.info[LEARNER_INFO] = extract_stats( info, LEARNER_STATS_KEY) metrics.info["custom_metrics"] = extract_stats( info, "custom_metrics") learn_timer.push_units_processed(batch.count) metrics.counters[STEPS_TRAINED_COUNTER] += batch.count if isinstance(batch, MultiAgentBatch): metrics.counters[AGENT_STEPS_TRAINED_COUNTER] += batch.agent_steps( ) # Update weights - after learning on the local worker - on all remote # workers. if self.workers.remote_workers(): with metrics.timers[WORKER_UPDATE_TIMER]: weights = ray.put(self.workers.local_worker().get_weights( self.policies or self.local_worker.policies_to_train)) for e in self.workers.remote_workers(): e.set_weights.remote(weights, _get_global_vars()) # Also update global vars of the local worker. self.workers.local_worker().set_global_vars(_get_global_vars()) return batch, info
def __call__(self, samples: SampleBatchType) -> (SampleBatchType, List[dict]): _check_sample_batch_type(samples) # Handle everything as if multi agent. samples = samples.as_multi_agent() metrics = _get_shared_metrics() load_timer = metrics.timers[LOAD_BATCH_TIMER] learn_timer = metrics.timers[LEARN_ON_BATCH_TIMER] # Load data into GPUs. with load_timer: num_loaded_samples = {} for policy_id, batch in samples.policy_batches.items(): # Not a policy-to-train. if not self.local_worker.is_policy_to_train( policy_id, samples): continue # Decompress SampleBatch, in case some columns are compressed. batch.decompress_if_needed() # Load the entire train batch into the Policy's only buffer # (idx=0). Policies only have >1 buffers, if we are training # asynchronously. num_loaded_samples[policy_id] = self.local_worker.policy_map[ policy_id].load_batch_into_buffer(batch, buffer_index=0) # Execute minibatch SGD on loaded data. with learn_timer: # Use LearnerInfoBuilder as a unified way to build the final # results dict from `learn_on_loaded_batch` call(s). # This makes sure results dicts always have the same structure # no matter the setup (multi-GPU, multi-agent, minibatch SGD, # tf vs torch). learner_info_builder = LearnerInfoBuilder( num_devices=len(self.devices)) for policy_id, samples_per_device in num_loaded_samples.items(): policy = self.local_worker.policy_map[policy_id] num_batches = max( 1, int(samples_per_device) // int(self.per_device_batch_size)) logger.debug("== sgd epochs for {} ==".format(policy_id)) for _ in range(self.num_sgd_iter): permutation = np.random.permutation(num_batches) for batch_index in range(num_batches): # Learn on the pre-loaded data in the buffer. # Note: For minibatch SGD, the data is an offset into # the pre-loaded entire train batch. results = policy.learn_on_loaded_batch( permutation[batch_index] * self.per_device_batch_size, buffer_index=0, ) learner_info_builder.add_learn_on_batch_results( results, policy_id) # Tower reduce and finalize results. learner_info = learner_info_builder.finalize() load_timer.push_units_processed(samples.count) learn_timer.push_units_processed(samples.count) metrics.counters[STEPS_TRAINED_COUNTER] += samples.count metrics.counters[STEPS_TRAINED_THIS_ITER_COUNTER] = samples.count metrics.counters[AGENT_STEPS_TRAINED_COUNTER] += samples.agent_steps() metrics.info[LEARNER_INFO] = learner_info if self.workers.remote_workers(): with metrics.timers[WORKER_UPDATE_TIMER]: weights = ray.put(self.workers.local_worker().get_weights( self.local_worker.get_policies_to_train())) for e in self.workers.remote_workers(): e.set_weights.remote(weights, _get_global_vars()) # Also update global vars of the local worker. self.workers.local_worker().set_global_vars(_get_global_vars()) return samples, learner_info
def __call__(self, samples: SampleBatchType) -> (SampleBatchType, List[dict]): _check_sample_batch_type(samples) # Handle everything as if multiagent if isinstance(samples, SampleBatch): samples = MultiAgentBatch({DEFAULT_POLICY_ID: samples}, samples.count) metrics = _get_shared_metrics() load_timer = metrics.timers[LOAD_BATCH_TIMER] learn_timer = metrics.timers[LEARN_ON_BATCH_TIMER] # Load data into GPUs. with load_timer: num_loaded_tuples = {} for policy_id, batch in samples.policy_batches.items(): # Not a policy-to-train. if policy_id not in self.policies: continue # Decompress SampleBatch, in case some columns are compressed. batch.decompress_if_needed() policy = self.workers.local_worker().get_policy(policy_id) policy._debug_vars() tuples = policy._get_loss_inputs_dict( batch, shuffle=self.shuffle_sequences) data_keys = list(policy._loss_input_dict_no_rnn.values()) if policy._state_inputs: state_keys = policy._state_inputs + [policy._seq_lens] else: state_keys = [] num_loaded_tuples[policy_id] = ( self.optimizers[policy_id].load_data( self.sess, [tuples[k] for k in data_keys], [tuples[k] for k in state_keys])) # Execute minibatch SGD on loaded data. with learn_timer: fetches = {} for policy_id, tuples_per_device in num_loaded_tuples.items(): optimizer = self.optimizers[policy_id] num_batches = max( 1, int(tuples_per_device) // int(self.per_device_batch_size)) logger.debug("== sgd epochs for {} ==".format(policy_id)) for _ in range(self.num_sgd_iter): permutation = np.random.permutation(num_batches) batch_fetches_all_towers = [] for batch_index in range(num_batches): batch_fetches = optimizer.optimize( self.sess, permutation[batch_index] * self.per_device_batch_size) batch_fetches_all_towers.append( tree.map_structure_with_path( lambda p, *s: all_tower_reduce(p, *s), *(batch_fetches["tower_{}".format(tower_num)] for tower_num in range(len(self.devices))))) # Reduce mean across all minibatch SGD steps (axis=0 to keep # all shapes as-is). fetches[policy_id] = tree.map_structure( lambda *s: np.nanmean(s, axis=0), *batch_fetches_all_towers) load_timer.push_units_processed(samples.count) learn_timer.push_units_processed(samples.count) metrics.counters[STEPS_TRAINED_COUNTER] += samples.count metrics.counters[AGENT_STEPS_TRAINED_COUNTER] += samples.agent_steps() metrics.info[LEARNER_INFO] = fetches if self.workers.remote_workers(): with metrics.timers[WORKER_UPDATE_TIMER]: weights = ray.put(self.workers.local_worker().get_weights( self.policies)) for e in self.workers.remote_workers(): e.set_weights.remote(weights, _get_global_vars()) # Also update global vars of the local worker. self.workers.local_worker().set_global_vars(_get_global_vars()) return samples, fetches
def __call__(self, samples: SampleBatchType) -> (SampleBatchType, List[dict]): _check_sample_batch_type(samples) # Handle everything as if multi agent. if isinstance(samples, SampleBatch): samples = MultiAgentBatch({DEFAULT_POLICY_ID: samples}, samples.count) metrics = _get_shared_metrics() load_timer = metrics.timers[LOAD_BATCH_TIMER] learn_timer = metrics.timers[LEARN_ON_BATCH_TIMER] # Load data into GPUs. with load_timer: num_loaded_tuples = {} for policy_id, batch in samples.policy_batches.items(): # Not a policy-to-train. if policy_id not in self.local_worker.policies_to_train: continue # Decompress SampleBatch, in case some columns are compressed. batch.decompress_if_needed() # Load the entire train batch into the Policy's only buffer # (idx=0). Policies only have >1 buffers, if we are training # asynchronously. num_loaded_tuples[policy_id] = self.local_worker.policy_map[ policy_id].load_batch_into_buffer(batch, buffer_index=0) # Execute minibatch SGD on loaded data. with learn_timer: fetches = {} for policy_id, tuples_per_device in num_loaded_tuples.items(): policy = self.local_worker.policy_map[policy_id] num_batches = max( 1, int(tuples_per_device) // int(self.per_device_batch_size)) logger.debug("== sgd epochs for {} ==".format(policy_id)) batch_fetches_all_towers = [] for _ in range(self.num_sgd_iter): permutation = np.random.permutation(num_batches) for batch_index in range(num_batches): # Learn on the pre-loaded data in the buffer. # Note: For minibatch SGD, the data is an offset into # the pre-loaded entire train batch. batch_fetches = policy.learn_on_loaded_batch( permutation[batch_index] * self.per_device_batch_size, buffer_index=0) # No towers: Single CPU. if "tower_0" not in batch_fetches: batch_fetches_all_towers.append(batch_fetches) else: batch_fetches_all_towers.append( tree.map_structure_with_path( lambda p, *s: all_tower_reduce(p, *s), *(batch_fetches["tower_{}".format( tower_num)] for tower_num in range(len(self.devices)) ))) # Reduce mean across all minibatch SGD steps (axis=0 to keep # all shapes as-is). fetches[policy_id] = tree.map_structure( lambda *s: np.nanmean(s, axis=0), *batch_fetches_all_towers) load_timer.push_units_processed(samples.count) learn_timer.push_units_processed(samples.count) metrics.counters[STEPS_TRAINED_COUNTER] += samples.count metrics.counters[AGENT_STEPS_TRAINED_COUNTER] += samples.agent_steps() metrics.info[LEARNER_INFO] = fetches if self.workers.remote_workers(): with metrics.timers[WORKER_UPDATE_TIMER]: weights = ray.put(self.workers.local_worker().get_weights( self.local_worker.policies_to_train)) for e in self.workers.remote_workers(): e.set_weights.remote(weights, _get_global_vars()) # Also update global vars of the local worker. self.workers.local_worker().set_global_vars(_get_global_vars()) return samples, fetches