Beispiel #1
0
def sync_checks(task, request):
    code_checks = inspect.getmembers(checks, inspect.isclass)
    request.log.info("%d malware checks found in codebase." % len(code_checks))

    all_checks = request.db.query(MalwareCheck).all()
    active_checks = {}
    wiped_out_checks = {}
    for check in all_checks:
        if not check.is_stale:
            if check.state == MalwareCheckState.WipedOut:
                wiped_out_checks[check.name] = check
            else:
                active_checks[check.name] = check

    if len(active_checks) > len(code_checks):
        code_check_names = set([name for name, cls in code_checks])
        missing = ", ".join(set(active_checks.keys()) - code_check_names)
        request.log.error(
            "Found %d active checks in the db, but only %d checks in \
code. Please manually move superfluous checks to the wiped_out state \
in the check admin: %s" % (len(active_checks), len(code_checks), missing))
        raise Exception(
            "Mismatch between number of db checks and code checks.")

    for check_name, check_class in code_checks:
        check = getattr(checks, check_name)

        if wiped_out_checks.get(check_name):
            request.log.error(
                "%s is wiped_out and cannot be synced. Please remove check from \
codebase." % check_name)
            continue

        db_check = active_checks.get(check_name)
        if db_check:
            if check.version == db_check.version:
                request.log.info("%s is unmodified." % check_name)
                continue

            request.log.info("Updating existing %s." % check_name)
            fields = get_check_fields(check)

            # Migrate the check state to the newest check.
            # Then mark the old check state as disabled.
            if db_check.state != MalwareCheckState.Disabled:
                fields["state"] = db_check.state.value
                db_check.state = MalwareCheckState.Disabled

            request.db.add(MalwareCheck(**fields))
        else:
            request.log.info("Adding new %s to the database." % check_name)
            fields = get_check_fields(check)
            request.db.add(MalwareCheck(**fields))
Beispiel #2
0
def test_checks_fields(checks):
    checks_from_module = inspect.getmembers(checks, inspect.isclass)

    for check_name, check in checks_from_module:
        elems = inspect.getmembers(check, lambda a: not (inspect.isroutine(a)))
        inspection_fields = {"name": check_name}
        for elem_name, value in elems:
            # Skip both dunder and "private" (_-prefixed) attributes
            if not elem_name.startswith("_"):
                inspection_fields[elem_name] = value
        fields = get_check_fields(check)

        assert inspection_fields == fields
Beispiel #3
0
    def test_failure(self, monkeypatch):
        monkeypatch.delattr(ExampleScheduledCheck, "schedule")

        with pytest.raises(AttributeError):
            get_check_fields(ExampleScheduledCheck)
Beispiel #4
0
 def test_success(self, check, result):
     assert get_check_fields(check) == result