def test_snapshot_distribution(self): """Test that snapshotter correctly calls saves/restores snapshots.""" # Create a test network. net1 = snt.Sequential([ networks.LayerNormMLP([10, 10]), networks.MultivariateNormalDiagHead(1) ]) spec = specs.Array([10], dtype=np.float32) tf2_utils.create_variables(net1, [spec]) # Save the test network. directory = self.get_tempdir() objects_to_save = {'net': net1} snapshotter = tf2_savers.Snapshotter(objects_to_save, directory=directory) snapshotter.save() # Reload the test network. net2 = tf.saved_model.load(os.path.join(snapshotter.directory, 'net')) inputs = tf2_utils.add_batch_dim(tf2_utils.zeros_like(spec)) with tf.GradientTape() as tape: dist1 = net1(inputs) loss1 = tf.math.reduce_sum(dist1.mean() + dist1.variance()) grads1 = tape.gradient(loss1, net1.trainable_variables) with tf.GradientTape() as tape: dist2 = net2(inputs) loss2 = tf.math.reduce_sum(dist2.mean() + dist2.variance()) grads2 = tape.gradient(loss2, net2.trainable_variables) assert all(tree.map_structure(np.allclose, list(grads1), list(grads2)))
def select_action(self, observation: types.NestedArray) -> types.NestedArray: # Add a dummy batch dimension and as a side effect convert numpy to TF. batched_obs = tf2_utils.add_batch_dim(observation) # Initialize the RNN state if necessary. if self._state is None: self._state = self._network.initial_state(1) # Forward. policy_output, new_state = self._policy(batched_obs, self._state) # If the policy network parameterises a distribution, sample from it. def maybe_sample(output): if isinstance(output, tfd.Distribution): output = output.sample() return output policy_output = tree.map_structure(maybe_sample, policy_output) self._prev_state = self._state self._state = new_state # Convert to numpy and squeeze out the batch dimension. action = tf2_utils.to_numpy_squeeze(policy_output) return action
def select_action(self, observation: types.NestedArray) -> types.NestedArray: # Add a dummy batch dimension and as a side effect convert numpy to TF. batched_obs = tf2_utils.add_batch_dim(observation) if self._state is None: self._state = self._network.initial_state(1) # Forward. (logits, _), new_state = self._policy(batched_obs, self._state) self._prev_logits = logits self._prev_state = self._state self._state = new_state action = tfd.Categorical(logits).sample() action = tf2_utils.to_numpy_squeeze(action) return action
def select_action(self, observation: types.NestedArray) -> types.NestedArray: # Add a dummy batch dimension and as a side effect convert numpy to TF. batched_obs = tf2_utils.add_batch_dim(observation) # Forward the policy network. policy_output = self._policy_network(batched_obs) # If the policy network parameterises a distribution, sample from it. def maybe_sample(output): if isinstance(output, tfd.Distribution): output = output.sample() return output policy_output = tree.map_structure(maybe_sample, policy_output) # Convert to numpy and squeeze out the batch dimension. action = tf2_utils.to_numpy_squeeze(policy_output) return action
def step(self, action: types.Action): # Reset if required. if self._needs_reset: raise ValueError('Model must be reset with an initial timestep.') # Step the model. state, action = tf2_utils.add_batch_dim([self._state, action]) new_state, reward, discount_logits = [ x.numpy().squeeze(axis=0) for x in self._forward(state, action) ] discount = special.softmax(discount_logits) # Save the resulting state for the next step. self._state = new_state # We threshold discount on a given tolerance. if discount < self._terminal_tol: self._needs_reset = True return dm_env.termination(reward=reward, observation=self._state.copy()) return dm_env.transition(reward=reward, observation=self._state.copy())
def test_rnn_snapshot(self): """Test that snapshotter correctly calls saves/restores snapshots on RNNs.""" # Create a test network. net = snt.LSTM(10) spec = specs.Array([10], dtype=np.float32) tf2_utils.create_variables(net, [spec]) # Test that if you add some postprocessing without rerunning # create_variables, it still works. wrapped_net = snt.DeepRNN([net, lambda x: x]) for net1 in [net, wrapped_net]: # Save the test network. directory = self.get_tempdir() objects_to_save = {'net': net1} snapshotter = tf2_savers.Snapshotter(objects_to_save, directory=directory) snapshotter.save() # Reload the test network. net2 = tf.saved_model.load( os.path.join(snapshotter.directory, 'net')) inputs = tf2_utils.add_batch_dim(tf2_utils.zeros_like(spec)) with tf.GradientTape() as tape: outputs1, next_state1 = net1(inputs, net1.initial_state(1)) loss1 = tf.math.reduce_sum(outputs1) grads1 = tape.gradient(loss1, net1.trainable_variables) with tf.GradientTape() as tape: outputs2, next_state2 = net2(inputs, net2.initial_state(1)) loss2 = tf.math.reduce_sum(outputs2) grads2 = tape.gradient(loss2, net2.trainable_variables) assert np.allclose(outputs1, outputs2) assert np.allclose(tree.flatten(next_state1), tree.flatten(next_state2)) assert all( tree.map_structure(np.allclose, list(grads1), list(grads2)))