Пример #1
0
def _get_rc_model_input(
    question_ids,
    question_mask,
    context_ids,
    context_mask,
    vocab,
):
    """Create RC module input from separate batched components.

  Args:
    question_ids: <int32> [batch_size, question_len]
    question_mask: <int32> [batch_size, question_len]
    context_ids: <int32> [batch_size, context_len]
    context_mask: <int32> [batch_size, context_len]
    vocab: Instance of text_utils.Vocab.

  Returns:
    input_ids: <int32> [batch_size, rc_input_len]
    input_mask: <int32> [batch_size, rc_input_len]
    segment_ids: <int32> [batch_size, rc_input_len]
  """
    # Get batch size.
    batch_size = tensor_utils.shape(context_ids, 0)

    # Get special tokens.
    cls = vocab.t2i(vocab.CLS)
    sep = vocab.t2i(vocab.SEP)

    # Join question, context, and special tokens.
    cls_batch = tf.fill([batch_size, 1], cls)
    sep_batch = tf.fill([batch_size, 1], sep)
    input_ids = tf.concat(
        [cls_batch, question_ids, sep_batch, context_ids, sep_batch], axis=1)

    # Create and join segment ids.
    segment_a_ids = tf.fill(
        [batch_size, tensor_utils.shape(question_ids, 1) + 2], 0)
    segment_b_ids = tf.fill(
        [batch_size, tensor_utils.shape(context_ids, 1) + 1], 1)
    segment_ids = tf.concat([segment_a_ids, segment_b_ids], axis=1)

    # Create joined mask, accounting for special tokens gaps.
    gap_mask = tf.fill([batch_size, 1], 1)
    input_mask = tf.concat(
        [gap_mask, question_mask, gap_mask, context_mask, gap_mask], axis=1)
    bool_mask = tf.cast(input_mask, tf.bool)

    # Select unmasked items and move all padding to the end.
    # Right now this looks like this:
    #   [CLS] X X X [PAD] ... [SEP] Y Y Y [PAD] ... [SEP] [PAD] ...
    # And we want to change it to look like this:
    #   [CLS] X X X [SEP] Y Y Y [SEP] [PAD] ...
    input_ids = tensor_utils.boolean_mask(input_ids, bool_mask)
    input_mask = tensor_utils.boolean_mask(input_mask, bool_mask)
    segment_ids = tensor_utils.boolean_mask(segment_ids, bool_mask)

    return input_ids, input_mask, segment_ids
 def update_values(old_values, current_value):
     """Update stored values with this time step."""
     shape = [1] * len(tensor_utils.shape(old_values))
     shape[:2] = [batch_size, num_steps]
     tile = tensor_utils.shape(old_values)
     tile[:2] = [1, 1]
     condition = tf.tile(tf.reshape(is_written, shape), tile)
     tile = [1] * len(tensor_utils.shape(old_values))
     tile[1] = num_steps
     current_value = tf.tile(current_value, tile)
     return tf.where(condition, old_values, current_value)
    def compute_image_transformer(
        self,
        input_ids,
        input_image,
        input_image_mask,
        input_positions,
        reuse=None,
    ):
        """Build the image transformer."""
        with tf.variable_scope(self.scope_prefix + "transformer", reuse=reuse):
            with tf.variable_scope("bridge"):
                image_emb = tf.layers.dense(
                    inputs=input_image,
                    units=self.config.hidden_size,
                    activation=tf.nn.relu,
                    kernel_initializer=modeling.create_initializer(
                        self.config.initializer_range),
                    reuse=reuse)

            with tf.variable_scope("embeddings"):
                input_emb = tf.gather(self.embedding_table, input_ids)
                image_emb = tf.concat([input_emb, image_emb], axis=1)
                batch_size = tensor_utils.shape(image_emb, 0)
                sequence_length = tensor_utils.shape(image_emb, 1)
                position_emb = tf.gather(self.image_region_table,
                                         input_positions)
                position_emb = tf.pad(position_emb, [[0, 0], [1, 0], [0, 0]])
                input_order = tf.range(tensor_utils.shape(image_emb, 1))
                input_order = tf.tile(tf.expand_dims(input_order, 0),
                                      [tensor_utils.shape(image_emb, 0), 1])
                order_emb = tf.gather(self.image_order_table, input_order)
                input_segment_id = tf.fill([batch_size, sequence_length],
                                           self.IMG)
                segment_emb = tf.gather(self.segment_table, input_segment_id)
                input_emb = image_emb + position_emb + order_emb + segment_emb
                input_emb = modeling.layer_norm_and_dropout(
                    input_emb, self.config.hidden_dropout_prob)

            with tf.variable_scope("image/encoder"):
                sequence_output, output_cache = compute_transformer(
                    input_tensor=input_emb,
                    attention_mask=tf.expand_dims(input_image_mask, 1),
                    hidden_size=self.config.hidden_size,
                    num_hidden_layers=self.config.num_hidden_layers,
                    num_attention_heads=self.config.num_attention_heads,
                    intermediate_size=self.config.intermediate_size,
                    intermediate_act_fn=modeling.get_activation(
                        self.config.hidden_act),
                    hidden_dropout_prob=self.config.hidden_dropout_prob,
                    attention_probs_dropout_prob=(
                        self.config.attention_probs_dropout_prob),
                    initializer_range=self.config.initializer_range,
                    input_cache=None)
            return sequence_output, output_cache
Пример #4
0
def build_planner_inputs(question, answer, length, lookup_table):
    """Convert text to TextInputs for conditional text planner.

  Args:
    question: <string>, space-separated token string.
    answer: <string>, space-separated token string.
    length: Length to pad or truncate to.
    lookup_table: Instance of contrib.lookup.index_table_from_tensor.

  Returns:
    Instance of TextInputs.
  """
    # Build question.
    q_tokens = tf.string_split([question]).values
    q_tokens = tf.concat([["[Q]"], q_tokens], axis=0)
    q_token_ids = tf.cast(lookup_table.lookup(q_tokens), tf.int32)
    q_len = tensor_utils.shape(q_token_ids, 0)
    q_positions = tf.range(q_len)

    # Build answer.
    a_tokens = tf.string_split([answer]).values
    a_tokens = tf.concat([["[A]"], a_tokens], axis=0)
    a_token_ids = tf.cast(lookup_table.lookup(a_tokens), tf.int32)
    a_len = tensor_utils.shape(a_token_ids, 0)
    a_positions = tf.range(a_len)

    # Combine.
    token_ids = tf.concat([q_token_ids, a_token_ids], axis=0)
    segment_ids = tf.concat([tf.fill([q_len], 2), tf.fill([a_len], 1)], axis=0)
    positions = tf.concat([q_positions, a_positions], axis=0)
    q_mask = tf.ones_like(q_token_ids)
    mask = tf.concat([q_mask, tf.ones_like(a_token_ids)], axis=0)

    # Truncate.
    token_ids = token_ids[:length]
    segment_ids = segment_ids[:length]
    mask = mask[:length]
    positions = positions[:length]

    # Pad.
    pad = [[0, length - tf.size(token_ids)]]
    token_ids = tf.pad(token_ids, pad)
    mask = tf.pad(mask, pad)
    segment_ids = tf.pad(segment_ids, pad)
    positions = tf.pad(positions, pad)

    text_input = TextInputs(token_ids=tf.ensure_shape(token_ids, [length]),
                            mask=tf.ensure_shape(mask, [length]),
                            segment_ids=tf.ensure_shape(segment_ids, [length]),
                            positions=tf.ensure_shape(positions, [length]))

    return text_input
