Ejemplo n.º 1
0
class ImmutableDict(Mapping):
    """
    Copies a dict and proxies it via types.MappingProxyType to make it immutable.
    https://stackoverflow.com/questions/9997176/immutable-dictionary-only-use-as-a-key-for-another-dictionary/39673094#39673094
    """
    def __init__(self, somedict):
        dictcopy = dict(somedict)  # make a copy
        self._dict = MappingProxyType(dictcopy)  # lock it
        self._hash = None

    def __getitem__(self, key):
        return self._dict[key]

    def __len__(self):
        return len(self._dict)

    def __iter__(self):
        return iter(self._dict)

    def __hash__(self):
        if self._hash is None:
            self._hash = hash(frozenset(self._dict.items()))
        return self._hash

    def __eq__(self, other):
        return self._dict == other._dict

    def __repr__(self):
        return str(self._dict)
Ejemplo n.º 2
0
def build_context(ctx: MappingProxyType = None,
                  session=None,
                  admin=None,
                  testing=None,
                  request_id=None):
    """

    :rtype:
    """
    new_fields = {
        CTX_SQL_SESSION: session,
        CTX_ADMIN: admin,
        CTX_TESTING: testing,
        CTX_REQUEST_ID: request_id,
    }
    if ctx is None:
        return MappingProxyType(new_fields)

    old_fields = {k: v
                  for k, v in ctx.items()
                  if v is not None}  # Remove None fields.

    merged = {
        **new_fields,
        **old_fields
    }  # Merge old and new context, with priority for the new one.

    return MappingProxyType(merged)
Ejemplo n.º 3
0
    def helper(d, schema):
        d_type = type(d)
        schema_type = type(schema)

        # list 처리
        if d_type == list:
            assert schema_type in JSON_ARRAY_TYPES, schema_type
            for element in d:
                helper(element, schema[0])
            return

        # 형식 확인 및 안전 장치 세팅
        assert d_type == dict, d_type
        assert schema_type in JSON_MAPPING_TYPES, schema_type
        if schema_type != MappingProxyType:
            schema = MappingProxyType(schema)

        # schema 에 없는 항목 처리
        for schema_k, schema_v in schema.items():
            if schema_k not in d:
                schema_v_type = type(schema_v)
                if schema_v_type in JSON_ARRAY_TYPES:
                    d[schema_k] = []
                elif schema_v_type in JSON_MAPPING_TYPES:
                    d[schema_k] = {}
                    helper(d[schema_k], schema_v)
                else:
                    assert schema_v_type in JSON_LEAF_TYPES, schema_v_type
                    d[schema_k] = schema_v

        # schema 에 있는 항목 처리
        for k, v in d.items():
            if k in schema:
                v_type = type(v)
                if v_type in JSON_CONTAINER_TYPES:
                    helper(v, schema[k])
                else:
                    assert v_type in JSON_LEAF_TYPES, v_type
                    schema_v = schema[k]
                    schema_v_type = type(schema_v)
                    assert schema_v_type in JSON_LEAF_TYPES, schema_v_type
                    if v is not None and schema_v is not None:
                        if not isinstance(v, type(schema_v)):
                            if v_type == int and schema_v_type in (float, str):
                                d[k] = schema_v_type(v)
                            elif v_type == float and schema_v_type in (str, ):
                                d[k] = schema_v_type(v)
                            else:
                                assert False, "k:v == ({}, {})\nschema = {}".format(
                                    k, v, schema)
Ejemplo n.º 4
0
class SetEnviron:
    """A reentrant, re-usable context manager for temporarily setting environment variables."""

    __slots__ = ('__weakref__', '_kwargs', '_kwargs_old')
    _kwargs: Mapping[str, str]
    _kwargs_old: Mapping[str, Optional[str]]

    def __init__(self, **kwargs: str) -> None:
        r"""Initialize the context manager.

        Parameters
        ----------
        \**kwargs : :class:`str`
            The to-be updated parameters.

        """
        self._kwargs = MappingProxyType(kwargs)
        self._kwargs_old = MappingProxyType({
            k: os.environ.get(k) for k in self._kwargs
        })

    def __enter__(self) -> None:
        """Enter the context manager."""
        os.environ.update(self._kwargs)

    def __exit__(
        self,
        __exc_type: Optional[Type[BaseException]],
        __exc_value: Optional[BaseException],
        __traceback: Optional[TracebackType],
    ) -> None:
        """Exit the context manager."""
        for k, v in self._kwargs_old.items():
            if v is None:
                # Use `pop` instead of `del` to ensure thread-safety
                os.environ.pop(k, None)
            else:
                os.environ[k] = v
Ejemplo n.º 5
0
	spec = [
		('array', nbtype[:]),
		('arraysize', nb.intp),
		('size', nb.intp),
		('upper', nb.boolean),
		('diag_val', nbtype)
	]

	return nb.jitclass(spec)(base)


# Jitclasses of NbTriLMatrixBase for each supported array data type
_LOWER_JITCLASS_BY_TYPE = MappingProxyType({
	dtype: _make_jitclass_for_type(NbTriLMatrixBase, nbtype)
	for dtype, nbtype in NUMBA_TYPES.items()
})

# Jitclasses of NbTriUMatrixBase for each supported array data type
_UPPER_JITCLASS_BY_TYPE = MappingProxyType({
	dtype: _make_jitclass_for_type(NbTriUMatrixBase, nbtype)
	for dtype, nbtype in NUMBA_TYPES.items()
})


def _get_jitclass_for_dtype(dtype, upper):
	"""
	Get the correct jitclass of NbTriMatrixBase for the data type and triangle.
	"""
	if upper:
		return _UPPER_JITCLASS_BY_TYPE[dtype]
Ejemplo n.º 6
0
class Network:
    """Maintains connections over the raft network as well as queues for messages"""
    def __init__(self, peer_id):
        self.peer_id = peer_id
        self._active_connections = {}
        self.inbox = asyncio.Queue()
        self._active_ids = self._active_connections.keys()
        self._all_ids = frozenset(config.SERVERS) - {peer_id}
        # have a queue per peer in the network
        self.outbox = MappingProxyType(
            {i: asyncio.Queue()
             for i in self._all_ids})

    def send(self, dest: int, msg):
        self.outbox[dest].put_nowait(msg)

    async def _send_messages(self):
        suppresser = contextlib.suppress(KeyError)
        for peer, queue in itertools.cycle(self.outbox.items()):
            try:
                msg = queue.get_nowait()
            except asyncio.QueueEmpty:  # no message here, come back later
                await asyncio.sleep(.1)
            else:
                queue.task_done()
                with suppresser:
                    reader, writer = self._active_connections[peer]
                    logger.debug(f"Sending to {peer=} {msg=}")
                    await transport.send(writer, msg)

    @cached_property
    def peers(self) -> frozenset:
        return self._all_ids

    async def recv(self) -> bytes:
        return await self.inbox.get()

    async def _on_connetion(self, reader, writer):
        logger.info(f"Handling client: {writer.get_extra_info('peername')}")
        connected_peer = await transport.recv(
            reader)  # first message is always peer_id
        logger.debug(f"Starting to receive from {connected_peer=}")
        with contextlib.suppress(asyncio.IncompleteReadError):
            while message := await transport.recv(reader):
                await self.inbox.put(message)

        logger.warning(
            f"Broken peer {connected_peer}. Closing connection and clearing queue."
        )
        writer.close()
        await writer.wait_closed()
        # remove all that was going to be sent
        peer_q = self.outbox[connected_peer]
        while peer_q.qsize():
            peer_q.get_nowait()
            peer_q.task_done()

        logger.warning(
            f"Queue {peer_q} empty. Resurrecting connection for {connected_peer}"
        )
        del self._active_connections[connected_peer]
        awaitable = self._connect(connected_peer)
        asyncio.create_task(
            awaitable)  # resurrect connection and hope for the best
Ejemplo n.º 7
0
class Snapshot:
    """
    An immutable snapshot of contract data.
    """
    def __init__(self, creates: Mapping[ContractId, CreateEvent],
                 offset: Optional[str]):
        self._offset = offset
        self._creates = MappingProxyType(dict(creates))
        self._contracts = ContractDataView(self._creates)
        self._template_ids = {
            cid.value_type: matching_normalizations(cid.value_type)
            for cid in self._creates
        }

    @property
    def offset(self) -> Optional[str]:
        """
        The offset point in a stream that contains all of the creates without offsetting archives
        in this snapshot.

        This value is ``None`` when connecting to a ledger that is completely empty.
        """
        return self._offset

    def earliest_contract(
            self,
            template_id: Union[str, TypeConName],
            match: Optional[ContractMatch] = None) -> Optional[ContractData]:
        """
        Return the earliest contract in the Active Contract Set (in other words, _still_ active)
        that was created in the transaction stream that matches the specified filter, or ``None``
        if there are no matches.
        """
        ev = self.earliest_create(template_id, match)
        return ev.payload if ev is not None else None

    def matching_contracts(
        self,
        template_id: Union[str, TypeConName],
        match: Optional[ContractMatch] = None
    ) -> Mapping[ContractId, ContractData]:
        """"""
        return ContractDataView(self.matching_creates(template_id, match))

    def latest_contract(
            self,
            template_id: Union[str, TypeConName],
            match: Optional[ContractMatch] = None) -> Optional[ContractData]:
        """
        Return the contract that was created last in the transaction stream that matches the
        specified filter, or ``None`` if there are no matches.
        """
        ev = self.latest_create(template_id, match)
        return ev.payload if ev is not None else None

    def earliest_create(
            self,
            template_id: Union[str, TypeConName],
            match: Optional[ContractMatch] = None) -> Optional[CreateEvent]:
        """
        Return the earliest :class:`CreateEvent` in the Active Contract Set (in other words, the
        corresponding contract is _still_ active) that was created in the transaction stream that
        matches the specified filter, or ``None``
        if there are no matches.
        """
        wanted_template_ids = self._matching_template_ids(template_id)

        # in Python 3.9, dict values can be `reversed`; sadly we're not there yet
        for cev in reversed(list(self._creates.values())):
            if cev.contract_id.value_type in wanted_template_ids and is_match(
                    match, cev.payload):
                return cev
        return None

    def matching_creates(
        self,
        template_id: Union[str, TypeConName],
        match: Optional[ContractMatch] = None
    ) -> Mapping[ContractId, CreateEvent]:
        """
        Return the :class:`CreateEvent`s (indexed by :class:`ContractId`) whose contracts match the
        specified criteria.
        """
        wanted_template_ids = self._matching_template_ids(template_id)
        matches = {}
        for cid, cev in self._creates.items():
            if cev.contract_id.value_type in wanted_template_ids and is_match(
                    match, cev.payload):
                matches[cev.contract_id] = cev
        return MappingProxyType(matches)

    def latest_create(
            self,
            template_id: Union[str, TypeConName],
            match: Optional[ContractMatch] = None) -> Optional[CreateEvent]:
        """
        Return the contract that was created last in the transaction stream that matches the
        specified filter, or ``None`` if there are no matches.
        """
        wanted_template_ids = self._matching_template_ids(template_id)
        for cev in self._creates.values():
            if cev.contract_id.value_type in wanted_template_ids and is_match(
                    match, cev.payload):
                return cev
        return None

    def _matching_template_ids(
            self, template_id: Union[str,
                                     TypeConName]) -> Collection[TypeConName]:
        t = normalize(template_id)
        return {
            tid
            for tid, matches in self._template_ids.items() if t in matches
        }

    @property
    def contracts(self) -> Mapping[ContractId, ContractData]:
        """
        A read-only map of contract IDs to contract data.
        """
        return self._contracts

    @property
    def creates(self) -> Mapping[ContractId, CreateEvent]:
        """
        A map of contract IDs to :class:`CreateEvent`s that represent the current set of contracts
        in this ACS.

        :class:`CreateEvent`'s expose additional information, including signatories and observers.
        """
        return self._creates

    def __bool__(self) -> bool:
        return bool(self._creates)

    def __len__(self) -> int:
        return len(self._creates)

    def __repr__(self) -> str:
        return f"Snapshot(len={len(self)})"
