def encode(cell_factory, query, query_length, document, document_length):
    """ DCN+ deep residual coattention encoder.
    Encodes query document pairs into a document-query representations in document space.

        cell_factory: Function of zero arguments returning an RNNCell.
        query: Tensor of rank 3, shape [N, Q, R].  
        query_length: Tensor of rank 1, shape [N]. Lengths of queries.  
        document: Tensor of rank 3, shape [N, D, R].  
        document_length: Tensor of rank 1, shape [N]. Lengths of documents.  
        Merged representation of query and document in document space, shape [N, D, 2H].

    with tf.variable_scope('initial_encoder'):
        query_encoding, document_encoding = query_document_encoder(cell_factory(), cell_factory(), query, query_length, document, document_length)
        query_encoding = tf.layers.dense(
    with tf.variable_scope('coattention_1'):
        summary_q_1, summary_d_1, coattention_d_1 = coattention(query_encoding, query_length, document_encoding, document_length, sentinel=True)
    with tf.variable_scope('summary_encoder'):
        summary_q_encoding, summary_d_encoding = query_document_encoder(cell_factory(), cell_factory(), summary_q_1, query_length, summary_d_1, document_length)
    with tf.variable_scope('coattention_2'):
        _, summary_d_2, coattention_d_2 = coattention(summary_q_encoding, query_length, summary_d_encoding, document_length)

    document_representations = [
        document_encoding,  # E^D_1
        summary_d_encoding, # E^D_2
        summary_d_1,        # S^D_1
        summary_d_2,        # S^D_2
        coattention_d_1,    # C^D_1
        coattention_d_2,    # C^D_2

    with tf.variable_scope('final_encoder'):
        document_representation = convert_gradient_to_tensor(tf.concat(document_representations, 2))
        outputs, _ = tf.nn.bidirectional_dynamic_rnn(
            cell_fw = cell_factory(),
            cell_bw = cell_factory(),
            dtype = tf.float32,
            inputs = document_representation,
            sequence_length = document_length,
        encoding = convert_gradient_to_tensor(tf.concat(outputs, 2))
    return encoding
def loss(logits, answer_span, max_iter):
    """ Calulates cumulative loss over the iterations

        logits: TensorArray of Tensors of rank 3 [N, D, 2] of size max_iter. Contains logits of start 
        and end of answer span  
        answer_span: Integer placeholder containing indices of true answer spans [N, 2].
        max_iter: Scalar integer, Maximum number of iterations the decoder is run.

        Mean cross entropy loss across iterations and batch. Mean is used instead of sum to make loss be 
        on same scale as other more traditional methods.
    batch_size = tf.shape(answer_span)[0]
    logits = convert_gradient_to_tensor(logits.concat())
    answer_span_repeated = tf.tile(answer_span, (max_iter, 1))

    start_loss = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits[:, :, 0], labels=answer_span_repeated[:, 0], name='start_loss')
    end_loss = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits[:, :, 1], labels=answer_span_repeated[:, 1], name='end_loss')
    start_loss = tf.stack(tf.split(start_loss, max_iter), axis=1)
    end_loss = tf.stack(tf.split(end_loss, max_iter), axis=1)

    loss_per_example = tf.reduce_mean(start_loss + end_loss, axis=1)
    loss = tf.reduce_mean(loss_per_example)
    return loss