Пример #5
0
def sample_from_rollouts(rollouts, baseline=None, reward_type="exact_match"):
    """Sample a single example from the given rollouts.

  Args:
    rollouts: Instance of RolloutOutputs.
    baseline: <float32> [batch_size] Baseline value b for R'(y) = R(y) - b.
    reward_type: Choice between indicator, exact_match, and F1.

  Returns:
    rollout: Instance of text_utils.TextInputs.
    reward: <float32> [batch_size]
  """
    batch_size = tensor_utils.shape(rollouts.token_ids, 0)
    rollout_length = tensor_utils.shape(rollouts.token_ids, 2)

    # Self-critical baseline.
    if baseline is None:
        baseline = tf.zeros([batch_size])

    # [batch_size, num_rollouts]
    rewards = rollouts.rewards[reward_type] - tf.expand_dims(baseline, 1)

    # Mask zero reward samples.
    masked_scores = tf.where(tf.not_equal(rewards, 0),
                             tf.zeros_like(rollouts.scores),
                             tf.ones_like(rollouts.scores) * -1e8)

    # [batch_size, 1]
    sample_idx = tf.distributions.Categorical(logits=masked_scores).sample()
    sample_idx = tf.reshape(sample_idx, [batch_size, 1])

    # [batch_size]
    reward = tf.reshape(tensor_utils.gather(rewards, sample_idx), [-1])

    # [batch_size, rollout_length]
    token_ids = tf.reshape(tensor_utils.gather(rollouts.token_ids, sample_idx),
                           [batch_size, -1])
    mask = tf.reshape(tensor_utils.gather(rollouts.mask, sample_idx),
                      [batch_size, -1])
    segment_ids = tf.zeros_like(token_ids)
    positions = tf.tile(tf.expand_dims(tf.range(rollout_length), 0),
                        [batch_size, 1])

    # Create text input.
    rollout = text_utils.TextInputs(token_ids=token_ids,
                                    mask=mask,
                                    segment_ids=segment_ids,
                                    positions=positions)

    return rollout, reward
def compute_attention_mask(token_mask, input_mask):
    """Compute attention mask."""
    batch_size = tensor_utils.shape(token_mask, 0)
    num_tokens = tensor_utils.shape(token_mask, 1)
    token_to_token = tf.ones([batch_size, num_tokens, num_tokens],
                             dtype=tf.int32)
    token_to_token = tf.matrix_band_part(token_to_token, -1, 0)
    if input_mask is not None:
        token_to_input = tf.expand_dims(input_mask, 1)
        token_to_input = tf.tile(token_to_input, [1, num_tokens, 1])
        attention_mask = tf.concat([token_to_input, token_to_token], axis=-1)
    else:
        attention_mask = token_to_token
    return attention_mask
Пример #7
0
def expand_example(features, sample_one=True):
    """Expand nested tensor protos into multiple examples."""
    question_ids = tf.io.parse_tensor(features["question_ids"],
                                      out_type=tf.int64)
    questions = tf.io.parse_tensor(features["questions"], out_type=tf.string)
    answers = tf.io.parse_tensor(features["answers"], out_type=tf.string)
    captions = tf.io.parse_tensor(features["captions"], out_type=tf.string)
    num_qas = tensor_utils.shape(questions, 0)

    if sample_one:
        rid = tf.random.uniform([], maxval=num_qas, dtype=tf.int32)
        question_ids = tf.expand_dims(question_ids[rid], 0)
        questions = tf.expand_dims(questions[rid], 0)
        answers = tf.expand_dims(answers[rid], 0)
        captions = tf.expand_dims(captions[rid], 0)
        num_qas = 1

    image_ids = tf.tile(tf.expand_dims(features["image_id"], 0), [num_qas])
    images = tf.tile(tf.expand_dims(features["image"], 0), [num_qas])
    object_features = tf.tile(tf.expand_dims(features["object_features"], 0),
                              [num_qas])
    object_positions = tf.tile(tf.expand_dims(features["object_positions"], 0),
                               [num_qas])

    features = dict(image_id=image_ids,
                    image=images,
                    object_features=object_features,
                    object_positions=object_positions,
                    question_id=question_ids,
                    question=questions,
                    answer=answers,
                    caption=captions)
    return tf.data.Dataset.from_tensor_slices(features)
Пример #8
0
def exact_match(answer_ids, prediction_ids, vocab):
    """Compute exact match score between answer tokens and prediction tokens.

  Args:
    answer_ids: <int32> [batch_size, answer_length]
    prediction_ids: <int32> [batch_size, prediction_length]
    vocab: Instance of text_utils.Vocab.

  Returns:
    score: <float32> [batch_size] tensor of {0.0, 1.0}.
  """
    batch_size = tensor_utils.shape(answer_ids, 0)

    # Get cleanable words.
    remove_ids = list(_get_normalized_set(vocab))
    remove_ids = tf.reshape(remove_ids, [1, 1, -1])
    remove_ids = tf.tile(remove_ids, [batch_size, 1, 1])

    # Clean answer: remove tokens that are in the normalized set.
    should_keep = tf.reduce_all(tf.not_equal(tf.expand_dims(answer_ids, -1),
                                             remove_ids),
                                axis=-1)
    answer_ids = tensor_utils.boolean_mask(answer_ids, should_keep)

    # Clean context: remove tokens that are in the normalized set.
    should_keep = tf.reduce_all(tf.not_equal(
        tf.expand_dims(prediction_ids, -1), remove_ids),
                                axis=-1)
    prediction_ids = tensor_utils.boolean_mask(prediction_ids, should_keep)

    # Cleaned lengths.
    answer_len = tensor_utils.shape(answer_ids, 1)
    prediction_len = tensor_utils.shape(prediction_ids, 1)

    # Pad the shorter one to the length of the longer.
    padding = tf.maximum(0, prediction_len - answer_len)
    answer_ids = tf.pad(answer_ids, [[0, 0], [0, padding]])
    padding = tf.maximum(0, answer_len - prediction_len)
    prediction_ids = tf.pad(prediction_ids, [[0, 0], [0, padding]])

    # Check for equality: Padded A == Padded B?
    is_equal = tf.reduce_all(tf.equal(answer_ids, prediction_ids), axis=1)
    score = tf.cast(is_equal, tf.float32)

    return score
