Example #1
0
    def __init__(self, doc_class, cb):
        super(Query, self).__init__(doc_class, cb, None)

        self._query_builder = QueryBuilder()
        self._sort_by = None
        self._group_by = None
        self._batch_size = 100
        self._default_args = {}
Example #2
0
 def __init__(self, doc_class, cb):
     """Initialize a RunHistoryQuery object."""
     self._doc_class = doc_class
     self._cb = cb
     self._count_valid = False
     super(RunHistoryQuery, self).__init__()
     self._query_builder = QueryBuilder()
     self._sort = {}
     self._criteria = {}
Example #3
0
    def __init__(self, doc_class, cb):
        """Initialize a FacetQuery object."""
        self._doc_class = doc_class
        self._cb = cb
        self._count_valid = False
        super(FacetQuery, self).__init__()

        self._query_builder = QueryBuilder()
        self._facet_fields = []
        self._criteria = {}
        self._run_id = None
Example #4
0
    def __init__(self, doc_class, cb):
        """Initialize a ResultQuery object."""
        self._doc_class = doc_class
        self._cb = cb
        self._count_valid = False
        super(ResultQuery, self).__init__()

        self._query_builder = QueryBuilder()
        self._criteria = {}
        self._sort = {}
        self._batch_size = 100
        self._run_id = None
Example #5
0
    def __init__(self, doc_class, cb, query=None):
        """Initialize a Query object."""
        super(Query, self).__init__(doc_class, cb, query)

        # max batch_size is 5000
        self._batch_size = 100
        if query is not None:
            # copy existing .where(), and_() queries
            self._query_builder = QueryBuilder()
            self._query_builder._query = query._query_builder._query
        else:
            self._query_builder = QueryBuilder()
    def __init__(self, doc_class, cb):
        """
        Initialize the DeviceSearchQuery.

        Args:
            doc_class (class): The model class that will be returned by this query.
            cb (BaseAPI): Reference to API object used to communicate with the server.
        """
        super().__init__(doc_class, cb)
        self._query_builder = QueryBuilder()
        self._criteria = {}
        self._time_filter = {}
        self._exclusions = {}
        self._sortcriteria = {}
Example #7
0
    def __init__(self, doc_class, cb):
        """
        Initialize the USBDeviceApprovalQuery.

        Args:
            doc_class (class): The model class that will be returned by this query.
            cb (BaseAPI): Reference to API object used to communicate with the server.
        """
        self._doc_class = doc_class
        self._cb = cb
        self._count_valid = False
        super(BaseQuery, self).__init__()
        self._query_builder = QueryBuilder()
        self._criteria = {}
        self._total_results = 0
    def __init__(self, doc_class, cb):
        """
        Initialize the BaseAlertSearchQuery.

        Args:
            doc_class (class): The model class that will be returned by this query.
            cb (BaseAPI): Reference to API object used to communicate with the server.
        """
        super().__init__(doc_class, cb)
        self._query_builder = QueryBuilder()
        self._criteria = {}
        self._time_filter = {}
        self._sortcriteria = {}
        self._bulkupdate_url = "/appservices/v6/orgs/{0}/alerts/workflow/_criteria"
        self._count_valid = False
        self._total_results = 0
Example #9
0
    def __init__(self, doc_class, cb):
        """
        Initialize the ReputationOverrideQuery.

        Args:
            doc_class (class): The model class that will be returned by this query.
            cb (BaseAPI): Reference to API object used to communicate with the server.
        """
        self._doc_class = doc_class
        self._cb = cb
        self._count_valid = False
        super(ReputationOverrideQuery, self).__init__()

        self._query_builder = QueryBuilder()
        self._criteria = {}
        self._sortcriteria = {}
Example #10
0
    def __init__(self, doc_class, cb, device=None):
        """
        Initialize the VulnerabilityQuery.

        Args:
            doc_class (class): The model class that will be returned by this query.
            cb (BaseAPI): Reference to API object used to communicate with the server.
            device (Device): Optional Device object to indicate VulnerabilityQuery is for a specific device
        """
        self._doc_class = doc_class
        self._cb = cb
        self._count_valid = False
        super(BaseQuery, self).__init__()

        self._query_builder = QueryBuilder()
        self._criteria = {}
        self._sortcriteria = {}
        self._total_results = 0

        self.device = device
        self._vcenter_uuid = None
Example #11
0
class Query(PaginatedQuery, QueryBuilderSupportMixin, IterableQueryMixin):
    """Represents a prepared query to the Cb Enterprise EDR backend.

    This object is returned as part of a `CbEnterpriseEDRAPI.select`
    operation on models requested from the Cb Enterprise EDR backend. You should not have to create this class yourself.

    The query is not executed on the server until it's accessed, either as an iterator (where it will generate values
    on demand as they're requested) or as a list (where it will retrieve the entire result set and save to a list).
    You can also call the Python built-in ``len()`` on this object to retrieve the total number of items matching
    the query.

    Examples::

    >>> from cbc_sdk.enterprise_edr import CBCloudAPI,Process
    >>> cb = CBCloudAPI()
    >>> query = cb.select(Process)
    >>> query = query.where(process_name="notepad.exe")
    >>> # alternatively:
    >>> query = query.where("process_name:notepad.exe")

    Notes:
        - The slicing operator only supports start and end parameters, but not step. ``[1:-1]`` is legal, but
          ``[1:2:-1]`` is not.
        - You can chain where clauses together to create AND queries; only objects that match all ``where`` clauses
          will be returned.
    """
    def __init__(self, doc_class, cb):
        super(Query, self).__init__(doc_class, cb, None)

        self._query_builder = QueryBuilder()
        self._sort_by = None
        self._group_by = None
        self._batch_size = 100
        self._default_args = {}

    def _get_query_parameters(self):
        args = self._default_args.copy()
        args['query'] = self._query_builder._collapse()
        if self._query_builder._process_guid is not None:
            args["process_guid"] = self._query_builder._process_guid
        if 'process_guid:' in args['query']:
            q = args['query'].split('process_guid:', 1)
            args["process_guid"] = q[1]
        args["fields"] = [
            "*", "parent_hash", "parent_name", "process_cmdline",
            "backend_timestamp", "device_external_ip", "device_group",
            "device_internal_ip", "device_os", "device_policy",
            "process_effective_reputation", "process_reputation",
            "process_start_time", "ttp"
        ]

        return args

    def _count(self):
        args = self._get_query_parameters()

        log.debug("args: {}".format(str(args)))

        result = self._cb.post_object(self._doc_class.urlobject.format(
            self._cb.credentials.org_key, args["process_guid"]),
                                      body=args).json()

        self._total_results = int(result.get('num_available', 0))
        self._count_valid = True

        return self._total_results

    def _validate(self, args):
        if not hasattr(self._doc_class, "validation_url"):
            return

        url = self._doc_class.validation_url.format(
            self._cb.credentials.org_key)

        if args.get('query', False):
            args['q'] = args['query']

        # v2 search sort key does not work with v1 validation
        args.pop('sort', None)

        validated = self._cb.get_object(url, query_parameters=args)

        if not validated.get("valid"):
            raise ApiError("Invalid query: {}: {}".format(
                args, validated["invalid_message"]))

    def _search(self, start=0, rows=0):
        # iterate over total result set, 100 at a time
        args = self._get_query_parameters()
        self._validate(args)

        if start != 0:
            args['start'] = start
        args['rows'] = self._batch_size

        current = start
        numrows = 0

        still_querying = True

        while still_querying:
            url = self._doc_class.urlobject.format(
                self._cb.credentials.org_key, args["process_guid"])
            resp = self._cb.post_object(url, body=args)
            result = resp.json()

            self._total_results = result.get("num_available", 0)
            self._total_segments = result.get("total_segments", 0)
            self._processed_segments = result.get("processed_segments", 0)
            self._count_valid = True

            results = result.get('results', [])

            for item in results:
                yield item
                current += 1
                numrows += 1
                if rows and numrows == rows:
                    still_querying = False
                    break

            args[
                'start'] = current + 1  # as of 6/2017, the indexing on the Cb Endpoint Standard backend is still 1-based

            if current >= self._total_results:
                break
            if not results:
                log.debug(
                    "server reported total_results overestimated the number of results for this query by {0}"
                    .format(self._total_results - current))
                log.debug(
                    "resetting total_results for this query to {0}".format(
                        current))
                self._total_results = current
                break
