Beispiel #1
0
 def test_lookup_no_synonyms(self):
     """Test looking up classes without auto-synonym."""
     resolver = Resolver([A], base=Base, synonym_attribute=None)
     self.assertEqual(A, resolver.lookup("a"))
     self.assertEqual(A, resolver.lookup("A"))
     with self.assertRaises(KeyError):
         self.assertEqual(A, resolver.lookup("a_synonym_1"))
Beispiel #2
0
 def test_no_arguments(self):
     """Check that the unexpected keyword error is thrown properly."""
     resolver = Resolver.from_subclasses(AltBase)
     with self.assertRaises(UnexpectedKeywordError) as e:
         resolver.make("A", nope="nopppeeee")
         self.assertEqual("AAltBase did not expect any keyword arguments",
                          str(e))
Beispiel #3
0
    def test_make_safe(self):
        """Test the make_safe function, which always returns none on none input."""
        self.assertIsNone(self.resolver.make_safe(None))
        self.assertIsNone(
            Resolver.from_subclasses(Base, default=A).make_safe(None))

        name = "charlie"
        # Test instantiating with positional dict into kwargs
        self.assertEqual(A(name=name),
                         self.resolver.make_safe("a", {"name": name}))
        # Test instantiating with kwargs
        self.assertEqual(A(name=name), self.resolver.make_safe("a", name=name))
Beispiel #4
0
    def test_click_option_default(self):
        """Test generating an option with a default."""
        resolver = Resolver([A, B, C, E], base=Base, default=A)

        @click.command()
        @resolver.get_option("--opt", as_string=True)
        def cli(opt):
            """Run the test CLI."""
            self.assertIsInstance(opt, str)
            click.echo(self.resolver.lookup(opt).__name__, nl=False)

        self._test_cli(cli)
Beispiel #5
0
    def test_registration_synonym_failure(self):
        """Test failure of registration."""
        resolver = Resolver([], base=Base)
        resolver.register(A, synonyms={"B"})
        with self.assertRaises(RegistrationSynonymConflict) as e:
            resolver.register(B)
        self.assertEqual("name", e.exception.label)
        self.assertIn("name", str(e.exception))

        class F(Base):
            """Extra class for testing."""

        with self.assertRaises(RegistrationSynonymConflict) as e:
            resolver.register(F, synonyms={"B"})
        self.assertEqual("synonym", e.exception.label)
        self.assertIn("synonym", str(e.exception))
Beispiel #6
0
    def test_base_suffix(self):
        """Check that the unexpected keyword error is thrown properly."""
        resolver = Resolver.from_subclasses(AltBase,
                                            suffix=None,
                                            base_as_suffix=True)
        self.assertEqual(AAltBase, resolver.lookup("AAltBase"))
        self.assertEqual(AAltBase, resolver.lookup("A"))

        resolver = Resolver.from_subclasses(AltBase,
                                            suffix="nope",
                                            base_as_suffix=True)
        self.assertEqual(AAltBase, resolver.lookup("AAltBase"))
        with self.assertRaises(KeyError):
            resolver.lookup("A")

        resolver = Resolver.from_subclasses(AltBase, suffix="")
        self.assertEqual(AAltBase, resolver.lookup("AAltBase"))
        with self.assertRaises(KeyError):
            resolver.lookup("A")

        resolver = Resolver.from_subclasses(AltBase, base_as_suffix=False)
        self.assertEqual(AAltBase, resolver.lookup("AAltBase"))
        with self.assertRaises(KeyError):
            resolver.lookup("A")
Beispiel #7
0
            x = x * 0x846ca68b
            x = x ^ (x >> 16)
            yield x % self.bit_array.shape[0]

    def add(self, triples: MappedTriples) -> None:
        """Add triples to the Bloom filter."""
        for i in self.probe(batch=triples):
            self.bit_array[i] = True

    def contains(self, batch: MappedTriples) -> torch.BoolTensor:
        """
        Check whether a triple is contained.

        :param batch: shape (batch_size, 3)
            The batch of triples.

        :return: shape: (batch_size,)
            The result. False guarantees that the element was not contained in the indexed triples. True can be
            erroneous.
        """
        result = batch.new_ones(batch.shape[:-1], dtype=torch.bool)
        for i in self.probe(batch):
            result &= self.bit_array[i]
        return result


