Exemplo n.º 1
0
    def _build_graph(self):
        # build input
        self.training = tf.placeholder(tf.bool, [])
        self.context_word = tf.placeholder(tf.int32, [None, None])
        self.slice_context_len = tf.where(self.training, 400, 1000)
        self.slice_question_len = tf.where(self.training, 50, 100)
        self.context_char = tf.placeholder(tf.int32, [None, None, None])
        self.context_char = tf.placeholder(tf.int32, [None, None, None])
        batch_size, original_context_len, max_context_word_len = tf.shape(self.context_char)[0],tf.shape(self.context_char)[1],tf.shape(self.context_char)[2]

        slice_context_word = tf.slice(self.context_word,[0,0],[batch_size,tf.minimum(self.slice_context_len,original_context_len)])
        slice_context_char = tf.slice(self.context_char, [0, 0, 0], [batch_size,tf.minimum(self.slice_context_len,original_context_len),max_context_word_len])
        self.context_len = tf.placeholder(tf.int32, [None])
        self.question_word = tf.placeholder(tf.int32, [None, None])
        self.question_char = tf.placeholder(tf.int32, [None, None, None])
        original_question_len,max_question_word_len = tf.shape(self.question_char)[1],tf.shape(self.question_char)[2]
        slice_question_word = tf.slice(self.question_word,[0,0],[batch_size,tf.minimum(self.slice_question_len,original_question_len)])
        slice_question_char = tf.slice(self.question_char, [0, 0, 0],[batch_size, tf.minimum(self.slice_question_len, original_question_len),max_question_word_len])
        self.question_len = tf.placeholder(tf.int32, [None])
        self.answer_start = tf.placeholder(tf.int32, [None])
        self.answer_end = tf.placeholder(tf.int32, [None])

        max_context_len = tf.shape(slice_context_word)[1]
        max_question_len = tf.shape(slice_question_word)[1]
        slice_context_len = tf.clip_by_value(self.context_len,0,max_context_len)
        slice_question_len = tf.clip_by_value(self.question_len,0,max_question_len)
        context_mask = (tf.sequence_mask(slice_context_len, max_context_len, dtype=tf.float32)-1)*100
        question_mask = (tf.sequence_mask(slice_question_len, max_question_len, dtype=tf.float32)-1) * 100
        divisors = tf.pow(tf.constant([10000.0] * (self.filters // 2), dtype=tf.float32),tf.range(0, self.filters, 2, dtype=tf.float32) / self.filters)
        quotients = tf.cast(tf.expand_dims(tf.range(0, max_context_len), -1),tf.float32) / tf.expand_dims(divisors, 0)
        position_repr = tf.concat([tf.sin(quotients), tf.cos(quotients)], -1) #  CL*F
        # 1. Word encoding
        word_embedding = Embedding(pretrained_embedding=self.pretrained_word_embedding,
                                   embedding_shape=(len(self.vocab.get_word_vocab()) + 1, self.word_embedding_size),
                                   trainable=self.word_embedding_trainable)
        char_embedding = Embedding(embedding_shape=(len(self.vocab.get_char_vocab()) + 1, self.char_embedding_size),
                                   trainable=True, init_scale=0.2)
        dropout = Dropout(self.keep_prob)
        # 1.1 Embedding
        context_word_embedding = Dropout(0.9)(word_embedding(slice_context_word),self.training) # B*CL*WD
        context_char_embedding = Dropout(0.95)(char_embedding(slice_context_char),self.training) # B*CL*WL*CD
        question_word_embedding = Dropout(0.9)(word_embedding(slice_question_word),self.training) # B*QL*WD
        question_char_embedding = Dropout(0.95)(char_embedding(slice_question_char),self.training) # B*QL*WL*CD

        char_cnn = Conv1DAndMaxPooling(self.char_filters,self.kernel_size1,padding='same',activation=tf.nn.relu)
        embedding_dense = tf.keras.layers.Dense(self.filters)
        highway = Highway(gate_activation=tf.nn.relu,trans_activation=tf.nn.tanh,hidden_units=self.filters,keep_prob=self.keep_prob)

        context_char_repr = tf.reshape(char_cnn(context_char_embedding),[-1,max_context_len,self.char_filters]) # B*CL*CF
        context_repr = highway(dropout(embedding_dense(tf.concat([context_word_embedding,context_char_repr],-1)), self.training),self.training) # B*CL*F

        question_char_repr = tf.reshape(char_cnn(question_char_embedding),[-1,max_question_len,self.char_filters]) # B*QL*CF
        question_repr = highway(dropout(embedding_dense(tf.concat([question_word_embedding,question_char_repr],-1)), self.training),self.training) # B*QL*CF

        # 1.2 Embedding Encoder
        embedding_encoder = EncoderBlock(self.kernel_size1,self.filters,self.conv_layers1,self.heads,self.keep_prob)
        embedding_context = tf.contrib.layers.layer_norm(embedding_encoder(context_repr + position_repr,self.training,context_mask), begin_norm_axis=-1)  # B*CL*F
        embedding_question = tf.contrib.layers.layer_norm(embedding_encoder(question_repr + position_repr[:max_question_len], self.training,question_mask), begin_norm_axis=-1)  # B*QL*F

        # 1.3 co-attention
        co_attention_context = tf.keras.layers.Dense(1)(embedding_context)  # B*CL*1
        co_attention_question = tf.keras.layers.Dense(1)(embedding_question)  # B*QL*1
        cq = tf.matmul(tf.expand_dims(tf.transpose(embedding_context, [0, 2, 1]), -1),tf.expand_dims(tf.transpose(embedding_question, [0, 2, 1]), -2))  # B*F*CL*QL
        cq_score = tf.keras.layers.Dense(1)(tf.transpose(cq, [0, 2, 3, 1]))[:, :, :, 0] + co_attention_context + tf.transpose(co_attention_question, [0, 2, 1])  # B*CL*QL
        question_similarity = tf.nn.softmax(cq_score + tf.expand_dims(question_mask, -2),2)  # B*CL*QL
        context_similarity = tf.nn.softmax(cq_score + tf.expand_dims(context_mask, -1), 1)  # B*CL*QL
        cqa = tf.matmul(question_similarity, embedding_question)  #  B*CL*F
        qca = tf.matmul(question_similarity, tf.matmul(context_similarity, embedding_context, transpose_a=True))  #  B*CL*F
        co_attention_output = dropout(tf.keras.layers.Dense(self.filters)(tf.concat([embedding_context, cqa, embedding_context * cqa, embedding_context * qca], -1)), self.training)  #  B*CL*F

        # 1.4 Model Encoder
        model_encoder_blocks=[EncoderBlock(self.kernel_size2,self.filters,self.conv_layers2,self.heads,self.keep_prob) for _ in range(self.model_encoder_layers)]
        m0 = co_attention_output
        for model_encoder_block in model_encoder_blocks:
            m0 = model_encoder_block(m0 + position_repr,self.training,context_mask)
        m1 = m0
        for model_encoder_block in model_encoder_blocks:
            m1 = model_encoder_block(m1 + position_repr,self.training,context_mask)
        m2 = m1
        for model_encoder_block in model_encoder_blocks:
            m2 = model_encoder_block(m2 + position_repr,self.training,context_mask)
        norm_m0 = tf.contrib.layers.layer_norm(m0, begin_norm_axis=-1)
        norm_m1 = tf.contrib.layers.layer_norm(m1, begin_norm_axis=-1)
        norm_m2 = tf.contrib.layers.layer_norm(m2, begin_norm_axis=-1)
        # predict start

        start_logits = tf.keras.layers.Dense(1)(tf.concat([norm_m0, norm_m1], -1))[:, :, 0] + context_mask
        self.start_prob = tf.nn.softmax(start_logits,-1)
        self.start_loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=start_logits,labels=tf.one_hot(self.answer_start,max_context_len)))

        # predict end
        end_logits = tf.keras.layers.Dense(1)(tf.concat([norm_m0, norm_m2], -1))[:, :, 0] + context_mask
        self.end_prob = tf.nn.softmax(end_logits,-1)
        self.end_loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=end_logits, labels=tf.one_hot(self.answer_end,max_context_len)))

        self.loss = self.start_loss + self.end_loss
        self.global_step = tf.train.get_or_create_global_step()
        input_dict = {
            "context_word": self.context_word,
            "context_char": self.context_char,
            "context_len": self.context_len,
            "question_word": self.question_word,
            "question_char": self.question_char,
            "question_len": self.question_len,
            "answer_start": self.answer_start,
            "answer_end": self.answer_end,
            "training": self.training
        }


        self.input_placeholder_dict = OrderedDict(input_dict)
        print(self.input_placeholder_dict)# = OrderedDict(input_dict)

        self.output_variable_dict = OrderedDict({
            "start_prob": self.start_prob,
            "end_prob": self.end_prob
        })

        # 8. Metrics and summary
        with tf.variable_scope("train_metrics"):
            self.train_metrics = {
                'loss': tf.metrics.mean(self.loss)
            }

        self.train_update_metrics = tf.group(*[op for _, op in self.train_metrics.values()])
        metric_variables = tf.get_collection(tf.GraphKeys.LOCAL_VARIABLES, scope="train_metrics")
        self.train_metric_init_op = tf.variables_initializer(metric_variables)

        with tf.variable_scope("eval_metrics"):
            self.eval_metrics = {
                'loss': tf.metrics.mean(self.loss)
            }

        self.eval_update_metrics = tf.group(*[op for _, op in self.eval_metrics.values()])
        metric_variables = tf.get_collection(tf.GraphKeys.LOCAL_VARIABLES, scope="eval_metrics")
        self.eval_metric_init_op = tf.variables_initializer(metric_variables)

        tf.summary.scalar('loss', self.loss)
        self.summary_op = tf.summary.merge_all()
