コード例 #1
0
def create_node(
    node_create: NodeCreate,
    db_node_type: DeclarativeMeta,
    db: Session,
    exclude: dict = None,
) -> DeclarativeMeta:
    """
    Helper function when creating a new Node that sets the attributes inherited from Node.
    """

    db_node: Node = db_node_type(**node_create.dict(exclude=exclude))

    if node_create.directives:
        db_node.directives = crud.read_by_values(values=node_create.directives,
                                                 db_table=NodeDirective,
                                                 db=db)

    if node_create.tags:
        db_node.tags = crud.read_by_values(values=node_create.tags,
                                           db_table=NodeTag,
                                           db=db)

    if node_create.threat_actor:
        db_node.threat_actor = crud.read_by_value(
            value=node_create.threat_actor, db_table=NodeThreatActor, db=db)

    if node_create.threats:
        db_node.threats = crud.read_by_values(values=node_create.threats,
                                              db_table=NodeThreat,
                                              db=db)

    return db_node
コード例 #2
0
ファイル: event.py プロジェクト: hollyfoxx/ace2-gui
def create_event(
        event: EventCreate,
        request: Request,
        response: Response,
        db: Session = Depends(get_db),
):
    # Create the new event Node using the data from the request
    new_event: Event = create_node(node_create=event,
                                   db_node_type=Event,
                                   db=db,
                                   exclude={"alert_uuids"})

    # Set the required event properties
    new_event.status = crud.read_by_value(value=event.status,
                                          db_table=EventStatus,
                                          db=db)

    # Set the various optional event properties if they were given in the request.
    if event.owner:
        new_event.owner = crud.read_user_by_username(username=event.owner,
                                                     db=db)

    if event.prevention_tools:
        new_event.prevention_tools = crud.read_by_values(
            values=event.prevention_tools,
            db_table=EventPreventionTool,
            db=db,
        )

    if event.remediations:
        new_event.remediations = crud.read_by_values(values=event.remediations,
                                                     db_table=EventRemediation,
                                                     db=db)

    if event.risk_level:
        new_event.risk_level = crud.read_by_value(value=event.risk_level,
                                                  db_table=EventRiskLevel,
                                                  db=db)

    if event.source:
        new_event.source = crud.read_by_value(value=event.source,
                                              db_table=EventSource,
                                              db=db)

    if event.type:
        new_event.type = crud.read_by_value(value=event.type,
                                            db_table=EventType,
                                            db=db)

    if event.vectors:
        new_event.vectors = crud.read_by_values(values=event.vectors,
                                                db_table=EventVector,
                                                db=db)

    # Save the new event to the database
    db.add(new_event)
    crud.commit(db)

    response.headers["Content-Location"] = request.url_for("get_event",
                                                           uuid=new_event.uuid)
コード例 #3
0
def update_analysis_module_type(
        uuid: UUID,
        analysis_module_type: AnalysisModuleTypeUpdate,
        request: Request,
        response: Response,
        db: Session = Depends(get_db),
):
    # Read the current analysis module type from the database
    db_analysis_module_type: AnalysisModuleType = crud.read(
        uuid=uuid, db_table=AnalysisModuleType, db=db)

    # Get the data that was given in the request and use it to update the database object
    update_data = analysis_module_type.dict(exclude_unset=True)

    if "description" in update_data:
        db_analysis_module_type.description = update_data["description"]

    if "extended_version" in update_data:
        db_analysis_module_type.extended_version = update_data[
            "extended_version"]

    if "manual" in update_data:
        db_analysis_module_type.manual = update_data["manual"]

    if "value" in update_data:
        db_analysis_module_type.value = update_data["value"]

    if "observable_types" in update_data:
        db_analysis_module_type.observable_types = crud.read_by_values(
            values=update_data["observable_types"],
            db_table=ObservableType,
            db=db)

    if "required_directives" in update_data:
        db_analysis_module_type.required_directives = crud.read_by_values(
            values=update_data["required_directives"],
            db_table=NodeDirective,
            db=db)

    if "required_tags" in update_data:
        db_analysis_module_type.required_tags = crud.read_by_values(
            values=update_data["required_tags"], db_table=NodeTag, db=db)

    if "version" in update_data:
        db_analysis_module_type.version = update_data["version"]

    crud.commit(db)

    response.headers["Content-Location"] = request.url_for(
        "get_analysis_module_type", uuid=uuid)
コード例 #4
0
def update_node_threat(
        uuid: UUID,
        node_threat: NodeThreatUpdate,
        request: Request,
        response: Response,
        db: Session = Depends(get_db),
):
    # Read the current node threat from the database
    db_node_threat: NodeThreat = crud.read(uuid=uuid,
                                           db_table=NodeThreat,
                                           db=db)

    # Get the data that was given in the request and use it to update the database object
    update_data = node_threat.dict(exclude_unset=True)

    if "description" in update_data:
        db_node_threat.description = update_data["description"]

    if "value" in update_data:
        db_node_threat.value = update_data["value"]

    if "types" in update_data:
        db_node_threat.types = crud.read_by_values(values=update_data["types"],
                                                   db_table=NodeThreatType,
                                                   db=db)

    crud.commit(db)

    response.headers["Content-Location"] = request.url_for("get_node_threat",
                                                           uuid=uuid)
