class InputValue(DBObject): """Basic Value of the Input structure.""" FIELDS = { 'node_id': DBObjectField( type=str, default='', is_list=False, ), 'output_id': DBObjectField( type=str, default='', is_list=False, ), 'resource_id': DBObjectField( type=str, default='', is_list=False, ), } def __str__(self): return 'InputValue({}, {})'.format(self.node_id, self.output_id) def __repr__(self): return 'InputValue({})'.format(str(self.to_dict()))
class Output(DBObject): """Basic Output structure.""" FIELDS = { 'name': DBObjectField( type=str, default='', is_list=False, ), 'file_type': DBObjectField( type=str, default='', is_list=False, ), 'resource_id': DBObjectField( type=str, default=None, is_list=False, ), } def __str__(self): return 'Output(name="{}")'.format(self.name) def __repr__(self): return 'Output({})'.format(str(self.to_dict()))
class MasterState(DBObject): """Master statistics snapshot.""" FIELDS = { '_id': DBObjectField( type=ObjectId, default=ObjectId, is_list=False, ), 'workers': DBObjectField( type=WorkerState, default=list, is_list=True, ), } DB_COLLECTION = Collections.MASTER_STATE
class ParameterEnum(DBObject): """Enum value.""" FIELDS = { 'values': DBObjectField( type=str, default=list, is_list=True, ), 'index': DBObjectField( type=str, default=-1, is_list=False, ), } def __repr__(self): return 'ParameterEnum({})'.format(str(self.to_dict()))
class ParameterCode(DBObject): """Code value.""" FIELDS = { 'value': DBObjectField( type=str, default='', is_list=False, ), 'mode': DBObjectField( type=str, default='python', is_list=False, ), } # Unused MODES = {'python'} def __repr__(self): return 'ParameterCode({})'.format(str(self.to_dict()))
class WorkerState(DBObject): """Status of the worker""" FIELDS = { '_id': DBObjectField( type=ObjectId, default=ObjectId, is_list=False, ), 'worker_id': DBObjectField( type=str, default=None, is_list=False, ), 'graph_id': DBObjectField( type=ObjectId, default=None, is_list=False, ), 'node': DBObjectField( type=Node, default=None, is_list=False, ), 'host': DBObjectField( type=str, default='', is_list=False, ), 'num_finished_jobs': DBObjectField( type=int, default=0, is_list=False, ), }
class GraphCancellation(DBObject): """GraphCancellation represents Graph Cancellation event in the database.""" FIELDS = { '_id': DBObjectField( type=ObjectId, default=ObjectId, is_list=False, ), "graph_id": DBObjectField( type=ObjectId, default=None, is_list=False, ), "acknowledged": DBObjectField( type=bool, default=False, is_list=False, ), } DB_COLLECTION = Collections.GRAPHS_CANCELLATIONS
class Input(DBObject): """Basic Input structure.""" FIELDS = { 'name': DBObjectField( type=str, default='', is_list=False, ), 'file_types': DBObjectField( type=str, default=list, is_list=True, ), 'values': DBObjectField( type=InputValue, default=list, is_list=True, ), 'min_count': DBObjectField( type=int, default=1, is_list=False, ), 'max_count': DBObjectField( type=int, default=1, is_list=False, ), } def __str__(self): return 'Input(name="{}")'.format(self.name) def __repr__(self): return 'Input({})'.format(str(self.to_dict()))
class ParameterWidget(DBObject): """Basic ParameterWidget structure.""" FIELDS = { 'alias': DBObjectField( type=str, default='', is_list=False, ), } def __str__(self): return 'ParameterWidget(alias="{}")'.format(self.alias) def __repr__(self): return 'ParameterWidget({})'.format(str(self.to_dict()))
class Graph(DBObject): """Basic graph with db interface.""" FIELDS = { '_id': DBObjectField( type=ObjectId, default=ObjectId, is_list=False, ), 'title': DBObjectField( type=str, default='Title', is_list=False, ), 'description': DBObjectField( type=str, default='Description', is_list=False, ), 'graph_running_status': DBObjectField( type=str, default=GraphRunningStatus.CREATED, is_list=False, ), 'author': DBObjectField( type=ObjectId, default=None, is_list=False, ), 'public': DBObjectField( type=bool, default=False, is_list=False, ), 'nodes': DBObjectField( type=Node, default=list, is_list=True, ), } DB_COLLECTION = Collections.GRAPHS def cancel(self): """Cancel the graph.""" self.graph_running_status = GraphRunningStatus.CANCELED for node in self.nodes: if node.node_running_status in [ NodeRunningStatus.RUNNING, NodeRunningStatus.IN_QUEUE ]: node.node_running_status = NodeRunningStatus.CANCELED self.save(force=True) def get_validation_error(self): """Validate Graph. Return: (ValidationError) Validation error if found; else None """ violations = [] if self.title == '': violations.append( ValidationError( target=ValidationTargetType.PROPERTY, object_id='title', validation_code=ValidationCode.MISSING_PARAMETER)) if self.description == '': violations.append( ValidationError( target=ValidationTargetType.PROPERTY, object_id='description', validation_code=ValidationCode.MISSING_PARAMETER)) # Meaning the node is in the graph. Otherwise souldn't be in validation step for node in self.nodes: node_violation = node.get_validation_error() if node_violation: violations.append(node_violation) if len(violations) == 0: return None return ValidationError(target=ValidationTargetType.GRAPH, object_id=str(self._id), validation_code=ValidationCode.IN_DEPENDENTS, children=violations) def arrange_auto_layout(self): """Use heuristic to rearange nodes.""" HEADER_HEIGHT = 23 DESCRIPTION_HEIGHT = 20 FOOTER_HEIGHT = 10 BORDERS_HEIGHT = 2 ITEM_HEIGHT = 20 SPACE_HEIGHT = 50 LEFT_PADDING = 30 TOP_PADDING = 80 LEVEL_WIDTH = 252 SPECIAL_PARAMETER_HEIGHT = 20 SPECIAL_PARAMETER_TYPES = [ParameterTypes.CODE] min_node_height = HEADER_HEIGHT + DESCRIPTION_HEIGHT + FOOTER_HEIGHT + BORDERS_HEIGHT node_id_to_level = defaultdict(lambda: -1) node_id_to_node = {} queued_node_ids = set() children_ids = defaultdict(set) node_ids = set([node._id for node in self.nodes]) non_zero_node_ids = set() for node in self.nodes: node_id_to_node[node._id] = node for input in node.inputs: for value in input.values: parent_node_id = to_object_id(value.node_id) non_zero_node_ids.add(parent_node_id) children_ids[parent_node_id].add(node._id) leaves = node_ids - non_zero_node_ids to_visit = deque() for leaf_id in leaves: node_id_to_level[leaf_id] = 0 to_visit.append(leaf_id) while to_visit: node_id = to_visit.popleft() node = node_id_to_node[node_id] node_level = max([node_id_to_level[node_id]] + [ node_id_to_level[child_id] + 1 for child_id in children_ids[node_id] ]) node_id_to_level[node_id] = node_level for input in node.inputs: for value in input.values: parent_node_id = to_object_id(value.node_id) parent_level = node_id_to_level[parent_node_id] node_id_to_level[parent_node_id] = max( node_level + 1, parent_level) if parent_node_id not in queued_node_ids: to_visit.append(parent_node_id) queued_node_ids.add(parent_node_id) max_level = max(node_id_to_level.values()) level_to_node_ids = defaultdict(list) row_heights = defaultdict(lambda: 0) def get_index_helper(node, level): if level < 0: return 0 parent_node_ids = set() for input in node.inputs: for value in input.values: parent_node_ids.add(to_object_id(value.node_id)) for index, node_id in enumerate(level_to_node_ids[level]): if node_id in parent_node_ids: return index return -1 def get_index(node, max_level, level): return tuple([ get_index_helper(node, lvl) for lvl in range(max_level, level, -1) ]) for node_id, level in node_id_to_level.items(): level_to_node_ids[level].append(node_id) for level in range(max_level, -1, -1): level_node_ids = level_to_node_ids[level] index_to_node_id = [] for node_id in level_node_ids: node = node_id_to_node[node_id] index = get_index(node, max_level, level) index_to_node_id.append((index, node_id)) index_to_node_id.sort() level_to_node_ids[level] = [ node_id for _, node_id in index_to_node_id ] for index, node_id in enumerate(level_to_node_ids[level]): node = node_id_to_node[node_id] special_parameters_count = sum( 1 if parameter.parameter_type in SPECIAL_PARAMETER_TYPES and parameter.widget else 0 for parameter in node.parameters) node_height = sum([ min_node_height, ITEM_HEIGHT * max(len(node.inputs), len(node.outputs)), special_parameters_count * SPECIAL_PARAMETER_HEIGHT ]) row_heights[index] = max(row_heights[index], node_height) cum_heights = [0] for index in range(len(row_heights)): cum_heights.append(cum_heights[-1] + row_heights[index] + SPACE_HEIGHT) max_height = max(cum_heights) for level in range(max_level, -1, -1): level_node_ids = level_to_node_ids[level] level_height = cum_heights[len(level_node_ids)] level_padding = (max_height - level_height) // 2 for index, node_id in enumerate(level_node_ids): node = node_id_to_node[node_id] node.x = LEFT_PADDING + (max_level - level) * LEVEL_WIDTH node.y = TOP_PADDING + level_padding + cum_heights[index] def __str__(self): return 'Graph(_id="{}", nodes={})'.format(self._id, [str(b) for b in self.nodes]) def __repr__(self): return 'Graph(_id="{}", title="{}", nodes={})'.format( self._id, self.title, str(self.nodes))
class Node(DBObject): """Basic Node with db interface.""" FIELDS = { '_id': DBObjectField( type=ObjectId, default=ObjectId, is_list=False, ), 'title': DBObjectField( type=str, default='Title', is_list=False, ), 'description': DBObjectField( type=str, default='Description', is_list=False, ), 'base_node_name': DBObjectField( type=str, default='bash_jinja2', is_list=False, ), 'parent_node': DBObjectField( type=ObjectId, default=None, is_list=False, ), 'successor_node': DBObjectField( type=ObjectId, default=None, is_list=False, ), 'inputs': DBObjectField( type=Input, default=list, is_list=True, ), 'outputs': DBObjectField( type=Output, default=list, is_list=True, ), 'parameters': DBObjectField( type=Parameter, default=list, is_list=True, ), 'logs': DBObjectField( type=Output, default=list, is_list=True, ), 'node_running_status': DBObjectField( type=str, default=NodeRunningStatus.CREATED, is_list=False, ), 'node_status': DBObjectField( type=str, default=NodeStatus.CREATED, is_list=False, ), 'cache_url': DBObjectField( type=str, default='', is_list=False, ), 'x': DBObjectField( type=int, default=0, is_list=False, ), 'y': DBObjectField( type=int, default=0, is_list=False, ), 'author': DBObjectField( type=ObjectId, default=None, is_list=False, ), 'public': DBObjectField( type=bool, default=False, is_list=False, ), } DB_COLLECTION = 'nodes' def get_validation_error(self): """Validate Node. Return: (ValidationError) Validation error if found; else None """ violations = [] if self.title == '': violations.append( ValidationError( target=ValidationTargetType.PROPERTY, object_id='title', validation_code=ValidationCode.MISSING_PARAMETER)) for input in self.inputs: if input.min_count < 0: violations.append( ValidationError(target=ValidationTargetType.INPUT, object_id=input.name, validation_code=ValidationCode. MINIMUM_COUNT_MUST_NOT_BE_NEGATIVE)) if input.min_count > input.max_count and input.max_count > 0: violations.append( ValidationError( target=ValidationTargetType.INPUT, object_id=input.name, validation_code=ValidationCode. MINIMUM_COUNT_MUST_BE_GREATER_THAN_MAXIMUM)) if input.max_count == 0: violations.append( ValidationError(target=ValidationTargetType.INPUT, object_id=input.name, validation_code=ValidationCode. MAXIMUM_COUNT_MUST_NOT_BE_ZERO)) # Meaning the node is in the graph. Otherwise souldn't be in validation step if self.node_status != NodeStatus.CREATED: for input in self.inputs: if len(input.values) < input.min_count: violations.append( ValidationError( target=ValidationTargetType.INPUT, object_id=input.name, validation_code=ValidationCode.MISSING_INPUT)) if self.node_status == NodeStatus.MANDATORY_DEPRECATED: violations.append( ValidationError( target=ValidationTargetType.NODE, object_id=str(self._id), validation_code=ValidationCode.DEPRECATED_NODE)) if len(violations) == 0: return None return ValidationError(target=ValidationTargetType.NODE, object_id=str(self._id), validation_code=ValidationCode.IN_DEPENDENTS, children=violations) def apply_properties(self, other_node): """Apply Properties and Inputs of another Node. Args: other_node (Node): A node to copy Properties and Inputs from """ for other_input in other_node.inputs: for input in self.inputs: if other_input.name == input.name: if (input.max_count < 0 or input.max_count >= other_input.max_count) and set( input.file_types) >= set( other_input.file_types): input.values = other_input.values break for other_parameter in other_node.parameters: for parameter in self.parameters: if other_parameter.name == parameter.name: if parameter.parameter_type == other_parameter.parameter_type and parameter.widget: parameter.value = other_parameter.value break self.description = other_node.description self.x = other_node.x self.y = other_node.y def __str__(self): return 'Node(_id="{}")'.format(self._id) def __repr__(self): return 'Node({})'.format(str(self.to_dict())) def _get_custom_element(self, arr, name): for parameter in arr: if parameter.name == name: return parameter raise Exception('Parameter "{}" not found in {}'.format( name, self.title)) def get_input_by_name(self, name): return self._get_custom_element(self.inputs, name) def get_parameter_by_name(self, name): return self._get_custom_element(self.parameters, name) def get_output_by_name(self, name): return self._get_custom_element(self.outputs, name) def get_log_by_name(self, name): return self._get_custom_element(self.logs, name)
class NodeCache(DBObject): """Basic Node Cache with db interface.""" FIELDS = { '_id': DBObjectField( type=ObjectId, default=ObjectId, is_list=False, ), 'key': DBObjectField( type=str, default='', is_list=False, ), 'graph_id': DBObjectField( type=ObjectId, default=None, is_list=False, ), 'node_id': DBObjectField( type=ObjectId, default=None, is_list=False, ), 'outputs': DBObjectField( type=Output, default=list, is_list=True, ), 'logs': DBObjectField( type=Output, default=list, is_list=True, ), } DB_COLLECTION = 'node_cache' IGNORED_PARAMETERS = {'cmd'} @staticmethod def instantiate(node, graph_id, user_id): """Instantiate a Node Cache from Node. Args: node (Node): Node object graph_id (ObjectId, str): Graph ID user_id (ObjectId, str): User ID Return: (NodeCache) """ return NodeCache({ 'key': NodeCache.generate_key(node, user_id), 'node_id': node._id, 'graph_id': graph_id, 'outputs': [output.to_dict() for output in node.outputs], 'logs': [log.to_dict() for log in node.logs], }) # TODO after Demo: remove user_id @staticmethod def generate_key(node, user_id): """Generate hash. Args: node (Node): Node object user_id (ObjectId, str): User ID Return: (str) Hash value """ inputs = node.inputs parameters = node.parameters parent_node = node.parent_node sorted_inputs = sorted(inputs, key=lambda x: x.name) inputs_hash = ','.join([ '{}:{}'.format( input.name, ','.join(sorted(map(lambda x: x.resource_id, input.values))) ) for input in sorted_inputs ]) sorted_parameters = sorted(parameters, key=lambda x: x.name) parameters_hash = ','.join([ '{}:{}'.format( parameter.name, parameter.value ) for parameter in sorted_parameters if parameter.name not in NodeCache.IGNORED_PARAMETERS ]) return hashlib.sha256( '{};{};{};{}'.format( parent_node, inputs_hash, parameters_hash, str(user_id)).encode('utf-8') ).hexdigest() def __str__(self): return 'NodeCache(_id="{}")'.format(self._id) def __repr__(self): return 'NodeCache({})'.format(str(self.to_dict()))
class User(DBObject): """Basic User class with db interface.""" FIELDS = { '_id': DBObjectField( type=ObjectId, default=ObjectId, is_list=False, ), 'username': DBObjectField( type=str, default='', is_list=False, ), 'password_hash': DBObjectField( type=str, default='', is_list=False, ), } DB_COLLECTION = 'users' def hash_password(self, password): """Change password. Args: password (str) Real password string """ self.password_hash = pwd_context.encrypt(password) def verify_password(self, password): """Verify password. Args: password (str) Real password string Return: (bool) True if password matches else False """ return pwd_context.verify(password, self.password_hash) def generate_access_token(self, expiration=600): """Generate access token. Args: expiration (int) Time to Live (TTL) in sec Return: (str) Secured token """ s = TimedSerializer(get_auth_config().secret_key, expires_in=expiration) return s.dumps({'username': self.username, 'type': 'access'}) def generate_refresh_token(self): """Generate refresh token. Return: (str) Secured token """ s = Serializer(get_auth_config().secret_key) return s.dumps({'username': self.username, 'type': 'refresh'}) def __str__(self): return 'User(_id="{}", username={})'.format(self._id, self.username) def __repr__(self): return 'User({})'.format(self.to_dict()) def __getattr__(self, name): raise Exception("Can't get attribute '{}'".format(name)) @staticmethod def find_user_by_name(username): """Find User. Args: username (str) Username Return: (User) User object or None """ user_dict = get_db_connector().users.find_one({'username': username}) if not user_dict: return None return User(user_dict) @staticmethod def verify_auth_token(token): """Verify token. Args: token (str) Token Return: (User) User object or None """ s = TimedSerializer(get_auth_config().secret_key) try: data = s.loads(token) if data['type'] != 'access': raise Exception('Not access token') except (BadSignature, SignatureExpired) as e: # access token is not valid or expired s = Serializer(get_auth_config().secret_key) try: data = s.loads(token) if data['type'] != 'refresh': raise Exception('Not refresh token') except Exception: return None except Exception as e: print("Unexpected exception: {}".format(e)) return None user = User.find_user_by_name(data['username']) return user
class Parameter(DBObject): """Basic Parameter structure.""" FIELDS = { 'name': DBObjectField( type=str, default='', is_list=False, ), 'parameter_type': DBObjectField( type=str, default=ParameterTypes.STR, is_list=False, ), # TODO make type factory 'value': DBObjectField( type=lambda x: x, # Preserve type default='', is_list=False, ), 'mutable_type': DBObjectField( type=bool, default=True, is_list=False, ), 'removable': DBObjectField( type=bool, default=True, is_list=False, ), 'publicable': DBObjectField( type=bool, default=True, is_list=False, ), 'widget': DBObjectField( type=ParameterWidget, default=None, is_list=False, ), } def __init__(self, obj_dict=None): super(Parameter, self).__init__(obj_dict) # `value` field is a special case: the type depends on `parameter_type` if self.value is None: self.value = _get_default_by_type(self.parameter_type) elif self.parameter_type == ParameterTypes.ENUM: self.value = ParameterEnum.from_dict(self.value) elif self.parameter_type == ParameterTypes.CODE: self.value = ParameterCode.from_dict(self.value) assert _value_is_valid(self.value, self.parameter_type), \ "Invalid parameter value type: {}: {}".format(self.name, self.value) def __str__(self): return 'Parameter(name="{}")'.format(self.name) def __repr__(self): return 'Parameter({})'.format(str(self.to_dict()))