예제 #1
0
    def _setup_policy(self, time_step_spec, action_spec, boltzmann_temperature,
                      emit_log_probability):
        policy = categorical_q_policy.CategoricalQPolicy(
            time_step_spec,
            action_spec,
            self._q_network,
            self._min_q_value,
            self._max_q_value,
            observation_and_action_constraint_splitter=(
                self._observation_and_action_constraint_splitter))

        if boltzmann_temperature is not None:
            collect_policy = boltzmann_policy.BoltzmannPolicy(
                policy, temperature=boltzmann_temperature)
        else:
            collect_policy = epsilon_greedy_policy.EpsilonGreedyPolicy(
                policy, epsilon=self._epsilon_greedy)
        policy = greedy_policy.GreedyPolicy(policy)

        target_policy = categorical_q_policy.CategoricalQPolicy(
            time_step_spec,
            action_spec,
            self._target_q_network,
            self._min_q_value,
            self._max_q_value,
            observation_and_action_constraint_splitter=(
                self._observation_and_action_constraint_splitter))
        self._target_greedy_policy = greedy_policy.GreedyPolicy(target_policy)

        return policy, collect_policy
예제 #2
0
파일: dqn_agent.py 프로젝트: wuzh07/agents
  def _setup_policy(self, time_step_spec, action_spec,
                    boltzmann_temperature, emit_log_probability):

    policy = q_policy.QPolicy(
        time_step_spec,
        action_spec,
        q_network=self._q_network,
        emit_log_probability=emit_log_probability,
        observation_and_action_constraint_splitter=(
            self._observation_and_action_constraint_splitter))

    if boltzmann_temperature is not None:
      collect_policy = boltzmann_policy.BoltzmannPolicy(
          policy, temperature=self._boltzmann_temperature)
    else:
      collect_policy = epsilon_greedy_policy.EpsilonGreedyPolicy(
          policy, epsilon=self._epsilon_greedy)
    policy = greedy_policy.GreedyPolicy(policy)

    # Create self._target_greedy_policy in order to compute target Q-values.
    target_policy = q_policy.QPolicy(
        time_step_spec,
        action_spec,
        q_network=self._target_q_network,
        observation_and_action_constraint_splitter=(
            self._observation_and_action_constraint_splitter))
    self._target_greedy_policy = greedy_policy.GreedyPolicy(target_policy)

    return policy, collect_policy
예제 #3
0
 def get_eval_policy(self):
     """Returns the greedy policy of the agent
 
 Returns:
     GreedyPolicy -- Always returns best suitable action
 """
     return greedy_policy.GreedyPolicy(self._agent.policy)
 def _get_policies(self, time_step_spec, action_spec, cloning_network):
   policy = q_policy.QPolicy(
       time_step_spec, action_spec, q_network=self._cloning_network)
   collect_policy = epsilon_greedy_policy.EpsilonGreedyPolicy(
       policy, epsilon=self._epsilon_greedy)
   policy = greedy_policy.GreedyPolicy(policy)
   return policy, collect_policy
예제 #5
0
    def __init__(self, policy, epsilon, name=None):
        """Builds an epsilon-greedy MixturePolicy wrapping the given policy.

    Args:
      policy: A policy implementing the tf_policy.Base interface.
      epsilon: The probability of taking the random action represented as a
        float scalar, a scalar Tensor of shape=(), or a callable that returns a
        float scalar or Tensor.
      name: The name of this policy.

    Raises:
      ValueError: If epsilon is invalid.
    """
        self._greedy_policy = greedy_policy.GreedyPolicy(policy)
        self._epsilon = epsilon
        self._random_policy = random_tf_policy.RandomTFPolicy(
            policy.time_step_spec,
            policy.action_spec,
            emit_log_probability=policy.emit_log_probability)
        super(EpsilonGreedyPolicy,
              self).__init__(policy.time_step_spec,
                             policy.action_spec,
                             policy.policy_state_spec,
                             policy.info_spec,
                             emit_log_probability=policy.emit_log_probability,
                             name=name)
예제 #6
0
    def policy(self):
        """Return the current policy held by the agent.

    Returns:
      A subclass of tf_policy.Base.
    """
        return greedy_policy.GreedyPolicy(self._make_policy(collect=False))
예제 #7
0
    def __init__(self,
                 policy,
                 temperature=10.0,
                 epsilon=0.1,
                 remove_neg_inf=False,
                 name=None):
        """Builds a BoltzmannPolicy wrapping the given policy.

        Args:
          policy: A policy implementing the tf_policy.Base interface, using
            a distribution parameterized by logits.
          temperature: Tensor or function that returns the temperature for sampling
            when `action` is called. This parameter applies when the action spec is
            discrete. If the temperature is close to 0.0 this is equivalent to
            calling `tf.argmax` on the output of the network.
          name: The name of this policy. All variables in this module will fall
            under that name. Defaults to the class name.
        """

        self._greedy_policy = greedy_policy.GreedyPolicy(
            policy, remove_neg_inf=remove_neg_inf)
        super(EpsilonBoltzmannPolicy,
              self).__init__(policy.time_step_spec,
                             policy.action_spec,
                             policy.policy_state_spec,
                             policy.info_spec,
                             emit_log_probability=policy.emit_log_probability,
                             clip=False,
                             name=name)
        self._temperature = temperature
        self._epsilon = epsilon
        self._wrapped_policy = policy
예제 #8
0
  def testCategoricalActions(self, action_probs):
    action_spec = [
        tensor_spec.BoundedTensorSpec((1,), tf.int32, 0, len(action_probs)-1),
        tensor_spec.BoundedTensorSpec((), tf.int32, 0, len(action_probs)-1)]
    wrapped_policy = DistributionPolicy([
        tfp.distributions.Categorical(probs=[action_probs]),
        tfp.distributions.Categorical(probs=action_probs)
    ], self._time_step_spec, action_spec)
    policy = greedy_policy.GreedyPolicy(wrapped_policy)

    self.assertEqual(policy.time_step_spec(), self._time_step_spec)
    self.assertEqual(policy.action_spec(), action_spec)

    observations = tf.constant([[1, 2], [3, 4]], dtype=tf.float32)
    time_step = ts.restart(observations, batch_size=2)
    action_step = policy.action(time_step)
    nest.assert_same_structure(action_spec, action_step.action)

    action_ = self.evaluate(action_step.action)
    self.assertEqual(action_[0][0], np.argmax(action_probs))
    self.assertEqual(action_[1], np.argmax(action_probs))
    self.assertAllEqual(action_[0].shape, [
        1,
    ] + action_spec[0].shape.as_list())
    self.assertAllEqual(action_[1].shape, [
        1,
    ] + action_spec[1].shape.as_list())
