def test_set(self):

        validator = validators.Set('a', 'b', 'c')
        self.assertEqual(validator.__call__('a'), 'a')
        self.assertEqual(validator.__call__('b'), 'b')
        self.assertEqual(validator.__call__('c'), 'c')
        self.assertEqual(validator.__call__(None), None)
        self.assertRaises(ValueError, validator.__call__, 'd')
class StubbedReportingCommand(ReportingCommand):
    boolean = Option(
        doc='''
        **Syntax:** **boolean=***<value>*
        **Description:** A boolean value''',
        require=False, validate=validators.Boolean())

    duration = Option(
        doc='''
        **Syntax:** **duration=***<value>*
        **Description:** A length of time''',
        validate=validators.Duration())

    fieldname = Option(
        doc='''
        **Syntax:** **fieldname=***<value>*
        **Description:** Name of a field''',
        validate=validators.Fieldname())

    file = Option(
        doc='''
        **Syntax:** **file=***<value>*
        **Description:** Name of a file''',
        validate=validators.File(mode='r'))

    integer = Option(
        doc='''
        **Syntax:** **integer=***<value>*
        **Description:** An integer value''',
        validate=validators.Integer())

    optionname = Option(
        doc='''
        **Syntax:** **optionname=***<value>*
        **Description:** The name of an option (used internally)''',
        validate=validators.OptionName())

    regularexpression = Option(
        doc='''
        **Syntax:** **regularexpression=***<value>*
        **Description:** Regular expression pattern to match''',
        validate=validators.RegularExpression())

    set = Option(
        doc='''
        **Syntax:** **set=***<value>*
        **Description:** Regular expression pattern to match''',
        validate=validators.Set("foo", "bar", "test"))

    @Configuration()
    def map(self, records):
        pass

    def reduce(self, records):
        pass
Beispiel #3
0
class RequestCommand(GeneratingCommand):
    url = Option(require=True)
    method = Option(default='GET',
                    validate=validators.Set('GET', 'POST', 'PUT', 'DELETE',
                                            'OPTIONS', 'HEAD', 'PATCH'))
    body = Option(default=None)
    headers = Option(default=None)

    def generate(self):
        response = requests.request(
            method=self.method,
            url=self.url,
            json=self.body and json.loads(self.body),
            headers=self.headers and json.loads(self.headers),
        )
        # Raise an error in case of bad request
        response.raise_for_status()
        try:
            data = response.json()
        except json.JSONDecodeError:
            yield self.to_event(response.content)
            return

        if isinstance(data, typing.Sequence):
            for item in data:
                yield self.to_event(item)
        else:
            yield self.to_event(data)

    @staticmethod
    def to_event(item):
        """
        Transform non-dicts into dicts (which represents splunk events)

        :return: An event which represents `item`
        :rtype: dict
        """
        return item if isinstance(item, typing.Mapping) else dict(value=item)
