def test_model_with_lookup(self): inputs = tf.keras.Input(shape=(1, ), dtype=tf.string) lookup_out = IndexLookup(vocabulary=["A", "B", "C"])(inputs) model = tf.keras.Model(inputs=inputs, outputs=lookup_out) out = model.call(tf.constant([["A"], ["C"], ["B"], ["D"], ["E"]])) self.assertTrue( np.array_equal(np.array([[0], [2], [1], [3], [3]], dtype=int), out.numpy()))
def test_lookup_with_file(self): with tempfile.TemporaryDirectory() as temp_dir: vocab_file = os.path.join(temp_dir, "vocab_test.txt") with open(vocab_file, "w") as f: f.write("A\n") f.write("B\n") f.write("C\n") lookup_layer = IndexLookup(vocabulary=vocab_file) self._check_lookup(lookup_layer)
def transform_from_code_gen(source_inputs): inputs = source_inputs.copy() education_hash_out = Hashing(education_hash.hash_bucket_size)( ToSparse()(inputs["education"]) ) occupation_hash_out = Hashing(occupation_hash.hash_bucket_size)( ToSparse()(inputs["occupation"]) ) native_country_hash_out = Hashing(native_country_hash.hash_bucket_size)( ToSparse()(inputs["native_country"]) ) workclass_lookup_out = IndexLookup(workclass_lookup.vocabulary_list)( ToSparse()(inputs["workclass"]) ) marital_status_lookup_out = IndexLookup( marital_status_lookup.vocabulary_list )(ToSparse()(inputs["marital_status"])) relationship_lookup_out = IndexLookup(relationship_lookup.vocabulary_list)( ToSparse()(inputs["relationship"]) ) race_lookup_out = IndexLookup(race_lookup.vocabulary_list)( ToSparse()(inputs["race"]) ) sex_lookup_out = IndexLookup(sex_lookup.vocabulary_list)( ToSparse()(inputs["sex"]) ) age_bucketize_out = Discretization(age_bucketize.boundaries)( ToSparse()(inputs["age"]) ) capital_gain_bucketize_out = Discretization( capital_gain_bucketize.boundaries )(ToSparse()(inputs["capital_gain"])) capital_loss_bucketize_out = Discretization( capital_loss_bucketize.boundaries )(ToSparse()(inputs["capital_loss"])) hours_per_week_bucketize_out = Discretization( hours_per_week_bucketize.boundaries )(ToSparse()(inputs["hours_per_week"])) group1_out = ConcatenateWithOffset(group1.id_offsets)( [ workclass_lookup_out, hours_per_week_bucketize_out, capital_gain_bucketize_out, capital_loss_bucketize_out, ] ) group2_out = ConcatenateWithOffset(group2.id_offsets)( [ education_hash_out, marital_status_lookup_out, relationship_lookup_out, occupation_hash_out, ] ) group3_out = ConcatenateWithOffset(group3.id_offsets)( [ age_bucketize_out, sex_lookup_out, race_lookup_out, native_country_hash_out, ] ) group1_embedding_wide_out = SparseEmbedding( input_dim=group1_embedding_wide.input_dim, output_dim=group1_embedding_wide.output_dim, )(group1_out) group2_embedding_wide_out = SparseEmbedding( input_dim=group2_embedding_wide.input_dim, output_dim=group2_embedding_wide.output_dim, )(group2_out) group1_embedding_deep_out = SparseEmbedding( input_dim=group1_embedding_deep.input_dim, output_dim=group1_embedding_deep.output_dim, )(group1_out) group2_embedding_deep_out = SparseEmbedding( input_dim=group2_embedding_deep.input_dim, output_dim=group2_embedding_deep.output_dim, )(group2_out) group3_embedding_deep_out = SparseEmbedding( input_dim=group3_embedding_deep.input_dim, output_dim=group3_embedding_deep.output_dim, )(group3_out) wide_embeddings_out = [ group1_embedding_wide_out, group2_embedding_wide_out, ] deep_embeddings_out = [ group1_embedding_deep_out, group2_embedding_deep_out, group3_embedding_deep_out, ] return wide_embeddings_out, deep_embeddings_out
def test_lookup_with_list(self): lookup_layer = IndexLookup(vocabulary=["A", "B", "C"]) self._check_lookup(lookup_layer) self.assertEqual(lookup_layer.vocab_size(), 4)
def test_lookup_with_list(self): lookup_layer = IndexLookup(vocabulary=["A", "B", "C"]) self._check_lookup(lookup_layer)