Ejemplo n.º 8
0
class DRV(object):
    """
    A discrete random variable.

    A DRV has one or more :dfn:`possible values` (or just :dfn:`values`), which
    can be any type. Each possible value has an associated :dfn:`probability`,
    which is a real number between 0 and 1.

    It is strongly recommended that the probabilities add up to exactly 1. This
    might be difficult to achieve with :obj:`float` probabilities, and so this
    class does not enforce that restriction, and makes it possible to sample a
    variable even if the total is not 1. The exact distribution of the samples
    in that case is not specified, only that it will attempt to follow the
    probabilities given. Loosely: if the total is too low then one value's
    probability is rounded up. If the total is too high, then one probability
    is rounded down, and/or one or more values is ignored. These adjustments
    apply only to sampling: the original probabilities are still reported by
    :func:`to_dict()` etc.

    Because :code:`==` is overridden to return a DRV (not a boolean), DRV
    objects are not hashable and cannot be used in a set or as a dictionary
    key, even though the objects are immutable. This means you cannot have a
    DRV as a "possible value" of another DRV.

    DRV also resists being considered in boolean context, so for example you
    cannot in general test whether or not a DRV appears in a list::

      >>> from omnidice.dice import d3, d6
      >>> d3 in [d3, d6]
      True
      >>> d6 in [d3, d6]
      Traceback (most recent call last):
        File "<stdin>", line 1, in <module>
        File "omnidice/drv.py", line 452, in __bool__
          raise ValueError('The truth value of a random variable is ambiguous')
      ValueError: The truth value of a random variable is ambiguous

    This is the same solution used by (for example) :obj:`numpy.array`. If the
    object allowed standard boolean conversion then :code:`d4 in [d3, d6]`
    would be True, which is unacceptably surprising!

    :param distribution: Any value from which a dictionary can be constructed,
      that is a :obj:`Mapping` or :obj:`Iterable` of (value, probability)
      pairs.
    :param tree: The expression from which this object was defined. Currently
      this is used only for the string representation, but might in future
      help support lazily-evaluated DRVs.
    """
    def __init__(
        self,
        distribution: 'DictData',
        *,
        tree: ExpressionTree = None,
    ):
        self.__dist = MappingProxyType(dict(distribution))
        # Cumulative distribution. Defer calculating this, because we only
        # need it if the variable is actually sampled. Intermediate values in
        # building up a complex DRV won't ever be sampled, so save the work.
        self.__cdf = None
        self.__lcm = None
        self.__intvalued = None
        self.__expr_tree = tree
        # Computed probabilities can hit 0 due to float underflow, but maybe
        # we should strip out anything with probability 0.
        if not all(0 <= prob <= 1 for value, prob in self._items()):
            raise ValueError('Probability not in range')
    def __repr__(self):
        if self.__expr_tree is not None:
            return self.__expr_tree.bracketed()
        return f'DRV({self.__dist})'
    def is_same(self, other: 'DRV') -> bool:
        """
        Return True if `self` and `other` have the same discrete probability
        distribution. Possible values with 0 probability are excluded from the
        comparison.
        """
        values = set(value for value, prob in self._items() if prob != 0)
        othervalues = set(value for value, prob in other._items() if prob != 0)
        if values != othervalues:
            return False
        return all(self.__dist[val] == other.__dist[val] for val in values)
    def is_close(self, other: 'DRV', *, rel_tol=None, abs_tol=None) -> bool:
        """
        Return True if `self` and `other` have approximately the same discrete
        probability distribution, within the specified tolerances. Possible
        values with 0 probability are excluded from the comparison.

        `rel_tol` and `abs_tol` are applied only to the probabilities, not to
        the possible values. They are defined as for :func:`math.isclose`.
        """
        values = set(value for value, prob in self._items() if prob != 0)
        othervalues = set(value for value, prob in other._items() if prob != 0)
        if values != othervalues:
            return False
        kwargs = {}
        if rel_tol is not None:
            kwargs['rel_tol'] = rel_tol
        if abs_tol is not None:
            kwargs['abs_tol'] = abs_tol
        return all(
            isclose(self.__dist[val], other.__dist[val], **kwargs)
            for val in values
        )
    def to_dict(self) -> Dict[Any, 'Probability']:
        """
        Return a dictionary mapping all possible values to probabilities.
        """
        # dict(self.__dist) is type-correct, but about 3 times slower.
        # Unfortunately there's no way to parameterise MappingProxyType to
        # say what the type is of the underlying mapping that gets copied.
        return self.__dist.copy()  # type: ignore
    def to_pd(self):
        """
        Return a :class:`pandas.Series` mapping values to probabilities. The
        series is indexed by the possible values.

        :raises: :class:`ModuleNotFoundError` if pandas is not installed. Note
          that pandas is not a hard dependency of this package. You must
          install it to use this method.
        """
        try:
            import pandas as pd
        except ModuleNotFoundError:
            msg = 'You must install pandas for this optional feature'
            raise ModuleNotFoundError(msg)
        return pd.Series(self.__dist, name='probability')
    def to_table(self, as_float: bool = False) -> str:
        """
        Return a string containing the values and probabilities formatted as a
        table. This is intended only for manually checking small distributions.

        :param as_float: Display probabilites as floating-point. You might find
          floats easier to read by eye.
        """
        if not as_float:
            items = self._items()
        else:
            items = ((v, float(p)) for v, p in self._items())
        with contextlib.suppress(TypeError):
            items = sorted(items)
        return '\n'.join([
            'value\tprobability',
            *(f'{v}\t{p}' for v, p in items),
        ])
    def faster(self) -> 'DRV':
        """
        Return a new DRV, with all probabilities converted to float.
        """
        return DRV(
            {x: float(y) for x, y in self._items()},
            tree=self._combine_post('.faster()'),
        )
    def _items(self):
        return self.__dist.items()
    def replace_tree(self, tree: ExpressionTree) -> 'DRV':
        """
        Return a new DRV with the same distribution as this DRV, but defined
        from the specified expression.

        This is used for example when some optimisation has computed a DRV one
        way, but we want to represent it the original way.
        """
        return DRV(self.__dist, tree=tree)
    @property
    def cdf(self):
        if self.__cdf is None:
            def iter_totals():
                total = 0
                for value, probability in self._items():
                    total += probability
                    yield value, total
                # In case of rounding errors
                if total < 1:
                    yield value, 1
            self.__cdf_values, self.__cdf = map(tuple, zip(*iter_totals()))
        return self.__cdf
    @property
    def _lcm(self):
        def lcm(a, b):
            return (a * b) // gcd(a, b)
        if self.__lcm is None:
            result = 1
            for _, prob in self._items():
                if not isinstance(prob, Fraction):
                    result = 0
                    break
                result = lcm(prob.denominator, result)
            self.__lcm = result
        return self.__lcm
    def sample(self, random: Random = rng):
        """
        Sample this variable.

        :param random: Random number generator to use. The default is a single
          object shared by all instances of :class:`DRV`.
        :returns: One possible value of this variable.
        """
        sample: Probability
        if self._lcm == 0:
            sample = random.random()
        else:
            sample = Fraction(random.randrange(self._lcm) + 1, self._lcm)
        # The index of the first cumulative probability greater than or equal
        # to our random sample. If there's a repeated probability in the array,
        # that means there was a value with probability 0. So we don't want to
        # select that value even in the very unlikely case of our sample being
        # exactly equal to the repeated probability!
        idx = bisect_left(self.cdf, sample)
        return self.__cdf_values[idx]
    @property
    def _intvalued(self):
        if self.__intvalued is None:
            self.__intvalued = all(isinstance(x, int) for x in self.__dist)
        return self.__intvalued
    def __add__(self, right) -> 'DRV':
        """
        Handler for :code:`self + right`.

        Return a random variable which is the result of adding this variable to
        `right`. `right` can be either a constant or another DRV (in which case
        the result assumes that the two random variables are independent).

        As with :meth:`apply()`, probabilities are added up wherever addition
        is many-to-one (for constant numbers it is one-to-one provided overflow
        does not occur).
        """
        while CONVOLVE_OPTIMISATION:
            if np is None:
                break
            if not isinstance(right, DRV):
                break
            product_size = len(self.__dist) * len(right.__dist)
            if product_size <= CONVOLVE_SIZE_LIMIT:
                break
            if not self._intvalued or not right._intvalued:
                break
            def get_range(dist):
                return range(min(dist), max(dist) + 1)
            self_values = get_range(self.__dist)
            right_values = get_range(right.__dist)
            # Very sparse arrays aren't faster to convolve.
            if 100 * product_size <= len(self_values) * len(right_values):
                break
            final_probs = np.convolve(
                np.array(tuple(self.__dist.get(x, 0) for x in self_values)),
                np.array(tuple(right.__dist.get(x, 0) for x in right_values)),
            )
            values = range(
                min(self_values) + min(right_values),
                max(self_values) + max(right_values) + 1,
            )
            filtered = (final_probs > 0)
            values = np.array(values)[filtered].tolist()
            final_probs = final_probs[filtered]
            return DRV(
                zip(values, final_probs),
                tree=self._combine(self, right, '+'),
            )
        return self._apply2(operator.add, right, connective='+')
    def __sub__(self, right) -> 'DRV':
        """
        Handler for :code:`self - right`.

        Return a random variable which is the result of subtracting `right`
        from this variable. `right` can be either a constant or another DRV (in
        which case the result assumes that the two random variables are
        independent).

        As with :meth:`apply()`, probabilities are added up wherever
        subtraction is many-to-one (for constant numbers it is one-to-one
        provided overflow does not occur).
        """
        if isinstance(right, DRV):
            # So that we get the convolve optimisation
            tree = self._combine(self, right, '-')
            return (self + -right).replace_tree(tree)
        else:
            return self._apply2(operator.sub, right, connective='-')
    def __mul__(self, right):
        """
        Handler for :code:`self * right`.

        Return a random variable which is the result of multiplying this
        variable with `right`. `right` can be either a constant or another DRV
        (in which case the result assumes that the two random variables are
        independent).

        As with :meth:`apply()`, probabilities are added up in the case where
        multiplication is not one-to-one (for constant numbers other than zero
        it is one-to-one provided overflow and underflow do not occur).
        """
        return self._apply2(operator.mul, right, connective='*')
    def __rmatmul__(self, left: int) -> 'DRV':
        """
        Handler for :code:`left @ self`.

        Return a random variable which is the result of sampling this variable
        `left` times, and adding the results together.
        """
        if not isinstance(left, int):
            return NotImplemented
        if left <= 0:
            raise ValueError(left)
        # Exponentiation by squaring. This isn't massively faster, but does
        # help a bit for hundreds of dice.
        result = None
        so_far = self
        original = left
        while True:
            if left % 2 == 1:
                if result is None:
                    result = so_far
                else:
                    result += so_far
            left //= 2
            if left == 0:
                break
            so_far += so_far
        # left was non-zero, so result cannot still be None
        result = cast(DRV, result)
        return result.replace_tree(self._combine(original, self, '@'))
    def __matmul__(self, right: 'DRV') -> 'DRV':
        """
        Handler for :code:`self @ right`.

        Return a random variable which is the result of sampling this variable
        once, then adding together that many samples of `right`.

        All possible values of this variable must be of type :obj:`int`.
        """
        if not isinstance(right, DRV):
            return NotImplemented
        if not all(isinstance(value, int) for value in self.__dist):
            raise TypeError('require integers on LHS of @')
        def iter_drvs():
            so_far = min(self.__dist) @ right
            for num_dice in range(min(self.__dist), max(self.__dist) + 1):
                if num_dice in self.__dist:
                    yield so_far, self.__dist[num_dice]
                so_far += right
        return DRV.weighted_average(
            iter_drvs(),
            tree=self._combine(self, right, '@'),
        )
    def __truediv__(self, right) -> 'DRV':
        """
        Handler for :code:`self / right`.

        Return a random variable which is the result of floor-dividing this
        variable by `right`. `right` can be either a constant or another DRV
        (in which case the result assumes that the two random variables are
        independent).

        As with :meth:`apply()`, probabilities are added up wherever division
        is many-to-one (for constant numbers other than zero it is one-to-one
        provided overflow and underflow do not occur).

        0 must not be a possible value of `right` (even with probability 0).
        """
        return self._apply2(operator.truediv, right, connective='/')
    def __floordiv__(self, right) -> 'DRV':
        """
        Handler for :code:`self // right`.

        Return a random variable which is the result of dividing this
        variable by `right`. `right` can be either a constant or another DRV
        (in which case the result assumes that the two random variables are
        independent).

        As with :meth:`apply()`, probabilities are added up wherever floor
        division is many-to-one (for numbers it is mostly many-to-one, for
        example :code:`2 // 2 == 1 == 3 // 2`).

        0 must not be a possible value of `right` (even with probability 0).
        """
        return self._apply2(operator.floordiv, right, connective='//')
    def __neg__(self) -> 'DRV':
        """
        Handler for :code:`-self`.

        Return a random variable which is the result of negating the values of
        this variable.

        As with :meth:`apply()`, probabilities are added up wherever negation
        is many-to-one (for numbers it is one-to-one).
        """
        return self.apply(operator.neg, tree=self._combine(self, '-'))
    def __eq__(self, right) -> 'DRV':  # type: ignore[override]
        """
        Handler for :code:`self == right`.

        Return a random variable which takes value :obj:`True` where `self` is
        equal to `right`, and :obj:`False` otherwise. `right` can be either a
        constant or another DRV (in which case the result assumes that the two
        random variables are independent).

        If either :obj:`True` or :obj:`False` cannot happen then the result
        has only one possible value, with probability 1. There is no possible
        value with probability 0.
        """
        if isinstance(right, DRV):
            small, big = sorted([self, right], key=lambda x: len(x.__dist))
            prob = sum(
                prob * big.__dist.get(val, 0)
                for val, prob in small._items()
            )
        else:
            prob = self.__dist.get(right)
        if not prob:
            return DRV({False: 1})
        if prob >= 1.0:
            return DRV({True: 1})
        return DRV(
            {False: 1 - prob, True: prob},
            tree=self._combine(self, right, '=='),
        )
    def __ne__(self, right: 'DRV') -> 'DRV':  # type: ignore[override]
        """
        Handler for :code:`self != right`.

        Return a random variable which takes value :obj:`True` where `self` is
        not equal to `right`, and :obj:`False` otherwise. `right` can be either
        a constant or another DRV (in which case the result assumes that the
        two random variables are independent).

        If either :obj:`True` or :obj:`False` cannot happen then the result
        has only one possible value, with probability 1. There is no possible
        value with probability 0.
        """
        return (
            (self == right)
            .apply(operator.not_)
            .replace_tree(self._combine(self, right, '!='))
        )
    def __bool__(self):
        # Prevent DRVs being truthy, and hence "3 in [DRV({2: 1})]" is true.
        raise ValueError('The truth value of a random variable is ambiguous')
    def __le__(self, right) -> 'DRV':
        """
        Handler for :code:`self <= right`.

        Return a random variable which takes value :obj:`True` where `self` is
        less than or equal to `right`, and :obj:`False` otherwise. `right` can
        be either a constant or another DRV (in which case the result assumes
        that the two random variables are independent).

        If either :obj:`True` or :obj:`False` cannot happen then the result
        has only one possible value, with probability 1. There is no possible
        value with probability 0.
        """
        return self._apply2(operator.le, right, connective='<=')
    def __lt__(self, right) -> 'DRV':
        """
        Handler for :code:`self < right`.

        Return a random variable which takes value :obj:`True` where `self` is
        less than `right`, and :obj:`False` otherwise. `right` can be either a
        constant or another DRV (in which case the result assumes that the two
        random variables are independent).

        If either :obj:`True` or :obj:`False` cannot happen then the result
        has only one possible value, with probability 1. There is no possible
        value with probability 0.
        """
        return self._apply2(operator.lt, right, connective='<')
    def __ge__(self, right) -> 'DRV':
        """
        Handler for :code:`self >= right`.

        Return a random variable which takes value :obj:`True` where `self` is
        greater than or equal to `right`, and :obj:`False` otherwise. `right`
        can be either a constant or another DRV (in which case the result
        assumes that the two random variables are independent).

        If either :obj:`True` or :obj:`False` cannot happen then the result
        has only one possible value, with probability 1. There is no possible
        value with probability 0.
        """
        return self._apply2(operator.ge, right, connective='>=')
    def __gt__(self, right) -> 'DRV':
        """
        Handler for :code:`self > right`.

        Return a random variable which takes value :obj:`True` where `self` is
        greater than `right`, and :obj:`False` otherwise. `right` can be either
        a constant or another DRV (in which case the result assumes that the
        two random variables are independent).

        If either :obj:`True` or :obj:`False` cannot happen then the result
        has only one possible value, with probability 1. There is no possible
        value with probability 0.
        """
        return self._apply2(operator.gt, right, connective='>')
    def explode(self, rerolls: int = 50) -> 'DRV':
        """
        Return a new DRV distributed according to the rules of an "exploding
        die". This means, first roll the die (sample this DRV). If the result
        is not the maximum possible, then keep it. If it is the maximum, then
        roll again and add the new result to the original.

        Because DRV represents only finitely-many possible values, whereas the
        process of rerolling can (with minuscule probability) carry on
        indefinitely, this method imposes an arbitary limit to the number of
        rerolls.

        :param rerolls: The maximum number of rerolls. Set this to 1 for a die
          that can only "explode" once, not indefinitely.
        """
        reroll_value = max(self.__dist.keys())
        reroll_prob = self.__dist[reroll_value]
        each_die = self.to_dict()
        each_die.pop(reroll_value)
        def iter_pairs():
            for idx in range(rerolls + 1):
                for value, prob in each_die.items():
                    value += reroll_value * idx
                    prob *= reroll_prob ** idx
                    yield (value, prob)
            yield (reroll_value * (idx + 1), reroll_prob ** (idx + 1))
        postfix = '.explode()' if rerolls == 50 else f'.explode({rerolls!r})'
        return self._reduced(iter_pairs(), tree=self._combine_post(postfix))
    def apply(
        self,
        func: Callable[[Any], Any],
        *,
        tree: ExpressionTree = None,
        allow_drv: bool = False,
    ) -> 'DRV':
        """
        Apply a unary function to the values produced by this DRV. If `func` is
        an injective (one-to-one) function, then the probabilities are
        unchanged. If `func` is many-to-one, then the probabilities are added
        together.

        :param func: Function to map the values. Each value `x` is replaced by
          `func(x)`.
        :param tree: the expression from which this object was defined. If
          ``None``, the result DRV is represented by listing out all the values
          and probabilities.
        :param allow_drv: If True, then when `func` returns a DRV, the possible
          values of that DRV are each included in the returned DRV. Recall that
          a DRV cannot be a possible value of the returned DRV, because it is
          not hashable. So, without this option `func` cannot return a DRV.

        .. versionchanged:: 1.1
            Added ``allow_drv`` option.
        """
        return DRV._reduced(self._items(), func, tree=tree, drv=allow_drv)
    def _apply2(self, func, right, connective=None) -> 'DRV':
        """Apply a binary function, with the values of this DRV on the left."""
        expr_tree = self._combine(self, right, connective)
        if isinstance(right, DRV):
            return self._cross_reduce(func, right, tree=expr_tree)
        return self.apply(lambda x: func(x, right), tree=expr_tree)
    def _cross_reduce(self, func, right, tree=None) -> 'DRV':
        """
        Take the cross product of self and right, then reduce by applying func.
        """
        return DRV._reduced(
            self._iter_cross(right),
            lambda value: func(*value),
            tree=tree,
        )
    def _iter_cross(self, right):
        """
        Take the cross product of self and right, with probabilities assuming
        that the two are independent variables.

        Note that the cross product of an object with itself represents the
        results of sampling it twice, *not* just the pairs (x, x) for each
        possible value!
        """
        for (lvalue, lprob) in self._items():
            for (rvalue, rprob) in right._items():
                yield ((lvalue, rvalue), lprob * rprob)
    @staticmethod
    def _reduced(iterable, func=lambda x: x, tree=None, drv=False) -> 'DRV':
        distribution: dict = collections.defaultdict(int)
        if not drv:
            # Optimisation does make a difference to e.g. test_convolve
            for value, prob in iterable:
                distribution[func(value)] += prob
        else:
            for value, prob in iterable:
                transformed = func(value)
                if isinstance(transformed, DRV):
                    for value2, prob2 in transformed._weighted_items(prob):
                        distribution[value2] += prob2
                else:
                    distribution[transformed] += prob
        return DRV(distribution, tree=tree)
    @staticmethod
    def weighted_average(
        iterable: Iterable[Tuple['DRV', 'Probability']],
        *,
        tree: ExpressionTree = None,
    ) -> 'DRV':
        """
        Compute a weighted average of DRVs, each with its own probability.

        This is for when you have a set of mutually-exclusive events which can
        happen, and then the final outcome occurs with a different known
        distribution according to which of those events occurs. For example,
        this function is used to implement the ``@`` operator when the
        left-hand-side is a DRV. The first roll determines what the second roll
        will be.

        The DRVs that are averaged together do not need to be disjoint (that
        is, they can have overlapping possible values). Whenever multiple
        events lead to the same final outcome, the probabilities are combined:

        https://en.wikipedia.org/wiki/Law_of_total_probability

        :param iterable: Pairs, each containing a DRV and the probability of
          that DRV being the one selected. The probabilities should add to 1,
          but this is not enforced.
        :param tree: the expression from which this object was defined. If
          ``None``, the result DRV is represented by listing out all the values
          and probabilities.

        .. versionadded:: 1.1
        """
        def iter_pairs():
            for drv, weight in iterable:
                yield from drv._weighted_items(weight)
        return DRV._reduced(iter_pairs(), tree=tree)
    def _weighted_items(self, weight, pred=lambda x: True):
        for value, prob in self.__dist.items():
            if pred(value):
                yield value, prob * weight
    def given(self, predicate: Callable[[Any], bool]) -> 'DRV':
        """
        Return the conditional probability distribution of this DRV, restricted
        to the possible values for which `predicate` is true.

        For example, :code:`drv.given(lambda x: True)` is the same distribution
        as :code:`drv`, and the following are equivalent to each other::

            d6.given(lambda x: bool(x % 2))
            DRV({1: Fraction(1, 3), 3: Fraction(1, 3), 5: Fraction(1, 3)})

        If `x` is a DRV, and `A` and `B` are predicates, then the conditional
        probability of `A` given `B`, written in probability theory as
        ``p(A(x) | B(x))``, can be computed as :code:`p(x.given(B).apply(A)))`.

        :param predicate: Called with possible values of `self`, and must
          return :obj:`bool` (not just truthy).
        :raises ZeroDivisionError: if the probability of `predicate` being
          true is 0.

        .. versionadded:: 1.1
        """
        total = p(self.apply(predicate))
        if total == 0:
            # Would be raised anyway, but nicer error message
            raise ZeroDivisionError('predicate is True with probability 0')
        return DRV(self._weighted_items(1 / total, predicate))
    @staticmethod
    def _combine(*args):
        """
        Helper for combining two expressions into a combined expression.
        """
        for arg in args:
            if isinstance(arg, DRV) and arg.__expr_tree is None:
                return None
        def unpack(subexpr):
            if isinstance(subexpr, DRV):
                return subexpr.__expr_tree
            return Atom(repr(subexpr))
        if len(args) == 2:
            # Unary expression
            subexpr, connective = args
            return UnaryExpression(unpack(subexpr), connective)
        # Binary expression
        left, right, connective = args
        return BinaryExpression(unpack(left), unpack(right), connective)
    def _combine_post(self, postfix):
        if self.__expr_tree is None:
            return None
        return AttrExpression(self.__expr_tree, postfix)
