Example #1
0
    def set_basic_conf(self):
        from dragonfly.opt.gp_bandit import EuclideanGPBandit
        from dragonfly.exd.experiment_caller import EuclideanFunctionCaller
        from dragonfly import load_config

        def cost(space, reporter):
            height, width = space["point"]
            reporter(loss=(height - 14)**2 - abs(width - 3))

        domain_vars = [{
            "name": "height",
            "type": "float",
            "min": -10,
            "max": 10
        }, {
            "name": "width",
            "type": "float",
            "min": 0,
            "max": 20
        }]

        domain_config = load_config({"domain": domain_vars})

        func_caller = EuclideanFunctionCaller(
            None, domain_config.domain.list_of_domains[0])
        optimizer = EuclideanGPBandit(func_caller, ask_tell_mode=True)
        search_alg = DragonflySearch(
            optimizer,
            metric="loss",
            mode="min",
            max_concurrent=1000,  # Here to avoid breaking back-compat.
        )
        return search_alg, cost
Example #2
0
        "type": "float",
        "min": 0,
        "max": 7
    }, {
        "name": "Li2SO4_vol",
        "type": "float",
        "min": 0,
        "max": 7
    }, {
        "name": "NaClO4_vol",
        "type": "float",
        "min": 0,
        "max": 7
    }]

    domain_config = load_config({"domain": domain_vars})

    func_caller = EuclideanFunctionCaller(
        None, domain_config.domain.list_of_domains[0])
    optimizer = EuclideanGPBandit(func_caller, ask_tell_mode=True)
    algo = DragonflySearch(optimizer,
                           max_concurrent=4,
                           metric="objective",
                           mode="max")
    scheduler = AsyncHyperBandScheduler(metric="objective", mode="max")
    run(objective,
        name="dragonfly_search",
        search_alg=algo,
        scheduler=scheduler,
        **config)
Example #3
0
def maximise_function(func, domain, max_capital, config=None, options=None):
    """
    Maximises a function 'func' over the domain 'domain'.
    Inputs:
      func: The function to be maximised.
      domain: The domain over which the function should be maximised, should be an
              instance of the Domain class in exd/domains.py.
              If domain is a list of the form [[l1, u1], [l2, u2], ...] where li < ui,
              then we will create a Euclidean domain with lower bounds li and upper bounds
              ui along each dimension.
      max_capital: The maximum capital (budget) available for optimisation.
      config: Contains configuration parameters that are typically returned by
              exd.cp_domain_utils.load_config_file. config can be None only if domain
              is a EuclideanDomain object.
      options: Additional hyper-parameters for optimisation.
      * Alternatively, domain could be None if config is either a path_name to a
        configuration file or has configuration parameters.
    Returns:
      opt_val: The maximum value found during the optimisatio procdure.
      opt_pt: The corresponding optimum point.
      history: A record of the optimisation procedure which include the point evaluated
               and the values at each time step.
  """
    # Preprocess domain and config arguments
    raw_func = func
    domain, preproc_func_list, config, _ = _preprocess_arguments(
        domain, [func], config)
    func = preproc_func_list[0]
    # Load arguments depending on domain type
    if domain.get_type() == 'euclidean':
        func_caller = EuclideanFunctionCaller(func, domain, vectorised=False)
    else:
        func_caller = CPFunctionCaller(
            func,
            domain,
            raw_func=raw_func,
            domain_orderings=config.domain_orderings)
    # Create worker manager and function caller
    worker_manager = SyntheticWorkerManager(num_workers=1)
    # Optimise function here -----------------------------------------------------------
    opt_val, opt_pt, history = gpb_from_func_caller(func_caller,
                                                    worker_manager,
                                                    max_capital,
                                                    is_mf=False,
                                                    options=options)
    # Post processing
    if domain.get_type() == 'euclidean' and config is None:
        opt_pt = func_caller.get_raw_domain_coords(opt_pt)
        history.curr_opt_points = [
            func_caller.get_raw_domain_coords(pt)
            for pt in history.curr_opt_points
        ]
        history.query_points = [
            func_caller.get_raw_domain_coords(pt)
            for pt in history.query_points
        ]
    else:
        opt_pt = get_raw_from_processed_via_config(opt_pt, config)
        history.curr_opt_points_raw = [
            get_raw_from_processed_via_config(pt, config)
            for pt in history.curr_opt_points
        ]
        history.query_points_raw = [
            get_raw_from_processed_via_config(pt, config)
            for pt in history.query_points
        ]
    return opt_val, opt_pt, history