Example #12
0
class VulnerabilityQuery(BaseQuery, QueryBuilderSupportMixin,
                         IterableQueryMixin, AsyncQueryMixin):
    """Represents a query that is used to locate Vulnerabiltity objects."""
    VALID_DEVICE_TYPE = ["WORKLOAD", "ENDPOINT"]
    VALID_OS_TYPE = ["CENTOS", "RHEL", "SLES", "UBUNTU", "WINDOWS"]
    VALID_SEVERITY = ["CRITICAL", "IMPORTANT", "MODERATE", "LOW"]
    VALID_SYNC_TYPE = ["MANUAL", "SCHEDULED"]
    VALID_SYNC_STATUS = [
        "NOT_STARTED", "MATCHED", "ERROR", "NOT_MATCHED", "NOT_SUPPORTED",
        "CANCELLED", "IN_PROGRESS", "ACTIVE", "COMPLETED"
    ]
    VALID_DIRECTIONS = ["ASC", "DESC"]

    def __init__(self, doc_class, cb, device=None):
        """
        Initialize the VulnerabilityQuery.

        Args:
            doc_class (class): The model class that will be returned by this query.
            cb (BaseAPI): Reference to API object used to communicate with the server.
            device (Device): Optional Device object to indicate VulnerabilityQuery is for a specific device
        """
        self._doc_class = doc_class
        self._cb = cb
        self._count_valid = False
        super(BaseQuery, self).__init__()

        self._query_builder = QueryBuilder()
        self._criteria = {}
        self._sortcriteria = {}
        self._total_results = 0

        self.device = device
        self._vcenter_uuid = None

    def set_vcenter(self, vcenter_uuid):
        """
        Restricts the vulnerabilities that this query is performed on to the specified vcenter id.

        Args:
            vcenter_uuid (str): vcenter uuid.

        Returns:
            VulnerabilityQuery: This instance.
        """
        if vcenter_uuid:
            self._vcenter_uuid = vcenter_uuid
        return self

    def add_criteria(self, key, value, operator='EQUALS'):
        """
        Restricts the vulnerabilities that this query is performed on to the specified key value pair.

        Args:
            key (str): Property from the vulnerability object
            value (str): Value of the property to filter by
            operator (str): (optional) logic operator to apply to property value.

        Returns:
            VulnerabilityQuery: This instance.
        """
        self._update_criteria(key, value, operator)
        return self

    def set_device_type(self, device_type, operator):
        """
        Restricts the vulnerabilities that this query is performed on to the specified device type.

        Args:
            device_type (str): device type ("WORKLOAD", "ENDPOINT")
            operator (str):  logic operator to apply to property value.

        Returns:
            VulnerabilityQuery: This instance.
        """
        if device_type not in VulnerabilityQuery.VALID_DEVICE_TYPE:
            raise ApiError("Invalid device type")
        self._update_criteria("device_type", device_type, operator)
        return self

    def set_highest_risk_score(self, highest_risk_score, operator):
        """
        Restricts the vulnerabilities that this query is performed on to the specified highest_risk_score.

        Args:
            highest_risk_score (double): highest_risk_score.
            operator (str):  logic operator to apply to property value.

        Returns:
            VulnerabilityQuery: This instance.
        """
        if not highest_risk_score:
            raise ApiError("Invalid highest risk score")
        self._update_criteria("highest_risk_score", highest_risk_score,
                              operator)
        return self

    def set_name(self, name, operator):
        """
        Restricts the vulnerabilities that this query is performed on to the specified name.

        Args:
            name (str): name.
            operator (str):  logic operator to apply to property value.

        Returns:
            VulnerabilityQuery: This instance.
        """
        if not name:
            raise ApiError("Invalid name")
        self._update_criteria("name", name, operator)
        return self

    def set_os_arch(self, os_arch, operator):
        """
        Restricts the vulnerabilities that this query is performed on to the specified os_arch.

        Args:
            os_arch (str): os_arch.
            operator (str):  logic operator to apply to property value.

        Returns:
            VulnerabilityQuery: This instance.
        """
        if not os_arch:
            raise ApiError("Invalid os architecture")
        self._update_criteria("os_arch", os_arch, operator)
        return self

    def set_os_name(self, os_name, operator):
        """
        Restricts the vulnerabilities that this query is performed on to the specified os_name.

        Args:
            os_name (str): os_name.
            operator (str):  logic operator to apply to property value.

        Returns:
            VulnerabilityQuery: This instance.
        """
        if not os_name:
            raise ApiError("Invalid os name")
        self._update_criteria("os_name", os_name, operator)
        return self

    def set_os_type(self, os_type, operator):
        """
        Restricts the vulnerabilities that this query is performed on to the specified os type.

        Args:
            os_type (str): os type ("CENTOS", "RHEL", "SLES", "UBUNTU", "WINDOWS")
            operator (str):  logic operator to apply to property value.

        Returns:
            VulnerabilityQuery: This instance.
        """
        if os_type not in VulnerabilityQuery.VALID_OS_TYPE:
            raise ApiError("Invalid os type")
        self._update_criteria("os_type", os_type, operator)
        return self

    def set_os_version(self, os_version, operator):
        """
        Restricts the vulnerabilities that this query is performed on to the specified os_version.

        Args:
            os_version (str): os_version.
            operator (str):  logic operator to apply to property value.

        Returns:
            VulnerabilityQuery: This instance.
        """
        if not os_version:
            raise ApiError("Invalid os version")
        self._update_criteria("os_version", os_version, operator)
        return self

    def set_severity(self, severity, operator):
        """
        Restricts the vulnerabilities that this query is performed on to the specified severity.

        Args:
            severity (str): severity ("CRITICAL", "IMPORTANT", "MODERATE", "LOW")
            operator (str):  logic operator to apply to property value.

        Returns:
            VulnerabilityQuery: This instance.
        """
        if severity not in VulnerabilityQuery.VALID_SEVERITY:
            raise ApiError("Invalid severity")
        self._update_criteria("severity", severity, operator)
        return self

    def set_sync_type(self, sync_type, operator):
        """
        Restricts the vulnerabilities that this query is performed on to the specified sync_type.

        Args:
            sync_type (str): sync_type ("MANUAL", "SCHEDULED")
            operator (str):  logic operator to apply to property value.

        Returns:
            VulnerabilityQuery: This instance.
        """
        if sync_type not in VulnerabilityQuery.VALID_SYNC_TYPE:
            raise ApiError("Invalid sync type")
        self._update_criteria("sync_type", sync_type, operator)
        return self

    def set_sync_status(self, sync_status, operator):
        """
        Restricts the vulnerabilities that this query is performed on to the specified sync_status.

        Args:
            sync_status (str): sync_status ("NOT_STARTED", "MATCHED", "ERROR", "NOT_MATCHED", "NOT_SUPPORTED",
                                               "CANCELLED", "IN_PROGRESS", "ACTIVE", "COMPLETED")
            operator (str):  logic operator to apply to property value.

        Returns:
            VulnerabilityQuery: This instance.
        """
        if sync_status not in VulnerabilityQuery.VALID_SYNC_STATUS:
            raise ApiError("Invalid sync status")
        self._update_criteria("sync_status", sync_status, operator)
        return self

    def set_vm_id(self, vm_id, operator):
        """
        Restricts the vulnerabilities that this query is performed on to the specified vm_id.

        Args:
            vm_id (str): vm_id.
            operator (str):  logic operator to apply to property value.

        Returns:
            VulnerabilityQuery: This instance.
        """
        if not vm_id:
            raise ApiError("Invalid vm id")
        self._update_criteria("vm_id", vm_id, operator)
        return self

    def set_vuln_count(self, vuln_count, operator):
        """
        Restricts the vulnerabilities that this query is performed on to the specified vuln_count.

        Args:
            vuln_count (str): vuln_count.
            operator (str):  logic operator to apply to property value.

        Returns:
            VulnerabilityQuery: This instance.
        """
        if not vuln_count:
            raise ApiError("Invalid vuln count")
        self._update_criteria("vuln_count", vuln_count, operator)
        return self

    def set_last_sync_ts(self, last_sync_ts, operator):
        """
        Restricts the vulnerabilities that this query is performed on to the specified last_sync_ts.

        Args:
            last_sync_ts (str): last_sync_ts.
            operator (str):  logic operator to apply to property value.

        Returns:
            VulnerabilityQuery: This instance.
        """
        if not last_sync_ts:
            raise ApiError("Invalid last_sync_ts")
        self._update_criteria("last_sync_ts", last_sync_ts, operator)
        return self

    """
        Including custom update_criteria, because the format is different:
        "criteria": {
            "property": {
                "value": "<str>",
                "operator": "<str>"
            }
        }
    """

    def _update_criteria(self, key, value, operator, overwrite=False):
        """
        Updates a list of criteria being collected for a query, by setting or appending items.

        Args:
            key (str): The property for the criteria item to be set.
            value (can be different types): the value for the criteria
            operator (str): any of the following types:
                - EQUALS, NOT_EQUALS, GREATER_THAN, LESS_THAN, IS_NULL,
                  IS_NOT_NULL, IS_TRUE, IS_FALSE, IN, NOT_IN, LIKE
            overwrite (bool): Overwrite the existing criteria for specified key

            The values are not lists, so if override is not allowed, disregard the change!
        """
        if self._criteria.get(key, None) is None or overwrite:
            self._criteria[key] = dict(value=value, operator=operator)

    def _build_request(self, from_row, max_rows, add_sort=True):
        """
        Creates the request body for an API call.

        Args:
            from_row (int): The row to start the query at.
            max_rows (int): The maximum number of rows to be returned.
            add_sort (bool): If True(default), the sort criteria will be added as part of the request.

        Returns:
            dict: The complete request body.
        """
        request = {
            "criteria": self._criteria,
            "query": self._query_builder._collapse(),
            "rows": 100
        }
        # Fetch 100 rows per page (instead of 10 by default) for better performance
        if from_row > 0:
            request["start"] = from_row
        if max_rows >= 0:
            request["rows"] = max_rows
        if add_sort and self._sortcriteria != {}:
            request["sort"] = [self._sortcriteria]
        return request

    def _build_url(self, tail_end):
        """
        Creates the URL to be used for an API call.

        Args:
            tail_end (str): str to be appended to the end of the generated URL.

        Returns:
            str: The complete URL.
        """
        if self.device:
            additional = f"/devices/{self.device._model_unique_id}/vulnerabilities"
        else:
            additional = "/devices/vulnerabilities"
            if self._vcenter_uuid:
                additional = f"/vcenters/{self._vcenter_uuid}" + additional

        url = self._doc_class.urlobject.format(
            self._cb.credentials.org_key) + additional + tail_end
        return url

    def _count(self):
        """
        Returns the number of results from the run of this query.

        Returns:
            int: The number of results from the run of this query.
        """
        if self._count_valid:
            return self._total_results
        url = self._build_url("/_search")
        request = self._build_request(0, -1)
        resp = self._cb.post_object(url, body=request)
        result = resp.json()

        self._total_results = result["num_found"]
        self._count_valid = True

        return self._total_results

    def _perform_query(self, from_row=0, max_rows=-1):
        """
        Performs the query and returns the results of the query in an iterable fashion.

        Args:
            from_row (int): The row to start the query at (default 0).
            max_rows (int): The maximum number of rows to be returned (default -1, meaning "all").

        Returns:
            Iterable: The iterated query.
        """
        url = self._build_url("/_search?dataForExport=true")
        current = from_row
        numrows = 0
        still_querying = True
        while still_querying:
            request = self._build_request(current, max_rows)
            resp = self._cb.post_object(url, body=request)
            result = resp.json()

            self._total_results = result["num_found"]
            self._count_valid = True

            results = result.get("results", [])
            for item in results:
                yield self._doc_class(self._cb,
                                      item.get('vuln_info',
                                               {}).get('cve_id', None),
                                      initial_data=item)
                current += 1
                numrows += 1

                if max_rows > 0 and numrows == max_rows:
                    still_querying = False
                    break

            if current >= self._total_results:
                break

    def _run_async_query(self, context):
        """
        Executed in the background to run an asynchronous query.

        Args:
            context (object): Not used, always None.

        Returns:
            list: Result of the async query, which is then returned by the future.
        """
        url = self._build_url("/_search?dataForExport=true")
        request = self._build_request(0, -1)
        resp = self._cb.post_object(url, body=request)
        result = resp.json()
        self._total_results = result["num_found"]
        self._count_valid = True
        results = result.get("results", [])
        return [
            self._doc_class(self._cb,
                            item.get('vuln_info', {}).get('cve_id', None),
                            initial_data=item) for item in results
        ]

    def sort_by(self, key, direction="ASC"):
        """
        Sets the sorting behavior on a query's results.

        Example:
            >>> cb.select(Vulnerabiltiy).sort_by("status")

        Args:
            key (str): The key in the schema to sort by.
            direction (str): The sort order, either "ASC" or "DESC".

        Returns:
            VulnerabilityQuery: This instance.

        Raises:
            ApiError: If an invalid direction value is passed.
        """
        if direction not in VulnerabilityQuery.VALID_DIRECTIONS:
            raise ApiError("invalid sort direction specified")
        self._sortcriteria = {"field": key, "order": direction}
        return self
Example #13
0
class FacetQuery(BaseQuery, QueryBuilderSupportMixin, IterableQueryMixin, CriteriaBuilderSupportMixin):
    """Represents a query that receives facet information from a LiveQuery run."""
    def __init__(self, doc_class, cb):
        """Initialize a FacetQuery object."""
        self._doc_class = doc_class
        self._cb = cb
        self._count_valid = False
        super(FacetQuery, self).__init__()

        self._query_builder = QueryBuilder()
        self._facet_fields = []
        self._criteria = {}
        self._run_id = None

    def facet_field(self, field):
        """Sets the facet fields to be received by this query.

        Arguments:
            field (str or [str]): Field(s) to be received.

        Returns:
            FacetQuery that will receive field(s) facet_field.

        Example:

        >>> cb.select(ResultFacet).run_id(my_run).facet_field(["device.policy_name", "device.os"])
        """
        if isinstance(field, str):
            self._facet_fields.append(field)
        else:
            for name in field:
                self._facet_fields.append(name)
        return self

    def set_device_ids(self, device_ids):
        """Sets the device.id criteria filter.

        Arguments:
            device_ids ([int]): Device IDs to filter on.

        Returns:
            The FacetQuery with specified device.id.
        """
        if not all(isinstance(device_id, int) for device_id in device_ids):
            raise ApiError("One or more invalid device IDs")
        self._update_criteria("device.id", device_ids)
        return self

    def set_device_names(self, device_names):
        """Sets the device.name criteria filter.

        Arguments:
            device_names ([str]): Device names to filter on.

        Returns:
            The FacetQuery with specified device.name.
        """
        if not all(isinstance(name, str) for name in device_names):
            raise ApiError("One or more invalid device names")
        self._update_criteria("device.name", device_names)
        return self

    def set_device_os(self, device_os):
        """Sets the device.os criteria.

        Arguments:
            device_os ([str]): Device OS's to filter on.

        Returns:
            The FacetQuery object with specified device_os.

        Note:
            Device OS's can be one or more of ["WINDOWS", "MAC", "LINUX"].
        """
        if not all(isinstance(os, str) for os in device_os):
            raise ApiError("device_type must be a list of strings, including"
                           " 'WINDOWS', 'MAC', and/or 'LINUX'")
        self._update_criteria("device.os", device_os)
        return self

    def set_policy_ids(self, policy_ids):
        """Sets the device.policy_id criteria.

        Arguments:
            policy_ids ([int]): Device policy ID's to filter on.

        Returns:
            The FacetQuery object with specified policy_ids.
        """
        if not all(isinstance(id, int) for id in policy_ids):
            raise ApiError("policy_ids must be a list of integers.")
        self._update_criteria("device.policy_id", policy_ids)
        return self

    def set_policy_names(self, policy_names):
        """Sets the device.policy_name criteria.

        Arguments:
            policy_names ([str]): Device policy names to filter on.

        Returns:
            The FacetQuery object with specified policy_names.
        """
        if not all(isinstance(name, str) for name in policy_names):
            raise ApiError("policy_names must be a list of strings.")
        self._update_criteria("device.policy_name", policy_names)
        return self

    def set_statuses(self, statuses):
        """Sets the status criteria.

        Arguments:
            statuses ([str]): Query statuses to filter on.

        Returns:
            The FacetQuery object with specified statuses.
        """
        if not all(isinstance(status, str) for status in statuses):
            raise ApiError("statuses must be a list of strings.")
        self._update_criteria("status", statuses)
        return self

    def run_id(self, run_id):
        """Sets the run ID to query results for.

        Arguments:
            run_id (int): The run ID to retrieve results for.

        Returns:
            FacetQuery object with specified run_id.

        Example:
        >>> cb.select(ResultFacet).run_id(my_run)
        """
        self._run_id = run_id
        return self

    def _build_request(self, rows):
        terms = {"fields": self._facet_fields}
        if rows != 0:
            terms["rows"] = rows
        request = {"query": self._query_builder._collapse(), "terms": terms}
        if self._criteria:
            request["criteria"] = self._criteria
        return request

    def _perform_query(self, rows=0):
        if self._run_id is None:
            raise ApiError("Can't retrieve results without a run ID")

        url = self._doc_class.urlobject.format(
            self._cb.credentials.org_key, self._run_id
        )
        request = self._build_request(rows)
        resp = self._cb.post_object(url, body=request)
        result = resp.json()
        results = result.get("terms", [])
        for item in results:
            yield self._doc_class(self._cb, item)
