def add(self, batch: SampleBatchType): """Splits a SampleBatch into episodes and adds episodes to the episode buffer Args: batch: SampleBatch to be added """ self.timesteps += batch.count episodes = batch.split_by_episode() for i, e in enumerate(episodes): episodes[i] = self.preprocess_episode(e) self.episodes.extend(episodes) if len(self.episodes) > self.max_length: delta = len(self.episodes) - self.max_length # Drop oldest episodes self.episodes = self.episodes[delta:]
def estimate(self, batch: SampleBatchType) -> OffPolicyEstimate: self.check_can_estimate_for(batch) estimates = [] for sub_batch in batch.split_by_episode(): rewards, old_prob = sub_batch["rewards"], sub_batch["action_prob"] new_prob = np.exp(self.action_log_likelihood(sub_batch)) # calculate importance ratios p = [] for t in range(sub_batch.count): if t == 0: pt_prev = 1.0 else: pt_prev = p[t - 1] p.append(pt_prev * new_prob[t] / old_prob[t]) for t, v in enumerate(p): if t >= len(self.filter_values): self.filter_values.append(v) self.filter_counts.append(1.0) else: self.filter_values[t] += v self.filter_counts[t] += 1.0 # calculate stepwise weighted IS estimate v_old = 0.0 v_new = 0.0 for t in range(sub_batch.count): v_old += rewards[t] * self.gamma ** t w_t = self.filter_values[t] / self.filter_counts[t] v_new += p[t] / w_t * rewards[t] * self.gamma ** t estimates.append( OffPolicyEstimate( self.name, { "v_old": v_old, "v_new": v_new, "v_gain": v_new / max(1e-8, v_old), }, ) ) return estimates
def __call__(self, batch: SampleBatchType) -> (SampleBatchType, List[dict]): _check_sample_batch_type(batch) metrics = _get_shared_metrics() learn_timer = metrics.timers[LEARN_ON_BATCH_TIMER] lw = self.local_worker with learn_timer: # Subsample minibatches (size=`sgd_minibatch_size`) from the # train batch and loop through train batch `num_sgd_iter` times. if self.num_sgd_iter > 1 or self.sgd_minibatch_size > 0: learner_info = do_minibatch_sgd( batch, { pid: lw.get_policy(pid) for pid in self.policies or lw.get_policies_to_train(batch) }, lw, self.num_sgd_iter, self.sgd_minibatch_size, [], ) # Single update step using train batch. else: learner_info = lw.learn_on_batch(batch) metrics.info[LEARNER_INFO] = learner_info learn_timer.push_units_processed(batch.count) metrics.counters[STEPS_TRAINED_COUNTER] += batch.count metrics.counters[STEPS_TRAINED_THIS_ITER_COUNTER] = batch.count if isinstance(batch, MultiAgentBatch): metrics.counters[AGENT_STEPS_TRAINED_COUNTER] += batch.agent_steps() # Update weights - after learning on the local worker - on all remote # workers. if self.workers.remote_workers(): with metrics.timers[WORKER_UPDATE_TIMER]: weights = ray.put( lw.get_weights(self.policies or lw.get_policies_to_train(batch)) ) for e in self.workers.remote_workers(): e.set_weights.remote(weights, _get_global_vars()) # Also update global vars of the local worker. lw.set_global_vars(_get_global_vars()) return batch, learner_info
def __call__(self, batch: SampleBatchType) -> List[SampleBatchType]: _check_sample_batch_type(batch) if self.count_steps_by == "env_steps": size = batch.count else: assert isinstance(batch, MultiAgentBatch), ( "`count_steps_by=agent_steps` only allowed in multi-agent " "environments!" ) size = batch.agent_steps() # Incoming batch is an empty dummy batch -> Ignore. # Possibly produced automatically by a PolicyServer to unblock # an external env waiting for inputs from unresponsive/disconnected # client(s). if size == 0: return [] self.count += size self.buffer.append(batch) if self.count >= self.min_batch_size: if self.count > self.min_batch_size * 2: logger.info( "Collected more training samples than expected " "(actual={}, expected={}). ".format(self.count, self.min_batch_size) + "This may be because you have many workers or " "long episodes in 'complete_episodes' batch mode." ) out = SampleBatch.concat_samples(self.buffer) perf_counter = time.perf_counter() timer = _get_shared_metrics().timers[SAMPLE_TIMER] timer.push(perf_counter - self.last_batch_time) timer.push_units_processed(self.count) self.last_batch_time = perf_counter self.buffer = [] self.count = 0 return [out] return []
def standardize_fields(samples: SampleBatchType, fields: List[str]) -> SampleBatchType: """Standardize fields of the given SampleBatch""" _check_sample_batch_type(samples) wrapped = False if isinstance(samples, SampleBatch): samples = samples.as_multi_agent() wrapped = True for policy_id in samples.policy_batches: batch = samples.policy_batches[policy_id] for field in fields: if field in batch: batch[field] = standardized(batch[field]) if wrapped: samples = samples.policy_batches[DEFAULT_POLICY_ID] return samples
def _add_to_policy_buffer(self, policy_id: PolicyID, batch: SampleBatchType) -> None: if self.replay_sequence_length == 1: timeslices = batch.timeslices(1) else: timeslices = timeslice_along_seq_lens_with_overlap( sample_batch=batch, zero_pad_max_seq_len=self.replay_sequence_length, pre_overlap=self.replay_burn_in, zero_init_states=self.replay_zero_init_states, ) for time_slice in timeslices: # If SampleBatch has prio-replay weights, average # over these to use as a weight for the entire # sequence. if "weights" in time_slice and len(time_slice["weights"]): weight = np.mean(time_slice["weights"]) else: weight = None self.replay_buffers[policy_id].add(time_slice, weight=weight)
def estimate(self, batch: SampleBatchType) -> Dict[str, Any]: """Compute off-policy estimates. Args: batch: The SampleBatch to run off-policy estimation on Returns: A dict consists of the following metrics: - v_behavior: The discounted return averaged over episodes in the batch - v_behavior_std: The standard deviation corresponding to v_behavior - v_target: The estimated discounted return for `self.policy`, averaged over episodes in the batch - v_target_std: The standard deviation corresponding to v_target - v_gain: v_target / max(v_behavior, 1e-8), averaged over episodes - v_gain_std: The standard deviation corresponding to v_gain """ batch = self.convert_ma_batch_to_sample_batch(batch) self.check_action_prob_in_batch(batch) estimates = {"v_behavior": [], "v_target": [], "v_gain": []} # Calculate Direct Method OPE estimates for episode in batch.split_by_episode(): rewards = episode["rewards"] v_behavior = 0.0 v_target = 0.0 for t in range(episode.count): v_behavior += rewards[t] * self.gamma ** t init_step = episode[0:1] v_target = self.model.estimate_v(init_step) v_target = convert_to_numpy(v_target).item() estimates["v_behavior"].append(v_behavior) estimates["v_target"].append(v_target) estimates["v_gain"].append(v_target / max(v_behavior, 1e-8)) estimates["v_behavior_std"] = np.std(estimates["v_behavior"]) estimates["v_behavior"] = np.mean(estimates["v_behavior"]) estimates["v_target_std"] = np.std(estimates["v_target"]) estimates["v_target"] = np.mean(estimates["v_target"]) estimates["v_gain_std"] = np.std(estimates["v_gain"]) estimates["v_gain"] = np.mean(estimates["v_gain"]) return estimates
def __call__(self, batch: SampleBatchType) -> (SampleBatchType, List[dict]): _check_sample_batch_type(batch) metrics = _get_shared_metrics() learn_timer = metrics.timers[LEARN_ON_BATCH_TIMER] with learn_timer: if self.num_sgd_iter > 1 or self.sgd_minibatch_size > 0: lw = self.workers.local_worker() info = do_minibatch_sgd( batch, { pid: lw.get_policy(pid) for pid in self.policies or self.local_worker.policies_to_train }, lw, self.num_sgd_iter, self.sgd_minibatch_size, []) # TODO(ekl) shouldn't be returning learner stats directly here # TODO(sven): Skips `custom_metrics` key from on_learn_on_batch # callback (shouldn't). metrics.info[LEARNER_INFO] = info else: info = self.workers.local_worker().learn_on_batch(batch) metrics.info[LEARNER_INFO] = extract_stats( info, LEARNER_STATS_KEY) metrics.info["custom_metrics"] = extract_stats( info, "custom_metrics") learn_timer.push_units_processed(batch.count) metrics.counters[STEPS_TRAINED_COUNTER] += batch.count if isinstance(batch, MultiAgentBatch): metrics.counters[AGENT_STEPS_TRAINED_COUNTER] += batch.agent_steps( ) # Update weights - after learning on the local worker - on all remote # workers. if self.workers.remote_workers(): with metrics.timers[WORKER_UPDATE_TIMER]: weights = ray.put(self.workers.local_worker().get_weights( self.policies or self.local_worker.policies_to_train)) for e in self.workers.remote_workers(): e.set_weights.remote(weights, _get_global_vars()) # Also update global vars of the local worker. self.workers.local_worker().set_global_vars(_get_global_vars()) return batch, info
def add(self, batch: SampleBatchType, weight: float) -> None: """Add a batch of experiences. Args: batch: SampleBatch to add to this buffer's storage. weight: The weight of the added sample used in subsequent sampling steps. """ idx = self._next_idx assert batch.count > 0, batch warn_replay_capacity(item=batch, num_items=self.capacity / batch.count) # Update our timesteps counts. self._num_timesteps_added += batch.count self._num_timesteps_added_wrap += batch.count if self._next_idx >= len(self._storage): self._storage.append(batch) self._est_size_bytes += batch.size_bytes() else: self._storage[self._next_idx] = batch # Wrap around storage as a circular buffer once we hit capacity. if self._num_timesteps_added_wrap >= self.capacity: self._eviction_started = True self._num_timesteps_added_wrap = 0 self._next_idx = 0 else: self._next_idx += 1 # Eviction of older samples has already started (buffer is "full"). if self._eviction_started: self._evicted_hit_stats.push(self._hit_count[self._next_idx]) self._hit_count[self._next_idx] = 0 if weight is None: weight = self._max_priority self._it_sum[idx] = weight**self._alpha self._it_min[idx] = weight**self._alpha
def _add_single_batch(self, item: SampleBatchType, **kwargs) -> None: """Add a SampleBatch of experiences to self._storage. An item consists of either one or more timesteps, a sequence or an episode. Differs from add() in that it does not consider the storage unit or type of batch and simply stores it. Args: item: The batch to be added. **kwargs: Forward compatibility kwargs. """ self._num_timesteps_added += item.count self._num_timesteps_added_wrap += item.count # Update add counts. self._num_add_calls += 1 # Update our timesteps counts. if self._num_timesteps_added < self.capacity: self._storage.append(item) self._est_size_bytes += item.size_bytes() else: # Eviction of older samples has already started (buffer is "full") self._eviction_started = True idx = random.randint(0, self._num_add_calls - 1) if idx < len(self._storage): self._num_evicted += 1 self._evicted_hit_stats.push(self._hit_count[idx]) self._hit_count[idx] = 0 # This is a bit of a hack: ReplayBuffer always inserts at # self._next_idx self._next_idx = idx self._evicted_hit_stats.push(self._hit_count[idx]) self._hit_count[idx] = 0 self._storage[idx] = item assert item.count > 0, item warn_replay_capacity(item=item, num_items=self.capacity / item.count)
def add(self, item: SampleBatchType, weight: float): warn_replay_buffer_size( item=item, num_items=self._maxsize / item.count) assert item.count > 0, item self._num_timesteps_added += item.count self._num_timesteps_added_wrap += item.count if self._next_idx >= len(self._storage): self._storage.append(item) self._est_size_bytes += item.size_bytes() else: self._storage[self._next_idx] = item # Wrap around storage as a circular buffer once we hit maxsize. if self._num_timesteps_added_wrap >= self._maxsize: self._eviction_started = True self._num_timesteps_added_wrap = 0 self._next_idx = 0 else: self._next_idx += 1 if self._eviction_started: self._evicted_hit_stats.push(self._hit_count[self._next_idx]) self._hit_count[self._next_idx] = 0
def add(self, batch: SampleBatchType, **kwargs) -> None: """Adds a batch to the appropriate policy's replay buffer. Turns the batch into a MultiAgentBatch of the DEFAULT_POLICY_ID if it is not a MultiAgentBatch. Subsequently, adds the individual policy batches to the storage. Args: batch: The batch to be added. **kwargs: Forward compatibility kwargs. """ # Make a copy so the replay buffer doesn't pin plasma memory. batch = batch.copy() # Handle everything as if multi-agent. batch = batch.as_multi_agent() kwargs = merge_dicts_with_warning(self.underlying_buffer_call_args, kwargs) # We need to split batches into timesteps, sequences or episodes # here already to properly keep track of self.last_added_batches # underlying buffers should not split up the batch any further with self.add_batch_timer: if self._storage_unit == StorageUnit.TIMESTEPS: for policy_id, sample_batch in batch.policy_batches.items(): if self.replay_sequence_length == 1: timeslices = sample_batch.timeslices(1) else: timeslices = timeslice_along_seq_lens_with_overlap( sample_batch=sample_batch, zero_pad_max_seq_len=self.replay_sequence_length, pre_overlap=self.replay_burn_in, zero_init_states=self.replay_zero_init_states, ) for time_slice in timeslices: self.replay_buffers[policy_id].add( time_slice, **kwargs) self.last_added_batches[policy_id].append(time_slice) elif self._storage_unit == StorageUnit.SEQUENCES: timestep_count = 0 for policy_id, sample_batch in batch.policy_batches.items(): for seq_len in sample_batch.get(SampleBatch.SEQ_LENS): start_seq = timestep_count end_seq = timestep_count + seq_len self.replay_buffers[policy_id].add( sample_batch[start_seq:end_seq], **kwargs) self.last_added_batches[policy_id].append( sample_batch[start_seq:end_seq]) timestep_count = end_seq elif self._storage_unit == StorageUnit.EPISODES: for policy_id, sample_batch in batch.policy_batches.items(): for eps in sample_batch.split_by_episode(): # Only add full episodes to the buffer if (eps.get(SampleBatch.T)[0] == 0 and eps.get( SampleBatch.DONES)[-1] == True # noqa E712 ): self.replay_buffers[policy_id].add(eps, **kwargs) self.last_added_batches[policy_id].append(eps) else: if log_once("only_full_episodes"): logger.info( "This buffer uses episodes as a storage " "unit and thus allows only full episodes " "to be added to it. Some samples may be " "dropped.") self._num_added += batch.count
def __call__(self, samples: SampleBatchType) -> (SampleBatchType, List[dict]): _check_sample_batch_type(samples) # Handle everything as if multi agent. samples = samples.as_multi_agent() metrics = _get_shared_metrics() load_timer = metrics.timers[LOAD_BATCH_TIMER] learn_timer = metrics.timers[LEARN_ON_BATCH_TIMER] # Load data into GPUs. with load_timer: num_loaded_samples = {} for policy_id, batch in samples.policy_batches.items(): # Not a policy-to-train. if not self.local_worker.is_policy_to_train( policy_id, samples): continue # Decompress SampleBatch, in case some columns are compressed. batch.decompress_if_needed() # Load the entire train batch into the Policy's only buffer # (idx=0). Policies only have >1 buffers, if we are training # asynchronously. num_loaded_samples[policy_id] = self.local_worker.policy_map[ policy_id].load_batch_into_buffer(batch, buffer_index=0) # Execute minibatch SGD on loaded data. with learn_timer: # Use LearnerInfoBuilder as a unified way to build the final # results dict from `learn_on_loaded_batch` call(s). # This makes sure results dicts always have the same structure # no matter the setup (multi-GPU, multi-agent, minibatch SGD, # tf vs torch). learner_info_builder = LearnerInfoBuilder( num_devices=len(self.devices)) for policy_id, samples_per_device in num_loaded_samples.items(): policy = self.local_worker.policy_map[policy_id] num_batches = max( 1, int(samples_per_device) // int(self.per_device_batch_size)) logger.debug("== sgd epochs for {} ==".format(policy_id)) for _ in range(self.num_sgd_iter): permutation = np.random.permutation(num_batches) for batch_index in range(num_batches): # Learn on the pre-loaded data in the buffer. # Note: For minibatch SGD, the data is an offset into # the pre-loaded entire train batch. results = policy.learn_on_loaded_batch( permutation[batch_index] * self.per_device_batch_size, buffer_index=0, ) learner_info_builder.add_learn_on_batch_results( results, policy_id) # Tower reduce and finalize results. learner_info = learner_info_builder.finalize() load_timer.push_units_processed(samples.count) learn_timer.push_units_processed(samples.count) metrics.counters[STEPS_TRAINED_COUNTER] += samples.count metrics.counters[STEPS_TRAINED_THIS_ITER_COUNTER] = samples.count metrics.counters[AGENT_STEPS_TRAINED_COUNTER] += samples.agent_steps() metrics.info[LEARNER_INFO] = learner_info if self.workers.remote_workers(): with metrics.timers[WORKER_UPDATE_TIMER]: weights = ray.put(self.workers.local_worker().get_weights( self.local_worker.get_policies_to_train())) for e in self.workers.remote_workers(): e.set_weights.remote(weights, _get_global_vars()) # Also update global vars of the local worker. self.workers.local_worker().set_global_vars(_get_global_vars()) return samples, learner_info
def process_episodes(self, batch: SampleBatchType) -> SampleBatchType: batch = batch.decompress_if_needed() self._mixin_buffer.add_batch(batch) processed_batches = self._mixin_buffer.replay(_ALL_POLICIES) return processed_batches
def _add_to_underlying_buffer( self, policy_id: PolicyID, batch: SampleBatchType, **kwargs ) -> None: """Add a batch of experiences to the underlying buffer of a policy. If the storage unit is `timesteps`, cut the batch into timeslices before adding them to the appropriate buffer. Otherwise, let the underlying buffer decide how slice batches. Args: policy_id: ID of the policy that corresponds to the underlying buffer batch: SampleBatch to add to the underlying buffer **kwargs: Forward compatibility kwargs. """ # For the storage unit `timesteps`, the underlying buffer will # simply store the samples how they arrive. For sequences and # episodes, the underlying buffer may split them itself. if self._storage_unit is StorageUnit.TIMESTEPS: if self.replay_sequence_length == 1: timeslices = batch.timeslices(1) else: timeslices = timeslice_along_seq_lens_with_overlap( sample_batch=batch, zero_pad_max_seq_len=self.replay_sequence_length, pre_overlap=self.replay_burn_in, zero_init_states=self.replay_zero_init_states, ) for time_slice in timeslices: # If SampleBatch has prio-replay weights, average # over these to use as a weight for the entire # sequence. if self.replay_mode is ReplayMode.INDEPENDENT: if "weights" in time_slice and len(time_slice["weights"]): weight = np.mean(time_slice["weights"]) else: weight = None if "weight" in kwargs and weight is not None: if log_once("overwrite_weight"): logger.warning( "Adding batches with column " "`weights` to this buffer while " "providing weights as a call argument " "to the add method results in the " "column being overwritten." ) kwargs = {"weight": weight, **kwargs} else: if "weight" in kwargs: if log_once("lockstep_no_weight_allowed"): logger.warning( "Settings weights for batches in " "lockstep mode is not allowed." "Weights are being ignored." ) kwargs = {**kwargs, "weight": None} self.replay_buffers[policy_id].add(time_slice, **kwargs) else: self.replay_buffers[policy_id].add(batch, **kwargs)
def __call__(self, samples: SampleBatchType) -> (SampleBatchType, List[dict]): _check_sample_batch_type(samples) # Handle everything as if multiagent if isinstance(samples, SampleBatch): samples = MultiAgentBatch({DEFAULT_POLICY_ID: samples}, samples.count) metrics = _get_shared_metrics() load_timer = metrics.timers[LOAD_BATCH_TIMER] learn_timer = metrics.timers[LEARN_ON_BATCH_TIMER] # Load data into GPUs. with load_timer: num_loaded_tuples = {} for policy_id, batch in samples.policy_batches.items(): # Not a policy-to-train. if policy_id not in self.policies: continue # Decompress SampleBatch, in case some columns are compressed. batch.decompress_if_needed() policy = self.workers.local_worker().get_policy(policy_id) policy._debug_vars() tuples = policy._get_loss_inputs_dict( batch, shuffle=self.shuffle_sequences) data_keys = list(policy._loss_input_dict_no_rnn.values()) if policy._state_inputs: state_keys = policy._state_inputs + [policy._seq_lens] else: state_keys = [] num_loaded_tuples[policy_id] = ( self.optimizers[policy_id].load_data( self.sess, [tuples[k] for k in data_keys], [tuples[k] for k in state_keys])) # Execute minibatch SGD on loaded data. with learn_timer: fetches = {} for policy_id, tuples_per_device in num_loaded_tuples.items(): optimizer = self.optimizers[policy_id] num_batches = max( 1, int(tuples_per_device) // int(self.per_device_batch_size)) logger.debug("== sgd epochs for {} ==".format(policy_id)) for _ in range(self.num_sgd_iter): permutation = np.random.permutation(num_batches) batch_fetches_all_towers = [] for batch_index in range(num_batches): batch_fetches = optimizer.optimize( self.sess, permutation[batch_index] * self.per_device_batch_size) batch_fetches_all_towers.append( tree.map_structure_with_path( lambda p, *s: all_tower_reduce(p, *s), *(batch_fetches["tower_{}".format(tower_num)] for tower_num in range(len(self.devices))))) # Reduce mean across all minibatch SGD steps (axis=0 to keep # all shapes as-is). fetches[policy_id] = tree.map_structure( lambda *s: np.nanmean(s, axis=0), *batch_fetches_all_towers) load_timer.push_units_processed(samples.count) learn_timer.push_units_processed(samples.count) metrics.counters[STEPS_TRAINED_COUNTER] += samples.count metrics.counters[AGENT_STEPS_TRAINED_COUNTER] += samples.agent_steps() metrics.info[LEARNER_INFO] = fetches if self.workers.remote_workers(): with metrics.timers[WORKER_UPDATE_TIMER]: weights = ray.put(self.workers.local_worker().get_weights( self.policies)) for e in self.workers.remote_workers(): e.set_weights.remote(weights, _get_global_vars()) # Also update global vars of the local worker. self.workers.local_worker().set_global_vars(_get_global_vars()) return samples, fetches
def timeslice_along_seq_lens_with_overlap( sample_batch: SampleBatchType, seq_lens: Optional[List[int]] = None, zero_pad_max_seq_len: int = 0, pre_overlap: int = 0, zero_init_states: bool = True, ) -> List["SampleBatch"]: """Slices batch along `seq_lens` (each seq-len item produces one batch). Args: sample_batch: The SampleBatch to timeslice. seq_lens (Optional[List[int]]): An optional list of seq_lens to slice at. If None, use `sample_batch[SampleBatch.SEQ_LENS]`. zero_pad_max_seq_len: If >0, already zero-pad the resulting slices up to this length. NOTE: This max-len will include the additional timesteps gained via setting pre_overlap (see Example). pre_overlap: If >0, will overlap each two consecutive slices by this many timesteps (toward the left side). This will cause zero-padding at the very beginning of the batch. zero_init_states: Whether initial states should always be zero'd. If False, will use the state_outs of the batch to populate state_in values. Returns: List[SampleBatch]: The list of (new) SampleBatches. Examples: assert seq_lens == [5, 5, 2] assert sample_batch.count == 12 # self = 0 1 2 3 4 | 5 6 7 8 9 | 10 11 <- timesteps slices = timeslice_along_seq_lens_with_overlap( sample_batch=sample_batch. zero_pad_max_seq_len=10, pre_overlap=3) # Z = zero padding (at beginning or end). # |pre (3)| seq | max-seq-len (up to 10) # slices[0] = | Z Z Z | 0 1 2 3 4 | Z Z # slices[1] = | 2 3 4 | 5 6 7 8 9 | Z Z # slices[2] = | 7 8 9 | 10 11 Z Z Z | Z Z # Note that `zero_pad_max_seq_len=10` includes the 3 pre-overlaps # count (makes sure each slice has exactly length 10). """ if seq_lens is None: seq_lens = sample_batch.get(SampleBatch.SEQ_LENS) else: if sample_batch.get(SampleBatch.SEQ_LENS) is not None and log_once( "overriding_sequencing_information" ): logger.warning( "Found sequencing information in a batch that will be " "ignored when slicing. Ignore this warning if you know " "what you are doing." ) if seq_lens is None: max_seq_len = zero_pad_max_seq_len - pre_overlap if log_once("no_sequence_lengths_available_for_time_slicing"): logger.warning( "Trying to slice a batch along sequences without " "sequence lengths being provided in the batch. Batch will " "be sliced into slices of size " "{} = {} - {} = zero_pad_max_seq_len - pre_overlap.".format( max_seq_len, zero_pad_max_seq_len, pre_overlap ) ) num_seq_lens, last_seq_len = divmod(len(sample_batch), max_seq_len) seq_lens = [zero_pad_max_seq_len] * num_seq_lens + ( [last_seq_len] if last_seq_len else [] ) assert ( seq_lens is not None and len(seq_lens) > 0 ), "Cannot timeslice along `seq_lens` when `seq_lens` is empty or None!" # Generate n slices based on seq_lens. start = 0 slices = [] for seq_len in seq_lens: pre_begin = start - pre_overlap slice_begin = start end = start + seq_len slices.append((pre_begin, slice_begin, end)) start += seq_len timeslices = [] for begin, slice_begin, end in slices: zero_length = None data_begin = 0 zero_init_states_ = zero_init_states if begin < 0: zero_length = pre_overlap data_begin = slice_begin zero_init_states_ = True else: eps_ids = sample_batch[SampleBatch.EPS_ID][begin if begin >= 0 else 0 : end] is_last_episode_ids = eps_ids == eps_ids[-1] if not is_last_episode_ids[0]: zero_length = int(sum(1.0 - is_last_episode_ids)) data_begin = begin + zero_length zero_init_states_ = True if zero_length is not None: data = { k: np.concatenate( [ np.zeros(shape=(zero_length,) + v.shape[1:], dtype=v.dtype), v[data_begin:end], ] ) for k, v in sample_batch.items() if k != SampleBatch.SEQ_LENS } else: data = { k: v[begin:end] for k, v in sample_batch.items() if k != SampleBatch.SEQ_LENS } if zero_init_states_: i = 0 key = "state_in_{}".format(i) while key in data: data[key] = np.zeros_like(sample_batch[key][0:1]) # Del state_out_n from data if exists. data.pop("state_out_{}".format(i), None) i += 1 key = "state_in_{}".format(i) # TODO: This will not work with attention nets as their state_outs are # not compatible with state_ins. else: i = 0 key = "state_in_{}".format(i) while key in data: data[key] = sample_batch["state_out_{}".format(i)][begin - 1 : begin] del data["state_out_{}".format(i)] i += 1 key = "state_in_{}".format(i) timeslices.append(SampleBatch(data, seq_lens=[end - begin])) # Zero-pad each slice if necessary. if zero_pad_max_seq_len > 0: for ts in timeslices: ts.right_zero_pad(max_seq_len=zero_pad_max_seq_len, exclude_states=True) return timeslices
def __call__(self, samples: SampleBatchType) -> (SampleBatchType, List[dict]): _check_sample_batch_type(samples) # Handle everything as if multi agent. if isinstance(samples, SampleBatch): samples = MultiAgentBatch({DEFAULT_POLICY_ID: samples}, samples.count) metrics = _get_shared_metrics() load_timer = metrics.timers[LOAD_BATCH_TIMER] learn_timer = metrics.timers[LEARN_ON_BATCH_TIMER] # Load data into GPUs. with load_timer: num_loaded_tuples = {} for policy_id, batch in samples.policy_batches.items(): # Not a policy-to-train. if policy_id not in self.local_worker.policies_to_train: continue # Decompress SampleBatch, in case some columns are compressed. batch.decompress_if_needed() # Load the entire train batch into the Policy's only buffer # (idx=0). Policies only have >1 buffers, if we are training # asynchronously. num_loaded_tuples[policy_id] = self.local_worker.policy_map[ policy_id].load_batch_into_buffer(batch, buffer_index=0) # Execute minibatch SGD on loaded data. with learn_timer: fetches = {} for policy_id, tuples_per_device in num_loaded_tuples.items(): policy = self.local_worker.policy_map[policy_id] num_batches = max( 1, int(tuples_per_device) // int(self.per_device_batch_size)) logger.debug("== sgd epochs for {} ==".format(policy_id)) batch_fetches_all_towers = [] for _ in range(self.num_sgd_iter): permutation = np.random.permutation(num_batches) for batch_index in range(num_batches): # Learn on the pre-loaded data in the buffer. # Note: For minibatch SGD, the data is an offset into # the pre-loaded entire train batch. batch_fetches = policy.learn_on_loaded_batch( permutation[batch_index] * self.per_device_batch_size, buffer_index=0) # No towers: Single CPU. if "tower_0" not in batch_fetches: batch_fetches_all_towers.append(batch_fetches) else: batch_fetches_all_towers.append( tree.map_structure_with_path( lambda p, *s: all_tower_reduce(p, *s), *(batch_fetches["tower_{}".format( tower_num)] for tower_num in range(len(self.devices)) ))) # Reduce mean across all minibatch SGD steps (axis=0 to keep # all shapes as-is). fetches[policy_id] = tree.map_structure( lambda *s: np.nanmean(s, axis=0), *batch_fetches_all_towers) load_timer.push_units_processed(samples.count) learn_timer.push_units_processed(samples.count) metrics.counters[STEPS_TRAINED_COUNTER] += samples.count metrics.counters[AGENT_STEPS_TRAINED_COUNTER] += samples.agent_steps() metrics.info[LEARNER_INFO] = fetches if self.workers.remote_workers(): with metrics.timers[WORKER_UPDATE_TIMER]: weights = ray.put(self.workers.local_worker().get_weights( self.local_worker.policies_to_train)) for e in self.workers.remote_workers(): e.set_weights.remote(weights, _get_global_vars()) # Also update global vars of the local worker. self.workers.local_worker().set_global_vars(_get_global_vars()) return samples, fetches
def _add_to_underlying_buffer(self, policy_id: PolicyID, batch: SampleBatchType, **kwargs) -> None: """Add a batch of experiences to the underlying buffer of a policy. If the storage unit is `timesteps`, cut the batch into timeslices before adding them to the appropriate buffer. Otherwise, let the underlying buffer decide how slice batches. Args: policy_id: ID of the policy that corresponds to the underlying buffer batch: SampleBatch to add to the underlying buffer ``**kwargs``: Forward compatibility kwargs. """ # Merge kwargs, overwriting standard call arguments kwargs = merge_dicts_with_warning(self.underlying_buffer_call_args, kwargs) # For the storage unit `timesteps`, the underlying buffer will # simply store the samples how they arrive. For sequences and # episodes, the underlying buffer may split them itself. if self.storage_unit is StorageUnit.TIMESTEPS: timeslices = batch.timeslices(1) elif self.storage_unit is StorageUnit.SEQUENCES: timeslices = timeslice_along_seq_lens_with_overlap( sample_batch=batch, seq_lens=batch.get(SampleBatch.SEQ_LENS) if self.replay_sequence_override else None, zero_pad_max_seq_len=self.replay_sequence_length, pre_overlap=self.replay_burn_in, zero_init_states=self.replay_zero_init_states, ) elif self.storage_unit == StorageUnit.EPISODES: timeslices = [] for eps in batch.split_by_episode(): if (eps.get(SampleBatch.T)[0] == 0 and eps.get(SampleBatch.DONES)[-1] == True # noqa E712 ): # Only add full episodes to the buffer timeslices.append(eps) else: if log_once("only_full_episodes"): logger.info("This buffer uses episodes as a storage " "unit and thus allows only full episodes " "to be added to it. Some samples may be " "dropped.") elif self.storage_unit == StorageUnit.FRAGMENTS: timeslices = [batch] else: raise ValueError("Unknown `storage_unit={}`".format( self.storage_unit)) for slice in timeslices: # If SampleBatch has prio-replay weights, average # over these to use as a weight for the entire # sequence. if self.replay_mode is ReplayMode.INDEPENDENT: if "weights" in slice and len(slice["weights"]): weight = np.mean(slice["weights"]) else: weight = None if "weight" in kwargs and weight is not None: if log_once("overwrite_weight"): logger.warning("Adding batches with column " "`weights` to this buffer while " "providing weights as a call argument " "to the add method results in the " "column being overwritten.") kwargs = {"weight": weight, **kwargs} else: if "weight" in kwargs: if log_once("lockstep_no_weight_allowed"): logger.warning("Settings weights for batches in " "lockstep mode is not allowed." "Weights are being ignored.") kwargs = {**kwargs, "weight": None} self.replay_buffers[policy_id].add(slice, **kwargs)