Пример #1
0
 def test_none_output(self):
     model = tf2_utils.to_sonnet_module(lambda x: None)
     input_spec = specs.Array(shape=(10, ), dtype=np.float32)
     expected_spec = None
     output_spec = tf2_utils.create_variables(model, [input_spec])
     self.assertEqual(model.variables, ())
     self.assertEqual(output_spec, expected_spec)
Пример #2
0
 def test_scalar_output(self):
     model = tf2_utils.to_sonnet_module(tf.reduce_sum)
     input_spec = specs.Array(shape=(10, ), dtype=np.float32)
     expected_spec = tf.TensorSpec(shape=(), dtype=tf.float32)
     output_spec = tf2_utils.create_variables(model, [input_spec])
     self.assertEqual(model.variables, ())
     self.assertEqual(output_spec, expected_spec)
Пример #3
0
    def __init__(
        self,
        policy_network: snt.Module,
        critic_network: snt.Module,
        target_policy_network: snt.Module,
        target_critic_network: snt.Module,
        discount: float,
        num_samples: int,
        target_policy_update_period: int,
        target_critic_update_period: int,
        dataset: tf.data.Dataset,
        observation_network: types.TensorTransformation = tf.identity,
        target_observation_network: types.TensorTransformation = tf.identity,
        policy_loss_module: Optional[snt.Module] = None,
        policy_optimizer: Optional[snt.Optimizer] = None,
        critic_optimizer: Optional[snt.Optimizer] = None,
        dual_optimizer: Optional[snt.Optimizer] = None,
        clipping: bool = True,
        counter: Optional[counting.Counter] = None,
        logger: Optional[loggers.Logger] = None,
        checkpoint: bool = True,
    ):

        # Store online and target networks.
        self._policy_network = policy_network
        self._critic_network = critic_network
        self._target_policy_network = target_policy_network
        self._target_critic_network = target_critic_network

        # Make sure observation networks are snt.Module's so they have variables.
        self._observation_network = tf2_utils.to_sonnet_module(
            observation_network)
        self._target_observation_network = tf2_utils.to_sonnet_module(
            target_observation_network)

        # General learner book-keeping and loggers.
        self._counter = counter or counting.Counter()
        self._logger = logger or loggers.make_default_logger('learner')

        # Other learner parameters.
        self._discount = discount
        self._num_samples = num_samples
        self._clipping = clipping

        # Necessary to track when to update target networks.
        self._num_steps = tf.Variable(0, dtype=tf.int32)
        self._target_policy_update_period = target_policy_update_period
        self._target_critic_update_period = target_critic_update_period

        # Batch dataset and create iterator.
        # TODO(b/155086959): Fix type stubs and remove.
        self._iterator = iter(dataset)  # pytype: disable=wrong-arg-types

        self._policy_loss_module = policy_loss_module or losses.MPO(
            epsilon=1e-1,
            epsilon_penalty=1e-3,
            epsilon_mean=1e-3,
            epsilon_stddev=1e-6,
            init_log_temperature=1.,
            init_log_alpha_mean=1.,
            init_log_alpha_stddev=10.)

        # Create the optimizers.
        self._critic_optimizer = critic_optimizer or snt.optimizers.Adam(1e-4)
        self._policy_optimizer = policy_optimizer or snt.optimizers.Adam(1e-4)
        self._dual_optimizer = dual_optimizer or snt.optimizers.Adam(1e-2)

        # Expose the variables.
        policy_network_to_expose = snt.Sequential(
            [self._target_observation_network, self._target_policy_network])
        self._variables = {
            'critic': self._target_critic_network.variables,
            'policy': policy_network_to_expose.variables,
        }

        # Create a checkpointer and snapshotter object.
        self._checkpointer = None
        self._snapshotter = None

        if checkpoint:
            self._checkpointer = tf2_savers.Checkpointer(
                subdirectory='dmpo_learner',
                objects_to_save={
                    'counter': self._counter,
                    'policy': self._policy_network,
                    'critic': self._critic_network,
                    'observation': self._observation_network,
                    'target_policy': self._target_policy_network,
                    'target_critic': self._target_critic_network,
                    'target_observation': self._target_observation_network,
                    'policy_optimizer': self._policy_optimizer,
                    'critic_optimizer': self._critic_optimizer,
                    'dual_optimizer': self._dual_optimizer,
                    'policy_loss_module': self._policy_loss_module,
                    'num_steps': self._num_steps,
                })

            self._snapshotter = tf2_savers.Snapshotter(
                objects_to_save={
                    'policy':
                    snt.Sequential([
                        self._target_observation_network,
                        self._target_policy_network
                    ]),
                })

        # Do not record timestamps until after the first learning step is done.
        # This is to avoid including the time it takes for actors to come online and
        # fill the replay buffer.
        self._timestamp = None
