def resolve_class_name(cls: Type[T], name: str) -> Tuple[Type[T], Optional[str]]: """ Returns the subclass that corresponds to the given `name`, along with the name of the method that was registered as a constructor for that `name`, if any. This method also allows `name` to be a fully-specified module name, instead of a name that was already added to the `Registry`. In that case, you cannot use a separate function as a constructor (as you need to call `cls.register()` in order to tell us what separate function to use). """ if name in Registrable._registry[cls]: subclass, constructor = Registrable._registry[cls][name] return subclass, constructor elif "." in name: # This might be a fully qualified class name, so we'll try importing its "module" # and finding it there. parts = name.split(".") submodule = ".".join(parts[:-1]) class_name = parts[-1] try: module = importlib.import_module(submodule) except ModuleNotFoundError: raise ConfigurationError( f"tried to interpret {name} as a path to a class " f"but unable to import module {submodule}") try: subclass = getattr(module, class_name) constructor = None return subclass, constructor except AttributeError: raise ConfigurationError( f"tried to interpret {name} as a path to a class " f"but unable to find class {class_name} in {submodule}") else: # is not a qualified class name raise ConfigurationError( f"{name} is not a registered name for {cls.__name__}. " "You probably need to use the --include-package flag " "to load your custom code. Alternatively, you can specify your choices " """using fully-qualified paths, e.g. {"model": "my_module.models.MyModel"} """ "in which case they will be automatically imported correctly.")
def get_optimizer(type: str, args: Optional[Dict]): optim_dict = {'Adam': torch.optim.Adam, 'Adagrad': torch.optim.Adagrad} try: optimizer = optim_dict[type](**args) except KeyError: raise ConfigurationError( f"Optimizer type {type} specified in config file not supported.") return optimizer
def list_available(cls) -> List[str]: """List default first if it exists""" keys = list(Registrable._registry[cls].keys()) default = cls.default_implementation if default is None: return keys elif default not in keys: raise ConfigurationError( f"Default implementation {default} is not registered") else: return [default] + [k for k in keys if k != default]
def create(config: Config, dataset: DatasetProcessor): """Factory method for sampler creation""" model_type = config.get("model.name") if model_type in BaseModel.list_available(): # kwargs = config.get("model.args") # TODO: 需要改成key的格式 return BaseModel.by_name(model_type)(config, dataset) else: raise ConfigurationError( f"{model_type} specified in configuration file is not supported" f"implement your model class with `BaseModel.register(name)")
def create(config: Config): """Factory method for data creation""" ds_type = config.get("dataset.name") if ds_type in DatasetProcessor.list_available(): kwargs = config.get("dataset.args") # TODO: 需要改成key的格式 return DatasetProcessor.by_name(ds_type)(config) else: raise ConfigurationError( f"{ds_type} specified in configuration file is not supported" f"implement your data class with `DatasetProcessor.register(name)" )
def create(config: Config, dataset: DatasetProcessor): """Factory method for loss creation""" ns_type = config.get("negative_sampling.name") if ns_type in NegativeSampler.list_available(): as_matrix = config.get("negative_sampling.as_matrix") # kwargs = config.get("model.args") # TODO: 需要改成key的格式 return NegativeSampler.by_name(ns_type)(config, dataset, as_matrix) else: raise ConfigurationError( f"{ns_type} specified in configuration file is not supported" f"implement your negative samping class with `NegativeSampler.register(name)" )
def register_subclass(subclass: Type[T]): if name in registry: if overwrite: msg = ( f"{name} has already been registered as {registry[name][0].__name__}, but " f"exist_ok=True, so overwriting with {cls.__name__}") logger.info(msg) else: msg = ( f"Cannot register {name} as {cls.__name__}; " f"name already in use for {registry[name][0].__name__}" ) raise ConfigurationError(msg) registry[name] = (subclass, constructor) return subclass
def get_scheduler(optimizer: torch.optim.optimizer.Optimizer, type: str, args: Optional[Dict]): scheduler_dict = { 'MultiStepLR': torch.optim.lr_scheduler.MultiStepLR, 'StepLR': torch.optim.lr_scheduler.StepLR, 'ExponentialLR': torch.optim.lr_scheduler.ExponentialLR, 'CosineAnnealingLR': torch.optim.lr_scheduler.CosineAnnealingLR, 'ReduceLROnPlateau': torch.optim.lr_scheduler.ReduceLROnPlateau, 'LambdaLR': torch.optim.lr_scheduler.LambdaLR } try: scheduler = scheduler_dict[type](optimizer, **args) except KeyError: raise ConfigurationError( f"Lr scheduler type {type} specified in config file not supported." ) return scheduler
def set( self, key: str, value, create=False, overwrite=Overwrite.Yes, log=False ) -> Any: """Set value of specified key. Nested dictionary values can be accessed via "." (e.g., "job.type"). If ``create`` is ``False`` , raises :class:`ValueError` when the key does not exist already; otherwise, the new key-value pair is inserted into the configuration. """ create = True from tkge.common.misc import is_number splits = key.split(".") data = self.options # flatten path and see if it is valid to be set path = [] for i in range(len(splits) - 1): if splits[i] in data: create = create or "+++" in data[splits[i]] else: if create: data[splits[i]] = dict() else: raise ConfigurationError( ( "{} cannot be set because creation of " + "{} is not permitted" ).format(key, ".".join(splits[: (i + 1)])) ) path.append(splits[i]) data = data[splits[i]] # check correctness of value try: current_value = data.get(splits[-1]) except: raise ConfigurationError( "These config entries {} {} caused an error.".format(data, splits[-1]) ) if current_value is None: if not create: raise ConfigurationError("key {} not present and `create` is disabled".format(key)) if isinstance(value, str) and is_number(value, int): value = int(value) elif isinstance(value, str) and is_number(value, float): value = float(value) else: if ( isinstance(value, str) and isinstance(current_value, float) and is_number(value, float) ): value = float(value) elif ( isinstance(value, str) and isinstance(current_value, int) and is_number(value, int) ): value = int(value) if type(value) != type(current_value): raise ConfigurationError( "key {} has incorrect type (expected {}, found {})".format( key, type(current_value), type(value) ) ) if overwrite == Config.Overwrite.No: return current_value if overwrite == Config.Overwrite.Error and value != current_value: raise ConfigurationError("key {} cannot be overwritten".format(key)) # all fine, set value data[splits[-1]] = value if log: self.log("Set {}={}".format(key, value)) return value