Example #1
0
 def apply(self, is_train, inputs, mask=None):
     inputs = tf.transpose(inputs, [1, 0, 2])  # to time first
     with tf.variable_scope("forward"):
         cell = LSTMBlockFusedCell(self.n_units, use_peephole=self.use_peepholes)
         fw = cell(inputs, dtype=tf.float32, sequence_length=mask)[0]
     with tf.variable_scope("backward"):
         cell = LSTMBlockFusedCell(self.n_units, use_peephole=self.use_peepholes)
         inputs = tf.reverse_sequence(inputs, mask, seq_axis=0, batch_axis=1)
         bw = cell(inputs, dtype=tf.float32, sequence_length=mask)[0]
         bw = tf.reverse_sequence(bw, mask, seq_axis=0, batch_axis=1)
     out = tf.concat([fw, bw], axis=2)
     out = tf.transpose(out, [1, 0, 2])  # back to batch first
     return out
Example #2
0
def build_fused_bidirectional_rnn(inputs,
                                  num_units,
                                  num_layers,
                                  inputs_length,
                                  input_keep_prob=1.0,
                                  scope=None,
                                  dtype=tf.float32):
    ''' The input of sequences should be time major
        And the Dropout is independent per time. '''
    assert num_layers > 0
    with tf.variable_scope(scope or "bidirectional_rnn"):
        inputs = flatten(inputs, 2)  # [N * M, JX, d]
        current_inputs = tf.transpose(
            inputs, [1, 0, 2])  #[time_len, batch_size, input_size]
        for layer_id in range(num_layers):
            reverse_inputs = tf.reverse_sequence(current_inputs,
                                                 inputs_length,
                                                 batch_dim=1,
                                                 seq_dim=0)
            fw_inputs = tf.nn.dropout(current_inputs, input_keep_prob)
            bw_inputs = tf.nn.dropout(reverse_inputs, input_keep_prob)
            fw_cell = LSTMBlockFusedCell(num_units, cell_clip=0)
            bw_cell = LSTMBlockFusedCell(num_units, cell_clip=0)
            fw_outputs, fw_final = fw_cell(fw_inputs,
                                           dtype=dtype,
                                           sequence_length=inputs_length,
                                           scope="fw_" + str(layer_id))
            bw_outputs, bw_final = bw_cell(bw_inputs,
                                           dtype=dtype,
                                           sequence_length=inputs_length,
                                           scope="bw_" + str(layer_id))
            bw_outputs = tf.reverse_sequence(bw_outputs,
                                             inputs_length,
                                             batch_dim=1,
                                             seq_dim=0)
            current_inputs = tf.concat((fw_outputs, bw_outputs), 2)
        output = tf.transpose(current_inputs, [1, 0, 2])
        output = tf.expand_dims(output, 1)  # [N, M, JX, 2d]

        final_state_c = tf.concat((fw_final[0], bw_final[0]),
                                  1,
                                  name=scope + '_final_c')
        final_state_h = tf.concat((fw_final[1], bw_final[1]),
                                  1,
                                  name=scope + '_final_h')
        final_state = LSTMStateTuple(
            c=final_state_c, h=final_state_h)  # ([N, 2 * d], [N, 2 * d])
        return output, final_state
 def __init__(self,
              GPU,
              num_layers,
              num_units,
              dropout=0.,
              dtype=tf.dtypes.float32,
              name=None):
     '''
     create a lstm adapter. equal to `LSTMBlockFusedCell` if GPU, else `CudnnLSTM`.
     '''
     base_layer.Layer.__init__(self, dtype=dtype, name=name)
     self.GPU = GPU
     self.dropout = dropout
     if GPU:
         self.model = CudnnLSTM(num_layers,
                                num_units,
                                dtype=self.dtype,
                                name=name)
     else:
         self.model = MultiFusedRNNCell([
             LSTMBlockFusedCell(num_units,
                                dtype=self.dtype,
                                name='%s_%d' % (name, i))
             for i in range(num_layers)
         ])
Example #4
0
def fused_lstm_module(input, name, train, units, recomp=False):
    """
    tensorflow LSTM implementation - for inference only
    :param input: input tensor
    :param name: name for variable scope etc.
    :param train: is_train placeholder
    :param units: number of lstm units
    :param recomp: recompute_gradient environment
    :return: output tensor after LSTM
    """
    # only for inference due to forget_bias=0.0
    # CUDNN forget bias is trained without offset, but initialized to 1
    # fused forget bias trained with offset 1, but initialized to 0
    # fix: initialize fused forget bias to 1 and use forget_bias=0 to have compatibility with cudnn lstm

    should_learn = tf.logical_and(train, tf.logical_not(recomp))

    # manually set name to reflect cudnn weight names
    # forget bias = 0 needed as otherwise 1 added to already learned bias
    lstm = LSTMBlockFusedCell(
        num_units=units,
        forget_bias=0.0,
        name=name + "/rnn/multi_rnn_cell/cell_0/cudnn_compatible_lstm_cell")
    # lstm swaps batch and time dimension
    input = tf.transpose(input, perm=[1, 0, 2])
    input, c = lstm(input, dtype=tf.float32)
    # swap back lstm batch and time dimension
    input = tf.transpose(input, perm=[1, 0, 2])

    return input