Пример #4
0
  def __init__(self,
               environment_spec: specs.EnvironmentSpec,
               policy_network: snt.Module,
               critic_network: snt.Module,
               observation_network: types.TensorTransformation = tf.identity,
               discount: float = 0.99,
               batch_size: int = 256,
               prefetch_size: int = 4,
               target_update_period: int = 100,
               min_replay_size: int = 1000,
               max_replay_size: int = 1000000,
               samples_per_insert: float = 32.0,
               n_step: int = 5,
               sigma: float = 0.3,
               clipping: bool = True,
               logger: loggers.Logger = None,
               counter: counting.Counter = None,
               checkpoint: bool = True,
               replay_table_name: str = adders.DEFAULT_PRIORITY_TABLE):
    """Initialize the agent.

    Args:
      environment_spec: description of the actions, observations, etc.
      policy_network: the online (optimized) policy.
      critic_network: the online critic.
      observation_network: optional network to transform the observations before
        they are fed into any network.
      discount: discount to use for TD updates.
      batch_size: batch size for updates.
      prefetch_size: size to prefetch from replay.
      target_update_period: number of learner steps to perform before updating
        the target networks.
      min_replay_size: minimum replay size before updating.
      max_replay_size: maximum replay size.
      samples_per_insert: number of samples to take from replay for every insert
        that is made.
      n_step: number of steps to squash into a single transition.
      sigma: standard deviation of zero-mean, Gaussian exploration noise.
      clipping: whether to clip gradients by global norm.
      logger: logger object to be used by learner.
      counter: counter object used to keep track of steps.
      checkpoint: boolean indicating whether to checkpoint the learner.
      replay_table_name: string indicating what name to give the replay table.
    """
    # Create a replay server to add data to. This uses no limiter behavior in
    # order to allow the Agent interface to handle it.
    replay_table = reverb.Table(
        name=replay_table_name,
        sampler=reverb.selectors.Uniform(),
        remover=reverb.selectors.Fifo(),
        max_size=max_replay_size,
        rate_limiter=reverb.rate_limiters.MinSize(1))
    self._server = reverb.Server([replay_table], port=None)

    # The adder is used to insert observations into replay.
    address = f'localhost:{self._server.port}'
    adder = adders.NStepTransitionAdder(
        priority_fns={replay_table_name: lambda x: 1.},
        client=reverb.Client(address),
        n_step=n_step,
        discount=discount)

    # The dataset provides an interface to sample from replay.
    dataset = datasets.make_reverb_dataset(
        table=replay_table_name,
        client=reverb.TFClient(address),
        batch_size=batch_size,
        prefetch_size=prefetch_size,
        environment_spec=environment_spec,
        transition_adder=True)

    # Make sure observation network is a Sonnet Module.
    observation_network = tf2_utils.to_sonnet_module(observation_network)

    # Create target networks.
    target_policy_network = copy.deepcopy(policy_network)
    target_critic_network = copy.deepcopy(critic_network)
    target_observation_network = copy.deepcopy(observation_network)

    # Get observation and action specs.
    act_spec = environment_spec.actions
    obs_spec = environment_spec.observations
    emb_spec = tf2_utils.create_variables(observation_network, [obs_spec])

    # Create the behavior policy.
    behavior_network = snt.Sequential([
        observation_network,
        policy_network,
        networks.ClippedGaussian(sigma),
        networks.ClipToSpec(act_spec),
    ])

    # Create variables.
    tf2_utils.create_variables(policy_network, [emb_spec])
    tf2_utils.create_variables(critic_network, [emb_spec, act_spec])
    tf2_utils.create_variables(target_policy_network, [emb_spec])
    tf2_utils.create_variables(target_critic_network, [emb_spec, act_spec])
    tf2_utils.create_variables(target_observation_network, [obs_spec])

    # Create the actor which defines how we take actions.
    actor = actors.FeedForwardActor(behavior_network, adder=adder)

    # Create optimizers.
    policy_optimizer = snt.optimizers.Adam(learning_rate=1e-4)
    critic_optimizer = snt.optimizers.Adam(learning_rate=1e-4)

    # The learner updates the parameters (and initializes them).
    learner = learning.D4PGLearner(
        policy_network=policy_network,
        critic_network=critic_network,
        observation_network=observation_network,
        target_policy_network=target_policy_network,
        target_critic_network=target_critic_network,
        target_observation_network=target_observation_network,
        policy_optimizer=policy_optimizer,
        critic_optimizer=critic_optimizer,
        clipping=clipping,
        discount=discount,
        target_update_period=target_update_period,
        dataset=dataset,
        counter=counter,
        logger=logger,
        checkpoint=checkpoint,
    )

    super().__init__(
        actor=actor,
        learner=learner,
        min_observations=max(batch_size, min_replay_size),
        observations_per_step=float(batch_size) / samples_per_insert)
