コード例 #1
0
ファイル: datastore.py プロジェクト: wojohowitz00/mindsdb
class DataStore():
    def __init__(self):
        self.config = Config()

        self.fs_store = FsSotre()
        self.company_id = os.environ.get('MINDSDB_COMPANY_ID', None)
        self.dir = self.config.paths['datasources']
        self.mindsdb_native = NativeInterface()

    def get_analysis(self, name):
        datasource_record = session.query(Datasource).filter_by(
            company_id=self.company_id, name=name).first()
        if datasource_record.analysis is None:
            datasource_record.analysis = json.dumps(
                self.mindsdb_native.analyse_dataset(
                    self.get_datasource_obj(name)))
            session.commit()

        analysis = json.loads(datasource_record.analysis)
        return analysis

    def get_datasources(self, name=None):
        datasource_arr = []
        if name is not None:
            datasource_record_arr = session.query(Datasource).filter_by(
                company_id=self.company_id, name=name)
        else:
            datasource_record_arr = session.query(Datasource).filter_by(
                company_id=self.company_id)
        for datasource_record in datasource_record_arr:
            try:
                datasource = json.loads(datasource_record.data)
                datasource['created_at'] = datasource_record.created_at
                datasource['updated_at'] = datasource_record.updated_at
                datasource['name'] = datasource_record.name
                datasource['id'] = datasource_record.id
                datasource_arr.append(datasource)
            except Exception as e:
                log.error(e)
        return datasource_arr

    def get_data(self, name, where=None, limit=None, offset=None):
        offset = 0 if offset is None else offset
        ds = self.get_datasource_obj(name)

        if limit is not None:
            # @TODO Add `offset` to the `filter` method of the datasource and get rid of `offset`
            filtered_ds = ds.filter(where=where,
                                    limit=limit + offset).iloc[offset:]
        else:
            filtered_ds = ds.filter(where=where)

        filtered_ds = filtered_ds.where(pd.notnull(filtered_ds), None)
        data = filtered_ds.to_dict(orient='records')
        return {
            'data': data,
            'rowcount': len(ds),
            'columns_names': filtered_ds.columns
        }

    def get_datasource(self, name):
        datasource_arr = self.get_datasources(name)
        if len(datasource_arr) == 1:
            return datasource_arr[0]
        # @TODO: Remove when db swithc is more stable, this should never happen, but good santiy check while this is kinda buggy
        elif len(datasource_arr) > 1:
            log.error('Two or more datasource with the same name, (',
                      len(datasource_arr), ') | Full list: ', datasource_arr)
            raise Exception('Two or more datasource with the same name')
        return None

    def delete_datasource(self, name):
        datasource_record = Datasource.query.filter_by(
            company_id=self.company_id, name=name).first()
        id = datasource_record.id
        session.delete(datasource_record)
        session.commit()
        self.fs_store.delete(
            f'datasource_{self.company_id}_{datasource_record.id}')
        try:
            shutil.rmtree(os.path.join(self.dir, name))
        except Exception:
            pass

    def save_datasource(self, name, source_type, source, file_path=None):
        datasource_record = Datasource(company_id=self.company_id, name=name)

        if source_type == 'file' and (file_path is None):
            raise Exception(
                '`file_path` argument required when source_type == "file"')

        ds_meta_dir = os.path.join(self.dir, name)
        os.mkdir(ds_meta_dir)

        session.add(datasource_record)
        session.commit()
        datasource_record = session.query(Datasource).filter_by(
            company_id=self.company_id, name=name).first()

        try:
            if source_type == 'file':
                source = os.path.join(ds_meta_dir, source)
                shutil.move(file_path, source)
                ds = FileDS(source)

                creation_info = {
                    'class': 'FileDS',
                    'args': [source],
                    'kwargs': {}
                }

            elif source_type in self.config['integrations']:
                integration = self.config['integrations'][source_type]

                ds_class_map = {
                    'clickhouse': ClickhouseDS,
                    'mariadb': MariaDS,
                    'mysql': MySqlDS,
                    'postgres': PostgresDS,
                    'mssql': MSSQLDS,
                    'mongodb': MongoDS,
                    'snowflake': SnowflakeDS
                }

                try:
                    dsClass = ds_class_map[integration['type']]
                except KeyError:
                    raise KeyError(
                        f"Unknown DS type: {source_type}, type is {integration['type']}"
                    )

                if integration['type'] in ['clickhouse']:
                    creation_info = {
                        'class': dsClass.__name__,
                        'args': [],
                        'kwargs': {
                            'query': source['query'],
                            'user': integration['user'],
                            'password': integration['password'],
                            'host': integration['host'],
                            'port': integration['port']
                        }
                    }
                    ds = dsClass(**creation_info['kwargs'])

                elif integration['type'] in [
                        'mssql', 'postgres', 'mariadb', 'mysql'
                ]:
                    creation_info = {
                        'class': dsClass.__name__,
                        'args': [],
                        'kwargs': {
                            'query': source['query'],
                            'user': integration['user'],
                            'password': integration['password'],
                            'host': integration['host'],
                            'port': integration['port']
                        }
                    }

                    if 'database' in integration:
                        creation_info['kwargs']['database'] = integration[
                            'database']

                    if 'database' in source:
                        creation_info['kwargs']['database'] = source[
                            'database']

                    ds = dsClass(**creation_info['kwargs'])

                elif integration['type'] == 'snowflake':
                    creation_info = {
                        'class': dsClass.__name__,
                        'args': [],
                        'kwargs': {
                            'query': source['query'],
                            'schema': source['schema'],
                            'warehouse': source['warehouse'],
                            'database': source['database'],
                            'host': integration['host'],
                            'password': integration['password'],
                            'user': integration['user'],
                            'account': integration['account']
                        }
                    }

                    ds = dsClass(**creation_info['kwargs'])

                elif integration['type'] == 'mongodb':
                    if isinstance(source['find'], str):
                        source['find'] = json.loads(source['find'])
                    creation_info = {
                        'class': dsClass.__name__,
                        'args': [],
                        'kwargs': {
                            'database': source['database'],
                            'collection': source['collection'],
                            'query': source['find'],
                            'user': integration['user'],
                            'password': integration['password'],
                            'host': integration['host'],
                            'port': integration['port']
                        }
                    }

                    ds = dsClass(**creation_info['kwargs'])
            else:
                # This probably only happens for urls
                ds = FileDS(source)
                creation_info = {
                    'class': 'FileDS',
                    'args': [source],
                    'kwargs': {}
                }

            df = ds.df

            if '' in df.columns or len(df.columns) != len(set(df.columns)):
                shutil.rmtree(ds_meta_dir)
                raise Exception(
                    'Each column in datasource must have unique non-empty name'
                )

            datasource_record.creation_info = json.dumps(creation_info)
            datasource_record.data = json.dumps({
                'source_type':
                source_type,
                'source':
                source,
                'row_count':
                len(df),
                'columns': [dict(name=x) for x in list(df.keys())]
            })

            self.fs_store.put(
                name, f'datasource_{self.company_id}_{datasource_record.id}',
                self.dir)

        except Exception:
            if os.path.isdir(ds_meta_dir):
                shutil.rmtree(ds_meta_dir)
            raise

        session.commit()
        return self.get_datasource_obj(name, raw=True), name

    def get_datasource_obj(self, name, raw=False):
        try:
            datasource_record = session.query(Datasource).filter_by(
                company_id=self.company_id, name=name).first()
            self.fs_store.get(
                name, f'datasource_{self.company_id}_{datasource_record.id}',
                self.dir)
            creation_info = json.loads(datasource_record.creation_info)
            if raw:
                return creation_info
            else:
                return eval(creation_info['class'])(*creation_info['args'],
                                                    **creation_info['kwargs'])
        except Exception as e:
            log.error(f'\n{e}\n')
            return None