Ejemplo n.º 9
0
""" Postponed Evaluation of Annotations Becomes Default """

# int.bit_count()
num1 = 1
print("num1 = {}".format(num1))
print("num1.bit_count() => {}".format(num1.bit_count()))

num1 = 10
print("num1 = {}".format(num1))
print("num1.bit_count() => {}".format(num1.bit_count()))

num1 = 100
print("num1 = {}".format(num1))
print("num1.bit_count() => {}".format(num1.bit_count()))

from types import MappingProxyType

dict1 = dict()
mappingProxy = MappingProxyType(dict1)
print(type(dict1.items()) is type(mappingProxy.items()))
print(mappingProxy.items().mapping is dict1)

# zip(*iterables, strict=False)
list1 = list(zip(("a", "b", "c"), (1, 2, 3), strict=True))
print("list1 = list(zip(('a', 'b', 'c'), (1, 2, 3), strict=True))")
print("list1 => {}".format(list1))
Ejemplo n.º 10
0
class Writer(object):
    def __init__(
        self,
        filename,
        workdir=None,
        encoding=UTF8,
        compression=DEFAULT_COMPRESSION,
        min_bin_size=512 * 1024,
        max_redirects=5,
        observer=None,
    ):
        self.filename = filename
        self.observer = observer
        if os.path.exists(self.filename):
            raise SystemExit('File %r already exists' % self.filename)

        # make sure we can write
        with fopen(self.filename, 'wb'):
            pass

        self.encoding = encoding

        if encodings.search_function(self.encoding) is None:
            raise UnknownEncoding(self.encoding)

        self.workdir = workdir

        self.tmpdir = tmpdir = tempfile.TemporaryDirectory(
            prefix='{0}-'.format(os.path.basename(filename)), dir=workdir)

        self.f_ref_positions = self._wbfopen('ref-positions')
        self.f_store_positions = self._wbfopen('store-positions')
        self.f_refs = self._wbfopen('refs')
        self.f_store = self._wbfopen('store')

        self.max_redirects = max_redirects
        if max_redirects:
            self.aliases_path = os.path.join(tmpdir.name, 'aliases')
            self.f_aliases = Writer(
                self.aliases_path,
                workdir=tmpdir.name,
                max_redirects=0,
                compression=None,
            )

        if compression is None:
            compression = ''
        if compression not in COMPRESSIONS:
            raise UnknownCompression(compression)
        else:
            self.compress = COMPRESSIONS[compression].compress

        self.compression = compression
        self.content_types = {}

        self.min_bin_size = min_bin_size

        self.current_bin = None

        self.blob_count = 0
        self.ref_count = 0
        self.bin_count = 0
        self._tags = {
            'version.python': sys.version.replace('\n', ' '),
            'version.pyicu': icu.VERSION,
            'version.icu': icu.ICU_VERSION,
            'created.at': datetime.now(timezone.utc).isoformat()
        }
        self.tags = MappingProxyType(self._tags)

    def _wbfopen(self, name):
        return StructWriter(fopen(os.path.join(self.tmpdir.name, name), 'wb'),
                            encoding=self.encoding)

    def tag(self, name, value=''):
        if len(name.encode(self.encoding)) > MAX_TINY_TEXT_LEN:
            self._fire_event('tag_name_too_long', (name, value))
            return

        if len(value.encode(self.encoding)) > MAX_TINY_TEXT_LEN:
            self._fire_event('tag_value_too_long', (name, value))
            value = ''

        self._tags[name] = value

    def _split_key(self, key):
        if isinstance(key, str):
            actual_key = key
            fragment = ''
        else:
            actual_key, fragment = key
        if len(actual_key) > MAX_TEXT_LEN or len(fragment) > MAX_TINY_TEXT_LEN:
            raise KeyTooLongException(key)
        return actual_key, fragment

    def add(self, blob, *keys, content_type=''):

        if len(blob) > MAX_LARGE_BYTE_STRING_LEN:
            self._fire_event('content_too_long', blob)
            return

        if len(content_type) > MAX_TEXT_LEN:
            self._fire_event('content_type_too_long', content_type)
            return

        actual_keys = []

        for key in keys:
            try:
                actual_key, fragment = self._split_key(key)
            except KeyTooLongException as e:
                self._fire_event('key_too_long', e.key)
            else:
                actual_keys.append((actual_key, fragment))

        if len(actual_keys) == 0:
            return

        if self.current_bin is None:
            self.current_bin = BinMemWriter()
            self.bin_count += 1

        if content_type not in self.content_types:
            self.content_types[content_type] = len(self.content_types)

        self.current_bin.add(self.content_types[content_type], blob)
        self.blob_count += 1
        bin_item_index = len(self.current_bin) - 1
        bin_index = self.bin_count - 1

        for actual_key, fragment in actual_keys:
            self._write_ref(actual_key, bin_index, bin_item_index, fragment)

        if (self.current_bin.current_offset > self.min_bin_size
                or len(self.current_bin) == MAX_BIN_ITEM_COUNT):
            self._write_current_bin()

    def add_alias(self, key, target_key):
        if self.max_redirects:
            try:
                self._split_key(key)
            except KeyTooLongException as e:
                self._fire_event('alias_too_long', e.key)
                return
            try:
                self._split_key(target_key)
            except KeyTooLongException as e:
                self._fire_event('alias_target_too_long', e.key)
                return
            self.f_aliases.add(pickle.dumps(target_key), key)
        else:
            raise NotImplementedError()

    def _fire_event(self, name, data=None):
        if self.observer:
            self.observer(WriterEvent(name, data))

    def _write_current_bin(self):
        self.f_store_positions.write_long(self.f_store.tell())
        self.current_bin.finalize(self.f_store, self.compress)
        self.current_bin = None

    def _write_ref(self, key, bin_index, item_index, fragment=''):
        self.f_ref_positions.write_long(self.f_refs.tell())
        self.f_refs.write_text(key)
        self.f_refs.write_int(bin_index)
        self.f_refs.write_short(item_index)
        self.f_refs.write_tiny_text(fragment)
        self.ref_count += 1

    def _sort(self):
        self._fire_event('begin_sort')
        f_ref_positions_sorted = self._wbfopen('ref-positions-sorted')
        self.f_refs.flush()
        self.f_ref_positions.close()
        with MultiFileReader(self.f_ref_positions.name, self.f_refs.name) as f:
            ref_list = RefList(f, self.encoding, count=self.ref_count)
            sortkey_func = sortkey(IDENTICAL)
            for i in sorted(range(len(ref_list)),
                            key=lambda j: sortkey_func(ref_list[j].key)):
                ref_pos = ref_list.pos(i)
                f_ref_positions_sorted.write_long(ref_pos)
        f_ref_positions_sorted.close()
        os.remove(self.f_ref_positions.name)
        os.rename(f_ref_positions_sorted.name, self.f_ref_positions.name)
        self.f_ref_positions = StructWriter(fopen(self.f_ref_positions.name,
                                                  'ab'),
                                            encoding=self.encoding)
        self._fire_event('end_sort')

    def _resolve_aliases(self):
        self._fire_event('begin_resolve_aliases')
        self.f_aliases.finalize()
        with MultiFileReader(self.f_ref_positions.name,
                             self.f_refs.name) as f_ref_list:
            ref_list = RefList(f_ref_list, self.encoding, count=self.ref_count)
            ref_dict = ref_list.as_dict()
            with Slob(self.aliases_path) as r:
                aliases = r.as_dict()
                path = os.path.join(self.tmpdir.name, 'resolved-aliases')
                with Writer(
                        path,
                        workdir=self.tmpdir.name,
                        max_redirects=0,
                        compression=None,
                ) as alias_writer:

                    def read_key_frag(item, default_fragment):
                        key_frag = pickle.loads(item.content)
                        if isinstance(key_frag, str):
                            return key_frag, default_fragment
                        else:
                            return key_frag

                    for item in r:
                        from_key = item.key
                        keys = set()
                        keys.add(from_key)
                        to_key, fragment = read_key_frag(item, item.fragment)
                        count = 0
                        while count <= self.max_redirects:
                            # is target key itself a redirect?
                            try:
                                orig_to_key = to_key
                                to_key, fragment = read_key_frag(
                                    next(aliases[to_key]), fragment)
                                count += 1
                                keys.add(orig_to_key)
                            except StopIteration:
                                break
                        if count > self.max_redirects:
                            self._fire_event('too_many_redirects', from_key)
                        try:
                            target_ref = next(ref_dict[to_key])
                        except StopIteration:
                            self._fire_event('alias_target_not_found', to_key)
                        else:
                            for key in keys:
                                ref = Ref(
                                    key=key,
                                    bin_index=target_ref.bin_index,
                                    item_index=target_ref.item_index,
                                    # last fragment in the chain wins
                                    fragment=target_ref.fragment or fragment,
                                )
                                alias_writer.add(pickle.dumps(ref), key)

        with Slob(path) as resolved_aliases_reader:
            previous_key = None
            for item in resolved_aliases_reader:
                ref = pickle.loads(item.content)
                if ref.key == previous_key:
                    continue
                self._write_ref(
                    ref.key,
                    ref.bin_index,
                    ref.item_index,
                    ref.fragment,
                )
                previous_key = ref.key
        self._sort()
        self._fire_event('end_resolve_aliases')

    def finalize(self):
        self._fire_event('begin_finalize')
        if self.current_bin is not None:
            self._write_current_bin()

        self._sort()
        if self.max_redirects:
            self._resolve_aliases()

        files = (
            self.f_ref_positions,
            self.f_refs,
            self.f_store_positions,
            self.f_store,
        )

        for f in files:
            f.close()

        buf_size = 10 * 1024 * 1024

        with fopen(self.filename, mode='wb') as output_file:
            out = StructWriter(output_file, self.encoding)
            out.write(MAGIC)
            out.write(uuid4().bytes)
            out.write_tiny_text(self.encoding, encoding=UTF8)
            out.write_tiny_text(self.compression)

            def write_tags(tags, f):
                f.write(pack(U_CHAR, len(tags)))
                for key, value in tags.items():
                    f.write_tiny_text(key)
                    f.write_tiny_text(value, editable=True)

            write_tags(self.tags, out)

            def write_content_types(content_types, f):
                count = len(content_types)
                f.write(pack(U_CHAR, count))
                types = sorted(content_types.items(), key=lambda x: x[1])
                for content_type, _ in types:
                    f.write_text(content_type)

            write_content_types(self.content_types, out)

            out.write_int(self.blob_count)
            store_offset = (
                out.tell() + U_LONG_LONG_SIZE +  # this value
                U_LONG_LONG_SIZE +  # file size value
                U_INT_SIZE +  # ref count value
                os.stat(self.f_ref_positions.name).st_size +
                os.stat(self.f_refs.name).st_size)
            out.write_long(store_offset)
            out.flush()

            file_size = (
                out.tell() +  # bytes written so far
                U_LONG_LONG_SIZE +  # file size value
                2 * U_INT_SIZE  # ref count and bin count
            )
            file_size += sum((os.stat(f.name).st_size for f in files))
            out.write_long(file_size)

            def mv(src, out):
                fname = src.name
                self._fire_event('begin_move', fname)
                with fopen(fname, mode='rb') as f:
                    while True:
                        data = f.read(buf_size)
                        if len(data) == 0:
                            break
                        out.write(data)
                        out.flush()
                os.remove(fname)
                self._fire_event('end_move', fname)

            out.write_int(self.ref_count)
            mv(self.f_ref_positions, out)
            mv(self.f_refs, out)

            out.write_int(self.bin_count)
            mv(self.f_store_positions, out)
            mv(self.f_store, out)

        self.tmpdir.cleanup()
        self._fire_event('end_finalize')

    def size_header(self):
        size = 0
        size += len(MAGIC)
        size += 16  # uuid bytes
        size += U_CHAR_SIZE + len(self.encoding.encode(UTF8))
        size += U_CHAR_SIZE + len(self.compression.encode(self.encoding))

        size += U_CHAR_SIZE  # tag length
        size += U_CHAR_SIZE  # content types count

        # tags and content types themselves counted elsewhere

        size += U_INT_SIZE  # blob count
        size += U_LONG_LONG_SIZE  # store offset
        size += U_LONG_LONG_SIZE  # file size
        size += U_INT_SIZE  # ref count
        size += U_INT_SIZE  # bin count

        return size

    def size_tags(self):
        size = 0
        for key, _ in self.tags.items():
            size += U_CHAR_SIZE + len(key.encode(self.encoding))
            size += 255
        return size

    def size_content_types(self):
        size = 0
        for content_type in self.content_types:
            size += U_CHAR_SIZE + len(content_type.encode(self.encoding))
        return size

    def size_data(self):
        files = (
            self.f_ref_positions,
            self.f_refs,
            self.f_store_positions,
            self.f_store,
        )
        return sum((os.stat(f.name).st_size for f in files))

    def __enter__(self):
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        self.finalize()
        return False
