Example #1
0
    def test_token_to_char_ids_mapper(self):
        with tf.Graph().as_default():
            dataset = tf.data.Dataset.from_tensor_slices({"s": ["a", "b"]})
            dataset = dataset.map(char_utils.token_to_char_ids_mapper(["s"],
                                                                      4))
            self.assertDictEqual(dataset.output_types, {
                "s": tf.string,
                "s_cid": tf.int32
            })
            self.assertDictEqual(dataset.output_shapes, {
                "s": [],
                "s_cid": [4]
            })

            dataset = dataset.batch(2)
            features = dataset.make_one_shot_iterator().get_next()

            with tf.Session() as sess:
                tf_s, tf_s_cid = sess.run([features["s"], features["s_cid"]])

            self.assertAllEqual(tf_s, ["a", "b"])

            expected_a_emb = [
                char_utils.BOW_CHAR, 97, char_utils.EOW_CHAR,
                char_utils.PAD_CHAR
            ]
            expected_b_emb = [
                char_utils.BOW_CHAR, 98, char_utils.EOW_CHAR,
                char_utils.PAD_CHAR
            ]
            self.assertAllEqual(tf_s_cid, [expected_a_emb, expected_b_emb])
Example #2
0
def preprocess_mapper(features, lookup_table):
    """Model-specific preprocessing of features from the dataset."""
    # Truncate contexts that are too long.
    features["context"] = features["context"][:FLAGS.max_context_len]

    # Add the input lengths to the dataset ("question_len" and "context_len").
    features = dataset_utils.length_mapper(["question", "context"])(features)

    # Add the word IDs to the dataset ("question_wid" and "context_wid").
    features = dataset_utils.string_to_int_mapper(["question", "context"],
                                                  mapping=lookup_table,
                                                  suffix="_wid")(features)

    # Add the character IDs to the dataset ("question_cid" and "context_cid").
    features = char_utils.token_to_char_ids_mapper(["question",
                                                    "context"])(features)
    return features