Example #1
0
def run_fit(predictor_id: int, df: pd.DataFrame) -> None:
    try:
        predictor_record = session.query(db.Predictor).filter_by(id=predictor_id).first()
        assert predictor_record is not None

        fs_store = FsStore()
        config = Config()

        predictor_record.data = {'training_log': 'training'}
        session.commit()
        predictor: lightwood.PredictorInterface = lightwood.predictor_from_code(predictor_record.code)
        predictor.learn(df)

        session.refresh(predictor_record)

        fs_name = f'predictor_{predictor_record.company_id}_{predictor_record.id}'
        pickle_path = os.path.join(config['paths']['predictors'], fs_name)
        predictor.save(pickle_path)

        fs_store.put(fs_name, fs_name, config['paths']['predictors'])

        predictor_record.data = predictor.model_analysis.to_dict()
        predictor_record.dtype_dict = predictor.dtype_dict
        session.commit()

        dbw = DatabaseWrapper(predictor_record.company_id)
        mi = ModelInterfaceWrapper(ModelInterface(), predictor_record.company_id)
        dbw.register_predictors([mi.get_model_data(predictor_record.name)])
    except Exception as e:
        session.refresh(predictor_record)
        predictor_record.data = {'error': f'{traceback.format_exc()}\nMain error: {e}'}
        session.commit()
        raise e
Example #2
0
def _get_integration_record_data(integration_record, sensitive_info=True):
    if integration_record is None or integration_record.data is None:
        return None
    data = deepcopy(integration_record.data)
    if data.get('password', None) is None:
        data['password'] = ''
    data['date_last_update'] = deepcopy(integration_record.updated_at)

    bundle_path = data.get('secure_connect_bundle')
    mysql_ssl_ca = data.get('ssl_ca')
    mysql_ssl_cert = data.get('ssl_cert')
    mysql_ssl_key = data.get('ssl_key')
    if (data.get('type') in ('mysql', 'mariadb') and
        (_is_not_empty_str(mysql_ssl_ca) or _is_not_empty_str(mysql_ssl_cert)
         or _is_not_empty_str(mysql_ssl_key))
            or data.get('type') in ('cassandra', 'scylla')
            and bundle_path is not None):
        fs_store = FsStore()
        integrations_dir = Config()['paths']['integrations']
        folder_name = f'integration_files_{integration_record.company_id}_{integration_record.id}'
        integration_dir = os.path.join(integrations_dir, folder_name)
        fs_store.get(folder_name, integration_dir, integrations_dir)

    if not sensitive_info:
        if 'password' in data:
            data['password'] = None
        if (data.get('type') == 'redis'
                and isinstance(data.get('connection'), dict)
                and 'password' in data['connection']):
            data['connection'] = None

    data['id'] = integration_record.id

    return data
Example #3
0
def run_update(name: str, company_id: int):
    original_name = name
    name = f'{company_id}@@@@@{name}'

    fs_store = FsStore()
    config = Config()
    data_store = DataStoreWrapper(DataStore(), company_id)

    try:
        predictor_record = Predictor.query.filter_by(company_id=company_id, name=original_name).first()
        assert predictor_record is not None

        predictor_record.update_status = 'updating'

        session.commit()
        ds = data_store.get_datasource_obj(None, raw=False, id=predictor_record.datasource_id)
        df = ds.df

        problem_definition = predictor_record.learn_args

        problem_definition['target'] = predictor_record.to_predict[0]

        if 'join_learn_process' in problem_definition:
            del problem_definition['join_learn_process']

        # Adapt kwargs to problem definition
        if 'timeseries_settings' in problem_definition:
            problem_definition['timeseries_settings'] = problem_definition['timeseries_settings']

        if 'stop_training_in_x_seconds' in problem_definition:
            problem_definition['time_aim'] = problem_definition['stop_training_in_x_seconds']

        json_ai = lightwood.json_ai_from_problem(df, problem_definition)
        predictor_record.json_ai = json_ai.to_dict()
        predictor_record.code = lightwood.code_from_json_ai(json_ai)
        predictor_record.data = {'training_log': 'training'}
        session.commit()
        predictor: lightwood.PredictorInterface = lightwood.predictor_from_code(predictor_record.code)
        predictor.learn(df)

        fs_name = f'predictor_{predictor_record.company_id}_{predictor_record.id}'
        pickle_path = os.path.join(config['paths']['predictors'], fs_name)
        predictor.save(pickle_path)
        fs_store.put(fs_name, fs_name, config['paths']['predictors'])
        predictor_record.data = predictor.model_analysis.to_dict()  # type: ignore
        session.commit()

        predictor_record.lightwood_version = lightwood.__version__
        predictor_record.mindsdb_version = mindsdb_version
        predictor_record.update_status = 'up_to_date'
        session.commit()

    except Exception as e:
        log.error(e)
        predictor_record.update_status = 'update_failed'  # type: ignore
        session.commit()
        return str(e)
