def structural(p):
    print(len(p))
    dataclasses.fields(p)

    dataclasses.asdict(p)
    dataclasses.astuple(p)
    dataclasses.replace(p)
Example #2
0
 def parse(cls, s):
     """Parse a string of items separated by a separator."""
     fields = dataclasses.fields(cls)
     items = s.split(cls._sep)
     if len(items) > len(fields):
         raise ValueError(f'Too many items for {cls.__qualname__}: {s}')
     return cls(**{x.name: x.type(y) for x, y in zip(fields, items)})
Example #3
0
    def fields(self) -> Dict[str, _Field]:
        def make_factory(value: object) -> Callable[[], Any]:
            return lambda: value

        if self._fields is None:
            hints = typing.get_type_hints(self.type)
            # This is gnarly. Sorry. For each field, store its default_factory if present; otherwise
            # create a factory returning its default if present; otherwise None. Default parameter
            # in the lambda is a ~~hack~~ to avoid messing up the variable binding.
            fields: Dict[str, _Field] = {
                field.name: _Field(
                    field.default_factory  # type: ignore
                    if field.default_factory is not MISSING  # type: ignore
                    else (
                        (make_factory(field.default))
                        if field.default is not MISSING
                        else None
                    ),
                    hints[field.name],
                )
                for field in dataclasses.fields(self.type)
            }

            self._fields = fields

        return self._fields
Example #4
0
def message(cls: Type[T]) -> Type[T]:
    """
    Returns the same class as was passed in, with additional dunder attributes needed for
    serialization and deserialization.
    """

    type_hints = get_type_hints(cls)

    try:
        # Used to list all fields and locate fields by field number.
        cls.__protobuf_fields__: Dict[int, Field] = dict(
            make_field(field_.metadata['number'], field_.name, type_hints[field_.name])
            for field_ in dataclasses.fields(cls)
        )
    except KeyError as e:
        # FIXME: catch `KeyError` in `make_field` and re-raise as `TypeError`.
        raise TypeError(f'type is not serializable: {e}') from e

    # noinspection PyUnresolvedReferences
    Message.register(cls)
    cls.serializer = MessageSerializer(cls)
    cls.type_url = f'type.googleapis.com/{cls.__module__}.{cls.__name__}'
    cls.validate = Message.validate
    cls.dump = Message.dump
    cls.dumps = Message.dumps
    cls.merge_from = Message.merge_from
    cls.load = classmethod(load)
    cls.loads = classmethod(loads)

    return cls
Example #5
0
    def unsafe_astuple(self) -> tuple:
        """Convert the fields into a tuple.

        Returns:
            A tuple of the fields of the dataclass.

        Use case:
            dataclasses.astuple returns a deepcopy of pig_acc. Inside PigJar pigs and
            accounts are registered by id in config.data and so this method will return the
            actual pig_acc.
        """
        return tuple([self.__dict__[field.name] for field in fields(self)])
Example #6
0
 def from_dict(cls, dict_data):
     data = {}
     for field in fields(cls):
         datum = dict_data.get(field.name)
         if datum:
             if issubclass(field.type, List):
                 sub_type = field.type.__args__[0]
                 constructor = cls._constructor_from_field_type(sub_type)
                 data[field.name] = [constructor(sub_data) for sub_data in datum]
             else:
                 constructor = cls._constructor_from_field_type(field.type)
                 data[field.name] = constructor(datum)
         elif field.default_factory is not MISSING:
             data[field.name] = field.default_factory()
         elif field.default is not MISSING:
             data[field.name] = field.default
     return cls(**data)
Example #7
0
    def _create_header(self, columns: FieldLike) -> None:
        # Check types of fieldlike column descriptors and convert them to field
        # descriptors, that are accepted by dataclasses.make_dataclass()
        fields: list = []
        for each in columns:
            if isinstance(each, str):
                fields.append(each)
                continue
            check.has_type(f"field {each}", each, tuple)
            check.has_size(f"field {each}", each, min_size=2, max_size=3)
            check.has_type("first arg", each[0], str)
            check.has_type("second arg", each[1], type)
            if len(each) == 2:
                fields.append(each)
                continue
            check.has_type("third arg", each[2], (Field, dict))
            if isinstance(each[2], Field):
                fields.append(each)
                continue
            field = dataclasses.field(**each[2])
            fields.append(each[:2] + (field,))

        # Create record namespace with table hooks
        namespace = {
            '_create_row_id': self._create_row_id,
            '_delete_hook': self._remove_row_id,
            '_restore_hook': self._append_row_id,
            '_update_hook': self._update_row_diff,
            '_revoke_hook': self._remove_row_diff}

        # Create Record dataclass and constructor
        self._Record = dataclasses.make_dataclass(
            'Row', fields, bases=(Record, ), namespace=namespace)

        # Create slots
        self._Record.__slots__ = ['id', 'state'] + [
            field.name for field in dataclasses.fields(self._Record)]

        # Reset store, diff and index
        self._store = []
        self._diff = []
        self._index = []
