class Showcase(SessionSingleton): ''' This class enables other classes to translate names previously registered to actual objects ''' def __init__(self): self._objects = WeakValueDictionary() self._cases = WeakValueDictionary() def get(self, oid): ''' :param str oid: the oid registered in the NameAuthority ''' return self._objects.get(oid) def put(self, instance): ''' :param INamed oid: the exposed object. ''' self._objects[instance.oid] = instance def get_case(self, tag, default=None): ''' :param str tag: The tag that the returned case should have ''' return self._cases.get(tag, default)
def __init__(self, diskCacheFolder, name, logger): self._diskCacheFile = os.path.join(diskCacheFolder, '{}.json.bz2'.format(name)) self._logger = logger # Initialize memory data cache self.__typeDataCache = {} self.__attributeDataCache = {} self.__effectDataCache = {} self.__modifierDataCache = {} self.__fingerprint = None # Initialize weakref object cache self.__typeObjCache = WeakValueDictionary() self.__attributeObjCache = WeakValueDictionary() self.__effectObjCache = WeakValueDictionary() self.__modifierObjCache = WeakValueDictionary() # If cache doesn't exist, silently finish initialization if not os.path.exists(self._diskCacheFile): return # Read JSON into local variable try: with bz2.BZ2File(self._diskCacheFile, 'r') as file: jsonData = file.read().decode('utf-8') data = json.loads(jsonData) # If file doesn't exist, JSON load errors occur, or # anything else bad happens, do not load anything # and leave values as initialized except: msg = 'error during reading cache' self._logger.error(msg, childName='cacheHandler') # Load data into data cache, if no errors occurred # during JSON reading/parsing else: self.__updateMemCache(data)
class Signal(object): def __init__(self): self.__slots = WeakValueDictionary() def __call__(self, *args, **kargs): for key in self.__slots: func, selfid = key if selfid is not None: func(self.__slots[key], *args, **kargs) else: func(*args, **kargs) def __get_key(self, slot): if hasattr(slot, 'im_func'): return (slot.im_func, id(slot.im_self)) else: return (slot, None) def connect(self, slot): key = self.__get_key(slot) if hasattr(slot, 'im_func'): self.__slots[key] = slot.im_self else: self.__slots[key] = slot def disconnect(self, slot): key = self.__get_key(slot) if key in self.__slots: self.__slots.pop(key) def clear(self): self.__slots.clear()
def __init__(self, nodes, ways, min_lat, max_lat, min_lon, max_lon, *args, node_color=(0, 0, 0), way_color="allrandom", bg_color="white", enlargement=50000): """Export map data (nodes and ways) as a map like image. Params: nodes - The raw nodes as read by any OSM file reader ways - The raw ways as read by any OSM file reader min_lat - The southern border of the map max_lat - The northern border of the map min_lon - The western border of the map max_lon - The eastern border of the map node_color - The colour of the nodes in the image way_color - The colour of the ways in the image bg_color - The colour of the image background enlargement - Multiplication factor from map coordinate to pixel coordinate. Determines image size. """ super(MapImageExporter, self).__init__(min_lat, max_lat, min_lon, max_lon, bg_color, enlargement) self.logger = logging.getLogger('.'.join((__name__, type(self).__name__))) self.nodes = WeakValueDictionary(nodes) self.ways = WeakValueDictionary(ways) self.node_color = node_color self.way_color = way_color
class Signal(object): """ A Signal is callable. When called, it calls all the callables in its slots. """ def __init__(self): self._slots = WeakValueDictionary() def __call__(self, *args, **kargs): for key in self._slots: func, _ = key func(self._slots[key], *args, **kargs) def connect(self, slot): """ Slots must call this to register a callback method. :param slot: callable """ key = (slot.im_func, id(slot.im_self)) self._slots[key] = slot.im_self def disconnect(self, slot): """ They can also unregister their callbacks here. :param slot: callable """ key = (slot.im_func, id(slot.im_self)) if key in self._slots: self._slots.pop(key) def clear(self): """ Clears all slots """ self._slots.clear()
def __init__(self, cache_path): self._cache_path = os.path.abspath(cache_path) # Initialize memory data cache self.__type_data_cache = {} self.__attribute_data_cache = {} self.__effect_data_cache = {} self.__modifier_data_cache = {} self.__fingerprint = None # Initialize weakref object cache self.__type_obj_cache = WeakValueDictionary() self.__attribute_obj_cache = WeakValueDictionary() self.__effect_obj_cache = WeakValueDictionary() self.__modifier_obj_cache = WeakValueDictionary() # If cache doesn't exist, silently finish initialization if not os.path.exists(self._cache_path): return # Read JSON into local variable try: with bz2.BZ2File(self._cache_path, 'r') as file: json_data = file.read().decode('utf-8') data = json.loads(json_data) except KeyboardInterrupt: raise # If file doesn't exist, JSON load errors occur, or # anything else bad happens, do not load anything # and leave values as initialized except: msg = 'error during reading cache' logger.error(msg) # Load data into data cache, if no errors occurred # during JSON reading/parsing else: self.__update_mem_cache(data)
class LRUCache: def __init__(self, max_size): self.LRU = [Node(time(), "none%s"%i) for i in range(max_size)] self.search = WeakValueDictionary() for i in self.LRU: self.search[i.name] = i def __setitem__(self, name, value): q = self.search.get(name, None) if q: q.data = value q.time = time() else: lru = self.LRU[0] self.search.pop(lru.name) lru.data = value lru.time = time() lru.name = name self.search[lru.name] = lru self.LRU.sort() def get(self, name, default=None): pos = None try: pos = self.search.__getitem__(name) pos.time = time() return pos.data except KeyError: if default is not None: return default else: raise
def __init__(self, name, parent=None, nolabel=False, **kwargs): """ Creates a maya menu or menu item :param name: Used to access a menu via its parent. Unless the nolabel flag is set to True, the name will also become the label of the menu. :type name: str :param parent: Optional - The parent menu. If None, this will create a toplevel menu. If parent menu is a Menu instance, this will create a menu item. Default is None. :type parent: Menu|None :param nolabel: Optional - If nolabel=True, the label flag for the maya command will not be overwritten by name :type nolabel: bool :param kwargs: all keyword arguments used for the cmds.menu/cmds.menuitem command :type kwargs: named arguments :returns: None :rtype: None :raises: errors.MenuExistsError """ WeakValueDictionary.__init__(self) self.__menustring = None self.__parent = parent self.__name = name self.__kwargs = kwargs if not nolabel: self.__kwargs['label'] = name if parent is not None: if name in parent: raise errors.MenuExistsError("A menu with this name: %s and parent: %s exists already!" % (name, parent)) cmds.setParent(parent.menustring(), menu=1) self.__kwargs['parent'] = parent.menustring() self.__menustring = cmds.menuItem(**self.__kwargs) parent[name] = self else: cmds.setParent('MayaWindow') self.__menustring = cmds.menu(**self.__kwargs)
def __init__(self, maxsize, cullsize=2, peakmult=10, aggressive_gc=True, *args, **kwargs): self.cullsize = max(2, cullsize) self.maxsize = max(cullsize, maxsize) self.aggressive_gc = aggressive_gc self.peakmult = peakmult self.queue = deque() WeakValueDictionary.__init__(self, *args, **kwargs)
def __init__(self): SendObject.__init__(self) WeakValueDictionary.__init__(self) def remove_wr(wr, selfref=ref(self)): self = selfref() if self is not None: del self[wr.key] self._remove = remove_wr
class Signal(object): def __init__(self,sender,max_connections=0, exc_catch=True): self._maxconn=max_connections self._sender=sender self._exc_catch=exc_catch self._slots = WeakValueDictionary() self._lock = threading.Lock() @property def connected(self): return len(self._slots) def connect(self, slot): if self._maxconn>0 and len(self._slots)>=self._maxconn: raise SignalError("Maximum number of connections was exceeded") assert callable(slot), "Signal slots must be callable." # Check for **kwargs try: argspec = inspect.getargspec(slot) except TypeError: try: argspec = inspect.getargspec(slot.__call__) except (TypeError, AttributeError): argspec = None if argspec: assert argspec[2] is not None, \ "Signal receivers must accept keyword arguments (**kwargs)." self._lock.acquire() try: key = (slot.im_func, id(slot.im_self)) self._slots[key] = slot.im_self finally: self._lock.release() def disconnect(self, slot): self._lock.acquire() try: key = (slot.im_func, id(slot.im_self)) if key in self._slots: self._slots.pop(key) finally: self._lock.release() def __call__(self,*args,**kwargs): assert not kwargs.has_key("sender"), \ "'sender' keyword argument is occupied" responses = [] kwargs["sender"]=self._sender for key in self._slots: func, _ = key try: response=func(self._slots[key], *args, **kwargs) responses.append((func,response)) except Exception, err: if self._exc_catch: self.exception("Slot {0} exception: {1}".format(str(func), err)) else: raise Exception(traceback.format_exc()) return responses
class Monitor(QObject): """File monitor This monitor can be used to track single files """ def __init__(self, **kwargs): super(Monitor, self).__init__(**kwargs) self.watched = WeakValueDictionary() self.delMapper = QSignalMapper(self) self.delMapper.mapped[str].connect(self.unmonitorFile) self.watcher = MonitorWithRename(parent=self) self.watcher.fileChanged.connect(self._onFileChanged) def monitorFile(self, path): """Monitor a file and return an object that tracks only `path` :rtype: SingleFileWatcher :return: an object tracking `path`, the same object is returned if the method is called with the same path. """ path = os.path.abspath(path) self.watcher.addPath(path) proxy = self.watched.get(path) if not proxy: proxy = SingleFileWatcher(path) proxy.destroyed.connect(self.delMapper.map) self.delMapper.setMapping(proxy, path) self.watched[path] = proxy return proxy @Slot(str) def unmonitorFile(self, path): """Stop monitoring a file Since there is only one :any:`SingleFileWatcher` object per path, all objects monitoring `path` will not receive notifications anymore. To let only one object stop monitoring the file, simply disconnect its `modified` signal. When the :any:`SingleFileWatcher` object returned by method :any:`monitorFile` is destroyed, the file is automatically un-monitored. """ path = os.path.abspath(path) self.watcher.removePath(path) self.watched.pop(path, None) @Slot(str) def _onFileChanged(self, path): proxy = self.watched.get(path) if proxy: proxy.modified.emit()
def __init__(self, n=None): WeakValueDictionary.__init__(self) if n<1: # user doesn't want any Most Recent value queue self.__class__ = WeakValueDictionary # revert to regular WVD return if isinstance(n, int): self.n = n # size limit else: self.n = 50 self.i = 0 # counter self._keepDict = {} # most recent queue
def __init__(self, n=None): WeakValueDictionary.__init__(self) if n<1: # user doesn't want any Most Recent value queue self.__class__ = WeakValueDictionary # revert to regular WVD return if n is True: # assign default value self.n = 50 else: self.n = int(n) # size limit self._head = self._tail = None self._keepDict = {} # most recent queue
def __init__(self, callback): WeakValueDictionary.__init__ (self) # The superclass WeakValueDictionary assigns self._remove as a # callback to all the KeyedRef it creates. So we have to override # self._remove. # Note however that self._remobe is *not* a method because, as a # callback, it can be invoked after the dictionary is collected. # So it is a plain function, stored as an *instance* attribute. def remove(wr, _callback=callback, _original=self._remove): _original(wr) _callback(wr.key) self._remove = remove
def __init__(self): atexit.register(self.cleanUp) self._polling_period = self.DefaultPollingPeriod self.polling_timers = {} self._polling_enabled = True self._attrs = WeakValueDictionary() self._devs = WeakValueDictionary() self._auths = WeakValueDictionary() import taurusmanager manager = taurusmanager.TaurusManager() self._serialization_mode = manager.getSerializationMode()
class Subscriber(Link, MutableMapping): def __init__(self, connect): super(Subscriber, self).__init__() self._dict = WeakValueDictionary() self._names = WeakKeyDictionary() self.connect = connect self.call() def subscribe(self, obj): self.__setitem__('_', obj) def unsubscribe(self, obj): wr = get_wrapper(obj) self._unsubscribe(wr) keys = wr._unsubscribe(self) keys.append(self._names[wr]) for key in keys: self._dict.pop(key, None) def __setitem__(self, key, obj): wr = get_wrapper(obj) self._names[wr] = key self._subscribe(wr) keys = wr._subscribe(self) keys.append(key) assert not(key != '_' and key in self._dict), 'same name' for key in keys: self._dict[key] = wr.obj return obj def __getitem__(self, key): return self._dict[key] def __delitem__(self, key): self.unsubscribe(self[key]) def __hash__(self): return Link.__hash__(self) def kill(self): for obj in set(self.links): self.unsubscribe(obj) super(Subscriber, self).kill() def call(self): pass def send(self, data): self.connect.send(data) def receive(self, data): receive(self, data)
class MapImageExporter(MapExporter): def __init__(self, nodes, ways, min_lat, max_lat, min_lon, max_lon, *args, node_color=(0, 0, 0), way_color="allrandom", bg_color="white", enlargement=50000): """Export map data (nodes and ways) as a map like image. Params: nodes - The raw nodes as read by any OSM file reader ways - The raw ways as read by any OSM file reader min_lat - The southern border of the map max_lat - The northern border of the map min_lon - The western border of the map max_lon - The eastern border of the map node_color - The colour of the nodes in the image way_color - The colour of the ways in the image bg_color - The colour of the image background enlargement - Multiplication factor from map coordinate to pixel coordinate. Determines image size. """ super(MapImageExporter, self).__init__(min_lat, max_lat, min_lon, max_lon, bg_color, enlargement) self.logger = logging.getLogger('.'.join((__name__, type(self).__name__))) self.nodes = WeakValueDictionary(nodes) self.ways = WeakValueDictionary(ways) self.node_color = node_color self.way_color = way_color def export(self, filename="export.png"): """Export the information to an image file Params: filename - The filename to export to, must have a valid image extention. Default: export.png """ self.logger.info('Exporting a map image to %s', filename) # Draw all ways self.logger.info('Drawing the ways') for id, way in self.ways.items(): coords = [ ((self.nodes[node].lon - self.min_lon) * self.enlargement, (self.nodes[node].lat - self.min_lat) * self.enlargement) for node in way.nodes] self.draw.line(coords, fill=self.way_color) # draw all nodes as points self.logger.info('Drawing the nodes') for id, node in self.nodes.items(): self.draw.point( ((node.lon - self.min_lon) * self.enlargement, (node.lat - self.min_lat) * self.enlargement), fill=self.node_color) self._save_image(filename)
class ObjectPool(object): """ This class allows to fetch mvc model objects using their UUID. This requires to model to have a property called "uuid". All class inheriting from the base 'Model' class will have this. If implementing a custom model, the UUID property is responsible for the removal and addition to the pool when it changes values. Also see the UUIDPropIntel class for an example implementation. We can use this to store complex relations between objects where references to each other can be replaced with the UUID. For a multi-threaded version see ThreadedObjectPool. """ def __init__(self, *args, **kwargs): object.__init__(self) self._objects = WeakValueDictionary() def add_or_get_object(self, obj): try: self.add_object(obj, force=False, silent=False) return obj except KeyError: return self.get_object(obj.uuid) def add_object(self, obj, force=False, fail_on_duplicate=False): if not obj.uuid in self._objects or force: self._objects[obj.uuid] = obj elif fail_on_duplicate: raise KeyError, "UUID %s is already taken by another object %s, cannot add object %s" % (obj.uuid, self._objects[obj.uuid], obj) else: # Just change the objects uuid, will break refs, but # it prevents issues with inherited properties etc. logger.warning("A duplicate UUID was passed to an ObjectPool for a %s object." % obj) obj.uuid = get_new_uuid() def change_all_uuids(self): # first get a copy off all uuids & objects: items = self._objects.items() for uuid, obj in items: # @UnusedVariable obj.uuid = get_new_uuid() def remove_object(self, obj): if obj.uuid in self._objects and self._objects[obj.uuid] == obj: del self._objects[obj.uuid] def get_object(self, uuid): obj = self._objects.get(uuid, None) return obj def clear(self): self._objects.clear()
def __init__(self, allow_none_id=False): """ :param bool allow_none_id: Flag specifying if calling :meth:`add` with an entity that does not have an ID is allowed. """ # self.__allow_none_id = allow_none_id # List of cached entities. This is the only place we are holding a # real reference to the entity. self.__entities = [] # Dictionary mapping entity IDs to entities for fast lookup by ID. self.__id_map = WeakValueDictionary() # Dictionary mapping entity slugs to entities for fast lookup by slug. self.__slug_map = WeakValueDictionary()
def __init__(self): self._key_to_registration = dict() self._singleton_instances = dict() self._singleton_instances_lock = threading.Lock() self._weak_references = WeakValueDictionary() self._weak_references_lock = threading.Lock() self._thread_local = threading.local()
def __init__(self, config): self._entries = dict() # indexed by (domain, user) tuple self._entry_by_domain = WeakValueDictionary() self._config = config
def __init__(self, store_uri, bucket_length=NOTSET, cache_size=3): self._bucket_store = BaseBucketStore.from_uri(store_uri=store_uri, default_scheme='file') # set empty fields self._bucket_length = None self._bucket_count = 0 self._len = 0 self._bucket_cache = None self._cache_size = None self.bucket_key_fmt = None # load current settings try: for attr, value in self._bucket_store.fetch_head().items(): setattr(self, attr, value) except BucketNotFound: pass # apply new settings self.bucket_length = bucket_length # LRU store for objects fetched from disk self.cache_size = cache_size # weakref store for objects still in use self._active_buckets = WeakValueDictionary() self._active_items = WeakValueDictionary() # calcualate metadata self._length = self._fetch_length() # store new settings self._store_head()
def __init__(self, data_file, index_file): # Dict storing currently loaded values self.cache = WeakValueDictionary() # Data file, and index mapping key to data offset self.data_file = data_file self.data_total = os.fstat(data_file.fileno()).st_size self.index_file = index_file self.index = {} self.used = set() # Read index data into dict self.index_file.seek(0) for line in self.index_file.readlines(): parts = line.strip().split() if len(parts) != 2: continue offset_str, key_str = parts offset = int(offset_str) key = pickle.loads(base64.decodebytes(key_str.encode('ascii'))) self.index[key] = offset # Seek both to EOF self.data_file.seek(0, os.SEEK_END) self.index_file.seek(0, os.SEEK_END)
def __init__(self): # core data self.root = FolderNode(0, "root", None, "root", 1, 1, 0, FifoStrategy()) self.nodes = WeakValueDictionary() self.nodes[0] = self.root self.pools = {} self.renderNodes = {} self.tasks = {} self.rules = [] self.poolShares = {} self.commands = {} # deduced properties self.nodeMaxId = 0 self.poolMaxId = 0 self.renderNodeMaxId = 0 self.taskMaxId = 0 self.commandMaxId = 0 self.poolShareMaxId = 0 self.toCreateElements = [] self.toModifyElements = [] self.toArchiveElements = [] # listeners self.nodeListener = ObjectListener(self.onNodeCreation, self.onNodeDestruction, self.onNodeChange) self.taskListener = ObjectListener(self.onTaskCreation, self.onTaskDestruction, self.onTaskChange) # # JSA # self.taskGroupListener = ObjectListener(self.onTaskCreation, self.onTaskDestruction, self.onTaskGroupChange) self.renderNodeListener = ObjectListener( self.onRenderNodeCreation, self.onRenderNodeDestruction, self.onRenderNodeChange ) self.poolListener = ObjectListener(self.onPoolCreation, self.onPoolDestruction, self.onPoolChange) self.commandListener = ObjectListener( onCreationEvent=self.onCommandCreation, onChangeEvent=self.onCommandChange ) self.poolShareListener = ObjectListener(self.onPoolShareCreation) self.modifiedNodes = []
class PQ(object): """Convenient queue manager.""" table = 'queue' template_path = os.path.dirname(__file__) def __init__(self, *args, **kwargs): self.queue_class = kwargs.pop('queue_class', Queue) self.params = args, kwargs self.queues = WeakValueDictionary() def __getitem__(self, name): try: return self.queues[name] except KeyError: return self.queues.setdefault( name, self.queue_class(name, *self.params[0], **self.params[1]) ) def close(self): self[''].close() def create(self): queue = self[''] with open(os.path.join(self.template_path, 'create.sql'), 'r') as f: sql = f.read() with queue._transaction() as cursor: cursor.execute(sql, {'name': Literal(queue.table)})
class PQ(object): """Convenient queue manager.""" table = 'queue' def __init__(self, *args, **kwargs): self.params = args, kwargs self.queues = WeakValueDictionary() def __getitem__(self, name): try: return self.queues[name] except KeyError: return self.queues.setdefault( name, Queue(name, *self.params[0], **self.params[1]) ) def close(self): self[''].close() def create(self): q = self[''] conn = q._conn() sql = _read_sql('create') with transaction(conn) as cursor: cursor.execute(sql, {'name': Literal(q.table)})
class _TransformExecutorServices(object): """Schedules and completes TransformExecutors. Controls the concurrency as appropriate for the applied transform the executor exists for. """ def __init__(self, executor_service): self._executor_service = executor_service self._scheduled = set() self._parallel = _ParallelEvaluationState( self._executor_service, self._scheduled) self._serial_cache = WeakValueDictionary() def parallel(self): return self._parallel def serial(self, step): cached = self._serial_cache.get(step) if not cached: cached = _SerialEvaluationState(self._executor_service, self._scheduled) self._serial_cache[step] = cached return cached @property def executors(self): return frozenset(self._scheduled)
def __init__(self, datagram_socket): """Constructor Arguments: datagram_socket -- the root socket; this must be a bound, unconnected datagram socket """ if datagram_socket.type != socket.SOCK_DGRAM: raise InvalidSocketError("datagram_socket is not of " + "type SOCK_DGRAM") try: datagram_socket.getsockname() except: raise InvalidSocketError("datagram_socket is unbound") try: datagram_socket.getpeername() except: pass else: raise InvalidSocketError("datagram_socket is connected") self.datagram_socket = datagram_socket self.payload = "" self.payload_peer_address = None self.connections = WeakValueDictionary()
def __init__(self): # Get a list of potential bind addresses addrs = socket.getaddrinfo(None, PORT, 0, socket.SOCK_STREAM, 0, socket.AI_PASSIVE) # Try to bind to each address socks = [] for family, type, proto, _canonname, addr in addrs: try: sock = socket.socket(family, type, proto) sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) if family == socket.AF_INET6: # Ensure an IPv6 listener doesn't also bind to IPv4, # since depending on the order of getaddrinfo return # values this could cause the IPv6 bind to fail sock.setsockopt(socket.IPPROTO_IPV6, socket.IPV6_V6ONLY, 1) sock.bind(addr) sock.listen(BACKLOG) sock.setblocking(0) socks.append(sock) except socket.error: pass if not socks: # None of the addresses worked raise ListenError("Couldn't bind listening socket") self._poll = _PendingConnPollSet() for sock in socks: self._poll.register(_ListeningSocket(sock), select.POLLIN) self._nonce_to_pending = WeakValueDictionary()
def __init__(self, *args, **kw): super(TaskQueue, self).__init__(*args, **kw) self.queue = Queue() self.results = WeakValueDictionary() self.results_lock = Lock()
from contextlib import contextmanager from functools import partial from time import time, sleep from traceback import print_exc from threading import Lock from weakref import WeakValueDictionary import socket import os from accelerator.compat import str_types, iteritems from accelerator import g status_tree = {} status_all = WeakValueDictionary() status_stacks_lock = Lock() # all currently (or recently) running launch.py PIDs class Children(set): def add(self, pid): with status_stacks_lock: set.add(self, pid) def remove(self, pid): with status_stacks_lock: d = status_all.pop(pid, None) if d and d.parent_pid in status_all: status_all[d.parent_pid].children.pop(pid, None) status_tree.pop(pid, None)
from io import TextIOBase, RawIOBase, IOBase, BufferedIOBase from traceback import extract_stack, print_stack from types import CodeType, FunctionType # noqa from typing import (Callable, Any, Union, Dict, List, TypeVar, Tuple, Set, Sequence, get_type_hints, TextIO, Optional, IO, BinaryIO) from warnings import warn from weakref import WeakKeyDictionary, WeakValueDictionary try: from typing import Type except ImportError: Type = None _type_hints_map = WeakKeyDictionary( ) # type: Dict[FunctionType, Dict[str, Any]] _functions_map = WeakValueDictionary() # type: Dict[CodeType, FunctionType] class _CallMemo: __slots__ = ('func', 'func_name', 'signature', 'typevars', 'arguments', 'type_hints') def __init__(self, func: Callable, frame=None, args: tuple = None, kwargs: Dict[str, Any] = None): self.func = func self.func_name = function_name(func) self.signature = inspect.signature(func) self.typevars = {} # type: Dict[Any, type]
copy_raw_to_string._always_inline_ = True copy_raw_to_string = func_with_new_name(copy_raw_to_string, 'copy_raw_to_%s' % name) return (copy_string_to_raw, copy_raw_to_string, copy_string_contents, _get_raw_buf) (copy_string_to_raw, copy_raw_to_string, copy_string_contents, _get_raw_buf_string) = _new_copy_contents_fun(STR, STR, Char, 'string') (copy_unicode_to_raw, copy_raw_to_unicode, copy_unicode_contents, _get_raw_buf_unicode) = _new_copy_contents_fun(UNICODE, UNICODE, UniChar, 'unicode') CONST_STR_CACHE = WeakValueDictionary() CONST_UNICODE_CACHE = WeakValueDictionary() class BaseLLStringRepr(Repr): def convert_const(self, value): if value is None: return nullptr(self.lowleveltype.TO) #value = getattr(value, '__self__', value) # for bound string methods if not isinstance(value, self.basetype): raise TyperError("not a str: %r" % (value, )) try: return self.CACHE[value] except KeyError: p = self.malloc(len(value)) for i in range(len(value)):
class Bitcoin: """ Conceptually, a Bitcoin object disallows some arithmetic operations (such as multiplying Bitcoin by Bitcoin) while allowing others (adding Bitcoin to Bitcoin or taking a fraction of Bitcoin). Since there is a minimum unit of bitcoin, a satoshi, it makes sense to do arithmetic operations on satoshis (integers) as much as possible, while using appropriately quantized Decimals when needed. This class wraps around both integer and Decimal amounts appropriately. Bitcoin is immutable and the implementation uses an intern pool to save memory. """ DECIMAL_PLACES = 8 # minimal unit is 1e-8, a "satoshi" # part of Bitcoin interning implementation: # weakrefs are used to avoid memory leaks, # lock is used for thread safety (maybe not # really needed in our use case, due to the GIL) __pool = WeakValueDictionary() __pool_lock = threading.Lock() # prevent the addition of extra attributes, # but we do add __weakref__ to allow the usage of weak references __slots__ = "satoshis", "decimal", "__weakref__" def __new__(cls, value): """ Since Bitcoin is immutable, we need to set its attributes in __new__ instead of __init__ """ if cls._is_bitcoin(value): # no need to copy an immutable return value try: value = cls.quantize(value) satoshis = int(value) dec = value / 10**cls.DECIMAL_PLACES except (ValueError, TypeError, InvalidOperation): raise BitcoinValueError( "Canot create Bitcoin with value '{}' of type {}".format( value, type(value).__name__, )) self = cls.__unintern(satoshis, dec) return self @classmethod def __unintern(cls, satoshis, dec): """ retrieve Bitcoin instance with given value """ with cls.__pool_lock: try: self = cls.__pool[satoshis] except KeyError: self = cls.__intern(satoshis, dec) return self @classmethod def __intern(cls, satoshis, dec): """ create and cache Bitcoin instance from value """ self = super().__new__(cls) object.__setattr__(self, "satoshis", satoshis) object.__setattr__(self, "decimal", dec) cls.__pool[satoshis] = self return self def __setattr__(self, *args): raise BitcoinAttributeError("Cannot set attribute on Bitcoin.") def __delattr__(self, *args): raise BitcoinAttributeError("Cannot delete attribute on Bitcoin.") def to_bytes(self, *args, **kwargs): """ Convenience method, since serialization is everywhere in Bitcoin! Since the usual format sent over the wire is satoshis encoded as bytes, we delegate to the 'satoshis' attribute. """ return self.satoshis.to_bytes(*args, **kwargs) def __repr__(self): return "Bitcoin('{!s}')".format(self.satoshis) def __str__(self): return str(self.satoshis) @staticmethod def _is_bitcoin(operand): return isinstance(operand, Bitcoin) def __eq__(self, other): return self is other def __hash__(self): return hash(self.satoshis) def __lt__(self, other): if self._is_bitcoin(other): return self.satoshis < other.satoshis return NotImplemented def __pos__(self): return self def __neg__(self): return Bitcoin(-1 * self.satoshis) def __add__(self, other): if not self._is_bitcoin(other): raise BitcoinTypeError( "Cannot add Bitcoin object to non-Bitcoin object.") return Bitcoin(self.satoshis + other.satoshis) def __sub__(self, other): if not self._is_bitcoin(other): raise BitcoinTypeError( "Cannot subtract non-Bitcoin object from Bitcoin object.") return Bitcoin(self.satoshis - other.satoshis) def __rsub__(self, other): if not self._is_bitcoin(other): raise BitcoinTypeError( "Cannot subtract Bitcoin object from non-Bitcoin object.") return Bitcoin(other.satoshis - self.satoshis) def __mul__(self, other): if self._is_bitcoin(other): raise BitcoinTypeError( "Cannot multiply Bitcoin object by another Bitcoin object.") return Bitcoin(self.satoshis * other) def __truediv__(self, other): if isinstance(other, Bitcoin): return self.decimal / other.decimal return Bitcoin(self.satoshis / other) def __rtruediv__(self, other): # pylint: disable=unused-argument raise BitcoinTypeError( "Cannot divide non-Bitcoin object by Bitcoin object.") __radd__ = __add__ __rmul__ = __mul__ @classmethod def quantize(cls, dec): dec = Decimal(dec) return dec.quantize(Decimal("10") * (-1 * cls.DECIMAL_PLACES), rounding=ROUND_HALF_EVEN) def __reduce__(self): """ since we made Bitcoin immutable, pickle will have problems with re-instantiation; this tells it how to do so """ return Bitcoin, (self.satoshis, ) def deconstruct(self): """ Needed so Django can write migration files. Return a 3-tuple of class import path, positional arguments, and keyword arguments. """ module_path = self.__class__.__module__ class_name = self.__class__.__name__ import_path = module_path + "." + class_name return import_path, (self.satoshis, ), {}
def __init__(self) -> None: self._content: Collection[GL.Pair] = set() # Do not serialize self._results: WeakValueDictionary[ Type[object], GL.GLResult] = WeakValueDictionary()
from typing import TYPE_CHECKING, Iterator from weakref import WeakValueDictionary from django.core.files.storage import Storage if TYPE_CHECKING: # Avoid circular imports from .fields import S3FileField FieldsDictType = WeakValueDictionary[str, 'S3FileField'] StoragesDictType = WeakValueDictionary[int, Storage] _fields: 'FieldsDictType' = WeakValueDictionary() _storages: 'StoragesDictType' = WeakValueDictionary() def register_field(field: 'S3FileField') -> None: field_id = field.id if field_id in _fields and not (_fields[field_id] is field): # This might be called multiple times, but it should always be consistent raise Exception(f'Cannot overwrite existing S3FileField declaration for {field_id}') _fields[field_id] = field storage = field.storage storage_label = id(storage) _storages[storage_label] = storage def get_field(field_id: str) -> 'S3FileField': """Get an S3FileFields by its __str__."""
class PluginManager: """ PluginManager is the core of CloudBot plugin loading. PluginManager loads Plugins, and adds their Hooks to easy-access dicts/lists. Each Plugin represents a file, and loads hooks onto itself using find_hooks. Plugins are the lowest level of abstraction in this class. There are four different plugin types: - CommandPlugin is for bot commands - RawPlugin hooks onto irc_raw irc lines - RegexPlugin loads a regex parameter, and executes on irc lines which match the regex - SievePlugin is a catch-all sieve, which all other plugins go through before being executed. :type bot: cloudbot.bot.CloudBot :type plugins: dict[str, Plugin] :type commands: dict[str, CommandHook] :type raw_triggers: dict[str, list[RawHook]] :type catch_all_triggers: list[RawHook] :type event_type_hooks: dict[cloudbot.event.EventType, list[EventHook]] :type regex_hooks: list[(re.__Regex, RegexHook)] :type sieves: list[SieveHook] """ def __init__(self, bot): """ Creates a new PluginManager. You generally only need to do this from inside cloudbot.bot.CloudBot :type bot: cloudbot.bot.CloudBot """ self.bot = bot self.plugins = {} self._plugin_name_map = WeakValueDictionary() self.commands = {} self.raw_triggers = {} self.catch_all_triggers = [] self.event_type_hooks = {} self.regex_hooks = [] self.sieves = [] self.cap_hooks = { "on_available": defaultdict(list), "on_ack": defaultdict(list) } self.connect_hooks = [] self.out_sieves = [] self.hook_hooks = defaultdict(list) self.perm_hooks = defaultdict(list) self._hook_waiting_queues = {} def find_plugin(self, title): """ Finds a loaded plugin and returns its Plugin object :param title: the title of the plugin to find :return: The Plugin object if it exists, otherwise None """ return self._plugin_name_map.get(title) @asyncio.coroutine def load_all(self, plugin_dir): """ Load a plugin from each *.py file in the given directory. Won't load any plugins listed in "disabled_plugins". :type plugin_dir: str """ plugin_dir = Path(plugin_dir) # Load all .py files in the plugins directory and any subdirectory # But ignore files starting with _ path_list = plugin_dir.rglob("[!_]*.py") # Load plugins asynchronously :O yield from asyncio.gather( *[self.load_plugin(path) for path in path_list], loop=self.bot.loop) @asyncio.coroutine def unload_all(self): yield from asyncio.gather( *[self.unload_plugin(path) for path in self.plugins.keys()], loop=self.bot.loop) @asyncio.coroutine def load_plugin(self, path): """ Loads a plugin from the given path and plugin object, then registers all hooks from that plugin. Won't load any plugins listed in "disabled_plugins". :type path: str | Path """ path = Path(path) file_path = path.resolve() file_name = file_path.name # Resolve the path relative to the current directory plugin_path = file_path.relative_to(self.bot.base_dir) title = '.'.join(plugin_path.parts[1:]).rsplit('.', 1)[0] if "plugin_loading" in self.bot.config: pl = self.bot.config.get("plugin_loading") if pl.get("use_whitelist", False): if title not in pl.get("whitelist", []): logger.info( 'Not loading plugin module "{}": plugin not whitelisted' .format(title)) return else: if title in pl.get("blacklist", []): logger.info( 'Not loading plugin module "{}": plugin blacklisted'. format(title)) return # make sure to unload the previously loaded plugin from this path, if it was loaded. if file_path in self.plugins: yield from self.unload_plugin(file_path) module_name = "plugins.{}".format(title) try: plugin_module = importlib.import_module(module_name) # if this plugin was loaded before, reload it if hasattr(plugin_module, "_cloudbot_loaded"): importlib.reload(plugin_module) except Exception: logger.exception("Error loading {}:".format(title)) return # create the plugin plugin = Plugin(str(file_path), file_name, title, plugin_module) # proceed to register hooks # create database tables yield from plugin.create_tables(self.bot) # run on_start hooks for on_start_hook in plugin.hooks["on_start"]: success = yield from self.launch( on_start_hook, Event(bot=self.bot, hook=on_start_hook)) if not success: logger.warning( "Not registering hooks from plugin {}: on_start hook errored" .format(plugin.title)) # unregister databases plugin.unregister_tables(self.bot) return self.plugins[plugin.file_path] = plugin self._plugin_name_map[plugin.title] = plugin for on_cap_available_hook in plugin.hooks["on_cap_available"]: for cap in on_cap_available_hook.caps: self.cap_hooks["on_available"][cap.casefold()].append( on_cap_available_hook) self._log_hook(on_cap_available_hook) for on_cap_ack_hook in plugin.hooks["on_cap_ack"]: for cap in on_cap_ack_hook.caps: self.cap_hooks["on_ack"][cap.casefold()].append( on_cap_ack_hook) self._log_hook(on_cap_ack_hook) for periodic_hook in plugin.hooks["periodic"]: task = async_util.wrap_future(self._start_periodic(periodic_hook)) plugin.tasks.append(task) self._log_hook(periodic_hook) # register commands for command_hook in plugin.hooks["command"]: for alias in command_hook.aliases: if alias in self.commands: logger.warning( "Plugin {} attempted to register command {} which was already registered by {}. " "Ignoring new assignment.".format( plugin.title, alias, self.commands[alias].plugin.title)) else: self.commands[alias] = command_hook self._log_hook(command_hook) # register raw hooks for raw_hook in plugin.hooks["irc_raw"]: if raw_hook.is_catch_all(): self.catch_all_triggers.append(raw_hook) else: for trigger in raw_hook.triggers: if trigger in self.raw_triggers: self.raw_triggers[trigger].append(raw_hook) else: self.raw_triggers[trigger] = [raw_hook] self._log_hook(raw_hook) # register events for event_hook in plugin.hooks["event"]: for event_type in event_hook.types: if event_type in self.event_type_hooks: self.event_type_hooks[event_type].append(event_hook) else: self.event_type_hooks[event_type] = [event_hook] self._log_hook(event_hook) # register regexps for regex_hook in plugin.hooks["regex"]: for regex_match in regex_hook.regexes: self.regex_hooks.append((regex_match, regex_hook)) self._log_hook(regex_hook) # register sieves for sieve_hook in plugin.hooks["sieve"]: self.sieves.append(sieve_hook) self._log_hook(sieve_hook) # register connect hooks for connect_hook in plugin.hooks["on_connect"]: self.connect_hooks.append(connect_hook) self._log_hook(connect_hook) for out_hook in plugin.hooks["irc_out"]: self.out_sieves.append(out_hook) self._log_hook(out_hook) for post_hook in plugin.hooks["post_hook"]: self.hook_hooks["post"].append(post_hook) self._log_hook(post_hook) for perm_hook in plugin.hooks["perm_check"]: for perm in perm_hook.perms: self.perm_hooks[perm].append(perm_hook) self._log_hook(perm_hook) # sort sieve hooks by priority self.sieves.sort(key=lambda x: x.priority) self.connect_hooks.sort(key=attrgetter("priority")) # Sort hooks self.regex_hooks.sort(key=lambda x: x[1].priority) dicts_of_lists_of_hooks = (self.event_type_hooks, self.raw_triggers, self.perm_hooks, self.hook_hooks) lists_of_hooks = [ self.catch_all_triggers, self.sieves, self.connect_hooks, self.out_sieves ] lists_of_hooks.extend( chain.from_iterable(d.values() for d in dicts_of_lists_of_hooks)) for lst in lists_of_hooks: lst.sort(key=attrgetter("priority")) # we don't need this anymore del plugin.hooks["on_start"] @asyncio.coroutine def unload_plugin(self, path): """ Unloads the plugin from the given path, unregistering all hooks from the plugin. Returns True if the plugin was unloaded, False if the plugin wasn't loaded in the first place. :type path: str | Path :rtype: bool """ path = Path(path) file_path = path.resolve() # make sure this plugin is actually loaded if str(file_path) not in self.plugins: return False # get the loaded plugin plugin = self.plugins[str(file_path)] for task in plugin.tasks: task.cancel() for on_cap_available_hook in plugin.hooks["on_cap_available"]: available_hooks = self.cap_hooks["on_available"] for cap in on_cap_available_hook.caps: cap_cf = cap.casefold() available_hooks[cap_cf].remove(on_cap_available_hook) if not available_hooks[cap_cf]: del available_hooks[cap_cf] for on_cap_ack in plugin.hooks["on_cap_ack"]: ack_hooks = self.cap_hooks["on_ack"] for cap in on_cap_ack.caps: cap_cf = cap.casefold() ack_hooks[cap_cf].remove(on_cap_ack) if not ack_hooks[cap_cf]: del ack_hooks[cap_cf] # unregister commands for command_hook in plugin.hooks["command"]: for alias in command_hook.aliases: if alias in self.commands and self.commands[ alias] == command_hook: # we need to make sure that there wasn't a conflict, so we don't delete another plugin's command del self.commands[alias] # unregister raw hooks for raw_hook in plugin.hooks["irc_raw"]: if raw_hook.is_catch_all(): self.catch_all_triggers.remove(raw_hook) else: for trigger in raw_hook.triggers: assert trigger in self.raw_triggers # this can't be not true self.raw_triggers[trigger].remove(raw_hook) if not self.raw_triggers[ trigger]: # if that was the last hook for this trigger del self.raw_triggers[trigger] # unregister events for event_hook in plugin.hooks["event"]: for event_type in event_hook.types: assert event_type in self.event_type_hooks # this can't be not true self.event_type_hooks[event_type].remove(event_hook) if not self.event_type_hooks[ event_type]: # if that was the last hook for this event type del self.event_type_hooks[event_type] # unregister regexps for regex_hook in plugin.hooks["regex"]: for regex_match in regex_hook.regexes: self.regex_hooks.remove((regex_match, regex_hook)) # unregister sieves for sieve_hook in plugin.hooks["sieve"]: self.sieves.remove(sieve_hook) # unregister connect hooks for connect_hook in plugin.hooks["on_connect"]: self.connect_hooks.remove(connect_hook) for out_hook in plugin.hooks["irc_out"]: self.out_sieves.remove(out_hook) for post_hook in plugin.hooks["post_hook"]: self.hook_hooks["post"].remove(post_hook) for perm_hook in plugin.hooks["perm_check"]: for perm in perm_hook.perms: self.perm_hooks[perm].remove(perm_hook) # Run on_stop hooks for on_stop_hook in plugin.hooks["on_stop"]: event = Event(bot=self.bot, hook=on_stop_hook) yield from self.launch(on_stop_hook, event) # unregister databases plugin.unregister_tables(self.bot) # remove last reference to plugin del self.plugins[plugin.file_path] if self.bot.config.get("logging", {}).get("show_plugin_loading", True): logger.info("Unloaded all plugins from {}".format(plugin.title)) return True def _log_hook(self, hook): """ Logs registering a given hook :type hook: Hook """ if self.bot.config.get("logging", {}).get("show_plugin_loading", True): logger.info("Loaded {}".format(hook)) logger.debug("Loaded {}".format(repr(hook))) def _prepare_parameters(self, hook, event): """ Prepares arguments for the given hook :type hook: cloudbot.plugin.Hook :type event: cloudbot.event.Event :rtype: list """ parameters = [] for required_arg in hook.required_args: if hasattr(event, required_arg): value = getattr(event, required_arg) parameters.append(value) else: logger.error( "Plugin {} asked for invalid argument '{}', cancelling execution!" .format(hook.description, required_arg)) logger.debug("Valid arguments are: {} ({})".format( dir(event), event)) return None return parameters def _execute_hook_threaded(self, hook, event): """ :type hook: Hook :type event: cloudbot.event.Event """ event.prepare_threaded() parameters = self._prepare_parameters(hook, event) if parameters is None: return None try: return hook.function(*parameters) finally: event.close_threaded() @asyncio.coroutine def _execute_hook_sync(self, hook, event): """ :type hook: Hook :type event: cloudbot.event.Event """ yield from event.prepare() parameters = self._prepare_parameters(hook, event) if parameters is None: return None try: return (yield from hook.function(*parameters)) finally: yield from event.close() @asyncio.coroutine def internal_launch(self, hook, event): """ Launches a hook with the data from [event] :param hook: The hook to launch :param event: The event providing data for the hook :return: a tuple of (ok, result) where ok is a boolean that determines if the hook ran without error and result is the result from the hook """ try: if hook.threaded: out = yield from self.bot.loop.run_in_executor( None, self._execute_hook_threaded, hook, event) else: out = yield from self._execute_hook_sync(hook, event) except Exception as e: logger.exception("Error in hook {}".format(hook.description)) return False, e return True, out @asyncio.coroutine def _execute_hook(self, hook, event): """ Runs the specific hook with the given bot and event. Returns False if the hook errored, True otherwise. :type hook: cloudbot.plugin.Hook :type event: cloudbot.event.Event :rtype: bool """ ok, out = yield from self.internal_launch(hook, event) result, error = None, None if ok is True: result = out else: error = out post_event = partial(PostHookEvent, launched_hook=hook, launched_event=event, bot=event.bot, conn=event.conn, result=result, error=error) for post_hook in self.hook_hooks["post"]: success, res = yield from self.internal_launch( post_hook, post_event(hook=post_hook)) if success and res is False: break return ok @asyncio.coroutine def _sieve(self, sieve, event, hook): """ :type sieve: cloudbot.plugin.Hook :type event: cloudbot.event.Event :type hook: cloudbot.plugin.Hook :rtype: cloudbot.event.Event """ try: if sieve.threaded: result = yield from self.bot.loop.run_in_executor( None, sieve.function, self.bot, event, hook) else: result = yield from sieve.function(self.bot, event, hook) except Exception: logger.exception("Error running sieve {} on {}:".format( sieve.description, hook.description)) return None else: return result @asyncio.coroutine def _start_periodic(self, hook): interval = hook.interval initial_interval = hook.initial_interval yield from asyncio.sleep(initial_interval) while True: event = Event(bot=self.bot, hook=hook) yield from self.launch(hook, event) yield from asyncio.sleep(interval) @asyncio.coroutine def launch(self, hook, event): """ Dispatch a given event to a given hook using a given bot object. Returns False if the hook didn't run successfully, and True if it ran successfully. :type event: cloudbot.event.Event | cloudbot.event.CommandEvent :type hook: cloudbot.plugin.Hook | cloudbot.plugin.CommandHook :rtype: bool """ if hook.type not in ( "on_start", "on_stop", "periodic"): # we don't need sieves on on_start hooks. for sieve in self.bot.plugin_manager.sieves: event = yield from self._sieve(sieve, event, hook) if event is None: return False if hook.single_thread: # There should only be one running instance of this hook, so let's wait for the last event to be processed # before starting this one. key = (hook.plugin.title, hook.function_name) if key in self._hook_waiting_queues: queue = self._hook_waiting_queues[key] if queue is None: # there's a hook running, but the queue hasn't been created yet, since there's only one hook queue = asyncio.Queue() self._hook_waiting_queues[key] = queue assert isinstance(queue, asyncio.Queue) # create a future to represent this task future = asyncio.Future() queue.put_nowait(future) # wait until the last task is completed yield from future else: # set to None to signify that this hook is running, but there's no need to create a full queue # in case there are no more hooks that will wait self._hook_waiting_queues[key] = None # Run the plugin with the message, and wait for it to finish result = yield from self._execute_hook(hook, event) queue = self._hook_waiting_queues[key] if queue is None or queue.empty(): # We're the last task in the queue, we can delete it now. del self._hook_waiting_queues[key] else: # set the result for the next task's future, so they can execute next_future = yield from queue.get() next_future.set_result(None) else: # Run the plugin with the message, and wait for it to finish result = yield from self._execute_hook(hook, event) # Return the result return result
def __init__(cls, *args): super().__init__(*args) cls.aliases = WeakValueDictionary()
class Cell(_FortranObjectWithID): """Cell stored internally. This class exposes a cell that is stored internally in the OpenMC library. To obtain a view of a cell with a given ID, use the :data:`openmc.lib.cells` mapping. Parameters ---------- index : int Index in the `cells` array. Attributes ---------- id : int ID of the cell """ __instances = WeakValueDictionary() def __new__(cls, uid=None, new=True, index=None): mapping = cells if index is None: if new: # Determine ID to assign if uid is None: uid = max(mapping, default=0) + 1 else: if uid in mapping: raise AllocationError('A cell with ID={} has already ' 'been allocated.'.format(uid)) index = c_int32() _dll.openmc_extend_cells(1, index, None) index = index.value else: index = mapping[uid]._index if index not in cls.__instances: instance = super().__new__(cls) instance._index = index if uid is not None: instance.id = uid cls.__instances[index] = instance return cls.__instances[index] @property def id(self): cell_id = c_int32() _dll.openmc_cell_get_id(self._index, cell_id) return cell_id.value @id.setter def id(self, cell_id): _dll.openmc_cell_set_id(self._index, cell_id) @property def name(self): name = c_char_p() _dll.openmc_cell_get_name(self._index, name) return name.value.decode() @name.setter def name(self, name): name_ptr = c_char_p(name.encode()) _dll.openmc_cell_set_name(self._index, name_ptr) @property def fill(self): fill_type = c_int() indices = POINTER(c_int32)() n = c_int32() _dll.openmc_cell_get_fill(self._index, fill_type, indices, n) if fill_type.value == 0: if n.value > 1: return [Material(index=i) for i in indices[:n.value]] else: index = indices[0] return Material(index=index) else: raise NotImplementedError @fill.setter def fill(self, fill): if isinstance(fill, Iterable): n = len(fill) indices = (c_int32 * n)(*(m._index if m is not None else -1 for m in fill)) _dll.openmc_cell_set_fill(self._index, 0, n, indices) elif isinstance(fill, Material): indices = (c_int32 * 1)(fill._index) _dll.openmc_cell_set_fill(self._index, 0, 1, indices) elif fill is None: indices = (c_int32 * 1)(-1) _dll.openmc_cell_set_fill(self._index, 0, 1, indices) def get_temperature(self, instance=None): """Get the temperature of a cell Parameters ---------- instance: int or None Which instance of the cell """ if instance is not None: instance = c_int32(instance) T = c_double() _dll.openmc_cell_get_temperature(self._index, instance, T) return T.value def set_temperature(self, T, instance=None): """Set the temperature of a cell Parameters ---------- T : float Temperature in K instance : int or None Which instance of the cell """ if instance is not None: instance = c_int32(instance) _dll.openmc_cell_set_temperature(self._index, T, instance) @property def bounding_box(self): inf = sys.float_info.max llc = np.zeros(3) urc = np.zeros(3) _dll.openmc_cell_bounding_box(self._index, llc.ctypes.data_as(POINTER(c_double)), urc.ctypes.data_as(POINTER(c_double))) llc[llc == inf] = np.inf urc[urc == inf] = np.inf llc[llc == -inf] = -np.inf urc[urc == -inf] = -np.inf return llc, urc
print_function) import atexit from weakref import WeakValueDictionary, ref import zsh from powerline.shell import ShellPowerline from powerline.lib.overrides import parsedotval, parse_override_var from powerline.lib.unicode import unicode, u from powerline.lib.encoding import (get_preferred_output_encoding, get_preferred_environment_encoding) from powerline.lib.dict import mergeargs used_powerlines = WeakValueDictionary() def shutdown(): for powerline in tuple(used_powerlines.values()): powerline.shutdown() def get_var_config(var): try: val = zsh.getvalue(var) if isinstance(val, dict): return mergeargs( [parsedotval((u(k), u(v))) for k, v in val.items()]) elif isinstance(val, (unicode, str, bytes)): return mergeargs(parse_override_var(u(val)))
class TaskQueue(AbstractTaskQueue): """Simple in-memory task queue implementation""" @classmethod def factory(cls, url, name=const.DEFAULT, *args, **kw): obj = _REFS.get((url, name)) if obj is None: obj = _REFS[(url, name)] = cls(url, name, *args, **kw) return obj def __init__(self, *args, **kw): super(TaskQueue, self).__init__(*args, **kw) self.queue = Queue() self.results = WeakValueDictionary() self.results_lock = Lock() def _init_result(self, result, status, message): with self.results_lock: if result.id in self.results: return False self.results[result.id] = result result.__status = status result.__value = Queue() result.__task = message result.__args = {} result.__lock = Lock() result.__for = None return True def enqueue_task(self, result, message): if self._init_result(result, const.ENQUEUED, message): self.queue.put(result) return True return False def defer_task(self, result, message, args): if self._init_result(result, const.PENDING, message): results = self.results # keep references to results to prevent GC result.__refs = [results.get(arg) for arg in args] return True return False def undefer_task(self, task_id): result = self.results[task_id] self.queue.put(result) def get(self, timeout=None): try: result = self.queue.get(timeout=timeout) except Empty: return None result.__status = const.PROCESSING return result.id, result.__task def size(self): return len(self.results) def discard_pending(self): with self.results_lock: while True: try: self.queue.get_nowait() except Empty: break self.results.clear() def reserve_argument(self, argument_id, deferred_id): result = self.results.get(argument_id) if result is None: return (False, None) with result.__lock: if result.__for is not None: return (False, None) result.__for = deferred_id try: message = result.__value.get_nowait() except Empty: message = None if message is not None: with self.results_lock: self.results.pop(argument_id, None) return (True, message) def set_argument(self, task_id, argument_id, message): result = self.results[task_id] with self.results_lock: self.results.pop(argument_id, None) with result.__lock: result.__args[argument_id] = message return len(result.__args) == len(result.__refs) def get_arguments(self, task_id): try: return self.results[task_id].__args except KeyError: return {} def set_task_timeout(self, task_id, timeout): pass def get_status(self, task_id): result = self.results.get(task_id) return None if result is None else result.__status def set_result(self, task_id, message, timeout): result = self.results.get(task_id) if result is not None: with result.__lock: result.__value.put(message) return result.__for def pop_result(self, task_id, timeout): result = self.results.get(task_id) if result is None: return const.TASK_EXPIRED # with result.__lock: # if result.__for is not None: # raise NotImplementedError # #return const.RESERVED # result.__for = task_id try: if timeout == 0: value = result.__value.get_nowait() else: value = result.__value.get(timeout=timeout) except Empty: value = None else: self.results.pop(task_id, None) return value def discard_result(self, task_id, task_expired_token): result = self.results.pop(task_id) if result is not None: result.__value.put(task_expired_token)
"""In-memory message queue and result store.""" import logging try: from Queue import Queue, Empty except ImportError: from queue import Queue, Empty from threading import Lock from weakref import WeakValueDictionary import worq.const as const from worq.core import AbstractTaskQueue log = logging.getLogger(__name__) _REFS = WeakValueDictionary() class TaskQueue(AbstractTaskQueue): """Simple in-memory task queue implementation""" @classmethod def factory(cls, url, name=const.DEFAULT, *args, **kw): obj = _REFS.get((url, name)) if obj is None: obj = _REFS[(url, name)] = cls(url, name, *args, **kw) return obj def __init__(self, *args, **kw): super(TaskQueue, self).__init__(*args, **kw) self.queue = Queue()
class RedisStore: def __init__(self, db_host, db_port, db_num, db_pw): self.pool = ConnectionPool(max_connections=2, db=db_num, host=db_host, port=db_port, password=db_pw, decode_responses=True) self.redis = StrictRedis(connection_pool=self.pool) self.redis.ping() self._object_map = WeakValueDictionary() def create_object(self, dbo_class, dbo_dict, update_timestamp=True): dbo_class = get_dbo_class(getattr(dbo_class, 'dbo_key_type', dbo_class)) if not dbo_class: return try: dbo_id = dbo_dict['dbo_id'] except KeyError: dbo_id, dbo_dict = dbo_dict, {} if dbo_id is None or dbo_id == '': log.warn("create_object called with empty dbo_id") return dbo_id = str(dbo_id).lower() if self.object_exists(dbo_class.dbo_key_type, dbo_id): raise ObjectExistsError(dbo_id) dbo = dbo_class() dbo.dbo_id = dbo_id dbo.hydrate(dbo_dict) dbo.db_created() if dbo.dbo_set_key: self.redis.sadd(dbo.dbo_set_key, dbo.dbo_id) self.save_object(dbo, update_timestamp) return dbo def load_object(self, dbo_key, key_type=None, silent=False): if key_type: try: key_type = key_type.dbo_key_type except AttributeError: pass try: dbo_key, dbo_id = ':'.join((key_type, dbo_key)), dbo_key except TypeError: if not silent: log.exception("Invalid dbo_key passed to load_object", stack_info=True) return else: key_type, _, dbo_id = dbo_key.partition(':') cached_dbo = self._object_map.get(dbo_key) if cached_dbo: return cached_dbo json_str = self.redis.get(dbo_key) if not json_str: if not silent: log.warn("Failed to find {} in database", dbo_key) return return self._json_to_obj(json_str, key_type, dbo_id) def save_object(self, dbo, update_timestamp=False, autosave=False): if update_timestamp: dbo.dbo_ts = int(time.time()) if dbo.dbo_indexes: self._update_indexes(dbo) self._clear_old_refs(dbo) save_root, new_refs = dbo.to_db_value() self.redis.set(dbo.dbo_key, json_encode(save_root)) if new_refs: self._set_new_refs(dbo, new_refs) log.debug("db object {} {}saved", dbo.dbo_key, "auto" if autosave else "") self._object_map[dbo.dbo_key] = dbo return dbo def update_object(self, dbo, dbo_dict): dbo.hydrate(dbo_dict) return self.save_object(dbo, True) def delete_object(self, dbo): key = dbo.dbo_key dbo.db_deleted() self.delete_key(key) self._clear_old_refs(dbo) if dbo.dbo_set_key: self.redis.srem(dbo.dbo_set_key, dbo.dbo_id) for children_type in dbo.dbo_children_types: self.delete_object_set( get_dbo_class(children_type), "{}_{}s:{}".format(dbo.dbo_key_type, children_type, dbo.dbo_id)) for ix_name in dbo.dbo_indexes: ix_value = getattr(dbo, ix_name, None) if ix_value is not None and ix_value != '': self.delete_index('ix:{}:{}'.format(dbo.dbo_key_type, ix_name), ix_value) log.debug("object deleted: {}", key) self.evict_object(dbo) def load_cached(self, dbo_key): return self._object_map.get(dbo_key) def object_exists(self, obj_type, obj_id): return self.redis.exists('{}:{}'.format(obj_type, obj_id)) def load_object_set(self, dbo_class, set_key=None): dbo_class = get_dbo_class(getattr(dbo_class, 'dbo_key_type', dbo_class)) key_type = dbo_class.dbo_key_type if not set_key: set_key = dbo_class.dbo_set_key results = set() keys = deque() pipeline = self.redis.pipeline() for key in self.fetch_set_keys(set_key): dbo_key = ':'.join([key_type, key]) try: results.add(self._object_map[dbo_key]) except KeyError: keys.append(key) pipeline.get(dbo_key) for dbo_id, json_str in zip(keys, pipeline.execute()): if json_str: obj = self._json_to_obj(json_str, key_type, dbo_id) if obj: results.add(obj) continue log.warn("Removing missing object from set {}", set_key) self.delete_set_key(set_key, dbo_id) return results def delete_object_set(self, dbo_class, set_key=None): if not set_key: set_key = dbo_class.dbo_set_key for dbo in self.load_object_set(dbo_class, set_key): self.delete_object(dbo) self.delete_key(set_key) def reload_object(self, dbo_key): dbo = self._object_map.get(dbo_key) if dbo: json_str = self.redis.get(dbo_key) if not json_str: log.warn("Failed to find {} in database for reload", dbo_key) return None return self.update_object(dbo, json_decode(json_str)) return self.load_object(dbo_key) def evict_object(self, dbo): self._object_map.pop(dbo.dbo_key, None) def load_value(self, key, default=None): json = self.redis.get(key) if json: return json_decode(json) return default def save_value(self, key, value): self.redis.set(key, json_encode(value)) def fetch_set_keys(self, set_key): return self.redis.smembers(set_key) def add_set_key(self, set_key, *values): self.redis.sadd(set_key, *values) def delete_set_key(self, set_key, value): self.redis.srem(set_key, value) def set_key_exists(self, set_key, value): return self.redis.sismember(set_key, value) def db_counter(self, counter_id, inc=1): return self.redis.incr("counter:{}".format(counter_id), inc) def delete_key(self, key): self.redis.delete(key) def set_index(self, index_name, key, value): return self.redis.hset(index_name, key, value) def get_index(self, index_name, key): return self.redis.hget(index_name, key) def get_full_index(self, index_name): return self.redis.hgetall(index_name) def delete_index(self, index_name, key): return self.redis.hdel(index_name, key) def get_all_hash(self, index_name): return { key: json_decode(value) for key, value in self.redis.hgetall(index_name).items() } def get_hash_keys(self, hash_id): return self.redis.hkeys(hash_id) def set_db_hash(self, hash_id, hash_key, value): return self.redis.hset(hash_id, hash_key, json_encode(value)) def get_db_hash(self, hash_id, hash_key): return json_decode(self.redis.hget(hash_id, hash_key)) def remove_db_hash(self, hash_id, hash_key): self.redis.hdel(hash_id, hash_key) def get_all_db_hash(self, hash_id): return [ json_decode(value) for value in self.redis.hgetall(hash_id).values() ] def get_db_list(self, list_id, start=0, end=-1): return [ json_decode(value) for value in self.redis.lrange(list_id, start, end) ] def add_db_list(self, list_id, value): self.redis.lpush(list_id, json_encode(value)) def trim_db_list(self, list_id, start, end): return self.redis.ltrim(list_id, start, end) def dbo_holders(self, dbo_key, degrees=0): all_keys = set() def find(find_key, degree): holder_keys = self.fetch_set_keys('{}:holders'.format(find_key)) for new_key in holder_keys: if new_key != dbo_key and new_key not in all_keys: all_keys.add(new_key) if degree < degrees: find(new_key, degree + 1) find(dbo_key, 0) return all_keys def _json_to_obj(self, json_str, key_type, dbo_id): dbo_dict = json_decode(json_str) dbo = get_mixed_type(key_type, dbo_dict.get('mixins'))() dbo.dbo_id = dbo_id dbo.hydrate(dbo_dict) self._object_map[dbo.dbo_key] = dbo return dbo def _update_indexes(self, dbo): try: old_dbo = json_decode(self.redis.get(dbo.dbo_key)) except TypeError: old_dbo = None for ix_name in dbo.dbo_indexes: new_val = getattr(dbo, ix_name, None) old_val = old_dbo.get(ix_name) if old_dbo else None if old_val == new_val: continue ix_key = 'ix:{}:{}'.format(dbo.dbo_key_type, ix_name) if old_val is not None: self.delete_index(ix_key, old_val) if new_val is not None and new_val != '': if self.get_index(ix_key, new_val): raise NonUniqueError(ix_key, new_val) self.set_index(ix_key, new_val, dbo.dbo_id) def _clear_old_refs(self, dbo): dbo_key = dbo.dbo_key ref_key = '{}:refs'.format(dbo_key) for ref_id in self.fetch_set_keys(ref_key): holder_key = '{}:holders'.format(ref_id) self.delete_set_key(holder_key, dbo_key) self.delete_key(ref_key) def _set_new_refs(self, dbo, new_refs): dbo_key = dbo.dbo_key self.add_set_key("{}:refs".format(dbo_key), *new_refs) for ref_id in new_refs: holder_key = '{}:holders'.format(ref_id) self.add_set_key(holder_key, dbo_key)
class UniversalCursors(object): def __init__(self): self.name_cursors = {} self.cursors_orient = WeakKeyDictionary() self.all_canvas = WeakValueDictionary() self.all_axes = WeakValueDictionary() self.backgrounds = {} self.visible = True self.needclear = False def _onmove(self, event): for canvas in self.all_canvas.values(): if not canvas.widgetlock.available(self): return if event.inaxes is None or not self.visible: if self.needclear: self._update(event) for canvas in self.all_canvas.values(): canvas.draw() self.needclear = False return self._update(event) def _update(self, event): # 1/ Reset background for canvas in self.all_canvas.values(): canvas.restore_region(self.backgrounds[id(canvas)]) # 2/ update cursors for cursors in self.cursors_orient.keys(): orient = self.cursors_orient[cursors] if (event.inaxes in [line.get_axes() for line in cursors] and self.visible): visible = True self.needclear = True else: visible = False for line in cursors: if orient == 'vertical': line.set_xdata((event.xdata, event.xdata)) if orient == 'horizontal': line.set_ydata((event.ydata, event.ydata)) line.set_visible(visible) ax = line.get_axes() ax.draw_artist(line) # 3/ update canvas for canvas in self.all_canvas.values(): canvas.blit(canvas.figure.bbox) def _clear(self, event): """clear the cursor""" self.backgrounds = {} for canvas in self.all_canvas.values(): self.backgrounds[id(canvas)] = ( canvas.copy_from_bbox(canvas.figure.bbox)) for cursor in self.cursors_orient.keys(): for line in cursor: line.set_visible(False) def add(self, name, axes=(), orient='vertical', **lineprops): if name in self.name_cursors.keys(): raise NameError class CursorList(list): def __hash__(self): return hash(tuple(self)) self.name_cursors[name] = CursorList() # Required to keep weakref for ax in axes: self.all_axes[id(ax)] = ax ax_canvas = ax.get_figure().canvas if ax_canvas not in self.all_canvas.values(): #if not ax_canvas.supports_blit: # warnings.warn("Must use canvas that support blit") # return self.all_canvas[id(ax_canvas)] = ax_canvas ax_canvas.mpl_connect('motion_notify_event', self._onmove) ax_canvas.mpl_connect('draw_event', self._clear) if orient == 'vertical': line = ax.axvline(ax.get_xbound()[0], visible=False, animated=True, **lineprops) if orient == 'horizontal': line = ax.axhline(ax.get_ybound()[0], visible=False, animated=True, **lineprops) self.name_cursors[name].append(line) self.cursors_orient[self.name_cursors[name]] = orient def remove(self, name): del self.name_cursors[name]
class LanguageTool: """ Main class used for checking text against different rules. LanguageTool v2 API documentation: https://languagetool.org/http-api/swagger-ui/#!/default/post_check """ _HOST = socket.gethostbyname('localhost') _MIN_PORT = 8081 _MAX_PORT = 8999 _TIMEOUT = 5 * 60 _remote = False _port = _MIN_PORT _server = None _consumer_thread = None _instances = WeakValueDictionary() _PORT_RE = re.compile(r"(?:https?://.*:|port\s+)(\d+)", re.I) def __init__(self, language=None, motherTongue=None, remote_server=None, newSpellings=None, new_spellings_persist=True): self._new_spellings = None self._new_spellings_persist = new_spellings_persist if newSpellings: self._new_spellings = newSpellings self._register_spellings(self._new_spellings) if remote_server is not None: self._remote = True self._url = parse_url(remote_server) self._url = urllib.parse.urljoin(self._url, 'v2/') self._update_remote_server_config(self._url) elif not self._server_is_alive(): self._start_server_on_free_port() if language is None: try: language = get_locale_language() except ValueError: language = FAILSAFE_LANGUAGE self._language = LanguageTag(language, self._get_languages()) self.motherTongue = motherTongue self.disabled_rules = set() self.enabled_rules = set() self.disabled_categories = set() self.enabled_categories = set() self.enabled_rules_only = False self._instances[id(self)] = self def __enter__(self): return self def __exit__(self, exc_type, exc_val, exc_tb): self.close() def __repr__(self): return '{}(language={!r}, motherTongue={!r})'.format( self.__class__.__name__, self.language, self.motherTongue) def close(self): if not self._instances and self._server_is_alive(): self._terminate_server() if not self._new_spellings_persist and self._new_spellings: self._unregister_spellings() self._new_spellings = [] @property def language(self): """The language to be used.""" return self._language @language.setter def language(self, language): self._language = LanguageTag(language, self._get_languages()) self.disabled_rules.clear() self.enabled_rules.clear() @property def motherTongue(self): """The user's mother tongue or None. The mother tongue may also be used as a source language for checking bilingual texts. """ return self._motherTongue @motherTongue.setter def motherTongue(self, motherTongue): self._motherTongue = (None if motherTongue is None else LanguageTag( motherTongue, self._get_languages())) @property def _spell_checking_categories(self): return {'TYPOS'} def check(self, text: str) -> [Match]: """Match text against enabled rules.""" url = urllib.parse.urljoin(self._url, 'check') response = self._query_server(url, self._encode(text)) matches = response['matches'] return [Match(match) for match in matches] def _encode(self, text): params = {'language': self.language, 'text': text.encode('utf-8')} if self.motherTongue is not None: params['motherTongue'] = self.motherTongue if self.disabled_rules: params['disabledRules'] = ','.join(self.disabled_rules) if self.enabled_rules: params['enabledRules'] = ','.join(self.enabled_rules) if self.enabled_rules_only: params['enabledOnly'] = 'true' if self.disabled_categories: params['disabledCategories'] = ','.join(self.disabled_categories) if self.enabled_categories: params['enabledCategories'] = ','.join(self.enabled_categories) return urllib.parse.urlencode(params).encode() def correct(self, text: str) -> str: """Automatically apply suggestions to the text.""" return correct(text, self.check(text)) def enable_spellchecking(self): """Enable spell-checking rules.""" self.disabled_categories.difference_update( self._spell_checking_categories) def disable_spellchecking(self): """Disable spell-checking rules.""" self.disabled_categories.update(self._spell_checking_categories) @staticmethod def _get_valid_spelling_file_path() -> str: library_path = get_language_tool_directory() spelling_file_path = os.path.join( library_path, "org/languagetool/resource/en/hunspell/spelling.txt") if not os.path.exists(spelling_file_path): raise FileNotFoundError( "Failed to find the spellings file at {}\n Please file an issue at " "https://github.com/jxmorris12/language_tool_python/issues". format(spelling_file_path)) return spelling_file_path def _register_spellings(self, spellings): spelling_file_path = self._get_valid_spelling_file_path() with open(spelling_file_path, "a+") as spellings_file: spellings_file.write("\n" + "\n".join([word for word in spellings])) if DEBUG_MODE: print("Registered new spellings at {}".format(spelling_file_path)) def _unregister_spellings(self): spelling_file_path = self._get_valid_spelling_file_path() with open(spelling_file_path, 'r+') as spellings_file: spellings_file.seek(0, os.SEEK_END) for _ in range(len(self._new_spellings)): while spellings_file.read(1) != '\n': spellings_file.seek(spellings_file.tell() - 2, os.SEEK_SET) spellings_file.seek(spellings_file.tell() - 2, os.SEEK_SET) spellings_file.seek(spellings_file.tell() + 1, os.SEEK_SET) spellings_file.truncate() if DEBUG_MODE: print( "Unregistered new spellings at {}".format(spelling_file_path)) def _get_languages(self) -> set: """Get supported languages (by querying the server).""" self._start_server_if_needed() url = urllib.parse.urljoin(self._url, 'languages') languages = set() for e in self._query_server(url, num_tries=1): languages.add(e.get('code')) languages.add(e.get('longCode')) return languages def _start_server_if_needed(self): # Start server. if not self._server_is_alive() and self._remote is False: self._start_server_on_free_port() def _update_remote_server_config(self, url): self._url = url self._remote = True def _query_server(self, url, data=None, num_tries=2): if DEBUG_MODE: print('_query_server url:', url, 'data:', data) for n in range(num_tries): try: with urlopen(url, data, self._TIMEOUT) as f: raw_data = f.read().decode('utf-8') try: return json.loads(raw_data) except json.decoder.JSONDecodeError as e: print( 'URL {url} and data {data} returned invalid JSON response:' ) print(raw_data) raise e except (IOError, http.client.HTTPException) as e: if self._remote is False: self._terminate_server() self._start_local_server() if n + 1 >= num_tries: raise LanguageToolError('{}: {}'.format(self._url, e)) def _start_server_on_free_port(self): while True: self._url = 'http://{}:{}/v2/'.format(self._HOST, self._port) try: self._start_local_server() break except ServerError: if self._MIN_PORT <= self._port < self._MAX_PORT: self._port += 1 else: raise def _start_local_server(self): # Before starting local server, download language tool if needed. download_lt() err = None try: server_cmd = get_server_cmd(self._port) except PathError as e: # Can't find path to LanguageTool. err = e else: # Need to PIPE all handles: http://bugs.python.org/issue3905 self._server = subprocess.Popen(server_cmd, stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.PIPE, universal_newlines=True, startupinfo=startupinfo) global RUNNING_SERVER_PROCESSES RUNNING_SERVER_PROCESSES.append(self._server) match = None while True: line = self._server.stdout.readline() if not line: break match = self._PORT_RE.search(line) if match: port = int(match.group(1)) if port != self._port: raise LanguageToolError( 'requested port {}, but got {}'.format( self._port, port)) break if not match: err_msg = self._terminate_server() match = self._PORT_RE.search(err_msg) if not match: raise LanguageToolError(err_msg) port = int(match.group(1)) if port != self._port: raise LanguageToolError(err_msg) if self._server: self._consumer_thread = threading.Thread( target=lambda: _consume(self._server.stdout)) self._consumer_thread.daemon = True self._consumer_thread.start() else: # Couldn't start the server, so maybe there is already one running. raise ServerError('Server running; don\'t start a server here.') def _server_is_alive(self): return self._server and self._server.poll() is None def _terminate_server(self): LanguageToolError_message = '' try: self._server.terminate() except OSError: pass try: LanguageToolError_message = self._server.communicate()[1].strip() except (IOError, ValueError): pass try: self._server.stdout.close() except IOError: pass try: self._server.stdin.close() except IOError: pass try: self._server.stderr.close() except IOError: pass self._server = None return LanguageToolError_message
class Boss: def __init__(self, os_window_id, opts, args, cached_values, new_os_window_trigger): set_layout_options(opts) self.clipboard_buffers = {} self.update_check_process = None self.window_id_map = WeakValueDictionary() self.startup_colors = { k: opts[k] for k in opts if isinstance(opts[k], Color) } self.startup_cursor_text_color = opts.cursor_text_color self.pending_sequences = None self.cached_values = cached_values self.os_window_map = {} self.os_window_death_actions = {} self.cursor_blinking = True self.shutting_down = False talk_fd = getattr(single_instance, 'socket', None) talk_fd = -1 if talk_fd is None else talk_fd.fileno() listen_fd = -1 if args.listen_on and (opts.allow_remote_control in ('y', 'socket-only')): listen_fd = listen_on(args.listen_on) self.child_monitor = ChildMonitor( self.on_child_death, DumpCommands(args) if args.dump_commands or args.dump_bytes else None, talk_fd, listen_fd) set_boss(self) self.opts, self.args = opts, args startup_sessions = create_sessions( opts, args, default_session=opts.startup_session) self.keymap = self.opts.keymap.copy() if new_os_window_trigger is not None: self.keymap.pop(new_os_window_trigger, None) for startup_session in startup_sessions: self.add_os_window(startup_session, os_window_id=os_window_id) os_window_id = None if args.start_as != 'normal': if args.start_as == 'fullscreen': self.toggle_fullscreen() else: change_os_window_state(args.start_as) if is_macos: from .fast_data_types import cocoa_set_notification_activated_callback cocoa_set_notification_activated_callback( self.notification_activated) def add_os_window(self, startup_session=None, os_window_id=None, wclass=None, wname=None, opts_for_size=None, startup_id=None): if os_window_id is None: opts_for_size = opts_for_size or getattr( startup_session, 'os_window_size', None) or self.opts cls = wclass or self.args.cls or appname with startup_notification_handler( do_notify=startup_id is not None, startup_id=startup_id) as pre_show_callback: os_window_id = create_os_window( initial_window_size_func(opts_for_size, self.cached_values), pre_show_callback, appname, wname or self.args.name or cls, cls) tm = TabManager(os_window_id, self.opts, self.args, startup_session) self.os_window_map[os_window_id] = tm return os_window_id def list_os_windows(self): with cached_process_data(): active_tab, active_window = self.active_tab, self.active_window active_tab_manager = self.active_tab_manager for os_window_id, tm in self.os_window_map.items(): yield { 'id': os_window_id, 'is_focused': tm is active_tab_manager, 'tabs': list(tm.list_tabs(active_tab, active_window)), } @property def all_tab_managers(self): yield from self.os_window_map.values() @property def all_tabs(self): for tm in self.all_tab_managers: yield from tm @property def all_windows(self): for tab in self.all_tabs: yield from tab def match_windows(self, match): try: field, exp = match.split(':', 1) except ValueError: return if field == 'num': tab = self.active_tab if tab is not None: try: w = tab.get_nth_window(int(exp)) except Exception: return if w is not None: yield w return if field == 'env': kp, vp = exp.partition('=')[::2] if vp: pat = tuple(map(re.compile, (kp, vp))) else: pat = re.compile(kp), None else: pat = re.compile(exp) for window in self.all_windows: if window.matches(field, pat): yield window def tab_for_window(self, window): for tab in self.all_tabs: for w in tab: if w.id == window.id: return tab def match_tabs(self, match): try: field, exp = match.split(':', 1) except ValueError: return pat = re.compile(exp) found = False if field in ('title', 'id'): for tab in self.all_tabs: if tab.matches(field, pat): yield tab found = True if not found: tabs = {self.tab_for_window(w) for w in self.match_windows(match)} for tab in tabs: if tab: yield tab def set_active_window(self, window): for os_window_id, tm in self.os_window_map.items(): for tab in tm: for w in tab: if w.id == window.id: if tab is not self.active_tab: tm.set_active_tab(tab) tab.set_active_window(w) return os_window_id def _new_os_window(self, args, cwd_from=None): if isinstance(args, SpecialWindowInstance): sw = args else: sw = self.args_to_special_window(args, cwd_from) if args else None startup_session = next( create_sessions(self.opts, special_window=sw, cwd_from=cwd_from)) return self.add_os_window(startup_session) def new_os_window(self, *args): self._new_os_window(args) @property def active_window_for_cwd(self): w = self.active_window if w is not None and w.overlay_for is not None and w.overlay_for in self.window_id_map: w = self.window_id_map[w.overlay_for] return w def new_os_window_with_cwd(self, *args): w = self.active_window_for_cwd cwd_from = w.child.pid_for_cwd if w is not None else None self._new_os_window(args, cwd_from) def new_os_window_with_wd(self, wd): special_window = SpecialWindow(None, cwd=wd) self._new_os_window(special_window) def add_child(self, window): self.child_monitor.add_child(window.id, window.child.pid, window.child.child_fd, window.screen) self.window_id_map[window.id] = window def _handle_remote_command(self, cmd, window=None, from_peer=False): response = None if self.opts.allow_remote_control == 'y' or from_peer or getattr( window, 'allow_remote_control', False): try: response = handle_cmd(self, window, cmd) except Exception as err: import traceback response = {'ok': False, 'error': str(err)} if not getattr(err, 'hide_traceback', False): response['tb'] = traceback.format_exc() else: response = { 'ok': False, 'error': 'Remote control is disabled. Add allow_remote_control to your kitty.conf' } return response def peer_message_received(self, msg): msg = msg.decode('utf-8') cmd_prefix = '\x1bP@kitty-cmd' if msg.startswith(cmd_prefix): cmd = msg[len(cmd_prefix):-2] response = self._handle_remote_command(cmd, from_peer=True) if response is not None: response = (cmd_prefix + json.dumps(response) + '\x1b\\').encode('utf-8') return response else: msg = json.loads(msg) if isinstance(msg, dict) and msg.get('cmd') == 'new_instance': startup_id = msg.get('startup_id') args, rest = parse_args(msg['args'][1:]) args.args = rest opts = create_opts(args) if not os.path.isabs(args.directory): args.directory = os.path.join(msg['cwd'], args.directory) for session in create_sessions(opts, args, respect_cwd=True): os_window_id = self.add_os_window(session, wclass=args.cls, wname=args.name, opts_for_size=opts, startup_id=startup_id) if msg.get('notify_on_os_window_death'): self.os_window_death_actions[os_window_id] = partial( self.notify_on_os_window_death, msg['notify_on_os_window_death']) else: log_error('Unknown message received from peer, ignoring') def handle_remote_cmd(self, cmd, window=None): response = self._handle_remote_command(cmd, window) if response is not None: if window is not None: window.send_cmd_response(response) def _cleanup_tab_after_window_removal(self, src_tab): if len(src_tab) < 1: tm = src_tab.tab_manager_ref() if tm is not None: tm.remove(src_tab) src_tab.destroy() if len(tm) == 0: if not self.shutting_down: mark_os_window_for_close(src_tab.os_window_id) def on_child_death(self, window_id): window = self.window_id_map.pop(window_id, None) if window is None: return if window.action_on_close: try: window.action_on_close(window) except Exception: import traceback traceback.print_exc() os_window_id = window.os_window_id window.destroy() tm = self.os_window_map.get(os_window_id) tab = None if tm is not None: for q in tm: if window in q: tab = q break if tab is not None: tab.remove_window(window) self._cleanup_tab_after_window_removal(tab) if window.action_on_removal: try: window.action_on_removal(window) except Exception: import traceback traceback.print_exc() window.action_on_close = window.action_on_removal = None def close_window(self, window=None): if window is None: window = self.active_window self.child_monitor.mark_for_close(window.id) def close_tab(self, tab=None): if tab is None: tab = self.active_tab for window in tab: self.close_window(window) def toggle_fullscreen(self): toggle_fullscreen() def toggle_maximized(self): toggle_maximized() def start(self): if not getattr(self, 'io_thread_started', False): self.child_monitor.start() self.io_thread_started = True if self.opts.update_check_interval > 0 and not hasattr( self, 'update_check_started'): from .update_check import run_update_check run_update_check(self.opts.update_check_interval * 60 * 60) self.update_check_started = True def activate_tab_at(self, os_window_id, x): tm = self.os_window_map.get(os_window_id) if tm is not None: tm.activate_tab_at(x) def on_window_resize(self, os_window_id, w, h, dpi_changed): if dpi_changed: self.on_dpi_change(os_window_id) else: tm = self.os_window_map.get(os_window_id) if tm is not None: tm.resize() def clear_terminal(self, action, only_active): if only_active: windows = [] w = self.active_window if w is not None: windows.append(w) else: windows = self.all_windows reset = action == 'reset' how = 3 if action == 'scrollback' else 2 for w in windows: if action == 'scroll': w.screen.scroll_until_cursor() continue w.screen.cursor.x = w.screen.cursor.y = 0 if reset: w.screen.reset() else: w.screen.erase_in_display(how, False) def increase_font_size(self): # legacy cfs = global_font_size() self.set_font_size(min(self.opts.font_size * 5, cfs + 2.0)) def decrease_font_size(self): # legacy cfs = global_font_size() self.set_font_size(max(MINIMUM_FONT_SIZE, cfs - 2.0)) def restore_font_size(self): # legacy self.set_font_size(self.opts.font_size) def set_font_size(self, new_size): # legacy self.change_font_size(True, None, new_size) def change_font_size(self, all_windows, increment_operation, amt): def calc_new_size(old_size): new_size = old_size if amt == 0: new_size = self.opts.font_size else: if increment_operation: new_size += (1 if increment_operation == '+' else -1) * amt else: new_size = amt new_size = max(MINIMUM_FONT_SIZE, min(new_size, self.opts.font_size * 5)) return new_size if all_windows: current_global_size = global_font_size() new_size = calc_new_size(current_global_size) if new_size != current_global_size: global_font_size(new_size) os_windows = tuple(self.os_window_map.keys()) else: os_windows = [] w = self.active_window if w is not None: os_windows.append(w.os_window_id) if os_windows: final_windows = {} for wid in os_windows: current_size = os_window_font_size(wid) if current_size: new_size = calc_new_size(current_size) if new_size != current_size: final_windows[wid] = new_size if final_windows: self._change_font_size(final_windows) def _change_font_size(self, sz_map): for os_window_id, sz in sz_map.items(): tm = self.os_window_map.get(os_window_id) if tm is not None: os_window_font_size(os_window_id, sz) tm.resize() def on_dpi_change(self, os_window_id): tm = self.os_window_map.get(os_window_id) if tm is not None: sz = os_window_font_size(os_window_id) if sz: os_window_font_size(os_window_id, sz, True) tm.resize() def _set_os_window_background_opacity(self, os_window_id, opacity): change_background_opacity(os_window_id, max(0.1, min(opacity, 1.0))) def set_background_opacity(self, opacity): window = self.active_window if window is None or not opacity: return if not self.opts.dynamic_background_opacity: return self.show_error( _('Cannot change background opacity'), _('You must set the dynamic_background_opacity option in kitty.conf to be able to change background opacity' )) os_window_id = window.os_window_id if opacity[0] in '+-': old_opacity = background_opacity_of(os_window_id) if old_opacity is None: return opacity = old_opacity + float(opacity) elif opacity == 'default': opacity = self.opts.background_opacity else: opacity = float(opacity) self._set_os_window_background_opacity(os_window_id, opacity) @property def active_tab_manager(self): os_window_id = current_os_window() return self.os_window_map.get(os_window_id) @property def active_tab(self): tm = self.active_tab_manager if tm is not None: return tm.active_tab @property def active_window(self): t = self.active_tab if t is not None: return t.active_window def dispatch_special_key(self, key, native_key, action, mods): # Handles shortcuts, return True if the key was consumed key_action = get_shortcut(self.keymap, mods, key, native_key) if key_action is None: sequences = get_shortcut(self.opts.sequence_map, mods, key, native_key) if sequences: self.pending_sequences = sequences set_in_sequence_mode(True) return True else: self.current_key_press_info = key, native_key, action, mods return self.dispatch_action(key_action) def process_sequence(self, key, native_key, action, mods): if not self.pending_sequences: set_in_sequence_mode(False) remaining = {} matched_action = None for seq, key_action in self.pending_sequences.items(): if shortcut_matches(seq[0], mods, key, native_key): seq = seq[1:] if seq: remaining[seq] = key_action else: matched_action = key_action if remaining: self.pending_sequences = remaining else: self.pending_sequences = None set_in_sequence_mode(False) if matched_action is not None: self.dispatch_action(matched_action) def start_resizing_window(self): w = self.active_window if w is None: return overlay_window = self._run_kitten( 'resize_window', args=[ '--horizontal-increment={}'.format( self.opts.window_resize_step_cells), '--vertical-increment={}'.format( self.opts.window_resize_step_lines) ]) if overlay_window is not None: overlay_window.allow_remote_control = True def resize_layout_window(self, window, increment, is_horizontal, reset=False): tab = window.tabref() if tab is None or not increment: return False if reset: return tab.reset_window_sizes() return tab.resize_window_by(window.id, increment, is_horizontal) def default_bg_changed_for(self, window_id): w = self.window_id_map.get(window_id) if w is not None: tm = self.os_window_map.get(w.os_window_id) if tm is not None: tm.update_tab_bar_data() tm.mark_tab_bar_dirty() t = tm.tab_for_id(w.tab_id) if t is not None: t.relayout_borders() def dispatch_action(self, key_action): if key_action is not None: f = getattr(self, key_action.func, None) if f is not None: if self.args.debug_keyboard: print('Keypress matched action:', func_name(f)) passthrough = f(*key_action.args) if passthrough is not True: return True tab = self.active_tab if tab is None: return False window = self.active_window if window is None: return False if key_action is not None: f = getattr(tab, key_action.func, getattr(window, key_action.func, None)) if f is not None: passthrough = f(*key_action.args) if self.args.debug_keyboard: print('Keypress matched action:', func_name(f)) if passthrough is not True: return True return False def combine(self, *actions): for key_action in actions: self.dispatch_action(key_action) def on_focus(self, os_window_id, focused): tm = self.os_window_map.get(os_window_id) if tm is not None: w = tm.active_window if w is not None: w.focus_changed(focused) tm.mark_tab_bar_dirty() def update_tab_bar_data(self, os_window_id): tm = self.os_window_map.get(os_window_id) if tm is not None: tm.update_tab_bar_data() def on_drop(self, os_window_id, strings): tm = self.os_window_map.get(os_window_id) if tm is not None: w = tm.active_window if w is not None: w.paste('\n'.join(strings)) def on_os_window_closed(self, os_window_id, viewport_width, viewport_height): self.cached_values['window-size'] = viewport_width, viewport_height tm = self.os_window_map.pop(os_window_id, None) if tm is not None: tm.destroy() for window_id in tuple( w.id for w in self.window_id_map.values() if getattr(w, 'os_window_id', None) == os_window_id): self.window_id_map.pop(window_id, None) action = self.os_window_death_actions.pop(os_window_id, None) if action is not None: action() def notify_on_os_window_death(self, address): import socket s = socket.socket(family=socket.AF_UNIX) with suppress(Exception): s.connect(address) s.sendall(b'c') with suppress(EnvironmentError): s.shutdown(socket.SHUT_RDWR) s.close() def display_scrollback(self, window, data, cmd): tab = self.active_tab if tab is not None and window.overlay_for is None: tab.new_special_window( SpecialWindow(cmd, data, _('History'), overlay_for=window.id)) def edit_config_file(self, *a): confpath = prepare_config_file_for_editing() # On macOS vim fails to handle SIGWINCH if it occurs early, so add a # small delay. cmd = [ kitty_exe(), '+runpy', 'import os, sys, time; time.sleep(0.05); os.execvp(sys.argv[1], sys.argv[1:])' ] + get_editor() + [confpath] self.new_os_window(*cmd) def get_output(self, source_window, num_lines=1): output = '' s = source_window.screen if num_lines is None: num_lines = s.lines for i in range(min(num_lines, s.lines)): output += str(s.linebuf.line(i)) return output def _run_kitten(self, kitten, args=(), input_data=None, window=None, custom_callback=None, action_on_removal=None): orig_args, args = list(args), list(args) from kittens.runner import create_kitten_handler end_kitten = create_kitten_handler(kitten, orig_args) if window is None: w = self.active_window tab = self.active_tab else: w = window tab = w.tabref() if end_kitten.no_ui: end_kitten(None, getattr(w, 'id', None), self) return if w is not None and tab is not None and w.overlay_for is None: args[0:0] = [config_dir, kitten] if input_data is None: type_of_input = end_kitten.type_of_input if type_of_input in ('text', 'history', 'ansi', 'ansi-history', 'screen', 'screen-history', 'screen-ansi', 'screen-ansi-history'): data = w.as_text(as_ansi='ansi' in type_of_input, add_history='history' in type_of_input, add_wrap_markers='screen' in type_of_input).encode('utf-8') elif type_of_input is None: data = None else: raise ValueError( 'Unknown type_of_input: {}'.format(type_of_input)) else: data = input_data if isinstance(data, str): data = data.encode('utf-8') copts = { k: self.opts[k] for k in ('select_by_word_characters', 'open_url_with') } overlay_window = tab.new_special_window(SpecialWindow( [ kitty_exe(), '+runpy', 'from kittens.runner import main; main()' ] + args, stdin=data, env={ 'KITTY_COMMON_OPTS': json.dumps(copts), 'KITTY_CHILD_PID': w.child.pid, 'PYTHONWARNINGS': 'ignore', 'OVERLAID_WINDOW_LINES': str(w.screen.lines), 'OVERLAID_WINDOW_COLS': str(w.screen.columns), }, cwd=w.cwd_of_child, overlay_for=w.id), copy_colors_from=w) wid = w.id overlay_window.action_on_close = partial( self.on_kitten_finish, wid, custom_callback or end_kitten) if action_on_removal is not None: overlay_window.action_on_removal = lambda *a: action_on_removal( wid, self) return overlay_window def kitten(self, kitten, *args): import shlex cmdline = args[0] if args else '' args = shlex.split(cmdline) if cmdline else [] self._run_kitten(kitten, args) def on_kitten_finish(self, target_window_id, end_kitten, source_window): output = self.get_output(source_window, num_lines=None) from kittens.runner import deserialize data = deserialize(output) if data is not None: end_kitten(data, target_window_id, self) def input_unicode_character(self): self._run_kitten('unicode_input') def set_tab_title(self): tab = self.active_tab if tab: args = [ '--name=tab-title', '--message', _('Enter the new title for this tab below.'), 'do_set_tab_title', str(tab.id) ] self._run_kitten('ask', args) def show_error(self, title, msg): self._run_kitten('show_error', args=['--title', title], input_data=msg) def do_set_tab_title(self, title, tab_id): tm = self.active_tab_manager if tm is not None and title: tab_id = int(tab_id) for tab in tm.tabs: if tab.id == tab_id: tab.set_title(title) break def kitty_shell(self, window_type): cmd = ['@', kitty_exe(), '@'] if window_type == 'tab': self._new_tab(cmd) elif window_type == 'os_window': os_window_id = self._new_os_window(cmd) self.os_window_map[os_window_id] elif window_type == 'overlay': w = self.active_window tab = self.active_tab if w is not None and tab is not None and w.overlay_for is None: tab.new_special_window(SpecialWindow(cmd, overlay_for=w.id)) else: self._new_window(cmd) def switch_focus_to(self, window_idx): tab = self.active_tab tab.set_active_window_idx(window_idx) def open_url(self, url, program=None, cwd=None): if url: if isinstance(program, str): program = to_cmdline(program) open_url(url, program or self.opts.open_url_with, cwd=cwd) def open_url_lines(self, lines, program=None): self.open_url(''.join(lines), program) def destroy(self): self.shutting_down = True self.child_monitor.shutdown_monitor() self.set_update_check_process() self.update_check_process = None del self.child_monitor for tm in self.os_window_map.values(): tm.destroy() self.os_window_map = {} destroy_global_data() def paste_to_active_window(self, text): if text: w = self.active_window if w is not None: w.paste(text) def paste_from_clipboard(self): text = get_clipboard_string() self.paste_to_active_window(text) def paste_from_selection(self): text = get_primary_selection( ) if supports_primary_selection else get_clipboard_string() self.paste_to_active_window(text) def set_primary_selection(self): w = self.active_window if w is not None and not w.destroyed: text = w.text_for_selection() if text: set_primary_selection(text) if self.opts.copy_on_select: self.copy_to_buffer(self.opts.copy_on_select) def copy_to_buffer(self, buffer_name): w = self.active_window if w is not None and not w.destroyed: text = w.text_for_selection() if text: if buffer_name == 'clipboard': set_clipboard_string(text) elif buffer_name == 'primary': set_primary_selection(text) else: self.clipboard_buffers[buffer_name] = text def paste_from_buffer(self, buffer_name): if buffer_name == 'clipboard': text = get_clipboard_string() elif buffer_name == 'primary': text = get_primary_selection() else: text = self.clipboard_buffers.get(buffer_name) if text: self.paste_to_active_window(text) def goto_tab(self, tab_num): tm = self.active_tab_manager if tm is not None: tm.goto_tab(tab_num - 1) def set_active_tab(self, tab): tm = self.active_tab_manager if tm is not None: return tm.set_active_tab(tab) return False def next_tab(self): tm = self.active_tab_manager if tm is not None: tm.next_tab() def previous_tab(self): tm = self.active_tab_manager if tm is not None: tm.next_tab(-1) prev_tab = previous_tab def process_stdin_source(self, window=None, stdin=None): w = window or self.active_window env = None if stdin: add_wrap_markers = stdin.endswith('_wrap') if add_wrap_markers: stdin = stdin[:-len('_wrap')] stdin = data_for_at(w, stdin, add_wrap_markers=add_wrap_markers) if stdin is not None: pipe_data = w.pipe_data( stdin, has_wrap_markers=add_wrap_markers) if w else {} if pipe_data: env = { 'KITTY_PIPE_DATA': '{scrolled_by}:{cursor_x},{cursor_y}:{lines},{columns}' .format(**pipe_data) } stdin = stdin.encode('utf-8') return env, stdin def special_window_for_cmd(self, cmd, window=None, stdin=None, cwd_from=None, as_overlay=False): w = window or self.active_window env, stdin = self.process_stdin_source(w, stdin) cmdline = [] for arg in cmd: if arg == '@selection': arg = data_for_at(w, arg) if not arg: continue cmdline.append(arg) overlay_for = w.id if as_overlay and w.overlay_for is None else None return SpecialWindow(cmd, stdin, cwd_from=cwd_from, overlay_for=overlay_for, env=env) def pipe(self, source, dest, exe, *args): cmd = [exe] + list(args) window = self.active_window cwd_from = window.child.pid_for_cwd if window else None def create_window(): return self.special_window_for_cmd(cmd, stdin=source, as_overlay=dest == 'overlay', cwd_from=cwd_from) if dest == 'overlay' or dest == 'window': tab = self.active_tab if tab is not None: return tab.new_special_window(create_window()) elif dest == 'tab': tm = self.active_tab_manager if tm is not None: tm.new_tab(special_window=create_window(), cwd_from=cwd_from) elif dest == 'os_window': self._new_os_window(create_window(), cwd_from=cwd_from) elif dest in ('clipboard', 'primary'): env, stdin = self.process_stdin_source(stdin=source, window=window) if stdin: func = set_clipboard_string if dest == 'clipboard' else set_primary_selection func(stdin) else: import subprocess env, stdin = self.process_stdin_source(stdin=source, window=window) cwd = None if cwd_from: with suppress(Exception): cwd = cwd_of_process(cwd_from) if stdin: r, w = safe_pipe(False) try: subprocess.Popen(cmd, env=env, stdin=r, cwd=cwd) except Exception: os.close(w) else: thread_write(w, stdin) finally: os.close(r) else: subprocess.Popen(cmd, env=env, cwd=cwd) def args_to_special_window(self, args, cwd_from=None): args = list(args) stdin = None w = self.active_window if args[0].startswith('@') and args[0] != '@': stdin = data_for_at(w, args[0]) or None if stdin is not None: stdin = stdin.encode('utf-8') del args[0] cmd = [] for arg in args: if arg == '@selection': arg = data_for_at(w, arg) if not arg: continue cmd.append(arg) return SpecialWindow(cmd, stdin, cwd_from=cwd_from) def _new_tab(self, args, cwd_from=None, as_neighbor=False): special_window = None if args: if isinstance(args, SpecialWindowInstance): special_window = args else: special_window = self.args_to_special_window(args, cwd_from=cwd_from) tm = self.active_tab_manager if tm is not None: return tm.new_tab(special_window=special_window, cwd_from=cwd_from, as_neighbor=as_neighbor) def _create_tab(self, args, cwd_from=None): as_neighbor = False if args and args[0].startswith('!'): as_neighbor = 'neighbor' in args[0][1:].split(',') args = args[1:] self._new_tab(args, as_neighbor=as_neighbor, cwd_from=cwd_from) def new_tab(self, *args): self._create_tab(args) def new_tab_with_cwd(self, *args): w = self.active_window_for_cwd cwd_from = w.child.pid_for_cwd if w is not None else None self._create_tab(args, cwd_from=cwd_from) def new_tab_with_wd(self, wd): special_window = SpecialWindow(None, cwd=wd) self._new_tab(special_window) def _new_window(self, args, cwd_from=None): tab = self.active_tab if tab is not None: location = None if args and args[0].startswith('!'): location = args[0][1:].lower() args = args[1:] if args: return tab.new_special_window(self.args_to_special_window( args, cwd_from=cwd_from), location=location) else: return tab.new_window(cwd_from=cwd_from, location=location) def new_window(self, *args): self._new_window(args) def new_window_with_cwd(self, *args): w = self.active_window_for_cwd if w is None: return self.new_window(*args) cwd_from = w.child.pid_for_cwd if w is not None else None self._new_window(args, cwd_from=cwd_from) def move_tab_forward(self): tm = self.active_tab_manager if tm is not None: tm.move_tab(1) def move_tab_backward(self): tm = self.active_tab_manager if tm is not None: tm.move_tab(-1) def disable_ligatures_in(self, where, strategy): if isinstance(where, str): windows = () if where == 'active': if self.active_window is not None: windows = (self.active_window, ) elif where == 'all': windows = self.all_windows elif where == 'tab': if self.active_tab is not None: windows = tuple(self.active_tab) else: windows = where for window in windows: window.screen.disable_ligatures = strategy window.refresh() def patch_colors(self, spec, cursor_text_color, configured=False): if configured: for k, v in spec.items(): if hasattr(self.opts, k): setattr(self.opts, k, color_from_int(v)) if cursor_text_color is not False: if isinstance(cursor_text_color, int): cursor_text_color = color_from_int(cursor_text_color) self.opts.cursor_text_color = cursor_text_color for tm in self.all_tab_managers: tm.tab_bar.patch_colors(spec) patch_global_colors(spec, configured) def safe_delete_temp_file(self, path): if is_path_in_temp_dir(path): with suppress(FileNotFoundError): os.remove(path) def set_update_check_process(self, process=None): if self.update_check_process is not None: with suppress(Exception): if self.update_check_process.poll() is None: self.update_check_process.kill() self.update_check_process = process def on_monitored_pid_death(self, pid, exit_status): update_check_process = getattr(self, 'update_check_process', None) if update_check_process is not None and pid == update_check_process.pid: self.update_check_process = None from .update_check import process_current_release try: raw = update_check_process.stdout.read().decode('utf-8') except Exception as e: log_error( 'Failed to read data from update check process, with error: {}' .format(e)) else: try: process_current_release(raw) except Exception as e: log_error( 'Failed to process update check data {!r}, with error: {}' .format(raw, e)) def notification_activated(self, identifier): if identifier == 'new-version': from .update_check import notification_activated notification_activated() def dbus_notification_callback(self, activated, *args): from .notify import dbus_notification_created, dbus_notification_activated if activated: dbus_notification_activated(*args) else: dbus_notification_created(*args) def show_bad_config_lines(self, bad_lines): def format_bad_line(bad_line): return '{}:{} in line: {}\n'.format(bad_line.number, bad_line.exception, bad_line.line) msg = '\n'.join(map(format_bad_line, bad_lines)).rstrip() self.show_error(_('Errors in kitty.conf'), msg) def set_colors(self, *args): from .cmds import parse_subcommand_cli, cmd_set_colors, set_colors opts, items = parse_subcommand_cli(cmd_set_colors, ['set-colors'] + list(args)) payload = cmd_set_colors(None, opts, items) set_colors(self, self.active_window, payload) def _move_window_to(self, window=None, target_tab_id=None, target_os_window_id=None): window = window or self.active_window if not window: return src_tab = self.tab_for_window(window) if src_tab is None: return if target_os_window_id == 'new': target_os_window_id = self.add_os_window() tm = self.os_window_map[target_os_window_id] target_tab = tm.new_tab(empty_tab=True) else: target_os_window_id = target_os_window_id or current_os_window() if target_tab_id == 'new': tm = self.os_window_map[target_os_window_id] target_tab = tm.new_tab(empty_tab=True) else: for tab in self.all_tabs: if tab.id == target_tab_id: target_tab = tab target_os_window_id = tab.os_window_id break else: return underlaid_window, overlaid_window = src_tab.detach_window(window) if underlaid_window: target_tab.attach_window(underlaid_window) if overlaid_window: target_tab.attach_window(overlaid_window) self._cleanup_tab_after_window_removal(src_tab) target_tab.make_active() def _move_tab_to(self, tab=None, target_os_window_id=None): tab = tab or self.active_tab if tab is None: return if target_os_window_id is None: target_os_window_id = self.add_os_window() tm = self.os_window_map[target_os_window_id] target_tab = tm.new_tab(empty_tab=True) target_tab.take_over_from(tab) self._cleanup_tab_after_window_removal(tab) target_tab.make_active() def detach_window(self, *args): if not args or args[0] == 'new': return self._move_window_to(target_os_window_id='new') if args[0] == 'new-tab': return self._move_window_to(target_tab_id='new') lines = ['Choose a tab to move the window to', ''] tab_id_map = {} current_tab = self.active_tab for i, tab in enumerate(self.all_tabs): if tab is not current_tab: tab_id_map[i + 1] = tab.id lines.append('{} {}'.format(i + 1, tab.title)) new_idx = len(tab_id_map) + 1 tab_id_map[new_idx] = 'new' lines.append('{} {}'.format(new_idx, 'New tab')) new_idx = len(tab_id_map) + 1 tab_id_map[new_idx] = None lines.append('{} {}'.format(new_idx, 'New OS Window')) def done(data, target_window_id, self): done.tab_id = tab_id_map[int(data['match'][0].partition(' ')[0])] def done2(target_window_id, self): tab_id = done.tab_id target_window = None for w in self.all_windows: if w.id == target_window_id: target_window = w break if tab_id is None: self._move_window_to(window=target_window, target_os_window_id='new') else: self._move_window_to(window=target_window, target_tab_id=tab_id) self._run_kitten('hints', args=( '--type=regex', r'--regex=(?m)^\d+ .+$', ), input_data='\r\n'.join(lines).encode('utf-8'), custom_callback=done, action_on_removal=done2) def detach_tab(self, *args): if not args or args[0] == 'new': return self._move_tab_to() lines = ['Choose an OS window to move the tab to', ''] os_window_id_map = {} current_os_window = getattr(self.active_tab, 'os_window_id', 0) for i, osw in enumerate(self.os_window_map): tm = self.os_window_map[osw] if current_os_window != osw and tm.active_tab and tm.active_tab: os_window_id_map[i + 1] = osw lines.append('{} {}'.format(i + 1, tm.active_tab.title)) new_idx = len(os_window_id_map) + 1 os_window_id_map[new_idx] = None lines.append('{} {}'.format(new_idx, 'New OS Window')) def done(data, target_window_id, self): done.os_window_id = os_window_id_map[int( data['match'][0].partition(' ')[0])] def done2(target_window_id, self): os_window_id = done.os_window_id target_tab = self.active_tab for w in self.all_windows: if w.id == target_window_id: target_tab = w.tabref() break if target_tab and target_tab.os_window_id == os_window_id: return self._move_tab_to(tab=target_tab, target_os_window_id=os_window_id) self._run_kitten('hints', args=( '--type=regex', r'--regex=(?m)^\d+ .+$', ), input_data='\r\n'.join(lines).encode('utf-8'), custom_callback=done, action_on_removal=done2)
class Registry(Mapping): """ Model registry for a particular database. The registry is essentially a mapping between model names and model classes. There is one registry instance per database. """ _lock = threading.RLock() _saved_lock = None # a cache for model classes, indexed by their base classes model_cache = WeakValueDictionary() @lazy_classproperty def registries(cls): """ A mapping from database names to registries. """ size = config.get('registry_lru_size', None) if not size: # Size the LRU depending of the memory limits if os.name != 'posix': # cannot specify the memory limit soft on windows... size = 42 else: # A registry takes 10MB of memory on average, so we reserve # 10Mb (registry) + 5Mb (working memory) per registry avgsz = 15 * 1024 * 1024 size = int(config['limit_memory_soft'] / avgsz) return LRU(size) def __new__(cls, db_name): """ Return the registry for the given database name.""" with cls._lock: try: return cls.registries[db_name] except KeyError: return cls.new(db_name) finally: # set db tracker - cleaned up at the WSGI dispatching phase in # odoo.service.wsgi_server.application threading.current_thread().dbname = db_name @classmethod def new(cls, db_name, force_demo=False, status=None, update_module=False): """ Create and return a new registry for the given database name. """ with cls._lock: with odoo.api.Environment.manage(): registry = object.__new__(cls) registry.init(db_name) # Initializing a registry will call general code which will in # turn call Registry() to obtain the registry being initialized. # Make it available in the registries dictionary then remove it # if an exception is raised. cls.delete(db_name) cls.registries[db_name] = registry try: registry.setup_signaling() # This should be a method on Registry try: odoo.modules.load_modules(registry._db, force_demo, status, update_module) except Exception: odoo.modules.reset_modules_state(db_name) raise except Exception: _logger.exception('Failed to load registry') del cls.registries[db_name] raise # load_modules() above can replace the registry by calling # indirectly new() again (when modules have to be uninstalled). # Yeah, crazy. init_parent = registry._init_parent registry = cls.registries[db_name] registry._init_parent.update(init_parent) with closing(registry.cursor()) as cr: registry.do_parent_store(cr) cr.commit() registry.ready = True registry.registry_invalidated = bool(update_module) return registry def init(self, db_name): self.models = {} # model name/model instance mapping self._sql_error = {} self._init = True self._init_parent = set() self._assertion_report = assertion_report.assertion_report() self._fields_by_model = None self._post_init_queue = deque() # modules fully loaded (maintained during init phase by `loading` module) self._init_modules = set() self.updated_modules = [] # installed/updated modules self.db_name = db_name self._db = odoo.sql_db.db_connect(db_name) # cursor for test mode; None means "normal" mode self.test_cr = None self.test_lock = None # Indicates that the registry is self.loaded = False # whether all modules are loaded self.ready = False # whether everything is set up # Inter-process signaling (used only when odoo.multi_process is True): # The `base_registry_signaling` sequence indicates the whole registry # must be reloaded. # The `base_cache_signaling sequence` indicates all caches must be # invalidated (i.e. cleared). self.registry_sequence = None self.cache_sequence = None # Flags indicating invalidation of the registry or the cache. self.registry_invalidated = False self.cache_invalidated = False with closing(self.cursor()) as cr: has_unaccent = odoo.modules.db.has_unaccent(cr) if odoo.tools.config['unaccent'] and not has_unaccent: _logger.warning("The option --unaccent was given but no unaccent() function was found in database.") self.has_unaccent = odoo.tools.config['unaccent'] and has_unaccent @classmethod def delete(cls, db_name): """ Delete the registry linked to a given database. """ with cls._lock: if db_name in cls.registries: registry = cls.registries.pop(db_name) registry.clear_caches() registry.registry_invalidated = True @classmethod def delete_all(cls): """ Delete all the registries. """ with cls._lock: for db_name in list(cls.registries.keys()): cls.delete(db_name) # # Mapping abstract methods implementation # => mixin provides methods keys, items, values, get, __eq__, and __ne__ # def __len__(self): """ Return the size of the registry. """ return len(self.models) def __iter__(self): """ Return an iterator over all model names. """ return iter(self.models) def __getitem__(self, model_name): """ Return the model with the given name or raise KeyError if it doesn't exist.""" return self.models[model_name] def __call__(self, model_name): """ Same as ``self[model_name]``. """ return self.models[model_name] def __setitem__(self, model_name, model): """ Add or replace a model in the registry.""" self.models[model_name] = model @lazy_property def field_sequence(self): """ Return a function mapping a field to an integer. The value of a field is guaranteed to be strictly greater than the value of the field's dependencies. """ # map fields on their dependents dependents = { field: set(dep for dep, _ in model._field_triggers[field] if dep != field) for model in self.values() for field in model._fields.values() } # sort them topologically, and associate a sequence number to each field mapping = { field: num for num, field in enumerate(reversed(topological_sort(dependents))) } return mapping.get def do_parent_store(self, cr): env = odoo.api.Environment(cr, SUPERUSER_ID, {}) for model_name in self._init_parent: if model_name in env: env[model_name]._parent_store_compute() self._init = False def descendants(self, model_names, *kinds): """ Return the models corresponding to ``model_names`` and all those that inherit/inherits from them. """ assert all(kind in ('_inherit', '_inherits') for kind in kinds) funcs = [attrgetter(kind + '_children') for kind in kinds] models = OrderedSet() queue = deque(model_names) while queue: model = self[queue.popleft()] models.add(model._name) for func in funcs: queue.extend(func(model)) return models def load(self, cr, module): """ Load a given module in the registry, and return the names of the modified models. At the Python level, the modules are already loaded, but not yet on a per-registry level. This method populates a registry with the given modules, i.e. it instanciates all the classes of a the given module and registers them in the registry. """ from .. import models lazy_property.reset_all(self) # Instantiate registered classes (via the MetaModel automatic discovery # or via explicit constructor call), and add them to the pool. model_names = [] for cls in models.MetaModel.module_to_models.get(module.name, []): # models register themselves in self.models model = cls._build_model(self, cr) model_names.append(model._name) return self.descendants(model_names, '_inherit', '_inherits') def setup_models(self, cr): """ Complete the setup of models. This must be called after loading modules and before using the ORM. """ lazy_property.reset_all(self) env = odoo.api.Environment(cr, SUPERUSER_ID, {}) # add manual models if self._init_modules: env['ir.model']._add_manual_models() # prepare the setup on all models models = list(env.values()) for model in models: model._prepare_setup() # do the actual setup from a clean state self._m2m = {} for model in models: model._setup_base() for model in models: model._setup_fields() for model in models: model._setup_complete() self.registry_invalidated = True def post_init(self, func, *args, **kwargs): """ Register a function to call at the end of :meth:`~.init_models`. """ self._post_init_queue.append(partial(func, *args, **kwargs)) def init_models(self, cr, model_names, context): """ Initialize a list of models (given by their name). Call methods ``_auto_init`` and ``init`` on each model to create or update the database tables supporting the models. The ``context`` may contain the following items: - ``module``: the name of the module being installed/updated, if any; - ``update_custom_fields``: whether custom fields should be updated. """ if 'module' in context: _logger.info('module %s: creating or updating database tables', context['module']) env = odoo.api.Environment(cr, SUPERUSER_ID, context) models = [env[model_name] for model_name in model_names] for model in models: model._auto_init() model.init() while self._post_init_queue: func = self._post_init_queue.popleft() func() if models: models[0].recompute() # make sure all tables are present table2model = {model._table: name for name, model in env.items() if not model._abstract} missing_tables = set(table2model).difference(existing_tables(cr, table2model)) if missing_tables: missing = {table2model[table] for table in missing_tables} _logger.warning("Models have no table: %s.", ", ".join(missing)) # recreate missing tables following model dependencies deps = {name: model._depends for name, model in env.items()} for name in topological_sort(deps): if name in missing: _logger.info("Recreate table of model %s.", name) env[name].init() # check again, and log errors if tables are still missing missing_tables = set(table2model).difference(existing_tables(cr, table2model)) for table in missing_tables: _logger.error("Model %s has no table.", table2model[table]) @lazy_property def cache(self): """ A cache for model methods. """ # this lazy_property is automatically reset by lazy_property.reset_all() return LRU(8192) def _clear_cache(self): """ Clear the cache and mark it as invalidated. """ self.cache.clear() self.cache_invalidated = True def clear_caches(self): """ Clear the caches associated to methods decorated with ``tools.ormcache`` or ``tools.ormcache_multi`` for all the models. """ for model in self.models.values(): model.clear_caches() def setup_signaling(self): """ Setup the inter-process signaling on this registry. """ if not odoo.multi_process: return with self.cursor() as cr: # The `base_registry_signaling` sequence indicates when the registry # must be reloaded. # The `base_cache_signaling` sequence indicates when all caches must # be invalidated (i.e. cleared). cr.execute("SELECT sequence_name FROM information_schema.sequences WHERE sequence_name='base_registry_signaling'") if not cr.fetchall(): cr.execute("CREATE SEQUENCE base_registry_signaling INCREMENT BY 1 START WITH 1") cr.execute("SELECT nextval('base_registry_signaling')") cr.execute("CREATE SEQUENCE base_cache_signaling INCREMENT BY 1 START WITH 1") cr.execute("SELECT nextval('base_cache_signaling')") cr.execute(""" SELECT base_registry_signaling.last_value, base_cache_signaling.last_value FROM base_registry_signaling, base_cache_signaling""") self.registry_sequence, self.cache_sequence = cr.fetchone() _logger.debug("Multiprocess load registry signaling: [Registry: %s] [Cache: %s]", self.registry_sequence, self.cache_sequence) def check_signaling(self): """ Check whether the registry has changed, and performs all necessary operations to update the registry. Return an up-to-date registry. """ if not odoo.multi_process: return self with closing(self.cursor()) as cr: cr.execute(""" SELECT base_registry_signaling.last_value, base_cache_signaling.last_value FROM base_registry_signaling, base_cache_signaling""") r, c = cr.fetchone() _logger.debug("Multiprocess signaling check: [Registry - %s -> %s] [Cache - %s -> %s]", self.registry_sequence, r, self.cache_sequence, c) # Check if the model registry must be reloaded if self.registry_sequence != r: _logger.info("Reloading the model registry after database signaling.") self = Registry.new(self.db_name) # Check if the model caches must be invalidated. elif self.cache_sequence != c: _logger.info("Invalidating all model caches after database signaling.") self.clear_caches() self.cache_invalidated = False self.registry_sequence = r self.cache_sequence = c return self def signal_changes(self): """ Notifies other processes if registry or cache has been invalidated. """ if odoo.multi_process and self.registry_invalidated: _logger.info("Registry changed, signaling through the database") with closing(self.cursor()) as cr: cr.execute("select nextval('base_registry_signaling')") self.registry_sequence = cr.fetchone()[0] # no need to notify cache invalidation in case of registry invalidation, # because reloading the registry implies starting with an empty cache elif odoo.multi_process and self.cache_invalidated: _logger.info("At least one model cache has been invalidated, signaling through the database.") with closing(self.cursor()) as cr: cr.execute("select nextval('base_cache_signaling')") self.cache_sequence = cr.fetchone()[0] self.registry_invalidated = False self.cache_invalidated = False def reset_changes(self): """ Reset the registry and cancel all invalidations. """ if self.registry_invalidated: with closing(self.cursor()) as cr: self.setup_models(cr) self.registry_invalidated = False if self.cache_invalidated: self.cache.clear() self.cache_invalidated = False @contextmanager def manage_changes(self): """ Context manager to signal/discard registry and cache invalidations. """ try: yield self self.signal_changes() except Exception: self.reset_changes() raise def in_test_mode(self): """ Test whether the registry is in 'test' mode. """ return self.test_cr is not None def enter_test_mode(self, cr): """ Enter the 'test' mode, where one cursor serves several requests. """ assert self.test_cr is None self.test_cr = cr self.test_lock = threading.RLock() assert Registry._saved_lock is None Registry._saved_lock = Registry._lock Registry._lock = DummyRLock() def leave_test_mode(self): """ Leave the test mode. """ assert self.test_cr is not None self.test_cr = None self.test_lock = None assert Registry._saved_lock is not None Registry._lock = Registry._saved_lock Registry._saved_lock = None def cursor(self): """ Return a new cursor for the database. The cursor itself may be used as a context manager to commit/rollback and close automatically. """ if self.test_cr is not None: # When in test mode, we use a proxy object that uses 'self.test_cr' # underneath. return TestCursor(self.test_cr, self.test_lock) return self._db.cursor()
class SerializableLock: _locks = WeakValueDictionary() """ A Serializable per-process Lock This wraps a normal ``threading.Lock`` object and satisfies the same interface. However, this lock can also be serialized and sent to different processes. It will not block concurrent operations between processes (for this you should look at ``multiprocessing.Lock`` or ``locket.lock_file`` but will consistently deserialize into the same lock. So if we make a lock in one process:: lock = SerializableLock() And then send it over to another process multiple times:: bytes = pickle.dumps(lock) a = pickle.loads(bytes) b = pickle.loads(bytes) Then the deserialized objects will operate as though they were the same lock, and collide as appropriate. This is useful for consistently protecting resources on a per-process level. The creation of locks is itself not threadsafe. """ def __init__(self, token=None): self.token = token or str(uuid.uuid4()) if self.token in SerializableLock._locks: self.lock = SerializableLock._locks[self.token] else: self.lock = Lock() SerializableLock._locks[self.token] = self.lock def acquire(self, *args, **kwargs): return self.lock.acquire(*args, **kwargs) def release(self, *args, **kwargs): return self.lock.release(*args, **kwargs) def __enter__(self): self.lock.__enter__() def __exit__(self, *args): self.lock.__exit__(*args) def locked(self): return self.lock.locked() def __getstate__(self): return self.token def __setstate__(self, token): self.__init__(token) def __str__(self): return "<%s: %s>" % (self.__class__.__name__, self.token) __repr__ = __str__
class PSPRayleighReflectance(CompositeBase): _rayleigh_cache = WeakValueDictionary() def get_angles(self, vis): from pyorbital.astronomy import get_alt_az, sun_zenith_angle from pyorbital.orbital import get_observer_look lons, lats = vis.attrs['area'].get_lonlats_dask(chunks=vis.data.chunks) sunalt, suna = get_alt_az(vis.attrs['start_time'], lons, lats) suna = xu.rad2deg(suna) sunz = sun_zenith_angle(vis.attrs['start_time'], lons, lats) sata, satel = get_observer_look(vis.attrs['satellite_longitude'], vis.attrs['satellite_latitude'], vis.attrs['satellite_altitude'], vis.attrs['start_time'], lons, lats, 0) satz = 90 - satel return sata, satz, suna, sunz def __call__(self, projectables, optional_datasets=None, **info): """Get the corrected reflectance when removing Rayleigh scattering. Uses pyspectral. """ from pyspectral.rayleigh import Rayleigh if not optional_datasets or len(optional_datasets) != 4: vis, red = self.check_areas(projectables) sata, satz, suna, sunz = self.get_angles(vis) red.data = da.rechunk(red.data, vis.data.chunks) else: vis, red, sata, satz, suna, sunz = self.check_areas( projectables + optional_datasets) sata, satz, suna, sunz = optional_datasets # get the dask array underneath sata = sata.data satz = satz.data suna = suna.data sunz = sunz.data LOG.info('Removing Rayleigh scattering and aerosol absorption') # First make sure the two azimuth angles are in the range 0-360: sata = sata % 360. suna = suna % 360. ssadiff = da.absolute(suna - sata) ssadiff = da.minimum(ssadiff, 360 - ssadiff) del sata, suna atmosphere = self.attrs.get('atmosphere', 'us-standard') aerosol_type = self.attrs.get('aerosol_type', 'marine_clean_aerosol') rayleigh_key = (vis.attrs['platform_name'], vis.attrs['sensor'], atmosphere, aerosol_type) if rayleigh_key not in self._rayleigh_cache: corrector = Rayleigh(vis.attrs['platform_name'], vis.attrs['sensor'], atmosphere=atmosphere, aerosol_type=aerosol_type) self._rayleigh_cache[rayleigh_key] = corrector else: corrector = self._rayleigh_cache[rayleigh_key] try: refl_cor_band = corrector.get_reflectance(sunz, satz, ssadiff, vis.attrs['name'], red.data) except (KeyError, IOError): LOG.warning( "Could not get the reflectance correction using band name: %s", vis.attrs['name']) LOG.warning( "Will try use the wavelength, however, this may be ambiguous!") refl_cor_band = corrector.get_reflectance( sunz, satz, ssadiff, vis.attrs['wavelength'][1], red.data) proj = vis - refl_cor_band proj.attrs = vis.attrs self.apply_modifier_info(vis, proj) return proj
class Path(object): """ :class:`Path` represents a series of possibly disconnected, possibly closed, line and curve segments. The underlying storage is made up of two parallel numpy arrays: - *vertices*: an Nx2 float array of vertices - *codes*: an N-length uint8 array of vertex types These two arrays always have the same length in the first dimension. For example, to represent a cubic curve, you must provide three vertices as well as three codes ``CURVE3``. The code types are: - ``STOP`` : 1 vertex (ignored) A marker for the end of the entire path (currently not required and ignored) - ``MOVETO`` : 1 vertex Pick up the pen and move to the given vertex. - ``LINETO`` : 1 vertex Draw a line from the current position to the given vertex. - ``CURVE3`` : 1 control point, 1 endpoint Draw a quadratic Bezier curve from the current position, with the given control point, to the given end point. - ``CURVE4`` : 2 control points, 1 endpoint Draw a cubic Bezier curve from the current position, with the given control points, to the given end point. - ``CLOSEPOLY`` : 1 vertex (ignored) Draw a line segment to the start point of the current polyline. Users of Path objects should not access the vertices and codes arrays directly. Instead, they should use :meth:`iter_segments` to get the vertex/code pairs. This is important, since many :class:`Path` objects, as an optimization, do not store a *codes* at all, but have a default one provided for them by :meth:`iter_segments`. Note also that the vertices and codes arrays should be treated as immutable -- there are a number of optimizations and assumptions made up front in the constructor that will not change when the data changes. """ # Path codes STOP = 0 # 1 vertex MOVETO = 1 # 1 vertex LINETO = 2 # 1 vertex CURVE3 = 3 # 2 vertices CURVE4 = 4 # 3 vertices CLOSEPOLY = 0x4f # 1 vertex NUM_VERTICES = [1, 1, 1, 2, 3, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1] code_type = np.uint8 def __init__(self, vertices, codes=None, _interpolation_steps=1): """ Create a new path with the given vertices and codes. *vertices* is an Nx2 numpy float array, masked array or Python sequence. *codes* is an N-length numpy array or Python sequence of type :attr:`matplotlib.path.Path.code_type`. These two arrays must have the same length in the first dimension. If *codes* is None, *vertices* will be treated as a series of line segments. If *vertices* contains masked values, they will be converted to NaNs which are then handled correctly by the Agg PathIterator and other consumers of path data, such as :meth:`iter_segments`. *interpolation_steps* is used as a hint to certain projections, such as Polar, that this path should be linearly interpolated immediately before drawing. This attribute is primarily an implementation detail and is not intended for public use. """ if ma.isMaskedArray(vertices): vertices = vertices.astype(np.float_).filled(np.nan) else: vertices = np.asarray(vertices, np.float_) if codes is not None: codes = np.asarray(codes, self.code_type) assert codes.ndim == 1 assert len(codes) == len(vertices) assert vertices.ndim == 2 assert vertices.shape[1] == 2 self.should_simplify = ( rcParams['path.simplify'] and (len(vertices) >= 128 and (codes is None or np.all(codes <= Path.LINETO)))) self.simplify_threshold = rcParams['path.simplify_threshold'] self.has_nonfinite = not np.isfinite(vertices).all() self.codes = codes self.vertices = vertices self._interpolation_steps = _interpolation_steps @classmethod def make_compound_path(cls, *args): """ (staticmethod) Make a compound path from a list of Path objects. Only polygons (not curves) are supported. """ for p in args: assert p.codes is None lengths = [len(x) for x in args] total_length = sum(lengths) vertices = np.vstack([x.vertices for x in args]) vertices.reshape((total_length, 2)) codes = cls.LINETO * np.ones(total_length) i = 0 for length in lengths: codes[i] = cls.MOVETO i += length return cls(vertices, codes) def __repr__(self): return "Path(%s, %s)" % (self.vertices, self.codes) def __len__(self): return len(self.vertices) def iter_segments(self, transform=None, remove_nans=True, clip=None, quantize=False, simplify=None, curves=True): """ Iterates over all of the curve segments in the path. Each iteration returns a 2-tuple (*vertices*, *code*), where *vertices* is a sequence of 1 - 3 coordinate pairs, and *code* is one of the :class:`Path` codes. Additionally, this method can provide a number of standard cleanups and conversions to the path. *transform*: if not None, the given affine transformation will be applied to the path. *remove_nans*: if True, will remove all NaNs from the path and insert MOVETO commands to skip over them. *clip*: if not None, must be a four-tuple (x1, y1, x2, y2) defining a rectangle in which to clip the path. *quantize*: if None, auto-quantize. If True, force quantize, and if False, don't quantize. *simplify*: if True, perform simplification, to remove vertices that do not affect the appearance of the path. If False, perform no simplification. If None, use the should_simplify member variable. *curves*: If True, curve segments will be returned as curve segments. If False, all curves will be converted to line segments. """ vertices = self.vertices if not len(vertices): return codes = self.codes NUM_VERTICES = self.NUM_VERTICES MOVETO = self.MOVETO LINETO = self.LINETO CLOSEPOLY = self.CLOSEPOLY STOP = self.STOP vertices, codes = cleanup_path(self, transform, remove_nans, clip, quantize, simplify, curves) len_vertices = len(vertices) i = 0 while i < len_vertices: code = codes[i] if code == STOP: return else: num_vertices = NUM_VERTICES[int(code) & 0xf] curr_vertices = vertices[i:i + num_vertices].flatten() yield curr_vertices, code i += num_vertices def transformed(self, transform): """ Return a transformed copy of the path. .. seealso:: :class:`matplotlib.transforms.TransformedPath` A specialized path class that will cache the transformed result and automatically update when the transform changes. """ return Path(transform.transform(self.vertices), self.codes, self._interpolation_steps) def contains_point(self, point, transform=None): """ Returns *True* if the path contains the given point. If *transform* is not *None*, the path will be transformed before performing the test. """ if transform is not None: transform = transform.frozen() return point_in_path(point[0], point[1], self, transform) def contains_path(self, path, transform=None): """ Returns *True* if this path completely contains the given path. If *transform* is not *None*, the path will be transformed before performing the test. """ if transform is not None: transform = transform.frozen() return path_in_path(self, None, path, transform) def get_extents(self, transform=None): """ Returns the extents (*xmin*, *ymin*, *xmax*, *ymax*) of the path. Unlike computing the extents on the *vertices* alone, this algorithm will take into account the curves and deal with control points appropriately. """ from transforms import Bbox if transform is not None: transform = transform.frozen() return Bbox(get_path_extents(self, transform)) def intersects_path(self, other, filled=True): """ Returns *True* if this path intersects another given path. *filled*, when True, treats the paths as if they were filled. That is, if one path completely encloses the other, :meth:`intersects_path` will return True. """ return path_intersects_path(self, other, filled) def intersects_bbox(self, bbox, filled=True): """ Returns *True* if this path intersects a given :class:`~matplotlib.transforms.Bbox`. *filled*, when True, treats the path as if it was filled. That is, if one path completely encloses the other, :meth:`intersects_path` will return True. """ from transforms import BboxTransformTo rectangle = self.unit_rectangle().transformed(BboxTransformTo(bbox)) result = self.intersects_path(rectangle, filled) return result def interpolated(self, steps): """ Returns a new path resampled to length N x steps. Does not currently handle interpolating curves. """ if steps == 1: return self vertices = simple_linear_interpolation(self.vertices, steps) codes = self.codes if codes is not None: new_codes = Path.LINETO * np.ones(((len(codes) - 1) * steps + 1, )) new_codes[0::steps] = codes else: new_codes = None return Path(vertices, new_codes) def to_polygons(self, transform=None, width=0, height=0): """ Convert this path to a list of polygons. Each polygon is an Nx2 array of vertices. In other words, each polygon has no ``MOVETO`` instructions or curves. This is useful for displaying in backends that do not support compound paths or Bezier curves, such as GDK. If *width* and *height* are both non-zero then the lines will be simplified so that vertices outside of (0, 0), (width, height) will be clipped. """ if len(self.vertices) == 0: return [] if transform is not None: transform = transform.frozen() if self.codes is None and (width == 0 or height == 0): if transform is None: return [self.vertices] else: return [transform.transform(self.vertices)] # Deal with the case where there are curves and/or multiple # subpaths (using extension code) return convert_path_to_polygons(self, transform, width, height) _unit_rectangle = None @classmethod def unit_rectangle(cls): """ (staticmethod) Returns a :class:`Path` of the unit rectangle from (0, 0) to (1, 1). """ if cls._unit_rectangle is None: cls._unit_rectangle = \ cls([[0.0, 0.0], [1.0, 0.0], [1.0, 1.0], [0.0, 1.0], [0.0, 0.0]]) return cls._unit_rectangle _unit_regular_polygons = WeakValueDictionary() @classmethod def unit_regular_polygon(cls, numVertices): """ (staticmethod) Returns a :class:`Path` for a unit regular polygon with the given *numVertices* and radius of 1.0, centered at (0, 0). """ if numVertices <= 16: path = cls._unit_regular_polygons.get(numVertices) else: path = None if path is None: theta = (2 * np.pi / numVertices * np.arange(numVertices + 1).reshape((numVertices + 1, 1))) # This initial rotation is to make sure the polygon always # "points-up" theta += np.pi / 2.0 verts = np.concatenate((np.cos(theta), np.sin(theta)), 1) path = cls(verts) cls._unit_regular_polygons[numVertices] = path return path _unit_regular_stars = WeakValueDictionary() @classmethod def unit_regular_star(cls, numVertices, innerCircle=0.5): """ (staticmethod) Returns a :class:`Path` for a unit regular star with the given numVertices and radius of 1.0, centered at (0, 0). """ if numVertices <= 16: path = cls._unit_regular_stars.get((numVertices, innerCircle)) else: path = None if path is None: ns2 = numVertices * 2 theta = (2 * np.pi / ns2 * np.arange(ns2 + 1)) # This initial rotation is to make sure the polygon always # "points-up" theta += np.pi / 2.0 r = np.ones(ns2 + 1) r[1::2] = innerCircle verts = np.vstack( (r * np.cos(theta), r * np.sin(theta))).transpose() path = cls(verts) cls._unit_regular_polygons[(numVertices, innerCircle)] = path return path @classmethod def unit_regular_asterisk(cls, numVertices): """ (staticmethod) Returns a :class:`Path` for a unit regular asterisk with the given numVertices and radius of 1.0, centered at (0, 0). """ return cls.unit_regular_star(numVertices, 0.0) _unit_circle = None @classmethod def unit_circle(cls): """ (staticmethod) Returns a :class:`Path` of the unit circle. The circle is approximated using cubic Bezier curves. This uses 8 splines around the circle using the approach presented here: Lancaster, Don. `Approximating a Circle or an Ellipse Using Four Bezier Cubic Splines <http://www.tinaja.com/glib/ellipse4.pdf>`_. """ if cls._unit_circle is None: MAGIC = 0.2652031 SQRTHALF = np.sqrt(0.5) MAGIC45 = np.sqrt((MAGIC * MAGIC) / 2.0) vertices = np.array([[0.0, -1.0], [MAGIC, -1.0], [SQRTHALF - MAGIC45, -SQRTHALF - MAGIC45], [SQRTHALF, -SQRTHALF], [SQRTHALF + MAGIC45, -SQRTHALF + MAGIC45], [1.0, -MAGIC], [1.0, 0.0], [1.0, MAGIC], [SQRTHALF + MAGIC45, SQRTHALF - MAGIC45], [SQRTHALF, SQRTHALF], [SQRTHALF - MAGIC45, SQRTHALF + MAGIC45], [MAGIC, 1.0], [0.0, 1.0], [-MAGIC, 1.0], [-SQRTHALF + MAGIC45, SQRTHALF + MAGIC45], [-SQRTHALF, SQRTHALF], [-SQRTHALF - MAGIC45, SQRTHALF - MAGIC45], [-1.0, MAGIC], [-1.0, 0.0], [-1.0, -MAGIC], [-SQRTHALF - MAGIC45, -SQRTHALF + MAGIC45], [-SQRTHALF, -SQRTHALF], [-SQRTHALF + MAGIC45, -SQRTHALF - MAGIC45], [-MAGIC, -1.0], [0.0, -1.0], [0.0, -1.0]], np.float_) codes = cls.CURVE4 * np.ones(26) codes[0] = cls.MOVETO codes[-1] = cls.CLOSEPOLY cls._unit_circle = cls(vertices, codes) return cls._unit_circle @classmethod def arc(cls, theta1, theta2, n=None, is_wedge=False): """ (staticmethod) Returns an arc on the unit circle from angle *theta1* to angle *theta2* (in degrees). If *n* is provided, it is the number of spline segments to make. If *n* is not provided, the number of spline segments is determined based on the delta between *theta1* and *theta2*. Masionobe, L. 2003. `Drawing an elliptical arc using polylines, quadratic or cubic Bezier curves <http://www.spaceroots.org/documents/ellipse/index.html>`_. """ # degrees to radians theta1 *= np.pi / 180.0 theta2 *= np.pi / 180.0 twopi = np.pi * 2.0 halfpi = np.pi * 0.5 eta1 = np.arctan2(np.sin(theta1), np.cos(theta1)) eta2 = np.arctan2(np.sin(theta2), np.cos(theta2)) eta2 -= twopi * np.floor((eta2 - eta1) / twopi) if (theta2 - theta1 > np.pi) and (eta2 - eta1 < np.pi): eta2 += twopi # number of curve segments to make if n is None: n = int(2**np.ceil((eta2 - eta1) / halfpi)) if n < 1: raise ValueError("n must be >= 1 or None") deta = (eta2 - eta1) / n t = np.tan(0.5 * deta) alpha = np.sin(deta) * (np.sqrt(4.0 + 3.0 * t * t) - 1) / 3.0 steps = np.linspace(eta1, eta2, n + 1, True) cos_eta = np.cos(steps) sin_eta = np.sin(steps) xA = cos_eta[:-1] yA = sin_eta[:-1] xA_dot = -yA yA_dot = xA xB = cos_eta[1:] yB = sin_eta[1:] xB_dot = -yB yB_dot = xB if is_wedge: length = n * 3 + 4 vertices = np.empty((length, 2), np.float_) codes = cls.CURVE4 * np.ones((length, ), cls.code_type) vertices[1] = [xA[0], yA[0]] codes[0:2] = [cls.MOVETO, cls.LINETO] codes[-2:] = [cls.LINETO, cls.CLOSEPOLY] vertex_offset = 2 end = length - 2 else: length = n * 3 + 1 vertices = np.empty((length, 2), np.float_) codes = cls.CURVE4 * np.ones((length, ), cls.code_type) vertices[0] = [xA[0], yA[0]] codes[0] = cls.MOVETO vertex_offset = 1 end = length vertices[vertex_offset:end:3, 0] = xA + alpha * xA_dot vertices[vertex_offset:end:3, 1] = yA + alpha * yA_dot vertices[vertex_offset + 1:end:3, 0] = xB - alpha * xB_dot vertices[vertex_offset + 1:end:3, 1] = yB - alpha * yB_dot vertices[vertex_offset + 2:end:3, 0] = xB vertices[vertex_offset + 2:end:3, 1] = yB return cls(vertices, codes) @classmethod def wedge(cls, theta1, theta2, n=None): """ (staticmethod) Returns a wedge of the unit circle from angle *theta1* to angle *theta2* (in degrees). If *n* is provided, it is the number of spline segments to make. If *n* is not provided, the number of spline segments is determined based on the delta between *theta1* and *theta2*. """ return cls.arc(theta1, theta2, n, True) _hatch_dict = maxdict(8) @classmethod def hatch(cls, hatchpattern, density=6): """ Given a hatch specifier, *hatchpattern*, generates a Path that can be used in a repeated hatching pattern. *density* is the number of lines per unit square. """ from matplotlib.hatch import get_path if hatchpattern is None: return None hatch_path = cls._hatch_dict.get((hatchpattern, density)) if hatch_path is not None: return hatch_path hatch_path = get_path(hatchpattern, density) cls._hatch_dict[(hatchpattern, density)] = hatch_path return hatch_path
def Resolver( get_host=get_host_default, get_nameservers=get_nameservers_default, set_sock_options=set_sock_options_default, transform_fqdn=mix_case, max_cname_chain_length=20, get_logger_adapter=get_logger_adapter_default, ): loop = get_running_loop() logger = get_logger_adapter({}) cache = {} invalidate_callbacks = {} in_progress = WeakValueDictionary() parsed_etc_hosts = parse_etc_hosts() logger.debug('Parsed /etc/hosts: %s', parsed_etc_hosts) parsed_resolve_conf = parse_resolve_conf() logger.debug('Parsed /etc/resolv.conf: %s', parsed_resolve_conf) async def resolve( fqdn_str, qtype, get_logger_adapter=get_logger_adapter, ): logger = get_logger_adapter({ 'aiodnsresolver_fqdn': fqdn_str, 'aiodnsresolver_qtype': qtype, }) fqdn = BytesExpiresAt(fqdn_str.encode('idna'), expires_at=float('inf')) for _ in range(max_cname_chain_length): host = await get_host(parsed_etc_hosts, fqdn, qtype) if host is not None: logger.info('Resolved %s from hosts: %s', fqdn, host) return (host,) cname_rdata, qtype_rdata = await request_memoized(logger, fqdn, qtype) min_expires_at = fqdn.expires_at # pylint: disable=no-member if qtype_rdata: logger.info('Resolved %s %s: %s', fqdn, qtype, qtype_rdata) return rdata_expires_at_min(qtype_rdata, min_expires_at) fqdn = rdata_expires_at_min([cname_rdata[0]], min_expires_at)[0] logger.info('Resolved CNAME: %s', fqdn) raise DnsCnameChainTooLong() async def request_memoized(logger, fqdn, qtype): key = (fqdn, qtype) try: cached_result = cache[key] except KeyError: logger.info('Not found %s in cache', fqdn) else: logger.info('Found %s in cache: %s', fqdn, cached_result) return cached_result try: memoized_mutex = in_progress[key] except KeyError: memoized_mutex = MemoizedMutex(request_and_cache, logger, fqdn, qtype) in_progress[key] = memoized_mutex else: logger.debug('Concurrent request found, waiting for it to complete') return await memoized_mutex() async def request_and_cache(logger, fqdn, qtype): answers = await request_until_response(logger, fqdn, qtype) expires_at = min( rdata_ttl.expires_at for rdata_groups in answers for rdata_ttl in rdata_groups ) key = (fqdn, qtype) invalidate_callbacks[key] = loop.call_at(expires_at, invalidate, logger, key) cache[key] = answers return answers def invalidate(logger, key): logger.debug('Removing from DNS cache: %s', key) del cache[key] invalidate_callbacks.pop(key).cancel() async def clear_cache( get_logger_adapter=get_logger_adapter, ): logger = get_logger_adapter({}) logger.debug('Clearing DNS cache') for callback in invalidate_callbacks.values(): callback.cancel() invalidate_callbacks.clear() cache.clear() async def request_until_response(logger, fqdn, qtype): exception = DnsError() async for nameserver in get_nameservers(parsed_resolve_conf, fqdn): logger.debug('Attempting nameserver: %s', nameserver) timeout, addrs = nameserver[0], nameserver[1:] try: return await request_with_timeout(logger, timeout, addrs, fqdn, qtype) except DnsRecordDoesNotExist: raise except DnsError as recent_exception: logger.warning('Nameserver failed: %s', nameserver[1]) exception = recent_exception raise exception async def request_with_timeout(logger, timeout, addrs, fqdn, qtype): cancelling_due_to_timeout = False task = current_task() def cancel(): nonlocal cancelling_due_to_timeout cancelling_due_to_timeout = True task.cancel() handle = loop.call_later(timeout, cancel) last_exception = None def set_timeout_cause(exception): nonlocal last_exception last_exception = exception try: return await request(logger, addrs, fqdn, qtype, set_timeout_cause) except CancelledError: if cancelling_due_to_timeout: raise DnsTimeout() from last_exception logger.debug('Cancelled') raise finally: handle.cancel() async def request(logger, addrs, fqdn, qtype, set_timeout_cause): async def req(): qid = randbelow(65536) fqdn_transformed = await transform_fqdn(fqdn) message = Message( qid=qid, qr=QUESTION, opcode=0, aa=0, tc=0, rd=1, ra=0, z=0, rcode=0, qd=(QuestionRecord(fqdn_transformed, qtype, qclass=1),), an=(), ns=(), ar=(), ) return message with ExitStack() as stack: socks = tuple( stack.enter_context(socket(AF_INET, SOCK_DGRAM)) for addr in addrs ) connections = {} last_exception = OSError() for addr_port, sock in zip(addrs, socks): try: sock.setblocking(False) set_sock_options(sock) sock.connect(addr_port) except OSError as exception: last_exception = exception set_timeout_cause(exception) else: _req = await req() connections[addr_port] = (sock, _req) if not connections: logger.debug('No sockets connected') raise DnsSocketError() from last_exception ttl_start = loop.time() for (sock, req) in connections.values(): logger.debug('Sending %s to %s', req, sock) await loop.sock_sendall(sock, pack(req)) last_exception = DnsError() while connections: connected_socks = tuple(sock for sock, req in connections.values()) try: response_data, addr_port = await recvfrom(loop, connected_socks, 512) except OSError as exception: logger.debug('Exception receiving from: %s', connected_socks) last_exception = exception set_timeout_cause(exception) continue else: logger.debug('Response from: %s', addr_port) try: res = parse(response_data) except (struct_error, IndexError, DnsPointerLoop) as exception: logger.debug('Error parsing response: %s', type(exception).__name__) last_exception = exception set_timeout_cause(exception) continue logger.debug('Received response: %s', res) trusted = res.qid == req.qid and res.qd == req.qd if not trusted: logger.debug('Response not trusted') continue del connections[addr_port] if res.tc: logger.warning('Response truncated') name_error = res.rcode == 3 non_name_error = res.rcode and not name_error name_lower = req.qd[0].name.lower() cname_answers = tuple( rdata_expires_at(answer, ttl_start + answer.ttl) for answer in res.an if answer.name.lower() == name_lower and answer.qtype == TYPES.CNAME ) qtype_answers = tuple( rdata_expires_at(answer, ttl_start + answer.ttl) for answer in res.an if answer.name.lower() == name_lower and answer.qtype == qtype ) if non_name_error: last_exception = DnsResponseCode(res.rcode) set_timeout_cause(last_exception) logger.debug('Error from %s', addr_port) elif name_error or (not cname_answers and not qtype_answers): # a name error can be returned by some non-authoritative # servers on not-existing, contradicting RFC 1035 logger.debug('Record not found from %s', addr_port) raise DnsRecordDoesNotExist() else: return cname_answers, qtype_answers if isinstance(last_exception, DnsError): raise last_exception raise DnsError() from last_exception def rdata_expires_at(record, expires_at): return \ IPv4AddressExpiresAt(record.rdata, expires_at) if record.qtype == TYPES.A else \ IPv6AddressExpiresAt(record.rdata, expires_at) if record.qtype == TYPES.AAAA else \ BytesExpiresAt(record.rdata.lower(), expires_at) def rdata_expires_at_min(rdatas, expires_at): return tuple( type(rdata)(rdata=rdata, expires_at=min(expires_at, rdata.expires_at)) for rdata in rdatas ) return resolve, clear_cache
The full license is in the file LICENSE, distributed with this software. Copyright (C) Jun Zhu. All rights reserved. """ from abc import abstractmethod import os.path as osp import platform from collections import OrderedDict from weakref import WeakValueDictionary import h5py # Track all FileAccess objects - {path: FileAccess} _file_access_registry = WeakValueDictionary() class FileOpenRegistry: def __init__(self, n_max): """Initialization. :param int n_max: maximum number of files. """ self._n_max = n_max # key: filepath, value: None (not used) self._cache = OrderedDict() def n_opened(self): """Return the number of opened files."""
class CommandDispatcher(object): count = 0 def __init__(self, dispatcher_id=None): try: self.dispatcher_id = dispatcher_id except AttributeError: pass self.main_window = None """:type: QtGui.QWidget""" self._undo_stack = UndoStack() self._action_history = [] self._actions = {} self._commands = {} self._parent_dispatcher = None """:type: ref[CommandDispatcher]""" self._children_dispatchers = WeakValueDictionary() self._action_added = MrSignal() self._main_data = None self.main_data_can_change = True @property def main_data(self): return self._main_data @main_data.setter def main_data(self, value): if self.main_data_can_change: self._main_data = value @property def undo_stack(self): try: return self._parent_dispatcher().undo_stack except (TypeError, AttributeError): return self._undo_stack @property def action_history(self): try: return self._parent_dispatcher().action_history except (TypeError, AttributeError): return self._action_history @property def action_added(self): try: return self._parent_dispatcher().action_added except (TypeError, AttributeError): return self._action_added def clear_children(self): self._children_dispatchers.clear() def add_child(self, dispatcher): if dispatcher.dispatcher_id is None: CommandDispatcher.count += 1 dispatcher.dispatcher_id = str(CommandDispatcher.count) if dispatcher.dispatcher_id in self._children_dispatchers: assert self._children_dispatchers[dispatcher.dispatcher_id] is dispatcher return self._children_dispatchers[dispatcher.dispatcher_id] = dispatcher dispatcher.set_parent(self) def set_parent(self, parent_dispatcher): self._parent_dispatcher = ref(parent_dispatcher) def multiple_dispatch(self, actions): for action in actions: self.dispatch(action) def _get_command(self, action): if isinstance(action, Action): action_name = action.action_name try: command = self._commands[action_name](action, main_window=self.main_window) except KeyError: raise TypeError('CommandDispatcher4: Command %s not found in defined actions!' % str(action_name)) elif isinstance(action, (Command, ChildCommand)): command = action command.main_window = self.main_window else: raise TypeError('CommandDispatcher4: Action type not valid! %s' % str(action)) return command def _try_undo_redo(self, action): if isinstance(action, str): tmp = action.upper() if tmp == 'UNDO': self.action_history.append('Undo') self.action_added.emit(action) self.undo_stack.undo() return True elif tmp == 'REDO': self.action_history.append('Redo') self.action_added.emit(action) self.undo_stack.redo() return True return False def undo(self): self.dispatch('Undo') def redo(self): self.dispatch('Redo') def _subdata(self, data): return self.main_data.subdata(data) def _action_str(self, action, action_data=None): assert isinstance(action, str) if action.upper() in ('REDO', 'UNDO'): return action # debuginfo(self.dispatcher_id, self._parent_dispatcher) if None not in (self.dispatcher_id, self._parent_dispatcher): if action_data is not None: data = action_data.split() # debuginfo(22222, data) if data[0] is None and data[1] is None: action = '%s.%s()' % (self.dispatcher_id, action) elif data[0] is None: action = '%s.%s%s' % (self.dispatcher_id, action, data[1]) elif data[1] is None: action = '%s[%s].%s()' % (self.dispatcher_id, data[0], action) else: action = '%s[%s].%s%s' % (self.dispatcher_id, data[0], action, data[1]) else: action = '%s.%s' % (self.dispatcher_id, action) return action def dispatch(self, action, tracking=True): if self._try_undo_redo(action): return True try: action_name, action_data = action action_name = action_name.replace('.', '_') except (TypeError, ValueError): assert isinstance(action, str) action_info = self.parse_action(action) # debuginfo(action_info) return self._dispatch(action_info, tracking, action) if not isinstance(action_data, tuple): action_data = (action_data,) try: action_cls = self._actions[action_name] except KeyError: raise TypeError('CommandDispatcher4: Action type not valid! %s' % str(action_name)) action_data = action_cls.ActionDataCls(*action_data) try: # debuginfo('getting action_str') action_str = self._action_str(action_name, action_data) # debuginfo(1111111, action_str) return self._parent_dispatcher()._traceback(action_str, tracking, action_data) except (TypeError, AttributeError): action_str = '%s%s' % (action_name, str(action_data)) action_info = self.parse_action(action_str) return self._dispatch(action_info, tracking, action_str, action_data) def _traceback(self, action, tracking=True, action_data=None): action_str = self._action_str(action) try: return self._parent_dispatcher()._traceback(action_str, tracking, action_data) except (TypeError, AttributeError): action_info = self.parse_action(action_str) # debuginfo(action_str) return self._dispatch(action_info, tracking, action_str, action_data) def _dispatch(self, action_info, tracking, action_str, action_data=None): _dispatches, _action = action_info # debuginfo(action_info) if len(_dispatches) == 0: action_name, action_data_ = _action if action_data is None: action_data = action_data_ # if not isinstance(action_data, tuple): # action_data = (action_data,) return self._final_dispatch(action_name, action_data, action_str, tracking) dispatcher_id, dispatcher_data = _dispatches[0] # debuginfo(action_info) if self.dispatcher_id is not None: try: assert dispatcher_id == self.dispatcher_id except AssertionError: print('This dispatcher = %s, other dispatcher = %s' % (self.dispatcher_id, dispatcher_id)) raise try: dispatcher_id = _dispatches[1][0] except IndexError: _action_info = [], _action return self._dispatch(_action_info, tracking, action_str, action_data) else: return self._children_dispatchers[dispatcher_id]._dispatch(action_info, tracking, action_str, action_data) # debuginfo(self.dispatcher_id, list(self._children_dispatchers.keys())) dispatcher = self._children_dispatchers[dispatcher_id] # FIXME: should the dispatcher be responsible for this? might be taken care of by the commands if dispatcher_data is not None: subdata = self._subdata(dispatcher_data) else: subdata = None old_main_data = dispatcher.get_model if subdata is not None: dispatcher.get_model = subdata _action_info = _dispatches[1:], _action # noinspection PyProtectedMember dispatch_result = dispatcher._dispatch(_action_info, tracking, action_str, action_data) dispatcher.get_model = old_main_data return dispatch_result def _final_dispatch(self, action_name, action_data, action_str, tracking=True): # debuginfo(action_str) if self._try_undo_redo(action_name): return True try: action_cls = self._actions[action_name] except KeyError: raise TypeError('CommandDispatcher4: Action type not valid! %s' % str(action_name)) if isinstance(action_data, tuple): action_data = action_cls.ActionDataCls(*action_data) assert isinstance(action_data, action_cls.ActionDataCls) action = action_cls(action_data) action.get_model = self.main_data command = self._get_command(action) if command is None: return False command = self._wrap_command(command) command.skip_first = False command.redo() command_result = command.command_result command.skip_first = True # TODO: not sure if this is the desired behavior if command_result is False: command.finalize() return False action_ = command.action if action_.log_action is True and tracking is True: if action_str is None: action_str = str(action_) self.action_history.append(action_str) # this notifies the main window that an action has been added, so that it can update the log self.action_added.emit(action_str) # if the action is successful, push it to the stack (it will be skipped on first push) if command_result is True: if command.push_to_stack and tracking is True: self.undo_stack.push(command) if command.set_clean is True: self.undo_stack.setClean() return True else: return False def _wrap_command(self, command): try: return self._parent_dispatcher()._wrap_command(command) except (TypeError, AttributeError): return command def verify(self): action_keys = set(self._actions.keys()) command_keys = set(self._commands.keys()) if action_keys != command_keys: if len(action_keys) > len(command_keys): raise Exception("CommandDispatcher4: Missing commands! %s" % str(action_keys - command_keys)) else: raise Exception("CommandDispatcher4: Missing actions! %s" % str(command_keys - action_keys)) for key, child in iteritems(self._children_dispatchers): child.verify() def finalize(self): self._parent_dispatcher = None self._actions.clear() self._commands.clear() for key, child in iteritems(self._children_dispatchers): child.finalize() self._children_dispatchers.clear() def __call__(self, action_name): action_name = action_name.replace('.', '_') def add_action(cls): if issubclass(cls, Action): self._actions[action_name] = cls elif issubclass(cls, Command): self._commands[action_name] = cls else: raise TypeError("CommandDispatcher4: %s is not an Action or Command!" % cls.__name__) cls.action_name = action_name return cls return add_action @staticmethod def parse_action(s): # debuginfo(s) tmp = s data = '' if s[-1] == ')': count = 1 for i in range(1, len(s)): a = s[-i-1] if a == ')': count += 1 elif a == '(': count -= 1 if count == 0: j = len(s) - i - 1 data_ = s[j + 1:-1] if data_ != '': data = literal_eval(data_) else: data = [] tmp = s[:j] break tmp = tmp.split('.') # debuginfo(tmp, data) tmp_ = tmp[:-1] _tmp = [] for i in tmp_: a = i.split('[') b = a[0] try: c = literal_eval(a[1][:-1]) except IndexError: c = None _tmp.append((b, c)) try: insert_data = _tmp[-1][1] _tmp[-1] = _tmp[-1][0], None if insert_data is not None: try: data.insert(0, insert_data) data = tuple(data) except AttributeError: data = tuple([insert_data, data]) else: data = tuple([data]) except IndexError: data = tuple(data) # debuginfo(_tmp, (tmp[-1], data)) return _tmp, (tmp[-1], data)
def __init__(self): self.__slots = WeakValueDictionary()
class Path(object): """ :class:`Path` represents a series of possibly disconnected, possibly closed, line and curve segments. The underlying storage is made up of two parallel numpy arrays: - *vertices*: an Nx2 float array of vertices - *codes*: an N-length uint8 array of vertex types These two arrays always have the same length in the first dimension. For example, to represent a cubic curve, you must provide three vertices as well as three codes ``CURVE3``. The code types are: - ``STOP`` : 1 vertex (ignored) A marker for the end of the entire path (currently not required and ignored) - ``MOVETO`` : 1 vertex Pick up the pen and move to the given vertex. - ``LINETO`` : 1 vertex Draw a line from the current position to the given vertex. - ``CURVE3`` : 1 control point, 1 endpoint Draw a quadratic Bezier curve from the current position, with the given control point, to the given end point. - ``CURVE4`` : 2 control points, 1 endpoint Draw a cubic Bezier curve from the current position, with the given control points, to the given end point. - ``CLOSEPOLY`` : 1 vertex (ignored) Draw a line segment to the start point of the current polyline. Users of Path objects should not access the vertices and codes arrays directly. Instead, they should use :meth:`iter_segments` or :meth:`cleaned` to get the vertex/code pairs. This is important, since many :class:`Path` objects, as an optimization, do not store a *codes* at all, but have a default one provided for them by :meth:`iter_segments`. .. note:: The vertices and codes arrays should be treated as immutable -- there are a number of optimizations and assumptions made up front in the constructor that will not change when the data changes. """ # Path codes STOP = 0 # 1 vertex MOVETO = 1 # 1 vertex LINETO = 2 # 1 vertex CURVE3 = 3 # 2 vertices CURVE4 = 4 # 3 vertices CLOSEPOLY = 79 # 1 vertex #: A dictionary mapping Path codes to the number of vertices that the #: code expects. NUM_VERTICES_FOR_CODE = { STOP: 1, MOVETO: 1, LINETO: 1, CURVE3: 2, CURVE4: 3, CLOSEPOLY: 1 } code_type = np.uint8 def __init__(self, vertices, codes=None, _interpolation_steps=1, closed=False, readonly=False): """ Create a new path with the given vertices and codes. Parameters ---------- vertices : array_like The ``(n, 2)`` float array, masked array or sequence of pairs representing the vertices of the path. If *vertices* contains masked values, they will be converted to NaNs which are then handled correctly by the Agg PathIterator and other consumers of path data, such as :meth:`iter_segments`. codes : {None, array_like}, optional n-length array integers representing the codes of the path. If not None, codes must be the same length as vertices. If None, *vertices* will be treated as a series of line segments. _interpolation_steps : int, optional Used as a hint to certain projections, such as Polar, that this path should be linearly interpolated immediately before drawing. This attribute is primarily an implementation detail and is not intended for public use. closed : bool, optional If *codes* is None and closed is True, vertices will be treated as line segments of a closed polygon. readonly : bool, optional Makes the path behave in an immutable way and sets the vertices and codes as read-only arrays. """ if isinstance(vertices, np.ma.MaskedArray): vertices = vertices.astype(float).filled(np.nan) else: vertices = np.asarray(vertices, float) if (vertices.ndim != 2) or (vertices.shape[1] != 2): msg = "'vertices' must be a 2D list or array with shape Nx2" raise ValueError(msg) if codes is not None: codes = np.asarray(codes, self.code_type) if (codes.ndim != 1) or len(codes) != len(vertices): msg = ("'codes' must be a 1D list or array with the same" " length of 'vertices'") raise ValueError(msg) if len(codes) and codes[0] != self.MOVETO: msg = ("The first element of 'code' must be equal to 'MOVETO':" " {0}") raise ValueError(msg.format(self.MOVETO)) elif closed: codes = np.empty(len(vertices), dtype=self.code_type) codes[0] = self.MOVETO codes[1:-1] = self.LINETO codes[-1] = self.CLOSEPOLY self._vertices = vertices self._codes = codes self._interpolation_steps = _interpolation_steps self._update_values() if readonly: self._vertices.flags.writeable = False if self._codes is not None: self._codes.flags.writeable = False self._readonly = True else: self._readonly = False @classmethod def _fast_from_codes_and_verts(cls, verts, codes, internals=None): """ Creates a Path instance without the expense of calling the constructor Parameters ---------- verts : numpy array codes : numpy array internals : dict or None The attributes that the resulting path should have. Allowed keys are ``readonly``, ``should_simplify``, ``simplify_threshold``, ``has_nonfinite`` and ``interpolation_steps``. """ internals = internals or {} pth = cls.__new__(cls) if isinstance(verts, np.ma.MaskedArray): verts = verts.astype(float).filled(np.nan) else: verts = np.asarray(verts, float) pth._vertices = verts pth._codes = codes pth._readonly = internals.pop('readonly', False) pth.should_simplify = internals.pop('should_simplify', True) pth.simplify_threshold = (internals.pop( 'simplify_threshold', rcParams['path.simplify_threshold'])) pth._has_nonfinite = internals.pop('has_nonfinite', False) pth._interpolation_steps = internals.pop('interpolation_steps', 1) if internals: raise ValueError('Unexpected internals provided to ' '_fast_from_codes_and_verts: ' '{0}'.format('\n *'.join( six.iterkeys(internals)))) return pth def _update_values(self): self._should_simplify = ( rcParams['path.simplify'] and (len(self._vertices) >= 128 and (self._codes is None or np.all(self._codes <= Path.LINETO)))) self._simplify_threshold = rcParams['path.simplify_threshold'] self._has_nonfinite = not np.isfinite(self._vertices).all() @property def vertices(self): """ The list of vertices in the `Path` as an Nx2 numpy array. """ return self._vertices @vertices.setter def vertices(self, vertices): if self._readonly: raise AttributeError("Can't set vertices on a readonly Path") self._vertices = vertices self._update_values() @property def codes(self): """ The list of codes in the `Path` as a 1-D numpy array. Each code is one of `STOP`, `MOVETO`, `LINETO`, `CURVE3`, `CURVE4` or `CLOSEPOLY`. For codes that correspond to more than one vertex (`CURVE3` and `CURVE4`), that code will be repeated so that the length of `self.vertices` and `self.codes` is always the same. """ return self._codes @codes.setter def codes(self, codes): if self._readonly: raise AttributeError("Can't set codes on a readonly Path") self._codes = codes self._update_values() @property def simplify_threshold(self): """ The fraction of a pixel difference below which vertices will be simplified out. """ return self._simplify_threshold @simplify_threshold.setter def simplify_threshold(self, threshold): self._simplify_threshold = threshold @property def has_nonfinite(self): """ `True` if the vertices array has nonfinite values. """ return self._has_nonfinite @property def should_simplify(self): """ `True` if the vertices array should be simplified. """ return self._should_simplify @should_simplify.setter def should_simplify(self, should_simplify): self._should_simplify = should_simplify @property def readonly(self): """ `True` if the `Path` is read-only. """ return self._readonly def __copy__(self): """ Returns a shallow copy of the `Path`, which will share the vertices and codes with the source `Path`. """ import copy return copy.copy(self) copy = __copy__ def __deepcopy__(self, memo=None): """ Returns a deepcopy of the `Path`. The `Path` will not be readonly, even if the source `Path` is. """ try: codes = self.codes.copy() except AttributeError: codes = None return self.__class__(self.vertices.copy(), codes, _interpolation_steps=self._interpolation_steps) deepcopy = __deepcopy__ @classmethod def make_compound_path_from_polys(cls, XY): """ Make a compound path object to draw a number of polygons with equal numbers of sides XY is a (numpolys x numsides x 2) numpy array of vertices. Return object is a :class:`Path` .. plot:: mpl_examples/api/histogram_path_demo.py """ # for each poly: 1 for the MOVETO, (numsides-1) for the LINETO, 1 for # the CLOSEPOLY; the vert for the closepoly is ignored but we still # need it to keep the codes aligned with the vertices numpolys, numsides, two = XY.shape if two != 2: raise ValueError("The third dimension of 'XY' must be 2") stride = numsides + 1 nverts = numpolys * stride verts = np.zeros((nverts, 2)) codes = np.ones(nverts, int) * cls.LINETO codes[0::stride] = cls.MOVETO codes[numsides::stride] = cls.CLOSEPOLY for i in range(numsides): verts[i::stride] = XY[:, i] return cls(verts, codes) @classmethod def make_compound_path(cls, *args): """Make a compound path from a list of Path objects.""" # Handle an empty list in args (i.e. no args). if not args: return Path(np.empty([0, 2], dtype=np.float32)) lengths = [len(x) for x in args] total_length = sum(lengths) vertices = np.vstack([x.vertices for x in args]) vertices.reshape((total_length, 2)) codes = np.empty(total_length, dtype=cls.code_type) i = 0 for path in args: if path.codes is None: codes[i] = cls.MOVETO codes[i + 1:i + len(path.vertices)] = cls.LINETO else: codes[i:i + len(path.codes)] = path.codes i += len(path.vertices) return cls(vertices, codes) def __repr__(self): return "Path(%r, %r)" % (self.vertices, self.codes) def __len__(self): return len(self.vertices) def iter_segments(self, transform=None, remove_nans=True, clip=None, snap=False, stroke_width=1.0, simplify=None, curves=True, sketch=None): """ Iterates over all of the curve segments in the path. Each iteration returns a 2-tuple (*vertices*, *code*), where *vertices* is a sequence of 1 - 3 coordinate pairs, and *code* is one of the :class:`Path` codes. Additionally, this method can provide a number of standard cleanups and conversions to the path. Parameters ---------- transform : None or :class:`~matplotlib.transforms.Transform` instance If not None, the given affine transformation will be applied to the path. remove_nans : {False, True}, optional If True, will remove all NaNs from the path and insert MOVETO commands to skip over them. clip : None or sequence, optional If not None, must be a four-tuple (x1, y1, x2, y2) defining a rectangle in which to clip the path. snap : None or bool, optional If None, auto-snap to pixels, to reduce fuzziness of rectilinear lines. If True, force snapping, and if False, don't snap. stroke_width : float, optional The width of the stroke being drawn. Needed as a hint for the snapping algorithm. simplify : None or bool, optional If True, perform simplification, to remove vertices that do not affect the appearance of the path. If False, perform no simplification. If None, use the should_simplify member variable. curves : {True, False}, optional If True, curve segments will be returned as curve segments. If False, all curves will be converted to line segments. sketch : None or sequence, optional If not None, must be a 3-tuple of the form (scale, length, randomness), representing the sketch parameters. """ if not len(self): return cleaned = self.cleaned(transform=transform, remove_nans=remove_nans, clip=clip, snap=snap, stroke_width=stroke_width, simplify=simplify, curves=curves, sketch=sketch) vertices = cleaned.vertices codes = cleaned.codes len_vertices = vertices.shape[0] # Cache these object lookups for performance in the loop. NUM_VERTICES_FOR_CODE = self.NUM_VERTICES_FOR_CODE STOP = self.STOP i = 0 while i < len_vertices: code = codes[i] if code == STOP: return else: num_vertices = NUM_VERTICES_FOR_CODE[code] curr_vertices = vertices[i:i + num_vertices].flatten() yield curr_vertices, code i += num_vertices def cleaned(self, transform=None, remove_nans=False, clip=None, quantize=False, simplify=False, curves=False, stroke_width=1.0, snap=False, sketch=None): """ Cleans up the path according to the parameters returning a new Path instance. .. seealso:: See :meth:`iter_segments` for details of the keyword arguments. Returns ------- Path instance with cleaned up vertices and codes. """ vertices, codes = _path.cleanup_path(self, transform, remove_nans, clip, snap, stroke_width, simplify, curves, sketch) internals = { 'should_simplify': self.should_simplify and not simplify, 'has_nonfinite': self.has_nonfinite and not remove_nans, 'simplify_threshold': self.simplify_threshold, 'interpolation_steps': self._interpolation_steps } return Path._fast_from_codes_and_verts(vertices, codes, internals) def transformed(self, transform): """ Return a transformed copy of the path. .. seealso:: :class:`matplotlib.transforms.TransformedPath` A specialized path class that will cache the transformed result and automatically update when the transform changes. """ return Path(transform.transform(self.vertices), self.codes, self._interpolation_steps) def contains_point(self, point, transform=None, radius=0.0): """ Returns *True* if the path contains the given point. If *transform* is not *None*, the path will be transformed before performing the test. *radius* allows the path to be made slightly larger or smaller. """ if transform is not None: transform = transform.frozen() result = _path.point_in_path(point[0], point[1], radius, self, transform) return result def contains_points(self, points, transform=None, radius=0.0): """ Returns a bool array which is *True* if the path contains the corresponding point. If *transform* is not *None*, the path will be transformed before performing the test. *radius* allows the path to be made slightly larger or smaller. """ if transform is not None: transform = transform.frozen() result = _path.points_in_path(points, radius, self, transform) return result.astype('bool') def contains_path(self, path, transform=None): """ Returns *True* if this path completely contains the given path. If *transform* is not *None*, the path will be transformed before performing the test. """ if transform is not None: transform = transform.frozen() return _path.path_in_path(self, None, path, transform) def get_extents(self, transform=None): """ Returns the extents (*xmin*, *ymin*, *xmax*, *ymax*) of the path. Unlike computing the extents on the *vertices* alone, this algorithm will take into account the curves and deal with control points appropriately. """ from .transforms import Bbox path = self if transform is not None: transform = transform.frozen() if not transform.is_affine: path = self.transformed(transform) transform = None return Bbox(_path.get_path_extents(path, transform)) def intersects_path(self, other, filled=True): """ Returns *True* if this path intersects another given path. *filled*, when True, treats the paths as if they were filled. That is, if one path completely encloses the other, :meth:`intersects_path` will return True. """ return _path.path_intersects_path(self, other, filled) def intersects_bbox(self, bbox, filled=True): """ Returns *True* if this path intersects a given :class:`~matplotlib.transforms.Bbox`. *filled*, when True, treats the path as if it was filled. That is, if one path completely encloses the other, :meth:`intersects_path` will return True. """ from .transforms import BboxTransformTo rectangle = self.unit_rectangle().transformed(BboxTransformTo(bbox)) result = self.intersects_path(rectangle, filled) return result def interpolated(self, steps): """ Returns a new path resampled to length N x steps. Does not currently handle interpolating curves. """ if steps == 1: return self vertices = simple_linear_interpolation(self.vertices, steps) codes = self.codes if codes is not None: new_codes = Path.LINETO * np.ones(((len(codes) - 1) * steps + 1, )) new_codes[0::steps] = codes else: new_codes = None return Path(vertices, new_codes) def to_polygons(self, transform=None, width=0, height=0, closed_only=True): """ Convert this path to a list of polygons or polylines. Each polygon/polyline is an Nx2 array of vertices. In other words, each polygon has no ``MOVETO`` instructions or curves. This is useful for displaying in backends that do not support compound paths or Bezier curves, such as GDK. If *width* and *height* are both non-zero then the lines will be simplified so that vertices outside of (0, 0), (width, height) will be clipped. If *closed_only* is `True` (default), only closed polygons, with the last point being the same as the first point, will be returned. Any unclosed polylines in the path will be explicitly closed. If *closed_only* is `False`, any unclosed polygons in the path will be returned as unclosed polygons, and the closed polygons will be returned explicitly closed by setting the last point to the same as the first point. """ if len(self.vertices) == 0: return [] if transform is not None: transform = transform.frozen() if self.codes is None and (width == 0 or height == 0): vertices = self.vertices if closed_only: if len(vertices) < 3: return [] elif np.any(vertices[0] != vertices[-1]): vertices = list(vertices) + [vertices[0]] if transform is None: return [vertices] else: return [transform.transform(vertices)] # Deal with the case where there are curves and/or multiple # subpaths (using extension code) return _path.convert_path_to_polygons(self, transform, width, height, closed_only) _unit_rectangle = None @classmethod def unit_rectangle(cls): """ Return a :class:`Path` instance of the unit rectangle from (0, 0) to (1, 1). """ if cls._unit_rectangle is None: cls._unit_rectangle = \ cls([[0.0, 0.0], [1.0, 0.0], [1.0, 1.0], [0.0, 1.0], [0.0, 0.0]], [cls.MOVETO, cls.LINETO, cls.LINETO, cls.LINETO, cls.CLOSEPOLY], readonly=True) return cls._unit_rectangle _unit_regular_polygons = WeakValueDictionary() @classmethod def unit_regular_polygon(cls, numVertices): """ Return a :class:`Path` instance for a unit regular polygon with the given *numVertices* and radius of 1.0, centered at (0, 0). """ if numVertices <= 16: path = cls._unit_regular_polygons.get(numVertices) else: path = None if path is None: theta = (2 * np.pi / numVertices * np.arange(numVertices + 1).reshape((numVertices + 1, 1))) # This initial rotation is to make sure the polygon always # "points-up" theta += np.pi / 2.0 verts = np.concatenate((np.cos(theta), np.sin(theta)), 1) codes = np.empty((numVertices + 1, )) codes[0] = cls.MOVETO codes[1:-1] = cls.LINETO codes[-1] = cls.CLOSEPOLY path = cls(verts, codes, readonly=True) if numVertices <= 16: cls._unit_regular_polygons[numVertices] = path return path _unit_regular_stars = WeakValueDictionary() @classmethod def unit_regular_star(cls, numVertices, innerCircle=0.5): """ Return a :class:`Path` for a unit regular star with the given numVertices and radius of 1.0, centered at (0, 0). """ if numVertices <= 16: path = cls._unit_regular_stars.get((numVertices, innerCircle)) else: path = None if path is None: ns2 = numVertices * 2 theta = (2 * np.pi / ns2 * np.arange(ns2 + 1)) # This initial rotation is to make sure the polygon always # "points-up" theta += np.pi / 2.0 r = np.ones(ns2 + 1) r[1::2] = innerCircle verts = np.vstack( (r * np.cos(theta), r * np.sin(theta))).transpose() codes = np.empty((ns2 + 1, )) codes[0] = cls.MOVETO codes[1:-1] = cls.LINETO codes[-1] = cls.CLOSEPOLY path = cls(verts, codes, readonly=True) if numVertices <= 16: cls._unit_regular_stars[(numVertices, innerCircle)] = path return path @classmethod def unit_regular_asterisk(cls, numVertices): """ Return a :class:`Path` for a unit regular asterisk with the given numVertices and radius of 1.0, centered at (0, 0). """ return cls.unit_regular_star(numVertices, 0.0) _unit_circle = None @classmethod def unit_circle(cls): """ Return the readonly :class:`Path` of the unit circle. For most cases, :func:`Path.circle` will be what you want. """ if cls._unit_circle is None: cls._unit_circle = cls.circle(center=(0, 0), radius=1, readonly=True) return cls._unit_circle @classmethod def circle(cls, center=(0., 0.), radius=1., readonly=False): """ Return a Path representing a circle of a given radius and center. Parameters ---------- center : pair of floats The center of the circle. Default ``(0, 0)``. radius : float The radius of the circle. Default is 1. readonly : bool Whether the created path should have the "readonly" argument set when creating the Path instance. Notes ----- The circle is approximated using cubic Bezier curves. This uses 8 splines around the circle using the approach presented here: Lancaster, Don. `Approximating a Circle or an Ellipse Using Four Bezier Cubic Splines <http://www.tinaja.com/glib/ellipse4.pdf>`_. """ MAGIC = 0.2652031 SQRTHALF = np.sqrt(0.5) MAGIC45 = np.sqrt((MAGIC * MAGIC) / 2.0) vertices = np.array([ [0.0, -1.0], [MAGIC, -1.0 ], [SQRTHALF - MAGIC45, -SQRTHALF - MAGIC45], [SQRTHALF, -SQRTHALF], [SQRTHALF + MAGIC45, -SQRTHALF + MAGIC45], [1.0, -MAGIC], [1.0, 0.0], [1.0, MAGIC], [SQRTHALF + MAGIC45, SQRTHALF - MAGIC45], [SQRTHALF, SQRTHALF], [SQRTHALF - MAGIC45, SQRTHALF + MAGIC45], [MAGIC, 1.0], [0.0, 1.0], [-MAGIC, 1.0], [-SQRTHALF + MAGIC45, SQRTHALF + MAGIC45], [-SQRTHALF, SQRTHALF], [-SQRTHALF - MAGIC45, SQRTHALF - MAGIC45], [-1.0, MAGIC], [-1.0, 0.0], [-1.0, -MAGIC], [-SQRTHALF - MAGIC45, -SQRTHALF + MAGIC45], [-SQRTHALF, -SQRTHALF], [-SQRTHALF + MAGIC45, -SQRTHALF - MAGIC45], [-MAGIC, -1.0], [0.0, -1.0], [0.0, -1.0] ], dtype=float) codes = [cls.CURVE4] * 26 codes[0] = cls.MOVETO codes[-1] = cls.CLOSEPOLY return Path(vertices * radius + center, codes, readonly=readonly) _unit_circle_righthalf = None @classmethod def unit_circle_righthalf(cls): """ Return a :class:`Path` of the right half of a unit circle. The circle is approximated using cubic Bezier curves. This uses 4 splines around the circle using the approach presented here: Lancaster, Don. `Approximating a Circle or an Ellipse Using Four Bezier Cubic Splines <http://www.tinaja.com/glib/ellipse4.pdf>`_. """ if cls._unit_circle_righthalf is None: MAGIC = 0.2652031 SQRTHALF = np.sqrt(0.5) MAGIC45 = np.sqrt((MAGIC * MAGIC) / 2.0) vertices = np.array([[0.0, -1.0], [MAGIC, -1.0], [SQRTHALF - MAGIC45, -SQRTHALF - MAGIC45], [SQRTHALF, -SQRTHALF], [SQRTHALF + MAGIC45, -SQRTHALF + MAGIC45], [1.0, -MAGIC], [1.0, 0.0], [1.0, MAGIC], [SQRTHALF + MAGIC45, SQRTHALF - MAGIC45], [SQRTHALF, SQRTHALF], [SQRTHALF - MAGIC45, SQRTHALF + MAGIC45], [MAGIC, 1.0], [0.0, 1.0], [0.0, -1.0]], float) codes = cls.CURVE4 * np.ones(14) codes[0] = cls.MOVETO codes[-1] = cls.CLOSEPOLY cls._unit_circle_righthalf = cls(vertices, codes, readonly=True) return cls._unit_circle_righthalf @classmethod def arc(cls, theta1, theta2, n=None, is_wedge=False): """ Return an arc on the unit circle from angle *theta1* to angle *theta2* (in degrees). If *n* is provided, it is the number of spline segments to make. If *n* is not provided, the number of spline segments is determined based on the delta between *theta1* and *theta2*. Masionobe, L. 2003. `Drawing an elliptical arc using polylines, quadratic or cubic Bezier curves <http://www.spaceroots.org/documents/ellipse/index.html>`_. """ theta1, theta2 = np.deg2rad([theta1, theta2]) twopi = np.pi * 2.0 halfpi = np.pi * 0.5 eta1 = np.arctan2(np.sin(theta1), np.cos(theta1)) eta2 = np.arctan2(np.sin(theta2), np.cos(theta2)) eta2 -= twopi * np.floor((eta2 - eta1) / twopi) # number of curve segments to make if n is None: n = int(2**np.ceil((eta2 - eta1) / halfpi)) if n < 1: raise ValueError("n must be >= 1 or None") deta = (eta2 - eta1) / n t = np.tan(0.5 * deta) alpha = np.sin(deta) * (np.sqrt(4.0 + 3.0 * t * t) - 1) / 3.0 steps = np.linspace(eta1, eta2, n + 1, True) cos_eta = np.cos(steps) sin_eta = np.sin(steps) xA = cos_eta[:-1] yA = sin_eta[:-1] xA_dot = -yA yA_dot = xA xB = cos_eta[1:] yB = sin_eta[1:] xB_dot = -yB yB_dot = xB if is_wedge: length = n * 3 + 4 vertices = np.zeros((length, 2), float) codes = cls.CURVE4 * np.ones((length, ), cls.code_type) vertices[1] = [xA[0], yA[0]] codes[0:2] = [cls.MOVETO, cls.LINETO] codes[-2:] = [cls.LINETO, cls.CLOSEPOLY] vertex_offset = 2 end = length - 2 else: length = n * 3 + 1 vertices = np.empty((length, 2), float) codes = cls.CURVE4 * np.ones((length, ), cls.code_type) vertices[0] = [xA[0], yA[0]] codes[0] = cls.MOVETO vertex_offset = 1 end = length vertices[vertex_offset:end:3, 0] = xA + alpha * xA_dot vertices[vertex_offset:end:3, 1] = yA + alpha * yA_dot vertices[vertex_offset + 1:end:3, 0] = xB - alpha * xB_dot vertices[vertex_offset + 1:end:3, 1] = yB - alpha * yB_dot vertices[vertex_offset + 2:end:3, 0] = xB vertices[vertex_offset + 2:end:3, 1] = yB return cls(vertices, codes, readonly=True) @classmethod def wedge(cls, theta1, theta2, n=None): """ Return a wedge of the unit circle from angle *theta1* to angle *theta2* (in degrees). If *n* is provided, it is the number of spline segments to make. If *n* is not provided, the number of spline segments is determined based on the delta between *theta1* and *theta2*. """ return cls.arc(theta1, theta2, n, True) _hatch_dict = maxdict(8) @classmethod def hatch(cls, hatchpattern, density=6): """ Given a hatch specifier, *hatchpattern*, generates a Path that can be used in a repeated hatching pattern. *density* is the number of lines per unit square. """ from matplotlib.hatch import get_path if hatchpattern is None: return None hatch_path = cls._hatch_dict.get((hatchpattern, density)) if hatch_path is not None: return hatch_path hatch_path = get_path(hatchpattern, density) cls._hatch_dict[(hatchpattern, density)] = hatch_path return hatch_path def clip_to_bbox(self, bbox, inside=True): """ Clip the path to the given bounding box. The path must be made up of one or more closed polygons. This algorithm will not behave correctly for unclosed paths. If *inside* is `True`, clip to the inside of the box, otherwise to the outside of the box. """ # Use make_compound_path_from_polys verts = _path.clip_path_to_rect(self, bbox, inside) paths = [Path(poly) for poly in verts] return self.make_compound_path(*paths)
class FileYAMLReader(AbstractYAMLReader): """Primary reader base class that is configured by a YAML file. This class uses the idea of per-file "file handler" objects to read file contents and determine what is available in the file. This differs from the base :class:`AbstractYAMLReader` which does not depend on individual file handler objects. In almost all cases this class should be used over its base class and can be used as a reader by itself and requires no subclassing. """ def __init__(self, config_files, filter_parameters=None, filter_filenames=True, **kwargs): """Set up initial internal storage for loading file data.""" super(FileYAMLReader, self).__init__(config_files) self.file_handlers = {} self.available_ids = {} self.filter_filenames = self.info.get('filter_filenames', filter_filenames) self.filter_parameters = filter_parameters or {} self.coords_cache = WeakValueDictionary() @property def sensor_names(self): """Names of sensors whose data is being loaded by this reader.""" if not self.file_handlers: return self.info['sensors'] file_handlers = (handlers[0] for handlers in self.file_handlers.values()) sensor_names = set() for fh in file_handlers: try: sensor_names.update(fh.sensor_names) except NotImplementedError: continue if not sensor_names: return self.info['sensors'] return sorted(sensor_names) @property def available_dataset_ids(self): """Get DatasetIDs that are loadable by this reader.""" return self.available_ids.keys() @property def start_time(self): """Start time of the earlier file used by this reader.""" if not self.file_handlers: raise RuntimeError("Start time unknown until files are selected") return min(x[0].start_time for x in self.file_handlers.values()) @property def end_time(self): """End time of the latest file used by this reader.""" if not self.file_handlers: raise RuntimeError("End time unknown until files are selected") return max(x[-1].end_time for x in self.file_handlers.values()) @staticmethod def check_file_covers_area(file_handler, check_area): """Check if the file covers the current area. If the file doesn't provide any bounding box information or 'area' was not provided in `filter_parameters`, the check returns True. """ try: gbb = Boundary(*file_handler.get_bounding_box()) except NotImplementedError as err: logger.debug("Bounding box computation not implemented: %s", str(err)) else: abb = AreaDefBoundary(get_area_def(check_area), frequency=1000) intersection = gbb.contour_poly.intersection(abb.contour_poly) if not intersection: return False return True def find_required_filehandlers(self, requirements, filename_info): """Find the necessary file handlers for the given requirements. We assume here requirements are available. Raises: KeyError, if no handler for the given requirements is available. RuntimeError, if there is a handler for the given requirements, but it doesn't match the filename info. """ req_fh = [] filename_info = set(filename_info.items()) if requirements: for requirement in requirements: for fhd in self.file_handlers[requirement]: if set(fhd.filename_info.items()).issubset(filename_info): req_fh.append(fhd) break else: raise RuntimeError("No matching requirement file of type " "{}".format(requirement)) # break everything and continue to next # filetype! return req_fh def sorted_filetype_items(self): """Sort the instance's filetypes in using order.""" processed_types = [] file_type_items = deque(self.config['file_types'].items()) while len(file_type_items): filetype, filetype_info = file_type_items.popleft() requirements = filetype_info.get('requires') if requirements is not None: # requirements have not been processed yet -> wait missing = [req for req in requirements if req not in processed_types] if missing: file_type_items.append((filetype, filetype_info)) continue processed_types.append(filetype) yield filetype, filetype_info @staticmethod def filename_items_for_filetype(filenames, filetype_info): """Iterate over the filenames matching *filetype_info*.""" matched_files = [] for pattern in filetype_info['file_patterns']: for filename in match_filenames(filenames, pattern): if filename in matched_files: continue try: filename_info = parse( pattern, get_filebase(filename, pattern)) except ValueError: logger.debug("Can't parse %s with %s.", filename, pattern) continue matched_files.append(filename) yield filename, filename_info def new_filehandler_instances(self, filetype_info, filename_items, fh_kwargs=None): """Generate new filehandler instances.""" requirements = filetype_info.get('requires') filetype_cls = filetype_info['file_reader'] if fh_kwargs is None: fh_kwargs = {} for filename, filename_info in filename_items: try: req_fh = self.find_required_filehandlers(requirements, filename_info) except KeyError as req: msg = "No handler for reading requirement {} for {}".format( req, filename) warnings.warn(msg) continue except RuntimeError as err: warnings.warn(str(err) + ' for {}'.format(filename)) continue yield filetype_cls(filename, filename_info, filetype_info, *req_fh, **fh_kwargs) def time_matches(self, fstart, fend): """Check that a file's start and end time mtach filter_parameters of this reader.""" start_time = self.filter_parameters.get('start_time') end_time = self.filter_parameters.get('end_time') fend = fend or fstart if start_time and fend and fend < start_time: return False if end_time and fstart and fstart > end_time: return False return True def metadata_matches(self, sample_dict, file_handler=None): """Check that file metadata matches filter_parameters of this reader.""" # special handling of start/end times if not self.time_matches( sample_dict.get('start_time'), sample_dict.get('end_time')): return False for key, val in self.filter_parameters.items(): if key != 'area' and key not in sample_dict: continue if key in ['start_time', 'end_time']: continue elif key == 'area' and file_handler: if not self.check_file_covers_area(file_handler, val): logger.info('Filtering out %s based on area', file_handler.filename) break elif key in sample_dict and val != sample_dict[key]: # don't use this file break else: # all the metadata keys are equal return True return False def filter_filenames_by_info(self, filename_items): """Filter out file using metadata from the filenames. Currently only uses start and end time. If only start time is available from the filename, keep all the filename that have a start time before the requested end time. """ for filename, filename_info in filename_items: fend = filename_info.get('end_time') fstart = filename_info.setdefault('start_time', fend) if fend and fend < fstart: # correct for filenames with 1 date and 2 times fend = fend.replace(year=fstart.year, month=fstart.month, day=fstart.day) filename_info['end_time'] = fend if self.metadata_matches(filename_info): yield filename, filename_info def filter_fh_by_metadata(self, filehandlers): """Filter out filehandlers using provide filter parameters.""" for filehandler in filehandlers: filehandler.metadata['start_time'] = filehandler.start_time filehandler.metadata['end_time'] = filehandler.end_time if self.metadata_matches(filehandler.metadata, filehandler): yield filehandler def filter_selected_filenames(self, filenames): """Filter provided files based on metadata in the filename.""" for _, filetype_info in self.sorted_filetype_items(): filename_iter = self.filename_items_for_filetype(filenames, filetype_info) if self.filter_filenames: filename_iter = self.filter_filenames_by_info(filename_iter) for fn, _ in filename_iter: yield fn def new_filehandlers_for_filetype(self, filetype_info, filenames, fh_kwargs=None): """Create filehandlers for a given filetype.""" filename_iter = self.filename_items_for_filetype(filenames, filetype_info) if self.filter_filenames: # preliminary filter of filenames based on start/end time # to reduce the number of files to open filename_iter = self.filter_filenames_by_info(filename_iter) filehandler_iter = self.new_filehandler_instances(filetype_info, filename_iter, fh_kwargs=fh_kwargs) filtered_iter = self.filter_fh_by_metadata(filehandler_iter) return list(filtered_iter) def create_filehandlers(self, filenames, fh_kwargs=None): """Organize the filenames into file types and create file handlers.""" filenames = list(OrderedDict.fromkeys(filenames)) logger.debug("Assigning to %s: %s", self.info['name'], filenames) self.info.setdefault('filenames', []).extend(filenames) filename_set = set(filenames) created_fhs = {} # load files that we know about by creating the file handlers for filetype, filetype_info in self.sorted_filetype_items(): filehandlers = self.new_filehandlers_for_filetype(filetype_info, filename_set, fh_kwargs=fh_kwargs) filename_set -= set([fhd.filename for fhd in filehandlers]) if filehandlers: created_fhs[filetype] = filehandlers self.file_handlers[filetype] = sorted( self.file_handlers.get(filetype, []) + filehandlers, key=lambda fhd: (fhd.start_time, fhd.filename)) # load any additional dataset IDs determined dynamically from the file # and update any missing metadata that only the file knows self.update_ds_ids_from_file_handlers() return created_fhs def _file_handlers_available_datasets(self): """Generate a series of available dataset information. This is done by chaining file handler's :meth:`satpy.readers.file_handlers.BaseFileHandler.available_datasets` together. See that method's documentation for more information. Returns: Generator of (bool, dict) where the boolean tells whether the current dataset is available from any of the file handlers. The boolean can also be None in the case where no loaded file handler is configured to load the dataset. The dictionary is the metadata provided either by the YAML configuration files or by the file handler itself if it is a new dataset. The file handler may have also supplemented or modified the information. """ # flatten all file handlers in to one list flat_fhs = (fh for fhs in self.file_handlers.values() for fh in fhs) id_values = list(self.all_ids.values()) configured_datasets = ((None, ds_info) for ds_info in id_values) for fh in flat_fhs: # chain the 'available_datasets' methods together by calling the # current file handler's method with the previous ones result configured_datasets = fh.available_datasets(configured_datasets=configured_datasets) return configured_datasets def update_ds_ids_from_file_handlers(self): """Add or modify available dataset information. Each file handler is consulted on whether or not it can load the dataset with the provided information dictionary. See :meth:`satpy.readers.file_handlers.BaseFileHandler.available_datasets` for more information. """ avail_datasets = self._file_handlers_available_datasets() new_ids = {} for is_avail, ds_info in avail_datasets: # especially from the yaml config coordinates = ds_info.get('coordinates') if isinstance(coordinates, list): # xarray doesn't like concatenating attributes that are # lists: https://github.com/pydata/xarray/issues/2060 ds_info['coordinates'] = tuple(ds_info['coordinates']) ds_info.setdefault('modifiers', tuple()) # default to no mods ds_id = DatasetID.from_dict(ds_info) # all datasets new_ids[ds_id] = ds_info # available datasets # False == we have the file type but it doesn't have this dataset # None == we don't have the file type object to ask if is_avail: self.available_ids[ds_id] = ds_info self.all_ids = new_ids @staticmethod def _load_dataset(dsid, ds_info, file_handlers, dim='y', **kwargs): """Load only a piece of the dataset.""" slice_list = [] failure = True for fh in file_handlers: try: projectable = fh.get_dataset(dsid, ds_info) if projectable is not None: slice_list.append(projectable) failure = False except KeyError: logger.warning("Failed to load {} from {}".format(dsid, fh), exc_info=True) if failure: raise KeyError( "Could not load {} from any provided files".format(dsid)) if dim not in slice_list[0].dims: return slice_list[0] res = xr.concat(slice_list, dim=dim) combined_info = file_handlers[0].combine_info( [p.attrs for p in slice_list]) res.attrs = combined_info return res def _load_dataset_data(self, file_handlers, dsid, **kwargs): ds_info = self.all_ids[dsid] proj = self._load_dataset(dsid, ds_info, file_handlers, **kwargs) # FIXME: areas could be concatenated here # Update the metadata proj.attrs['start_time'] = file_handlers[0].start_time proj.attrs['end_time'] = file_handlers[-1].end_time return proj def _preferred_filetype(self, filetypes): """Get the preferred filetype out of the *filetypes* list. At the moment, it just returns the first filetype that has been loaded. """ if not isinstance(filetypes, list): filetypes = [filetypes] # look through the file types and use the first one that we have loaded for filetype in filetypes: if filetype in self.file_handlers: return filetype return None def _load_area_def(self, dsid, file_handlers, **kwargs): """Load the area definition of *dsid*.""" return _load_area_def(dsid, file_handlers) def _get_coordinates_for_dataset_key(self, dsid): """Get the coordinate dataset keys for *dsid*.""" ds_info = self.all_ids[dsid] cids = [] for cinfo in ds_info.get('coordinates', []): if not isinstance(cinfo, dict): cinfo = {'name': cinfo} cinfo['resolution'] = ds_info['resolution'] if 'polarization' in ds_info: cinfo['polarization'] = ds_info['polarization'] cid = DatasetID(**cinfo) cids.append(self.get_dataset_key(cid)) return cids def _get_coordinates_for_dataset_keys(self, dsids): """Get all coordinates.""" coordinates = {} for dsid in dsids: cids = self._get_coordinates_for_dataset_key(dsid) coordinates.setdefault(dsid, []).extend(cids) return coordinates def _get_file_handlers(self, dsid): """Get the file handler to load this dataset.""" ds_info = self.all_ids[dsid] filetype = self._preferred_filetype(ds_info['file_type']) if filetype is None: logger.warning("Required file type '%s' not found or loaded for " "'%s'", ds_info['file_type'], dsid.name) else: return self.file_handlers[filetype] def _make_area_from_coords(self, coords): """Create an appropriate area with the given *coords*.""" if len(coords) == 2: lon_sn = coords[0].attrs.get('standard_name') lat_sn = coords[1].attrs.get('standard_name') if lon_sn == 'longitude' and lat_sn == 'latitude': key = None try: key = (coords[0].data.name, coords[1].data.name) sdef = self.coords_cache.get(key) except AttributeError: sdef = None if sdef is None: sdef = SwathDefinition(*coords) sensor_str = '_'.join(self.info['sensors']) shape_str = '_'.join(map(str, coords[0].shape)) sdef.name = "{}_{}_{}_{}".format(sensor_str, shape_str, coords[0].attrs['name'], coords[1].attrs['name']) if key is not None: self.coords_cache[key] = sdef return sdef else: raise ValueError( 'Coordinates info object missing standard_name key: ' + str(coords)) elif len(coords) != 0: raise NameError("Don't know what to do with coordinates " + str( coords)) def _load_dataset_area(self, dsid, file_handlers, coords, **kwargs): """Get the area for *dsid*.""" try: return self._load_area_def(dsid, file_handlers, **kwargs) except NotImplementedError: if any(x is None for x in coords): logger.warning( "Failed to load coordinates for '{}'".format(dsid)) return None area = self._make_area_from_coords(coords) if area is None: logger.debug("No coordinates found for %s", str(dsid)) return area def _load_dataset_with_area(self, dsid, coords, **kwargs): """Load *dsid* and its area if available.""" file_handlers = self._get_file_handlers(dsid) if not file_handlers: return area = self._load_dataset_area(dsid, file_handlers, coords, **kwargs) try: ds = self._load_dataset_data(file_handlers, dsid, **kwargs) except (KeyError, ValueError) as err: logger.exception("Could not load dataset '%s': %s", dsid, str(err)) return None if area is not None: ds.attrs['area'] = area ds = add_crs_xy_coords(ds, area) return ds def _load_ancillary_variables(self, datasets): """Load the ancillary variables of `datasets`.""" all_av_ids = set() for dataset in datasets.values(): ancillary_variables = dataset.attrs.get('ancillary_variables', []) if not isinstance(ancillary_variables, (list, tuple, set)): ancillary_variables = ancillary_variables.split(' ') av_ids = [] for key in ancillary_variables: try: av_ids.append(self.get_dataset_key(key)) except KeyError: logger.warning("Can't load ancillary dataset %s", str(key)) all_av_ids |= set(av_ids) dataset.attrs['ancillary_variables'] = av_ids loadable_av_ids = [av_id for av_id in all_av_ids if av_id not in datasets] if not all_av_ids: return if loadable_av_ids: self.load(loadable_av_ids, previous_datasets=datasets) for dataset in datasets.values(): new_vars = [] for av_id in dataset.attrs.get('ancillary_variables', []): if isinstance(av_id, DatasetID): new_vars.append(datasets[av_id]) else: new_vars.append(av_id) dataset.attrs['ancillary_variables'] = new_vars def get_dataset_key(self, key, available_only=False, **kwargs): """Get the fully qualified `DatasetID` matching `key`. This will first search through available DatasetIDs, datasets that should be possible to load, and fallback to "known" datasets, those that are configured but aren't loadable from the provided files. Providing ``available_only=True`` will stop this fallback behavior and raise a ``KeyError`` exception if no available dataset is found. Args: key (str, float, DatasetID): Key to search for in this reader. available_only (bool): Search only loadable datasets for the provided key. Loadable datasets are always searched first, but if ``available_only=False`` (default) then all known datasets will be searched. kwargs: See :func:`satpy.readers.get_key` for more information about kwargs. Returns: Best matching DatasetID to the provided ``key``. Raises: KeyError: if no key match is found. """ try: return get_key(key, self.available_ids.keys(), **kwargs) except KeyError: if available_only: raise return get_key(key, self.all_ids.keys(), **kwargs) def load(self, dataset_keys, previous_datasets=None, **kwargs): """Load `dataset_keys`. If `previous_datasets` is provided, do not reload those. """ all_datasets = previous_datasets or DatasetDict() datasets = DatasetDict() # Include coordinates in the list of datasets to load dsids = [self.get_dataset_key(ds_key) for ds_key in dataset_keys] coordinates = self._get_coordinates_for_dataset_keys(dsids) all_dsids = list(set().union(*coordinates.values())) + dsids for dsid in all_dsids: if dsid in all_datasets: continue coords = [all_datasets.get(cid, None) for cid in coordinates.get(dsid, [])] ds = self._load_dataset_with_area(dsid, coords, **kwargs) if ds is not None: all_datasets[dsid] = ds if dsid in dsids: datasets[dsid] = ds self._load_ancillary_variables(all_datasets) return datasets
class SendFiles: _send_files = WeakValueDictionary() def __init__(self, index, files, callback, time_multiplier=1.5, size=0): """Initialize the send_files. :param int index: Index of the receiver. :param tuple files: The files to be sent. :param callback: A callable object to be called after all transfers are complete. :param float time_multiplier: The multiplier for the wait time to stop the downloader. :param int size: Total compressed file size to skip compression. """ global transfer_id self.index = index self.files = files self.callback = callback self.size = size self.file = self.files[-1].encode("utf-8") if not self.size: compress = net_compresspackets.get_bool() for file in files: path = Path(GAME_PATH / file) self.size += compress_file( path) if compress else path.stat().st_size self.estimated_time = (self.size / 256) * server.tick_interval self.transfer_time = self.estimated_time self.time_limit = self.estimated_time * time_multiplier net_channel = Player(index).client.net_channel for file in self.files: if not net_channel.send_file(file, transfer_id): self.delay = Delay(0.1, self.transfer_error, cancel_on_level_end=True) break transfer_id += 1 else: self.net_channel = ctypes.c_void_p(net_channel._ptr().address) self.delay = Delay(self.estimated_time, self.transfer_end, cancel_on_level_end=True) self._send_files[id(self)] = self def transfer_error(self): self.delay = None self.callback(self, False) def transfer_end(self): sv_allowupload.update_index(self.index) self.check_files() def check_files(self): self.delay = None if not net_chan_is_file_in_waiting_list(self.net_channel, self.file): if self.index not in sv_allowupload: return self.callback(self, True) else: return self.callback(self, False) if self.transfer_time >= self.time_limit: return self.callback(self, False) self.transfer_time += 1 self.delay = Delay(1, self.check_files, cancel_on_level_end=True) @OnClientDisconnect def on_client_disconnect(index): for send_files in SendFiles._send_files.values(): if (send_files.index == index and send_files.delay is not None): if send_files.delay.running: send_files.delay.cancel() send_files.delay = None
from twisted.internet import reactor from twisted.internet.task import deferLater from twisted.internet.defer import inlineCallbacks, returnValue from django.conf import settings from evennia.commands.command import InterruptCommand from evennia.comms.channelhandler import CHANNELHANDLER from evennia.utils import logger, utils from evennia.utils.utils import string_suggestions from django.utils.translation import gettext as _ _IN_GAME_ERRORS = settings.IN_GAME_ERRORS __all__ = ("cmdhandler", "InterruptCommand") _GA = object.__getattribute__ _CMDSET_MERGE_CACHE = WeakValueDictionary() # tracks recursive calls by each caller # to avoid infinite loops (commands calling themselves) _COMMAND_NESTING = defaultdict(lambda: 0) _COMMAND_RECURSION_LIMIT = 10 # This decides which command parser is to be used. # You have to restart the server for changes to take effect. _COMMAND_PARSER = utils.variable_from_module( *settings.COMMAND_PARSER.rsplit(".", 1)) # System command names - import these variables rather than trying to # remember the actual string constants. If not defined, Evennia # hard-coded defaults are used instead.
def __init__(self, conn): self.conn = conn self.objects = {} self.proxy_cache = WeakValueDictionary()