Example #1
0
  def _testDynamicUnrollResetsStateOnReset(self, cell_type):
    cell = cell_type()
    batch_size = 4
    max_time = 7
    inputs = tf.random.uniform((batch_size, max_time, 1))
    reset_mask = (tf.random.normal((batch_size, max_time)) > 0)

    layer = dynamic_unroll_layer.DynamicUnroll(cell, dtype=tf.float32)
    outputs, final_state = layer(inputs, reset_mask=reset_mask)

    tf.nest.assert_same_structure(outputs, cell.output_size)
    tf.nest.assert_same_structure(final_state, cell.state_size)

    reset_mask, inputs, outputs, final_state = self.evaluate(
        (reset_mask, inputs, outputs, final_state))

    self.assertAllClose(outputs[:, -1, :], final_state)

    # outputs will contain cumulative sums up until a reset
    expected_outputs = []
    state = np.zeros_like(final_state)
    for i, frame in enumerate(np.transpose(inputs, [1, 0, 2])):
      state = state * np.reshape(~reset_mask[:, i], state.shape) + frame
      expected_outputs.append(np.array(state))
    expected_outputs = np.transpose(expected_outputs, [1, 0, 2])
    self.assertAllClose(outputs, expected_outputs)
Example #2
0
    def testMixOfNonRecurrentAndRecurrent(self):
        sequential = sequential_lib.Sequential(
            [
                tf.keras.layers.Dense(2),
                tf.keras.layers.LSTM(
                    2, return_state=True, return_sequences=True),
                tf.keras.layers.RNN(
                    tf.keras.layers.StackedRNNCells([
                        tf.keras.layers.LSTMCell(1),
                        tf.keras.layers.LSTMCell(32),
                    ], ),
                    return_state=True,
                    return_sequences=True,
                ),
                # Convert inner dimension to [4, 4, 2] for convolution.
                inner_reshape.InnerReshape([32], [4, 4, 2]),
                tf.keras.layers.Conv2D(2, 3),
                # Convert 3 inner dimensions to [?] for RNN.
                inner_reshape.InnerReshape([None] * 3, [-1]),
                tf.keras.layers.GRU(
                    2, return_state=True, return_sequences=True),
                dynamic_unroll_layer.DynamicUnroll(
                    tf.keras.layers.LSTMCell(2)),
            ],
            input_spec=tf.TensorSpec((3, ), tf.float32))
        self.assertEqual(sequential.input_tensor_spec,
                         tf.TensorSpec((3, ), tf.float32))

        output_spec = sequential.create_variables()
        self.assertEqual(output_spec, tf.TensorSpec((2, ), dtype=tf.float32))

        tf.nest.map_structure(
            self.assertEqual,
            sequential.state_spec,
            (
                [  # LSTM
                    tf.TensorSpec((2, ), tf.float32),
                    tf.TensorSpec((2, ), tf.float32),
                ],
                (  # RNN(StackedRNNCells)
                    [
                        tf.TensorSpec((1, ), tf.float32),
                        tf.TensorSpec((1, ), tf.float32),
                    ],
                    [
                        tf.TensorSpec((32, ), tf.float32),
                        tf.TensorSpec((32, ), tf.float32),
                    ],
                ),
                # GRU
                tf.TensorSpec((2, ), tf.float32),
                [  # DynamicUnroll
                    tf.TensorSpec((2, ), tf.float32),
                    tf.TensorSpec((2, ), tf.float32),
                ]))

        inputs = tf.ones((8, 10, 3), dtype=tf.float32)
        outputs, _ = sequential(inputs)
        self.assertEqual(outputs.shape, tf.TensorShape([8, 10, 2]))
def create_recurrent_network(input_fc_layer_units, lstm_size,
                             output_fc_layer_units, num_actions):
    rnn_cell = tf.keras.layers.StackedRNNCells(
        [fused_lstm_cell(s) for s in lstm_size])
    return sequential.Sequential(
        [dense(num_units) for num_units in input_fc_layer_units] +
        [dynamic_unroll_layer.DynamicUnroll(rnn_cell)] +
        [dense(num_units)
         for num_units in output_fc_layer_units] + [logits(num_actions)])