def unknown(p):
    dataclasses.fields(p)

    dataclasses.asdict(p)
    dataclasses.astuple(p)
import dataclasses
from typing import Type, Union


@dataclasses.dataclass
class Base:
    pass


class A(Base):
    pass


dataclasses.fields(A)
dataclasses.fields(A())

dataclasses.asdict(<warning descr="'dataclasses.asdict' method should be called on dataclass instances">A()</warning>)
dataclasses.astuple(A())
dataclasses.replace(A())


@dataclasses.dataclass
class B(Base):
    pass


dataclasses.fields(B)
dataclasses.fields(B())

dataclasses.asdict(B())
dataclasses.astuple(B())
Example #10
0
@dataclass
class Position:
    name: str
    lon: float = field(metadata={'unit': 'degrees'})
    lat: float = field(metadata={'unit': 'degrees'})

@dataclass
class Capital(Position):
    country: str


x = Capital('Oslo', 10.8, 59.9, 'Norway')
print(x.name, x.lon, x.lat, '\n')

print(fields(x))
print(fields(Position)[2].name)
print(fields(Position)[2].metadata['unit'], '\n')


# ==============================================================
# demonstrates the post_init; which runs after the auto __init__
# shows use of default factory function
# and an example of ordering

RANKS = '2 3 4 5 6 7 8 9 10 J Q K A'.split()
SUITS = '♣ ♢ ♡ ♠'.split()

@dataclass(order=True)  # we can make comparisons between objects
class PlayingCard:
    # init=False means the field isnt initialised; it's done by the post_init below
Example #11
0
    def get_schema_required(cls):
        def is_required(field: dataclasses.Field) -> bool:
            if field.default == dataclasses.MISSING:
                return True

        return [field.name for field in dataclasses.fields(cls) if is_required(field)]
Example #12
0
def merge_old_version(version, new, old):
    # Changes to exclusion file not implemented yet
    if old.exclusions != new.exclusions:
        raise NotImplementedError("exclusions differ")

    # In these change records, 0xFF means "no change"
    bidir_changes = [0xFF] * 0x110000
    category_changes = [0xFF] * 0x110000
    decimal_changes = [0xFF] * 0x110000
    mirrored_changes = [0xFF] * 0x110000
    east_asian_width_changes = [0xFF] * 0x110000
    # In numeric data, 0 means "no change",
    # -1 means "did not have a numeric value
    numeric_changes = [0] * 0x110000
    # normalization_changes is a list of key-value pairs
    normalization_changes = []
    for i in range(0x110000):
        if new.table[i] is None:
            # Characters unassigned in the new version ought to
            # be unassigned in the old one
            assert old.table[i] is None
            continue
        # check characters unassigned in the old version
        if old.table[i] is None:
            # category 0 is "unassigned"
            category_changes[i] = 0
            continue
        # check characters that differ
        if old.table[i] != new.table[i]:
            for k, field in enumerate(dataclasses.fields(UcdRecord)):
                value = getattr(old.table[i], field.name)
                new_value = getattr(new.table[i], field.name)
                if value != new_value:
                    if k == 1 and i in PUA_15:
                        # the name is not set in the old.table, but in the
                        # new.table we are using it for aliases and named seq
                        assert value == ''
                    elif k == 2:
                        category_changes[i] = CATEGORY_NAMES.index(value)
                    elif k == 4:
                        bidir_changes[i] = BIDIRECTIONAL_NAMES.index(value)
                    elif k == 5:
                        # We assume that all normalization changes are in 1:1 mappings
                        assert " " not in value
                        normalization_changes.append((i, value))
                    elif k == 6:
                        # we only support changes where the old value is a single digit
                        assert value in "0123456789"
                        decimal_changes[i] = int(value)
                    elif k == 8:
                        # Since 0 encodes "no change", the old value is better not 0
                        if not value:
                            numeric_changes[i] = -1
                        else:
                            numeric_changes[i] = float(value)
                            assert numeric_changes[i] not in (0, -1)
                    elif k == 9:
                        if value == 'Y':
                            mirrored_changes[i] = '1'
                        else:
                            mirrored_changes[i] = '0'
                    elif k == 11:
                        # change to ISO comment, ignore
                        pass
                    elif k == 12:
                        # change to simple uppercase mapping; ignore
                        pass
                    elif k == 13:
                        # change to simple lowercase mapping; ignore
                        pass
                    elif k == 14:
                        # change to simple titlecase mapping; ignore
                        pass
                    elif k == 15:
                        # change to east asian width
                        east_asian_width_changes[
                            i] = EASTASIANWIDTH_NAMES.index(value)
                    elif k == 16:
                        # derived property changes; not yet
                        pass
                    elif k == 17:
                        # normalization quickchecks are not performed
                        # for older versions
                        pass
                    else:

                        class Difference(Exception):
                            pass

                        raise Difference(hex(i), k, old.table[i], new.table[i])
    new.changed.append(
        (version,
         list(
             zip(bidir_changes, category_changes, decimal_changes,
                 mirrored_changes, east_asian_width_changes,
                 numeric_changes)), normalization_changes))
