def model_fn(features, labels, mode, params, vocab): """Model function that satisfies the Estimator API. Args: features: Dictionary of model input tensors. labels: Ununsed. mode: A tf.estimator.ModeKeys value. params: Dictionary of model parameters. vocab: A utils.text_utils.Vocab instance. Returns: spec: A tf.estimator.TPUEstimatorSpec. """ del labels # ---------------------------------------------------------------------------- # INITIALIZATION. # ---------------------------------------------------------------------------- model = transformer_utils.TransformerModel( config=transformer_utils.TransformerConfig.from_dict(params), is_training=(mode == tf_estimator.ModeKeys.TRAIN)) # image_features: [batch_size, num_regions, feature_size] # image_positions: [batch_size, num_regions] # image_mask: [batch_size, num_regions] image_features = features["object_features"].features image_positions = features["object_features"].positions image_mask = features["object_features"].mask # Expand mask by 1 to account for the leading [IMG] token. # [batch_size, num_regions + 1] batch_size = tensor_utils.shape(image_mask, 0) input_mask = tf.pad(image_mask, [[0, 0], [1, 0]], constant_values=1) # Encode the image and store the cached transformer values. # [batch_size, num_regions + 1, num_layers, num_heads, head_size] _, input_cache = model.compute_image_transformer( input_ids=tf.fill([batch_size, 1], vocab.t2i(vocab.IMG)), input_image=image_features, input_image_mask=input_mask, input_positions=image_positions) if params.get("conditional_decoding"): # Add additional (text) conditioning information to the input cache. # The conditioning information gets to see the image information. # The new input consists of both the image and the extra encoded text. # This is used for the LEARN function of Alg. 1 in the paper. # [batch_size, num_regions + condition_length + 1] input_mask = tf.concat([input_mask, features["condition_inputs"].mask], 1) # [batch_size, condition_length, num_layers, num_heads, head_size] _, condition_cache = model.compute_transformer( input_ids=features["condition_inputs"].token_ids, input_segment_id=features["condition_inputs"].segment_ids, input_positions=features["condition_inputs"].positions, attention_mask=tf.expand_dims(input_mask, 1), input_cache=input_cache, reuse=tf.AUTO_REUSE, conditional=True) # [batch_size, input_length, num_layers, num_heads, head_size] input_cache = transformer_utils.TransformerCache( keys=tf.concat([input_cache.keys, condition_cache.keys], 1), values=tf.concat([input_cache.values, condition_cache.values], 1)) # ---------------------------------------------------------------------------- # TRAINING # ---------------------------------------------------------------------------- if mode == tf_estimator.ModeKeys.TRAIN: # During training, apply forced decoding with a diagonal attention mask. # [batch_size, caption_length - 1, input_length + caption_length - 1] attention_mask = transformer_utils.compute_attention_mask( token_mask=features["token_inputs"].mask, input_mask=input_mask) # [batch_size, caption_length - 1, hidden_size] target_emb, _ = model.compute_transformer( input_ids=features["token_inputs"].token_ids, input_segment_id=features["token_inputs"].segment_ids, input_positions=features["token_inputs"].positions, attention_mask=attention_mask, input_cache=input_cache, reuse=tf.AUTO_REUSE) # [batch_size, caption_length - 1, vocab_size] target_logits = model.compute_logits(target_emb, reuse=tf.AUTO_REUSE) # Compute the MLE objective (cross-entropy loss). loss = tf.losses.sparse_softmax_cross_entropy( labels=features["token_outputs"].token_ids, logits=target_logits, weights=features["token_outputs"].mask) # BERT-style optimization with linear warmp. train_op = optimization.create_optimizer( loss=loss, init_lr=params["learning_rate"], num_train_steps=params["num_train_steps"], num_warmup_steps=params["num_warmup_steps"], use_tpu=params.get("use_tpu")) summaries = tpu_summaries.TpuSummaries(params["model_dir"]) summaries.scalar("loss", loss) host_call = summaries.get_host_call() else: loss = None train_op = None host_call = None # ---------------------------------------------------------------------------- # TESTING. # ---------------------------------------------------------------------------- if mode == tf_estimator.ModeKeys.PREDICT: decode_output = transformer_utils.beam_search_decode( model=model, encoder_cache=input_cache, encoder_cache_mask=input_mask, start_id=vocab.t2i(vocab.CLS), stop_id=vocab.t2i(vocab.SEP), segment_id=0, num_steps=params["decode_length"], beam_size=params["beam_size"], alpha=params["beam_length_penalty"], reuse=tf.AUTO_REUSE) predictions = dict(image_id=features.get("image_id", -1), question_id=features.get("question_id", -1), token_ids=decode_output.token_ids[:, :, 1:]) else: predictions = None # ---------------------------------------------------------------------------- # WARM-START. # ---------------------------------------------------------------------------- # Initialize from pretrained model. def scaffold_fn(): """Init op run on host.""" checkpoint = params.get("warm_start_path") if checkpoint: checkpoint_utils.init_from_checkpoint(checkpoint) return tf.train.Scaffold() return tf_estimator.tpu.TPUEstimatorSpec( mode=mode, loss=loss, train_op=train_op, predictions=predictions, scaffold_fn=scaffold_fn, host_call=host_call, )
def model_fn(features, labels, mode, params, vocab): """Model function that satisfies the Estimator API. Args: features: Dictionary of model input tensors. labels: Ununsed. mode: A tf.estimator.ModeKeys value. params: Dictionary of model parameters. vocab: A utils.text_utils.Vocab instance. Returns: spec: A tf.estimator.TPUEstimatorSpec. """ del labels # ---------------------------------------------------------------------------- # INITIALIZATION. # ---------------------------------------------------------------------------- # Update model config from the pre-trained checkpoint. model = transformer_utils.TransformerModel( config=transformer_utils.TransformerConfig.from_dict(params), is_training=(mode == tf_estimator.ModeKeys.TRAIN)) # Initialize QA model. rc_model = hub.Module(params["rc_model"]) # image_features: [batch_size, num_regions, feature_size] # image_positions: [batch_size, num_regions] # image_mask: [batch_size, num_regions] image_features = features["object_features"].features image_positions = features["object_features"].positions image_mask = features["object_features"].mask # Expand mask by 1 to account for the leading [IMG] token. # [batch_size, num_regions + 1] batch_size = tensor_utils.shape(image_mask, 0) input_mask = tf.pad(image_mask, [[0, 0], [1, 0]], constant_values=1) # Encode the image and store the cached transformer values. # [batch_size, num_regions + 1, num_layers, num_heads, head_size] _, input_cache = model.compute_image_transformer( input_ids=tf.fill([batch_size, 1], vocab.t2i(vocab.IMG)), input_image=image_features, input_image_mask=input_mask, input_positions=image_positions) # ---------------------------------------------------------------------------- # TRAINING # ---------------------------------------------------------------------------- if mode == tf_estimator.ModeKeys.TRAIN: # MIXER-style training objective consists of two parts: # 1) Policy gradient on rewarded rollouts. # 2) MLE regularization on references. # The full loss is L_total = L_pg + L_mle. # Step 1: Policy gradient. # Compute and score policy rollouts (multiple per image). rollouts = reward_utils.compute_rollouts(model=model, rc_model=rc_model, features=features, encoder_cache=input_cache, encoder_cache_mask=input_mask, vocab=vocab, params=params) # Using a self-critical baseline, R'(y) = R(y) - b where b = argmax p(y|x), # sample a single rollout with non-zero reward. rollout, reward = reward_utils.sample_from_rollouts( rollouts=rollouts, baseline=rollouts.rewards[params["reward"]][:, 0], reward_type=params["reward"]) # Compute the probablity of the rollout (back-propable). # [batch_size, decode_length, input_length + decode_length] rollout_attention_mask = transformer_utils.compute_attention_mask( token_mask=rollout.mask[:, :-1], input_mask=input_mask) # [batch_size, decode_length, vocab_size] rollout_emb, _ = model.compute_transformer( input_ids=rollout.token_ids[:, :-1], input_segment_id=rollout.segment_ids[:, :-1], input_positions=rollout.positions[:, :-1], attention_mask=rollout_attention_mask, input_cache=input_cache, reuse=tf.AUTO_REUSE) # [batch_size, decode_length, vocab_size] rollout_logits = model.compute_logits(rollout_emb, reuse=tf.AUTO_REUSE) # Compute the RL loss, -R(y) * log p(y|x) # Some elements in this batch are MLE only, mask those out from the loss. rollout_mask = tf.cast(rollout.mask[:, 1:], tf.float32) pg_mask = tf.equal(features["input_type"], datasets.DatasetTypes.VQA) rollout_mask *= tf.expand_dims(tf.cast(pg_mask, tf.float32), 1) rl_loss = tf.losses.sparse_softmax_cross_entropy( labels=rollout.token_ids[:, 1:], logits=rollout_logits, weights=tf.expand_dims(reward, 1) * rollout_mask, reduction=tf.losses.Reduction.SUM) rl_loss = tf.math.divide_no_nan(rl_loss, tf.reduce_sum(rollout_mask)) # Step 2: MLE on references. # [batch_size, decode_length, input_length + decode_length] reference_attention_mask = transformer_utils.compute_attention_mask( token_mask=features["token_inputs"].mask, input_mask=input_mask) # [batch_size, decode_length, hidden_size] target_emb, _ = model.compute_transformer( input_ids=features["token_inputs"].token_ids, input_segment_id=features["token_inputs"].segment_ids, input_positions=features["token_inputs"].positions, attention_mask=reference_attention_mask, input_cache=input_cache, reuse=tf.AUTO_REUSE) # [batch_size, decode_length, vocab_size] target_logits = model.compute_logits(target_emb, reuse=tf.AUTO_REUSE) # Compute the MLE objective (cross-entropy loss). weights = features["token_outputs"].mask ref_mask = tf.equal(features["input_type"], datasets.DatasetTypes.REFERENCE) weights *= tf.expand_dims(tf.cast(ref_mask, tf.int32), 1) reference_loss = tf.losses.sparse_softmax_cross_entropy( labels=features["token_outputs"].token_ids, logits=target_logits, weights=weights) # Add both losses together. loss = rl_loss + reference_loss # BERT-style optimization with linear warmp. train_op = optimization.create_optimizer( loss=loss, init_lr=params["learning_rate"], num_train_steps=params["num_train_steps"], num_warmup_steps=params["num_warmup_steps"], use_tpu=params.get("use_tpu")) # Book-keeping. summaries = tpu_summaries.TpuSummaries(params["model_dir"]) summaries.scalar("loss", loss) # Check what percentage of examples have non-zero reward. total_vqa = tf.reduce_sum(tf.cast(pg_mask, tf.float32)) nonzero = tf.cast(tf.not_equal(reward, 0), tf.float32) nonzero *= tf.cast(pg_mask, tf.float32) total_nonzero = tf.reduce_sum(nonzero) summaries.scalar("density", tf.div_no_nan(total_nonzero, total_vqa)) # Total (non-normalized) reward. reward = rollouts.rewards[params["reward"]][:, 0] reward *= tf.cast(pg_mask, tf.float32) total_reward = tf.reduce_sum(reward) summaries.scalar("reward", tf.div_no_nan(total_reward, total_vqa)) host_call = summaries.get_host_call() else: loss = None train_op = None host_call = None # ---------------------------------------------------------------------------- # TESTING. # ---------------------------------------------------------------------------- if mode == tf_estimator.ModeKeys.PREDICT: decode_output = transformer_utils.beam_search_decode( model=model, encoder_cache=input_cache, encoder_cache_mask=input_mask, start_id=vocab.t2i(vocab.CLS), stop_id=vocab.t2i(vocab.SEP), segment_id=0, num_steps=params["decode_length"], beam_size=params["beam_size"], alpha=params["beam_length_penalty"], reuse=tf.AUTO_REUSE) predictions = dict(image_id=features.get("image_id", -1), question_id=features.get("question_id", -1), token_ids=decode_output.token_ids[:, :, 1:]) else: predictions = None # ---------------------------------------------------------------------------- # WARM-START. # ---------------------------------------------------------------------------- # Initialize from pretrained model. def scaffold_fn(): """Init op run on host.""" checkpoint = params["base_model"] if params["warm_start_path"]: checkpoint = params["warm_start_path"] if checkpoint: checkpoint_utils.init_from_checkpoint(checkpoint) return tf.train.Scaffold() return tf_estimator.tpu.TPUEstimatorSpec( mode=mode, loss=loss, train_op=train_op, predictions=predictions, scaffold_fn=scaffold_fn, host_call=host_call, )
def compute_rollouts( model, rc_model, features, encoder_cache, encoder_cache_mask, vocab, params, ): """Rollout model and compute rewards for each sample. Args: model: utils.transformer_utils.TransformerModel instance. rc_model: TF Hub module for extractive QA. features: Input features (questions and answers). encoder_cache: Transformer cache for encoded input. encoder_cache_mask: Input mask for the Transformer cache. vocab: Instance of text_utils.Vocab. params: Model parameters. Returns: rollout: Instance of RolloutOutputs. """ # 1) First rollout the model with top-K beam search. rollout = transformer_utils.beam_search_decode( model=model, encoder_cache=encoder_cache, encoder_cache_mask=encoder_cache_mask, start_id=vocab.t2i(vocab.CLS), stop_id=vocab.t2i(vocab.SEP), segment_id=0, num_steps=params["decode_length"], beam_size=params["num_rollouts"], alpha=params["beam_length_penalty"], reuse=tf.AUTO_REUSE) # [batch_size, num_rollouts, rollout_length] batch_size = tensor_utils.shape(rollout.token_ids, 0) num_rollouts = tensor_utils.shape(rollout.token_ids, 1) rollout_ids = rollout.token_ids rollout_mask = rollout.mask # [batch_size * num_rollouts, rollout_length] rollout_length = tensor_utils.shape(rollout_ids, -1) rollout_ids = tf.reshape(rollout_ids, [-1, rollout_length]) rollout_mask = tf.reshape(rollout_mask, [-1, rollout_length]) # 2) Compute the QA rewards on the rollouts. # [batch_size * num_rollouts, question_length] question = tensor_utils.tile_batch(features["question_inputs"], num_rollouts) # [batch_size * num_rollouts, answer_length] answer = tensor_utils.tile_batch(features["answer_outputs"], num_rollouts) # [batch_size * num_rollouts] rewards = compute_qa_rewards(question_ids=question.token_ids, question_mask=question.mask, answer_ids=answer.token_ids, answer_mask=answer.mask, context_ids=rollout_ids[:, 1:], context_mask=rollout_mask[:, 1:], rc_model=rc_model, vocab=vocab, max_answer_length=params["answer_length"], no_answer_bias=params["no_answer_bias"]) # [batch_size, num_rollouts, ...] reshaped_rewards = {} for k, v in rewards.items(): if len(v.shape) > 1: v = tf.reshape(v, [batch_size, num_rollouts, -1]) else: v = tf.reshape(v, [batch_size, num_rollouts]) reshaped_rewards[k] = v # 3) Combine rollouts and rewards. rollouts = RolloutOutputs(token_ids=rollout.token_ids, mask=rollout.mask, scores=rollout.scores, rewards=reshaped_rewards) return rollouts