示例#1
0
class TransactionController:
    def __init__(self, session, transaction_metadata, breakpoint=PHASE_END):
        """
        A transaction is the interface to start some MindsDB operation within a session

        :param session:
        :type session: utils.controllers.session_controller.SessionController
        :param transaction_type:
        :param transaction_metadata:
        :type transaction_metadata: TransactionMetadata
        :param breakpoint:
        """

        self.session = session
        self.breakpoint = breakpoint
        self.session.current_transaction = self
        self.metadata = transaction_metadata  #type: TransactionMetadata

        # variables to de defined by setup
        self.error = None
        self.errorMsg = None

        self.input_data = TransactionData()
        self.output_data = TransactionOutputData(
            predicted_columns=self.metadata.model_predict_columns)

        self.model_data = ModelData()

        # variables that can be persisted
        self.persistent_model_metadata = PersistentModelMetadata()
        self.persistent_model_metadata.model_name = self.metadata.model_name
        self.persistent_ml_model_info = PersistentMlModelInfo()
        self.persistent_ml_model_info.model_name = self.metadata.model_name

        self.run()

    def getPhaseInstance(self, module_name, **kwargs):
        """
        Loads the module that we want to start for

        :param module_name:
        :param kwargs:
        :return:
        """

        module_path = convert_cammelcase_to_snake_string(module_name)
        module_full_path = 'mindsdb.libs.phases.{module_path}.{module_path}'.format(
            module_path=module_path)
        try:
            main_module = importlib.import_module(module_full_path)
            module = getattr(main_module, module_name)
            return module(self.session, self, **kwargs)
        except:
            self.session.logging.error(
                'Could not load module {module_name}'.format(
                    module_name=module_name))
            self.session.logging.error(traceback.format_exc())
            return None

    def callPhaseModule(self, module_name):
        """

        :param module_name:
        :return:
        """
        module = self.getPhaseInstance(module_name)
        return module()

    def executeLearn(self):
        """

        :return:
        """

        self.callPhaseModule('DataExtractor')
        if len(self.input_data.data_array) <= 0 or len(
                self.input_data.data_array[0]) <= 0:
            self.type = TRANSACTION_BAD_QUERY
            self.errorMsg = "No results for this query."
            return

        try:
            # make sure that we remove all previous data about this model
            info = self.persistent_ml_model_info.find_one(
                self.persistent_model_metadata.getPkey())
            if info is not None:
                info.deleteFiles()
            self.persistent_model_metadata.delete()
            self.persistent_ml_model_info.delete()

            # start populating data
            self.persistent_model_metadata.train_metadata = self.metadata.getAsDict(
            )
            self.persistent_model_metadata.current_phase = MODEL_STATUS_ANALYZING
            self.persistent_model_metadata.columns = self.input_data.columns  # this is populated by data extractor
            self.persistent_model_metadata.predict_columns = self.metadata.model_predict_columns
            self.persistent_model_metadata.insert()

            self.callPhaseModule('StatsGenerator')
            self.persistent_model_metadata.current_phase = MODEL_STATUS_PREPARING
            self.persistent_model_metadata.update()

            self.callPhaseModule('DataVectorizer')
            self.persistent_model_metadata.current_phase = MODEL_STATUS_TRAINING
            self.persistent_model_metadata.update()

            # self.callPhaseModule('DataEncoder')
            self.callPhaseModule('ModelTrainer')
            # TODO: Loop over all stats and when all stats are done, then we can mark model as MODEL_STATUS_TRAINED

            return
        except Exception as e:

            self.persistent_model_metadata.current_phase = MODEL_STATUS_ERROR
            self.persistent_model_metadata.error_msg = traceback.print_exc()
            self.persistent_model_metadata.update()
            self.session.logging.error(
                self.persistent_model_metadata.error_msg)
            self.session.logging.error(e)

            return

    def executeDropModel(self):
        """

        :return:
        """

        # make sure that we remove all previous data about this model
        self.persistent_model_metadata.delete()
        self.persistent_model_stats.delete()

        self.output_data.data_array = [[
            'Model ' + self.metadata.model_name + ' deleted.'
        ]]
        self.output_data.columns = ['Status']

        return

    def executeNormalSelect(self):
        """

        :return:
        """

        self.callPhaseModule('DataExtractor')
        self.output_data = self.input_data
        return

    def executePredict(self):
        """

        :return:
        """

        self.callPhaseModule('StatsLoader')
        if self.persistent_model_metadata is None:
            self.session.logging.error('No metadata found for this model')
            return

        self.callPhaseModule('DataExtractor')
        if len(self.input_data.data_array[0]) <= 0:
            self.output_data = self.input_data
            return

        self.callPhaseModule('DataVectorizer')
        self.callPhaseModule('ModelPredictor')

        return

    def run(self):
        """

        :return:
        """

        if self.metadata.type == TRANSACTION_BAD_QUERY:
            self.session.logging.error(self.errorMsg)
            self.error = True
            return

        if self.metadata.type == TRANSACTION_DROP_MODEL:
            self.executeDropModel()
            return

        if self.metadata.type == TRANSACTION_LEARN:
            self.output_data.data_array = [[
                'Model ' + self.metadata.model_name + ' training.'
            ]]
            self.output_data.columns = ['Status']

            if CONFIG.EXEC_LEARN_IN_THREAD == False:
                self.executeLearn()
            else:
                _thread.start_new_thread(self.executeLearn, ())
            return

        elif self.metadata.type == TRANSACTION_PREDICT:
            self.executePredict()
        elif self.metadata.type == TRANSACTION_NORMAL_SELECT:
            self.executeNormalSelect()