Exemplo n.º 2
0
    def _build_graph(self):
        self.context_word = tf.placeholder(tf.int32, [None, None])
        self.context_char = tf.placeholder(tf.int32, [None, None, None])
        self.context_len = tf.placeholder(tf.int32, [None])
        self.question_word = tf.placeholder(tf.int32, [None, None])
        self.question_char = tf.placeholder(tf.int32, [None, None, None])
        self.question_len = tf.placeholder(tf.int32, [None])
        self.answer_start = tf.placeholder(tf.int32, [None])
        self.answer_end = tf.placeholder(tf.int32, [None])
        self.training = tf.placeholder(tf.bool, [])

        self.question_tokens = tf.placeholder(tf.string, [None, None])
        self.context_tokens = tf.placeholder(tf.string,[None,None])
        if self.enable_na_answer:
            self.na = tf.placeholder(tf.int32, [None])
        # 1. Word encoding
        word_embedding = Embedding(pretrained_embedding=self.pretrained_word_embedding,
                                   embedding_shape=(len(self.vocab.get_word_vocab()) + 1, self.word_embedding_size),
                                   trainable=self.word_embedding_trainable)
        char_embedding = Embedding(embedding_shape=(len(self.vocab.get_char_vocab()) + 1, self.char_embedding_size),
                                   trainable=True, init_scale=0.2)


        # 1.1 Embedding
        context_word_repr = word_embedding(self.context_word)
        context_char_repr = char_embedding(self.context_char)
        question_word_repr = word_embedding(self.question_word)
        question_char_repr = char_embedding(self.question_char)

        # 1.2 Char convolution
        dropout = Dropout(self.keep_prob)
        conv1d = Conv1DAndMaxPooling(self.char_conv_filters, self.char_conv_kernel_size)
        context_char_repr = dropout(conv1d(context_char_repr), self.training)
        question_char_repr = dropout(conv1d(question_char_repr), self.training)

        #elmo embedding
        if self.use_elmo:
            elmo_emb = ElmoEmbedding(local_path=self.elmo_local_path)
            context_elmo_repr = elmo_emb(self.context_tokens,self.context_len)
            context_elmo_repr = dropout(context_elmo_repr,self.training)
            question_elmo_repr = elmo_emb(self.question_tokens,self.question_len)
            question_elmo_repr = dropout(question_elmo_repr,self.training)
        #concat word and char
        context_repr = tf.concat([context_word_repr, context_char_repr],axis=-1)
        question_repr = tf.concat([question_word_repr,question_char_repr],axis=-1)
        if self.use_elmo:
            context_repr= tf.concat([context_repr,context_elmo_repr],axis=-1)
            question_repr = tf.concat([question_repr,question_elmo_repr],axis=-1)

        # 1.3 Highway network
        highway1 = Highway()
        highway2 = Highway()
        context_repr = highway2(highway1(context_repr))
        question_repr = highway2(highway1(question_repr))

        # 2. Phrase encoding
        phrase_lstm = CudnnBiLSTM(self.rnn_hidden_size)
        context_repr, _ = phrase_lstm(dropout(context_repr, self.training), self.context_len)
        question_repr, _ = phrase_lstm(dropout(question_repr, self.training), self.question_len)

        # 3. Bi-Attention
        bi_attention = BiAttention(TriLinear())
        c2q, q2c = bi_attention(context_repr, question_repr, self.context_len, self.question_len)

        # 4. Modeling layer
        final_merged_context = tf.concat([context_repr, c2q, context_repr * c2q, context_repr * q2c], axis=-1)
        modeling_lstm1 = CudnnBiLSTM(self.rnn_hidden_size)
        modeling_lstm2 = CudnnBiLSTM(self.rnn_hidden_size)
        modeled_context1, _ = modeling_lstm1(dropout(final_merged_context, self.training), self.context_len)
        modeled_context2, _ = modeling_lstm2(dropout(modeled_context1, self.training), self.context_len)
        modeled_context = modeled_context1 + modeled_context2

        # 5. Start prediction
        start_pred_layer = tf.keras.layers.Dense(1, use_bias=False)
        start_logits = start_pred_layer(
            dropout(tf.concat([final_merged_context, modeled_context], axis=-1), self.training))
        start_logits = tf.squeeze(start_logits, axis=-1)
        self.start_prob = masked_softmax(start_logits, self.context_len)

        # 6. End prediction
        start_repr = weighted_sum(modeled_context, self.start_prob)
        tiled_start_repr = tf.tile(tf.expand_dims(start_repr, axis=1), [1, tf.shape(modeled_context)[1], 1])
        end_lstm = CudnnBiLSTM(self.rnn_hidden_size)
        encoded_end_repr, _ = end_lstm(dropout(tf.concat(
            [final_merged_context, modeled_context, tiled_start_repr, modeled_context * tiled_start_repr], axis=-1),
            self.training),
            self.context_len)
        end_pred_layer = tf.keras.layers.Dense(1, use_bias=False)
        end_logits = end_pred_layer(dropout(tf.concat(
            [final_merged_context, encoded_end_repr], axis=-1), self.training))
        end_logits = tf.squeeze(end_logits, axis=-1)
        self.end_prob = masked_softmax(end_logits, self.context_len)

        # 7. Loss and input/output dict
        if self.enable_na_answer:
            self.na_bias = tf.get_variable("na_bias", shape=[1], dtype='float')
            self.na_bias_tiled = tf.tile(tf.reshape(self.na_bias, [1, 1]), [tf.shape(self.context_word)[0], 1])
            self.concat_start_na_logits = tf.concat([self.na_bias_tiled, start_logits], axis=-1)
            concat_start_na_prob = masked_softmax(self.concat_start_na_logits, self.context_len + 1)
            self.na_prob = tf.squeeze(tf.slice(concat_start_na_prob, [0, 0], [-1, 1]), axis=1)
            self.start_prob = tf.slice(concat_start_na_prob, [0, 1], [-1, -1])
            self.concat_end_na_logits = tf.concat([self.na_bias_tiled,end_logits],axis=-1)
            concat_end_na_prob = masked_softmax(self.concat_end_na_logits,self.context_len+1)
            self.na_prob2 =tf.squeeze(tf.slice(concat_end_na_prob,[0,0],[-1,1]),axis=1)
            self.end_prob = tf.slice(concat_end_na_prob,[0,1],[-1,-1])
            max_len =tf.reduce_max(self.context_len)
            start_label = tf.cast(tf.one_hot(self.answer_start,max_len),tf.float32)
            start_label =(1.0-tf.cast(tf.expand_dims(self.na,axis=-1),tf.float32))*start_label
            na = tf.cast(tf.expand_dims(self.na,axis=-1),tf.float32)
            start_na_label = tf.concat([na,start_label],axis=-1)
            self.start_loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(
                    logits=mask_logits(self.concat_start_na_logits,self.context_len+1),
                    labels=start_na_label))
            end_label = tf.cast(tf.one_hot(self.answer_end,max_len),tf.float32)
            end_label = (1.0-tf.cast(tf.expand_dims(self.na,axis=-1),tf.float32))*end_label
            end_na_label = tf.concat([na,end_label],axis=-1)
            self.end_loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(
                    logits=mask_logits(self.concat_end_na_logits,self.context_len+1),
                    labels=end_na_label))
        else:

            self.start_loss = tf.reduce_mean(
                tf.nn.sparse_softmax_cross_entropy_with_logits(logits=mask_logits(start_logits, self.context_len),
                                                               labels=self.answer_start))
            self.end_loss = tf.reduce_mean(
                tf.nn.sparse_softmax_cross_entropy_with_logits(logits=mask_logits(end_logits, self.context_len),
                                                               labels=self.answer_end))
        self.loss = self.start_loss + self.end_loss
        global_step = tf.train.get_or_create_global_step()
        input_dict = {
            "context_word": self.context_word,
            "context_char": self.context_char,
            "context_len": self.context_len,
            "question_word": self.question_word,
            "question_char": self.question_char,
            "question_len": self.question_len,
            "answer_start": self.answer_start,
            "answer_end": self.answer_end,
            "training": self.training
        }
        if self.use_elmo:
            input_dict['context_tokens'] = self.context_tokens
            input_dict['question_tokens'] = self.question_tokens
        if self.enable_na_answer:
            input_dict["is_impossible"] = self.na
        self.input_placeholder_dict = OrderedDict(input_dict)

        output_dict = {
            "start_prob": self.start_prob,
            "end_prob": self.end_prob
        }
        if self.enable_na_answer:
            output_dict['na_prob'] = self.na_prob*self.na_prob2

        self.output_variable_dict = OrderedDict(output_dict)

        # 8. Metrics and summary
        with tf.variable_scope("train_metrics"):
            self.train_metrics = {
                'loss': tf.metrics.mean(self.loss)
            }

        self.train_update_metrics = tf.group(*[op for _, op in self.train_metrics.values()])
        metric_variables = tf.get_collection(tf.GraphKeys.LOCAL_VARIABLES, scope="train_metrics")
        self.train_metric_init_op = tf.variables_initializer(metric_variables)

        with tf.variable_scope("eval_metrics"):
            self.eval_metrics = {
                'loss': tf.metrics.mean(self.loss)
            }

        self.eval_update_metrics = tf.group(*[op for _, op in self.eval_metrics.values()])
        metric_variables = tf.get_collection(tf.GraphKeys.LOCAL_VARIABLES, scope="eval_metrics")
        self.eval_metric_init_op = tf.variables_initializer(metric_variables)

        tf.summary.scalar('loss', self.loss)
        self.summary_op = tf.summary.merge_all()
