Ejemplo n.º 1
0
  def __init__(self,
               environment_spec: specs.EnvironmentSpec,
               policy_network: snt.Module,
               critic_network: snt.Module,
               observation_network: types.TensorTransformation = tf.identity,
               discount: float = 0.99,
               batch_size: int = 256,
               prefetch_size: int = 4,
               target_policy_update_period: int = 100,
               target_critic_update_period: int = 100,
               min_replay_size: int = 1000,
               max_replay_size: int = 1000000,
               samples_per_insert: float = 32.0,
               policy_loss_module: snt.Module = None,
               policy_optimizer: snt.Optimizer = None,
               critic_optimizer: snt.Optimizer = None,
               n_step: int = 5,
               num_samples: int = 20,
               clipping: bool = True,
               logger: loggers.Logger = None,
               counter: counting.Counter = None,
               checkpoint: bool = True,
               replay_table_name: str = adders.DEFAULT_PRIORITY_TABLE):
    """Initialize the agent.

    Args:
      environment_spec: description of the actions, observations, etc.
      policy_network: the online (optimized) policy.
      critic_network: the online critic.
      observation_network: optional network to transform the observations before
        they are fed into any network.
      discount: discount to use for TD updates.
      batch_size: batch size for updates.
      prefetch_size: size to prefetch from replay.
      target_policy_update_period: number of updates to perform before updating
        the target policy network.
      target_critic_update_period: number of updates to perform before updating
        the target critic network.
      min_replay_size: minimum replay size before updating.
      max_replay_size: maximum replay size.
      samples_per_insert: number of samples to take from replay for every insert
        that is made.
      policy_loss_module: configured MPO loss function for the policy
        optimization; defaults to sensible values on the control suite.
        See `acme/tf/losses/mpo.py` for more details.
      policy_optimizer: optimizer to be used on the policy.
      critic_optimizer: optimizer to be used on the critic.
      n_step: number of steps to squash into a single transition.
      num_samples: number of actions to sample when doing a Monte Carlo
        integration with respect to the policy.
      clipping: whether to clip gradients by global norm.
      logger: logging object used to write to logs.
      counter: counter object used to keep track of steps.
      checkpoint: boolean indicating whether to checkpoint the learner.
      replay_table_name: string indicating what name to give the replay table.
    """

    # Create a replay server to add data to.
    replay_table = reverb.Table(
        name=adders.DEFAULT_PRIORITY_TABLE,
        sampler=reverb.selectors.Uniform(),
        remover=reverb.selectors.Fifo(),
        max_size=max_replay_size,
        rate_limiter=reverb.rate_limiters.MinSize(min_size_to_sample=1),
        signature=adders.NStepTransitionAdder.signature(environment_spec))
    self._server = reverb.Server([replay_table], port=None)

    # The adder is used to insert observations into replay.
    address = f'localhost:{self._server.port}'
    adder = adders.NStepTransitionAdder(
        client=reverb.Client(address),
        n_step=n_step,
        discount=discount)

    # The dataset object to learn from.
    dataset = datasets.make_reverb_dataset(
        table=replay_table_name,
        server_address=address,
        batch_size=batch_size,
        prefetch_size=prefetch_size)

    # Make sure observation network is a Sonnet Module.
    observation_network = tf2_utils.to_sonnet_module(observation_network)

    # Create target networks before creating online/target network variables.
    target_policy_network = copy.deepcopy(policy_network)
    target_critic_network = copy.deepcopy(critic_network)
    target_observation_network = copy.deepcopy(observation_network)

    # Get observation and action specs.
    act_spec = environment_spec.actions
    obs_spec = environment_spec.observations
    emb_spec = tf2_utils.create_variables(observation_network, [obs_spec])

    # Create the behavior policy.
    behavior_network = snt.Sequential([
        observation_network,
        policy_network,
        networks.StochasticSamplingHead(),
    ])

    # Create variables.
    tf2_utils.create_variables(policy_network, [emb_spec])
    tf2_utils.create_variables(critic_network, [emb_spec, act_spec])
    tf2_utils.create_variables(target_policy_network, [emb_spec])
    tf2_utils.create_variables(target_critic_network, [emb_spec, act_spec])
    tf2_utils.create_variables(target_observation_network, [obs_spec])

    # Create the actor which defines how we take actions.
    actor = actors.FeedForwardActor(
        policy_network=behavior_network, adder=adder)

    # Create optimizers.
    policy_optimizer = policy_optimizer or snt.optimizers.Adam(1e-4)
    critic_optimizer = critic_optimizer or snt.optimizers.Adam(1e-4)

    # The learner updates the parameters (and initializes them).
    learner = learning.DistributionalMPOLearner(
        policy_network=policy_network,
        critic_network=critic_network,
        observation_network=observation_network,
        target_policy_network=target_policy_network,
        target_critic_network=target_critic_network,
        target_observation_network=target_observation_network,
        policy_loss_module=policy_loss_module,
        policy_optimizer=policy_optimizer,
        critic_optimizer=critic_optimizer,
        clipping=clipping,
        discount=discount,
        num_samples=num_samples,
        target_policy_update_period=target_policy_update_period,
        target_critic_update_period=target_critic_update_period,
        dataset=dataset,
        logger=logger,
        counter=counter,
        checkpoint=checkpoint)

    super().__init__(
        actor=actor,
        learner=learner,
        min_observations=max(batch_size, min_replay_size),
        observations_per_step=float(batch_size) / samples_per_insert)
