Esempio n. 1
0
    def make_dataset_iterator(
        self,
        replay_client: reverb.Client,
    ) -> Iterator[reverb.ReplaySample]:
        """Create a dataset iterator to use for training/updating the system.

        Args:
            replay_client (reverb.Client): Reverb Client which points to the
                replay server.

        Returns:
            [type]: dataset iterator.

        Yields:
            Iterator[reverb.ReplaySample]: data samples from the dataset.
        """

        sequence_length = (
            self._config.sequence_length
            if issubclass(self._executor_fn, executors.RecurrentExecutor)
            else None
        )

        dataset = datasets.make_reverb_dataset(
            table=self._config.replay_table_name,
            server_address=replay_client.server_address,
            batch_size=self._config.batch_size,
            prefetch_size=self._config.prefetch_size,
            sequence_length=sequence_length,
        )
        return iter(dataset)
Esempio n. 2
0
  def learner(self, queue: reverb.Client, counter: counting.Counter):
    """The Learning part of the agent."""
    # Use architect and create the environment.
    # Create the networks.
    network = self._network_factory(self._environment_spec.actions)
    tf2_utils.create_variables(network, [self._environment_spec.observations])

    # The dataset object to learn from.
    dataset = datasets.make_reverb_dataset(
        server_address=queue.server_address,
        batch_size=self._batch_size,
        prefetch_size=self._prefetch_size)

    logger = loggers.make_default_logger('learner', steps_key='learner_steps')
    counter = counting.Counter(counter, 'learner')

    # Return the learning agent.
    learner = learning.IMPALALearner(
        environment_spec=self._environment_spec,
        network=network,
        dataset=dataset,
        discount=self._discount,
        learning_rate=self._learning_rate,
        entropy_cost=self._entropy_cost,
        baseline_cost=self._baseline_cost,
        max_abs_reward=self._max_abs_reward,
        max_gradient_norm=self._max_gradient_norm,
        counter=counter,
        logger=logger,
    )

    return tf2_savers.CheckpointingRunner(learner,
                                          time_delta_minutes=5,
                                          subdirectory='impala_learner')
Esempio n. 3
0
    def make_dataset_iterator(
        self,
        reverb_client: reverb.Client,
    ) -> Iterator[reverb.ReplaySample]:
        """Create a dataset iterator to use for learning/updating the agent."""
        # The dataset provides an interface to sample from replay.
        dataset = datasets.make_reverb_dataset(
            table=self._config.replay_table_name,
            server_address=reverb_client.server_address,
            batch_size=self._config.batch_size,
            prefetch_size=self._config.prefetch_size)

        # TODO(b/155086959): Fix type stubs and remove.
        return iter(dataset)  # pytype: disable=wrong-arg-types
Esempio n. 4
0
    def learner(
        self,
        replay: reverb.Client,
        counter: counting.Counter,
    ):
        """The Learning part of the agent."""

        # Create online and target networks.
        online_networks = self._network_factory(self._environment_spec)
        target_networks = self._network_factory(self._environment_spec)

        # The dataset object to learn from.
        dataset = datasets.make_reverb_dataset(
            server_address=replay.server_address,
            batch_size=self._batch_size,
            prefetch_size=self._prefetch_size,
        )

        counter = counting.Counter(counter, 'learner')
        logger = loggers.make_default_logger('learner',
                                             time_delta=self._log_every)

        # Create policy loss module if a factory is passed.
        if self._policy_loss_factory:
            policy_loss_module = self._policy_loss_factory()
        else:
            policy_loss_module = None

        # Return the learning agent.
        return learning.MoGMPOLearner(
            policy_network=online_networks['policy'],
            critic_network=online_networks['critic'],
            observation_network=online_networks['observation'],
            target_policy_network=target_networks['policy'],
            target_critic_network=target_networks['critic'],
            target_observation_network=target_networks['observation'],
            discount=self._additional_discount,
            num_samples=self._num_samples,
            policy_evaluation_config=self._policy_evaluation_config,
            target_policy_update_period=self._target_policy_update_period,
            target_critic_update_period=self._target_critic_update_period,
            policy_loss_module=policy_loss_module,
            dataset=dataset,
            counter=counter,
            logger=logger)
Esempio n. 5
0
 def reset_replay_table(self, name='new_replay_table'):
     replay_table = reverb.Table(
         name=name,
         sampler=reverb.selectors.Uniform(),
         remover=reverb.selectors.Fifo(),
         max_size=1000000,
         rate_limiter=reverb.rate_limiters.MinSize(min_size_to_sample=1),
         signature=adders.NStepTransitionAdder.signature(self._environment_spec))
     port = self._agent._server.port
     del self._agent._server
     self._agent._server = reverb.Server([replay_table], port=port)
     dataset = datasets.make_reverb_dataset(
         table=name,
         server_address=f'localhost:{port}',
         batch_size=256,
         prefetch_size=4,
     )
     self._agent._learner._iterator = iter(dataset)
Esempio n. 6
0
  def learner(self, replay: reverb.Client, counter: counting.Counter):
    """The Learning part of the agent."""
    # Use architect and create the environment.
    # Create the networks.
    network = self._network_factory(self._environment_spec.actions)
    target_network = copy.deepcopy(network)

    tf2_utils.create_variables(network, [self._obs_spec])
    tf2_utils.create_variables(target_network, [self._obs_spec])

    # The dataset object to learn from.
    reverb_client = reverb.TFClient(replay.server_address)
    sequence_length = self._burn_in_length + self._trace_length + 1
    dataset = datasets.make_reverb_dataset(
        server_address=replay.server_address,
        batch_size=self._batch_size,
        prefetch_size=self._prefetch_size)

    counter = counting.Counter(counter, 'learner')
    logger = loggers.make_default_logger(
        'learner', save_data=True, steps_key='learner_steps')
    # Return the learning agent.
    learner = learning.R2D2Learner(
        environment_spec=self._environment_spec,
        network=network,
        target_network=target_network,
        burn_in_length=self._burn_in_length,
        sequence_length=sequence_length,
        dataset=dataset,
        reverb_client=reverb_client,
        counter=counter,
        logger=logger,
        discount=self._discount,
        target_update_period=self._target_update_period,
        importance_sampling_exponent=self._importance_sampling_exponent,
        learning_rate=self._learning_rate,
        max_replay_size=self._max_replay_size)
    return tf2_savers.CheckpointingRunner(
        wrapped=learner, time_delta_minutes=60, subdirectory='r2d2_learner')
Esempio n. 7
0
    def learner(self, replay: reverb.Client, counter: counting.Counter):
        """The Learning part of the agent."""

        # Create the networks.
        network = self._network_factory(self._env_spec.actions)
        target_network = copy.deepcopy(network)

        tf2_utils.create_variables(network, [self._env_spec.observations])
        tf2_utils.create_variables(target_network,
                                   [self._env_spec.observations])

        # The dataset object to learn from.
        replay_client = reverb.Client(replay.server_address)
        dataset = datasets.make_reverb_dataset(
            server_address=replay.server_address,
            batch_size=self._batch_size,
            prefetch_size=self._prefetch_size)

        logger = loggers.make_default_logger('learner',
                                             steps_key='learner_steps')

        # Return the learning agent.
        counter = counting.Counter(counter, 'learner')

        learner = learning.DQNLearner(
            network=network,
            target_network=target_network,
            discount=self._discount,
            importance_sampling_exponent=self._importance_sampling_exponent,
            learning_rate=self._learning_rate,
            target_update_period=self._target_update_period,
            dataset=dataset,
            replay_client=replay_client,
            counter=counter,
            logger=logger)
        return tf2_savers.CheckpointingRunner(learner,
                                              subdirectory='dqn_learner',
                                              time_delta_minutes=60)
Esempio n. 8
0
    def learner(self, replay: reverb.Client, counter: counting.Counter):
        """The learning part of the agent."""
        # Create the networks.
        network = self._network_factory(self._env_spec.actions)

        tf2_utils.create_variables(network, [self._env_spec.observations])

        # The dataset object to learn from.
        dataset = datasets.make_reverb_dataset(
            server_address=replay.server_address,
            batch_size=self._batch_size,
            prefetch_size=self._prefetch_size)

        # Create the optimizer.
        optimizer = snt.optimizers.Adam(self._learning_rate)

        # Return the learning agent.
        return learning.AZLearner(
            network=network,
            discount=self._discount,
            dataset=dataset,
            optimizer=optimizer,
            counter=counter,
        )
