def encode(cell_factory, query, query_length, document, document_length):
    """ DCN+ deep residual coattention encoder.
    
    Encodes query document pairs into a document-query representations in document space.

    Args:  
        cell_factory: Function of zero arguments returning an RNNCell.
        query: Tensor of rank 3, shape [N, Q, R].  
        query_length: Tensor of rank 1, shape [N]. Lengths of queries.  
        document: Tensor of rank 3, shape [N, D, R].  
        document_length: Tensor of rank 1, shape [N]. Lengths of documents.  
    
    Returns:  
        Merged representation of query and document in document space, shape [N, D, 2H].
    """

    with tf.variable_scope('initial_encoder'):
        query_encoding, document_encoding = query_document_encoder(cell_factory(), cell_factory(), query, query_length, document, document_length)
        query_encoding = tf.layers.dense(
            query_encoding, 
            query_encoding.get_shape()[2], 
            activation=tf.tanh
        )
    
    with tf.variable_scope('coattention_1'):
        summary_q_1, summary_d_1, coattention_d_1 = coattention(query_encoding, query_length, document_encoding, document_length, sentinel=True)
    
    with tf.variable_scope('summary_encoder'):
        summary_q_encoding, summary_d_encoding = query_document_encoder(cell_factory(), cell_factory(), summary_q_1, query_length, summary_d_1, document_length)
    
    with tf.variable_scope('coattention_2'):
        _, summary_d_2, coattention_d_2 = coattention(summary_q_encoding, query_length, summary_d_encoding, document_length)

    document_representations = [
        document_encoding,  # E^D_1
        summary_d_encoding, # E^D_2
        summary_d_1,        # S^D_1
        summary_d_2,        # S^D_2
        coattention_d_1,    # C^D_1
        coattention_d_2,    # C^D_2
    ]

    with tf.variable_scope('final_encoder'):
        document_representation = convert_gradient_to_tensor(tf.concat(document_representations, 2))
        outputs, _ = tf.nn.bidirectional_dynamic_rnn(
            cell_fw = cell_factory(),
            cell_bw = cell_factory(),
            dtype = tf.float32,
            inputs = document_representation,
            sequence_length = document_length,
        )
        encoding = convert_gradient_to_tensor(tf.concat(outputs, 2))
    return encoding
def loss(logits, answer_span, max_iter):
    """ Calulates cumulative loss over the iterations

    Args:  
        logits: TensorArray of Tensors of rank 3 [N, D, 2] of size max_iter. Contains logits of start 
        and end of answer span  
        answer_span: Integer placeholder containing indices of true answer spans [N, 2].
        max_iter: Scalar integer, Maximum number of iterations the decoder is run.

    Returns:  
        Mean cross entropy loss across iterations and batch. Mean is used instead of sum to make loss be 
        on same scale as other more traditional methods.
    """
    batch_size = tf.shape(answer_span)[0]
    logits = convert_gradient_to_tensor(logits.concat())
    answer_span_repeated = tf.tile(answer_span, (max_iter, 1))

    start_loss = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits[:, :, 0], labels=answer_span_repeated[:, 0], name='start_loss')
    end_loss = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits[:, :, 1], labels=answer_span_repeated[:, 1], name='end_loss')
    start_loss = tf.stack(tf.split(start_loss, max_iter), axis=1)
    end_loss = tf.stack(tf.split(end_loss, max_iter), axis=1)

    loss_per_example = tf.reduce_mean(start_loss + end_loss, axis=1)
    loss = tf.reduce_mean(loss_per_example)
    return loss
def decoder_body(encoding, state, answer, state_size, pool_size, keep_prob=1.0):
    """ Decoder feedforward network.  

    Calculates answer span start and end logits.  

    Args:  
        encoding: Tensor of rank 3, shape [N, D, xH]. Query-document encoding.  
        state: Tensor of rank 2, shape [N, D, C]. Current state of decoder state machine.  
        answer: Tensor of rank 2, shape [N, 2]. Current iteration's answer.  
        state_size: Scalar integer. Hidden units of highway maxout network.  
        pool_size: Scalar integer. Number of units that are max pooled in maxout network.  
        keep_prob: Scalar float. Input dropout keep probability for maxout layers.
    
    Returns:  
        Tensor of rank 3, shape [N, D, 2]. Answer span logits for answer start and end.
    """
    maxlen = tf.shape(encoding)[1]
    
    with tf.variable_scope('start'):
        span_encoding = start_and_end_encoding(encoding, answer)
        r_input = convert_gradient_to_tensor(tf.concat([state, span_encoding], axis=1))
        r = tf.layers.dense(r_input, state_size, use_bias=False, activation=tf.tanh)  # add dropout?
        r = tf.expand_dims(r, 1)
        r = tf.tile(r, (1, maxlen, 1))
        highway_input = convert_gradient_to_tensor(tf.concat([encoding, r], 2))
        alpha = highway_maxout(highway_input, state_size, pool_size, keep_prob)

    with tf.variable_scope('end'):
        updated_start = tf.argmax(alpha, axis=1, output_type=tf.int32)
        updated_answer = tf.stack([updated_start, answer[:, 1]], axis=1)
        span_encoding = start_and_end_encoding(encoding, updated_answer)
        r_input = convert_gradient_to_tensor(tf.concat([state, span_encoding], axis=1))
        r = tf.layers.dense(r_input, state_size, use_bias=False, activation=tf.tanh)
        r = tf.expand_dims(r, 1)
        r = tf.tile(r, (1, maxlen, 1))
        highway_input = convert_gradient_to_tensor(tf.concat([encoding, r], 2))
        beta = highway_maxout(highway_input, state_size, pool_size, keep_prob)
    
    return tf.stack([alpha, beta], axis=2)
