Beispiel #1
0
def _infer_state_specs(
    layers: Sequence[tf.keras.layers.Layer]
) -> Tuple[types.NestedTensorSpec, List[bool]]:
    """Infer the state spec of a sequence of keras Layers and Networks.

  Args:
    layers: A list of Keras layers and Network.

  Returns:
    A tuple with `state_spec`, a tuple of the state specs of length
    `len(layers)` and a list of bools indicating if the corresponding layer
    has lists in it's state.
  """
    state_specs = []
    layer_state_is_list = []
    for layer in layers:
        spec = network.get_state_spec(layer)
        if isinstance(spec, list):
            layer_state_is_list.append(True)
            state_specs.append(tuple(spec))
        else:
            state_specs.append(spec)
            layer_state_is_list.append(False)

    return tuple(state_specs), layer_state_is_list
Beispiel #2
0
def _infer_specs(
    layers: typing.Sequence[tf.keras.layers.Layer],
    input_spec: types.NestedTensorSpec
) -> typing.Tuple[
    types.NestedTensorSpec,
    types.NestedTensorSpec
]:
  """Infer the state spec of a sequence of keras Layers and Networks.

  This runs `create_variables` on each layer, and identifies the
  state spec from each.  Running `create_variables` is necessary
  because this creates a `_network_state_spec` property on each
  generic (non-Network) Keras layer.

  Args:
    layers: A list of Keras layers and Network.
    input_spec: The input to the first laayer.

  Returns:
    A tuple `(output_spec, state_spec)` where `output_spec` is the output spec
    from the final layer and `state_spec` is a tuple of the state specs.
  """
  state_specs = []

  output_spec = input_spec
  for layer in layers:
    output_spec = network.create_variables(layer, output_spec)
    state_spec = network.get_state_spec(layer)
    state_specs.append(state_spec)

  state_specs = tuple(state_specs)
  return output_spec, state_specs
Beispiel #3
0
def _infer_state_specs(
        layers: Sequence[tf.keras.layers.Layer]) -> types.NestedTensorSpec:
    """Infer the state spec of a sequence of keras Layers and Networks.

  Args:
    layers: A list of Keras layers and Network.

  Returns:
    `state_spec`, a tuple of the state specs of length `len(layers)`.
  """
    state_specs = tuple(network.get_state_spec(layer) for layer in layers)
    return state_specs  # pytype: disable=bad-return-type