Exemplo n.º 1
0
    def create_model_graph(self, num_classes, word_vocab=None, char_vocab=None, is_training=True, global_step=None):
        options = self.options
        # ======word representation layer======
        in_question_repres = [] # word and char
        in_passage_repres = [] # word and char
        input_dim = 0
        if word_vocab is not None:
            word_vec_trainable = True
            cur_device = '/gpu:0'
            if options.fix_word_vec:
                word_vec_trainable = False
                cur_device = '/cpu:0'
            with tf.device(cur_device):
                self.word_embedding = tf.get_variable("word_embedding", trainable=word_vec_trainable, 
                                                  initializer=tf.constant(word_vocab.word_vecs), dtype=tf.float32)

            in_question_word_repres = tf.nn.embedding_lookup(self.word_embedding, self.in_question_words) # [batch_size, question_len, word_dim]
            in_passage_word_repres = tf.nn.embedding_lookup(self.word_embedding, self.in_passage_words) # [batch_size, passage_len, word_dim]
            in_question_repres.append(in_question_word_repres)
            in_passage_repres.append(in_passage_word_repres)

            input_shape = tf.shape(self.in_question_words)
            batch_size = input_shape[0]
            question_len = input_shape[1]
            input_shape = tf.shape(self.in_passage_words)
            passage_len = input_shape[1]
            input_dim += word_vocab.word_dim
            
        if options.with_char and char_vocab is not None:
            input_shape = tf.shape(self.in_question_chars)
            batch_size = input_shape[0]
            question_len = input_shape[1]
            q_char_len = input_shape[2]
            input_shape = tf.shape(self.in_passage_chars)
            passage_len = input_shape[1]
            p_char_len = input_shape[2]
            char_dim = char_vocab.word_dim
            self.char_embedding = tf.get_variable("char_embedding", initializer=tf.constant(char_vocab.word_vecs), dtype=tf.float32)

            in_question_char_repres = tf.nn.embedding_lookup(self.char_embedding, self.in_question_chars) # [batch_size, question_len, q_char_len, char_dim]
            in_question_char_repres = tf.reshape(in_question_char_repres, shape=[-1, q_char_len, char_dim])
            question_char_lengths = tf.reshape(self.question_char_lengths, [-1])
            quesiton_char_mask = tf.sequence_mask(question_char_lengths, q_char_len, dtype=tf.float32)  # [batch_size*question_len, q_char_len]
            in_question_char_repres = tf.multiply(in_question_char_repres, tf.expand_dims(quesiton_char_mask, axis=-1))


            in_passage_char_repres = tf.nn.embedding_lookup(self.char_embedding, self.in_passage_chars) # [batch_size, passage_len, p_char_len, char_dim]
            in_passage_char_repres = tf.reshape(in_passage_char_repres, shape=[-1, p_char_len, char_dim])
            passage_char_lengths = tf.reshape(self.passage_char_lengths, [-1])
            passage_char_mask = tf.sequence_mask(passage_char_lengths, p_char_len, dtype=tf.float32)  # [batch_size*passage_len, p_char_len]
            in_passage_char_repres = tf.multiply(in_passage_char_repres, tf.expand_dims(passage_char_mask, axis=-1))

            (question_char_outputs_fw, question_char_outputs_bw, _) = layer_utils.my_lstm_layer(in_question_char_repres, options.char_lstm_dim,
                    input_lengths=question_char_lengths,scope_name="char_lstm", reuse=False,
                    is_training=is_training, dropout_rate=options.dropout_rate, use_cudnn=options.use_cudnn)
            question_char_outputs_fw = layer_utils.collect_final_step_of_lstm(question_char_outputs_fw, question_char_lengths - 1)
            question_char_outputs_bw = question_char_outputs_bw[:, 0, :]
            question_char_outputs = tf.concat(axis=1, values=[question_char_outputs_fw, question_char_outputs_bw])
            question_char_outputs = tf.reshape(question_char_outputs, [batch_size, question_len, 2*options.char_lstm_dim])

            (passage_char_outputs_fw, passage_char_outputs_bw, _) = layer_utils.my_lstm_layer(in_passage_char_repres, options.char_lstm_dim,
                    input_lengths=passage_char_lengths, scope_name="char_lstm", reuse=True,
                    is_training=is_training, dropout_rate=options.dropout_rate, use_cudnn=options.use_cudnn)
            passage_char_outputs_fw = layer_utils.collect_final_step_of_lstm(passage_char_outputs_fw, passage_char_lengths - 1)
            passage_char_outputs_bw = passage_char_outputs_bw[:, 0, :]
            passage_char_outputs = tf.concat(axis=1, values=[passage_char_outputs_fw, passage_char_outputs_bw])
            passage_char_outputs = tf.reshape(passage_char_outputs, [batch_size, passage_len, 2*options.char_lstm_dim])
                
            in_question_repres.append(question_char_outputs)
            in_passage_repres.append(passage_char_outputs)

            input_dim += 2*options.char_lstm_dim

        in_question_repres = tf.concat(axis=2, values=in_question_repres) # [batch_size, question_len, dim] # concat word and char
        in_passage_repres = tf.concat(axis=2, values=in_passage_repres) # [batch_size, passage_len, dim] # concat word and char

        if is_training:
            in_question_repres = tf.nn.dropout(in_question_repres, (1 - options.dropout_rate))
            in_passage_repres = tf.nn.dropout(in_passage_repres, (1 - options.dropout_rate))

        mask = tf.sequence_mask(self.passage_lengths, passage_len, dtype=tf.float32) # [batch_size, passage_len]
        question_mask = tf.sequence_mask(self.question_lengths, question_len, dtype=tf.float32) # [batch_size, question_len]

        # ======Highway layer======
        if options.with_highway:
            with tf.variable_scope("input_highway"):
                in_question_repres = match_utils.multi_highway_layer(in_question_repres, input_dim, options.highway_layer_num)
                tf.get_variable_scope().reuse_variables()
                in_passage_repres = match_utils.multi_highway_layer(in_passage_repres, input_dim, options.highway_layer_num)

        # in_question_repres = tf.multiply(in_question_repres, tf.expand_dims(question_mask, axis=-1))
        # in_passage_repres = tf.multiply(in_passage_repres, tf.expand_dims(mask, axis=-1))

        # ========Bilateral Matching=====
        (match_representation, match_dim) = match_utils.bilateral_match_func(in_question_repres, in_passage_repres,
                        self.question_lengths, self.passage_lengths, question_mask, mask, input_dim, is_training, options=options)

        #========Prediction Layer=========
        # match_dim = 4 * self.options.aggregation_lstm_dim
        w_0 = tf.get_variable("w_0", [match_dim, match_dim/2], dtype=tf.float32)
        b_0 = tf.get_variable("b_0", [match_dim/2], dtype=tf.float32)
        w_1 = tf.get_variable("w_1", [match_dim/2, num_classes],dtype=tf.float32)
        b_1 = tf.get_variable("b_1", [num_classes],dtype=tf.float32)

        # if is_training: match_representation = tf.nn.dropout(match_representation, (1 - options.dropout_rate))
        logits = tf.matmul(match_representation, w_0) + b_0
        logits = tf.tanh(logits)
        if is_training: logits = tf.nn.dropout(logits, (1 - options.dropout_rate))
        logits = tf.matmul(logits, w_1) + b_1

        self.prob = tf.nn.softmax(logits)
        
        gold_matrix = tf.one_hot(self.truth, num_classes, dtype=tf.float32)
        self.loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=logits, labels=gold_matrix))

        correct = tf.nn.in_top_k(logits, self.truth, 1)
        self.eval_correct = tf.reduce_sum(tf.cast(correct, tf.int32))
        self.predictions = tf.argmax(self.prob, 1)

        if not is_training: return

        tvars = tf.trainable_variables()
        if self.options.lambda_l2>0.0:
            l2_loss = tf.add_n([tf.nn.l2_loss(v) for v in tvars if v.get_shape().ndims > 1])
            self.loss = self.loss + self.options.lambda_l2 * l2_loss

        if self.options.optimize_type == 'adadelta':
            optimizer = tf.train.AdadeltaOptimizer(learning_rate=self.options.learning_rate)
        elif self.options.optimize_type == 'adam':
            optimizer = tf.train.AdamOptimizer(learning_rate=self.options.learning_rate)

        grads = layer_utils.compute_gradients(self.loss, tvars)
        grads, _ = tf.clip_by_global_norm(grads, self.options.grad_clipper)
        self.train_op = optimizer.apply_gradients(zip(grads, tvars), global_step=global_step)
        # self.train_op = optimizer.apply_gradients(zip(grads, tvars))

        if self.options.with_moving_average:
            # Track the moving averages of all trainable variables.
            MOVING_AVERAGE_DECAY = 0.9999  # The decay to use for the moving average.
            variable_averages = tf.train.ExponentialMovingAverage(MOVING_AVERAGE_DECAY, global_step)
            variables_averages_op = variable_averages.apply(tf.trainable_variables())
            train_ops = [self.train_op, variables_averages_op]
            self.train_op = tf.group(*train_ops)
