Esempio n. 1
0
 def _desc_from_choices(self, choices):
     """Create description object from choices."""
     desc = {}
     pos = 0
     for key in self.search_space.modules:
         config_space = copy.deepcopy(self.search_space[key])
         module_cfg, pos = self._sub_config_choice(config_space, choices,
                                                   pos)
         desc[key] = module_cfg
     desc = update_dict(desc, copy.deepcopy(self.search_space))
     return desc
Esempio n. 2
0
    def decode(self, code):
        """Decode the code to Network Desc.

        :param code: input code
        :type code: list of int
        :return: network desc
        :rtype: NetworkDesc
        """
        chn_info = self._code_to_chninfo(code)
        desc = {
            'backbone': chn_info,
        }
        desc = update_dict(desc, copy.deepcopy(self.search_space))
        return desc
Esempio n. 3
0
    def decode(self, code):
        """Decode the code.

        :param code: code of network
        :type code: list
        :return: network desc
        :rtype: NetworkDesc
        """
        length = len(code)
        desc = {
            "nbit_w_list": code[:length // 2],
            "nbit_a_list": code[length // 2:]
        }
        desc = update_dict(desc, copy.deepcopy(self.search_space))
        return desc
Esempio n. 4
0
    def sample(self):
        """Sample a work id and model from search algorithm."""
        for _ in range(10):
            res = self.search_alg.search()
            if not res:
                return None
            if not isinstance(res, list):
                res = [res]
            if len(res) == 0:
                return None
            out = []
            for sample in res:
                if isinstance(sample, dict):
                    id = sample["worker_id"]
                    desc = self._decode_hps(sample["encoded_desc"])
                    sample.pop("worker_id")
                    sample.pop("encoded_desc")
                    kwargs = sample
                    sample = _split_sample((id, desc))
                else:
                    kwargs = {}
                    sample = _split_sample(sample)
                if hasattr(self, "objective_keys") and self.objective_keys:
                    kwargs["objective_keys"] = self.objective_keys
                (id, desc, hps) = sample

                if "modules" in desc:
                    PipeStepConfig.model.model_desc = deepcopy(desc)
                elif "network" in desc:
                    origin_desc = PipeStepConfig.model.model_desc
                    model_desc = update_dict(desc["network"], origin_desc)
                    PipeStepConfig.model.model_desc = model_desc
                    desc.pop('network')
                    desc.update(model_desc)

                if self.quota.is_filtered(desc):
                    continue
                if self.affinity and not self.affinity.is_affinity(desc):
                    continue
                ReportClient().update(General.step_name,
                                      id,
                                      desc=desc,
                                      hps=hps,
                                      **kwargs)
                out.append((id, desc, hps))
            if out:
                break
        return out
Esempio n. 5
0
    def decode(self, code):
        """Decode the code to Network Desc.

        :param code: input code
        :type code: list of int
        :return: network desc
        :rtype: NetworkDesc
        """
        chn_info = self._code_to_chninfo(code)
        desc = {
            "backbone": chn_info,
            "head": {
                "base_channel": chn_info['chn_node'][-1]
            }
        }
        desc = update_dict(desc, copy.deepcopy(self.search_space))
        return desc
Esempio n. 6
0
    def _decode_hps(hps):
        """Decode hps: `trainer.optim.lr : 0.1` to dict format.

        And convert to `vega.common.config import Config` object
        This Config will be override in Trainer or Datasets class
        The override priority is: input hps > user configuration >  default configuration
        :param hps: hyper params
        :return: dict
        """
        hps_dict = {}
        if hps is None:
            return None
        if isinstance(hps, tuple):
            return hps
        for hp_name, value in hps.items():
            hp_dict = {}
            for key in list(reversed(hp_name.split('.'))):
                if hp_dict:
                    hp_dict = {key: hp_dict}
                else:
                    hp_dict = {key: value}
            # update cfg with hps
            hps_dict = update_dict(hps_dict, hp_dict, [])
        return Config(hps_dict)