Exemplo n.º 1
0
 def _init_dataloader(self, mode, loader=None, transforms=None):
     """Init dataloader."""
     if loader is not None:
         return loader
     if mode == "train" and self.hps is not None and self.hps.get(
             "dataset") is not None:
         if self.hps.get("dataset") and self.hps.get("dataset").get('type'):
             dataset_cls = ClassFactory.get_cls(
                 ClassType.DATASET,
                 self.hps.get("dataset").get('type'))
         else:
             dataset_cls = ClassFactory.get_cls(ClassType.DATASET)
         dataset = dataset_cls(mode=mode, hps=self.hps.get("dataset"))
     elif self.hps:
         if self.hps.get("dataset") and self.hps.get("dataset").get('type'):
             dataset_cls = ClassFactory.get_cls(
                 ClassType.DATASET,
                 self.hps.get("dataset").get('type'))
             dataset = dataset_cls(mode=mode, hps=self.hps.get("dataset"))
         else:
             dataset_cls = ClassFactory.get_cls(ClassType.DATASET)
             dataset = dataset_cls(mode=mode)
     else:
         dataset_cls = ClassFactory.get_cls(ClassType.DATASET)
         dataset = dataset_cls(mode=mode)
     if transforms is not None:
         dataset.transforms = transforms
     if self.distributed and mode == "train":
         dataset.set_distributed(self._world_size, self._rank_id)
     # adapt the dataset to specific backend
     dataloader = Adapter(dataset).loader
     return dataloader
Exemplo n.º 2
0
 def _train_single_model(self,
                         model_desc=None,
                         model_id=None,
                         weights_file=None):
     cls_trainer = ClassFactory.get_cls(ClassType.TRAINER,
                                        PipeStepConfig.trainer.type)
     step_name = self.task.step_name
     if model_desc is not None:
         sample = dict(worker_id=model_id,
                       desc=model_desc,
                       step_name=step_name)
         record = ReportRecord().load_dict(sample)
         logging.debug("update record=%s", str(record))
         trainer = cls_trainer(model_desc=model_desc,
                               id=model_id,
                               pretrained_model_file=weights_file)
     else:
         trainer = cls_trainer(None, 0)
         record = ReportRecord(trainer.step_name,
                               trainer.worker_id,
                               desc=trainer.model_desc)
     ReportClient().update(**record.to_dict())
     # resume training
     if vega.is_torch_backend() and General._resume:
         trainer.load_checkpoint = True
         trainer._resume_training = True
     if self._distributed_training:
         self._do_distributed_fully_train(trainer)
     else:
         self._do_single_fully_train(trainer)
Exemplo n.º 3
0
    def _do_hccl_fully_train(self, trainer):
        origin_worker_id = trainer.worker_id
        model_desc = trainer.model_desc
        del trainer

        os.environ['RANK_SIZE'] = os.environ['ORIGIN_RANK_SIZE']
        os.environ['RANK_TABLE_FILE'] = os.environ['ORIGIN_RANK_TABLE_FILE']
        origin_parallel_fully_train = General.parallel_fully_train
        origin_parallel = General._parallel
        General.parallel_fully_train = True
        General.dft = True
        General._parallel = True

        cls_trainer = ClassFactory.get_cls(ClassType.TRAINER,
                                           PipeStepConfig.trainer.type)
        self.master = create_master()
        workers_num = int(os.environ['RANK_SIZE'])
        for i in range(workers_num):
            worker_id = "{}-{}".format(origin_worker_id, i)
            trainer = cls_trainer(model_desc, id=worker_id)
            evaluator = self._get_evaluator(
                worker_id) if os.environ['DEVICE_ID'] == "0" else None
            self.master.run(trainer, evaluator)

        self.master.join()
        self.master.close()
        General.parallel_fully_train = origin_parallel_fully_train
        General.dft = False
        General._parallel = origin_parallel