Exemplo n.º 2
0
    def bilateral_match_func(self, in_question_repres, in_passage_repres,
                        question_lengths, passage_lengths, question_mask,
                        passage_mask, input_dim):
        question_aware_representatins = []
        question_aware_dim = 0
        passage_aware_representatins = []
        passage_aware_dim = 0

        # ====word level matching======
        (match_reps, match_dim) = self.match_passage_with_question(in_passage_repres,
                                    in_question_repres, passage_mask, question_mask,
                                    passage_lengths,
                                    question_lengths, input_dim, scope="word_match_forward",
                                    with_full_match=False, with_maxpool_match=self.config.with_maxpool_match,
                                    with_attentive_match=self.config.with_attentive_match,
                                    with_max_attentive_match=self.config.with_max_attentive_match,
                                    dropout_rate=self.dropout_rate, forward=True)
        question_aware_representatins.append(match_reps)
        question_aware_dim += match_dim

        (match_reps, match_dim) = self.match_passage_with_question(in_question_repres,
                                    in_passage_repres, question_mask, passage_mask,
                                    question_lengths,
                                    passage_lengths, input_dim, scope="word_match_backward",
                                    with_full_match=False, with_maxpool_match=self.config.with_maxpool_match,
                                    with_attentive_match=self.config.with_attentive_match,
                                    with_max_attentive_match=self.config.with_max_attentive_match,
                                    dropout_rate=self.dropout_rate, forward=False)
        passage_aware_representatins.append(match_reps)
        passage_aware_dim += match_dim

        with tf.variable_scope('context_MP_matching'):
            for i in range(self.config.context_layer_num): # support multiple context layer
                with tf.variable_scope('layer-{}'.format(i)):
                    # contextual lstm for both passage and question
                    in_question_repres = tf.multiply(in_question_repres, tf.expand_dims(question_mask, axis=-1))
                    in_passage_repres = tf.multiply(in_passage_repres, tf.expand_dims(passage_mask, axis=-1))
                    (question_context_representation_fw, question_context_representation_bw,
                     in_question_repres) = layer_utils.my_lstm_layer(
                            in_question_repres, self.config.context_lstm_dim, input_lengths=question_lengths, scope_name="context_represent",
                            reuse=False, dropout_rate=self.dropout_rate, use_cudnn=self.config.use_cudnn)
                    (passage_context_representation_fw, passage_context_representation_bw,
                     in_passage_repres) = layer_utils.my_lstm_layer(
                            in_passage_repres, self.config.context_lstm_dim, input_lengths=passage_lengths, scope_name="context_represent",
                            reuse=True, dropout_rate=self.dropout_rate, use_cudnn=self.config.use_cudnn)

                    # Multi-perspective matching
                    with tf.variable_scope('left_MP_matching'):
                        (match_reps, match_dim) = self.match_passage_with_question(passage_context_representation_fw,
                                    question_context_representation_fw, passage_mask, question_mask, passage_lengths,
                                    question_lengths, self.config.context_lstm_dim, scope="forward_match",
                                    with_full_match=self.config.with_full_match, with_maxpool_match=self.config.with_maxpool_match,
                                    with_attentive_match=self.config.with_attentive_match,
                                    with_max_attentive_match=self.config.with_max_attentive_match,
                                    dropout_rate=self.dropout_rate, forward=True)
                        question_aware_representatins.append(match_reps)
                        question_aware_dim += match_dim
                        (match_reps, match_dim) = self.match_passage_with_question(passage_context_representation_bw,
                                    question_context_representation_bw, passage_mask, question_mask, passage_lengths,
                                    question_lengths, self.config.context_lstm_dim, scope="backward_match",
                                    with_full_match=self.config.with_full_match, with_maxpool_match=self.config.with_maxpool_match,
                                    with_attentive_match=self.config.with_attentive_match,
                                    with_max_attentive_match=self.config.with_max_attentive_match,
                                    dropout_rate=self.dropout_rate, forward=False)
                        question_aware_representatins.append(match_reps)
                        question_aware_dim += match_dim

                    with tf.variable_scope('right_MP_matching'):
                        (match_reps, match_dim) = self.match_passage_with_question(question_context_representation_fw,
                                    passage_context_representation_fw, question_mask, passage_mask, question_lengths,
                                    passage_lengths, self.config.context_lstm_dim, scope="forward_match",
                                    with_full_match=self.config.with_full_match, with_maxpool_match=self.config.with_maxpool_match,
                                    with_attentive_match=self.config.with_attentive_match,
                                    with_max_attentive_match=self.config.with_max_attentive_match,
                                    dropout_rate=self.dropout_rate, forward=True)
                        passage_aware_representatins.append(match_reps)
                        passage_aware_dim += match_dim
                        (match_reps, match_dim) = self.match_passage_with_question(question_context_representation_bw,
                                    passage_context_representation_bw, question_mask, passage_mask, question_lengths,
                                    passage_lengths, self.config.context_lstm_dim, scope="backward_match",
                                    with_full_match=self.config.with_full_match, with_maxpool_match=self.config.with_maxpool_match,
                                    with_attentive_match=self.config.with_attentive_match,
                                    with_max_attentive_match=self.config.with_max_attentive_match,
                                    dropout_rate=self.dropout_rate, forward=False)
                        passage_aware_representatins.append(match_reps)
                        passage_aware_dim += match_dim

        question_aware_representatins = tf.concat(axis=2, values=question_aware_representatins) # [batch_size, passage_len, passage_aware_dim]
        passage_aware_representatins = tf.concat(axis=2, values=passage_aware_representatins) # [batch_size, question_len, question_aware_dim]

        question_aware_representatins = tf.nn.dropout(question_aware_representatins, (1 - self.dropout_rate))
        passage_aware_representatins = tf.nn.dropout(passage_aware_representatins, (1 - self.dropout_rate))

        # ======Highway layer======
        if self.config.with_match_highway:
            with tf.variable_scope("left_matching_highway"):
                question_aware_representatins = self.multi_highway_layer(question_aware_representatins, question_aware_dim,
                                                                    self.config.highway_layer_num)
            with tf.variable_scope("right_matching_highway"):
                passage_aware_representatins = self.multi_highway_layer(passage_aware_representatins, passage_aware_dim,
                                                               self.config.highway_layer_num)

        #========Aggregation Layer======
        aggregation_representation = []
        aggregation_dim = 0

        qa_aggregation_input = question_aware_representatins
        pa_aggregation_input = passage_aware_representatins
        with tf.variable_scope('aggregation_layer'):
            for i in range(self.config.aggregation_layer_num): # support multiple aggregation layer
                qa_aggregation_input = tf.multiply(qa_aggregation_input, tf.expand_dims(passage_mask, axis=-1))
                (fw_rep, bw_rep, cur_aggregation_representation) = layer_utils.my_lstm_layer(
                            qa_aggregation_input, self.config.aggregation_lstm_dim,
                            input_lengths=passage_lengths, scope_name='left_layer-{}'.format(i),
                            reuse=False, dropout_rate=self.dropout_rate, use_cudnn=self.config.use_cudnn)
                fw_rep = layer_utils.collect_final_step_of_lstm(fw_rep, passage_lengths - 1)
                bw_rep = bw_rep[:, 0, :]
                aggregation_representation.append(fw_rep)
                aggregation_representation.append(bw_rep)
                aggregation_dim += 2 * self.config.aggregation_lstm_dim
                # [batch_size, question_len, 2*aggregation_lstm_dim]
                qa_aggregation_input = cur_aggregation_representation

                pa_aggregation_input = tf.multiply(pa_aggregation_input, tf.expand_dims(question_mask, axis=-1))
                (fw_rep, bw_rep, cur_aggregation_representation) = layer_utils.my_lstm_layer(
                            pa_aggregation_input, self.config.aggregation_lstm_dim,
                            input_lengths=question_lengths, scope_name='right_layer-{}'.format(i),
                            reuse=False, dropout_rate=self.dropout_rate, use_cudnn=self.config.use_cudnn)
                fw_rep = layer_utils.collect_final_step_of_lstm(fw_rep, question_lengths - 1)
                bw_rep = bw_rep[:, 0, :]
                aggregation_representation.append(fw_rep)
                aggregation_representation.append(bw_rep)
                aggregation_dim += 2 * self.config.aggregation_lstm_dim
                # [batch_size, passage_len, 2*aggregation_lstm_dim]
                pa_aggregation_input = cur_aggregation_representation

        # [batch_size, 4*aggregation_lstm_dim*aggregation_layer_num]
        aggregation_representation = tf.concat(axis=1, values=aggregation_representation)

        # ======Highway layer======
        if self.config.with_aggregation_highway:
            with tf.variable_scope("aggregation_highway"):
                agg_shape = tf.shape(aggregation_representation)
                batch_size = agg_shape[0]
                aggregation_representation = tf.reshape(aggregation_representation, [1, batch_size, aggregation_dim])
                aggregation_representation = self.multi_highway_layer(aggregation_representation, aggregation_dim, self.config.highway_layer_num)
                aggregation_representation = tf.reshape(aggregation_representation, [batch_size, aggregation_dim])
        return (aggregation_representation, aggregation_dim)
Exemplo n.º 3
0
    def match_passage_with_question(self, passage_reps, question_reps,
                                    passage_mask, question_mask,
                                    passage_lengths, question_lengths,
                                    context_lstm_dim, scope=None,
                                    with_full_match=True, with_maxpool_match=True,
                                    with_attentive_match=True, with_max_attentive_match=True,
                                    dropout_rate=0, forward=True):
        passage_reps = tf.multiply(passage_reps, tf.expand_dims(passage_mask,-1))
        question_reps = tf.multiply(question_reps, tf.expand_dims(question_mask,-1))
        all_question_aware_representatins = []
        dim = 0
        with tf.variable_scope(scope or "match_passage_with_question"):
            # relevancy_matrix: [batch_size, p_len, q_len]
            relevancy_matrix = self.cal_relevancy_matrix(question_reps, passage_reps)
            relevancy_matrix = self.mask_relevancy_matrix(relevancy_matrix, question_mask, passage_mask)

            all_question_aware_representatins.append(tf.reduce_max(relevancy_matrix, axis=2, keep_dims=True))
            all_question_aware_representatins.append(tf.reduce_mean(relevancy_matrix, axis=2, keep_dims=True))
            dim += 2
            if with_full_match:
                if forward:
                    question_full_rep = layer_utils.collect_final_step_of_lstm(question_reps, question_lengths - 1)
                else:
                    question_full_rep = question_reps[:,0,:]

                passage_len = tf.shape(passage_reps)[1]
                question_full_rep = tf.expand_dims(question_full_rep, axis=1)
                # [batch_size, pasasge_len, feature_dim]
                question_full_rep = tf.tile(question_full_rep, [1, passage_len, 1])
                # attentive_rep: [batch_size, passage_len, match_dim]
                (attentive_rep, match_dim) = self.multi_perspective_match(context_lstm_dim,
                                    passage_reps, question_full_rep,
                                    dropout_rate=self.dropout_rate,
                                    scope_name='mp-match-full-match')
                all_question_aware_representatins.append(attentive_rep)
                dim += match_dim

            if with_maxpool_match:
                maxpooling_decomp_params = tf.get_variable("maxpooling_matching_decomp",
                                                shape=[self.config.cosine_MP_dim, context_lstm_dim],
                                                dtype=tf.float32)
                # maxpooling_rep: [batch_size, passage_len, 2 * cosine_MP_dim]
                maxpooling_rep = self.cal_maxpooling_matching(passage_reps,
                                    question_reps, maxpooling_decomp_params)
                all_question_aware_representatins.append(maxpooling_rep)
                dim += 2 * self.config.cosine_MP_dim

            if with_attentive_match:
                # atten_scores: [batch_size, p_len, q_len]
                atten_scores = layer_utils.calcuate_attention(passage_reps, question_reps, context_lstm_dim, context_lstm_dim,
                        scope_name="attention", att_type=self.config.att_type, att_dim=self.config.att_dim,
                        remove_diagnoal=False, mask1=passage_mask, mask2=question_mask, dropout_rate=self.dropout_rate)
                att_question_contexts = tf.matmul(atten_scores, question_reps)
                (attentive_rep, match_dim) = self.multi_perspective_match(context_lstm_dim,
                        passage_reps, att_question_contexts, dropout_rate=self.dropout_rate,
                        scope_name='mp-match-att_question')
                all_question_aware_representatins.append(attentive_rep)
                dim += match_dim

            if with_max_attentive_match:
                # relevancy_matrix: [batch_size, p_len, q_len]
                # question_reps: [batch_size, q_len, dim]
                # max_att: [batch_size, p_len, dim]
                max_att = self.cal_max_question_representation(question_reps, relevancy_matrix)
                # max_attentive_rep: [batch_size, passage_len, match_dim]
                (max_attentive_rep, match_dim) = self.multi_perspective_match(context_lstm_dim,
                        passage_reps, max_att, dropout_rate=self.dropout_rate,
                        scope_name='mp-match-max-att')
                all_question_aware_representatins.append(max_attentive_rep)
                dim += match_dim

            all_question_aware_representatins = tf.concat(axis=2, values=all_question_aware_representatins)
        return (all_question_aware_representatins, dim)
