示例#1
0
 def __init__(self, hps=None, **kwargs):
     """Construct method."""
     default_config = {
         'batch_size': 1,
         'num_workers': 0,
         'shuffle': False,
         'distributed': False,
         'imgs_per_gpu': 1,
         'pin_memory': True,
         'drop_last': True
     }
     self.mode = "train"
     if "mode" in kwargs.keys():
         self.mode = kwargs["mode"]
     if hps is not None:
         self._init_hps(hps)
     if 'common' in self.cfg.keys() and self.cfg['common'] is not None:
         common = deepcopy(self.cfg['common'])
         self.args = update_dict(common, deepcopy(self.cfg[self.mode]))
     else:
         self.args = deepcopy(self.cfg[self.mode])
     self.args = update_dict(self.args, Config(default_config))
     for key in kwargs.keys():
         if key in self.args:
             self.args[key] = kwargs[key]
     self.train = self.mode in ["train", "val"]
     transforms_list = self._init_transforms()
     self._transforms = Transforms(transforms_list)
     if "transforms" in kwargs.keys():
         self._transforms.__transform__ = kwargs["transforms"]
     self.sampler = self._init_sampler()
示例#2
0
 def _reference_trainer_settings(self):
     """Set reference Trainer."""
     ref = self.cfg.get('ref')
     if ref:
         ref_dict = ClassFactory.__configs__
         for key in ref.split('.'):
             ref_dict = ref_dict.get(key)
         update_dict(ref_dict, self.cfg)
示例#3
0
文件: mfkd.py 项目: zeyefkey/vega
 def _desc_from_choices(self, choices):
     """Create description object from choices."""
     desc = {}
     pos = 0
     for key in self.search_space.modules:
         config_space = copy.deepcopy(self.search_space[key])
         module_cfg, pos = self._sub_config_choice(config_space, choices, pos)
         desc[key] = module_cfg
     desc = update_dict(desc, copy.deepcopy(self.search_space))
     return desc
示例#4
0
    def decode(self, code):
        """Decode the code to Network Desc.

        :param code: input code
        :type code: list of int
        :return: network desc
        :rtype: NetworkDesc
        """
        chn_info = self._code_to_chninfo(code)
        desc = {"backbone": chn_info}
        desc = update_dict(desc, copy.deepcopy(self.search_space))
        return desc
示例#5
0
文件: trainer.py 项目: zhwzhong/vega
    def _init_hps(self, hps):
        """Convert trainer values in hps to cfg.

        :param hps: hyperparameters
        :type hps: dict
        """
        if "hps_file" in self.cfg and self.cfg.hps_file is not None:
            hps_file = self.cfg.hps_file.replace("{local_base_path}",
                                                 self.local_base_path)
            hps = Config(hps_file)
        if hps is not None:
            self.cfg = Config(update_dict(hps.get('trainer'), self.cfg))
            self.hps = hps
示例#6
0
 def _dispatch_trainer(self, samples):
     for (id_ele, desc) in samples:
         cls_trainer = ClassFactory.get_cls('trainer')
         if "modules" in desc:
             PipeStepConfig.model.model_desc = deepcopy(desc)
         elif "network" in desc:
             origin_desc = PipeStepConfig.model.model_desc
             desc = update_dict(desc["network"], origin_desc)
         model_ele = NetworkDesc(desc).to_model()
         trainer = cls_trainer(model_ele, id_ele, hps=desc)
         logging.info("submit trainer, id={}".format(id_ele))
         self.master.run(trainer)
     if isinstance(samples, list) and len(samples) > 1:
         self.master.join()
示例#7
0
    def decode(self, code):
        """Decode the code.

        :param code: code of network
        :type code: list
        :return: network desc
        :rtype: NetworkDesc
        """
        length = len(code)
        desc = {
            "nbit_w_list": code[: length // 2],
            "nbit_a_list": code[length // 2:]
        }
        desc = update_dict(desc, copy.deepcopy(self.search_space))
        return desc
示例#8
0
    def _decode_best_hps(self):
        """Decode best hps: `trainer.optim.lr : 0.1` to dict format.

        :return: dict
        """
        hps = self.best_hps['configs']
        hps_dict = {}
        for hp_name, value in hps.items():
            hp_dict = {}
            for key in list(reversed(hp_name.split('.'))):
                if hp_dict:
                    hp_dict = {key: hp_dict}
                else:
                    hp_dict = {key: value}
            # update cfg with hps
            hps_dict = update_dict(hps_dict, hp_dict)
        return hps_dict
示例#9
0
    def _decode_hps(hps):
        """Decode hps: `trainer.optim.lr : 0.1` to dict format.

        And convert to `vega.core.common.config import Config` object
        This Config will be override in Trainer or Datasets class
        The override priority is: input hps > user configuration >  default configuration
        :param hps: hyper params
        :return: dict
        """
        hps_dict = {}
        for hp_name, value in hps.items():
            hp_dict = {}
            for key in list(reversed(hp_name.split('.'))):
                if hp_dict:
                    hp_dict = {key: hp_dict}
                else:
                    hp_dict = {key: value}
            # update cfg with hps
            hps_dict = update_dict(hps_dict, hp_dict, [])
        return Config(hps_dict)
示例#10
0
 def _init_hps(self, hps):
     """Convert trainer values in hps to cfg."""
     if hps.get("dataset") is not None:
         self.cfg.train = Config(update_dict(hps.dataset, self.cfg.train))
     self.cfg = Config(update_dict(hps.trainer, self.cfg))
示例#11
0
文件: dataset.py 项目: zeyefkey/vega
 def _init_hps(self, hps):
     """Convert trainer values in hps to cfg."""
     if hps is not None:
         self.args = Config(update_dict(hps, self.args))