Exemplo n.º 1
0
 def test_lookup_no_synonyms(self):
     """Test looking up classes without auto-synonym."""
     resolver = Resolver([A], base=Base, synonym_attribute=None)
     self.assertEqual(A, resolver.lookup("a"))
     self.assertEqual(A, resolver.lookup("A"))
     with self.assertRaises(KeyError):
         self.assertEqual(A, resolver.lookup("a_synonym_1"))
Exemplo n.º 2
0
    def test_click_option_default(self):
        """Test generating an option with a default."""
        resolver = Resolver([A, B, C, E], base=Base, default=A)

        @click.command()
        @resolver.get_option("--opt", as_string=True)
        def cli(opt):
            """Run the test CLI."""
            self.assertIsInstance(opt, str)
            click.echo(self.resolver.lookup(opt).__name__, nl=False)

        self._test_cli(cli)
Exemplo n.º 3
0
    def test_registration_synonym_failure(self):
        """Test failure of registration."""
        resolver = Resolver([], base=Base)
        resolver.register(A, synonyms={"B"})
        with self.assertRaises(RegistrationSynonymConflict) as e:
            resolver.register(B)
        self.assertEqual("name", e.exception.label)
        self.assertIn("name", str(e.exception))

        class F(Base):
            """Extra class for testing."""

        with self.assertRaises(RegistrationSynonymConflict) as e:
            resolver.register(F, synonyms={"B"})
        self.assertEqual("synonym", e.exception.label)
        self.assertIn("synonym", str(e.exception))
Exemplo n.º 4
0
.. note:: This table can be re-generated with ``pykeen ls trainers -f rst``
"""

from typing import Set, Type

from class_resolver import Resolver

from .lcwa import LCWATrainingLoop  # noqa: F401
from .slcwa import SLCWATrainingLoop  # noqa: F401
from .training_loop import NonFiniteLossError, TrainingLoop  # noqa: F401

__all__ = [
    'TrainingLoop',
    'SLCWATrainingLoop',
    'LCWATrainingLoop',
    'NonFiniteLossError',
    'training_loop_resolver',
]

_TRAINING_LOOP_SUFFIX = 'TrainingLoop'
_TRAINING_LOOPS: Set[Type[TrainingLoop]] = {
    LCWATrainingLoop,
    SLCWATrainingLoop,
}
training_loop_resolver = Resolver(
    _TRAINING_LOOPS,
    base=TrainingLoop,  # type: ignore
    default=SLCWATrainingLoop,
    suffix=_TRAINING_LOOP_SUFFIX,
)
Exemplo n.º 5
0
random  :class:`optuna.samplers.RandomSampler`
tpe     :class:`optuna.samplers.TPESampler`
======  ======================================

.. note:: This table can be re-generated with ``pykeen ls hpo-samplers -f rst``
"""

# TODO update docs with table and CLI wtih generator

from typing import Set, Type

from class_resolver import Resolver
from optuna.samplers import BaseSampler, GridSampler, RandomSampler, TPESampler

__all__ = [
    'sampler_resolver',
]

_SAMPLER_SUFFIX = 'Sampler'
_SAMPLERS: Set[Type[BaseSampler]] = {
    RandomSampler,
    TPESampler,
    GridSampler,
}
sampler_resolver = Resolver(
    _SAMPLERS,
    base=BaseSampler,
    default=TPESampler,
    suffix=_SAMPLER_SUFFIX,
)
Exemplo n.º 6
0
    'RankBasedMetricResults',
    'SklearnEvaluator',
    'SklearnMetricResults',
    'evaluator_resolver',
    'metric_resolver',
    'get_metric_list',
]

_EVALUATOR_SUFFIX = 'Evaluator'
_EVALUATORS: Set[Type[Evaluator]] = {
    RankBasedEvaluator,
    SklearnEvaluator,
}
evaluator_resolver = Resolver(
    _EVALUATORS,
    base=Evaluator,  # type: ignore
    suffix=_EVALUATOR_SUFFIX,
    default=RankBasedEvaluator,
)

