示例#1
0
 def _get_model_desc(self):
     model_desc = self.trainer.model_desc
     if not model_desc:
         if ModelConfig.model_desc_file is not None:
             desc_file = ModelConfig.model_desc_file
             desc_file = desc_file.replace("{local_base_path}",
                                           self.trainer.local_base_path)
             if ":" not in desc_file:
                 desc_file = os.path.abspath(desc_file)
             if ":" in desc_file:
                 local_desc_file = FileOps.join_path(
                     self.trainer.local_output_path,
                     os.path.basename(desc_file))
                 FileOps.copy_file(desc_file, local_desc_file)
                 desc_file = local_desc_file
             model_desc = Config(desc_file)
             logger.info("net_desc:{}".format(model_desc))
         elif ModelConfig.model_desc is not None:
             model_desc = ModelConfig.model_desc
         elif ModelConfig.models_folder is not None:
             folder = ModelConfig.models_folder.replace(
                 "{local_base_path}", self.trainer.local_base_path)
             pattern = FileOps.join_path(folder, "desc_*.json")
             desc_file = glob.glob(pattern)[0]
             model_desc = Config(desc_file)
     return model_desc
示例#2
0
class DartsNetworkTemplateConfig(ConfigSerializable):
    """Darts network template config."""

    cifar10 = Config(
        os.path.join(os.path.dirname(__file__), "darts_cifar10.json"))
    cifar100 = Config(
        os.path.join(os.path.dirname(__file__), "darts_cifar100.json"))
    imagenet = Config(
        os.path.join(os.path.dirname(__file__), "darts_imagenet.json"))
示例#3
0
    def _code_to_chninfo(self, code):
        """Transform code to channel info.

        :param code: input code
        :type code: list of int
        :return: channel info
        :rtype: Config
        """
        chn = copy.deepcopy(self.base_chn)
        chn_node = copy.deepcopy(self.base_chn_node)
        chninfo = Config()
        chninfo['base_chn'] = self.base_chn
        chninfo['base_chn_node'] = self.base_chn_node
        if code is None:
            chninfo['chn'] = chn
            chninfo['chn_node'] = chn_node
            chninfo['encoding'] = code
            return chninfo
        chn_node = [self.search_space.backbone.base_channel] + chn_node
        chn_mask = []
        chn_node_mask = []
        start_id = 0
        end_id = chn[0]
        for i in range(len(chn)):
            if sum(code[start_id:end_id]) == 0:
                len_mask = len(code[start_id:end_id])
                tmp_mask = [0] * len_mask
                tmp_mask[random.randint(0, len_mask - 1)] = 1
                chn_mask.append(tmp_mask)
            else:
                chn_mask.append(code[start_id:end_id])
            start_id = end_id
            if i + 1 == len(chn):
                end_id += chn_node[0]
            else:
                end_id += chn[i + 1]
        chn = []
        for single_chn_mask in chn_mask:
            chn.append(sum(single_chn_mask))
        for i in range(len(chn_node)):
            if sum(code[start_id:end_id]) == 0:
                len_mask = len(code[start_id:end_id])
                tmp_mask = [0] * len_mask
                tmp_mask[random.randint(0, len_mask - 1)] = 1
                chn_node_mask.append(tmp_mask)
            else:
                chn_node_mask.append(code[start_id:end_id])
            start_id = end_id
            if i + 1 < len(chn_node):
                end_id += chn_node[i + 1]
        chn_node = []
        for single_chn_mask in chn_node_mask:
            chn_node.append(sum(single_chn_mask))
        chninfo['chn'] = chn
        chninfo['chn_node'] = chn_node[1:]
        chninfo['base_channel'] = chn_node[0]
        chninfo['chn_mask'] = chn_mask
        chninfo['chn_node_mask'] = chn_node_mask
        chninfo['encoding'] = code
        return chninfo
示例#4
0
 def __init__(self, metric_cfg=None):
     """Init Metrics."""
     self.mdict = {}
     metric_config = self.config.to_dict()
     if not isinstance(metric_config, list):
         metric_config = [metric_config]
     for metric_item in metric_config:
         ClassFactory.get_cls(ClassType.METRIC, self.config.type)
         metric_name = metric_item.pop('type')
         metric_class = ClassFactory.get_cls(ClassType.METRIC, metric_name)
         if isfunction(metric_class):
             metric_class = partial(metric_class, **metric_item.get("params", {}))
         else:
             metric_class = metric_class(**metric_item.get("params", {}))
         self.mdict[metric_name] = metric_class
     self.mdict = Config(self.mdict)
     self.metric_results = dict()