Esempio n. 9
0
    def learner(
        self,
        replay: reverb.Client,
        counter: counting.Counter,
    ):
        """The Learning part of the agent."""

        act_spec = self._environment_spec.actions
        obs_spec = self._environment_spec.observations

        # Create online and target networks.
        online_networks = self._network_factory(act_spec)
        target_networks = self._network_factory(act_spec)

        # Make sure observation networks are Sonnet Modules.
        observation_network = online_networks.get('observation', tf.identity)
        observation_network = tf2_utils.to_sonnet_module(observation_network)
        online_networks['observation'] = observation_network
        target_observation_network = target_networks.get(
            'observation', tf.identity)
        target_observation_network = tf2_utils.to_sonnet_module(
            target_observation_network)
        target_networks['observation'] = target_observation_network

        # Get embedding spec and create observation network variables.
        emb_spec = tf2_utils.create_variables(observation_network, [obs_spec])

        tf2_utils.create_variables(online_networks['policy'], [emb_spec])
        tf2_utils.create_variables(online_networks['critic'],
                                   [emb_spec, act_spec])
        tf2_utils.create_variables(target_networks['observation'], [obs_spec])
        tf2_utils.create_variables(target_networks['policy'], [emb_spec])
        tf2_utils.create_variables(target_networks['critic'],
                                   [emb_spec, act_spec])

        # The dataset object to learn from.
        dataset = datasets.make_reverb_dataset(
            server_address=replay.server_address)
        dataset = dataset.batch(self._batch_size, drop_remainder=True)
        dataset = dataset.prefetch(self._prefetch_size)

        # Create a counter and logger for bookkeeping steps and performance.
        counter = counting.Counter(counter, 'learner')
        logger = loggers.make_default_logger('learner',
                                             time_delta=self._log_every,
                                             steps_key='learner_steps')

        # Create policy loss module if a factory is passed.
        if self._policy_loss_factory:
            policy_loss_module = self._policy_loss_factory()
        else:
            policy_loss_module = None

        # Return the learning agent.
        return learning.MPOLearner(
            policy_network=online_networks['policy'],
            critic_network=online_networks['critic'],
            observation_network=observation_network,
            target_policy_network=target_networks['policy'],
            target_critic_network=target_networks['critic'],
            target_observation_network=target_observation_network,
            discount=self._additional_discount,
            num_samples=self._num_samples,
            target_policy_update_period=self._target_policy_update_period,
            target_critic_update_period=self._target_critic_update_period,
            policy_loss_module=policy_loss_module,
            dataset=dataset,
            counter=counter,
            logger=logger)
Esempio n. 10
0
    def __init__(
        self,
        environment_spec: specs.EnvironmentSpec,
        network: snt.Module,
        batch_size: int = 256,
        prefetch_size: int = 4,
        target_update_period: int = 100,
        samples_per_insert: float = 32.0,
        min_replay_size: int = 1000,
        max_replay_size: int = 1000000,
        importance_sampling_exponent: float = 0.2,
        priority_exponent: float = 0.6,
        n_step: int = 5,
        epsilon: tf.Tensor = None,
        learning_rate: float = 1e-3,
        discount: float = 0.99,
        cql_alpha: float = 1.,
        logger: loggers.Logger = None,
        counter: counting.Counter = None,
        checkpoint_subpath: str = '~/acme/',
    ):
        """Initialize the agent.

    Args:
      environment_spec: description of the actions, observations, etc.
      network: the online Q network (the one being optimized)
      batch_size: batch size for updates.
      prefetch_size: size to prefetch from replay.
      target_update_period: number of learner steps to perform before updating
        the target networks.
      samples_per_insert: number of samples to take from replay for every insert
        that is made.
      min_replay_size: minimum replay size before updating. This and all
        following arguments are related to dataset construction and will be
        ignored if a dataset argument is passed.
      max_replay_size: maximum replay size.
      importance_sampling_exponent: power to which importance weights are raised
        before normalizing.
      priority_exponent: exponent used in prioritized sampling.
      n_step: number of steps to squash into a single transition.
      epsilon: probability of taking a random action; ignored if a policy
        network is given.
      learning_rate: learning rate for the q-network update.
      discount: discount to use for TD updates.
      logger: logger object to be used by learner.
      checkpoint: boolean indicating whether to checkpoint the learner.
      checkpoint_subpath: directory for the checkpoint.
    """

        # Create a replay server to add data to. This uses no limiter behavior in
        # order to allow the Agent interface to handle it.
        replay_table = reverb.Table(
            name=adders.DEFAULT_PRIORITY_TABLE,
            sampler=reverb.selectors.Prioritized(priority_exponent),
            remover=reverb.selectors.Fifo(),
            max_size=max_replay_size,
            rate_limiter=reverb.rate_limiters.MinSize(1),
            signature=adders.NStepTransitionAdder.signature(environment_spec))
        self._server = reverb.Server([replay_table], port=None)

        # The adder is used to insert observations into replay.
        address = f'localhost:{self._server.port}'
        adder = adders.NStepTransitionAdder(client=reverb.Client(address),
                                            n_step=n_step,
                                            discount=discount)

        # The dataset provides an interface to sample from replay.
        replay_client = reverb.TFClient(address)
        dataset = datasets.make_reverb_dataset(
            client=replay_client,
            environment_spec=environment_spec,
            batch_size=batch_size,
            prefetch_size=prefetch_size,
            transition_adder=True)

        # Use constant 0.05 epsilon greedy policy by default.
        if epsilon is None:
            epsilon = tf.Variable(0.05, trainable=False)
        policy_network = snt.Sequential([
            network,
            lambda q: trfl.epsilon_greedy(q, epsilon=epsilon).sample(),
        ])

        # Create a target network.
        target_network = copy.deepcopy(network)

        # Ensure that we create the variables before proceeding (maybe not needed).
        tf2_utils.create_variables(network, [environment_spec.observations])
        tf2_utils.create_variables(target_network,
                                   [environment_spec.observations])

        # Create the actor which defines how we take actions.
        actor = actors.FeedForwardActor(policy_network, adder)

        # The learner updates the parameters (and initializes them).
        learner = CQLLearner(
            network=network,
            discount=discount,
            importance_sampling_exponent=importance_sampling_exponent,
            learning_rate=learning_rate,
            cql_alpha=cql_alpha,
            target_update_period=target_update_period,
            dataset=dataset,
            replay_client=replay_client,
            logger=logger,
            counter=counter,
            checkpoint_subpath=checkpoint_subpath)

        super().__init__(actor=actor,
                         learner=learner,
                         min_observations=max(batch_size, min_replay_size),
                         observations_per_step=float(batch_size) /
                         samples_per_insert)