예제 #9
0
    def __init__(self,
                 time_step_spec,
                 action_spec,
                 actor_network,
                 optimizer,
                 normalize_returns=True,
                 gradient_clipping=None,
                 debug_summaries=False,
                 summarize_grads_and_vars=False,
                 entropy_regularization=None,
                 train_step_counter=None,
                 name=None):
        """Creates a REINFORCE Agent.

    Args:
      time_step_spec: A `TimeStep` spec of the expected time_steps.
      action_spec: A nest of BoundedTensorSpec representing the actions.
      actor_network: A tf_agents.network.Network to be used by the agent. The
        network will be called with call(observation, step_type).
      optimizer: Optimizer for the actor network.
      normalize_returns: Whether to normalize returns across episodes when
        computing the loss.
      gradient_clipping: Norm length to clip gradients.
      debug_summaries: A bool to gather debug summaries.
      summarize_grads_and_vars: If True, gradient and network variable summaries
        will be written during training.
      entropy_regularization: Coefficient for entropy regularization loss term.
      train_step_counter: An optional counter to increment every time the train
        op is run.  Defaults to the global_step.
      name: The name of this agent. All variables in this module will fall
        under that name. Defaults to the class name.
    """
        tf.Module.__init__(self, name=name)

        self._actor_network = actor_network

        collect_policy = actor_policy.ActorPolicy(
            time_step_spec=time_step_spec,
            action_spec=action_spec,
            actor_network=self._actor_network,
            clip=True)

        policy = greedy_policy.GreedyPolicy(collect_policy)

        self._optimizer = optimizer
        self._normalize_returns = normalize_returns
        self._gradient_clipping = gradient_clipping
        self._entropy_regularization = entropy_regularization

        super(ReinforceAgent,
              self).__init__(time_step_spec,
                             action_spec,
                             policy,
                             collect_policy,
                             train_sequence_length=None,
                             debug_summaries=debug_summaries,
                             summarize_grads_and_vars=summarize_grads_and_vars,
                             train_step_counter=train_step_counter)
예제 #10
0
 def get_option_policies(self):
     return [
         greedy_policy.GreedyPolicy(
             actor_policy.ActorPolicy(
                 time_step_spec=self.time_step_spec,
                 action_spec=self.action_spec,
                 actor_network=option_net
             )
         )
         for option_net in self.actor_net.get_options()
     ]
예제 #11
0
def load_policy(policy, env_name, load_dir, ckpt_file=None):
    policy = greedy_policy.GreedyPolicy(policy)
    checkpoint = tf.train.Checkpoint(policy=policy)
    if ckpt_file is None:
        checkpoint_filename = tf.train.latest_checkpoint(load_dir)
    else:
        checkpoint_filename = os.path.join(load_dir, ckpt_file)
    print('Loading policy from %s.' % checkpoint_filename)
    checkpoint.restore(checkpoint_filename).assert_existing_objects_matched()
    # Unwrap greedy wrapper.
    return policy.wrapped_policy
예제 #12
0
    def __init__(self,
                 policy: tf_policy.TFPolicy,
                 epsilon: types.FloatOrReturningFloat,
                 exploration_mask: Optional[Sequence[int]] = None,
                 info_fields_to_inherit_from_greedy: Sequence[Text] = (),
                 name: Optional[Text] = None):
        """Builds an epsilon-greedy MixturePolicy wrapping the given policy.

    Args:
      policy: A policy implementing the tf_policy.TFPolicy interface.
      epsilon: The probability of taking the random action represented as a
        float scalar, a scalar Tensor of shape=(), or a callable that returns a
        float scalar or Tensor.
      exploration_mask: A `[0, 1]` vector describing which actions should be in
        the set of exploratory actions.
      info_fields_to_inherit_from_greedy: A list of policy info fields which
        should be copied over from the greedy action's info, even if the random
        action was taken.
      name: The name of this policy.

    Raises:
      ValueError: If epsilon is invalid.
    """
        try:
            observation_and_action_constraint_splitter = (
                policy.observation_and_action_constraint_splitter)
        except AttributeError:
            observation_and_action_constraint_splitter = None
        try:
            accepts_per_arm_features = policy.accepts_per_arm_features
        except AttributeError:
            accepts_per_arm_features = False
        self._greedy_policy = greedy_policy.GreedyPolicy(policy)
        self._epsilon = epsilon
        self._exploration_mask = exploration_mask
        self.info_fields_to_inherit_from_greedy = info_fields_to_inherit_from_greedy
        self._random_policy = random_tf_policy.RandomTFPolicy(
            policy.time_step_spec,
            policy.action_spec,
            emit_log_probability=policy.emit_log_probability,
            observation_and_action_constraint_splitter=(
                observation_and_action_constraint_splitter),
            accepts_per_arm_features=accepts_per_arm_features,
            stationary_mask=exploration_mask,
            info_spec=policy.info_spec)
        super(EpsilonGreedyPolicy,
              self).__init__(policy.time_step_spec,
                             policy.action_spec,
                             policy.policy_state_spec,
                             policy.info_spec,
                             emit_log_probability=policy.emit_log_probability,
                             observation_and_action_constraint_splitter=(
                                 observation_and_action_constraint_splitter),
                             name=name)
 def setUp(self):
     super(PolicyLoaderTest, self).setUp()
     self.root_dir = self.get_temp_dir()
     tf_observation_spec = tensor_spec.TensorSpec((), np.float32)
     tf_time_step_spec = ts.time_step_spec(tf_observation_spec)
     tf_action_spec = tensor_spec.BoundedTensorSpec((), np.float32, 0, 3)
     self.net = AddNet()
     self.policy = greedy_policy.GreedyPolicy(
         q_policy.QPolicy(tf_time_step_spec, tf_action_spec, self.net))
     self.train_step = common.create_variable('train_step', initial_value=0)
     self.saver = policy_saver.PolicySaver(self.policy,
                                           train_step=self.train_step)
예제 #14
0
 def _get_policies(self, time_step_spec, action_spec, cloning_network):
     policy = q_policy.QPolicy(
         time_step_spec,
         action_spec,
         q_network=self._cloning_network,
         # Unlike DQN, we support continuous action spaces - in which case
         # the policy just emits the network output.  In that case, we
         # don't care if the action_spec is a scalar integer value.
         validate_action_spec_and_network=False,
     )
     collect_policy = epsilon_greedy_policy.EpsilonGreedyPolicy(
         policy, epsilon=self._epsilon_greedy)
     policy = greedy_policy.GreedyPolicy(policy)
     return policy, collect_policy
예제 #15
0
 def _setup_as_discrete(self, time_step_spec, action_spec, loss_fn,
                        epsilon_greedy):
     self._loss_fn = loss_fn or self._discrete_loss
     # Unlike DQN, we support continuous action spaces - in which case
     # the policy just emits the network output.  In that case, we
     # don't care if the action_spec is a scalar integer value.
     policy = q_policy.QPolicy(
         time_step_spec,
         action_spec,
         q_network=self._cloning_network,
         validate_action_spec_and_network=False,
     )
     collect_policy = epsilon_greedy_policy.EpsilonGreedyPolicy(
         policy, epsilon=epsilon_greedy)
     policy = greedy_policy.GreedyPolicy(policy)
     return policy, collect_policy
예제 #16
0
    def _setup_policy(self, time_step_spec, action_spec, boltzmann_temperature,
                      emit_log_probability):

        policy = q_policy.QPolicy(time_step_spec,
                                  action_spec,
                                  q_network=self._q_network,
                                  emit_log_probability=emit_log_probability)

        if boltzmann_temperature is not None:
            collect_policy = boltzmann_policy.BoltzmannPolicy(
                policy, temperature=self._boltzmann_temperature)
        else:
            collect_policy = epsilon_greedy_policy.EpsilonGreedyPolicy(
                policy, epsilon=self._epsilon_greedy)
        policy = greedy_policy.GreedyPolicy(policy)

        return policy, collect_policy
