コード例 #1
0
ファイル: tensor_spec_test.py プロジェクト: sparshag21/agents
 def testAddOuterDimNoneToSpecs(self, dtype):
   if dtype == tf.string:
     self.skipTest("Not compatible with string type.")
   nested_spec = example_nested_tensor_spec(dtype)
   outer_dim = None
   self.assertEqual(
       tensor_spec.add_outer_dim(nested_spec, outer_dim),
       example_nested_tensor_spec(dtype, (outer_dim,)))
コード例 #2
0
ファイル: test_utils.py プロジェクト: tensorflow/agents
def create_reverb_server_for_replay_buffer_and_variable_container(
        collect_policy, train_step, replay_buffer_capacity, port):
    """Sets up one reverb server for replay buffer and variable container."""
    # Create the signature for the variable container holding the policy weights.
    variables = {
        reverb_variable_container.POLICY_KEY: collect_policy.variables(),
        reverb_variable_container.TRAIN_STEP_KEY: train_step
    }
    variable_container_signature = tf.nest.map_structure(
        lambda variable: tf.TensorSpec(variable.shape, dtype=variable.dtype),
        variables)

    # Create the signature for the replay buffer holding observed experience.
    replay_buffer_signature = tensor_spec.from_spec(
        collect_policy.collect_data_spec)
    replay_buffer_signature = tensor_spec.add_outer_dim(
        replay_buffer_signature)

    # Crete and start the replay buffer and variable container server.
    server = reverb.Server(
        tables=[
            reverb.Table(  # Replay buffer storing experience.
                name=reverb_replay_buffer.DEFAULT_TABLE,
                sampler=reverb.selectors.Uniform(),
                remover=reverb.selectors.Fifo(),
                # TODO(b/159073060): Set rate limiter for SAC properly.
                rate_limiter=reverb.rate_limiters.MinSize(1),
                max_size=replay_buffer_capacity,
                max_times_sampled=0,
                signature=replay_buffer_signature,
            ),
            reverb.Table(  # Variable container storing policy parameters.
                name=reverb_variable_container.DEFAULT_TABLE,
                sampler=reverb.selectors.Uniform(),
                remover=reverb.selectors.Fifo(),
                rate_limiter=reverb.rate_limiters.MinSize(1),
                max_size=1,
                max_times_sampled=0,
                signature=variable_container_signature,
            ),
        ],
        port=port)
    return server
コード例 #3
0
        while not time_step.is_last():
            action_step = policy.action(time_step)
            time_step = environment.step(action_step.action)
            episode_return += time_step.reward
        total_return += episode_return

    avg_return = total_return / num_episodes
    return avg_return.numpy()[0]


# Standard implementations for evaluation metrics in the metrics module.

table_name = 'uniform_table'
replay_buffer_signature = tensor_spec.from_spec(tf_agent.collect_data_spec)
replay_buffer_signature = tensor_spec.add_outer_dim(replay_buffer_signature)
table = reverb.Table(table_name,
                     max_size=replay_buffer_capacity,
                     sampler=reverb.selectors.Uniform(),
                     remover=reverb.selectors.Fifo(),
                     rate_limiter=reverb.rate_limiters.MinSize(1),
                     signature=replay_buffer_signature)

reverb_server = reverb.Server([table])

replay_buffer = reverb_replay_buffer.ReverbReplayBuffer(
    tf_agent.collect_data_spec,
    table_name=table_name,
    sequence_length=None,
    local_server=reverb_server)