Пример #9
0
def preprocess_mapper(features, params, lookup_table, vocab, mode):
    """Model-specific preprocessing of features from the dataset."""
    # Set input type.
    features["input_type"] = tf.constant(datasets.DatasetTypes.REFERENCE)

    if mode != tf.estimator.ModeKeys.PREDICT:
        # Select random caption.
        captions = tf.io.parse_tensor(features["captions"], tf.string)
        num_captions = tensor_utils.shape(captions, 0)
        rid = tf.random.uniform([], maxval=num_captions, dtype=tf.int32)

        caption = text_utils.build_text_inputs(text=captions[rid],
                                               length=params["caption_length"],
                                               lookup_table=lookup_table,
                                               segment_id=0,
                                               start_token=vocab.CLS,
                                               end_token=vocab.SEP)
        assert isinstance(caption, text_utils.TextInputs)

        features["token_inputs"] = text_utils.TextInputs(
            token_ids=caption.token_ids[:-1],
            mask=caption.mask[:-1],
            segment_ids=caption.segment_ids[:-1],
            positions=caption.positions[:-1])

        features["token_outputs"] = text_utils.TextOutputs(
            token_ids=caption.token_ids[1:], mask=caption.mask[1:])

        if params.get("conditional_decoding"):
            random_span = text_utils.get_random_span(
                text=captions[rid],
                p=params["span_sample_p"],
                max_span_len=params["span_length"])

            features["condition_inputs"] = text_utils.build_text_inputs(
                text=random_span,
                length=params["condition_length"],
                lookup_table=lookup_table,
                segment_id=1,
                start_token=vocab.ANS)

    features["object_features"] = image_utils.parse_object_features(
        features["object_features"], features["object_positions"], params)

    # Remove extra inputs.
    features = {f: features[f] for f in features if f in KEYS}

    # Add dummy inputs for standardization for multi-tasking.
    footprint = datasets.footprint(params)
    assert footprint
    for k, v in footprint.items():
        if k not in features:
            features[k] = v

    return features
Пример #10
0
def get_token_mask(token_ids, stop_id):
    """Create mask for all ids past stop_id (inclusive)."""
    batch_size = tensor_utils.shape(token_ids, 0)
    num_tokens = tensor_utils.shape(token_ids, 1)

    # Create position matrix.
    idx_range = tf.expand_dims(tf.range(num_tokens), 0)
    idx_range = tf.tile(idx_range, [batch_size, 1])

    # Find positions of stop_id.
    stop_positions = tf.where(condition=tf.equal(token_ids, stop_id),
                              x=idx_range,
                              y=tf.fill([batch_size, num_tokens], num_tokens))

    # Find earliest stop position (length).
    stop_positions = tf.reduce_min(stop_positions, -1)

    # Mask out all tokens at positions > stop_id.
    mask = tf.less_equal(idx_range, tf.expand_dims(stop_positions, -1))

    return tf.cast(mask, tf.int32)
Пример #11
0
def max_scoring_span(start_scores, end_scores, max_length, no_answer_bias=0):
    """Compute max scoring span, using the sum of start and end scores.

  Args:
    start_scores: <float32> [batch_size, seq_len]
    end_scores: <float32> [batch_size, seq_len]
    max_length: <int32> Max answer length.
    no_answer_bias: <float32> Log-odds threshold for "no-answer" selection. I.e.
      if log p(span=i,j)/p(span=NULL) > no_answer_bias, then select i, j as the
      span, and NULL otherwise.

  Returns:
    start: <int32> [batch_size]
    end: <int32> [batch_size]
  """
    # Create sparse tensor of size [seq_len].
    seq_len = tensor_utils.shape(start_scores, -1)
    no_answer_bias = tf.scatter_nd([[0]], [no_answer_bias], [seq_len])
    no_answer_bias = tf.cast(no_answer_bias, tf.float32)

    # Apply bias to CLS token logits.
    no_answer_bias = tf.div(no_answer_bias, 2)
    start_scores += tf.expand_dims(no_answer_bias, 0)
    end_scores += tf.expand_dims(no_answer_bias, 0)

    # Compute outer sum, and mask to be upper triangular.
    # This gives a matrix of start[i] + end[j] scores, where j >= i.
    scores = tf.expand_dims(start_scores, 2) + tf.expand_dims(end_scores, 1)
    mask = (1 - tf.matrix_band_part(tf.ones_like(scores), 0, max_length - 1))
    scores -= mask * 1e-4

    def map_fn(inputs):
        flattened = tf.reshape(inputs, [-1])
        argmax = tf.argmax(flattened, output_type=tf.int32)
        indices = tensor_utils.unravel_index_2d(argmax, inputs.shape)
        score = flattened[argmax]
        return indices, score

    # Return i, j indices of max-scoring entry.
    with tf.device("/cpu"):
        endpoints, span_scores = tf.map_fn(fn=map_fn,
                                           elems=scores,
                                           dtype=(tf.int32, tf.float32))
    start = endpoints[:, 0]
    end = endpoints[:, 1]

    return start, end, span_scores
Пример #12
0
def parse_object_features(features, positions, params):
    """Parse ObjectDetectionOutput from TensorProtos."""
    features = tf.io.parse_tensor(features, tf.float32)
    positions = tf.io.parse_tensor(positions, tf.int64)
    positions = tf.cast(positions, tf.int32)
    features = features[:params["num_image_regions"]]
    num_objects = tensor_utils.shape(features, 0)
    padding = tf.maximum(0, params["num_image_regions"] - num_objects)
    features = tf.pad(features, [[0, padding], [0, 0]])
    positions = tf.pad(positions, [[0, padding]])
    features = tf.ensure_shape(
        features, [params["num_image_regions"], params["image_feature_size"]])
    positions = tf.ensure_shape(positions, [params["num_image_regions"]])
    mask = tf.pad(tf.ones(num_objects, dtype=tf.int32), [[0, padding]])
    mask = tf.ensure_shape(mask, [params["num_image_regions"]])
    output = ObjectDetectionOutput(features=features,
                                   positions=positions,
                                   mask=mask)
    return output
