class LegacyApexMemory(Specifiable): """ Apex prioritized replay implementing compression. """ def __init__(self, capacity=1000, alpha=1.0, beta=1.0, n_step_adjustment=1): """ TODO: documentation. Args: capacity (): alpha (): beta (): n_step_adjustment (): """ super(LegacyApexMemory, self).__init__() self.memory_values = [] self.index = 0 self.capacity = capacity self.size = 0 self.max_priority = 1.0 self.alpha = alpha self.beta = beta assert n_step_adjustment > 0, "ERROR: n-step adjustment must be at least 1 where 1 corresponds" \ "to the direct next state." self.n_step_adjustment = n_step_adjustment self.default_new_weight = np.power(self.max_priority, self.alpha) self.priority_capacity = 1 while self.priority_capacity < self.capacity: self.priority_capacity *= 2 # Create segment trees, initialize with neutral elements. sum_values = [0.0 for _ in range_(2 * self.priority_capacity)] sum_segment_tree = MemSegmentTree(sum_values, self.priority_capacity, operator.add) min_values = [float('inf') for _ in range_(2 * self.priority_capacity)] min_segment_tree = MemSegmentTree(min_values, self.priority_capacity, min) self.merged_segment_tree = MinSumSegmentTree( sum_tree=sum_segment_tree, min_tree=min_segment_tree, capacity=self.priority_capacity) def insert_records(self, record): # TODO: This has the record interface, but actually expects a specific structure anyway, so # may as well change API? if self.index >= self.size: self.memory_values.append(record) else: self.memory_values[self.index] = record # Weights. # TODO this is problematic due to index not existing. if record[4] is not None: self.merged_segment_tree.insert(self.index, record[4]) else: self.merged_segment_tree.insert(self.index, self.default_new_weight) # Update indices. self.index = (self.index + 1) % self.capacity self.size = min(self.size + 1, self.capacity) def read_records(self, indices): """ Obtains record values for the provided indices. Args: indices ndarray: Indices to read. Assumed to be not contiguous. Returns: dict: Record value dict. """ states = list() actions = list() rewards = list() terminals = list() next_states = list() for index in indices: # TODO remove as array casts if they dont help. state, action, reward, terminal, weight = self.memory_values[index] decompressed_state = np.array(ray_decompress(state), copy=False) states.append(decompressed_state) actions.append(np.array(action, copy=False)) rewards.append(reward) terminals.append(terminal) decompressed_next_state = decompressed_state # If terminal -> just use same state, already decompressed if terminal: next_states.append(decompressed_next_state) else: # Otherwise advance until correct next state or terminal. next_state = decompressed_next_state for i in range_(self.n_step_adjustment): next_index = (index + i + 1) % self.size next_state, _, _, terminal, _ = self.memory_values[ next_index] if terminal: break next_states.append( np.array(ray_decompress(next_state), copy=False)) return dict(states=np.array(states), actions=np.array(actions), rewards=np.array(rewards), terminals=np.array(terminals), next_states=np.array(next_states)) def get_records(self, num_records): indices = [] # Ensure we always have n-next states. # TODO potentially block this if size - 1 - nstep < 1? prob_sum = self.merged_segment_tree.sum_segment_tree.get_sum( 0, self.size - 1 - self.n_step_adjustment) samples = np.random.random(size=(num_records, )) * prob_sum for sample in samples: indices.append( self.merged_segment_tree.sum_segment_tree.index_of_prefixsum( prefix_sum=sample)) sum_prob = self.merged_segment_tree.sum_segment_tree.get_sum() min_prob = self.merged_segment_tree.min_segment_tree.get_min_value( ) / sum_prob max_weight = (min_prob * self.size)**(-self.beta) weights = [] for index in indices: sample_prob = self.merged_segment_tree.sum_segment_tree.get( index) / sum_prob weight = (sample_prob * self.size)**(-self.beta) weights.append(weight / max_weight) indices = np.array(indices, copy=False) return self.read_records(indices=indices), indices, np.array( weights, copy=False) def update_records(self, indices, update): for index, loss in zip(indices, update): priority = np.power(loss, self.alpha) self.merged_segment_tree.insert(index, priority) self.max_priority = max(self.max_priority, priority)
class ApexMemory(Specifiable): """ Apex prioritized replay implementing compression. """ def __init__(self, capacity=1000, alpha=1.0, beta=1.0, n_step_adjustment=1): """ TODO: documentation. Args: capacity (): alpha (): beta (): n_step_adjustment (): """ super(ApexMemory, self).__init__() self.memory_values = [] self.index = 0 self.capacity = capacity self.size = 0 self.max_priority = 1.0 self.alpha = alpha self.beta = beta # TODO this is not used here any more # assert n_step_adjustment > 0, "ERROR: n-step adjustment must be at least 1 where 1 corresponds" \ # "to the direct next state." self.n_step_adjustment = n_step_adjustment self.default_new_weight = np.power(self.max_priority, self.alpha) self.priority_capacity = 1 while self.priority_capacity < self.capacity: self.priority_capacity *= 2 # Create segment trees, initialize with neutral elements. sum_values = [0.0 for _ in range_(2 * self.priority_capacity)] sum_segment_tree = MemSegmentTree(sum_values, self.priority_capacity, operator.add) min_values = [float('inf') for _ in range_(2 * self.priority_capacity)] min_segment_tree = MemSegmentTree(min_values, self.priority_capacity, min) self.merged_segment_tree = MinSumSegmentTree( sum_tree=sum_segment_tree, min_tree=min_segment_tree, capacity=self.priority_capacity) def insert_records(self, record): # TODO: This has the record interface, but actually expects a specific structure anyway, so # may as well change API? if self.index >= self.size: self.memory_values.append(record) else: self.memory_values[self.index] = record # Weights. # TODO this is problematic due to index not existing. if record[5] is not None: self.merged_segment_tree.insert(self.index, record[5]**self.alpha) else: self.merged_segment_tree.insert(self.index, self.max_priority**self.alpha) # Update indices. self.index = (self.index + 1) % self.capacity self.size = min(self.size + 1, self.capacity) def read_records(self, indices): """ Obtains record values for the provided indices. Args: indices ndarray: Indices to read. Assumed to be not contiguous. Returns: dict: Record value dict. """ states = [] actions = [] rewards = [] terminals = [] next_states = [] for index in indices: state, action, reward, terminal, next_state, weight = self.memory_values[ index] states.append(ray_decompress(state)) actions.append(action) rewards.append(reward) terminals.append(terminal) next_states.append(ray_decompress(next_state)) return dict(states=np.asarray(states), actions=np.asarray(actions), rewards=np.asarray(rewards), terminals=np.asarray(terminals), next_states=np.asarray(next_states)) def get_records(self, num_records): indices = [] prob_sum = self.merged_segment_tree.sum_segment_tree.get_sum( 0, self.size) samples = np.random.random(size=(num_records, )) * prob_sum for sample in samples: indices.append( self.merged_segment_tree.sum_segment_tree.index_of_prefixsum( prefix_sum=sample)) sum_prob = self.merged_segment_tree.sum_segment_tree.get_sum() min_prob = self.merged_segment_tree.min_segment_tree.get_min_value( ) / sum_prob + SMALL_NUMBER max_weight = (min_prob * self.size)**(-self.beta) weights = [] for index in indices: sample_prob = self.merged_segment_tree.sum_segment_tree.get( index) / sum_prob weight = (sample_prob * self.size)**(-self.beta) weights.append(weight / max_weight) return self.read_records( indices=indices), np.asarray(indices), np.asarray(weights) def update_records(self, indices, update): for index, loss in zip(indices, update): self.merged_segment_tree.insert(index, loss**self.alpha) self.max_priority = max(self.max_priority, loss)
class MemPrioritizedReplay(Memory): """ Implements an in-memory prioritized replay. API: update_records(indices, update) -> Updates the given indices with the given priority scores. """ def __init__(self, capacity=1000, next_states=True, alpha=1.0, beta=0.0): super(MemPrioritizedReplay, self).__init__() self.memory_values = [] self.index = 0 self.capacity = capacity self.size = 0 self.max_priority = 1.0 self.alpha = alpha self.beta = beta self.next_states = next_states self.default_new_weight = np.power(self.max_priority, self.alpha) def create_variables(self, input_spaces, action_space=None): # Store our record-space for convenience. self.record_space = input_spaces["records"] self.record_space_flat = Dict(self.record_space.flatten( custom_scope_separator="/", scope_separator_at_start=False), add_batch_rank=True) self.priority_capacity = 1 while self.priority_capacity < self.capacity: self.priority_capacity *= 2 # Create segment trees, initialize with neutral elements. sum_values = [0.0 for _ in range_(2 * self.priority_capacity)] sum_segment_tree = MemSegmentTree(sum_values, self.priority_capacity, operator.add) min_values = [float('inf') for _ in range_(2 * self.priority_capacity)] min_segment_tree = MemSegmentTree(min_values, self.priority_capacity, min) self.merged_segment_tree = MinSumSegmentTree( sum_tree=sum_segment_tree, min_tree=min_segment_tree, capacity=self.priority_capacity) def _read_records(self, indices): """ Obtains record values for the provided indices. Args: indices ndarray: Indices to read. Assumed to be not contiguous. Returns: dict: Record value dict. """ records = {} for name in self.record_space_flat.keys(): records[name] = [] if self.size > 0: for index in indices: record = self.memory_values[index] for name in self.record_space_flat.keys(): records[name].append(record[name]) else: # TODO figure out how to do default handling in pytorch builds. # Fill with default vals for build. for name in self.record_space_flat.keys(): if get_backend() == "pytorch": records[name] = torch.zeros( self.record_space_flat[name].shape, dtype=dtype_(self.record_space_flat[name].dtype, "pytorch")) else: records[name] = np.zeros( self.record_space_flat[name].shape) # Convert if necessary: list of tensors fails at space inference otherwise. if get_backend() == "pytorch": for name in self.record_space_flat.keys(): records[name] = torch.squeeze(torch.stack(records[name])) return records @rlgraph_api(flatten_ops=True) def _graph_fn_insert_records(self, records): if records is None or get_rank(records['rewards']) == 0: return num_records = len(records['rewards']) if num_records == 1: if self.index >= self.size: self.memory_values.append(records) else: self.memory_values[self.index] = records self.merged_segment_tree.insert(self.index, self.default_new_weight) else: insert_indices = np.arange( start=self.index, stop=self.index + num_records) % self.capacity i = 0 for insert_index in insert_indices: self.merged_segment_tree.insert(insert_index, self.default_new_weight) record = dict() for name, record_values in records.items(): record[name] = record_values[i] if insert_index >= self.size: self.memory_values.append(record) else: self.memory_values[insert_index] = record i += 1 # Update indices self.index = (self.index + num_records) % self.capacity self.size = min(self.size + num_records, self.capacity) @rlgraph_api def _graph_fn_get_records(self, num_records=1): indices = [] prob_sum = self.merged_segment_tree.sum_segment_tree.get_sum( 0, self.size - 1) samples = np.random.random(size=(num_records, )) * prob_sum for sample in samples: indices.append( self.merged_segment_tree.sum_segment_tree.index_of_prefixsum( prefix_sum=sample)) sum_prob = self.merged_segment_tree.sum_segment_tree.get_sum( ) + SMALL_NUMBER min_prob = self.merged_segment_tree.min_segment_tree.get_min_value( ) / sum_prob max_weight = (min_prob * self.size)**(-self.beta) weights = [] for index in indices: sample_prob = self.merged_segment_tree.sum_segment_tree.get( index) / sum_prob weight = (sample_prob * self.size)**(-self.beta) weights.append(weight / max_weight) if get_backend() == "pytorch": indices = torch.tensor(indices) weights = torch.tensor(weights) else: indices = np.asarray(indices) weights = np.asarray(weights) return self._read_records(indices=indices), indices, weights @rlgraph_api(must_be_complete=False) def _graph_fn_update_records(self, indices, update): if len(indices) > 0 and indices[0]: for index, loss in zip(indices, update): priority = np.power(loss, self.alpha) self.merged_segment_tree.insert(index, priority) self.max_priority = max(self.max_priority, priority)
class MemPrioritizedReplay(Memory): """ Implements an in-memory prioritized replay. API: update_records(indices, update) -> Updates the given indices with the given priority scores. """ def __init__(self, capacity=1000, next_states=True, alpha=1.0, beta=0.0): super(MemPrioritizedReplay, self).__init__() self.memory_values = [] self.index = 0 self.capacity = capacity self.size = 0 self.max_priority = 1.0 self.alpha = alpha self.beta = beta self.next_states = next_states self.default_new_weight = np.power(self.max_priority, self.alpha) def create_variables(self, input_spaces, action_space=None): super(MemPrioritizedReplay, self).create_variables(input_spaces, action_space) self.priority_capacity = 1 while self.priority_capacity < self.capacity: self.priority_capacity *= 2 # Create segment trees, initialize with neutral elements. sum_values = [0.0 for _ in range_(2 * self.priority_capacity)] sum_segment_tree = MemSegmentTree(sum_values, self.priority_capacity, operator.add) min_values = [float('inf') for _ in range_(2 * self.priority_capacity)] min_segment_tree = MemSegmentTree(min_values, self.priority_capacity, min) self.merged_segment_tree = MinSumSegmentTree( sum_tree=sum_segment_tree, min_tree=min_segment_tree, capacity=self.priority_capacity) @rlgraph_api(flatten_ops=True) def _graph_fn_insert_records(self, records): if records is None or get_rank(records[self.terminal_key]) == 0: return num_records = len(records[self.terminal_key]) if num_records == 1: if self.index >= self.size: self.memory_values.append(records) else: self.memory_values[self.index] = records self.merged_segment_tree.insert(self.index, self.default_new_weight) else: insert_indices = np.arange( start=self.index, stop=self.index + num_records) % self.capacity i = 0 for insert_index in insert_indices: self.merged_segment_tree.insert(insert_index, self.default_new_weight) record = {} for name, record_values in records.items(): record[name] = record_values[i] if insert_index >= self.size: self.memory_values.append(record) else: self.memory_values[insert_index] = record i += 1 # Update indices self.index = (self.index + num_records) % self.capacity self.size = min(self.size + num_records, self.capacity) @rlgraph_api def _graph_fn_get_records(self, num_records=1): available_records = min(num_records, self.size) indices = [] prob_sum = self.merged_segment_tree.sum_segment_tree.get_sum( 0, self.size - 1) samples = np.random.random(size=(available_records, )) * prob_sum for sample in samples: indices.append( self.merged_segment_tree.sum_segment_tree.index_of_prefixsum( prefix_sum=sample)) sum_prob = self.merged_segment_tree.sum_segment_tree.get_sum( ) + SMALL_NUMBER min_prob = self.merged_segment_tree.min_segment_tree.get_min_value( ) / sum_prob max_weight = (min_prob * self.size)**(-self.beta) weights = [] for index in indices: sample_prob = self.merged_segment_tree.sum_segment_tree.get( index) / sum_prob weight = (sample_prob * self.size)**(-self.beta) weights.append(weight / max_weight) if get_backend() == "pytorch": indices = torch.tensor(indices) weights = torch.tensor(weights) else: indices = np.asarray(indices) weights = np.asarray(weights) records = DataOpDict() for name, variable in self.memory.items(): records[name] = self.read_variable( variable, indices, dtype=util.convert_dtype(self.flat_record_space[name].dtype, to="pytorch")) records = define_by_run_unflatten(records) return records, indices, weights @rlgraph_api(must_be_complete=False) def _graph_fn_update_records(self, indices, update): for index, loss in zip(indices, update): priority = np.power(loss, self.alpha) self.merged_segment_tree.insert(index, priority) self.max_priority = max(self.max_priority, priority) def get_state(self): return { "size": self.size, "index": self.index, "max_priority": self.max_priority }
class ApexMemory(Specifiable): """ Apex prioritized replay implementing compression. """ def __init__(self, state_space=None, action_space=None, capacity=1000, alpha=1.0, beta=1.0): """ Args: state_space (dict): State spec. action_space (dict): Actions spec. capacity (int): Max capacity. alpha (float): Initial weight. beta (float): Prioritisation factor. """ super(ApexMemory, self).__init__() self.state_space = state_space self.action_space = action_space self.container_actions = isinstance(action_space, dict) self.memory_values = [] self.index = 0 self.capacity = capacity self.size = 0 self.max_priority = 1.0 self.alpha = alpha self.beta = beta self.default_new_weight = np.power(self.max_priority, self.alpha) self.priority_capacity = 1 while self.priority_capacity < self.capacity: self.priority_capacity *= 2 # Create segment trees, initialize with neutral elements. sum_values = [0.0 for _ in range_(2 * self.priority_capacity)] sum_segment_tree = MemSegmentTree(sum_values, self.priority_capacity, operator.add) min_values = [float('inf') for _ in range_(2 * self.priority_capacity)] min_segment_tree = MemSegmentTree(min_values, self.priority_capacity, min) self.merged_segment_tree = MinSumSegmentTree( sum_tree=sum_segment_tree, min_tree=min_segment_tree, capacity=self.priority_capacity) def insert_records(self, record): # TODO: This has the record interface, but actually expects a specific structure anyway, so # may as well change API? if self.index >= self.size: self.memory_values.append(record) else: self.memory_values[self.index] = record # Weights. # TODO this is problematic due to index not existing. if record[5] is not None: self.merged_segment_tree.insert(self.index, record[5]**self.alpha) else: self.merged_segment_tree.insert(self.index, self.max_priority**self.alpha) # Update indices. self.index = (self.index + 1) % self.capacity self.size = min(self.size + 1, self.capacity) def read_records(self, indices): """ Obtains record values for the provided indices. Args: indices (ndarray): Indices to read. Assumed to be not contiguous. Returns: dict: Record value dict. """ states = [] if self.container_actions: actions = {k: [] for k in self.action_space.keys()} else: actions = [] rewards = [] terminals = [] next_states = [] for index in indices: state, action, reward, terminal, next_state, weight = self.memory_values[ index] states.append(ray_decompress(state)) if self.container_actions: for name in self.action_space.keys(): actions[name].append(action[name]) else: actions.append(action) rewards.append(reward) terminals.append(terminal) next_states.append(ray_decompress(next_state)) if self.container_actions: for name in self.action_space.keys(): actions[name] = np.squeeze(np.array(actions[name])) else: actions = np.array(actions) return dict(states=np.asarray(states), actions=actions, rewards=np.asarray(rewards), terminals=np.asarray(terminals), next_states=np.asarray(next_states)) def get_records(self, num_records): indices = [] prob_sum = self.merged_segment_tree.sum_segment_tree.get_sum( 0, self.size) samples = np.random.random(size=(num_records, )) * prob_sum for sample in samples: indices.append( self.merged_segment_tree.sum_segment_tree.index_of_prefixsum( prefix_sum=sample)) sum_prob = self.merged_segment_tree.sum_segment_tree.get_sum() min_prob = self.merged_segment_tree.min_segment_tree.get_min_value( ) / sum_prob + SMALL_NUMBER max_weight = (min_prob * self.size)**(-self.beta) weights = [] for index in indices: sample_prob = self.merged_segment_tree.sum_segment_tree.get( index) / sum_prob weight = (sample_prob * self.size)**(-self.beta) weights.append(weight / max_weight) return self.read_records( indices=indices), np.asarray(indices), np.asarray(weights) def update_records(self, indices, update): for index, loss in zip(indices, update): self.merged_segment_tree.insert(index, loss**self.alpha) self.max_priority = max(self.max_priority, loss)