Beispiel #4
0
class TestSearchCommand(SearchCommand):

    boolean = Option(
        doc='''
        **Syntax:** **boolean=***<value>*
        **Description:** A boolean value''',
        validate=validators.Boolean())

    required_boolean = Option(
        doc='''
        **Syntax:** **boolean=***<value>*
        **Description:** A boolean value''',
        require=True, validate=validators.Boolean())

    aliased_required_boolean = Option(
        doc='''
        **Syntax:** **boolean=***<value>*
        **Description:** A boolean value''',
        name='foo', require=True, validate=validators.Boolean())

    code = Option(
        doc='''
        **Syntax:** **code=***<value>*
        **Description:** A Python expression, if mode == "eval", or statement, if mode == "exec"''',
        validate=validators.Code())

    required_code = Option(
        doc='''
        **Syntax:** **code=***<value>*
        **Description:** A Python expression, if mode == "eval", or statement, if mode == "exec"''',
        require=True, validate=validators.Code())

    duration = Option(
        doc='''
        **Syntax:** **duration=***<value>*
        **Description:** A length of time''',
        validate=validators.Duration())

    required_duration = Option(
        doc='''
        **Syntax:** **duration=***<value>*
        **Description:** A length of time''',
        require=True, validate=validators.Duration())

    fieldname = Option(
        doc='''
        **Syntax:** **fieldname=***<value>*
        **Description:** Name of a field''',
        validate=validators.Fieldname())

    required_fieldname = Option(
        doc='''
        **Syntax:** **fieldname=***<value>*
        **Description:** Name of a field''',
        require=True, validate=validators.Fieldname())

    file = Option(
        doc='''
        **Syntax:** **file=***<value>*
        **Description:** Name of a file''',
        validate=validators.File())

    required_file = Option(
        doc='''
        **Syntax:** **file=***<value>*
        **Description:** Name of a file''',
        require=True, validate=validators.File())

    integer = Option(
        doc='''
        **Syntax:** **integer=***<value>*
        **Description:** An integer value''',
        validate=validators.Integer())

    required_integer = Option(
        doc='''
        **Syntax:** **integer=***<value>*
        **Description:** An integer value''',
        require=True, validate=validators.Integer())

    map = Option(
        doc='''
        **Syntax:** **map=***<value>*
        **Description:** A mapping from one value to another''',
        validate=validators.Map(foo=1, bar=2, test=3))

    required_map = Option(
        doc='''
        **Syntax:** **map=***<value>*
        **Description:** A mapping from one value to another''',
        require=True, validate=validators.Map(foo=1, bar=2, test=3))

    match = Option(
        doc='''
        **Syntax:** **match=***<value>*
        **Description:** A value that matches a regular expression pattern''',
        validate=validators.Match('social security number', r'\d{3}-\d{2}-\d{4}'))

    required_match = Option(
        doc='''
        **Syntax:** **required_match=***<value>*
        **Description:** A value that matches a regular expression pattern''',
        require=True, validate=validators.Match('social security number', r'\d{3}-\d{2}-\d{4}'))

    optionname = Option(
        doc='''
        **Syntax:** **optionname=***<value>*
        **Description:** The name of an option (used internally)''',
        validate=validators.OptionName())

    required_optionname = Option(
        doc='''
        **Syntax:** **optionname=***<value>*
        **Description:** The name of an option (used internally)''',
        require=True, validate=validators.OptionName())

    regularexpression = Option(
        doc='''
        **Syntax:** **regularexpression=***<value>*
        **Description:** Regular expression pattern to match''',
        validate=validators.RegularExpression())

    required_regularexpression = Option(
        doc='''
        **Syntax:** **regularexpression=***<value>*
        **Description:** Regular expression pattern to match''',
        require=True, validate=validators.RegularExpression())

    set = Option(
        doc='''
        **Syntax:** **set=***<value>*
        **Description:** A member of a set''',
        validate=validators.Set('foo', 'bar', 'test'))

    required_set = Option(
        doc='''
        **Syntax:** **set=***<value>*
        **Description:** A member of a set''',
        require=True, validate=validators.Set('foo', 'bar', 'test'))

    class ConfigurationSettings(SearchCommand.ConfigurationSettings):
        @classmethod
        def fix_up(cls, command_class):
            pass
