class MLP(tf.keras.Model): def __init__(self, num_labels: int, num_dense_units: int = 256, dropout_rate: float = 0.3, num_blocks: int = 3, name: str = "mlp", **kwargs): super(MLP, self).__init__(name=name, **kwargs) if num_blocks < 1: raise Exception("Cannot have less than 1 block!") self.num_blocks = num_blocks self.flatten = Flatten(name="flatten") for num in range(1, num_blocks + 1): setattr( self, "dense_{}".format(num), Dense(num_dense_units, activation="relu", name="dense_{}".format(num))) setattr(self, "batch_norm_{}".format(num), BatchNormalization(name="batch_norm_{}".format(num))) self.dropout = Dropout(dropout_rate, name="dropout") self.classifier = Dense(num_labels, activation="softmax", name="classifier") def build(self, inputs_shape): self.flatten.build(inputs_shape) self.dense_1.build((None, tf.math.reduce_prod(inputs_shape[1:]))) for layer in self.layers: if layer.name not in ["flatten", "dense_1"]: layer.build((None, self.dense_1.units)) super(MLP, self).build(inputs_shape) def call(self, inputs, training=False): x = self.flatten(inputs) for num in range(1, self.num_blocks + 1): dense = self.get_layer("dense_{}".format(num)) batch_norm = self.get_layer("batch_norm_{}".format(num)) x = dense(x) x = batch_norm(x) x = self.dropout(x, training=training) output = self.classifier(x) return output
layer_config = Conv2D.get_config(layer_lst[ii]) layer_temp = Conv2D.from_config(layer_config) elif layer_type == "MaxPooling2D": layer_config = MaxPooling2D.get_config(layer_lst[ii]) layer_temp = MaxPooling2D.from_config(layer_config) elif layer_type == "Dense": layer_config = Dense.get_config(layer_lst[ii]) layer_temp = Dense.from_config(layer_config) elif layer_type == "Flatten": layer_temp = Flatten() else: layer_config = tf.keras.layers.Layer.get_config(layer_lst[ii]) layer_temp = tf.keras.layers.Layer.from_config(layer_config) layer_weight = layer_lst[ii].get_weights() layer_temp.build(layer_lst[ii].input_shape) if layer_weight: layer_temp.set_weights(layer_weight) model_new.add(layer_temp) model_new.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy']) model_new.summary() tf.keras.models.save_model(model_new, 'mnist_swap_model3.h5') # Evaluating the new model model_new.evaluate(test_images, test_labels)