Esempio n. 11
0
    def __init__(self,
                 environment_spec: specs.EnvironmentSpec,
                 policy_network: snt.Module,
                 critic_network: snt.Module,
                 encoder_network: types.TensorTransformation = tf.identity,
                 entropy_coeff: float = 0.01,
                 target_update_period: int = 0,
                 discount: float = 0.99,
                 batch_size: int = 256,
                 policy_learn_rate: float = 3e-4,
                 critic_learn_rate: float = 5e-4,
                 prefetch_size: int = 4,
                 min_replay_size: int = 1000,
                 max_replay_size: int = 250000,
                 samples_per_insert: float = 64.0,
                 n_step: int = 5,
                 sigma: float = 0.5,
                 clipping: bool = True,
                 logger: loggers.Logger = None,
                 counter: counting.Counter = None,
                 checkpoint: bool = True,
                 replay_table_name: str = adders.DEFAULT_PRIORITY_TABLE):
        """Initialize the agent.

    Args:
      environment_spec: description of the actions, observations, etc.
      policy_network: the online (optimized) policy.
      critic_network: the online critic.
      observation_network: optional network to transform the observations before
        they are fed into any network.
      discount: discount to use for TD updates.
      batch_size: batch size for updates.
      prefetch_size: size to prefetch from replay.
      target_update_period: number of learner steps to perform before updating
        the target networks.
      min_replay_size: minimum replay size before updating.
      max_replay_size: maximum replay size.
      samples_per_insert: number of samples to take from replay for every insert
        that is made.
      n_step: number of steps to squash into a single transition.
      sigma: standard deviation of zero-mean, Gaussian exploration noise.
      clipping: whether to clip gradients by global norm.
      logger: logger object to be used by learner.
      counter: counter object used to keep track of steps.
      checkpoint: boolean indicating whether to checkpoint the learner.
      replay_table_name: string indicating what name to give the replay table.
    """
        # Create a replay server to add data to. This uses no limiter behavior in
        # order to allow the Agent interface to handle it.

        dim_actions = np.prod(environment_spec.actions.shape, dtype=int)
        extra_spec = {
            'logP': tf.ones(shape=(1), dtype=tf.float32),
            'policy': tf.ones(shape=(1, dim_actions), dtype=tf.float32)
        }
        # Remove batch dimensions.
        extra_spec = tf2_utils.squeeze_batch_dim(extra_spec)

        replay_table = reverb.Table(
            name=replay_table_name,
            sampler=reverb.selectors.Uniform(),
            remover=reverb.selectors.Fifo(),
            max_size=max_replay_size,
            rate_limiter=reverb.rate_limiters.MinSize(1),
            signature=adders.NStepTransitionAdder.signature(
                environment_spec, extras_spec=extra_spec))
        self._server = reverb.Server([replay_table], port=None)

        # The adder is used to insert observations into replay.
        address = f'localhost:{self._server.port}'
        adder = adders.NStepTransitionAdder(
            priority_fns={replay_table_name: lambda x: 1.},
            client=reverb.Client(address),
            n_step=n_step,
            discount=discount)

        # The dataset provides an interface to sample from replay.
        dataset = datasets.make_reverb_dataset(table=replay_table_name,
                                               server_address=address,
                                               batch_size=batch_size,
                                               prefetch_size=prefetch_size)

        # Make sure observation network is a Sonnet Module.
        observation_network = model.MDPNormalization(environment_spec,
                                                     encoder_network)

        # Get observation and action specs.
        act_spec = environment_spec.actions
        obs_spec = environment_spec.observations

        # Create the behavior policy.
        sampling_head = model.SquashedGaussianSamplingHead(act_spec, sigma)
        self._behavior_network = model.PolicyValueBehaviorNet(
            snt.Sequential([observation_network, policy_network]),
            sampling_head)

        # Create variables.
        emb_spec = tf2_utils.create_variables(observation_network, [obs_spec])
        tf2_utils.create_variables(policy_network, [emb_spec])
        tf2_utils.create_variables(critic_network, [emb_spec, act_spec])

        # Create the actor which defines how we take actions.
        actor = model.SACFeedForwardActor(self._behavior_network, adder)

        if target_update_period > 0:
            target_policy_network = copy.deepcopy(policy_network)
            target_critic_network = copy.deepcopy(critic_network)
            target_observation_network = copy.deepcopy(observation_network)

            tf2_utils.create_variables(target_policy_network, [emb_spec])
            tf2_utils.create_variables(target_critic_network,
                                       [emb_spec, act_spec])
            tf2_utils.create_variables(target_observation_network, [obs_spec])
        else:
            target_policy_network = policy_network
            target_critic_network = critic_network
            target_observation_network = observation_network

        # Create optimizers.
        policy_optimizer = snt.optimizers.Adam(learning_rate=policy_learn_rate)
        critic_optimizer = snt.optimizers.Adam(learning_rate=critic_learn_rate)

        # The learner updates the parameters (and initializes them).
        learner = learning.SACLearner(
            policy_network=policy_network,
            critic_network=critic_network,
            sampling_head=sampling_head,
            observation_network=observation_network,
            target_policy_network=target_policy_network,
            target_critic_network=target_critic_network,
            target_observation_network=target_observation_network,
            policy_optimizer=policy_optimizer,
            critic_optimizer=critic_optimizer,
            target_update_period=target_update_period,
            learning_rate=policy_learn_rate,
            clipping=clipping,
            entropy_coeff=entropy_coeff,
            discount=discount,
            dataset=dataset,
            counter=counter,
            logger=logger,
            checkpoint=checkpoint,
        )

        super().__init__(actor=actor,
                         learner=learner,
                         min_observations=max(batch_size, min_replay_size),
                         observations_per_step=float(batch_size) /
                         samples_per_insert)
Esempio n. 12
0
    def __init__(
        self,
        environment_spec: specs.EnvironmentSpec,
        network: networks.PolicyValueRNN,
        initial_state_fn: Callable[[], networks.RNNState],
        sequence_length: int,
        sequence_period: int,
        counter: counting.Counter = None,
        logger: loggers.Logger = None,
        discount: float = 0.99,
        max_queue_size: int = 100000,
        batch_size: int = 16,
        learning_rate: float = 1e-3,
        entropy_cost: float = 0.01,
        baseline_cost: float = 0.5,
        seed: int = 0,
        max_abs_reward: float = np.inf,
        max_gradient_norm: float = np.inf,
    ):

        num_actions = environment_spec.actions.num_values
        self._logger = logger or loggers.TerminalLogger('agent')
        queue = reverb.Table.queue(name=adders.DEFAULT_PRIORITY_TABLE,
                                   max_size=max_queue_size)
        self._server = reverb.Server([queue], port=None)
        self._can_sample = lambda: queue.can_sample(batch_size)
        address = f'localhost:{self._server.port}'

        # Component to add things into replay.
        adder = adders.SequenceAdder(
            client=reverb.Client(address),
            period=sequence_period,
            sequence_length=sequence_length,
        )

        # The dataset object to learn from.
        extra_spec = {
            'core_state': hk.transform(initial_state_fn).apply(None),
            'logits': np.ones(shape=(num_actions, ), dtype=np.float32)
        }
        # Remove batch dimensions.
        dataset = datasets.make_reverb_dataset(
            client=reverb.TFClient(address),
            environment_spec=environment_spec,
            batch_size=batch_size,
            extra_spec=extra_spec,
            sequence_length=sequence_length)

        rng = hk.PRNGSequence(seed)

        optimizer = optix.chain(
            optix.clip_by_global_norm(max_gradient_norm),
            optix.adam(learning_rate),
        )
        self._learner = learning.IMPALALearner(
            obs_spec=environment_spec.observations,
            network=network,
            initial_state_fn=initial_state_fn,
            iterator=dataset.as_numpy_iterator(),
            rng=rng,
            counter=counter,
            logger=logger,
            optimizer=optimizer,
            discount=discount,
            entropy_cost=entropy_cost,
            baseline_cost=baseline_cost,
            max_abs_reward=max_abs_reward,
        )

        variable_client = jax_variable_utils.VariableClient(self._learner,
                                                            key='policy')
        self._actor = acting.IMPALAActor(
            network=network,
            initial_state_fn=initial_state_fn,
            rng=rng,
            adder=adder,
            variable_client=variable_client,
        )
