training_loop='sLCWA', negative_sampler='bernoulli', ) """ # noqa from typing import Set, Type from class_resolver import Resolver, get_subclasses from .basic_negative_sampler import BasicNegativeSampler from .bernoulli_negative_sampler import BernoulliNegativeSampler from .negative_sampler import NegativeSampler __all__ = [ 'NegativeSampler', 'BasicNegativeSampler', 'BernoulliNegativeSampler', # Utils 'negative_sampler_resolver', ] _NEGATIVE_SAMPLER_SUFFIX = 'NegativeSampler' _NEGATIVE_SAMPLERS: Set[Type[NegativeSampler]] = set( get_subclasses(NegativeSampler)) # type: ignore negative_sampler_resolver = Resolver( _NEGATIVE_SAMPLERS, base=NegativeSampler, # type: ignore default=BasicNegativeSampler, suffix=_NEGATIVE_SAMPLER_SUFFIX, )
... evaluation_kwargs=dict(batch_size=128), ... stopper='early', ... stopper_kwargs=dict(frequency=5, patience=2, relative_delta=0.002), ... ) """ from typing import Collection, Type from class_resolver import Resolver, get_subclasses from .early_stopping import EarlyStopper, StopperCallback # noqa: F401 from .stopper import NopStopper, Stopper __all__ = [ 'Stopper', 'NopStopper', 'EarlyStopper', # Utils 'stopper_resolver', ] _STOPPER_SUFFIX = 'Stopper' _STOPPERS: Collection[Type[Stopper]] = set( get_subclasses(Stopper)) # type: ignore stopper_resolver = Resolver( _STOPPERS, default=NopStopper, suffix=_STOPPER_SUFFIX, base=Stopper, # type: ignore )
def test_testing(self): """Check that there is a test for all subclasses.""" to_test = set(get_subclasses(self.base_cls)).difference(self.skip_cls) tested = (test_cls.cls for test_cls in get_subclasses(self.base_test) if hasattr(test_cls, "cls")) not_tested = to_test.difference(tested) assert not not_tested, not_tested
"TransE", "TransF", "TransH", "TransR", "TuckER", "UM", # Evaluation-only models "MarginalDistributionBaseline", # Utils "model_resolver", "make_model", "make_model_cls", ] model_resolver = Resolver.from_subclasses( base=Model, skip={ # Abstract Models _NewAbstractModel, # We might be able to relax this later ERModel, LiteralModel, # baseline models behave differently EvaluationOnlyModel, *get_subclasses(EvaluationOnlyModel), # Old style models should never be looked up _OldAbstractModel, EntityRelationEmbeddingModel, }, )
from .base import ConsoleResultTracker, ResultTracker from .file import CSVResultTracker, FileResultTracker, JSONResultTracker from .mlflow import MLFlowResultTracker from .neptune import NeptuneResultTracker from .wandb import WANDBResultTracker __all__ = [ # Base classes 'ResultTracker', 'FileResultTracker', # Concrete classes 'MLFlowResultTracker', 'NeptuneResultTracker', 'WANDBResultTracker', 'JSONResultTracker', 'CSVResultTracker', 'ConsoleResultTracker', # Utilities 'tracker_resolver', ] _RESULT_TRACKER_SUFFIX = 'ResultTracker' _TRACKERS = [ tracker for tracker in get_subclasses(ResultTracker) if tracker not in {FileResultTracker} ] tracker_resolver = Resolver(_TRACKERS, base=ResultTracker, default=ResultTracker, suffix=_RESULT_TRACKER_SUFFIX)
'HolE', 'KG2E', 'MuRE', 'NTN', 'PairRE', 'ProjE', 'QuatE', 'RESCAL', 'RGCN', 'RotatE', 'SimplE', 'StructuredEmbedding', 'TransD', 'TransE', 'TransH', 'TransR', 'TuckER', 'UnstructuredModel', # Utils 'model_resolver', 'make_model', 'make_model_cls', ] _MODELS: Set[Type[Model]] = { subcls for subcls in get_subclasses(Model) # type: ignore if not subcls._is_base_model } model_resolver = Resolver(classes=_MODELS, base=Model) # type: ignore