def _compare_mleap_pyspark(mleap_prediction, spark_prediction): import pandas spark_pandas = spark_prediction.toPandas() mleap_pandas = mleap_prediction.toPandas() spark_predicted_labels = spark_pandas.prediction.values mleap_predicted_labels = mleap_pandas.prediction.values msg = compare_outputs(spark_predicted_labels, mleap_predicted_labels, decimal=5) if msg: raise OnnxRuntimeAssertionError("Predictions in mleap and spark do not match") spark_probability = spark_pandas.probability.apply(lambda x: pandas.Series(x.toArray())).values mleap_probability = mleap_pandas.probability.apply(lambda x: pandas.Series(x.toArray())).values msg = compare_outputs(spark_probability, mleap_probability, decimal=5) if msg: raise OnnxRuntimeAssertionError("Probabilities in mleap and spark do not match")
def compare_results(expected, output, decimal=5): tested = 0 if isinstance(expected, list): if isinstance(output, list): if len(expected) != len(output): raise OnnxRuntimeAssertionError( "Unexpected number of outputs: expected={0}, got={1}".format(len(expected), len(output))) for exp, out in zip(expected, output): compare_results(exp, out, decimal=decimal) tested += 1 else: raise OnnxRuntimeAssertionError( "Type mismatch: output type is {0}".format(type(output))) elif isinstance(expected, dict): if not isinstance(output, dict): raise OnnxRuntimeAssertionError("Type mismatch fo") for k, v in output.items(): if k not in expected: continue msg = compare_outputs(expected[k], v, decimal=decimal) if msg: raise OnnxRuntimeAssertionError("Unexpected output '{0}': \n{2}".format(k, msg)) tested += 1 elif isinstance(expected, numpy.ndarray): if isinstance(output, list): if expected.shape[0] == len(output) and isinstance(output[0], dict): import pandas output = pandas.DataFrame(output) output = output[list(sorted(output.columns))] output = output.values if isinstance(output, (dict, list)): if len(output) != 1: ex = str(output) if len(ex) > 70: ex = ex[:70] + "..." raise OnnxRuntimeAssertionError( "More than one output when 1 is expected\n{0}".format(ex)) output = output[-1] if not isinstance(output, numpy.ndarray): raise OnnxRuntimeAssertionError( "output must be an array not {0}".format(type(output))) msg = compare_outputs(expected, output, decimal=decimal) if isinstance(msg, ExpectedAssertionError): raise msg if msg: raise OnnxRuntimeAssertionError("Unexpected output\n{1}".format(msg)) tested += 1 else: from scipy.sparse.csr import csr_matrix if isinstance(expected, csr_matrix): # DictVectorizer one_array = numpy.array(output) msg = compare_outputs(expected.todense(), one_array, decimal=decimal) if msg: raise OnnxRuntimeAssertionError("Unexpected output\n{1}".format(msg)) tested += 1 else: raise OnnxRuntimeAssertionError( "Unexpected type for expected output ({0})".format(type(expected))) if tested == 0: raise OnnxRuntimeAssertionError("No test for model")