filterer_resolver = Resolver.from_subclasses(
    base=Filterer,
    default=BloomFilterer,
)
Beispiel #8
0
    "CKG",
    "CSKG",
    "DBpedia50",
    "DB100K",
    "Countries",
    "WD50KT",
    "Wikidata5M",
    # Utilities
    "dataset_resolver",
    "get_dataset",
    "has_dataset",
]

logger = logging.getLogger(__name__)

dataset_resolver = Resolver.from_entrypoint(group="pykeen.datasets",
                                            base=Dataset)
if not dataset_resolver.lookup_dict:
    raise RuntimeError(
        dedent("""\
    Datasets have been loaded with entrypoints since PyKEEN v1.0.5, which is now a
    very old version of PyKEEN.

    If you simply use `python3 -m pip install --upgrade pykeen`, the entrypoints will
    not be reloaded. Instead, please reinstall PyKEEN using the following commands:

    $ python3 -m pip uninstall pykeen
    $ python3 -m pip install pykeen

    If you are on Kaggle or Google Colab, please follow these instructions:
    https://pykeen.readthedocs.io/en/stable/installation.html#google-colab-and-kaggle-users
Beispiel #9
0
        for r in self.regularizers:
            if isinstance(r, NoRegularizer):
                raise TypeError('Can not combine a no-op regularizer')
        self.register_buffer(
            name='normalization_factor',
            tensor=torch.as_tensor(sum(
                r.weight for r in self.regularizers), ).reciprocal())

    @property
    def normalize(self):  # noqa: D102
        return any(r.normalize for r in self.regularizers)

    def forward(self, x: torch.FloatTensor) -> torch.FloatTensor:  # noqa: D102
        return self.normalization_factor * sum(r.weight * r.forward(x)
                                               for r in self.regularizers)


_REGULARIZERS: Collection[Type[Regularizer]] = {
    NoRegularizer,  # type: ignore
    LpRegularizer,
    PowerSumRegularizer,
    CombinedRegularizer,
    TransHRegularizer,
}
regularizer_resolver = Resolver(
    _REGULARIZERS,
    base=Regularizer,  # type: ignore
    default=NoRegularizer,
    suffix=_REGULARIZER_SUFFIX,
)
Beispiel #10
0
from torch.optim.adamax import Adamax
from torch.optim.adamw import AdamW
from torch.optim.optimizer import Optimizer
from torch.optim.sgd import SGD

__all__ = [
    'Optimizer',
    'optimizers_hpo_defaults',
    'optimizer_resolver',
]

_OPTIMIZER_LIST: Set[Type[Optimizer]] = {
    Adadelta,
    Adagrad,
    Adam,
    Adamax,
    AdamW,
    SGD,
}

#: The default strategy for optimizing the optimizers' hyper-parameters (yo dawg)
optimizers_hpo_defaults: Mapping[Type[Optimizer], Mapping[str, Any]] = {
    Adagrad: dict(lr=dict(type=float, low=0.001, high=0.1, scale='log'), ),
    Adam: dict(lr=dict(type=float, low=0.001, high=0.1, scale='log'), ),
    Adamax: dict(lr=dict(type=float, low=0.001, high=0.1, scale='log'), ),
    AdamW: dict(lr=dict(type=float, low=0.001, high=0.1, scale='log'), ),
    SGD: dict(lr=dict(type=float, low=0.001, high=0.1, scale='log'), ),
}

optimizer_resolver = Resolver(_OPTIMIZER_LIST, base=Optimizer, default=Adam)
Beispiel #11
0
random  :class:`optuna.samplers.RandomSampler`
tpe     :class:`optuna.samplers.TPESampler`
======  ======================================

.. note:: This table can be re-generated with ``pykeen ls hpo-samplers -f rst``
"""

# TODO update docs with table and CLI wtih generator

from typing import Set, Type

from class_resolver import Resolver
from optuna.samplers import BaseSampler, GridSampler, RandomSampler, TPESampler

__all__ = [
    'sampler_resolver',
]

_SAMPLER_SUFFIX = 'Sampler'
_SAMPLERS: Set[Type[BaseSampler]] = {
    RandomSampler,
    TPESampler,
    GridSampler,
}
sampler_resolver = Resolver(
    _SAMPLERS,
    base=BaseSampler,
    default=TPESampler,
    suffix=_SAMPLER_SUFFIX,
)
Beispiel #12
0
    'RankBasedMetricResults',
    'SklearnEvaluator',
    'SklearnMetricResults',
    'evaluator_resolver',
    'metric_resolver',
    'get_metric_list',
]

_EVALUATOR_SUFFIX = 'Evaluator'
_EVALUATORS: Set[Type[Evaluator]] = {
    RankBasedEvaluator,
    SklearnEvaluator,
}
evaluator_resolver = Resolver(
    _EVALUATORS,
    base=Evaluator,  # type: ignore
    suffix=_EVALUATOR_SUFFIX,
    default=RankBasedEvaluator,
)

_METRICS_SUFFIX = 'MetricResults'
_METRICS: Set[Type[MetricResults]] = {
    RankBasedMetricResults,
    SklearnMetricResults,
}
metric_resolver = Resolver(
    _METRICS,
    suffix=_METRICS_SUFFIX,
    base=MetricResults,
)

Beispiel #13
0
        # scale = scipy.stats.norm.ppf(1 - 1/d * 1/math.e) - mean
        # return scipy.stats.gumbel_r.mean(loc=mean, scale=scale)
    elif math.isfinite(p):
        exp_abs_norm_p = math.pow(2, p / 2) * math.gamma(
            (p + 1) / 2) / math.sqrt(math.pi)
        return math.pow(exp_abs_norm_p * d, 1 / p)
    else:
        raise TypeError(f"norm not implemented for {type(p)}: {p}")


activation_resolver = Resolver(
    classes=(
        nn.LeakyReLU,
        nn.PReLU,
        nn.ReLU,
        nn.Softplus,
        nn.Sigmoid,
        nn.Tanh,
    ),
    base=nn.Module,  # type: ignore
    default=nn.ReLU,
)


class Bias(nn.Module):
    """A module wrapper for adding a bias."""
    def __init__(self, dim: int):
        """Initialize the module.

        :param dim: >0
            The dimension of the input.
        """
Beispiel #14
0
from .base import ConsoleResultTracker, ResultTracker
from .file import CSVResultTracker, FileResultTracker, JSONResultTracker
from .mlflow import MLFlowResultTracker
from .neptune import NeptuneResultTracker
from .wandb import WANDBResultTracker

__all__ = [
    # Base classes
    'ResultTracker',
    'FileResultTracker',
    # Concrete classes
    'MLFlowResultTracker',
    'NeptuneResultTracker',
    'WANDBResultTracker',
    'JSONResultTracker',
    'CSVResultTracker',
    'ConsoleResultTracker',
    # Utilities
    'tracker_resolver',
]

_RESULT_TRACKER_SUFFIX = 'ResultTracker'
_TRACKERS = [
    tracker for tracker in get_subclasses(ResultTracker)
    if tracker not in {FileResultTracker}
]
tracker_resolver = Resolver(_TRACKERS,
                            base=ResultTracker,
                            default=ResultTracker,
                            suffix=_RESULT_TRACKER_SUFFIX)
Beispiel #15
0
    'HolE',
    'KG2E',
    'MuRE',
    'NTN',
    'PairRE',
    'ProjE',
    'QuatE',
    'RESCAL',
    'RGCN',
    'RotatE',
    'SimplE',
    'StructuredEmbedding',
    'TransD',
    'TransE',
    'TransH',
    'TransR',
    'TuckER',
    'UnstructuredModel',
    # Utils
    'model_resolver',
    'make_model',
    'make_model_cls',
]

