Beispiel #1
0
class InputValue(DBObject):
    """Basic Value of the Input structure."""

    FIELDS = {
        'node_id': DBObjectField(
            type=str,
            default='',
            is_list=False,
        ),
        'output_id': DBObjectField(
            type=str,
            default='',
            is_list=False,
        ),
        'resource_id': DBObjectField(
            type=str,
            default='',
            is_list=False,
        ),
    }

    def __str__(self):
        return 'InputValue({}, {})'.format(self.node_id, self.output_id)

    def __repr__(self):
        return 'InputValue({})'.format(str(self.to_dict()))
Beispiel #2
0
class Output(DBObject):
    """Basic Output structure."""

    FIELDS = {
        'name': DBObjectField(
            type=str,
            default='',
            is_list=False,
            ),
        'file_type': DBObjectField(
            type=str,
            default='',
            is_list=False,
            ),
        'resource_id': DBObjectField(
            type=str,
            default=None,
            is_list=False,
            ),
    }

    def __str__(self):
        return 'Output(name="{}")'.format(self.name)

    def __repr__(self):
        return 'Output({})'.format(str(self.to_dict()))
Beispiel #3
0
class MasterState(DBObject):
    """Master statistics snapshot."""

    FIELDS = {
        '_id': DBObjectField(
            type=ObjectId,
            default=ObjectId,
            is_list=False,
            ),
        'workers': DBObjectField(
            type=WorkerState,
            default=list,
            is_list=True,
            ),
    }

    DB_COLLECTION = Collections.MASTER_STATE
Beispiel #4
0
class ParameterEnum(DBObject):
    """Enum value."""

    FIELDS = {
        'values': DBObjectField(
            type=str,
            default=list,
            is_list=True,
        ),
        'index': DBObjectField(
            type=str,
            default=-1,
            is_list=False,
        ),
    }

    def __repr__(self):
        return 'ParameterEnum({})'.format(str(self.to_dict()))
Beispiel #5
0
class ParameterCode(DBObject):
    """Code value."""

    FIELDS = {
        'value': DBObjectField(
            type=str,
            default='',
            is_list=False,
        ),
        'mode': DBObjectField(
            type=str,
            default='python',
            is_list=False,
        ),
    }

    # Unused
    MODES = {'python'}

    def __repr__(self):
        return 'ParameterCode({})'.format(str(self.to_dict()))
Beispiel #6
0
class WorkerState(DBObject):
    """Status of the worker"""
    FIELDS = {
        '_id': DBObjectField(
            type=ObjectId,
            default=ObjectId,
            is_list=False,
            ),
        'worker_id': DBObjectField(
            type=str,
            default=None,
            is_list=False,
            ),
        'graph_id': DBObjectField(
            type=ObjectId,
            default=None,
            is_list=False,
            ),
        'node': DBObjectField(
            type=Node,
            default=None,
            is_list=False,
            ),
        'host': DBObjectField(
            type=str,
            default='',
            is_list=False,
            ),
        'num_finished_jobs': DBObjectField(
            type=int,
            default=0,
            is_list=False,
            ),
    }
Beispiel #7
0
class GraphCancellation(DBObject):
    """GraphCancellation represents Graph Cancellation event in the database."""

    FIELDS = {
        '_id': DBObjectField(
            type=ObjectId,
            default=ObjectId,
            is_list=False,
        ),
        "graph_id": DBObjectField(
            type=ObjectId,
            default=None,
            is_list=False,
        ),
        "acknowledged": DBObjectField(
            type=bool,
            default=False,
            is_list=False,
        ),
    }

    DB_COLLECTION = Collections.GRAPHS_CANCELLATIONS
Beispiel #8
0
class Input(DBObject):
    """Basic Input structure."""

    FIELDS = {
        'name': DBObjectField(
            type=str,
            default='',
            is_list=False,
        ),
        'file_types': DBObjectField(
            type=str,
            default=list,
            is_list=True,
        ),
        'values': DBObjectField(
            type=InputValue,
            default=list,
            is_list=True,
        ),
        'min_count': DBObjectField(
            type=int,
            default=1,
            is_list=False,
        ),
        'max_count': DBObjectField(
            type=int,
            default=1,
            is_list=False,
        ),
    }

    def __str__(self):
        return 'Input(name="{}")'.format(self.name)

    def __repr__(self):
        return 'Input({})'.format(str(self.to_dict()))
