示例#1
0
def model_fn(features, labels, mode, params, vocab):
    """Model function that satisfies the Estimator API.

  Args:
    features: Dictionary of model input tensors.
    labels: Ununsed.
    mode: A tf.estimator.ModeKeys value.
    params: Dictionary of model parameters.
    vocab: A utils.text_utils.Vocab instance.

  Returns:
    spec: A tf.estimator.TPUEstimatorSpec.
  """
    del labels

    # ----------------------------------------------------------------------------
    # INITIALIZATION.
    # ----------------------------------------------------------------------------

    model = transformer_utils.TransformerModel(
        config=transformer_utils.TransformerConfig.from_dict(params),
        is_training=(mode == tf_estimator.ModeKeys.TRAIN))

    # image_features: [batch_size, num_regions, feature_size]
    # image_positions: [batch_size, num_regions]
    # image_mask: [batch_size, num_regions]
    image_features = features["object_features"].features
    image_positions = features["object_features"].positions
    image_mask = features["object_features"].mask

    # Expand mask by 1 to account for the leading [IMG] token.
    # [batch_size, num_regions + 1]
    batch_size = tensor_utils.shape(image_mask, 0)
    input_mask = tf.pad(image_mask, [[0, 0], [1, 0]], constant_values=1)

    # Encode the image and store the cached transformer values.
    # [batch_size, num_regions + 1, num_layers, num_heads, head_size]
    _, input_cache = model.compute_image_transformer(
        input_ids=tf.fill([batch_size, 1], vocab.t2i(vocab.IMG)),
        input_image=image_features,
        input_image_mask=input_mask,
        input_positions=image_positions)

    if params.get("conditional_decoding"):
        # Add additional (text) conditioning information to the input cache.
        # The conditioning information gets to see the image information.
        # The new input consists of both the image and the extra encoded text.
        # This is used for the LEARN function of Alg. 1 in the paper.

        # [batch_size, num_regions + condition_length + 1]
        input_mask = tf.concat([input_mask, features["condition_inputs"].mask],
                               1)

        # [batch_size, condition_length, num_layers, num_heads, head_size]
        _, condition_cache = model.compute_transformer(
            input_ids=features["condition_inputs"].token_ids,
            input_segment_id=features["condition_inputs"].segment_ids,
            input_positions=features["condition_inputs"].positions,
            attention_mask=tf.expand_dims(input_mask, 1),
            input_cache=input_cache,
            reuse=tf.AUTO_REUSE,
            conditional=True)

        # [batch_size, input_length, num_layers, num_heads, head_size]
        input_cache = transformer_utils.TransformerCache(
            keys=tf.concat([input_cache.keys, condition_cache.keys], 1),
            values=tf.concat([input_cache.values, condition_cache.values], 1))

    # ----------------------------------------------------------------------------
    # TRAINING
    # ----------------------------------------------------------------------------

    if mode == tf_estimator.ModeKeys.TRAIN:
        # During training, apply forced decoding with a diagonal attention mask.
        # [batch_size, caption_length - 1, input_length + caption_length - 1]
        attention_mask = transformer_utils.compute_attention_mask(
            token_mask=features["token_inputs"].mask, input_mask=input_mask)

        # [batch_size, caption_length - 1, hidden_size]
        target_emb, _ = model.compute_transformer(
            input_ids=features["token_inputs"].token_ids,
            input_segment_id=features["token_inputs"].segment_ids,
            input_positions=features["token_inputs"].positions,
            attention_mask=attention_mask,
            input_cache=input_cache,
            reuse=tf.AUTO_REUSE)

        # [batch_size, caption_length - 1, vocab_size]
        target_logits = model.compute_logits(target_emb, reuse=tf.AUTO_REUSE)

        # Compute the MLE objective (cross-entropy loss).
        loss = tf.losses.sparse_softmax_cross_entropy(
            labels=features["token_outputs"].token_ids,
            logits=target_logits,
            weights=features["token_outputs"].mask)

        # BERT-style optimization with linear warmp.
        train_op = optimization.create_optimizer(
            loss=loss,
            init_lr=params["learning_rate"],
            num_train_steps=params["num_train_steps"],
            num_warmup_steps=params["num_warmup_steps"],
            use_tpu=params.get("use_tpu"))

        summaries = tpu_summaries.TpuSummaries(params["model_dir"])
        summaries.scalar("loss", loss)
        host_call = summaries.get_host_call()
    else:
        loss = None
        train_op = None
        host_call = None

    # ----------------------------------------------------------------------------
    # TESTING.
    # ----------------------------------------------------------------------------

    if mode == tf_estimator.ModeKeys.PREDICT:
        decode_output = transformer_utils.beam_search_decode(
            model=model,
            encoder_cache=input_cache,
            encoder_cache_mask=input_mask,
            start_id=vocab.t2i(vocab.CLS),
            stop_id=vocab.t2i(vocab.SEP),
            segment_id=0,
            num_steps=params["decode_length"],
            beam_size=params["beam_size"],
            alpha=params["beam_length_penalty"],
            reuse=tf.AUTO_REUSE)
        predictions = dict(image_id=features.get("image_id", -1),
                           question_id=features.get("question_id", -1),
                           token_ids=decode_output.token_ids[:, :, 1:])
    else:
        predictions = None

    # ----------------------------------------------------------------------------
    # WARM-START.
    # ----------------------------------------------------------------------------

    # Initialize from pretrained model.
    def scaffold_fn():
        """Init op run on host."""
        checkpoint = params.get("warm_start_path")
        if checkpoint:
            checkpoint_utils.init_from_checkpoint(checkpoint)
        return tf.train.Scaffold()

    return tf_estimator.tpu.TPUEstimatorSpec(
        mode=mode,
        loss=loss,
        train_op=train_op,
        predictions=predictions,
        scaffold_fn=scaffold_fn,
        host_call=host_call,
    )