Example #13
0
 def keys(self) -> Tuple[str, ...]:
     return tuple(map(lambda f: f.name.replace("_", " "), fields(self)))
Example #14
0
def mqtt_on_connect(client: mqtt.Client, userdata, flags, rc):
  client.subscribe([(_mqtt_topics['sub'].format(data_field.name), 0)
                    for data_field in fields(_data.properties)])
  # Subscribe to subscription updates.
  client.subscribe('$SYS/broker/log/M/subscribe/#')
Example #15
0
 def state_data(self):
     return {
         field.name: getattr(self, field.name)
         for field in fields(self)
     }
Example #16
0
 def args(self) -> List[Any]:
     args = []
     for field in fields(self):
         args.append(getattr(self, field.name))
     return args
    def shape_resolved_transaction(self, transaction_id, raw: dict):
        try:
            raw['amount'] = Money(
                amount=Decimal(raw['amount']), currency=raw['currency']
            )
            del raw['currency']
        except (DecimalException, ValueError, KeyError):
            raise TransactionShapingError(
                "Transaction amount must be specified in decimal \'amount\' and "
                "\'currency\' fields."
            )
        except ValueError:
            raise TransactionShapingError(
                'Invalid transaction amount \"%(amount)s %(currency)s\"', {
                    'amount': raw['amount'], 'currency': raw['currency']
                }
            )
        try:
            pipeline_section_id = int(raw['pipeline_section_id'])
            del raw['pipeline_section_id']
        except KeyError:
            if len(self.pipeline_spec) == 1:
                pipeline_section_id = 0
            else:
                raise TransactionShapingError(
                    'pipeline_section_id is required on all transactions'
                )
        try:
            rt_class, preparator = self.pipeline_spec[pipeline_section_id]
        except IndexError:
            raise TransactionShapingError(
                'Invalid pipeline section \'%d\'.', pipeline_section_id
            )

        # attempt to reprocess fields in dict
        # we allow coercions of string values, for easier interoperability
        #  with html attributes
        for f in dataclasses.fields(rt_class):
            try:
                raw_field = raw[f.name]
            except KeyError:
                # will be dealt with later, if necessary
                continue
            if f.name == 'amount':  # we already dealt with this
                continue
            if f.type is datetime:
                try:
                    ts = datetime.fromisoformat(raw_field)
                except (KeyError, ValueError, TypeError):
                    raise TransactionShapingError(
                        'Could not parse ISO datetime \'%s\'.',
                        raw_field
                    )
                if ts.tzinfo is None:
                    # naive datetime - treat as UTC
                    raw[f.name] = pytz.utc.localize(ts)
                else:
                    # replace timezone by UTC
                    raw[f.name] = ts.astimezone(pytz.utc)
            elif f.type is bool:
                if isinstance(raw_field, bool):
                    continue
                elif isinstance(raw_field, str):
                    if raw_field.casefold() == 'true':
                        raw[f.name] = True
                    elif raw_field.casefold() == 'false':
                        raw[f.name] = False
                    else:
                        raise TransactionShapingError(
                            'Invalid boolean string \'%s\'.', raw_field
                        )
                else:
                    raise TransactionShapingError(
                        'Boolean fields must be represented as booleans '
                        'or \'true\'/\'false\' strings.'
                    )
            else:
                if type(raw_field) is f.type:
                    continue
                try:
                    raw[f.name] = f.type(raw_field)
                except:
                    raise TransactionShapingError(
                        'Failed to coerce \'%(value)s\' of type \'%(value_type)s\' '
                        'to value of type \'%(field_type)s\' in field '
                        '\'%(field_name)s\'.', {
                            'value': raw_field, 'value_type': type(raw_field),
                            'field_type': f.type, 'field_name': f.name
                        }
                    )

        raw.setdefault('do_not_skip', False)

        try:
            resolved_transaction = rt_class(
                **raw,
                message_context=APIErrorContext(transaction_id=transaction_id)
            )
        except (TypeError, ValueError):
            keys_specified = frozenset(raw.keys())
            all_fields = {
                f.name for f in dataclasses.fields(rt_class)
                if f.name != 'message_context'
            }
            over_defined = keys_specified - all_fields
            required_fields = {
                f.name for f in dataclasses.fields(rt_class)
                if f.default_factory is dataclasses.MISSING
                    and f.default is dataclasses.MISSING
                    and f.name != 'message_context'
            }
            under_defined = required_fields - keys_specified
            if over_defined and under_defined:
                raise TransactionShapingError(
                    'The fields \'%s\' are required, and '
                    'the fields \'%s\' are undefined.' % (
                        ', '.join(under_defined),
                        ', '.join(over_defined)
                    )
                )
            elif over_defined:
                raise TransactionShapingError(
                    'The fields \'%s\' are undefined.' % ', '.join(over_defined)
                )
            elif under_defined:
                raise TransactionShapingError(
                    'The fields \'%s\' are required.' % ', '.join(under_defined)
                )
            else:  # pragma: no cover
                raise TransactionShapingError(
                    'Failed to instantiate \'resolved_transaction\'.'
                )
        return pipeline_section_id, resolved_transaction
