def testRequiredConstructorArgs(self):
    with self.assertRaisesRegex(NotImplementedError,
                                'with return_state==False'):
      rnn_wrapper.RNNWrapper(tf.keras.layers.LSTM(3, return_sequences=True))

    with self.assertRaisesRegex(NotImplementedError,
                                'with return_sequences==False'):
      rnn_wrapper.RNNWrapper(tf.keras.layers.LSTM(3, return_state=True))
  def testWrapperCall(self):
    wrapper = rnn_wrapper.RNNWrapper(
        tf.keras.layers.LSTM(3, return_state=True, return_sequences=True))

    batch_size = 2
    input_depth = 5
    inputs = np.random.rand(batch_size, input_depth).astype(np.float32)

    # Make sure wrapper call works when no time dimension is passed in.
    outputs, next_state = wrapper(inputs)

    inputs_time_dim = tf.expand_dims(inputs, axis=1)
    outputs_time_dim, next_state_time_dim = wrapper(inputs_time_dim)
    outputs_time_dim = tf.squeeze(outputs_time_dim, axis=1)

    outputs_manual_state, next_state_manual_state = wrapper(
        inputs, wrapper.get_initial_state(inputs))

    self.evaluate(tf.compat.v1.global_variables_initializer())
    for out_variant in (outputs, outputs_time_dim, outputs_manual_state):
      self.assertEqual(out_variant.shape, (batch_size, 3))
    for state_variant in (next_state, next_state_time_dim,
                          next_state_manual_state):
      self.assertLen(state_variant, 2)
      self.assertEqual(state_variant[0].shape, (batch_size, 3))
      self.assertEqual(state_variant[1].shape, (batch_size, 3))

    self.assertAllClose(outputs, outputs_time_dim)
    self.assertAllClose(outputs, outputs_manual_state)
    self.assertAllClose(next_state, next_state_time_dim)
    self.assertAllClose(next_state, next_state_manual_state)
예제 #3
0
    def __init__(self,
                 layers: typing.Sequence[tf.keras.layers.Layer],
                 input_spec: types.NestedTensorSpec = None,
                 name: typing.Text = None):
        """Create a Sequential Network.

    Args:
      layers: A list or tuple of layers to compose.  Any layers that
        are subclasses of `tf.keras.layers.{RNN,LSTM,GRU,...}` are
        wrapped in `tf_agents.keras_layers.RNNWrapper`.
      input_spec: (Optional.) A nest of `tf.TypeSpec` representing the
        input observations.
      name: (Optional.) Network name.

    Raises:
      ValueError: If `layers` is empty.
      ValueError: If `layers[0]` is a generic Keras layer (not a TF-Agents
        network) and `input_spec is None`.
      TypeError: If any of the layers are not instances of keras `Layer`.
      RuntimeError: If not `tf.executing_eagerly()`; as this is required to
        be able to create deep copies of layers in `layers`.
    """
        if not tf.executing_eagerly():
            raise RuntimeError(
                'Not executing eagerly - cannot make deep copies of `layers`.')
        if not layers:
            raise ValueError(
                '`layers` must not be empty; saw: {}'.format(layers))
        for layer in layers:
            if not isinstance(layer, tf.keras.layers.Layer):
                raise TypeError(
                    'Expected all layers to be instances of keras Layer, but saw'
                    ': \'{}\''.format(layer))

        layers = [
            rnn_wrapper.RNNWrapper(layer) if isinstance(
                layer, tf.keras.layers.RNN) else layer for layer in layers
        ]

        state_spec = _infer_state_specs(layers)

        # Now we remove all of the empty state specs so if there are no RNN layers,
        # our state spec is empty.  layer_has_state is a list of bools telling us
        # which layers have a state and which don't.
        # TODO(b/158804957): tf.function changes "s in ((),)" to a tensor bool expr.
        # pylint: disable=literal-comparison
        layer_has_state = [s is not () for s in state_spec]
        state_spec = tuple(s for s in state_spec if s is not ())
        # pylint: enable=literal-comparison
        super(Sequential, self).__init__(input_tensor_spec=input_spec,
                                         state_spec=state_spec,
                                         name=name)
        self._sequential_layers = layers
        self._layer_has_state = layer_has_state
 def testWrapperBuild(self):
   wrapper = rnn_wrapper.RNNWrapper(
       tf.keras.layers.LSTM(3, return_state=True, return_sequences=True))
   # Make sure wrapper.build() works when no time dimension is passed in.
   wrapper.build((1, 4))
   self.evaluate(tf.compat.v1.global_variables_initializer())
   variables = self.evaluate(wrapper.trainable_variables)
   self.assertLen(variables, 3)
   self.assertLen(wrapper.variables, 3)
   self.assertTrue(wrapper.trainable)
   wrapper.trainable = False
   self.assertFalse(wrapper.trainable)
   self.assertEmpty(wrapper.trainable_variables)
   self.assertLen(wrapper.variables, 3)
