示例#1
0
def test_get_bin_index_on_categorical_value():
    g = graph.create_graph()
    i = graph.create_input(g, "i", onnx.TensorProto.STRING, [None, 1])

    g = ebm.get_bin_index_on_categorical_value({
        'foo': 1,
        'bar': 2,
        'biz': 3,
    })(i)
    g = graph.add_output(g, g.transients[0].name, onnx.TensorProto.INT64,
                         [None, 1])

    result = infer_model(graph.compile(g, target_opset=13),
                         input={
                             'i': [["biz"], ["foo"], ["bar"], ["nan"],
                                   ["okif"]],
                         })
    expected_result = [
        [3],
        [1],
        [2],
        [0],
        [-1],
    ]
    assert len(expected_result) == len(result[0])
    for i, r in enumerate(expected_result):
        assert result[0][i] == r
示例#2
0
def test_less():
    g = graph.create_graph()

    a = graph.create_initializer(g, "a", onnx.TensorProto.FLOAT, [4],
                                 [1.1, 2.3, 3.5, 9.6])
    b = graph.create_input(g, "b", onnx.TensorProto.FLOAT, [None, 1])

    l = ops.less()(graph.merge(a, b))
    l = graph.add_output(l, l.transients[0].name, onnx.TensorProto.BOOL,
                         [None, 4])

    assert_model_result(l,
                        input={'b': [
                            [0.1],
                            [1.2],
                            [11],
                            [4.2],
                            [np.NaN],
                        ]},
                        expected_result=[[
                            [False, False, False, False],
                            [True, False, False, False],
                            [True, True, True, True],
                            [True, True, True, False],
                            [False, False, False, False],
                        ]])
示例#3
0
def test_get_bin_score_1d_multiclass():
    """test on 3 classes
    shape of scores is [bin_count x class_count]
    """
    g = graph.create_graph()
    i = graph.create_input(g, "i", onnx.TensorProto.INT64, [None, 1])

    g = ebm.get_bin_score_1d(
        np.array([
            [0.0, 1.0, 2.0],
            [0.1, 1.1, 2.1],
            [0.2, 1.2, 2.2],
            [0.3, 1.3, 2.3],
        ]))(i)
    g = graph.add_output(g, g.transients[0].name, onnx.TensorProto.FLOAT,
                         [None, 1, 3])

    assert_model_result(
        g,
        input={'i': [
            [3],
            [1],
            [2],
            [0],
            [2],
        ]},
        expected_result=[[
            [[0.3, 1.3, 2.3]],
            [[0.1, 1.1, 2.1]],
            [[0.2, 1.2, 2.2]],
            [[0.0, 1.0, 2.0]],
            [[0.2, 1.2, 2.2]],
        ]],
    )
示例#4
0
def test_get_bin_score_2d():
    g = graph.create_graph()
    i1 = graph.create_input(g, "i1", onnx.TensorProto.INT64, [None, 1])
    i2 = graph.create_input(g, "i2", onnx.TensorProto.INT64, [None, 1])

    i = graph.merge(i1, i2)
    g = ebm.get_bin_score_2d(
        np.array([
            [0.0, 0.1, 0.2, 0.3],
            [1.0, 2.1, 3.2, 4.3],
            [10.0, 20.1, 30.2, 40.3],
        ]))(i)
    g = graph.add_output(g, g.transients[0].name, onnx.TensorProto.FLOAT,
                         [None, 1, 1])

    assert_model_result(g,
                        input={
                            'i1': [[2], [1], [2], [0]],
                            'i2': [[3], [0], [2], [1]],
                        },
                        expected_result=[[
                            [[40.3]],
                            [[1.0]],
                            [[30.2]],
                            [[0.1]],
                        ]])
示例#5
0
def test_create_graph():
    g = graph.create_graph()

    assert g.generate_name is not None
    assert g.inputs == []
    assert g.outputs == []
    assert g.nodes == []
