Ejemplo n.º 1
0
    def _build_graph(self):
        # build input

        self.context_word = tf.placeholder(tf.int32, [None, None])
        self.context_len = tf.placeholder(tf.int32, [None])
        self.question_word = tf.placeholder(tf.int32, [None, None])
        self.match_lemma = tf.placeholder(tf.int32, [None, None])
        self.match_lower = tf.placeholder(tf.int32, [None, None])
        self.question_len = tf.placeholder(tf.int32, [None])
        self.answer_start = tf.placeholder(tf.int32, [None])
        self.answer_end = tf.placeholder(tf.int32, [None])
        self.pos_feature = tf.placeholder(tf.int32, [None, None])
        self.ner_feature = tf.placeholder(tf.int32, [None, None])
        self.normalized_tf = tf.placeholder(tf.int32, [None, None])
        self.question_tokens = tf.placeholder(tf.string, [None, None])
        self.context_tokens = tf.placeholder(tf.string, [None, None])

        self.training = tf.placeholder(tf.bool, [])

        # 1. Word encoding
        word_embedding = PartiallyTrainableEmbedding(
            trainable_num=self.finetune_word_size,
            pretrained_embedding=self.pretrained_word_embedding,
            embedding_shape=(len(self.vocab.get_word_vocab()) + 1,
                             self.word_embedding_size))

        # 1.1 Embedding
        context_word_repr = word_embedding(self.context_word)
        question_word_repr = word_embedding(self.question_word)

        # pos embedding
        if len(self.features) > 0 and 'pos' in self.features:
            self.context_pos_feature = tf.cast(
                tf.one_hot(self.pos_feature,
                           len(self.feature_vocab['pos']) + 1), tf.float32)
        # ner embedding
        if len(self.features) > 0 and 'ner' in self.features:
            self.context_ner_feature = tf.cast(
                tf.one_hot(self.ner_feature,
                           len(self.feature_vocab['ner']) + 1), tf.float32)
        dropout = VariationalDropout(self.keep_prob)
        # embedding dropout
        context_word_repr = dropout(context_word_repr, self.training)
        question_word_repr = dropout(question_word_repr, self.training)

        # elmo embedding
        if self.use_elmo:
            elmo_emb = ElmoEmbedding(local_path=self.elmo_local_path)
            context_elmo_repr = elmo_emb(self.context_tokens, self.context_len)
            context_elmo_repr = dropout(context_elmo_repr, self.training)
            question_elmo_repr = elmo_emb(self.question_tokens,
                                          self.question_len)
            question_elmo_repr = dropout(question_elmo_repr, self.training)
            context_word_repr = tf.concat(
                [context_word_repr, context_elmo_repr], axis=-1)
            question_word_repr = tf.concat(
                [question_word_repr, question_elmo_repr], axis=-1)

        # 1.2 exact match feature
        context_expanded = tf.tile(
            tf.expand_dims(self.context_word, axis=-1),
            [1, 1, tf.shape(self.question_word)[1]])
        query_expanded = tf.tile(tf.expand_dims(self.question_word, axis=1),
                                 [1, tf.shape(self.context_word)[1], 1])
        exact_match_feature = tf.cast(
            tf.reduce_any(tf.equal(context_expanded, query_expanded), axis=-1),
            tf.float32)
        exact_match_feature = tf.expand_dims(exact_match_feature, axis=-1)
        if len(self.features) > 0 and 'match_lower' in self.features:
            match_lower_feature = tf.expand_dims(tf.cast(
                self.match_lower, tf.float32),
                                                 axis=-1)
            exact_match_feature = tf.concat(
                [exact_match_feature, match_lower_feature], axis=-1)
        if len(self.features) > 0 and 'match_lemma' in self.features:
            exact_match_feature = tf.concat([
                exact_match_feature,
                tf.expand_dims(tf.cast(self.match_lemma, tf.float32), axis=-1)
            ],
                                            axis=-1)
        if len(self.features) > 0 and 'pos' in self.features:
            exact_match_feature = tf.concat(
                [exact_match_feature, self.context_pos_feature], axis=-1)
        if len(self.features) > 0 and 'ner' in self.features:
            exact_match_feature = tf.concat(
                [exact_match_feature, self.context_ner_feature], axis=-1)
        if len(self.features) > 0 and 'context_tf' in self.features:
            exact_match_feature = tf.concat([
                exact_match_feature,
                tf.cast(tf.expand_dims(self.normalized_tf, axis=-1),
                        tf.float32)
            ],
                                            axis=-1)
        # 1.3 aligned question embedding
        sim_function = ProjectedDotProduct(self.rnn_hidden_size,
                                           activation=tf.nn.relu,
                                           reuse_weight=True)
        word_fusion = UniAttention(sim_function)
        aligned_question_embedding = word_fusion(context_word_repr,
                                                 question_word_repr,
                                                 self.question_len)

        # context_repr
        context_repr = tf.concat([
            context_word_repr, exact_match_feature, aligned_question_embedding
        ],
                                 axis=-1)

        # 1.4context encoder
        context_rnn_layers = [
            CudnnBiLSTM(self.rnn_hidden_size)
            for _ in range(self.doc_rnn_layers)
        ]
        multi_bilstm_layer_context = MultiLayerRNN(context_rnn_layers,
                                                   concat_layer_out=True,
                                                   input_keep_prob=0.7)
        context_repr = multi_bilstm_layer_context(context_repr,
                                                  self.context_len,
                                                  self.training)
        # rnn output dropout
        context_repr = dropout(context_repr, self.training)

        # 1.5 question encoder
        question_rnn_layers = [
            CudnnBiLSTM(self.rnn_hidden_size)
            for _ in range(self.question_rnn_layers)
        ]
        multi_bilstm_layer_question = MultiLayerRNN(question_rnn_layers,
                                                    concat_layer_out=True,
                                                    input_keep_prob=0.7)
        question_repr = multi_bilstm_layer_question(question_word_repr,
                                                    self.question_len,
                                                    self.training)
        # rnn output dropout
        question_repr = dropout(question_repr, self.training)
        self_attn = SelfAttn()
        weighted_question_repr = self_attn(question_repr, self.question_len)

        # predict
        doc_hidden_size = self.rnn_hidden_size * self.doc_rnn_layers * 2
        start_project = tf.keras.layers.Dense(doc_hidden_size, use_bias=False)
        start_logits = tf.squeeze(tf.matmul(
            start_project(context_repr),
            tf.expand_dims(weighted_question_repr, axis=-1)),
                                  axis=-1)
        self.start_prob = masked_softmax(start_logits, self.context_len)
        end_project = tf.keras.layers.Dense(doc_hidden_size, use_bias=False)
        end_logits = tf.squeeze(tf.matmul(
            end_project(context_repr),
            tf.expand_dims(weighted_question_repr, axis=-1)),
                                axis=-1)
        self.end_prob = masked_softmax(end_logits, self.context_len)
        # 7. Loss and input/output dict
        self.start_loss = tf.reduce_mean(
            tf.nn.sparse_softmax_cross_entropy_with_logits(
                logits=mask_logits(start_logits, self.context_len),
                labels=self.answer_start))
        self.end_loss = tf.reduce_mean(
            tf.nn.sparse_softmax_cross_entropy_with_logits(
                logits=mask_logits(end_logits, self.context_len),
                labels=self.answer_end))
        self.loss = self.start_loss + self.end_loss
        global_step = tf.train.get_or_create_global_step()
        input_dict = {
            "context_word": self.context_word,
            "context_len": self.context_len,
            "question_word": self.question_word,
            "question_len": self.question_len,
            "answer_start": self.answer_start,
            "answer_end": self.answer_end,
            "training": self.training
        }
        if len(self.features) > 0 and 'match_lemma' in self.features:
            input_dict['match_lemma'] = self.match_lemma
        if len(self.features) > 0 and 'context_tf' in self.features:
            input_dict['context_tf'] = self.normalized_tf
        if len(self.features) > 0 and 'match_lower' in self.features:
            input_dict['match_lower'] = self.match_lower
        if len(self.features) > 0 and 'pos' in self.features:
            input_dict['pos'] = self.pos_feature
        if len(self.features) > 0 and 'ner' in self.features:
            input_dict['ner'] = self.ner_feature
        if self.use_elmo:
            input_dict['context_tokens'] = self.context_tokens
            input_dict['question_tokens'] = self.question_tokens
        self.input_placeholder_dict = OrderedDict(input_dict)

        self.output_variable_dict = OrderedDict({
            "start_prob": self.start_prob,
            "end_prob": self.end_prob
        })

        # 8. Metrics and summary
        with tf.variable_scope("train_metrics"):
            self.train_metrics = {'loss': tf.metrics.mean(self.loss)}

        self.train_update_metrics = tf.group(
            *[op for _, op in self.train_metrics.values()])
        metric_variables = tf.get_collection(tf.GraphKeys.LOCAL_VARIABLES,
                                             scope="train_metrics")
        self.train_metric_init_op = tf.variables_initializer(metric_variables)

        with tf.variable_scope("eval_metrics"):
            self.eval_metrics = {'loss': tf.metrics.mean(self.loss)}

        self.eval_update_metrics = tf.group(
            *[op for _, op in self.eval_metrics.values()])
        metric_variables = tf.get_collection(tf.GraphKeys.LOCAL_VARIABLES,
                                             scope="eval_metrics")
        self.eval_metric_init_op = tf.variables_initializer(metric_variables)

        tf.summary.scalar('loss', self.loss)
        self.summary_op = tf.summary.merge_all()
