コード例 #1
def test_import_json_to_fitted_chain_template_correctly():
    json_path_load = create_correct_path('test_fitted_chain_convert_to_json')

    chain = Chain()
    chain_template = ChainTemplate(chain)
    json_actual = chain_template.convert_to_dict()

    with open(json_path_load, 'r') as json_file:
        json_expected = json.load(json_file)

    assert json.dumps(json_actual) == json.dumps(json_expected)
コード例 #2
def test_import_json_template_to_chain_correctly():
    json_path_load = create_correct_path('test_chain_convert_to_json')

    chain = Chain()
    chain_template = ChainTemplate(chain)
    json_actual = chain_template.convert_to_dict()

    chain_expected = create_chain()
    chain_expected_template = ChainTemplate(chain_expected)
    json_expected = chain_expected_template.convert_to_dict()

    assert json.dumps(json_actual) == json.dumps(json_expected)
コード例 #3
class Chain:
    Base class used for composite model structure definition

    :param nodes: Node object(s)
    :param log: Log object to record messages

    .. note::
        fitted_on_data stores the data which were used in last chain fitting (equals None if chain hasn't been
        fitted yet)
    def __init__(self,
                 nodes: Optional[Union[Node, List[Node]]] = None,
                 log: Log = None):
        self.nodes = []
        self.log = log
        self.template = None
        self.computation_time = None
        self.operator = GraphOperator(self)
        if not log:
            self.log = default_log(__name__)
            self.log = log

        if nodes:
            if isinstance(nodes, list):
                for node in nodes:
        self.fitted_on_data = {}

    def fit_from_scratch(self, input_data: InputData = None):
        Method used for training the chain without using cached information

        :param input_data: data used for operation training
        # Clean all cache and fit all operations
        self.log.info('Fit chain from scratch')
        self.fit(input_data, use_cache=False)

    def update_fitted_on_data(self, data: InputData):
        characteristics = input_data_characteristics(data=data, log=self.log)
        self.fitted_on_data['data_type'] = characteristics[0]
        self.fitted_on_data['features_hash'] = characteristics[1]
        self.fitted_on_data['target_hash'] = characteristics[2]

    def _cache_status_if_new_data(self, new_input_data: InputData,
                                  cache_status: bool):
        new_data_params = input_data_characteristics(new_input_data,
        if cache_status and self.fitted_on_data:
            params_names = ('data_type', 'features_hash', 'target_hash')
            are_data_params_different = any([
                new_data_param != self.fitted_on_data[param_name]
                for new_data_param, param_name in zip(new_data_params,
            if are_data_params_different:
                info = 'Trained operation cache is not actual because you are using new dataset for training. ' \
                       'Parameter use_cache value changed to False'
                cache_status = False
        return cache_status

    def _fit_with_time_limit(
        input_data: Optional[InputData] = None,
        time: timedelta = timedelta(minutes=3)) -> Manager:
        Run training process with time limit. Create

        :param input_data: data used for operation training
        :param use_cache: flag defining whether use cache information about previous executions or not, default True
        :param time: time constraint for operation fitting process (seconds)
        time = int(time.total_seconds())
        manager = Manager()
        process_state_dict = manager.dict()
        fitted_operations = manager.list()
        p = Process(target=self._fit,
                    args=(input_data, use_cache, process_state_dict,
        if p.is_alive():
            raise TimeoutError(
                f'Chain fitness evaluation time limit is expired')

        self.fitted_on_data = process_state_dict['fitted_on_data']
        self.computation_time = process_state_dict['computation_time']
        for node_num, node in enumerate(self.nodes):
            self.nodes[node_num].fitted_operation = fitted_operations[node_num]
        return process_state_dict['train_predicted']

    def _fit(self,
             input_data: InputData,
             process_state_dict: Manager = None,
             fitted_operations: Manager = None):
        Run training process in all nodes in chain starting with root.

        :param input_data: data used for operation training
        :param use_cache: flag defining whether use cache information about previous executions or not, default True
        :param process_state_dict: this dictionary is used for saving required chain parameters (which were changed
        inside the process) in a case of operation fit time control (when process created)
        :param fitted_operations: this list is used for saving fitted operations of chain nodes

        # InputData was set directly to the primary nodes
        if input_data is None:
            use_cache = False
            use_cache = self._cache_status_if_new_data(
                new_input_data=input_data, cache_status=use_cache)

            if not use_cache or not self.fitted_on_data:
                # Don't use previous information

        with Timer(log=self.log) as t:
            computation_time_update = not use_cache or not self.root_node.fitted_operation or \
                                      self.computation_time is None

            train_predicted = self.root_node.fit(input_data=input_data)
            if computation_time_update:
                self.computation_time = round(t.minutes_from_start, 3)

        if process_state_dict is None:
            return train_predicted
            process_state_dict['train_predicted'] = train_predicted
            process_state_dict['computation_time'] = self.computation_time
            process_state_dict['fitted_on_data'] = self.fitted_on_data
            for node in self.nodes:

    def fit(self,
            input_data: Optional[InputData] = None,
            time_constraint: Optional[timedelta] = None):
        Run training process in all nodes in chain starting with root.

        :param input_data: data used for operation training
        :param use_cache: flag defining whether use cache information about previous executions or not, default True
        :param time_constraint: time constraint for operation fitting (seconds)
        if not use_cache:

        if time_constraint is None:
            train_predicted = self._fit(input_data=input_data,
            train_predicted = self._fit_with_time_limit(input_data=input_data,
        return train_predicted

    def predict(self,
                input_data: InputData = None,
                output_mode: str = 'default'):
        Run the predict process in all nodes in chain starting with root.

        :param input_data: data for prediction
        :param output_mode: desired form of output for operations. Available options are:
                'default' (as is),
                'labels' (numbers of classes - for classification) ,
                'probs' (probabilities - for classification =='default'),
                'full_probs' (return all probabilities - for binary classification).
        :return: OutputData with prediction

        if not self.is_fitted:
            ex = 'Trained operation cache is not actual or empty'
            raise ValueError(ex)

        result = self.root_node.predict(input_data=input_data,
        return result

    def fine_tune_all_nodes(self,
                            loss_function: Callable,
                            loss_params: Callable = None,
                            input_data: Optional[InputData] = None,
                            max_lead_time: int = 5) -> 'Chain':
        """ Tune all hyperparameters of nodes simultaneously via black-box
            optimization using ChainTuner. For details, see
        max_lead_time = timedelta(minutes=max_lead_time)
        chain_tuner = ChainTuner(chain=self,
        self.log.info('Start tuning of primary nodes')
        tuned_chain = chain_tuner.tune_chain(input_data=input_data,
        self.log.info('Tuning was finished')

        return tuned_chain

    def add_node(self, new_node: Node):
        Add new node to the Chain

        :param node: new Node object

    def update_node(self, old_node: Node, new_node: Node):
        Replace old_node with new one.

        :param old_node: Node object to replace
        :param new_node: Node object to replace

        self.operator.update_node(old_node, new_node)

    def update_subtree(self, old_subroot: Node, new_subroot: Node):
        Replace the subtrees with old and new nodes as subroots

        :param old_subroot: Node object to replace
        :param new_subroot: Node object to replace
        self.operator.update_subtree(old_subroot, new_subroot)

    def delete_node(self, node: Node):
        Delete chosen node redirecting all its parents to the child.

        :param node: Node object to delete


    def delete_subtree(self, subroot: Node):
        Delete the subtree with node as subroot.

        :param subroot:

    def is_fitted(self):
        return all([(node.fitted_operation is not None)
                    for node in self.nodes])

    def unfit(self):
        Remove fitted operations for all nodes.
        for node in self.nodes:

    def fit_from_cache(self, cache):
        for node in self.nodes:
            cached_state = cache.get(node)
            if cached_state:
                node.fitted_operation = cached_state.operation
                node.fitted_operation = None

    def save(self, path: str):
        Save the chain to the json representation with pickled fitted operations.

        :param path to json file with operation
        :return: json containing a composite operation description
        if not self.template:
            self.template = ChainTemplate(self, self.log)
        json_object = self.template.export_chain(path)
        return json_object

    def load(self, path: str):
        Load the chain the json representation with pickled fitted operations.

        :param path to json file with operation
        self.nodes = []
        self.template = ChainTemplate(self, self.log)

    def show(self, path: str = None):
        ChainVisualiser().visualise(self, path)

    def __eq__(self, other) -> bool:
        return self.root_node.descriptive_id == other.root_node.descriptive_id

    def __str__(self):
        description = {
            'depth': self.depth,
            'length': self.length,
            'nodes': self.nodes,
        return f'{description}'

    def __repr__(self):
        return self.__str__()

    def root_node(self) -> Optional[Node]:
        if len(self.nodes) == 0:
            return None
        root = [
            node for node in self.nodes
            if not any(self.operator.node_children(node))
        if len(root) > 1:
            raise ValueError(f'{ERROR_PREFIX} More than 1 root_nodes in chain')
        return root[0]

    def length(self) -> int:
        return len(self.nodes)

    def depth(self) -> int:
        def _depth_recursive(node):
            if node is None:
                return 0
            if isinstance(node, PrimaryNode):
                return 1
                return 1 + max([
                    for next_node in node.nodes_from

        return _depth_recursive(self.root_node)
コード例 #4
ファイル: chain.py プロジェクト: STATAN/FEDOT
class Chain:
    Base class used for composite model structure definition

    :param nodes: Node object(s)
    :param log: Log object to record messages

    .. note::
        fitted_on_data stores the data which were used in last chain fitting (equals None if chain hasn't been
        fitted yet)
    def __init__(self,
                 nodes: Optional[Union[Node, List[Node]]] = None,
                 log: Log = None):
        self.nodes = []
        self.log = log
        self.template = None

        if not log:
            self.log = default_log(__name__)
            self.log = log

        if nodes:
            if isinstance(nodes, list):
                for node in nodes:
        self.fitted_on_data = None

    def fit_from_scratch(self, input_data: InputData, verbose=False):
        Method used for training the chain without using cached information

        :param input_data: data used for model training
        :param verbose: flag used for status printing to console, default False
        # Clean all cache and fit all models
        self.log.info('Fit chain from scratch')
        self.fit(input_data, use_cache=False, verbose=verbose)

    def cache_status_if_new_data(self, new_input_data: InputData,
                                 cache_status: bool):
        if self.fitted_on_data is not None and self.fitted_on_data is not new_input_data:
            if cache_status:
                    'Trained model cache is not actual because you are using new dataset for training. '
                    'Parameter use_cache value changed to False')
                cache_status = False
        return cache_status

    def fit(self, input_data: InputData, use_cache=True, verbose=False):
        Run training process in all nodes in chain starting with root.

        :param input_data: data used for model training
        :param use_cache: flag defining whether use cache information about previous executions or not, default True
        :param verbose: flag used for status printing to console, default False
        use_cache = self.cache_status_if_new_data(new_input_data=input_data,

        if not use_cache:

        if input_data.task.task_type == TaskTypesEnum.ts_forecasting:
            if input_data.task.task_params.make_future_prediction:
                input_data.task.task_params.return_all_steps = True
            # the make_future_prediction is useless for the fit stage
            input_data.task.task_params.make_future_prediction = False

        if not use_cache or self.fitted_on_data is None:
            self.fitted_on_data = input_data
        train_predicted = self.root_node.fit(input_data=input_data,
        return train_predicted

    def predict(self, input_data: InputData, output_mode: str = 'default'):
        Run the predict process in all nodes in chain starting with root.

        :param input_data: data for prediction
        :param output_mode: desired form of output for models. Available options are:
                'default' (as is),
                'labels' (numbers of classes - for classification) ,
                'probs' (probabilities - for classification =='default'),
                'full_probs' (return all probabilities - for binary classification).
        :return: array of predicted target values

        if not self.is_all_cache_actual():
            ex = 'Trained model cache is not actual or empty'
            raise ValueError(ex)

        result = self.root_node.predict(input_data=input_data,
        return result

    def fine_tune_primary_nodes(
            input_data: InputData,
            iterations: int = 30,
            max_lead_time: timedelta = timedelta(minutes=5),
        Optimize hyperparameters in primary nodes models

        :param input_data: data used for tuning
        :param iterations: max number of iterations
        :param max_lead_time: max time available for tuning process
        :param verbose: flag used for status printing to console, default False
        # Select all primary nodes
        # Perform fine-tuning for each model in node
        if verbose:
            self.log.info('Start tuning of primary nodes')

        all_primary_nodes = [
            node for node in self.nodes if isinstance(node, PrimaryNode)
        for node in all_primary_nodes:

        if verbose:
            self.log.info('End tuning')

    def fine_tune_all_nodes(self,
                            input_data: InputData,
                            iterations: int = 30,
                            max_lead_time: timedelta = timedelta(minutes=5),
        Optimize hyperparameters in all nodes models

        :param input_data: data used for tuning
        :param iterations: max number of iterations
        :param max_lead_time: max time available for tuning process
        :param verbose: flag used for status printing to console, default False
        if verbose:
            self.log.info('Start tuning of chain')

        node = self.root_node

        if verbose:
            self.log.info('End tuning')

    def add_node(self, new_node: Node):
        Add new node to the Chain

        :param new_node: new Node object
        if new_node not in self.nodes:
            if new_node.nodes_from:
                for new_parent_node in new_node.nodes_from:
                    if new_parent_node not in self.nodes:

    def _actualise_old_node_childs(self, old_node: Node, new_node: Node):
        old_node_offspring = self.node_childs(old_node)
        for old_node_child in old_node_offspring:
                old_node)] = new_node

    def replace_node_with_parents(self, old_node: Node, new_node: Node):
        new_node = deepcopy(new_node)
        self._actualise_old_node_childs(old_node, new_node)
        new_nodes = [
            parent for parent in new_node.ordered_subnodes_hierarchy
            if not parent in self.nodes
        old_nodes = [
            node for node in self.nodes
            if not node in old_node.ordered_subnodes_hierarchy
        self.nodes = new_nodes + old_nodes

    def update_node(self, old_node: Node, new_node: Node):
        self._actualise_old_node_childs(old_node, new_node)
        new_node.nodes_from = old_node.nodes_from

    def delete_node(self, node: Node):
        for node_child in self.node_childs(node):
        for subtree_node in node.ordered_subnodes_hierarchy:

    def _clean_model_cache(self):
        for node in self.nodes:
            node.cache = FittedModelCache(node)

    def is_all_cache_actual(self):
        cache_status = [
            node.cache.actual_cached_state is not None for node in self.nodes
        return all(cache_status)

    def node_childs(self, node) -> List[Optional[Node]]:
        return [
            other_node for other_node in self.nodes
            if isinstance(other_node, SecondaryNode)
            and node in other_node.nodes_from

    def _is_node_has_child(self, node) -> bool:
        return any(self.node_childs(node))

    def import_cache(self, fitted_chain: 'Chain'):
        for node in self.nodes:
            if not node.cache.actual_cached_state:
                for fitted_node in fitted_chain.nodes:
                    if fitted_node.descriptive_id == node.descriptive_id:

    # TODO why trees visualisation is incorrect?
    def sort_nodes(self):
        """layer by layer sorting"""
        nodes = self.root_node.ordered_subnodes_hierarchy
        self.nodes = nodes

    def save_chain(self, path: str):
        if not self.template:
            self.template = ChainTemplate(self, self.log)
        json_object = self.template.export_chain(path)
        return json_object

    def load_chain(self, path: str):
        self.nodes = []
        self.template = ChainTemplate(self, self.log)

    def __eq__(self, other) -> bool:
        return self.root_node.descriptive_id == other.root_node.descriptive_id

    def root_node(self) -> Optional[Node]:
        if len(self.nodes) == 0:
            return None
        root = [
            node for node in self.nodes if not self._is_node_has_child(node)
        if len(root) > 1:
            raise ValueError(f'{ERROR_PREFIX} More than 1 root_nodes in chain')
        return root[0]

    def length(self) -> int:
        return len(self.nodes)

    def depth(self) -> int:
        def _depth_recursive(node):
            if node is None:
                return 0
            if isinstance(node, PrimaryNode):
                return 1
                return 1 + max([
                    for next_node in node.nodes_from

        return _depth_recursive(self.root_node)