Example #18
0
 def __mul__(self, other):
     return self.__class__(*(getattr(self, dim.name) * other
                             for dim in fields(self)))
Example #19
0
 def __sub__(self, other):
     return self.__class__(*(getattr(self, dim.name) -
                             getattr(other, dim.name)
                             for dim in fields(self)))
def union1(p: Union[A, B]):
    dataclasses.fields(p)

    dataclasses.asdict(p)
    dataclasses.astuple(p)
    dataclasses.replace(p)
Example #21
0
 def pretty(self, indent=''):
     max_fieldlen = max(len(field.name) for field in fields(self))
     return '\n'.join(
         f'{indent}{field.name.rjust(max_fieldlen)}: {getattr(self, field.name)}'
         for field in fields(self) if field.repr)
Example #22
0
 def get_field(self, name):
     """For `name` = "field name", retrieve a field "field_name", if any."""
     for f in fields(self):
         if f.name in (name, name.replace(" ", "_")):
             return f
Example #23
0
 def from_group(cls, file_path: Path, group: Group) -> 'Matrix':
     return cls(file_path, group.name,
                *(group[f.name] for f in fields(cls)[2:]))
Example #24
0
 def pack(self):
     return struct.pack(
         self.format(), *[
             getattr(self, field.name) for field in fields(type(self))
             if field.init
         ])
Example #25
0
 def variables(self):
     """Return names of all model variables."""
     return [var.name for var in dataclasses.fields(self)]
