示例#1
0
    def build_graph(self):
        """Builds the main part of the graph for the model.
        """
        with vs.variable_scope("context"):
            context_encoder = RNNEncoder(self.FLAGS.hidden_size,
                                         self.keep_prob)
            context_hiddens = context_encoder.build_graph(
                self.context_embs,
                self.context_mask)  # (batch_size, context_len, hidden_size*2)

        with vs.variable_scope("question"):
            question_encoder = RNNEncoder(self.FLAGS.hidden_size,
                                          self.keep_prob)
            question_hiddens = question_encoder.build_graph(
                self.qn_embs,
                self.qn_mask)  # (batch_size, question_len, hidden_size*2)
            question_last_hidden = tf.reshape(question_hiddens[:, -1, :],
                                              (-1, 2 * self.FLAGS.hidden_size))
            question_last_hidden = tf.contrib.layers.fully_connected(
                question_last_hidden, num_outputs=self.FLAGS.hidden_size)
        # Use context hidden states to attend to question hidden states

        # attn_output is shape (batch_size, context_len, hidden_size*2)
        # The following is BiDAF attention
        if self.FLAGS.use_bidaf:
            attn_layer = BiDAF(self.keep_prob, self.FLAGS.hidden_size * 2,
                               self.FLAGS.hidden_size * 2)
            attn_output = attn_layer.build_graph(
                question_hiddens, self.qn_mask, context_hiddens, self.
                context_mask)  # (batch_size, context_len, hidden_size * 6)
        else:  # otherwise, basic attention
            attn_layer = BasicAttn(self.keep_prob, self.FLAGS.hidden_size * 2,
                                   self.FLAGS.hidden_size * 2)
            _, attn_output = attn_layer.build_graph(question_hiddens,
                                                    self.qn_mask,
                                                    context_hiddens)
        # Concat attn_output to context_hiddens to get blended_reps
        blended_reps = tf.concat(
            [context_hiddens, attn_output],
            axis=2)  # (batch_size, context_len, hidden_size*4)

        blended_reps_final = tf.contrib.layers.fully_connected(
            blended_reps, num_outputs=self.FLAGS.hidden_size)

        decoder = RNNDecoder(self.FLAGS.batch_size,
                             self.FLAGS.hidden_size,
                             self.ans_vocab_size,
                             self.FLAGS.answer_len,
                             self.ans_embedding_matrix,
                             self.keep_prob,
                             sampling_prob=self.sampling_prob,
                             schedule_embed=self.FLAGS.schedule_embed,
                             pred_method=self.FLAGS.pred_method)
        (self.train_logits, self.train_translations, _), \
        (self.dev_logits, self.dev_translations, self.attention_results) = decoder.build_graph(blended_reps_final, question_last_hidden,
                                                                       self.ans_embs, self.ans_mask, self.ans_ids,
                                                                       self.context_mask)
