Esempio n. 1
0
    def apply(self, is_train, x, x_mask=None):
        x_word_dim = tf.shape(x)[1]

        # (batch, x_word, key_word)
        dist_matrix = self.attention.get_scores(x, x)
        dist_matrix += tf.expand_dims(tf.eye(x_word_dim) * VERY_NEGATIVE_NUMBER, 0)  # Mask out self

        joint_mask = compute_attention_mask(x_mask, x_mask, x_word_dim, x_word_dim)
        if joint_mask is not None:
            dist_matrix += VERY_NEGATIVE_NUMBER * (1 - tf.cast(joint_mask, dist_matrix.dtype))

        if not self.alignment_bias:
            select_probs = tf.nn.softmax(dist_matrix)
        else:
            # Allow zero-attention by adding a learned bias to the normalizer
            bias = tf.exp(tf.get_variable("no-alignment-bias", initializer=tf.constant(-1.0, dtype=tf.float32)))
            dist_matrix = tf.exp(dist_matrix)
            select_probs = dist_matrix / (tf.reduce_sum(dist_matrix, axis=2, keep_dims=True) + bias)

        response = tf.matmul(select_probs, x)  # (batch, x_words, q_dim)

        if self.merge is not None:
            with tf.variable_scope("merge"):
                response = self.merge.apply(is_train, response, x)
            return response
        else:
            return response
Esempio n. 2
0
    def apply(self, is_train, x, keys, memories, x_mask=None, mem_mask=None):
        x_word_dim = tf.shape(x)[1]
        key_word_dim = tf.shape(keys)[1]

        dist_matrix = self.sim.get_scores(x, keys)
        joint_mask = compute_attention_mask(x_mask, mem_mask, x_word_dim, key_word_dim)
        if joint_mask is not None:
            dist_matrix += VERY_NEGATIVE_NUMBER * (1 - tf.cast(joint_mask, dist_matrix.dtype))
        query_probs = tf.nn.softmax(dist_matrix)  # probability of each mem_word per x_word

        # Batch matrix multiplication to get the attended vectors
        select_query = tf.matmul(query_probs, memories)  # (batch, x_words, q_dim)

        if not self.q2c:
            if self.query_dots:
                return tf.concat([x, select_query, x * select_query], axis=2)
            else:
                return tf.concat([x, select_query], axis=2)

        # select query-to-context
        context_dist = tf.reduce_max(dist_matrix, axis=2)  # (batch, x_word``s)
        context_probs = tf.nn.softmax(context_dist)  # (batch, x_words)
        select_context = tf.einsum("ai,aik->ak", context_probs, x)  # (batch, x_dim)
        select_context = tf.expand_dims(select_context, 1)

        if self.query_dots:
            return tf.concat([x, select_query, x * select_query, x * select_context], axis=2)
        else:
            return tf.concat([x, select_query, x * select_context], axis=2)
Esempio n. 3
0
    def apply(self, is_train, x, keys, memories, x_mask=None, mem_mask=None):
        x_word_dim = tf.shape(x)[1]
        key_word_dim = tf.shape(keys)[1]

        # (batch, x_word, key_word)
        dist_matrix = self.attention.get_scores(x, keys)

        joint_mask = compute_attention_mask(x_mask, mem_mask, x_word_dim, key_word_dim)
        dist_matrix += VERY_NEGATIVE_NUMBER * (1 - tf.cast(joint_mask, dist_matrix.dtype))

        if self.alignment_bias is None:
            select_probs = tf.nn.softmax(dist_matrix)
        else:
            # Compute softmax with an additional bias term, this allows the model to 'ignore' the memories
            # if needed since the sum of the weights given to each memory can be < 1.
            bias = tf.exp(tf.get_variable("no-alignment-bias", initializer=tf.constant(-1.0, dtype=tf.float32)))
            dist_matrix = tf.exp(dist_matrix)
            select_probs = dist_matrix / (tf.reduce_sum(dist_matrix, axis=2, keep_dims=True) + bias)

        #  Too (batch, x_word, memory_dim)
        response = tf.matmul(select_probs, memories)

        if self.merge is not None:
            with tf.variable_scope("merge"):
                response = self.merge.apply(is_train, response, x)
            return response
        else:
            return response
