class ImmutableDict(Mapping): """ Copies a dict and proxies it via types.MappingProxyType to make it immutable. https://stackoverflow.com/questions/9997176/immutable-dictionary-only-use-as-a-key-for-another-dictionary/39673094#39673094 """ def __init__(self, somedict): dictcopy = dict(somedict) # make a copy self._dict = MappingProxyType(dictcopy) # lock it self._hash = None def __getitem__(self, key): return self._dict[key] def __len__(self): return len(self._dict) def __iter__(self): return iter(self._dict) def __hash__(self): if self._hash is None: self._hash = hash(frozenset(self._dict.items())) return self._hash def __eq__(self, other): return self._dict == other._dict def __repr__(self): return str(self._dict)
def build_context(ctx: MappingProxyType = None, session=None, admin=None, testing=None, request_id=None): """ :rtype: """ new_fields = { CTX_SQL_SESSION: session, CTX_ADMIN: admin, CTX_TESTING: testing, CTX_REQUEST_ID: request_id, } if ctx is None: return MappingProxyType(new_fields) old_fields = {k: v for k, v in ctx.items() if v is not None} # Remove None fields. merged = { **new_fields, **old_fields } # Merge old and new context, with priority for the new one. return MappingProxyType(merged)
def helper(d, schema): d_type = type(d) schema_type = type(schema) # list 처리 if d_type == list: assert schema_type in JSON_ARRAY_TYPES, schema_type for element in d: helper(element, schema[0]) return # 형식 확인 및 안전 장치 세팅 assert d_type == dict, d_type assert schema_type in JSON_MAPPING_TYPES, schema_type if schema_type != MappingProxyType: schema = MappingProxyType(schema) # schema 에 없는 항목 처리 for schema_k, schema_v in schema.items(): if schema_k not in d: schema_v_type = type(schema_v) if schema_v_type in JSON_ARRAY_TYPES: d[schema_k] = [] elif schema_v_type in JSON_MAPPING_TYPES: d[schema_k] = {} helper(d[schema_k], schema_v) else: assert schema_v_type in JSON_LEAF_TYPES, schema_v_type d[schema_k] = schema_v # schema 에 있는 항목 처리 for k, v in d.items(): if k in schema: v_type = type(v) if v_type in JSON_CONTAINER_TYPES: helper(v, schema[k]) else: assert v_type in JSON_LEAF_TYPES, v_type schema_v = schema[k] schema_v_type = type(schema_v) assert schema_v_type in JSON_LEAF_TYPES, schema_v_type if v is not None and schema_v is not None: if not isinstance(v, type(schema_v)): if v_type == int and schema_v_type in (float, str): d[k] = schema_v_type(v) elif v_type == float and schema_v_type in (str, ): d[k] = schema_v_type(v) else: assert False, "k:v == ({}, {})\nschema = {}".format( k, v, schema)
class SetEnviron: """A reentrant, re-usable context manager for temporarily setting environment variables.""" __slots__ = ('__weakref__', '_kwargs', '_kwargs_old') _kwargs: Mapping[str, str] _kwargs_old: Mapping[str, Optional[str]] def __init__(self, **kwargs: str) -> None: r"""Initialize the context manager. Parameters ---------- \**kwargs : :class:`str` The to-be updated parameters. """ self._kwargs = MappingProxyType(kwargs) self._kwargs_old = MappingProxyType({ k: os.environ.get(k) for k in self._kwargs }) def __enter__(self) -> None: """Enter the context manager.""" os.environ.update(self._kwargs) def __exit__( self, __exc_type: Optional[Type[BaseException]], __exc_value: Optional[BaseException], __traceback: Optional[TracebackType], ) -> None: """Exit the context manager.""" for k, v in self._kwargs_old.items(): if v is None: # Use `pop` instead of `del` to ensure thread-safety os.environ.pop(k, None) else: os.environ[k] = v
spec = [ ('array', nbtype[:]), ('arraysize', nb.intp), ('size', nb.intp), ('upper', nb.boolean), ('diag_val', nbtype) ] return nb.jitclass(spec)(base) # Jitclasses of NbTriLMatrixBase for each supported array data type _LOWER_JITCLASS_BY_TYPE = MappingProxyType({ dtype: _make_jitclass_for_type(NbTriLMatrixBase, nbtype) for dtype, nbtype in NUMBA_TYPES.items() }) # Jitclasses of NbTriUMatrixBase for each supported array data type _UPPER_JITCLASS_BY_TYPE = MappingProxyType({ dtype: _make_jitclass_for_type(NbTriUMatrixBase, nbtype) for dtype, nbtype in NUMBA_TYPES.items() }) def _get_jitclass_for_dtype(dtype, upper): """ Get the correct jitclass of NbTriMatrixBase for the data type and triangle. """ if upper: return _UPPER_JITCLASS_BY_TYPE[dtype]
class Network: """Maintains connections over the raft network as well as queues for messages""" def __init__(self, peer_id): self.peer_id = peer_id self._active_connections = {} self.inbox = asyncio.Queue() self._active_ids = self._active_connections.keys() self._all_ids = frozenset(config.SERVERS) - {peer_id} # have a queue per peer in the network self.outbox = MappingProxyType( {i: asyncio.Queue() for i in self._all_ids}) def send(self, dest: int, msg): self.outbox[dest].put_nowait(msg) async def _send_messages(self): suppresser = contextlib.suppress(KeyError) for peer, queue in itertools.cycle(self.outbox.items()): try: msg = queue.get_nowait() except asyncio.QueueEmpty: # no message here, come back later await asyncio.sleep(.1) else: queue.task_done() with suppresser: reader, writer = self._active_connections[peer] logger.debug(f"Sending to {peer=} {msg=}") await transport.send(writer, msg) @cached_property def peers(self) -> frozenset: return self._all_ids async def recv(self) -> bytes: return await self.inbox.get() async def _on_connetion(self, reader, writer): logger.info(f"Handling client: {writer.get_extra_info('peername')}") connected_peer = await transport.recv( reader) # first message is always peer_id logger.debug(f"Starting to receive from {connected_peer=}") with contextlib.suppress(asyncio.IncompleteReadError): while message := await transport.recv(reader): await self.inbox.put(message) logger.warning( f"Broken peer {connected_peer}. Closing connection and clearing queue." ) writer.close() await writer.wait_closed() # remove all that was going to be sent peer_q = self.outbox[connected_peer] while peer_q.qsize(): peer_q.get_nowait() peer_q.task_done() logger.warning( f"Queue {peer_q} empty. Resurrecting connection for {connected_peer}" ) del self._active_connections[connected_peer] awaitable = self._connect(connected_peer) asyncio.create_task( awaitable) # resurrect connection and hope for the best
class Snapshot: """ An immutable snapshot of contract data. """ def __init__(self, creates: Mapping[ContractId, CreateEvent], offset: Optional[str]): self._offset = offset self._creates = MappingProxyType(dict(creates)) self._contracts = ContractDataView(self._creates) self._template_ids = { cid.value_type: matching_normalizations(cid.value_type) for cid in self._creates } @property def offset(self) -> Optional[str]: """ The offset point in a stream that contains all of the creates without offsetting archives in this snapshot. This value is ``None`` when connecting to a ledger that is completely empty. """ return self._offset def earliest_contract( self, template_id: Union[str, TypeConName], match: Optional[ContractMatch] = None) -> Optional[ContractData]: """ Return the earliest contract in the Active Contract Set (in other words, _still_ active) that was created in the transaction stream that matches the specified filter, or ``None`` if there are no matches. """ ev = self.earliest_create(template_id, match) return ev.payload if ev is not None else None def matching_contracts( self, template_id: Union[str, TypeConName], match: Optional[ContractMatch] = None ) -> Mapping[ContractId, ContractData]: """""" return ContractDataView(self.matching_creates(template_id, match)) def latest_contract( self, template_id: Union[str, TypeConName], match: Optional[ContractMatch] = None) -> Optional[ContractData]: """ Return the contract that was created last in the transaction stream that matches the specified filter, or ``None`` if there are no matches. """ ev = self.latest_create(template_id, match) return ev.payload if ev is not None else None def earliest_create( self, template_id: Union[str, TypeConName], match: Optional[ContractMatch] = None) -> Optional[CreateEvent]: """ Return the earliest :class:`CreateEvent` in the Active Contract Set (in other words, the corresponding contract is _still_ active) that was created in the transaction stream that matches the specified filter, or ``None`` if there are no matches. """ wanted_template_ids = self._matching_template_ids(template_id) # in Python 3.9, dict values can be `reversed`; sadly we're not there yet for cev in reversed(list(self._creates.values())): if cev.contract_id.value_type in wanted_template_ids and is_match( match, cev.payload): return cev return None def matching_creates( self, template_id: Union[str, TypeConName], match: Optional[ContractMatch] = None ) -> Mapping[ContractId, CreateEvent]: """ Return the :class:`CreateEvent`s (indexed by :class:`ContractId`) whose contracts match the specified criteria. """ wanted_template_ids = self._matching_template_ids(template_id) matches = {} for cid, cev in self._creates.items(): if cev.contract_id.value_type in wanted_template_ids and is_match( match, cev.payload): matches[cev.contract_id] = cev return MappingProxyType(matches) def latest_create( self, template_id: Union[str, TypeConName], match: Optional[ContractMatch] = None) -> Optional[CreateEvent]: """ Return the contract that was created last in the transaction stream that matches the specified filter, or ``None`` if there are no matches. """ wanted_template_ids = self._matching_template_ids(template_id) for cev in self._creates.values(): if cev.contract_id.value_type in wanted_template_ids and is_match( match, cev.payload): return cev return None def _matching_template_ids( self, template_id: Union[str, TypeConName]) -> Collection[TypeConName]: t = normalize(template_id) return { tid for tid, matches in self._template_ids.items() if t in matches } @property def contracts(self) -> Mapping[ContractId, ContractData]: """ A read-only map of contract IDs to contract data. """ return self._contracts @property def creates(self) -> Mapping[ContractId, CreateEvent]: """ A map of contract IDs to :class:`CreateEvent`s that represent the current set of contracts in this ACS. :class:`CreateEvent`'s expose additional information, including signatories and observers. """ return self._creates def __bool__(self) -> bool: return bool(self._creates) def __len__(self) -> int: return len(self._creates) def __repr__(self) -> str: return f"Snapshot(len={len(self)})"
class DRV(object): """ A discrete random variable. A DRV has one or more :dfn:`possible values` (or just :dfn:`values`), which can be any type. Each possible value has an associated :dfn:`probability`, which is a real number between 0 and 1. It is strongly recommended that the probabilities add up to exactly 1. This might be difficult to achieve with :obj:`float` probabilities, and so this class does not enforce that restriction, and makes it possible to sample a variable even if the total is not 1. The exact distribution of the samples in that case is not specified, only that it will attempt to follow the probabilities given. Loosely: if the total is too low then one value's probability is rounded up. If the total is too high, then one probability is rounded down, and/or one or more values is ignored. These adjustments apply only to sampling: the original probabilities are still reported by :func:`to_dict()` etc. Because :code:`==` is overridden to return a DRV (not a boolean), DRV objects are not hashable and cannot be used in a set or as a dictionary key, even though the objects are immutable. This means you cannot have a DRV as a "possible value" of another DRV. DRV also resists being considered in boolean context, so for example you cannot in general test whether or not a DRV appears in a list:: >>> from omnidice.dice import d3, d6 >>> d3 in [d3, d6] True >>> d6 in [d3, d6] Traceback (most recent call last): File "<stdin>", line 1, in <module> File "omnidice/drv.py", line 452, in __bool__ raise ValueError('The truth value of a random variable is ambiguous') ValueError: The truth value of a random variable is ambiguous This is the same solution used by (for example) :obj:`numpy.array`. If the object allowed standard boolean conversion then :code:`d4 in [d3, d6]` would be True, which is unacceptably surprising! :param distribution: Any value from which a dictionary can be constructed, that is a :obj:`Mapping` or :obj:`Iterable` of (value, probability) pairs. :param tree: The expression from which this object was defined. Currently this is used only for the string representation, but might in future help support lazily-evaluated DRVs. """ def __init__( self, distribution: 'DictData', *, tree: ExpressionTree = None, ): self.__dist = MappingProxyType(dict(distribution)) # Cumulative distribution. Defer calculating this, because we only # need it if the variable is actually sampled. Intermediate values in # building up a complex DRV won't ever be sampled, so save the work. self.__cdf = None self.__lcm = None self.__intvalued = None self.__expr_tree = tree # Computed probabilities can hit 0 due to float underflow, but maybe # we should strip out anything with probability 0. if not all(0 <= prob <= 1 for value, prob in self._items()): raise ValueError('Probability not in range') def __repr__(self): if self.__expr_tree is not None: return self.__expr_tree.bracketed() return f'DRV({self.__dist})' def is_same(self, other: 'DRV') -> bool: """ Return True if `self` and `other` have the same discrete probability distribution. Possible values with 0 probability are excluded from the comparison. """ values = set(value for value, prob in self._items() if prob != 0) othervalues = set(value for value, prob in other._items() if prob != 0) if values != othervalues: return False return all(self.__dist[val] == other.__dist[val] for val in values) def is_close(self, other: 'DRV', *, rel_tol=None, abs_tol=None) -> bool: """ Return True if `self` and `other` have approximately the same discrete probability distribution, within the specified tolerances. Possible values with 0 probability are excluded from the comparison. `rel_tol` and `abs_tol` are applied only to the probabilities, not to the possible values. They are defined as for :func:`math.isclose`. """ values = set(value for value, prob in self._items() if prob != 0) othervalues = set(value for value, prob in other._items() if prob != 0) if values != othervalues: return False kwargs = {} if rel_tol is not None: kwargs['rel_tol'] = rel_tol if abs_tol is not None: kwargs['abs_tol'] = abs_tol return all( isclose(self.__dist[val], other.__dist[val], **kwargs) for val in values ) def to_dict(self) -> Dict[Any, 'Probability']: """ Return a dictionary mapping all possible values to probabilities. """ # dict(self.__dist) is type-correct, but about 3 times slower. # Unfortunately there's no way to parameterise MappingProxyType to # say what the type is of the underlying mapping that gets copied. return self.__dist.copy() # type: ignore def to_pd(self): """ Return a :class:`pandas.Series` mapping values to probabilities. The series is indexed by the possible values. :raises: :class:`ModuleNotFoundError` if pandas is not installed. Note that pandas is not a hard dependency of this package. You must install it to use this method. """ try: import pandas as pd except ModuleNotFoundError: msg = 'You must install pandas for this optional feature' raise ModuleNotFoundError(msg) return pd.Series(self.__dist, name='probability') def to_table(self, as_float: bool = False) -> str: """ Return a string containing the values and probabilities formatted as a table. This is intended only for manually checking small distributions. :param as_float: Display probabilites as floating-point. You might find floats easier to read by eye. """ if not as_float: items = self._items() else: items = ((v, float(p)) for v, p in self._items()) with contextlib.suppress(TypeError): items = sorted(items) return '\n'.join([ 'value\tprobability', *(f'{v}\t{p}' for v, p in items), ]) def faster(self) -> 'DRV': """ Return a new DRV, with all probabilities converted to float. """ return DRV( {x: float(y) for x, y in self._items()}, tree=self._combine_post('.faster()'), ) def _items(self): return self.__dist.items() def replace_tree(self, tree: ExpressionTree) -> 'DRV': """ Return a new DRV with the same distribution as this DRV, but defined from the specified expression. This is used for example when some optimisation has computed a DRV one way, but we want to represent it the original way. """ return DRV(self.__dist, tree=tree) @property def cdf(self): if self.__cdf is None: def iter_totals(): total = 0 for value, probability in self._items(): total += probability yield value, total # In case of rounding errors if total < 1: yield value, 1 self.__cdf_values, self.__cdf = map(tuple, zip(*iter_totals())) return self.__cdf @property def _lcm(self): def lcm(a, b): return (a * b) // gcd(a, b) if self.__lcm is None: result = 1 for _, prob in self._items(): if not isinstance(prob, Fraction): result = 0 break result = lcm(prob.denominator, result) self.__lcm = result return self.__lcm def sample(self, random: Random = rng): """ Sample this variable. :param random: Random number generator to use. The default is a single object shared by all instances of :class:`DRV`. :returns: One possible value of this variable. """ sample: Probability if self._lcm == 0: sample = random.random() else: sample = Fraction(random.randrange(self._lcm) + 1, self._lcm) # The index of the first cumulative probability greater than or equal # to our random sample. If there's a repeated probability in the array, # that means there was a value with probability 0. So we don't want to # select that value even in the very unlikely case of our sample being # exactly equal to the repeated probability! idx = bisect_left(self.cdf, sample) return self.__cdf_values[idx] @property def _intvalued(self): if self.__intvalued is None: self.__intvalued = all(isinstance(x, int) for x in self.__dist) return self.__intvalued def __add__(self, right) -> 'DRV': """ Handler for :code:`self + right`. Return a random variable which is the result of adding this variable to `right`. `right` can be either a constant or another DRV (in which case the result assumes that the two random variables are independent). As with :meth:`apply()`, probabilities are added up wherever addition is many-to-one (for constant numbers it is one-to-one provided overflow does not occur). """ while CONVOLVE_OPTIMISATION: if np is None: break if not isinstance(right, DRV): break product_size = len(self.__dist) * len(right.__dist) if product_size <= CONVOLVE_SIZE_LIMIT: break if not self._intvalued or not right._intvalued: break def get_range(dist): return range(min(dist), max(dist) + 1) self_values = get_range(self.__dist) right_values = get_range(right.__dist) # Very sparse arrays aren't faster to convolve. if 100 * product_size <= len(self_values) * len(right_values): break final_probs = np.convolve( np.array(tuple(self.__dist.get(x, 0) for x in self_values)), np.array(tuple(right.__dist.get(x, 0) for x in right_values)), ) values = range( min(self_values) + min(right_values), max(self_values) + max(right_values) + 1, ) filtered = (final_probs > 0) values = np.array(values)[filtered].tolist() final_probs = final_probs[filtered] return DRV( zip(values, final_probs), tree=self._combine(self, right, '+'), ) return self._apply2(operator.add, right, connective='+') def __sub__(self, right) -> 'DRV': """ Handler for :code:`self - right`. Return a random variable which is the result of subtracting `right` from this variable. `right` can be either a constant or another DRV (in which case the result assumes that the two random variables are independent). As with :meth:`apply()`, probabilities are added up wherever subtraction is many-to-one (for constant numbers it is one-to-one provided overflow does not occur). """ if isinstance(right, DRV): # So that we get the convolve optimisation tree = self._combine(self, right, '-') return (self + -right).replace_tree(tree) else: return self._apply2(operator.sub, right, connective='-') def __mul__(self, right): """ Handler for :code:`self * right`. Return a random variable which is the result of multiplying this variable with `right`. `right` can be either a constant or another DRV (in which case the result assumes that the two random variables are independent). As with :meth:`apply()`, probabilities are added up in the case where multiplication is not one-to-one (for constant numbers other than zero it is one-to-one provided overflow and underflow do not occur). """ return self._apply2(operator.mul, right, connective='*') def __rmatmul__(self, left: int) -> 'DRV': """ Handler for :code:`left @ self`. Return a random variable which is the result of sampling this variable `left` times, and adding the results together. """ if not isinstance(left, int): return NotImplemented if left <= 0: raise ValueError(left) # Exponentiation by squaring. This isn't massively faster, but does # help a bit for hundreds of dice. result = None so_far = self original = left while True: if left % 2 == 1: if result is None: result = so_far else: result += so_far left //= 2 if left == 0: break so_far += so_far # left was non-zero, so result cannot still be None result = cast(DRV, result) return result.replace_tree(self._combine(original, self, '@')) def __matmul__(self, right: 'DRV') -> 'DRV': """ Handler for :code:`self @ right`. Return a random variable which is the result of sampling this variable once, then adding together that many samples of `right`. All possible values of this variable must be of type :obj:`int`. """ if not isinstance(right, DRV): return NotImplemented if not all(isinstance(value, int) for value in self.__dist): raise TypeError('require integers on LHS of @') def iter_drvs(): so_far = min(self.__dist) @ right for num_dice in range(min(self.__dist), max(self.__dist) + 1): if num_dice in self.__dist: yield so_far, self.__dist[num_dice] so_far += right return DRV.weighted_average( iter_drvs(), tree=self._combine(self, right, '@'), ) def __truediv__(self, right) -> 'DRV': """ Handler for :code:`self / right`. Return a random variable which is the result of floor-dividing this variable by `right`. `right` can be either a constant or another DRV (in which case the result assumes that the two random variables are independent). As with :meth:`apply()`, probabilities are added up wherever division is many-to-one (for constant numbers other than zero it is one-to-one provided overflow and underflow do not occur). 0 must not be a possible value of `right` (even with probability 0). """ return self._apply2(operator.truediv, right, connective='/') def __floordiv__(self, right) -> 'DRV': """ Handler for :code:`self // right`. Return a random variable which is the result of dividing this variable by `right`. `right` can be either a constant or another DRV (in which case the result assumes that the two random variables are independent). As with :meth:`apply()`, probabilities are added up wherever floor division is many-to-one (for numbers it is mostly many-to-one, for example :code:`2 // 2 == 1 == 3 // 2`). 0 must not be a possible value of `right` (even with probability 0). """ return self._apply2(operator.floordiv, right, connective='//') def __neg__(self) -> 'DRV': """ Handler for :code:`-self`. Return a random variable which is the result of negating the values of this variable. As with :meth:`apply()`, probabilities are added up wherever negation is many-to-one (for numbers it is one-to-one). """ return self.apply(operator.neg, tree=self._combine(self, '-')) def __eq__(self, right) -> 'DRV': # type: ignore[override] """ Handler for :code:`self == right`. Return a random variable which takes value :obj:`True` where `self` is equal to `right`, and :obj:`False` otherwise. `right` can be either a constant or another DRV (in which case the result assumes that the two random variables are independent). If either :obj:`True` or :obj:`False` cannot happen then the result has only one possible value, with probability 1. There is no possible value with probability 0. """ if isinstance(right, DRV): small, big = sorted([self, right], key=lambda x: len(x.__dist)) prob = sum( prob * big.__dist.get(val, 0) for val, prob in small._items() ) else: prob = self.__dist.get(right) if not prob: return DRV({False: 1}) if prob >= 1.0: return DRV({True: 1}) return DRV( {False: 1 - prob, True: prob}, tree=self._combine(self, right, '=='), ) def __ne__(self, right: 'DRV') -> 'DRV': # type: ignore[override] """ Handler for :code:`self != right`. Return a random variable which takes value :obj:`True` where `self` is not equal to `right`, and :obj:`False` otherwise. `right` can be either a constant or another DRV (in which case the result assumes that the two random variables are independent). If either :obj:`True` or :obj:`False` cannot happen then the result has only one possible value, with probability 1. There is no possible value with probability 0. """ return ( (self == right) .apply(operator.not_) .replace_tree(self._combine(self, right, '!=')) ) def __bool__(self): # Prevent DRVs being truthy, and hence "3 in [DRV({2: 1})]" is true. raise ValueError('The truth value of a random variable is ambiguous') def __le__(self, right) -> 'DRV': """ Handler for :code:`self <= right`. Return a random variable which takes value :obj:`True` where `self` is less than or equal to `right`, and :obj:`False` otherwise. `right` can be either a constant or another DRV (in which case the result assumes that the two random variables are independent). If either :obj:`True` or :obj:`False` cannot happen then the result has only one possible value, with probability 1. There is no possible value with probability 0. """ return self._apply2(operator.le, right, connective='<=') def __lt__(self, right) -> 'DRV': """ Handler for :code:`self < right`. Return a random variable which takes value :obj:`True` where `self` is less than `right`, and :obj:`False` otherwise. `right` can be either a constant or another DRV (in which case the result assumes that the two random variables are independent). If either :obj:`True` or :obj:`False` cannot happen then the result has only one possible value, with probability 1. There is no possible value with probability 0. """ return self._apply2(operator.lt, right, connective='<') def __ge__(self, right) -> 'DRV': """ Handler for :code:`self >= right`. Return a random variable which takes value :obj:`True` where `self` is greater than or equal to `right`, and :obj:`False` otherwise. `right` can be either a constant or another DRV (in which case the result assumes that the two random variables are independent). If either :obj:`True` or :obj:`False` cannot happen then the result has only one possible value, with probability 1. There is no possible value with probability 0. """ return self._apply2(operator.ge, right, connective='>=') def __gt__(self, right) -> 'DRV': """ Handler for :code:`self > right`. Return a random variable which takes value :obj:`True` where `self` is greater than `right`, and :obj:`False` otherwise. `right` can be either a constant or another DRV (in which case the result assumes that the two random variables are independent). If either :obj:`True` or :obj:`False` cannot happen then the result has only one possible value, with probability 1. There is no possible value with probability 0. """ return self._apply2(operator.gt, right, connective='>') def explode(self, rerolls: int = 50) -> 'DRV': """ Return a new DRV distributed according to the rules of an "exploding die". This means, first roll the die (sample this DRV). If the result is not the maximum possible, then keep it. If it is the maximum, then roll again and add the new result to the original. Because DRV represents only finitely-many possible values, whereas the process of rerolling can (with minuscule probability) carry on indefinitely, this method imposes an arbitary limit to the number of rerolls. :param rerolls: The maximum number of rerolls. Set this to 1 for a die that can only "explode" once, not indefinitely. """ reroll_value = max(self.__dist.keys()) reroll_prob = self.__dist[reroll_value] each_die = self.to_dict() each_die.pop(reroll_value) def iter_pairs(): for idx in range(rerolls + 1): for value, prob in each_die.items(): value += reroll_value * idx prob *= reroll_prob ** idx yield (value, prob) yield (reroll_value * (idx + 1), reroll_prob ** (idx + 1)) postfix = '.explode()' if rerolls == 50 else f'.explode({rerolls!r})' return self._reduced(iter_pairs(), tree=self._combine_post(postfix)) def apply( self, func: Callable[[Any], Any], *, tree: ExpressionTree = None, allow_drv: bool = False, ) -> 'DRV': """ Apply a unary function to the values produced by this DRV. If `func` is an injective (one-to-one) function, then the probabilities are unchanged. If `func` is many-to-one, then the probabilities are added together. :param func: Function to map the values. Each value `x` is replaced by `func(x)`. :param tree: the expression from which this object was defined. If ``None``, the result DRV is represented by listing out all the values and probabilities. :param allow_drv: If True, then when `func` returns a DRV, the possible values of that DRV are each included in the returned DRV. Recall that a DRV cannot be a possible value of the returned DRV, because it is not hashable. So, without this option `func` cannot return a DRV. .. versionchanged:: 1.1 Added ``allow_drv`` option. """ return DRV._reduced(self._items(), func, tree=tree, drv=allow_drv) def _apply2(self, func, right, connective=None) -> 'DRV': """Apply a binary function, with the values of this DRV on the left.""" expr_tree = self._combine(self, right, connective) if isinstance(right, DRV): return self._cross_reduce(func, right, tree=expr_tree) return self.apply(lambda x: func(x, right), tree=expr_tree) def _cross_reduce(self, func, right, tree=None) -> 'DRV': """ Take the cross product of self and right, then reduce by applying func. """ return DRV._reduced( self._iter_cross(right), lambda value: func(*value), tree=tree, ) def _iter_cross(self, right): """ Take the cross product of self and right, with probabilities assuming that the two are independent variables. Note that the cross product of an object with itself represents the results of sampling it twice, *not* just the pairs (x, x) for each possible value! """ for (lvalue, lprob) in self._items(): for (rvalue, rprob) in right._items(): yield ((lvalue, rvalue), lprob * rprob) @staticmethod def _reduced(iterable, func=lambda x: x, tree=None, drv=False) -> 'DRV': distribution: dict = collections.defaultdict(int) if not drv: # Optimisation does make a difference to e.g. test_convolve for value, prob in iterable: distribution[func(value)] += prob else: for value, prob in iterable: transformed = func(value) if isinstance(transformed, DRV): for value2, prob2 in transformed._weighted_items(prob): distribution[value2] += prob2 else: distribution[transformed] += prob return DRV(distribution, tree=tree) @staticmethod def weighted_average( iterable: Iterable[Tuple['DRV', 'Probability']], *, tree: ExpressionTree = None, ) -> 'DRV': """ Compute a weighted average of DRVs, each with its own probability. This is for when you have a set of mutually-exclusive events which can happen, and then the final outcome occurs with a different known distribution according to which of those events occurs. For example, this function is used to implement the ``@`` operator when the left-hand-side is a DRV. The first roll determines what the second roll will be. The DRVs that are averaged together do not need to be disjoint (that is, they can have overlapping possible values). Whenever multiple events lead to the same final outcome, the probabilities are combined: https://en.wikipedia.org/wiki/Law_of_total_probability :param iterable: Pairs, each containing a DRV and the probability of that DRV being the one selected. The probabilities should add to 1, but this is not enforced. :param tree: the expression from which this object was defined. If ``None``, the result DRV is represented by listing out all the values and probabilities. .. versionadded:: 1.1 """ def iter_pairs(): for drv, weight in iterable: yield from drv._weighted_items(weight) return DRV._reduced(iter_pairs(), tree=tree) def _weighted_items(self, weight, pred=lambda x: True): for value, prob in self.__dist.items(): if pred(value): yield value, prob * weight def given(self, predicate: Callable[[Any], bool]) -> 'DRV': """ Return the conditional probability distribution of this DRV, restricted to the possible values for which `predicate` is true. For example, :code:`drv.given(lambda x: True)` is the same distribution as :code:`drv`, and the following are equivalent to each other:: d6.given(lambda x: bool(x % 2)) DRV({1: Fraction(1, 3), 3: Fraction(1, 3), 5: Fraction(1, 3)}) If `x` is a DRV, and `A` and `B` are predicates, then the conditional probability of `A` given `B`, written in probability theory as ``p(A(x) | B(x))``, can be computed as :code:`p(x.given(B).apply(A)))`. :param predicate: Called with possible values of `self`, and must return :obj:`bool` (not just truthy). :raises ZeroDivisionError: if the probability of `predicate` being true is 0. .. versionadded:: 1.1 """ total = p(self.apply(predicate)) if total == 0: # Would be raised anyway, but nicer error message raise ZeroDivisionError('predicate is True with probability 0') return DRV(self._weighted_items(1 / total, predicate)) @staticmethod def _combine(*args): """ Helper for combining two expressions into a combined expression. """ for arg in args: if isinstance(arg, DRV) and arg.__expr_tree is None: return None def unpack(subexpr): if isinstance(subexpr, DRV): return subexpr.__expr_tree return Atom(repr(subexpr)) if len(args) == 2: # Unary expression subexpr, connective = args return UnaryExpression(unpack(subexpr), connective) # Binary expression left, right, connective = args return BinaryExpression(unpack(left), unpack(right), connective) def _combine_post(self, postfix): if self.__expr_tree is None: return None return AttrExpression(self.__expr_tree, postfix)
""" Postponed Evaluation of Annotations Becomes Default """ # int.bit_count() num1 = 1 print("num1 = {}".format(num1)) print("num1.bit_count() => {}".format(num1.bit_count())) num1 = 10 print("num1 = {}".format(num1)) print("num1.bit_count() => {}".format(num1.bit_count())) num1 = 100 print("num1 = {}".format(num1)) print("num1.bit_count() => {}".format(num1.bit_count())) from types import MappingProxyType dict1 = dict() mappingProxy = MappingProxyType(dict1) print(type(dict1.items()) is type(mappingProxy.items())) print(mappingProxy.items().mapping is dict1) # zip(*iterables, strict=False) list1 = list(zip(("a", "b", "c"), (1, 2, 3), strict=True)) print("list1 = list(zip(('a', 'b', 'c'), (1, 2, 3), strict=True))") print("list1 => {}".format(list1))
class Writer(object): def __init__( self, filename, workdir=None, encoding=UTF8, compression=DEFAULT_COMPRESSION, min_bin_size=512 * 1024, max_redirects=5, observer=None, ): self.filename = filename self.observer = observer if os.path.exists(self.filename): raise SystemExit('File %r already exists' % self.filename) # make sure we can write with fopen(self.filename, 'wb'): pass self.encoding = encoding if encodings.search_function(self.encoding) is None: raise UnknownEncoding(self.encoding) self.workdir = workdir self.tmpdir = tmpdir = tempfile.TemporaryDirectory( prefix='{0}-'.format(os.path.basename(filename)), dir=workdir) self.f_ref_positions = self._wbfopen('ref-positions') self.f_store_positions = self._wbfopen('store-positions') self.f_refs = self._wbfopen('refs') self.f_store = self._wbfopen('store') self.max_redirects = max_redirects if max_redirects: self.aliases_path = os.path.join(tmpdir.name, 'aliases') self.f_aliases = Writer( self.aliases_path, workdir=tmpdir.name, max_redirects=0, compression=None, ) if compression is None: compression = '' if compression not in COMPRESSIONS: raise UnknownCompression(compression) else: self.compress = COMPRESSIONS[compression].compress self.compression = compression self.content_types = {} self.min_bin_size = min_bin_size self.current_bin = None self.blob_count = 0 self.ref_count = 0 self.bin_count = 0 self._tags = { 'version.python': sys.version.replace('\n', ' '), 'version.pyicu': icu.VERSION, 'version.icu': icu.ICU_VERSION, 'created.at': datetime.now(timezone.utc).isoformat() } self.tags = MappingProxyType(self._tags) def _wbfopen(self, name): return StructWriter(fopen(os.path.join(self.tmpdir.name, name), 'wb'), encoding=self.encoding) def tag(self, name, value=''): if len(name.encode(self.encoding)) > MAX_TINY_TEXT_LEN: self._fire_event('tag_name_too_long', (name, value)) return if len(value.encode(self.encoding)) > MAX_TINY_TEXT_LEN: self._fire_event('tag_value_too_long', (name, value)) value = '' self._tags[name] = value def _split_key(self, key): if isinstance(key, str): actual_key = key fragment = '' else: actual_key, fragment = key if len(actual_key) > MAX_TEXT_LEN or len(fragment) > MAX_TINY_TEXT_LEN: raise KeyTooLongException(key) return actual_key, fragment def add(self, blob, *keys, content_type=''): if len(blob) > MAX_LARGE_BYTE_STRING_LEN: self._fire_event('content_too_long', blob) return if len(content_type) > MAX_TEXT_LEN: self._fire_event('content_type_too_long', content_type) return actual_keys = [] for key in keys: try: actual_key, fragment = self._split_key(key) except KeyTooLongException as e: self._fire_event('key_too_long', e.key) else: actual_keys.append((actual_key, fragment)) if len(actual_keys) == 0: return if self.current_bin is None: self.current_bin = BinMemWriter() self.bin_count += 1 if content_type not in self.content_types: self.content_types[content_type] = len(self.content_types) self.current_bin.add(self.content_types[content_type], blob) self.blob_count += 1 bin_item_index = len(self.current_bin) - 1 bin_index = self.bin_count - 1 for actual_key, fragment in actual_keys: self._write_ref(actual_key, bin_index, bin_item_index, fragment) if (self.current_bin.current_offset > self.min_bin_size or len(self.current_bin) == MAX_BIN_ITEM_COUNT): self._write_current_bin() def add_alias(self, key, target_key): if self.max_redirects: try: self._split_key(key) except KeyTooLongException as e: self._fire_event('alias_too_long', e.key) return try: self._split_key(target_key) except KeyTooLongException as e: self._fire_event('alias_target_too_long', e.key) return self.f_aliases.add(pickle.dumps(target_key), key) else: raise NotImplementedError() def _fire_event(self, name, data=None): if self.observer: self.observer(WriterEvent(name, data)) def _write_current_bin(self): self.f_store_positions.write_long(self.f_store.tell()) self.current_bin.finalize(self.f_store, self.compress) self.current_bin = None def _write_ref(self, key, bin_index, item_index, fragment=''): self.f_ref_positions.write_long(self.f_refs.tell()) self.f_refs.write_text(key) self.f_refs.write_int(bin_index) self.f_refs.write_short(item_index) self.f_refs.write_tiny_text(fragment) self.ref_count += 1 def _sort(self): self._fire_event('begin_sort') f_ref_positions_sorted = self._wbfopen('ref-positions-sorted') self.f_refs.flush() self.f_ref_positions.close() with MultiFileReader(self.f_ref_positions.name, self.f_refs.name) as f: ref_list = RefList(f, self.encoding, count=self.ref_count) sortkey_func = sortkey(IDENTICAL) for i in sorted(range(len(ref_list)), key=lambda j: sortkey_func(ref_list[j].key)): ref_pos = ref_list.pos(i) f_ref_positions_sorted.write_long(ref_pos) f_ref_positions_sorted.close() os.remove(self.f_ref_positions.name) os.rename(f_ref_positions_sorted.name, self.f_ref_positions.name) self.f_ref_positions = StructWriter(fopen(self.f_ref_positions.name, 'ab'), encoding=self.encoding) self._fire_event('end_sort') def _resolve_aliases(self): self._fire_event('begin_resolve_aliases') self.f_aliases.finalize() with MultiFileReader(self.f_ref_positions.name, self.f_refs.name) as f_ref_list: ref_list = RefList(f_ref_list, self.encoding, count=self.ref_count) ref_dict = ref_list.as_dict() with Slob(self.aliases_path) as r: aliases = r.as_dict() path = os.path.join(self.tmpdir.name, 'resolved-aliases') with Writer( path, workdir=self.tmpdir.name, max_redirects=0, compression=None, ) as alias_writer: def read_key_frag(item, default_fragment): key_frag = pickle.loads(item.content) if isinstance(key_frag, str): return key_frag, default_fragment else: return key_frag for item in r: from_key = item.key keys = set() keys.add(from_key) to_key, fragment = read_key_frag(item, item.fragment) count = 0 while count <= self.max_redirects: # is target key itself a redirect? try: orig_to_key = to_key to_key, fragment = read_key_frag( next(aliases[to_key]), fragment) count += 1 keys.add(orig_to_key) except StopIteration: break if count > self.max_redirects: self._fire_event('too_many_redirects', from_key) try: target_ref = next(ref_dict[to_key]) except StopIteration: self._fire_event('alias_target_not_found', to_key) else: for key in keys: ref = Ref( key=key, bin_index=target_ref.bin_index, item_index=target_ref.item_index, # last fragment in the chain wins fragment=target_ref.fragment or fragment, ) alias_writer.add(pickle.dumps(ref), key) with Slob(path) as resolved_aliases_reader: previous_key = None for item in resolved_aliases_reader: ref = pickle.loads(item.content) if ref.key == previous_key: continue self._write_ref( ref.key, ref.bin_index, ref.item_index, ref.fragment, ) previous_key = ref.key self._sort() self._fire_event('end_resolve_aliases') def finalize(self): self._fire_event('begin_finalize') if self.current_bin is not None: self._write_current_bin() self._sort() if self.max_redirects: self._resolve_aliases() files = ( self.f_ref_positions, self.f_refs, self.f_store_positions, self.f_store, ) for f in files: f.close() buf_size = 10 * 1024 * 1024 with fopen(self.filename, mode='wb') as output_file: out = StructWriter(output_file, self.encoding) out.write(MAGIC) out.write(uuid4().bytes) out.write_tiny_text(self.encoding, encoding=UTF8) out.write_tiny_text(self.compression) def write_tags(tags, f): f.write(pack(U_CHAR, len(tags))) for key, value in tags.items(): f.write_tiny_text(key) f.write_tiny_text(value, editable=True) write_tags(self.tags, out) def write_content_types(content_types, f): count = len(content_types) f.write(pack(U_CHAR, count)) types = sorted(content_types.items(), key=lambda x: x[1]) for content_type, _ in types: f.write_text(content_type) write_content_types(self.content_types, out) out.write_int(self.blob_count) store_offset = ( out.tell() + U_LONG_LONG_SIZE + # this value U_LONG_LONG_SIZE + # file size value U_INT_SIZE + # ref count value os.stat(self.f_ref_positions.name).st_size + os.stat(self.f_refs.name).st_size) out.write_long(store_offset) out.flush() file_size = ( out.tell() + # bytes written so far U_LONG_LONG_SIZE + # file size value 2 * U_INT_SIZE # ref count and bin count ) file_size += sum((os.stat(f.name).st_size for f in files)) out.write_long(file_size) def mv(src, out): fname = src.name self._fire_event('begin_move', fname) with fopen(fname, mode='rb') as f: while True: data = f.read(buf_size) if len(data) == 0: break out.write(data) out.flush() os.remove(fname) self._fire_event('end_move', fname) out.write_int(self.ref_count) mv(self.f_ref_positions, out) mv(self.f_refs, out) out.write_int(self.bin_count) mv(self.f_store_positions, out) mv(self.f_store, out) self.tmpdir.cleanup() self._fire_event('end_finalize') def size_header(self): size = 0 size += len(MAGIC) size += 16 # uuid bytes size += U_CHAR_SIZE + len(self.encoding.encode(UTF8)) size += U_CHAR_SIZE + len(self.compression.encode(self.encoding)) size += U_CHAR_SIZE # tag length size += U_CHAR_SIZE # content types count # tags and content types themselves counted elsewhere size += U_INT_SIZE # blob count size += U_LONG_LONG_SIZE # store offset size += U_LONG_LONG_SIZE # file size size += U_INT_SIZE # ref count size += U_INT_SIZE # bin count return size def size_tags(self): size = 0 for key, _ in self.tags.items(): size += U_CHAR_SIZE + len(key.encode(self.encoding)) size += 255 return size def size_content_types(self): size = 0 for content_type in self.content_types: size += U_CHAR_SIZE + len(content_type.encode(self.encoding)) return size def size_data(self): files = ( self.f_ref_positions, self.f_refs, self.f_store_positions, self.f_store, ) return sum((os.stat(f.name).st_size for f in files)) def __enter__(self): return self def __exit__(self, exc_type, exc_val, exc_tb): self.finalize() return False
STRING = 0 PATH = 1 LIST = 2 W_STR_RENDERING_TYPE_2_VALUE_TYPE = MappingProxyType( collections.OrderedDict([ (WithStrRenderingType.STRING, ValueType.STRING), (WithStrRenderingType.PATH, ValueType.PATH), (WithStrRenderingType.LIST, ValueType.LIST), ])) VALUE_TYPE_2_W_STR_RENDERING_TYPE = MappingProxyType( collections.OrderedDict([ (item[1], item[0]) for item in W_STR_RENDERING_TYPE_2_VALUE_TYPE.items() ] )) def sorted_types(values: Iterable[ValueType]) -> Tuple[ValueType, ...]: return tuple( sorted(values, key=_value_type_sorting_key) ) def _value_type_sorting_key(x: ValueType) -> int: return x.value
'@type': '@id', }, 'rdfs:subClassOf': { '@type': '@id', }, 'octa:subjectOf': { '@type': '@id', }, 'title': 'octa:title', 'position': 'octa:position', # The default namespaces list should be included in context # We have to convert URLs to strings though - to make them serializable. **{ namespace_name: str(namespace_url) for namespace_name, namespace_url in DEFAULT_NAMESPACES.items() if namespace_name }, }) class Triple(NamedTuple): """RDF triple.""" subject: rdflib.URIRef predicate: rdflib.URIRef object: Union[rdflib.URIRef, rdflib.Literal] # noqa: WPS125 def as_quad(self, graph: rdflib.URIRef) -> 'Quad': """Add graph to this triple and hence get a quad.""" return Quad(self.subject, self.predicate, self.object, graph)
class Api(metaclass=ApiMeta): r""" Construct and send an NS API request. NS API docs can be found here: https://www.nationstates.net/pages/api.html This is a low-level API wrapper. Some attempts will be made to prevent bad requests, but it will not check shards against a verified list. Authentication may be provided for private nation shards. X-Pin headers are be stored internally and globally for ease of use. Api objects may be awaited or asynchronously iterated. To perform operations from another thread, use the :attr:`threadsafe` property. The Api object itself supports all :class:`collections.abc.Mapping` methods. ================ ============================================================================ Operation Description ================ ============================================================================ await x Make a request to the NS API and returns the root XML element. async for y in x Make a request to the NS API and return each shard element as it is parsed. Useful for larger requests. x + y Combines two :class:`Api` objects together into a new one. Shard keywords that can't be combined will be overwritten with y's data. bool(x) Naïvely check if this :class:`Api` object would result in a 400 Bad Request. Truthy :class:`Api` objects may still result in a 400 Bad Request. Use `len(x)` to check for containment. str(x) Return the URL this :class:`Api` object will make a request to. Other All other :class:`collections.abc.Mapping` methods, except x == y, are also supported. ================ ============================================================================ Parameters ---------- \*shards: str Shards to request from the API. password: str X-Password authentication for private nation shards. autologin: str X-Autologin authentication for private nation shards. \*\*parameters: str Query parameters to append to the request, e.g. nation, scale. Examples -------- Usage:: darc = await Api(nation="darcania") async for shard in Api(nation="testlandia"): print(pretty_string(shard)) tnp = Api(region="the_north_pacific").threadsafe() for shard in Api(region="testregionia").threadsafe: print(pretty_string(shard)) """ __slots__ = ("__proxy", "_password", "_str", "_hash", "_last_response") def __new__( cls, *shards: _Union[str, _Iterable[str]], password: _Optional[str] = None, autologin: _Optional[str] = None, **parameters: str, ): if len(shards) == 1 and not parameters: if isinstance(shards[0], cls): return shards[0] with contextlib.suppress(Exception): return cls.from_url(shards[0]) return super().__new__(cls) def __init__( self, *shards: _Union[str, _Iterable[str]], password: _Optional[str] = None, autologin: _Optional[str] = None, **parameters: str, ): has_nation = "nation" in parameters dicts = [parameters] if parameters else [] for shard in filter(bool, shards): if isinstance(shard, Mapping): dicts.append(shard) if not has_nation and "nation" in shard: has_nation = True else: dicts.append({"q": shard}) if not has_nation and (password or autologin): raise ValueError( "Authentication may only be used with the Nation API.") self.__proxy = MappingProxyType(_normalize_dicts(*dicts)) self._password = password self._last_response = None self._str = None self._hash = None async def __await(self): async for element in self.__aiter__(clear=False): pass return element def __await__(self): return self.__await().__await__() async def __aiter__( self, *, clear: bool = True ) -> _AsyncGenerator[objectify.ObjectifiedElement, None]: if not Api.agent: raise RuntimeError("The API's user agent is not yet set.") if "a" in self and self["a"].lower() == "sendtg": raise RuntimeError( "This API wrapper does not support API telegrams.") if not self: # Preempt the request to conserve ratelimit raise BadRequest() url = str(self) headers = {"User-Agent": Api.agent} if self._password: headers["X-Password"] = self._password autologin = self.autologin if autologin: headers["X-Autologin"] = autologin if self.get("nation") in PINS: headers["X-Pin"] = PINS[self["nation"]] async with Api.session.request("GET", url, headers=headers) as response: self._last_response = response if "X-Autologin" in response.headers: self._password = None if "X-Pin" in response.headers: PINS[self["nation"]] = response.headers["X-Pin"] response.raise_for_status() encoding = (response.headers["Content-Type"].split("charset=") [1].split(",")[0]) with contextlib.suppress(etree.XMLSyntaxError), contextlib.closing( etree.XMLPullParser(["end"], base_url=url, remove_blank_text=True)) as parser: parser.set_element_class_lookup( objectify.ObjectifyElementClassLookup()) events = parser.read_events() async for data, _ in response.content.iter_chunks(): parser.feed(data.decode(encoding)) for _, element in events: if clear and (element.getparent() is None or element.getparent().getparent() is not None): continue yield element if clear: element.clear(keep_tail=True) def __add__(self, other: _Any) -> "Api": if isinstance(other, str): with contextlib.suppress(Exception): other = type(self).from_url(other) with contextlib.suppress(Exception): return type(self)(self, other) return NotImplemented def __bool__(self): if any(a in self for a in ("nation", "region")): return True if "a" in self: if self["a"] == "verify" and all(a in self for a in ("nation", "checksum")): return True if self["a"] == "sendtg" and all( a in self for a in ("client", "tgid", "key", "to")): return True return False return "q" in self def __contains__(self, key: str) -> bool: return key in self.__proxy def __dir__(self): return set(super().__dir__()).union( dir(self.__proxy), (a for a in dir(type(self)) if not hasattr(type, a))) __eq__ = None def __getattr__(self, name: str): with contextlib.suppress(AttributeError): return getattr(self.__proxy, name) raise AttributeError( f"{type(self).__name__!r} has no attribute {name!r}") def __getitem__(self, key): return self.__proxy[str(key).lower()] def __hash__(self): if self._hash is not None: return self._hash params = sorted((k, v if isinstance(v, str) else " ".join(sorted(v))) for k, v in self.items()) self._hash = hash(tuple(params)) return self._hash def __iter__(self): return iter(self.__proxy) def __len__(self): return len(self.__proxy) def __repr__(self) -> str: return "{}.{}({})".format( type(self).__module__, type(self).__name__, ", ".join("{}={!r}".format(*t) for t in self.__proxy.items()), ) def __str__(self) -> str: if self._str is not None: return self._str params = [(k, v if isinstance(v, str) else "+".join(v)) for k, v in self.items()] self._str = urlunparse((*API_URL, None, urlencode(params, safe="+"), None)) return self._str @property def autologin(self) -> _Optional[str]: """ If a private nation shard was properly requested and returned, this property may be used to get the "X-Autologin" token. """ if self._last_response: return self._last_response.headers.get("X-Autologin") return None @property def last_headers(self) -> _Optional[_Mapping[str, str]]: """ Returns the headers returned from the last request this API object sent. """ if self._last_response: return self._last_response.headers return None @property def last_response(self) -> _Optional[NSResponse]: """ Returns the response object from the last request this API object sent. """ return self._last_response @property def threadsafe(self) -> Threadsafe: """ Returns a threadsafe wrapper around this object. The returned wrapper may be called, awaited, or iterated over. Both standard and async iteration are supported. """ return Threadsafe(self) @classmethod def from_url(cls, url: str, *shards: _Iterable[str], **parameters: _Iterable[str]) -> "Api": """ Constructs an Api object from a provided URL. The Api object may be further modified with shards and parameters, as per the :class:`Api` constructor. """ parsed_url = urlparse(str(url)) url = parsed_url[:len(API_URL)] if any(url) and url != API_URL: raise ValueError( "URL must be solely query parameters or an API url") return cls(*shards, parse_qs(parsed_url.query), parameters)
class Environment(object): """Abstract base class for environments. Represents a type and configuration of environment. Each type of Environment should have a unique urn. For internal use only. No backwards compatibility guarantees. """ _known_urns = {} # type: Dict[str, Tuple[Optional[type], ConstructorFn]] _urn_to_env_cls = {} # type: Dict[str, type] def __init__( self, capabilities=(), # type: Iterable[str] artifacts=( ), # type: Iterable[beam_runner_api_pb2.ArtifactInformation] resource_hints=None, # type: Optional[Mapping[str, bytes]] ): # type: (...) -> None self._capabilities = capabilities self._artifacts = sorted(artifacts, key=lambda x: x.SerializeToString()) # Hints on created environments should be immutable since pipeline context # stores environments in hash maps and we use hints to compute the hash. self._resource_hints = MappingProxyType( dict(resource_hints) if resource_hints else {}) def __eq__(self, other): return ( self.__class__ == other.__class__ and self._artifacts == other._artifacts # Assuming that we don't have instances of the same Environment subclass # with different set of capabilities. and self._resource_hints == other._resource_hints) def __hash__(self): # type: () -> int return hash((self.__class__, frozenset(self._resource_hints.items()))) def artifacts(self): # type: () -> Iterable[beam_runner_api_pb2.ArtifactInformation] return self._artifacts def to_runner_api_parameter(self, context): # type: (PipelineContext) -> Tuple[str, Optional[Union[message.Message, bytes, str]]] raise NotImplementedError def capabilities(self): # type: () -> Iterable[str] return self._capabilities def resource_hints(self): # type: () -> Mapping[str, bytes] return self._resource_hints @classmethod @overload def register_urn( cls, urn, # type: str parameter_type, # type: Type[T] ): # type: (...) -> Callable[[Union[type, Callable[[T, Iterable[str], PipelineContext], Any]]], Callable[[T, Iterable[str], PipelineContext], Any]] pass @classmethod @overload def register_urn( cls, urn, # type: str parameter_type, # type: None ): # type: (...) -> Callable[[Union[type, Callable[[bytes, Iterable[str], Iterable[beam_runner_api_pb2.ArtifactInformation], PipelineContext], Any]]], Callable[[bytes, Iterable[str], PipelineContext], Any]] pass @classmethod @overload def register_urn( cls, urn, # type: str parameter_type, # type: Type[T] constructor # type: Callable[[T, Iterable[str], Iterable[beam_runner_api_pb2.ArtifactInformation], PipelineContext], Any] ): # type: (...) -> None pass @classmethod @overload def register_urn( cls, urn, # type: str parameter_type, # type: None constructor # type: Callable[[bytes, Iterable[str], Iterable[beam_runner_api_pb2.ArtifactInformation], PipelineContext], Any] ): # type: (...) -> None pass @classmethod def register_urn(cls, urn, parameter_type, constructor=None): def register(constructor): if isinstance(constructor, type): constructor.from_runner_api_parameter = register( constructor.from_runner_api_parameter) # register environment urn to environment class cls._urn_to_env_cls[urn] = constructor return constructor else: cls._known_urns[urn] = parameter_type, constructor return staticmethod(constructor) if constructor: # Used as a statement. register(constructor) else: # Used as a decorator. return register @classmethod def get_env_cls_from_urn(cls, urn): # type: (str) -> Type[Environment] return cls._urn_to_env_cls[urn] def to_runner_api(self, context): # type: (PipelineContext) -> beam_runner_api_pb2.Environment urn, typed_param = self.to_runner_api_parameter(context) return beam_runner_api_pb2.Environment( urn=urn, payload=typed_param.SerializeToString() if isinstance( typed_param, message.Message) else typed_param if (isinstance(typed_param, bytes) or typed_param is None) else typed_param.encode('utf-8'), capabilities=self.capabilities(), dependencies=self.artifacts(), resource_hints=self.resource_hints()) @classmethod def from_runner_api( cls, proto, # type: Optional[beam_runner_api_pb2.Environment] context # type: PipelineContext ): # type: (...) -> Optional[Environment] if proto is None or not proto.urn: return None parameter_type, constructor = cls._known_urns[proto.urn] return constructor( proto_utils.parse_Bytes(proto.payload, parameter_type), proto.capabilities, proto.dependencies, proto.resource_hints, context) @classmethod def from_options(cls, options): # type: (Type[EnvironmentT], PortableOptions) -> EnvironmentT """Creates an Environment object from PortableOptions. Args: options: The PortableOptions object. """ raise NotImplementedError
skill['Two Heads'], skill['Very Long Legs'], ))), (skillcat['RACIAL CHARACTERISTICS'], frozenset(( skill['Always Hungry'], skill['Big Guy'], skill['Blood Lust'], skill['Bone Head'], skill['Easily Confused'], skill['Hypnotic Gaze'], skill['Nurgle\'s Rot'], skill['Really Stupid'], skill['Regeneration'], skill['Right Stuff'], skill['Stunty'], skill['Take Root'], skill['Throw Team-Mate'], skill['Thrud\'s Fans'], skill['Wild Animal'], ))), ))) ## REDUNDANT BUT USEFUL ######################################## skillcat_by_skill = MappingProxyType( {skill: skillcat for skillcat, skills in skill_by_skillcat.items() for skill in skills} )
class Api(metaclass=_ApiMeta): __slots__ = ("__dict", ) def __init__(self, *shards: _Union[str, _Mapping[str, str]], **kwargs: str): dicts = [kwargs] if kwargs else [] for shard in shards: if isinstance(shard, Mapping) and shard: dicts.append(shard) elif shard: dicts.append({"q": shard}) self.__dict = MappingProxyType(_normalize_dicts(*dicts)) async def __await(self): # pylint: disable=E1133 async for element in self.__aiter__(clear=False): pass return element def __await__(self): return self.__await().__await__() async def __aiter__(self, *, clear: bool = True): if not self: raise ValueError("Bad request") url = str(self) parser = etree.XMLPullParser(["end"], base_url=url, remove_blank_text=True) parser.set_element_class_lookup( etree.ElementDefaultClassLookup(element=_NSElement)) events = parser.read_events() async with type(self).session.request( "GET", url, headers={"User-Agent": type(self).agent}) as response: yield parser.makeelement("HEADERS", attrib=response.headers) encoding = response.headers["Content-Type"].split( "charset=")[1].split(",")[0] async for data, _ in response.content.iter_chunks(): parser.feed(data.decode(encoding)) for _, element in events: yield element if clear: element.clear() def __add__(self, other: _Any) -> "Api": with contextlib.suppress(Exception): return type(self)(self, other) return NotImplemented def __bool__(self): return any(a in self for a in ("a", "nation", "region", "q", "wa")) def __contains__(self, key): return key in self.__dict def __dir__(self): return set(super().__dir__()).union(dir(self.__dict)) def __getattribute__(self, name): try: return super().__getattribute__(name) except AttributeError: with contextlib.suppress(AttributeError): return getattr(self.__dict, name) raise def __getitem__(self, key): return self.__dict[str(key).lower()] def __iter__(self): return iter(self.__dict) def __len__(self): return len(self.__dict) def __repr__(self) -> str: return "{}({})".format( type(self).__name__, ", ".join( "{}={!r}".format(k, v if isinstance(v, str) else " ".join(v)) for k, v in self.__dict.items()), ) def __str__(self) -> str: params = [(k, v if isinstance(v, str) else " ".join(v)) for k, v in self.items()] return urlunparse((*API_URL, None, urlencode(params), None)) def copy(self): return type(self)(**self.__dict) @classmethod def from_url(cls, url: str, *args, **kwargs): parsed_url = urlparse(str(url)) url = parsed_url[:len(API_URL)] if any(url) and url != API_URL: raise ValueError( "URL must be solely query parameters or an API url") return cls(*args, dict(parse_qsl(parsed_url.query)), kwargs)
class Composition: """ Defines a composition of a compound. To create a composition, use the class methods: - :meth:`from_pure` - :meth:`from_formula` - :meth:`from_mass_fractions` - :meth:`from_atomic_fractions` Use the following attributes to access the composition values: - :attr:`mass_fractions`: :class:`dict` where the keys are atomic numbers and the values weight fractions. - :attr:`atomic_fractions`: :class:`dict` where the keys are atomic numbers and the values atomic fractions. - :attr:`formula`: chemical formula The composition object is immutable, i.e. it cannot be modified once created. Equality can be checked. It is hashable. It can be pickled or copied. """ _key = object() PRECISION = 0.000000001 # 1ppb def __init__(self, key, mass_fractions, atomic_fractions, formula): """ Private constructor. It should never be used. """ if key != Composition._key: raise TypeError('Composition cannot be created using constructor') if set(mass_fractions.keys()) != set(atomic_fractions.keys()): raise ValueError('Mass and atomic fractions must have the same elements') self.mass_fractions = MappingProxyType(mass_fractions) self.atomic_fractions = MappingProxyType(atomic_fractions) self._formula = formula @classmethod def from_pure(cls, z): """ Creates a pure composition. Args: z (int): atomic number """ return cls(cls._key, {z: 1.0}, {z: 1.0}, pyxray.element_symbol(z)) @classmethod def from_formula(cls, formula): """ Creates a composition from a chemical formula. Args: formula (str): chemical formula """ atomic_fractions = convert_formula_to_atomic_fractions(formula) return cls.from_atomic_fractions(atomic_fractions) @classmethod def from_mass_fractions(cls, mass_fractions, formula=None): """ Creates a composition from a mass fraction :class:`dict`. Args: mass_fractions (dict): mass fraction :class:`dict`. The keys are atomic numbers and the values weight fractions. Wildcard are accepted, e.g. ``{5: '?', 25: 0.4}`` where boron will get a mass fraction of 0.6. formula (str): optional chemical formula for the composition. If ``None``, a formula will be generated for the composition. """ mass_fractions = process_wildcard(mass_fractions) atomic_fractions = convert_mass_to_atomic_fractions(mass_fractions) if not formula: formula = generate_name(atomic_fractions) return cls(cls._key, mass_fractions, atomic_fractions, formula) @classmethod def from_atomic_fractions(cls, atomic_fractions, formula=None): """ Creates a composition from an atomic fraction :class:`dict`. Args: atomic_fractions (dict): atomic fraction :class:`dict`. The keys are atomic numbers and the values atomic fractions. Wildcard are accepted, e.g. ``{5: '?', 25: 0.4}`` where boron will get a atomic fraction of 0.6. formula (str): optional chemical formula for the composition. If ``None``, a formula will be generated for the composition. """ atomic_fractions = process_wildcard(atomic_fractions) mass_fractions = convert_atomic_to_mass_fractions(atomic_fractions) if not formula: formula = generate_name(atomic_fractions) return cls(cls._key, mass_fractions, atomic_fractions, formula) def __len__(self): return len(self.mass_fractions) def __contains__(self, z): return z in self.mass_fractions def __iter__(self): return iter(self.mass_fractions.keys()) def __repr__(self): return '<{}({})>'.format(self.__class__.__name__, self.inner_repr()) def __eq__(self, other): if not isinstance(other, self.__class__): return False if len(self) != len(other): return False for z in self.mass_fractions: if z not in other.mass_fractions: return False fraction = self.mass_fractions[z] other_fraction = other.mass_fractions[z] if not math.isclose(fraction, other_fraction, abs_tol=self.PRECISION): return False return True def __ne__(self, other): return not self == other def __hash__(self): out = [] for z in sorted(self.mass_fractions): out.append(z) out.append(int(self.mass_fractions[z] / self.PRECISION)) return hash(tuple(out)) def __getstate__(self): return {'mass_fractions': dict(self.mass_fractions), 'atomic_fractions': dict(self.atomic_fractions), 'formula': self.formula} def __setstate__(self, state): self.mass_fractions = MappingProxyType(state.get('mass_fractions', {})) self.atomic_fractions = MappingProxyType(state.get('atomic_fractions', {})) self._formula = state.get('formula', '') def is_normalized(self): return math.isclose(sum(self.mass_fractions.values()), 1.0, abs_tol=self.PRECISION) def inner_repr(self): return ', '.join('{}: {:.4f}'.format(pyxray.element_symbol(z), mass_fraction) for z, mass_fraction in self.mass_fractions.items()) @property def formula(self): return self._formula
))), ))) stat_value = { '+MA': 30000, '+AG': 40000, '+ST': 50000, '+AV': 30000, } ## REDUNDANT BUT USEFUL ######################################## deck_of_card = MappingProxyType( {card: deck for deck, cards in cards_of_deck.items() for card in cards} ) card_price = MappingProxyType( {card: deck_price[deck] for card, deck in deck_of_card.items()} ) skillcat_by_skill = MappingProxyType( {skill: skillcat for skillcat, skills in skill_by_skillcat.items() for skill in skills} )
def __init__(self, raster, list_of_prod_fp, channel_ids, is_flat, dst_nodata, interpolation, max_queue_size, parent_uid, key_in_parent): # Mutable attributes ******************************************************************** ** # Attributes that relates a query to a single optional computation phase self.cache_computation = None # type: Union[None, CacheComputationInfos] # Immutable attributes ****************************************************************** ** self.parent_uid = parent_uid self.key_in_parent = key_in_parent # The parameters given by user in invocation self.channel_ids = channel_ids # type: Sequence[int] self.is_flat = is_flat # type: bool self.unique_channel_ids = [] for bi in channel_ids: if bi not in self.unique_channel_ids: self.unique_channel_ids.append(bi) self.unique_channel_ids = tuple(self.unique_channel_ids) self.dst_nodata = dst_nodata # type: Union[int, float] self.interpolation = interpolation # type: str # Output max queue size (Parameter given to queue.Queue) self.max_queue_size = max_queue_size # type: int # How many arrays are requested self.produce_count = len(list_of_prod_fp) # type: int # Build CacheProduceInfos objects ************************************** to_zip = [] # The list of Footprints requested list_of_prod_fp = list_of_prod_fp # type: List[ProductionFootprint] to_zip.append(list_of_prod_fp) # Boolean attribute of each `prod_fp` # If `True` the resampling phase has to be performed on a Pool list_of_prod_same_grid = [ fp.same_grid(raster.fp) for fp in list_of_prod_fp ] # type: List[bool] to_zip.append(list_of_prod_same_grid) # Boolean attribute of each `prod_fp` # If `False` the queried footprint is outside of raster's footprint. It means that no # sampling is necessary and the outputed array will be full of `dst_nodata` list_of_prod_share_area = [ fp.share_area(raster.fp) for fp in list_of_prod_fp ] # type: List[bool] to_zip.append(list_of_prod_share_area) # The full Footprint that needs to be sampled for each `prod_fp` # Is `None` if `prod_fp` is fully outside of raster list_of_prod_sample_fp = [] # type: List[Union[None, SampleFootprint]] to_zip.append(list_of_prod_sample_fp) # The set of cache Footprints that are needed for each `prod_fp` list_of_prod_cache_fps = [] # type: List[FrozenSet[CacheFootprint]] to_zip.append(list_of_prod_cache_fps) # The list of resamplings to perform for each `prod_fp` # Always at least 1 resampling per `prod_fp` list_of_prod_resample_fps = [ ] # type: List[Tuple[ResampleFootprint, ...]] to_zip.append(list_of_prod_resample_fps) # The set of `cache_fp` necessary per `resample_fp` for each `prod_fp` list_of_prod_resample_cache_deps_fps = [ ] # type: List[Mapping[ResampleFootprint, FrozenSet[CacheFootprint]]] to_zip.append(list_of_prod_resample_cache_deps_fps) # The full Footprint that needs to be sampled par `resample_fp` for each `prod_fp` # Might be `None` if `prod_fp` is fully outside of raster list_of_prod_resample_sample_dep_fp = [ ] # type: List[Mapping[ResampleFootprint, Union[None, SampleFootprint]]] to_zip.append(list_of_prod_resample_sample_dep_fp) it = zip(list_of_prod_fp, list_of_prod_same_grid, list_of_prod_share_area) # TODO: Speed up that piece of code # - Code footprint with lower level code # - Spawn a ProcessPoolExecutor when >100 prod. (The same could be done for fp.tile). # - What about a global process pool executor in `buzz.env`? for prod_fp, same_grid, share_area in it: if not share_area: # Resampling will be performed in one pass, on the scheduler list_of_prod_sample_fp.append(None) list_of_prod_cache_fps.append(frozenset()) resample_fp = cast(ResampleFootprint, prod_fp) list_of_prod_resample_fps.append((resample_fp, )) list_of_prod_resample_cache_deps_fps.append( MappingProxyType({resample_fp: frozenset()})) list_of_prod_resample_sample_dep_fp.append( MappingProxyType({resample_fp: None})) else: if same_grid: # Remapping will be performed in one pass, on the scheduler sample_fp = raster.fp & prod_fp resample_fps = [cast(ResampleFootprint, prod_fp)] sample_dep_fp = {resample_fps[0]: sample_fp} else: sample_fp = raster.build_sampling_footprint_to_remap_interpolate( prod_fp, interpolation) if raster.max_resampling_size is None: # Remapping will be performed in one pass, on a Pool resample_fps = [cast(ResampleFootprint, prod_fp)] sample_dep_fp = {resample_fps[0]: sample_fp} else: # Resampling will be performed in several passes, on a Pool rsize = np.maximum(prod_fp.rsize, sample_fp.rsize) countx, county = np.ceil( rsize / raster.max_resampling_size).astype(int) resample_fps = prod_fp.tile_count( countx, county, boundary_effect='shrink').flatten().tolist() sample_dep_fp = { resample_fp: (raster. build_sampling_footprint_to_remap_interpolate( resample_fp, interpolation) if resample_fp.share_area(raster.fp) else None) for resample_fp in resample_fps } resample_cache_deps_fps = MappingProxyType({ resample_fp: frozenset(raster.cache_fps_of_fp(sample_subfp)) for resample_fp in resample_fps for sample_subfp in [sample_dep_fp[resample_fp]] if sample_subfp is not None }) for s in resample_cache_deps_fps.items(): assert len(s) > 0 # The `intersection of the cache_fps with sample_fp` might not be the same as the # the `intersection of the cache_fps with resample_fps`! cache_fps = frozenset( itertools.chain.from_iterable( resample_cache_deps_fps.values())) assert len(cache_fps) > 0 list_of_prod_cache_fps.append(cache_fps) list_of_prod_sample_fp.append(sample_fp) list_of_prod_resample_fps.append(tuple(resample_fps)) list_of_prod_resample_cache_deps_fps.append( resample_cache_deps_fps) list_of_prod_resample_sample_dep_fp.append( MappingProxyType(sample_dep_fp)) self.prod = tuple([CacheProduceInfos(*args) for args in zip(*to_zip) ]) # type: Tuple[CacheProduceInfos, ...] # Misc ***************************************************************** # The list of all cache Footprints needed, ordered by priority self.list_of_cache_fp = [] # type: Sequence[CacheFootprint] seen = set() for fps in list_of_prod_cache_fps: for fp in fps: if fp not in seen: seen.add(fp) self.list_of_cache_fp.append(fp) self.list_of_cache_fp = tuple(self.list_of_cache_fp) del seen # The dict of cache Footprint to set of production idxs # For each `cache_fp`, the set of prod_idx that need this cache tile self.dict_of_prod_idxs_per_cache_fp = collections.defaultdict( set) # type: Mapping[CacheFootprint, AbstractSet[int]] for i, (prod_fp, cache_fps) in enumerate( zip(list_of_prod_fp, list_of_prod_cache_fps)): for cache_fp in cache_fps: self.dict_of_prod_idxs_per_cache_fp[cache_fp].add(i) for k, v in self.dict_of_prod_idxs_per_cache_fp.items(): self.dict_of_prod_idxs_per_cache_fp[k] = frozenset(v) self.dict_of_prod_idxs_per_cache_fp = MappingProxyType( self.dict_of_prod_idxs_per_cache_fp) # The dict of cache Footprint to production_idx # For each `cache_fp`, the minimum prod_idx that need this cache tile self.dict_of_min_prod_idx_per_cache_fp = { } # type: Mapping[CacheFootprint, int] for k, v in self.dict_of_prod_idxs_per_cache_fp.items(): self.dict_of_min_prod_idx_per_cache_fp[k] = min(v) self.dict_of_min_prod_idx_per_cache_fp = MappingProxyType( self.dict_of_min_prod_idx_per_cache_fp)
class Loader: def __init__(self, filename): self.filename = filename self._categories = {'bot': BotCategory()} self.categories = MappingProxyType(self._categories) self._errors = [] def add_category(self, name, category): if name in self._categories: raise SettingsKeyError(f'Defaults for category `{key}`' ' provided more than once.') if not (category and isinstance(category, Category)): raise SettingsTypeError(f'Incorrect Category type for `{key}:' f' `{type(category)}`') self._categories[name] = category def remove_category(self, name): return self._categories.pop(name, None) is not None def load(self): self._errors.clear() toml_dict = toml.load(self.filename) for name, category in self.categories.items(): self._load_category(category, name, toml_dict) if self._errors: raise SettingsError(self._errors) if self._has_leftover_values(toml_dict): logging.warning(f'Leftover settings remain: `{repr(toml_dict)}`') return Settings(self.categories) def _load_category(self, category, name, toml_dict, _prefix=''): settings = [s for s in dir(category) if not s.startswith('__')] for setting in settings: try: self._load_setting(setting, category, name, toml_dict, _prefix) except SettingsError as ex: self._errors.append(ex) def _load_setting(self, setting, category, name, toml_dict, prefix): default = getattr(category, setting) setting_type = type(default) required = False if isinstance(default, Required): setting_type = default.type required = True category_dict = toml_dict.get(name, {}) if not isinstance(category_dict, dict): raise SettingsTypeError(f'Setting `{name}` must be a category' f' but was `{type(category_dict)}' ' instead.`') if isinstance(default, Category): return self._load_category(default, setting, category_dict, f'{name}.{prefix}') if setting in category_dict: # remove (pop) from dict so we know it was used value = category_dict.pop(setting) value_type = type(value) if value_type is int and setting_type is float: value = float(value) elif value_type is not setting_type: raise SettingsTypeError(f'Setting `{prefix}{name}:{setting}`' f' must be of type `{setting_type}`' f' but was `{value_type}` instead.') setattr(category, setting, value) elif required: raise SettingsKeyError(f'Setting `{prefix}{name}:{setting}`' ' required but not provided.') def _has_leftover_values(self, toml_dict): has_leftovers = False # copy to list since dict changes size while iterating for key, value in list(toml_dict.items()): if not isinstance(value, dict): has_leftovers = True continue if self._has_leftover_values(value): has_leftovers = True else: del toml_dict[key] # remove empty subdicts return has_leftovers
class Trainer: def __init__(self, model, dataset, criterion, optimizer, settings): super().__init__() context = deepcopy(settings) self.ctx = MappingProxyType(vars(context)) self.mode = ('train', 'val').index(context.cmd) self.logger = R['LOGGER'] self.gpc = R['GPC'] # Global Path Controller self.path = self.gpc.get_path self.batch_size = context.batch_size self.checkpoint = context.resume self.load_checkpoint = (len(self.checkpoint) > 0) self.num_epochs = context.num_epochs self.lr = float(context.lr) self.save = context.save_on or context.out_dir self.out_dir = context.out_dir self.trace_freq = int(context.trace_freq) self.device = torch.device(context.device) self.suffix_off = context.suffix_off for k, v in sorted(self.ctx.items()): self.logger.show("{}: {}".format(k, v)) self.model = model_factory(model, context) self.model.to(self.device) self.criterion = critn_factory(criterion, context) self.criterion.to(self.device) self.metrics = metric_factory(context.metrics, context) if self.is_training: self.train_loader = data_factory(dataset, 'train', context) self.val_loader = data_factory(dataset, 'val', context) self.optimizer = optim_factory(optimizer, self.model, context) else: self.val_loader = data_factory(dataset, 'val', context) self.start_epoch = 0 self._init_max_acc_and_epoch = (0.0, 0) @property def is_training(self): return self.mode == 0 def train_epoch(self, epoch): raise NotImplementedError def validate_epoch(self, epoch=0, store=False): raise NotImplementedError def _write_prompt(self): self.logger.dump(input("\nWrite some notes: ")) def run(self): if self.is_training: self._write_prompt() self.train() else: self.evaluate() def train(self): if self.load_checkpoint: self._resume_from_checkpoint() max_acc, best_epoch = self._init_max_acc_and_epoch for epoch in range(self.start_epoch, self.num_epochs): lr = self._adjust_learning_rate(epoch) self.logger.show_nl("Epoch: [{0}]\tlr {1:.06f}".format(epoch, lr)) # Train for one epoch self.train_epoch(epoch) # Clear the history of metric objects for m in self.metrics: m.reset() # Evaluate the model on validation set self.logger.show_nl("Validate") acc = self.validate_epoch(epoch=epoch, store=self.save) is_best = acc > max_acc if is_best: max_acc = acc best_epoch = epoch self.logger.show_nl( "Current: {:.6f} ({:03d})\tBest: {:.6f} ({:03d})\t".format( acc, epoch, max_acc, best_epoch)) # The checkpoint saves next epoch self._save_checkpoint(self.model.state_dict(), self.optimizer.state_dict(), (max_acc, best_epoch), epoch + 1, is_best) def evaluate(self): if self.checkpoint: if self._resume_from_checkpoint(): self.validate_epoch(self.ckp_epoch, self.save) else: self.logger.warning("Warning: no checkpoint assigned!") def _adjust_learning_rate(self, epoch): if self.ctx['lr_mode'] == 'step': lr = self.lr * (0.5**(epoch // self.ctx['step'])) elif self.ctx['lr_mode'] == 'poly': lr = self.lr * (1 - epoch / self.num_epochs)**1.1 elif self.ctx['lr_mode'] == 'const': lr = self.lr else: raise ValueError('unknown lr mode {}'.format(self.ctx['lr_mode'])) for param_group in self.optimizer.param_groups: param_group['lr'] = lr return lr def _resume_from_checkpoint(self): ## XXX: This could be slow! if not os.path.isfile(self.checkpoint): self.logger.error("=> No checkpoint was found at '{}'.".format( self.checkpoint)) return False self.logger.show("=> Loading checkpoint '{}'".format(self.checkpoint)) checkpoint = torch.load(self.checkpoint, map_location=self.device) state_dict = self.model.state_dict() ckp_dict = checkpoint.get('state_dict', checkpoint) update_dict = { k: v for k, v in ckp_dict.items() if k in state_dict and state_dict[k].shape == v.shape } num_to_update = len(update_dict) if (num_to_update < len(state_dict)) or (len(state_dict) < len(ckp_dict)): if not self.is_training and (num_to_update < len(state_dict)): self.logger.error("=> Mismatched checkpoint for evaluation") return False self.logger.warning( "Warning: trying to load an mismatched checkpoint.") if num_to_update == 0: self.logger.error("=> No parameter is to be loaded.") return False else: self.logger.warning( "=> {} params are to be loaded.".format(num_to_update)) elif (not self.ctx['anew']) or not self.is_training: self.start_epoch = checkpoint.get('epoch', 0) max_acc_and_epoch = checkpoint.get('max_acc', (0.0, self.ckp_epoch)) # For backward compatibility if isinstance(max_acc_and_epoch, (float, int)): self._init_max_acc_and_epoch = (max_acc_and_epoch, self.ckp_epoch) else: self._init_max_acc_and_epoch = max_acc_and_epoch if self.ctx['load_optim'] and self.is_training: try: # Note that weight decay might be modified here self.optimizer.load_state_dict(checkpoint['optimizer']) except KeyError: self.logger.warning( "Warning: failed to load optimizer parameters.") state_dict.update(update_dict) self.model.load_state_dict(state_dict) self.logger.show( "=> Loaded checkpoint '{}' (epoch {}, max_acc {:.4f} at epoch {})". format(self.checkpoint, self.ckp_epoch, *self._init_max_acc_and_epoch)) return True def _save_checkpoint(self, state_dict, optim_state, max_acc, epoch, is_best): state = { 'epoch': epoch, 'state_dict': state_dict, 'optimizer': optim_state, 'max_acc': max_acc } # Save history history_path = self.path('weight', constants.CKP_COUNTED.format(e=epoch), underline=True) if epoch % self.trace_freq == 0: torch.save(state, history_path) # Save latest latest_path = self.path('weight', constants.CKP_LATEST, underline=True) torch.save(state, latest_path) if is_best: shutil.copyfile( latest_path, self.path('weight', constants.CKP_BEST, underline=True)) @property def ckp_epoch(self): # Get current epoch of the checkpoint # For dismatched ckp or no ckp, set to 0 return max(self.start_epoch - 1, 0) def save_image(self, file_name, image, epoch): file_path = os.path.join('epoch_{}/'.format(epoch), self.out_dir, file_name) out_path = self.path('out', file_path, suffix=not self.suffix_off, auto_make=True, underline=True) return io.imsave(out_path, image)
def convert_mapping_proxy_type(value: MappingProxyType): return dict(value.items())
class Composition: """ Defines a composition of a compound. To create a composition, use the class methods: - :meth:`from_pure` - :meth:`from_formula` - :meth:`from_mass_fractions` - :meth:`from_atomic_fractions` Use the following attributes to access the composition values: - :attr:`mass_fractions`: :class:`dict` where the keys are atomic numbers and the values weight fractions. - :attr:`atomic_fractions`: :class:`dict` where the keys are atomic numbers and the values atomic fractions. - :attr:`formula`: chemical formula The composition object is immutable, i.e. it cannot be modified once created. Equality can be checked. It is hashable. It can be pickled or copied. """ _key = object() PRECISION = 0.000000001 # 1ppb def __init__(self, key, mass_fractions, atomic_fractions, formula): """ Private constructor. It should never be used. """ if key != Composition._key: raise TypeError("Composition cannot be created using constructor") if set(mass_fractions.keys()) != set(atomic_fractions.keys()): raise ValueError( "Mass and atomic fractions must have the same elements") self.mass_fractions = MappingProxyType(mass_fractions) self.atomic_fractions = MappingProxyType(atomic_fractions) self._formula = formula @classmethod def from_pure(cls, z): """ Creates a pure composition. Args: z (int): atomic number """ return cls(cls._key, {z: 1.0}, {z: 1.0}, pyxray.element_symbol(z)) @classmethod def from_formula(cls, formula): """ Creates a composition from a chemical formula. Args: formula (str): chemical formula """ atomic_fractions = convert_formula_to_atomic_fractions(formula) return cls.from_atomic_fractions(atomic_fractions) @classmethod def from_mass_fractions(cls, mass_fractions, formula=None): """ Creates a composition from a mass fraction :class:`dict`. Args: mass_fractions (dict): mass fraction :class:`dict`. The keys are atomic numbers and the values weight fractions. Wildcard are accepted, e.g. ``{5: '?', 25: 0.4}`` where boron will get a mass fraction of 0.6. formula (str): optional chemical formula for the composition. If ``None``, a formula will be generated for the composition. """ mass_fractions = process_wildcard(mass_fractions) atomic_fractions = convert_mass_to_atomic_fractions(mass_fractions) if not formula: formula = generate_name(atomic_fractions) return cls(cls._key, mass_fractions, atomic_fractions, formula) @classmethod def from_atomic_fractions(cls, atomic_fractions, formula=None): """ Creates a composition from an atomic fraction :class:`dict`. Args: atomic_fractions (dict): atomic fraction :class:`dict`. The keys are atomic numbers and the values atomic fractions. Wildcard are accepted, e.g. ``{5: '?', 25: 0.4}`` where boron will get a atomic fraction of 0.6. formula (str): optional chemical formula for the composition. If ``None``, a formula will be generated for the composition. """ atomic_fractions = process_wildcard(atomic_fractions) mass_fractions = convert_atomic_to_mass_fractions(atomic_fractions) if not formula: formula = generate_name(atomic_fractions) return cls(cls._key, mass_fractions, atomic_fractions, formula) def __len__(self): return len(self.mass_fractions) def __contains__(self, z): return z in self.mass_fractions def __iter__(self): return iter(self.mass_fractions.keys()) def __repr__(self): return "<{}({})>".format(self.__class__.__name__, self.inner_repr()) def __eq__(self, other): if not isinstance(other, self.__class__): return False if len(self) != len(other): return False for z in self.mass_fractions: if z not in other.mass_fractions: return False fraction = self.mass_fractions[z] other_fraction = other.mass_fractions[z] if not math.isclose( fraction, other_fraction, abs_tol=self.PRECISION): return False return True def __ne__(self, other): return not self == other def __hash__(self): out = [] for z in sorted(self.mass_fractions): out.append(z) out.append(int(self.mass_fractions[z] / self.PRECISION)) return hash(tuple(out)) def __getstate__(self): return { "mass_fractions": dict(self.mass_fractions), "atomic_fractions": dict(self.atomic_fractions), "formula": self.formula, } def __setstate__(self, state): self.mass_fractions = MappingProxyType(state.get("mass_fractions", {})) self.atomic_fractions = MappingProxyType( state.get("atomic_fractions", {})) self._formula = state.get("formula", "") def is_normalized(self): return math.isclose(sum(self.mass_fractions.values()), 1.0, abs_tol=self.PRECISION) def inner_repr(self): return ", ".join( "{}: {:.4f}".format(pyxray.element_symbol(z), mass_fraction) for z, mass_fraction in self.mass_fractions.items()) @property def formula(self): return self._formula