def beam_search_decode(
    model,
    encoder_cache,
    encoder_cache_mask,
    start_id,
    stop_id,
    segment_id,
    num_steps,
    beam_size,
    alpha=0,
    reuse=tf.AUTO_REUSE,
):
    """Decode for a given number of steps."""
    true_batch_size = tensor_utils.shape(encoder_cache_mask, 0)
    num_layers = model.config.num_hidden_layers
    num_heads = model.config.num_attention_heads
    head_size = int(model.config.hidden_size / num_heads)

    def symbols_to_logits_fn(input_ids, i, state):
        """Go from ids to logits for next symbol."""
        # Size of expanded tensor (expanded by beam size).
        batch_size = tensor_utils.shape(input_ids, 0)

        # [batch_size, 1]
        current_step_mask = tf.ones([batch_size, 1], tf.int32)

        # [batch_size, num_steps]
        written_mask = tf.cast(tf.less(tf.range(num_steps), i), tf.int32)
        written_mask = tf.tile(tf.expand_dims(written_mask, 0),
                               [batch_size, 1])
        is_written = tf.cast(written_mask, tf.bool)

        # [batch_size, cache_size + num_steps, num_layers, num_heads, head_size]
        input_cache = TransformerCache(
            keys=tf.concat([state.encoder_cache.keys, state.output_cache.keys],
                           1),
            values=tf.concat(
                [state.encoder_cache.values, state.output_cache.values], 1))

        # [batch_size, 1, cache_size + num_steps]
        masks = [state.encoder_cache_mask, written_mask, current_step_mask]
        attention_mask = tf.concat(masks, axis=1)
        attention_mask = tf.expand_dims(attention_mask, 1)

        # sequence_output: [batch_size, 1, hidden_size],
        # step_cache: [batch_size, 1, num_layers, num_heads, head_size]
        sequence_output, step_cache = model.compute_transformer(
            input_ids=input_ids,
            input_segment_id=tf.fill(tensor_utils.shape(input_ids),
                                     segment_id),
            input_positions=tf.fill(tensor_utils.shape(input_ids), i),
            attention_mask=attention_mask,
            input_cache=input_cache,
            reuse=reuse)

        # [batch_size, 1, vocab_size]
        logits = model.compute_logits(sequence_output, reuse=reuse)

        def update_values(old_values, current_value):
            """Update stored values with this time step."""
            shape = [1] * len(tensor_utils.shape(old_values))
            shape[:2] = [batch_size, num_steps]
            tile = tensor_utils.shape(old_values)
            tile[:2] = [1, 1]
            condition = tf.tile(tf.reshape(is_written, shape), tile)
            tile = [1] * len(tensor_utils.shape(old_values))
            tile[1] = num_steps
            current_value = tf.tile(current_value, tile)
            return tf.where(condition, old_values, current_value)

        # [batch_size, num_steps, num_layers, num_heads, head_size]
        beam_output_cache = TransformerCache(
            keys=update_values(state.output_cache.keys, step_cache.keys),
            values=update_values(state.output_cache.values, step_cache.values))

        # Return new state.
        state = DecodeState(encoder_cache=state.encoder_cache,
                            encoder_cache_mask=state.encoder_cache_mask,
                            output_cache=beam_output_cache)

        return tf.squeeze(logits, 1), state

    # Initialize output cache with zeros.
    shape = [true_batch_size, num_steps, num_layers, num_heads, head_size]
    output_cache = TransformerCache(keys=tf.zeros(shape),
                                    values=tf.zeros(shape))

    # Initialize state.
    state = DecodeState(encoder_cache=encoder_cache,
                        encoder_cache_mask=encoder_cache_mask,
                        output_cache=output_cache)

    # Decode using beam search.
    decoded_ids, scores, state = beam_search.beam_search(
        symbols_to_logits_fn=symbols_to_logits_fn,
        initial_ids=tf.fill([true_batch_size], start_id),
        eos_id=stop_id,
        beam_size=beam_size,
        alpha=alpha,
        decode_length=num_steps,
        vocab_size=model.config.vocab_size,
        states=state,
        use_tpu=True)

    # Postprocess.
    flat_mask = text_utils.get_token_mask(
        tf.reshape(decoded_ids, [-1, num_steps + 1]), stop_id)
    mask = tf.reshape(flat_mask, [true_batch_size, beam_size, num_steps + 1])
    decoded_ids *= mask

    return DecodeOutput(decoded_ids, mask, scores)