class DeviceSearchQuery(BaseQuery, QueryBuilderSupportMixin,
                        CriteriaBuilderSupportMixin, IterableQueryMixin,
                        AsyncQueryMixin):
    """Represents a query that is used to locate Device objects."""
    VALID_OS = ["WINDOWS", "ANDROID", "MAC", "IOS", "LINUX", "OTHER"]
    VALID_STATUSES = [
        "PENDING", "REGISTERED", "UNINSTALLED", "DEREGISTERED", "ACTIVE",
        "INACTIVE", "ERROR", "ALL", "BYPASS_ON", "BYPASS", "QUARANTINE",
        "SENSOR_OUTOFDATE", "DELETED", "LIVE"
    ]
    VALID_PRIORITIES = ["LOW", "MEDIUM", "HIGH", "MISSION_CRITICAL"]
    VALID_DIRECTIONS = ["ASC", "DESC"]
    VALID_DEPLOYMENT_TYPES = ["ENDPOINT", "WORKLOAD"]

    def __init__(self, doc_class, cb):
        """
        Initialize the DeviceSearchQuery.

        Args:
            doc_class (class): The model class that will be returned by this query.
            cb (BaseAPI): Reference to API object used to communicate with the server.
        """
        self._doc_class = doc_class
        self._cb = cb
        self._count_valid = False
        super(DeviceSearchQuery, self).__init__()

        self._query_builder = QueryBuilder()
        self._criteria = {}
        self._time_filter = {}
        self._exclusions = {}
        self._sortcriteria = {}
        self.max_rows = -1

    def _update_exclusions(self, key, newlist):
        """
        Updates the exclusion criteria being collected for a query.

        Assumes the specified criteria item is defined as a list; the list passed in will be set as the value for this
        criteria item, or appended to the existing one if there is one.

        Args:
            key (str): The key for the criteria item to be set.
            newlist (list): List of values to be set for the criteria item.
        """
        oldlist = self._exclusions.get(key, [])
        self._exclusions[key] = oldlist + newlist

    def set_ad_group_ids(self, ad_group_ids):
        """
        Restricts the devices that this query is performed on to the specified AD group IDs.

        Args:
            ad_group_ids (list): List of AD group IDs to restrict the search to.

        Returns:
            DeviceSearchQuery: This instance.

        Raises:
            ApiError: If invalid (non-int) values are passed in the list.
        """
        if not all(
                isinstance(ad_group_id, int) for ad_group_id in ad_group_ids):
            raise ApiError("One or more invalid AD group IDs")
        self._update_criteria("ad_group_id", ad_group_ids)
        return self

    def set_device_ids(self, device_ids):
        """
        Restricts the devices that this query is performed on to the specified device IDs.

        Args:
            device_ids (list): List of device IDs to restrict the search to.

        Returns:
            DeviceSearchQuery: This instance.

        Raises:
            ApiError: If invalid (non-int) values are passed in the list.
        """
        if not all(isinstance(device_id, int) for device_id in device_ids):
            raise ApiError("One or more invalid device IDs")
        self._update_criteria("id", device_ids)
        return self

    def set_last_contact_time(self, *args, **kwargs):
        """
        Restricts the devices that this query is performed on to the specified last contact time.

        Args:
            *args (list): Not used, retained for compatibility.
            **kwargs (dict): Keyword arguments to this function.  The critical ones are "start" (the start time),
                             "end" (the end time), and "range" (the range value).

        Returns:
            DeviceSearchQuery: This instance.

        Raises:
            ApiError: If an invalid combination of keyword parameters are specified.
        """
        if kwargs.get("start", None) and kwargs.get("end", None):
            if kwargs.get("range", None):
                raise ApiError(
                    "cannot specify range= in addition to start= and end=")
            stime = kwargs["start"]
            if not isinstance(stime, str):
                stime = stime.isoformat()
            etime = kwargs["end"]
            if not isinstance(etime, str):
                etime = etime.isoformat()
            self._time_filter = {"start": stime, "end": etime}
        elif kwargs.get("range", None):
            if kwargs.get("start", None) or kwargs.get("end", None):
                raise ApiError(
                    "cannot specify start= or end= in addition to range=")
            self._time_filter = {"range": kwargs["range"]}
        else:
            raise ApiError("must specify either start= and end= or range=")
        return self

    def set_os(self, operating_systems):
        """
        Restricts the devices that this query is performed on to the specified operating systems.

        Args:
            operating_systems (list): List of operating systems to restrict search to.  Valid values in this list are
                                      "WINDOWS", "ANDROID", "MAC", "IOS", "LINUX", and "OTHER".

        Returns:
            DeviceSearchQuery: This instance.

        Raises:
            ApiError: If invalid operating system values are passed in the list.
        """
        if not all((osval in DeviceSearchQuery.VALID_OS)
                   for osval in operating_systems):
            raise ApiError("One or more invalid operating systems")
        self._update_criteria("os", operating_systems)
        return self

    def set_policy_ids(self, policy_ids):
        """
        Restricts the devices that this query is performed on to the specified policy IDs.

        Args:
            policy_ids (list): List of policy IDs to restrict the search to.

        Returns:
            DeviceSearchQuery: This instance.

        Raises:
            ApiError: If invalid (non-int) values are passed in the list.
        """
        if not all(isinstance(policy_id, int) for policy_id in policy_ids):
            raise ApiError("One or more invalid policy IDs")
        self._update_criteria("policy_id", policy_ids)
        return self

    def set_status(self, statuses):
        """
        Restricts the devices that this query is performed on to the specified status values.

        Args:
            statuses (list): List of statuses to restrict search to.  Valid values in this list are "PENDING",
                             "REGISTERED", "UNINSTALLED", "DEREGISTERED", "ACTIVE", "INACTIVE", "ERROR", "ALL",
                             "BYPASS_ON", "BYPASS", "QUARANTINE", "SENSOR_OUTOFDATE", "DELETED", and "LIVE".

        Returns:
            DeviceSearchQuery: This instance.

        Raises:
            ApiError: If invalid status values are passed in the list.
        """
        if not all(
            (stat in DeviceSearchQuery.VALID_STATUSES) for stat in statuses):
            raise ApiError("One or more invalid status values")
        self._update_criteria("status", statuses)
        return self

    def set_target_priorities(self, target_priorities):
        """
        Restricts the devices that this query is performed on to the specified target priority values.

        Args:
            target_priorities (list): List of priorities to restrict search to.  Valid values in this list are "LOW",
                                      "MEDIUM", "HIGH", and "MISSION_CRITICAL".

        Returns:
            DeviceSearchQuery: This instance.

        Raises:
            ApiError: If invalid priority values are passed in the list.
        """
        if not all((prio in DeviceSearchQuery.VALID_PRIORITIES)
                   for prio in target_priorities):
            raise ApiError("One or more invalid target priority values")
        self._update_criteria("target_priority", target_priorities)
        return self

    def set_exclude_sensor_versions(self, sensor_versions):
        """
        Restricts the devices that this query is performed on to exclude specified sensor versions.

        Args:
            sensor_versions (list): List of sensor versions to be excluded.

        Returns:
            DeviceSearchQuery: This instance.

        Raises:
            ApiError: If invalid (non-string) values are passed in the list.
        """
        if not all(isinstance(v, str) for v in sensor_versions):
            raise ApiError("One or more invalid sensor versions")
        self._update_exclusions("sensor_version", sensor_versions)
        return self

    def sort_by(self, key, direction="ASC"):
        """
        Sets the sorting behavior on a query's results.

        Example:
            >>> cb.select(Device).sort_by("status")

        Args:
            key (str): The key in the schema to sort by.
            direction (str): The sort order, either "ASC" or "DESC".

        Returns:
            DeviceSearchQuery: This instance.

        Raises:
            ApiError: If an invalid direction value is passed.
        """
        if direction not in DeviceSearchQuery.VALID_DIRECTIONS:
            raise ApiError("invalid sort direction specified")
        self._sortcriteria = {"field": key, "order": direction}
        return self

    def set_deployment_type(self, deployment_type):
        """
        Restricts the devices that this query is performed on to the specified deployment types.

        Args:
            deployment_type (list): List of deployment types to restrict search to.

        Returns:
            DeviceSearchQuery: This instance.

        Raises:
            ApiError: If invalid deployment type values are passed in the list.
        """
        if not all((type in DeviceSearchQuery.VALID_DEPLOYMENT_TYPES)
                   for type in deployment_type):
            raise ApiError("invalid deployment_type specified")
        self._update_criteria("deployment_type", deployment_type)
        return self

    def set_max_rows(self, max_rows):
        """
        Sets the max number of devices to fetch in a singular query

        Args:
            max_rows (integer): Max number of devices

        Returns:
            DeviceSearchQuery: This instance.

        Raises:
            ApiError: If rows is negative or greater than 10000
        """
        if max_rows < 0 or max_rows > 10000:
            raise ApiError("Max rows must be between 0 and 10000")
        self.max_rows = max_rows
        return self

    def _build_request(self, from_row, max_rows):
        """
        Creates the request body for an API call.

        Args:
            from_row (int): The row to start the query at.
            max_rows (int): The maximum number of rows to be returned.

        Returns:
            dict: The complete request body.
        """
        mycrit = self._criteria
        if self._time_filter:
            mycrit["last_contact_time"] = self._time_filter
        request = {"criteria": mycrit, "exclusions": self._exclusions}
        request["query"] = self._query_builder._collapse()
        if from_row > 1:
            request["start"] = from_row
        if max_rows >= 0:
            request["rows"] = max_rows
        elif self.max_rows >= 0:
            request["rows"] = self.max_rows
        if self._sortcriteria != {}:
            request["sort"] = [self._sortcriteria]
        return request

    def _build_url(self, tail_end):
        """
        Creates the URL to be used for an API call.

        Args:
            tail_end (str): String to be appended to the end of the generated URL.

        Returns:
            str: The complete URL.
        """
        url = self._doc_class.urlobject.format(
            self._cb.credentials.org_key) + tail_end
        return url

    def _count(self):
        """
        Returns the number of results from the run of this query.

        Returns:
            int: The number of results from the run of this query.
        """
        if self._count_valid:
            return self._total_results

        url = self._build_url("/_search")
        request = self._build_request(0, -1)
        resp = self._cb.post_object(url, body=request)
        result = resp.json()

        self._total_results = result["num_found"]
        self._count_valid = True

        return self._total_results

    def _perform_query(self, from_row=1, max_rows=-1):
        """
        Performs the query and returns the results of the query in an iterable fashion.

        Device v6 API uses base 1 instead of 0.

        Args:
            from_row (int): The row to start the query at (default 1).
            max_rows (int): The maximum number of rows to be returned (default -1, meaning "all").

        Returns:
            Iterable: The iterated query.
        """
        url = self._build_url("/_search")
        current = from_row
        numrows = 0
        still_querying = True
        while still_querying:
            request = self._build_request(current, max_rows)
            resp = self._cb.post_object(url, body=request)
            result = resp.json()

            self._total_results = result["num_found"]
            self._count_valid = True

            results = result.get("results", [])
            for item in results:
                yield self._doc_class(self._cb, item["id"], item)
                current += 1
                numrows += 1

                if max_rows > 0 and numrows == max_rows:
                    still_querying = False
                    break

            from_row = current
            if current >= self._total_results:
                still_querying = False
                break

    def _run_async_query(self, context):
        """
        Executed in the background to run an asynchronous query. Must be implemented in any inheriting classes.

        Args:
            context (object): The context returned by _init_async_query. May be None.

        Returns:
            Any: Result of the async query, which is then returned by the future.
        """
        url = self._build_url("/_search")
        self._total_results = 0
        self._count_valid = False
        output = []
        while not self._count_valid or len(output) < self._total_results:
            request = self._build_request(len(output), -1)
            resp = self._cb.post_object(url, body=request)
            result = resp.json()

            if not self._count_valid:
                self._total_results = result["num_found"]
                self._count_valid = True

            results = result.get("results", [])
            output += [
                self._doc_class(self._cb, item["id"], item) for item in results
            ]
        return output

    def download(self):
        """
        Uses the query parameters that have been set to download all device listings in CSV format.

        Example:
            >>> cb.select(Device).set_status(["ALL"]).download()

        Returns:
            str: The CSV raw data as returned from the server.

        Raises:
            ApiError: If status values have not been set before calling this function.
        """
        tmp = self._criteria.get("status", [])
        if not tmp:
            raise ApiError("at least one status must be specified to download")
        query_params = {"status": ",".join(tmp)}
        tmp = self._criteria.get("ad_group_id", [])
        if tmp:
            query_params["ad_group_id"] = ",".join([str(t) for t in tmp])
        tmp = self._criteria.get("policy_id", [])
        if tmp:
            query_params["policy_id"] = ",".join([str(t) for t in tmp])
        tmp = self._criteria.get("target_priority", [])
        if tmp:
            query_params["target_priority"] = ",".join(tmp)
        tmp = self._query_builder._collapse()
        if tmp:
            query_params["query_string"] = tmp
        if self._sortcriteria:
            query_params["sort_field"] = self._sortcriteria["field"]
            query_params["sort_order"] = self._sortcriteria["order"]
        url = self._build_url("/_search/download")
        return self._cb.get_raw_data(url, query_params)

    def _bulk_device_action(self, action_type, options=None):
        """
        Perform a bulk action on all devices matching the current search criteria.

        Args:
            action_type (str): The action type to be performed.
            options (dict): Any options for the bulk device action.

        Returns:
            str: The JSON output from the request.
        """
        request = {
            "action_type": action_type,
            "search": self._build_request(0, -1)
        }
        if options:
            request["options"] = options
        return self._cb._raw_device_action(request)

    def background_scan(self, scan):
        """
        Set the background scan option for the specified devices.

        Args:
            scan (bool): True to turn background scan on, False to turn it off.

        Returns:
            str: The JSON output from the request.
        """
        return self._bulk_device_action("BACKGROUND_SCAN",
                                        self._cb._action_toggle(scan))

    def bypass(self, enable):
        """
        Set the bypass option for the specified devices.

        Args:
            enable (bool): True to enable bypass, False to disable it.

        Returns:
            str: The JSON output from the request.
        """
        return self._bulk_device_action("BYPASS",
                                        self._cb._action_toggle(enable))

    def delete_sensor(self):
        """
        Delete the specified sensor devices.

        Returns:
            str: The JSON output from the request.
        """
        return self._bulk_device_action("DELETE_SENSOR")

    def uninstall_sensor(self):
        """
        Uninstall the specified sensor devices.

        Returns:
            str: The JSON output from the request.
        """
        return self._bulk_device_action("UNINSTALL_SENSOR")

    def quarantine(self, enable):
        """
        Set the quarantine option for the specified devices.

        Args:
            enable (bool): True to enable quarantine, False to disable it.

        Returns:
            str: The JSON output from the request.
        """
        return self._bulk_device_action("QUARANTINE",
                                        self._cb._action_toggle(enable))

    def update_policy(self, policy_id):
        """
        Set the current policy for the specified devices.

        Args:
            policy_id (int): ID of the policy to set for the devices.

        Returns:
            str: The JSON output from the request.
        """
        return self._bulk_device_action("UPDATE_POLICY",
                                        {"policy_id": policy_id})

    def update_sensor_version(self, sensor_version):
        """
        Update the sensor version for the specified devices.

        Args:
            sensor_version (dict): New version properties for the sensor.

        Returns:
            str: The JSON output from the request.
        """
        return self._bulk_device_action("UPDATE_SENSOR_VERSION",
                                        {"sensor_version": sensor_version})
