Esempio n. 1
0
class AttentionCell(tf.nn.rnn_cell.RNNCell):
    def __init__(self, features, recurrent_max_abs):
        super(AttentionCell, self).__init__()

        self._in_channels = features.get_shape()[2].value  # DICT_SIZE
        self._features = features
        self._bias = self.add_variable("bias",
                                       shape=[1],
                                       initializer=tf.zeros_initializer())
        self._filt_shape = (-1, 5, self._in_channels, 1)
        self._indrnn = IndRNNCell(self._filt_shape[1] * self._in_channels,
                                  recurrent_max_abs=recurrent_max_abs)

    @property
    def state_size(self):
        return self._indrnn.state_size

    @property
    def output_size(self):
        return self._in_channels + 2

    def build(self, inputs_shape):
        self._indrnn.build(inputs_shape)

    def __call__(self, inputs, state, scope=None):
        filt, new_state = self._indrnn(inputs, state, scope)

        filt = tf.reshape(filt, self._filt_shape)
        # filt has shape (B, width, in_channels, out_channels)

        conv = batchwise_conv_2(self._features, filt)
        # conv has shape (B, width, out_channels)

        conv = tf.nn.relu(conv + self._bias)  # (B, width, 1)

        # TODO: try other methods for squashing or normalizing, etc
        attention = tf.nn.softmax(conv, axis=1)  # (B, width, 1)

        output = tf.multiply(attention, conv)  # (B, width, dict_size)
        output = tf.reduce_mean(output, axis=1)  # (B, dict_size)

        output = tf.concat([output, inputs], axis=1)  # (B, dict_size + 2)

        return output, new_state
Esempio n. 2
0
class IndCatCell(tf.nn.rnn_cell.RNNCell):
    def __init__(self, num_units, recurrent_max_abs):
        super(IndCatCell, self).__init__()
        self._indrnn = IndRNNCell(
            num_units,
            recurrent_max_abs=recurrent_max_abs)

    @property
    def state_size(self):
        return self._indrnn.state_size

    @property
    def output_size(self):
        return self._indrnn.output_size

    def build(self, inputs_shape):
        self._indrnn.build(inputs_shape)

    def __call__(self, inputs, state, scope=None):
        out, state = self._indrnn(inputs, state, scope)
        pad_size = self._indrnn.output_size - tf.shape(inputs)[1]
        out = tf.pad(inputs, [[0, 0], [0, pad_size]]) # residual connection
        return out, state