def model_fn(features, labels, mode, params, vocab):
    """Model function that satisfies the Estimator API.

  Args:
    features: Dictionary of model input tensors.
    labels: Ununsed.
    mode: A tf.estimator.ModeKeys value.
    params: Dictionary of model parameters.
    vocab: A utils.text_utils.Vocab instance.

  Returns:
    spec: A tf.estimator.TPUEstimatorSpec.
  """
    del labels

    # ----------------------------------------------------------------------------
    # INITIALIZATION.
    # ----------------------------------------------------------------------------

    # Update model config from the pre-trained checkpoint.
    model = transformer_utils.TransformerModel(
        config=transformer_utils.TransformerConfig.from_dict(params),
        is_training=(mode == tf_estimator.ModeKeys.TRAIN))

    # Initialize QA model.
    rc_model = hub.Module(params["rc_model"])

    # image_features: [batch_size, num_regions, feature_size]
    # image_positions: [batch_size, num_regions]
    # image_mask: [batch_size, num_regions]
    image_features = features["object_features"].features
    image_positions = features["object_features"].positions
    image_mask = features["object_features"].mask

    # Expand mask by 1 to account for the leading [IMG] token.
    # [batch_size, num_regions + 1]
    batch_size = tensor_utils.shape(image_mask, 0)
    input_mask = tf.pad(image_mask, [[0, 0], [1, 0]], constant_values=1)

    # Encode the image and store the cached transformer values.
    # [batch_size, num_regions + 1, num_layers, num_heads, head_size]
    _, input_cache = model.compute_image_transformer(
        input_ids=tf.fill([batch_size, 1], vocab.t2i(vocab.IMG)),
        input_image=image_features,
        input_image_mask=input_mask,
        input_positions=image_positions)

    # ----------------------------------------------------------------------------
    # TRAINING
    # ----------------------------------------------------------------------------

    if mode == tf_estimator.ModeKeys.TRAIN:
        # MIXER-style training objective consists of two parts:
        #   1) Policy gradient on rewarded rollouts.
        #   2) MLE regularization on references.
        # The full loss is L_total = L_pg + L_mle.

        # Step 1: Policy gradient.
        # Compute and score policy rollouts (multiple per image).
        rollouts = reward_utils.compute_rollouts(model=model,
                                                 rc_model=rc_model,
                                                 features=features,
                                                 encoder_cache=input_cache,
                                                 encoder_cache_mask=input_mask,
                                                 vocab=vocab,
                                                 params=params)

        # Using a self-critical baseline, R'(y) = R(y) - b where b = argmax p(y|x),
        # sample a single rollout with non-zero reward.
        rollout, reward = reward_utils.sample_from_rollouts(
            rollouts=rollouts,
            baseline=rollouts.rewards[params["reward"]][:, 0],
            reward_type=params["reward"])

        # Compute the probablity of the rollout (back-propable).
        # [batch_size, decode_length, input_length + decode_length]
        rollout_attention_mask = transformer_utils.compute_attention_mask(
            token_mask=rollout.mask[:, :-1], input_mask=input_mask)

        # [batch_size, decode_length, vocab_size]
        rollout_emb, _ = model.compute_transformer(
            input_ids=rollout.token_ids[:, :-1],
            input_segment_id=rollout.segment_ids[:, :-1],
            input_positions=rollout.positions[:, :-1],
            attention_mask=rollout_attention_mask,
            input_cache=input_cache,
            reuse=tf.AUTO_REUSE)

        # [batch_size, decode_length, vocab_size]
        rollout_logits = model.compute_logits(rollout_emb, reuse=tf.AUTO_REUSE)

        # Compute the RL loss, -R(y) * log p(y|x)
        # Some elements in this batch are MLE only, mask those out from the loss.
        rollout_mask = tf.cast(rollout.mask[:, 1:], tf.float32)
        pg_mask = tf.equal(features["input_type"], datasets.DatasetTypes.VQA)
        rollout_mask *= tf.expand_dims(tf.cast(pg_mask, tf.float32), 1)
        rl_loss = tf.losses.sparse_softmax_cross_entropy(
            labels=rollout.token_ids[:, 1:],
            logits=rollout_logits,
            weights=tf.expand_dims(reward, 1) * rollout_mask,
            reduction=tf.losses.Reduction.SUM)
        rl_loss = tf.math.divide_no_nan(rl_loss, tf.reduce_sum(rollout_mask))

        # Step 2: MLE on references.
        # [batch_size, decode_length, input_length + decode_length]
        reference_attention_mask = transformer_utils.compute_attention_mask(
            token_mask=features["token_inputs"].mask, input_mask=input_mask)

        # [batch_size, decode_length, hidden_size]
        target_emb, _ = model.compute_transformer(
            input_ids=features["token_inputs"].token_ids,
            input_segment_id=features["token_inputs"].segment_ids,
            input_positions=features["token_inputs"].positions,
            attention_mask=reference_attention_mask,
            input_cache=input_cache,
            reuse=tf.AUTO_REUSE)

        # [batch_size, decode_length, vocab_size]
        target_logits = model.compute_logits(target_emb, reuse=tf.AUTO_REUSE)

        # Compute the MLE objective (cross-entropy loss).
        weights = features["token_outputs"].mask
        ref_mask = tf.equal(features["input_type"],
                            datasets.DatasetTypes.REFERENCE)
        weights *= tf.expand_dims(tf.cast(ref_mask, tf.int32), 1)
        reference_loss = tf.losses.sparse_softmax_cross_entropy(
            labels=features["token_outputs"].token_ids,
            logits=target_logits,
            weights=weights)

        # Add both losses together.
        loss = rl_loss + reference_loss

        # BERT-style optimization with linear warmp.
        train_op = optimization.create_optimizer(
            loss=loss,
            init_lr=params["learning_rate"],
            num_train_steps=params["num_train_steps"],
            num_warmup_steps=params["num_warmup_steps"],
            use_tpu=params.get("use_tpu"))

        # Book-keeping.
        summaries = tpu_summaries.TpuSummaries(params["model_dir"])
        summaries.scalar("loss", loss)

        # Check what percentage of examples have non-zero reward.
        total_vqa = tf.reduce_sum(tf.cast(pg_mask, tf.float32))
        nonzero = tf.cast(tf.not_equal(reward, 0), tf.float32)
        nonzero *= tf.cast(pg_mask, tf.float32)
        total_nonzero = tf.reduce_sum(nonzero)
        summaries.scalar("density", tf.div_no_nan(total_nonzero, total_vqa))

        # Total (non-normalized) reward.
        reward = rollouts.rewards[params["reward"]][:, 0]
        reward *= tf.cast(pg_mask, tf.float32)
        total_reward = tf.reduce_sum(reward)
        summaries.scalar("reward", tf.div_no_nan(total_reward, total_vqa))
        host_call = summaries.get_host_call()
    else:
        loss = None
        train_op = None
        host_call = None

    # ----------------------------------------------------------------------------
    # TESTING.
    # ----------------------------------------------------------------------------

    if mode == tf_estimator.ModeKeys.PREDICT:
        decode_output = transformer_utils.beam_search_decode(
            model=model,
            encoder_cache=input_cache,
            encoder_cache_mask=input_mask,
            start_id=vocab.t2i(vocab.CLS),
            stop_id=vocab.t2i(vocab.SEP),
            segment_id=0,
            num_steps=params["decode_length"],
            beam_size=params["beam_size"],
            alpha=params["beam_length_penalty"],
            reuse=tf.AUTO_REUSE)
        predictions = dict(image_id=features.get("image_id", -1),
                           question_id=features.get("question_id", -1),
                           token_ids=decode_output.token_ids[:, :, 1:])
    else:
        predictions = None

    # ----------------------------------------------------------------------------
    # WARM-START.
    # ----------------------------------------------------------------------------

    # Initialize from pretrained model.
    def scaffold_fn():
        """Init op run on host."""
        checkpoint = params["base_model"]
        if params["warm_start_path"]:
            checkpoint = params["warm_start_path"]
        if checkpoint:
            checkpoint_utils.init_from_checkpoint(checkpoint)
        return tf.train.Scaffold()

    return tf_estimator.tpu.TPUEstimatorSpec(
        mode=mode,
        loss=loss,
        train_op=train_op,
        predictions=predictions,
        scaffold_fn=scaffold_fn,
        host_call=host_call,
    )
