Esempio n. 1
0
 def test_cross_output_dtype(self):
   layer = hashed_crossing.HashedCrossing(num_bins=2)
   self.assertAllEqual(layer(([1], [1])).dtype, tf.int64)
   layer = hashed_crossing.HashedCrossing(num_bins=2, dtype=tf.int32)
   self.assertAllEqual(layer(([1], [1])).dtype, tf.int32)
   layer = hashed_crossing.HashedCrossing(num_bins=2, output_mode='one_hot')
   self.assertAllEqual(layer(([1], [1])).dtype, tf.float32)
   layer = hashed_crossing.HashedCrossing(
       num_bins=2, output_mode='one_hot', dtype=tf.float64)
   self.assertAllEqual(layer(([1], [1])).dtype, tf.float64)
Esempio n. 2
0
 def test_cross_batch_of_scalars_2d(self, data_fn):
     layer = hashed_crossing.HashedCrossing(num_bins=10)
     feat1 = data_fn([['A'], ['B'], ['A'], ['B'], ['A']])
     feat2 = data_fn([[101], [101], [101], [102], [102]])
     outputs = layer((feat1, feat2))
     self.assertAllClose(outputs, [[1], [4], [1], [6], [3]])
     self.assertAllEqual(outputs.shape.as_list(), [5, 1])
Esempio n. 3
0
 def test_cross_batch_of_scalars_1d(self, data_fn):
     layer = hashed_crossing.HashedCrossing(num_bins=10)
     feat1 = data_fn(['A', 'B', 'A', 'B', 'A'])
     feat2 = data_fn([101, 101, 101, 102, 102])
     outputs = layer((feat1, feat2))
     self.assertAllClose(outputs, [1, 4, 1, 6, 3])
     self.assertAllEqual(outputs.shape.as_list(), [5])
Esempio n. 4
0
 def test_cross_scalars(self, data_fn):
     layer = hashed_crossing.HashedCrossing(num_bins=10)
     feat1 = data_fn('A')
     feat2 = data_fn(101)
     outputs = layer((feat1, feat2))
     self.assertAllClose(outputs, 1)
     self.assertAllEqual(outputs.shape.as_list(), [])
 def test_float_input_fails(self):
     with self.assertRaisesRegex(
         ValueError, "should have an integer or string"
     ):
         hashed_crossing.HashedCrossing(num_bins=10)(
             (tf.constant([1.0]), tf.constant([1.0]))
         )
Esempio n. 6
0
 def test_from_config(self):
     layer = hashed_crossing.HashedCrossing(num_bins=5,
                                            output_mode='one_hot',
                                            sparse=True)
     cloned_layer = hashed_crossing.HashedCrossing.from_config(
         layer.get_config())
     feat1 = tf.constant([['A'], ['B'], ['A'], ['B'], ['A']])
     feat2 = tf.constant([[101], [101], [101], [102], [102]])
     original_outputs = layer((feat1, feat2))
     cloned_outputs = cloned_layer((feat1, feat2))
     self.assertAllEqual(tf.sparse.to_dense(cloned_outputs),
                         tf.sparse.to_dense(original_outputs))
Esempio n. 7
0
 def test_cross_one_hot_output(self, sparse):
   layer = hashed_crossing.HashedCrossing(
       num_bins=5, output_mode='one_hot', sparse=sparse)
   feat1 = tf.constant([['A'], ['B'], ['A'], ['B'], ['A']])
   feat2 = tf.constant([[101], [101], [101], [102], [102]])
   outputs = layer((feat1, feat2))
   if sparse:
     outputs = tf.sparse.to_dense(outputs)
   self.assertAllClose(outputs, [
       [0, 1, 0, 0, 0],
       [0, 0, 0, 0, 1],
       [0, 1, 0, 0, 0],
       [0, 1, 0, 0, 0],
       [0, 0, 0, 1, 0],
   ])
   self.assertAllEqual(outputs.shape.as_list(), [5, 5])
def embedding_varlen(batch_size):
    """Benchmark a variable-length embedding."""
    # Data and constants.
    num_buckets = 10000
    data_a = tf.random.uniform(shape=(batch_size * NUM_REPEATS, 1),
                               maxval=32768,
                               dtype=tf.int64)
    data_b = tf.strings.as_string(data_a)

    # Keras implementation
    input_1 = keras.Input(shape=(1, ), name="data_a", dtype=tf.int64)
    input_2 = keras.Input(shape=(1, ), name="data_b", dtype=tf.string)
    outputs = hashed_crossing.HashedCrossing(num_buckets)([input_1, input_2])
    model = keras.Model([input_1, input_2], outputs)

    # FC implementation
    fc = tf.feature_column.crossed_column(["data_a", "data_b"], num_buckets)

    # Wrap the FC implementation in a tf.function for a fair comparison
    @tf_function()
    def fc_fn(tensors):
        fc.transform_feature(
            tf.__internal__.feature_column.FeatureTransformationCache(tensors),
            None,
        )

    # Benchmark runs
    keras_data = {
        "data_a": data_a,
        "data_b": data_b,
    }
    k_avg_time = fc_bm.run_keras(keras_data, model, batch_size, NUM_REPEATS)

    fc_data = {
        "data_a": data_a,
        "data_b": data_b,
    }
    fc_avg_time = fc_bm.run_fc(fc_data, fc_fn, batch_size, NUM_REPEATS)

    return k_avg_time, fc_avg_time
Esempio n. 9
0
    def test_saved_model_keras(self):
        string_in = keras.Input(shape=(1, ), dtype=tf.string)
        int_in = keras.Input(shape=(1, ), dtype=tf.int64)
        out = hashed_crossing.HashedCrossing(num_bins=10)((string_in, int_in))
        model = keras.Model(inputs=(string_in, int_in), outputs=out)

        string_data = tf.constant([['A'], ['B'], ['A'], ['B'], ['A']])
        int_data = tf.constant([[101], [101], [101], [102], [102]])
        expected_output = [[1], [4], [1], [6], [3]]

        output_data = model((string_data, int_data))
        self.assertAllClose(output_data, expected_output)

        # Save the model to disk.
        output_path = os.path.join(self.get_temp_dir(), 'saved_model')
        model.save(output_path, save_format='tf')
        loaded_model = keras.models.load_model(
            output_path,
            custom_objects={'HashedCrossing': hashed_crossing.HashedCrossing})

        # Validate correctness of the new model.
        new_output_data = loaded_model((string_data, int_data))
        self.assertAllClose(new_output_data, expected_output)
Esempio n. 10
0
 def test_upsupported_shape_input_fails(self):
     with self.assertRaisesRegex(ValueError, 'inputs should have shape'):
         hashed_crossing.HashedCrossing(num_bins=10)(
             (tf.constant([[[1.]]]), tf.constant([[[1.]]])))
Esempio n. 11
0
 def test_sparse_input_fails(self):
     with self.assertRaisesRegex(ValueError,
                                 'inputs should be dense tensors'):
         sparse_in = tf.sparse.from_dense(tf.constant([1]))
         hashed_crossing.HashedCrossing(num_bins=10)((sparse_in, sparse_in))
Esempio n. 12
0
 def test_single_input_fails(self):
     with self.assertRaisesRegex(ValueError, 'at least two inputs'):
         hashed_crossing.HashedCrossing(num_bins=10)([tf.constant(1)])
Esempio n. 13
0
 def test_non_list_input_fails(self):
     with self.assertRaisesRegex(ValueError, 'should be called on a list'):
         hashed_crossing.HashedCrossing(num_bins=10)(tf.constant(1))