Example #1
0
    def test_snapshot_distribution(self):
        """Test that snapshotter correctly calls saves/restores snapshots."""
        # Create a test network.
        net1 = snt.Sequential([
            networks.LayerNormMLP([10, 10]),
            networks.MultivariateNormalDiagHead(1)
        ])
        spec = specs.Array([10], dtype=np.float32)
        tf2_utils.create_variables(net1, [spec])

        # Save the test network.
        directory = self.get_tempdir()
        objects_to_save = {'net': net1}
        snapshotter = tf2_savers.Snapshotter(objects_to_save,
                                             directory=directory)
        snapshotter.save()

        # Reload the test network.
        net2 = tf.saved_model.load(os.path.join(snapshotter.directory, 'net'))
        inputs = tf2_utils.add_batch_dim(tf2_utils.zeros_like(spec))

        with tf.GradientTape() as tape:
            dist1 = net1(inputs)
            loss1 = tf.math.reduce_sum(dist1.mean() + dist1.variance())
            grads1 = tape.gradient(loss1, net1.trainable_variables)

        with tf.GradientTape() as tape:
            dist2 = net2(inputs)
            loss2 = tf.math.reduce_sum(dist2.mean() + dist2.variance())
            grads2 = tape.gradient(loss2, net2.trainable_variables)

        assert all(tree.map_structure(np.allclose, list(grads1), list(grads2)))
Example #2
0
    def test_rnn_snapshot(self):
        """Test that snapshotter correctly calls saves/restores snapshots on RNNs."""
        # Create a test network.
        net = snt.LSTM(10)
        spec = specs.Array([10], dtype=np.float32)
        tf2_utils.create_variables(net, [spec])

        # Test that if you add some postprocessing without rerunning
        # create_variables, it still works.
        wrapped_net = snt.DeepRNN([net, lambda x: x])

        for net1 in [net, wrapped_net]:
            # Save the test network.
            directory = self.get_tempdir()
            objects_to_save = {'net': net1}
            snapshotter = tf2_savers.Snapshotter(objects_to_save,
                                                 directory=directory)
            snapshotter.save()

            # Reload the test network.
            net2 = tf.saved_model.load(
                os.path.join(snapshotter.directory, 'net'))
            inputs = tf2_utils.add_batch_dim(tf2_utils.zeros_like(spec))

            with tf.GradientTape() as tape:
                outputs1, next_state1 = net1(inputs, net1.initial_state(1))
                loss1 = tf.math.reduce_sum(outputs1)
                grads1 = tape.gradient(loss1, net1.trainable_variables)

            with tf.GradientTape() as tape:
                outputs2, next_state2 = net2(inputs, net2.initial_state(1))
                loss2 = tf.math.reduce_sum(outputs2)
                grads2 = tape.gradient(loss2, net2.trainable_variables)

            assert np.allclose(outputs1, outputs2)
            assert np.allclose(tree.flatten(next_state1),
                               tree.flatten(next_state2))
            assert all(
                tree.map_structure(np.allclose, list(grads1), list(grads2)))