예제 #1
0
    def resolve_class_name(cls: Type[T],
                           name: str) -> Tuple[Type[T], Optional[str]]:
        """
        Returns the subclass that corresponds to the given `name`, along with the name of the
        method that was registered as a constructor for that `name`, if any.

        This method also allows `name` to be a fully-specified module name, instead of a name that
        was already added to the `Registry`.  In that case, you cannot use a separate function as
        a constructor (as you need to call `cls.register()` in order to tell us what separate
        function to use).
        """
        if name in Registrable._registry[cls]:
            subclass, constructor = Registrable._registry[cls][name]
            return subclass, constructor
        elif "." in name:
            # This might be a fully qualified class name, so we'll try importing its "module"
            # and finding it there.
            parts = name.split(".")
            submodule = ".".join(parts[:-1])
            class_name = parts[-1]

            try:
                module = importlib.import_module(submodule)
            except ModuleNotFoundError:
                raise ConfigurationError(
                    f"tried to interpret {name} as a path to a class "
                    f"but unable to import module {submodule}")

            try:
                subclass = getattr(module, class_name)
                constructor = None
                return subclass, constructor
            except AttributeError:
                raise ConfigurationError(
                    f"tried to interpret {name} as a path to a class "
                    f"but unable to find class {class_name} in {submodule}")

        else:
            # is not a qualified class name
            raise ConfigurationError(
                f"{name} is not a registered name for {cls.__name__}. "
                "You probably need to use the --include-package flag "
                "to load your custom code. Alternatively, you can specify your choices "
                """using fully-qualified paths, e.g. {"model": "my_module.models.MyModel"} """
                "in which case they will be automatically imported correctly.")
예제 #2
0
def get_optimizer(type: str, args: Optional[Dict]):
    optim_dict = {'Adam': torch.optim.Adam, 'Adagrad': torch.optim.Adagrad}

    try:
        optimizer = optim_dict[type](**args)
    except KeyError:
        raise ConfigurationError(
            f"Optimizer type {type} specified in config file not supported.")

    return optimizer
예제 #3
0
    def list_available(cls) -> List[str]:
        """List default first if it exists"""
        keys = list(Registrable._registry[cls].keys())
        default = cls.default_implementation

        if default is None:
            return keys
        elif default not in keys:
            raise ConfigurationError(
                f"Default implementation {default} is not registered")
        else:
            return [default] + [k for k in keys if k != default]
예제 #4
0
    def create(config: Config, dataset: DatasetProcessor):
        """Factory method for sampler creation"""

        model_type = config.get("model.name")

        if model_type in BaseModel.list_available():
            # kwargs = config.get("model.args")  # TODO: 需要改成key的格式
            return BaseModel.by_name(model_type)(config, dataset)
        else:
            raise ConfigurationError(
                f"{model_type} specified in configuration file is not supported"
                f"implement your model class with `BaseModel.register(name)")
예제 #5
0
    def create(config: Config):
        """Factory method for data creation"""

        ds_type = config.get("dataset.name")

        if ds_type in DatasetProcessor.list_available():
            kwargs = config.get("dataset.args")  # TODO: 需要改成key的格式
            return DatasetProcessor.by_name(ds_type)(config)
        else:
            raise ConfigurationError(
                f"{ds_type} specified in configuration file is not supported"
                f"implement your data class with `DatasetProcessor.register(name)"
            )
예제 #6
0
    def create(config: Config, dataset: DatasetProcessor):
        """Factory method for loss creation"""

        ns_type = config.get("negative_sampling.name")

        if ns_type in NegativeSampler.list_available():
            as_matrix = config.get("negative_sampling.as_matrix")
            # kwargs = config.get("model.args")  # TODO: 需要改成key的格式
            return NegativeSampler.by_name(ns_type)(config, dataset, as_matrix)
        else:
            raise ConfigurationError(
                f"{ns_type} specified in configuration file is not supported"
                f"implement your negative samping class with `NegativeSampler.register(name)"
            )
예제 #7
0
        def register_subclass(subclass: Type[T]):
            if name in registry:
                if overwrite:
                    msg = (
                        f"{name} has already been registered as {registry[name][0].__name__}, but "
                        f"exist_ok=True, so overwriting with {cls.__name__}")
                    logger.info(msg)
                else:
                    msg = (
                        f"Cannot register {name} as {cls.__name__}; "
                        f"name already in use for {registry[name][0].__name__}"
                    )
                    raise ConfigurationError(msg)

            registry[name] = (subclass, constructor)
            return subclass
예제 #8
0
def get_scheduler(optimizer: torch.optim.optimizer.Optimizer, type: str,
                  args: Optional[Dict]):
    scheduler_dict = {
        'MultiStepLR': torch.optim.lr_scheduler.MultiStepLR,
        'StepLR': torch.optim.lr_scheduler.StepLR,
        'ExponentialLR': torch.optim.lr_scheduler.ExponentialLR,
        'CosineAnnealingLR': torch.optim.lr_scheduler.CosineAnnealingLR,
        'ReduceLROnPlateau': torch.optim.lr_scheduler.ReduceLROnPlateau,
        'LambdaLR': torch.optim.lr_scheduler.LambdaLR
    }

    try:
        scheduler = scheduler_dict[type](optimizer, **args)
    except KeyError:
        raise ConfigurationError(
            f"Lr scheduler type {type} specified in config file not supported."
        )

    return scheduler
예제 #9
0
    def set(
            self, key: str, value, create=False, overwrite=Overwrite.Yes, log=False
    ) -> Any:
        """Set value of specified key.

        Nested dictionary values can be accessed via "." (e.g., "job.type").

        If ``create`` is ``False`` , raises :class:`ValueError` when the key
        does not exist already; otherwise, the new key-value pair is inserted
        into the configuration.

        """
        create = True
        from tkge.common.misc import is_number

        splits = key.split(".")
        data = self.options

        # flatten path and see if it is valid to be set
        path = []
        for i in range(len(splits) - 1):
            if splits[i] in data:
                create = create or "+++" in data[splits[i]]
            else:
                if create:
                    data[splits[i]] = dict()
                else:
                    raise ConfigurationError(
                        (
                                "{} cannot be set because creation of "
                                + "{} is not permitted"
                        ).format(key, ".".join(splits[: (i + 1)]))
                    )
            path.append(splits[i])
            data = data[splits[i]]

        # check correctness of value
        try:
            current_value = data.get(splits[-1])
        except:
            raise ConfigurationError(
                "These config entries {} {} caused an error.".format(data, splits[-1])
            )

        if current_value is None:
            if not create:
                raise ConfigurationError("key {} not present and `create` is disabled".format(key))

            if isinstance(value, str) and is_number(value, int):
                value = int(value)
            elif isinstance(value, str) and is_number(value, float):
                value = float(value)
        else:
            if (
                    isinstance(value, str)
                    and isinstance(current_value, float)
                    and is_number(value, float)
            ):
                value = float(value)
            elif (
                    isinstance(value, str)
                    and isinstance(current_value, int)
                    and is_number(value, int)
            ):
                value = int(value)
            if type(value) != type(current_value):
                raise ConfigurationError(
                    "key {} has incorrect type (expected {}, found {})".format(
                        key, type(current_value), type(value)
                    )
                )
            if overwrite == Config.Overwrite.No:
                return current_value
            if overwrite == Config.Overwrite.Error and value != current_value:
                raise ConfigurationError("key {} cannot be overwritten".format(key))

        # all fine, set value
        data[splits[-1]] = value
        if log:
            self.log("Set {}={}".format(key, value))
        return value