def __init__( self, environment_spec: specs.EnvironmentSpec, replay_capacity: int, batch_size: int, hidden_sizes: Tuple[int, ...], learning_rate: float = 1e-3, terminal_tol: float = 1e-3, ): self._obs_spec = environment_spec.observations self._action_spec = environment_spec.actions # Hyperparameters. self._batch_size = batch_size self._terminal_tol = terminal_tol # Modelling self._replay = replay.Replay(replay_capacity) self._transition_model = MLPTransitionModel(environment_spec, hidden_sizes) self._optimizer = snt.optimizers.Adam(learning_rate) self._forward = tf.function(self._transition_model) tf2_utils.create_variables( self._transition_model, [self._obs_spec, self._action_spec]) self._variables = self._transition_model.trainable_variables # Model state. self._needs_reset = True
def test_snapshot_distribution(self): """Test that snapshotter correctly calls saves/restores snapshots.""" # Create a test network. net1 = snt.Sequential([ networks.LayerNormMLP([10, 10]), networks.MultivariateNormalDiagHead(1) ]) spec = specs.Array([10], dtype=np.float32) tf2_utils.create_variables(net1, [spec]) # Save the test network. directory = self.get_tempdir() objects_to_save = {'net': net1} snapshotter = tf2_savers.Snapshotter(objects_to_save, directory=directory) snapshotter.save() # Reload the test network. net2 = tf.saved_model.load(os.path.join(snapshotter.directory, 'net')) inputs = tf2_utils.add_batch_dim(tf2_utils.zeros_like(spec)) with tf.GradientTape() as tape: dist1 = net1(inputs) loss1 = tf.math.reduce_sum(dist1.mean() + dist1.variance()) grads1 = tape.gradient(loss1, net1.trainable_variables) with tf.GradientTape() as tape: dist2 = net2(inputs) loss2 = tf.math.reduce_sum(dist2.mean() + dist2.variance()) grads2 = tape.gradient(loss2, net2.trainable_variables) assert all(tree.map_structure(np.allclose, list(grads1), list(grads2)))
def test_update(self): # Create two instances of the same model. actor_model = snt.nets.MLP([50, 30]) learner_model = snt.nets.MLP([50, 30]) # Create variables first. input_spec = tf.TensorSpec(shape=(28, ), dtype=tf.float32) tf2_utils.create_variables(actor_model, [input_spec]) tf2_utils.create_variables(learner_model, [input_spec]) # Register them as client and source variables, respectively. actor_variables = actor_model.variables np_learner_variables = [ tf2_utils.to_numpy(v) for v in learner_model.variables ] variable_source = fakes.VariableSource(np_learner_variables) variable_client = tf2_variable_utils.VariableClient( variable_source, {'policy': actor_variables}) # Now, given some random batch of test input: x = tf.random.normal(shape=(8, 28)) # Before copying variables, the models have different outputs. actor_output = actor_model(x).numpy() learner_output = learner_model(x).numpy() self.assertFalse(np.allclose(actor_output, learner_output)) # Update the variable client. variable_client.update_and_wait() # After copying variables (by updating the client), the models are the same. actor_output = actor_model(x).numpy() learner_output = learner_model(x).numpy() self.assertTrue(np.allclose(actor_output, learner_output))
def main(_): # Create an environment and grab the spec. raw_environment = bsuite.load_from_id(FLAGS.bsuite_id) environment = single_precision.SinglePrecisionWrapper(raw_environment) environment_spec = specs.make_environment_spec(environment) # Build demonstration dataset. if hasattr(raw_environment, 'raw_env'): raw_environment = raw_environment.raw_env batch_dataset = bsuite_demonstrations.make_dataset(raw_environment) # Combine with demonstration dataset. transition = functools.partial(_n_step_transition_from_episode, n_step=1, additional_discount=1.) dataset = batch_dataset.map(transition) # Batch and prefetch. dataset = dataset.batch(FLAGS.batch_size, drop_remainder=True) dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE) # Create the networks to optimize. policy_network = make_policy_network(environment_spec.actions) # If the agent is non-autoregressive use epsilon=0 which will be a greedy # policy. evaluator_network = snt.Sequential([ policy_network, lambda q: trfl.epsilon_greedy(q, epsilon=FLAGS.epsilon).sample(), ]) # Ensure that we create the variables before proceeding (maybe not needed). tf2_utils.create_variables(policy_network, [environment_spec.observations]) counter = counting.Counter() learner_counter = counting.Counter(counter, prefix='learner') # Create the actor which defines how we take actions. evaluation_network = actors_tf2.FeedForwardActor(evaluator_network) eval_loop = acme.EnvironmentLoop(environment=environment, actor=evaluation_network, counter=counter, logger=loggers.TerminalLogger( 'evaluation', time_delta=1.)) # The learner updates the parameters (and initializes them). learner = learning.BCLearner(network=policy_network, learning_rate=FLAGS.learning_rate, dataset=dataset, counter=learner_counter) # Run the environment loop. while True: for _ in range(FLAGS.evaluate_every): learner.step() learner_counter.increment(learner_steps=FLAGS.evaluate_every) eval_loop.run(FLAGS.evaluation_episodes)
def test_feedforward(self, recurrent: bool): model = snt.Linear(42) if recurrent: model = snt.DeepRNN([model]) input_spec = specs.Array(shape=(10, ), dtype=np.float32) tf2_utils.create_variables(model, [input_spec]) variables: Sequence[tf.Variable] = model.variables shapes = [v.shape.as_list() for v in variables] self.assertSequenceEqual(shapes, [[42], [10, 42]])
def test_none_output(self): model = tf2_utils.to_sonnet_module(lambda x: None) input_spec = specs.Array(shape=(10, ), dtype=np.float32) expected_spec = None output_spec = tf2_utils.create_variables(model, [input_spec]) self.assertEqual(model.variables, ()) self.assertEqual(output_spec, expected_spec)
def test_scalar_output(self): model = tf2_utils.to_sonnet_module(tf.reduce_sum) input_spec = specs.Array(shape=(10, ), dtype=np.float32) expected_spec = tf.TensorSpec(shape=(), dtype=tf.float32) output_spec = tf2_utils.create_variables(model, [input_spec]) self.assertEqual(model.variables, ()) self.assertEqual(output_spec, expected_spec)
def test_output_spec_feedforward(self, recurrent: bool): input_spec = specs.Array(shape=(10, ), dtype=np.float32) model = snt.Linear(42) expected_spec = tf.TensorSpec(shape=(42, ), dtype=tf.float32) if recurrent: model = snt.DeepRNN([model]) expected_spec = (expected_spec, ()) output_spec = tf2_utils.create_variables(model, [input_spec]) self.assertEqual(output_spec, expected_spec)
def test_multiple_ouputs(self): model = PolicyValueHead(42) input_spec = specs.Array(shape=(10, ), dtype=np.float32) expected_spec = (tf.TensorSpec(shape=(42, ), dtype=tf.float32), tf.TensorSpec(shape=(), dtype=tf.float32)) output_spec = tf2_utils.create_variables(model, [input_spec]) variables: Sequence[tf.Variable] = model.variables shapes = [v.shape.as_list() for v in variables] self.assertSequenceEqual(shapes, [[42], [10, 42], [1], [10, 1]]) self.assertSequenceEqual(output_spec, expected_spec)
def test_force_snapshot(self): """Test that the force feature in Snapshotter.save() works correctly.""" # Create a test network. net = snt.Linear(10) spec = specs.Array([10], dtype=np.float32) tf2_utils.create_variables(net, [spec]) # Save the test network. directory = self.get_tempdir() objects_to_save = {'net': net} # Very long time_delta_minutes. snapshotter = tf2_savers.Snapshotter(objects_to_save, directory=directory, time_delta_minutes=1000) self.assertTrue(snapshotter.save(force=False)) # Due to the long time_delta_minutes, only force=True will create a new # snapshot. This also checks the default is force=False. self.assertFalse(snapshotter.save()) self.assertTrue(snapshotter.save(force=True))
def test_rnn_snapshot(self): """Test that snapshotter correctly calls saves/restores snapshots on RNNs.""" # Create a test network. net = snt.LSTM(10) spec = specs.Array([10], dtype=np.float32) tf2_utils.create_variables(net, [spec]) # Test that if you add some postprocessing without rerunning # create_variables, it still works. wrapped_net = snt.DeepRNN([net, lambda x: x]) for net1 in [net, wrapped_net]: # Save the test network. directory = self.get_tempdir() objects_to_save = {'net': net1} snapshotter = tf2_savers.Snapshotter(objects_to_save, directory=directory) snapshotter.save() # Reload the test network. net2 = tf.saved_model.load( os.path.join(snapshotter.directory, 'net')) inputs = tf2_utils.add_batch_dim(tf2_utils.zeros_like(spec)) with tf.GradientTape() as tape: outputs1, next_state1 = net1(inputs, net1.initial_state(1)) loss1 = tf.math.reduce_sum(outputs1) grads1 = tape.gradient(loss1, net1.trainable_variables) with tf.GradientTape() as tape: outputs2, next_state2 = net2(inputs, net2.initial_state(1)) loss2 = tf.math.reduce_sum(outputs2) grads2 = tape.gradient(loss2, net2.trainable_variables) assert np.allclose(outputs1, outputs2) assert np.allclose(tree.flatten(next_state1), tree.flatten(next_state2)) assert all( tree.map_structure(np.allclose, list(grads1), list(grads2)))
def __init__(self, environment_spec: specs.EnvironmentSpec, policy_network: snt.Module, critic_network: snt.Module, observation_network: types.TensorTransformation = tf.identity, discount: float = 0.99, batch_size: int = 256, prefetch_size: int = 4, target_policy_update_period: int = 100, target_critic_update_period: int = 100, min_replay_size: int = 1000, max_replay_size: int = 1000000, samples_per_insert: float = 32.0, policy_loss_module: snt.Module = None, policy_optimizer: snt.Optimizer = None, critic_optimizer: snt.Optimizer = None, n_step: int = 5, num_samples: int = 20, clipping: bool = True, logger: loggers.Logger = None, counter: counting.Counter = None, checkpoint: bool = True, replay_table_name: str = adders.DEFAULT_PRIORITY_TABLE): """Initialize the agent. Args: environment_spec: description of the actions, observations, etc. policy_network: the online (optimized) policy. critic_network: the online critic. observation_network: optional network to transform the observations before they are fed into any network. discount: discount to use for TD updates. batch_size: batch size for updates. prefetch_size: size to prefetch from replay. target_policy_update_period: number of updates to perform before updating the target policy network. target_critic_update_period: number of updates to perform before updating the target critic network. min_replay_size: minimum replay size before updating. max_replay_size: maximum replay size. samples_per_insert: number of samples to take from replay for every insert that is made. policy_loss_module: configured MPO loss function for the policy optimization; defaults to sensible values on the control suite. See `acme/tf/losses/mpo.py` for more details. policy_optimizer: optimizer to be used on the policy. critic_optimizer: optimizer to be used on the critic. n_step: number of steps to squash into a single transition. num_samples: number of actions to sample when doing a Monte Carlo integration with respect to the policy. clipping: whether to clip gradients by global norm. logger: logging object used to write to logs. counter: counter object used to keep track of steps. checkpoint: boolean indicating whether to checkpoint the learner. replay_table_name: string indicating what name to give the replay table. """ # Create a replay server to add data to. replay_table = reverb.Table( name=adders.DEFAULT_PRIORITY_TABLE, sampler=reverb.selectors.Uniform(), remover=reverb.selectors.Fifo(), max_size=max_replay_size, rate_limiter=reverb.rate_limiters.MinSize(min_size_to_sample=1)) self._server = reverb.Server([replay_table], port=None) # The adder is used to insert observations into replay. address = f'localhost:{self._server.port}' adder = adders.NStepTransitionAdder(client=reverb.Client(address), n_step=n_step, discount=discount) # The dataset object to learn from. dataset = datasets.make_reverb_dataset( table=replay_table_name, client=reverb.TFClient(address), batch_size=batch_size, prefetch_size=prefetch_size, environment_spec=environment_spec, transition_adder=True) # Make sure observation network is a Sonnet Module. observation_network = tf2_utils.to_sonnet_module(observation_network) # Create target networks before creating online/target network variables. target_policy_network = copy.deepcopy(policy_network) target_critic_network = copy.deepcopy(critic_network) target_observation_network = copy.deepcopy(observation_network) # Get observation and action specs. act_spec = environment_spec.actions obs_spec = environment_spec.observations emb_spec = tf2_utils.create_variables(observation_network, [obs_spec]) # Create the behavior policy. behavior_network = snt.Sequential([ observation_network, policy_network, networks.StochasticSamplingHead(), ]) # Create variables. tf2_utils.create_variables(policy_network, [emb_spec]) tf2_utils.create_variables(critic_network, [emb_spec, act_spec]) tf2_utils.create_variables(target_policy_network, [emb_spec]) tf2_utils.create_variables(target_critic_network, [emb_spec, act_spec]) tf2_utils.create_variables(target_observation_network, [obs_spec]) # Create the actor which defines how we take actions. actor = actors.FeedForwardActor(policy_network=behavior_network, adder=adder) # Create optimizers. policy_optimizer = policy_optimizer or snt.optimizers.Adam(1e-4) critic_optimizer = critic_optimizer or snt.optimizers.Adam(1e-4) # The learner updates the parameters (and initializes them). learner = learning.DistributionalMPOLearner( policy_network=policy_network, critic_network=critic_network, observation_network=observation_network, target_policy_network=target_policy_network, target_critic_network=target_critic_network, target_observation_network=target_observation_network, policy_loss_module=policy_loss_module, policy_optimizer=policy_optimizer, critic_optimizer=critic_optimizer, clipping=clipping, discount=discount, num_samples=num_samples, target_policy_update_period=target_policy_update_period, target_critic_update_period=target_critic_update_period, dataset=dataset, logger=logger, counter=counter, checkpoint=checkpoint) super().__init__(actor=actor, learner=learner, min_observations=max(batch_size, min_replay_size), observations_per_step=float(batch_size) / samples_per_insert)
def __init__( self, environment_spec: specs.EnvironmentSpec, network: snt.RNNCore, burn_in_length: int, trace_length: int, replay_period: int, counter: counting.Counter = None, logger: loggers.Logger = None, discount: float = 0.99, batch_size: int = 32, prefetch_size: int = tf.data.experimental.AUTOTUNE, target_update_period: int = 100, importance_sampling_exponent: float = 0.2, priority_exponent: float = 0.6, epsilon: float = 0.01, learning_rate: float = 1e-3, min_replay_size: int = 1000, max_replay_size: int = 1000000, samples_per_insert: float = 32.0, store_lstm_state: bool = True, max_priority_weight: float = 0.9, checkpoint: bool = True, ): replay_table = reverb.Table( name=adders.DEFAULT_PRIORITY_TABLE, sampler=reverb.selectors.Prioritized(priority_exponent), remover=reverb.selectors.Fifo(), max_size=max_replay_size, rate_limiter=reverb.rate_limiters.MinSize(min_size_to_sample=1)) self._server = reverb.Server([replay_table], port=None) address = f'localhost:{self._server.port}' sequence_length = burn_in_length + trace_length + 1 # Component to add things into replay. adder = adders.SequenceAdder( client=reverb.Client(address), period=replay_period, sequence_length=sequence_length, ) # The dataset object to learn from. reverb_client = reverb.TFClient(address) extra_spec = { 'core_state': network.initial_state(1), } # Remove batch dimensions. extra_spec = tf2_utils.squeeze_batch_dim(extra_spec) dataset = datasets.make_reverb_dataset( client=reverb_client, environment_spec=environment_spec, batch_size=batch_size, prefetch_size=prefetch_size, extra_spec=extra_spec, sequence_length=sequence_length) target_network = copy.deepcopy(network) tf2_utils.create_variables(network, [environment_spec.observations]) tf2_utils.create_variables(target_network, [environment_spec.observations]) learner = learning.R2D2Learner( environment_spec=environment_spec, network=network, target_network=target_network, burn_in_length=burn_in_length, sequence_length=sequence_length, dataset=dataset, reverb_client=reverb_client, counter=counter, logger=logger, discount=discount, target_update_period=target_update_period, importance_sampling_exponent=importance_sampling_exponent, max_replay_size=max_replay_size, learning_rate=learning_rate, store_lstm_state=store_lstm_state, max_priority_weight=max_priority_weight, ) self._checkpointer = tf2_savers.Checkpointer( subdirectory='r2d2_learner', time_delta_minutes=60, objects_to_save=learner.state, enable_checkpointing=checkpoint, ) self._snapshotter = tf2_savers.Snapshotter( objects_to_save={'network': network}, time_delta_minutes=60.) policy_network = snt.DeepRNN([ network, lambda qs: trfl.epsilon_greedy(qs, epsilon=epsilon).sample(), ]) actor = actors.RecurrentActor(policy_network, adder) observations_per_step = (float(replay_period * batch_size) / samples_per_insert) super().__init__(actor=actor, learner=learner, min_observations=replay_period * max(batch_size, min_replay_size), observations_per_step=observations_per_step)
def __init__( self, environment_spec: specs.EnvironmentSpec, network: snt.Module, batch_size: int = 256, prefetch_size: int = 4, target_update_period: int = 100, samples_per_insert: float = 32.0, min_replay_size: int = 1000, max_replay_size: int = 1000000, importance_sampling_exponent: float = 0.2, priority_exponent: float = 0.6, n_step: int = 5, epsilon: tf.Tensor = None, learning_rate: float = 1e-3, discount: float = 0.99, logger: loggers.Logger = None, checkpoint: bool = True, ): """Initialize the agent. Args: environment_spec: description of the actions, observations, etc. network: the online Q network (the one being optimized) batch_size: batch size for updates. prefetch_size: size to prefetch from replay. target_update_period: number of learner steps to perform before updating the target networks. samples_per_insert: number of samples to take from replay for every insert that is made. min_replay_size: minimum replay size before updating. This and all following arguments are related to dataset construction and will be ignored if a dataset argument is passed. max_replay_size: maximum replay size. importance_sampling_exponent: power to which importance weights are raised before normalizing. priority_exponent: exponent used in prioritized sampling. n_step: number of steps to squash into a single transition. epsilon: probability of taking a random action; ignored if a policy network is given. learning_rate: learning rate for the q-network update. discount: discount to use for TD updates. logger: logger object to be used by learner. checkpoint: boolean indicating whether to checkpoint the learner. """ # Create a replay server to add data to. This uses no limiter behavior in # order to allow the Agent interface to handle it. replay_table = reverb.Table( name=adders.DEFAULT_PRIORITY_TABLE, sampler=reverb.selectors.Prioritized(priority_exponent), remover=reverb.selectors.Fifo(), max_size=max_replay_size, rate_limiter=reverb.rate_limiters.MinSize(1)) self._server = reverb.Server([replay_table], port=None) # The adder is used to insert observations into replay. address = f'localhost:{self._server.port}' adder = adders.NStepTransitionAdder(client=reverb.Client(address), n_step=n_step, discount=discount) # The dataset provides an interface to sample from replay. replay_client = reverb.TFClient(address) dataset = datasets.make_reverb_dataset( client=replay_client, environment_spec=environment_spec, batch_size=batch_size, prefetch_size=prefetch_size, transition_adder=True) # Use constant 0.05 epsilon greedy policy by default. if epsilon is None: epsilon = tf.Variable(0.05, trainable=False) policy_network = snt.Sequential([ network, lambda q: trfl.epsilon_greedy(q, epsilon=epsilon).sample(), ]) # Create a target network. target_network = copy.deepcopy(network) # Ensure that we create the variables before proceeding (maybe not needed). tf2_utils.create_variables(network, [environment_spec.observations]) tf2_utils.create_variables(target_network, [environment_spec.observations]) # Create the actor which defines how we take actions. actor = actors.FeedForwardActor(policy_network, adder) # The learner updates the parameters (and initializes them). learner = learning.DQNLearner( network=network, target_network=target_network, discount=discount, importance_sampling_exponent=importance_sampling_exponent, learning_rate=learning_rate, target_update_period=target_update_period, dataset=dataset, replay_client=replay_client, logger=logger, checkpoint=checkpoint) if checkpoint: self._checkpointer = tf2_savers.Checkpointer( objects_to_save=learner.state, subdirectory='dqn_learner', time_delta_minutes=60.) else: self._checkpointer = None super().__init__(actor=actor, learner=learner, min_observations=max(batch_size, min_replay_size), observations_per_step=float(batch_size) / samples_per_insert)
def __init__( self, environment_spec: specs.EnvironmentSpec, network: snt.RNNCore, sequence_length: int, sequence_period: int, counter: counting.Counter = None, logger: loggers.Logger = None, discount: float = 0.99, max_queue_size: int = 100000, batch_size: int = 16, learning_rate: float = 1e-3, entropy_cost: float = 0.01, baseline_cost: float = 0.5, max_abs_reward: Optional[float] = None, max_gradient_norm: Optional[float] = None, ): num_actions = environment_spec.actions.num_values self._logger = logger or loggers.TerminalLogger('agent') queue = reverb.Table.queue(name=adders.DEFAULT_PRIORITY_TABLE, max_size=max_queue_size) self._server = reverb.Server([queue], port=None) self._can_sample = lambda: queue.can_sample(batch_size) address = f'localhost:{self._server.port}' # Component to add things into replay. adder = adders.SequenceAdder( client=reverb.Client(address), period=sequence_period, sequence_length=sequence_length, ) # The dataset object to learn from. extra_spec = { 'core_state': network.initial_state(1), 'logits': tf.ones(shape=(1, num_actions), dtype=tf.float32) } # Remove batch dimensions. extra_spec = tf2_utils.squeeze_batch_dim(extra_spec) dataset = datasets.make_reverb_dataset( client=reverb.TFClient(address), environment_spec=environment_spec, batch_size=batch_size, extra_spec=extra_spec, sequence_length=sequence_length) tf2_utils.create_variables(network, [environment_spec.observations]) self._actor = acting.IMPALAActor(network, adder) self._learner = learning.IMPALALearner( environment_spec=environment_spec, network=network, dataset=dataset, counter=counter, logger=logger, discount=discount, learning_rate=learning_rate, entropy_cost=entropy_cost, baseline_cost=baseline_cost, max_gradient_norm=max_gradient_norm, max_abs_reward=max_abs_reward, )
def __init__(self, environment_spec: specs.EnvironmentSpec, policy_network: snt.Module, critic_network: snt.Module, observation_network: types.TensorTransformation = tf.identity, discount: float = 0.99, batch_size: int = 256, prefetch_size: int = 4, target_update_period: int = 100, min_replay_size: int = 1000, max_replay_size: int = 1000000, samples_per_insert: float = 32.0, n_step: int = 5, sigma: float = 0.3, clipping: bool = True, logger: loggers.Logger = None, counter: counting.Counter = None, checkpoint: bool = True, replay_table_name: str = adders.DEFAULT_PRIORITY_TABLE): """Initialize the agent. Args: environment_spec: description of the actions, observations, etc. policy_network: the online (optimized) policy. critic_network: the online critic. observation_network: optional network to transform the observations before they are fed into any network. discount: discount to use for TD updates. batch_size: batch size for updates. prefetch_size: size to prefetch from replay. target_update_period: number of learner steps to perform before updating the target networks. min_replay_size: minimum replay size before updating. max_replay_size: maximum replay size. samples_per_insert: number of samples to take from replay for every insert that is made. n_step: number of steps to squash into a single transition. sigma: standard deviation of zero-mean, Gaussian exploration noise. clipping: whether to clip gradients by global norm. logger: logger object to be used by learner. counter: counter object used to keep track of steps. checkpoint: boolean indicating whether to checkpoint the learner. replay_table_name: string indicating what name to give the replay table. """ # Create a replay server to add data to. This uses no limiter behavior in # order to allow the Agent interface to handle it. replay_table = reverb.Table( name=replay_table_name, sampler=reverb.selectors.Uniform(), remover=reverb.selectors.Fifo(), max_size=max_replay_size, rate_limiter=reverb.rate_limiters.MinSize(1)) self._server = reverb.Server([replay_table], port=None) # The adder is used to insert observations into replay. address = f'localhost:{self._server.port}' adder = adders.NStepTransitionAdder( priority_fns={replay_table_name: lambda x: 1.}, client=reverb.Client(address), n_step=n_step, discount=discount) # The dataset provides an interface to sample from replay. dataset = datasets.make_reverb_dataset( table=replay_table_name, client=reverb.TFClient(address), batch_size=batch_size, prefetch_size=prefetch_size, environment_spec=environment_spec, transition_adder=True) # Make sure observation network is a Sonnet Module. observation_network = tf2_utils.to_sonnet_module(observation_network) # Create target networks. target_policy_network = copy.deepcopy(policy_network) target_critic_network = copy.deepcopy(critic_network) target_observation_network = copy.deepcopy(observation_network) # Get observation and action specs. act_spec = environment_spec.actions obs_spec = environment_spec.observations emb_spec = tf2_utils.create_variables(observation_network, [obs_spec]) # Create the behavior policy. behavior_network = snt.Sequential([ observation_network, policy_network, networks.ClippedGaussian(sigma), networks.ClipToSpec(act_spec), ]) # Create variables. tf2_utils.create_variables(policy_network, [emb_spec]) tf2_utils.create_variables(critic_network, [emb_spec, act_spec]) tf2_utils.create_variables(target_policy_network, [emb_spec]) tf2_utils.create_variables(target_critic_network, [emb_spec, act_spec]) tf2_utils.create_variables(target_observation_network, [obs_spec]) # Create the actor which defines how we take actions. actor = actors.FeedForwardActor(behavior_network, adder=adder) # Create optimizers. policy_optimizer = snt.optimizers.Adam(learning_rate=1e-4) critic_optimizer = snt.optimizers.Adam(learning_rate=1e-4) # The learner updates the parameters (and initializes them). learner = learning.D4PGLearner( policy_network=policy_network, critic_network=critic_network, observation_network=observation_network, target_policy_network=target_policy_network, target_critic_network=target_critic_network, target_observation_network=target_observation_network, policy_optimizer=policy_optimizer, critic_optimizer=critic_optimizer, clipping=clipping, discount=discount, target_update_period=target_update_period, dataset=dataset, counter=counter, logger=logger, checkpoint=checkpoint, ) super().__init__( actor=actor, learner=learner, min_observations=max(batch_size, min_replay_size), observations_per_step=float(batch_size) / samples_per_insert)
def __init__(self, environment_spec: specs.EnvironmentSpec, network: snt.RNNCore, target_network: snt.RNNCore, burn_in_length: int, trace_length: int, replay_period: int, demonstration_dataset: tf.data.Dataset, demonstration_ratio: float, counter: counting.Counter = None, logger: loggers.Logger = None, discount: float = 0.99, batch_size: int = 32, target_update_period: int = 100, importance_sampling_exponent: float = 0.2, epsilon: float = 0.01, learning_rate: float = 1e-3, log_to_bigtable: bool = False, log_name: str = 'agent', checkpoint: bool = True, min_replay_size: int = 1000, max_replay_size: int = 1000000, samples_per_insert: float = 32.0): replay_table = reverb.Table( name=adders.DEFAULT_PRIORITY_TABLE, sampler=reverb.selectors.Uniform(), remover=reverb.selectors.Fifo(), max_size=max_replay_size, rate_limiter=reverb.rate_limiters.MinSize(min_size_to_sample=1)) self._server = reverb.Server([replay_table], port=None) address = f'localhost:{self._server.port}' sequence_length = burn_in_length + trace_length + 1 # Component to add things into replay. sequence_kwargs = dict( period=replay_period, sequence_length=sequence_length, ) adder = adders.SequenceAdder(client=reverb.Client(address), **sequence_kwargs) # The dataset object to learn from. reverb_client = reverb.TFClient(address) extra_spec = { 'core_state': network.initial_state(1), } # Remove batch dimensions. extra_spec = tf2_utils.squeeze_batch_dim(extra_spec) dataset = datasets.make_reverb_dataset( client=reverb_client, environment_spec=environment_spec, extra_spec=extra_spec, sequence_length=sequence_length) # Combine with demonstration dataset. transition = functools.partial(_sequence_from_episode, extra_spec=extra_spec, **sequence_kwargs) dataset_demos = demonstration_dataset.map(transition) dataset = tf.data.experimental.sample_from_datasets( [dataset, dataset_demos], [1 - demonstration_ratio, demonstration_ratio]) # Batch and prefetch. dataset = dataset.batch(batch_size, drop_remainder=True) dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE) tf2_utils.create_variables(network, [environment_spec.observations]) tf2_utils.create_variables(target_network, [environment_spec.observations]) learner = learning.R2D2Learner( environment_spec=environment_spec, network=network, target_network=target_network, burn_in_length=burn_in_length, dataset=dataset, reverb_client=reverb_client, counter=counter, logger=logger, sequence_length=sequence_length, discount=discount, target_update_period=target_update_period, importance_sampling_exponent=importance_sampling_exponent, max_replay_size=max_replay_size, learning_rate=learning_rate, store_lstm_state=False, ) self._checkpointer = tf2_savers.Checkpointer( subdirectory='r2d2_learner', time_delta_minutes=60, objects_to_save=learner.state, enable_checkpointing=checkpoint, ) self._snapshotter = tf2_savers.Snapshotter( objects_to_save={'network': network}, time_delta_minutes=60.) policy_network = snt.DeepRNN([ network, lambda qs: trfl.epsilon_greedy(qs, epsilon=epsilon).sample(), ]) actor = actors.RecurrentActor(policy_network, adder) observations_per_step = (float(replay_period * batch_size) / samples_per_insert) super().__init__( actor=actor, learner=learner, min_observations=replay_period * max(batch_size, min_replay_size), observations_per_step=observations_per_step)
def __init__( self, environment_spec: specs.EnvironmentSpec, network: snt.Module, demonstration_dataset: tf.data.Dataset, demonstration_ratio: float, batch_size: int = 256, prefetch_size: int = 4, target_update_period: int = 100, samples_per_insert: float = 32.0, min_replay_size: int = 1000, max_replay_size: int = 1000000, importance_sampling_exponent: float = 0.2, n_step: int = 5, epsilon: tf.Tensor = None, learning_rate: float = 1e-3, discount: float = 0.99, ): """Initialize the agent. Args: environment_spec: description of the actions, observations, etc. network: the online Q network (the one being optimized) demonstration_dataset: tf.data.Dataset producing (timestep, action) tuples containing full episodes. demonstration_ratio: Ratio of transitions coming from demonstrations. batch_size: batch size for updates. prefetch_size: size to prefetch from replay. target_update_period: number of learner steps to perform before updating the target networks. samples_per_insert: number of samples to take from replay for every insert that is made. min_replay_size: minimum replay size before updating. This and all following arguments are related to dataset construction and will be ignored if a dataset argument is passed. max_replay_size: maximum replay size. importance_sampling_exponent: power to which importance weights are raised before normalizing. n_step: number of steps to squash into a single transition. epsilon: probability of taking a random action; ignored if a policy network is given. learning_rate: learning rate for the q-network update. discount: discount to use for TD updates. """ # Create a replay server to add data to. This uses no limiter behavior in # order to allow the Agent interface to handle it. replay_table = reverb.Table( name=adders.DEFAULT_PRIORITY_TABLE, sampler=reverb.selectors.Uniform(), remover=reverb.selectors.Fifo(), max_size=max_replay_size, rate_limiter=reverb.rate_limiters.MinSize(1)) self._server = reverb.Server([replay_table], port=None) # The adder is used to insert observations into replay. address = f'localhost:{self._server.port}' adder = adders.NStepTransitionAdder(client=reverb.Client(address), n_step=n_step, discount=discount) # The dataset provides an interface to sample from replay. replay_client = reverb.TFClient(address) dataset = datasets.make_reverb_dataset( client=replay_client, environment_spec=environment_spec, transition_adder=True) # Combine with demonstration dataset. transition = functools.partial(_n_step_transition_from_episode, n_step=n_step, discount=discount) dataset_demos = demonstration_dataset.map(transition) dataset = tf.data.experimental.sample_from_datasets( [dataset, dataset_demos], [1 - demonstration_ratio, demonstration_ratio]) # Batch and prefetch. dataset = dataset.batch(batch_size, drop_remainder=True) dataset = dataset.prefetch(prefetch_size) # Use constant 0.05 epsilon greedy policy by default. if epsilon is None: epsilon = tf.Variable(0.05, trainable=False) policy_network = snt.Sequential([ network, lambda q: trfl.epsilon_greedy(q, epsilon=epsilon).sample(), ]) # Create a target network. target_network = copy.deepcopy(network) # Ensure that we create the variables before proceeding (maybe not needed). tf2_utils.create_variables(network, [environment_spec.observations]) tf2_utils.create_variables(target_network, [environment_spec.observations]) # Create the actor which defines how we take actions. actor = actors.FeedForwardActor(policy_network, adder) # The learner updates the parameters (and initializes them). learner = dqn.DQNLearner( network=network, target_network=target_network, discount=discount, importance_sampling_exponent=importance_sampling_exponent, learning_rate=learning_rate, target_update_period=target_update_period, dataset=dataset, replay_client=replay_client) super().__init__(actor=actor, learner=learner, min_observations=max(batch_size, min_replay_size), observations_per_step=float(batch_size) / samples_per_insert)
def __init__( self, network: snt.Module, model: models.Model, optimizer: snt.Optimizer, n_step: int, discount: float, replay_capacity: int, num_simulations: int, environment_spec: specs.EnvironmentSpec, batch_size: int, ): # Create a replay server for storing transitions. replay_table = reverb.Table( name=adders.DEFAULT_PRIORITY_TABLE, sampler=reverb.selectors.Uniform(), remover=reverb.selectors.Fifo(), max_size=replay_capacity, rate_limiter=reverb.rate_limiters.MinSize(1)) self._server = reverb.Server([replay_table], port=None) # The adder is used to insert observations into replay. address = f'localhost:{self._server.port}' adder = adders.NStepTransitionAdder(client=reverb.Client(address), n_step=n_step, discount=discount) # The dataset provides an interface to sample from replay. replay_client = reverb.TFClient(address) action_spec: specs.DiscreteArray = environment_spec.actions dataset = datasets.make_reverb_dataset( client=replay_client, environment_spec=environment_spec, extra_spec={ 'pi': specs.Array(shape=(action_spec.num_values, ), dtype=np.float32) }, transition_adder=True) dataset = dataset.batch(batch_size, drop_remainder=True) tf2_utils.create_variables(network, [environment_spec.observations]) # Now create the agent components: actor & learner. actor = acting.MCTSActor( environment_spec=environment_spec, model=model, network=network, discount=discount, adder=adder, num_simulations=num_simulations, ) learner = learning.AZLearner( network=network, optimizer=optimizer, dataset=dataset, discount=discount, ) # The parent class combines these together into one 'agent'. super().__init__( actor=actor, learner=learner, min_observations=10, observations_per_step=1, )