def _clone_function(layer): if type(layer) in [ tf.keras.layers.Embedding, SparseEmbedding, ] and _need_partition_embedding(layer): logger.debug("Replace {} with {}".format( layer.name, Embedding)) # ElasticDL embedding only accept a string type initializer init = tf.keras.initializers.serialize( layer.embeddings_initializer)["class_name"] if type(layer) == tf.keras.layers.Embedding: embedding_layer = Embedding( output_dim=layer.output_dim, input_dim=layer.input_dim, embeddings_initializer=init, mask_zero=layer.mask_zero, input_length=layer.input_length, name=layer.name, ) else: embedding_layer = Embedding( output_dim=layer.output_dim, input_dim=layer.input_dim, embeddings_initializer=init, name=layer.name, combiner=layer.combiner, ) embedding_layer.set_embedding_weight_name( layer.trainable_weights[0].name) return embedding_layer elif type(layer) == tf.keras.layers.DenseFeatures: return _replace_tf_embedding_column_with_edl(layer) return layer
def _replace_attr_with_edl_embedding(model): """Replace the keras embedding attributes in the model with `elasticdl.layers.Embedding` layers. """ for name, value in model.__dict__.items(): if type( value ) == tf.keras.layers.Embedding and _need_partition_embedding( value ): logger.info( "Replace {} layer with " "elasticdl.layers.Embedding".format(value) ) initializer_name = tf.keras.initializers.serialize( value.embeddings_initializer )["class_name"] embedding_layer = Embedding( output_dim=value.output_dim, input_dim=value.input_dim, embeddings_initializer=initializer_name, mask_zero=value.mask_zero, input_length=value.input_length, name=value.name, ) # The weights of subclass model is None, so we need to create # the weight name which is "{layer_name}/embeddings:0" in # tf.keras.layers.Embedding. embedding_layer.set_embedding_weight_name( value.name + "/embeddings:0" ) setattr(model, name, embedding_layer) elif type(value) == SparseEmbedding and _need_partition_embedding( value ): logger.info( "Replace {} layer with " "elasticdl.layers.Embedding".format(value) ) embedding_layer = Embedding( output_dim=value.output_dim, input_dim=value.input_dim, embeddings_initializer=initializer_name, combiner=value.combiner, name=value.name, ) embedding_layer.set_embedding_weight_name( value.name + "/embeddings:0" ) setattr(model, name, embedding_layer) elif type(value) == tf.keras.layers.DenseFeatures: feature_layer = _replace_tf_embedding_column_with_edl(value) setattr(model, name, feature_layer) return model