示例#1
0
 def metric_for_nodes(self, metric_function, train_data: InputData,
                      test_data: InputData, is_chain_shared: bool,
                      chain: Chain) -> float:
     try:
         validate(chain)
         if is_chain_shared:
             chain = SharedChain(base_chain=chain, shared_cache=self.shared_cache)
         chain.fit(input_data=train_data)
         return metric_function(chain, test_data)
     except Exception as ex:
         self.log.info(f'Error in chain assessment during composition: {ex}. Continue.')
         return max_int_value
示例#2
0
    def composer_metric(self, metrics, train_data: InputData,
                        test_data: InputData,
                        chain: Chain) -> Optional[Tuple[Any]]:
        try:
            validate(chain)
            chain.log = self.log

            if type(metrics) is not list:
                metrics = [metrics]

            if self.cache is not None:
                # TODO improve cache
                chain.fit_from_cache(self.cache)

            if not chain.is_fitted:
                self.log.debug(
                    f'Chain {chain.root_node.descriptive_id} fit started')
                chain.fit(input_data=train_data,
                          time_constraint=self.composer_requirements.
                          max_chain_fit_time)
                self.cache.save_chain(chain)

            evaluated_metrics = ()
            for metric in metrics:
                if callable(metric):
                    metric_func = metric
                else:
                    metric_func = MetricsRepository().metric_by_id(metric)
                evaluated_metrics = evaluated_metrics + (metric_func(
                    chain, reference_data=test_data), )

            self.log.debug(
                f'Chain {chain.root_node.descriptive_id} with metrics: {list(evaluated_metrics)}'
            )

        except Exception as ex:
            self.log.info(f'Chain assessment warning: {ex}. Continue.')
            evaluated_metrics = None

        return evaluated_metrics
示例#3
0
def constraint_function(chain: Chain):
    try:
        validate(chain)
        return True
    except ValueError:
        return False
示例#4
0
def test_chain_validate_correct():
    chain = valid_chain()
    validate(chain)