예제 #1
0
def main(unused_argv: Sequence[Text]) -> None:
    logging.set_verbosity(logging.INFO)
    tf.enable_v2_behavior()

    gin.parse_config_files_and_bindings(FLAGS.gin_file, FLAGS.gin_bindings)

    # Wait for the greedy policy to become available, then load it.
    greedy_policy_dir = os.path.join(FLAGS.root_dir,
                                     learner.POLICY_SAVED_MODEL_DIR,
                                     learner.GREEDY_POLICY_SAVED_MODEL_DIR)
    policy = train_utils.wait_for_policy(greedy_policy_dir,
                                         load_specs_from_pbtxt=True)

    # Create the variable container. The weights of the greedy policy is updated
    # from it periodically.
    variable_container = reverb_variable_container.ReverbVariableContainer(
        FLAGS.variable_container_server_address,
        table_names=[reverb_variable_container.DEFAULT_TABLE])

    # Run the evaluation.
    evaluate(summary_dir=os.path.join(FLAGS.root_dir, learner.TRAIN_DIR,
                                      'eval'),
             environment_name=gin.REQUIRED,
             policy=policy,
             variable_container=variable_container)
예제 #2
0
def run_eval(
    root_dir: Text,
    # TODO(b/178225158): Deprecate in favor of the reporting libray when ready.
    return_reporting_fn: Optional[Callable[[int, float], None]] = None
) -> None:
    """Load the policy and evaluate it.

  Args:
    root_dir: the root directory for this experiment.
    return_reporting_fn: Optional callback function of the form `fn(train_step,
      average_return)` which reports the average return to a custom destination.
  """
    # Wait for the greedy policy to become available, then load it.
    greedy_policy_dir = os.path.join(root_dir, learner.POLICY_SAVED_MODEL_DIR,
                                     learner.GREEDY_POLICY_SAVED_MODEL_DIR)
    policy = train_utils.wait_for_policy(greedy_policy_dir,
                                         load_specs_from_pbtxt=True)

    # Create the variable container. The weights of the greedy policy is updated
    # from it periodically.
    variable_container = reverb_variable_container.ReverbVariableContainer(
        FLAGS.variable_container_server_address,
        table_names=[reverb_variable_container.DEFAULT_TABLE])

    # Prepare summary directory.
    summary_dir = os.path.join(FLAGS.root_dir, learner.TRAIN_DIR, 'eval',
                               str(FLAGS.task))

    # Run the evaluation.
    evaluate(summary_dir=summary_dir,
             environment_name=gin.REQUIRED,
             policy=policy,
             variable_container=variable_container,
             return_reporting_fn=return_reporting_fn)
  def test_update_raises_value_error_if_variable_struct_not_match(self) -> None:
    # Prepare some data in the Reverb server.
    self._push_nested_data()

    variable_container = reverb_variable_container.ReverbVariableContainer(
        self._server_address)
    with self.assertRaises(ValueError):
      variable_container.update(tf.Variable(1))
