Example #1
0
def get_gluon_random_data_run(log_models=True):
    mlflow.gluon.autolog(log_models)

    with mlflow.start_run() as run:
        data = DataLoader(LogsDataset(), batch_size=128, last_batch="discard")
        validation = DataLoader(LogsDataset(),
                                batch_size=128,
                                last_batch="discard")

        model = HybridSequential()
        model.add(Dense(64, activation="relu"))
        model.add(Dense(64, activation="relu"))
        model.add(Dense(10))
        model.initialize()
        model.hybridize()
        trainer = Trainer(
            model.collect_params(),
            "adam",
            optimizer_params={
                "learning_rate": 0.001,
                "epsilon": 1e-07
            },
        )
        est = get_estimator(model, trainer)

        with warnings.catch_warnings():
            warnings.simplefilter("ignore")
            est.fit(data, epochs=3, val_data=validation)
    client = mlflow.tracking.MlflowClient()
    return client.get_run(run.info.run_id)
Example #2
0
def test_autolog_registering_model():
    registered_model_name = "test_autolog_registered_model"
    mlflow.gluon.autolog(registered_model_name=registered_model_name)

    data = DataLoader(LogsDataset(), batch_size=128, last_batch="discard")

    model = HybridSequential()
    model.add(Dense(64, activation="relu"))
    model.add(Dense(10))
    model.initialize()
    model.hybridize()

    trainer = Trainer(model.collect_params(),
                      "adam",
                      optimizer_params={
                          "learning_rate": 0.001,
                          "epsilon": 1e-07
                      })
    est = get_estimator(model, trainer)

    with mlflow.start_run(), warnings.catch_warnings():
        warnings.simplefilter("ignore")
        est.fit(data, epochs=3)

        registered_model = MlflowClient().get_registered_model(
            registered_model_name)
        assert registered_model.name == registered_model_name
Example #3
0
def test_autolog_persists_manually_created_run():
    mlflow.gluon.autolog()

    data = DataLoader(LogsDataset(), batch_size=128, last_batch="discard")

    with mlflow.start_run() as run:

        model = HybridSequential()
        model.add(Dense(64, activation="relu"))
        model.add(Dense(64, activation="relu"))
        model.add(Dense(10))
        model.initialize()
        model.hybridize()
        trainer = Trainer(
            model.collect_params(),
            "adam",
            optimizer_params={
                "learning_rate": 0.001,
                "epsilon": 1e-07
            },
        )
        est = get_estimator(model, trainer)

        with warnings.catch_warnings():
            warnings.simplefilter("ignore")
            est.fit(data, epochs=3)

        assert mlflow.active_run().info.run_id == run.info.run_id
def gluon_model(model_data):
    train_data, train_label, _ = model_data
    dataset = mx.gluon.data.ArrayDataset(train_data, train_label)
    train_data_loader = DataLoader(dataset, batch_size=128, last_batch="discard")
    model = HybridSequential()
    model.add(Dense(128, activation="relu"))
    model.add(Dense(64, activation="relu"))
    model.add(Dense(10))
    model.initialize()
    model.hybridize()
    trainer = Trainer(
        model.collect_params(), "adam", optimizer_params={"learning_rate": 0.001, "epsilon": 1e-07}
    )

    est = get_estimator(model, trainer)

    est.fit(train_data_loader, epochs=3)

    return model
Example #5
0
def test_autolog_ends_auto_created_run():
    mlflow.gluon.autolog()

    data = DataLoader(LogsDataset(), batch_size=128, last_batch="discard")

    model = HybridSequential()
    model.add(Dense(64, activation="relu"))
    model.add(Dense(64, activation="relu"))
    model.add(Dense(10))
    model.initialize()
    model.hybridize()

    trainer = Trainer(
        model.collect_params(), "adam", optimizer_params={"learning_rate": 0.001, "epsilon": 1e-07}
    )
    est = get_estimator(model, trainer)

    est.fit(data, epochs=3)

    assert mlflow.active_run() is None