예제 #17
0
    def __init__(self,
                 time_step_spec,
                 action_spec,
                 actor_network,
                 optimizer,
                 normalize_returns=True,
                 gradient_clipping=None,
                 debug_summaries=False,
                 summarize_grads_and_vars=False):
        """Creates a REINFORCE Agent.

    Args:
      time_step_spec: A `TimeStep` spec of the expected time_steps.
      action_spec: A nest of BoundedTensorSpec representing the actions.
      actor_network: A tf_agents.network.Network to be used by the agent. The
        network will be called with call(observation, step_type).
      optimizer: Optimizer for the actor network.
      normalize_returns: Whether to normalize returns across episodes when
        computing the loss.
      gradient_clipping: Norm length to clip gradients.
      debug_summaries: A bool to gather debug summaries.
      summarize_grads_and_vars: If True, gradient and network variable summaries
        will be written during training.
    """

        self._actor_network = actor_network
        collect_policy = actor_policy.ActorPolicy(
            time_step_spec=time_step_spec,
            action_spec=action_spec,
            actor_network=self._actor_network,
            clip=True)
        policy = greedy_policy.GreedyPolicy(collect_policy)

        self._optimizer = optimizer
        self._normalize_returns = normalize_returns
        self._gradient_clipping = gradient_clipping

        super(ReinforceAgent,
              self).__init__(time_step_spec,
                             action_spec,
                             policy,
                             collect_policy,
                             train_sequence_length=None,
                             debug_summaries=debug_summaries,
                             summarize_grads_and_vars=summarize_grads_and_vars)
  def testNormalActions(self, loc, scale):
    action_spec = tensor_spec.BoundedTensorSpec(
        [1], tf.float32, tf.float32.min, tf.float32.max)
    wrapped_policy = DistributionPolicy(
        tfp.distributions.Normal([loc], [scale]), self._time_step_spec,
        action_spec)
    policy = greedy_policy.GreedyPolicy(wrapped_policy)

    self.assertEqual(policy.time_step_spec, self._time_step_spec)
    self.assertEqual(policy.action_spec, action_spec)

    observations = tf.constant([[1, 2], [3, 4]], dtype=tf.float32)
    time_step = ts.restart(observations, batch_size=2)
    action_step = policy.action(time_step)
    tf.nest.assert_same_structure(action_spec, action_step.action)

    action_ = self.evaluate(action_step.action)
    self.assertAlmostEqual(action_[0], loc)
예제 #19
0
    def __init__(self, policy, epsilon):
        """Builds an epsilon-greedy MixturePolicy wrapping the given policy.

    Args:
      policy: A policy implementing the tf_policy.Base interface.
      epsilon: A float scalar or a scalar Tensor of shape=(), corresponding to
        the probability of taking the random action.
    Raises:
      ValueError: If epsilon is invalid.
    """
        self._greedy_policy = greedy_policy.GreedyPolicy(policy)
        self._epsilon = epsilon
        self._random_policy = random_tf_policy.RandomTFPolicy(
            policy.time_step_spec(), policy.action_spec())
        super(EpsilonGreedyPolicy, self).__init__(policy.time_step_spec(),
                                                  policy.action_spec(),
                                                  policy.policy_state_spec(),
                                                  policy.info_spec())
예제 #20
0
  def __init__(self,
               policy: tf_policy.TFPolicy,
               epsilon: types.FloatOrReturningFloat,
               name: Optional[Text] = None):
    """Builds an epsilon-greedy MixturePolicy wrapping the given policy.

    Args:
      policy: A policy implementing the tf_policy.TFPolicy interface.
      epsilon: The probability of taking the random action represented as a
        float scalar, a scalar Tensor of shape=(), or a callable that returns a
        float scalar or Tensor.
      name: The name of this policy.

    Raises:
      ValueError: If epsilon is invalid.
    """
    try:
      observation_and_action_constraint_splitter = (
          policy.observation_and_action_constraint_splitter)
    except AttributeError:
      observation_and_action_constraint_splitter = None
    try:
      accepts_per_arm_features = policy.accepts_per_arm_features
    except AttributeError:
      accepts_per_arm_features = False
    self._greedy_policy = greedy_policy.GreedyPolicy(policy)
    self._epsilon = epsilon
    self._random_policy = random_tf_policy.RandomTFPolicy(
        policy.time_step_spec,
        policy.action_spec,
        emit_log_probability=policy.emit_log_probability,
        observation_and_action_constraint_splitter=(
            observation_and_action_constraint_splitter),
        accepts_per_arm_features=accepts_per_arm_features,
        info_spec=policy.info_spec)
    super(EpsilonGreedyPolicy, self).__init__(
        policy.time_step_spec,
        policy.action_spec,
        policy.policy_state_spec,
        policy.info_spec,
        emit_log_probability=policy.emit_log_probability,
        observation_and_action_constraint_splitter=(
            observation_and_action_constraint_splitter),
        name=name)
예제 #21
0
  def _setup_as_discrete(self, time_step_spec, action_spec, loss_fn,
                         epsilon_greedy):
    self._bc_loss_fn = loss_fn or self._discrete_loss

    if any(isinstance(d, distribution_utils.DistributionSpecV2) for
           d in tf.nest.flatten([self._network_output_spec])):
      # If the output of the cloning network contains a distribution.
      base_policy = actor_policy.ActorPolicy(time_step_spec, action_spec,
                                             self._cloning_network)
    else:
      # If the output of the cloning network is logits.
      base_policy = q_policy.QPolicy(
          time_step_spec,
          action_spec,
          q_network=self._cloning_network,
          validate_action_spec_and_network=False)
    policy = greedy_policy.GreedyPolicy(base_policy)
    collect_policy = epsilon_greedy_policy.EpsilonGreedyPolicy(
        base_policy, epsilon=epsilon_greedy)
    return policy, collect_policy
예제 #22
0
def get_target_policy(load_dir, env_name):
    """Gets target policy."""
    env = tf_py_environment.TFPyEnvironment(suites.load_mujoco(env_name))
    actor_net = actor_distribution_network.ActorDistributionNetwork(
        env.observation_spec(),
        env.action_spec(),
        fc_layer_params=(256, 256),
        continuous_projection_net=tanh_normal_projection_network.
        TanhNormalProjectionNetwork)
    policy = actor_policy.ActorPolicy(time_step_spec=env.time_step_spec(),
                                      action_spec=env.action_spec(),
                                      actor_network=actor_net,
                                      training=False)
    policy = greedy_policy.GreedyPolicy(policy)

    checkpoint = tf.train.Checkpoint(policy=policy)

    directory = os.path.join(load_dir, env_name, 'train/policy')
    checkpoint_filename = tf.train.latest_checkpoint(directory)
    print('Loading policy from %s' % checkpoint_filename)
    checkpoint.restore(checkpoint_filename).assert_existing_objects_matched()
    policy = policy.wrapped_policy

    return policy, env