示例#6
0
def test_compute_multiclass_score():
    g = graph.create_graph()
    i1 = graph.create_input(g, "i1", onnx.TensorProto.FLOAT, [None, 1, 3])
    i2 = graph.create_input(g, "i2", onnx.TensorProto.FLOAT, [None, 1, 3])
    i3 = graph.create_input(g, "i3", onnx.TensorProto.FLOAT, [None, 1, 3])

    i = graph.merge(i1, i2, i3)
    g, _ = ebm.compute_class_score(np.array([0.1, 0.2, 0.3]))(i)
    g = graph.add_output(g, g.transients[0].name, onnx.TensorProto.FLOAT,
                         [None, 3])

    assert_model_result(g,
                        input={
                            'i1': [[[0.1, 0.2, 0.3]], [[0.2, 0.3, 0.4]],
                                   [[0.3, 0.4, 0.5]], [[0.4, 0.5, 0.6]]],
                            'i2': [[[1.1, 1.2, 1.3]], [[1.2, 1.3, 1.4]],
                                   [[1.3, 1.4, 1.5]], [[1.4, 1.5, 1.6]]],
                            'i3': [[[2.1, 2.2, 2.3]], [[2.2, 2.3, 2.4]],
                                   [[2.3, 2.4, 2.5]], [[2.4, 2.5, 2.6]]],
                        },
                        expected_result=[[
                            [3.4, 3.8, 4.2],
                            [3.7, 4.1, 4.5],
                            [4.0, 4.4, 4.8],
                            [4.3, 4.7, 5.1],
                        ]])
示例#7
0
def test_create_one_input():
    g = graph.create_graph()

    input = graph.create_input(g, "foo", onnx.TensorProto.FLOAT, [None, 3])
    assert len(input.inputs) == 1
    assert input.inputs == [onnx.helper.make_tensor_value_info(
        'foo' ,
        onnx.TensorProto.FLOAT,
        [None, 3])
    ]
    assert input.inputs == input.transients
示例#8
0
def test_predict_class_binary():
    g = graph.create_graph()
    i = graph.create_input(g, "i", onnx.TensorProto.FLOAT, [None, 1])

    g = ebm.predict_class(binary=True)(i)
    g = graph.add_output(g, g.transients[0].name, onnx.TensorProto.INT64,
                         [None])

    assert_model_result(g,
                        input={'i': [[3.5], [-3.8], [-0.1], [0.2]]},
                        expected_result=[[1, 0, 0, 1]])
示例#9
0
def test_create_initializer():
    g = graph.create_graph()

    init = graph.create_initializer(g, "foo", onnx.TensorProto.FLOAT, [4], [0.1, 0.2, 0.3, 0.4])
    assert len(init.initializers) == 1
    assert init.initializers == [onnx.helper.make_tensor(
        'foo_0' ,
        onnx.TensorProto.FLOAT,
        [4],
        [0.1, 0.2, 0.3, 0.4]
    )]
    assert init.initializers == init.transients
示例#10
0
def test_flatten():
    g = graph.create_graph()

    i = graph.create_input(g, "i", onnx.TensorProto.FLOAT, [None])

    l = ops.flatten()(i)
    l = graph.add_output(l, l.transients[0].name, onnx.TensorProto.FLOAT,
                         [None, 1])

    assert_model_result(l,
                        input={'i': [0.1, 0.2, 0.3, 0.4]},
                        expected_result=[[[0.1], [0.2], [0.3], [0.4]]])
示例#11
0
def test_add():
    g = graph.create_graph()

    a = graph.create_initializer(g, "a", onnx.TensorProto.FLOAT, [1], [0.3])
    i = graph.create_input(g, "i", onnx.TensorProto.FLOAT, [None])

    l = ops.add()(graph.merge(i, a))
    l = graph.add_output(l, l.transients[0].name, onnx.TensorProto.FLOAT,
                         [None])

    assert_model_result(l,
                        input={
                            'i': [0.1, 1.2, 11, 4.2],
                        },
                        expected_result=[[0.4, 1.5, 11.3, 4.5]])
