Ejemplo n.º 1
0
 def test_one_time_compliance_check_one_ok_tuple(self):
     query, _, distributed_query = self._force_query(
         force_pack=True,
         force_compliance_check=True,
         force_distributed_query=True
     )
     compliance_check = query.compliance_check
     serial_number = get_random_string()
     cc_status_agg = ComplianceCheckStatusAggregator(serial_number)
     status_time = datetime.utcnow()
     cc_status_agg.add_result(
         query.pk, query.version, status_time, [{"ztl_status": Status.OK.name}], distributed_query.pk
     )
     events = list(cc_status_agg.commit())
     self.assertEqual(len(events), 1)
     event = events[0]
     self.assertIsInstance(event, OsqueryCheckStatusUpdated)
     self.assertEqual(event.payload["pk"], compliance_check.pk)
     self.assertEqual(event.payload["version"], query.version)
     self.assertEqual(event.payload["version"], compliance_check.version)
     self.assertEqual(event.payload["osquery_query"], {"pk": query.pk})
     self.assertIsNone(event.payload.get("osquery_pack"))
     self.assertEqual(event.payload["osquery_run"], {"pk": distributed_query.pk})
     self.assertEqual(event.payload["status"], Status.OK.name)
     self.assertIsNone(event.payload.get("previous_status"))
     self.assertEqual(event.get_linked_objects_keys(),
                      {"compliance_check": [(compliance_check.pk,)],
                       "osquery_query": [(query.pk,)],
                       "osquery_run": [(distributed_query.pk,)]})
     ms_qs = MachineStatus.objects.filter(compliance_check=compliance_check, serial_number=serial_number)
     self.assertEqual(ms_qs.count(), 1)
     ms = ms_qs.first()
     self.assertEqual(ms.compliance_check_version, compliance_check.version)
     self.assertEqual(ms.status, Status.OK.value)
     self.assertEqual(ms.status_time, status_time)
Ejemplo n.º 2
0
 def test_scheduled_compliance_check_one_ok_tuple_update(self):
     query, _, _ = self._force_query(force_pack=True, force_compliance_check=True)
     serial_number = get_random_string()
     cc_status_agg = ComplianceCheckStatusAggregator(serial_number)
     existing_ms = MachineStatus.objects.create(
         serial_number=serial_number,
         compliance_check=query.compliance_check,
         compliance_check_version=query.compliance_check.version,
         status=Status.OK.value,
         status_time=datetime(2001, 1, 1)
     )
     status_time = datetime.utcnow()
     cc_status_agg.add_result(query.pk, query.version, status_time, [{"ztl_status": Status.FAILED.name}])
     events = list(cc_status_agg.commit())
     self.assertEqual(len(events), 1)
     event = events[0]
     self.assertEqual(event.payload["status"], Status.FAILED.name)
     self.assertEqual(event.payload["previous_status"], Status.OK.name)
     ms_qs = MachineStatus.objects.filter(compliance_check=query.compliance_check, serial_number=serial_number)
     self.assertEqual(ms_qs.count(), 1)
     ms = ms_qs.first()
     self.assertEqual(ms, existing_ms)
     self.assertEqual(ms.compliance_check_version, query.compliance_check.version)
     self.assertEqual(ms.status, Status.FAILED.value)
     self.assertEqual(ms.status_time, status_time)
Ejemplo n.º 3
0
 def test_no_compliance_check(self):
     query, _, _ = self._force_query(force_pack=True)
     serial_number = get_random_string()
     cc_status_agg = ComplianceCheckStatusAggregator(serial_number)
     cc_status_agg.add_result(query.pk, query.version, datetime.utcnow(), [{"ztl_status": Status.OK.value}])
     events = list(cc_status_agg.commit())
     self.assertEqual(len(events), 0)
     ms_qs = MachineStatus.objects.filter(compliance_check=query.compliance_check, serial_number=serial_number)
     self.assertEqual(ms_qs.count(), 0)
Ejemplo n.º 4
0
 def test_scheduled_compliance_check_one_outdated_version_failed_tuple(self):
     query, _, _ = self._force_query(force_pack=True, force_compliance_check=True)
     query.version = 127
     query.save()
     serial_number = get_random_string()
     cc_status_agg = ComplianceCheckStatusAggregator(serial_number)
     cc_status_agg.add_result(query.pk, 1, datetime.utcnow(), [{"ztl_status": Status.FAILED.name}])
     events = list(cc_status_agg.commit())
     self.assertEqual(len(events), 0)
     ms_qs = MachineStatus.objects.filter(compliance_check=query.compliance_check, serial_number=serial_number)
     self.assertEqual(ms_qs.count(), 0)
Ejemplo n.º 5
0
 def test_scheduled_compliance_check_no_tuple(self):
     query, _, _ = self._force_query(force_pack=True, force_compliance_check=True)
     serial_number = get_random_string()
     cc_status_agg = ComplianceCheckStatusAggregator(serial_number)
     status_time = datetime.utcnow()
     cc_status_agg.add_result(query.pk, query.version, status_time, [])
     events = list(cc_status_agg.commit())
     self.assertEqual(len(events), 1)
     ms_qs = MachineStatus.objects.filter(compliance_check=query.compliance_check, serial_number=serial_number)
     self.assertEqual(ms_qs.count(), 1)
     ms = ms_qs.first()
     self.assertEqual(ms.compliance_check_version, query.compliance_check.version)
     self.assertEqual(ms.status, Status.UNKNOWN.value)
     self.assertEqual(ms.status_time, status_time)