示例#5
0
 def _init_lr_scheduler(self):
     """Init lr scheduler from timm according to type in config."""
     args = self.config.lr_scheduler().to_dict()["params"]
     args['epochs'] = self.config.epochs
     lr_scheduler, self.config.epochs = create_scheduler(
         Config(args), self.trainer.optimizer)
     start_epoch = args.get('start_epoch', 0)
     lr_scheduler.step(start_epoch)
     return lr_scheduler
示例#6
0
 def __init__(self, types=['epoch', 'train'], max_steps=[0, 0]):
     self.estimator = Config()
     if not isinstance(types, list) or not isinstance(max_steps, list):
         types = [types]
         max_steps = [max_steps]
     if len(types) != len(max_steps):
         raise Exception('types length must equal to max_step')
     for type, max_step in zip(types, max_steps):
         self.add_runtime_est(type, max_step)
示例#7
0
    def _get_model_desc(self):
        model_desc = self.model_desc
        self.saved_folder = self.get_local_worker_path(self.step_name,
                                                       self.worker_id)
        if not model_desc:
            if os.path.exists(
                    FileOps.join_path(self.saved_folder,
                                      'desc_{}.json'.format(self.worker_id))):
                model_config = Config(
                    FileOps.join_path(self.saved_folder,
                                      'desc_{}.json'.format(self.worker_id)))
                if "type" not in model_config and "modules" not in model_config:
                    model_config = ModelConfig.model_desc
                model_desc = model_config
            elif ModelConfig.model_desc_file is not None:
                desc_file = ModelConfig.model_desc_file
                desc_file = desc_file.replace("{local_base_path}",
                                              self.local_base_path)
                if ":" not in desc_file:
                    desc_file = os.path.abspath(desc_file)
                if ":" in desc_file:
                    local_desc_file = FileOps.join_path(
                        self.local_output_path, os.path.basename(desc_file))
                    FileOps.copy_file(desc_file, local_desc_file)
                    desc_file = local_desc_file
                model_desc = Config(desc_file)
                logger.info("net_desc:{}".format(model_desc))
            elif ModelConfig.model_desc is not None:
                model_desc = ModelConfig.model_desc
            elif ModelConfig.models_folder is not None:
                folder = ModelConfig.models_folder.replace(
                    "{local_base_path}", self.local_base_path)
                pattern = FileOps.join_path(folder, "desc_*.json")
                desc_file = glob.glob(pattern)[0]
                model_desc = Config(desc_file)

            elif PipeStepConfig.pipe_step.get("models_folder") is not None:
                folder = PipeStepConfig.pipe_step.get("models_folder").replace(
                    "{local_base_path}", self.local_base_path)
                desc_file = FileOps.join_path(
                    folder, "desc_{}.json".format(self.worker_id))
                model_desc = Config(desc_file)
                logger.info("Load model from model folder {}.".format(folder))
        return model_desc
示例#8
0
 def __init__(self, search_space=None, **kwargs):
     """Init DartsCodec."""
     super(DartsCodec, self).__init__(search_space, **kwargs)
     self.darts_cfg = copy.deepcopy(search_space)
     self.super_net = {
         'cells.normal': self.darts_cfg.super_network.cells.normal.genotype,
         'cells.reduce': self.darts_cfg.super_network.cells.reduce.genotype
     }
     self.super_net = Config(self.super_net)
     self.steps = self.darts_cfg.super_network.cells.normal.steps
示例#9
0
 def from_dict(cls, data, skip_check=True):
     """Restore config from a dictionary or a file."""
     t_cls = super(ModelConfig, cls).from_dict(data, skip_check)
     if data.get("models_folder") and not data.get('model_desc'):
         folder = data.models_folder.replace(
             "{local_base_path}",
             os.path.join(TaskConfig.local_base_path, TaskConfig.task_id))
         pattern = FileOps.join_path(folder, "desc_*.json")
         desc_file = glob.glob(pattern)[0]
         t_cls.model_desc = Config(desc_file)
     elif data.get("model_desc_file") and not data.get('model_desc'):
         model_desc_file = data.get("model_desc_file").replace(
             "{local_base_path}",
             os.path.join(TaskConfig.local_base_path, TaskConfig.task_id))
         t_cls.model_desc = Config(model_desc_file)
     if data.get("pretrained_model_file"):
         pretrained_model_file = data.get("pretrained_model_file").replace(
             "{local_base_path}",
             os.path.join(TaskConfig.local_base_path, TaskConfig.task_id))
         t_cls.pretrained_model_file = pretrained_model_file
     return t_cls
