def test_schema_and_examples_are_save_correctly(saved_tf_iris_model):
    train_x, train_y = iris_data_utils.load_data()[0]
    X = pd.DataFrame(train_x)
    y = pd.Series(train_y)
    for signature in (None, infer_signature(X, y)):
        for example in (None, X.head(3)):
            with TempDir() as tmp:
                path = tmp.path("model")
                mlflow.tensorflow.save_model(
                    tf_saved_model_dir=saved_tf_iris_model.path,
                    tf_meta_graph_tags=saved_tf_iris_model.meta_graph_tags,
                    tf_signature_def_key=saved_tf_iris_model.signature_def_key,
                    path=path,
                    signature=signature,
                    input_example=example,
                )
                mlflow_model = Model.load(path)
                assert signature == mlflow_model.signature
                if example is None:
                    assert mlflow_model.saved_input_example_info is None
                else:
                    assert all((_read_example(mlflow_model, path) == example).all())
def saved_tf_iris_model(tmpdir):
    # Following code from
    # https://github.com/tensorflow/models/blob/master/samples/core/get_started/premade_estimator.py
    train_x, train_y = iris_data_utils.load_data()[0]

    # Feature columns describe how to use the input.
    my_feature_columns = []
    for key in train_x.keys():
        my_feature_columns.append(tf.feature_column.numeric_column(key=key))

    # Build 2 hidden layer DNN with 10, 10 units respectively.
    estimator = tf.estimator.DNNClassifier(
        feature_columns=my_feature_columns,
        # Two hidden layers of 10 nodes each.
        hidden_units=[10, 10],
        # The model must choose between 3 classes.
        n_classes=3,
    )

    # Train the Model.
    batch_size = 100
    train_steps = 1000
    estimator.train(
        input_fn=lambda: iris_data_utils.train_input_fn(
            train_x, train_y, batch_size),
        steps=train_steps,
    )

    # Generate predictions from the model
    predict_x = {
        "SepalLength": [5.1, 5.9, 6.9],
        "SepalWidth": [3.3, 3.0, 3.1],
        "PetalLength": [1.7, 4.2, 5.4],
        "PetalWidth": [0.5, 1.5, 2.1],
    }

    estimator_preds = estimator.predict(
        lambda: iris_data_utils.eval_input_fn(predict_x, None, batch_size))

    # Building a dictionary of the predictions by the estimator.
    if sys.version_info < (3, 0):
        estimator_preds_dict = estimator_preds.next()
    else:
        estimator_preds_dict = next(estimator_preds)
    for row in estimator_preds:
        for key in row.keys():
            estimator_preds_dict[key] = np.vstack(
                (estimator_preds_dict[key], row[key]))

    # Building a pandas DataFrame out of the prediction dictionary.
    estimator_preds_df = copy.deepcopy(estimator_preds_dict)
    for col in estimator_preds_df.keys():
        if all(len(element) == 1 for element in estimator_preds_df[col]):
            estimator_preds_df[col] = estimator_preds_df[col].ravel()
        else:
            estimator_preds_df[col] = estimator_preds_df[col].tolist()

    # Building a DataFrame that contains the names of the flowers predicted.
    estimator_preds_df = pandas.DataFrame.from_dict(data=estimator_preds_df)
    estimator_preds_results = [
        iris_data_utils.SPECIES[id[0]]
        for id in estimator_preds_dict["class_ids"]
    ]
    estimator_preds_results_df = pd.DataFrame(
        {"predictions": estimator_preds_results})

    # Define a function for estimator inference
    feature_spec = {}
    for name in my_feature_columns:
        feature_spec[name.key] = tf.Variable([],
                                             dtype=tf.float64,
                                             name=name.key)

    receiver_fn = tf.estimator.export.build_raw_serving_input_receiver_fn(
        feature_spec)

    # Save the estimator and its inference function
    saved_estimator_path = str(tmpdir.mkdir("saved_model"))
    saved_estimator_path = estimator.export_saved_model(
        saved_estimator_path, receiver_fn).decode("utf-8")
    return SavedModelInfo(
        path=saved_estimator_path,
        meta_graph_tags=["serve"],
        signature_def_key="predict",
        inference_df=pd.DataFrame(
            data=predict_x, columns=[name.key for name in my_feature_columns]),
        expected_results_df=estimator_preds_results_df,
        raw_results=estimator_preds_dict,
        raw_df=estimator_preds_df,
    )