class XmlDataSource(db.Model, BaseMixin, RefXmlImportHandlerMixin): TYPES = DataSource.DATASOURCE_DICT.keys() # TODO: unique for XmlImportHandler name = db.Column(db.String(200), nullable=False) type = db.Column(db.Enum(*TYPES, name='xml_datasource_types')) params = deferred(db.Column(JSONType)) @validates('params') def validate_params(self, key, params): # TODO: return params def to_xml(self, secure=False, to_string=False, pretty_print=True): extra = self.params if secure else {} elem = etree.Element(self.type, name=self.name, **extra) if to_string: return etree.tostring(elem, pretty_print=pretty_print) return elem @property def core_datasource(self): # TODO: secure ds_xml = self.to_xml(secure=True) return DataSource.factory(ds_xml) def __repr__(self): return "<DataSource %s>" % self.name
class Segment(db.Model, BaseMixin): __tablename__ = 'segment' name = db.Column(db.String(200)) records = db.Column(db.Integer) model_id = db.Column(db.Integer, db.ForeignKey('model.id')) model = relationship(Model, backref=backref('segments'))
class Transformer(BaseModel, BaseTrainedEntity, db.Model): """ Represents pretrained transformer """ from api.features.config import TRANSFORMERS TYPES_LIST = TRANSFORMERS.keys() params = db.Column(JSONType) field_name = db.Column(db.String(100)) feature_type = db.Column(db.String(100)) type = db.Column(db.Enum(*TYPES_LIST, name='pretrained_transformer_types'), nullable=False) datasets = relationship('DataSet', secondary=lambda: transformer_data_sets_table) def train(self, iterator, *args, **kwargs): from cloudml.transformers.transformer import Transformer transformer = Transformer(json.dumps(self.json), is_file=False) transformer.train(iterator) return transformer def set_trainer(self, transformer): from bson import Binary import cPickle as pickle trainer_data = Binary(pickle.dumps(transformer)) self.trainer = trainer_data self.trainer_size = len(trainer_data) def get_trainer(self): import cPickle as pickle return pickle.loads(self.trainer) @property def json(self): return { "transformer-name": self.name, "field-name": self.field_name, "type": self.feature_type, "transformer": { "type": self.type, "params": self.params } } def load_from_json(self, json): self.name = json.get("transformer-name") self.field_name = json.get("field-name") self.feature_type = json.get("type") if "transformer" in json and json["transformer"]: transformer_config = json["transformer"] self.type = transformer_config.get("type") self.params = transformer_config.get("params") @property def can_delete(self): if self.training_in_progress: self.reason_msg = "The transformer cannot be deleted while " \ "training is still in progress." return not self.training_in_progress and \ super(Transformer, self).can_delete
class Tag(db.Model, BaseMixin): """ Model tag. """ text = db.Column(db.String(200)) count = db.Column(db.Integer) def update_counter(self): """ Recalculates the counter """ self.count = len(self.models) self.save()
class XmlQuery(db.Model, BaseMixin): FIELDS_TO_SERIALIZE = ['target', 'sqoop_dataset_name', 'autoload_sqoop_dataset'] target = db.Column(db.String(200)) # Could be filled when entity contains sqoop element sqoop_dataset_name = db.Column(db.String(200)) autoload_sqoop_dataset = db.Column(db.Boolean) text = db.Column(db.Text) def __repr__(self): return "<Query %s>" % self.text
class NamedFeatureType(BaseModel, PredefinedItemMixin, ExportImportMixin, db.Model): """ Represents named feature type """ __tablename__ = 'predefined_feature_type' TYPES_LIST = [ 'boolean', 'int', 'float', 'numeric', 'date', 'map', 'categorical_label', 'categorical', 'text', 'regex', 'composite' ] FIELDS_TO_SERIALIZE = ('name', 'type', 'input_format', 'params') type = db.Column(db.Enum(*TYPES_LIST, name='named_feature_types'), nullable=False) input_format = db.Column(db.String(200))
class VerificationExample(BaseMixin, db.Model): verification_id = db.Column(db.Integer, db.ForeignKey('server_model_verification.id')) verification = relationship('ServerModelVerification', backref=backref('verification_examples', cascade='all,delete')) example_id = db.Column(db.Integer, db.ForeignKey('test_example.id')) example = relationship('TestExample', backref=backref('verification_examples', cascade='all,delete')) result = db.Column(JSONType)
class Predict(db.Model, BaseMixin): models = relationship( 'PredictModel', secondary=lambda: predict_models_table, backref='predict_section') # Results label_id = db.Column(db.ForeignKey('predict_result_label.id')) label = relationship('PredictResultLabel', foreign_keys=[label_id], cascade='all,delete', backref='results') probability_id = db.Column(db.ForeignKey('predict_result_probability.id')) probability = relationship( 'PredictResultProbability', foreign_keys=[probability_id], cascade='all,delete', backref='probabilities')
class AsyncTask(db.Model, BaseModel): STATUS_IN_PROGRESS = 'In Progress' STATUS_COMPLETED = 'Completed' STATUS_ERROR = 'Error' STATUSES = [STATUS_IN_PROGRESS, STATUS_COMPLETED, STATUS_ERROR] status = db.Column(db.Enum(*STATUSES, name='async_task_statuses'), default=STATUS_IN_PROGRESS) error = db.Column(db.String(300)) args = db.Column(JSONType) kwargs = db.Column(JSONType) result = db.Column(JSONType) task_name = db.Column(db.String(300)) task_id = db.Column(db.String(300)) object_type = db.Column(db.String(300)) object_id = db.Column(db.Integer) @classmethod def _get_object_type_name(cls, obj): return obj.__class__.__name__ @classmethod def create_by_task_and_object(cls, task_name, task_id, args, kwargs, obj): return cls(task_name=task_name, task_id=task_id, object_type=cls._get_object_type_name(obj), object_id=obj.id, args=args, kwargs=kwargs) @classmethod def get_current_by_object(cls, obj, task_name=None, user=None, statuses=[STATUS_IN_PROGRESS, STATUS_COMPLETED], **kwargs): cursor = cls.query.filter_by( object_type=cls._get_object_type_name(obj), object_id=obj.id, ).filter(cls.status.in_(statuses)) if task_name: cursor = cursor.filter_by(task_name=task_name) if user: cursor = cursor.filter_by(created_by=user) if kwargs: cursor = cursor.filter_by(**kwargs) return cursor.order_by(desc(AsyncTask.created_on)).all() def terminate_task(self): from api import celery celery.control.revoke(self.task_id, terminate=True, signal='SIGKILL')
class WeightsCategory(db.Model, BaseMixin): """ Represents Model Parameter Weights Category. NOTE: used for constructing trees of weights. """ __tablename__ = 'weights_category' name = db.Column(db.String(200)) short_name = db.Column(db.String(200)) # TODO: remove it model_name = db.Column(db.String(200)) model_id = db.Column(db.Integer, db.ForeignKey('model.id')) model = relationship(Model, backref=backref('weight_categories')) segment_id = db.Column(db.Integer, db.ForeignKey('segment.id')) segment = relationship(Segment, backref=backref('weight_categories')) normalized_weight = db.Column(db.Float) class_label = db.Column(db.String(100), nullable=True) parent = db.Column(db.String(200)) # TODO: Maybe have FK Weight to WeightsCategory? # @aggregated('normalized_weight', sa.Column(sa.Float)) # def normalized_weight(self): # return sa.func.sum(Weight.value2) def __repr__(self): return '<Category {0}>'.format(self.name)
class XmlScript(db.Model, BaseMixin, RefXmlImportHandlerMixin): TYPE_PYTHON_CODE = 'python_code' TYPE_PYTHON_FILE = 'python_file' TYPES = [TYPE_PYTHON_CODE, TYPE_PYTHON_FILE] data = db.Column(db.Text) type = db.Column(db.Enum(*TYPES, name='xml_script_types'), server_default=TYPE_PYTHON_CODE) @staticmethod def to_s3(data, import_handler_id): from api.amazon_utils import AmazonS3Helper from datetime import datetime import api try: handler = XmlImportHandler.query.get(import_handler_id) if not handler: raise ValueError("Import handler {0} not found".format( import_handler_id)) key = "{0}/{1}_python_script_{2}.py".format( api.app.config['IMPORT_HANDLER_SCRIPTS_FOLDER'], handler.name, datetime.now().strftime("%Y-%m-%d %H:%M:%S")) s3helper = AmazonS3Helper() s3helper.save_key_string(key, data) except Exception as e: raise ValueError("Error when uploading file to Amazon S3: " "{0}".format(e)) return key def to_xml(self, to_string=False, pretty_print=True): attrib = {"src": self.data} \ if self.type == XmlScript.TYPE_PYTHON_FILE \ else {} text = self.data if self.type == XmlScript.TYPE_PYTHON_CODE else None elem = etree.Element(self.type, attrib) elem.text = text if to_string: return etree.tostring(elem, pretty_print=pretty_print) return elem @property def script_string(self): try: script = Script(self.to_xml()) return script.get_script_str() except Exception as e: raise ValueError("Can't load script sources. {0}".format(e))
class PredefinedScaler(BaseModel, PredefinedItemMixin, db.Model, ExportImportMixin): """ Represents predefined feature scaler """ FIELDS_TO_SERIALIZE = ('type', 'params') NO_PARAMS_KEY = False TYPES_LIST = SCALERS.keys() type = db.Column(db.Enum(*TYPES_LIST, name='scaler_types'), nullable=False)
class PredefinedClassifier(BaseModel, PredefinedItemMixin, db.Model, ExportImportMixin): """ Represents predefined classifier """ NO_PARAMS_KEY = False FIELDS_TO_SERIALIZE = ('type', 'params') TYPES_LIST = CLASSIFIERS.keys() type = db.Column(db.Enum(*TYPES_LIST, name='classifier_types'), nullable=False)
class PredefinedItemMixin(object): name = db.Column(db.String(200), nullable=False, unique=True) @declared_attr def params(cls): return db.Column(JSONType) def __repr__(self): return '<%s %s>' % (self.__class__.__name__.lower(), self.type)
class XmlInputParameter(db.Model, BaseMixin, RefXmlImportHandlerMixin): FIELDS_TO_SERIALIZE = ['name', 'type', 'regex', 'format'] TYPES = PROCESS_STRATEGIES.keys() # TODO: unique for XmlImportHandler name = db.Column(db.String(200), nullable=False) type = db.Column(db.Enum(*TYPES, name='xml_input_types')) regex = db.Column(db.String(200)) format = db.Column(db.String(200)) def save(self, *args, **kwargs): super(XmlInputParameter, self).save(*args, **kwargs) self.import_handler.update_import_params() def delete(self, *args, **kwargs): handler = self.import_handler super(XmlInputParameter, self).delete(*args, **kwargs) handler.update_import_params()
class PredefinedDataSource(db.Model, BaseModel): """ Datasource that used to db exports """ TYPE_REQUEST = 'http' TYPE_SQL = 'sql' TYPES_LIST = (TYPE_REQUEST, TYPE_SQL) VENDOR_POSTGRES = 'postgres' VENDORS_LIST = (VENDOR_POSTGRES, ) name = db.Column(db.String(200), nullable=False, unique=True) type = db.Column(db.Enum(*TYPES_LIST, name='datasource_types'), default=TYPE_SQL) # sample: {"conn": basestring, "vendor": basestring} db = deferred(db.Column(JSONType)) @property def safe_db(self): if not self.can_edit: return { 'vendor': self.db['vendor'], 'conn': re.sub(PASSWORD_REGEX, HIDDEN_PASSWORD, self.db['conn']) } return self.db @validates('db') def validate_db(self, key, db): self.validate_db_fields(db) return db @classmethod def validate_db_fields(cls, db): key = 'db' assert 'vendor' in db, assertion_msg(key, 'vendor is required') assert db['vendor'] in cls.VENDORS_LIST, assertion_msg( key, 'choose vendor from %s' % ', '.join(cls.VENDORS_LIST)) assert 'conn' in db, assertion_msg(key, 'conn is required')
class Instance(BaseModel, db.Model): """ Represents instance, which could be using for exec tasks """ TYPES_LIST = ['small', 'large'] name = db.Column(db.String(200), nullable=False, unique=True) description = deferred(db.Column(db.Text)) ip = db.Column(db.String(200), nullable=False) type = db.Column(db.Enum(*TYPES_LIST, name='instance_types'), nullable=False) is_default = db.Column(db.Boolean, default=False) def save(self, commit=True): super(Instance, self).save(False) if self.is_default: Instance.query\ .filter(Instance.is_default, Instance.name != self.name)\ .update({Instance.is_default: False}) if commit: db.session.commit() def __repr__(self): return "<Instance %s>" % self.name
class ServerModelVerification(BaseModel, db.Model, RefXmlImportHandlerMixin): """ Represents verification of the model, that deployed to the server """ STATUS_NEW = 'New' STATUS_QUEUED = 'Queued' STATUS_IN_PROGRESS = 'In Progress' STATUS_ERROR = 'Error' STATUS_DONE = 'Done' STATUSES = [ STATUS_NEW, STATUS_QUEUED, STATUS_IN_PROGRESS, STATUS_ERROR, STATUS_DONE ] status = db.Column(db.Enum(*STATUSES, name='model_verification_statuses'), nullable=False, default=STATUS_NEW) error = db.Column(db.Text) server_id = db.Column(db.Integer, db.ForeignKey('server.id')) server = relationship(Server, backref=backref('model_verifications', cascade='all,delete')) model_id = db.Column(db.Integer, db.ForeignKey('model.id')) model = relationship(Model, backref=backref('model_verifications', cascade='all,delete')) test_result_id = db.Column(db.Integer, db.ForeignKey('test_result.id')) test_result = relationship('TestResult', backref=backref('model_verifications', cascade='all,delete')) description = db.Column(JSONType) result = db.Column(JSONType) params_map = db.Column(JSONType) clazz = db.Column(db.String(200)) def __repr__(self): return '<ServerModelVerification {0}>'.format(self.model.name)
class XmlSqoop(db.Model, BaseMixin): target = db.Column(db.String(200), nullable=False) table = db.Column(db.String(200), nullable=False) where = db.Column(db.String(200), nullable=True) direct = db.Column(db.String(200), nullable=True) mappers = db.Column(db.String(200), nullable=True) options = db.Column(db.String(200), nullable=True) text = db.Column(db.Text, nullable=True) FIELDS_TO_SERIALIZE = ['target', 'table', 'where', 'direct', 'mappers', 'options'] # Global datasource datasource_id = db.Column(db.ForeignKey('xml_data_source.id', ondelete='SET NULL')) datasource = relationship('XmlDataSource', foreign_keys=[datasource_id]) entity_id = db.Column(db.ForeignKey('xml_entity.id')) entity = relationship( 'XmlEntity', foreign_keys=[entity_id], backref=backref( 'sqoop_imports', cascade='all,delete', order_by='XmlSqoop.id')) @property def pig_fields(self): from api.async_tasks.models import AsyncTask return AsyncTask.get_current_by_object( self, 'api.import_handlers.tasks.load_pig_fields', ) def to_dict(self): sqoop = super(XmlSqoop, self).to_dict() if self.datasource: sqoop['datasource'] = self.datasource.name return sqoop
class ClassifierGridParams(db.Model, BaseModel): STATUS_LIST = ('New', 'Queued', 'Calculating', 'Completed', 'Error') model_id = db.Column(db.Integer, db.ForeignKey('model.id')) model = relationship(Model, backref=backref('classifier_grid_params')) scoring = db.Column(db.String(100), default='accuracy') status = db.Column(db.Enum(*STATUS_LIST, name='classifier_grid_params_statuses'), nullable=False, default='New') train_data_set_id = db.Column( db.Integer, db.ForeignKey('data_set.id', ondelete='SET NULL')) train_dataset = relationship('DataSet', foreign_keys=[train_data_set_id]) test_data_set_id = db.Column( db.Integer, db.ForeignKey('data_set.id', ondelete='SET NULL')) test_dataset = relationship('DataSet', foreign_keys=[test_data_set_id]) parameters = db.Column(JSONType) parameters_grid = db.Column(JSONType)
class XmlEntity(db.Model, BaseMixin, RefXmlImportHandlerMixin): id = db.Column(db.Integer, primary_key=True) name = db.Column(db.String(200), nullable=False) autoload_fields = db.Column(db.Boolean, default=False) # JSON or CSV field as datasource transformed_field_id = db.Column(db.ForeignKey( 'xml_field.id', use_alter=True, name="fk_transformed_field", ondelete='SET NULL')) transformed_field = relationship('XmlField', post_update=True, foreign_keys=[transformed_field_id], backref='entities_for_field_ds') # Sub entity entity_id = db.Column(db.ForeignKey('xml_entity.id')) entity = relationship('XmlEntity', remote_side=[id], backref=backref('entities', cascade='all,delete')) # Global datasource datasource_id = db.Column(db.ForeignKey('xml_data_source.id', ondelete='CASCADE')) datasource = relationship('XmlDataSource', foreign_keys=[datasource_id]) query_id = db.Column(db.ForeignKey('xml_query.id')) query_obj = relationship('XmlQuery', foreign_keys=[query_id], cascade='all,delete', backref='parent_entity') def __repr__(self): return "<Entity %s>" % self.name def to_dict(self): ent = {'name': self.name} if self.transformed_field: ent['datasource'] = self.transformed_field.name if self.datasource: ent['datasource'] = self.datasource.name if self.autoload_fields: ent['autoload_fields'] = str(self.autoload_fields).lower() return ent
class LogMessage(BaseMixin, db.Model): TRAIN_MODEL = 'trainmodel_log' IMPORT_DATA = 'importdata_log' RUN_TEST = 'runtest_log' CONFUSION_MATRIX_LOG = 'confusion_matrix_log' GRID_SEARCH = 'gridsearch_log' TYPES_LIST = (TRAIN_MODEL, IMPORT_DATA, RUN_TEST, CONFUSION_MATRIX_LOG, GRID_SEARCH) LEVELS_LIST = [ 'CRITICAL', 'ERROR', 'WARN', 'WARNING', 'INFO', 'DEBUG', 'NOTSET' ] id = db.Column(Integer, primary_key=True) level = db.Column(Enum(*LEVELS_LIST, name='log_levels')) params = deferred(db.Column(JSONType)) content = deferred(db.Column(String(600))) type = db.Column(Enum(*TYPES_LIST, name='log_types')) created_on = db.Column(db.DateTime, server_default=func.now()) @classmethod def delete_related_logs(cls, obj): # TODO: fill with code pass
class PredictResultLabel(db.Model, RefPredictModelMixin): FIELDS_TO_SERIALIZE = ('script', 'predict_model') script = db.Column(db.Text)
def params(cls): return db.Column(JSONType)
class PredictResultProbability(db.Model, RefPredictModelMixin): FIELDS_TO_SERIALIZE = ('label', 'script', 'predict_model') script = db.Column(db.Text) label = db.Column(db.String(200))
class Cluster(BaseModel, db.Model): STATUS_NEW = 'New' STATUS_STARTING = 'Starting' STATUS_RUNNING = 'Running' STATUS_WAITING = 'Waiting' STATUS_ERROR = 'Error' STATUS_TERMINATED = 'Terminated' STATUS_TERMINATING = 'Terminating' STATUS_TERMINATED_WITH_ERRORS = 'Terminated_with_errors' # ['STARTING', 'BOOTSTRAPPING', 'RUNNING', 'WAITING','TERMINATING', 'TERMINATED', 'TERMINATED_WITH_ERRORS']) PENDING = -1 PORT_RANGE = (9000, 9010) STATUSES = [ STATUS_NEW, STATUS_STARTING, STATUS_RUNNING, STATUS_WAITING, STATUS_ERROR, STATUS_TERMINATED, STATUS_TERMINATING, STATUS_TERMINATED_WITH_ERRORS ] TERMINATED_STATUSES = [ STATUS_TERMINATED, STATUS_TERMINATING, STATUS_TERMINATED_WITH_ERRORS ] ACTIVE_STATUSES = [ STATUS_NEW, STATUS_STARTING, STATUS_RUNNING, STATUS_WAITING, STATUS_ERROR ] jobflow_id = db.Column(db.String(200), nullable=False, unique=True) master_node_dns = db.Column(db.String(200), nullable=True) port = db.Column(db.Integer, nullable=True) pid = db.Column(db.Integer, nullable=True) status = db.Column(db.Enum(*STATUSES, name='cluster_statuses'), default=STATUS_NEW) logs_folder = db.Column(db.String(200), nullable=True) # FIXME: do we need this field? is_default = db.Column(db.Boolean, default=False) def generate_port(self): """ Generates random port which isn't used by any other cluster. Available ports are in range: `PORT_RANGE`. """ exclude = set([ cl[0] for cl in Cluster.query.with_entities(Cluster.port).filter( Cluster.status != Cluster.STATUS_TERMINATED, Cluster.status != Cluster.STATUS_ERROR) ]) ports = list(set(xrange(*self.PORT_RANGE)) - exclude) if ports: self.port = random.choice(ports) else: raise ValueError('All ports are busy') @property def tunnels(self): from api.async_tasks.models import AsyncTask return AsyncTask.get_current_by_object( self, 'api.instances.tasks.run_ssh_tunnel', ) @property def active_tunnel(self): return self.pid def create_ssh_tunnel(self): from api.instances.tasks import run_ssh_tunnel if self.pid is None: # task delayed self.pid = self.PENDING self.save() run_ssh_tunnel.delay(self.id) def terminate_ssh_tunnel(self): import os import signal if self.pid is not None and self is not self.PENDING: try: os.kill(self.pid, signal.SIGKILL) except Exception, exc: logging.error("Unknown error occures, while removing " "process: {0}".format(exc)) self.pid = None self.save()
class FeatureSet(ExportImportMixin, BaseModel, db.Model): """ Represents list of the features with schema name.""" FIELDS_TO_SERIALIZE = ('schema_name', ) FEATURES_STRUCT = {'schema-name': '', 'features': [], "feature-types": []} schema_name = db.Column(db.String(200), nullable=False, default='noname') group_by = relationship('Feature', secondary=lambda: group_by_table, backref='group_by_feature_set') target_variable = db.Column(db.String(200)) features_count = db.Column(db.Integer, default=0) features_dict = db.Column(JSONType) modified = db.Column(db.Boolean, default=False) locked = db.Column(db.Boolean, default=False) __table_args__ = (CheckConstraint(features_count >= 0, name='check_features_count_positive'), {}) def __repr__(self): return '<Feature Set {0} ({1})>'.format(self.schema_name, self.target_variable) @property def features(self): if self.features_dict is None or self.modified: self.features_dict = self.to_dict() self.modified = False self.save() return self.features_dict def from_dict(self, features_dict, commit=True): if features_dict is None or \ not isinstance(features_dict, dict): raise ValueError('should be a dictionary') self.schema_name = features_dict['schema-name'] self.group_by = [] type_list = features_dict.get('feature-types', None) if type_list: for feature_type in type_list: count = NamedFeatureType.query.filter_by( name=feature_type['name']).count() if not count: ntype = NamedFeatureType() ntype.from_dict(feature_type, commit=False) group_by_exist = 'group-by' in features_dict for feature_dict in features_dict['features']: feature = Feature(feature_set=self) feature.from_dict(feature_dict, commit=False) if group_by_exist and feature.name in features_dict['group-by']: self.group_by.append(feature) if commit: db.session.commit() db.session.expire( self, ['target_variable', 'features_count', 'features_dict']) def to_dict(self): features_dict = { 'schema-name': self.schema_name, 'group-by': [f.name for f in self.group_by], 'features': [], "feature-types": [] } types = [] for feature in Feature.query.filter_by(feature_set=self): if feature.type not in NamedFeatureType.TYPES_LIST: types.append(feature.type) features_dict['features'].append(feature.to_dict()) for ftype in set(types): named_type = NamedFeatureType.query.filter_by(name=ftype).one() features_dict['feature-types'].append(named_type.to_dict()) return features_dict def _check_locked(self): if self.locked: self.reason_msg = 'The model referring to this feature set is ' \ 'deployed and blocked for modifications.' return False return True @property def can_edit(self): return self._check_locked() and super(FeatureSet, self).can_edit @property def can_delete(self): return self._check_locked() and super(FeatureSet, self).can_delete def save(self, commit=True): # TODO: Why do default attr of the column not work? if self.features_dict is None: self.features_dict = self.FEATURES_STRUCT self.features_dict['schema-name'] = self.schema_name BaseModel.save(self, commit=commit) def delete(self): features = Feature.query.filter( Feature.feature_set_id == self.id).all() for feature in features: feature.delete() super(FeatureSet, self).delete()
class Feature(ExportImportMixin, RefFeatureSetMixin, BaseModel, db.Model): FIELDS_TO_SERIALIZE = ('name', 'type', 'input_format', 'params', 'default', 'is_target_variable', 'required', 'transformer', 'scaler', 'disabled') name = db.Column(db.String(200), nullable=False) type = db.Column(db.String(200), nullable=False) input_format = db.Column(db.String(200)) default = db.Column(JSONType) # TODO: think about type required = db.Column(db.Boolean, default=True) is_target_variable = db.Column(db.Boolean, default=False) disabled = db.Column(db.Boolean, default=False, nullable=False, server_default='false') params = deferred(db.Column(JSONType, default={})) transformer = deferred(db.Column(JSONType)) scaler = deferred(db.Column(JSONType)) __table_args__ = (UniqueConstraint('feature_set_id', 'name', name='name_unique'), ) def __repr__(self): return '<Feature %s>' % self.name def transformer_type(self): if self.transformer is None: return None return self.transformer['type'] def scaler_type(self): if self.scaler is None: return None return self.scaler['type'] def save(self, commit=True): super(Feature, self).save(commit=False) if self.is_target_variable: Feature.query\ .filter( Feature.is_target_variable, Feature.name != self.name, Feature.feature_set_id == self.feature_set_id)\ .update({Feature.is_target_variable: False}) self.required = True if commit: db.session.commit() @property def can_delete(self): if self.is_target_variable: self.reason_msg = "Target variable can not be deleted" return False return super(Feature, self).can_delete @staticmethod def field_type_to_feature_type(field_type): if field_type == 'float' or field_type == 'boolean': return field_type elif field_type == 'integer': return 'int' elif field_type == 'string': return 'text' elif field_type is 'json': return 'map' else: return 'text'
def feature_set_id(cls): return db.Column('feature_set_id', db.ForeignKey('feature_set.id'))
self.features_dict = self.FEATURES_STRUCT self.features_dict['schema-name'] = self.schema_name BaseModel.save(self, commit=commit) def delete(self): features = Feature.query.filter( Feature.feature_set_id == self.id).all() for feature in features: feature.delete() super(FeatureSet, self).delete() group_by_table = db.Table( 'group_by_table', db.Model.metadata, db.Column( 'feature_set_id', db.Integer, db.ForeignKey('feature_set.id', ondelete='CASCADE', onupdate='CASCADE')), db.Column( 'feature_id', db.Integer, db.ForeignKey('feature.id', ondelete='CASCADE', onupdate='CASCADE'))) @event.listens_for(Feature, "after_insert") def after_insert_feature(mapper, connection, target): if target.feature_set is None and target.feature_set_id is not None: from sqlalchemy.orm import joinedload target = target.__class__.query.options(joinedload('feature_set')).get( target.id) if target.feature_set is not None: update_feature_set_on_change_features(connection, target.feature_set, target)