Ejemplo n.º 1
0
def test_keras_load_model_on_resume(tmp_dir, xor_model, mocker,
                                    save_weights_only, capture_wrap):
    import dvclive.keras

    model, x, y = xor_model()

    if save_weights_only:
        model.save_weights("model.h5")
    else:
        model.save("model.h5")

    load_weights = mocker.spy(model, "load_weights")
    load_model = mocker.spy(dvclive.keras, "load_model")

    model.fit(
        x,
        y,
        epochs=1,
        batch_size=1,
        callbacks=[
            DvcLiveCallback(
                model_file="model.h5",
                save_weights_only=save_weights_only,
                resume=True,
            )
        ],
    )

    assert load_model.call_count != save_weights_only
    assert load_weights.call_count == save_weights_only
Ejemplo n.º 2
0
def test_keras_callback(tmp_dir, xor_model, capture_wrap):
    model, x, y = xor_model()

    model.fit(
        x,
        y,
        epochs=1,
        batch_size=1,
        callbacks=[DvcLiveCallback()],
    )

    assert os.path.exists("dvclive")
    logs, _ = read_logs(tmp_dir / "dvclive" / Scalar.subfolder)

    assert "accuracy" in logs
Ejemplo n.º 3
0
def test_keras_model_file(tmp_dir, xor_model, mocker, save_weights_only,
                          capture_wrap):
    model, x, y = xor_model()
    save = mocker.spy(model, "save")
    save_weights = mocker.spy(model, "save_weights")

    model.fit(
        x,
        y,
        epochs=1,
        batch_size=1,
        callbacks=[
            DvcLiveCallback(model_file="model.h5",
                            save_weights_only=save_weights_only)
        ],
    )
    assert save.call_count != save_weights_only
    assert save_weights.call_count == save_weights_only
Ejemplo n.º 4
0
def test_keras_None_model_file_skip_load(tmp_dir, xor_model, mocker,
                                         capture_wrap):
    model, x, y = xor_model()

    model.save_weights("model.h5")

    load_weights = mocker.spy(model, "load_weights")

    model.fit(
        x,
        y,
        epochs=1,
        batch_size=1,
        callbacks=[DvcLiveCallback(
            save_weights_only=True,
            resume=True,
        )],
    )

    assert load_weights.call_count == 0
Ejemplo n.º 5
0
                      include_top=False,
                      weights='imagenet')

inputs = tf.keras.Input(shape=IMG_SHAPE)
x = data_augmentation(inputs)
x = PREPROCESS_INPUT(x)
x = base_model(x, training=False)
x = tf.keras.layers.GlobalAveragePooling2D()(x)
x = tf.keras.layers.Dropout(0.2)(x)
outputs = tf.keras.layers.Dense(1)(x)

model = tf.keras.Model(inputs, outputs)

callbacks = [
    # Use dvclive's Keras callback
    DvcLiveCallback(),
    ModelCheckpoint(str(TRAIN_DIR / "best_weights.h5"), save_best_only=True),
]

#%% Freeze the base model and train 10 epochs
base_model.trainable = False

model.compile(
    optimizer=tf.keras.optimizers.Adam(lr=LEARNING_RATE),
    loss=tf.keras.losses.BinaryCrossentropy(from_logits=True),
    metrics=["accuracy"],
)
model.summary()

history = model.fit(
    train_dataset,
Ejemplo n.º 6
0
def main():
    params = load_params()
    m = get_model(conv_units=params['model']['conv_units'])
    m.summary()

    training_images, training_labels, testing_images, testing_labels = read_dataset(
        DATASET_FILE)

    assert training_images.shape[0] + testing_images.shape[0] == 70000
    assert training_labels.shape[0] + testing_labels.shape[0] == 70000

    training_images = normalize(training_images)
    testing_images = normalize(testing_images)

    training_labels = tf.keras.utils.to_categorical(training_labels,
                                                    num_classes=10,
                                                    dtype="float32")
    testing_labels = tf.keras.utils.to_categorical(testing_labels,
                                                   num_classes=10,
                                                   dtype="float32")

    # We use the test set as validation for simplicity
    x_train = training_images
    x_valid = testing_images
    y_train = training_labels
    y_valid = testing_labels

    history = m.fit(
        x_train,
        y_train,
        batch_size=BATCH_SIZE,
        epochs=params["train"]["epochs"],
        verbose=1,
        validation_data=(x_valid, y_valid),
        callbacks=[DvcLiveCallback(model_file=f"{OUTPUT_DIR}/model.h5")],
    )

    metrics_dict = m.evaluate(
        testing_images,
        testing_labels,
        batch_size=BATCH_SIZE,
        return_dict=True,
    )

    with open(METRICS_FILE, "w") as f:
        f.write(json.dumps(metrics_dict))

    misclassified = {}

    # predictions for the confusion matrix
    y_prob = m.predict(x_valid)
    y_pred = y_prob.argmax(axis=-1)
    os.makedirs("plots")
    with open("plots/confusion.csv", "w") as f:
        f.write("actual,predicted\n")
        sx = y_valid.shape[0]
        for i in range(sx):
            actual = y_valid[i].argmax()
            predicted = y_pred[i]
            f.write(f"{actual},{predicted}\n")
            misclassified[(actual, predicted)] = x_valid[i]

    # find misclassified examples and generate a confusion table image
    confusion_out = create_image_matrix(misclassified)
    imageio.imwrite("plots/confusion.png", confusion_out)