예제 #1
0
def main(unused_argv: Sequence[Text]) -> None:
    logging.set_verbosity(logging.INFO)
    tf.enable_v2_behavior()

    gin.parse_config_files_and_bindings(FLAGS.gin_file, FLAGS.gin_bindings)

    # Wait for the greedy policy to become available, then load it.
    greedy_policy_dir = os.path.join(FLAGS.root_dir,
                                     learner.POLICY_SAVED_MODEL_DIR,
                                     learner.GREEDY_POLICY_SAVED_MODEL_DIR)
    policy = train_utils.wait_for_policy(greedy_policy_dir,
                                         load_specs_from_pbtxt=True)

    # Create the variable container. The weights of the greedy policy is updated
    # from it periodically.
    variable_container = reverb_variable_container.ReverbVariableContainer(
        FLAGS.variable_container_server_address,
        table_names=[reverb_variable_container.DEFAULT_TABLE])

    # Run the evaluation.
    evaluate(summary_dir=os.path.join(FLAGS.root_dir, learner.TRAIN_DIR,
                                      'eval'),
             environment_name=gin.REQUIRED,
             policy=policy,
             variable_container=variable_container)
예제 #2
0
def run_eval(
    root_dir: Text,
    # TODO(b/178225158): Deprecate in favor of the reporting libray when ready.
    return_reporting_fn: Optional[Callable[[int, float], None]] = None
) -> None:
    """Load the policy and evaluate it.

  Args:
    root_dir: the root directory for this experiment.
    return_reporting_fn: Optional callback function of the form `fn(train_step,
      average_return)` which reports the average return to a custom destination.
  """
    # Wait for the greedy policy to become available, then load it.
    greedy_policy_dir = os.path.join(root_dir, learner.POLICY_SAVED_MODEL_DIR,
                                     learner.GREEDY_POLICY_SAVED_MODEL_DIR)
    policy = train_utils.wait_for_policy(greedy_policy_dir,
                                         load_specs_from_pbtxt=True)

    # Create the variable container. The weights of the greedy policy is updated
    # from it periodically.
    variable_container = reverb_variable_container.ReverbVariableContainer(
        FLAGS.variable_container_server_address,
        table_names=[reverb_variable_container.DEFAULT_TABLE])

    # Prepare summary directory.
    summary_dir = os.path.join(FLAGS.root_dir, learner.TRAIN_DIR, 'eval',
                               str(FLAGS.task))

    # Run the evaluation.
    evaluate(summary_dir=summary_dir,
             environment_name=gin.REQUIRED,
             policy=policy,
             variable_container=variable_container,
             return_reporting_fn=return_reporting_fn)
예제 #3
0
def run_eval(root_dir: Text) -> None:
    """Load the policy and evaluate it."""
    # Wait for the greedy policy to become available, then load it.
    greedy_policy_dir = os.path.join(root_dir, learner.POLICY_SAVED_MODEL_DIR,
                                     learner.GREEDY_POLICY_SAVED_MODEL_DIR)
    policy = train_utils.wait_for_policy(greedy_policy_dir,
                                         load_specs_from_pbtxt=True)

    # Create the variable container. The weights of the greedy policy is updated
    # from it periodically.
    variable_container = reverb_variable_container.ReverbVariableContainer(
        FLAGS.variable_container_server_address,
        table_names=[reverb_variable_container.DEFAULT_TABLE])

    # Run the evaluation.
    evaluate(summary_dir=os.path.join(root_dir, learner.TRAIN_DIR, 'eval'),
             environment_name=gin.REQUIRED,
             policy=policy,
             variable_container=variable_container)
예제 #4
0
def main(unused_argv: Sequence[Text]) -> None:
  logging.set_verbosity(logging.INFO)

  summary_dir = os.path.join(FLAGS.root_dir, learner.TRAIN_DIR, 'eval',
                             str(FLAGS.node_id))
  policy_dir = os.path.join(FLAGS.root_dir, learner.POLICY_SAVED_MODEL_DIR,
                            learner.GREEDY_POLICY_SAVED_MODEL_DIR)
  checkpoint_dir = os.path.join(
      FLAGS.root_dir, learner.TRAIN_DIR, learner.POLICY_CHECKPOINT_DIR)
  policy = train_utils.wait_for_policy(policy_dir, load_specs_from_pbtxt=True)

  eval_worker = EvaluatorWorker(
      summary_dir,
      checkpoint_dir,
      policy,
      node_id=FLAGS.node_id,
      env_name=FLAGS.env_name,
      num_eval_episodes=5,
      max_train_step=1000,)
  eval_worker.run()
