def test_snapshot_distribution(self): """Test that snapshotter correctly calls saves/restores snapshots.""" # Create a test network. net1 = snt.Sequential([ networks.LayerNormMLP([10, 10]), networks.MultivariateNormalDiagHead(1) ]) spec = specs.Array([10], dtype=np.float32) tf2_utils.create_variables(net1, [spec]) # Save the test network. directory = self.get_tempdir() objects_to_save = {'net': net1} snapshotter = tf2_savers.Snapshotter(objects_to_save, directory=directory) snapshotter.save() # Reload the test network. net2 = tf.saved_model.load(os.path.join(snapshotter.directory, 'net')) inputs = tf2_utils.add_batch_dim(tf2_utils.zeros_like(spec)) with tf.GradientTape() as tape: dist1 = net1(inputs) loss1 = tf.math.reduce_sum(dist1.mean() + dist1.variance()) grads1 = tape.gradient(loss1, net1.trainable_variables) with tf.GradientTape() as tape: dist2 = net2(inputs) loss2 = tf.math.reduce_sum(dist2.mean() + dist2.variance()) grads2 = tape.gradient(loss2, net2.trainable_variables) assert all(tree.map_structure(np.allclose, list(grads1), list(grads2)))
def __init__(self, network: snt.Module, learning_rate: float, dataset: tf.data.Dataset, counter: counting.Counter = None, logger: loggers.Logger = None): """Initializes the learner. Args: network: the online Q network (the one being optimized) learning_rate: learning rate for the q-network update. dataset: dataset to learn from. counter: Counter object for (potentially distributed) counting. logger: Logger object for writing logs to. """ self._counter = counter or counting.Counter() self._logger = logger or loggers.TerminalLogger('learner', time_delta=1.) # Get an iterator over the dataset. self._iterator = iter(dataset) # pytype: disable=wrong-arg-types # TODO(b/155086959): Fix type stubs and remove. self._network = network self._optimizer = snt.optimizers.Adam(learning_rate) self._variables: List[List[tf.Tensor]] = [network.trainable_variables] self._num_steps = tf.Variable(0, dtype=tf.int32) self._snapshotter = tf2_savers.Snapshotter( objects_to_save={'network': network}, time_delta_minutes=60.)
def __init__( self, environment_spec: specs.EnvironmentSpec, network: snt.RNNCore, dataset: tf.data.Dataset, learning_rate: float, discount: float = 0.99, entropy_cost: float = 0., baseline_cost: float = 1., max_abs_reward: Optional[float] = None, max_gradient_norm: Optional[float] = None, counter: counting.Counter = None, logger: loggers.Logger = None, ): # Internalise, optimizer, and dataset. self._env_spec = environment_spec self._optimizer = snt.optimizers.Adam(learning_rate=learning_rate) self._network = network self._variables = network.variables # TODO(b/155086959): Fix type stubs and remove. self._iterator = iter(dataset) # pytype: disable=wrong-arg-types # Hyperparameters. self._discount = discount self._entropy_cost = entropy_cost self._baseline_cost = baseline_cost # Set up reward/gradient clipping. if max_abs_reward is None: max_abs_reward = np.inf if max_gradient_norm is None: max_gradient_norm = 1e10 # A very large number. Infinity results in NaNs. self._max_abs_reward = tf.convert_to_tensor(max_abs_reward) self._max_gradient_norm = tf.convert_to_tensor(max_gradient_norm) # Set up logging/counting. self._counter = counter or counting.Counter() self._logger = logger or loggers.TerminalLogger('learner', time_delta=1.) self._snapshotter = tf2_savers.Snapshotter( objects_to_save={'network': network}, time_delta_minutes=60.) # Do not record timestamps until after the first learning step is done. # This is to avoid including the time it takes for actors to come online and # fill the replay buffer. self._timestamp = None
def test_force_snapshot(self): """Test that the force feature in Snapshotter.save() works correctly.""" # Create a test network. net = snt.Linear(10) spec = specs.Array([10], dtype=np.float32) tf2_utils.create_variables(net, [spec]) # Save the test network. directory = self.get_tempdir() objects_to_save = {'net': net} # Very long time_delta_minutes. snapshotter = tf2_savers.Snapshotter(objects_to_save, directory=directory, time_delta_minutes=1000) self.assertTrue(snapshotter.save(force=False)) # Due to the long time_delta_minutes, only force=True will create a new # snapshot. This also checks the default is force=False. self.assertFalse(snapshotter.save()) self.assertTrue(snapshotter.save(force=True))
def test_rnn_snapshot(self): """Test that snapshotter correctly calls saves/restores snapshots on RNNs.""" # Create a test network. net = snt.LSTM(10) spec = specs.Array([10], dtype=np.float32) tf2_utils.create_variables(net, [spec]) # Test that if you add some postprocessing without rerunning # create_variables, it still works. wrapped_net = snt.DeepRNN([net, lambda x: x]) for net1 in [net, wrapped_net]: # Save the test network. directory = self.get_tempdir() objects_to_save = {'net': net1} snapshotter = tf2_savers.Snapshotter(objects_to_save, directory=directory) snapshotter.save() # Reload the test network. net2 = tf.saved_model.load( os.path.join(snapshotter.directory, 'net')) inputs = tf2_utils.add_batch_dim(tf2_utils.zeros_like(spec)) with tf.GradientTape() as tape: outputs1, next_state1 = net1(inputs, net1.initial_state(1)) loss1 = tf.math.reduce_sum(outputs1) grads1 = tape.gradient(loss1, net1.trainable_variables) with tf.GradientTape() as tape: outputs2, next_state2 = net2(inputs, net2.initial_state(1)) loss2 = tf.math.reduce_sum(outputs2) grads2 = tape.gradient(loss2, net2.trainable_variables) assert np.allclose(outputs1, outputs2) assert np.allclose(tree.flatten(next_state1), tree.flatten(next_state2)) assert all( tree.map_structure(np.allclose, list(grads1), list(grads2)))
def __init__( self, policy_network: snt.Module, critic_network: snt.Module, target_policy_network: snt.Module, target_critic_network: snt.Module, discount: float, num_samples: int, target_policy_update_period: int, target_critic_update_period: int, dataset: tf.data.Dataset, observation_network: types.TensorTransformation = tf.identity, target_observation_network: types.TensorTransformation = tf.identity, policy_loss_module: Optional[snt.Module] = None, policy_optimizer: Optional[snt.Optimizer] = None, critic_optimizer: Optional[snt.Optimizer] = None, dual_optimizer: Optional[snt.Optimizer] = None, clipping: bool = True, counter: Optional[counting.Counter] = None, logger: Optional[loggers.Logger] = None, checkpoint: bool = True, ): # Store online and target networks. self._policy_network = policy_network self._critic_network = critic_network self._target_policy_network = target_policy_network self._target_critic_network = target_critic_network # Make sure observation networks are snt.Module's so they have variables. self._observation_network = tf2_utils.to_sonnet_module( observation_network) self._target_observation_network = tf2_utils.to_sonnet_module( target_observation_network) # General learner book-keeping and loggers. self._counter = counter or counting.Counter() self._logger = logger or loggers.make_default_logger('learner') # Other learner parameters. self._discount = discount self._num_samples = num_samples self._clipping = clipping # Necessary to track when to update target networks. self._num_steps = tf.Variable(0, dtype=tf.int32) self._target_policy_update_period = target_policy_update_period self._target_critic_update_period = target_critic_update_period # Batch dataset and create iterator. # TODO(b/155086959): Fix type stubs and remove. self._iterator = iter(dataset) # pytype: disable=wrong-arg-types self._policy_loss_module = policy_loss_module or losses.MPO( epsilon=1e-1, epsilon_penalty=1e-3, epsilon_mean=1e-3, epsilon_stddev=1e-6, init_log_temperature=1., init_log_alpha_mean=1., init_log_alpha_stddev=10.) # Create the optimizers. self._critic_optimizer = critic_optimizer or snt.optimizers.Adam(1e-4) self._policy_optimizer = policy_optimizer or snt.optimizers.Adam(1e-4) self._dual_optimizer = dual_optimizer or snt.optimizers.Adam(1e-2) # Expose the variables. policy_network_to_expose = snt.Sequential( [self._target_observation_network, self._target_policy_network]) self._variables = { 'critic': self._target_critic_network.variables, 'policy': policy_network_to_expose.variables, } # Create a checkpointer and snapshotter object. self._checkpointer = None self._snapshotter = None if checkpoint: self._checkpointer = tf2_savers.Checkpointer( subdirectory='dmpo_learner', objects_to_save={ 'counter': self._counter, 'policy': self._policy_network, 'critic': self._critic_network, 'observation': self._observation_network, 'target_policy': self._target_policy_network, 'target_critic': self._target_critic_network, 'target_observation': self._target_observation_network, 'policy_optimizer': self._policy_optimizer, 'critic_optimizer': self._critic_optimizer, 'dual_optimizer': self._dual_optimizer, 'policy_loss_module': self._policy_loss_module, 'num_steps': self._num_steps, }) self._snapshotter = tf2_savers.Snapshotter( objects_to_save={ 'policy': snt.Sequential([ self._target_observation_network, self._target_policy_network ]), }) # Do not record timestamps until after the first learning step is done. # This is to avoid including the time it takes for actors to come online and # fill the replay buffer. self._timestamp = None
def __init__( self, environment_spec: specs.EnvironmentSpec, network: snt.RNNCore, burn_in_length: int, trace_length: int, replay_period: int, counter: counting.Counter = None, logger: loggers.Logger = None, discount: float = 0.99, batch_size: int = 32, prefetch_size: int = tf.data.experimental.AUTOTUNE, target_update_period: int = 100, importance_sampling_exponent: float = 0.2, priority_exponent: float = 0.6, epsilon: float = 0.01, learning_rate: float = 1e-3, min_replay_size: int = 1000, max_replay_size: int = 1000000, samples_per_insert: float = 32.0, store_lstm_state: bool = True, max_priority_weight: float = 0.9, checkpoint: bool = True, ): replay_table = reverb.Table( name=adders.DEFAULT_PRIORITY_TABLE, sampler=reverb.selectors.Prioritized(priority_exponent), remover=reverb.selectors.Fifo(), max_size=max_replay_size, rate_limiter=reverb.rate_limiters.MinSize(min_size_to_sample=1)) self._server = reverb.Server([replay_table], port=None) address = f'localhost:{self._server.port}' sequence_length = burn_in_length + trace_length + 1 # Component to add things into replay. adder = adders.SequenceAdder( client=reverb.Client(address), period=replay_period, sequence_length=sequence_length, ) # The dataset object to learn from. reverb_client = reverb.TFClient(address) extra_spec = { 'core_state': network.initial_state(1), } # Remove batch dimensions. extra_spec = tf2_utils.squeeze_batch_dim(extra_spec) dataset = datasets.make_reverb_dataset( client=reverb_client, environment_spec=environment_spec, batch_size=batch_size, prefetch_size=prefetch_size, extra_spec=extra_spec, sequence_length=sequence_length) target_network = copy.deepcopy(network) tf2_utils.create_variables(network, [environment_spec.observations]) tf2_utils.create_variables(target_network, [environment_spec.observations]) learner = learning.R2D2Learner( environment_spec=environment_spec, network=network, target_network=target_network, burn_in_length=burn_in_length, sequence_length=sequence_length, dataset=dataset, reverb_client=reverb_client, counter=counter, logger=logger, discount=discount, target_update_period=target_update_period, importance_sampling_exponent=importance_sampling_exponent, max_replay_size=max_replay_size, learning_rate=learning_rate, store_lstm_state=store_lstm_state, max_priority_weight=max_priority_weight, ) self._checkpointer = tf2_savers.Checkpointer( subdirectory='r2d2_learner', time_delta_minutes=60, objects_to_save=learner.state, enable_checkpointing=checkpoint, ) self._snapshotter = tf2_savers.Snapshotter( objects_to_save={'network': network}, time_delta_minutes=60.) policy_network = snt.DeepRNN([ network, lambda qs: trfl.epsilon_greedy(qs, epsilon=epsilon).sample(), ]) actor = actors.RecurrentActor(policy_network, adder) observations_per_step = (float(replay_period * batch_size) / samples_per_insert) super().__init__(actor=actor, learner=learner, min_observations=replay_period * max(batch_size, min_replay_size), observations_per_step=observations_per_step)
def __init__( self, policy_network: snt.Module, critic_network: snt.Module, target_policy_network: snt.Module, target_critic_network: snt.Module, discount: float, target_update_period: int, dataset: tf.data.Dataset, observation_network: types.TensorTransformation = lambda x: x, target_observation_network: types.TensorTransformation = lambda x: x, policy_optimizer: snt.Optimizer = None, critic_optimizer: snt.Optimizer = None, clipping: bool = True, counter: counting.Counter = None, logger: loggers.Logger = None, checkpoint: bool = True, ): """Initializes the learner. Args: policy_network: the online (optimized) policy. critic_network: the online critic. target_policy_network: the target policy (which lags behind the online policy). target_critic_network: the target critic. discount: discount to use for TD updates. target_update_period: number of learner steps to perform before updating the target networks. dataset: dataset to learn from, whether fixed or from a replay buffer (see `acme.datasets.reverb.make_dataset` documentation). observation_network: an optional online network to process observations before the policy and the critic. target_observation_network: the target observation network. policy_optimizer: the optimizer to be applied to the DPG (policy) loss. critic_optimizer: the optimizer to be applied to the distributional Bellman loss. clipping: whether to clip gradients by global norm. counter: counter object used to keep track of steps. logger: logger object to be used by learner. checkpoint: boolean indicating whether to checkpoint the learner. """ # Store online and target networks. self._policy_network = policy_network self._critic_network = critic_network self._target_policy_network = target_policy_network self._target_critic_network = target_critic_network # Make sure observation networks are snt.Module's so they have variables. self._observation_network = tf2_utils.to_sonnet_module(observation_network) self._target_observation_network = tf2_utils.to_sonnet_module( target_observation_network) # General learner book-keeping and loggers. self._counter = counter or counting.Counter() self._logger = logger or loggers.make_default_logger('learner') # Other learner parameters. self._discount = discount self._clipping = clipping # Necessary to track when to update target networks. self._num_steps = tf.Variable(0, dtype=tf.int32) self._target_update_period = target_update_period # Batch dataset and create iterator. # TODO(b/155086959): Fix type stubs and remove. self._iterator = iter(dataset) # pytype: disable=wrong-arg-types # Create optimizers if they aren't given. self._critic_optimizer = critic_optimizer or snt.optimizers.Adam(1e-4) self._policy_optimizer = policy_optimizer or snt.optimizers.Adam(1e-4) # Expose the variables. policy_network_to_expose = snt.Sequential( [self._target_observation_network, self._target_policy_network]) self._variables = { 'critic': self._target_critic_network.variables, 'policy': policy_network_to_expose.variables, } # Create a checkpointer and snapshotter objects. self._checkpointer = None self._snapshotter = None if checkpoint: self._checkpointer = tf2_savers.Checkpointer( subdirectory='d4pg_learner', objects_to_save={ 'counter': self._counter, 'policy': self._policy_network, 'critic': self._critic_network, 'observation': self._observation_network, 'target_policy': self._target_policy_network, 'target_critic': self._target_critic_network, 'target_observation': self._target_observation_network, 'policy_optimizer': self._policy_optimizer, 'critic_optimizer': self._critic_optimizer, 'num_steps': self._num_steps, }) critic_mean = snt.Sequential( [self._critic_network, acme_nets.StochasticMeanHead()]) self._snapshotter = tf2_savers.Snapshotter( objects_to_save={ 'policy': self._policy_network, 'critic': critic_mean, }) # Do not record timestamps until after the first learning step is done. # This is to avoid including the time it takes for actors to come online and # fill the replay buffer. self._timestamp = None
def __init__(self, environment_spec: specs.EnvironmentSpec, network: snt.RNNCore, target_network: snt.RNNCore, burn_in_length: int, trace_length: int, replay_period: int, demonstration_dataset: tf.data.Dataset, demonstration_ratio: float, counter: counting.Counter = None, logger: loggers.Logger = None, discount: float = 0.99, batch_size: int = 32, target_update_period: int = 100, importance_sampling_exponent: float = 0.2, epsilon: float = 0.01, learning_rate: float = 1e-3, log_to_bigtable: bool = False, log_name: str = 'agent', checkpoint: bool = True, min_replay_size: int = 1000, max_replay_size: int = 1000000, samples_per_insert: float = 32.0): replay_table = reverb.Table( name=adders.DEFAULT_PRIORITY_TABLE, sampler=reverb.selectors.Uniform(), remover=reverb.selectors.Fifo(), max_size=max_replay_size, rate_limiter=reverb.rate_limiters.MinSize(min_size_to_sample=1)) self._server = reverb.Server([replay_table], port=None) address = f'localhost:{self._server.port}' sequence_length = burn_in_length + trace_length + 1 # Component to add things into replay. sequence_kwargs = dict( period=replay_period, sequence_length=sequence_length, ) adder = adders.SequenceAdder(client=reverb.Client(address), **sequence_kwargs) # The dataset object to learn from. reverb_client = reverb.TFClient(address) extra_spec = { 'core_state': network.initial_state(1), } # Remove batch dimensions. extra_spec = tf2_utils.squeeze_batch_dim(extra_spec) dataset = datasets.make_reverb_dataset( client=reverb_client, environment_spec=environment_spec, extra_spec=extra_spec, sequence_length=sequence_length) # Combine with demonstration dataset. transition = functools.partial(_sequence_from_episode, extra_spec=extra_spec, **sequence_kwargs) dataset_demos = demonstration_dataset.map(transition) dataset = tf.data.experimental.sample_from_datasets( [dataset, dataset_demos], [1 - demonstration_ratio, demonstration_ratio]) # Batch and prefetch. dataset = dataset.batch(batch_size, drop_remainder=True) dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE) tf2_utils.create_variables(network, [environment_spec.observations]) tf2_utils.create_variables(target_network, [environment_spec.observations]) learner = learning.R2D2Learner( environment_spec=environment_spec, network=network, target_network=target_network, burn_in_length=burn_in_length, dataset=dataset, reverb_client=reverb_client, counter=counter, logger=logger, sequence_length=sequence_length, discount=discount, target_update_period=target_update_period, importance_sampling_exponent=importance_sampling_exponent, max_replay_size=max_replay_size, learning_rate=learning_rate, store_lstm_state=False, ) self._checkpointer = tf2_savers.Checkpointer( subdirectory='r2d2_learner', time_delta_minutes=60, objects_to_save=learner.state, enable_checkpointing=checkpoint, ) self._snapshotter = tf2_savers.Snapshotter( objects_to_save={'network': network}, time_delta_minutes=60.) policy_network = snt.DeepRNN([ network, lambda qs: trfl.epsilon_greedy(qs, epsilon=epsilon).sample(), ]) actor = actors.RecurrentActor(policy_network, adder) observations_per_step = (float(replay_period * batch_size) / samples_per_insert) super().__init__( actor=actor, learner=learner, min_observations=replay_period * max(batch_size, min_replay_size), observations_per_step=observations_per_step)
def __init__( self, network: snt.Module, target_network: snt.Module, discount: float, importance_sampling_exponent: float, learning_rate: float, target_update_period: int, dataset: tf.data.Dataset, huber_loss_parameter: float = 1., replay_client: reverb.TFClient = None, counter: counting.Counter = None, logger: loggers.Logger = None, ): """Initializes the learner. Args: network: the online Q network (the one being optimized) target_network: the target Q critic (which lags behind the online net). discount: discount to use for TD updates. importance_sampling_exponent: power to which importance weights are raised before normalizing. learning_rate: learning rate for the q-network update. target_update_period: number of learner steps to perform before updating the target networks. dataset: dataset to learn from, whether fixed or from a replay buffer (see `acme.datasets.reverb.make_dataset` documentation). huber_loss_parameter: Quadratic-linear boundary for Huber loss. replay_client: client to replay to allow for updating priorities. counter: Counter object for (potentially distributed) counting. logger: Logger object for writing logs to. """ # Internalise agent components (replay buffer, networks, optimizer). # TODO(b/155086959): Fix type stubs and remove. self._iterator = iter(dataset) # pytype: disable=wrong-arg-types self._network = network self._target_network = target_network self._optimizer = snt.optimizers.Adam(learning_rate) self._replay_client = replay_client # Internalise the hyperparameters. self._discount = discount self._target_update_period = target_update_period self._importance_sampling_exponent = importance_sampling_exponent self._huber_loss_parameter = huber_loss_parameter # Learner state. self._variables: List[List[tf.Tensor]] = [network.trainable_variables] self._num_steps = tf.Variable(0, dtype=tf.int32) # Internalise logging/counting objects. self._counter = counter or counting.Counter() self._logger = logger or loggers.TerminalLogger('learner', time_delta=1.) self._snapshotter = tf2_savers.Snapshotter( objects_to_save={'network': network}, time_delta_minutes=60.) # Do not record timestamps until after the first learning step is done. # This is to avoid including the time it takes for actors to come online and # fill the replay buffer. self._timestamp = None