Exemplo n.º 4
0
def MCAN_match_func(in_question_repres,
                    in_passage_repres,
                    question_lengths,
                    passage_lengths,
                    question_mask,
                    passage_mask,
                    input_dim,
                    is_training,
                    options=None):
    question_aware_representatins = []
    question_aware_dim = 0
    passage_aware_representatins = []
    passage_aware_dim = 0

    # ====word level matching======
    # because the with_full_match allways False, so it has no significance that the forward is True or False.
    # match_passage_with_question(repres1,repres2,...) is to calculate each vector of repres1 to match whole repres2, so the return match_reps size is[batchSize,repres1.length,repre_dim]
    # passage to question
    (match_reps, match_dim) = match_passage_with_question(
        in_passage_repres,
        in_question_repres,
        passage_mask,
        question_mask,
        passage_lengths,
        question_lengths,
        input_dim,
        scope="word_match_forward",
        with_full_match=False,
        with_maxpool_match=options.with_maxpool_match,
        with_attentive_match=options.with_attentive_match,
        with_max_attentive_match=options.with_max_attentive_match,
        is_training=is_training,
        options=options,
        dropout_rate=options.dropout_rate,
        forward=False)
    question_aware_representatins.append(match_reps)
    question_aware_dim += match_dim

    # add passage to passage
    (match_reps, match_dim) = match_passage_with_question(
        in_passage_repres,
        in_passage_repres,
        passage_mask,
        passage_mask,
        passage_lengths,
        passage_lengths,
        input_dim,
        scope="word_match_passage",
        with_full_match=False,
        with_maxpool_match=options.with_maxpool_match,
        with_attentive_match=options.with_attentive_match,
        with_max_attentive_match=options.with_max_attentive_match,
        is_training=is_training,
        options=options,
        dropout_rate=options.dropout_rate,
        forward=False)
    question_aware_representatins.append(match_reps)
    question_aware_dim += match_dim

    # question to passage
    (match_reps, match_dim) = match_passage_with_question(
        in_question_repres,
        in_passage_repres,
        question_mask,
        passage_mask,
        question_lengths,
        passage_lengths,
        input_dim,
        scope="word_match_backward",
        with_full_match=False,
        with_maxpool_match=options.with_maxpool_match,
        with_attentive_match=options.with_attentive_match,
        with_max_attentive_match=options.with_max_attentive_match,
        is_training=is_training,
        options=options,
        dropout_rate=options.dropout_rate,
        forward=False)
    passage_aware_representatins.append(match_reps)
    passage_aware_dim += match_dim

    # add question to question
    (match_reps, match_dim) = match_passage_with_question(
        in_question_repres,
        in_question_repres,
        question_mask,
        question_mask,
        question_lengths,
        question_lengths,
        input_dim,
        scope="word_match_question",
        with_full_match=False,
        with_maxpool_match=options.with_maxpool_match,
        with_attentive_match=options.with_attentive_match,
        with_max_attentive_match=options.with_max_attentive_match,
        is_training=is_training,
        options=options,
        dropout_rate=options.dropout_rate,
        forward=False)
    passage_aware_representatins.append(match_reps)
    passage_aware_dim += match_dim

    with tf.variable_scope('context_MP_matching'):
        for i in range(
                options.context_layer_num):  # support multiple context layer
            with tf.variable_scope('layer-{}'.format(i)):
                # contextual lstm for both passage and question
                in_question_repres = tf.multiply(
                    in_question_repres, tf.expand_dims(question_mask, axis=-1))
                in_passage_repres = tf.multiply(
                    in_passage_repres, tf.expand_dims(passage_mask, axis=-1))
                (question_context_representation_fw,
                 question_context_representation_bw,
                 in_question_repres) = layer_utils.my_lstm_layer(
                     in_question_repres,
                     options.context_lstm_dim,
                     input_lengths=question_lengths,
                     scope_name="context_represent",
                     reuse=False,
                     is_training=is_training,
                     dropout_rate=options.dropout_rate,
                     use_cudnn=options.use_cudnn)
                (passage_context_representation_fw,
                 passage_context_representation_bw,
                 in_passage_repres) = layer_utils.my_lstm_layer(
                     in_passage_repres,
                     options.context_lstm_dim,
                     input_lengths=passage_lengths,
                     scope_name="context_represent",
                     reuse=True,
                     is_training=is_training,
                     dropout_rate=options.dropout_rate,
                     use_cudnn=options.use_cudnn)

                # Multi-perspective matching
                with tf.variable_scope('left_MP_matching'):
                    (match_reps, match_dim) = match_passage_with_question(
                        passage_context_representation_fw,
                        question_context_representation_fw,
                        passage_mask,
                        question_mask,
                        passage_lengths,
                        question_lengths,
                        options.context_lstm_dim,
                        scope="ques_forward_match",
                        with_full_match=options.with_full_match,
                        with_maxpool_match=options.with_maxpool_match,
                        with_attentive_match=options.with_attentive_match,
                        with_max_attentive_match=options.
                        with_max_attentive_match,
                        is_training=is_training,
                        options=options,
                        dropout_rate=options.dropout_rate,
                        forward=True)
                    question_aware_representatins.append(match_reps)
                    question_aware_dim += match_dim
                    (match_reps, match_dim) = match_passage_with_question(
                        passage_context_representation_bw,
                        question_context_representation_bw,
                        passage_mask,
                        question_mask,
                        passage_lengths,
                        question_lengths,
                        options.context_lstm_dim,
                        scope="ques_backward_match",
                        with_full_match=options.with_full_match,
                        with_maxpool_match=options.with_maxpool_match,
                        with_attentive_match=options.with_attentive_match,
                        with_max_attentive_match=options.
                        with_max_attentive_match,
                        is_training=is_training,
                        options=options,
                        dropout_rate=options.dropout_rate,
                        forward=False)
                    question_aware_representatins.append(match_reps)
                    question_aware_dim += match_dim
                    # add passage to passage
                    (match_reps, match_dim) = match_passage_with_question(
                        passage_context_representation_fw,
                        passage_context_representation_fw,
                        passage_mask,
                        passage_mask,
                        passage_lengths,
                        passage_lengths,
                        options.context_lstm_dim,
                        scope="pass_self_forward_match",
                        with_full_match=options.with_full_match,
                        with_maxpool_match=options.with_maxpool_match,
                        with_attentive_match=options.with_attentive_match,
                        with_max_attentive_match=options.
                        with_max_attentive_match,
                        is_training=is_training,
                        options=options,
                        dropout_rate=options.dropout_rate,
                        forward=True)

                    question_aware_representatins.append(match_reps)
                    question_aware_dim += match_dim
                    (match_reps, match_dim) = match_passage_with_question(
                        passage_context_representation_bw,
                        passage_context_representation_bw,
                        passage_mask,
                        passage_mask,
                        passage_lengths,
                        passage_lengths,
                        options.context_lstm_dim,
                        scope="pass_self_backward_match",
                        with_full_match=options.with_full_match,
                        with_maxpool_match=options.with_maxpool_match,
                        with_attentive_match=options.with_attentive_match,
                        with_max_attentive_match=options.
                        with_max_attentive_match,
                        is_training=is_training,
                        options=options,
                        dropout_rate=options.dropout_rate,
                        forward=False)
                    question_aware_representatins.append(match_reps)
                    question_aware_dim += match_dim

                with tf.variable_scope('right_MP_matching'):
                    (match_reps, match_dim) = match_passage_with_question(
                        question_context_representation_fw,
                        passage_context_representation_fw,
                        question_mask,
                        passage_mask,
                        question_lengths,
                        passage_lengths,
                        options.context_lstm_dim,
                        scope="pass_forward_match",
                        with_full_match=options.with_full_match,
                        with_maxpool_match=options.with_maxpool_match,
                        with_attentive_match=options.with_attentive_match,
                        with_max_attentive_match=options.
                        with_max_attentive_match,
                        is_training=is_training,
                        options=options,
                        dropout_rate=options.dropout_rate,
                        forward=True)
                    passage_aware_representatins.append(match_reps)
                    passage_aware_dim += match_dim
                    (match_reps, match_dim) = match_passage_with_question(
                        question_context_representation_bw,
                        passage_context_representation_bw,
                        question_mask,
                        passage_mask,
                        question_lengths,
                        passage_lengths,
                        options.context_lstm_dim,
                        scope="pass_backward_match",
                        with_full_match=options.with_full_match,
                        with_maxpool_match=options.with_maxpool_match,
                        with_attentive_match=options.with_attentive_match,
                        with_max_attentive_match=options.
                        with_max_attentive_match,
                        is_training=is_training,
                        options=options,
                        dropout_rate=options.dropout_rate,
                        forward=False)
                    passage_aware_representatins.append(match_reps)
                    passage_aware_dim += match_dim
                    # add question to question
                    (match_reps, match_dim) = match_passage_with_question(
                        question_context_representation_fw,
                        question_context_representation_fw,
                        question_mask,
                        question_mask,
                        question_lengths,
                        question_lengths,
                        options.context_lstm_dim,
                        scope="ques_self_forward_match",
                        with_full_match=options.with_full_match,
                        with_maxpool_match=options.with_maxpool_match,
                        with_attentive_match=options.with_attentive_match,
                        with_max_attentive_match=options.
                        with_max_attentive_match,
                        is_training=is_training,
                        options=options,
                        dropout_rate=options.dropout_rate,
                        forward=True)
                    passage_aware_representatins.append(match_reps)
                    passage_aware_dim += match_dim
                    (match_reps, match_dim) = match_passage_with_question(
                        question_context_representation_bw,
                        question_context_representation_bw,
                        question_mask,
                        question_mask,
                        question_lengths,
                        question_lengths,
                        options.context_lstm_dim,
                        scope="ques_self_backward_match",
                        with_full_match=options.with_full_match,
                        with_maxpool_match=options.with_maxpool_match,
                        with_attentive_match=options.with_attentive_match,
                        with_max_attentive_match=options.
                        with_max_attentive_match,
                        is_training=is_training,
                        options=options,
                        dropout_rate=options.dropout_rate,
                        forward=False)
                    passage_aware_representatins.append(match_reps)
                    passage_aware_dim += match_dim

    question_aware_representatins = tf.concat(
        axis=2, values=question_aware_representatins
    )  # [batch_size, passage_len, question_aware_dim]
    passage_aware_representatins = tf.concat(
        axis=2, values=passage_aware_representatins
    )  # [batch_size, question_len, question_aware_dim]

    if is_training:
        question_aware_representatins = tf.nn.dropout(
            question_aware_representatins, (1 - options.dropout_rate))
        passage_aware_representatins = tf.nn.dropout(
            passage_aware_representatins, (1 - options.dropout_rate))

    # ======Highway layer======
    if options.with_match_highway:
        with tf.variable_scope("left_matching_highway"):
            question_aware_representatins = multi_highway_layer(
                question_aware_representatins, question_aware_dim,
                options.highway_layer_num)
        with tf.variable_scope("right_matching_highway"):
            passage_aware_representatins = multi_highway_layer(
                passage_aware_representatins, passage_aware_dim,
                options.highway_layer_num)

    # ========Aggregation Layer======
    aggregation_representation = []
    aggregation_dim = 0

    qa_aggregation_input = question_aware_representatins
    pa_aggregation_input = passage_aware_representatins
    with tf.variable_scope('aggregation_layer'):
        for i in range(options.aggregation_layer_num
                       ):  # support multiple aggregation layer
            qa_aggregation_input = tf.multiply(
                qa_aggregation_input, tf.expand_dims(passage_mask, axis=-1))
            (fw_rep, bw_rep,
             cur_aggregation_representation) = layer_utils.my_lstm_layer(
                 qa_aggregation_input,
                 options.aggregation_lstm_dim,
                 input_lengths=passage_lengths,
                 scope_name='left_layer-{}'.format(i),
                 reuse=False,
                 is_training=is_training,
                 dropout_rate=options.dropout_rate,
                 use_cudnn=options.use_cudnn)
            fw_rep = layer_utils.collect_final_step_of_lstm(
                fw_rep, passage_lengths - 1)
            bw_rep = bw_rep[:, 0, :]
            aggregation_representation.append(fw_rep)
            aggregation_representation.append(bw_rep)
            aggregation_dim += 2 * options.aggregation_lstm_dim
            qa_aggregation_input = cur_aggregation_representation  # [batch_size, passage_len, 2*aggregation_lstm_dim]

            pa_aggregation_input = tf.multiply(
                pa_aggregation_input, tf.expand_dims(question_mask, axis=-1))
            (fw_rep, bw_rep,
             cur_aggregation_representation) = layer_utils.my_lstm_layer(
                 pa_aggregation_input,
                 options.aggregation_lstm_dim,
                 input_lengths=question_lengths,
                 scope_name='right_layer-{}'.format(i),
                 reuse=False,
                 is_training=is_training,
                 dropout_rate=options.dropout_rate,
                 use_cudnn=options.use_cudnn)
            fw_rep = layer_utils.collect_final_step_of_lstm(
                fw_rep, question_lengths - 1)
            bw_rep = bw_rep[:, 0, :]
            aggregation_representation.append(fw_rep)
            aggregation_representation.append(bw_rep)
            aggregation_dim += 2 * options.aggregation_lstm_dim
            pa_aggregation_input = cur_aggregation_representation  # [batch_size, passage_len, 2*aggregation_lstm_dim]

    aggregation_representation = tf.concat(
        axis=1,
        values=aggregation_representation)  # [batch_size, aggregation_dim]

    # ======Highway layer======
    if options.with_aggregation_highway:
        with tf.variable_scope("aggregation_highway"):
            agg_shape = tf.shape(aggregation_representation)
            batch_size = agg_shape[0]
            aggregation_representation = tf.reshape(
                aggregation_representation, [1, batch_size, aggregation_dim])
            aggregation_representation = multi_highway_layer(
                aggregation_representation, aggregation_dim,
                options.highway_layer_num)
            aggregation_representation = tf.reshape(
                aggregation_representation, [batch_size, aggregation_dim])

    return (aggregation_representation, aggregation_dim)
