Пример #1
0
def test_shared_cache(data_setup):
    train, _ = data_setup

    shared_cache = {}
    main_chain = SharedChain(chain_first(), shared_cache)
    other_chain = SharedChain(chain_first(), shared_cache)
    other_chain.fit(train)

    # test cache is shared
    assert isinstance(main_chain.root_node.cache, SharedCache)
    # test cache is actual
    assert main_chain.root_node.cache.actual_cached_state is not None

    saved_model = main_chain.root_node.cache.actual_cached_state
    main_chain.root_node.cache.clear()
    # test cache is still actual despite the clearing of local cache
    assert main_chain.root_node.cache.actual_cached_state is not None

    shared_cache.clear()
    # test cache is not actual after clearing shared cache
    assert main_chain.root_node.cache.actual_cached_state is None

    main_chain.root_node.cache.append(saved_model)
    # test cache is actual after manual appending of model
    assert main_chain.root_node.cache.actual_cached_state is not None
    assert shared_cache[main_chain.root_node.descriptive_id] == saved_model
Пример #2
0
def test_chain_sharing_and_unsharing(data_setup):
    chain = chain_first()
    assert all([isinstance(node.cache, FittedModelCache) for node in chain.nodes])
    chain = SharedChain(chain, {})

    assert all([isinstance(node.cache, SharedCache) for node in chain.nodes])
    chain = chain.unshare()
    assert all([isinstance(node.cache, FittedModelCache) for node in chain.nodes])
    assert isinstance(chain, Chain)
Пример #3
0
 def metric_for_nodes(self, metric_function, train_data: InputData,
                      test_data: InputData, is_chain_shared: bool,
                      chain: Chain) -> float:
     validate(chain)
     if is_chain_shared:
         chain = SharedChain(base_chain=chain,
                             shared_cache=self.shared_cache)
     chain.fit(input_data=train_data)
     return metric_function(chain, test_data)
Пример #4
0
 def metric_for_nodes(self, metric_function, train_data: InputData,
                      test_data: InputData, is_chain_shared: bool,
                      chain: Chain) -> float:
     try:
         validate(chain)
         if is_chain_shared:
             chain = SharedChain(base_chain=chain,
                                 shared_cache=self.shared_cache)
         chain.fit(input_data=train_data)
         return metric_function(chain, test_data)
     except Exception as ex:
         print(
             f'Error in chain assessment during composition: {ex}. Continue.'
         )
         return max_int_value
Пример #5
0
def test_multi_chain_caching_with_shared_cache(data_setup):
    train, _ = data_setup
    shared_cache = {}

    main_chain = SharedChain(base_chain=chain_second(), shared_cache=shared_cache)
    other_chain = SharedChain(base_chain=chain_first(), shared_cache=shared_cache)

    # fit other_chain that contains the parts identical to main_chain
    other_chain.fit(input_data=train)

    nodes_with_non_actual_cache = [main_chain.root_node, main_chain.root_node.nodes_from[0]] + \
                                  [_ for _ in main_chain.root_node.nodes_from[0].nodes_from]
    nodes_with_actual_cache = [node for node in main_chain.nodes if node not in nodes_with_non_actual_cache]

    # check that using of SharedChain make identical of the main_chain fitted,
    # despite the main_chain.fit() was not called
    assert all([node.cache.actual_cached_state for node in nodes_with_actual_cache])
    # the non-identical parts are still not fitted
    assert not any([node.cache.actual_cached_state for node in nodes_with_non_actual_cache])

    # check the same case with another chains
    shared_cache = {}

    main_chain = SharedChain(base_chain=chain_fourth(), shared_cache=shared_cache)

    prev_chain_first = SharedChain(base_chain=chain_third(), shared_cache=shared_cache)
    prev_chain_second = SharedChain(base_chain=chain_fifth(), shared_cache=shared_cache)

    prev_chain_first.fit(input_data=train)
    prev_chain_second.fit(input_data=train)

    nodes_with_non_actual_cache = [main_chain.root_node, main_chain.root_node.nodes_from[1]]
    nodes_with_actual_cache = [child for child in main_chain.root_node.nodes_from[0].nodes_from]

    assert not any([node.cache.actual_cached_state for node in nodes_with_non_actual_cache])
    assert all([node.cache.actual_cached_state for node in nodes_with_actual_cache])