Beispiel #5
0
class MACFormatCommand(StreamingCommand):
    """ Convert a given MAC address field to specified format.

    ##Syntax

    .. code-block::
        | macformat input=field-list output=field-list format=[cisco|dash|ieee|none]

    ## Description

    Convert the fields in the `input` field list to the ones in the `output` list; Both lists are
    optional. The `input` list defaults to `macaddress`. The`output` list is filled with fields in
    the `input` list it the `output` list is shorter than the `input`.

    The `format` option is one of [cisco|dash|ieee|none]. The default is `none`.

    Raises a ValueError exception if the MAC address is invalid.
    """
    format = Option(doc='''
        **Syntax:** **format=**`[cisco|dash|ieee|none]`
        **Description:** Format of the output MAC address. Defaults to `none`.''',
                    require=False,
                    validate=validators.Set('cisco', 'dash', 'ieee', 'none'))

    inputs = Option(doc='''
        **Syntax:** **inputs=***<field-list>*
        **Description:** A comma-delimited list of input fields to convert. Defaults to `macaddress`.''',
                    require=False,
                    validate=validators.List())

    outputs = Option(doc='''
        **Syntax:** **outputs=***<field-list>*
        **Description:** A comma-delimited list of fields for the results. Defaults to `inputs`.''',
                     require=False,
                     validate=validators.List())

    def prepare(self):
        """ Prepare the options.

        :return: :const:`None`
        :rtype: NoneType
        """
        self.toform = globals()['_' + (self.format or self.def_format)]
        inputs = self.inputs
        if inputs is None:
            self.inputs = inputs = self.def_inputs
        outputs = self.outputs
        if outputs is None:
            outputs = inputs
        elif len(outputs) < len(inputs):
            outputs += inputs[len(outputs):]
        self.outputs = outputs
        self.logger.debug(
            'MACFormatCommand.prepare: inputs = %s, outputs = %s', self.inputs,
            outputs)

    def stream(self, records):
        toform = self.toform
        inputs = self.inputs
        outputs = self.outputs
        if outputs is None:
            outputs = inputs
        elif len(outputs) < len(inputs):
            outputs += inputs[len(outputs):]
        for record in records:
            self.logger.debug('MACFormatCommand: record = %s', record)
            for i in range(len(inputs)):
                mac = record.get(inputs[i])
                if mac != None:
                    try:
                        record[outputs[i]] = toform(mac)
                    except Exception as err:
                        record[outputs[i]] = mac
                        self.logger.error('(input=%s) %s', inputs[i],
                                          err.message)
            yield record

    def __init__(self):
        StreamingCommand.__init__(self)
        appdir = path.dirname(path.dirname(__file__))
        defconfpath = path.join(appdir, "default", "app.conf")
        defconf = cli.readConfFile(defconfpath).get('macformat') or {}
        localconfpath = path.join(appdir, "local", "app.conf")
        localconf = (cli.readConfFile(localconfpath).get('macformat')
                     or {}) if path.exists(localconfpath) else {}
        self.def_format = localconf.get('format') or defconf.get(
            'format') or 'none'
        inputs = localconf.get('inputs') or defconf.get('inputs')
        self.def_inputs = re.split('[\s,]',
                                   inputs) if inputs else ['macaddress']
