Beispiel #1
0
def test_import_json_to_fitted_chain_template_correctly():
    json_path_load = create_correct_path('test_fitted_chain_convert_to_json')

    chain = Chain()
    chain_template = ChainTemplate(chain)
    chain_template.import_chain(json_path_load)
    json_actual = chain_template.convert_to_dict()

    with open(json_path_load, 'r') as json_file:
        json_expected = json.load(json_file)

    assert json.dumps(json_actual) == json.dumps(json_expected)
Beispiel #2
0
def test_import_json_template_to_chain_correctly():
    json_path_load = create_correct_path('test_chain_convert_to_json')

    chain = Chain()
    chain_template = ChainTemplate(chain)
    chain_template.import_chain(json_path_load)
    json_actual = chain_template.convert_to_dict()

    chain_expected = create_chain()
    chain_expected_template = ChainTemplate(chain_expected)
    json_expected = chain_expected_template.convert_to_dict()

    assert json.dumps(json_actual) == json.dumps(json_expected)
Beispiel #3
0
class Chain:
    """
    Base class used for composite model structure definition

    :param nodes: Node object(s)
    :param log: Log object to record messages

    .. note::
        fitted_on_data stores the data which were used in last chain fitting (equals None if chain hasn't been
        fitted yet)
    """
    def __init__(self,
                 nodes: Optional[Union[Node, List[Node]]] = None,
                 log: Log = None):
        self.nodes = []
        self.log = log
        self.template = None
        self.computation_time = None
        self.operator = GraphOperator(self)
        if not log:
            self.log = default_log(__name__)
        else:
            self.log = log

        if nodes:
            if isinstance(nodes, list):
                for node in nodes:
                    self.add_node(node)
            else:
                self.add_node(nodes)
        self.fitted_on_data = {}

    def fit_from_scratch(self, input_data: InputData = None):
        """
        Method used for training the chain without using cached information

        :param input_data: data used for operation training
        """
        # Clean all cache and fit all operations
        self.log.info('Fit chain from scratch')
        self.unfit()
        self.fit(input_data, use_cache=False)

    def update_fitted_on_data(self, data: InputData):
        characteristics = input_data_characteristics(data=data, log=self.log)
        self.fitted_on_data['data_type'] = characteristics[0]
        self.fitted_on_data['features_hash'] = characteristics[1]
        self.fitted_on_data['target_hash'] = characteristics[2]

    def _cache_status_if_new_data(self, new_input_data: InputData,
                                  cache_status: bool):
        new_data_params = input_data_characteristics(new_input_data,
                                                     log=self.log)
        if cache_status and self.fitted_on_data:
            params_names = ('data_type', 'features_hash', 'target_hash')
            are_data_params_different = any([
                new_data_param != self.fitted_on_data[param_name]
                for new_data_param, param_name in zip(new_data_params,
                                                      params_names)
            ])
            if are_data_params_different:
                info = 'Trained operation cache is not actual because you are using new dataset for training. ' \
                       'Parameter use_cache value changed to False'
                self.log.info(info)
                cache_status = False
        return cache_status

    def _fit_with_time_limit(
        self,
        input_data: Optional[InputData] = None,
        use_cache=False,
        time: timedelta = timedelta(minutes=3)) -> Manager:
        """
        Run training process with time limit. Create

        :param input_data: data used for operation training
        :param use_cache: flag defining whether use cache information about previous executions or not, default True
        :param time: time constraint for operation fitting process (seconds)
        """
        time = int(time.total_seconds())
        manager = Manager()
        process_state_dict = manager.dict()
        fitted_operations = manager.list()
        p = Process(target=self._fit,
                    args=(input_data, use_cache, process_state_dict,
                          fitted_operations),
                    kwargs={})
        p.start()
        p.join(time)
        if p.is_alive():
            p.terminate()
            raise TimeoutError(
                f'Chain fitness evaluation time limit is expired')

        self.fitted_on_data = process_state_dict['fitted_on_data']
        self.computation_time = process_state_dict['computation_time']
        for node_num, node in enumerate(self.nodes):
            self.nodes[node_num].fitted_operation = fitted_operations[node_num]
        return process_state_dict['train_predicted']

    def _fit(self,
             input_data: InputData,
             use_cache=False,
             process_state_dict: Manager = None,
             fitted_operations: Manager = None):
        """
        Run training process in all nodes in chain starting with root.

        :param input_data: data used for operation training
        :param use_cache: flag defining whether use cache information about previous executions or not, default True
        :param process_state_dict: this dictionary is used for saving required chain parameters (which were changed
        inside the process) in a case of operation fit time control (when process created)
        :param fitted_operations: this list is used for saving fitted operations of chain nodes
        """

        # InputData was set directly to the primary nodes
        if input_data is None:
            use_cache = False
        else:
            use_cache = self._cache_status_if_new_data(
                new_input_data=input_data, cache_status=use_cache)

            if not use_cache or not self.fitted_on_data:
                # Don't use previous information
                self.unfit()
                self.update_fitted_on_data(input_data)

        with Timer(log=self.log) as t:
            computation_time_update = not use_cache or not self.root_node.fitted_operation or \
                                      self.computation_time is None

            train_predicted = self.root_node.fit(input_data=input_data)
            if computation_time_update:
                self.computation_time = round(t.minutes_from_start, 3)

        if process_state_dict is None:
            return train_predicted
        else:
            process_state_dict['train_predicted'] = train_predicted
            process_state_dict['computation_time'] = self.computation_time
            process_state_dict['fitted_on_data'] = self.fitted_on_data
            for node in self.nodes:
                fitted_operations.append(node.fitted_operation)

    def fit(self,
            input_data: Optional[InputData] = None,
            use_cache=True,
            time_constraint: Optional[timedelta] = None):
        """
        Run training process in all nodes in chain starting with root.

        :param input_data: data used for operation training
        :param use_cache: flag defining whether use cache information about previous executions or not, default True
        :param time_constraint: time constraint for operation fitting (seconds)
        """
        if not use_cache:
            self.unfit()

        if time_constraint is None:
            train_predicted = self._fit(input_data=input_data,
                                        use_cache=use_cache)
        else:
            train_predicted = self._fit_with_time_limit(input_data=input_data,
                                                        use_cache=use_cache,
                                                        time=time_constraint)
        return train_predicted

    def predict(self,
                input_data: InputData = None,
                output_mode: str = 'default'):
        """
        Run the predict process in all nodes in chain starting with root.

        :param input_data: data for prediction
        :param output_mode: desired form of output for operations. Available options are:
                'default' (as is),
                'labels' (numbers of classes - for classification) ,
                'probs' (probabilities - for classification =='default'),
                'full_probs' (return all probabilities - for binary classification).
        :return: OutputData with prediction
        """

        if not self.is_fitted:
            ex = 'Trained operation cache is not actual or empty'
            self.log.error(ex)
            raise ValueError(ex)

        result = self.root_node.predict(input_data=input_data,
                                        output_mode=output_mode)
        return result

    def fine_tune_all_nodes(self,
                            loss_function: Callable,
                            loss_params: Callable = None,
                            input_data: Optional[InputData] = None,
                            iterations=50,
                            max_lead_time: int = 5) -> 'Chain':
        """ Tune all hyperparameters of nodes simultaneously via black-box
            optimization using ChainTuner. For details, see
        :meth:`~fedot.core.chains.tuning.unified.ChainTuner.tune_chain`
        """
        max_lead_time = timedelta(minutes=max_lead_time)
        chain_tuner = ChainTuner(chain=self,
                                 task=input_data.task,
                                 iterations=iterations,
                                 max_lead_time=max_lead_time)
        self.log.info('Start tuning of primary nodes')
        tuned_chain = chain_tuner.tune_chain(input_data=input_data,
                                             loss_function=loss_function,
                                             loss_params=loss_params)
        self.log.info('Tuning was finished')

        return tuned_chain

    def add_node(self, new_node: Node):
        """
        Add new node to the Chain

        :param node: new Node object
        """
        self.operator.add_node(new_node)

    def update_node(self, old_node: Node, new_node: Node):
        """
        Replace old_node with new one.

        :param old_node: Node object to replace
        :param new_node: Node object to replace
        """

        self.operator.update_node(old_node, new_node)

    def update_subtree(self, old_subroot: Node, new_subroot: Node):
        """
        Replace the subtrees with old and new nodes as subroots

        :param old_subroot: Node object to replace
        :param new_subroot: Node object to replace
        """
        self.operator.update_subtree(old_subroot, new_subroot)

    def delete_node(self, node: Node):
        """
        Delete chosen node redirecting all its parents to the child.

        :param node: Node object to delete
        """

        self.operator.delete_node(node)

    def delete_subtree(self, subroot: Node):
        """
        Delete the subtree with node as subroot.

        :param subroot:
        """
        self.operator.delete_subtree(subroot)

    @property
    def is_fitted(self):
        return all([(node.fitted_operation is not None)
                    for node in self.nodes])

    def unfit(self):
        """
        Remove fitted operations for all nodes.
        """
        for node in self.nodes:
            node.unfit()

    def fit_from_cache(self, cache):
        for node in self.nodes:
            cached_state = cache.get(node)
            if cached_state:
                node.fitted_operation = cached_state.operation
            else:
                node.fitted_operation = None

    def save(self, path: str):
        """
        Save the chain to the json representation with pickled fitted operations.

        :param path to json file with operation
        :return: json containing a composite operation description
        """
        if not self.template:
            self.template = ChainTemplate(self, self.log)
        json_object = self.template.export_chain(path)
        return json_object

    def load(self, path: str):
        """
        Load the chain the json representation with pickled fitted operations.

        :param path to json file with operation
        """
        self.nodes = []
        self.template = ChainTemplate(self, self.log)
        self.template.import_chain(path)

    def show(self, path: str = None):
        ChainVisualiser().visualise(self, path)

    def __eq__(self, other) -> bool:
        return self.root_node.descriptive_id == other.root_node.descriptive_id

    def __str__(self):
        description = {
            'depth': self.depth,
            'length': self.length,
            'nodes': self.nodes,
        }
        return f'{description}'

    def __repr__(self):
        return self.__str__()

    @property
    def root_node(self) -> Optional[Node]:
        if len(self.nodes) == 0:
            return None
        root = [
            node for node in self.nodes
            if not any(self.operator.node_children(node))
        ]
        if len(root) > 1:
            raise ValueError(f'{ERROR_PREFIX} More than 1 root_nodes in chain')
        return root[0]

    @property
    def length(self) -> int:
        return len(self.nodes)

    @property
    def depth(self) -> int:
        def _depth_recursive(node):
            if node is None:
                return 0
            if isinstance(node, PrimaryNode):
                return 1
            else:
                return 1 + max([
                    _depth_recursive(next_node)
                    for next_node in node.nodes_from
                ])

        return _depth_recursive(self.root_node)
