from olympus.utils.factory import fetch_factories from olympus.utils import set_seeds as init_seed, warning, HyperParameters from torch.nn import Module from torch.random import fork_rng registered_initialization = fetch_factories('olympus.models.inits', __file__) def known_initialization(): return registered_initialization.keys() def register_initialization(name, factory, override=False): global registered_initialization if name in registered_initialization: warning(f'{name} was already registered, use override=True to ignore') if not override: return registered_initialization[name] = factory class RegisteredInitNotFound(Exception): pass def get_initializers_space():
from olympus.utils.factory import fetch_factories from olympus.utils import MissingArgument, warning from olympus.hpo.fidelity import Fidelity from olympus.hpo.parallel import ParallelHPO registered_optimizer = fetch_factories('olympus.hpo', __file__) def known_hpo(): return registered_optimizer.keys() def register_hpo(name, factory, override=False): global registered_optimizer if name in registered_optimizer: warning(f'{name} was already registered, use override=True to ignore') if not override: return registered_optimizer[name] = factory class RegisteredHPONotFound(Exception): pass class HPOptimizer: """Olympus standardized HPO interface
from torch.utils.data import Dataset as TorchDataset, Subset from olympus.utils import new_seed from olympus.utils.factory import fetch_factories sampling_methods = fetch_factories('olympus.datasets.split', __file__, function_name='split') def known_split_methods(): return list(sampling_methods.keys()) class RegisteredSplitMethodNotFound(Exception): pass def _generate_splits(datasets, split_method, seed, ratio, index, data_size, balanced): if data_size is not None: data_size = int(data_size * len(datasets)) assert data_size <= len(datasets) else: data_size = len(datasets) split = sampling_methods.get(split_method) if split is None: raise RegisteredSplitMethodNotFound( f'Split method `{split_method}` was not found use {known_split_methods()}'
import copy from collections import defaultdict import torch from torch.utils.data import DataLoader as TorchDataLoader, Dataset as TorchDataset from olympus.utils import warning, option, MissingArgument from olympus.utils.factory import fetch_factories from olympus.datasets.transformed import TransformedSubset from olympus.datasets.split import SplitDataset from olympus.datasets.sampling import RandomSampler, SequentialSampler registered_datasets = fetch_factories('olympus.datasets', __file__) def known_datasets(*category_filters, include_unknown=False): if not category_filters: return registered_datasets.keys() matching = [] for filter in category_filters: for name, factory in registered_datasets.items(): if hasattr(factory, 'categories'): if filter in factory.categories(): matching.append(name) # we don't know if it matches because it does not have the categories method elif include_unknown: matching.append(name) return matching
from olympus.utils import warning from olympus.utils.factory import fetch_factories registered_environment = fetch_factories('olympus.reinforcement', __file__) def known_environments(*category_filters, include_unknown=False): """List known environments""" if not category_filters: return registered_environment.keys() matching = [] for filter in category_filters: for name, factory in registered_environment.items(): if hasattr(factory, 'categories'): if filter in factory.categories(): matching.append(name) # we don't know if it matches because it does not have the categories method elif include_unknown: matching.append(name) return matching def register_environment(name, factory, override=False): """Register a new environment backend""" global registered_environment if name in registered_environment:
from typing import Dict import torch from torch.optim.optimizer import Optimizer as TorchOptimizer from olympus.utils import MissingArgument, warning, HyperParameters from olympus.utils.factory import fetch_factories from olympus.optimizers.schedules import LRSchedule, known_schedule registered_optimizers = fetch_factories('olympus.optimizers', __file__) def known_optimizers(): return registered_optimizers.keys() class RegisteredOptimizerNotFound(Exception): pass class UninitializedOptimizer(Exception): pass def register_optimizer(name, factory, override=False): global registered_optimizers if name in registered_optimizers: warning(f'{name} was already registered, use override=True to ignore') if not override:
from olympus.utils import MissingArgument, warning, HyperParameters from olympus.utils.factory import fetch_factories registered_schedules = fetch_factories('olympus.optimizers.schedules', __file__) class RegisteredLRSchedulerNotFound(Exception): pass class UninitializedLRScheduler(Exception): pass def known_schedule(): return registered_schedules.keys() def register_schedule(name, factory, override=False): global registered_schedules if name in registered_schedules: warning(f'{name} was already registered, use override=True to ignore') if not override: return registered_schedules[name] = factory
import torch import torch.nn as nn from olympus.models.module import Module from olympus.models.inits import Initializer, known_initialization, get_initializers_space from olympus.utils import MissingArgument, warning, LazyCall, HyperParameters from olympus.utils.factory import fetch_factories from olympus.utils.fp16 import network_to_half registered_models = fetch_factories('olympus.models', __file__) def known_models(): return registered_models.keys() def register_model(name, factory, override=False): global registered_models if name in registered_models: warning(f'{name} was already registered, use override=True to ignore') if not override: return registered_models[name] = factory class RegisteredModelNotFound(Exception): pass
from olympus.utils import warning from olympus.utils.factory import fetch_factories from .gradient_ascent import GradientAscentAdversary from .fast_gradient import FastGradientAdversary registered_adversary = fetch_factories('olympus.adversary', __file__) def known_adversary(): return registered_adversary.keys() def register_adversary(name, factory, override=False): global registered_adversary if name in registered_adversary: warning(f'{name} was already registered, use override=True to ignore') if not override: return registered_adversary[name] = factory class RegisteredAdversaryNotFound(Exception): pass class Adversary: def __init__(self, name, model, min_confidence=0.90, max_iter=10):