Exemplo n.º 4
0
def from_module(module):
    """From Model."""
    name = module.__class__.__name__
    if ClassFactory.is_exists(ClassType.NETWORK, name):
        module_cls = ClassFactory.get_cls(ClassType.NETWORK, name)
        if hasattr(module_cls, "from_module"):
            return module_cls.from_module(module)
    return module
Exemplo n.º 5
0
 def _get_evaluator(self, worker_id):
     if not PipeStepConfig.evaluator_enable:
         return None
     cls_evaluator = ClassFactory.get_cls('evaluator', "Evaluator")
     evaluator = cls_evaluator({
         "step_name": self.task.step_name,
         "worker_id": worker_id
     })
     return evaluator
Exemplo n.º 6
0
 def _train_single_model(self, model_desc, model_id, hps, multi_task):
     cls_trainer = ClassFactory.get_cls(ClassType.TRAINER, PipeStepConfig.trainer.type)
     step_name = self.task.step_name
     sample = dict(worker_id=model_id, desc=model_desc, step_name=step_name)
     record = ReportRecord().load_dict(sample)
     logging.debug("update record=%s", str(record))
     trainer = cls_trainer(model_desc=model_desc, id=model_id, hps=hps, multi_task=multi_task)
     ReportClient().update(**record.to_dict())
     if self._distributed_training:
         self._do_distributed_fully_train(trainer)
     else:
         self._do_single_fully_train(trainer)
Exemplo n.º 7
0
 def to_model(self):
     """Transform a NetworkDesc to a special model."""
     logging.debug("Start to Create a Network.")
     module = ClassFactory.get_cls(ClassType.NETWORK, "Module")
     model = module.from_desc(self._desc)
     if not model:
         raise Exception("Failed to create model, model desc={}".format(
             self._desc))
     model.desc = self._desc
     if hasattr(model, '_apply_names'):
         model._apply_names()
     return model
Exemplo n.º 8
0
 def __new__(cls,
             model=None,
             id=None,
             hps=None,
             load_ckpt_flag=False,
             model_desc=None,
             **kwargs):
     """Create Trainer clss."""
     if vega.is_torch_backend():
         trainer_cls = ClassFactory.get_cls(ClassType.TRAINER,
                                            "TrainerTorch")
     elif vega.is_tf_backend():
         trainer_cls = ClassFactory.get_cls(ClassType.TRAINER, "TrainerTf")
     else:
         trainer_cls = ClassFactory.get_cls(ClassType.TRAINER, "TrainerMs")
     return trainer_cls(model=model,
                        id=id,
                        hps=hps,
                        load_ckpt_flag=load_ckpt_flag,
                        model_desc=model_desc,
                        **kwargs)
Exemplo n.º 9
0
 def __init__(self,
              search_space,
              num_samples,
              max_epochs,
              repeat_times,
              min_epochs=1,
              eta=3,
              multi_obj=False,
              random_samples=None,
              prob_crossover=0.6,
              prob_mutatation=0.2,
              tuner="GP"):
     """Init BOHB."""
     super().__init__(search_space, num_samples, max_epochs, min_epochs,
                      eta)
     # init all the configs
     self.repeat_times = repeat_times
     self.max_epochs = max_epochs
     self.iter_list, self.min_epoch_list = self._get_total_iters(
         num_samples, max_epochs, self.repeat_times, min_epochs, eta)
     self.additional_samples = self._get_additional_samples(eta)
     if random_samples is not None:
         self.random_samples = random_samples
     else:
         self.random_samples = max(self.iter_list[0][0], 2)
     self.tuner_name = "GA" if multi_obj else tuner
     logger.info(
         f"bohb info, iter list: {self.iter_list}, min epoch list: {self.min_epoch_list}, "
         f"addition samples: {self.additional_samples}, tuner: {self.tuner_name}, "
         f"random samples: {self.random_samples}")
     # create sha list
     self.sha_list = self._create_sha_list(search_space, self.iter_list,
                                           self.min_epoch_list,
                                           self.repeat_times)
     # create tuner
     if multi_obj:
         self.tuner = GeneticAlgorithm(search_space,
                                       random_samples=self.random_samples,
                                       prob_crossover=prob_crossover,
                                       prob_mutatation=prob_mutatation)
     elif self.tuner_name == "hebo":
         self.tuner = ClassFactory.get_cls(ClassType.SEARCH_ALGORITHM,
                                           "HeboAdaptor")(search_space)
     else:
         self.tuner = TunerBuilder(search_space, tuner=tuner)
