def main(unused_argv: Sequence[Text]) -> None: logging.set_verbosity(logging.INFO) tf.enable_v2_behavior() gin.parse_config_files_and_bindings(FLAGS.gin_file, FLAGS.gin_bindings) # Wait for the greedy policy to become available, then load it. greedy_policy_dir = os.path.join(FLAGS.root_dir, learner.POLICY_SAVED_MODEL_DIR, learner.GREEDY_POLICY_SAVED_MODEL_DIR) policy = train_utils.wait_for_policy(greedy_policy_dir, load_specs_from_pbtxt=True) # Create the variable container. The weights of the greedy policy is updated # from it periodically. variable_container = reverb_variable_container.ReverbVariableContainer( FLAGS.variable_container_server_address, table_names=[reverb_variable_container.DEFAULT_TABLE]) # Run the evaluation. evaluate(summary_dir=os.path.join(FLAGS.root_dir, learner.TRAIN_DIR, 'eval'), environment_name=gin.REQUIRED, policy=policy, variable_container=variable_container)
def run_eval( root_dir: Text, # TODO(b/178225158): Deprecate in favor of the reporting libray when ready. return_reporting_fn: Optional[Callable[[int, float], None]] = None ) -> None: """Load the policy and evaluate it. Args: root_dir: the root directory for this experiment. return_reporting_fn: Optional callback function of the form `fn(train_step, average_return)` which reports the average return to a custom destination. """ # Wait for the greedy policy to become available, then load it. greedy_policy_dir = os.path.join(root_dir, learner.POLICY_SAVED_MODEL_DIR, learner.GREEDY_POLICY_SAVED_MODEL_DIR) policy = train_utils.wait_for_policy(greedy_policy_dir, load_specs_from_pbtxt=True) # Create the variable container. The weights of the greedy policy is updated # from it periodically. variable_container = reverb_variable_container.ReverbVariableContainer( FLAGS.variable_container_server_address, table_names=[reverb_variable_container.DEFAULT_TABLE]) # Prepare summary directory. summary_dir = os.path.join(FLAGS.root_dir, learner.TRAIN_DIR, 'eval', str(FLAGS.task)) # Run the evaluation. evaluate(summary_dir=summary_dir, environment_name=gin.REQUIRED, policy=policy, variable_container=variable_container, return_reporting_fn=return_reporting_fn)
def run_eval(root_dir: Text) -> None: """Load the policy and evaluate it.""" # Wait for the greedy policy to become available, then load it. greedy_policy_dir = os.path.join(root_dir, learner.POLICY_SAVED_MODEL_DIR, learner.GREEDY_POLICY_SAVED_MODEL_DIR) policy = train_utils.wait_for_policy(greedy_policy_dir, load_specs_from_pbtxt=True) # Create the variable container. The weights of the greedy policy is updated # from it periodically. variable_container = reverb_variable_container.ReverbVariableContainer( FLAGS.variable_container_server_address, table_names=[reverb_variable_container.DEFAULT_TABLE]) # Run the evaluation. evaluate(summary_dir=os.path.join(root_dir, learner.TRAIN_DIR, 'eval'), environment_name=gin.REQUIRED, policy=policy, variable_container=variable_container)
def main(unused_argv: Sequence[Text]) -> None: logging.set_verbosity(logging.INFO) summary_dir = os.path.join(FLAGS.root_dir, learner.TRAIN_DIR, 'eval', str(FLAGS.node_id)) policy_dir = os.path.join(FLAGS.root_dir, learner.POLICY_SAVED_MODEL_DIR, learner.GREEDY_POLICY_SAVED_MODEL_DIR) checkpoint_dir = os.path.join( FLAGS.root_dir, learner.TRAIN_DIR, learner.POLICY_CHECKPOINT_DIR) policy = train_utils.wait_for_policy(policy_dir, load_specs_from_pbtxt=True) eval_worker = EvaluatorWorker( summary_dir, checkpoint_dir, policy, node_id=FLAGS.node_id, env_name=FLAGS.env_name, num_eval_episodes=5, max_train_step=1000,) eval_worker.run()
def main(_): logging.set_verbosity(logging.INFO) tf.enable_v2_behavior() gin.parse_config_files_and_bindings(FLAGS.gin_file, FLAGS.gin_bindings) # Wait for the collect policy to become available, then load it. collect_policy_dir = os.path.join(FLAGS.root_dir, learner.POLICY_SAVED_MODEL_DIR, learner.COLLECT_POLICY_SAVED_MODEL_DIR) collect_policy = train_utils.wait_for_policy( collect_policy_dir, load_specs_from_pbtxt=True) # Prepare summary directory. summary_dir = os.path.join(FLAGS.root_dir, learner.TRAIN_DIR, str(FLAGS.task)) # Perform collection. collect( summary_dir=summary_dir, environment_name=gin.REQUIRED, collect_policy=collect_policy, replay_buffer_server_address=FLAGS.replay_buffer_server_address, variable_container_server_address=FLAGS.variable_container_server_address)
def main(_): logging.set_verbosity(logging.INFO) # Wait for the collect policy to become available, then load it. collect_policy_dir = os.path.join(FLAGS.root_dir, learner.POLICY_SAVED_MODEL_DIR, learner.COLLECT_POLICY_SAVED_MODEL_DIR) collect_policy = train_utils.wait_for_policy(collect_policy_dir, load_specs_from_pbtxt=True) samples_per_insert = FLAGS.samples_per_insert min_table_size_before_sampling = FLAGS.min_table_size_before_sampling # Create the signature for the variable container holding the policy weights. train_step = train_utils.create_train_step() variables = { reverb_variable_container.POLICY_KEY: collect_policy.variables(), reverb_variable_container.TRAIN_STEP_KEY: train_step } variable_container_signature = tf.nest.map_structure( lambda variable: tf.TensorSpec(variable.shape, dtype=variable.dtype), variables) logging.info('Signature of variables: \n%s', variable_container_signature) # Create the signature for the replay buffer holding observed experience. replay_buffer_signature = tensor_spec.from_spec( collect_policy.collect_data_spec) replay_buffer_signature = tf.nest.map_structure( lambda s: tf.TensorSpec((None, ) + s.shape, s.dtype, s.name), replay_buffer_signature) logging.info('Signature of experience: \n%s', replay_buffer_signature) if samples_per_insert is not None: # Use SamplesPerInsertRatio limiter samples_per_insert_tolerance = (_SAMPLES_PER_INSERT_TOLERANCE_RATIO * samples_per_insert) error_buffer = min_table_size_before_sampling * samples_per_insert_tolerance experience_rate_limiter = reverb.rate_limiters.SampleToInsertRatio( min_size_to_sample=min_table_size_before_sampling, samples_per_insert=samples_per_insert, error_buffer=error_buffer) else: # Use MinSize limiter experience_rate_limiter = reverb.rate_limiters.MinSize( min_table_size_before_sampling) # Crete and start the replay buffer and variable container server. server = reverb.Server( tables=[ reverb.Table( # Replay buffer storing experience. name=reverb_replay_buffer.DEFAULT_TABLE, sampler=reverb.selectors.Uniform(), remover=reverb.selectors.Fifo(), rate_limiter=experience_rate_limiter, max_size=FLAGS.replay_buffer_capacity, max_times_sampled=0, signature=replay_buffer_signature, ), reverb.Table( # Variable container storing policy parameters. name=reverb_variable_container.DEFAULT_TABLE, sampler=reverb.selectors.Uniform(), remover=reverb.selectors.Fifo(), rate_limiter=reverb.rate_limiters.MinSize(1), max_size=1, max_times_sampled=0, signature=variable_container_signature, ), ], port=FLAGS.port) server.wait()