Пример #1
0
    def _search_space_wrapper(self,
                              space,
                              space_name,
                              fixed=None,
                              fallback=None,
                              preset=None):
        # This function pretends to be a ConfigScope for a named_config
        # but under the hood it is getting a suggestion from the optimizer

        self.current_search_space_name = space_name
        sp = build_search_space(space)

        # Establish connection to database
        if self.db is None:
            self._init_db()

        # Check the validity of this search space
        self._verify_and_init_search_space(sp)

        # Create the optimizer
        if self.optimizer_class is not None:
            if not self.db:
                import warnings
                warnings.warn('No database. Falling back to random search')
                self.optimizer = RandomSearch(self.current_search_space)
            self.optimizer = self.optimizer_class(self.current_search_space)
        else:
            self.optimizer = RandomSearch(self.current_search_space)

        fixed = fixed or {}
        final_config = dict(preset or {})
        # the fallback parameter is needed to fit the interface of a
        # ConfigScope, but here it is not supported.
        assert not fallback, "{}".format(fallback)
        # ensure we have a search space definition
        if self.current_search_space is None:
            raise ValueError("LabAssistant search_space_wrapper called but "
                             "there is no search space definition")

        # Get a hyperparameter configuration from the optimizer
        values = self.get_suggestion()

        # Create configuration object
        config = fill_in_values(self.current_search_space.search_space,
                                values,
                                fill_by='uid')
        final_config.update(config)
        final_config.update(fixed)

        return final_config