Exemplo n.º 10
0
 def __init__(self, desc=None):
     """Init SearchSpace."""
     super(SearchSpace, self).__init__()
     if desc is None:
         desc = SearchSpaceConfig().to_dict()
         if desc.type is not None:
             desc = ClassFactory.get_cls(ClassType.SEARCHSPACE,
                                         desc.type).get_space(desc)
     for name, item in desc.items():
         self.__setattr__(name, item)
         self.__setitem__(name, item)
     self._params = OrderedDict()
     self._condition_dict = OrderedDict()
     self._forbidden_list = []
     self._hp_count = 0
     self._dag = DAG()
     if desc is not None:
         self.form_desc(desc)
Exemplo n.º 11
0
def transform_architecture(model, pretrained_model_file=None):
    """Transform architecture."""
    if not hasattr(model, "_arch_params") or not model._arch_params or \
            PipeStepConfig.pipe_step.get("type") == "TrainPipeStep":
        return model
    model._apply_names()
    logging.info(
        "Start to transform architecture, model arch params type: {}".format(
            model._arch_params_type))
    ConnectionsArchParamsCombiner().combine(model)
    if vega.is_ms_backend():
        from mindspore.train.serialization import load_checkpoint
        changed_name_list = []
        mask_weight_list = []
        for name, module in model.named_modules():
            if not ClassFactory.is_exists(model._arch_params_type,
                                          module.model_name):
                continue
            changed_name_list, mask_weight_list = decode_fn_ms(
                module, changed_name_list, mask_weight_list)
        assert len(changed_name_list) == len(mask_weight_list)
        # change model and rebuild
        model_desc = model.desc
        root_name = [
            name for name in list(model_desc.keys())
            if name not in ('type', '_arch_params')
        ]
        for changed_name, mask in zip(changed_name_list, mask_weight_list):
            name = changed_name.split('.')
            name[0] = root_name[int(name[0])]
            assert len(name) <= 6
            if len(name) == 6:
                model_desc[name[0]][name[1]][name[2]][name[3]][name[4]][
                    name[5]] = sum(mask)
            if len(name) == 5:
                model_desc[name[0]][name[1]][name[2]][name[3]][name[4]] = sum(
                    mask)
            if len(name) == 4:
                model_desc[name[0]][name[1]][name[2]][name[3]] = sum(mask)
            if len(name) == 3:
                model_desc[name[0]][name[1]][name[2]] = sum(mask)
            if len(name) == 2:
                model_desc[name[0]][name[1]] = sum(mask)
        network = NetworkDesc(model_desc)
        model = network.to_model()
        model_desc.pop(
            '_arch_params') if '_arch_params' in model_desc else model_desc
        model.desc = model_desc
        # change weight
        if hasattr(model, "pretrained"):
            pretrained_weight = model.pretrained(pretrained_model_file)
            load_checkpoint(pretrained_weight, net=model)
            os.remove(pretrained_weight)

    else:
        for name, module in model.named_modules():
            if not ClassFactory.is_exists(model._arch_params_type,
                                          module.model_name):
                continue
            arch_cls = ClassFactory.get_cls(model._arch_params_type,
                                            module.model_name)

            decode_fn(module, arch_cls)
            module.register_forward_pre_hook(arch_cls.fit_weights)
            module.register_forward_hook(module.clear_module_arch_params)
    return model