def _testDynamicUnrollResetsStateOnReset(self, cell_type): cell = cell_type() batch_size = 4 max_time = 7 inputs = tf.random.uniform((batch_size, max_time, 1)) reset_mask = (tf.random.normal((batch_size, max_time)) > 0) layer = dynamic_unroll_layer.DynamicUnroll(cell, dtype=tf.float32) outputs, final_state = layer(inputs, reset_mask=reset_mask) tf.nest.assert_same_structure(outputs, cell.output_size) tf.nest.assert_same_structure(final_state, cell.state_size) reset_mask, inputs, outputs, final_state = self.evaluate( (reset_mask, inputs, outputs, final_state)) self.assertAllClose(outputs[:, -1, :], final_state) # outputs will contain cumulative sums up until a reset expected_outputs = [] state = np.zeros_like(final_state) for i, frame in enumerate(np.transpose(inputs, [1, 0, 2])): state = state * np.reshape(~reset_mask[:, i], state.shape) + frame expected_outputs.append(np.array(state)) expected_outputs = np.transpose(expected_outputs, [1, 0, 2]) self.assertAllClose(outputs, expected_outputs)
def testMixOfNonRecurrentAndRecurrent(self): sequential = sequential_lib.Sequential( [ tf.keras.layers.Dense(2), tf.keras.layers.LSTM( 2, return_state=True, return_sequences=True), tf.keras.layers.RNN( tf.keras.layers.StackedRNNCells([ tf.keras.layers.LSTMCell(1), tf.keras.layers.LSTMCell(32), ], ), return_state=True, return_sequences=True, ), # Convert inner dimension to [4, 4, 2] for convolution. inner_reshape.InnerReshape([32], [4, 4, 2]), tf.keras.layers.Conv2D(2, 3), # Convert 3 inner dimensions to [?] for RNN. inner_reshape.InnerReshape([None] * 3, [-1]), tf.keras.layers.GRU( 2, return_state=True, return_sequences=True), dynamic_unroll_layer.DynamicUnroll( tf.keras.layers.LSTMCell(2)), ], input_spec=tf.TensorSpec((3, ), tf.float32)) self.assertEqual(sequential.input_tensor_spec, tf.TensorSpec((3, ), tf.float32)) output_spec = sequential.create_variables() self.assertEqual(output_spec, tf.TensorSpec((2, ), dtype=tf.float32)) tf.nest.map_structure( self.assertEqual, sequential.state_spec, ( [ # LSTM tf.TensorSpec((2, ), tf.float32), tf.TensorSpec((2, ), tf.float32), ], ( # RNN(StackedRNNCells) [ tf.TensorSpec((1, ), tf.float32), tf.TensorSpec((1, ), tf.float32), ], [ tf.TensorSpec((32, ), tf.float32), tf.TensorSpec((32, ), tf.float32), ], ), # GRU tf.TensorSpec((2, ), tf.float32), [ # DynamicUnroll tf.TensorSpec((2, ), tf.float32), tf.TensorSpec((2, ), tf.float32), ])) inputs = tf.ones((8, 10, 3), dtype=tf.float32) outputs, _ = sequential(inputs) self.assertEqual(outputs.shape, tf.TensorShape([8, 10, 2]))
def create_recurrent_network(input_fc_layer_units, lstm_size, output_fc_layer_units, num_actions): rnn_cell = tf.keras.layers.StackedRNNCells( [fused_lstm_cell(s) for s in lstm_size]) return sequential.Sequential( [dense(num_units) for num_units in input_fc_layer_units] + [dynamic_unroll_layer.DynamicUnroll(rnn_cell)] + [dense(num_units) for num_units in output_fc_layer_units] + [logits(num_actions)])
def testMixOfNonRecurrentAndRecurrent(self): sequential = sequential_lib.Sequential([ tf.keras.layers.Dense(2), tf.keras.layers.LSTM(2, return_state=True, return_sequences=True), tf.keras.layers.RNN( tf.keras.layers.StackedRNNCells([ tf.keras.layers.LSTMCell(1), tf.keras.layers.LSTMCell(32), ], ), return_state=True, return_sequences=True, ), tf.keras.layers.Reshape((-1, 4, 4, 2)), tf.keras.layers.Conv2D(2, 3), tf.keras.layers.TimeDistributed(tf.keras.layers.Flatten()), tf.keras.layers.GRU(2, return_state=True, return_sequences=True), dynamic_unroll_layer.DynamicUnroll(tf.keras.layers.LSTMCell(2)), ], input_spec=tf.TensorSpec( (3, ), tf.float32)) self.assertEqual(sequential.input_tensor_spec, tf.TensorSpec((3, ), tf.float32)) output_spec = sequential.create_variables() self.assertEqual(output_spec, tf.TensorSpec((2, ), dtype=tf.float32)) tf.nest.map_structure( self.assertEqual, sequential.state_spec, ( [ # LSTM tf.TensorSpec((2, ), tf.float32), tf.TensorSpec((2, ), tf.float32), ], [ # RNN(StackedRNNCells) [ tf.TensorSpec((1, ), tf.float32), tf.TensorSpec((1, ), tf.float32), ], [ tf.TensorSpec((32, ), tf.float32), tf.TensorSpec((32, ), tf.float32), ], ], [ # GRU tf.TensorSpec((2, ), tf.float32), ], [ # DynamicUnroll tf.TensorSpec((2, ), tf.float32), tf.TensorSpec((2, ), tf.float32), ])) inputs = tf.ones((8, 10, 3), dtype=tf.float32) outputs, _ = sequential(inputs) self.assertEqual(outputs.shape, tf.TensorShape([8, 10, 2]))
def testNoTimeDimensionMatchesSingleStep(self): cell = tf.keras.layers.LSTMCell(3) batch_size = 4 max_time = 1 inputs = tf.random.uniform((batch_size, max_time, 2), dtype=tf.float32) inputs_no_time = tf.squeeze(inputs, axis=1) layer = dynamic_unroll_layer.DynamicUnroll(cell, dtype=tf.float32) outputs, next_state = layer(inputs) outputs_squeezed_time = tf.squeeze(outputs, axis=1) outputs_no_time, next_state_no_time = layer(inputs_no_time) self.evaluate(tf.compat.v1.global_variables_initializer()) outputs_squeezed_time, next_state, outputs_no_time, next_state_no_time = ( self.evaluate((outputs_squeezed_time, next_state, outputs_no_time, next_state_no_time))) self.assertAllEqual(outputs_squeezed_time, outputs_no_time) self.assertAllEqual(next_state, next_state_no_time)
def testDynamicUnrollMatchesDynamicRNNWhenNoResetSingleTimeStep(self): cell = tf.compat.v1.nn.rnn_cell.LSTMCell(3) batch_size = 4 max_time = 1 inputs = tf.random.uniform((batch_size, max_time, 2), dtype=tf.float32) reset_mask = tf.zeros((batch_size, max_time), dtype=tf.bool) layer = dynamic_unroll_layer.DynamicUnroll(cell, dtype=tf.float32) outputs_dun, final_state_dun = layer(inputs, reset_mask) outputs_drnn, final_state_drnn = tf.compat.v1.nn.dynamic_rnn( cell, inputs, dtype=tf.float32) self.evaluate(tf.compat.v1.global_variables_initializer()) outputs_dun, final_state_dun, outputs_drnn, final_state_drnn = ( self.evaluate( (outputs_dun, final_state_dun, outputs_drnn, final_state_drnn))) self.assertAllClose(outputs_dun, outputs_drnn) self.assertAllClose(final_state_dun, final_state_drnn)
def testFromConfigLSTM(self): l1 = dynamic_unroll_layer.DynamicUnroll( tf.keras.layers.LSTMCell(units=3), parallel_iterations=10) l2 = dynamic_unroll_layer.DynamicUnroll.from_config(l1.get_config()) self.assertEqual(l1.get_config(), l2.get_config())
def testMixOfNonRecurrentAndRecurrent(self): sequential = sequential_lib.Sequential( [ tf.keras.layers.Dense(2), tf.keras.layers.LSTM( 2, return_state=True, return_sequences=True), tf.keras.layers.RNN( tf.keras.layers.StackedRNNCells([ tf.keras.layers.LSTMCell(1), tf.keras.layers.LSTMCell(32), ], ), return_state=True, return_sequences=True, ), # Convert inner dimension to [4, 4, 2] for convolution. inner_reshape.InnerReshape([32], [4, 4, 2]), tf.keras.layers.Conv2D(2, 3), # Convert 3 inner dimensions to [?] for RNN. inner_reshape.InnerReshape([None] * 3, [-1]), tf.keras.layers.GRU( 2, return_state=True, return_sequences=True), dynamic_unroll_layer.DynamicUnroll( tf.keras.layers.LSTMCell(2)), tf.keras.layers.Lambda( lambda x: tfd.MultivariateNormalDiag(loc=x, scale_diag=x)), ], input_spec=tf.TensorSpec((3, ), tf.float32)) # pytype: disable=wrong-arg-types self.assertEqual(sequential.input_tensor_spec, tf.TensorSpec((3, ), tf.float32)) output_spec = sequential.create_variables() self.assertIsInstance(output_spec, distribution_utils.DistributionSpecV2) output_event_spec = output_spec.event_spec self.assertEqual(output_event_spec, tf.TensorSpec((2, ), dtype=tf.float32)) tf.nest.map_structure( self.assertEqual, sequential.state_spec, ( [ # LSTM tf.TensorSpec((2, ), tf.float32), tf.TensorSpec((2, ), tf.float32), ], ( # RNN(StackedRNNCells) [ tf.TensorSpec((1, ), tf.float32), tf.TensorSpec((1, ), tf.float32), ], [ tf.TensorSpec((32, ), tf.float32), tf.TensorSpec((32, ), tf.float32), ], ), # GRU tf.TensorSpec((2, ), tf.float32), [ # DynamicUnroll tf.TensorSpec((2, ), tf.float32), tf.TensorSpec((2, ), tf.float32), ])) inputs = tf.ones((8, 10, 3), dtype=tf.float32) dist, _ = sequential(inputs) outputs = dist.sample() self.assertEqual(outputs.shape, tf.TensorShape([8, 10, 2]))
def __init__(self, input_tensor_spec, output_tensor_spec, conv_layer_params=None, input_fc_layer_params=(200, 100), lstm_size=(40, ), output_fc_layer_params=(200, 100), activation_fn=tf.keras.activations.relu, name='ActorRnnNetwork'): """Creates an instance of `ActorRnnNetwork`. Args: input_tensor_spec: A nest of `tensor_spec.TensorSpec` representing the input observations. output_tensor_spec: A nest of `tensor_spec.BoundedTensorSpec` representing the actions. conv_layer_params: Optional list of convolution layers parameters, where each item is a length-three tuple indicating (filters, kernel_size, stride). input_fc_layer_params: Optional list of fully_connected parameters, where each item is the number of units in the layer. This is applied before the LSTM cell. lstm_size: An iterable of ints specifying the LSTM cell sizes to use. output_fc_layer_params: Optional list of fully_connected parameters, where each item is the number of units in the layer. This is applied after the LSTM cell. activation_fn: Activation function, e.g. tf.nn.relu, slim.leaky_relu, ... name: A string representing name of the network. Returns: A nest of action tensors matching the action_spec. Raises: ValueError: If `input_tensor_spec` contains more than one observation. """ if len(tf.nest.flatten(input_tensor_spec)) > 1: raise ValueError( 'Only a single observation is supported by this network') input_layers = utils.mlp_layers(conv_layer_params, input_fc_layer_params, activation_fn=activation_fn, kernel_initializer=tf.compat.v1.keras. initializers.glorot_uniform(), name='input_mlp') # Create RNN cell if len(lstm_size) == 1: cell = tf.keras.layers.LSTMCell(lstm_size[0]) else: cell = tf.keras.layers.StackedRNNCells( [tf.keras.layers.LSTMCell(size) for size in lstm_size]) state_spec = tf.nest.map_structure( functools.partial(tensor_spec.TensorSpec, dtype=tf.float32, name='network_state_spec'), list(cell.state_size)) output_layers = utils.mlp_layers( fc_layer_params=output_fc_layer_params, name='output') flat_action_spec = tf.nest.flatten(output_tensor_spec) action_layers = [ tf.keras.layers.Dense( single_action_spec.shape.num_elements(), activation=tf.keras.activations.tanh, kernel_initializer=tf.keras.initializers.RandomUniform( minval=-0.003, maxval=0.003), name='action') for single_action_spec in flat_action_spec ] super(ActorRnnNetwork, self).__init__(input_tensor_spec=input_tensor_spec, state_spec=state_spec, name=name) self._output_tensor_spec = output_tensor_spec self._flat_action_spec = flat_action_spec self._conv_layer_params = conv_layer_params self._input_layers = input_layers self._dynamic_unroll = dynamic_unroll_layer.DynamicUnroll(cell) self._output_layers = output_layers self._action_layers = action_layers
def testMixOfNonRecurrentAndRecurrent(self): def reshape_inner_dims(tensor, ndims, new_inner_shape): """Reshapes tensor to: shape(tensor)[:-ndims] + new_inner_shape.""" tensor_shape = tf.shape(tensor) new_shape = tf.concat((tensor_shape[:-ndims], new_inner_shape), axis=0) new_tensor = tf.reshape(tensor, new_shape) new_tensor.set_shape(tensor.shape[:-ndims] + new_inner_shape) return new_tensor sequential = sequential_lib.Sequential([ tf.keras.layers.Dense(2), tf.keras.layers.LSTM(2, return_state=True, return_sequences=True), tf.keras.layers.RNN( tf.keras.layers.StackedRNNCells([ tf.keras.layers.LSTMCell(1), tf.keras.layers.LSTMCell(32), ], ), return_state=True, return_sequences=True, ), tf.keras.layers.Lambda( lambda t: reshape_inner_dims(t, 1, [4, 4, 2])), tf.keras.layers.Conv2D(2, 3), tf.keras.layers.Lambda(lambda t: reshape_inner_dims(t, 3, [8])), tf.keras.layers.GRU(2, return_state=True, return_sequences=True), dynamic_unroll_layer.DynamicUnroll(tf.keras.layers.LSTMCell(2)), ], input_spec=tf.TensorSpec( (3, ), tf.float32)) self.assertEqual(sequential.input_tensor_spec, tf.TensorSpec((3, ), tf.float32)) output_spec = sequential.create_variables() self.assertEqual(output_spec, tf.TensorSpec((2, ), dtype=tf.float32)) tf.nest.map_structure( self.assertEqual, sequential.state_spec, ( [ # LSTM tf.TensorSpec((2, ), tf.float32), tf.TensorSpec((2, ), tf.float32), ], ( # RNN(StackedRNNCells) [ tf.TensorSpec((1, ), tf.float32), tf.TensorSpec((1, ), tf.float32), ], [ tf.TensorSpec((32, ), tf.float32), tf.TensorSpec((32, ), tf.float32), ], ), # GRU tf.TensorSpec((2, ), tf.float32), [ # DynamicUnroll tf.TensorSpec((2, ), tf.float32), tf.TensorSpec((2, ), tf.float32), ])) inputs = tf.ones((8, 10, 3), dtype=tf.float32) outputs, _ = sequential(inputs) self.assertEqual(outputs.shape, tf.TensorShape([8, 10, 2]))
def __init__( self, input_tensor_spec, preprocessing_layers=None, preprocessing_combiner=None, conv_layer_params=None, input_fc_layer_params=(75, 40), lstm_size=None, output_fc_layer_params=(75, 40), activation_fn=tf.keras.activations.relu, rnn_construction_fn=None, rnn_construction_kwargs=None, dtype=tf.float32, name='LSTMEncodingNetwork', ): """Creates an instance of `LSTMEncodingNetwork`. Input preprocessing is possible via `preprocessing_layers` and `preprocessing_combiner` Layers. If the `preprocessing_layers` nest is shallower than `input_tensor_spec`, then the layers will get the subnests. For example, if: ```python input_tensor_spec = ([TensorSpec(3)] * 2, [TensorSpec(3)] * 5) preprocessing_layers = (Layer1(), Layer2()) ``` then preprocessing will call: ```python preprocessed = [preprocessing_layers[0](observations[0]), preprocessing_layers[1](observations[1])] ``` However if ```python preprocessing_layers = ([Layer1() for _ in range(2)], [Layer2() for _ in range(5)]) ``` then preprocessing will call: ```python preprocessed = [ layer(obs) for layer, obs in zip(flatten(preprocessing_layers), flatten(observations)) ] ``` Args: input_tensor_spec: A nest of `tensor_spec.TensorSpec` representing the observations. preprocessing_layers: (Optional.) A nest of `tf.keras.layers.Layer` representing preprocessing for the different observations. All of these layers must not be already built. preprocessing_combiner: (Optional.) A keras layer that takes a flat list of tensors and combines them. Good options include `tf.keras.layers.Add` and `tf.keras.layers.Concatenate(axis=-1)`. This layer must not be already built. conv_layer_params: Optional list of convolution layers parameters, where each item is a length-three tuple indicating (filters, kernel_size, stride). input_fc_layer_params: Optional list of fully connected parameters, where each item is the number of units in the layer. These feed into the recurrent layer. lstm_size: An iterable of ints specifying the LSTM cell sizes to use. output_fc_layer_params: Optional list of fully connected parameters, where each item is the number of units in the layer. These are applied on top of the recurrent layer. activation_fn: Activation function, e.g. tf.keras.activations.relu,. rnn_construction_fn: (Optional.) Alternate RNN construction function, e.g. tf.keras.layers.LSTM, tf.keras.layers.CuDNNLSTM. It is invalid to provide both rnn_construction_fn and lstm_size. rnn_construction_kwargs: (Optional.) Dictionary or arguments to pass to rnn_construction_fn. The RNN will be constructed via: ``` rnn_layer = rnn_construction_fn(**rnn_construction_kwargs) ``` dtype: The dtype to use by the convolution, LSTM, and fully connected layers. name: A string representing name of the network. Raises: ValueError: If any of `preprocessing_layers` is already built. ValueError: If `preprocessing_combiner` is already built. ValueError: If neither `lstm_size` nor `rnn_construction_fn` are provided. ValueError: If both `lstm_size` and `rnn_construction_fn` are provided. """ if lstm_size is None and rnn_construction_fn is None: raise ValueError( 'Need to provide either custom rnn_construction_fn or ' 'lstm_size.') if lstm_size and rnn_construction_fn: raise ValueError( 'Cannot provide both custom rnn_construction_fn and ' 'lstm_size.') kernel_initializer = tf.compat.v1.variance_scaling_initializer( scale=2.0, mode='fan_in', distribution='truncated_normal') input_encoder = encoding_network.EncodingNetwork( input_tensor_spec, preprocessing_layers=preprocessing_layers, preprocessing_combiner=preprocessing_combiner, conv_layer_params=conv_layer_params, fc_layer_params=input_fc_layer_params, activation_fn=activation_fn, kernel_initializer=kernel_initializer, dtype=dtype) # Create RNN cell if rnn_construction_fn: rnn_construction_kwargs = rnn_construction_kwargs or {} lstm_network = rnn_construction_fn(**rnn_construction_kwargs) else: if len(lstm_size) == 1: cell = tf.keras.layers.LSTMCell( lstm_size[0], dtype=dtype, implementation=KERAS_LSTM_FUSED) else: cell = tf.keras.layers.StackedRNNCells([ tf.keras.layers.LSTMCell(size, dtype=dtype, implementation=KERAS_LSTM_FUSED) for size in lstm_size ]) lstm_network = dynamic_unroll_layer.DynamicUnroll(cell) output_encoder = [] if output_fc_layer_params: output_encoder = [ tf.keras.layers.Dense(num_units, activation=activation_fn, kernel_initializer=kernel_initializer, dtype=dtype) for num_units in output_fc_layer_params ] counter = [-1] def create_spec(size): counter[0] += 1 return tensor_spec.TensorSpec(size, dtype=dtype, name='network_state_%d' % counter[0]) state_spec = tf.nest.map_structure(create_spec, lstm_network.cell.state_size) super(LSTMEncodingNetwork, self).__init__(input_tensor_spec=input_tensor_spec, state_spec=state_spec, name=name) self._conv_layer_params = conv_layer_params self._input_encoder = input_encoder self._lstm_network = lstm_network self._output_encoder = output_encoder
def __init__(self, input_tensor_spec, observation_conv_layer_params=None, observation_fc_layer_params=(200, ), action_fc_layer_params=(200, ), joint_fc_layer_params=(100, ), lstm_size=None, output_fc_layer_params=(200, 100), activation_fn=tf.keras.activations.relu, kernel_initializer=None, last_kernel_initializer=None, rnn_construction_fn=None, rnn_construction_kwargs=None, name='CriticRnnNetwork'): """Creates an instance of `CriticRnnNetwork`. Args: input_tensor_spec: A tuple of (observation, action) each of type `tensor_spec.TensorSpec` representing the inputs. observation_conv_layer_params: Optional list of convolution layers parameters to apply to the observations, where each item is a length-three tuple indicating (filters, kernel_size, stride). observation_fc_layer_params: Optional list of fully_connected parameters, where each item is the number of units in the layer. This is applied after the observation convultional layer. action_fc_layer_params: Optional list of parameters for a fully_connected layer to apply to the actions, where each item is the number of units in the layer. joint_fc_layer_params: Optional list of parameters for a fully_connected layer to apply after merging observations and actions, where each item is the number of units in the layer. lstm_size: An iterable of ints specifying the LSTM cell sizes to use. output_fc_layer_params: Optional list of fully_connected parameters, where each item is the number of units in the layer. This is applied after the LSTM cell. activation_fn: Activation function, e.g. tf.nn.relu, slim.leaky_relu, ... kernel_initializer: kernel initializer for all layers except for the value regression layer. If None, a VarianceScaling initializer will be used. last_kernel_initializer: kernel initializer for the value regression layer . If None, a RandomUniform initializer will be used. rnn_construction_fn: (Optional.) Alternate RNN construction function, e.g. tf.keras.layers.LSTM, tf.keras.layers.CuDNNLSTM. It is invalid to provide both rnn_construction_fn and lstm_size. rnn_construction_kwargs: (Optional.) Dictionary or arguments to pass to rnn_construction_fn. The RNN will be constructed via: ``` rnn_layer = rnn_construction_fn(**rnn_construction_kwargs) ``` name: A string representing name of the network. Raises: ValueError: If `observation_spec` or `action_spec` contains more than one item. ValueError: If neither `lstm_size` nor `rnn_construction_fn` are provided. ValueError: If both `lstm_size` and `rnn_construction_fn` are provided. """ if lstm_size is None and rnn_construction_fn is None: raise ValueError( 'Need to provide either custom rnn_construction_fn or ' 'lstm_size.') if lstm_size and rnn_construction_fn: raise ValueError( 'Cannot provide both custom rnn_construction_fn and ' 'lstm_size.') observation_spec, action_spec = input_tensor_spec if len(tf.nest.flatten(observation_spec)) > 1: raise ValueError( 'Only a single observation is supported by this network.') if len(tf.nest.flatten(action_spec)) > 1: raise ValueError( 'Only a single action is supported by this network.') if kernel_initializer is None: kernel_initializer = tf.compat.v1.keras.initializers.VarianceScaling( scale=1. / 3., mode='fan_in', distribution='uniform') if last_kernel_initializer is None: last_kernel_initializer = tf.keras.initializers.RandomUniform( minval=-0.003, maxval=0.003) observation_layers = utils.mlp_layers( observation_conv_layer_params, observation_fc_layer_params, activation_fn=activation_fn, kernel_initializer=kernel_initializer, name='observation_encoding') action_layers = utils.mlp_layers(None, action_fc_layer_params, activation_fn=activation_fn, kernel_initializer=kernel_initializer, name='action_encoding') joint_layers = utils.mlp_layers(None, joint_fc_layer_params, activation_fn=activation_fn, kernel_initializer=kernel_initializer, name='joint_mlp') # Create RNN cell if rnn_construction_fn: rnn_construction_kwargs = rnn_construction_kwargs or {} lstm_network = rnn_construction_fn(**rnn_construction_kwargs) else: if len(lstm_size) == 1: cell = tf.keras.layers.LSTMCell(lstm_size[0]) else: cell = tf.keras.layers.StackedRNNCells( [tf.keras.layers.LSTMCell(size) for size in lstm_size]) lstm_network = dynamic_unroll_layer.DynamicUnroll(cell) counter = [-1] def create_spec(size): counter[0] += 1 return tensor_spec.TensorSpec(size, dtype=tf.float32, name='network_state_%d' % counter[0]) state_spec = tf.nest.map_structure(create_spec, lstm_network.cell.state_size) output_layers = utils.mlp_layers( fc_layer_params=output_fc_layer_params, name='output') output_layers.append( tf.keras.layers.Dense(1, activation=None, kernel_initializer=last_kernel_initializer, name='value')) super(CriticRnnNetwork, self).__init__(input_tensor_spec=input_tensor_spec, state_spec=state_spec, name=name) self._observation_layers = observation_layers self._action_layers = action_layers self._joint_layers = joint_layers self._lstm_network = lstm_network self._output_layers = output_layers