コード例 #2
0
ファイル: model_controller.py プロジェクト: szhorizon/mindsdb
class ModelController():
    def __init__(self, ray_based):
        self.config = Config()
        self.fs_store = FsSotre()
        self.company_id = os.environ.get('MINDSDB_COMPANY_ID', None)
        self.dbw = DatabaseWrapper()
        self.predictor_cache = {}
        self.ray_based = ray_based

    def _pack(self, obj):
            if self.ray_based:
                return obj
            return xmlrpc.client.Binary(pickle.dumps(obj))

    def _invalidate_cached_predictors(self):
        from mindsdb_datasources import FileDS, ClickhouseDS, MariaDS, MySqlDS, PostgresDS, MSSQLDS, MongoDS, SnowflakeDS, AthenaDS
        import mindsdb_native
        from mindsdb_native import F
        from mindsdb_native.libs.constants.mindsdb import DATA_SUBTYPES
        from mindsdb.interfaces.storage.db import session, Predictor


        # @TODO: Cache will become stale if the respective NativeInterface is not invoked yet a bunch of predictors remained cached, no matter where we invoke it. In practice shouldn't be a big issue though
        for predictor_name in list(self.predictor_cache.keys()):
            if (datetime.datetime.now() - self.predictor_cache[predictor_name]['created']).total_seconds() > 1200:
                del self.predictor_cache[predictor_name]

    def _lock_predictor(self, id, mode='write'):
        from mindsdb.interfaces.storage.db import session, Semaphor

        while True:
            semaphor_record = session.query(Semaphor).filter_by(company_id=self.company_id, entity_id=id, entity_type='predictor').first()
            if semaphor_record is not None:
                if mode == 'read' and semaphor_record.action == 'read':
                    return True
            try:
                semaphor_record = Semaphor(company_id=self.company_id, entity_id=id, entity_type='predictor', action=mode)
                session.add(semaphor_record)
                session.commit()
                return True
            except Excpetion as e:
                pass
            time.sleep(1)


    def _unlock_predictor(self, id):
        from mindsdb.interfaces.storage.db import session, Semaphor
        semaphor_record = session.query(Semaphor).filter_by(company_id=self.company_id, entity_id=id, entity_type='predictor').first()
        if semaphor_record is not None:
            session.delete(semaphor_record)
            session.commit()

    @contextmanager
    def _lock_context(self, id, mode='write'):
        try:
            self._lock_predictor(mode)
            yield True
        finally:
            self._unlock_predictor(id)

    def _setup_for_creation(self, name):
        from mindsdb_datasources import FileDS, ClickhouseDS, MariaDS, MySqlDS, PostgresDS, MSSQLDS, MongoDS, SnowflakeDS, AthenaDS
        import mindsdb_native
        from mindsdb_native import F
        from mindsdb_native.libs.constants.mindsdb import DATA_SUBTYPES
        from mindsdb.interfaces.storage.db import session, Predictor


        if name in self.predictor_cache:
            del self.predictor_cache[name]
        # Here for no particular reason, because we want to run this sometimes but not too often
        self._invalidate_cached_predictors()

        predictor_dir = Path(self.config.paths['predictors']).joinpath(name)
        create_directory(predictor_dir)
        predictor_record = Predictor(company_id=self.company_id, name=name, is_custom=False)

        session.add(predictor_record)
        session.commit()

    def _try_outdate_db_status(self, predictor_record):
        from mindsdb_native import __version__ as native_version
        from mindsdb import __version__ as mindsdb_version
        from mindsdb.interfaces.storage.db import session

        if predictor_record.update_status == 'update_failed':
            return predictor_record

        if predictor_record.native_version != native_version:
            predictor_record.update_status = 'available'
        if predictor_record.mindsdb_version != mindsdb_version:
            predictor_record.update_status = 'available'

        session.commit()
        return predictor_record

    def _update_db_status(self, predictor_record):
        from mindsdb_native import __version__ as native_version
        from mindsdb import __version__ as mindsdb_version
        from mindsdb.interfaces.storage.db import session

        predictor_record.native_version = native_version
        predictor_record.mindsdb_version = mindsdb_version
        predictor_record.update_status = 'up_to_date'

        session.commit()
        return predictor_record

    def create(self, name):
        from mindsdb_datasources import FileDS, ClickhouseDS, MariaDS, MySqlDS, PostgresDS, MSSQLDS, MongoDS, SnowflakeDS, AthenaDS
        import mindsdb_native
        from mindsdb_native import F
        from mindsdb_native.libs.constants.mindsdb import DATA_SUBTYPES
        from mindsdb.interfaces.storage.db import session, Predictor


        self._setup_for_creation(name)
        predictor = mindsdb_native.Predictor(name=name, run_env={'trigger': 'mindsdb'})
        return predictor

    def learn(self, name, from_data, to_predict, datasource_id, kwargs={}):
        from mindsdb.interfaces.model.learn_process import LearnProcess, run_learn

        create_process_mark('learn')

        join_learn_process = kwargs.get('join_learn_process', False)
        if 'join_learn_process' in kwargs:
            del kwargs['join_learn_process']

        self._setup_for_creation(name)

        if self.ray_based:
            run_learn(name, from_data, to_predict, kwargs, datasource_id)
        else:
            p = LearnProcess(name, from_data, to_predict, kwargs, datasource_id)
            p.start()
            if join_learn_process is True:
                p.join()
                if p.exitcode != 0:
                    delete_process_mark('learn')
                    raise Exception('Learning process failed !')

        delete_process_mark('learn')
        return 0

    def predict(self, name, pred_format, when_data=None, kwargs={}):
        from mindsdb_datasources import FileDS, ClickhouseDS, MariaDS, MySqlDS, PostgresDS, MSSQLDS, MongoDS, SnowflakeDS, AthenaDS
        import mindsdb_native
        from mindsdb.interfaces.storage.db import session, Predictor

        create_process_mark('predict')

        if name not in self.predictor_cache:
            # Clear the cache entirely if we have less than 1.2 GB left
            if psutil.virtual_memory().available < 1.2 * pow(10, 9):
                self.predictor_cache = {}

            predictor_record = Predictor.query.filter_by(company_id=self.company_id, name=name, is_custom=False).first()
            if predictor_record.data['status'] == 'complete':
                self.fs_store.get(name, f'predictor_{self.company_id}_{predictor_record.id}', self.config['paths']['predictors'])
                self.predictor_cache[name] = {
                    'predictor': mindsdb_native.Predictor(name=name, run_env={'trigger': 'mindsdb'}),
                    'created': datetime.datetime.now()
                }

        if isinstance(when_data, dict) and 'kwargs' in when_data and 'args' in when_data:
            data_source = getattr(mindsdb_datasources, when_data['class'])(*when_data['args'], **when_data['kwargs'])
        else:
            # @TODO: Replace with Datasource
            try:
                data_source = pd.DataFrame(when_data)
            except Exception:
                data_source = when_data

        predictions = self.predictor_cache[name]['predictor'].predict(
            when_data=when_data,
            **kwargs
        )
        if pred_format == 'explain' or pred_format == 'new_explain':
            predictions = [p.explain() for p in predictions]
        elif pred_format == 'dict':
            predictions = [p.as_dict() for p in predictions]
        elif pred_format == 'dict&explain':
            predictions = [[p.as_dict() for p in predictions], [p.explain() for p in predictions]]
        else:
            delete_process_mark('predict')
            raise Exception(f'Unkown predictions format: {pred_format}')

        delete_process_mark('predict')
        return self._pack(predictions)

    def analyse_dataset(self, ds):
        from mindsdb_datasources import FileDS, ClickhouseDS, MariaDS, MySqlDS, PostgresDS, MSSQLDS, MongoDS, SnowflakeDS, AthenaDS
        from mindsdb_native import F

        create_process_mark('analyse')

        ds = eval(ds['class'])(*ds['args'], **ds['kwargs'])
        analysis = F.analyse_dataset(ds)

        delete_process_mark('analyse')
        return self._pack(analysis)

    def get_model_data(self, name, db_fix=True):
        from mindsdb_native import F
        from mindsdb_native.libs.constants.mindsdb import DATA_SUBTYPES
        from mindsdb.interfaces.storage.db import session, Predictor


        predictor_record = Predictor.query.filter_by(company_id=self.company_id, name=name, is_custom=False).first()
        predictor_record = self._try_outdate_db_status(predictor_record)
        model = predictor_record.data
        if model is None or model['status'] == 'training':
            try:
                self.fs_store.get(name, f'predictor_{self.company_id}_{predictor_record.id}', self.config['paths']['predictors'])
                new_model_data = mindsdb_native.F.get_model_data(name)
            except Exception:
                new_model_data = None

            if predictor_record.data is None or (new_model_data is not None and len(new_model_data) > len(predictor_record.data)):
                predictor_record.data = new_model_data
                model = new_model_data
                session.commit()

        # Make some corrections for databases not to break when dealing with empty columns
        if db_fix:
            data_analysis = model['data_analysis_v2']
            for column in model['columns']:
                analysis = data_analysis.get(column)
                if isinstance(analysis, dict) and (len(analysis) == 0 or analysis.get('empty', {}).get('is_empty', False)):
                    data_analysis[column]['typing'] = {
                        'data_subtype': DATA_SUBTYPES.INT
                    }

        model['created_at'] = str(parse_datetime(str(predictor_record.created_at).split('.')[0]))
        model['updated_at'] = str(parse_datetime(str(predictor_record.updated_at).split('.')[0]))
        model['predict'] = predictor_record.to_predict
        model['update'] = predictor_record.update_status
        return self._pack(model)

    def get_models(self):
        from mindsdb.interfaces.storage.db import session, Predictor

        models = []
        predictor_records = Predictor.query.filter_by(company_id=self.company_id, is_custom=False)
        predictor_names = [
            x.name for x in predictor_records
        ]
        for model_name in predictor_names:
            try:
                if self.ray_based:
                    model_data = self.get_model_data(model_name, db_fix=False)
                else:
                    bin = self.get_model_data(model_name, db_fix=False)
                    model_data = pickle.loads(bin.data)
                reduced_model_data = {}

                for k in ['name', 'version', 'is_active', 'predict', 'status', 'current_phase', 'accuracy', 'data_source', 'update']:
                    reduced_model_data[k] = model_data.get(k, None)

                for k in ['train_end_at', 'updated_at', 'created_at']:
                    reduced_model_data[k] = model_data.get(k, None)
                    if reduced_model_data[k] is not None:
                        try:
                            reduced_model_data[k] = parse_datetime(str(reduced_model_data[k]).split('.')[0])
                        except Exception as e:
                            # @TODO Does this ever happen
                            log.error(f'Date parsing exception while parsing: {k} in get_models: ', e)
                            reduced_model_data[k] = parse_datetime(str(reduced_model_data[k]))

                models.append(reduced_model_data)
            except Exception as e:
                log.error(f"Can't list data for model: '{model_name}' when calling `get_models(), error: {e}`")
        return self._pack(models)

    def delete_model(self, name):
        from mindsdb_native import F
        from mindsdb_native.libs.constants.mindsdb import DATA_SUBTYPES
        from mindsdb.interfaces.storage.db import session, Predictor

        predictor_record = Predictor.query.filter_by(company_id=self.company_id, name=name, is_custom=False).first()
        id = predictor_record.id
        session.delete(predictor_record)
        session.commit()
        F.delete_model(name)
        self.dbw.unregister_predictor(name)
        self.fs_store.delete(f'predictor_{self.company_id}_{id}')
        return 0

    def update_model(self, name):
        from mindsdb_native import F
        from mindsdb_worker.updater.update_model import update_model
        from mindsdb.interfaces.storage.db import session, Predictor
        from mindsdb.interfaces.datastore.datastore import DataStore

        try:
            predictor_record = Predictor.query.filter_by(company_id=self.company_id, name=name, is_custom=False).first()
            predictor_record.update_status = 'updating'
            session.commit()
            update_model(name, self.delete_model, F.delete_model, self.learn, self._lock_context, self.company_id, self.config['paths']['predictors'], predictor_record, self.fs_store, DataStore())

            predictor_record = self._update_db_status(predictor_record)
        except Exception as e:
            log.error(e)
            predictor_record.update_status = 'update_failed'
            session.commit()
            return str(e)