_MODELS: Set[Type[Model]] = {
    subcls
    for subcls in get_subclasses(Model)  # type: ignore
    if not subcls._is_base_model
}
model_resolver = Resolver(classes=_MODELS, base=Model)  # type: ignore
Beispiel #16
0
    "TransE",
    "TransF",
    "TransH",
    "TransR",
    "TuckER",
    "UM",
    # Evaluation-only models
    "MarginalDistributionBaseline",
    # Utils
    "model_resolver",
    "make_model",
    "make_model_cls",
]

model_resolver = Resolver.from_subclasses(
    base=Model,
    skip={
        # Abstract Models
        _NewAbstractModel,
        # We might be able to relax this later
        ERModel,
        LiteralModel,
        # baseline models behave differently
        EvaluationOnlyModel,
        *get_subclasses(EvaluationOnlyModel),
        # Old style models should never be looked up
        _OldAbstractModel,
        EntityRelationEmbeddingModel,
    },
)
Beispiel #17
0
                                            dtype=torch.get_default_dtype())

        # scale. We model this as log(scale) to ensure scale > 0, and thus monotonicity
        self.log_scale = nn.Parameter(torch.empty(size=tuple()),
                                      requires_grad=trainable_scale)
        self.initial_log_scale = torch.as_tensor(
            data=[math.log(initial_scale)], dtype=torch.get_default_dtype())

    def reset_parameters(self):  # noqa: D102
        self.bias.data = self.initial_bias.to(device=self.bias.device)
        self.log_scale.data = self.initial_log_scale.to(
            device=self.bias.device)

    def forward(
        self,
        h: HeadRepresentation,
        r: RelationRepresentation,
        t: TailRepresentation,
    ) -> torch.FloatTensor:  # noqa: D102
        return self.log_scale.exp() * self.base(h=h, r=r, t=t) + self.bias