Example #4
0
def run_fit(predictor_id: int, df: pd.DataFrame) -> None:
    try:
        predictor_record = Predictor.query.with_for_update().get(predictor_id)
        assert predictor_record is not None

        fs_store = FsStore()
        config = Config()

        predictor_record.data = {'training_log': 'training'}
        session.commit()
        predictor: lightwood.PredictorInterface = lightwood.predictor_from_code(
            predictor_record.code)
        predictor.learn(df)

        session.refresh(predictor_record)

        fs_name = f'predictor_{predictor_record.company_id}_{predictor_record.id}'
        pickle_path = os.path.join(config['paths']['predictors'], fs_name)
        predictor.save(pickle_path)

        fs_store.put(fs_name, fs_name, config['paths']['predictors'])

        predictor_record.data = predictor.model_analysis.to_dict()

        # getting training time for each tried model. it is possible to do
        # after training only
        fit_mixers = list(predictor.runtime_log[x]
                          for x in predictor.runtime_log
                          if isinstance(x, tuple) and x[0] == "fit_mixer")
        submodel_data = predictor_record.data.get("submodel_data", [])
        # add training time to other mixers info
        if submodel_data and fit_mixers and len(submodel_data) == len(
                fit_mixers):
            for i, tr_time in enumerate(fit_mixers):
                submodel_data[i]["training_time"] = tr_time
        predictor_record.data["submodel_data"] = submodel_data

        predictor_record.dtype_dict = predictor.dtype_dict
        session.commit()

        dbw = DatabaseWrapper(predictor_record.company_id)
        mi = WithKWArgsWrapper(ModelInterface(),
                               company_id=predictor_record.company_id)
    except Exception as e:
        session.refresh(predictor_record)
        predictor_record.data = {
            'error': f'{traceback.format_exc()}\nMain error: {e}'
        }
        session.commit()
        raise e

    try:
        dbw.register_predictors([mi.get_model_data(predictor_record.name)])
    except Exception as e:
        log.warn(e)
Example #5
0
def remove_db_integration(name, company_id):
    integration_record = session.query(Integration).filter_by(
        company_id=company_id, name=name).first()
    integrations_dir = Config()['paths']['integrations']
    folder_name = f'integration_files_{company_id}_{integration_record.id}'
    integration_dir = os.path.join(integrations_dir, folder_name)
    if os.path.isdir(integration_dir):
        shutil.rmtree(integration_dir)
    try:
        FsStore().delete(folder_name)
    except Exception:
        pass
    session.delete(integration_record)
    session.commit()
Example #6
0
 def __init__(self):
     self.config = Config()
     self.fs_store = FsStore()
     self.dir = self.config['paths']['datasources']
     self.model_interface = ModelInterface()
