def call(self, inputs, training=None, **kwargs):
        inputs, memory = inputs
        batch_size = K.shape(inputs)[0]
        seq_len = K.shape(inputs)[1]
        mem_mask = K.tile(K.ones_like(memory[:, :, :1], dtype=K.floatx()), [1, 1, seq_len])

        # Build content mask with random permutation
        ranges = K.tile(K.expand_dims(K.arange(0, seq_len), axis=-1), [1, batch_size])
        if self.enabled:
            shuffle = random_shuffle(ranges)
        else:
            shuffle = ranges
        if self.directional:
            shuffled = K.in_train_phase(shuffle, ranges, training)
        else:
            if self.enabled:
                shuffled = K.in_train_phase(shuffle, ranges + seq_len, training)
            else:
                shuffled = ranges + seq_len
        ranges = K.expand_dims(K.permute_dimensions(ranges, [1, 0]), axis=-1)
        shuffled = K.expand_dims(K.permute_dimensions(shuffled, [1, 0]), axis=1)
        content_mask = K.cast(ranges <= shuffled, dtype=K.floatx())

        # Build query mask based on content mask
        ranges = K.arange(0, seq_len)
        eye = K.equal(K.expand_dims(ranges, axis=0), K.expand_dims(ranges, axis=-1))
        eye = K.expand_dims(K.cast(eye, dtype=K.floatx()), axis=0)
        query_mask = content_mask * (1.0 - eye)

        content_mask = K.concatenate([mem_mask, content_mask], axis=1)
        query_mask = K.concatenate([mem_mask, query_mask], axis=1)
        return [
            K.permute_dimensions(content_mask, [0, 2, 1]),
            K.permute_dimensions(query_mask, [0, 2, 1]),
        ]
Beispiel #2
0
 def _reshape_to_batches(self, x):
     input_shape = K.shape(x)
     batch_size, seq_len, feature_dim = input_shape[0], input_shape[
         1], input_shape[2]
     x = K.reshape(x, (batch_size, seq_len, self.num_head, self.units_head))
     x = K.permute_dimensions(x, [0, 2, 1, 3])
     return K.reshape(
         x, (batch_size * self.num_head, seq_len, self.units_head))
Beispiel #3
0
 def _reshape_to_batches(x, head_num):
     input_shape = K.shape(x)
     batch_size, seq_len, feature_dim = input_shape[0], input_shape[
         1], input_shape[2]
     head_dim = feature_dim // head_num
     x = K.reshape(x, (batch_size, seq_len, head_num, head_dim))
     x = K.permute_dimensions(x, [0, 2, 1, 3])
     return K.reshape(x, (batch_size * head_num, seq_len, head_dim))
Beispiel #4
0
 def _attention_regularizer(self, attention):
     batch_size = K.cast(K.shape(attention)[0], K.floatx())
     input_len = K.shape(attention)[-1]
     indices = K.expand_dims(K.arange(0, input_len), axis=0)
     diagonal = K.expand_dims(K.arange(0, input_len), axis=-1)
     eye = K.cast(K.equal(indices, diagonal), K.floatx())
     return self.attention_regularizer_weight * K.sum(K.square(K.batch_dot(
         attention,
         K.permute_dimensions(attention, (0, 2, 1))) - eye)) / batch_size
Beispiel #5
0
    def call(self, inputs, mask=None, **kwargs):
        input_len = K.shape(inputs)[1]

        if self.attention_type == SeqSelfAttention.ATTENTION_TYPE_ADD:
            e = self._call_additive_emission(inputs)
        elif self.attention_type == SeqSelfAttention.ATTENTION_TYPE_MUL:
            e = self._call_multiplicative_emission(inputs)

        if self.attention_activation is not None:
            e = self.attention_activation(e)
        e = K.exp(e - K.max(e, axis=-1, keepdims=True))
        if self.attention_width is not None:
            if self.history_only:
                lower = K.arange(0, input_len) - (self.attention_width - 1)
            else:
                lower = K.arange(0, input_len) - self.attention_width // 2
            lower = K.expand_dims(lower, axis=-1)
            upper = lower + self.attention_width
            indices = K.expand_dims(K.arange(0, input_len), axis=0)
            e = e * K.cast(lower <= indices, K.floatx()) * K.cast(indices < upper, K.floatx())
        if mask is not None:
            mask = K.cast(mask, K.floatx())
            mask = K.expand_dims(mask)
            e = K.permute_dimensions(K.permute_dimensions(e * mask, (0, 2, 1)) * mask, (0, 2, 1))

        # a_{t} = \text{softmax}(e_t)
        s = K.sum(e, axis=-1, keepdims=True)
        a = e / (s + K.epsilon())

        # l_t = \sum_{t'} a_{t, t'} x_{t'}
        v = K.batch_dot(a, inputs)
        if self.attention_regularizer_weight > 0.0:
            self.add_loss(self._attention_regularizer(a))

        if self.return_attention:
            return [v, a]
        return v
Beispiel #6
0
 def _call_multiplicative_emission(self, inputs):
     # e_{t, t'} = x_t^T W_a x_{t'} + b_a
     e = K.batch_dot(K.dot(inputs, self.Wa), K.permute_dimensions(inputs, (0, 2, 1)))
     if self.use_attention_bias:
         e += self.ba[0]
     return e