Example #15
0
class ResultQuery(BaseQuery, QueryBuilderSupportMixin, IterableQueryMixin, CriteriaBuilderSupportMixin):
    """Represents a query that retrieves results from a LiveQuery run."""
    def __init__(self, doc_class, cb):
        """Initialize a ResultQuery object."""
        self._doc_class = doc_class
        self._cb = cb
        self._count_valid = False
        super(ResultQuery, self).__init__()

        self._query_builder = QueryBuilder()
        self._criteria = {}
        self._sort = {}
        self._batch_size = 100
        self._run_id = None

    def set_device_ids(self, device_ids):
        """Sets the device.id criteria filter.

        Arguments:
            device_ids ([int]): Device IDs to filter on.

        Returns:
            The ResultQuery with specified device.id.
        """
        if not all(isinstance(device_id, int) for device_id in device_ids):
            raise ApiError("One or more invalid device IDs")
        self._update_criteria("device.id", device_ids)
        return self

    def set_device_names(self, device_names):
        """Sets the device.name criteria filter.

        Arguments:
            device_names ([str]): Device names to filter on.

        Returns:
            The ResultQuery with specified device.name.
        """
        if not all(isinstance(name, str) for name in device_names):
            raise ApiError("One or more invalid device names")
        self._update_criteria("device.name", device_names)
        return self

    def set_device_os(self, device_os):
        """Sets the device.os criteria.

        Arguments:
            device_os ([str]): Device OS's to filter on.

        Returns:
            The ResultQuery object with specified device_os.

        Note:
            Device OS's can be one or more of ["WINDOWS", "MAC", "LINUX"].
        """
        if not all(isinstance(os, str) for os in device_os):
            raise ApiError("device_type must be a list of strings, including"
                           " 'WINDOWS', 'MAC', and/or 'LINUX'")
        self._update_criteria("device.os", device_os)
        return self

    def set_policy_ids(self, policy_ids):
        """Sets the device.policy_id criteria.

        Arguments:
            policy_ids ([int]): Device policy ID's to filter on.

        Returns:
            The ResultQuery object with specified policy_ids.
        """
        if not all(isinstance(id, int) for id in policy_ids):
            raise ApiError("policy_ids must be a list of integers.")
        self._update_criteria("device.policy_id", policy_ids)
        return self

    def set_policy_names(self, policy_names):
        """Sets the device.policy_name criteria.

        Arguments:
            policy_names ([str]): Device policy names to filter on.

        Returns:
            The ResultQuery object with specified policy_names.
        """
        if not all(isinstance(name, str) for name in policy_names):
            raise ApiError("policy_names must be a list of strings.")
        self._update_criteria("device.policy_name", policy_names)
        return self

    def set_statuses(self, statuses):
        """Sets the status criteria.

        Arguments:
            statuses ([str]): Query statuses to filter on.

        Returns:
            The ResultQuery object with specified statuses.
        """
        if not all(isinstance(status, str) for status in statuses):
            raise ApiError("statuses must be a list of strings.")
        self._update_criteria("status", statuses)
        return self

    def sort_by(self, key, direction="ASC"):
        """Sets the sorting behavior on a query's results.

        Arguments:
            key (str): The key in the schema to sort by.
            direction (str): The sort order, either "ASC" or "DESC".

        Returns:
            ResultQuery object with specified sorting key and order.

        Example:

        >>> cb.select(Result).run_id(my_run).where(username="******").sort_by("uid")
        """
        self._sort.update({"field": key, "order": direction})
        return self

    def run_id(self, run_id):
        """Sets the run ID to query results for.

        Arguments:
            run_id (int): The run ID to retrieve results for.

        Returns:
            ResultQuery object with specified run_id.

        Example:

        >>> cb.select(Result).run_id(my_run)
        """
        self._run_id = run_id
        return self

    def _build_request(self, start, rows):
        request = {"start": start, "query": self._query_builder._collapse()}

        if rows != 0:
            request["rows"] = rows
        if self._criteria:
            request["criteria"] = self._criteria
        if self._sort:
            request["sort"] = [self._sort]

        return request

    def _count(self):
        if self._count_valid:
            return self._total_results

        if self._run_id is None:
            raise ApiError("Can't retrieve count without a run ID")

        url = self._doc_class.urlobject.format(
            self._cb.credentials.org_key, self._run_id
        )
        request = self._build_request(start=0, rows=0)
        resp = self._cb.post_object(url, body=request)
        result = resp.json()

        self._total_results = result["num_found"]
        self._count_valid = True

        return self._total_results

    def _perform_query(self, start=0, rows=0):
        if self._run_id is None:
            raise ApiError("Can't retrieve results without a run ID")

        url = self._doc_class.urlobject.format(
            self._cb.credentials.org_key, self._run_id
        )
        current = start
        numrows = 0
        still_querying = True
        while still_querying:
            request = self._build_request(start, rows)
            resp = self._cb.post_object(url, body=request)
            result = resp.json()

            self._total_results = result["num_found"]
            if self._total_results > MAX_RESULTS_LIMIT:
                self._total_results = MAX_RESULTS_LIMIT
            self._count_valid = True

            results = result.get("results", [])
            for item in results:
                yield self._doc_class(self._cb, item)
                current += 1
                numrows += 1

                if rows and numrows == rows:
                    still_querying = False
                    break

            start = current
            if current >= self._total_results:
                still_querying = False
                break
Example #16
0
class RunHistoryQuery(BaseQuery, QueryBuilderSupportMixin, IterableQueryMixin, CriteriaBuilderSupportMixin):
    """Represents a query that retrieves historic LiveQuery runs."""
    def __init__(self, doc_class, cb):
        """Initialize a RunHistoryQuery object."""
        self._doc_class = doc_class
        self._cb = cb
        self._count_valid = False
        super(RunHistoryQuery, self).__init__()
        self._query_builder = QueryBuilder()
        self._sort = {}
        self._criteria = {}

    def set_template_ids(self, template_ids):
        """Sets the template_id criteria filter.

        Arguments:
            template_ids ([str]): Template IDs to filter on.

        Returns:
            The ResultQuery with specified template_id.
        """
        if not all(isinstance(template_id, str) for template_id in template_ids):
            raise ApiError("One or more invalid template IDs")
        self._update_criteria("template_id", template_ids)
        return self

    def sort_by(self, key, direction="ASC"):
        """Sets the sorting behavior on a query's results.

        Arguments:
            key (str): The key in the schema to sort by.
            direction (str): The sort order, either "ASC" or "DESC".

        Returns:
            RunHistoryQuery object with specified sorting key and order.

        Example:

        >>> cb.select(Result).run_id(my_run).where(username="******").sort_by("uid")
        """
        self._sort.update({"field": key, "order": direction})
        return self

    def _build_request(self, start, rows):
        request = {"start": start}

        if self._query_builder:
            request["query"] = self._query_builder._collapse()
        if rows != 0:
            request["rows"] = rows
        if self._criteria:
            request["criteria"] = self._criteria
        if self._sort:
            request["sort"] = [self._sort]

        return request

    def _count(self):
        if self._count_valid:
            return self._total_results

        url = self._doc_class.urlobject_history.format(
            self._cb.credentials.org_key
        )
        request = self._build_request(start=0, rows=0)
        resp = self._cb.post_object(url, body=request)
        result = resp.json()

        self._total_results = result["num_found"]
        self._count_valid = True

        return self._total_results

    def _perform_query(self, start=0, rows=0):
        url = self._doc_class.urlobject_history.format(
            self._cb.credentials.org_key
        )
        current = start
        numrows = 0
        still_querying = True
        while still_querying:
            request = self._build_request(start, rows)
            resp = self._cb.post_object(url, body=request)
            result = resp.json()

            self._total_results = result["num_found"]
            self._count_valid = True

            results = result.get("results", [])
            for item in results:
                yield self._doc_class(self._cb, item)
                current += 1
                numrows += 1

                if rows and numrows == rows:
                    still_querying = False
                    break

            start = current
            if current >= self._total_results:
                still_querying = False
                break
