def __init__(self, doc_class, cb): super(Query, self).__init__(doc_class, cb, None) self._query_builder = QueryBuilder() self._sort_by = None self._group_by = None self._batch_size = 100 self._default_args = {}
def __init__(self, doc_class, cb): """Initialize a RunHistoryQuery object.""" self._doc_class = doc_class self._cb = cb self._count_valid = False super(RunHistoryQuery, self).__init__() self._query_builder = QueryBuilder() self._sort = {} self._criteria = {}
def __init__(self, doc_class, cb): """Initialize a FacetQuery object.""" self._doc_class = doc_class self._cb = cb self._count_valid = False super(FacetQuery, self).__init__() self._query_builder = QueryBuilder() self._facet_fields = [] self._criteria = {} self._run_id = None
def __init__(self, doc_class, cb): """Initialize a ResultQuery object.""" self._doc_class = doc_class self._cb = cb self._count_valid = False super(ResultQuery, self).__init__() self._query_builder = QueryBuilder() self._criteria = {} self._sort = {} self._batch_size = 100 self._run_id = None
def __init__(self, doc_class, cb, query=None): """Initialize a Query object.""" super(Query, self).__init__(doc_class, cb, query) # max batch_size is 5000 self._batch_size = 100 if query is not None: # copy existing .where(), and_() queries self._query_builder = QueryBuilder() self._query_builder._query = query._query_builder._query else: self._query_builder = QueryBuilder()
def __init__(self, doc_class, cb): """ Initialize the DeviceSearchQuery. Args: doc_class (class): The model class that will be returned by this query. cb (BaseAPI): Reference to API object used to communicate with the server. """ super().__init__(doc_class, cb) self._query_builder = QueryBuilder() self._criteria = {} self._time_filter = {} self._exclusions = {} self._sortcriteria = {}
def __init__(self, doc_class, cb): """ Initialize the USBDeviceApprovalQuery. Args: doc_class (class): The model class that will be returned by this query. cb (BaseAPI): Reference to API object used to communicate with the server. """ self._doc_class = doc_class self._cb = cb self._count_valid = False super(BaseQuery, self).__init__() self._query_builder = QueryBuilder() self._criteria = {} self._total_results = 0
def __init__(self, doc_class, cb): """ Initialize the BaseAlertSearchQuery. Args: doc_class (class): The model class that will be returned by this query. cb (BaseAPI): Reference to API object used to communicate with the server. """ super().__init__(doc_class, cb) self._query_builder = QueryBuilder() self._criteria = {} self._time_filter = {} self._sortcriteria = {} self._bulkupdate_url = "/appservices/v6/orgs/{0}/alerts/workflow/_criteria" self._count_valid = False self._total_results = 0
def __init__(self, doc_class, cb): """ Initialize the ReputationOverrideQuery. Args: doc_class (class): The model class that will be returned by this query. cb (BaseAPI): Reference to API object used to communicate with the server. """ self._doc_class = doc_class self._cb = cb self._count_valid = False super(ReputationOverrideQuery, self).__init__() self._query_builder = QueryBuilder() self._criteria = {} self._sortcriteria = {}
def __init__(self, doc_class, cb, device=None): """ Initialize the VulnerabilityQuery. Args: doc_class (class): The model class that will be returned by this query. cb (BaseAPI): Reference to API object used to communicate with the server. device (Device): Optional Device object to indicate VulnerabilityQuery is for a specific device """ self._doc_class = doc_class self._cb = cb self._count_valid = False super(BaseQuery, self).__init__() self._query_builder = QueryBuilder() self._criteria = {} self._sortcriteria = {} self._total_results = 0 self.device = device self._vcenter_uuid = None
class Query(PaginatedQuery, QueryBuilderSupportMixin, IterableQueryMixin): """Represents a prepared query to the Cb Enterprise EDR backend. This object is returned as part of a `CbEnterpriseEDRAPI.select` operation on models requested from the Cb Enterprise EDR backend. You should not have to create this class yourself. The query is not executed on the server until it's accessed, either as an iterator (where it will generate values on demand as they're requested) or as a list (where it will retrieve the entire result set and save to a list). You can also call the Python built-in ``len()`` on this object to retrieve the total number of items matching the query. Examples:: >>> from cbc_sdk.enterprise_edr import CBCloudAPI,Process >>> cb = CBCloudAPI() >>> query = cb.select(Process) >>> query = query.where(process_name="notepad.exe") >>> # alternatively: >>> query = query.where("process_name:notepad.exe") Notes: - The slicing operator only supports start and end parameters, but not step. ``[1:-1]`` is legal, but ``[1:2:-1]`` is not. - You can chain where clauses together to create AND queries; only objects that match all ``where`` clauses will be returned. """ def __init__(self, doc_class, cb): super(Query, self).__init__(doc_class, cb, None) self._query_builder = QueryBuilder() self._sort_by = None self._group_by = None self._batch_size = 100 self._default_args = {} def _get_query_parameters(self): args = self._default_args.copy() args['query'] = self._query_builder._collapse() if self._query_builder._process_guid is not None: args["process_guid"] = self._query_builder._process_guid if 'process_guid:' in args['query']: q = args['query'].split('process_guid:', 1) args["process_guid"] = q[1] args["fields"] = [ "*", "parent_hash", "parent_name", "process_cmdline", "backend_timestamp", "device_external_ip", "device_group", "device_internal_ip", "device_os", "device_policy", "process_effective_reputation", "process_reputation", "process_start_time", "ttp" ] return args def _count(self): args = self._get_query_parameters() log.debug("args: {}".format(str(args))) result = self._cb.post_object(self._doc_class.urlobject.format( self._cb.credentials.org_key, args["process_guid"]), body=args).json() self._total_results = int(result.get('num_available', 0)) self._count_valid = True return self._total_results def _validate(self, args): if not hasattr(self._doc_class, "validation_url"): return url = self._doc_class.validation_url.format( self._cb.credentials.org_key) if args.get('query', False): args['q'] = args['query'] # v2 search sort key does not work with v1 validation args.pop('sort', None) validated = self._cb.get_object(url, query_parameters=args) if not validated.get("valid"): raise ApiError("Invalid query: {}: {}".format( args, validated["invalid_message"])) def _search(self, start=0, rows=0): # iterate over total result set, 100 at a time args = self._get_query_parameters() self._validate(args) if start != 0: args['start'] = start args['rows'] = self._batch_size current = start numrows = 0 still_querying = True while still_querying: url = self._doc_class.urlobject.format( self._cb.credentials.org_key, args["process_guid"]) resp = self._cb.post_object(url, body=args) result = resp.json() self._total_results = result.get("num_available", 0) self._total_segments = result.get("total_segments", 0) self._processed_segments = result.get("processed_segments", 0) self._count_valid = True results = result.get('results', []) for item in results: yield item current += 1 numrows += 1 if rows and numrows == rows: still_querying = False break args[ 'start'] = current + 1 # as of 6/2017, the indexing on the Cb Endpoint Standard backend is still 1-based if current >= self._total_results: break if not results: log.debug( "server reported total_results overestimated the number of results for this query by {0}" .format(self._total_results - current)) log.debug( "resetting total_results for this query to {0}".format( current)) self._total_results = current break
class VulnerabilityQuery(BaseQuery, QueryBuilderSupportMixin, IterableQueryMixin, AsyncQueryMixin): """Represents a query that is used to locate Vulnerabiltity objects.""" VALID_DEVICE_TYPE = ["WORKLOAD", "ENDPOINT"] VALID_OS_TYPE = ["CENTOS", "RHEL", "SLES", "UBUNTU", "WINDOWS"] VALID_SEVERITY = ["CRITICAL", "IMPORTANT", "MODERATE", "LOW"] VALID_SYNC_TYPE = ["MANUAL", "SCHEDULED"] VALID_SYNC_STATUS = [ "NOT_STARTED", "MATCHED", "ERROR", "NOT_MATCHED", "NOT_SUPPORTED", "CANCELLED", "IN_PROGRESS", "ACTIVE", "COMPLETED" ] VALID_DIRECTIONS = ["ASC", "DESC"] def __init__(self, doc_class, cb, device=None): """ Initialize the VulnerabilityQuery. Args: doc_class (class): The model class that will be returned by this query. cb (BaseAPI): Reference to API object used to communicate with the server. device (Device): Optional Device object to indicate VulnerabilityQuery is for a specific device """ self._doc_class = doc_class self._cb = cb self._count_valid = False super(BaseQuery, self).__init__() self._query_builder = QueryBuilder() self._criteria = {} self._sortcriteria = {} self._total_results = 0 self.device = device self._vcenter_uuid = None def set_vcenter(self, vcenter_uuid): """ Restricts the vulnerabilities that this query is performed on to the specified vcenter id. Args: vcenter_uuid (str): vcenter uuid. Returns: VulnerabilityQuery: This instance. """ if vcenter_uuid: self._vcenter_uuid = vcenter_uuid return self def add_criteria(self, key, value, operator='EQUALS'): """ Restricts the vulnerabilities that this query is performed on to the specified key value pair. Args: key (str): Property from the vulnerability object value (str): Value of the property to filter by operator (str): (optional) logic operator to apply to property value. Returns: VulnerabilityQuery: This instance. """ self._update_criteria(key, value, operator) return self def set_device_type(self, device_type, operator): """ Restricts the vulnerabilities that this query is performed on to the specified device type. Args: device_type (str): device type ("WORKLOAD", "ENDPOINT") operator (str): logic operator to apply to property value. Returns: VulnerabilityQuery: This instance. """ if device_type not in VulnerabilityQuery.VALID_DEVICE_TYPE: raise ApiError("Invalid device type") self._update_criteria("device_type", device_type, operator) return self def set_highest_risk_score(self, highest_risk_score, operator): """ Restricts the vulnerabilities that this query is performed on to the specified highest_risk_score. Args: highest_risk_score (double): highest_risk_score. operator (str): logic operator to apply to property value. Returns: VulnerabilityQuery: This instance. """ if not highest_risk_score: raise ApiError("Invalid highest risk score") self._update_criteria("highest_risk_score", highest_risk_score, operator) return self def set_name(self, name, operator): """ Restricts the vulnerabilities that this query is performed on to the specified name. Args: name (str): name. operator (str): logic operator to apply to property value. Returns: VulnerabilityQuery: This instance. """ if not name: raise ApiError("Invalid name") self._update_criteria("name", name, operator) return self def set_os_arch(self, os_arch, operator): """ Restricts the vulnerabilities that this query is performed on to the specified os_arch. Args: os_arch (str): os_arch. operator (str): logic operator to apply to property value. Returns: VulnerabilityQuery: This instance. """ if not os_arch: raise ApiError("Invalid os architecture") self._update_criteria("os_arch", os_arch, operator) return self def set_os_name(self, os_name, operator): """ Restricts the vulnerabilities that this query is performed on to the specified os_name. Args: os_name (str): os_name. operator (str): logic operator to apply to property value. Returns: VulnerabilityQuery: This instance. """ if not os_name: raise ApiError("Invalid os name") self._update_criteria("os_name", os_name, operator) return self def set_os_type(self, os_type, operator): """ Restricts the vulnerabilities that this query is performed on to the specified os type. Args: os_type (str): os type ("CENTOS", "RHEL", "SLES", "UBUNTU", "WINDOWS") operator (str): logic operator to apply to property value. Returns: VulnerabilityQuery: This instance. """ if os_type not in VulnerabilityQuery.VALID_OS_TYPE: raise ApiError("Invalid os type") self._update_criteria("os_type", os_type, operator) return self def set_os_version(self, os_version, operator): """ Restricts the vulnerabilities that this query is performed on to the specified os_version. Args: os_version (str): os_version. operator (str): logic operator to apply to property value. Returns: VulnerabilityQuery: This instance. """ if not os_version: raise ApiError("Invalid os version") self._update_criteria("os_version", os_version, operator) return self def set_severity(self, severity, operator): """ Restricts the vulnerabilities that this query is performed on to the specified severity. Args: severity (str): severity ("CRITICAL", "IMPORTANT", "MODERATE", "LOW") operator (str): logic operator to apply to property value. Returns: VulnerabilityQuery: This instance. """ if severity not in VulnerabilityQuery.VALID_SEVERITY: raise ApiError("Invalid severity") self._update_criteria("severity", severity, operator) return self def set_sync_type(self, sync_type, operator): """ Restricts the vulnerabilities that this query is performed on to the specified sync_type. Args: sync_type (str): sync_type ("MANUAL", "SCHEDULED") operator (str): logic operator to apply to property value. Returns: VulnerabilityQuery: This instance. """ if sync_type not in VulnerabilityQuery.VALID_SYNC_TYPE: raise ApiError("Invalid sync type") self._update_criteria("sync_type", sync_type, operator) return self def set_sync_status(self, sync_status, operator): """ Restricts the vulnerabilities that this query is performed on to the specified sync_status. Args: sync_status (str): sync_status ("NOT_STARTED", "MATCHED", "ERROR", "NOT_MATCHED", "NOT_SUPPORTED", "CANCELLED", "IN_PROGRESS", "ACTIVE", "COMPLETED") operator (str): logic operator to apply to property value. Returns: VulnerabilityQuery: This instance. """ if sync_status not in VulnerabilityQuery.VALID_SYNC_STATUS: raise ApiError("Invalid sync status") self._update_criteria("sync_status", sync_status, operator) return self def set_vm_id(self, vm_id, operator): """ Restricts the vulnerabilities that this query is performed on to the specified vm_id. Args: vm_id (str): vm_id. operator (str): logic operator to apply to property value. Returns: VulnerabilityQuery: This instance. """ if not vm_id: raise ApiError("Invalid vm id") self._update_criteria("vm_id", vm_id, operator) return self def set_vuln_count(self, vuln_count, operator): """ Restricts the vulnerabilities that this query is performed on to the specified vuln_count. Args: vuln_count (str): vuln_count. operator (str): logic operator to apply to property value. Returns: VulnerabilityQuery: This instance. """ if not vuln_count: raise ApiError("Invalid vuln count") self._update_criteria("vuln_count", vuln_count, operator) return self def set_last_sync_ts(self, last_sync_ts, operator): """ Restricts the vulnerabilities that this query is performed on to the specified last_sync_ts. Args: last_sync_ts (str): last_sync_ts. operator (str): logic operator to apply to property value. Returns: VulnerabilityQuery: This instance. """ if not last_sync_ts: raise ApiError("Invalid last_sync_ts") self._update_criteria("last_sync_ts", last_sync_ts, operator) return self """ Including custom update_criteria, because the format is different: "criteria": { "property": { "value": "<str>", "operator": "<str>" } } """ def _update_criteria(self, key, value, operator, overwrite=False): """ Updates a list of criteria being collected for a query, by setting or appending items. Args: key (str): The property for the criteria item to be set. value (can be different types): the value for the criteria operator (str): any of the following types: - EQUALS, NOT_EQUALS, GREATER_THAN, LESS_THAN, IS_NULL, IS_NOT_NULL, IS_TRUE, IS_FALSE, IN, NOT_IN, LIKE overwrite (bool): Overwrite the existing criteria for specified key The values are not lists, so if override is not allowed, disregard the change! """ if self._criteria.get(key, None) is None or overwrite: self._criteria[key] = dict(value=value, operator=operator) def _build_request(self, from_row, max_rows, add_sort=True): """ Creates the request body for an API call. Args: from_row (int): The row to start the query at. max_rows (int): The maximum number of rows to be returned. add_sort (bool): If True(default), the sort criteria will be added as part of the request. Returns: dict: The complete request body. """ request = { "criteria": self._criteria, "query": self._query_builder._collapse(), "rows": 100 } # Fetch 100 rows per page (instead of 10 by default) for better performance if from_row > 0: request["start"] = from_row if max_rows >= 0: request["rows"] = max_rows if add_sort and self._sortcriteria != {}: request["sort"] = [self._sortcriteria] return request def _build_url(self, tail_end): """ Creates the URL to be used for an API call. Args: tail_end (str): str to be appended to the end of the generated URL. Returns: str: The complete URL. """ if self.device: additional = f"/devices/{self.device._model_unique_id}/vulnerabilities" else: additional = "/devices/vulnerabilities" if self._vcenter_uuid: additional = f"/vcenters/{self._vcenter_uuid}" + additional url = self._doc_class.urlobject.format( self._cb.credentials.org_key) + additional + tail_end return url def _count(self): """ Returns the number of results from the run of this query. Returns: int: The number of results from the run of this query. """ if self._count_valid: return self._total_results url = self._build_url("/_search") request = self._build_request(0, -1) resp = self._cb.post_object(url, body=request) result = resp.json() self._total_results = result["num_found"] self._count_valid = True return self._total_results def _perform_query(self, from_row=0, max_rows=-1): """ Performs the query and returns the results of the query in an iterable fashion. Args: from_row (int): The row to start the query at (default 0). max_rows (int): The maximum number of rows to be returned (default -1, meaning "all"). Returns: Iterable: The iterated query. """ url = self._build_url("/_search?dataForExport=true") current = from_row numrows = 0 still_querying = True while still_querying: request = self._build_request(current, max_rows) resp = self._cb.post_object(url, body=request) result = resp.json() self._total_results = result["num_found"] self._count_valid = True results = result.get("results", []) for item in results: yield self._doc_class(self._cb, item.get('vuln_info', {}).get('cve_id', None), initial_data=item) current += 1 numrows += 1 if max_rows > 0 and numrows == max_rows: still_querying = False break if current >= self._total_results: break def _run_async_query(self, context): """ Executed in the background to run an asynchronous query. Args: context (object): Not used, always None. Returns: list: Result of the async query, which is then returned by the future. """ url = self._build_url("/_search?dataForExport=true") request = self._build_request(0, -1) resp = self._cb.post_object(url, body=request) result = resp.json() self._total_results = result["num_found"] self._count_valid = True results = result.get("results", []) return [ self._doc_class(self._cb, item.get('vuln_info', {}).get('cve_id', None), initial_data=item) for item in results ] def sort_by(self, key, direction="ASC"): """ Sets the sorting behavior on a query's results. Example: >>> cb.select(Vulnerabiltiy).sort_by("status") Args: key (str): The key in the schema to sort by. direction (str): The sort order, either "ASC" or "DESC". Returns: VulnerabilityQuery: This instance. Raises: ApiError: If an invalid direction value is passed. """ if direction not in VulnerabilityQuery.VALID_DIRECTIONS: raise ApiError("invalid sort direction specified") self._sortcriteria = {"field": key, "order": direction} return self
class FacetQuery(BaseQuery, QueryBuilderSupportMixin, IterableQueryMixin, CriteriaBuilderSupportMixin): """Represents a query that receives facet information from a LiveQuery run.""" def __init__(self, doc_class, cb): """Initialize a FacetQuery object.""" self._doc_class = doc_class self._cb = cb self._count_valid = False super(FacetQuery, self).__init__() self._query_builder = QueryBuilder() self._facet_fields = [] self._criteria = {} self._run_id = None def facet_field(self, field): """Sets the facet fields to be received by this query. Arguments: field (str or [str]): Field(s) to be received. Returns: FacetQuery that will receive field(s) facet_field. Example: >>> cb.select(ResultFacet).run_id(my_run).facet_field(["device.policy_name", "device.os"]) """ if isinstance(field, str): self._facet_fields.append(field) else: for name in field: self._facet_fields.append(name) return self def set_device_ids(self, device_ids): """Sets the device.id criteria filter. Arguments: device_ids ([int]): Device IDs to filter on. Returns: The FacetQuery with specified device.id. """ if not all(isinstance(device_id, int) for device_id in device_ids): raise ApiError("One or more invalid device IDs") self._update_criteria("device.id", device_ids) return self def set_device_names(self, device_names): """Sets the device.name criteria filter. Arguments: device_names ([str]): Device names to filter on. Returns: The FacetQuery with specified device.name. """ if not all(isinstance(name, str) for name in device_names): raise ApiError("One or more invalid device names") self._update_criteria("device.name", device_names) return self def set_device_os(self, device_os): """Sets the device.os criteria. Arguments: device_os ([str]): Device OS's to filter on. Returns: The FacetQuery object with specified device_os. Note: Device OS's can be one or more of ["WINDOWS", "MAC", "LINUX"]. """ if not all(isinstance(os, str) for os in device_os): raise ApiError("device_type must be a list of strings, including" " 'WINDOWS', 'MAC', and/or 'LINUX'") self._update_criteria("device.os", device_os) return self def set_policy_ids(self, policy_ids): """Sets the device.policy_id criteria. Arguments: policy_ids ([int]): Device policy ID's to filter on. Returns: The FacetQuery object with specified policy_ids. """ if not all(isinstance(id, int) for id in policy_ids): raise ApiError("policy_ids must be a list of integers.") self._update_criteria("device.policy_id", policy_ids) return self def set_policy_names(self, policy_names): """Sets the device.policy_name criteria. Arguments: policy_names ([str]): Device policy names to filter on. Returns: The FacetQuery object with specified policy_names. """ if not all(isinstance(name, str) for name in policy_names): raise ApiError("policy_names must be a list of strings.") self._update_criteria("device.policy_name", policy_names) return self def set_statuses(self, statuses): """Sets the status criteria. Arguments: statuses ([str]): Query statuses to filter on. Returns: The FacetQuery object with specified statuses. """ if not all(isinstance(status, str) for status in statuses): raise ApiError("statuses must be a list of strings.") self._update_criteria("status", statuses) return self def run_id(self, run_id): """Sets the run ID to query results for. Arguments: run_id (int): The run ID to retrieve results for. Returns: FacetQuery object with specified run_id. Example: >>> cb.select(ResultFacet).run_id(my_run) """ self._run_id = run_id return self def _build_request(self, rows): terms = {"fields": self._facet_fields} if rows != 0: terms["rows"] = rows request = {"query": self._query_builder._collapse(), "terms": terms} if self._criteria: request["criteria"] = self._criteria return request def _perform_query(self, rows=0): if self._run_id is None: raise ApiError("Can't retrieve results without a run ID") url = self._doc_class.urlobject.format( self._cb.credentials.org_key, self._run_id ) request = self._build_request(rows) resp = self._cb.post_object(url, body=request) result = resp.json() results = result.get("terms", []) for item in results: yield self._doc_class(self._cb, item)
class DeviceSearchQuery(BaseQuery, QueryBuilderSupportMixin, CriteriaBuilderSupportMixin, IterableQueryMixin, AsyncQueryMixin): """Represents a query that is used to locate Device objects.""" VALID_OS = ["WINDOWS", "ANDROID", "MAC", "IOS", "LINUX", "OTHER"] VALID_STATUSES = [ "PENDING", "REGISTERED", "UNINSTALLED", "DEREGISTERED", "ACTIVE", "INACTIVE", "ERROR", "ALL", "BYPASS_ON", "BYPASS", "QUARANTINE", "SENSOR_OUTOFDATE", "DELETED", "LIVE" ] VALID_PRIORITIES = ["LOW", "MEDIUM", "HIGH", "MISSION_CRITICAL"] VALID_DIRECTIONS = ["ASC", "DESC"] VALID_DEPLOYMENT_TYPES = ["ENDPOINT", "WORKLOAD"] def __init__(self, doc_class, cb): """ Initialize the DeviceSearchQuery. Args: doc_class (class): The model class that will be returned by this query. cb (BaseAPI): Reference to API object used to communicate with the server. """ self._doc_class = doc_class self._cb = cb self._count_valid = False super(DeviceSearchQuery, self).__init__() self._query_builder = QueryBuilder() self._criteria = {} self._time_filter = {} self._exclusions = {} self._sortcriteria = {} self.max_rows = -1 def _update_exclusions(self, key, newlist): """ Updates the exclusion criteria being collected for a query. Assumes the specified criteria item is defined as a list; the list passed in will be set as the value for this criteria item, or appended to the existing one if there is one. Args: key (str): The key for the criteria item to be set. newlist (list): List of values to be set for the criteria item. """ oldlist = self._exclusions.get(key, []) self._exclusions[key] = oldlist + newlist def set_ad_group_ids(self, ad_group_ids): """ Restricts the devices that this query is performed on to the specified AD group IDs. Args: ad_group_ids (list): List of AD group IDs to restrict the search to. Returns: DeviceSearchQuery: This instance. Raises: ApiError: If invalid (non-int) values are passed in the list. """ if not all( isinstance(ad_group_id, int) for ad_group_id in ad_group_ids): raise ApiError("One or more invalid AD group IDs") self._update_criteria("ad_group_id", ad_group_ids) return self def set_device_ids(self, device_ids): """ Restricts the devices that this query is performed on to the specified device IDs. Args: device_ids (list): List of device IDs to restrict the search to. Returns: DeviceSearchQuery: This instance. Raises: ApiError: If invalid (non-int) values are passed in the list. """ if not all(isinstance(device_id, int) for device_id in device_ids): raise ApiError("One or more invalid device IDs") self._update_criteria("id", device_ids) return self def set_last_contact_time(self, *args, **kwargs): """ Restricts the devices that this query is performed on to the specified last contact time. Args: *args (list): Not used, retained for compatibility. **kwargs (dict): Keyword arguments to this function. The critical ones are "start" (the start time), "end" (the end time), and "range" (the range value). Returns: DeviceSearchQuery: This instance. Raises: ApiError: If an invalid combination of keyword parameters are specified. """ if kwargs.get("start", None) and kwargs.get("end", None): if kwargs.get("range", None): raise ApiError( "cannot specify range= in addition to start= and end=") stime = kwargs["start"] if not isinstance(stime, str): stime = stime.isoformat() etime = kwargs["end"] if not isinstance(etime, str): etime = etime.isoformat() self._time_filter = {"start": stime, "end": etime} elif kwargs.get("range", None): if kwargs.get("start", None) or kwargs.get("end", None): raise ApiError( "cannot specify start= or end= in addition to range=") self._time_filter = {"range": kwargs["range"]} else: raise ApiError("must specify either start= and end= or range=") return self def set_os(self, operating_systems): """ Restricts the devices that this query is performed on to the specified operating systems. Args: operating_systems (list): List of operating systems to restrict search to. Valid values in this list are "WINDOWS", "ANDROID", "MAC", "IOS", "LINUX", and "OTHER". Returns: DeviceSearchQuery: This instance. Raises: ApiError: If invalid operating system values are passed in the list. """ if not all((osval in DeviceSearchQuery.VALID_OS) for osval in operating_systems): raise ApiError("One or more invalid operating systems") self._update_criteria("os", operating_systems) return self def set_policy_ids(self, policy_ids): """ Restricts the devices that this query is performed on to the specified policy IDs. Args: policy_ids (list): List of policy IDs to restrict the search to. Returns: DeviceSearchQuery: This instance. Raises: ApiError: If invalid (non-int) values are passed in the list. """ if not all(isinstance(policy_id, int) for policy_id in policy_ids): raise ApiError("One or more invalid policy IDs") self._update_criteria("policy_id", policy_ids) return self def set_status(self, statuses): """ Restricts the devices that this query is performed on to the specified status values. Args: statuses (list): List of statuses to restrict search to. Valid values in this list are "PENDING", "REGISTERED", "UNINSTALLED", "DEREGISTERED", "ACTIVE", "INACTIVE", "ERROR", "ALL", "BYPASS_ON", "BYPASS", "QUARANTINE", "SENSOR_OUTOFDATE", "DELETED", and "LIVE". Returns: DeviceSearchQuery: This instance. Raises: ApiError: If invalid status values are passed in the list. """ if not all( (stat in DeviceSearchQuery.VALID_STATUSES) for stat in statuses): raise ApiError("One or more invalid status values") self._update_criteria("status", statuses) return self def set_target_priorities(self, target_priorities): """ Restricts the devices that this query is performed on to the specified target priority values. Args: target_priorities (list): List of priorities to restrict search to. Valid values in this list are "LOW", "MEDIUM", "HIGH", and "MISSION_CRITICAL". Returns: DeviceSearchQuery: This instance. Raises: ApiError: If invalid priority values are passed in the list. """ if not all((prio in DeviceSearchQuery.VALID_PRIORITIES) for prio in target_priorities): raise ApiError("One or more invalid target priority values") self._update_criteria("target_priority", target_priorities) return self def set_exclude_sensor_versions(self, sensor_versions): """ Restricts the devices that this query is performed on to exclude specified sensor versions. Args: sensor_versions (list): List of sensor versions to be excluded. Returns: DeviceSearchQuery: This instance. Raises: ApiError: If invalid (non-string) values are passed in the list. """ if not all(isinstance(v, str) for v in sensor_versions): raise ApiError("One or more invalid sensor versions") self._update_exclusions("sensor_version", sensor_versions) return self def sort_by(self, key, direction="ASC"): """ Sets the sorting behavior on a query's results. Example: >>> cb.select(Device).sort_by("status") Args: key (str): The key in the schema to sort by. direction (str): The sort order, either "ASC" or "DESC". Returns: DeviceSearchQuery: This instance. Raises: ApiError: If an invalid direction value is passed. """ if direction not in DeviceSearchQuery.VALID_DIRECTIONS: raise ApiError("invalid sort direction specified") self._sortcriteria = {"field": key, "order": direction} return self def set_deployment_type(self, deployment_type): """ Restricts the devices that this query is performed on to the specified deployment types. Args: deployment_type (list): List of deployment types to restrict search to. Returns: DeviceSearchQuery: This instance. Raises: ApiError: If invalid deployment type values are passed in the list. """ if not all((type in DeviceSearchQuery.VALID_DEPLOYMENT_TYPES) for type in deployment_type): raise ApiError("invalid deployment_type specified") self._update_criteria("deployment_type", deployment_type) return self def set_max_rows(self, max_rows): """ Sets the max number of devices to fetch in a singular query Args: max_rows (integer): Max number of devices Returns: DeviceSearchQuery: This instance. Raises: ApiError: If rows is negative or greater than 10000 """ if max_rows < 0 or max_rows > 10000: raise ApiError("Max rows must be between 0 and 10000") self.max_rows = max_rows return self def _build_request(self, from_row, max_rows): """ Creates the request body for an API call. Args: from_row (int): The row to start the query at. max_rows (int): The maximum number of rows to be returned. Returns: dict: The complete request body. """ mycrit = self._criteria if self._time_filter: mycrit["last_contact_time"] = self._time_filter request = {"criteria": mycrit, "exclusions": self._exclusions} request["query"] = self._query_builder._collapse() if from_row > 1: request["start"] = from_row if max_rows >= 0: request["rows"] = max_rows elif self.max_rows >= 0: request["rows"] = self.max_rows if self._sortcriteria != {}: request["sort"] = [self._sortcriteria] return request def _build_url(self, tail_end): """ Creates the URL to be used for an API call. Args: tail_end (str): String to be appended to the end of the generated URL. Returns: str: The complete URL. """ url = self._doc_class.urlobject.format( self._cb.credentials.org_key) + tail_end return url def _count(self): """ Returns the number of results from the run of this query. Returns: int: The number of results from the run of this query. """ if self._count_valid: return self._total_results url = self._build_url("/_search") request = self._build_request(0, -1) resp = self._cb.post_object(url, body=request) result = resp.json() self._total_results = result["num_found"] self._count_valid = True return self._total_results def _perform_query(self, from_row=1, max_rows=-1): """ Performs the query and returns the results of the query in an iterable fashion. Device v6 API uses base 1 instead of 0. Args: from_row (int): The row to start the query at (default 1). max_rows (int): The maximum number of rows to be returned (default -1, meaning "all"). Returns: Iterable: The iterated query. """ url = self._build_url("/_search") current = from_row numrows = 0 still_querying = True while still_querying: request = self._build_request(current, max_rows) resp = self._cb.post_object(url, body=request) result = resp.json() self._total_results = result["num_found"] self._count_valid = True results = result.get("results", []) for item in results: yield self._doc_class(self._cb, item["id"], item) current += 1 numrows += 1 if max_rows > 0 and numrows == max_rows: still_querying = False break from_row = current if current >= self._total_results: still_querying = False break def _run_async_query(self, context): """ Executed in the background to run an asynchronous query. Must be implemented in any inheriting classes. Args: context (object): The context returned by _init_async_query. May be None. Returns: Any: Result of the async query, which is then returned by the future. """ url = self._build_url("/_search") self._total_results = 0 self._count_valid = False output = [] while not self._count_valid or len(output) < self._total_results: request = self._build_request(len(output), -1) resp = self._cb.post_object(url, body=request) result = resp.json() if not self._count_valid: self._total_results = result["num_found"] self._count_valid = True results = result.get("results", []) output += [ self._doc_class(self._cb, item["id"], item) for item in results ] return output def download(self): """ Uses the query parameters that have been set to download all device listings in CSV format. Example: >>> cb.select(Device).set_status(["ALL"]).download() Returns: str: The CSV raw data as returned from the server. Raises: ApiError: If status values have not been set before calling this function. """ tmp = self._criteria.get("status", []) if not tmp: raise ApiError("at least one status must be specified to download") query_params = {"status": ",".join(tmp)} tmp = self._criteria.get("ad_group_id", []) if tmp: query_params["ad_group_id"] = ",".join([str(t) for t in tmp]) tmp = self._criteria.get("policy_id", []) if tmp: query_params["policy_id"] = ",".join([str(t) for t in tmp]) tmp = self._criteria.get("target_priority", []) if tmp: query_params["target_priority"] = ",".join(tmp) tmp = self._query_builder._collapse() if tmp: query_params["query_string"] = tmp if self._sortcriteria: query_params["sort_field"] = self._sortcriteria["field"] query_params["sort_order"] = self._sortcriteria["order"] url = self._build_url("/_search/download") return self._cb.get_raw_data(url, query_params) def _bulk_device_action(self, action_type, options=None): """ Perform a bulk action on all devices matching the current search criteria. Args: action_type (str): The action type to be performed. options (dict): Any options for the bulk device action. Returns: str: The JSON output from the request. """ request = { "action_type": action_type, "search": self._build_request(0, -1) } if options: request["options"] = options return self._cb._raw_device_action(request) def background_scan(self, scan): """ Set the background scan option for the specified devices. Args: scan (bool): True to turn background scan on, False to turn it off. Returns: str: The JSON output from the request. """ return self._bulk_device_action("BACKGROUND_SCAN", self._cb._action_toggle(scan)) def bypass(self, enable): """ Set the bypass option for the specified devices. Args: enable (bool): True to enable bypass, False to disable it. Returns: str: The JSON output from the request. """ return self._bulk_device_action("BYPASS", self._cb._action_toggle(enable)) def delete_sensor(self): """ Delete the specified sensor devices. Returns: str: The JSON output from the request. """ return self._bulk_device_action("DELETE_SENSOR") def uninstall_sensor(self): """ Uninstall the specified sensor devices. Returns: str: The JSON output from the request. """ return self._bulk_device_action("UNINSTALL_SENSOR") def quarantine(self, enable): """ Set the quarantine option for the specified devices. Args: enable (bool): True to enable quarantine, False to disable it. Returns: str: The JSON output from the request. """ return self._bulk_device_action("QUARANTINE", self._cb._action_toggle(enable)) def update_policy(self, policy_id): """ Set the current policy for the specified devices. Args: policy_id (int): ID of the policy to set for the devices. Returns: str: The JSON output from the request. """ return self._bulk_device_action("UPDATE_POLICY", {"policy_id": policy_id}) def update_sensor_version(self, sensor_version): """ Update the sensor version for the specified devices. Args: sensor_version (dict): New version properties for the sensor. Returns: str: The JSON output from the request. """ return self._bulk_device_action("UPDATE_SENSOR_VERSION", {"sensor_version": sensor_version})
class ResultQuery(BaseQuery, QueryBuilderSupportMixin, IterableQueryMixin, CriteriaBuilderSupportMixin): """Represents a query that retrieves results from a LiveQuery run.""" def __init__(self, doc_class, cb): """Initialize a ResultQuery object.""" self._doc_class = doc_class self._cb = cb self._count_valid = False super(ResultQuery, self).__init__() self._query_builder = QueryBuilder() self._criteria = {} self._sort = {} self._batch_size = 100 self._run_id = None def set_device_ids(self, device_ids): """Sets the device.id criteria filter. Arguments: device_ids ([int]): Device IDs to filter on. Returns: The ResultQuery with specified device.id. """ if not all(isinstance(device_id, int) for device_id in device_ids): raise ApiError("One or more invalid device IDs") self._update_criteria("device.id", device_ids) return self def set_device_names(self, device_names): """Sets the device.name criteria filter. Arguments: device_names ([str]): Device names to filter on. Returns: The ResultQuery with specified device.name. """ if not all(isinstance(name, str) for name in device_names): raise ApiError("One or more invalid device names") self._update_criteria("device.name", device_names) return self def set_device_os(self, device_os): """Sets the device.os criteria. Arguments: device_os ([str]): Device OS's to filter on. Returns: The ResultQuery object with specified device_os. Note: Device OS's can be one or more of ["WINDOWS", "MAC", "LINUX"]. """ if not all(isinstance(os, str) for os in device_os): raise ApiError("device_type must be a list of strings, including" " 'WINDOWS', 'MAC', and/or 'LINUX'") self._update_criteria("device.os", device_os) return self def set_policy_ids(self, policy_ids): """Sets the device.policy_id criteria. Arguments: policy_ids ([int]): Device policy ID's to filter on. Returns: The ResultQuery object with specified policy_ids. """ if not all(isinstance(id, int) for id in policy_ids): raise ApiError("policy_ids must be a list of integers.") self._update_criteria("device.policy_id", policy_ids) return self def set_policy_names(self, policy_names): """Sets the device.policy_name criteria. Arguments: policy_names ([str]): Device policy names to filter on. Returns: The ResultQuery object with specified policy_names. """ if not all(isinstance(name, str) for name in policy_names): raise ApiError("policy_names must be a list of strings.") self._update_criteria("device.policy_name", policy_names) return self def set_statuses(self, statuses): """Sets the status criteria. Arguments: statuses ([str]): Query statuses to filter on. Returns: The ResultQuery object with specified statuses. """ if not all(isinstance(status, str) for status in statuses): raise ApiError("statuses must be a list of strings.") self._update_criteria("status", statuses) return self def sort_by(self, key, direction="ASC"): """Sets the sorting behavior on a query's results. Arguments: key (str): The key in the schema to sort by. direction (str): The sort order, either "ASC" or "DESC". Returns: ResultQuery object with specified sorting key and order. Example: >>> cb.select(Result).run_id(my_run).where(username="******").sort_by("uid") """ self._sort.update({"field": key, "order": direction}) return self def run_id(self, run_id): """Sets the run ID to query results for. Arguments: run_id (int): The run ID to retrieve results for. Returns: ResultQuery object with specified run_id. Example: >>> cb.select(Result).run_id(my_run) """ self._run_id = run_id return self def _build_request(self, start, rows): request = {"start": start, "query": self._query_builder._collapse()} if rows != 0: request["rows"] = rows if self._criteria: request["criteria"] = self._criteria if self._sort: request["sort"] = [self._sort] return request def _count(self): if self._count_valid: return self._total_results if self._run_id is None: raise ApiError("Can't retrieve count without a run ID") url = self._doc_class.urlobject.format( self._cb.credentials.org_key, self._run_id ) request = self._build_request(start=0, rows=0) resp = self._cb.post_object(url, body=request) result = resp.json() self._total_results = result["num_found"] self._count_valid = True return self._total_results def _perform_query(self, start=0, rows=0): if self._run_id is None: raise ApiError("Can't retrieve results without a run ID") url = self._doc_class.urlobject.format( self._cb.credentials.org_key, self._run_id ) current = start numrows = 0 still_querying = True while still_querying: request = self._build_request(start, rows) resp = self._cb.post_object(url, body=request) result = resp.json() self._total_results = result["num_found"] if self._total_results > MAX_RESULTS_LIMIT: self._total_results = MAX_RESULTS_LIMIT self._count_valid = True results = result.get("results", []) for item in results: yield self._doc_class(self._cb, item) current += 1 numrows += 1 if rows and numrows == rows: still_querying = False break start = current if current >= self._total_results: still_querying = False break
class RunHistoryQuery(BaseQuery, QueryBuilderSupportMixin, IterableQueryMixin, CriteriaBuilderSupportMixin): """Represents a query that retrieves historic LiveQuery runs.""" def __init__(self, doc_class, cb): """Initialize a RunHistoryQuery object.""" self._doc_class = doc_class self._cb = cb self._count_valid = False super(RunHistoryQuery, self).__init__() self._query_builder = QueryBuilder() self._sort = {} self._criteria = {} def set_template_ids(self, template_ids): """Sets the template_id criteria filter. Arguments: template_ids ([str]): Template IDs to filter on. Returns: The ResultQuery with specified template_id. """ if not all(isinstance(template_id, str) for template_id in template_ids): raise ApiError("One or more invalid template IDs") self._update_criteria("template_id", template_ids) return self def sort_by(self, key, direction="ASC"): """Sets the sorting behavior on a query's results. Arguments: key (str): The key in the schema to sort by. direction (str): The sort order, either "ASC" or "DESC". Returns: RunHistoryQuery object with specified sorting key and order. Example: >>> cb.select(Result).run_id(my_run).where(username="******").sort_by("uid") """ self._sort.update({"field": key, "order": direction}) return self def _build_request(self, start, rows): request = {"start": start} if self._query_builder: request["query"] = self._query_builder._collapse() if rows != 0: request["rows"] = rows if self._criteria: request["criteria"] = self._criteria if self._sort: request["sort"] = [self._sort] return request def _count(self): if self._count_valid: return self._total_results url = self._doc_class.urlobject_history.format( self._cb.credentials.org_key ) request = self._build_request(start=0, rows=0) resp = self._cb.post_object(url, body=request) result = resp.json() self._total_results = result["num_found"] self._count_valid = True return self._total_results def _perform_query(self, start=0, rows=0): url = self._doc_class.urlobject_history.format( self._cb.credentials.org_key ) current = start numrows = 0 still_querying = True while still_querying: request = self._build_request(start, rows) resp = self._cb.post_object(url, body=request) result = resp.json() self._total_results = result["num_found"] self._count_valid = True results = result.get("results", []) for item in results: yield self._doc_class(self._cb, item) current += 1 numrows += 1 if rows and numrows == rows: still_querying = False break start = current if current >= self._total_results: still_querying = False break
class USBDeviceApprovalQuery(BaseQuery, QueryBuilderSupportMixin, CriteriaBuilderSupportMixin, IterableQueryMixin, AsyncQueryMixin): """Represents a query that is used to locate USBDeviceApproval objects.""" def __init__(self, doc_class, cb): """ Initialize the USBDeviceApprovalQuery. Args: doc_class (class): The model class that will be returned by this query. cb (BaseAPI): Reference to API object used to communicate with the server. """ self._doc_class = doc_class self._cb = cb self._count_valid = False super(BaseQuery, self).__init__() self._query_builder = QueryBuilder() self._criteria = {} self._total_results = 0 def set_device_ids(self, device_ids): """ Restricts the device approvals that this query is performed on to the specified device IDs. Args: device_ids (list): List of string device IDs. Returns: USBDeviceApprovalQuery: This instance. """ if not all(isinstance(device_id, str) for device_id in device_ids): raise ApiError("One or more invalid device IDs") self._update_criteria("device.id", device_ids) return self def set_product_names(self, product_names): """ Restricts the device approvals that this query is performed on to the specified product names. Args: product_names (list): List of string product names. Returns: USBDeviceApprovalQuery: This instance. """ if not all( isinstance(product_name, str) for product_name in product_names): raise ApiError("One or more invalid product names") self._update_criteria("product_name", product_names) return self def set_vendor_names(self, vendor_names): """ Restricts the device approvals that this query is performed on to the specified vendor names. Args: vendor_names (list): List of string vendor names. Returns: USBDeviceApprovalQuery: This instance. """ if not all( isinstance(vendor_name, str) for vendor_name in vendor_names): raise ApiError("One or more invalid vendor names") self._update_criteria("vendor_name", vendor_names) return self def _build_request(self, from_row, max_rows): """ Creates the request body for an API call. Args: from_row (int): The row to start the query at. max_rows (int): The maximum number of rows to be returned. Returns: dict: The complete request body. """ request = { "criteria": self._criteria, "query": self._query_builder._collapse(), "rows": 100 } # Fetch 100 rows per page (instead of 10 by default) for better performance if from_row > 0: request["start"] = from_row if max_rows >= 0: request["rows"] = max_rows return request def _build_url(self, tail_end): """ Creates the URL to be used for an API call. Args: tail_end (str): String to be appended to the end of the generated URL. Returns: str: The complete URL. """ url = self._doc_class.urlobject.format( self._cb.credentials.org_key) + tail_end return url def _count(self): """ Returns the number of results from the run of this query. Returns: int: The number of results from the run of this query. """ if self._count_valid: return self._total_results url = self._build_url("/_search") request = self._build_request(0, -1) resp = self._cb.post_object(url, body=request) result = resp.json() self._total_results = result["num_found"] self._count_valid = True return self._total_results def _perform_query(self, from_row=0, max_rows=-1): """ Performs the query and returns the results of the query in an iterable fashion. Args: from_row (int): The row to start the query at (default 0). max_rows (int): The maximum number of rows to be returned (default -1, meaning "all"). Returns: Iterable: The iterated query. """ url = self._build_url("/_search") current = from_row numrows = 0 still_querying = True while still_querying: request = self._build_request(current, max_rows) resp = self._cb.post_object(url, body=request) result = resp.json() self._total_results = result["num_found"] self._count_valid = True results = result.get("results", []) for item in results: yield self._doc_class(self._cb, item["id"], item) current += 1 numrows += 1 if max_rows > 0 and numrows == max_rows: still_querying = False break if current >= self._total_results: break def _run_async_query(self, context): """ Executed in the background to run an asynchronous query. Args: context (object): Not used, always None. Returns: list: Result of the async query, which is then returned by the future. """ url = self._build_url("/_search") request = self._build_request(0, -1) resp = self._cb.post_object(url, body=request) result = resp.json() self._total_results = result["num_found"] self._count_valid = True results = result.get("results", []) return [ self._doc_class(self._cb, item["id"], item) for item in results ]
class ReputationOverrideQuery(BaseQuery, QueryBuilderSupportMixin, IterableQueryMixin, AsyncQueryMixin): """Represents a query that is used to locate ReputationOverride objects.""" VALID_DIRECTIONS = ["ASC", "DESC", "asc", "desc"] def __init__(self, doc_class, cb): """ Initialize the ReputationOverrideQuery. Args: doc_class (class): The model class that will be returned by this query. cb (BaseAPI): Reference to API object used to communicate with the server. """ self._doc_class = doc_class self._cb = cb self._count_valid = False super(ReputationOverrideQuery, self).__init__() self._query_builder = QueryBuilder() self._criteria = {} self._sortcriteria = {} def set_override_list(self, override_list): """Sets the override_list criteria filter. Arguments: override_list (str): Override List to filter on. Returns: The ReputationOverrideQuery with specified override_list. """ if not isinstance(override_list, str) and override_list in [ "WHITE_LIST", "BLACK_LIST" ]: raise ApiError( "Invalid override_list must be one of WHITE_LIST, BLACK_LIST") self._criteria["override_list"] = override_list return self def set_override_type(self, override_type): """Sets the override_type criteria filter. Arguments: override_type (str): Override List to filter on. Returns: The ReputationOverrideQuery with specified override_type. """ if not isinstance(override_type, str) and override_type in [ "SHA256", "CERT", "IT_TOOL" ]: raise ApiError( "Invalid override_type must be one of SHA256, CERT, IT_TOOL") self._criteria["override_type"] = override_type return self def sort_by(self, key, direction="ASC"): """ Sets the sorting behavior on a query's results. Example: >>> cb.select(ReputationOverride).sort_by("create_time") Args: key (str): The key in the schema to sort by. direction (str): The sort order, either "ASC" or "DESC". Returns: ReputationOverrideQuery: This instance. Raises: ApiError: If an invalid direction value is passed. """ if direction not in ReputationOverrideQuery.VALID_DIRECTIONS: raise ApiError("invalid sort direction specified") self._sortcriteria = {"sort_field": key, "sort_order": direction} return self def _build_request(self, from_row, max_rows): """ Creates the request body for an API call. Args: from_row (int): The row to start the query at. max_rows (int): The maximum number of rows to be returned. Returns: dict: The complete request body. """ request = { "criteria": self._criteria, "query": self._query_builder._collapse() } if from_row > 0: request["start"] = from_row if max_rows >= 0: request["rows"] = max_rows if self._sortcriteria != {}: request.update(self._sortcriteria) return request def _build_url(self, tail_end): """ Creates the URL to be used for an API call. Args: tail_end (str): String to be appended to the end of the generated URL. Returns: str: The complete URL. """ url = self._doc_class.urlobject.format( self._cb.credentials.org_key) + tail_end return url def _count(self): """ Returns the number of results from the run of this query. Returns: int: The number of results from the run of this query. """ if self._count_valid: return self._total_results url = self._build_url("/_search") request = self._build_request(0, -1) resp = self._cb.post_object(url, body=request) result = resp.json() self._total_results = result["num_found"] self._count_valid = True return self._total_results def _perform_query(self, from_row=0, max_rows=-1): """ Performs the query and returns the results of the query in an iterable fashion. Args: from_row (int): The row to start the query at (default 0). max_rows (int): The maximum number of rows to be returned (default -1, meaning "all"). Returns: Iterable: The iterated query. """ url = self._build_url("/_search") current = from_row numrows = 0 still_querying = True while still_querying: request = self._build_request(current, max_rows) resp = self._cb.post_object(url, body=request) result = resp.json() self._total_results = result["num_found"] self._count_valid = True results = result.get("results", []) for item in results: yield self._doc_class(self._cb, item["id"], item) current += 1 numrows += 1 if max_rows > 0 and numrows == max_rows: still_querying = False break if current >= self._total_results: break def _run_async_query(self, context): """ Executed in the background to run an asynchronous query. Args: context (object): Not used, always None. Returns: list: Result of the async query, which is then returned by the future. """ url = self._build_url("/_search") request = self._build_request(0, -1) resp = self._cb.post_object(url, body=request) result = resp.json() self._total_results = result["num_found"] self._count_valid = True results = result.get("results", []) return [ self._doc_class(self._cb, item["id"], item) for item in results ]
class USBDeviceQuery(BaseQuery, QueryBuilderSupportMixin, CriteriaBuilderSupportMixin, IterableQueryMixin, AsyncQueryMixin): """Represents a query that is used to locate USBDevice objects.""" VALID_STATUSES = ["APPROVED", "UNAPPROVED"] VALID_FACET_FIELDS = [ "vendor_name", "product_name", "endpoint.endpoint_name", "status" ] def __init__(self, doc_class, cb): """ Initialize the USBDeviceQuery. Args: doc_class (class): The model class that will be returned by this query. cb (BaseAPI): Reference to API object used to communicate with the server. """ self._doc_class = doc_class self._cb = cb self._count_valid = False super(BaseQuery, self).__init__() self._query_builder = QueryBuilder() self._criteria = {} self._sortcriteria = {} self._total_results = 0 def set_endpoint_names(self, endpoint_names): """ Restricts the devices that this query is performed on to the specified endpoint names. Args: endpoint_names (list): List of string endpoint names. Returns: USBDeviceQuery: This instance. """ if not all( isinstance(endpoint_name, str) for endpoint_name in endpoint_names): raise ApiError("One or more invalid endpoint names") self._update_criteria("endpoint.endpoint_name", endpoint_names) return self def set_product_names(self, product_names): """ Restricts the devices that this query is performed on to the specified product names. Args: product_names (list): List of string product names. Returns: USBDeviceQuery: This instance. """ if not all( isinstance(product_name, str) for product_name in product_names): raise ApiError("One or more invalid product names") self._update_criteria("product_name", product_names) return self def set_serial_numbers(self, serial_numbers): """ Restricts the devices that this query is performed on to the specified serial numbers. Args: serial_numbers (list): List of string serial numbers. Returns: USBDeviceQuery: This instance. """ if not all( isinstance(serial_number, str) for serial_number in serial_numbers): raise ApiError("One or more invalid serial numbers") self._update_criteria("serial_number", serial_numbers) return self def set_statuses(self, statuses): """ Restricts the devices that this query is performed on to the specified status values. Args: statuses (list): List of string status values. Valid values are APPROVED and UNAPPROVED. Returns: USBDeviceQuery: This instance. """ if not all((s in USBDeviceQuery.VALID_STATUSES) for s in statuses): raise ApiError("One or more invalid status values") self._update_criteria("status", statuses) return self def set_vendor_names(self, vendor_names): """ Restricts the devices that this query is performed on to the specified vendor names. Args: vendor_names (list): List of string vendor names. Returns: USBDeviceQuery: This instance. """ if not all( isinstance(vendor_name, str) for vendor_name in vendor_names): raise ApiError("One or more invalid vendor names") self._update_criteria("vendor_name", vendor_names) return self def sort_by(self, key, direction="ASC"): """ Sets the sorting behavior on a query's results. Example: >>> cb.select(USBDevice).sort_by("product_name") Args: key (str): The key in the schema to sort by. direction (str): The sort order, either "ASC" or "DESC". Returns: USBDeviceQuery: This instance. """ if direction not in DeviceSearchQuery.VALID_DIRECTIONS: raise ApiError("invalid sort direction specified") self._sortcriteria = {"field": key, "order": direction} return self def _build_request(self, from_row, max_rows, add_sort=True): """ Creates the request body for an API call. Args: from_row (int): The row to start the query at. max_rows (int): The maximum number of rows to be returned. add_sort (bool): If True(default), the sort criteria will be added as part of the request. Returns: dict: The complete request body. """ request = { "criteria": self._criteria, "query": self._query_builder._collapse(), "rows": 100 } # Fetch 100 rows per page (instead of 10 by default) for better performance if from_row > 0: request["start"] = from_row if max_rows >= 0: request["rows"] = max_rows if add_sort and self._sortcriteria != {}: request["sort"] = [self._sortcriteria] return request def _build_url(self, tail_end): """ Creates the URL to be used for an API call. Args: tail_end (str): String to be appended to the end of the generated URL. Returns: str: The complete URL. """ url = self._doc_class.urlobject.format( self._cb.credentials.org_key) + tail_end return url def _count(self): """ Returns the number of results from the run of this query. Returns: int: The number of results from the run of this query. """ if self._count_valid: return self._total_results url = self._build_url("/_search") request = self._build_request(0, -1) resp = self._cb.post_object(url, body=request) result = resp.json() self._total_results = result["num_found"] self._count_valid = True return self._total_results def _perform_query(self, from_row=0, max_rows=-1): """ Performs the query and returns the results of the query in an iterable fashion. Args: from_row (int): The row to start the query at (default 0). max_rows (int): The maximum number of rows to be returned (default -1, meaning "all"). Returns: Iterable: The iterated query. """ url = self._build_url("/_search") current = from_row numrows = 0 still_querying = True while still_querying: request = self._build_request(current, max_rows) resp = self._cb.post_object(url, body=request) result = resp.json() self._total_results = result["num_found"] self._count_valid = True results = result.get("results", []) for item in results: yield self._doc_class(self._cb, item["id"], item) current += 1 numrows += 1 if max_rows > 0 and numrows == max_rows: still_querying = False break if current >= self._total_results: break def _run_async_query(self, context): """ Executed in the background to run an asynchronous query. Args: context (object): Not used, always None. Returns: list: Result of the async query, which is then returned by the future. """ url = self._build_url("/_search") request = self._build_request(0, -1) resp = self._cb.post_object(url, body=request) result = resp.json() self._total_results = result["num_found"] self._count_valid = True results = result.get("results", []) return [ self._doc_class(self._cb, item["id"], item) for item in results ] def facets(self, fieldlist, max_rows=0): """ Return information about the facets for all known USB devices, using the defined criteria. Args: fieldlist (list): List of facet field names. Valid names are "vendor_name", "product_name", "endpoint.endpoint_name", and "status". max_rows (int): The maximum number of rows to return. 0 means return all rows. Returns: list: A list of facet information specified as dicts. """ if not all((field in USBDeviceQuery.VALID_FACET_FIELDS) for field in fieldlist): raise ApiError("One or more invalid term field names") request = self._build_request(0, -1, False) del request["rows"] request["terms"] = {"fields": fieldlist, "rows": max_rows} url = self._build_url("/_facet") resp = self._cb.post_object(url, body=request) result = resp.json() return result.get("terms", [])
class Query(PaginatedQuery, PlatformQueryBase, QueryBuilderSupportMixin, IterableQueryMixin): """Represents a prepared query to the Cb Endpoint Standard server. This object is returned as part of a `CBCloudAPI.select` operation on models requested from the Cb Endpoint Standard server. You should not have to create this class yourself. The query is not executed on the server until it's accessed, either as an iterator (where it will generate values on demand as they're requested) or as a list (where it will retrieve the entire result set and save to a list). You can also call the Python built-in `len() on this object to retrieve the total number of items matching the query. Example: >>> from cbc_sdk import CBCloudAPI >>> cb = CBCloudAPI() Notes: - The slicing operator only supports start and end parameters, but not step. ``[1:-1]`` is legal, but ``[1:2:-1]`` is not. - You can chain where clauses together to create AND queries; only objects that match all ``where`` clauses will be returned. - Device Queries with multiple search parameters only support AND operations, not OR. Use of Query.or_(myParameter='myValue') will add 'AND myParameter:myValue' to the search query. """ def __init__(self, doc_class, cb, query=None): """Initialize a Query object.""" super(Query, self).__init__(doc_class, cb, query) # max batch_size is 5000 self._batch_size = 100 if query is not None: # copy existing .where(), and_() queries self._query_builder = QueryBuilder() self._query_builder._query = query._query_builder._query else: self._query_builder = QueryBuilder() def _clone(self): nq = self.__class__(self._doc_class, self._cb, query=self) nq._batch_size = self._batch_size return nq def or_(self, **kwargs): """Unsupported. Will raise if called. Raises: ApiError: .or_() cannot be called on Endpoint Standard queries. """ raise ApiError(".or_() cannot be called on Endpoint Standard queries.") def prepare_query(self, args): """Adds query parameters that are part of a `select().where()` clause to the request.""" request = args params = self._query_builder._collapse() if params is not None: for query in params.split(' '): try: # convert from str('key:value') to dict{'key': 'value'} key, value = query.split(':', 1) # must remove leading or trailing parentheses that were inserted by logical combinations key = key.strip('(').strip(')') value = value.strip('(').strip(')') request[key] = value except ValueError: # AND or OR encountered pass return request def _count(self): if self._count_valid: return self._total_results args = {} args = self.prepare_query(args) query_args = convert_query_params(args) self._total_results = int(self._cb.get_object(self._doc_class.urlobject, query_parameters=query_args) .get("totalResults", 0)) self._count_valid = True return self._total_results def _search(self, start=0, rows=0): # iterate over total result set, in batches of self._batch_size at a time # defaults to 100 results each call args = {} if start != 0: args['start'] = start args['rows'] = self._batch_size current = start numrows = 0 args = self.prepare_query(args) still_querying = True while still_querying: query_args = convert_query_params(args) result = self._cb.get_object(self._doc_class.urlobject, query_parameters=query_args) self._total_results = result.get("totalResults", 0) self._count_valid = True results = result.get('results', []) if results is None: log.debug("Results are None") if current >= 100000: log.info("Max result size exceeded. Truncated to 100k.") break for item in results: yield item current += 1 numrows += 1 if rows and numrows == rows: still_querying = False break # as of 6/2017, the indexing on the Cb Endpoint Standard backend is still 1-based args['start'] = current + 1 if current >= self._total_results: break if not results: log.debug("server reported total_results overestimated the number of results for this query by {0}" .format(self._total_results - current)) log.debug("resetting total_results for this query to {0}".format(current)) self._total_results = current break
class BaseAlertSearchQuery(BaseQuery, QueryBuilderSupportMixin, IterableQueryMixin, CriteriaBuilderSupportMixin): """Represents a query that is used to locate BaseAlert objects.""" VALID_CATEGORIES = ["THREAT", "MONITORED", "INFO", "MINOR", "SERIOUS", "CRITICAL"] VALID_REPUTATIONS = ["KNOWN_MALWARE", "SUSPECT_MALWARE", "PUP", "NOT_LISTED", "ADAPTIVE_WHITE_LIST", "COMMON_WHITE_LIST", "TRUSTED_WHITE_LIST", "COMPANY_BLACK_LIST"] VALID_ALERT_TYPES = ["CB_ANALYTICS", "DEVICE_CONTROL", "WATCHLIST"] VALID_WORKFLOW_VALS = ["OPEN", "DISMISSED"] VALID_FACET_FIELDS = ["ALERT_TYPE", "CATEGORY", "REPUTATION", "WORKFLOW", "TAG", "POLICY_ID", "POLICY_NAME", "DEVICE_ID", "DEVICE_NAME", "APPLICATION_HASH", "APPLICATION_NAME", "STATUS", "RUN_STATE", "POLICY_APPLIED_STATE", "POLICY_APPLIED", "SENSOR_ACTION"] def __init__(self, doc_class, cb): """ Initialize the BaseAlertSearchQuery. Args: doc_class (class): The model class that will be returned by this query. cb (BaseAPI): Reference to API object used to communicate with the server. """ self._doc_class = doc_class self._cb = cb self._count_valid = False super(BaseAlertSearchQuery, self).__init__() self._query_builder = QueryBuilder() self._criteria = {} self._time_filters = {} self._sortcriteria = {} self._bulkupdate_url = "/appservices/v6/orgs/{0}/alerts/workflow/_criteria" self._count_valid = False self._total_results = 0 def set_categories(self, categories): """ Restricts the alerts that this query is performed on to the specified categories. Args: categories (list): List of categories to be restricted to. Valid categories are "THREAT", "MONITORED", "INFO", "MINOR", "SERIOUS", and "CRITICAL." Returns: BaseAlertSearchQuery: This instance. """ if not all((c in BaseAlertSearchQuery.VALID_CATEGORIES) for c in categories): raise ApiError("One or more invalid category values") self._update_criteria("category", categories) return self def set_create_time(self, *args, **kwargs): """ Restricts the alerts that this query is performed on to the specified creation time. The time may either be specified as a start and end point or as a range. Args: *args (list): Not used. **kwargs (dict): Used to specify start= for start time, end= for end time, and range= for range. Returns: BaseAlertSearchQuery: This instance. """ if kwargs.get("start", None) and kwargs.get("end", None): if kwargs.get("range", None): raise ApiError("cannot specify range= in addition to start= and end=") stime = kwargs["start"] if not isinstance(stime, str): stime = stime.isoformat() etime = kwargs["end"] if not isinstance(etime, str): etime = etime.isoformat() self._time_filters["create_time"] = {"start": stime, "end": etime} elif kwargs.get("range", None): if kwargs.get("start", None) or kwargs.get("end", None): raise ApiError("cannot specify start= or end= in addition to range=") self._time_filters["create_time"] = {"range": kwargs["range"]} else: raise ApiError("must specify either start= and end= or range=") return self def set_device_ids(self, device_ids): """ Restricts the alerts that this query is performed on to the specified device IDs. Args: device_ids (list): List of integer device IDs. Returns: BaseAlertSearchQuery: This instance. """ if not all(isinstance(device_id, int) for device_id in device_ids): raise ApiError("One or more invalid device IDs") self._update_criteria("device_id", device_ids) return self def set_device_names(self, device_names): """ Restricts the alerts that this query is performed on to the specified device names. Args: device_names (list): List of string device names. Returns: BaseAlertSearchQuery: This instance. """ if not all(isinstance(n, str) for n in device_names): raise ApiError("One or more invalid device names") self._update_criteria("device_name", device_names) return self def set_device_os(self, device_os): """ Restricts the alerts that this query is performed on to the specified device operating systems. Args: device_os (list): List of string operating systems. Valid values are "WINDOWS", "ANDROID", "MAC", "IOS", "LINUX", and "OTHER." Returns: BaseAlertSearchQuery: This instance. """ if not all((osval in DeviceSearchQuery.VALID_OS) for osval in device_os): raise ApiError("One or more invalid operating systems") self._update_criteria("device_os", device_os) return self def set_device_os_versions(self, device_os_versions): """ Restricts the alerts that this query is performed on to the specified device operating system versions. Args: device_os_versions (list): List of string operating system versions. Returns: BaseAlertSearchQuery: This instance. """ if not all(isinstance(n, str) for n in device_os_versions): raise ApiError("One or more invalid device OS versions") self._update_criteria("device_os_version", device_os_versions) return self def set_device_username(self, users): """ Restricts the alerts that this query is performed on to the specified user names. Args: users (list): List of string user names. Returns: BaseAlertSearchQuery: This instance. """ if not all(isinstance(u, str) for u in users): raise ApiError("One or more invalid user names") self._update_criteria("device_username", users) return self def set_group_results(self, do_group): """ Specifies whether or not to group the results of the query. Args: do_group (bool): True to group the results, False to not do so. Returns: BaseAlertSearchQuery: This instance. """ self._criteria["group_results"] = True if do_group else False return self def set_alert_ids(self, alert_ids): """ Restricts the alerts that this query is performed on to the specified alert IDs. Args: alert_ids (list): List of string alert IDs. Returns: BaseAlertSearchQuery: This instance. """ if not all(isinstance(v, str) for v in alert_ids): raise ApiError("One or more invalid alert ID values") self._update_criteria("id", alert_ids) return self def set_legacy_alert_ids(self, alert_ids): """ Restricts the alerts that this query is performed on to the specified legacy alert IDs. Args: alert_ids (list): List of string legacy alert IDs. Returns: BaseAlertSearchQuery: This instance. """ if not all(isinstance(v, str) for v in alert_ids): raise ApiError("One or more invalid alert ID values") self._update_criteria("legacy_alert_id", alert_ids) return self def set_minimum_severity(self, severity): """ Restricts the alerts that this query is performed on to the specified minimum severity level. Args: severity (int): The minimum severity level for alerts. Returns: BaseAlertSearchQuery: This instance. """ self._criteria["minimum_severity"] = severity return self def set_policy_ids(self, policy_ids): """ Restricts the alerts that this query is performed on to the specified policy IDs. Args: policy_ids (list): List of integer policy IDs. Returns: BaseAlertSearchQuery: This instance. """ if not all(isinstance(policy_id, int) for policy_id in policy_ids): raise ApiError("One or more invalid policy IDs") self._update_criteria("policy_id", policy_ids) return self def set_policy_names(self, policy_names): """ Restricts the alerts that this query is performed on to the specified policy names. Args: policy_names (list): List of string policy names. Returns: BaseAlertSearchQuery: This instance. """ if not all(isinstance(n, str) for n in policy_names): raise ApiError("One or more invalid policy names") self._update_criteria("policy_name", policy_names) return self def set_process_names(self, process_names): """ Restricts the alerts that this query is performed on to the specified process names. Args: process_names (list): List of string process names. Returns: BaseAlertSearchQuery: This instance. """ if not all(isinstance(n, str) for n in process_names): raise ApiError("One or more invalid process names") self._update_criteria("process_name", process_names) return self def set_process_sha256(self, shas): """ Restricts the alerts that this query is performed on to the specified process SHA-256 hash values. Args: shas (list): List of string process SHA-256 hash values. Returns: BaseAlertSearchQuery: This instance. """ if not all(isinstance(n, str) for n in shas): raise ApiError("One or more invalid SHA256 values") self._update_criteria("process_sha256", shas) return self def set_reputations(self, reps): """ Restricts the alerts that this query is performed on to the specified reputation values. Args: reps (list): List of string reputation values. Valid values are "KNOWN_MALWARE", "SUSPECT_MALWARE", "PUP", "NOT_LISTED", "ADAPTIVE_WHITE_LIST", "COMMON_WHITE_LIST", "TRUSTED_WHITE_LIST", and "COMPANY_BLACK_LIST". Returns: BaseAlertSearchQuery: This instance. """ if not all((r in BaseAlertSearchQuery.VALID_REPUTATIONS) for r in reps): raise ApiError("One or more invalid reputation values") self._update_criteria("reputation", reps) return self def set_tags(self, tags): """ Restricts the alerts that this query is performed on to the specified tag values. Args: tags (list): List of string tag values. Returns: BaseAlertSearchQuery: This instance. """ if not all(isinstance(tag, str) for tag in tags): raise ApiError("One or more invalid tags") self._update_criteria("tag", tags) return self def set_target_priorities(self, priorities): """ Restricts the alerts that this query is performed on to the specified target priority values. Args: priorities (list): List of string target priority values. Valid values are "LOW", "MEDIUM", "HIGH", and "MISSION_CRITICAL". Returns: BaseAlertSearchQuery: This instance. """ if not all((prio in DeviceSearchQuery.VALID_PRIORITIES) for prio in priorities): raise ApiError("One or more invalid priority values") self._update_criteria("target_value", priorities) return self def set_threat_ids(self, threats): """ Restricts the alerts that this query is performed on to the specified threat ID values. Args: threats (list): List of string threat ID values. Returns: BaseAlertSearchQuery: This instance. """ if not all(isinstance(t, str) for t in threats): raise ApiError("One or more invalid threat ID values") self._update_criteria("threat_id", threats) return self def set_time_range(self, key, **kwargs): """ Restricts the alerts that this query is performed on to the specified time range. The time may either be specified as a start and end point or as a range. Args: key (str): The key to use for criteria one of create_time, first_event_time, last_event_time, or last_update_time **kwargs (dict): Used to specify start= for start time, end= for end time, and range= for range. Returns: BaseAlertSearchQuery: This instance. """ if key not in ["create_time", "first_event_time", "last_event_time", "last_update_time"]: raise ApiError("key must be one of create_time, first_event_time, last_event_time, or last_update_time") if kwargs.get("start", None) and kwargs.get("end", None): if kwargs.get("range", None): raise ApiError("cannot specify range= in addition to start= and end=") stime = kwargs["start"] if not isinstance(stime, str): stime = stime.isoformat() etime = kwargs["end"] if not isinstance(etime, str): etime = etime.isoformat() self._time_filters[key] = {"start": stime, "end": etime} elif kwargs.get("range", None): if kwargs.get("start", None) or kwargs.get("end", None): raise ApiError("cannot specify start= or end= in addition to range=") self._time_filters[key] = {"range": kwargs["range"]} else: raise ApiError("must specify either start= and end= or range=") return self def set_types(self, alerttypes): """ Restricts the alerts that this query is performed on to the specified alert type values. Args: alerttypes (list): List of string alert type values. Valid values are "CB_ANALYTICS", and "WATCHLIST". Returns: BaseAlertSearchQuery: This instance. """ if not all((t in BaseAlertSearchQuery.VALID_ALERT_TYPES) for t in alerttypes): raise ApiError("One or more invalid alert type values") self._update_criteria("type", alerttypes) return self def set_workflows(self, workflow_vals): """ Restricts the alerts that this query is performed on to the specified workflow status values. Args: workflow_vals (list): List of string alert type values. Valid values are "OPEN" and "DISMISSED". Returns: BaseAlertSearchQuery: This instance. """ if not all((t in BaseAlertSearchQuery.VALID_WORKFLOW_VALS) for t in workflow_vals): raise ApiError("One or more invalid workflow status values") self._update_criteria("workflow", workflow_vals) return self def _build_criteria(self): """ Builds the criteria object for use in a query. Returns: dict: The criteria object. """ mycrit = self._criteria if self._time_filters: mycrit.update(self._time_filters) return mycrit def sort_by(self, key, direction="ASC"): """ Sets the sorting behavior on a query's results. Example: >>> cb.select(BaseAlert).sort_by("name") Args: key (str): The key in the schema to sort by. direction (str): The sort order, either "ASC" or "DESC". Returns: BaseAlertSearchQuery: This instance. """ if direction not in DeviceSearchQuery.VALID_DIRECTIONS: raise ApiError("invalid sort direction specified") self._sortcriteria = {"field": key, "order": direction} return self def _build_request(self, from_row, max_rows, add_sort=True): """ Creates the request body for an API call. Args: from_row (int): The row to start the query at. max_rows (int): The maximum number of rows to be returned. add_sort (bool): If True(default), the sort criteria will be added as part of the request. Returns: dict: The complete request body. """ request = {"criteria": self._build_criteria()} request["query"] = self._query_builder._collapse() # Fetch 100 rows per page (instead of 10 by default) for better performance request["rows"] = 100 if from_row > 0: request["start"] = from_row if max_rows >= 0: request["rows"] = max_rows if add_sort and self._sortcriteria != {}: request["sort"] = [self._sortcriteria] return request def _build_url(self, tail_end): """ Creates the URL to be used for an API call. Args: tail_end (str): String to be appended to the end of the generated URL. Returns: str: The complete URL. """ url = self._doc_class.urlobject.format(self._cb.credentials.org_key) + tail_end return url def _count(self): """ Returns the number of results from the run of this query. Returns: int: The number of results from the run of this query. """ if self._count_valid: return self._total_results url = self._build_url("/_search") request = self._build_request(0, -1) resp = self._cb.post_object(url, body=request) result = resp.json() self._total_results = result["num_found"] self._count_valid = True return self._total_results def _perform_query(self, from_row=0, max_rows=-1): """ Performs the query and returns the results of the query in an iterable fashion. Args: from_row (int): The row to start the query at (default 0). max_rows (int): The maximum number of rows to be returned (default -1, meaning "all"). Returns: Iterable: The iterated query. """ url = self._build_url("/_search") current = from_row numrows = 0 still_querying = True while still_querying: request = self._build_request(current, max_rows) resp = self._cb.post_object(url, body=request) result = resp.json() self._total_results = result["num_found"] self._count_valid = True results = result.get("results", []) for item in results: yield self._doc_class(self._cb, item["id"], item) current += 1 numrows += 1 if max_rows > 0 and numrows == max_rows: still_querying = False break from_row = current if current >= self._total_results: still_querying = False break def facets(self, fieldlist, max_rows=0): """ Return information about the facets for this alert by search, using the defined criteria. Args: fieldlist (list): List of facet field names. Valid names are "ALERT_TYPE", "CATEGORY", "REPUTATION", "WORKFLOW", "TAG", "POLICY_ID", "POLICY_NAME", "DEVICE_ID", "DEVICE_NAME", "APPLICATION_HASH", "APPLICATION_NAME", "STATUS", "RUN_STATE", "POLICY_APPLIED_STATE", "POLICY_APPLIED", and "SENSOR_ACTION". max_rows (int): The maximum number of rows to return. 0 means return all rows. Returns: list: A list of facet information specified as dicts. """ if not all((field in BaseAlertSearchQuery.VALID_FACET_FIELDS) for field in fieldlist): raise ApiError("One or more invalid term field names") request = self._build_request(0, -1, False) request["terms"] = {"fields": fieldlist, "rows": max_rows} url = self._build_url("/_facet") resp = self._cb.post_object(url, body=request) result = resp.json() return result.get("results", []) def _update_status(self, status, remediation, comment): """ Updates the status of all alerts matching the given query. Args: status (str): The status to put the alerts into, either "OPEN" or "DISMISSED". remediation (str): The remediation state to set for all alerts. comment (str): The comment to set for all alerts. Returns: str: The request ID, which may be used to select a WorkflowStatus object. """ request = {"state": status, "criteria": self._build_criteria(), "query": self._query_builder._collapse()} if remediation is not None: request["remediation_state"] = remediation if comment is not None: request["comment"] = comment resp = self._cb.post_object(self._bulkupdate_url.format(self._cb.credentials.org_key), body=request) output = resp.json() return output["request_id"] def update(self, remediation=None, comment=None): """ Update all alerts matching the given query. The alerts will be left in an OPEN state after this request. Args: remediation (str): The remediation state to set for all alerts. comment (str): The comment to set for all alerts. Returns: str: The request ID, which may be used to select a WorkflowStatus object. """ return self._update_status("OPEN", remediation, comment) def dismiss(self, remediation=None, comment=None): """ Dismiss all alerts matching the given query. The alerts will be left in a DISMISSED state after this request. Args: remediation (str): The remediation state to set for all alerts. comment (str): The comment to set for all alerts. Returns: str: The request ID, which may be used to select a WorkflowStatus object. """ return self._update_status("DISMISSED", remediation, comment)
class ComputeResourceQuery(BaseQuery, QueryBuilderSupportMixin, CriteriaBuilderSupportMixin, IterableQueryMixin, AsyncQueryMixin): """Represents a query that is used to locate ComputeResource objects.""" VALID_OS_TYPE = ("WINDOWS", "RHEL", "UBUNTU", "SUSE", "SLES", "CENTOS", "OTHER", "AMAZON_LINUX", "ORACLE") VALID_DIRECTIONS = ("ASC", "DESC") VALID_ELIGIBILITY = ("ELIGIBLE", "NOT_ELIGIBLE", "UNSUPPORTED") VALID_OS_ARCHITECTURE = ("32", "64") VALID_INSTALLATION_STATUS = ("SUCCESS", "ERROR", "PENDING", "NOT_INSTALLED") def __init__(self, doc_class, cb): """ Initialize the ComputeResource. Args: doc_class (class): The model class that will be returned by this query. cb (BaseAPI): Reference to API object used to communicate with the server. """ self._doc_class = doc_class self._cb = cb self._count_valid = False super(BaseQuery, self).__init__() self._query_builder = QueryBuilder() self._criteria = {} self._sortcriteria = {} self._total_results = 0 def set_appliance_uuid(self, appliance_uuid): """ Restricts the search that this query is performed on to the specified appliance uuid. Args: appliance_uuid (list): List of string appliance uuids. Returns: ComputeResourceQuery: This instance. """ if not all(isinstance(_, str) for _ in appliance_uuid): raise ApiError("One or more invalid appliance uuid") self._update_criteria("appliance_uuid", appliance_uuid) return self def set_eligibility(self, eligibility): """ Restricts the search that this query is performed on to the specified eligibility. Args: eligibility (list): List of string eligibilities. Returns: ComputeResourceQuery: This instance. """ if not all((_ in ComputeResourceQuery.VALID_ELIGIBILITY) for _ in eligibility): raise ApiError("One or more invalid eligibility") self._update_criteria("eligibility", eligibility) return self def set_cluster_name(self, cluster_name): """ Restricts the search that this query is performed on to the specified cluster name. Args: cluster_name (list): List of string cluster names. Returns: ComputeResourceQuery: This instance. """ if not all(isinstance(_, str) for _ in cluster_name): raise ApiError("One or more invalid cluster name") self._update_criteria("cluster_name", cluster_name) return self def set_name(self, name): """ Restricts the search that this query is performed on to the specified name. Args: name (list): List of string names. Returns: ComputeResourceQuery: This instance. """ if not all(isinstance(_, str) for _ in name): raise ApiError("One or more invalid names") self._update_criteria("name", name) return self def set_ip_address(self, ip_address): """ Restricts the search that this query is performed on to the specified ip address. Args: ip_address (list): List of string ip addresses. Returns: ComputeResourceQuery: This instance. """ if not all(isinstance(_, str) for _ in ip_address): raise ApiError("One or more invalid ip address") self._update_criteria("ip_address", ip_address) return self def set_installation_status(self, installation_status): """ Restricts the search that this query is performed on to the specified installation status. Args: installation_status (list): List of string installation status. Returns: ComputeResourceQuery: This instance. """ if not all((_ in ComputeResourceQuery.VALID_INSTALLATION_STATUS) for _ in installation_status): raise ApiError("One or more invalid installation status") self._update_criteria("installation_status", installation_status) return self def set_uuid(self, uuid): """ Restricts the search that this query is performed on to the specified uuid. Args: uuid (list): List of string uuid. Returns: ComputeResourceQuery: This instance. """ if not all(isinstance(_, str) for _ in uuid): raise ApiError("One or more invalid uuid") self._update_criteria("uuid", uuid) return self def set_os_type(self, os_type): """ Restricts the search that this query is performed on to the specified os type. Args: os_type (list): List of string os type. Returns: ComputeResourceQuery: This instance. """ if not all((_ in ComputeResourceQuery.VALID_OS_TYPE) for _ in os_type): raise ApiError("One or more invalid os type") self._update_criteria("os_type", os_type) return self def set_os_architecture(self, os_architecture): """ Restricts the search that this query is performed on to the specified os architecture. Args: os_architecture (list): List of string os architecture. Returns: ComputeResourceQuery: This instance. """ if not all((_ in ComputeResourceQuery.VALID_OS_ARCHITECTURE) for _ in os_architecture): raise ApiError("One or more invalid os architecture") self._update_criteria("os_architecture", os_architecture) return self def sort_by(self, key, direction="ASC"): """ Sets the sorting behavior on a query's results. Example: >>> cb.select(ComputeResource).sort_by("name") Args: key (str): The key in the schema to sort by. direction (str): The sort order. Returns: ComputeResourceQuery: This instance. """ if direction not in ComputeResourceQuery.VALID_DIRECTIONS: raise ApiError("invalid sort direction specified") self._sortcriteria = {"field": key, "order": direction} return self def _build_request(self, from_row, max_rows, add_sort=True): """ Creates the request body for an API call. Args: from_row (int): The row to start the query at. max_rows (int): The maximum number of rows to be returned. add_sort (bool): If True(default), the sort criteria will be added as part of the request. Returns: dict: The complete request body. """ request = {"criteria": self._criteria, "query": self._query_builder._collapse(), "rows": 100} # Fetch 100 rows per page (instead of 10 by default) for better performance if from_row > 0: request["start"] = from_row if max_rows >= 0: request["rows"] = max_rows if add_sort and self._sortcriteria != {}: request["sort"] = [self._sortcriteria] return request def _build_url(self, tail_end): """ Creates the URL to be used for an API call. Args: tail_end (str): String to be appended to the end of the generated URL. Returns: str: The complete URL. """ url = self._doc_class.urlobject.format(self._cb.credentials.org_key) + tail_end return url def _count(self): """ Returns the number of results from the run of this query. Returns: int: The number of results from the run of this query. """ if self._count_valid: return self._total_results url = self._build_url("/_search") request = self._build_request(0, -1) resp = self._cb.post_object(url, body=request) result = resp.json() self._total_results = result["num_found"] self._count_valid = True return self._total_results def _perform_query(self, from_row=0, max_rows=-1): """ Performs the query and returns the results of the query in an iterable fashion. Args: from_row (int): The row to start the query at (default 0). max_rows (int): The maximum number of rows to be returned (default -1, meaning "all"). Returns: Iterable: The iterated query. """ url = self._build_url("/_search") current = from_row numrows = 0 still_querying = True while still_querying: request = self._build_request(current, max_rows) resp = self._cb.post_object(url, body=request) result = resp.json() self._total_results = result["num_found"] self._count_valid = True results = result.get("results", []) for item in results: yield self._doc_class(self._cb, item["id"], item) current += 1 numrows += 1 if max_rows > 0 and numrows == max_rows: still_querying = False break if current >= self._total_results: break def _run_async_query(self, context): """ Executed in the background to run an asynchronous query. Args: context (object): Not used, always None. Returns: list: Result of the async query, which is then returned by the future. """ url = self._build_url("/_search") request = self._build_request(0, -1) resp = self._cb.post_object(url, body=request) result = resp.json() self._total_results = result["num_found"] self._count_valid = True results = result.get("results", []) return [self._doc_class(self._cb, item["id"], item) for item in results]