Пример #5
0
    def __init__(self,
                 environment_spec: specs.EnvironmentSpec,
                 policy_network: snt.Module,
                 critic_network: snt.Module,
                 observation_network: types.TensorTransformation = tf.identity,
                 discount: float = 0.99,
                 batch_size: int = 256,
                 prefetch_size: int = 4,
                 target_policy_update_period: int = 100,
                 target_critic_update_period: int = 100,
                 min_replay_size: int = 1000,
                 max_replay_size: int = 1000000,
                 samples_per_insert: float = 32.0,
                 policy_loss_module: snt.Module = None,
                 policy_optimizer: snt.Optimizer = None,
                 critic_optimizer: snt.Optimizer = None,
                 n_step: int = 5,
                 num_samples: int = 20,
                 clipping: bool = True,
                 logger: loggers.Logger = None,
                 counter: counting.Counter = None,
                 checkpoint: bool = True,
                 replay_table_name: str = adders.DEFAULT_PRIORITY_TABLE):
        """Initialize the agent.

    Args:
      environment_spec: description of the actions, observations, etc.
      policy_network: the online (optimized) policy.
      critic_network: the online critic.
      observation_network: optional network to transform the observations before
        they are fed into any network.
      discount: discount to use for TD updates.
      batch_size: batch size for updates.
      prefetch_size: size to prefetch from replay.
      target_policy_update_period: number of updates to perform before updating
        the target policy network.
      target_critic_update_period: number of updates to perform before updating
        the target critic network.
      min_replay_size: minimum replay size before updating.
      max_replay_size: maximum replay size.
      samples_per_insert: number of samples to take from replay for every insert
        that is made.
      policy_loss_module: configured MPO loss function for the policy
        optimization; defaults to sensible values on the control suite.
        See `acme/tf/losses/mpo.py` for more details.
      policy_optimizer: optimizer to be used on the policy.
      critic_optimizer: optimizer to be used on the critic.
      n_step: number of steps to squash into a single transition.
      num_samples: number of actions to sample when doing a Monte Carlo
        integration with respect to the policy.
      clipping: whether to clip gradients by global norm.
      logger: logging object used to write to logs.
      counter: counter object used to keep track of steps.
      checkpoint: boolean indicating whether to checkpoint the learner.
      replay_table_name: string indicating what name to give the replay table.
    """

        # Create a replay server to add data to.
        replay_table = reverb.Table(
            name=adders.DEFAULT_PRIORITY_TABLE,
            sampler=reverb.selectors.Uniform(),
            remover=reverb.selectors.Fifo(),
            max_size=max_replay_size,
            rate_limiter=reverb.rate_limiters.MinSize(min_size_to_sample=1))
        self._server = reverb.Server([replay_table], port=None)

        # The adder is used to insert observations into replay.
        address = f'localhost:{self._server.port}'
        adder = adders.NStepTransitionAdder(client=reverb.Client(address),
                                            n_step=n_step,
                                            discount=discount)

        # The dataset object to learn from.
        dataset = datasets.make_reverb_dataset(
            table=replay_table_name,
            client=reverb.TFClient(address),
            batch_size=batch_size,
            prefetch_size=prefetch_size,
            environment_spec=environment_spec,
            transition_adder=True)

        # Make sure observation network is a Sonnet Module.
        observation_network = tf2_utils.to_sonnet_module(observation_network)

        # Create target networks before creating online/target network variables.
        target_policy_network = copy.deepcopy(policy_network)
        target_critic_network = copy.deepcopy(critic_network)
        target_observation_network = copy.deepcopy(observation_network)

        # Get observation and action specs.
        act_spec = environment_spec.actions
        obs_spec = environment_spec.observations
        emb_spec = tf2_utils.create_variables(observation_network, [obs_spec])

        # Create the behavior policy.
        behavior_network = snt.Sequential([
            observation_network,
            policy_network,
            networks.StochasticSamplingHead(),
        ])

        # Create variables.
        tf2_utils.create_variables(policy_network, [emb_spec])
        tf2_utils.create_variables(critic_network, [emb_spec, act_spec])
        tf2_utils.create_variables(target_policy_network, [emb_spec])
        tf2_utils.create_variables(target_critic_network, [emb_spec, act_spec])
        tf2_utils.create_variables(target_observation_network, [obs_spec])

        # Create the actor which defines how we take actions.
        actor = actors.FeedForwardActor(policy_network=behavior_network,
                                        adder=adder)

        # Create optimizers.
        policy_optimizer = policy_optimizer or snt.optimizers.Adam(1e-4)
        critic_optimizer = critic_optimizer or snt.optimizers.Adam(1e-4)

        # The learner updates the parameters (and initializes them).
        learner = learning.DistributionalMPOLearner(
            policy_network=policy_network,
            critic_network=critic_network,
            observation_network=observation_network,
            target_policy_network=target_policy_network,
            target_critic_network=target_critic_network,
            target_observation_network=target_observation_network,
            policy_loss_module=policy_loss_module,
            policy_optimizer=policy_optimizer,
            critic_optimizer=critic_optimizer,
            clipping=clipping,
            discount=discount,
            num_samples=num_samples,
            target_policy_update_period=target_policy_update_period,
            target_critic_update_period=target_critic_update_period,
            dataset=dataset,
            logger=logger,
            counter=counter,
            checkpoint=checkpoint)

        super().__init__(actor=actor,
                         learner=learner,
                         min_observations=max(batch_size, min_replay_size),
                         observations_per_step=float(batch_size) /
                         samples_per_insert)