示例#2
0
    def build_graph(self):
        """Builds the main part of the graph for the model, starting from the input embeddings to the final distributions for the answer span.

        Defines:
          self.logits_start, self.logits_end: Both tensors shape (batch_size, context_len).
            These are the logits (i.e. values that are fed into the softmax function) for the start and end distribution.
            Important: these are -large in the pad locations. Necessary for when we feed into the cross entropy function.
          self.probdist_start, self.probdist_end: Both shape (batch_size, context_len). Each row sums to 1.
            These are the result of taking (masked) softmax of logits_start and logits_end.
        """

        # Use a RNN to get hidden states for the context and the question
        # Note: here the RNNEncoder is shared (i.e. the weights are the same)
        # between the context and the question.
        encoder = RNNEncoder(self.FLAGS.hidden_size, self.keep_prob)
        context_hiddens = encoder.build_graph(
            self.context_embs,
            self.context_mask)  # (batch_size, context_len, hidden_size*2)
        question_hiddens = encoder.build_graph(
            self.qn_embs,
            self.qn_mask)  # (batch_size, question_len, hidden_size*2)

        # BasicAttn
        # Use context hidden states to attend to question hidden states
        # attn_layer = BasicAttn(self.keep_prob, self.FLAGS.hidden_size*2, self.FLAGS.hidden_size*2)
        # _, attn_output = attn_layer.build_graph(question_hiddens, self.qn_mask, context_hiddens) # attn_output is shape (batch_size, context_len, hidden_size*2)
        # # Concat attn_output to context_hiddens to get blended_reps
        # blended_reps = tf.concat([context_hiddens, attn_output], axis=2) # (batch_size, context_len, hidden_size*4)

        #BiDAF
        attn_layer = BiDAF(self.keep_prob)
        blended_reps = attn_layer.build_graph(
            question_hiddens, self.qn_mask, context_hiddens, self.context_mask
        )  # attn_output is shape (batch_size, context_len, hidden_size*2)

        # Apply fully connected layer to each blended representation
        # Note, blended_reps_final corresponds to b' in the handout
        # Note, tf.contrib.layers.fully_connected applies a ReLU non-linarity here by default
        blended_reps_final = tf.contrib.layers.fully_connected(
            blended_reps, num_outputs=self.FLAGS.hidden_size
        )  # blended_reps_final is shape (batch_size, context_len, hidden_size)

        # Use softmax layer to compute probability distribution for start location
        # Note this produces self.logits_start and self.probdist_start, both of which have shape (batch_size, context_len)
        with vs.variable_scope("StartDist"):
            softmax_layer_start = SimpleSoftmaxLayer()
            self.logits_start, self.probdist_start = softmax_layer_start.build_graph(
                blended_reps_final, self.context_mask)

        # Use softmax layer to compute probability distribution for end location
        # Note this produces self.logits_end and self.probdist_end, both of which have shape (batch_size, context_len)
        with vs.variable_scope("EndDist"):
            softmax_layer_end = SimpleSoftmaxLayer()
            self.logits_end, self.probdist_end = softmax_layer_end.build_graph(
                blended_reps_final, self.context_mask)
