Exemplo n.º 1
0
        training_loop='sLCWA',
        negative_sampler='bernoulli',
    )
"""  # noqa

from typing import Set, Type

from class_resolver import Resolver, get_subclasses

from .basic_negative_sampler import BasicNegativeSampler
from .bernoulli_negative_sampler import BernoulliNegativeSampler
from .negative_sampler import NegativeSampler

__all__ = [
    'NegativeSampler',
    'BasicNegativeSampler',
    'BernoulliNegativeSampler',
    # Utils
    'negative_sampler_resolver',
]

_NEGATIVE_SAMPLER_SUFFIX = 'NegativeSampler'
_NEGATIVE_SAMPLERS: Set[Type[NegativeSampler]] = set(
    get_subclasses(NegativeSampler))  # type: ignore
negative_sampler_resolver = Resolver(
    _NEGATIVE_SAMPLERS,
    base=NegativeSampler,  # type: ignore
    default=BasicNegativeSampler,
    suffix=_NEGATIVE_SAMPLER_SUFFIX,
)
Exemplo n.º 2
0
...     evaluation_kwargs=dict(batch_size=128),
...     stopper='early',
...     stopper_kwargs=dict(frequency=5, patience=2, relative_delta=0.002),
... )
"""

from typing import Collection, Type

from class_resolver import Resolver, get_subclasses

from .early_stopping import EarlyStopper, StopperCallback  # noqa: F401
from .stopper import NopStopper, Stopper

__all__ = [
    'Stopper',
    'NopStopper',
    'EarlyStopper',
    # Utils
    'stopper_resolver',
]

_STOPPER_SUFFIX = 'Stopper'
_STOPPERS: Collection[Type[Stopper]] = set(
    get_subclasses(Stopper))  # type: ignore
stopper_resolver = Resolver(
    _STOPPERS,
    default=NopStopper,
    suffix=_STOPPER_SUFFIX,
    base=Stopper,  # type: ignore
)
Exemplo n.º 3
0
 def test_testing(self):
     """Check that there is a test for all subclasses."""
     to_test = set(get_subclasses(self.base_cls)).difference(self.skip_cls)
     tested = (test_cls.cls for test_cls in get_subclasses(self.base_test) if hasattr(test_cls, "cls"))
     not_tested = to_test.difference(tested)
     assert not not_tested, not_tested
Exemplo n.º 4
0
    "TransE",
    "TransF",
    "TransH",
    "TransR",
    "TuckER",
    "UM",
    # Evaluation-only models
    "MarginalDistributionBaseline",
    # Utils
    "model_resolver",
    "make_model",
    "make_model_cls",
]

model_resolver = Resolver.from_subclasses(
    base=Model,
    skip={
        # Abstract Models
        _NewAbstractModel,
        # We might be able to relax this later
        ERModel,
        LiteralModel,
        # baseline models behave differently
        EvaluationOnlyModel,
        *get_subclasses(EvaluationOnlyModel),
        # Old style models should never be looked up
        _OldAbstractModel,
        EntityRelationEmbeddingModel,
    },
)
Exemplo n.º 5
0
from .base import ConsoleResultTracker, ResultTracker
from .file import CSVResultTracker, FileResultTracker, JSONResultTracker
from .mlflow import MLFlowResultTracker
from .neptune import NeptuneResultTracker
from .wandb import WANDBResultTracker

__all__ = [
    # Base classes
    'ResultTracker',
    'FileResultTracker',
    # Concrete classes
    'MLFlowResultTracker',
    'NeptuneResultTracker',
    'WANDBResultTracker',
    'JSONResultTracker',
    'CSVResultTracker',
    'ConsoleResultTracker',
    # Utilities
    'tracker_resolver',
]

_RESULT_TRACKER_SUFFIX = 'ResultTracker'
_TRACKERS = [
    tracker for tracker in get_subclasses(ResultTracker)
    if tracker not in {FileResultTracker}
]
tracker_resolver = Resolver(_TRACKERS,
                            base=ResultTracker,
                            default=ResultTracker,
                            suffix=_RESULT_TRACKER_SUFFIX)
Exemplo n.º 6
0
    'HolE',
    'KG2E',
    'MuRE',
    'NTN',
    'PairRE',
    'ProjE',
    'QuatE',
    'RESCAL',
    'RGCN',
    'RotatE',
    'SimplE',
    'StructuredEmbedding',
    'TransD',
    'TransE',
    'TransH',
    'TransR',
    'TuckER',
    'UnstructuredModel',
    # Utils
    'model_resolver',
    'make_model',
    'make_model_cls',
]

_MODELS: Set[Type[Model]] = {
    subcls
    for subcls in get_subclasses(Model)  # type: ignore
    if not subcls._is_base_model
}
model_resolver = Resolver(classes=_MODELS, base=Model)  # type: ignore