interaction_resolver = Resolver.from_subclasses(
    Interaction,  # type: ignore
    skip={
        TranslationalInteraction, FunctionalInteraction,
        MonotonicAffineTransformationInteraction
    },
    suffix=Interaction.__name__,
)
Beispiel #18
0
__all__ = [
    "evaluate",
    "Evaluator",
    "MetricResults",
    "RankBasedEvaluator",
    "RankBasedMetricResults",
    "ClassificationEvaluator",
    "ClassificationMetricResults",
    "evaluator_resolver",
    "metric_resolver",
    "get_metric_list",
]

evaluator_resolver = Resolver.from_subclasses(
    base=Evaluator,  # type: ignore
    default=RankBasedEvaluator,
)

_METRICS_SUFFIX = "MetricResults"
_METRICS: Set[Type[MetricResults]] = {
    RankBasedMetricResults,
    ClassificationMetricResults,
}
metric_resolver = Resolver(
    _METRICS,
    suffix=_METRICS_SUFFIX,
    base=MetricResults,
)


def get_metric_list():
Beispiel #19
0
    """Calculate inverse relative frequency weighting."""
    # Calculate in-degree, i.e. number of incoming edges
    inv, cnt = torch.unique(idx, return_counts=True, return_inverse=True)[1:]
    return cnt[inv].float().reciprocal()


class InverseInDegreeEdgeWeighting(EdgeWeighting):
    """Normalize messages by inverse in-degree."""
    def forward(self, source: torch.LongTensor,
                target: torch.LongTensor) -> torch.FloatTensor:  # noqa: D102
        return _inverse_frequency_weighting(idx=target)