Пример #14
0
def model_fn(features, labels, mode, params, vocab):
    """Model function that satisfies the Estimator API.

  Args:
    features: Dictionary of model input tensors.
    labels: Ununsed.
    mode: A tf.estimator.ModeKeys value.
    params: Dictionary of model parameters.
    vocab: A utils.text_utils.Vocab instance.

  Returns:
    spec: A tf.estimator.TPUEstimatorSpec.
  """
    del labels

    # ----------------------------------------------------------------------------
    # INITIALIZATION.
    # ----------------------------------------------------------------------------

    # Update model config from the pre-trained checkpoint.
    model = transformer_utils.TransformerModel(
        config=transformer_utils.TransformerConfig.from_dict(params),
        is_training=(mode == tf_estimator.ModeKeys.TRAIN))

    # Initialize QA model.
    rc_model = hub.Module(params["rc_model"])

    # image_features: [batch_size, num_regions, feature_size]
    # image_positions: [batch_size, num_regions]
    # image_mask: [batch_size, num_regions]
    image_features = features["object_features"].features
    image_positions = features["object_features"].positions
    image_mask = features["object_features"].mask

    # Expand mask by 1 to account for the leading [IMG] token.
    # [batch_size, num_regions + 1]
    batch_size = tensor_utils.shape(image_mask, 0)
    input_mask = tf.pad(image_mask, [[0, 0], [1, 0]], constant_values=1)

    # Encode the image and store the cached transformer values.
    # [batch_size, num_regions + 1, num_layers, num_heads, head_size]
    _, input_cache = model.compute_image_transformer(
        input_ids=tf.fill([batch_size, 1], vocab.t2i(vocab.IMG)),
        input_image=image_features,
        input_image_mask=input_mask,
        input_positions=image_positions)

    # ----------------------------------------------------------------------------
    # TRAINING
    # ----------------------------------------------------------------------------

    if mode == tf_estimator.ModeKeys.TRAIN:
        # MIXER-style training objective consists of two parts:
        #   1) Policy gradient on rewarded rollouts.
        #   2) MLE regularization on references.
        # The full loss is L_total = L_pg + L_mle.

        # Step 1: Policy gradient.
        # Compute and score policy rollouts (multiple per image).
        rollouts = reward_utils.compute_rollouts(model=model,
                                                 rc_model=rc_model,
                                                 features=features,
                                                 encoder_cache=input_cache,
                                                 encoder_cache_mask=input_mask,
                                                 vocab=vocab,
                                                 params=params)

        # Using a self-critical baseline, R'(y) = R(y) - b where b = argmax p(y|x),
        # sample a single rollout with non-zero reward.
        rollout, reward = reward_utils.sample_from_rollouts(
            rollouts=rollouts,
            baseline=rollouts.rewards[params["reward"]][:, 0],
            reward_type=params["reward"])

        # Compute the probablity of the rollout (back-propable).
        # [batch_size, decode_length, input_length + decode_length]
        rollout_attention_mask = transformer_utils.compute_attention_mask(
            token_mask=rollout.mask[:, :-1], input_mask=input_mask)

        # [batch_size, decode_length, vocab_size]
        rollout_emb, _ = model.compute_transformer(
            input_ids=rollout.token_ids[:, :-1],
            input_segment_id=rollout.segment_ids[:, :-1],
            input_positions=rollout.positions[:, :-1],
            attention_mask=rollout_attention_mask,
            input_cache=input_cache,
            reuse=tf.AUTO_REUSE)

        # [batch_size, decode_length, vocab_size]
        rollout_logits = model.compute_logits(rollout_emb, reuse=tf.AUTO_REUSE)

        # Compute the RL loss, -R(y) * log p(y|x)
        # Some elements in this batch are MLE only, mask those out from the loss.
        rollout_mask = tf.cast(rollout.mask[:, 1:], tf.float32)
        pg_mask = tf.equal(features["input_type"], datasets.DatasetTypes.VQA)
        rollout_mask *= tf.expand_dims(tf.cast(pg_mask, tf.float32), 1)
        rl_loss = tf.losses.sparse_softmax_cross_entropy(
            labels=rollout.token_ids[:, 1:],
            logits=rollout_logits,
            weights=tf.expand_dims(reward, 1) * rollout_mask,
            reduction=tf.losses.Reduction.SUM)
        rl_loss = tf.math.divide_no_nan(rl_loss, tf.reduce_sum(rollout_mask))

        # Step 2: MLE on references.
        # [batch_size, decode_length, input_length + decode_length]
        reference_attention_mask = transformer_utils.compute_attention_mask(
            token_mask=features["token_inputs"].mask, input_mask=input_mask)

        # [batch_size, decode_length, hidden_size]
        target_emb, _ = model.compute_transformer(
            input_ids=features["token_inputs"].token_ids,
            input_segment_id=features["token_inputs"].segment_ids,
            input_positions=features["token_inputs"].positions,
            attention_mask=reference_attention_mask,
            input_cache=input_cache,
            reuse=tf.AUTO_REUSE)

        # [batch_size, decode_length, vocab_size]
        target_logits = model.compute_logits(target_emb, reuse=tf.AUTO_REUSE)

        # Compute the MLE objective (cross-entropy loss).
        weights = features["token_outputs"].mask
        ref_mask = tf.equal(features["input_type"],
                            datasets.DatasetTypes.REFERENCE)
        weights *= tf.expand_dims(tf.cast(ref_mask, tf.int32), 1)
        reference_loss = tf.losses.sparse_softmax_cross_entropy(
            labels=features["token_outputs"].token_ids,
            logits=target_logits,
            weights=weights)

        # Add both losses together.
        loss = rl_loss + reference_loss

        # BERT-style optimization with linear warmp.
        train_op = optimization.create_optimizer(
            loss=loss,
            init_lr=params["learning_rate"],
            num_train_steps=params["num_train_steps"],
            num_warmup_steps=params["num_warmup_steps"],
            use_tpu=params.get("use_tpu"))

        # Book-keeping.
        summaries = tpu_summaries.TpuSummaries(params["model_dir"])
        summaries.scalar("loss", loss)

        # Check what percentage of examples have non-zero reward.
        total_vqa = tf.reduce_sum(tf.cast(pg_mask, tf.float32))
        nonzero = tf.cast(tf.not_equal(reward, 0), tf.float32)
        nonzero *= tf.cast(pg_mask, tf.float32)
        total_nonzero = tf.reduce_sum(nonzero)
        summaries.scalar("density", tf.div_no_nan(total_nonzero, total_vqa))

        # Total (non-normalized) reward.
        reward = rollouts.rewards[params["reward"]][:, 0]
        reward *= tf.cast(pg_mask, tf.float32)
        total_reward = tf.reduce_sum(reward)
        summaries.scalar("reward", tf.div_no_nan(total_reward, total_vqa))
        host_call = summaries.get_host_call()
    else:
        loss = None
        train_op = None
        host_call = None

    # ----------------------------------------------------------------------------
    # TESTING.
    # ----------------------------------------------------------------------------

    if mode == tf_estimator.ModeKeys.PREDICT:
        decode_output = transformer_utils.beam_search_decode(
            model=model,
            encoder_cache=input_cache,
            encoder_cache_mask=input_mask,
            start_id=vocab.t2i(vocab.CLS),
            stop_id=vocab.t2i(vocab.SEP),
            segment_id=0,
            num_steps=params["decode_length"],
            beam_size=params["beam_size"],
            alpha=params["beam_length_penalty"],
            reuse=tf.AUTO_REUSE)
        predictions = dict(image_id=features.get("image_id", -1),
                           question_id=features.get("question_id", -1),
                           token_ids=decode_output.token_ids[:, :, 1:])
    else:
        predictions = None

    # ----------------------------------------------------------------------------
    # WARM-START.
    # ----------------------------------------------------------------------------

    # Initialize from pretrained model.
    def scaffold_fn():
        """Init op run on host."""
        checkpoint = params["base_model"]
        if params["warm_start_path"]:
            checkpoint = params["warm_start_path"]
        if checkpoint:
            checkpoint_utils.init_from_checkpoint(checkpoint)
        return tf.train.Scaffold()

    return tf_estimator.tpu.TPUEstimatorSpec(
        mode=mode,
        loss=loss,
        train_op=train_op,
        predictions=predictions,
        scaffold_fn=scaffold_fn,
        host_call=host_call,
    )