def decoder_body(encoding, state, answer, state_size, pool_size, keep_prob=1.0):
    """ Decoder feedforward network.  

    Calculates answer span start and end logits.  

        encoding: Tensor of rank 3, shape [N, D, xH]. Query-document encoding.  
        state: Tensor of rank 2, shape [N, D, C]. Current state of decoder state machine.  
        answer: Tensor of rank 2, shape [N, 2]. Current iteration's answer.  
        state_size: Scalar integer. Hidden units of highway maxout network.  
        pool_size: Scalar integer. Number of units that are max pooled in maxout network.  
        keep_prob: Scalar float. Input dropout keep probability for maxout layers.
        Tensor of rank 3, shape [N, D, 2]. Answer span logits for answer start and end.
    maxlen = tf.shape(encoding)[1]
    with tf.variable_scope('start'):
        span_encoding = start_and_end_encoding(encoding, answer)
        r_input = convert_gradient_to_tensor(tf.concat([state, span_encoding], axis=1))
        r = tf.layers.dense(r_input, state_size, use_bias=False, activation=tf.tanh)  # add dropout?
        r = tf.expand_dims(r, 1)
        r = tf.tile(r, (1, maxlen, 1))
        highway_input = convert_gradient_to_tensor(tf.concat([encoding, r], 2))
        alpha = highway_maxout(highway_input, state_size, pool_size, keep_prob)

    with tf.variable_scope('end'):
        updated_start = tf.argmax(alpha, axis=1, output_type=tf.int32)
        updated_answer = tf.stack([updated_start, answer[:, 1]], axis=1)
        span_encoding = start_and_end_encoding(encoding, updated_answer)
        r_input = convert_gradient_to_tensor(tf.concat([state, span_encoding], axis=1))
        r = tf.layers.dense(r_input, state_size, use_bias=False, activation=tf.tanh)
        r = tf.expand_dims(r, 1)
        r = tf.tile(r, (1, maxlen, 1))
        highway_input = convert_gradient_to_tensor(tf.concat([encoding, r], 2))
        beta = highway_maxout(highway_input, state_size, pool_size, keep_prob)
    return tf.stack([alpha, beta], axis=2)
def query_document_encoder(cell_fw, cell_bw, query, query_length, document, document_length):
    """ DCN+ Query Document Encoder layer.
    Forward and backward cells are shared between the bidirectional query and document encoders.  

        cell_fw: RNNCell for forward direction encoding.  
        cell_bw: RNNCell for backward direction encoding.  
        query: Tensor of rank 3, shape [N, Q, ?].  
        query_length: Tensor of rank 1, shape [N]. Lengths of queries.  
        document: Tensor of rank 3, shape [N, D, ?].  
        document_length: Tensor of rank 1, shape [N]. Lengths of documents.  

        A tuple containing  
            encoding of query, shape [N, Q, 2H].  
            encoding of document, shape [N, D, 2H].
    query_fw_bw_encodings, _ = tf.nn.bidirectional_dynamic_rnn(
        cell_fw = cell_fw,
        cell_bw = cell_bw,
        dtype = tf.float32,
        inputs = query,
        sequence_length = query_length
    query_encoding = convert_gradient_to_tensor(tf.concat(query_fw_bw_encodings, 2))

    document_fw_bw_encodings, _ = tf.nn.bidirectional_dynamic_rnn(
        cell_fw = cell_fw,
        cell_bw = cell_bw,
        dtype = tf.float32,
        inputs = document,
        sequence_length = document_length
    document_encoding = convert_gradient_to_tensor(tf.concat(document_fw_bw_encodings, 2))

    return query_encoding, document_encoding
def concat_sentinel(sentinel_name, other_tensor):
    """ Left concatenates a sentinel vector along `other_tensor`'s second dimension.

        sentinel_name: Variable name of sentinel.  
        other_tensor: Tensor of rank 3 to left concatenate sentinel to.  

        other_tensor with sentinel.
    sentinel = tf.get_variable(sentinel_name, other_tensor.get_shape()[2], tf.float32)
    sentinel = tf.reshape(sentinel, (1, 1, -1))
    sentinel = tf.tile(sentinel, (tf.shape(other_tensor)[0], 1, 1))
    other_tensor = convert_gradient_to_tensor(tf.concat([sentinel, other_tensor], 1))
    return other_tensor
def start_and_end_encoding(encoding, answer):
    """ Gathers the encodings representing the start and end of the answer span passed
    and concatenates the encodings.

        encoding: Tensor of rank 3, shape [N, D, xH]. Query-document encoding.  
        answer: Tensor of rank 2. Answer span.  
        Tensor of rank 2 [N, 2xH], containing the encodings of the start and end of the answer span
    batch_size = tf.shape(encoding)[0]
    start, end = answer[:, 0], answer[:, 1]
    encoding_start = tf.gather_nd(encoding, tf.stack([tf.range(batch_size), start], axis=1))  # May be causing UserWarning
    encoding_end = tf.gather_nd(encoding, tf.stack([tf.range(batch_size), end], axis=1))
    return convert_gradient_to_tensor(tf.concat([encoding_start, encoding_end], axis=1))
def highway_maxout(inputs, hidden_size, pool_size, keep_prob=1.0):
    """ Highway maxout network.

        inputs: Tensor of rank 3, shape [N, D, ?]. Inputs to network.  
        hidden_size: Scalar integer. Hidden units of highway maxout network.  
        pool_size: Scalar integer. Number of units that are max pooled in maxout layer.  
        keep_prob: Scalar float. Input dropout keep probability for maxout layers.  

        Tensor of rank 2, shape [N, D]. Logits.
    layer1 = maxout_layer(inputs, hidden_size, pool_size, keep_prob)
    layer2 = maxout_layer(layer1, hidden_size, pool_size, keep_prob)
    highway = convert_gradient_to_tensor(tf.concat([layer1, layer2], -1))
    output = maxout_layer(highway, 1, pool_size)
    output = tf.squeeze(output, -1)
    return output