Esempio n. 4
0
    def apply(self, is_train, x, keys, memories, x_mask=None, mem_mask=None):
        x_word_dim = tf.shape(x)[1]
        key_word_dim = tf.shape(keys)[1]

        # (batch, x_word, key_word)
        dist_matrix = self.attention.get_scores(x, keys)

        joint_mask = compute_attention_mask(x_mask, mem_mask, x_word_dim,
                                            key_word_dim)
        dist_matrix += VERY_NEGATIVE_NUMBER * (
            1 - tf.cast(joint_mask, dist_matrix.dtype))

        if self.alignment_bias is None:
            select_probs = tf.nn.softmax(dist_matrix)
        else:
            bias = tf.exp(
                tf.get_variable("no-alignment-bias",
                                initializer=tf.constant(-1.0,
                                                        dtype=tf.float32)))
            dist_matrix = tf.exp(dist_matrix)
            select_probs = dist_matrix / (
                tf.reduce_sum(dist_matrix, axis=2, keep_dims=True) + bias)

        #  Too (batch, x_word, memory_dim)
        response = tf.matmul(select_probs, memories)

        with tf.variable_scope("encode_keys"):
            encoded = self.encoder_layer.apply(is_train, keys, mem_mask)

        return tf.concat(
            [x, response, x * response, x * tf.expand_dims(encoded, 1)],
            axis=2)
Esempio n. 5
0
    def apply(self, is_train, x, mask=None):
        batch_size = tf.shape(x)[0]
        x_word_dim = tf.shape(x)[1]
        x_feature_dim = x.shape.as_list()[-1]
        project_size = self.project_size
        if project_size is None:
            project_size = x_feature_dim // self.n_heads
            if x_feature_dim % self.n_heads != 0:
                raise ValueError()
        mem_size = self.memory_size
        if mem_size is None:
            mem_size = project_size

        init = get_keras_initialization(self.init)

        query_proj = tf.get_variable("query_proj", (x_feature_dim, self.n_heads, project_size), initializer=init)
        if self.shared_project:
            key_proj = query_proj
        else:
            key_proj = tf.get_variable("key_proj", (x_feature_dim, self.n_heads, project_size), initializer=init)
        mem_proj = tf.get_variable("mem_proj", (x_feature_dim, self.n_heads, mem_size), initializer=init)

        queries = tf.tensordot(x, query_proj, [[2], [0]])  # (batch, word, n_head, project_size)
        keys = tf.tensordot(x, key_proj, [[2], [0]])  # (batch, key, n_head, project_size)

        if self.project_bias:
            queries += tf.get_variable("query_bias", (1, 1, self.n_heads, project_size),
                                        initializer=tf.zeros_initializer())
            keys += tf.get_variable("key_bias", (1, 1, self.n_heads, project_size),
                                        initializer=tf.zeros_initializer())

        # dist_matrix = tf.matmul(queries, keys, transpose_b=True)
        dist_matrix = tf.einsum("bwhd,bkhd->bwkh", queries, keys)   # dots of (batch, word, key, head)

        if self.scale:
            dist_matrix /= tf.sqrt(float(project_size))

        if self.bilinear_comp:
            query_bias_proj = tf.get_variable("query_bias_proj", (x_feature_dim, self.n_heads), initializer=init)
            key_bias_proj = tf.get_variable("query_bias_proj", (x_feature_dim, self.n_heads), initializer=init)
            dist_matrix += tf.expand_dims(tf.tensordot(x, query_bias_proj, [[2], [0]]), 2)
            dist_matrix += tf.expand_dims(tf.tensordot(x, key_bias_proj, [[2], [0]]), 1)

        joint_mask = compute_attention_mask(mask, mask, x_word_dim, x_word_dim)
        if joint_mask is not None:
            dist_matrix += tf.expand_dims(VERY_NEGATIVE_NUMBER * (1 - tf.cast(joint_mask, dist_matrix.dtype)), 2)
        dist_matrix += tf.expand_dims(tf.expand_dims(tf.eye(x_word_dim) * VERY_NEGATIVE_NUMBER, 0), 2)

        if self.bias:
            bias = tf.get_variable("bias", (1, 1, self.n_heads, 1), initializer=tf.zeros_initializer())
            dist_matrix += bias

        select_probs = tf.nn.softmax(dist_matrix)   # for each (batch, word, head) probability over keys

        memories = tf.tensordot(x, mem_proj, [[2], [0]])  # (batch, memory, head, mem_size)
        response = tf.einsum("bwhk,bkhd->bwhd", select_probs, memories)  # (batch, word, head, mem_size)

        response = tf.reshape(response, (batch_size, x_word_dim, self.n_heads * mem_size))   # concat heads

        if self.merge is not None:
            with tf.variable_scope("merge"):
                 response = self.merge.apply(is_train, x, response)
        return response