Beispiel #1
0
    def _build_graph(self):
        # build input

        self.context_word = tf.placeholder(tf.int32, [None, None])
        self.context_len = tf.placeholder(tf.int32, [None])
        self.question_word = tf.placeholder(tf.int32, [None, None])
        self.question_len = tf.placeholder(tf.int32, [None])
        self.answer_start = tf.placeholder(tf.int32, [None])
        self.answer_end = tf.placeholder(tf.int32, [None])
        self.training = tf.placeholder(tf.bool, [])
        self.match_lemma = tf.placeholder(tf.int32, [None, None])
        self.match_lower = tf.placeholder(tf.int32, [None, None])
        self.pos_feature = tf.placeholder(tf.int32, [None, None])
        self.ner_feature = tf.placeholder(tf.int32, [None, None])
        self.normalized_tf = tf.placeholder(tf.int32, [None, None])

        # 1. Word encoding
        word_embedding = PartiallyTrainableEmbedding(
            trainable_num=self.finetune_word_size,
            pretrained_embedding=self.pretrained_word_embedding,
            embedding_shape=(len(self.vocab.get_word_vocab()) + 1,
                             self.word_embedding_size))
        # 1.1 Embedding
        context_word_repr = word_embedding(self.context_word)
        question_word_repr = word_embedding(self.question_word)

        # pos embedding
        if len(self.features) > 0 and 'pos' in self.features:
            pos_embedding = Embedding(
                pretrained_embedding=None,
                embedding_shape=(len(self.feature_vocab['pos']) + 1,
                                 self.pos_embedding_size))
            context_pos_feature = pos_embedding(self.pos_feature)
        if len(self.features) > 0 and 'ner' in self.features:
            ner_embedding = Embedding(
                pretrained_embedding=None,
                embedding_shape=(len(self.feature_vocab['ner']) + 1,
                                 self.ner_embedding_size))
            context_ner_feature = ner_embedding(self.ner_feature)
        # add dropout
        dropout = VariationalDropout(self.dropout_keep_prob)
        context_word_repr = dropout(context_word_repr, self.training)
        question_word_repr = dropout(question_word_repr, self.training)
        glove_word_repr = context_word_repr
        glove_question_repr = question_word_repr
        if self.use_cove_emb and self.cove_path is not None:
            cove_embedding = CoveEmbedding(
                cove_path=self.cove_path,
                pretrained_word_embedding=self.pretrained_word_embedding,
                vocab=self.vocab)
            cove_context_repr = cove_embedding(self.context_word,
                                               self.context_len)
            cove_question_repr = cove_embedding(self.question_word,
                                                self.question_len)
            cove_context_repr = dropout(cove_context_repr, self.training)
            cove_question_repr = dropout(cove_question_repr, self.training)
        # exact_match feature
        context_expanded = tf.tile(
            tf.expand_dims(self.context_word, axis=-1),
            [1, 1, tf.shape(self.question_word)[1]])
        query_expanded = tf.tile(tf.expand_dims(self.question_word, axis=1),
                                 [1, tf.shape(self.context_word)[1], 1])
        exact_match_feature = tf.cast(
            tf.reduce_any(tf.equal(context_expanded, query_expanded), axis=-1),
            tf.float32)
        exact_match_feature = tf.expand_dims(exact_match_feature, axis=-1)

        if len(self.features) > 0 and 'match_lower' in self.features:
            match_lower_feature = tf.expand_dims(tf.cast(
                self.match_lower, tf.float32),
                                                 axis=-1)
            exact_match_feature = tf.concat(
                [exact_match_feature, match_lower_feature], axis=-1)
        if len(self.features) > 0 and 'match_lemma' in self.features:
            exact_match_feature = tf.concat([
                exact_match_feature,
                tf.expand_dims(tf.cast(self.match_lemma, tf.float32), axis=-1)
            ],
                                            axis=-1)
        # features

        if len(self.features) > 0 and 'pos' in self.features:
            context_word_repr = tf.concat(
                [context_word_repr, context_pos_feature], axis=-1)
        if len(self.features) > 0 and 'ner' in self.features:
            context_word_repr = tf.concat(
                [context_word_repr, context_ner_feature], axis=-1)
        if len(self.features) > 0 and 'context_tf' in self.features:
            context_word_repr = tf.concat([
                context_word_repr,
                tf.cast(tf.expand_dims(self.normalized_tf, axis=-1),
                        tf.float32)
            ],
                                          axis=-1)

        if self.use_cove_emb:
            context_word_repr = tf.concat(
                [cove_context_repr, context_word_repr], axis=-1)
            question_word_repr = tf.concat(
                [cove_question_repr, question_word_repr], axis=-1)

        # 1.2word fusion
        sim_function = ProjectedDotProduct(self.rnn_hidden_size,
                                           activation=tf.nn.relu,
                                           reuse_weight=True)
        word_fusion = UniAttention(sim_function)
        context2question_fusion = word_fusion(glove_word_repr,
                                              glove_question_repr,
                                              self.question_len)
        enhanced_context_repr = tf.concat(
            [context_word_repr, exact_match_feature, context2question_fusion],
            axis=-1)

        enhanced_context_repr = dropout(enhanced_context_repr, self.training)
        # 1.3.1 context encoder
        context_encoder = [
            CudnnBiLSTM(self.rnn_hidden_size),
            CudnnBiLSTM(self.rnn_hidden_size)
        ]
        context_low_repr, _ = context_encoder[0](enhanced_context_repr,
                                                 self.context_len)
        context_low_repr = dropout(context_low_repr, self.training)
        context_high_repr, _ = context_encoder[1](context_low_repr,
                                                  self.context_len)
        context_high_repr = dropout(context_high_repr, self.training)
        # 1.3.2 question encoder

        question_encoder = [
            CudnnBiLSTM(self.rnn_hidden_size),
            CudnnBiLSTM(self.rnn_hidden_size)
        ]
        question_low_repr, _ = question_encoder[0](question_word_repr,
                                                   self.question_len)
        question_low_repr = dropout(question_low_repr, self.training)
        question_high_repr, _ = question_encoder[1](question_low_repr,
                                                    self.question_len)
        question_high_repr = dropout(question_high_repr, self.training)
        # 1.4 question understanding
        question_understanding_encoder = CudnnBiLSTM(self.rnn_hidden_size)
        question_understanding, _ = question_understanding_encoder(
            tf.concat([question_low_repr, question_high_repr], axis=-1),
            self.question_len)
        question_understanding = dropout(question_understanding, self.training)

        # history of context
        context_history = tf.concat(
            [glove_word_repr, context_low_repr, context_high_repr], axis=-1)

        # histor of question
        question_history = tf.concat(
            [glove_question_repr, question_low_repr, question_high_repr],
            axis=-1)

        # concat cove emb
        if self.use_cove_emb:
            context_history = tf.concat([cove_context_repr, context_history],
                                        axis=-1)
            question_history = tf.concat(
                [cove_question_repr, question_history], axis=-1)

        # 1.5.1 low level fusion
        low_level_attn = UniAttention(SymmetricProject(
            self.attention_hidden_size),
                                      name='low_level_fusion')
        low_level_fusion = low_level_attn(context_history, question_history,
                                          self.question_len, question_low_repr)
        low_level_fusion = dropout(low_level_fusion, self.training)

        # 1.5.2 high level fusion
        high_level_attn = UniAttention(SymmetricProject(
            self.attention_hidden_size, name='high_sim_func'),
                                       name='high_level_fusion')
        high_level_fusion = high_level_attn(context_history, question_history,
                                            self.question_len,
                                            question_high_repr)
        high_level_fusion = dropout(high_level_fusion, self.training)

        # 1.5.3 understanding level fusion
        understanding_attn = UniAttention(SymmetricProject(
            self.attention_hidden_size, name='understanding_sim_func'),
                                          name='understanding_level_fusion')
        understanding_fusion = understanding_attn(context_history,
                                                  question_history,
                                                  self.question_len,
                                                  question_understanding)
        understanding_fusion = dropout(understanding_fusion, self.training)

        # merge context attention
        fully_aware_encoder = CudnnBiLSTM(self.rnn_hidden_size)
        full_fusion_context = tf.concat([
            context_low_repr, context_high_repr, low_level_fusion,
            high_level_fusion, understanding_fusion
        ],
                                        axis=-1)
        full_fusion_context_repr, _ = fully_aware_encoder(
            full_fusion_context, self.context_len)

        # history of context
        context_history = tf.concat(
            [glove_word_repr, full_fusion_context, full_fusion_context_repr],
            axis=-1)
        if self.use_cove_emb:
            context_history = tf.concat([cove_context_repr, context_history],
                                        axis=-1)

        # 1.6 self boosted fusion
        self_boosted_attn = UniAttention(SymmetricProject(
            self.attention_hidden_size, name='self_boosted_attn'),
                                         name='boosted_fusion')
        boosted_fusion = self_boosted_attn(context_history, context_history,
                                           self.context_len,
                                           full_fusion_context_repr)
        boosted_fusion = dropout(boosted_fusion, self.training)

        # 1.7 context vectors
        context_final_encoder = CudnnBiLSTM(self.rnn_hidden_size)
        context_repr, _ = context_final_encoder(
            tf.concat([full_fusion_context_repr, boosted_fusion], axis=-1),
            self.context_len)
        context_repr = dropout(context_repr, self.training)
        self_attn = SelfAttn()
        U_Q = self_attn(question_understanding, self.question_len)

        # start project
        start_project = tf.keras.layers.Dense(self.rnn_hidden_size * 2,
                                              use_bias=False)
        context_repr = dropout(context_repr, self.training)
        start_logits = tf.squeeze(tf.matmul(start_project(context_repr),
                                            tf.expand_dims(U_Q, axis=-1)),
                                  axis=-1)
        start_logits_masked = mask_logits(start_logits, self.context_len)
        self.start_prob = tf.nn.softmax(start_logits_masked)

        gru_input = tf.reduce_sum(tf.expand_dims(self.start_prob, axis=-1) *
                                  context_repr,
                                  axis=1)
        GRUCell = tf.contrib.rnn.GRUCell(self.rnn_hidden_size * 2)

        V_Q, _ = GRUCell(gru_input, U_Q)

        # end project
        end_project = tf.keras.layers.Dense(self.rnn_hidden_size * 2,
                                            use_bias=False)
        end_logits = tf.squeeze(tf.matmul(end_project(context_repr),
                                          tf.expand_dims(V_Q, axis=-1)),
                                axis=-1)
        end_logits_masked = mask_logits(end_logits, self.context_len)
        self.end_prob = tf.nn.softmax(end_logits_masked)
        # 7. Loss and input/output dict
        self.start_loss = tf.reduce_mean(
            tf.nn.sparse_softmax_cross_entropy_with_logits(
                logits=mask_logits(start_logits, self.context_len),
                labels=self.answer_start))
        self.end_loss = tf.reduce_mean(
            tf.nn.sparse_softmax_cross_entropy_with_logits(
                logits=mask_logits(end_logits, self.context_len),
                labels=self.answer_end))
        self.loss = self.start_loss + self.end_loss
        global_step = tf.train.get_or_create_global_step()

        input_dict = {
            "context_word": self.context_word,
            "context_len": self.context_len,
            "question_word": self.question_word,
            "question_len": self.question_len,
            "answer_start": self.answer_start,
            "answer_end": self.answer_end,
            "training": self.training,
        }
        if len(self.features) > 0 and 'match_lower' in self.features:
            input_dict['match_lower'] = self.match_lower
        if len(self.features) > 0 and 'match_lemma' in self.features:
            input_dict['match_lemma'] = self.match_lemma

        if self.use_outer_embedding:
            input_dict['context'] = self.context_string,
            input_dict['question'] = self.question_string
        if len(self.features) > 0 and 'pos' in self.features:
            input_dict['pos'] = self.pos_feature
        if len(self.features) > 0 and 'ner' in self.features:
            input_dict['ner'] = self.ner_feature
        if len(self.features) > 0 and 'context_tf' in self.features:
            input_dict['context_tf'] = self.normalized_tf

        self.input_placeholder_dict = OrderedDict(input_dict)

        self.output_variable_dict = OrderedDict({
            "start_prob": self.start_prob,
            "end_prob": self.end_prob
        })

        # 8. Metrics and summary
        with tf.variable_scope("train_metrics"):
            self.train_metrics = {'loss': tf.metrics.mean(self.loss)}

        self.train_update_metrics = tf.group(
            *[op for _, op in self.train_metrics.values()])
        metric_variables = tf.get_collection(tf.GraphKeys.LOCAL_VARIABLES,
                                             scope="train_metrics")
        self.train_metric_init_op = tf.variables_initializer(metric_variables)

        with tf.variable_scope("eval_metrics"):
            self.eval_metrics = {'loss': tf.metrics.mean(self.loss)}

        self.eval_update_metrics = tf.group(
            *[op for _, op in self.eval_metrics.values()])
        metric_variables = tf.get_collection(tf.GraphKeys.LOCAL_VARIABLES,
                                             scope="eval_metrics")
        self.eval_metric_init_op = tf.variables_initializer(metric_variables)

        tf.summary.scalar('loss', self.loss)
        self.summary_op = tf.summary.merge_all()
    def _build_graph(self):
        self.context_word = tf.placeholder(tf.int32, [None, None])
        self.context_len = tf.placeholder(tf.int32, [None])

        self.context_char = tf.placeholder(tf.int32, [None, None, None])
        self.context_word_len = tf.placeholder(tf.int32, [None, None])

        self.question_word = tf.placeholder(tf.int32, [None, None])
        self.question_len = tf.placeholder(tf.int32, [None])

        self.question_char = tf.placeholder(tf.int32, [None, None, None])
        self.question_word_len = tf.placeholder(tf.int32, [None, None])

        self.answer_start = tf.placeholder(tf.int32, [None])
        self.answer_end = tf.placeholder(tf.int32, [None])
        self.abstractive_answer_mask = tf.placeholder(tf.int32, [None, self.abstractive_answer_num])
        self.training = tf.placeholder(tf.bool, [])
        
        self.question_tokens = tf.placeholder(tf.string, [None, None])
        self.context_tokens = tf.placeholder(tf.string,[None,None])

        # 1. Word encoding
        word_embedding = Embedding(pretrained_embedding=self.pretrained_word_embedding,
                                   embedding_shape=(len(self.vocab.get_word_vocab()) + 1, self.word_embedding_size),
                                   trainable=self.word_embedding_trainable)
        char_embedding = Embedding(embedding_shape=(len(self.vocab.get_char_vocab()) + 1, self.char_embedding_size), trainable=True, init_scale=0.05)

        # 1.1 Embedding
        dropout = Dropout(self.keep_prob)
        context_word_repr = word_embedding(self.context_word)
        context_char_repr = char_embedding(self.context_char)
        question_word_repr = word_embedding(self.question_word)
        question_char_repr = char_embedding(self.question_char)
        if self.use_elmo:
            elmo_emb = ElmoEmbedding(local_path=self.elmo_local_path)
            context_elmo_repr = elmo_emb(self.context_tokens,self.context_len)
            context_elmo_repr = dropout(context_elmo_repr,self.training)
            question_elmo_repr = elmo_emb(self.question_tokens,self.question_len)
            question_elmo_repr = dropout(question_elmo_repr,self.training)

        # 1.2 Char convolution
        conv1d = Conv1DAndMaxPooling(self.char_conv_filters, self.char_conv_kernel_size)
        if self.max_pooling_mask:
            question_char_repr = conv1d(dropout(question_char_repr, self.training), self.question_word_len)
            context_char_repr = conv1d(dropout(context_char_repr, self.training), self.context_word_len)
        else:
            question_char_repr = conv1d(dropout(question_char_repr, self.training))
            context_char_repr = conv1d(dropout(context_char_repr, self.training))

        # 2. Phrase encoding
        context_embs = [context_word_repr, context_char_repr]
        question_embs = [question_word_repr, question_char_repr]
        if self.use_elmo:
            context_embs.append(context_elmo_repr)
            question_embs.append(question_elmo_repr)

        context_repr = tf.concat(context_embs, axis=-1)
        question_repr = tf.concat(question_embs, axis=-1)

        variational_dropout = VariationalDropout(self.keep_prob)
        emb_enc_gru = CudnnBiGRU(self.rnn_hidden_size)

        context_repr = variational_dropout(context_repr, self.training)
        context_repr, _ = emb_enc_gru(context_repr, self.context_len)
        context_repr = variational_dropout(context_repr,  self.training)

        question_repr = variational_dropout(question_repr, self.training)
        question_repr, _ = emb_enc_gru(question_repr, self.question_len)
        question_repr = variational_dropout(question_repr, self.training)

        # 3. Bi-Attention
        bi_attention = BiAttention(TriLinear(bias=True, name="bi_attention_tri_linear"))
        c2q, q2c = bi_attention(context_repr, question_repr, self.context_len, self.question_len)
        context_repr = tf.concat([context_repr, c2q, context_repr * c2q, context_repr * q2c], axis=-1)

        # 4. Self-Attention layer
        dense1 = tf.keras.layers.Dense(self.rnn_hidden_size*2, use_bias=True, activation=tf.nn.relu)
        gru = CudnnBiGRU(self.rnn_hidden_size)
        dense2 = tf.keras.layers.Dense(self.rnn_hidden_size*2, use_bias=True, activation=tf.nn.relu)
        self_attention = SelfAttention(TriLinear(bias=True, name="self_attention_tri_linear"))

        inputs = dense1(context_repr)
        outputs = variational_dropout(inputs, self.training)
        outputs, _ = gru(outputs, self.context_len)
        outputs = variational_dropout(outputs, self.training)
        c2c = self_attention(outputs, self.context_len)
        outputs = tf.concat([c2c, outputs, c2c * outputs], axis=len(c2c.shape)-1)
        outputs = dense2(outputs)
        context_repr = inputs + outputs
        context_repr = variational_dropout(context_repr, self.training)
        
        # 5. Modeling layer
        sum_max_encoding = SumMaxEncoder()
        context_modeling_gru1 = CudnnBiGRU(self.rnn_hidden_size)
        context_modeling_gru2 = CudnnBiGRU(self.rnn_hidden_size)
        question_modeling_gru1 = CudnnBiGRU(self.rnn_hidden_size)
        question_modeling_gru2 = CudnnBiGRU(self.rnn_hidden_size)
        self.max_context_len = tf.reduce_max(self.context_len)
        self.max_question_len = tf.reduce_max(self.question_len)

        modeled_context1, _ = context_modeling_gru1(context_repr, self.context_len)
        modeled_context2, _ = context_modeling_gru2(tf.concat([context_repr, modeled_context1], axis=2), self.context_len)
        encoded_context = sum_max_encoding(modeled_context1, self.context_len, self.max_context_len)
        modeled_question1, _ = question_modeling_gru1(question_repr, self.question_len)
        modeled_question2, _ = question_modeling_gru2(tf.concat([question_repr, modeled_question1], axis=2), self.question_len)
        encoded_question = sum_max_encoding(modeled_question2, self.question_len, self.max_question_len)
        
        # 6. Predictions
        start_dense = tf.keras.layers.Dense(1, activation=None)
        start_logits = tf.squeeze(start_dense(modeled_context1), squeeze_dims=[2])
        start_logits = mask_logits(start_logits, self.context_len)

        end_dense = tf.keras.layers.Dense(1, activation=None)
        end_logits = tf.squeeze(end_dense(modeled_context2), squeeze_dims=[2])
        end_logits = mask_logits(end_logits, self.context_len)

        abstractive_answer_logits = None
        if self.abstractive_answer_num != 0:
            abstractive_answer_logits = []
            for i in range(self.abstractive_answer_num):
                tri_linear = TriLinear(name="cls"+str(i))
                abstractive_answer_logits.append(tf.squeeze(tri_linear(encoded_context, encoded_question), squeeze_dims=[2]))
            abstractive_answer_logits = tf.concat(abstractive_answer_logits, axis=-1)

        # 7. Loss and input/output dict
        seq_length = tf.shape(start_logits)[1]
        start_mask = tf.one_hot(self.answer_start, depth=seq_length, dtype=tf.float32) 
        end_mask = tf.one_hot(self.answer_end, depth=seq_length, dtype=tf.float32) 
        if self.abstractive_answer_num != 0:
            abstractive_answer_mask = tf.cast(self.abstractive_answer_mask, dtype=tf.float32)
            extractive_mask = 1. - tf.reduce_max(abstractive_answer_mask, axis=-1, keepdims=True)
            start_mask = extractive_mask * start_mask
            end_mask = extractive_mask * end_mask

            concated_start_masks = tf.concat([start_mask, abstractive_answer_mask], axis=1)
            concated_end_masks = tf.concat([end_mask, abstractive_answer_mask], axis=1)

            concated_start_logits = tf.concat([start_logits, abstractive_answer_logits], axis=1)
            concated_end_logits = tf.concat([end_logits, abstractive_answer_logits], axis=1)
        else:
            concated_start_masks = start_mask
            concated_end_masks = end_mask

            concated_start_logits = start_logits
            concated_end_logits = end_logits

        start_log_norm = tf.reduce_logsumexp(concated_start_logits, axis=1)
        start_log_score = tf.reduce_logsumexp(concated_start_logits + VERY_NEGATIVE_NUMBER * (1- tf.cast(concated_start_masks, tf.float32)), axis=1)
        self.start_loss = tf.reduce_mean(-(start_log_score - start_log_norm))

        end_log_norm = tf.reduce_logsumexp(concated_end_logits, axis=1)
        end_log_score = tf.reduce_logsumexp(concated_end_logits + VERY_NEGATIVE_NUMBER * (1- tf.cast(concated_end_masks, tf.float32)), axis=1)
        self.end_loss = tf.reduce_mean(-(end_log_score - end_log_norm))

        self.loss = self.start_loss + self.end_loss
        global_step = tf.train.get_or_create_global_step()

        self.input_placeholder_dict = OrderedDict({
            "context_word": self.context_word,
            "question_word": self.question_word,
            "context_char": self.context_char,
            "question_char": self.question_char,
            "context_len": self.context_len,
            "question_len": self.question_len,
            "answer_start": self.answer_start,
            "answer_end": self.answer_end,
            "training": self.training
        })
        if self.max_pooling_mask:
            self.input_placeholder_dict['context_word_len'] = self.context_word_len
            self.input_placeholder_dict['question_word_len'] = self.question_word_len
        if self.use_elmo:
            self.input_placeholder_dict['context_tokens'] = self.context_tokens
            self.input_placeholder_dict['question_tokens'] = self.question_tokens
        if self.abstractive_answer_num != 0:
            self.input_placeholder_dict["abstractive_answer_mask"] = self.abstractive_answer_mask

        self.output_variable_dict = OrderedDict({
            "start_logits": start_logits,
            "end_logits": end_logits,
        })
        if self.abstractive_answer_num != 0:
            self.output_variable_dict["abstractive_answer_logits"] = abstractive_answer_logits

        # 8. Metrics and summary
        with tf.variable_scope("train_metrics"):
            self.train_metrics = {
                'loss': tf.metrics.mean(self.loss)
            }

        self.train_update_metrics = tf.group(*[op for _, op in self.train_metrics.values()])
        metric_variables = tf.get_collection(tf.GraphKeys.LOCAL_VARIABLES, scope="train_metrics")
        self.train_metric_init_op = tf.variables_initializer(metric_variables)

        with tf.variable_scope("eval_metrics"):
            self.eval_metrics = {
                'loss': tf.metrics.mean(self.loss)
            }

        self.eval_update_metrics = tf.group(*[op for _, op in self.eval_metrics.values()])
        metric_variables = tf.get_collection(tf.GraphKeys.LOCAL_VARIABLES, scope="eval_metrics")
        self.eval_metric_init_op = tf.variables_initializer(metric_variables)

        tf.summary.scalar('loss', self.loss)
        self.summary_op = tf.summary.merge_all()