コード例 #3
0
class CustomModels():
    def __init__(self):
        self.config = Config()
        self.fs_store = FsSotre()
        self.company_id = os.environ.get('MINDSDB_COMPANY_ID', None)
        self.dbw = DatabaseWrapper()
        self.storage_dir = self.config['paths']['custom_models']
        os.makedirs(self.storage_dir, exist_ok=True)
        self.model_cache = {}
        self.mindsdb_native = NativeInterface()
        self.dbw = DatabaseWrapper()

    def _dir(self, name):
        return str(os.path.join(self.storage_dir, name))

    def _internal_load(self, name):
        self.fs_store.get(name, f'custom_model_{self.company_id}_{name}',
                          self.storage_dir)
        sys.path.insert(0, self._dir(name))
        module = __import__(name)

        try:
            model = module.Model.load(
                os.path.join(self._dir(name), 'model.pickle'))
        except Exception as e:
            model = module.Model()
            model.initialize_column_types()
            if hasattr(model, 'setup'):
                model.setup()

        self.model_cache[name] = model

        return model

    def learn(self, name, from_data, to_predict, datasource_id, kwargs={}):
        model_data = self.get_model_data(name)
        model_data['status'] = 'training'
        self.save_model_data(name, model_data)

        to_predict = to_predict if isinstance(to_predict,
                                              list) else [to_predict]

        data_source = getattr(mindsdb_datasources,
                              from_data['class'])(*from_data['args'],
                                                  **from_data['kwargs'])
        data_frame = data_source.df
        model = self._internal_load(name)
        model.to_predict = to_predict

        model_data = self.get_model_data(name)
        model_data['predict'] = model.to_predict
        self.save_model_data(name, model_data)

        data_analysis = self.mindsdb_native.analyse_dataset(
            data_source)['data_analysis_v2']

        model_data = self.get_model_data(name)
        model_data['data_analysis_v2'] = data_analysis
        self.save_model_data(name, model_data)

        model.fit(data_frame, to_predict, data_analysis, kwargs)

        model.save(os.path.join(self._dir(name), 'model.pickle'))
        self.model_cache[name] = model

        model_data = self.get_model_data(name)
        model_data['status'] = 'completed'
        model_data['columns'] = list(data_analysis.keys())
        self.save_model_data(name, model_data)
        self.fs_store.put(name, f'custom_model_{self.company_id}_{name}',
                          self.storage_dir)

        self.dbw.unregister_predictor(name)
        self.dbw.register_predictors([self.get_model_data(name)])

    def predict(self, name, when_data=None, from_data=None, kwargs=None):
        self.fs_store.get(name, f'custom_model_{self.company_id}_{name}',
                          self.storage_dir)
        if kwargs is None:
            kwargs = {}
        if from_data is not None:
            if isinstance(from_data, dict):
                data_source = getattr(mindsdb_datasources, from_data['class'])(
                    *from_data['args'], **from_data['kwargs'])
            # assume that particular instance of any DataSource class is provided
            else:
                data_source = from_data
            data_frame = data_source.df
        elif when_data is not None:
            if isinstance(when_data, dict):
                for k in when_data:
                    when_data[k] = [when_data[k]]
                data_frame = pd.DataFrame(when_data)
            else:
                data_frame = pd.DataFrame(when_data)

        model = self._internal_load(name)
        predictions = model.predict(data_frame, kwargs)

        pred_arr = []
        for i in range(len(predictions)):
            pred_arr.append({})
            pred_arr[-1] = {}
            for col in predictions.columns:
                pred_arr[-1][col] = {}
                pred_arr[-1][col]['predicted_value'] = predictions[col].iloc[i]

        return pred_arr

    def get_model_data(self, name):
        predictor_record = Predictor.query.filter_by(
            company_id=self.company_id, name=name, is_custom=True).first()
        return predictor_record.data

    def save_model_data(self, name, data):
        predictor_record = Predictor.query.filter_by(
            company_id=self.company_id, name=name, is_custom=True).first()
        if predictor_record is None:
            predictor_record = Predictor(company_id=self.company_id,
                                         name=name,
                                         is_custom=True,
                                         data=data)
            session.add(predictor_record)
        else:
            predictor_record.data = data
        session.commit()

    def get_models(self):
        predictor_names = [
            x.name
            for x in Predictor.query.filter_by(company_id=self.company_id,
                                               is_custom=True)
        ]
        models = []
        for name in predictor_names:
            models.append(self.get_model_data(name))

        return models

    def delete_model(self, name):
        Predictor.query.filter_by(company_id=self.company_id,
                                  name=name,
                                  is_custom=True).delete()
        session.commit()
        shutil.rmtree(self._dir(name))
        self.dbw.unregister_predictor(name)
        self.fs_store.delete(f'custom_model_{self.company_id}_{name}')

    def rename_model(self, name, new_name):
        self.fs_store.get(name, f'custom_model_{self.company_id}_{name}',
                          self.storage_dir)

        self.dbw.unregister_predictor(name)
        shutil.move(self._dir(name), self._dir(new_name))
        shutil.move(os.path.join(self._dir(new_name) + f'{name}.py'),
                    os.path.join(self._dir(new_name), f'{new_name}.py'))

        predictor_record = Predictor.query.filter_by(
            company_id=self.company_id, name=name, is_custom=True).first()
        predictor_record.name = new_name
        session.commit()

        self.dbw.register_predictors([self.get_model_data(new_name)])

        self.fs_store.put(name, f'custom_model_{self.company_id}_{new_name}',
                          self.storage_dir)
        self.fs_store.delete(f'custom_model_{self.company_id}_{name}')

    def export_model(self, name):
        shutil.make_archive(base_name=name,
                            format='zip',
                            root_dir=self._dir(name))
        return str(self._dir(name)) + '.zip'

    def load_model(self, fpath, name, trained_status):
        shutil.unpack_archive(fpath, self._dir(name), 'zip')
        shutil.move(os.path.join(self._dir(name), 'model.py'),
                    os.path.join(self._dir(name), f'{name}.py'))
        model = self._internal_load(name)
        model.to_predict = model.to_predict if isinstance(
            model.to_predict, list) else [model.to_predict]
        self.save_model_data(
            name, {
                'name': name,
                'data_analysis_v2': model.column_type_map,
                'predict': model.to_predict,
                'status': trained_status,
                'is_custom': True,
                'columns': list(model.column_type_map.keys())
            })

        with open(os.path.join(self._dir(name), '__init__.py'), 'w') as fp:
            fp.write('')

        self.fs_store.put(name, f'custom_model_{self.company_id}_{name}',
                          self.storage_dir)

        if trained_status == 'trained':
            self.dbw.register_predictors([self.get_model_data(name)])