示例#12
0
def test_predict_proba_binary():
    g = graph.create_graph()
    i = graph.create_input(g, "i", onnx.TensorProto.FLOAT, [None, 1])

    g = ebm.predict_proba(binary=True)(i)
    g = graph.add_output(g, g.transients[0].name, onnx.TensorProto.FLOAT,
                         [None, 2])

    assert_model_result(g,
                        input={'i': [[3.5], [-3.8], [-0.1], [0.2]]},
                        expected_result=[[
                            [0.02931223, 0.97068775],
                            [0.97811866, 0.02188127],
                            [0.5249792, 0.4750208],
                            [0.450166, 0.54983395],
                        ]])
示例#13
0
def test_cast():
    g = graph.create_graph()

    i = graph.create_input(g, "i", onnx.TensorProto.INT64, [None, 1])

    l = ops.cast(onnx.TensorProto.FLOAT)(i)
    l = graph.add_output(l, l.transients[0].name, onnx.TensorProto.FLOAT,
                         [None, 1])

    assert_model_result(l,
                        input={'i': [
                            [1],
                            [2],
                            [11],
                            [4],
                        ]},
                        expected_result=[[[1.0], [2.0], [11.0], [4.0]]])
示例#14
0
def test_predict_multiclass_binary():
    g = graph.create_graph()
    i = graph.create_input(g, "i", onnx.TensorProto.FLOAT, [None, 3])

    g = ebm.predict_class(binary=False)(i)
    g = graph.add_output(g, g.transients[0].name, onnx.TensorProto.INT64,
                         [None])

    assert_model_result(g,
                        input={
                            'i': [
                                [3.4, 3.8, 4.2],
                                [3.7, 4.1, 0.5],
                                [4.0, 0.4, 0.8],
                                [4.3, 4.7, 5.1],
                            ]
                        },
                        expected_result=[[2, 1, 0, 2]])
示例#15
0
def test_reshape():
    g = graph.create_graph()

    shape = graph.create_initializer(g, "shape", onnx.TensorProto.INT64, [1],
                                     [0])
    i = graph.create_input(g, "i", onnx.TensorProto.FLOAT, [None, 1])

    l = ops.reshape()(graph.merge(i, shape))
    l = graph.add_output(l, l.transients[0].name, onnx.TensorProto.FLOAT,
                         [None])

    assert_model_result(l,
                        input={'i': [
                            [0.1],
                            [1.2],
                            [11],
                            [4.2],
                        ]},
                        expected_result=[[0.1, 1.2, 11, 4.2]])
示例#16
0
def test_create_several_inputs():
    g = graph.create_graph()

    i1 = graph.create_input(g, "foo", onnx.TensorProto.FLOAT, [None, 3])
    i2 = graph.create_input(g, "bar", onnx.TensorProto.INT64, [None, 2])

    assert i1.inputs == [onnx.helper.make_tensor_value_info(
        'foo' ,
        onnx.TensorProto.FLOAT,
        [None, 3])
    ]
    assert i1.inputs == i1.transients

    assert i2.inputs == [onnx.helper.make_tensor_value_info(
        f'bar' ,
        onnx.TensorProto.INT64,
        [None, 2])
    ]
    assert i2.inputs == i2.transients
示例#17
0
def test_argmax():
    g = graph.create_graph()
    i = graph.create_input(g, "i", onnx.TensorProto.FLOAT, [None, 3])

    l = ops.argmax(axis=1)(i)
    l = graph.add_output(l, l.transients[0].name, onnx.TensorProto.INT64,
                         [None, 1])

    assert_model_result(l,
                        input={'i': [
                            [1, 4, 2],
                            [2, 8, 12],
                            [11, 0, 5],
                        ]},
                        expected_result=[[
                            [1],
                            [2],
                            [0],
                        ]])