示例#3
0
    def build_graph(self):
        """Builds the main part of the graph for the model, starting from the input embeddings to the final distributions for the answer span.

        Defines:
          self.logits_start, self.logits_end: Both tensors shape (batch_size, context_len).
            These are the logits (i.e. values that are fed into the softmax function) for the start and end distribution.
            Important: these are -large in the pad locations. Necessary for when we feed into the cross entropy function.
          self.probdist_start, self.probdist_end: Both shape (batch_size, context_len). Each row sums to 1.
            These are the result of taking (masked) softmax of logits_start and logits_end.
        """

        # Use a RNN to get hidden states for the context and the question
        # Note: here the RNNEncoder is shared (i.e. the weights are the same)
        # between the context and the question.
        encoder = RNNEncoder(self.FLAGS.h_hidden_size,
                             self.keep_prob,
                             num_layers=self.FLAGS.h_num_layers,
                             combiner=self.FLAGS.h_combiner,
                             cell_type=self.FLAGS.h_cell_type)
        if self.FLAGS.share_encoder:
            question_hiddens, question_states_fw, question_states_bw = encoder.build_graph(
                self.qn_embs,
                self.qn_mask)  # (batch_size, question_len, hidden_size*2)
        else:
            question_encoder = RNNEncoder(self.FLAGS.h_hidden_size,
                                          self.keep_prob,
                                          num_layers=self.FLAGS.h_num_layers,
                                          combiner=self.FLAGS.h_combiner,
                                          cell_type=self.FLAGS.h_cell_type,
                                          scope='question_encoder')
            question_hiddens, question_states_fw, question_states_bw = question_encoder.build_graph(
                self.qn_embs,
                self.qn_mask)  # (batch_size, question_len, hidden_size*2)
        if not self.FLAGS.reuse_question_states:
            question_states_fw, question_states_bw = None, None
        context_hiddens, _, _ = encoder.build_graph(
            self.context_embs,
            self.context_mask,
            initial_states_fw=question_states_fw,
            initial_states_bw=question_states_bw
        )  # (batch_size, context_len, hidden_size*2)

        if self.FLAGS.use_bidaf:
            attn_layer = BiDAF(self.keep_prob)
            context_att, question_att = attn_layer.build_graph(
                question_hiddens, self.qn_mask, context_hiddens,
                self.context_mask)
            blended_reps = tf.concat([
                context_hiddens, context_att, context_hiddens * context_att,
                context_hiddens * question_att
            ],
                                     axis=2)
        else:
            # Use context hidden states to attend to question hidden states
            attn_layer = BasicAttn(self.keep_prob)
            _, attn_output = attn_layer.build_graph(
                question_hiddens, self.qn_mask, context_hiddens
            )  # attn_output is shape (batch_size, context_len, hidden_size*2)
            # Concat attn_output to context_hiddens to get blended_reps
            blended_reps = tf.concat(
                [context_hiddens, attn_output, context_hiddens * attn_output],
                axis=2)  # (batch_size, context_len, hidden_size*4)

        if self.FLAGS.modeling_layer_uses_rnn:
            modelling_encoder = RNNEncoder(
                self.FLAGS.h_model_size,
                self.keep_prob,
                num_layers=self.FLAGS.h_model_layers,
                combiner=self.FLAGS.h_combiner,
                cell_type=self.FLAGS.h_cell_type,
                scope='blended_reps_scope')
            blended_reps_final, model_states_fw, model_states_bw = modelling_encoder.build_graph(
                blended_reps, self.context_mask)
        else:
            # Apply fully connected layer to each blended representation
            # Note, blended_reps_final corresponds to b' in the handout
            # Note, tf.contrib.layers.fully_connected applies a ReLU non-linarity here by default
            blended_reps_final = tf.contrib.layers.fully_connected(
                blended_reps, num_outputs=self.FLAGS.h_hidden_size
            )  # blended_reps_final is shape (batch_size, context_len, hidden_size)

        # Use softmax layer to compute probability distribution for start location
        # Note this produces self.logits_start and self.probdist_start, both of which have shape (batch_size, context_len)
        with vs.variable_scope("StartDist"):
            softmax_layer_start = SimpleSoftmaxLayer()
            self.logits_start, self.probdist_start = softmax_layer_start.build_graph(
                blended_reps_final, self.context_mask)

        # Use softmax layer to compute probability distribution for end location
        # Note this produces self.logits_end and self.probdist_end, both of which have shape (batch_size, context_len)
        with vs.variable_scope("EndDist"):
            if self.FLAGS.use_rnn_for_ends:
                end_encoder = RNNEncoder(self.FLAGS.h_model_size,
                                         self.keep_prob,
                                         num_layers=self.FLAGS.h_model_layers,
                                         combiner=self.FLAGS.h_combiner,
                                         cell_type=self.FLAGS.h_cell_type,
                                         scope='blended_reps_final')
                blended_reps_combined = tf.concat([
                    blended_reps_final,
                    tf.expand_dims(self.probdist_start, 2)
                ], 2)
                blended_reps_final, _, _ = end_encoder.build_graph(
                    blended_reps_combined,
                    self.context_mask,
                    initial_states_fw=model_states_fw,
                    initial_states_bw=model_states_bw)
            softmax_layer_end = SimpleSoftmaxLayer()
            self.logits_end, self.probdist_end = softmax_layer_end.build_graph(
                blended_reps_final, self.context_mask)