Esempio n. 13
0
    def __init__(
        self,
        environment_spec: specs.EnvironmentSpec,
        policy_network: snt.Module,
        critic_network: snt.Module,
        observation_network: types.TensorTransformation = tf.identity,
        discount: float = 0.99,
        batch_size: int = 256,
        prefetch_size: int = 4,
        target_policy_update_period: int = 100,
        target_critic_update_period: int = 100,
        min_replay_size: int = 1000,
        max_replay_size: int = 1000000,
        samples_per_insert: float = 32.0,
        policy_loss_module: snt.Module = None,
        policy_optimizer: snt.Optimizer = None,
        critic_optimizer: snt.Optimizer = None,
        n_step: int = 5,
        num_samples: int = 20,
        clipping: bool = True,
        logger: loggers.Logger = None,
        counter: counting.Counter = None,
        checkpoint: bool = True,
        replay_table_name: str = adders.DEFAULT_PRIORITY_TABLE,
    ):
        """Initialize the agent.

    Args:
      environment_spec: description of the actions, observations, etc.
      policy_network: the online (optimized) policy.
      critic_network: the online critic.
      observation_network: optional network to transform the observations before
        they are fed into any network.
      discount: discount to use for TD updates.
      batch_size: batch size for updates.
      prefetch_size: size to prefetch from replay.
      target_policy_update_period: number of updates to perform before updating
        the target policy network.
      target_critic_update_period: number of updates to perform before updating
        the target critic network.
      min_replay_size: minimum replay size before updating.
      max_replay_size: maximum replay size.
      samples_per_insert: number of samples to take from replay for every insert
        that is made.
      policy_loss_module: configured MPO loss function for the policy
        optimization; defaults to sensible values on the control suite.
        See `acme/tf/losses/mpo.py` for more details.
      policy_optimizer: optimizer to be used on the policy.
      critic_optimizer: optimizer to be used on the critic.
      n_step: number of steps to squash into a single transition.
      num_samples: number of actions to sample when doing a Monte Carlo
        integration with respect to the policy.
      clipping: whether to clip gradients by global norm.
      logger: logging object used to write to logs.
      counter: counter object used to keep track of steps.
      checkpoint: boolean indicating whether to checkpoint the learner.
      replay_table_name: string indicating what name to give the replay table.
    """

        # Create a replay server to add data to.
        replay_table = reverb.Table(
            name=adders.DEFAULT_PRIORITY_TABLE,
            sampler=reverb.selectors.Uniform(),
            remover=reverb.selectors.Fifo(),
            max_size=max_replay_size,
            rate_limiter=reverb.rate_limiters.MinSize(min_size_to_sample=1),
            signature=adders.NStepTransitionAdder.signature(environment_spec))
        self._server = reverb.Server([replay_table], port=None)

        # The adder is used to insert observations into replay.
        address = f'localhost:{self._server.port}'
        adder = adders.NStepTransitionAdder(client=reverb.Client(address),
                                            n_step=n_step,
                                            discount=discount)

        # The dataset object to learn from.
        dataset = datasets.make_reverb_dataset(
            table=replay_table_name,
            client=reverb.TFClient(address),
            batch_size=batch_size,
            prefetch_size=prefetch_size,
            environment_spec=environment_spec,
            transition_adder=True)

        # Make sure observation network is a Sonnet Module.
        observation_network = tf2_utils.to_sonnet_module(observation_network)

        # Create target networks before creating online/target network variables.
        target_policy_network = copy.deepcopy(policy_network)
        target_critic_network = copy.deepcopy(critic_network)
        target_observation_network = copy.deepcopy(observation_network)

        # Get observation and action specs.
        act_spec = environment_spec.actions
        obs_spec = environment_spec.observations
        emb_spec = tf2_utils.create_variables(observation_network, [obs_spec])

        # Create the behavior policy.
        behavior_network = snt.Sequential([
            observation_network,
            policy_network,
            networks.StochasticSamplingHead(),
        ])

        # Create variables.
        tf2_utils.create_variables(policy_network, [emb_spec])
        tf2_utils.create_variables(critic_network, [emb_spec, act_spec])
        tf2_utils.create_variables(target_policy_network, [emb_spec])
        tf2_utils.create_variables(target_critic_network, [emb_spec, act_spec])
        tf2_utils.create_variables(target_observation_network, [obs_spec])

        # Create the actor which defines how we take actions.
        actor = actors.FeedForwardActor(policy_network=behavior_network,
                                        adder=adder)

        # Create optimizers.
        policy_optimizer = policy_optimizer or snt.optimizers.Adam(1e-4)
        critic_optimizer = critic_optimizer or snt.optimizers.Adam(1e-4)

        # The learner updates the parameters (and initializes them).
        learner = learning.MPOLearner(
            policy_network=policy_network,
            critic_network=critic_network,
            observation_network=observation_network,
            target_policy_network=target_policy_network,
            target_critic_network=target_critic_network,
            target_observation_network=target_observation_network,
            policy_loss_module=policy_loss_module,
            policy_optimizer=policy_optimizer,
            critic_optimizer=critic_optimizer,
            clipping=clipping,
            discount=discount,
            num_samples=num_samples,
            target_policy_update_period=target_policy_update_period,
            target_critic_update_period=target_critic_update_period,
            dataset=dataset,
            logger=logger,
            counter=counter,
            checkpoint=checkpoint)

        super().__init__(actor=actor,
                         learner=learner,
                         min_observations=max(batch_size, min_replay_size),
                         observations_per_step=float(batch_size) /
                         samples_per_insert)
Esempio n. 14
0
    def __init__(
        self,
        environment_spec: specs.EnvironmentSpec,
        network: snt.Module,
        demonstration_dataset: tf.data.Dataset,
        demonstration_ratio: float,
        batch_size: int = 256,
        prefetch_size: int = 4,
        target_update_period: int = 100,
        samples_per_insert: float = 32.0,
        min_replay_size: int = 1000,
        max_replay_size: int = 1000000,
        importance_sampling_exponent: float = 0.2,
        n_step: int = 5,
        epsilon: tf.Tensor = None,
        learning_rate: float = 1e-3,
        discount: float = 0.99,
    ):
        """Initialize the agent.

    Args:
      environment_spec: description of the actions, observations, etc.
      network: the online Q network (the one being optimized)
      demonstration_dataset: tf.data.Dataset producing (timestep, action)
        tuples containing full episodes.
      demonstration_ratio: Ratio of transitions coming from demonstrations.
      batch_size: batch size for updates.
      prefetch_size: size to prefetch from replay.
      target_update_period: number of learner steps to perform before updating
        the target networks.
      samples_per_insert: number of samples to take from replay for every insert
        that is made.
      min_replay_size: minimum replay size before updating. This and all
        following arguments are related to dataset construction and will be
        ignored if a dataset argument is passed.
      max_replay_size: maximum replay size.
      importance_sampling_exponent: power to which importance weights are raised
        before normalizing.
      n_step: number of steps to squash into a single transition.
      epsilon: probability of taking a random action; ignored if a policy
        network is given.
      learning_rate: learning rate for the q-network update.
      discount: discount to use for TD updates.
    """

        # Create a replay server to add data to. This uses no limiter behavior in
        # order to allow the Agent interface to handle it.
        replay_table = reverb.Table(
            name=adders.DEFAULT_PRIORITY_TABLE,
            sampler=reverb.selectors.Uniform(),
            remover=reverb.selectors.Fifo(),
            max_size=max_replay_size,
            rate_limiter=reverb.rate_limiters.MinSize(1))
        self._server = reverb.Server([replay_table], port=None)

        # The adder is used to insert observations into replay.
        address = f'localhost:{self._server.port}'
        adder = adders.NStepTransitionAdder(client=reverb.Client(address),
                                            n_step=n_step,
                                            discount=discount)

        # The dataset provides an interface to sample from replay.
        replay_client = reverb.TFClient(address)
        dataset = datasets.make_reverb_dataset(
            client=replay_client,
            environment_spec=environment_spec,
            transition_adder=True)

        # Combine with demonstration dataset.
        transition = functools.partial(_n_step_transition_from_episode,
                                       n_step=n_step,
                                       discount=discount)
        dataset_demos = demonstration_dataset.map(transition)
        dataset = tf.data.experimental.sample_from_datasets(
            [dataset, dataset_demos],
            [1 - demonstration_ratio, demonstration_ratio])

        # Batch and prefetch.
        dataset = dataset.batch(batch_size, drop_remainder=True)
        dataset = dataset.prefetch(prefetch_size)

        # Use constant 0.05 epsilon greedy policy by default.
        if epsilon is None:
            epsilon = tf.Variable(0.05, trainable=False)
        policy_network = snt.Sequential([
            network,
            lambda q: trfl.epsilon_greedy(q, epsilon=epsilon).sample(),
        ])

        # Create a target network.
        target_network = copy.deepcopy(network)

        # Ensure that we create the variables before proceeding (maybe not needed).
        tf2_utils.create_variables(network, [environment_spec.observations])
        tf2_utils.create_variables(target_network,
                                   [environment_spec.observations])

        # Create the actor which defines how we take actions.
        actor = actors.FeedForwardActor(policy_network, adder)

        # The learner updates the parameters (and initializes them).
        learner = dqn.DQNLearner(
            network=network,
            target_network=target_network,
            discount=discount,
            importance_sampling_exponent=importance_sampling_exponent,
            learning_rate=learning_rate,
            target_update_period=target_update_period,
            dataset=dataset,
            replay_client=replay_client)

        super().__init__(actor=actor,
                         learner=learner,
                         min_observations=max(batch_size, min_replay_size),
                         observations_per_step=float(batch_size) /
                         samples_per_insert)