Example #17
0
class USBDeviceApprovalQuery(BaseQuery, QueryBuilderSupportMixin,
                             CriteriaBuilderSupportMixin, IterableQueryMixin,
                             AsyncQueryMixin):
    """Represents a query that is used to locate USBDeviceApproval objects."""
    def __init__(self, doc_class, cb):
        """
        Initialize the USBDeviceApprovalQuery.

        Args:
            doc_class (class): The model class that will be returned by this query.
            cb (BaseAPI): Reference to API object used to communicate with the server.
        """
        self._doc_class = doc_class
        self._cb = cb
        self._count_valid = False
        super(BaseQuery, self).__init__()
        self._query_builder = QueryBuilder()
        self._criteria = {}
        self._total_results = 0

    def set_device_ids(self, device_ids):
        """
        Restricts the device approvals that this query is performed on to the specified device IDs.

        Args:
            device_ids (list): List of string device IDs.

        Returns:
            USBDeviceApprovalQuery: This instance.
        """
        if not all(isinstance(device_id, str) for device_id in device_ids):
            raise ApiError("One or more invalid device IDs")
        self._update_criteria("device.id", device_ids)
        return self

    def set_product_names(self, product_names):
        """
        Restricts the device approvals that this query is performed on to the specified product names.

        Args:
            product_names (list): List of string product names.

        Returns:
            USBDeviceApprovalQuery: This instance.
        """
        if not all(
                isinstance(product_name, str)
                for product_name in product_names):
            raise ApiError("One or more invalid product names")
        self._update_criteria("product_name", product_names)
        return self

    def set_vendor_names(self, vendor_names):
        """
        Restricts the device approvals that this query is performed on to the specified vendor names.

        Args:
            vendor_names (list): List of string vendor names.

        Returns:
            USBDeviceApprovalQuery: This instance.
        """
        if not all(
                isinstance(vendor_name, str) for vendor_name in vendor_names):
            raise ApiError("One or more invalid vendor names")
        self._update_criteria("vendor_name", vendor_names)
        return self

    def _build_request(self, from_row, max_rows):
        """
        Creates the request body for an API call.

        Args:
            from_row (int): The row to start the query at.
            max_rows (int): The maximum number of rows to be returned.

        Returns:
            dict: The complete request body.
        """
        request = {
            "criteria": self._criteria,
            "query": self._query_builder._collapse(),
            "rows": 100
        }
        # Fetch 100 rows per page (instead of 10 by default) for better performance
        if from_row > 0:
            request["start"] = from_row
        if max_rows >= 0:
            request["rows"] = max_rows
        return request

    def _build_url(self, tail_end):
        """
        Creates the URL to be used for an API call.

        Args:
            tail_end (str): String to be appended to the end of the generated URL.

        Returns:
            str: The complete URL.
        """
        url = self._doc_class.urlobject.format(
            self._cb.credentials.org_key) + tail_end
        return url

    def _count(self):
        """
        Returns the number of results from the run of this query.

        Returns:
            int: The number of results from the run of this query.
        """
        if self._count_valid:
            return self._total_results

        url = self._build_url("/_search")
        request = self._build_request(0, -1)
        resp = self._cb.post_object(url, body=request)
        result = resp.json()

        self._total_results = result["num_found"]
        self._count_valid = True

        return self._total_results

    def _perform_query(self, from_row=0, max_rows=-1):
        """
        Performs the query and returns the results of the query in an iterable fashion.

        Args:
            from_row (int): The row to start the query at (default 0).
            max_rows (int): The maximum number of rows to be returned (default -1, meaning "all").

        Returns:
            Iterable: The iterated query.
        """
        url = self._build_url("/_search")
        current = from_row
        numrows = 0
        still_querying = True
        while still_querying:
            request = self._build_request(current, max_rows)
            resp = self._cb.post_object(url, body=request)
            result = resp.json()

            self._total_results = result["num_found"]
            self._count_valid = True

            results = result.get("results", [])
            for item in results:
                yield self._doc_class(self._cb, item["id"], item)
                current += 1
                numrows += 1

                if max_rows > 0 and numrows == max_rows:
                    still_querying = False
                    break

            if current >= self._total_results:
                break

    def _run_async_query(self, context):
        """
        Executed in the background to run an asynchronous query.

        Args:
            context (object): Not used, always None.

        Returns:
            list: Result of the async query, which is then returned by the future.
        """
        url = self._build_url("/_search")
        request = self._build_request(0, -1)
        resp = self._cb.post_object(url, body=request)
        result = resp.json()
        self._total_results = result["num_found"]
        self._count_valid = True
        results = result.get("results", [])
        return [
            self._doc_class(self._cb, item["id"], item) for item in results
        ]
Example #18
0
class ReputationOverrideQuery(BaseQuery, QueryBuilderSupportMixin,
                              IterableQueryMixin, AsyncQueryMixin):
    """Represents a query that is used to locate ReputationOverride objects."""
    VALID_DIRECTIONS = ["ASC", "DESC", "asc", "desc"]

    def __init__(self, doc_class, cb):
        """
        Initialize the ReputationOverrideQuery.

        Args:
            doc_class (class): The model class that will be returned by this query.
            cb (BaseAPI): Reference to API object used to communicate with the server.
        """
        self._doc_class = doc_class
        self._cb = cb
        self._count_valid = False
        super(ReputationOverrideQuery, self).__init__()

        self._query_builder = QueryBuilder()
        self._criteria = {}
        self._sortcriteria = {}

    def set_override_list(self, override_list):
        """Sets the override_list criteria filter.

        Arguments:
            override_list (str): Override List to filter on.

        Returns:
            The ReputationOverrideQuery with specified override_list.
        """
        if not isinstance(override_list, str) and override_list in [
                "WHITE_LIST", "BLACK_LIST"
        ]:
            raise ApiError(
                "Invalid override_list must be one of WHITE_LIST, BLACK_LIST")
        self._criteria["override_list"] = override_list
        return self

    def set_override_type(self, override_type):
        """Sets the override_type criteria filter.

        Arguments:
            override_type (str): Override List to filter on.

        Returns:
            The ReputationOverrideQuery with specified override_type.
        """
        if not isinstance(override_type, str) and override_type in [
                "SHA256", "CERT", "IT_TOOL"
        ]:
            raise ApiError(
                "Invalid override_type must be one of SHA256, CERT, IT_TOOL")
        self._criteria["override_type"] = override_type
        return self

    def sort_by(self, key, direction="ASC"):
        """
        Sets the sorting behavior on a query's results.

        Example:
            >>> cb.select(ReputationOverride).sort_by("create_time")

        Args:
            key (str): The key in the schema to sort by.
            direction (str): The sort order, either "ASC" or "DESC".

        Returns:
            ReputationOverrideQuery: This instance.

        Raises:
            ApiError: If an invalid direction value is passed.
        """
        if direction not in ReputationOverrideQuery.VALID_DIRECTIONS:
            raise ApiError("invalid sort direction specified")
        self._sortcriteria = {"sort_field": key, "sort_order": direction}
        return self

    def _build_request(self, from_row, max_rows):
        """
        Creates the request body for an API call.

        Args:
            from_row (int): The row to start the query at.
            max_rows (int): The maximum number of rows to be returned.

        Returns:
            dict: The complete request body.
        """
        request = {
            "criteria": self._criteria,
            "query": self._query_builder._collapse()
        }
        if from_row > 0:
            request["start"] = from_row
        if max_rows >= 0:
            request["rows"] = max_rows
        if self._sortcriteria != {}:
            request.update(self._sortcriteria)
        return request

    def _build_url(self, tail_end):
        """
        Creates the URL to be used for an API call.

        Args:
            tail_end (str): String to be appended to the end of the generated URL.

        Returns:
            str: The complete URL.
        """
        url = self._doc_class.urlobject.format(
            self._cb.credentials.org_key) + tail_end
        return url

    def _count(self):
        """
        Returns the number of results from the run of this query.

        Returns:
            int: The number of results from the run of this query.
        """
        if self._count_valid:
            return self._total_results

        url = self._build_url("/_search")
        request = self._build_request(0, -1)
        resp = self._cb.post_object(url, body=request)
        result = resp.json()

        self._total_results = result["num_found"]
        self._count_valid = True

        return self._total_results

    def _perform_query(self, from_row=0, max_rows=-1):
        """
        Performs the query and returns the results of the query in an iterable fashion.

        Args:
            from_row (int): The row to start the query at (default 0).
            max_rows (int): The maximum number of rows to be returned (default -1, meaning "all").

        Returns:
            Iterable: The iterated query.
        """
        url = self._build_url("/_search")
        current = from_row
        numrows = 0
        still_querying = True
        while still_querying:
            request = self._build_request(current, max_rows)
            resp = self._cb.post_object(url, body=request)
            result = resp.json()

            self._total_results = result["num_found"]
            self._count_valid = True

            results = result.get("results", [])
            for item in results:
                yield self._doc_class(self._cb, item["id"], item)
                current += 1
                numrows += 1

                if max_rows > 0 and numrows == max_rows:
                    still_querying = False
                    break

            if current >= self._total_results:
                break

    def _run_async_query(self, context):
        """
        Executed in the background to run an asynchronous query.

        Args:
            context (object): Not used, always None.

        Returns:
            list: Result of the async query, which is then returned by the future.
        """
        url = self._build_url("/_search")
        request = self._build_request(0, -1)
        resp = self._cb.post_object(url, body=request)
        result = resp.json()
        self._total_results = result["num_found"]
        self._count_valid = True
        results = result.get("results", [])
        return [
            self._doc_class(self._cb, item["id"], item) for item in results
        ]
