class WorkerState(DBObject): """Worker statistics snapshot.""" FIELDS = { '_id': DBObjectField( type=ObjectId, default=ObjectId, is_list=False, ), 'worker_id': DBObjectField( type=str, default=None, is_list=False, ), 'host': DBObjectField( type=str, default='', is_list=False, ), 'runs': DBObjectField( type=Node, default=list, is_list=True, ), 'kinds': DBObjectField( type=str, default=list, is_list=True, ), } DB_COLLECTION = Collections.WORKER_STATE
class Output(DBObject): """Basic Output structure.""" FIELDS = { 'name': DBObjectField( type=str, default='', is_list=False, ), 'file_type': DBObjectField( type=str, default='', is_list=False, ), 'resource_id': DBObjectField( type=str, default=None, is_list=False, ), } def __str__(self): return 'Output(name="{}")'.format(self.name) def __repr__(self): return 'Output({})'.format(str(self.to_dict()))
class InputValue(DBObject): """Basic Value of the Input structure.""" FIELDS = { 'node_id': DBObjectField( type=str, default='', is_list=False, ), 'output_id': DBObjectField( type=str, default='', is_list=False, ), 'resource_id': DBObjectField( type=str, default='', is_list=False, ), } def __str__(self): return 'InputValue({}, {})'.format(self.node_id, self.output_id) def __repr__(self): return 'InputValue({})'.format(str(self.to_dict()))
class RunCancellation(DBObject): """RunCancellation represents Run Cancellation event in the database.""" FIELDS = { '_id': DBObjectField( type=ObjectId, default=ObjectId, is_list=False, ), "run_id": DBObjectField( type=ObjectId, default=None, is_list=False, ), } DB_COLLECTION = Collections.RUN_CANCELLATIONS
class MasterState(DBObject): """Master statistics snapshot.""" FIELDS = { '_id': DBObjectField( type=ObjectId, default=ObjectId, is_list=False, ), 'workers': DBObjectField( type=WorkerState, default=list, is_list=True, ), } DB_COLLECTION = Collections.MASTER_STATE
class ParameterEnum(DBObject): """Enum value.""" FIELDS = { 'values': DBObjectField( type=str, default=list, is_list=True, ), 'index': DBObjectField( type=str, default=-1, is_list=False, ), } def __repr__(self): return 'ParameterEnum({})'.format(str(self.to_dict()))
class ParameterCode(DBObject): """Code value.""" FIELDS = { 'value': DBObjectField( type=str, default='', is_list=False, ), 'mode': DBObjectField( type=str, default='python', is_list=False, ), } def __repr__(self): return 'ParameterCode({})'.format(str(self.to_dict()))
class Group(DBObject): """Group with db interface.""" FIELDS = { '_id': DBObjectField( type=ObjectId, default=ObjectId, is_list=False, ), '_type': DBObjectField( type=str, default='Group', is_list=False, ), 'title': DBObjectField( type=str, default='Title', is_list=False, ), # Kind, such as plynx.plugins.executors.local.BashJinja2. Derived from from plynx.plugins.executors.BaseExecutor class. 'kind': DBObjectField( type=str, default='', is_list=False, ), 'items': DBObjectField( type=lambda x: x, # Preserve type default=list, is_list=True, ), 'author': DBObjectField( type=ObjectId, default=None, is_list=False, ), 'starred': DBObjectField( type=bool, default=False, is_list=False, ), } DB_COLLECTION = Collections.GROUPS def __str__(self): return 'Group(_id="{}")'.format(self._id) def __repr__(self): return 'Group({})'.format(str(self.to_dict()))
class WorkerState(DBObject): """Status of the worker""" FIELDS = { '_id': DBObjectField( type=ObjectId, default=ObjectId, is_list=False, ), 'worker_id': DBObjectField( type=str, default=None, is_list=False, ), 'graph_id': DBObjectField( type=ObjectId, default=None, is_list=False, ), 'node': DBObjectField( type=Node, default=None, is_list=False, ), 'host': DBObjectField( type=str, default='', is_list=False, ), 'num_finished_jobs': DBObjectField( type=int, default=0, is_list=False, ), }
class GraphCancellation(DBObject): """GraphCancellation represents Graph Cancellation event in the database.""" FIELDS = { '_id': DBObjectField( type=ObjectId, default=ObjectId, is_list=False, ), "graph_id": DBObjectField( type=ObjectId, default=None, is_list=False, ), "acknowledged": DBObjectField( type=bool, default=False, is_list=False, ), } DB_COLLECTION = Collections.GRAPHS_CANCELLATIONS
class ParameterListOfNodes(DBObject): """List Of Nodes value.""" FIELDS = { 'value': DBObjectField( type=Node, default=list, is_list=True, ), } def __repr__(self): return 'ParameterListOfNodes({})'.format(str(self.to_dict()))
class Input(DBObject): """Basic Input structure.""" FIELDS = { 'name': DBObjectField( type=str, default='', is_list=False, ), 'file_types': DBObjectField( type=str, default=list, is_list=True, ), 'values': DBObjectField( type=InputValue, default=list, is_list=True, ), 'min_count': DBObjectField( type=int, default=1, is_list=False, ), 'max_count': DBObjectField( type=int, default=1, is_list=False, ), } def __str__(self): return 'Input(name="{}")'.format(self.name) def __repr__(self): return 'Input({})'.format(str(self.to_dict()))
class Input(DBObject): """Basic Input structure.""" FIELDS = dict({ 'input_references': DBObjectField( type=InputReference, default=list, is_list=True, ), }, **RESOURCE_FIELDS) def __str__(self): return 'Input(name="{}")'.format(self.name) def __repr__(self): return 'Input({})'.format(str(self.to_dict()))
class ParameterWidget(DBObject): """Basic ParameterWidget structure.""" FIELDS = { 'alias': DBObjectField( type=str, default='', is_list=False, ), } def __str__(self): return 'ParameterWidget(alias="{}")'.format(self.alias) def __repr__(self): return 'ParameterWidget({})'.format(str(self.to_dict()))
class NodeCache(DBObject): """Basic Node Cache with db interface.""" FIELDS = { '_id': DBObjectField( type=ObjectId, default=ObjectId, is_list=False, ), 'key': DBObjectField( type=str, default='', is_list=False, ), 'run_id': DBObjectField( type=ObjectId, default=None, is_list=False, ), 'node_id': DBObjectField( type=ObjectId, default=None, is_list=False, ), 'outputs': DBObjectField( type=Output, default=list, is_list=True, ), 'logs': DBObjectField( type=Output, default=list, is_list=True, ), } DB_COLLECTION = Collections.NODE_CACHE IGNORED_PARAMETERS = {'cmd', '_timeout'} @staticmethod def instantiate(node, run_id): """Instantiate a Node Cache from Node. Args: node (Node): Node object run_id (ObjectId, str): Run ID Return: (NodeCache) """ return NodeCache({ 'key': NodeCache.generate_key(node), 'node_id': node._id, 'run_id': run_id, 'outputs': [output.to_dict() for output in node.outputs], 'logs': [log.to_dict() for log in node.logs], }) @staticmethod def generate_key(node): """Generate hash. Args: node (Node): Node object Return: (str) Hash value """ inputs = node.inputs parameters = node.parameters original_node_id = node.original_node_id sorted_inputs = sorted(inputs, key=lambda x: x.name) inputs_hash = ','.join([ '{}:{}'.format(input.name, ','.join(sorted(input.values))) for input in sorted_inputs ]) sorted_parameters = sorted(parameters, key=lambda x: x.name) parameters_hash = ','.join([ '{}:{}'.format(parameter.name, parameter.value) for parameter in sorted_parameters if parameter.name not in NodeCache.IGNORED_PARAMETERS ]) return hashlib.sha256(';'.join([ str(original_node_id), inputs_hash, parameters_hash, ]).encode('utf-8')).hexdigest() def __str__(self): return 'NodeCache(_id="{}")'.format(self._id) def __repr__(self): return 'NodeCache({})'.format(str(self.to_dict()))
class Parameter(DBObject): """Basic Parameter structure.""" FIELDS = { 'name': DBObjectField( type=str, default='', is_list=False, ), 'parameter_type': DBObjectField( type=str, default=ParameterTypes.STR, is_list=False, ), # TODO make type factory 'value': DBObjectField( type=lambda x: x, # Preserve type default='', is_list=False, ), 'mutable_type': DBObjectField( type=bool, default=True, is_list=False, ), 'removable': DBObjectField( type=bool, default=True, is_list=False, ), 'publicable': DBObjectField( type=bool, default=True, is_list=False, ), 'widget': DBObjectField( type=str, default=None, is_list=False, ), # Link to global parameter 'reference': DBObjectField( type=str, default=None, is_list=False, ), } def __init__(self, obj_dict=None): super(Parameter, self).__init__(obj_dict) # `value` field is a special case: the type depends on `parameter_type` if self.value is None: self.value = _get_default_by_type(self.parameter_type) elif self.parameter_type == ParameterTypes.ENUM: self.value = ParameterEnum.from_dict(self.value) elif self.parameter_type == ParameterTypes.CODE: self.value = ParameterCode.from_dict(self.value) elif self.parameter_type == ParameterTypes.LIST_NODE: self.value = ParameterListOfNodes.from_dict(self.value) if not _value_is_valid(self.value, self.parameter_type): raise ValueError("Invalid parameter value type: {}: {}".format(self.name, self.value)) def __str__(self): return 'Parameter(name="{}")'.format(self.name) def __repr__(self): return 'Parameter({})'.format(str(self.to_dict()))
class Node(DBObject): """Basic Node with db interface.""" FIELDS = { '_id': DBObjectField( type=ObjectId, default=ObjectId, is_list=False, ), 'title': DBObjectField( type=str, default='Title', is_list=False, ), 'description': DBObjectField( type=str, default='Description', is_list=False, ), 'base_node_name': DBObjectField( type=str, default='bash_jinja2', is_list=False, ), 'parent_node': DBObjectField( type=ObjectId, default=None, is_list=False, ), 'successor_node': DBObjectField( type=ObjectId, default=None, is_list=False, ), 'inputs': DBObjectField( type=Input, default=list, is_list=True, ), 'outputs': DBObjectField( type=Output, default=list, is_list=True, ), 'parameters': DBObjectField( type=Parameter, default=list, is_list=True, ), 'logs': DBObjectField( type=Output, default=list, is_list=True, ), 'node_running_status': DBObjectField( type=str, default=NodeRunningStatus.CREATED, is_list=False, ), 'node_status': DBObjectField( type=str, default=NodeStatus.CREATED, is_list=False, ), 'cache_url': DBObjectField( type=str, default='', is_list=False, ), 'x': DBObjectField( type=int, default=0, is_list=False, ), 'y': DBObjectField( type=int, default=0, is_list=False, ), 'author': DBObjectField( type=ObjectId, default=None, is_list=False, ), 'starred': DBObjectField( type=bool, default=False, is_list=False, ), } DB_COLLECTION = Collections.NODES def get_validation_error(self): """Validate Node. Return: (ValidationError) Validation error if found; else None """ violations = [] if self.title == '': violations.append( ValidationError( target=ValidationTargetType.PROPERTY, object_id='title', validation_code=ValidationCode.MISSING_PARAMETER )) for input in self.inputs: if input.min_count < 0: violations.append( ValidationError( target=ValidationTargetType.INPUT, object_id=input.name, validation_code=ValidationCode.MINIMUM_COUNT_MUST_NOT_BE_NEGATIVE )) if input.min_count > input.max_count and input.max_count > 0: violations.append( ValidationError( target=ValidationTargetType.INPUT, object_id=input.name, validation_code=ValidationCode.MINIMUM_COUNT_MUST_BE_GREATER_THAN_MAXIMUM )) if input.max_count == 0: violations.append( ValidationError( target=ValidationTargetType.INPUT, object_id=input.name, validation_code=ValidationCode.MAXIMUM_COUNT_MUST_NOT_BE_ZERO )) # Meaning the node is in the graph. Otherwise souldn't be in validation step if self.node_status != NodeStatus.CREATED: for input in self.inputs: if len(input.values) < input.min_count: violations.append( ValidationError( target=ValidationTargetType.INPUT, object_id=input.name, validation_code=ValidationCode.MISSING_INPUT )) if self.node_status == NodeStatus.MANDATORY_DEPRECATED: violations.append( ValidationError( target=ValidationTargetType.NODE, object_id=str(self._id), validation_code=ValidationCode.DEPRECATED_NODE )) if len(violations) == 0: return None return ValidationError( target=ValidationTargetType.NODE, object_id=str(self._id), validation_code=ValidationCode.IN_DEPENDENTS, children=violations ) def apply_properties(self, other_node): """Apply Properties and Inputs of another Node. Args: other_node (Node): A node to copy Properties and Inputs from """ for other_input in other_node.inputs: for input in self.inputs: if other_input.name == input.name: if (input.max_count < 0 or input.max_count >= other_input.max_count) and set(input.file_types) >= set(other_input.file_types): input.values = other_input.values break for other_parameter in other_node.parameters: for parameter in self.parameters: if other_parameter.name == parameter.name: if parameter.parameter_type == other_parameter.parameter_type and parameter.widget: parameter.value = other_parameter.value break self.description = other_node.description self.x = other_node.x self.y = other_node.y def __str__(self): return 'Node(_id="{}")'.format(self._id) def __repr__(self): return 'Node({})'.format(str(self.to_dict())) def _get_custom_element(self, arr, name, throw): for parameter in arr: if parameter.name == name: return parameter if throw: raise Exception('Parameter "{}" not found in {}'.format(name, self.title)) return None def get_input_by_name(self, name, throw=True): return self._get_custom_element(self.inputs, name, throw) def get_parameter_by_name(self, name, throw=True): return self._get_custom_element(self.parameters, name, throw) def get_output_by_name(self, name, throw=True): return self._get_custom_element(self.outputs, name, throw) def get_log_by_name(self, name, throw=True): return self._get_custom_element(self.logs, name, throw)
class User(DBObject): """Basic User class with db interface.""" FIELDS = { '_id': DBObjectField( type=ObjectId, default=ObjectId, is_list=False, ), 'username': DBObjectField( type=str, default='', is_list=False, ), 'password_hash': DBObjectField( type=str, default='', is_list=False, ), 'active': DBObjectField( type=bool, default=True, is_list=False, ), } DB_COLLECTION = Collections.USERS def hash_password(self, password): """Change password. Args: password (str) Real password string """ self.password_hash = pwd_context.encrypt(password) def verify_password(self, password): """Verify password. Args: password (str) Real password string Return: (bool) True if password matches else False """ return pwd_context.verify(password, self.password_hash) def generate_access_token(self, expiration=600): """Generate access token. Args: expiration (int) Time to Live (TTL) in sec Return: (str) Secured token """ s = TimedSerializer(get_auth_config().secret_key, expires_in=expiration) return s.dumps({'username': self.username, 'type': 'access'}) def generate_refresh_token(self): """Generate refresh token. Return: (str) Secured token """ s = Serializer(get_auth_config().secret_key) return s.dumps({'username': self.username, 'type': 'refresh'}) def __str__(self): return 'User(_id="{}", username={})'.format(self._id, self.username) def __repr__(self): return 'User({})'.format(self.to_dict()) def __getattr__(self, name): raise Exception("Can't get attribute '{}'".format(name)) @staticmethod def find_user_by_name(username): """Find User. Args: username (str) Username Return: (User) User object or None """ user_dict = getattr(get_db_connector(), User.DB_COLLECTION).find_one({'username': username}) if not user_dict: return None return User(user_dict) @staticmethod def find_users(): return getattr(get_db_connector(), User.DB_COLLECTION).find({}) @staticmethod def verify_auth_token(token): """Verify token. Args: token (str) Token Return: (User) User object or None """ s = TimedSerializer(get_auth_config().secret_key) try: data = s.loads(token) if data['type'] != 'access': raise Exception('Not access token') except (BadSignature, SignatureExpired): # access token is not valid or expired s = Serializer(get_auth_config().secret_key) try: data = s.loads(token) if data['type'] != 'refresh': raise Exception('Not refresh token') except Exception: return None except Exception as e: print("Unexpected exception: {}".format(e)) return None user = User.find_user_by_name(data['username']) if not user.active: return None return user
class Node(DBObject): """Basic Node with db interface.""" FIELDS = { '_id': DBObjectField( type=ObjectId, default=ObjectId, is_list=False, ), '_type': DBObjectField( type=str, default='Node', is_list=False, ), 'title': DBObjectField( type=str, default='Title', is_list=False, ), 'description': DBObjectField( type=str, default='Description', is_list=False, ), # Kind, such as plynx.plugins.executors.local.BashJinja2. Derived from from plynx.plugins.executors.BaseExecutor class. 'kind': DBObjectField( type=str, default='dummy', is_list=False, ), # ID of previous version of the node, always refer to `nodes` collection. 'parent_node_id': DBObjectField( type=ObjectId, default=None, is_list=False, ), # ID of next version of the node, always refer to `nodes` collection. 'successor_node_id': DBObjectField( type=ObjectId, default=None, is_list=False, ), # ID of original node, used in `runs`, always refer to `nodes` collection. # A Run refers to original node 'original_node_id': DBObjectField( type=ObjectId, default=None, is_list=False, ), 'inputs': DBObjectField( type=Input, default=list, is_list=True, ), 'outputs': DBObjectField( type=Output, default=list, is_list=True, ), 'parameters': DBObjectField( type=lambda object_dict: Parameter(object_dict), default=list, is_list=True, ), 'logs': DBObjectField( type=Output, default=list, is_list=True, ), 'node_running_status': DBObjectField( type=str, default=NodeRunningStatus.CREATED, is_list=False, ), 'node_status': DBObjectField( type=str, default=NodeStatus.CREATED, is_list=False, ), 'cache_url': DBObjectField( type=str, default='', is_list=False, ), 'x': DBObjectField( type=int, default=0, is_list=False, ), 'y': DBObjectField( type=int, default=0, is_list=False, ), 'author': DBObjectField( type=ObjectId, default=None, is_list=False, ), 'starred': DBObjectField( type=bool, default=False, is_list=False, ), } DB_COLLECTION = Collections.TEMPLATES def _DEFAULT_LOG(name): return Output.from_dict({ 'name': name, 'file_type': FILE_KIND, 'values': [], 'is_array': False, 'min_count': 1, }) def apply_properties(self, other_node): """Apply Properties and Inputs of another Node. This method is used for updating nodes. Args: other_node (Node): A node to copy Properties and Inputs from """ for other_input in other_node.inputs: for input in self.inputs: if other_input.name == input.name and \ other_input.file_type == input.file_type and ( input.is_array or (not input.is_array and 1 == len(other_input.input_references)) ): input.input_references = other_input.input_references break for other_parameter in other_node.parameters: for parameter in self.parameters: if other_parameter.name == parameter.name: if parameter.parameter_type == other_parameter.parameter_type and parameter.widget: parameter.value = other_parameter.value break self.description = other_node.description self.x = other_node.x self.y = other_node.y def clone(self, node_clone_policy): return _clone_update_in_place(self.copy(), node_clone_policy) def __str__(self): return 'Node(_id="{}")'.format(self._id) def __repr__(self): return 'Node({})'.format(str(self.to_dict())) def _get_custom_element(self, arr, name, throw, default=None): for parameter in arr: if parameter.name == name: return parameter if throw: raise Exception('Parameter "{}" not found in {}'.format(name, self.title)) if default: arr.append(default(name)) return arr[-1] return None def get_input_by_name(self, name, throw=True): return self._get_custom_element(self.inputs, name, throw) def get_parameter_by_name(self, name, throw=True): return self._get_custom_element(self.parameters, name, throw) def get_output_by_name(self, name, throw=True): return self._get_custom_element(self.outputs, name, throw) def get_log_by_name(self, name, throw=False): return self._get_custom_element(self.logs, name, throw, default=Node._DEFAULT_LOG) def arrange_auto_layout(self, readonly=False): """Use heuristic to rearange nodes.""" HEADER_HEIGHT = 23 TITLE_HEIGHT = 20 FOOTER_HEIGHT = 10 BORDERS_HEIGHT = 2 ITEM_HEIGHT = 20 SPACE_HEIGHT = 50 LEFT_PADDING = 30 TOP_PADDING = 80 LEVEL_WIDTH = 252 SPECIAL_PARAMETER_HEIGHT = 20 SPECIAL_PARAMETER_TYPES = [ParameterTypes.CODE] min_node_height = HEADER_HEIGHT + TITLE_HEIGHT + FOOTER_HEIGHT + BORDERS_HEIGHT node_id_to_level = defaultdict(lambda: -1) node_id_to_node = {} queued_node_ids = set() children_ids = defaultdict(set) sub_nodes = self.get_parameter_by_name('_nodes').value.value if len(sub_nodes) == 0: return node_ids = set([node._id for node in sub_nodes]) non_zero_node_ids = set() for node in sub_nodes: node_id_to_node[node._id] = node for input in node.inputs: for input_reference in input.input_references: parent_node_id = ObjectId(input_reference.node_id) non_zero_node_ids.add(parent_node_id) children_ids[parent_node_id].add(node._id) leaves = node_ids - non_zero_node_ids to_visit = deque() # Alwasy put Output Node in the end push_special = True if SpecialNodeId.OUTPUT in leaves and len(leaves) > 1 else False for leaf_id in leaves: node_id_to_level[leaf_id] = 1 if push_special and leaf_id != SpecialNodeId.OUTPUT else 0 to_visit.append(leaf_id) while to_visit: node_id = to_visit.popleft() node = node_id_to_node[node_id] node_level = max([node_id_to_level[node_id]] + [node_id_to_level[child_id] + 1 for child_id in children_ids[node_id]]) node_id_to_level[node_id] = node_level for input in node.inputs: for input_reference in input.input_references: parent_node_id = ObjectId(input_reference.node_id) parent_level = node_id_to_level[parent_node_id] node_id_to_level[parent_node_id] = max(node_level + 1, parent_level) if parent_node_id not in queued_node_ids: to_visit.append(parent_node_id) queued_node_ids.add(parent_node_id) max_level = max(node_id_to_level.values()) level_to_node_ids = defaultdict(list) row_heights = defaultdict(lambda: 0) def get_index_helper(node, level): if level < 0: return 0 parent_node_ids = set() for input in node.inputs: for input_reference in input.input_references: parent_node_ids.add(ObjectId(input_reference.node_id)) for index, node_id in enumerate(level_to_node_ids[level]): if node_id in parent_node_ids: return index return -1 def get_index(node, max_level, level): return tuple( [get_index_helper(node, lvl) for lvl in range(max_level, level, -1)] ) for node_id, level in node_id_to_level.items(): level_to_node_ids[level].append(node_id) # Push Input Node up the level if SpecialNodeId.INPUT in node_id_to_level and \ (node_id_to_level[SpecialNodeId.INPUT] != max_level or len(level_to_node_ids[max_level]) > 1): input_level = node_id_to_level[SpecialNodeId.INPUT] level_to_node_ids[input_level] = [node_id for node_id in level_to_node_ids[input_level] if node_id != SpecialNodeId.INPUT] max_level += 1 node_id_to_level[SpecialNodeId.INPUT] = max_level level_to_node_ids[max_level] = [SpecialNodeId.INPUT] for level in range(max_level, -1, -1): level_node_ids = level_to_node_ids[level] index_to_node_id = [] for node_id in level_node_ids: node = node_id_to_node[node_id] index = get_index(node, max_level, level) index_to_node_id.append((index, node_id)) index_to_node_id.sort() level_to_node_ids[level] = [node_id for _, node_id in index_to_node_id] for index, node_id in enumerate(level_to_node_ids[level]): node = node_id_to_node[node_id] special_parameters_count = sum( 1 if parameter.parameter_type in SPECIAL_PARAMETER_TYPES and parameter.widget else 0 for parameter in node.parameters ) node_height = sum([ min_node_height, ITEM_HEIGHT * max(len(node.inputs), len(node.outputs)), special_parameters_count * SPECIAL_PARAMETER_HEIGHT ]) row_heights[index] = max(row_heights[index], node_height) # TODO compute grid in a separate function if readonly: return level_to_node_ids, node_id_to_node cum_heights = [0] for index in range(len(row_heights)): cum_heights.append(cum_heights[-1] + row_heights[index] + SPACE_HEIGHT) max_height = max(cum_heights) for level in range(max_level, -1, -1): level_node_ids = level_to_node_ids[level] level_height = cum_heights[len(level_node_ids)] level_padding = (max_height - level_height) // 2 for index, node_id in enumerate(level_node_ids): node = node_id_to_node[node_id] node.x = LEFT_PADDING + (max_level - level) * LEVEL_WIDTH node.y = TOP_PADDING + level_padding + cum_heights[index]
log.values = [] for parameter in sub_node.parameters: if not parameter.reference: continue parameter.value = node.get_parameter_by_name(parameter.reference, throw=True).value for output_or_log in node.outputs + node.logs: output_or_log.resource_id = None return node RESOURCE_FIELDS = { 'name': DBObjectField( type=str, default='', is_list=False, ), 'file_type': DBObjectField( type=str, default=FILE_KIND, is_list=False, ), 'values': DBObjectField( type=lambda object_dict: object_dict, default=list, is_list=True, ), 'is_array': DBObjectField( type=bool, default=False,
class Graph(DBObject): """Basic graph with db interface.""" FIELDS = { '_id': DBObjectField( type=ObjectId, default=ObjectId, is_list=False, ), 'title': DBObjectField( type=str, default='Title', is_list=False, ), 'description': DBObjectField( type=str, default='Description', is_list=False, ), 'graph_running_status': DBObjectField( type=str, default=GraphRunningStatus.CREATED, is_list=False, ), 'author': DBObjectField( type=ObjectId, default=None, is_list=False, ), 'nodes': DBObjectField( type=Node, default=list, is_list=True, ), } DB_COLLECTION = Collections.GRAPHS def cancel(self): """Cancel the graph.""" self.graph_running_status = GraphRunningStatus.CANCELED for node in self.nodes: if node.node_running_status in [ NodeRunningStatus.RUNNING, NodeRunningStatus.IN_QUEUE ]: node.node_running_status = NodeRunningStatus.CANCELED self.save(force=True) def get_validation_error(self): """Validate Graph. Return: (ValidationError) Validation error if found; else None """ violations = [] if self.title == '': violations.append( ValidationError( target=ValidationTargetType.PROPERTY, object_id='title', validation_code=ValidationCode.MISSING_PARAMETER)) if self.description == '': violations.append( ValidationError( target=ValidationTargetType.PROPERTY, object_id='description', validation_code=ValidationCode.MISSING_PARAMETER)) if len(self.nodes) == 0: violations.append( ValidationError(target=ValidationTargetType.PROPERTY, object_id=str(self._id), validation_code=ValidationCode.EMPTY_GRAPH)) # Meaning the node is in the graph. Otherwise souldn't be in validation step for node in self.nodes: node_violation = node.get_validation_error() if node_violation: violations.append(node_violation) if len(violations) == 0: return None return ValidationError(target=ValidationTargetType.GRAPH, object_id=str(self._id), validation_code=ValidationCode.IN_DEPENDENTS, children=violations) def arrange_auto_layout(self, readonly=False): """Use heuristic to rearange nodes.""" HEADER_HEIGHT = 23 DESCRIPTION_HEIGHT = 20 FOOTER_HEIGHT = 10 BORDERS_HEIGHT = 2 ITEM_HEIGHT = 20 SPACE_HEIGHT = 50 LEFT_PADDING = 30 TOP_PADDING = 80 LEVEL_WIDTH = 252 SPECIAL_PARAMETER_HEIGHT = 20 SPECIAL_PARAMETER_TYPES = [ParameterTypes.CODE] min_node_height = HEADER_HEIGHT + DESCRIPTION_HEIGHT + FOOTER_HEIGHT + BORDERS_HEIGHT node_id_to_level = defaultdict(lambda: -1) node_id_to_node = {} queued_node_ids = set() children_ids = defaultdict(set) node_ids = set([node._id for node in self.nodes]) non_zero_node_ids = set() for node in self.nodes: node_id_to_node[node._id] = node for input in node.inputs: for value in input.values: parent_node_id = to_object_id(value.node_id) non_zero_node_ids.add(parent_node_id) children_ids[parent_node_id].add(node._id) leaves = node_ids - non_zero_node_ids to_visit = deque() for leaf_id in leaves: node_id_to_level[leaf_id] = 0 to_visit.append(leaf_id) while to_visit: node_id = to_visit.popleft() node = node_id_to_node[node_id] node_level = max([node_id_to_level[node_id]] + [ node_id_to_level[child_id] + 1 for child_id in children_ids[node_id] ]) node_id_to_level[node_id] = node_level for input in node.inputs: for value in input.values: parent_node_id = to_object_id(value.node_id) parent_level = node_id_to_level[parent_node_id] node_id_to_level[parent_node_id] = max( node_level + 1, parent_level) if parent_node_id not in queued_node_ids: to_visit.append(parent_node_id) queued_node_ids.add(parent_node_id) max_level = max(node_id_to_level.values()) level_to_node_ids = defaultdict(list) row_heights = defaultdict(lambda: 0) def get_index_helper(node, level): if level < 0: return 0 parent_node_ids = set() for input in node.inputs: for value in input.values: parent_node_ids.add(to_object_id(value.node_id)) for index, node_id in enumerate(level_to_node_ids[level]): if node_id in parent_node_ids: return index return -1 def get_index(node, max_level, level): return tuple([ get_index_helper(node, lvl) for lvl in range(max_level, level, -1) ]) for node_id, level in node_id_to_level.items(): level_to_node_ids[level].append(node_id) # level_to_node_ids, node_id_to_node, for level in range(max_level, -1, -1): level_node_ids = level_to_node_ids[level] index_to_node_id = [] for node_id in level_node_ids: node = node_id_to_node[node_id] index = get_index(node, max_level, level) index_to_node_id.append((index, node_id)) index_to_node_id.sort() level_to_node_ids[level] = [ node_id for _, node_id in index_to_node_id ] for index, node_id in enumerate(level_to_node_ids[level]): node = node_id_to_node[node_id] special_parameters_count = sum( 1 if parameter.parameter_type in SPECIAL_PARAMETER_TYPES and parameter.widget else 0 for parameter in node.parameters) node_height = sum([ min_node_height, ITEM_HEIGHT * max(len(node.inputs), len(node.outputs)), special_parameters_count * SPECIAL_PARAMETER_HEIGHT ]) row_heights[index] = max(row_heights[index], node_height) # TODO compute grid in a separate function if readonly: return level_to_node_ids, node_id_to_node cum_heights = [0] for index in range(len(row_heights)): cum_heights.append(cum_heights[-1] + row_heights[index] + SPACE_HEIGHT) max_height = max(cum_heights) for level in range(max_level, -1, -1): level_node_ids = level_to_node_ids[level] level_height = cum_heights[len(level_node_ids)] level_padding = (max_height - level_height) // 2 for index, node_id in enumerate(level_node_ids): node = node_id_to_node[node_id] node.x = LEFT_PADDING + (max_level - level) * LEVEL_WIDTH node.y = TOP_PADDING + level_padding + cum_heights[index] return None, None def generate_code(self): code_blocks = [] unique_nodes = {node.parent_node: node for node in self.nodes} def name_iteration_handler(lst): return ', '.join( map(lambda element: "'{}'".format(element.name), lst)) def generate_class_name(title): return ''.join( map(lambda s: s.title(), re.split('[^a-zA-Z]', title))) def generate_var_name(title): return '_'.join(re.split('[^a-zA-Z0-9]', title)) def param_to_value(param): if param.parameter_type == ParameterTypes.INT: return param.value elif param.parameter_type == ParameterTypes.ENUM: return repr(param.value.values[int(param.value.index)]) elif param.parameter_type == ParameterTypes.ENUM: return param.value.values[param.value.index] elif param.parameter_type == ParameterTypes.LIST_INT: return map(int, param.value) elif param.parameter_type == ParameterTypes.CODE: return repr(param.value.value) return repr(param.value) used_class_names = set() node_id_to_class_name = {} for node_id, node in unique_nodes.items(): if node.base_node_name == 'file': class_type = 'File' content = '\n '.join([ "", "id='{}',".format(node_id), "title='{}',".format(node.title), "description='{}',".format(node.description), ]) orig_class_name = generate_var_name(node.title) else: class_type = 'Operation' content = '\n '.join([ "", "id='{}',".format(node_id), "title='{}',".format(node.title), "inputs=[{}],".format(name_iteration_handler(node.inputs)), "params=[{}],".format( name_iteration_handler( filter(lambda p: p.widget, node.parameters))), "outputs=[{}],".format(name_iteration_handler( node.outputs)), ]) orig_class_name = generate_class_name(node.title) class_name = orig_class_name while class_name in used_class_names: class_name = '{}_{}'.format(orig_class_name, str(uuid.uuid1())[:4]) used_class_names.add(class_name) node_id_to_class_name[node_id] = class_name code = "{class_name} = {class_type}({content}\n)\n".format( class_name=class_name, class_type=class_type, content=content, ) code_blocks.append(code) level_to_node_ids, node_id_to_node = self.arrange_auto_layout( readonly=True) max_level = max(level_to_node_ids.keys()) node_id_to_var_name = {} for level in range(max_level, -1, -1): for row, node_id in enumerate(level_to_node_ids[level]): node = node_id_to_node[node_id] if node.base_node_name == 'file': node_id_to_var_name[node_id] = node_id_to_class_name[ node.parent_node] continue var_name = '{}_{}_{}'.format( generate_var_name(node.title.lower()), max_level - level, row) # generate args args = [] for input in node.inputs: values = [] for value in input.values: values.append('{}.outputs.{}'.format( node_id_to_var_name[to_object_id(value.node_id)], value.output_id)) if values: args.append(' {}={},'.format( generate_var_name(input.name), values[0] if len(values) == 1 else '[{}]'.format( ', '.join(values)))) for param in node.parameters: if param.widget: args.append(' {}={},'.format( generate_var_name(param.name), param_to_value(param))) # generate var declaration node_id_to_var_name[node_id] = var_name content = '{var_name} = {class_name}(\n{args}\n)\n'.format( var_name=var_name, class_name=node_id_to_class_name[node.parent_node], args='\n'.join(args)) code_blocks.append(content) code_blocks.append( "graph = Graph(\n" " Client(\n" " token=TOKEN,\n" " endpoint=ENDPOINT,\n" " ),\n" " title='{title}',\n" " description='{description}',\n" " targets=[{targets}]\n" ")\n\n" "graph.approve().wait()\n".format( title=self.title, description=self.description, targets=", ".join( map(lambda node_id: node_id_to_var_name[node_id], level_to_node_ids[0])))) return '\n'.join(code_blocks) def clone(self): graph = Graph.from_dict(self.to_dict()) graph._id = ObjectId() graph.graph_running_status = GraphRunningStatus.CREATED for node in graph.nodes: if node.node_running_status != NodeRunningStatus.STATIC: node.node_running_status = NodeRunningStatus.CREATED for output in node.outputs: output.resource_id = None for input in node.inputs: for value in input.values: value.resource_id = None for log in node.logs: log.resource_id = None return graph def __str__(self): return 'Graph(_id="{}", nodes={})'.format(self._id, [str(b) for b in self.nodes]) def __repr__(self): return 'Graph(_id="{}", title="{}", nodes={})'.format( self._id, self.title, str(self.nodes))
for parameter in sub_node.parameters: if not parameter.reference: continue parameter.value = node.get_parameter_by_name( parameter.reference, throw=True).value for output_or_log in node.outputs + node.logs: output_or_log.resource_id = None return node RESOURCE_FIELDS = { 'name': DBObjectField( type=str, default='', is_list=False, ), 'file_type': DBObjectField( type=str, default=FILE_KIND, is_list=False, ), 'values': DBObjectField( type=lambda object_dict: object_dict, default=list, is_list=True, ), 'is_array':
class NodeCache(DBObject): """Basic Node Cache with db interface.""" FIELDS = { '_id': DBObjectField( type=ObjectId, default=ObjectId, is_list=False, ), 'key': DBObjectField( type=str, default='', is_list=False, ), 'graph_id': DBObjectField( type=ObjectId, default=None, is_list=False, ), 'node_id': DBObjectField( type=ObjectId, default=None, is_list=False, ), 'outputs': DBObjectField( type=Output, default=list, is_list=True, ), 'logs': DBObjectField( type=Output, default=list, is_list=True, ), # `protected` is used to prevent removing saved cache 'protected': DBObjectField( type=bool, default=False, is_list=False, ), 'removed': DBObjectField( type=bool, default=False, is_list=False, ), } DB_COLLECTION = Collections.NODE_CACHE IGNORED_PARAMETERS = {'cmd', '_timeout'} @staticmethod def instantiate(node, graph_id, user_id): """Instantiate a Node Cache from Node. Args: node (Node): Node object graph_id (ObjectId, str): Graph ID user_id (ObjectId, str): User ID Return: (NodeCache) """ return NodeCache({ 'key': NodeCache.generate_key(node, user_id), 'node_id': node._id, 'graph_id': graph_id, 'outputs': [output.to_dict() for output in node.outputs], 'logs': [log.to_dict() for log in node.logs], }) # TODO after Demo: remove user_id @staticmethod def generate_key(node, user_id): """Generate hash. Args: node (Node): Node object user_id (ObjectId, str): User ID Return: (str) Hash value """ if not demo_config.enabled: user_id = '' # TODO after demo inputs = node.inputs parameters = node.parameters original_node_id = node.original_node_id sorted_inputs = sorted(inputs, key=lambda x: x.name) inputs_hash = ','.join([ '{}:{}'.format( input.name, ','.join(sorted(map(lambda x: x.resource_id, input.values)))) for input in sorted_inputs ]) sorted_parameters = sorted(parameters, key=lambda x: x.name) parameters_hash = ','.join([ '{}:{}'.format(parameter.name, parameter.value) for parameter in sorted_parameters if parameter.name not in NodeCache.IGNORED_PARAMETERS ]) return hashlib.sha256('{};{};{};{}'.format( original_node_id, inputs_hash, parameters_hash, str(user_id)).encode('utf-8')).hexdigest() def __str__(self): return 'NodeCache(_id="{}")'.format(self._id) def __repr__(self): return 'NodeCache({})'.format(str(self.to_dict()))