Ejemplo n.º 11
0
    STRING = 0
    PATH = 1
    LIST = 2


W_STR_RENDERING_TYPE_2_VALUE_TYPE = MappingProxyType(
    collections.OrderedDict([
        (WithStrRenderingType.STRING, ValueType.STRING),
        (WithStrRenderingType.PATH, ValueType.PATH),
        (WithStrRenderingType.LIST, ValueType.LIST),
    ]))

VALUE_TYPE_2_W_STR_RENDERING_TYPE = MappingProxyType(
    collections.OrderedDict([
        (item[1], item[0])
        for item in W_STR_RENDERING_TYPE_2_VALUE_TYPE.items()
    ]
    ))


def sorted_types(values: Iterable[ValueType]) -> Tuple[ValueType, ...]:
    return tuple(
        sorted(values,
               key=_value_type_sorting_key)
    )


def _value_type_sorting_key(x: ValueType) -> int:
    return x.value

Ejemplo n.º 12
0
        '@type': '@id',
    },
    'rdfs:subClassOf': {
        '@type': '@id',
    },
    'octa:subjectOf': {
        '@type': '@id',
    },
    'title': 'octa:title',
    'position': 'octa:position',

    # The default namespaces list should be included in context
    # We have to convert URLs to strings though - to make them serializable.
    **{
        namespace_name: str(namespace_url)
        for namespace_name, namespace_url in DEFAULT_NAMESPACES.items() if namespace_name
    },
})