예제 #23
0
def train_eval(
        root_dir,
        env_name='HalfCheetah-v2',
        # Training params
        initial_collect_steps=10000,
        num_iterations=3200000,
        actor_fc_layers=(256, 256),
        critic_obs_fc_layers=None,
        critic_action_fc_layers=None,
        critic_joint_fc_layers=(256, 256),
        # Agent params
        batch_size=256,
        actor_learning_rate=3e-4,
        critic_learning_rate=3e-4,
        alpha_learning_rate=3e-4,
        gamma=0.99,
        target_update_tau=0.005,
        target_update_period=1,
        reward_scale_factor=0.1,
        # Replay params
        reverb_port=None,
        replay_capacity=1000000,
        # Others
        # Defaults to not checkpointing saved policy. If you wish to enable this,
        # please note the caveat explained in README.md.
        policy_save_interval=-1,
        eval_interval=10000,
        eval_episodes=30,
        debug_summaries=False,
        summarize_grads_and_vars=False):
    """Trains and evaluates SAC."""
    logging.info('Training SAC on: %s', env_name)
    collect_env = suite_mujoco.load(env_name)
    eval_env = suite_mujoco.load(env_name)

    observation_tensor_spec, action_tensor_spec, time_step_tensor_spec = (
        spec_utils.get_tensor_specs(collect_env))

    train_step = train_utils.create_train_step()

    actor_net = actor_distribution_network.ActorDistributionNetwork(
        observation_tensor_spec,
        action_tensor_spec,
        fc_layer_params=actor_fc_layers,
        continuous_projection_net=tanh_normal_projection_network.
        TanhNormalProjectionNetwork)
    critic_net = critic_network.CriticNetwork(
        (observation_tensor_spec, action_tensor_spec),
        observation_fc_layer_params=critic_obs_fc_layers,
        action_fc_layer_params=critic_action_fc_layers,
        joint_fc_layer_params=critic_joint_fc_layers,
        kernel_initializer='glorot_uniform',
        last_kernel_initializer='glorot_uniform')

    agent = sac_agent.SacAgent(
        time_step_tensor_spec,
        action_tensor_spec,
        actor_network=actor_net,
        critic_network=critic_net,
        actor_optimizer=tf.compat.v1.train.AdamOptimizer(
            learning_rate=actor_learning_rate),
        critic_optimizer=tf.compat.v1.train.AdamOptimizer(
            learning_rate=critic_learning_rate),
        alpha_optimizer=tf.compat.v1.train.AdamOptimizer(
            learning_rate=alpha_learning_rate),
        target_update_tau=target_update_tau,
        target_update_period=target_update_period,
        td_errors_loss_fn=tf.math.squared_difference,
        gamma=gamma,
        reward_scale_factor=reward_scale_factor,
        gradient_clipping=None,
        debug_summaries=debug_summaries,
        summarize_grads_and_vars=summarize_grads_and_vars,
        train_step_counter=train_step)
    agent.initialize()

    table_name = 'uniform_table'
    table = reverb.Table(table_name,
                         max_size=replay_capacity,
                         sampler=reverb.selectors.Uniform(),
                         remover=reverb.selectors.Fifo(),
                         rate_limiter=reverb.rate_limiters.MinSize(1))

    reverb_server = reverb.Server([table], port=reverb_port)
    reverb_replay = reverb_replay_buffer.ReverbReplayBuffer(
        agent.collect_data_spec,
        sequence_length=2,
        table_name=table_name,
        local_server=reverb_server)
    rb_observer = reverb_utils.ReverbAddTrajectoryObserver(
        reverb_replay.py_client,
        table_name,
        sequence_length=2,
        stride_length=1)

    dataset = reverb_replay.as_dataset(sample_batch_size=batch_size,
                                       num_steps=2).prefetch(50)
    experience_dataset_fn = lambda: dataset

    saved_model_dir = os.path.join(root_dir, learner.POLICY_SAVED_MODEL_DIR)
    env_step_metric = py_metrics.EnvironmentSteps()
    learning_triggers = [
        triggers.PolicySavedModelTrigger(
            saved_model_dir,
            agent,
            train_step,
            interval=policy_save_interval,
            metadata_metrics={triggers.ENV_STEP_METADATA_KEY:
                              env_step_metric}),
        triggers.StepPerSecondLogTrigger(train_step, interval=1000),
    ]

    agent_learner = learner.Learner(root_dir,
                                    train_step,
                                    agent,
                                    experience_dataset_fn,
                                    triggers=learning_triggers)

    random_policy = random_py_policy.RandomPyPolicy(
        collect_env.time_step_spec(), collect_env.action_spec())
    initial_collect_actor = actor.Actor(collect_env,
                                        random_policy,
                                        train_step,
                                        steps_per_run=initial_collect_steps,
                                        observers=[rb_observer])
    logging.info('Doing initial collect.')
    initial_collect_actor.run()

    tf_collect_policy = agent.collect_policy
    collect_policy = py_tf_eager_policy.PyTFEagerPolicy(tf_collect_policy,
                                                        use_tf_function=True)

    collect_actor = actor.Actor(collect_env,
                                collect_policy,
                                train_step,
                                steps_per_run=1,
                                metrics=actor.collect_metrics(10),
                                summary_dir=os.path.join(
                                    root_dir, learner.TRAIN_DIR),
                                observers=[rb_observer, env_step_metric])

    tf_greedy_policy = greedy_policy.GreedyPolicy(agent.policy)
    eval_greedy_policy = py_tf_eager_policy.PyTFEagerPolicy(
        tf_greedy_policy, use_tf_function=True)

    eval_actor = actor.Actor(
        eval_env,
        eval_greedy_policy,
        train_step,
        episodes_per_run=eval_episodes,
        metrics=actor.eval_metrics(eval_episodes),
        summary_dir=os.path.join(root_dir, 'eval'),
    )

    if eval_interval:
        logging.info('Evaluating.')
        eval_actor.run_and_log()

    logging.info('Training.')
    for _ in range(num_iterations):
        collect_actor.run()
        agent_learner.run(iterations=1)

        if eval_interval and agent_learner.train_step_numpy % eval_interval == 0:
            logging.info('Evaluating.')
            eval_actor.run_and_log()

    rb_observer.close()
    reverb_server.stop()
