Exemplo n.º 1
0
def preprocess_mapper(features, params, lookup_table, vocab, mode):
    """Model-specific preprocessing of features from the dataset."""
    # Set input type.
    features["input_type"] = tf.constant(datasets.DatasetTypes.REFERENCE)

    if mode != tf.estimator.ModeKeys.PREDICT:
        # Select random caption.
        captions = tf.io.parse_tensor(features["captions"], tf.string)
        num_captions = tensor_utils.shape(captions, 0)
        rid = tf.random.uniform([], maxval=num_captions, dtype=tf.int32)

        caption = text_utils.build_text_inputs(text=captions[rid],
                                               length=params["caption_length"],
                                               lookup_table=lookup_table,
                                               segment_id=0,
                                               start_token=vocab.CLS,
                                               end_token=vocab.SEP)
        assert isinstance(caption, text_utils.TextInputs)

        features["token_inputs"] = text_utils.TextInputs(
            token_ids=caption.token_ids[:-1],
            mask=caption.mask[:-1],
            segment_ids=caption.segment_ids[:-1],
            positions=caption.positions[:-1])

        features["token_outputs"] = text_utils.TextOutputs(
            token_ids=caption.token_ids[1:], mask=caption.mask[1:])

        if params.get("conditional_decoding"):
            random_span = text_utils.get_random_span(
                text=captions[rid],
                p=params["span_sample_p"],
                max_span_len=params["span_length"])

            features["condition_inputs"] = text_utils.build_text_inputs(
                text=random_span,
                length=params["condition_length"],
                lookup_table=lookup_table,
                segment_id=1,
                start_token=vocab.ANS)

    features["object_features"] = image_utils.parse_object_features(
        features["object_features"], features["object_positions"], params)

    # Remove extra inputs.
    features = {f: features[f] for f in features if f in KEYS}

    # Add dummy inputs for standardization for multi-tasking.
    footprint = datasets.footprint(params)
    assert footprint
    for k, v in footprint.items():
        if k not in features:
            features[k] = v

    return features
Exemplo n.º 2
0
def preprocess_mapper(features, params, lookup_table, vocab):
  """Model-specific preprocessing of features from the dataset."""
  # Set input type.
  features["input_type"] = tf.constant(datasets.DatasetTypes.VQA)

  # Fix question id.
  features["question_id"] = tf.ensure_shape(features["question_id"], [])

  features["question_inputs"] = text_utils.build_text_inputs(
      text=features["question"],
      length=params["question_length"],
      lookup_table=lookup_table,
      segment_id=0)

  answer = text_utils.build_text_inputs(
      text=features["answer"],
      length=params["answer_length"],
      lookup_table=lookup_table,
      segment_id=1,
      start_token=vocab.ANS)
  assert isinstance(answer, text_utils.TextInputs)

  features["answer_inputs"] = answer

  features["answer_outputs"] = text_utils.TextOutputs(
      token_ids=answer.token_ids[1:], mask=answer.mask[1:])

  features["object_features"] = image_utils.parse_object_features(
      features["object_features"], features["object_positions"], params)

  if params.get("conditional_decoding"):
    features["condition_inputs"] = text_utils.build_planner_inputs(
        question=features["question"],
        answer=features["answer"],
        length=params["condition_length"],
        lookup_table=lookup_table)

  # Remove extra inputs.
  features = {f: features[f] for f in features if f in KEYS}

  # Add dummy inputs for standardization for multi-tasking.
  for k, v in datasets.footprint(params).items():
    if k not in features:
      features[k] = v

  return features
Exemplo n.º 3
0
def preprocess_mapper(features, params, lookup_table, vocab):
    """Model-specific preprocessing of features from the dataset."""
    # Set input type.
    features["input_type"] = tf.constant(datasets.DatasetTypes.GUIDED)

    caption = text_utils.build_text_inputs(text=features["caption"],
                                           length=params["caption_length"],
                                           lookup_table=lookup_table,
                                           segment_id=0,
                                           start_token=vocab.CLS,
                                           end_token=vocab.SEP)
    assert isinstance(caption, text_utils.TextInputs)

    features["token_inputs"] = text_utils.TextInputs(
        token_ids=caption.token_ids[:-1],
        mask=caption.mask[:-1],
        segment_ids=caption.segment_ids[:-1],
        positions=caption.positions[:-1])

    features["token_outputs"] = text_utils.TextOutputs(
        token_ids=caption.token_ids[1:], mask=caption.mask[1:])

    features["condition_inputs"] = text_utils.build_planner_inputs(
        question=features["question"],
        answer=features["answer"],
        length=params["condition_length"],
        lookup_table=lookup_table)

    features["object_features"] = image_utils.parse_object_features(
        features["object_features"], features["object_positions"], params)

    # Remove extra inputs.
    features = {f: features[f] for f in features if f in KEYS}

    # Add dummy inputs for standardization for multi-tasking.
    for k, v in datasets.footprint(params).items():
        if k not in features:
            features[k] = v

    return features
Exemplo n.º 4
0
def preprocess_mapper(raw_text, params, lookup_table, vocab):
    """Model-specific preprocessing of features from the dataset."""
    features = dict(input_type=datasets.DatasetTypes.GUIDED)

    splits = tf.strings.split([raw_text], "\t").values
    question, answer, text = splits[1], splits[2], splits[3]

    text = text_utils.build_text_inputs(text=text,
                                        length=params["caption_length"],
                                        lookup_table=lookup_table,
                                        segment_id=0,
                                        start_token=vocab.CLS,
                                        end_token=vocab.SEP)
    assert isinstance(text, text_utils.TextInputs)

    features["token_inputs"] = text_utils.TextInputs(
        token_ids=text.token_ids[:-1],
        mask=text.mask[:-1],
        segment_ids=text.segment_ids[:-1],
        positions=text.positions[:-1])

    features["token_outputs"] = text_utils.TextOutputs(
        token_ids=text.token_ids[1:], mask=text.mask[1:])

    if params.get("conditional_decoding"):
        features["condition_inputs"] = text_utils.build_planner_inputs(
            question=question,
            answer=answer,
            length=params["condition_length"],
            lookup_table=lookup_table)

    # Add dummy inputs for standardization for multi-tasking.
    for k, v in datasets.footprint(params).items():
        if k not in features:
            features[k] = v

    return features