class InverseOutDegreeEdgeWeighting(EdgeWeighting):
    """Normalize messages by inverse out-degree."""
    def forward(self, source: torch.LongTensor,
                target: torch.LongTensor) -> torch.FloatTensor:  # noqa: D102
        return _inverse_frequency_weighting(idx=source)


class SymmetricEdgeWeighting(EdgeWeighting):
    """Normalize messages by product of inverse sqrt of in-degree and out-degree."""
    def forward(self, source: torch.LongTensor,
                target: torch.LongTensor) -> torch.FloatTensor:  # noqa: D102
        return (_inverse_frequency_weighting(idx=source) *
                _inverse_frequency_weighting(idx=target)).sqrt()


edge_weight_resolver = Resolver.from_subclasses(base=EdgeWeighting,
                                                default=SymmetricEdgeWeighting)
Beispiel #20
0
    def forward(self, a: torch.FloatTensor,
                b: torch.FloatTensor) -> torch.FloatTensor:  # noqa: D102
        return self.__class__.func(a, b)


class SubtractionCompositionModule(FunctionalCompositionModule):
    """Composition by element-wise subtraction."""

    func = torch.sub


class MultiplicationCompositionModule(FunctionalCompositionModule):
    """Composition by element-wise multiplication."""

    func = torch.mul


class CircularCorrelationCompositionModule(FunctionalCompositionModule):
    """Composition by circular correlation via :func:`pykeen.nn.functional.circular_correlation`."""

    func = circular_correlation


composition_resolver = Resolver.from_subclasses(
    CompositionModule,
    default=MultiplicationCompositionModule,
    skip={
        FunctionalCompositionModule,
    },
)
Beispiel #21
0
.. note:: This table can be re-generated with ``pykeen ls trainers -f rst``
"""

from typing import Set, Type

from class_resolver import Resolver

from .lcwa import LCWATrainingLoop  # noqa: F401
from .slcwa import SLCWATrainingLoop  # noqa: F401
from .training_loop import NonFiniteLossError, TrainingLoop  # noqa: F401

__all__ = [
    'TrainingLoop',
    'SLCWATrainingLoop',
    'LCWATrainingLoop',
    'NonFiniteLossError',
    'training_loop_resolver',
]

_TRAINING_LOOP_SUFFIX = 'TrainingLoop'
_TRAINING_LOOPS: Set[Type[TrainingLoop]] = {
    LCWATrainingLoop,
    SLCWATrainingLoop,
}
training_loop_resolver = Resolver(
    _TRAINING_LOOPS,
    base=TrainingLoop,  # type: ignore
    default=SLCWATrainingLoop,
    suffix=_TRAINING_LOOP_SUFFIX,
)
Beispiel #22
0
    "NeptuneResultTracker",
    "WANDBResultTracker",
    "JSONResultTracker",
    "CSVResultTracker",
    "PythonResultTracker",
    "TensorBoardResultTracker",
    "ConsoleResultTracker",
    # Utilities
    "tracker_resolver",
    "TrackerHint",
    "resolve_result_trackers",
]

tracker_resolver = Resolver.from_subclasses(
    base=ResultTracker,
    default=ResultTracker,
    skip={FileResultTracker, MultiResultTracker},
)


def resolve_result_trackers(
    result_tracker: Optional[OneOrSequence[HintType[ResultTracker]]] = None,
    result_tracker_kwargs: Optional[OneOrSequence[Optional[Mapping[
        str, Any]]]] = None,
) -> MultiResultTracker:
    """Resolve and compose result trackers.

    :param result_tracker: Either none (will result in a Python result tracker),
        a single tracker (as either a class, instance, or string for class name), or a list
        of trackers (as either a class, instance, or string for class name
    :param result_tracker_kwargs: Either none (will use all defaults), a single dictionary
Beispiel #23
0
...     evaluation_kwargs=dict(batch_size=128),
...     stopper='early',
...     stopper_kwargs=dict(frequency=5, patience=2, relative_delta=0.002),
... )
"""

from typing import Collection, Type

from class_resolver import Resolver, get_subclasses

from .early_stopping import EarlyStopper, StopperCallback  # noqa: F401
from .stopper import NopStopper, Stopper

__all__ = [
    'Stopper',
    'NopStopper',
    'EarlyStopper',
    # Utils
    'stopper_resolver',
]

_STOPPER_SUFFIX = 'Stopper'
_STOPPERS: Collection[Type[Stopper]] = set(
    get_subclasses(Stopper))  # type: ignore
stopper_resolver = Resolver(
    _STOPPERS,
    default=NopStopper,
    suffix=_STOPPER_SUFFIX,
    base=Stopper,  # type: ignore
)
Beispiel #24
0
_LOSSES: Set[Type[Loss]] = {
    MarginRankingLoss,
    BCEWithLogitsLoss,
    SoftplusLoss,
    BCEAfterSigmoidLoss,
    CrossEntropyLoss,
    MSELoss,
    NSSALoss,
}
losses_synonyms: Mapping[str, Type[Loss]] = {
    normalize_string(synonym, suffix=_LOSS_SUFFIX): cls
    for cls in _LOSSES if cls.synonyms is not None for synonym in cls.synonyms
}
loss_resolver = Resolver(
    _LOSSES,
    base=Loss,
    default=MarginRankingLoss,
    suffix=_LOSS_SUFFIX,
    synonyms=losses_synonyms,
)


def has_mr_loss(model) -> bool:
    """Check if the model has a marging ranking loss."""
    return isinstance(model.loss, MarginRankingLoss)


def has_nssa_loss(model) -> bool:
    """Check if the model has a NSSA loss."""
    return isinstance(model.loss, NSSALoss)
Beispiel #25
0
There are two other major considerations when randomly sampling negative triples: the random sampling
strategy and the filtering of positive triples. A full guide on negative sampling with the SLCWA can be
found in :mod:`pykeen.sampling`. The following chart from [ali2020a]_ demonstrates the different potential
triples considered in LCWA vs. sLCWA based on the given true triples (in red):

.. image:: ../img/training_approaches.png
  :alt: Troubleshooting Image 2
"""  # noqa:E501