コード例 #5
0
def create_user(
        user: UserCreate,
        request: Request,
        response: Response,
        db: Session = Depends(get_db),
):
    # Create the new user using the data from the request
    new_user = User(**user.dict())

    # Get the alert queue from the database to associate with the new user
    new_user.default_alert_queue = crud.read_by_value(user.default_alert_queue,
                                                      db_table=AlertQueue,
                                                      db=db)

    # Get the user roles from the database to associate with the new user
    new_user.roles = crud.read_by_values(user.roles, db_table=UserRole, db=db)

    # Securely hash and salt the password. Bcrypt_256 is used to get around the Bcrypt limitations
    # of silently truncating passwords longer than 72 characters as well as not handling NULL bytes.
    new_user.password = hash_password(new_user.password)

    # Save the new user to the database
    db.add(new_user)
    crud.commit(db)

    response.headers["Content-Location"] = request.url_for("get_user",
                                                           uuid=new_user.uuid)
コード例 #6
0
def create_analysis_module_type(
        analysis_module_type: AnalysisModuleTypeCreate,
        request: Request,
        response: Response,
        db: Session = Depends(get_db),
):
    # Create the new analysis module type using the data from the request
    new_analysis_module_type = AnalysisModuleType(
        **analysis_module_type.dict())

    # If observable types were given, get them from the database and use them in the new analysis module type
    db_observable_types = []
    if analysis_module_type.observable_types:
        db_observable_types = crud.read_by_values(
            values=analysis_module_type.observable_types,
            db_table=ObservableType,
            db=db)
    new_analysis_module_type.observable_types = db_observable_types

    # If required directives were given, get them from the database and use them in the new analysis module type
    db_required_directives = []
    if analysis_module_type.required_directives:
        db_required_directives = crud.read_by_values(
            values=analysis_module_type.required_directives,
            db_table=NodeDirective,
            db=db)
    new_analysis_module_type.required_directives = db_required_directives

    # If required tags were given, get them from the database and use them in the new analysis module type
    db_required_tags = []
    if analysis_module_type.required_tags:
        db_required_tags = crud.read_by_values(
            values=analysis_module_type.required_tags, db_table=NodeTag, db=db)
    new_analysis_module_type.required_tags = db_required_tags

    # Save the new analysis module type to the database
    db.add(new_analysis_module_type)
    crud.commit(db)

    response.headers["Content-Location"] = request.url_for(
        "get_analysis_module_type", uuid=new_analysis_module_type.uuid)
コード例 #7
0
def update_node(node_update: NodeUpdate, uuid: UUID, db_table: DeclarativeMeta,
                db: Session) -> DeclarativeMeta:
    """
    Helper function when updating a Node that enforces version matching and updates the attributes inherited from Node.
    """

    # Fetch the Node from the database
    db_node: Node = crud.read(uuid=uuid, db_table=db_table, db=db)

    # Get the data that was given in the request and use it to update the database object
    update_data = node_update.dict(exclude_unset=True)

    # Return an exception if the passed in version does not match the Node's current version
    if update_data["version"] != db_node.version:
        raise HTTPException(
            status_code=status.HTTP_409_CONFLICT,
            detail="Unable to update Node due to version mismatch")

    if "directives" in update_data:
        db_node.directives = crud.read_by_values(
            values=update_data["directives"], db_table=NodeDirective, db=db)

    if "tags" in update_data:
        db_node.tags = crud.read_by_values(values=update_data["tags"],
                                           db_table=NodeTag,
                                           db=db)

    if "threat_actor" in update_data:
        db_node.threat_actor = crud.read_by_value(
            value=update_data["threat_actor"], db_table=NodeThreatActor, db=db)

    if "threats" in update_data:
        db_node.threats = crud.read_by_values(values=update_data["threats"],
                                              db_table=NodeThreat,
                                              db=db)

    # Update the node version
    db_node.version = uuid4()

    return db_node