예제 #24
0
  def __init__(
      self,
      time_step_spec,
      action_spec,
      q_network,
      optimizer,
      epsilon_greedy=0.1,
      boltzmann_temperature=None,
      # Params for target network updates
      target_update_tau=1.0,
      target_update_period=1,
      # Params for training.
      td_errors_loss_fn=None,
      gamma=1.0,
      reward_scale_factor=1.0,
      gradient_clipping=None,
      # Params for debugging
      debug_summaries=False,
      summarize_grads_and_vars=False,
      train_step_counter=None,
      name=None):
    """Creates a DQN Agent.

    Args:
      time_step_spec: A `TimeStep` spec of the expected time_steps.
      action_spec: A nest of BoundedTensorSpec representing the actions.
      q_network: A tf_agents.network.Network to be used by the agent. The
        network will be called with call(observation, step_type).
      optimizer: The optimizer to use for training.
      epsilon_greedy: probability of choosing a random action in the default
        epsilon-greedy collect policy (used only if a wrapper is not provided to
        the collect_policy method).
      boltzmann_temperature: Temperature value to use for Boltzmann sampling of
        the actions during data collection. The closer to 0.0, the higher the
        probability of choosing the best action.
      target_update_tau: Factor for soft update of the target networks.
      target_update_period: Period for soft update of the target networks.
      td_errors_loss_fn: A function for computing the TD errors loss. If None, a
        default value of element_wise_huber_loss is used. This function takes as
        input the target and the estimated Q values and returns the loss for
        each element of the batch.
      gamma: A discount factor for future rewards.
      reward_scale_factor: Multiplicative scale for the reward.
      gradient_clipping: Norm length to clip gradients.
      debug_summaries: A bool to gather debug summaries.
      summarize_grads_and_vars: If True, gradient and network variable summaries
        will be written during training.
      train_step_counter: An optional counter to increment every time the train
        op is run.  Defaults to the global_step.
      name: The name of this agent. All variables in this module will fall
        under that name. Defaults to the class name.

    Raises:
      ValueError: If the action spec contains more than one action or action
        spec minimum is not equal to 0.
    """
    tf.Module.__init__(self, name=name)

    flat_action_spec = tf.nest.flatten(action_spec)
    self._num_actions = [
        spec.maximum - spec.minimum + 1 for spec in flat_action_spec
    ]

    # TODO(oars): Get DQN working with more than one dim in the actions.
    if len(flat_action_spec) > 1 or flat_action_spec[0].shape.ndims > 1:
      raise ValueError('Only one dimensional actions are supported now.')

    if not all(spec.minimum == 0 for spec in flat_action_spec):
      raise ValueError(
          'Action specs should have minimum of 0, but saw: {0}'.format(
              [spec.minimum for spec in flat_action_spec]))

    if epsilon_greedy is not None and boltzmann_temperature is not None:
      raise ValueError(
          'Configured both epsilon_greedy value {} and temperature {}, '
          'however only one of them can be used for exploration.'.format(
              epsilon_greedy, boltzmann_temperature))

    self._q_network = q_network
    self._target_q_network = self._q_network.copy(name='TargetQNetwork')
    self._epsilon_greedy = epsilon_greedy
    self._boltzmann_temperature = boltzmann_temperature
    self._optimizer = optimizer
    self._td_errors_loss_fn = td_errors_loss_fn or element_wise_huber_loss
    self._gamma = gamma
    self._reward_scale_factor = reward_scale_factor
    self._gradient_clipping = gradient_clipping
    self._update_target = self._get_target_updater(
        target_update_tau, target_update_period)

    policy = q_policy.QPolicy(
        time_step_spec, action_spec, q_network=self._q_network)

    if boltzmann_temperature is not None:
      collect_policy = boltzmann_policy.BoltzmannPolicy(
          policy, temperature=self._boltzmann_temperature)
    else:
      collect_policy = epsilon_greedy_policy.EpsilonGreedyPolicy(
          policy, epsilon=self._epsilon_greedy)
    policy = greedy_policy.GreedyPolicy(policy)

    super(DqnAgent, self).__init__(
        time_step_spec,
        action_spec,
        policy,
        collect_policy,
        train_sequence_length=2 if not q_network.state_spec else None,
        debug_summaries=debug_summaries,
        summarize_grads_and_vars=summarize_grads_and_vars,
        train_step_counter=train_step_counter)
예제 #25
0
 def _setup_as_continuous(self, time_step_spec, action_spec, loss_fn):
     self._bc_loss_fn = loss_fn or self._continuous_loss_fn
     collect_policy = actor_policy.ActorPolicy(
         time_step_spec, action_spec, actor_network=self._cloning_network)
     policy = greedy_policy.GreedyPolicy(collect_policy)
     return policy, collect_policy
