Exemple #1
0
class Generator(object):
    """Convert search space and search algorithm, sample a new model."""

    def __init__(self):
        self.step_name = General.step_name
        self.search_space = SearchSpace()
        self.search_alg = SearchAlgorithm(self.search_space)
        self.report = Report()
        self.record = ReportRecord()
        self.record.step_name = self.step_name
        if hasattr(self.search_alg.config, 'objective_keys'):
            self.record.objective_keys = self.search_alg.config.objective_keys
        self.quota = QuotaCompare('restrict')

    @property
    def is_completed(self):
        """Define a property to determine search algorithm is completed."""
        return self.search_alg.is_completed or self.quota.is_halted()

    def sample(self):
        """Sample a work id and model from search algorithm."""
        res = self.search_alg.search()
        if not res:
            return None
        if not isinstance(res, list):
            res = [res]
        if len(res) == 0:
            return None
        out = []
        for sample in res:
            desc = sample.get("desc") if isinstance(sample, dict) else sample[1]
            desc = self._decode_hps(desc)
            model_desc = deepcopy(desc)
            if "modules" in desc:
                PipeStepConfig.model.model_desc = deepcopy(desc)
            elif "network" in desc:
                origin_desc = PipeStepConfig.model.model_desc
                desc = update_dict(desc["network"], origin_desc)
                PipeStepConfig.model.model_desc = deepcopy(desc)
            if self.quota.is_filtered(desc):
                continue
            record = self.record.from_sample(sample, desc)
            Report().broadcast(record)
            out.append((record.worker_id, model_desc))
        return out

    def update(self, step_name, worker_id):
        """Update search algorithm accord to the worker path.

        :param step_name: step name
        :param worker_id: current worker id
        :return:
        """
        report = Report()
        record = report.receive(step_name, worker_id)
        logging.debug("Get Record=%s", str(record))
        self.search_alg.update(record.serialize())
        report.dump_report(record.step_name, record)
        self.dump()
        logging.info("Update Success. step_name=%s, worker_id=%s", step_name, worker_id)
        logging.info("Best values: %s", Report().print_best(step_name=General.step_name))

    @staticmethod
    def _decode_hps(hps):
        """Decode hps: `trainer.optim.lr : 0.1` to dict format.

        And convert to `zeus.common.config import Config` object
        This Config will be override in Trainer or Datasets class
        The override priority is: input hps > user configuration >  default configuration
        :param hps: hyper params
        :return: dict
        """
        hps_dict = {}
        if hps is None:
            return None
        if isinstance(hps, tuple):
            return hps
        for hp_name, value in hps.items():
            hp_dict = {}
            for key in list(reversed(hp_name.split('.'))):
                if hp_dict:
                    hp_dict = {key: hp_dict}
                else:
                    hp_dict = {key: value}
            # update cfg with hps
            hps_dict = update_dict(hps_dict, hp_dict, [])
        return Config(hps_dict)

    def dump(self):
        """Dump generator to file."""
        step_path = TaskOps().step_path
        _file = os.path.join(step_path, ".generator")
        with open(_file, "wb") as f:
            pickle.dump(self, f, protocol=pickle.HIGHEST_PROTOCOL)

    @classmethod
    def restore(cls):
        """Restore generator from file."""
        step_path = TaskOps().step_path
        _file = os.path.join(step_path, ".generator")
        if os.path.exists(_file):
            with open(_file, "rb") as f:
                return pickle.load(f)
        else:
            return None
Exemple #2
0
class Generator(object):
    """Convert search space and search algorithm, sample a new model."""
    def __init__(self):
        self.step_name = General.step_name
        self.search_space = SearchSpace()
        self.search_alg = SearchAlgorithm(self.search_space)
        if hasattr(self.search_alg.config, 'objective_keys'):
            self.objective_keys = self.search_alg.config.objective_keys
        self.quota = QuotaCompare('restrict')
        self.affinity = None if General.quota.affinity.type is None else QuotaAffinity(
            General.quota.affinity)

    @property
    def is_completed(self):
        """Define a property to determine search algorithm is completed."""
        return self.search_alg.is_completed or self.quota.is_halted()

    def sample(self):
        """Sample a work id and model from search algorithm."""
        for _ in range(10):
            res = self.search_alg.search()
            if not res:
                return None
            if not isinstance(res, list):
                res = [res]
            if len(res) == 0:
                return None
            out = []
            for sample in res:
                if isinstance(sample, dict):
                    id = sample["worker_id"]
                    desc = self._decode_hps(sample["encoded_desc"])
                    sample.pop("worker_id")
                    sample.pop("encoded_desc")
                    kwargs = sample
                    sample = _split_sample((id, desc))
                else:
                    kwargs = {}
                    sample = _split_sample(sample)
                if hasattr(self, "objective_keys") and self.objective_keys:
                    kwargs["objective_keys"] = self.objective_keys
                (id, desc, hps) = sample

                if "modules" in desc:
                    PipeStepConfig.model.model_desc = deepcopy(desc)
                elif "network" in desc:
                    origin_desc = PipeStepConfig.model.model_desc
                    model_desc = update_dict(desc["network"], origin_desc)
                    PipeStepConfig.model.model_desc = model_desc
                    desc.pop('network')
                    desc.update(model_desc)

                if self.quota.is_filtered(desc):
                    continue
                if self.affinity and not self.affinity.is_affinity(desc):
                    continue
                ReportClient().update(General.step_name,
                                      id,
                                      desc=desc,
                                      hps=hps,
                                      **kwargs)
                out.append((id, desc, hps))
            if out:
                break
        return out

    def update(self, step_name, worker_id):
        """Update search algorithm accord to the worker path.

        :param step_name: step name
        :param worker_id: current worker id
        :return:
        """
        record = ReportClient().get_record(step_name, worker_id)
        logging.debug("Get Record=%s", str(record))
        self.search_alg.update(record.serialize())
        try:
            self.dump()
        except TypeError:
            logging.warning(
                "The Generator contains object which can't be pickled.")
        logging.info(
            f"Update Success. step_name={step_name}, worker_id={worker_id}")
        logging.info("Best values: %s",
                     ReportServer().print_best(step_name=General.step_name))

    @staticmethod
    def _decode_hps(hps):
        """Decode hps: `trainer.optim.lr : 0.1` to dict format.

        And convert to `vega.common.config import Config` object
        This Config will be override in Trainer or Datasets class
        The override priority is: input hps > user configuration >  default configuration
        :param hps: hyper params
        :return: dict
        """
        hps_dict = {}
        if hps is None:
            return None
        if isinstance(hps, tuple):
            return hps
        for hp_name, value in hps.items():
            hp_dict = {}
            for key in list(reversed(hp_name.split('.'))):
                if hp_dict:
                    hp_dict = {key: hp_dict}
                else:
                    hp_dict = {key: value}
            # update cfg with hps
            hps_dict = update_dict(hps_dict, hp_dict, [])
        return Config(hps_dict)

    def dump(self):
        """Dump generator to file."""
        step_path = TaskOps().step_path
        _file = os.path.join(step_path, ".generator")
        with open(_file, "wb") as f:
            pickle.dump(self, f, protocol=pickle.HIGHEST_PROTOCOL)

    @classmethod
    def restore(cls):
        """Restore generator from file."""
        step_path = TaskOps().step_path
        _file = os.path.join(step_path, ".generator")
        if os.path.exists(_file):
            with open(_file, "rb") as f:
                return pickle.load(f)
        else:
            return None