Beispiel #4
0
class Chain:
    """
    Base class used for composite model structure definition

    :param nodes: Node object(s)
    :param log: Log object to record messages

    .. note::
        fitted_on_data stores the data which were used in last chain fitting (equals None if chain hasn't been
        fitted yet)
    """
    def __init__(self,
                 nodes: Optional[Union[Node, List[Node]]] = None,
                 log: Log = None):
        self.nodes = []
        self.log = log
        self.template = None

        if not log:
            self.log = default_log(__name__)
        else:
            self.log = log

        if nodes:
            if isinstance(nodes, list):
                for node in nodes:
                    self.add_node(node)
            else:
                self.add_node(nodes)
        self.fitted_on_data = None

    def fit_from_scratch(self, input_data: InputData, verbose=False):
        """
        Method used for training the chain without using cached information

        :param input_data: data used for model training
        :param verbose: flag used for status printing to console, default False
        """
        # Clean all cache and fit all models
        self.log.info('Fit chain from scratch')
        self.fit(input_data, use_cache=False, verbose=verbose)

    def cache_status_if_new_data(self, new_input_data: InputData,
                                 cache_status: bool):
        if self.fitted_on_data is not None and self.fitted_on_data is not new_input_data:
            if cache_status:
                self.log.warn(
                    'Trained model cache is not actual because you are using new dataset for training. '
                    'Parameter use_cache value changed to False')
                cache_status = False
        return cache_status

    def fit(self, input_data: InputData, use_cache=True, verbose=False):
        """
        Run training process in all nodes in chain starting with root.

        :param input_data: data used for model training
        :param use_cache: flag defining whether use cache information about previous executions or not, default True
        :param verbose: flag used for status printing to console, default False
        """
        use_cache = self.cache_status_if_new_data(new_input_data=input_data,
                                                  cache_status=use_cache)

        if not use_cache:
            self._clean_model_cache()

        if input_data.task.task_type == TaskTypesEnum.ts_forecasting:
            if input_data.task.task_params.make_future_prediction:
                input_data.task.task_params.return_all_steps = True
            # the make_future_prediction is useless for the fit stage
            input_data.task.task_params.make_future_prediction = False
            check_data_appropriate_for_task(input_data)

        if not use_cache or self.fitted_on_data is None:
            self.fitted_on_data = input_data
        train_predicted = self.root_node.fit(input_data=input_data,
                                             verbose=verbose)
        return train_predicted

    def predict(self, input_data: InputData, output_mode: str = 'default'):
        """
        Run the predict process in all nodes in chain starting with root.

        :param input_data: data for prediction
        :param output_mode: desired form of output for models. Available options are:
                'default' (as is),
                'labels' (numbers of classes - for classification) ,
                'probs' (probabilities - for classification =='default'),
                'full_probs' (return all probabilities - for binary classification).
        :return: array of predicted target values
        """

        if not self.is_all_cache_actual():
            ex = 'Trained model cache is not actual or empty'
            self.log.error(ex)
            raise ValueError(ex)

        result = self.root_node.predict(input_data=input_data,
                                        output_mode=output_mode)
        return result

    def fine_tune_primary_nodes(
            self,
            input_data: InputData,
            iterations: int = 30,
            max_lead_time: timedelta = timedelta(minutes=5),
            verbose=False):
        """
        Optimize hyperparameters in primary nodes models

        :param input_data: data used for tuning
        :param iterations: max number of iterations
        :param max_lead_time: max time available for tuning process
        :param verbose: flag used for status printing to console, default False
        """
        # Select all primary nodes
        # Perform fine-tuning for each model in node
        if verbose:
            self.log.info('Start tuning of primary nodes')

        all_primary_nodes = [
            node for node in self.nodes if isinstance(node, PrimaryNode)
        ]
        for node in all_primary_nodes:
            node.fine_tune(input_data,
                           max_lead_time=max_lead_time,
                           iterations=iterations)

        if verbose:
            self.log.info('End tuning')

    def fine_tune_all_nodes(self,
                            input_data: InputData,
                            iterations: int = 30,
                            max_lead_time: timedelta = timedelta(minutes=5),
                            verbose=False):
        """
        Optimize hyperparameters in all nodes models

        :param input_data: data used for tuning
        :param iterations: max number of iterations
        :param max_lead_time: max time available for tuning process
        :param verbose: flag used for status printing to console, default False
        """
        if verbose:
            self.log.info('Start tuning of chain')

        node = self.root_node
        node.fine_tune(input_data,
                       max_lead_time=max_lead_time,
                       iterations=iterations)

        if verbose:
            self.log.info('End tuning')

    def add_node(self, new_node: Node):
        """
        Add new node to the Chain

        :param new_node: new Node object
        """
        if new_node not in self.nodes:
            self.nodes.append(new_node)
            if new_node.nodes_from:
                for new_parent_node in new_node.nodes_from:
                    if new_parent_node not in self.nodes:
                        self.add_node(new_parent_node)

    def _actualise_old_node_childs(self, old_node: Node, new_node: Node):
        old_node_offspring = self.node_childs(old_node)
        for old_node_child in old_node_offspring:
            old_node_child.nodes_from[old_node_child.nodes_from.index(
                old_node)] = new_node

    def replace_node_with_parents(self, old_node: Node, new_node: Node):
        new_node = deepcopy(new_node)
        self._actualise_old_node_childs(old_node, new_node)
        new_nodes = [
            parent for parent in new_node.ordered_subnodes_hierarchy
            if not parent in self.nodes
        ]
        old_nodes = [
            node for node in self.nodes
            if not node in old_node.ordered_subnodes_hierarchy
        ]
        self.nodes = new_nodes + old_nodes
        self.sort_nodes()

    def update_node(self, old_node: Node, new_node: Node):
        self._actualise_old_node_childs(old_node, new_node)
        new_node.nodes_from = old_node.nodes_from
        self.nodes.remove(old_node)
        self.nodes.append(new_node)
        self.sort_nodes()

    def delete_node(self, node: Node):
        for node_child in self.node_childs(node):
            node_child.nodes_from.remove(node)
        for subtree_node in node.ordered_subnodes_hierarchy:
            self.nodes.remove(subtree_node)

    def _clean_model_cache(self):
        for node in self.nodes:
            node.cache = FittedModelCache(node)

    def is_all_cache_actual(self):
        cache_status = [
            node.cache.actual_cached_state is not None for node in self.nodes
        ]
        return all(cache_status)

    def node_childs(self, node) -> List[Optional[Node]]:
        return [
            other_node for other_node in self.nodes
            if isinstance(other_node, SecondaryNode)
            and node in other_node.nodes_from
        ]

    def _is_node_has_child(self, node) -> bool:
        return any(self.node_childs(node))

    def import_cache(self, fitted_chain: 'Chain'):
        for node in self.nodes:
            if not node.cache.actual_cached_state:
                for fitted_node in fitted_chain.nodes:
                    if fitted_node.descriptive_id == node.descriptive_id:
                        node.cache.import_from_other_cache(fitted_node.cache)
                        break

    # TODO why trees visualisation is incorrect?
    def sort_nodes(self):
        """layer by layer sorting"""
        nodes = self.root_node.ordered_subnodes_hierarchy
        self.nodes = nodes

    def save_chain(self, path: str):
        if not self.template:
            self.template = ChainTemplate(self, self.log)
        json_object = self.template.export_chain(path)
        return json_object

    def load_chain(self, path: str):
        self.nodes = []
        self.template = ChainTemplate(self, self.log)
        self.template.import_chain(path)

    def __eq__(self, other) -> bool:
        return self.root_node.descriptive_id == other.root_node.descriptive_id

    @property
    def root_node(self) -> Optional[Node]:
        if len(self.nodes) == 0:
            return None
        root = [
            node for node in self.nodes if not self._is_node_has_child(node)
        ]
        if len(root) > 1:
            raise ValueError(f'{ERROR_PREFIX} More than 1 root_nodes in chain')
        return root[0]

    @property
    def length(self) -> int:
        return len(self.nodes)

    @property
    def depth(self) -> int:
        def _depth_recursive(node):
            if node is None:
                return 0
            if isinstance(node, PrimaryNode):
                return 1
            else:
                return 1 + max([
                    _depth_recursive(next_node)
                    for next_node in node.nodes_from
                ])

        return _depth_recursive(self.root_node)