示例#1
0
    def build_graph(self):
        """Builds the main part of the graph for the model, starting from the input embeddings to the final distributions for the answer span.

        Defines:
          self.logits_start, self.logits_end: Both tensors shape (batch_size, context_len).
            These are the logits (i.e. values that are fed into the softmax function) for the start and end distribution.
            Important: these are -large in the pad locations. Necessary for when we feed into the cross entropy function.
          self.probdist_start, self.probdist_end: Both shape (batch_size, context_len). Each row sums to 1.
            These are the result of taking (masked) softmax of logits_start and logits_end.
        """

        # Use a RNN to get hidden states for the context and the question
        # Note: here the RNNEncoder is shared (i.e. the weights are the same)
        # between the context and the question.
        encoder = RNNEncoder(self.FLAGS.hidden_size, self.keep_prob)
        context_hiddens = encoder.build_graph(self.context_embs, self.context_mask) # (batch_size, context_len, hidden_size*2)
        question_hiddens = encoder.build_graph(self.qn_embs, self.qn_mask) # (batch_size, question_len, hidden_size*2)
        
        if self.FLAGS.attention_type == 'dot_product':
            print("<<<<<<<< Adding dot_poduct attention >>>")         
            attn_layer = BasicAttn(self.keep_prob, self.FLAGS.hidden_size*2, self.FLAGS.hidden_size*2)
            _, attn_output = attn_layer.build_graph(question_hiddens, self.qn_mask, context_hiddens) # attn_output is shape (batch_size, context_len, hidden_size*2)
    
            # Concat attn_output to context_hiddens to get blended_reps
            blended_reps = tf.concat([context_hiddens, attn_output], axis=2) # (batch_size, context_len, hidden_size*4)
        
        elif self.FLAGS.attention_type == 'self_attention':
            print("<<<<<<<<< Adding Self attention over basic attention >>>>>>>")
            basic_attn_layer = BasicAttn(self.keep_prob, self.FLAGS.hidden_size*2, self.FLAGS.hidden_size*2)
            _, basic_attn_output = basic_attn_layer.build_graph(question_hiddens, self.qn_mask, context_hiddens) # attn_output is shape (batch_size, context_len, hidden_size*2)
            
            self_attn_layer = SelfAttn(self.keep_prob, self.FLAGS.self_attn_zsize, self.FLAGS.hidden_size*2)
            _, self_attn_output = self_attn_layer.build_graph(basic_attn_output, self.context_mask)
            concated_basic_self = tf.concat([basic_attn_output,self_attn_output], axis=2) #(bs,N,4h)
            
            self_attn_encoder = RNNEncoder(self.FLAGS.hidden_size, self.keep_prob)
            blended_reps = self_attn_encoder.build_graph(concated_basic_self, self.context_mask, scope_name="self_attn_encoder") # (batch_size, N, hidden_size*2)
        
        elif self.FLAGS.attention_type == 'bidaf':
            print("<<<<<<<<< Adding BIDAF attention >>>>>>>")
            attn_layer = BidafAttn(self.keep_prob, self.FLAGS.hidden_size*2)
            c2q_attention, q2c_attention = attn_layer.build_graph(context_hiddens, question_hiddens, self.qn_mask, self.context_mask)
            
            # Combined tensors o get final output.....
            body_c2q_attention_mult = context_hiddens*c2q_attention # (batch_size, num_keys(N), 2h)
            q2c_expanded = tf.expand_dims(q2c_attention, 1) #(bs,1,2h)
            body_q2c_attention_mult = context_hiddens*q2c_expanded # (batch_size, num_keys(N), 2h)
            blended_reps = tf.concat([c2q_attention, body_c2q_attention_mult, body_q2c_attention_mult], axis=2) #(bs,N,6h) # context_hiddens removed
            blended_reps = tf.nn.dropout(blended_reps, self.keep_prob)
        
        # Apply fully connected layer to each blended representation
        # Note, blended_reps_final corresponds to b' in the handout
        # Note, tf.contrib.layers.fully_connected applies a ReLU non-linarity here by default
        blended_reps_final = tf.contrib.layers.fully_connected(blended_reps, num_outputs=self.FLAGS.hidden_size) # blended_reps_final is shape (batch_size, context_len, hidden_size)
        
        with vs.variable_scope("ClassProb"):
            softmax_layer_class = CustomSimpleSoftmaxLayer()
            
            #Both have dimesions:  shape (batch_size, 4)
            self.logits_class, self.probdist_class =  softmax_layer_class.build_graph(blended_reps_final, self.context_mask, self.FLAGS.reduction_type)
示例#2
0
 def computeBlendedReps(self,
                        context_hiddens,
                        question_hiddens,
                        newBaseline=False,
                        AttnModel=BasicAttn):
     # This routine makes the assumption that new baseline always uses the BasicAttn module while the old baseline can use either model
     if newBaseline == True:
         # Use context hidden states to attend to question hidden states
         attn_layer_start = BasicAttn(self.keep_prob,
                                      self.FLAGS.hidden_size * 2,
                                      self.FLAGS.hidden_size * 2)
         blended_reps_start = attn_layer_start.build_graph(
             question_hiddens, self.qn_mask, context_hiddens
         )  # attn_output is shape (batch_size, context_len, hidden_size*2)
         attn_layer_end = BasicAttn(self.keep_prob,
                                    self.FLAGS.hidden_size * 2,
                                    self.FLAGS.hidden_size * 2)
         blended_reps_end = attn_layer_end.build_graph(
             question_hiddens, self.qn_mask, context_hiddens)
     else:
         attn_layer_start = AttnModel(self.keep_prob,
                                      self.FLAGS.hidden_size * 2,
                                      self.FLAGS.hidden_size * 2)
         if AttnModel == BasicAttn:
             blended_reps_start = attn_layer_start.build_graph(
                 question_hiddens, self.qn_mask, context_hiddens)
         else:
             blended_reps_start = attn_layer_start.build_graph(
                 question_hiddens, context_hiddens, self.qn_mask,
                 self.context_mask)
         blended_reps_end = blended_reps_start
     return blended_reps_start, blended_reps_end
示例#3
0
    def build_graph(self):
        encoder = RNNEncoder(self.FLAGS.hidden_size, self.keep_prob)
        context_hiddens = encoder.build_graph(
            self.context_embs,
            self.context_mask)  # (batch_size, context_len, hidden_size*2)
        question_hiddens = encoder.build_graph(
            self.qn_embs,
            self.qn_mask)  # (batch_size, question_len, hidden_size*2)

        # Use context hidden states to attend to question hidden states
        attn_layer = BasicAttn(self.keep_prob, self.FLAGS.hidden_size * 2)
        _, attn_output = attn_layer.build_graph(
            question_hiddens, self.qn_mask, context_hiddens
        )  # attn_output is shape (batch_size, context_len, hidden_size*2)

        attn_layer = R_Net_Attn(self.keep_prob, self.FLAGS.hidden_size * 2,
                                self.FLAGS)
        output = attn_layer.build_graph(
            attn_output, self.context_mask
        )  # attn_output is shape (batch_size, context_len, hidden_size*2)

        blended_reps_final = tf.contrib.layers.fully_connected(
            tf.concat([attn_output, output], 2),
            num_outputs=self.FLAGS.hidden_size
        )  # blended_reps_final is shape (batch_size, context_len, hidden_size)

        with vs.variable_scope("StartDist"):
            softmax_layer_start = SimpleSoftmaxLayer()
            self.logits_start, self.probdist_start = softmax_layer_start.build_graph(
                blended_reps_final, self.context_mask)

        with vs.variable_scope("EndDist"):
            softmax_layer_end = SimpleSoftmaxLayer()
            self.logits_end, self.probdist_end = softmax_layer_end.build_graph(
                blended_reps_final, self.context_mask)
示例#4
0
    def build_graph(self):
        """Builds the main part of the graph for the model, starting from the input embeddings to the final distributions for the answer span.

        Defines:
          self.logits_start, self.logits_end: Both tensors shape (batch_size, context_len).
            These are the logits (i.e. values that are fed into the softmax function) for the start and end distribution.
            Important: these are -large in the pad locations. Necessary for when we feed into the cross entropy function.
          self.probdist_start, self.probdist_end: Both shape (batch_size, context_len). Each row sums to 1.
            These are the result of taking (masked) softmax of logits_start and logits_end.
        """

        # Use a RNN to get hidden states for the context and the question
        # Note: here the RNNEncoder is shared (i.e. the weights are the same)
        # between the context and the question.
        encoder = RNNEncoder(self.FLAGS.hidden_size, self.keep_prob)
        context_hiddens = encoder.build_graph(self.context_embs, self.context_mask) # (batch_size, context_len, hidden_size*2)
        question_hiddens = encoder.build_graph(self.qn_embs, self.qn_mask) # (batch_size, question_len, hidden_size*2)

        # Use context hidden states to attend to question hidden states
        if self.attention =='Baseline':
            attn_layer = BasicAttn(self.keep_prob, self.FLAGS.hidden_size*2, self.FLAGS.hidden_size*2)
            _, attn_output = attn_layer.build_graph(question_hiddens, self.qn_mask, context_hiddens) # attn_output is shape (batch_size, context_len, hidden_size*2)

            # Concat attn_output to context_hiddens to get blended_reps
            blended_reps = tf.concat([context_hiddens, attn_output], axis=2) # (batch_size, context_len, hidden_size*4)
        
        if self.attention=='BiDAF':
            attn_layer = BiDAFAttn(self.keep_prob, self.FLAGS.hidden_size*2, self.FLAGS.hidden_size*2)
            _, attn_output_c2q, attn_output_q2c= attn_layer.build_graph(question_hiddens, self.qn_mask, context_hiddens) # attn_output_c2q is shape (batch_size, context_len, hidden_size*2) attn_output_q2c is shape (batch_size, 1,hidden_size*2)

            # Concat attn_output to context_hiddens to get blended_reps
            blended_reps = tf.concat([context_hiddens, attn_output_c2q,context_hiddens*attn_output_c2q,context_hiddens*attn_output_q2c], axis=2) # (batch_size, context_len, hidden_size*8)
        

        # Apply fully connected layer to each blended representation
        # Note, blended_reps_final corresponds to b' in the handout
        # Note, tf.contrib.layers.fully_connected applies a ReLU non-linarity here by default
        blended_reps_final = tf.contrib.layers.fully_connected(blended_reps, num_outputs=self.FLAGS.hidden_size) # blended_reps_final is shape (batch_size, context_len, hidden_size)

        # Use softmax layer to compute probability distribution for start location
        # Note this produces self.logits_start and self.probdist_start, both of which have shape (batch_size, context_len)
        with vs.variable_scope("StartDist"):
            softmax_layer_start = SimpleSoftmaxLayer()
            self.logits_start, self.probdist_start = softmax_layer_start.build_graph(blended_reps_final, self.context_mask)

        # Use softmax layer to compute probability distribution for end location
        # Note this produces self.logits_end and self.probdist_end, both of which have shape (batch_size, context_len)
        with vs.variable_scope("EndDist"):
            softmax_layer_end = SimpleSoftmaxLayer()
            self.logits_end, self.probdist_end = softmax_layer_end.build_graph(blended_reps_final, self.context_mask)