def train_eval(
    root_dir,
    environment_name="broken_reacher",
    num_iterations=1000000,
    actor_fc_layers=(256, 256),
    critic_obs_fc_layers=None,
    critic_action_fc_layers=None,
    critic_joint_fc_layers=(256, 256),
    initial_collect_steps=10000,
    real_initial_collect_steps=10000,
    collect_steps_per_iteration=1,
    real_collect_interval=10,
    replay_buffer_capacity=1000000,
    # Params for target update
    target_update_tau=0.005,
    target_update_period=1,
    # Params for train
    train_steps_per_iteration=1,
    batch_size=256,
    actor_learning_rate=3e-4,
    critic_learning_rate=3e-4,
    classifier_learning_rate=3e-4,
    alpha_learning_rate=3e-4,
    td_errors_loss_fn=tf.math.squared_difference,
    gamma=0.99,
    reward_scale_factor=0.1,
    gradient_clipping=None,
    use_tf_functions=True,
    # Params for eval
    num_eval_episodes=30,
    eval_interval=10000,
    # Params for summaries and logging
    train_checkpoint_interval=10000,
    policy_checkpoint_interval=5000,
    rb_checkpoint_interval=50000,
    log_interval=1000,
    summary_interval=1000,
    summaries_flush_secs=10,
    debug_summaries=True,
    summarize_grads_and_vars=False,
    train_on_real=False,
    delta_r_warmup=0,
    random_seed=0,
    checkpoint_dir=None,
):
    """A simple train and eval for SAC."""
    np.random.seed(random_seed)
    tf.random.set_seed(random_seed)
    root_dir = os.path.expanduser(root_dir)
    train_dir = os.path.join(root_dir, "train")
    eval_dir = os.path.join(root_dir, "eval")

    train_summary_writer = tf.compat.v2.summary.create_file_writer(
        train_dir, flush_millis=summaries_flush_secs * 1000)
    train_summary_writer.set_as_default()

    eval_summary_writer = tf.compat.v2.summary.create_file_writer(
        eval_dir, flush_millis=summaries_flush_secs * 1000)

    if environment_name == "broken_reacher":
        get_env_fn = darc_envs.get_broken_reacher_env
    elif environment_name == "half_cheetah_obstacle":
        get_env_fn = darc_envs.get_half_cheetah_direction_env
    elif environment_name == "inverted_pendulum":
        get_env_fn = darc_envs.get_inverted_pendulum_env
    elif environment_name.startswith("broken_joint"):
        base_name = environment_name.split("broken_joint_")[1]
        get_env_fn = functools.partial(darc_envs.get_broken_joint_env,
                                       env_name=base_name)
    elif environment_name.startswith("falling"):
        base_name = environment_name.split("falling_")[1]
        get_env_fn = functools.partial(darc_envs.get_falling_env,
                                       env_name=base_name)
    else:
        raise NotImplementedError("Unknown environment: %s" % environment_name)

    eval_name_list = ["sim", "real"]
    eval_env_list = [get_env_fn(mode) for mode in eval_name_list]

    eval_metrics_list = []
    for name in eval_name_list:
        eval_metrics_list.append([
            tf_metrics.AverageReturnMetric(buffer_size=num_eval_episodes,
                                           name="AverageReturn_%s" % name),
        ])

    global_step = tf.compat.v1.train.get_or_create_global_step()
    with tf.compat.v2.summary.record_if(
            lambda: tf.math.equal(global_step % summary_interval, 0)):
        tf_env_real = get_env_fn("real")
        if train_on_real:
            tf_env = get_env_fn("real")
        else:
            tf_env = get_env_fn("sim")

        time_step_spec = tf_env.time_step_spec()
        observation_spec = time_step_spec.observation
        action_spec = tf_env.action_spec()

        actor_net = actor_distribution_network.ActorDistributionNetwork(
            observation_spec,
            action_spec,
            fc_layer_params=actor_fc_layers,
            continuous_projection_net=(
                tanh_normal_projection_network.TanhNormalProjectionNetwork),
        )
        critic_net = critic_network.CriticNetwork(
            (observation_spec, action_spec),
            observation_fc_layer_params=critic_obs_fc_layers,
            action_fc_layer_params=critic_action_fc_layers,
            joint_fc_layer_params=critic_joint_fc_layers,
            kernel_initializer="glorot_uniform",
            last_kernel_initializer="glorot_uniform",
        )

        classifier = classifiers.build_classifier(observation_spec,
                                                  action_spec)

        tf_agent = darc_agent.DarcAgent(
            time_step_spec,
            action_spec,
            actor_network=actor_net,
            critic_network=critic_net,
            classifier=classifier,
            actor_optimizer=tf.compat.v1.train.AdamOptimizer(
                learning_rate=actor_learning_rate),
            critic_optimizer=tf.compat.v1.train.AdamOptimizer(
                learning_rate=critic_learning_rate),
            classifier_optimizer=tf.compat.v1.train.AdamOptimizer(
                learning_rate=classifier_learning_rate),
            alpha_optimizer=tf.compat.v1.train.AdamOptimizer(
                learning_rate=alpha_learning_rate),
            target_update_tau=target_update_tau,
            target_update_period=target_update_period,
            td_errors_loss_fn=td_errors_loss_fn,
            gamma=gamma,
            reward_scale_factor=reward_scale_factor,
            gradient_clipping=gradient_clipping,
            debug_summaries=debug_summaries,
            summarize_grads_and_vars=summarize_grads_and_vars,
            train_step_counter=global_step,
        )
        tf_agent.initialize()

        # Make the replay buffer.
        replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
            data_spec=tf_agent.collect_data_spec,
            batch_size=1,
            max_length=replay_buffer_capacity,
        )
        replay_observer = [replay_buffer.add_batch]

        real_replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
            data_spec=tf_agent.collect_data_spec,
            batch_size=1,
            max_length=replay_buffer_capacity,
        )
        real_replay_observer = [real_replay_buffer.add_batch]

        sim_train_metrics = [
            tf_metrics.NumberOfEpisodes(name="NumberOfEpisodesSim"),
            tf_metrics.EnvironmentSteps(name="EnvironmentStepsSim"),
            tf_metrics.AverageReturnMetric(
                buffer_size=num_eval_episodes,
                batch_size=tf_env.batch_size,
                name="AverageReturnSim",
            ),
            tf_metrics.AverageEpisodeLengthMetric(
                buffer_size=num_eval_episodes,
                batch_size=tf_env.batch_size,
                name="AverageEpisodeLengthSim",
            ),
        ]
        real_train_metrics = [
            tf_metrics.NumberOfEpisodes(name="NumberOfEpisodesReal"),
            tf_metrics.EnvironmentSteps(name="EnvironmentStepsReal"),
            tf_metrics.AverageReturnMetric(
                buffer_size=num_eval_episodes,
                batch_size=tf_env.batch_size,
                name="AverageReturnReal",
            ),
            tf_metrics.AverageEpisodeLengthMetric(
                buffer_size=num_eval_episodes,
                batch_size=tf_env.batch_size,
                name="AverageEpisodeLengthReal",
            ),
        ]

        eval_policy = greedy_policy.GreedyPolicy(tf_agent.policy)
        initial_collect_policy = random_tf_policy.RandomTFPolicy(
            tf_env.time_step_spec(), tf_env.action_spec())
        collect_policy = tf_agent.collect_policy

        train_checkpointer = common.Checkpointer(
            ckpt_dir=train_dir,
            agent=tf_agent,
            global_step=global_step,
            metrics=metric_utils.MetricsGroup(
                sim_train_metrics + real_train_metrics, "train_metrics"),
        )
        policy_checkpointer = common.Checkpointer(
            ckpt_dir=os.path.join(train_dir, "policy"),
            policy=eval_policy,
            global_step=global_step,
        )
        rb_checkpointer = common.Checkpointer(
            ckpt_dir=os.path.join(train_dir, "replay_buffer"),
            max_to_keep=1,
            replay_buffer=(replay_buffer, real_replay_buffer),
        )

        if checkpoint_dir is not None:
            checkpoint_path = tf.train.latest_checkpoint(checkpoint_dir)
            assert checkpoint_path is not None
            train_checkpointer._load_status = train_checkpointer._checkpoint.restore(  # pylint: disable=protected-access
                checkpoint_path)
            train_checkpointer._load_status.initialize_or_restore()  # pylint: disable=protected-access
        else:
            train_checkpointer.initialize_or_restore()
        rb_checkpointer.initialize_or_restore()

        if replay_buffer.num_frames() == 0:
            initial_collect_driver = dynamic_step_driver.DynamicStepDriver(
                tf_env,
                initial_collect_policy,
                observers=replay_observer + sim_train_metrics,
                num_steps=initial_collect_steps,
            )
            real_initial_collect_driver = dynamic_step_driver.DynamicStepDriver(
                tf_env_real,
                initial_collect_policy,
                observers=real_replay_observer + real_train_metrics,
                num_steps=real_initial_collect_steps,
            )

        collect_driver = dynamic_step_driver.DynamicStepDriver(
            tf_env,
            collect_policy,
            observers=replay_observer + sim_train_metrics,
            num_steps=collect_steps_per_iteration,
        )

        real_collect_driver = dynamic_step_driver.DynamicStepDriver(
            tf_env_real,
            collect_policy,
            observers=real_replay_observer + real_train_metrics,
            num_steps=collect_steps_per_iteration,
        )

        config_str = gin.operative_config_str()
        logging.info(config_str)
        with tf.compat.v1.gfile.Open(os.path.join(root_dir, "operative.gin"),
                                     "w") as f:
            f.write(config_str)

        if use_tf_functions:
            initial_collect_driver.run = common.function(
                initial_collect_driver.run)
            real_initial_collect_driver.run = common.function(
                real_initial_collect_driver.run)
            collect_driver.run = common.function(collect_driver.run)
            real_collect_driver.run = common.function(real_collect_driver.run)
            tf_agent.train = common.function(tf_agent.train)

        # Collect initial replay data.
        if replay_buffer.num_frames() == 0:
            logging.info(
                "Initializing replay buffer by collecting experience for %d steps with "
                "a random policy.",
                initial_collect_steps,
            )
            initial_collect_driver.run()
            real_initial_collect_driver.run()

        for eval_name, eval_env, eval_metrics in zip(eval_name_list,
                                                     eval_env_list,
                                                     eval_metrics_list):
            metric_utils.eager_compute(
                eval_metrics,
                eval_env,
                eval_policy,
                num_episodes=num_eval_episodes,
                train_step=global_step,
                summary_writer=eval_summary_writer,
                summary_prefix="Metrics-%s" % eval_name,
            )
            metric_utils.log_metrics(eval_metrics)

        time_step = None
        real_time_step = None
        policy_state = collect_policy.get_initial_state(tf_env.batch_size)

        timed_at_step = global_step.numpy()
        time_acc = 0

        # Prepare replay buffer as dataset with invalid transitions filtered.
        def _filter_invalid_transition(trajectories, unused_arg1):
            return ~trajectories.is_boundary()[0]

        dataset = (replay_buffer.as_dataset(
            sample_batch_size=batch_size, num_steps=2).unbatch().filter(
                _filter_invalid_transition).batch(batch_size).prefetch(5))
        real_dataset = (real_replay_buffer.as_dataset(
            sample_batch_size=batch_size, num_steps=2).unbatch().filter(
                _filter_invalid_transition).batch(batch_size).prefetch(5))

        # Dataset generates trajectories with shape [Bx2x...]
        iterator = iter(dataset)
        real_iterator = iter(real_dataset)

        def train_step():
            experience, _ = next(iterator)
            real_experience, _ = next(real_iterator)
            return tf_agent.train(experience, real_experience=real_experience)

        if use_tf_functions:
            train_step = common.function(train_step)

        for _ in range(num_iterations):
            start_time = time.time()
            time_step, policy_state = collect_driver.run(
                time_step=time_step,
                policy_state=policy_state,
            )
            assert not policy_state  # We expect policy_state == ().
            if (global_step.numpy() % real_collect_interval == 0
                    and global_step.numpy() >= delta_r_warmup):
                real_time_step, policy_state = real_collect_driver.run(
                    time_step=real_time_step,
                    policy_state=policy_state,
                )

            for _ in range(train_steps_per_iteration):
                train_loss = train_step()
            time_acc += time.time() - start_time

            global_step_val = global_step.numpy()

            if global_step_val % log_interval == 0:
                logging.info("step = %d, loss = %f", global_step_val,
                             train_loss.loss)
                steps_per_sec = (global_step_val - timed_at_step) / time_acc
                logging.info("%.3f steps/sec", steps_per_sec)
                tf.compat.v2.summary.scalar(name="global_steps_per_sec",
                                            data=steps_per_sec,
                                            step=global_step)
                timed_at_step = global_step_val
                time_acc = 0

            for train_metric in sim_train_metrics:
                train_metric.tf_summaries(train_step=global_step,
                                          step_metrics=sim_train_metrics[:2])
            for train_metric in real_train_metrics:
                train_metric.tf_summaries(train_step=global_step,
                                          step_metrics=real_train_metrics[:2])

            if global_step_val % eval_interval == 0:
                for eval_name, eval_env, eval_metrics in zip(
                        eval_name_list, eval_env_list, eval_metrics_list):
                    metric_utils.eager_compute(
                        eval_metrics,
                        eval_env,
                        eval_policy,
                        num_episodes=num_eval_episodes,
                        train_step=global_step,
                        summary_writer=eval_summary_writer,
                        summary_prefix="Metrics-%s" % eval_name,
                    )
                    metric_utils.log_metrics(eval_metrics)

            if global_step_val % train_checkpoint_interval == 0:
                train_checkpointer.save(global_step=global_step_val)

            if global_step_val % policy_checkpoint_interval == 0:
                policy_checkpointer.save(global_step=global_step_val)

            if global_step_val % rb_checkpoint_interval == 0:
                rb_checkpointer.save(global_step=global_step_val)
        return train_loss
