def test_hash_compute_output_signature(self): input_shape = tensor_shape.TensorShape([2, 3]) input_spec = tensor_spec.TensorSpec(input_shape, dtypes.string) layer = categorical.Hashing(num_bins=2) output_spec = layer.compute_output_signature(input_spec) self.assertEqual(output_spec.shape.dims, input_shape.dims) self.assertEqual(output_spec.dtype, dtypes.int64)
def test_hash_sparse_input(self): layer = categorical.Hashing(num_bins=2) inp = sparse_tensor.SparseTensor( indices=[[0, 0], [1, 0], [1, 1], [2, 0], [2, 1]], values=['omar', 'stringer', 'marlo', 'wire', 'skywalker'], dense_shape=[3, 2]) output = layer(inp) self.assertEqual(output.values.numpy().max(), 1) self.assertEqual(output.values.numpy().min(), 0)
def test_hash_ragged_string_input(self): layer = categorical.Hashing(num_bins=2) inp_data = ragged_factory_ops.constant( [['omar', 'stringer', 'marlo', 'wire'], ['marlo', 'skywalker', 'wire']], dtype=dtypes.string) out_data = layer(inp_data) self.assertEqual(out_data.values.numpy().max(), 1) self.assertEqual(out_data.values.numpy().min(), 0) # hash of 'marlo' should be same. self.assertAllClose(out_data[0][2], out_data[1][0]) # hash of 'wire' should be same. self.assertAllClose(out_data[0][3], out_data[1][2]) inp_t = input_layer.Input(shape=(None,), ragged=True, dtype=dtypes.string) out_t = layer(inp_t) model = training.Model(inputs=inp_t, outputs=out_t) self.assertAllClose(out_data, model.predict(inp_data))
def test_config_with_custom_name(self): layer = categorical.Hashing(num_bins=2, name='hashing') config = layer.get_config() layer_1 = categorical.Hashing.from_config(config) self.assertEqual(layer_1.name, layer.name)
def test_hash_two_bins(self): layer = categorical.Hashing(num_bins=2) inp = np.asarray([['A'], ['B'], ['C'], ['D'], ['E']]) output = layer(inp) self.assertEqual(output.numpy().max(), 1) self.assertEqual(output.numpy().min(), 0)
def test_hash_single_bin(self): layer = categorical.Hashing(num_bins=1) inp = np.asarray([['A'], ['B'], ['C'], ['D'], ['E']]) output = layer(inp) self.assertAllClose(np.asarray([[0], [0], [0], [0], [0]]), output)