class Triple(NamedTuple):
    """RDF triple."""

    subject: rdflib.URIRef
    predicate: rdflib.URIRef
    object: Union[rdflib.URIRef, rdflib.Literal]  # noqa: WPS125

    def as_quad(self, graph: rdflib.URIRef) -> 'Quad':
        """Add graph to this triple and hence get a quad."""
        return Quad(self.subject, self.predicate, self.object, graph)
Ejemplo n.º 13
0
class Api(metaclass=ApiMeta):
    r"""
    Construct and send an NS API request.
    NS API docs can be found here: https://www.nationstates.net/pages/api.html

    This is a low-level API wrapper. Some attempts will be made to prevent bad requests,
    but it will not check shards against a verified list.
    Authentication may be provided for private nation shards.
    X-Pin headers are be stored internally and globally for ease of use.

    Api objects may be awaited or asynchronously iterated.
    To perform operations from another thread, use the :attr:`threadsafe` property.
    The Api object itself supports all :class:`collections.abc.Mapping` methods.

    ================ ============================================================================
    Operation        Description
    ================ ============================================================================
    await x          Make a request to the NS API and returns the root XML element.
    async for y in x Make a request to the NS API and return each shard element as it is parsed.
                     Useful for larger requests.
    x + y            Combines two :class:`Api` objects together into a new one.
                     Shard keywords that can't be combined will be overwritten with y's data.
    bool(x)          Naïvely check if this :class:`Api` object would result in a 400 Bad Request.
                     Truthy :class:`Api` objects may still result in a 400 Bad Request.
                     Use `len(x)` to check for containment.
    str(x)           Return the URL this :class:`Api` object will make a request to.
    Other            All other :class:`collections.abc.Mapping` methods, except x == y,
                     are also supported.
    ================ ============================================================================

    Parameters
    ----------
    \*shards: str
        Shards to request from the API.
    password: str
        X-Password authentication for private nation shards.
    autologin: str
        X-Autologin authentication for private nation shards.
    \*\*parameters: str
        Query parameters to append to the request, e.g. nation, scale.

    Examples
    --------

    Usage::

        darc = await Api(nation="darcania")
        async for shard in Api(nation="testlandia"):
           print(pretty_string(shard))

        tnp = Api(region="the_north_pacific").threadsafe()
        for shard in Api(region="testregionia").threadsafe:
            print(pretty_string(shard))
    """

    __slots__ = ("__proxy", "_password", "_str", "_hash", "_last_response")

    def __new__(
        cls,
        *shards: _Union[str, _Iterable[str]],
        password: _Optional[str] = None,
        autologin: _Optional[str] = None,
        **parameters: str,
    ):
        if len(shards) == 1 and not parameters:
            if isinstance(shards[0], cls):
                return shards[0]
            with contextlib.suppress(Exception):
                return cls.from_url(shards[0])
        return super().__new__(cls)

    def __init__(
        self,
        *shards: _Union[str, _Iterable[str]],
        password: _Optional[str] = None,
        autologin: _Optional[str] = None,
        **parameters: str,
    ):
        has_nation = "nation" in parameters
        dicts = [parameters] if parameters else []
        for shard in filter(bool, shards):
            if isinstance(shard, Mapping):
                dicts.append(shard)
                if not has_nation and "nation" in shard:
                    has_nation = True
            else:
                dicts.append({"q": shard})
        if not has_nation and (password or autologin):
            raise ValueError(
                "Authentication may only be used with the Nation API.")
        self.__proxy = MappingProxyType(_normalize_dicts(*dicts))
        self._password = password
        self._last_response = None
        self._str = None
        self._hash = None

    async def __await(self):
        async for element in self.__aiter__(clear=False):
            pass
        return element

    def __await__(self):
        return self.__await().__await__()

    async def __aiter__(
        self,
        *,
        clear: bool = True
    ) -> _AsyncGenerator[objectify.ObjectifiedElement, None]:
        if not Api.agent:
            raise RuntimeError("The API's user agent is not yet set.")
        if "a" in self and self["a"].lower() == "sendtg":
            raise RuntimeError(
                "This API wrapper does not support API telegrams.")
        if not self:
            # Preempt the request to conserve ratelimit
            raise BadRequest()
        url = str(self)

        headers = {"User-Agent": Api.agent}
        if self._password:
            headers["X-Password"] = self._password
        autologin = self.autologin
        if autologin:
            headers["X-Autologin"] = autologin
        if self.get("nation") in PINS:
            headers["X-Pin"] = PINS[self["nation"]]

        async with Api.session.request("GET", url,
                                       headers=headers) as response:
            self._last_response = response
            if "X-Autologin" in response.headers:
                self._password = None
            if "X-Pin" in response.headers:
                PINS[self["nation"]] = response.headers["X-Pin"]
            response.raise_for_status()

            encoding = (response.headers["Content-Type"].split("charset=")
                        [1].split(",")[0])
            with contextlib.suppress(etree.XMLSyntaxError), contextlib.closing(
                    etree.XMLPullParser(["end"],
                                        base_url=url,
                                        remove_blank_text=True)) as parser:
                parser.set_element_class_lookup(
                    objectify.ObjectifyElementClassLookup())
                events = parser.read_events()

                async for data, _ in response.content.iter_chunks():
                    parser.feed(data.decode(encoding))
                    for _, element in events:
                        if clear and (element.getparent() is None
                                      or element.getparent().getparent()
                                      is not None):
                            continue
                        yield element
                        if clear:
                            element.clear(keep_tail=True)

    def __add__(self, other: _Any) -> "Api":
        if isinstance(other, str):
            with contextlib.suppress(Exception):
                other = type(self).from_url(other)
        with contextlib.suppress(Exception):
            return type(self)(self, other)
        return NotImplemented

    def __bool__(self):
        if any(a in self for a in ("nation", "region")):
            return True
        if "a" in self:
            if self["a"] == "verify" and all(a in self
                                             for a in ("nation", "checksum")):
                return True
            if self["a"] == "sendtg" and all(
                    a in self for a in ("client", "tgid", "key", "to")):
                return True
            return False
        return "q" in self

    def __contains__(self, key: str) -> bool:
        return key in self.__proxy

    def __dir__(self):
        return set(super().__dir__()).union(
            dir(self.__proxy),
            (a for a in dir(type(self)) if not hasattr(type, a)))

    __eq__ = None

    def __getattr__(self, name: str):
        with contextlib.suppress(AttributeError):
            return getattr(self.__proxy, name)
        raise AttributeError(
            f"{type(self).__name__!r} has no attribute {name!r}")

    def __getitem__(self, key):
        return self.__proxy[str(key).lower()]

    def __hash__(self):
        if self._hash is not None:
            return self._hash
        params = sorted((k, v if isinstance(v, str) else " ".join(sorted(v)))
                        for k, v in self.items())
        self._hash = hash(tuple(params))
        return self._hash

    def __iter__(self):
        return iter(self.__proxy)

    def __len__(self):
        return len(self.__proxy)

    def __repr__(self) -> str:
        return "{}.{}({})".format(
            type(self).__module__,
            type(self).__name__,
            ", ".join("{}={!r}".format(*t) for t in self.__proxy.items()),
        )

    def __str__(self) -> str:
        if self._str is not None:
            return self._str
        params = [(k, v if isinstance(v, str) else "+".join(v))
                  for k, v in self.items()]
        self._str = urlunparse((*API_URL, None, urlencode(params,
                                                          safe="+"), None))
        return self._str

    @property
    def autologin(self) -> _Optional[str]:
        """
        If a private nation shard was properly requested and returned,
        this property may be used to get the "X-Autologin" token.
        """
        if self._last_response:
            return self._last_response.headers.get("X-Autologin")
        return None

    @property
    def last_headers(self) -> _Optional[_Mapping[str, str]]:
        """
        Returns the headers returned from the last request this API object sent.
        """
        if self._last_response:
            return self._last_response.headers
        return None

    @property
    def last_response(self) -> _Optional[NSResponse]:
        """
        Returns the response object from the last request this API object sent.
        """
        return self._last_response

    @property
    def threadsafe(self) -> Threadsafe:
        """
        Returns a threadsafe wrapper around this object.

        The returned wrapper may be called, awaited, or iterated over.
        Both standard and async iteration are supported.
        """
        return Threadsafe(self)

    @classmethod
    def from_url(cls, url: str, *shards: _Iterable[str],
                 **parameters: _Iterable[str]) -> "Api":
        """
        Constructs an Api object from a provided URL.

        The Api object may be further modified with shards and parameters,
        as per the :class:`Api` constructor.
        """
        parsed_url = urlparse(str(url))
        url = parsed_url[:len(API_URL)]
        if any(url) and url != API_URL:
            raise ValueError(
                "URL must be solely query parameters or an API url")
        return cls(*shards, parse_qs(parsed_url.query), parameters)