Beispiel #7
0
    def call(self, inputs, mask=None, training=None):
        (inputs, content, memories, segment_mat, segment_embed, relatives,
         bias_context, bias_relative, bias_segment, permutation) = inputs
        full = K.concatenate([memories, content],
                             axis=1)  # (batch, prev_len + seq_len, units)

        kernel_q = self.kernel[:, :self.units]
        kernel_kv = self.kernel[:, self.units:self.units * 3]
        kernel_r = self.kernel[:, self.units * 3:self.units * 4]
        kernel_o = self.kernel[:, self.units * 4:self.units * 5]

        bias_q, bias_kv, bias_r, bias_o = (None, ) * 4
        if self.use_bias:
            bias_q = self.bias[:self.units]
            bias_kv = self.bias[self.units:self.units * 3]
            bias_r = self.bias[self.units * 3:self.units * 4]
            bias_o = self.bias[self.units * 4:self.units * 5]

        w_q = K.dot(inputs, kernel_q)  # (batch, seq_len, units)
        w_kv = K.dot(full, kernel_kv)  # (batch, prev_len + seq_len, units * 2)
        w_r = K.dot(relatives, kernel_r)  # (batch, prev_len + seq_len, units)
        if self.use_bias:
            w_q = K.bias_add(w_q, bias_q)
            w_kv = K.bias_add(w_kv, bias_kv)
            w_r = K.bias_add(w_r, bias_r)
        if self.activation is not None:
            w_q = self.activation(w_q)
            w_kv = self.activation(w_kv)
            w_r = self.activation(w_r)

        w_k = w_kv[:, :, :self.units]  # (batch, prev_len + seq_len, units)
        w_v = w_kv[:, :, self.units:]  # (batch, prev_len + seq_len, units)
        batch_size, q_len, k_len = K.shape(inputs)[0], K.shape(
            w_q)[1], K.shape(w_k)[1]

        w_qc = K.bias_add(w_q, bias_context)
        w_qc = self._reshape_to_batches(
            w_qc)  # (batch * n_head, seq_len, units_head)
        w_k = self._reshape_to_batches(
            w_k)  # (batch * n_head, prev_len + seq_len, units_head)
        a_context = K.batch_dot(
            w_qc, w_k, axes=2)  # (batch * n_head, seq_len, prev_len + seq_len)

        w_qr = K.bias_add(w_q, bias_relative)
        w_qr = self._reshape_to_batches(
            w_qr)  # (batch * n_head, seq_len, units_head)
        w_r = self._reshape_to_batches(
            w_r)  # (batch * n_head, prev_len + seq_len, units_head)
        a_relative = K.batch_dot(
            w_qr, w_r, axes=2)  # (batch * n_head, seq_len, prev_len + seq_len)
        a_relative = self._relative_shift(  # (batch * n_head, seq_len, prev_len + seq_len)
            a_relative,
            key_len_expected=K.shape(a_context)[-1],
        )

        w_qs = K.bias_add(w_q, bias_segment)
        w_qs = K.reshape(w_qs, (-1, q_len, self.num_head, self.units_head))
        w_qs = K.permute_dimensions(
            w_qs, (2, 0, 1, 3))  # (n_head, batch, seq_len, units_head)
        segment_embed = K.reshape(K.transpose(segment_embed),
                                  (self.num_head, 1, self.units_head, 2))
        segment_embed = K.tile(segment_embed, (1, batch_size, 1, 1))
        a_segment = K.batch_dot(w_qs, segment_embed,
                                axes=(3, 2))  # (n_head, batch, seq_len, 2)
        a_segment = K.permute_dimensions(
            a_segment, (1, 2, 3, 0))  # (batch, seq_len, 2, n_head)
        a_segment = K.batch_dot(
            segment_mat, a_segment,
            axes=(3, 2))  # (batch, seq_len, prev_len + seq_len, n_head)
        a_segment = K.reshape(K.permute_dimensions(a_segment, (0, 3, 1, 2)),
                              (-1, q_len, k_len))

        att = (a_context + a_relative + a_segment) / K.sqrt(
            K.constant(self.units_head, dtype=K.floatx()))
        exp = K.exp(att - K.max(att, axis=-1, keepdims=True))

        permutation = K.tile(K.expand_dims(permutation, axis=1),
                             [1, self.num_head, 1, 1])
        permutation = K.reshape(permutation, (-1, q_len, k_len))
        exp *= permutation
        if mask is not None and mask[0] is not None:
            mask = K.cast(mask[0], K.floatx())
            mask = K.concatenate([K.ones_like(memories[:, :, 0]), mask],
                                 axis=1)
            exp *= K.expand_dims(self._reshape_mask(mask), axis=1)

        att = exp / (K.sum(exp, axis=-1, keepdims=True) + K.epsilon())
        if self.att_drop_layer is not None:
            att = self.att_drop_layer(att, training=training)
        w_v = self._reshape_to_batches(
            w_v)  # (batch * n_head, prev_len + seq_len, units_head)
        w_o = K.batch_dot(att, w_v)  # (batch * n_head, seq_len, units_head)

        w_o = self._reshape_from_batches(w_o)  # (batch, seq_len, units)
        w_o = K.dot(w_o, kernel_o)  # (batch, seq_len, units)
        if self.use_bias:
            w_o = K.bias_add(w_o, bias_o)
        if self.activation is not None:
            w_o = self.activation(w_o)

        if TF_KERAS:
            # Add shape information to tensor when using `tf.keras`
            input_shape = K.int_shape(inputs)
            if input_shape[1] is not None:
                w_o = K.reshape(w_o, (-1, ) + input_shape[1:])
        return w_o