Beispiel #6
0
class JsonFormatCommand(StreamingCommand):
    """ Format a that a Json field and report any errors, if requested.

    ##Syntax

    .. code-block::
        jsonformat (indent=<int>)? (order=undefined|preserve|sort) (input_mode=json|python)? (errors=<field>)? (<field> (as <field>)?)*

    """
    indent = Option(doc="How many spaces for each indentation.",
                    require=False,
                    default=2,
                    validate=validators.Integer(0, 10))

    order = Option(
        doc=
        "Pick order options.  undefined (default), preserve, or sort.  Only impacts hash order",
        require=False,
        default="preserve",
        validate=validators.Set("undefined", "preserve", "sort"))

    errors = Option(doc="field name to capture any parsing error messages.",
                    require=False,
                    default=None,
                    validate=validators.Fieldname())

    input_mode = Option(
        doc=
        "Select an alternate input mode.  Supports 'json' and 'python' repr format "
        "(literals only).  In this mode, the 'preserve' order option will not work.",
        require=False,
        default="json",
        validate=validators.Set("json", "python"))

    output_mode = Option(
        doc=
        "Select an alternate output mode.  Supports 'json' (the default) and 'makeresults' "
        "which allows easy creation of run-anywhere sample of a json object.  "
        "You can paste the output to Splunk Answers when requesting help with JSON processing.",
        require=False,
        default="json",
        validate=validators.Set("json", "makeresults"))

    @staticmethod
    def handle_field_as(fieldnames):
        """ Convert a list of fields, which may include "a as b" style renaming into a more usable
        output format.  The output is a list of tuples in the form of (src, dest) showing any rename\
        mappings.  In the simple case, where no renaming occurs, src and dest are the same.
        """
        fields = fieldnames[:]
        fieldpairs = []
        while fields:
            f = fields.pop(0)
            if len(fields) > 1 and fields[0].lower() == "as":
                fieldpairs.append((f, fields[1]))
                fields = fields[2:]
            else:
                fieldpairs.append((f, f))
        return fieldpairs

    def stream(self, records):
        json_loads = json.loads
        json_dumps = partial(json.dumps, indent=self.indent)

        if self.order == "preserve":
            json_loads = partial(json.loads, object_pairs_hook=OrderedDict)
        elif self.order == "sort":
            json_dumps = partial(json.dumps,
                                 indent=self.indent,
                                 sort_keys=True)

        if self.input_mode == "python":
            json_loads = from_python

        if self.fieldnames:
            fieldpairs = self.handle_field_as(self.fieldnames)
        else:
            fieldpairs = [("_raw", "_raw")]

        self.logger.info("fieldnames={}".format(self.fieldnames))
        for (src_field, dest_field) in fieldpairs:
            if src_field != dest_field:
                self.logger.info("Mapping JSON field {} -> {}".format(
                    src_field, dest_field))
        self.logger.info("fieldpairs={}".format(fieldpairs))

        def output_json(json_string):
            # Normal mode.  Just load and dump json
            data = json_loads(json_string)
            return json_dumps(data)

        def output_makeresults(json_string):
            # Build a "makeresults" (run-anywhere) output sample
            quote_chars = ('\\', "\n", "\t", '"')  # Order matters
            try:
                data = json_loads(json_string)
                json_min = json.dumps(data, indent=None, separators=(",", ":"))
                for char in quote_chars:
                    json_min = json_min.replace(char, "\\" + char)
                return '| makeresults | eval {}="{}"'.format(
                    src_field, json_min)
            except ValueError as e:
                return "ERROR:  {!r}   {}".format(json_string, e)

        if self.output_mode == "json":
            output = output_json
        elif self.output_mode == "makeresults":
            output = output_makeresults

        first_row = True
        linecount_set = False

        for record in records:
            errors = []
            for (src_field, dest_field) in fieldpairs:
                json_string = record.get(src_field, None)
                if isinstance(json_string, (list, tuple)):
                    # XXX: Add proper support for multivalue input fields.  For now, skip.
                    json_string = None
                if json_string:
                    try:
                        text = output(json_string)
                        record[dest_field] = text
                        # Handle special case for _raw message update
                        if dest_field == "_raw":
                            if "linecount" in record:
                                record["linecount"] = len(text.splitlines())
                                linecount_set = True
                    except ValueError as e:
                        if len(fieldpairs) > 1:
                            errors.append("Field {} error:  {}".format(
                                src_field, e.message))
                        else:
                            errors.append(e.message)
                else:
                    if src_field != dest_field:
                        record[dest_field] = json_string
                if self.errors:
                    record[self.errors] = errors or "none"

            # Make sure that all of our output fields are present on the first record, since this
            # dictates the possible return fields which cannot be updated later.
            if first_row:
                first_row = False
                needed_fields = [df for (sf, df) in fieldpairs]
                if linecount_set:
                    needed_fields.append("linecount")
                for f in needed_fields:
                    if f not in record:
                        record[f] = None

            yield record
