class XmlDataSource(db.Model, BaseMixin, RefXmlImportHandlerMixin): TYPES = DataSource.DATASOURCE_DICT.keys() # TODO: unique for XmlImportHandler name = db.Column(db.String(200), nullable=False) type = db.Column(db.Enum(*TYPES, name='xml_datasource_types')) params = deferred(db.Column(JSONType)) @validates('params') def validate_params(self, key, params): # TODO: return params def to_xml(self, secure=False, to_string=False, pretty_print=True): extra = self.params if secure else {} elem = etree.Element(self.type, name=self.name, **extra) if to_string: return etree.tostring(elem, pretty_print=pretty_print) return elem @property def core_datasource(self): # TODO: secure ds_xml = self.to_xml(secure=True) return DataSource.factory(ds_xml) def __repr__(self): return "<DataSource %s>" % self.name
class PredefinedScaler(BaseModel, PredefinedItemMixin, db.Model, ExportImportMixin): """ Represents predefined feature scaler """ FIELDS_TO_SERIALIZE = ('type', 'params') NO_PARAMS_KEY = False TYPES_LIST = SCALERS.keys() type = db.Column(db.Enum(*TYPES_LIST, name='scaler_types'), nullable=False)
class Transformer(BaseModel, BaseTrainedEntity, db.Model): """ Represents pretrained transformer """ from api.features.config import TRANSFORMERS TYPES_LIST = TRANSFORMERS.keys() params = db.Column(JSONType) field_name = db.Column(db.String(100)) feature_type = db.Column(db.String(100)) type = db.Column(db.Enum(*TYPES_LIST, name='pretrained_transformer_types'), nullable=False) datasets = relationship('DataSet', secondary=lambda: transformer_data_sets_table) def train(self, iterator, *args, **kwargs): from cloudml.transformers.transformer import Transformer transformer = Transformer(json.dumps(self.json), is_file=False) transformer.train(iterator) return transformer def set_trainer(self, transformer): from bson import Binary import cPickle as pickle trainer_data = Binary(pickle.dumps(transformer)) self.trainer = trainer_data self.trainer_size = len(trainer_data) def get_trainer(self): import cPickle as pickle return pickle.loads(self.trainer) @property def json(self): return { "transformer-name": self.name, "field-name": self.field_name, "type": self.feature_type, "transformer": { "type": self.type, "params": self.params } } def load_from_json(self, json): self.name = json.get("transformer-name") self.field_name = json.get("field-name") self.feature_type = json.get("type") if "transformer" in json and json["transformer"]: transformer_config = json["transformer"] self.type = transformer_config.get("type") self.params = transformer_config.get("params") @property def can_delete(self): if self.training_in_progress: self.reason_msg = "The transformer cannot be deleted while " \ "training is still in progress." return not self.training_in_progress and \ super(Transformer, self).can_delete
class PredefinedClassifier(BaseModel, PredefinedItemMixin, db.Model, ExportImportMixin): """ Represents predefined classifier """ NO_PARAMS_KEY = False FIELDS_TO_SERIALIZE = ('type', 'params') TYPES_LIST = CLASSIFIERS.keys() type = db.Column(db.Enum(*TYPES_LIST, name='classifier_types'), nullable=False)
class AsyncTask(db.Model, BaseModel): STATUS_IN_PROGRESS = 'In Progress' STATUS_COMPLETED = 'Completed' STATUS_ERROR = 'Error' STATUSES = [STATUS_IN_PROGRESS, STATUS_COMPLETED, STATUS_ERROR] status = db.Column(db.Enum(*STATUSES, name='async_task_statuses'), default=STATUS_IN_PROGRESS) error = db.Column(db.String(300)) args = db.Column(JSONType) kwargs = db.Column(JSONType) result = db.Column(JSONType) task_name = db.Column(db.String(300)) task_id = db.Column(db.String(300)) object_type = db.Column(db.String(300)) object_id = db.Column(db.Integer) @classmethod def _get_object_type_name(cls, obj): return obj.__class__.__name__ @classmethod def create_by_task_and_object(cls, task_name, task_id, args, kwargs, obj): return cls(task_name=task_name, task_id=task_id, object_type=cls._get_object_type_name(obj), object_id=obj.id, args=args, kwargs=kwargs) @classmethod def get_current_by_object(cls, obj, task_name=None, user=None, statuses=[STATUS_IN_PROGRESS, STATUS_COMPLETED], **kwargs): cursor = cls.query.filter_by( object_type=cls._get_object_type_name(obj), object_id=obj.id, ).filter(cls.status.in_(statuses)) if task_name: cursor = cursor.filter_by(task_name=task_name) if user: cursor = cursor.filter_by(created_by=user) if kwargs: cursor = cursor.filter_by(**kwargs) return cursor.order_by(desc(AsyncTask.created_on)).all() def terminate_task(self): from api import celery celery.control.revoke(self.task_id, terminate=True, signal='SIGKILL')
class NamedFeatureType(BaseModel, PredefinedItemMixin, ExportImportMixin, db.Model): """ Represents named feature type """ __tablename__ = 'predefined_feature_type' TYPES_LIST = [ 'boolean', 'int', 'float', 'numeric', 'date', 'map', 'categorical_label', 'categorical', 'text', 'regex', 'composite' ] FIELDS_TO_SERIALIZE = ('name', 'type', 'input_format', 'params') type = db.Column(db.Enum(*TYPES_LIST, name='named_feature_types'), nullable=False) input_format = db.Column(db.String(200))
class XmlField(db.Model, BaseMixin): TYPES = PROCESS_STRATEGIES.keys() TRANSFORM_TYPES = ['json', 'csv'] FIELDS_TO_SERIALIZE = ['name', 'type', 'column', 'jsonpath', 'delimiter', 'regex', 'split', 'dateFormat', 'template', 'transform', 'headers', 'script', 'required', 'multipart', 'key_path', 'value_path'] def to_dict(self): fieldDict = super(XmlField, self).to_dict() if 'multipart' in fieldDict and fieldDict['multipart'] == 'false': fieldDict.pop('multipart') if 'required' in fieldDict and fieldDict['required'] == 'false': fieldDict.pop('required') return fieldDict name = db.Column(db.String(200), nullable=False) type = db.Column(db.Enum(*TYPES, name='xml_field_types')) column = db.Column(db.String(200)) jsonpath = db.Column(db.String(200)) delimiter = db.Column(db.String(200)) regex = db.Column(db.String(200)) split = db.Column(db.String(200)) dateFormat = db.Column(db.String(200)) template = db.Column(db.String(200)) transform = db.Column( db.Enum(*TRANSFORM_TYPES, name='xml_transform_types')) headers = db.Column(db.String(200)) script = db.Column(db.Text) required = db.Column(db.Boolean, default=False) multipart = db.Column(db.Boolean, default=False) key_path = db.Column(db.String(200)) value_path = db.Column(db.String(200)) entity_id = db.Column(db.ForeignKey('xml_entity.id')) entity = relationship( 'XmlEntity', foreign_keys=[entity_id], backref=backref( 'fields', cascade='all,delete', order_by='XmlField.id'))
class XmlScript(db.Model, BaseMixin, RefXmlImportHandlerMixin): TYPE_PYTHON_CODE = 'python_code' TYPE_PYTHON_FILE = 'python_file' TYPES = [TYPE_PYTHON_CODE, TYPE_PYTHON_FILE] data = db.Column(db.Text) type = db.Column(db.Enum(*TYPES, name='xml_script_types'), server_default=TYPE_PYTHON_CODE) @staticmethod def to_s3(data, import_handler_id): from api.amazon_utils import AmazonS3Helper from datetime import datetime import api try: handler = XmlImportHandler.query.get(import_handler_id) if not handler: raise ValueError("Import handler {0} not found".format( import_handler_id)) key = "{0}/{1}_python_script_{2}.py".format( api.app.config['IMPORT_HANDLER_SCRIPTS_FOLDER'], handler.name, datetime.now().strftime("%Y-%m-%d %H:%M:%S")) s3helper = AmazonS3Helper() s3helper.save_key_string(key, data) except Exception as e: raise ValueError("Error when uploading file to Amazon S3: " "{0}".format(e)) return key def to_xml(self, to_string=False, pretty_print=True): attrib = {"src": self.data} \ if self.type == XmlScript.TYPE_PYTHON_FILE \ else {} text = self.data if self.type == XmlScript.TYPE_PYTHON_CODE else None elem = etree.Element(self.type, attrib) elem.text = text if to_string: return etree.tostring(elem, pretty_print=pretty_print) return elem @property def script_string(self): try: script = Script(self.to_xml()) return script.get_script_str() except Exception as e: raise ValueError("Can't load script sources. {0}".format(e))
class XmlInputParameter(db.Model, BaseMixin, RefXmlImportHandlerMixin): FIELDS_TO_SERIALIZE = ['name', 'type', 'regex', 'format'] TYPES = PROCESS_STRATEGIES.keys() # TODO: unique for XmlImportHandler name = db.Column(db.String(200), nullable=False) type = db.Column(db.Enum(*TYPES, name='xml_input_types')) regex = db.Column(db.String(200)) format = db.Column(db.String(200)) def save(self, *args, **kwargs): super(XmlInputParameter, self).save(*args, **kwargs) self.import_handler.update_import_params() def delete(self, *args, **kwargs): handler = self.import_handler super(XmlInputParameter, self).delete(*args, **kwargs) handler.update_import_params()
class ClassifierGridParams(db.Model, BaseModel): STATUS_LIST = ('New', 'Queued', 'Calculating', 'Completed', 'Error') model_id = db.Column(db.Integer, db.ForeignKey('model.id')) model = relationship(Model, backref=backref('classifier_grid_params')) scoring = db.Column(db.String(100), default='accuracy') status = db.Column(db.Enum(*STATUS_LIST, name='classifier_grid_params_statuses'), nullable=False, default='New') train_data_set_id = db.Column( db.Integer, db.ForeignKey('data_set.id', ondelete='SET NULL')) train_dataset = relationship('DataSet', foreign_keys=[train_data_set_id]) test_data_set_id = db.Column( db.Integer, db.ForeignKey('data_set.id', ondelete='SET NULL')) test_dataset = relationship('DataSet', foreign_keys=[test_data_set_id]) parameters = db.Column(JSONType) parameters_grid = db.Column(JSONType)
class PredefinedDataSource(db.Model, BaseModel): """ Datasource that used to db exports """ TYPE_REQUEST = 'http' TYPE_SQL = 'sql' TYPES_LIST = (TYPE_REQUEST, TYPE_SQL) VENDOR_POSTGRES = 'postgres' VENDORS_LIST = (VENDOR_POSTGRES, ) name = db.Column(db.String(200), nullable=False, unique=True) type = db.Column(db.Enum(*TYPES_LIST, name='datasource_types'), default=TYPE_SQL) # sample: {"conn": basestring, "vendor": basestring} db = deferred(db.Column(JSONType)) @property def safe_db(self): if not self.can_edit: return { 'vendor': self.db['vendor'], 'conn': re.sub(PASSWORD_REGEX, HIDDEN_PASSWORD, self.db['conn']) } return self.db @validates('db') def validate_db(self, key, db): self.validate_db_fields(db) return db @classmethod def validate_db_fields(cls, db): key = 'db' assert 'vendor' in db, assertion_msg(key, 'vendor is required') assert db['vendor'] in cls.VENDORS_LIST, assertion_msg( key, 'choose vendor from %s' % ', '.join(cls.VENDORS_LIST)) assert 'conn' in db, assertion_msg(key, 'conn is required')
class ServerModelVerification(BaseModel, db.Model, RefXmlImportHandlerMixin): """ Represents verification of the model, that deployed to the server """ STATUS_NEW = 'New' STATUS_QUEUED = 'Queued' STATUS_IN_PROGRESS = 'In Progress' STATUS_ERROR = 'Error' STATUS_DONE = 'Done' STATUSES = [ STATUS_NEW, STATUS_QUEUED, STATUS_IN_PROGRESS, STATUS_ERROR, STATUS_DONE ] status = db.Column(db.Enum(*STATUSES, name='model_verification_statuses'), nullable=False, default=STATUS_NEW) error = db.Column(db.Text) server_id = db.Column(db.Integer, db.ForeignKey('server.id')) server = relationship(Server, backref=backref('model_verifications', cascade='all,delete')) model_id = db.Column(db.Integer, db.ForeignKey('model.id')) model = relationship(Model, backref=backref('model_verifications', cascade='all,delete')) test_result_id = db.Column(db.Integer, db.ForeignKey('test_result.id')) test_result = relationship('TestResult', backref=backref('model_verifications', cascade='all,delete')) description = db.Column(JSONType) result = db.Column(JSONType) params_map = db.Column(JSONType) clazz = db.Column(db.String(200)) def __repr__(self): return '<ServerModelVerification {0}>'.format(self.model.name)
class Instance(BaseModel, db.Model): """ Represents instance, which could be using for exec tasks """ TYPES_LIST = ['small', 'large'] name = db.Column(db.String(200), nullable=False, unique=True) description = deferred(db.Column(db.Text)) ip = db.Column(db.String(200), nullable=False) type = db.Column(db.Enum(*TYPES_LIST, name='instance_types'), nullable=False) is_default = db.Column(db.Boolean, default=False) def save(self, commit=True): super(Instance, self).save(False) if self.is_default: Instance.query\ .filter(Instance.is_default, Instance.name != self.name)\ .update({Instance.is_default: False}) if commit: db.session.commit() def __repr__(self): return "<Instance %s>" % self.name
class DataSet(db.Model, BaseModel): """ Set of the imported data. """ LOG_TYPE = LogMessage.IMPORT_DATA STATUS_NEW = 'New' STATUS_IMPORTING = 'Importing' STATUS_UPLOADING = 'Uploading' STATUS_IMPORTED = 'Imported' STATUS_ERROR = 'Error' STATUSES = [ STATUS_IMPORTING, STATUS_UPLOADING, STATUS_IMPORTED, STATUS_ERROR, STATUS_NEW ] FORMAT_JSON = 'json' FORMAT_CSV = 'csv' FORMATS = [FORMAT_JSON, FORMAT_CSV] name = db.Column(db.String(200)) status = db.Column(db.Enum(*STATUSES, name='dataset_statuses'), default=STATUS_NEW) error = db.Column(db.String(300)) # TODO: trunc error to 300 symbols data = db.Column(db.String(200)) import_params = db.Column(JSONType) # Generic relation to import handler import_handler_id = db.Column(db.Integer, nullable=False) import_handler_type = db.Column(db.String(200), default='xml') import_handler_xml = db.Column(db.Text) cluster_id = db.Column(db.Integer, db.ForeignKey('cluster.id', ondelete='SET NULL')) cluster = relationship('Cluster', backref=backref('datasets')) pig_step = db.Column(db.Integer, nullable=True) pig_row = db.Column(JSONType) on_s3 = db.Column(db.Boolean) compress = db.Column(db.Boolean) filename = db.Column(db.String(200)) filesize = db.Column(db.BigInteger) records_count = db.Column(db.Integer) time = db.Column(db.Integer) data_fields = db.Column(postgresql.ARRAY(db.String)) format = db.Column(db.String(10)) uid = db.Column(db.String(200)) locked = db.Column(db.Boolean, default=False) @property def import_handler(self): """Provides in-Python access to the "parent" by choosing the appropriate relationship. """ return ImportHandler.query.get(self.import_handler_id) @import_handler.setter def import_handler(self, handler): self.import_handler_id = handler.id self.import_handler_type = handler.TYPE self.import_handler_xml = handler.data def set_uid(self): if not self.uid: self.uid = uuid.uuid1().hex def get_s3_download_url(self, expires_in=3600): helper = AmazonS3Helper() return helper.get_download_url(self.uid, expires_in) def set_file_path(self): self.set_uid() data = '%s.%s' % (self.uid, 'gz' if self.compress else 'json') self.data = data from api.base.io_utils import get_or_create_data_folder path = get_or_create_data_folder() self.filename = join(path, data) self.save() @property def loaded_data(self): if not self.on_s3: raise Exception('Invalid oper') if not hasattr(self, '_data'): self._data = self.load_from_s3() return self._data def get_data_stream(self): import gzip if not self.on_s3 or exists(self.filename): logging.info('Loading data from local file') open_meth = gzip.open if self.compress else open return open_meth(self.filename, 'r') else: logging.info('Loading data from Amazon S3') stream = StringIO.StringIO(self.loaded_data) if self.compress: logging.info('Decompress data') return gzip.GzipFile(fileobj=stream, mode='r') return stream def get_iterator(self, stream): from cloudml.trainer.streamutils import streamingiterload return streamingiterload(stream, source_format=self.format) def load_from_s3(self): helper = AmazonS3Helper() return helper.load_key(self.uid) def save_to_s3(self): meta = { 'handler': self.import_handler_id, 'dataset': self.name, 'params': str(self.import_params) } self.set_uid() helper = AmazonS3Helper() helper.save_gz_file(self.uid, self.filename, meta) helper.close() self.on_s3 = True self.save() def set_error(self, error, commit=True): self.error = str(error)[:299] self.status = self.STATUS_ERROR if commit: self.save() def delete(self): # Stop task # self.terminate_task() # TODO filename = self.filename on_s3 = self.on_s3 uid = self.uid super(DataSet, self).delete() LogMessage.delete_related_logs(self.id, type_=LogMessage.IMPORT_DATA) # TODO: check import handler type try: os.remove(filename) except OSError: pass if on_s3: from botocore.exceptions import ClientError helper = AmazonS3Helper() try: helper.delete_key(uid) except ClientError as e: logging.exception(str(e)) def save(self, *args, **kwargs): if self.status != self.STATUS_ERROR: self.error = '' super(DataSet, self).save(*args, **kwargs) def __repr__(self): return '<Dataset %r>' % self.name def _check_locked(self): if self.locked: self.reason_msg = 'Some existing models were trained/tested ' \ 'using this dataset. ' return False return True @property def can_edit(self): return self._check_locked() and super(DataSet, self).can_edit @property def can_delete(self): return self._check_locked() and super(DataSet, self).can_delete def unlock(self): from api.ml_models.models import data_sets_table from api.model_tests.models import TestResult if db.session.query(data_sets_table).filter( data_sets_table.c.data_set_id == self.id).count() == 0 and \ TestResult.query.filter( TestResult.data_set_id == self.id).count() == 0: self.locked = False self.save()
class Cluster(BaseModel, db.Model): STATUS_NEW = 'New' STATUS_STARTING = 'Starting' STATUS_RUNNING = 'Running' STATUS_WAITING = 'Waiting' STATUS_ERROR = 'Error' STATUS_TERMINATED = 'Terminated' STATUS_TERMINATING = 'Terminating' STATUS_TERMINATED_WITH_ERRORS = 'Terminated_with_errors' # ['STARTING', 'BOOTSTRAPPING', 'RUNNING', 'WAITING','TERMINATING', 'TERMINATED', 'TERMINATED_WITH_ERRORS']) PENDING = -1 PORT_RANGE = (9000, 9010) STATUSES = [ STATUS_NEW, STATUS_STARTING, STATUS_RUNNING, STATUS_WAITING, STATUS_ERROR, STATUS_TERMINATED, STATUS_TERMINATING, STATUS_TERMINATED_WITH_ERRORS ] TERMINATED_STATUSES = [ STATUS_TERMINATED, STATUS_TERMINATING, STATUS_TERMINATED_WITH_ERRORS ] ACTIVE_STATUSES = [ STATUS_NEW, STATUS_STARTING, STATUS_RUNNING, STATUS_WAITING, STATUS_ERROR ] jobflow_id = db.Column(db.String(200), nullable=False, unique=True) master_node_dns = db.Column(db.String(200), nullable=True) port = db.Column(db.Integer, nullable=True) pid = db.Column(db.Integer, nullable=True) status = db.Column(db.Enum(*STATUSES, name='cluster_statuses'), default=STATUS_NEW) logs_folder = db.Column(db.String(200), nullable=True) # FIXME: do we need this field? is_default = db.Column(db.Boolean, default=False) def generate_port(self): """ Generates random port which isn't used by any other cluster. Available ports are in range: `PORT_RANGE`. """ exclude = set([ cl[0] for cl in Cluster.query.with_entities(Cluster.port).filter( Cluster.status != Cluster.STATUS_TERMINATED, Cluster.status != Cluster.STATUS_ERROR) ]) ports = list(set(xrange(*self.PORT_RANGE)) - exclude) if ports: self.port = random.choice(ports) else: raise ValueError('All ports are busy') @property def tunnels(self): from api.async_tasks.models import AsyncTask return AsyncTask.get_current_by_object( self, 'api.instances.tasks.run_ssh_tunnel', ) @property def active_tunnel(self): return self.pid def create_ssh_tunnel(self): from api.instances.tasks import run_ssh_tunnel if self.pid is None: # task delayed self.pid = self.PENDING self.save() run_ssh_tunnel.delay(self.id) def terminate_ssh_tunnel(self): import os import signal if self.pid is not None and self is not self.PENDING: try: os.kill(self.pid, signal.SIGKILL) except Exception, exc: logging.error("Unknown error occures, while removing " "process: {0}".format(exc)) self.pid = None self.save()
class Server(BaseModel, db.Model): """ Represents cloudml-predict server """ ALLOWED_FOLDERS = [FOLDER_MODELS, FOLDER_IMPORT_HANDLERS] PRODUCTION = 'Production' STAGING = 'Staging' DEV = 'Development' ANALYTICS = 'Analytics' TYPES = [STAGING, DEV, ANALYTICS, PRODUCTION] ENV_MAP = { PRODUCTION: 'prod', STAGING: 'staging', DEV: 'dev', ANALYTICS: 'analytics' } name = db.Column(db.String(200), nullable=False, unique=True) description = deferred(db.Column(db.Text)) ip = db.Column(db.String(200), nullable=False) folder = db.Column(db.String(600), nullable=False) is_default = db.Column(db.Boolean, default=False) memory_mb = db.Column(db.Integer, nullable=False, default=0) type = db.Column(db.Enum(*TYPES, name='server_types'), default=DEV) logs_url = db.Column(db.Text) def __repr__(self): return '<Server {0}>'.format(self.name) @property def grafana_name(self): if not hasattr(self, '_grafana_name'): from slugify import slugify name = self.name.replace('_', '-') self._grafana_name = slugify('CloudMl ' + name) return self._grafana_name @property def grafana_url(self): from api import app return 'http://{0}/dashboard/db/{1}'.format( app.config.get('GRAFANA_HOST'), self.grafana_name) def list_keys(self, folder=None, params={}): path = self.folder.strip('/') if folder and folder in self.ALLOWED_FOLDERS: path += '/{0!s}'.format(folder) objects = [] s3 = AmazonS3Helper( bucket_name=app.config['CLOUDML_PREDICT_BUCKET_NAME']) for key in s3.list_keys(path): uid = key['Key'].split('/')[-1] key = s3.load_key(key['Key'], with_metadata=True) if key['Metadata']['hide'] == 'True': continue objects.append({ 'id': uid, 'object_name': key['Metadata'].get('object_name', None), 'size': key['ContentLength'], 'uploaded_on': key['Metadata'].get('uploaded_on', None), 'last_modified': str(key['LastModified']), 'name': key['Metadata'].get('name', None), 'object_id': key['Metadata'].get('id', None), 'object_type': key['Metadata'].get('type', None), 'user_id': key['Metadata'].get('user_id', None), 'user_name': key['Metadata'].get('user_name', None), 'crc32': key['Metadata'].get('crc32', None), 'server_id': self.id, 'loading_error': key['Metadata'].get('loading_error', None), 'count_400': key['Metadata'].get('count_400', None), 'count_500': key['Metadata'].get('count_500', None), 'count_of_max_response': key['Metadata'].get('count_of_max_response', None), 'longest_resp_count': key['Metadata'].get('longest_resp_count', None), 'longest_resp_time': key['Metadata'].get('longest_resp_time', None), 'max_response_time': key['Metadata'].get('max_response_time', None), 'requests': key['Metadata'].get('requests', None) }) sort_by = params.get('sort_by', None) order = params.get('order', 'asc') if objects and sort_by: obj = objects[0] if sort_by in obj.keys(): return sorted(objects, key=lambda x: x[sort_by], reverse=order != 'asc') else: raise ValueError( 'Unable to sort by %s. Property is not exist.' % sort_by) return objects def set_key_metadata(self, uid, folder, key, value): if self.check_edit_metadata(folder, key, value): key_name = '{0}/{1}/{2}'.format(self.folder, folder, uid) s3 = AmazonS3Helper( bucket_name=app.config['CLOUDML_PREDICT_BUCKET_NAME']) s3.set_key_metadata(key_name, {key: value}, True) # this means key is deleted, need to update model/import handler if key == 'hide' and value == 'True': obj = s3.load_key(key_name, with_metadata=True) cl = Model if folder == FOLDER_MODELS else XmlImportHandler model = cl.query.get(obj['Metadata']['id']) server_list = [s for s in model.servers_ids if s != self.id] model.servers_ids = server_list model.save() def check_edit_metadata(self, folder, key, value): entities_by_folder = { FOLDER_MODELS: 'Model', FOLDER_IMPORT_HANDLERS: 'Import Handler' } integer_keys = ['count_400', 'count_500', 'count_of_max_response'] entity = entities_by_folder.get(folder, None) if not entity: raise ValueError('Wrong folder: %s' % folder) if key == 'name': files = self.list_keys(folder) for file_ in files: if file_['name'] == value: raise ValueError('{0} with name "{1}" already exists on ' 'the server {2}'.format( entity, value, self.name)) if key in integer_keys: try: v = int(value) except Exception: raise ValueError("Incorrect value '{0}' for '{1}'. Integer is " "expected".format(value, key)) return True def get_key_metadata(self, uid, folder, key): key_name = '{0}/{1}/{2}'.format(self.folder, folder, uid) s3 = AmazonS3Helper( bucket_name=app.config['CLOUDML_PREDICT_BUCKET_NAME']) s3key = s3.load_key(key_name, with_metadata=True) return s3key['Metadata'][key] def save(self, commit=True): BaseModel.save(self, commit=False) if self.is_default: Server.query\ .filter(Server.is_default, Server.name != self.name)\ .update({Server.is_default: False}) if commit: db.session.commit()
class TestResult(db.Model, BaseModel): LOG_TYPE = LogMessage.RUN_TEST STATUS_QUEUED = 'Queued' STATUS_IMPORTING = 'Importing' STATUS_IMPORTED = 'Imported' STATUS_IN_PROGRESS = 'In Progress' STATUS_STORING = 'Storing' STATUS_COMPLETED = 'Completed' STATUS_ERROR = 'Error' STATUSES = [ STATUS_QUEUED, STATUS_IMPORTING, STATUS_IMPORTED, STATUS_IN_PROGRESS, STATUS_STORING, STATUS_COMPLETED, STATUS_ERROR ] TEST_STATUSES = [ STATUS_QUEUED, STATUS_IMPORTING, STATUS_IMPORTED, STATUS_IN_PROGRESS, STATUS_STORING ] __tablename__ = 'test_result' name = db.Column(db.String(200), nullable=False) status = db.Column(db.Enum(*STATUSES, name='test_statuses')) error = db.Column(db.String(300)) model_id = db.Column(db.Integer, db.ForeignKey('model.id')) model = relationship(Model, backref=backref('tests', cascade='all,delete')) model_name = db.Column(db.String(200)) data_set_id = db.Column(db.Integer, db.ForeignKey('data_set.id', ondelete='SET NULL')) dataset = relationship(DataSet, foreign_keys=[data_set_id]) examples_count = db.Column(db.Integer) examples_fields = db.Column(postgresql.ARRAY(db.String)) examples_size = db.Column(db.Float) parameters = db.Column(JSONType) classes_set = db.Column(postgresql.ARRAY(db.String)) accuracy = db.Column(db.Float) roc_auc = db.Column(JSONType) metrics = db.Column(JSONType) memory_usage = db.Column(db.Integer) vect_data = deferred(db.Column(S3File)) fill_weights = db.Column(db.Boolean, default=False) def __repr__(self): return '<TestResult {0}>'.format(self.name) def get_vect_data(self, num, segment): from pickle import loads data = loads(self.vect_data) offset = 0 for k, v in data.items(): offset += v.shape[0] if k == segment: break import numpy if isinstance(data[segment], numpy.ndarray): return data[num - offset] return data[segment].getrow(num - offset).todense().tolist()[0] def set_error(self, error, commit=True): self.error = str(error)[:299] self.status = TestResult.STATUS_ERROR if commit: self.save() @property def exports(self): from api.async_tasks.models import AsyncTask return AsyncTask.get_current_by_object( self, 'api.model_tests.tasks.get_csv_results', ) @property def db_exports(self): from api.async_tasks.models import AsyncTask return AsyncTask.get_current_by_object( self, 'api.model_tests.tasks.export_results_to_db', ) @property def confusion_matrix_calculations(self): from api.async_tasks.models import AsyncTask return AsyncTask.get_current_by_object( self, 'api.model_tests.tasks.calculate_confusion_matrix', statuses=AsyncTask.STATUSES) @property def can_edit(self): if not self.model.can_edit: self.reason_msg = self.model.reason_msg return False return super(TestResult, self).can_edit @property def can_delete(self): if not self.model.can_delete: self.reason_msg = self.model.reason_msg return False return super(TestResult, self).can_delete def delete(self): ds = self.dataset super(TestResult, self).delete() ds.unlock() @property def test_in_progress(self): return self.status in self.TEST_STATUSES
def status(cls): from api.base.utils import convert_name name = convert_name(cls.__name__) return db.Column(db.Enum(*cls.STATUSES, name='%s_statuses' % name), default=cls.STATUS_NEW)