def create_node( node_create: NodeCreate, db_node_type: DeclarativeMeta, db: Session, exclude: dict = None, ) -> DeclarativeMeta: """ Helper function when creating a new Node that sets the attributes inherited from Node. """ db_node: Node = db_node_type(**node_create.dict(exclude=exclude)) if node_create.directives: db_node.directives = crud.read_by_values(values=node_create.directives, db_table=NodeDirective, db=db) if node_create.tags: db_node.tags = crud.read_by_values(values=node_create.tags, db_table=NodeTag, db=db) if node_create.threat_actor: db_node.threat_actor = crud.read_by_value( value=node_create.threat_actor, db_table=NodeThreatActor, db=db) if node_create.threats: db_node.threats = crud.read_by_values(values=node_create.threats, db_table=NodeThreat, db=db) return db_node
def create_event( event: EventCreate, request: Request, response: Response, db: Session = Depends(get_db), ): # Create the new event Node using the data from the request new_event: Event = create_node(node_create=event, db_node_type=Event, db=db, exclude={"alert_uuids"}) # Set the required event properties new_event.status = crud.read_by_value(value=event.status, db_table=EventStatus, db=db) # Set the various optional event properties if they were given in the request. if event.owner: new_event.owner = crud.read_user_by_username(username=event.owner, db=db) if event.prevention_tools: new_event.prevention_tools = crud.read_by_values( values=event.prevention_tools, db_table=EventPreventionTool, db=db, ) if event.remediations: new_event.remediations = crud.read_by_values(values=event.remediations, db_table=EventRemediation, db=db) if event.risk_level: new_event.risk_level = crud.read_by_value(value=event.risk_level, db_table=EventRiskLevel, db=db) if event.source: new_event.source = crud.read_by_value(value=event.source, db_table=EventSource, db=db) if event.type: new_event.type = crud.read_by_value(value=event.type, db_table=EventType, db=db) if event.vectors: new_event.vectors = crud.read_by_values(values=event.vectors, db_table=EventVector, db=db) # Save the new event to the database db.add(new_event) crud.commit(db) response.headers["Content-Location"] = request.url_for("get_event", uuid=new_event.uuid)
def update_analysis_module_type( uuid: UUID, analysis_module_type: AnalysisModuleTypeUpdate, request: Request, response: Response, db: Session = Depends(get_db), ): # Read the current analysis module type from the database db_analysis_module_type: AnalysisModuleType = crud.read( uuid=uuid, db_table=AnalysisModuleType, db=db) # Get the data that was given in the request and use it to update the database object update_data = analysis_module_type.dict(exclude_unset=True) if "description" in update_data: db_analysis_module_type.description = update_data["description"] if "extended_version" in update_data: db_analysis_module_type.extended_version = update_data[ "extended_version"] if "manual" in update_data: db_analysis_module_type.manual = update_data["manual"] if "value" in update_data: db_analysis_module_type.value = update_data["value"] if "observable_types" in update_data: db_analysis_module_type.observable_types = crud.read_by_values( values=update_data["observable_types"], db_table=ObservableType, db=db) if "required_directives" in update_data: db_analysis_module_type.required_directives = crud.read_by_values( values=update_data["required_directives"], db_table=NodeDirective, db=db) if "required_tags" in update_data: db_analysis_module_type.required_tags = crud.read_by_values( values=update_data["required_tags"], db_table=NodeTag, db=db) if "version" in update_data: db_analysis_module_type.version = update_data["version"] crud.commit(db) response.headers["Content-Location"] = request.url_for( "get_analysis_module_type", uuid=uuid)
def update_node_threat( uuid: UUID, node_threat: NodeThreatUpdate, request: Request, response: Response, db: Session = Depends(get_db), ): # Read the current node threat from the database db_node_threat: NodeThreat = crud.read(uuid=uuid, db_table=NodeThreat, db=db) # Get the data that was given in the request and use it to update the database object update_data = node_threat.dict(exclude_unset=True) if "description" in update_data: db_node_threat.description = update_data["description"] if "value" in update_data: db_node_threat.value = update_data["value"] if "types" in update_data: db_node_threat.types = crud.read_by_values(values=update_data["types"], db_table=NodeThreatType, db=db) crud.commit(db) response.headers["Content-Location"] = request.url_for("get_node_threat", uuid=uuid)
def create_user( user: UserCreate, request: Request, response: Response, db: Session = Depends(get_db), ): # Create the new user using the data from the request new_user = User(**user.dict()) # Get the alert queue from the database to associate with the new user new_user.default_alert_queue = crud.read_by_value(user.default_alert_queue, db_table=AlertQueue, db=db) # Get the user roles from the database to associate with the new user new_user.roles = crud.read_by_values(user.roles, db_table=UserRole, db=db) # Securely hash and salt the password. Bcrypt_256 is used to get around the Bcrypt limitations # of silently truncating passwords longer than 72 characters as well as not handling NULL bytes. new_user.password = hash_password(new_user.password) # Save the new user to the database db.add(new_user) crud.commit(db) response.headers["Content-Location"] = request.url_for("get_user", uuid=new_user.uuid)
def create_analysis_module_type( analysis_module_type: AnalysisModuleTypeCreate, request: Request, response: Response, db: Session = Depends(get_db), ): # Create the new analysis module type using the data from the request new_analysis_module_type = AnalysisModuleType( **analysis_module_type.dict()) # If observable types were given, get them from the database and use them in the new analysis module type db_observable_types = [] if analysis_module_type.observable_types: db_observable_types = crud.read_by_values( values=analysis_module_type.observable_types, db_table=ObservableType, db=db) new_analysis_module_type.observable_types = db_observable_types # If required directives were given, get them from the database and use them in the new analysis module type db_required_directives = [] if analysis_module_type.required_directives: db_required_directives = crud.read_by_values( values=analysis_module_type.required_directives, db_table=NodeDirective, db=db) new_analysis_module_type.required_directives = db_required_directives # If required tags were given, get them from the database and use them in the new analysis module type db_required_tags = [] if analysis_module_type.required_tags: db_required_tags = crud.read_by_values( values=analysis_module_type.required_tags, db_table=NodeTag, db=db) new_analysis_module_type.required_tags = db_required_tags # Save the new analysis module type to the database db.add(new_analysis_module_type) crud.commit(db) response.headers["Content-Location"] = request.url_for( "get_analysis_module_type", uuid=new_analysis_module_type.uuid)
def update_node(node_update: NodeUpdate, uuid: UUID, db_table: DeclarativeMeta, db: Session) -> DeclarativeMeta: """ Helper function when updating a Node that enforces version matching and updates the attributes inherited from Node. """ # Fetch the Node from the database db_node: Node = crud.read(uuid=uuid, db_table=db_table, db=db) # Get the data that was given in the request and use it to update the database object update_data = node_update.dict(exclude_unset=True) # Return an exception if the passed in version does not match the Node's current version if update_data["version"] != db_node.version: raise HTTPException( status_code=status.HTTP_409_CONFLICT, detail="Unable to update Node due to version mismatch") if "directives" in update_data: db_node.directives = crud.read_by_values( values=update_data["directives"], db_table=NodeDirective, db=db) if "tags" in update_data: db_node.tags = crud.read_by_values(values=update_data["tags"], db_table=NodeTag, db=db) if "threat_actor" in update_data: db_node.threat_actor = crud.read_by_value( value=update_data["threat_actor"], db_table=NodeThreatActor, db=db) if "threats" in update_data: db_node.threats = crud.read_by_values(values=update_data["threats"], db_table=NodeThreat, db=db) # Update the node version db_node.version = uuid4() return db_node
def update_user( uuid: UUID, user: UserUpdate, request: Request, response: Response, db: Session = Depends(get_db), ): # Read the current user from the database db_user: User = crud.read(uuid=uuid, db_table=User, db=db) # Get the data that was given in the request and use it to update the database object update_data = user.dict(exclude_unset=True) if "default_alert_queue" in update_data: db_user.default_alert_queue = crud.read_by_value( value=update_data["default_alert_queue"], db_table=AlertQueue, db=db) if "display_name" in update_data: db_user.display_name = update_data["display_name"] if "email" in update_data: db_user.email = update_data["email"] if "enabled" in update_data: db_user.enabled = update_data["enabled"] if "password" in update_data: db_user.password = hash_password(update_data["password"]) if "roles" in update_data: db_user.roles = crud.read_by_values(values=update_data["roles"], db_table=UserRole, db=db) if "timezone" in update_data: db_user.timezone = update_data["timezone"] if "username" in update_data: db_user.username = update_data["username"] crud.commit(db) response.headers["Content-Location"] = request.url_for("get_user", uuid=uuid)
def create_node_threat( node_threat: NodeThreatCreate, request: Request, response: Response, db: Session = Depends(get_db), ): # Make sure that all the threat types that were given actually exist db_threat_types = crud.read_by_values(values=node_threat.types, db_table=NodeThreatType, db=db) # Create the new node threat new_threat = NodeThreat(**node_threat.dict()) # Set the threat types on the new node threat new_threat.types = db_threat_types # Save the new node threat to the database db.add(new_threat) crud.commit(db) response.headers["Content-Location"] = request.url_for( "get_node_threat", uuid=new_threat.uuid)
for value in data["user_role"]: db.add(UserRole(value=value)) print(f"Adding user role: {value}") else: # Make sure there is always an "admin" role db.add(UserRole(value="admin")) print("Adding user role: admin") # Add an "analyst" user if there are no existing users if not crud.read_all(db_table=User, db=db): # Commit the database changes so that they can be used to create the analyst user crud.commit(db) db.add( User( default_alert_queue=crud.read_by_value(value="default", db_table=AlertQueue, db=db), display_name="Analyst", email="analyst@localhost", password="******", roles=crud.read_by_values(values=["admin"], db_table=UserRole, db=db), username="******", )) print("Adding user: analyst") # Commit all of the changes crud.commit(db)
def update_event( uuid: UUID, event: EventUpdate, request: Request, response: Response, db: Session = Depends(get_db), ): # Update the Node attributes db_event: Event = update_node(node_update=event, uuid=uuid, db_table=Event, db=db) # Get the data that was given in the request and use it to update the database object update_data = event.dict(exclude_unset=True) if "alert_time" in update_data: db_event.alert_time = update_data["alert_time"] if "contain_time" in update_data: db_event.contain_time = update_data["contain_time"] if "disposition_time" in update_data: db_event.disposition_time = update_data["disposition_time"] if "event_time" in update_data: db_event.event_time = update_data["event_time"] if "name" in update_data: db_event.name = update_data["name"] if "owner" in update_data: db_event.owner = crud.read_user_by_username( username=update_data["owner"], db=db) if "ownership_time" in update_data: db_event.ownership_time = update_data["ownership_time"] if "prevention_tools" in update_data: db_event.prevention_tools = crud.read_by_values( values=update_data["prevention_tools"], db_table=EventPreventionTool, db=db, ) if "remediation_time" in update_data: db_event.remediation_time = update_data["remediation_time"] if "remediations" in update_data: db_event.remediations = crud.read_by_values( values=update_data["remediations"], db_table=EventRemediation, db=db, ) if "risk_level" in update_data: db_event.risk_level = crud.read_by_value( value=update_data["risk_level"], db_table=EventRiskLevel, db=db) if "source" in update_data: db_event.source = crud.read_by_value(value=update_data["source"], db_table=EventSource, db=db) if "status" in update_data: db_event.status = crud.read_by_value(value=update_data["status"], db_table=EventStatus, db=db) if "type" in update_data: db_event.type = crud.read_by_value(value=update_data["type"], db_table=EventType, db=db) if "vectors" in update_data: db_event.vectors = crud.read_by_values(values=update_data["vectors"], db_table=EventVector, db=db) crud.commit(db) response.headers["Content-Location"] = request.url_for("get_event", uuid=uuid)