Ejemplo n.º 1
0
    def call(self, inputs, **kwargs):
        embeddings = inputs[1:1 + (self.cluster_num + 1)]
        projections = inputs[1 + (self.cluster_num + 1):]
        inputs = inputs[0]
        if self.div_val == 1:
            if self.embed_dim != self.input_dim or self.force_projection:
                projection = self.projections
                if projection is None:
                    projection = projections[0]
                inputs = K.dot(inputs, K.transpose(projection))
            embedding = self.embeddings
            if embedding is None:
                embedding = embeddings[0]
            out = K.dot(inputs, K.transpose(embedding))
            if self.use_bias:
                out = K.bias_add(out, self.biases)
            out = keras.activations.softmax(out, axis=-1)
        else:
            cluster_probs = None
            outputs = []
            for i in range(len(self.cutoffs) - 1):
                embed_dim = self.embed_dim // (self.div_val ** i)
                if embed_dim != self.input_dim or self.force_projection:
                    projection = self.projections[i]
                    if projection is None:
                        projection = projections[i]
                    cluster_input = K.dot(inputs, K.transpose(projection))
                else:
                    cluster_input = inputs
                embedding = self.embeddings[i]
                if embedding is None:
                    embedding = embeddings[i]
                cluster_output = K.dot(cluster_input, K.transpose(embedding))
                if self.use_bias:
                    cluster_output = K.bias_add(cluster_output, self.biases[i])
                if cluster_probs is None:
                    cluster_probs = K.dot(cluster_input, self.kernel_cluster)
                    if self.use_bias:
                        cluster_probs = K.bias_add(cluster_probs, self.bias_cluster)
                    cluster_output = K.concatenate([cluster_output, cluster_probs], axis=-1)
                    cluster_output = keras.activations.softmax(cluster_output, axis=-1)
                    cluster_probs = cluster_output[..., -self.cluster_num:]
                    cluster_output = cluster_output[..., :-self.cluster_num]
                else:
                    cluster_output = keras.activations.softmax(cluster_output, axis=-1)
                    cluster_output = cluster_output * K.expand_dims(cluster_probs[..., i - 1])
                outputs.append(cluster_output)
            out = K.concatenate(outputs, axis=-1)

        return out
Ejemplo n.º 2
0
 def call(self, inputs, mask=None):
     if isinstance(inputs, list):
         q, k, v = inputs
     else:
         q = k = v = inputs
     if isinstance(mask, list):
         q_mask, k_mask, v_mask = mask
     else:
         q_mask = k_mask = v_mask = mask
     q = K.dot(q, self.Wq)
     k = K.dot(k, self.Wk)
     v = K.dot(v, self.Wv)
     if self.use_bias:
         q += self.bq
         k += self.bk
         v += self.bv
     if self.activation is not None:
         q = self.activation(q)
         k = self.activation(k)
         v = self.activation(v)
     y = ScaledDotProductAttention(
         history_only=self.history_only,
         name='%s-Attention' % self.name,
     )(
         inputs=[
             self._reshape_to_batches(q, self.head_num),
             self._reshape_to_batches(k, self.head_num),
             self._reshape_to_batches(v, self.head_num),
         ],
         mask=[
             self._reshape_mask(q_mask, self.head_num),
             self._reshape_mask(k_mask, self.head_num),
             self._reshape_mask(v_mask, self.head_num),
         ],
     )
     y = self._reshape_from_batches(y, self.head_num)
     y = K.dot(y, self.Wo)
     if self.use_bias:
         y += self.bo
     if self.activation is not None:
         y = self.activation(y)
     if TF_KERAS:
         # Add shape information to tensor when using `tf.keras`
         input_shape = [K.int_shape(q), K.int_shape(k), K.int_shape(v)]
         output_shape = self.compute_output_shape(input_shape)
         if output_shape[1] is not None:
             output_shape = (-1, ) + output_shape[1:]
             y = K.reshape(y, output_shape)
     return y
Ejemplo n.º 3
0
    def _call_additive_emission(self, inputs):
        input_shape = K.shape(inputs)
        batch_size, input_len = input_shape[0], input_shape[1]

        # h_{t, t'} = \tanh(x_t^T W_t + x_{t'}^T W_x + b_h)
        q = K.expand_dims(K.dot(inputs, self.Wt), 2)
        k = K.expand_dims(K.dot(inputs, self.Wx), 1)
        if self.use_additive_bias:
            h = K.tanh(q + k + self.bh)
        else:
            h = K.tanh(q + k)

        # e_{t, t'} = W_a h_{t, t'} + b_a
        if self.use_attention_bias:
            e = K.reshape(K.dot(h, self.Wa) + self.ba, (batch_size, input_len, input_len))
        else:
            e = K.reshape(K.dot(h, self.Wa), (batch_size, input_len, input_len))
        return e