예제 #5
0
    def __init__(self,
                 layers: Sequence[tf.keras.layers.Layer],
                 input_spec: Optional[types.NestedTensorSpec] = None,
                 name: Optional[Text] = None):
        """Create a Sequential Network.

    Args:
      layers: A list or tuple of layers to compose.  Any layers that
        are subclasses of `tf.keras.layers.{RNN,LSTM,GRU,...}` are
        wrapped in `tf_agents.keras_layers.RNNWrapper`.
      input_spec: (Optional.) A nest of `tf.TypeSpec` representing the
        input observations to the first layer.
      name: (Optional.) Network name.

    Raises:
      ValueError: If `layers` is empty.
      ValueError: If `layers[0]` is a generic Keras layer (not a TF-Agents
        network) and `input_spec is None`.
      TypeError: If any of the layers are not instances of keras `Layer`.
    """
        if not layers:
            raise ValueError(
                '`layers` must not be empty; saw: {}'.format(layers))
        for layer in layers:
            if not isinstance(layer, tf.keras.layers.Layer):
                raise TypeError(
                    'Expected all layers to be instances of keras Layer, but saw'
                    ': \'{}\''.format(layer))

        layers = [
            rnn_wrapper.RNNWrapper(layer) if isinstance(
                layer, tf.keras.layers.RNN) else layer for layer in layers
        ]

        state_spec, self._layer_state_is_list = _infer_state_specs(layers)

        # Now we remove all of the empty state specs so if there are no RNN layers,
        # our state spec is empty.  layer_has_state is a list of bools telling us
        # which layers have a non-empty state and which don't.
        flattened_specs = [tf.nest.flatten(s) for s in state_spec]
        layer_has_state = [bool(fs) for fs in flattened_specs]
        state_spec = tuple(s
                           for s, has_state in zip(state_spec, layer_has_state)
                           if has_state)
        super(Sequential, self).__init__(input_tensor_spec=input_spec,
                                         state_spec=state_spec,
                                         name=name)
        self._sequential_layers = layers
        self._layer_has_state = layer_has_state