예제 #4
0
def collect(summary_dir: Text,
            environment_name: Text,
            collect_policy: py_tf_eager_policy.PyTFEagerPolicyBase,
            replay_buffer_server_address: Text,
            variable_container_server_address: Text,
            suite_load_fn: Callable[
                [Text], py_environment.PyEnvironment] = suite_mujoco.load,
            initial_collect_steps: int = 10000,
            max_train_steps: int = 2000000) -> None:
  """Collects experience using a policy updated after every episode."""
  # Create the environment. For now support only single environment collection.
  collect_env = suite_load_fn(environment_name)

  # Create the variable container.
  train_step = train_utils.create_train_step()
  variables = {
      reverb_variable_container.POLICY_KEY: collect_policy.variables(),
      reverb_variable_container.TRAIN_STEP_KEY: train_step
  }
  variable_container = reverb_variable_container.ReverbVariableContainer(
      variable_container_server_address,
      table_names=[reverb_variable_container.DEFAULT_TABLE])
  variable_container.update(variables)

  # Create the replay buffer observer.
  rb_observer = reverb_utils.ReverbAddTrajectoryObserver(
      reverb.Client(replay_buffer_server_address),
      table_name=reverb_replay_buffer.DEFAULT_TABLE,
      sequence_length=2,
      stride_length=1)

  random_policy = random_py_policy.RandomPyPolicy(collect_env.time_step_spec(),
                                                  collect_env.action_spec())
  initial_collect_actor = actor.Actor(
      collect_env,
      random_policy,
      train_step,
      steps_per_run=initial_collect_steps,
      observers=[rb_observer])
  logging.info('Doing initial collect.')
  initial_collect_actor.run()

  env_step_metric = py_metrics.EnvironmentSteps()
  collect_actor = actor.Actor(
      collect_env,
      collect_policy,
      train_step,
      steps_per_run=1,
      metrics=actor.collect_metrics(10),
      summary_dir=summary_dir,
      observers=[rb_observer, env_step_metric])

  # Run the experience collection loop.
  while train_step.numpy() < max_train_steps:
    logging.info('Collecting with policy at step: %d', train_step.numpy())
    collect_actor.run()
    variable_container.update(variables)
    def test_push(self) -> None:
        # Prepare nested variables to push into the server.
        variables = _create_nested_variable()

        # Push the input to the server.
        variable_container = reverb_variable_container.ReverbVariableContainer(
            self._server_address)
        variable_container.push(variables)  # pytype: disable=wrong-arg-types

        # Check the content of the server.
        self._assert_nested_variable_in_server()
 def test_push_raises_error_if_variable_type_is_wrong(self) -> None:
   variable_container = reverb_variable_container.ReverbVariableContainer(
       self._server_address)
   # The first element has a type `tf.int64` in the signature, but here we
   # declare `tf.int32`.
   variables_with_wrong_type = (tf.Variable(-1, dtype=tf.int32, shape=()), {
       'var1': (tf.Variable([0, 0], dtype=tf.float64, shape=(2,)),),
       'var2': tf.Variable([[0], [0]], dtype=tf.int32, shape=(2, 1))
   })
   with self.assertRaises(tf.errors.InvalidArgumentError):
     variable_container.push(variables_with_wrong_type)
    def test_push_under_distribute_strategy(
            self, strategy: tf.distribute.Strategy) -> None:
        # Prepare nested variables under strategy scope to push into the server.
        with strategy.scope():
            variables = _create_nested_variable()
        logging.info('Variables: %s', variables)

        # Push the input to the server.
        variable_container = reverb_variable_container.ReverbVariableContainer(
            self._server_address)
        variable_container.push(variables)  # pytype: disable=wrong-arg-types

        # Check the content of the server.
        self._assert_nested_variable_in_server()
  def test_update_raises_value_error_if_variable_type_is_wrong(self) -> None:
    # Prepare some data in the Reverb server.
    self._push_nested_data()

    variable_container = reverb_variable_container.ReverbVariableContainer(
        self._server_address)
    # The first element has a type `tf.int64` in the signature, but here we
    # declare `tf.int32`.
    variables_with_wrong_type = (tf.Variable(-1, dtype=tf.int32, shape=()), {
        'var1': (tf.Variable([0, 0], dtype=tf.float64, shape=(2,)),),
        'var2': tf.Variable([[0], [0]], dtype=tf.int32, shape=(2, 1))
    })
    with self.assertRaises(ValueError):
      variable_container.update(variables_with_wrong_type)
  def test_update(self) -> None:
    # Prepare some data in the Reverb server.
    self._push_nested_data()

    # Get the values from the server.
    variables = (tf.Variable(-1, dtype=tf.int64, shape=()), {
        'var1': (tf.Variable([0, 0], dtype=tf.float64, shape=(2,)),),
        'var2': tf.Variable([[0], [0]], dtype=tf.int32, shape=(2, 1))
    })

    # Update variables based on value pulled from the server.
    variable_container = reverb_variable_container.ReverbVariableContainer(
        self._server_address)
    variable_container.update(variables)

    # Check the values of the `variables`.
    self._assert_nested_variable_updated(variables)
  def test_push_with_not_exact_sequence_type_matching(self) -> None:
    # The second element (i.e the value of `var1`) was in a tuple in the
    # original signature, here we place it into a list.
    variables = (tf.Variable(0, dtype=tf.int64, shape=()), {
        'var1': [tf.Variable([1, 1], dtype=tf.float64, shape=(2,))],
        'var2': tf.Variable([[2], [3]], dtype=tf.int32, shape=(2, 1))
    })

    # Sequence type check is turned off by default allowing sequence type
    # differences in the signature. This is required to be able work with
    # policies loaded from file which often change tuple to e.g. `ListWrapper`.
    variable_container = reverb_variable_container.ReverbVariableContainer(
        self._server_address)
    variable_container.push(variables)

    # Check the content of the server.
    self._assert_nested_variable_in_server()