Exemplo n.º 5
0
def MCAN_match_func(in_question_repres,
                    in_passage_repres,
                    question_lengths,
                    passage_lengths,
                    question_mask,
                    passage_mask,
                    input_dim,
                    is_training,
                    scope="default",
                    options=None):
    question_reps = in_question_repres
    passage_reps = in_passage_repres

    total_match_dim = 0
    final_question_repres=question_reps
    final_passage_repres=passage_reps
    #####

    (match_reps, match_dim) = match_passage_with_question(in_passage_repres, in_question_repres, passage_mask,
                                                          question_mask, passage_lengths,
                                                          question_lengths, input_dim, scope="word_match_forward",
                                                          with_full_match=False,
                                                          with_maxpool_match=options.with_maxpool_match,
                                                          with_attentive_match=options.with_attentive_match,
                                                          with_max_attentive_match=options.with_max_attentive_match,
                                                          is_training=is_training, options=options,
                                                          dropout_rate=options.dropout_rate, forward=True)

    final_passage_repres = tf.concat([final_passage_repres, match_reps],
                                      axis=-1)
    total_match_dim+=match_dim
    (match_reps, match_dim) = match_passage_with_question(in_question_repres, in_passage_repres, question_mask,
                                                          passage_mask, question_lengths,
                                                          passage_lengths, input_dim, scope="word_match_backward",
                                                          with_full_match=False,
                                                          with_maxpool_match=options.with_maxpool_match,
                                                          with_attentive_match=options.with_attentive_match,
                                                          with_max_attentive_match=options.with_max_attentive_match,
                                                          is_training=is_training, options=options,
                                                          dropout_rate=options.dropout_rate, forward=False)
    final_question_repres = tf.concat([final_question_repres, match_reps],
                                       axis=-1)

    #####

    # self-attention

    # relevancy_matrix3 = cal_relevancy_matrix(question_reps, question_reps)
    # relevancy_matrix3 = mask_relevancy_matrix(relevancy_matrix3, question_mask, question_mask)
    # relevancy_matrix3 = tf.nn.softmax(relevancy_matrix3,axis=-1)
    # relevancy_matrix3 = mask_relevancy_matrix(relevancy_matrix3, question_mask, question_mask)
    # attended_question = tf.matmul(relevancy_matrix3,question_reps)
    # final_question_repres=tf.concat([final_question_repres,tf.layers.dense(attended_question, units=5)],axis=-1)
    #
    # relevancy_matrix4 = cal_relevancy_matrix(passage_reps, passage_reps)
    # relevancy_matrix4 = mask_relevancy_matrix(relevancy_matrix4, passage_mask, passage_mask)
    # relevancy_matrix4 = tf.nn.softmax(relevancy_matrix4, axis=-1)
    # relevancy_matrix4 = mask_relevancy_matrix(relevancy_matrix4, passage_mask, passage_mask)
    # attended_passage = tf.matmul(relevancy_matrix4, passage_reps)
    # final_passage_repres = tf.concat([final_passage_repres, tf.layers.dense(attended_passage, units=5)],
    #                                  axis=-1)

    # LSTM-matching
    in_question_repres_masked = tf.multiply(in_question_repres, tf.expand_dims(question_mask, axis=-1))
    in_passage_repres_masked = tf.multiply(in_passage_repres, tf.expand_dims(passage_mask, axis=-1))
    (question_context_representation_fw, question_context_representation_bw,
     in_question_repres_masked) = layer_utils.my_lstm_layer(
        in_question_repres_masked, options.context_lstm_dim, input_lengths=question_lengths,
        scope_name="context_represent",
        reuse=False, is_training=is_training, dropout_rate=options.dropout_rate,
        use_cudnn=options.use_cudnn)
    (passage_context_representation_fw, passage_context_representation_bw,
     in_passage_repres_masked) = layer_utils.my_lstm_layer(
        in_passage_repres_masked, options.context_lstm_dim, input_lengths=passage_lengths,
        scope_name="context_represent",
        reuse=True, is_training=is_training, dropout_rate=options.dropout_rate, use_cudnn=options.use_cudnn)

    # Multi-perspective matching
    with tf.variable_scope('left_MP_matching'):
        (match_reps, match_dim) = match_passage_with_question(passage_context_representation_fw,
                                                              question_context_representation_fw,
                                                              passage_mask, question_mask, passage_lengths,
                                                              question_lengths, options.context_lstm_dim,
                                                              scope="forward_match",
                                                              with_full_match=options.with_full_match,
                                                              with_maxpool_match=options.with_maxpool_match,
                                                              with_attentive_match=options.with_attentive_match,
                                                              with_max_attentive_match=options.with_max_attentive_match,
                                                              is_training=is_training, options=options,
                                                              dropout_rate=options.dropout_rate,
                                                              forward=True)
        final_passage_repres = tf.concat([final_passage_repres, match_reps],
                                         axis=-1)
        total_match_dim+=match_dim
        (match_reps, match_dim) = match_passage_with_question(passage_context_representation_bw,
                                                              question_context_representation_bw,
                                                              passage_mask, question_mask, passage_lengths,
                                                              question_lengths, options.context_lstm_dim,
                                                              scope="backward_match",
                                                              with_full_match=options.with_full_match,
                                                              with_maxpool_match=options.with_maxpool_match,
                                                              with_attentive_match=options.with_attentive_match,
                                                              with_max_attentive_match=options.with_max_attentive_match,
                                                              is_training=is_training, options=options,
                                                              dropout_rate=options.dropout_rate,
                                                              forward=False)
        final_passage_repres = tf.concat([final_passage_repres, match_reps],
                                         axis=-1)
        total_match_dim += match_dim
    with tf.variable_scope('right_MP_matching'):
        (match_reps, match_dim) = match_passage_with_question(question_context_representation_fw,
                                                              passage_context_representation_fw,
                                                              question_mask, passage_mask, question_lengths,
                                                              passage_lengths, options.context_lstm_dim,
                                                              scope="forward_match",
                                                              with_full_match=options.with_full_match,
                                                              with_maxpool_match=options.with_maxpool_match,
                                                              with_attentive_match=options.with_attentive_match,
                                                              with_max_attentive_match=options.with_max_attentive_match,
                                                              is_training=is_training, options=options,
                                                              dropout_rate=options.dropout_rate,
                                                              forward=True)
        final_question_repres = tf.concat([final_question_repres, match_reps],
                                          axis=-1)
        (match_reps, match_dim) = match_passage_with_question(question_context_representation_bw,
                                                              passage_context_representation_bw,
                                                              question_mask, passage_mask, question_lengths,
                                                              passage_lengths, options.context_lstm_dim,
                                                              scope="backward_match",
                                                              with_full_match=options.with_full_match,
                                                              with_maxpool_match=options.with_maxpool_match,
                                                              with_attentive_match=options.with_attentive_match,
                                                              with_max_attentive_match=options.with_max_attentive_match,
                                                              is_training=is_training, options=options,
                                                              dropout_rate=options.dropout_rate,
                                                              forward=False)
        final_question_repres = tf.concat([final_question_repres, match_reps],
                                          axis=-1)


    if is_training:
        final_question_repres = tf.nn.dropout(final_question_repres, (1 - options.dropout_rate))
        final_passage_repres = tf.nn.dropout(final_passage_repres, (1 - options.dropout_rate))
    print(total_match_dim)
    # ======Highway layer======
    #if options.with_match_highway:
    #    with tf.variable_scope("left_matching_highway"):
    #        final_question_repres = multi_highway_layer(final_question_repres, total_match_dim,
    #                                                            options.highway_layer_num)
    #    with tf.variable_scope("right_matching_highway"):
    #        final_passage_repres = multi_highway_layer(final_passage_repres, total_match_dim,
    #                                                           options.highway_layer_num)



    # final encoder

    qa_aggregation_input = final_passage_repres
    pa_aggregation_input = final_question_repres
    aggregation_representation = []
    aggregation_dim = 0
    with tf.variable_scope('aggregation_layer'):
        for i in range(options.aggregation_layer_num):  # support multiple aggregation layer
            if passage_mask != None:
                qa_aggregation_input = tf.multiply(qa_aggregation_input, tf.expand_dims(passage_mask, axis=-1))
            (fw_rep, bw_rep, cur_aggregation_representation) = layer_utils.my_lstm_layer(
                qa_aggregation_input, options.aggregation_lstm_dim, input_lengths=passage_lengths,
                scope_name=scope + '_left_layer-{}'.format(i),
                reuse=False, is_training=is_training, dropout_rate=options.dropout_rate, use_cudnn=options.use_cudnn)
            fw_rep = layer_utils.collect_final_step_of_lstm(fw_rep, passage_lengths - 1)
            bw_rep = bw_rep[:, 0, :]
            aggregation_representation.append(fw_rep)
            aggregation_representation.append(bw_rep)
            aggregation_dim += 2 * options.aggregation_lstm_dim
            qa_aggregation_input = cur_aggregation_representation  # [batch_size, passage_len, 2*aggregation_lstm_dim]
            if question_mask != None:
                pa_aggregation_input = tf.multiply(pa_aggregation_input, tf.expand_dims(question_mask, axis=-1))
            (fw_rep, bw_rep, cur_aggregation_representation) = layer_utils.my_lstm_layer(
                pa_aggregation_input, options.aggregation_lstm_dim,
                input_lengths=question_lengths, scope_name=scope + '_right_layer-{}'.format(i),
                reuse=False, is_training=is_training, dropout_rate=options.dropout_rate, use_cudnn=options.use_cudnn)
            fw_rep = layer_utils.collect_final_step_of_lstm(fw_rep, question_lengths - 1)
            bw_rep = bw_rep[:, 0, :]
            aggregation_representation.append(fw_rep)
            aggregation_representation.append(bw_rep)
            aggregation_dim += 2 * options.aggregation_lstm_dim
            pa_aggregation_input = cur_aggregation_representation  # [batch_size, passage_len, 2*aggregation_lstm_dim]

    aggregation_representation = tf.concat(axis=1, values=aggregation_representation)  # [batch_size, aggregation_dim]

    # ======Highway layer======
    if options.with_aggregation_highway:
        with tf.variable_scope(scope + "_aggregation_highway"):
            agg_shape = tf.shape(aggregation_representation)
            batch_size = agg_shape[0]
            aggregation_representation = tf.reshape(aggregation_representation, [1, batch_size, aggregation_dim])
            aggregation_representation = multi_highway_layer(aggregation_representation, aggregation_dim,
                                                             options.highway_layer_num)
            aggregation_representation = tf.reshape(aggregation_representation, [batch_size, aggregation_dim])

    return (aggregation_representation, aggregation_dim)
    def create_siameseLSTM_model_graph(self,
                                       num_classes,
                                       word_vocab=None,
                                       char_vocab=None,
                                       is_training=True,
                                       global_step=None):
        """
        """
        options = self.options
        # ======word representation layer======
        in_question_repres = []
        in_passage_repres = []
        input_dim = 0
        if word_vocab is not None:
            word_vec_trainable = True
            cur_device = '/gpu:0'
            if options.fix_word_vec:
                word_vec_trainable = False
                cur_device = '/cpu:0'
            with tf.device(cur_device):
                self.embedding = tf.placeholder(
                    tf.float32, shape=word_vocab.word_vecs.shape)
                self.word_embedding = tf.get_variable(
                    "word_embedding",
                    trainable=word_vec_trainable,
                    initializer=self.embedding,
                    dtype=tf.float32)  # tf.constant(word_vocab.word_vecs)

            in_question_word_repres = tf.nn.embedding_lookup(
                self.word_embedding,
                self.in_question_words)  # [batch_size, question_len, word_dim]
            in_passage_word_repres = tf.nn.embedding_lookup(
                self.word_embedding,
                self.in_passage_words)  # [batch_size, passage_len, word_dim]
            in_question_repres.append(in_question_word_repres)
            in_passage_repres.append(in_passage_word_repres)

            input_shape = tf.shape(self.in_question_words)
            batch_size = input_shape[0]
            question_len = input_shape[1]
            input_shape = tf.shape(self.in_passage_words)
            passage_len = input_shape[1]
            input_dim += word_vocab.word_dim

        in_question_repres = tf.concat(
            axis=2,
            values=in_question_repres)  # [batch_size, question_len, dim]
        in_passage_repres = tf.concat(
            axis=2, values=in_passage_repres)  # [batch_size, passage_len, dim]

        if is_training:
            in_question_repres = tf.nn.dropout(in_question_repres,
                                               (1 - options.dropout_rate))
            in_passage_repres = tf.nn.dropout(in_passage_repres,
                                              (1 - options.dropout_rate))

        passage_mask = tf.sequence_mask(
            self.passage_lengths, passage_len,
            dtype=tf.float32)  # [batch_size, passage_len]
        question_mask = tf.sequence_mask(
            self.question_lengths, question_len,
            dtype=tf.float32)  # [batch_size, question_len]

        # ======Highway layer======
        if options.with_highway:
            with tf.variable_scope("input_highway"):
                in_question_repres = match_utils.multi_highway_layer(
                    in_question_repres, input_dim, options.highway_layer_num)
                tf.get_variable_scope().reuse_variables()
                in_passage_repres = match_utils.multi_highway_layer(
                    in_passage_repres, input_dim, options.highway_layer_num)

        # ======BiLSTM context layer======
        for i in range(
                options.context_layer_num):  # support multiple context layer
            with tf.variable_scope('bilstm-layer-{}'.format(i)):
                # contextual lstm for both passage and question
                in_question_repres = tf.multiply(
                    in_question_repres, tf.expand_dims(question_mask, axis=-1))
                (question_context_representation_fw,
                 question_context_representation_bw,
                 in_question_repres) = layer_utils.my_lstm_layer(
                     in_question_repres,
                     options.context_lstm_dim,
                     input_lengths=self.question_lengths,
                     scope_name="context_represent",
                     reuse=False,
                     is_training=is_training,
                     dropout_rate=options.dropout_rate,
                     use_cudnn=options.use_cudnn)

                # Encode the second sentence, using the same LSTM weights.
                tf.get_variable_scope().reuse_variables()
                in_passage_repres = tf.multiply(
                    in_passage_repres, tf.expand_dims(passage_mask, axis=-1))
                (passage_context_representation_fw,
                 passage_context_representation_bw,
                 in_passage_repres) = layer_utils.my_lstm_layer(
                     in_passage_repres,
                     options.context_lstm_dim,
                     input_lengths=self.passage_lengths,
                     scope_name="context_represent",
                     reuse=True,
                     is_training=is_training,
                     dropout_rate=options.dropout_rate,
                     use_cudnn=options.use_cudnn)

        if options.lstm_out_type == 'mean':
            question_context_representation_fw = layer_utils.collect_mean_step_of_lstm(
                question_context_representation_fw)
            question_context_representation_bw = layer_utils.collect_mean_step_of_lstm(
                question_context_representation_bw)
            passage_context_representation_fw = layer_utils.collect_mean_step_of_lstm(
                passage_context_representation_fw)
            passage_context_representation_bw = layer_utils.collect_mean_step_of_lstm(
                passage_context_representation_bw)
        elif options.lstm_out_type == 'end':
            question_context_representation_fw = layer_utils.collect_final_step_of_lstm(
                question_context_representation_fw, self.question_lengths - 1)
            question_context_representation_bw = question_context_representation_bw[:,
                                                                                    0, :]
            passage_context_representation_fw = layer_utils.collect_final_step_of_lstm(
                passage_context_representation_fw, self.passage_lengths - 1)
            passage_context_representation_bw = passage_context_representation_bw[:,
                                                                                  0, :]

        question_context_outputs = tf.concat(
            axis=1,
            values=[
                question_context_representation_fw,
                question_context_representation_bw
            ])
        passage_context_outputs = tf.concat(
            axis=1,
            values=[
                passage_context_representation_fw,
                passage_context_representation_bw
            ])

        (match_representation, match_dim) = match_utils.siameseLSTM_match_func(
            question_context_outputs, passage_context_outputs,
            options.context_lstm_dim)

        #========Prediction Layer=========
        w_0 = tf.get_variable("w_0", [match_dim, int(match_dim / 2)],
                              dtype=tf.float32)
        b_0 = tf.get_variable("b_0", [int(match_dim / 2)], dtype=tf.float32)
        w_1 = tf.get_variable("w_1", [int(match_dim / 2), num_classes],
                              dtype=tf.float32)
        b_1 = tf.get_variable("b_1", [num_classes], dtype=tf.float32)

        # if is_training: match_representation = tf.nn.dropout(match_representation, (1 - options.dropout_rate))
        logits = tf.matmul(match_representation, w_0) + b_0
        logits = tf.nn.relu(logits)
        if is_training:
            logits = tf.nn.dropout(logits, (1 - options.dropout_rate))
        logits = tf.matmul(logits, w_1) + b_1

        self.prob = tf.nn.softmax(logits)
        self.predictions = tf.argmax(self.prob, 1)

        gold_matrix = tf.one_hot(self.truth, num_classes, dtype=tf.float32)
        self.loss = tf.reduce_mean(
            tf.nn.softmax_cross_entropy_with_logits(logits=logits,
                                                    labels=gold_matrix))

        correct = tf.nn.in_top_k(logits, self.truth, 1)
        self.eval_correct = tf.reduce_sum(tf.cast(correct, tf.int32))

        if not is_training: return

        tvars = tf.trainable_variables()
        if self.options.lambda_l1 > 0.0:
            l1_loss = tf.add_n([
                tf.contrib.layers.l1_regularizer(self.options.lambda_l1)(v)
                for v in tvars if v.get_shape().ndims > 1
            ])
            self.loss = self.loss + l1_loss
        if self.options.lambda_l2 > 0.0:
            # l2_loss = tf.add_n([tf.nn.l2_loss(v) for v in tvars if v.get_shape().ndims > 1])
            l2_loss = tf.add_n([
                tf.contrib.layers.l2_regularizer(self.options.lambda_l2)(v)
                for v in tvars if v.get_shape().ndims > 1
            ])
            self.loss = self.loss + l2_loss

        if self.options.optimize_type == 'adadelta':
            optimizer = tf.train.AdadeltaOptimizer(
                learning_rate=self.options.learning_rate)
        elif self.options.optimize_type == 'adam':
            optimizer = tf.train.AdamOptimizer(
                learning_rate=self.options.learning_rate)

        grads = layer_utils.compute_gradients(self.loss, tvars)
        grads, _ = tf.clip_by_global_norm(grads, self.options.grad_clipper)
        self.train_op = optimizer.apply_gradients(zip(grads, tvars),
                                                  global_step=global_step)
        # self.train_op = optimizer.apply_gradients(zip(grads, tvars))

        if self.options.with_moving_average:
            # Track the moving averages of all trainable variables.
            MOVING_AVERAGE_DECAY = 0.9999  # The decay to use for the moving average.
            variable_averages = tf.train.ExponentialMovingAverage(
                MOVING_AVERAGE_DECAY, global_step)
            variables_averages_op = variable_averages.apply(
                tf.trainable_variables())
            train_ops = [self.train_op, variables_averages_op]
            self.train_op = tf.group(*train_ops)