Esempio n. 15
0
    def __init__(self,
                 environment_spec: specs.EnvironmentSpec,
                 network: snt.RNNCore,
                 target_network: snt.RNNCore,
                 burn_in_length: int,
                 trace_length: int,
                 replay_period: int,
                 demonstration_dataset: tf.data.Dataset,
                 demonstration_ratio: float,
                 counter: counting.Counter = None,
                 logger: loggers.Logger = None,
                 discount: float = 0.99,
                 batch_size: int = 32,
                 target_update_period: int = 100,
                 importance_sampling_exponent: float = 0.2,
                 epsilon: float = 0.01,
                 learning_rate: float = 1e-3,
                 log_to_bigtable: bool = False,
                 log_name: str = 'agent',
                 checkpoint: bool = True,
                 min_replay_size: int = 1000,
                 max_replay_size: int = 1000000,
                 samples_per_insert: float = 32.0):

        extra_spec = {
            'core_state': network.initial_state(1),
        }
        # Remove batch dimensions.
        extra_spec = tf2_utils.squeeze_batch_dim(extra_spec)
        replay_table = reverb.Table(
            name=adders.DEFAULT_PRIORITY_TABLE,
            sampler=reverb.selectors.Uniform(),
            remover=reverb.selectors.Fifo(),
            max_size=max_replay_size,
            rate_limiter=reverb.rate_limiters.MinSize(min_size_to_sample=1),
            signature=adders.SequenceAdder.signature(environment_spec,
                                                     extra_spec))
        self._server = reverb.Server([replay_table], port=None)
        address = f'localhost:{self._server.port}'

        sequence_length = burn_in_length + trace_length + 1
        # Component to add things into replay.
        sequence_kwargs = dict(
            period=replay_period,
            sequence_length=sequence_length,
        )
        adder = adders.SequenceAdder(client=reverb.Client(address),
                                     **sequence_kwargs)

        # The dataset object to learn from.
        dataset = datasets.make_reverb_dataset(server_address=address,
                                               sequence_length=sequence_length)

        # Combine with demonstration dataset.
        transition = functools.partial(_sequence_from_episode,
                                       extra_spec=extra_spec,
                                       **sequence_kwargs)
        dataset_demos = demonstration_dataset.map(transition)
        dataset = tf.data.experimental.sample_from_datasets(
            [dataset, dataset_demos],
            [1 - demonstration_ratio, demonstration_ratio])

        # Batch and prefetch.
        dataset = dataset.batch(batch_size, drop_remainder=True)
        dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE)

        tf2_utils.create_variables(network, [environment_spec.observations])
        tf2_utils.create_variables(target_network,
                                   [environment_spec.observations])

        learner = learning.R2D2Learner(
            environment_spec=environment_spec,
            network=network,
            target_network=target_network,
            burn_in_length=burn_in_length,
            dataset=dataset,
            reverb_client=reverb.TFClient(address),
            counter=counter,
            logger=logger,
            sequence_length=sequence_length,
            discount=discount,
            target_update_period=target_update_period,
            importance_sampling_exponent=importance_sampling_exponent,
            max_replay_size=max_replay_size,
            learning_rate=learning_rate,
            store_lstm_state=False,
        )

        self._checkpointer = tf2_savers.Checkpointer(
            subdirectory='r2d2_learner',
            time_delta_minutes=60,
            objects_to_save=learner.state,
            enable_checkpointing=checkpoint,
        )

        self._snapshotter = tf2_savers.Snapshotter(
            objects_to_save={'network': network}, time_delta_minutes=60.)

        policy_network = snt.DeepRNN([
            network,
            lambda qs: trfl.epsilon_greedy(qs, epsilon=epsilon).sample(),
        ])

        actor = actors.RecurrentActor(policy_network, adder)
        observations_per_step = (float(replay_period * batch_size) /
                                 samples_per_insert)
        super().__init__(actor=actor,
                         learner=learner,
                         min_observations=replay_period *
                         max(batch_size, min_replay_size),
                         observations_per_step=observations_per_step)
Esempio n. 16
0
    def __init__(
        self,
        environment_spec: specs.EnvironmentSpec,
        network: snt.RNNCore,
        sequence_length: int,
        sequence_period: int,
        counter: counting.Counter = None,
        logger: loggers.Logger = None,
        discount: float = 0.99,
        max_queue_size: int = 100000,
        batch_size: int = 16,
        learning_rate: float = 1e-3,
        entropy_cost: float = 0.01,
        baseline_cost: float = 0.5,
        max_abs_reward: Optional[float] = None,
        max_gradient_norm: Optional[float] = None,
    ):

        num_actions = environment_spec.actions.num_values
        self._logger = logger or loggers.TerminalLogger('agent')

        extra_spec = {
            'core_state': network.initial_state(1),
            'logits': tf.ones(shape=(1, num_actions), dtype=tf.float32)
        }
        # Remove batch dimensions.
        extra_spec = tf2_utils.squeeze_batch_dim(extra_spec)

        queue = reverb.Table.queue(name=adders.DEFAULT_PRIORITY_TABLE,
                                   max_size=max_queue_size,
                                   signature=adders.SequenceAdder.signature(
                                       environment_spec,
                                       extras_spec=extra_spec,
                                       sequence_length=sequence_length))
        self._server = reverb.Server([queue], port=None)
        self._can_sample = lambda: queue.can_sample(batch_size)
        address = f'localhost:{self._server.port}'

        # Component to add things into replay.
        adder = adders.SequenceAdder(
            client=reverb.Client(address),
            period=sequence_period,
            sequence_length=sequence_length,
        )

        # The dataset object to learn from.
        dataset = datasets.make_reverb_dataset(server_address=address,
                                               batch_size=batch_size)

        tf2_utils.create_variables(network, [environment_spec.observations])

        self._actor = acting.IMPALAActor(network, adder)
        self._learner = learning.IMPALALearner(
            environment_spec=environment_spec,
            network=network,
            dataset=dataset,
            counter=counter,
            logger=logger,
            discount=discount,
            learning_rate=learning_rate,
            entropy_cost=entropy_cost,
            baseline_cost=baseline_cost,
            max_gradient_norm=max_gradient_norm,
            max_abs_reward=max_abs_reward,
        )