示例#10
0
def init_cluster_args():
    """Initialize local_cluster."""
    if not General.cluster.master_ip:
        master_ip = '127.0.0.1'
        General.cluster.master_ip = master_ip
        env = Config({
            "init_method": "tcp://{}:{}".format(master_ip, General.cluster.listen_port),
            "world_size": 1,
            "rank": 0
        })
    else:
        world_size = len(General.cluster.slaves) + 1 if General.cluster.slaves else 1
        env = Config({
            "init_method": "tcp://{}:{}".format(
                General.cluster.master_ip, General.cluster.listen_port),
            "world_size": world_size,
            "rank": 0,
            "slaves": General.cluster.slaves,
        })
    General.env = env
    return env
示例#11
0
    def _code_to_chninfo(self, code):
        """Transform code to channel info.

        :param code: input code
        :type code: list of int
        :return: channel info
        :rtype: Config
        """
        chn_info = Config()
        start_id = 0
        end_id = 0
        if not code:
            return self.cfgs
        chn = copy.deepcopy(self.cfgs)
        chn_mask = []
        for idx, layer in enumerate(self.cfgs):
            if idx == 0:
                pass
            # hidden_dim
            cfg_idx = 1
            end_id += int(layer[cfg_idx])
            if sum(code[start_id:end_id]) == 0:
                len_mask = len(code[start_id:end_id])
                tmp_mask = [0] * len_mask
                tmp_mask[random.randint(0, len_mask - 1)] = 1
                chn_mask.append(tmp_mask)
            else:
                chn_mask.append(code[start_id:end_id])
            chn[idx][cfg_idx] = int(sum(chn_mask[-1]))
            start_id = end_id

            # output_channel
            cfg_idx = 2
            end_id += int(layer[cfg_idx])
            if sum(code[start_id:end_id]) == 0:
                len_mask = len(code[start_id:end_id])
                tmp_mask = [0] * len_mask
                tmp_mask[random.randint(0, len_mask - 1)] = 1
                chn_mask.append(tmp_mask)
            else:
                chn_mask.append(code[start_id:end_id])
            chn[idx][cfg_idx] = int(sum(chn_mask[-1]))
            start_id = end_id

        chn_info['cfgs'] = chn
        chn_info['base_cfgs'] = self.cfgs
        chn_info['chn_mask'] = chn_mask
        chn_info['encoding'] = code
        return chn_info
示例#12
0
 def _get_hps(self, hps):
     if hps is not None:
         pass
     elif self.config.hps_file is not None:
         desc_file = self.config.hps_file.replace("{local_base_path}",
                                                  self.local_base_path)
         hps = Config(desc_file)
         if "trainer" in hps:
             if "epochs" in hps["trainer"]:
                 hps["trainer"].pop("epochs")
             if "checkpoint_path" in hps["trainer"]:
                 hps["trainer"].pop("checkpoint_path")
     elif self.config.hps_folder is not None:
         folder = self.config.hps_folder.replace("{local_base_path}",
                                                 self.local_base_path)
         pattern = os.path.join(folder, "hps_*.json")
         desc_file = glob.glob(pattern)[0]
         hps = Config(desc_file)
         if "trainer" in hps:
             if "epochs" in hps["trainer"]:
                 hps["trainer"].pop("epochs")
             if "checkpoint_path" in hps["trainer"]:
                 hps["trainer"].pop("checkpoint_path")
     return hps
示例#13
0
    def add_runtime_est(self, type, max_step):
        """Add new type of runtime estimator.

        :param type: runtime type
        :type type: str
        :param max_step: max step of new type
        :type type: int
        """
        if type in self.estimator:
            logging.warning('type %s has already in estimator', type)
            return
        self.estimator[type] = Config()
        self.estimator[type].start_time = None
        self.estimator[type].current_time = None
        self.estimator[type].start_step = 0
        self.estimator[type].current_step = 0
        self.estimator[type].max_step = max_step
