def setUp(self): super(ReverbReplayBufferTest, self).setUp() # Prepare the environment (and the corresponding specs). self._env = test_envs.EpisodeCountingEnv(steps_per_episode=3) tensor_time_step_spec = tf.nest.map_structure(tensor_spec.from_spec, self._env.time_step_spec()) tensor_action_spec = tensor_spec.from_spec(self._env.action_spec()) self._data_spec = trajectory.Trajectory( step_type=tensor_time_step_spec.step_type, observation=tensor_time_step_spec.observation, action=tensor_action_spec, policy_info=(), next_step_type=tensor_time_step_spec.step_type, reward=tensor_time_step_spec.reward, discount=tensor_time_step_spec.discount, ) table_spec = tf.nest.map_structure( lambda s: tf.TensorSpec(dtype=s.dtype, shape=(None,) + s.shape), self._data_spec) self._array_data_spec = tensor_spec.to_nest_array_spec(self._data_spec) # Initialize and start a Reverb server (and set up a client to it). self._table_name = 'test_table' uniform_table = reverb.Table( self._table_name, max_size=100, sampler=reverb.selectors.Uniform(), remover=reverb.selectors.Fifo(), rate_limiter=reverb.rate_limiters.MinSize(1), signature=table_spec, ) self._server = reverb.Server([uniform_table]) self._py_client = reverb.Client('localhost:{}'.format(self._server.port))
def test_dataset_with_variable_sequence_length_truncates(self): spec = tf.TensorSpec((), tf.int64) table_spec = tf.TensorSpec((None, ), tf.int64) table = reverb.Table( name=self._table_name, sampler=reverb.selectors.Fifo(), remover=reverb.selectors.Fifo(), max_times_sampled=1, max_size=100, rate_limiter=reverb.rate_limiters.MinSize(1), signature=table_spec, ) server = reverb.Server([table]) py_client = reverb.Client('localhost:{}'.format(server.port)) # Insert two episodes: one of length 3 and one of length 5 with py_client.trajectory_writer(10) as writer: writer.append(1) writer.append(2) writer.append(3) writer.create_item(self._table_name, trajectory=writer.history[-3:], priority=5) with py_client.trajectory_writer(10) as writer: writer.append(10) writer.append(20) writer.append(30) writer.append(40) writer.append(50) writer.create_item(self._table_name, trajectory=writer.history[-5:], priority=5) replay = reverb_replay_buffer.ReverbReplayBuffer( spec, self._table_name, local_server=server, sequence_length=None, rate_limiter_timeout_ms=100) ds = replay.as_dataset(single_deterministic_pass=True, num_steps=2) it = iter(ds) # Expect [1, 2] data, _ = next(it) self.assertAllEqual(data, [1, 2]) # Expect [10, 20] data, _ = next(it) self.assertAllEqual(data, [10, 20]) # Expect [30, 40] data, _ = next(it) self.assertAllEqual(data, [30, 40]) with self.assertRaises(StopIteration): next(it)
def collect(summary_dir: Text, environment_name: Text, collect_policy: py_tf_eager_policy.PyTFEagerPolicyBase, replay_buffer_server_address: Text, variable_container_server_address: Text, suite_load_fn: Callable[ [Text], py_environment.PyEnvironment] = suite_mujoco.load, initial_collect_steps: int = 10000, max_train_steps: int = 2000000) -> None: """Collects experience using a policy updated after every episode.""" # Create the environment. For now support only single environment collection. collect_env = suite_load_fn(environment_name) # Create the variable container. train_step = train_utils.create_train_step() variables = { reverb_variable_container.POLICY_KEY: collect_policy.variables(), reverb_variable_container.TRAIN_STEP_KEY: train_step } variable_container = reverb_variable_container.ReverbVariableContainer( variable_container_server_address, table_names=[reverb_variable_container.DEFAULT_TABLE]) variable_container.update(variables) # Create the replay buffer observer. rb_observer = reverb_utils.ReverbAddTrajectoryObserver( reverb.Client(replay_buffer_server_address), table_name=reverb_replay_buffer.DEFAULT_TABLE, sequence_length=2, stride_length=1) random_policy = random_py_policy.RandomPyPolicy(collect_env.time_step_spec(), collect_env.action_spec()) initial_collect_actor = actor.Actor( collect_env, random_policy, train_step, steps_per_run=initial_collect_steps, observers=[rb_observer]) logging.info('Doing initial collect.') initial_collect_actor.run() env_step_metric = py_metrics.EnvironmentSteps() collect_actor = actor.Actor( collect_env, collect_policy, train_step, steps_per_run=1, metrics=actor.collect_metrics(10), summary_dir=summary_dir, observers=[rb_observer, env_step_metric]) # Run the experience collection loop. while train_step.numpy() < max_train_steps: logging.info('Collecting with policy at step: %d', train_step.numpy()) collect_actor.run() variable_container.update(variables)
def __init__( self, environment_spec: specs.EnvironmentSpec, builder: builders.ActorLearnerBuilder, networks: Any, policy_network: actors.FeedForwardPolicy, min_replay_size: int = 1000, samples_per_insert: float = 256.0, batch_size: int = 256, num_sgd_steps_per_step: int = 1, counter: counting.Counter = None, logger: loggers.Logger = None, checkpoint: bool = True, ): """Initialize the agent. Args: environment_spec: description of the actions, observations, etc. builder: builder defining an RL algorithm to train. networks: network objects to be passed to the learner. policy_network: function that given an observation returns actions. min_replay_size: minimum replay size before updating. samples_per_insert: number of samples to take from replay for every insert that is made. batch_size: batch size for updates. num_sgd_steps_per_step: how many sgd steps a learner does per 'step' call. For performance reasons (especially to reduce TPU host-device transfer times) it is performance-beneficial to do multiple sgd updates at once, provided that it does not hurt the training, which needs to be verified empirically for each environment. counter: counter object used to keep track of steps. logger: logger object to be used by learner. checkpoint: boolean indicating whether to checkpoint the learner. """ # Create the replay server and grab its address. replay_tables = builder.make_replay_tables(environment_spec) replay_server = reverb.Server(replay_tables, port=None) replay_client = reverb.Client(f'localhost:{replay_server.port}') # Create actor, dataset, and learner for generating, storing, and consuming # data respectively. adder = builder.make_adder(replay_client) dataset = builder.make_dataset_iterator(replay_client) learner = builder.make_learner(networks=networks, dataset=dataset, replay_client=replay_client, counter=counter, logger=logger, checkpoint=checkpoint) actor = builder.make_actor(policy_network, adder, variable_source=learner) effective_batch_size = batch_size * num_sgd_steps_per_step super().__init__( actor=actor, learner=learner, min_observations=max(effective_batch_size, min_replay_size), observations_per_step=float(effective_batch_size) / samples_per_insert) # Save the replay so we don't garbage collect it. self._replay_server = replay_server
def trainer_main_tf_dataset(perwez_url, config): weight_send = perwez.SendSocket(perwez_url, "weight", broadcast=True) # init reverb reverb_client = reverb.Client(f"localhost:{PORT}") # reverb dataset def _make_dataset(_): dataset = reverb.ReplayDataset( f"localhost:{PORT}", TABLE_NAME, max_in_flight_samples_per_worker=config["common"]["batch_size"], dtypes=(tf.float32, tf.int64, tf.float32, tf.float32, tf.float32), shapes=( tf.TensorShape((4, 84, 84)), tf.TensorShape([]), tf.TensorShape([]), tf.TensorShape((4, 84, 84)), tf.TensorShape([]), ), ) dataset = dataset.batch(config["common"]["batch_size"], drop_remainder=True) return dataset num_parallel_calls = 16 prefetch_size = 4 dataset = tf.data.Dataset.range(num_parallel_calls) dataset = dataset.interleave( map_func=_make_dataset, cycle_length=num_parallel_calls, num_parallel_calls=num_parallel_calls, deterministic=False, ) dataset = dataset.prefetch(prefetch_size) numpy_iter = dataset.as_numpy_iterator() trainer = get_trainer(config) sync_weights_interval = config["common"]["sync_weights_interval"] ts = 0 while True: ts += 1 info, data = next(numpy_iter) indices = info.key weights = info.probability weights = (weights / weights.min()) ** (-0.4) loss = trainer.step(data, weights=weights) reverb_client.mutate_priorities( TABLE_NAME, updates=dict(zip(np.asarray(indices), np.asarray(loss))) ) if ts % sync_weights_interval == 0: weight_send.send(trainer.save_weights().getbuffer())
def _push_nested_data(self, server_address: Optional[Text] = None) -> None: # Create Python client. address = server_address or self._server_address client = reverb.Client(address) with client.writer(1) as writer: writer.append([ np.array(0, dtype=np.int64), np.array([1, 1], dtype=np.float64), np.array([[2], [3]], dtype=np.int32) ]) writer.create_item(reverb_variable_container.DEFAULT_TABLE, 1, 1.0) self.assertEqual( client.server_info()[ reverb_variable_container.DEFAULT_TABLE].current_size, 1)
def main(argv): if len(argv) > 1: raise app.UsageError('Too many command-line arguments.') config = _CONFIG.value # Connect to reverb reverb_client = reverb.Client(FLAGS.reverb_address) puddles = arenas.get_arena(config.env.pw.arena) pw = puddle_world.PuddleWorld(puddles=puddles, goal_position=geometry.Point((1.0, 1.0))) dpw = pw_utils.DiscretizedPuddleWorld(pw, config.env.pw.num_bins) if FLAGS.eval: eval_worker(dpw, reverb_client, config) else: train_worker(dpw, reverb_client, config)
def __init__( self, actor_id, environment_module, environment_fn_name, environment_kwargs, network_module, network_fn_name, network_kwargs, adder_module, adder_fn_name, adder_kwargs, replay_server_address, variable_server_name, variable_server_address, counter: counting.Counter = None, logger: loggers.Logger = None, ): # Counter and Logger self._actor_id = actor_id self._counter = counter or counting.Counter() self._logger = logger or loggers.make_default_logger( f'actor_{actor_id}') # Create the environment self._environment = getattr(environment_module, environment_fn_name)(**environment_kwargs) env_spec = acme.make_environment_spec(self._environment) # Create actor's network self._network = getattr(network_module, network_fn_name)(**network_kwargs) tf2_utils.create_variables(self._network, [env_spec.observations]) self._variables = tree.flatten(self._network.variables) self._policy = tf.function(self._network) # The adder is used to insert observations into replay. self._adder = getattr(adder_module, adder_fn_name)( reverb.Client(replay_server_address), **adder_kwargs) variable_client = reverb.TFClient(variable_server_address) self._variable_dataset = variable_client.dataset( table=variable_server_name, dtypes=[tf.float32 for _ in self._variables], shapes=[v.shape for v in self._variables])
def _assert_nested_variable_in_server(self, server_address: Optional[Text] = None ) -> None: # Create Python client. address = server_address or self._server_address client = reverb.Client(address) self.assertEqual( client.server_info()[ reverb_variable_container.DEFAULT_TABLE].current_size, 1) # Draw one sample from the server using the Python client. content = next( iter(client.sample(reverb_variable_container.DEFAULT_TABLE, 1)))[0].data # Internally in Reverb the data is stored in the form of flat numpy lists. self.assertLen(content, 3) self.assertAllEqual(content[0], np.array(0, dtype=np.int64)) self.assertAllClose(content[1], np.array([1, 1], dtype=np.float64)) self.assertAllEqual(content[2], np.array([[2], [3]], dtype=np.int32))
def __init__(self, env_name, buffer_table_name, buffer_server_port, buffer_min_size, n_steps=2, data=None, make_sparse=False): # environments; their hyperparameters self._train_env = gym.make(env_name) self._eval_env = gym.make(env_name) self._n_outputs = self._train_env.action_space.n # number of actions self._input_shape = self._train_env.observation_space.shape # data contains weighs, masks, and a corresponding reward self._data = data self._is_sparse = make_sparse assert not (not data and make_sparse), "Making a sparse model needs data of weights and mask" # networks self._model = None self._target_model = None # fraction of random exp sampling self._epsilon = 0.1 # hyperparameters for optimization self._optimizer = keras.optimizers.Adam(lr=1e-3) self._loss_fn = keras.losses.mean_squared_error # buffer; hyperparameters for a reward calculation self._table_name = buffer_table_name # an object with a client, which is used to store data on a server self._replay_memory_client = reverb.Client(f'localhost:{buffer_server_port}') # make a batch size equal of a minimal size of a buffer self._sample_batch_size = buffer_min_size self._n_steps = n_steps # 1. amount of steps stored per item, it should be at least 2; # 2. for details see function _collect_trajectories_from_episode() # initialize a dataset to be used to sample data from a server self._dataset = storage.initialize_dataset(buffer_server_port, buffer_table_name, self._input_shape, self._sample_batch_size, self._n_steps) self._iterator = iter(self._dataset) self._discount_rate = tf.constant(0.95, dtype=tf.float32) self._items_sampled = 0
def worker_main(perwez_url, config, idx): reverb_client = reverb.Client(f"localhost:{PORT}") reverb_writer = reverb_client.writer(1) weight_recv = perwez.RecvSocket(perwez_url, "weight", broadcast=True) batch_size = config["common"]["batch_size"] num_workers = config["common"]["num_workers"] eps = 0.4 ** (1 + (idx / (num_workers - 1)) * 7) solver = get_solver(config, device="cpu") log_flag = idx >= num_workers + (-num_workers // 3) # aligned with ray worker = get_worker( config, exploration=eps, solver=solver, logger=getLogger(f"worker{idx}") if log_flag else None, ) while True: # load weights if not weight_recv.empty(): worker.load_weights(io.BytesIO(weight_recv.recv())) # step data = worker.step_batch(batch_size) loss = worker.solver.calc_loss(data) # format s0, a, r, s1, done = data s0 = np.asarray(s0, dtype="f4") a = np.asarray(a, dtype="i8") r = np.asarray(r, dtype="f4") s1 = np.asarray(s1, dtype="f4") done = np.asarray(done, dtype="f4") loss = np.asarray(loss, dtype="f4") # upload for i, _ in enumerate(s0): reverb_writer.append([s0[i], a[i], r[i], s1[i], done[i]]) reverb_writer.create_item( table=TABLE_NAME, num_timesteps=1, priority=loss[i] )
def learner(self, replay: reverb.Client, counter: counting.Counter): """The Learning part of the agent.""" # Create the networks. network = self._network_factory(self._env_spec.actions) target_network = copy.deepcopy(network) tf2_utils.create_variables(network, [self._env_spec.observations]) tf2_utils.create_variables(target_network, [self._env_spec.observations]) # The dataset object to learn from. replay_client = reverb.Client(replay.server_address) dataset = datasets.make_reverb_dataset( server_address=replay.server_address, batch_size=self._batch_size, prefetch_size=self._prefetch_size) logger = loggers.make_default_logger('learner', steps_key='learner_steps') # Return the learning agent. counter = counting.Counter(counter, 'learner') learner = learning.DQNLearner( network=network, target_network=target_network, discount=self._discount, importance_sampling_exponent=self._importance_sampling_exponent, learning_rate=self._learning_rate, target_update_period=self._target_update_period, dataset=dataset, replay_client=replay_client, counter=counter, logger=logger) return tf2_savers.CheckpointingRunner(learner, subdirectory='dqn_learner', time_delta_minutes=60)
def trainer_main_np_client(perwez_url, config): weight_send = perwez.SendSocket(perwez_url, "weight", broadcast=True) # init reverb reverb_client = reverb.Client(f"localhost:{PORT}") trainer = get_trainer(config) sync_weights_interval = config["common"]["sync_weights_interval"] ts = 0 while True: ts += 1 samples = reverb_client.sample(TABLE_NAME, config["common"]["batch_size"]) samples = list(samples) data, indices, weights = _reverb_samples_to_ndarray(samples) weights = (weights / weights.min()) ** (-0.4) loss = trainer.step(data, weights=weights) reverb_client.mutate_priorities( TABLE_NAME, updates=dict(zip(np.asarray(indices), np.asarray(loss))) ) if ts % sync_weights_interval == 0: weight_send.send(trainer.save_weights().getbuffer())
def make_reverb_online_queue( environment_spec: specs.EnvironmentSpec, extra_spec: Dict[str, Any], max_queue_size: int, sequence_length: int, sequence_period: int, batch_size: int, replay_table_name: str = adders.DEFAULT_PRIORITY_TABLE, ) -> ReverbReplay: """Creates a single process queue from an environment spec and extra_spec.""" signature = adders.SequenceAdder.signature(environment_spec, extra_spec) queue = reverb.Table.queue(name=replay_table_name, max_size=max_queue_size, signature=signature) server = reverb.Server([queue], port=None) can_sample = lambda: queue.can_sample(batch_size) # Component to add things into replay. address = f'localhost:{server.port}' adder = adders.SequenceAdder( client=reverb.Client(address), period=sequence_period, sequence_length=sequence_length, ) # The dataset object to learn from. # We don't use datasets.make_reverb_dataset() here to avoid interleaving # and prefetching, that doesn't work well with can_sample() check on update. dataset = reverb.ReplayDataset.from_table_signature( server_address=address, table=replay_table_name, max_in_flight_samples_per_worker=1, sequence_length=sequence_length, emit_timesteps=False) dataset = dataset.batch(batch_size, drop_remainder=True) data_iterator = dataset.as_numpy_iterator() return ReverbReplay(server, adder, data_iterator, can_sample=can_sample)
def build(self): """Creates reverb server, client and dataset.""" self._reverb_server = reverb.Server( tables=[ reverb.Table( name="replay_buffer", sampler=reverb.selectors.Uniform(), remover=reverb.selectors.Fifo(), max_size=self._max_replay_size, rate_limiter=reverb.rate_limiters.MinSize(1), signature=self._signature, ), ], port=None, ) self._reverb_client = reverb.Client(f"localhost:{self._reverb_server.port}") self._reverb_dataset = reverb.TrajectoryDataset.from_table_signature( server_address=f"localhost:{self._reverb_server.port}", table="replay_buffer", max_in_flight_samples_per_worker=2 * self._batch_size, ) self._batched_dataset = self._reverb_dataset.batch( self._batch_size, drop_remainder=True ).as_numpy_iterator()
def main(_): environment = fakes.ContinuousEnvironment(action_dim=8, observation_dim=87, episode_length=10000000) spec = specs.make_environment_spec(environment) replay_tables = make_replay_tables(spec) replay_server = reverb.Server(replay_tables, port=None) replay_client = reverb.Client(f'localhost:{replay_server.port}') adder = make_adder(replay_client) timestep = environment.reset() adder.add_first(timestep) # TODO(raveman): Consider also filling the table to say 1M (too slow). for steps in range(10000): if steps % 1000 == 0: logging.info('Processed %s steps', steps) action = np.asarray(np.random.uniform(-1, 1, (8,)), dtype=np.float32) next_timestep = environment.step(action) adder.add(action, next_timestep, extras=()) for batch_size in [256, 256 * 8, 256 * 64]: for prefetch_size in [0, 1, 4]: print(f'Processing batch_size={batch_size} prefetch_size={prefetch_size}') ds = datasets.make_reverb_dataset( table='default', server_address=replay_client.server_address, batch_size=batch_size, prefetch_size=prefetch_size, ) it = ds.as_numpy_iterator() for iteration in range(3): t = time.time() for _ in range(1000): _ = next(it) print(f'Iteration {iteration} finished in {time.time() - t}s')
def __init__( self, environment_spec: specs.EnvironmentSpec, network: snt.RNNCore, target_network: snt.RNNCore, burn_in_length: int, trace_length: int, replay_period: int, demonstration_generator: iter, demonstration_ratio: float, model_directory: str, counter: counting.Counter = None, logger: loggers.Logger = None, discount: float = 0.99, batch_size: int = 32, target_update_period: int = 100, importance_sampling_exponent: float = 0.2, epsilon: float = 0.01, learning_rate: float = 1e-3, log_to_bigtable: bool = False, log_name: str = 'agent', checkpoint: bool = True, min_replay_size: int = 1000, max_replay_size: int = 1000000, samples_per_insert: float = 32.0, ): extra_spec = { 'core_state': network.initial_state(1), } # replay table # Remove batch dimensions. extra_spec = tf2_utils.squeeze_batch_dim(extra_spec) replay_table = reverb.Table( name=adders.DEFAULT_PRIORITY_TABLE, sampler=reverb.selectors.Prioritized(0.8), remover=reverb.selectors.Fifo(), max_size=max_replay_size, rate_limiter=reverb.rate_limiters.MinSize(min_size_to_sample=1), signature=adders.SequenceAdder.signature(environment_spec, extra_spec)) # demonstation table. demonstration_table = reverb.Table( name='demonstration_table', sampler=reverb.selectors.Prioritized(0.8), remover=reverb.selectors.Fifo(), max_size=max_replay_size, rate_limiter=reverb.rate_limiters.MinSize(min_size_to_sample=1), signature=adders.SequenceAdder.signature(environment_spec, extra_spec)) # launch server self._server = reverb.Server([replay_table, demonstration_table], port=None) address = f'localhost:{self._server.port}' sequence_length = burn_in_length + trace_length + 1 # Component to add things into replay and demo sequence_kwargs = dict( period=replay_period, sequence_length=sequence_length, ) adder = adders.SequenceAdder(client=reverb.Client(address), **sequence_kwargs) priority_function = {demonstration_table.name: lambda x: 1.} demo_adder = adders.SequenceAdder(client=reverb.Client(address), priority_fns=priority_function, **sequence_kwargs) # play demonstrations and write # exhaust the generator # TODO: MAX REPLAY SIZE _prev_action = 1 # this has to come from spec _add_first = True #include this to make datasets equivalent numpy_state = tf2_utils.to_numpy_squeeze(network.initial_state(1)) for ts, action in demonstration_generator: if _add_first: demo_adder.add_first(ts) _add_first = False else: demo_adder.add(_prev_action, ts, extras=(numpy_state, )) _prev_action = action # reset to new episode if ts.last(): _prev_action = None _add_first = True # replay dataset max_in_flight_samples_per_worker = 2 * batch_size if batch_size else 100 dataset = reverb.ReplayDataset.from_table_signature( server_address=address, table=adders.DEFAULT_PRIORITY_TABLE, max_in_flight_samples_per_worker=max_in_flight_samples_per_worker, num_workers_per_iterator= 2, # memory perf improvment attempt https://github.com/deepmind/acme/issues/33 sequence_length=sequence_length, emit_timesteps=sequence_length is None) # demonstation dataset d_dataset = reverb.ReplayDataset.from_table_signature( server_address=address, table=demonstration_table.name, max_in_flight_samples_per_worker=max_in_flight_samples_per_worker, num_workers_per_iterator=2, sequence_length=sequence_length, emit_timesteps=sequence_length is None) dataset = tf.data.experimental.sample_from_datasets( [dataset, d_dataset], [1 - demonstration_ratio, demonstration_ratio]) # Batch and prefetch. dataset = dataset.batch(batch_size, drop_remainder=True) dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE) tf2_utils.create_variables(network, [environment_spec.observations]) tf2_utils.create_variables(target_network, [environment_spec.observations]) learner = learning.R2D2Learner( environment_spec=environment_spec, network=network, target_network=target_network, burn_in_length=burn_in_length, dataset=dataset, reverb_client=reverb.TFClient(address), counter=counter, logger=logger, sequence_length=sequence_length, discount=discount, target_update_period=target_update_period, importance_sampling_exponent=importance_sampling_exponent, max_replay_size=max_replay_size, learning_rate=learning_rate, store_lstm_state=False, ) self._checkpointer = tf2_savers.Checkpointer( directory=model_directory, subdirectory='r2d2_learner_v1', time_delta_minutes=15, objects_to_save=learner.state, enable_checkpointing=checkpoint, ) self._snapshotter = tf2_savers.Snapshotter(objects_to_save=None, time_delta_minutes=15000., directory=model_directory) policy_network = snt.DeepRNN([ network, lambda qs: trfl.epsilon_greedy(qs, epsilon=epsilon).sample(), ]) actor = actors.RecurrentActor(policy_network, adder) observations_per_step = (float(replay_period * batch_size) / samples_per_insert) super().__init__(actor=actor, learner=learner, min_observations=replay_period * max(batch_size, min_replay_size), observations_per_step=observations_per_step)
def __init__( self, environment_spec: specs.EnvironmentSpec, policy_network: snt.Module, critic_network: snt.Module, discount: float = 0.99, batch_size: int = 256, prefetch_size: int = 4, target_update_period: int = 100, prior_network: Optional[snt.Module] = None, policy_optimizer: Optional[snt.Optimizer] = None, critic_optimizer: Optional[snt.Optimizer] = None, prior_optimizer: Optional[snt.Optimizer] = None, distillation_cost: Optional[float] = 1e-3, entropy_regularizer_cost: Optional[float] = 1e-3, min_replay_size: int = 1000, max_replay_size: int = 1000000, samples_per_insert: float = 32.0, sequence_length: int = 10, sigma: float = 0.3, replay_table_name: str = reverb_adders.DEFAULT_PRIORITY_TABLE, counter: Optional[counting.Counter] = None, logger: Optional[loggers.Logger] = None, checkpoint: bool = True, ): """Initialize the agent. Args: environment_spec: description of the actions, observations, etc. policy_network: the online (optimized) policy. critic_network: the online critic. discount: discount to use for TD updates. batch_size: batch size for updates. prefetch_size: size to prefetch from replay. target_update_period: number of learner steps to perform before updating the target networks. prior_network: an optional `behavior prior` to regularize against. policy_optimizer: optimizer for the policy network updates. critic_optimizer: optimizer for the critic network updates. prior_optimizer: optimizer for the prior network updates. distillation_cost: a multiplier to be used when adding distillation against the prior to the losses. entropy_regularizer_cost: a multiplier used for per state sample based entropy added to the actor loss. min_replay_size: minimum replay size before updating. max_replay_size: maximum replay size. samples_per_insert: number of samples to take from replay for every insert that is made. sequence_length: number of timesteps to store for each trajectory. sigma: standard deviation of zero-mean, Gaussian exploration noise. replay_table_name: string indicating what name to give the replay table. counter: counter object used to keep track of steps. logger: logger object to be used by learner. checkpoint: boolean indicating whether to checkpoint the learner. """ # Create the Builder object which will internally create agent components. builder = SVG0Builder( # TODO(mwhoffman): pass the config dataclass in directly. # TODO(mwhoffman): use the limiter rather than the workaround below. # Right now this modifies min_replay_size and samples_per_insert so that # they are not controlled by a limiter and are instead handled by the # Agent base class (the above TODO directly references this behavior). SVG0Config( discount=discount, batch_size=batch_size, prefetch_size=prefetch_size, target_update_period=target_update_period, policy_optimizer=policy_optimizer, critic_optimizer=critic_optimizer, prior_optimizer=prior_optimizer, distillation_cost=distillation_cost, entropy_regularizer_cost=entropy_regularizer_cost, min_replay_size=1, # Let the Agent class handle this. max_replay_size=max_replay_size, samples_per_insert=None, # Let the Agent class handle this. sequence_length=sequence_length, sigma=sigma, replay_table_name=replay_table_name, )) # TODO(mwhoffman): pass the network dataclass in directly. online_networks = SVG0Networks( policy_network=policy_network, critic_network=critic_network, prior_network=prior_network, ) # Target networks are just a copy of the online networks. target_networks = copy.deepcopy(online_networks) # Initialize the networks. online_networks.init(environment_spec) target_networks.init(environment_spec) # TODO(mwhoffman): either make this Dataclass or pass only one struct. # The network struct passed to make_learner is just a tuple for the # time-being (for backwards compatibility). networks = (online_networks, target_networks) # Create the behavior policy. policy_network = online_networks.make_policy() # Create the replay server and grab its address. replay_tables = builder.make_replay_tables(environment_spec, sequence_length) replay_server = reverb.Server(replay_tables, port=None) replay_client = reverb.Client(f'localhost:{replay_server.port}') # Create actor, dataset, and learner for generating, storing, and consuming # data respectively. adder = builder.make_adder(replay_client) actor = builder.make_actor(policy_network, adder) dataset = builder.make_dataset_iterator(replay_client) learner = builder.make_learner(networks, dataset, counter, logger, checkpoint) super().__init__(actor=actor, learner=learner, min_observations=max(batch_size, min_replay_size), observations_per_step=float(batch_size) / samples_per_insert) # Save the replay so we don't garbage collect it. self._replay_server = replay_server
"""Creates a single-process replay infrastructure from an environment spec.""" # Create a replay server to add data to. This uses no limiter behavior in # order to allow the Agent interface to handle it. replay_table = reverb.Table( name=replay_table_name, sampler=reverb.selectors.Prioritized(priority_exponent), remover=reverb.selectors.Fifo(), max_size=max_replay_size, rate_limiter=reverb.rate_limiters.MinSize(min_replay_size), signature=adders.NStepTransitionAdder.signature( environment_spec=environment_spec)) server = reverb.Server([replay_table], port=None) # The adder is used to insert observations into replay. address = f'localhost:{server.port}' client = reverb.Client(address) adder = adders.NStepTransitionAdder(client=client, n_step=n_step, discount=discount) # The dataset provides an interface to sample from replay. data_iterator = datasets.make_reverb_dataset( table=replay_table_name, server_address=address, batch_size=batch_size, prefetch_size=prefetch_size, environment_spec=environment_spec, transition_adder=True, ).as_numpy_iterator() return ReverbReplay(server, adder, data_iterator, client)
def _reverb_client(port): return reverb.Client('localhost:{}'.format(port))
def __init__(self, environment_spec: specs.EnvironmentSpec, policy_network: snt.Module, critic_network: snt.Module, encoder_network: types.TensorTransformation = tf.identity, entropy_coeff: float = 0.01, target_update_period: int = 0, discount: float = 0.99, batch_size: int = 256, policy_learn_rate: float = 3e-4, critic_learn_rate: float = 5e-4, prefetch_size: int = 4, min_replay_size: int = 1000, max_replay_size: int = 250000, samples_per_insert: float = 64.0, n_step: int = 5, sigma: float = 0.5, clipping: bool = True, logger: loggers.Logger = None, counter: counting.Counter = None, checkpoint: bool = True, replay_table_name: str = adders.DEFAULT_PRIORITY_TABLE): """Initialize the agent. Args: environment_spec: description of the actions, observations, etc. policy_network: the online (optimized) policy. critic_network: the online critic. observation_network: optional network to transform the observations before they are fed into any network. discount: discount to use for TD updates. batch_size: batch size for updates. prefetch_size: size to prefetch from replay. target_update_period: number of learner steps to perform before updating the target networks. min_replay_size: minimum replay size before updating. max_replay_size: maximum replay size. samples_per_insert: number of samples to take from replay for every insert that is made. n_step: number of steps to squash into a single transition. sigma: standard deviation of zero-mean, Gaussian exploration noise. clipping: whether to clip gradients by global norm. logger: logger object to be used by learner. counter: counter object used to keep track of steps. checkpoint: boolean indicating whether to checkpoint the learner. replay_table_name: string indicating what name to give the replay table. """ # Create a replay server to add data to. This uses no limiter behavior in # order to allow the Agent interface to handle it. dim_actions = np.prod(environment_spec.actions.shape, dtype=int) extra_spec = { 'logP': tf.ones(shape=(1), dtype=tf.float32), 'policy': tf.ones(shape=(1, dim_actions), dtype=tf.float32) } # Remove batch dimensions. extra_spec = tf2_utils.squeeze_batch_dim(extra_spec) replay_table = reverb.Table( name=replay_table_name, sampler=reverb.selectors.Uniform(), remover=reverb.selectors.Fifo(), max_size=max_replay_size, rate_limiter=reverb.rate_limiters.MinSize(1), signature=adders.NStepTransitionAdder.signature( environment_spec, extras_spec=extra_spec)) self._server = reverb.Server([replay_table], port=None) # The adder is used to insert observations into replay. address = f'localhost:{self._server.port}' adder = adders.NStepTransitionAdder( priority_fns={replay_table_name: lambda x: 1.}, client=reverb.Client(address), n_step=n_step, discount=discount) # The dataset provides an interface to sample from replay. dataset = datasets.make_reverb_dataset(table=replay_table_name, server_address=address, batch_size=batch_size, prefetch_size=prefetch_size) # Make sure observation network is a Sonnet Module. observation_network = model.MDPNormalization(environment_spec, encoder_network) # Get observation and action specs. act_spec = environment_spec.actions obs_spec = environment_spec.observations # Create the behavior policy. sampling_head = model.SquashedGaussianSamplingHead(act_spec, sigma) self._behavior_network = model.PolicyValueBehaviorNet( snt.Sequential([observation_network, policy_network]), sampling_head) # Create variables. emb_spec = tf2_utils.create_variables(observation_network, [obs_spec]) tf2_utils.create_variables(policy_network, [emb_spec]) tf2_utils.create_variables(critic_network, [emb_spec, act_spec]) # Create the actor which defines how we take actions. actor = model.SACFeedForwardActor(self._behavior_network, adder) if target_update_period > 0: target_policy_network = copy.deepcopy(policy_network) target_critic_network = copy.deepcopy(critic_network) target_observation_network = copy.deepcopy(observation_network) tf2_utils.create_variables(target_policy_network, [emb_spec]) tf2_utils.create_variables(target_critic_network, [emb_spec, act_spec]) tf2_utils.create_variables(target_observation_network, [obs_spec]) else: target_policy_network = policy_network target_critic_network = critic_network target_observation_network = observation_network # Create optimizers. policy_optimizer = snt.optimizers.Adam(learning_rate=policy_learn_rate) critic_optimizer = snt.optimizers.Adam(learning_rate=critic_learn_rate) # The learner updates the parameters (and initializes them). learner = learning.SACLearner( policy_network=policy_network, critic_network=critic_network, sampling_head=sampling_head, observation_network=observation_network, target_policy_network=target_policy_network, target_critic_network=target_critic_network, target_observation_network=target_observation_network, policy_optimizer=policy_optimizer, critic_optimizer=critic_optimizer, target_update_period=target_update_period, learning_rate=policy_learn_rate, clipping=clipping, entropy_coeff=entropy_coeff, discount=discount, dataset=dataset, counter=counter, logger=logger, checkpoint=checkpoint, ) super().__init__(actor=actor, learner=learner, min_observations=max(batch_size, min_replay_size), observations_per_step=float(batch_size) / samples_per_insert)
def __init__( self, environment_spec: specs.EnvironmentSpec, network: networks.PolicyValueRNN, initial_state_fn: Callable[[], networks.RNNState], sequence_length: int, sequence_period: int, counter: counting.Counter = None, logger: loggers.Logger = None, discount: float = 0.99, max_queue_size: int = 100000, batch_size: int = 16, learning_rate: float = 1e-3, entropy_cost: float = 0.01, baseline_cost: float = 0.5, seed: int = 0, max_abs_reward: float = np.inf, max_gradient_norm: float = np.inf, ): num_actions = environment_spec.actions.num_values self._logger = logger or loggers.TerminalLogger('agent') queue = reverb.Table.queue(name=adders.DEFAULT_PRIORITY_TABLE, max_size=max_queue_size) self._server = reverb.Server([queue], port=None) self._can_sample = lambda: queue.can_sample(batch_size) address = f'localhost:{self._server.port}' # Component to add things into replay. adder = adders.SequenceAdder( client=reverb.Client(address), period=sequence_period, sequence_length=sequence_length, ) # The dataset object to learn from. extra_spec = { 'core_state': hk.transform(initial_state_fn).apply(None), 'logits': np.ones(shape=(num_actions, ), dtype=np.float32) } # Remove batch dimensions. dataset = datasets.make_reverb_dataset( client=reverb.TFClient(address), environment_spec=environment_spec, batch_size=batch_size, extra_spec=extra_spec, sequence_length=sequence_length) rng = hk.PRNGSequence(seed) optimizer = optix.chain( optix.clip_by_global_norm(max_gradient_norm), optix.adam(learning_rate), ) self._learner = learning.IMPALALearner( obs_spec=environment_spec.observations, network=network, initial_state_fn=initial_state_fn, iterator=dataset.as_numpy_iterator(), rng=rng, counter=counter, logger=logger, optimizer=optimizer, discount=discount, entropy_cost=entropy_cost, baseline_cost=baseline_cost, max_abs_reward=max_abs_reward, ) variable_client = jax_variable_utils.VariableClient(self._learner, key='policy') self._actor = acting.IMPALAActor( network=network, initial_state_fn=initial_state_fn, rng=rng, adder=adder, variable_client=variable_client, )
def __init__(self, environment_spec: specs.EnvironmentSpec, policy_network: snt.Module, critic_network: snt.Module, observation_network: types.TensorTransformation = tf.identity, discount: float = 0.99, batch_size: int = 256, prefetch_size: int = 4, target_update_period: int = 100, min_replay_size: int = 1000, max_replay_size: int = 1000000, samples_per_insert: float = 32.0, n_step: int = 5, sigma: float = 0.3, clipping: bool = True, logger: loggers.Logger = None, counter: counting.Counter = None, checkpoint: bool = True, replay_table_name: str = adders.DEFAULT_PRIORITY_TABLE): """Initialize the agent. Args: environment_spec: description of the actions, observations, etc. policy_network: the online (optimized) policy. critic_network: the online critic. observation_network: optional network to transform the observations before they are fed into any network. discount: discount to use for TD updates. batch_size: batch size for updates. prefetch_size: size to prefetch from replay. target_update_period: number of learner steps to perform before updating the target networks. min_replay_size: minimum replay size before updating. max_replay_size: maximum replay size. samples_per_insert: number of samples to take from replay for every insert that is made. n_step: number of steps to squash into a single transition. sigma: standard deviation of zero-mean, Gaussian exploration noise. clipping: whether to clip gradients by global norm. logger: logger object to be used by learner. counter: counter object used to keep track of steps. checkpoint: boolean indicating whether to checkpoint the learner. replay_table_name: string indicating what name to give the replay table. """ # Create a replay server to add data to. This uses no limiter behavior in # order to allow the Agent interface to handle it. replay_table = reverb.Table( name=replay_table_name, sampler=reverb.selectors.Uniform(), remover=reverb.selectors.Fifo(), max_size=max_replay_size, rate_limiter=reverb.rate_limiters.MinSize(1), signature=adders.NStepTransitionAdder.signature(environment_spec)) self._server = reverb.Server([replay_table], port=None) # The adder is used to insert observations into replay. address = f'localhost:{self._server.port}' adder = adders.NStepTransitionAdder( priority_fns={replay_table_name: lambda x: 1.}, client=reverb.Client(address), n_step=n_step, discount=discount) # The dataset provides an interface to sample from replay. dataset = datasets.make_reverb_dataset( table=replay_table_name, client=reverb.TFClient(address), environment_spec=environment_spec, batch_size=batch_size, prefetch_size=prefetch_size, transition_adder=True) # Get observation and action specs. act_spec = environment_spec.actions obs_spec = environment_spec.observations emb_spec = tf2_utils.create_variables(observation_network, [obs_spec]) # pytype: disable=wrong-arg-types # Make sure observation network is a Sonnet Module. observation_network = tf2_utils.to_sonnet_module(observation_network) # Create target networks. target_policy_network = copy.deepcopy(policy_network) target_critic_network = copy.deepcopy(critic_network) target_observation_network = copy.deepcopy(observation_network) # Create the behavior policy. behavior_network = snt.Sequential([ observation_network, policy_network, networks.ClippedGaussian(sigma), networks.ClipToSpec(act_spec), ]) # Create variables. tf2_utils.create_variables(policy_network, [emb_spec]) tf2_utils.create_variables(critic_network, [emb_spec, act_spec]) tf2_utils.create_variables(target_policy_network, [emb_spec]) tf2_utils.create_variables(target_critic_network, [emb_spec, act_spec]) tf2_utils.create_variables(target_observation_network, [obs_spec]) # Create the actor which defines how we take actions. actor = actors.FeedForwardActor(behavior_network, adder=adder) # Create optimizers. policy_optimizer = snt.optimizers.Adam(learning_rate=1e-4) critic_optimizer = snt.optimizers.Adam(learning_rate=1e-4) # The learner updates the parameters (and initializes them). learner = learning.DDPGLearner( policy_network=policy_network, critic_network=critic_network, observation_network=observation_network, target_policy_network=target_policy_network, target_critic_network=target_critic_network, target_observation_network=target_observation_network, policy_optimizer=policy_optimizer, critic_optimizer=critic_optimizer, clipping=clipping, discount=discount, target_update_period=target_update_period, dataset=dataset, counter=counter, logger=logger, checkpoint=checkpoint, ) super().__init__(actor=actor, learner=learner, min_observations=max(batch_size, min_replay_size), observations_per_step=float(batch_size) / samples_per_insert)
def __init__(self, variable_server_name, variable_server_address): self._variable_server_name = variable_server_name self._variable_client = reverb.Client(variable_server_address),
# Initalize Reverb Server server = reverb.Server(tables=[ reverb.Table(name='ReplayBuffer', sampler=reverb.selectors.Uniform(), remover=reverb.selectors.Fifo(), max_size=buffer_size, rate_limiter=reverb.rate_limiters.MinSize(1)), reverb.Table(name='PrioritizedReplayBuffer', sampler=reverb.selectors.Prioritized(alpha), remover=reverb.selectors.Fifo(), max_size=buffer_size, rate_limiter=reverb.rate_limiters.MinSize(1)) ]) client = reverb.Client(f"localhost:{server.port}") tf_client = reverb.TFClient(f"localhost:{server.port}") # Helper Function def env(n): e = { "obs": np.ones((n, obs_shape)), "act": np.zeros((n, act_shape)), "next_obs": np.ones((n, obs_shape)), "rew": np.zeros(n), "done": np.zeros(n) } return e
def __init__( self, environment_spec: specs.EnvironmentSpec, network: snt.Module, demonstration_dataset: tf.data.Dataset, demonstration_ratio: float, batch_size: int = 256, prefetch_size: int = 4, target_update_period: int = 100, samples_per_insert: float = 32.0, min_replay_size: int = 1000, max_replay_size: int = 1000000, importance_sampling_exponent: float = 0.2, n_step: int = 5, epsilon: tf.Tensor = None, learning_rate: float = 1e-3, discount: float = 0.99, ): """Initialize the agent. Args: environment_spec: description of the actions, observations, etc. network: the online Q network (the one being optimized) demonstration_dataset: tf.data.Dataset producing (timestep, action) tuples containing full episodes. demonstration_ratio: Ratio of transitions coming from demonstrations. batch_size: batch size for updates. prefetch_size: size to prefetch from replay. target_update_period: number of learner steps to perform before updating the target networks. samples_per_insert: number of samples to take from replay for every insert that is made. min_replay_size: minimum replay size before updating. This and all following arguments are related to dataset construction and will be ignored if a dataset argument is passed. max_replay_size: maximum replay size. importance_sampling_exponent: power to which importance weights are raised before normalizing. n_step: number of steps to squash into a single transition. epsilon: probability of taking a random action; ignored if a policy network is given. learning_rate: learning rate for the q-network update. discount: discount to use for TD updates. """ # Create a replay server to add data to. This uses no limiter behavior in # order to allow the Agent interface to handle it. replay_table = reverb.Table( name=adders.DEFAULT_PRIORITY_TABLE, sampler=reverb.selectors.Uniform(), remover=reverb.selectors.Fifo(), max_size=max_replay_size, rate_limiter=reverb.rate_limiters.MinSize(1)) self._server = reverb.Server([replay_table], port=None) # The adder is used to insert observations into replay. address = f'localhost:{self._server.port}' adder = adders.NStepTransitionAdder(client=reverb.Client(address), n_step=n_step, discount=discount) # The dataset provides an interface to sample from replay. replay_client = reverb.TFClient(address) dataset = datasets.make_reverb_dataset( client=replay_client, environment_spec=environment_spec, transition_adder=True) # Combine with demonstration dataset. transition = functools.partial(_n_step_transition_from_episode, n_step=n_step, discount=discount) dataset_demos = demonstration_dataset.map(transition) dataset = tf.data.experimental.sample_from_datasets( [dataset, dataset_demos], [1 - demonstration_ratio, demonstration_ratio]) # Batch and prefetch. dataset = dataset.batch(batch_size, drop_remainder=True) dataset = dataset.prefetch(prefetch_size) # Use constant 0.05 epsilon greedy policy by default. if epsilon is None: epsilon = tf.Variable(0.05, trainable=False) policy_network = snt.Sequential([ network, lambda q: trfl.epsilon_greedy(q, epsilon=epsilon).sample(), ]) # Create a target network. target_network = copy.deepcopy(network) # Ensure that we create the variables before proceeding (maybe not needed). tf2_utils.create_variables(network, [environment_spec.observations]) tf2_utils.create_variables(target_network, [environment_spec.observations]) # Create the actor which defines how we take actions. actor = actors.FeedForwardActor(policy_network, adder) # The learner updates the parameters (and initializes them). learner = dqn.DQNLearner( network=network, target_network=target_network, discount=discount, importance_sampling_exponent=importance_sampling_exponent, learning_rate=learning_rate, target_update_period=target_update_period, dataset=dataset, replay_client=replay_client) super().__init__(actor=actor, learner=learner, min_observations=max(batch_size, min_replay_size), observations_per_step=float(batch_size) / samples_per_insert)
def __init__( self, environment_spec: specs.EnvironmentSpec, policy_network: snt.Module, critic_network: snt.Module, observation_network: types.TensorTransformation = tf.identity, discount: float = 0.99, batch_size: int = 256, prefetch_size: int = 4, target_policy_update_period: int = 100, target_critic_update_period: int = 100, min_replay_size: int = 1000, max_replay_size: int = 1000000, samples_per_insert: float = 32.0, policy_loss_module: snt.Module = None, policy_optimizer: snt.Optimizer = None, critic_optimizer: snt.Optimizer = None, n_step: int = 5, num_samples: int = 20, clipping: bool = True, logger: loggers.Logger = None, counter: counting.Counter = None, checkpoint: bool = True, replay_table_name: str = adders.DEFAULT_PRIORITY_TABLE, ): """Initialize the agent. Args: environment_spec: description of the actions, observations, etc. policy_network: the online (optimized) policy. critic_network: the online critic. observation_network: optional network to transform the observations before they are fed into any network. discount: discount to use for TD updates. batch_size: batch size for updates. prefetch_size: size to prefetch from replay. target_policy_update_period: number of updates to perform before updating the target policy network. target_critic_update_period: number of updates to perform before updating the target critic network. min_replay_size: minimum replay size before updating. max_replay_size: maximum replay size. samples_per_insert: number of samples to take from replay for every insert that is made. policy_loss_module: configured MPO loss function for the policy optimization; defaults to sensible values on the control suite. See `acme/tf/losses/mpo.py` for more details. policy_optimizer: optimizer to be used on the policy. critic_optimizer: optimizer to be used on the critic. n_step: number of steps to squash into a single transition. num_samples: number of actions to sample when doing a Monte Carlo integration with respect to the policy. clipping: whether to clip gradients by global norm. logger: logging object used to write to logs. counter: counter object used to keep track of steps. checkpoint: boolean indicating whether to checkpoint the learner. replay_table_name: string indicating what name to give the replay table. """ # Create a replay server to add data to. replay_table = reverb.Table( name=adders.DEFAULT_PRIORITY_TABLE, sampler=reverb.selectors.Uniform(), remover=reverb.selectors.Fifo(), max_size=max_replay_size, rate_limiter=reverb.rate_limiters.MinSize(min_size_to_sample=1), signature=adders.NStepTransitionAdder.signature(environment_spec)) self._server = reverb.Server([replay_table], port=None) # The adder is used to insert observations into replay. address = f'localhost:{self._server.port}' adder = adders.NStepTransitionAdder(client=reverb.Client(address), n_step=n_step, discount=discount) # The dataset object to learn from. dataset = datasets.make_reverb_dataset( table=replay_table_name, client=reverb.TFClient(address), batch_size=batch_size, prefetch_size=prefetch_size, environment_spec=environment_spec, transition_adder=True) # Make sure observation network is a Sonnet Module. observation_network = tf2_utils.to_sonnet_module(observation_network) # Create target networks before creating online/target network variables. target_policy_network = copy.deepcopy(policy_network) target_critic_network = copy.deepcopy(critic_network) target_observation_network = copy.deepcopy(observation_network) # Get observation and action specs. act_spec = environment_spec.actions obs_spec = environment_spec.observations emb_spec = tf2_utils.create_variables(observation_network, [obs_spec]) # Create the behavior policy. behavior_network = snt.Sequential([ observation_network, policy_network, networks.StochasticSamplingHead(), ]) # Create variables. tf2_utils.create_variables(policy_network, [emb_spec]) tf2_utils.create_variables(critic_network, [emb_spec, act_spec]) tf2_utils.create_variables(target_policy_network, [emb_spec]) tf2_utils.create_variables(target_critic_network, [emb_spec, act_spec]) tf2_utils.create_variables(target_observation_network, [obs_spec]) # Create the actor which defines how we take actions. actor = actors.FeedForwardActor(policy_network=behavior_network, adder=adder) # Create optimizers. policy_optimizer = policy_optimizer or snt.optimizers.Adam(1e-4) critic_optimizer = critic_optimizer or snt.optimizers.Adam(1e-4) # The learner updates the parameters (and initializes them). learner = learning.MPOLearner( policy_network=policy_network, critic_network=critic_network, observation_network=observation_network, target_policy_network=target_policy_network, target_critic_network=target_critic_network, target_observation_network=target_observation_network, policy_loss_module=policy_loss_module, policy_optimizer=policy_optimizer, critic_optimizer=critic_optimizer, clipping=clipping, discount=discount, num_samples=num_samples, target_policy_update_period=target_policy_update_period, target_critic_update_period=target_critic_update_period, dataset=dataset, logger=logger, counter=counter, checkpoint=checkpoint) super().__init__(actor=actor, learner=learner, min_observations=max(batch_size, min_replay_size), observations_per_step=float(batch_size) / samples_per_insert)
def __init__( self, environment_spec: specs.EnvironmentSpec, network: snt.Module, batch_size: int = 256, prefetch_size: int = 4, target_update_period: int = 100, samples_per_insert: float = 32.0, min_replay_size: int = 1000, max_replay_size: int = 1000000, importance_sampling_exponent: float = 0.2, priority_exponent: float = 0.6, n_step: int = 5, epsilon: tf.Tensor = None, learning_rate: float = 1e-3, discount: float = 0.99, cql_alpha: float = 1., logger: loggers.Logger = None, counter: counting.Counter = None, checkpoint_subpath: str = '~/acme/', ): """Initialize the agent. Args: environment_spec: description of the actions, observations, etc. network: the online Q network (the one being optimized) batch_size: batch size for updates. prefetch_size: size to prefetch from replay. target_update_period: number of learner steps to perform before updating the target networks. samples_per_insert: number of samples to take from replay for every insert that is made. min_replay_size: minimum replay size before updating. This and all following arguments are related to dataset construction and will be ignored if a dataset argument is passed. max_replay_size: maximum replay size. importance_sampling_exponent: power to which importance weights are raised before normalizing. priority_exponent: exponent used in prioritized sampling. n_step: number of steps to squash into a single transition. epsilon: probability of taking a random action; ignored if a policy network is given. learning_rate: learning rate for the q-network update. discount: discount to use for TD updates. logger: logger object to be used by learner. checkpoint: boolean indicating whether to checkpoint the learner. checkpoint_subpath: directory for the checkpoint. """ # Create a replay server to add data to. This uses no limiter behavior in # order to allow the Agent interface to handle it. replay_table = reverb.Table( name=adders.DEFAULT_PRIORITY_TABLE, sampler=reverb.selectors.Prioritized(priority_exponent), remover=reverb.selectors.Fifo(), max_size=max_replay_size, rate_limiter=reverb.rate_limiters.MinSize(1), signature=adders.NStepTransitionAdder.signature(environment_spec)) self._server = reverb.Server([replay_table], port=None) # The adder is used to insert observations into replay. address = f'localhost:{self._server.port}' adder = adders.NStepTransitionAdder(client=reverb.Client(address), n_step=n_step, discount=discount) # The dataset provides an interface to sample from replay. replay_client = reverb.TFClient(address) dataset = datasets.make_reverb_dataset( client=replay_client, environment_spec=environment_spec, batch_size=batch_size, prefetch_size=prefetch_size, transition_adder=True) # Use constant 0.05 epsilon greedy policy by default. if epsilon is None: epsilon = tf.Variable(0.05, trainable=False) policy_network = snt.Sequential([ network, lambda q: trfl.epsilon_greedy(q, epsilon=epsilon).sample(), ]) # Create a target network. target_network = copy.deepcopy(network) # Ensure that we create the variables before proceeding (maybe not needed). tf2_utils.create_variables(network, [environment_spec.observations]) tf2_utils.create_variables(target_network, [environment_spec.observations]) # Create the actor which defines how we take actions. actor = actors.FeedForwardActor(policy_network, adder) # The learner updates the parameters (and initializes them). learner = CQLLearner( network=network, discount=discount, importance_sampling_exponent=importance_sampling_exponent, learning_rate=learning_rate, cql_alpha=cql_alpha, target_update_period=target_update_period, dataset=dataset, replay_client=replay_client, logger=logger, counter=counter, checkpoint_subpath=checkpoint_subpath) super().__init__(actor=actor, learner=learner, min_observations=max(batch_size, min_replay_size), observations_per_step=float(batch_size) / samples_per_insert)
def __init__( self, environment_spec: specs.EnvironmentSpec, network: hk.Transformed, batch_size: int = 256, prefetch_size: int = 4, target_update_period: int = 100, samples_per_insert: float = 32.0, min_replay_size: int = 1000, max_replay_size: int = 1000000, importance_sampling_exponent: float = 0.2, priority_exponent: float = 0.6, n_step: int = 5, epsilon: float = 0., learning_rate: float = 1e-3, discount: float = 0.99, seed: int = 1, ): """Initialize the agent.""" # Create a replay server to add data to. This uses no limiter behavior in # order to allow the Agent interface to handle it. replay_table = reverb.Table( name=adders.DEFAULT_PRIORITY_TABLE, sampler=reverb.selectors.Prioritized(priority_exponent), remover=reverb.selectors.Fifo(), max_size=max_replay_size, rate_limiter=reverb.rate_limiters.MinSize(1), signature=adders.NStepTransitionAdder.signature( environment_spec=environment_spec)) self._server = reverb.Server([replay_table], port=None) # The adder is used to insert observations into replay. address = f'localhost:{self._server.port}' adder = adders.NStepTransitionAdder(client=reverb.Client(address), n_step=n_step, discount=discount) # The dataset provides an interface to sample from replay. dataset = datasets.make_reverb_dataset( server_address=address, environment_spec=environment_spec, batch_size=batch_size, prefetch_size=prefetch_size, transition_adder=True) def policy(params: hk.Params, key: jnp.ndarray, observation: jnp.ndarray) -> jnp.ndarray: action_values = network.apply(params, observation) return rlax.epsilon_greedy(epsilon).sample(key, action_values) # The learner updates the parameters (and initializes them). learner = learning.DQNLearner( network=network, obs_spec=environment_spec.observations, rng=hk.PRNGSequence(seed), optimizer=optax.adam(learning_rate), discount=discount, importance_sampling_exponent=importance_sampling_exponent, target_update_period=target_update_period, iterator=dataset.as_numpy_iterator(), replay_client=reverb.Client(address), ) variable_client = variable_utils.VariableClient(learner, '') actor = actors.FeedForwardActor(policy=policy, rng=hk.PRNGSequence(seed), variable_client=variable_client, adder=adder) super().__init__(actor=actor, learner=learner, min_observations=max(batch_size, min_replay_size), observations_per_step=float(batch_size) / samples_per_insert)
def __init__( self, environment_spec: specs.EnvironmentSpec, network: snt.RNNCore, burn_in_length: int, trace_length: int, replay_period: int, counter: counting.Counter = None, logger: loggers.Logger = None, discount: float = 0.99, batch_size: int = 32, prefetch_size: int = tf.data.experimental.AUTOTUNE, target_update_period: int = 100, importance_sampling_exponent: float = 0.2, priority_exponent: float = 0.6, epsilon_init: float = 1.0, epsilon_final: float = 0.01, epsilon_schedule_timesteps: float = 20000, learning_rate: float = 1e-3, min_replay_size: int = 1000, max_replay_size: int = 1000000, samples_per_insert: float = 32.0, store_lstm_state: bool = True, max_priority_weight: float = 0.9, checkpoint: bool = True, ): if store_lstm_state: extra_spec = { 'core_state': tf2_utils.squeeze_batch_dim(network.initial_state(1)), } else: extra_spec = () replay_table = reverb.Table( name=adders.DEFAULT_PRIORITY_TABLE, sampler=reverb.selectors.Prioritized(priority_exponent), remover=reverb.selectors.Fifo(), max_size=max_replay_size, rate_limiter=reverb.rate_limiters.MinSize(min_size_to_sample=1), signature=adders.SequenceAdder.signature(environment_spec, extra_spec)) self._server = reverb.Server([replay_table], port=None) address = f'localhost:{self._server.port}' sequence_length = burn_in_length + trace_length + 1 # Component to add things into replay. self._adder = adders.SequenceAdder( client=reverb.Client(address), period=replay_period, sequence_length=sequence_length, ) # The dataset object to learn from. dataset = make_reverb_dataset(server_address=address, batch_size=batch_size, prefetch_size=prefetch_size, sequence_length=sequence_length) target_network = copy.deepcopy(network) tf2_utils.create_variables(network, [environment_spec.observations]) tf2_utils.create_variables(target_network, [environment_spec.observations]) learner = learning.R2D2Learner( environment_spec=environment_spec, network=network, target_network=target_network, burn_in_length=burn_in_length, sequence_length=sequence_length, dataset=dataset, reverb_client=reverb.TFClient(address), counter=counter, logger=logger, discount=discount, target_update_period=target_update_period, importance_sampling_exponent=importance_sampling_exponent, max_replay_size=max_replay_size, learning_rate=learning_rate, store_lstm_state=store_lstm_state, max_priority_weight=max_priority_weight, ) self._saver = tf2_savers.Saver(learner.state) policy_network = snt.DeepRNN([ network, EpsilonGreedyExploration( epsilon_init=epsilon_init, epsilon_final=epsilon_final, epsilon_schedule_timesteps=epsilon_schedule_timesteps) ]) actor = actors.RecurrentActor(policy_network, self._adder, store_recurrent_state=store_lstm_state) max_Q_network = snt.DeepRNN([ network, lambda qs: trfl.epsilon_greedy(qs, epsilon=0.0).sample(), ]) self._deterministic_actor = actors.RecurrentActor( max_Q_network, self._adder, store_recurrent_state=store_lstm_state) observations_per_step = (float(replay_period * batch_size) / samples_per_insert) super().__init__(actor=actor, learner=learner, min_observations=replay_period * max(batch_size, min_replay_size), observations_per_step=observations_per_step)