コード例 #8
0
def update_user(
        uuid: UUID,
        user: UserUpdate,
        request: Request,
        response: Response,
        db: Session = Depends(get_db),
):
    # Read the current user from the database
    db_user: User = crud.read(uuid=uuid, db_table=User, db=db)

    # Get the data that was given in the request and use it to update the database object
    update_data = user.dict(exclude_unset=True)

    if "default_alert_queue" in update_data:
        db_user.default_alert_queue = crud.read_by_value(
            value=update_data["default_alert_queue"],
            db_table=AlertQueue,
            db=db)

    if "display_name" in update_data:
        db_user.display_name = update_data["display_name"]

    if "email" in update_data:
        db_user.email = update_data["email"]

    if "enabled" in update_data:
        db_user.enabled = update_data["enabled"]

    if "password" in update_data:
        db_user.password = hash_password(update_data["password"])

    if "roles" in update_data:
        db_user.roles = crud.read_by_values(values=update_data["roles"],
                                            db_table=UserRole,
                                            db=db)

    if "timezone" in update_data:
        db_user.timezone = update_data["timezone"]

    if "username" in update_data:
        db_user.username = update_data["username"]

    crud.commit(db)

    response.headers["Content-Location"] = request.url_for("get_user",
                                                           uuid=uuid)
コード例 #9
0
def create_node_threat(
        node_threat: NodeThreatCreate,
        request: Request,
        response: Response,
        db: Session = Depends(get_db),
):
    # Make sure that all the threat types that were given actually exist
    db_threat_types = crud.read_by_values(values=node_threat.types,
                                          db_table=NodeThreatType,
                                          db=db)

    # Create the new node threat
    new_threat = NodeThreat(**node_threat.dict())

    # Set the threat types on the new node threat
    new_threat.types = db_threat_types

    # Save the new node threat to the database
    db.add(new_threat)
    crud.commit(db)

    response.headers["Content-Location"] = request.url_for(
        "get_node_threat", uuid=new_threat.uuid)
コード例 #10
0
ファイル: seed.py プロジェクト: hollyfoxx/ace2-gui
        for value in data["user_role"]:
            db.add(UserRole(value=value))
            print(f"Adding user role: {value}")
    else:
        # Make sure there is always an "admin" role
        db.add(UserRole(value="admin"))
        print("Adding user role: admin")

# Add an "analyst" user if there are no existing users
if not crud.read_all(db_table=User, db=db):
    # Commit the database changes so that they can be used to create the analyst user
    crud.commit(db)

    db.add(
        User(
            default_alert_queue=crud.read_by_value(value="default",
                                                   db_table=AlertQueue,
                                                   db=db),
            display_name="Analyst",
            email="analyst@localhost",
            password="******",
            roles=crud.read_by_values(values=["admin"],
                                      db_table=UserRole,
                                      db=db),
            username="******",
        ))
    print("Adding user: analyst")

# Commit all of the changes
crud.commit(db)
コード例 #11
0
ファイル: event.py プロジェクト: hollyfoxx/ace2-gui
def update_event(
        uuid: UUID,
        event: EventUpdate,
        request: Request,
        response: Response,
        db: Session = Depends(get_db),
):
    # Update the Node attributes
    db_event: Event = update_node(node_update=event,
                                  uuid=uuid,
                                  db_table=Event,
                                  db=db)

    # Get the data that was given in the request and use it to update the database object
    update_data = event.dict(exclude_unset=True)

    if "alert_time" in update_data:
        db_event.alert_time = update_data["alert_time"]

    if "contain_time" in update_data:
        db_event.contain_time = update_data["contain_time"]

    if "disposition_time" in update_data:
        db_event.disposition_time = update_data["disposition_time"]

    if "event_time" in update_data:
        db_event.event_time = update_data["event_time"]

    if "name" in update_data:
        db_event.name = update_data["name"]

    if "owner" in update_data:
        db_event.owner = crud.read_user_by_username(
            username=update_data["owner"], db=db)

    if "ownership_time" in update_data:
        db_event.ownership_time = update_data["ownership_time"]

    if "prevention_tools" in update_data:
        db_event.prevention_tools = crud.read_by_values(
            values=update_data["prevention_tools"],
            db_table=EventPreventionTool,
            db=db,
        )

    if "remediation_time" in update_data:
        db_event.remediation_time = update_data["remediation_time"]

    if "remediations" in update_data:
        db_event.remediations = crud.read_by_values(
            values=update_data["remediations"],
            db_table=EventRemediation,
            db=db,
        )

    if "risk_level" in update_data:
        db_event.risk_level = crud.read_by_value(
            value=update_data["risk_level"], db_table=EventRiskLevel, db=db)

    if "source" in update_data:
        db_event.source = crud.read_by_value(value=update_data["source"],
                                             db_table=EventSource,
                                             db=db)

    if "status" in update_data:
        db_event.status = crud.read_by_value(value=update_data["status"],
                                             db_table=EventStatus,
                                             db=db)

    if "type" in update_data:
        db_event.type = crud.read_by_value(value=update_data["type"],
                                           db_table=EventType,
                                           db=db)

    if "vectors" in update_data:
        db_event.vectors = crud.read_by_values(values=update_data["vectors"],
                                               db_table=EventVector,
                                               db=db)

    crud.commit(db)

    response.headers["Content-Location"] = request.url_for("get_event",
                                                           uuid=uuid)