Пример #1
0
    def learner(
        self,
        replay: reverb.Client,
        counter: counting.Counter,
    ):
        """The Learning part of the agent."""

        act_spec = self._environment_spec.actions
        obs_spec = self._environment_spec.observations

        # Create online and target networks.
        online_networks = self._network_factory(act_spec)
        target_networks = self._network_factory(act_spec)

        # Make sure observation networks are Sonnet Modules.
        observation_network = online_networks.get('observation', tf.identity)
        observation_network = tf2_utils.to_sonnet_module(observation_network)
        online_networks['observation'] = observation_network
        target_observation_network = target_networks.get(
            'observation', tf.identity)
        target_observation_network = tf2_utils.to_sonnet_module(
            target_observation_network)
        target_networks['observation'] = target_observation_network

        # Get embedding spec and create observation network variables.
        emb_spec = tf2_utils.create_variables(observation_network, [obs_spec])

        tf2_utils.create_variables(online_networks['policy'], [emb_spec])
        tf2_utils.create_variables(online_networks['critic'],
                                   [emb_spec, act_spec])
        tf2_utils.create_variables(target_networks['observation'], [obs_spec])
        tf2_utils.create_variables(target_networks['policy'], [emb_spec])
        tf2_utils.create_variables(target_networks['critic'],
                                   [emb_spec, act_spec])

        # The dataset object to learn from.
        dataset = datasets.make_reverb_dataset(
            server_address=replay.server_address)
        dataset = dataset.batch(self._batch_size, drop_remainder=True)
        dataset = dataset.prefetch(self._prefetch_size)

        # Create a counter and logger for bookkeeping steps and performance.
        counter = counting.Counter(counter, 'learner')
        logger = loggers.make_default_logger('learner',
                                             time_delta=self._log_every,
                                             steps_key='learner_steps')

        # Create policy loss module if a factory is passed.
        if self._policy_loss_factory:
            policy_loss_module = self._policy_loss_factory()
        else:
            policy_loss_module = None

        # Return the learning agent.
        return learning.MPOLearner(
            policy_network=online_networks['policy'],
            critic_network=online_networks['critic'],
            observation_network=observation_network,
            target_policy_network=target_networks['policy'],
            target_critic_network=target_networks['critic'],
            target_observation_network=target_observation_network,
            discount=self._additional_discount,
            num_samples=self._num_samples,
            target_policy_update_period=self._target_policy_update_period,
            target_critic_update_period=self._target_critic_update_period,
            policy_loss_module=policy_loss_module,
            dataset=dataset,
            counter=counter,
            logger=logger)