Пример #6
0
  def __init__(
      self,
      policy_network: snt.Module,
      critic_network: snt.Module,
      target_policy_network: snt.Module,
      target_critic_network: snt.Module,
      discount: float,
      target_update_period: int,
      dataset: tf.data.Dataset,
      observation_network: types.TensorTransformation = lambda x: x,
      target_observation_network: types.TensorTransformation = lambda x: x,
      policy_optimizer: snt.Optimizer = None,
      critic_optimizer: snt.Optimizer = None,
      clipping: bool = True,
      counter: counting.Counter = None,
      logger: loggers.Logger = None,
      checkpoint: bool = True,
  ):
    """Initializes the learner.

    Args:
      policy_network: the online (optimized) policy.
      critic_network: the online critic.
      target_policy_network: the target policy (which lags behind the online
        policy).
      target_critic_network: the target critic.
      discount: discount to use for TD updates.
      target_update_period: number of learner steps to perform before updating
        the target networks.
      dataset: dataset to learn from, whether fixed or from a replay buffer
        (see `acme.datasets.reverb.make_dataset` documentation).
      observation_network: an optional online network to process observations
        before the policy and the critic.
      target_observation_network: the target observation network.
      policy_optimizer: the optimizer to be applied to the DPG (policy) loss.
      critic_optimizer: the optimizer to be applied to the distributional
        Bellman loss.
      clipping: whether to clip gradients by global norm.
      counter: counter object used to keep track of steps.
      logger: logger object to be used by learner.
      checkpoint: boolean indicating whether to checkpoint the learner.
    """

    # Store online and target networks.
    self._policy_network = policy_network
    self._critic_network = critic_network
    self._target_policy_network = target_policy_network
    self._target_critic_network = target_critic_network

    # Make sure observation networks are snt.Module's so they have variables.
    self._observation_network = tf2_utils.to_sonnet_module(observation_network)
    self._target_observation_network = tf2_utils.to_sonnet_module(
        target_observation_network)

    # General learner book-keeping and loggers.
    self._counter = counter or counting.Counter()
    self._logger = logger or loggers.make_default_logger('learner')

    # Other learner parameters.
    self._discount = discount
    self._clipping = clipping

    # Necessary to track when to update target networks.
    self._num_steps = tf.Variable(0, dtype=tf.int32)
    self._target_update_period = target_update_period

    # Batch dataset and create iterator.
    # TODO(b/155086959): Fix type stubs and remove.
    self._iterator = iter(dataset)  # pytype: disable=wrong-arg-types

    # Create optimizers if they aren't given.
    self._critic_optimizer = critic_optimizer or snt.optimizers.Adam(1e-4)
    self._policy_optimizer = policy_optimizer or snt.optimizers.Adam(1e-4)

    # Expose the variables.
    policy_network_to_expose = snt.Sequential(
        [self._target_observation_network, self._target_policy_network])
    self._variables = {
        'critic': self._target_critic_network.variables,
        'policy': policy_network_to_expose.variables,
    }

    # Create a checkpointer and snapshotter objects.
    self._checkpointer = None
    self._snapshotter = None

    if checkpoint:
      self._checkpointer = tf2_savers.Checkpointer(
          subdirectory='d4pg_learner',
          objects_to_save={
              'counter': self._counter,
              'policy': self._policy_network,
              'critic': self._critic_network,
              'observation': self._observation_network,
              'target_policy': self._target_policy_network,
              'target_critic': self._target_critic_network,
              'target_observation': self._target_observation_network,
              'policy_optimizer': self._policy_optimizer,
              'critic_optimizer': self._critic_optimizer,
              'num_steps': self._num_steps,
          })
      critic_mean = snt.Sequential(
          [self._critic_network, acme_nets.StochasticMeanHead()])
      self._snapshotter = tf2_savers.Snapshotter(
          objects_to_save={
              'policy': self._policy_network,
              'critic': critic_mean,
          })

    # Do not record timestamps until after the first learning step is done.
    # This is to avoid including the time it takes for actors to come online and
    # fill the replay buffer.
    self._timestamp = None