示例#4
0
    def build_graph(self):
        """Builds the main part of the graph for the model, starting from the input embeddings to the final distributions for the answer span.

        Defines:
          self.logits_start, self.logits_end: Both tensors shape (batch_size, context_len).
            These are the logits (i.e. values that are fed into the softmax function) for the start and end distribution.
            Important: these are -large in the pad locations. Necessary for when we feed into the cross entropy function.
          self.probdist_start, self.probdist_end: Both shape (batch_size, context_len). Each row sums to 1.
            These are the result of taking (masked) softmax of logits_start and logits_end.
        """

        # Use a RNN to get hidden states for the context and the question
        # Note: here the RNNEncoder is shared (i.e. the weights are the same)
        # between the context and the question.

        ########################################
        # First bidirection GRU layer
        ########################################

        encoder = RNNEncoder(self.FLAGS.hidden_size, self.keep_prob)
        context_hiddens = encoder.build_graph(
            self.context_embs,
            self.context_mask)  # (batch_size, context_len, hidden_size*2)
        question_hiddens = encoder.build_graph(
            self.qn_embs,
            self.qn_mask)  # (batch_size, question_len, hidden_size*2)

        ####################
        # Bidaf Attn Layer
        ####################

        # Use context hidden states to attend to question hidden states
        attn_layer = BiDAF(self.keep_prob, self.FLAGS.hidden_size * 2,
                           self.FLAGS.hidden_size * 2)
        _, attn_output = attn_layer.build_graph(
            context_hiddens, question_hiddens, self.context_mask, self.qn_mask
        )  # attn_output is shape (batch_size, context_len, hidden_size*2)
        # Concat attn_output to contexxt_hiddens to get blended_reps
        blended_reps = tf.concat(
            [context_hiddens, attn_output],
            axis=2)  # (batch_size, context_len, hidden_size*4)

        ####################
        # Bidaf second bidirection layer
        ####################

        # Bidaf layer after context and question attnetion is calculated. Based off oringinal BiDaf paper

        encoder2 = RNNEncoder(self.FLAGS.hidden_size, self.keep_prob)
        bidaf_second_layer_hiddens = encoder2.build_graph(
            blended_reps, self.context_mask, scope_name="BidafEncoder"
        )  # (batch_size, question_len, hidden_size*2)

        ####################
        # Self Attn Layer
        ####################

        # if self.self_attn:
        # # Bidaf second attention layer, should eventually use self attention
        #     self_attn_layer = SelfAttn(self.keep_prob, self.FLAGS.hidden_size*2)
        #     _, self_attn_output = self_attn_layer.build_graph(bidaf_second_layer_hiddens, self.context_mask)
        #     self_attn_reps = tf.concat([bidaf_second_layer_hiddens, self_attn_output], axis=2)
        # else:
        #     self_attn_reps = bidaf_second_layer_hiddens

        ####################
        # Bidaf third bidirection layer
        ####################

        encoder3 = RNNEncoder(self.FLAGS.hidden_size, self.keep_prob)
        bidaf_third_layer = encoder3.build_graph(
            bidaf_second_layer_hiddens,
            self.context_mask,
            scope_name="SelfAttnBidaf"
        )  # (batch_size, question_len, hidden_size*2)

        final_context_reps = tf.contrib.layers.fully_connected(
            bidaf_third_layer, num_outputs=self.FLAGS.hidden_size
        )  # final_context_reps is shape (batch_size, context_len, hidden_size)

        ####################
        # Attn_Layer
        # ansptr_layer = AnsPtr(self.FLAGS.hidden_size, self.keep_prob)

        # BiDAF Output Layer
        bidaf_out = BiDAFOut(self.FLAGS.hidden_size, self.keep_prob)
        self.logits_start, self.probdist_start, self.logits_end, self.probdist_end = bidaf_out.build_graph(
            attn_output, bidaf_second_layer_hiddens, self.context_mask)