示例#14
0
 def _save_descript(self):
     """Save result descript."""
     template_file = self.config.darts_template_file
     genotypes = self.search_alg.codec.calc_genotype(
         self._get_arch_weights())
     if template_file == "{default_darts_cifar10_template}":
         template = DartsNetworkTemplateConfig.cifar10
     elif template_file == "{default_darts_cifar100_template}":
         template = DartsNetworkTemplateConfig.cifar100
     elif template_file == "{default_darts_imagenet_template}":
         template = DartsNetworkTemplateConfig.imagenet
     else:
         dst = FileOps.join_path(self.trainer.get_local_worker_path(),
                                 os.path.basename(template_file))
         FileOps.copy_file(template_file, dst)
         template = Config(dst)
     model_desc = self._gen_model_desc(genotypes, template)
     self.trainer.config.codec = model_desc
示例#15
0
    def decode(self, sample):
        """Decode backbone to description.

        :param sample: input sample to decode.
        :type sample: dict
        :return: return a decoded sample desc.
        :rtype: dict
        """
        if 'code' not in sample:
            raise ValueError('No code to decode in sample:{}'.format(sample))
        backbone_code, ffm_code = sample['code'].split('+')

        decoder_map = dict(x=ResNeXtVariantDetCodec, r=ResNetVariantDetCodec)
        CodecSpec = decoder_map.get(backbone_code[0], None)
        if CodecSpec is None:
            raise NotImplementedError(f'Only {decoder_map} is support in auto_lane algorithm')
        generator = CodecSpec(**CodecSpec.arch_decoder(backbone_code))
        backbone_desc = str2dict(generator.config)
        neck_desc = dict(
            arch_code=ffm_code,
            type='FeatureFusionModule',
            in_channels=backbone_desc['out_channels'],
        )
        head_desc = dict(
            base_channel=128 + 128 + backbone_desc['out_channels'][2] if ffm_code != '-' else
            backbone_desc['out_channels'][2],
            num_classes=2,
            up_points=73,
            down_points=72,
            type='AutoLaneHead'
        )
        detector = dict(
            modules=['backbone', 'neck', 'head'],
            num_class=2,
            method=sample['method'],
            code=sample['code'],
            backbone=backbone_desc,
            neck=neck_desc,
            head=head_desc
        )
        return Config({'modules': ['detector'], 'detector': {'type': 'AutoLaneDetector', 'desc': detector}})
示例#16
0
    def genotypes_to_json(self, genotypes):
        """Transfer genotypes to json.

        :param genotypes: Genotype for models
        :type genotypes: namedtuple Genotype
        """
        desc_list = []
        if self.trainer.config.darts_template_file == "{default_darts_cifar10_template}":
            template = DartsNetworkTemplateConfig.cifar10
        elif self.trainer.config.darts_template_file == "{default_darts_cifar100_template}":
            template = DartsNetworkTemplateConfig.cifar100
        elif self.trainer.config.darts_template_file == "{default_darts_imagenet_template}":
            template = DartsNetworkTemplateConfig.imagenet
        else:
            template = self.trainer.config.darts_template_file
        for idx in range(len(genotypes)):
            template_cfg = Config(template)
            template_cfg.super_network.cells.normal.genotype = genotypes[idx].normal
            template_cfg.super_network.cells.reduce.genotype = genotypes[idx].reduce
            desc_list.append(template_cfg)
        return desc_list
示例#17
0
 def _load_single_model_records(self):
     model_desc = PipeStepConfig.model.model_desc
     model_desc_file = PipeStepConfig.model.model_desc_file
     if model_desc_file:
         model_desc_file = model_desc_file.replace(
             "{local_base_path}",
             TaskOps().local_base_path)
         model_desc = Config(model_desc_file)
     if not model_desc:
         logger.error("Model desc or Model desc file is None.")
         return []
     model_file = PipeStepConfig.model.pretrained_model_file
     if not model_file:
         logger.error("Model file is None.")
         return []
     if not os.path.exists(model_file):
         logger.error("Model file is not existed.")
         return []
     return [
         ReportRecord().load_dict(
             dict(worker_id="1", desc=model_desc, weights_file=model_file))
     ]
