def _transform(self) -> List[Field]: ctx = self._ctx fields = self._collect_fields() schema_t = self._lookup_type('edb.schema.schema.Schema') for f in fields: if f.has_explicit_accessor: continue mypy_helpers.add_method( ctx, name=f'get_{f.name}', args=[ nodes.Argument( variable=nodes.Var( name='schema', type=schema_t, ), type_annotation=schema_t, initializer=None, kind=nodes.ARG_POS, ), ], return_type=f.type, ) return fields
def add_model_init_hook(ctx: ClassDefContext) -> None: """Add a dummy __init__() to a model and record it is generated. Instantiation will be checked more precisely when we inferred types (using get_function_hook and model_hook). """ if '__init__' in ctx.cls.info.names: # Don't override existing definition. return any = AnyType(TypeOfAny.special_form) var = Var('kwargs', any) kw_arg = Argument(variable=var, type_annotation=any, initializer=None, kind=ARG_STAR2) add_method(ctx, '__init__', [kw_arg], NoneTyp()) ctx.cls.info.metadata.setdefault('sqlalchemy', {})['generated_init'] = True # Also add a selection of auto-generated attributes. sym = ctx.api.lookup_fully_qualified_or_none('sqlalchemy.sql.schema.Table') if sym: assert isinstance(sym.node, TypeInfo) typ = Instance(sym.node, []) # type: Type else: typ = AnyType(TypeOfAny.special_form) add_var_to_class('__table__', typ, ctx.cls.info)
def _synthesize_init(self, fields: List[Field]) -> None: ctx = self._ctx cls_info = ctx.cls.info # If our self type has placeholders (probably because of type # var bounds), defer. If we skip deferring and stick something # in our symbol table anyway, we'll get in trouble. (Arguably # plugins.common ought to help us with this, but oh well.) self_type = mypy_helpers.fill_typevars(cls_info) if semanal.has_placeholder(self_type): raise DeferException if ( ( '__init__' not in cls_info.names or cls_info.names['__init__'].plugin_generated ) and fields ): mypy_helpers.add_method( ctx, '__init__', self_type=self_type, args=[field.to_argument() for field in fields], return_type=types.NoneType(), )
def adjust_class_def(class_def_context: ClassDefContext) -> None: # This MyPy plugin inserts method type stubs for the "missing" ordering methods the # @total_ordering class decorator will fill in dynamically. api = class_def_context.api ordering_other_type = api.named_type("__builtins__.object") ordering_return_type = api.named_type("__builtins__.bool") args = [ Argument( variable=Var(name="other", type=ordering_other_type), type_annotation=ordering_other_type, initializer=None, kind=ARG_POS, ) ] type_info: TypeInfo = class_def_context.cls.info for ordering_method_name in "__lt__", "__le__", "__gt__", "__ge__": existing_method = type_info.get(ordering_method_name) if existing_method is None: add_method( ctx=class_def_context, name=ordering_method_name, args=args, return_type=ordering_return_type, )
def _munge_dataclassy( ctx: ClassDefContext, classy: ClassyInfo, ) -> None: cls = ctx.cls info = cls.info fields = classy.fields # We store the dataclassy info here so that we can figure out later which # classes are dataclassy classes info.metadata[_meta_key] = classy.serialize() # Add the __init__ method if we have to if classy.args.init: add_method( ctx, '__init__', args=[f.metharg for f in fields.values()], return_type=NoneType(), ) # Add the fields for field in fields.values(): var = field.var var.info = info var.is_property = True var._fullname = f'{info.fullname}.{var.name}' info.names[field.name] = SymbolTableNode(MDEF, var)
def copy_method_to_another_class(ctx: ClassDefContext, self_type: Instance, new_method_name: str, method_node: FuncDef) -> None: arguments, return_type = _prepare_new_method_arguments(method_node) add_method(ctx, new_method_name, args=arguments, return_type=return_type, self_type=self_type)
def copy_method_to_another_class( ctx: ClassDefContext, self_type: Instance, new_method_name: str, method_node: FuncDef ) -> None: semanal_api = get_semanal_api(ctx) if method_node.type is None: if not semanal_api.final_iteration: semanal_api.defer() return arguments, return_type = build_unannotated_method_args(method_node) add_method(ctx, new_method_name, args=arguments, return_type=return_type, self_type=self_type) return method_type = method_node.type if not isinstance(method_type, CallableType): if not semanal_api.final_iteration: semanal_api.defer() return arguments = [] bound_return_type = semanal_api.anal_type(method_type.ret_type, allow_placeholder=True) assert bound_return_type is not None if isinstance(bound_return_type, PlaceholderNode): return try: original_arguments = method_node.arguments[1:] except AttributeError: original_arguments = [] for arg_name, arg_type, original_argument in zip( method_type.arg_names[1:], method_type.arg_types[1:], original_arguments ): bound_arg_type = semanal_api.anal_type(arg_type, allow_placeholder=True) if bound_arg_type is None and not semanal_api.final_iteration: semanal_api.defer() return assert bound_arg_type is not None if isinstance(bound_arg_type, PlaceholderNode): return var = Var(name=original_argument.variable.name, type=arg_type) var.line = original_argument.variable.line var.column = original_argument.variable.column argument = Argument( variable=var, type_annotation=bound_arg_type, initializer=original_argument.initializer, kind=original_argument.kind, ) argument.set_line(original_argument) arguments.append(argument) add_method(ctx, new_method_name, args=arguments, return_type=bound_return_type, self_type=self_type)
def analyze_stubs(ctx: ClassDefContext) -> None: boolean = ctx.api.builtin_type("builtins.bool") add_method( ctx, "is_valid", args=[], return_type=Instance(boolean.type, []), )
def add_init_to_cls(ctx: ClassDefContext) -> None: if "__init__" not in ctx.cls.info.names: anytype = AnyType(TypeOfAny.special_form) var = Var("kwargs", anytype) kw_arg = Argument(variable=var, type_annotation=anytype, initializer=None, kind=ARG_STAR2) add_method(ctx, "__init__", [kw_arg], NoneTyp()) set_declarative(ctx.cls.info)
def add_method(self, method_name: str, args: List[Argument], ret_type: Type, self_type: Optional[Type] = None, tvd: Optional[TypeVarDef] = None) -> None: """Add a method: def <method_name>(self, <args>) -> <ret_type>): ... to info. self_type: The type to use for the self argument or None to use the inferred self type. tvd: If the method is generic these should be the type variables. """ self_type = self_type if self_type is not None else self.self_type add_method(self.ctx, method_name, args, ret_type, self_type, tvd)
def add_model_replace(ctx) -> None: """Add model replace method.""" any_type = types.AnyType(types.TypeOfAny.special_form) var = nodes.Var('change', any_type) kw_arg = nodes.Argument(variable=var, type_annotation=any_type, initializer=None, kind=nodes.ARG_STAR2) ret_type = types.Instance(ctx.cls.info, []) common.add_method(ctx, 'replace', [kw_arg], ret_type)
def add_model_astuple(ctx) -> None: """Add model astuple method.""" bool_type = ctx.api.builtin_type('builtins.bool') tuple_type = ctx.api.builtin_type('builtins.tuple') var = nodes.Var('recurse', bool_type) recurse = nodes.Argument(variable=var, type_annotation=bool_type, initializer=nodes.NameExpr('True'), kind=nodes.ARG_NAMED_OPT) common.add_method(ctx, 'astuple', [recurse], tuple_type)
def add_method(self, method_name: str, args: List[Argument], ret_type: Type, self_type: Optional[Type] = None, tvd: Optional[TypeVarDef] = None) -> None: """Add a method: def <method_name>(self, <args>) -> <ret_type>): ... to info. self_type: The type to use for the self argument or None to use the inferred self type. tvd: If the method is generic these should be the type variables. """ self_type = self_type if self_type is not None else self.self_type add_method(self.ctx, method_name, args, ret_type, self_type, tvd)
def transform(self): ctx = self._ctx metadata = ctx.cls.info.metadata.get(METADATA_KEY) if not metadata: ctx.cls.info.metadata[METADATA_KEY] = metadata = {} metadata['processing'] = True if metadata.get('processed'): return try: fields = self._collect_fields() schema_t = self._lookup_type('edb.schema.schema.Schema') except DeferException: ctx.api.defer() return None cls_info = ctx.cls.info for f in fields: ftype = cls_info.get(f.name).type if ftype is None or cls_info.get(f'get_{f.name}') is not None: # The class is already doing something funny with the # field or the accessor, so ignore it. continue if f.is_optional: ftype = types.UnionType.make_union( [ftype, types.NoneType()], line=ftype.line, column=ftype.column, ) mypy_helpers.add_method( ctx, name=f'get_{f.name}', args=[ nodes.Argument( variable=nodes.Var( name='schema', type=schema_t, ), type_annotation=schema_t, initializer=None, kind=nodes.ARG_POS, ), ], return_type=ftype, ) metadata['fields'] = {f.name: f.serialize() for f in fields} metadata['processed'] = True
def add_struc_and_unstruc_to_classdefcontext(cls_def_ctx: ClassDefContext): """This MyPy hook tells MyPy that struc and unstruc will be present on a Cat""" dict_type = cls_def_ctx.api.named_type("__builtins__.dict") str_type = cls_def_ctx.api.named_type("__builtins__.str") api = cls_def_ctx.api implicit_any = AnyType(TypeOfAny.special_form) mapping = api.lookup_fully_qualified_or_none("typing.Mapping") if not mapping or not mapping.node: api.defer() return mapping_str_any_type = Instance(mapping.node, [str_type, implicit_any]) maybe_mapping_str_any_type = make_optional(mapping_str_any_type) if fullname == CAT_PATH: attr_class_maker_callback( cls_def_ctx, True ) # since a Cat is also an attr.s class... info = cls_def_ctx.cls.info if STRUCTURE_NAME not in info.names: add_static_method( cls_def_ctx, STRUCTURE_NAME, [ Argument( Var("d", mapping_str_any_type), mapping_str_any_type, None, ARG_POS, ) ], fill_typevars(info), ) if TRY_STRUCTURE_NAME not in info.names: # print('adding ' + TRY_STRUCTURE_NAME + ' to ' + str(info.fullname()) ) add_static_method( cls_def_ctx, TRY_STRUCTURE_NAME, [ Argument( Var("d", maybe_mapping_str_any_type), maybe_mapping_str_any_type, None, ARG_POS, ) ], fill_typevars(info), ) if UNSTRUCTURE_NAME not in info.names: add_method(cls_def_ctx, UNSTRUCTURE_NAME, [], dict_type)
def add_model_init(ctx, var_types, var_fields) -> None: """Add dummy init method to class.""" args = [] for (var_type, var_field) in zip(var_types, var_fields): var = nodes.Var(var_field, var_type) args.append( nodes.Argument(variable=var, type_annotation=var_type, initializer=None, kind=nodes.ARG_POS)) common.add_method(ctx, '__init__', args, types.NoneTyp())
def add_get_set_attr_fallback_to_any(ctx: ClassDefContext): any = AnyType(TypeOfAny.special_form) name_arg = Argument(variable=Var('name', any), type_annotation=any, initializer=None, kind=ARG_POS) add_method(ctx, '__getattr__', [name_arg], any) value_arg = Argument(variable=Var('value', any), type_annotation=any, initializer=None, kind=ARG_POS) add_method(ctx, '__setattr__', [name_arg, value_arg], any)
def strawberry_pydantic_class_callback(ctx: ClassDefContext) -> None: # in future we want to have a proper pydantic plugin, but for now # let's fallback to **kwargs for __init__, some resources are here: # https://github.com/samuelcolvin/pydantic/blob/master/pydantic/mypy.py # >>> model_index = ctx.cls.decorators[0].arg_names.index("model") # >>> model_name = ctx.cls.decorators[0].args[model_index].name # >>> model_type = ctx.api.named_type("UserModel") # >>> model_type = ctx.api.lookup(model_name, Context()) model_expression = _get_argument(call=ctx.reason, name="model") # type: ignore if model_expression is None: ctx.api.fail("model argument in decorator failed to be parsed", ctx.reason) else: # Add __init__ init_args = [ Argument(Var("kwargs"), AnyType(TypeOfAny.explicit), None, ARG_STAR2) ] add_method(ctx, "__init__", init_args, NoneType()) model_type = _get_type_for_expr(model_expression, ctx.api) # Add to_pydantic add_method( ctx, "to_pydantic", args=[], return_type=model_type, ) # Add from_pydantic model_argument = Argument( variable=Var(name="instance", type=model_type), type_annotation=model_type, initializer=None, kind=ARG_OPT, ) add_static_method_to_class( ctx.api, ctx.cls, name="from_pydantic", args=[model_argument], return_type=fill_typevars(ctx.cls.info), )
def add_dummy_init_method(ctx: ClassDefContext) -> None: any = AnyType(TypeOfAny.special_form) pos_arg = Argument(variable=Var('args', any), type_annotation=any, initializer=None, kind=ARG_STAR) kw_arg = Argument(variable=Var('kwargs', any), type_annotation=any, initializer=None, kind=ARG_STAR2) add_method(ctx, '__init__', [pos_arg, kw_arg], NoneTyp()) # mark as model class ctx.cls.info.metadata.setdefault('django', {})['generated_init'] = True
def add_model_init_hook(ctx: ClassDefContext) -> None: """Add a dummy __init__() to a model and record it is generated. Instantiation will be checked more precisely when we inferred types (using get_function_hook and model_hook). """ if '__init__' in ctx.cls.info.names: # Don't override existing definition. return any = AnyType(TypeOfAny.special_form) var = Var('kwargs', any) kw_arg = Argument(variable=var, type_annotation=any, initializer=None, kind=ARG_STAR2) add_method(ctx, '__init__', [kw_arg], NoneTyp()) ctx.cls.info.metadata.setdefault('olo', {})['generated_init'] = True
def transform(self): ctx = self._ctx metadata_key = self._get_metadata_key() metadata = ctx.cls.info.metadata.get(metadata_key) if not metadata: ctx.cls.info.metadata[metadata_key] = metadata = {} metadata['processing'] = True if metadata.get('processed'): return try: fields = self._collect_fields() schema_t = self._lookup_type('edb.schema.schema.Schema') except DeferException: ctx.api.defer() return None for f in fields: if f.has_explicit_accessor: continue mypy_helpers.add_method( ctx, name=f'get_{f.name}', args=[ nodes.Argument( variable=nodes.Var( name='schema', type=schema_t, ), type_annotation=schema_t, initializer=None, kind=nodes.ARG_POS, ), ], return_type=f.type, ) metadata['fields'] = {f.name: f.serialize() for f in fields} metadata['processed'] = True
def add_model_set(ctx) -> None: """Add model fields method.""" args = [] str_type = ctx.api.builtin_type('builtins.str') name_var = nodes.Var('name', str_type) name_arg = nodes.Argument(variable=name_var, type_annotation=str_type, initializer=None, kind=nodes.ARG_POS) args.append(name_arg) any_type = types.AnyType(types.TypeOfAny.special_form) value_var = nodes.Var('value', any_type) value_arg = nodes.Argument(variable=value_var, type_annotation=any_type, initializer=None, kind=nodes.ARG_POS) args.append(value_arg) common.add_method(ctx, '_set', args, types.NoneTyp())
def adjust_class_def(class_def_context: ClassDefContext) -> None: api = class_def_context.api ordering_other_type = api.named_type("__builtins__.object") ordering_return_type = api.named_type("__builtins__.bool") arg = Argument( variable=Var(name="other", type=ordering_other_type), type_annotation=ordering_other_type, initializer=None, kind=ARG_POS, ) type_info: TypeInfo = class_def_context.cls.info for ordering_method_name in "__lt__", "__le__", "__gt__", "__ge__": existing_method = type_info.get(ordering_method_name) if existing_method is None: add_method( ctx=class_def_context, name=ordering_method_name, args=[arg], return_type=ordering_return_type, )
def copy_method_to_another_class(ctx: ClassDefContext, self_type: Instance, new_method_name: str, method_node: FuncDef) -> None: arguments, return_type = _prepare_new_method_arguments(method_node) semanal_api = get_semanal_api(ctx) for argument in arguments: if argument.type_annotation is not None: argument.type_annotation = semanal_api.anal_type( argument.type_annotation, allow_placeholder=True) if return_type is not None: ret = semanal_api.anal_type(return_type, allow_placeholder=True) assert ret is not None return_type = ret add_method(ctx, new_method_name, args=arguments, return_type=return_type, self_type=self_type)
def run_with_model_cls(self, model_cls: Type[Model]) -> None: # get_FOO_display for choices for field in self.django_context.get_model_fields(model_cls): if field.choices: info = self.lookup_typeinfo_or_incomplete_defn_error('builtins.str') return_type = Instance(info, []) common.add_method(self.ctx, name='get_{}_display'.format(field.attname), args=[], return_type=return_type) # get_next_by, get_previous_by for Date, DateTime for field in self.django_context.get_model_fields(model_cls): if isinstance(field, (DateField, DateTimeField)) and not field.null: return_type = Instance(self.model_classdef.info, []) common.add_method(self.ctx, name='get_next_by_{}'.format(field.attname), args=[Argument(Var('kwargs', AnyType(TypeOfAny.explicit)), AnyType(TypeOfAny.explicit), initializer=None, kind=ARG_STAR2)], return_type=return_type) common.add_method(self.ctx, name='get_previous_by_{}'.format(field.attname), args=[Argument(Var('kwargs', AnyType(TypeOfAny.explicit)), AnyType(TypeOfAny.explicit), initializer=None, kind=ARG_STAR2)], return_type=return_type)
def analyze_template(ctx: ClassDefContext) -> None: template_response = ctx.api.lookup_fully_qualified_or_none( "django.template.response.TemplateResponse") if template_response is None and not ctx.api.final_iteration: ctx.api.defer() return http_request = ctx.api.lookup_fully_qualified_or_none( "django.http.HttpRequest") if http_request is None and not ctx.api.final_iteration: ctx.api.defer() return request_arg = Argument( variable=Var( "request", Instance(http_request.node, []), # type: ignore[union-attr, arg-type] ), type_annotation=Instance( http_request.node, [] # type: ignore[union-attr, arg-type] ), initializer=None, kind=ARG_POS, ) add_method( ctx, "render", args=[request_arg], return_type=Instance( template_response.node, [] # type: ignore[union-attr, arg-type] ), )
def hook(ctx): # Limitation: we can't have closures around our classes constructed_cls = locate(ctx.cls.fullname) if getattr(constructed_cls.Meta, 'abstract', False): return for method_name, method in constructed_cls.__constructed_methods.items(): return_type = ctx.api.named_type( method.__annotations__['return'].__name__) self_type = ctx.api.named_type(ctx.cls.name) arguments = [] arg_specification = inspect.getfullargspec(method) # Pop the self arg off arg_specification.args.pop(0) for arg in arg_specification.args: arg_type = ctx.api.named_type( arg_specification.annotations[arg].__name__) arg = Argument(variable=Var(arg, arg_type), type_annotation=arg_type, initializer=None, kind=0) arguments.append(arg) add_method( ctx, method_name, args=arguments, return_type=return_type, self_type=self_type, )
def layout_class_callback(ctx: ClassDefContext) -> None: path_type = ctx.api.builtin_type("pathlib.Path") config_type = ctx.api.named_type_or_none( "tts.config.Config") # type: ignore # Change the types of class members when type checking the body to `pathlib.Path`. # `_unpacked_layout` expects the values to be paths. for stmt in ctx.cls.defs.body: assert isinstance(stmt, AssignmentStmt) assert len(stmt.lvalues) == 1 lvalue = stmt.lvalues[0] assert isinstance(lvalue, NameExpr) lvalue.node = _make_name_lvalue_var(lvalue, path_type, ctx) add_method( ctx, "__init__", [ Argument(Var("path"), path_type, None, ARG_POS), Argument(Var("config"), config_type, None, ARG_POS), ], NoneType(), )
def transform(self) -> None: """Apply all the necessary transformations to the underlying dataclass so as to ensure it is fully type checked according to the rules in PEP 557. """ ctx = self._ctx info = self._ctx.cls.info attributes = self.collect_attributes() if attributes is None: # Some definitions are not ready, defer() should be already called. return for attr in attributes: if attr.type is None: ctx.api.defer() return decorator_arguments = { 'init': _get_decorator_bool_argument(self._ctx, 'init', True), 'eq': _get_decorator_bool_argument(self._ctx, 'eq', True), 'order': _get_decorator_bool_argument(self._ctx, 'order', False), 'frozen': _get_decorator_bool_argument(self._ctx, 'frozen', False), } if info.get('replace') is None: obj_type = ctx.api.named_type('__builtins__.object') self_tvar_expr = TypeVarExpr(SELF_UVAR_NAME, info.fullname + '.' + SELF_UVAR_NAME, [], obj_type) info.names[SELF_UVAR_NAME] = SymbolTableNode(MDEF, self_tvar_expr) replace_tvar_def = TypeVarDef(SELF_UVAR_NAME, info.fullname + '.' + SELF_UVAR_NAME, -1, [], fill_typevars(info)) replace_other_type = TypeVarType(replace_tvar_def) add_method(ctx, 'replace', args=[ Argument( Var('changes', AnyType(TypeOfAny.explicit)), AnyType(TypeOfAny.explicit), None, ARG_STAR2) ], return_type=replace_other_type, self_type=replace_other_type, tvar_def=replace_tvar_def) # If there are no attributes, it may be that the semantic analyzer has not # processed them yet. In order to work around this, we can simply skip generating # __init__ if there are no attributes, because if the user truly did not define any, # then the object default __init__ with an empty signature will be present anyway. if (decorator_arguments['init'] and ('__init__' not in info.names or info.names['__init__'].plugin_generated) and attributes): add_method( ctx, '__init__', args=[ attr.to_argument() for attr in attributes if attr.is_in_init ], return_type=NoneType(), ) if (decorator_arguments['eq'] and info.get('__eq__') is None or decorator_arguments['order']): # Type variable for self types in generated methods. obj_type = ctx.api.named_type('__builtins__.object') self_tvar_expr = TypeVarExpr(SELF_TVAR_NAME, info.fullname + '.' + SELF_TVAR_NAME, [], obj_type) info.names[SELF_TVAR_NAME] = SymbolTableNode(MDEF, self_tvar_expr) # Add <, >, <=, >=, but only if the class has an eq method. if decorator_arguments['order']: if not decorator_arguments['eq']: ctx.api.fail('eq must be True if order is True', ctx.cls) for method_name in ['__lt__', '__gt__', '__le__', '__ge__']: # Like for __eq__ and __ne__, we want "other" to match # the self type. obj_type = ctx.api.named_type('__builtins__.object') order_tvar_def = TypeVarDef( SELF_TVAR_NAME, info.fullname + '.' + SELF_TVAR_NAME, -1, [], obj_type) order_other_type = TypeVarType(order_tvar_def) order_return_type = ctx.api.named_type('__builtins__.bool') order_args = [ Argument(Var('other', order_other_type), order_other_type, None, ARG_POS) ] existing_method = info.get(method_name) if existing_method is not None and not existing_method.plugin_generated: assert existing_method.node ctx.api.fail( 'You may not have a custom %s method when order=True' % method_name, existing_method.node, ) add_method( ctx, method_name, args=order_args, return_type=order_return_type, self_type=order_other_type, tvar_def=order_tvar_def, ) if decorator_arguments['frozen']: self._freeze(attributes) self.reset_init_only_vars(info, attributes) info.metadata['dataclass'] = { 'attributes': [attr.serialize() for attr in attributes], 'frozen': decorator_arguments['frozen'], }
def transform(self) -> None: """Apply all the necessary transformations to the underlying dataclass so as to ensure it is fully type checked according to the rules in PEP 557. """ ctx = self._ctx info = self._ctx.cls.info attributes = self.collect_attributes() if ctx.api.options.new_semantic_analyzer: # Check if attribute types are ready. for attr in attributes: if info[attr.name].type is None: ctx.api.defer() return decorator_arguments = { 'init': _get_decorator_bool_argument(self._ctx, 'init', True), 'eq': _get_decorator_bool_argument(self._ctx, 'eq', True), 'order': _get_decorator_bool_argument(self._ctx, 'order', False), 'frozen': _get_decorator_bool_argument(self._ctx, 'frozen', False), } if decorator_arguments['init']: add_method( ctx, '__init__', args=[attr.to_argument(info) for attr in attributes if attr.is_in_init], return_type=NoneTyp(), ) if (decorator_arguments['eq'] and info.get('__eq__') is None or decorator_arguments['order']): # Type variable for self types in generated methods. obj_type = ctx.api.named_type('__builtins__.object') self_tvar_expr = TypeVarExpr(SELF_TVAR_NAME, info.fullname() + '.' + SELF_TVAR_NAME, [], obj_type) info.names[SELF_TVAR_NAME] = SymbolTableNode(MDEF, self_tvar_expr) # Add an eq method, but only if the class doesn't already have one. if decorator_arguments['eq'] and info.get('__eq__') is None: for method_name in ['__eq__', '__ne__']: # The TVar is used to enforce that "other" must have # the same type as self (covariant). Note the # "self_type" parameter to add_method. obj_type = ctx.api.named_type('__builtins__.object') cmp_tvar_def = TypeVarDef(SELF_TVAR_NAME, info.fullname() + '.' + SELF_TVAR_NAME, -1, [], obj_type) cmp_other_type = TypeVarType(cmp_tvar_def) cmp_return_type = ctx.api.named_type('__builtins__.bool') add_method( ctx, method_name, args=[Argument(Var('other', cmp_other_type), cmp_other_type, None, ARG_POS)], return_type=cmp_return_type, self_type=cmp_other_type, tvar_def=cmp_tvar_def, ) # Add <, >, <=, >=, but only if the class has an eq method. if decorator_arguments['order']: if not decorator_arguments['eq']: ctx.api.fail('eq must be True if order is True', ctx.cls) for method_name in ['__lt__', '__gt__', '__le__', '__ge__']: # Like for __eq__ and __ne__, we want "other" to match # the self type. obj_type = ctx.api.named_type('__builtins__.object') order_tvar_def = TypeVarDef(SELF_TVAR_NAME, info.fullname() + '.' + SELF_TVAR_NAME, -1, [], obj_type) order_other_type = TypeVarType(order_tvar_def) order_return_type = ctx.api.named_type('__builtins__.bool') order_args = [ Argument(Var('other', order_other_type), order_other_type, None, ARG_POS) ] existing_method = info.get(method_name) if existing_method is not None: assert existing_method.node ctx.api.fail( 'You may not have a custom %s method when order=True' % method_name, existing_method.node, ) add_method( ctx, method_name, args=order_args, return_type=order_return_type, self_type=order_other_type, tvar_def=order_tvar_def, ) if decorator_arguments['frozen']: self._freeze(attributes) # Remove init-only vars from the class. for attr in attributes: if attr.is_init_var: del info.names[attr.name] info.metadata['dataclass'] = { 'attributes': OrderedDict((attr.name, attr.serialize()) for attr in attributes), 'frozen': decorator_arguments['frozen'], }
def transform(self) -> None: """Apply all the necessary transformations to the underlying dataclass so as to ensure it is fully type checked according to the rules in PEP 557. """ ctx = self._ctx info = self._ctx.cls.info attributes = self.collect_attributes() if attributes is None: # Some definitions are not ready, defer() should be already called. return for attr in attributes: if attr.type is None: ctx.api.defer() return decorator_arguments = { 'init': _get_decorator_bool_argument(self._ctx, 'init', True), 'eq': _get_decorator_bool_argument(self._ctx, 'eq', True), 'order': _get_decorator_bool_argument(self._ctx, 'order', False), 'frozen': _get_decorator_bool_argument(self._ctx, 'frozen', False), 'slots': _get_decorator_bool_argument(self._ctx, 'slots', False), 'match_args': _get_decorator_bool_argument(self._ctx, 'match_args', True), } py_version = self._ctx.api.options.python_version # If there are no attributes, it may be that the semantic analyzer has not # processed them yet. In order to work around this, we can simply skip generating # __init__ if there are no attributes, because if the user truly did not define any, # then the object default __init__ with an empty signature will be present anyway. if (decorator_arguments['init'] and ('__init__' not in info.names or info.names['__init__'].plugin_generated) and attributes): args = [ attr.to_argument() for attr in attributes if attr.is_in_init and not self._is_kw_only_type(attr.type) ] if info.fallback_to_any: # Make positional args optional since we don't know their order. # This will at least allow us to typecheck them if they are called # as kwargs for arg in args: if arg.kind == ARG_POS: arg.kind = ARG_OPT nameless_var = Var('') args = [ Argument(nameless_var, AnyType(TypeOfAny.explicit), None, ARG_STAR), *args, Argument(nameless_var, AnyType(TypeOfAny.explicit), None, ARG_STAR2), ] add_method( ctx, '__init__', args=args, return_type=NoneType(), ) if (decorator_arguments['eq'] and info.get('__eq__') is None or decorator_arguments['order']): # Type variable for self types in generated methods. obj_type = ctx.api.named_type('builtins.object') self_tvar_expr = TypeVarExpr(SELF_TVAR_NAME, info.fullname + '.' + SELF_TVAR_NAME, [], obj_type) info.names[SELF_TVAR_NAME] = SymbolTableNode(MDEF, self_tvar_expr) # Add <, >, <=, >=, but only if the class has an eq method. if decorator_arguments['order']: if not decorator_arguments['eq']: ctx.api.fail('eq must be True if order is True', ctx.cls) for method_name in ['__lt__', '__gt__', '__le__', '__ge__']: # Like for __eq__ and __ne__, we want "other" to match # the self type. obj_type = ctx.api.named_type('builtins.object') order_tvar_def = TypeVarType( SELF_TVAR_NAME, info.fullname + '.' + SELF_TVAR_NAME, -1, [], obj_type) order_return_type = ctx.api.named_type('builtins.bool') order_args = [ Argument(Var('other', order_tvar_def), order_tvar_def, None, ARG_POS) ] existing_method = info.get(method_name) if existing_method is not None and not existing_method.plugin_generated: assert existing_method.node ctx.api.fail( 'You may not have a custom %s method when order=True' % method_name, existing_method.node, ) add_method( ctx, method_name, args=order_args, return_type=order_return_type, self_type=order_tvar_def, tvar_def=order_tvar_def, ) if decorator_arguments['frozen']: self._freeze(attributes) else: self._propertize_callables(attributes) if decorator_arguments['slots']: self.add_slots(info, attributes, correct_version=py_version >= (3, 10)) self.reset_init_only_vars(info, attributes) if (decorator_arguments['match_args'] and ('__match_args__' not in info.names or info.names['__match_args__'].plugin_generated) and attributes): str_type = ctx.api.named_type("builtins.str") literals: List[Type] = [ LiteralType(attr.name, str_type) for attr in attributes if attr.is_in_init ] match_args_type = TupleType(literals, ctx.api.named_type("builtins.tuple")) add_attribute_to_class(ctx.api, ctx.cls, "__match_args__", match_args_type, final=True) self._add_dataclass_fields_magic_attribute() info.metadata['dataclass'] = { 'attributes': [attr.serialize() for attr in attributes], 'frozen': decorator_arguments['frozen'], }
def transform(self) -> None: """Apply all the necessary transformations to the underlying dataclass so as to ensure it is fully type checked according to the rules in PEP 557. """ ctx = self._ctx info = self._ctx.cls.info attributes = self.collect_attributes() if attributes is None: # Some definitions are not ready, defer() should be already called. return for attr in attributes: if attr.type is None: ctx.api.defer() return decorator_arguments = { "init": _get_decorator_bool_argument(self._ctx, "init", True), "eq": _get_decorator_bool_argument(self._ctx, "eq", True), "order": _get_decorator_bool_argument(self._ctx, "order", False), "frozen": _get_decorator_bool_argument(self._ctx, "frozen", False), } # If there are no attributes, it may be that the semantic analyzer has not # processed them yet. In order to work around this, we can simply skip # generating __init__ if there are no attributes, because if the user # truly did not define any, then the object default __init__ with an # empty signature will be present anyway. if (decorator_arguments["init"] and ("__init__" not in info.names or info.names["__init__"].plugin_generated) and attributes): add_method( ctx, "__init__", args=[ attr.to_argument() for attr in attributes if attr.is_in_init ], return_type=NoneType(), ) if (decorator_arguments["eq"] and info.get("__eq__") is None or decorator_arguments["order"]): # Type variable for self types in generated methods. obj_type = ctx.api.named_type("__builtins__.object") self_tvar_expr = TypeVarExpr(SELF_TVAR_NAME, info.fullname + "." + SELF_TVAR_NAME, [], obj_type) info.names[SELF_TVAR_NAME] = SymbolTableNode(MDEF, self_tvar_expr) # Add <, >, <=, >=, but only if the class has an eq method. if decorator_arguments["order"]: if not decorator_arguments["eq"]: ctx.api.fail("eq must be True if order is True", ctx.cls) for method_name in ["__lt__", "__gt__", "__le__", "__ge__"]: # Like for __eq__ and __ne__, we want "other" to match # the self type. obj_type = ctx.api.named_type("__builtins__.object") order_tvar_def = TypeVarDef( SELF_TVAR_NAME, info.fullname + "." + SELF_TVAR_NAME, -1, [], obj_type, ) order_other_type = TypeVarType(order_tvar_def) order_return_type = ctx.api.named_type("__builtins__.bool") order_args = [ Argument(Var("other", order_other_type), order_other_type, None, ARG_POS) ] existing_method = info.get(method_name) if existing_method is not None and not existing_method.plugin_generated: assert existing_method.node ctx.api.fail( "You may not have a custom %s method when order=True" % method_name, existing_method.node, ) add_method( ctx, method_name, args=order_args, return_type=order_return_type, self_type=order_other_type, tvar_def=order_tvar_def, ) if decorator_arguments["frozen"]: self._freeze(attributes) self.reset_init_only_vars(info, attributes) info.metadata["dataclass"] = { "attributes": [attr.serialize() for attr in attributes], "frozen": decorator_arguments["frozen"], }