示例#5
0
    def build_graph(self):
        """Builds the main part of the graph for the model, starting from the input embeddings to the final distributions for the answer span.

        Defines:
          self.logits_start, self.logits_end: Both tensors shape (batch_size, context_len).
            These are the logits (i.e. values that are fed into the softmax function) for the start and end distribution.
            Important: these are -large in the pad locations. Necessary for when we feed into the cross entropy function.
          self.probdist_start, self.probdist_end: Both shape (batch_size, context_len). Each row sums to 1.
            These are the result of taking (masked) softmax of logits_start and logits_end.
        """

        # Use a RNN to get hidden states for the context and the question
        # Note: here the RNNEncoder is shared (i.e. the weights are the same)
        # between the context and the question.
        encoder = RNNEncoder(self.FLAGS.hidden_size, self.keep_prob)
        context_hiddens = encoder.build_graph(
            self.context_embs,
            self.context_mask)  # (batch_size, context_len, hidden_size*2)
        question_hiddens = encoder.build_graph(
            self.qn_embs,
            self.qn_mask)  # (batch_size, question_len, hidden_size*2)

        # Use context hidden states to attend to question hidden states
        if self.FLAGS.model == 'bidaf':
            bidaf_layer = BiDAF(self.FLAGS.hidden_size, self.keep_prob,
                                self.FLAGS.hidden_size * 2,
                                self.FLAGS.hidden_size * 2)
            g_m, g_m2 = bidaf_layer.build_graph(
                question_hiddens, self.qn_mask, context_hiddens,
                self.context_mask)  # (batch_size, context_len, hidden_size*10)
            with vs.variable_scope("StartDist"):
                softmax_layer_start = SimpleSoftmaxLayer(1 -
                                                         self.FLAGS.dropout)
                self.logits_start, self.probdist_start = softmax_layer_start.build_graph(
                    g_m, self.context_mask)
            with vs.variable_scope("EndDist"):
                softmax_layer_end = SimpleSoftmaxLayer(1 - self.FLAGS.dropout)
                self.logits_end, self.probdist_end = softmax_layer_end.build_graph(
                    g_m2, self.context_mask)
        elif self.FLAGS.model == 'bicoattn':
            bicoattn_layer = BiCoattn(self.FLAGS.batch_size,
                                      self.FLAGS.context_len,
                                      self.FLAGS.hidden_size, self.keep_prob,
                                      self.FLAGS.hidden_size * 2,
                                      self.FLAGS.hidden_size * 2)
            g_m, g_m2, self.attn_output = bicoattn_layer.build_graph(
                question_hiddens, self.qn_mask, context_hiddens,
                self.context_mask)  # (batch_size, context_len, hidden_size*14)
            with vs.variable_scope("StartDist"):
                softmax_layer_start = SimpleSoftmaxLayer(1 -
                                                         self.FLAGS.dropout)
                self.logits_start, self.probdist_start = softmax_layer_start.build_graph(
                    g_m, self.context_mask)
            with vs.variable_scope("EndDist"):
                softmax_layer_end = SimpleSoftmaxLayer(1 - self.FLAGS.dropout)
                self.logits_end, self.probdist_end = softmax_layer_end.build_graph(
                    g_m2, self.context_mask)
        elif self.FLAGS.model == 'transformernetwork':
            transformernetwork_layer = TransformerNetwork(
                self.FLAGS.hidden_size, self.keep_prob,
                self.FLAGS.hidden_size * 2, self.FLAGS.hidden_size * 2, 3, 8)
            g_m, g_m2 = transformernetwork_layer.build_graph(
                question_hiddens, self.qn_mask, context_hiddens,
                self.context_mask, self.FLAGS.is_training
            )  # (batch_size, context_len, hidden_size*2)
            with vs.variable_scope("StartDist"):
                softmax_layer_start = SimpleSoftmaxLayer(1 -
                                                         self.FLAGS.dropout)
                self.logits_start, self.probdist_start = softmax_layer_start.build_graph(
                    g_m, self.context_mask)
            with vs.variable_scope("EndDist"):
                softmax_layer_end = SimpleSoftmaxLayer(1 - self.FLAGS.dropout)
                self.logits_end, self.probdist_end = softmax_layer_end.build_graph(
                    g_m2, self.context_mask)
        elif self.FLAGS.model == 'bctn':
            bctn_layer = BCTN(self.FLAGS.hidden_size, self.keep_prob,
                              self.FLAGS.hidden_size * 2,
                              self.FLAGS.hidden_size * 2, 2, 8)
            g_m, g_m2 = bctn_layer.build_graph(
                question_hiddens, self.qn_mask, context_hiddens,
                self.context_mask, self.FLAGS.is_training
            )  # (batch_size, context_len, hidden_size*2)
            with vs.variable_scope("StartDist"):
                softmax_layer_start = SimpleSoftmaxLayer(1 -
                                                         self.FLAGS.dropout)
                self.logits_start, self.probdist_start = softmax_layer_start.build_graph(
                    g_m, self.context_mask)
            with vs.variable_scope("EndDist"):
                softmax_layer_end = SimpleSoftmaxLayer(1 - self.FLAGS.dropout)
                self.logits_end, self.probdist_end = softmax_layer_end.build_graph(
                    g_m2, self.context_mask)

        elif self.FLAGS.model == 'rnet':
            with vs.variable_scope("Contextual"):
                encoder = RNNEncoder(self.FLAGS.hidden_size, self.keep_prob)
                context_hiddens = encoder.build_graph(self.context_embs,
                                                      self.context_mask)
                question_hiddens = encoder.build_graph(self.qn_embs,
                                                       self.qn_mask)
            print "GatedAttn"
            with vs.variable_scope("GatedAttn"):
                attn_layer_gated = GatedAttn(self.keep_prob,
                                             self.FLAGS.hidden_size * 2,
                                             self.FLAGS.hidden_size * 2,
                                             self.FLAGS.hidden_size)
                context_hiddens_gated, self.a_t = attn_layer_gated.build_graph(
                    question_hiddens, self.qn_mask,
                    context_hiddens)  # (batch_size, context_len, hidden_size)
            print "SelfAttn"
            with vs.variable_scope("SelfAttn"):
                attn_layer_self = SelfAttn(self.keep_prob,
                                           self.FLAGS.hidden_size,
                                           self.FLAGS.hidden_size)
                attn_output_self, self.attn_output = attn_layer_self.build_graph(
                    context_hiddens_gated, self.context_mask
                )  # (batch_size, context_len, hidden_size * 2)

            print "Output"
            with vs.variable_scope("Output"):
                output_layer = Output_Rnet(self.keep_prob,
                                           self.FLAGS.hidden_size * 2,
                                           self.FLAGS.hidden_size * 2,
                                           self.FLAGS.hidden_size)
                self.logits_start, self.probdist_start, self.logits_end, self.probdist_end, self.a = output_layer.build_graph(
                    attn_output_self, question_hiddens, self.context_mask,
                    self.qn_mask)
        elif self.FLAGS.model == 'basicattnplusone':
            attn_layer = BasicAttnPlusOne(self.FLAGS.batch_size,
                                          self.FLAGS.context_len,
                                          self.FLAGS.hidden_size,
                                          self.keep_prob,
                                          self.FLAGS.hidden_size * 2,
                                          self.FLAGS.hidden_size * 2)
            _, attn_output = attn_layer.build_graph(
                question_hiddens, self.qn_mask, context_hiddens,
                self.context_mask
            )  # attn_output is shape (batch_size, context_len, hidden_size*4)

            blended_reps = tf.concat(
                [context_hiddens, attn_output],
                axis=2)  # (batch_size, context_len, hidden_size*4)

            blended_reps_final = tf.contrib.layers.fully_connected(
                blended_reps, num_outputs=self.FLAGS.hidden_size
            )  # blended_reps_final is shape (batch_size, context_len, hidden_size)

            with vs.variable_scope("StartDist"):
                softmax_layer_start = SimpleSoftmaxLayer(1 -
                                                         self.FLAGS.dropout)
                self.logits_start, self.probdist_start = softmax_layer_start.build_graph(
                    blended_reps_final, self.context_mask)

            with vs.variable_scope("EndDist"):
                softmax_layer_end = SimpleSoftmaxLayer(1 - self.FLAGS.dropout)
                self.logits_end, self.probdist_end = softmax_layer_end.build_graph(
                    blended_reps_final, self.context_mask)
        elif self.FLAGS.model == 'basicattnplustwo':
            attn_layer = BasicAttnPlusTwo(self.FLAGS.batch_size,
                                          self.FLAGS.context_len,
                                          self.FLAGS.hidden_size,
                                          self.keep_prob,
                                          self.FLAGS.hidden_size * 2,
                                          self.FLAGS.hidden_size * 2)
            g_m, g_m2 = attn_layer.build_graph(
                question_hiddens, self.qn_mask, context_hiddens,
                self.context_mask
            )  # attn_output is shape (batch_size, context_len, hidden_size*4)
            with vs.variable_scope("StartDist"):
                softmax_layer_start = SimpleSoftmaxLayer(1 -
                                                         self.FLAGS.dropout)
                self.logits_start, self.probdist_start = softmax_layer_start.build_graph(
                    g_m, self.context_mask)
            with vs.variable_scope("EndDist"):
                softmax_layer_end = SimpleSoftmaxLayer(1 - self.FLAGS.dropout)
                self.logits_end, self.probdist_end = softmax_layer_end.build_graph(
                    g_m2, self.context_mask)
        else:
            attn_layer = BasicAttn(self.keep_prob, self.FLAGS.hidden_size * 2,
                                   self.FLAGS.hidden_size * 2)
            _, attn_output = attn_layer.build_graph(
                question_hiddens, self.qn_mask, context_hiddens
            )  # attn_output is shape (batch_size, context_len, hidden_size*4)

            # Concat attn_output to context_hiddens to get blended_reps
            blended_reps = tf.concat(
                [context_hiddens, attn_output],
                axis=2)  # (batch_size, context_len, hidden_size*4)

            # Apply fully connected layer to each blended representation
            # Note, blended_reps_final corresponds to b' in the handout
            # Note, tf.contrib.layers.fully_connected applies a ReLU non-linarity here by default
            blended_reps_final = tf.contrib.layers.fully_connected(
                blended_reps, num_outputs=self.FLAGS.hidden_size
            )  # blended_reps_final is shape (batch_size, context_len, hidden_size)

            # Use softmax layer to compute probability distribution for start location
            # Note this produces self.logits_start and self.probdist_start, both of which have shape (batch_size, context_len)
            with vs.variable_scope("StartDist"):
                softmax_layer_start = SimpleSoftmaxLayer(1 -
                                                         self.FLAGS.dropout)
                self.logits_start, self.probdist_start = softmax_layer_start.build_graph(
                    blended_reps_final, self.context_mask)

            # Use softmax layer to compute probability distribution for end location
            # Note this produces self.logits_end and self.probdist_end, both of which have shape (batch_size, context_len)
            with vs.variable_scope("EndDist"):
                softmax_layer_end = SimpleSoftmaxLayer(1 - self.FLAGS.dropout)
                self.logits_end, self.probdist_end = softmax_layer_end.build_graph(
                    blended_reps_final, self.context_mask)
    def build_graph(self):
        """Builds the main part of the graph for the model, starting from the input embeddings to the final distributions for the answer span.

        Defines:
          self.logits_start, self.logits_end: Both tensors shape (batch_size, context_len).
            These are the logits (i.e. values that are fed into the softmax function) for the start and end distribution.
            Important: these are -large in the pad locations. Necessary for when we feed into the cross entropy function.
          self.probdist_start, self.probdist_end: Both shape (batch_size, context_len). Each row sums to 1.
            These are the result of taking (masked) softmax of logits_start and logits_end.
        """

        # Use a RNN to get hidden states for the context and the question
        # Note: here the RNNEncoder is shared (i.e. the weights are the same)
        # between the context and the question.
        encoder = RNNEncoder(self.FLAGS.hidden_size, self.keep_prob)
        print "self.context_embs shape", self.context_embs.shape
        context_hiddens = encoder.build_graph(self.context_embs, self.context_mask) # (batch_size, context_len, hidden_size*2)
        print "context hiddens output of encoder",context_hiddens.shape
        question_hiddens = encoder.build_graph(self.qn_embs, self.qn_mask) # (batch_size, question_len, hidden_size*2)

        # Use context hidden states to attend to question hidden states

        if self.FLAGS.attention == "BasicAttn":

            attn_layer = BasicAttn(self.keep_prob, self.FLAGS.hidden_size*2, self.FLAGS.hidden_size*2)
            _, attn_output = attn_layer.build_graph(question_hiddens, self.qn_mask, context_hiddens) # attn_output is shape (batch_size, context_len, hidden_size*2)

            # Concat attn_output to context_hiddens to get blended_reps
            blended_reps = tf.concat([context_hiddens, attn_output], axis=2) # (batch_size, context_len, hidden_size*4)

        if self.FLAGS.attention == "Rnet":

            attn_layer = Rnet(self.keep_prob, self.FLAGS.hidden_size * 2, self.FLAGS.hidden_size * 2)
            attn_output, rep_v = attn_layer.build_graph(question_hiddens, self.qn_mask, context_hiddens, self.context_mask)  # attn_output is shape (batch_size, context_len, hidden_size*2)

            blended_reps_ = tf.concat([attn_output, rep_v], axis=2)  # (batch_size, context_len, hidden_size*4)
            # print "blended reps before encoder shape", blended_reps_.shape
            # print "self.context", self.context_mask.shape
            # blended_reps_ = tf.contrib.layers.fully_connected(blended_reps_,num_outputs=self.FLAGS.hidden_size * 2)  # blended_reps_final is shape (batch_size, context_len, hidden_size)
            # print "blended reps encoder input", blended_reps_.shape
            # cell_fw = tf.nn.rnn_cell.LSTMCell(self.FLAGS.hidden_size * 2)
            # cell_bw = tf.nn.rnn_cell.LSTMCell(self.FLAGS.hidden_size * 2)
            # # compute coattention encoding
            # (fw_out, bw_out), _ = tf.nn.bidirectional_dynamic_rnn(
            #     cell_fw, cell_bw, blended_reps_,
            #     dtype=tf.float32)
            encoderRnet = BiRNN(self.FLAGS.hidden_size, self.keep_prob)
            blended_reps = encoderRnet.build_graph(blended_reps_, self.context_mask)  # (batch_size, context_len, hidden_size*2??)
            # blended_reps = tf.concat([fw_out, bw_out],2)
            print "blended after encoder reps shape", blended_reps.shape


        if self.FLAGS.attention == "BiDAF":
            attn_layer = BiDAF(self.keep_prob, self.FLAGS.hidden_size * 2, self.FLAGS.hidden_size * 2)
            attn_output_C2Q,attn_output_Q2C = attn_layer.build_graph(question_hiddens, self.qn_mask,
                                                    context_hiddens, self.context_mask)  # attn_output is shape (batch_size, context_len, hidden_size*2)

            # Concat attn_output to context_hiddens to get blended_reps
            c_c2q_dot  = tf.multiply(context_hiddens, attn_output_C2Q)
            c_q2c_dot  = tf.multiply(context_hiddens, attn_output_Q2C)

            blended_reps = tf.concat([context_hiddens,attn_output_C2Q, c_c2q_dot,
                                      c_q2c_dot], axis=2) # (batch_size, context_len, hidden_size*4)


        if self.FLAGS.attention == "CoAttn" :

            attn_layer = CoAttn(self.keep_prob, self.FLAGS.hidden_size * 2, self.FLAGS.hidden_size * 2)
            attn_output_C2Q,attn_output_Q2C = attn_layer.build_graph(question_hiddens, self.qn_mask,
                                                    context_hiddens,self.context_mask)  # attn_output is shape (batch_size, context_len, hidden_size*2)

        # Concat attn_output to context_hiddens to get blended_reps
            c_c2q_dot  = tf.multiply(context_hiddens, attn_output_C2Q)
            c_q2c_dot  = tf.multiply(context_hiddens, attn_output_Q2C)

            blended_reps = tf.concat([context_hiddens,attn_output_C2Q, c_c2q_dot,
                                      c_q2c_dot], axis=2) # (batch_size, context_len, hidden_size*4)




        # Apply fully connected layer to each blended representation
        # Note, blended_reps_final corresponds to b' in the handout
        # Note, tf.contrib.layers.fully_connected applies a ReLU non-linarity here by default
        blended_reps_final = tf.contrib.layers.fully_connected(blended_reps, num_outputs=self.FLAGS.hidden_size) # blended_reps_final is shape (batch_size, context_len, hidden_size)
        print "shape of blended_reps_final ", blended_reps_final.shape
        # Use softmax layer to compute probability distribution for start location
        # Note this produces self.logits_start and self.probdist_start, both of which have shape (batch_size, context_len)
        with vs.variable_scope("StartDist"):
            softmax_layer_start = SimpleSoftmaxLayer()
            self.logits_start, self.probdist_start = softmax_layer_start.build_graph(blended_reps_final, self.context_mask)

        # Use softmax layer to compute probability distribution for end location
        # Note this produces self.logits_end and self.probdist_end, both of which have shape (batch_size, context_len)
        with vs.variable_scope("EndDist"):
            softmax_layer_end = SimpleSoftmaxLayer()
            self.logits_end, self.probdist_end = softmax_layer_end.build_graph(blended_reps_final, self.context_mask)