def _spellchecker_for(word_set, name, spellcheck_cache_path=None, sources=None): """Get a whoosh spellchecker for :word_set:. The word graph for this spellchecker will be stored on-disk with the unique-name :name: in :spellcheck_cache_path:, if it exists. This allows for much faster loading of word graphs after they have been pre-populated. :sources: is a list of filenames which will be checked to see if they are newer than the stored word graph. If they are newer, then the word graph gets repopulated. """ assert "/" not in name and "\\" not in name if _spellchecker_cache.get(name, None) is not None: return _spellchecker_cache[name].corrector # Check the modification time of all the paths in :sources: to see # if they've been modified since the cache file was created. If so, # delete the cache file. This will cause it to be regenerated. # # Note that this relies on an implementation detail in whoosh, namely # that the cache file is always stored at spellcheck_cache_path/name. if spellcheck_cache_path: # Ensure that the directory has been created try: os.makedirs(spellcheck_cache_path) except OSError as error: if error.errno != errno.EEXIST: # suppress(PYC90) raise error graph_path = os.path.realpath(spellcheck_cache_path) file_storage = FileStorage(graph_path) preexisting_cache = os.path.abspath(os.path.join(spellcheck_cache_path, name)) if os.path.exists(preexisting_cache): cache_mtime = os.path.getmtime(preexisting_cache) for source in sources: source_path = os.path.realpath(source) if not os.path.exists(source_path): continue if os.path.getmtime(source_path) > cache_mtime: file_storage.delete_file(name) break try: word_graph = copy_to_ram(file_storage).open_file(name) except (IOError, NameError): word_graph = _create_word_graph_file(name, file_storage, word_set) else: ram_storage = RamStorage() word_graph = _create_word_graph_file(name, ram_storage, word_set) reader = fst.GraphReader(word_graph) corrector = spelling.GraphCorrector(reader) _spellchecker_cache[name] = SpellcheckerCacheEntry(corrector, reader) return corrector
class Indexer(RaftNode): def __init__(self, host='localhost', port=7070, seed_addr=None, conf=SyncObjConf(), data_dir='/tmp/cockatrice/index', grpc_port=5050, grpc_max_workers=10, http_port=8080, logger=getLogger(), http_logger=getLogger(), metrics_registry=CollectorRegistry()): self.__host = host self.__port = port self.__seed_addr = seed_addr self.__conf = conf self.__data_dir = data_dir self.__grpc_port = grpc_port self.__grpc_max_workers = grpc_max_workers self.__http_port = http_port self.__logger = logger self.__http_logger = http_logger self.__metrics_registry = metrics_registry # metrics self.__metrics_core_documents = Gauge( '{0}_indexer_index_documents'.format(NAME), 'The number of documents.', [ 'index_name', ], registry=self.__metrics_registry) self.__metrics_requests_total = Counter( '{0}_indexer_requests_total'.format(NAME), 'The number of requests.', ['func'], registry=self.__metrics_registry) self.__metrics_requests_duration_seconds = Histogram( '{0}_indexer_requests_duration_seconds'.format(NAME), 'The invocation duration in seconds.', ['func'], registry=self.__metrics_registry) self.__self_addr = '{0}:{1}'.format(self.__host, self.__port) self.__peer_addrs = [] if self.__seed_addr is None else get_peers( bind_addr=self.__seed_addr, timeout=10) self.__other_addrs = [ peer_addr for peer_addr in self.__peer_addrs if peer_addr != self.__self_addr ] self.__conf.serializer = self.__serialize self.__conf.deserializer = self.__deserialize self.__conf.validate() self.__indices = {} self.__index_configs = {} self.__writers = {} self.__auto_commit_timers = {} self.__lock = RLock() # create data dir os.makedirs(self.__data_dir, exist_ok=True) self.__file_storage = FileStorage(self.__data_dir, supports_mmap=True, readonly=False, debug=False) self.__ram_storage = RamStorage() # if seed addr specified and self node does not exist in the cluster, add self node to the cluster if self.__seed_addr is not None and self.__self_addr not in self.__peer_addrs: Thread(target=add_node, kwargs={ 'node_name': self.__self_addr, 'bind_addr': self.__seed_addr, 'timeout': 10 }).start() # copy snapshot from the leader node if self.__seed_addr is not None: try: metadata = get_metadata(bind_addr=get_leader( bind_addr=self.__seed_addr, timeout=10), timeout=10) response = requests.get('http://{0}/snapshot'.format( metadata['http_addr'])) if response.status_code == HTTPStatus.OK: with open(self.__conf.fullDumpFile, 'wb') as f: f.write(response.content) except Exception as ex: self.__logger.error('failed to copy snapshot: {0}'.format(ex)) # start node metadata = { 'grpc_addr': '{0}:{1}'.format(self.__host, self.__grpc_port), 'http_addr': '{0}:{1}'.format(self.__host, self.__http_port) } self.__logger.info('starting raft state machine') super(Indexer, self).__init__(self.__self_addr, self.__peer_addrs, conf=self.__conf, metadata=metadata) self.__logger.info('raft state machine has started') if os.path.exists(self.__conf.fullDumpFile): self.__logger.debug('snapshot exists: {0}'.format( self.__conf.fullDumpFile)) else: pass while not self.isReady(): # recovering data self.__logger.debug('waiting for cluster ready') self.__logger.debug(self.getStatus()) time.sleep(1) self.__logger.info('cluster ready') self.__logger.debug(self.getStatus()) # open existing indices on startup for index_name in self.get_index_names(): self.__open_index(index_name, index_config=None) # record index metrics timer self.metrics_timer = Timer(10, self.__record_index_metrics) self.metrics_timer.start() # start gRPC self.__grpc_server = grpc.server( futures.ThreadPoolExecutor(max_workers=self.__grpc_max_workers)) add_IndexServicer_to_server( IndexGRPCServicer(self, logger=self.__logger, metrics_registry=self.__metrics_registry), self.__grpc_server) self.__grpc_server.add_insecure_port('{0}:{1}'.format( self.__host, self.__grpc_port)) self.__grpc_server.start() self.__logger.info('gRPC server has started') # start HTTP server self.__http_servicer = IndexHTTPServicer(self, self.__logger, self.__http_logger, self.__metrics_registry) self.__http_server = HTTPServer(self.__host, self.__http_port, self.__http_servicer) self.__http_server.start() self.__logger.info('HTTP server has started') self.__logger.info('indexer has started') def stop(self): # stop HTTP server self.__http_server.stop() self.__logger.info('HTTP server has stopped') # stop gRPC server self.__grpc_server.stop(grace=0.0) self.__logger.info('gRPC server has stopped') self.metrics_timer.cancel() # close indices for index_name in list(self.__indices.keys()): self.__close_index(index_name) self.destroy() self.__logger.info('index core has stopped') def __record_index_metrics(self): for index_name in list(self.__indices.keys()): try: self.__metrics_core_documents.labels( index_name=index_name).set(self.get_doc_count(index_name)) except Exception as ex: self.__logger.error(ex) def __record_metrics(self, start_time, func_name): self.__metrics_requests_total.labels(func=func_name).inc() self.__metrics_requests_duration_seconds.labels( func=func_name).observe(time.time() - start_time) # def __serialize_indices(self, filename): # with self.__lock: # try: # self.__logger.info('starting serialize indices') # # except Exception as ex: # self.__logger.error('failed to create snapshot: {0}'.format(ex)) # finally: # self.__logger.info('serialize indices has finished') # def __serialize_raft_data(self, filename, raft_data): # with self.__lock: # pass # index serializer def __serialize(self, filename, raft_data): with self.__lock: try: self.__logger.debug('serializer has started') # store the index files and raft logs to the snapshot file with zipfile.ZipFile(filename, 'w', zipfile.ZIP_DEFLATED) as f: for index_name in self.get_index_names(): self.__commit_index(index_name) # with self.__get_writer(index_name).writelock: # with self.__indices[index_name].lock('WRITELOCK'): # index files for index_filename in self.get_index_files(index_name): if self.__index_configs.get( index_name).get_storage_type() == "ram": with self.__ram_storage.open_file( index_filename) as r: f.writestr(index_filename, r.read()) else: f.write( os.path.join(self.__file_storage.folder, index_filename), index_filename) self.__logger.debug('{0} has stored in {1}'.format( index_filename, filename)) # index config file f.write( os.path.join( self.__file_storage.folder, self.get_index_config_file(index_name)), self.get_index_config_file(index_name)) self.__logger.debug('{0} has stored in {1}'.format( self.get_index_config_file(index_name), filename)) # store the raft data f.writestr(RAFT_DATA_FILE, pickle.dumps(raft_data)) self.__logger.debug( '{0} has restored'.format(RAFT_DATA_FILE)) self.__logger.debug('snapshot has created') except Exception as ex: self.__logger.error( 'failed to create snapshot: {0}'.format(ex)) finally: self.__logger.debug('serializer has stopped') # index deserializer def __deserialize(self, filename): with self.__lock: try: self.__logger.debug('deserializer has started') with zipfile.ZipFile(filename, 'r') as zf: # get file names in snapshot file filenames = list(zf.namelist()) # get index names in snapshot file index_names = [] pattern_toc = re.compile(r'^_(.+)_\d+\.toc$') for f in filenames: match = pattern_toc.search(f) if match and match.group(1) not in index_names: index_names.append(match.group(1)) for index_name in index_names: # extract the index config first zf.extract(self.get_index_config_file(index_name), path=self.__file_storage.folder) index_config = pickle.loads( zf.read(self.get_index_config_file(index_name))) # get index files pattern_toc = re.compile(r'^_{0}_(\d+)\..+$'.format( index_name)) # ex) _myindex_0.toc pattern_seg = re.compile( r'^{0}_([a-z0-9]+)\..+$'.format(index_name) ) # ex) myindex_zseabukc2nbpvh0u.seg pattern_lock = re.compile(r'^{0}_WRITELOCK$'.format( index_name)) # ex) myindex_WRITELOCK index_files = [] for file_name in filenames: if re.match(pattern_toc, file_name): index_files.append(file_name) elif re.match(pattern_seg, file_name): index_files.append(file_name) elif re.match(pattern_lock, file_name): index_files.append(file_name) # extract the index files for index_file in index_files: if index_config.get_storage_type() == 'ram': with self.__ram_storage.create_file( index_file) as r: r.write(zf.read(index_file)) else: zf.extract(index_file, path=self.__file_storage.folder) self.__logger.debug( '{0} has restored from {1}'.format( index_file, filename)) self.__logger.debug( '{0} has restored'.format(index_name)) # extract the raft data raft_data = pickle.loads(zf.read(RAFT_DATA_FILE)) self.__logger.debug( '{0} has restored'.format(RAFT_DATA_FILE)) return raft_data except Exception as ex: self.__logger.error( 'failed to restore indices: {0}'.format(ex)) finally: self.__logger.debug('deserializer has stopped') def is_healthy(self): return self.isHealthy() def is_alive(self): return self.isAlive() def is_ready(self): return self.isReady() def get_addr(self): return self.__self_addr def get_index_files(self, index_name): index_files = [] pattern_toc = re.compile( r'^_{0}_(\d+)\..+$'.format(index_name)) # ex) _myindex_0.toc pattern_seg = re.compile(r'^{0}_([a-z0-9]+)\..+$'.format( index_name)) # ex) myindex_zseabukc2nbpvh0u.seg pattern_lock = re.compile( r'^{0}_WRITELOCK$'.format(index_name)) # ex) myindex_WRITELOCK if self.__index_configs.get(index_name).get_storage_type() == "ram": storage = self.__ram_storage else: storage = self.__file_storage for file_name in list(storage.list()): if re.match(pattern_toc, file_name): index_files.append(file_name) elif re.match(pattern_seg, file_name): index_files.append(file_name) elif re.match(pattern_lock, file_name): index_files.append(file_name) return index_files @staticmethod def get_index_config_file(index_name): return '{0}_CONFIG'.format(index_name) def get_index_names(self): index_names = [] pattern_toc = re.compile(r'^_(.+)_\d+\.toc$') for filename in list(self.__file_storage.list()): match = pattern_toc.search(filename) if match and match.group(1) not in index_names: index_names.append(match.group(1)) for filename in list(self.__ram_storage.list()): match = pattern_toc.search(filename) if match and match.group(1) not in index_names: index_names.append(match.group(1)) return index_names def is_index_exist(self, index_name): return self.__file_storage.index_exists( indexname=index_name) or self.__ram_storage.index_exists( indexname=index_name) def is_index_open(self, index_name): return index_name in self.__indices @replicated def open_index(self, index_name, index_config=None): return self.__open_index(index_name, index_config=index_config) def __open_index(self, index_name, index_config=None): start_time = time.time() index = None try: # open the index index = self.__indices.get(index_name) if index is None: self.__logger.debug('opening {0}'.format(index_name)) if index_config is None: # set saved index config with open( os.path.join( self.__file_storage.folder, self.get_index_config_file(index_name)), 'rb') as f: self.__index_configs[index_name] = pickle.loads( f.read()) else: # set given index config self.__index_configs[index_name] = index_config if self.__index_configs[index_name].get_storage_type( ) == 'ram': index = self.__ram_storage.open_index( indexname=index_name, schema=self.__index_configs[index_name].get_schema()) else: index = self.__file_storage.open_index( indexname=index_name, schema=self.__index_configs[index_name].get_schema()) self.__indices[index_name] = index self.__logger.info('{0} has opened'.format(index_name)) # open the index writer self.__open_writer(index_name) except Exception as ex: self.__logger.error('failed to open {0}: {1}'.format( index_name, ex)) finally: self.__record_metrics(start_time, 'open_index') return index @replicated def close_index(self, index_name): return self.__close_index(index_name) def __close_index(self, index_name): start_time = time.time() index = None try: # close the index writer self.__close_writer(index_name) # close the index index = self.__indices.pop(index_name) if index is not None: self.__logger.debug('closing {0}'.format(index_name)) index.close() self.__logger.info('{0} has closed'.format(index_name)) except Exception as ex: self.__logger.error('failed to close {0}: {1}'.format( index_name, ex)) finally: self.__record_metrics(start_time, 'close_index') return index @replicated def create_index(self, index_name, index_config): return self.__create_index(index_name, index_config) def __create_index(self, index_name, index_config): if self.is_index_exist(index_name): # open the index return self.__open_index(index_name, index_config=index_config) start_time = time.time() index = None with self.__lock: try: self.__logger.debug('creating {0}'.format(index_name)) # set index config self.__index_configs[index_name] = index_config self.__logger.debug( self.__index_configs[index_name].get_storage_type()) # create the index if self.__index_configs[index_name].get_storage_type( ) == 'ram': index = self.__ram_storage.create_index( self.__index_configs[index_name].get_schema(), indexname=index_name) else: index = self.__file_storage.create_index( self.__index_configs[index_name].get_schema(), indexname=index_name) self.__indices[index_name] = index self.__logger.info('{0} has created'.format(index_name)) # save the index config with open( os.path.join(self.__file_storage.folder, self.get_index_config_file(index_name)), 'wb') as f: f.write(pickle.dumps(index_config)) # open the index writer self.__open_writer(index_name) except Exception as ex: self.__logger.error('failed to create {0}: {1}'.format( index_name, ex)) finally: self.__record_metrics(start_time, 'create_index') return index @replicated def delete_index(self, index_name): return self.__delete_index(index_name) def __delete_index(self, index_name): # close index index = self.__close_index(index_name) start_time = time.time() with self.__lock: try: self.__logger.debug('deleting {0}'.format(index_name)) # delete index files for filename in self.get_index_files(index_name): self.__file_storage.delete_file(filename) self.__logger.debug('{0} was deleted'.format(filename)) self.__logger.info('{0} has deleted'.format(index_name)) # delete the index config self.__index_configs.pop(index_name, None) os.remove( os.path.join(self.__file_storage.folder, self.get_index_config_file(index_name))) except Exception as ex: self.__logger.error('failed to delete {0}: {1}'.format( index_name, ex)) finally: self.__record_metrics(start_time, 'delete_index') return index def get_index(self, index_name): return self.__get_index(index_name) def __get_index(self, index_name): start_time = time.time() try: index = self.__indices.get(index_name) except Exception as ex: raise ex finally: self.__record_metrics(start_time, 'get_index') return index def __start_auto_commit_timer(self, index_name, period): timer = self.__auto_commit_timers.get(index_name, None) if timer is None: self.__auto_commit_timers[index_name] = threading.Timer( period, self.__auto_commit_index, kwargs={ 'index_name': index_name, 'period': period }) self.__auto_commit_timers[index_name].start() self.__logger.debug( 'auto commit timer for {0} were started'.format(index_name)) def __stop_auto_commit_timer(self, index_name): timer = self.__auto_commit_timers.pop(index_name, None) if timer is not None: timer.cancel() self.__logger.debug( 'auto commit timer for {0} were stopped'.format(index_name)) def __auto_commit_index(self, index_name, period): self.__stop_auto_commit_timer(index_name) self.__commit_index(index_name) self.__start_auto_commit_timer(index_name, period=period) def __open_writer(self, index_name): writer = None try: writer = self.__writers.get(index_name, None) if writer is None or writer.is_closed: self.__logger.debug( 'opening writer for {0}'.format(index_name)) writer = self.__indices.get(index_name).writer() self.__writers[index_name] = writer self.__logger.debug( 'writer for {0} has opened'.format(index_name)) self.__start_auto_commit_timer( index_name, period=self.__index_configs.get( index_name).get_writer_auto_commit_period()) except Exception as ex: self.__logger.error('failed to open writer for {0}: {1}'.format( index_name, ex)) return writer def __close_writer(self, index_name): writer = None try: self.__stop_auto_commit_timer(index_name) # close the index writer = self.__writers.pop(index_name, None) if writer is not None: self.__logger.debug( 'closing writer for {0}'.format(index_name)) writer.commit() self.__logger.debug( 'writer for {0} has closed'.format(index_name)) except Exception as ex: self.__logger.error('failed to close writer for {0}: {1}'.format( index_name, ex)) return writer def __get_writer(self, index_name): return self.__writers.get(index_name, None) def __get_searcher(self, index_name, weighting=None): try: if weighting is None: searcher = self.__indices.get(index_name).searcher() else: searcher = self.__indices.get(index_name).searcher( weighting=weighting) except Exception as ex: raise ex return searcher @replicated def commit_index(self, index_name): return self.__commit_index(index_name) def __commit_index(self, index_name): start_time = time.time() success = False with self.__lock: try: self.__logger.debug('committing {0}'.format(index_name)) self.__get_writer(index_name).commit() self.__open_writer(index_name) # reopen writer self.__logger.info('{0} has committed'.format(index_name)) success = True except Exception as ex: self.__logger.error('failed to commit index {0}: {1}'.format( index_name, ex)) finally: self.__record_metrics(start_time, 'commit_index') return success @replicated def rollback_index(self, index_name): return self.__rollback_index(index_name) def __rollback_index(self, index_name): start_time = time.time() success = False with self.__lock: try: self.__logger.debug('rolling back {0}'.format(index_name)) self.__get_writer(index_name).cancel() self.__open_writer(index_name) # reopen writer self.__logger.info('{0} has rolled back'.format(index_name)) success = True except Exception as ex: self.__logger.error('failed to rollback index {0}: {1}'.format( index_name, ex)) finally: self.__record_metrics(start_time, 'rollback_index') return success @replicated def optimize_index(self, index_name): return self.__optimize_index(index_name) def __optimize_index(self, index_name): start_time = time.time() success = False with self.__lock: try: self.__logger.debug('optimizing {0}'.format(index_name)) self.__get_writer(index_name).commit(optimize=True, merge=False) self.__open_writer(index_name) # reopen writer self.__logger.info('{0} has optimized'.format(index_name)) success = True except Exception as ex: self.__logger.error('failed to optimize {0}: {1}'.format( index_name, ex)) finally: self.__record_metrics(start_time, 'optimize_index') return success def get_doc_count(self, index_name): try: cnt = self.__indices.get(index_name).doc_count() except Exception as ex: raise ex return cnt def get_schema(self, index_name): try: schema = self.__indices.get(index_name).schema except Exception as ex: raise ex return schema @replicated def put_document(self, index_name, doc_id, fields): return self.__put_document(index_name, doc_id, fields) def __put_document(self, index_name, doc_id, fields): doc = copy.deepcopy(fields) doc[self.__index_configs.get(index_name).get_doc_id_field()] = doc_id return self.__put_documents(index_name, [doc]) @replicated def put_documents(self, index_name, docs): return self.__put_documents(index_name, docs) def __put_documents(self, index_name, docs): start_time = time.time() with self.__lock: try: self.__logger.debug( 'putting documents to {0}'.format(index_name)) # count = self.__get_writer(index_name).update_documents(docs) count = 0 for doc in docs: self.__get_writer(index_name).update_document(**doc) count += 1 self.__logger.info('{0} documents has put to {1}'.format( count, index_name)) except Exception as ex: self.__logger.error( 'failed to put documents to {0}: {1}'.format( index_name, ex)) count = -1 finally: self.__record_metrics(start_time, 'put_documents') return count def get_document(self, index_name, doc_id): try: results_page = self.search_documents( index_name, doc_id, self.__index_configs.get(index_name).get_doc_id_field(), 1, page_len=1) if results_page.total > 0: self.__logger.debug('{0} was got from {1}'.format( doc_id, index_name)) else: self.__logger.debug('{0} did not exist in {1}'.format( doc_id, index_name)) except Exception as ex: raise ex return results_page @replicated def delete_document(self, index_name, doc_id): return self.__delete_document(index_name, doc_id) def __delete_document(self, index_name, doc_id): return self.__delete_documents(index_name, [doc_id]) @replicated def delete_documents(self, index_name, doc_ids): return self.__delete_documents(index_name, doc_ids) def __delete_documents(self, index_name, doc_ids): start_time = time.time() with self.__lock: try: self.__logger.debug( 'deleting documents from {0}'.format(index_name)) # count = self.__get_writer(index_name).delete_documents(doc_ids, doc_id_field=self.__index_configs.get( # index_name).get_doc_id_field()) count = 0 for doc_id in doc_ids: count += self.__get_writer(index_name).delete_by_term( self.__index_configs.get( index_name).get_doc_id_field(), doc_id) self.__logger.info('{0} documents has deleted from {1}'.format( count, index_name)) except Exception as ex: self.__logger.error( 'failed to delete documents in bulk to {0}: {1}'.format( index_name, ex)) count = -1 finally: self.__record_metrics(start_time, 'delete_documents') return count def search_documents(self, index_name, query, search_field, page_num, page_len=10, weighting=None, **kwargs): start_time = time.time() try: searcher = self.__get_searcher(index_name, weighting=weighting) query_parser = QueryParser(search_field, self.get_schema(index_name)) query_obj = query_parser.parse(query) results_page = searcher.search_page(query_obj, page_num, pagelen=page_len, **kwargs) self.__logger.info('{0} documents ware searched from {1}'.format( results_page.total, index_name)) except Exception as ex: raise ex finally: self.__record_metrics(start_time, 'search_documents') return results_page @replicated def create_snapshot(self): self.__create_snapshot() def __create_snapshot(self): self.forceLogCompaction() def get_snapshot_file_name(self): return self.__conf.fullDumpFile def is_snapshot_exist(self): return os.path.exists(self.get_snapshot_file_name()) def open_snapshot_file(self): with self.__lock: try: file = open(self.get_snapshot_file_name(), mode='rb') except Exception as ex: raise ex return file
class TestReadWrite(unittest.TestCase): def __init__(self, *args, **kwargs): super(TestReadWrite, self).__init__(*args, **kwargs) self.fs = FileStorage(".") def make_postings(self): postings = [(1, 23), (3, 45), (12, 2), (34, 21), (43, 7), (67, 103), (68, 1), (102, 31), (145, 4), (212, 9), (283, 30), (291, 6), (412, 39), (900, 50), (905, 28), (1024, 8), (1800, 13), (2048, 3), (15000, 40)] return postings def make_file(self, name): return self.fs.create_file(name + "_test.pst") def open_file(self, name): return self.fs.open_file(name + "_test.pst") def delete_file(self, name): try: self.fs.delete_file(name + "_test.pst") except OSError: pass def test_readwrite(self): format = Frequency(None) postings = self.make_postings() postfile = self.make_file("readwrite") try: fpw = FilePostingWriter(postfile, blocklimit=8) fpw.start(format) for id, freq in postings: fpw.write(id, format.encode(freq)) fpw.close() postfile = self.open_file("readwrite") fpr = FilePostingReader(postfile, 0, format) #self.assertEqual(postings, list(fpr.items_as("frequency"))) fpr.close() finally: self.delete_file("readwrite") def test_skip(self): format = Frequency(None) postings = self.make_postings() postfile = self.make_file("skip") try: fpw = FilePostingWriter(postfile, blocklimit=8) fpw.start(format) for id, freq in postings: fpw.write(id, format.encode(freq)) fpw.close() postfile = self.open_file("skip") fpr = FilePostingReader(postfile, 0, format) #fpr.skip_to(220) #self.assertEqual(postings[10:], list(fpr.items_as("frequency"))) fpr.close() finally: self.delete_file("skip") def roundtrip(self, postings, format, astype): postfile = self.make_file(astype) readback = None try: fpw = FilePostingWriter(postfile, blocklimit=8) fpw.start(format) for id, value in postings: fpw.write(id, format.encode(value)) fpw.close() postfile = self.open_file(astype) fpr = FilePostingReader(postfile, 0, format) readback = list(fpr.all_as(astype)) fpr.close() finally: self.delete_file(astype) return readback def test_existence_postings(self): postings = [] docnum = 0 for _ in xrange(0, 20): docnum += randint(1, 10) postings.append((docnum, 1)) self.assertEqual( postings, self.roundtrip(postings, Existence(None), "frequency")) def test_docboost_postings(self): postings = [] docnum = 0 for _ in xrange(0, 20): docnum += randint(1, 10) freq = randint(1, 1000) boost = byte_to_float(float_to_byte(random() * 2)) postings.append((docnum, (freq, boost))) self.assertEqual( postings, self.roundtrip(postings, DocBoosts(None), "docboosts")) def test_position_postings(self): postings = [] docnum = 0 for _ in xrange(0, 20): docnum += randint(1, 10) posns = [] pos = 0 for __ in xrange(0, randint(1, 10)): pos += randint(1, 10) posns.append(pos) postings.append((docnum, posns)) self.assertEqual( postings, self.roundtrip(postings, Positions(None), "positions")) as_freq = [(docnum, len(posns)) for docnum, posns in postings] self.assertEqual( as_freq, self.roundtrip(postings, Positions(None), "frequency")) def test_character_postings(self): postings = [] docnum = 0 for _ in xrange(0, 20): docnum += randint(1, 10) posns = [] pos = 0 endchar = 0 for __ in xrange(0, randint(1, 10)): pos += randint(1, 10) startchar = endchar + randint(3, 10) endchar = startchar + randint(3, 10) posns.append((pos, startchar, endchar)) postings.append((docnum, posns)) self.assertEqual( postings, self.roundtrip(postings, Characters(None), "characters")) as_posns = [(docnum, [pos for pos, sc, ec in posns]) for docnum, posns in postings] self.assertEqual( as_posns, self.roundtrip(postings, Characters(None), "positions")) as_freq = [(docnum, len(posns)) for docnum, posns in as_posns] self.assertEqual( as_freq, self.roundtrip(postings, Characters(None), "frequency")) def test_posboost_postings(self): postings = [] docnum = 0 for _ in xrange(0, 3): docnum += randint(1, 10) posns = [] pos = 0 for __ in xrange(0, randint(1, 3)): pos += randint(1, 10) boost = byte_to_float(float_to_byte(random() * 2)) posns.append((pos, boost)) postings.append((docnum, posns)) self.assertEqual( postings, self.roundtrip(postings, PositionBoosts(None), "position_boosts")) as_posns = [(docnum, [pos for pos, boost in posns]) for docnum, posns in postings] self.assertEqual( as_posns, self.roundtrip(postings, PositionBoosts(None), "positions")) as_freq = [(docnum, len(posns)) for docnum, posns in postings] self.assertEqual( as_freq, self.roundtrip(postings, PositionBoosts(None), "frequency")) def test_charboost_postings(self): postings = [] docnum = 0 for _ in xrange(0, 20): docnum += randint(1, 10) posns = [] pos = 0 endchar = 0 for __ in xrange(0, randint(1, 10)): pos += randint(1, 10) startchar = endchar + randint(3, 10) endchar = startchar + randint(3, 10) boost = byte_to_float(float_to_byte(random() * 2)) posns.append((pos, startchar, endchar, boost)) postings.append((docnum, posns)) self.assertEqual( postings, self.roundtrip(postings, CharacterBoosts(None), "character_boosts")) as_chars = [(docnum, [(pos, sc, ec) for pos, sc, ec, bst in posns]) for docnum, posns in postings] self.assertEqual( as_chars, self.roundtrip(postings, CharacterBoosts(None), "characters")) as_posbsts = [(docnum, [(pos, bst) for pos, sc, ec, bst in posns]) for docnum, posns in postings] self.assertEqual( as_posbsts, self.roundtrip(postings, CharacterBoosts(None), "position_boosts")) as_posns = [(docnum, [pos for pos, sc, ec, bst in posns]) for docnum, posns in postings] self.assertEqual( as_posns, self.roundtrip(postings, CharacterBoosts(None), "positions")) as_freq = [(docnum, len(posns)) for docnum, posns in as_posns] self.assertEqual( as_freq, self.roundtrip(postings, CharacterBoosts(None), "frequency"))
class TestReadWrite(unittest.TestCase): def __init__(self, *args, **kwargs): super(TestReadWrite, self).__init__(*args, **kwargs) self.fs = FileStorage(".") def make_postings(self): postings = [(1, 23), (3, 45), (12, 2), (34, 21), (43, 7), (67, 103), (68, 1), (102, 31), (145, 4), (212, 9), (283, 30), (291, 6), (412, 39), (900, 50), (905, 28), (1024, 8), (1800, 13), (2048, 3), (15000, 40)] return postings def make_file(self, name): return self.fs.create_file(name+"_test.pst") def open_file(self, name): return self.fs.open_file(name+"_test.pst") def delete_file(self, name): try: self.fs.delete_file(name+"_test.pst") except OSError: pass def test_readwrite(self): format = Frequency(None) postings = self.make_postings() postfile = self.make_file("readwrite") try: fpw = FilePostingWriter(postfile, blocklimit=8) fpw.start(format) for id, freq in postings: fpw.write(id, format.encode(freq)) fpw.close() postfile = self.open_file("readwrite") fpr = FilePostingReader(postfile, 0, format) #self.assertEqual(postings, list(fpr.items_as("frequency"))) fpr.close() finally: self.delete_file("readwrite") def test_skip(self): format = Frequency(None) postings = self.make_postings() postfile = self.make_file("skip") try: fpw = FilePostingWriter(postfile, blocklimit=8) fpw.start(format) for id, freq in postings: fpw.write(id, format.encode(freq)) fpw.close() postfile = self.open_file("skip") fpr = FilePostingReader(postfile, 0, format) #fpr.skip_to(220) #self.assertEqual(postings[10:], list(fpr.items_as("frequency"))) fpr.close() finally: self.delete_file("skip") def roundtrip(self, postings, format, astype): postfile = self.make_file(astype) readback = None try: fpw = FilePostingWriter(postfile, blocklimit=8) fpw.start(format) for id, value in postings: fpw.write(id, format.encode(value)) fpw.close() postfile = self.open_file(astype) fpr = FilePostingReader(postfile, 0, format) readback = list(fpr.all_as(astype)) fpr.close() finally: self.delete_file(astype) return readback def test_existence_postings(self): postings = [] docnum = 0 for _ in xrange(0, 20): docnum += randint(1, 10) postings.append((docnum, 1)) self.assertEqual(postings, self.roundtrip(postings, Existence(None), "frequency")) def test_docboost_postings(self): postings = [] docnum = 0 for _ in xrange(0, 20): docnum += randint(1, 10) freq = randint(1, 1000) boost = byte_to_float(float_to_byte(random() * 2)) postings.append((docnum, (freq, boost))) self.assertEqual(postings, self.roundtrip(postings, DocBoosts(None), "docboosts")) def test_position_postings(self): postings = [] docnum = 0 for _ in xrange(0, 20): docnum += randint(1, 10) posns = [] pos = 0 for __ in xrange(0, randint(1, 10)): pos += randint(1, 10) posns.append(pos) postings.append((docnum, posns)) self.assertEqual(postings, self.roundtrip(postings, Positions(None), "positions")) as_freq = [(docnum, len(posns)) for docnum, posns in postings] self.assertEqual(as_freq, self.roundtrip(postings, Positions(None), "frequency")) def test_character_postings(self): postings = [] docnum = 0 for _ in xrange(0, 20): docnum += randint(1, 10) posns = [] pos = 0 endchar = 0 for __ in xrange(0, randint(1, 10)): pos += randint(1, 10) startchar = endchar + randint(3, 10) endchar = startchar + randint(3, 10) posns.append((pos, startchar, endchar)) postings.append((docnum, posns)) self.assertEqual(postings, self.roundtrip(postings, Characters(None), "characters")) as_posns = [(docnum, [pos for pos, sc, ec in posns]) for docnum, posns in postings] self.assertEqual(as_posns, self.roundtrip(postings, Characters(None), "positions")) as_freq = [(docnum, len(posns)) for docnum, posns in as_posns] self.assertEqual(as_freq, self.roundtrip(postings, Characters(None), "frequency")) def test_posboost_postings(self): postings = [] docnum = 0 for _ in xrange(0, 3): docnum += randint(1, 10) posns = [] pos = 0 for __ in xrange(0, randint(1, 3)): pos += randint(1, 10) boost = byte_to_float(float_to_byte(random() * 2)) posns.append((pos, boost)) postings.append((docnum, posns)) self.assertEqual(postings, self.roundtrip(postings, PositionBoosts(None), "position_boosts")) as_posns = [(docnum, [pos for pos, boost in posns]) for docnum, posns in postings] self.assertEqual(as_posns, self.roundtrip(postings, PositionBoosts(None), "positions")) as_freq = [(docnum, len(posns)) for docnum, posns in postings] self.assertEqual(as_freq, self.roundtrip(postings, PositionBoosts(None), "frequency")) def test_charboost_postings(self): postings = [] docnum = 0 for _ in xrange(0, 20): docnum += randint(1, 10) posns = [] pos = 0 endchar = 0 for __ in xrange(0, randint(1, 10)): pos += randint(1, 10) startchar = endchar + randint(3, 10) endchar = startchar + randint(3, 10) boost = byte_to_float(float_to_byte(random() * 2)) posns.append((pos, startchar, endchar, boost)) postings.append((docnum, posns)) self.assertEqual(postings, self.roundtrip(postings, CharacterBoosts(None), "character_boosts")) as_chars = [(docnum, [(pos, sc, ec) for pos, sc, ec, bst in posns]) for docnum, posns in postings] self.assertEqual(as_chars, self.roundtrip(postings, CharacterBoosts(None), "characters")) as_posbsts = [(docnum, [(pos, bst) for pos, sc, ec, bst in posns]) for docnum, posns in postings] self.assertEqual(as_posbsts, self.roundtrip(postings, CharacterBoosts(None), "position_boosts")) as_posns = [(docnum, [pos for pos, sc, ec, bst in posns]) for docnum, posns in postings] self.assertEqual(as_posns, self.roundtrip(postings, CharacterBoosts(None), "positions")) as_freq = [(docnum, len(posns)) for docnum, posns in as_posns] self.assertEqual(as_freq, self.roundtrip(postings, CharacterBoosts(None), "frequency"))