Example #19
0
class USBDeviceQuery(BaseQuery, QueryBuilderSupportMixin,
                     CriteriaBuilderSupportMixin, IterableQueryMixin,
                     AsyncQueryMixin):
    """Represents a query that is used to locate USBDevice objects."""
    VALID_STATUSES = ["APPROVED", "UNAPPROVED"]
    VALID_FACET_FIELDS = [
        "vendor_name", "product_name", "endpoint.endpoint_name", "status"
    ]

    def __init__(self, doc_class, cb):
        """
        Initialize the USBDeviceQuery.

        Args:
            doc_class (class): The model class that will be returned by this query.
            cb (BaseAPI): Reference to API object used to communicate with the server.
        """
        self._doc_class = doc_class
        self._cb = cb
        self._count_valid = False
        super(BaseQuery, self).__init__()
        self._query_builder = QueryBuilder()
        self._criteria = {}
        self._sortcriteria = {}
        self._total_results = 0

    def set_endpoint_names(self, endpoint_names):
        """
        Restricts the devices that this query is performed on to the specified endpoint names.

        Args:
            endpoint_names (list): List of string endpoint names.

        Returns:
            USBDeviceQuery: This instance.
        """
        if not all(
                isinstance(endpoint_name, str)
                for endpoint_name in endpoint_names):
            raise ApiError("One or more invalid endpoint names")
        self._update_criteria("endpoint.endpoint_name", endpoint_names)
        return self

    def set_product_names(self, product_names):
        """
        Restricts the devices that this query is performed on to the specified product names.

        Args:
            product_names (list): List of string product names.

        Returns:
            USBDeviceQuery: This instance.
        """
        if not all(
                isinstance(product_name, str)
                for product_name in product_names):
            raise ApiError("One or more invalid product names")
        self._update_criteria("product_name", product_names)
        return self

    def set_serial_numbers(self, serial_numbers):
        """
        Restricts the devices that this query is performed on to the specified serial numbers.

        Args:
            serial_numbers (list): List of string serial numbers.

        Returns:
            USBDeviceQuery: This instance.
        """
        if not all(
                isinstance(serial_number, str)
                for serial_number in serial_numbers):
            raise ApiError("One or more invalid serial numbers")
        self._update_criteria("serial_number", serial_numbers)
        return self

    def set_statuses(self, statuses):
        """
        Restricts the devices that this query is performed on to the specified status values.

        Args:
            statuses (list): List of string status values.  Valid values are APPROVED and UNAPPROVED.

        Returns:
            USBDeviceQuery: This instance.
        """
        if not all((s in USBDeviceQuery.VALID_STATUSES) for s in statuses):
            raise ApiError("One or more invalid status values")
        self._update_criteria("status", statuses)
        return self

    def set_vendor_names(self, vendor_names):
        """
        Restricts the devices that this query is performed on to the specified vendor names.

        Args:
            vendor_names (list): List of string vendor names.

        Returns:
            USBDeviceQuery: This instance.
        """
        if not all(
                isinstance(vendor_name, str) for vendor_name in vendor_names):
            raise ApiError("One or more invalid vendor names")
        self._update_criteria("vendor_name", vendor_names)
        return self

    def sort_by(self, key, direction="ASC"):
        """
        Sets the sorting behavior on a query's results.

        Example:
            >>> cb.select(USBDevice).sort_by("product_name")

        Args:
            key (str): The key in the schema to sort by.
            direction (str): The sort order, either "ASC" or "DESC".

        Returns:
            USBDeviceQuery: This instance.
        """
        if direction not in DeviceSearchQuery.VALID_DIRECTIONS:
            raise ApiError("invalid sort direction specified")
        self._sortcriteria = {"field": key, "order": direction}
        return self

    def _build_request(self, from_row, max_rows, add_sort=True):
        """
        Creates the request body for an API call.

        Args:
            from_row (int): The row to start the query at.
            max_rows (int): The maximum number of rows to be returned.
            add_sort (bool): If True(default), the sort criteria will be added as part of the request.

        Returns:
            dict: The complete request body.
        """
        request = {
            "criteria": self._criteria,
            "query": self._query_builder._collapse(),
            "rows": 100
        }
        # Fetch 100 rows per page (instead of 10 by default) for better performance
        if from_row > 0:
            request["start"] = from_row
        if max_rows >= 0:
            request["rows"] = max_rows
        if add_sort and self._sortcriteria != {}:
            request["sort"] = [self._sortcriteria]
        return request

    def _build_url(self, tail_end):
        """
        Creates the URL to be used for an API call.

        Args:
            tail_end (str): String to be appended to the end of the generated URL.

        Returns:
            str: The complete URL.
        """
        url = self._doc_class.urlobject.format(
            self._cb.credentials.org_key) + tail_end
        return url

    def _count(self):
        """
        Returns the number of results from the run of this query.

        Returns:
            int: The number of results from the run of this query.
        """
        if self._count_valid:
            return self._total_results

        url = self._build_url("/_search")
        request = self._build_request(0, -1)
        resp = self._cb.post_object(url, body=request)
        result = resp.json()

        self._total_results = result["num_found"]
        self._count_valid = True

        return self._total_results

    def _perform_query(self, from_row=0, max_rows=-1):
        """
        Performs the query and returns the results of the query in an iterable fashion.

        Args:
            from_row (int): The row to start the query at (default 0).
            max_rows (int): The maximum number of rows to be returned (default -1, meaning "all").

        Returns:
            Iterable: The iterated query.
        """
        url = self._build_url("/_search")
        current = from_row
        numrows = 0
        still_querying = True
        while still_querying:
            request = self._build_request(current, max_rows)
            resp = self._cb.post_object(url, body=request)
            result = resp.json()

            self._total_results = result["num_found"]
            self._count_valid = True

            results = result.get("results", [])
            for item in results:
                yield self._doc_class(self._cb, item["id"], item)
                current += 1
                numrows += 1

                if max_rows > 0 and numrows == max_rows:
                    still_querying = False
                    break

            if current >= self._total_results:
                break

    def _run_async_query(self, context):
        """
        Executed in the background to run an asynchronous query.

        Args:
            context (object): Not used, always None.

        Returns:
            list: Result of the async query, which is then returned by the future.
        """
        url = self._build_url("/_search")
        request = self._build_request(0, -1)
        resp = self._cb.post_object(url, body=request)
        result = resp.json()
        self._total_results = result["num_found"]
        self._count_valid = True
        results = result.get("results", [])
        return [
            self._doc_class(self._cb, item["id"], item) for item in results
        ]

    def facets(self, fieldlist, max_rows=0):
        """
        Return information about the facets for all known USB devices, using the defined criteria.

        Args:
            fieldlist (list): List of facet field names. Valid names are "vendor_name", "product_name",
                              "endpoint.endpoint_name", and "status".
            max_rows (int): The maximum number of rows to return. 0 means return all rows.

        Returns:
            list: A list of facet information specified as dicts.
        """
        if not all((field in USBDeviceQuery.VALID_FACET_FIELDS)
                   for field in fieldlist):
            raise ApiError("One or more invalid term field names")
        request = self._build_request(0, -1, False)
        del request["rows"]
        request["terms"] = {"fields": fieldlist, "rows": max_rows}
        url = self._build_url("/_facet")
        resp = self._cb.post_object(url, body=request)
        result = resp.json()
        return result.get("terms", [])
Example #20
0
class Query(PaginatedQuery, PlatformQueryBase, QueryBuilderSupportMixin, IterableQueryMixin):
    """Represents a prepared query to the Cb Endpoint Standard server.

    This object is returned as part of a `CBCloudAPI.select`
    operation on models requested from the Cb Endpoint Standard server.
    You should not have to create this class yourself.

    The query is not executed on the server until it's accessed, either as an iterator (where it will generate values
    on demand as they're requested) or as a list (where it will retrieve the entire result set and save to a list).
    You can also call the Python built-in `len() on this object to retrieve the total number of items matching
    the query.

    Example:
    >>> from cbc_sdk import CBCloudAPI
    >>> cb = CBCloudAPI()

    Notes:
        - The slicing operator only supports start and end parameters, but not step. ``[1:-1]`` is legal, but
          ``[1:2:-1]`` is not.
        - You can chain where clauses together to create AND queries; only objects that match all ``where`` clauses
          will be returned.
          - Device Queries with multiple search parameters only support AND operations, not OR. Use of
          Query.or_(myParameter='myValue') will add 'AND myParameter:myValue' to the search query.
    """
    def __init__(self, doc_class, cb, query=None):
        """Initialize a Query object."""
        super(Query, self).__init__(doc_class, cb, query)

        # max batch_size is 5000
        self._batch_size = 100
        if query is not None:
            # copy existing .where(), and_() queries
            self._query_builder = QueryBuilder()
            self._query_builder._query = query._query_builder._query
        else:
            self._query_builder = QueryBuilder()

    def _clone(self):
        nq = self.__class__(self._doc_class, self._cb, query=self)
        nq._batch_size = self._batch_size
        return nq

    def or_(self, **kwargs):
        """Unsupported. Will raise if called.

        Raises:
            ApiError: .or_() cannot be called on Endpoint Standard queries.
        """
        raise ApiError(".or_() cannot be called on Endpoint Standard queries.")

    def prepare_query(self, args):
        """Adds query parameters that are part of a `select().where()` clause to the request."""
        request = args
        params = self._query_builder._collapse()
        if params is not None:
            for query in params.split(' '):
                try:
                    # convert from str('key:value') to dict{'key': 'value'}
                    key, value = query.split(':', 1)
                    # must remove leading or trailing parentheses that were inserted by logical combinations
                    key = key.strip('(').strip(')')
                    value = value.strip('(').strip(')')
                    request[key] = value
                except ValueError:
                    # AND or OR encountered
                    pass
        return request

    def _count(self):
        if self._count_valid:
            return self._total_results

        args = {}
        args = self.prepare_query(args)

        query_args = convert_query_params(args)
        self._total_results = int(self._cb.get_object(self._doc_class.urlobject, query_parameters=query_args)
                                  .get("totalResults", 0))
        self._count_valid = True
        return self._total_results

    def _search(self, start=0, rows=0):
        # iterate over total result set, in batches of self._batch_size at a time
        # defaults to 100 results each call
        args = {}
        if start != 0:
            args['start'] = start
        args['rows'] = self._batch_size

        current = start
        numrows = 0

        args = self.prepare_query(args)
        still_querying = True

        while still_querying:
            query_args = convert_query_params(args)
            result = self._cb.get_object(self._doc_class.urlobject, query_parameters=query_args)

            self._total_results = result.get("totalResults", 0)
            self._count_valid = True

            results = result.get('results', [])

            if results is None:
                log.debug("Results are None")
                if current >= 100000:
                    log.info("Max result size exceeded. Truncated to 100k.")
                break

            for item in results:
                yield item
                current += 1
                numrows += 1
                if rows and numrows == rows:
                    still_querying = False
                    break
            # as of 6/2017, the indexing on the Cb Endpoint Standard backend is still 1-based
            args['start'] = current + 1

            if current >= self._total_results:
                break

            if not results:
                log.debug("server reported total_results overestimated the number of results for this query by {0}"
                          .format(self._total_results - current))
                log.debug("resetting total_results for this query to {0}".format(current))
                self._total_results = current
                break