_METRICS_SUFFIX = 'MetricResults'
_METRICS: Set[Type[MetricResults]] = {
    RankBasedMetricResults,
    SklearnMetricResults,
}
metric_resolver = Resolver(
    _METRICS,
    suffix=_METRICS_SUFFIX,
    base=MetricResults,
)

Exemplo n.º 7
0
# -*- coding: utf-8 -*-
"""A wrapper for looking up pruners from :mod:`optuna`."""

from typing import Set, Type

from class_resolver import Resolver
from optuna.pruners import BasePruner, MedianPruner, NopPruner, PercentilePruner, SuccessiveHalvingPruner

__all__ = [
    "pruner_resolver",
]

_PRUNER_SUFFIX = "Pruner"
_PRUNERS: Set[Type[BasePruner]] = {
    MedianPruner,
    NopPruner,
    PercentilePruner,
    SuccessiveHalvingPruner,
}
pruner_resolver = Resolver(
    _PRUNERS,
    default=MedianPruner,
    suffix=_PRUNER_SUFFIX,
    base=BasePruner,
)
Exemplo n.º 8
0
    dict(gamma=dict(type=float, low=0.8, high=1.0, step=0.025), ),
    LambdaLR:
    dict(lr_lambda=dict(
        type="categorical",
        choices=[lambda epoch: epoch // 30, lambda epoch: 0.95**epoch]), ),
    MultiplicativeLR:
    dict(lr_lambda=dict(
        type="categorical",
        choices=[lambda epoch: 0.85, lambda epoch: 0.9,
                 lambda epoch: 0.95]), ),
    MultiStepLR:
    dict(
        gamma=dict(type=float, low=0.1, high=0.9, step=0.1),
        milestones=dict(type="categorical", choices=[75, 130, 190, 240, 370]),
    ),
    OneCycleLR:
    dict(max_lr=dict(type=float, low=0.1, high=0.3, scale="log"), ),
    StepLR:
    dict(
        gamma=dict(type=float, low=0.1, high=0.9, step=0.1),
        step_size=dict(type=int, low=1, high=50, step=5),
    ),
}

#: A resolver for learning rate schedulers
lr_scheduler_resolver = Resolver(
    base=LRScheduler,
    default=ExponentialLR,
    classes=set(lr_schedulers_hpo_defaults),
)
Exemplo n.º 9
0
        for r in self.regularizers:
            if isinstance(r, NoRegularizer):
                raise TypeError('Can not combine a no-op regularizer')
        self.register_buffer(
            name='normalization_factor',
            tensor=torch.as_tensor(sum(
                r.weight for r in self.regularizers), ).reciprocal())

    @property
    def normalize(self):  # noqa: D102
        return any(r.normalize for r in self.regularizers)

    def forward(self, x: torch.FloatTensor) -> torch.FloatTensor:  # noqa: D102
        return self.normalization_factor * sum(r.weight * r.forward(x)
                                               for r in self.regularizers)


_REGULARIZERS: Collection[Type[Regularizer]] = {
    NoRegularizer,  # type: ignore
    LpRegularizer,
    PowerSumRegularizer,
    CombinedRegularizer,
    TransHRegularizer,
}
regularizer_resolver = Resolver(
    _REGULARIZERS,
    base=Regularizer,  # type: ignore
    default=NoRegularizer,
    suffix=_REGULARIZER_SUFFIX,
)
Exemplo n.º 10
0
        # scale = scipy.stats.norm.ppf(1 - 1/d * 1/math.e) - mean
        # return scipy.stats.gumbel_r.mean(loc=mean, scale=scale)
    elif math.isfinite(p):
        exp_abs_norm_p = math.pow(2, p / 2) * math.gamma(
            (p + 1) / 2) / math.sqrt(math.pi)
        return math.pow(exp_abs_norm_p * d, 1 / p)
    else:
        raise TypeError(f"norm not implemented for {type(p)}: {p}")


activation_resolver = Resolver(
    classes=(
        nn.LeakyReLU,
        nn.PReLU,
        nn.ReLU,
        nn.Softplus,
        nn.Sigmoid,
        nn.Tanh,
    ),
    base=nn.Module,  # type: ignore
    default=nn.ReLU,
)


class Bias(nn.Module):
    """A module wrapper for adding a bias."""
    def __init__(self, dim: int):
        """Initialize the module.

        :param dim: >0
            The dimension of the input.
        """
Exemplo n.º 11
0
from .base import ConsoleResultTracker, ResultTracker
from .file import CSVResultTracker, FileResultTracker, JSONResultTracker
from .mlflow import MLFlowResultTracker
from .neptune import NeptuneResultTracker
from .wandb import WANDBResultTracker

__all__ = [
    # Base classes
    'ResultTracker',
    'FileResultTracker',
    # Concrete classes
    'MLFlowResultTracker',
    'NeptuneResultTracker',
    'WANDBResultTracker',
    'JSONResultTracker',
    'CSVResultTracker',
    'ConsoleResultTracker',
    # Utilities
    'tracker_resolver',
]

_RESULT_TRACKER_SUFFIX = 'ResultTracker'
_TRACKERS = [
    tracker for tracker in get_subclasses(ResultTracker)
    if tracker not in {FileResultTracker}
]
tracker_resolver = Resolver(_TRACKERS,
                            base=ResultTracker,
                            default=ResultTracker,
                            suffix=_RESULT_TRACKER_SUFFIX)
Exemplo n.º 12
0
    'HolE',
    'KG2E',
    'MuRE',
    'NTN',
    'PairRE',
    'ProjE',
    'QuatE',
    'RESCAL',
    'RGCN',
    'RotatE',
    'SimplE',
    'StructuredEmbedding',
    'TransD',
    'TransE',
    'TransH',
    'TransR',
    'TuckER',
    'UnstructuredModel',
    # Utils
    'model_resolver',
    'make_model',
    'make_model_cls',
]

_MODELS: Set[Type[Model]] = {
    subcls
    for subcls in get_subclasses(Model)  # type: ignore
    if not subcls._is_base_model
}
model_resolver = Resolver(classes=_MODELS, base=Model)  # type: ignore
Exemplo n.º 13
0
 def setUp(self) -> None:
     """Set up the resolver class."""
     self.resolver = Resolver([A, B, C, E], base=Base)
Exemplo n.º 14
0
from torch.optim.adamax import Adamax
from torch.optim.adamw import AdamW
from torch.optim.optimizer import Optimizer
from torch.optim.sgd import SGD

__all__ = [
    'Optimizer',
    'optimizers_hpo_defaults',
    'optimizer_resolver',
]

_OPTIMIZER_LIST: Set[Type[Optimizer]] = {
    Adadelta,
    Adagrad,
    Adam,
    Adamax,
    AdamW,
    SGD,
}

#: The default strategy for optimizing the optimizers' hyper-parameters (yo dawg)
optimizers_hpo_defaults: Mapping[Type[Optimizer], Mapping[str, Any]] = {
    Adagrad: dict(lr=dict(type=float, low=0.001, high=0.1, scale='log'), ),
    Adam: dict(lr=dict(type=float, low=0.001, high=0.1, scale='log'), ),
    Adamax: dict(lr=dict(type=float, low=0.001, high=0.1, scale='log'), ),
    AdamW: dict(lr=dict(type=float, low=0.001, high=0.1, scale='log'), ),
    SGD: dict(lr=dict(type=float, low=0.001, high=0.1, scale='log'), ),
}

optimizer_resolver = Resolver(_OPTIMIZER_LIST, base=Optimizer, default=Adam)
Exemplo n.º 15
0
    "RankBasedMetricResults",
    "ClassificationEvaluator",
    "ClassificationMetricResults",
    "evaluator_resolver",
    "metric_resolver",
    "get_metric_list",
]