Esempio n. 17
0
    def learner(
        self,
        replay: reverb.Client,
        counter: counting.Counter,
    ):
        """The Learning part of the agent."""

        act_spec = self._environment_spec.actions
        obs_spec = self._environment_spec.observations

        # Create the networks to optimize (online) and target networks.
        online_networks = self._network_factory(act_spec)
        target_networks = self._network_factory(act_spec)

        # Make sure observation network is a Sonnet Module.
        observation_network = online_networks.get('observation', tf.identity)
        target_observation_network = target_networks.get(
            'observation', tf.identity)
        observation_network = tf2_utils.to_sonnet_module(observation_network)
        target_observation_network = tf2_utils.to_sonnet_module(
            target_observation_network)

        # Get embedding spec and create observation network variables.
        emb_spec = tf2_utils.create_variables(observation_network, [obs_spec])

        # Create variables.
        tf2_utils.create_variables(online_networks['policy'], [emb_spec])
        tf2_utils.create_variables(online_networks['critic'],
                                   [emb_spec, act_spec])
        tf2_utils.create_variables(target_networks['policy'], [emb_spec])
        tf2_utils.create_variables(target_networks['critic'],
                                   [emb_spec, act_spec])
        tf2_utils.create_variables(target_observation_network, [obs_spec])

        # The dataset object to learn from.
        dataset = datasets.make_reverb_dataset(
            server_address=replay.server_address,
            batch_size=self._batch_size,
            prefetch_size=self._prefetch_size)

        # Create optimizers.
        policy_optimizer = snt.optimizers.Adam(learning_rate=1e-4)
        critic_optimizer = snt.optimizers.Adam(learning_rate=1e-4)

        counter = counting.Counter(counter, 'learner')
        logger = loggers.make_default_logger('learner',
                                             time_delta=self._log_every,
                                             steps_key='learner_steps')

        # Return the learning agent.
        return learning.DDPGLearner(
            policy_network=online_networks['policy'],
            critic_network=online_networks['critic'],
            observation_network=observation_network,
            target_policy_network=target_networks['policy'],
            target_critic_network=target_networks['critic'],
            target_observation_network=target_observation_network,
            discount=self._discount,
            target_update_period=self._target_update_period,
            dataset=dataset,
            policy_optimizer=policy_optimizer,
            critic_optimizer=critic_optimizer,
            clipping=self._clipping,
            counter=counter,
            logger=logger,
        )
Esempio n. 18
0
    """Creates a single-process replay infrastructure from an environment spec."""
    # Create a replay server to add data to. This uses no limiter behavior in
    # order to allow the Agent interface to handle it.
    replay_table = reverb.Table(
        name=replay_table_name,
        sampler=reverb.selectors.Prioritized(priority_exponent),
        remover=reverb.selectors.Fifo(),
        max_size=max_replay_size,
        rate_limiter=reverb.rate_limiters.MinSize(min_replay_size),
        signature=adders.NStepTransitionAdder.signature(
            environment_spec=environment_spec))
    server = reverb.Server([replay_table], port=None)

    # The adder is used to insert observations into replay.
    address = f'localhost:{server.port}'
    client = reverb.Client(address)
    adder = adders.NStepTransitionAdder(client=client,
                                        n_step=n_step,
                                        discount=discount)

    # The dataset provides an interface to sample from replay.
    data_iterator = datasets.make_reverb_dataset(
        table=replay_table_name,
        server_address=address,
        batch_size=batch_size,
        prefetch_size=prefetch_size,
        environment_spec=environment_spec,
        transition_adder=True,
    ).as_numpy_iterator()
    return ReverbReplay(server, adder, data_iterator, client)
Esempio n. 19
0
            environment_spec, extra_spec),
    )
    server = reverb.Server([replay_table], port=None)

    # The adder is used to insert observations into replay.
    address = f'localhost:{server.port}'
    client = reverb.Client(address)
    adder = adders.NStepTransitionAdder(client,
                                        n_step,
                                        discount,
                                        priority_fns=priority_fns)

    # The dataset provides an interface to sample from replay.
    data_iterator = datasets.make_reverb_dataset(
        table=replay_table_name,
        server_address=address,
        batch_size=batch_size,
        prefetch_size=prefetch_size,
    ).as_numpy_iterator()
    return ReverbReplay(server, adder, data_iterator, client=client)


def make_reverb_online_queue(
    environment_spec: specs.EnvironmentSpec,
    extra_spec: Dict[str, Any],
    max_queue_size: int,
    sequence_length: int,
    sequence_period: int,
    batch_size: int,
    replay_table_name: str = adders.DEFAULT_PRIORITY_TABLE,
) -> ReverbReplay:
    """Creates a single process queue from an environment spec and extra_spec."""
Esempio n. 20
0
    def __init__(self,
                 environment_spec: specs.EnvironmentSpec,
                 policy_network: snt.Module,
                 critic_network: snt.Module,
                 observation_network: types.TensorTransformation = tf.identity,
                 discount: float = 0.99,
                 batch_size: int = 256,
                 prefetch_size: int = 4,
                 target_update_period: int = 100,
                 min_replay_size: int = 1000,
                 max_replay_size: int = 1000000,
                 samples_per_insert: float = 32.0,
                 n_step: int = 5,
                 sigma: float = 0.3,
                 clipping: bool = True,
                 logger: loggers.Logger = None,
                 counter: counting.Counter = None,
                 checkpoint: bool = True,
                 replay_table_name: str = adders.DEFAULT_PRIORITY_TABLE):
        """Initialize the agent.

    Args:
      environment_spec: description of the actions, observations, etc.
      policy_network: the online (optimized) policy.
      critic_network: the online critic.
      observation_network: optional network to transform the observations before
        they are fed into any network.
      discount: discount to use for TD updates.
      batch_size: batch size for updates.
      prefetch_size: size to prefetch from replay.
      target_update_period: number of learner steps to perform before updating
        the target networks.
      min_replay_size: minimum replay size before updating.
      max_replay_size: maximum replay size.
      samples_per_insert: number of samples to take from replay for every insert
        that is made.
      n_step: number of steps to squash into a single transition.
      sigma: standard deviation of zero-mean, Gaussian exploration noise.
      clipping: whether to clip gradients by global norm.
      logger: logger object to be used by learner.
      counter: counter object used to keep track of steps.
      checkpoint: boolean indicating whether to checkpoint the learner.
      replay_table_name: string indicating what name to give the replay table.
    """
        # Create a replay server to add data to. This uses no limiter behavior in
        # order to allow the Agent interface to handle it.
        replay_table = reverb.Table(
            name=replay_table_name,
            sampler=reverb.selectors.Uniform(),
            remover=reverb.selectors.Fifo(),
            max_size=max_replay_size,
            rate_limiter=reverb.rate_limiters.MinSize(1),
            signature=adders.NStepTransitionAdder.signature(environment_spec))
        self._server = reverb.Server([replay_table], port=None)

        # The adder is used to insert observations into replay.
        address = f'localhost:{self._server.port}'
        adder = adders.NStepTransitionAdder(
            priority_fns={replay_table_name: lambda x: 1.},
            client=reverb.Client(address),
            n_step=n_step,
            discount=discount)

        # The dataset provides an interface to sample from replay.
        dataset = datasets.make_reverb_dataset(
            table=replay_table_name,
            client=reverb.TFClient(address),
            environment_spec=environment_spec,
            batch_size=batch_size,
            prefetch_size=prefetch_size,
            transition_adder=True)

        # Get observation and action specs.
        act_spec = environment_spec.actions
        obs_spec = environment_spec.observations
        emb_spec = tf2_utils.create_variables(observation_network, [obs_spec])  # pytype: disable=wrong-arg-types

        # Make sure observation network is a Sonnet Module.
        observation_network = tf2_utils.to_sonnet_module(observation_network)

        # Create target networks.
        target_policy_network = copy.deepcopy(policy_network)
        target_critic_network = copy.deepcopy(critic_network)
        target_observation_network = copy.deepcopy(observation_network)

        # Create the behavior policy.
        behavior_network = snt.Sequential([
            observation_network,
            policy_network,
            networks.ClippedGaussian(sigma),
            networks.ClipToSpec(act_spec),
        ])

        # Create variables.
        tf2_utils.create_variables(policy_network, [emb_spec])
        tf2_utils.create_variables(critic_network, [emb_spec, act_spec])
        tf2_utils.create_variables(target_policy_network, [emb_spec])
        tf2_utils.create_variables(target_critic_network, [emb_spec, act_spec])
        tf2_utils.create_variables(target_observation_network, [obs_spec])

        # Create the actor which defines how we take actions.
        actor = actors.FeedForwardActor(behavior_network, adder=adder)

        # Create optimizers.
        policy_optimizer = snt.optimizers.Adam(learning_rate=1e-4)
        critic_optimizer = snt.optimizers.Adam(learning_rate=1e-4)

        # The learner updates the parameters (and initializes them).
        learner = learning.DDPGLearner(
            policy_network=policy_network,
            critic_network=critic_network,
            observation_network=observation_network,
            target_policy_network=target_policy_network,
            target_critic_network=target_critic_network,
            target_observation_network=target_observation_network,
            policy_optimizer=policy_optimizer,
            critic_optimizer=critic_optimizer,
            clipping=clipping,
            discount=discount,
            target_update_period=target_update_period,
            dataset=dataset,
            counter=counter,
            logger=logger,
            checkpoint=checkpoint,
        )

        super().__init__(actor=actor,
                         learner=learner,
                         min_observations=max(batch_size, min_replay_size),
                         observations_per_step=float(batch_size) /
                         samples_per_insert)
