def test_unflatten(self): with tf.Graph().as_default(): tensor = tf.placeholder(tf.float32, [4, 7, 6, 3]) w = tf.placeholder(tf.float32, [3, 9]) flat_tensor, unflatten = tensor_utils.flatten(tensor) self.assertAllEqual(tensor_utils.shape(flat_tensor), [4 * 7 * 6, 3]) flat_projected_tensor = tf.matmul(flat_tensor, w) projected_tensor = unflatten(flat_projected_tensor) self.assertAllEqual(tensor_utils.shape(projected_tensor), [4, 7, 6, 9])
def span_candidates(masks, max_span_width): """Generate span candidates. Args: masks: <int32> [num_retrievals, max_sequence_len] max_span_width: int Returns: starts: <int32> [num_spans] ends: <int32> [num_spans] span_masks: <int32> [num_retrievals, num_spans] """ _, max_sequence_len = tensor_utils.shape(masks) def _spans_given_width(width): current_starts = tf.range(max_sequence_len - width + 1) current_ends = tf.range(width - 1, max_sequence_len) return current_starts, current_ends starts, ends = zip(*(_spans_given_width(w + 1) for w in range(max_span_width))) # [num_spans] starts = tf.concat(starts, 0) ends = tf.concat(ends, 0) # [num_retrievals, num_spans] start_masks = tf.gather(masks, starts, axis=-1) end_masks = tf.gather(masks, ends, axis=-1) span_masks = start_masks * end_masks return starts, ends, span_masks
def test_shape_static(self): with tf.Graph().as_default(): tensor = tf.placeholder(tf.int64, [4, 7]) d0_single = tensor_utils.shape(tensor, 0) d1_single = tensor_utils.shape(tensor, 1) d0_full, d1_full = tensor_utils.shape(tensor) self.assertIsInstance(d0_single, int) self.assertIsInstance(d1_single, int) self.assertIsInstance(d0_full, int) self.assertIsInstance(d1_full, int) self.assertEqual(d0_single, 4) self.assertEqual(d1_single, 7) self.assertEqual(d0_full, 4) self.assertEqual(d1_full, 7)
def mask_attention(attention, seq_len1, seq_len2): """Masks an attention matrix. Args: attention: <tf.float32>[batch, seq_len1, seq_len2] seq_len1: <tf.int32>[batch] seq_len2: <tf.int32>[batch] Returns: the masked scores <tf.float32>[batch, seq_len1, seq_len2] """ dim1 = tensor_utils.shape(attention, 1) dim2 = tensor_utils.shape(attention, 2) m1 = tf.sequence_mask(seq_len1, dim1) m2 = tf.sequence_mask(seq_len2, dim2) joint_mask = tf.logical_and(tf.expand_dims(m1, 2), tf.expand_dims(m2, 1)) return ops.mask_logits(attention, joint_mask)
def test_shape_mixed(self): """Test for shape() with a mixture of static and dynamic sizes.""" with tf.Graph().as_default(): tensor = tf.placeholder(tf.int64, [4, None]) d0_single = tensor_utils.shape(tensor, 0) d1_single = tensor_utils.shape(tensor, 1) d0_full, d1_full = tensor_utils.shape(tensor) self.assertIsInstance(d0_single, int) self.assertIsInstance(d1_single, tf.Tensor) self.assertIsInstance(d0_full, int) self.assertIsInstance(d1_full, tf.Tensor) self.assertEqual(d0_single, 4) self.assertEqual(d0_full, 4) with tf.Session() as sess: feed_dict = {tensor: np.zeros((4, 7))} tf_d1_single = sess.run(d1_single, feed_dict=feed_dict) self.assertEqual(tf_d1_single, 7) tf_d1_full = sess.run(d1_full, feed_dict=feed_dict) self.assertEqual(tf_d1_full, 7)
def test_shape_dynamic(self): with tf.Graph().as_default(): tensor = tf.placeholder(tf.int64, [None, None]) d0_single = tensor_utils.shape(tensor, 0) d1_single = tensor_utils.shape(tensor, 1) d0_full, d1_full = tensor_utils.shape(tensor) self.assertIsInstance(d0_single, tf.Tensor) self.assertIsInstance(d1_single, tf.Tensor) self.assertIsInstance(d0_full, tf.Tensor) self.assertIsInstance(d1_full, tf.Tensor) with tf.Session() as sess: feed_dict = {tensor: np.zeros((4, 7))} tf_d0_single, tf_d1_single = sess.run([d0_single, d1_single], feed_dict=feed_dict) self.assertEqual(tf_d0_single, 4) self.assertEqual(tf_d1_single, 7) tf_d0_full, tf_d1_full = sess.run([d0_full, d1_full], feed_dict=feed_dict) self.assertEqual(tf_d0_full, 4) self.assertEqual(tf_d1_full, 7)
def _bilinear_score(context_emb, question_emb): """Compute a bilinear score between the context and question embeddings. Args: context_emb: <float32> [batch_size, max_context_len, hidden_size] question_emb: <float32> [batch_size, hidden_size] Returns: scores: <float32> [batch_size, max_context_len] """ # [batch_size, hidden_size] projected_question_emb = tf.layers.dense( question_emb, tensor_utils.shape(context_emb, -1)) # [batch_size, max_context_len, 1] scores = tf.matmul(context_emb, tf.expand_dims(projected_question_emb, -1)) return tf.squeeze(scores, -1)
def cross_shard_pad(input_tensor): """Cross shard pad. Assuming `input_tensor` is replicated over different TPU cores across the zeroth dimension, this creates a global tensor with unique chunks per replica. This function only fills in the local `input_tensor` and pads the non-local part of the tensor with zeros. Does not actually do any cross-shard communication. Args: input_tensor: <int32|float32> [local_batch_size, dim1, dim2, ...] Returns: padded_tensor: <int32|float32> [local_batch_size * num_shards, dim1, dim2, ...] """ num_shards = num_tpu_shards() # [num_shards] local_mask = tf.equal(tf.range(num_shards), shard_id()) local_mask = tf.cast(local_mask, input_tensor.dtype) tensor_shape = tensor_utils.shape(input_tensor) local_batch_size = tensor_shape[0] global_batch_size = num_shards * local_batch_size # [num_shards, 1, 1, ...] for _ in tensor_shape: local_mask = tf.expand_dims(local_mask, -1) # [num_shards, local_batch_size, input_tensor_dim1, ...] padded_tensor = local_mask * tf.expand_dims(input_tensor, 0) # [global_batch_size, input_tensor_dim1, ...] padded_tensor = tf.reshape(padded_tensor, [global_batch_size] + tensor_shape[1:]) return padded_tensor
def batch_word_to_char_ids(words, word_length): """Batched version of word_to_char_ids. This is a deterministic function that should be computed during preprocessing. We pin this op to the CPU anyways to be safe, since it is slower on GPUs. Args: words: <string> [...] word_length: Number of bytes to include per word. Returns: char_ids: <int32> [..., word_length] """ with tf.device("/cpu:0"): flat_words = tf.reshape(words, [-1]) flat_char_ids = tf.map_fn(fn=partial(word_to_char_ids, word_length=word_length), elems=flat_words, dtype=tf.int32, back_prop=False) char_ids = tf.reshape(flat_char_ids, tensor_utils.shape(words) + [word_length]) return char_ids
def model_fn(features, labels, mode, params): """Model function.""" del labels # [local_batch_size, block_seq_len] block_ids = features["block_ids"] block_mask = features["block_mask"] block_segment_ids = features["block_segment_ids"] # [local_batch_size, query_seq_len] query_ids = features["query_ids"] query_mask = features["query_mask"] local_batch_size = tensor_utils.shape(block_ids, 0) tf.logging.info("Model batch size: %d", local_batch_size) ict_module = create_ict_module(params, mode) query_emb = ict_module(inputs=dict(input_ids=query_ids, input_mask=query_mask, segment_ids=tf.zeros_like(query_ids)), signature="projected") block_emb = ict_module(inputs=dict(input_ids=block_ids, input_mask=block_mask, segment_ids=block_segment_ids), signature="projected") if params["use_tpu"]: # [global_batch_size, hidden_size] block_emb = tpu_utils.cross_shard_concat(block_emb) # [global_batch_size, local_batch_size] labels = tpu_utils.cross_shard_pad(tf.eye(local_batch_size)) # [local_batch_size] labels = tf.argmax(labels, 0) else: # [local_batch_size] labels = tf.range(local_batch_size) tf.logging.info("Global batch size: %s", tensor_utils.shape(block_emb, 0)) # [batch_size, global_batch_size] logits = tf.matmul(query_emb, block_emb, transpose_b=True) # [] loss = tf.losses.sparse_softmax_cross_entropy(labels=labels, logits=logits) train_op = optimization.create_optimizer( loss=loss, init_lr=params["learning_rate"], num_train_steps=params["num_train_steps"], num_warmup_steps=min(10000, max(100, int(params["num_train_steps"] / 10))), use_tpu=params["use_tpu"] if "use_tpu" in params else False) predictions = tf.argmax(logits, -1) metric_args = [ query_mask, block_mask, labels, predictions, features["mask_query"] ] def metric_fn(query_mask, block_mask, labels, predictions, mask_query): masked_accuracy = tf.metrics.accuracy(labels=labels, predictions=predictions, weights=mask_query) unmasked_accuracy = tf.metrics.accuracy( labels=labels, predictions=predictions, weights=tf.logical_not(mask_query)) return dict(query_non_padding=tf.metrics.mean(query_mask), block_non_padding=tf.metrics.mean(block_mask), actual_mask_ratio=tf.metrics.mean(mask_query), masked_accuracy=masked_accuracy, unmasked_accuracy=unmasked_accuracy) if params["use_tpu"]: return tf.estimator.tpu.TPUEstimatorSpec(mode=mode, loss=loss, train_op=train_op, eval_metrics=(metric_fn, metric_args)) else: return tf.estimator.EstimatorSpec( mode=mode, loss=loss, train_op=train_op, eval_metric_ops=metric_fn(*metric_args), predictions=predictions)
def variational_dropout(x, dropout_rate, is_train): if is_train: shape = tensor_utils.shape(x) return tf.nn.dropout(x, 1.0 - dropout_rate, [shape[0], 1, shape[2]]) else: return x
def decomposable_attention(emb1, len1, emb2, len2, hidden_size, hidden_layers, dropout_ratio, mode, epsilon=1e-8): """See https://arxiv.org/abs/1606.01933. Args: emb1: A Tensor with shape [batch_size, max_len1, emb_size] representing the first input sequence. len1: A Tensor with shape [batch_size], indicating the true sequence length of `emb1`. This is required due to padding. emb2: A Tensor with shape [batch_size, max_len2, emb_size] representing the second input sequence. len2: A Tensor with shape [batch_size], indicating the true sequence length of `emb1`. This is required due to padding. hidden_size: An integer indicating the size of each hidden layer in the feed-forward neural networks. hidden_layers: An integer indicating the number of hidden layers in the feed-forward neural networks. dropout_ratio: The probability of dropping out each unit in the activation. This can be None, and is only applied during training. mode: One of the keys from tf.estimator.ModeKeys. epsilon: A small positive constant to add to masks for numerical stability. Returns: final_emb: A Tensor with shape [batch_size, hidden_size]. """ # [batch_size, maxlen1] mask1 = tf.sequence_mask(len1, tensor_utils.shape(emb1, 1), dtype=tf.float32) # [batch_size, maxlen2] mask2 = tf.sequence_mask(len2, tensor_utils.shape(emb2, 1), dtype=tf.float32) with tf.variable_scope("attend"): projected_emb1 = common_layers.ffnn(emb1, [hidden_size] * hidden_layers, dropout_ratio, mode) with tf.variable_scope("attend", reuse=True): projected_emb2 = common_layers.ffnn(emb2, [hidden_size] * hidden_layers, dropout_ratio, mode) # [batch_size, maxlen1, maxlen2] attention_scores = tf.matmul(projected_emb1, projected_emb2, transpose_b=True) attention_weights1 = tf.nn.softmax( attention_scores + tf.log(tf.expand_dims(mask2, 1) + epsilon), 2) attention_weights2 = tf.nn.softmax( attention_scores + tf.log(tf.expand_dims(mask1, 2) + epsilon), 1) # [batch_size, maxlen1, emb_size] attended_emb1 = tf.matmul(attention_weights1, emb2) # [batch_size, maxlen2, emb_size] attended_emb2 = tf.matmul(attention_weights2, emb1, transpose_a=True) with tf.variable_scope("compare"): compared_emb1 = common_layers.ffnn( tf.concat([emb1, attended_emb1], -1), [hidden_size] * hidden_layers, dropout_ratio, mode) with tf.variable_scope("compare", reuse=True): compared_emb2 = common_layers.ffnn( tf.concat([emb2, attended_emb2], -1), [hidden_size] * hidden_layers, dropout_ratio, mode) compared_emb1 *= tf.expand_dims(mask1, -1) compared_emb2 *= tf.expand_dims(mask2, -1) # [batch_size, hidden_size] aggregated_emb1 = tf.reduce_sum(compared_emb1, 1) aggregated_emb2 = tf.reduce_sum(compared_emb2, 1) with tf.variable_scope("aggregate"): final_emb = common_layers.ffnn( tf.concat([aggregated_emb1, aggregated_emb2], -1), [hidden_size] * hidden_layers, dropout_ratio, mode) return final_emb
def model_function(features, labels, mode, params, embeddings): """A model function satisfying the tf.estimator API. Args: features: Dictionary of feature tensors with keys: - question_tok: <string> [batch_size, max_question_len] - context_tok: <string> [batch_size, max_num_context, max_context_len] - question_tok_len: <int32> [batch_size] - num_context: <int32> [batch_size] - context_tok_len: <int32> [batch_size] - question_tok_wid: <int32> [batch_size, max_question_len] - context_tok_wid: <int32> [batch_size, max_num_context, max_context_len] - long_answer_indices: <int32> [batch_size] labels: <int32> [batch_size] for answer index (-1 = NULL). mode: One of the keys from tf.estimator.ModeKeys. params: Dictionary of hyperparameters. embeddings: An embedding_utils.PretrainedWordEmbeddings object. Returns: estimator_spec: A tf.estimator.EstimatorSpec object. """ del params # Unused. if mode == tf.estimator.ModeKeys.PREDICT: # Add a dummy batch dimension if we are exporting the predictor. features = {k: tf.expand_dims(v, 0) for k, v in features.items()} embedding_weights, embedding_scaffold = embeddings.get_params( trainable=False) # Features. question_tok_len = features["question_tok_len"] question_tok_wid = features["question_tok_wid"] context_tok_wid = features["context_tok_wid"] num_context = features["num_context"] context_tok_len = features["context_tok_len"] # Truncate the contexts and labels to a certain maximum length. context_tok_wid, num_context, context_tok_len = ( nq_long_utils.truncate_contexts(context_token_ids=context_tok_wid, num_contexts=num_context, context_len=context_tok_len, max_contexts=FLAGS.max_contexts, max_context_len=FLAGS.max_context_len)) non_null_context_scores = nq_long_decatt_model.build_model( question_tok_wid=question_tok_wid, question_lens=question_tok_len, context_tok_wid=context_tok_wid, context_lens=context_tok_len, embedding_weights=embedding_weights, mode=mode) # Mask out contexts that are padding. num_context_mask = tf.log( tf.sequence_mask(num_context, tensor_utils.shape(non_null_context_scores, 1), dtype=tf.float32)) non_null_context_scores += num_context_mask # <float> [batch_size, 1] null_score = tf.zeros([tf.shape(question_tok_wid)[0], 1]) # Offset everything by 1 to account for null context. # [batch_size, 1 + max_contexts] context_scores = tf.concat([null_score, non_null_context_scores], 1) if mode != tf.estimator.ModeKeys.PREDICT: labels = nq_long_utils.truncate_labels(labels, FLAGS.max_contexts) # In the data, NULL is given index -1 but this is not compatible with # softmax so shift by 1. labels = labels + 1 # Reweight null examples. weights = nq_long_utils.compute_null_weights(labels, FLAGS.null_weight) # When computing the loss we take only the first label. loss_labels = labels[:, 0] # [] loss = tf.losses.sparse_softmax_cross_entropy(labels=loss_labels, logits=context_scores, weights=weights) optimizer = tf.train.AdagradOptimizer( learning_rate=FLAGS.learning_rate) train_op = optimizer.minimize(loss=loss, global_step=tf.train.get_global_step()) # <int32> [batch_size] eval_predictions = tf.to_int32(tf.argmax(context_scores, 1)) non_null_match, non_null_gold, non_null_predictions = ( nq_long_utils.compute_match_stats(eval_predictions, labels)) precision, precision_op = (tf.metrics.mean( non_null_match, weights=non_null_predictions)) recall, recall_op = (tf.metrics.mean(non_null_match, weights=non_null_gold)) f1, f1_op = (nq_long_utils.f1_metric(precision=precision, precision_op=precision_op, recall=recall, recall_op=recall_op)) # Bogus metric until we figure out how to connect Ming Wei's eval code. eval_metric_ops = { "precision": (precision, precision_op), "recall": (recall, recall_op), "f1": (f1, f1_op) } else: loss = None train_op = None eval_metric_ops = {} # In the export, we never predict NULL since the eval metric will compute the # best possible F1. export_long_answer_idx = tf.to_int32(tf.argmax(non_null_context_scores, 1)) export_long_answer_score = tf.reduce_max(non_null_context_scores, 1) predictions = dict(idx=export_long_answer_idx, score=export_long_answer_score) if mode == tf.estimator.ModeKeys.PREDICT: # Remove the dummy batch dimension if we are exporting the predictor. predictions = {k: tf.squeeze(v, 0) for k, v in predictions.items()} estimator_spec = tf.estimator.EstimatorSpec( mode=mode, loss=loss, predictions=predictions, train_op=train_op, eval_metric_ops=eval_metric_ops, scaffold=embedding_scaffold) return estimator_spec
def score_endpoints(question_emb, question_len, context_emb, context_len, hidden_size, num_layers, dropout_ratio, mode, use_cudnn=None): """Compute two scores over context words based on the input embeddings. Args: question_emb: <float32> [batch_size, max_question_len, hidden_size] question_len: <int32> [batch_size] context_emb: <float32>[batch_size, max_context_len, hidden_size] context_len: <int32> [batch_size] hidden_size: Size of hidden layers. num_layers: Number of LSTM layers. dropout_ratio: The probability of dropping out hidden units. mode: Object of type tf.estimator.ModeKeys. use_cudnn: Specify the use of cudnn. `None` denotes automatic selection. Returns: start_scores: <float32> [batch_size, max_context_words] end_scores: <float32> [batch_size, max_context_words] """ # [batch_size, max_question_len] question_mask = tf.sequence_mask(question_len, tensor_utils.shape(question_emb, 1), dtype=tf.float32) # [batch_size, max_context_len, hidden_size] attended_emb = _attend_to_question(context_emb=context_emb, question_emb=question_emb, question_mask=question_mask, hidden_size=hidden_size) # [batch_size, max_context_len, hidden_size * 2] context_emb = tf.concat([context_emb, attended_emb], -1) with tf.variable_scope("contextualize_context"): # [batch_size, max_context_len, hidden_size] contextualized_context_emb = cudnn_layers.stacked_bilstm( input_emb=context_emb, input_len=context_len, hidden_size=hidden_size, num_layers=num_layers, dropout_ratio=dropout_ratio, mode=mode, use_cudnn=use_cudnn) with tf.variable_scope("contextualize_question"): # [batch_size, max_question_len, hidden_size] contextualized_question_emb = cudnn_layers.stacked_bilstm( input_emb=question_emb, input_len=question_len, hidden_size=hidden_size, num_layers=num_layers, dropout_ratio=dropout_ratio, mode=mode, use_cudnn=use_cudnn) if mode == tf_estimator.ModeKeys.TRAIN: contextualized_context_emb = tf.nn.dropout(contextualized_context_emb, 1.0 - dropout_ratio) contextualized_question_emb = tf.nn.dropout( contextualized_question_emb, 1.0 - dropout_ratio) # [batch_size, hidden_size] pooled_question_emb = _attention_pool(contextualized_question_emb, question_mask) if mode == tf_estimator.ModeKeys.TRAIN: pooled_question_emb = tf.nn.dropout(pooled_question_emb, 1.0 - dropout_ratio) # [batch_size, max_context_len] with tf.variable_scope("start_scores"): start_scores = _bilinear_score(contextualized_context_emb, pooled_question_emb) with tf.variable_scope("end_scores"): end_scores = _bilinear_score(contextualized_context_emb, pooled_question_emb) context_log_mask = tf.log( tf.sequence_mask(context_len, tensor_utils.shape(context_emb, 1), dtype=tf.float32)) start_scores += context_log_mask end_scores += context_log_mask return start_scores, end_scores
def model_fn(features, labels, mode, params): """Model function.""" del labels # ============================== # Input features # ============================== # [batch_size, query_seq_len] query_inputs = features["query_inputs"] # [batch_size, num_candidates, candidate_seq_len] candidate_inputs = features["candidate_inputs"] # [batch_size, num_candidates, query_seq_len + candidate_seq_len] joint_inputs = features["joint_inputs"] # [batch_size, num_masks] mlm_targets = features["mlm_targets"] mlm_positions = features["mlm_positions"] mlm_mask = features["mlm_mask"] # ============================== # Create modules. # ============================== bert_module = hub.Module( spec=params["bert_hub_module_handle"], name="bert", tags={"train"} if mode == tf.estimator.ModeKeys.TRAIN else {}, trainable=True) hub.register_module_for_export(bert_module, "bert") embedder_module = hub.Module( spec=params["embedder_hub_module_handle"], name="embedder", tags={"train"} if mode == tf.estimator.ModeKeys.TRAIN else {}, trainable=True) hub.register_module_for_export(embedder_module, "embedder") # ============================== # Retrieve. # ============================== # [batch_size, projected_size] query_emb = embedder_module( inputs=dict( input_ids=query_inputs.token_ids, input_mask=query_inputs.mask, segment_ids=query_inputs.segment_ids), signature="projected") # [batch_size * num_candidates, candidate_seq_len] flat_candidate_inputs, unflatten = flatten_bert_inputs( candidate_inputs) # [batch_size * num_candidates, projected_size] flat_candidate_emb = embedder_module( inputs=dict( input_ids=flat_candidate_inputs.token_ids, input_mask=flat_candidate_inputs.mask, segment_ids=flat_candidate_inputs.segment_ids), signature="projected") # [batch_size, num_candidates, projected_size] unflattened_candidate_emb = unflatten(flat_candidate_emb) # [batch_size, num_candidates] retrieval_score = tf.einsum("BD,BND->BN", query_emb, unflattened_candidate_emb) # ============================== # Read. # ============================== # [batch_size * num_candidates, query_seq_len + candidate_seq_len] flat_joint_inputs, unflatten = flatten_bert_inputs(joint_inputs) # [batch_size * num_candidates, num_masks] flat_mlm_positions, _ = tensor_utils.flatten( tf.tile( tf.expand_dims(mlm_positions, 1), [1, params["num_candidates"], 1])) batch_size, num_masks = tensor_utils.shape(mlm_targets) # [batch_size * num_candidates, query_seq_len + candidates_seq_len] flat_joint_bert_outputs = bert_module( inputs=dict( input_ids=flat_joint_inputs.token_ids, input_mask=flat_joint_inputs.mask, segment_ids=flat_joint_inputs.segment_ids, mlm_positions=flat_mlm_positions), signature="mlm", as_dict=True) # [batch_size, num_candidates] candidate_score = retrieval_score # [batch_size, num_candidates] candidate_log_probs = tf.math.log_softmax(candidate_score) # ============================== # Compute marginal log-likelihood. # ============================== # [batch_size * num_candidates, num_masks] flat_mlm_logits = flat_joint_bert_outputs["mlm_logits"] # [batch_size, num_candidates, num_masks, vocab_size] mlm_logits = tf.reshape( flat_mlm_logits, [batch_size, params["num_candidates"], num_masks, -1]) mlm_log_probs = tf.math.log_softmax(mlm_logits) # [batch_size, num_candidates, num_masks] tiled_mlm_targets = tf.tile( tf.expand_dims(mlm_targets, 1), [1, params["num_candidates"], 1]) # [batch_size, num_candidates, num_masks, 1] tiled_mlm_targets = tf.expand_dims(tiled_mlm_targets, -1) # [batch_size, num_candidates, num_masks, 1] gold_log_probs = tf.batch_gather(mlm_log_probs, tiled_mlm_targets) # [batch_size, num_candidates, num_masks] gold_log_probs = tf.squeeze(gold_log_probs, -1) # [batch_size, num_candidates, num_masks] joint_gold_log_probs = ( tf.expand_dims(candidate_log_probs, -1) + gold_log_probs) # [batch_size, num_masks] marginal_gold_log_probs = tf.reduce_logsumexp(joint_gold_log_probs, 1) # [batch_size, num_masks] float_mlm_mask = tf.cast(mlm_mask, tf.float32) # [] loss = -tf.div_no_nan( tf.reduce_sum(marginal_gold_log_probs * float_mlm_mask), tf.reduce_sum(float_mlm_mask)) # ============================== # Optimization # ============================== num_warmup_steps = min(10000, max(100, int(params["num_train_steps"] / 10))) train_op = optimization.create_optimizer( loss=loss, init_lr=params["learning_rate"], num_train_steps=params["num_train_steps"], num_warmup_steps=num_warmup_steps, use_tpu=params["use_tpu"]) # ============================== # Evaluation # ============================== eval_metric_ops = None if params["use_tpu"] else dict() if mode != tf.estimator.ModeKeys.PREDICT: # [batch_size, num_masks] retrieval_utility = marginal_gold_log_probs - gold_log_probs[:, 0] retrieval_utility *= tf.cast(features["mlm_mask"], tf.float32) # [] retrieval_utility = tf.div_no_nan( tf.reduce_sum(retrieval_utility), tf.reduce_sum(float_mlm_mask)) add_mean_metric("retrieval_utility", retrieval_utility, eval_metric_ops) has_timestamp = tf.cast( tf.greater(features["export_timestamp"], 0), tf.float64) off_policy_delay_secs = ( tf.timestamp() - tf.cast(features["export_timestamp"], tf.float64)) off_policy_delay_mins = off_policy_delay_secs / 60.0 off_policy_delay_mins *= tf.cast(has_timestamp, tf.float64) add_mean_metric("off_policy_delay_mins", off_policy_delay_mins, eval_metric_ops) # Create empty predictions to avoid errors when running in prediction mode. predictions = dict() if params["use_tpu"]: return tf.estimator.tpu.TPUEstimatorSpec( mode=mode, loss=loss, train_op=train_op, predictions=predictions) else: if eval_metric_ops is not None: # Make sure the eval metrics are updated during training so that we get # quick feedback from tensorboard summaries when debugging locally. with tf.control_dependencies([u for _, u in eval_metric_ops.values()]): loss = tf.identity(loss) return tf.estimator.EstimatorSpec( mode=mode, loss=loss, train_op=train_op, eval_metric_ops=eval_metric_ops, predictions=predictions)
def model_function(features, labels, mode, params, embeddings): """A model function satisfying the tf.estimator API. Args: features: Dictionary of feature tensors with keys: - question: <string> [batch_size, max_question_len] - question_len: <int32> [batch_size] - question_cid: <int32> [batch_size, max_question_len, max_chars] - question_wid: <int32> [batch_size, max_question_len] - context: <string> [batch_size, max_context_len] - context_len: <int32> [batch_size] - context_cid: <int32> [batch_size, max_context_len, max_chars] - context_wid: <int32> [batch_size, max_context_len] - answer_start: <int32> [batch_size] - answer_end: <int32> [batch_size] labels: Pair of tensors containing the answer start and answer end. mode: One of the keys from tf.estimator.ModeKeys. params: Unused parameter dictionary. embeddings: An embedding_utils.PretrainedWordEmbeddings object. Returns: estimator_spec: A tf.estimator.EstimatorSpec object. """ del params if mode == tf.estimator.ModeKeys.PREDICT: # Add a dummy batch dimension if we are exporting the predictor. features = {k: tf.expand_dims(v, 0) for k, v in features.items()} embedding_weights, embedding_scaffold = embeddings.get_params(trainable=False) def _embed(prefix): """Embed the input text based and word and character IDs.""" word_emb = tf.nn.embedding_lookup(embedding_weights, features[prefix + "_wid"]) char_emb = common_layers.character_cnn( char_ids=features[prefix + "_cid"], emb_size=FLAGS.char_emb_size, kernel_width=FLAGS.char_kernel_width, num_filters=FLAGS.num_char_filters) concat_emb = tf.concat([word_emb, char_emb], -1) if mode == tf.estimator.ModeKeys.TRAIN: concat_emb = tf.nn.dropout(concat_emb, 1.0 - FLAGS.dropout_ratio) return concat_emb with tf.variable_scope("embed"): # [batch_size, max_question_len, hidden_size] question_emb = _embed("question") with tf.variable_scope("embed", reuse=True): # [batch_size, max_context_len, hidden_size] context_emb = _embed("context") # [batch_size, max_context_len] start_logits, end_logits = document_reader.score_endpoints( question_emb=question_emb, question_len=features["question_len"], context_emb=context_emb, context_len=features["context_len"], hidden_size=FLAGS.hidden_size, num_layers=FLAGS.num_layers, dropout_ratio=FLAGS.dropout_ratio, mode=mode, use_cudnn=False if mode == tf.estimator.ModeKeys.PREDICT else None) if mode != tf.estimator.ModeKeys.PREDICT: # [batch_size] start_labels, end_labels = labels # Since we truncate long contexts, some of the labels will not be # recoverable. In that case, we mask these invalid labels. valid_start_labels = tf.less(start_labels, features["context_len"]) valid_end_labels = tf.less(end_labels, features["context_len"]) tf.summary.histogram("valid_start_labels", tf.to_float(valid_start_labels)) tf.summary.histogram("valid_end_labels", tf.to_float(valid_end_labels)) dummy_labels = tf.zeros_like(start_labels) # [] start_loss = tf.losses.sparse_softmax_cross_entropy( labels=tf.where(valid_start_labels, start_labels, dummy_labels), logits=start_logits, weights=tf.to_float(valid_start_labels), reduction=tf.losses.Reduction.MEAN) end_loss = tf.losses.sparse_softmax_cross_entropy( labels=tf.where(valid_end_labels, end_labels, dummy_labels), logits=end_logits, weights=tf.to_float(valid_end_labels), reduction=tf.losses.Reduction.MEAN) loss = start_loss + end_loss else: loss = None if mode == tf.estimator.ModeKeys.TRAIN: optimizer = tf.train.AdamOptimizer() gradients, variables = zip(*optimizer.compute_gradients(loss)) gradients, _ = tf.clip_by_global_norm(gradients, 5.0) train_op = optimizer.apply_gradients( grads_and_vars=zip(gradients, variables), global_step=tf.train.get_global_step()) else: # Don't build the train_op unnecessarily, since the ADAM variables can cause # problems with loading checkpoints on CPUs. train_op = None batch_size, max_context_len = tensor_utils.shape(features["context_wid"]) tf.summary.histogram("batch_size", batch_size) tf.summary.histogram("non_padding", features["context_len"] / max_context_len) # [batch_size], [batch_size] start_predictions, end_predictions, predicted_score = ( span_utils.max_scoring_span(start_logits, end_logits)) # [batch_size, 2] predictions = dict( start_idx=start_predictions, end_idx=(end_predictions + 1), score=predicted_score) if mode == tf.estimator.ModeKeys.PREDICT: # Remove the dummy batch dimension if we are exporting the predictor. predictions = {k: tf.squeeze(v, 0) for k, v in predictions.items()} if mode == tf.estimator.ModeKeys.EVAL: text_summary = get_text_summary( question=features["question"], context=features["context"], start_predictions=start_predictions, end_predictions=end_predictions) # TODO(kentonl): Replace this with @mingweichang's official eval script. exact_match = tf.logical_and( tf.equal(start_predictions, start_labels), tf.equal(end_predictions, end_labels)) eval_metric_ops = dict( exact_match=tf.metrics.mean(exact_match), text_summary=(text_summary, tf.no_op())) else: eval_metric_ops = None estimator_spec = tf.estimator.EstimatorSpec( mode=mode, loss=loss, predictions=predictions, train_op=train_op, eval_metric_ops=eval_metric_ops, scaffold=embedding_scaffold) return estimator_spec
def create_de_model(bert_config, is_training, input_ids_1, input_mask_1, segment_ids_1, input_ids_2, input_masks_2, segment_ids_2, num_candidates, labels, use_one_hot_embeddings): """Creates a ranking model using cosine and dual encoder representations.""" sequence_length_query = FLAGS.max_seq_length_query sequence_length_passage = FLAGS.max_seq_length - FLAGS.max_seq_length_query input_ids_1 = tf.reshape(input_ids_1, [-1, sequence_length_query]) segment_ids_1 = tf.reshape(segment_ids_1, [-1, sequence_length_query]) input_masks_1 = tf.reshape(input_mask_1, [-1, sequence_length_query]) batch_size = tf.shape(input_masks_1)[0] input_ids_2 = tf.reshape(input_ids_2, [-1, sequence_length_passage]) segment_ids_2 = tf.reshape(segment_ids_2, [-1, sequence_length_passage]) input_masks_2 = tf.reshape(input_masks_2, [-1, sequence_length_passage]) # [batch_size, num_candidates] labels = tf.dtypes.cast(labels, tf.float32) # [batch_size, num_vec_query, hidden_size], [batch_size, num_vec_query] output_layer_1, mask_1 = encode_block(bert_config, input_ids_1, input_masks_1, segment_ids_1, use_one_hot_embeddings, FLAGS.num_vec_query, is_training) output_layer_2, mask_2 = encode_block(bert_config, input_ids_2, input_masks_2, segment_ids_2, use_one_hot_embeddings, FLAGS.num_vec_passage, is_training) label_mask = tf.expand_dims(tf.eye(batch_size), axis=2) label_mask = tf.tile(label_mask, [1, 1, num_candidates]) label_mask = tf.reshape(label_mask, [batch_size, -1]) label_mask = tf.cast(label_mask, tf.float32) labels = tf.tile(labels, [1, batch_size]) labels = tf.multiply(labels, label_mask) output_layer_2_logits = tf.reshape( output_layer_2, [batch_size, num_candidates, FLAGS.num_vec_passage, -1]) mask_2_logits = tf.reshape( mask_2, [batch_size, num_candidates, FLAGS.num_vec_passage]) mask_logits = tf.einsum("BQ,BCP->BCQP", tf.cast(mask_1, tf.float32), tf.cast(mask_2_logits, tf.float32)) logits = tf.einsum("BQH,BCPH->BCQP", output_layer_1, output_layer_2_logits) logits = tf.multiply(logits, mask_logits) logits = tf.reduce_max(logits, axis=-1) logits = tf.reduce_sum(logits, axis=-1) if FLAGS.use_tpu and is_training: num_shards = tpu_utils.num_tpu_shards() output_layer_2 = tpu_utils.cross_shard_concat(output_layer_2) mask_2 = tpu_utils.cross_shard_concat(tf.cast(mask_2, tf.float32)) mask_2 = tf.cast(mask_2, tf.bool) labels = tpu_utils.cross_shard_pad(labels) tf.logging.info("Global batch size: %s", tensor_utils.shape(labels, 0)) tf.logging.info("Num shards: %s", num_shards) tf.logging.info("Number of candidates in batch: %s", tensor_utils.shape(output_layer_2, 0)) labels = tf.reshape(labels, [num_shards, batch_size, -1]) labels = tf.transpose(labels, perm=[1, 0, 2]) labels = tf.reshape(labels, [batch_size, -1]) with tf.variable_scope("loss"): if is_training: output_layer_1 = tf.nn.dropout(output_layer_1, keep_prob=FLAGS.dropout) output_layer_2 = tf.nn.dropout(output_layer_2, keep_prob=FLAGS.dropout) cosine_similarity = tf.einsum("AQH,BPH->ABQP", output_layer_1, output_layer_2) mask = tf.cast( tf.logical_and(tf.expand_dims(tf.expand_dims(mask_1, 2), 1), tf.expand_dims(tf.expand_dims(mask_2, 1), 0)), tf.float32) cosine_similarity = tf.multiply(cosine_similarity, mask) cosine_similarity = tf.reduce_max(cosine_similarity, axis=-1) cosine_similarity = tf.reduce_sum(cosine_similarity, axis=-1) per_example_loss = tf.losses.softmax_cross_entropy( labels, cosine_similarity) return (per_example_loss, logits)
def build_model(question_tok_wid, question_lens, context_tok_wid, context_lens, embedding_weights, mode): """Wrapper around for Decomposable Attention model for NQ long answer scoring. Args: question_tok_wid: <int32> [batch_size, question_len] question_lens: <int32> [batch_size] context_tok_wid: <int32> [batch_size, num_context, context_len] context_lens: <int32> [batch_size, num_context] embedding_weights: <float> [vocab_size, embed_dim] mode: One of the keys from tf.estimator.ModeKeys. Returns: context_scores: <float> [batch_size, num_context] """ # <float> [batch_size, question_len, embed_dim] question_emb = tf.nn.embedding_lookup(embedding_weights, question_tok_wid) # <float> [batch_size, num_context, context_len, embed_dim] context_emb = tf.nn.embedding_lookup(embedding_weights, context_tok_wid) question_emb = tf.layers.dense(inputs=question_emb, units=FLAGS.hidden_size, activation=None, name="reduce_emb", reuse=False) context_emb = tf.layers.dense(inputs=context_emb, units=FLAGS.hidden_size, activation=None, name="reduce_emb", reuse=True) batch_size, num_contexts, max_context_len, embed_dim = ( tensor_utils.shape(context_emb)) _, max_question_len, _ = tensor_utils.shape(question_emb) # <float> [batch_size * num_context, context_len, embed_dim] flat_context_emb = tf.reshape(context_emb, [-1, max_context_len, embed_dim]) # <int32> [batch_size * num_context] flat_context_lens = tf.reshape(context_lens, [-1]) # <float> [batch_size * num_context, question_len, embed_dim] question_emb_tiled = tf.tile(tf.expand_dims(question_emb, 1), [1, num_contexts, 1, 1]) flat_question_emb_tiled = tf.reshape(question_emb_tiled, [-1, max_question_len, embed_dim]) # <int32> [batch_size * num_context] question_lens_tiled = tf.tile(tf.expand_dims(question_lens, 1), [1, num_contexts]) flat_question_lens_tiled = tf.reshape(question_lens_tiled, [-1]) # <float> [batch_size * num_context, hidden_size] flat_decatt_emb = decatt.decomposable_attention( emb1=flat_question_emb_tiled, len1=flat_question_lens_tiled, emb2=flat_context_emb, len2=flat_context_lens, hidden_size=FLAGS.hidden_size, hidden_layers=FLAGS.hidden_layers, dropout_ratio=FLAGS.dropout_ratio, mode=mode) # <float> [batch_size, num_context, hidden_size] decatt_emb = tf.reshape(flat_decatt_emb, [batch_size, num_contexts, FLAGS.hidden_size]) weighted_num_overlap, unweighted_num_overlap, pos_embs = ( _get_non_neural_features(question_tok_wid=question_tok_wid, question_lens=question_lens, context_tok_wid=context_tok_wid, context_lens=context_lens)) final_emb = tf.concat( [decatt_emb, weighted_num_overlap, unweighted_num_overlap, pos_embs], -1) # Final linear layer to get score. # <float> [batch_size, num_context] context_scores = tf.layers.dense(inputs=final_emb, units=1, activation=None) context_scores = tf.squeeze(context_scores, -1) return context_scores
def create_model(bert_config, is_training, input_ids, input_mask, segment_ids, labels, use_one_hot_embeddings, use_tpu): """Creates a classification model.""" tpu_split = FLAGS.tpu_split if use_tpu else 1 model = modeling.BertModel(config=bert_config, is_training=is_training, input_ids=input_ids, input_mask=input_mask, token_type_ids=segment_ids, use_one_hot_embeddings=use_one_hot_embeddings) output_final_layer = model.get_sequence_output() # shape: bze, max_seq_len, hidden if FLAGS.emb_rep == "cls": embedding = tf.squeeze(output_final_layer[:, 0:1, :], axis=1) elif FLAGS.emb_rep == "mean": embedding = tf.reduce_mean(output_final_layer, axis=1) tf.logging.info("per tpu slice") tf.logging.info("emebdding size: %s", embedding.shape) tf.logging.info("label size: %s", labels.shape) tf.logging.info("=======" * 10) if use_tpu: # for tpu usage: combine embeddings after splitting 8 ways # [global_batch_size] labels = tpu_utils.cross_shard_concat(labels) tf.logging.info("label size: %s", labels.shape) tf.logging.info("=======" * 10) # [global_batch_size, hidden_size] embedding = tpu_utils.cross_shard_concat(embedding) tf.logging.info("Global batch size: %s", tensor_utils.shape(embedding, 0)) tf.logging.info("emebdding size: %s", embedding.shape) tf.logging.info("label size: %s", labels.shape) tf.logging.info("num tpu shards: %s", tpu_utils.num_tpu_shards()) tf.logging.info("=======" * 10) num_known_classes = FLAGS.num_domains * FLAGS.num_labels_per_domain num_unknown_classes = NUM_CLASSES - num_known_classes if FLAGS.continual_learning == "pretrain": num_classes = num_known_classes n_examples = FLAGS.known_num_shots elif FLAGS.continual_learning == "few_shot": num_classes = num_unknown_classes n_examples = FLAGS.few_shot if FLAGS.few_shot_known_neg: num_classes = NUM_CLASSES real_num_classes = num_unknown_classes # remove padding in each batch if use_tpu: real_shift = math.ceil( num_classes / FLAGS.batch_size) * FLAGS.batch_size # if use TPU, then embedding.shape[0] will be (num_classes + pad_num) * 8 real_indices = tf.range(num_classes) for i in range(1, tpu_split): real_indices = tf.concat( [real_indices, tf.range(num_classes) + real_shift * i], axis=0) embedding = tf.gather(embedding, real_indices) labels = tf.gather(labels, real_indices) tf.logging.info("emebdding size after removing padding in batch: %s", embedding.shape) tf.logging.info("label size after removing padding in batch: %s", labels.shape) # remove padded batch if n_examples < tpu_split: real_batch_total = n_examples * num_classes embedding = embedding[:real_batch_total] labels = labels[:real_batch_total] real_num = n_examples else: real_num = tpu_split else: # not use TPUs if n_examples < tpu_split: real_num = n_examples else: real_num = tpu_split real_batch_total = real_num * num_classes embedding = embedding[:real_batch_total] labels = labels[:real_batch_total] tf.logging.info("real emebdding size: %s", embedding.shape) tf.logging.info("real label size: %s", labels.shape) n = embedding.shape[0].value assert n == real_num * num_classes, "n: %d; real_num: %d: num_classes: %d" % ( n, real_num, num_classes) with tf.variable_scope("loss", reuse=tf.AUTO_REUSE): if is_training: # I.e., 0.1 dropout embedding = tf.nn.dropout(embedding, keep_prob=1 - DROPOUT_PROB) logits = tf.matmul(embedding, embedding, transpose_b=True) diagonal_matrix = tf.eye(n, n) logits = logits - diagonal_matrix * logits logits_reshape = tf.reshape(logits, [n, real_num, num_classes]) if FLAGS.reduce_method == "mean": all_logits_sum = tf.reduce_sum(logits_reshape, 1) num_counts = tf.ones([n, num_classes]) * real_num label_diagonal = tf.eye(num_classes, num_classes) label_diagonal = tf.tile(label_diagonal, tf.constant([real_num, 1])) num_counts = num_counts - label_diagonal mean_logits = tf.divide(all_logits_sum, num_counts) if FLAGS.few_shot_known_neg: real_logits_indices = tf.range(real_num_classes) for i in range(1, n_examples): real_logits_indices = tf.concat([ real_logits_indices, tf.range(real_num_classes) + num_classes * i ], axis=0) mean_logits = tf.gather(mean_logits, real_logits_indices) label_diagonal = tf.eye(real_num_classes, num_classes) label_diagonal = tf.tile(label_diagonal, tf.constant([real_num, 1])) probabilities = tf.nn.softmax(mean_logits, axis=-1) log_probs = tf.nn.log_softmax(mean_logits, axis=-1) return_logits = mean_logits elif FLAGS.reduce_method == "max": max_logits = tf.reduce_max(logits_reshape, 1) if FLAGS.min_max: # Because the diagnoal is 0, we need to assign a large number to get the # true min. large_number = 50000 added_logits = logits + diagonal_matrix * large_number added_reshape_logits = tf.reshape(added_logits, [n, real_num, num_classes]) min_logits = tf.reduce_min(added_reshape_logits, 1) # n * num_classes masks = tf.tile(tf.eye(num_classes, num_classes), tf.constant([real_num, 1])) max_logits = masks * min_logits + (1 - masks) * max_logits label_diagonal = tf.eye(num_classes, num_classes) if FLAGS.few_shot_known_neg: real_logits_indices = tf.range(real_num_classes) # WARNING: current implementation may not be correct for few_shot > 8 on # tpus in the following for loop, it should be for i in # range(1, real_num) instead of in range(1, n_examples). assert n_examples < 8, ( "current implementation may not be correct for " "few_shot > 8 on tpus. Need to check") # Note: n_examples here is 2 or 5, which is less than tpu_slit. for i in range(1, n_examples): real_logits_indices = tf.concat([ real_logits_indices, tf.range(real_num_classes) + num_classes * i ], axis=0) max_logits = tf.gather(max_logits, real_logits_indices) label_diagonal = label_diagonal[:real_num_classes] label_diagonal = tf.tile(label_diagonal, tf.constant([real_num, 1])) probabilities = tf.nn.softmax(max_logits, axis=-1) log_probs = tf.nn.log_softmax(max_logits, axis=-1) return_logits = max_logits elif FLAGS.reduce_method == "random": indice_0 = tf.expand_dims(tf.range(n), axis=1) # n x 1 indice_1 = tf.random.uniform([n, 1], minval=0, maxval=real_num, dtype=tf.dtypes.int32) random_indices = tf.concat([indice_0, indice_1], axis=1) random_logits = tf.gather_nd(logits_reshape, random_indices) label_diagonal = tf.eye(num_classes, num_classes) if FLAGS.few_shot_known_neg: real_logits_indices = tf.range(real_num_classes) for i in range(1, n_examples): real_logits_indices = tf.concat([ real_logits_indices, tf.range(real_num_classes) + num_classes * i ], axis=0) random_logits = tf.gather(random_logits, real_logits_indices) label_diagonal = label_diagonal[:real_num_classes] label_diagonal = tf.tile(label_diagonal, tf.constant([real_num, 1])) probabilities = tf.nn.softmax(random_logits, axis=-1) log_probs = tf.nn.log_softmax(random_logits, axis=-1) return_logits = random_logits per_example_loss = -tf.reduce_sum(label_diagonal * log_probs, axis=-1) loss = tf.reduce_mean(per_example_loss) return (loss, per_example_loss, return_logits, probabilities)