示例#18
0
def test_mul():
    g = graph.create_graph()

    a = graph.create_initializer(g, "a", onnx.TensorProto.FLOAT, [3],
                                 [1.0, 2.0, 3.0])
    b = graph.create_input(g, "b", onnx.TensorProto.FLOAT, [None, 3])

    l = ops.mul()(graph.merge(a, b))
    l = graph.add_output(l, l.transients[0].name, onnx.TensorProto.FLOAT,
                         [None, 3])

    assert_model_result(l,
                        input={'b': [
                            [0.1, 0.1, 0.1],
                            [0.1, 0.2, 0.3],
                        ]},
                        expected_result=[[
                            [0.1, 0.2, 0.3],
                            [0.1, 0.4, 0.9],
                        ]])
示例#19
0
def test_concat():
    g = graph.create_graph()

    a = graph.create_input(g, "a", onnx.TensorProto.FLOAT, [3, 1])
    b = graph.create_input(g, "b", onnx.TensorProto.FLOAT, [3, 1])

    l = ops.concat(axis=1)(graph.merge(a, b))
    l = graph.add_output(l, l.transients[0].name, onnx.TensorProto.FLOAT,
                         [None, 2])

    assert_model_result(l,
                        input={
                            'a': [[0.1], [0.2], [0.3]],
                            'b': [[1.1], [1.2], [1.3]],
                        },
                        expected_result=[[
                            [0.1, 1.1],
                            [0.2, 1.2],
                            [0.3, 1.3],
                        ]])
示例#20
0
def test_softmax():
    g = graph.create_graph()

    i = graph.create_input(g, "i", onnx.TensorProto.FLOAT, [None, 2])

    l = ops.softmax()(i)
    l = graph.add_output(l, l.transients[0].name, onnx.TensorProto.FLOAT,
                         [None, 2])

    assert_model_result(
        l,
        input={'i': [
            [0.0, 0.68],
            [0.0, 0.2],
            [1.2, 0.3],
            [0.0, -0.2],
        ]},
        expected_result=[[[0.3362613, 0.66373867], [0.450166, 0.54983395],
                          [0.7109495, 0.2890505], [0.54983395, 0.450166]]],
    )
示例#21
0
def test_get_bin_score_1d():
    g = graph.create_graph()
    i = graph.create_input(g, "i", onnx.TensorProto.INT64, [None, 1])

    g = ebm.get_bin_score_1d(np.array([0.0, 0.1, 0.2, 0.3]))(i)
    g = graph.add_output(g, g.transients[0].name, onnx.TensorProto.FLOAT,
                         [None, 1, 1])

    assert_model_result(g,
                        input={'i': [
                            [3],
                            [1],
                            [2],
                            [0],
                        ]},
                        expected_result=[[
                            [[0.3]],
                            [[0.1]],
                            [[0.2]],
                            [[0.0]],
                        ]])
示例#22
0
def test_reduce_sum():
    g = graph.create_graph()

    axis = graph.create_initializer(g, "axis", onnx.TensorProto.INT64, [1],
                                    [1])
    i = graph.create_input(g, "i", onnx.TensorProto.FLOAT, [None, 3])

    l = ops.reduce_sum(keepdims=0)(graph.merge(i, axis))
    l = graph.add_output(l, l.transients[0].name, onnx.TensorProto.FLOAT,
                         [None])

    assert_model_result(l,
                        input={
                            'i': [
                                [0.1, 1.0, 1.2],
                                [1.2, 0.4, 0.9],
                                [11, 0.8, -0.2],
                                [4.2, 3.2, -6.4],
                            ]
                        },
                        expected_result=[[2.3, 2.5, 11.6, 1.0]])
