def __init__( self, actor_id, environment_module, environment_fn_name, environment_kwargs, network_module, network_fn_name, network_kwargs, adder_module, adder_fn_name, adder_kwargs, replay_server_address, variable_server_name, variable_server_address, counter: counting.Counter = None, logger: loggers.Logger = None, ): # Counter and Logger self._actor_id = actor_id self._counter = counter or counting.Counter() self._logger = logger or loggers.make_default_logger( f'actor_{actor_id}') # Create the environment self._environment = getattr(environment_module, environment_fn_name)(**environment_kwargs) env_spec = acme.make_environment_spec(self._environment) # Create actor's network self._network = getattr(network_module, network_fn_name)(**network_kwargs) tf2_utils.create_variables(self._network, [env_spec.observations]) self._variables = tree.flatten(self._network.variables) self._policy = tf.function(self._network) # The adder is used to insert observations into replay. self._adder = getattr(adder_module, adder_fn_name)( reverb.Client(replay_server_address), **adder_kwargs) variable_client = reverb.TFClient(variable_server_address) self._variable_dataset = variable_client.dataset( table=variable_server_name, dtypes=[tf.float32 for _ in self._variables], shapes=[v.shape for v in self._variables])
def __init__( self, # --- db_server: str, # --- table: str, variables: dict, ): self._table = table self._variables = variables # Initializes the reverb client self.tf_client = reverb.TFClient(server_address=db_server) # variables signature for variable container table self.signature = tf.nest.map_structure( lambda variable: tf.TensorSpec(variable.shape, dtype=variable.dtype), self._variables, ) self.dtypes = tf.nest.map_structure(lambda spec: spec.dtype, self.signature)
def learner(self, replay: reverb.Client, counter: counting.Counter): """The Learning part of the agent.""" # Use architect and create the environment. # Create the networks. network = self._network_factory(self._environment_spec.actions) target_network = copy.deepcopy(network) tf2_utils.create_variables(network, [self._obs_spec]) tf2_utils.create_variables(target_network, [self._obs_spec]) # The dataset object to learn from. reverb_client = reverb.TFClient(replay.server_address) sequence_length = self._burn_in_length + self._trace_length + 1 dataset = datasets.make_reverb_dataset( server_address=replay.server_address, batch_size=self._batch_size, prefetch_size=self._prefetch_size) counter = counting.Counter(counter, 'learner') logger = loggers.make_default_logger( 'learner', save_data=True, steps_key='learner_steps') # Return the learning agent. learner = learning.R2D2Learner( environment_spec=self._environment_spec, network=network, target_network=target_network, burn_in_length=self._burn_in_length, sequence_length=sequence_length, dataset=dataset, reverb_client=reverb_client, counter=counter, logger=logger, discount=self._discount, target_update_period=self._target_update_period, importance_sampling_exponent=self._importance_sampling_exponent, learning_rate=self._learning_rate, max_replay_size=self._max_replay_size) return tf2_savers.CheckpointingRunner( wrapped=learner, time_delta_minutes=60, subdirectory='r2d2_learner')
def __init__( self, environment_spec: specs.EnvironmentSpec, network: networks.PolicyValueRNN, initial_state_fn: Callable[[], networks.RNNState], sequence_length: int, sequence_period: int, counter: counting.Counter = None, logger: loggers.Logger = None, discount: float = 0.99, max_queue_size: int = 100000, batch_size: int = 16, learning_rate: float = 1e-3, entropy_cost: float = 0.01, baseline_cost: float = 0.5, seed: int = 0, max_abs_reward: float = np.inf, max_gradient_norm: float = np.inf, ): num_actions = environment_spec.actions.num_values self._logger = logger or loggers.TerminalLogger('agent') queue = reverb.Table.queue(name=adders.DEFAULT_PRIORITY_TABLE, max_size=max_queue_size) self._server = reverb.Server([queue], port=None) self._can_sample = lambda: queue.can_sample(batch_size) address = f'localhost:{self._server.port}' # Component to add things into replay. adder = adders.SequenceAdder( client=reverb.Client(address), period=sequence_period, sequence_length=sequence_length, ) # The dataset object to learn from. extra_spec = { 'core_state': hk.transform(initial_state_fn).apply(None), 'logits': np.ones(shape=(num_actions, ), dtype=np.float32) } # Remove batch dimensions. dataset = datasets.make_reverb_dataset( client=reverb.TFClient(address), environment_spec=environment_spec, batch_size=batch_size, extra_spec=extra_spec, sequence_length=sequence_length) rng = hk.PRNGSequence(seed) optimizer = optix.chain( optix.clip_by_global_norm(max_gradient_norm), optix.adam(learning_rate), ) self._learner = learning.IMPALALearner( obs_spec=environment_spec.observations, network=network, initial_state_fn=initial_state_fn, iterator=dataset.as_numpy_iterator(), rng=rng, counter=counter, logger=logger, optimizer=optimizer, discount=discount, entropy_cost=entropy_cost, baseline_cost=baseline_cost, max_abs_reward=max_abs_reward, ) variable_client = jax_variable_utils.VariableClient(self._learner, key='policy') self._actor = acting.IMPALAActor( network=network, initial_state_fn=initial_state_fn, rng=rng, adder=adder, variable_client=variable_client, )
def __init__( self, environment_spec: specs.EnvironmentSpec, network: snt.RNNCore, burn_in_length: int, trace_length: int, replay_period: int, counter: counting.Counter = None, logger: loggers.Logger = None, discount: float = 0.99, batch_size: int = 32, prefetch_size: int = tf.data.experimental.AUTOTUNE, target_update_period: int = 100, importance_sampling_exponent: float = 0.2, priority_exponent: float = 0.6, epsilon_init: float = 1.0, epsilon_final: float = 0.01, epsilon_schedule_timesteps: float = 20000, learning_rate: float = 1e-3, min_replay_size: int = 1000, max_replay_size: int = 1000000, samples_per_insert: float = 32.0, store_lstm_state: bool = True, max_priority_weight: float = 0.9, checkpoint: bool = True, ): if store_lstm_state: extra_spec = { 'core_state': tf2_utils.squeeze_batch_dim(network.initial_state(1)), } else: extra_spec = () replay_table = reverb.Table( name=adders.DEFAULT_PRIORITY_TABLE, sampler=reverb.selectors.Prioritized(priority_exponent), remover=reverb.selectors.Fifo(), max_size=max_replay_size, rate_limiter=reverb.rate_limiters.MinSize(min_size_to_sample=1), signature=adders.SequenceAdder.signature(environment_spec, extra_spec)) self._server = reverb.Server([replay_table], port=None) address = f'localhost:{self._server.port}' sequence_length = burn_in_length + trace_length + 1 # Component to add things into replay. self._adder = adders.SequenceAdder( client=reverb.Client(address), period=replay_period, sequence_length=sequence_length, ) # The dataset object to learn from. dataset = make_reverb_dataset(server_address=address, batch_size=batch_size, prefetch_size=prefetch_size, sequence_length=sequence_length) target_network = copy.deepcopy(network) tf2_utils.create_variables(network, [environment_spec.observations]) tf2_utils.create_variables(target_network, [environment_spec.observations]) learner = learning.R2D2Learner( environment_spec=environment_spec, network=network, target_network=target_network, burn_in_length=burn_in_length, sequence_length=sequence_length, dataset=dataset, reverb_client=reverb.TFClient(address), counter=counter, logger=logger, discount=discount, target_update_period=target_update_period, importance_sampling_exponent=importance_sampling_exponent, max_replay_size=max_replay_size, learning_rate=learning_rate, store_lstm_state=store_lstm_state, max_priority_weight=max_priority_weight, ) self._saver = tf2_savers.Saver(learner.state) policy_network = snt.DeepRNN([ network, EpsilonGreedyExploration( epsilon_init=epsilon_init, epsilon_final=epsilon_final, epsilon_schedule_timesteps=epsilon_schedule_timesteps) ]) actor = actors.RecurrentActor(policy_network, self._adder, store_recurrent_state=store_lstm_state) max_Q_network = snt.DeepRNN([ network, lambda qs: trfl.epsilon_greedy(qs, epsilon=0.0).sample(), ]) self._deterministic_actor = actors.RecurrentActor( max_Q_network, self._adder, store_recurrent_state=store_lstm_state) observations_per_step = (float(replay_period * batch_size) / samples_per_insert) super().__init__(actor=actor, learner=learner, min_observations=replay_period * max(batch_size, min_replay_size), observations_per_step=observations_per_step)
def __init__( self, environment_spec: specs.EnvironmentSpec, policy_network: snt.Module, critic_network: snt.Module, observation_network: types.TensorTransformation = tf.identity, discount: float = 0.99, batch_size: int = 256, prefetch_size: int = 4, target_policy_update_period: int = 100, target_critic_update_period: int = 100, min_replay_size: int = 1000, max_replay_size: int = 1000000, samples_per_insert: float = 32.0, policy_loss_module: snt.Module = None, policy_optimizer: snt.Optimizer = None, critic_optimizer: snt.Optimizer = None, n_step: int = 5, num_samples: int = 20, clipping: bool = True, logger: loggers.Logger = None, counter: counting.Counter = None, checkpoint: bool = True, replay_table_name: str = adders.DEFAULT_PRIORITY_TABLE, ): """Initialize the agent. Args: environment_spec: description of the actions, observations, etc. policy_network: the online (optimized) policy. critic_network: the online critic. observation_network: optional network to transform the observations before they are fed into any network. discount: discount to use for TD updates. batch_size: batch size for updates. prefetch_size: size to prefetch from replay. target_policy_update_period: number of updates to perform before updating the target policy network. target_critic_update_period: number of updates to perform before updating the target critic network. min_replay_size: minimum replay size before updating. max_replay_size: maximum replay size. samples_per_insert: number of samples to take from replay for every insert that is made. policy_loss_module: configured MPO loss function for the policy optimization; defaults to sensible values on the control suite. See `acme/tf/losses/mpo.py` for more details. policy_optimizer: optimizer to be used on the policy. critic_optimizer: optimizer to be used on the critic. n_step: number of steps to squash into a single transition. num_samples: number of actions to sample when doing a Monte Carlo integration with respect to the policy. clipping: whether to clip gradients by global norm. logger: logging object used to write to logs. counter: counter object used to keep track of steps. checkpoint: boolean indicating whether to checkpoint the learner. replay_table_name: string indicating what name to give the replay table. """ # Create a replay server to add data to. replay_table = reverb.Table( name=adders.DEFAULT_PRIORITY_TABLE, sampler=reverb.selectors.Uniform(), remover=reverb.selectors.Fifo(), max_size=max_replay_size, rate_limiter=reverb.rate_limiters.MinSize(min_size_to_sample=1), signature=adders.NStepTransitionAdder.signature(environment_spec)) self._server = reverb.Server([replay_table], port=None) # The adder is used to insert observations into replay. address = f'localhost:{self._server.port}' adder = adders.NStepTransitionAdder(client=reverb.Client(address), n_step=n_step, discount=discount) # The dataset object to learn from. dataset = datasets.make_reverb_dataset( table=replay_table_name, client=reverb.TFClient(address), batch_size=batch_size, prefetch_size=prefetch_size, environment_spec=environment_spec, transition_adder=True) # Make sure observation network is a Sonnet Module. observation_network = tf2_utils.to_sonnet_module(observation_network) # Create target networks before creating online/target network variables. target_policy_network = copy.deepcopy(policy_network) target_critic_network = copy.deepcopy(critic_network) target_observation_network = copy.deepcopy(observation_network) # Get observation and action specs. act_spec = environment_spec.actions obs_spec = environment_spec.observations emb_spec = tf2_utils.create_variables(observation_network, [obs_spec]) # Create the behavior policy. behavior_network = snt.Sequential([ observation_network, policy_network, networks.StochasticSamplingHead(), ]) # Create variables. tf2_utils.create_variables(policy_network, [emb_spec]) tf2_utils.create_variables(critic_network, [emb_spec, act_spec]) tf2_utils.create_variables(target_policy_network, [emb_spec]) tf2_utils.create_variables(target_critic_network, [emb_spec, act_spec]) tf2_utils.create_variables(target_observation_network, [obs_spec]) # Create the actor which defines how we take actions. actor = actors.FeedForwardActor(policy_network=behavior_network, adder=adder) # Create optimizers. policy_optimizer = policy_optimizer or snt.optimizers.Adam(1e-4) critic_optimizer = critic_optimizer or snt.optimizers.Adam(1e-4) # The learner updates the parameters (and initializes them). learner = learning.MPOLearner( policy_network=policy_network, critic_network=critic_network, observation_network=observation_network, target_policy_network=target_policy_network, target_critic_network=target_critic_network, target_observation_network=target_observation_network, policy_loss_module=policy_loss_module, policy_optimizer=policy_optimizer, critic_optimizer=critic_optimizer, clipping=clipping, discount=discount, num_samples=num_samples, target_policy_update_period=target_policy_update_period, target_critic_update_period=target_critic_update_period, dataset=dataset, logger=logger, counter=counter, checkpoint=checkpoint) super().__init__(actor=actor, learner=learner, min_observations=max(batch_size, min_replay_size), observations_per_step=float(batch_size) / samples_per_insert)
# Initalize Reverb Server server = reverb.Server(tables=[ reverb.Table(name='ReplayBuffer', sampler=reverb.selectors.Uniform(), remover=reverb.selectors.Fifo(), max_size=buffer_size, rate_limiter=reverb.rate_limiters.MinSize(1)), reverb.Table(name='PrioritizedReplayBuffer', sampler=reverb.selectors.Prioritized(alpha), remover=reverb.selectors.Fifo(), max_size=buffer_size, rate_limiter=reverb.rate_limiters.MinSize(1)) ]) client = reverb.Client(f"localhost:{server.port}") tf_client = reverb.TFClient(f"localhost:{server.port}") # Helper Function def env(n): e = { "obs": np.ones((n, obs_shape)), "act": np.zeros((n, act_shape)), "next_obs": np.ones((n, obs_shape)), "rew": np.zeros(n), "done": np.zeros(n) } return e def add_client(_rb, table):
def __init__( self, environment_spec: specs.EnvironmentSpec, network: snt.Module, demonstration_dataset: tf.data.Dataset, demonstration_ratio: float, batch_size: int = 256, prefetch_size: int = 4, target_update_period: int = 100, samples_per_insert: float = 32.0, min_replay_size: int = 1000, max_replay_size: int = 1000000, importance_sampling_exponent: float = 0.2, n_step: int = 5, epsilon: tf.Tensor = None, learning_rate: float = 1e-3, discount: float = 0.99, ): """Initialize the agent. Args: environment_spec: description of the actions, observations, etc. network: the online Q network (the one being optimized) demonstration_dataset: tf.data.Dataset producing (timestep, action) tuples containing full episodes. demonstration_ratio: Ratio of transitions coming from demonstrations. batch_size: batch size for updates. prefetch_size: size to prefetch from replay. target_update_period: number of learner steps to perform before updating the target networks. samples_per_insert: number of samples to take from replay for every insert that is made. min_replay_size: minimum replay size before updating. This and all following arguments are related to dataset construction and will be ignored if a dataset argument is passed. max_replay_size: maximum replay size. importance_sampling_exponent: power to which importance weights are raised before normalizing. n_step: number of steps to squash into a single transition. epsilon: probability of taking a random action; ignored if a policy network is given. learning_rate: learning rate for the q-network update. discount: discount to use for TD updates. """ # Create a replay server to add data to. This uses no limiter behavior in # order to allow the Agent interface to handle it. replay_table = reverb.Table( name=adders.DEFAULT_PRIORITY_TABLE, sampler=reverb.selectors.Uniform(), remover=reverb.selectors.Fifo(), max_size=max_replay_size, rate_limiter=reverb.rate_limiters.MinSize(1)) self._server = reverb.Server([replay_table], port=None) # The adder is used to insert observations into replay. address = f'localhost:{self._server.port}' adder = adders.NStepTransitionAdder(client=reverb.Client(address), n_step=n_step, discount=discount) # The dataset provides an interface to sample from replay. replay_client = reverb.TFClient(address) dataset = datasets.make_reverb_dataset( client=replay_client, environment_spec=environment_spec, transition_adder=True) # Combine with demonstration dataset. transition = functools.partial(_n_step_transition_from_episode, n_step=n_step, discount=discount) dataset_demos = demonstration_dataset.map(transition) dataset = tf.data.experimental.sample_from_datasets( [dataset, dataset_demos], [1 - demonstration_ratio, demonstration_ratio]) # Batch and prefetch. dataset = dataset.batch(batch_size, drop_remainder=True) dataset = dataset.prefetch(prefetch_size) # Use constant 0.05 epsilon greedy policy by default. if epsilon is None: epsilon = tf.Variable(0.05, trainable=False) policy_network = snt.Sequential([ network, lambda q: trfl.epsilon_greedy(q, epsilon=epsilon).sample(), ]) # Create a target network. target_network = copy.deepcopy(network) # Ensure that we create the variables before proceeding (maybe not needed). tf2_utils.create_variables(network, [environment_spec.observations]) tf2_utils.create_variables(target_network, [environment_spec.observations]) # Create the actor which defines how we take actions. actor = actors.FeedForwardActor(policy_network, adder) # The learner updates the parameters (and initializes them). learner = dqn.DQNLearner( network=network, target_network=target_network, discount=discount, importance_sampling_exponent=importance_sampling_exponent, learning_rate=learning_rate, target_update_period=target_update_period, dataset=dataset, replay_client=replay_client) super().__init__(actor=actor, learner=learner, min_observations=max(batch_size, min_replay_size), observations_per_step=float(batch_size) / samples_per_insert)
def __init__( self, environment_spec: specs.EnvironmentSpec, network: snt.Module, batch_size: int = 256, prefetch_size: int = 4, target_update_period: int = 100, samples_per_insert: float = 32.0, min_replay_size: int = 1000, max_replay_size: int = 1000000, importance_sampling_exponent: float = 0.2, priority_exponent: float = 0.6, n_step: int = 5, epsilon: tf.Tensor = None, learning_rate: float = 1e-3, discount: float = 0.99, cql_alpha: float = 1., logger: loggers.Logger = None, counter: counting.Counter = None, checkpoint_subpath: str = '~/acme/', ): """Initialize the agent. Args: environment_spec: description of the actions, observations, etc. network: the online Q network (the one being optimized) batch_size: batch size for updates. prefetch_size: size to prefetch from replay. target_update_period: number of learner steps to perform before updating the target networks. samples_per_insert: number of samples to take from replay for every insert that is made. min_replay_size: minimum replay size before updating. This and all following arguments are related to dataset construction and will be ignored if a dataset argument is passed. max_replay_size: maximum replay size. importance_sampling_exponent: power to which importance weights are raised before normalizing. priority_exponent: exponent used in prioritized sampling. n_step: number of steps to squash into a single transition. epsilon: probability of taking a random action; ignored if a policy network is given. learning_rate: learning rate for the q-network update. discount: discount to use for TD updates. logger: logger object to be used by learner. checkpoint: boolean indicating whether to checkpoint the learner. checkpoint_subpath: directory for the checkpoint. """ # Create a replay server to add data to. This uses no limiter behavior in # order to allow the Agent interface to handle it. replay_table = reverb.Table( name=adders.DEFAULT_PRIORITY_TABLE, sampler=reverb.selectors.Prioritized(priority_exponent), remover=reverb.selectors.Fifo(), max_size=max_replay_size, rate_limiter=reverb.rate_limiters.MinSize(1), signature=adders.NStepTransitionAdder.signature(environment_spec)) self._server = reverb.Server([replay_table], port=None) # The adder is used to insert observations into replay. address = f'localhost:{self._server.port}' adder = adders.NStepTransitionAdder(client=reverb.Client(address), n_step=n_step, discount=discount) # The dataset provides an interface to sample from replay. replay_client = reverb.TFClient(address) dataset = datasets.make_reverb_dataset( client=replay_client, environment_spec=environment_spec, batch_size=batch_size, prefetch_size=prefetch_size, transition_adder=True) # Use constant 0.05 epsilon greedy policy by default. if epsilon is None: epsilon = tf.Variable(0.05, trainable=False) policy_network = snt.Sequential([ network, lambda q: trfl.epsilon_greedy(q, epsilon=epsilon).sample(), ]) # Create a target network. target_network = copy.deepcopy(network) # Ensure that we create the variables before proceeding (maybe not needed). tf2_utils.create_variables(network, [environment_spec.observations]) tf2_utils.create_variables(target_network, [environment_spec.observations]) # Create the actor which defines how we take actions. actor = actors.FeedForwardActor(policy_network, adder) # The learner updates the parameters (and initializes them). learner = CQLLearner( network=network, discount=discount, importance_sampling_exponent=importance_sampling_exponent, learning_rate=learning_rate, cql_alpha=cql_alpha, target_update_period=target_update_period, dataset=dataset, replay_client=replay_client, logger=logger, counter=counter, checkpoint_subpath=checkpoint_subpath) super().__init__(actor=actor, learner=learner, min_observations=max(batch_size, min_replay_size), observations_per_step=float(batch_size) / samples_per_insert)
def __init__( self, environment_spec: specs.EnvironmentSpec, network: networks.QNetwork, batch_size: int = 256, prefetch_size: int = 4, target_update_period: int = 100, samples_per_insert: float = 32.0, min_replay_size: int = 1000, max_replay_size: int = 1000000, importance_sampling_exponent: float = 0.2, priority_exponent: float = 0.6, n_step: int = 5, epsilon: float = 0., learning_rate: float = 1e-3, discount: float = 0.99, ): """Initialize the agent.""" # Create a replay server to add data to. This uses no limiter behavior in # order to allow the Agent interface to handle it. replay_table = reverb.Table( name=adders.DEFAULT_PRIORITY_TABLE, sampler=reverb.selectors.Prioritized(priority_exponent), remover=reverb.selectors.Fifo(), max_size=max_replay_size, rate_limiter=reverb.rate_limiters.MinSize(1)) self._server = reverb.Server([replay_table], port=None) # The adder is used to insert observations into replay. address = f'localhost:{self._server.port}' adder = adders.NStepTransitionAdder(client=reverb.Client(address), n_step=n_step, discount=discount) # The dataset provides an interface to sample from replay. dataset = datasets.make_reverb_dataset( client=reverb.TFClient(address), environment_spec=environment_spec, batch_size=batch_size, prefetch_size=prefetch_size, transition_adder=True) def policy(params: hk.Params, key: jnp.ndarray, observation: jnp.ndarray) -> jnp.ndarray: action_values = hk.transform(network).apply(params, observation) return rlax.epsilon_greedy(epsilon).sample(key, action_values) # The learner updates the parameters (and initializes them). learner = learning.DQNLearner( network=network, obs_spec=environment_spec.observations, rng=hk.PRNGSequence(1), optimizer=optix.adam(learning_rate), discount=discount, importance_sampling_exponent=importance_sampling_exponent, target_update_period=target_update_period, iterator=dataset.as_numpy_iterator(), replay_client=reverb.Client(address), ) variable_client = variable_utils.VariableClient(learner, 'foo') actor = actors.FeedForwardActor(policy=policy, rng=hk.PRNGSequence(1), variable_client=variable_client, adder=adder) super().__init__(actor=actor, learner=learner, min_observations=max(batch_size, min_replay_size), observations_per_step=float(batch_size) / samples_per_insert)
def __init__( self, environment_spec: specs.EnvironmentSpec, network: snt.RNNCore, sequence_length: int, sequence_period: int, counter: counting.Counter = None, logger: loggers.Logger = None, discount: float = 0.99, max_queue_size: int = 100000, batch_size: int = 16, learning_rate: float = 1e-3, entropy_cost: float = 0.01, baseline_cost: float = 0.5, max_abs_reward: Optional[float] = None, max_gradient_norm: Optional[float] = None, ): num_actions = environment_spec.actions.num_values self._logger = logger or loggers.TerminalLogger('agent') queue = reverb.Table.queue(name=adders.DEFAULT_PRIORITY_TABLE, max_size=max_queue_size) self._server = reverb.Server([queue], port=None) self._can_sample = lambda: queue.can_sample(batch_size) address = f'localhost:{self._server.port}' # Component to add things into replay. adder = adders.SequenceAdder( client=reverb.Client(address), period=sequence_period, sequence_length=sequence_length, ) # The dataset object to learn from. extra_spec = { 'core_state': network.initial_state(1), 'logits': tf.ones(shape=(1, num_actions), dtype=tf.float32) } # Remove batch dimensions. extra_spec = tf2_utils.squeeze_batch_dim(extra_spec) dataset = datasets.make_reverb_dataset( client=reverb.TFClient(address), environment_spec=environment_spec, batch_size=batch_size, extra_spec=extra_spec, sequence_length=sequence_length) tf2_utils.create_variables(network, [environment_spec.observations]) self._actor = acting.IMPALAActor(network, adder) self._learner = learning.IMPALALearner( environment_spec=environment_spec, network=network, dataset=dataset, counter=counter, logger=logger, discount=discount, learning_rate=learning_rate, entropy_cost=entropy_cost, baseline_cost=baseline_cost, max_gradient_norm=max_gradient_norm, max_abs_reward=max_abs_reward, )
def __init__( self, environment_spec: specs.EnvironmentSpec, network: snt.Module, batch_size: int = 256, prefetch_size: int = 4, target_update_period: int = 100, samples_per_insert: float = 32.0, min_replay_size: int = 20, max_replay_size: int = 1000000, importance_sampling_exponent: float = 0.2, priority_exponent: float = 0.6, n_step: int = 5, epsilon_init: float = 1.0, epsilon_final: float = 0.01, epsilon_schedule_timesteps: int = 20000, learning_rate: float = 1e-3, discount: float = 0.99, max_gradient_norm: Optional[float] = None, logger: loggers.Logger = None, ): """Initialize the agent. Args: environment_spec: description of the actions, observations, etc. network: the online Q network (the one being optimized) batch_size: batch size for updates. prefetch_size: size to prefetch from replay. target_update_period: number of learner steps to perform before updating the target networks. samples_per_insert: number of samples to take from replay for every insert that is made. min_replay_size: minimum replay size before updating. This and all following arguments are related to dataset construction and will be ignored if a dataset argument is passed. max_replay_size: maximum replay size. importance_sampling_exponent: power to which importance weights are raised before normalizing (beta). See https://arxiv.org/pdf/1710.02298.pdf priority_exponent: exponent used in prioritized sampling (omega). See https://arxiv.org/pdf/1710.02298.pdf n_step: number of steps to squash into a single transition. epsilon_init: Initial epsilon value (probability of taking a random action) epsilon_final: Final epsilon value (probability of taking a random action) epsilon_schedule_timesteps: timesteps to decay epsilon from 'epsilon_init' to 'epsilon_final'. learning_rate: learning rate for the q-network update. discount: discount to use for TD updates. logger: logger object to be used by learner. max_gradient_norm: used for gradient clipping. """ # Create a replay server to add data to. This uses no limiter behavior in # order to allow the Agent interface to handle it. replay_table = reverb.Table( name=adders.DEFAULT_PRIORITY_TABLE, sampler=reverb.selectors.Prioritized(priority_exponent), remover=reverb.selectors.Fifo(), max_size=max_replay_size, rate_limiter=reverb.rate_limiters.MinSize(1), signature=adders.NStepTransitionAdder.signature(environment_spec)) self._server = reverb.Server([replay_table], port=None) # The adder is used to insert observations into replay. address = f'localhost:{self._server.port}' self._adder = adders.NStepTransitionAdder( client=reverb.Client(address), n_step=n_step, discount=discount) # The dataset provides an interface to sample from replay. replay_client = reverb.TFClient(address) dataset = make_reverb_dataset(server_address=address, batch_size=batch_size, prefetch_size=prefetch_size) policy_network = snt.Sequential([ network, EpsilonGreedyExploration( epsilon_init=epsilon_init, epsilon_final=epsilon_final, epsilon_schedule_timesteps=epsilon_schedule_timesteps) ]) # Create a target network. target_network = copy.deepcopy(network) # Ensure that we create the variables before proceeding (maybe not needed). tf2_utils.create_variables(network, [environment_spec.observations]) tf2_utils.create_variables(target_network, [environment_spec.observations]) # Create the actor which defines how we take actions. actor = actors_tf2.FeedForwardActor(policy_network, self._adder) # The learner updates the parameters (and initializes them). learner = learning.DQNLearner( network=network, target_network=target_network, discount=discount, importance_sampling_exponent=importance_sampling_exponent, learning_rate=learning_rate, target_update_period=target_update_period, dataset=dataset, replay_client=replay_client, max_gradient_norm=max_gradient_norm, logger=logger, checkpoint=False) self._saver = tf2_savers.Saver(learner.state) # Deterministic (max-Q) actor. max_Q_network = snt.Sequential([ network, lambda q: trfl.epsilon_greedy(q, epsilon=0.0).sample(), ]) self._deterministic_actor = actors_tf2.FeedForwardActor(max_Q_network) super().__init__(actor=actor, learner=learner, min_observations=max(batch_size, min_replay_size), observations_per_step=float(batch_size) / samples_per_insert)
def __init__( self, environment_spec: specs.EnvironmentSpec, network: snt.Module, params=None, logger: loggers.Logger = None, checkpoint: bool = True, paths: Save_paths = None, ): """Initialize the agent. Args: environment_spec: description of the actions, observations, etc. network: the online Q network (the one being optimized) batch_size: batch size for updates. prefetch_size: size to prefetch from replay. target_update_period: number of learner steps to perform before updating the target networks. samples_per_insert: number of samples to take from replay for every insert that is made. min_replay_size: minimum replay size before updating. This and all following arguments are related to dataset construction and will be ignored if a dataset argument is passed. max_replay_size: maximum replay size. importance_sampling_exponent: power to which importance weights are raised before normalizing. priority_exponent: exponent used in prioritized sampling. n_step: number of steps to squash into a single transition. epsilon: probability of taking a random action; ignored if a policy network is given. learning_rate: learning rate for the q-network update. discount: discount to use for TD updates. logger: logger object to be used by learner. checkpoint: boolean indicating whether to checkpoint the learner. """ # Create a replay server to add data to. This uses no limiter behavior in # order to allow the Agent interface to handle it. if params is None: params = { 'batch_size': 256, 'prefetch_size': 4, 'target_update_period': 100, 'samples_per_insert': 32.0, 'min_replay_size': 1000, 'max_replay_size': 1000000, 'importance_sampling_exponent': 0.2, 'priority_exponent': 0.6, 'n_step': 5, 'epsilon': 0.05, 'learning_rate': 1e-3, 'discount': 0.99, } replay_table = reverb.Table( name=adders.DEFAULT_PRIORITY_TABLE, sampler=reverb.selectors.Prioritized(params['priority_exponent']), remover=reverb.selectors.Fifo(), max_size=params['max_replay_size'], rate_limiter=reverb.rate_limiters.MinSize(1)) self._server = reverb.Server([replay_table], port=None) # The adder is used to insert observations into replay. address = f'localhost:{self._server.port}' adder = adders.NStepTransitionAdder(client=reverb.Client(address), n_step=params['n_step'], discount=params['discount']) # The dataset provides an interface to sample from replay. replay_client = reverb.TFClient(address) dataset = datasets.make_reverb_dataset( client=replay_client, environment_spec=environment_spec, batch_size=params['batch_size'], prefetch_size=params['prefetch_size'], transition_adder=True) # Use constant 0.05 epsilon greedy policy by default. epsilon = tf.Variable(params['epsilon'], trainable=False) policy_network = snt.Sequential([ network, lambda q: trfl.epsilon_greedy(q, epsilon=epsilon).sample(), ]) # Create a target network. target_network = copy.deepcopy(network) # Ensure that we create the variables before proceeding (maybe not needed). # tf2_utils.create_variables(network, [environment_spec.observations]) # tf2_utils.create_variables(target_network, [environment_spec.observations]) # Create the actor which defines how we take actions. actor = actors.FeedForwardActor(policy_network, adder) # The learner updates the parameters (and initializes them). learner = learning.DQNLearner( network=network, target_network=target_network, discount=params['discount'], importance_sampling_exponent=params[ 'importance_sampling_exponent'], learning_rate=params['learning_rate'], target_update_period=params['target_update_period'], dataset=dataset, replay_client=replay_client, logger=logger, checkpoint=checkpoint) if checkpoint: self._checkpointer = tf2_savers.Checkpointer( add_uid=False, objects_to_save=learner.state, directory=paths.data_dir, subdirectory=paths.experiment_name, time_delta_minutes=60.) else: self._checkpointer = None super().__init__(actor=actor, learner=learner, min_observations=max(params['batch_size'], params['min_replay_size']), observations_per_step=float(params['batch_size']) / params['samples_per_insert'])
def __init__( self, environment_spec: specs.EnvironmentSpec, network: snt.Module, batch_size: int = 32, prefetch_size: int = 4, target_update_period: int = 100, samples_per_insert: float = 32.0, min_replay_size: int = 1000, max_replay_size: int = 100000, importance_sampling_exponent: float = 0.2, priority_exponent: float = 0.6, n_step: int = 5, epsilon: Optional[float] = 0.05, learning_rate: float = 1e-3, discount: float = 0.99, logger: loggers.Logger = None, max_gradient_norm: Optional[float] = None, expert_data: List[Dict] = None, ) -> None: """ Initialize the agent. """ # Create a replay server to add data to. This uses no limiter behavior # in order to allow the Agent interface to handle it. replay_table = reverb.Table( name=adders.DEFAULT_PRIORITY_TABLE, sampler=reverb.selectors.Prioritized(priority_exponent), remover=reverb.selectors.Fifo(), max_size=max_replay_size, rate_limiter=reverb.rate_limiters.MinSize(1), signature=adders.NStepTransitionAdder.signature(environment_spec)) self._server = reverb.Server([replay_table], port=None) # The adder is used to insert observations into replay. address = f'localhost:{self._server.port}' adder = adders.NStepTransitionAdder(client=reverb.Client(address), n_step=n_step, discount=discount) # Adding expert data to the replay memory: if expert_data is not None: for d in expert_data: adder.add_first(d["first"]) for (action, next_ts) in d["mid"]: adder.add(np.int32(action), next_ts) # The dataset provides an interface to sample from replay. replay_client = reverb.TFClient(address) dataset = datasets.make_reverb_dataset(server_address=address, batch_size=batch_size, prefetch_size=prefetch_size) # Creating the epsilon greedy policy network: epsilon = tf.Variable(epsilon) policy_network = snt.Sequential([ network, lambda q: trfl.epsilon_greedy(q, epsilon=epsilon).sample(), ]) # Create a target network. target_network = copy.deepcopy(network) # Ensure that we create the variables before proceeding (maybe not # needed). tf2_utils.create_variables(network, [environment_spec.observations]) tf2_utils.create_variables(target_network, [environment_spec.observations]) # Create the actor which defines how we take actions. actor = actors.FeedForwardActor(policy_network, adder) # The learner updates the parameters (and initializes them). learner = learning.DQNLearner( network=network, target_network=target_network, discount=discount, importance_sampling_exponent=importance_sampling_exponent, learning_rate=learning_rate, target_update_period=target_update_period, dataset=dataset, replay_client=replay_client, max_gradient_norm=max_gradient_norm, logger=logger, ) super().__init__(actor=actor, learner=learner, min_observations=max(batch_size, min_replay_size), observations_per_step=float(batch_size) / samples_per_insert)
def __init__(self, environment_spec: specs.EnvironmentSpec, network: snt.RNNCore, target_network: snt.RNNCore, burn_in_length: int, trace_length: int, replay_period: int, demonstration_dataset: tf.data.Dataset, demonstration_ratio: float, counter: counting.Counter = None, logger: loggers.Logger = None, discount: float = 0.99, batch_size: int = 32, target_update_period: int = 100, importance_sampling_exponent: float = 0.2, epsilon: float = 0.01, learning_rate: float = 1e-3, log_to_bigtable: bool = False, log_name: str = 'agent', checkpoint: bool = True, min_replay_size: int = 1000, max_replay_size: int = 1000000, samples_per_insert: float = 32.0): extra_spec = { 'core_state': network.initial_state(1), } # Remove batch dimensions. extra_spec = tf2_utils.squeeze_batch_dim(extra_spec) replay_table = reverb.Table( name=adders.DEFAULT_PRIORITY_TABLE, sampler=reverb.selectors.Uniform(), remover=reverb.selectors.Fifo(), max_size=max_replay_size, rate_limiter=reverb.rate_limiters.MinSize(min_size_to_sample=1), signature=adders.SequenceAdder.signature(environment_spec, extra_spec)) self._server = reverb.Server([replay_table], port=None) address = f'localhost:{self._server.port}' sequence_length = burn_in_length + trace_length + 1 # Component to add things into replay. sequence_kwargs = dict( period=replay_period, sequence_length=sequence_length, ) adder = adders.SequenceAdder(client=reverb.Client(address), **sequence_kwargs) # The dataset object to learn from. dataset = datasets.make_reverb_dataset(server_address=address, sequence_length=sequence_length) # Combine with demonstration dataset. transition = functools.partial(_sequence_from_episode, extra_spec=extra_spec, **sequence_kwargs) dataset_demos = demonstration_dataset.map(transition) dataset = tf.data.experimental.sample_from_datasets( [dataset, dataset_demos], [1 - demonstration_ratio, demonstration_ratio]) # Batch and prefetch. dataset = dataset.batch(batch_size, drop_remainder=True) dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE) tf2_utils.create_variables(network, [environment_spec.observations]) tf2_utils.create_variables(target_network, [environment_spec.observations]) learner = learning.R2D2Learner( environment_spec=environment_spec, network=network, target_network=target_network, burn_in_length=burn_in_length, dataset=dataset, reverb_client=reverb.TFClient(address), counter=counter, logger=logger, sequence_length=sequence_length, discount=discount, target_update_period=target_update_period, importance_sampling_exponent=importance_sampling_exponent, max_replay_size=max_replay_size, learning_rate=learning_rate, store_lstm_state=False, ) self._checkpointer = tf2_savers.Checkpointer( subdirectory='r2d2_learner', time_delta_minutes=60, objects_to_save=learner.state, enable_checkpointing=checkpoint, ) self._snapshotter = tf2_savers.Snapshotter( objects_to_save={'network': network}, time_delta_minutes=60.) policy_network = snt.DeepRNN([ network, lambda qs: trfl.epsilon_greedy(qs, epsilon=epsilon).sample(), ]) actor = actors.RecurrentActor(policy_network, adder) observations_per_step = (float(replay_period * batch_size) / samples_per_insert) super().__init__(actor=actor, learner=learner, min_observations=replay_period * max(batch_size, min_replay_size), observations_per_step=observations_per_step)
def _reverb_tf_client(port): return reverb.TFClient('localhost:{}'.format(port))
def __init__( self, network: snt.Module, model: models.Model, optimizer: snt.Optimizer, n_step: int, discount: float, replay_capacity: int, num_simulations: int, environment_spec: specs.EnvironmentSpec, batch_size: int, ): # Create a replay server for storing transitions. replay_table = reverb.Table( name=adders.DEFAULT_PRIORITY_TABLE, sampler=reverb.selectors.Uniform(), remover=reverb.selectors.Fifo(), max_size=replay_capacity, rate_limiter=reverb.rate_limiters.MinSize(1)) self._server = reverb.Server([replay_table], port=None) # The adder is used to insert observations into replay. address = f'localhost:{self._server.port}' adder = adders.NStepTransitionAdder(client=reverb.Client(address), n_step=n_step, discount=discount) # The dataset provides an interface to sample from replay. replay_client = reverb.TFClient(address) action_spec: specs.DiscreteArray = environment_spec.actions dataset = datasets.make_reverb_dataset( client=replay_client, environment_spec=environment_spec, extra_spec={ 'pi': specs.Array(shape=(action_spec.num_values, ), dtype=np.float32) }, transition_adder=True) dataset = dataset.batch(batch_size, drop_remainder=True) tf2_utils.create_variables(network, [environment_spec.observations]) # Now create the agent components: actor & learner. actor = acting.MCTSActor( environment_spec=environment_spec, model=model, network=network, discount=discount, adder=adder, num_simulations=num_simulations, ) learner = learning.AZLearner( network=network, optimizer=optimizer, dataset=dataset, discount=discount, ) # The parent class combines these together into one 'agent'. super().__init__( actor=actor, learner=learner, min_observations=10, observations_per_step=1, )
def __init__( self, environment_spec: specs.EnvironmentSpec, network: snt.RNNCore, target_network: snt.RNNCore, burn_in_length: int, trace_length: int, replay_period: int, demonstration_generator: iter, demonstration_ratio: float, model_directory: str, counter: counting.Counter = None, logger: loggers.Logger = None, discount: float = 0.99, batch_size: int = 32, target_update_period: int = 100, importance_sampling_exponent: float = 0.2, epsilon: float = 0.01, learning_rate: float = 1e-3, log_to_bigtable: bool = False, log_name: str = 'agent', checkpoint: bool = True, min_replay_size: int = 1000, max_replay_size: int = 1000000, samples_per_insert: float = 32.0, ): extra_spec = { 'core_state': network.initial_state(1), } # replay table # Remove batch dimensions. extra_spec = tf2_utils.squeeze_batch_dim(extra_spec) replay_table = reverb.Table( name=adders.DEFAULT_PRIORITY_TABLE, sampler=reverb.selectors.Prioritized(0.8), remover=reverb.selectors.Fifo(), max_size=max_replay_size, rate_limiter=reverb.rate_limiters.MinSize(min_size_to_sample=1), signature=adders.SequenceAdder.signature(environment_spec, extra_spec)) # demonstation table. demonstration_table = reverb.Table( name='demonstration_table', sampler=reverb.selectors.Prioritized(0.8), remover=reverb.selectors.Fifo(), max_size=max_replay_size, rate_limiter=reverb.rate_limiters.MinSize(min_size_to_sample=1), signature=adders.SequenceAdder.signature(environment_spec, extra_spec)) # launch server self._server = reverb.Server([replay_table, demonstration_table], port=None) address = f'localhost:{self._server.port}' sequence_length = burn_in_length + trace_length + 1 # Component to add things into replay and demo sequence_kwargs = dict( period=replay_period, sequence_length=sequence_length, ) adder = adders.SequenceAdder(client=reverb.Client(address), **sequence_kwargs) priority_function = {demonstration_table.name: lambda x: 1.} demo_adder = adders.SequenceAdder(client=reverb.Client(address), priority_fns=priority_function, **sequence_kwargs) # play demonstrations and write # exhaust the generator # TODO: MAX REPLAY SIZE _prev_action = 1 # this has to come from spec _add_first = True #include this to make datasets equivalent numpy_state = tf2_utils.to_numpy_squeeze(network.initial_state(1)) for ts, action in demonstration_generator: if _add_first: demo_adder.add_first(ts) _add_first = False else: demo_adder.add(_prev_action, ts, extras=(numpy_state, )) _prev_action = action # reset to new episode if ts.last(): _prev_action = None _add_first = True # replay dataset max_in_flight_samples_per_worker = 2 * batch_size if batch_size else 100 dataset = reverb.ReplayDataset.from_table_signature( server_address=address, table=adders.DEFAULT_PRIORITY_TABLE, max_in_flight_samples_per_worker=max_in_flight_samples_per_worker, num_workers_per_iterator= 2, # memory perf improvment attempt https://github.com/deepmind/acme/issues/33 sequence_length=sequence_length, emit_timesteps=sequence_length is None) # demonstation dataset d_dataset = reverb.ReplayDataset.from_table_signature( server_address=address, table=demonstration_table.name, max_in_flight_samples_per_worker=max_in_flight_samples_per_worker, num_workers_per_iterator=2, sequence_length=sequence_length, emit_timesteps=sequence_length is None) dataset = tf.data.experimental.sample_from_datasets( [dataset, d_dataset], [1 - demonstration_ratio, demonstration_ratio]) # Batch and prefetch. dataset = dataset.batch(batch_size, drop_remainder=True) dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE) tf2_utils.create_variables(network, [environment_spec.observations]) tf2_utils.create_variables(target_network, [environment_spec.observations]) learner = learning.R2D2Learner( environment_spec=environment_spec, network=network, target_network=target_network, burn_in_length=burn_in_length, dataset=dataset, reverb_client=reverb.TFClient(address), counter=counter, logger=logger, sequence_length=sequence_length, discount=discount, target_update_period=target_update_period, importance_sampling_exponent=importance_sampling_exponent, max_replay_size=max_replay_size, learning_rate=learning_rate, store_lstm_state=False, ) self._checkpointer = tf2_savers.Checkpointer( directory=model_directory, subdirectory='r2d2_learner_v1', time_delta_minutes=15, objects_to_save=learner.state, enable_checkpointing=checkpoint, ) self._snapshotter = tf2_savers.Snapshotter(objects_to_save=None, time_delta_minutes=15000., directory=model_directory) policy_network = snt.DeepRNN([ network, lambda qs: trfl.epsilon_greedy(qs, epsilon=epsilon).sample(), ]) actor = actors.RecurrentActor(policy_network, adder) observations_per_step = (float(replay_period * batch_size) / samples_per_insert) super().__init__(actor=actor, learner=learner, min_observations=replay_period * max(batch_size, min_replay_size), observations_per_step=observations_per_step)
def __init__( self, environment_spec: specs.EnvironmentSpec, network: snt.RNNCore, burn_in_length: int, trace_length: int, replay_period: int, counter: counting.Counter = None, logger: loggers.Logger = None, discount: float = 0.99, batch_size: int = 32, prefetch_size: int = tf.data.experimental.AUTOTUNE, target_update_period: int = 100, importance_sampling_exponent: float = 0.2, priority_exponent: float = 0.6, epsilon: float = 0.01, learning_rate: float = 1e-3, min_replay_size: int = 1000, max_replay_size: int = 1000000, samples_per_insert: float = 32.0, store_lstm_state: bool = True, max_priority_weight: float = 0.9, checkpoint: bool = True, ): replay_table = reverb.Table( name=adders.DEFAULT_PRIORITY_TABLE, sampler=reverb.selectors.Prioritized(priority_exponent), remover=reverb.selectors.Fifo(), max_size=max_replay_size, rate_limiter=reverb.rate_limiters.MinSize(min_size_to_sample=1)) self._server = reverb.Server([replay_table], port=None) address = f'localhost:{self._server.port}' sequence_length = burn_in_length + trace_length + 1 # Component to add things into replay. adder = adders.SequenceAdder( client=reverb.Client(address), period=replay_period, sequence_length=sequence_length, ) # The dataset object to learn from. reverb_client = reverb.TFClient(address) extra_spec = { 'core_state': network.initial_state(1), } # Remove batch dimensions. extra_spec = tf2_utils.squeeze_batch_dim(extra_spec) dataset = datasets.make_reverb_dataset( client=reverb_client, environment_spec=environment_spec, batch_size=batch_size, prefetch_size=prefetch_size, extra_spec=extra_spec, sequence_length=sequence_length) target_network = copy.deepcopy(network) tf2_utils.create_variables(network, [environment_spec.observations]) tf2_utils.create_variables(target_network, [environment_spec.observations]) learner = learning.R2D2Learner( environment_spec=environment_spec, network=network, target_network=target_network, burn_in_length=burn_in_length, sequence_length=sequence_length, dataset=dataset, reverb_client=reverb_client, counter=counter, logger=logger, discount=discount, target_update_period=target_update_period, importance_sampling_exponent=importance_sampling_exponent, max_replay_size=max_replay_size, learning_rate=learning_rate, store_lstm_state=store_lstm_state, max_priority_weight=max_priority_weight, ) self._checkpointer = tf2_savers.Checkpointer( subdirectory='r2d2_learner', time_delta_minutes=60, objects_to_save=learner.state, enable_checkpointing=checkpoint, ) self._snapshotter = tf2_savers.Snapshotter( objects_to_save={'network': network}, time_delta_minutes=60.) policy_network = snt.DeepRNN([ network, lambda qs: trfl.epsilon_greedy(qs, epsilon=epsilon).sample(), ]) actor = actors.RecurrentActor(policy_network, adder) observations_per_step = (float(replay_period * batch_size) / samples_per_insert) super().__init__(actor=actor, learner=learner, min_observations=replay_period * max(batch_size, min_replay_size), observations_per_step=observations_per_step)
def __init__(self, environment_spec: specs.EnvironmentSpec, policy_network: snt.Module, critic_network: snt.Module, observation_network: types.TensorTransformation = tf.identity, discount: float = 0.99, batch_size: int = 256, prefetch_size: int = 4, target_update_period: int = 100, min_replay_size: int = 1000, max_replay_size: int = 1000000, samples_per_insert: float = 32.0, n_step: int = 5, sigma: float = 0.3, clipping: bool = True, logger: loggers.Logger = None, counter: counting.Counter = None, checkpoint: bool = True, replay_table_name: str = adders.DEFAULT_PRIORITY_TABLE): """Initialize the agent. Args: environment_spec: description of the actions, observations, etc. policy_network: the online (optimized) policy. critic_network: the online critic. observation_network: optional network to transform the observations before they are fed into any network. discount: discount to use for TD updates. batch_size: batch size for updates. prefetch_size: size to prefetch from replay. target_update_period: number of learner steps to perform before updating the target networks. min_replay_size: minimum replay size before updating. max_replay_size: maximum replay size. samples_per_insert: number of samples to take from replay for every insert that is made. n_step: number of steps to squash into a single transition. sigma: standard deviation of zero-mean, Gaussian exploration noise. clipping: whether to clip gradients by global norm. logger: logger object to be used by learner. counter: counter object used to keep track of steps. checkpoint: boolean indicating whether to checkpoint the learner. replay_table_name: string indicating what name to give the replay table. """ # Create a replay server to add data to. This uses no limiter behavior in # order to allow the Agent interface to handle it. replay_table = reverb.Table( name=replay_table_name, sampler=reverb.selectors.Uniform(), remover=reverb.selectors.Fifo(), max_size=max_replay_size, rate_limiter=reverb.rate_limiters.MinSize(1), signature=adders.NStepTransitionAdder.signature(environment_spec)) self._server = reverb.Server([replay_table], port=None) # The adder is used to insert observations into replay. address = f'localhost:{self._server.port}' adder = adders.NStepTransitionAdder( priority_fns={replay_table_name: lambda x: 1.}, client=reverb.Client(address), n_step=n_step, discount=discount) # The dataset provides an interface to sample from replay. dataset = datasets.make_reverb_dataset( table=replay_table_name, client=reverb.TFClient(address), environment_spec=environment_spec, batch_size=batch_size, prefetch_size=prefetch_size, transition_adder=True) # Get observation and action specs. act_spec = environment_spec.actions obs_spec = environment_spec.observations emb_spec = tf2_utils.create_variables(observation_network, [obs_spec]) # pytype: disable=wrong-arg-types # Make sure observation network is a Sonnet Module. observation_network = tf2_utils.to_sonnet_module(observation_network) # Create target networks. target_policy_network = copy.deepcopy(policy_network) target_critic_network = copy.deepcopy(critic_network) target_observation_network = copy.deepcopy(observation_network) # Create the behavior policy. behavior_network = snt.Sequential([ observation_network, policy_network, networks.ClippedGaussian(sigma), networks.ClipToSpec(act_spec), ]) # Create variables. tf2_utils.create_variables(policy_network, [emb_spec]) tf2_utils.create_variables(critic_network, [emb_spec, act_spec]) tf2_utils.create_variables(target_policy_network, [emb_spec]) tf2_utils.create_variables(target_critic_network, [emb_spec, act_spec]) tf2_utils.create_variables(target_observation_network, [obs_spec]) # Create the actor which defines how we take actions. actor = actors.FeedForwardActor(behavior_network, adder=adder) # Create optimizers. policy_optimizer = snt.optimizers.Adam(learning_rate=1e-4) critic_optimizer = snt.optimizers.Adam(learning_rate=1e-4) # The learner updates the parameters (and initializes them). learner = learning.DDPGLearner( policy_network=policy_network, critic_network=critic_network, observation_network=observation_network, target_policy_network=target_policy_network, target_critic_network=target_critic_network, target_observation_network=target_observation_network, policy_optimizer=policy_optimizer, critic_optimizer=critic_optimizer, clipping=clipping, discount=discount, target_update_period=target_update_period, dataset=dataset, counter=counter, logger=logger, checkpoint=checkpoint, ) super().__init__(actor=actor, learner=learner, min_observations=max(batch_size, min_replay_size), observations_per_step=float(batch_size) / samples_per_insert)
def get_client(self): print("retracing ReverbUniformReplayBuffer get_client") return reverb.TFClient(f'localhost:{self.reverb_server.port}')