Пример #1
0
 def new_model(self):
     """Build new model."""
     net_desc = NetworkDesc(self.search_space)
     model_new = net_desc.to_model().cuda()
     for x, y in zip(model_new.arch_parameters(),
                     self.model.arch_parameters()):
         x.detach().copy_(y.detach())
     return model_new
Пример #2
0
    def _init_model(self):
        """Initialize the model architecture for full train step.

        :return: train model
        :rtype: class
        """
        logging.info('Initializing model')
        if self.cfg.model_desc:
            logging.debug("model_desc: {}".format(self.cfg.model_desc))
            _file = FileOps.join_path(
                self.worker_path, "model_desc_{}.json".format(self._worker_id))
            with open(_file, "w") as f:
                json.dump(self.cfg.model_desc, f)
            if self.cfg.distributed:
                hvd.join()
            model_desc = self.cfg.model_desc
            net_desc = NetworkDesc(model_desc)
            model = net_desc.to_model()
            return model
        else:
            return None
Пример #3
0
    def get_model(cls,
                  model_desc=None,
                  pretrained_model_file=None,
                  exclude_weight_prefix=None):
        """Get model from model zoo.

        :param network_name: the name of network, eg. ResNetVariant.
        :type network_name: str or None.
        :param network_desc: the description of network.
        :type network_desc: str or None.
        :param pretrained_model_file: path of model.
        :type pretrained_model_file: str.
        :return: model.
        :rtype: model.

        """
        model = None
        if model_desc is not None:
            try:
                network = NetworkDesc(model_desc)
                model = network.to_model()
            except Exception as e:
                logging.error(
                    "Failed to get model, model_desc={}, msg={}".format(
                        model_desc, str(e)))
                raise e
        logging.info("Model was created.")
        if not isinstance(model, Module):
            model = cls.to_module(model)
        if pretrained_model_file is not None:
            if exclude_weight_prefix:
                model.exclude_weight_prefix = exclude_weight_prefix
            model = cls._load_pretrained_model(model, pretrained_model_file)
        model = transform_architecture(model, pretrained_model_file)
        if model is None:
            raise ValueError("Failed to get mode, model is None.")
        return model
Пример #4
0
def transform_architecture(model, pretrained_model_file=None):
    """Transform architecture."""
    if not hasattr(model, "_arch_params") or not model._arch_params or \
            PipeStepConfig.pipe_step.get("type") == "TrainPipeStep":
        return model
    model._apply_names()
    logging.info(
        "Start to transform architecture, model arch params type: {}".format(
            model._arch_params_type))
    ConnectionsArchParamsCombiner().combine(model)
    if vega.is_ms_backend():
        from mindspore.train.serialization import load_checkpoint
        changed_name_list = []
        mask_weight_list = []
        for name, module in model.named_modules():
            if not ClassFactory.is_exists(model._arch_params_type,
                                          module.model_name):
                continue
            changed_name_list, mask_weight_list = decode_fn_ms(
                module, changed_name_list, mask_weight_list)
        assert len(changed_name_list) == len(mask_weight_list)
        # change model and rebuild
        model_desc = model.desc
        root_name = [
            name for name in list(model_desc.keys())
            if name not in ('type', '_arch_params')
        ]
        for changed_name, mask in zip(changed_name_list, mask_weight_list):
            name = changed_name.split('.')
            name[0] = root_name[int(name[0])]
            assert len(name) <= 6
            if len(name) == 6:
                model_desc[name[0]][name[1]][name[2]][name[3]][name[4]][
                    name[5]] = sum(mask)
            if len(name) == 5:
                model_desc[name[0]][name[1]][name[2]][name[3]][name[4]] = sum(
                    mask)
            if len(name) == 4:
                model_desc[name[0]][name[1]][name[2]][name[3]] = sum(mask)
            if len(name) == 3:
                model_desc[name[0]][name[1]][name[2]] = sum(mask)
            if len(name) == 2:
                model_desc[name[0]][name[1]] = sum(mask)
        network = NetworkDesc(model_desc)
        model = network.to_model()
        model_desc.pop(
            '_arch_params') if '_arch_params' in model_desc else model_desc
        model.desc = model_desc
        # change weight
        if hasattr(model, "pretrained"):
            pretrained_weight = model.pretrained(pretrained_model_file)
            load_checkpoint(pretrained_weight, net=model)
            os.remove(pretrained_weight)

    else:
        for name, module in model.named_modules():
            if not ClassFactory.is_exists(model._arch_params_type,
                                          module.model_name):
                continue
            arch_cls = ClassFactory.get_cls(model._arch_params_type,
                                            module.model_name)

            decode_fn(module, arch_cls)
            module.register_forward_pre_hook(arch_cls.fit_weights)
            module.register_forward_hook(module.clear_module_arch_params)
    return model