def make_dataset_iterator( self, replay_client: reverb.Client, ) -> Iterator[reverb.ReplaySample]: """Create a dataset iterator to use for training/updating the system. Args: replay_client (reverb.Client): Reverb Client which points to the replay server. Returns: [type]: dataset iterator. Yields: Iterator[reverb.ReplaySample]: data samples from the dataset. """ sequence_length = ( self._config.sequence_length if issubclass(self._executor_fn, executors.RecurrentExecutor) else None ) dataset = datasets.make_reverb_dataset( table=self._config.replay_table_name, server_address=replay_client.server_address, batch_size=self._config.batch_size, prefetch_size=self._config.prefetch_size, sequence_length=sequence_length, ) return iter(dataset)
def learner(self, queue: reverb.Client, counter: counting.Counter): """The Learning part of the agent.""" # Use architect and create the environment. # Create the networks. network = self._network_factory(self._environment_spec.actions) tf2_utils.create_variables(network, [self._environment_spec.observations]) # The dataset object to learn from. dataset = datasets.make_reverb_dataset( server_address=queue.server_address, batch_size=self._batch_size, prefetch_size=self._prefetch_size) logger = loggers.make_default_logger('learner', steps_key='learner_steps') counter = counting.Counter(counter, 'learner') # Return the learning agent. learner = learning.IMPALALearner( environment_spec=self._environment_spec, network=network, dataset=dataset, discount=self._discount, learning_rate=self._learning_rate, entropy_cost=self._entropy_cost, baseline_cost=self._baseline_cost, max_abs_reward=self._max_abs_reward, max_gradient_norm=self._max_gradient_norm, counter=counter, logger=logger, ) return tf2_savers.CheckpointingRunner(learner, time_delta_minutes=5, subdirectory='impala_learner')
def make_dataset_iterator( self, reverb_client: reverb.Client, ) -> Iterator[reverb.ReplaySample]: """Create a dataset iterator to use for learning/updating the agent.""" # The dataset provides an interface to sample from replay. dataset = datasets.make_reverb_dataset( table=self._config.replay_table_name, server_address=reverb_client.server_address, batch_size=self._config.batch_size, prefetch_size=self._config.prefetch_size) # TODO(b/155086959): Fix type stubs and remove. return iter(dataset) # pytype: disable=wrong-arg-types
def learner( self, replay: reverb.Client, counter: counting.Counter, ): """The Learning part of the agent.""" # Create online and target networks. online_networks = self._network_factory(self._environment_spec) target_networks = self._network_factory(self._environment_spec) # The dataset object to learn from. dataset = datasets.make_reverb_dataset( server_address=replay.server_address, batch_size=self._batch_size, prefetch_size=self._prefetch_size, ) counter = counting.Counter(counter, 'learner') logger = loggers.make_default_logger('learner', time_delta=self._log_every) # Create policy loss module if a factory is passed. if self._policy_loss_factory: policy_loss_module = self._policy_loss_factory() else: policy_loss_module = None # Return the learning agent. return learning.MoGMPOLearner( policy_network=online_networks['policy'], critic_network=online_networks['critic'], observation_network=online_networks['observation'], target_policy_network=target_networks['policy'], target_critic_network=target_networks['critic'], target_observation_network=target_networks['observation'], discount=self._additional_discount, num_samples=self._num_samples, policy_evaluation_config=self._policy_evaluation_config, target_policy_update_period=self._target_policy_update_period, target_critic_update_period=self._target_critic_update_period, policy_loss_module=policy_loss_module, dataset=dataset, counter=counter, logger=logger)
def reset_replay_table(self, name='new_replay_table'): replay_table = reverb.Table( name=name, sampler=reverb.selectors.Uniform(), remover=reverb.selectors.Fifo(), max_size=1000000, rate_limiter=reverb.rate_limiters.MinSize(min_size_to_sample=1), signature=adders.NStepTransitionAdder.signature(self._environment_spec)) port = self._agent._server.port del self._agent._server self._agent._server = reverb.Server([replay_table], port=port) dataset = datasets.make_reverb_dataset( table=name, server_address=f'localhost:{port}', batch_size=256, prefetch_size=4, ) self._agent._learner._iterator = iter(dataset)
def learner(self, replay: reverb.Client, counter: counting.Counter): """The Learning part of the agent.""" # Use architect and create the environment. # Create the networks. network = self._network_factory(self._environment_spec.actions) target_network = copy.deepcopy(network) tf2_utils.create_variables(network, [self._obs_spec]) tf2_utils.create_variables(target_network, [self._obs_spec]) # The dataset object to learn from. reverb_client = reverb.TFClient(replay.server_address) sequence_length = self._burn_in_length + self._trace_length + 1 dataset = datasets.make_reverb_dataset( server_address=replay.server_address, batch_size=self._batch_size, prefetch_size=self._prefetch_size) counter = counting.Counter(counter, 'learner') logger = loggers.make_default_logger( 'learner', save_data=True, steps_key='learner_steps') # Return the learning agent. learner = learning.R2D2Learner( environment_spec=self._environment_spec, network=network, target_network=target_network, burn_in_length=self._burn_in_length, sequence_length=sequence_length, dataset=dataset, reverb_client=reverb_client, counter=counter, logger=logger, discount=self._discount, target_update_period=self._target_update_period, importance_sampling_exponent=self._importance_sampling_exponent, learning_rate=self._learning_rate, max_replay_size=self._max_replay_size) return tf2_savers.CheckpointingRunner( wrapped=learner, time_delta_minutes=60, subdirectory='r2d2_learner')
def learner(self, replay: reverb.Client, counter: counting.Counter): """The Learning part of the agent.""" # Create the networks. network = self._network_factory(self._env_spec.actions) target_network = copy.deepcopy(network) tf2_utils.create_variables(network, [self._env_spec.observations]) tf2_utils.create_variables(target_network, [self._env_spec.observations]) # The dataset object to learn from. replay_client = reverb.Client(replay.server_address) dataset = datasets.make_reverb_dataset( server_address=replay.server_address, batch_size=self._batch_size, prefetch_size=self._prefetch_size) logger = loggers.make_default_logger('learner', steps_key='learner_steps') # Return the learning agent. counter = counting.Counter(counter, 'learner') learner = learning.DQNLearner( network=network, target_network=target_network, discount=self._discount, importance_sampling_exponent=self._importance_sampling_exponent, learning_rate=self._learning_rate, target_update_period=self._target_update_period, dataset=dataset, replay_client=replay_client, counter=counter, logger=logger) return tf2_savers.CheckpointingRunner(learner, subdirectory='dqn_learner', time_delta_minutes=60)
def learner(self, replay: reverb.Client, counter: counting.Counter): """The learning part of the agent.""" # Create the networks. network = self._network_factory(self._env_spec.actions) tf2_utils.create_variables(network, [self._env_spec.observations]) # The dataset object to learn from. dataset = datasets.make_reverb_dataset( server_address=replay.server_address, batch_size=self._batch_size, prefetch_size=self._prefetch_size) # Create the optimizer. optimizer = snt.optimizers.Adam(self._learning_rate) # Return the learning agent. return learning.AZLearner( network=network, discount=self._discount, dataset=dataset, optimizer=optimizer, counter=counter, )
def learner( self, replay: reverb.Client, counter: counting.Counter, ): """The Learning part of the agent.""" act_spec = self._environment_spec.actions obs_spec = self._environment_spec.observations # Create online and target networks. online_networks = self._network_factory(act_spec) target_networks = self._network_factory(act_spec) # Make sure observation networks are Sonnet Modules. observation_network = online_networks.get('observation', tf.identity) observation_network = tf2_utils.to_sonnet_module(observation_network) online_networks['observation'] = observation_network target_observation_network = target_networks.get( 'observation', tf.identity) target_observation_network = tf2_utils.to_sonnet_module( target_observation_network) target_networks['observation'] = target_observation_network # Get embedding spec and create observation network variables. emb_spec = tf2_utils.create_variables(observation_network, [obs_spec]) tf2_utils.create_variables(online_networks['policy'], [emb_spec]) tf2_utils.create_variables(online_networks['critic'], [emb_spec, act_spec]) tf2_utils.create_variables(target_networks['observation'], [obs_spec]) tf2_utils.create_variables(target_networks['policy'], [emb_spec]) tf2_utils.create_variables(target_networks['critic'], [emb_spec, act_spec]) # The dataset object to learn from. dataset = datasets.make_reverb_dataset( server_address=replay.server_address) dataset = dataset.batch(self._batch_size, drop_remainder=True) dataset = dataset.prefetch(self._prefetch_size) # Create a counter and logger for bookkeeping steps and performance. counter = counting.Counter(counter, 'learner') logger = loggers.make_default_logger('learner', time_delta=self._log_every, steps_key='learner_steps') # Create policy loss module if a factory is passed. if self._policy_loss_factory: policy_loss_module = self._policy_loss_factory() else: policy_loss_module = None # Return the learning agent. return learning.MPOLearner( policy_network=online_networks['policy'], critic_network=online_networks['critic'], observation_network=observation_network, target_policy_network=target_networks['policy'], target_critic_network=target_networks['critic'], target_observation_network=target_observation_network, discount=self._additional_discount, num_samples=self._num_samples, target_policy_update_period=self._target_policy_update_period, target_critic_update_period=self._target_critic_update_period, policy_loss_module=policy_loss_module, dataset=dataset, counter=counter, logger=logger)
def __init__( self, environment_spec: specs.EnvironmentSpec, network: snt.Module, batch_size: int = 256, prefetch_size: int = 4, target_update_period: int = 100, samples_per_insert: float = 32.0, min_replay_size: int = 1000, max_replay_size: int = 1000000, importance_sampling_exponent: float = 0.2, priority_exponent: float = 0.6, n_step: int = 5, epsilon: tf.Tensor = None, learning_rate: float = 1e-3, discount: float = 0.99, cql_alpha: float = 1., logger: loggers.Logger = None, counter: counting.Counter = None, checkpoint_subpath: str = '~/acme/', ): """Initialize the agent. Args: environment_spec: description of the actions, observations, etc. network: the online Q network (the one being optimized) batch_size: batch size for updates. prefetch_size: size to prefetch from replay. target_update_period: number of learner steps to perform before updating the target networks. samples_per_insert: number of samples to take from replay for every insert that is made. min_replay_size: minimum replay size before updating. This and all following arguments are related to dataset construction and will be ignored if a dataset argument is passed. max_replay_size: maximum replay size. importance_sampling_exponent: power to which importance weights are raised before normalizing. priority_exponent: exponent used in prioritized sampling. n_step: number of steps to squash into a single transition. epsilon: probability of taking a random action; ignored if a policy network is given. learning_rate: learning rate for the q-network update. discount: discount to use for TD updates. logger: logger object to be used by learner. checkpoint: boolean indicating whether to checkpoint the learner. checkpoint_subpath: directory for the checkpoint. """ # Create a replay server to add data to. This uses no limiter behavior in # order to allow the Agent interface to handle it. replay_table = reverb.Table( name=adders.DEFAULT_PRIORITY_TABLE, sampler=reverb.selectors.Prioritized(priority_exponent), remover=reverb.selectors.Fifo(), max_size=max_replay_size, rate_limiter=reverb.rate_limiters.MinSize(1), signature=adders.NStepTransitionAdder.signature(environment_spec)) self._server = reverb.Server([replay_table], port=None) # The adder is used to insert observations into replay. address = f'localhost:{self._server.port}' adder = adders.NStepTransitionAdder(client=reverb.Client(address), n_step=n_step, discount=discount) # The dataset provides an interface to sample from replay. replay_client = reverb.TFClient(address) dataset = datasets.make_reverb_dataset( client=replay_client, environment_spec=environment_spec, batch_size=batch_size, prefetch_size=prefetch_size, transition_adder=True) # Use constant 0.05 epsilon greedy policy by default. if epsilon is None: epsilon = tf.Variable(0.05, trainable=False) policy_network = snt.Sequential([ network, lambda q: trfl.epsilon_greedy(q, epsilon=epsilon).sample(), ]) # Create a target network. target_network = copy.deepcopy(network) # Ensure that we create the variables before proceeding (maybe not needed). tf2_utils.create_variables(network, [environment_spec.observations]) tf2_utils.create_variables(target_network, [environment_spec.observations]) # Create the actor which defines how we take actions. actor = actors.FeedForwardActor(policy_network, adder) # The learner updates the parameters (and initializes them). learner = CQLLearner( network=network, discount=discount, importance_sampling_exponent=importance_sampling_exponent, learning_rate=learning_rate, cql_alpha=cql_alpha, target_update_period=target_update_period, dataset=dataset, replay_client=replay_client, logger=logger, counter=counter, checkpoint_subpath=checkpoint_subpath) super().__init__(actor=actor, learner=learner, min_observations=max(batch_size, min_replay_size), observations_per_step=float(batch_size) / samples_per_insert)
def __init__(self, environment_spec: specs.EnvironmentSpec, policy_network: snt.Module, critic_network: snt.Module, encoder_network: types.TensorTransformation = tf.identity, entropy_coeff: float = 0.01, target_update_period: int = 0, discount: float = 0.99, batch_size: int = 256, policy_learn_rate: float = 3e-4, critic_learn_rate: float = 5e-4, prefetch_size: int = 4, min_replay_size: int = 1000, max_replay_size: int = 250000, samples_per_insert: float = 64.0, n_step: int = 5, sigma: float = 0.5, clipping: bool = True, logger: loggers.Logger = None, counter: counting.Counter = None, checkpoint: bool = True, replay_table_name: str = adders.DEFAULT_PRIORITY_TABLE): """Initialize the agent. Args: environment_spec: description of the actions, observations, etc. policy_network: the online (optimized) policy. critic_network: the online critic. observation_network: optional network to transform the observations before they are fed into any network. discount: discount to use for TD updates. batch_size: batch size for updates. prefetch_size: size to prefetch from replay. target_update_period: number of learner steps to perform before updating the target networks. min_replay_size: minimum replay size before updating. max_replay_size: maximum replay size. samples_per_insert: number of samples to take from replay for every insert that is made. n_step: number of steps to squash into a single transition. sigma: standard deviation of zero-mean, Gaussian exploration noise. clipping: whether to clip gradients by global norm. logger: logger object to be used by learner. counter: counter object used to keep track of steps. checkpoint: boolean indicating whether to checkpoint the learner. replay_table_name: string indicating what name to give the replay table. """ # Create a replay server to add data to. This uses no limiter behavior in # order to allow the Agent interface to handle it. dim_actions = np.prod(environment_spec.actions.shape, dtype=int) extra_spec = { 'logP': tf.ones(shape=(1), dtype=tf.float32), 'policy': tf.ones(shape=(1, dim_actions), dtype=tf.float32) } # Remove batch dimensions. extra_spec = tf2_utils.squeeze_batch_dim(extra_spec) replay_table = reverb.Table( name=replay_table_name, sampler=reverb.selectors.Uniform(), remover=reverb.selectors.Fifo(), max_size=max_replay_size, rate_limiter=reverb.rate_limiters.MinSize(1), signature=adders.NStepTransitionAdder.signature( environment_spec, extras_spec=extra_spec)) self._server = reverb.Server([replay_table], port=None) # The adder is used to insert observations into replay. address = f'localhost:{self._server.port}' adder = adders.NStepTransitionAdder( priority_fns={replay_table_name: lambda x: 1.}, client=reverb.Client(address), n_step=n_step, discount=discount) # The dataset provides an interface to sample from replay. dataset = datasets.make_reverb_dataset(table=replay_table_name, server_address=address, batch_size=batch_size, prefetch_size=prefetch_size) # Make sure observation network is a Sonnet Module. observation_network = model.MDPNormalization(environment_spec, encoder_network) # Get observation and action specs. act_spec = environment_spec.actions obs_spec = environment_spec.observations # Create the behavior policy. sampling_head = model.SquashedGaussianSamplingHead(act_spec, sigma) self._behavior_network = model.PolicyValueBehaviorNet( snt.Sequential([observation_network, policy_network]), sampling_head) # Create variables. emb_spec = tf2_utils.create_variables(observation_network, [obs_spec]) tf2_utils.create_variables(policy_network, [emb_spec]) tf2_utils.create_variables(critic_network, [emb_spec, act_spec]) # Create the actor which defines how we take actions. actor = model.SACFeedForwardActor(self._behavior_network, adder) if target_update_period > 0: target_policy_network = copy.deepcopy(policy_network) target_critic_network = copy.deepcopy(critic_network) target_observation_network = copy.deepcopy(observation_network) tf2_utils.create_variables(target_policy_network, [emb_spec]) tf2_utils.create_variables(target_critic_network, [emb_spec, act_spec]) tf2_utils.create_variables(target_observation_network, [obs_spec]) else: target_policy_network = policy_network target_critic_network = critic_network target_observation_network = observation_network # Create optimizers. policy_optimizer = snt.optimizers.Adam(learning_rate=policy_learn_rate) critic_optimizer = snt.optimizers.Adam(learning_rate=critic_learn_rate) # The learner updates the parameters (and initializes them). learner = learning.SACLearner( policy_network=policy_network, critic_network=critic_network, sampling_head=sampling_head, observation_network=observation_network, target_policy_network=target_policy_network, target_critic_network=target_critic_network, target_observation_network=target_observation_network, policy_optimizer=policy_optimizer, critic_optimizer=critic_optimizer, target_update_period=target_update_period, learning_rate=policy_learn_rate, clipping=clipping, entropy_coeff=entropy_coeff, discount=discount, dataset=dataset, counter=counter, logger=logger, checkpoint=checkpoint, ) super().__init__(actor=actor, learner=learner, min_observations=max(batch_size, min_replay_size), observations_per_step=float(batch_size) / samples_per_insert)
def __init__( self, environment_spec: specs.EnvironmentSpec, network: networks.PolicyValueRNN, initial_state_fn: Callable[[], networks.RNNState], sequence_length: int, sequence_period: int, counter: counting.Counter = None, logger: loggers.Logger = None, discount: float = 0.99, max_queue_size: int = 100000, batch_size: int = 16, learning_rate: float = 1e-3, entropy_cost: float = 0.01, baseline_cost: float = 0.5, seed: int = 0, max_abs_reward: float = np.inf, max_gradient_norm: float = np.inf, ): num_actions = environment_spec.actions.num_values self._logger = logger or loggers.TerminalLogger('agent') queue = reverb.Table.queue(name=adders.DEFAULT_PRIORITY_TABLE, max_size=max_queue_size) self._server = reverb.Server([queue], port=None) self._can_sample = lambda: queue.can_sample(batch_size) address = f'localhost:{self._server.port}' # Component to add things into replay. adder = adders.SequenceAdder( client=reverb.Client(address), period=sequence_period, sequence_length=sequence_length, ) # The dataset object to learn from. extra_spec = { 'core_state': hk.transform(initial_state_fn).apply(None), 'logits': np.ones(shape=(num_actions, ), dtype=np.float32) } # Remove batch dimensions. dataset = datasets.make_reverb_dataset( client=reverb.TFClient(address), environment_spec=environment_spec, batch_size=batch_size, extra_spec=extra_spec, sequence_length=sequence_length) rng = hk.PRNGSequence(seed) optimizer = optix.chain( optix.clip_by_global_norm(max_gradient_norm), optix.adam(learning_rate), ) self._learner = learning.IMPALALearner( obs_spec=environment_spec.observations, network=network, initial_state_fn=initial_state_fn, iterator=dataset.as_numpy_iterator(), rng=rng, counter=counter, logger=logger, optimizer=optimizer, discount=discount, entropy_cost=entropy_cost, baseline_cost=baseline_cost, max_abs_reward=max_abs_reward, ) variable_client = jax_variable_utils.VariableClient(self._learner, key='policy') self._actor = acting.IMPALAActor( network=network, initial_state_fn=initial_state_fn, rng=rng, adder=adder, variable_client=variable_client, )
def __init__( self, environment_spec: specs.EnvironmentSpec, policy_network: snt.Module, critic_network: snt.Module, observation_network: types.TensorTransformation = tf.identity, discount: float = 0.99, batch_size: int = 256, prefetch_size: int = 4, target_policy_update_period: int = 100, target_critic_update_period: int = 100, min_replay_size: int = 1000, max_replay_size: int = 1000000, samples_per_insert: float = 32.0, policy_loss_module: snt.Module = None, policy_optimizer: snt.Optimizer = None, critic_optimizer: snt.Optimizer = None, n_step: int = 5, num_samples: int = 20, clipping: bool = True, logger: loggers.Logger = None, counter: counting.Counter = None, checkpoint: bool = True, replay_table_name: str = adders.DEFAULT_PRIORITY_TABLE, ): """Initialize the agent. Args: environment_spec: description of the actions, observations, etc. policy_network: the online (optimized) policy. critic_network: the online critic. observation_network: optional network to transform the observations before they are fed into any network. discount: discount to use for TD updates. batch_size: batch size for updates. prefetch_size: size to prefetch from replay. target_policy_update_period: number of updates to perform before updating the target policy network. target_critic_update_period: number of updates to perform before updating the target critic network. min_replay_size: minimum replay size before updating. max_replay_size: maximum replay size. samples_per_insert: number of samples to take from replay for every insert that is made. policy_loss_module: configured MPO loss function for the policy optimization; defaults to sensible values on the control suite. See `acme/tf/losses/mpo.py` for more details. policy_optimizer: optimizer to be used on the policy. critic_optimizer: optimizer to be used on the critic. n_step: number of steps to squash into a single transition. num_samples: number of actions to sample when doing a Monte Carlo integration with respect to the policy. clipping: whether to clip gradients by global norm. logger: logging object used to write to logs. counter: counter object used to keep track of steps. checkpoint: boolean indicating whether to checkpoint the learner. replay_table_name: string indicating what name to give the replay table. """ # Create a replay server to add data to. replay_table = reverb.Table( name=adders.DEFAULT_PRIORITY_TABLE, sampler=reverb.selectors.Uniform(), remover=reverb.selectors.Fifo(), max_size=max_replay_size, rate_limiter=reverb.rate_limiters.MinSize(min_size_to_sample=1), signature=adders.NStepTransitionAdder.signature(environment_spec)) self._server = reverb.Server([replay_table], port=None) # The adder is used to insert observations into replay. address = f'localhost:{self._server.port}' adder = adders.NStepTransitionAdder(client=reverb.Client(address), n_step=n_step, discount=discount) # The dataset object to learn from. dataset = datasets.make_reverb_dataset( table=replay_table_name, client=reverb.TFClient(address), batch_size=batch_size, prefetch_size=prefetch_size, environment_spec=environment_spec, transition_adder=True) # Make sure observation network is a Sonnet Module. observation_network = tf2_utils.to_sonnet_module(observation_network) # Create target networks before creating online/target network variables. target_policy_network = copy.deepcopy(policy_network) target_critic_network = copy.deepcopy(critic_network) target_observation_network = copy.deepcopy(observation_network) # Get observation and action specs. act_spec = environment_spec.actions obs_spec = environment_spec.observations emb_spec = tf2_utils.create_variables(observation_network, [obs_spec]) # Create the behavior policy. behavior_network = snt.Sequential([ observation_network, policy_network, networks.StochasticSamplingHead(), ]) # Create variables. tf2_utils.create_variables(policy_network, [emb_spec]) tf2_utils.create_variables(critic_network, [emb_spec, act_spec]) tf2_utils.create_variables(target_policy_network, [emb_spec]) tf2_utils.create_variables(target_critic_network, [emb_spec, act_spec]) tf2_utils.create_variables(target_observation_network, [obs_spec]) # Create the actor which defines how we take actions. actor = actors.FeedForwardActor(policy_network=behavior_network, adder=adder) # Create optimizers. policy_optimizer = policy_optimizer or snt.optimizers.Adam(1e-4) critic_optimizer = critic_optimizer or snt.optimizers.Adam(1e-4) # The learner updates the parameters (and initializes them). learner = learning.MPOLearner( policy_network=policy_network, critic_network=critic_network, observation_network=observation_network, target_policy_network=target_policy_network, target_critic_network=target_critic_network, target_observation_network=target_observation_network, policy_loss_module=policy_loss_module, policy_optimizer=policy_optimizer, critic_optimizer=critic_optimizer, clipping=clipping, discount=discount, num_samples=num_samples, target_policy_update_period=target_policy_update_period, target_critic_update_period=target_critic_update_period, dataset=dataset, logger=logger, counter=counter, checkpoint=checkpoint) super().__init__(actor=actor, learner=learner, min_observations=max(batch_size, min_replay_size), observations_per_step=float(batch_size) / samples_per_insert)
def __init__( self, environment_spec: specs.EnvironmentSpec, network: snt.Module, demonstration_dataset: tf.data.Dataset, demonstration_ratio: float, batch_size: int = 256, prefetch_size: int = 4, target_update_period: int = 100, samples_per_insert: float = 32.0, min_replay_size: int = 1000, max_replay_size: int = 1000000, importance_sampling_exponent: float = 0.2, n_step: int = 5, epsilon: tf.Tensor = None, learning_rate: float = 1e-3, discount: float = 0.99, ): """Initialize the agent. Args: environment_spec: description of the actions, observations, etc. network: the online Q network (the one being optimized) demonstration_dataset: tf.data.Dataset producing (timestep, action) tuples containing full episodes. demonstration_ratio: Ratio of transitions coming from demonstrations. batch_size: batch size for updates. prefetch_size: size to prefetch from replay. target_update_period: number of learner steps to perform before updating the target networks. samples_per_insert: number of samples to take from replay for every insert that is made. min_replay_size: minimum replay size before updating. This and all following arguments are related to dataset construction and will be ignored if a dataset argument is passed. max_replay_size: maximum replay size. importance_sampling_exponent: power to which importance weights are raised before normalizing. n_step: number of steps to squash into a single transition. epsilon: probability of taking a random action; ignored if a policy network is given. learning_rate: learning rate for the q-network update. discount: discount to use for TD updates. """ # Create a replay server to add data to. This uses no limiter behavior in # order to allow the Agent interface to handle it. replay_table = reverb.Table( name=adders.DEFAULT_PRIORITY_TABLE, sampler=reverb.selectors.Uniform(), remover=reverb.selectors.Fifo(), max_size=max_replay_size, rate_limiter=reverb.rate_limiters.MinSize(1)) self._server = reverb.Server([replay_table], port=None) # The adder is used to insert observations into replay. address = f'localhost:{self._server.port}' adder = adders.NStepTransitionAdder(client=reverb.Client(address), n_step=n_step, discount=discount) # The dataset provides an interface to sample from replay. replay_client = reverb.TFClient(address) dataset = datasets.make_reverb_dataset( client=replay_client, environment_spec=environment_spec, transition_adder=True) # Combine with demonstration dataset. transition = functools.partial(_n_step_transition_from_episode, n_step=n_step, discount=discount) dataset_demos = demonstration_dataset.map(transition) dataset = tf.data.experimental.sample_from_datasets( [dataset, dataset_demos], [1 - demonstration_ratio, demonstration_ratio]) # Batch and prefetch. dataset = dataset.batch(batch_size, drop_remainder=True) dataset = dataset.prefetch(prefetch_size) # Use constant 0.05 epsilon greedy policy by default. if epsilon is None: epsilon = tf.Variable(0.05, trainable=False) policy_network = snt.Sequential([ network, lambda q: trfl.epsilon_greedy(q, epsilon=epsilon).sample(), ]) # Create a target network. target_network = copy.deepcopy(network) # Ensure that we create the variables before proceeding (maybe not needed). tf2_utils.create_variables(network, [environment_spec.observations]) tf2_utils.create_variables(target_network, [environment_spec.observations]) # Create the actor which defines how we take actions. actor = actors.FeedForwardActor(policy_network, adder) # The learner updates the parameters (and initializes them). learner = dqn.DQNLearner( network=network, target_network=target_network, discount=discount, importance_sampling_exponent=importance_sampling_exponent, learning_rate=learning_rate, target_update_period=target_update_period, dataset=dataset, replay_client=replay_client) super().__init__(actor=actor, learner=learner, min_observations=max(batch_size, min_replay_size), observations_per_step=float(batch_size) / samples_per_insert)
def __init__(self, environment_spec: specs.EnvironmentSpec, network: snt.RNNCore, target_network: snt.RNNCore, burn_in_length: int, trace_length: int, replay_period: int, demonstration_dataset: tf.data.Dataset, demonstration_ratio: float, counter: counting.Counter = None, logger: loggers.Logger = None, discount: float = 0.99, batch_size: int = 32, target_update_period: int = 100, importance_sampling_exponent: float = 0.2, epsilon: float = 0.01, learning_rate: float = 1e-3, log_to_bigtable: bool = False, log_name: str = 'agent', checkpoint: bool = True, min_replay_size: int = 1000, max_replay_size: int = 1000000, samples_per_insert: float = 32.0): extra_spec = { 'core_state': network.initial_state(1), } # Remove batch dimensions. extra_spec = tf2_utils.squeeze_batch_dim(extra_spec) replay_table = reverb.Table( name=adders.DEFAULT_PRIORITY_TABLE, sampler=reverb.selectors.Uniform(), remover=reverb.selectors.Fifo(), max_size=max_replay_size, rate_limiter=reverb.rate_limiters.MinSize(min_size_to_sample=1), signature=adders.SequenceAdder.signature(environment_spec, extra_spec)) self._server = reverb.Server([replay_table], port=None) address = f'localhost:{self._server.port}' sequence_length = burn_in_length + trace_length + 1 # Component to add things into replay. sequence_kwargs = dict( period=replay_period, sequence_length=sequence_length, ) adder = adders.SequenceAdder(client=reverb.Client(address), **sequence_kwargs) # The dataset object to learn from. dataset = datasets.make_reverb_dataset(server_address=address, sequence_length=sequence_length) # Combine with demonstration dataset. transition = functools.partial(_sequence_from_episode, extra_spec=extra_spec, **sequence_kwargs) dataset_demos = demonstration_dataset.map(transition) dataset = tf.data.experimental.sample_from_datasets( [dataset, dataset_demos], [1 - demonstration_ratio, demonstration_ratio]) # Batch and prefetch. dataset = dataset.batch(batch_size, drop_remainder=True) dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE) tf2_utils.create_variables(network, [environment_spec.observations]) tf2_utils.create_variables(target_network, [environment_spec.observations]) learner = learning.R2D2Learner( environment_spec=environment_spec, network=network, target_network=target_network, burn_in_length=burn_in_length, dataset=dataset, reverb_client=reverb.TFClient(address), counter=counter, logger=logger, sequence_length=sequence_length, discount=discount, target_update_period=target_update_period, importance_sampling_exponent=importance_sampling_exponent, max_replay_size=max_replay_size, learning_rate=learning_rate, store_lstm_state=False, ) self._checkpointer = tf2_savers.Checkpointer( subdirectory='r2d2_learner', time_delta_minutes=60, objects_to_save=learner.state, enable_checkpointing=checkpoint, ) self._snapshotter = tf2_savers.Snapshotter( objects_to_save={'network': network}, time_delta_minutes=60.) policy_network = snt.DeepRNN([ network, lambda qs: trfl.epsilon_greedy(qs, epsilon=epsilon).sample(), ]) actor = actors.RecurrentActor(policy_network, adder) observations_per_step = (float(replay_period * batch_size) / samples_per_insert) super().__init__(actor=actor, learner=learner, min_observations=replay_period * max(batch_size, min_replay_size), observations_per_step=observations_per_step)
def __init__( self, environment_spec: specs.EnvironmentSpec, network: snt.RNNCore, sequence_length: int, sequence_period: int, counter: counting.Counter = None, logger: loggers.Logger = None, discount: float = 0.99, max_queue_size: int = 100000, batch_size: int = 16, learning_rate: float = 1e-3, entropy_cost: float = 0.01, baseline_cost: float = 0.5, max_abs_reward: Optional[float] = None, max_gradient_norm: Optional[float] = None, ): num_actions = environment_spec.actions.num_values self._logger = logger or loggers.TerminalLogger('agent') extra_spec = { 'core_state': network.initial_state(1), 'logits': tf.ones(shape=(1, num_actions), dtype=tf.float32) } # Remove batch dimensions. extra_spec = tf2_utils.squeeze_batch_dim(extra_spec) queue = reverb.Table.queue(name=adders.DEFAULT_PRIORITY_TABLE, max_size=max_queue_size, signature=adders.SequenceAdder.signature( environment_spec, extras_spec=extra_spec, sequence_length=sequence_length)) self._server = reverb.Server([queue], port=None) self._can_sample = lambda: queue.can_sample(batch_size) address = f'localhost:{self._server.port}' # Component to add things into replay. adder = adders.SequenceAdder( client=reverb.Client(address), period=sequence_period, sequence_length=sequence_length, ) # The dataset object to learn from. dataset = datasets.make_reverb_dataset(server_address=address, batch_size=batch_size) tf2_utils.create_variables(network, [environment_spec.observations]) self._actor = acting.IMPALAActor(network, adder) self._learner = learning.IMPALALearner( environment_spec=environment_spec, network=network, dataset=dataset, counter=counter, logger=logger, discount=discount, learning_rate=learning_rate, entropy_cost=entropy_cost, baseline_cost=baseline_cost, max_gradient_norm=max_gradient_norm, max_abs_reward=max_abs_reward, )
def learner( self, replay: reverb.Client, counter: counting.Counter, ): """The Learning part of the agent.""" act_spec = self._environment_spec.actions obs_spec = self._environment_spec.observations # Create the networks to optimize (online) and target networks. online_networks = self._network_factory(act_spec) target_networks = self._network_factory(act_spec) # Make sure observation network is a Sonnet Module. observation_network = online_networks.get('observation', tf.identity) target_observation_network = target_networks.get( 'observation', tf.identity) observation_network = tf2_utils.to_sonnet_module(observation_network) target_observation_network = tf2_utils.to_sonnet_module( target_observation_network) # Get embedding spec and create observation network variables. emb_spec = tf2_utils.create_variables(observation_network, [obs_spec]) # Create variables. tf2_utils.create_variables(online_networks['policy'], [emb_spec]) tf2_utils.create_variables(online_networks['critic'], [emb_spec, act_spec]) tf2_utils.create_variables(target_networks['policy'], [emb_spec]) tf2_utils.create_variables(target_networks['critic'], [emb_spec, act_spec]) tf2_utils.create_variables(target_observation_network, [obs_spec]) # The dataset object to learn from. dataset = datasets.make_reverb_dataset( server_address=replay.server_address, batch_size=self._batch_size, prefetch_size=self._prefetch_size) # Create optimizers. policy_optimizer = snt.optimizers.Adam(learning_rate=1e-4) critic_optimizer = snt.optimizers.Adam(learning_rate=1e-4) counter = counting.Counter(counter, 'learner') logger = loggers.make_default_logger('learner', time_delta=self._log_every, steps_key='learner_steps') # Return the learning agent. return learning.DDPGLearner( policy_network=online_networks['policy'], critic_network=online_networks['critic'], observation_network=observation_network, target_policy_network=target_networks['policy'], target_critic_network=target_networks['critic'], target_observation_network=target_observation_network, discount=self._discount, target_update_period=self._target_update_period, dataset=dataset, policy_optimizer=policy_optimizer, critic_optimizer=critic_optimizer, clipping=self._clipping, counter=counter, logger=logger, )
"""Creates a single-process replay infrastructure from an environment spec.""" # Create a replay server to add data to. This uses no limiter behavior in # order to allow the Agent interface to handle it. replay_table = reverb.Table( name=replay_table_name, sampler=reverb.selectors.Prioritized(priority_exponent), remover=reverb.selectors.Fifo(), max_size=max_replay_size, rate_limiter=reverb.rate_limiters.MinSize(min_replay_size), signature=adders.NStepTransitionAdder.signature( environment_spec=environment_spec)) server = reverb.Server([replay_table], port=None) # The adder is used to insert observations into replay. address = f'localhost:{server.port}' client = reverb.Client(address) adder = adders.NStepTransitionAdder(client=client, n_step=n_step, discount=discount) # The dataset provides an interface to sample from replay. data_iterator = datasets.make_reverb_dataset( table=replay_table_name, server_address=address, batch_size=batch_size, prefetch_size=prefetch_size, environment_spec=environment_spec, transition_adder=True, ).as_numpy_iterator() return ReverbReplay(server, adder, data_iterator, client)
environment_spec, extra_spec), ) server = reverb.Server([replay_table], port=None) # The adder is used to insert observations into replay. address = f'localhost:{server.port}' client = reverb.Client(address) adder = adders.NStepTransitionAdder(client, n_step, discount, priority_fns=priority_fns) # The dataset provides an interface to sample from replay. data_iterator = datasets.make_reverb_dataset( table=replay_table_name, server_address=address, batch_size=batch_size, prefetch_size=prefetch_size, ).as_numpy_iterator() return ReverbReplay(server, adder, data_iterator, client=client) def make_reverb_online_queue( environment_spec: specs.EnvironmentSpec, extra_spec: Dict[str, Any], max_queue_size: int, sequence_length: int, sequence_period: int, batch_size: int, replay_table_name: str = adders.DEFAULT_PRIORITY_TABLE, ) -> ReverbReplay: """Creates a single process queue from an environment spec and extra_spec."""
def __init__(self, environment_spec: specs.EnvironmentSpec, policy_network: snt.Module, critic_network: snt.Module, observation_network: types.TensorTransformation = tf.identity, discount: float = 0.99, batch_size: int = 256, prefetch_size: int = 4, target_update_period: int = 100, min_replay_size: int = 1000, max_replay_size: int = 1000000, samples_per_insert: float = 32.0, n_step: int = 5, sigma: float = 0.3, clipping: bool = True, logger: loggers.Logger = None, counter: counting.Counter = None, checkpoint: bool = True, replay_table_name: str = adders.DEFAULT_PRIORITY_TABLE): """Initialize the agent. Args: environment_spec: description of the actions, observations, etc. policy_network: the online (optimized) policy. critic_network: the online critic. observation_network: optional network to transform the observations before they are fed into any network. discount: discount to use for TD updates. batch_size: batch size for updates. prefetch_size: size to prefetch from replay. target_update_period: number of learner steps to perform before updating the target networks. min_replay_size: minimum replay size before updating. max_replay_size: maximum replay size. samples_per_insert: number of samples to take from replay for every insert that is made. n_step: number of steps to squash into a single transition. sigma: standard deviation of zero-mean, Gaussian exploration noise. clipping: whether to clip gradients by global norm. logger: logger object to be used by learner. counter: counter object used to keep track of steps. checkpoint: boolean indicating whether to checkpoint the learner. replay_table_name: string indicating what name to give the replay table. """ # Create a replay server to add data to. This uses no limiter behavior in # order to allow the Agent interface to handle it. replay_table = reverb.Table( name=replay_table_name, sampler=reverb.selectors.Uniform(), remover=reverb.selectors.Fifo(), max_size=max_replay_size, rate_limiter=reverb.rate_limiters.MinSize(1), signature=adders.NStepTransitionAdder.signature(environment_spec)) self._server = reverb.Server([replay_table], port=None) # The adder is used to insert observations into replay. address = f'localhost:{self._server.port}' adder = adders.NStepTransitionAdder( priority_fns={replay_table_name: lambda x: 1.}, client=reverb.Client(address), n_step=n_step, discount=discount) # The dataset provides an interface to sample from replay. dataset = datasets.make_reverb_dataset( table=replay_table_name, client=reverb.TFClient(address), environment_spec=environment_spec, batch_size=batch_size, prefetch_size=prefetch_size, transition_adder=True) # Get observation and action specs. act_spec = environment_spec.actions obs_spec = environment_spec.observations emb_spec = tf2_utils.create_variables(observation_network, [obs_spec]) # pytype: disable=wrong-arg-types # Make sure observation network is a Sonnet Module. observation_network = tf2_utils.to_sonnet_module(observation_network) # Create target networks. target_policy_network = copy.deepcopy(policy_network) target_critic_network = copy.deepcopy(critic_network) target_observation_network = copy.deepcopy(observation_network) # Create the behavior policy. behavior_network = snt.Sequential([ observation_network, policy_network, networks.ClippedGaussian(sigma), networks.ClipToSpec(act_spec), ]) # Create variables. tf2_utils.create_variables(policy_network, [emb_spec]) tf2_utils.create_variables(critic_network, [emb_spec, act_spec]) tf2_utils.create_variables(target_policy_network, [emb_spec]) tf2_utils.create_variables(target_critic_network, [emb_spec, act_spec]) tf2_utils.create_variables(target_observation_network, [obs_spec]) # Create the actor which defines how we take actions. actor = actors.FeedForwardActor(behavior_network, adder=adder) # Create optimizers. policy_optimizer = snt.optimizers.Adam(learning_rate=1e-4) critic_optimizer = snt.optimizers.Adam(learning_rate=1e-4) # The learner updates the parameters (and initializes them). learner = learning.DDPGLearner( policy_network=policy_network, critic_network=critic_network, observation_network=observation_network, target_policy_network=target_policy_network, target_critic_network=target_critic_network, target_observation_network=target_observation_network, policy_optimizer=policy_optimizer, critic_optimizer=critic_optimizer, clipping=clipping, discount=discount, target_update_period=target_update_period, dataset=dataset, counter=counter, logger=logger, checkpoint=checkpoint, ) super().__init__(actor=actor, learner=learner, min_observations=max(batch_size, min_replay_size), observations_per_step=float(batch_size) / samples_per_insert)
def __init__( self, environment_spec: specs.EnvironmentSpec, network: snt.Module, params=None, logger: loggers.Logger = None, checkpoint: bool = True, paths: Save_paths = None, ): """Initialize the agent. Args: environment_spec: description of the actions, observations, etc. network: the online Q network (the one being optimized) batch_size: batch size for updates. prefetch_size: size to prefetch from replay. target_update_period: number of learner steps to perform before updating the target networks. samples_per_insert: number of samples to take from replay for every insert that is made. min_replay_size: minimum replay size before updating. This and all following arguments are related to dataset construction and will be ignored if a dataset argument is passed. max_replay_size: maximum replay size. importance_sampling_exponent: power to which importance weights are raised before normalizing. priority_exponent: exponent used in prioritized sampling. n_step: number of steps to squash into a single transition. epsilon: probability of taking a random action; ignored if a policy network is given. learning_rate: learning rate for the q-network update. discount: discount to use for TD updates. logger: logger object to be used by learner. checkpoint: boolean indicating whether to checkpoint the learner. """ # Create a replay server to add data to. This uses no limiter behavior in # order to allow the Agent interface to handle it. if params is None: params = { 'batch_size': 256, 'prefetch_size': 4, 'target_update_period': 100, 'samples_per_insert': 32.0, 'min_replay_size': 1000, 'max_replay_size': 1000000, 'importance_sampling_exponent': 0.2, 'priority_exponent': 0.6, 'n_step': 5, 'epsilon': 0.05, 'learning_rate': 1e-3, 'discount': 0.99, } replay_table = reverb.Table( name=adders.DEFAULT_PRIORITY_TABLE, sampler=reverb.selectors.Prioritized(params['priority_exponent']), remover=reverb.selectors.Fifo(), max_size=params['max_replay_size'], rate_limiter=reverb.rate_limiters.MinSize(1)) self._server = reverb.Server([replay_table], port=None) # The adder is used to insert observations into replay. address = f'localhost:{self._server.port}' adder = adders.NStepTransitionAdder(client=reverb.Client(address), n_step=params['n_step'], discount=params['discount']) # The dataset provides an interface to sample from replay. replay_client = reverb.TFClient(address) dataset = datasets.make_reverb_dataset( client=replay_client, environment_spec=environment_spec, batch_size=params['batch_size'], prefetch_size=params['prefetch_size'], transition_adder=True) # Use constant 0.05 epsilon greedy policy by default. epsilon = tf.Variable(params['epsilon'], trainable=False) policy_network = snt.Sequential([ network, lambda q: trfl.epsilon_greedy(q, epsilon=epsilon).sample(), ]) # Create a target network. target_network = copy.deepcopy(network) # Ensure that we create the variables before proceeding (maybe not needed). # tf2_utils.create_variables(network, [environment_spec.observations]) # tf2_utils.create_variables(target_network, [environment_spec.observations]) # Create the actor which defines how we take actions. actor = actors.FeedForwardActor(policy_network, adder) # The learner updates the parameters (and initializes them). learner = learning.DQNLearner( network=network, target_network=target_network, discount=params['discount'], importance_sampling_exponent=params[ 'importance_sampling_exponent'], learning_rate=params['learning_rate'], target_update_period=params['target_update_period'], dataset=dataset, replay_client=replay_client, logger=logger, checkpoint=checkpoint) if checkpoint: self._checkpointer = tf2_savers.Checkpointer( add_uid=False, objects_to_save=learner.state, directory=paths.data_dir, subdirectory=paths.experiment_name, time_delta_minutes=60.) else: self._checkpointer = None super().__init__(actor=actor, learner=learner, min_observations=max(params['batch_size'], params['min_replay_size']), observations_per_step=float(params['batch_size']) / params['samples_per_insert'])
def __init__( self, network: snt.Module, model: models.Model, optimizer: snt.Optimizer, n_step: int, discount: float, replay_capacity: int, num_simulations: int, environment_spec: specs.EnvironmentSpec, batch_size: int, ): # Create a replay server for storing transitions. replay_table = reverb.Table( name=adders.DEFAULT_PRIORITY_TABLE, sampler=reverb.selectors.Uniform(), remover=reverb.selectors.Fifo(), max_size=replay_capacity, rate_limiter=reverb.rate_limiters.MinSize(1)) self._server = reverb.Server([replay_table], port=None) # The adder is used to insert observations into replay. address = f'localhost:{self._server.port}' adder = adders.NStepTransitionAdder(client=reverb.Client(address), n_step=n_step, discount=discount) # The dataset provides an interface to sample from replay. replay_client = reverb.TFClient(address) action_spec: specs.DiscreteArray = environment_spec.actions dataset = datasets.make_reverb_dataset( client=replay_client, environment_spec=environment_spec, extra_spec={ 'pi': specs.Array(shape=(action_spec.num_values, ), dtype=np.float32) }, transition_adder=True) dataset = dataset.batch(batch_size, drop_remainder=True) tf2_utils.create_variables(network, [environment_spec.observations]) # Now create the agent components: actor & learner. actor = acting.MCTSActor( environment_spec=environment_spec, model=model, network=network, discount=discount, adder=adder, num_simulations=num_simulations, ) learner = learning.AZLearner( network=network, optimizer=optimizer, dataset=dataset, discount=discount, ) # The parent class combines these together into one 'agent'. super().__init__( actor=actor, learner=learner, min_observations=10, observations_per_step=1, )
def __init__( self, environment_spec: specs.EnvironmentSpec, network: snt.Module, batch_size: int = 32, prefetch_size: int = 4, target_update_period: int = 100, samples_per_insert: float = 32.0, min_replay_size: int = 1000, max_replay_size: int = 100000, importance_sampling_exponent: float = 0.2, priority_exponent: float = 0.6, n_step: int = 5, epsilon: Optional[float] = 0.05, learning_rate: float = 1e-3, discount: float = 0.99, logger: loggers.Logger = None, max_gradient_norm: Optional[float] = None, expert_data: List[Dict] = None, ) -> None: """ Initialize the agent. """ # Create a replay server to add data to. This uses no limiter behavior # in order to allow the Agent interface to handle it. replay_table = reverb.Table( name=adders.DEFAULT_PRIORITY_TABLE, sampler=reverb.selectors.Prioritized(priority_exponent), remover=reverb.selectors.Fifo(), max_size=max_replay_size, rate_limiter=reverb.rate_limiters.MinSize(1), signature=adders.NStepTransitionAdder.signature(environment_spec)) self._server = reverb.Server([replay_table], port=None) # The adder is used to insert observations into replay. address = f'localhost:{self._server.port}' adder = adders.NStepTransitionAdder(client=reverb.Client(address), n_step=n_step, discount=discount) # Adding expert data to the replay memory: if expert_data is not None: for d in expert_data: adder.add_first(d["first"]) for (action, next_ts) in d["mid"]: adder.add(np.int32(action), next_ts) # The dataset provides an interface to sample from replay. replay_client = reverb.TFClient(address) dataset = datasets.make_reverb_dataset(server_address=address, batch_size=batch_size, prefetch_size=prefetch_size) # Creating the epsilon greedy policy network: epsilon = tf.Variable(epsilon) policy_network = snt.Sequential([ network, lambda q: trfl.epsilon_greedy(q, epsilon=epsilon).sample(), ]) # Create a target network. target_network = copy.deepcopy(network) # Ensure that we create the variables before proceeding (maybe not # needed). tf2_utils.create_variables(network, [environment_spec.observations]) tf2_utils.create_variables(target_network, [environment_spec.observations]) # Create the actor which defines how we take actions. actor = actors.FeedForwardActor(policy_network, adder) # The learner updates the parameters (and initializes them). learner = learning.DQNLearner( network=network, target_network=target_network, discount=discount, importance_sampling_exponent=importance_sampling_exponent, learning_rate=learning_rate, target_update_period=target_update_period, dataset=dataset, replay_client=replay_client, max_gradient_norm=max_gradient_norm, logger=logger, ) super().__init__(actor=actor, learner=learner, min_observations=max(batch_size, min_replay_size), observations_per_step=float(batch_size) / samples_per_insert)
def __init__( self, environment_spec: specs.EnvironmentSpec, network: hk.Transformed, batch_size: int = 256, prefetch_size: int = 4, target_update_period: int = 100, samples_per_insert: float = 32.0, min_replay_size: int = 1000, max_replay_size: int = 1000000, importance_sampling_exponent: float = 0.2, priority_exponent: float = 0.6, n_step: int = 5, epsilon: float = 0., learning_rate: float = 1e-3, discount: float = 0.99, seed: int = 1, ): """Initialize the agent.""" # Create a replay server to add data to. This uses no limiter behavior in # order to allow the Agent interface to handle it. replay_table = reverb.Table( name=adders.DEFAULT_PRIORITY_TABLE, sampler=reverb.selectors.Prioritized(priority_exponent), remover=reverb.selectors.Fifo(), max_size=max_replay_size, rate_limiter=reverb.rate_limiters.MinSize(1), signature=adders.NStepTransitionAdder.signature( environment_spec=environment_spec)) self._server = reverb.Server([replay_table], port=None) # The adder is used to insert observations into replay. address = f'localhost:{self._server.port}' adder = adders.NStepTransitionAdder(client=reverb.Client(address), n_step=n_step, discount=discount) # The dataset provides an interface to sample from replay. dataset = datasets.make_reverb_dataset( server_address=address, environment_spec=environment_spec, batch_size=batch_size, prefetch_size=prefetch_size, transition_adder=True) def policy(params: hk.Params, key: jnp.ndarray, observation: jnp.ndarray) -> jnp.ndarray: action_values = network.apply(params, observation) return rlax.epsilon_greedy(epsilon).sample(key, action_values) # The learner updates the parameters (and initializes them). learner = learning.DQNLearner( network=network, obs_spec=environment_spec.observations, rng=hk.PRNGSequence(seed), optimizer=optax.adam(learning_rate), discount=discount, importance_sampling_exponent=importance_sampling_exponent, target_update_period=target_update_period, iterator=dataset.as_numpy_iterator(), replay_client=reverb.Client(address), ) variable_client = variable_utils.VariableClient(learner, '') actor = actors.FeedForwardActor(policy=policy, rng=hk.PRNGSequence(seed), variable_client=variable_client, adder=adder) super().__init__(actor=actor, learner=learner, min_observations=max(batch_size, min_replay_size), observations_per_step=float(batch_size) / samples_per_insert)
def __init__( self, environment_spec: specs.EnvironmentSpec, network: snt.RNNCore, burn_in_length: int, trace_length: int, replay_period: int, counter: counting.Counter = None, logger: loggers.Logger = None, discount: float = 0.99, batch_size: int = 32, prefetch_size: int = tf.data.experimental.AUTOTUNE, target_update_period: int = 100, importance_sampling_exponent: float = 0.2, priority_exponent: float = 0.6, epsilon: float = 0.01, learning_rate: float = 1e-3, min_replay_size: int = 1000, max_replay_size: int = 1000000, samples_per_insert: float = 32.0, store_lstm_state: bool = True, max_priority_weight: float = 0.9, checkpoint: bool = True, ): replay_table = reverb.Table( name=adders.DEFAULT_PRIORITY_TABLE, sampler=reverb.selectors.Prioritized(priority_exponent), remover=reverb.selectors.Fifo(), max_size=max_replay_size, rate_limiter=reverb.rate_limiters.MinSize(min_size_to_sample=1)) self._server = reverb.Server([replay_table], port=None) address = f'localhost:{self._server.port}' sequence_length = burn_in_length + trace_length + 1 # Component to add things into replay. adder = adders.SequenceAdder( client=reverb.Client(address), period=replay_period, sequence_length=sequence_length, ) # The dataset object to learn from. reverb_client = reverb.TFClient(address) extra_spec = { 'core_state': network.initial_state(1), } # Remove batch dimensions. extra_spec = tf2_utils.squeeze_batch_dim(extra_spec) dataset = datasets.make_reverb_dataset( client=reverb_client, environment_spec=environment_spec, batch_size=batch_size, prefetch_size=prefetch_size, extra_spec=extra_spec, sequence_length=sequence_length) target_network = copy.deepcopy(network) tf2_utils.create_variables(network, [environment_spec.observations]) tf2_utils.create_variables(target_network, [environment_spec.observations]) learner = learning.R2D2Learner( environment_spec=environment_spec, network=network, target_network=target_network, burn_in_length=burn_in_length, sequence_length=sequence_length, dataset=dataset, reverb_client=reverb_client, counter=counter, logger=logger, discount=discount, target_update_period=target_update_period, importance_sampling_exponent=importance_sampling_exponent, max_replay_size=max_replay_size, learning_rate=learning_rate, store_lstm_state=store_lstm_state, max_priority_weight=max_priority_weight, ) self._checkpointer = tf2_savers.Checkpointer( subdirectory='r2d2_learner', time_delta_minutes=60, objects_to_save=learner.state, enable_checkpointing=checkpoint, ) self._snapshotter = tf2_savers.Snapshotter( objects_to_save={'network': network}, time_delta_minutes=60.) policy_network = snt.DeepRNN([ network, lambda qs: trfl.epsilon_greedy(qs, epsilon=epsilon).sample(), ]) actor = actors.RecurrentActor(policy_network, adder) observations_per_step = (float(replay_period * batch_size) / samples_per_insert) super().__init__(actor=actor, learner=learner, min_observations=replay_period * max(batch_size, min_replay_size), observations_per_step=observations_per_step)