Ejemplo n.º 1
0
    def encode(self, is_training=True):
        options = self.options

        # ======word representation layer======
        in_question_repres = []
        in_passage_repres = []
        input_dim = 0
        if options.with_word and self.word_vocab is not None:
            word_vec_trainable = True
            cur_device = '/gpu:0'
            if options.fix_word_vec:
                word_vec_trainable = False
                cur_device = '/cpu:0'
            with tf.variable_scope("embedding"), tf.device(cur_device):
                self.word_embedding = tf.get_variable(
                    "word_embedding",
                    trainable=word_vec_trainable,
                    initializer=tf.constant(self.word_vocab.word_vecs),
                    dtype=tf.float32)

            in_question_word_repres = tf.nn.embedding_lookup(
                self.word_embedding,
                self.in_question_words)  # [batch_size, question_len, word_dim]
            in_passage_word_repres = tf.nn.embedding_lookup(
                self.word_embedding,
                self.in_passage_words)  # [batch_size, passage_len, word_dim]
            in_question_repres.append(in_question_word_repres)
            in_passage_repres.append(in_passage_word_repres)

            input_shape = tf.shape(self.in_question_words)
            batch_size = input_shape[0]
            question_len = input_shape[1]
            input_shape = tf.shape(self.in_passage_words)
            passage_len = input_shape[1]
            input_dim += self.word_vocab.word_dim

        if options.with_char and self.char_vocab is not None:
            input_shape = tf.shape(self.in_question_chars)
            batch_size = input_shape[0]
            question_len = input_shape[1]
            q_char_len = input_shape[2]
            input_shape = tf.shape(self.in_passage_chars)
            passage_len = input_shape[1]
            p_char_len = input_shape[2]
            char_dim = self.char_vocab.word_dim
            self.char_embedding = tf.get_variable(
                "char_embedding",
                initializer=tf.constant(self.char_vocab.word_vecs),
                dtype=tf.float32)

            in_question_char_repres = tf.nn.embedding_lookup(
                self.char_embedding, self.in_question_chars
            )  # [batch_size, question_len, q_char_len, char_dim]
            in_question_char_repres = tf.reshape(
                in_question_char_repres, shape=[-1, q_char_len, char_dim])
            question_char_lengths = tf.reshape(self.question_char_lengths,
                                               [-1])
            in_passage_char_repres = tf.nn.embedding_lookup(
                self.char_embedding, self.in_passage_chars
            )  # [batch_size, passage_len, p_char_len, char_dim]
            in_passage_char_repres = tf.reshape(
                in_passage_char_repres, shape=[-1, p_char_len, char_dim])
            passage_char_lengths = tf.reshape(self.passage_char_lengths, [-1])
            with tf.variable_scope('char_lstm'):
                # lstm cell
                char_lstm_cell = tf.contrib.rnn.BasicLSTMCell(
                    options.char_lstm_dim)
                # dropout
                if is_training:
                    char_lstm_cell = tf.contrib.rnn.DropoutWrapper(
                        char_lstm_cell,
                        output_keep_prob=(1 - options.dropout_rate))
                char_lstm_cell = tf.contrib.rnn.MultiRNNCell([char_lstm_cell])

                # question_representation
                question_char_outputs = tf.nn.dynamic_rnn(
                    char_lstm_cell,
                    in_question_char_repres,
                    sequence_length=question_char_lengths,
                    dtype=tf.float32
                )[0]  # [batch_size*question_len, q_char_len, char_lstm_dim]
                question_char_outputs = question_char_outputs[:, -1, :]
                question_char_outputs = tf.reshape(
                    question_char_outputs,
                    [batch_size, question_len, options.char_lstm_dim])

                tf.get_variable_scope().reuse_variables()
                # passage representation
                passage_char_outputs = tf.nn.dynamic_rnn(
                    char_lstm_cell,
                    in_passage_char_repres,
                    sequence_length=passage_char_lengths,
                    dtype=tf.float32
                )[0]  # [batch_size*question_len, q_char_len, char_lstm_dim]
                passage_char_outputs = passage_char_outputs[:, -1, :]
                passage_char_outputs = tf.reshape(
                    passage_char_outputs,
                    [batch_size, passage_len, options.char_lstm_dim])

            in_question_repres.append(question_char_outputs)
            in_passage_repres.append(passage_char_outputs)
            input_dim += options.char_lstm_dim

        if options.with_POS and self.POS_vocab is not None:
            self.POS_embedding = tf.get_variable("POS_embedding",
                                                 initializer=tf.constant(
                                                     self.POS_vocab.word_vecs),
                                                 dtype=tf.float32)

            in_question_POS_repres = tf.nn.embedding_lookup(
                self.POS_embedding,
                self.in_question_POSs)  # [batch_size, question_len, POS_dim]
            in_passage_POS_repres = tf.nn.embedding_lookup(
                self.POS_embedding,
                self.in_passage_POSs)  # [batch_size, passage_len, POS_dim]
            in_question_repres.append(in_question_POS_repres)
            in_passage_repres.append(in_passage_POS_repres)

            input_shape = tf.shape(self.in_question_POSs)
            batch_size = input_shape[0]
            question_len = input_shape[1]
            input_shape = tf.shape(self.in_passage_POSs)
            passage_len = input_shape[1]
            input_dim += self.POS_vocab.word_dim

        if options.with_NER and self.NER_vocab is not None:
            self.NER_embedding = tf.get_variable("NER_embedding",
                                                 initializer=tf.constant(
                                                     self.NER_vocab.word_vecs),
                                                 dtype=tf.float32)

            in_question_NER_repres = tf.nn.embedding_lookup(
                self.NER_embedding,
                self.in_question_NERs)  # [batch_size, question_len, NER_dim]
            in_passage_NER_repres = tf.nn.embedding_lookup(
                self.NER_embedding,
                self.in_passage_NERs)  # [batch_size, passage_len, NER_dim]
            in_question_repres.append(in_question_NER_repres)
            in_passage_repres.append(in_passage_NER_repres)

            input_shape = tf.shape(self.in_question_NERs)
            batch_size = input_shape[0]
            question_len = input_shape[1]
            input_shape = tf.shape(self.in_passage_NERs)
            passage_len = input_shape[1]
            input_dim += self.NER_vocab.word_dim

        in_question_repres = tf.concat(in_question_repres,
                                       2)  # [batch_size, question_len, dim]
        in_passage_repres = tf.concat(in_passage_repres,
                                      2)  # [batch_size, passage_len, dim]

        if options.compress_input:  # compress input word vector into smaller vectors
            w_compress = tf.get_variable(
                "w_compress_input", [input_dim, options.compress_input_dim],
                dtype=tf.float32)
            b_compress = tf.get_variable("b_compress_input",
                                         [options.compress_input_dim],
                                         dtype=tf.float32)

            in_question_repres = tf.reshape(in_question_repres,
                                            [-1, input_dim])
            in_question_repres = tf.matmul(in_question_repres,
                                           w_compress) + b_compress
            in_question_repres = tf.tanh(in_question_repres)
            in_question_repres = tf.reshape(
                in_question_repres,
                [batch_size, question_len, options.compress_input_dim])

            in_passage_repres = tf.reshape(in_passage_repres, [-1, input_dim])
            in_passage_repres = tf.matmul(in_passage_repres,
                                          w_compress) + b_compress
            in_passage_repres = tf.tanh(in_passage_repres)
            in_passage_repres = tf.reshape(
                in_passage_repres,
                [batch_size, passage_len, options.compress_input_dim])
            input_dim = options.compress_input_dim

        if is_training:
            in_question_repres = tf.nn.dropout(in_question_repres,
                                               (1 - options.dropout_rate))
            in_passage_repres = tf.nn.dropout(in_passage_repres,
                                              (1 - options.dropout_rate))
        else:
            in_question_repres = tf.multiply(in_question_repres,
                                             (1 - options.dropout_rate))
            in_passage_repres = tf.multiply(in_passage_repres,
                                            (1 - options.dropout_rate))

        passage_mask = tf.sequence_mask(
            self.passage_lengths, passage_len,
            dtype=tf.float32)  # [batch_size, passage_len]
        question_mask = tf.sequence_mask(
            self.question_lengths, question_len,
            dtype=tf.float32)  # [batch_size, question_len]

        # ======Highway layer======
        if options.with_highway:
            with tf.variable_scope("input_highway"):
                in_question_repres = match_utils.multi_highway_layer(
                    in_question_repres, input_dim, options.highway_layer_num)
                tf.get_variable_scope().reuse_variables()
                in_passage_repres = match_utils.multi_highway_layer(
                    in_passage_repres, input_dim, options.highway_layer_num)

        # ======Filter layer======
        cosine_matrix = match_utils.cal_relevancy_matrix(
            in_question_repres, in_passage_repres)
        cosine_matrix = match_utils.mask_relevancy_matrix(
            cosine_matrix, question_mask, passage_mask)
        #         relevancy_matrix = tf.select(tf.greater(cosine_matrix,
        #                                     tf.scalar_mul(filter_layer_threshold, tf.ones_like(cosine_matrix, dtype=tf.float32))),
        #                                     cosine_matrix, tf.zeros_like(cosine_matrix, dtype=tf.float32)) # [batch_size, passage_len, question_len]
        raw_in_passage_repres = in_passage_repres
        if options.with_filter_layer:
            relevancy_matrix = cosine_matrix  # [batch_size, passage_len, question_len]
            relevancy_degrees = tf.reduce_max(
                relevancy_matrix, axis=2)  # [batch_size, passage_len]
            relevancy_degrees = tf.expand_dims(
                relevancy_degrees, axis=-1)  # [batch_size, passage_len, 'x']
            in_passage_repres = tf.multiply(in_passage_repres,
                                            relevancy_degrees)

        # =======Context Representation Layer & Multi-Perspective matching layer=====
        all_question_aware_representatins = []
        question_aware_dim = 0
        if options.with_word_match:
            with tf.variable_scope('word_level_matching'):
                (word_match_vectors,
                 word_match_dim) = match_utils.match_passage_with_question(
                     raw_in_passage_repres,
                     None,
                     passage_mask,
                     in_question_repres,
                     None,
                     question_mask,
                     input_dim,
                     with_full_matching=False,
                     with_attentive_matching=options.with_attentive_matching,
                     with_max_attentive_matching=options.
                     with_max_attentive_matching,
                     with_maxpooling_matching=options.with_maxpooling_matching,
                     with_local_attentive_matching=options.
                     with_local_attentive_matching,
                     win_size=options.win_size,
                     with_forward_match=True,
                     with_backward_match=False,
                     match_options=options)
                all_question_aware_representatins.extend(word_match_vectors)
                question_aware_dim += word_match_dim
        # lex decomposition
        if options.with_lex_decomposition:
            lex_decomposition = match_utils.cal_linear_decomposition_representation(
                raw_in_passage_repres, self.passage_lengths, cosine_matrix,
                is_training, options.lex_decompsition_dim,
                options.dropout_rate)
            all_question_aware_representatins.append(lex_decomposition)
            if options.lex_decompsition_dim == -1:
                question_aware_dim += 2 * input_dim
            else:
                question_aware_dim += 2 * options.lex_decompsition_dim

        if options.with_question_passage_word_feature:
            all_question_aware_representatins.append(raw_in_passage_repres)

            att_question_representation = match_utils.calculate_cosine_weighted_question_representation(
                in_question_repres, cosine_matrix)
            all_question_aware_representatins.append(
                att_question_representation)
            question_aware_dim += 2 * input_dim

        # sequential context matching
        question_forward = None
        question_backward = None
        passage_forward = None
        passage_backward = None
        if options.with_sequential_match:
            with tf.variable_scope('context_MP_matching'):
                cur_in_question_repres = in_question_repres
                cur_in_passage_repres = in_passage_repres
                for i in xrange(options.context_layer_num):
                    with tf.variable_scope('layer-{}'.format(i)):
                        with tf.variable_scope('context_represent'):
                            # parameters
                            context_lstm_cell_fw = tf.contrib.rnn.LSTMCell(
                                options.context_lstm_dim)
                            context_lstm_cell_bw = tf.contrib.rnn.LSTMCell(
                                options.context_lstm_dim)
                            if is_training:
                                context_lstm_cell_fw = tf.contrib.rnn.DropoutWrapper(
                                    context_lstm_cell_fw,
                                    output_keep_prob=(1 -
                                                      options.dropout_rate))
                                context_lstm_cell_bw = tf.contrib.rnn.DropoutWrapper(
                                    context_lstm_cell_bw,
                                    output_keep_prob=(1 -
                                                      options.dropout_rate))

                            # question representation
                            ((question_context_representation_fw,
                              question_context_representation_bw),
                             (question_forward, question_backward
                              )) = tf.nn.bidirectional_dynamic_rnn(
                                  context_lstm_cell_fw,
                                  context_lstm_cell_bw,
                                  cur_in_question_repres,
                                  dtype=tf.float32,
                                  sequence_length=self.question_lengths
                              )  # [batch_size, question_len, context_lstm_dim]
                            cur_in_question_repres = tf.concat([
                                question_context_representation_fw,
                                question_context_representation_bw
                            ], 2)

                            # passage representation
                            tf.get_variable_scope().reuse_variables()
                            ((passage_context_representation_fw,
                              passage_context_representation_bw),
                             (passage_forward, passage_backward
                              )) = tf.nn.bidirectional_dynamic_rnn(
                                  context_lstm_cell_fw,
                                  context_lstm_cell_bw,
                                  cur_in_passage_repres,
                                  dtype=tf.float32,
                                  sequence_length=self.passage_lengths
                              )  # [batch_size, passage_len, context_lstm_dim]
                            cur_in_passage_repres = tf.concat([
                                passage_context_representation_fw,
                                passage_context_representation_bw
                            ], 2)

                        # Multi-perspective matching
                        with tf.variable_scope('MP_matching'):
                            (matching_vectors, matching_dim
                             ) = match_utils.match_passage_with_question(
                                 passage_context_representation_fw,
                                 passage_context_representation_bw,
                                 passage_mask,
                                 question_context_representation_fw,
                                 question_context_representation_bw,
                                 question_mask,
                                 options.context_lstm_dim,
                                 with_full_matching=options.with_full_matching,
                                 with_attentive_matching=options.
                                 with_attentive_matching,
                                 with_max_attentive_matching=options.
                                 with_max_attentive_matching,
                                 with_maxpooling_matching=options.
                                 with_maxpooling_matching,
                                 with_local_attentive_matching=options.
                                 with_local_attentive_matching,
                                 win_size=options.win_size,
                                 with_forward_match=options.with_forward_match,
                                 with_backward_match=options.
                                 with_backward_match,
                                 match_options=options)
                            all_question_aware_representatins.extend(
                                matching_vectors)
                            question_aware_dim += matching_dim

        all_question_aware_representatins = tf.concat(
            all_question_aware_representatins,
            2)  # [batch_size, passage_len, dim]

        if is_training:
            all_question_aware_representatins = tf.nn.dropout(
                all_question_aware_representatins, (1 - options.dropout_rate))
        else:
            all_question_aware_representatins = tf.multiply(
                all_question_aware_representatins, (1 - options.dropout_rate))

        # ======Highway layer======
        if options.with_match_highway:
            with tf.variable_scope("matching_highway"):
                all_question_aware_representatins = match_utils.multi_highway_layer(
                    all_question_aware_representatins, question_aware_dim,
                    options.highway_layer_num)

        #========Aggregation Layer======
        if not options.with_aggregation:
            aggregation_representation = all_question_aware_representatins
            aggregation_dim = question_aware_dim
        else:
            aggregation_representation = []
            aggregation_dim = 0
            aggregation_input = all_question_aware_representatins
            with tf.variable_scope('aggregation_layer'):
                for i in xrange(options.aggregation_layer_num):
                    with tf.variable_scope('layer-{}'.format(i)):
                        aggregation_lstm_cell_fw = tf.contrib.rnn.BasicLSTMCell(
                            options.aggregation_lstm_dim)
                        aggregation_lstm_cell_bw = tf.contrib.rnn.BasicLSTMCell(
                            options.aggregation_lstm_dim)
                        if is_training:
                            aggregation_lstm_cell_fw = tf.contrib.rnn.DropoutWrapper(
                                aggregation_lstm_cell_fw,
                                output_keep_prob=(1 - options.dropout_rate))
                            aggregation_lstm_cell_bw = tf.contrib.rnn.DropoutWrapper(
                                aggregation_lstm_cell_bw,
                                output_keep_prob=(1 - options.dropout_rate))
                        aggregation_lstm_cell_fw = tf.contrib.rnn.MultiRNNCell(
                            [aggregation_lstm_cell_fw])
                        aggregation_lstm_cell_bw = tf.contrib.rnn.MultiRNNCell(
                            [aggregation_lstm_cell_bw])

                        cur_aggregation_representation, _ = rnn.bidirectional_dynamic_rnn(
                            aggregation_lstm_cell_fw,
                            aggregation_lstm_cell_bw,
                            aggregation_input,
                            dtype=tf.float32,
                            sequence_length=self.passage_lengths)
                        cur_aggregation_representation = tf.concat(
                            cur_aggregation_representation, 2
                        )  # [batch_size, passage_len, 2*aggregation_lstm_dim]
                        aggregation_representation.append(
                            cur_aggregation_representation)
                        aggregation_dim += 2 * options.aggregation_lstm_dim
                        aggregation_input = cur_aggregation_representation

            aggregation_representation = tf.concat(aggregation_representation,
                                                   2)
            aggregation_representation = tf.concat([
                aggregation_representation, all_question_aware_representatins
            ], 2)
            aggregation_dim += question_aware_dim

        # ======Highway layer======
        if options.with_aggregation_highway:
            with tf.variable_scope("aggregation_highway"):
                aggregation_representation = match_utils.multi_highway_layer(
                    aggregation_representation, aggregation_dim,
                    options.highway_layer_num)

        #========output Layer=========
        encode_size = aggregation_dim + input_dim
        encode_hiddens = tf.concat(
            [aggregation_representation, in_passage_repres],
            2)  # [batch_size, passage_len, enc_size]
        encode_hiddens = encode_hiddens * tf.expand_dims(passage_mask, axis=-1)

        # initial state for the LSTM decoder
        #'''
        with tf.variable_scope('initial_state_for_decoder'):
            # Define weights and biases to reduce the cell and reduce the state
            w_reduce_c = tf.get_variable(
                'w_reduce_c',
                [4 * options.context_lstm_dim, options.gen_hidden_size],
                dtype=tf.float32)
            w_reduce_h = tf.get_variable(
                'w_reduce_h',
                [4 * options.context_lstm_dim, options.gen_hidden_size],
                dtype=tf.float32)
            bias_reduce_c = tf.get_variable('bias_reduce_c',
                                            [options.gen_hidden_size],
                                            dtype=tf.float32)
            bias_reduce_h = tf.get_variable('bias_reduce_h',
                                            [options.gen_hidden_size],
                                            dtype=tf.float32)

            old_c = tf.concat(values=[
                question_forward.c, question_backward.c, passage_forward.c,
                passage_backward.c
            ],
                              axis=1)
            old_h = tf.concat(values=[
                question_forward.h, question_backward.h, passage_forward.h,
                passage_backward.h
            ],
                              axis=1)
            new_c = tf.nn.tanh(tf.matmul(old_c, w_reduce_c) + bias_reduce_c)
            new_h = tf.nn.tanh(tf.matmul(old_h, w_reduce_h) + bias_reduce_h)

            init_state = tf.contrib.rnn.LSTMStateTuple(new_c, new_h)
        '''
        new_c = tf.zeros([batch_size, options.gen_hidden_size])
        new_h = tf.zeros([batch_size, options.gen_hidden_size])
        init_state = LSTMStateTuple(new_c, new_h)
        '''
        return (encode_size, encode_hiddens, init_state)