Beispiel #9
0
class ParameterWidget(DBObject):
    """Basic ParameterWidget structure."""

    FIELDS = {
        'alias': DBObjectField(
            type=str,
            default='',
            is_list=False,
        ),
    }

    def __str__(self):
        return 'ParameterWidget(alias="{}")'.format(self.alias)

    def __repr__(self):
        return 'ParameterWidget({})'.format(str(self.to_dict()))
Beispiel #10
0
class Graph(DBObject):
    """Basic graph with db interface."""

    FIELDS = {
        '_id':
        DBObjectField(
            type=ObjectId,
            default=ObjectId,
            is_list=False,
        ),
        'title':
        DBObjectField(
            type=str,
            default='Title',
            is_list=False,
        ),
        'description':
        DBObjectField(
            type=str,
            default='Description',
            is_list=False,
        ),
        'graph_running_status':
        DBObjectField(
            type=str,
            default=GraphRunningStatus.CREATED,
            is_list=False,
        ),
        'author':
        DBObjectField(
            type=ObjectId,
            default=None,
            is_list=False,
        ),
        'public':
        DBObjectField(
            type=bool,
            default=False,
            is_list=False,
        ),
        'nodes':
        DBObjectField(
            type=Node,
            default=list,
            is_list=True,
        ),
    }

    DB_COLLECTION = Collections.GRAPHS

    def cancel(self):
        """Cancel the graph."""
        self.graph_running_status = GraphRunningStatus.CANCELED
        for node in self.nodes:
            if node.node_running_status in [
                    NodeRunningStatus.RUNNING, NodeRunningStatus.IN_QUEUE
            ]:
                node.node_running_status = NodeRunningStatus.CANCELED
        self.save(force=True)

    def get_validation_error(self):
        """Validate Graph.

        Return:
            (ValidationError)   Validation error if found; else None
        """
        violations = []
        if self.title == '':
            violations.append(
                ValidationError(
                    target=ValidationTargetType.PROPERTY,
                    object_id='title',
                    validation_code=ValidationCode.MISSING_PARAMETER))
        if self.description == '':
            violations.append(
                ValidationError(
                    target=ValidationTargetType.PROPERTY,
                    object_id='description',
                    validation_code=ValidationCode.MISSING_PARAMETER))

        # Meaning the node is in the graph. Otherwise souldn't be in validation step
        for node in self.nodes:
            node_violation = node.get_validation_error()
            if node_violation:
                violations.append(node_violation)

        if len(violations) == 0:
            return None

        return ValidationError(target=ValidationTargetType.GRAPH,
                               object_id=str(self._id),
                               validation_code=ValidationCode.IN_DEPENDENTS,
                               children=violations)

    def arrange_auto_layout(self):
        """Use heuristic to rearange nodes."""
        HEADER_HEIGHT = 23
        DESCRIPTION_HEIGHT = 20
        FOOTER_HEIGHT = 10
        BORDERS_HEIGHT = 2
        ITEM_HEIGHT = 20
        SPACE_HEIGHT = 50
        LEFT_PADDING = 30
        TOP_PADDING = 80
        LEVEL_WIDTH = 252
        SPECIAL_PARAMETER_HEIGHT = 20
        SPECIAL_PARAMETER_TYPES = [ParameterTypes.CODE]
        min_node_height = HEADER_HEIGHT + DESCRIPTION_HEIGHT + FOOTER_HEIGHT + BORDERS_HEIGHT

        node_id_to_level = defaultdict(lambda: -1)
        node_id_to_node = {}
        queued_node_ids = set()
        children_ids = defaultdict(set)

        node_ids = set([node._id for node in self.nodes])
        non_zero_node_ids = set()
        for node in self.nodes:
            node_id_to_node[node._id] = node
            for input in node.inputs:
                for value in input.values:
                    parent_node_id = to_object_id(value.node_id)
                    non_zero_node_ids.add(parent_node_id)
                    children_ids[parent_node_id].add(node._id)

        leaves = node_ids - non_zero_node_ids
        to_visit = deque()
        for leaf_id in leaves:
            node_id_to_level[leaf_id] = 0
            to_visit.append(leaf_id)

        while to_visit:
            node_id = to_visit.popleft()
            node = node_id_to_node[node_id]
            node_level = max([node_id_to_level[node_id]] + [
                node_id_to_level[child_id] + 1
                for child_id in children_ids[node_id]
            ])
            node_id_to_level[node_id] = node_level
            for input in node.inputs:
                for value in input.values:
                    parent_node_id = to_object_id(value.node_id)
                    parent_level = node_id_to_level[parent_node_id]
                    node_id_to_level[parent_node_id] = max(
                        node_level + 1, parent_level)
                    if parent_node_id not in queued_node_ids:
                        to_visit.append(parent_node_id)
                        queued_node_ids.add(parent_node_id)

        max_level = max(node_id_to_level.values())
        level_to_node_ids = defaultdict(list)
        row_heights = defaultdict(lambda: 0)

        def get_index_helper(node, level):
            if level < 0:
                return 0
            parent_node_ids = set()
            for input in node.inputs:
                for value in input.values:
                    parent_node_ids.add(to_object_id(value.node_id))

            for index, node_id in enumerate(level_to_node_ids[level]):
                if node_id in parent_node_ids:
                    return index
            return -1

        def get_index(node, max_level, level):
            return tuple([
                get_index_helper(node, lvl)
                for lvl in range(max_level, level, -1)
            ])

        for node_id, level in node_id_to_level.items():
            level_to_node_ids[level].append(node_id)

        for level in range(max_level, -1, -1):
            level_node_ids = level_to_node_ids[level]
            index_to_node_id = []
            for node_id in level_node_ids:
                node = node_id_to_node[node_id]
                index = get_index(node, max_level, level)
                index_to_node_id.append((index, node_id))

            index_to_node_id.sort()
            level_to_node_ids[level] = [
                node_id for _, node_id in index_to_node_id
            ]

            for index, node_id in enumerate(level_to_node_ids[level]):
                node = node_id_to_node[node_id]
                special_parameters_count = sum(
                    1 if parameter.parameter_type in SPECIAL_PARAMETER_TYPES
                    and parameter.widget else 0
                    for parameter in node.parameters)
                node_height = sum([
                    min_node_height,
                    ITEM_HEIGHT * max(len(node.inputs), len(node.outputs)),
                    special_parameters_count * SPECIAL_PARAMETER_HEIGHT
                ])
                row_heights[index] = max(row_heights[index], node_height)

        cum_heights = [0]
        for index in range(len(row_heights)):
            cum_heights.append(cum_heights[-1] + row_heights[index] +
                               SPACE_HEIGHT)

        max_height = max(cum_heights)

        for level in range(max_level, -1, -1):
            level_node_ids = level_to_node_ids[level]
            level_height = cum_heights[len(level_node_ids)]
            level_padding = (max_height - level_height) // 2
            for index, node_id in enumerate(level_node_ids):
                node = node_id_to_node[node_id]
                node.x = LEFT_PADDING + (max_level - level) * LEVEL_WIDTH
                node.y = TOP_PADDING + level_padding + cum_heights[index]

    def __str__(self):
        return 'Graph(_id="{}", nodes={})'.format(self._id,
                                                  [str(b) for b in self.nodes])

    def __repr__(self):
        return 'Graph(_id="{}", title="{}", nodes={})'.format(
            self._id, self.title, str(self.nodes))