예제 #27
0
        actor_network=actor_net,
        critic_network=critic_net,
        actor_optimizer=tf.compat.v1.train.AdamOptimizer(learning_rate=actorLearningRate),
        critic_optimizer=tf.compat.v1.train.AdamOptimizer(learning_rate=criticLearningRate),
        alpha_optimizer=tf.compat.v1.train.AdamOptimizer(learning_rate=alphaLearningRate),
        target_update_tau=target_update_tau,
        gamma=gamma,
        gradient_clipping=gradientClipping,
        train_step_counter=global_step,
    )
    tf_agent.initialize()
    print('SAC Agent Created.')


    # policies
    evaluate_policy = greedy_policy.GreedyPolicy(tf_agent.policy)
    collect_policy = tf_agent.collect_policy

    # metrics and evaluation
    def compute_avg_return(environment, policy, num_episodes=2):
        total_return = 0.0
        for _ in range(num_episodes):
            time_step = environment.reset()
            episode_return = 0.0
            while not time_step.is_last():
                action_step = policy.action(time_step)
                time_step = environment.step(action_step.action)
                episode_return += time_step.reward
            total_return += episode_return
        avg_return = total_return / num_episodes
        return avg_return.numpy()[0]
예제 #28
0
    def __init__(
        self,
        saved_model_dir: Text,
        agent: tf_agent.TFAgent,
        train_step: tf.Variable,
        interval: int,
        async_saving: bool = False,
        metadata_metrics: Optional[Mapping[Text, py_metric.PyMetric]] = None,
        start: int = 0,
        extra_concrete_functions: Optional[Sequence[Tuple[
            str, policy_saver.def_function.Function]]] = None,
        batch_size: Optional[int] = None,
        use_nest_path_signatures: bool = True,
        save_greedy_policy=True,
        save_collect_policy=True,
        input_fn_and_spec: Optional[policy_saver.InputFnAndSpecType] = None,
    ):
        """Initializes a PolicySavedModelTrigger.

    Args:
      saved_model_dir: Base dir where checkpoints will be saved.
      agent: Agent to extract policies from.
      train_step: `tf.Variable` which keeps track of the number of train steps.
      interval: How often, in train_steps, the trigger will save. Note that as
        long as the >= `interval` number of steps have passed since the last
        trigger, the event gets triggered. The current value is not necessarily
        `interval` steps away from the last triggered value.
      async_saving: If True saving will be done asynchronously in a separate
        thread. Note if this is on the variable values in the saved
        checkpoints/models are not deterministic.
      metadata_metrics: A dictionary of metrics, whose `result()` method returns
        a scalar to be saved along with the policy. Currently only supported
        when async_saving is False.
      start: Initial value for the trigger passed directly to the base class. It
        helps control from which train step the weigts of the model are saved.
      extra_concrete_functions: Optional sequence of extra concrete functions to
        register in the policy savers. The sequence should consist of tuples
        with string name for the function and the tf.function to register. Note
        this does not support adding extra assets.
      batch_size: The number of batch entries the policy will process at a time.
        This must be either `None` (unknown batch size) or a python integer.
      use_nest_path_signatures: SavedModel spec signatures will be created based
        on the sructure of the specs. Otherwise all specs must have unique
        names.
      save_greedy_policy: Disable when an agent's policy distribution method
        does not support mode.
      save_collect_policy: Disable when not saving collect policy.
      input_fn_and_spec: A `(input_fn, tensor_spec)` tuple where input_fn is a
        function that takes inputs according to tensor_spec and converts them to
        the `(time_step, policy_state)` tuple that is used as the input to the
        action_fn. When `input_fn_and_spec` is set, `tensor_spec` is the input
        for the action signature. When `input_fn_and_spec is None`, the action
        signature takes as input `(time_step, policy_state)`.
    """
        if async_saving and metadata_metrics:
            raise NotImplementedError('Support for metadata_metrics is not '
                                      'implemented for async policy saver.')

        self._agent = agent
        self._train_step = train_step
        self._async_saving = async_saving
        self._metadata_metrics = metadata_metrics or {}
        self._metadata = {
            k: tf.Variable(0, dtype=v.result().dtype, shape=v.result().shape)
            for k, v in self._metadata_metrics.items()
        }
        self._input_fn_and_spec = input_fn_and_spec

        greedy = None
        if isinstance(agent.policy, greedy_policy.GreedyPolicy):
            raw_policy = agent.policy.wrapped_policy
            greedy = agent.policy
        else:
            raw_policy = agent.policy
            if save_greedy_policy:
                greedy = greedy_policy.GreedyPolicy(agent.policy)

        self._raw_policy_saver = self._build_saver(raw_policy, batch_size,
                                                   use_nest_path_signatures)
        savers = [(self._raw_policy_saver, learner.RAW_POLICY_SAVED_MODEL_DIR)]

        if save_collect_policy:
            collect_policy_saver = self._build_saver(agent.collect_policy,
                                                     batch_size,
                                                     use_nest_path_signatures)

            savers.append(
                (collect_policy_saver, learner.COLLECT_POLICY_SAVED_MODEL_DIR))

        if save_greedy_policy:
            greedy_policy_saver = self._build_saver(greedy, batch_size,
                                                    use_nest_path_signatures)
            savers.append(
                (greedy_policy_saver, learner.GREEDY_POLICY_SAVED_MODEL_DIR))

        extra_concrete_functions = extra_concrete_functions or []
        for saver, _ in savers:
            for name, fn in extra_concrete_functions:
                saver.register_concrete_function(name, fn)

        self._checkpoint_dir = os.path.join(saved_model_dir,
                                            learner.POLICY_CHECKPOINT_DIR)

        # TODO(b/173815037): Use a TF-Agents util to check for whether a saved
        # policy already exists.
        for saver, path in savers:
            spec_path = os.path.join(saved_model_dir, path,
                                     'policy_specs.pbtxt')
            if not tf.io.gfile.exists(spec_path):
                saver.save(os.path.join(saved_model_dir, path))

        super(PolicySavedModelTrigger, self).__init__(interval,
                                                      self._save_fn,
                                                      start=start)