Ejemplo n.º 2
0
    def _build_graph(self):
        self.context_word = tf.placeholder(tf.int32, [None, None])
        self.context_char = tf.placeholder(tf.int32, [None, None, None])
        self.context_len = tf.placeholder(tf.int32, [None])
        self.question_word = tf.placeholder(tf.int32, [None, None])
        self.question_char = tf.placeholder(tf.int32, [None, None, None])
        self.question_len = tf.placeholder(tf.int32, [None])
        self.answer_start = tf.placeholder(tf.int32, [None])
        self.answer_end = tf.placeholder(tf.int32, [None])
        self.training = tf.placeholder(tf.bool, [])

        self.question_tokens = tf.placeholder(tf.string, [None, None])
        self.context_tokens = tf.placeholder(tf.string,[None,None])
        if self.enable_na_answer:
            self.na = tf.placeholder(tf.int32, [None])
        # 1. Word encoding
        word_embedding = Embedding(pretrained_embedding=self.pretrained_word_embedding,
                                   embedding_shape=(len(self.vocab.get_word_vocab()) + 1, self.word_embedding_size),
                                   trainable=self.word_embedding_trainable)
        char_embedding = Embedding(embedding_shape=(len(self.vocab.get_char_vocab()) + 1, self.char_embedding_size),
                                   trainable=True, init_scale=0.2)


        # 1.1 Embedding
        context_word_repr = word_embedding(self.context_word)
        context_char_repr = char_embedding(self.context_char)
        question_word_repr = word_embedding(self.question_word)
        question_char_repr = char_embedding(self.question_char)

        # 1.2 Char convolution
        dropout = Dropout(self.keep_prob)
        conv1d = Conv1DAndMaxPooling(self.char_conv_filters, self.char_conv_kernel_size)
        context_char_repr = dropout(conv1d(context_char_repr), self.training)
        question_char_repr = dropout(conv1d(question_char_repr), self.training)

        #elmo embedding
        if self.use_elmo:
            elmo_emb = ElmoEmbedding(local_path=self.elmo_local_path)
            context_elmo_repr = elmo_emb(self.context_tokens,self.context_len)
            context_elmo_repr = dropout(context_elmo_repr,self.training)
            question_elmo_repr = elmo_emb(self.question_tokens,self.question_len)
            question_elmo_repr = dropout(question_elmo_repr,self.training)
        #concat word and char
        context_repr = tf.concat([context_word_repr, context_char_repr],axis=-1)
        question_repr = tf.concat([question_word_repr,question_char_repr],axis=-1)
        if self.use_elmo:
            context_repr= tf.concat([context_repr,context_elmo_repr],axis=-1)
            question_repr = tf.concat([question_repr,question_elmo_repr],axis=-1)

        # 1.3 Highway network
        highway1 = Highway()
        highway2 = Highway()
        context_repr = highway2(highway1(context_repr))
        question_repr = highway2(highway1(question_repr))

        # 2. Phrase encoding
        phrase_lstm = CudnnBiLSTM(self.rnn_hidden_size)
        context_repr, _ = phrase_lstm(dropout(context_repr, self.training), self.context_len)
        question_repr, _ = phrase_lstm(dropout(question_repr, self.training), self.question_len)

        # 3. Bi-Attention
        bi_attention = BiAttention(TriLinear())
        c2q, q2c = bi_attention(context_repr, question_repr, self.context_len, self.question_len)

        # 4. Modeling layer
        final_merged_context = tf.concat([context_repr, c2q, context_repr * c2q, context_repr * q2c], axis=-1)
        modeling_lstm1 = CudnnBiLSTM(self.rnn_hidden_size)
        modeling_lstm2 = CudnnBiLSTM(self.rnn_hidden_size)
        modeled_context1, _ = modeling_lstm1(dropout(final_merged_context, self.training), self.context_len)
        modeled_context2, _ = modeling_lstm2(dropout(modeled_context1, self.training), self.context_len)
        modeled_context = modeled_context1 + modeled_context2

        # 5. Start prediction
        start_pred_layer = tf.keras.layers.Dense(1, use_bias=False)
        start_logits = start_pred_layer(
            dropout(tf.concat([final_merged_context, modeled_context], axis=-1), self.training))
        start_logits = tf.squeeze(start_logits, axis=-1)
        self.start_prob = masked_softmax(start_logits, self.context_len)

        # 6. End prediction
        start_repr = weighted_sum(modeled_context, self.start_prob)
        tiled_start_repr = tf.tile(tf.expand_dims(start_repr, axis=1), [1, tf.shape(modeled_context)[1], 1])
        end_lstm = CudnnBiLSTM(self.rnn_hidden_size)
        encoded_end_repr, _ = end_lstm(dropout(tf.concat(
            [final_merged_context, modeled_context, tiled_start_repr, modeled_context * tiled_start_repr], axis=-1),
            self.training),
            self.context_len)
        end_pred_layer = tf.keras.layers.Dense(1, use_bias=False)
        end_logits = end_pred_layer(dropout(tf.concat(
            [final_merged_context, encoded_end_repr], axis=-1), self.training))
        end_logits = tf.squeeze(end_logits, axis=-1)
        self.end_prob = masked_softmax(end_logits, self.context_len)

        # 7. Loss and input/output dict
        if self.enable_na_answer:
            self.na_bias = tf.get_variable("na_bias", shape=[1], dtype='float')
            self.na_bias_tiled = tf.tile(tf.reshape(self.na_bias, [1, 1]), [tf.shape(self.context_word)[0], 1])
            self.concat_start_na_logits = tf.concat([self.na_bias_tiled, start_logits], axis=-1)
            concat_start_na_prob = masked_softmax(self.concat_start_na_logits, self.context_len + 1)
            self.na_prob = tf.squeeze(tf.slice(concat_start_na_prob, [0, 0], [-1, 1]), axis=1)
            self.start_prob = tf.slice(concat_start_na_prob, [0, 1], [-1, -1])
            self.concat_end_na_logits = tf.concat([self.na_bias_tiled,end_logits],axis=-1)
            concat_end_na_prob = masked_softmax(self.concat_end_na_logits,self.context_len+1)
            self.na_prob2 =tf.squeeze(tf.slice(concat_end_na_prob,[0,0],[-1,1]),axis=1)
            self.end_prob = tf.slice(concat_end_na_prob,[0,1],[-1,-1])
            max_len =tf.reduce_max(self.context_len)
            start_label = tf.cast(tf.one_hot(self.answer_start,max_len),tf.float32)
            start_label =(1.0-tf.cast(tf.expand_dims(self.na,axis=-1),tf.float32))*start_label
            na = tf.cast(tf.expand_dims(self.na,axis=-1),tf.float32)
            start_na_label = tf.concat([na,start_label],axis=-1)
            self.start_loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(
                    logits=mask_logits(self.concat_start_na_logits,self.context_len+1),
                    labels=start_na_label))
            end_label = tf.cast(tf.one_hot(self.answer_end,max_len),tf.float32)
            end_label = (1.0-tf.cast(tf.expand_dims(self.na,axis=-1),tf.float32))*end_label
            end_na_label = tf.concat([na,end_label],axis=-1)
            self.end_loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(
                    logits=mask_logits(self.concat_end_na_logits,self.context_len+1),
                    labels=end_na_label))
        else:

            self.start_loss = tf.reduce_mean(
                tf.nn.sparse_softmax_cross_entropy_with_logits(logits=mask_logits(start_logits, self.context_len),
                                                               labels=self.answer_start))
            self.end_loss = tf.reduce_mean(
                tf.nn.sparse_softmax_cross_entropy_with_logits(logits=mask_logits(end_logits, self.context_len),
                                                               labels=self.answer_end))
        self.loss = self.start_loss + self.end_loss
        global_step = tf.train.get_or_create_global_step()
        input_dict = {
            "context_word": self.context_word,
            "context_char": self.context_char,
            "context_len": self.context_len,
            "question_word": self.question_word,
            "question_char": self.question_char,
            "question_len": self.question_len,
            "answer_start": self.answer_start,
            "answer_end": self.answer_end,
            "training": self.training
        }
        if self.use_elmo:
            input_dict['context_tokens'] = self.context_tokens
            input_dict['question_tokens'] = self.question_tokens
        if self.enable_na_answer:
            input_dict["is_impossible"] = self.na
        self.input_placeholder_dict = OrderedDict(input_dict)

        output_dict = {
            "start_prob": self.start_prob,
            "end_prob": self.end_prob
        }
        if self.enable_na_answer:
            output_dict['na_prob'] = self.na_prob*self.na_prob2

        self.output_variable_dict = OrderedDict(output_dict)

        # 8. Metrics and summary
        with tf.variable_scope("train_metrics"):
            self.train_metrics = {
                'loss': tf.metrics.mean(self.loss)
            }

        self.train_update_metrics = tf.group(*[op for _, op in self.train_metrics.values()])
        metric_variables = tf.get_collection(tf.GraphKeys.LOCAL_VARIABLES, scope="train_metrics")
        self.train_metric_init_op = tf.variables_initializer(metric_variables)

        with tf.variable_scope("eval_metrics"):
            self.eval_metrics = {
                'loss': tf.metrics.mean(self.loss)
            }

        self.eval_update_metrics = tf.group(*[op for _, op in self.eval_metrics.values()])
        metric_variables = tf.get_collection(tf.GraphKeys.LOCAL_VARIABLES, scope="eval_metrics")
        self.eval_metric_init_op = tf.variables_initializer(metric_variables)

        tf.summary.scalar('loss', self.loss)
        self.summary_op = tf.summary.merge_all()