示例#2
0
文件: predict.py 项目: torrmal/main
class PredictWorker():
    def __init__(self, data, model_name):
        """
        Load basic data needed to find the model data
        :param data: data to make predictions on
        :param model_name: the model to load
        :param submodel_name: if its also a submodel, the submodel name
        """

        self.data = data
        self.model_name = model_name

        self.persistent_model_metadata = PersistentModelMetadata()
        self.persistent_model_metadata.model_name = self.model_name
        self.persistent_ml_model_info = PersistentMlModelInfo()
        self.persistent_ml_model_info.model_name = self.model_name

        self.persistent_model_metadata = self.persistent_model_metadata.find_one(
            self.persistent_model_metadata.getPkey())

        # laod the most accurate model

        info = self.persistent_ml_model_info.find(
            {'model_name': self.model_name},
            order_by=[('r_squared', -1)],
            limit=1)

        if info is not None and len(info) > 0:
            self.persistent_ml_model_info = info[
                0]  #type: PersistentMlModelInfo
        else:
            # TODO: Make sure we have a model for this
            logging.info('No model found')
            return

        self.predict_sampler = Sampler(
            self.data.predict_set,
            metadata_as_stored=self.persistent_model_metadata)

        self.ml_model_name = self.persistent_ml_model_info.ml_model_name
        self.config_serialized = self.persistent_ml_model_info.config_serialized

        fs_file_ids = self.persistent_ml_model_info.fs_file_ids
        self.framework, self.dummy, self.ml_model_name = self.ml_model_name.split(
            '.')
        self.ml_model_module_path = 'mindsdb.libs.ml_models.' + self.framework + '.models.' + self.ml_model_name + '.' + self.ml_model_name
        self.ml_model_class_name = convert_snake_to_cammelcase_string(
            self.ml_model_name)

        self.ml_model_module = importlib.import_module(
            self.ml_model_module_path)
        self.ml_model_class = getattr(self.ml_model_module,
                                      self.ml_model_class_name)

        self.sample_batch = self.predict_sampler.getSampleBatch()

        self.gfs_save_head_time = time.time(
        )  # the last time it was saved into GridFS, assume it was now

        logging.info('Starting model...')
        self.data_model_object = self.ml_model_class.loadFromDisk(
            file_ids=fs_file_ids)
        self.data_model_object.sample_batch = self.sample_batch

    def predict(self):
        """
        This actually calls the model and returns the predictions in diff form

        :return: diffs, which is a list of dictionaries with pointers as to where to replace the prediction given the value that was predicted

        """
        self.predict_sampler.variable_wrapper = self.ml_model_class.variable_wrapper
        self.predict_sampler.variable_unwrapper = self.ml_model_class.variable_unwrapper

        ret_diffs = []
        for batch in self.predict_sampler:

            logging.info('predicting batch...')
            ret = self.data_model_object.forward(
                batch.getInput(flatten=self.data_model_object.flatInput))
            if type(ret) != type({}):
                ret_dict = batch.deflatTarget(ret)
            else:
                ret_dict = ret

            ret_dict_denorm = {}

            for col in ret_dict:
                ret_dict[col] = self.ml_model_class.variable_unwrapper(
                    ret_dict[col])
                for row in ret_dict[col]:
                    if col not in ret_dict_denorm:
                        ret_dict_denorm[col] = []

                    ret_dict_denorm[col] += [
                        denorm(
                            row,
                            self.persistent_model_metadata.column_stats[col])
                    ]

            ret_total_item = {
                'group_pointer': batch.group_pointer,
                'column_pointer': batch.column_pointer,
                'start_pointer': batch.start_pointer,
                'end_pointer': batch.end_pointer,
                'ret_dict': ret_dict_denorm
            }
            ret_diffs += [ret_total_item]

        return ret_diffs

    @staticmethod
    def start(data, model_name):
        """
        We use this worker to parallel train different data models and data model configurations

        :param data: This is the vectorized data
        :param model_name: This will be the model name so we can pull stats and other
        :param data_model: This will be the data model name, which can let us find the data model implementation
        :param config: this is the hyperparameter config
        """

        w = PredictWorker(data, model_name)
        logging.info('Inferring from model and data...')
        return w.predict()