Exemplo n.º 7
0
def match_passage_with_question(passage_reps,
                                question_reps,
                                passage_mask,
                                question_mask,
                                passage_lengths,
                                question_lengths,
                                context_lstm_dim,
                                scope=None,
                                with_full_match=True,
                                with_maxpool_match=True,
                                with_attentive_match=True,
                                with_max_attentive_match=True,
                                is_training=True,
                                options=None,
                                dropout_rate=0,
                                forward=True):
    passage_mask = tf.cast(passage_mask, tf.float32)
    question_mask = tf.cast(question_mask, tf.float32)
    passage_reps = tf.multiply(passage_reps, tf.expand_dims(passage_mask, -1))
    question_reps = tf.multiply(question_reps,
                                tf.expand_dims(question_mask, -1))
    all_question_aware_representatins = []
    dim = 0
    with tf.variable_scope(scope or "match_passage_with_question"):
        relevancy_matrix = cal_relevancy_matrix(question_reps, passage_reps)
        relevancy_matrix = mask_relevancy_matrix(relevancy_matrix,
                                                 question_mask, passage_mask)
        # relevancy_matrix = layer_utils.calcuate_attention(passage_reps, question_reps, context_lstm_dim, context_lstm_dim,
        #			 scope_name="fw_attention", att_type=options.att_type, att_dim=options.att_dim,
        #			 remove_diagnoal=False, mask1=passage_mask, mask2=question_mask, is_training=is_training, dropout_rate=dropout_rate)

        all_question_aware_representatins.append(
            tf.reduce_max(relevancy_matrix, axis=2, keep_dims=True))
        all_question_aware_representatins.append(
            tf.reduce_mean(relevancy_matrix, axis=2, keep_dims=True))
        dim += 2
        if with_full_match:
            print("-------------using full match-----------")
            if forward:
                question_full_rep = layer_utils.collect_final_step_of_lstm(
                    question_reps, question_lengths - 1)
            else:
                question_full_rep = question_reps[:, 0, :]

            passage_len = tf.shape(passage_reps)[1]
            question_full_rep = tf.expand_dims(question_full_rep, axis=1)
            question_full_rep = tf.tile(
                question_full_rep,
                [1, passage_len, 1])  # [batch_size, pasasge_len, feature_dim]

            (attentive_rep, match_dim) = multi_perspective_match(
                context_lstm_dim,
                passage_reps,
                question_full_rep,
                is_training=is_training,
                dropout_rate=dropout_rate,
                options=options,
                scope_name='mp-match-full-match')
            all_question_aware_representatins.append(attentive_rep)
            dim += match_dim

        if with_maxpool_match:
            print("-------------using maxpool match-----------")
            maxpooling_decomp_params = tf.get_variable(
                "maxpooling_matching_decomp",
                shape=[options["cosine_MP_dim"], context_lstm_dim],
                dtype=tf.float32)
            maxpooling_rep = cal_maxpooling_matching(passage_reps,
                                                     question_reps,
                                                     maxpooling_decomp_params)
            all_question_aware_representatins.append(maxpooling_rep)
            dim += 2 * options["cosine_MP_dim"]

        if with_attentive_match:
            print("-------------using attentive match-----------")
            atten_scores = layer_utils.calcuate_attention(
                passage_reps,
                question_reps,
                context_lstm_dim,
                context_lstm_dim,
                scope_name="attention",
                att_type=options["att_type"],
                att_dim=options["att_dim"],
                remove_diagnoal=False,
                mask1=passage_mask,
                mask2=question_mask,
                is_training=is_training,
                dropout_rate=dropout_rate)
            att_question_contexts = tf.matmul(atten_scores, question_reps)
            (attentive_rep, match_dim) = multi_perspective_match(
                context_lstm_dim,
                passage_reps,
                att_question_contexts,
                is_training=is_training,
                dropout_rate=dropout_rate,
                options=options,
                scope_name='mp-match-att_question')
            all_question_aware_representatins.append(attentive_rep)
            dim += match_dim

        if with_max_attentive_match:
            print("-------------using max attentive match-----------")
            max_att = cal_max_question_representation(question_reps,
                                                      relevancy_matrix)
            (max_attentive_rep, match_dim) = multi_perspective_match(
                context_lstm_dim,
                passage_reps,
                max_att,
                is_training=is_training,
                dropout_rate=dropout_rate,
                options=options,
                scope_name='mp-match-max-att')
            all_question_aware_representatins.append(max_attentive_rep)
            dim += match_dim

        all_question_aware_representatins = tf.concat(
            axis=2, values=all_question_aware_representatins)
    return (all_question_aware_representatins, dim)