Ejemplo n.º 3
0
    def _build_graph(self):
        self.context_word = tf.placeholder(tf.int32, [None, None])
        self.context_char = tf.placeholder(tf.int32, [None, None, None])
        self.context_len = tf.placeholder(tf.int32, [None])
        self.question_word = tf.placeholder(tf.int32, [None, None])
        self.question_char = tf.placeholder(tf.int32, [None, None, None])
        self.question_len = tf.placeholder(tf.int32, [None])
        self.answer_start = tf.placeholder(tf.int32, [None])
        self.answer_end = tf.placeholder(tf.int32, [None])
        self.training = tf.placeholder(tf.bool, [])

        # 1. Word encoding
        word_embedding = Embedding(
            pretrained_embedding=self.pretrained_word_embedding,
            embedding_shape=(len(self.vocab.get_word_vocab()) + 1,
                             self.word_embedding_size),
            trainable=self.word_embedding_trainable)

        context_word_repr = word_embedding(self.context_word)
        question_word_repr = word_embedding(self.question_word)

        dropout = Dropout(self.keep_prob)

        # 2 inner attention between question word and context word

        inner_att = ProjectedDotProduct(self.rnn_hidden_size,
                                        activation=tf.nn.leaky_relu,
                                        reuse_weight=True)

        inner_score = inner_att(question_word_repr, context_word_repr)

        context_word_softmax = tf.nn.softmax(inner_score, axis=2)

        question_inner_representation = tf.matmul(context_word_softmax,
                                                  context_word_repr)

        question_word_softmax = tf.nn.softmax(inner_score, axis=1)

        context_inner_representation = tf.matmul(question_word_softmax,
                                                 question_word_repr,
                                                 transpose_a=True)

        highway1 = Highway()
        highway2 = Highway()

        context_repr = highway1(
            highway2(
                tf.concat([context_word_repr, context_inner_representation],
                          axis=-1)))
        question_repr = highway1(
            highway2(
                tf.concat([question_word_repr, question_inner_representation],
                          axis=-1)))

        # 2. Phrase encoding
        phrase_lstm = CudnnBiLSTM(self.rnn_hidden_size)
        context_repr, _ = phrase_lstm(dropout(context_repr, self.training),
                                      self.context_len)
        question_repr, _ = phrase_lstm(dropout(question_repr, self.training),
                                       self.question_len)

        # 3. Bi-Attention
        bi_attention = BiAttention(TriLinear())
        c2q, q2c = bi_attention(context_repr, question_repr, self.context_len,
                                self.question_len)

        # 4. Modeling layer
        final_merged_context = tf.concat(
            [context_repr, c2q, context_repr * c2q, context_repr * q2c],
            axis=-1)
        modeling_lstm1 = CudnnBiLSTM(self.rnn_hidden_size)
        modeling_lstm2 = CudnnBiLSTM(self.rnn_hidden_size)
        modeled_context1, _ = modeling_lstm1(
            dropout(final_merged_context, self.training), self.context_len)
        modeled_context2, _ = modeling_lstm2(
            dropout(modeled_context1, self.training), self.context_len)
        modeled_context = modeled_context1 + modeled_context2

        # 5. Start prediction
        start_pred_layer = tf.keras.layers.Dense(1, use_bias=False)
        start_logits = start_pred_layer(
            dropout(
                tf.concat([final_merged_context, modeled_context], axis=-1),
                self.training))
        start_logits = tf.squeeze(start_logits, axis=-1)
        self.start_prob = masked_softmax(start_logits, self.context_len)

        # 6. End prediction
        start_repr = weighted_sum(modeled_context, self.start_prob)
        tiled_start_repr = tf.tile(tf.expand_dims(start_repr, axis=1),
                                   [1, tf.shape(modeled_context)[1], 1])
        end_lstm = CudnnBiLSTM(self.rnn_hidden_size)
        encoded_end_repr, _ = end_lstm(
            dropout(
                tf.concat([
                    final_merged_context, modeled_context, tiled_start_repr,
                    modeled_context * tiled_start_repr
                ],
                          axis=-1), self.training), self.context_len)
        end_pred_layer = tf.keras.layers.Dense(1, use_bias=False)
        end_logits = end_pred_layer(
            dropout(
                tf.concat([final_merged_context, encoded_end_repr], axis=-1),
                self.training))
        end_logits = tf.squeeze(end_logits, axis=-1)
        self.end_prob = masked_softmax(end_logits, self.context_len)

        # 7. Loss and input/output dict
        self.start_loss = tf.reduce_mean(
            tf.nn.sparse_softmax_cross_entropy_with_logits(
                logits=mask_logits(start_logits, self.context_len),
                labels=self.answer_start))
        self.end_loss = tf.reduce_mean(
            tf.nn.sparse_softmax_cross_entropy_with_logits(
                logits=mask_logits(end_logits, self.context_len),
                labels=self.answer_end))
        self.loss = (self.start_loss + self.end_loss) / 2
        global_step = tf.train.get_or_create_global_step()

        self.input_placeholder_dict = OrderedDict({
            "context_word": self.context_word,
            "context_char": self.context_char,
            "context_len": self.context_len,
            "question_word": self.question_word,
            "question_char": self.question_char,
            "question_len": self.question_len,
            "answer_start": self.answer_start,
            "answer_end": self.answer_end,
            "training": self.training
        })

        self.output_variable_dict = OrderedDict({
            "start_prob": self.start_prob,
            "end_prob": self.end_prob
        })

        # 8. Metrics and summary
        with tf.variable_scope("train_metrics"):
            self.train_metrics = {'loss': tf.metrics.mean(self.loss)}

        self.train_update_metrics = tf.group(
            *[op for _, op in self.train_metrics.values()])
        metric_variables = tf.get_collection(tf.GraphKeys.LOCAL_VARIABLES,
                                             scope="train_metrics")
        self.train_metric_init_op = tf.variables_initializer(metric_variables)

        with tf.variable_scope("eval_metrics"):
            self.eval_metrics = {'loss': tf.metrics.mean(self.loss)}

        self.eval_update_metrics = tf.group(
            *[op for _, op in self.eval_metrics.values()])
        metric_variables = tf.get_collection(tf.GraphKeys.LOCAL_VARIABLES,
                                             scope="eval_metrics")
        self.eval_metric_init_op = tf.variables_initializer(metric_variables)

        tf.summary.scalar('loss', self.loss)
        self.summary_op = tf.summary.merge_all()