Ejemplo n.º 14
0
class Environment(object):
    """Abstract base class for environments.

  Represents a type and configuration of environment.
  Each type of Environment should have a unique urn.

  For internal use only. No backwards compatibility guarantees.
  """

    _known_urns = {}  # type: Dict[str, Tuple[Optional[type], ConstructorFn]]
    _urn_to_env_cls = {}  # type: Dict[str, type]

    def __init__(
            self,
            capabilities=(),  # type: Iterable[str]
            artifacts=(
            ),  # type: Iterable[beam_runner_api_pb2.ArtifactInformation]
            resource_hints=None,  # type: Optional[Mapping[str, bytes]]
    ):
        # type: (...) -> None
        self._capabilities = capabilities
        self._artifacts = sorted(artifacts,
                                 key=lambda x: x.SerializeToString())
        # Hints on created environments should be immutable since pipeline context
        # stores environments in hash maps and we use hints to compute the hash.
        self._resource_hints = MappingProxyType(
            dict(resource_hints) if resource_hints else {})

    def __eq__(self, other):
        return (
            self.__class__ == other.__class__
            and self._artifacts == other._artifacts
            # Assuming that we don't have instances of the same Environment subclass
            # with different set of capabilities.
            and self._resource_hints == other._resource_hints)

    def __hash__(self):
        # type: () -> int
        return hash((self.__class__, frozenset(self._resource_hints.items())))

    def artifacts(self):
        # type: () -> Iterable[beam_runner_api_pb2.ArtifactInformation]
        return self._artifacts

    def to_runner_api_parameter(self, context):
        # type: (PipelineContext) -> Tuple[str, Optional[Union[message.Message, bytes, str]]]
        raise NotImplementedError

    def capabilities(self):
        # type: () -> Iterable[str]
        return self._capabilities

    def resource_hints(self):
        # type: () -> Mapping[str, bytes]
        return self._resource_hints

    @classmethod
    @overload
    def register_urn(
            cls,
            urn,  # type: str
            parameter_type,  # type: Type[T]
    ):
        # type: (...) -> Callable[[Union[type, Callable[[T, Iterable[str], PipelineContext], Any]]], Callable[[T, Iterable[str], PipelineContext], Any]]
        pass

    @classmethod
    @overload
    def register_urn(
            cls,
            urn,  # type: str
            parameter_type,  # type: None
    ):
        # type: (...) -> Callable[[Union[type, Callable[[bytes, Iterable[str], Iterable[beam_runner_api_pb2.ArtifactInformation], PipelineContext], Any]]], Callable[[bytes, Iterable[str], PipelineContext], Any]]
        pass

    @classmethod
    @overload
    def register_urn(
        cls,
        urn,  # type: str
        parameter_type,  # type: Type[T]
        constructor  # type: Callable[[T, Iterable[str], Iterable[beam_runner_api_pb2.ArtifactInformation], PipelineContext], Any]
    ):
        # type: (...) -> None
        pass

    @classmethod
    @overload
    def register_urn(
        cls,
        urn,  # type: str
        parameter_type,  # type: None
        constructor  # type: Callable[[bytes, Iterable[str], Iterable[beam_runner_api_pb2.ArtifactInformation], PipelineContext], Any]
    ):
        # type: (...) -> None
        pass

    @classmethod
    def register_urn(cls, urn, parameter_type, constructor=None):
        def register(constructor):
            if isinstance(constructor, type):
                constructor.from_runner_api_parameter = register(
                    constructor.from_runner_api_parameter)
                # register environment urn to environment class
                cls._urn_to_env_cls[urn] = constructor
                return constructor

            else:
                cls._known_urns[urn] = parameter_type, constructor
                return staticmethod(constructor)

        if constructor:
            # Used as a statement.
            register(constructor)
        else:
            # Used as a decorator.
            return register

    @classmethod
    def get_env_cls_from_urn(cls, urn):
        # type: (str) -> Type[Environment]
        return cls._urn_to_env_cls[urn]

    def to_runner_api(self, context):
        # type: (PipelineContext) -> beam_runner_api_pb2.Environment
        urn, typed_param = self.to_runner_api_parameter(context)
        return beam_runner_api_pb2.Environment(
            urn=urn,
            payload=typed_param.SerializeToString() if isinstance(
                typed_param, message.Message) else typed_param if
            (isinstance(typed_param, bytes)
             or typed_param is None) else typed_param.encode('utf-8'),
            capabilities=self.capabilities(),
            dependencies=self.artifacts(),
            resource_hints=self.resource_hints())

    @classmethod
    def from_runner_api(
        cls,
        proto,  # type: Optional[beam_runner_api_pb2.Environment]
        context  # type: PipelineContext
    ):
        # type: (...) -> Optional[Environment]
        if proto is None or not proto.urn:
            return None
        parameter_type, constructor = cls._known_urns[proto.urn]

        return constructor(
            proto_utils.parse_Bytes(proto.payload,
                                    parameter_type), proto.capabilities,
            proto.dependencies, proto.resource_hints, context)

    @classmethod
    def from_options(cls, options):
        # type: (Type[EnvironmentT], PortableOptions) -> EnvironmentT
        """Creates an Environment object from PortableOptions.

    Args:
      options: The PortableOptions object.
    """
        raise NotImplementedError
Ejemplo n.º 15
0
        skill['Two Heads'],
        skill['Very Long Legs'],
        ))),
    (skillcat['RACIAL CHARACTERISTICS'], frozenset((
        skill['Always Hungry'],
        skill['Big Guy'],
        skill['Blood Lust'],
        skill['Bone Head'],
        skill['Easily Confused'],
        skill['Hypnotic Gaze'],
        skill['Nurgle\'s Rot'],
        skill['Really Stupid'],
        skill['Regeneration'],
        skill['Right Stuff'],
        skill['Stunty'],
        skill['Take Root'],
        skill['Throw Team-Mate'],
        skill['Thrud\'s Fans'],
        skill['Wild Animal'],
        ))),
    )))


## REDUNDANT BUT USEFUL ########################################

skillcat_by_skill = MappingProxyType(
    {skill: skillcat
        for skillcat, skills in skill_by_skillcat.items()
        for skill in skills}
    )
Ejemplo n.º 16
0
class Api(metaclass=_ApiMeta):
    __slots__ = ("__dict", )

    def __init__(self, *shards: _Union[str, _Mapping[str, str]],
                 **kwargs: str):
        dicts = [kwargs] if kwargs else []
        for shard in shards:
            if isinstance(shard, Mapping) and shard:
                dicts.append(shard)
            elif shard:
                dicts.append({"q": shard})
        self.__dict = MappingProxyType(_normalize_dicts(*dicts))

    async def __await(self):
        # pylint: disable=E1133
        async for element in self.__aiter__(clear=False):
            pass
        return element

    def __await__(self):
        return self.__await().__await__()

    async def __aiter__(self, *, clear: bool = True):
        if not self:
            raise ValueError("Bad request")
        url = str(self)

        parser = etree.XMLPullParser(["end"],
                                     base_url=url,
                                     remove_blank_text=True)
        parser.set_element_class_lookup(
            etree.ElementDefaultClassLookup(element=_NSElement))
        events = parser.read_events()

        async with type(self).session.request(
                "GET", url, headers={"User-Agent":
                                     type(self).agent}) as response:
            yield parser.makeelement("HEADERS", attrib=response.headers)
            encoding = response.headers["Content-Type"].split(
                "charset=")[1].split(",")[0]
            async for data, _ in response.content.iter_chunks():
                parser.feed(data.decode(encoding))
                for _, element in events:
                    yield element
                    if clear:
                        element.clear()

    def __add__(self, other: _Any) -> "Api":
        with contextlib.suppress(Exception):
            return type(self)(self, other)
        return NotImplemented

    def __bool__(self):
        return any(a in self for a in ("a", "nation", "region", "q", "wa"))

    def __contains__(self, key):
        return key in self.__dict

    def __dir__(self):
        return set(super().__dir__()).union(dir(self.__dict))

    def __getattribute__(self, name):
        try:
            return super().__getattribute__(name)
        except AttributeError:
            with contextlib.suppress(AttributeError):
                return getattr(self.__dict, name)
            raise

    def __getitem__(self, key):
        return self.__dict[str(key).lower()]

    def __iter__(self):
        return iter(self.__dict)

    def __len__(self):
        return len(self.__dict)

    def __repr__(self) -> str:
        return "{}({})".format(
            type(self).__name__,
            ", ".join(
                "{}={!r}".format(k, v if isinstance(v, str) else " ".join(v))
                for k, v in self.__dict.items()),
        )

    def __str__(self) -> str:
        params = [(k, v if isinstance(v, str) else " ".join(v))
                  for k, v in self.items()]
        return urlunparse((*API_URL, None, urlencode(params), None))

    def copy(self):
        return type(self)(**self.__dict)

    @classmethod
    def from_url(cls, url: str, *args, **kwargs):
        parsed_url = urlparse(str(url))
        url = parsed_url[:len(API_URL)]
        if any(url) and url != API_URL:
            raise ValueError(
                "URL must be solely query parameters or an API url")
        return cls(*args, dict(parse_qsl(parsed_url.query)), kwargs)
Ejemplo n.º 17
0
class Composition:
    """
    Defines a composition of a compound.

    To create a composition, use the class methods:

        - :meth:`from_pure`
        - :meth:`from_formula`
        - :meth:`from_mass_fractions`
        - :meth:`from_atomic_fractions`

    Use the following attributes to access the composition values:

        - :attr:`mass_fractions`: :class:`dict` where the keys are atomic numbers and the values weight fractions.
        - :attr:`atomic_fractions`: :class:`dict` where the keys are atomic numbers and the values atomic fractions.
        - :attr:`formula`: chemical formula

    The composition object is immutable, i.e. it cannot be modified once created.
    Equality can be checked.
    It is hashable.
    It can be pickled or copied.
    """

    _key = object()
    PRECISION = 0.000000001 # 1ppb

    def __init__(self, key, mass_fractions, atomic_fractions, formula):
        """
        Private constructor. It should never be used.
        """
        if key != Composition._key:
            raise TypeError('Composition cannot be created using constructor')
        if set(mass_fractions.keys()) != set(atomic_fractions.keys()):
            raise ValueError('Mass and atomic fractions must have the same elements')

        self.mass_fractions = MappingProxyType(mass_fractions)
        self.atomic_fractions = MappingProxyType(atomic_fractions)
        self._formula = formula

    @classmethod
    def from_pure(cls, z):
        """
        Creates a pure composition.

        Args:
            z (int): atomic number
        """
        return cls(cls._key, {z: 1.0}, {z: 1.0}, pyxray.element_symbol(z))

    @classmethod
    def from_formula(cls, formula):
        """
        Creates a composition from a chemical formula.

        Args:
            formula (str): chemical formula
        """
        atomic_fractions = convert_formula_to_atomic_fractions(formula)
        return cls.from_atomic_fractions(atomic_fractions)

    @classmethod
    def from_mass_fractions(cls, mass_fractions, formula=None):
        """
        Creates a composition from a mass fraction :class:`dict`.

        Args:
            mass_fractions (dict): mass fraction :class:`dict`.
                The keys are atomic numbers and the values weight fractions.
                Wildcard are accepted, e.g. ``{5: '?', 25: 0.4}`` where boron
                will get a mass fraction of 0.6.
            formula (str): optional chemical formula for the composition.
                If ``None``, a formula will be generated for the composition.
        """
        mass_fractions = process_wildcard(mass_fractions)
        atomic_fractions = convert_mass_to_atomic_fractions(mass_fractions)
        if not formula:
            formula = generate_name(atomic_fractions)
        return cls(cls._key, mass_fractions, atomic_fractions, formula)

    @classmethod
    def from_atomic_fractions(cls, atomic_fractions, formula=None):
        """
        Creates a composition from an atomic fraction :class:`dict`.

        Args:
            atomic_fractions (dict): atomic fraction :class:`dict`.
                The keys are atomic numbers and the values atomic fractions.
                Wildcard are accepted, e.g. ``{5: '?', 25: 0.4}`` where boron
                will get a atomic fraction of 0.6.
            formula (str): optional chemical formula for the composition.
                If ``None``, a formula will be generated for the composition.
        """
        atomic_fractions = process_wildcard(atomic_fractions)
        mass_fractions = convert_atomic_to_mass_fractions(atomic_fractions)
        if not formula:
            formula = generate_name(atomic_fractions)
        return cls(cls._key, mass_fractions, atomic_fractions, formula)

    def __len__(self):
        return len(self.mass_fractions)

    def __contains__(self, z):
        return z in self.mass_fractions

    def __iter__(self):
        return iter(self.mass_fractions.keys())

    def __repr__(self):
        return '<{}({})>'.format(self.__class__.__name__, self.inner_repr())

    def __eq__(self, other):
        if not isinstance(other, self.__class__):
            return False

        if len(self) != len(other):
            return False

        for z in self.mass_fractions:
            if z not in other.mass_fractions:
                return False

            fraction = self.mass_fractions[z]
            other_fraction = other.mass_fractions[z]

            if not math.isclose(fraction, other_fraction, abs_tol=self.PRECISION):
                return False

        return True

    def __ne__(self, other):
        return not self == other

    def __hash__(self):
        out = []
        for z in sorted(self.mass_fractions):
            out.append(z)
            out.append(int(self.mass_fractions[z] / self.PRECISION))

        return hash(tuple(out))

    def __getstate__(self):
        return {'mass_fractions': dict(self.mass_fractions),
                'atomic_fractions': dict(self.atomic_fractions),
                'formula': self.formula}

    def __setstate__(self, state):
        self.mass_fractions = MappingProxyType(state.get('mass_fractions', {}))
        self.atomic_fractions = MappingProxyType(state.get('atomic_fractions', {}))
        self._formula = state.get('formula', '')

    def is_normalized(self):
        return math.isclose(sum(self.mass_fractions.values()), 1.0, abs_tol=self.PRECISION)

    def inner_repr(self):
        return ', '.join('{}: {:.4f}'.format(pyxray.element_symbol(z), mass_fraction) for z, mass_fraction in self.mass_fractions.items())

    @property
    def formula(self):
        return self._formula
Ejemplo n.º 18
0
        ))),
    )))