evaluator_resolver = Resolver.from_subclasses(
    base=Evaluator,  # type: ignore
    default=RankBasedEvaluator,
)

_METRICS_SUFFIX = "MetricResults"
_METRICS: Set[Type[MetricResults]] = {
    RankBasedMetricResults,
    ClassificationMetricResults,
}
metric_resolver = Resolver(
    _METRICS,
    suffix=_METRICS_SUFFIX,
    base=MetricResults,
)


def get_metric_list():
    """Get info about all metrics across all evaluators."""
    return [(field, name, value)
            for name, value in metric_resolver.lookup_dict.items()
            for field in dataclasses.fields(value)]
Exemplo n.º 16
0
...     evaluation_kwargs=dict(batch_size=128),
...     stopper='early',
...     stopper_kwargs=dict(frequency=5, patience=2, relative_delta=0.002),
... )
"""

from typing import Collection, Type

from class_resolver import Resolver, get_subclasses

from .early_stopping import EarlyStopper, StopperCallback  # noqa: F401
from .stopper import NopStopper, Stopper

__all__ = [
    'Stopper',
    'NopStopper',
    'EarlyStopper',
    # Utils
    'stopper_resolver',
]

_STOPPER_SUFFIX = 'Stopper'
_STOPPERS: Collection[Type[Stopper]] = set(
    get_subclasses(Stopper))  # type: ignore
stopper_resolver = Resolver(
    _STOPPERS,
    default=NopStopper,
    suffix=_STOPPER_SUFFIX,
    base=Stopper,  # type: ignore
)
Exemplo n.º 17
0
_LOSSES: Set[Type[Loss]] = {
    MarginRankingLoss,
    BCEWithLogitsLoss,
    SoftplusLoss,
    BCEAfterSigmoidLoss,
    CrossEntropyLoss,
    MSELoss,
    NSSALoss,
}
losses_synonyms: Mapping[str, Type[Loss]] = {
    normalize_string(synonym, suffix=_LOSS_SUFFIX): cls
    for cls in _LOSSES if cls.synonyms is not None for synonym in cls.synonyms
}
loss_resolver = Resolver(
    _LOSSES,
    base=Loss,
    default=MarginRankingLoss,
    suffix=_LOSS_SUFFIX,
    synonyms=losses_synonyms,
)


def has_mr_loss(model) -> bool:
    """Check if the model has a marging ranking loss."""
    return isinstance(model.loss, MarginRankingLoss)


def has_nssa_loss(model) -> bool:
    """Check if the model has a NSSA loss."""
    return isinstance(model.loss, NSSALoss)
Exemplo n.º 18
0
        training_loop='sLCWA',
        negative_sampler='bernoulli',
    )
"""  # noqa