def compute_rollouts(
    model,
    rc_model,
    features,
    encoder_cache,
    encoder_cache_mask,
    vocab,
    params,
):
    """Rollout model and compute rewards for each sample.

  Args:
    model: utils.transformer_utils.TransformerModel instance.
    rc_model: TF Hub module for extractive QA.
    features: Input features (questions and answers).
    encoder_cache: Transformer cache for encoded input.
    encoder_cache_mask: Input mask for the Transformer cache.
    vocab: Instance of text_utils.Vocab.
    params: Model parameters.

  Returns:
    rollout: Instance of RolloutOutputs.
  """
    # 1) First rollout the model with top-K beam search.
    rollout = transformer_utils.beam_search_decode(
        model=model,
        encoder_cache=encoder_cache,
        encoder_cache_mask=encoder_cache_mask,
        start_id=vocab.t2i(vocab.CLS),
        stop_id=vocab.t2i(vocab.SEP),
        segment_id=0,
        num_steps=params["decode_length"],
        beam_size=params["num_rollouts"],
        alpha=params["beam_length_penalty"],
        reuse=tf.AUTO_REUSE)

    # [batch_size, num_rollouts, rollout_length]
    batch_size = tensor_utils.shape(rollout.token_ids, 0)
    num_rollouts = tensor_utils.shape(rollout.token_ids, 1)
    rollout_ids = rollout.token_ids
    rollout_mask = rollout.mask

    # [batch_size * num_rollouts, rollout_length]
    rollout_length = tensor_utils.shape(rollout_ids, -1)
    rollout_ids = tf.reshape(rollout_ids, [-1, rollout_length])
    rollout_mask = tf.reshape(rollout_mask, [-1, rollout_length])

    # 2) Compute the QA rewards on the rollouts.
    # [batch_size * num_rollouts, question_length]
    question = tensor_utils.tile_batch(features["question_inputs"],
                                       num_rollouts)

    # [batch_size * num_rollouts, answer_length]
    answer = tensor_utils.tile_batch(features["answer_outputs"], num_rollouts)

    # [batch_size * num_rollouts]
    rewards = compute_qa_rewards(question_ids=question.token_ids,
                                 question_mask=question.mask,
                                 answer_ids=answer.token_ids,
                                 answer_mask=answer.mask,
                                 context_ids=rollout_ids[:, 1:],
                                 context_mask=rollout_mask[:, 1:],
                                 rc_model=rc_model,
                                 vocab=vocab,
                                 max_answer_length=params["answer_length"],
                                 no_answer_bias=params["no_answer_bias"])

    # [batch_size, num_rollouts, ...]
    reshaped_rewards = {}
    for k, v in rewards.items():
        if len(v.shape) > 1:
            v = tf.reshape(v, [batch_size, num_rollouts, -1])
        else:
            v = tf.reshape(v, [batch_size, num_rollouts])
        reshaped_rewards[k] = v

    # 3) Combine rollouts and rewards.
    rollouts = RolloutOutputs(token_ids=rollout.token_ids,
                              mask=rollout.mask,
                              scores=rollout.scores,
                              rewards=reshaped_rewards)

    return rollouts