Example #5
0
 def lstm(self, inputs, batch_size, num_units, swap_axes=True):
     if swap_axes:
         inputs = tf.transpose(inputs, [1, 0, 2])
     cell = LSTMBlockFusedCell(num_units)
     self.rnn_state = self.zero_state(batch_size, num_units)
     outputs, new_rnn_state = cell(inputs=inputs,
                                   initial_state=self.rnn_state,
                                   dtype=tf.float32)
     if swap_axes:
         outputs = tf.transpose(outputs, [1, 0, 2])
     return outputs, new_rnn_state
Example #6
0
 def apply(self, is_train, x, mask=None):
     x = tf.transpose(x, [1, 0, 2])  # to time first
     state = LSTMBlockFusedCell(self.n_units)(x, dtype=tf.float32, sequence_length=mask)[1]
     if self.state and self.hidden:
         state = tf.concat(state, 1)
     elif self.hidden:
         state = state.h
     elif self.state:
         state = state.c
     else:
         raise ValueError()
     return state
Example #7
0
 def get_lstm_outputs(self, chars, last_state=None, reuse=False):
     with tf.variable_scope('char_embedding', reuse=reuse):
         self.char_embedding = tf.get_variable('char_embedding', initializer=tf.orthogonal_initializer()(
             (self.NUM_CHARS, self.CHAR_EMBEDDING_SIZE)), dtype=tf.float32)
         out = tf.nn.embedding_lookup(self.char_embedding, chars)
     with tf.variable_scope('spam_gen_rnn', reuse=reuse):
         next_state = []
         for layer in xrange(self.LAYERS):
             with tf.variable_scope('lstm_layer_%d' % layer, initializer=tf.orthogonal_initializer()) as scope:
                 out, next_state_part = LSTMBlockFusedCell(self.HIDDEN_LAYER_SIZE)(out,
                                                                                   last_state[layer] if last_state is not None else None,
                                                                                   dtype=tf.float32,
                                                                                   scope=scope)
                 out = tf.nn.relu(out) # ???? already applied by LSTMBlockFusedCell?
                 next_state.append(next_state_part)
         return out, next_state
