def _clone_function(layer): if type(layer) == Embedding: logger.info( "Replace embedding layer with " "elasticdl.layers.Embedding" ) # The combiner is not None only for SparseEmbedding, if layer.combiner is not None: embedding_layer = SparseEmbedding( output_dim=layer.output_dim, input_dim=layer.input_dim, embeddings_initializer=layer.embeddings_initializer, name=layer.name, combiner=layer.combiner, ) else: embedding_layer = tf.keras.layers.Embedding( output_dim=layer.output_dim, input_dim=layer.input_dim, embeddings_initializer=layer.embeddings_initializer, mask_zero=layer.mask_zero, input_length=layer.input_length, name=layer.name, ) return embedding_layer elif type(layer) == tf.keras.layers.DenseFeatures: return _replace_edl_embedding_column_with_tf(layer) return layer
def _replace_attr_with_keras_embedding(model): """Replace the elasticdl.layers.Embedding attributes in the model with `tf.keras.layers.Embedding` or `SparseEmbedding` layers. """ for name, value in model.__dict__.items(): if type(value) == Embedding: # The combiner is not None only for SparseEmbedding, if value.combiner is not None: logger.info("Replace elasticdl with SparseEmbedding") embedding_layer = SparseEmbedding( output_dim=value.output_dim, input_dim=value.input_dim, embeddings_initializer=value.embeddings_initializer, combiner=value.combiner, ) else: logger.info("Replace elasticdl with ", "tf.kerasl.layers.Embedding") embedding_layer = tf.keras.layers.Embedding( output_dim=value.output_dim, input_dim=value.input_dim, embeddings_initializer=value.embeddings_initializer, mask_zero=value.mask_zero, input_length=value.input_length, ) setattr(model, name, embedding_layer) return model
def custom_model_with_sparse_embedding(): sparse_input = tf.keras.layers.Input(shape=(4, ), dtype="int64", sparse=True, name="sparse_feature") embedding = SparseEmbedding(4, 2, combiner="sum", name="embedding")(sparse_input) outputs = tf.keras.layers.Dense(1)(embedding) return tf.keras.models.Model(sparse_input, outputs)