def testRequiredConstructorArgs(self): with self.assertRaisesRegex(NotImplementedError, 'with return_state==False'): rnn_wrapper.RNNWrapper(tf.keras.layers.LSTM(3, return_sequences=True)) with self.assertRaisesRegex(NotImplementedError, 'with return_sequences==False'): rnn_wrapper.RNNWrapper(tf.keras.layers.LSTM(3, return_state=True))
def testWrapperCall(self): wrapper = rnn_wrapper.RNNWrapper( tf.keras.layers.LSTM(3, return_state=True, return_sequences=True)) batch_size = 2 input_depth = 5 inputs = np.random.rand(batch_size, input_depth).astype(np.float32) # Make sure wrapper call works when no time dimension is passed in. outputs, next_state = wrapper(inputs) inputs_time_dim = tf.expand_dims(inputs, axis=1) outputs_time_dim, next_state_time_dim = wrapper(inputs_time_dim) outputs_time_dim = tf.squeeze(outputs_time_dim, axis=1) outputs_manual_state, next_state_manual_state = wrapper( inputs, wrapper.get_initial_state(inputs)) self.evaluate(tf.compat.v1.global_variables_initializer()) for out_variant in (outputs, outputs_time_dim, outputs_manual_state): self.assertEqual(out_variant.shape, (batch_size, 3)) for state_variant in (next_state, next_state_time_dim, next_state_manual_state): self.assertLen(state_variant, 2) self.assertEqual(state_variant[0].shape, (batch_size, 3)) self.assertEqual(state_variant[1].shape, (batch_size, 3)) self.assertAllClose(outputs, outputs_time_dim) self.assertAllClose(outputs, outputs_manual_state) self.assertAllClose(next_state, next_state_time_dim) self.assertAllClose(next_state, next_state_manual_state)
def __init__(self, layers: typing.Sequence[tf.keras.layers.Layer], input_spec: types.NestedTensorSpec = None, name: typing.Text = None): """Create a Sequential Network. Args: layers: A list or tuple of layers to compose. Any layers that are subclasses of `tf.keras.layers.{RNN,LSTM,GRU,...}` are wrapped in `tf_agents.keras_layers.RNNWrapper`. input_spec: (Optional.) A nest of `tf.TypeSpec` representing the input observations. name: (Optional.) Network name. Raises: ValueError: If `layers` is empty. ValueError: If `layers[0]` is a generic Keras layer (not a TF-Agents network) and `input_spec is None`. TypeError: If any of the layers are not instances of keras `Layer`. RuntimeError: If not `tf.executing_eagerly()`; as this is required to be able to create deep copies of layers in `layers`. """ if not tf.executing_eagerly(): raise RuntimeError( 'Not executing eagerly - cannot make deep copies of `layers`.') if not layers: raise ValueError( '`layers` must not be empty; saw: {}'.format(layers)) for layer in layers: if not isinstance(layer, tf.keras.layers.Layer): raise TypeError( 'Expected all layers to be instances of keras Layer, but saw' ': \'{}\''.format(layer)) layers = [ rnn_wrapper.RNNWrapper(layer) if isinstance( layer, tf.keras.layers.RNN) else layer for layer in layers ] state_spec = _infer_state_specs(layers) # Now we remove all of the empty state specs so if there are no RNN layers, # our state spec is empty. layer_has_state is a list of bools telling us # which layers have a state and which don't. # TODO(b/158804957): tf.function changes "s in ((),)" to a tensor bool expr. # pylint: disable=literal-comparison layer_has_state = [s is not () for s in state_spec] state_spec = tuple(s for s in state_spec if s is not ()) # pylint: enable=literal-comparison super(Sequential, self).__init__(input_tensor_spec=input_spec, state_spec=state_spec, name=name) self._sequential_layers = layers self._layer_has_state = layer_has_state
def testWrapperBuild(self): wrapper = rnn_wrapper.RNNWrapper( tf.keras.layers.LSTM(3, return_state=True, return_sequences=True)) # Make sure wrapper.build() works when no time dimension is passed in. wrapper.build((1, 4)) self.evaluate(tf.compat.v1.global_variables_initializer()) variables = self.evaluate(wrapper.trainable_variables) self.assertLen(variables, 3) self.assertLen(wrapper.variables, 3) self.assertTrue(wrapper.trainable) wrapper.trainable = False self.assertFalse(wrapper.trainable) self.assertEmpty(wrapper.trainable_variables) self.assertLen(wrapper.variables, 3)
def __init__(self, layers: Sequence[tf.keras.layers.Layer], input_spec: Optional[types.NestedTensorSpec] = None, name: Optional[Text] = None): """Create a Sequential Network. Args: layers: A list or tuple of layers to compose. Any layers that are subclasses of `tf.keras.layers.{RNN,LSTM,GRU,...}` are wrapped in `tf_agents.keras_layers.RNNWrapper`. input_spec: (Optional.) A nest of `tf.TypeSpec` representing the input observations to the first layer. name: (Optional.) Network name. Raises: ValueError: If `layers` is empty. ValueError: If `layers[0]` is a generic Keras layer (not a TF-Agents network) and `input_spec is None`. TypeError: If any of the layers are not instances of keras `Layer`. """ if not layers: raise ValueError( '`layers` must not be empty; saw: {}'.format(layers)) for layer in layers: if not isinstance(layer, tf.keras.layers.Layer): raise TypeError( 'Expected all layers to be instances of keras Layer, but saw' ': \'{}\''.format(layer)) layers = [ rnn_wrapper.RNNWrapper(layer) if isinstance( layer, tf.keras.layers.RNN) else layer for layer in layers ] state_spec, self._layer_state_is_list = _infer_state_specs(layers) # Now we remove all of the empty state specs so if there are no RNN layers, # our state spec is empty. layer_has_state is a list of bools telling us # which layers have a non-empty state and which don't. flattened_specs = [tf.nest.flatten(s) for s in state_spec] layer_has_state = [bool(fs) for fs in flattened_specs] state_spec = tuple(s for s, has_state in zip(state_spec, layer_has_state) if has_state) super(Sequential, self).__init__(input_tensor_spec=input_spec, state_spec=state_spec, name=name) self._sequential_layers = layers self._layer_has_state = layer_has_state
class CreateVariablesTest(parameterized.TestCase, tf.test.TestCase): def testNetworkCreate(self): observation_spec = specs.TensorSpec([1], tf.float32, 'observation') action_spec = specs.TensorSpec([2], tf.float32, 'action') net = MockNetwork(observation_spec, action_spec) self.assertFalse(net.built) with self.assertRaises(ValueError): net.variables # pylint: disable=pointless-statement output_spec = network.create_variables(net) # MockNetwork adds some variables to observation, which has shape [bs, 1] self.assertEqual(output_spec, tf.TensorSpec([1], dtype=tf.float32)) self.assertTrue(net.built) self.assertLen(net.variables, 2) self.assertLen(net.trainable_variables, 1) # pylint: disable=g-long-lambda @parameterized.named_parameters( ( 'Dense', lambda: tf.keras.layers.Dense(3), tf.TensorSpec((5,), tf.float32), # input_spec tf.TensorSpec((3,), tf.float32), # expected_output_spec (), # expected_state_spec ), ( 'LSTMCell', lambda: tf.keras.layers.LSTMCell(3), tf.TensorSpec((5,), tf.float32), tf.TensorSpec((3,), tf.float32), [tf.TensorSpec((3,), tf.float32), tf.TensorSpec((3,), tf.float32)], ), ( 'LSTMCellInRNN', lambda: rnn_wrapper.RNNWrapper( tf.keras.layers.RNN( tf.keras.layers.LSTMCell(3), return_state=True, return_sequences=True) ), tf.TensorSpec((5,), tf.float32), tf.TensorSpec((3,), tf.float32), [tf.TensorSpec((3,), tf.float32), tf.TensorSpec((3,), tf.float32)], ), ( 'LSTM', lambda: rnn_wrapper.RNNWrapper( tf.keras.layers.LSTM( 3, return_state=True, return_sequences=True) ), tf.TensorSpec((5,), tf.float32), tf.TensorSpec((3,), tf.float32), [tf.TensorSpec((3,), tf.float32), tf.TensorSpec((3,), tf.float32)], ), ( 'TimeDistributed', lambda: tf.keras.layers.TimeDistributed(tf.keras.layers.Dense(3)), tf.TensorSpec((5,), tf.float32), tf.TensorSpec((3,), tf.float32), () ), ( 'Conv2D', lambda: tf.keras.layers.Conv2D(2, 3), tf.TensorSpec((28, 28, 5), tf.float32), tf.TensorSpec((26, 26, 2), tf.float32), () ), ( 'SequentialOfDense', lambda: tf.keras.Sequential([tf.keras.layers.Dense(3)] * 2), tf.TensorSpec((5,), tf.float32), tf.TensorSpec((3,), tf.float32), () ), ( 'NormalDistribution', lambda: tf.keras.Sequential( [tf.keras.layers.Dense(3), tf.keras.layers.Lambda( lambda x: tfd.Normal(loc=x, scale=x**2))]), tf.TensorSpec((5,), tf.float32), distribution_utils.DistributionSpecV2( event_shape=tf.TensorShape(()), dtype=tf.float32, parameters=distribution_utils.Params( type_=tfd.Normal, params=dict( loc=tf.TensorSpec((3,), tf.float32), scale=tf.TensorSpec((3,), tf.float32), ))), () ), ) # pylint: enable=g-long-λ def testKerasLayerCreate(self, layer_fn, input_spec, expected_output_spec, expected_state_spec): layer = layer_fn() with self.assertRaisesRegex(ValueError, 'an input_spec is required'): network.create_variables(layer) output_spec = network.create_variables(layer, input_spec) self.assertTrue(layer.built) self.assertEqual( output_spec, expected_output_spec, '\n{}\nvs.\n{}\n'.format(output_spec, expected_output_spec)) output_spec_2 = network.create_variables(layer, input_spec) self.assertEqual(output_spec_2, expected_output_spec) state_spec = getattr(layer, '_network_state_spec', None) self.assertEqual(state_spec, expected_state_spec)
def testCopy(self): wrapper = rnn_wrapper.RNNWrapper( tf.keras.layers.LSTM(3, return_state=True, return_sequences=True)) clone = type(wrapper).from_config(wrapper.get_config()) self.assertEqual(wrapper.wrapped_layer.dtype, clone.wrapped_layer.dtype) self.assertEqual(wrapper.wrapped_layer.units, clone.wrapped_layer.units)