Exemple #1
0
def create_variables(
    network: snt.Module,
    input_spec: List[OLT],
) -> Optional[tf.TensorSpec]:
    """Builds the network with dummy inputs to create the necessary variables.
    Args:
      network: Sonnet Module whose variables are to be created.
      input_spec: list of input specs to the network. The length of this list
        should match the number of arguments expected by `network`.
    Returns:
      output_spec: only returns an output spec if the output is a tf.Tensor, else
          it doesn't return anything (None); e.g. if the output is a
          tfp.distributions.Distribution.
    """
    # Create a dummy observation with no batch dimension.
    dummy_input = [
        OLT(
            observation=zeros_like(in_spec.observation),
            legal_actions=ones_like(in_spec.legal_actions),
            terminal=zeros_like(in_spec.terminal),
        ) for in_spec in input_spec
    ]

    # If we have an RNNCore the hidden state will be an additional input.
    if isinstance(network, snt.RNNCore):
        initial_state = squeeze_batch_dim(network.initial_state(1))
        dummy_input += [initial_state]

    # Forward pass of the network which will create variables as a side effect.
    dummy_output = network(*add_batch_dim(dummy_input))

    # Evaluate the input signature by converting the dummy input into a
    # TensorSpec. We then save the signature as a property of the network. This is
    # done so that we can later use it when creating snapshots. We do this here
    # because the snapshot code may not have access to the precise form of the
    # inputs.
    input_signature = tree.map_structure(
        lambda t: tf.TensorSpec((None, ) + t.shape, t.dtype), dummy_input)
    network._input_signature = input_signature  # pylint: disable=protected-access

    def spec(output: tf.Tensor) -> tf.TensorSpec:
        # If the output is not a Tensor, return None as spec is ill-defined.
        if not isinstance(output, tf.Tensor):
            return None
        # If this is not a scalar Tensor, make sure to squeeze out the batch dim.
        if tf.rank(output) > 0:
            output = squeeze_batch_dim(output)
        return tf.TensorSpec(output.shape, output.dtype)

    return tree.map_structure(spec, dummy_output)
Exemple #2
0
  def test_rnn_snapshot(self):
    """Test that snapshotter correctly calls saves/restores snapshots on RNNs."""
    # Create a test network.
    net = snt.LSTM(10)
    spec = specs.Array([10], dtype=np.float32)
    tf2_utils.create_variables(net, [spec])

    # Test that if you add some postprocessing without rerunning
    # create_variables, it still works.
    wrapped_net = snt.DeepRNN([net, lambda x: x])

    for net1 in [net, wrapped_net]:
      # Save the test network.
      directory = self.get_tempdir()
      objects_to_save = {'net': net1}
      snapshotter = tf2_savers.Snapshotter(objects_to_save, directory=directory)
      snapshotter.save()

      # Reload the test network.
      net2 = tf.saved_model.load(os.path.join(snapshotter.directory, 'net'))
      inputs = tf2_utils.add_batch_dim(tf2_utils.zeros_like(spec))

      with tf.GradientTape() as tape:
        outputs1, next_state1 = net1(inputs, net1.initial_state(1))
        loss1 = tf.math.reduce_sum(outputs1)
        grads1 = tape.gradient(loss1, net1.trainable_variables)

      with tf.GradientTape() as tape:
        outputs2, next_state2 = net2(inputs, net2.initial_state(1))
        loss2 = tf.math.reduce_sum(outputs2)
        grads2 = tape.gradient(loss2, net2.trainable_variables)

      assert np.allclose(outputs1, outputs2)
      assert np.allclose(tree.flatten(next_state1), tree.flatten(next_state2))
      assert all(tree.map_structure(np.allclose, list(grads1), list(grads2)))
Exemple #3
0
  def test_snapshot_distribution(self):
    """Test that snapshotter correctly calls saves/restores snapshots."""
    # Create a test network.
    net1 = snt.Sequential([
        networks.LayerNormMLP([10, 10]),
        networks.MultivariateNormalDiagHead(1)
    ])
    spec = specs.Array([10], dtype=np.float32)
    tf2_utils.create_variables(net1, [spec])

    # Save the test network.
    directory = self.get_tempdir()
    objects_to_save = {'net': net1}
    snapshotter = tf2_savers.Snapshotter(objects_to_save, directory=directory)
    snapshotter.save()

    # Reload the test network.
    net2 = tf.saved_model.load(os.path.join(snapshotter.directory, 'net'))
    inputs = tf2_utils.add_batch_dim(tf2_utils.zeros_like(spec))

    with tf.GradientTape() as tape:
      dist1 = net1(inputs)
      loss1 = tf.math.reduce_sum(dist1.mean() + dist1.variance())
      grads1 = tape.gradient(loss1, net1.trainable_variables)

    with tf.GradientTape() as tape:
      dist2 = net2(inputs)
      loss2 = tf.math.reduce_sum(dist2.mean() + dist2.variance())
      grads2 = tape.gradient(loss2, net2.trainable_variables)

    assert all(tree.map_structure(np.allclose, list(grads1), list(grads2)))