from typing import Set, Type

from class_resolver import Resolver, get_subclasses

from .basic_negative_sampler import BasicNegativeSampler
from .bernoulli_negative_sampler import BernoulliNegativeSampler
from .negative_sampler import NegativeSampler

__all__ = [
    'NegativeSampler',
    'BasicNegativeSampler',
    'BernoulliNegativeSampler',
    # Utils
    'negative_sampler_resolver',
]

_NEGATIVE_SAMPLER_SUFFIX = 'NegativeSampler'
_NEGATIVE_SAMPLERS: Set[Type[NegativeSampler]] = set(
    get_subclasses(NegativeSampler))  # type: ignore
negative_sampler_resolver = Resolver(
    _NEGATIVE_SAMPLERS,
    base=NegativeSampler,  # type: ignore
    default=BasicNegativeSampler,
    suffix=_NEGATIVE_SAMPLER_SUFFIX,
)
Exemplo n.º 19
0
    def forward(
        self,
        scores: torch.FloatTensor,
        labels: torch.FloatTensor,
    ) -> torch.FloatTensor:  # noqa: D102
        assert self.validate_labels(labels=labels)
        return functional.mse_loss(scores, labels, reduction=self.reduction)


margin_activation_resolver = Resolver(
    classes={
        nn.ReLU,
        nn.Softplus,
    },
    base=nn.Module,  # type: ignore
    synonyms=dict(
        hard=nn.ReLU,
        soft=nn.Softplus,
    ),
)


class MarginRankingLoss(PairwiseLoss):
    r"""A module for the margin ranking loss.

    .. math ::
        L(score^+, score^-) = activation(score^- - score^+ + margin)

    .. seealso:: :class:`torch.nn.MarginRankingLoss`
    """