Example #7
0
class DataStore():
    def __init__(self):
        self.config = Config()
        self.fs_store = FsStore()
        self.dir = self.config['paths']['datasources']
        self.model_interface = ModelInterface()

    def get_analysis(self, name, company_id=None):
        datasource_record = session.query(Datasource).filter_by(
            company_id=company_id, name=name).first()
        if datasource_record.analysis is None:
            return None
        analysis = json.loads(datasource_record.analysis)
        return analysis

    def start_analysis(self, name, company_id=None):
        datasource_record = session.query(Datasource).filter_by(
            company_id=company_id, name=name).first()
        if datasource_record.analysis is not None:
            return None
        semaphor_record = session.query(Semaphor).filter_by(
            company_id=company_id,
            entity_id=datasource_record.id,
            entity_type='datasource').first()
        if semaphor_record is None:
            semaphor_record = Semaphor(company_id=company_id,
                                       entity_id=datasource_record.id,
                                       entity_type='datasource',
                                       action='write')
            session.add(semaphor_record)
            session.commit()
        else:
            return
        try:
            analysis = self.model_interface.analyse_dataset(
                ds=self.get_datasource_obj(name,
                                           raw=True,
                                           company_id=company_id),
                company_id=company_id)
            datasource_record = session.query(Datasource).filter_by(
                company_id=company_id, name=name).first()
            datasource_record.analysis = json.dumps(analysis,
                                                    cls=CustomJSONEncoder)
            session.commit()
        except Exception as e:
            log.error(e)
        finally:
            semaphor_record = session.query(Semaphor).filter_by(
                company_id=company_id,
                entity_id=datasource_record.id,
                entity_type='datasource').first()
            session.delete(semaphor_record)
            session.commit()

    def get_datasources(self, name=None, company_id=None):
        datasource_arr = []
        if name is not None:
            datasource_record_arr = session.query(Datasource).filter_by(
                company_id=company_id, name=name)
        else:
            datasource_record_arr = session.query(Datasource).filter_by(
                company_id=company_id)
        for datasource_record in datasource_record_arr:
            try:
                if datasource_record.data is None:
                    continue
                datasource = json.loads(datasource_record.data)
                datasource['created_at'] = datasource_record.created_at
                datasource['updated_at'] = datasource_record.updated_at
                datasource['name'] = datasource_record.name
                datasource['id'] = datasource_record.id
                datasource_arr.append(datasource)
            except Exception as e:
                log.error(e)
        return datasource_arr

    def get_data(self,
                 name,
                 where=None,
                 limit=None,
                 offset=None,
                 company_id=None):
        offset = 0 if offset is None else offset
        ds = self.get_datasource_obj(name, company_id=company_id)

        if limit is not None:
            # @TODO Add `offset` to the `filter` method of the datasource and get rid of `offset`
            filtered_ds = ds.filter(where=where,
                                    limit=limit + offset).iloc[offset:]
        else:
            filtered_ds = ds.filter(where=where)

        filtered_ds = filtered_ds.where(pd.notnull(filtered_ds), None)
        data = filtered_ds.to_dict(orient='records')
        return {
            'data': data,
            'rowcount': len(ds),
            'columns_names': list(data[0].keys())
        }

    def get_datasource(self, name, company_id=None):
        datasource_arr = self.get_datasources(name, company_id=company_id)
        if len(datasource_arr) == 1:
            return datasource_arr[0]
        # @TODO: Remove when db swithc is more stable, this should never happen, but good santiy check while this is kinda buggy
        elif len(datasource_arr) > 1:
            log.error('Two or more datasource with the same name, (',
                      len(datasource_arr), ') | Full list: ', datasource_arr)
            raise Exception('Two or more datasource with the same name')
        return None

    def delete_datasource(self, name, company_id=None):
        datasource_record = Datasource.query.filter_by(company_id=company_id,
                                                       name=name).first()
        if not Config()["force_datasource_removing"]:
            linked_models = Predictor.query.filter_by(
                company_id=company_id,
                datasource_id=datasource_record.id).all()
            if linked_models:
                raise Exception(
                    "Can't delete {} datasource because there are next models linked to it: {}"
                    .format(name, [model.name for model in linked_models]))
        session.query(Semaphor).filter_by(company_id=company_id,
                                          entity_id=datasource_record.id,
                                          entity_type='datasource').delete()
        session.delete(datasource_record)
        session.commit()
        self.fs_store.delete(f'datasource_{company_id}_{datasource_record.id}')
        try:
            shutil.rmtree(os.path.join(self.dir, f'{company_id}@@@@@{name}'))
        except Exception:
            pass

    def get_vacant_name(self, base=None, company_id=None):
        ''' returns name of datasource, which starts from 'base' and ds with that name is not exists yet
        '''
        if base is None:
            base = 'datasource'
        datasources = session.query(
            Datasource.name).filter_by(company_id=company_id).all()
        datasources_names = [x[0] for x in datasources]
        if base not in datasources_names:
            return base
        for i in range(1, 1000):
            candidate = f'{base}_{i}'
            if candidate not in datasources_names:
                return candidate
        raise Exception(
            f"Can not find appropriate name for datasource '{base}'")

    def create_datasource(self,
                          source_type,
                          source,
                          file_path=None,
                          company_id=None,
                          ds_meta_dir=None):
        datasource_controller = DatasourceController()
        if source_type == 'file':
            source = os.path.join(ds_meta_dir, source)
            shutil.move(file_path, source)
            ds = FileDS(source)

            creation_info = {'class': 'FileDS', 'args': [source], 'kwargs': {}}

        elif datasource_controller.get_db_integration(source_type,
                                                      company_id) is not None:
            integration = datasource_controller.get_db_integration(
                source_type, company_id)

            ds_class_map = {
                'clickhouse': ClickhouseDS,
                'mariadb': MariaDS,
                'mysql': MySqlDS,
                'singlestore': MySqlDS,
                'postgres': PostgresDS,
                'cockroachdb': PostgresDS,
                'mssql': MSSQLDS,
                'mongodb': MongoDS,
                'snowflake': SnowflakeDS,
                'athena': AthenaDS,
                'cassandra': CassandraDS,
                'scylladb': ScyllaDS,
                'trinodb': TrinoDS
            }

            try:
                dsClass = ds_class_map[integration['type']]
            except KeyError:
                raise KeyError(
                    f"Unknown DS type: {source_type}, type is {integration['type']}"
                )

            if dsClass is None:
                raise Exception(
                    f"Unsupported datasource: {source_type}, type is {integration['type']}, please install required dependencies!"
                )

            if integration['type'] in ['clickhouse']:
                creation_info = {
                    'class': dsClass.__name__,
                    'args': [],
                    'kwargs': {
                        'query': source['query'],
                        'user': integration['user'],
                        'password': integration['password'],
                        'host': integration['host'],
                        'port': integration['port']
                    }
                }
                ds = dsClass(**creation_info['kwargs'])

            elif integration['type'] in [
                    'mssql', 'postgres', 'cockroachdb', 'mariadb', 'mysql',
                    'singlestore', 'cassandra', 'scylladb'
            ]:
                creation_info = {
                    'class': dsClass.__name__,
                    'args': [],
                    'kwargs': {
                        'query': source['query'],
                        'user': integration['user'],
                        'password': integration['password'],
                        'host': integration['host'],
                        'port': integration['port']
                    }
                }
                kwargs = creation_info['kwargs']

                integration_folder_name = f'integration_files_{company_id}_{integration["id"]}'
                if integration['type'] in ('mysql', 'mariadb'):
                    kwargs['ssl'] = integration.get('ssl')
                    kwargs['ssl_ca'] = integration.get('ssl_ca')
                    kwargs['ssl_cert'] = integration.get('ssl_cert')
                    kwargs['ssl_key'] = integration.get('ssl_key')
                    for key in ['ssl_ca', 'ssl_cert', 'ssl_key']:
                        if isinstance(kwargs[key],
                                      str) and len(kwargs[key]) > 0:
                            kwargs[key] = os.path.join(
                                self.integrations_dir, integration_folder_name,
                                kwargs[key])
                elif integration['type'] in ('cassandra', 'scylla'):
                    kwargs['secure_connect_bundle'] = integration.get(
                        'secure_connect_bundle')
                    if (isinstance(kwargs['secure_connect_bundle'], str)
                            and len(kwargs['secure_connect_bundle']) > 0):
                        kwargs['secure_connect_bundle'] = os.path.join(
                            self.integrations_dir, integration_folder_name,
                            kwargs['secure_connect_bundle'])

                if 'database' in integration:
                    kwargs['database'] = integration['database']

                if 'database' in source:
                    kwargs['database'] = source['database']

                ds = dsClass(**kwargs)

            elif integration['type'] == 'snowflake':
                creation_info = {
                    'class': dsClass.__name__,
                    'args': [],
                    'kwargs': {
                        'query':
                        source['query'].replace('"', "'"),
                        'schema':
                        source.get('schema', integration['schema']),
                        'warehouse':
                        source.get('warehouse', integration['warehouse']),
                        'database':
                        source.get('database', integration['database']),
                        'host':
                        integration['host'],
                        'password':
                        integration['password'],
                        'user':
                        integration['user'],
                        'account':
                        integration['account']
                    }
                }

                ds = dsClass(**creation_info['kwargs'])

            elif integration['type'] == 'mongodb':
                if isinstance(source['find'], str):
                    source['find'] = json.loads(source['find'])
                creation_info = {
                    'class': dsClass.__name__,
                    'args': [],
                    'kwargs': {
                        'database': source['database'],
                        'collection': source['collection'],
                        'query': source['find'],
                        'user': integration['user'],
                        'password': integration['password'],
                        'host': integration['host'],
                        'port': integration['port']
                    }
                }

                ds = dsClass(**creation_info['kwargs'])

            elif integration['type'] == 'athena':
                creation_info = {
                    'class': dsClass.__name__,
                    'args': [],
                    'kwargs': {
                        'query': source['query'],
                        'staging_dir': source['staging_dir'],
                        'database': source['database'],
                        'access_key': source['access_key'],
                        'secret_key': source['secret_key'],
                        'region_name': source['region_name']
                    }
                }

                ds = dsClass(**creation_info['kwargs'])

            elif integration['type'] == 'trinodb':
                creation_info = {
                    'class': dsClass.__name__,
                    'args': [],
                    'kwargs': {
                        'query': source['query'],
                        'user': integration['user'],
                        'password': integration['password'],
                        'host': integration['host'],
                        'port': integration['port'],
                        'schema': integration['schema'],
                        'catalog': integration['catalog']
                    }
                }

                ds = dsClass(**creation_info['kwargs'])
        else:
            # This probably only happens for urls
            ds = FileDS(source)
            creation_info = {'class': 'FileDS', 'args': [source], 'kwargs': {}}
        return ds, creation_info

    def save_datasource(self,
                        name,
                        source_type,
                        source,
                        file_path=None,
                        company_id=None):
        if source_type == 'file' and (file_path is None):
            raise Exception(
                '`file_path` argument required when source_type == "file"')

        datasource_record = session.query(Datasource).filter_by(
            company_id=company_id, name=name).first()
        while datasource_record is not None:
            raise Exception(f'Datasource with name {name} already exists')

        try:
            datasource_record = Datasource(
                company_id=company_id,
                name=name,
                datasources_version=mindsdb_datasources.__version__,
                mindsdb_version=mindsdb_version)
            session.add(datasource_record)
            session.commit()

            ds_meta_dir = os.path.join(self.dir, f'{company_id}@@@@@{name}')
            os.mkdir(ds_meta_dir)

            ds, creation_info = self.create_datasource(source_type, source,
                                                       file_path, company_id,
                                                       ds_meta_dir)

            if hasattr(ds, 'get_columns') and hasattr(ds, 'get_row_count'):
                try:
                    column_names = ds.get_columns()
                    row_count = ds.get_row_count()
                except Exception:
                    df = ds.df
                    column_names = list(df.keys())
                    row_count = len(df)
            else:
                df = ds.df
                column_names = list(df.keys())
                row_count = len(df)

            if '' in column_names or len(column_names) != len(
                    set(column_names)):
                shutil.rmtree(ds_meta_dir)
                raise Exception(
                    'Each column in datasource must have unique non-empty name'
                )

            datasource_record.creation_info = json.dumps(creation_info)
            datasource_record.data = json.dumps({
                'source_type':
                source_type,
                'source':
                source,
                'row_count':
                row_count,
                'columns': [dict(name=x) for x in column_names]
            })

            self.fs_store.put(
                f'{company_id}@@@@@{name}',
                f'datasource_{company_id}_{datasource_record.id}', self.dir)
            session.commit()

        except Exception as e:
            log.error(f'Error creating datasource {name}, exception: {e}')
            try:
                self.delete_datasource(name, company_id=company_id)
            except Exception:
                pass
            raise e

        return self.get_datasource_obj(name, raw=True, company_id=company_id)

    def get_datasource_obj(self,
                           name=None,
                           id=None,
                           raw=False,
                           company_id=None):
        try:
            if name is not None:
                datasource_record = session.query(Datasource).filter_by(
                    company_id=company_id, name=name).first()
            else:
                datasource_record = session.query(Datasource).filter_by(
                    company_id=company_id, id=id).first()

            self.fs_store.get(
                f'{company_id}@@@@@{name}',
                f'datasource_{company_id}_{datasource_record.id}', self.dir)
            creation_info = json.loads(datasource_record.creation_info)
            if raw:
                return creation_info
            else:
                return eval(creation_info['class'])(*creation_info['args'],
                                                    **creation_info['kwargs'])
        except Exception as e:
            log.error(f'Error getting datasource {name}, exception: {e}')
            return None
 def __init__(self, ray_based: bool) -> None:
     self.config = Config()
     self.fs_store = FsStore()
     self.predictor_cache = {}
     self.ray_based = ray_based
