Ejemplo n.º 1
0
    def __init__(self, service):
        super(DriverStatusThread, self).__init__()

        self.service = service
        self.daemon = True
        self.running = False
        self.summary_interval = float(config('summary_interval', '60'))
def serve(addr, search_id, spaces_dir, models_dir, on_next=None, on_report=None, on_summary=None):
    import grpc
    from concurrent import futures

    worker_number = int(config('grpc_worker_count', '10'))
    service = SearchDriverService(spaces_dir, models_dir)
    server = grpc.server(futures.ThreadPoolExecutor(max_workers=worker_number))
    spec_pb2_grpc.add_SearchDriverServicer_to_server(service, server)

    server.add_insecure_port(addr)
    server.start()

    service.start_search(search_id, on_next, on_report, on_summary)

    return server, service
Ejemplo n.º 3
0
    def dispatch(self, hyper_model, X, y, X_val, y_val, cv, num_folds,
                 max_trials, dataset_id, trial_store, **fit_kwargs):
        assert not any(dask.is_dask_collection(i) for i in (X, y, X_val, y_val)), \
            f'{self.__class__.__name__} does not support to run trial with dask collection.'

        queue_size = int(config('search_queue', '1'))
        worker_count = int(config('search_executors', '3'))
        retry_limit = int(config('search_retry', '1000'))

        failed_counter = Counter()
        success_counter = Counter()

        def on_trial_start(trial_item):
            trial_item.start_at = time.time()
            if logger.is_info_enabled():
                msg = f'Start trial {trial_item.trial_no}, space_id={trial_item.space_id}' \
                      + f',model_file={trial_item.model_file}'
                logger.info(msg)
            for callback in hyper_model.callbacks:
                # callback.on_build_estimator(hyper_model, space_sample, estimator, trial_no) #fixme
                callback.on_trial_begin(hyper_model, trial_item.space_sample,
                                        trial_item.trial_no)

        def on_trial_done(trial_item):
            trial_item.done_at = time.time()

            if trial_item.reward != 0 and not math.isnan(
                    trial_item.reward):  # success
                improved = hyper_model.history.append(trial_item)
                for callback in hyper_model.callbacks:
                    callback.on_trial_end(hyper_model, trial_item.space_sample,
                                          trial_item.trial_no,
                                          trial_item.reward, improved,
                                          trial_item.elapsed)
                success_counter()
            else:
                for callback in hyper_model.callbacks:
                    callback.on_trial_error(hyper_model,
                                            trial_item.space_sample,
                                            trial_item.trial_no)
                failed_counter()

            if logger.is_info_enabled():
                elapsed = '%.3f' % (trial_item.done_at - trial_item.start_at)
                msg = f'Trial {trial_item.trial_no} done with reward={trial_item.reward}, ' \
                      f'elapsed {elapsed} seconds\n'
                logger.info(msg)
            if trial_store is not None:
                trial_store.put(dataset_id, trial_item)

        pool = DaskExecutorPool(worker_count, queue_size, on_trial_start,
                                on_trial_done, hyper_model._run_trial, X, y,
                                X_val, y_val, fit_kwargs)
        pool.start()

        trial_no = 1
        retry_counter = 0

        while trial_no <= max_trials and pool.running:
            if pool.qsize >= queue_size:
                time.sleep(0.1)
                continue

            space_sample = hyper_model.searcher.sample()
            if hyper_model.history.is_existed(space_sample):
                if retry_counter >= retry_limit:
                    logger.info(
                        f'Unable to take valid sample and exceed the retry limit 1000.'
                    )
                    break
                trial = hyper_model.history.get_trial(space_sample)
                for callback in hyper_model.callbacks:
                    callback.on_skip_trial(hyper_model, space_sample, trial_no,
                                           'trial_existed', trial.reward,
                                           False, trial.elapsed)
                retry_counter += 1
                continue

            try:
                if trial_store is not None:
                    trial = trial_store.get(dataset_id, space_sample)
                    if trial is not None:
                        reward = trial.reward
                        elapsed = trial.elapsed
                        trial = Trial(space_sample, trial_no, reward, elapsed)
                        improved = hyper_model.history.append(trial)
                        hyper_model.searcher.update_result(
                            space_sample, reward)
                        for callback in hyper_model.callbacks:
                            callback.on_skip_trial(hyper_model, space_sample,
                                                   trial_no, 'hit_trial_store',
                                                   reward, improved, elapsed)
                        trial_no += 1
                        continue

                model_file = '%s/%05d_%s.pkl' % (self.models_dir, trial_no,
                                                 space_sample.space_id)

                item = DaskTrialItem(space_sample,
                                     trial_no,
                                     model_file=model_file)
                pool.push(item)

                if logger.is_info_enabled():
                    logger.info(
                        f'Found trial {trial_no}, queue size: {pool.qsize}')
            except EarlyStoppingError:
                pool.stop()
                break
            except KeyboardInterrupt:
                pool.stop()
                pool.interrupted = True
                print('KeyboardInterrupt')
                break
            except Exception as e:
                import traceback
                msg = f'{">" * 20} Search trial {trial_no} failed! {"<" * 20}\n' \
                      + f'{e.__class__.__name__}: {e}\n' \
                      + traceback.format_exc() \
                      + '*' * 50
                logger.error(msg)
            finally:
                trial_no += 1
                retry_counter = 0

        # wait trials
        if pool.running:
            logger.info('Search done, wait trial tasks.')
        pool.push(None)  # mark end
        pool.join()

        if logger.is_info_enabled():
            logger.info(
                f'Search and all trials done, {success_counter.value} success, '
                f'{failed_counter.value} failed.')

        return trial_no
    def dispatch(self, hyper_model, X, y, X_eval, y_eval, max_trails,
                 dataset_id, trail_store, **fit_kwargs):
        def on_next_space(item):
            for cb in hyper_model.callbacks:
                # cb.on_build_estimator(hyper_model, space_sample, estimator, trail_no)
                cb.on_trail_begin(hyper_model, item.space_sample,
                                  item.trail_no)

        def on_report_space(item):
            if item.success:
                elapsed = item.report_at - item.start_at
                trail = Trail(item.space_sample, item.trail_no, item.reward,
                              elapsed)
                # print(f'trail result:{trail}')

                improved = hyper_model.history.append(trail)
                if improved and logger.is_info_enabled():
                    logger.info(
                        f'>>>improved: reward={item.reward}, trail_no={item.trail_no}, space_id={item.space_id}'
                    )
                hyper_model.searcher.update_result(item.space_sample,
                                                   item.reward)

                if trail_store is not None:
                    trail_store.put(dataset_id, trail)

                for cb in hyper_model.callbacks:
                    cb.on_trail_end(hyper_model, item.space_sample,
                                    item.trail_no, item.reward, improved,
                                    elapsed)
            else:
                for cb in hyper_model.callbacks:
                    cb.on_trail_error(hyper_model, space_sample, trail_no)

        def on_summary():
            t = hyper_model.get_best_trail()
            if t:
                detail = f'reward={t.reward}, trail_no={t.trail_no}, space_id={t.space_sample.space_id}'
                return f'best: {detail}'
            else:
                return None

        def do_clean():
            # shutdown grpc server
            search_service.status_thread.stop()
            search_service.status_thread.report_summary()
            # server.stop(grace=1.0)

        if 'search_id' in fit_kwargs:
            search_id = fit_kwargs.pop('search_id')
        else:
            global _search_counter
            _search_counter += 1
            search_id = 'search-%02d' % _search_counter

        if logger.is_info_enabled():
            logger.info(f'start driver server at {self.address}')
        server, search_service = get_or_serve(self.address,
                                              search_id,
                                              self.spaces_dir,
                                              self.models_dir,
                                              on_next=on_next_space,
                                              on_report=on_report_space,
                                              on_summary=on_summary)

        search_start_at = time.time()

        trail_no = 1
        retry_counter = 0
        queue_size = int(config('search_queue', '1'))

        while trail_no <= max_trails:
            space_sample = hyper_model.searcher.sample()
            if hyper_model.history.is_existed(space_sample):
                if retry_counter >= 1000:
                    if logger.is_info_enabled():
                        logger.info(
                            f'Unable to take valid sample and exceed the retry limit 1000.'
                        )
                    break
                trail = hyper_model.history.get_trail(space_sample)
                for callback in hyper_model.callbacks:
                    callback.on_skip_trail(hyper_model, space_sample, trail_no,
                                           'trail_exsited', trail.reward,
                                           False, trail.elapsed)
                retry_counter += 1
                continue

            try:
                if trail_store is not None:
                    trail = trail_store.get(dataset_id, space_sample)
                    if trail is not None:
                        reward = trail.reward
                        elapsed = trail.elapsed
                        trail = Trail(space_sample, trail_no, reward, elapsed)
                        improved = hyper_model.history.append(trail)
                        hyper_model.searcher.update_result(
                            space_sample, reward)
                        for callback in hyper_model.callbacks:
                            callback.on_skip_trail(hyper_model, space_sample,
                                                   trail_no, 'hit_trail_store',
                                                   reward, improved, elapsed)
                        trail_no += 1
                        continue

                search_service.add(trail_no, space_sample)

                # wait for queued trail
                while search_service.queue_size() >= queue_size:
                    time.sleep(0.1)
            except EarlyStoppingError:
                break
                # TODO: early stopping
            except KeyboardInterrupt:
                do_clean()
                return trail_no
            except Exception as e:
                if logger.is_warning_enabled():
                    import sys
                    import traceback
                    msg = f'{e.__class__.__name__}: {e}'
                    logger.warning(f'{">" * 20} Trail failed! {"<" * 20}')
                    logger.warning(msg + '\n' + traceback.format_exc())
                    logger.warning('*' * 50)
            finally:
                trail_no += 1
                retry_counter = 0
        if logger.is_info_enabled():
            logger.info("-" * 20 +
                        'no more space to search, waiting trails ...')
        try:
            while search_service.running_size() > 0:
                # if logger.is_info_enabled():
                #    logger.info(f"wait ... {search_service.running_size()} samples found.")
                time.sleep(0.1)
        except KeyboardInterrupt:
            return trail_no
        finally:
            do_clean()

        if logger.is_info_enabled():
            logger.info('-' * 20 + ' all trails done ' + '-' * 20)

        return trail_no