Beispiel #7
0
class ReaperCommand(EventingCommand):
    """ Filters out noise from Splunk queries by leveraging the Threshing Floor
        API.
    ##Syntax
    .. code-block::
        reaper logtype=<http, auth, generic> <port=<int>:<'udp|tcp'>>
    ##Description
    The :code:`reaper` command filters network security noise from HTTP logs,
    ssh access logs, and generic log files.
    """

    BASE_URI = "https://api.threshingfloor.io"
    API_KEY = ""

    logtype = Option(doc='''**Syntax:** **type'=***<event-type>*
        **Description:** The type of events you wish to reduce. Can be `http`, `auth`, or `generic`.''',
                     name='type',
                     validate=validators.Set('http', 'auth', 'generic'))

    ports = Option()

    def get_config(self, conf_file_name, section):
        env = dict()
        env.update(os.environ)
        splunk_home = env.get('SPLUNK_HOME', '/Applications/Splunk')
        btool = os.path.join(splunk_home, "bin", "btool")
        tmp = subprocess.Popen([btool, conf_file_name, "list"],
                               stdout=subprocess.PIPE,
                               env=env)
        (output, error) = tmp.communicate()

        f = StringIO.StringIO()
        f.write(output)
        f.seek(0)
        cfgparse = ConfigParser.RawConfigParser()
        cfgparse.readfp(f)

        cfg = dict()
        for opt in cfgparse.options(section):
            cfg[opt] = cfgparse.get(section, opt)
        return cfg

    def transform(self, events):
        # We have like, 3 copies of the events which is not optimal
        dictEvent = []
        rawEvents = []

        # Save off the events so they can be parsed by the library
        for event in events:
            dictEvent.append(event)
            rawEvents.append(event['_raw'])

        # Set to generic mode if ports are present and no type is specified
        if self.logtype == None and self.ports != None:
            self.logtype = 'generic'
        else:
            self.logtype = self.guessType(rawEvents)

        # Send an error if
        if self.logtype == 'generic' and self.ports == None:
            raise Exception("Generic mode requires the port option.")

        # Get the ports of we have them
        if self.ports:
            ports = self.ports.split(";")

        # Initialize the correct log type
        if self.logtype == 'auth':
            analyzed = TFAuthLog(rawEvents, self.API_KEY, self.BASE_URI)
        elif self.logtype == 'http':
            analyzed = TFHttpLog(rawEvents, self.API_KEY, self.BASE_URI)
        elif self.logtype == 'generic':
            analyzed = TFGenericLog(rawEvents, ports, self.API_KEY,
                                    self.BASE_URI)
        else:
            raise TFException("Failed to parse the query.")

        reduced = analyzed.reduce()
        reducedItem = reduced.next()

        for i in range(0, len(dictEvent)):
            if dictEvent[i]['_raw'] == reducedItem:
                yield dictEvent[i]
                reducedItem = reduced.next()

        return

    def guessType(self, logfile, baseName=None):
        REGEX_HTTP = "^\[(?P<timestamp>.+)?\]\s\"(?P<request>.+?)\"\s(?P<responseCode>\d+)\s(?P<size>\d+)(?P<combinedFields>.*)"

        # If we can't do that, we will read 10 lines in, then try to match with a regular expression
        logline = logfile[min(10, len(logfile) - 1)]

        try:

            # See if it's http
            splitLine = logline.split()
            m = re.search(REGEX_HTTP, " ".join(splitLine[3:]))
            if m:
                return 'http'

            # See if it's auth
            try:
                # Try and make a timestamp from the beginning of the line
                if int(
                        time.mktime(
                            time.strptime(
                                " ".join(splitLine[0:3]) + " " + "2017",
                                "%b %d %H:%M:%S %Y"))) > 0:
                    return 'auth'
            except Exception as e:
                pass

            # If we haven't returned by now, we can't figure out the type
            raise TFException(
                "Unable to automatically identify the log type. Please specify a type with the -t flag."
            )
        except IOError as e:
            exit()

    def __init__(self):
        EventingCommand.__init__(self)

        conf = self.get_config('threshingfloor', 'api-config')
        self.BASE_URI = conf.get('base_uri', None)
        self.API_KEY = conf.get('api_key', None)
