def test_iterkeys(self): disk_dict = DiskDict() disk_dict['a'] = 'abc' disk_dict['b'] = 'abc' disk_dict['c'] = 'abc' self.assertEqual(set(disk_dict.iterkeys()), set(['a', 'b', 'c']))
def __init__(self): self._variants_eq = DiskDict(table_prefix='variant_db_eq') self._variants = DiskDict(table_prefix='variant_db') self.params_max_variants = cf.cf.get('params_max_variants', PARAMS_MAX_VARIANTS) self.path_max_variants = cf.cf.get('path_max_variants', PATH_MAX_VARIANTS) self._db_lock = threading.RLock()
def __init__(self): self._variants = DiskDict(table_prefix='variant_db') self._variants_eq = DiskDict(table_prefix='variant_db_eq') self._variants_form = DiskDict(table_prefix='variant_db_form') self.params_max_variants = cf.cf.get('params_max_variants') self.path_max_variants = cf.cf.get('path_max_variants') self.max_equal_form_variants = cf.cf.get('max_equal_form_variants') self._db_lock = threading.RLock()
def test_remove_table(self): disk_dict = DiskDict() table_name = disk_dict.table_name db = get_default_temp_db_instance() self.assertTrue(db.table_exists(table_name)) disk_dict.cleanup() self.assertFalse(db.table_exists(table_name))
def test_get(self): disk_dict = DiskDict() disk_dict[0] = 'abc' abc1 = disk_dict.get(0) abc2 = disk_dict.get(0, 1) two = disk_dict.get(1, 2) self.assertEqual(abc1, 'abc') self.assertEqual(abc2, 'abc') self.assertEqual(two, 2)
def test_table_with_prefix(self): _unittest = 'unittest' disk_dict = DiskDict(_unittest) self.assertIn(_unittest, disk_dict.table_name) db = get_default_temp_db_instance() self.assertTrue(db.table_exists(disk_dict.table_name)) disk_dict.cleanup() self.assertFalse(db.table_exists(disk_dict.table_name))
def __init__(self): GrepPlugin.__init__(self) # Internal variables self._comments = DiskDict(table_prefix='html_comments') self._already_reported = ScalableBloomFilter() self._end_was_called = False
def __init__(self): AuditPlugin.__init__(self) # Internal variables self._persistent_multi_in = None self._expected_mutant_dict = DiskDict(table_prefix='ssi') self._extract_expected_re = re.compile('[1-9]{5}')
def _init(self, maxsize): """ Initialize the dicts and pointer :param maxsize: The max size for the queue """ self.queue_order = list() self.hash_to_uuid = dict() self.memory = dict() self.disk = DiskDict(table_prefix='%sCachedQueue' % self.name)
def __init__(self): AuditPlugin.__init__(self) # Internal variables self._expected_res_mutant = DiskDict() self._freq_list = DiskList() re_str = '<!--#exec cmd="echo -n (.*?);echo -n (.*?)" -->' self._extract_results_re = re.compile(re_str)
def _init(self, maxsize): """ Initialize the dicts and pointer :param maxsize: The max size for the queue """ self.memory = dict() self.disk = DiskDict(table_prefix='%sCachedQueue' % self.name) self.get_pointer = 0 self.put_pointer = 0
def __init__(self, max_in_memory=50, table_prefix=None): """ :param max_in_memory: The max number of items to keep in memory """ assert max_in_memory > 0, 'In-memory items must be > 0' table_prefix = self._get_table_prefix(table_prefix) self._max_in_memory = max_in_memory self._disk_dict = DiskDict(table_prefix=table_prefix) self._in_memory = dict() self._access_count = dict()
def __init__(self, iterable=(), maxsize=-1): if not hasattr(self, "data"): self.left = self.right = 0 self.data = DiskDict(table_prefix="deque") self.maxsize = maxsize self.extend(iterable)
class VariantDB(object): """ See the notes on PARAMS_MAX_VARIANTS and PATH_MAX_VARIANTS above. Also understand that we'll keep "dirty" versions of the references/fuzzable requests in order to be able to answer "False" to a call for need_more_variants in a situation like this: need_more_variants('http://foo.com/abc?id=32') --> True append('http://foo.com/abc?id=32') need_more_variants('http://foo.com/abc?id=32') --> False """ HASH_IGNORE_HEADERS = ('referer',) TAG = '[variant_db]' def __init__(self, params_max_variants=PARAMS_MAX_VARIANTS, path_max_variants=PATH_MAX_VARIANTS): self._variants_eq = DiskDict(table_prefix='variant_db_eq') self._variants = DiskDict(table_prefix='variant_db') self.params_max_variants = params_max_variants self.path_max_variants = path_max_variants self._db_lock = threading.RLock() def cleanup(self): self._variants_eq.cleanup() self._variants.cleanup() def append(self, fuzzable_request): """ :return: True if we added a new fuzzable request variant to the DB, False if no more variants are required for this fuzzable request. """ with self._db_lock: # # Is the fuzzable request already known to us? (exactly the same) # request_hash = fuzzable_request.get_request_hash(self.HASH_IGNORE_HEADERS) already_seen = self._variants_eq.get(request_hash, False) if already_seen: return False # Store it to avoid duplicated fuzzable requests in our framework self._variants_eq[request_hash] = True # # Do we need more variants for the fuzzable request? (similar match) # clean_dict_key = clean_fuzzable_request(fuzzable_request) count = self._variants.get(clean_dict_key, None) if count is None: self._variants[clean_dict_key] = 1 return True # We've seen at least one fuzzable request with this pattern... url = fuzzable_request.get_uri() has_params = url.has_query_string() or fuzzable_request.get_raw_data() # Choose which max_variants to use if has_params: max_variants = self.params_max_variants else: max_variants = self.path_max_variants if count >= max_variants: return False else: self._variants[clean_dict_key] = count + 1 return True
class html_comments(GrepPlugin): """ Extract and analyze HTML comments. :author: Andres Riancho ([email protected]) """ HTML_RE = re.compile('<[a-zA-Z]*.*?>.*?</[a-zA-Z]>') INTERESTING_WORDS = ( # In English 'user', 'pass', 'xxx', 'fix', 'bug', 'broken', 'oops', 'hack', 'caution', 'todo', 'note', 'warning', '!!!', '???', 'shit', 'pass', 'password', 'passwd', 'pwd', 'secret', 'stupid', # In Spanish 'tonto', 'porqueria', 'cuidado', 'usuario', u'contraseña', 'puta', 'email', 'security', 'captcha', 'pinga', 'cojones', # some in Portuguese 'banco', 'bradesco', 'itau', 'visa', 'bancoreal', u'transfêrencia', u'depósito', u'cartão', u'crédito', 'dados pessoais' ) _multi_in = multi_in([' %s ' % w for w in INTERESTING_WORDS]) def __init__(self): GrepPlugin.__init__(self) # Internal variables self._comments = DiskDict() self._already_reported_interesting = ScalableBloomFilter() def grep(self, request, response): """ Plugin entry point, parse those comments! :param request: The HTTP request object. :param response: The HTTP response object :return: None """ if not response.is_text_or_html(): return try: dp = parser_cache.dpc.get_document_parser_for(response) except BaseFrameworkException: return for comment in dp.get_comments(): # These next two lines fix this issue: # audit.ssi + grep.html_comments + web app with XSS = false positive if request.sent(comment): continue if self._is_new(comment, response): self._interesting_word(comment, request, response) self._html_in_comment(comment, request, response) def _interesting_word(self, comment, request, response): """ Find interesting words in HTML comments """ comment = comment.lower() for word in self._multi_in.query(comment): if (word, response.get_url()) not in self._already_reported_interesting: desc = 'A comment with the string "%s" was found in: "%s".'\ ' This could be interesting.' desc = desc % (word, response.get_url()) i = Info('Interesting HTML comment', desc, response.id, self.get_name()) i.set_dc(request.get_dc()) i.set_uri(response.get_uri()) i.add_to_highlight(word) kb.kb.append(self, 'interesting_comments', i) om.out.information(i.get_desc()) self._already_reported_interesting.add((word, response.get_url())) def _html_in_comment(self, comment, request, response): """ Find HTML code in HTML comments """ html_in_comment = self.HTML_RE.search(comment) if html_in_comment and \ (comment, response.get_url()) not in self._already_reported_interesting: # There is HTML code in the comment. comment = comment.strip() comment = comment.replace('\n', '') comment = comment.replace('\r', '') comment = comment[:40] desc = 'A comment with the string "%s" was found in: "%s".'\ ' This could be interesting.' desc = desc % (comment, response.get_url()) i = Info('HTML comment contains HTML code', desc, response.id, self.get_name()) i.set_dc(request.get_dc()) i.set_uri(response.get_uri()) i.add_to_highlight(html_in_comment.group(0)) kb.kb.append(self, 'html_comment_hides_html', i) om.out.information(i.get_desc()) self._already_reported_interesting.add( (comment, response.get_url())) def _is_new(self, comment, response): """ Make sure that we perform a thread safe check on the self._comments dict, in order to avoid duplicates. """ with self._plugin_lock: #pylint: disable=E1103 comment_data = self._comments.get(comment, None) if comment_data is None: self._comments[comment] = [(response.get_url(), response.id), ] return True else: if response.get_url() not in [x[0] for x in comment_data]: comment_data.append((response.get_url(), response.id)) self._comments[comment] = comment_data return True #pylint: enable=E1103 return False def end(self): """ This method is called when the plugin wont be used anymore. :return: None """ inform = [] for comment in self._comments.iterkeys(): urls_with_this_comment = self._comments[comment] stick_comment = ' '.join(comment.split()) if len(stick_comment) > 40: msg = 'A comment with the string "%s..." (and %s more bytes)'\ ' was found on these URL(s):' om.out.information( msg % (stick_comment[:40], str(len(stick_comment) - 40))) else: msg = 'A comment containing "%s" was found on these URL(s):' om.out.information(msg % (stick_comment)) for url, request_id in urls_with_this_comment: inform.append('- ' + url + ' (request with id: ' + str(request_id) + ')') inform.sort() for i in inform: om.out.information(i) self._comments.cleanup() def get_long_desc(self): """ :return: A DETAILED description of the plugin functions and features. """ return """
def inner(self, *args, **kwargs): if not hasattr(self, 'disk_cache'): self.disk_cache = {'key_set': set(), 'disk_cache': DiskDict('rsp_parser')} return func(self, *args, **kwargs)
class CachedDiskDict(object): """ This data structure keeps the `max_in_memory` most frequently accessed keys in memory and stores the rest on disk. It is ideal for situations where a DiskDict is frequently accessed, fast read / writes are required, and items can take considerable amounts of memory. """ def __init__(self, max_in_memory=50, table_prefix=None): """ :param max_in_memory: The max number of items to keep in memory """ assert max_in_memory > 0, 'In-memory items must be > 0' table_prefix = self._get_table_prefix(table_prefix) self._max_in_memory = max_in_memory self._disk_dict = DiskDict(table_prefix=table_prefix) self._in_memory = dict() self._access_count = Counter() def cleanup(self): self._disk_dict.cleanup() def _get_table_prefix(self, table_prefix): if table_prefix is None: table_prefix = 'cached_disk_dict_%s' % rand_alpha(16) else: args = (table_prefix, rand_alpha(16)) table_prefix = 'cached_disk_dict_%s_%s' % args return table_prefix def get(self, key, default=-456): try: return self[key] except KeyError: if default is not -456: return default raise KeyError() def __getitem__(self, key): try: value = self._in_memory[key] except KeyError: # This will raise KeyError if k is not found, and that is OK # because we don't need to increase the access count when the # key doesn't exist value = self._disk_dict[key] self._increase_access_count(key) return value def _get_keys_for_memory(self): """ :return: Generate the names of the keys that should be kept in memory. For example, if `max_in_memory` is set to 2 and: _in_memory: {1: None, 2: None} _access_count: {1: 10, 2: 20, 3: 5} _disk_dict: {3: None} Then the method will generate [1, 2]. """ return [k for k, v in self._access_count.most_common(self._max_in_memory)] def _increase_access_count(self, key): self._access_count.update([key]) keys_for_memory = self._get_keys_for_memory() self._move_key_to_disk_if_needed(keys_for_memory) self._move_key_to_memory_if_needed(key, keys_for_memory) def _move_key_to_disk_if_needed(self, keys_for_memory): """ Analyzes the current access count for the last accessed key and checks if any if the keys in memory should be moved to disk. :param keys_for_memory: The keys that should be in memory :return: The name of the key that was moved to disk, or None if all the keys are still in memory. """ for key in self._in_memory: if key in keys_for_memory: continue try: value = self._in_memory.pop(key) except KeyError: return else: self._disk_dict[key] = value return key def _move_key_to_memory_if_needed(self, key, keys_for_memory): """ Analyzes the current access count for the last accessed key and checks if any if the keys in disk should be moved to memory. :param key: The key that was last accessed :param keys_for_memory: The keys that should be in memory :return: The name of the key that was moved to memory, or None if all the keys are still on disk. """ # The key is already in memory, nothing to do here if key in self._in_memory: return # The key must not be in memory, nothing to do here if key not in keys_for_memory: return try: value = self._disk_dict.pop(key) except KeyError: return else: self._in_memory[key] = value return key def __setitem__(self, key, value): if key in self._in_memory: self._in_memory[key] = value elif len(self._in_memory) < self._max_in_memory: self._in_memory[key] = value else: self._disk_dict[key] = value self._increase_access_count(key) def __delitem__(self, key): try: del self._in_memory[key] except KeyError: # This will raise KeyError if k is not found, and that is OK # because we don't need to increase the access count when the # key doesn't exist del self._disk_dict[key] try: del self._access_count[key] except KeyError: # Another thread removed this key pass def __contains__(self, key): if key in self._in_memory: self._increase_access_count(key) return True if key in self._disk_dict: self._increase_access_count(key) return True return False def __iter__(self): """ Decided not to increase the access count when iterating through the items. In most cases the iteration will be performed on all items, thus increasing the access count +1 for each, which will leave all access counts +1, forcing no movements between memory and disk. """ for key in self._in_memory: yield key for key in self._disk_dict: yield key def iteritems(self): for key, value in self._in_memory.iteritems(): yield key, value for key, value in self._disk_dict.iteritems(): yield key, value
def __init__(self): AuditPlugin.__init__(self) # Internal variables self._expected_mutant_dict = DiskDict(table_prefix="ssi") self._extract_expected_re = re.compile("[1-9]{5}")
class ssi(AuditPlugin): """ Find server side inclusion vulnerabilities. :author: Andres Riancho ([email protected]) """ def __init__(self): AuditPlugin.__init__(self) # Internal variables self._expected_mutant_dict = DiskDict(table_prefix="ssi") self._extract_expected_re = re.compile("[1-9]{5}") def audit(self, freq, orig_response): """ Tests an URL for server side inclusion vulnerabilities. :param freq: A FuzzableRequest """ ssi_strings = self._get_ssi_strings() mutants = create_mutants(freq, ssi_strings, orig_resp=orig_response) self._send_mutants_in_threads(self._uri_opener.send_mutant, mutants, self._analyze_result) def _get_ssi_strings(self): """ This method returns a list of server sides to try to include. :return: A string, see above. """ # Generic yield '<!--#exec cmd="echo -n %s;echo -n %s" -->' % get_seeds() # Perl SSI yield ( '<!--#set var="SEED_A" value="%s" -->' '<!--#echo var="SEED_A" -->' '<!--#set var="SEED_B" value="%s" -->' '<!--#echo var="SEED_B" -->' % get_seeds() ) # Smarty # http://www.smarty.net/docsv2/en/language.function.math.tpl yield '{math equation="x * y" x=%s y=%s}' % get_seeds() # Mako # http://docs.makotemplates.org/en/latest/syntax.html yield "${%s * %s}" % get_seeds() # Jinja2 and Twig # http://jinja.pocoo.org/docs/dev/templates/#math # http://twig.sensiolabs.org/doc/templates.html yield "{{%s * %s}}" % get_seeds() # Generic yield "{%s * %s}" % get_seeds() def _get_expected_results(self, mutant): """ Extracts the potential results from the mutant payload and returns them in a list. """ sent_payload = mutant.get_token_payload() seed_numbers = self._extract_expected_re.findall(sent_payload) seed_a = int(seed_numbers[0]) seed_b = int(seed_numbers[1]) return [str(seed_a * seed_b), "%s%s" % (seed_a, seed_b)] def _analyze_result(self, mutant, response): """ Analyze the result of the previously sent request. :return: None, save the vuln to the kb. """ # Store the mutants in order to be able to analyze the persistent case # later expected_results = self._get_expected_results(mutant) for expected_result in expected_results: self._expected_mutant_dict[expected_result] = mutant # Now we analyze the "reflected" case if self._has_bug(mutant): return for expected_result in expected_results: if expected_result not in response: continue if expected_result in mutant.get_original_response_body(): continue desc = "Server side include (SSI) was found at: %s" desc %= mutant.found_at() v = Vuln.from_mutant( "Server side include vulnerability", desc, severity.HIGH, response.id, self.get_name(), mutant ) v.add_to_highlight(expected_result) self.kb_append_uniq(self, "ssi", v) def end(self): """ This method is called when the plugin wont be used anymore and is used to find persistent SSI vulnerabilities. Example where a persistent SSI can be found: Say you have a "guest book" (a CGI application that allows visitors to leave messages for everyone to see) on a server that has SSI enabled. Most such guest books around the Net actually allow visitors to enter HTML code as part of their comments. Now, what happens if a malicious visitor decides to do some damage by entering the following: <!--#exec cmd="ls" --> If the guest book CGI program was designed carefully, to strip SSI commands from the input, then there is no problem. But, if it was not, there exists the potential for a major headache! For a working example please see moth VM. """ fuzzable_request_set = kb.kb.get_all_known_fuzzable_requests() self._send_mutants_in_threads( self._uri_opener.send_mutant, fuzzable_request_set, self._analyze_persistent, cache=False ) self._expected_mutant_dict.cleanup() def _analyze_persistent(self, freq, response): """ Analyze the response of sending each fuzzable request found by the framework, trying to identify any locations where we might have injected a payload. :param freq: The fuzzable request :param response: The HTTP response :return: None, vulns are stored in KB """ multi_in_inst = multi_in(self._expected_mutant_dict.keys()) for matched_expected_result in multi_in_inst.query(response.get_body()): # We found one of the expected results, now we search the # self._expected_mutant_dict to find which of the mutants sent it # and create the vulnerability mutant = self._expected_mutant_dict[matched_expected_result] desc = ( "Server side include (SSI) was found at: %s" " The result of that injection is shown by browsing" ' to "%s".' ) desc %= (mutant.found_at(), freq.get_url()) v = Vuln.from_mutant( "Persistent server side include vulnerability", desc, severity.HIGH, response.id, self.get_name(), mutant, ) v.add_to_highlight(matched_expected_result) self.kb_append(self, "ssi", v) def get_long_desc(self): """ :return: A DETAILED description of the plugin functions and features. """ return """
class VariantDB(object): def __init__(self, max_variants=DEFAULT_MAX_VARIANTS): self._disk_dict = DiskDict(table_prefix='variant_db') self._db_lock = threading.RLock() self.max_variants = max_variants def append(self, reference): """ Called when a new reference is found and we proved that new variants are still needed. :param reference: The reference (as a URL object) to add. This method will "normalize" it before adding it to the internal shelve. """ clean_reference = self._clean_reference(reference) with self._db_lock: count = self._disk_dict.get(clean_reference, None) if count is not None: self._disk_dict[clean_reference] = count + 1 else: self._disk_dict[clean_reference] = 1 def need_more_variants(self, reference): """ :return: True if there are not enough variants associated with this reference in the DB. """ clean_reference = self._clean_reference(reference) # I believe this is atomic enough... count = self._disk_dict.get(clean_reference, 0) if count >= self.max_variants: return False else: return True def _clean_reference(self, reference): """ This method is VERY dependent on the are_variants method from core.data.request.variant_identification , make sure to remember that when changing stuff here or there. What this method does is to "normalize" any input reference string so that they can be compared very simply using string match. """ res = reference.get_domain_path() + reference.get_file_name() if reference.has_query_string(): res += '?' qs = copy.deepcopy(reference.querystring) for key, value, path, setter in qs.iter_setters(): if value.isdigit(): setter('number') else: setter('string') res += str(qs) return res
class VariantDB(object): """ See the notes on PARAMS_MAX_VARIANTS and PATH_MAX_VARIANTS above. Also understand that we'll keep "dirty" versions of the references/fuzzable requests in order to be able to answer "False" to a call for need_more_variants in a situation like this: need_more_variants('http://foo.com/abc?id=32') --> True append('http://foo.com/abc?id=32') need_more_variants('http://foo.com/abc?id=32') --> False """ HASH_IGNORE_HEADERS = ('referer',) TAG = '[variant_db]' def __init__(self): self._variants_eq = DiskDict(table_prefix='variant_db_eq') self._variants = DiskDict(table_prefix='variant_db') self.params_max_variants = cf.cf.get('params_max_variants', PARAMS_MAX_VARIANTS) self.path_max_variants = cf.cf.get('path_max_variants', PATH_MAX_VARIANTS) self._db_lock = threading.RLock() def cleanup(self): self._variants_eq.cleanup() self._variants.cleanup() def append(self, fuzzable_request): """ :return: True if we added a new fuzzable request variant to the DB, False if no more variants are required for this fuzzable request. """ with self._db_lock: # # Is the fuzzable request already known to us? (exactly the same) # request_hash = fuzzable_request.get_request_hash(self.HASH_IGNORE_HEADERS) already_seen = self._variants_eq.get(request_hash, False) if already_seen: return False # Store it to avoid duplicated fuzzable requests in our framework self._variants_eq[request_hash] = True # # Do we need more variants for the fuzzable request? (similar match) # clean_dict_key = clean_fuzzable_request(fuzzable_request) count = self._variants.get(clean_dict_key, None) if count is None: self._variants[clean_dict_key] = 1 return True # We've seen at least one fuzzable request with this pattern... url = fuzzable_request.get_uri() has_params = url.has_query_string() or fuzzable_request.get_raw_data() # Choose which max_variants to use if has_params: max_variants = self.params_max_variants else: max_variants = self.path_max_variants if count >= max_variants: return False else: self._variants[clean_dict_key] = count + 1 return True
class VariantDB(object): def __init__(self, max_variants=DEFAULT_MAX_VARIANTS): self._disk_dict = DiskDict(table_prefix='variant_db') self._db_lock = threading.RLock() self.max_variants = max_variants def append(self, reference): """ Called when a new reference is found and we proved that new variants are still needed. :param reference: The reference (as a URL object) to add. This method will "normalize" it before adding it to the internal shelve. """ clean_reference = self._clean_reference(reference) with self._db_lock: count = self._disk_dict.get(clean_reference, None) if count is not None: self._disk_dict[clean_reference] = count + 1 else: self._disk_dict[clean_reference] = 1 def append_fr(self, fuzzable_request): """ See append()'s documentation """ clean_fuzzable_request = self._clean_fuzzable_request(fuzzable_request) with self._db_lock: count = self._disk_dict.get(clean_fuzzable_request, None) if count is not None: self._disk_dict[clean_fuzzable_request] = count + 1 else: self._disk_dict[clean_fuzzable_request] = 1 def need_more_variants(self, reference): """ :return: True if there are not enough variants associated with this reference in the DB. """ clean_reference = self._clean_reference(reference) has_qs = reference.has_query_string() # I believe this is atomic enough... count = self._disk_dict.get(clean_reference, 0) # When we're analyzing a path (without QS), we just need 1 max_variants = self.max_variants if has_qs else 1 if count >= max_variants: return False else: return True def need_more_variants_for_fr(self, fuzzable_request): """ :return: True if there are not enough variants associated with this reference in the DB. """ clean_fuzzable_request = self._clean_fuzzable_request(fuzzable_request) # I believe this is atomic enough... count = self._disk_dict.get(clean_fuzzable_request, 0) if count >= self.max_variants: return False else: return True def _clean_reference(self, reference): """ This method is VERY dependent on the are_variants method from core.data.request.variant_identification , make sure to remember that when changing stuff here or there. What this method does is to "normalize" any input reference string so that they can be compared very simply using string match. Since this is a reference (link) we'll prepend '(GET)-' to the result, which will help us add support for forms/fuzzable requests with '(POST)-' in the future. """ res = '(GET)-' res += reference.get_domain_path() + reference.get_file_name() if reference.has_query_string(): res += '?' + self._clean_data_container(reference.querystring) return res def _clean_data_container(self, data_container): """ A simplified/serialized version of the query string """ dc = copy.deepcopy(data_container) for key, value, path, setter in dc.iter_setters(): if value.isdigit(): setter('number') else: setter('string') return str(dc) def _clean_fuzzable_request(self, fuzzable_request): """ Very similar to _clean_reference but we receive a fuzzable request instead. The output includes the HTTP method and any parameters which might be sent over HTTP post-data in the request are appended to the result as query string params. :param fuzzable_request: The fuzzable request instance to clean :return: See _clean_reference """ res = '(%s)-' % fuzzable_request.get_method().upper() uri = fuzzable_request.get_uri() res += uri.get_domain_path() + uri.get_file_name() if uri.has_query_string(): res += '?' + self._clean_data_container(uri.querystring) if fuzzable_request.get_raw_data(): res += '!' + self._clean_data_container(fuzzable_request.get_raw_data()) return res
def __init__(self): GrepPlugin.__init__(self) # Internal variables self._comments = DiskDict() self._already_reported_interesting = ScalableBloomFilter()
def __init__(self, max_variants=DEFAULT_MAX_VARIANTS): self._disk_dict = DiskDict(table_prefix='variant_db') self._db_lock = threading.RLock() self.max_variants = max_variants
def test_not_in(self): disk_dict = DiskDict() self.assertRaises(KeyError, disk_dict.__getitem__, 'abc')
def test_del(self): disk_dict = DiskDict() disk_dict['a'] = 'abc' del disk_dict['a'] self.assertNotIn('a', disk_dict)
class html_comments(GrepPlugin): """ Extract and analyze HTML comments. :author: Andres Riancho ([email protected]) """ HTML_RE = re.compile('<[a-zA-Z]*.*?>.*?</[a-zA-Z]>') INTERESTING_WORDS = ( # In English 'user', 'pass', 'xxx', 'fix', 'bug', 'broken', 'oops', 'hack', 'caution', 'todo', 'note', 'warning', '!!!', '???', 'shit', 'pass', 'password', 'passwd', 'pwd', 'secret', 'stupid', # In Spanish 'tonto', 'porqueria', 'cuidado', 'usuario', u'contraseña', 'puta', 'email', 'security', 'captcha', 'pinga', 'cojones', # some in Portuguese 'banco', 'bradesco', 'itau', 'visa', 'bancoreal', u'transfêrencia', u'depósito', u'cartão', u'crédito', 'dados pessoais' ) _multi_in = MultiIn([' %s ' % w for w in INTERESTING_WORDS]) def __init__(self): GrepPlugin.__init__(self) # Internal variables self._comments = DiskDict(table_prefix='html_comments') self._already_reported = ScalableBloomFilter() self._end_was_called = False def grep(self, request, response): """ Plugin entry point, parse those comments! :param request: The HTTP request object. :param response: The HTTP response object :return: None """ if not response.is_text_or_html(): return try: dp = parser_cache.dpc.get_document_parser_for(response) except BaseFrameworkException: return for comment in dp.get_comments(): # These next two lines fix this issue: # audit.ssi + grep.html_comments + web app with XSS = false positive if request.sent(comment): continue if self._is_new(comment, response): self._interesting_word(comment, request, response) self._html_in_comment(comment, request, response) def _interesting_word(self, comment, request, response): """ Find interesting words in HTML comments """ comment = comment.lower() for word in self._multi_in.query(comment): if (word, response.get_url()) in self._already_reported: continue desc = ('A comment with the string "%s" was found in: "%s".' ' This could be interesting.') desc %= (word, response.get_url()) i = Info.from_fr('Interesting HTML comment', desc, response.id, self.get_name(), request) i.add_to_highlight(word) kb.kb.append(self, 'interesting_comments', i) om.out.information(i.get_desc()) self._already_reported.add((word, response.get_url())) def _html_in_comment(self, comment, request, response): """ Find HTML code in HTML comments """ html_in_comment = self.HTML_RE.search(comment) if html_in_comment is None: return if (comment, response.get_url()) in self._already_reported: return # There is HTML code in the comment. comment = comment.strip() comment = comment.replace('\n', '') comment = comment.replace('\r', '') comment = comment[:40] desc = ('A comment with the string "%s" was found in: "%s".' ' This could be interesting.') desc %= (comment, response.get_url()) i = Info.from_fr('HTML comment contains HTML code', desc, response.id, self.get_name(), request) i.set_uri(response.get_uri()) i.add_to_highlight(html_in_comment.group(0)) kb.kb.append(self, 'html_comment_hides_html', i) om.out.information(i.get_desc()) self._already_reported.add((comment, response.get_url())) def _handle_no_such_table(self, comment, response, nste): """ I had a lot of issues trying to reproduce [0], so this code is just a helper for me to identify the root cause. [0] https://github.com/andresriancho/w3af/issues/10849 :param nste: The original exception :param comment: The comment we're analyzing :param response: The HTTP response :return: None, an exception with more information is re-raised """ msg = ('A NoSuchTableException was raised by the DBMS. This issue is' ' related with #10849 , but since I was unable to reproduce' ' it, extra debug information is added to the exception:' '\n' '\n - Grep plugin end() was called: %s' '\n - Response ID is: %s' '\n - HTML comment is: "%s"' '\n - Original exception: "%s"' '\n\n' 'https://github.com/andresriancho/w3af/issues/10849\n') args = (self._end_was_called, response.get_id(), comment, nste) raise NoSuchTableException(msg % args) def _is_new(self, comment, response): """ Make sure that we perform a thread safe check on the self._comments dict, in order to avoid duplicates. """ with self._plugin_lock: #pylint: disable=E1103 try: comment_data = self._comments.get(comment, None) except NoSuchTableException, nste: self._handle_no_such_table(comment, response, nste) response_url = response.get_url() if comment_data is None: self._comments[comment] = [(response_url, response.id)] return True else: for saved_url, response_id in comment_data: if response_url == saved_url: return False else: comment_data.append((response_url, response.id)) self._comments[comment] = comment_data return True
def __init__(self, max_variants=5): self._disk_dict = DiskDict() self._db_lock = threading.RLock() self.max_variants = max_variants
class CachedQueue(Queue.Queue, QueueSpeedMeasurement): """ The framework uses the producer / consumer design pattern extensively. In order to avoid high memory usage in the queues connecting the different parts of the framework we defined a max size. When a queue max size is reached, one or more threads will block. This line is printed during a real scan: Thread blocked 5.76617312431 seconds waiting for Queue.put() to have space in the Grep queue. The queue's maxsize is 20. In the case of the Grep consumer / producer the problem with a block is increased by the fact that HTTP responses won't reach other parts of the framework until the queue has space. Increasing the queue size would increase memory usage. Using an on-disk queue would increase CPU (serialization) and disk IO. The CacheQueue is a mix of in-memory and on-disk queue. The first N items are stored in memory, when more items are put() we just write them to disk. The CacheQueue object implements these methods from QueueSpeedMeasurement: * get_input_rpm * get_output_rpm Which allows users to understand how fast a queue is moving. """ def __init__(self, maxsize=0, name='Unknown'): self.name = name self.max_in_memory = maxsize QueueSpeedMeasurement.__init__(self) # We want to send zero to the maxsize of the Queue implementation # here because we can write an infinite number of items Queue.Queue.__init__(self, maxsize=0) def get_name(self): return self.name def next_item_saved_to_memory(self): return len(self.memory) < self.max_in_memory def _init(self, maxsize): """ Initialize the dicts and pointer :param maxsize: The max size for the queue """ self.memory = dict() self.disk = DiskDict(table_prefix='%sCachedQueue' % self.name) self.get_pointer = 0 self.put_pointer = 0 def _qsize(self, len=len): return len(self.memory) + len(self.disk) def _get_class_name(self, obj): try: return obj.__class__.__name__ except: return type(obj) def _put(self, item): """ Put a new item in the queue """ # # This is very useful information for finding bottlenecks in the # framework / strategy # if len(self.memory) == self.max_in_memory: # # If you see many messages like this in the scan log, then you # might want to experiment with a larger maxsize for this queue # msg = ('CachedQueue.put() will write a %r item to the %s DiskDict.' ' This uses more CPU and disk IO than storing in memory' ' but will avoid high memory usage issues. The current' ' %s DiskDict size is %s.') args = (self._get_class_name(item), self.get_name(), self.get_name(), len(self.disk)) om.out.debug(msg % args) # # And now we just save the item to memory (if there is space) or # disk (if it doesn't fit on memory) # if len(self.memory) < self.max_in_memory: self.memory[self.put_pointer] = item else: self.disk[self.put_pointer] = item self.put_pointer += 1 self._item_added_to_queue() def _get(self): """ Get an item from the queue """ try: item = self.memory.pop(self.get_pointer) except KeyError: item = self.disk.pop(self.get_pointer) if len(self.disk): # # If you see many messages like this in the scan log, then you # might want to experiment with a larger maxsize for this queue # msg = ('CachedQueue.get() from %s DiskDict was used to read an' ' item from disk. The current %s DiskDict size is %s.') args = (self.get_name(), self.get_name(), len(self.disk)) om.out.debug(msg % args) self._item_left_queue() self.get_pointer += 1 return item
class ssi(AuditPlugin): """ Find server side inclusion vulnerabilities. :author: Andres Riancho ([email protected]) """ def __init__(self): AuditPlugin.__init__(self) # Internal variables self._expected_res_mutant = DiskDict() self._freq_list = DiskList() re_str = '<!--#exec cmd="echo -n (.*?);echo -n (.*?)" -->' self._extract_results_re = re.compile(re_str) def audit(self, freq, orig_response): """ Tests an URL for server side inclusion vulnerabilities. :param freq: A FuzzableRequest """ # Create the mutants to send right now, ssi_strings = self._get_ssi_strings() mutants = create_mutants(freq, ssi_strings, orig_resp=orig_response) # Used in end() to detect "persistent SSI" for mut in mutants: expected_result = self._extract_result_from_payload( mut.get_token_value()) self._expected_res_mutant[expected_result] = mut self._freq_list.append(freq) # End of persistent SSI setup self._send_mutants_in_threads(self._uri_opener.send_mutant, mutants, self._analyze_result) def _get_ssi_strings(self): """ This method returns a list of server sides to try to include. :return: A string, see above. """ yield '<!--#exec cmd="echo -n %s;echo -n %s" -->' % (rand_alpha(5), rand_alpha(5)) # TODO: Add mod_perl ssi injection support # http://www.sens.buffalo.edu/services/webhosting/advanced/perlssi.shtml #yield <!--#perl sub="sub {print qq/If you see this, mod_perl is working!/;}" --> def _extract_result_from_payload(self, payload): """ Extract the expected result from the payload we're sending. """ match = self._extract_results_re.search(payload) return match.group(1) + match.group(2) def _analyze_result(self, mutant, response): """ Analyze the result of the previously sent request. :return: None, save the vuln to the kb. """ if self._has_no_bug(mutant): e_res = self._extract_result_from_payload(mutant.get_token_value()) if e_res in response and not e_res in mutant.get_original_response_body(): desc = 'Server side include (SSI) was found at: %s' desc = desc % mutant.found_at() v = Vuln.from_mutant('Server side include vulnerability', desc, severity.HIGH, response.id, self.get_name(), mutant) v.add_to_highlight(e_res) self.kb_append_uniq(self, 'ssi', v) def end(self): """ This method is called when the plugin wont be used anymore and is used to find persistent SSI vulnerabilities. Example where a persistent SSI can be found: Say you have a "guestbook" (a CGI application that allows visitors to leave messages for everyone to see) on a server that has SSI enabled. Most such guestbooks around the Net actually allow visitors to enter HTML code as part of their comments. Now, what happens if a malicious visitor decides to do some damage by entering the following: <!--#exec cmd="ls" --> If the guestbook CGI program was designed carefully, to strip SSI commands from the input, then there is no problem. But, if it was not, there exists the potential for a major headache! For a working example please see moth VM. """ multi_in_inst = multi_in(self._expected_res_mutant.keys()) def filtered_freq_generator(freq_list): already_tested = ScalableBloomFilter() for freq in freq_list: if freq not in already_tested: already_tested.add(freq) yield freq def analyze_persistent(freq, response): for matched_expected_result in multi_in_inst.query(response.get_body()): # We found one of the expected results, now we search the # self._persistent_data to find which of the mutants sent it # and create the vulnerability mutant = self._expected_res_mutant[matched_expected_result] desc = 'Server side include (SSI) was found at: %s' \ ' The result of that injection is shown by browsing'\ ' to "%s".' desc = desc % (mutant.found_at(), freq.get_url()) v = Vuln.from_mutant('Persistent server side include vulnerability', desc, severity.HIGH, response.id, self.get_name(), mutant) v.add_to_highlight(matched_expected_result) self.kb_append(self, 'ssi', v) self._send_mutants_in_threads(self._uri_opener.send_mutant, filtered_freq_generator(self._freq_list), analyze_persistent, cache=False) self._expected_res_mutant.cleanup() self._freq_list.cleanup() def get_long_desc(self): """ :return: A DETAILED description of the plugin functions and features. """ return """
def __init__(self, iterable=(), maxsize=-1): if not hasattr(self, 'data'): self.left = self.right = 0 self.data = DiskDict() self.maxsize = maxsize self.extend(iterable)
def test_len(self): disk_dict = DiskDict() disk_dict['a'] = 'abc' self.assertEqual(len(disk_dict), 1)
class OrderedCachedQueue(Queue.Queue, QueueSpeedMeasurement): """ This queue implements all the features explained in CachedQueue (see cached_queue.py) plus it will order the items in the queue as they are inserted. The queue is ordered by a unique identifier that is returned by the object being added. If the object is None, then it is is added to the end of the queue. The goal of this ordered queue is to impose an order in which URLs and forms identified by the w3af framework are processed by the plugins. Since plugins are run in threads, the order in which new URLs are added to the queue is "completely random" and depends on HTTP response times, CPU-load, memory swapping, etc. """ LAST_MD5_HASH = 'f' * 32 def __init__(self, maxsize=0, name='Unknown'): self.name = name self.max_in_memory = maxsize self.processed_tasks = 0 QueueSpeedMeasurement.__init__(self) self.queue_order = None self.hash_to_uuid = None self.memory = None self.disk = None # We want to send zero to the maxsize of the Queue implementation # here because we can write an infinite number of items. But keep # in mind that we don't really use the queue storage in any way Queue.Queue.__init__(self, maxsize=0) def get_name(self): return self.name def get_processed_tasks(self): return self.processed_tasks def next_item_saved_to_memory(self): return len(self.memory) < self.max_in_memory def _init(self, maxsize): """ Initialize the dicts and pointer :param maxsize: The max size for the queue """ self.queue_order = list() self.hash_to_uuid = dict() self.memory = dict() self.disk = DiskDict(table_prefix='%sCachedQueue' % self.name) def _qsize(self, _len=len): return _len(self.memory) + _len(self.disk) def _get_class_name(self, obj): try: return obj.__class__.__name__ except: return type(obj) def _get_hash(self, item): if item is None or item == POISON_PILL: # Return ffff...ffff which is the latest (in alphanumeric order) # hash that exists in MD5. This forces the None item to be placed # at the end of the queue. # # Warning! If FuzzableRequest.get_hash() ever changes its # implementation this will stop working as expected! return self.LAST_MD5_HASH return item.get_hash() def _put(self, item): """ Put a new item in the queue """ # # This is very useful information for finding bottlenecks in the # framework / strategy # if len(self.memory) == self.max_in_memory: # # If you see many messages like this in the scan log, then you # might want to experiment with a larger maxsize for this queue # msg = ('OrderedCachedQueue.put() will write a %r item to the %s' ' DiskDict. This uses more CPU and disk IO than storing' ' in memory but will avoid high memory usage issues. The' ' current %s DiskDict size is %s.') args = (self._get_class_name(item), self.get_name(), self.get_name(), len(self.disk)) om.out.debug(msg % args) # # Get the item hash to store it in the queue order list, and insert # it using bisect.insort() that will keep the order at a low cost # item_hash = self._get_hash(item) bisect.insort(self.queue_order, item_hash) # # Keep an in-memory dict that allows us to find the fuzzable requests # in the other dictionaries # unique_id = str(uuid.uuid4()) unique_id_list = self.hash_to_uuid.setdefault(item_hash, []) bisect.insort(unique_id_list, unique_id) # # And now we just save the item to memory (if there is space) or # disk (if it doesn't fit on memory) # if len(self.memory) < self.max_in_memory: self.memory[unique_id] = item else: self.disk[unique_id] = item self._item_added_to_queue() def _get(self): """ Get an item from the queue """ item_hash = self.queue_order.pop(0) unique_id_list = self.hash_to_uuid.pop(item_hash) unique_id = unique_id_list.pop(0) if unique_id_list: # # There are still items in this unique_id_list, this is most likely # because two items with the same hash were added to the queue, and # only one of those has been read. # # Need to add the other item(s) to the list again # bisect.insort(self.queue_order, item_hash) self.hash_to_uuid[item_hash] = unique_id_list try: item = self.memory.pop(unique_id) except KeyError: item = self.disk.pop(unique_id) if len(self.disk): # # If you see many messages like this in the scan log, then you # might want to experiment with a larger maxsize for this queue # msg = ('OrderedCachedQueue.get() from %s DiskDict was used to' ' read an item from disk. The current %s DiskDict' ' size is %s.') args = (self.get_name(), self.get_name(), len(self.disk)) om.out.debug(msg % args) self._item_left_queue() self.processed_tasks += 1 return item
def __init__(self, iterable=(), maxsize=-1): if not hasattr(self, 'data'): self.left = self.right = 0 self.data = DiskDict(table_prefix='deque') self.maxsize = maxsize self.extend(iterable)
class DiskDeque(object): """ The base code for this file comes from [0], I've modified it to use a DiskDict which stores the "self.data" dictionary to disk in order to save memory. [0] https://code.activestate.com/recipes/259179/ """ def __init__(self, iterable=(), maxsize=-1): if not hasattr(self, 'data'): self.left = self.right = 0 self.data = DiskDict(table_prefix='deque') self.maxsize = maxsize self.extend(iterable) def append(self, x): self.data[self.right] = x self.right += 1 if self.maxsize != -1 and len(self) > self.maxsize: self.popleft() def appendleft(self, x): self.left -= 1 self.data[self.left] = x if self.maxsize != -1 and len(self) > self.maxsize: self.pop() def pop(self): if self.left == self.right: raise IndexError('cannot pop from empty deque') self.right -= 1 elem = self.data[self.right] del self.data[self.right] return elem def popleft(self): if self.left == self.right: raise IndexError('cannot pop from empty deque') elem = self.data[self.left] del self.data[self.left] self.left += 1 return elem def clear(self): self.data.cleanup() self.left = self.right = 0 def extend(self, iterable): for elem in iterable: self.append(elem) def extendleft(self, iterable): for elem in iterable: self.appendleft(elem) def rotate(self, n=1): if self: n %= len(self) for i in xrange(n): self.appendleft(self.pop()) def __getitem__(self, i): if i < 0: i += len(self) try: return self.data[i + self.left] except KeyError: raise IndexError def __setitem__(self, i, value): if i < 0: i += len(self) try: self.data[i + self.left] = value except KeyError: raise IndexError def __delitem__(self, i): size = len(self) if not (-size <= i < size): raise IndexError data = self.data if i < 0: i += size for j in xrange(self.left + i, self.right - 1): data[j] = data[j + 1] self.pop() def __len__(self): return self.right - self.left def __cmp__(self, other): if type(self) != type(other): return cmp(type(self), type(other)) return cmp(list(self), list(other)) def __repr__(self, _track=[]): if id(self) in _track: return '...' _track.append(id(self)) r = 'deque(%r)' % (list(self), ) _track.remove(id(self)) return r def __getstate__(self): return tuple(self) def __setstate__(self, s): self.__init__(s[0]) def __hash__(self): raise TypeError def __copy__(self): return self.__class__(self) def __deepcopy__(self, memo={}): from copy import deepcopy result = self.__class__() memo[id(self)] = result result.__init__(deepcopy(tuple(self), memo)) return result
class VariantDB(object): def __init__(self, max_variants=DEFAULT_MAX_VARIANTS): self._disk_dict = DiskDict(table_prefix='variant_db') self._db_lock = threading.RLock() self.max_variants = max_variants def append(self, reference): """ Called when a new reference is found and we proved that new variants are still needed. :param reference: The reference (as a URL object) to add. This method will "normalize" it before adding it to the internal shelve. """ clean_reference = self._clean_reference(reference) with self._db_lock: count = self._disk_dict.get(clean_reference, None) if count is not None: self._disk_dict[clean_reference] = count + 1 else: self._disk_dict[clean_reference] = 1 def append_fr(self, fuzzable_request): """ See append()'s documentation """ clean_fuzzable_request = self._clean_fuzzable_request(fuzzable_request) with self._db_lock: count = self._disk_dict.get(clean_fuzzable_request, None) if count is not None: self._disk_dict[clean_fuzzable_request] = count + 1 else: self._disk_dict[clean_fuzzable_request] = 1 def need_more_variants(self, reference): """ :return: True if there are not enough variants associated with this reference in the DB. """ clean_reference = self._clean_reference(reference) has_qs = reference.has_query_string() # I believe this is atomic enough... count = self._disk_dict.get(clean_reference, 0) # When we're analyzing a path (without QS), we just need 1 max_variants = self.max_variants if has_qs else 1 if count >= max_variants: return False else: return True def need_more_variants_for_fr(self, fuzzable_request): """ :return: True if there are not enough variants associated with this reference in the DB. """ clean_fuzzable_request = self._clean_fuzzable_request(fuzzable_request) # I believe this is atomic enough... count = self._disk_dict.get(clean_fuzzable_request, 0) if count >= self.max_variants: return False else: return True def _clean_reference(self, reference): """ This method is VERY dependent on the are_variants method from core.data.request.variant_identification , make sure to remember that when changing stuff here or there. What this method does is to "normalize" any input reference string so that they can be compared very simply using string match. Since this is a reference (link) we'll prepend '(GET)-' to the result, which will help us add support for forms/fuzzable requests with '(POST)-' in the future. """ res = '(GET)-' res += reference.get_domain_path().url_string.encode(DEFAULT_ENCODING) res += reference.get_file_name() if reference.has_query_string(): res += '?' + self._clean_data_container(reference.querystring) return res def _clean_data_container(self, data_container): """ A simplified/serialized version of the query string """ dc = copy.deepcopy(data_container) for key, value, path, setter in dc.iter_setters(): if value.isdigit(): setter('number') else: setter('string') return str(dc) def _clean_fuzzable_request(self, fuzzable_request): """ Very similar to _clean_reference but we receive a fuzzable request instead. The output includes the HTTP method and any parameters which might be sent over HTTP post-data in the request are appended to the result as query string params. :param fuzzable_request: The fuzzable request instance to clean :return: See _clean_reference """ res = '(%s)-' % fuzzable_request.get_method().upper() uri = fuzzable_request.get_uri() res += uri.get_domain_path() + uri.get_file_name() if uri.has_query_string(): res += '?' + self._clean_data_container(uri.querystring) if fuzzable_request.get_raw_data(): res += '!' + self._clean_data_container(fuzzable_request.get_raw_data()) return res
class DiskDeque(object): """ The base code for this file comes from [0], I've modified it to use a DiskDict which stores the "self.data" dictionary to disk in order to save memory. [0] https://code.activestate.com/recipes/259179/ """ def __init__(self, iterable=(), maxsize=-1): if not hasattr(self, 'data'): self.left = self.right = 0 self.data = DiskDict() self.maxsize = maxsize self.extend(iterable) def append(self, x): self.data[self.right] = x self.right += 1 if self.maxsize != -1 and len(self) > self.maxsize: self.popleft() def appendleft(self, x): self.left -= 1 self.data[self.left] = x if self.maxsize != -1 and len(self) > self.maxsize: self.pop() def pop(self): if self.left == self.right: raise IndexError('cannot pop from empty deque') self.right -= 1 elem = self.data[self.right] del self.data[self.right] return elem def popleft(self): if self.left == self.right: raise IndexError('cannot pop from empty deque') elem = self.data[self.left] del self.data[self.left] self.left += 1 return elem def clear(self): self.data.clear() self.left = self.right = 0 def extend(self, iterable): for elem in iterable: self.append(elem) def extendleft(self, iterable): for elem in iterable: self.appendleft(elem) def rotate(self, n=1): if self: n %= len(self) for i in xrange(n): self.appendleft(self.pop()) def __getitem__(self, i): if i < 0: i += len(self) try: return self.data[i + self.left] except KeyError: raise IndexError def __setitem__(self, i, value): if i < 0: i += len(self) try: self.data[i + self.left] = value except KeyError: raise IndexError def __delitem__(self, i): size = len(self) if not (-size <= i < size): raise IndexError data = self.data if i < 0: i += size for j in xrange(self.left+i, self.right-1): data[j] = data[j+1] self.pop() def __len__(self): return self.right - self.left def __cmp__(self, other): if type(self) != type(other): return cmp(type(self), type(other)) return cmp(list(self), list(other)) def __repr__(self, _track=[]): if id(self) in _track: return '...' _track.append(id(self)) r = 'deque(%r)' % (list(self),) _track.remove(id(self)) return r def __getstate__(self): return (tuple(self),) def __setstate__(self, s): self.__init__(s[0]) def __hash__(self): raise TypeError def __copy__(self): return self.__class__(self) def __deepcopy__(self, memo={}): from copy import deepcopy result = self.__class__() memo[id(self)] = result result.__init__(deepcopy(tuple(self), memo)) return result
class CachedDiskDict(object): """ This data structure keeps the `max_in_memory` most frequently accessed keys in memory and stores the rest on disk. It is ideal for situations where a DiskDict is frequently accessed, fast read / writes are required, and items can take considerable amounts of memory. """ def __init__(self, max_in_memory=50, table_prefix=None): """ :param max_in_memory: The max number of items to keep in memory """ assert max_in_memory > 0, 'In-memory items must be > 0' table_prefix = self._get_table_prefix(table_prefix) self._max_in_memory = max_in_memory self._disk_dict = DiskDict(table_prefix=table_prefix) self._in_memory = dict() self._access_count = dict() def cleanup(self): self._disk_dict.cleanup() def _get_table_prefix(self, table_prefix): if table_prefix is None: table_prefix = 'cached_disk_dict_%s' % rand_alpha(16) else: args = (table_prefix, rand_alpha(16)) table_prefix = 'cached_disk_dict_%s_%s' % args return table_prefix def get(self, key, default=-456): try: return self[key] except KeyError: if default is not -456: return default raise KeyError() def __getitem__(self, key): try: value = self._in_memory[key] except KeyError: # This will raise KeyError if k is not found, and that is OK # because we don't need to increase the access count when the # key doesn't exist value = self._disk_dict[key] self._increase_access_count(key) return value def _get_keys_for_memory(self): """ :return: Generate the names of the keys that should be kept in memory. For example, if `max_in_memory` is set to 2 and: _in_memory: {1: None, 2: None} _access_count: {1: 10, 2: 20, 3: 5} _disk_dict: {3: None} Then the method will generate [1, 2]. """ items = self._access_count.items() items.sort(sort_by_value) iterator = min(self._max_in_memory, len(items)) for i in xrange(iterator): yield items[i][0] def _belongs_in_memory(self, key): """ :param key: A key :return: True if the key should be stored in memory """ if key in self._get_keys_for_memory(): return True return False def _increase_access_count(self, key): access_count = self._access_count.get(key, 0) access_count += 1 self._access_count[key] = access_count self._move_key_to_disk_if_needed(key) self._move_key_to_memory_if_needed(key) def _move_key_to_disk_if_needed(self, key): """ Analyzes the current access count for the last accessed key and checks if any if the keys in memory should be moved to disk. :param key: The key that was last accessed :return: The name of the key that was moved to disk, or None if all the keys are still in memory. """ for key in self._in_memory.keys(): if not self._belongs_in_memory(key): try: value = self._in_memory[key] except KeyError: return None else: self._disk_dict[key] = value self._in_memory.pop(key, None) return key def _move_key_to_memory_if_needed(self, key): """ Analyzes the current access count for the last accessed key and checks if any if the keys in disk should be moved to memory. :param key: The key that was last accessed :return: The name of the key that was moved to memory, or None if all the keys are still on disk. """ key_belongs_in_memory = self._belongs_in_memory(key) if not key_belongs_in_memory: return None try: value = self._disk_dict[key] except KeyError: return None else: self._in_memory[key] = value self._disk_dict.pop(key, None) return key def __setitem__(self, key, value): if len(self._in_memory) < self._max_in_memory: self._in_memory[key] = value else: self._disk_dict[key] = value self._increase_access_count(key) def __delitem__(self, key): try: del self._in_memory[key] except KeyError: # This will raise KeyError if k is not found, and that is OK # because we don't need to increase the access count when the # key doesn't exist del self._disk_dict[key] try: del self._access_count[key] except KeyError: # Another thread removed this key pass def __contains__(self, key): if key in self._in_memory: self._increase_access_count(key) return True if key in self._disk_dict: self._increase_access_count(key) return True return False def __iter__(self): """ Decided not to increase the access count when iterating through the items. In most cases the iteration will be performed on all items, thus increasing the access count +1 for each, which will leave all access counts +1, forcing no movements between memory and disk. """ for key in self._in_memory: yield key for key in self._disk_dict: yield key