예제 #11
0
def run_eval(root_dir: Text) -> None:
    """Load the policy and evaluate it."""
    # Wait for the greedy policy to become available, then load it.
    greedy_policy_dir = os.path.join(root_dir, learner.POLICY_SAVED_MODEL_DIR,
                                     learner.GREEDY_POLICY_SAVED_MODEL_DIR)
    policy = train_utils.wait_for_policy(greedy_policy_dir,
                                         load_specs_from_pbtxt=True)

    # Create the variable container. The weights of the greedy policy is updated
    # from it periodically.
    variable_container = reverb_variable_container.ReverbVariableContainer(
        FLAGS.variable_container_server_address,
        table_names=[reverb_variable_container.DEFAULT_TABLE])

    # Run the evaluation.
    evaluate(summary_dir=os.path.join(root_dir, learner.TRAIN_DIR, 'eval'),
             environment_name=gin.REQUIRED,
             policy=policy,
             variable_container=variable_container)
예제 #12
0
  def build_and_run_actor():
    root_dir = test_case.create_tempdir().full_path
    env, action_tensor_spec, time_step_tensor_spec = (
        get_cartpole_env_and_specs())

    train_step = train_utils.create_train_step()

    q_net = build_dummy_sequential_net(fc_layer_params=(100,),
                                       action_spec=action_tensor_spec)

    agent = dqn_agent.DqnAgent(
        time_step_tensor_spec,
        action_tensor_spec,
        q_network=q_net,
        optimizer=tf.keras.optimizers.Adam(learning_rate=0.001),
        train_step_counter=train_step)

    _, rb_observer = (
        replay_buffer_utils.get_reverb_buffer_and_observer(
            agent.collect_data_spec,
            table_name=reverb_replay_buffer.DEFAULT_TABLE,
            sequence_length=2,
            reverb_server_address='localhost:{}'.format(reverb_server_port)))

    variable_container = reverb_variable_container.ReverbVariableContainer(
        server_address='localhost:{}'.format(reverb_server_port),
        table_names=[reverb_variable_container.DEFAULT_TABLE])

    test_actor = build_actor(
        root_dir, env, agent, rb_observer, train_step)

    variables_dict = {
        reverb_variable_container.POLICY_KEY: agent.collect_policy.variables(),
        reverb_variable_container.TRAIN_STEP_KEY: train_step
    }
    variable_container.update(variables_dict)

    for _ in range(num_iterations):
      test_actor.run()
    def test_update_with_not_exact_sequence_type_matching(self) -> None:
        # Prepare some data in the Reverb server.
        self._push_nested_data()

        # The second element (i.e the value of `var1`) was in a tuple in the
        # original signature, here we place it into a list.
        variables = (tf.Variable(-1, dtype=tf.int64, shape=()), {
            'var1': [tf.Variable([0, 0], dtype=tf.float64, shape=(2, ))],
            'var2':
            tf.Variable([[0], [0]], dtype=tf.int32, shape=(2, 1))
        })

        # Sequence type check is turned off by default allowing sequence type
        # differences in the signature. This is required to be able work with
        # policies loaded from file which often change tuple to e.g. `ListWrapper`.
        variable_container = reverb_variable_container.ReverbVariableContainer(
            self._server_address)
        variable_container.update(variables)  # pytype: disable=wrong-arg-types

        # Check the values of the `variables`.
        self._assert_nested_variable_updated(variables,
                                             check_nest_seq_types=False)  # pytype: disable=wrong-arg-types
 def test_init_raises_key_error_if_undefined_table_passed(self):
     server, server_address = _create_server(table='no_variables_table')
     with self.assertRaises(KeyError):
         reverb_variable_container.ReverbVariableContainer(server_address)
     server.stop()
 def test_init_raises_type_error_if_no_signature_of_a_table(self):
     server, server_address = _create_server(signature=None)  # pytype: disable=wrong-arg-types
     with self.assertRaises(TypeError):
         reverb_variable_container.ReverbVariableContainer(server_address)
     server.stop()
 def test_pull_raises_key_error_on_unknown_table(self) -> None:
     variable_container = reverb_variable_container.ReverbVariableContainer(
         self._server_address)
     with self.assertRaises(KeyError):
         variable_container.pull('unknown_table')
 def test_push_raises_error_if_variable_struct_not_match(self) -> None:
     variable_container = reverb_variable_container.ReverbVariableContainer(
         self._server_address)
     with self.assertRaises(tf.errors.InvalidArgumentError):
         variable_container.push(tf.Variable(1))
 def test_init_raises_value_error_if_max_size_is_different_than_one(self):
     server, server_address = _create_server(max_size=2)
     with self.assertRaises(ValueError):
         reverb_variable_container.ReverbVariableContainer(server_address)
     server.stop()
