def testKerasLayerCreate(self, layer_fn, input_spec, expected_output_spec): layer = layer_fn() with self.assertRaisesRegex(ValueError, 'an input_spec is required'): network.create_variables(layer) output_spec = network.create_variables(layer, input_spec) self.assertTrue(layer.built) self.assertEqual(output_spec, expected_output_spec) output_spec_2 = network.create_variables(layer, input_spec) self.assertEqual(output_spec_2, expected_output_spec)
def testKerasLayerFailsIfRecurrentDoesNotReturnState(self): with self.assertRaisesRegex(ValueError, 'with return_state==False'): network.create_variables( tf.keras.layers.LSTM(3, return_sequences=True), input_spec=tf.TensorSpec((3, ), tf.float32)) with self.assertRaisesRegex(ValueError, 'with return_sequences==False'): network.create_variables(tf.keras.layers.LSTM(3, return_state=True), input_spec=tf.TensorSpec((3, ), tf.float32))
def testKerasLayerCreate(self, layer_fn, input_spec, expected_output_spec, expected_state_spec): layer = layer_fn() with self.assertRaisesRegex(ValueError, 'an input_spec is required'): network.create_variables(layer) output_spec = network.create_variables(layer, input_spec) self.assertTrue(layer.built) self.assertEqual( output_spec, expected_output_spec, '\n{}\nvs.\n{}\n'.format(output_spec, expected_output_spec)) output_spec_2 = network.create_variables(layer, input_spec) self.assertEqual(output_spec_2, expected_output_spec) state_spec = getattr(layer, '_network_state_spec', None) self.assertEqual(state_spec, expected_state_spec)
def _infer_specs( layers: typing.Sequence[tf.keras.layers.Layer], input_spec: types.NestedTensorSpec ) -> typing.Tuple[ types.NestedTensorSpec, types.NestedTensorSpec ]: """Infer the state spec of a sequence of keras Layers and Networks. This runs `create_variables` on each layer, and identifies the state spec from each. Running `create_variables` is necessary because this creates a `_network_state_spec` property on each generic (non-Network) Keras layer. Args: layers: A list of Keras layers and Network. input_spec: The input to the first laayer. Returns: A tuple `(output_spec, state_spec)` where `output_spec` is the output spec from the final layer and `state_spec` is a tuple of the state specs. """ state_specs = [] output_spec = input_spec for layer in layers: output_spec = network.create_variables(layer, output_spec) state_spec = network.get_state_spec(layer) state_specs.append(state_spec) state_specs = tuple(state_specs) return output_spec, state_specs
def testNetworkCreate(self): observation_spec = specs.TensorSpec([1], tf.float32, 'observation') action_spec = specs.TensorSpec([2], tf.float32, 'action') net = MockNetwork(observation_spec, action_spec) self.assertFalse(net.built) with self.assertRaises(ValueError): net.variables # pylint: disable=pointless-statement output_spec = network.create_variables(net) # MockNetwork adds some variables to observation, which has shape [bs, 1] self.assertEqual(output_spec, tf.TensorSpec([1], dtype=tf.float32)) self.assertTrue(net.built) self.assertLen(net.variables, 2) self.assertLen(net.trainable_variables, 1)
def create_variables(self, input_spec=None): output_spec = network.create_variables( self._layer, input_spec or self._input_tensor_spec) self._network_output_spec = output_spec self.built = True return output_spec