Пример #15
0
def indicator_score(answer_ids, answer_mask, context_ids, vocab):
    """Compute indicator score of answer and context.

  Checks if the answer tokens are a subspan of the context.

  Args:
    answer_ids: <int32> [batch_size, answer_length]
    answer_mask: <int32> [batch_size, answer_length]
    context_ids: <int32> [batch_size, context_length]
    vocab: Instance of text_utils.Vocab.

  Returns:
    score: <float32> [batch_size] tensor of {0.0, 1.0}.
  """
    batch_size = tensor_utils.shape(answer_ids, 0)

    # Get cleanable words.
    remove_ids = list(_get_normalized_set(vocab))
    remove_ids = tf.reshape(remove_ids, [1, 1, -1])
    remove_ids = tf.tile(remove_ids, [batch_size, 1, 1])

    # Clean answer: remove tokens that are in the normalized set.
    should_keep = tf.reduce_all(tf.not_equal(tf.expand_dims(answer_ids, -1),
                                             remove_ids),
                                axis=-1)
    answer_ids = tensor_utils.boolean_mask(answer_ids, should_keep)
    answer_mask = tensor_utils.boolean_mask(answer_mask, should_keep)

    # Clean context: remove tokens that are in the normalized set.
    should_keep = tf.reduce_all(tf.not_equal(tf.expand_dims(context_ids, -1),
                                             remove_ids),
                                axis=-1)
    context_ids = tensor_utils.boolean_mask(context_ids, should_keep)

    # Cleaned lengths.
    answer_len = tensor_utils.shape(answer_ids, 1)
    context_len = tensor_utils.shape(context_ids, 1)

    # Pad start of context (to select NULL for over-length indices).
    context_ids = tf.pad(context_ids, [[0, 0], [1, 0]])
    context_len += 1

    # Sliding window approach: take the full context of length N and gather
    # it into a tensor with all windows of length M (a N x M tensor).
    # [context_len, answer_len]
    window_idx = tf.range(answer_len)
    window_idx = tf.tile(tf.expand_dims(window_idx, 0), [context_len, 1])
    offsets = tf.expand_dims(tf.range(context_len), 1)
    window_idx += offsets
    window_idx *= tf.cast(tf.less(window_idx, context_len), tf.int32)

    # [batch_size, context_len * answer_len]
    window_idx = tf.reshape(window_idx, [1, -1])
    window_idx = tf.tile(window_idx, [batch_size, 1])

    # [batch_size, context_len * answer_len]
    batch_idx = tf.range(batch_size)
    batch_idx = tf.expand_dims(batch_idx, 1)
    batch_idx = tf.tile(batch_idx, [1, context_len * answer_len])

    # [batch_size, context_len, answer_len]
    batch_idx = tf.reshape(batch_idx, [-1])
    window_idx = tf.reshape(window_idx, [-1])
    coords = tf.stack([batch_idx, window_idx], axis=1)
    window_ids = tf.gather_nd(context_ids, coords)
    window_ids = tf.reshape(window_ids, [batch_size, context_len, answer_len])

    # [batch_size, context_len, answer_len]
    answer_mask = tf.expand_dims(answer_mask, 1)
    window_ids *= answer_mask

    # Check for equality. The whole window has to match the answer, but only
    # one window has to count to be a positive indicator value.
    answer_ids = tf.expand_dims(answer_ids, 1)
    is_equal = tf.reduce_all(tf.equal(answer_ids, window_ids), axis=-1)
    score = tf.cast(tf.reduce_any(is_equal, axis=-1), tf.float32)

    return score
    def symbols_to_logits_fn(input_ids, i, state):
        """Go from ids to logits for next symbol."""
        # Size of expanded tensor (expanded by beam size).
        batch_size = tensor_utils.shape(input_ids, 0)

        # [batch_size, 1]
        current_step_mask = tf.ones([batch_size, 1], tf.int32)

        # [batch_size, num_steps]
        written_mask = tf.cast(tf.less(tf.range(num_steps), i), tf.int32)
        written_mask = tf.tile(tf.expand_dims(written_mask, 0),
                               [batch_size, 1])
        is_written = tf.cast(written_mask, tf.bool)

        # [batch_size, cache_size + num_steps, num_layers, num_heads, head_size]
        input_cache = TransformerCache(
            keys=tf.concat([state.encoder_cache.keys, state.output_cache.keys],
                           1),
            values=tf.concat(
                [state.encoder_cache.values, state.output_cache.values], 1))

        # [batch_size, 1, cache_size + num_steps]
        masks = [state.encoder_cache_mask, written_mask, current_step_mask]
        attention_mask = tf.concat(masks, axis=1)
        attention_mask = tf.expand_dims(attention_mask, 1)

        # sequence_output: [batch_size, 1, hidden_size],
        # step_cache: [batch_size, 1, num_layers, num_heads, head_size]
        sequence_output, step_cache = model.compute_transformer(
            input_ids=input_ids,
            input_segment_id=tf.fill(tensor_utils.shape(input_ids),
                                     segment_id),
            input_positions=tf.fill(tensor_utils.shape(input_ids), i),
            attention_mask=attention_mask,
            input_cache=input_cache,
            reuse=reuse)

        # [batch_size, 1, vocab_size]
        logits = model.compute_logits(sequence_output, reuse=reuse)

        def update_values(old_values, current_value):
            """Update stored values with this time step."""
            shape = [1] * len(tensor_utils.shape(old_values))
            shape[:2] = [batch_size, num_steps]
            tile = tensor_utils.shape(old_values)
            tile[:2] = [1, 1]
            condition = tf.tile(tf.reshape(is_written, shape), tile)
            tile = [1] * len(tensor_utils.shape(old_values))
            tile[1] = num_steps
            current_value = tf.tile(current_value, tile)
            return tf.where(condition, old_values, current_value)

        # [batch_size, num_steps, num_layers, num_heads, head_size]
        beam_output_cache = TransformerCache(
            keys=update_values(state.output_cache.keys, step_cache.keys),
            values=update_values(state.output_cache.values, step_cache.values))

        # Return new state.
        state = DecodeState(encoder_cache=state.encoder_cache,
                            encoder_cache_mask=state.encoder_cache_mask,
                            output_cache=beam_output_cache)

        return tf.squeeze(logits, 1), state