示例#23
0
def test_gather_elements():
    g = graph.create_graph()

    a = graph.create_initializer(g, "a", onnx.TensorProto.FLOAT, [3, 1],
                                 [0.1, 0.2, 0.3])
    b = graph.create_input(g, "b", onnx.TensorProto.INT64, [None, 1])

    l = ops.gather_elements()(graph.merge(a, b))
    l = graph.add_output(l, l.transients[0].name, onnx.TensorProto.FLOAT,
                         [None, 1])

    assert_model_result(l,
                        input={'b': [
                            [2],
                            [1],
                            [0],
                        ]},
                        expected_result=[[
                            [0.3],
                            [0.2],
                            [0.1],
                        ]])
示例#24
0
def test_get_bin_index_on_continuous_value():
    g = graph.create_graph()
    i = graph.create_input(g, "i", onnx.TensorProto.FLOAT, [None, 1])

    g = ebm.get_bin_index_on_continuous_value(
        [np.NINF, np.NINF, 0.2, 0.7, 1.2, 4.3])(i)
    g = graph.add_output(g, g.transients[0].name, onnx.TensorProto.INT64,
                         [None, 1])

    assert_model_result(g,
                        input={'i': [
                            [1.3],
                            [0.6999],
                            [-9.6],
                            [9.6],
                        ]},
                        expected_result=[[
                            [4],
                            [2],
                            [1],
                            [5],
                        ]])
示例#25
0
def test_strip_to_transients():
    g = graph.create_graph()

    input1 = graph.create_input(g, "bar1", onnx.TensorProto.FLOAT, [None, 3])
    input2 = graph.create_input(g, "bar2", onnx.TensorProto.FLOAT, [None, 4])

    m = graph.merge(input1, input2)
    m = graph.strip_to_transients(m)

    assert m.initializers == []
    assert m.inputs == []
    assert m.transients == [
        onnx.helper.make_tensor_value_info(
            'bar1' ,
            onnx.TensorProto.FLOAT,
            [None, 3],
        ),
        onnx.helper.make_tensor_value_info(
            'bar2' ,
            onnx.TensorProto.FLOAT,
            [None, 4],
        ),
    ]
示例#26
0
def test_gather_nd():
    g = graph.create_graph()

    a = graph.create_initializer(
        g, "a", onnx.TensorProto.FLOAT, [3, 3],
        np.array([
            [0.1, 0.2, 0.3],
            [1.1, 2.2, 3.3],
            [0.1, 20.2, 30.3],
        ]).flatten())
    b = graph.create_input(g, "b", onnx.TensorProto.INT64, [None, 2])

    l = ops.gather_nd()(graph.merge(a, b))
    l = graph.add_output(l, l.transients[0].name, onnx.TensorProto.FLOAT,
                         [None])

    assert_model_result(l,
                        input={'b': [
                            [2, 0],
                            [1, 1],
                            [0, 1],
                        ]},
                        expected_result=np.array([[0.1, 2.2, 0.2]]))
示例#27
0
def test_compute_class_score():
    g = graph.create_graph()
    i1 = graph.create_input(g, "i1", onnx.TensorProto.FLOAT, [None, 1, 1])
    i2 = graph.create_input(g, "i2", onnx.TensorProto.FLOAT, [None, 1, 1])
    i3 = graph.create_input(g, "i3", onnx.TensorProto.FLOAT, [None, 1, 1])

    i = graph.merge(i1, i2, i3)
    g, _ = ebm.compute_class_score(np.array([0.2]))(i)
    g = graph.add_output(g, g.transients[0].name, onnx.TensorProto.FLOAT,
                         [None, 1])

    assert_model_result(g,
                        input={
                            'i1': [[[0.1]], [[0.2]], [[0.3]], [[0.4]]],
                            'i2': [[[1.1]], [[1.2]], [[1.3]], [[1.4]]],
                            'i3': [[[2.1]], [[2.2]], [[2.3]], [[2.4]]],
                        },
                        expected_result=[[
                            [3.5],
                            [3.8],
                            [4.1],
                            [4.4],
                        ]])
