Example #1
0
def main():
    gpus = tf.config.experimental.list_physical_devices("GPU")
    if gpus:
        try:
            for gpu in gpus:
                tf.config.experimental.set_memory_growth(gpu, True)
        except RuntimeError as e:
            print(e)
            exit(-1)

    data_root = tf.keras.utils.get_file(
        'flower_photos',
        'https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz',
        untar=True)

    train_ds = tf.keras.preprocessing.image_dataset_from_directory(
        str(data_root),
        validation_split=0.2,
        subset="training",
        seed=123,
        image_size=(224, 224),
        batch_size=32)

    val_ds = tf.keras.preprocessing.image_dataset_from_directory(
        str(data_root),
        validation_split=0.2,
        subset="validation",
        seed=123,
        image_size=(224, 224),
        batch_size=32)

    class_names = np.array(train_ds.class_names)

    # create direction for saving weights
    if not os.path.exists("save_weights"):
        os.makedirs("save_weights")

    # # create model
    model = AlexNet()

    model.compile(
        optimizer='adam',
        loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),
        metrics=['acc'])

    history = model.fit(train_ds, epochs=10, validation_data=(val_ds))

    model.save("./save_weights/AlexNet_model")

    # 评估模型
    plt.plot(history.history['accuracy'], label='accuracy')
    plt.plot(history.history['val_accuracy'], label='val_accuracy')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy')
    plt.ylim([0.5, 1])
    plt.legend(loc='lower right')
    plt.show()
Example #2
0
    def __init__(self):
        self._input_shape = (224, 224, 3)
        self._output_dim = 2

        model = AlexNet(self._input_shape, self._output_dim).get_model()
        optimizer = tf.keras.optimizers.Adam(learning_rate=0.0001)
        model.compile(loss=tf.keras.losses.categorical_crossentropy,
                      optimizer=optimizer,
                      metrics=["accuracy"])
        reduce_lro_n_plat = ReduceLROnPlateau(monitor='val_loss',
                                              factor=0.8,
                                              patience=10,
                                              verbose=1,
                                              mode='auto',
                                              min_delta=0.0001,
                                              cooldown=5,
                                              min_lr=1e-10)
        early = EarlyStopping(monitor="val_loss", mode="min", patience=20)

        data_gen = ImageDataGenerator()
        train_it = data_gen.flow_from_directory('data/all/train/',
                                                target_size=(224, 224))
        val_it = data_gen.flow_from_directory('data/all/validation/',
                                              target_size=(224, 224))

        callbacks_list = [early, reduce_lro_n_plat]

        try:
            model.fit(train_it,
                      batch_size=32,
                      epochs=10000,
                      validation_data=val_it,
                      callbacks=callbacks_list,
                      verbose=1)
        except KeyboardInterrupt:
            pass

        model.save_weights("data/model.h5")
        tfjs.converters.save_keras_model(model, "data/model")