Example #4
0
    def testMixOfNonRecurrentAndRecurrent(self):
        sequential = sequential_lib.Sequential([
            tf.keras.layers.Dense(2),
            tf.keras.layers.LSTM(2, return_state=True, return_sequences=True),
            tf.keras.layers.RNN(
                tf.keras.layers.StackedRNNCells([
                    tf.keras.layers.LSTMCell(1),
                    tf.keras.layers.LSTMCell(32),
                ], ),
                return_state=True,
                return_sequences=True,
            ),
            tf.keras.layers.Reshape((-1, 4, 4, 2)),
            tf.keras.layers.Conv2D(2, 3),
            tf.keras.layers.TimeDistributed(tf.keras.layers.Flatten()),
            tf.keras.layers.GRU(2, return_state=True, return_sequences=True),
            dynamic_unroll_layer.DynamicUnroll(tf.keras.layers.LSTMCell(2)),
        ],
                                               input_spec=tf.TensorSpec(
                                                   (3, ), tf.float32))
        self.assertEqual(sequential.input_tensor_spec,
                         tf.TensorSpec((3, ), tf.float32))

        output_spec = sequential.create_variables()
        self.assertEqual(output_spec, tf.TensorSpec((2, ), dtype=tf.float32))

        tf.nest.map_structure(
            self.assertEqual,
            sequential.state_spec,
            (
                [  # LSTM
                    tf.TensorSpec((2, ), tf.float32),
                    tf.TensorSpec((2, ), tf.float32),
                ],
                [  # RNN(StackedRNNCells)
                    [
                        tf.TensorSpec((1, ), tf.float32),
                        tf.TensorSpec((1, ), tf.float32),
                    ],
                    [
                        tf.TensorSpec((32, ), tf.float32),
                        tf.TensorSpec((32, ), tf.float32),
                    ],
                ],
                [  # GRU
                    tf.TensorSpec((2, ), tf.float32),
                ],
                [  # DynamicUnroll
                    tf.TensorSpec((2, ), tf.float32),
                    tf.TensorSpec((2, ), tf.float32),
                ]))

        inputs = tf.ones((8, 10, 3), dtype=tf.float32)
        outputs, _ = sequential(inputs)
        self.assertEqual(outputs.shape, tf.TensorShape([8, 10, 2]))
Example #5
0
 def testNoTimeDimensionMatchesSingleStep(self):
   cell = tf.keras.layers.LSTMCell(3)
   batch_size = 4
   max_time = 1
   inputs = tf.random.uniform((batch_size, max_time, 2), dtype=tf.float32)
   inputs_no_time = tf.squeeze(inputs, axis=1)
   layer = dynamic_unroll_layer.DynamicUnroll(cell, dtype=tf.float32)
   outputs, next_state = layer(inputs)
   outputs_squeezed_time = tf.squeeze(outputs, axis=1)
   outputs_no_time, next_state_no_time = layer(inputs_no_time)
   self.evaluate(tf.compat.v1.global_variables_initializer())
   outputs_squeezed_time, next_state, outputs_no_time, next_state_no_time = (
       self.evaluate((outputs_squeezed_time, next_state,
                      outputs_no_time, next_state_no_time)))
   self.assertAllEqual(outputs_squeezed_time, outputs_no_time)
   self.assertAllEqual(next_state, next_state_no_time)
Example #6
0
 def testDynamicUnrollMatchesDynamicRNNWhenNoResetSingleTimeStep(self):
   cell = tf.compat.v1.nn.rnn_cell.LSTMCell(3)
   batch_size = 4
   max_time = 1
   inputs = tf.random.uniform((batch_size, max_time, 2), dtype=tf.float32)
   reset_mask = tf.zeros((batch_size, max_time), dtype=tf.bool)
   layer = dynamic_unroll_layer.DynamicUnroll(cell, dtype=tf.float32)
   outputs_dun, final_state_dun = layer(inputs, reset_mask)
   outputs_drnn, final_state_drnn = tf.compat.v1.nn.dynamic_rnn(
       cell, inputs, dtype=tf.float32)
   self.evaluate(tf.compat.v1.global_variables_initializer())
   outputs_dun, final_state_dun, outputs_drnn, final_state_drnn = (
       self.evaluate(
           (outputs_dun, final_state_dun, outputs_drnn, final_state_drnn)))
   self.assertAllClose(outputs_dun, outputs_drnn)
   self.assertAllClose(final_state_dun, final_state_drnn)
