def test__get_model_store(ms_class, config):
    ms = mock.Mock()
    ms_class.return_value = ms
    ma.clear_caches()

    ms1 = ma._get_model_store(config)
    assert ms1 is ms
def test__get_model_caching(ms_class, config):
    # caching
    ma.clear_caches()
    ms_class.return_value.load_trained_model.return_value = (
        mock.Mock(),
        mock.MagicMock(),
    )

    mw1, mm1 = ma._get_model(config)
    ms_class.return_value.load_trained_model.return_value = (
        mock.Mock(),
        mock.MagicMock(),
    )
    mw2, mm2 = ma._get_model(config)
    assert mw1 is mw2
    assert mm1 is mm2

    # no caching
    ma.clear_caches()
    ms_class.return_value.load_trained_model.return_value = (
        mock.Mock(),
        mock.MagicMock(),
    )

    mw1, mm1 = ma._get_model(config, cache=False)
    ms_class.return_value.load_trained_model.return_value = (
        mock.Mock(),
        mock.MagicMock(),
    )
    mw2, mm2 = ma._get_model(config)
    assert mw1 is not mw2
    assert mm1 is not mm2
def test__get_model(ms_class, config):
    ma.clear_caches()
    model_wrapper = mock.Mock()
    model_meta = mock.MagicMock()
    ms_class.return_value.load_trained_model.return_value = (
        model_wrapper,
        model_meta,
    )

    mw, mm = ma._get_model(config)
    assert mw is model_wrapper
    assert mm is model_meta
def test_clear_caches():
    # TODO: This is testing the implementation, should test the functionality instead
    ma._cached_model_stores = {"a": 1}
    ma._cached_model_tuples = {"a": 1}
    ma._cached_data_source_sink_tuples = {"a": 1}
    ma._cached_model_makers = {"a": 1}
    ma._cached_model_classes = {"a": 1}

    ma.clear_caches()

    assert ma._cached_model_stores == {}
    assert ma._cached_model_tuples == {}
    assert ma._cached_data_source_sink_tuples == {}
    assert ma._cached_model_makers == {}
    assert ma._cached_model_classes == {}
def test__get_model_store_caching(ms_class, config):
    # caching
    ma.clear_caches()
    ms_class.return_value = 1
    ms1 = ma._get_model_store(config)
    ms_class.return_value = 2
    ms2 = ma._get_model_store(config)
    assert ms1 is ms2
    assert ms2 == 1

    # no caching
    ma.clear_caches()
    ms_class.return_value = 1
    ms1 = ma._get_model_store(config, cache=False)
    ms_class.return_value = 2
    ms2 = ma._get_model_store(config)
    assert ms1 is not ms2
def test__get_model_class(ms_class, imp, config):
    ma.clear_caches()
    mc = ma._get_model_class(config)
    imp.assert_called_with("blamodule")
    assert mc is MockModelClass
def test__get_model_maker(ms_class, imp, config):
    ma.clear_caches()
    mm = ma._get_model_maker(config)
    imp.assert_called_with("blamodule")
    assert isinstance(mm, MockModelMakerClass)