Exemplo n.º 8
0
def MCAN_match_func(in_question_repres,
                    in_passage_repres,
                    question_lengths,
                    passage_lengths,
                    question_mask,
                    passage_mask,
                    input_dim,
                    is_training,
                    scope="default",
                    options=None):
    question_reps = in_question_repres
    passage_reps = in_passage_repres

    relevancy_matrix = cal_relevancy_matrix(question_reps, passage_reps)
    relevancy_matrix = mask_relevancy_matrix(relevancy_matrix, question_mask,
                                             passage_mask)

    in_passage_repres = tf.concat([
        in_passage_repres,
        tf.reduce_max(relevancy_matrix, axis=2, keep_dims=True)
    ],
                                  axis=-1)
    in_passage_repres = tf.concat([
        in_passage_repres,
        tf.reduce_mean(relevancy_matrix, axis=2, keep_dims=True)
    ],
                                  axis=-1)

    qa_aggregation_input = in_passage_repres
    pa_aggregation_input = in_question_repres
    aggregation_representation = []
    aggregation_dim = 0
    with tf.variable_scope('aggregation_layer'):
        for i in range(options.aggregation_layer_num
                       ):  # support multiple aggregation layer
            if passage_mask != None:
                qa_aggregation_input = tf.multiply(
                    qa_aggregation_input, tf.expand_dims(passage_mask,
                                                         axis=-1))
            (fw_rep, bw_rep,
             cur_aggregation_representation) = layer_utils.my_lstm_layer(
                 qa_aggregation_input,
                 options.aggregation_lstm_dim,
                 input_lengths=passage_lengths,
                 scope_name=scope + '_left_layer-{}'.format(i),
                 reuse=False,
                 is_training=is_training,
                 dropout_rate=options.dropout_rate,
                 use_cudnn=options.use_cudnn)
            fw_rep = layer_utils.collect_final_step_of_lstm(
                fw_rep, passage_lengths - 1)
            bw_rep = bw_rep[:, 0, :]
            aggregation_representation.append(fw_rep)
            aggregation_representation.append(bw_rep)
            aggregation_dim += 2 * options.aggregation_lstm_dim
            qa_aggregation_input = cur_aggregation_representation  # [batch_size, passage_len, 2*aggregation_lstm_dim]
            if question_mask != None:
                pa_aggregation_input = tf.multiply(
                    pa_aggregation_input, tf.expand_dims(question_mask,
                                                         axis=-1))
            (fw_rep, bw_rep,
             cur_aggregation_representation) = layer_utils.my_lstm_layer(
                 pa_aggregation_input,
                 options.aggregation_lstm_dim,
                 input_lengths=question_lengths,
                 scope_name=scope + '_right_layer-{}'.format(i),
                 reuse=False,
                 is_training=is_training,
                 dropout_rate=options.dropout_rate,
                 use_cudnn=options.use_cudnn)
            fw_rep = layer_utils.collect_final_step_of_lstm(
                fw_rep, question_lengths - 1)
            bw_rep = bw_rep[:, 0, :]
            aggregation_representation.append(fw_rep)
            aggregation_representation.append(bw_rep)
            aggregation_dim += 2 * options.aggregation_lstm_dim
            pa_aggregation_input = cur_aggregation_representation  # [batch_size, passage_len, 2*aggregation_lstm_dim]

    aggregation_representation = tf.concat(
        axis=1,
        values=aggregation_representation)  # [batch_size, aggregation_dim]

    # ======Highway layer======
    if options.with_aggregation_highway:
        with tf.variable_scope(scope + "_aggregation_highway"):
            agg_shape = tf.shape(aggregation_representation)
            batch_size = agg_shape[0]
            aggregation_representation = tf.reshape(
                aggregation_representation, [1, batch_size, aggregation_dim])
            aggregation_representation = multi_highway_layer(
                aggregation_representation, aggregation_dim,
                options.highway_layer_num)
            aggregation_representation = tf.reshape(
                aggregation_representation, [batch_size, aggregation_dim])

    return (aggregation_representation, aggregation_dim)