Beispiel #3
0
    def _build_graph(self):
        self.context_word = tf.placeholder(tf.int32, [None, None])
        self.context_len = tf.placeholder(tf.int32, [None])

        self.context_char = tf.placeholder(tf.int32, [None, None, None])
        self.context_word_len = tf.placeholder(tf.int32, [None, None])

        self.question_word = tf.placeholder(tf.int32, [None, None])
        self.question_len = tf.placeholder(tf.int32, [None])

        self.question_char = tf.placeholder(tf.int32, [None, None, None])
        self.question_word_len = tf.placeholder(tf.int32, [None, None])

        self.answer_start = tf.placeholder(tf.int32, [None])
        self.answer_end = tf.placeholder(tf.int32, [None])
        self.training = tf.placeholder(tf.bool, [])

        # 1. Word encoding
        word_embedding = Embedding(
            pretrained_embedding=self.pretrained_word_embedding,
            embedding_shape=(len(self.vocab.get_word_vocab()) + 1,
                             self.word_embedding_size),
            trainable=self.word_embedding_trainable)
        char_embedding = Embedding(
            embedding_shape=(len(self.vocab.get_char_vocab()) + 1,
                             self.char_embedding_size),
            trainable=True,
            init_scale=0.05)

        # 1.1 Embedding
        context_word_repr = word_embedding(self.context_word)
        context_char_repr = char_embedding(self.context_char)
        question_word_repr = word_embedding(self.question_word)
        question_char_repr = char_embedding(self.question_char)

        self.context_word_repr = context_word_repr
        self.question_word_repr = question_word_repr
        self.context_char_repr = context_char_repr
        self.question_char_repr = question_char_repr

        # 1.2 Char convolution
        dropout = Dropout(self.keep_prob)
        conv1d = Conv1DAndMaxPooling(self.char_conv_filters,
                                     self.char_conv_kernel_size)
        question_char_repr = conv1d(dropout(question_char_repr, self.training),
                                    self.question_word_len)
        context_char_repr = conv1d(dropout(context_char_repr, self.training),
                                   self.context_word_len)

        # 2. Phrase encoding
        context_repr = tf.concat([context_word_repr, context_char_repr],
                                 axis=-1)
        question_repr = tf.concat([question_word_repr, question_char_repr],
                                  axis=-1)

        variational_dropout = VariationalDropout(self.keep_prob)
        emb_enc_gru = CudnnBiGRU(self.rnn_hidden_size)

        context_repr = variational_dropout(context_repr, self.training)
        context_repr, _ = emb_enc_gru(context_repr, self.context_len)
        context_repr = variational_dropout(context_repr, self.training)

        question_repr = variational_dropout(question_repr, self.training)
        question_repr, _ = emb_enc_gru(question_repr, self.question_len)
        question_repr = variational_dropout(question_repr, self.training)

        # 3. Bi-Attention
        bi_attention = BiAttention(
            TriLinear(bias=True, name="bi_attention_tri_linear"))
        c2q, q2c = bi_attention(context_repr, question_repr, self.context_len,
                                self.question_len)
        context_repr = tf.concat(
            [context_repr, c2q, context_repr * c2q, context_repr * q2c],
            axis=-1)

        # 4. Self-Attention layer
        dense1 = tf.keras.layers.Dense(self.rnn_hidden_size * 2,
                                       use_bias=True,
                                       activation=tf.nn.relu)
        gru = CudnnBiGRU(self.rnn_hidden_size)
        dense2 = tf.keras.layers.Dense(self.rnn_hidden_size * 2,
                                       use_bias=True,
                                       activation=tf.nn.relu)
        self_attention = SelfAttention(
            TriLinear(bias=True, name="self_attention_tri_linear"))

        inputs = dense1(context_repr)
        outputs = variational_dropout(inputs, self.training)
        outputs, _ = gru(outputs, self.context_len)
        outputs = variational_dropout(outputs, self.training)
        c2c = self_attention(outputs, self.context_len)
        outputs = tf.concat([c2c, outputs, c2c * outputs],
                            axis=len(c2c.shape) - 1)
        outputs = dense2(outputs)
        context_repr = inputs + outputs
        context_repr = variational_dropout(context_repr, self.training)

        # 5. Modeling layer
        sum_max_encoding = SumMaxEncoder()
        context_modeling_gru1 = CudnnBiGRU(self.rnn_hidden_size)
        context_modeling_gru2 = CudnnBiGRU(self.rnn_hidden_size)
        question_modeling_gru1 = CudnnBiGRU(self.rnn_hidden_size)
        question_modeling_gru2 = CudnnBiGRU(self.rnn_hidden_size)
        self.max_context_len = tf.reduce_max(self.context_len)
        self.max_question_len = tf.reduce_max(self.question_len)

        modeled_context1, _ = context_modeling_gru1(context_repr,
                                                    self.context_len)
        modeled_context2, _ = context_modeling_gru2(
            tf.concat([context_repr, modeled_context1], axis=2),
            self.context_len)
        encoded_context = sum_max_encoding(modeled_context1, self.context_len,
                                           self.max_context_len)
        modeled_question1, _ = question_modeling_gru1(question_repr,
                                                      self.question_len)
        modeled_question2, _ = question_modeling_gru2(
            tf.concat([question_repr, modeled_question1], axis=2),
            self.question_len)
        encoded_question = sum_max_encoding(modeled_question2,
                                            self.question_len,
                                            self.max_question_len)

        start_dense = tf.keras.layers.Dense(1, activation=None)
        start_logits = tf.squeeze(start_dense(modeled_context1),
                                  squeeze_dims=[2])
        start_logits = mask_logits(start_logits, self.context_len)
        start_prob = tf.nn.softmax(start_logits)

        end_dense = tf.keras.layers.Dense(1, activation=None)
        end_logits = tf.squeeze(end_dense(modeled_context2), squeeze_dims=[2])
        end_logits = mask_logits(end_logits, self.context_len)
        end_prob = tf.nn.softmax(end_logits)

        self.start_loss = tf.reduce_mean(
            tf.nn.sparse_softmax_cross_entropy_with_logits(
                logits=start_logits, labels=self.answer_start))
        self.end_loss = tf.reduce_mean(
            tf.nn.sparse_softmax_cross_entropy_with_logits(
                logits=end_logits, labels=self.answer_end))
        # 7. Loss and input/output dict

        self.loss = self.start_loss + self.end_loss
        global_step = tf.train.get_or_create_global_step()

        self.input_placeholder_dict = OrderedDict({
            "context_word": self.context_word,
            "context_len": self.context_len,
            "context_char": self.context_char,
            "context_word_len": self.context_word_len,
            "question_word": self.question_word,
            "question_len": self.question_len,
            "question_char": self.question_char,
            "question_word_len": self.question_word_len,
            "answer_start": self.answer_start,
            "answer_end": self.answer_end,
            "training": self.training
        })

        self.output_variable_dict = OrderedDict({
            "start_prob": start_prob,
            "end_prob": end_prob,
        })

        # 8. Metrics and summary
        with tf.variable_scope("train_metrics"):
            self.train_metrics = {'loss': tf.metrics.mean(self.loss)}

        self.train_update_metrics = tf.group(
            *[op for _, op in self.train_metrics.values()])
        metric_variables = tf.get_collection(tf.GraphKeys.LOCAL_VARIABLES,
                                             scope="train_metrics")
        self.train_metric_init_op = tf.variables_initializer(metric_variables)

        with tf.variable_scope("eval_metrics"):
            self.eval_metrics = {'loss': tf.metrics.mean(self.loss)}

        self.eval_update_metrics = tf.group(
            *[op for _, op in self.eval_metrics.values()])
        metric_variables = tf.get_collection(tf.GraphKeys.LOCAL_VARIABLES,
                                             scope="eval_metrics")
        self.eval_metric_init_op = tf.variables_initializer(metric_variables)

        tf.summary.scalar('loss', self.loss)
        self.summary_op = tf.summary.merge_all()