from class_resolver import Resolver

from .callbacks import TrainingCallback  # noqa: F401
from .lcwa import LCWATrainingLoop  # noqa: F401
from .slcwa import SLCWATrainingLoop  # noqa: F401
from .training_loop import NonFiniteLossError, TrainingLoop  # noqa: F401

__all__ = [
    "TrainingLoop",
    "SLCWATrainingLoop",
    "LCWATrainingLoop",
    "NonFiniteLossError",
    "training_loop_resolver",
    "TrainingCallback",
]

training_loop_resolver = Resolver.from_subclasses(
    base=TrainingLoop,  # type: ignore
    default=SLCWATrainingLoop,
)
Beispiel #26
0
    dict(gamma=dict(type=float, low=0.8, high=1.0, step=0.025), ),
    LambdaLR:
    dict(lr_lambda=dict(
        type="categorical",
        choices=[lambda epoch: epoch // 30, lambda epoch: 0.95**epoch]), ),
    MultiplicativeLR:
    dict(lr_lambda=dict(
        type="categorical",
        choices=[lambda epoch: 0.85, lambda epoch: 0.9,
                 lambda epoch: 0.95]), ),
    MultiStepLR:
    dict(
        gamma=dict(type=float, low=0.1, high=0.9, step=0.1),
        milestones=dict(type="categorical", choices=[75, 130, 190, 240, 370]),
    ),
    OneCycleLR:
    dict(max_lr=dict(type=float, low=0.1, high=0.3, scale="log"), ),
    StepLR:
    dict(
        gamma=dict(type=float, low=0.1, high=0.9, step=0.1),
        step_size=dict(type=int, low=1, high=50, step=5),
    ),
}

#: A resolver for learning rate schedulers
lr_scheduler_resolver = Resolver(
    base=LRScheduler,
    default=ExponentialLR,
    classes=set(lr_schedulers_hpo_defaults),
)
Beispiel #27
0
        # other relations
        for r in range(self.num_relations):
            source_r, target_r, weights_r = _reduce_relation_specific(
                relation=r,
                source=source,
                target=target,
                edge_type=edge_type,
                edge_weights=edge_weights,
            )

            # skip relations without edges
            if source_r is None:
                continue

            # compute message, shape: (num_edges_of_type, num_blocks, block_size)
            uniq_source_r, inv_source_r = source_r.unique(return_inverse=True)
            w_r = self.blocks[r]
            m = torch.einsum('nbi,bij->nbj', x[uniq_source_r], w_r).index_select(dim=0, index=inv_source_r)

            # optional message weighting
            if weights_r is not None:
                m = m * weights_r.unsqueeze(dim=1).unsqueeze(dim=2)

            # message aggregation
            out.index_add_(dim=0, index=target_r, source=m)

        return out.reshape(-1, self.output_dim)


decomposition_resolver = Resolver.from_subclasses(base=Decomposition, default=BasesDecomposition)
Beispiel #28
0
    def __init__(
        self,
        regularizers: Iterable[Regularizer],
        total_weight: float = 1.0,
        apply_only_once: bool = False,
    ):
        super().__init__(weight=total_weight, apply_only_once=apply_only_once)
        self.regularizers = nn.ModuleList(regularizers)
        for r in self.regularizers:
            if isinstance(r, NoRegularizer):
                raise TypeError('Can not combine a no-op regularizer')
        self.register_buffer(
            name='normalization_factor',
            tensor=torch.as_tensor(sum(
                r.weight for r in self.regularizers), ).reciprocal())

    @property
    def normalize(self):  # noqa: D102
        return any(r.normalize for r in self.regularizers)

    def forward(self, x: torch.FloatTensor) -> torch.FloatTensor:  # noqa: D102
        return self.normalization_factor * sum(r.weight * r.forward(x)
                                               for r in self.regularizers)


regularizer_resolver = Resolver.from_subclasses(
    base=Regularizer,
    default=NoRegularizer,
)
Beispiel #29
0
        training_loop='sLCWA',
        negative_sampler='bernoulli',
    )
"""  # noqa

from typing import Set, Type

from class_resolver import Resolver, get_subclasses

from .basic_negative_sampler import BasicNegativeSampler
from .bernoulli_negative_sampler import BernoulliNegativeSampler
from .negative_sampler import NegativeSampler

__all__ = [
    'NegativeSampler',
    'BasicNegativeSampler',
    'BernoulliNegativeSampler',
    # Utils
    'negative_sampler_resolver',
]

_NEGATIVE_SAMPLER_SUFFIX = 'NegativeSampler'
_NEGATIVE_SAMPLERS: Set[Type[NegativeSampler]] = set(
    get_subclasses(NegativeSampler))  # type: ignore
negative_sampler_resolver = Resolver(
    _NEGATIVE_SAMPLERS,
    base=NegativeSampler,  # type: ignore
    default=BasicNegativeSampler,
    suffix=_NEGATIVE_SAMPLER_SUFFIX,
)
Beispiel #30
0
# -*- coding: utf-8 -*-
"""A wrapper for looking up pruners from :mod:`optuna`."""

from typing import Set, Type

from class_resolver import Resolver
from optuna.pruners import BasePruner, MedianPruner, NopPruner, PercentilePruner, SuccessiveHalvingPruner

__all__ = [
    "pruner_resolver",
]

_PRUNER_SUFFIX = "Pruner"
_PRUNERS: Set[Type[BasePruner]] = {
    MedianPruner,
    NopPruner,
    PercentilePruner,
    SuccessiveHalvingPruner,
}
pruner_resolver = Resolver(
    _PRUNERS,
    default=MedianPruner,
    suffix=_PRUNER_SUFFIX,
    base=BasePruner,
)