stat_value = {
    '+MA': 30000,
    '+AG': 40000,
    '+ST': 50000,
    '+AV': 30000,
    }


## REDUNDANT BUT USEFUL ########################################

deck_of_card = MappingProxyType(
    {card: deck
        for deck, cards in cards_of_deck.items()
        for card in cards}
    )

card_price = MappingProxyType(
    {card: deck_price[deck]
        for card, deck in deck_of_card.items()}
    )

skillcat_by_skill = MappingProxyType(
    {skill: skillcat
        for skillcat, skills in skill_by_skillcat.items()
        for skill in skills}
    )
Ejemplo n.º 19
0
    def __init__(self, raster, list_of_prod_fp, channel_ids, is_flat,
                 dst_nodata, interpolation, max_queue_size, parent_uid,
                 key_in_parent):
        # Mutable attributes ******************************************************************** **
        # Attributes that relates a query to a single optional computation phase
        self.cache_computation = None  # type: Union[None, CacheComputationInfos]

        # Immutable attributes ****************************************************************** **
        self.parent_uid = parent_uid
        self.key_in_parent = key_in_parent

        # The parameters given by user in invocation
        self.channel_ids = channel_ids  # type: Sequence[int]
        self.is_flat = is_flat  # type: bool
        self.unique_channel_ids = []
        for bi in channel_ids:
            if bi not in self.unique_channel_ids:
                self.unique_channel_ids.append(bi)
        self.unique_channel_ids = tuple(self.unique_channel_ids)

        self.dst_nodata = dst_nodata  # type: Union[int, float]
        self.interpolation = interpolation  # type: str

        # Output max queue size (Parameter given to queue.Queue)
        self.max_queue_size = max_queue_size  # type: int

        # How many arrays are requested
        self.produce_count = len(list_of_prod_fp)  # type: int

        # Build CacheProduceInfos objects **************************************
        to_zip = []

        # The list of Footprints requested
        list_of_prod_fp = list_of_prod_fp  # type: List[ProductionFootprint]
        to_zip.append(list_of_prod_fp)

        # Boolean attribute of each `prod_fp`
        # If `True` the resampling phase has to be performed on a Pool
        list_of_prod_same_grid = [
            fp.same_grid(raster.fp) for fp in list_of_prod_fp
        ]  # type: List[bool]
        to_zip.append(list_of_prod_same_grid)

        # Boolean attribute of each `prod_fp`
        # If `False` the queried footprint is outside of raster's footprint. It means that no
        # sampling is necessary and the outputed array will be full of `dst_nodata`
        list_of_prod_share_area = [
            fp.share_area(raster.fp) for fp in list_of_prod_fp
        ]  # type: List[bool]
        to_zip.append(list_of_prod_share_area)

        # The full Footprint that needs to be sampled for each `prod_fp`
        # Is `None` if `prod_fp` is fully outside of raster
        list_of_prod_sample_fp = []  # type: List[Union[None, SampleFootprint]]
        to_zip.append(list_of_prod_sample_fp)

        # The set of cache Footprints that are needed for each `prod_fp`
        list_of_prod_cache_fps = []  # type: List[FrozenSet[CacheFootprint]]
        to_zip.append(list_of_prod_cache_fps)

        # The list of resamplings to perform for each `prod_fp`
        # Always at least 1 resampling per `prod_fp`
        list_of_prod_resample_fps = [
        ]  # type: List[Tuple[ResampleFootprint, ...]]
        to_zip.append(list_of_prod_resample_fps)

        # The set of `cache_fp` necessary per `resample_fp` for each `prod_fp`
        list_of_prod_resample_cache_deps_fps = [
        ]  # type: List[Mapping[ResampleFootprint, FrozenSet[CacheFootprint]]]
        to_zip.append(list_of_prod_resample_cache_deps_fps)

        # The full Footprint that needs to be sampled par `resample_fp` for each `prod_fp`
        # Might be `None` if `prod_fp` is fully outside of raster
        list_of_prod_resample_sample_dep_fp = [
        ]  # type: List[Mapping[ResampleFootprint, Union[None, SampleFootprint]]]
        to_zip.append(list_of_prod_resample_sample_dep_fp)

        it = zip(list_of_prod_fp, list_of_prod_same_grid,
                 list_of_prod_share_area)
        # TODO: Speed up that piece of code
        # - Code footprint with lower level code
        # - Spawn a ProcessPoolExecutor when >100 prod. (The same could be done for fp.tile).
        # - What about a global process pool executor in `buzz.env`?
        for prod_fp, same_grid, share_area in it:
            if not share_area:
                # Resampling will be performed in one pass, on the scheduler
                list_of_prod_sample_fp.append(None)
                list_of_prod_cache_fps.append(frozenset())
                resample_fp = cast(ResampleFootprint, prod_fp)
                list_of_prod_resample_fps.append((resample_fp, ))
                list_of_prod_resample_cache_deps_fps.append(
                    MappingProxyType({resample_fp: frozenset()}))
                list_of_prod_resample_sample_dep_fp.append(
                    MappingProxyType({resample_fp: None}))
            else:
                if same_grid:
                    # Remapping will be performed in one pass, on the scheduler
                    sample_fp = raster.fp & prod_fp
                    resample_fps = [cast(ResampleFootprint, prod_fp)]
                    sample_dep_fp = {resample_fps[0]: sample_fp}
                else:
                    sample_fp = raster.build_sampling_footprint_to_remap_interpolate(
                        prod_fp, interpolation)

                    if raster.max_resampling_size is None:
                        # Remapping will be performed in one pass, on a Pool
                        resample_fps = [cast(ResampleFootprint, prod_fp)]
                        sample_dep_fp = {resample_fps[0]: sample_fp}
                    else:
                        # Resampling will be performed in several passes, on a Pool
                        rsize = np.maximum(prod_fp.rsize, sample_fp.rsize)
                        countx, county = np.ceil(
                            rsize / raster.max_resampling_size).astype(int)
                        resample_fps = prod_fp.tile_count(
                            countx, county,
                            boundary_effect='shrink').flatten().tolist()
                        sample_dep_fp = {
                            resample_fp:
                            (raster.
                             build_sampling_footprint_to_remap_interpolate(
                                 resample_fp, interpolation)
                             if resample_fp.share_area(raster.fp) else None)
                            for resample_fp in resample_fps
                        }

                resample_cache_deps_fps = MappingProxyType({
                    resample_fp:
                    frozenset(raster.cache_fps_of_fp(sample_subfp))
                    for resample_fp in resample_fps
                    for sample_subfp in [sample_dep_fp[resample_fp]]
                    if sample_subfp is not None
                })
                for s in resample_cache_deps_fps.items():
                    assert len(s) > 0

                # The `intersection of the cache_fps with sample_fp` might not be the same as the
                # the `intersection of the cache_fps with resample_fps`!
                cache_fps = frozenset(
                    itertools.chain.from_iterable(
                        resample_cache_deps_fps.values()))
                assert len(cache_fps) > 0

                list_of_prod_cache_fps.append(cache_fps)
                list_of_prod_sample_fp.append(sample_fp)
                list_of_prod_resample_fps.append(tuple(resample_fps))
                list_of_prod_resample_cache_deps_fps.append(
                    resample_cache_deps_fps)
                list_of_prod_resample_sample_dep_fp.append(
                    MappingProxyType(sample_dep_fp))

        self.prod = tuple([CacheProduceInfos(*args) for args in zip(*to_zip)
                           ])  # type: Tuple[CacheProduceInfos, ...]

        # Misc *****************************************************************
        # The list of all cache Footprints needed, ordered by priority
        self.list_of_cache_fp = []  # type: Sequence[CacheFootprint]
        seen = set()
        for fps in list_of_prod_cache_fps:
            for fp in fps:
                if fp not in seen:
                    seen.add(fp)
                    self.list_of_cache_fp.append(fp)
        self.list_of_cache_fp = tuple(self.list_of_cache_fp)
        del seen

        # The dict of cache Footprint to set of production idxs
        # For each `cache_fp`, the set of prod_idx that need this cache tile
        self.dict_of_prod_idxs_per_cache_fp = collections.defaultdict(
            set)  # type: Mapping[CacheFootprint, AbstractSet[int]]
        for i, (prod_fp, cache_fps) in enumerate(
                zip(list_of_prod_fp, list_of_prod_cache_fps)):
            for cache_fp in cache_fps:
                self.dict_of_prod_idxs_per_cache_fp[cache_fp].add(i)
        for k, v in self.dict_of_prod_idxs_per_cache_fp.items():
            self.dict_of_prod_idxs_per_cache_fp[k] = frozenset(v)
        self.dict_of_prod_idxs_per_cache_fp = MappingProxyType(
            self.dict_of_prod_idxs_per_cache_fp)

        # The dict of cache Footprint to production_idx
        # For each `cache_fp`, the minimum prod_idx that need this cache tile
        self.dict_of_min_prod_idx_per_cache_fp = {
        }  # type: Mapping[CacheFootprint, int]
        for k, v in self.dict_of_prod_idxs_per_cache_fp.items():
            self.dict_of_min_prod_idx_per_cache_fp[k] = min(v)
        self.dict_of_min_prod_idx_per_cache_fp = MappingProxyType(
            self.dict_of_min_prod_idx_per_cache_fp)
Ejemplo n.º 20
0
class Loader:
    def __init__(self, filename):
        self.filename = filename
        self._categories = {'bot': BotCategory()}
        self.categories = MappingProxyType(self._categories)
        self._errors = []

    def add_category(self, name, category):
        if name in self._categories:
            raise SettingsKeyError(f'Defaults for category `{key}`'
                                   ' provided more than once.')

        if not (category and isinstance(category, Category)):
            raise SettingsTypeError(f'Incorrect Category type for `{key}:'
                                    f' `{type(category)}`')

        self._categories[name] = category

    def remove_category(self, name):
        return self._categories.pop(name, None) is not None

    def load(self):
        self._errors.clear()
        toml_dict = toml.load(self.filename)

        for name, category in self.categories.items():
            self._load_category(category, name, toml_dict)

        if self._errors:
            raise SettingsError(self._errors)

        if self._has_leftover_values(toml_dict):
            logging.warning(f'Leftover settings remain: `{repr(toml_dict)}`')

        return Settings(self.categories)

    def _load_category(self, category, name, toml_dict, _prefix=''):
        settings = [s for s in dir(category) if not s.startswith('__')]
        for setting in settings:
            try:
                self._load_setting(setting, category, name, toml_dict, _prefix)
            except SettingsError as ex:
                self._errors.append(ex)

    def _load_setting(self, setting, category, name, toml_dict, prefix):
        default = getattr(category, setting)
        setting_type = type(default)
        required = False

        if isinstance(default, Required):
            setting_type = default.type
            required = True

        category_dict = toml_dict.get(name, {})

        if not isinstance(category_dict, dict):
            raise SettingsTypeError(f'Setting `{name}` must be a category'
                                    f' but was `{type(category_dict)}'
                                    ' instead.`')

        if isinstance(default, Category):
            return self._load_category(default, setting, category_dict,
                                       f'{name}.{prefix}')

        if setting in category_dict:
            # remove (pop) from dict so we know it was used
            value = category_dict.pop(setting)
            value_type = type(value)

            if value_type is int and setting_type is float:
                value = float(value)

            elif value_type is not setting_type:
                raise SettingsTypeError(f'Setting `{prefix}{name}:{setting}`'
                                        f' must be of type `{setting_type}`'
                                        f' but was `{value_type}` instead.')

            setattr(category, setting, value)

        elif required:
            raise SettingsKeyError(f'Setting `{prefix}{name}:{setting}`'
                                   ' required but not provided.')

    def _has_leftover_values(self, toml_dict):
        has_leftovers = False

        # copy to list since dict changes size while iterating
        for key, value in list(toml_dict.items()):
            if not isinstance(value, dict):
                has_leftovers = True
                continue

            if self._has_leftover_values(value):
                has_leftovers = True
            else:
                del toml_dict[key]  # remove empty subdicts

        return has_leftovers
