async def _get_documents(self, sha, sca, cmd, args, fmt, otpt): documents = [] cur = db.async_command_collection.select_many(sha256_digest=sha, scale=sca, command=cmd, args=args, sort="timestamp") while await cur.fetch_next: doc = cur.next_object() doc = schema.CommandSchema().load(doc) try: # Ignore output for missing scales and/or commands scale = scale_manager.get_scale(doc['scale']) commands = scale_manager.get_component( scale, enums.ScaleComponent.COMMANDS) except Exception as err: print("%s - %s" % (doc['scale'], err)) # TODO: Output to log continue output = None if '_output_id' in doc and doc['_output_id']: output = await db.async_command_output_collection.get( doc['_output_id']) doc = schema.CommandSchema().dump(doc) try: if otpt: doc['output'] = commands.snake.format(fmt, cmd, output) doc['format'] = fmt except (SnakeError, TypeError) as err: print("%s - %s" % (doc['scale'], err)) # TODO: Output to log continue documents += [doc] return documents
async def post(self, data): # Check that there is a file for this hash document = await db.async_file_collection.select(data['sha256_digest']) if not document: self.write_warning("no sample for given data", 404, data) self.finish() return # Check scale support try: scale = scale_manager.get_scale(data['scale'], document['file_type']) commands = scale_manager.get_component( scale, enums.ScaleComponent.COMMANDS) cmd = commands.snake.command(data['command']) except SnakeError as err: self.write_warning("%s" % err, 404, data) self.finish() return # Validate arguments as to not waste users time, yes this is also done on execution result, args = validate_args(cmd, data['args']) if not result: self.write_warning(args, 422, data) self.finish() return data['args'] = args # Queue command try: document = await route_support.queue_command(data) except SnakeError as err: self.write_warning("%s" % err, 500, data) self.finish() return document = schema.CommandSchema().load(document) output = None if document['_output_id']: output = await db.async_command_output_collection.get( document['_output_id']) try: document['output'] = commands.snake.format(data['format'], document['command'], output) document['format'] = data['format'] except SnakeError as err: self.write_warning("%s" % err, 404, data) self.finish() return # Dump and finish document = schema.CommandSchema().dump(document) self.jsonify({"command": document}) self.finish()
async def get(self, data): # NOTE: Tornado/Marshmallow does not like Dict in args, will have to parse manually # TODO: Use marshmallow validation if 'args' in self.request.arguments and self.request.arguments['args']: data['args'] = json.loads(self.request.arguments['args'][0]) else: data['args'] = {} document = await db.async_command_collection.select( data['sha256_digest'], data['scale'], data['command'], data['args']) if not document: self.write_warning("no output for given data", 404, data) self.finish() return if document['status'] == enums.Status.ERROR: self.write_warning("%s" % document['output'], 404, data) self.finish() return document = schema.CommandSchema().load(document) output = None if document['_output_id']: output = await db.async_command_output_collection.get( document['_output_id']) try: scale = scale_manager.get_scale(data['scale']) commands = scale_manager.get_component( scale, enums.ScaleComponent.COMMANDS) if data['output']: document['output'] = commands.snake.format( data['format'], document['command'], output) document['format'] = data['format'] except (SnakeError, TypeError) as err: self.write_warning("%s" % err, 404, data) self.finish() return document = schema.CommandSchema().dump(document) self.jsonify({'command': document}) self.finish()
async def execute_autoruns(sha256_digest, file_type, mime_type): """Find and queue autoruns for a given file (sha256_digest). If enabled this function will queue all applicable autoruns for the file given (sha256_digest). Args: sha256_digest (str): The hash of the file to execute the autoruns on. file_type (:obj:`FileType`): The file type used to help apply autoruns. mime_type (str): The mime type used to help apply autoruns. """ if snake_config['command_autoruns']: autoruns = scale_manager.get_autoruns(file_type=file_type) for mod, cmd, mime in autoruns: if mime and not mime == mime_type: continue args = { 'sha256_digest': sha256_digest, 'scale': mod, 'command': cmd, 'asynchronous': True } args = schema.CommandSchema().load(args) await queue_command(args)
def execute_command(command_schema): """Execute the command on the celery worker This is the task used by celery for the workers. It will execute the command and update the database as required. Args: command_schema (:obj:`CommandSchema`): The command schema to execute. """ from snake.config import snake_config # XXX: Reload config, bit hacky but required with pymongo.MongoClient(snake_config['mongodb']) as connection: try: # NOTE: We assume the _output_id is always NULL! command_collection = command.CommandCollection(connection.snake) command_output_collection = command.CommandOutputCollection( connection.snake) command_schema['start_time'] = datetime.utcnow() command_schema['status'] = enums.Status.RUNNING command_schema = schema.CommandSchema().dump(command_schema) command_collection.update(command_schema['sha256_digest'], command_schema['scale'], command_schema['command'], command_schema['args'], command_schema) command_schema = schema.CommandSchema().load(command_schema) scale_manager_ = scale_manager.ScaleManager( [command_schema['scale']]) scale = scale_manager_.get_scale(command_schema['scale']) commands = scale_manager_.get_component( scale, enums.ScaleComponent.COMMANDS) cmd = commands.snake.command(command_schema['command']) output = cmd(args=command_schema['args'], sha256_digest=command_schema['sha256_digest']) command_schema['status'] = enums.Status.SUCCESS except error.CommandWarning as err: output = {'error': str(err)} command_schema['status'] = enums.Status.FAILED app_log.warning(err) except (error.SnakeError, error.MongoError, TypeError) as err: output = {'error': str(err)} command_schema['status'] = enums.Status.FAILED app_log.error(err) except (exceptions.SoftTimeLimitExceeded, exceptions.TimeLimitExceeded, BrokenPipeError) as err: output = {'error': 'time limit exceeded'} command_schema['status'] = enums.Status.FAILED app_log.exception(err) except Exception as err: output = {'error': 'a server side error has occurred'} command_schema['status'] = enums.Status.FAILED app_log.exception(err) else: # Test serialising of scale output as it could fail and we need to catch that try: json.dumps(output) except TypeError as err: output = { 'error': 'failed to serialize scale output - {}'.format(err) } finally: command_schema['end_time'] = datetime.utcnow() command_schema = schema.CommandSchema().dump(command_schema) _output_id = command_output_collection.put( command_schema['command'], bytes(json.dumps(output), 'utf-8')) command_schema['_output_id'] = _output_id command_collection.update(command_schema['sha256_digest'], command_schema['scale'], command_schema['command'], command_schema['args'], command_schema)
async def post(self, data): # pylint: disable=too-many-locals, too-many-branches, too-many-statements # XXX: Needs a major clean/rework if not data: self.write_warning("commands - no request body found", 422, data) self.finish() return # Find the commands and validate their arguments for d in data: # Find the command try: s = scale_manager.get_scale(d['scale']) c = scale_manager.get_component(s, enums.ScaleComponent.COMMANDS) cmd = c.snake.command(d['command']) except ScaleError as err: self.write_warning(err.message, err.status_code) self.finish() return result, args = validate_args(cmd, d['args']) if not result: self.write_warning(self.json_decode(args.replace("'", '"')), 422, data) self.finish() return d['args'] = args # Validate hashes and validate them against scales missing = [] unsupported = [] for d in data: s = scale_manager.get_scale(d['scale']) for sha in d['sha256_digests']: if sha.lower() == 'all': if not s.supports and not len(s.supports) == len( [x for x in enums.FileType]): unsupported += [d] break elif sha.lower()[:4] == 'all:': file_type = sha.lower().split(':')[1] if file_type == 'file': ft = enums.FileType.FILE elif file_type == 'memory': ft = enums.FileType.MEMORY else: ft = None if ft is None or (s.supports and ft not in s.supports): unsupported += [(sha, s.name)] break else: document = await db.async_file_collection.select(sha) if not document: missing += [d] elif s.supports and document[ 'file_type'] not in s.supports: # Check scale support unsupported += [d] if missing: self.write_warning("commands - no sample(s) for given data", 404, missing) self.finish() return if unsupported: self.write_warning("commands - command unsupported for given data", 422, unsupported) self.finish() return # Queue commands documents = [] for d in data: cmd_dict = {} for k, v in d.items(): if k != 'sha256_digests': cmd_dict[k] = v cmd_dict['asynchronous'] = True for sha in d['sha256_digests']: if sha.lower() == 'all': cursor = db.async_file_collection.select_all() while await cursor.fetch_next: cmd_dict['sha256_digest'] = cursor.next_object( )['sha256_digest'] cmd_d = schema.CommandSchema().load(cmd_dict) documents += [await route_support.queue_command(cmd_d)] break elif sha.lower()[:4] == 'all:': ft = sha.lower().split(':')[1] if ft == 'file': ft = enums.FileType.FILE elif ft == 'memory': ft = enums.FileType.MEMORY cursor = db.async_file_collection.select_many(file_type=ft) while await cursor.fetch_next: cmd_dict['sha256_digest'] = cursor.next_object( )['sha256_digest'] cmd_d = schema.CommandSchema().load(cmd_dict) documents += [await route_support.queue_command(cmd_d)] break else: cmd_dict['sha256_digest'] = sha cmd_d = schema.CommandSchema().load(cmd_dict) documents += [await route_support.queue_command(cmd_d)] # Dump and finish documents = schema.CommandSchema(many=True).load(documents) documents = schema.CommandSchema(many=True).dump(documents) self.jsonify({"commands": documents}) self.finish()
def test_execute_command(mocker): """ Test the execute_command function """ # NOTE: The setup probably warrant tests in themselves but this is better than nothing ;) base_data = schema.CommandSchema().load({ 'sha256_digest': 'abcd', 'scale': 'abcd', 'command': 'abcd' }) class DataBase(): # pylint: disable=too-few-public-methods def __init__(self): # pylint: disable=unused-argument self.data = schema.CommandSchema().dump(base_data) self.output = '' database = DataBase() class CommandCollection(): # pylint: disable=too-few-public-methods, no-self-use def __init__(self, db): # pylint: disable=unused-argument self.db = db def update(self, sha256_digest, scale, command, data): # pylint: disable=unused-argument self.db.data = data class CommandOutputCollection(): # pylint: disable=too-few-public-methods, no-self-use def __init__(self, db): # pylint: disable=unused-argument self.db = db def put(self, file_name, data): # pylint: disable=unused-argument self.db.output = data class MongoClient(): # pylint: disable=too-few-public-methods class Snake: def __init__(self, db): # pylint: disable=unused-argument self.snake = db def dummy(self, *args, **kwargs): # pylint: disable=unused-argument return self.snake __enter__ = dummy __exit__ = dummy def __init__(self, db): # pylint: disable=unused-argument self.snake = self.Snake(database) class ScaleManagerCW(): # pylint: disable=too-few-public-methods def __init__(self, *args, **kwargs): # pylint: disable=unused-argument raise error.CommandWarning('error') class ScaleManagerSE(): # pylint: disable=too-few-public-methods def __init__(self, *args, **kwargs): # pylint: disable=unused-argument raise error.SnakeError('error') class ScaleManagerTE(): # pylint: disable=too-few-public-methods def __init__(self, *args, **kwargs): # pylint: disable=unused-argument raise BrokenPipeError('error') class ScaleManagerE(): # pylint: disable=too-few-public-methods def __init__(self, *args, **kwargs): # pylint: disable=unused-argument raise Exception('error') def dumps(data): try: return str(data).replace("\'", '\"') except Exception as err: return '{"dummy": "%s"}' % err mocker.patch('json.dumps', dumps) mocker.patch('pymongo.MongoClient', MongoClient) mocker.patch('snake.core.scale_manager.ScaleManager') mocker.patch('snake.engines.mongo.command.CommandCollection', CommandCollection) mocker.patch('snake.engines.mongo.command.CommandOutputCollection', CommandOutputCollection) # Test success data = schema.CommandSchema().dump(base_data) celery.execute_command(data) assert database.data['status'] == 'success' # Cause command warning mocker.patch('snake.core.scale_manager.ScaleManager', ScaleManagerCW) data = schema.CommandSchema().dump(base_data) celery.execute_command(data) assert database.data['status'] == 'failed' output = database.output if isinstance(output, bytes): output = output.decode('utf-8') assert 'error' in json.loads(output) # Cause snake error mocker.patch('snake.core.scale_manager.ScaleManager', ScaleManagerSE) data = schema.CommandSchema().dump(base_data) celery.execute_command(data) assert database.data['status'] == 'failed' output = database.output if isinstance(output, bytes): output = output.decode('utf-8') assert 'error' in json.loads(output) # Cause timeout error mocker.patch('snake.core.scale_manager.ScaleManager', ScaleManagerTE) data = schema.CommandSchema().dump(base_data) celery.execute_command(data) assert database.data['status'] == 'failed' output = database.output if isinstance(output, bytes): output = output.decode('utf-8') assert 'error' in json.loads(output) # Cause general error mocker.patch('snake.core.scale_manager.ScaleManager', ScaleManagerE) data = schema.CommandSchema().dump(base_data) celery.execute_command(data) assert database.data['status'] == 'failed' output = database.output if isinstance(output, bytes): output = output.decode('utf-8') assert 'error' in json.loads(output)
def __init__(self): # pylint: disable=unused-argument self.data = schema.CommandSchema().dump(base_data) self.output = ''
async def queue_command(data): """Queue commands for execution This will queue commands for execution on the celery workers. Note: The returned command schema will reflect the status of the queued command. Args: data (:obj:`CommandSchema`): The command to queue for execution. Returns: :obj:`CommandSchema`: The command schema with updates """ # The lastest execution always wins, thus we replace the current one in the db document = await db.async_command_collection.select( data['sha256_digest'], data['scale'], data['command'], data['args']) if document: if 'status' in document and document['status'] == enums.Status.RUNNING: return schema.CommandSchema().dump( schema.CommandSchema().load(document)) else: _output_id = None if '_output_id' in document: _output_id = document['_output_id'] data['timestamp'] = datetime.utcnow() data = schema.CommandSchema().dump(data) await db.async_command_collection.replace(data['sha256_digest'], data['scale'], data['command'], data['args'], data) # NOTE: We delete after the replace to try and prevent concurrent # reads to a file while it is being deleted if _output_id: await db.async_command_output_collection.delete(_output_id) else: # Save the command, this will be in a pending state data['timestamp'] = datetime.utcnow() data = schema.CommandSchema().dump(data) await db.async_command_collection.insert(data) data = schema.CommandSchema().load(data) if data['asynchronous'] is True: celery.execute_command.apply_async(args=[data], time_limit=data['timeout'] + 30, soft_time_limit=data['timeout']) else: task = celery.execute_command.apply_async( args=[data], time_limit=data['timeout'] + 30, soft_time_limit=data['timeout']) result = await celery.wait_for_task(task) if not task.successful(): document = await db.async_command_collection.select( data['sha256_digest'], data['scale'], data['command'], data['args']) _output_id = None if '_output_id' in document: _output_id = document['_output_id'] _new_output_id = await db.async_command_output_collection.put( document['command'], b"{'error': 'worker failed please check log'}") document['_output_id'] = _new_output_id document['status'] = enums.Status.FAILED await db.async_command_collection.update(document['sha256_digest'], document['scale'], document['command'], data['args'], document) if _output_id: await db.async_command_output_collection.delete(_output_id) raise error.SnakeError(result) return await db.async_command_collection.select(data['sha256_digest'], data['scale'], data['command'], data['args'])