def testAddOuterDimNoneToSpecs(self, dtype): if dtype == tf.string: self.skipTest("Not compatible with string type.") nested_spec = example_nested_tensor_spec(dtype) outer_dim = None self.assertEqual( tensor_spec.add_outer_dim(nested_spec, outer_dim), example_nested_tensor_spec(dtype, (outer_dim,)))
def create_reverb_server_for_replay_buffer_and_variable_container( collect_policy, train_step, replay_buffer_capacity, port): """Sets up one reverb server for replay buffer and variable container.""" # Create the signature for the variable container holding the policy weights. variables = { reverb_variable_container.POLICY_KEY: collect_policy.variables(), reverb_variable_container.TRAIN_STEP_KEY: train_step } variable_container_signature = tf.nest.map_structure( lambda variable: tf.TensorSpec(variable.shape, dtype=variable.dtype), variables) # Create the signature for the replay buffer holding observed experience. replay_buffer_signature = tensor_spec.from_spec( collect_policy.collect_data_spec) replay_buffer_signature = tensor_spec.add_outer_dim( replay_buffer_signature) # Crete and start the replay buffer and variable container server. server = reverb.Server( tables=[ reverb.Table( # Replay buffer storing experience. name=reverb_replay_buffer.DEFAULT_TABLE, sampler=reverb.selectors.Uniform(), remover=reverb.selectors.Fifo(), # TODO(b/159073060): Set rate limiter for SAC properly. rate_limiter=reverb.rate_limiters.MinSize(1), max_size=replay_buffer_capacity, max_times_sampled=0, signature=replay_buffer_signature, ), reverb.Table( # Variable container storing policy parameters. name=reverb_variable_container.DEFAULT_TABLE, sampler=reverb.selectors.Uniform(), remover=reverb.selectors.Fifo(), rate_limiter=reverb.rate_limiters.MinSize(1), max_size=1, max_times_sampled=0, signature=variable_container_signature, ), ], port=port) return server
while not time_step.is_last(): action_step = policy.action(time_step) time_step = environment.step(action_step.action) episode_return += time_step.reward total_return += episode_return avg_return = total_return / num_episodes return avg_return.numpy()[0] # Standard implementations for evaluation metrics in the metrics module. table_name = 'uniform_table' replay_buffer_signature = tensor_spec.from_spec(tf_agent.collect_data_spec) replay_buffer_signature = tensor_spec.add_outer_dim(replay_buffer_signature) table = reverb.Table(table_name, max_size=replay_buffer_capacity, sampler=reverb.selectors.Uniform(), remover=reverb.selectors.Fifo(), rate_limiter=reverb.rate_limiters.MinSize(1), signature=replay_buffer_signature) reverb_server = reverb.Server([table]) replay_buffer = reverb_replay_buffer.ReverbReplayBuffer( tf_agent.collect_data_spec, table_name=table_name, sequence_length=None, local_server=reverb_server)
def main(_): logging.set_verbosity(logging.INFO) # Wait for the collect policy to become available, then load it. collect_policy_dir = os.path.join(FLAGS.root_dir, learner.POLICY_SAVED_MODEL_DIR, learner.COLLECT_POLICY_SAVED_MODEL_DIR) collect_policy = train_utils.wait_for_policy(collect_policy_dir, load_specs_from_pbtxt=True) samples_per_insert = FLAGS.samples_per_insert min_table_size_before_sampling = FLAGS.min_table_size_before_sampling # Create the signature for the variable container holding the policy weights. train_step = train_utils.create_train_step() variables = { reverb_variable_container.POLICY_KEY: collect_policy.variables(), reverb_variable_container.TRAIN_STEP_KEY: train_step } variable_container_signature = tf.nest.map_structure( lambda variable: tf.TensorSpec(variable.shape, dtype=variable.dtype), variables) logging.info('Signature of variables: \n%s', variable_container_signature) # Create the signature for the replay buffer holding observed experience. replay_buffer_signature = tensor_spec.from_spec( collect_policy.collect_data_spec) replay_buffer_signature = tensor_spec.add_outer_dim( replay_buffer_signature) logging.info('Signature of experience: \n%s', replay_buffer_signature) if samples_per_insert is not None: # Use SamplesPerInsertRatio limiter samples_per_insert_tolerance = (_SAMPLES_PER_INSERT_TOLERANCE_RATIO * samples_per_insert) error_buffer = min_table_size_before_sampling * samples_per_insert_tolerance experience_rate_limiter = reverb.rate_limiters.SampleToInsertRatio( min_size_to_sample=min_table_size_before_sampling, samples_per_insert=samples_per_insert, error_buffer=error_buffer) else: # Use MinSize limiter experience_rate_limiter = reverb.rate_limiters.MinSize( min_table_size_before_sampling) # Crete and start the replay buffer and variable container server. server = reverb.Server( tables=[ reverb.Table( # Replay buffer storing experience. name=reverb_replay_buffer.DEFAULT_TABLE, sampler=reverb.selectors.Uniform(), remover=reverb.selectors.Fifo(), rate_limiter=experience_rate_limiter, max_size=FLAGS.replay_buffer_capacity, max_times_sampled=0, signature=replay_buffer_signature, ), reverb.Table( # Variable container storing policy parameters. name=reverb_variable_container.DEFAULT_TABLE, sampler=reverb.selectors.Uniform(), remover=reverb.selectors.Fifo(), rate_limiter=reverb.rate_limiters.MinSize(1), max_size=1, max_times_sampled=0, signature=variable_container_signature, ), ], port=FLAGS.port) server.wait()
def get_reverb_buffer(data_spec, sequence_length=None, table_name='uniform_table', table=None, reverb_server_address=None, port=None, replay_capacity=1000, min_size_limiter_size=1): """Returns an instance of Reverb replay buffer and observer to add items. Either creates a local reverb server or uses a remote reverb server at reverb_sever_address (if set). If reverb_server_address is None, creates a local server with a uniform table underneath. Args: data_spec: spec of the data elements to be stored in the replay buffer sequence_length: integer specifying sequence_lenghts used to write to the given table. table_name: Name of the table to create. table: Optional table for the backing local server. If None, automatically creates a uniform sampling table. reverb_server_address: Address of the remote reverb server, if None a local server is created. port: Port to launch the server in. replay_capacity: Optinal (for default uniform sampling table only, i.e if table=None) capacity of the uniform sampling table for the local replay server. min_size_limiter_size: Optional (for default uniform sampling table only, i.e if table=None) minimum number of items required in the RB before sampling can begin, used for local server only. Returns: Reverb replay buffer instance Note: the if local server is created, it is not returned. It can be retrieved by calling local_server() on the returned replay buffer. """ table_signature = tensor_spec.add_outer_dim(data_spec, sequence_length) if reverb_server_address is None: if table is None: table = _create_uniform_table( table_name, table_signature, table_capacity=replay_capacity, min_size_limiter_size=min_size_limiter_size) reverb_server = reverb.Server([table], port=port) reverb_replay = reverb_replay_buffer.ReverbReplayBuffer( data_spec, sequence_length=sequence_length, table_name=table_name, local_server=reverb_server) else: reverb_replay = reverb_replay_buffer.ReverbReplayBuffer( data_spec, sequence_length=sequence_length, table_name=table_name, server_address=reverb_server_address) return reverb_replay