def _add_implicit_field(self, node, cls_locals, key, typ): if key in cls_locals: self.vm.errorlog.invalid_annotation( self.vm.frames, None, name=key, details=f"flax.linen.Module defines field '{key}' implicitly") default = typ.to_variable(node) cls_locals[key] = abstract_utils.Local(node, None, typ, default, self.vm)
def _get_class_locals(self, node, cls_name, cls_dict): # First, check if get_class_locals works for this class. if cls_name in self.vm.local_ops: ret = classgen.get_class_locals(cls_name, False, classgen.Ordering.LAST_ASSIGN, self.vm).items() return ret # If it doesn't work, then it's likely this class was created using the # functional API. Grab members from the cls_dict instead. ret = { name: abstract_utils.Local(node, None, None, value, self.vm) for name, value in cls_dict.items() } return ret.items()