Example #26
0
    def completedOrder(self, fields):
        o = Order()
        c = Contract()
        st = OrderState()

        (_, c.conId, c.symbol, c.secType, c.lastTradeDateOrContractMonth,
         c.strike, c.right, c.multiplier, c.exchange, c.currency,
         c.localSymbol, c.tradingClass, o.action, o.totalQuantity, o.orderType,
         o.lmtPrice, o.auxPrice, o.tif, o.ocaGroup, o.account, o.openClose,
         o.origin, o.orderRef, o.permId, o.outsideRth, o.hidden,
         o.discretionaryAmt, o.goodAfterTime, o.faGroup, o.faMethod,
         o.faPercentage, o.faProfile, o.modelCode, o.goodTillDate, o.rule80A,
         o.percentOffset, o.settlingFirm, o.shortSaleSlot,
         o.designatedLocation, o.exemptCode, o.startingPrice, o.stockRefPrice,
         o.delta, o.stockRangeLower, o.stockRangeUpper, o.displaySize,
         o.sweepToFill, o.allOrNone, o.minQty, o.ocaType, o.triggerMethod,
         o.volatility, o.volatilityType, o.deltaNeutralOrderType,
         o.deltaNeutralAuxPrice, *fields) = fields

        if o.deltaNeutralOrderType:
            (o.deltaNeutralConId, o.deltaNeutralShortSale,
             o.deltaNeutralShortSaleSlot, o.deltaNeutralDesignatedLocation,
             *fields) = fields
        (o.continuousUpdate, o.referencePriceType, o.trailStopPrice,
         o.trailingPercent, c.comboLegsDescrip, *fields) = fields

        numLegs = int(fields.pop(0))
        c.comboLegs = []
        for _ in range(numLegs):
            leg = ComboLeg()
            (leg.conId, leg.ratio, leg.action, leg.exchange, leg.openClose,
             leg.shortSaleSlot, leg.designatedLocation, leg.exemptCode,
             *fields) = fields
            self.parse(leg)
            c.comboLegs.append(leg)

        numOrderLegs = int(fields.pop(0))
        o.orderComboLegs = []
        for _ in range(numOrderLegs):
            leg = OrderComboLeg()
            leg.price = fields.pop(0)
            self.parse(leg)
            o.orderComboLegs.append(leg)

        numParams = int(fields.pop(0))
        if numParams > 0:
            o.smartComboRoutingParams = []
            for _ in range(numParams):
                tag, value, *fields = fields
                o.smartComboRoutingParams.append(TagValue(tag, value))
        (o.scaleInitLevelSize, o.scaleSubsLevelSize, increment,
         *fields) = fields

        o.scalePriceIncrement = float(increment or UNSET_DOUBLE)
        if 0 < o.scalePriceIncrement < UNSET_DOUBLE:
            (o.scalePriceAdjustValue, o.scalePriceAdjustInterval,
             o.scaleProfitOffset, o.scaleAutoReset, o.scaleInitPosition,
             o.scaleInitFillQty, o.scaleRandomPercent, *fields) = fields

        o.hedgeType = fields.pop(0)
        if o.hedgeType:
            o.hedgeParam = fields.pop(0)

        (o.clearingAccount, o.clearingIntent, o.notHeld, dncPresent,
         *fields) = fields

        if int(dncPresent):
            conId, delta, price, *fields = fields
            c.deltaNeutralContract = DeltaNeutralContract(
                int(conId or 0), float(delta or 0), float(price or 0))

        o.algoStrategy = fields.pop(0)
        if o.algoStrategy:
            numParams = int(fields.pop(0))
            if numParams > 0:
                o.algoParams = []
                for _ in range(numParams):
                    tag, value, *fields = fields
                    o.algoParams.append(TagValue(tag, value))
        (o.solicited, st.status, o.randomizeSize, o.randomizePrice,
         *fields) = fields

        if o.orderType == 'PEG BENCH':
            (o.referenceContractId, o.isPeggedChangeAmountDecrease,
             o.peggedChangeAmount, o.referenceChangeAmount,
             o.referenceExchangeId, *fields) = fields

        numConditions = int(fields.pop(0))
        if numConditions > 0:
            for _ in range(numConditions):
                condType = int(fields.pop(0))
                condCls = OrderCondition.createClass(condType)
                n = len(dataclasses.fields(condCls)) - 1
                cond = condCls(condType, *fields[:n])
                self.parse(cond)
                o.conditions.append(cond)
                fields = fields[n:]
            (o.conditionsIgnoreRth, o.conditionsCancelOrder, *fields) = fields

        (o.trailStopPrice, o.lmtPriceOffset, o.cashQty, *fields) = fields

        if self.serverVersion >= 141:
            o.dontUseAutoPriceForHedge = fields.pop(0)
        if self.serverVersion >= 145:
            o.isOmsContainer = fields.pop(0)

        (o.autoCancelDate, o.filledQuantity, o.refFuturesConId,
         o.autoCancelParent, o.shareholder, o.imbalanceOnly,
         o.routeMarketableToBbo, o.parentPermId, st.completedTime,
         st.completedStatus) = fields

        self.parse(c)
        self.parse(o)
        self.parse(st)
        self.wrapper.completedOrder(c, o, st)
Example #27
0
 def normalized(self):
     return self.__class__(
         **{f.name: f.type(getattr(self, f.name))
            for f in fields(self)})
