def _infer_state_specs( layers: Sequence[tf.keras.layers.Layer] ) -> Tuple[types.NestedTensorSpec, List[bool]]: """Infer the state spec of a sequence of keras Layers and Networks. Args: layers: A list of Keras layers and Network. Returns: A tuple with `state_spec`, a tuple of the state specs of length `len(layers)` and a list of bools indicating if the corresponding layer has lists in it's state. """ state_specs = [] layer_state_is_list = [] for layer in layers: spec = network.get_state_spec(layer) if isinstance(spec, list): layer_state_is_list.append(True) state_specs.append(tuple(spec)) else: state_specs.append(spec) layer_state_is_list.append(False) return tuple(state_specs), layer_state_is_list
def _infer_specs( layers: typing.Sequence[tf.keras.layers.Layer], input_spec: types.NestedTensorSpec ) -> typing.Tuple[ types.NestedTensorSpec, types.NestedTensorSpec ]: """Infer the state spec of a sequence of keras Layers and Networks. This runs `create_variables` on each layer, and identifies the state spec from each. Running `create_variables` is necessary because this creates a `_network_state_spec` property on each generic (non-Network) Keras layer. Args: layers: A list of Keras layers and Network. input_spec: The input to the first laayer. Returns: A tuple `(output_spec, state_spec)` where `output_spec` is the output spec from the final layer and `state_spec` is a tuple of the state specs. """ state_specs = [] output_spec = input_spec for layer in layers: output_spec = network.create_variables(layer, output_spec) state_spec = network.get_state_spec(layer) state_specs.append(state_spec) state_specs = tuple(state_specs) return output_spec, state_specs
def _infer_state_specs( layers: Sequence[tf.keras.layers.Layer]) -> types.NestedTensorSpec: """Infer the state spec of a sequence of keras Layers and Networks. Args: layers: A list of Keras layers and Network. Returns: `state_spec`, a tuple of the state specs of length `len(layers)`. """ state_specs = tuple(network.get_state_spec(layer) for layer in layers) return state_specs # pytype: disable=bad-return-type