Example #7
0
 def testFromConfigLSTM(self):
   l1 = dynamic_unroll_layer.DynamicUnroll(
       tf.keras.layers.LSTMCell(units=3), parallel_iterations=10)
   l2 = dynamic_unroll_layer.DynamicUnroll.from_config(l1.get_config())
   self.assertEqual(l1.get_config(), l2.get_config())
    def testMixOfNonRecurrentAndRecurrent(self):
        sequential = sequential_lib.Sequential(
            [
                tf.keras.layers.Dense(2),
                tf.keras.layers.LSTM(
                    2, return_state=True, return_sequences=True),
                tf.keras.layers.RNN(
                    tf.keras.layers.StackedRNNCells([
                        tf.keras.layers.LSTMCell(1),
                        tf.keras.layers.LSTMCell(32),
                    ], ),
                    return_state=True,
                    return_sequences=True,
                ),
                # Convert inner dimension to [4, 4, 2] for convolution.
                inner_reshape.InnerReshape([32], [4, 4, 2]),
                tf.keras.layers.Conv2D(2, 3),
                # Convert 3 inner dimensions to [?] for RNN.
                inner_reshape.InnerReshape([None] * 3, [-1]),
                tf.keras.layers.GRU(
                    2, return_state=True, return_sequences=True),
                dynamic_unroll_layer.DynamicUnroll(
                    tf.keras.layers.LSTMCell(2)),
                tf.keras.layers.Lambda(
                    lambda x: tfd.MultivariateNormalDiag(loc=x, scale_diag=x)),
            ],
            input_spec=tf.TensorSpec((3, ), tf.float32))  # pytype: disable=wrong-arg-types
        self.assertEqual(sequential.input_tensor_spec,
                         tf.TensorSpec((3, ), tf.float32))

        output_spec = sequential.create_variables()
        self.assertIsInstance(output_spec,
                              distribution_utils.DistributionSpecV2)
        output_event_spec = output_spec.event_spec
        self.assertEqual(output_event_spec,
                         tf.TensorSpec((2, ), dtype=tf.float32))

        tf.nest.map_structure(
            self.assertEqual,
            sequential.state_spec,
            (
                [  # LSTM
                    tf.TensorSpec((2, ), tf.float32),
                    tf.TensorSpec((2, ), tf.float32),
                ],
                (  # RNN(StackedRNNCells)
                    [
                        tf.TensorSpec((1, ), tf.float32),
                        tf.TensorSpec((1, ), tf.float32),
                    ],
                    [
                        tf.TensorSpec((32, ), tf.float32),
                        tf.TensorSpec((32, ), tf.float32),
                    ],
                ),
                # GRU
                tf.TensorSpec((2, ), tf.float32),
                [  # DynamicUnroll
                    tf.TensorSpec((2, ), tf.float32),
                    tf.TensorSpec((2, ), tf.float32),
                ]))

        inputs = tf.ones((8, 10, 3), dtype=tf.float32)
        dist, _ = sequential(inputs)
        outputs = dist.sample()
        self.assertEqual(outputs.shape, tf.TensorShape([8, 10, 2]))
