def create_variables( network: snt.Module, input_spec: List[OLT], ) -> Optional[tf.TensorSpec]: """Builds the network with dummy inputs to create the necessary variables. Args: network: Sonnet Module whose variables are to be created. input_spec: list of input specs to the network. The length of this list should match the number of arguments expected by `network`. Returns: output_spec: only returns an output spec if the output is a tf.Tensor, else it doesn't return anything (None); e.g. if the output is a tfp.distributions.Distribution. """ # Create a dummy observation with no batch dimension. dummy_input = [ OLT( observation=zeros_like(in_spec.observation), legal_actions=ones_like(in_spec.legal_actions), terminal=zeros_like(in_spec.terminal), ) for in_spec in input_spec ] # If we have an RNNCore the hidden state will be an additional input. if isinstance(network, snt.RNNCore): initial_state = squeeze_batch_dim(network.initial_state(1)) dummy_input += [initial_state] # Forward pass of the network which will create variables as a side effect. dummy_output = network(*add_batch_dim(dummy_input)) # Evaluate the input signature by converting the dummy input into a # TensorSpec. We then save the signature as a property of the network. This is # done so that we can later use it when creating snapshots. We do this here # because the snapshot code may not have access to the precise form of the # inputs. input_signature = tree.map_structure( lambda t: tf.TensorSpec((None, ) + t.shape, t.dtype), dummy_input) network._input_signature = input_signature # pylint: disable=protected-access def spec(output: tf.Tensor) -> tf.TensorSpec: # If the output is not a Tensor, return None as spec is ill-defined. if not isinstance(output, tf.Tensor): return None # If this is not a scalar Tensor, make sure to squeeze out the batch dim. if tf.rank(output) > 0: output = squeeze_batch_dim(output) return tf.TensorSpec(output.shape, output.dtype) return tree.map_structure(spec, dummy_output)
def test_rnn_snapshot(self): """Test that snapshotter correctly calls saves/restores snapshots on RNNs.""" # Create a test network. net = snt.LSTM(10) spec = specs.Array([10], dtype=np.float32) tf2_utils.create_variables(net, [spec]) # Test that if you add some postprocessing without rerunning # create_variables, it still works. wrapped_net = snt.DeepRNN([net, lambda x: x]) for net1 in [net, wrapped_net]: # Save the test network. directory = self.get_tempdir() objects_to_save = {'net': net1} snapshotter = tf2_savers.Snapshotter(objects_to_save, directory=directory) snapshotter.save() # Reload the test network. net2 = tf.saved_model.load(os.path.join(snapshotter.directory, 'net')) inputs = tf2_utils.add_batch_dim(tf2_utils.zeros_like(spec)) with tf.GradientTape() as tape: outputs1, next_state1 = net1(inputs, net1.initial_state(1)) loss1 = tf.math.reduce_sum(outputs1) grads1 = tape.gradient(loss1, net1.trainable_variables) with tf.GradientTape() as tape: outputs2, next_state2 = net2(inputs, net2.initial_state(1)) loss2 = tf.math.reduce_sum(outputs2) grads2 = tape.gradient(loss2, net2.trainable_variables) assert np.allclose(outputs1, outputs2) assert np.allclose(tree.flatten(next_state1), tree.flatten(next_state2)) assert all(tree.map_structure(np.allclose, list(grads1), list(grads2)))
def test_snapshot_distribution(self): """Test that snapshotter correctly calls saves/restores snapshots.""" # Create a test network. net1 = snt.Sequential([ networks.LayerNormMLP([10, 10]), networks.MultivariateNormalDiagHead(1) ]) spec = specs.Array([10], dtype=np.float32) tf2_utils.create_variables(net1, [spec]) # Save the test network. directory = self.get_tempdir() objects_to_save = {'net': net1} snapshotter = tf2_savers.Snapshotter(objects_to_save, directory=directory) snapshotter.save() # Reload the test network. net2 = tf.saved_model.load(os.path.join(snapshotter.directory, 'net')) inputs = tf2_utils.add_batch_dim(tf2_utils.zeros_like(spec)) with tf.GradientTape() as tape: dist1 = net1(inputs) loss1 = tf.math.reduce_sum(dist1.mean() + dist1.variance()) grads1 = tape.gradient(loss1, net1.trainable_variables) with tf.GradientTape() as tape: dist2 = net2(inputs) loss2 = tf.math.reduce_sum(dist2.mean() + dist2.variance()) grads2 = tape.gradient(loss2, net2.trainable_variables) assert all(tree.map_structure(np.allclose, list(grads1), list(grads2)))