Example #28
0
    def __post_init__(self):
        """Validation logic that runs after an object has been instantiated.

        Based heavily on:
        https://stackoverflow.com/questions/50563546/validating-detailed-types-in-python-dataclasses
        """
        for field_def in fields(self):
            field_name = field_def.name
            field_value = getattr(self, field_name)
            actual_type = type(field_value)

            if hasattr(field_def.type, '__origin__'):
                # If a type hint uses typing.List, we need to check the origin
                # in order to see that it's a list
                expected_type = field_def.type.__origin__
            else:
                expected_type = field_def.type

            # Lists are a special case, because we have to get the list element
            # type in a different way
            if field_value is not None:
                class_name = self.__class__.__name__

                # A ForwardRef will appear to just be a str
                # Check that the expected type is a str instead of an actual
                # type definition, and check that the name of the current class
                # matches the string in the ForwardRef.
                if (class_name == expected_type
                        and isinstance(expected_type, str)):
                    # Double check that the type itself and the current class
                    # are the same
                    if actual_type != self.__class__:
                        raise TypeError((f'{class_name}.{field_name} was '
                                         'defined as a <class '
                                         f"'{expected_type}'>, "
                                         f'but we found a {actual_type} '
                                         'instead'))
                else:
                    # Optionals are technically just Union[T, None]
                    if expected_type == typing.Union:
                        possible_types = field_def.type.__args__
                        matches = (isinstance(field_value, possible_type)
                                   for possible_type in possible_types)
                        if not any(matches):
                            raise TypeError((f'{class_name}.{field_name} was '
                                             'defined to be any of: '
                                             f'{possible_types} but was found '
                                             f'to be {actual_type} instead'))

                    elif (isinstance(field_value, expected_type)
                          and isinstance(field_value, list)):
                        if not hasattr(field_def.type, '__args__'):
                            raise TypeError((f'{class_name}.{field_name} was '
                                             f'defined as a {actual_type}, '
                                             'but you must use '
                                             'typing.List[type] '
                                             'instead'))

                        expected_element_type = field_def.type.__args__[0]
                        if isinstance(expected_element_type, typing.TypeVar):
                            raise TypeError((f'{class_name}.{field_name} was '
                                             f'defined as a {actual_type}, '
                                             'but is missing information '
                                             'about the'
                                             ' type of the elements inside '
                                             'it'))

                        if not self._ensure_no_native_collections(
                                expected_element_type):
                            raise TypeError(((f'{class_name}.{field_name} was '
                                              'detected to use a native '
                                              'Python '
                                              'collection in its type '
                                              'definition. '
                                              'We should only use '
                                              'typing.List[] '
                                              'for these')))

                        for i, element in enumerate(field_value):
                            if isinstance(element, dict):
                                if not element:
                                    raise TypeError(((f'{class_name}.'
                                                      f'{field_name} '
                                                      'was found to have an '
                                                      'empty dictionary. An '
                                                      'empty '
                                                      'dictionary will not '
                                                      'properly instantiate a '
                                                      'nested object')))

                                # Set reference of the specific list index.
                                # Kind of a hack, to get around the fact that
                                # __setattr__ can only seem to take field
                                # names, but not indices
                                getattr(self,
                                        field_name)[i] = expected_element_type(
                                            **element)

                        if not self._validate_list_types(
                                field_value, field_def.type):
                            raise TypeError((f'{class_name}.{field_name} is '
                                             f'{field_value} which does not '
                                             'match '
                                             f'{field_def.type}. '
                                             'Unfortunately, '
                                             'we are unable to infer the '
                                             'explicit '
                                             f'type of {class_name}.'
                                             f'{field_name}'))

                    elif not isinstance(field_value, expected_type):
                        if isinstance(field_value, dict):
                            if not self._ensure_no_native_collections(
                                    expected_type):
                                raise TypeError((f'{class_name}.{field_name} '
                                                 'was '
                                                 'detected to use a native '
                                                 'Python '
                                                 'dict in its type '
                                                 'definition. '
                                                 'We should only use custom '
                                                 'objects for these'))
                            try:
                                setattr(self, field_name,
                                        expected_type(**field_value))
                            except TypeError:
                                raise TypeError(f'{class_name}.{field_name} '
                                                'is '
                                                'expected to be '
                                                f'{expected_type}, but value '
                                                f'{field_value} is a dict '
                                                'with unexpected keys')
                        else:
                            raise TypeError(f'{class_name}.{field_name} is '
                                            'expected to be '
                                            f'{expected_type}, but value '
                                            f'{field_value} with '
                                            f'type {actual_type} was found '
                                            'instead')
Example #29
0
    def serialize_mappings() -> SerializeMappings:
        def f(x: Optional[int]) -> Optional[HexInt]:
            return map_option(x, HexInt)

        return {field.name: f for field in fields(Config) if field.type is int}
Example #30
0
    def serialize(self):

        # noinspection PyDataclass
        # fields(self)
        return b''.join(pack('L', getattr(self, f.name)) for f in fields(self))
