class AddressResolver(object): pool = None def __init__(self, minthreads=1, maxthreads=4): self.pool = ThreadPool(minthreads=minthreads, maxthreads=maxthreads) # unclosed ThreadPool leads to reactor hangs at shutdown # this is a problem in many situation, so better enforce pool stop here reactor.addSystemEventTrigger( "before", "shutdown", self.pool.stop ) self.pool.start() def get_host_by_name(self, address): d = defer.Deferred() def func(): try: reactor.callFromThread( d.callback, socket.gethostbyname(address) ) except Exception as e: reactor.callFromThread(d.errback, e) self.pool.callInThread(func) return d def close(self): self.pool.stop()
class AddressResolver(object): pool = None def __init__(self, minthreads=1, maxthreads=4): self.pool = ThreadPool(minthreads=minthreads, maxthreads=maxthreads) # unclosed ThreadPool leads to reactor hangs at shutdown # this is a problem in many situation, so better enforce pool stop here reactor.addSystemEventTrigger("before", "shutdown", self.pool.stop) self.pool.start() def get_host_by_name(self, address): d = defer.Deferred() def func(): try: reactor.callFromThread(d.callback, socket.gethostbyname(address)) except Exception as e: reactor.callFromThread(d.errback, e) self.pool.callInThread(func) return d def close(self): self.pool.stop()
def run_by_pool(): urls = [LINK_URL % n for n in range(1, PAGE_NUM + 1)] print (urls) # 5*20 最大100线程在运行 error_log("start:" + str(time.time())) pool = ThreadPool(minthreads=1, maxthreads=5) for url in urls: pool.callInThread(start, url, save_path=IMAGE_SAVE_BASEPATH) pool.start() while True: # 每20s判断一次线程池状态,没有线程正在运行则停止下载进程 time.sleep(20) if len(pool.working) == 0: pool.stop() error_log("end:" + str(time.time())) break
class NodeEngagementMutex: """ TODO: Does this belong on middleware? TODO: There are a couple of ways this can break. If one fo the jobs hangs, the whole thing will hang. Also, if there are fewer successfully completed than percent_to_complete_before_release, the partial queue will never release. TODO: Make registry per... I guess Policy? It's weird to be able to accidentally enact again. """ log = Logger("Policy") def __init__( self, callable_to_engage, # TODO: typing.Protocol nodes, network_middleware, percent_to_complete_before_release=5, note=None, threadpool_size=120, timeout=20, *args, **kwargs): self.f = callable_to_engage self.nodes = nodes self.network_middleware = network_middleware self.args = args self.kwargs = kwargs self.completed = {} self.failed = {} self._started = False self._finished = False self.timeout = timeout self.percent_to_complete_before_release = percent_to_complete_before_release self._partial_queue = Queue() self._completion_queue = Queue() self._block_until_this_many_are_complete = math.ceil( len(nodes) * self.percent_to_complete_before_release / 100) self.nodes_contacted_during_partial_block = False self.when_complete = Deferred( ) # TODO: Allow cancelling via KB Interrupt or some other way? if note is None: self._repr = f"{callable_to_engage} to {len(nodes)} nodes" else: self._repr = f"{note}: {callable_to_engage} to {len(nodes)} nodes" self._threadpool = ThreadPool(minthreads=threadpool_size, maxthreads=threadpool_size, name=self._repr) self.log.info(f"NEM spinning up {self._threadpool}") self._threadpool.callInThread(self._bail_on_timeout) def __repr__(self): return self._repr def _bail_on_timeout(self): while True: if self.when_complete.called: return duration = datetime.datetime.now() - self._started if duration.seconds >= self.timeout: try: self._threadpool.stop() except AlreadyQuit: raise RuntimeError( "Is there a race condition here? If this line is being hit, it's a bug." ) raise RuntimeError( f"Timed out. Nodes completed: {self.completed}") time.sleep(.5) def block_until_success_is_reasonably_likely(self): """ https://www.youtube.com/watch?v=OkSLswPSq2o """ if len(self.completed) < self._block_until_this_many_are_complete: try: completed_for_reasonable_likelihood_of_success = self._partial_queue.get( timeout=self.timeout) # TODO: Shorter timeout here? except Empty: raise RuntimeError( f"Timed out. Nodes completed: {self.completed}") self.log.debug( f"{len(self.completed)} nodes were contacted while blocking for a little while." ) return completed_for_reasonable_likelihood_of_success else: return self.completed def block_until_complete(self): if self.total_disposed() < len(self.nodes): try: _ = self._completion_queue.get( timeout=self.timeout ) # Interesting opportuntiy to pass some data, like the list of contacted nodes above. except Empty: raise RuntimeError( f"Timed out. Nodes completed: {self.completed}") if not reactor.running and not self._threadpool.joined: # If the reactor isn't running, the user *must* call this, because this is where we stop. self._threadpool.stop() def _handle_success(self, response, node): if response.status_code == 201: self.completed[node] = response else: assert False # TODO: What happens if this is a 300 or 400 level response? (A 500 response will propagate as an error and be handled in the errback chain.) if self.nodes_contacted_during_partial_block: self._consider_finalizing() else: if len(self.completed) >= self._block_until_this_many_are_complete: contacted = tuple(self.completed.keys()) self.nodes_contacted_during_partial_block = contacted self.log.debug( f"Blocked for a little while, completed {contacted} nodes") self._partial_queue.put(contacted) return response def _handle_error(self, failure, node): self.failed[node] = failure # TODO: Add a failfast mode? self._consider_finalizing() self.log.warn(f"{node} failed: {failure}") def total_disposed(self): return len(self.completed) + len(self.failed) def _consider_finalizing(self): if not self._finished: if self.total_disposed() == len(self.nodes): # TODO: Consider whether this can possibly hang. self._finished = True if reactor.running: reactor.callInThread(self._threadpool.stop) self._completion_queue.put(self.completed) self.when_complete.callback(self.completed) self.log.info(f"{self} finished.") else: raise RuntimeError("Already finished.") def _engage_node(self, node): maybe_coro = self.f(node, network_middleware=self.network_middleware, *self.args, **self.kwargs) d = ensureDeferred(maybe_coro) d.addCallback(self._handle_success, node) d.addErrback(self._handle_error, node) return d def start(self): if self._started: raise RuntimeError("Already started.") self._started = datetime.datetime.now() self.log.info(f"NEM Starting {self._threadpool}") for node in self.nodes: self._threadpool.callInThread(self._engage_node, node) self._threadpool.start()
class pylabsTaskletRunner(TaskletRunner): def __init__(self, engine, threadpoolsize=10): self.engine = engine # Job queue self._queue = Queue.Queue() # Threadpool self._runners = list() self._threadpool = None reactor.addSystemEventTrigger('after', 'startup', self.start, threadpoolsize) reactor.addSystemEventTrigger('before', 'shutdown', self.shutdown) def start(self, threadpoolsize): self._threadpool = ThreadPool(minthreads=threadpoolsize, maxthreads=threadpoolsize + 1) # Set up threadpool q.logger.log('[PMTASKLETS] Constructing taskletserver threadpool', 6) self._threadpool.start() for i in xrange(threadpoolsize): runner = TaskletRunnerThread(self._queue) self._runners.append(runner) self._threadpool.callInThread(runner.run) self._running = True def queue(self, params, author=None, name=None, tags=None, priority=-1, logname=None): author = author or '*' name = name or '*' tags = tags or list() priority = priority if priority > -1 else -1 q.logger.log('[PMTASKLETS] Queue: params=%s, author=%s, name=%s, ' 'tags=%s, priority=%d' % \ (params, author, name, tags, priority), 4) # Wrap the tasklet executor methods so the appname (for logging) is set # correctly def logwrapper(func): @functools.wraps(func) def _wrapped(*args, **kwargs): import pylabs oldappname = pylabs.q.application.appname if logname: pylabs.q.application.appname = \ 'applicationserver:pmtasklets:%s' % logname else: pylabs.q.application.appname = \ 'applicationserver:pmtasklets' try: ret = func(*args, **kwargs) finally: pylabs.q.application.appname = oldappname return ret return _wrapped execute_args = { 'author': author, 'name': name, 'tags': tags, 'priority': priority, 'params': params, 'wrapper': logwrapper, } #Append list of tasklet methods to run to the queue self._queue.put((self.engine, execute_args, )) def shutdown(self): q.logger.log('Shutting down tasklet runner', 5) self._running = False #Tell all threads to stop running for runner in self._runners: runner.keep_running = False self._threadpool.stop() @classmethod def install(cls): log.msg('Installing pylabs tasklet runner') import applicationserver applicationserver.TaskletRunner = cls
class WorkerPool: """ A generalized class that can start multiple workers in a thread pool with values drawn from the given value factory object, and wait for their completion and a given number of successes (a worker returning something without throwing an exception). """ class TimedOut(WorkerPoolException): """Raised if waiting for the target number of successes timed out.""" def __init__(self, timeout: float, *args, **kwargs): self.timeout = timeout super().__init__( message_prefix=f"Execution timed out after {timeout}s", *args, **kwargs) class OutOfValues(WorkerPoolException): """Raised if the value factory is out of values, but the target number was not reached.""" def __init__(self, *args, **kwargs): super().__init__( message_prefix= "Execution stopped before completion - not enough available values", *args, **kwargs) def __init__(self, worker: Callable[[Any], Any], value_factory: Callable[[int], Optional[List[Any]]], target_successes, timeout: float, stagger_timeout: float = 0, threadpool_size: int = None): # TODO: make stagger_timeout a part of the value factory? self._worker = worker self._value_factory = value_factory self._timeout = timeout self._stagger_timeout = stagger_timeout self._target_successes = target_successes thread_pool_kwargs = {} if threadpool_size is not None: thread_pool_kwargs['minthreads'] = threadpool_size thread_pool_kwargs['maxthreads'] = threadpool_size self._threadpool = ThreadPool(**thread_pool_kwargs) # These three tasks must be run in separate threads # to avoid being blocked by workers in the thread pool. self._bail_on_timeout_thread = Thread(target=self._bail_on_timeout) self._produce_values_thread = Thread(target=self._produce_values) self._process_results_thread = Thread(target=self._process_results) self._successes = {} self._failures = {} self._started_tasks = 0 self._finished_tasks = 0 self._cancel_event = Event() self._result_queue = Queue() self._target_value = Future() self._producer_error = Future() self._results_lock = Lock() self._threadpool_stop_lock = Lock() self._threadpool_stopped = False def start(self): # TODO: check if already started? self._threadpool.start() self._produce_values_thread.start() self._process_results_thread.start() self._bail_on_timeout_thread.start() def cancel(self): """ Cancels the tasks enqueued in the thread pool and stops the producer thread. """ self._cancel_event.set() def _stop_threadpool(self): # This can be called from multiple threads # (`join()` itself can be called from multiple threads, # and we also attempt to stop the pool from the `_process_results()` thread). with self._threadpool_stop_lock: if not self._threadpool_stopped: self._threadpool.stop() self._threadpool_stopped = True def _check_for_producer_error(self): # Check for any unexpected exceptions in the producer thread if self._producer_error.is_set(): # Will raise if Future was set with an exception self._producer_error.get() def join(self): """ Waits for all the threads to finish. Can be called several times. """ self._produce_values_thread.join() self._process_results_thread.join() self._bail_on_timeout_thread.join() # In most cases `_threadpool` will be stopped by the `_process_results()` thread. # But in case there's some unexpected bug in its code, we're making sure the pool is stopped # to avoid the whole process hanging. self._stop_threadpool() self._check_for_producer_error() def _sleep(self, timeout): """ Sleeps for a given timeout, can be interrupted by a cancellation event. """ if self._cancel_event.wait(timeout): raise Cancelled def block_until_target_successes(self) -> Dict: """ Blocks until the target number of successes is reached. Returns a dictionary of values matched to results. Can be called several times. """ self._check_for_producer_error() result = self._target_value.get() if result == TIMEOUT_TRIGGERED: raise self.TimedOut(timeout=self._timeout, failures=self.get_failures()) elif result == PRODUCER_STOPPED: raise self.OutOfValues(failures=self.get_failures()) return result def get_failures(self) -> Dict: """ Get the current failures, as a dictionary of values to thrown exceptions. """ with self._results_lock: return dict(self._failures) def get_successes(self) -> Dict: """ Get the current successes, as a dictionary of values to worker return values. """ with self._results_lock: return dict(self._successes) def _bail_on_timeout(self): """ A service thread that cancels the pool on timeout. """ if not self._cancel_event.wait(timeout=self._timeout): self._target_value.set(TIMEOUT_TRIGGERED) self._cancel_event.set() def _worker_wrapper(self, value): """ A wrapper that catches exceptions thrown by the worker and sends the results to the processing thread. """ try: # If we're in the cancelled state, interrupt early self._sleep(0) result = self._worker(value) self._result_queue.put(Success(value, result)) except Cancelled as e: self._result_queue.put(e) except BaseException as e: self._result_queue.put(Failure(value, sys.exc_info())) def _process_results(self): """ A service thread that processes worker results and waits for the target number of successes to be reached. """ producer_stopped = False success_event_reached = False while True: result = self._result_queue.get() if result == PRODUCER_STOPPED: producer_stopped = True else: self._finished_tasks += 1 if isinstance(result, Success): with self._results_lock: self._successes[result.value] = result.result len_successes = len(self._successes) if not success_event_reached and len_successes == self._target_successes: # A protection for the case of repeating values. # Only trigger the target value once. success_event_reached = True self._target_value.set(self.get_successes()) if isinstance(result, Failure): with self._results_lock: self._failures[result.value] = result.exc_info if success_event_reached: # no need to continue processing results self.cancel() # to cancel the timeout thread break if producer_stopped and self._finished_tasks == self._started_tasks: self.cancel() # to cancel the timeout thread self._target_value.set(PRODUCER_STOPPED) break self._stop_threadpool() def _produce_values(self): while True: try: with self._results_lock: len_successes = len(self._successes) batch = self._value_factory(len_successes) if not batch: break self._started_tasks += len(batch) for value in batch: # There is a possible race between `callInThread()` and `stop()`, # But we never execute them at the same time, # because `join()` checks that the producer thread is stopped. self._threadpool.callInThread(self._worker_wrapper, value) self._sleep(self._stagger_timeout) except Cancelled: break except BaseException: self._producer_error.set_exception() self.cancel() break self._result_queue.put(PRODUCER_STOPPED)
class WorkerPool: """ A generalized class that can start multiple workers in a thread pool with values drawn from the given value factory object, and wait for their completion and a given number of successes (a worker returning something without throwing an exception). """ class TimedOut(Exception): "Raised if waiting for the target number of successes timed out." class OutOfValues(Exception): "Raised if the value factory is out of values, but the target number was not reached." def __init__(self, worker: Callable[[Any], Any], value_factory: Callable[[int], Optional[List[Any]]], target_successes, timeout: float, stagger_timeout: float = 0, threadpool_size: int = None): # TODO: make stagger_timeout a part of the value factory? self._worker = worker self._value_factory = value_factory self._timeout = timeout self._stagger_timeout = stagger_timeout self._target_successes = target_successes thread_pool_kwargs = {} if threadpool_size is not None: thread_pool_kwargs['minthreads'] = threadpool_size thread_pool_kwargs['maxthreads'] = threadpool_size self._threadpool = ThreadPool(**thread_pool_kwargs) # These three tasks must be run in separate threads # to avoid being blocked by workers in the thread pool. self._bail_on_timeout_thread = Thread(target=self._bail_on_timeout) self._produce_values_thread = Thread(target=self._produce_values) self._process_results_thread = Thread(target=self._process_results) self._successes = {} self._failures = {} self._started_tasks = 0 self._finished_tasks = 0 self._cancel_event = Event() self._result_queue = Queue() self._target_value = SetOnce() self._unexpected_error = SetOnce() self._results_lock = Lock() self._stopped = False def start(self): # TODO: check if already started? self._threadpool.start() self._produce_values_thread.start() self._process_results_thread.start() self._bail_on_timeout_thread.start() def cancel(self): """ Cancels the tasks enqueued in the thread pool and stops the producer thread. """ self._cancel_event.set() def join(self): """ Waits for all the threads to finish. Can be called several times. """ if self._stopped: return # or raise AlreadyStopped? self._produce_values_thread.join() self._process_results_thread.join() self._bail_on_timeout_thread.join() # protect from a possible race try: self._threadpool.stop() except AlreadyQuit: pass self._stopped = True if self._unexpected_error.is_set(): e = self._unexpected_error.get() raise RuntimeError(f"Unexpected error in the producer thread: {e}") def _sleep(self, timeout): """ Sleeps for a given timeout, can be interrupted by a cancellation event. """ if self._cancel_event.wait(timeout): raise Cancelled def block_until_target_successes(self) -> Dict: """ Blocks until the target number of successes is reached. Returns a dictionary of values matched to results. Can be called several times. """ if self._unexpected_error.is_set(): # So that we don't raise it again when join() is called e = self._unexpected_error.get_and_clear() raise RuntimeError(f"Unexpected error in the producer thread: {e}") result = self._target_value.get() if result == TIMEOUT_TRIGGERED: raise self.TimedOut() elif result == PRODUCER_STOPPED: raise self.OutOfValues() return result def get_failures(self) -> Dict: """ Get the current failures, as a dictionary of values to thrown exceptions. """ with self._results_lock: return dict(self._failures) def get_successes(self) -> Dict: """ Get the current successes, as a dictionary of values to worker return values. """ with self._results_lock: return dict(self._successes) def _bail_on_timeout(self): """ A service thread that cancels the pool on timeout. """ if not self._cancel_event.wait(timeout=self._timeout): self._target_value.set(TIMEOUT_TRIGGERED) self._cancel_event.set() def _worker_wrapper(self, value): """ A wrapper that catches exceptions thrown by the worker and sends the results to the processing thread. """ try: # If we're in the cancelled state, interrupt early self._sleep(0) result = self._worker(value) self._result_queue.put(Success(value, result)) except Cancelled as e: self._result_queue.put(e) except BaseException as e: self._result_queue.put(Failure(value, str(e))) def _process_results(self): """ A service thread that processes worker results and waits for the target number of successes to be reached. """ producer_stopped = False success_event_reached = False while True: result = self._result_queue.get() if result == PRODUCER_STOPPED: producer_stopped = True else: self._finished_tasks += 1 if isinstance(result, Success): with self._results_lock: self._successes[result.value] = result.result len_successes = len(self._successes) if not success_event_reached and len_successes == self._target_successes: # A protection for the case of repeating values. # Only trigger the target value once. success_event_reached = True self._target_value.set(self.get_successes()) if isinstance(result, Failure): with self._results_lock: self._failures[result.value] = result.exception if producer_stopped and self._finished_tasks == self._started_tasks: self.cancel() # to cancel the timeout thread self._target_value.set(PRODUCER_STOPPED) break def _produce_values(self): while True: try: with self._results_lock: len_successes = len(self._successes) batch = self._value_factory(len_successes) if not batch: break self._started_tasks += len(batch) for value in batch: # There is a possible race between `callInThread()` and `stop()`, # But we never execute them at the same time, # because `join()` checks that the producer thread is stopped. self._threadpool.callInThread(self._worker_wrapper, value) self._sleep(self._stagger_timeout) except Cancelled: break except BaseException as e: self._unexpected_error.set(e) self.cancel() break self._result_queue.put(PRODUCER_STOPPED)
class pylabsTaskletRunner(TaskletRunner): def __init__(self, engine, threadpoolsize=10): self.engine = engine # Job queue self._queue = Queue.Queue() # Threadpool self._runners = list() self._threadpool = None reactor.addSystemEventTrigger('after', 'startup', self.start, threadpoolsize) reactor.addSystemEventTrigger('before', 'shutdown', self.shutdown) def start(self, threadpoolsize): self._threadpool = ThreadPool(minthreads=threadpoolsize, maxthreads=threadpoolsize + 1) # Set up threadpool q.logger.log('[PMTASKLETS] Constructing taskletserver threadpool', 6) self._threadpool.start() for i in xrange(threadpoolsize): runner = TaskletRunnerThread(self._queue) self._runners.append(runner) self._threadpool.callInThread(runner.run) self._running = True def queue(self, params, author=None, name=None, tags=None, priority=-1, logname=None): author = author or '*' name = name or '*' tags = tags or list() priority = priority if priority > -1 else -1 q.logger.log('[PMTASKLETS] Queue: params=%s, author=%s, name=%s, ' 'tags=%s, priority=%d' % \ (params, author, name, tags, priority), 4) # Wrap the tasklet executor methods so the appname (for logging) is set # correctly def logwrapper(func): @functools.wraps(func) def _wrapped(*args, **kwargs): import pylabs oldappname = pylabs.q.application.appname if logname: pylabs.q.application.appname = \ 'applicationserver:pmtasklets:%s' % logname else: pylabs.q.application.appname = \ 'applicationserver:pmtasklets' try: ret = func(*args, **kwargs) finally: pylabs.q.application.appname = oldappname return ret return _wrapped execute_args = { 'author': author, 'name': name, 'tags': tags, 'priority': priority, 'params': params, 'wrapper': logwrapper, } #Append list of tasklet methods to run to the queue self._queue.put(( self.engine, execute_args, )) def shutdown(self): q.logger.log('Shutting down tasklet runner', 5) self._running = False #Tell all threads to stop running for runner in self._runners: runner.keep_running = False self._threadpool.stop() @classmethod def install(cls): log.msg('Installing pylabs tasklet runner') import applicationserver applicationserver.TaskletRunner = cls
class SwiftStorageProviderBackend(StorageProvider): """ Args: hs (HomeServer) config: The config returned by `parse_config` """ def __init__(self, hs, config): self.cache_directory = hs.config.media_store_path self.container = config["container"] self.cloud = config["cloud"] self.api_kwargs = {} if "region_name" in config: self.api_kwargs["region_name"] = config["region_name"] threadpool_size = config.get("threadpool_size", 40) self._download_pool = ThreadPool(name="swift-download-pool", maxthreads=threadpool_size) self._download_pool.start() def store_file(self, path, file_info): """See StorageProvider.store_file""" def _store_file(): connection = openstack.connection.from_config(**self.api_kwargs) connection.object_store.create_object(self.container, path, filename=os.path.join( self.cache_directory, path)) # XXX: reactor.callInThread doesn't return anything, so I don't think this does # what the author intended. return make_deferred_yieldable(reactor.callInThread(_store_file)) def fetch(self, path, file_info): """See StorageProvider.fetch""" logcontext = current_context() d = defer.Deferred() self._download_pool.callInThread(swift_download_task, self.container, self.api_kwargs, path, d, logcontext) return make_deferred_yieldable(d) @staticmethod def parse_config(config): """Called on startup to parse config supplied. This should parse the config and raise if there is a problem. The returned value is passed into the constructor. In this case we return a dict with fields, `bucket` and `storage_class` """ container = config["container"] cloud = config["cloud"] assert isinstance(container, string_types) result = {"container": container, "cloud": cloud} if "region_name" in config: result["region_name"] = config["region_name"] return result
class S3StorageProviderBackend(StorageProvider): """ Args: hs (HomeServer) config: The config returned by `parse_config` """ def __init__(self, hs, config): self.cache_directory = hs.config.media_store_path self.bucket = config["bucket"] self.storage_class = config["storage_class"] self.api_kwargs = {} if "region_name" in config: self.api_kwargs["region_name"] = config["region_name"] if "endpoint_url" in config: self.api_kwargs["endpoint_url"] = config["endpoint_url"] if "access_key_id" in config: self.api_kwargs["aws_access_key_id"] = config["access_key_id"] if "secret_access_key" in config: self.api_kwargs["aws_secret_access_key"] = config[ "secret_access_key"] threadpool_size = config.get("threadpool_size", 40) self._download_pool = ThreadPool(name="s3-download-pool", maxthreads=threadpool_size) self._download_pool.start() def store_file(self, path, file_info): """See StorageProvider.store_file""" def _store_file(): session = boto3.session.Session() session.resource("s3", **self.api_kwargs).Bucket( self.bucket).upload_file( Filename=os.path.join(self.cache_directory, path), Key=path, ExtraArgs={"StorageClass": self.storage_class}, ) # XXX: reactor.callInThread doesn't return anything, so I don't think this does # what the author intended. return make_deferred_yieldable(reactor.callInThread(_store_file)) def fetch(self, path, file_info): """See StorageProvider.fetch""" logcontext = current_context() d = defer.Deferred() self._download_pool.callInThread(s3_download_task, self.bucket, self.api_kwargs, path, d, logcontext) return make_deferred_yieldable(d) @staticmethod def parse_config(config): """Called on startup to parse config supplied. This should parse the config and raise if there is a problem. The returned value is passed into the constructor. In this case we return a dict with fields, `bucket` and `storage_class` """ bucket = config["bucket"] storage_class = config.get("storage_class", "STANDARD") assert isinstance(bucket, string_types) assert storage_class in _VALID_STORAGE_CLASSES result = { "bucket": bucket, "storage_class": storage_class, } if "region_name" in config: result["region_name"] = config["region_name"] if "endpoint_url" in config: result["endpoint_url"] = config["endpoint_url"] if "access_key_id" in config: result["access_key_id"] = config["access_key_id"] if "secret_access_key" in config: result["secret_access_key"] = config["secret_access_key"] return result
class S3StorageProviderBackend(StorageProvider): """ Args: hs (HomeServer) config: The config returned by `parse_config` """ def __init__(self, hs, config): self.cache_directory = hs.config.media.media_store_path self.bucket = config["bucket"] self.storage_class = config["storage_class"] self.api_kwargs = {} if "region_name" in config: self.api_kwargs["region_name"] = config["region_name"] if "endpoint_url" in config: self.api_kwargs["endpoint_url"] = config["endpoint_url"] if "access_key_id" in config: self.api_kwargs["aws_access_key_id"] = config["access_key_id"] if "secret_access_key" in config: self.api_kwargs["aws_secret_access_key"] = config["secret_access_key"] self._s3_client = None self._s3_client_lock = threading.Lock() threadpool_size = config.get("threadpool_size", 40) self._s3_pool = ThreadPool(name="s3-pool", maxthreads=threadpool_size) self._s3_pool.start() # Manually stop the thread pool on shutdown. If we don't do this then # stopping Synapse takes an extra ~30s as Python waits for the threads # to exit. reactor.addSystemEventTrigger( "during", "shutdown", self._s3_pool.stop, ) def _get_s3_client(self): # this method is designed to be thread-safe, so that we can share a # single boto3 client across multiple threads. # # (XXX: is creating a client actually a blocking operation, or could we do # this on the main thread, to simplify all this?) # first of all, do a fast lock-free check s3 = self._s3_client if s3: return s3 # no joy, grab the lock and repeat the check with self._s3_client_lock: s3 = self._s3_client if not s3: b3_session = boto3.session.Session() self._s3_client = s3 = b3_session.client("s3", **self.api_kwargs) return s3 def store_file(self, path, file_info): """See StorageProvider.store_file""" parent_logcontext = current_context() def _store_file(): with LoggingContext(parent_context=parent_logcontext): self._get_s3_client().upload_file( Filename=os.path.join(self.cache_directory, path), Bucket=self.bucket, Key=path, ExtraArgs={"StorageClass": self.storage_class}, ) return make_deferred_yieldable( threads.deferToThreadPool(reactor, self._s3_pool, _store_file) ) def fetch(self, path, file_info): """See StorageProvider.fetch""" logcontext = current_context() d = defer.Deferred() def _get_file(): s3_download_task(self._get_s3_client(), self.bucket, path, d, logcontext) self._s3_pool.callInThread(_get_file) return make_deferred_yieldable(d) @staticmethod def parse_config(config): """Called on startup to parse config supplied. This should parse the config and raise if there is a problem. The returned value is passed into the constructor. In this case we return a dict with fields, `bucket` and `storage_class` """ bucket = config["bucket"] storage_class = config.get("storage_class", "STANDARD") assert isinstance(bucket, string_types) assert storage_class in _VALID_STORAGE_CLASSES result = { "bucket": bucket, "storage_class": storage_class, } if "region_name" in config: result["region_name"] = config["region_name"] if "endpoint_url" in config: result["endpoint_url"] = config["endpoint_url"] if "access_key_id" in config: result["access_key_id"] = config["access_key_id"] if "secret_access_key" in config: result["secret_access_key"] = config["secret_access_key"] return result