Exemplo n.º 1
0
def test_gbm2ds(tmpdir):
    IN_PKL = os.path.join(DATA_PATH, 'gbm.pmml')
    OUT_SAS = os.path.join(str(tmpdir), 'gbm.sas')
    EXPECTED_SAS = os.path.join(DATA_PATH, 'gbm.sas')

    pyml2ds(IN_PKL, OUT_SAS, test=True)
    result = open(OUT_SAS, 'rb').read()
    expected = open(EXPECTED_SAS, 'rb').read()
    assert result == expected
Exemplo n.º 2
0
def test_gbm2ds():
    IN_PKL = os.path.join(DATA_PATH, 'gbm.pmml')
    EXPECTED_SAS = os.path.join(DATA_PATH, 'gbm.sas')

    from sasctl.utils.pyml2ds.connectors.ensembles.pmml import PmmlTreeParser

    # Expected output contains integer values instead of floats.
    # Convert to ensure match.
    class TestPmmlTreeParser(PmmlTreeParser):
        def _split_value(self):
            val = super(TestPmmlTreeParser, self)._split_value()
            return int(float(val))

        def _leaf_value(self):
            val = super(TestPmmlTreeParser, self)._leaf_value()
            return int(float(val))

    test_parser = TestPmmlTreeParser()

    with mock.patch(
            'sasctl.utils.pyml2ds.connectors.ensembles.pmml.PmmlTreeParser'
    ) as parser:
        parser.return_value = test_parser
        result = pyml2ds(IN_PKL)

    expected = open(EXPECTED_SAS, 'r').read()
    assert result == expected
Exemplo n.º 3
0
def test_lgb2ds():
    pytest.importorskip('lightgbm')

    IN_PKL = os.path.join(DATA_PATH, 'lgb.pkl')
    EXPECTED_SAS = os.path.join(DATA_PATH, 'lgb_datastep')

    from sasctl.utils.pyml2ds.connectors.ensembles.lgb import LightgbmTreeParser

    # Expected output contains integer values instead of floats.
    # Convert to ensure match.
    class TestLightgbmTreeParser(LightgbmTreeParser):
        def _split_value(self):
            val = super(TestLightgbmTreeParser, self)._split_value()
            return int(float(val))

        def _leaf_value(self):
            val = super(TestLightgbmTreeParser, self)._leaf_value()
            return int(float(val))

    test_parser = TestLightgbmTreeParser()

    with mock.patch(
            'sasctl.utils.pyml2ds.connectors.ensembles.lgb.LightgbmTreeParser'
    ) as parser:
        parser.return_value = test_parser
        result = pyml2ds(IN_PKL)

    with open(EXPECTED_SAS, 'r') as f:
        expected = f.read()
    assert result == expected
Exemplo n.º 4
0
def test_pickle_input():
    """pyml2ds should accept a binary pickle string as input."""
    import pickle
    from sasctl.utils.pyml2ds import pyml2ds

    # The target "model" to use
    target = {'msg': 'hello world'}

    # Pickle the "model" to a file-like object
    in_file = pickle.dumps(target)
    out_file = 'model.sas'

    with mock.patch('sasctl.utils.pyml2ds.core._check_type') as check:
        check.translate.return_value = 'translated'
        pyml2ds(in_file, out_file)

    # Verify _check_type should have been called with the "model"
    assert check.call_count == 1
    assert check.call_args[0][0] == target
Exemplo n.º 5
0
def test_path_input(tmpdir_factory):
    """pyml2ds should accept a file path (str) as input."""
    import pickle
    from sasctl.utils.pyml2ds import pyml2ds

    # The target "model" to use
    target = {'msg': 'hello world'}

    # Pickle the "model" to a file
    temp_dir = tmpdir_factory.mktemp('pyml2ds')
    in_file = str(temp_dir.join('model.pkl'))
    out_file = str(temp_dir.join('model.sas'))
    with open(in_file, 'wb') as f:
        pickle.dump(target, f)

    with mock.patch('sasctl.utils.pyml2ds.core._check_type') as check:
        check.translate.return_value = 'translated'
        pyml2ds(in_file, out_file)

    # Verify _check_type should have been called with the "model"
    assert check.call_count == 1
    assert check.call_args[0][0] == target