def encode(cell_factory, query, query_length, document, document_length): """ DCN+ deep residual coattention encoder. Encodes query document pairs into a document-query representations in document space. Args: cell_factory: Function of zero arguments returning an RNNCell. query: Tensor of rank 3, shape [N, Q, R]. query_length: Tensor of rank 1, shape [N]. Lengths of queries. document: Tensor of rank 3, shape [N, D, R]. document_length: Tensor of rank 1, shape [N]. Lengths of documents. Returns: Merged representation of query and document in document space, shape [N, D, 2H]. """ with tf.variable_scope('initial_encoder'): query_encoding, document_encoding = query_document_encoder(cell_factory(), cell_factory(), query, query_length, document, document_length) query_encoding = tf.layers.dense( query_encoding, query_encoding.get_shape()[2], activation=tf.tanh ) with tf.variable_scope('coattention_1'): summary_q_1, summary_d_1, coattention_d_1 = coattention(query_encoding, query_length, document_encoding, document_length, sentinel=True) with tf.variable_scope('summary_encoder'): summary_q_encoding, summary_d_encoding = query_document_encoder(cell_factory(), cell_factory(), summary_q_1, query_length, summary_d_1, document_length) with tf.variable_scope('coattention_2'): _, summary_d_2, coattention_d_2 = coattention(summary_q_encoding, query_length, summary_d_encoding, document_length) document_representations = [ document_encoding, # E^D_1 summary_d_encoding, # E^D_2 summary_d_1, # S^D_1 summary_d_2, # S^D_2 coattention_d_1, # C^D_1 coattention_d_2, # C^D_2 ] with tf.variable_scope('final_encoder'): document_representation = convert_gradient_to_tensor(tf.concat(document_representations, 2)) outputs, _ = tf.nn.bidirectional_dynamic_rnn( cell_fw = cell_factory(), cell_bw = cell_factory(), dtype = tf.float32, inputs = document_representation, sequence_length = document_length, ) encoding = convert_gradient_to_tensor(tf.concat(outputs, 2)) return encoding
def loss(logits, answer_span, max_iter): """ Calulates cumulative loss over the iterations Args: logits: TensorArray of Tensors of rank 3 [N, D, 2] of size max_iter. Contains logits of start and end of answer span answer_span: Integer placeholder containing indices of true answer spans [N, 2]. max_iter: Scalar integer, Maximum number of iterations the decoder is run. Returns: Mean cross entropy loss across iterations and batch. Mean is used instead of sum to make loss be on same scale as other more traditional methods. """ batch_size = tf.shape(answer_span)[0] logits = convert_gradient_to_tensor(logits.concat()) answer_span_repeated = tf.tile(answer_span, (max_iter, 1)) start_loss = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits[:, :, 0], labels=answer_span_repeated[:, 0], name='start_loss') end_loss = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits[:, :, 1], labels=answer_span_repeated[:, 1], name='end_loss') start_loss = tf.stack(tf.split(start_loss, max_iter), axis=1) end_loss = tf.stack(tf.split(end_loss, max_iter), axis=1) loss_per_example = tf.reduce_mean(start_loss + end_loss, axis=1) loss = tf.reduce_mean(loss_per_example) return loss
def decoder_body(encoding, state, answer, state_size, pool_size, keep_prob=1.0): """ Decoder feedforward network. Calculates answer span start and end logits. Args: encoding: Tensor of rank 3, shape [N, D, xH]. Query-document encoding. state: Tensor of rank 2, shape [N, D, C]. Current state of decoder state machine. answer: Tensor of rank 2, shape [N, 2]. Current iteration's answer. state_size: Scalar integer. Hidden units of highway maxout network. pool_size: Scalar integer. Number of units that are max pooled in maxout network. keep_prob: Scalar float. Input dropout keep probability for maxout layers. Returns: Tensor of rank 3, shape [N, D, 2]. Answer span logits for answer start and end. """ maxlen = tf.shape(encoding)[1] with tf.variable_scope('start'): span_encoding = start_and_end_encoding(encoding, answer) r_input = convert_gradient_to_tensor(tf.concat([state, span_encoding], axis=1)) r = tf.layers.dense(r_input, state_size, use_bias=False, activation=tf.tanh) # add dropout? r = tf.expand_dims(r, 1) r = tf.tile(r, (1, maxlen, 1)) highway_input = convert_gradient_to_tensor(tf.concat([encoding, r], 2)) alpha = highway_maxout(highway_input, state_size, pool_size, keep_prob) with tf.variable_scope('end'): updated_start = tf.argmax(alpha, axis=1, output_type=tf.int32) updated_answer = tf.stack([updated_start, answer[:, 1]], axis=1) span_encoding = start_and_end_encoding(encoding, updated_answer) r_input = convert_gradient_to_tensor(tf.concat([state, span_encoding], axis=1)) r = tf.layers.dense(r_input, state_size, use_bias=False, activation=tf.tanh) r = tf.expand_dims(r, 1) r = tf.tile(r, (1, maxlen, 1)) highway_input = convert_gradient_to_tensor(tf.concat([encoding, r], 2)) beta = highway_maxout(highway_input, state_size, pool_size, keep_prob) return tf.stack([alpha, beta], axis=2)
def query_document_encoder(cell_fw, cell_bw, query, query_length, document, document_length): """ DCN+ Query Document Encoder layer. Forward and backward cells are shared between the bidirectional query and document encoders. Args: cell_fw: RNNCell for forward direction encoding. cell_bw: RNNCell for backward direction encoding. query: Tensor of rank 3, shape [N, Q, ?]. query_length: Tensor of rank 1, shape [N]. Lengths of queries. document: Tensor of rank 3, shape [N, D, ?]. document_length: Tensor of rank 1, shape [N]. Lengths of documents. Returns: A tuple containing encoding of query, shape [N, Q, 2H]. encoding of document, shape [N, D, 2H]. """ query_fw_bw_encodings, _ = tf.nn.bidirectional_dynamic_rnn( cell_fw = cell_fw, cell_bw = cell_bw, dtype = tf.float32, inputs = query, sequence_length = query_length ) query_encoding = convert_gradient_to_tensor(tf.concat(query_fw_bw_encodings, 2)) document_fw_bw_encodings, _ = tf.nn.bidirectional_dynamic_rnn( cell_fw = cell_fw, cell_bw = cell_bw, dtype = tf.float32, inputs = document, sequence_length = document_length ) document_encoding = convert_gradient_to_tensor(tf.concat(document_fw_bw_encodings, 2)) return query_encoding, document_encoding
def concat_sentinel(sentinel_name, other_tensor): """ Left concatenates a sentinel vector along `other_tensor`'s second dimension. Args: sentinel_name: Variable name of sentinel. other_tensor: Tensor of rank 3 to left concatenate sentinel to. Returns: other_tensor with sentinel. """ sentinel = tf.get_variable(sentinel_name, other_tensor.get_shape()[2], tf.float32) sentinel = tf.reshape(sentinel, (1, 1, -1)) sentinel = tf.tile(sentinel, (tf.shape(other_tensor)[0], 1, 1)) other_tensor = convert_gradient_to_tensor(tf.concat([sentinel, other_tensor], 1)) return other_tensor
def start_and_end_encoding(encoding, answer): """ Gathers the encodings representing the start and end of the answer span passed and concatenates the encodings. Args: encoding: Tensor of rank 3, shape [N, D, xH]. Query-document encoding. answer: Tensor of rank 2. Answer span. Returns: Tensor of rank 2 [N, 2xH], containing the encodings of the start and end of the answer span """ batch_size = tf.shape(encoding)[0] start, end = answer[:, 0], answer[:, 1] encoding_start = tf.gather_nd(encoding, tf.stack([tf.range(batch_size), start], axis=1)) # May be causing UserWarning encoding_end = tf.gather_nd(encoding, tf.stack([tf.range(batch_size), end], axis=1)) return convert_gradient_to_tensor(tf.concat([encoding_start, encoding_end], axis=1))
def highway_maxout(inputs, hidden_size, pool_size, keep_prob=1.0): """ Highway maxout network. Args: inputs: Tensor of rank 3, shape [N, D, ?]. Inputs to network. hidden_size: Scalar integer. Hidden units of highway maxout network. pool_size: Scalar integer. Number of units that are max pooled in maxout layer. keep_prob: Scalar float. Input dropout keep probability for maxout layers. Returns: Tensor of rank 2, shape [N, D]. Logits. """ layer1 = maxout_layer(inputs, hidden_size, pool_size, keep_prob) layer2 = maxout_layer(layer1, hidden_size, pool_size, keep_prob) highway = convert_gradient_to_tensor(tf.concat([layer1, layer2], -1)) output = maxout_layer(highway, 1, pool_size) output = tf.squeeze(output, -1) return output