def _graph_fn_insert_records(self, records): num_records = get_batch_size(records[self.terminal_key]) if get_backend() == "tf": # List of indices to update (insert from `index` forward and roll over at `self.capacity`). update_indices = tf.range(start=self.index, limit=self.index + num_records) % self.capacity # Updates all the necessary sub-variables in the record. record_updates = [] for key in self.memory: record_updates.append(self.scatter_update_variable( variable=self.memory[key], indices=update_indices, updates=records[key] )) # Update indices and size. with tf.control_dependencies(control_inputs=record_updates): index_updates = [self.assign_variable(ref=self.index, value=(self.index + num_records) % self.capacity)] update_size = tf.minimum(x=(self.read_variable(self.size) + num_records), y=self.capacity) index_updates.append(self.assign_variable(self.size, value=update_size)) # Nothing to return. with tf.control_dependencies(control_inputs=index_updates): return tf.no_op() elif get_backend() == "pytorch": update_indices = torch.arange(self.index, self.index + num_records) % self.capacity for key in self.memory: for i, val in zip(update_indices, records[key]): self.memory[key][i] = val self.index = (self.index + num_records) % self.capacity self.size = min(self.size + num_records, self.capacity) return None
def _graph_fn_update_records(self, indices, update): num_records = get_batch_size(indices) max_priority = 0.0 # Update has to be sequential. def insert_body(i, max_priority_): priority = tf.pow(x=update[i], y=self.alpha) sum_insert = self.sum_segment_tree.insert(index=indices[i], element=priority, insert_op=tf.add) min_insert = self.min_segment_tree.insert(index=indices[i], element=priority, insert_op=tf.minimum) # Keep track of current max priority element. max_priority_ = tf.maximum(x=max_priority_, y=priority) with tf.control_dependencies( control_inputs=[tf.group(sum_insert, min_insert)]): # TODO: This confuses the auto-return value detector. return i + 1, max_priority_ def cond(i, max_priority_): return i < num_records - 1 _, max_priority = tf.while_loop(cond=cond, body=insert_body, loop_vars=(0, max_priority)) assignment = self.assign_variable(ref=self.max_priority, value=max_priority) with tf.control_dependencies(control_inputs=[assignment]): return tf.no_op()
def _graph_fn_insert_records(self, records): num_records = get_batch_size(records["terminals"]) # List of indices to update (insert from `index` forward and roll over at `self.capacity`). update_indices = tf.range(start=self.index, limit=self.index + num_records) % self.capacity # Updates all the necessary sub-variables in the record. # update_indices = tf.Print(update_indices, [update_indices, index, num_records], summarize=100, # message='Update indices / index / num records = ') record_updates = list() for key in self.record_registry: record_updates.append(self.scatter_update_variable( variable=self.record_registry[key], indices=update_indices, updates=records[key] )) # Update indices and size. with tf.control_dependencies(control_inputs=record_updates): index_updates = list() index_updates.append(self.assign_variable(ref=self.index, value=(self.index + num_records) % self.capacity)) update_size = tf.minimum(x=(self.read_variable(self.size) + num_records), y=self.capacity) index_updates.append(self.assign_variable(self.size, value=update_size)) # Nothing to return. with tf.control_dependencies(control_inputs=index_updates): return tf.no_op()
def _graph_fn_sample(self, sample_size, inputs): """ Takes a set of input tensors and uniformly samples a subset of the specified size from them. Args: sample_size (SingleDataOp[int]): Subsample size. inputs (FlattenedDataOp): Input tensors (in a FlattenedDataOp) to sample from. All values (tensors) should all be the same size. Returns: FlattenedDataOp: The sub-sampled api_methods (will be unflattened automatically). """ batch_size = get_batch_size(next(iter(inputs.values()))) if get_backend() == "tf": sample_indices = tf.random_uniform(shape=(sample_size, ), maxval=batch_size, dtype=tf.int32) sample = FlattenedDataOp() for key, tensor in inputs.items(): sample[key] = tf.gather(params=tensor, indices=sample_indices) return sample
def _graph_fn_insert_records(self, records): num_records = get_batch_size(records["terminals"]) index = self.read_variable(self.index) update_indices = tf.range(start=index, limit=index + num_records) % self.capacity # Updates all the necessary sub-variables in the record. record_updates = list() for key in self.record_registry: record_updates.append( self.scatter_update_variable( variable=self.record_registry[key], indices=update_indices, updates=records[key])) # Update indices and size. with tf.control_dependencies(control_inputs=record_updates): index_updates = list() index_updates.append( self.assign_variable(ref=self.index, value=(index + num_records) % self.capacity)) update_size = tf.minimum(x=(self.read_variable(self.size) + num_records), y=self.capacity) index_updates.append( self.assign_variable(self.size, value=update_size)) weight = tf.pow(x=self.max_priority, y=self.alpha) # Insert new priorities into segment tree. def insert_body(i): sum_insert = self.sum_segment_tree.insert(update_indices[i], weight, tf.add) with tf.control_dependencies(control_inputs=[sum_insert]): return i + 1 def cond(i): return i < num_records with tf.control_dependencies(control_inputs=index_updates): sum_insert = tf.while_loop(cond=cond, body=insert_body, loop_vars=[0]) def insert_body(i): min_insert = self.min_segment_tree.insert(update_indices[i], weight, tf.minimum) with tf.control_dependencies(control_inputs=[min_insert]): return i + 1 def cond(i): return i < num_records with tf.control_dependencies(control_inputs=[sum_insert]): min_insert = tf.while_loop(cond=cond, body=insert_body, loop_vars=[0]) # Nothing to return. with tf.control_dependencies(control_inputs=[min_insert]): return tf.no_op()
def _graph_fn_insert_records(self, records): if get_backend() == "tf": num_records = get_batch_size(records["terminals"]) index = self.read_variable(self.index) # Episodes before inserting these records. prev_num_episodes = self.read_variable(self.num_episodes) update_indices = tf.range( start=index, limit=index + num_records) % self.capacity # Episodes previously existing in the range we inserted to as indicated # by count of terminals in the that slice. insert_terminal_slice = self.read_variable( self.record_registry['terminals'], update_indices) # Shift episode indices. with tf.control_dependencies([ update_indices, index, prev_num_episodes, insert_terminal_slice ]): index_updates = [] # Newly inserted episodes. inserted_episodes = tf.reduce_sum(input_tensor=tf.cast( records['terminals'], dtype=tf.int32), axis=0) episodes_in_insert_range = tf.reduce_sum(input_tensor=tf.cast( insert_terminal_slice, dtype=tf.int32), axis=0) num_episode_update = prev_num_episodes - episodes_in_insert_range + inserted_episodes # Shift contiguous episode indices. index_updates.append( self.assign_variable( ref=self.episode_indices[:prev_num_episodes - episodes_in_insert_range], value=self.episode_indices[ episodes_in_insert_range:prev_num_episodes])) # Insert new episodes starting at previous count minus the ones we removed, # ending at previous count minus removed + inserted. slice_start = prev_num_episodes - episodes_in_insert_range slice_end = num_episode_update # update_indices = tf.Print(update_indices, [update_indices, tf.shape(update_indices)], # summarize=100, message='\n update indices / shape = ') # Update indices and size. with tf.control_dependencies(index_updates): index_updates = [] # Actually update indices. mask = tf.boolean_mask(tensor=update_indices, mask=records['terminals']) index_updates.append( self.assign_variable( ref=self.episode_indices[slice_start:slice_end], value=mask)) # Assign final new episode count. index_updates.append( self.assign_variable(self.num_episodes, num_episode_update)) index_updates.append( self.assign_variable(ref=self.index, value=(index + num_records) % self.capacity)) update_size = tf.minimum(x=(self.read_variable(self.size) + num_records), y=self.capacity) index_updates.append( self.assign_variable(self.size, value=update_size)) # Updates all the necessary sub-variables in the record. with tf.control_dependencies(index_updates): record_updates = [] for key in self.record_registry: record_updates.append( self.scatter_update_variable( variable=self.record_registry[key], indices=update_indices, updates=records[key])) # Nothing to return. with tf.control_dependencies(control_inputs=record_updates): return tf.no_op() elif get_backend() == "pytorch": # TODO: Unclear if we should do this in numpy and then convert to torch once we sample. num_records = get_batch_size(records["terminals"]) update_indices = torch.arange( self.index, self.index + num_records) % self.capacity # Newly inserted episodes. inserted_episodes = torch.sum(records['terminals'].int(), 0) # Episodes previously existing in the range we inserted to as indicated # by count of terminals in the that slice. episodes_in_insert_range = 0 # Count terminals in inserted range. for index in update_indices: episodes_in_insert_range += int( self.record_registry["terminals"][index]) num_episode_update = self.num_episodes - episodes_in_insert_range + inserted_episodes self.episode_indices[:self.num_episodes - episodes_in_insert_range] = \ self.episode_indices[episodes_in_insert_range:self.num_episodes] # Insert new episodes starting at previous count minus the ones we removed, # ending at previous count minus removed + inserted. slice_start = self.num_episodes - episodes_in_insert_range slice_end = num_episode_update byte_terminals = records["terminals"].byte() mask = torch.masked_select(update_indices, byte_terminals) self.episode_indices[slice_start:slice_end] = mask # Update indices. self.num_episodes = int(num_episode_update) self.index = (self.index + num_records) % self.capacity self.size = min(self.size + num_records, self.capacity) # Updates all the necessary sub-variables in the record. for key in self.record_registry: for i, val in zip(update_indices, records[key]): self.record_registry[key][i] = val # The TF version returns no-op, return None so return-val inference system does not throw error. return None
def _graph_fn_insert_records(self, records): num_records = get_batch_size(records["terminals"]) index = self.read_variable(self.index) update_indices = tf.range(start=index, limit=index + num_records) % self.capacity # update_indices = tf.Print(update_indices, [index, num_records, update_indices], # summarize=100, message='index|num|indices') # update_indices = tf.Print(update_indices, [tf.shape(update_indices), # tf.shape(records["terminals"])], # summarize=100, message='shape indices|shape recods') # Update indices and size. with tf.control_dependencies([update_indices]): index_updates = [] if self.episode_semantics: # Episodes before inserting these records. prev_num_episodes = self.read_variable(self.num_episodes) # Newly inserted episodes. inserted_episodes = tf.reduce_sum(input_tensor=tf.cast(records['terminals'], dtype=tf.int32), axis=0) # Episodes previously existing in the range we inserted to as indicated # by count of terminals in the that slice. insert_terminal_slice = self.read_variable(self.record_registry['terminals'], update_indices) episodes_in_insert_range = tf.reduce_sum( input_tensor=tf.cast(insert_terminal_slice, dtype=tf.int32), axis=0 ) # prev_num_episodes = tf.Print(prev_num_episodes, [ # prev_num_episodes, # episodes_in_insert_range, # inserted_episodes], # summarize=100, message='previous num eps / prev episodes in insert range / inserted eps = ' # ) num_episode_update = prev_num_episodes - episodes_in_insert_range + inserted_episodes # prev_num_episodes = tf.Print(prev_num_episodes, [prev_num_episodes, episodes_in_insert_range], # summarize=100, message='num eps, eps in insert range =') # Remove previous episodes in inserted range. index_updates.append(self.assign_variable( ref=self.episode_indices[:prev_num_episodes + 1 - episodes_in_insert_range], value=self.episode_indices[episodes_in_insert_range:prev_num_episodes + 1] )) # Insert new episodes starting at previous count minus the ones we removed, # ending at previous count minus removed + inserted. slice_start = prev_num_episodes - episodes_in_insert_range slice_end = num_episode_update # update_indices = tf.Print(update_indices, [update_indices, tf.shape(update_indices)], # summarize=100, message='\n update indices / shape = ') # slice_start = tf.Print( # slice_start, [slice_start, slice_end, self.episode_indices], # summarize=100, # message='\n slice start/ slice end / episode indices before = ' # ) with tf.control_dependencies(index_updates): index_updates = [] mask = tf.boolean_mask(tensor=update_indices, mask=records['terminals']) # mask = tf.Print(mask, [mask, update_indices, records['/terminals']], summarize=100, # message='\n mask / update indices / records-terminal') index_updates.append(self.assign_variable( ref=self.episode_indices[slice_start:slice_end], value=mask )) # num_episode_update = tf.Print(num_episode_update, [num_episode_update, self.episode_indices], # summarize=100, message='\n num episodes / episode indices after: ') # Assign final new episode count. index_updates.append(self.assign_variable(self.num_episodes, num_episode_update)) index_updates.append(self.assign_variable(ref=self.index, value=(index + num_records) % self.capacity)) update_size = tf.minimum(x=(self.read_variable(self.size) + num_records), y=self.capacity) index_updates.append(self.assign_variable(self.size, value=update_size)) # Updates all the necessary sub-variables in the record. with tf.control_dependencies(index_updates): record_updates = [] for key in self.record_registry: record_updates.append(self.scatter_update_variable( variable=self.record_registry[key], indices=update_indices, updates=records[key] )) # Nothing to return. with tf.control_dependencies(control_inputs=record_updates): return tf.no_op()