Beispiel #11
0
class Node(DBObject):
    """Basic Node with db interface."""

    FIELDS = {
        '_id':
        DBObjectField(
            type=ObjectId,
            default=ObjectId,
            is_list=False,
        ),
        'title':
        DBObjectField(
            type=str,
            default='Title',
            is_list=False,
        ),
        'description':
        DBObjectField(
            type=str,
            default='Description',
            is_list=False,
        ),
        'base_node_name':
        DBObjectField(
            type=str,
            default='bash_jinja2',
            is_list=False,
        ),
        'parent_node':
        DBObjectField(
            type=ObjectId,
            default=None,
            is_list=False,
        ),
        'successor_node':
        DBObjectField(
            type=ObjectId,
            default=None,
            is_list=False,
        ),
        'inputs':
        DBObjectField(
            type=Input,
            default=list,
            is_list=True,
        ),
        'outputs':
        DBObjectField(
            type=Output,
            default=list,
            is_list=True,
        ),
        'parameters':
        DBObjectField(
            type=Parameter,
            default=list,
            is_list=True,
        ),
        'logs':
        DBObjectField(
            type=Output,
            default=list,
            is_list=True,
        ),
        'node_running_status':
        DBObjectField(
            type=str,
            default=NodeRunningStatus.CREATED,
            is_list=False,
        ),
        'node_status':
        DBObjectField(
            type=str,
            default=NodeStatus.CREATED,
            is_list=False,
        ),
        'cache_url':
        DBObjectField(
            type=str,
            default='',
            is_list=False,
        ),
        'x':
        DBObjectField(
            type=int,
            default=0,
            is_list=False,
        ),
        'y':
        DBObjectField(
            type=int,
            default=0,
            is_list=False,
        ),
        'author':
        DBObjectField(
            type=ObjectId,
            default=None,
            is_list=False,
        ),
        'public':
        DBObjectField(
            type=bool,
            default=False,
            is_list=False,
        ),
    }

    DB_COLLECTION = 'nodes'

    def get_validation_error(self):
        """Validate Node.

        Return:
            (ValidationError)   Validation error if found; else None
        """
        violations = []
        if self.title == '':
            violations.append(
                ValidationError(
                    target=ValidationTargetType.PROPERTY,
                    object_id='title',
                    validation_code=ValidationCode.MISSING_PARAMETER))

        for input in self.inputs:
            if input.min_count < 0:
                violations.append(
                    ValidationError(target=ValidationTargetType.INPUT,
                                    object_id=input.name,
                                    validation_code=ValidationCode.
                                    MINIMUM_COUNT_MUST_NOT_BE_NEGATIVE))
            if input.min_count > input.max_count and input.max_count > 0:
                violations.append(
                    ValidationError(
                        target=ValidationTargetType.INPUT,
                        object_id=input.name,
                        validation_code=ValidationCode.
                        MINIMUM_COUNT_MUST_BE_GREATER_THAN_MAXIMUM))
            if input.max_count == 0:
                violations.append(
                    ValidationError(target=ValidationTargetType.INPUT,
                                    object_id=input.name,
                                    validation_code=ValidationCode.
                                    MAXIMUM_COUNT_MUST_NOT_BE_ZERO))

        # Meaning the node is in the graph. Otherwise souldn't be in validation step
        if self.node_status != NodeStatus.CREATED:
            for input in self.inputs:
                if len(input.values) < input.min_count:
                    violations.append(
                        ValidationError(
                            target=ValidationTargetType.INPUT,
                            object_id=input.name,
                            validation_code=ValidationCode.MISSING_INPUT))

            if self.node_status == NodeStatus.MANDATORY_DEPRECATED:
                violations.append(
                    ValidationError(
                        target=ValidationTargetType.NODE,
                        object_id=str(self._id),
                        validation_code=ValidationCode.DEPRECATED_NODE))

        if len(violations) == 0:
            return None

        return ValidationError(target=ValidationTargetType.NODE,
                               object_id=str(self._id),
                               validation_code=ValidationCode.IN_DEPENDENTS,
                               children=violations)

    def apply_properties(self, other_node):
        """Apply Properties and Inputs of another Node.

        Args:
            other_node  (Node):     A node to copy Properties and Inputs from
        """
        for other_input in other_node.inputs:
            for input in self.inputs:
                if other_input.name == input.name:
                    if (input.max_count < 0 or
                            input.max_count >= other_input.max_count) and set(
                                input.file_types) >= set(
                                    other_input.file_types):
                        input.values = other_input.values
                    break

        for other_parameter in other_node.parameters:
            for parameter in self.parameters:
                if other_parameter.name == parameter.name:
                    if parameter.parameter_type == other_parameter.parameter_type and parameter.widget:
                        parameter.value = other_parameter.value
                    break

        self.description = other_node.description

        self.x = other_node.x
        self.y = other_node.y

    def __str__(self):
        return 'Node(_id="{}")'.format(self._id)

    def __repr__(self):
        return 'Node({})'.format(str(self.to_dict()))

    def _get_custom_element(self, arr, name):
        for parameter in arr:
            if parameter.name == name:
                return parameter
        raise Exception('Parameter "{}" not found in {}'.format(
            name, self.title))

    def get_input_by_name(self, name):
        return self._get_custom_element(self.inputs, name)

    def get_parameter_by_name(self, name):
        return self._get_custom_element(self.parameters, name)

    def get_output_by_name(self, name):
        return self._get_custom_element(self.outputs, name)

    def get_log_by_name(self, name):
        return self._get_custom_element(self.logs, name)