示例#5
0
    def build_graph(self):
        """Builds the main part of the graph for the model, starting from the input embeddings to the final distributions for the answer span.
        """

        # Use a RNN to get hidden states for the context and the question
        encoder = RNNEncoder(self.FLAGS.hidden_size, self.keep_prob)
        _,context_hiddens = encoder.build_graph(self.context_embs, self.context_mask) # (batch_size, context_len, hidden_size*2)
        _,question_hiddens = encoder.build_graph(self.qn_embs, self.qn_mask) # (batch_size, question_len, hidden_size*2)

        # Use context hidden states to attend to question hidden states
        attn_layer = BasicAttn(self.keep_prob, self.FLAGS.hidden_size*2, self.FLAGS.hidden_size*2)
        _, attn_output = attn_layer.build_graph(question_hiddens, self.qn_mask, context_hiddens) # attn_output is shape (batch_size, context_len, hidden_size*2)

        # Concat attn_output to context_hiddens to get blended_reps
        blended_reps = tf.concat([context_hiddens, attn_output], axis=2) # (batch_size, context_len, hidden_size*4)

        # Apply fully connected layer to each blended representation
        blended_reps_final = tf.contrib.layers.fully_connected(blended_reps, num_outputs=self.FLAGS.hidden_size) # blended_reps_final is shape (batch_size, context_len, hidden_size)

        # Use softmax layer to compute probability distribution for start location
        with vs.variable_scope("StartDist"):
            softmax_layer_start = SimpleSoftmaxLayer()
            self.logits_start, self.probdist_start = softmax_layer_start.build_graph(blended_reps_final, self.context_mask, False)

        # Use softmax layer to compute probability distribution for end location
        with vs.variable_scope("EndDist"):
            softmax_layer_end = SimpleSoftmaxLayer()
            self.logits_end, self.probdist_end = softmax_layer_end.build_graph(blended_reps_final, self.context_mask, False)
示例#6
0
    def get_aligned_question_embs(self):
        """
        Adds aligned question embeddings to context embeddings, and another dummy row to question embeddings. See DrQA fro details.
        """
        with vs.variable_scope("add_alignedQ"):
            attn_layer = BasicAttn(self.keep_prob, self.FLAGS.embedding_size,
                                   self.FLAGS.embedding_size)
            attn_dist, _ = attn_layer.build_graph(
                self.qn_embs, self.qn_mask, self.context_embs,
                self.FLAGS.hidden_size)  # havent added features to *embs yet

            self.bidaf = attn_dist

            # attn_dist    : (batch_size, context_len, question_len)
            # self.qn_embs : (batch_size, context_len, embedding_size)
            a = tf.expand_dims(attn_dist, 3) * tf.expand_dims(
                self.qn_embs, 1)  # (b, N, M, d) = (b,N,M,1)*(b,1,M,d)
            self.alignedQ_embs = tf.reduce_sum(a, axis=2)  # (b,N,d)
