示例#1
0
    def test_snapshot_distribution(self):
        """Test that snapshotter correctly calls saves/restores snapshots."""
        # Create a test network.
        net1 = snt.Sequential([
            networks.LayerNormMLP([10, 10]),
            networks.MultivariateNormalDiagHead(1)
        ])
        spec = specs.Array([10], dtype=np.float32)
        tf2_utils.create_variables(net1, [spec])

        # Save the test network.
        directory = self.get_tempdir()
        objects_to_save = {'net': net1}
        snapshotter = tf2_savers.Snapshotter(objects_to_save,
                                             directory=directory)
        snapshotter.save()

        # Reload the test network.
        net2 = tf.saved_model.load(os.path.join(snapshotter.directory, 'net'))
        inputs = tf2_utils.add_batch_dim(tf2_utils.zeros_like(spec))

        with tf.GradientTape() as tape:
            dist1 = net1(inputs)
            loss1 = tf.math.reduce_sum(dist1.mean() + dist1.variance())
            grads1 = tape.gradient(loss1, net1.trainable_variables)

        with tf.GradientTape() as tape:
            dist2 = net2(inputs)
            loss2 = tf.math.reduce_sum(dist2.mean() + dist2.variance())
            grads2 = tape.gradient(loss2, net2.trainable_variables)

        assert all(tree.map_structure(np.allclose, list(grads1), list(grads2)))
示例#2
0
  def select_action(self, observation: types.NestedArray) -> types.NestedArray:
    # Add a dummy batch dimension and as a side effect convert numpy to TF.
    batched_obs = tf2_utils.add_batch_dim(observation)

    # Initialize the RNN state if necessary.
    if self._state is None:
      self._state = self._network.initial_state(1)

    # Forward.
    policy_output, new_state = self._policy(batched_obs, self._state)

    # If the policy network parameterises a distribution, sample from it.
    def maybe_sample(output):
      if isinstance(output, tfd.Distribution):
        output = output.sample()
      return output

    policy_output = tree.map_structure(maybe_sample, policy_output)

    self._prev_state = self._state
    self._state = new_state

    # Convert to numpy and squeeze out the batch dimension.
    action = tf2_utils.to_numpy_squeeze(policy_output)

    return action
示例#3
0
    def select_action(self,
                      observation: types.NestedArray) -> types.NestedArray:
        # Add a dummy batch dimension and as a side effect convert numpy to TF.
        batched_obs = tf2_utils.add_batch_dim(observation)

        if self._state is None:
            self._state = self._network.initial_state(1)

        # Forward.
        (logits, _), new_state = self._policy(batched_obs, self._state)

        self._prev_logits = logits
        self._prev_state = self._state
        self._state = new_state

        action = tfd.Categorical(logits).sample()
        action = tf2_utils.to_numpy_squeeze(action)

        return action
示例#4
0
  def select_action(self, observation: types.NestedArray) -> types.NestedArray:
    # Add a dummy batch dimension and as a side effect convert numpy to TF.
    batched_obs = tf2_utils.add_batch_dim(observation)

    # Forward the policy network.
    policy_output = self._policy_network(batched_obs)

    # If the policy network parameterises a distribution, sample from it.
    def maybe_sample(output):
      if isinstance(output, tfd.Distribution):
        output = output.sample()
      return output

    policy_output = tree.map_structure(maybe_sample, policy_output)

    # Convert to numpy and squeeze out the batch dimension.
    action = tf2_utils.to_numpy_squeeze(policy_output)

    return action
示例#5
0
  def step(self, action: types.Action):
    # Reset if required.
    if self._needs_reset:
      raise ValueError('Model must be reset with an initial timestep.')

    # Step the model.
    state, action = tf2_utils.add_batch_dim([self._state, action])
    new_state, reward, discount_logits = [
        x.numpy().squeeze(axis=0) for x in self._forward(state, action)
    ]
    discount = special.softmax(discount_logits)

    # Save the resulting state for the next step.
    self._state = new_state

    # We threshold discount on a given tolerance.
    if discount < self._terminal_tol:
      self._needs_reset = True
      return dm_env.termination(reward=reward, observation=self._state.copy())
    return dm_env.transition(reward=reward, observation=self._state.copy())
示例#6
0
    def test_rnn_snapshot(self):
        """Test that snapshotter correctly calls saves/restores snapshots on RNNs."""
        # Create a test network.
        net = snt.LSTM(10)
        spec = specs.Array([10], dtype=np.float32)
        tf2_utils.create_variables(net, [spec])

        # Test that if you add some postprocessing without rerunning
        # create_variables, it still works.
        wrapped_net = snt.DeepRNN([net, lambda x: x])

        for net1 in [net, wrapped_net]:
            # Save the test network.
            directory = self.get_tempdir()
            objects_to_save = {'net': net1}
            snapshotter = tf2_savers.Snapshotter(objects_to_save,
                                                 directory=directory)
            snapshotter.save()

            # Reload the test network.
            net2 = tf.saved_model.load(
                os.path.join(snapshotter.directory, 'net'))
            inputs = tf2_utils.add_batch_dim(tf2_utils.zeros_like(spec))

            with tf.GradientTape() as tape:
                outputs1, next_state1 = net1(inputs, net1.initial_state(1))
                loss1 = tf.math.reduce_sum(outputs1)
                grads1 = tape.gradient(loss1, net1.trainable_variables)

            with tf.GradientTape() as tape:
                outputs2, next_state2 = net2(inputs, net2.initial_state(1))
                loss2 = tf.math.reduce_sum(outputs2)
                grads2 = tape.gradient(loss2, net2.trainable_variables)

            assert np.allclose(outputs1, outputs2)
            assert np.allclose(tree.flatten(next_state1),
                               tree.flatten(next_state2))
            assert all(
                tree.map_structure(np.allclose, list(grads1), list(grads2)))