def register_hpo(client, namespace, function, config, defaults): hpo = { 'hpo': make_remote_call(HPOptimizer, **config), 'hpo_state': None, 'work': make_remote_call(function, **defaults), 'experiment': namespace } return client.push(WORK_QUEUE, namespace, message=hpo, mtype=HPO_ITEM)
def test_hpo_serializable(model_type): namespace = 'test-robo-' + model_type n_init = 2 count = 10 # First run using a remote worker where serialization is necessary # and for which hpo is resumed between each braning call hpo = build_robo(model_type, n_init=n_init, count=count) namespace = 'test_hpo_serializable' hpo = { 'hpo': make_remote_call(HPOptimizer, **hpo.kwargs), 'hpo_state': None, 'work': make_remote_call(branin), 'experiment': namespace } client = new_client(URI, DATABASE) client.push(WORK_QUEUE, namespace, message=hpo, mtype=HPO_ITEM) worker = TrialWorker(URI, DATABASE, 0, None) worker.max_retry = 0 worker.timeout = 1 worker.run() messages = client.monitor().unread_messages(RESULT_QUEUE, namespace) for m in messages: if m.mtype == HPO_ITEM: break assert m.mtype == HPO_ITEM, 'HPO not completed' worker_hpo = build_robo(model_type) worker_hpo.load_state_dict(m.message['hpo_state']) assert len(worker_hpo.trials) == count # Then run locally where BO is not resumed local_hpo = build_robo(model_type, n_init=n_init, count=count) i = 0 best = float('inf') while local_hpo.remaining() and i < local_hpo.hpo.count: samples = local_hpo.suggest() for sample in samples: z = branin(**sample) local_hpo.observe(sample['uid'], z) best = min(z, best) i += 1 assert i == local_hpo.hpo.count # Although remote worker was resumed many times, it should give the same # results as the local one which was executed in a single run. assert worker_hpo.trials == local_hpo.trials
def run_hpo(self, hpo, fun, *args, **kwargs): """Launch a master HPO""" state = { 'work': make_remote_call(fun, *args, **kwargs), 'experiment': self.experiment } manager = HPOManager.sync_hpo(hpo, self.client, state) return manager
def queue_work(self, fun, *args, namespace=None, **kwargs): """Queue work items for the workers""" message = make_remote_call(fun, *args, **kwargs) if namespace is None: namespace = self.experiment return self.client.push(WORK_QUEUE, namespace, mtype=WORK_ITEM, message=message)
def register_hpo_replicates(client, function, hpo_namespace, configs): new_registered = set() for replicate_type, replicate_configs in configs.items(): rep_namespace = env(hpo_namespace, replicate_type) registered = set(fetch_all_trial_info(client, rep_namespace).keys()) for config in replicate_configs: if config['uid'] in registered: print(f'trial {config["uid"]} already registered') continue client.push(WORK_QUEUE, rep_namespace, make_remote_call(function, **config), mtype=WORK_ITEM) new_registered.add(config['uid']) registered.add(config['uid']) return new_registered
def register_tests(client, namespace, function, configs): new_registered = defaultdict(lambda: defaultdict(set)) for hpo, hpo_runs in configs.items(): for hpo_namespace, test_configs in hpo_runs: registered = fetch_registered(client, [hpo_namespace]) for config in configs: if config['uid'] in registered: print(f'trial {config["uid"]} already registered') continue client.push(WORK_QUEUE, env(hpo_namespace, 'test'), make_remote_call(function, **config), mtype=WORK_ITEM) new_registered[hpo][hpo_namespace].add(config['uid']) registered.add(config['uid']) return new_registered
def register(client, function, namespace, variables): registered = fetch_registered( client, [env(namespace, variable) for variable in variables.keys()]) new_registered = set() for variable, configs in variables.items(): for config in configs: if config['uid'] in registered: print(f'trial {config["uid"]} already registered') continue client.push(WORK_QUEUE, env(namespace, variable), make_remote_call(function, **config), mtype=WORK_ITEM) new_registered.add(config['uid']) registered.add(config['uid']) return new_registered