Beispiel #4
0
    def _build_graph(self):
        # build input

        self.context_word = tf.placeholder(tf.int32, [None, None])
        self.context_char = tf.placeholder(tf.int32, [None, None, None])
        self.context_len = tf.placeholder(tf.int32, [None])
        self.context_word_len = tf.placeholder(tf.int32, [None, None])
        self.question_word = tf.placeholder(tf.int32, [None, None])
        self.question_char = tf.placeholder(tf.int32, [None, None, None])
        self.question_len = tf.placeholder(tf.int32, [None])
        self.question_word_len = tf.placeholder(tf.int32, [None, None])
        self.answer_start = tf.placeholder(tf.int32, [None])
        self.answer_end = tf.placeholder(tf.int32, [None])
        self.training = tf.placeholder(tf.bool, [])

        max_context_len = tf.shape(self.context_word)[1]
        max_context_word_len = tf.shape(self.context_char)[2]
        max_question_len = tf.shape(self.question_word)[1]
        max_question_word_len = tf.shape(self.question_char)[2]
        context_mask = (tf.sequence_mask(
            self.context_len, max_context_len, dtype=tf.float32) - 1) * 100
        question_mask = (tf.sequence_mask(
            self.question_len, max_question_len, dtype=tf.float32) - 1) * 100
        # 1. Word encoding
        word_embedding = Embedding(
            pretrained_embedding=self.pretrained_word_embedding,
            embedding_shape=(len(self.vocab.get_word_vocab()) + 1,
                             self.word_embedding_size),
            trainable=self.word_embedding_trainable)
        char_embedding = Embedding(
            embedding_shape=(len(self.vocab.get_char_vocab()) + 1,
                             self.char_embedding_size),
            trainable=True,
            init_scale=0.2)
        dropout = VariationalDropout(self.keep_prob)
        context_word_embedding = word_embedding(self.context_word)  # B*CL*WD
        context_char_embedding = dropout(
            tf.reshape(char_embedding(self.context_char),
                       [-1, max_context_word_len, self.char_embedding_size]),
            self.training)  # (B*CL)*WL*CD
        question_word_embedding = word_embedding(self.question_word)  # B*QL*WD
        question_char_embedding = dropout(
            tf.reshape(char_embedding(self.question_char),
                       [-1, max_question_word_len, self.char_embedding_size]),
            self.training)  # (B*QL)*WL*CD

        char_forward_rnn = tf.keras.layers.CuDNNGRU(self.hidden_size)
        char_backward_rnn = tf.keras.layers.CuDNNGRU(self.hidden_size,
                                                     go_backwards=True)
        context_char_forward_final_states = char_forward_rnn(
            context_char_embedding)  # (B*CL)*H
        context_char_backward_final_states = char_backward_rnn(
            context_char_embedding)  # (B*CL)*H
        context_char_final_states = tf.reshape(
            tf.concat([
                context_char_forward_final_states,
                context_char_backward_final_states
            ], -1), [-1, max_context_len, self.hidden_size * 2])
        context_repr = tf.concat(
            [context_word_embedding, context_char_final_states],
            -1)  # B*CL*(WD+H)

        question_char_forword_final_states = char_forward_rnn(
            question_char_embedding)  # (B*CL)*H
        question_char_backword_final_states = char_backward_rnn(
            question_char_embedding)  # (B*CL)*H
        question_char_final_states = tf.reshape(
            tf.concat([
                question_char_forword_final_states,
                question_char_backword_final_states
            ], -1), [-1, max_question_len, self.hidden_size * 2])
        question_repr = tf.concat(
            [question_word_embedding, question_char_final_states],
            -1)  # B*QL*(WD+H)

        # 1.2 Encoder
        context_rnn = [
            CudnnBiGRU(self.hidden_size) for _ in range(self.doc_rnn_layers)
        ]
        encoder_multi_bigru = MultiLayerRNN(context_rnn,
                                            concat_layer_out=True,
                                            input_keep_prob=self.keep_prob)
        encoder_context = dropout(
            encoder_multi_bigru(context_repr, self.context_len, self.training),
            self.training)  # B*CL*(H*2)
        encoder_question = dropout(
            encoder_multi_bigru(question_repr, self.question_len,
                                self.training), self.training)  # B*QL*(H*2)

        # 1.3 co-attention
        co_attention_context = tf.expand_dims(
            tf.keras.layers.Dense(self.hidden_size)(encoder_context),
            2)  # B*CL*1*H
        co_attention_question = tf.expand_dims(
            tf.keras.layers.Dense(self.hidden_size)(encoder_question),
            1)  # B*1*QL*H
        co_attention_score = tf.keras.layers.Dense(1)(
            tf.nn.tanh(co_attention_context +
                       co_attention_question))[:, :, :, 0] + tf.expand_dims(
                           question_mask, 1)  # B*CL*QL
        co_attention_similarity = tf.nn.softmax(co_attention_score,
                                                -1)  # B*CL*QL
        co_attention_rnn_input = tf.concat([
            encoder_context,
            tf.matmul(co_attention_similarity, encoder_question)
        ], -1)  # B*CL*(H*4)
        co_attention_rnn_input = co_attention_rnn_input * tf.keras.layers.Dense(
            self.hidden_size * 4,
            activation=tf.nn.sigmoid)(co_attention_rnn_input)
        co_attention_rnn = CudnnGRU(self.hidden_size)
        co_attention_output = dropout(
            co_attention_rnn(co_attention_rnn_input, self.context_len)[0],
            self.training)  # B*CL*(H*2)

        # 1.4 self-attention
        multi_head_attention = MultiHeadAttention(self.heads, self.hidden_size,
                                                  False)
        self_attention_repr = dropout(
            multi_head_attention(co_attention_output, co_attention_output,
                                 co_attention_output, context_mask),
            self.training)
        self_attention_rnn_input = tf.concat(
            [co_attention_output, self_attention_repr], -1)  # B*CL*(H*2)
        self_attention_rnn_input = self_attention_rnn_input * tf.keras.layers.Dense(
            self.hidden_size * 2,
            activation=tf.nn.sigmoid)(self_attention_rnn_input)
        self_attention_rnn = CudnnBiGRU(self.hidden_size)
        self_attention_output = dropout(
            self_attention_rnn(self_attention_rnn_input, self.context_len)[0],
            self.training)  # B*CL*(H*2)

        # predict_start
        start_question_score = tf.keras.layers.Dense(1)(
            tf.keras.layers.Dense(self.hidden_size, activation=tf.nn.tanh)
            (encoder_question)) + tf.expand_dims(question_mask, -1)  # B*QL*1
        start_question_similarity = tf.nn.softmax(start_question_score,
                                                  1)  # B*QL*1
        start_question_repr = tf.matmul(start_question_similarity,
                                        encoder_question,
                                        transpose_a=True)[:, 0]  # B*(H*2)
        start_logits = tf.keras.layers.Dense(1)(tf.nn.tanh(
            tf.keras.layers.Dense(self.hidden_size)(self_attention_output) +
            tf.expand_dims(
                tf.keras.layers.Dense(self.hidden_size)
                (start_question_repr), 1)))[:, :, 0] + context_mask  # B*CL
        self.start_prob = tf.nn.softmax(start_logits, -1)  # B*CL
        self.start_loss = tf.reduce_mean(
            tf.nn.sparse_softmax_cross_entropy_with_logits(
                logits=start_logits, labels=self.answer_start))

        start_repr = tf.matmul(tf.expand_dims(self.start_prob, 1),
                               self_attention_output)  # B*1*(H*2)
        start_output = Dropout(self.keep_prob)(tf.keras.layers.CuDNNGRU(
            self.hidden_size * 2)(start_repr, start_question_repr),
                                               self.training)  # B*(H*2)

        # predict_end
        end_logits = tf.keras.layers.Dense(1)(tf.nn.tanh(
            tf.keras.layers.Dense(self.hidden_size)(self_attention_output) +
            tf.expand_dims(
                tf.keras.layers.Dense(self.hidden_size)
                (start_output), 1)))[:, :, 0] + context_mask  # B*CL
        self.end_prob = tf.nn.softmax(end_logits, -1)  # B*CL
        self.end_loss = tf.reduce_mean(
            tf.nn.sparse_softmax_cross_entropy_with_logits(
                logits=end_logits, labels=self.answer_end))

        self.loss = self.start_loss + self.end_loss
        self.global_step = tf.train.get_or_create_global_step()
        input_dict = {
            "context_word": self.context_word,
            "context_char": self.context_char,
            "context_len": self.context_len,
            "context_word_len": self.context_word_len,
            "question_word": self.question_word,
            "question_char": self.question_char,
            "question_len": self.question_len,
            "question_word_len": self.question_word_len,
            "answer_start": self.answer_start,
            "answer_end": self.answer_end,
            "training": self.training
        }

        self.input_placeholder_dict = OrderedDict(input_dict)
        print(self.input_placeholder_dict)  # = OrderedDict(input_dict)

        self.output_variable_dict = OrderedDict({
            "start_prob": self.start_prob,
            "end_prob": self.end_prob
        })

        # 8. Metrics and summary
        with tf.variable_scope("train_metrics"):
            self.train_metrics = {'loss': tf.metrics.mean(self.loss)}

        self.train_update_metrics = tf.group(
            *[op for _, op in self.train_metrics.values()])
        metric_variables = tf.get_collection(tf.GraphKeys.LOCAL_VARIABLES,
                                             scope="train_metrics")
        self.train_metric_init_op = tf.variables_initializer(metric_variables)

        with tf.variable_scope("eval_metrics"):
            self.eval_metrics = {'loss': tf.metrics.mean(self.loss)}

        self.eval_update_metrics = tf.group(
            *[op for _, op in self.eval_metrics.values()])
        metric_variables = tf.get_collection(tf.GraphKeys.LOCAL_VARIABLES,
                                             scope="eval_metrics")
        self.eval_metric_init_op = tf.variables_initializer(metric_variables)

        tf.summary.scalar('loss', self.loss)
        self.summary_op = tf.summary.merge_all()