コード例 #4
0
class NativeInterface():
    def __init__(self):
        self.config = Config()
        self.fs_store = FsSotre()
        self.company_id = os.environ.get('MINDSDB_COMPANY_ID', None)
        self.dbw = DatabaseWrapper()
        self.predictor_cache = {}

    def _invalidate_cached_predictors(self):
        # @TODO: Cache will become stale if the respective NativeInterface is not invoked yet a bunch of predictors remained cached, no matter where we invoke it. In practice shouldn't be a big issue though
        for predictor_name in list(self.predictor_cache.keys()):
            if (datetime.datetime.now() -
                    self.predictor_cache[predictor_name]['created']
                ).total_seconds() > 1200:
                del self.predictor_cache[predictor_name]

    def _setup_for_creation(self, name):
        if name in self.predictor_cache:
            del self.predictor_cache[name]
        # Here for no particular reason, because we want to run this sometimes but not too often
        self._invalidate_cached_predictors()

        predictor_dir = Path(self.config.paths['predictors']).joinpath(name)
        create_directory(predictor_dir)
        predictor_record = Predictor(company_id=self.company_id,
                                     name=name,
                                     is_custom=False)

        session.add(predictor_record)
        session.commit()

    def create(self, name):
        self._setup_for_creation(name)
        predictor = mindsdb_native.Predictor(name=name,
                                             run_env={'trigger': 'mindsdb'})
        return predictor

    def learn(self, name, from_data, to_predict, datasource_id, kwargs={}):
        join_learn_process = kwargs.get('join_learn_process', False)
        if 'join_learn_process' in kwargs:
            del kwargs['join_learn_process']

        self._setup_for_creation(name)

        p = LearnProcess(name, from_data, to_predict, kwargs, datasource_id)
        p.start()
        if join_learn_process is True:
            p.join()
            if p.exitcode != 0:
                raise Exception('Learning process failed !')

    def predict(self, name, when_data=None, kwargs={}):
        try:
            original_process_title = setproctitle.getproctitle()
            setproctitle.setproctitle('mindsdb_native_process')
        except Exception:
            pass

        if name not in self.predictor_cache:
            # Clear the cache entirely if we have less than 1.2 GB left
            if psutil.virtual_memory().available < 1.2 * pow(10, 9):
                self.predictor_cache = {}

            predictor_record = Predictor.query.filter_by(
                company_id=self.company_id, name=name,
                is_custom=False).first()
            if predictor_record.data['status'] == 'complete':
                self.fs_store.get(
                    name, f'predictor_{self.company_id}_{predictor_record.id}',
                    self.config['paths']['predictors'])
                self.predictor_cache[name] = {
                    'predictor':
                    mindsdb_native.Predictor(name=name,
                                             run_env={'trigger': 'mindsdb'}),
                    'created':
                    datetime.datetime.now()
                }

        predictions = self.predictor_cache[name]['predictor'].predict(
            when_data=when_data, **kwargs)

        try:
            setproctitle.setproctitle(original_process_title)
        except Exception:
            pass
        return predictions

    # @TODO Move somewhere else to avoid circular import issues in the future
    def analyse_dataset(self, ds):
        return F.analyse_dataset(ds)

    def get_model_data(self, name, db_fix=True):
        predictor_record = Predictor.query.filter_by(
            company_id=self.company_id, name=name, is_custom=False).first()
        model = predictor_record.data
        if model is None or model['status'] == 'training':
            try:
                self.fs_store.get(
                    name, f'predictor_{self.company_id}_{predictor_record.id}',
                    self.config['paths']['predictors'])
                new_model_data = mindsdb_native.F.get_model_data(name)
            except Exception:
                new_model_data = None

            if predictor_record.data is None or (
                    new_model_data is not None
                    and len(new_model_data) > len(predictor_record.data)):
                predictor_record.data = new_model_data
                model = new_model_data
                session.commit()

        # Make some corrections for databases not to break when dealing with empty columns
        if db_fix:
            data_analysis = model['data_analysis_v2']
            for column in model['columns']:
                analysis = data_analysis.get(column)
                if isinstance(analysis,
                              dict) and (len(analysis) == 0 or analysis.get(
                                  'empty', {}).get('is_empty', False)):
                    data_analysis[column]['typing'] = {
                        'data_subtype': DATA_SUBTYPES.INT
                    }

        return model

    def get_models(self):
        models = []
        predictor_names = [
            x.name
            for x in Predictor.query.filter_by(company_id=self.company_id,
                                               is_custom=False)
        ]
        for model_name in predictor_names:
            try:
                model_data = self.get_model_data(model_name, db_fix=False)
                if model_data['status'] == 'training' and parse_datetime(
                        model_data['created_at']) < parse_datetime(
                            self.config['mindsdb_last_started_at']):
                    continue

                reduced_model_data = {}

                for k in [
                        'name', 'version', 'is_active', 'predict', 'status',
                        'current_phase', 'accuracy', 'data_source'
                ]:
                    reduced_model_data[k] = model_data.get(k, None)

                for k in ['train_end_at', 'updated_at', 'created_at']:
                    reduced_model_data[k] = model_data.get(k, None)
                    if reduced_model_data[k] is not None:
                        try:
                            reduced_model_data[k] = parse_datetime(
                                str(reduced_model_data[k]).split('.')[0])
                        except Exception as e:
                            # @TODO Does this ever happen
                            print(
                                f'Date parsing exception while parsing: {k} in get_models: ',
                                e)
                            reduced_model_data[k] = parse_datetime(
                                str(reduced_model_data[k]))

                models.append(reduced_model_data)
            except Exception as e:
                print(
                    f"Can't list data for model: '{model_name}' when calling `get_models(), error: {e}`"
                )

        return models

    def delete_model(self, name):
        predictor_record = Predictor.query.filter_by(
            company_id=self.company_id, name=name, is_custom=False).first()
        id = predictor_record.id
        session.delete(predictor_record)
        session.commit()
        F.delete_model(name)
        self.dbw.unregister_predictor(name)
        self.fs_store.delete(f'predictor_{self.company_id}_{id}')