Beispiel #12
0
class NodeCache(DBObject):
    """Basic Node Cache with db interface."""

    FIELDS = {
        '_id': DBObjectField(
            type=ObjectId,
            default=ObjectId,
            is_list=False,
            ),
        'key': DBObjectField(
            type=str,
            default='',
            is_list=False,
            ),
        'graph_id': DBObjectField(
            type=ObjectId,
            default=None,
            is_list=False,
            ),
        'node_id': DBObjectField(
            type=ObjectId,
            default=None,
            is_list=False,
            ),
        'outputs': DBObjectField(
            type=Output,
            default=list,
            is_list=True,
            ),
        'logs': DBObjectField(
            type=Output,
            default=list,
            is_list=True,
            ),
    }

    DB_COLLECTION = 'node_cache'

    IGNORED_PARAMETERS = {'cmd'}

    @staticmethod
    def instantiate(node, graph_id, user_id):
        """Instantiate a Node Cache from Node.

        Args:
            node        (Node):             Node object
            graph_id    (ObjectId, str):    Graph ID
            user_id     (ObjectId, str):    User ID

        Return:
            (NodeCache)
        """

        return NodeCache({
            'key': NodeCache.generate_key(node, user_id),
            'node_id': node._id,
            'graph_id': graph_id,
            'outputs': [output.to_dict() for output in node.outputs],
            'logs': [log.to_dict() for log in node.logs],
        })

    # TODO after Demo: remove user_id
    @staticmethod
    def generate_key(node, user_id):
        """Generate hash.

        Args:
            node        (Node):             Node object
            user_id     (ObjectId, str):    User ID

        Return:
            (str)   Hash value
        """
        inputs = node.inputs
        parameters = node.parameters
        parent_node = node.parent_node

        sorted_inputs = sorted(inputs, key=lambda x: x.name)
        inputs_hash = ','.join([
            '{}:{}'.format(
                input.name,
                ','.join(sorted(map(lambda x: x.resource_id, input.values)))
            )
            for input in sorted_inputs
        ])

        sorted_parameters = sorted(parameters, key=lambda x: x.name)
        parameters_hash = ','.join([
            '{}:{}'.format(
                parameter.name,
                parameter.value
            )
            for parameter in sorted_parameters if parameter.name not in NodeCache.IGNORED_PARAMETERS
        ])

        return hashlib.sha256(
            '{};{};{};{}'.format(
                parent_node,
                inputs_hash,
                parameters_hash,
                str(user_id)).encode('utf-8')
        ).hexdigest()

    def __str__(self):
        return 'NodeCache(_id="{}")'.format(self._id)

    def __repr__(self):
        return 'NodeCache({})'.format(str(self.to_dict()))