예제 #29
0
    def __init__(
            self,
            time_step_spec,
            action_spec,
            categorical_q_network,
            optimizer,
            min_q_value=-10.0,
            max_q_value=10.0,
            epsilon_greedy=0.1,
            n_step_update=1,
            boltzmann_temperature=None,
            # Params for target network updates
            target_update_tau=1.0,
            target_update_period=1,
            # Params for training.
            td_errors_loss_fn=None,
            gamma=1.0,
            reward_scale_factor=1.0,
            gradient_clipping=None,
            # Params for debugging
            debug_summaries=False,
            summarize_grads_and_vars=False,
            train_step_counter=None,
            name=None):
        """Creates a Categorical DQN Agent.

    Args:
      time_step_spec: A `TimeStep` spec of the expected time_steps.
      action_spec: A `BoundedTensorSpec` representing the actions.
      categorical_q_network: A categorical_q_network.CategoricalQNetwork that
        returns the q_distribution for each action.
      optimizer: The optimizer to use for training.
      min_q_value: A float specifying the minimum Q-value, used for setting up
        the support.
      max_q_value: A float specifying the maximum Q-value, used for setting up
        the support.
      epsilon_greedy: probability of choosing a random action in the default
        epsilon-greedy collect policy (used only if a wrapper is not provided to
        the collect_policy method).
      n_step_update: The number of steps to consider when computing TD error and
        TD loss. Defaults to single-step updates. Note that this requires the
        user to call train on Trajectory objects with a time dimension of
        `n_step_update + 1`. However, note that we do not yet support
        `n_step_update > 1` in the case of RNNs (i.e., non-empty
        `q_network.state_spec`).
      boltzmann_temperature: Temperature value to use for Boltzmann sampling of
        the actions during data collection. The closer to 0.0, the higher the
        probability of choosing the best action.
      target_update_tau: Factor for soft update of the target networks.
      target_update_period: Period for soft update of the target networks.
      td_errors_loss_fn: A function for computing the TD errors loss. If None, a
        default value of element_wise_huber_loss is used. This function takes as
        input the target and the estimated Q values and returns the loss for
        each element of the batch.
      gamma: A discount factor for future rewards.
      reward_scale_factor: Multiplicative scale for the reward.
      gradient_clipping: Norm length to clip gradients.
      debug_summaries: A bool to gather debug summaries.
      summarize_grads_and_vars: If True, gradient and network variable summaries
        will be written during training.
      train_step_counter: An optional counter to increment every time the train
        op is run.  Defaults to the global_step.
      name: The name of this agent. All variables in this module will fall
        under that name. Defaults to the class name.

    Raises:
      TypeError: If the action spec contains more than one action.
    """
        num_atoms = getattr(categorical_q_network, 'num_atoms', None)
        if num_atoms is None:
            raise TypeError(
                'Expected categorical_q_network to have property '
                '`num_atoms`, but it doesn\'t (note: you likely want to '
                'use a CategoricalQNetwork). Network is: %s' %
                (categorical_q_network, ))

        self._num_atoms = num_atoms
        self._min_q_value = min_q_value
        self._max_q_value = max_q_value
        self._support = tf.linspace(min_q_value, max_q_value, num_atoms)

        super(CategoricalDqnAgent,
              self).__init__(time_step_spec,
                             action_spec,
                             categorical_q_network,
                             optimizer,
                             epsilon_greedy=epsilon_greedy,
                             n_step_update=n_step_update,
                             boltzmann_temperature=boltzmann_temperature,
                             target_update_tau=target_update_tau,
                             target_update_period=target_update_period,
                             td_errors_loss_fn=td_errors_loss_fn,
                             gamma=gamma,
                             reward_scale_factor=reward_scale_factor,
                             gradient_clipping=gradient_clipping,
                             debug_summaries=debug_summaries,
                             summarize_grads_and_vars=summarize_grads_and_vars,
                             train_step_counter=train_step_counter,
                             name=name)

        policy = categorical_q_policy.CategoricalQPolicy(
            min_q_value, max_q_value, self._q_network, self._action_spec)
        if boltzmann_temperature is not None:
            self._collect_policy = boltzmann_policy.BoltzmannPolicy(
                policy, temperature=self._boltzmann_temperature)
        else:
            self._collect_policy = epsilon_greedy_policy.EpsilonGreedyPolicy(
                policy, epsilon=self._epsilon_greedy)
        self._policy = greedy_policy.GreedyPolicy(policy)
예제 #30
0
 def GetEvalPolicy(self):
     return greedy_policy.GreedyPolicy(self._agent.policy)