Example #8
0
    def _init(self):
        ExtractionQAModel._init(self)
        if self._composition == "GRU":
            if self._layer_norm:
                rnn_constructor = lambda size: FusedRNNCellAdaptor(LayerNormGRUCell(size), use_dynamic_rnn=True)
            else:
                rnn_constructor = lambda size: FusedRNNCellAdaptor(GRUBlockCell(size), use_dynamic_rnn=True)
        elif self._composition == "RNN":
            rnn_constructor = lambda size: FusedRNNCellAdaptor(BasicRNNCell(size), use_dynamic_rnn=True)
        else:
            if self._layer_norm:
                rnn_constructor = lambda size: FusedRNNCellAdaptor(LayerNormLSTMCell(size), use_dynamic_rnn=True)
            else:
                rnn_constructor = lambda size: LSTMBlockFusedCell(size)

        with tf.device(self._device0):
            self._eval = tf.get_variable("is_eval", initializer=False, trainable=False)
            self._set_train = self._eval.initializer
            self._set_eval = self._eval.assign(True)

            self.context_mask = tfutil.mask_for_lengths(self.context_length, self._batch_size, self.embedder.max_length)

            question_binary_mask = tfutil.mask_for_lengths(self.question_length,
                                                           self.question_embedder.batch_size,
                                                           self.question_embedder.max_length,
                                                           value=1.0,
                                                           mask_right=False)

            with tf.variable_scope("preprocessing_layer"):

                question_binary_mask = tf.gather(question_binary_mask, self.context_partition)
                self._embedded_question_not_dropped = tf.gather(self._embedded_question_not_dropped, self.context_partition)

                # context
                if self._with_features:
                    mask = tf.get_variable("attention_mask", [1, 1, self._embedded_question_not_dropped.get_shape()[-1].value],
                                           initializer=tf.constant_initializer(1.0))
                    # compute word wise features
                    #masked_question = self.question_embedder.output * mask
                    # [B, Q, L]
                    q2c_scores = tf.matmul(self._embedded_question_not_dropped * mask,
                                                 self._embedded_context_not_dropped, adjoint_b=True)
                    q2c_scores = q2c_scores + tf.expand_dims(self.context_mask, 1)
                    #c2q_weights = tf.reduce_max(q2c_scores / (tf.reduce_max(q2c_scores, [2], keep_dims=True) + 1e-5), [1])

                    q2c_weights = tf.reduce_sum(tf.nn.softmax(q2c_scores) * \
                                                tf.expand_dims(question_binary_mask, 2), [1])

                    # [B, L , 1]
                    self.context_features = tf.concat(axis=2, values=[tf.expand_dims(self._word_in_question, 2),
                                                          #tf.expand_dims(c2q_weights, 2),
                                                          tf.expand_dims(q2c_weights,  2)])

                    embedded_ctxt = tf.concat(axis=2, values=[self.embedded_context, self.context_features])


                    in_question_feature = tf.ones(tf.stack([self.question_embedder.batch_size,
                                                           self.question_embedder.max_length, 2]))
                    embedded_question = tf.concat(axis=2, values=[self.embedded_question, in_question_feature])
                else:
                    embedded_ctxt = self.embedded_context
                    embedded_question = self.embedded_question

                if self._with_question_type_features:
                    # Need to add another zero vector so that the total number
                    # of features is even, for LSTM performance reasons.
                    question_type_features = tf.stack([self._is_factoid,
                                                      self._is_list,
                                                      self._is_yesno,
                                                      tf.zeros(tf.shape(self._is_list),
                                                               dtype=tf.bool)],
                                                     axis=1)
                    question_type_features = tf.cast(question_type_features, tf.float32)
                    question_type_features = tf.expand_dims(question_type_features, 1)

                    embedded_question = tf.concat(axis=2, values=[embedded_question,
                                                      tf.tile(question_type_features,
                                                              tf.stack([1, tf.shape(embedded_question)[1], 1]))])

                    question_type_features = tf.gather(question_type_features, self.context_partition)
                    embedded_ctxt = tf.concat(axis=2, values=[embedded_ctxt,
                                                  tf.tile(question_type_features,
                                                          tf.stack([1, tf.shape(embedded_ctxt)[1], 1]))])

                if self._with_entity_tag_features:
                    embedded_question = tf.concat(axis=2, values=[embedded_question,
                                                      tf.cast(self._question_tags, tf.float32)])
                    embedded_ctxt = tf.concat(axis=2, values=[embedded_ctxt,
                                                  tf.cast(self._context_tags, tf.float32)])

                self.encoded_question = self._preprocessing_layer(rnn_constructor, embedded_question,
                                                                  self.question_length, projection_scope="question_proj")

                self.encoded_ctxt = self._preprocessing_layer(rnn_constructor, embedded_ctxt, self.context_length,
                                                              share_rnn=True, projection_scope="context_proj",
                                                              num_fusion_layers=self._num_intrafusion_layers)

                # single time attention over question
                attention_scores = tf.contrib.layers.fully_connected(self.encoded_question, 1,
                                                                     activation_fn=None,
                                                                     weights_initializer=None,
                                                                     biases_initializer=None,
                                                                     scope="attention")
                attention_scores = attention_scores + tf.expand_dims(
                    tfutil.mask_for_lengths(self.question_length, self.question_embedder.batch_size,
                                            self.question_embedder.max_length), 2)
                attention_weights = tf.nn.softmax(attention_scores, 1)
                self.question_attention_weights = attention_weights
                self.question_representation = tf.reduce_sum(attention_weights * self.encoded_question, [1])

                # Multiply question features for each paragraph
                self.encoded_question = tf.gather(self.encoded_question, self.context_partition)
                self.question_representation_per_context = tf.gather(self.question_representation, self.context_partition)
                self.question_length = tf.gather(self.question_length, self.context_partition)

            if self._with_inter_fusion:
                with tf.variable_scope("inter_fusion"):
                    with tf.variable_scope("associative") as vs:
                        mask = tf.get_variable("attention_mask", [1, 1, self.size], initializer=tf.constant_initializer(1.0))
                        mask = tf.nn.relu(mask)
                        for i in range(1):
                            # [B, Q, L]
                            inter_scores = tf.matmul(self.encoded_question * mask, self.encoded_ctxt, adjoint_b=True)
                            inter_scores = inter_scores + tf.expand_dims(self.context_mask, 1)

                            inter_weights = tf.nn.softmax(inter_scores)
                            inter_weights = inter_weights * tf.expand_dims(question_binary_mask, 2)
                            # [B, L, Q] x [B, Q, S] -> [B, L, S]
                            co_states = tf.matmul(inter_weights, self.encoded_question, adj_x=True)

                            u = tf.contrib.layers.fully_connected(tf.concat(axis=2, values=[self.encoded_ctxt, co_states]), self.size,
                                                                  activation_fn=tf.sigmoid,
                                                                  biases_initializer=tf.constant_initializer(1.0),
                                                                  scope="update_gate")
                            self.encoded_ctxt = u * self.encoded_ctxt + (1.0 - u) * co_states
                            vs.reuse_variables()

                    with tf.variable_scope("recurrent") as vs:
                        self.encoded_ctxt.set_shape([None, None, self.size])
                        self.encoded_ctxt = dynamic_rnn(GatedAggregationRNNCell(self.size),
                                                        tf.reverse_sequence(self.encoded_ctxt, self.context_length, 1),
                                                        self.context_length,
                                                        dtype=tf.float32, time_major=False, scope="backward")[0]

                        self.encoded_ctxt = dynamic_rnn(GatedAggregationRNNCell(self.size),
                                                        tf.reverse_sequence(self.encoded_ctxt, self.context_length, 1),
                                                        self.context_length,
                                                        dtype=tf.float32, time_major=False, scope="forward")[0]

            # No matching layer, so set matched_output to encoded_ctxt (for compatibility)
            self.matched_output = self.encoded_ctxt

            with tf.variable_scope("pointer_layer"):
                self.predicted_context_indices, \
                self._start_scores, self._start_pointer, self.start_probs, \
                self._end_scores, self._end_pointer, self.end_probs = \
                    self._spn_answer_layer(self.question_representation_per_context, self.encoded_ctxt)

            self.yesno_added = False
            if self._with_yesno:
                self.add_yesno(add_model_scope=False)

            self._train_variables = [p for p in tf.trainable_variables() if self.name in p.name]