Пример #2
0
    def __init__(
        self,
        environment_spec: specs.EnvironmentSpec,
        policy_network: snt.Module,
        critic_network: snt.Module,
        observation_network: types.TensorTransformation = tf.identity,
        discount: float = 0.99,
        batch_size: int = 256,
        prefetch_size: int = 4,
        target_policy_update_period: int = 100,
        target_critic_update_period: int = 100,
        min_replay_size: int = 1000,
        max_replay_size: int = 1000000,
        samples_per_insert: float = 32.0,
        policy_loss_module: snt.Module = None,
        policy_optimizer: snt.Optimizer = None,
        critic_optimizer: snt.Optimizer = None,
        n_step: int = 5,
        num_samples: int = 20,
        clipping: bool = True,
        logger: loggers.Logger = None,
        counter: counting.Counter = None,
        checkpoint: bool = True,
        replay_table_name: str = adders.DEFAULT_PRIORITY_TABLE,
    ):
        """Initialize the agent.

    Args:
      environment_spec: description of the actions, observations, etc.
      policy_network: the online (optimized) policy.
      critic_network: the online critic.
      observation_network: optional network to transform the observations before
        they are fed into any network.
      discount: discount to use for TD updates.
      batch_size: batch size for updates.
      prefetch_size: size to prefetch from replay.
      target_policy_update_period: number of updates to perform before updating
        the target policy network.
      target_critic_update_period: number of updates to perform before updating
        the target critic network.
      min_replay_size: minimum replay size before updating.
      max_replay_size: maximum replay size.
      samples_per_insert: number of samples to take from replay for every insert
        that is made.
      policy_loss_module: configured MPO loss function for the policy
        optimization; defaults to sensible values on the control suite.
        See `acme/tf/losses/mpo.py` for more details.
      policy_optimizer: optimizer to be used on the policy.
      critic_optimizer: optimizer to be used on the critic.
      n_step: number of steps to squash into a single transition.
      num_samples: number of actions to sample when doing a Monte Carlo
        integration with respect to the policy.
      clipping: whether to clip gradients by global norm.
      logger: logging object used to write to logs.
      counter: counter object used to keep track of steps.
      checkpoint: boolean indicating whether to checkpoint the learner.
      replay_table_name: string indicating what name to give the replay table.
    """

        # Create a replay server to add data to.
        replay_table = reverb.Table(
            name=adders.DEFAULT_PRIORITY_TABLE,
            sampler=reverb.selectors.Uniform(),
            remover=reverb.selectors.Fifo(),
            max_size=max_replay_size,
            rate_limiter=reverb.rate_limiters.MinSize(min_size_to_sample=1),
            signature=adders.NStepTransitionAdder.signature(environment_spec))
        self._server = reverb.Server([replay_table], port=None)

        # The adder is used to insert observations into replay.
        address = f'localhost:{self._server.port}'
        adder = adders.NStepTransitionAdder(client=reverb.Client(address),
                                            n_step=n_step,
                                            discount=discount)

        # The dataset object to learn from.
        dataset = datasets.make_reverb_dataset(
            table=replay_table_name,
            client=reverb.TFClient(address),
            batch_size=batch_size,
            prefetch_size=prefetch_size,
            environment_spec=environment_spec,
            transition_adder=True)

        # Make sure observation network is a Sonnet Module.
        observation_network = tf2_utils.to_sonnet_module(observation_network)

        # Create target networks before creating online/target network variables.
        target_policy_network = copy.deepcopy(policy_network)
        target_critic_network = copy.deepcopy(critic_network)
        target_observation_network = copy.deepcopy(observation_network)

        # Get observation and action specs.
        act_spec = environment_spec.actions
        obs_spec = environment_spec.observations
        emb_spec = tf2_utils.create_variables(observation_network, [obs_spec])

        # Create the behavior policy.
        behavior_network = snt.Sequential([
            observation_network,
            policy_network,
            networks.StochasticSamplingHead(),
        ])

        # Create variables.
        tf2_utils.create_variables(policy_network, [emb_spec])
        tf2_utils.create_variables(critic_network, [emb_spec, act_spec])
        tf2_utils.create_variables(target_policy_network, [emb_spec])
        tf2_utils.create_variables(target_critic_network, [emb_spec, act_spec])
        tf2_utils.create_variables(target_observation_network, [obs_spec])

        # Create the actor which defines how we take actions.
        actor = actors.FeedForwardActor(policy_network=behavior_network,
                                        adder=adder)

        # Create optimizers.
        policy_optimizer = policy_optimizer or snt.optimizers.Adam(1e-4)
        critic_optimizer = critic_optimizer or snt.optimizers.Adam(1e-4)

        # The learner updates the parameters (and initializes them).
        learner = learning.MPOLearner(
            policy_network=policy_network,
            critic_network=critic_network,
            observation_network=observation_network,
            target_policy_network=target_policy_network,
            target_critic_network=target_critic_network,
            target_observation_network=target_observation_network,
            policy_loss_module=policy_loss_module,
            policy_optimizer=policy_optimizer,
            critic_optimizer=critic_optimizer,
            clipping=clipping,
            discount=discount,
            num_samples=num_samples,
            target_policy_update_period=target_policy_update_period,
            target_critic_update_period=target_critic_update_period,
            dataset=dataset,
            logger=logger,
            counter=counter,
            checkpoint=checkpoint)

        super().__init__(actor=actor,
                         learner=learner,
                         min_observations=max(batch_size, min_replay_size),
                         observations_per_step=float(batch_size) /
                         samples_per_insert)