def _jsonschema_type_mapping(self): initializers = list(initializer_registry.keys()) return { "oneOf": [ { "type": "string", "enum": initializers, "default": default, "title": "initializer_preconfigured_option", "description": "Pick a preconfigured initializer.", }, # Note: default not provided in the custom dict option: { "type": "object", "properties": { "type": {"type": "string", "enum": initializers}, }, "required": ["type"], "title": "initializer_custom_option", "additionalProperties": True, "description": "Customize an existing initializer.", }, ], "title": self.name, "default": default, "description": description, }
def _deserialize(self, value, attr, data, **kwargs): initializers = list(initializer_registry.keys()) if isinstance(value, str): if value not in initializers: raise ValidationError( f"Expected one of: {initializers}, found: {value}") return value if isinstance(value, dict): if "type" not in value: raise ValidationError(f"Dict must contain 'type'") if value["type"] not in initializers: raise ValidationError( f"Dict expected key 'type' to be one of: {initializers}, found: {value}" ) return value raise ValidationError("Field should be str or dict")
def _jsonschema_type_mapping(self): initializers = list(initializer_registry.keys()) return { "oneOf": [ { "type": "string", "enum": initializers }, { "type": "object", "properties": { "type": { "type": "string", "enum": initializers }, }, "required": ["type"], "additionalProperties": True, }, ] }
def InitializerOptions(default=None): return StringOptions(list(initializer_registry.keys()), default=default, nullable=True)
def InitializerOptions(default: str = "xavier_uniform", description=""): """Utility wrapper that returns a `StringOptions` field with keys from `initializer_registry`.""" return StringOptions(list(initializer_registry.keys()), default=default, allow_none=False, description=description)
def InitializerOrDict(default: str = "xavier_uniform", description: str = ""): """Returns a dataclass field with marshmallow metadata allowing customizable initializers. In particular, allows str or dict types; in the former case the field is equivalent to `InitializerOptions` while in the latter case a dict can be defined with the `type` field enforced to be one of `initializer_registry` as usual while additional properties are unrestricted. """ initializers = list(initializer_registry.keys()) if not isinstance(default, str) or default not in initializers: raise ValidationError(f"Invalid default: `{default}`") class InitializerOptionsOrCustomDictField(fields.Field): def _deserialize(self, value, attr, data, **kwargs): if isinstance(value, str): if value not in initializers: raise ValidationError(f"Expected one of: {initializers}, found: {value}") return value if isinstance(value, dict): if "type" not in value: raise ValidationError("Dict must contain 'type'") if value["type"] not in initializers: raise ValidationError(f"Dict expected key 'type' to be one of: {initializers}, found: {value}") return value raise ValidationError("Field should be str or dict") def _jsonschema_type_mapping(self): initializers = list(initializer_registry.keys()) return { "oneOf": [ { "type": "string", "enum": initializers, "default": default, "title": "initializer_preconfigured_option", "description": "Pick a preconfigured initializer.", }, # Note: default not provided in the custom dict option: { "type": "object", "properties": { "type": {"type": "string", "enum": initializers}, }, "required": ["type"], "title": "initializer_custom_option", "additionalProperties": True, "description": "Customize an existing initializer.", }, ], "title": self.name, "default": default, "description": description, } return field( metadata={ "marshmallow_field": InitializerOptionsOrCustomDictField( allow_none=False, load_default=default, dump_default=default, metadata={"description": description} ) }, default=default, )