Example #4
0
def maximise_multifidelity_function(func,
                                    fidel_space,
                                    domain,
                                    fidel_to_opt,
                                    fidel_cost_func,
                                    max_capital,
                                    config=None,
                                    options=None):
    """
    Maximises a multi-fidelity function 'func' over the domain 'domain' and fidelity
    space 'fidel_space'.
    Inputs:
      func: The function to be maximised. Takes two arguments func(z, x) where z is a
            member of the fidelity space and x is a member of the domain.
      fidel_space: The fidelity space from which the approximations are obtained.
                   Should be an instance of the Domain class in exd/domains.py.
                   If of the form [[l1, u1], [l2, u2], ...] where li < ui, then we will
                   create a Euclidean domain with lower bounds li and upper bounds
                   ui along each dimension.
      domain: The domain over which the function should be maximised, should be an
              instance of the Domain class in exd/domains.py.
              If domain is a list of the form [[l1, u1], [l2, u2], ...] where li < ui,
              then we will create a Euclidean domain with lower bounds li and upper bounds
              ui along each dimension.
      fidel_to_opt: The point at the fidelity space at which we wish to maximise func.
      max_capital: The maximum capital (budget) available for optimisation.
      config: Contains configuration parameters that are typically returned by
              exd.cp_domain_utils.load_config_file. config can be None only if domain
              is a EuclideanDomain object.
      options: Additional hyper-parameters for optimisation.
      * Alternatively, domain and fidelity space could be None if config is either a
        path_name to a configuration file or has configuration parameters.
    Returns:
      opt_val: The maximum value found during the optimisation procdure.
      opt_pt: The corresponding optimum point.
      history: A record of the optimisation procedure which include the point evaluated
               and the values at each time step.
  """
    # Preprocess domain and config arguments
    raw_func = func
    fidel_space, domain, preproc_func_list, fidel_cost_func, fidel_to_opt, config, _ = \
      _preprocess_multifidelity_arguments(fidel_space, domain, [func], fidel_cost_func,
                                          fidel_to_opt, config)
    func = preproc_func_list[0]
    # Load arguments and function caller
    if fidel_space.get_type() == 'euclidean' and domain.get_type(
    ) == 'euclidean':
        func_caller = EuclideanFunctionCaller(func,
                                              domain,
                                              vectorised=False,
                                              raw_fidel_space=fidel_space,
                                              fidel_cost_func=fidel_cost_func,
                                              raw_fidel_to_opt=fidel_to_opt)
    else:
        func_caller = CPFunctionCaller(
            func,
            domain,
            '',
            raw_func=raw_func,
            domain_orderings=config.domain_orderings,
            fidel_space=fidel_space,
            fidel_cost_func=fidel_cost_func,
            fidel_to_opt=fidel_to_opt,
            fidel_space_orderings=config.fidel_space_orderings)

    # Create worker manager
    worker_manager = SyntheticWorkerManager(num_workers=1)
    # Optimise function here -----------------------------------------------------------
    opt_val, opt_pt, history = gpb_from_func_caller(func_caller,
                                                    worker_manager,
                                                    max_capital,
                                                    is_mf=True,
                                                    options=options)
    # Post processing
    if domain.get_type() == 'euclidean' and config is None:
        opt_pt = func_caller.get_raw_domain_coords(opt_pt)
        history.curr_opt_points = [
            func_caller.get_raw_domain_coords(pt)
            for pt in history.curr_opt_points
        ]
        history.query_points = [
            func_caller.get_raw_domain_coords(pt)
            for pt in history.query_points
        ]
    else:

        def _get_raw_from_processed_for_mf(fidel, pt):
            """ Returns raw point from processed point by accounting for the fact that a
          point could be None in the multi-fidelity setting. """
            if fidel is None or pt is None:
                return None, None
            else:
                return get_raw_from_processed_via_config((fidel, pt), config)

        # Now re-write curr_opt_points
        opt_pt = _get_raw_from_processed_for_mf(fidel_to_opt, opt_pt)[1]
        history.curr_opt_points_raw = [
            _get_raw_from_processed_for_mf(fidel_to_opt, pt)[1]
            for pt in history.curr_opt_points
        ]
        query_fidel_points_raw = [
            _get_raw_from_processed_for_mf(fidel, pt)
            for fidel, pt in zip(history.query_fidels, history.query_points)
        ]
        history.query_fidels = [zx[0] for zx in query_fidel_points_raw]
        history.query_points = [zx[1] for zx in query_fidel_points_raw]
    return opt_val, opt_pt, history
