def conv_lstm_2d(inputs, state, output_channels, kernel_size=5, name=None, spatial_dims=None): """2D Convolutional LSTM.""" input_shape = common_layers.shape_list(inputs) batch_size, input_channels = input_shape[0], input_shape[-1] if spatial_dims is None: input_shape = input_shape[1:] else: input_shape = spatial_dims + [input_channels] cell = contrib.rnn().ConvLSTMCell( 2, input_shape, output_channels, [kernel_size, kernel_size], name=name) if state is None: state = cell.zero_state(batch_size, tf.float32) outputs, new_state = cell(inputs, state) return outputs, new_state
def build_controller(self): """Create the RNN and output projections for controlling the stack. """ with tf.name_scope("controller"): self.rnn = contrib.rnn().BasicRNNCell(self._num_units) self._input_proj = self.add_variable( "input_projection_weights", shape=[self._embedding_size * (self._num_read_heads + 1), self._num_units], dtype=self.dtype) self._input_bias = self.add_variable( "input_projection_bias", shape=[self._num_units], initializer=tf.zeros_initializer(dtype=self.dtype)) self._push_proj, self._push_bias = self.add_scalar_projection( "push", self._num_write_heads) self._pop_proj, self._pop_bias = self.add_scalar_projection( "pop", self._num_write_heads) self._value_proj, self._value_bias = self.add_vector_projection( "value", self._num_write_heads) self._output_proj, self._output_bias = self.add_vector_projection( "output", 1)
def _rnn(self, inputs, name, initial_state=None, sequence_length=None): """A helper method to build tf.nn.dynamic_rnn. Args: inputs: The inputs to the RNN. A tensor of shape [batch_size, max_seq_length, embedding_size] name: A namespace for the RNN. initial_state: An optional initial state for the RNN. sequence_length: An optional sequence length for the RNN. Returns: A tf.nn.dynamic_rnn operator. """ layers = [self.cell(layer_size) for layer_size in self._hparams.controller_layer_sizes] with tf.variable_scope(name): return tf.nn.dynamic_rnn( contrib.rnn().MultiRNNCell(layers), inputs, initial_state=initial_state, sequence_length=sequence_length, dtype=tf.float32, time_major=False)