Exemplo n.º 3
0
    def _build_graph(self):
        self.context_word = tf.placeholder(tf.int32, [None, None])
        self.context_char = tf.placeholder(tf.int32, [None, None, None])
        self.context_len = tf.placeholder(tf.int32, [None])
        self.question_word = tf.placeholder(tf.int32, [None, None])
        self.question_char = tf.placeholder(tf.int32, [None, None, None])
        self.question_len = tf.placeholder(tf.int32, [None])
        self.answer_start = tf.placeholder(tf.int32, [None])
        self.answer_end = tf.placeholder(tf.int32, [None])
        self.training = tf.placeholder(tf.bool, [])

        # 1. Word encoding
        word_embedding = Embedding(
            pretrained_embedding=self.pretrained_word_embedding,
            embedding_shape=(len(self.vocab.get_word_vocab()) + 1,
                             self.word_embedding_size),
            trainable=self.word_embedding_trainable)

        context_word_repr = word_embedding(self.context_word)
        question_word_repr = word_embedding(self.question_word)

        dropout = Dropout(self.keep_prob)

        # 2 inner attention between question word and context word

        inner_att = ProjectedDotProduct(self.rnn_hidden_size,
                                        activation=tf.nn.leaky_relu,
                                        reuse_weight=True)

        inner_score = inner_att(question_word_repr, context_word_repr)

        context_word_softmax = tf.nn.softmax(inner_score, axis=2)

        question_inner_representation = tf.matmul(context_word_softmax,
                                                  context_word_repr)

        question_word_softmax = tf.nn.softmax(inner_score, axis=1)

        context_inner_representation = tf.matmul(question_word_softmax,
                                                 question_word_repr,
                                                 transpose_a=True)

        highway1 = Highway()
        highway2 = Highway()

        context_repr = highway1(
            highway2(
                tf.concat([context_word_repr, context_inner_representation],
                          axis=-1)))
        question_repr = highway1(
            highway2(
                tf.concat([question_word_repr, question_inner_representation],
                          axis=-1)))

        # 2. Phrase encoding
        phrase_lstm = CudnnBiLSTM(self.rnn_hidden_size)
        context_repr, _ = phrase_lstm(dropout(context_repr, self.training),
                                      self.context_len)
        question_repr, _ = phrase_lstm(dropout(question_repr, self.training),
                                       self.question_len)

        # 3. Bi-Attention
        bi_attention = BiAttention(TriLinear())
        c2q, q2c = bi_attention(context_repr, question_repr, self.context_len,
                                self.question_len)

        # 4. Modeling layer
        final_merged_context = tf.concat(
            [context_repr, c2q, context_repr * c2q, context_repr * q2c],
            axis=-1)
        modeling_lstm1 = CudnnBiLSTM(self.rnn_hidden_size)
        modeling_lstm2 = CudnnBiLSTM(self.rnn_hidden_size)
        modeled_context1, _ = modeling_lstm1(
            dropout(final_merged_context, self.training), self.context_len)
        modeled_context2, _ = modeling_lstm2(
            dropout(modeled_context1, self.training), self.context_len)
        modeled_context = modeled_context1 + modeled_context2

        # 5. Start prediction
        start_pred_layer = tf.keras.layers.Dense(1, use_bias=False)
        start_logits = start_pred_layer(
            dropout(
                tf.concat([final_merged_context, modeled_context], axis=-1),
                self.training))
        start_logits = tf.squeeze(start_logits, axis=-1)
        self.start_prob = masked_softmax(start_logits, self.context_len)

        # 6. End prediction
        start_repr = weighted_sum(modeled_context, self.start_prob)
        tiled_start_repr = tf.tile(tf.expand_dims(start_repr, axis=1),
                                   [1, tf.shape(modeled_context)[1], 1])
        end_lstm = CudnnBiLSTM(self.rnn_hidden_size)
        encoded_end_repr, _ = end_lstm(
            dropout(
                tf.concat([
                    final_merged_context, modeled_context, tiled_start_repr,
                    modeled_context * tiled_start_repr
                ],
                          axis=-1), self.training), self.context_len)
        end_pred_layer = tf.keras.layers.Dense(1, use_bias=False)
        end_logits = end_pred_layer(
            dropout(
                tf.concat([final_merged_context, encoded_end_repr], axis=-1),
                self.training))
        end_logits = tf.squeeze(end_logits, axis=-1)
        self.end_prob = masked_softmax(end_logits, self.context_len)

        # 7. Loss and input/output dict
        self.start_loss = tf.reduce_mean(
            tf.nn.sparse_softmax_cross_entropy_with_logits(
                logits=mask_logits(start_logits, self.context_len),
                labels=self.answer_start))
        self.end_loss = tf.reduce_mean(
            tf.nn.sparse_softmax_cross_entropy_with_logits(
                logits=mask_logits(end_logits, self.context_len),
                labels=self.answer_end))
        self.loss = (self.start_loss + self.end_loss) / 2
        global_step = tf.train.get_or_create_global_step()

        self.input_placeholder_dict = OrderedDict({
            "context_word": self.context_word,
            "context_char": self.context_char,
            "context_len": self.context_len,
            "question_word": self.question_word,
            "question_char": self.question_char,
            "question_len": self.question_len,
            "answer_start": self.answer_start,
            "answer_end": self.answer_end,
            "training": self.training
        })

        self.output_variable_dict = OrderedDict({
            "start_prob": self.start_prob,
            "end_prob": self.end_prob
        })

        # 8. Metrics and summary
        with tf.variable_scope("train_metrics"):
            self.train_metrics = {'loss': tf.metrics.mean(self.loss)}

        self.train_update_metrics = tf.group(
            *[op for _, op in self.train_metrics.values()])
        metric_variables = tf.get_collection(tf.GraphKeys.LOCAL_VARIABLES,
                                             scope="train_metrics")
        self.train_metric_init_op = tf.variables_initializer(metric_variables)

        with tf.variable_scope("eval_metrics"):
            self.eval_metrics = {'loss': tf.metrics.mean(self.loss)}

        self.eval_update_metrics = tf.group(
            *[op for _, op in self.eval_metrics.values()])
        metric_variables = tf.get_collection(tf.GraphKeys.LOCAL_VARIABLES,
                                             scope="eval_metrics")
        self.eval_metric_init_op = tf.variables_initializer(metric_variables)

        tf.summary.scalar('loss', self.loss)
        self.summary_op = tf.summary.merge_all()