Example #5
0
def main():
    """ Main function. """
    # Load configuration file
    objective, config_file, mf_cost = _CHOOSER_DICT[PROBLEM]
    config = load_config_file(config_file)

    # Specify optimisation method -----------------------------------------------------
    opt_method = 'bo'
    # opt_method = 'ga'
    # opt_method = 'rand'

    # Optimise
    max_capital = 60
    domain, domain_orderings = config.domain, config.domain_orderings
    if PROBLEM in ['3d', '5d']:
        # Create function caller.
        # Note there is no function passed in to the Function Caller object.
        func_caller = CPFunctionCaller(None,
                                       domain,
                                       domain_orderings=domain_orderings)

        if opt_method == 'bo':
            opt = gp_bandit.CPGPBandit(func_caller, ask_tell_mode=True)
        elif opt_method == 'ga':
            opt = cp_ga_optimiser.CPGAOptimiser(func_caller,
                                                ask_tell_mode=True)
        elif opt_method == 'rand':
            opt = random_optimiser.CPRandomOptimiser(func_caller,
                                                     ask_tell_mode=True)
        opt.initialise()

        # Optimize using the ask-tell interface
        # User continually asks for the next point to evaluate, then tells the optimizer the
        # new result to perform Bayesian optimisation.
        best_x, best_y = None, float('-inf')
        for _ in range(max_capital):
            x = opt.ask()
            y = objective(x)
            opt.tell([(x, y)])
            print('x: %s, y: %s' % (x, y))
            if y > best_y:
                best_x, best_y = x, y
        print("Optimal Value: %s, Optimal Point: %s" % (best_y, best_x))

        # Compare results with the maximise_function API
        print("-------------")
        print("Compare with maximise_function API:")
        opt_val, opt_pt, history = maximise_function(objective,
                                                     config.domain,
                                                     max_capital,
                                                     opt_method=opt_method,
                                                     config=config)

    elif PROBLEM == '3d_euc':
        # Create function caller.
        # Note there is no function passed in to the Function Caller object.
        domain = domain.list_of_domains[0]
        func_caller = EuclideanFunctionCaller(None, domain)

        if opt_method == 'bo':
            opt = gp_bandit.EuclideanGPBandit(func_caller, ask_tell_mode=True)
        elif opt_method == 'ga':
            raise ValueError("Invalid opt_method %s" % (opt_method))
        opt.initialise()

        # Optimize using the ask-tell interface
        # User continually asks for the next point to evaluate, then tells the optimizer the
        # new result to perform Bayesian optimisation.
        best_x, best_y = None, float('-inf')
        for _ in range(max_capital):
            # Optionally, you can add an integer argument `n_points` to ask to have it return
            # `n_points` number of points. These points will be returned as a list.
            # No argument for `n_points` returns a single point from ask.
            x = opt.ask()
            y = objective(x)
            opt.tell([(x, y)])
            print('x: %s, y: %s' % (x, y))
            if y > best_y:
                best_x, best_y = x, y
        print("Optimal Value: %s, Optimal Point: %s" % (best_y, best_x))

        # Compare results with the maximise_function API
        print("-------------")
        print("Compare with maximise_function API:")
        opt_val, opt_pt, history = maximise_function(objective,
                                                     config.domain,
                                                     max_capital,
                                                     opt_method=opt_method,
                                                     config=config)

    else:
        # Create function caller.
        # Note there is no function passed in to the Function Caller object.
        (ask_tell_fidel_space, ask_tell_domain, _, ask_tell_mf_cost, ask_tell_fidel_to_opt, ask_tell_config, _) = \
          preprocess_multifidelity_arguments(config.fidel_space, domain, [objective],
                                             mf_cost, config.fidel_to_opt, config)
        func_caller = CPFunctionCaller(
            None,
            ask_tell_domain,
            domain_orderings=domain_orderings,
            fidel_space=ask_tell_fidel_space,
            fidel_cost_func=ask_tell_mf_cost,
            fidel_to_opt=ask_tell_fidel_to_opt,
            fidel_space_orderings=config.fidel_space_orderings,
            config=ask_tell_config)
        if opt_method == 'bo':
            opt = gp_bandit.CPGPBandit(func_caller,
                                       is_mf=True,
                                       ask_tell_mode=True)
        else:
            raise ValueError("Invalid opt_method %s" % (opt_method))
        opt.initialise()

        # Optimize using the ask-tell interface
        # User continually asks for the next point to evaluate, then tells the optimizer the
        # new result to perform Bayesian optimisation.
        best_z, best_x, best_y = None, None, float('-inf')
        for _ in range(max_capital):
            point = opt.ask()
            z, x = point[0], point[1]
            y = objective(z, x)
            opt.tell([(z, x, y)])
            print('z: %s, x: %s, y: %s' % (z, x, y))
            if y > best_y:
                best_z, best_x, best_y = z, x, y
        print("Optimal Value: %s, Optimal Point: %s" % (best_y, best_x))

        # Compare results with the maximise_multifidelity_function API
        print("-------------")
        print("Compare with maximise_multifidelity_function API:")
        opt_val, opt_pt, history = maximise_multifidelity_function(
            objective,
            config.fidel_space,
            config.domain,
            config.fidel_to_opt,
            mf_cost,
            max_capital,
            opt_method=opt_method,
            config=config)

    print('opt_pt: %s' % (str(opt_pt)))
    print('opt_val: %s' % (str(opt_val)))
