示例#1
0
 def test_model_with_lookup(self):
     inputs = tf.keras.Input(shape=(1, ), dtype=tf.string)
     lookup_out = IndexLookup(vocabulary=["A", "B", "C"])(inputs)
     model = tf.keras.Model(inputs=inputs, outputs=lookup_out)
     out = model.call(tf.constant([["A"], ["C"], ["B"], ["D"], ["E"]]))
     self.assertTrue(
         np.array_equal(np.array([[0], [2], [1], [3], [3]], dtype=int),
                        out.numpy()))
示例#2
0
 def test_lookup_with_file(self):
     with tempfile.TemporaryDirectory() as temp_dir:
         vocab_file = os.path.join(temp_dir, "vocab_test.txt")
         with open(vocab_file, "w") as f:
             f.write("A\n")
             f.write("B\n")
             f.write("C\n")
         lookup_layer = IndexLookup(vocabulary=vocab_file)
         self._check_lookup(lookup_layer)
def transform_from_code_gen(source_inputs):
    inputs = source_inputs.copy()

    education_hash_out = Hashing(education_hash.hash_bucket_size)(
        ToSparse()(inputs["education"])
    )
    occupation_hash_out = Hashing(occupation_hash.hash_bucket_size)(
        ToSparse()(inputs["occupation"])
    )
    native_country_hash_out = Hashing(native_country_hash.hash_bucket_size)(
        ToSparse()(inputs["native_country"])
    )
    workclass_lookup_out = IndexLookup(workclass_lookup.vocabulary_list)(
        ToSparse()(inputs["workclass"])
    )
    marital_status_lookup_out = IndexLookup(
        marital_status_lookup.vocabulary_list
    )(ToSparse()(inputs["marital_status"]))
    relationship_lookup_out = IndexLookup(relationship_lookup.vocabulary_list)(
        ToSparse()(inputs["relationship"])
    )
    race_lookup_out = IndexLookup(race_lookup.vocabulary_list)(
        ToSparse()(inputs["race"])
    )
    sex_lookup_out = IndexLookup(sex_lookup.vocabulary_list)(
        ToSparse()(inputs["sex"])
    )
    age_bucketize_out = Discretization(age_bucketize.boundaries)(
        ToSparse()(inputs["age"])
    )
    capital_gain_bucketize_out = Discretization(
        capital_gain_bucketize.boundaries
    )(ToSparse()(inputs["capital_gain"]))
    capital_loss_bucketize_out = Discretization(
        capital_loss_bucketize.boundaries
    )(ToSparse()(inputs["capital_loss"]))
    hours_per_week_bucketize_out = Discretization(
        hours_per_week_bucketize.boundaries
    )(ToSparse()(inputs["hours_per_week"]))

    group1_out = ConcatenateWithOffset(group1.id_offsets)(
        [
            workclass_lookup_out,
            hours_per_week_bucketize_out,
            capital_gain_bucketize_out,
            capital_loss_bucketize_out,
        ]
    )
    group2_out = ConcatenateWithOffset(group2.id_offsets)(
        [
            education_hash_out,
            marital_status_lookup_out,
            relationship_lookup_out,
            occupation_hash_out,
        ]
    )
    group3_out = ConcatenateWithOffset(group3.id_offsets)(
        [
            age_bucketize_out,
            sex_lookup_out,
            race_lookup_out,
            native_country_hash_out,
        ]
    )

    group1_embedding_wide_out = SparseEmbedding(
        input_dim=group1_embedding_wide.input_dim,
        output_dim=group1_embedding_wide.output_dim,
    )(group1_out)
    group2_embedding_wide_out = SparseEmbedding(
        input_dim=group2_embedding_wide.input_dim,
        output_dim=group2_embedding_wide.output_dim,
    )(group2_out)

    group1_embedding_deep_out = SparseEmbedding(
        input_dim=group1_embedding_deep.input_dim,
        output_dim=group1_embedding_deep.output_dim,
    )(group1_out)
    group2_embedding_deep_out = SparseEmbedding(
        input_dim=group2_embedding_deep.input_dim,
        output_dim=group2_embedding_deep.output_dim,
    )(group2_out)
    group3_embedding_deep_out = SparseEmbedding(
        input_dim=group3_embedding_deep.input_dim,
        output_dim=group3_embedding_deep.output_dim,
    )(group3_out)

    wide_embeddings_out = [
        group1_embedding_wide_out,
        group2_embedding_wide_out,
    ]

    deep_embeddings_out = [
        group1_embedding_deep_out,
        group2_embedding_deep_out,
        group3_embedding_deep_out,
    ]

    return wide_embeddings_out, deep_embeddings_out
示例#4
0
 def test_lookup_with_list(self):
     lookup_layer = IndexLookup(vocabulary=["A", "B", "C"])
     self._check_lookup(lookup_layer)
     self.assertEqual(lookup_layer.vocab_size(), 4)
 def test_lookup_with_list(self):
     lookup_layer = IndexLookup(vocabulary=["A", "B", "C"])
     self._check_lookup(lookup_layer)