예제 #19
0
def collect(task,
            root_dir,
            replay_buffer_server_address,
            variable_container_server_address,
            create_env_fn,
            initial_collect_steps=10000,
            num_iterations=10000000):
  """Collects experience using a policy updated after every episode."""
  # Create the environment. For now support only single environment collection.
  collect_env = create_env_fn()

  # Create the path for the serialized collect policy.
  collect_policy_saved_model_path = os.path.join(
      root_dir, learner.POLICY_SAVED_MODEL_DIR,
      learner.COLLECT_POLICY_SAVED_MODEL_DIR)
  saved_model_pb_path = os.path.join(collect_policy_saved_model_path,
                                     'saved_model.pb')
  try:
    # Wait for the collect policy to be outputed by learner (timeout after 2
    # days), then load it.
    train_utils.wait_for_file(
        saved_model_pb_path, sleep_time_secs=2, num_retries=86400)
    collect_policy = py_tf_eager_policy.SavedModelPyTFEagerPolicy(
        collect_policy_saved_model_path, load_specs_from_pbtxt=True)
  except TimeoutError as e:
    # If the collect policy does not become available during the wait time of
    # the call `wait_for_file`, that probably means the learner is not running.
    logging.error('Could not get the file %s. Exiting.', saved_model_pb_path)
    raise e

  # Create the variable container.
  train_step = train_utils.create_train_step()
  variables = {
      reverb_variable_container.POLICY_KEY: collect_policy.variables(),
      reverb_variable_container.TRAIN_STEP_KEY: train_step
  }
  variable_container = reverb_variable_container.ReverbVariableContainer(
      variable_container_server_address,
      table_names=[reverb_variable_container.DEFAULT_TABLE])
  variable_container.update(variables)

  # Create the replay buffer observer.
  rb_observer = reverb_utils.ReverbAddTrajectoryObserver(
      reverb.Client(replay_buffer_server_address),
      table_name=reverb_replay_buffer.DEFAULT_TABLE,
      sequence_length=2,
      stride_length=1)

  random_policy = random_py_policy.RandomPyPolicy(
      collect_env.time_step_spec(), collect_env.action_spec())
  initial_collect_actor = actor.Actor(
      collect_env,
      random_policy,
      train_step,
      steps_per_run=initial_collect_steps,
      observers=[rb_observer])
  logging.info('Doing initial collect.')
  initial_collect_actor.run()

  env_step_metric = py_metrics.EnvironmentSteps()
  collect_actor = actor.Actor(
      collect_env,
      collect_policy,
      train_step,
      steps_per_run=1,
      metrics=actor.collect_metrics(10),
      summary_dir=os.path.join(root_dir, learner.TRAIN_DIR, str(task)),
      observers=[rb_observer, env_step_metric])

  # Run the experience collection loop.
  for _ in range(num_iterations):
    logging.info('Collecting with policy at step: %d', train_step.numpy())
    collect_actor.run()
    variable_container.update(variables)