Example #21
0
class BaseAlertSearchQuery(BaseQuery, QueryBuilderSupportMixin, IterableQueryMixin, CriteriaBuilderSupportMixin):
    """Represents a query that is used to locate BaseAlert objects."""
    VALID_CATEGORIES = ["THREAT", "MONITORED", "INFO", "MINOR", "SERIOUS", "CRITICAL"]
    VALID_REPUTATIONS = ["KNOWN_MALWARE", "SUSPECT_MALWARE", "PUP", "NOT_LISTED", "ADAPTIVE_WHITE_LIST",
                         "COMMON_WHITE_LIST", "TRUSTED_WHITE_LIST", "COMPANY_BLACK_LIST"]
    VALID_ALERT_TYPES = ["CB_ANALYTICS", "DEVICE_CONTROL", "WATCHLIST"]
    VALID_WORKFLOW_VALS = ["OPEN", "DISMISSED"]
    VALID_FACET_FIELDS = ["ALERT_TYPE", "CATEGORY", "REPUTATION", "WORKFLOW", "TAG", "POLICY_ID",
                          "POLICY_NAME", "DEVICE_ID", "DEVICE_NAME", "APPLICATION_HASH",
                          "APPLICATION_NAME", "STATUS", "RUN_STATE", "POLICY_APPLIED_STATE",
                          "POLICY_APPLIED", "SENSOR_ACTION"]

    def __init__(self, doc_class, cb):
        """
        Initialize the BaseAlertSearchQuery.

        Args:
            doc_class (class): The model class that will be returned by this query.
            cb (BaseAPI): Reference to API object used to communicate with the server.
        """
        self._doc_class = doc_class
        self._cb = cb
        self._count_valid = False
        super(BaseAlertSearchQuery, self).__init__()

        self._query_builder = QueryBuilder()
        self._criteria = {}
        self._time_filters = {}
        self._sortcriteria = {}
        self._bulkupdate_url = "/appservices/v6/orgs/{0}/alerts/workflow/_criteria"
        self._count_valid = False
        self._total_results = 0

    def set_categories(self, categories):
        """
        Restricts the alerts that this query is performed on to the specified categories.

        Args:
            categories (list): List of categories to be restricted to. Valid categories are
                               "THREAT", "MONITORED", "INFO", "MINOR", "SERIOUS", and "CRITICAL."

        Returns:
            BaseAlertSearchQuery: This instance.
        """
        if not all((c in BaseAlertSearchQuery.VALID_CATEGORIES) for c in categories):
            raise ApiError("One or more invalid category values")
        self._update_criteria("category", categories)
        return self

    def set_create_time(self, *args, **kwargs):
        """
        Restricts the alerts that this query is performed on to the specified creation time.

        The time may either be specified as a start and end point or as a range.

        Args:
            *args (list): Not used.
            **kwargs (dict): Used to specify start= for start time, end= for end time, and range= for range.

        Returns:
            BaseAlertSearchQuery: This instance.
        """
        if kwargs.get("start", None) and kwargs.get("end", None):
            if kwargs.get("range", None):
                raise ApiError("cannot specify range= in addition to start= and end=")
            stime = kwargs["start"]
            if not isinstance(stime, str):
                stime = stime.isoformat()
            etime = kwargs["end"]
            if not isinstance(etime, str):
                etime = etime.isoformat()
            self._time_filters["create_time"] = {"start": stime, "end": etime}
        elif kwargs.get("range", None):
            if kwargs.get("start", None) or kwargs.get("end", None):
                raise ApiError("cannot specify start= or end= in addition to range=")
            self._time_filters["create_time"] = {"range": kwargs["range"]}
        else:
            raise ApiError("must specify either start= and end= or range=")
        return self

    def set_device_ids(self, device_ids):
        """
        Restricts the alerts that this query is performed on to the specified device IDs.

        Args:
            device_ids (list): List of integer device IDs.

        Returns:
            BaseAlertSearchQuery: This instance.
        """
        if not all(isinstance(device_id, int) for device_id in device_ids):
            raise ApiError("One or more invalid device IDs")
        self._update_criteria("device_id", device_ids)
        return self

    def set_device_names(self, device_names):
        """
        Restricts the alerts that this query is performed on to the specified device names.

        Args:
            device_names (list): List of string device names.

        Returns:
            BaseAlertSearchQuery: This instance.
        """
        if not all(isinstance(n, str) for n in device_names):
            raise ApiError("One or more invalid device names")
        self._update_criteria("device_name", device_names)
        return self

    def set_device_os(self, device_os):
        """
        Restricts the alerts that this query is performed on to the specified device operating systems.

        Args:
            device_os (list): List of string operating systems.  Valid values are "WINDOWS", "ANDROID",
                              "MAC", "IOS", "LINUX", and "OTHER."

        Returns:
            BaseAlertSearchQuery: This instance.
        """
        if not all((osval in DeviceSearchQuery.VALID_OS) for osval in device_os):
            raise ApiError("One or more invalid operating systems")
        self._update_criteria("device_os", device_os)
        return self

    def set_device_os_versions(self, device_os_versions):
        """
        Restricts the alerts that this query is performed on to the specified device operating system versions.

        Args:
            device_os_versions (list): List of string operating system versions.

        Returns:
            BaseAlertSearchQuery: This instance.
        """
        if not all(isinstance(n, str) for n in device_os_versions):
            raise ApiError("One or more invalid device OS versions")
        self._update_criteria("device_os_version", device_os_versions)
        return self

    def set_device_username(self, users):
        """
        Restricts the alerts that this query is performed on to the specified user names.

        Args:
            users (list): List of string user names.

        Returns:
            BaseAlertSearchQuery: This instance.
        """
        if not all(isinstance(u, str) for u in users):
            raise ApiError("One or more invalid user names")
        self._update_criteria("device_username", users)
        return self

    def set_group_results(self, do_group):
        """
        Specifies whether or not to group the results of the query.

        Args:
            do_group (bool): True to group the results, False to not do so.

        Returns:
            BaseAlertSearchQuery: This instance.
        """
        self._criteria["group_results"] = True if do_group else False
        return self

    def set_alert_ids(self, alert_ids):
        """
        Restricts the alerts that this query is performed on to the specified alert IDs.

        Args:
            alert_ids (list): List of string alert IDs.

        Returns:
            BaseAlertSearchQuery: This instance.
        """
        if not all(isinstance(v, str) for v in alert_ids):
            raise ApiError("One or more invalid alert ID values")
        self._update_criteria("id", alert_ids)
        return self

    def set_legacy_alert_ids(self, alert_ids):
        """
        Restricts the alerts that this query is performed on to the specified legacy alert IDs.

        Args:
            alert_ids (list): List of string legacy alert IDs.

        Returns:
            BaseAlertSearchQuery: This instance.
        """
        if not all(isinstance(v, str) for v in alert_ids):
            raise ApiError("One or more invalid alert ID values")
        self._update_criteria("legacy_alert_id", alert_ids)
        return self

    def set_minimum_severity(self, severity):
        """
        Restricts the alerts that this query is performed on to the specified minimum severity level.

        Args:
            severity (int): The minimum severity level for alerts.

        Returns:
            BaseAlertSearchQuery: This instance.
        """
        self._criteria["minimum_severity"] = severity
        return self

    def set_policy_ids(self, policy_ids):
        """
        Restricts the alerts that this query is performed on to the specified policy IDs.

        Args:
            policy_ids (list): List of integer policy IDs.

        Returns:
            BaseAlertSearchQuery: This instance.
        """
        if not all(isinstance(policy_id, int) for policy_id in policy_ids):
            raise ApiError("One or more invalid policy IDs")
        self._update_criteria("policy_id", policy_ids)
        return self

    def set_policy_names(self, policy_names):
        """
        Restricts the alerts that this query is performed on to the specified policy names.

        Args:
            policy_names (list): List of string policy names.

        Returns:
            BaseAlertSearchQuery: This instance.
        """
        if not all(isinstance(n, str) for n in policy_names):
            raise ApiError("One or more invalid policy names")
        self._update_criteria("policy_name", policy_names)
        return self

    def set_process_names(self, process_names):
        """
        Restricts the alerts that this query is performed on to the specified process names.

        Args:
            process_names (list): List of string process names.

        Returns:
            BaseAlertSearchQuery: This instance.
        """
        if not all(isinstance(n, str) for n in process_names):
            raise ApiError("One or more invalid process names")
        self._update_criteria("process_name", process_names)
        return self

    def set_process_sha256(self, shas):
        """
        Restricts the alerts that this query is performed on to the specified process SHA-256 hash values.

        Args:
            shas (list): List of string process SHA-256 hash values.

        Returns:
            BaseAlertSearchQuery: This instance.
        """
        if not all(isinstance(n, str) for n in shas):
            raise ApiError("One or more invalid SHA256 values")
        self._update_criteria("process_sha256", shas)
        return self

    def set_reputations(self, reps):
        """
        Restricts the alerts that this query is performed on to the specified reputation values.

        Args:
            reps (list): List of string reputation values.  Valid values are "KNOWN_MALWARE", "SUSPECT_MALWARE",
                         "PUP", "NOT_LISTED", "ADAPTIVE_WHITE_LIST", "COMMON_WHITE_LIST", "TRUSTED_WHITE_LIST",
                          and "COMPANY_BLACK_LIST".

        Returns:
            BaseAlertSearchQuery: This instance.
        """
        if not all((r in BaseAlertSearchQuery.VALID_REPUTATIONS) for r in reps):
            raise ApiError("One or more invalid reputation values")
        self._update_criteria("reputation", reps)
        return self

    def set_tags(self, tags):
        """
        Restricts the alerts that this query is performed on to the specified tag values.

        Args:
            tags (list): List of string tag values.

        Returns:
            BaseAlertSearchQuery: This instance.
        """
        if not all(isinstance(tag, str) for tag in tags):
            raise ApiError("One or more invalid tags")
        self._update_criteria("tag", tags)
        return self

    def set_target_priorities(self, priorities):
        """
        Restricts the alerts that this query is performed on to the specified target priority values.

        Args:
            priorities (list): List of string target priority values.  Valid values are "LOW", "MEDIUM",
                               "HIGH", and "MISSION_CRITICAL".

        Returns:
            BaseAlertSearchQuery: This instance.
        """
        if not all((prio in DeviceSearchQuery.VALID_PRIORITIES) for prio in priorities):
            raise ApiError("One or more invalid priority values")
        self._update_criteria("target_value", priorities)
        return self

    def set_threat_ids(self, threats):
        """
        Restricts the alerts that this query is performed on to the specified threat ID values.

        Args:
            threats (list): List of string threat ID values.

        Returns:
            BaseAlertSearchQuery: This instance.
        """
        if not all(isinstance(t, str) for t in threats):
            raise ApiError("One or more invalid threat ID values")
        self._update_criteria("threat_id", threats)
        return self

    def set_time_range(self, key, **kwargs):
        """
        Restricts the alerts that this query is performed on to the specified time range.

        The time may either be specified as a start and end point or as a range.

        Args:
            key (str): The key to use for criteria one of create_time,
                       first_event_time, last_event_time, or last_update_time
            **kwargs (dict): Used to specify start= for start time, end= for end time, and range= for range.

        Returns:
            BaseAlertSearchQuery: This instance.
        """
        if key not in ["create_time", "first_event_time", "last_event_time", "last_update_time"]:
            raise ApiError("key must be one of create_time, first_event_time, last_event_time, or last_update_time")
        if kwargs.get("start", None) and kwargs.get("end", None):
            if kwargs.get("range", None):
                raise ApiError("cannot specify range= in addition to start= and end=")
            stime = kwargs["start"]
            if not isinstance(stime, str):
                stime = stime.isoformat()
            etime = kwargs["end"]
            if not isinstance(etime, str):
                etime = etime.isoformat()
            self._time_filters[key] = {"start": stime, "end": etime}
        elif kwargs.get("range", None):
            if kwargs.get("start", None) or kwargs.get("end", None):
                raise ApiError("cannot specify start= or end= in addition to range=")
            self._time_filters[key] = {"range": kwargs["range"]}
        else:
            raise ApiError("must specify either start= and end= or range=")
        return self

    def set_types(self, alerttypes):
        """
        Restricts the alerts that this query is performed on to the specified alert type values.

        Args:
            alerttypes (list): List of string alert type values.  Valid values are "CB_ANALYTICS",
                               and "WATCHLIST".

        Returns:
            BaseAlertSearchQuery: This instance.
        """
        if not all((t in BaseAlertSearchQuery.VALID_ALERT_TYPES) for t in alerttypes):
            raise ApiError("One or more invalid alert type values")
        self._update_criteria("type", alerttypes)
        return self

    def set_workflows(self, workflow_vals):
        """
        Restricts the alerts that this query is performed on to the specified workflow status values.

        Args:
            workflow_vals (list): List of string alert type values.  Valid values are "OPEN" and "DISMISSED".

        Returns:
            BaseAlertSearchQuery: This instance.
        """
        if not all((t in BaseAlertSearchQuery.VALID_WORKFLOW_VALS) for t in workflow_vals):
            raise ApiError("One or more invalid workflow status values")
        self._update_criteria("workflow", workflow_vals)
        return self

    def _build_criteria(self):
        """
        Builds the criteria object for use in a query.

        Returns:
            dict: The criteria object.
        """
        mycrit = self._criteria
        if self._time_filters:
            mycrit.update(self._time_filters)
        return mycrit

    def sort_by(self, key, direction="ASC"):
        """
        Sets the sorting behavior on a query's results.

        Example:
            >>> cb.select(BaseAlert).sort_by("name")

        Args:
            key (str): The key in the schema to sort by.
            direction (str): The sort order, either "ASC" or "DESC".

        Returns:
            BaseAlertSearchQuery: This instance.
        """
        if direction not in DeviceSearchQuery.VALID_DIRECTIONS:
            raise ApiError("invalid sort direction specified")
        self._sortcriteria = {"field": key, "order": direction}
        return self

    def _build_request(self, from_row, max_rows, add_sort=True):
        """
        Creates the request body for an API call.

        Args:
            from_row (int): The row to start the query at.
            max_rows (int): The maximum number of rows to be returned.
            add_sort (bool): If True(default), the sort criteria will be added as part of the request.

        Returns:
            dict: The complete request body.
        """
        request = {"criteria": self._build_criteria()}
        request["query"] = self._query_builder._collapse()
        # Fetch 100 rows per page (instead of 10 by default) for better performance
        request["rows"] = 100
        if from_row > 0:
            request["start"] = from_row
        if max_rows >= 0:
            request["rows"] = max_rows
        if add_sort and self._sortcriteria != {}:
            request["sort"] = [self._sortcriteria]
        return request

    def _build_url(self, tail_end):
        """
        Creates the URL to be used for an API call.

        Args:
            tail_end (str): String to be appended to the end of the generated URL.

        Returns:
            str: The complete URL.
        """
        url = self._doc_class.urlobject.format(self._cb.credentials.org_key) + tail_end
        return url

    def _count(self):
        """
        Returns the number of results from the run of this query.

        Returns:
            int: The number of results from the run of this query.
        """
        if self._count_valid:
            return self._total_results

        url = self._build_url("/_search")
        request = self._build_request(0, -1)
        resp = self._cb.post_object(url, body=request)
        result = resp.json()

        self._total_results = result["num_found"]
        self._count_valid = True

        return self._total_results

    def _perform_query(self, from_row=0, max_rows=-1):
        """
        Performs the query and returns the results of the query in an iterable fashion.

        Args:
            from_row (int): The row to start the query at (default 0).
            max_rows (int): The maximum number of rows to be returned (default -1, meaning "all").

        Returns:
            Iterable: The iterated query.
        """
        url = self._build_url("/_search")
        current = from_row
        numrows = 0
        still_querying = True
        while still_querying:
            request = self._build_request(current, max_rows)
            resp = self._cb.post_object(url, body=request)
            result = resp.json()

            self._total_results = result["num_found"]
            self._count_valid = True

            results = result.get("results", [])
            for item in results:
                yield self._doc_class(self._cb, item["id"], item)
                current += 1
                numrows += 1

                if max_rows > 0 and numrows == max_rows:
                    still_querying = False
                    break

            from_row = current
            if current >= self._total_results:
                still_querying = False
                break

    def facets(self, fieldlist, max_rows=0):
        """
        Return information about the facets for this alert by search, using the defined criteria.

        Args:
            fieldlist (list): List of facet field names. Valid names are "ALERT_TYPE", "CATEGORY", "REPUTATION",
                              "WORKFLOW", "TAG", "POLICY_ID", "POLICY_NAME", "DEVICE_ID", "DEVICE_NAME",
                              "APPLICATION_HASH", "APPLICATION_NAME", "STATUS", "RUN_STATE", "POLICY_APPLIED_STATE",
                              "POLICY_APPLIED", and "SENSOR_ACTION".
            max_rows (int): The maximum number of rows to return. 0 means return all rows.

        Returns:
            list: A list of facet information specified as dicts.
        """
        if not all((field in BaseAlertSearchQuery.VALID_FACET_FIELDS) for field in fieldlist):
            raise ApiError("One or more invalid term field names")
        request = self._build_request(0, -1, False)
        request["terms"] = {"fields": fieldlist, "rows": max_rows}
        url = self._build_url("/_facet")
        resp = self._cb.post_object(url, body=request)
        result = resp.json()
        return result.get("results", [])

    def _update_status(self, status, remediation, comment):
        """
        Updates the status of all alerts matching the given query.

        Args:
            status (str): The status to put the alerts into, either "OPEN" or "DISMISSED".
            remediation (str): The remediation state to set for all alerts.
            comment (str): The comment to set for all alerts.

        Returns:
            str: The request ID, which may be used to select a WorkflowStatus object.
        """
        request = {"state": status, "criteria": self._build_criteria(), "query": self._query_builder._collapse()}
        if remediation is not None:
            request["remediation_state"] = remediation
        if comment is not None:
            request["comment"] = comment
        resp = self._cb.post_object(self._bulkupdate_url.format(self._cb.credentials.org_key), body=request)
        output = resp.json()
        return output["request_id"]

    def update(self, remediation=None, comment=None):
        """
        Update all alerts matching the given query. The alerts will be left in an OPEN state after this request.

        Args:
            remediation (str): The remediation state to set for all alerts.
            comment (str): The comment to set for all alerts.

        Returns:
            str: The request ID, which may be used to select a WorkflowStatus object.
        """
        return self._update_status("OPEN", remediation, comment)

    def dismiss(self, remediation=None, comment=None):
        """
        Dismiss all alerts matching the given query. The alerts will be left in a DISMISSED state after this request.

        Args:
            remediation (str): The remediation state to set for all alerts.
            comment (str): The comment to set for all alerts.

        Returns:
            str: The request ID, which may be used to select a WorkflowStatus object.
        """
        return self._update_status("DISMISSED", remediation, comment)
