def __init__(self, service): super(DriverStatusThread, self).__init__() self.service = service self.daemon = True self.running = False self.summary_interval = float(config('summary_interval', '60'))
def serve(addr, search_id, spaces_dir, models_dir, on_next=None, on_report=None, on_summary=None): import grpc from concurrent import futures worker_number = int(config('grpc_worker_count', '10')) service = SearchDriverService(spaces_dir, models_dir) server = grpc.server(futures.ThreadPoolExecutor(max_workers=worker_number)) spec_pb2_grpc.add_SearchDriverServicer_to_server(service, server) server.add_insecure_port(addr) server.start() service.start_search(search_id, on_next, on_report, on_summary) return server, service
def dispatch(self, hyper_model, X, y, X_val, y_val, cv, num_folds, max_trials, dataset_id, trial_store, **fit_kwargs): assert not any(dask.is_dask_collection(i) for i in (X, y, X_val, y_val)), \ f'{self.__class__.__name__} does not support to run trial with dask collection.' queue_size = int(config('search_queue', '1')) worker_count = int(config('search_executors', '3')) retry_limit = int(config('search_retry', '1000')) failed_counter = Counter() success_counter = Counter() def on_trial_start(trial_item): trial_item.start_at = time.time() if logger.is_info_enabled(): msg = f'Start trial {trial_item.trial_no}, space_id={trial_item.space_id}' \ + f',model_file={trial_item.model_file}' logger.info(msg) for callback in hyper_model.callbacks: # callback.on_build_estimator(hyper_model, space_sample, estimator, trial_no) #fixme callback.on_trial_begin(hyper_model, trial_item.space_sample, trial_item.trial_no) def on_trial_done(trial_item): trial_item.done_at = time.time() if trial_item.reward != 0 and not math.isnan( trial_item.reward): # success improved = hyper_model.history.append(trial_item) for callback in hyper_model.callbacks: callback.on_trial_end(hyper_model, trial_item.space_sample, trial_item.trial_no, trial_item.reward, improved, trial_item.elapsed) success_counter() else: for callback in hyper_model.callbacks: callback.on_trial_error(hyper_model, trial_item.space_sample, trial_item.trial_no) failed_counter() if logger.is_info_enabled(): elapsed = '%.3f' % (trial_item.done_at - trial_item.start_at) msg = f'Trial {trial_item.trial_no} done with reward={trial_item.reward}, ' \ f'elapsed {elapsed} seconds\n' logger.info(msg) if trial_store is not None: trial_store.put(dataset_id, trial_item) pool = DaskExecutorPool(worker_count, queue_size, on_trial_start, on_trial_done, hyper_model._run_trial, X, y, X_val, y_val, fit_kwargs) pool.start() trial_no = 1 retry_counter = 0 while trial_no <= max_trials and pool.running: if pool.qsize >= queue_size: time.sleep(0.1) continue space_sample = hyper_model.searcher.sample() if hyper_model.history.is_existed(space_sample): if retry_counter >= retry_limit: logger.info( f'Unable to take valid sample and exceed the retry limit 1000.' ) break trial = hyper_model.history.get_trial(space_sample) for callback in hyper_model.callbacks: callback.on_skip_trial(hyper_model, space_sample, trial_no, 'trial_existed', trial.reward, False, trial.elapsed) retry_counter += 1 continue try: if trial_store is not None: trial = trial_store.get(dataset_id, space_sample) if trial is not None: reward = trial.reward elapsed = trial.elapsed trial = Trial(space_sample, trial_no, reward, elapsed) improved = hyper_model.history.append(trial) hyper_model.searcher.update_result( space_sample, reward) for callback in hyper_model.callbacks: callback.on_skip_trial(hyper_model, space_sample, trial_no, 'hit_trial_store', reward, improved, elapsed) trial_no += 1 continue model_file = '%s/%05d_%s.pkl' % (self.models_dir, trial_no, space_sample.space_id) item = DaskTrialItem(space_sample, trial_no, model_file=model_file) pool.push(item) if logger.is_info_enabled(): logger.info( f'Found trial {trial_no}, queue size: {pool.qsize}') except EarlyStoppingError: pool.stop() break except KeyboardInterrupt: pool.stop() pool.interrupted = True print('KeyboardInterrupt') break except Exception as e: import traceback msg = f'{">" * 20} Search trial {trial_no} failed! {"<" * 20}\n' \ + f'{e.__class__.__name__}: {e}\n' \ + traceback.format_exc() \ + '*' * 50 logger.error(msg) finally: trial_no += 1 retry_counter = 0 # wait trials if pool.running: logger.info('Search done, wait trial tasks.') pool.push(None) # mark end pool.join() if logger.is_info_enabled(): logger.info( f'Search and all trials done, {success_counter.value} success, ' f'{failed_counter.value} failed.') return trial_no
def dispatch(self, hyper_model, X, y, X_eval, y_eval, max_trails, dataset_id, trail_store, **fit_kwargs): def on_next_space(item): for cb in hyper_model.callbacks: # cb.on_build_estimator(hyper_model, space_sample, estimator, trail_no) cb.on_trail_begin(hyper_model, item.space_sample, item.trail_no) def on_report_space(item): if item.success: elapsed = item.report_at - item.start_at trail = Trail(item.space_sample, item.trail_no, item.reward, elapsed) # print(f'trail result:{trail}') improved = hyper_model.history.append(trail) if improved and logger.is_info_enabled(): logger.info( f'>>>improved: reward={item.reward}, trail_no={item.trail_no}, space_id={item.space_id}' ) hyper_model.searcher.update_result(item.space_sample, item.reward) if trail_store is not None: trail_store.put(dataset_id, trail) for cb in hyper_model.callbacks: cb.on_trail_end(hyper_model, item.space_sample, item.trail_no, item.reward, improved, elapsed) else: for cb in hyper_model.callbacks: cb.on_trail_error(hyper_model, space_sample, trail_no) def on_summary(): t = hyper_model.get_best_trail() if t: detail = f'reward={t.reward}, trail_no={t.trail_no}, space_id={t.space_sample.space_id}' return f'best: {detail}' else: return None def do_clean(): # shutdown grpc server search_service.status_thread.stop() search_service.status_thread.report_summary() # server.stop(grace=1.0) if 'search_id' in fit_kwargs: search_id = fit_kwargs.pop('search_id') else: global _search_counter _search_counter += 1 search_id = 'search-%02d' % _search_counter if logger.is_info_enabled(): logger.info(f'start driver server at {self.address}') server, search_service = get_or_serve(self.address, search_id, self.spaces_dir, self.models_dir, on_next=on_next_space, on_report=on_report_space, on_summary=on_summary) search_start_at = time.time() trail_no = 1 retry_counter = 0 queue_size = int(config('search_queue', '1')) while trail_no <= max_trails: space_sample = hyper_model.searcher.sample() if hyper_model.history.is_existed(space_sample): if retry_counter >= 1000: if logger.is_info_enabled(): logger.info( f'Unable to take valid sample and exceed the retry limit 1000.' ) break trail = hyper_model.history.get_trail(space_sample) for callback in hyper_model.callbacks: callback.on_skip_trail(hyper_model, space_sample, trail_no, 'trail_exsited', trail.reward, False, trail.elapsed) retry_counter += 1 continue try: if trail_store is not None: trail = trail_store.get(dataset_id, space_sample) if trail is not None: reward = trail.reward elapsed = trail.elapsed trail = Trail(space_sample, trail_no, reward, elapsed) improved = hyper_model.history.append(trail) hyper_model.searcher.update_result( space_sample, reward) for callback in hyper_model.callbacks: callback.on_skip_trail(hyper_model, space_sample, trail_no, 'hit_trail_store', reward, improved, elapsed) trail_no += 1 continue search_service.add(trail_no, space_sample) # wait for queued trail while search_service.queue_size() >= queue_size: time.sleep(0.1) except EarlyStoppingError: break # TODO: early stopping except KeyboardInterrupt: do_clean() return trail_no except Exception as e: if logger.is_warning_enabled(): import sys import traceback msg = f'{e.__class__.__name__}: {e}' logger.warning(f'{">" * 20} Trail failed! {"<" * 20}') logger.warning(msg + '\n' + traceback.format_exc()) logger.warning('*' * 50) finally: trail_no += 1 retry_counter = 0 if logger.is_info_enabled(): logger.info("-" * 20 + 'no more space to search, waiting trails ...') try: while search_service.running_size() > 0: # if logger.is_info_enabled(): # logger.info(f"wait ... {search_service.running_size()} samples found.") time.sleep(0.1) except KeyboardInterrupt: return trail_no finally: do_clean() if logger.is_info_enabled(): logger.info('-' * 20 + ' all trails done ' + '-' * 20) return trail_no