def attention(query):
            """Put attention masks on hidden using hidden_features and query."""
            # Results of attention reads will be stored here.
            ds = []
            # Will store masks over encoder context
            attn_masks = []
            # Store attention logits
            attn_logits = []
            # If the query is a tuple (LSTMStateTuple), flatten it.
            if nest.is_sequence(query):
                query_list = nest.flatten(query)
                # Check that ndims == 2 if specified.
                for q in query_list:
                    ndims = q.get_shape().ndims
                    if ndims:
                        assert ndims == 2
                query = array_ops.concat(query_list, axis=1)
            with variable_scope.variable_scope("Attention"):
                if attn_type == "linear":
                    y = linear(query, attention_vec_size, True)
                    y = array_ops.reshape(y, [-1, 1, 1, attention_vec_size])
                    # Attention mask is a softmax of v^T * tanh(...).
                    s = math_ops.reduce_sum(
                        v[0] * math_ops.tanh(hidden_features[0] + y), [2, 3])
                elif attn_type == "bilinear":
                    query = tf.tile(tf.expand_dims(query, 1),
                                    [1, attn_length, 1])
                    query = batch_linear(query, attn_size, bias=True)
                    hid = tf.squeeze(hidden, [2])
                    s = tf.reduce_sum(tf.matmul(query, hid), [2])
                else:
                    # Two layer MLP
                    y = linear(query, attention_vec_size, True)
                    y = array_ops.reshape(y, [-1, 1, 1, attention_vec_size])
                    # Attention mask is a softmax of v^T * tanh(...).
                    layer1 = math_ops.tanh(hidden_features[0] + y)
                    k2 = variable_scope.get_variable(
                        "AttnW_1", [1, 1, attn_size, attention_vec_size])
                    layer2 = nn_ops.conv2d(layer1, k2, [1, 1, 1, 1], "SAME")
                    s = math_ops.reduce_sum(v[0] * math_ops.tanh(layer2),
                                            [2, 3])

                a = nn_ops.softmax(s)
                attn_masks.append(a)
                attn_logits.append(s)
                # Now calculate the attention-weighted vector d.
                # Hidden is encoder hidden states
                d = math_ops.reduce_sum(
                    array_ops.reshape(a, [-1, attn_length, 1, 1]) * hidden,
                    [1, 2])
                ds.append(array_ops.reshape(d, [-1, attn_size]))
            return ds, attn_masks, attn_logits
        def attention_kb_triple(query):
            """
            Compute attention over kb triples given decoder hidden state as a query
            :param query:
            :return:
            """
            # Expand dims so can concatenate with embedded_key
            with variable_scope.variable_scope("Attention_KB_Triple"):
                if attn_type == "two-mlp":
                    query = tf.expand_dims(query, [1])
                    with variable_scope.variable_scope("KB_key_W1"):
                        key_layer_1 = batch_linear(embedded_kb_key,
                                                   attention_vec_size,
                                                   bias=True)

                    with variable_scope.variable_scope("Query_W1"):
                        query_layer_1 = batch_linear(query,
                                                     attention_vec_size,
                                                     bias=True)

                    layer_1 = math_ops.tanh(key_layer_1 + query_layer_1)
                    with variable_scope.variable_scope("KB_Query_W2"):
                        layer_2 = batch_linear(layer_1,
                                               attention_vec_size,
                                               bias=True)

                    layer_2 = math_ops.tanh(layer_2)
                    with variable_scope.variable_scope("KB_Query_W3"):
                        layer_3 = batch_linear(layer_2, 1, bias=True)

                    layer_3_logits = tf.squeeze(layer_3, [2])
                    layer_3 = nn_ops.softmax(layer_3_logits)

                    return layer_3, layer_3_logits
                elif attn_type == "linear":
                    query = tf.expand_dims(query, [1])
                    with variable_scope.variable_scope("KB_key_W1"):
                        key_layer_1 = batch_linear(embedded_kb_key,
                                                   attention_vec_size,
                                                   bias=True)

                    with variable_scope.variable_scope("Query_W1"):
                        query_layer_1 = batch_linear(query,
                                                     attention_vec_size,
                                                     bias=True)

                    layer_1 = math_ops.tanh(key_layer_1 + query_layer_1)
                    with variable_scope.variable_scope("KB_Query_W2"):
                        layer_2 = batch_linear(layer_1, 1, bias=True)

                    layer_2_logits = tf.squeeze(layer_2, [2])
                    layer_2 = nn_ops.softmax(layer_2_logits)
                    return layer_2, layer_2_logits