class ComputeResourceQuery(BaseQuery, QueryBuilderSupportMixin, CriteriaBuilderSupportMixin,
                           IterableQueryMixin, AsyncQueryMixin):
    """Represents a query that is used to locate ComputeResource objects."""
    VALID_OS_TYPE = ("WINDOWS", "RHEL", "UBUNTU", "SUSE", "SLES", "CENTOS", "OTHER", "AMAZON_LINUX", "ORACLE")
    VALID_DIRECTIONS = ("ASC", "DESC")
    VALID_ELIGIBILITY = ("ELIGIBLE", "NOT_ELIGIBLE", "UNSUPPORTED")
    VALID_OS_ARCHITECTURE = ("32", "64")
    VALID_INSTALLATION_STATUS = ("SUCCESS", "ERROR", "PENDING", "NOT_INSTALLED")

    def __init__(self, doc_class, cb):
        """
        Initialize the ComputeResource.

        Args:
            doc_class (class): The model class that will be returned by this query.
            cb (BaseAPI): Reference to API object used to communicate with the server.
        """
        self._doc_class = doc_class
        self._cb = cb
        self._count_valid = False
        super(BaseQuery, self).__init__()
        self._query_builder = QueryBuilder()
        self._criteria = {}
        self._sortcriteria = {}
        self._total_results = 0

    def set_appliance_uuid(self, appliance_uuid):
        """
        Restricts the search that this query is performed on to the specified appliance uuid.

        Args:
            appliance_uuid (list): List of string appliance uuids.

        Returns:
            ComputeResourceQuery: This instance.
        """
        if not all(isinstance(_, str) for _ in appliance_uuid):
            raise ApiError("One or more invalid appliance uuid")
        self._update_criteria("appliance_uuid", appliance_uuid)
        return self

    def set_eligibility(self, eligibility):
        """
        Restricts the search that this query is performed on to the specified eligibility.

        Args:
            eligibility (list): List of string eligibilities.

        Returns:
            ComputeResourceQuery: This instance.
        """
        if not all((_ in ComputeResourceQuery.VALID_ELIGIBILITY) for _ in eligibility):
            raise ApiError("One or more invalid eligibility")
        self._update_criteria("eligibility", eligibility)
        return self

    def set_cluster_name(self, cluster_name):
        """
        Restricts the search that this query is performed on to the specified cluster name.

        Args:
            cluster_name (list): List of string cluster names.

        Returns:
            ComputeResourceQuery: This instance.
        """
        if not all(isinstance(_, str) for _ in cluster_name):
            raise ApiError("One or more invalid cluster name")
        self._update_criteria("cluster_name", cluster_name)
        return self

    def set_name(self, name):
        """
        Restricts the search that this query is performed on to the specified name.

        Args:
            name (list): List of string names.

        Returns:
            ComputeResourceQuery: This instance.
        """
        if not all(isinstance(_, str) for _ in name):
            raise ApiError("One or more invalid names")
        self._update_criteria("name", name)
        return self

    def set_ip_address(self, ip_address):
        """
        Restricts the search that this query is performed on to the specified ip address.

        Args:
            ip_address (list): List of string ip addresses.

        Returns:
            ComputeResourceQuery: This instance.
        """
        if not all(isinstance(_, str) for _ in ip_address):
            raise ApiError("One or more invalid ip address")
        self._update_criteria("ip_address", ip_address)
        return self

    def set_installation_status(self, installation_status):
        """
        Restricts the search that this query is performed on to the specified installation status.

        Args:
            installation_status (list): List of string installation status.

        Returns:
            ComputeResourceQuery: This instance.
        """
        if not all((_ in ComputeResourceQuery.VALID_INSTALLATION_STATUS) for _ in installation_status):
            raise ApiError("One or more invalid installation status")
        self._update_criteria("installation_status", installation_status)
        return self

    def set_uuid(self, uuid):
        """
        Restricts the search that this query is performed on to the specified uuid.

        Args:
            uuid (list): List of string uuid.

        Returns:
            ComputeResourceQuery: This instance.
        """
        if not all(isinstance(_, str) for _ in uuid):
            raise ApiError("One or more invalid uuid")
        self._update_criteria("uuid", uuid)
        return self

    def set_os_type(self, os_type):
        """
        Restricts the search that this query is performed on to the specified os type.

        Args:
            os_type (list): List of string os type.

        Returns:
            ComputeResourceQuery: This instance.
        """
        if not all((_ in ComputeResourceQuery.VALID_OS_TYPE) for _ in os_type):
            raise ApiError("One or more invalid os type")
        self._update_criteria("os_type", os_type)
        return self

    def set_os_architecture(self, os_architecture):
        """
        Restricts the search that this query is performed on to the specified os architecture.

        Args:
            os_architecture (list): List of string os architecture.

        Returns:
            ComputeResourceQuery: This instance.
        """
        if not all((_ in ComputeResourceQuery.VALID_OS_ARCHITECTURE) for _ in os_architecture):
            raise ApiError("One or more invalid os architecture")
        self._update_criteria("os_architecture", os_architecture)
        return self

    def sort_by(self, key, direction="ASC"):
        """
        Sets the sorting behavior on a query's results.

        Example:
            >>> cb.select(ComputeResource).sort_by("name")

        Args:
            key (str): The key in the schema to sort by.
            direction (str): The sort order.

        Returns:
            ComputeResourceQuery: This instance.
        """
        if direction not in ComputeResourceQuery.VALID_DIRECTIONS:
            raise ApiError("invalid sort direction specified")
        self._sortcriteria = {"field": key, "order": direction}
        return self

    def _build_request(self, from_row, max_rows, add_sort=True):
        """
        Creates the request body for an API call.

        Args:
            from_row (int): The row to start the query at.
            max_rows (int): The maximum number of rows to be returned.
            add_sort (bool): If True(default), the sort criteria will be added as part of the request.

        Returns:
            dict: The complete request body.
        """
        request = {"criteria": self._criteria, "query": self._query_builder._collapse(), "rows": 100}
        # Fetch 100 rows per page (instead of 10 by default) for better performance
        if from_row > 0:
            request["start"] = from_row
        if max_rows >= 0:
            request["rows"] = max_rows
        if add_sort and self._sortcriteria != {}:
            request["sort"] = [self._sortcriteria]
        return request

    def _build_url(self, tail_end):
        """
        Creates the URL to be used for an API call.

        Args:
            tail_end (str): String to be appended to the end of the generated URL.

        Returns:
            str: The complete URL.
        """
        url = self._doc_class.urlobject.format(self._cb.credentials.org_key) + tail_end
        return url

    def _count(self):
        """
        Returns the number of results from the run of this query.

        Returns:
            int: The number of results from the run of this query.
        """
        if self._count_valid:
            return self._total_results

        url = self._build_url("/_search")
        request = self._build_request(0, -1)
        resp = self._cb.post_object(url, body=request)
        result = resp.json()

        self._total_results = result["num_found"]
        self._count_valid = True

        return self._total_results

    def _perform_query(self, from_row=0, max_rows=-1):
        """
        Performs the query and returns the results of the query in an iterable fashion.

        Args:
            from_row (int): The row to start the query at (default 0).
            max_rows (int): The maximum number of rows to be returned (default -1, meaning "all").

        Returns:
            Iterable: The iterated query.
        """
        url = self._build_url("/_search")
        current = from_row
        numrows = 0
        still_querying = True
        while still_querying:
            request = self._build_request(current, max_rows)
            resp = self._cb.post_object(url, body=request)
            result = resp.json()

            self._total_results = result["num_found"]
            self._count_valid = True

            results = result.get("results", [])
            for item in results:
                yield self._doc_class(self._cb, item["id"], item)
                current += 1
                numrows += 1

                if max_rows > 0 and numrows == max_rows:
                    still_querying = False
                    break

            if current >= self._total_results:
                break

    def _run_async_query(self, context):
        """
        Executed in the background to run an asynchronous query.

        Args:
            context (object): Not used, always None.

        Returns:
            list: Result of the async query, which is then returned by the future.
        """
        url = self._build_url("/_search")
        request = self._build_request(0, -1)
        resp = self._cb.post_object(url, body=request)
        result = resp.json()
        self._total_results = result["num_found"]
        self._count_valid = True
        results = result.get("results", [])
        return [self._doc_class(self._cb, item["id"], item) for item in results]