예제 #20
0
def train(
    root_dir: Text,
    environment_name: Text,
    strategy: tf.distribute.Strategy,
    replay_buffer_server_address: Text,
    variable_container_server_address: Text,
    suite_load_fn: Callable[[Text],
                            py_environment.PyEnvironment] = suite_mujoco.load,
    # Training params
    learning_rate: float = 3e-4,
    batch_size: int = 256,
    num_iterations: int = 2000000,
    learner_iterations_per_call: int = 1) -> None:
  """Trains a DQN agent."""
  # Get the specs from the environment.
  logging.info('Training SAC with learning rate: %f', learning_rate)
  env = suite_load_fn(environment_name)
  observation_tensor_spec, action_tensor_spec, time_step_tensor_spec = (
      spec_utils.get_tensor_specs(env))

  # Create the agent.
  with strategy.scope():
    train_step = train_utils.create_train_step()
    agent = _create_agent(
        train_step=train_step,
        observation_tensor_spec=observation_tensor_spec,
        action_tensor_spec=action_tensor_spec,
        time_step_tensor_spec=time_step_tensor_spec,
        learning_rate=learning_rate)

  # Create the policy saver which saves the initial model now, then it
  # periodically checkpoints the policy weigths.
  saved_model_dir = os.path.join(root_dir, learner.POLICY_SAVED_MODEL_DIR)
  save_model_trigger = triggers.PolicySavedModelTrigger(
      saved_model_dir, agent, train_step, interval=1000)

  # Create the variable container.
  variables = {
      reverb_variable_container.POLICY_KEY: agent.collect_policy.variables(),
      reverb_variable_container.TRAIN_STEP_KEY: train_step
  }
  variable_container = reverb_variable_container.ReverbVariableContainer(
      variable_container_server_address,
      table_names=[reverb_variable_container.DEFAULT_TABLE])
  variable_container.push(variables)

  # Create the replay buffer.
  reverb_replay = reverb_replay_buffer.ReverbReplayBuffer(
      agent.collect_data_spec,
      sequence_length=2,
      table_name=reverb_replay_buffer.DEFAULT_TABLE,
      server_address=replay_buffer_server_address)

  # Initialize the dataset.
  def experience_dataset_fn():
    with strategy.scope():
      return reverb_replay.as_dataset(
          sample_batch_size=batch_size, num_steps=2).prefetch(3)

  # Create the learner.
  learning_triggers = [
      save_model_trigger,
      triggers.StepPerSecondLogTrigger(train_step, interval=1000)
  ]
  sac_learner = learner.Learner(
      root_dir,
      train_step,
      agent,
      experience_dataset_fn,
      triggers=learning_triggers,
      strategy=strategy)

  # Run the training loop.
  while train_step.numpy() < num_iterations:
    sac_learner.run(iterations=learner_iterations_per_call)
    variable_container.push(variables)
예제 #21
0
def train(
    root_dir,
    strategy,
    replay_buffer_server_address,
    variable_container_server_address,
    create_agent_fn,
    create_env_fn,
    # Training params
    learning_rate=3e-4,
    batch_size=256,
    num_iterations=32000,
    learner_iterations_per_call=100):
  """Trains a DQN agent."""
  # Get the specs from the environment.
  logging.info('Training SAC with learning rate: %f', learning_rate)
  env = create_env_fn()
  observation_tensor_spec, action_tensor_spec, time_step_tensor_spec = (
      spec_utils.get_tensor_specs(env))

  # Create the agent.
  with strategy.scope():
    train_step = train_utils.create_train_step()
    agent = create_agent_fn(train_step, observation_tensor_spec,
                            action_tensor_spec, time_step_tensor_spec,
                            learning_rate)
    agent.initialize()

  # Create the policy saver which saves the initial model now, then it
  # periodically checkpoints the policy weigths.
  saved_model_dir = os.path.join(root_dir, learner.POLICY_SAVED_MODEL_DIR)
  save_model_trigger = triggers.PolicySavedModelTrigger(
      saved_model_dir, agent, train_step, interval=1000)

  # Create the variable container.
  variables = {
      reverb_variable_container.POLICY_KEY: agent.collect_policy.variables(),
      reverb_variable_container.TRAIN_STEP_KEY: train_step
  }
  variable_container = reverb_variable_container.ReverbVariableContainer(
      variable_container_server_address,
      table_names=[reverb_variable_container.DEFAULT_TABLE])
  variable_container.push(variables)

  # Create the replay buffer.
  reverb_replay = reverb_replay_buffer.ReverbReplayBuffer(
      agent.collect_data_spec,
      sequence_length=2,
      table_name=reverb_replay_buffer.DEFAULT_TABLE,
      server_address=replay_buffer_server_address)

  # Initialize the dataset.
  def experience_dataset_fn():
    with strategy.scope():
      return reverb_replay.as_dataset(
          sample_batch_size=batch_size, num_steps=2).prefetch(3)

  # Create the learner.
  learning_triggers = [
      save_model_trigger,
      triggers.StepPerSecondLogTrigger(train_step, interval=1000)
  ]
  sac_learner = learner.Learner(
      root_dir,
      train_step,
      agent,
      experience_dataset_fn,
      triggers=learning_triggers,
      strategy=strategy)

  # Run the training loop.
  # TODO(b/162440911) change the loop use train_step to handle preemptions
  for _ in range(num_iterations):
    sac_learner.run(iterations=learner_iterations_per_call)
    variable_container.push(variables)