示例#1
0
from olympus.utils.factory import fetch_factories

from olympus.utils import set_seeds as init_seed, warning, HyperParameters

from torch.nn import Module
from torch.random import fork_rng

registered_initialization = fetch_factories('olympus.models.inits', __file__)


def known_initialization():
    return registered_initialization.keys()


def register_initialization(name, factory, override=False):
    global registered_initialization

    if name in registered_initialization:
        warning(f'{name} was already registered, use override=True to ignore')

        if not override:
            return

    registered_initialization[name] = factory


class RegisteredInitNotFound(Exception):
    pass


def get_initializers_space():
示例#2
0
from olympus.utils.factory import fetch_factories
from olympus.utils import MissingArgument, warning
from olympus.hpo.fidelity import Fidelity
from olympus.hpo.parallel import ParallelHPO

registered_optimizer = fetch_factories('olympus.hpo', __file__)


def known_hpo():
    return registered_optimizer.keys()


def register_hpo(name, factory, override=False):
    global registered_optimizer

    if name in registered_optimizer:
        warning(f'{name} was already registered, use override=True to ignore')

        if not override:
            return

    registered_optimizer[name] = factory


class RegisteredHPONotFound(Exception):
    pass


class HPOptimizer:
    """Olympus standardized HPO interface
示例#3
0
from torch.utils.data import Dataset as TorchDataset, Subset

from olympus.utils import new_seed
from olympus.utils.factory import fetch_factories

sampling_methods = fetch_factories('olympus.datasets.split',
                                   __file__,
                                   function_name='split')


def known_split_methods():
    return list(sampling_methods.keys())


class RegisteredSplitMethodNotFound(Exception):
    pass


def _generate_splits(datasets, split_method, seed, ratio, index, data_size,
                     balanced):
    if data_size is not None:
        data_size = int(data_size * len(datasets))
        assert data_size <= len(datasets)
    else:
        data_size = len(datasets)

    split = sampling_methods.get(split_method)

    if split is None:
        raise RegisteredSplitMethodNotFound(
            f'Split method `{split_method}` was not found use {known_split_methods()}'
示例#4
0
import copy
from collections import defaultdict
import torch
from torch.utils.data import DataLoader as TorchDataLoader, Dataset as TorchDataset

from olympus.utils import warning, option, MissingArgument
from olympus.utils.factory import fetch_factories
from olympus.datasets.transformed import TransformedSubset
from olympus.datasets.split import SplitDataset
from olympus.datasets.sampling import RandomSampler, SequentialSampler

registered_datasets = fetch_factories('olympus.datasets', __file__)


def known_datasets(*category_filters, include_unknown=False):
    if not category_filters:
        return registered_datasets.keys()

    matching = []
    for filter in category_filters:
        for name, factory in registered_datasets.items():

            if hasattr(factory, 'categories'):
                if filter in factory.categories():
                    matching.append(name)

            # we don't know if it matches because it does not have the categories method
            elif include_unknown:
                matching.append(name)

    return matching
示例#5
0
from olympus.utils import warning
from olympus.utils.factory import fetch_factories

registered_environment = fetch_factories('olympus.reinforcement', __file__)


def known_environments(*category_filters, include_unknown=False):
    """List known environments"""
    if not category_filters:
        return registered_environment.keys()

    matching = []
    for filter in category_filters:
        for name, factory in registered_environment.items():

            if hasattr(factory, 'categories'):
                if filter in factory.categories():
                    matching.append(name)

            # we don't know if it matches because it does not have the categories method
            elif include_unknown:
                matching.append(name)

    return matching


def register_environment(name, factory, override=False):
    """Register a new environment backend"""
    global registered_environment

    if name in registered_environment:
示例#6
0
from typing import Dict

import torch
from torch.optim.optimizer import Optimizer as TorchOptimizer

from olympus.utils import MissingArgument, warning, HyperParameters
from olympus.utils.factory import fetch_factories
from olympus.optimizers.schedules import LRSchedule, known_schedule

registered_optimizers = fetch_factories('olympus.optimizers', __file__)


def known_optimizers():
    return registered_optimizers.keys()


class RegisteredOptimizerNotFound(Exception):
    pass


class UninitializedOptimizer(Exception):
    pass


def register_optimizer(name, factory, override=False):
    global registered_optimizers

    if name in registered_optimizers:
        warning(f'{name} was already registered, use override=True to ignore')

        if not override:
示例#7
0
from olympus.utils import MissingArgument, warning, HyperParameters
from olympus.utils.factory import fetch_factories


registered_schedules = fetch_factories('olympus.optimizers.schedules', __file__)


class RegisteredLRSchedulerNotFound(Exception):
    pass


class UninitializedLRScheduler(Exception):
    pass


def known_schedule():
    return registered_schedules.keys()


def register_schedule(name, factory, override=False):
    global registered_schedules

    if name in registered_schedules:
        warning(f'{name} was already registered, use override=True to ignore')

        if not override:
            return

    registered_schedules[name] = factory

示例#8
0
import torch
import torch.nn as nn

from olympus.models.module import Module
from olympus.models.inits import Initializer, known_initialization, get_initializers_space

from olympus.utils import MissingArgument, warning, LazyCall, HyperParameters
from olympus.utils.factory import fetch_factories
from olympus.utils.fp16 import network_to_half

registered_models = fetch_factories('olympus.models', __file__)


def known_models():
    return registered_models.keys()


def register_model(name, factory, override=False):
    global registered_models

    if name in registered_models:
        warning(f'{name} was already registered, use override=True to ignore')

        if not override:
            return

    registered_models[name] = factory


class RegisteredModelNotFound(Exception):
    pass
示例#9
0
from olympus.utils import warning
from olympus.utils.factory import fetch_factories

from .gradient_ascent import GradientAscentAdversary
from .fast_gradient import FastGradientAdversary

registered_adversary = fetch_factories('olympus.adversary', __file__)


def known_adversary():
    return registered_adversary.keys()


def register_adversary(name, factory, override=False):
    global registered_adversary

    if name in registered_adversary:
        warning(f'{name} was already registered, use override=True to ignore')

        if not override:
            return

    registered_adversary[name] = factory


class RegisteredAdversaryNotFound(Exception):
    pass


class Adversary:
    def __init__(self, name, model, min_confidence=0.90, max_iter=10):