def test_shared_cache(data_setup): train, _ = data_setup shared_cache = {} main_chain = SharedChain(chain_first(), shared_cache) other_chain = SharedChain(chain_first(), shared_cache) other_chain.fit(train) # test cache is shared assert isinstance(main_chain.root_node.cache, SharedCache) # test cache is actual assert main_chain.root_node.cache.actual_cached_state is not None saved_model = main_chain.root_node.cache.actual_cached_state main_chain.root_node.cache.clear() # test cache is still actual despite the clearing of local cache assert main_chain.root_node.cache.actual_cached_state is not None shared_cache.clear() # test cache is not actual after clearing shared cache assert main_chain.root_node.cache.actual_cached_state is None main_chain.root_node.cache.append(saved_model) # test cache is actual after manual appending of model assert main_chain.root_node.cache.actual_cached_state is not None assert shared_cache[main_chain.root_node.descriptive_id] == saved_model
def test_chain_sharing_and_unsharing(data_setup): chain = chain_first() assert all([isinstance(node.cache, FittedModelCache) for node in chain.nodes]) chain = SharedChain(chain, {}) assert all([isinstance(node.cache, SharedCache) for node in chain.nodes]) chain = chain.unshare() assert all([isinstance(node.cache, FittedModelCache) for node in chain.nodes]) assert isinstance(chain, Chain)
def metric_for_nodes(self, metric_function, train_data: InputData, test_data: InputData, is_chain_shared: bool, chain: Chain) -> float: validate(chain) if is_chain_shared: chain = SharedChain(base_chain=chain, shared_cache=self.shared_cache) chain.fit(input_data=train_data) return metric_function(chain, test_data)
def metric_for_nodes(self, metric_function, train_data: InputData, test_data: InputData, is_chain_shared: bool, chain: Chain) -> float: try: validate(chain) if is_chain_shared: chain = SharedChain(base_chain=chain, shared_cache=self.shared_cache) chain.fit(input_data=train_data) return metric_function(chain, test_data) except Exception as ex: print( f'Error in chain assessment during composition: {ex}. Continue.' ) return max_int_value
def test_multi_chain_caching_with_shared_cache(data_setup): train, _ = data_setup shared_cache = {} main_chain = SharedChain(base_chain=chain_second(), shared_cache=shared_cache) other_chain = SharedChain(base_chain=chain_first(), shared_cache=shared_cache) # fit other_chain that contains the parts identical to main_chain other_chain.fit(input_data=train) nodes_with_non_actual_cache = [main_chain.root_node, main_chain.root_node.nodes_from[0]] + \ [_ for _ in main_chain.root_node.nodes_from[0].nodes_from] nodes_with_actual_cache = [node for node in main_chain.nodes if node not in nodes_with_non_actual_cache] # check that using of SharedChain make identical of the main_chain fitted, # despite the main_chain.fit() was not called assert all([node.cache.actual_cached_state for node in nodes_with_actual_cache]) # the non-identical parts are still not fitted assert not any([node.cache.actual_cached_state for node in nodes_with_non_actual_cache]) # check the same case with another chains shared_cache = {} main_chain = SharedChain(base_chain=chain_fourth(), shared_cache=shared_cache) prev_chain_first = SharedChain(base_chain=chain_third(), shared_cache=shared_cache) prev_chain_second = SharedChain(base_chain=chain_fifth(), shared_cache=shared_cache) prev_chain_first.fit(input_data=train) prev_chain_second.fit(input_data=train) nodes_with_non_actual_cache = [main_chain.root_node, main_chain.root_node.nodes_from[1]] nodes_with_actual_cache = [child for child in main_chain.root_node.nodes_from[0].nodes_from] assert not any([node.cache.actual_cached_state for node in nodes_with_non_actual_cache]) assert all([node.cache.actual_cached_state for node in nodes_with_actual_cache])