Ejemplo n.º 2
0
  def learner(
      self,
      replay: reverb.Client,
      counter: counting.Counter,
  ):
    """The Learning part of the agent."""

    act_spec = self._environment_spec.actions
    obs_spec = self._environment_spec.observations

    # Create online and target networks.
    online_networks = self._network_factory(act_spec)
    target_networks = self._network_factory(act_spec)

    # Make sure observation network is a Sonnet Module.
    observation_network = online_networks.get('observation', tf.identity)
    target_observation_network = target_networks.get('observation', tf.identity)
    observation_network = tf2_utils.to_sonnet_module(observation_network)
    target_observation_network = tf2_utils.to_sonnet_module(
        target_observation_network)

    # Get embedding spec and create observation network variables.
    emb_spec = tf2_utils.create_variables(observation_network, [obs_spec])

    # Create variables.
    tf2_utils.create_variables(online_networks['policy'], [emb_spec])
    tf2_utils.create_variables(online_networks['critic'], [emb_spec, act_spec])
    tf2_utils.create_variables(target_networks['policy'], [emb_spec])
    tf2_utils.create_variables(target_networks['critic'], [emb_spec, act_spec])
    tf2_utils.create_variables(target_observation_network, [obs_spec])

    # The dataset object to learn from.
    dataset = datasets.make_reverb_dataset(server_address=replay.server_address)
    dataset = dataset.batch(self._batch_size, drop_remainder=True)
    if self._observation_augmentation:
      transform = image_augmentation.make_transform(
          observation_transform=self._observation_augmentation)
      dataset = dataset.map(
          transform, num_parallel_calls=16, deterministic=False)
    dataset = dataset.prefetch(self._prefetch_size)

    counter = counting.Counter(counter, 'learner')
    logger = loggers.make_default_logger(
        'learner', time_delta=self._log_every, steps_key='learner_steps')

    # Create policy loss module if a factory is passed.
    if self._policy_loss_factory:
      policy_loss_module = self._policy_loss_factory()
    else:
      policy_loss_module = None

    # Return the learning agent.
    return learning.DistributionalMPOLearner(
        policy_network=online_networks['policy'],
        critic_network=online_networks['critic'],
        observation_network=observation_network,
        target_policy_network=target_networks['policy'],
        target_critic_network=target_networks['critic'],
        target_observation_network=target_observation_network,
        discount=self._additional_discount,
        num_samples=self._num_samples,
        target_policy_update_period=self._target_policy_update_period,
        target_critic_update_period=self._target_critic_update_period,
        policy_loss_module=policy_loss_module,
        dataset=dataset,
        counter=counter,
        logger=logger)