def learner(self, replay: reverb.Client, counter: counting.Counter): """The Learning part of the agent.""" # Create the networks. network = self._network_factory(self._env_spec.actions) target_network = copy.deepcopy(network) tf2_utils.create_variables(network, [self._env_spec.observations]) tf2_utils.create_variables(target_network, [self._env_spec.observations]) # The dataset object to learn from. replay_client = reverb.Client(replay.server_address) dataset = datasets.make_reverb_dataset( server_address=replay.server_address, batch_size=self._batch_size, prefetch_size=self._prefetch_size) logger = loggers.make_default_logger('learner', steps_key='learner_steps') # Return the learning agent. counter = counting.Counter(counter, 'learner') learner = learning.DQNLearner( network=network, target_network=target_network, discount=self._discount, importance_sampling_exponent=self._importance_sampling_exponent, learning_rate=self._learning_rate, target_update_period=self._target_update_period, dataset=dataset, replay_client=replay_client, counter=counter, logger=logger) return tf2_savers.CheckpointingRunner(learner, subdirectory='dqn_learner', time_delta_minutes=60)
def __init__( self, environment_spec: specs.EnvironmentSpec, network: snt.Module, batch_size: int = 256, prefetch_size: int = 4, target_update_period: int = 100, samples_per_insert: float = 32.0, min_replay_size: int = 1000, max_replay_size: int = 1000000, importance_sampling_exponent: float = 0.2, priority_exponent: float = 0.6, n_step: int = 5, epsilon: tf.Tensor = None, learning_rate: float = 1e-3, discount: float = 0.99, logger: loggers.Logger = None, checkpoint: bool = True, checkpoint_subpath: str = '~/acme/', ): """Initialize the agent. Args: environment_spec: description of the actions, observations, etc. network: the online Q network (the one being optimized) batch_size: batch size for updates. prefetch_size: size to prefetch from replay. target_update_period: number of learner steps to perform before updating the target networks. samples_per_insert: number of samples to take from replay for every insert that is made. min_replay_size: minimum replay size before updating. This and all following arguments are related to dataset construction and will be ignored if a dataset argument is passed. max_replay_size: maximum replay size. importance_sampling_exponent: power to which importance weights are raised before normalizing. priority_exponent: exponent used in prioritized sampling. n_step: number of steps to squash into a single transition. epsilon: probability of taking a random action; ignored if a policy network is given. learning_rate: learning rate for the q-network update. discount: discount to use for TD updates. logger: logger object to be used by learner. checkpoint: boolean indicating whether to checkpoint the learner. checkpoint_subpath: directory for the checkpoint. """ # Create a replay server to add data to. This uses no limiter behavior in # order to allow the Agent interface to handle it. replay_table = reverb.Table( name=adders.DEFAULT_PRIORITY_TABLE, sampler=reverb.selectors.Prioritized(priority_exponent), remover=reverb.selectors.Fifo(), max_size=max_replay_size, rate_limiter=reverb.rate_limiters.MinSize(1), signature=adders.NStepTransitionAdder.signature(environment_spec)) self._server = reverb.Server([replay_table], port=None) # The adder is used to insert observations into replay. address = f'localhost:{self._server.port}' adder = adders.NStepTransitionAdder(client=reverb.Client(address), n_step=n_step, discount=discount) # The dataset provides an interface to sample from replay. replay_client = reverb.TFClient(address) dataset = datasets.make_reverb_dataset( client=replay_client, environment_spec=environment_spec, batch_size=batch_size, prefetch_size=prefetch_size, transition_adder=True) # Use constant 0.05 epsilon greedy policy by default. if epsilon is None: epsilon = tf.Variable(0.05, trainable=False) policy_network = snt.Sequential([ network, lambda q: trfl.epsilon_greedy(q, epsilon=epsilon).sample(), ]) # Create a target network. target_network = copy.deepcopy(network) # Ensure that we create the variables before proceeding (maybe not needed). tf2_utils.create_variables(network, [environment_spec.observations]) tf2_utils.create_variables(target_network, [environment_spec.observations]) # Create the actor which defines how we take actions. actor = actors.FeedForwardActor(policy_network, adder) # The learner updates the parameters (and initializes them). learner = learning.DQNLearner( network=network, target_network=target_network, discount=discount, importance_sampling_exponent=importance_sampling_exponent, learning_rate=learning_rate, target_update_period=target_update_period, dataset=dataset, replay_client=replay_client, logger=logger, checkpoint=checkpoint) if checkpoint: self._checkpointer = tf2_savers.Checkpointer( directory=checkpoint_subpath, objects_to_save=learner.state, subdirectory='dqn_learner', time_delta_minutes=60.) else: self._checkpointer = None super().__init__(actor=actor, learner=learner, min_observations=max(batch_size, min_replay_size), observations_per_step=float(batch_size) / samples_per_insert)
def __init__( self, environment_spec: specs.EnvironmentSpec, network: snt.Module, batch_size: int = 256, prefetch_size: int = 4, target_update_period: int = 100, samples_per_insert: float = 32.0, min_replay_size: int = 20, max_replay_size: int = 1000000, importance_sampling_exponent: float = 0.2, priority_exponent: float = 0.6, n_step: int = 5, epsilon_init: float = 1.0, epsilon_final: float = 0.01, epsilon_schedule_timesteps: int = 20000, learning_rate: float = 1e-3, discount: float = 0.99, max_gradient_norm: Optional[float] = None, logger: loggers.Logger = None, ): """Initialize the agent. Args: environment_spec: description of the actions, observations, etc. network: the online Q network (the one being optimized) batch_size: batch size for updates. prefetch_size: size to prefetch from replay. target_update_period: number of learner steps to perform before updating the target networks. samples_per_insert: number of samples to take from replay for every insert that is made. min_replay_size: minimum replay size before updating. This and all following arguments are related to dataset construction and will be ignored if a dataset argument is passed. max_replay_size: maximum replay size. importance_sampling_exponent: power to which importance weights are raised before normalizing (beta). See https://arxiv.org/pdf/1710.02298.pdf priority_exponent: exponent used in prioritized sampling (omega). See https://arxiv.org/pdf/1710.02298.pdf n_step: number of steps to squash into a single transition. epsilon_init: Initial epsilon value (probability of taking a random action) epsilon_final: Final epsilon value (probability of taking a random action) epsilon_schedule_timesteps: timesteps to decay epsilon from 'epsilon_init' to 'epsilon_final'. learning_rate: learning rate for the q-network update. discount: discount to use for TD updates. logger: logger object to be used by learner. max_gradient_norm: used for gradient clipping. """ # Create a replay server to add data to. This uses no limiter behavior in # order to allow the Agent interface to handle it. replay_table = reverb.Table( name=adders.DEFAULT_PRIORITY_TABLE, sampler=reverb.selectors.Prioritized(priority_exponent), remover=reverb.selectors.Fifo(), max_size=max_replay_size, rate_limiter=reverb.rate_limiters.MinSize(1), signature=adders.NStepTransitionAdder.signature(environment_spec)) self._server = reverb.Server([replay_table], port=None) # The adder is used to insert observations into replay. address = f'localhost:{self._server.port}' self._adder = adders.NStepTransitionAdder( client=reverb.Client(address), n_step=n_step, discount=discount) # The dataset provides an interface to sample from replay. replay_client = reverb.TFClient(address) dataset = make_reverb_dataset(server_address=address, batch_size=batch_size, prefetch_size=prefetch_size) policy_network = snt.Sequential([ network, EpsilonGreedyExploration( epsilon_init=epsilon_init, epsilon_final=epsilon_final, epsilon_schedule_timesteps=epsilon_schedule_timesteps) ]) # Create a target network. target_network = copy.deepcopy(network) # Ensure that we create the variables before proceeding (maybe not needed). tf2_utils.create_variables(network, [environment_spec.observations]) tf2_utils.create_variables(target_network, [environment_spec.observations]) # Create the actor which defines how we take actions. actor = actors_tf2.FeedForwardActor(policy_network, self._adder) # The learner updates the parameters (and initializes them). learner = learning.DQNLearner( network=network, target_network=target_network, discount=discount, importance_sampling_exponent=importance_sampling_exponent, learning_rate=learning_rate, target_update_period=target_update_period, dataset=dataset, replay_client=replay_client, max_gradient_norm=max_gradient_norm, logger=logger, checkpoint=False) self._saver = tf2_savers.Saver(learner.state) # Deterministic (max-Q) actor. max_Q_network = snt.Sequential([ network, lambda q: trfl.epsilon_greedy(q, epsilon=0.0).sample(), ]) self._deterministic_actor = actors_tf2.FeedForwardActor(max_Q_network) super().__init__(actor=actor, learner=learner, min_observations=max(batch_size, min_replay_size), observations_per_step=float(batch_size) / samples_per_insert)