Example #31
0
 def to_str_pretty(self) -> str:
     return ", ".join(
         # NOTE: not using isinstance(value, int) because apparently bools are ints
         f"{field.name}={hex(value) if field.type is int else str(value)}"
         for field in fields(Config)
         for value in (getattr(self, field.name), ) if value is not None)
    def __post_init__(self, log=_log, model=None, experiment=None):
        super(CMIP6DataSourceAttributes, self).__post_init__(log=log)
        config = core.ConfigManager()
        cv = cmip6.CMIP6_CVs()

        def _init_x_from_y(source, dest):
            if not getattr(self, dest, ""):
                try:
                    source_val = getattr(self, source, "")
                    if not source_val:
                        raise KeyError()
                    dest_val = cv.lookup_single(source_val, source, dest)
                    log.debug("Set %s='%s' based on %s='%s'.", dest, dest_val,
                              source, source_val)
                    setattr(self, dest, dest_val)
                except KeyError:
                    log.debug("Couldn't set %s from %s='%s'.", dest, source,
                              source_val)
                    setattr(self, dest, "")

        if not self.CASE_ROOT_DIR and config.CASE_ROOT_DIR:
            log.debug("Using global CASE_ROOT_DIR = '%s'.",
                      config.CASE_ROOT_DIR)
            self.CASE_ROOT_DIR = config.CASE_ROOT_DIR
        # verify case root dir exists
        if not os.path.isdir(self.CASE_ROOT_DIR):
            log.critical("Data directory CASE_ROOT_DIR = '%s' not found.",
                         self.CASE_ROOT_DIR)
            util.exit_handler(code=1)

        # should really fix this at the level of CLI flag synonyms
        if model and not self.source_id:
            self.source_id = model
        if experiment and not self.experiment_id:
            self.experiment_id = experiment

        # validate non-empty field values
        for field in dataclasses.fields(self):
            val = getattr(self, field.name, "")
            if not val:
                continue
            try:
                if not cv.is_in_cv(field.name, val):
                    log.error((
                        "Supplied value '%s' for '%s' is not recognized by "
                        "the CMIP6 CV. Continuing, but queries will probably fail."
                    ), val, field.name)
            except KeyError:
                # raised if not a valid CMIP6 CV category
                continue
        # currently no inter-field consistency checks: happens implicitly, since
        # set_experiment will find zero experiments.

        # Attempt to determine first few fields of DRS, to avoid having to crawl
        # entire DRS structure
        _init_x_from_y('experiment_id', 'activity_id')
        _init_x_from_y('source_id', 'institution_id')
        _init_x_from_y('institution_id', 'source_id')
        # TODO: multi-column lookups
        # set CATALOG_DIR to be further down the hierarchy if possible, to
        # avoid having to crawl entire DRS strcture; CASE_ROOT_DIR remains the
        # root of the DRS hierarchy
        new_root = self.CASE_ROOT_DIR
        for drs_attr in ("activity_id", "institution_id", "source_id",
                         "experiment_id"):
            drs_val = getattr(self, drs_attr, "")
            if not drs_val:
                break
            new_root = os.path.join(new_root, drs_val)
        if not os.path.isdir(new_root):
            log.error("Data directory '%s' not found; starting crawl at '%s'.",
                      new_root, self.CASE_ROOT_DIR)
            self.CATALOG_DIR = self.CASE_ROOT_DIR
        else:
            self.CATALOG_DIR = new_root
Example #33
0
 def from_dict(cls, dic):
     field_names = set(f.name for f in dataclasses.fields(cls))
     return cls(**{k: v for k, v in dic.items() if k in field_names})
Example #34
0
 def _get_fields(self) -> FieldTuple:
     return dataclasses.fields(self._Record)
Example #35
0
def shallow_asdict(x: Any) -> Dict[str, Any]:
    assert dataclasses.is_dataclass(x)
    return {
        field.name: getattr(x, field.name)
        for field in dataclasses.fields(x)
    }
import dataclasses
from typing import Type, Union


class A:
    pass


dataclasses.fields(<warning descr="'dataclasses.fields' method should be called on dataclass instances or types">A</warning>)
dataclasses.fields(<warning descr="'dataclasses.fields' method should be called on dataclass instances or types">A()</warning>)

dataclasses.asdict(<warning descr="'dataclasses.asdict' method should be called on dataclass instances">A()</warning>)
dataclasses.astuple(<warning descr="'dataclasses.astuple' method should be called on dataclass instances">A()</warning>)
dataclasses.replace(<warning descr="'dataclasses.replace' method should be called on dataclass instances">A()</warning>)


@dataclasses.dataclass
class B:
    pass


dataclasses.fields(B)
dataclasses.fields(B())

dataclasses.asdict(B())
dataclasses.astuple(B())
dataclasses.replace(B())