コード例 #4
0
def main(_):
    logging.set_verbosity(logging.INFO)

    # Wait for the collect policy to become available, then load it.
    collect_policy_dir = os.path.join(FLAGS.root_dir,
                                      learner.POLICY_SAVED_MODEL_DIR,
                                      learner.COLLECT_POLICY_SAVED_MODEL_DIR)
    collect_policy = train_utils.wait_for_policy(collect_policy_dir,
                                                 load_specs_from_pbtxt=True)

    samples_per_insert = FLAGS.samples_per_insert
    min_table_size_before_sampling = FLAGS.min_table_size_before_sampling

    # Create the signature for the variable container holding the policy weights.
    train_step = train_utils.create_train_step()
    variables = {
        reverb_variable_container.POLICY_KEY: collect_policy.variables(),
        reverb_variable_container.TRAIN_STEP_KEY: train_step
    }
    variable_container_signature = tf.nest.map_structure(
        lambda variable: tf.TensorSpec(variable.shape, dtype=variable.dtype),
        variables)
    logging.info('Signature of variables: \n%s', variable_container_signature)

    # Create the signature for the replay buffer holding observed experience.
    replay_buffer_signature = tensor_spec.from_spec(
        collect_policy.collect_data_spec)
    replay_buffer_signature = tensor_spec.add_outer_dim(
        replay_buffer_signature)
    logging.info('Signature of experience: \n%s', replay_buffer_signature)

    if samples_per_insert is not None:
        # Use SamplesPerInsertRatio limiter
        samples_per_insert_tolerance = (_SAMPLES_PER_INSERT_TOLERANCE_RATIO *
                                        samples_per_insert)
        error_buffer = min_table_size_before_sampling * samples_per_insert_tolerance

        experience_rate_limiter = reverb.rate_limiters.SampleToInsertRatio(
            min_size_to_sample=min_table_size_before_sampling,
            samples_per_insert=samples_per_insert,
            error_buffer=error_buffer)
    else:
        # Use MinSize limiter
        experience_rate_limiter = reverb.rate_limiters.MinSize(
            min_table_size_before_sampling)

    # Crete and start the replay buffer and variable container server.
    server = reverb.Server(
        tables=[
            reverb.Table(  # Replay buffer storing experience.
                name=reverb_replay_buffer.DEFAULT_TABLE,
                sampler=reverb.selectors.Uniform(),
                remover=reverb.selectors.Fifo(),
                rate_limiter=experience_rate_limiter,
                max_size=FLAGS.replay_buffer_capacity,
                max_times_sampled=0,
                signature=replay_buffer_signature,
            ),
            reverb.Table(  # Variable container storing policy parameters.
                name=reverb_variable_container.DEFAULT_TABLE,
                sampler=reverb.selectors.Uniform(),
                remover=reverb.selectors.Fifo(),
                rate_limiter=reverb.rate_limiters.MinSize(1),
                max_size=1,
                max_times_sampled=0,
                signature=variable_container_signature,
            ),
        ],
        port=FLAGS.port)
    server.wait()
コード例 #5
0
def get_reverb_buffer(data_spec,
                      sequence_length=None,
                      table_name='uniform_table',
                      table=None,
                      reverb_server_address=None,
                      port=None,
                      replay_capacity=1000,
                      min_size_limiter_size=1):
    """Returns an instance of Reverb replay buffer and observer to add items.

  Either creates a local reverb server or uses a remote reverb server at
  reverb_sever_address (if set).

  If reverb_server_address is None, creates a local server with a uniform
  table underneath.

  Args:
    data_spec: spec of the data elements to be stored in the replay buffer
    sequence_length: integer specifying sequence_lenghts used to write
      to the given table.
    table_name: Name of the table to create.
    table: Optional table for the backing local server. If None, automatically
      creates a uniform sampling table.
    reverb_server_address: Address of the remote reverb server, if None a local
      server is created.
    port: Port to launch the server in.
    replay_capacity: Optinal (for default uniform sampling table
      only, i.e if table=None) capacity of the uniform sampling table for the
      local replay server.
    min_size_limiter_size: Optional (for default uniform sampling
      table only, i.e if table=None) minimum number of items
      required in the RB before sampling can begin, used for local server only.
  Returns:
    Reverb replay buffer instance

    Note: the if local server is created, it is not returned. It can be
      retrieved by calling local_server() on the returned replay buffer.
  """
    table_signature = tensor_spec.add_outer_dim(data_spec, sequence_length)

    if reverb_server_address is None:
        if table is None:
            table = _create_uniform_table(
                table_name,
                table_signature,
                table_capacity=replay_capacity,
                min_size_limiter_size=min_size_limiter_size)

        reverb_server = reverb.Server([table], port=port)
        reverb_replay = reverb_replay_buffer.ReverbReplayBuffer(
            data_spec,
            sequence_length=sequence_length,
            table_name=table_name,
            local_server=reverb_server)
    else:
        reverb_replay = reverb_replay_buffer.ReverbReplayBuffer(
            data_spec,
            sequence_length=sequence_length,
            table_name=table_name,
            server_address=reverb_server_address)

    return reverb_replay