def get_config(self): # Import here to avoid circular imports. from tensorflow.python.feature_column import serialization # pylint: disable=g-import-not-at-top column_configs = serialization.serialize_feature_columns( self._feature_columns) config = {'feature_columns': column_configs} config['partitioner'] = generic_utils.serialize_keras_object( self._partitioner) base_config = super( # pylint: disable=bad-super-call _BaseFeaturesLayer, self).get_config() return dict(list(base_config.items()) + list(config.items()))
def get_config(self): from tensorflow.python.feature_column.serialization import serialize_feature_columns, serialize_feature_column config = { "item_embedding_size": self.item_embedding_size, "item_id_column": serialize_feature_column(self.item_id_column), "item_feature_columns": serialize_feature_columns(self.item_feature_columns), "target_id_column": serialize_feature_column(self.target_id_column) } base_config = super(EgesModel, self).get_config() return {**base_config, **config}
def test_deserialization_deduping(self): price = fc.numeric_column('price') bucketized_price = fc.bucketized_column(price, boundaries=[0, 1]) configs = serialization.serialize_feature_columns([price, bucketized_price]) deserialized_feature_columns = serialization.deserialize_feature_columns( configs) self.assertLen(deserialized_feature_columns, 2) new_price = deserialized_feature_columns[0] new_bucketized_price = deserialized_feature_columns[1] # Ensure these are not the original objects: self.assertIsNot(price, new_price) self.assertIsNot(bucketized_price, new_bucketized_price) # But they are equivalent: self.assertEqual(price, new_price) self.assertEqual(bucketized_price, new_bucketized_price) # Check that deduping worked: self.assertIs(new_bucketized_price.source_column, new_price)
def deserialization_custom_objects(self): # Note that custom_objects is also tested extensively above per class, this # test ensures that the public wrappers also handle it correctly. def _custom_fn(input_tensor): return input_tensor + 42. price = fc.numeric_column('price', normalizer_fn=_custom_fn) configs = serialization.serialize_feature_columns([price]) deserialized_feature_columns = serialization.deserialize_feature_columns( configs) self.assertLen(deserialized_feature_columns, 1) new_price = deserialized_feature_columns[0] # Ensure these are not the original objects: self.assertIsNot(price, new_price) # But they are equivalent: self.assertEqual(price, new_price) # Check that normalizer_fn points to the correct function. self.assertIs(new_price.normalizer_fn, _custom_fn)