def _get_rc_model_input( question_ids, question_mask, context_ids, context_mask, vocab, ): """Create RC module input from separate batched components. Args: question_ids: <int32> [batch_size, question_len] question_mask: <int32> [batch_size, question_len] context_ids: <int32> [batch_size, context_len] context_mask: <int32> [batch_size, context_len] vocab: Instance of text_utils.Vocab. Returns: input_ids: <int32> [batch_size, rc_input_len] input_mask: <int32> [batch_size, rc_input_len] segment_ids: <int32> [batch_size, rc_input_len] """ # Get batch size. batch_size = tensor_utils.shape(context_ids, 0) # Get special tokens. cls = vocab.t2i(vocab.CLS) sep = vocab.t2i(vocab.SEP) # Join question, context, and special tokens. cls_batch = tf.fill([batch_size, 1], cls) sep_batch = tf.fill([batch_size, 1], sep) input_ids = tf.concat( [cls_batch, question_ids, sep_batch, context_ids, sep_batch], axis=1) # Create and join segment ids. segment_a_ids = tf.fill( [batch_size, tensor_utils.shape(question_ids, 1) + 2], 0) segment_b_ids = tf.fill( [batch_size, tensor_utils.shape(context_ids, 1) + 1], 1) segment_ids = tf.concat([segment_a_ids, segment_b_ids], axis=1) # Create joined mask, accounting for special tokens gaps. gap_mask = tf.fill([batch_size, 1], 1) input_mask = tf.concat( [gap_mask, question_mask, gap_mask, context_mask, gap_mask], axis=1) bool_mask = tf.cast(input_mask, tf.bool) # Select unmasked items and move all padding to the end. # Right now this looks like this: # [CLS] X X X [PAD] ... [SEP] Y Y Y [PAD] ... [SEP] [PAD] ... # And we want to change it to look like this: # [CLS] X X X [SEP] Y Y Y [SEP] [PAD] ... input_ids = tensor_utils.boolean_mask(input_ids, bool_mask) input_mask = tensor_utils.boolean_mask(input_mask, bool_mask) segment_ids = tensor_utils.boolean_mask(segment_ids, bool_mask) return input_ids, input_mask, segment_ids
def update_values(old_values, current_value): """Update stored values with this time step.""" shape = [1] * len(tensor_utils.shape(old_values)) shape[:2] = [batch_size, num_steps] tile = tensor_utils.shape(old_values) tile[:2] = [1, 1] condition = tf.tile(tf.reshape(is_written, shape), tile) tile = [1] * len(tensor_utils.shape(old_values)) tile[1] = num_steps current_value = tf.tile(current_value, tile) return tf.where(condition, old_values, current_value)
def compute_image_transformer( self, input_ids, input_image, input_image_mask, input_positions, reuse=None, ): """Build the image transformer.""" with tf.variable_scope(self.scope_prefix + "transformer", reuse=reuse): with tf.variable_scope("bridge"): image_emb = tf.layers.dense( inputs=input_image, units=self.config.hidden_size, activation=tf.nn.relu, kernel_initializer=modeling.create_initializer( self.config.initializer_range), reuse=reuse) with tf.variable_scope("embeddings"): input_emb = tf.gather(self.embedding_table, input_ids) image_emb = tf.concat([input_emb, image_emb], axis=1) batch_size = tensor_utils.shape(image_emb, 0) sequence_length = tensor_utils.shape(image_emb, 1) position_emb = tf.gather(self.image_region_table, input_positions) position_emb = tf.pad(position_emb, [[0, 0], [1, 0], [0, 0]]) input_order = tf.range(tensor_utils.shape(image_emb, 1)) input_order = tf.tile(tf.expand_dims(input_order, 0), [tensor_utils.shape(image_emb, 0), 1]) order_emb = tf.gather(self.image_order_table, input_order) input_segment_id = tf.fill([batch_size, sequence_length], self.IMG) segment_emb = tf.gather(self.segment_table, input_segment_id) input_emb = image_emb + position_emb + order_emb + segment_emb input_emb = modeling.layer_norm_and_dropout( input_emb, self.config.hidden_dropout_prob) with tf.variable_scope("image/encoder"): sequence_output, output_cache = compute_transformer( input_tensor=input_emb, attention_mask=tf.expand_dims(input_image_mask, 1), hidden_size=self.config.hidden_size, num_hidden_layers=self.config.num_hidden_layers, num_attention_heads=self.config.num_attention_heads, intermediate_size=self.config.intermediate_size, intermediate_act_fn=modeling.get_activation( self.config.hidden_act), hidden_dropout_prob=self.config.hidden_dropout_prob, attention_probs_dropout_prob=( self.config.attention_probs_dropout_prob), initializer_range=self.config.initializer_range, input_cache=None) return sequence_output, output_cache
def build_planner_inputs(question, answer, length, lookup_table): """Convert text to TextInputs for conditional text planner. Args: question: <string>, space-separated token string. answer: <string>, space-separated token string. length: Length to pad or truncate to. lookup_table: Instance of contrib.lookup.index_table_from_tensor. Returns: Instance of TextInputs. """ # Build question. q_tokens = tf.string_split([question]).values q_tokens = tf.concat([["[Q]"], q_tokens], axis=0) q_token_ids = tf.cast(lookup_table.lookup(q_tokens), tf.int32) q_len = tensor_utils.shape(q_token_ids, 0) q_positions = tf.range(q_len) # Build answer. a_tokens = tf.string_split([answer]).values a_tokens = tf.concat([["[A]"], a_tokens], axis=0) a_token_ids = tf.cast(lookup_table.lookup(a_tokens), tf.int32) a_len = tensor_utils.shape(a_token_ids, 0) a_positions = tf.range(a_len) # Combine. token_ids = tf.concat([q_token_ids, a_token_ids], axis=0) segment_ids = tf.concat([tf.fill([q_len], 2), tf.fill([a_len], 1)], axis=0) positions = tf.concat([q_positions, a_positions], axis=0) q_mask = tf.ones_like(q_token_ids) mask = tf.concat([q_mask, tf.ones_like(a_token_ids)], axis=0) # Truncate. token_ids = token_ids[:length] segment_ids = segment_ids[:length] mask = mask[:length] positions = positions[:length] # Pad. pad = [[0, length - tf.size(token_ids)]] token_ids = tf.pad(token_ids, pad) mask = tf.pad(mask, pad) segment_ids = tf.pad(segment_ids, pad) positions = tf.pad(positions, pad) text_input = TextInputs(token_ids=tf.ensure_shape(token_ids, [length]), mask=tf.ensure_shape(mask, [length]), segment_ids=tf.ensure_shape(segment_ids, [length]), positions=tf.ensure_shape(positions, [length])) return text_input
def sample_from_rollouts(rollouts, baseline=None, reward_type="exact_match"): """Sample a single example from the given rollouts. Args: rollouts: Instance of RolloutOutputs. baseline: <float32> [batch_size] Baseline value b for R'(y) = R(y) - b. reward_type: Choice between indicator, exact_match, and F1. Returns: rollout: Instance of text_utils.TextInputs. reward: <float32> [batch_size] """ batch_size = tensor_utils.shape(rollouts.token_ids, 0) rollout_length = tensor_utils.shape(rollouts.token_ids, 2) # Self-critical baseline. if baseline is None: baseline = tf.zeros([batch_size]) # [batch_size, num_rollouts] rewards = rollouts.rewards[reward_type] - tf.expand_dims(baseline, 1) # Mask zero reward samples. masked_scores = tf.where(tf.not_equal(rewards, 0), tf.zeros_like(rollouts.scores), tf.ones_like(rollouts.scores) * -1e8) # [batch_size, 1] sample_idx = tf.distributions.Categorical(logits=masked_scores).sample() sample_idx = tf.reshape(sample_idx, [batch_size, 1]) # [batch_size] reward = tf.reshape(tensor_utils.gather(rewards, sample_idx), [-1]) # [batch_size, rollout_length] token_ids = tf.reshape(tensor_utils.gather(rollouts.token_ids, sample_idx), [batch_size, -1]) mask = tf.reshape(tensor_utils.gather(rollouts.mask, sample_idx), [batch_size, -1]) segment_ids = tf.zeros_like(token_ids) positions = tf.tile(tf.expand_dims(tf.range(rollout_length), 0), [batch_size, 1]) # Create text input. rollout = text_utils.TextInputs(token_ids=token_ids, mask=mask, segment_ids=segment_ids, positions=positions) return rollout, reward
def compute_attention_mask(token_mask, input_mask): """Compute attention mask.""" batch_size = tensor_utils.shape(token_mask, 0) num_tokens = tensor_utils.shape(token_mask, 1) token_to_token = tf.ones([batch_size, num_tokens, num_tokens], dtype=tf.int32) token_to_token = tf.matrix_band_part(token_to_token, -1, 0) if input_mask is not None: token_to_input = tf.expand_dims(input_mask, 1) token_to_input = tf.tile(token_to_input, [1, num_tokens, 1]) attention_mask = tf.concat([token_to_input, token_to_token], axis=-1) else: attention_mask = token_to_token return attention_mask
def expand_example(features, sample_one=True): """Expand nested tensor protos into multiple examples.""" question_ids = tf.io.parse_tensor(features["question_ids"], out_type=tf.int64) questions = tf.io.parse_tensor(features["questions"], out_type=tf.string) answers = tf.io.parse_tensor(features["answers"], out_type=tf.string) captions = tf.io.parse_tensor(features["captions"], out_type=tf.string) num_qas = tensor_utils.shape(questions, 0) if sample_one: rid = tf.random.uniform([], maxval=num_qas, dtype=tf.int32) question_ids = tf.expand_dims(question_ids[rid], 0) questions = tf.expand_dims(questions[rid], 0) answers = tf.expand_dims(answers[rid], 0) captions = tf.expand_dims(captions[rid], 0) num_qas = 1 image_ids = tf.tile(tf.expand_dims(features["image_id"], 0), [num_qas]) images = tf.tile(tf.expand_dims(features["image"], 0), [num_qas]) object_features = tf.tile(tf.expand_dims(features["object_features"], 0), [num_qas]) object_positions = tf.tile(tf.expand_dims(features["object_positions"], 0), [num_qas]) features = dict(image_id=image_ids, image=images, object_features=object_features, object_positions=object_positions, question_id=question_ids, question=questions, answer=answers, caption=captions) return tf.data.Dataset.from_tensor_slices(features)
def exact_match(answer_ids, prediction_ids, vocab): """Compute exact match score between answer tokens and prediction tokens. Args: answer_ids: <int32> [batch_size, answer_length] prediction_ids: <int32> [batch_size, prediction_length] vocab: Instance of text_utils.Vocab. Returns: score: <float32> [batch_size] tensor of {0.0, 1.0}. """ batch_size = tensor_utils.shape(answer_ids, 0) # Get cleanable words. remove_ids = list(_get_normalized_set(vocab)) remove_ids = tf.reshape(remove_ids, [1, 1, -1]) remove_ids = tf.tile(remove_ids, [batch_size, 1, 1]) # Clean answer: remove tokens that are in the normalized set. should_keep = tf.reduce_all(tf.not_equal(tf.expand_dims(answer_ids, -1), remove_ids), axis=-1) answer_ids = tensor_utils.boolean_mask(answer_ids, should_keep) # Clean context: remove tokens that are in the normalized set. should_keep = tf.reduce_all(tf.not_equal( tf.expand_dims(prediction_ids, -1), remove_ids), axis=-1) prediction_ids = tensor_utils.boolean_mask(prediction_ids, should_keep) # Cleaned lengths. answer_len = tensor_utils.shape(answer_ids, 1) prediction_len = tensor_utils.shape(prediction_ids, 1) # Pad the shorter one to the length of the longer. padding = tf.maximum(0, prediction_len - answer_len) answer_ids = tf.pad(answer_ids, [[0, 0], [0, padding]]) padding = tf.maximum(0, answer_len - prediction_len) prediction_ids = tf.pad(prediction_ids, [[0, 0], [0, padding]]) # Check for equality: Padded A == Padded B? is_equal = tf.reduce_all(tf.equal(answer_ids, prediction_ids), axis=1) score = tf.cast(is_equal, tf.float32) return score
def preprocess_mapper(features, params, lookup_table, vocab, mode): """Model-specific preprocessing of features from the dataset.""" # Set input type. features["input_type"] = tf.constant(datasets.DatasetTypes.REFERENCE) if mode != tf.estimator.ModeKeys.PREDICT: # Select random caption. captions = tf.io.parse_tensor(features["captions"], tf.string) num_captions = tensor_utils.shape(captions, 0) rid = tf.random.uniform([], maxval=num_captions, dtype=tf.int32) caption = text_utils.build_text_inputs(text=captions[rid], length=params["caption_length"], lookup_table=lookup_table, segment_id=0, start_token=vocab.CLS, end_token=vocab.SEP) assert isinstance(caption, text_utils.TextInputs) features["token_inputs"] = text_utils.TextInputs( token_ids=caption.token_ids[:-1], mask=caption.mask[:-1], segment_ids=caption.segment_ids[:-1], positions=caption.positions[:-1]) features["token_outputs"] = text_utils.TextOutputs( token_ids=caption.token_ids[1:], mask=caption.mask[1:]) if params.get("conditional_decoding"): random_span = text_utils.get_random_span( text=captions[rid], p=params["span_sample_p"], max_span_len=params["span_length"]) features["condition_inputs"] = text_utils.build_text_inputs( text=random_span, length=params["condition_length"], lookup_table=lookup_table, segment_id=1, start_token=vocab.ANS) features["object_features"] = image_utils.parse_object_features( features["object_features"], features["object_positions"], params) # Remove extra inputs. features = {f: features[f] for f in features if f in KEYS} # Add dummy inputs for standardization for multi-tasking. footprint = datasets.footprint(params) assert footprint for k, v in footprint.items(): if k not in features: features[k] = v return features
def get_token_mask(token_ids, stop_id): """Create mask for all ids past stop_id (inclusive).""" batch_size = tensor_utils.shape(token_ids, 0) num_tokens = tensor_utils.shape(token_ids, 1) # Create position matrix. idx_range = tf.expand_dims(tf.range(num_tokens), 0) idx_range = tf.tile(idx_range, [batch_size, 1]) # Find positions of stop_id. stop_positions = tf.where(condition=tf.equal(token_ids, stop_id), x=idx_range, y=tf.fill([batch_size, num_tokens], num_tokens)) # Find earliest stop position (length). stop_positions = tf.reduce_min(stop_positions, -1) # Mask out all tokens at positions > stop_id. mask = tf.less_equal(idx_range, tf.expand_dims(stop_positions, -1)) return tf.cast(mask, tf.int32)
def max_scoring_span(start_scores, end_scores, max_length, no_answer_bias=0): """Compute max scoring span, using the sum of start and end scores. Args: start_scores: <float32> [batch_size, seq_len] end_scores: <float32> [batch_size, seq_len] max_length: <int32> Max answer length. no_answer_bias: <float32> Log-odds threshold for "no-answer" selection. I.e. if log p(span=i,j)/p(span=NULL) > no_answer_bias, then select i, j as the span, and NULL otherwise. Returns: start: <int32> [batch_size] end: <int32> [batch_size] """ # Create sparse tensor of size [seq_len]. seq_len = tensor_utils.shape(start_scores, -1) no_answer_bias = tf.scatter_nd([[0]], [no_answer_bias], [seq_len]) no_answer_bias = tf.cast(no_answer_bias, tf.float32) # Apply bias to CLS token logits. no_answer_bias = tf.div(no_answer_bias, 2) start_scores += tf.expand_dims(no_answer_bias, 0) end_scores += tf.expand_dims(no_answer_bias, 0) # Compute outer sum, and mask to be upper triangular. # This gives a matrix of start[i] + end[j] scores, where j >= i. scores = tf.expand_dims(start_scores, 2) + tf.expand_dims(end_scores, 1) mask = (1 - tf.matrix_band_part(tf.ones_like(scores), 0, max_length - 1)) scores -= mask * 1e-4 def map_fn(inputs): flattened = tf.reshape(inputs, [-1]) argmax = tf.argmax(flattened, output_type=tf.int32) indices = tensor_utils.unravel_index_2d(argmax, inputs.shape) score = flattened[argmax] return indices, score # Return i, j indices of max-scoring entry. with tf.device("/cpu"): endpoints, span_scores = tf.map_fn(fn=map_fn, elems=scores, dtype=(tf.int32, tf.float32)) start = endpoints[:, 0] end = endpoints[:, 1] return start, end, span_scores
def parse_object_features(features, positions, params): """Parse ObjectDetectionOutput from TensorProtos.""" features = tf.io.parse_tensor(features, tf.float32) positions = tf.io.parse_tensor(positions, tf.int64) positions = tf.cast(positions, tf.int32) features = features[:params["num_image_regions"]] num_objects = tensor_utils.shape(features, 0) padding = tf.maximum(0, params["num_image_regions"] - num_objects) features = tf.pad(features, [[0, padding], [0, 0]]) positions = tf.pad(positions, [[0, padding]]) features = tf.ensure_shape( features, [params["num_image_regions"], params["image_feature_size"]]) positions = tf.ensure_shape(positions, [params["num_image_regions"]]) mask = tf.pad(tf.ones(num_objects, dtype=tf.int32), [[0, padding]]) mask = tf.ensure_shape(mask, [params["num_image_regions"]]) output = ObjectDetectionOutput(features=features, positions=positions, mask=mask) return output
def beam_search_decode( model, encoder_cache, encoder_cache_mask, start_id, stop_id, segment_id, num_steps, beam_size, alpha=0, reuse=tf.AUTO_REUSE, ): """Decode for a given number of steps.""" true_batch_size = tensor_utils.shape(encoder_cache_mask, 0) num_layers = model.config.num_hidden_layers num_heads = model.config.num_attention_heads head_size = int(model.config.hidden_size / num_heads) def symbols_to_logits_fn(input_ids, i, state): """Go from ids to logits for next symbol.""" # Size of expanded tensor (expanded by beam size). batch_size = tensor_utils.shape(input_ids, 0) # [batch_size, 1] current_step_mask = tf.ones([batch_size, 1], tf.int32) # [batch_size, num_steps] written_mask = tf.cast(tf.less(tf.range(num_steps), i), tf.int32) written_mask = tf.tile(tf.expand_dims(written_mask, 0), [batch_size, 1]) is_written = tf.cast(written_mask, tf.bool) # [batch_size, cache_size + num_steps, num_layers, num_heads, head_size] input_cache = TransformerCache( keys=tf.concat([state.encoder_cache.keys, state.output_cache.keys], 1), values=tf.concat( [state.encoder_cache.values, state.output_cache.values], 1)) # [batch_size, 1, cache_size + num_steps] masks = [state.encoder_cache_mask, written_mask, current_step_mask] attention_mask = tf.concat(masks, axis=1) attention_mask = tf.expand_dims(attention_mask, 1) # sequence_output: [batch_size, 1, hidden_size], # step_cache: [batch_size, 1, num_layers, num_heads, head_size] sequence_output, step_cache = model.compute_transformer( input_ids=input_ids, input_segment_id=tf.fill(tensor_utils.shape(input_ids), segment_id), input_positions=tf.fill(tensor_utils.shape(input_ids), i), attention_mask=attention_mask, input_cache=input_cache, reuse=reuse) # [batch_size, 1, vocab_size] logits = model.compute_logits(sequence_output, reuse=reuse) def update_values(old_values, current_value): """Update stored values with this time step.""" shape = [1] * len(tensor_utils.shape(old_values)) shape[:2] = [batch_size, num_steps] tile = tensor_utils.shape(old_values) tile[:2] = [1, 1] condition = tf.tile(tf.reshape(is_written, shape), tile) tile = [1] * len(tensor_utils.shape(old_values)) tile[1] = num_steps current_value = tf.tile(current_value, tile) return tf.where(condition, old_values, current_value) # [batch_size, num_steps, num_layers, num_heads, head_size] beam_output_cache = TransformerCache( keys=update_values(state.output_cache.keys, step_cache.keys), values=update_values(state.output_cache.values, step_cache.values)) # Return new state. state = DecodeState(encoder_cache=state.encoder_cache, encoder_cache_mask=state.encoder_cache_mask, output_cache=beam_output_cache) return tf.squeeze(logits, 1), state # Initialize output cache with zeros. shape = [true_batch_size, num_steps, num_layers, num_heads, head_size] output_cache = TransformerCache(keys=tf.zeros(shape), values=tf.zeros(shape)) # Initialize state. state = DecodeState(encoder_cache=encoder_cache, encoder_cache_mask=encoder_cache_mask, output_cache=output_cache) # Decode using beam search. decoded_ids, scores, state = beam_search.beam_search( symbols_to_logits_fn=symbols_to_logits_fn, initial_ids=tf.fill([true_batch_size], start_id), eos_id=stop_id, beam_size=beam_size, alpha=alpha, decode_length=num_steps, vocab_size=model.config.vocab_size, states=state, use_tpu=True) # Postprocess. flat_mask = text_utils.get_token_mask( tf.reshape(decoded_ids, [-1, num_steps + 1]), stop_id) mask = tf.reshape(flat_mask, [true_batch_size, beam_size, num_steps + 1]) decoded_ids *= mask return DecodeOutput(decoded_ids, mask, scores)
def model_fn(features, labels, mode, params, vocab): """Model function that satisfies the Estimator API. Args: features: Dictionary of model input tensors. labels: Ununsed. mode: A tf.estimator.ModeKeys value. params: Dictionary of model parameters. vocab: A utils.text_utils.Vocab instance. Returns: spec: A tf.estimator.TPUEstimatorSpec. """ del labels # ---------------------------------------------------------------------------- # INITIALIZATION. # ---------------------------------------------------------------------------- # Update model config from the pre-trained checkpoint. model = transformer_utils.TransformerModel( config=transformer_utils.TransformerConfig.from_dict(params), is_training=(mode == tf_estimator.ModeKeys.TRAIN)) # Initialize QA model. rc_model = hub.Module(params["rc_model"]) # image_features: [batch_size, num_regions, feature_size] # image_positions: [batch_size, num_regions] # image_mask: [batch_size, num_regions] image_features = features["object_features"].features image_positions = features["object_features"].positions image_mask = features["object_features"].mask # Expand mask by 1 to account for the leading [IMG] token. # [batch_size, num_regions + 1] batch_size = tensor_utils.shape(image_mask, 0) input_mask = tf.pad(image_mask, [[0, 0], [1, 0]], constant_values=1) # Encode the image and store the cached transformer values. # [batch_size, num_regions + 1, num_layers, num_heads, head_size] _, input_cache = model.compute_image_transformer( input_ids=tf.fill([batch_size, 1], vocab.t2i(vocab.IMG)), input_image=image_features, input_image_mask=input_mask, input_positions=image_positions) # ---------------------------------------------------------------------------- # TRAINING # ---------------------------------------------------------------------------- if mode == tf_estimator.ModeKeys.TRAIN: # MIXER-style training objective consists of two parts: # 1) Policy gradient on rewarded rollouts. # 2) MLE regularization on references. # The full loss is L_total = L_pg + L_mle. # Step 1: Policy gradient. # Compute and score policy rollouts (multiple per image). rollouts = reward_utils.compute_rollouts(model=model, rc_model=rc_model, features=features, encoder_cache=input_cache, encoder_cache_mask=input_mask, vocab=vocab, params=params) # Using a self-critical baseline, R'(y) = R(y) - b where b = argmax p(y|x), # sample a single rollout with non-zero reward. rollout, reward = reward_utils.sample_from_rollouts( rollouts=rollouts, baseline=rollouts.rewards[params["reward"]][:, 0], reward_type=params["reward"]) # Compute the probablity of the rollout (back-propable). # [batch_size, decode_length, input_length + decode_length] rollout_attention_mask = transformer_utils.compute_attention_mask( token_mask=rollout.mask[:, :-1], input_mask=input_mask) # [batch_size, decode_length, vocab_size] rollout_emb, _ = model.compute_transformer( input_ids=rollout.token_ids[:, :-1], input_segment_id=rollout.segment_ids[:, :-1], input_positions=rollout.positions[:, :-1], attention_mask=rollout_attention_mask, input_cache=input_cache, reuse=tf.AUTO_REUSE) # [batch_size, decode_length, vocab_size] rollout_logits = model.compute_logits(rollout_emb, reuse=tf.AUTO_REUSE) # Compute the RL loss, -R(y) * log p(y|x) # Some elements in this batch are MLE only, mask those out from the loss. rollout_mask = tf.cast(rollout.mask[:, 1:], tf.float32) pg_mask = tf.equal(features["input_type"], datasets.DatasetTypes.VQA) rollout_mask *= tf.expand_dims(tf.cast(pg_mask, tf.float32), 1) rl_loss = tf.losses.sparse_softmax_cross_entropy( labels=rollout.token_ids[:, 1:], logits=rollout_logits, weights=tf.expand_dims(reward, 1) * rollout_mask, reduction=tf.losses.Reduction.SUM) rl_loss = tf.math.divide_no_nan(rl_loss, tf.reduce_sum(rollout_mask)) # Step 2: MLE on references. # [batch_size, decode_length, input_length + decode_length] reference_attention_mask = transformer_utils.compute_attention_mask( token_mask=features["token_inputs"].mask, input_mask=input_mask) # [batch_size, decode_length, hidden_size] target_emb, _ = model.compute_transformer( input_ids=features["token_inputs"].token_ids, input_segment_id=features["token_inputs"].segment_ids, input_positions=features["token_inputs"].positions, attention_mask=reference_attention_mask, input_cache=input_cache, reuse=tf.AUTO_REUSE) # [batch_size, decode_length, vocab_size] target_logits = model.compute_logits(target_emb, reuse=tf.AUTO_REUSE) # Compute the MLE objective (cross-entropy loss). weights = features["token_outputs"].mask ref_mask = tf.equal(features["input_type"], datasets.DatasetTypes.REFERENCE) weights *= tf.expand_dims(tf.cast(ref_mask, tf.int32), 1) reference_loss = tf.losses.sparse_softmax_cross_entropy( labels=features["token_outputs"].token_ids, logits=target_logits, weights=weights) # Add both losses together. loss = rl_loss + reference_loss # BERT-style optimization with linear warmp. train_op = optimization.create_optimizer( loss=loss, init_lr=params["learning_rate"], num_train_steps=params["num_train_steps"], num_warmup_steps=params["num_warmup_steps"], use_tpu=params.get("use_tpu")) # Book-keeping. summaries = tpu_summaries.TpuSummaries(params["model_dir"]) summaries.scalar("loss", loss) # Check what percentage of examples have non-zero reward. total_vqa = tf.reduce_sum(tf.cast(pg_mask, tf.float32)) nonzero = tf.cast(tf.not_equal(reward, 0), tf.float32) nonzero *= tf.cast(pg_mask, tf.float32) total_nonzero = tf.reduce_sum(nonzero) summaries.scalar("density", tf.div_no_nan(total_nonzero, total_vqa)) # Total (non-normalized) reward. reward = rollouts.rewards[params["reward"]][:, 0] reward *= tf.cast(pg_mask, tf.float32) total_reward = tf.reduce_sum(reward) summaries.scalar("reward", tf.div_no_nan(total_reward, total_vqa)) host_call = summaries.get_host_call() else: loss = None train_op = None host_call = None # ---------------------------------------------------------------------------- # TESTING. # ---------------------------------------------------------------------------- if mode == tf_estimator.ModeKeys.PREDICT: decode_output = transformer_utils.beam_search_decode( model=model, encoder_cache=input_cache, encoder_cache_mask=input_mask, start_id=vocab.t2i(vocab.CLS), stop_id=vocab.t2i(vocab.SEP), segment_id=0, num_steps=params["decode_length"], beam_size=params["beam_size"], alpha=params["beam_length_penalty"], reuse=tf.AUTO_REUSE) predictions = dict(image_id=features.get("image_id", -1), question_id=features.get("question_id", -1), token_ids=decode_output.token_ids[:, :, 1:]) else: predictions = None # ---------------------------------------------------------------------------- # WARM-START. # ---------------------------------------------------------------------------- # Initialize from pretrained model. def scaffold_fn(): """Init op run on host.""" checkpoint = params["base_model"] if params["warm_start_path"]: checkpoint = params["warm_start_path"] if checkpoint: checkpoint_utils.init_from_checkpoint(checkpoint) return tf.train.Scaffold() return tf_estimator.tpu.TPUEstimatorSpec( mode=mode, loss=loss, train_op=train_op, predictions=predictions, scaffold_fn=scaffold_fn, host_call=host_call, )
def indicator_score(answer_ids, answer_mask, context_ids, vocab): """Compute indicator score of answer and context. Checks if the answer tokens are a subspan of the context. Args: answer_ids: <int32> [batch_size, answer_length] answer_mask: <int32> [batch_size, answer_length] context_ids: <int32> [batch_size, context_length] vocab: Instance of text_utils.Vocab. Returns: score: <float32> [batch_size] tensor of {0.0, 1.0}. """ batch_size = tensor_utils.shape(answer_ids, 0) # Get cleanable words. remove_ids = list(_get_normalized_set(vocab)) remove_ids = tf.reshape(remove_ids, [1, 1, -1]) remove_ids = tf.tile(remove_ids, [batch_size, 1, 1]) # Clean answer: remove tokens that are in the normalized set. should_keep = tf.reduce_all(tf.not_equal(tf.expand_dims(answer_ids, -1), remove_ids), axis=-1) answer_ids = tensor_utils.boolean_mask(answer_ids, should_keep) answer_mask = tensor_utils.boolean_mask(answer_mask, should_keep) # Clean context: remove tokens that are in the normalized set. should_keep = tf.reduce_all(tf.not_equal(tf.expand_dims(context_ids, -1), remove_ids), axis=-1) context_ids = tensor_utils.boolean_mask(context_ids, should_keep) # Cleaned lengths. answer_len = tensor_utils.shape(answer_ids, 1) context_len = tensor_utils.shape(context_ids, 1) # Pad start of context (to select NULL for over-length indices). context_ids = tf.pad(context_ids, [[0, 0], [1, 0]]) context_len += 1 # Sliding window approach: take the full context of length N and gather # it into a tensor with all windows of length M (a N x M tensor). # [context_len, answer_len] window_idx = tf.range(answer_len) window_idx = tf.tile(tf.expand_dims(window_idx, 0), [context_len, 1]) offsets = tf.expand_dims(tf.range(context_len), 1) window_idx += offsets window_idx *= tf.cast(tf.less(window_idx, context_len), tf.int32) # [batch_size, context_len * answer_len] window_idx = tf.reshape(window_idx, [1, -1]) window_idx = tf.tile(window_idx, [batch_size, 1]) # [batch_size, context_len * answer_len] batch_idx = tf.range(batch_size) batch_idx = tf.expand_dims(batch_idx, 1) batch_idx = tf.tile(batch_idx, [1, context_len * answer_len]) # [batch_size, context_len, answer_len] batch_idx = tf.reshape(batch_idx, [-1]) window_idx = tf.reshape(window_idx, [-1]) coords = tf.stack([batch_idx, window_idx], axis=1) window_ids = tf.gather_nd(context_ids, coords) window_ids = tf.reshape(window_ids, [batch_size, context_len, answer_len]) # [batch_size, context_len, answer_len] answer_mask = tf.expand_dims(answer_mask, 1) window_ids *= answer_mask # Check for equality. The whole window has to match the answer, but only # one window has to count to be a positive indicator value. answer_ids = tf.expand_dims(answer_ids, 1) is_equal = tf.reduce_all(tf.equal(answer_ids, window_ids), axis=-1) score = tf.cast(tf.reduce_any(is_equal, axis=-1), tf.float32) return score
def symbols_to_logits_fn(input_ids, i, state): """Go from ids to logits for next symbol.""" # Size of expanded tensor (expanded by beam size). batch_size = tensor_utils.shape(input_ids, 0) # [batch_size, 1] current_step_mask = tf.ones([batch_size, 1], tf.int32) # [batch_size, num_steps] written_mask = tf.cast(tf.less(tf.range(num_steps), i), tf.int32) written_mask = tf.tile(tf.expand_dims(written_mask, 0), [batch_size, 1]) is_written = tf.cast(written_mask, tf.bool) # [batch_size, cache_size + num_steps, num_layers, num_heads, head_size] input_cache = TransformerCache( keys=tf.concat([state.encoder_cache.keys, state.output_cache.keys], 1), values=tf.concat( [state.encoder_cache.values, state.output_cache.values], 1)) # [batch_size, 1, cache_size + num_steps] masks = [state.encoder_cache_mask, written_mask, current_step_mask] attention_mask = tf.concat(masks, axis=1) attention_mask = tf.expand_dims(attention_mask, 1) # sequence_output: [batch_size, 1, hidden_size], # step_cache: [batch_size, 1, num_layers, num_heads, head_size] sequence_output, step_cache = model.compute_transformer( input_ids=input_ids, input_segment_id=tf.fill(tensor_utils.shape(input_ids), segment_id), input_positions=tf.fill(tensor_utils.shape(input_ids), i), attention_mask=attention_mask, input_cache=input_cache, reuse=reuse) # [batch_size, 1, vocab_size] logits = model.compute_logits(sequence_output, reuse=reuse) def update_values(old_values, current_value): """Update stored values with this time step.""" shape = [1] * len(tensor_utils.shape(old_values)) shape[:2] = [batch_size, num_steps] tile = tensor_utils.shape(old_values) tile[:2] = [1, 1] condition = tf.tile(tf.reshape(is_written, shape), tile) tile = [1] * len(tensor_utils.shape(old_values)) tile[1] = num_steps current_value = tf.tile(current_value, tile) return tf.where(condition, old_values, current_value) # [batch_size, num_steps, num_layers, num_heads, head_size] beam_output_cache = TransformerCache( keys=update_values(state.output_cache.keys, step_cache.keys), values=update_values(state.output_cache.values, step_cache.values)) # Return new state. state = DecodeState(encoder_cache=state.encoder_cache, encoder_cache_mask=state.encoder_cache_mask, output_cache=beam_output_cache) return tf.squeeze(logits, 1), state
def compute_rollouts( model, rc_model, features, encoder_cache, encoder_cache_mask, vocab, params, ): """Rollout model and compute rewards for each sample. Args: model: utils.transformer_utils.TransformerModel instance. rc_model: TF Hub module for extractive QA. features: Input features (questions and answers). encoder_cache: Transformer cache for encoded input. encoder_cache_mask: Input mask for the Transformer cache. vocab: Instance of text_utils.Vocab. params: Model parameters. Returns: rollout: Instance of RolloutOutputs. """ # 1) First rollout the model with top-K beam search. rollout = transformer_utils.beam_search_decode( model=model, encoder_cache=encoder_cache, encoder_cache_mask=encoder_cache_mask, start_id=vocab.t2i(vocab.CLS), stop_id=vocab.t2i(vocab.SEP), segment_id=0, num_steps=params["decode_length"], beam_size=params["num_rollouts"], alpha=params["beam_length_penalty"], reuse=tf.AUTO_REUSE) # [batch_size, num_rollouts, rollout_length] batch_size = tensor_utils.shape(rollout.token_ids, 0) num_rollouts = tensor_utils.shape(rollout.token_ids, 1) rollout_ids = rollout.token_ids rollout_mask = rollout.mask # [batch_size * num_rollouts, rollout_length] rollout_length = tensor_utils.shape(rollout_ids, -1) rollout_ids = tf.reshape(rollout_ids, [-1, rollout_length]) rollout_mask = tf.reshape(rollout_mask, [-1, rollout_length]) # 2) Compute the QA rewards on the rollouts. # [batch_size * num_rollouts, question_length] question = tensor_utils.tile_batch(features["question_inputs"], num_rollouts) # [batch_size * num_rollouts, answer_length] answer = tensor_utils.tile_batch(features["answer_outputs"], num_rollouts) # [batch_size * num_rollouts] rewards = compute_qa_rewards(question_ids=question.token_ids, question_mask=question.mask, answer_ids=answer.token_ids, answer_mask=answer.mask, context_ids=rollout_ids[:, 1:], context_mask=rollout_mask[:, 1:], rc_model=rc_model, vocab=vocab, max_answer_length=params["answer_length"], no_answer_bias=params["no_answer_bias"]) # [batch_size, num_rollouts, ...] reshaped_rewards = {} for k, v in rewards.items(): if len(v.shape) > 1: v = tf.reshape(v, [batch_size, num_rollouts, -1]) else: v = tf.reshape(v, [batch_size, num_rollouts]) reshaped_rewards[k] = v # 3) Combine rollouts and rewards. rollouts = RolloutOutputs(token_ids=rollout.token_ids, mask=rollout.mask, scores=rollout.scores, rewards=reshaped_rewards) return rollouts
def model_fn(features, labels, mode, params, vocab): """Model function that satisfies the Estimator API. Args: features: Dictionary of model input tensors. labels: Ununsed. mode: A tf.estimator.ModeKeys value. params: Dictionary of model parameters. vocab: A utils.text_utils.Vocab instance. Returns: spec: A tf.estimator.TPUEstimatorSpec. """ del labels # ---------------------------------------------------------------------------- # INITIALIZATION. # ---------------------------------------------------------------------------- model = transformer_utils.TransformerModel( config=transformer_utils.TransformerConfig.from_dict(params), is_training=(mode == tf_estimator.ModeKeys.TRAIN)) # image_features: [batch_size, num_regions, feature_size] # image_positions: [batch_size, num_regions] # image_mask: [batch_size, num_regions] image_features = features["object_features"].features image_positions = features["object_features"].positions image_mask = features["object_features"].mask # Expand mask by 1 to account for the leading [IMG] token. # [batch_size, num_regions + 1] batch_size = tensor_utils.shape(image_mask, 0) input_mask = tf.pad(image_mask, [[0, 0], [1, 0]], constant_values=1) # Encode the image and store the cached transformer values. # [batch_size, num_regions + 1, num_layers, num_heads, head_size] _, input_cache = model.compute_image_transformer( input_ids=tf.fill([batch_size, 1], vocab.t2i(vocab.IMG)), input_image=image_features, input_image_mask=input_mask, input_positions=image_positions) if params.get("conditional_decoding"): # Add additional (text) conditioning information to the input cache. # The conditioning information gets to see the image information. # The new input consists of both the image and the extra encoded text. # This is used for the LEARN function of Alg. 1 in the paper. # [batch_size, num_regions + condition_length + 1] input_mask = tf.concat([input_mask, features["condition_inputs"].mask], 1) # [batch_size, condition_length, num_layers, num_heads, head_size] _, condition_cache = model.compute_transformer( input_ids=features["condition_inputs"].token_ids, input_segment_id=features["condition_inputs"].segment_ids, input_positions=features["condition_inputs"].positions, attention_mask=tf.expand_dims(input_mask, 1), input_cache=input_cache, reuse=tf.AUTO_REUSE, conditional=True) # [batch_size, input_length, num_layers, num_heads, head_size] input_cache = transformer_utils.TransformerCache( keys=tf.concat([input_cache.keys, condition_cache.keys], 1), values=tf.concat([input_cache.values, condition_cache.values], 1)) # ---------------------------------------------------------------------------- # TRAINING # ---------------------------------------------------------------------------- if mode == tf_estimator.ModeKeys.TRAIN: # During training, apply forced decoding with a diagonal attention mask. # [batch_size, caption_length - 1, input_length + caption_length - 1] attention_mask = transformer_utils.compute_attention_mask( token_mask=features["token_inputs"].mask, input_mask=input_mask) # [batch_size, caption_length - 1, hidden_size] target_emb, _ = model.compute_transformer( input_ids=features["token_inputs"].token_ids, input_segment_id=features["token_inputs"].segment_ids, input_positions=features["token_inputs"].positions, attention_mask=attention_mask, input_cache=input_cache, reuse=tf.AUTO_REUSE) # [batch_size, caption_length - 1, vocab_size] target_logits = model.compute_logits(target_emb, reuse=tf.AUTO_REUSE) # Compute the MLE objective (cross-entropy loss). loss = tf.losses.sparse_softmax_cross_entropy( labels=features["token_outputs"].token_ids, logits=target_logits, weights=features["token_outputs"].mask) # BERT-style optimization with linear warmp. train_op = optimization.create_optimizer( loss=loss, init_lr=params["learning_rate"], num_train_steps=params["num_train_steps"], num_warmup_steps=params["num_warmup_steps"], use_tpu=params.get("use_tpu")) summaries = tpu_summaries.TpuSummaries(params["model_dir"]) summaries.scalar("loss", loss) host_call = summaries.get_host_call() else: loss = None train_op = None host_call = None # ---------------------------------------------------------------------------- # TESTING. # ---------------------------------------------------------------------------- if mode == tf_estimator.ModeKeys.PREDICT: decode_output = transformer_utils.beam_search_decode( model=model, encoder_cache=input_cache, encoder_cache_mask=input_mask, start_id=vocab.t2i(vocab.CLS), stop_id=vocab.t2i(vocab.SEP), segment_id=0, num_steps=params["decode_length"], beam_size=params["beam_size"], alpha=params["beam_length_penalty"], reuse=tf.AUTO_REUSE) predictions = dict(image_id=features.get("image_id", -1), question_id=features.get("question_id", -1), token_ids=decode_output.token_ids[:, :, 1:]) else: predictions = None # ---------------------------------------------------------------------------- # WARM-START. # ---------------------------------------------------------------------------- # Initialize from pretrained model. def scaffold_fn(): """Init op run on host.""" checkpoint = params.get("warm_start_path") if checkpoint: checkpoint_utils.init_from_checkpoint(checkpoint) return tf.train.Scaffold() return tf_estimator.tpu.TPUEstimatorSpec( mode=mode, loss=loss, train_op=train_op, predictions=predictions, scaffold_fn=scaffold_fn, host_call=host_call, )
def model_fn(features, labels, mode, params, vocab): """Model function.""" del labels assert mode == tf.estimator.ModeKeys.PREDICT, "Mode should be PREDICT." # Initialize transformer model. model = transformer_utils.TransformerModel( config=transformer_utils.TransformerConfig.from_dict(params), is_training=(mode == tf.estimator.ModeKeys.TRAIN)) # image_features: [batch_size, num_regions, feature_size] # image_positions: [batch_size, num_regions] # image_mask: [batch_size, num_regions] image_features = features["object_features"].features image_positions = features["object_features"].positions image_mask = features["object_features"].mask # Expand mask by 1 for IMG token. batch_size = tensor_utils.shape(image_mask, 0) input_mask = tf.pad(image_mask, [[0, 0], [1, 0]], constant_values=1) # [batch_size, num_regions + 1, num_layers, num_heads, head_size] _, input_cache = model.compute_image_transformer( input_ids=tf.fill([batch_size, 1], vocab.t2i(vocab.IMG)), input_image=image_features, input_image_mask=input_mask, input_positions=image_positions) # Add conditioning information to input cache. if params.get("conditional_decoding"): # Add additional (text) conditioning information to the input cache. # The conditioning information gets to see the image information. # The new input consists of both the image and the extra encoded text. # This is used for the LEARN function of Alg. 1 in the paper. # [batch_size, num_regions + condition_length + 1] input_mask = tf.concat([input_mask, features["condition_inputs"].mask], 1) # [batch_size, condition_length, num_layers, num_heads, head_size] _, condition_cache = model.compute_transformer( input_ids=features["condition_inputs"].token_ids, input_segment_id=features["condition_inputs"].segment_ids, input_positions=features["condition_inputs"].positions, attention_mask=tf.expand_dims(input_mask, 1), input_cache=input_cache, reuse=tf.AUTO_REUSE, conditional=True) # [batch_size, input_length, num_layers, num_heads, head_size] input_cache = transformer_utils.TransformerCache( keys=tf.concat([input_cache.keys, condition_cache.keys], 1), values=tf.concat([input_cache.values, condition_cache.values], 1)) # Initialize QA model. rc_model = hub.Module(params["rc_model"]) # Compute rollouts. rollouts = reward_utils.compute_rollouts(model=model, rc_model=rc_model, features=features, encoder_cache=input_cache, encoder_cache_mask=input_mask, vocab=vocab, params=params) # Add to predictions. predictions = dict( image_id=features["image_id"], question_id=features["question_id"], token_ids=rollouts.token_ids[:, :, 1:], scores=rollouts.scores, ) # Add all rewards. for k, v in rollouts.rewards.items(): predictions[k] = v # Initialize base model. def scaffold_fn(): """Init op run on host.""" checkpoint_utils.init_from_checkpoint(params["checkpoint"]) return tf.train.Scaffold() return tf.estimator.tpu.TPUEstimatorSpec( mode=mode, predictions=predictions, scaffold_fn=scaffold_fn, )
def rc_span( question_ids, question_mask, context_ids, context_mask, rc_model, vocab, max_length=10, no_answer_bias=0, ): """Computes exact match score from QA model run on context. Args: question_ids: <int32> [batch_size, question_len] question_mask: <int32> [batch_size, question_len] context_ids: <int32> [batch_size, context_len] context_mask: <int32> [batch_size, context_len] rc_model: Extractive question answering model. vocab: Instance of text_utils.Vocab. max_length: Max answer length. no_answer_bias: Log-odds ratio for answer span over NULL. Returns: score: <float32> [batch_size] """ # Mask out stop id in context if present. stop_id = vocab.t2i(vocab.SEP) stop_mask = tf.cast(tf.not_equal(context_ids, stop_id), tf.int32) context_mask *= stop_mask # Prepare rc inputs. input_ids, input_mask, segment_ids = _get_rc_model_input( question_ids=question_ids, question_mask=question_mask, context_ids=context_ids, context_mask=context_mask, vocab=vocab) # Get start/end logits from RC model. outputs = rc_model(inputs=dict(input_ids=input_ids, input_mask=input_mask, segment_ids=segment_ids), signature="extractive_qa", as_dict=True) # Dimensions batch_size = tensor_utils.shape(input_ids, 0) context_len = tensor_utils.shape(input_ids, 1) # Decode span. start_logits = tf.reshape(outputs["start_logits"], [-1, context_len]) end_logits = tf.reshape(outputs["end_logits"], [-1, context_len]) start, end, span_scores = max_scoring_span(start_scores=start_logits, end_scores=end_logits, max_length=max_length, no_answer_bias=no_answer_bias) # Expand shape to be compatible for broadcasting. start = tf.reshape(start, [-1, 1]) end = tf.reshape(end, [-1, 1]) # Create mask where mask[i, j] = True if i >= start and j <= end. # [batch_size, max_rc_input_len] mask = tf.tile(tf.expand_dims(tf.range(context_len), 0), [batch_size, 1]) mask = tf.logical_and(tf.greater_equal(mask, start), tf.less_equal(mask, end)) # Gather padded answer span from context. answer_span = tensor_utils.boolean_mask(input_ids, mask) return answer_span, span_scores