Exemplo n.º 9
0
    def _build_graph(self):
        node_1_mask = self.batch_mask_first
        node_2_mask = self.batch_mask_second
        node_1_looking_table = self.looking_table_first
        node_2_looking_table = self.looking_table_second

        node_2_aware_representations = []
        node_2_aware_dim = 0
        node_1_aware_representations = []
        node_1_aware_dim = 0

        pad_word_embedding = tf.zeros([1, self.word_embedding_dim
                                       ])  # this is for the PAD symbol
        self.word_embeddings = tf.concat([
            pad_word_embedding,
            tf.get_variable(
                'pretrained_embedding',
                shape=[self.pretrained_word_size, self.word_embedding_dim],
                initializer=tf.constant_initializer(
                    self.pretrained_word_embeddings),
                trainable=True),
            tf.get_variable(
                'W_train',
                shape=[self.learned_word_size, self.word_embedding_dim],
                initializer=tf.contrib.layers.xavier_initializer(),
                trainable=True)
        ], 0)

        self.watch['word_embeddings'] = self.word_embeddings

        # ============ encode node feature by looking up word embedding =============
        with tf.variable_scope('node_rep_gen'):
            # [node_size, hidden_layer_dim]
            feature_embedded_chars_first = tf.nn.embedding_lookup(
                self.word_embeddings, self.feature_info_first)
            graph_1_size = tf.shape(feature_embedded_chars_first)[0]

            feature_embedded_chars_second = tf.nn.embedding_lookup(
                self.word_embeddings, self.feature_info_second)
            graph_2_size = tf.shape(feature_embedded_chars_second)[0]

            if self.node_vec_method == "lstm":
                cell = self.build_encoder_cell(1, self.hidden_layer_dim)

                outputs, hidden_states = tf.nn.dynamic_rnn(
                    cell=cell,
                    inputs=feature_embedded_chars_first,
                    sequence_length=self.feature_len_first,
                    dtype=tf.float32)
                node_1_rep = layer_utils.collect_final_step_of_lstm(
                    outputs, self.feature_len_first - 1)

                outputs, hidden_states = tf.nn.dynamic_rnn(
                    cell=cell,
                    inputs=feature_embedded_chars_second,
                    sequence_length=self.feature_len_second,
                    dtype=tf.float32)
                node_2_rep = layer_utils.collect_final_step_of_lstm(
                    outputs, self.feature_len_second - 1)

            elif self.node_vec_method == "word_emb":
                node_1_rep = tf.reshape(feature_embedded_chars_first,
                                        [graph_1_size, -1])
                node_2_rep = tf.reshape(feature_embedded_chars_second,
                                        [graph_2_size, -1])

            self.watch["node_1_rep_initial"] = node_1_rep

        # ============ encode node feature by GCN =============
        with tf.variable_scope('first_gcn') as first_gcn_scope:
            # shape of node embedding: [batch_size, single_graph_nodes_size, node_embedding_dim]
            # shape of node size: [batch_size]
            gcn_1_res = self.gcn_encode(
                self.batch_nodes_first,
                node_1_rep,
                self.fw_adj_info_first,
                self.bw_adj_info_first,
                input_node_dim=self.word_embedding_dim,
                output_node_dim=self.aggregator_dim_first,
                fw_aggregators=self.fw_aggregators_first,
                bw_aggregators=self.bw_aggregators_first,
                window_size=self.gcn_window_size_first,
                layer_size=self.gcn_layer_size_first,
                scope="first_gcn",
                agg_type=self.agg_type_first,
                sample_size_per_layer=self.sample_size_per_layer_first,
                keep_inter_state=self.if_use_multiple_gcn_1_state)

            node_1_rep = gcn_1_res[0]
            node_1_rep_dim = gcn_1_res[3]

            gcn_2_res = self.gcn_encode(
                self.batch_nodes_second,
                node_2_rep,
                self.fw_adj_info_second,
                self.bw_adj_info_second,
                input_node_dim=self.word_embedding_dim,
                output_node_dim=self.aggregator_dim_first,
                fw_aggregators=self.fw_aggregators_first,
                bw_aggregators=self.bw_aggregators_first,
                window_size=self.gcn_window_size_first,
                layer_size=self.gcn_layer_size_first,
                scope="first_gcn",
                agg_type=self.agg_type_first,
                sample_size_per_layer=self.sample_size_per_layer_second,
                keep_inter_state=self.if_use_multiple_gcn_1_state)

            node_2_rep = gcn_2_res[0]
            node_2_rep_dim = gcn_2_res[3]

        self.watch["node_1_rep_first_GCN"] = node_1_rep
        self.watch["node_1_mask"] = node_1_mask

        # mask
        node_1_rep = tf.multiply(node_1_rep, tf.expand_dims(node_1_mask, 2))
        node_2_rep = tf.multiply(node_2_rep, tf.expand_dims(node_2_mask, 2))

        self.watch["node_1_rep_first_GCN_masked"] = node_1_rep

        if self.pred_method == "node_level":
            entity_1_rep = tf.reshape(
                tf.nn.embedding_lookup(tf.transpose(node_1_rep, [1, 0, 2]),
                                       tf.constant(0)), [-1, node_1_rep_dim])
            entity_2_rep = tf.reshape(
                tf.nn.embedding_lookup(tf.transpose(node_2_rep, [1, 0, 2]),
                                       tf.constant(0)), [-1, node_2_rep_dim])

            entity_1_2_diff = entity_1_rep - entity_2_rep
            entity_1_2_sim = entity_1_rep * entity_2_rep

            aggregation = tf.concat(
                [entity_1_rep, entity_2_rep, entity_1_2_diff, entity_1_2_sim],
                axis=1)
            aggregation_dim = 4 * node_1_rep_dim

            w_0 = tf.get_variable("w_0",
                                  [aggregation_dim, aggregation_dim / 2],
                                  dtype=tf.float32)
            b_0 = tf.get_variable("b_0", [aggregation_dim / 2],
                                  dtype=tf.float32)
            w_1 = tf.get_variable("w_1", [aggregation_dim / 2, 2],
                                  dtype=tf.float32)
            b_1 = tf.get_variable("b_1", [2], dtype=tf.float32)

            # ====== Prediction Layer ===============
            logits = tf.matmul(aggregation, w_0) + b_0
            logits = tf.tanh(logits)
            logits = tf.matmul(logits, w_1) + b_1

        elif self.pred_method == "graph_level":
            # if the prediction method is graph_level, we perform the graph matching based prediction

            assert node_1_rep_dim == node_2_rep_dim
            input_dim = node_1_rep_dim

            with tf.variable_scope('node_level_matching') as matching_scope:
                # ========= node level matching ===============
                (match_reps,
                 match_dim) = match_graph_1_with_graph_2(node_1_rep,
                                                         node_2_rep,
                                                         node_1_mask,
                                                         node_2_mask,
                                                         input_dim,
                                                         options=options,
                                                         watch=self.watch)

                matching_scope.reuse_variables()

                node_2_aware_representations.append(match_reps)
                node_2_aware_dim += match_dim

                (match_reps,
                 match_dim) = match_graph_1_with_graph_2(node_2_rep,
                                                         node_1_rep,
                                                         node_2_mask,
                                                         node_1_mask,
                                                         input_dim,
                                                         options=options,
                                                         watch=self.watch)

                node_1_aware_representations.append(match_reps)
                node_1_aware_dim += match_dim

            # TODO: add one more MP matching over the graph representation
            # with tf.variable_scope('context_MP_matching'):
            #     for i in range(options['context_layer_num']):
            #         with tf.variable_scope('layer-{}',format(i)):

            # [batch_size, single_graph_nodes_size, node_2_aware_dim]
            node_2_aware_representations = tf.concat(
                axis=2, values=node_2_aware_representations)

            # [batch_size, single_graph_nodes_size, node_1_aware_dim]
            node_1_aware_representations = tf.concat(
                axis=2, values=node_1_aware_representations)

            # if self.mode == "train":
            #     node_2_aware_representations = tf.nn.dropout(node_2_aware_representations, (1 - options['dropout_rate']))
            #     node_1_aware_representations = tf.nn.dropout(node_1_aware_representations, (1 - options['dropout_rate']))

            # ========= Highway layer ==============
            if self.with_match_highway:
                with tf.variable_scope("left_matching_highway"):
                    node_2_aware_representations = multi_highway_layer(
                        node_2_aware_representations, node_2_aware_dim,
                        options['highway_layer_num'])
                with tf.variable_scope("right_matching_highway"):
                    node_1_aware_representations = multi_highway_layer(
                        node_1_aware_representations, node_1_aware_dim,
                        options['highway_layer_num'])

            self.watch["node_1_rep_match"] = node_2_aware_representations

            # ========= Aggregation Layer ==============
            aggregation_representation = []
            aggregation_dim = 0

            node_2_aware_aggregation_input = node_2_aware_representations
            node_1_aware_aggregation_input = node_1_aware_representations

            self.watch[
                "node_1_rep_match_layer"] = node_2_aware_aggregation_input

            with tf.variable_scope('aggregation_layer'):
                # TODO: now we only have 1 aggregation layer; need to change this part if support more aggregation layers
                # [batch_size, single_graph_nodes_size, node_2_aware_dim]
                node_2_aware_aggregation_input = tf.multiply(
                    node_2_aware_aggregation_input,
                    tf.expand_dims(node_1_mask, axis=-1))

                # [batch_size, single_graph_nodes_size, node_1_aware_dim]
                node_1_aware_aggregation_input = tf.multiply(
                    node_1_aware_aggregation_input,
                    tf.expand_dims(node_2_mask, axis=-1))

                if self.agg_sim_method == "GCN":
                    # [batch_size*single_graph_nodes_size, node_2_aware_dim]
                    node_2_aware_aggregation_input = tf.reshape(
                        node_2_aware_aggregation_input,
                        shape=[-1, node_2_aware_dim])

                    # [batch_size*single_graph_nodes_size, node_1_aware_dim]
                    node_1_aware_aggregation_input = tf.reshape(
                        node_1_aware_aggregation_input,
                        shape=[-1, node_1_aware_dim])

                    # [node_1_size, node_2_aware_dim]
                    node_1_rep = tf.concat([
                        tf.nn.embedding_lookup(node_2_aware_aggregation_input,
                                               node_1_looking_table),
                        tf.zeros([1, node_2_aware_dim])
                    ], 0)

                    # [node_2_size, node_1_aware_dim]
                    node_2_rep = tf.concat([
                        tf.nn.embedding_lookup(node_1_aware_aggregation_input,
                                               node_2_looking_table),
                        tf.zeros([1, node_1_aware_dim])
                    ], 0)

                    gcn_1_res = self.gcn_encode(
                        self.batch_nodes_first,
                        node_1_rep,
                        self.fw_adj_info_first,
                        self.bw_adj_info_first,
                        input_node_dim=node_2_aware_dim,
                        output_node_dim=self.aggregator_dim_second,
                        fw_aggregators=self.fw_aggregators_second,
                        bw_aggregators=self.bw_aggregators_second,
                        window_size=self.gcn_window_size_second,
                        layer_size=self.gcn_layer_size_second,
                        scope="second_gcn",
                        agg_type=self.agg_type_second,
                        sample_size_per_layer=self.sample_size_per_layer_first,
                        keep_inter_state=self.if_use_multiple_gcn_2_state)

                    max_graph_1_rep = gcn_1_res[1]
                    mean_graph_1_rep = gcn_1_res[2]
                    graph_1_rep_dim = gcn_1_res[3]

                    gcn_2_res = self.gcn_encode(
                        self.batch_nodes_second,
                        node_2_rep,
                        self.fw_adj_info_second,
                        self.bw_adj_info_second,
                        input_node_dim=node_1_aware_dim,
                        output_node_dim=self.aggregator_dim_second,
                        fw_aggregators=self.fw_aggregators_second,
                        bw_aggregators=self.bw_aggregators_second,
                        window_size=self.gcn_window_size_second,
                        layer_size=self.gcn_layer_size_second,
                        scope="second_gcn",
                        agg_type=self.agg_type_second,
                        sample_size_per_layer=self.
                        sample_size_per_layer_second,
                        keep_inter_state=self.if_use_multiple_gcn_2_state)

                    max_graph_2_rep = gcn_2_res[1]
                    mean_graph_2_rep = gcn_2_res[2]
                    graph_2_rep_dim = gcn_2_res[3]

                    assert graph_1_rep_dim == graph_2_rep_dim

                    if self.if_use_multiple_gcn_2_state:
                        graph_1_reps = gcn_1_res[5]
                        graph_2_reps = gcn_2_res[5]
                        inter_dims = gcn_1_res[6]
                        for idx in range(len(graph_1_reps)):
                            (max_graph_1_rep_tmp,
                             mean_graph_1_rep_tmp) = graph_1_reps[idx]
                            (max_graph_2_rep_tmp,
                             mean_graph_2_rep_tmp) = graph_2_reps[idx]
                            inter_dim = inter_dims[idx]
                            aggregation_representation.append(
                                max_graph_1_rep_tmp)
                            aggregation_representation.append(
                                mean_graph_1_rep_tmp)
                            aggregation_representation.append(
                                max_graph_2_rep_tmp)
                            aggregation_representation.append(
                                mean_graph_2_rep_tmp)
                            aggregation_dim += 4 * inter_dim

                    else:
                        aggregation_representation.append(max_graph_1_rep)
                        aggregation_representation.append(mean_graph_1_rep)
                        aggregation_representation.append(max_graph_2_rep)
                        aggregation_representation.append(mean_graph_2_rep)
                        aggregation_dim = 4 * graph_1_rep_dim

                    # aggregation_representation = tf.concat(aggregation_representation, axis=1)

                    gcn_2_window_size = int(
                        len(aggregation_representation) / 4)
                    aggregation_dim = aggregation_dim / gcn_2_window_size

                    w_0 = tf.get_variable(
                        "w_0", [aggregation_dim, aggregation_dim / 2],
                        dtype=tf.float32)
                    b_0 = tf.get_variable("b_0", [aggregation_dim / 2],
                                          dtype=tf.float32)
                    w_1 = tf.get_variable("w_1", [aggregation_dim / 2, 2],
                                          dtype=tf.float32)
                    b_1 = tf.get_variable("b_1", [2], dtype=tf.float32)

                    weights = tf.get_variable("gcn_2_window_weights",
                                              [gcn_2_window_size],
                                              dtype=tf.float32)

                    # shape: [gcn_2_window_size, batch_size, 2]
                    logits = []
                    for layer_idx in range(gcn_2_window_size):
                        max_graph_1_rep = aggregation_representation[
                            layer_idx * 4 + 0]
                        mean_graph_1_rep = aggregation_representation[
                            layer_idx * 4 + 1]
                        max_graph_2_rep = aggregation_representation[
                            layer_idx * 4 + 2]
                        mean_graph_2_rep = aggregation_representation[
                            layer_idx * 4 + 3]

                        aggregation_representation_single = tf.concat([
                            max_graph_1_rep, mean_graph_1_rep, max_graph_2_rep,
                            mean_graph_2_rep
                        ],
                                                                      axis=1)

                        # ====== Prediction Layer ===============
                        logit = tf.matmul(aggregation_representation_single,
                                          w_0) + b_0
                        logit = tf.tanh(logit)
                        logit = tf.matmul(logit, w_1) + b_1
                        logits.append(logit)

                    if len(logits) != 1:
                        logits = tf.reshape(tf.concat(logits, axis=0),
                                            [gcn_2_window_size, -1, 2])
                        logits = tf.transpose(logits, [1, 0, 2])
                        logits = tf.multiply(logits,
                                             tf.expand_dims(weights, axis=-1))
                        logits = tf.reduce_sum(logits, axis=1)
                    else:
                        logits = tf.reshape(logits, [-1, 2])

        # ====== Highway layer ============
        # if options['with_aggregation_highway']:

        with tf.name_scope("loss"):
            self.y_pred = tf.nn.softmax(logits)
            self.loss = tf.reduce_sum(
                tf.nn.softmax_cross_entropy_with_logits(
                    labels=self.y_true,
                    logits=logits, name="xentropy_loss")) / tf.cast(
                        self.batch_size, tf.float32)

        # ============  Training Objective ===========================
        if self.mode == "train" and not self.if_pred_on_dev:
            optimizer = tf.train.AdamOptimizer()
            params = tf.trainable_variables()
            gradients = tf.gradients(self.loss, params)
            clipped_gradients, _ = tf.clip_by_global_norm(gradients, 1)
            self.training_op = optimizer.apply_gradients(
                zip(clipped_gradients, params))