示例#7
0
    def build_graph(self):
        """Builds the main part of the graph for the model, starting from the input embeddings to the final distributions for the answer span.

        Defines:
          self.logits_start, self.logits_end: Both tensors shape (batch_size, context_len).
            These are the logits (i.e. values that are fed into the softmax function) for the start and end distribution.
            Important: these are -large in the pad locations. Necessary for when we feed into the cross entropy function.
          self.probdist_start, self.probdist_end: Both shape (batch_size, context_len). Each row sums to 1.
            These are the result of taking (masked) softmax of logits_start and logits_end.
        """

        # Use a RNN to get hidden states for the context and the question
        # Note: here the RNNEncoder is shared (i.e. the weights are the same)
        # between the context and the question.

        encoder = RNNEncoder(self.FLAGS.hidden_size, self.keep_prob)
        encoderQ = RNNEncoder(self.FLAGS.hidden_size, self.keep_prob)
        context_hiddens = encoder.build_graph(self.context_embs, self.context_mask,"rnnencoder1") # (batch_size, context_len, hidden_size*2)
        question_hiddens = encoderQ.build_graph(self.qn_embs, self.qn_mask,"rnnencoderQ") # (batch_size, question_len, ,"rnnencoder1"hidden_size*2)

        # Use context hidden states to attend to question hidden states
        attn_layer = BasicAttn(self.keep_prob, self.FLAGS.hidden_size*2, self.FLAGS.hidden_size*2)
        _, attn_output,new_attn = attn_layer.build_graph(question_hiddens, self.qn_mask, context_hiddens,2*self.FLAGS.hidden_size) # attn_output is shape (batch_size, context_len, hidden_size*2)

        _,_,blended_reps_final=build_graph_middle(self,new_attn,attn_output,context_hiddens,question_hiddens)

        # Use softmax layer to compute probability distribution for start location
        # Note this produces self.logits_start and self.probdist_start, both of which have shape (batch_size, context_len)
        with vs.variable_scope("StartDist"):
            softmax_layer_start = SimpleSoftmaxLayer()
            self.logits_start, self.probdist_start = softmax_layer_start.build_graph(blended_reps_final, self.context_mask)

        # Use softmax layer to compute probability distribution for end location
        # Note this produces self.logits_end and self.probdist_end, both of which have shape (batch_size, context_len)
        with vs.variable_scope("EndDist"):
            softmax_layer_end = SimpleSoftmaxLayer()
            self.logits_end, self.probdist_end = softmax_layer_end.build_graph(blended_reps_final, self.context_mask)

        

        '''
示例#8
0
    def build_graph(self):
        """Builds the main part of the graph for the model, starting from the input embeddings to the final distributions for the answer span.

        Defines:
          self.logits_start, self.logits_end: Both tensors shape (batch_size, context_len).
            These are the logits (i.e. values that are fed into the softmax function) for the start and end distribution.
            Important: these are -large in the pad locations. Necessary for when we feed into the cross entropy function.
          self.probdist_start, self.probdist_end: Both shape (batch_size, context_len). Each row sums to 1.
            These are the result of taking (masked) softmax of logits_start and logits_end.
        """

        # Use a RNN to get hidden states for the context and the question
        # Note: here the RNNEncoder is shared (i.e. the weights are the same)
        # between the context and the question.
        encoder = RNNEncoder(self.FLAGS.hidden_size, self.keep_prob)
        context_hiddens = encoder.build_graph(self.context_embs, self.context_mask) # (batch_size, context_len, hidden_size*2)
        question_hiddens = encoder.build_graph(self.qn_embs, self.qn_mask) # (batch_size, question_len, hidden_size*2)

        # Use context hidden states to attend to question hidden states
        attn_layer = BasicAttn(self.keep_prob, self.FLAGS.hidden_size*2, self.FLAGS.hidden_size*2)
        _, attn_output = attn_layer.build_graph(question_hiddens, self.qn_mask, context_hiddens) # attn_output is shape (batch_size, context_len, hidden_size*2)

        # Concat attn_output to context_hiddens to get blended_reps
        blended_reps = tf.concat([context_hiddens, attn_output], axis=2) # (batch_size, context_len, hidden_size*4)

        # Apply fully connected layer to each blended representation
        # Note, blended_reps_final corresponds to b' in the handout
        # Note, tf.contrib.layers.fully_connected applies a ReLU non-linarity here by default
        blended_reps_final = tf.contrib.layers.fully_connected(blended_reps, num_outputs=self.FLAGS.hidden_size) # blended_reps_final is shape (batch_size, context_len, hidden_size)

        # Use softmax layer to compute probability distribution for start location
        # Note this produces self.logits_start and self.probdist_start, both of which have shape (batch_size, context_len)
        with vs.variable_scope("StartDist"):
            softmax_layer_start = SimpleSoftmaxLayer()
            self.logits_start, self.probdist_start = softmax_layer_start.build_graph(blended_reps_final, self.context_mask)

        # Use softmax layer to compute probability distribution for end location
        # Note this produces self.logits_end and self.probdist_end, both of which have shape (batch_size, context_len)
        with vs.variable_scope("EndDist"):
            softmax_layer_end = SimpleSoftmaxLayer()
            self.logits_end, self.probdist_end = softmax_layer_end.build_graph(blended_reps_final, self.context_mask)
示例#9
0
文件: qa_model.py 项目: shpda/squad
    def build_graph(self):
        """Builds the main part of the graph for the model, starting from the input embeddings to the final distributions for the answer span.

        Defines:
          self.logits_start, self.logits_end: Both tensors shape (batch_size, context_len).
            These are the logits (i.e. values that are fed into the softmax function) for the start and end distribution.
            Important: these are -large in the pad locations. Necessary for when we feed into the cross entropy function.
          self.probdist_start, self.probdist_end: Both shape (batch_size, context_len). Each row sums to 1.
            These are the result of taking (masked) softmax of logits_start and logits_end.
        """

        if self.FLAGS.more_single_dir_rnn:
            # Use a RNN to get hidden states for the context and the question
            # Note: here the RNNEncoder is shared (i.e. the weights are the same)
            # between the context and the question.
            encoder0 = RNNEncoder0(self.FLAGS.hidden_size, self.keep_prob)
            # (batch_size, context_len, hidden_size)
            context_hiddens0 = encoder0.build_graph(self.context_embs, self.context_mask) 
            # (batch_size, question_len, hidden_size)
            question_hiddens0 = encoder0.build_graph(self.qn_embs, self.qn_mask) 
    
            encoder1 = RNNEncoder1(self.FLAGS.hidden_size, self.keep_prob)
            # (batch_size, context_len, hidden_size*2)
            context_hiddens1 = encoder1.build_graph(context_hiddens0, self.context_mask) 
            # (batch_size, question_len, hidden_size*2)
            question_hiddens1 = encoder1.build_graph(question_hiddens0, self.qn_mask) 
        else:
            encoder1 = RNNEncoder1(self.FLAGS.hidden_size, self.keep_prob)
            # (batch_size, context_len, hidden_size*2)
            context_hiddens1 = encoder1.build_graph(self.context_embs, self.context_mask) 
            # (batch_size, question_len, hidden_size*2)
            question_hiddens1 = encoder1.build_graph(self.qn_embs, self.qn_mask) 

        # Use context hidden states to attend to question hidden states
        basic_attn_layer = BasicAttn(self.keep_prob, self.FLAGS.hidden_size*2, self.FLAGS.hidden_size*2, self.FLAGS.advanced_basic_attn)
        # attn_output is shape (batch_size, context_len, hidden_size*2)
        _, basic_attn_output = basic_attn_layer.build_graph(question_hiddens1, self.qn_mask, context_hiddens1) 

        # Concat basic_attn_output to context_hiddens to get blended_reps0
        blended_reps0 = tf.concat([context_hiddens1, basic_attn_output], axis=2) # (batch_size, context_len, hidden_size*4)

        if self.FLAGS.more_single_dir_rnn:
            rnnBasicAttn = RNNBasicAttn(self.FLAGS.hidden_size*4, self.keep_prob)
            rnn_basic_attn_reps = rnnBasicAttn.build_graph(blended_reps0, self.context_mask) # (batch_size, context_len, hidden_size*4)
        else:
            rnn_basic_attn_reps = blended_reps0
        
        # Gang: adding self attention (R-NET)
        # self_attn_layer = SelfAttn(self.keep_prob, self.FLAGS.hidden_size*4)
        # # (batch_size, context_len, hidden_size*4)
        # _, self_attn_output = self_attn_layer.build_graph(basic_attn_output, self.context_mask) 

        # Gang: adding dot attention (Attention Is All You Need)
        dot_attn_layer = DotAttn(self.keep_prob, self.FLAGS.hidden_size*4, self.FLAGS.advanced_dot_attn)
        # (batch_size, context_len, hidden_size*4)
        _, dot_attn_output = dot_attn_layer.build_graph(rnn_basic_attn_reps, self.context_mask) 
        
        # Concat dot_attn_output to blended_reps0 to get blended_reps1
        blended_reps1 = tf.concat([rnn_basic_attn_reps, dot_attn_output], axis=2) # (batch_size, context_len, hidden_size*8)

        # Gang: adding gated representation (R-NET)
        if self.FLAGS.gated_reps:
            gated_reps_layer = GatedReps(self.FLAGS.hidden_size*8)
            gated_blended_reps = gated_reps_layer.build_graph(blended_reps1)
        else:
            gated_blended_reps = blended_reps1

        rnnDotAttn = RNNDotAttn(self.FLAGS.hidden_size*8, self.keep_prob)
        # (batch_size, context_len, hidden_size*16)
        rnn_dot_attn_reps = rnnDotAttn.build_graph(gated_blended_reps, self.context_mask) 

        if self.FLAGS.use_answer_pointer:
            # blended_reps_final = tf.contrib.layers.fully_connected(rnn_dot_attn_reps, 
            #                      num_outputs = self.FLAGS.hidden_size*2) 

            pointer_layer_start = AnswerPointerLayerStart(self.keep_prob, self.FLAGS.hidden_size, self.FLAGS.hidden_size*16)
            rQ, self.logits_start, self.probdist_start = pointer_layer_start.build_graph(question_hiddens1, self.qn_mask, 
                                                                                         rnn_dot_attn_reps, self.context_mask)

            pointer_layer_end = AnswerPointerLayerEnd(self.keep_prob, self.FLAGS.hidden_size*16)
            self.logits_end, self.probdist_end = pointer_layer_end.build_graph(self.probdist_start, rQ, 
                                                                               rnn_dot_attn_reps, self.context_mask)
        else:
            # Apply fully connected layer to each blended representation
            # Note, blended_reps_final corresponds to b' in the handout
            # Note, tf.contrib.layers.fully_connected applies a ReLU non-linarity here by default
            # blended_reps_final is shape (batch_size, context_len, hidden_size)
            blended_reps_final = tf.contrib.layers.fully_connected(rnn_dot_attn_reps, num_outputs=self.FLAGS.hidden_size) 

            # Use softmax layer to compute probability distribution for start location
            # Note this produces self.logits_start and self.probdist_start, both of which have shape (batch_size, context_len)
            with vs.variable_scope("StartDist"):
                softmax_layer_start = SimpleSoftmaxLayer()
                self.logits_start, self.probdist_start = softmax_layer_start.build_graph(blended_reps_final, self.context_mask)
    
            # Use softmax layer to compute probability distribution for end location
            # Note this produces self.logits_end and self.probdist_end, both of which have shape (batch_size, context_len)
            with vs.variable_scope("EndDist"):
                softmax_layer_end = SimpleSoftmaxLayer()
                self.logits_end, self.probdist_end = softmax_layer_end.build_graph(blended_reps_final, self.context_mask)
示例#10
0
文件: qa_model.py 项目: xuwd11/QANet
    def build_graph(self):
        """Builds the main part of the graph for the model, starting from the input embeddings to the final distributions for the answer span.

        Defines:
          self.logits_start, self.logits_end: Both tensors shape (batch_size, context_len).
            These are the logits (i.e. values that are fed into the softmax function) for the start and end distribution.
            Important: these are -large in the pad locations. Necessary for when we feed into the cross entropy function.
          self.probdist_start, self.probdist_end: Both shape (batch_size, context_len). Each row sums to 1.
            These are the result of taking (masked) softmax of logits_start and logits_end.
        """

        # Use a RNN to get hidden states for the context and the question
        # Note: here the RNNEncoder is shared (i.e. the weights are the same)
        # between the context and the question.
        if self.FLAGS.cell_type in ['rnn_gru', 'rnn_lstm']:
            encoder = RNNEncoder(self.FLAGS.hidden_size,
                                 self.keep_prob,
                                 cell_type=self.FLAGS.cell_type)
            context_hiddens = encoder.build_graph(
                self.context_embs,
                self.context_mask)  # (batch_size, context_len, hidden_size*2)
            question_hiddens = encoder.build_graph(
                self.qn_embs,
                self.qn_mask)  # (batch_size, question_len, hidden_size*2)
        elif self.FLAGS.cell_type == 'qanet':
            encoder = QAEncoder(num_blocks=self.FLAGS.emb_num_blocks, num_layers=self.FLAGS.emb_num_layers, \
                                num_heads=self.FLAGS.emb_num_heads, \
                                filters=self.FLAGS.hidden_size, kernel_size=self.FLAGS.emb_kernel_size, \
                                keep_prob=self.keep_prob, input_mapping=True)
            context_hiddens = encoder.build_graph(self.context_embs,
                                                  self.context_mask)
            question_hiddens = encoder.build_graph(self.qn_embs, self.qn_mask)

        if self.FLAGS.attention == 'basic':
            # Use context hidden states to attend to question hidden states
            attn_layer = BasicAttn(self.keep_prob, self.FLAGS.hidden_size * 2,
                                   self.FLAGS.hidden_size * 2)
            _, attn_output = attn_layer.build_graph(
                question_hiddens, self.qn_mask, context_hiddens
            )  # attn_output is shape (batch_size, context_len, hidden_size*2)

            # Concat attn_output to context_hiddens to get blended_reps
            blended_reps = tf.concat(
                [context_hiddens, attn_output],
                axis=2)  # (batch_size, context_len, hidden_size*4)

        elif self.FLAGS.attention == 'bidaf':
            attn_layer = BiDAFAttn(self.keep_prob)
            blended_reps = attn_layer.build_graph(context_hiddens,
                                                  self.context_mask,
                                                  question_hiddens,
                                                  self.qn_mask)

        if self.FLAGS.modeling_layer == 'basic':
            # Apply fully connected layer to each blended representation
            # Note, blended_reps_final corresponds to b' in the handout
            # Note, tf.contrib.layers.fully_connected applies a ReLU non-linarity here by default
            blended_reps_final = tf.contrib.layers.fully_connected(
                blended_reps,
                num_outputs=self.FLAGS.hidden_size,
                weights_initializer=initializer_relu()
            )  # blended_reps_final is shape (batch_size, context_len, hidden_size)

            # Use softmax layer to compute probability distribution for start location
            # Note this produces self.logits_start and self.probdist_start, both of which have shape (batch_size, context_len)
            with tf.variable_scope("StartDist"):
                softmax_layer_start = SimpleSoftmaxLayer()
                self.logits_start, self.probdist_start = softmax_layer_start.build_graph(
                    blended_reps_final, self.context_mask)

            # Use softmax layer to compute probability distribution for end location
            # Note this produces self.logits_end and self.probdist_end, both of which have shape (batch_size, context_len)
            with tf.variable_scope("EndDist"):
                softmax_layer_end = SimpleSoftmaxLayer()
                self.logits_end, self.probdist_end = softmax_layer_end.build_graph(
                    blended_reps_final, self.context_mask)

        elif self.FLAGS.modeling_layer == 'rnn':
            encoder_start = RNNEncoder(self.FLAGS.hidden_size, self.keep_prob, \
                                       cell_type=self.FLAGS.cell_type, name='m1')
            m1 = encoder_start.build_graph(blended_reps, self.context_mask)
            encoder_end = RNNEncoder(self.FLAGS.hidden_size, self.keep_prob, \
                                     cell_type=self.FLAGS.cell_type, name='m2')
            m2 = encoder_end.build_graph(m1, self.context_mask)
            with tf.variable_scope("StartDist"):
                softmax_layer_start = SimpleSoftmaxLayer()
                self.logits_start, self.probdist_start = softmax_layer_start.build_graph(
                    tf.concat([blended_reps, m1], -1), self.context_mask)
            with tf.variable_scope("EndDist"):
                softmax_layer_end = SimpleSoftmaxLayer()
                self.logits_end, self.probdist_end = softmax_layer_end.build_graph(
                    tf.concat([blended_reps, m2], -1), self.context_mask)

        elif self.FLAGS.modeling_layer == 'qanet':
            modeling_encoder = QAEncoder(num_blocks=self.FLAGS.model_num_blocks, \
                                         num_layers=self.FLAGS.model_num_layers, \
                                         num_heads=self.FLAGS.model_num_heads, \
                                         filters=self.FLAGS.hidden_size, \
                                         kernel_size=self.FLAGS.model_kernel_size, \
                                         keep_prob=self.keep_prob, input_mapping=False, \
                                         name='modeling_encoder')
            m0 = tf.layers.conv1d(blended_reps, filters=self.FLAGS.hidden_size, \
                                  kernel_size=1, padding='SAME', name='attn_mapping')
            m1 = modeling_encoder.build_graph(m0, self.context_mask)
            m2 = modeling_encoder.build_graph(m1, self.context_mask)
            m3 = modeling_encoder.build_graph(m2, self.context_mask)
            with tf.variable_scope("StartDist"):
                softmax_layer_start = SimpleSoftmaxLayer()
                self.logits_start, self.probdist_start = softmax_layer_start.build_graph(
                    tf.concat([m1, m2], -1), self.context_mask)
            with tf.variable_scope("EndDist"):
                softmax_layer_end = SimpleSoftmaxLayer()
                self.logits_end, self.probdist_end = softmax_layer_end.build_graph(
                    tf.concat([m1, m3], -1), self.context_mask)

        elif self.FLAGS.modeling_layer == 'qanet2':
            modeling_encoder1 = QAEncoder(num_blocks=self.FLAGS.model_num_blocks, \
                                          num_layers=self.FLAGS.model_num_layers, \
                                          num_heads=self.FLAGS.model_num_heads, \
                                          filters=self.FLAGS.hidden_size, \
                                          kernel_size=self.FLAGS.model_kernel_size, \
                                          keep_prob=self.keep_prob, input_mapping=False, \
                                          name='modeling_encoder1')
            '''
            modeling_encoder2 = QAEncoder(num_blocks=self.FLAGS.model_num_blocks, \
                                          num_layers=self.FLAGS.model_num_layers, \
                                          num_heads=self.FLAGS.model_num_heads, \
                                          filters=self.FLAGS.hidden_size, \
                                          kernel_size=self.FLAGS.model_kernel_size, \
                                          keep_prob=self.keep_prob, input_mapping=False, \
                                          name='modeling_encoder2')
            '''
            m0 = tf.layers.conv1d(blended_reps, filters=self.FLAGS.hidden_size, \
                                  kernel_size=1, padding='SAME', name='attn_mapping')
            m1 = modeling_encoder1.build_graph(m0, self.context_mask)
            m2 = modeling_encoder1.build_graph(m1, self.context_mask)
            with tf.variable_scope("StartDist"):
                softmax_layer_start = SimpleSoftmaxLayer()
                self.logits_start, self.probdist_start = softmax_layer_start.build_graph(
                    tf.concat([m0, m1], -1), self.context_mask)
            with tf.variable_scope("EndDist"):
                softmax_layer_end = SimpleSoftmaxLayer()
                self.logits_end, self.probdist_end = softmax_layer_end.build_graph(
                    tf.concat([m0, m2], -1), self.context_mask)
    def build_graph(self):
        """Builds the main part of the graph for the model, starting from the input embeddings to the final distributions for the answer span.
        Defines:
          self.logits_start, self.logits_end: Both tensors shape (batch_size, context_len).
            These are the logits (i.e. values that are fed into the softmax function) for the start and end distribution.
            Important: these are -large in the pad locations. Necessary for when we feed into the cross entropy function.
          self.probdist_start, self.probdist_end: Both shape (batch_size, context_len). Each row sums to 1.
            These are the result of taking (masked) softmax of logits_start and logits_end.
        """

        # Use a RNN to get hidden states for the context and the question
        # Note: here the RNNEncoder is shared (i.e. the weights are the same)
        # between the context and the question.
        if self.FLAGS.model == "baseline" :
            encoder = RNNEncoder(self.FLAGS.hidden_size, self.keep_prob)
        elif self.FLAGS.model == "bidaf" or self.FLAGS.model == "bidaf_dynamic" or self.FLAGS.model=="bidaf_self_attn" or self.FLAGS.model=="bidaf_dynamic_self_attn":
            print("INSIDE the BIDAF model")
            encoder = RNNEncoder_LSTM(self.FLAGS.hidden_size, self.keep_prob)
        elif self.FLAGS.model == "coatt" or self.FLAGS.model == "coatt_dynamic" or self.FLAGS.model=="coatt_dynamic_self_attn":
            encoder = LSTMEncoder(self.FLAGS.hidden_size, self.keep_prob)

        if self.FLAGS.model != "coatt" and self.FLAGS.model != "coatt_dynamic" and self.FLAGS.model!="coatt_dynamic_self_attn":
            context_hiddens = encoder.build_graph(self.context_embs, self.context_mask) # (batch_size, context_len, hidden_size*2)
            question_hiddens = encoder.build_graph(self.qn_embs, self.qn_mask) # (batch_size, question_len, hidden_size*2)

        # Attention model
        # Use context hidden states to attend to question hidden states
        if self.FLAGS.model == "baseline" :
            attn_layer = BasicAttn(self.keep_prob, self.FLAGS.hidden_size * 2, self.FLAGS.hidden_size * 2)
            _,attn_output = attn_layer.build_graph(question_hiddens, self.qn_mask, context_hiddens)  # attn_output is shape (batch_size, context_len, hidden_size*2)
            # Concat attn_output to context_hiddens to get blended_reps
            blended_reps = tf.concat([context_hiddens, attn_output], axis=2)  # (batch_size, context_len, hidden_size*4)
            # Apply fully connected layer to each blended representation
            # Note, blended_reps_final corresponds to b' in the handout
            # Note, tf.contrib.layers.fully_connected applies a ReLU non-linarity here by default
            blended_reps_final = tf.contrib.layers.fully_connected(blended_reps, num_outputs=self.FLAGS.hidden_size)  # blended_reps_final is shape (batch_size, context_len, hidden_size)

            # Use softmax layer to compute probability distribution for start location
            # Note this produces self.logits_start and self.probdist_start, both of which have shape (batch_size, context_len)
            with vs.variable_scope("StartDist"):
                softmax_layer_start = SimpleSoftmaxLayer()
                self.logits_start, self.probdist_start = softmax_layer_start.build_graph(blended_reps_final,self.context_mask)

            # Use softmax layer to compute probability distribution for end location
            # Note this produces self.logits_end and self.probdist_end, both of which have shape (batch_size, context_len)
            with vs.variable_scope("EndDist"):
                softmax_layer_end = SimpleSoftmaxLayer()
                self.logits_end, self.probdist_end = softmax_layer_end.build_graph(blended_reps_final,self.context_mask)

        # Attention model
        # Use context hidden states to attend to question hidden states
        if self.FLAGS.model == "coatt" :
            #context_hiddens = encoder.build_graph(self.context_embs, self.context_mask, "context") # (batch_size, context_len, hidden_size*2)
            #question_hiddens = encoder.build_graph(self.qn_embs, self.qn_mask, "question") # (batch_size, question_len, hidden_size*2)
            context_hiddens, question_hiddens = encoder.build_graph1(self.context_embs, self.qn_embs, self.context_mask, self.qn_mask)

            attn_layer = CoAttention(self.keep_prob, self.FLAGS.hidden_size*2, self.FLAGS.hidden_size*2)
            attn_output = attn_layer.build_graph(question_hiddens, self.qn_mask, context_hiddens, self.context_mask)
            blended_reps_final = attn_output
            #blended_reps = tf.concat([context_hiddens, attn_output], axis=2)
            # Apply fully connected layer to each blended representation
            # Note, blended_reps_final corresponds to b' in the handout
            # Note, tf.contrib.layers.fully_connected applies a ReLU non-linarity here by default
            #blended_reps_final = tf.contrib.layers.fully_connected(blended_reps, num_outputs=self.FLAGS.hidden_size)  # blended_reps_final is shape (batch_size, context_len, hidden_size)

            # Use softmax layer to compute probability distribution for start location
            # Note this produces self.logits_start and self.probdist_start, both of which have shape (batch_size, context_len)
            with vs.variable_scope("StartDist"):
                softmax_layer_start = SimpleSoftmaxLayer()
                self.logits_start, self.probdist_start = softmax_layer_start.build_graph(blended_reps_final,self.context_mask)

            # Use softmax layer to compute probability distribution for end location
            # Note this produces self.logits_end and self.probdist_end, both of which have shape (batch_size, context_len)
            with vs.variable_scope("EndDist"):
                contextLen = tf.reduce_sum(self.context_mask, axis=1)
                cell = tf.contrib.rnn.LSTMBlockCell(2 * self.FLAGS.hidden_size)
                (fw_out, bw_out), _ = tf.nn.bidirectional_dynamic_rnn(cell, cell, attn_output, contextLen, dtype = tf.float32)
                U_1 = tf.concat([fw_out, bw_out], axis=2)
                out = tf.nn.dropout(U_1, self.keep_prob)
                softmax_layer_end = SimpleSoftmaxLayer()
                self.logits_end, self.probdist_end = softmax_layer_end.build_graph(out,self.context_mask)


        elif self.FLAGS.model =="bidaf"  or self.FLAGS.model=="bidaf_self_attn":
            attn_layer = BiDafAttn(self.keep_prob, self.FLAGS.hidden_size*2, self.FLAGS.hidden_size*2)
            attn_output_tmp = attn_layer.build_graph(question_hiddens, self.qn_mask, context_hiddens, self.context_mask) # attn_output is shape (batch_size, context_len, hidden_size*8)
            # Set of vectors which produces a set of query aware feature vectors for each word in the context
            #blended_reps = attn_output  #(batch_size, num_keys, 4*value_vec_size)

            if self.FLAGS.model == "bidaf_self_attn":
                self_attn_layer = SelfAttn(self.keep_prob, self.FLAGS.hidden_size * 8, self.FLAGS.hidden_size * 8)
                _,self_attn_output = self_attn_layer.build_graph(attn_output_tmp, self.context_mask) #(batch_size, conetx_len, 8*hidden_size)
                attn_output = tf.concat([attn_output_tmp, self_attn_output], axis=2) #(batch_size, context_len, 16*hidden_size)
            else:
                attn_output = attn_output_tmp


            # In BIDAF the attention output is feed to a modeling layer
            # The Modeling layer is a 2 layer lstm
            mod_layer = MODEL_LAYER_BIDAF(self.FLAGS.hidden_size, self.keep_prob)
            mod_layer_out = mod_layer.build_graph(attn_output, self.context_mask)  # (batch_size, context_len, hidden_size*2)
            blended_reps_start = tf.concat([attn_output,mod_layer_out], axis=2)  # (batch_size, context_len, hidden_size*10)


            # Use softmax layer to compute probability distribution for start location
            # Note this produces self.logits_start and self.probdist_start, both of which have shape (batch_size, context_len)
            with vs.variable_scope("StartDist"):
                softmax_layer_start = SimpleSoftmaxLayer()
                self.logits_start, self.probdist_start = softmax_layer_start.build_graph(blended_reps_start, self.context_mask)



            # Use softmax layer to compute probability distribution for end location
            # Note this produces self.logits_end and self.probdist_end, both of which have shape (batch_size, context_len)


            with vs.variable_scope("EndDist"):
                # Concatenate the start logits with the modelling layer output to get the input to the
                # end word lstm
                #self.logits_start has a shape of #(batch_size, context_len)
                logits_start_expand = tf.expand_dims(self.logits_start, axis=2) #(batch_size, context_len, 1)
                end_lstm_input = tf.concat([logits_start_expand, mod_layer_out], axis=2) #(batch_size, context_len, 1 + hidden_size*2)

                # LSTM
                end_layer = END_WORD_LAYER(self.FLAGS.hidden_size, self.keep_prob)
                blended_reps_end = end_layer.build_graph(end_lstm_input, self.context_mask)

                blended_reps_end_final = tf.concat([attn_output, blended_reps_end], axis=2)
                softmax_layer_end = SimpleSoftmaxLayer()
                self.logits_end, self.probdist_end = softmax_layer_end.build_graph(blended_reps_end_final, self.context_mask)

        elif self.FLAGS.model =="bidaf_dynamic" or self.FLAGS.model =="bidaf_dynamic_self_attn":
            attn_layer = BiDafAttn(self.keep_prob, self.FLAGS.hidden_size*2, self.FLAGS.hidden_size*2)
            attn_output_tmp = attn_layer.build_graph(question_hiddens, self.qn_mask, context_hiddens, self.context_mask) # attn_output is shape (batch_size, context_len, hidden_size*8)

            if self.FLAGS.model == "bidaf_dynamic_self_attn":
                self_attn_layer = SelfAttn(self.keep_prob, self.FLAGS.hidden_size * 8, self.FLAGS.hidden_size * 8)
                _,self_attn_output = self_attn_layer.build_graph(attn_output_tmp,self.context_mask)  # (batch_size, conetx_len, 8*hidden_size)
                attn_output = tf.concat([attn_output_tmp, self_attn_output], axis=2) #(batch_size, context_len, 16*hidden_size)
            else:
                attn_output = attn_output_tmp

            # Set of vectors which produces a set of query aware feature vectors for each word in the context
            #blended_reps = attn_output  #(batch_size, num_keys, 4*value_vec_size)

            # In BIDAF the attention output is feed to a modeling layer
            # The Modeling layer is a 2 layer lstm
            mod_layer = MODEL_LAYER_BIDAF(self.FLAGS.hidden_size, self.keep_prob)
            mod_layer_out = mod_layer.build_graph(attn_output, self.context_mask)  # (batch_size, context_len, hidden_size*2)
            blended_reps_start = tf.concat([attn_output,mod_layer_out], axis=2)  # (batch_size, context_len, hidden_size*10)

            # We now feed this to dynamic decoder module coded in Answer decoder
            # the output of the decoder are start, end, alpha_logits and beta_logits
            # start and end have a shape of (batch_size, num_iterations)
            #alpha_logits and beta_logits have a shape of (batch_size, num_iterations, inpit_dim)
            decoder = ANSWER_DECODER(self.FLAGS.hidden_size, self.keep_prob, self.FLAGS.num_iterations, self.FLAGS.max_pool, self.FLAGS.batch_size)

            u_s_init = mod_layer_out[:,0,:]
            u_e_init = mod_layer_out[:,0,:]
            start_location, end_location, alpha_logits, beta_logits = decoder.build_graph(mod_layer_out, self.context_mask, u_s_init, u_e_init)


            # Use softmax layer to compute probability distribution for start location
            # Note this produces self.logits_start and self.probdist_start, both of which have shape (batch_size, context_len)
            with vs.variable_scope("StartDist"):
                #softmax_layer_start = SimpleSoftmaxLayer()
                logits_start_tmp = [masked_softmax(logits, self.context_mask,1) for logits in alpha_logits]
                self.alpha_logits , alpha_logits_probs = zip(*logits_start_tmp)
                self.logits_start, self.probdist_start = self.alpha_logits[self.FLAGS.num_iterations -1], alpha_logits_probs[self.FLAGS.num_iterations -1]

            # Use softmax layer to compute probability distribution for end location
            # Note this produces self.logits_end and self.probdist_end, both of which have shape (batch_size, context_len)


            with vs.variable_scope("EndDist"):
                logits_end_tmp = [masked_softmax(logits, self.context_mask,1) for logits in beta_logits]
                self.beta_logits , beta_logits_probs = zip(*logits_end_tmp)
                self.logits_end, self.probdist_end = self.beta_logits[self.FLAGS.num_iterations -1], beta_logits_probs[self.FLAGS.num_iterations -1]

        elif self.FLAGS.model =="coatt_dynamic" or self.FLAGS.model == "coatt_dynamic_self_attn":
            context_hiddens, question_hiddens = encoder.build_graph1(self.context_embs, self.qn_embs, self.context_mask, self.qn_mask)

            attn_layer = CoAttention(self.keep_prob, self.FLAGS.hidden_size*2, self.FLAGS.hidden_size*2)

            if self.FLAGS.model == "coatt_dynamic_self_attn":
                CoATT = attn_layer.build_graph1(question_hiddens, self.qn_mask, context_hiddens, self.context_mask)
                self_attn_layer = SelfAttn(self.keep_prob, self.FLAGS.hidden_size * 8, self.FLAGS.hidden_size * 8)
                _, self_attn_output = self_attn_layer.build_graph(CoATT, self.context_mask)  # (batch_size, conetx_len, 8*hidden_size)
                attn_output = tf.concat([CoATT, self_attn_output], axis=2) #(batch_size, context_len, 16*hidden_size)
            else:
                U = attn_layer.build_graph(question_hiddens, self.qn_mask, context_hiddens, self.context_mask)
                attn_output = U
            #blended_reps = tf.concat([context_hiddens, attn_output], axis=2)
            # Apply fully connected layer to each blended representation
            # Note, blended_reps_final corresponds to b' in the handout
            # Note, tf.contrib.layers.fully_connected applies a ReLU non-linarity here by default
            decoder = ANSWER_DECODER(self.FLAGS.hidden_size, self.keep_prob, self.FLAGS.num_iterations, self.FLAGS.max_pool, self.FLAGS.batch_size)

            u_s_init = attn_output[:,0,:]
            u_e_init = attn_output[:,0,:]
            start_location, end_location, alpha_logits, beta_logits = decoder.build_graph(attn_output, self.context_mask, u_s_init, u_e_init)


            # Use softmax layer to compute probability distribution for start location
            # Note this produces self.logits_start and self.probdist_start, both of which have shape (batch_size, context_len)
            with vs.variable_scope("StartDist"):
                #softmax_layer_start = SimpleSoftmaxLayer()
                logits_start_tmp = [masked_softmax(logits, self.context_mask,1) for logits in alpha_logits]
                self.alpha_logits , alpha_logits_probs = zip(*logits_start_tmp)
                self.logits_start, self.probdist_start = self.alpha_logits[self.FLAGS.num_iterations -1], alpha_logits_probs[self.FLAGS.num_iterations -1]

                # Use softmax layer to compute probability distribution for end location
                # Note this produces self.logits_end and self.probdist_end, both of which have shape (batch_size, context_len)


            with vs.variable_scope("EndDist"):
                logits_end_tmp = [masked_softmax(logits, self.context_mask,1) for logits in beta_logits]
                self.beta_logits , beta_logits_probs = zip(*logits_end_tmp)
                self.logits_end, self.probdist_end = self.beta_logits[self.FLAGS.num_iterations -1], beta_logits_probs[self.FLAGS.num_iterations -1]
示例#12
0
    def build_graph(self):
        """Builds the main part of the graph for the model, starting from the input embeddings to the final distributions for the answer span.

        Defines:
          self.logits_start, self.logits_end: Both tensors shape (batch_size, context_len).
            These are the logits (i.e. values that are fed into the softmax function) for the start and end distribution.
            Important: these are -large in the pad locations. Necessary for when we feed into the cross entropy function.
          self.probdist_start, self.probdist_end: Both shape (batch_size, context_len). Each row sums to 1.
            These are the result of taking (masked) softmax of logits_start and logits_end.
        """

        # Use a RNN to get hidden states for the context and the question
        # Note: here the RNNEncoder is shared (i.e. the weights are the same)
        # between the context and the question.
        if self.FLAGS.self_attention:
            encoder = RNNEncoder(self.FLAGS.hidden_size_encoder,
                                 self.keep_prob)
        else:
            encoder = RNNEncoder(self.FLAGS.hidden_size, self.keep_prob)

        context_hiddens = encoder.build_graph(
            self.context_embs,
            self.context_mask)  # (batch_size, context_len, hidden_size*2)
        question_hiddens = encoder.build_graph(
            self.qn_embs,
            self.qn_mask)  # (batch_size, question_len, hidden_size*2)

        # Use context hidden states to attend to question hidden states
        if self.FLAGS.simple_attention:
            attn_layer = BasicAttn(self.keep_prob, self.FLAGS.hidden_size * 2,
                                   self.FLAGS.hidden_size * 2)
            _, attn_output = attn_layer.build_graph(
                question_hiddens, self.qn_mask, context_hiddens
            )  # attn_output is shape (batch_size, context_len, hidden_size*2)

            # Concat attn_output to context_hiddens to get blended_reps
            blended_reps = tf.concat(
                [context_hiddens, attn_output],
                axis=2)  # (batch_size, context_len, hidden_size*4)

            # Apply fully connected layer to each blended representation
            # Note, blended_reps_final corresponds to b' in the handout
            # Note, tf.contrib.layers.fully_connected applies a ReLU non-linarity here by default
            blended_reps_final = tf.contrib.layers.fully_connected(
                blended_reps, num_outputs=self.FLAGS.hidden_size
            )  # blended_reps_final is shape (batch_size, context_len, hidden_size)

        if self.FLAGS.co_attention:
            #This step sends the question embeddings through a fully-connected-layer to allow for variation between question_embedding and document embedding space
            question_hiddens_t = tf.transpose(
                question_hiddens,
                perm=[0, 2, 1])  #(batch_size,hidden_size*2,question_len)
            trans_question_hiddens_t = tf.contrib.layers.fully_connected(
                question_hiddens_t,
                num_outputs=self.FLAGS.question_len,
                activation_fn=tf.nn.tanh
            )  #(batch_size,hidden_size*2,question_len)
            trans_question_hiddens = tf.transpose(
                trans_question_hiddens_t,
                perm=[0, 2, 1])  #(batch_size,question_len,hidden_size*2)

            #Computing the coattention context
            co_attn_layer = CoAttn(self.keep_prob, self.FLAGS.hidden_size * 2,
                                   self.FLAGS.hidden_size * 2)
            co_attn_output = co_attn_layer.build_graph(
                trans_question_hiddens, self.qn_mask, self.context_mask,
                context_hiddens)  #(batch_size,context_len,6*hidden_size)

            # performing the fusion of temporal information to the coattention context via a bidirectional GRU
            with tf.variable_scope("co-attn-encoder"):
                co_attn_encoder = LSTMEncoder(self.FLAGS.hidden_size,
                                              self.keep_prob)
                blended_reps_final = co_attn_encoder.build_graph(
                    co_attn_output, self.context_mask)

        if self.FLAGS.self_attention:
            # implemrntation of self attention of the rnet paper

            self_attention_encoder = SelfAttn(self.FLAGS.hidden_size_encoder,
                                              self.FLAGS.hidden_size_qp,
                                              self.FLAGS.hidden_size_pp,
                                              self.keep_prob)
            v_p = self_attention_encoder.build_graph_qp(
                context_hiddens, question_hiddens, self.context_mask,
                self.qn_mask, self.FLAGS.context_len, self.FLAGS.question_len)
            h_p = self_attention_encoder.build_graph_pp(
                context_hiddens, question_hiddens, self.context_mask,
                self.qn_mask, v_p, self.FLAGS.context_len,
                self.FLAGS.question_len)
            blended_reps_final = tf.concat(
                [context_hiddens, v_p, h_p],
                axis=2)  #(batch_size,context_len,5*hidden_size)

        if self.FLAGS.answer_pointer:
            #implementation of answer pointer as used in R-Net paper
            if self.FLAGS.co_attention:
                hidden_size_attn = self.FLAGS.hidden_size * 2
            elif self.FLAGS.self_attention:
                hidden_size_attn = 2 * self.FLAGS.hidden_size_encoder + self.FLAGS.hidden_size_qp + 2 * self.FLAGS.hidden_size_pp
            else:
                hidden_size_attn = self.FLAGS.hidden_size

            answer_decoder = AnswerPointer(self.FLAGS.hidden_size_encoder,
                                           hidden_size_attn,
                                           self.FLAGS.question_len,
                                           self.keep_prob)
            p, logits = answer_decoder.build_graph_answer_pointer(
                question_hiddens, context_hiddens, blended_reps_final,
                self.FLAGS.question_len, self.FLAGS.context_len, self.qn_mask,
                self.context_mask)

            self.logits_start = logits[0]
            self.probdist_start = p[0]

            self.logits_end = logits[1]
            self.probdist_end = p[1]

        if self.FLAGS.simple_softmax:
            # Use softmax layer to compute probability distribution for start location
            # Note this produces self.logits_start and self.probdist_start, both of which have shape (batch_size, context_len)
            with vs.variable_scope("StartDist"):
                softmax_layer_start = SimpleSoftmaxLayer()
                self.logits_start, self.probdist_start = softmax_layer_start.build_graph(
                    blended_reps_final, self.context_mask)

            # Use softmax layer to compute probability distribution for end location
            # Note this produces self.logits_end and self.probdist_end, both of which have shape (batch_size, context_len)
            with vs.variable_scope("EndDist"):
                softmax_layer_end = SimpleSoftmaxLayer()
                self.logits_end, self.probdist_end = softmax_layer_end.build_graph(
                    blended_reps_final, self.context_mask)
示例#13
0
    def build_graph(self):
        """Builds the main part of the graph for the model, starting from the input embeddings to the final distributions for the answer span.

        Defines:
          self.logits_start, self.logits_end: Both tensors shape (batch_size, context_len).
            These are the logits (i.e. values that are fed into the softmax function) for the start and end distribution.
            Important: these are -large in the pad locations. Necessary for when we feed into the cross entropy function.
          self.probdist_start, self.probdist_end: Both shape (batch_size, context_len). Each row sums to 1.
            These are the result of taking (masked) softmax of logits_start and logits_end.
        """

        # Use a RNN to get hidden states for the context and the question
        # Note: here the RNNEncoder is shared (i.e. the weights are the same)
        # between the context and the question.
        context_input_lens = tf.reshape(
            tf.reduce_sum(tf.cast(tf.cast(self.context_char_ids, tf.bool),
                                  tf.int32),
                          axis=2), [-1])
        qn_input_lens = tf.reshape(
            tf.reduce_sum(tf.cast(tf.cast(self.qn_char_ids, tf.bool),
                                  tf.int32),
                          axis=2), [-1])
        cell_fw = rnn_cell.GRUCell(self.FLAGS.hidden_size)
        cell_bw = rnn_cell.GRUCell(self.FLAGS.hidden_size)
        _, (state_fw,
            state_bw) = tf.nn.bidirectional_dynamic_rnn(cell_fw,
                                                        cell_bw,
                                                        self.context_char_embs,
                                                        context_input_lens,
                                                        dtype=tf.float32)
        ch_emb = tf.reshape(
            tf.concat([state_fw, state_bw], axis=1),
            [-1, self.FLAGS.context_len, 2 * self.FLAGS.hidden_size])
        self.context_embs = tf.concat([self.context_embs, ch_emb], axis=2)

        _, (state_fw,
            state_bw) = tf.nn.bidirectional_dynamic_rnn(cell_fw,
                                                        cell_bw,
                                                        self.qn_char_embs,
                                                        qn_input_lens,
                                                        dtype=tf.float32)
        qh_emb = tf.reshape(
            tf.concat([state_fw, state_bw], axis=1),
            [-1, self.FLAGS.question_len, 2 * self.FLAGS.hidden_size])
        self.qn_embs = tf.concat([self.qn_embs, qh_emb], axis=2)

        # ToDo Deep encoder
        encoder = RNNEncoder(self.FLAGS.hidden_size, self.keep_prob)
        context_hiddens = encoder.build_graph(
            self.context_embs,
            self.context_mask)  # (batch_size, context_len, hidden_size*2)
        question_hiddens = encoder.build_graph(
            self.qn_embs,
            self.qn_mask)  # (batch_size, question_len, hidden_size*2)

        # Use context hidden states to attend to question hidden states
        attn_layer = BasicAttn(self.keep_prob, self.FLAGS.hidden_size * 2,
                               self.FLAGS.hidden_size * 2)
        _, attn_output = attn_layer.build_graph(
            question_hiddens, self.qn_mask, context_hiddens
        )  # attn_output is shape (batch_size, context_len, hidden_size*2)

        # Concat attn_output to context_hiddens to get blended_reps
        blended_reps = tf.concat(
            [context_hiddens, attn_output],
            axis=2)  # (batch_size, context_len, hidden_size*4)

        # Apply fully connected layer to each blended representation
        # Note, blended_reps_final corresponds to b' in the handout
        # Note, tf.contrib.layers.fully_connected applies a ReLU non-linarity here by default
        blended_reps_final = tf.contrib.layers.fully_connected(
            blended_reps, num_outputs=self.FLAGS.hidden_size
        )  # blended_reps_final is shape (batch_size, context_len, hidden_size)

        # Use softmax layer to compute probability distribution for start location
        # Note this produces self.logits_start and self.probdist_start, both of which have shape (batch_size, context_len)
        with vs.variable_scope("StartDist"):
            softmax_layer_start = SimpleSoftmaxLayer()
            self.logits_start, self.probdist_start = softmax_layer_start.build_graph(
                blended_reps_final, self.context_mask)

        # Use softmax layer to compute probability distribution for end location
        # Note this produces self.logits_end and self.probdist_end, both of which have shape (batch_size, context_len)
        with vs.variable_scope("EndDist"):
            softmax_layer_end = SimpleSoftmaxLayer()
            self.logits_end, self.probdist_end = softmax_layer_end.build_graph(
                blended_reps_final, self.context_mask)
示例#14
0
    def build_graph(self):
        """Builds the main part of the graph for the model, starting from the input embeddings to the final distributions for the answer span.

        Defines:
          self.logits_start, self.logits_end: Both tensors shape (batch_size, context_len).
            These are the logits (i.e. values that are fed into the softmax function) for the start and end distribution.
            Important: these are -large in the pad locations. Necessary for when we feed into the cross entropy function.
          self.probdist_start, self.probdist_end: Both shape (batch_size, context_len). Each row sums to 1.
            These are the result of taking (masked) softmax of logits_start and logits_end.
        """

        print "Running Attention Model with... %s" % self.FLAGS.attention
        if self.FLAGS.attention == "BiDAF":

            encoder = RNNEncoder(self.FLAGS.hidden_size, self.keep_prob)
            context_hiddens = encoder.build_graph(
                self.context_embs,
                self.context_mask)  # (batch_size, context_len, hidden_size*2)
            question_hiddens = encoder.build_graph(
                self.qn_embs,
                self.qn_mask)  # (batch_size, question_len, hidden_size*2)

            bidaf_attn_layer = BiDirectionalAttn(self.keep_prob,
                                                 self.FLAGS.hidden_size * 2,
                                                 self.FLAGS.hidden_size * 2,
                                                 self.FLAGS.question_len,
                                                 self.FLAGS.context_len)
            _, context_to_question, _, question_to_context = bidaf_attn_layer.build_graph(
                question_hiddens, self.qn_mask, context_hiddens,
                self.context_mask)

            # Combine attention vectors and hidden context vector
            context_c2q = tf.multiply(context_hiddens, context_to_question)
            context_q2c = tf.multiply(context_hiddens, question_to_context)
            blended_reps = tf.concat(
                [
                    context_hiddens, context_to_question, context_c2q,
                    context_q2c
                ],
                axis=2)  # (batch_size, context_len, hidden_size*8)

            # Modeling Layers (2 layers of bidirectional LSTM) encodes the query-aware representations of context words.
            modeling_layer = BiRNN(self.FLAGS.hidden_size, self.keep_prob)
            blended_reps_1 = modeling_layer.build_graph(
                blended_reps,
                self.context_mask)  # (batch_size, context_len, hidden_size*2).

            modeling_layer_2 = BiRNN2(self.FLAGS.hidden_size, self.keep_prob)
            blended_reps_final = modeling_layer_2.build_graph(
                blended_reps_1,
                self.context_mask)  # (batch_size, context_len, hidden_size*2).

        else:  # Default: self.FLAGS.attention == "BasicAttn"

            encoder = RNNEncoder(self.FLAGS.hidden_size, self.keep_prob)
            context_hiddens = encoder.build_graph(
                self.context_embs,
                self.context_mask)  # (batch_size, context_len, hidden_size*2)
            question_hiddens = encoder.build_graph(
                self.qn_embs,
                self.qn_mask)  # (batch_size, question_len, hidden_size*2)

            # Use context hidden states to attend to question hidden states
            attn_layer = BasicAttn(self.keep_prob, self.FLAGS.hidden_size * 2,
                                   self.FLAGS.hidden_size * 2)
            _, attn_output = attn_layer.build_graph(
                question_hiddens, self.qn_mask, context_hiddens
            )  # attn_output is shape (batch_size, context_len, hidden_size*2)

            # Concat attn_output to context_hiddens to get blended_reps
            blended_reps = tf.concat(
                [context_hiddens, attn_output],
                axis=2)  # (batch_size, context_len, hidden_size*4)

            # Apply fully connected layer to each blended representation
            # Note, blended_reps_final corresponds to b' in the handout
            # Note, tf.contrib.layers.fully_connected applies a ReLU non-linarity here by default
            blended_reps_final = tf.contrib.layers.fully_connected(
                blended_reps, num_outputs=self.FLAGS.hidden_size
            )  # blended_reps_final is shape (batch_size, context_len, hidden_size)

        # Use softmax layer to compute probability distribution for start location
        # Note this produces self.logits_start and self.probdist_start, both of which have shape (batch_size, context_len)
        with vs.variable_scope("StartDist"):
            softmax_layer_start = SimpleSoftmaxLayer()
            self.logits_start, self.probdist_start = softmax_layer_start.build_graph(
                blended_reps_final, self.context_mask)

        # Use softmax layer to compute probability distribution for end location
        # Note this produces self.logits_end and self.probdist_end, both of which have shape (batch_size, context_len)
        with vs.variable_scope("EndDist"):
            softmax_layer_end = SimpleSoftmaxLayer()
            self.logits_end, self.probdist_end = softmax_layer_end.build_graph(
                blended_reps_final, self.context_mask)
示例#15
0
    def build_graph(self, multi_lstm=False, bidaf=False):
        """Builds the main part of the graph for the model, starting from the input embeddings to the final distributions for the answer span.

        Defines:
          self.logits_start, self.logits_end: Both tensors shape (batch_size, context_len).
            These are the logits (i.e. values that are fed into the softmax function) for the start and end distribution.
            Important: these are -large in the pad locations. Necessary for when we feed into the cross entropy function.
          self.probdist_start, self.probdist_end: Both shape (batch_size, context_len). Each row sums to 1.
            These are the result of taking (masked) softmax of logits_start and logits_end.
        """

        # Use a RNN to get hidden states for the context and the question
        # Note: here the RNNEncoder is shared (i.e. the weights are the same)
        # between the context and the question.
        if multi_lstm is False:
            encoder = RNNEncoder(self.FLAGS.hidden_size, self.keep_prob)
        else:
            encoder = MultiLSTMEncoder(self.FLAGS.hidden_size, self.keep_prob)
        context_hiddens = encoder.build_graph(
            self.context_embs,
            self.context_mask)  # (batch_size, context_len, hidden_size*2)
        question_hiddens = encoder.build_graph(
            self.qn_embs,
            self.qn_mask)  # (batch_size, question_len, hidden_size*2)

        if bidaf is False:
            # Use context hidden states to attend to question hidden states
            attn_layer = BasicAttn(self.keep_prob, self.FLAGS.hidden_size * 2,
                                   self.FLAGS.hidden_size * 2)
            _, attn_output = attn_layer.build_graph(
                question_hiddens, self.qn_mask, context_hiddens
            )  # attn_output is shape (batch_size, context_len, hidden_size*2)
            blended_reps = tf.concat(
                [context_hiddens, attn_output],
                axis=2)  # (batch_size, context_len, hidden_size*4)
        else:
            attn_layer = BiDAFAttn(self.keep_prob, self.FLAGS.hidden_size * 2,
                                   self.FLAGS.hidden_size * 2)
            c2q_attn, q2c_attn = attn_layer.build_graph(
                question_hiddens, self.qn_mask, context_hiddens,
                self.context_mask)
            q2c_attn = q2c_attn + tf.zeros(
                shape=[1, c2q_attn.shape[1], c2q_attn.shape[2]])
            print(q2c_attn.shape, c2q_attn.shape)
            context_c2q = tf.multiply(context_hiddens, c2q_attn)
            context_q2c = tf.multiply(context_hiddens, q2c_attn)
            blended_reps = tf.concat(
                [context_hiddens, c2q_attn, context_c2q, context_q2c],
                axis=2)  # (batch_size, context_hiddens, hidden_size*8)

        # Apply fully connected layer to each blended representation
        # Note, blended_reps_final corresponds to b' in the handout
        # Note, tf.contrib.layers.fully_connected applies a ReLU non-linarity here by default
        blended_reps_final = tf.contrib.layers.fully_connected(
            blended_reps, num_outputs=self.FLAGS.hidden_size
        )  # blended_reps_final is shape (batch_size, context_len, hidden_size)

        # Use softmax layer to compute probability distribution for start location
        # Note this produces self.logits_start and self.probdist_start, both of which have shape (batch_size, context_len)
        if self.FLAGS.start_lstm_decode is False:
            with vs.variable_scope("StartDist"):
                softmax_layer_start = SimpleSoftmaxLayer()
                self.logits_start, self.probdist_start = softmax_layer_start.build_graph(
                    blended_reps_final, self.context_mask)
        else:
            with vs.variable_scope("StartDist"):
                start_decode_layer = StartDecodeLayer(self.FLAGS.hidden_size,
                                                      self.keep_prob)
                self.logits_start, self.probdist_start = start_decode_layer.build_graph(
                    blended_reps_final, self.context_mask)

        # Use softmax layer to compute probability distribution for end location
        # Note this produces self.logits_end and self.probdist_end, both of which have shape (batch_size, context_len)
        if self.FLAGS.cond_pred is False:
            with vs.variable_scope("EndDist"):
                softmax_layer_end = SimpleSoftmaxLayer()
                self.logits_end, self.probdist_end = softmax_layer_end.build_graph(
                    blended_reps_final, self.context_mask)
        else:
            logits_start_float32 = tf.expand_dims(tf.cast(self.logits_start,
                                                          dtype=tf.float32),
                                                  axis=2)
            logits_start_float32 = logits_start_float32 + tf.zeros(
                shape=(1, blended_reps_final.shape[1],
                       blended_reps_final.shape[2]),
                dtype=tf.float32)
            print(blended_reps_final.dtype, blended_reps_final.shape,
                  logits_start_float32.dtype, logits_start_float32.shape)
            comb_blended_reps = tf.concat(
                [blended_reps_final, logits_start_float32], axis=2)
            with vs.variable_scope("EndDist"):
                conditional_output_layer = ConditionalOutputLayer(
                    self.FLAGS.hidden_size, self.keep_prob)
                self.logits_end, self.probdist_end = conditional_output_layer.build_graph(
                    comb_blended_reps, self.context_mask)
    def build_graph(self):
        """Builds the main part of the graph for the model, starting from the input embeddings to the final distributions for the answer span.

        Defines:
          self.logits_start, self.logits_end: Both tensors shape (batch_size, context_len).
            These are the logits (i.e. values that are fed into the softmax function) for the start and end distribution.
            Important: these are -large in the pad locations. Necessary for when we feed into the cross entropy function.
          self.probdist_start, self.probdist_end: Both shape (batch_size, context_len). Each row sums to 1.
            These are the result of taking (masked) softmax of logits_start and logits_end.
        """

        # Use a RNN to get hidden states for the context and the question
        # Note: here the RNNEncoder is shared (i.e. the weights are the same)
        # between the context and the question.
        encoder = RNNEncoder(self.FLAGS.hidden_size, self.keep_prob)
        print "self.context_embs shape", self.context_embs.shape
        context_hiddens = encoder.build_graph(self.context_embs, self.context_mask) # (batch_size, context_len, hidden_size*2)
        print "context hiddens output of encoder",context_hiddens.shape
        question_hiddens = encoder.build_graph(self.qn_embs, self.qn_mask) # (batch_size, question_len, hidden_size*2)

        # Use context hidden states to attend to question hidden states

        if self.FLAGS.attention == "BasicAttn":

            attn_layer = BasicAttn(self.keep_prob, self.FLAGS.hidden_size*2, self.FLAGS.hidden_size*2)
            _, attn_output = attn_layer.build_graph(question_hiddens, self.qn_mask, context_hiddens) # attn_output is shape (batch_size, context_len, hidden_size*2)

            # Concat attn_output to context_hiddens to get blended_reps
            blended_reps = tf.concat([context_hiddens, attn_output], axis=2) # (batch_size, context_len, hidden_size*4)

        if self.FLAGS.attention == "Rnet":

            attn_layer = Rnet(self.keep_prob, self.FLAGS.hidden_size * 2, self.FLAGS.hidden_size * 2)
            attn_output, rep_v = attn_layer.build_graph(question_hiddens, self.qn_mask, context_hiddens, self.context_mask)  # attn_output is shape (batch_size, context_len, hidden_size*2)

            blended_reps_ = tf.concat([attn_output, rep_v], axis=2)  # (batch_size, context_len, hidden_size*4)
            # print "blended reps before encoder shape", blended_reps_.shape
            # print "self.context", self.context_mask.shape
            # blended_reps_ = tf.contrib.layers.fully_connected(blended_reps_,num_outputs=self.FLAGS.hidden_size * 2)  # blended_reps_final is shape (batch_size, context_len, hidden_size)
            # print "blended reps encoder input", blended_reps_.shape
            # cell_fw = tf.nn.rnn_cell.LSTMCell(self.FLAGS.hidden_size * 2)
            # cell_bw = tf.nn.rnn_cell.LSTMCell(self.FLAGS.hidden_size * 2)
            # # compute coattention encoding
            # (fw_out, bw_out), _ = tf.nn.bidirectional_dynamic_rnn(
            #     cell_fw, cell_bw, blended_reps_,
            #     dtype=tf.float32)
            encoderRnet = BiRNN(self.FLAGS.hidden_size, self.keep_prob)
            blended_reps = encoderRnet.build_graph(blended_reps_, self.context_mask)  # (batch_size, context_len, hidden_size*2??)
            # blended_reps = tf.concat([fw_out, bw_out],2)
            print "blended after encoder reps shape", blended_reps.shape


        if self.FLAGS.attention == "BiDAF":
            attn_layer = BiDAF(self.keep_prob, self.FLAGS.hidden_size * 2, self.FLAGS.hidden_size * 2)
            attn_output_C2Q,attn_output_Q2C = attn_layer.build_graph(question_hiddens, self.qn_mask,
                                                    context_hiddens, self.context_mask)  # attn_output is shape (batch_size, context_len, hidden_size*2)

            # Concat attn_output to context_hiddens to get blended_reps
            c_c2q_dot  = tf.multiply(context_hiddens, attn_output_C2Q)
            c_q2c_dot  = tf.multiply(context_hiddens, attn_output_Q2C)

            blended_reps = tf.concat([context_hiddens,attn_output_C2Q, c_c2q_dot,
                                      c_q2c_dot], axis=2) # (batch_size, context_len, hidden_size*4)


        if self.FLAGS.attention == "CoAttn" :

            attn_layer = CoAttn(self.keep_prob, self.FLAGS.hidden_size * 2, self.FLAGS.hidden_size * 2)
            attn_output_C2Q,attn_output_Q2C = attn_layer.build_graph(question_hiddens, self.qn_mask,
                                                    context_hiddens,self.context_mask)  # attn_output is shape (batch_size, context_len, hidden_size*2)

        # Concat attn_output to context_hiddens to get blended_reps
            c_c2q_dot  = tf.multiply(context_hiddens, attn_output_C2Q)
            c_q2c_dot  = tf.multiply(context_hiddens, attn_output_Q2C)

            blended_reps = tf.concat([context_hiddens,attn_output_C2Q, c_c2q_dot,
                                      c_q2c_dot], axis=2) # (batch_size, context_len, hidden_size*4)




        # Apply fully connected layer to each blended representation
        # Note, blended_reps_final corresponds to b' in the handout
        # Note, tf.contrib.layers.fully_connected applies a ReLU non-linarity here by default
        blended_reps_final = tf.contrib.layers.fully_connected(blended_reps, num_outputs=self.FLAGS.hidden_size) # blended_reps_final is shape (batch_size, context_len, hidden_size)
        print "shape of blended_reps_final ", blended_reps_final.shape
        # Use softmax layer to compute probability distribution for start location
        # Note this produces self.logits_start and self.probdist_start, both of which have shape (batch_size, context_len)
        with vs.variable_scope("StartDist"):
            softmax_layer_start = SimpleSoftmaxLayer()
            self.logits_start, self.probdist_start = softmax_layer_start.build_graph(blended_reps_final, self.context_mask)

        # Use softmax layer to compute probability distribution for end location
        # Note this produces self.logits_end and self.probdist_end, both of which have shape (batch_size, context_len)
        with vs.variable_scope("EndDist"):
            softmax_layer_end = SimpleSoftmaxLayer()
            self.logits_end, self.probdist_end = softmax_layer_end.build_graph(blended_reps_final, self.context_mask)