コード例 #1
0
        def _clone_function(layer):
            if type(layer) in [
                    tf.keras.layers.Embedding,
                    SparseEmbedding,
            ] and _need_partition_embedding(layer):
                logger.debug("Replace {} with {}".format(
                    layer.name, Embedding))
                # ElasticDL embedding only accept a string type initializer
                init = tf.keras.initializers.serialize(
                    layer.embeddings_initializer)["class_name"]

                if type(layer) == tf.keras.layers.Embedding:
                    embedding_layer = Embedding(
                        output_dim=layer.output_dim,
                        input_dim=layer.input_dim,
                        embeddings_initializer=init,
                        mask_zero=layer.mask_zero,
                        input_length=layer.input_length,
                        name=layer.name,
                    )
                else:
                    embedding_layer = Embedding(
                        output_dim=layer.output_dim,
                        input_dim=layer.input_dim,
                        embeddings_initializer=init,
                        name=layer.name,
                        combiner=layer.combiner,
                    )
                embedding_layer.set_embedding_weight_name(
                    layer.trainable_weights[0].name)
                return embedding_layer
            elif type(layer) == tf.keras.layers.DenseFeatures:
                return _replace_tf_embedding_column_with_edl(layer)
            return layer
コード例 #2
0
ファイル: model_handler.py プロジェクト: shanhaijun/elasticdl
 def _replace_attr_with_edl_embedding(model):
     """Replace the keras embedding attributes in the model with
     `elasticdl.layers.Embedding` layers.
     """
     for name, value in model.__dict__.items():
         if type(
             value
         ) == tf.keras.layers.Embedding and _need_partition_embedding(
             value
         ):
             logger.info(
                 "Replace {} layer with "
                 "elasticdl.layers.Embedding".format(value)
             )
             initializer_name = tf.keras.initializers.serialize(
                 value.embeddings_initializer
             )["class_name"]
             embedding_layer = Embedding(
                 output_dim=value.output_dim,
                 input_dim=value.input_dim,
                 embeddings_initializer=initializer_name,
                 mask_zero=value.mask_zero,
                 input_length=value.input_length,
                 name=value.name,
             )
             # The weights of subclass model is None, so we need to create
             # the weight name which is "{layer_name}/embeddings:0" in
             # tf.keras.layers.Embedding.
             embedding_layer.set_embedding_weight_name(
                 value.name + "/embeddings:0"
             )
             setattr(model, name, embedding_layer)
         elif type(value) == SparseEmbedding and _need_partition_embedding(
             value
         ):
             logger.info(
                 "Replace {} layer with "
                 "elasticdl.layers.Embedding".format(value)
             )
             embedding_layer = Embedding(
                 output_dim=value.output_dim,
                 input_dim=value.input_dim,
                 embeddings_initializer=initializer_name,
                 combiner=value.combiner,
                 name=value.name,
             )
             embedding_layer.set_embedding_weight_name(
                 value.name + "/embeddings:0"
             )
             setattr(model, name, embedding_layer)
         elif type(value) == tf.keras.layers.DenseFeatures:
             feature_layer = _replace_tf_embedding_column_with_edl(value)
             setattr(model, name, feature_layer)
     return model