Ejemplo n.º 6
0
def post_results(msn, user_agent, ip, results):
    event_uuid = uuid.uuid4()
    if user_agent or ip:
        request = EventRequest(user_agent, ip)
    else:
        request = None
    cc_status_agg = ComplianceCheckStatusAggregator(msn)
    for index, result in enumerate(_iter_cleaned_up_records(results)):
        try:
            event_time = _get_record_created_at(result)
        except Exception:
            logger.exception("Could not extract osquery result time")
            event_time = None
        metadata = EventMetadata(uuid=event_uuid,
                                 index=index,
                                 machine_serial_number=msn,
                                 request=request,
                                 created_at=event_time)
        event = OsqueryResultEvent(metadata, result)
        event.post()
        snapshot = event.payload.get("snapshot")
        if snapshot is None:
            # no snapshot, cannot be a compliance check
            continue
        try:
            _, query_pk, query_version = event.parse_result_name()
        except ValueError as e:
            logger.warning(str(e))
            continue
        cc_status_agg.add_result(query_pk, query_version, event_time, snapshot)
    cc_status_agg.commit_and_post_events()
Ejemplo n.º 7
0
    def do_node_post(self):
        results = self.data.get("queries", {})
        statuses = self.data.get("statuses", {})
        messages = self.data.get("messages", {})
        dqm_pk_set = set(chain(results.keys(), statuses.keys(), messages.keys()))
        if not dqm_pk_set:
            return {}
        dqm_cache = {str(dqm.pk): dqm
                     for dqm in DistributedQueryMachine.objects
                                                       .select_related("distributed_query__query__compliance_check")
                                                       .filter(pk__in=dqm_pk_set)}

        # update distributed query machines
        for dqm_pk, dqm in dqm_cache.items():
            status = statuses.get(dqm_pk)
            if status is None:
                logger.warning("Missing status for DistributedQueryMachine %s", dqm_pk)
                status = 999  # TODO: better?
            dqm.status = status
            dqm.error_message = messages.get(dqm_pk)
            dqm.save()

        # save_results
        dq_results = (
            DistributedQueryResult(
                distributed_query=dqm.distributed_query,
                serial_number=self.machine.serial_number,
                row=remove_null_character(row)
            )
            for dqm_pk, dqm in dqm_cache.items()
            for row in results.get(dqm_pk, [])
        )
        while True:
            batch = list(islice(dq_results, self.batch_size))
            if not batch:
                break
            DistributedQueryResult.objects.bulk_create(batch, self.batch_size)

        # process compliance checks
        cc_status_agg = ComplianceCheckStatusAggregator(self.machine.serial_number)
        status_time = datetime.utcnow()  # TODO: how to get a better time? add ztl_status_time = now() to the query?
        for dqm_pk, dqm in dqm_cache.items():
            distributed_query = dqm.distributed_query
            query = distributed_query.query
            if not query or not query.compliance_check:
                continue
            cc_status_agg.add_result(
                query.pk,
                distributed_query.query_version,
                status_time,
                results.get(dqm_pk, []),
                distributed_query.pk
            )
        cc_status_agg.commit_and_post_events()

        return {}
Ejemplo n.º 8
0
 def test_scheduled_compliance_check_one_ok_one_failed_tuple(self):
     query1, _, _ = self._force_query(force_pack=True, force_compliance_check=True)
     query2, _, _ = self._force_query(force_pack=True, force_compliance_check=True)
     serial_number = get_random_string()
     cc_status_agg = ComplianceCheckStatusAggregator(serial_number)
     status_time1 = datetime.utcnow()
     status_time2 = datetime.utcnow()
     cc_status_agg.add_result(query1.pk, query1.version, status_time1, [{"ztl_status": Status.OK.name}])
     cc_status_agg.add_result(query2.pk, query2.version, status_time2, [{"ztl_status": Status.FAILED.name}])
     events = list(cc_status_agg.commit())
     self.assertEqual(len(events), 2)
     ms_qs1 = MachineStatus.objects.filter(compliance_check=query1.compliance_check, serial_number=serial_number)
     self.assertEqual(ms_qs1.count(), 1)
     ms_qs2 = MachineStatus.objects.filter(compliance_check=query2.compliance_check, serial_number=serial_number)
     self.assertEqual(ms_qs2.count(), 1)
     ms1 = ms_qs1.get(compliance_check=query1.compliance_check)
     self.assertEqual(ms1.compliance_check_version, query1.compliance_check.version)
     self.assertEqual(ms1.status, Status.OK.value)
     self.assertEqual(ms1.status_time, status_time1)
     ms2 = ms_qs2.get(compliance_check=query2.compliance_check)
     self.assertEqual(ms2.compliance_check_version, query2.compliance_check.version)
     self.assertEqual(ms2.status, Status.FAILED.value)
     self.assertEqual(ms2.status_time, status_time2)