Beispiel #5
0
    def _build_graph(self):
        # build input

        self.context_word = tf.placeholder(tf.int32, [None, None])
        self.context_len = tf.placeholder(tf.int32, [None])
        self.question_word = tf.placeholder(tf.int32, [None, None])
        self.match_lemma = tf.placeholder(tf.int32, [None, None])
        self.match_lower = tf.placeholder(tf.int32, [None, None])
        self.question_len = tf.placeholder(tf.int32, [None])
        self.answer_start = tf.placeholder(tf.int32, [None])
        self.answer_end = tf.placeholder(tf.int32, [None])
        self.pos_feature = tf.placeholder(tf.int32, [None, None])
        self.ner_feature = tf.placeholder(tf.int32, [None, None])
        self.normalized_tf = tf.placeholder(tf.int32, [None, None])
        self.question_tokens = tf.placeholder(tf.string, [None, None])
        self.context_tokens = tf.placeholder(tf.string, [None, None])

        self.training = tf.placeholder(tf.bool, [])

        # 1. Word encoding
        word_embedding = PartiallyTrainableEmbedding(
            trainable_num=self.finetune_word_size,
            pretrained_embedding=self.pretrained_word_embedding,
            embedding_shape=(len(self.vocab.get_word_vocab()) + 1,
                             self.word_embedding_size))

        # 1.1 Embedding
        context_word_repr = word_embedding(self.context_word)
        question_word_repr = word_embedding(self.question_word)

        # pos embedding
        if len(self.features) > 0 and 'pos' in self.features:
            self.context_pos_feature = tf.cast(
                tf.one_hot(self.pos_feature,
                           len(self.feature_vocab['pos']) + 1), tf.float32)
        # ner embedding
        if len(self.features) > 0 and 'ner' in self.features:
            self.context_ner_feature = tf.cast(
                tf.one_hot(self.ner_feature,
                           len(self.feature_vocab['ner']) + 1), tf.float32)
        dropout = VariationalDropout(self.keep_prob)
        # embedding dropout
        context_word_repr = dropout(context_word_repr, self.training)
        question_word_repr = dropout(question_word_repr, self.training)

        # elmo embedding
        if self.use_elmo:
            elmo_emb = ElmoEmbedding(local_path=self.elmo_local_path)
            context_elmo_repr = elmo_emb(self.context_tokens, self.context_len)
            context_elmo_repr = dropout(context_elmo_repr, self.training)
            question_elmo_repr = elmo_emb(self.question_tokens,
                                          self.question_len)
            question_elmo_repr = dropout(question_elmo_repr, self.training)
            context_word_repr = tf.concat(
                [context_word_repr, context_elmo_repr], axis=-1)
            question_word_repr = tf.concat(
                [question_word_repr, question_elmo_repr], axis=-1)

        # 1.2 exact match feature
        context_expanded = tf.tile(
            tf.expand_dims(self.context_word, axis=-1),
            [1, 1, tf.shape(self.question_word)[1]])
        query_expanded = tf.tile(tf.expand_dims(self.question_word, axis=1),
                                 [1, tf.shape(self.context_word)[1], 1])
        exact_match_feature = tf.cast(
            tf.reduce_any(tf.equal(context_expanded, query_expanded), axis=-1),
            tf.float32)
        exact_match_feature = tf.expand_dims(exact_match_feature, axis=-1)
        if len(self.features) > 0 and 'match_lower' in self.features:
            match_lower_feature = tf.expand_dims(tf.cast(
                self.match_lower, tf.float32),
                                                 axis=-1)
            exact_match_feature = tf.concat(
                [exact_match_feature, match_lower_feature], axis=-1)
        if len(self.features) > 0 and 'match_lemma' in self.features:
            exact_match_feature = tf.concat([
                exact_match_feature,
                tf.expand_dims(tf.cast(self.match_lemma, tf.float32), axis=-1)
            ],
                                            axis=-1)
        if len(self.features) > 0 and 'pos' in self.features:
            exact_match_feature = tf.concat(
                [exact_match_feature, self.context_pos_feature], axis=-1)
        if len(self.features) > 0 and 'ner' in self.features:
            exact_match_feature = tf.concat(
                [exact_match_feature, self.context_ner_feature], axis=-1)
        if len(self.features) > 0 and 'context_tf' in self.features:
            exact_match_feature = tf.concat([
                exact_match_feature,
                tf.cast(tf.expand_dims(self.normalized_tf, axis=-1),
                        tf.float32)
            ],
                                            axis=-1)
        # 1.3 aligned question embedding
        sim_function = ProjectedDotProduct(self.rnn_hidden_size,
                                           activation=tf.nn.relu,
                                           reuse_weight=True)
        word_fusion = UniAttention(sim_function)
        aligned_question_embedding = word_fusion(context_word_repr,
                                                 question_word_repr,
                                                 self.question_len)

        # context_repr
        context_repr = tf.concat([
            context_word_repr, exact_match_feature, aligned_question_embedding
        ],
                                 axis=-1)

        # 1.4context encoder
        context_rnn_layers = [
            CudnnBiLSTM(self.rnn_hidden_size)
            for _ in range(self.doc_rnn_layers)
        ]
        multi_bilstm_layer_context = MultiLayerRNN(context_rnn_layers,
                                                   concat_layer_out=True,
                                                   input_keep_prob=0.7)
        context_repr = multi_bilstm_layer_context(context_repr,
                                                  self.context_len,
                                                  self.training)
        # rnn output dropout
        context_repr = dropout(context_repr, self.training)

        # 1.5 question encoder
        question_rnn_layers = [
            CudnnBiLSTM(self.rnn_hidden_size)
            for _ in range(self.question_rnn_layers)
        ]
        multi_bilstm_layer_question = MultiLayerRNN(question_rnn_layers,
                                                    concat_layer_out=True,
                                                    input_keep_prob=0.7)
        question_repr = multi_bilstm_layer_question(question_word_repr,
                                                    self.question_len,
                                                    self.training)
        # rnn output dropout
        question_repr = dropout(question_repr, self.training)
        self_attn = SelfAttn()
        weighted_question_repr = self_attn(question_repr, self.question_len)

        # predict
        doc_hidden_size = self.rnn_hidden_size * self.doc_rnn_layers * 2
        start_project = tf.keras.layers.Dense(doc_hidden_size, use_bias=False)
        start_logits = tf.squeeze(tf.matmul(
            start_project(context_repr),
            tf.expand_dims(weighted_question_repr, axis=-1)),
                                  axis=-1)
        self.start_prob = masked_softmax(start_logits, self.context_len)
        end_project = tf.keras.layers.Dense(doc_hidden_size, use_bias=False)
        end_logits = tf.squeeze(tf.matmul(
            end_project(context_repr),
            tf.expand_dims(weighted_question_repr, axis=-1)),
                                axis=-1)
        self.end_prob = masked_softmax(end_logits, self.context_len)
        # 7. Loss and input/output dict
        self.start_loss = tf.reduce_mean(
            tf.nn.sparse_softmax_cross_entropy_with_logits(
                logits=mask_logits(start_logits, self.context_len),
                labels=self.answer_start))
        self.end_loss = tf.reduce_mean(
            tf.nn.sparse_softmax_cross_entropy_with_logits(
                logits=mask_logits(end_logits, self.context_len),
                labels=self.answer_end))
        self.loss = self.start_loss + self.end_loss
        global_step = tf.train.get_or_create_global_step()
        input_dict = {
            "context_word": self.context_word,
            "context_len": self.context_len,
            "question_word": self.question_word,
            "question_len": self.question_len,
            "answer_start": self.answer_start,
            "answer_end": self.answer_end,
            "training": self.training
        }
        if len(self.features) > 0 and 'match_lemma' in self.features:
            input_dict['match_lemma'] = self.match_lemma
        if len(self.features) > 0 and 'context_tf' in self.features:
            input_dict['context_tf'] = self.normalized_tf
        if len(self.features) > 0 and 'match_lower' in self.features:
            input_dict['match_lower'] = self.match_lower
        if len(self.features) > 0 and 'pos' in self.features:
            input_dict['pos'] = self.pos_feature
        if len(self.features) > 0 and 'ner' in self.features:
            input_dict['ner'] = self.ner_feature
        if self.use_elmo:
            input_dict['context_tokens'] = self.context_tokens
            input_dict['question_tokens'] = self.question_tokens
        self.input_placeholder_dict = OrderedDict(input_dict)

        self.output_variable_dict = OrderedDict({
            "start_prob": self.start_prob,
            "end_prob": self.end_prob
        })

        # 8. Metrics and summary
        with tf.variable_scope("train_metrics"):
            self.train_metrics = {'loss': tf.metrics.mean(self.loss)}

        self.train_update_metrics = tf.group(
            *[op for _, op in self.train_metrics.values()])
        metric_variables = tf.get_collection(tf.GraphKeys.LOCAL_VARIABLES,
                                             scope="train_metrics")
        self.train_metric_init_op = tf.variables_initializer(metric_variables)

        with tf.variable_scope("eval_metrics"):
            self.eval_metrics = {'loss': tf.metrics.mean(self.loss)}

        self.eval_update_metrics = tf.group(
            *[op for _, op in self.eval_metrics.values()])
        metric_variables = tf.get_collection(tf.GraphKeys.LOCAL_VARIABLES,
                                             scope="eval_metrics")
        self.eval_metric_init_op = tf.variables_initializer(metric_variables)

        tf.summary.scalar('loss', self.loss)
        self.summary_op = tf.summary.merge_all()