dataclasses.asdict(<warning descr="'dataclasses.asdict' method should be called on dataclass instances">B</warning>)
dataclasses.astuple(<warning descr="'dataclasses.astuple' method should be called on dataclass instances">B</warning>)
dataclasses.replace(<warning descr="'dataclasses.replace' method should be called on dataclass instances">B</warning>)
Example #37
0
 def from_dict(cls, dict_):
     class_fields = {f.name for f in fields(cls)}
     return ColumnFormat(
         **{k: v
            for k, v in dict_.items() if k in class_fields})
Example #38
0
 def get_tag_names() -> tuple[str, ...]:
     return tuple(field.name for field in fields(Taggable))
Example #39
0
def _fields_from_dataclass(obj: Any) -> Dict[str, Field]:
    data_fields: Tuple[dataclasses.Field, ...] = dataclasses.fields(obj)
    return {f.name: _field_from_dataclass_field(f) for f in data_fields}
def union2(p: Union[Type[A], Type[B]]):
    dataclasses.fields(p)

    dataclasses.asdict(<warning descr="'dataclasses.asdict' method should be called on dataclass instances">p</warning>)
    dataclasses.astuple(<warning descr="'dataclasses.astuple' method should be called on dataclass instances">p</warning>)
    dataclasses.replace(<warning descr="'dataclasses.replace' method should be called on dataclass instances">p</warning>)
Example #41
0
 def format(kls):
     return ''.join(field.type for field in fields(kls) if field.init)
Example #42
0
    def __call__(self, container):
        target = self.target
        context = container.context

        # Make the args dict that we will construct dataclass with
        args = {}

        # Iterate through the dataclass fields
        # Because fields() gives a string for the type, instead of the
        # actual type, let's get a mapping of field name -> field type
        fields_mapping = {f.name: f for f in fields(target)}

        # Iterate through the dataclass fields
        for field_name, field_type in get_type_hints(target).items():

            # Doing this style of bailing out quickly for performance
            # reasons. Don't want to keep doing "if", though it
            # means some repetitions.
            if field_type is ServiceContainer:
                args[field_name] = container
                continue

            if field_type == Context:
                args[field_name] = context
                continue

            # See if this field is using the injectable field, e.g.
            # url: str = injected(Url, attr='value')
            full_field: Field = fields_mapping[field_name]
            if full_field.metadata.get('injected', False):
                injected_info = full_field.metadata['injected']
                injected_attr = injected_info.get('attr')
                injected_type = injected_info['type_']
                injected_name = injected_info['name']

                # Another special case: if asked to inject Context or
                # ServiceContainer, consider it like a sentinel and return it.
                if injected_type is Context:
                    injected_target = context
                elif injected_type is ServiceContainer:
                    injected_target = container
                else:
                    # Ask the registry for one of these
                    injected_target = container.get(injected_type,
                                                    name=injected_name)

                # If attr is used, get specified attribute off that instance
                if injected_attr:
                    field_value = getattr(injected_target, injected_attr)
                else:
                    field_value = injected_target
                args[field_name] = field_value
                continue

            # Now the general case, something like url: Url
            try:
                field_value = container.get(field_type)
                args[field_name] = field_value
            except TypeError:
                # Seems that wired, when looking up str, gives:
                #   TypeError: can't set attributes of bui...sion type 'str'
                # We will use that to our advantage to look for a dataclass
                # field default value.
                field_default = getattr(full_field, 'default', None)
                if field_default is not MISSING:
                    args[field_name] = field_default
                    continue
                elif full_field.init is False:
                    # Expect a __post_init__ that assigns this value
                    if not hasattr(target, '__post_init__'):
                        m = 'has init=False but no __post_init__'
                        msg = f'Field "{field_name}" {m}'
                        raise LookupError(msg)
                    continue
                else:
                    msg = f'No default value on field {field_name}'
                    raise LookupError(msg)
            except LookupError:
                # Give up and work around ``wired`` unhelpful exception
                # by adding some context information.

                # Note that a dataclass with ``__post_init__`` might still
                # do some construction. Only do this next part if there's
                # no __post_init__
                if not hasattr(target, '__post_init__'):
                    m = 'Injector failed for'
                    msg = f'{m} {field_name} on {target.__name__}'
                    raise LookupError(msg)

        # Now construct an instance of the target dataclass
        return target(**args)
Example #43
0
 def __getattr__(self, name):
     # Be able to fetch Weights attributes directly
     for f in fields(Weights):
         if f.name == name:
             return object.__getattribute__(self.value, name)
     return super().__getattr__(name)