Ejemplo n.º 2
0
def graph_match(in_question_repres,
                in_passage_repres,
                question_mask,
                passage_mask,
                edge_embedding,
                question_neighbor_indices,
                passage_neighbor_indices,
                question_neighbor_edges,
                passage_neighbor_edges,
                question_neighbor_size,
                passage_neighbor_size,
                neighbor_vector_dim,
                input_dim,
                edge_dim,
                num_syntax_match_layer,
                with_attentive_matching=True,
                with_max_attentive_matching=True,
                with_maxpooling_matching=True,
                match_options=None):

    all_matching_vectors = []
    all_matching_dim = 0

    input_shape = tf.shape(question_neighbor_indices)
    batch_size = input_shape[0]
    question_len = input_shape[1]
    num_question_neighbors = input_shape[2]

    input_shape = tf.shape(passage_neighbor_indices)
    #     batch_size = input_shape[0]
    passage_len = input_shape[1]
    num_passage_neighbors = input_shape[2]

    question_neighbor_mask = tf.sequence_mask(tf.reshape(
        question_neighbor_size, [-1]),
                                              num_question_neighbors,
                                              dtype=tf.float32)
    question_neighbor_mask = tf.reshape(
        question_neighbor_mask,
        [batch_size, question_len, num_question_neighbors])

    passage_neighbor_mask = tf.sequence_mask(tf.reshape(
        passage_neighbor_size, [-1]),
                                             num_passage_neighbors,
                                             dtype=tf.float32)
    passage_neighbor_mask = tf.reshape(
        passage_neighbor_mask,
        [batch_size, passage_len, num_passage_neighbors])

    question_neighbor_edge_representations = tf.nn.embedding_lookup(
        edge_embedding, question_neighbor_edges)
    # [batch_size, question_len, num_question_neighbors, edge_dim]
    passage_neighbor_edge_representations = tf.nn.embedding_lookup(
        edge_embedding, passage_neighbor_edges)
    # [batch_size, passage_len, num_passage_neighbors, edge_dim]
    question_neighbor_node_representations = collect_neighbor_node_representations(
        in_question_repres, question_neighbor_indices)
    # [batch_size, question_len, num_question_neighbors, input_dim]
    passage_neighbor_node_representations = collect_neighbor_node_representations(
        in_passage_repres, passage_neighbor_indices)
    # [batch_size, passage_len, num_passage_neighbors, input_dim]

    question_neighbor_representations = tf.concat(3, [
        question_neighbor_node_representations,
        question_neighbor_edge_representations
    ])
    # [batch_size, question_len, num_question_neighbors, input_dim+ edge_dim]
    passage_neighbor_representations = tf.concat(3, [
        passage_neighbor_node_representations,
        passage_neighbor_edge_representations
    ])
    # [batch_size, passage_len, num_passage_neighbors, input_dim + edge_dim]

    # =====compress neighbor_representations
    compress_vector_dim = neighbor_vector_dim
    w_compress = tf.get_variable("w_compress",
                                 [input_dim + edge_dim, compress_vector_dim],
                                 dtype=tf.float32)
    b_compress = tf.get_variable("b_compress", [compress_vector_dim],
                                 dtype=tf.float32)

    question_neighbor_representations = tf.reshape(
        question_neighbor_representations, [-1, input_dim + edge_dim])
    question_neighbor_representations = tf.matmul(
        question_neighbor_representations, w_compress) + b_compress
    question_neighbor_representations = tf.tanh(
        question_neighbor_representations)
    # [batch_size*question_len*num_question_neighbors, compress_vector_dim]

    passage_neighbor_representations = tf.reshape(
        passage_neighbor_representations, [-1, input_dim + edge_dim])
    passage_neighbor_representations = tf.matmul(
        passage_neighbor_representations, w_compress) + b_compress
    passage_neighbor_representations = tf.tanh(
        passage_neighbor_representations)
    # [batch_size*passage_len*num_passage_neighbors, compress_vector_dim]

    # assume each node has a neighbor vector, and it is None at the beginning
    question_node_hidden = tf.zeros(
        [batch_size, question_len, neighbor_vector_dim])
    question_node_cell = tf.zeros(
        [batch_size, question_len, neighbor_vector_dim])

    passage_node_hidden = tf.zeros(
        [batch_size, passage_len, neighbor_vector_dim])
    passage_node_cell = tf.zeros(
        [batch_size, passage_len, neighbor_vector_dim])

    w_ingate = tf.get_variable("w_ingate",
                               [compress_vector_dim, neighbor_vector_dim],
                               dtype=tf.float32)
    u_ingate = tf.get_variable("u_ingate",
                               [neighbor_vector_dim, neighbor_vector_dim],
                               dtype=tf.float32)
    b_ingate = tf.get_variable("b_ingate", [neighbor_vector_dim],
                               dtype=tf.float32)

    w_forgetgate = tf.get_variable("w_forgetgate",
                                   [compress_vector_dim, neighbor_vector_dim],
                                   dtype=tf.float32)
    u_forgetgate = tf.get_variable("u_forgetgate",
                                   [neighbor_vector_dim, neighbor_vector_dim],
                                   dtype=tf.float32)
    b_forgetgate = tf.get_variable("b_forgetgate", [neighbor_vector_dim],
                                   dtype=tf.float32)

    w_outgate = tf.get_variable("w_outgate",
                                [compress_vector_dim, neighbor_vector_dim],
                                dtype=tf.float32)
    u_outgate = tf.get_variable("u_outgate",
                                [neighbor_vector_dim, neighbor_vector_dim],
                                dtype=tf.float32)
    b_outgate = tf.get_variable("b_outgate", [neighbor_vector_dim],
                                dtype=tf.float32)

    w_cell = tf.get_variable("w_cell",
                             [compress_vector_dim, neighbor_vector_dim],
                             dtype=tf.float32)
    u_cell = tf.get_variable("u_cell",
                             [neighbor_vector_dim, neighbor_vector_dim],
                             dtype=tf.float32)
    b_cell = tf.get_variable("b_cell", [neighbor_vector_dim], dtype=tf.float32)

    for i in xrange(num_syntax_match_layer):
        with tf.variable_scope('syntax_match_layer-{}'.format(i)):
            # ========for question============
            question_edge_prev_hidden = collect_neighbor_node_representations(
                question_node_hidden, question_neighbor_indices)
            # [batch_size, question_len, num_question_neighbors, neighbor_vector_dim]
            question_edge_prev_cell = collect_neighbor_node_representations(
                question_node_cell, question_neighbor_indices)
            # [batch_size, question_len, num_question_neighbors, neighbor_vector_dim]
            question_edge_prev_hidden = tf.reshape(question_edge_prev_hidden,
                                                   [-1, neighbor_vector_dim])
            question_edge_prev_cell = tf.reshape(question_edge_prev_cell,
                                                 [-1, neighbor_vector_dim])

            question_edge_ingate = tf.sigmoid(
                tf.matmul(question_neighbor_representations, w_ingate) +
                tf.matmul(question_edge_prev_hidden, u_ingate) + b_ingate)
            question_edge_forgetgate = tf.sigmoid(
                tf.matmul(question_neighbor_representations, w_forgetgate) +
                tf.matmul(question_edge_prev_hidden, u_forgetgate) +
                b_forgetgate)
            question_edge_outgate = tf.sigmoid(
                tf.matmul(question_neighbor_representations, w_outgate) +
                tf.matmul(question_edge_prev_hidden, u_outgate) + b_outgate)
            question_edge_cell_input = tf.tanh(
                tf.matmul(question_neighbor_representations, w_cell) +
                tf.matmul(question_edge_prev_hidden, u_cell) + b_cell)
            question_edge_cell = question_edge_forgetgate * question_edge_prev_cell + question_edge_ingate * question_edge_cell_input
            question_edge_hidden = question_edge_outgate * tf.tanh(
                question_edge_cell)
            question_edge_cell = tf.reshape(question_edge_cell, [
                batch_size, question_len, num_question_neighbors,
                neighbor_vector_dim
            ])
            question_edge_hidden = tf.reshape(question_edge_hidden, [
                batch_size, question_len, num_question_neighbors,
                neighbor_vector_dim
            ])
            # edge mask
            question_edge_cell = tf.mul(
                question_edge_cell,
                tf.expand_dims(question_neighbor_mask, axis=-1))
            question_edge_hidden = tf.mul(
                question_edge_hidden,
                tf.expand_dims(question_neighbor_mask, axis=-1))
            question_node_cell = tf.reduce_sum(question_edge_cell, axis=2)
            question_node_hidden = tf.reduce_sum(question_edge_hidden, axis=2)
            #[batch_size, question_len, neighbor_vector_dim]

            # node mask
            question_node_cell = question_node_cell * tf.expand_dims(
                question_mask, axis=-1)
            question_node_hidden = question_node_hidden * tf.expand_dims(
                question_mask, axis=-1)

            # ========for passage============
            passage_edge_prev_hidden = collect_neighbor_node_representations(
                passage_node_hidden, passage_neighbor_indices)
            passage_edge_prev_cell = collect_neighbor_node_representations(
                passage_node_cell, passage_neighbor_indices)
            # [batch_size, passage_len, num_passage_neighbors, neighbor_vector_dim]
            passage_edge_prev_hidden = tf.reshape(passage_edge_prev_hidden,
                                                  [-1, neighbor_vector_dim])
            passage_edge_prev_cell = tf.reshape(passage_edge_prev_cell,
                                                [-1, neighbor_vector_dim])

            passage_edge_ingate = tf.sigmoid(
                tf.matmul(passage_neighbor_representations, w_ingate) +
                tf.matmul(passage_edge_prev_hidden, u_ingate) + b_ingate)
            passage_edge_forgetgate = tf.sigmoid(
                tf.matmul(passage_neighbor_representations, w_forgetgate) +
                tf.matmul(passage_edge_prev_hidden, u_forgetgate) +
                b_forgetgate)
            passage_edge_outgate = tf.sigmoid(
                tf.matmul(passage_neighbor_representations, w_outgate) +
                tf.matmul(passage_edge_prev_hidden, u_outgate) + b_outgate)
            passage_edge_cell_input = tf.tanh(
                tf.matmul(passage_neighbor_representations, w_cell) +
                tf.matmul(passage_edge_prev_hidden, u_cell) + b_cell)
            passage_edge_cell = passage_edge_forgetgate * passage_edge_prev_cell + passage_edge_ingate * passage_edge_cell_input
            passage_edge_hidden = passage_edge_outgate * tf.tanh(
                passage_edge_cell)
            passage_edge_cell = tf.reshape(passage_edge_cell, [
                batch_size, passage_len, num_passage_neighbors,
                neighbor_vector_dim
            ])
            passage_edge_hidden = tf.reshape(passage_edge_hidden, [
                batch_size, passage_len, num_passage_neighbors,
                neighbor_vector_dim
            ])
            # edge mask
            passage_edge_cell = tf.mul(
                passage_edge_cell,
                tf.expand_dims(passage_neighbor_mask, axis=-1))
            passage_edge_hidden = tf.mul(
                passage_edge_hidden,
                tf.expand_dims(passage_neighbor_mask, axis=-1))
            passage_node_cell = tf.reduce_sum(passage_edge_cell, axis=2)
            passage_node_hidden = tf.reduce_sum(passage_edge_hidden, axis=2)

            # node mask
            passage_node_cell = passage_node_cell * tf.expand_dims(
                passage_mask, axis=-1)
            passage_node_hidden = passage_node_hidden * tf.expand_dims(
                passage_mask, axis=-1)

            #=====matching
            (node_matching_vectors,
             node_matching_dim) = match_utils.match_passage_with_question(
                 passage_node_hidden,
                 None,
                 passage_mask,
                 question_node_hidden,
                 None,
                 question_mask,
                 neighbor_vector_dim,
                 with_full_matching=False,
                 with_attentive_matching=with_attentive_matching,
                 with_max_attentive_matching=with_max_attentive_matching,
                 with_maxpooling_matching=with_maxpooling_matching,
                 with_forward_match=True,
                 with_backward_match=False,
                 match_options=match_options)
            all_matching_vectors.extend(
                node_matching_vectors
            )  #[batch_size, passage_len, node_matching_dim]
            all_matching_dim += node_matching_dim

    all_matching_vectors = tf.concat(
        2, all_matching_vectors)  # [batch_size, passage_len, all_matching_dim]
    return (all_matching_vectors, all_matching_dim)