def __call__(self, item: Tuple[ModelGradients, int]) -> None: if not isinstance(item, tuple) or len(item) != 2: raise ValueError( "Input must be a tuple of (grad_dict, count), got {}".format( item)) gradients, count = item metrics = _get_shared_metrics() metrics.counters[STEPS_TRAINED_COUNTER] += count apply_timer = metrics.timers[APPLY_GRADS_TIMER] with apply_timer: self.workers.local_worker().apply_gradients(gradients) apply_timer.push_units_processed(count) # Also update global vars of the local worker. self.workers.local_worker().set_global_vars(_get_global_vars()) if self.update_all: if self.workers.remote_workers(): with metrics.timers[WORKER_UPDATE_TIMER]: weights = ray.put(self.workers.local_worker().get_weights( self.policies)) for e in self.workers.remote_workers(): e.set_weights.remote(weights, _get_global_vars()) else: if metrics.current_actor is None: raise ValueError( "Could not find actor to update. When " "update_all=False, `current_actor` must be set " "in the iterator context.") with metrics.timers[WORKER_UPDATE_TIMER]: weights = self.workers.local_worker().get_weights( self.policies) metrics.current_actor.set_weights.remote( weights, _get_global_vars())
def __call__(self, batch: SampleBatchType) -> (SampleBatchType, List[dict]): _check_sample_batch_type(batch) metrics = _get_shared_metrics() learn_timer = metrics.timers[LEARN_ON_BATCH_TIMER] with learn_timer: if self.num_sgd_iter > 1 or self.sgd_minibatch_size > 0: lw = self.workers.local_worker() info = do_minibatch_sgd( batch, {pid: lw.get_policy(pid) for pid in self.policies}, lw, self.num_sgd_iter, self.sgd_minibatch_size, []) # TODO(ekl) shouldn't be returning learner stats directly here # TODO(sven): Skips `custom_metrics` key from on_learn_on_batch # callback (shouldn't). metrics.info[LEARNER_INFO] = info else: info = self.workers.local_worker().learn_on_batch(batch) metrics.info[LEARNER_INFO] = extract_stats( info, LEARNER_STATS_KEY) metrics.info["custom_metrics"] = extract_stats( info, "custom_metrics") learn_timer.push_units_processed(batch.count) metrics.counters[STEPS_TRAINED_COUNTER] += batch.count # Update weights - after learning on the local worker - on all remote # workers. if self.workers.remote_workers(): with metrics.timers[WORKER_UPDATE_TIMER]: weights = ray.put(self.workers.local_worker().get_weights( self.policies)) for e in self.workers.remote_workers(): e.set_weights.remote(weights, _get_global_vars()) # Also update global vars of the local worker. self.workers.local_worker().set_global_vars(_get_global_vars()) return batch, info
def __call__(self, batch: SampleBatchType) -> (SampleBatchType, List[dict]): _check_sample_batch_type(batch) metrics = _get_shared_metrics() learn_timer = metrics.timers[LEARN_ON_BATCH_TIMER] with learn_timer: if self.num_sgd_iter > 1 or self.sgd_minibatch_size > 0: w = self.workers.local_worker() info = do_minibatch_sgd( batch, {p: w.get_policy(p) for p in self.policies}, w, self.num_sgd_iter, self.sgd_minibatch_size, []) # TODO(ekl) shouldn't be returning learner stats directly here metrics.info[LEARNER_INFO] = info else: info = self.workers.local_worker().learn_on_batch(batch) metrics.info[LEARNER_INFO] = get_learner_stats(info) learn_timer.push_units_processed(batch.count) metrics.counters[STEPS_TRAINED_COUNTER] += batch.count if self.workers.remote_workers(): with metrics.timers[WORKER_UPDATE_TIMER]: weights = ray.put(self.workers.local_worker().get_weights( self.policies)) for e in self.workers.remote_workers(): e.set_weights.remote(weights, _get_global_vars()) # Also update global vars of the local worker. self.workers.local_worker().set_global_vars(_get_global_vars()) return batch, info
def __call__(self, item): actor, batch = item self.steps_since_broadcast += 1 if (self.steps_since_broadcast >= self.broadcast_interval and self.learner_thread.weights_updated): self.weights = ray.put(self.workers.local_worker().get_weights()) self.steps_since_broadcast = 0 self.learner_thread.weights_updated = False # Update metrics. metrics = _get_shared_metrics() metrics.counters["num_weight_broadcasts"] += 1 actor.set_weights.remote(self.weights, _get_global_vars()) # Also update global vars of the local worker. self.workers.local_worker().set_global_vars(_get_global_vars())
def __call__(self, item): actor, (info, samples, training_steps) = item metrics = _get_shared_metrics() metrics.counters[STEPS_TRAINED_COUNTER] += training_steps metrics.counters[STEPS_SAMPLED_COUNTER] += samples self.counters[actor] += 1 metrics.counters[ f"WorkerIteration/Worker{self.worker_idx[actor]}"] += 1 global_vars = _get_global_vars() self.workers.local_worker().set_global_vars(global_vars) actor.set_global_vars.remote(global_vars) if self.counters[actor] % self.broadcast_interval == 0: metrics.counters["num_weight_broadcasts"] += 1 with metrics.timers[WORKER_UPDATE_TIMER]: for pid, gw in self.global_weights.items(): def update_worker(w, alpha): return w.policy_map[pid].easgd_update(gw, alpha) diff = ray.get( actor.apply.remote(update_worker, self.alpha)) self.global_weights[pid] = EASGDUpdate.easgd_add( gw, diff, self.alpha) return info
def __call__(self, item: Tuple[ActorHandle, SampleBatchType]): actor, batch = item self.steps_since_update[actor] += batch.count if self.steps_since_update[actor] >= self.max_weight_sync_delay: # Note that it's important to pull new weights once # updated to avoid excessive correlation between actors. if self.weights is None or self.learner_thread.weights_updated: self.learner_thread.weights_updated = False self.weights = ray.put( self.workers.local_worker().get_weights()) actor.set_weights.remote(self.weights, _get_global_vars()) # Also update global vars of the local worker. self.workers.local_worker().set_global_vars(_get_global_vars()) self.steps_since_update[actor] = 0 # Update metrics. metrics = _get_shared_metrics() metrics.counters["num_weight_syncs"] += 1
def __call__(self, batch: SampleBatchType) -> List[dict]: _check_sample_batch_type(batch) metrics = LocalIterator.get_metrics() learn_timer = metrics.timers[LEARN_ON_BATCH_TIMER] with learn_timer: info = self.workers.local_worker().learn_on_batch(batch) learn_timer.push_units_processed(batch.count) metrics.counters[STEPS_TRAINED_COUNTER] += batch.count metrics.info[LEARNER_INFO] = get_learner_stats(info) if self.workers.remote_workers(): with metrics.timers[WORKER_UPDATE_TIMER]: weights = ray.put(self.workers.local_worker().get_weights()) for e in self.workers.remote_workers(): e.set_weights.remote(weights, _get_global_vars()) # Also update global vars of the local worker. self.workers.local_worker().set_global_vars(_get_global_vars()) return info
def __call__(self, batch: SampleBatchType) -> (SampleBatchType, List[dict]): _check_sample_batch_type(batch) metrics = _get_shared_metrics() learn_timer = metrics.timers[LEARN_ON_BATCH_TIMER] lw = self.local_worker with learn_timer: # Subsample minibatches (size=`sgd_minibatch_size`) from the # train batch and loop through train batch `num_sgd_iter` times. if self.num_sgd_iter > 1 or self.sgd_minibatch_size > 0: learner_info = do_minibatch_sgd( batch, { pid: lw.get_policy(pid) for pid in self.policies or lw.get_policies_to_train(batch) }, lw, self.num_sgd_iter, self.sgd_minibatch_size, [], ) # Single update step using train batch. else: learner_info = lw.learn_on_batch(batch) metrics.info[LEARNER_INFO] = learner_info learn_timer.push_units_processed(batch.count) metrics.counters[STEPS_TRAINED_COUNTER] += batch.count metrics.counters[STEPS_TRAINED_THIS_ITER_COUNTER] = batch.count if isinstance(batch, MultiAgentBatch): metrics.counters[AGENT_STEPS_TRAINED_COUNTER] += batch.agent_steps( ) # Update weights - after learning on the local worker - on all remote # workers. if self.workers.remote_workers(): with metrics.timers[WORKER_UPDATE_TIMER]: weights = ray.put( lw.get_weights(self.policies or lw.get_policies_to_train(batch))) for e in self.workers.remote_workers(): e.set_weights.remote(weights, _get_global_vars()) # Also update global vars of the local worker. lw.set_global_vars(_get_global_vars()) return batch, learner_info
def __call__(self, item): actor, (updates, info, samples, training_steps) = item if self.counters[actor] > min(self.counters.values()) + 50: return {} lw = self.workers.local_worker() metrics = _get_shared_metrics() metrics.counters[STEPS_TRAINED_COUNTER] += training_steps metrics.counters[STEPS_SAMPLED_COUNTER] += samples self.counters[actor] += 1 metrics.counters[ f"WorkerIteration/Worker{self.worker_idx[actor]}"] += 1 global_vars = _get_global_vars() lw.set_global_vars(global_vars) actor.set_global_vars.remote(global_vars) with metrics.timers[WORKER_UPDATE_TIMER]: for pid, update in updates.items(): def sync_update(w, update): w.policy_map[pid].asp_sync_updates(update) update, num_significant = update if lw != actor: sync_update(lw, update) if self.workers.remote_workers(): for e in self.workers.remote_workers(): if e != actor: e.apply.remote(sync_update, update) metrics.counters[ "significant_weight_updates"] += num_significant return info
def __call__(self, samples: SampleBatchType) -> (SampleBatchType, List[dict]): _check_sample_batch_type(samples) # Handle everything as if multiagent if isinstance(samples, SampleBatch): samples = MultiAgentBatch({DEFAULT_POLICY_ID: samples}, samples.count) metrics = _get_shared_metrics() load_timer = metrics.timers[LOAD_BATCH_TIMER] learn_timer = metrics.timers[LEARN_ON_BATCH_TIMER] # Load data into GPUs. with load_timer: num_loaded_tuples = {} for policy_id, batch in samples.policy_batches.items(): # Not a policy-to-train. if policy_id not in self.policies: continue # Decompress SampleBatch, in case some columns are compressed. batch.decompress_if_needed() policy = self.workers.local_worker().get_policy(policy_id) policy._debug_vars() tuples = policy._get_loss_inputs_dict( batch, shuffle=self.shuffle_sequences) data_keys = list(policy._loss_input_dict_no_rnn.values()) if policy._state_inputs: state_keys = policy._state_inputs + [policy._seq_lens] else: state_keys = [] num_loaded_tuples[policy_id] = ( self.optimizers[policy_id].load_data( self.sess, [tuples[k] for k in data_keys], [tuples[k] for k in state_keys])) # Execute minibatch SGD on loaded data. with learn_timer: fetches = {} for policy_id, tuples_per_device in num_loaded_tuples.items(): optimizer = self.optimizers[policy_id] num_batches = max( 1, int(tuples_per_device) // int(self.per_device_batch_size)) logger.debug("== sgd epochs for {} ==".format(policy_id)) for _ in range(self.num_sgd_iter): permutation = np.random.permutation(num_batches) batch_fetches_all_towers = [] for batch_index in range(num_batches): batch_fetches = optimizer.optimize( self.sess, permutation[batch_index] * self.per_device_batch_size) batch_fetches_all_towers.append( tree.map_structure_with_path( lambda p, *s: self._all_tower_reduce(p, *s), *(batch_fetches["tower_{}".format(tower_num)] for tower_num in range(len(self.devices))))) # Reduce mean across all minibatch SGD steps (axis=0 to keep # all shapes as-is). fetches[policy_id] = tree.map_structure( lambda *s: np.nanmean(s, axis=0), *batch_fetches_all_towers) load_timer.push_units_processed(samples.count) learn_timer.push_units_processed(samples.count) metrics.counters[STEPS_TRAINED_COUNTER] += samples.count metrics.info[LEARNER_INFO] = fetches if self.workers.remote_workers(): with metrics.timers[WORKER_UPDATE_TIMER]: weights = ray.put(self.workers.local_worker().get_weights( self.policies)) for e in self.workers.remote_workers(): e.set_weights.remote(weights, _get_global_vars()) # Also update global vars of the local worker. self.workers.local_worker().set_global_vars(_get_global_vars()) return samples, fetches
def __call__(self, samples: SampleBatchType) -> (SampleBatchType, List[dict]): _check_sample_batch_type(samples) # Handle everything as if multiagent if isinstance(samples, SampleBatch): samples = MultiAgentBatch({ DEFAULT_POLICY_ID: samples }, samples.count) metrics = _get_shared_metrics() load_timer = metrics.timers[LOAD_BATCH_TIMER] learn_timer = metrics.timers[LEARN_ON_BATCH_TIMER] with load_timer: # (1) Load data into GPUs. num_loaded_tuples = {} for policy_id, batch in samples.policy_batches.items(): if policy_id not in self.policies: continue policy = self.workers.local_worker().get_policy(policy_id) policy._debug_vars() tuples = policy._get_loss_inputs_dict( batch, shuffle=self.shuffle_sequences) data_keys = list(policy._loss_input_dict_no_rnn.values()) if policy._state_inputs: state_keys = policy._state_inputs + [policy._seq_lens] else: state_keys = [] num_loaded_tuples[policy_id] = ( self.optimizers[policy_id].load_data( self.sess, [tuples[k] for k in data_keys], [tuples[k] for k in state_keys])) with learn_timer: # (2) Execute minibatch SGD on loaded data. fetches = {} for policy_id, tuples_per_device in num_loaded_tuples.items(): optimizer = self.optimizers[policy_id] num_batches = max( 1, int(tuples_per_device) // int(self.per_device_batch_size)) logger.debug("== sgd epochs for {} ==".format(policy_id)) for i in range(self.num_sgd_iter): iter_extra_fetches = defaultdict(list) permutation = np.random.permutation(num_batches) for batch_index in range(num_batches): batch_fetches = optimizer.optimize( self.sess, permutation[batch_index] * self.per_device_batch_size) for k, v in batch_fetches[LEARNER_STATS_KEY].items(): iter_extra_fetches[k].append(v) if logger.getEffectiveLevel() <= logging.DEBUG: avg = averaged(iter_extra_fetches) logger.debug("{} {}".format(i, avg)) fetches[policy_id] = averaged(iter_extra_fetches, axis=0) load_timer.push_units_processed(samples.count) learn_timer.push_units_processed(samples.count) metrics.counters[STEPS_TRAINED_COUNTER] += samples.count metrics.info[LEARNER_INFO] = fetches if self.workers.remote_workers(): with metrics.timers[WORKER_UPDATE_TIMER]: weights = ray.put(self.workers.local_worker().get_weights( self.policies)) for e in self.workers.remote_workers(): e.set_weights.remote(weights, _get_global_vars()) # Also update global vars of the local worker. self.workers.local_worker().set_global_vars(_get_global_vars()) return samples, fetches
def __call__(self, samples: SampleBatchType) -> (SampleBatchType, List[dict]): _check_sample_batch_type(samples) # Handle everything as if multi agent. samples = samples.as_multi_agent() metrics = _get_shared_metrics() load_timer = metrics.timers[LOAD_BATCH_TIMER] learn_timer = metrics.timers[LEARN_ON_BATCH_TIMER] # Load data into GPUs. with load_timer: num_loaded_samples = {} for policy_id, batch in samples.policy_batches.items(): # Not a policy-to-train. if not self.local_worker.is_policy_to_train( policy_id, samples): continue # Decompress SampleBatch, in case some columns are compressed. batch.decompress_if_needed() # Load the entire train batch into the Policy's only buffer # (idx=0). Policies only have >1 buffers, if we are training # asynchronously. num_loaded_samples[policy_id] = self.local_worker.policy_map[ policy_id].load_batch_into_buffer(batch, buffer_index=0) # Execute minibatch SGD on loaded data. with learn_timer: # Use LearnerInfoBuilder as a unified way to build the final # results dict from `learn_on_loaded_batch` call(s). # This makes sure results dicts always have the same structure # no matter the setup (multi-GPU, multi-agent, minibatch SGD, # tf vs torch). learner_info_builder = LearnerInfoBuilder( num_devices=len(self.devices)) for policy_id, samples_per_device in num_loaded_samples.items(): policy = self.local_worker.policy_map[policy_id] num_batches = max( 1, int(samples_per_device) // int(self.per_device_batch_size)) logger.debug("== sgd epochs for {} ==".format(policy_id)) for _ in range(self.num_sgd_iter): permutation = np.random.permutation(num_batches) for batch_index in range(num_batches): # Learn on the pre-loaded data in the buffer. # Note: For minibatch SGD, the data is an offset into # the pre-loaded entire train batch. results = policy.learn_on_loaded_batch( permutation[batch_index] * self.per_device_batch_size, buffer_index=0, ) learner_info_builder.add_learn_on_batch_results( results, policy_id) # Tower reduce and finalize results. learner_info = learner_info_builder.finalize() load_timer.push_units_processed(samples.count) learn_timer.push_units_processed(samples.count) metrics.counters[STEPS_TRAINED_COUNTER] += samples.count metrics.counters[STEPS_TRAINED_THIS_ITER_COUNTER] = samples.count metrics.counters[AGENT_STEPS_TRAINED_COUNTER] += samples.agent_steps() metrics.info[LEARNER_INFO] = learner_info if self.workers.remote_workers(): with metrics.timers[WORKER_UPDATE_TIMER]: weights = ray.put(self.workers.local_worker().get_weights( self.local_worker.get_policies_to_train())) for e in self.workers.remote_workers(): e.set_weights.remote(weights, _get_global_vars()) # Also update global vars of the local worker. self.workers.local_worker().set_global_vars(_get_global_vars()) return samples, learner_info
def update_worker_global_vars(item): global_vars = _get_global_vars() for w in workers.remote_workers(): w.set_global_vars.remote(global_vars) return item
def __call__(self, samples: SampleBatchType) -> (SampleBatchType, List[dict]): _check_sample_batch_type(samples) # Handle everything as if multi agent. if isinstance(samples, SampleBatch): samples = MultiAgentBatch({DEFAULT_POLICY_ID: samples}, samples.count) metrics = _get_shared_metrics() load_timer = metrics.timers[LOAD_BATCH_TIMER] learn_timer = metrics.timers[LEARN_ON_BATCH_TIMER] # Load data into GPUs. with load_timer: num_loaded_tuples = {} for policy_id, batch in samples.policy_batches.items(): # Not a policy-to-train. if policy_id not in self.local_worker.policies_to_train: continue # Decompress SampleBatch, in case some columns are compressed. batch.decompress_if_needed() # Load the entire train batch into the Policy's only buffer # (idx=0). Policies only have >1 buffers, if we are training # asynchronously. num_loaded_tuples[policy_id] = self.local_worker.policy_map[ policy_id].load_batch_into_buffer(batch, buffer_index=0) # Execute minibatch SGD on loaded data. with learn_timer: fetches = {} for policy_id, tuples_per_device in num_loaded_tuples.items(): policy = self.local_worker.policy_map[policy_id] num_batches = max( 1, int(tuples_per_device) // int(self.per_device_batch_size)) logger.debug("== sgd epochs for {} ==".format(policy_id)) batch_fetches_all_towers = [] for _ in range(self.num_sgd_iter): permutation = np.random.permutation(num_batches) for batch_index in range(num_batches): # Learn on the pre-loaded data in the buffer. # Note: For minibatch SGD, the data is an offset into # the pre-loaded entire train batch. batch_fetches = policy.learn_on_loaded_batch( permutation[batch_index] * self.per_device_batch_size, buffer_index=0) # No towers: Single CPU. if "tower_0" not in batch_fetches: batch_fetches_all_towers.append(batch_fetches) else: batch_fetches_all_towers.append( tree.map_structure_with_path( lambda p, *s: all_tower_reduce(p, *s), *(batch_fetches["tower_{}".format( tower_num)] for tower_num in range(len(self.devices)) ))) # Reduce mean across all minibatch SGD steps (axis=0 to keep # all shapes as-is). fetches[policy_id] = tree.map_structure( lambda *s: np.nanmean(s, axis=0), *batch_fetches_all_towers) load_timer.push_units_processed(samples.count) learn_timer.push_units_processed(samples.count) metrics.counters[STEPS_TRAINED_COUNTER] += samples.count metrics.counters[AGENT_STEPS_TRAINED_COUNTER] += samples.agent_steps() metrics.info[LEARNER_INFO] = fetches if self.workers.remote_workers(): with metrics.timers[WORKER_UPDATE_TIMER]: weights = ray.put(self.workers.local_worker().get_weights( self.local_worker.policies_to_train)) for e in self.workers.remote_workers(): e.set_weights.remote(weights, _get_global_vars()) # Also update global vars of the local worker. self.workers.local_worker().set_global_vars(_get_global_vars()) return samples, fetches