예제 #5
0
def main(_):
  logging.set_verbosity(logging.INFO)
  tf.enable_v2_behavior()

  gin.parse_config_files_and_bindings(FLAGS.gin_file, FLAGS.gin_bindings)

  # Wait for the collect policy to become available, then load it.
  collect_policy_dir = os.path.join(FLAGS.root_dir,
                                    learner.POLICY_SAVED_MODEL_DIR,
                                    learner.COLLECT_POLICY_SAVED_MODEL_DIR)
  collect_policy = train_utils.wait_for_policy(
      collect_policy_dir, load_specs_from_pbtxt=True)

  # Prepare summary directory.
  summary_dir = os.path.join(FLAGS.root_dir, learner.TRAIN_DIR, str(FLAGS.task))

  # Perform collection.
  collect(
      summary_dir=summary_dir,
      environment_name=gin.REQUIRED,
      collect_policy=collect_policy,
      replay_buffer_server_address=FLAGS.replay_buffer_server_address,
      variable_container_server_address=FLAGS.variable_container_server_address)
예제 #6
0
def main(_):
    logging.set_verbosity(logging.INFO)

    # Wait for the collect policy to become available, then load it.
    collect_policy_dir = os.path.join(FLAGS.root_dir,
                                      learner.POLICY_SAVED_MODEL_DIR,
                                      learner.COLLECT_POLICY_SAVED_MODEL_DIR)
    collect_policy = train_utils.wait_for_policy(collect_policy_dir,
                                                 load_specs_from_pbtxt=True)

    samples_per_insert = FLAGS.samples_per_insert
    min_table_size_before_sampling = FLAGS.min_table_size_before_sampling

    # Create the signature for the variable container holding the policy weights.
    train_step = train_utils.create_train_step()
    variables = {
        reverb_variable_container.POLICY_KEY: collect_policy.variables(),
        reverb_variable_container.TRAIN_STEP_KEY: train_step
    }
    variable_container_signature = tf.nest.map_structure(
        lambda variable: tf.TensorSpec(variable.shape, dtype=variable.dtype),
        variables)
    logging.info('Signature of variables: \n%s', variable_container_signature)

    # Create the signature for the replay buffer holding observed experience.
    replay_buffer_signature = tensor_spec.from_spec(
        collect_policy.collect_data_spec)
    replay_buffer_signature = tf.nest.map_structure(
        lambda s: tf.TensorSpec((None, ) + s.shape, s.dtype, s.name),
        replay_buffer_signature)
    logging.info('Signature of experience: \n%s', replay_buffer_signature)

    if samples_per_insert is not None:
        # Use SamplesPerInsertRatio limiter
        samples_per_insert_tolerance = (_SAMPLES_PER_INSERT_TOLERANCE_RATIO *
                                        samples_per_insert)
        error_buffer = min_table_size_before_sampling * samples_per_insert_tolerance

        experience_rate_limiter = reverb.rate_limiters.SampleToInsertRatio(
            min_size_to_sample=min_table_size_before_sampling,
            samples_per_insert=samples_per_insert,
            error_buffer=error_buffer)
    else:
        # Use MinSize limiter
        experience_rate_limiter = reverb.rate_limiters.MinSize(
            min_table_size_before_sampling)

    # Crete and start the replay buffer and variable container server.
    server = reverb.Server(
        tables=[
            reverb.Table(  # Replay buffer storing experience.
                name=reverb_replay_buffer.DEFAULT_TABLE,
                sampler=reverb.selectors.Uniform(),
                remover=reverb.selectors.Fifo(),
                rate_limiter=experience_rate_limiter,
                max_size=FLAGS.replay_buffer_capacity,
                max_times_sampled=0,
                signature=replay_buffer_signature,
            ),
            reverb.Table(  # Variable container storing policy parameters.
                name=reverb_variable_container.DEFAULT_TABLE,
                sampler=reverb.selectors.Uniform(),
                remover=reverb.selectors.Fifo(),
                rate_limiter=reverb.rate_limiters.MinSize(1),
                max_size=1,
                max_times_sampled=0,
                signature=variable_container_signature,
            ),
        ],
        port=FLAGS.port)
    server.wait()