Esempio n. 21
0
    def __init__(
        self,
        environment_spec: specs.EnvironmentSpec,
        network: snt.Module,
        params=None,
        logger: loggers.Logger = None,
        checkpoint: bool = True,
        paths: Save_paths = None,
    ):
        """Initialize the agent.

        Args:
          environment_spec: description of the actions, observations, etc.
          network: the online Q network (the one being optimized)
          batch_size: batch size for updates.
          prefetch_size: size to prefetch from replay.
          target_update_period: number of learner steps to perform before updating
            the target networks.
          samples_per_insert: number of samples to take from replay for every insert
            that is made.
          min_replay_size: minimum replay size before updating. This and all
            following arguments are related to dataset construction and will be
            ignored if a dataset argument is passed.
          max_replay_size: maximum replay size.
          importance_sampling_exponent: power to which importance weights are raised
            before normalizing.
          priority_exponent: exponent used in prioritized sampling.
          n_step: number of steps to squash into a single transition.
          epsilon: probability of taking a random action; ignored if a policy
            network is given.
          learning_rate: learning rate for the q-network update.
          discount: discount to use for TD updates.
          logger: logger object to be used by learner.
          checkpoint: boolean indicating whether to checkpoint the learner.
        """

        # Create a replay server to add data to. This uses no limiter behavior in
        # order to allow the Agent interface to handle it.
        if params is None:
            params = {
                'batch_size': 256,
                'prefetch_size': 4,
                'target_update_period': 100,
                'samples_per_insert': 32.0,
                'min_replay_size': 1000,
                'max_replay_size': 1000000,
                'importance_sampling_exponent': 0.2,
                'priority_exponent': 0.6,
                'n_step': 5,
                'epsilon': 0.05,
                'learning_rate': 1e-3,
                'discount': 0.99,
            }
        replay_table = reverb.Table(
            name=adders.DEFAULT_PRIORITY_TABLE,
            sampler=reverb.selectors.Prioritized(params['priority_exponent']),
            remover=reverb.selectors.Fifo(),
            max_size=params['max_replay_size'],
            rate_limiter=reverb.rate_limiters.MinSize(1))
        self._server = reverb.Server([replay_table], port=None)

        # The adder is used to insert observations into replay.
        address = f'localhost:{self._server.port}'
        adder = adders.NStepTransitionAdder(client=reverb.Client(address),
                                            n_step=params['n_step'],
                                            discount=params['discount'])

        # The dataset provides an interface to sample from replay.
        replay_client = reverb.TFClient(address)
        dataset = datasets.make_reverb_dataset(
            client=replay_client,
            environment_spec=environment_spec,
            batch_size=params['batch_size'],
            prefetch_size=params['prefetch_size'],
            transition_adder=True)

        # Use constant 0.05 epsilon greedy policy by default.
        epsilon = tf.Variable(params['epsilon'], trainable=False)

        policy_network = snt.Sequential([
            network,
            lambda q: trfl.epsilon_greedy(q, epsilon=epsilon).sample(),
        ])

        # Create a target network.
        target_network = copy.deepcopy(network)

        # Ensure that we create the variables before proceeding (maybe not needed).
        # tf2_utils.create_variables(network, [environment_spec.observations])
        # tf2_utils.create_variables(target_network, [environment_spec.observations])

        # Create the actor which defines how we take actions.
        actor = actors.FeedForwardActor(policy_network, adder)

        # The learner updates the parameters (and initializes them).
        learner = learning.DQNLearner(
            network=network,
            target_network=target_network,
            discount=params['discount'],
            importance_sampling_exponent=params[
                'importance_sampling_exponent'],
            learning_rate=params['learning_rate'],
            target_update_period=params['target_update_period'],
            dataset=dataset,
            replay_client=replay_client,
            logger=logger,
            checkpoint=checkpoint)

        if checkpoint:
            self._checkpointer = tf2_savers.Checkpointer(
                add_uid=False,
                objects_to_save=learner.state,
                directory=paths.data_dir,
                subdirectory=paths.experiment_name,
                time_delta_minutes=60.)
        else:
            self._checkpointer = None

        super().__init__(actor=actor,
                         learner=learner,
                         min_observations=max(params['batch_size'],
                                              params['min_replay_size']),
                         observations_per_step=float(params['batch_size']) /
                         params['samples_per_insert'])
Esempio n. 22
0
    def __init__(
        self,
        network: snt.Module,
        model: models.Model,
        optimizer: snt.Optimizer,
        n_step: int,
        discount: float,
        replay_capacity: int,
        num_simulations: int,
        environment_spec: specs.EnvironmentSpec,
        batch_size: int,
    ):

        # Create a replay server for storing transitions.
        replay_table = reverb.Table(
            name=adders.DEFAULT_PRIORITY_TABLE,
            sampler=reverb.selectors.Uniform(),
            remover=reverb.selectors.Fifo(),
            max_size=replay_capacity,
            rate_limiter=reverb.rate_limiters.MinSize(1))
        self._server = reverb.Server([replay_table], port=None)

        # The adder is used to insert observations into replay.
        address = f'localhost:{self._server.port}'
        adder = adders.NStepTransitionAdder(client=reverb.Client(address),
                                            n_step=n_step,
                                            discount=discount)

        # The dataset provides an interface to sample from replay.
        replay_client = reverb.TFClient(address)
        action_spec: specs.DiscreteArray = environment_spec.actions
        dataset = datasets.make_reverb_dataset(
            client=replay_client,
            environment_spec=environment_spec,
            extra_spec={
                'pi':
                specs.Array(shape=(action_spec.num_values, ), dtype=np.float32)
            },
            transition_adder=True)

        dataset = dataset.batch(batch_size, drop_remainder=True)

        tf2_utils.create_variables(network, [environment_spec.observations])

        # Now create the agent components: actor & learner.
        actor = acting.MCTSActor(
            environment_spec=environment_spec,
            model=model,
            network=network,
            discount=discount,
            adder=adder,
            num_simulations=num_simulations,
        )

        learner = learning.AZLearner(
            network=network,
            optimizer=optimizer,
            dataset=dataset,
            discount=discount,
        )

        # The parent class combines these together into one 'agent'.
        super().__init__(
            actor=actor,
            learner=learner,
            min_observations=10,
            observations_per_step=1,
        )