Beispiel #13
0
class User(DBObject):
    """Basic User class with db interface."""

    FIELDS = {
        '_id': DBObjectField(
            type=ObjectId,
            default=ObjectId,
            is_list=False,
            ),
        'username': DBObjectField(
            type=str,
            default='',
            is_list=False,
            ),
        'password_hash': DBObjectField(
            type=str,
            default='',
            is_list=False,
            ),
    }

    DB_COLLECTION = 'users'

    def hash_password(self, password):
        """Change password.

        Args:
            password    (str)   Real password string
        """
        self.password_hash = pwd_context.encrypt(password)

    def verify_password(self, password):
        """Verify password.

        Args:
            password    (str)   Real password string

        Return:
            (bool)    True if password matches else False
        """
        return pwd_context.verify(password, self.password_hash)

    def generate_access_token(self, expiration=600):
        """Generate access token.

        Args:
            expiration  (int)   Time to Live (TTL) in sec

        Return:
            (str)   Secured token
        """
        s = TimedSerializer(get_auth_config().secret_key, expires_in=expiration)
        return s.dumps({'username': self.username, 'type': 'access'})

    def generate_refresh_token(self):
        """Generate refresh token.

        Return:
            (str)   Secured token
        """
        s = Serializer(get_auth_config().secret_key)
        return s.dumps({'username': self.username, 'type': 'refresh'})

    def __str__(self):
        return 'User(_id="{}", username={})'.format(self._id, self.username)

    def __repr__(self):
        return 'User({})'.format(self.to_dict())

    def __getattr__(self, name):
        raise Exception("Can't get attribute '{}'".format(name))

    @staticmethod
    def find_user_by_name(username):
        """Find User.

        Args:
            username    (str)   Username

        Return:
            (User)   User object or None
        """
        user_dict = get_db_connector().users.find_one({'username': username})
        if not user_dict:
            return None

        return User(user_dict)

    @staticmethod
    def verify_auth_token(token):
        """Verify token.

        Args:
            token   (str)   Token

        Return:
            (User)   User object or None
        """
        s = TimedSerializer(get_auth_config().secret_key)
        try:
            data = s.loads(token)
            if data['type'] != 'access':
                raise Exception('Not access token')
        except (BadSignature, SignatureExpired) as e:
            # access token is not valid or expired
            s = Serializer(get_auth_config().secret_key)
            try:
                data = s.loads(token)
                if data['type'] != 'refresh':
                    raise Exception('Not refresh token')
            except Exception:
                return None
        except Exception as e:
            print("Unexpected exception: {}".format(e))
            return None
        user = User.find_user_by_name(data['username'])
        return user