Пример #17
0
def compute_rollouts(
    model,
    rc_model,
    features,
    encoder_cache,
    encoder_cache_mask,
    vocab,
    params,
):
    """Rollout model and compute rewards for each sample.

  Args:
    model: utils.transformer_utils.TransformerModel instance.
    rc_model: TF Hub module for extractive QA.
    features: Input features (questions and answers).
    encoder_cache: Transformer cache for encoded input.
    encoder_cache_mask: Input mask for the Transformer cache.
    vocab: Instance of text_utils.Vocab.
    params: Model parameters.

  Returns:
    rollout: Instance of RolloutOutputs.
  """
    # 1) First rollout the model with top-K beam search.
    rollout = transformer_utils.beam_search_decode(
        model=model,
        encoder_cache=encoder_cache,
        encoder_cache_mask=encoder_cache_mask,
        start_id=vocab.t2i(vocab.CLS),
        stop_id=vocab.t2i(vocab.SEP),
        segment_id=0,
        num_steps=params["decode_length"],
        beam_size=params["num_rollouts"],
        alpha=params["beam_length_penalty"],
        reuse=tf.AUTO_REUSE)

    # [batch_size, num_rollouts, rollout_length]
    batch_size = tensor_utils.shape(rollout.token_ids, 0)
    num_rollouts = tensor_utils.shape(rollout.token_ids, 1)
    rollout_ids = rollout.token_ids
    rollout_mask = rollout.mask

    # [batch_size * num_rollouts, rollout_length]
    rollout_length = tensor_utils.shape(rollout_ids, -1)
    rollout_ids = tf.reshape(rollout_ids, [-1, rollout_length])
    rollout_mask = tf.reshape(rollout_mask, [-1, rollout_length])

    # 2) Compute the QA rewards on the rollouts.
    # [batch_size * num_rollouts, question_length]
    question = tensor_utils.tile_batch(features["question_inputs"],
                                       num_rollouts)

    # [batch_size * num_rollouts, answer_length]
    answer = tensor_utils.tile_batch(features["answer_outputs"], num_rollouts)

    # [batch_size * num_rollouts]
    rewards = compute_qa_rewards(question_ids=question.token_ids,
                                 question_mask=question.mask,
                                 answer_ids=answer.token_ids,
                                 answer_mask=answer.mask,
                                 context_ids=rollout_ids[:, 1:],
                                 context_mask=rollout_mask[:, 1:],
                                 rc_model=rc_model,
                                 vocab=vocab,
                                 max_answer_length=params["answer_length"],
                                 no_answer_bias=params["no_answer_bias"])

    # [batch_size, num_rollouts, ...]
    reshaped_rewards = {}
    for k, v in rewards.items():
        if len(v.shape) > 1:
            v = tf.reshape(v, [batch_size, num_rollouts, -1])
        else:
            v = tf.reshape(v, [batch_size, num_rollouts])
        reshaped_rewards[k] = v

    # 3) Combine rollouts and rewards.
    rollouts = RolloutOutputs(token_ids=rollout.token_ids,
                              mask=rollout.mask,
                              scores=rollout.scores,
                              rewards=reshaped_rewards)

    return rollouts
Пример #18
0
def model_fn(features, labels, mode, params, vocab):
    """Model function that satisfies the Estimator API.

  Args:
    features: Dictionary of model input tensors.
    labels: Ununsed.
    mode: A tf.estimator.ModeKeys value.
    params: Dictionary of model parameters.
    vocab: A utils.text_utils.Vocab instance.

  Returns:
    spec: A tf.estimator.TPUEstimatorSpec.
  """
    del labels

    # ----------------------------------------------------------------------------
    # INITIALIZATION.
    # ----------------------------------------------------------------------------

    model = transformer_utils.TransformerModel(
        config=transformer_utils.TransformerConfig.from_dict(params),
        is_training=(mode == tf_estimator.ModeKeys.TRAIN))

    # image_features: [batch_size, num_regions, feature_size]
    # image_positions: [batch_size, num_regions]
    # image_mask: [batch_size, num_regions]
    image_features = features["object_features"].features
    image_positions = features["object_features"].positions
    image_mask = features["object_features"].mask

    # Expand mask by 1 to account for the leading [IMG] token.
    # [batch_size, num_regions + 1]
    batch_size = tensor_utils.shape(image_mask, 0)
    input_mask = tf.pad(image_mask, [[0, 0], [1, 0]], constant_values=1)

    # Encode the image and store the cached transformer values.
    # [batch_size, num_regions + 1, num_layers, num_heads, head_size]
    _, input_cache = model.compute_image_transformer(
        input_ids=tf.fill([batch_size, 1], vocab.t2i(vocab.IMG)),
        input_image=image_features,
        input_image_mask=input_mask,
        input_positions=image_positions)

    if params.get("conditional_decoding"):
        # Add additional (text) conditioning information to the input cache.
        # The conditioning information gets to see the image information.
        # The new input consists of both the image and the extra encoded text.
        # This is used for the LEARN function of Alg. 1 in the paper.

        # [batch_size, num_regions + condition_length + 1]
        input_mask = tf.concat([input_mask, features["condition_inputs"].mask],
                               1)

        # [batch_size, condition_length, num_layers, num_heads, head_size]
        _, condition_cache = model.compute_transformer(
            input_ids=features["condition_inputs"].token_ids,
            input_segment_id=features["condition_inputs"].segment_ids,
            input_positions=features["condition_inputs"].positions,
            attention_mask=tf.expand_dims(input_mask, 1),
            input_cache=input_cache,
            reuse=tf.AUTO_REUSE,
            conditional=True)

        # [batch_size, input_length, num_layers, num_heads, head_size]
        input_cache = transformer_utils.TransformerCache(
            keys=tf.concat([input_cache.keys, condition_cache.keys], 1),
            values=tf.concat([input_cache.values, condition_cache.values], 1))

    # ----------------------------------------------------------------------------
    # TRAINING
    # ----------------------------------------------------------------------------

    if mode == tf_estimator.ModeKeys.TRAIN:
        # During training, apply forced decoding with a diagonal attention mask.
        # [batch_size, caption_length - 1, input_length + caption_length - 1]
        attention_mask = transformer_utils.compute_attention_mask(
            token_mask=features["token_inputs"].mask, input_mask=input_mask)

        # [batch_size, caption_length - 1, hidden_size]
        target_emb, _ = model.compute_transformer(
            input_ids=features["token_inputs"].token_ids,
            input_segment_id=features["token_inputs"].segment_ids,
            input_positions=features["token_inputs"].positions,
            attention_mask=attention_mask,
            input_cache=input_cache,
            reuse=tf.AUTO_REUSE)

        # [batch_size, caption_length - 1, vocab_size]
        target_logits = model.compute_logits(target_emb, reuse=tf.AUTO_REUSE)

        # Compute the MLE objective (cross-entropy loss).
        loss = tf.losses.sparse_softmax_cross_entropy(
            labels=features["token_outputs"].token_ids,
            logits=target_logits,
            weights=features["token_outputs"].mask)

        # BERT-style optimization with linear warmp.
        train_op = optimization.create_optimizer(
            loss=loss,
            init_lr=params["learning_rate"],
            num_train_steps=params["num_train_steps"],
            num_warmup_steps=params["num_warmup_steps"],
            use_tpu=params.get("use_tpu"))

        summaries = tpu_summaries.TpuSummaries(params["model_dir"])
        summaries.scalar("loss", loss)
        host_call = summaries.get_host_call()
    else:
        loss = None
        train_op = None
        host_call = None

    # ----------------------------------------------------------------------------
    # TESTING.
    # ----------------------------------------------------------------------------

    if mode == tf_estimator.ModeKeys.PREDICT:
        decode_output = transformer_utils.beam_search_decode(
            model=model,
            encoder_cache=input_cache,
            encoder_cache_mask=input_mask,
            start_id=vocab.t2i(vocab.CLS),
            stop_id=vocab.t2i(vocab.SEP),
            segment_id=0,
            num_steps=params["decode_length"],
            beam_size=params["beam_size"],
            alpha=params["beam_length_penalty"],
            reuse=tf.AUTO_REUSE)
        predictions = dict(image_id=features.get("image_id", -1),
                           question_id=features.get("question_id", -1),
                           token_ids=decode_output.token_ids[:, :, 1:])
    else:
        predictions = None

    # ----------------------------------------------------------------------------
    # WARM-START.
    # ----------------------------------------------------------------------------

    # Initialize from pretrained model.
    def scaffold_fn():
        """Init op run on host."""
        checkpoint = params.get("warm_start_path")
        if checkpoint:
            checkpoint_utils.init_from_checkpoint(checkpoint)
        return tf.train.Scaffold()

    return tf_estimator.tpu.TPUEstimatorSpec(
        mode=mode,
        loss=loss,
        train_op=train_op,
        predictions=predictions,
        scaffold_fn=scaffold_fn,
        host_call=host_call,
    )
