class AttentionCell(tf.nn.rnn_cell.RNNCell): def __init__(self, features, recurrent_max_abs): super(AttentionCell, self).__init__() self._in_channels = features.get_shape()[2].value # DICT_SIZE self._features = features self._bias = self.add_variable("bias", shape=[1], initializer=tf.zeros_initializer()) self._filt_shape = (-1, 5, self._in_channels, 1) self._indrnn = IndRNNCell(self._filt_shape[1] * self._in_channels, recurrent_max_abs=recurrent_max_abs) @property def state_size(self): return self._indrnn.state_size @property def output_size(self): return self._in_channels + 2 def build(self, inputs_shape): self._indrnn.build(inputs_shape) def __call__(self, inputs, state, scope=None): filt, new_state = self._indrnn(inputs, state, scope) filt = tf.reshape(filt, self._filt_shape) # filt has shape (B, width, in_channels, out_channels) conv = batchwise_conv_2(self._features, filt) # conv has shape (B, width, out_channels) conv = tf.nn.relu(conv + self._bias) # (B, width, 1) # TODO: try other methods for squashing or normalizing, etc attention = tf.nn.softmax(conv, axis=1) # (B, width, 1) output = tf.multiply(attention, conv) # (B, width, dict_size) output = tf.reduce_mean(output, axis=1) # (B, dict_size) output = tf.concat([output, inputs], axis=1) # (B, dict_size + 2) return output, new_state
class IndCatCell(tf.nn.rnn_cell.RNNCell): def __init__(self, num_units, recurrent_max_abs): super(IndCatCell, self).__init__() self._indrnn = IndRNNCell( num_units, recurrent_max_abs=recurrent_max_abs) @property def state_size(self): return self._indrnn.state_size @property def output_size(self): return self._indrnn.output_size def build(self, inputs_shape): self._indrnn.build(inputs_shape) def __call__(self, inputs, state, scope=None): out, state = self._indrnn(inputs, state, scope) pad_size = self._indrnn.output_size - tf.shape(inputs)[1] out = tf.pad(inputs, [[0, 0], [0, pad_size]]) # residual connection return out, state