Ejemplo n.º 21
0
class Trainer:
    def __init__(self, model, dataset, criterion, optimizer, settings):
        super().__init__()
        context = deepcopy(settings)
        self.ctx = MappingProxyType(vars(context))
        self.mode = ('train', 'val').index(context.cmd)

        self.logger = R['LOGGER']
        self.gpc = R['GPC']  # Global Path Controller
        self.path = self.gpc.get_path

        self.batch_size = context.batch_size
        self.checkpoint = context.resume
        self.load_checkpoint = (len(self.checkpoint) > 0)
        self.num_epochs = context.num_epochs
        self.lr = float(context.lr)
        self.save = context.save_on or context.out_dir
        self.out_dir = context.out_dir
        self.trace_freq = int(context.trace_freq)
        self.device = torch.device(context.device)
        self.suffix_off = context.suffix_off

        for k, v in sorted(self.ctx.items()):
            self.logger.show("{}: {}".format(k, v))

        self.model = model_factory(model, context)
        self.model.to(self.device)
        self.criterion = critn_factory(criterion, context)
        self.criterion.to(self.device)
        self.metrics = metric_factory(context.metrics, context)

        if self.is_training:
            self.train_loader = data_factory(dataset, 'train', context)
            self.val_loader = data_factory(dataset, 'val', context)
            self.optimizer = optim_factory(optimizer, self.model, context)
        else:
            self.val_loader = data_factory(dataset, 'val', context)

        self.start_epoch = 0
        self._init_max_acc_and_epoch = (0.0, 0)

    @property
    def is_training(self):
        return self.mode == 0

    def train_epoch(self, epoch):
        raise NotImplementedError

    def validate_epoch(self, epoch=0, store=False):
        raise NotImplementedError

    def _write_prompt(self):
        self.logger.dump(input("\nWrite some notes: "))

    def run(self):
        if self.is_training:
            self._write_prompt()
            self.train()
        else:
            self.evaluate()

    def train(self):
        if self.load_checkpoint:
            self._resume_from_checkpoint()

        max_acc, best_epoch = self._init_max_acc_and_epoch

        for epoch in range(self.start_epoch, self.num_epochs):
            lr = self._adjust_learning_rate(epoch)

            self.logger.show_nl("Epoch: [{0}]\tlr {1:.06f}".format(epoch, lr))

            # Train for one epoch
            self.train_epoch(epoch)

            # Clear the history of metric objects
            for m in self.metrics:
                m.reset()

            # Evaluate the model on validation set
            self.logger.show_nl("Validate")
            acc = self.validate_epoch(epoch=epoch, store=self.save)

            is_best = acc > max_acc
            if is_best:
                max_acc = acc
                best_epoch = epoch
            self.logger.show_nl(
                "Current: {:.6f} ({:03d})\tBest: {:.6f} ({:03d})\t".format(
                    acc, epoch, max_acc, best_epoch))

            # The checkpoint saves next epoch
            self._save_checkpoint(self.model.state_dict(),
                                  self.optimizer.state_dict(),
                                  (max_acc, best_epoch), epoch + 1, is_best)

    def evaluate(self):
        if self.checkpoint:
            if self._resume_from_checkpoint():
                self.validate_epoch(self.ckp_epoch, self.save)
        else:
            self.logger.warning("Warning: no checkpoint assigned!")

    def _adjust_learning_rate(self, epoch):
        if self.ctx['lr_mode'] == 'step':
            lr = self.lr * (0.5**(epoch // self.ctx['step']))
        elif self.ctx['lr_mode'] == 'poly':
            lr = self.lr * (1 - epoch / self.num_epochs)**1.1
        elif self.ctx['lr_mode'] == 'const':
            lr = self.lr
        else:
            raise ValueError('unknown lr mode {}'.format(self.ctx['lr_mode']))

        for param_group in self.optimizer.param_groups:
            param_group['lr'] = lr
        return lr

    def _resume_from_checkpoint(self):
        ## XXX: This could be slow!
        if not os.path.isfile(self.checkpoint):
            self.logger.error("=> No checkpoint was found at '{}'.".format(
                self.checkpoint))
            return False

        self.logger.show("=> Loading checkpoint '{}'".format(self.checkpoint))
        checkpoint = torch.load(self.checkpoint, map_location=self.device)

        state_dict = self.model.state_dict()
        ckp_dict = checkpoint.get('state_dict', checkpoint)
        update_dict = {
            k: v
            for k, v in ckp_dict.items()
            if k in state_dict and state_dict[k].shape == v.shape
        }

        num_to_update = len(update_dict)
        if (num_to_update < len(state_dict)) or (len(state_dict) <
                                                 len(ckp_dict)):
            if not self.is_training and (num_to_update < len(state_dict)):
                self.logger.error("=> Mismatched checkpoint for evaluation")
                return False
            self.logger.warning(
                "Warning: trying to load an mismatched checkpoint.")
            if num_to_update == 0:
                self.logger.error("=> No parameter is to be loaded.")
                return False
            else:
                self.logger.warning(
                    "=> {} params are to be loaded.".format(num_to_update))
        elif (not self.ctx['anew']) or not self.is_training:
            self.start_epoch = checkpoint.get('epoch', 0)
            max_acc_and_epoch = checkpoint.get('max_acc',
                                               (0.0, self.ckp_epoch))
            # For backward compatibility
            if isinstance(max_acc_and_epoch, (float, int)):
                self._init_max_acc_and_epoch = (max_acc_and_epoch,
                                                self.ckp_epoch)
            else:
                self._init_max_acc_and_epoch = max_acc_and_epoch
            if self.ctx['load_optim'] and self.is_training:
                try:
                    # Note that weight decay might be modified here
                    self.optimizer.load_state_dict(checkpoint['optimizer'])
                except KeyError:
                    self.logger.warning(
                        "Warning: failed to load optimizer parameters.")

        state_dict.update(update_dict)
        self.model.load_state_dict(state_dict)

        self.logger.show(
            "=> Loaded checkpoint '{}' (epoch {}, max_acc {:.4f} at epoch {})".
            format(self.checkpoint, self.ckp_epoch,
                   *self._init_max_acc_and_epoch))
        return True

    def _save_checkpoint(self, state_dict, optim_state, max_acc, epoch,
                         is_best):
        state = {
            'epoch': epoch,
            'state_dict': state_dict,
            'optimizer': optim_state,
            'max_acc': max_acc
        }
        # Save history
        history_path = self.path('weight',
                                 constants.CKP_COUNTED.format(e=epoch),
                                 underline=True)
        if epoch % self.trace_freq == 0:
            torch.save(state, history_path)
        # Save latest
        latest_path = self.path('weight', constants.CKP_LATEST, underline=True)
        torch.save(state, latest_path)
        if is_best:
            shutil.copyfile(
                latest_path,
                self.path('weight', constants.CKP_BEST, underline=True))

    @property
    def ckp_epoch(self):
        # Get current epoch of the checkpoint
        # For dismatched ckp or no ckp, set to 0
        return max(self.start_epoch - 1, 0)

    def save_image(self, file_name, image, epoch):
        file_path = os.path.join('epoch_{}/'.format(epoch), self.out_dir,
                                 file_name)
        out_path = self.path('out',
                             file_path,
                             suffix=not self.suffix_off,
                             auto_make=True,
                             underline=True)
        return io.imsave(out_path, image)
Ejemplo n.º 22
0
def convert_mapping_proxy_type(value: MappingProxyType):
    return dict(value.items())
Ejemplo n.º 23
0
class Composition:
    """
    Defines a composition of a compound.

    To create a composition, use the class methods:

        - :meth:`from_pure`
        - :meth:`from_formula`
        - :meth:`from_mass_fractions`
        - :meth:`from_atomic_fractions`

    Use the following attributes to access the composition values:

        - :attr:`mass_fractions`: :class:`dict` where the keys are atomic numbers and the values weight fractions.
        - :attr:`atomic_fractions`: :class:`dict` where the keys are atomic numbers and the values atomic fractions.
        - :attr:`formula`: chemical formula

    The composition object is immutable, i.e. it cannot be modified once created.
    Equality can be checked.
    It is hashable.
    It can be pickled or copied.
    """

    _key = object()
    PRECISION = 0.000000001  # 1ppb

    def __init__(self, key, mass_fractions, atomic_fractions, formula):
        """
        Private constructor. It should never be used.
        """
        if key != Composition._key:
            raise TypeError("Composition cannot be created using constructor")
        if set(mass_fractions.keys()) != set(atomic_fractions.keys()):
            raise ValueError(
                "Mass and atomic fractions must have the same elements")

        self.mass_fractions = MappingProxyType(mass_fractions)
        self.atomic_fractions = MappingProxyType(atomic_fractions)
        self._formula = formula

    @classmethod
    def from_pure(cls, z):
        """
        Creates a pure composition.

        Args:
            z (int): atomic number
        """
        return cls(cls._key, {z: 1.0}, {z: 1.0}, pyxray.element_symbol(z))

    @classmethod
    def from_formula(cls, formula):
        """
        Creates a composition from a chemical formula.

        Args:
            formula (str): chemical formula
        """
        atomic_fractions = convert_formula_to_atomic_fractions(formula)
        return cls.from_atomic_fractions(atomic_fractions)

    @classmethod
    def from_mass_fractions(cls, mass_fractions, formula=None):
        """
        Creates a composition from a mass fraction :class:`dict`.

        Args:
            mass_fractions (dict): mass fraction :class:`dict`.
                The keys are atomic numbers and the values weight fractions.
                Wildcard are accepted, e.g. ``{5: '?', 25: 0.4}`` where boron
                will get a mass fraction of 0.6.
            formula (str): optional chemical formula for the composition.
                If ``None``, a formula will be generated for the composition.
        """
        mass_fractions = process_wildcard(mass_fractions)
        atomic_fractions = convert_mass_to_atomic_fractions(mass_fractions)
        if not formula:
            formula = generate_name(atomic_fractions)
        return cls(cls._key, mass_fractions, atomic_fractions, formula)

    @classmethod
    def from_atomic_fractions(cls, atomic_fractions, formula=None):
        """
        Creates a composition from an atomic fraction :class:`dict`.

        Args:
            atomic_fractions (dict): atomic fraction :class:`dict`.
                The keys are atomic numbers and the values atomic fractions.
                Wildcard are accepted, e.g. ``{5: '?', 25: 0.4}`` where boron
                will get a atomic fraction of 0.6.
            formula (str): optional chemical formula for the composition.
                If ``None``, a formula will be generated for the composition.
        """
        atomic_fractions = process_wildcard(atomic_fractions)
        mass_fractions = convert_atomic_to_mass_fractions(atomic_fractions)
        if not formula:
            formula = generate_name(atomic_fractions)
        return cls(cls._key, mass_fractions, atomic_fractions, formula)

    def __len__(self):
        return len(self.mass_fractions)

    def __contains__(self, z):
        return z in self.mass_fractions

    def __iter__(self):
        return iter(self.mass_fractions.keys())

    def __repr__(self):
        return "<{}({})>".format(self.__class__.__name__, self.inner_repr())

    def __eq__(self, other):
        if not isinstance(other, self.__class__):
            return False

        if len(self) != len(other):
            return False

        for z in self.mass_fractions:
            if z not in other.mass_fractions:
                return False

            fraction = self.mass_fractions[z]
            other_fraction = other.mass_fractions[z]

            if not math.isclose(
                    fraction, other_fraction, abs_tol=self.PRECISION):
                return False

        return True

    def __ne__(self, other):
        return not self == other

    def __hash__(self):
        out = []
        for z in sorted(self.mass_fractions):
            out.append(z)
            out.append(int(self.mass_fractions[z] / self.PRECISION))

        return hash(tuple(out))

    def __getstate__(self):
        return {
            "mass_fractions": dict(self.mass_fractions),
            "atomic_fractions": dict(self.atomic_fractions),
            "formula": self.formula,
        }

    def __setstate__(self, state):
        self.mass_fractions = MappingProxyType(state.get("mass_fractions", {}))
        self.atomic_fractions = MappingProxyType(
            state.get("atomic_fractions", {}))
        self._formula = state.get("formula", "")

    def is_normalized(self):
        return math.isclose(sum(self.mass_fractions.values()),
                            1.0,
                            abs_tol=self.PRECISION)

    def inner_repr(self):
        return ", ".join(
            "{}: {:.4f}".format(pyxray.element_symbol(z), mass_fraction)
            for z, mass_fraction in self.mass_fractions.items())

    @property
    def formula(self):
        return self._formula