Example #9
0
    def __init__(self,
                 input_tensor_spec,
                 output_tensor_spec,
                 conv_layer_params=None,
                 input_fc_layer_params=(200, 100),
                 lstm_size=(40, ),
                 output_fc_layer_params=(200, 100),
                 activation_fn=tf.keras.activations.relu,
                 name='ActorRnnNetwork'):
        """Creates an instance of `ActorRnnNetwork`.

    Args:
      input_tensor_spec: A nest of `tensor_spec.TensorSpec` representing the
        input observations.
      output_tensor_spec: A nest of `tensor_spec.BoundedTensorSpec` representing
        the actions.
      conv_layer_params: Optional list of convolution layers parameters, where
        each item is a length-three tuple indicating (filters, kernel_size,
        stride).
      input_fc_layer_params: Optional list of fully_connected parameters, where
        each item is the number of units in the layer. This is applied before
        the LSTM cell.
      lstm_size: An iterable of ints specifying the LSTM cell sizes to use.
      output_fc_layer_params: Optional list of fully_connected parameters, where
        each item is the number of units in the layer. This is applied after the
        LSTM cell.
      activation_fn: Activation function, e.g. tf.nn.relu, slim.leaky_relu, ...
      name: A string representing name of the network.

    Returns:
      A nest of action tensors matching the action_spec.

    Raises:
      ValueError: If `input_tensor_spec` contains more than one observation.
    """
        if len(tf.nest.flatten(input_tensor_spec)) > 1:
            raise ValueError(
                'Only a single observation is supported by this network')

        input_layers = utils.mlp_layers(conv_layer_params,
                                        input_fc_layer_params,
                                        activation_fn=activation_fn,
                                        kernel_initializer=tf.compat.v1.keras.
                                        initializers.glorot_uniform(),
                                        name='input_mlp')

        # Create RNN cell
        if len(lstm_size) == 1:
            cell = tf.keras.layers.LSTMCell(lstm_size[0])
        else:
            cell = tf.keras.layers.StackedRNNCells(
                [tf.keras.layers.LSTMCell(size) for size in lstm_size])

        state_spec = tf.nest.map_structure(
            functools.partial(tensor_spec.TensorSpec,
                              dtype=tf.float32,
                              name='network_state_spec'),
            list(cell.state_size))

        output_layers = utils.mlp_layers(
            fc_layer_params=output_fc_layer_params, name='output')

        flat_action_spec = tf.nest.flatten(output_tensor_spec)
        action_layers = [
            tf.keras.layers.Dense(
                single_action_spec.shape.num_elements(),
                activation=tf.keras.activations.tanh,
                kernel_initializer=tf.keras.initializers.RandomUniform(
                    minval=-0.003, maxval=0.003),
                name='action') for single_action_spec in flat_action_spec
        ]

        super(ActorRnnNetwork,
              self).__init__(input_tensor_spec=input_tensor_spec,
                             state_spec=state_spec,
                             name=name)

        self._output_tensor_spec = output_tensor_spec
        self._flat_action_spec = flat_action_spec
        self._conv_layer_params = conv_layer_params
        self._input_layers = input_layers
        self._dynamic_unroll = dynamic_unroll_layer.DynamicUnroll(cell)
        self._output_layers = output_layers
        self._action_layers = action_layers
