def __init__(self, name="ReverbUniformReplayBuffer", reverb_server=None): super().__init__(name=name) self.device = Params.DEVICE with tf.device(self.device), self.name_scope: self.buffer_size = tf.cast(Params.BUFFER_SIZE, tf.int64) self.batch_size = tf.cast(Params.MINIBATCH_SIZE, tf.int64) self.batch_size_float = tf.cast(Params.MINIBATCH_SIZE, tf.float64) self.sequence_length = tf.cast(Params.N_STEP_RETURNS, tf.int64) # Initialize the reverb server if not reverb_server: self.reverb_server = reverb.Server(tables=[ reverb.Table( name=Params.BUFFER_TYPE, sampler=reverb.selectors.Uniform(), remover=reverb.selectors.Fifo(), max_size=self.buffer_size, rate_limiter=reverb.rate_limiters.MinSize( self.batch_size), ) ], ) else: self.reverb_server = reverb_server dataset = reverb.ReplayDataset( server_address=f'localhost:{self.reverb_server.port}', table=Params.BUFFER_TYPE, max_in_flight_samples_per_worker=2 * self.batch_size, dtypes=Params.BUFFER_DATA_SPEC_DTYPES, shapes=Params.BUFFER_DATA_SPEC_SHAPES, ) dataset = dataset.map( map_func=reduce_trajectory, num_parallel_calls=tf.data.experimental.AUTOTUNE, deterministic=True, ) dataset = dataset.batch(self.batch_size) dataset = dataset.prefetch(5) self.iterator = dataset.__iter__()
def _create_server( table: Text = reverb_variable_container.DEFAULT_TABLE, max_size: int = 1, signature: types.NestedTensorSpec = (tf.TensorSpec((), tf.int64), { 'var1': (tf.TensorSpec((2), tf.float64), ), 'var2': tf.TensorSpec((2, 1), tf.int32) }) ) -> Tuple[reverb.Server, Text]: server = reverb.Server(tables=[ reverb.Table(name=table, sampler=reverb.selectors.Uniform(), remover=reverb.selectors.Fifo(), rate_limiter=reverb.rate_limiters.MinSize(1), max_size=max_size, max_times_sampled=0, signature=signature) ], port=portpicker.pick_unused_port()) return server, 'localhost:{}'.format(server.port)
def __init__(self, num_tables: int = 1, min_size: int = 64, max_size: int = 100000, checkpointer=None): self._min_size = min_size self._table_names = [f"uniform_table_{i}" for i in range(num_tables)] self._server = reverb.Server( tables=[ reverb.Table( name=self._table_names[i], sampler=reverb.selectors.Uniform(), remover=reverb.selectors.Fifo(), max_size=int(max_size), rate_limiter=reverb.rate_limiters.MinSize(min_size), ) for i in range(num_tables) ], # Sets the port to None to make the server pick one automatically. port=None, checkpointer=checkpointer)
def create_reverb_server_for_replay_buffer_and_variable_container( collect_policy, train_step, replay_buffer_capacity, port): """Sets up one reverb server for replay buffer and variable container.""" # Create the signature for the variable container holding the policy weights. variables = { reverb_variable_container.POLICY_KEY: collect_policy.variables(), reverb_variable_container.TRAIN_STEP_KEY: train_step } variable_container_signature = tf.nest.map_structure( lambda variable: tf.TensorSpec(variable.shape, dtype=variable.dtype), variables) # Create the signature for the replay buffer holding observed experience. replay_buffer_signature = tensor_spec.from_spec( collect_policy.collect_data_spec) # Crete and start the replay buffer and variable container server. server = reverb.Server( tables=[ reverb.Table( # Replay buffer storing experience. name=reverb_replay_buffer.DEFAULT_TABLE, sampler=reverb.selectors.Uniform(), remover=reverb.selectors.Fifo(), # TODO(b/159073060): Set rate limiter for SAC properly. rate_limiter=reverb.rate_limiters.MinSize(1), max_size=replay_buffer_capacity, max_times_sampled=0, signature=replay_buffer_signature, ), reverb.Table( # Variable container storing policy parameters. name=reverb_variable_container.DEFAULT_TABLE, sampler=reverb.selectors.Uniform(), remover=reverb.selectors.Fifo(), rate_limiter=reverb.rate_limiters.MinSize(1), max_size=1, max_times_sampled=0, signature=variable_container_signature, ), ], port=port) return server
def test_uniform_table(self): table_name = 'test_uniform_table' queue_table = reverb.Table( table_name, sampler=reverb.selectors.Uniform(), remover=reverb.selectors.Fifo(), max_size=1000, rate_limiter=reverb.rate_limiters.MinSize(3)) reverb_server = reverb.Server([queue_table]) data_spec = tensor_spec.TensorSpec((), dtype=tf.int64) replay = reverb_replay_buffer.ReverbReplayBuffer( data_spec, table_name, local_server=reverb_server, sequence_length=1, dataset_buffer_size=1) with replay.py_client.trajectory_writer( num_keep_alive_refs=1) as writer: for i in range(3): writer.append(i) trajectory = writer.history[-1:] writer.create_item(table_name, trajectory=trajectory, priority=1) dataset = replay.as_dataset(sample_batch_size=1, num_steps=None, num_parallel_calls=1) iterator = iter(dataset) counts = [0] * 3 for i in range(1000): item_0 = next(iterator)[0].numpy() # This is a matrix shaped 1x1. counts[int(item_0)] += 1 # Comparing against 200 to avoid flakyness self.assertGreater(counts[0], 200) self.assertGreater(counts[1], 200) self.assertGreater(counts[2], 200)
def test_reverb(data): import reverb import tensorflow as tf print("TEST REVERB") print("initializing...") reverb_server = reverb.Server( tables=[ reverb.Table( name="req", sampler=reverb.selectors.Prioritized(0.6), remover=reverb.selectors.Fifo(), max_size=CAPACITY, rate_limiter=reverb.rate_limiters.MinSize(100), ) ], port=15867, ) client = reverb_server.in_process_client() for i in range(CAPACITY): client.insert([col[i] for col in data], {"req": np.random.rand()}) dataset = reverb.ReplayDataset( server_address="localhost:15867", table="req", dtypes=(tf.float64, tf.float64, tf.float64, tf.float64), shapes=( tf.TensorShape([1, 84, 84]), tf.TensorShape([1, 84, 84]), tf.TensorShape([]), tf.TensorShape([]), ), max_in_flight_samples_per_worker=10, ) dataset = dataset.batch(BATCH_SIZE) print("ready") t0 = time.perf_counter() for sample in dataset.take(TEST_CNT): pass t1 = time.perf_counter() print(TEST_CNT, t1 - t0)
def test_prioritized_table_max_sample(self): table_name = 'test_prioritized_table' table = reverb.Table(table_name, sampler=reverb.selectors.Prioritized(1.0), remover=reverb.selectors.Fifo(), max_times_sampled=10, rate_limiter=reverb.rate_limiters.MinSize(1), max_size=3) reverb_server = reverb.Server([table]) data_spec = tensor_spec.TensorSpec((), dtype=tf.int64) replay = reverb_replay_buffer.ReverbReplayBuffer( data_spec, table_name, sequence_length=1, local_server=reverb_server, dataset_buffer_size=1) with replay.py_client.trajectory_writer(1) as writer: for i in range(3): writer.append(i) writer.create_item(table_name, trajectory=writer.history[-1:], priority=i) dataset = replay.as_dataset(sample_batch_size=3, num_parallel_calls=3) self.assertTrue(table.can_sample(3)) iterator = iter(dataset) counts = [0] * 3 for i in range(10): item_0 = next(iterator)[0].numpy() # This is a matrix shaped 1x3. for item in item_0: counts[int(item)] += 1 self.assertFalse(table.can_sample(3)) # Same number of counts due to limit on max_times_sampled self.assertEqual(counts[0], 10) # priority 0 self.assertEqual(counts[1], 10) # priority 1 self.assertEqual(counts[2], 10) # priority 2
def main(argv): if len(argv) > 1: raise app.UsageError('Too many command-line arguments.') eval_table_size = _CONFIG.value.num_eval_points # TODO(joshgreaves): Choose an appropriate rate_limiter, max_size. server = reverb.Server(tables=[ reverb.Table(name='successor_table', sampler=reverb.selectors.Uniform(), remover=reverb.selectors.Fifo(), max_size=1_000_000, rate_limiter=reverb.rate_limiters.MinSize(20_000)), reverb.Table( name='eval_table', sampler=reverb.selectors.Fifo(), remover=reverb.selectors.Fifo(), max_size=eval_table_size, rate_limiter=reverb.rate_limiters.MinSize(eval_table_size)), ], port=FLAGS.port) server.wait()
def make_reverb_online_queue( environment_spec: specs.EnvironmentSpec, extra_spec: Dict[str, Any], max_queue_size: int, sequence_length: int, sequence_period: int, batch_size: int, replay_table_name: str = adders.DEFAULT_PRIORITY_TABLE, ) -> ReverbReplay: """Creates a single process queue from an environment spec and extra_spec.""" signature = adders.SequenceAdder.signature(environment_spec, extra_spec) queue = reverb.Table.queue(name=replay_table_name, max_size=max_queue_size, signature=signature) server = reverb.Server([queue], port=None) can_sample = lambda: queue.can_sample(batch_size) # Component to add things into replay. address = f'localhost:{server.port}' adder = adders.SequenceAdder( client=reverb.Client(address), period=sequence_period, sequence_length=sequence_length, ) # The dataset object to learn from. # We don't use datasets.make_reverb_dataset() here to avoid interleaving # and prefetching, that doesn't work well with can_sample() check on update. dataset = reverb.ReplayDataset.from_table_signature( server_address=address, table=replay_table_name, max_in_flight_samples_per_worker=1, sequence_length=sequence_length, emit_timesteps=False) dataset = dataset.batch(batch_size, drop_remainder=True) data_iterator = dataset.as_numpy_iterator() return ReverbReplay(server, adder, data_iterator, can_sample=can_sample)
def test_prioritized_table(self): table_name = 'test_prioritized_table' queue_table = reverb.Table( table_name, sampler=reverb.selectors.Prioritized(1.0), remover=reverb.selectors.Fifo(), rate_limiter=reverb.rate_limiters.MinSize(1), max_size=3) reverb_server = reverb.Server([queue_table]) data_spec = tensor_spec.TensorSpec((), dtype=tf.int64) replay = reverb_replay_buffer.ReverbReplayBuffer( data_spec, table_name, sequence_length=1, local_server=reverb_server, dataset_buffer_size=1) with replay.py_client.writer(max_sequence_length=1) as writer: for i in range(3): writer.append(i) writer.create_item(table=table_name, num_timesteps=1, priority=i) dataset = replay.as_dataset(sample_batch_size=1, num_steps=None, num_parallel_calls=None) iterator = iter(dataset) counts = [0] * 3 for i in range(1000): item_0 = next(iterator)[0].numpy() # This is a matrix shaped 1x1. counts[int(item_0)] += 1 self.assertEqual(counts[0], 0) # priority 0 self.assertGreater(counts[1], 250) # priority 1 self.assertGreater(counts[2], 600) # priority 2
def main(_): environment = fakes.ContinuousEnvironment(action_dim=8, observation_dim=87, episode_length=10000000) spec = specs.make_environment_spec(environment) replay_tables = make_replay_tables(spec) replay_server = reverb.Server(replay_tables, port=None) replay_client = reverb.Client(f'localhost:{replay_server.port}') adder = make_adder(replay_client) timestep = environment.reset() adder.add_first(timestep) # TODO(raveman): Consider also filling the table to say 1M (too slow). for steps in range(10000): if steps % 1000 == 0: logging.info('Processed %s steps', steps) action = np.asarray(np.random.uniform(-1, 1, (8,)), dtype=np.float32) next_timestep = environment.step(action) adder.add(action, next_timestep, extras=()) for batch_size in [256, 256 * 8, 256 * 64]: for prefetch_size in [0, 1, 4]: print(f'Processing batch_size={batch_size} prefetch_size={prefetch_size}') ds = datasets.make_reverb_dataset( table='default', server_address=replay_client.server_address, batch_size=batch_size, prefetch_size=prefetch_size, ) it = ds.as_numpy_iterator() for iteration in range(3): t = time.time() for _ in range(1000): _ = next(it) print(f'Iteration {iteration} finished in {time.time() - t}s')
def build(self): """Creates reverb server, client and dataset.""" self._reverb_server = reverb.Server( tables=[ reverb.Table( name="replay_buffer", sampler=reverb.selectors.Uniform(), remover=reverb.selectors.Fifo(), max_size=self._max_replay_size, rate_limiter=reverb.rate_limiters.MinSize(1), signature=self._signature, ), ], port=None, ) self._reverb_client = reverb.Client(f"localhost:{self._reverb_server.port}") self._reverb_dataset = reverb.TrajectoryDataset.from_table_signature( server_address=f"localhost:{self._reverb_server.port}", table="replay_buffer", max_in_flight_samples_per_worker=2 * self._batch_size, ) self._batched_dataset = self._reverb_dataset.batch( self._batch_size, drop_remainder=True ).as_numpy_iterator()
def __init__(self, environment_spec: specs.EnvironmentSpec, policy_network: snt.Module, critic_network: snt.Module, observation_network: types.TensorTransformation = tf.identity, discount: float = 0.99, batch_size: int = 256, prefetch_size: int = 4, target_update_period: int = 100, min_replay_size: int = 1000, max_replay_size: int = 1000000, samples_per_insert: float = 32.0, n_step: int = 5, sigma: float = 0.3, clipping: bool = True, logger: loggers.Logger = None, counter: counting.Counter = None, checkpoint: bool = True, replay_table_name: str = adders.DEFAULT_PRIORITY_TABLE): """Initialize the agent. Args: environment_spec: description of the actions, observations, etc. policy_network: the online (optimized) policy. critic_network: the online critic. observation_network: optional network to transform the observations before they are fed into any network. discount: discount to use for TD updates. batch_size: batch size for updates. prefetch_size: size to prefetch from replay. target_update_period: number of learner steps to perform before updating the target networks. min_replay_size: minimum replay size before updating. max_replay_size: maximum replay size. samples_per_insert: number of samples to take from replay for every insert that is made. n_step: number of steps to squash into a single transition. sigma: standard deviation of zero-mean, Gaussian exploration noise. clipping: whether to clip gradients by global norm. logger: logger object to be used by learner. counter: counter object used to keep track of steps. checkpoint: boolean indicating whether to checkpoint the learner. replay_table_name: string indicating what name to give the replay table. """ # Create a replay server to add data to. This uses no limiter behavior in # order to allow the Agent interface to handle it. replay_table = reverb.Table( name=replay_table_name, sampler=reverb.selectors.Uniform(), remover=reverb.selectors.Fifo(), max_size=max_replay_size, rate_limiter=reverb.rate_limiters.MinSize(1), signature=adders.NStepTransitionAdder.signature(environment_spec)) self._server = reverb.Server([replay_table], port=None) # The adder is used to insert observations into replay. address = f'localhost:{self._server.port}' adder = adders.NStepTransitionAdder( priority_fns={replay_table_name: lambda x: 1.}, client=reverb.Client(address), n_step=n_step, discount=discount) # The dataset provides an interface to sample from replay. dataset = datasets.make_reverb_dataset( table=replay_table_name, client=reverb.TFClient(address), environment_spec=environment_spec, batch_size=batch_size, prefetch_size=prefetch_size, transition_adder=True) # Get observation and action specs. act_spec = environment_spec.actions obs_spec = environment_spec.observations emb_spec = tf2_utils.create_variables(observation_network, [obs_spec]) # pytype: disable=wrong-arg-types # Make sure observation network is a Sonnet Module. observation_network = tf2_utils.to_sonnet_module(observation_network) # Create target networks. target_policy_network = copy.deepcopy(policy_network) target_critic_network = copy.deepcopy(critic_network) target_observation_network = copy.deepcopy(observation_network) # Create the behavior policy. behavior_network = snt.Sequential([ observation_network, policy_network, networks.ClippedGaussian(sigma), networks.ClipToSpec(act_spec), ]) # Create variables. tf2_utils.create_variables(policy_network, [emb_spec]) tf2_utils.create_variables(critic_network, [emb_spec, act_spec]) tf2_utils.create_variables(target_policy_network, [emb_spec]) tf2_utils.create_variables(target_critic_network, [emb_spec, act_spec]) tf2_utils.create_variables(target_observation_network, [obs_spec]) # Create the actor which defines how we take actions. actor = actors.FeedForwardActor(behavior_network, adder=adder) # Create optimizers. policy_optimizer = snt.optimizers.Adam(learning_rate=1e-4) critic_optimizer = snt.optimizers.Adam(learning_rate=1e-4) # The learner updates the parameters (and initializes them). learner = learning.DDPGLearner( policy_network=policy_network, critic_network=critic_network, observation_network=observation_network, target_policy_network=target_policy_network, target_critic_network=target_critic_network, target_observation_network=target_observation_network, policy_optimizer=policy_optimizer, critic_optimizer=critic_optimizer, clipping=clipping, discount=discount, target_update_period=target_update_period, dataset=dataset, counter=counter, logger=logger, checkpoint=checkpoint, ) super().__init__(actor=actor, learner=learner, min_observations=max(batch_size, min_replay_size), observations_per_step=float(batch_size) / samples_per_insert)
def __init__( self, environment_spec: specs.EnvironmentSpec, network: networks.PolicyValueRNN, initial_state_fn: Callable[[], networks.RNNState], sequence_length: int, sequence_period: int, counter: counting.Counter = None, logger: loggers.Logger = None, discount: float = 0.99, max_queue_size: int = 100000, batch_size: int = 16, learning_rate: float = 1e-3, entropy_cost: float = 0.01, baseline_cost: float = 0.5, seed: int = 0, max_abs_reward: float = np.inf, max_gradient_norm: float = np.inf, ): num_actions = environment_spec.actions.num_values self._logger = logger or loggers.TerminalLogger('agent') queue = reverb.Table.queue(name=adders.DEFAULT_PRIORITY_TABLE, max_size=max_queue_size) self._server = reverb.Server([queue], port=None) self._can_sample = lambda: queue.can_sample(batch_size) address = f'localhost:{self._server.port}' # Component to add things into replay. adder = adders.SequenceAdder( client=reverb.Client(address), period=sequence_period, sequence_length=sequence_length, ) # The dataset object to learn from. extra_spec = { 'core_state': hk.transform(initial_state_fn).apply(None), 'logits': np.ones(shape=(num_actions, ), dtype=np.float32) } # Remove batch dimensions. dataset = datasets.make_reverb_dataset( client=reverb.TFClient(address), environment_spec=environment_spec, batch_size=batch_size, extra_spec=extra_spec, sequence_length=sequence_length) rng = hk.PRNGSequence(seed) optimizer = optix.chain( optix.clip_by_global_norm(max_gradient_norm), optix.adam(learning_rate), ) self._learner = learning.IMPALALearner( obs_spec=environment_spec.observations, network=network, initial_state_fn=initial_state_fn, iterator=dataset.as_numpy_iterator(), rng=rng, counter=counter, logger=logger, optimizer=optimizer, discount=discount, entropy_cost=entropy_cost, baseline_cost=baseline_cost, max_abs_reward=max_abs_reward, ) variable_client = jax_variable_utils.VariableClient(self._learner, key='policy') self._actor = acting.IMPALAActor( network=network, initial_state_fn=initial_state_fn, rng=rng, adder=adder, variable_client=variable_client, )
def __init__( self, environment_spec: specs.EnvironmentSpec, network: snt.Module, demonstration_dataset: tf.data.Dataset, demonstration_ratio: float, batch_size: int = 256, prefetch_size: int = 4, target_update_period: int = 100, samples_per_insert: float = 32.0, min_replay_size: int = 1000, max_replay_size: int = 1000000, importance_sampling_exponent: float = 0.2, n_step: int = 5, epsilon: tf.Tensor = None, learning_rate: float = 1e-3, discount: float = 0.99, ): """Initialize the agent. Args: environment_spec: description of the actions, observations, etc. network: the online Q network (the one being optimized) demonstration_dataset: tf.data.Dataset producing (timestep, action) tuples containing full episodes. demonstration_ratio: Ratio of transitions coming from demonstrations. batch_size: batch size for updates. prefetch_size: size to prefetch from replay. target_update_period: number of learner steps to perform before updating the target networks. samples_per_insert: number of samples to take from replay for every insert that is made. min_replay_size: minimum replay size before updating. This and all following arguments are related to dataset construction and will be ignored if a dataset argument is passed. max_replay_size: maximum replay size. importance_sampling_exponent: power to which importance weights are raised before normalizing. n_step: number of steps to squash into a single transition. epsilon: probability of taking a random action; ignored if a policy network is given. learning_rate: learning rate for the q-network update. discount: discount to use for TD updates. """ # Create a replay server to add data to. This uses no limiter behavior in # order to allow the Agent interface to handle it. replay_table = reverb.Table( name=adders.DEFAULT_PRIORITY_TABLE, sampler=reverb.selectors.Uniform(), remover=reverb.selectors.Fifo(), max_size=max_replay_size, rate_limiter=reverb.rate_limiters.MinSize(1)) self._server = reverb.Server([replay_table], port=None) # The adder is used to insert observations into replay. address = f'localhost:{self._server.port}' adder = adders.NStepTransitionAdder(client=reverb.Client(address), n_step=n_step, discount=discount) # The dataset provides an interface to sample from replay. replay_client = reverb.TFClient(address) dataset = datasets.make_reverb_dataset( client=replay_client, environment_spec=environment_spec, transition_adder=True) # Combine with demonstration dataset. transition = functools.partial(_n_step_transition_from_episode, n_step=n_step, discount=discount) dataset_demos = demonstration_dataset.map(transition) dataset = tf.data.experimental.sample_from_datasets( [dataset, dataset_demos], [1 - demonstration_ratio, demonstration_ratio]) # Batch and prefetch. dataset = dataset.batch(batch_size, drop_remainder=True) dataset = dataset.prefetch(prefetch_size) # Use constant 0.05 epsilon greedy policy by default. if epsilon is None: epsilon = tf.Variable(0.05, trainable=False) policy_network = snt.Sequential([ network, lambda q: trfl.epsilon_greedy(q, epsilon=epsilon).sample(), ]) # Create a target network. target_network = copy.deepcopy(network) # Ensure that we create the variables before proceeding (maybe not needed). tf2_utils.create_variables(network, [environment_spec.observations]) tf2_utils.create_variables(target_network, [environment_spec.observations]) # Create the actor which defines how we take actions. actor = actors.FeedForwardActor(policy_network, adder) # The learner updates the parameters (and initializes them). learner = dqn.DQNLearner( network=network, target_network=target_network, discount=discount, importance_sampling_exponent=importance_sampling_exponent, learning_rate=learning_rate, target_update_period=target_update_period, dataset=dataset, replay_client=replay_client) super().__init__(actor=actor, learner=learner, min_observations=max(batch_size, min_replay_size), observations_per_step=float(batch_size) / samples_per_insert)
def __init__( self, # --- env_name: str, port: int, # --- actor_units: list, clip_mean_min: float, clip_mean_max: float, init_noise: float, # --- min_replay_size: int, max_replay_size: int, samples_per_insert: int, # --- db_path: str, ): super(Server, self).__init__(env_name, False) self._port = port # Init actor's network self.actor = Actor( units=actor_units, n_outputs=np.prod(self._env.action_space.shape), clip_mean_min=clip_mean_min, clip_mean_max=clip_mean_max, init_noise=init_noise, ) self.actor.build((None,) + self._env.observation_space.shape) # Show models details self.actor.summary() # Variables self._train_step = tf.Variable( 0, trainable=False, dtype=tf.uint64, aggregation=tf.VariableAggregation.ONLY_FIRST_REPLICA, shape=(), ) self._stop_agents = tf.Variable( False, trainable=False, dtype=tf.bool, aggregation=tf.VariableAggregation.ONLY_FIRST_REPLICA, shape=(), ) # Table for storing variables self._variable_container = VariableContainer( db_server=f"localhost:{self._port}", table="variable", variables={ "train_step": self._train_step, "stop_agents": self._stop_agents, "policy_variables": self.actor.variables, }, ) # Load DB from checkpoint or make a new one if db_path is None: checkpointer = None else: checkpointer = reverb.checkpointers.DefaultCheckpointer(path=db_path) if samples_per_insert: # 10% tolerance in rate samples_per_insert_tolerance = 0.1 * samples_per_insert error_buffer = min_replay_size * samples_per_insert_tolerance limiter = reverb.rate_limiters.SampleToInsertRatio( min_size_to_sample=min_replay_size, samples_per_insert=samples_per_insert, error_buffer=error_buffer, ) else: limiter = reverb.rate_limiters.MinSize(min_replay_size) # Initialize the reverb server self.server = reverb.Server( tables=[ reverb.Table( # Replay buffer name="experience", sampler=reverb.selectors.Uniform(), remover=reverb.selectors.Fifo(), rate_limiter=limiter, max_size=max_replay_size, max_times_sampled=0, signature={ "observation": tf.TensorSpec( [*self._env.observation_space.shape], self._env.observation_space.dtype, ), "action": tf.TensorSpec( [*self._env.action_space.shape], self._env.action_space.dtype, ), "reward": tf.TensorSpec([1], tf.float32), "next_observation": tf.TensorSpec( [*self._env.observation_space.shape], self._env.observation_space.dtype, ), "terminal": tf.TensorSpec([1], tf.bool), }, ), reverb.Table( # Variables container name="variable", sampler=reverb.selectors.Uniform(), remover=reverb.selectors.Fifo(), rate_limiter=reverb.rate_limiters.MinSize(1), max_size=1, max_times_sampled=0, signature=self._variable_container.signature, ), ], port=self._port, checkpointer=checkpointer, ) # Init variable container in DB self._variable_container.push_variables()
def train_eval( root_dir, env_name='HalfCheetah-v2', # Training params num_iterations=1600, actor_fc_layers=(64, 64), value_fc_layers=(64, 64), learning_rate=3e-4, collect_sequence_length=2048, minibatch_size=64, num_epochs=10, # Agent params importance_ratio_clipping=0.2, lambda_value=0.95, discount_factor=0.99, entropy_regularization=0., value_pred_loss_coef=0.5, use_gae=True, use_td_lambda_return=True, gradient_clipping=0.5, value_clipping=None, # Replay params reverb_port=None, replay_capacity=10000, # Others policy_save_interval=5000, summary_interval=1000, eval_interval=10000, eval_episodes=100, debug_summaries=False, summarize_grads_and_vars=False): """Trains and evaluates PPO (Importance Ratio Clipping). Args: root_dir: Main directory path where checkpoints, saved_models, and summaries will be written to. env_name: Name for the Mujoco environment to load. num_iterations: The number of iterations to perform collection and training. actor_fc_layers: List of fully_connected parameters for the actor network, where each item is the number of units in the layer. value_fc_layers: : List of fully_connected parameters for the value network, where each item is the number of units in the layer. learning_rate: Learning rate used on the Adam optimizer. collect_sequence_length: Number of steps to take in each collect run. minibatch_size: Number of elements in each mini batch. If `None`, the entire collected sequence will be treated as one batch. num_epochs: Number of iterations to repeat over all collected data per data collection step. (Schulman,2017) sets this to 10 for Mujoco, 15 for Roboschool and 3 for Atari. importance_ratio_clipping: Epsilon in clipped, surrogate PPO objective. For more detail, see explanation at the top of the doc. lambda_value: Lambda parameter for TD-lambda computation. discount_factor: Discount factor for return computation. Default to `0.99` which is the value used for all environments from (Schulman, 2017). entropy_regularization: Coefficient for entropy regularization loss term. Default to `0.0` because no entropy bonus was used in (Schulman, 2017). value_pred_loss_coef: Multiplier for value prediction loss to balance with policy gradient loss. Default to `0.5`, which was used for all environments in the OpenAI baseline implementation. This parameters is irrelevant unless you are sharing part of actor_net and value_net. In that case, you would want to tune this coeeficient, whose value depends on the network architecture of your choice. use_gae: If True (default False), uses generalized advantage estimation for computing per-timestep advantage. Else, just subtracts value predictions from empirical return. use_td_lambda_return: If True (default False), uses td_lambda_return for training value function; here: `td_lambda_return = gae_advantage + value_predictions`. `use_gae` must be set to `True` as well to enable TD -lambda returns. If `use_td_lambda_return` is set to True while `use_gae` is False, the empirical return will be used and a warning will be logged. gradient_clipping: Norm length to clip gradients. value_clipping: Difference between new and old value predictions are clipped to this threshold. Value clipping could be helpful when training very deep networks. Default: no clipping. reverb_port: Port for reverb server, if None, use a randomly chosen unused port. replay_capacity: The maximum number of elements for the replay buffer. Items will be wasted if this is smalled than collect_sequence_length. policy_save_interval: How often, in train_steps, the policy will be saved. summary_interval: How often to write data into Tensorboard. eval_interval: How often to run evaluation, in train_steps. eval_episodes: Number of episodes to evaluate over. debug_summaries: Boolean for whether to gather debug summaries. summarize_grads_and_vars: If true, gradient summaries will be written. """ collect_env = suite_mujoco.load(env_name) eval_env = suite_mujoco.load(env_name) num_environments = 1 observation_tensor_spec, action_tensor_spec, time_step_tensor_spec = ( spec_utils.get_tensor_specs(collect_env)) # TODO(b/172267869): Remove this conversion once TensorNormalizer stops # converting float64 inputs to float32. observation_tensor_spec = tf.TensorSpec( dtype=tf.float32, shape=observation_tensor_spec.shape) train_step = train_utils.create_train_step() actor_net_builder = ppo_actor_network.PPOActorNetwork() actor_net = actor_net_builder.create_sequential_actor_net( actor_fc_layers, action_tensor_spec) value_net = value_network.ValueNetwork( observation_tensor_spec, fc_layer_params=value_fc_layers, kernel_initializer=tf.keras.initializers.Orthogonal()) current_iteration = tf.Variable(0, dtype=tf.int64) def learning_rate_fn(): # Linearly decay the learning rate. return learning_rate * (1 - current_iteration / num_iterations) agent = ppo_clip_agent.PPOClipAgent( time_step_tensor_spec, action_tensor_spec, optimizer=tf.keras.optimizers.Adam( learning_rate=learning_rate_fn, epsilon=1e-5), actor_net=actor_net, value_net=value_net, importance_ratio_clipping=importance_ratio_clipping, lambda_value=lambda_value, discount_factor=discount_factor, entropy_regularization=entropy_regularization, value_pred_loss_coef=value_pred_loss_coef, # This is a legacy argument for the number of times we repeat the data # inside of the train function, incompatible with mini batch learning. # We set the epoch number from the replay buffer and tf.Data instead. num_epochs=1, use_gae=use_gae, use_td_lambda_return=use_td_lambda_return, gradient_clipping=gradient_clipping, value_clipping=value_clipping, # TODO(b/150244758): Default compute_value_and_advantage_in_train to False # after Reverb open source. compute_value_and_advantage_in_train=False, # Skips updating normalizers in the agent, as it's handled in the learner. update_normalizers_in_train=False, debug_summaries=debug_summaries, summarize_grads_and_vars=summarize_grads_and_vars, train_step_counter=train_step) agent.initialize() reverb_server = reverb.Server( [ reverb.Table( # Replay buffer storing experience for training. name='training_table', sampler=reverb.selectors.Fifo(), remover=reverb.selectors.Fifo(), rate_limiter=reverb.rate_limiters.MinSize(1), max_size=replay_capacity, max_times_sampled=1, ), reverb.Table( # Replay buffer storing experience for normalization. name='normalization_table', sampler=reverb.selectors.Fifo(), remover=reverb.selectors.Fifo(), rate_limiter=reverb.rate_limiters.MinSize(1), max_size=replay_capacity, max_times_sampled=1, ) ], port=reverb_port) # Create the replay buffer. reverb_replay_train = reverb_replay_buffer.ReverbReplayBuffer( agent.collect_data_spec, sequence_length=collect_sequence_length, table_name='training_table', server_address='localhost:{}'.format(reverb_server.port), # The only collected sequence is used to populate the batches. max_cycle_length=1, rate_limiter_timeout_ms=1000) reverb_replay_normalization = reverb_replay_buffer.ReverbReplayBuffer( agent.collect_data_spec, sequence_length=collect_sequence_length, table_name='normalization_table', server_address='localhost:{}'.format(reverb_server.port), # The only collected sequence is used to populate the batches. max_cycle_length=1, rate_limiter_timeout_ms=1000) rb_observer = reverb_utils.ReverbTrajectorySequenceObserver( reverb_replay_train.py_client, ['training_table', 'normalization_table'], sequence_length=collect_sequence_length, stride_length=collect_sequence_length) saved_model_dir = os.path.join(root_dir, learner.POLICY_SAVED_MODEL_DIR) collect_env_step_metric = py_metrics.EnvironmentSteps() learning_triggers = [ triggers.PolicySavedModelTrigger( saved_model_dir, agent, train_step, interval=policy_save_interval, metadata_metrics={ triggers.ENV_STEP_METADATA_KEY: collect_env_step_metric }), triggers.StepPerSecondLogTrigger(train_step, interval=summary_interval), ] def training_dataset_fn(): return reverb_replay_train.as_dataset( sample_batch_size=num_environments, sequence_preprocess_fn=agent.preprocess_sequence) def normalization_dataset_fn(): return reverb_replay_normalization.as_dataset( sample_batch_size=num_environments, sequence_preprocess_fn=agent.preprocess_sequence) agent_learner = ppo_learner.PPOLearner( root_dir, train_step, agent, experience_dataset_fn=training_dataset_fn, normalization_dataset_fn=normalization_dataset_fn, num_samples=1, num_epochs=num_epochs, minibatch_size=minibatch_size, shuffle_buffer_size=collect_sequence_length, triggers=learning_triggers) tf_collect_policy = agent.collect_policy collect_policy = py_tf_eager_policy.PyTFEagerPolicy( tf_collect_policy, use_tf_function=True) collect_actor = actor.Actor( collect_env, collect_policy, train_step, steps_per_run=collect_sequence_length, observers=[rb_observer], metrics=actor.collect_metrics(buffer_size=10) + [collect_env_step_metric], reference_metrics=[collect_env_step_metric], summary_dir=os.path.join(root_dir, learner.TRAIN_DIR), summary_interval=summary_interval) eval_greedy_policy = py_tf_eager_policy.PyTFEagerPolicy( agent.policy, use_tf_function=True) if eval_interval: logging.info('Intial evaluation.') eval_actor = actor.Actor( eval_env, eval_greedy_policy, train_step, metrics=actor.eval_metrics(eval_episodes), reference_metrics=[collect_env_step_metric], summary_dir=os.path.join(root_dir, 'eval'), episodes_per_run=eval_episodes) eval_actor.run_and_log() logging.info('Training on %s', env_name) last_eval_step = 0 for i in range(num_iterations): collect_actor.run() rb_observer.flush() agent_learner.run() reverb_replay_train.clear() reverb_replay_normalization.clear() current_iteration.assign_add(1) # Eval only if `eval_interval` has been set. Then, eval if the current train # step is equal or greater than the `last_eval_step` + `eval_interval` or if # this is the last iteration. This logic exists because agent_learner.run() # does not return after every train step. if (eval_interval and (agent_learner.train_step_numpy >= eval_interval + last_eval_step or i == num_iterations - 1)): logging.info('Evaluating.') eval_actor.run_and_log() last_eval_step = agent_learner.train_step_numpy rb_observer.close() reverb_server.stop()
def __init__( self, environment_spec: specs.EnvironmentSpec, network: snt.RNNCore, target_network: snt.RNNCore, burn_in_length: int, trace_length: int, replay_period: int, demonstration_generator: iter, demonstration_ratio: float, model_directory: str, counter: counting.Counter = None, logger: loggers.Logger = None, discount: float = 0.99, batch_size: int = 32, target_update_period: int = 100, importance_sampling_exponent: float = 0.2, epsilon: float = 0.01, learning_rate: float = 1e-3, log_to_bigtable: bool = False, log_name: str = 'agent', checkpoint: bool = True, min_replay_size: int = 1000, max_replay_size: int = 1000000, samples_per_insert: float = 32.0, ): extra_spec = { 'core_state': network.initial_state(1), } # replay table # Remove batch dimensions. extra_spec = tf2_utils.squeeze_batch_dim(extra_spec) replay_table = reverb.Table( name=adders.DEFAULT_PRIORITY_TABLE, sampler=reverb.selectors.Prioritized(0.8), remover=reverb.selectors.Fifo(), max_size=max_replay_size, rate_limiter=reverb.rate_limiters.MinSize(min_size_to_sample=1), signature=adders.SequenceAdder.signature(environment_spec, extra_spec)) # demonstation table. demonstration_table = reverb.Table( name='demonstration_table', sampler=reverb.selectors.Prioritized(0.8), remover=reverb.selectors.Fifo(), max_size=max_replay_size, rate_limiter=reverb.rate_limiters.MinSize(min_size_to_sample=1), signature=adders.SequenceAdder.signature(environment_spec, extra_spec)) # launch server self._server = reverb.Server([replay_table, demonstration_table], port=None) address = f'localhost:{self._server.port}' sequence_length = burn_in_length + trace_length + 1 # Component to add things into replay and demo sequence_kwargs = dict( period=replay_period, sequence_length=sequence_length, ) adder = adders.SequenceAdder(client=reverb.Client(address), **sequence_kwargs) priority_function = {demonstration_table.name: lambda x: 1.} demo_adder = adders.SequenceAdder(client=reverb.Client(address), priority_fns=priority_function, **sequence_kwargs) # play demonstrations and write # exhaust the generator # TODO: MAX REPLAY SIZE _prev_action = 1 # this has to come from spec _add_first = True #include this to make datasets equivalent numpy_state = tf2_utils.to_numpy_squeeze(network.initial_state(1)) for ts, action in demonstration_generator: if _add_first: demo_adder.add_first(ts) _add_first = False else: demo_adder.add(_prev_action, ts, extras=(numpy_state, )) _prev_action = action # reset to new episode if ts.last(): _prev_action = None _add_first = True # replay dataset max_in_flight_samples_per_worker = 2 * batch_size if batch_size else 100 dataset = reverb.ReplayDataset.from_table_signature( server_address=address, table=adders.DEFAULT_PRIORITY_TABLE, max_in_flight_samples_per_worker=max_in_flight_samples_per_worker, num_workers_per_iterator= 2, # memory perf improvment attempt https://github.com/deepmind/acme/issues/33 sequence_length=sequence_length, emit_timesteps=sequence_length is None) # demonstation dataset d_dataset = reverb.ReplayDataset.from_table_signature( server_address=address, table=demonstration_table.name, max_in_flight_samples_per_worker=max_in_flight_samples_per_worker, num_workers_per_iterator=2, sequence_length=sequence_length, emit_timesteps=sequence_length is None) dataset = tf.data.experimental.sample_from_datasets( [dataset, d_dataset], [1 - demonstration_ratio, demonstration_ratio]) # Batch and prefetch. dataset = dataset.batch(batch_size, drop_remainder=True) dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE) tf2_utils.create_variables(network, [environment_spec.observations]) tf2_utils.create_variables(target_network, [environment_spec.observations]) learner = learning.R2D2Learner( environment_spec=environment_spec, network=network, target_network=target_network, burn_in_length=burn_in_length, dataset=dataset, reverb_client=reverb.TFClient(address), counter=counter, logger=logger, sequence_length=sequence_length, discount=discount, target_update_period=target_update_period, importance_sampling_exponent=importance_sampling_exponent, max_replay_size=max_replay_size, learning_rate=learning_rate, store_lstm_state=False, ) self._checkpointer = tf2_savers.Checkpointer( directory=model_directory, subdirectory='r2d2_learner_v1', time_delta_minutes=15, objects_to_save=learner.state, enable_checkpointing=checkpoint, ) self._snapshotter = tf2_savers.Snapshotter(objects_to_save=None, time_delta_minutes=15000., directory=model_directory) policy_network = snt.DeepRNN([ network, lambda qs: trfl.epsilon_greedy(qs, epsilon=epsilon).sample(), ]) actor = actors.RecurrentActor(policy_network, adder) observations_per_step = (float(replay_period * batch_size) / samples_per_insert) super().__init__(actor=actor, learner=learner, min_observations=replay_period * max(batch_size, min_replay_size), observations_per_step=observations_per_step)
def train_eval( root_dir, env_name, # Training params train_sequence_length, initial_collect_steps=1000, collect_steps_per_iteration=1, num_iterations=100000, # RNN params. q_network_fn=q_lstm_network, # defaults to q_lstm_network. # Agent params epsilon_greedy=0.1, batch_size=64, learning_rate=1e-3, gamma=0.99, target_update_tau=0.05, target_update_period=5, reward_scale_factor=1.0, # Replay params reverb_port=None, replay_capacity=100000, # Others policy_save_interval=1000, eval_interval=1000, eval_episodes=10): """Trains and evaluates DQN.""" collect_env = suite_gym.load(env_name) eval_env = suite_gym.load(env_name) unused_observation_tensor_spec, action_tensor_spec, time_step_tensor_spec = ( spec_utils.get_tensor_specs(collect_env)) train_step = train_utils.create_train_step() num_actions = action_tensor_spec.maximum - action_tensor_spec.minimum + 1 q_net = q_network_fn(num_actions=num_actions) sequence_length = train_sequence_length + 1 agent = dqn_agent.DqnAgent( time_step_tensor_spec, action_tensor_spec, q_network=q_net, epsilon_greedy=epsilon_greedy, # n-step updates aren't supported with RNNs yet. n_step_update=1, target_update_tau=target_update_tau, target_update_period=target_update_period, optimizer=tf.keras.optimizers.Adam(learning_rate=learning_rate), td_errors_loss_fn=common.element_wise_squared_loss, gamma=gamma, reward_scale_factor=reward_scale_factor, train_step_counter=train_step) table_name = 'uniform_table' table = reverb.Table(table_name, max_size=replay_capacity, sampler=reverb.selectors.Uniform(), remover=reverb.selectors.Fifo(), rate_limiter=reverb.rate_limiters.MinSize(1)) reverb_server = reverb.Server([table], port=reverb_port) reverb_replay = reverb_replay_buffer.ReverbReplayBuffer( agent.collect_data_spec, sequence_length=sequence_length, table_name=table_name, local_server=reverb_server) rb_observer = reverb_utils.ReverbAddTrajectoryObserver( reverb_replay.py_client, table_name, sequence_length=sequence_length, stride_length=1, pad_end_of_episodes=True) def experience_dataset_fn(): return reverb_replay.as_dataset(sample_batch_size=batch_size, num_steps=sequence_length) saved_model_dir = os.path.join(root_dir, learner.POLICY_SAVED_MODEL_DIR) env_step_metric = py_metrics.EnvironmentSteps() learning_triggers = [ triggers.PolicySavedModelTrigger( saved_model_dir, agent, train_step, interval=policy_save_interval, metadata_metrics={triggers.ENV_STEP_METADATA_KEY: env_step_metric}), triggers.StepPerSecondLogTrigger(train_step, interval=100), ] dqn_learner = learner.Learner(root_dir, train_step, agent, experience_dataset_fn, triggers=learning_triggers) # If we haven't trained yet make sure we collect some random samples first to # fill up the Replay Buffer with some experience. random_policy = random_py_policy.RandomPyPolicy( collect_env.time_step_spec(), collect_env.action_spec()) initial_collect_actor = actor.Actor(collect_env, random_policy, train_step, steps_per_run=initial_collect_steps, observers=[rb_observer]) logging.info('Doing initial collect.') initial_collect_actor.run() tf_collect_policy = agent.collect_policy collect_policy = py_tf_eager_policy.PyTFEagerPolicy(tf_collect_policy, use_tf_function=True) collect_actor = actor.Actor( collect_env, collect_policy, train_step, steps_per_run=collect_steps_per_iteration, observers=[rb_observer, env_step_metric], metrics=actor.collect_metrics(10), summary_dir=os.path.join(root_dir, learner.TRAIN_DIR), ) tf_greedy_policy = agent.policy greedy_policy = py_tf_eager_policy.PyTFEagerPolicy(tf_greedy_policy, use_tf_function=True) eval_actor = actor.Actor( eval_env, greedy_policy, train_step, episodes_per_run=eval_episodes, metrics=actor.eval_metrics(eval_episodes), summary_dir=os.path.join(root_dir, 'eval'), ) if eval_interval: logging.info('Evaluating.') eval_actor.run_and_log() logging.info('Training.') for _ in range(num_iterations): collect_actor.run() dqn_learner.run(iterations=1) if eval_interval and dqn_learner.train_step_numpy % eval_interval == 0: logging.info('Evaluating.') eval_actor.run_and_log() rb_observer.close() reverb_server.stop()
def train_agent(iterations, modeldir, logdir, policydir): """Train and convert the model using TF Agents.""" # TODO: add code to instantiate the training and evaluation environments # TODO: add code to create a reinforcement learning agent that is going to be trained tf_agent.initialize() eval_policy = tf_agent.policy collect_policy = tf_agent.collect_policy tf_policy_saver = policy_saver.PolicySaver(collect_policy) # Use reverb as replay buffer replay_buffer_signature = tensor_spec.from_spec(tf_agent.collect_data_spec) table = reverb.Table( REPLAY_BUFFER_TABLE_NAME, max_size=REPLAY_BUFFER_CAPACITY, sampler=reverb.selectors.Uniform(), remover=reverb.selectors.Fifo(), rate_limiter=reverb.rate_limiters.MinSize(1), signature=replay_buffer_signature, ) # specify signature here for validation at insertion time reverb_server = reverb.Server([table]) replay_buffer = reverb_replay_buffer.ReverbReplayBuffer( tf_agent.collect_data_spec, sequence_length=None, table_name=REPLAY_BUFFER_TABLE_NAME, local_server=reverb_server, ) replay_buffer_observer = reverb_utils.ReverbAddEpisodeObserver( replay_buffer.py_client, REPLAY_BUFFER_TABLE_NAME, REPLAY_BUFFER_CAPACITY) # Optimize by wrapping some of the code in a graph using TF function. tf_agent.train = common.function(tf_agent.train) # Evaluate the agent's policy once before training. avg_return = compute_avg_return_and_steps(eval_env, tf_agent.policy, NUM_EVAL_EPISODES) summary_writer = tf.summary.create_file_writer(logdir) for i in range(iterations): # TODO: add code to collect game episodes and train the agent logger = tf.get_logger() if i % EVAL_INTERVAL == 0: avg_return, avg_episode_length = compute_avg_return_and_steps( eval_env, eval_policy, NUM_EVAL_EPISODES) with summary_writer.as_default(): tf.summary.scalar("Average return", avg_return, step=i) tf.summary.scalar("Average episode length", avg_episode_length, step=i) summary_writer.flush() logger.info( "iteration = {0}: Average Return = {1}, Average Episode Length = {2}" .format(i, avg_return, avg_episode_length)) summary_writer.close() tf_policy_saver.save(policydir)
def __init__(self, environment_spec: specs.EnvironmentSpec, policy_network: snt.Module, critic_network: snt.Module, encoder_network: types.TensorTransformation = tf.identity, entropy_coeff: float = 0.01, target_update_period: int = 0, discount: float = 0.99, batch_size: int = 256, policy_learn_rate: float = 3e-4, critic_learn_rate: float = 5e-4, prefetch_size: int = 4, min_replay_size: int = 1000, max_replay_size: int = 250000, samples_per_insert: float = 64.0, n_step: int = 5, sigma: float = 0.5, clipping: bool = True, logger: loggers.Logger = None, counter: counting.Counter = None, checkpoint: bool = True, replay_table_name: str = adders.DEFAULT_PRIORITY_TABLE): """Initialize the agent. Args: environment_spec: description of the actions, observations, etc. policy_network: the online (optimized) policy. critic_network: the online critic. observation_network: optional network to transform the observations before they are fed into any network. discount: discount to use for TD updates. batch_size: batch size for updates. prefetch_size: size to prefetch from replay. target_update_period: number of learner steps to perform before updating the target networks. min_replay_size: minimum replay size before updating. max_replay_size: maximum replay size. samples_per_insert: number of samples to take from replay for every insert that is made. n_step: number of steps to squash into a single transition. sigma: standard deviation of zero-mean, Gaussian exploration noise. clipping: whether to clip gradients by global norm. logger: logger object to be used by learner. counter: counter object used to keep track of steps. checkpoint: boolean indicating whether to checkpoint the learner. replay_table_name: string indicating what name to give the replay table. """ # Create a replay server to add data to. This uses no limiter behavior in # order to allow the Agent interface to handle it. dim_actions = np.prod(environment_spec.actions.shape, dtype=int) extra_spec = { 'logP': tf.ones(shape=(1), dtype=tf.float32), 'policy': tf.ones(shape=(1, dim_actions), dtype=tf.float32) } # Remove batch dimensions. extra_spec = tf2_utils.squeeze_batch_dim(extra_spec) replay_table = reverb.Table( name=replay_table_name, sampler=reverb.selectors.Uniform(), remover=reverb.selectors.Fifo(), max_size=max_replay_size, rate_limiter=reverb.rate_limiters.MinSize(1), signature=adders.NStepTransitionAdder.signature( environment_spec, extras_spec=extra_spec)) self._server = reverb.Server([replay_table], port=None) # The adder is used to insert observations into replay. address = f'localhost:{self._server.port}' adder = adders.NStepTransitionAdder( priority_fns={replay_table_name: lambda x: 1.}, client=reverb.Client(address), n_step=n_step, discount=discount) # The dataset provides an interface to sample from replay. dataset = datasets.make_reverb_dataset(table=replay_table_name, server_address=address, batch_size=batch_size, prefetch_size=prefetch_size) # Make sure observation network is a Sonnet Module. observation_network = model.MDPNormalization(environment_spec, encoder_network) # Get observation and action specs. act_spec = environment_spec.actions obs_spec = environment_spec.observations # Create the behavior policy. sampling_head = model.SquashedGaussianSamplingHead(act_spec, sigma) self._behavior_network = model.PolicyValueBehaviorNet( snt.Sequential([observation_network, policy_network]), sampling_head) # Create variables. emb_spec = tf2_utils.create_variables(observation_network, [obs_spec]) tf2_utils.create_variables(policy_network, [emb_spec]) tf2_utils.create_variables(critic_network, [emb_spec, act_spec]) # Create the actor which defines how we take actions. actor = model.SACFeedForwardActor(self._behavior_network, adder) if target_update_period > 0: target_policy_network = copy.deepcopy(policy_network) target_critic_network = copy.deepcopy(critic_network) target_observation_network = copy.deepcopy(observation_network) tf2_utils.create_variables(target_policy_network, [emb_spec]) tf2_utils.create_variables(target_critic_network, [emb_spec, act_spec]) tf2_utils.create_variables(target_observation_network, [obs_spec]) else: target_policy_network = policy_network target_critic_network = critic_network target_observation_network = observation_network # Create optimizers. policy_optimizer = snt.optimizers.Adam(learning_rate=policy_learn_rate) critic_optimizer = snt.optimizers.Adam(learning_rate=critic_learn_rate) # The learner updates the parameters (and initializes them). learner = learning.SACLearner( policy_network=policy_network, critic_network=critic_network, sampling_head=sampling_head, observation_network=observation_network, target_policy_network=target_policy_network, target_critic_network=target_critic_network, target_observation_network=target_observation_network, policy_optimizer=policy_optimizer, critic_optimizer=critic_optimizer, target_update_period=target_update_period, learning_rate=policy_learn_rate, clipping=clipping, entropy_coeff=entropy_coeff, discount=discount, dataset=dataset, counter=counter, logger=logger, checkpoint=checkpoint, ) super().__init__(actor=actor, learner=learner, min_observations=max(batch_size, min_replay_size), observations_per_step=float(batch_size) / samples_per_insert)
def __init__( self, environment_spec: specs.EnvironmentSpec, network: snt.Module, batch_size: int = 256, prefetch_size: int = 4, target_update_period: int = 100, samples_per_insert: float = 32.0, min_replay_size: int = 1000, max_replay_size: int = 1000000, importance_sampling_exponent: float = 0.2, priority_exponent: float = 0.6, n_step: int = 5, epsilon: tf.Tensor = None, learning_rate: float = 1e-3, discount: float = 0.99, cql_alpha: float = 1., logger: loggers.Logger = None, counter: counting.Counter = None, checkpoint_subpath: str = '~/acme/', ): """Initialize the agent. Args: environment_spec: description of the actions, observations, etc. network: the online Q network (the one being optimized) batch_size: batch size for updates. prefetch_size: size to prefetch from replay. target_update_period: number of learner steps to perform before updating the target networks. samples_per_insert: number of samples to take from replay for every insert that is made. min_replay_size: minimum replay size before updating. This and all following arguments are related to dataset construction and will be ignored if a dataset argument is passed. max_replay_size: maximum replay size. importance_sampling_exponent: power to which importance weights are raised before normalizing. priority_exponent: exponent used in prioritized sampling. n_step: number of steps to squash into a single transition. epsilon: probability of taking a random action; ignored if a policy network is given. learning_rate: learning rate for the q-network update. discount: discount to use for TD updates. logger: logger object to be used by learner. checkpoint: boolean indicating whether to checkpoint the learner. checkpoint_subpath: directory for the checkpoint. """ # Create a replay server to add data to. This uses no limiter behavior in # order to allow the Agent interface to handle it. replay_table = reverb.Table( name=adders.DEFAULT_PRIORITY_TABLE, sampler=reverb.selectors.Prioritized(priority_exponent), remover=reverb.selectors.Fifo(), max_size=max_replay_size, rate_limiter=reverb.rate_limiters.MinSize(1), signature=adders.NStepTransitionAdder.signature(environment_spec)) self._server = reverb.Server([replay_table], port=None) # The adder is used to insert observations into replay. address = f'localhost:{self._server.port}' adder = adders.NStepTransitionAdder(client=reverb.Client(address), n_step=n_step, discount=discount) # The dataset provides an interface to sample from replay. replay_client = reverb.TFClient(address) dataset = datasets.make_reverb_dataset( client=replay_client, environment_spec=environment_spec, batch_size=batch_size, prefetch_size=prefetch_size, transition_adder=True) # Use constant 0.05 epsilon greedy policy by default. if epsilon is None: epsilon = tf.Variable(0.05, trainable=False) policy_network = snt.Sequential([ network, lambda q: trfl.epsilon_greedy(q, epsilon=epsilon).sample(), ]) # Create a target network. target_network = copy.deepcopy(network) # Ensure that we create the variables before proceeding (maybe not needed). tf2_utils.create_variables(network, [environment_spec.observations]) tf2_utils.create_variables(target_network, [environment_spec.observations]) # Create the actor which defines how we take actions. actor = actors.FeedForwardActor(policy_network, adder) # The learner updates the parameters (and initializes them). learner = CQLLearner( network=network, discount=discount, importance_sampling_exponent=importance_sampling_exponent, learning_rate=learning_rate, cql_alpha=cql_alpha, target_update_period=target_update_period, dataset=dataset, replay_client=replay_client, logger=logger, counter=counter, checkpoint_subpath=checkpoint_subpath) super().__init__(actor=actor, learner=learner, min_observations=max(batch_size, min_replay_size), observations_per_step=float(batch_size) / samples_per_insert)
def train_eval( root_dir, env_name='CartPole-v0', # Training params initial_collect_steps=1000, num_iterations=100000, fc_layer_params=(100, ), # Agent params epsilon_greedy=0.1, batch_size=64, learning_rate=1e-3, n_step_update=1, gamma=0.99, target_update_tau=0.05, target_update_period=5, reward_scale_factor=1.0, # Replay params reverb_port=None, replay_capacity=100000, # Others policy_save_interval=1000, eval_interval=1000, eval_episodes=10): """Trains and evaluates DQN.""" collect_env = suite_gym.load(env_name) eval_env = suite_gym.load(env_name) time_step_tensor_spec = tensor_spec.from_spec(collect_env.time_step_spec()) action_tensor_spec = tensor_spec.from_spec(collect_env.action_spec()) train_step = train_utils.create_train_step() num_actions = action_tensor_spec.maximum - action_tensor_spec.minimum + 1 # Define a helper function to create Dense layers configured with the right # activation and kernel initializer. def dense_layer(num_units): return tf.keras.layers.Dense( num_units, activation=tf.keras.activations.relu, kernel_initializer=tf.keras.initializers.VarianceScaling( scale=2.0, mode='fan_in', distribution='truncated_normal')) # QNetwork consists of a sequence of Dense layers followed by a dense layer # with `num_actions` units to generate one q_value per available action as # it's output. dense_layers = [dense_layer(num_units) for num_units in fc_layer_params] q_values_layer = tf.keras.layers.Dense( num_actions, activation=None, kernel_initializer=tf.keras.initializers.RandomUniform(minval=-0.03, maxval=0.03), bias_initializer=tf.keras.initializers.Constant(-0.2)) q_net = sequential.Sequential(dense_layers + [q_values_layer]) agent = dqn_agent.DqnAgent( time_step_tensor_spec, action_tensor_spec, q_network=q_net, epsilon_greedy=epsilon_greedy, n_step_update=n_step_update, target_update_tau=target_update_tau, target_update_period=target_update_period, optimizer=tf.keras.optimizers.Adam(learning_rate=learning_rate), td_errors_loss_fn=common.element_wise_squared_loss, gamma=gamma, reward_scale_factor=reward_scale_factor, train_step_counter=train_step) table_name = 'uniform_table' table = reverb.Table(table_name, max_size=replay_capacity, sampler=reverb.selectors.Uniform(), remover=reverb.selectors.Fifo(), rate_limiter=reverb.rate_limiters.MinSize(1)) reverb_server = reverb.Server([table], port=reverb_port) reverb_replay = reverb_replay_buffer.ReverbReplayBuffer( agent.collect_data_spec, sequence_length=2, table_name=table_name, local_server=reverb_server) rb_observer = reverb_utils.ReverbAddTrajectoryObserver( reverb_replay.py_client, table_name, sequence_length=2, stride_length=1) dataset = reverb_replay.as_dataset(num_parallel_calls=3, sample_batch_size=batch_size, num_steps=2).prefetch(3) experience_dataset_fn = lambda: dataset saved_model_dir = os.path.join(root_dir, learner.POLICY_SAVED_MODEL_DIR) env_step_metric = py_metrics.EnvironmentSteps() learning_triggers = [ triggers.PolicySavedModelTrigger( saved_model_dir, agent, train_step, interval=policy_save_interval, metadata_metrics={triggers.ENV_STEP_METADATA_KEY: env_step_metric}), triggers.StepPerSecondLogTrigger(train_step, interval=100), ] dqn_learner = learner.Learner(root_dir, train_step, agent, experience_dataset_fn, triggers=learning_triggers) # If we haven't trained yet make sure we collect some random samples first to # fill up the Replay Buffer with some experience. random_policy = random_py_policy.RandomPyPolicy( collect_env.time_step_spec(), collect_env.action_spec()) initial_collect_actor = actor.Actor(collect_env, random_policy, train_step, steps_per_run=initial_collect_steps, observers=[rb_observer]) logging.info('Doing initial collect.') initial_collect_actor.run() tf_collect_policy = agent.collect_policy collect_policy = py_tf_eager_policy.PyTFEagerPolicy(tf_collect_policy, use_tf_function=True) collect_actor = actor.Actor( collect_env, collect_policy, train_step, steps_per_run=1, observers=[rb_observer, env_step_metric], metrics=actor.collect_metrics(10), summary_dir=os.path.join(root_dir, learner.TRAIN_DIR), ) tf_greedy_policy = agent.policy greedy_policy = py_tf_eager_policy.PyTFEagerPolicy(tf_greedy_policy, use_tf_function=True) eval_actor = actor.Actor( eval_env, greedy_policy, train_step, episodes_per_run=eval_episodes, metrics=actor.eval_metrics(eval_episodes), summary_dir=os.path.join(root_dir, 'eval'), ) if eval_interval: logging.info('Evaluating.') eval_actor.run_and_log() logging.info('Training.') for _ in range(num_iterations): collect_actor.run() dqn_learner.run(iterations=1) if eval_interval and dqn_learner.train_step_numpy % eval_interval == 0: logging.info('Evaluating.') eval_actor.run_and_log() rb_observer.close() reverb_server.stop()
def train_agent(iterations, modeldir, logdir, policydir): """Train and convert the model using TF Agents.""" train_py_env = planestrike_py_environment.PlaneStrikePyEnvironment( board_size=BOARD_SIZE, discount=DISCOUNT, max_steps=BOARD_SIZE**2) eval_py_env = planestrike_py_environment.PlaneStrikePyEnvironment( board_size=BOARD_SIZE, discount=DISCOUNT, max_steps=BOARD_SIZE**2) train_env = tf_py_environment.TFPyEnvironment(train_py_env) eval_env = tf_py_environment.TFPyEnvironment(eval_py_env) # Alternatively you could use ActorDistributionNetwork as actor_net actor_net = tfa.networks.Sequential( [ tfa.keras_layers.InnerReshape([BOARD_SIZE, BOARD_SIZE], [BOARD_SIZE**2]), tf.keras.layers.Dense(FC_LAYER_PARAMS, activation='relu'), tf.keras.layers.Dense(BOARD_SIZE**2), tf.keras.layers.Lambda( lambda t: tfp.distributions.Categorical(logits=t)), ], input_spec=train_py_env.observation_spec()) optimizer = tf.keras.optimizers.Adam(learning_rate=LEARNING_RATE) train_step_counter = tf.Variable(0) tf_agent = reinforce_agent.ReinforceAgent( train_env.time_step_spec(), train_env.action_spec(), actor_network=actor_net, optimizer=optimizer, normalize_returns=True, train_step_counter=train_step_counter) tf_agent.initialize() eval_policy = tf_agent.policy collect_policy = tf_agent.collect_policy tf_policy_saver = policy_saver.PolicySaver(collect_policy) # Use reverb as replay buffer replay_buffer_signature = tensor_spec.from_spec(tf_agent.collect_data_spec) table = reverb.Table( REPLAY_BUFFER_TABLE_NAME, max_size=REPLAY_BUFFER_CAPACITY, sampler=reverb.selectors.Uniform(), remover=reverb.selectors.Fifo(), rate_limiter=reverb.rate_limiters.MinSize(1), signature=replay_buffer_signature ) # specify signature here for validation at insertion time reverb_server = reverb.Server([table]) replay_buffer = reverb_replay_buffer.ReverbReplayBuffer( tf_agent.collect_data_spec, sequence_length=None, table_name=REPLAY_BUFFER_TABLE_NAME, local_server=reverb_server) replay_buffer_observer = reverb_utils.ReverbAddEpisodeObserver( replay_buffer.py_client, REPLAY_BUFFER_TABLE_NAME, REPLAY_BUFFER_CAPACITY) # Optimize by wrapping some of the code in a graph using TF function. tf_agent.train = common.function(tf_agent.train) # Evaluate the agent's policy once before training. avg_return = compute_avg_return_and_steps(eval_env, tf_agent.policy, NUM_EVAL_EPISODES) summary_writer = tf.summary.create_file_writer(logdir) for i in range(iterations): # Collect a few episodes using collect_policy and save to the replay buffer. collect_episode(train_py_env, collect_policy, COLLECT_EPISODES_PER_ITERATION, replay_buffer_observer) # Use data from the buffer and update the agent's network. iterator = iter(replay_buffer.as_dataset(sample_batch_size=1)) trajectories, _ = next(iterator) tf_agent.train(experience=trajectories) replay_buffer.clear() logger = tf.get_logger() if i % EVAL_INTERVAL == 0: avg_return, avg_episode_length = compute_avg_return_and_steps( eval_env, eval_policy, NUM_EVAL_EPISODES) with summary_writer.as_default(): tf.summary.scalar('Average return', avg_return, step=i) tf.summary.scalar('Average episode length', avg_episode_length, step=i) summary_writer.flush() logger.info( 'iteration = {0}: Average Return = {1}, Average Episode Length = {2}' .format(i, avg_return, avg_episode_length)) summary_writer.close() tf_policy_saver.save(policydir) # Convert to tflite model converter = tf.lite.TFLiteConverter.from_saved_model( policydir, signature_keys=['action']) converter.target_spec.supported_ops = [ tf.lite.OpsSet.TFLITE_BUILTINS, # enable TensorFlow Lite ops. tf.lite.OpsSet.SELECT_TF_OPS # enable TensorFlow ops. ] tflite_policy = converter.convert() with open(os.path.join(modeldir, 'planestrike_tf_agents.tflite'), 'wb') as f: f.write(tflite_policy)
def __init__( self, environment_spec: specs.EnvironmentSpec, network: hk.Transformed, batch_size: int = 256, prefetch_size: int = 4, target_update_period: int = 100, samples_per_insert: float = 32.0, min_replay_size: int = 1000, max_replay_size: int = 1000000, importance_sampling_exponent: float = 0.2, priority_exponent: float = 0.6, n_step: int = 5, epsilon: float = 0., learning_rate: float = 1e-3, discount: float = 0.99, seed: int = 1, ): """Initialize the agent.""" # Create a replay server to add data to. This uses no limiter behavior in # order to allow the Agent interface to handle it. replay_table = reverb.Table( name=adders.DEFAULT_PRIORITY_TABLE, sampler=reverb.selectors.Prioritized(priority_exponent), remover=reverb.selectors.Fifo(), max_size=max_replay_size, rate_limiter=reverb.rate_limiters.MinSize(1), signature=adders.NStepTransitionAdder.signature( environment_spec=environment_spec)) self._server = reverb.Server([replay_table], port=None) # The adder is used to insert observations into replay. address = f'localhost:{self._server.port}' adder = adders.NStepTransitionAdder(client=reverb.Client(address), n_step=n_step, discount=discount) # The dataset provides an interface to sample from replay. dataset = datasets.make_reverb_dataset( server_address=address, environment_spec=environment_spec, batch_size=batch_size, prefetch_size=prefetch_size, transition_adder=True) def policy(params: hk.Params, key: jnp.ndarray, observation: jnp.ndarray) -> jnp.ndarray: action_values = network.apply(params, observation) return rlax.epsilon_greedy(epsilon).sample(key, action_values) # The learner updates the parameters (and initializes them). learner = learning.DQNLearner( network=network, obs_spec=environment_spec.observations, rng=hk.PRNGSequence(seed), optimizer=optax.adam(learning_rate), discount=discount, importance_sampling_exponent=importance_sampling_exponent, target_update_period=target_update_period, iterator=dataset.as_numpy_iterator(), replay_client=reverb.Client(address), ) variable_client = variable_utils.VariableClient(learner, '') actor = actors.FeedForwardActor(policy=policy, rng=hk.PRNGSequence(seed), variable_client=variable_client, adder=adder) super().__init__(actor=actor, learner=learner, min_observations=max(batch_size, min_replay_size), observations_per_step=float(batch_size) / samples_per_insert)
def __init__( self, environment_spec: specs.EnvironmentSpec, builder: builders.ActorLearnerBuilder, networks: Any, policy_network: Any, min_replay_size: int = 1000, samples_per_insert: float = 256.0, batch_size: int = 256, num_sgd_steps_per_step: int = 1, counter: counting.Counter = None, logger: loggers.Logger = None, checkpoint: bool = True, ): """Initialize the agent. Args: environment_spec: description of the actions, observations, etc. builder: builder defining an RL algorithm to train. networks: network objects to be passed to the learner. policy_network: function that given an observation returns actions. min_replay_size: minimum replay size before updating. samples_per_insert: number of samples to take from replay for every insert that is made. batch_size: batch size for updates. num_sgd_steps_per_step: how many sgd steps a learner does per 'step' call. For performance reasons (especially to reduce TPU host-device transfer times) it is performance-beneficial to do multiple sgd updates at once, provided that it does not hurt the training, which needs to be verified empirically for each environment. counter: counter object used to keep track of steps. logger: logger object to be used by learner. checkpoint: boolean indicating whether to checkpoint the learner. """ # Create the replay server and grab its address. replay_tables = builder.make_replay_tables(environment_spec) replay_server = reverb.Server(replay_tables, port=None) replay_client = reverb.Client(f'localhost:{replay_server.port}') # Create actor, dataset, and learner for generating, storing, and consuming # data respectively. adder = builder.make_adder(replay_client) dataset = builder.make_dataset_iterator(replay_client) learner = builder.make_learner(networks=networks, dataset=dataset, replay_client=replay_client, counter=counter, logger=logger, checkpoint=checkpoint) actor = builder.make_actor(policy_network, adder, variable_source=learner) effective_batch_size = batch_size * num_sgd_steps_per_step super().__init__(actor=actor, learner=learner, min_observations=max(effective_batch_size, min_replay_size), observations_per_step=float(effective_batch_size) / samples_per_insert) # Save the replay so we don't garbage collect it. self._replay_server = replay_server
return avg_return.numpy()[0] # Standard implementations for evaluation metrics in the metrics module. table_name = 'uniform_table' replay_buffer_signature = tensor_spec.from_spec(tf_agent.collect_data_spec) replay_buffer_signature = tensor_spec.add_outer_dim(replay_buffer_signature) table = reverb.Table(table_name, max_size=replay_buffer_capacity, sampler=reverb.selectors.Uniform(), remover=reverb.selectors.Fifo(), rate_limiter=reverb.rate_limiters.MinSize(1), signature=replay_buffer_signature) reverb_server = reverb.Server([table]) replay_buffer = reverb_replay_buffer.ReverbReplayBuffer( tf_agent.collect_data_spec, table_name=table_name, sequence_length=None, local_server=reverb_server) rb_observer = reverb_utils.ReverbAddEpisodeObserver(replay_buffer.py_client, table_name, replay_buffer_capacity) def collect_episode(environment, policy, num_episodes): driver = py_driver.PyDriver(environment,
discount: float = 1., replay_table_name: str = adders.DEFAULT_PRIORITY_TABLE, prefetch_size: int = 4, ) -> ReverbReplay: """Creates a single-process replay infrastructure from an environment spec.""" # Create a replay server to add data to. This uses no limiter behavior in # order to allow the Agent interface to handle it. replay_table = reverb.Table( name=replay_table_name, sampler=reverb.selectors.Prioritized(priority_exponent), remover=reverb.selectors.Fifo(), max_size=max_replay_size, rate_limiter=reverb.rate_limiters.MinSize(min_replay_size), signature=adders.NStepTransitionAdder.signature( environment_spec=environment_spec)) server = reverb.Server([replay_table], port=None) # The adder is used to insert observations into replay. address = f'localhost:{server.port}' client = reverb.Client(address) adder = adders.NStepTransitionAdder(client=client, n_step=n_step, discount=discount) # The dataset provides an interface to sample from replay. data_iterator = datasets.make_reverb_dataset( table=replay_table_name, server_address=address, batch_size=batch_size, prefetch_size=prefetch_size, environment_spec=environment_spec,
def train_eval( root_dir, env_name='HalfCheetah-v2', # Training params initial_collect_steps=10000, num_iterations=3200000, actor_fc_layers=(256, 256), critic_obs_fc_layers=None, critic_action_fc_layers=None, critic_joint_fc_layers=(256, 256), # Agent params batch_size=256, actor_learning_rate=3e-4, critic_learning_rate=3e-4, alpha_learning_rate=3e-4, gamma=0.99, target_update_tau=0.005, target_update_period=1, reward_scale_factor=0.1, # Replay params reverb_port=None, replay_capacity=1000000, # Others # Defaults to not checkpointing saved policy. If you wish to enable this, # please note the caveat explained in README.md. policy_save_interval=-1, eval_interval=10000, eval_episodes=30, debug_summaries=False, summarize_grads_and_vars=False): """Trains and evaluates SAC.""" logging.info('Training SAC on: %s', env_name) collect_env = suite_mujoco.load(env_name) eval_env = suite_mujoco.load(env_name) observation_tensor_spec, action_tensor_spec, time_step_tensor_spec = ( spec_utils.get_tensor_specs(collect_env)) train_step = train_utils.create_train_step() actor_net = actor_distribution_network.ActorDistributionNetwork( observation_tensor_spec, action_tensor_spec, fc_layer_params=actor_fc_layers, continuous_projection_net=tanh_normal_projection_network. TanhNormalProjectionNetwork) critic_net = critic_network.CriticNetwork( (observation_tensor_spec, action_tensor_spec), observation_fc_layer_params=critic_obs_fc_layers, action_fc_layer_params=critic_action_fc_layers, joint_fc_layer_params=critic_joint_fc_layers, kernel_initializer='glorot_uniform', last_kernel_initializer='glorot_uniform') agent = sac_agent.SacAgent( time_step_tensor_spec, action_tensor_spec, actor_network=actor_net, critic_network=critic_net, actor_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=actor_learning_rate), critic_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=critic_learning_rate), alpha_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=alpha_learning_rate), target_update_tau=target_update_tau, target_update_period=target_update_period, td_errors_loss_fn=tf.math.squared_difference, gamma=gamma, reward_scale_factor=reward_scale_factor, gradient_clipping=None, debug_summaries=debug_summaries, summarize_grads_and_vars=summarize_grads_and_vars, train_step_counter=train_step) agent.initialize() table_name = 'uniform_table' table = reverb.Table(table_name, max_size=replay_capacity, sampler=reverb.selectors.Uniform(), remover=reverb.selectors.Fifo(), rate_limiter=reverb.rate_limiters.MinSize(1)) reverb_server = reverb.Server([table], port=reverb_port) reverb_replay = reverb_replay_buffer.ReverbReplayBuffer( agent.collect_data_spec, sequence_length=2, table_name=table_name, local_server=reverb_server) rb_observer = reverb_utils.ReverbAddTrajectoryObserver( reverb_replay.py_client, table_name, sequence_length=2, stride_length=1) dataset = reverb_replay.as_dataset(sample_batch_size=batch_size, num_steps=2).prefetch(50) experience_dataset_fn = lambda: dataset saved_model_dir = os.path.join(root_dir, learner.POLICY_SAVED_MODEL_DIR) env_step_metric = py_metrics.EnvironmentSteps() learning_triggers = [ triggers.PolicySavedModelTrigger( saved_model_dir, agent, train_step, interval=policy_save_interval, metadata_metrics={triggers.ENV_STEP_METADATA_KEY: env_step_metric}), triggers.StepPerSecondLogTrigger(train_step, interval=1000), ] agent_learner = learner.Learner(root_dir, train_step, agent, experience_dataset_fn, triggers=learning_triggers) random_policy = random_py_policy.RandomPyPolicy( collect_env.time_step_spec(), collect_env.action_spec()) initial_collect_actor = actor.Actor(collect_env, random_policy, train_step, steps_per_run=initial_collect_steps, observers=[rb_observer]) logging.info('Doing initial collect.') initial_collect_actor.run() tf_collect_policy = agent.collect_policy collect_policy = py_tf_eager_policy.PyTFEagerPolicy(tf_collect_policy, use_tf_function=True) collect_actor = actor.Actor(collect_env, collect_policy, train_step, steps_per_run=1, metrics=actor.collect_metrics(10), summary_dir=os.path.join( root_dir, learner.TRAIN_DIR), observers=[rb_observer, env_step_metric]) tf_greedy_policy = greedy_policy.GreedyPolicy(agent.policy) eval_greedy_policy = py_tf_eager_policy.PyTFEagerPolicy( tf_greedy_policy, use_tf_function=True) eval_actor = actor.Actor( eval_env, eval_greedy_policy, train_step, episodes_per_run=eval_episodes, metrics=actor.eval_metrics(eval_episodes), summary_dir=os.path.join(root_dir, 'eval'), ) if eval_interval: logging.info('Evaluating.') eval_actor.run_and_log() logging.info('Training.') for _ in range(num_iterations): collect_actor.run() agent_learner.run(iterations=1) if eval_interval and agent_learner.train_step_numpy % eval_interval == 0: logging.info('Evaluating.') eval_actor.run_and_log() rb_observer.close() reverb_server.stop()
def __init__( self, environment_spec: specs.EnvironmentSpec, policy_network: snt.Module, critic_network: snt.Module, discount: float = 0.99, batch_size: int = 256, prefetch_size: int = 4, target_update_period: int = 100, prior_network: Optional[snt.Module] = None, policy_optimizer: Optional[snt.Optimizer] = None, critic_optimizer: Optional[snt.Optimizer] = None, prior_optimizer: Optional[snt.Optimizer] = None, distillation_cost: Optional[float] = 1e-3, entropy_regularizer_cost: Optional[float] = 1e-3, min_replay_size: int = 1000, max_replay_size: int = 1000000, samples_per_insert: float = 32.0, sequence_length: int = 10, sigma: float = 0.3, replay_table_name: str = reverb_adders.DEFAULT_PRIORITY_TABLE, counter: Optional[counting.Counter] = None, logger: Optional[loggers.Logger] = None, checkpoint: bool = True, ): """Initialize the agent. Args: environment_spec: description of the actions, observations, etc. policy_network: the online (optimized) policy. critic_network: the online critic. discount: discount to use for TD updates. batch_size: batch size for updates. prefetch_size: size to prefetch from replay. target_update_period: number of learner steps to perform before updating the target networks. prior_network: an optional `behavior prior` to regularize against. policy_optimizer: optimizer for the policy network updates. critic_optimizer: optimizer for the critic network updates. prior_optimizer: optimizer for the prior network updates. distillation_cost: a multiplier to be used when adding distillation against the prior to the losses. entropy_regularizer_cost: a multiplier used for per state sample based entropy added to the actor loss. min_replay_size: minimum replay size before updating. max_replay_size: maximum replay size. samples_per_insert: number of samples to take from replay for every insert that is made. sequence_length: number of timesteps to store for each trajectory. sigma: standard deviation of zero-mean, Gaussian exploration noise. replay_table_name: string indicating what name to give the replay table. counter: counter object used to keep track of steps. logger: logger object to be used by learner. checkpoint: boolean indicating whether to checkpoint the learner. """ # Create the Builder object which will internally create agent components. builder = SVG0Builder( # TODO(mwhoffman): pass the config dataclass in directly. # TODO(mwhoffman): use the limiter rather than the workaround below. # Right now this modifies min_replay_size and samples_per_insert so that # they are not controlled by a limiter and are instead handled by the # Agent base class (the above TODO directly references this behavior). SVG0Config( discount=discount, batch_size=batch_size, prefetch_size=prefetch_size, target_update_period=target_update_period, policy_optimizer=policy_optimizer, critic_optimizer=critic_optimizer, prior_optimizer=prior_optimizer, distillation_cost=distillation_cost, entropy_regularizer_cost=entropy_regularizer_cost, min_replay_size=1, # Let the Agent class handle this. max_replay_size=max_replay_size, samples_per_insert=None, # Let the Agent class handle this. sequence_length=sequence_length, sigma=sigma, replay_table_name=replay_table_name, )) # TODO(mwhoffman): pass the network dataclass in directly. online_networks = SVG0Networks( policy_network=policy_network, critic_network=critic_network, prior_network=prior_network, ) # Target networks are just a copy of the online networks. target_networks = copy.deepcopy(online_networks) # Initialize the networks. online_networks.init(environment_spec) target_networks.init(environment_spec) # TODO(mwhoffman): either make this Dataclass or pass only one struct. # The network struct passed to make_learner is just a tuple for the # time-being (for backwards compatibility). networks = (online_networks, target_networks) # Create the behavior policy. policy_network = online_networks.make_policy() # Create the replay server and grab its address. replay_tables = builder.make_replay_tables(environment_spec, sequence_length) replay_server = reverb.Server(replay_tables, port=None) replay_client = reverb.Client(f'localhost:{replay_server.port}') # Create actor, dataset, and learner for generating, storing, and consuming # data respectively. adder = builder.make_adder(replay_client) actor = builder.make_actor(policy_network, adder) dataset = builder.make_dataset_iterator(replay_client) learner = builder.make_learner(networks, dataset, counter, logger, checkpoint) super().__init__(actor=actor, learner=learner, min_observations=max(batch_size, min_replay_size), observations_per_step=float(batch_size) / samples_per_insert) # Save the replay so we don't garbage collect it. self._replay_server = replay_server