Ejemplo n.º 4
0
 def call(self, inputs, mask=None, **kwargs):
     inputs, embeddings = inputs
     if self.stop_gradient:
         embeddings = K.stop_gradient(embeddings)
     outputs = K.dot(inputs, K.transpose(embeddings))
     if self.use_bias:
         outputs = K.bias_add(outputs, self.bias)
     if self.return_logits:
         return outputs
     return keras.activations.softmax(outputs)
Ejemplo n.º 5
0
 def call(self, inputs, mask=None):
     input_shape = K.shape(inputs)
     if self.mode == self.MODE_ADD:
         batch_size, seq_len, output_dim = input_shape[0], input_shape[
             1], input_shape[2]
         pos_input = K.tile(K.expand_dims(K.arange(0, seq_len), axis=0),
                            [batch_size, 1])
     elif self.mode == self.MODE_CONCAT:
         batch_size, seq_len, output_dim = input_shape[0], input_shape[
             1], self.output_dim
         pos_input = K.tile(K.expand_dims(K.arange(0, seq_len), axis=0),
                            [batch_size, 1])
     else:
         output_dim = self.output_dim
         pos_input = inputs
     if K.dtype(pos_input) != K.floatx():
         pos_input = K.cast(pos_input, K.floatx())
     evens = K.arange(0, output_dim // 2) * 2
     odds = K.arange(0, output_dim // 2) * 2 + 1
     even_embd = K.sin(
         K.dot(
             K.expand_dims(pos_input, -1),
             K.expand_dims(
                 1.0 / K.pow(
                     10000.0,
                     K.cast(evens, K.floatx()) /
                     K.cast(output_dim, K.floatx())), 0)))
     odd_embd = K.cos(
         K.dot(
             K.expand_dims(pos_input, -1),
             K.expand_dims(
                 1.0 / K.pow(
                     10000.0,
                     K.cast((odds - 1), K.floatx()) /
                     K.cast(output_dim, K.floatx())), 0)))
     embd = K.stack([even_embd, odd_embd], axis=-1)
     output = K.reshape(embd, [-1, K.shape(inputs)[1], output_dim])
     if self.mode == self.MODE_CONCAT:
         output = K.concatenate([inputs, output], axis=-1)
     if self.mode == self.MODE_ADD:
         output += inputs
     return output
Ejemplo n.º 6
0
 def call(self, inputs, **kwargs):
     if K.dtype(inputs) != 'int32':
         inputs = K.cast(inputs, 'int32')
     if self.div_val == 1:
         out = K.gather(self.embeddings, inputs)
         if self.embed_dim != self.output_dim or self.force_projection:
             out = K.dot(out, self.projections)
     else:
         out = K.tile(
             K.expand_dims(K.zeros_like(inputs, dtype=K.floatx()), axis=-1),
             (1, ) * K.ndim(inputs) + (self.output_dim, ),
         )
         for i in range(len(self.cutoffs) - 1):
             embed_dim = self.embed_dim // (self.div_val**i)
             low, high = self.cutoffs[i], self.cutoffs[i + 1]
             mask = K.cast(low <= inputs, K.floatx()) * K.cast(
                 inputs < high, K.floatx())
             selected = K.gather(self.embeddings[i],
                                 (inputs - low) * K.cast(mask, 'int32'))
             if embed_dim != self.output_dim or self.force_projection:
                 projected = K.dot(selected, self.projections[i])
             else:
                 projected = selected
             out += projected * K.expand_dims(mask, axis=-1)
     if self.return_embeddings or self.return_projections:
         out = [out]
     if self.return_embeddings:
         if self.div_val == 1:
             out += [self.embeddings]
         else:
             out += [K.identity(embed) for embed in self.embeddings]
     if self.return_projections:
         if self.div_val == 1:
             if self.projections is not None:
                 out += [self.projections]
         else:
             out += [K.identity(proj) for proj in self.projections]
     return out
Ejemplo n.º 7
0
 def call(self, x, mask=None):
     logits = K.dot(x, self.W)
     if self.use_bias:
         logits += self.b
     x_shape = K.shape(x)
     logits = K.reshape(logits, (x_shape[0], x_shape[1]))
     ai = K.exp(logits - K.max(logits, axis=-1, keepdims=True))
     if mask is not None:
         mask = K.cast(mask, K.floatx())
         ai = ai * mask
     att_weights = ai / (K.sum(ai, axis=1, keepdims=True) + K.epsilon())
     weighted_input = x * K.expand_dims(att_weights)
     result = K.sum(weighted_input, axis=1)
     if self.return_attention:
         return [result, att_weights]
     return result
Ejemplo n.º 8
0
 def build(self, input_shape):
     if type(input_shape) == list:
         self.input_spec = list(
             map(lambda x: keras.engine.InputSpec(shape=x), input_shape))
     else:
         self.input_spec = keras.engine.InputSpec(shape=input_shape)
     if isinstance(self.layers, list) and len(self.layers) == 0:
         self.layer.build(input_shape)
         config = self.layer.get_config()
         name = config['name']
         self.layers = []
         for i in range(self.layer_num):
             copied = copy.copy(config)
             copied['name'] = name + '_{}'.format(i + 1)
             self.layers.append(self.layer.__class__.from_config(copied))
     for layer in self.layers:
         layer.build(input_shape)
     if self.hidden_dim is not None:
         self.W = self.add_weight(
             shape=(int(input_shape[-1]), self.hidden_dim * self.layer_num),
             name='{}_W'.format(self.name),
             initializer=keras.initializers.get('uniform'),
         )
         if self.use_bias:
             self.b = self.add_weight(
                 shape=(self.hidden_dim * self.layer_num, ),
                 name='{}_b'.format(self.name),
                 initializer=keras.initializers.get('zeros'),
             )
     if self.reg_index:
         for i, (index, interval, weight) in enumerate(
                 zip(self.reg_index, self.reg_slice, self.reg_weight)):
             weights = []
             if type(interval) is slice:
                 interval = (interval, )
             for layer in self.layers:
                 if interval is None:
                     weights.append(K.flatten(layer.get_weights()[index]))
                 else:
                     weights.append(
                         K.flatten(layer.get_weights()[index][interval]))
             weights = K.stack(weights)
             self.add_loss(weight * K.sum(
                 K.square(
                     K.dot(weights, K.transpose(weights)) -
                     K.eye(len(self.layers)))))
     super(MultiHead, self).build(input_shape)
Ejemplo n.º 9
0
 def call(self, inputs, training=None, mask=None):
     kwargs = {}
     if keras.utils.generic_utils.has_arg(self.layer.call, 'training'):
         kwargs['training'] = training
     if keras.utils.generic_utils.has_arg(self.layer.call,
                                          'mask') and mask is not None:
         kwargs['mask'] = mask
     if self.hidden_dim is None:
         outputs = [
             K.expand_dims(layer.call(inputs, **kwargs))
             for layer in self.layers
         ]
     else:
         outputs = []
         for i, layer in enumerate(self.layers):
             begin = i * self.hidden_dim
             end = begin + self.hidden_dim
             transformed = K.dot(inputs, self.W[:, begin:end])
             if self.use_bias:
                 transformed += self.b[begin:end]
             outputs.append(K.expand_dims(layer.call(transformed,
                                                     **kwargs)))
     return K.concatenate(outputs, axis=-1)
Ejemplo n.º 10
0
 def _call_multiplicative_emission(self, inputs):
     # e_{t, t'} = x_t^T W_a x_{t'} + b_a
     e = K.batch_dot(K.dot(inputs, self.Wa), K.permute_dimensions(inputs, (0, 2, 1)))
     if self.use_attention_bias:
         e += self.ba[0]
     return e
Ejemplo n.º 11
0
    def call(self, inputs, mask=None, training=None):
        (inputs, content, memories, segment_mat, segment_embed, relatives,
         bias_context, bias_relative, bias_segment, permutation) = inputs
        full = K.concatenate([memories, content],
                             axis=1)  # (batch, prev_len + seq_len, units)

        kernel_q = self.kernel[:, :self.units]
        kernel_kv = self.kernel[:, self.units:self.units * 3]
        kernel_r = self.kernel[:, self.units * 3:self.units * 4]
        kernel_o = self.kernel[:, self.units * 4:self.units * 5]

        bias_q, bias_kv, bias_r, bias_o = (None, ) * 4
        if self.use_bias:
            bias_q = self.bias[:self.units]
            bias_kv = self.bias[self.units:self.units * 3]
            bias_r = self.bias[self.units * 3:self.units * 4]
            bias_o = self.bias[self.units * 4:self.units * 5]

        w_q = K.dot(inputs, kernel_q)  # (batch, seq_len, units)
        w_kv = K.dot(full, kernel_kv)  # (batch, prev_len + seq_len, units * 2)
        w_r = K.dot(relatives, kernel_r)  # (batch, prev_len + seq_len, units)
        if self.use_bias:
            w_q = K.bias_add(w_q, bias_q)
            w_kv = K.bias_add(w_kv, bias_kv)
            w_r = K.bias_add(w_r, bias_r)
        if self.activation is not None:
            w_q = self.activation(w_q)
            w_kv = self.activation(w_kv)
            w_r = self.activation(w_r)

        w_k = w_kv[:, :, :self.units]  # (batch, prev_len + seq_len, units)
        w_v = w_kv[:, :, self.units:]  # (batch, prev_len + seq_len, units)
        batch_size, q_len, k_len = K.shape(inputs)[0], K.shape(
            w_q)[1], K.shape(w_k)[1]

        w_qc = K.bias_add(w_q, bias_context)
        w_qc = self._reshape_to_batches(
            w_qc)  # (batch * n_head, seq_len, units_head)
        w_k = self._reshape_to_batches(
            w_k)  # (batch * n_head, prev_len + seq_len, units_head)
        a_context = K.batch_dot(
            w_qc, w_k, axes=2)  # (batch * n_head, seq_len, prev_len + seq_len)

        w_qr = K.bias_add(w_q, bias_relative)
        w_qr = self._reshape_to_batches(
            w_qr)  # (batch * n_head, seq_len, units_head)
        w_r = self._reshape_to_batches(
            w_r)  # (batch * n_head, prev_len + seq_len, units_head)
        a_relative = K.batch_dot(
            w_qr, w_r, axes=2)  # (batch * n_head, seq_len, prev_len + seq_len)
        a_relative = self._relative_shift(  # (batch * n_head, seq_len, prev_len + seq_len)
            a_relative,
            key_len_expected=K.shape(a_context)[-1],
        )

        w_qs = K.bias_add(w_q, bias_segment)
        w_qs = K.reshape(w_qs, (-1, q_len, self.num_head, self.units_head))
        w_qs = K.permute_dimensions(
            w_qs, (2, 0, 1, 3))  # (n_head, batch, seq_len, units_head)
        segment_embed = K.reshape(K.transpose(segment_embed),
                                  (self.num_head, 1, self.units_head, 2))
        segment_embed = K.tile(segment_embed, (1, batch_size, 1, 1))
        a_segment = K.batch_dot(w_qs, segment_embed,
                                axes=(3, 2))  # (n_head, batch, seq_len, 2)
        a_segment = K.permute_dimensions(
            a_segment, (1, 2, 3, 0))  # (batch, seq_len, 2, n_head)
        a_segment = K.batch_dot(
            segment_mat, a_segment,
            axes=(3, 2))  # (batch, seq_len, prev_len + seq_len, n_head)
        a_segment = K.reshape(K.permute_dimensions(a_segment, (0, 3, 1, 2)),
                              (-1, q_len, k_len))

        att = (a_context + a_relative + a_segment) / K.sqrt(
            K.constant(self.units_head, dtype=K.floatx()))
        exp = K.exp(att - K.max(att, axis=-1, keepdims=True))

        permutation = K.tile(K.expand_dims(permutation, axis=1),
                             [1, self.num_head, 1, 1])
        permutation = K.reshape(permutation, (-1, q_len, k_len))
        exp *= permutation
        if mask is not None and mask[0] is not None:
            mask = K.cast(mask[0], K.floatx())
            mask = K.concatenate([K.ones_like(memories[:, :, 0]), mask],
                                 axis=1)
            exp *= K.expand_dims(self._reshape_mask(mask), axis=1)

        att = exp / (K.sum(exp, axis=-1, keepdims=True) + K.epsilon())
        if self.att_drop_layer is not None:
            att = self.att_drop_layer(att, training=training)
        w_v = self._reshape_to_batches(
            w_v)  # (batch * n_head, prev_len + seq_len, units_head)
        w_o = K.batch_dot(att, w_v)  # (batch * n_head, seq_len, units_head)

        w_o = self._reshape_from_batches(w_o)  # (batch, seq_len, units)
        w_o = K.dot(w_o, kernel_o)  # (batch, seq_len, units)
        if self.use_bias:
            w_o = K.bias_add(w_o, bias_o)
        if self.activation is not None:
            w_o = self.activation(w_o)

        if TF_KERAS:
            # Add shape information to tensor when using `tf.keras`
            input_shape = K.int_shape(inputs)
            if input_shape[1] is not None:
                w_o = K.reshape(w_o, (-1, ) + input_shape[1:])
        return w_o
Ejemplo n.º 12
0
 def call(self, inputs, mask=None, **kwargs):
     inputs, embeddings = inputs
     outputs = K.bias_add(K.dot(inputs, K.transpose(embeddings)), self.bias)
     return keras.activations.softmax(outputs)