Пример #2
0
class LabAssistant(object):

    """
    The main class for Labwatch. It runs an experiment with a configuration suggested by
    and hyperparameter optimizer.

    The hyperparameter optimizer uses the information about the experiment that are stored in the
    database to suggest a new configuration.
    """

    def __init__(self,
                 experiment,
                 database_name=None,
                 url="localhost",
                 optimizer=None,
                 prefix='runs',
                 always_inject_observer=False):

        """
        Create a new LabAssistant and connects it with a database.

        Parameters
        ----------
        experiment : sacred.Experiment
            The (sacred) experiment that is going to be optimized.
        database_name : str
            The name of the database where all information about the runs 
            are saved.
        optimizer: object, optional
            Specifies which optimizer is used to suggest a new hyperparameter
            configuration
        prefix: str, optional
            Additional prefix for the database
        always_inject_observer: bool, optional
            If true an MongoObserver is added to the experiment.
        """

        self.ex = experiment
        self.ex.option_hook(self._option_hook)

        self.db_name = database_name
        self.url = url
        self.db = None

        self.ex.logger = create_basic_stream_logger()
        self.logger = self.ex.logger.getChild('LabAssistant')
        self.prefix = prefix
        self.version_policy = 'newer'
        self.always_inject_observer = always_inject_observer
        self.optimizer_class = optimizer
        self.block_time = 1000  # TODO: what value should this be?
        # remember for which experiments we have config hooks setup
        self.observer_mapping = dict()
        # mark that we have newer looked for finished runs
        self.known_jobs = set()
        self.last_checked = None
        self.current_search_space = None
        self.mongo_observer = None

    def _option_hook(self, options):
        mongo_opt = options.get(MongoDbOption.get_flag())
        if mongo_opt is not None:
            fake_run = FakeRun()
            MongoDbOption.apply(mongo_opt, fake_run)
            self.mongo_observer = fake_run.observers[0]
        else:
            self.mongo_observer = None

    def _init_db(self):
        if self.db_name is None:
            if self.mongo_observer is None:
                mongo_observers = sorted([mo for mo in self.ex.observers
                                          if isinstance(mo, MongoObserver)],
                                         key=lambda x: x.priority)
                if not mongo_observers:
                    raise RuntimeError('No mongo observer found!')
                self.mongo_observer = mongo_observers[-1]
        else:
            self.mongo_observer = MongoObserver.create(db_name=self.db_name,
                                                       collection=self.prefix,
                                                       url=self.url)
            self._inject_observer()
        self.runs = self.mongo_observer.runs
        self.db = self.runs.database
        self.db_search_space = self.db.search_space

        for manipulator in SON_MANIPULATORS:
            self.db.add_son_manipulator(manipulator)

    def _verify_and_init_search_space(self, space_from_ex):
        # Get a search space from the database or from the experiment

        # Check if search space is already in the database
        # (Note: We don't have any id yet that's why we have to loop over all entries)
        in_db = False
        if self.db_search_space.count() > 0:
            for sp in self.db_search_space.find():
                if sp == space_from_ex:
                    self.current_search_space = sp
                    in_db = True
        if not in_db:
            sp_id = self.db_search_space.insert(space_from_ex.to_json())
            self.current_search_space = self.db_search_space.find_one({"_id": sp_id})

        return self.current_search_space

    def _clean_config(self, config):
        values = get_values_from_config(config, self.current_search_space.parameters)
        return values

    def _search_space_wrapper(self, space, space_name, fixed=None,
                              fallback=None, preset=None):
        # This function pretends to be a ConfigScope for a named_config
        # but under the hood it is getting a suggestion from the optimizer

        self.current_search_space_name = space_name
        sp = build_search_space(space)

        # Establish connection to database
        if self.db is None:
            self._init_db()

        # Check the validity of this search space
        self._verify_and_init_search_space(sp)

        # Create the optimizer
        if self.optimizer_class is not None:
            if not self.db:
                import warnings
                warnings.warn('No database. Falling back to random search')
                self.optimizer = RandomSearch(self.current_search_space)
            self.optimizer = self.optimizer_class(self.current_search_space)
        else:
            self.optimizer = RandomSearch(self.current_search_space)

        fixed = fixed or {}
        final_config = dict(preset or {})
        # the fallback parameter is needed to fit the interface of a
        # ConfigScope, but here it is not supported.
        assert not fallback, "{}".format(fallback)
        # ensure we have a search space definition
        if self.current_search_space is None:
            raise ValueError("LabAssistant search_space_wrapper called but "
                             "there is no search space definition")

        # Get a hyperparameter configuration from the optimizer
        values = self.get_suggestion()

        # Create configuration object
        config = fill_in_values(self.current_search_space.search_space, values,
                                fill_by='uid')
        final_config.update(config)
        final_config.update(fixed)

        return final_config

    def _inject_observer(self):
        if self.mongo_observer is None:
            raise ValueError("LabAssistant has no database "
                             "but you called inject_observer")
        if self.mongo_observer not in self.ex.observers:
            self.ex.observers.append(self.mongo_observer)


    def _dequeue_run(self, remaining_time, sleep_time):
        criterion = {'status': 'QUEUED'}
        ex_info = self.ex.get_experiment_info()
        run = None
        start_time = time.time()
        while remaining_time > 0.:
            run = self.runs.find_one(criterion)
            if run is None:
                self.logger.warn('Could not find run from queue waiting for '
                                 'max another {} s'.format(remaining_time))
                time.sleep(sleep_time)
                expired_time = (time.time() - start_time)
                remaining_time = self.block_time - expired_time
            else:
                # verify the run
                check_names(ex_info['name'], run['experiment']['name'])
                check_sources(ex_info['sources'], run['experiment']['sources'])
                check_dependencies(ex_info['dependencies'],
                                   run['experiment']['dependencies'],
                                   self.version_policy)
                
                # set status to INITIALIZING to prevent others from
                # running the same Run.
                old_status = run['status']
                run['status'] = 'INITIALIZING'
                replace_summary = self.runs.replace_one(
                    {'_id': run['_id'], 'status': old_status},
                    replacement=run)
                if replace_summary.modified_count == 1 or \
                   replace_summary.raw_result['updatedExisting']:
                    # the second part above is necessary in case we are
                    # working with an older mongodb server (version < 2.6)
                    # which will not return the modified_count flag
                    break  # we've successfully acquired a run
        return run
        
    # ########################## exported functions ###########################

    def set_database(self, database):
        self.db = database
        self._init_db()
        # we need to verify the search space again
        self._verify_and_init_search_space(self.current_search_space)
    
    def update_optimizer(self):
        if self.db is None:
            self.logger.warn("Cannot update optimizer, reason: no database!")
            return
        # First check database for all configurations
        if self.last_checked is None:
            # if we never checked the database we have to check
            # everything that happened since the definition of time ;)
            self.last_checked = datetime.datetime.min
        # oldest_still_running = None
        # running_jobs = self.runs.find(
        #     {
        #         'heartbeat': {'$gte': self.last_checked},
        #         'status': 'RUNNING'
        #     },
        #     sort=[("start_time", 1)]
        # )
        #

        # Take all jobs that are finished and were run with a config from this search space
        completed_jobs = self.runs.find(
            {
                'heartbeat': {'$gte': self.last_checked},
                'status': 'COMPLETED',
                'meta.options.UPDATE': self.current_search_space_name
            }
        )
        # update the last checked to the oldest one that is still running
        self.last_checked = datetime.datetime.now()
        # collect all configs and their results
        info = [(self._clean_config(job["config"]), convert_result(job["result"]), job)
                for job in completed_jobs if job["_id"] not in self.known_jobs]        
        if len(info) > 0:
            configs, results, jobs = (list(x) for x in zip(*info))
            self.known_jobs |= {job['_id'] for job in jobs}
            modifications = self.optimizer.update(configs, results, jobs)
            # the optimizer might modify the additional info of jobs
            if modifications is not None:
                for job in modifications:
                    new_info = job.info
                    self.runs.update_one(
                        {'_id': job["_id"]},
                        {'$set': {'info': new_info}},
                        upsert=False)

    def get_suggestion(self):
        if self.current_search_space is None:
            raise ValueError("LabAssistant sample_suggestion called "
                             "without a defined search space")
        #if self.optimizer.needs_updates():
        self.update_optimizer()

        suggestion = self.optimizer.suggest_configuration()
        values = {self.current_search_space.parameters[k]['uid']: v for k, v in suggestion.items() if k in self.current_search_space.parameters}
        return values

    def get_current_best(self, return_job_info=False):
        if self.db is None:
            self.logger.warn("cannot update optimizer, reason: no database!")
            return
        # ("status", 1) sorts according to status in ascending order
        best_job = self.runs.find_one({'status': 'COMPLETED'},
                                      sort=[("result", 1)])
        if best_job is None:
            best_result = None
            best_config = None
        else:
            best_result = best_job["result"]
            best_config = self._clean_config(best_job["config"])
        if return_job_info:
            return best_config, best_result, best_job
        else:
            return best_config, best_result

    def run_suggestion(self, command=None):
        # get config from optimizer
        #return self.run_config(self.get_suggestion(), command)
        values = self.get_suggestion()
        config = fill_in_values(self.current_search_space.search_space, values, fill_by='uid')

        return self.run_config(config, command)

    def run_random(self, command=None):
        return self.run_config(self.optimizer.get_random_config(), command)

    def run_default(self, command=None):
        return self.run_config(self.optimizer.get_default_config(), command)

    def run_config(self, config, command=None):
        if config is None:
            raise RuntimeError("None is not an acceptable config!")
        #config = self._clean_config(config)
        self._inject_observer()
        if command is None:
            res = self.ex.run(config_updates=config)
        else:
            res = self.ex.run_command(command, config_updates=config)
        return res

    def enqueue_suggestion(self, command='main'):
        # Next get config from optimizer
        config = self._clean_config(self.get_suggestion())
        if config is None:
            raise RuntimeError("Optimizer did not return a config!")
        self._inject_observer()
        res = self.ex.run_command(command,
                                  config_updates=config,
                                  args={"--queue": QueueOption()})

    def run_from_queue(self, wait_time_in_s=10 * 60, sleep_time=5):
        run = self._dequeue_run(wait_time_in_s, sleep_time)
        if run is None:
            self.logger.warn("No run found in queue for {} s -> terminating"
                             .format(wait_time_in_s))
            return None
        else:
            # remove MongoObserver if we have one for that experiment
            had_matching_observer = False
            if self.ex in self.observer_mapping:
                had_matching_observer = True
                matching = None
                for i, observer in enumerate(self.ex.observers):
                    if observer == self.observer_mapping[self.ex]:
                        matching = i
                if matching is None:
                    self.logger.warn("Could not remove observer in run_from_queue")
                    pass
                else:
                    del self.ex.observers[matching]
                    del self.observer_mapping[self.ex]
            # add a matching MongoObserver to the experiment and tell it to
            # overwrite the run
            fs = gridfs.GridFS(self.db, collection=self.prefix)
            self.ex.observers.append(MongoObserver(self.runs, fs,
                                                   overwrite=run))

            # run the experiment
            res = self.ex.run_command(run['command'],
                                      config_updates=run['config'])

            # remove the extra observer
            self.ex.observers.pop()
            # and inject the default one
            if had_matching_observer:
                self._inject_observer()
            return res

    # ############################## Decorators ###############################

    def search_space(self, function):
        """Decorator for creating a searchspace definition from a function."""
        #if self.search_space is not None:
        #    raise RuntimeError('Only one search space allowed per Assistant')

        # space = build_search_space(function)
        #
        # #TODO: Get MongoObserver from experiment
        #
        # # Establish connection to database
        # self._init_db()
        #
        # # Check the validity of this search space
        # self._verify_and_init_search_space(space)

        # Get a configuration from the optimizer and add it as a named config
        search_space_wrapper = functools.partial(self._search_space_wrapper,
                                                 space=function,
                                                 space_name=function.__name__)
        self.ex._add_named_config(function.__name__, search_space_wrapper)