예제 #6
0
class CreateVariablesTest(parameterized.TestCase, tf.test.TestCase):

  def testNetworkCreate(self):
    observation_spec = specs.TensorSpec([1], tf.float32, 'observation')
    action_spec = specs.TensorSpec([2], tf.float32, 'action')
    net = MockNetwork(observation_spec, action_spec)
    self.assertFalse(net.built)
    with self.assertRaises(ValueError):
      net.variables  # pylint: disable=pointless-statement
    output_spec = network.create_variables(net)
    # MockNetwork adds some variables to observation, which has shape [bs, 1]
    self.assertEqual(output_spec, tf.TensorSpec([1], dtype=tf.float32))
    self.assertTrue(net.built)
    self.assertLen(net.variables, 2)
    self.assertLen(net.trainable_variables, 1)

  # pylint: disable=g-long-lambda
  @parameterized.named_parameters(
      (
          'Dense',
          lambda: tf.keras.layers.Dense(3),
          tf.TensorSpec((5,), tf.float32),  # input_spec
          tf.TensorSpec((3,), tf.float32),  # expected_output_spec
          (),  # expected_state_spec
      ),
      (
          'LSTMCell',
          lambda: tf.keras.layers.LSTMCell(3),
          tf.TensorSpec((5,), tf.float32),
          tf.TensorSpec((3,), tf.float32),
          [tf.TensorSpec((3,), tf.float32),
           tf.TensorSpec((3,), tf.float32)],
      ),
      (
          'LSTMCellInRNN',
          lambda: rnn_wrapper.RNNWrapper(
              tf.keras.layers.RNN(
                  tf.keras.layers.LSTMCell(3),
                  return_state=True,
                  return_sequences=True)
          ),
          tf.TensorSpec((5,), tf.float32),
          tf.TensorSpec((3,), tf.float32),
          [tf.TensorSpec((3,), tf.float32),
           tf.TensorSpec((3,), tf.float32)],
      ),
      (
          'LSTM',
          lambda: rnn_wrapper.RNNWrapper(
              tf.keras.layers.LSTM(
                  3,
                  return_state=True,
                  return_sequences=True)
          ),
          tf.TensorSpec((5,), tf.float32),
          tf.TensorSpec((3,), tf.float32),
          [tf.TensorSpec((3,), tf.float32),
           tf.TensorSpec((3,), tf.float32)],
      ),
      (
          'TimeDistributed',
          lambda: tf.keras.layers.TimeDistributed(tf.keras.layers.Dense(3)),
          tf.TensorSpec((5,), tf.float32),
          tf.TensorSpec((3,), tf.float32),
          ()
      ),
      (
          'Conv2D',
          lambda: tf.keras.layers.Conv2D(2, 3),
          tf.TensorSpec((28, 28, 5), tf.float32),
          tf.TensorSpec((26, 26, 2), tf.float32),
          ()
      ),
      (
          'SequentialOfDense',
          lambda: tf.keras.Sequential([tf.keras.layers.Dense(3)] * 2),
          tf.TensorSpec((5,), tf.float32),
          tf.TensorSpec((3,), tf.float32),
          ()
      ),
      (
          'NormalDistribution',
          lambda: tf.keras.Sequential(
              [tf.keras.layers.Dense(3),
               tf.keras.layers.Lambda(
                   lambda x: tfd.Normal(loc=x, scale=x**2))]),
          tf.TensorSpec((5,), tf.float32),
          distribution_utils.DistributionSpecV2(
              event_shape=tf.TensorShape(()),
              dtype=tf.float32,
              parameters=distribution_utils.Params(
                  type_=tfd.Normal,
                  params=dict(
                      loc=tf.TensorSpec((3,), tf.float32),
                      scale=tf.TensorSpec((3,), tf.float32),
                  ))),
          ()
      ),
  )
  # pylint: enable=g-long-λ
  def testKerasLayerCreate(self, layer_fn, input_spec, expected_output_spec,
                           expected_state_spec):
    layer = layer_fn()
    with self.assertRaisesRegex(ValueError, 'an input_spec is required'):
      network.create_variables(layer)
    output_spec = network.create_variables(layer, input_spec)
    self.assertTrue(layer.built)
    self.assertEqual(
        output_spec, expected_output_spec,
        '\n{}\nvs.\n{}\n'.format(output_spec, expected_output_spec))
    output_spec_2 = network.create_variables(layer, input_spec)
    self.assertEqual(output_spec_2, expected_output_spec)
    state_spec = getattr(layer, '_network_state_spec', None)
    self.assertEqual(state_spec, expected_state_spec)
 def testCopy(self):
   wrapper = rnn_wrapper.RNNWrapper(
       tf.keras.layers.LSTM(3, return_state=True, return_sequences=True))
   clone = type(wrapper).from_config(wrapper.get_config())
   self.assertEqual(wrapper.wrapped_layer.dtype, clone.wrapped_layer.dtype)
   self.assertEqual(wrapper.wrapped_layer.units, clone.wrapped_layer.units)