示例#18
0
 def __init__(self, desc):
     """Init NetworkDesc."""
     self._desc = Config(deepcopy(desc))
示例#19
0
class Metrics(object):
    """Metrics class of all metrics defined in cfg.

    :param metric_cfg: metric part of config
    :type metric_cfg: dict or Config
    """

    config = MetricsConfig()

    def __init__(self, metric_cfg=None):
        """Init Metrics."""
        self.mdict = {}
        metric_config = self.config.to_dict()
        if not isinstance(metric_config, list):
            metric_config = [metric_config]
        for metric_item in metric_config:
            ClassFactory.get_cls(ClassType.METRIC, self.config.type)
            metric_name = metric_item.pop('type')
            metric_class = ClassFactory.get_cls(ClassType.METRIC, metric_name)
            if isfunction(metric_class):
                metric_class = partial(metric_class,
                                       **metric_item.get("params", {}))
            else:
                metric_class = metric_class(**metric_item.get("params", {}))
            self.mdict[metric_name] = metric_class
        self.mdict = Config(self.mdict)
        self.metric_results = dict()

    def __call__(self, output=None, target=None, *args, **kwargs):
        """Calculate all supported metrics by using output and target.

        :param output: predicted output by networks
        :type output: torch tensor
        :param target: target label data
        :type target: torch tensor
        :return: performance of metrics
        :rtype: list
        """
        pfms = {}
        for key in self.mdict:
            metric = self.mdict[key]
            pfms.update(metric(output, target, *args, **kwargs))
        for key in pfms:
            self.metric_results[key] = None
        return pfms

    def reset(self):
        """Reset states for new evaluation after each epoch."""
        self.metric_results = dict()

    @property
    def results(self):
        """Return metrics results."""
        return deepcopy(self.metric_results)

    @property
    def objectives(self):
        """Return objectives results."""
        return {name: self.mdict.get(name).objective for name in self.mdict}

    def update(self, metrics):
        """Update the metrics results.

        :param metrics: outside metrics
        :type metrics: dict
        """
        for key in metrics:
            # if key in self.metric_results:
            self.metric_results[key] = metrics[key]
示例#20
0
class Metrics(object):
    """Metrics class of all metrics defined in cfg.

    :param metric_cfg: metric part of config
    :type metric_cfg: dict or Config
    """

    config = MetricsConfig()

    def __init__(self, metric_cfg=None):
        """Init Metrics."""
        self.mdict = {}
        metric_config = self.config.to_dict() if not metric_cfg else deepcopy(
            metric_cfg)
        if not isinstance(metric_config, list):
            metric_config = [metric_config]
        for metric_item in metric_config:
            ClassFactory.get_cls(ClassType.METRIC, self.config.type)
            metric_name = metric_item.pop('type')
            metric_class = ClassFactory.get_cls(ClassType.METRIC, metric_name)
            if isfunction(metric_class):
                metric_class = partial(metric_class,
                                       **metric_item.get("params", {}))
            else:
                metric_class = metric_class(**metric_item.get("params", {}))
            self.mdict[metric_name] = metric_class
        self.mdict = Config(self.mdict)

    def __call__(self, output=None, target=None, *args, **kwargs):
        """Calculate all supported metrics by using output and target.

        :param output: predicted output by networks
        :type output: torch tensor
        :param target: target label data
        :type target: torch tensor
        :return: performance of metrics
        :rtype: list
        """
        pfms = []
        for key in self.mdict:
            metric = self.mdict[key]
            pfms.append(metric(output, target, *args, **kwargs))
        return pfms

    def reset(self):
        """Reset states for new evaluation after each epoch."""
        for val in self.mdict.values():
            val.reset()

    @property
    def results(self):
        """Return metrics results."""
        res = {}
        for name, metric in self.mdict.items():
            res.update(metric.result)
        return res

    @property
    def objectives(self):
        """Return objectives results."""
        _objs = {}
        for name in self.mdict:
            objective = self.mdict.get(name).objective
            if isinstance(objective, dict):
                _objs = dict(_objs, **objective)
            else:
                _objs[name] = objective
        return _objs

    def __getattr__(self, key):
        """Get a metric by key name.

        :param key: metric name
        :type key: str
        """
        return self.mdict[key]