Example #10
0
    def testMixOfNonRecurrentAndRecurrent(self):
        def reshape_inner_dims(tensor, ndims, new_inner_shape):
            """Reshapes tensor to: shape(tensor)[:-ndims] + new_inner_shape."""
            tensor_shape = tf.shape(tensor)
            new_shape = tf.concat((tensor_shape[:-ndims], new_inner_shape),
                                  axis=0)
            new_tensor = tf.reshape(tensor, new_shape)
            new_tensor.set_shape(tensor.shape[:-ndims] + new_inner_shape)
            return new_tensor

        sequential = sequential_lib.Sequential([
            tf.keras.layers.Dense(2),
            tf.keras.layers.LSTM(2, return_state=True, return_sequences=True),
            tf.keras.layers.RNN(
                tf.keras.layers.StackedRNNCells([
                    tf.keras.layers.LSTMCell(1),
                    tf.keras.layers.LSTMCell(32),
                ], ),
                return_state=True,
                return_sequences=True,
            ),
            tf.keras.layers.Lambda(
                lambda t: reshape_inner_dims(t, 1, [4, 4, 2])),
            tf.keras.layers.Conv2D(2, 3),
            tf.keras.layers.Lambda(lambda t: reshape_inner_dims(t, 3, [8])),
            tf.keras.layers.GRU(2, return_state=True, return_sequences=True),
            dynamic_unroll_layer.DynamicUnroll(tf.keras.layers.LSTMCell(2)),
        ],
                                               input_spec=tf.TensorSpec(
                                                   (3, ), tf.float32))
        self.assertEqual(sequential.input_tensor_spec,
                         tf.TensorSpec((3, ), tf.float32))

        output_spec = sequential.create_variables()
        self.assertEqual(output_spec, tf.TensorSpec((2, ), dtype=tf.float32))

        tf.nest.map_structure(
            self.assertEqual,
            sequential.state_spec,
            (
                [  # LSTM
                    tf.TensorSpec((2, ), tf.float32),
                    tf.TensorSpec((2, ), tf.float32),
                ],
                (  # RNN(StackedRNNCells)
                    [
                        tf.TensorSpec((1, ), tf.float32),
                        tf.TensorSpec((1, ), tf.float32),
                    ],
                    [
                        tf.TensorSpec((32, ), tf.float32),
                        tf.TensorSpec((32, ), tf.float32),
                    ],
                ),
                # GRU
                tf.TensorSpec((2, ), tf.float32),
                [  # DynamicUnroll
                    tf.TensorSpec((2, ), tf.float32),
                    tf.TensorSpec((2, ), tf.float32),
                ]))

        inputs = tf.ones((8, 10, 3), dtype=tf.float32)
        outputs, _ = sequential(inputs)
        self.assertEqual(outputs.shape, tf.TensorShape([8, 10, 2]))
    def __init__(
        self,
        input_tensor_spec,
        preprocessing_layers=None,
        preprocessing_combiner=None,
        conv_layer_params=None,
        input_fc_layer_params=(75, 40),
        lstm_size=None,
        output_fc_layer_params=(75, 40),
        activation_fn=tf.keras.activations.relu,
        rnn_construction_fn=None,
        rnn_construction_kwargs=None,
        dtype=tf.float32,
        name='LSTMEncodingNetwork',
    ):
        """Creates an instance of `LSTMEncodingNetwork`.

    Input preprocessing is possible via `preprocessing_layers` and
    `preprocessing_combiner` Layers.  If the `preprocessing_layers` nest is
    shallower than `input_tensor_spec`, then the layers will get the subnests.
    For example, if:

    ```python
    input_tensor_spec = ([TensorSpec(3)] * 2, [TensorSpec(3)] * 5)
    preprocessing_layers = (Layer1(), Layer2())
    ```

    then preprocessing will call:

    ```python
    preprocessed = [preprocessing_layers[0](observations[0]),
                    preprocessing_layers[1](observations[1])]
    ```

    However if

    ```python
    preprocessing_layers = ([Layer1() for _ in range(2)],
                            [Layer2() for _ in range(5)])
    ```

    then preprocessing will call:
    ```python
    preprocessed = [
      layer(obs) for layer, obs in zip(flatten(preprocessing_layers),
                                       flatten(observations))
    ]
    ```

    Args:
      input_tensor_spec: A nest of `tensor_spec.TensorSpec` representing the
        observations.
      preprocessing_layers: (Optional.) A nest of `tf.keras.layers.Layer`
        representing preprocessing for the different observations. All of these
        layers must not be already built.
      preprocessing_combiner: (Optional.) A keras layer that takes a flat list
        of tensors and combines them.  Good options include
        `tf.keras.layers.Add` and `tf.keras.layers.Concatenate(axis=-1)`. This
        layer must not be already built.
      conv_layer_params: Optional list of convolution layers parameters, where
        each item is a length-three tuple indicating (filters, kernel_size,
        stride).
      input_fc_layer_params: Optional list of fully connected parameters, where
        each item is the number of units in the layer. These feed into the
        recurrent layer.
      lstm_size: An iterable of ints specifying the LSTM cell sizes to use.
      output_fc_layer_params: Optional list of fully connected parameters, where
        each item is the number of units in the layer. These are applied on top
        of the recurrent layer.
      activation_fn: Activation function, e.g. tf.keras.activations.relu,.
      rnn_construction_fn: (Optional.) Alternate RNN construction function, e.g.
        tf.keras.layers.LSTM, tf.keras.layers.CuDNNLSTM. It is invalid to
        provide both rnn_construction_fn and lstm_size.
      rnn_construction_kwargs: (Optional.) Dictionary or arguments to pass to
        rnn_construction_fn.

        The RNN will be constructed via:

        ```
        rnn_layer = rnn_construction_fn(**rnn_construction_kwargs)
        ```
      dtype: The dtype to use by the convolution, LSTM, and fully connected
        layers.
      name: A string representing name of the network.

    Raises:
      ValueError: If any of `preprocessing_layers` is already built.
      ValueError: If `preprocessing_combiner` is already built.
      ValueError: If neither `lstm_size` nor `rnn_construction_fn` are provided.
      ValueError: If both `lstm_size` and `rnn_construction_fn` are provided.
    """
        if lstm_size is None and rnn_construction_fn is None:
            raise ValueError(
                'Need to provide either custom rnn_construction_fn or '
                'lstm_size.')
        if lstm_size and rnn_construction_fn:
            raise ValueError(
                'Cannot provide both custom rnn_construction_fn and '
                'lstm_size.')

        kernel_initializer = tf.compat.v1.variance_scaling_initializer(
            scale=2.0, mode='fan_in', distribution='truncated_normal')

        input_encoder = encoding_network.EncodingNetwork(
            input_tensor_spec,
            preprocessing_layers=preprocessing_layers,
            preprocessing_combiner=preprocessing_combiner,
            conv_layer_params=conv_layer_params,
            fc_layer_params=input_fc_layer_params,
            activation_fn=activation_fn,
            kernel_initializer=kernel_initializer,
            dtype=dtype)

        # Create RNN cell
        if rnn_construction_fn:
            rnn_construction_kwargs = rnn_construction_kwargs or {}
            lstm_network = rnn_construction_fn(**rnn_construction_kwargs)
        else:
            if len(lstm_size) == 1:
                cell = tf.keras.layers.LSTMCell(
                    lstm_size[0], dtype=dtype, implementation=KERAS_LSTM_FUSED)
            else:
                cell = tf.keras.layers.StackedRNNCells([
                    tf.keras.layers.LSTMCell(size,
                                             dtype=dtype,
                                             implementation=KERAS_LSTM_FUSED)
                    for size in lstm_size
                ])
            lstm_network = dynamic_unroll_layer.DynamicUnroll(cell)

        output_encoder = []
        if output_fc_layer_params:
            output_encoder = [
                tf.keras.layers.Dense(num_units,
                                      activation=activation_fn,
                                      kernel_initializer=kernel_initializer,
                                      dtype=dtype)
                for num_units in output_fc_layer_params
            ]

        counter = [-1]

        def create_spec(size):
            counter[0] += 1
            return tensor_spec.TensorSpec(size,
                                          dtype=dtype,
                                          name='network_state_%d' % counter[0])

        state_spec = tf.nest.map_structure(create_spec,
                                           lstm_network.cell.state_size)

        super(LSTMEncodingNetwork,
              self).__init__(input_tensor_spec=input_tensor_spec,
                             state_spec=state_spec,
                             name=name)

        self._conv_layer_params = conv_layer_params
        self._input_encoder = input_encoder
        self._lstm_network = lstm_network
        self._output_encoder = output_encoder