def model_fn(features, labels, mode, params, vocab):
    """Model function."""
    del labels
    assert mode == tf.estimator.ModeKeys.PREDICT, "Mode should be PREDICT."

    # Initialize transformer model.
    model = transformer_utils.TransformerModel(
        config=transformer_utils.TransformerConfig.from_dict(params),
        is_training=(mode == tf.estimator.ModeKeys.TRAIN))

    # image_features: [batch_size, num_regions, feature_size]
    # image_positions: [batch_size, num_regions]
    # image_mask: [batch_size, num_regions]
    image_features = features["object_features"].features
    image_positions = features["object_features"].positions
    image_mask = features["object_features"].mask

    # Expand mask by 1 for IMG token.
    batch_size = tensor_utils.shape(image_mask, 0)
    input_mask = tf.pad(image_mask, [[0, 0], [1, 0]], constant_values=1)

    # [batch_size, num_regions + 1, num_layers, num_heads, head_size]
    _, input_cache = model.compute_image_transformer(
        input_ids=tf.fill([batch_size, 1], vocab.t2i(vocab.IMG)),
        input_image=image_features,
        input_image_mask=input_mask,
        input_positions=image_positions)

    # Add conditioning information to input cache.
    if params.get("conditional_decoding"):
        # Add additional (text) conditioning information to the input cache.
        # The conditioning information gets to see the image information.
        # The new input consists of both the image and the extra encoded text.
        # This is used for the LEARN function of Alg. 1 in the paper.

        # [batch_size, num_regions + condition_length + 1]
        input_mask = tf.concat([input_mask, features["condition_inputs"].mask],
                               1)

        # [batch_size, condition_length, num_layers, num_heads, head_size]
        _, condition_cache = model.compute_transformer(
            input_ids=features["condition_inputs"].token_ids,
            input_segment_id=features["condition_inputs"].segment_ids,
            input_positions=features["condition_inputs"].positions,
            attention_mask=tf.expand_dims(input_mask, 1),
            input_cache=input_cache,
            reuse=tf.AUTO_REUSE,
            conditional=True)

        # [batch_size, input_length, num_layers, num_heads, head_size]
        input_cache = transformer_utils.TransformerCache(
            keys=tf.concat([input_cache.keys, condition_cache.keys], 1),
            values=tf.concat([input_cache.values, condition_cache.values], 1))

    # Initialize QA model.
    rc_model = hub.Module(params["rc_model"])

    # Compute rollouts.
    rollouts = reward_utils.compute_rollouts(model=model,
                                             rc_model=rc_model,
                                             features=features,
                                             encoder_cache=input_cache,
                                             encoder_cache_mask=input_mask,
                                             vocab=vocab,
                                             params=params)

    # Add to predictions.
    predictions = dict(
        image_id=features["image_id"],
        question_id=features["question_id"],
        token_ids=rollouts.token_ids[:, :, 1:],
        scores=rollouts.scores,
    )

    # Add all rewards.
    for k, v in rollouts.rewards.items():
        predictions[k] = v

    # Initialize base model.
    def scaffold_fn():
        """Init op run on host."""
        checkpoint_utils.init_from_checkpoint(params["checkpoint"])
        return tf.train.Scaffold()

    return tf.estimator.tpu.TPUEstimatorSpec(
        mode=mode,
        predictions=predictions,
        scaffold_fn=scaffold_fn,
    )
Пример #20
0
def rc_span(
    question_ids,
    question_mask,
    context_ids,
    context_mask,
    rc_model,
    vocab,
    max_length=10,
    no_answer_bias=0,
):
    """Computes exact match score from QA model run on context.

  Args:
    question_ids: <int32> [batch_size, question_len]
    question_mask: <int32> [batch_size, question_len]
    context_ids: <int32> [batch_size, context_len]
    context_mask: <int32> [batch_size, context_len]
    rc_model: Extractive question answering model.
    vocab: Instance of text_utils.Vocab.
    max_length: Max answer length.
    no_answer_bias: Log-odds ratio for answer span over NULL.

  Returns:
    score: <float32> [batch_size]
  """
    # Mask out stop id in context if present.
    stop_id = vocab.t2i(vocab.SEP)
    stop_mask = tf.cast(tf.not_equal(context_ids, stop_id), tf.int32)
    context_mask *= stop_mask

    # Prepare rc inputs.
    input_ids, input_mask, segment_ids = _get_rc_model_input(
        question_ids=question_ids,
        question_mask=question_mask,
        context_ids=context_ids,
        context_mask=context_mask,
        vocab=vocab)

    # Get start/end logits from RC model.
    outputs = rc_model(inputs=dict(input_ids=input_ids,
                                   input_mask=input_mask,
                                   segment_ids=segment_ids),
                       signature="extractive_qa",
                       as_dict=True)

    # Dimensions
    batch_size = tensor_utils.shape(input_ids, 0)
    context_len = tensor_utils.shape(input_ids, 1)

    # Decode span.
    start_logits = tf.reshape(outputs["start_logits"], [-1, context_len])
    end_logits = tf.reshape(outputs["end_logits"], [-1, context_len])
    start, end, span_scores = max_scoring_span(start_scores=start_logits,
                                               end_scores=end_logits,
                                               max_length=max_length,
                                               no_answer_bias=no_answer_bias)

    # Expand shape to be compatible for broadcasting.
    start = tf.reshape(start, [-1, 1])
    end = tf.reshape(end, [-1, 1])

    # Create mask where mask[i, j] = True if i >= start and j <= end.
    # [batch_size, max_rc_input_len]
    mask = tf.tile(tf.expand_dims(tf.range(context_len), 0), [batch_size, 1])
    mask = tf.logical_and(tf.greater_equal(mask, start),
                          tf.less_equal(mask, end))

    # Gather padded answer span from context.
    answer_span = tensor_utils.boolean_mask(input_ids, mask)

    return answer_span, span_scores