Пример #1
0
class MLP(tf.keras.Model):
    def __init__(self,
                 num_labels: int,
                 num_dense_units: int = 256,
                 dropout_rate: float = 0.3,
                 num_blocks: int = 3,
                 name: str = "mlp",
                 **kwargs):
        super(MLP, self).__init__(name=name, **kwargs)

        if num_blocks < 1:
            raise Exception("Cannot have less than 1 block!")
        self.num_blocks = num_blocks

        self.flatten = Flatten(name="flatten")

        for num in range(1, num_blocks + 1):
            setattr(
                self, "dense_{}".format(num),
                Dense(num_dense_units,
                      activation="relu",
                      name="dense_{}".format(num)))
            setattr(self, "batch_norm_{}".format(num),
                    BatchNormalization(name="batch_norm_{}".format(num)))

        self.dropout = Dropout(dropout_rate, name="dropout")
        self.classifier = Dense(num_labels,
                                activation="softmax",
                                name="classifier")

    def build(self, inputs_shape):
        self.flatten.build(inputs_shape)
        self.dense_1.build((None, tf.math.reduce_prod(inputs_shape[1:])))
        for layer in self.layers:
            if layer.name not in ["flatten", "dense_1"]:
                layer.build((None, self.dense_1.units))
        super(MLP, self).build(inputs_shape)

    def call(self, inputs, training=False):
        x = self.flatten(inputs)
        for num in range(1, self.num_blocks + 1):
            dense = self.get_layer("dense_{}".format(num))
            batch_norm = self.get_layer("batch_norm_{}".format(num))

            x = dense(x)
            x = batch_norm(x)
            x = self.dropout(x, training=training)
        output = self.classifier(x)
        return output
Пример #2
0
                layer_config = Conv2D.get_config(layer_lst[ii])
                layer_temp = Conv2D.from_config(layer_config)
            elif layer_type == "MaxPooling2D":
                layer_config = MaxPooling2D.get_config(layer_lst[ii])
                layer_temp = MaxPooling2D.from_config(layer_config)
            elif layer_type == "Dense":
                layer_config = Dense.get_config(layer_lst[ii])
                layer_temp = Dense.from_config(layer_config)
            elif layer_type == "Flatten":
                layer_temp = Flatten()
            else:
                layer_config = tf.keras.layers.Layer.get_config(layer_lst[ii])
                layer_temp = tf.keras.layers.Layer.from_config(layer_config)

            layer_weight = layer_lst[ii].get_weights()
            layer_temp.build(layer_lst[ii].input_shape)
            if layer_weight:
                layer_temp.set_weights(layer_weight)

            model_new.add(layer_temp)

    model_new.compile(optimizer='adam',
                      loss='sparse_categorical_crossentropy',
                      metrics=['accuracy'])
    model_new.summary()

    tf.keras.models.save_model(model_new, 'mnist_swap_model3.h5')

# Evaluating the new model
model_new.evaluate(test_images, test_labels)