Example #12
0
    def __init__(self,
                 input_tensor_spec,
                 observation_conv_layer_params=None,
                 observation_fc_layer_params=(200, ),
                 action_fc_layer_params=(200, ),
                 joint_fc_layer_params=(100, ),
                 lstm_size=None,
                 output_fc_layer_params=(200, 100),
                 activation_fn=tf.keras.activations.relu,
                 kernel_initializer=None,
                 last_kernel_initializer=None,
                 rnn_construction_fn=None,
                 rnn_construction_kwargs=None,
                 name='CriticRnnNetwork'):
        """Creates an instance of `CriticRnnNetwork`.

    Args:
      input_tensor_spec: A tuple of (observation, action) each of type
        `tensor_spec.TensorSpec` representing the inputs.
      observation_conv_layer_params: Optional list of convolution layers
        parameters to apply to the observations, where each item is a
        length-three tuple indicating (filters, kernel_size, stride).
      observation_fc_layer_params: Optional list of fully_connected parameters,
        where each item is the number of units in the layer. This is applied
        after the observation convultional layer.
      action_fc_layer_params: Optional list of parameters for a fully_connected
        layer to apply to the actions, where each item is the number of units
        in the layer.
      joint_fc_layer_params: Optional list of parameters for a fully_connected
        layer to apply after merging observations and actions, where each item
        is the number of units in the layer.
      lstm_size: An iterable of ints specifying the LSTM cell sizes to use.
      output_fc_layer_params: Optional list of fully_connected parameters, where
        each item is the number of units in the layer. This is applied after the
        LSTM cell.
      activation_fn: Activation function, e.g. tf.nn.relu, slim.leaky_relu, ...
      kernel_initializer: kernel initializer for all layers except for the value
        regression layer. If None, a VarianceScaling initializer will be used.
      last_kernel_initializer: kernel initializer for the value regression layer
        . If None, a RandomUniform initializer will be used.
      rnn_construction_fn: (Optional.) Alternate RNN construction function, e.g.
        tf.keras.layers.LSTM, tf.keras.layers.CuDNNLSTM. It is invalid to
        provide both rnn_construction_fn and lstm_size.
      rnn_construction_kwargs: (Optional.) Dictionary or arguments to pass to
        rnn_construction_fn.

        The RNN will be constructed via:

        ```
        rnn_layer = rnn_construction_fn(**rnn_construction_kwargs)
        ```
      name: A string representing name of the network.

    Raises:
      ValueError: If `observation_spec` or `action_spec` contains more than one
        item.
      ValueError: If neither `lstm_size` nor `rnn_construction_fn` are provided.
      ValueError: If both `lstm_size` and `rnn_construction_fn` are provided.
    """
        if lstm_size is None and rnn_construction_fn is None:
            raise ValueError(
                'Need to provide either custom rnn_construction_fn or '
                'lstm_size.')
        if lstm_size and rnn_construction_fn:
            raise ValueError(
                'Cannot provide both custom rnn_construction_fn and '
                'lstm_size.')

        observation_spec, action_spec = input_tensor_spec

        if len(tf.nest.flatten(observation_spec)) > 1:
            raise ValueError(
                'Only a single observation is supported by this network.')

        if len(tf.nest.flatten(action_spec)) > 1:
            raise ValueError(
                'Only a single action is supported by this network.')

        if kernel_initializer is None:
            kernel_initializer = tf.compat.v1.keras.initializers.VarianceScaling(
                scale=1. / 3., mode='fan_in', distribution='uniform')
        if last_kernel_initializer is None:
            last_kernel_initializer = tf.keras.initializers.RandomUniform(
                minval=-0.003, maxval=0.003)

        observation_layers = utils.mlp_layers(
            observation_conv_layer_params,
            observation_fc_layer_params,
            activation_fn=activation_fn,
            kernel_initializer=kernel_initializer,
            name='observation_encoding')

        action_layers = utils.mlp_layers(None,
                                         action_fc_layer_params,
                                         activation_fn=activation_fn,
                                         kernel_initializer=kernel_initializer,
                                         name='action_encoding')

        joint_layers = utils.mlp_layers(None,
                                        joint_fc_layer_params,
                                        activation_fn=activation_fn,
                                        kernel_initializer=kernel_initializer,
                                        name='joint_mlp')

        # Create RNN cell
        if rnn_construction_fn:
            rnn_construction_kwargs = rnn_construction_kwargs or {}
            lstm_network = rnn_construction_fn(**rnn_construction_kwargs)
        else:
            if len(lstm_size) == 1:
                cell = tf.keras.layers.LSTMCell(lstm_size[0])
            else:
                cell = tf.keras.layers.StackedRNNCells(
                    [tf.keras.layers.LSTMCell(size) for size in lstm_size])
            lstm_network = dynamic_unroll_layer.DynamicUnroll(cell)

        counter = [-1]

        def create_spec(size):
            counter[0] += 1
            return tensor_spec.TensorSpec(size,
                                          dtype=tf.float32,
                                          name='network_state_%d' % counter[0])

        state_spec = tf.nest.map_structure(create_spec,
                                           lstm_network.cell.state_size)

        output_layers = utils.mlp_layers(
            fc_layer_params=output_fc_layer_params, name='output')

        output_layers.append(
            tf.keras.layers.Dense(1,
                                  activation=None,
                                  kernel_initializer=last_kernel_initializer,
                                  name='value'))

        super(CriticRnnNetwork,
              self).__init__(input_tensor_spec=input_tensor_spec,
                             state_spec=state_spec,
                             name=name)

        self._observation_layers = observation_layers
        self._action_layers = action_layers
        self._joint_layers = joint_layers
        self._lstm_network = lstm_network
        self._output_layers = output_layers