class ModelController():
    config: Config
    fs_store: FsStore
    predictor_cache: Dict[str, Dict[str, Union[Any]]]
    ray_based: bool

    def __init__(self, ray_based: bool) -> None:
        self.config = Config()
        self.fs_store = FsStore()
        self.predictor_cache = {}
        self.ray_based = ray_based

    def _invalidate_cached_predictors(self) -> None:
        # @TODO: Cache will become stale if the respective ModelInterface is not invoked yet a bunch of predictors remained cached, no matter where we invoke it. In practice shouldn't be a big issue though
        for predictor_name in list(self.predictor_cache.keys()):
            if (datetime.datetime.now() -
                    self.predictor_cache[predictor_name]['created']
                ).total_seconds() > 1200:
                del self.predictor_cache[predictor_name]

    def _lock_predictor(self, id: int, mode: str) -> None:
        from mindsdb.interfaces.storage.db import session, Semaphor

        while True:
            semaphor_record = session.query(Semaphor).filter_by(
                entity_id=id, entity_type='predictor').first()
            if semaphor_record is not None:
                if mode == 'read' and semaphor_record.action == 'read':
                    return True
            try:
                semaphor_record = Semaphor(entity_id=id,
                                           entity_type='predictor',
                                           action=mode)
                session.add(semaphor_record)
                session.commit()
                return True
            except Exception:
                pass
            time.sleep(1)

    def _unlock_predictor(self, id: int) -> None:
        from mindsdb.interfaces.storage.db import session, Semaphor
        semaphor_record = session.query(Semaphor).filter_by(
            entity_id=id, entity_type='predictor').first()
        if semaphor_record is not None:
            session.delete(semaphor_record)
            session.commit()

    @contextmanager
    def _lock_context(self, id, mode: str):
        try:
            self._lock_predictor(id, mode)
            yield True
        finally:
            self._unlock_predictor(id)

    def _get_from_data_df(self, from_data: dict) -> DataFrame:
        ds_cls = getattr(mindsdb_datasources, from_data['class'])
        ds = ds_cls(*from_data['args'], **from_data['kwargs'])
        return ds.df

    def _unpack_old_args(
        self,
        from_data: dict,
        kwargs: dict,
        to_predict: Optional[Union[str, list]] = None
    ) -> Tuple[pd.DataFrame, ProblemDefinition, bool]:
        problem_definition = kwargs or {}
        if isinstance(to_predict, str):
            problem_definition['target'] = to_predict
        elif isinstance(to_predict, list) and len(to_predict) == 1:
            problem_definition['target'] = to_predict[0]
        elif problem_definition.get('target') is None:
            raise Exception(
                f"Predict target must be 'str' or 'list' with 1 element. Got: {to_predict}"
            )

        join_learn_process = kwargs.get('join_learn_process', False)
        if 'join_learn_process' in kwargs:
            del kwargs['join_learn_process']

        # Adapt kwargs to problem definition
        if 'timeseries_settings' in kwargs:
            problem_definition['timeseries_settings'] = kwargs[
                'timeseries_settings']

        if 'stop_training_in_x_seconds' in kwargs:
            problem_definition['time_aim'] = kwargs[
                'stop_training_in_x_seconds']

        if kwargs.get('ignore_columns') is not None:
            problem_definition['ignore_features'] = kwargs['ignore_columns']

        if (problem_definition.get('ignore_features') is not None
                and isinstance(problem_definition['ignore_features'],
                               list) is False):
            problem_definition['ignore_features'] = [
                problem_definition['ignore_features']
            ]

        df = self._get_from_data_df(from_data)

        return df, problem_definition, join_learn_process

    @mark_process(name='learn')
    def learn(self,
              name: str,
              from_data: dict,
              to_predict: str,
              datasource_id: int,
              kwargs: dict,
              company_id: int,
              delete_ds_on_fail: Optional[bool] = False) -> None:
        predictor_record = db.session.query(db.Predictor).filter_by(
            company_id=company_id, name=name).first()
        if predictor_record is not None:
            raise Exception('Predictor name must be unique.')

        df, problem_definition, join_learn_process = self._unpack_old_args(
            from_data, kwargs, to_predict)

        problem_definition = ProblemDefinition.from_dict(problem_definition)
        predictor_record = db.Predictor(
            company_id=company_id,
            name=name,
            datasource_id=datasource_id,
            mindsdb_version=mindsdb_version,
            lightwood_version=lightwood_version,
            to_predict=problem_definition.target,
            learn_args=problem_definition.to_dict(),
            data={'name': name})

        db.session.add(predictor_record)
        db.session.commit()
        predictor_id = predictor_record.id

        p = LearnProcess(df, problem_definition, predictor_id,
                         delete_ds_on_fail)
        p.start()
        if join_learn_process:
            p.join()
            if not IS_PY36:
                p.close()
        db.session.refresh(predictor_record)

        data = {}
        if predictor_record.update_status == 'available':
            data['status'] = 'complete'
        elif predictor_record.json_ai is None and predictor_record.code is None:
            data['status'] = 'generating'
        elif predictor_record.data is None:
            data['status'] = 'editable'
        elif 'training_log' in predictor_record.data:
            data['status'] = 'training'
        elif 'error' not in predictor_record.data:
            data['status'] = 'complete'
        else:
            data['status'] = 'error'

    @mark_process(name='predict')
    def predict(self, name: str, when_data: Union[dict, list, pd.DataFrame],
                pred_format: str, company_id: int):
        original_name = name
        name = f'{company_id}@@@@@{name}'

        predictor_record = db.session.query(db.Predictor).filter_by(
            company_id=company_id, name=original_name).first()
        assert predictor_record is not None
        predictor_data = self.get_model_data(name, company_id)
        fs_name = f'predictor_{company_id}_{predictor_record.id}'

        if (name in self.predictor_cache
                and self.predictor_cache[name]['updated_at'] !=
                predictor_record.updated_at):
            del self.predictor_cache[name]

        if name not in self.predictor_cache:
            # Clear the cache entirely if we have less than 1.2 GB left
            if psutil.virtual_memory().available < 1.2 * pow(10, 9):
                self.predictor_cache = {}

            if predictor_data['status'] == 'complete':
                self.fs_store.get(fs_name, fs_name,
                                  self.config['paths']['predictors'])
                self.predictor_cache[name] = {
                    'predictor':
                    lightwood.predictor_from_state(
                        os.path.join(self.config['paths']['predictors'],
                                     fs_name), predictor_record.code),
                    'updated_at':
                    predictor_record.updated_at,
                    'created':
                    datetime.datetime.now(),
                    'code':
                    predictor_record.code,
                    'pickle':
                    str(
                        os.path.join(self.config['paths']['predictors'],
                                     fs_name))
                }
            else:
                raise Exception(
                    f'Trying to predict using predictor {original_name} with status: {predictor_data["status"]}. Error is: {predictor_data.get("error", "unknown")}'
                )

        if isinstance(when_data,
                      dict) and 'kwargs' in when_data and 'args' in when_data:
            ds_cls = getattr(mindsdb_datasources, when_data['class'])
            df = ds_cls(*when_data['args'], **when_data['kwargs']).df
        else:
            if isinstance(when_data, dict):
                when_data = [when_data]
            df = pd.DataFrame(when_data)

        predictions = self.predictor_cache[name]['predictor'].predict(df)
        predictions = predictions.to_dict(orient='records')
        # Bellow is useful for debugging caching and storage issues
        # del self.predictor_cache[name]

        target = predictor_record.to_predict[0]
        if pred_format in ('explain', 'dict', 'dict&explain'):
            explain_arr = []
            dict_arr = []
            for i, row in enumerate(predictions):
                explain_arr.append({
                    target: {
                        'predicted_value': row['prediction'],
                        'confidence': row.get('confidence', None),
                        'confidence_lower_bound': row.get('lower', None),
                        'confidence_upper_bound': row.get('upper', None),
                        'anomaly': row.get('anomaly', None),
                        'truth': row.get('truth', None)
                    }
                })

                td = {'predicted_value': row['prediction']}
                for col in df.columns:
                    if col in row:
                        td[col] = row[col]
                    elif f'order_{col}' in row:
                        td[col] = row[f'order_{col}']
                    elif f'group_{col}' in row:
                        td[col] = row[f'group_{col}']
                    else:
                        orginal_index = row.get('original_index')
                        if orginal_index is None:
                            log.warning('original_index is None')
                            orginal_index = i
                        td[col] = df.iloc[orginal_index][col]
                dict_arr.append({target: td})
            if pred_format == 'explain':
                return explain_arr
            elif pred_format == 'dict':
                return dict_arr
            elif pred_format == 'dict&explain':
                return dict_arr, explain_arr
        # New format -- Try switching to this in 2-3 months for speed, for now above is ok
        else:
            return predictions

    @mark_process(name='analyse')
    def analyse_dataset(self, ds: dict,
                        company_id: int) -> lightwood.DataAnalysis:
        ds_cls = getattr(mindsdb_datasources, ds['class'])
        df = ds_cls(*ds['args'], **ds['kwargs']).df
        analysis = lightwood.analyze_dataset(df)
        return analysis.to_dict()  # type: ignore

    def get_model_data(self, name, company_id: int):
        if '@@@@@' in name:
            sn = name.split('@@@@@')
            assert len(sn) < 3  # security
            name = sn[1]

        original_name = name
        name = f'{company_id}@@@@@{name}'

        predictor_record = db.session.query(db.Predictor).filter_by(
            company_id=company_id, name=original_name).first()
        assert predictor_record is not None

        linked_db_ds = db.session.query(db.Datasource).filter_by(
            company_id=company_id, id=predictor_record.datasource_id).first()

        data = deepcopy(predictor_record.data)
        data['dtype_dict'] = predictor_record.dtype_dict
        data['created_at'] = str(
            parse_datetime(str(predictor_record.created_at).split('.')[0]))
        data['updated_at'] = str(
            parse_datetime(str(predictor_record.updated_at).split('.')[0]))
        data['predict'] = predictor_record.to_predict[0]
        data['update'] = predictor_record.update_status
        data['mindsdb_version'] = predictor_record.mindsdb_version
        data['name'] = predictor_record.name
        data['code'] = predictor_record.code
        data['json_ai'] = predictor_record.json_ai
        data['data_source_name'] = linked_db_ds.name if linked_db_ds else None
        data['problem_definition'] = predictor_record.learn_args

        # assume older models are complete, only temporary
        if 'error' in predictor_record.data:
            data['status'] = 'error'
        elif predictor_record.update_status == 'available':
            data['status'] = 'complete'
        elif predictor_record.json_ai is None and predictor_record.code is None:
            data['status'] = 'generating'
        elif predictor_record.data is None:
            data['status'] = 'editable'
        elif 'training_log' in predictor_record.data:
            data['status'] = 'training'
        elif 'error' not in predictor_record.data:
            data['status'] = 'complete'
        else:
            data['status'] = 'error'

        if data.get('accuracies', None) is not None:
            if len(data['accuracies']) > 0:
                data['accuracy'] = float(
                    np.mean(list(data['accuracies'].values())))
        return data

    def get_model_description(self, name: str, company_id: int):
        """
        Similar to `get_model_data` but meant to be seen directly by the user, rather than parsed by something like the Studio predictor view.

        Uses `get_model_data` to compose this, but in the future we might want to make this independent if we deprected `get_model_data`

        :returns: Dictionary of the analysis (meant to be foramtted by the APIs and displayed as json/yml/whatever)
        """ # noqa
        model_description = {}
        model_data = self.get_model_data(name, company_id)

        model_description['accuracies'] = model_data['accuracies']
        model_description['column_importances'] = model_data[
            'column_importances']
        model_description['outputs'] = [model_data['predict']]
        model_description['inputs'] = [
            col for col in model_data['dtype_dict']
            if col not in model_description['outputs']
        ]
        model_description['datasource'] = model_data['data_source_name']
        model_description['model'] = ' --> '.join(
            str(k) for k in model_data['json_ai'])

        return model_description

    def get_models(self, company_id: int):
        models = []
        for db_p in db.session.query(
                db.Predictor).filter_by(company_id=company_id):
            model_data = self.get_model_data(db_p.name, company_id=company_id)
            reduced_model_data = {}

            for k in [
                    'name', 'version', 'is_active', 'predict', 'status',
                    'current_phase', 'accuracy', 'data_source', 'update',
                    'data_source_name', 'mindsdb_version', 'error'
            ]:
                reduced_model_data[k] = model_data.get(k, None)

            for k in ['train_end_at', 'updated_at', 'created_at']:
                reduced_model_data[k] = model_data.get(k, None)
                if reduced_model_data[k] is not None:
                    try:
                        reduced_model_data[k] = parse_datetime(
                            str(reduced_model_data[k]).split('.')[0])
                    except Exception as e:
                        # @TODO Does this ever happen
                        log.error(
                            f'Date parsing exception while parsing: {k} in get_models: ',
                            e)
                        reduced_model_data[k] = parse_datetime(
                            str(reduced_model_data[k]))

            models.append(reduced_model_data)
        return models

    def delete_model(self, name, company_id: int):
        original_name = name
        name = f'{company_id}@@@@@{name}'

        db_p = db.session.query(db.Predictor).filter_by(
            company_id=company_id, name=original_name).first()
        if db_p is None:
            raise Exception(f"Predictor '{name}' does not exist")
        db.session.delete(db_p)
        if db_p.datasource_id is not None:
            try:
                dataset_record = db.Datasource.query.get(db_p.datasource_id)
                if (isinstance(dataset_record.data, str) and json.loads(
                        dataset_record.data).get('source_type') != 'file'):
                    DataStore().delete_datasource(dataset_record.name,
                                                  company_id)
            except Exception:
                pass
        db.session.commit()

        DatabaseWrapper(company_id).unregister_predictor(name)

        # delete from s3
        self.fs_store.delete(f'predictor_{company_id}_{db_p.id}')

        return 0

    def rename_model(self, old_name, new_name, company_id: int):
        db_p = db.session.query(db.Predictor).filter_by(company_id=company_id,
                                                        name=old_name).first()
        db_p.name = new_name
        db.session.commit()
        dbw = DatabaseWrapper(company_id)
        dbw.unregister_predictor(old_name)
        dbw.register_predictors([self.get_model_data(new_name, company_id)])

    @mark_process(name='learn')
    def update_model(self, name: str, company_id: int):
        # TODO: Add version check here once we're done debugging
        predictor_record = db.session.query(db.Predictor).filter_by(
            company_id=company_id, name=name).first()
        assert predictor_record is not None
        predictor_record.update_status = 'updating'
        db.session.commit()

        p = UpdateProcess(name, company_id)
        p.start()
        return 'Updated in progress'

    @mark_process(name='learn')
    def generate_predictor(self, name: str, from_data: dict, datasource_id,
                           problem_definition_dict: dict,
                           join_learn_process: bool, company_id: int):
        predictor_record = db.session.query(db.Predictor).filter_by(
            company_id=company_id, name=name).first()
        if predictor_record is not None:
            raise Exception('Predictor name must be unique.')

        df, problem_definition, _ = self._unpack_old_args(
            from_data, problem_definition_dict)

        problem_definition = ProblemDefinition.from_dict(problem_definition)

        predictor_record = db.Predictor(
            company_id=company_id,
            name=name,
            datasource_id=datasource_id,
            mindsdb_version=mindsdb_version,
            lightwood_version=lightwood_version,
            to_predict=problem_definition.target,
            learn_args=problem_definition.to_dict(),
            data={'name': name})

        db.session.add(predictor_record)
        db.session.commit()
        predictor_id = predictor_record.id

        p = GenerateProcess(df, problem_definition, predictor_id)
        p.start()
        if join_learn_process:
            p.join()
            if not IS_PY36:
                p.close()
        db.session.refresh(predictor_record)

    def edit_json_ai(self, name: str, json_ai: dict, company_id=None):
        predictor_record = db.session.query(db.Predictor).filter_by(
            company_id=company_id, name=name).first()
        assert predictor_record is not None

        json_ai = lightwood.JsonAI.from_dict(json_ai)
        predictor_record.code = lightwood.code_from_json_ai(json_ai)
        predictor_record.json_ai = json_ai.to_dict()
        db.session.commit()

    def code_from_json_ai(self, json_ai: dict, company_id=None):
        json_ai = lightwood.JsonAI.from_dict(json_ai)
        code = lightwood.code_from_json_ai(json_ai)
        return code

    def edit_code(self, name: str, code: str, company_id=None):
        """Edit an existing predictor's code"""
        if self.config.get('cloud', False):
            raise Exception('Code editing prohibited on cloud')

        predictor_record = db.session.query(db.Predictor).filter_by(
            company_id=company_id, name=name).first()
        assert predictor_record is not None

        lightwood.predictor_from_code(code)
        predictor_record.code = code
        predictor_record.json_ai = None
        db.session.commit()

    @mark_process(name='learn')
    def fit_predictor(self, name: str, from_data: dict,
                      join_learn_process: bool, company_id: int) -> None:
        predictor_record = db.session.query(db.Predictor).filter_by(
            company_id=company_id, name=name).first()
        assert predictor_record is not None

        df = self._get_from_data_df(from_data)
        p = FitProcess(predictor_record.id, df)
        p.start()
        if join_learn_process:
            p.join()
            if not IS_PY36:
                p.close()