示例#28
0
def test_merge():
    g = graph.create_graph()

    init1 = graph.create_initializer(g, "foo", onnx.TensorProto.FLOAT, [4], [0.1, 0.2, 0.3, 0.4])
    init2 = graph.create_initializer(g, "foo", onnx.TensorProto.FLOAT, [4], [1.1, 1.2, 3.3, 4.4])
    input1 = graph.create_input(g, "bar1", onnx.TensorProto.FLOAT, [None, 3])
    input2 = graph.create_input(g, "bar2", onnx.TensorProto.FLOAT, [None, 4])

    m = graph.merge(init1, input1, init2, input2)

    assert len(m.initializers) == 2
    assert len(m.inputs) == 2
    assert len(m.transients) == 4

    assert m.initializers == [
        onnx.helper.make_tensor(
            'foo_0' ,
            onnx.TensorProto.FLOAT,
            [4],
            [0.1, 0.2, 0.3, 0.4]
        ),
        onnx.helper.make_tensor(
            'foo_1' ,
            onnx.TensorProto.FLOAT,
            [4],
            [1.1, 1.2, 3.3, 4.4]
        ),
    ]

    assert m.inputs == [
        onnx.helper.make_tensor_value_info(
            'bar1' ,
            onnx.TensorProto.FLOAT,
            [None, 3],
        ),
        onnx.helper.make_tensor_value_info(
            'bar2' ,
            onnx.TensorProto.FLOAT,
            [None, 4],
        ),
    ]

    assert m.transients == [
        onnx.helper.make_tensor(
            'foo_0' ,
            onnx.TensorProto.FLOAT,
            [4],
            [0.1, 0.2, 0.3, 0.4]
        ),
        onnx.helper.make_tensor_value_info(
            'bar1' ,
            onnx.TensorProto.FLOAT,
            [None, 3],
        ),
        onnx.helper.make_tensor(
            'foo_1' ,
            onnx.TensorProto.FLOAT,
            [4],
            [1.1, 1.2, 3.3, 4.4]
        ),
        onnx.helper.make_tensor_value_info(
            'bar2' ,
            onnx.TensorProto.FLOAT,
            [None, 4],
        ),
    ]