Example #9
0
    def _build_forward(self):
        config = self.config
        N, M, JX, JQ, VW, VC, d, W = \
            config.batch_size, config.max_num_sents, config.max_sent_size, \
            config.max_ques_size, config.word_vocab_size, config.char_vocab_size, config.hidden_size, \
            config.max_word_size
        print("VW:", VW, "N:", N, "M:", M, "JX:", JX, "JQ:", JQ)
        JA = config.max_answer_length
        JX = tf.shape(self.x)[2]
        JQ = tf.shape(self.q)[1]
        M = tf.shape(self.x)[1]
        print("VW:", VW, "N:", N, "M:", M, "JX:", JX, "JQ:", JQ)
        dc, dw, dco = config.char_emb_size, config.word_emb_size, config.char_out_size

        with tf.variable_scope("emb"):
            # Char-CNN Embedding
            if config.use_char_emb:
                with tf.variable_scope("emb_var"), tf.device("/cpu:0"):
                    char_emb_mat = tf.get_variable("char_emb_mat",
                                                   shape=[VC, dc],
                                                   dtype='float')

                with tf.variable_scope("char"):
                    Acx = tf.nn.embedding_lookup(char_emb_mat,
                                                 self.cx)  # [N, M, JX, W, dc]
                    Acq = tf.nn.embedding_lookup(char_emb_mat,
                                                 self.cq)  # [N, JQ, W, dc]
                    Acx = tf.reshape(Acx, [-1, JX, W, dc])
                    Acq = tf.reshape(Acq, [-1, JQ, W, dc])

                    filter_sizes = list(
                        map(int, config.out_channel_dims.split(',')))
                    heights = list(map(int, config.filter_heights.split(',')))
                    assert sum(filter_sizes) == dco, (filter_sizes, dco)
                    with tf.variable_scope("conv"):
                        xx = multi_conv1d(Acx,
                                          filter_sizes,
                                          heights,
                                          "VALID",
                                          self.is_train,
                                          config.keep_prob,
                                          scope="xx")
                        if config.share_cnn_weights:
                            tf.get_variable_scope().reuse_variables()
                            qq = multi_conv1d(Acq,
                                              filter_sizes,
                                              heights,
                                              "VALID",
                                              self.is_train,
                                              config.keep_prob,
                                              scope="xx")
                        else:
                            qq = multi_conv1d(Acq,
                                              filter_sizes,
                                              heights,
                                              "VALID",
                                              self.is_train,
                                              config.keep_prob,
                                              scope="qq")
                        xx = tf.reshape(xx, [-1, M, JX, dco])
                        qq = tf.reshape(qq, [-1, JQ, dco])

            # Word Embedding
            if config.use_word_emb:
                with tf.variable_scope("emb_var") as scope, tf.device(
                        "/cpu:0"):
                    if config.mode == 'train':
                        word_emb_mat = tf.get_variable(
                            "word_emb_mat",
                            dtype='float',
                            shape=[VW, dw],
                            initializer=get_initializer(config.emb_mat))
                    else:
                        word_emb_mat = tf.get_variable("word_emb_mat",
                                                       shape=[VW, dw],
                                                       dtype='float')
                    tf.get_variable_scope().reuse_variables()
                    self.word_emb_scope = scope
                    if config.use_glove_for_unk:
                        word_emb_mat = tf.concat(
                            [word_emb_mat, self.new_emb_mat], 0)

                with tf.name_scope("word"):
                    Ax = tf.nn.embedding_lookup(word_emb_mat,
                                                self.x)  # [N, M, JX, d]
                    Aq = tf.nn.embedding_lookup(word_emb_mat,
                                                self.q)  # [N, JQ, d]
                    self.tensor_dict['x'] = Ax
                    self.tensor_dict['q'] = Aq
                # Concat Char-CNN Embedding and Word Embedding
                if config.use_char_emb:
                    xx = tf.concat([xx, Ax], 3)  # [N, M, JX, di]
                    qq = tf.concat([qq, Aq], 2)  # [N, JQ, di]
                else:
                    xx = Ax
                    qq = Aq

            # exact match
            if config.use_exact_match:
                emx = tf.expand_dims(tf.cast(self.emx, tf.float32), -1)
                xx = tf.concat([xx, emx], 3)  # [N, M, JX, di+1]
                emq = tf.expand_dims(tf.cast(self.emq, tf.float32), -1)
                qq = tf.concat([qq, emq], 2)  # [N, JQ, di+1]

        # 2 layer highway network on Concat Embedding
        if config.highway:
            with tf.variable_scope("highway"):
                xx = highway_network(xx,
                                     config.highway_num_layers,
                                     True,
                                     wd=config.wd,
                                     is_train=self.is_train)
                tf.get_variable_scope().reuse_variables()
                qq = highway_network(qq,
                                     config.highway_num_layers,
                                     True,
                                     wd=config.wd,
                                     is_train=self.is_train)

        self.tensor_dict['xx'] = xx
        self.tensor_dict['qq'] = qq

        # Bidirection-LSTM (3rd layer on paper)
        cell = GRUCell(d) if config.GRU else BasicLSTMCell(d,
                                                           state_is_tuple=True)
        d_cell = SwitchableDropoutWrapper(
            cell, self.is_train, input_keep_prob=config.input_keep_prob)
        x_len = tf.reduce_sum(tf.cast(self.x_mask, 'int32'), 2)  # [N, M]
        q_len = tf.reduce_sum(tf.cast(self.q_mask, 'int32'), 1)  # [N]
        flat_x_len = flatten(x_len, 0)  # [N * M]

        with tf.variable_scope("prepro"):
            if config.use_fused_lstm:
                with tf.variable_scope("u1"):
                    fw_inputs = tf.transpose(
                        qq, [1, 0, 2])  #[time_len, batch_size, input_size]
                    bw_inputs = tf.reverse_sequence(fw_inputs,
                                                    q_len,
                                                    batch_dim=1,
                                                    seq_dim=0)
                    fw_inputs = tf.nn.dropout(fw_inputs,
                                              config.input_keep_prob)
                    bw_inputs = tf.nn.dropout(bw_inputs,
                                              config.input_keep_prob)
                    prep_fw_cell = LSTMBlockFusedCell(d, cell_clip=0)
                    prep_bw_cell = LSTMBlockFusedCell(d, cell_clip=0)
                    fw_outputs, fw_final = prep_fw_cell(fw_inputs,
                                                        dtype=tf.float32,
                                                        sequence_length=q_len,
                                                        scope="fw")
                    bw_outputs, bw_final = prep_bw_cell(bw_inputs,
                                                        dtype=tf.float32,
                                                        sequence_length=q_len,
                                                        scope="bw")
                    bw_outputs = tf.reverse_sequence(bw_outputs,
                                                     q_len,
                                                     batch_dim=1,
                                                     seq_dim=0)
                    current_inputs = tf.concat((fw_outputs, bw_outputs), 2)
                    output = tf.transpose(current_inputs, [1, 0, 2])
                    u = output
                flat_xx = flatten(xx, 2)  # [N * M, JX, d]
                if config.share_lstm_weights:
                    tf.get_variable_scope().reuse_variables()
                    with tf.variable_scope("u1"):
                        fw_inputs = tf.transpose(
                            flat_xx,
                            [1, 0, 2])  #[time_len, batch_size, input_size]
                        bw_inputs = tf.reverse_sequence(fw_inputs,
                                                        flat_x_len,
                                                        batch_dim=1,
                                                        seq_dim=0)
                        # fw_inputs = tf.nn.dropout(fw_inputs, config.input_keep_prob)
                        # bw_inputs = tf.nn.dropout(bw_inputs, config.input_keep_prob)
                        fw_outputs, fw_final = prep_fw_cell(
                            fw_inputs,
                            dtype=tf.float32,
                            sequence_length=flat_x_len,
                            scope="fw")
                        bw_outputs, bw_final = prep_bw_cell(
                            bw_inputs,
                            dtype=tf.float32,
                            sequence_length=flat_x_len,
                            scope="bw")
                        bw_outputs = tf.reverse_sequence(bw_outputs,
                                                         flat_x_len,
                                                         batch_dim=1,
                                                         seq_dim=0)
                        current_inputs = tf.concat((fw_outputs, bw_outputs), 2)
                        output = tf.transpose(current_inputs, [1, 0, 2])
                else:
                    with tf.variable_scope("h1"):
                        fw_inputs = tf.transpose(
                            flat_xx,
                            [1, 0, 2])  #[time_len, batch_size, input_size]
                        bw_inputs = tf.reverse_sequence(fw_inputs,
                                                        flat_x_len,
                                                        batch_dim=1,
                                                        seq_dim=0)
                        # fw_inputs = tf.nn.dropout(fw_inputs, config.input_keep_prob)
                        # bw_inputs = tf.nn.dropout(bw_inputs, config.input_keep_prob)
                        prep_fw_cell = LSTMBlockFusedCell(d, cell_clip=0)
                        prep_bw_cell = LSTMBlockFusedCell(d, cell_clip=0)
                        fw_outputs, fw_final = prep_fw_cell(
                            fw_inputs,
                            dtype=tf.float32,
                            sequence_length=flat_x_len,
                            scope="fw")
                        bw_outputs, bw_final = prep_bw_cell(
                            bw_inputs,
                            dtype=tf.float32,
                            sequence_length=flat_x_len,
                            scope="bw")
                        bw_outputs = tf.reverse_sequence(bw_outputs,
                                                         flat_x_len,
                                                         batch_dim=1,
                                                         seq_dim=0)
                        current_inputs = tf.concat((fw_outputs, bw_outputs), 2)
                        output = tf.transpose(current_inputs, [1, 0, 2])
                h = tf.expand_dims(output, 1)  # [N, M, JX, 2d]
            else:
                (fw_u, bw_u), _ = bidirectional_dynamic_rnn(
                    d_cell, d_cell, qq, q_len, dtype='float',
                    scope='u1')  # [N, J, d], [N, d]
                u = tf.concat([fw_u, bw_u], 2)
                if config.share_lstm_weights:
                    tf.get_variable_scope().reuse_variables()
                    (fw_h, bw_h), _ = bidirectional_dynamic_rnn(
                        cell, cell, xx, x_len, dtype='float',
                        scope='u1')  # [N, M, JX, 2d]
                    h = tf.concat([fw_h, bw_h], 3)  # [N, M, JX, 2d]
                else:
                    (fw_h, bw_h), _ = bidirectional_dynamic_rnn(
                        cell, cell, xx, x_len, dtype='float',
                        scope='h1')  # [N, M, JX, 2d]
                    h = tf.concat([fw_h, bw_h], 3)  # [N, M, JX, 2d]
            self.tensor_dict['u'] = u
            self.tensor_dict['h'] = h

        # Attention Flow Layer (4th layer on paper)
        with tf.variable_scope("main"):
            if config.dynamic_att:
                p0 = h
                u = tf.reshape(tf.tile(tf.expand_dims(u, 1), [1, M, 1, 1]),
                               [N * M, JQ, 2 * d])
                q_mask = tf.reshape(
                    tf.tile(tf.expand_dims(self.q_mask, 1), [1, M, 1]),
                    [N * M, JQ])
                first_cell = AttentionCell(
                    cell,
                    u,
                    size=d,
                    mask=q_mask,
                    mapper='sim',
                    input_keep_prob=self.config.input_keep_prob,
                    is_train=self.is_train)
            else:
                p0 = attention_layer(config,
                                     self.is_train,
                                     h,
                                     u,
                                     h_mask=self.x_mask,
                                     u_mask=self.q_mask,
                                     scope="p0",
                                     tensor_dict=self.tensor_dict)
                first_cell = d_cell
            tp0 = p0

        # Modeling layer (5th layer on paper)
        with tf.variable_scope('modeling_layer'):
            if config.use_fused_lstm:
                g1, encoder_state_final = build_fused_bidirectional_rnn(
                    inputs=p0,
                    num_units=config.hidden_size,
                    num_layers=config.num_modeling_layers,
                    inputs_length=flat_x_len,
                    input_keep_prob=config.input_keep_prob,
                    scope='modeling_layer_g')

            else:
                for layer_idx in range(config.num_modeling_layers - 1):
                    (fw_g0, bw_g0), _ = bidirectional_dynamic_rnn(
                        first_cell,
                        first_cell,
                        p0,
                        x_len,
                        dtype='float',
                        scope="g_{}".format(layer_idx))  # [N, M, JX, 2d]
                    p0 = tf.concat([fw_g0, bw_g0], 3)
                (fw_g1, bw_g1), (fw_s_f, bw_s_f) = bidirectional_dynamic_rnn(
                    first_cell,
                    first_cell,
                    p0,
                    x_len,
                    dtype='float',
                    scope='g1')  # [N, M, JX, 2d]
                g1 = tf.concat([fw_g1, bw_g1], 3)  # [N, M, JX, 2d]

        # Self match layer
        if config.use_self_match:
            s0 = tf.reshape(g1, [N * M, JX, 2 * d])  # [N * M, JX, 2d]
            x_mask = tf.reshape(self.x_mask, [N * M, JX])  # [N * M, JX]
            if config.use_static_self_match:
                with tf.variable_scope(
                        "StaticSelfMatch"
                ):  # implemented follow r-net section 3.3
                    W_x_Vj = tf.contrib.layers.fully_connected(  # [N * M, JX, d]
                        s0,
                        int(d / 2),
                        scope='row_first',
                        activation_fn=None,
                        biases_initializer=None)
                    W_x_Vt = tf.contrib.layers.fully_connected(  # [N * M, JX, d]
                        s0,
                        int(d / 2),
                        scope='col_first',
                        activation_fn=None,
                        biases_initializer=None)
                    sum_rc = tf.add(  # [N * M, JX, JX, d]
                        tf.expand_dims(W_x_Vj, 1), tf.expand_dims(W_x_Vt, 2))
                    v = tf.get_variable('second',
                                        shape=[1, 1, 1, int(d / 2)],
                                        dtype=tf.float32)
                    Sj = tf.reduce_sum(tf.multiply(v, tf.tanh(sum_rc)),
                                       -1)  # [N * M, JX, JX]
                    Ai = softmax(Sj, mask=tf.expand_dims(x_mask,
                                                         1))  # [N * M, JX, JX]
                    Ai = tf.expand_dims(Ai, -1)  # [N * M, JX, JX, 1]
                    Vi = tf.expand_dims(s0, 1)  # [N * M, 1, JX, 2d]
                    Ct = tf.reduce_sum(  # [N * M, JX, 2d]
                        tf.multiply(Ai, Vi), axis=2)
                    inputs_Vt_Ct = tf.concat([s0, Ct], 2)  # [N * M, JX, 4d]
                    if config.use_fused_lstm:
                        fw_inputs = tf.transpose(
                            inputs_Vt_Ct,
                            [1, 0, 2])  # [time_len, batch_size, input_size]
                        bw_inputs = tf.reverse_sequence(fw_inputs,
                                                        flat_x_len,
                                                        batch_dim=1,
                                                        seq_dim=0)
                        fw_inputs = tf.nn.dropout(fw_inputs,
                                                  config.input_keep_prob)
                        bw_inputs = tf.nn.dropout(bw_inputs,
                                                  config.input_keep_prob)
                        prep_fw_cell = LSTMBlockFusedCell(d, cell_clip=0)
                        prep_bw_cell = LSTMBlockFusedCell(d, cell_clip=0)
                        fw_outputs, fw_s_f = prep_fw_cell(
                            fw_inputs,
                            dtype=tf.float32,
                            sequence_length=flat_x_len,
                            scope="fw")
                        bw_outputs, bw_s_f = prep_bw_cell(
                            bw_inputs,
                            dtype=tf.float32,
                            sequence_length=flat_x_len,
                            scope="bw")
                        fw_s_f = LSTMStateTuple(c=fw_s_f[0], h=fw_s_f[1])
                        bw_s_f = LSTMStateTuple(c=bw_s_f[0], h=bw_s_f[1])
                        bw_outputs = tf.reverse_sequence(bw_outputs,
                                                         flat_x_len,
                                                         batch_dim=1,
                                                         seq_dim=0)
                        current_inputs = tf.concat((fw_outputs, bw_outputs), 2)
                        s1 = tf.transpose(current_inputs, [1, 0, 2])
                    else:
                        (fw_s, bw_s), (fw_s_f,
                                       bw_s_f) = bidirectional_dynamic_rnn(
                                           first_cell,
                                           first_cell,
                                           inputs_Vt_Ct,
                                           flat_x_len,
                                           dtype='float',
                                           scope='s')  # [N, M, JX, 2d]
                        s1 = tf.concat([fw_s, bw_s],
                                       2)  # [N * M, JX, 2d], M == 1
            else:
                with tf.variable_scope("DynamicSelfMatch"):
                    first_cell = AttentionCell(cell,
                                               s0,
                                               size=d,
                                               mask=x_mask,
                                               is_train=self.is_train)
                    (fw_s, bw_s), (fw_s_f, bw_s_f) = bidirectional_dynamic_rnn(
                        first_cell,
                        first_cell,
                        s0,
                        x_len,
                        dtype='float',
                        scope='s')  # [N, M, JX, 2d]
                    s1 = tf.concat([fw_s, bw_s], 2)  # [N * M, JX, 2d], M == 1
            g1 = tf.expand_dims(s1, 1)  # [N, M, JX, 2d]

        # prepare for PtrNet
        encoder_output = g1  # [N, M, JX, 2d]
        encoder_output = tf.expand_dims(tf.cast(self.x_mask, tf.float32),
                                        -1) * encoder_output  # [N, M, JX, 2d]

        if config.use_self_match or not config.use_fused_lstm:
            if config.GRU:
                encoder_state_final = tf.concat((fw_s_f, bw_s_f),
                                                1,
                                                name='encoder_concat')
            else:
                if isinstance(fw_s_f, LSTMStateTuple):
                    encoder_state_c = tf.concat((fw_s_f.c, bw_s_f.c),
                                                1,
                                                name='encoder_concat_c')
                    encoder_state_h = tf.concat((fw_s_f.h, bw_s_f.h),
                                                1,
                                                name='encoder_concat_h')
                    encoder_state_final = LSTMStateTuple(c=encoder_state_c,
                                                         h=encoder_state_h)
                elif isinstance(fw_s_f, tf.Tensor):
                    encoder_state_final = tf.concat((fw_s_f, bw_s_f),
                                                    1,
                                                    name='encoder_concat')
                else:
                    encoder_state_final = None
                    tf.logging.error("encoder_state_final not set")

        print("encoder_state_final:", encoder_state_final)

        with tf.variable_scope("output"):
            # eos_symbol = config.eos_symbol
            # next_symbol = config.next_symbol

            tf.assert_equal(
                M,
                1)  # currently dynamic M is not supported, thus we assume M==1
            answer_string = tf.placeholder(
                shape=(N, 1, JA + 1), dtype=tf.int32,
                name='answer_string')  # [N, M, JA + 1]
            answer_string_mask = tf.placeholder(
                shape=(N, 1, JA + 1), dtype=tf.bool,
                name='answer_string_mask')  # [N, M, JA + 1]
            answer_string_length = tf.placeholder(
                shape=(N, 1),
                dtype=tf.int32,
                name='answer_string_length',
            )  # [N, M]
            self.tensor_dict['answer_string'] = answer_string
            self.tensor_dict['answer_string_mask'] = answer_string_mask
            self.tensor_dict['answer_string_length'] = answer_string_length
            self.answer_string = answer_string
            self.answer_string_mask = answer_string_mask
            self.answer_string_length = answer_string_length

            answer_string_flattened = tf.reshape(answer_string,
                                                 [N * M, JA + 1])
            self.answer_string_flattened = answer_string_flattened  # [N * M, JA+1]
            print("answer_string_flattened:", answer_string_flattened)

            answer_string_length_flattened = tf.reshape(
                answer_string_length, [N * M])
            self.answer_string_length_flattened = answer_string_length_flattened  # [N * M]
            print("answer_string_length_flattened:",
                  answer_string_length_flattened)

            decoder_cell = GRUCell(2 * d) if config.GRU else BasicLSTMCell(
                2 * d, state_is_tuple=True)

            with tf.variable_scope("Decoder"):
                decoder_train_logits = ptr_decoder(
                    decoder_cell,
                    tf.reshape(tp0, [N * M, JX, 2 * d]),  # [N * M, JX, 2d]
                    tf.reshape(encoder_output,
                               [N * M, JX, 2 * d]),  # [N * M, JX, 2d]
                    flat_x_len,
                    encoder_final_state=encoder_state_final,
                    max_encoder_length=config.sent_size_th,
                    decoder_output_length=
                    answer_string_length_flattened,  # [N * M]
                    batch_size=N,  # N * M (M=1)
                    attention_proj_dim=self.config.decoder_proj_dim,
                    scope='ptr_decoder'
                )  # [batch_size, dec_len*, enc_seq_len + 1]

                self.decoder_train_logits = decoder_train_logits
                print("decoder_train_logits:", decoder_train_logits)
                self.decoder_train_softmax = tf.nn.softmax(
                    self.decoder_train_logits)
                self.decoder_inference = tf.argmax(
                    decoder_train_logits, axis=2,
                    name='decoder_inference')  # [N, JA + 1]

            self.yp = tf.ones([N, M, JX], dtype=tf.int32) * -1
            self.yp2 = tf.ones([N, M, JX], dtype=tf.int32) * -1