Example #10
0
def add_db_integration(name, data, company_id):
    if 'database_name' not in data:
        data['database_name'] = name
    if 'publish' not in data:
        data['publish'] = True

    bundle_path = data.get('secure_connect_bundle')
    if data.get('type') in ('cassandra',
                            'scylla') and _is_not_empty_str(bundle_path):
        if os.path.isfile(bundle_path) is False:
            raise Exception(f'Can not get access to file: {bundle_path}')
        integrations_dir = Config()['paths']['integrations']

        p = Path(bundle_path)
        data['secure_connect_bundle'] = p.name

        integration_record = Integration(name=name,
                                         data=data,
                                         company_id=company_id)
        session.add(integration_record)
        session.commit()
        integration_id = integration_record.id

        folder_name = f'integration_files_{company_id}_{integration_id}'
        integration_dir = os.path.join(integrations_dir, folder_name)
        create_directory(integration_dir)
        shutil.copyfile(bundle_path, os.path.join(integration_dir, p.name))

        FsStore().put(folder_name, integration_dir, integrations_dir)
    elif data.get('type') in ('mysql', 'mariadb'):
        ssl = data.get('ssl')
        files = {}
        temp_dir = None
        if ssl is True:
            for key in ['ssl_ca', 'ssl_cert', 'ssl_key']:
                if key not in data:
                    continue
                if os.path.isfile(data[key]) is False:
                    if _is_not_empty_str(data[key]) is False:
                        raise Exception(
                            "'ssl_ca', 'ssl_cert' and 'ssl_key' must be paths or inline certs"
                        )
                    if temp_dir is None:
                        temp_dir = tempfile.mkdtemp(
                            prefix='integration_files_')
                    cert_file_name = data.get(f'{key}_name', f'{key}.pem')
                    cert_file_path = os.path.join(temp_dir, cert_file_name)
                    with open(cert_file_path, 'wt') as f:
                        f.write(data[key])
                    data[key] = cert_file_path
                files[key] = data[key]
                p = Path(data[key])
                data[key] = p.name
        integration_record = Integration(name=name,
                                         data=data,
                                         company_id=company_id)
        session.add(integration_record)
        session.commit()
        integration_id = integration_record.id

        if len(files) > 0:
            integrations_dir = Config()['paths']['integrations']
            folder_name = f'integration_files_{company_id}_{integration_id}'
            integration_dir = os.path.join(integrations_dir, folder_name)
            create_directory(integration_dir)
            for file_path in files.values():
                p = Path(file_path)
                shutil.copyfile(file_path,
                                os.path.join(integration_dir, p.name))
            FsStore().put(folder_name, integration_dir, integrations_dir)
    else:
        integration_record = Integration(name=name,
                                         data=data,
                                         company_id=company_id)
        session.add(integration_record)
        session.commit()