Esempio n. 23
0
    def __init__(
        self,
        environment_spec: specs.EnvironmentSpec,
        network: snt.Module,
        batch_size: int = 32,
        prefetch_size: int = 4,
        target_update_period: int = 100,
        samples_per_insert: float = 32.0,
        min_replay_size: int = 1000,
        max_replay_size: int = 100000,
        importance_sampling_exponent: float = 0.2,
        priority_exponent: float = 0.6,
        n_step: int = 5,
        epsilon: Optional[float] = 0.05,
        learning_rate: float = 1e-3,
        discount: float = 0.99,
        logger: loggers.Logger = None,
        max_gradient_norm: Optional[float] = None,
        expert_data: List[Dict] = None,
    ) -> None:
        """ Initialize the agent. """

        # Create a replay server to add data to. This uses no limiter behavior
        # in order to allow the Agent interface to handle it.
        replay_table = reverb.Table(
            name=adders.DEFAULT_PRIORITY_TABLE,
            sampler=reverb.selectors.Prioritized(priority_exponent),
            remover=reverb.selectors.Fifo(),
            max_size=max_replay_size,
            rate_limiter=reverb.rate_limiters.MinSize(1),
            signature=adders.NStepTransitionAdder.signature(environment_spec))
        self._server = reverb.Server([replay_table], port=None)

        # The adder is used to insert observations into replay.
        address = f'localhost:{self._server.port}'
        adder = adders.NStepTransitionAdder(client=reverb.Client(address),
                                            n_step=n_step,
                                            discount=discount)

        # Adding expert data to the replay memory:
        if expert_data is not None:
            for d in expert_data:
                adder.add_first(d["first"])
                for (action, next_ts) in d["mid"]:
                    adder.add(np.int32(action), next_ts)

        # The dataset provides an interface to sample from replay.
        replay_client = reverb.TFClient(address)
        dataset = datasets.make_reverb_dataset(server_address=address,
                                               batch_size=batch_size,
                                               prefetch_size=prefetch_size)

        # Creating the epsilon greedy policy network:
        epsilon = tf.Variable(epsilon)
        policy_network = snt.Sequential([
            network,
            lambda q: trfl.epsilon_greedy(q, epsilon=epsilon).sample(),
        ])

        # Create a target network.
        target_network = copy.deepcopy(network)

        # Ensure that we create the variables before proceeding (maybe not
        # needed).
        tf2_utils.create_variables(network, [environment_spec.observations])
        tf2_utils.create_variables(target_network,
                                   [environment_spec.observations])

        # Create the actor which defines how we take actions.
        actor = actors.FeedForwardActor(policy_network, adder)

        # The learner updates the parameters (and initializes them).
        learner = learning.DQNLearner(
            network=network,
            target_network=target_network,
            discount=discount,
            importance_sampling_exponent=importance_sampling_exponent,
            learning_rate=learning_rate,
            target_update_period=target_update_period,
            dataset=dataset,
            replay_client=replay_client,
            max_gradient_norm=max_gradient_norm,
            logger=logger,
        )

        super().__init__(actor=actor,
                         learner=learner,
                         min_observations=max(batch_size, min_replay_size),
                         observations_per_step=float(batch_size) /
                         samples_per_insert)
Esempio n. 24
0
    def __init__(
        self,
        environment_spec: specs.EnvironmentSpec,
        network: hk.Transformed,
        batch_size: int = 256,
        prefetch_size: int = 4,
        target_update_period: int = 100,
        samples_per_insert: float = 32.0,
        min_replay_size: int = 1000,
        max_replay_size: int = 1000000,
        importance_sampling_exponent: float = 0.2,
        priority_exponent: float = 0.6,
        n_step: int = 5,
        epsilon: float = 0.,
        learning_rate: float = 1e-3,
        discount: float = 0.99,
        seed: int = 1,
    ):
        """Initialize the agent."""

        # Create a replay server to add data to. This uses no limiter behavior in
        # order to allow the Agent interface to handle it.
        replay_table = reverb.Table(
            name=adders.DEFAULT_PRIORITY_TABLE,
            sampler=reverb.selectors.Prioritized(priority_exponent),
            remover=reverb.selectors.Fifo(),
            max_size=max_replay_size,
            rate_limiter=reverb.rate_limiters.MinSize(1),
            signature=adders.NStepTransitionAdder.signature(
                environment_spec=environment_spec))
        self._server = reverb.Server([replay_table], port=None)

        # The adder is used to insert observations into replay.
        address = f'localhost:{self._server.port}'
        adder = adders.NStepTransitionAdder(client=reverb.Client(address),
                                            n_step=n_step,
                                            discount=discount)

        # The dataset provides an interface to sample from replay.
        dataset = datasets.make_reverb_dataset(
            server_address=address,
            environment_spec=environment_spec,
            batch_size=batch_size,
            prefetch_size=prefetch_size,
            transition_adder=True)

        def policy(params: hk.Params, key: jnp.ndarray,
                   observation: jnp.ndarray) -> jnp.ndarray:
            action_values = network.apply(params, observation)
            return rlax.epsilon_greedy(epsilon).sample(key, action_values)

        # The learner updates the parameters (and initializes them).
        learner = learning.DQNLearner(
            network=network,
            obs_spec=environment_spec.observations,
            rng=hk.PRNGSequence(seed),
            optimizer=optax.adam(learning_rate),
            discount=discount,
            importance_sampling_exponent=importance_sampling_exponent,
            target_update_period=target_update_period,
            iterator=dataset.as_numpy_iterator(),
            replay_client=reverb.Client(address),
        )

        variable_client = variable_utils.VariableClient(learner, '')

        actor = actors.FeedForwardActor(policy=policy,
                                        rng=hk.PRNGSequence(seed),
                                        variable_client=variable_client,
                                        adder=adder)

        super().__init__(actor=actor,
                         learner=learner,
                         min_observations=max(batch_size, min_replay_size),
                         observations_per_step=float(batch_size) /
                         samples_per_insert)
Esempio n. 25
0
    def __init__(
        self,
        environment_spec: specs.EnvironmentSpec,
        network: snt.RNNCore,
        burn_in_length: int,
        trace_length: int,
        replay_period: int,
        counter: counting.Counter = None,
        logger: loggers.Logger = None,
        discount: float = 0.99,
        batch_size: int = 32,
        prefetch_size: int = tf.data.experimental.AUTOTUNE,
        target_update_period: int = 100,
        importance_sampling_exponent: float = 0.2,
        priority_exponent: float = 0.6,
        epsilon: float = 0.01,
        learning_rate: float = 1e-3,
        min_replay_size: int = 1000,
        max_replay_size: int = 1000000,
        samples_per_insert: float = 32.0,
        store_lstm_state: bool = True,
        max_priority_weight: float = 0.9,
        checkpoint: bool = True,
    ):

        replay_table = reverb.Table(
            name=adders.DEFAULT_PRIORITY_TABLE,
            sampler=reverb.selectors.Prioritized(priority_exponent),
            remover=reverb.selectors.Fifo(),
            max_size=max_replay_size,
            rate_limiter=reverb.rate_limiters.MinSize(min_size_to_sample=1))
        self._server = reverb.Server([replay_table], port=None)
        address = f'localhost:{self._server.port}'

        sequence_length = burn_in_length + trace_length + 1
        # Component to add things into replay.
        adder = adders.SequenceAdder(
            client=reverb.Client(address),
            period=replay_period,
            sequence_length=sequence_length,
        )

        # The dataset object to learn from.
        reverb_client = reverb.TFClient(address)
        extra_spec = {
            'core_state': network.initial_state(1),
        }
        # Remove batch dimensions.
        extra_spec = tf2_utils.squeeze_batch_dim(extra_spec)
        dataset = datasets.make_reverb_dataset(
            client=reverb_client,
            environment_spec=environment_spec,
            batch_size=batch_size,
            prefetch_size=prefetch_size,
            extra_spec=extra_spec,
            sequence_length=sequence_length)

        target_network = copy.deepcopy(network)
        tf2_utils.create_variables(network, [environment_spec.observations])
        tf2_utils.create_variables(target_network,
                                   [environment_spec.observations])

        learner = learning.R2D2Learner(
            environment_spec=environment_spec,
            network=network,
            target_network=target_network,
            burn_in_length=burn_in_length,
            sequence_length=sequence_length,
            dataset=dataset,
            reverb_client=reverb_client,
            counter=counter,
            logger=logger,
            discount=discount,
            target_update_period=target_update_period,
            importance_sampling_exponent=importance_sampling_exponent,
            max_replay_size=max_replay_size,
            learning_rate=learning_rate,
            store_lstm_state=store_lstm_state,
            max_priority_weight=max_priority_weight,
        )

        self._checkpointer = tf2_savers.Checkpointer(
            subdirectory='r2d2_learner',
            time_delta_minutes=60,
            objects_to_save=learner.state,
            enable_checkpointing=checkpoint,
        )
        self._snapshotter = tf2_savers.Snapshotter(
            objects_to_save={'network': network}, time_delta_minutes=60.)

        policy_network = snt.DeepRNN([
            network,
            lambda qs: trfl.epsilon_greedy(qs, epsilon=epsilon).sample(),
        ])

        actor = actors.RecurrentActor(policy_network, adder)
        observations_per_step = (float(replay_period * batch_size) /
                                 samples_per_insert)
        super().__init__(actor=actor,
                         learner=learner,
                         min_observations=replay_period *
                         max(batch_size, min_replay_size),
                         observations_per_step=observations_per_step)