Exemplo n.º 10
0
def multi_granularity_match(feature_dim,
                            passage,
                            question,
                            passage_length,
                            question_length,
                            passage_mask=None,
                            question_mask=None,
                            is_training=True,
                            dropout_rate=0.2,
                            options=None,
                            with_full_matching=False,
                            with_attentive_matching=True,
                            with_max_attentive_matching=True,
                            scope_name='mgm',
                            reuse=False):
    '''
        passage: [batch_size, passage_length, feature_dim]
        question: [batch_size, question_length, feature_dim]
        passage_length: [batch_size]
        question_length: [batch_size]
    '''
    input_shape = tf.shape(passage)
    batch_size = input_shape[0]
    passage_len = input_shape[1]

    match_reps = []
    with tf.variable_scope(scope_name, reuse=reuse):
        match_dim = 0
        if with_full_matching:
            passage_fw = passage[:, :, 0:feature_dim / 2]
            passage_bw = passage[:, :, feature_dim / 2:feature_dim]

            question_fw = question[:, :, 0:feature_dim / 2]
            question_bw = question[:, :, feature_dim / 2:feature_dim]
            question_fw = layer_utils.collect_final_step_of_lstm(
                question_fw,
                question_length - 1)  # [batch_size, feature_dim/2]
            question_bw = question_bw[:, 0, :]

            question_fw = tf.expand_dims(question_fw, axis=1)
            question_fw = tf.tile(
                question_fw,
                [1, passage_len, 1])  # [batch_size, pasasge_len, feature_dim]

            question_bw = tf.expand_dims(question_bw, axis=1)
            question_bw = tf.tile(
                question_bw,
                [1, passage_len, 1])  # [batch_size, pasasge_len, feature_dim]
            (fw_full_match_reps, fw_full_match_dim) = multi_perspective_match(
                feature_dim / 2,
                passage_fw,
                question_fw,
                is_training=is_training,
                dropout_rate=dropout_rate,
                options=options,
                scope_name='fw_full_match')
            (bw_full_match_reps, bw_full_match_dim) = multi_perspective_match(
                feature_dim / 2,
                passage_bw,
                question_bw,
                is_training=is_training,
                dropout_rate=dropout_rate,
                options=options,
                scope_name='bw_full_match')
            match_reps.append(fw_full_match_reps)
            match_reps.append(bw_full_match_reps)
            match_dim += fw_full_match_dim
            match_dim += bw_full_match_dim

        if with_attentive_matching or with_max_attentive_matching:
            atten_scores = layer_utils.calcuate_attention(
                passage,
                question,
                feature_dim,
                feature_dim,
                scope_name="attention",
                att_type=options.attn_type,
                att_dim=options.attn_depth,
                remove_diagnoal=False,
                mask1=passage_mask,
                mask2=question_mask,
                is_training=is_training,
                dropout_rate=dropout_rate)
            # match_reps.append(tf.reduce_max(atten_scores, axis=2, keep_dims=True))
            # match_reps.append(tf.reduce_mean(atten_scores, axis=2, keep_dims=True))
            # match_dim += 2

        if with_max_attentive_matching:
            atten_positions = tf.argmax(
                atten_scores, axis=2,
                output_type=tf.int32)  # [batch_size, passage_len]
            max_question_reps = layer_utils.collect_representation(
                question, atten_positions)
            (max_att_match_rep, max_att_match_dim) = multi_perspective_match(
                feature_dim,
                passage,
                max_question_reps,
                is_training=is_training,
                dropout_rate=dropout_rate,
                options=options,
                scope_name='max_att_match')
            match_reps.append(max_att_match_rep)
            match_dim += max_att_match_dim

        if with_attentive_matching:
            att_rep = tf.matmul(atten_scores, question)
            (attentive_match_rep,
             attentive_match_dim) = multi_perspective_match(
                 feature_dim,
                 passage,
                 att_rep,
                 is_training=is_training,
                 dropout_rate=dropout_rate,
                 options=options,
                 scope_name='att_match')
            match_reps.append(attentive_match_rep)
            match_dim += attentive_match_dim
    match_reps = tf.concat(axis=2, values=match_reps)
    return (match_reps, match_dim)