示例#29
0
def to_onnx(
    model,
    dtype,
    name="ebm",
    predict_proba=False,
    explain=False,
    target_opset=None,
):
    """Converts an EBM model to ONNX

    Args:
        model: The EBM model, trained with interpretml
        dtype: A dict containing the type of each input feature. Types are expressed as strings, the following values are supported: float, double, int, str.
        name: [Optional] The name of the model
        predict_proba: [Optional] For classification models, output prediction probabilities instead of class
        explain: [Optional] Adds an additional output with the score per feature per class
        target_opset: [Optional] The target onnx opset version to use

    Returns:
        An ONNX model.
    """
    target_opset = target_opset or get_latest_opset_version()
    root = graph.create_graph()

    class_index = 0
    inputs = [None for _ in model.feature_names]
    parts = []

    # first compute the score of each feature
    for feature_index in range(len(model.feature_names)):
        feature_name = model.feature_names[feature_index]
        feature_type = model.feature_types[feature_index]
        feature_group = model.feature_groups_[feature_index]

        if feature_type == 'continuous':
            bins = [np.NINF, np.NINF] + list(
                model.preprocessor_.col_bin_edges_[feature_group[0]])
            additive_terms = model.additive_terms_[feature_index]

            feature_dtype = infer_features_dtype(dtype, feature_name)
            part = graph.create_input(root, feature_name, feature_dtype,
                                      [None])
            part = ops.flatten()(part)
            inputs[feature_index] = part
            part = ebm.get_bin_index_on_continuous_value(bins)(part)
            part = ebm.get_bin_score_1d(additive_terms)(part)
            parts.append(part)

        elif feature_type == 'categorical':
            col_mapping = model.preprocessor_.col_mapping_[feature_group[0]]
            additive_terms = model.additive_terms_[feature_index]

            feature_dtype = infer_features_dtype(dtype, feature_name)
            if feature_dtype != onnx.TensorProto.STRING:
                raise ValueError(
                    "categorical features must be encoded as strings only. "
                    "{} is encoded as {} which is not supported.".format(
                        feature_name, dtype[feature_name]))
            part = graph.create_input(root, feature_name, feature_dtype,
                                      [None])
            part = ops.flatten()(part)
            inputs[feature_index] = part
            part = ebm.get_bin_index_on_categorical_value(col_mapping)(part)
            part = ebm.get_bin_score_1d(additive_terms)(part)
            parts.append(part)

        elif feature_type == 'interaction':
            i_parts = []
            for index in range(2):
                i_feature_index = feature_group[index]
                i_feature_type = model.feature_types[i_feature_index]

                if i_feature_type == 'continuous':
                    bins = [np.NINF, np.NINF
                            ] + list(model.pair_preprocessor_.
                                     col_bin_edges_[i_feature_index])
                    input = graph.strip_to_transients(inputs[i_feature_index])
                    i_parts.append(
                        ebm.get_bin_index_on_continuous_value(bins)(input))

                elif i_feature_type == 'categorical':
                    col_mapping = model.preprocessor_.col_mapping_[
                        i_feature_index]
                    input = graph.strip_to_transients(inputs[i_feature_index])
                    i_parts.append(
                        ebm.get_bin_index_on_categorical_value(col_mapping)(
                            input))

                else:
                    raise NotImplementedError(
                        f"feature type {feature_type} is not supported in interactions"
                    )

            part = graph.merge(*i_parts)
            additive_terms = model.additive_terms_[feature_index]
            part = ebm.get_bin_score_2d(np.array(additive_terms))(part)
            parts.append(part)

        else:
            raise NotImplementedError(
                f"feature type {feature_type} is not supported")

    # compute scores, predict and proba
    g = graph.merge(*parts)
    if type(model) is ExplainableBoostingClassifier:
        g, scores_output_name = ebm.compute_class_score(model.intercept_)(g)
        if len(model.classes_) == 2:  # binary classification
            if predict_proba is False:
                g = ebm.predict_class(binary=True)(g)
                g = graph.add_output(g, g.transients[0].name,
                                     onnx.TensorProto.INT64, [None])
            else:
                g = ebm.predict_proba(binary=True)(g)
                g = graph.add_output(g, g.transients[0].name,
                                     onnx.TensorProto.FLOAT,
                                     [None, len(model.classes_)])
        else:
            if predict_proba is False:
                g = ebm.predict_class(binary=False)(g)
                g = graph.add_output(g, g.transients[0].name,
                                     onnx.TensorProto.INT64, [None])
            else:
                g = ebm.predict_proba(binary=False)(g)
                g = graph.add_output(g, g.transients[0].name,
                                     onnx.TensorProto.FLOAT,
                                     [None, len(model.classes_)])
    elif type(model) is ExplainableBoostingRegressor:
        g, scores_output_name = ebm.compute_class_score(
            np.array([model.intercept_]))(g)
        g = ebm.predict_value()(g)
        g = graph.add_output(g, g.transients[0].name, onnx.TensorProto.FLOAT,
                             [None])
    else:
        raise NotImplementedError("{} models are not supported".format(
            type(model)))

    if explain is True:
        if len(model.classes_) == 2:
            g = graph.add_output(g, scores_output_name, onnx.TensorProto.FLOAT,
                                 [None, len(model.feature_names), 1])
        else:
            g = graph.add_output(
                g, scores_output_name, onnx.TensorProto.FLOAT,
                [None, len(model.feature_names),
                 len(model.classes_)])

    model = graph.compile(g, target_opset, name=name)
    return model