Beispiel #14
0
class Parameter(DBObject):
    """Basic Parameter structure."""

    FIELDS = {
        'name':
        DBObjectField(
            type=str,
            default='',
            is_list=False,
        ),
        'parameter_type':
        DBObjectField(
            type=str,
            default=ParameterTypes.STR,
            is_list=False,
        ),
        # TODO make type factory
        'value':
        DBObjectField(
            type=lambda x: x,  # Preserve type
            default='',
            is_list=False,
        ),
        'mutable_type':
        DBObjectField(
            type=bool,
            default=True,
            is_list=False,
        ),
        'removable':
        DBObjectField(
            type=bool,
            default=True,
            is_list=False,
        ),
        'publicable':
        DBObjectField(
            type=bool,
            default=True,
            is_list=False,
        ),
        'widget':
        DBObjectField(
            type=ParameterWidget,
            default=None,
            is_list=False,
        ),
    }

    def __init__(self, obj_dict=None):
        super(Parameter, self).__init__(obj_dict)

        # `value` field is a special case: the type depends on `parameter_type`
        if self.value is None:
            self.value = _get_default_by_type(self.parameter_type)
        elif self.parameter_type == ParameterTypes.ENUM:
            self.value = ParameterEnum.from_dict(self.value)
        elif self.parameter_type == ParameterTypes.CODE:
            self.value = ParameterCode.from_dict(self.value)
        assert _value_is_valid(self.value, self.parameter_type), \
            "Invalid parameter value type: {}: {}".format(self.name, self.value)

    def __str__(self):
        return 'Parameter(name="{}")'.format(self.name)

    def __repr__(self):
        return 'Parameter({})'.format(str(self.to_dict()))