Beispiel #8
0
class VirusTotalCommand(StreamingCommand):
    hash = Option(
        doc='''
        **Syntax:** **hash=***<fieldname>*
        **Description:** Name of the field which contains the hash''',
        require=False, validate=validators.Fieldname())

    ip = Option(
        doc='''
        **Syntax:** **ip=***<fieldname>*
        **Description:** Name of the field which contains the ip''',
        require=False, validate=validators.Fieldname())

    url = Option(
        doc='''
        **Syntax:** **url=***<fieldname>*
        **Description:** Name of the field which contains the url''',
        require=False, validate=validators.Fieldname())

    domain = Option(
        doc='''
            **Syntax:** **domain=***<fieldname>*
            **Description:** Name of the field which contains the domain''',
        require=False, validate=validators.Fieldname())

    mode = Option(
        doc='''
        **Syntax:** **mode=***<raw|v1>*
        **Description:** Name of the field which contains the url''',
        require=False, default="v1", validate=validators.Set('json', 'v1'))

    rescan = Option(
        doc='''
            **Syntax:** **rescan=***<fieldname>*
            **Description:** bool. If false, will not rescan rows that already have vt_* fields.
            If true, will scan all hashes. Uses vt_resource field to determine if info exists. (Deafults True)''',
        require=False, default=True, validate=validators.Boolean())

    def correlate_vt(self, records):
        """
        Incorporate VT information into the events provided in 'records'
        :param records: The records to be supplemented with added information
        :return: None
        """
        for record in records:
            for k in ALL_OUTPUT_FIELDS[self.report_type]:
                if k not in record.keys():
                    record[k] = ""

        expected_min_resource_len = 0
        if self.report_type == "hash":
            expected_min_resource_len = 20
        elif self.report_type == "ip":
            expected_min_resource_len = 7
        elif self.report_type == "url":
            expected_min_resource_len = 4
        elif self.report_type == "domain":
            expected_min_resource_len = 4

        # Ignore records that already have info if we are not rescanning
        if not self.rescan:
            records = [record for record in records
                       if "vt_resource" not in record.keys()
                       or not isinstance(record['vt_resource'], str)
                       or len(record['vt_resource']) < expected_min_resource_len]

        already_warned = False
        # Put records into temporary dict, as cross-reference
        records_dict = {}
        resources = []
        for record in records:
            # The following checks makes sure the field is a string
            if self.matching_field in record.keys() \
                    and isinstance(record[self.matching_field], str) and len(record[self.matching_field]) >= expected_min_resource_len \
                    and record[self.matching_field] == record[self.matching_field].strip():
                _resource = record[self.matching_field]
                records_dict[_resource] = record
                resources.append(_resource)
            elif not already_warned:
                self.write_warning("VirusTotal Command: Warning: \
                One or more events had bad data or no data in your input field. \
                Normalize the field in your data to correct this issue. Note: this \
                is often caused by empty values, mvfield values, or values with leading or trailing whitespaces. \
                Warning: Unaddressed data quality issues can additionally cause subsequent failures with lookups. "
                                   )
                already_warned = True

        # If there are no hashes to scan, exit.
        if len(resources) == 0:
            self.logger.debug("Not querying VT API with %d resources" % len(resources))
            return
        self.logger.debug("Querying VT API with %d resources (%s)" % (len(resources), self.report_type))

        attempts = 0
        # Query the API
        while True:
            try:
                attempts += 1
                if self.report_type == "hash":
                    vt_res = _query_virustotal_hashes(resources, mode=self.mode)
                elif self.report_type == "ip":
                    vt_res = _query_virustotal_ips(resources, mode=self.mode)
                elif self.report_type == "url":
                    vt_res = _query_virustotal_urls(resources, mode=self.mode)
                elif self.report_type == "domain":
                    vt_res = _query_virustotal_domains(resources, mode=self.mode)
                break
            except VTRequestLimitExceededException:
                # Always log to the search.log file
                self.logger.warning("VirusTotal Request Limit Exceeded. Waiting 1 minute before resuming queries.")

                # End sleep in 60 seconds
                sleep_end_time = time.time() + 60
                while time.time() < sleep_end_time:
                    # Check if user terminated the job
                    self.termination_helper.check_termination(now=True)
                    # Sleep at most 5 seconds, and at least enough seconds to reach end of timeout period
                    time.sleep(max(0.0, min(5.0, sleep_end_time - time.time())))
            except Exception as e:
                self.error_exit(e, "Unexpected error when querying VirusTotal API: %s" % e.message)
            if attempts > 10:
                self.error_exit(None, "Failed to retrieve results from VirusTotal after 10 retries. Aborting.")

        # Verify that we got expected number of results
        if len(vt_res) != len(records_dict):
            self.error_exit(None, "VirusTotal returned %d results, but %d were expected. "
                                  "Is the batch_size value set too high for this specific key (app setup)?"
                            % (len(vt_res), len(records_dict)))

        # Place values from results into the rows we are processing.
        for k, v in vt_res.items():
            # Fill with real values from response (at least as many as we have)
            for vtk, vtv in v.items():
                if self.mode == 'v1':
                    records_dict[k]["vt_%s" % vtk] = vtv
                elif self.mode == 'json':
                    records_dict[k][vtk] = vtv

    def prepare(self):
        """
        Called by splunkd before the command executes.
        Used to get configuration data for this command from splunk.
        :return: None
        """
        global API_KEY, BATCHING, CMD_TIMEOUT, PROXY

        self.logger.debug('VirusTotalCommand: %s', self)  # logs command line

        proxy_password = None

        # Get the API key from Splunkd's REST API
        # Also get proxy password if configured
        for passwd in self.service.storage_passwords:  # type: StoragePassword
            if (passwd.realm is None or passwd.realm.strip() == "") and passwd.username == "virustotal":
                API_KEY = passwd.clear_password
            if (passwd.realm is None or passwd.realm.strip() == "") and passwd.username == "vt_proxy":
                proxy_password = passwd.clear_password

        # Verify we got the key
        if API_KEY is None or API_KEY == "defaults_empty":
            self.error_exit(None, "No API key found for VirusTotal. Re-run the app setup for the TA.")

        # Helper method to get config settings from virustotal.conf with error checking
        def get_safely(stanza, key, thetype):
            try:
                return thetype(self.service.confs[str('virustotal')][str(stanza)][str(key)])
            except:
                self.error_exit(sys.exc_info(), "VirusTotal command: Error while processing %s. "
                                                "Ensure that the batch_size variable is correctly configured (app setup)"
                                                " and that defaults.conf was not damaged. "
                                                "Error: %s" % (key, sys.exc_info()[1]))

        # Configure some common settings

        BATCHING = get_safely('settings', 'batch_size', int)
        CMD_TIMEOUT = get_safely('settings', 'cmd_timeout', int)

        # Configure proxy settings (if the user enabled the proxy in setup)

        # Following "best practice" of using "disabled" instead of "enabled" leads to bad-looking logic...
        if not get_safely('proxy', 'disabled', validators.Boolean()):
            match = re.match('^(https?|socks5)://([^@:#\$ _]+)(:(\d+))?$', get_safely('proxy', 'url', str))

            import requests
            self.logger.warning(requests.__version__)

            # We need either 2 or 4 groups, depending on whether the user specified the port
            # This regex is here mainly to sanitise/validate user input, instead of trusting that the user set it correctly
            if match is None or (len(match.groups()) != 2 and len(match.groups()) != 4):
                self.error_exit(None, "VirusTotal Command: Proxy settings appear to be incorrect. "
                                      "Go to the App Setup page and ensure that the URL for the proxy is configured correctly,")
                return

            username = get_safely('proxy', 'username', str)

            if username is not None and len(username) > 0:
                url = '%s://%s:%s@%s' % (match.group(1), username, proxy_password, match.group(2))
            else:
                url = '%s://%s' % (match.group(1), match.group(2))

            if len(match.groups()) == 4 and match.group(4) is not None:
                url = url + ":%s" % match.group(4)

            PROXY = {
                'https': url
            }

    def stream(self, records):
        """
        Hooking point for splunk.
        :param records: The generator function provided by Splunk which will provide all the events.
        :return: yields events one at a time
        """
        self.termination_helper = TerminationHelper(self.service, self)

        self.logger.debug("VirusTotalCommand: BATCHING = %d" % BATCHING)
        self.logger.debug("VirusTotalCommand: RESCAN = %s" % self.rescan)

        self.matching_field = None
        self.report_type = None
        for rt in REPORT_TYPES:
            if getattr(self, rt) is not None:
                if self.report_type is not None:
                    self.error_exit(None, "VirusTotal Command: Getting multiple types of reports in a single search is not supported. "
                                          "Specify only one of 'hash=', 'ip=', 'url=', or 'domain=' and try again.")
                    return
                self.report_type = rt
                self.matching_field = getattr(self, rt)
        if self.report_type is None:
            self.error_exit(None, "VirusTotal Command: No field specified for matching. "
                                  "Specify one of 'hash=', 'ip=', 'url=', or 'domain=' and try again.")
            return

        # Process the events
        try:
            while True:
                _records = batch(records, n=BATCHING)
                if len(_records) == 0:
                    break
                self.termination_helper.check_termination()
                self.correlate_vt(_records)
                for record in _records:
                    yield record
        except SplunkJobTerminatedException as sjt:
            warning = "VirusTotal Command: Forcing exit. Reason: Parent job termination detected. " \
                      "Parent job state: %s" % sjt.state
            self.write_warning(warning)
            self.logger.warning(warning)
            return
        except CustomCommandTimeoutException as cct:
            warning = "VirusTotal Command: Forcing exit. Reason: Internal timeout reached. " \
                      "If necessary, the timeout can be increased on the app setup page. " \
                      "Command has been running for: %d seconds" % cct.runtime
            self.write_warning(warning)
            self.logger.warning(warning)
            return