def query_document_encoder(cell_fw, cell_bw, query, query_length, document, document_length):
    """ DCN+ Query Document Encoder layer.
    
    Forward and backward cells are shared between the bidirectional query and document encoders.  

    Args:  
        cell_fw: RNNCell for forward direction encoding.  
        cell_bw: RNNCell for backward direction encoding.  
        query: Tensor of rank 3, shape [N, Q, ?].  
        query_length: Tensor of rank 1, shape [N]. Lengths of queries.  
        document: Tensor of rank 3, shape [N, D, ?].  
        document_length: Tensor of rank 1, shape [N]. Lengths of documents.  

    Returns:  
        A tuple containing  
            encoding of query, shape [N, Q, 2H].  
            encoding of document, shape [N, D, 2H].
    """
    query_fw_bw_encodings, _ = tf.nn.bidirectional_dynamic_rnn(
        cell_fw = cell_fw,
        cell_bw = cell_bw,
        dtype = tf.float32,
        inputs = query,
        sequence_length = query_length
    )
    query_encoding = convert_gradient_to_tensor(tf.concat(query_fw_bw_encodings, 2))

    document_fw_bw_encodings, _ = tf.nn.bidirectional_dynamic_rnn(
        cell_fw = cell_fw,
        cell_bw = cell_bw,
        dtype = tf.float32,
        inputs = document,
        sequence_length = document_length
    )
        
    document_encoding = convert_gradient_to_tensor(tf.concat(document_fw_bw_encodings, 2))

    return query_encoding, document_encoding
def concat_sentinel(sentinel_name, other_tensor):
    """ Left concatenates a sentinel vector along `other_tensor`'s second dimension.

    Args:  
        sentinel_name: Variable name of sentinel.  
        other_tensor: Tensor of rank 3 to left concatenate sentinel to.  

    Returns:  
        other_tensor with sentinel.
    """
    sentinel = tf.get_variable(sentinel_name, other_tensor.get_shape()[2], tf.float32)
    sentinel = tf.reshape(sentinel, (1, 1, -1))
    sentinel = tf.tile(sentinel, (tf.shape(other_tensor)[0], 1, 1))
    other_tensor = convert_gradient_to_tensor(tf.concat([sentinel, other_tensor], 1))
    return other_tensor
def start_and_end_encoding(encoding, answer):
    """ Gathers the encodings representing the start and end of the answer span passed
    and concatenates the encodings.

    Args:  
        encoding: Tensor of rank 3, shape [N, D, xH]. Query-document encoding.  
        answer: Tensor of rank 2. Answer span.  
    
    Returns:
        Tensor of rank 2 [N, 2xH], containing the encodings of the start and end of the answer span
    """
    batch_size = tf.shape(encoding)[0]
    start, end = answer[:, 0], answer[:, 1]
    encoding_start = tf.gather_nd(encoding, tf.stack([tf.range(batch_size), start], axis=1))  # May be causing UserWarning
    encoding_end = tf.gather_nd(encoding, tf.stack([tf.range(batch_size), end], axis=1))
    return convert_gradient_to_tensor(tf.concat([encoding_start, encoding_end], axis=1))
def highway_maxout(inputs, hidden_size, pool_size, keep_prob=1.0):
    """ Highway maxout network.

    Args:  
        inputs: Tensor of rank 3, shape [N, D, ?]. Inputs to network.  
        hidden_size: Scalar integer. Hidden units of highway maxout network.  
        pool_size: Scalar integer. Number of units that are max pooled in maxout layer.  
        keep_prob: Scalar float. Input dropout keep probability for maxout layers.  

    Returns:  
        Tensor of rank 2, shape [N, D]. Logits.
    """
    layer1 = maxout_layer(inputs, hidden_size, pool_size, keep_prob)
    layer2 = maxout_layer(layer1, hidden_size, pool_size, keep_prob)
    
    highway = convert_gradient_to_tensor(tf.concat([layer1, layer2], -1))
    output = maxout_layer(highway, 1, pool_size)
    output = tf.squeeze(output, -1)
    return output