class registry(object): languages = catalogue.create("spacy", "languages", entry_points=True) architectures = catalogue.create("spacy", "architectures", entry_points=True) lookups = catalogue.create("spacy", "lookups", entry_points=True) factories = catalogue.create("spacy", "factories", entry_points=True) displacy_colors = catalogue.create("spacy", "displacy_colors", entry_points=True)
def test_create_single_namespace(): test_registry = catalogue.create("test") assert catalogue.REGISTRY == {} @test_registry.register("a") def a(): pass def b(): pass test_registry.register("b", func=b) items = test_registry.get_all() assert len(items) == 2 assert items["a"] == a assert items["b"] == b assert catalogue.check_exists("test", "a") assert catalogue.check_exists("test", "b") assert catalogue._get(("test", "a")) == a assert catalogue._get(("test", "b")) == b with pytest.raises(TypeError): # The decorator only accepts one argument @test_registry.register("x", "y") def x(): pass
def create(cls, registry_name: str, entry_points: bool = False) -> None: """Create a new custom registry.""" if hasattr(cls, registry_name): raise ValueError(f"Registry '{registry_name}' already exists") reg: Decorator = catalogue.create('xpersist', registry_name, entry_points=entry_points) setattr(cls, registry_name, reg)
def test_create_multi_namespace(): test_registry = catalogue.create("x", "y") @test_registry.register("z") def z(): pass items = test_registry.get_all() assert len(items) == 1 assert items["z"] == z assert catalogue.check_exists("x", "y", "z") assert catalogue._get(("x", "y", "z")) == z
def test_registry_find(): test_registry = catalogue.create("test_registry_find") name = "a" @test_registry.register(name) def a(): """This is a registered function.""" pass info = test_registry.find(name) assert info["module"] == "catalogue.tests.test_catalogue" assert info["file"] == str(Path(__file__)) assert info["docstring"] == "This is a registered function." assert info["line_no"]
def test_entry_points(): # Create a new EntryPoint object by pretending we have a setup.cfg and # use one of catalogue's util functions as the advertised function ep_string = "[options.entry_points]test_foo\n bar = catalogue:check_exists" ep = catalogue.importlib_metadata.EntryPoint._from_text(ep_string) catalogue.AVAILABLE_ENTRY_POINTS["test_foo"] = ep assert catalogue.REGISTRY == {} test_registry = catalogue.create("test", "foo", entry_points=True) entry_points = test_registry.get_entry_points() assert "bar" in entry_points assert entry_points["bar"] == catalogue.check_exists assert test_registry.get_entry_point("bar") == catalogue.check_exists assert catalogue.REGISTRY == {} assert test_registry.get("bar") == catalogue.check_exists assert test_registry.get_all() == {"bar": catalogue.check_exists} assert "bar" in test_registry
class Registry: """ Catalogue registry for types, preprocessors, logging configuration, and others Attributes: types: Types for field specs, registered functions for creating ValueSupplierInterface that will supply values for the given type >>> @datacraft.registry.types('special_sauce') ... def _handle_special_type(field_spec: dict, loader: datacraft.Loader) -> ValueSupplierInterface: ... # return ValueSupplierInterface from spec config schemas: Schemas for field spec types, used to validate that the spec for a given type conforms to the schema for it >>> @datacraft.registry.schemas('special_sauce') ... def _special_sauce_schema() -> dict: ... # return JSON schema validating specs with type: special_sauce preprocessors: Functions to modify specs before data generations process. If there is a customization you want to do for every data spec, or an extenstion you added that requires modifications to the spec before they are run, this is where you would register that pre-processor. >>> @datacraft.registry.preprocessors('custom-preprocessing') ... def _preprocess_spec_to_some_end(raw_spec: dict, is_refs: bool) -> dict: ... # return spec with any modification logging: Custom logging setup. Can override or modify the default logging behavior. >>> @datacraft.registry.logging('denoise') ... def _customize_logging(loglevel: str): ... logging.getLogger('too.verbose.module').level = logging.ERROR formats: Registered formats for output. When using the --format <format name>. Unlike other registered functions, this one is called directly for to perform the required formatting function. The return value from the formatter is the new value that will be written to the configured output (default is console). >>> @datacraft.registry.formats('custom_format') ... def _format_custom(record: dict) -> str: ... # write to database or some other custom output, return something to write out or print to console distribution: Different numeric distributions, normal, uniform, etc. These are used for more nuanced counts values. The built in distributions are uniform and normal. >>> @datacraft.registry.distribution('hyperbolic_inverse_haversine') ... def _hyperbolic_inverse_haversine(mean, stddev, **kwargs): ... # return a datacraft.Distribution, args can be custom for the defined distribution defaults: Default values. Different types have different default values for some configs. This provides a mechanism to override or to register other custom defaults. Read a default from the registry with: ``datacraft.types.get_default('var_key')``. While ``datacraft.types.all_defaults()`` will give a mapping of all registered default keys and values. >>> @datacraft.registry.defaults('special_sauce_ingredient') ... def _default_special_sauce_ingredient(): ... # return the default value (i.e. onions) casters: Cast or alter values in simple ways. These are all the valid forms of altering generated values after they are created outside of the ValueSupplier types. Use ``datacraft.types.registered_casters()`` to get a list of all the currently registered ones. >>> @datacraft.registry.casters('reverse') ... def _cast_reverse_strings(): ... # return a datacraft.CasterInterface """ types = catalogue.create('datacraft', 'type') schemas = catalogue.create('datacraft', 'schemas') preprocessors = catalogue.create('datacraft', 'preprocessor') logging = catalogue.create('datacraft', 'logging') formats = catalogue.create('datacraft', 'format') distribution = catalogue.create('datacraft', 'distribution') defaults = catalogue.create('datacraft', 'defaults') casters = catalogue.create('datacraft', 'casters')
import catalogue register_loader = catalogue.create("ml-datasets", entry_points=True)
class registry: translators = catalogue.create("dstl", "translators", entry_points=True)
class registry: """xpersist's global registry entrypoint. This is used to register serializers and other components that are used by xpersist. """ serializers: Decorator = catalogue.create('xpersist', 'serializers', entry_points=True) metadata_store: Decorator = catalogue.create('xpersist', 'metadata_store', entry_points=True) @classmethod def create(cls, registry_name: str, entry_points: bool = False) -> None: """Create a new custom registry.""" if hasattr(cls, registry_name): raise ValueError(f"Registry '{registry_name}' already exists") reg: Decorator = catalogue.create('xpersist', registry_name, entry_points=entry_points) setattr(cls, registry_name, reg) @classmethod def has(cls, registry_name: str, func_name: str) -> bool: """Check whether a function is available in a registry. Parameters ---------- registry_name : str The name of the registry to check. func_name : str The name of the function to check. Returns ------- bool Whether the function is available in the registry. """ if not hasattr(cls, registry_name): return False reg = getattr(cls, registry_name) return func_name in reg @classmethod def get(cls, registry_name: str, func_name: str) -> typing.Callable: """Get a registered function from a given registry. Parameters ---------- registry_name : str The name of the registry to get the function from. func_name : str The name of the function to get. Returns ------- func : typing.Callable The function from the registry. """ if not hasattr(cls, registry_name): raise ValueError(f"Unknown registry: '{registry_name}'") reg = getattr(cls, registry_name) func = reg.get(func_name) if func is None: raise ValueError( f"Could not find '{func_name}' in '{registry_name}'") return func
import prodigy from prodigy.components.db import connect from prodigy.util import log, split_string, set_hashes, TASK_HASH_ATTR, INPUT_HASH_ATTR import murmurhash from sense2vec import Sense2Vec import srsly import spacy import random from wasabi import msg from collections import defaultdict, Counter import copy import catalogue # fmt: off eval_strategies = catalogue.create("prodigy", "sense2vec.eval") EVAL_EXCLUDE_SENSES = ("SYM", "MONEY", "ORDINAL", "CARDINAL", "DATE", "TIME", "PERCENT", "QUANTITY", "NUM", "X", "PUNCT") # fmt: on @prodigy.recipe( "sense2vec.teach", dataset=("Dataset to save annotations to", "positional", None, str), vectors_path=("Path to pretrained sense2vec vectors", "positional", None, str), seeds=("One or more comma-separated seed phrases", "option", "se", split_string), threshold=("Similarity threshold for sense2vec", "option", "t", float), n_similar=("Number of similar items to get at once", "option", "n", int), batch_size=("Batch size for submitting annotations", "option", "bs", int), case_sensitive=("Show the same terms with different casing", "flag", "CS",
class registry(object): # fmt: off optimizers: Decorator = catalogue.create("thinc", "optimizers", entry_points=True) schedules: Decorator = catalogue.create("thinc", "schedules", entry_points=True) layers: Decorator = catalogue.create("thinc", "layers", entry_points=True) losses: Decorator = catalogue.create("thinc", "losses", entry_points=True) initializers: Decorator = catalogue.create("thinc", "initializers", entry_points=True) datasets: Decorator = catalogue.create("thinc", "datasets", entry_points=True) # fmt: on @classmethod def create(cls, registry_name: str, entry_points: bool = False) -> None: """Create a new custom registry.""" if hasattr(cls, registry_name): raise ValueError(f"Registry '{registry_name}' already exists") reg: Decorator = catalogue.create("thinc", registry_name, entry_points=entry_points) setattr(cls, registry_name, reg) @classmethod def get(cls, registry_name: str, func_name: str) -> Callable: """Get a registered function from a given registry.""" if not hasattr(cls, registry_name): raise ValueError(f"Unknown registry: '{registry_name}'") reg = getattr(cls, registry_name) func = reg.get(func_name) if func is None: raise ValueError( f"Could not find '{func_name}' in '{registry_name}'") return func @classmethod def resolve( cls, config: Union[Config, Dict[str, Dict[str, Any]]], *, schema: Type[BaseModel] = EmptySchema, overrides: Dict[str, Any] = {}, validate: bool = True, ) -> Tuple[Dict[str, Any], Config]: """Unpack a config dictionary and create two versions of the config: a resolved version with objects from the registry created recursively, and a filled version with all references to registry functions left intact, but filled with all values and defaults based on the type annotations. If validate=True, the config will be validated against the type annotations of the registered functions referenced in the config (if available) and/or the schema (if available). """ # Valid: {"optimizer": {"@optimizers": "my_cool_optimizer", "rate": 1.0}} # Invalid: {"@optimizers": "my_cool_optimizer", "rate": 1.0} if cls.is_promise(config): err_msg = "The top-level config object can't be a reference to a registered function." raise ConfigValidationError(config, [{"msg": err_msg}]) # If a Config was loaded with interpolate=False, we assume it needs to # be interpolated first, otherwise we take it at face value is_interpolated = not isinstance(config, Config) or config.is_interpolated section_order = config.section_order if isinstance(config, Config) else None orig_config = config if not is_interpolated: config = Config(orig_config).interpolate() filled, _, resolved = cls._fill(config, schema, validate=validate, overrides=overrides) filled = Config(filled, section_order=section_order) # Check that overrides didn't include invalid properties not in config if validate: cls._validate_overrides(filled, overrides) # Merge the original config back to preserve variables if we started # with a config that wasn't interpolated. Here, we prefer variables to # allow auto-filling a non-interpolated config without destroying # variable references. if not is_interpolated: filled = filled.merge(Config(orig_config, is_interpolated=False), remove_extra=True) return dict(resolved), filled @classmethod def make_from_config( cls, config: Union[Config, Dict[str, Dict[str, Any]]], *, schema: Type[BaseModel] = EmptySchema, overrides: Dict[str, Any] = {}, validate: bool = True, ) -> Dict[str, Any]: """Unpack a config dictionary, creating objects from the registry recursively. If validate=True, the config will be validated against the type annotations of the registered functions referenced in the config (if available) and/or the schema (if available). """ # Valid: {"optimizer": {"@optimizers": "my_cool_optimizer", "rate": 1.0}} # Invalid: {"@optimizers": "my_cool_optimizer", "rate": 1.0} resolved, _ = cls.resolve(config, schema=schema, overrides=overrides, validate=validate) return resolved @classmethod def fill_config( cls, config: Union[Config, Dict[str, Dict[str, Any]]], *, schema: Type[BaseModel] = EmptySchema, overrides: Dict[str, Any] = {}, validate: bool = True, ) -> Config: """Unpack a config dictionary, leave all references to registry functions intact and don't resolve them, but fill in all values and defaults based on the type annotations. If validate=True, the config will be validated against the type annotations of the registered functions referenced in the config (if available) and/or the schema (if available). """ _, filled = cls.resolve(config, schema=schema, overrides=overrides, validate=validate) return filled @classmethod def _fill( cls, config: Union[Config, Dict[str, Dict[str, Any]]], schema: Type[BaseModel] = EmptySchema, *, validate: bool = True, parent: str = "", overrides: Dict[str, Dict[str, Any]] = {}, ) -> Tuple[Union[Dict[str, Any], Config], Union[Dict[str, Any], Config], Dict[str, Any]]: """Build three representations of the config: 1. All promises are preserved (just like config user would provide). 2. Promises are replaced by their return values. This is the validation copy and will be parsed by pydantic. It lets us include hacks to work around problems (e.g. handling of generators). 3. Final copy with promises replaced by their return values. This is what registry.make_from_config returns. """ filled: Dict[str, Any] = {} validation: Dict[str, Any] = {} final: Dict[str, Any] = {} for key, value in config.items(): # If the field name is reserved, we use its alias for validation v_key = RESERVED_FIELDS.get(key, key) key_parent = f"{parent}.{key}".strip(".") if key_parent in overrides: value = overrides[key_parent] config[key] = value if cls.is_promise(value): promise_schema = cls.make_promise_schema(value) filled[key], validation[v_key], final[key] = cls._fill( value, promise_schema, validate=validate, parent=key_parent, overrides=overrides, ) # Call the function and populate the field value. We can't just # create an instance of the type here, since this wouldn't work # for generics / more complex custom types getter = cls.get_constructor(final[key]) args, kwargs = cls.parse_args(final[key]) try: getter_result = getter(*args, **kwargs) except Exception as err: err_msg = "Can't construct config: calling registry function failed" raise ConfigValidationError({key: value}, [{ "msg": err, "loc": [getter.__name__] }], err_msg) from err validation[v_key] = getter_result final[key] = getter_result if isinstance(validation[v_key], dict): # The registered function returned a dict, prevent it from # being validated as a config section validation[v_key] = {} if isinstance(validation[v_key], GeneratorType): # If value is a generator we can't validate type without # consuming it (which doesn't work if it's infinite – see # schedule for examples). So we skip it. validation[v_key] = [] elif hasattr(value, "items"): field_type = EmptySchema if key in schema.__fields__: field = schema.__fields__[key] field_type = field.type_ if not isinstance(field.type_, ModelMetaclass): # If we don't have a pydantic schema and just a type field_type = EmptySchema filled[key], validation[v_key], final[key] = cls._fill( value, field_type, validate=validate, parent=key_parent, overrides=overrides, ) if key == ARGS_FIELD and isinstance(validation[v_key], dict): # If the value of variable positional args is a dict (e.g. # created via config blocks), only use its values validation[v_key] = list(validation[v_key].values()) final[key] = list(final[key].values()) else: filled[key] = value # Prevent pydantic from consuming generator if part of a union validation[v_key] = ( value if not isinstance(value, GeneratorType) else []) final[key] = value # Now that we've filled in all of the promises, update with defaults # from schema, and validate if validation is enabled exclude = [] if validate: try: result = schema.parse_obj(validation) except ValidationError as e: raise ConfigValidationError(config, e.errors(), element=parent) from None else: # Same as parse_obj, but without validation result = schema.construct(**validation) # If our schema doesn't allow extra values, we need to filter them # manually because .construct doesn't parse anything if schema.Config.extra in (Extra.forbid, Extra.ignore): fields = schema.__fields__.keys() exclude = [k for k in result.__fields_set__ if k not in fields] exclude_validation = set([ARGS_FIELD_ALIAS, *RESERVED_FIELDS.keys()]) validation.update(result.dict(exclude=exclude_validation)) filled, final = cls._update_from_parsed(validation, filled, final) if exclude: filled = {k: v for k, v in filled.items() if k not in exclude} validation = { k: v for k, v in validation.items() if k not in exclude } final = {k: v for k, v in final.items() if k not in exclude} return filled, validation, final @classmethod def _update_from_parsed(cls, validation: Dict[str, Any], filled: Dict[str, Any], final: Dict[str, Any]): """Update the final result with the parsed config like converted values recursively. """ for key, value in validation.items(): if key in RESERVED_FIELDS.values(): continue # skip aliases for reserved fields if key not in filled: filled[key] = value if key not in final: final[key] = value if isinstance(value, dict): filled[key], final[key] = cls._update_from_parsed( value, filled[key], final[key]) # Update final config with parsed value if they're not equal (in # value and in type) but not if it's a generator because we had to # replace that to validate it correctly elif key == ARGS_FIELD: continue # don't substitute if list of positional args elif isinstance(value, numpy.ndarray): # check numpy first, just in case final[key] = value elif (value != final[key] or not isinstance(type(value), type(final[key])) ) and not isinstance(final[key], GeneratorType): final[key] = value return filled, final @classmethod def _validate_overrides(cls, filled: Config, overrides: Dict[str, Any]): """Validate overrides against a filled config to make sure there are no references to properties that don't exist and weren't used.""" error_msg = "Invalid override: config value doesn't exist" errors = [] for override_key in overrides.keys(): if not cls._is_in_config(override_key, filled): errors.append({"msg": error_msg, "loc": [override_key]}) if errors: raise ConfigValidationError(filled, errors) @classmethod def _is_in_config(cls, prop: str, config: Union[Dict[str, Any], Config]): """Check whether a nested config property like "section.subsection.key" is in a given config.""" tree = prop.split(".") obj = dict(config) while tree: key = tree.pop(0) if isinstance(obj, dict) and key in obj: obj = obj[key] else: return False return True @classmethod def is_promise(cls, obj: Any) -> bool: """Check whether an object is a "promise", i.e. contains a reference to a registered function (via a key starting with `"@"`. """ if not hasattr(obj, "keys"): return False id_keys = [k for k in obj.keys() if k.startswith("@")] if len(id_keys): return True return False @classmethod def get_constructor(cls, obj: Dict[str, Any]) -> Callable: id_keys = [k for k in obj.keys() if k.startswith("@")] if len(id_keys) != 1: err_msg = f"A block can only contain one function registry reference. Got: {id_keys}" raise ConfigValidationError(obj, [{"msg": err_msg}]) else: key = id_keys[0] value = obj[key] return cls.get(key[1:], value) @classmethod def parse_args(cls, obj: Dict[str, Any]) -> Tuple[List[Any], Dict[str, Any]]: args = [] kwargs = {} for key, value in obj.items(): if not key.startswith("@"): if key == ARGS_FIELD: args = value elif key in RESERVED_FIELDS.values(): continue else: kwargs[key] = value return args, kwargs @classmethod def make_promise_schema(cls, obj: Dict[str, Any]) -> Type[BaseModel]: """Create a schema for a promise dict (referencing a registry function) by inspecting the function signature. """ func = cls.get_constructor(obj) # Read the argument annotations and defaults from the function signature id_keys = [k for k in obj.keys() if k.startswith("@")] sig_args: Dict[str, Any] = {id_keys[0]: (str, ...)} for param in inspect.signature(func).parameters.values(): # If no annotation is specified assume it's anything annotation = param.annotation if param.annotation != param.empty else Any # If no default value is specified assume that it's required default = param.default if param.default != param.empty else ... # Handle spread arguments and use their annotation as Sequence[whatever] if param.kind == param.VAR_POSITIONAL: spread_annot = Sequence[annotation] # type: ignore sig_args[ARGS_FIELD_ALIAS] = (spread_annot, default) else: name = RESERVED_FIELDS.get(param.name, param.name) sig_args[name] = (annotation, default) sig_args["__config__"] = _PromiseSchemaConfig return create_model("ArgModel", **sig_args)
import inspect import catalogue from typing import Type, Optional, Union, TYPE_CHECKING if TYPE_CHECKING: from scrubadub.post_processors import PostProcessor post_processor_catalogue = catalogue.create('scrubadub', 'post_processors', entry_points=True) def register_post_processor(post_processor: Type['PostProcessor'], autoload: Optional[bool] = None, index: Optional[int] = None) -> None: """Register a PostProcessor for use with the ``Scrubber`` class. You can use ``register_post_processor(NewPostProcessor)`` after your post-processor definition to automatically register it with the ``Scrubber`` class so that it can be used to process Filth. The argument ``autoload`` sets if a new ``Scrubber()`` instance should load this ``PostProcessor`` by default. :param post_processor: The ``PostProcessor`` to register with the scrubadub post-processor configuration. :type post_processor: PostProcessor class :param autoload: Whether to automatically load this ``Detector`` on ``Scrubber`` initialisation. :type autoload: bool :param index: The location/index in which this ``PostProcessor`` should be added. :type index: int """ if not inspect.isclass(post_processor): raise ValueError("post_processor should be a class, not an instance.") if autoload is not None:
class registry: operations = catalogue.create("recon", "operations", entry_points=True)
import inspect import catalogue from typing import Type, Optional, Union, TYPE_CHECKING if TYPE_CHECKING: from scrubadub.detectors import Detector detector_catalogue = catalogue.create('scrubadub', 'detectors', entry_points=True) def register_detector(detector: Type['Detector'], *, autoload: Optional[bool] = None) -> Type['Detector']: """Register a detector for use with the ``Scrubber`` class. You can use ``register_detector(NewDetector, autoload=True)`` after your detector definition to automatically register it with the ``Scrubber`` class so that it can be used to remove Filth. The argument ``autoload``decides whether a new ``Scrubber()`` instance should load this ``detector`` by default. .. code:: pycon >>> import scrubadub >>> class NewDetector(scrubadub.detectors.Detector): ... pass >>> scrubadub.detectors.register_detector(NewDetector, autoload=False) <class 'scrubadub.detectors.catalogue.NewDetector'>
def test_registry_get_set(): test_registry = catalogue.create("test") with pytest.raises(catalogue.RegistryError): test_registry.get("foo") test_registry.register("foo", func=lambda x: x) assert "foo" in test_registry
import functools import catalogue from ._version import version from .exceptions import * from ._packer import Packer as _Packer from ._unpacker import unpackb as _unpackb from ._unpacker import unpack as _unpack from ._unpacker import Unpacker as _Unpacker from ._ext_type import ExtType from ._msgpack_numpy import encode_numpy as _encode_numpy from ._msgpack_numpy import decode_numpy as _decode_numpy msgpack_encoders = catalogue.create("srsly", "msgpack_encoders", entry_points=True) msgpack_decoders = catalogue.create("srsly", "msgpack_decoders", entry_points=True) msgpack_encoders.register("numpy", func=_encode_numpy) msgpack_decoders.register("numpy", func=_decode_numpy) # msgpack_numpy extensions class Packer(_Packer): def __init__(self, *args, **kwargs): default = kwargs.get("default") for encoder in msgpack_encoders.get_all().values(): default = functools.partial(encoder, chain=default)
class registry: preprocessors = catalogue.create("recon", "preprocessors", entry_points=True)
class my_registry(thinc.config.registry): cats = catalogue.create("thinc", "tests", "cats", entry_points=False)
class MetadataDetector(Detector): @abstractmethod def detect(self, column: CatColumn) -> Optional[PiiType]: """Scan the text and return an array of PiiTypes that are found""" class DatumDetector(Detector): @abstractmethod def detect(self, column: CatColumn, datum: str) -> Optional[PiiType]: """Scan the text and return an array of PiiTypes that are found""" detector_registry = catalogue.create("piicatcher", "detectors", entry_points=True) def register_detector(detector: Type["Detector"]) -> Type["Detector"]: """Register a detector for use. You can use ``register_detector(NewDetector)`` after your detector definition to automatically register it. .. code:: pycon >>> import piicatcher >>> class NewDetector(piicatcher.detectors.Detector): ... pass >>> piicatcher.detectors.register_detector(NewDetector) <class 'piicatcher.detectors.catalogue.NewDetector'>
from pathlib import Path import random from transformers import AutoConfig, AutoModel, AutoTokenizer from transformers.tokenization_utils import BatchEncoding from transformers.tokenization_utils_fast import PreTrainedTokenizerFast import catalogue from spacy.util import registry from thinc.api import get_current_ops, CupyOps import torch.cuda import tempfile import shutil import contextlib # fmt: off registry.span_getters = catalogue.create("spacy", "span_getters", entry_points=True) registry.annotation_setters = catalogue.create("spacy", "annotation_setters", entry_points=True) # fmt: on def huggingface_from_pretrained(source: Union[Path, str], tok_config: Dict, trf_config: Dict): """Create a Huggingface transformer model from pretrained weights. Will download the model if it is not already downloaded. source (Union[str, Path]): The name of the model or a path to it, such as 'bert-base-cased'. tok_config (dict): Settings to pass to the tokenizer.
class registry(object): optimizers = catalogue.create("thinc", "optimizers", entry_points=True) schedules = catalogue.create("thinc", "schedules", entry_points=True) layers = catalogue.create("thinc", "layers", entry_points=True) @classmethod def get(cls, name, key): if not hasattr(cls, name): raise ValueError("Unknown registry: %s" % name) reg = getattr(cls, name) func = reg.get(key) if func is None: raise ValueError("Could not find %s in %s" % (name, key)) return func @classmethod def make_optimizer(name, args, kwargs): func = cls.optimizers.get(name) return func(*args, **kwargs) @classmethod def make_schedule(name, args, kwargs): func = cls.schedules.get(name) return func(*args, **kwargs) @classmethod def make_initializer(name, args, kwargs): func = cls.initializers.get(name) return func(*args, **kwargs) @classmethod def make_layer(cls, name, args, kwargs): func = cls.layers.get(name) return func(*args, **kwargs) @classmethod def make_combinator(cls, name, args, kwargs): func = cls.combinators.get(name) return func(*args, **kwargs) @classmethod def make_transform(cls, name, args, kwargs): func = cls.transforms.get(name) return func(*args, **kwargs) @classmethod def make_from_config(cls, config, id_start="@"): """Unpack a config dictionary, creating objects from the registry recursively. """ id_keys = [key for key in config.keys() if key.startswith(id_start)] if len(id_keys) >= 2: raise ValueError("Multiple registry keys in config: %s" % id_keys) elif len(id_keys) == 0: # Recurse over subdictionaries, filling in values. filled = {} for key, value in config.items(): if isinstance(value, dict): filled[key] = cls.make_from_config(value, id_start=id_start) else: filled[key] = value return filled else: getter = cls.get(id_keys[0].replace(id_start, ""), config[id_keys[0]]) args = [] kwargs = {} for key, value in config.items(): if isinstance(value, dict): value = cls.make_from_config(value, id_start=id_start) if isinstance(key, int) or key.isdigit(): args.append((int(key), value)) elif not key.startswith(id_start): kwargs[key] = value args = [value for key, value in sorted(args)] return getter(*args, **kwargs)
class registry(object): candidate_selection = catalogue.create("spacy_ke", "candidate_selection")
def test_registry_call(): test_registry = catalogue.create("test") test_registry("foo", func=lambda x: x) assert "foo" in test_registry
try: import cupy except ImportError: cupy = None try: import tensorflow as tf except ImportError: # pragma: no cover pass try: import h5py except ImportError: # pragma: no cover pass keras_model_fns = catalogue.create("thinc", "keras", entry_points=True) def maybe_handshake_model(keras_model): """Call the required predict/compile/build APIs to initialize a model if it is a subclass of tf.keras.Model. This is required to be able to call set_weights on subclassed layers.""" try: keras_model.get_config() return keras_model except (AttributeError, NotImplementedError): # Subclassed models don't implement get_config pass for prop_name in ["catalogue_name", "eg_x", "eg_y", "eg_shape"]: if not hasattr(keras_model, prop_name):
class registry(object): make_key = catalogue.create("sense2vec", "make_key") split_key = catalogue.create("sense2vec", "split_key") make_spacy_key = catalogue.create("sense2vec", "make_spacy_key") get_phrases = catalogue.create("sense2vec", "get_phrases") merge_phrases = catalogue.create("sense2vec", "merge_phrases")
class registry(object): # fmt: off optimizers: Decorator = catalogue.create("thinc", "optimizers", entry_points=True) schedules: Decorator = catalogue.create("thinc", "schedules", entry_points=True) layers: Decorator = catalogue.create("thinc", "layers", entry_points=True) losses: Decorator = catalogue.create("thinc", "losses", entry_points=True) initializers: Decorator = catalogue.create("thinc", "initializers", entry_points=True) datasets: Decorator = catalogue.create("thinc", "datasets", entry_points=True) # fmt: on @classmethod def create(cls, registry_name: str, entry_points: bool = False) -> None: """Create a new custom registry.""" if hasattr(cls, registry_name): raise ValueError(f"Registry '{registry_name}' already exists") reg: Decorator = catalogue.create("thinc", registry_name, entry_points=entry_points) setattr(cls, registry_name, reg) @classmethod def get(cls, registry_name: str, func_name: str) -> Callable: """Get a registered function from a given registry.""" if not hasattr(cls, registry_name): raise ValueError(f"Unknown registry: '{registry_name}'") reg = getattr(cls, registry_name) func = reg.get(func_name) if func is None: raise ValueError( f"Could not find '{func_name}' in '{registry_name}'") return func @classmethod def make_from_config( cls, config: Union[Config, Dict[str, Dict[str, Any]]], *, schema: Type[BaseModel] = EmptySchema, validate: bool = True, ) -> Config: """Unpack a config dictionary, creating objects from the registry recursively. If validate=True, the config will be validated against the type annotations of the registered functions referenced in the config (if available) and/or the schema (if available). """ # Valid: {"optimizer": {"@optimizers": "my_cool_optimizer", "rate": 1.0}} # Invalid: {"@optimizers": "my_cool_optimizer", "rate": 1.0} if cls.is_promise(config): err_msg = "The top-level config object can't be a reference to a registered function." raise ConfigValidationError(config, [{"msg": err_msg}]) _, _, resolved = cls._fill(config, schema, validate) return resolved @classmethod def fill_config( cls, config: Union[Config, Dict[str, Dict[str, Any]]], *, schema: Type[BaseModel] = EmptySchema, validate: bool = True, ) -> Config: """Unpack a config dictionary, leave all references to registry functions intact and don't resolve them, but fill in all values and defaults based on the type annotations. If validate=True, the config will be validated against the type annotations of the registered functions referenced in the config (if available) and/or the schema (if available). """ # Valid: {"optimizer": {"@optimizers": "my_cool_optimizer", "rate": 1.0}} # Invalid: {"@optimizers": "my_cool_optimizer", "rate": 1.0} if cls.is_promise(config): err_msg = "The top-level config object can't be a reference to a registered function." raise ConfigValidationError(config, [{"msg": err_msg}]) filled, _, _ = cls._fill(config, schema, validate) return filled @classmethod def _fill( cls, config: Union[Config, Dict[str, Dict[str, Any]]], schema: Type[BaseModel] = EmptySchema, validate: bool = True, parent: str = "", ) -> Tuple[Config, Config, Config]: """Build three representations of the config: 1. All promises are preserved (just like config user would provide). 2. Promises are replaced by their return values. This is the validation copy and will be parsed by pydantic. It lets us include hacks to work around problems (e.g. handling of generators). 3. Final copy with promises replaced by their return values. This is what registry.make_from_config returns. """ filled: Dict[str, Any] = {} validation: Dict[str, Any] = {} final: Dict[str, Any] = {} for key, value in config.items(): key_parent = f"{parent}.{key}".strip(".") if cls.is_promise(value): promise_schema = cls.make_promise_schema(value) filled[key], validation[key], final[key] = cls._fill( value, promise_schema, validate, parent=key_parent) # Call the function and populate the field value. We can't just # create an instance of the type here, since this wouldn't work # for generics / more complex custom types getter = cls.get_constructor(final[key]) args, kwargs = cls.parse_args(final[key]) try: getter_result = getter(*args, **kwargs) except Exception as err: err_msg = "Can't construct config: calling registry function failed" raise ConfigValidationError({key: value}, [{ "msg": err, "loc": [getter.__name__] }], err_msg) validation[key] = getter_result final[key] = getter_result if isinstance(validation[key], GeneratorType): # If value is a generator we can't validate type without # consuming it (which doesn't work if it's infinite – see # schedule for examples). So we skip it. validation[key] = [] elif hasattr(value, "items"): field_type = EmptySchema if key in schema.__fields__: field = schema.__fields__[key] field_type = field.type_ if not isinstance(field.type_, ModelMetaclass): # If we don't have a pydantic schema and just a type field_type = EmptySchema filled[key], validation[key], final[key] = cls._fill( value, field_type, validate, parent=key_parent) if key == ARGS_FIELD and isinstance(validation[key], dict): # If the value of variable positional args is a dict (e.g. # created via config blocks), only use its values validation[key] = list(validation[key].values()) final[key] = list(final[key].values()) else: filled[key] = value # Prevent pydantic from consuming generator if part of a union validation[key] = value if not isinstance( value, GeneratorType) else [] final[key] = value # Now that we've filled in all of the promises, update with defaults # from schema, and validate if validation is enabled if validate: try: result = schema.parse_obj(validation) except ValidationError as e: raise ConfigValidationError(config, e.errors(), element=parent) else: # Same as parse_obj, but without validation result = schema.construct(**validation) validation.update(result.dict(exclude={ARGS_FIELD_ALIAS})) filled, final = cls._update_from_parsed(validation, filled, final) return Config(filled), Config(validation), Config(final) @classmethod def _update_from_parsed(cls, validation: Dict[str, Any], filled: Dict[str, Any], final: Dict[str, Any]): """Update the final result with the parsed config like converted values recursively. """ for key, value in validation.items(): if key not in filled: filled[key] = value if key not in final: final[key] = value if isinstance(value, dict): filled[key], final[key] = cls._update_from_parsed( value, filled[key], final[key]) # Update final config with parsed value if they're not equal (in # value and in type) but not if it's a generator because we had to # replace that to validate it correctly elif key == ARGS_FIELD: continue # don't substitute if list of positional args elif isinstance(value, numpy.ndarray): # check numpy first, just in case final[key] = value elif (value != final[key] or not isinstance(type(value), type(final[key])) ) and not isinstance(final[key], GeneratorType): final[key] = value return filled, final @classmethod def is_promise(cls, obj: Any) -> bool: """Check whether an object is a "promise", i.e. contains a reference to a registered function (via a key starting with `"@"`. """ if not hasattr(obj, "keys"): return False id_keys = [k for k in obj.keys() if k.startswith("@")] if len(id_keys): return True return False @classmethod def get_constructor(cls, obj: Dict[str, Any]) -> Callable: id_keys = [k for k in obj.keys() if k.startswith("@")] if len(id_keys) != 1: err_msg = f"A block can only contain one function registry reference. Got: {id_keys}" raise ConfigValidationError(obj, [{"msg": err_msg}]) else: key = id_keys[0] value = obj[key] return cls.get(key[1:], value) @classmethod def parse_args(cls, obj: Dict[str, Any]) -> Tuple[List[Any], Dict[str, Any]]: args = [] kwargs = {} for key, value in obj.items(): if not key.startswith("@"): if key == ARGS_FIELD: args = value else: kwargs[key] = value return args, kwargs @classmethod def make_promise_schema(cls, obj: Dict[str, Any]) -> Type[BaseModel]: """Create a schema for a promise dict (referencing a registry function) by inspecting the function signature. """ func = cls.get_constructor(obj) # Read the argument annotations and defaults from the function signature id_keys = [k for k in obj.keys() if k.startswith("@")] sig_args: Dict[str, Any] = {id_keys[0]: (str, ...)} for param in inspect.signature(func).parameters.values(): # If no annotation is specified assume it's anything annotation = param.annotation if param.annotation != param.empty else Any # If no default value is specified assume that it's required default = param.default if param.default != param.empty else ... # Handle spread arguments and use their annotation as Sequence[whatever] if param.kind == param.VAR_POSITIONAL: spread_annot = Sequence[annotation] # type: ignore sig_args[ARGS_FIELD_ALIAS] = (spread_annot, default) else: sig_args[param.name] = (annotation, default) sig_args["__config__"] = _PromiseSchemaConfig return create_model("ArgModel", **sig_args)