Example #6
0
def hparams(algorithm, scheduler, num_samples, tensorboard, bare):
    from glob import glob

    import tensorflow.summary
    from tensorflow import random as tfrandom, int64 as tfint64
    from ray import init as init_ray, shutdown as shutdown_ray
    from ray import tune
    from wandb.ray import WandbLogger
    from wandb import sweep as wandbsweep
    from wandb.apis import CommError as wandbCommError

    # less summaries are logged if MLENCRYPT_TB is TRUE (for efficiency)
    # TODO: use tf.summary.record_if?
    environ["MLENCRYPT_TB"] = str(tensorboard).upper()
    environ["MLENCRYPT_BARE"] = str(bare).upper()
    if getenv('MLENCRYPT_TB', 'FALSE') == 'TRUE' and \
            getenv('MLENCRYPT_BARE', 'FALSE') == 'TRUE':
        raise ValueError('TensorBoard logging cannot be enabled in bare mode.')

    logdir = f'logs/hparams/{datetime.now()}'

    # "These results show that K = 3 is the optimal choice for the
    # cryptographic application of neural synchronization. K = 1 and K = 2 are
    # too insecure in regard to the geometric attack. And for K > 3 the effort
    # of A and B grows exponentially with increasing L, while the simple attack
    # is quite successful in the limit K -> infinity. Consequently, one should
    # only use Tree Parity Machines with three hidden units for the neural
    # key-exchange protocol." (Ruttor, 2006)
    # https://arxiv.org/pdf/0711.2411.pdf#page=59

    update_rules = [
        'random-same',
        # 'random-different-A-B-E', 'random-different-A-B',
        'hebbian',
        'anti_hebbian',
        'random_walk'
    ]
    K_bounds = {'min': 4, 'max': 8}
    N_bounds = {'min': 4, 'max': 8}
    L_bounds = {'min': 4, 'max': 8}

    # TODO: don't use *_bounds.values() since .values doesn't preserve order

    def get_session_num(logdir):
        current_runs = glob(join(logdir, "run-*"))
        if current_runs:
            last_run_path = current_runs[-1]
            last_run_session_num = int(last_run_path.split('-')[-1])
            return last_run_session_num + 1
        else:  # there are no runs yet, start at 0
            return 0

    def trainable(config, reporter):
        """
        Args:
            config (dict): Parameters provided from the search algorithm
                or variant generation.
        """
        if not isinstance(config['update_rule'], str):
            update_rule = update_rules[int(config['update_rule'])]
        else:
            update_rule = config['update_rule']
        K, N, L = int(config['K']), int(config['N']), int(config['L'])

        run_name = f"run-{get_session_num(logdir)}"
        run_logdir = join(logdir, run_name)
        # for each attack, the TPMs should start with the same weights
        initial_weights_tensors = get_initial_weights(K, N, L)
        training_steps_ls = {}
        eve_scores_ls = {}
        losses_ls = {}
        # for each attack, the TPMs should use the same inputs
        seed = tfrandom.uniform([],
                                minval=0,
                                maxval=tfint64.max,
                                dtype=tfint64).numpy()
        for attack in ['none', 'geometric']:
            initial_weights = {
                tpm: weights_tensor_to_variable(weights, tpm)
                for tpm, weights in initial_weights_tensors.items()
            }
            tfrandom.set_seed(seed)

            if tensorboard:
                attack_logdir = join(run_logdir, attack)
                attack_writer = tensorflow.summary.create_file_writer(
                    attack_logdir)
                with attack_writer.as_default():
                    training_steps, sync_scores, loss = run(
                        update_rule, K, N, L, attack, initial_weights)
            else:
                training_steps, sync_scores, loss = run(
                    update_rule, K, N, L, attack, initial_weights)
            training_steps_ls[attack] = training_steps
            eve_scores_ls[attack] = sync_scores
            losses_ls[attack] = loss
        avg_training_steps = tensorflow.math.reduce_mean(
            list(training_steps_ls.values()))
        avg_eve_score = tensorflow.math.reduce_mean(
            list(eve_scores_ls.values()))
        mean_loss = tensorflow.math.reduce_mean(list(losses_ls.values()))
        reporter(
            avg_training_steps=avg_training_steps.numpy(),
            avg_eve_score=avg_eve_score.numpy(),
            mean_loss=mean_loss.numpy(),
            done=True,
        )

    if algorithm == 'hyperopt':
        from hyperopt import hp as hyperopt
        from hyperopt.pyll.base import scope
        from ray.tune.suggest.hyperopt import HyperOptSearch

        space = {
            'update_rule': hyperopt.choice(
                'update_rule',
                update_rules,
            ),
            'K': scope.int(hyperopt.quniform('K', *K_bounds.values(), q=1)),
            'N': scope.int(hyperopt.quniform('N', *N_bounds.values(), q=1)),
            'L': scope.int(hyperopt.quniform('L', *L_bounds.values(), q=1)),
        }
        algo = HyperOptSearch(
            space,
            metric='mean_loss',
            mode='min',
            points_to_evaluate=[
                {
                    'update_rule': 0,
                    'K': 3,
                    'N': 16,
                    'L': 8
                },
                {
                    'update_rule': 0,
                    'K': 8,
                    'N': 16,
                    'L': 8
                },
                {
                    'update_rule': 0,
                    'K': 8,
                    'N': 16,
                    'L': 128
                },
            ],
        )
    elif algorithm == 'bayesopt':
        from ray.tune.suggest.bayesopt import BayesOptSearch

        space = {
            'update_rule': (0, len(update_rules)),
            'K': tuple(K_bounds.values()),
            'N': tuple(N_bounds.values()),
            'L': tuple(L_bounds.values()),
        }
        algo = BayesOptSearch(
            space,
            metric="mean_loss",
            mode="min",
            # TODO: what is utility_kwargs for and why is it needed?
            utility_kwargs={
                "kind": "ucb",
                "kappa": 2.5,
                "xi": 0.0
            })
    elif algorithm == 'nevergrad':
        from ray.tune.suggest.nevergrad import NevergradSearch
        from nevergrad import optimizers
        from nevergrad import p as ngp

        algo = NevergradSearch(
            optimizers.TwoPointsDE(
                ngp.Instrumentation(
                    update_rule=ngp.Choice(update_rules),
                    K=ngp.Scalar(lower=K_bounds['min'],
                                 upper=K_bounds['max']).set_integer_casting(),
                    N=ngp.Scalar(lower=N_bounds['min'],
                                 upper=N_bounds['max']).set_integer_casting(),
                    L=ngp.Scalar(lower=L_bounds['min'],
                                 upper=L_bounds['max']).set_integer_casting(),
                )),
            None,  # since the optimizer is already instrumented with kwargs
            metric="mean_loss",
            mode="min")
    elif algorithm == 'skopt':
        from skopt import Optimizer
        from ray.tune.suggest.skopt import SkOptSearch

        optimizer = Optimizer([
            update_rules,
            tuple(K_bounds.values()),
            tuple(N_bounds.values()),
            tuple(L_bounds.values())
        ])
        algo = SkOptSearch(
            optimizer,
            ["update_rule", "K", "N", "L"],
            metric="mean_loss",
            mode="min",
            points_to_evaluate=[
                ['random-same', 3, 16, 8],
                ['random-same', 8, 16, 8],
                ['random-same', 8, 16, 128],
            ],
        )
    elif algorithm == 'dragonfly':
        # TODO: doesn't work
        from ray.tune.suggest.dragonfly import DragonflySearch
        from dragonfly.exd.experiment_caller import EuclideanFunctionCaller
        from dragonfly.opt.gp_bandit import EuclideanGPBandit
        # from dragonfly.exd.experiment_caller import CPFunctionCaller
        # from dragonfly.opt.gp_bandit import CPGPBandit
        from dragonfly import load_config

        domain_config = load_config({
            "domain": [
                {
                    "name": "update_rule",
                    "type": "discrete",
                    "dim": 1,
                    "items": update_rules
                },
                {
                    "name": "K",
                    "type": "int",
                    "min": K_bounds['min'],
                    "max": K_bounds['max'],
                    # "dim": 1
                },
                {
                    "name": "N",
                    "type": "int",
                    "min": N_bounds['min'],
                    "max": N_bounds['max'],
                    # "dim": 1
                },
                {
                    "name": "L",
                    "type": "int",
                    "min": L_bounds['min'],
                    "max": L_bounds['max'],
                    # "dim": 1
                }
            ]
        })
        func_caller = EuclideanFunctionCaller(
            None, domain_config.domain.list_of_domains[0])
        optimizer = EuclideanGPBandit(func_caller, ask_tell_mode=True)
        algo = DragonflySearch(
            optimizer,
            metric="mean_loss",
            mode="min",
            points_to_evaluate=[
                ['random-same', 3, 16, 8],
                ['random-same', 8, 16, 8],
                ['random-same', 8, 16, 128],
            ],
        )
    elif algorithm == 'bohb':
        from ConfigSpace import ConfigurationSpace
        from ConfigSpace import hyperparameters as CSH
        from ray.tune.suggest.bohb import TuneBOHB

        config_space = ConfigurationSpace()
        config_space.add_hyperparameter(
            CSH.CategoricalHyperparameter("update_rule", choices=update_rules))
        config_space.add_hyperparameter(
            CSH.UniformIntegerHyperparameter(name='K',
                                             lower=K_bounds['min'],
                                             upper=K_bounds['max']))
        config_space.add_hyperparameter(
            CSH.UniformIntegerHyperparameter(name='N',
                                             lower=N_bounds['min'],
                                             upper=N_bounds['max']))
        config_space.add_hyperparameter(
            CSH.UniformIntegerHyperparameter(name='L',
                                             lower=L_bounds['min'],
                                             upper=L_bounds['max']))
        algo = TuneBOHB(config_space, metric="mean_loss", mode="min")
    elif algorithm == 'zoopt':
        from ray.tune.suggest.zoopt import ZOOptSearch
        from zoopt import ValueType

        space = {
            "update_rule":
            (ValueType.DISCRETE, range(0, len(update_rules)), False),
            "K": (ValueType.DISCRETE,
                  range(K_bounds['min'], K_bounds['max'] + 1), True),
            "N": (ValueType.DISCRETE,
                  range(N_bounds['min'], N_bounds['max'] + 1), True),
            "L": (ValueType.DISCRETE,
                  range(L_bounds['min'], L_bounds['max'] + 1), True),
        }
        # TODO: change budget to a large value
        algo = ZOOptSearch(budget=10,
                           dim_dict=space,
                           metric="mean_loss",
                           mode="min")

    # TODO: use more appropriate arguments for schedulers:
    # https://docs.ray.io/en/master/tune/api_docs/schedulers.html
    if scheduler == 'fifo':
        sched = None  # Tune defaults to FIFO
    elif scheduler == 'pbt':
        from ray.tune.schedulers import PopulationBasedTraining
        from random import randint
        sched = PopulationBasedTraining(
            metric="mean_loss",
            mode="min",
            hyperparam_mutations={
                "update_rule": update_rules,
                "K": lambda: randint(K_bounds['min'], K_bounds['max']),
                "N": lambda: randint(N_bounds['min'], N_bounds['max']),
                "L": lambda: randint(L_bounds['min'], L_bounds['max']),
            })
    elif scheduler == 'ahb' or scheduler == 'asha':
        # https://docs.ray.io/en/latest/tune/api_docs/schedulers.html#asha-tune-schedulers-ashascheduler
        from ray.tune.schedulers import AsyncHyperBandScheduler
        sched = AsyncHyperBandScheduler(metric="mean_loss", mode="min")
    elif scheduler == 'hb':
        from ray.tune.schedulers import HyperBandScheduler
        sched = HyperBandScheduler(metric="mean_loss", mode="min")
    elif algorithm == 'bohb' or scheduler == 'bohb':
        from ray.tune.schedulers import HyperBandForBOHB
        sched = HyperBandForBOHB(metric="mean_loss", mode="min")
    elif scheduler == 'msr':
        from ray.tune.schedulers import MedianStoppingRule
        sched = MedianStoppingRule(metric="mean_loss", mode="min")
    init_ray(
        address=getenv("ip_head"),
        redis_password=getenv('redis_password'),
    )
    analysis = tune.run(
        trainable,
        name='mlencrypt_research',
        config={
            "monitor": True,
            "env_config": {
                "wandb": {
                    "project": "mlencrypt-research",
                    "sync_tensorboard": True,
                },
            },
        },
        # resources_per_trial={"cpu": 1, "gpu": 3},
        local_dir='./ray_results',
        export_formats=['csv'],  # TODO: add other formats?
        num_samples=num_samples,
        loggers=[
            tune.logger.JsonLogger, tune.logger.CSVLogger,
            tune.logger.TBXLogger, WandbLogger
        ],
        search_alg=algo,
        scheduler=sched,
        queue_trials=True,
    )
    try:
        wandbsweep(analysis)
    except wandbCommError:
        # see https://docs.wandb.com/sweeps/ray-tune#feature-compatibility
        pass
    best_config = analysis.get_best_config(metric='mean_loss', mode='min')
    print(f"Best config: {best_config}")
    shutdown_ray()