示例#1
0
文件: ops.py 项目: huawei-noah/vega
def from_module(module):
    """From Model."""
    name = module.__class__.__name__
    if ClassFactory.is_exists(ClassType.NETWORK, name):
        module_cls = ClassFactory.get_cls(ClassType.NETWORK, name)
        if hasattr(module_cls, "from_module"):
            return module_cls.from_module(module)
    return module
示例#2
0
def transform_architecture(model, pretrained_model_file=None):
    """Transform architecture."""
    if not hasattr(model, "_arch_params") or not model._arch_params or \
            PipeStepConfig.pipe_step.get("type") == "TrainPipeStep":
        return model
    model._apply_names()
    logging.info(
        "Start to transform architecture, model arch params type: {}".format(
            model._arch_params_type))
    ConnectionsArchParamsCombiner().combine(model)
    if vega.is_ms_backend():
        from mindspore.train.serialization import load_checkpoint
        changed_name_list = []
        mask_weight_list = []
        for name, module in model.named_modules():
            if not ClassFactory.is_exists(model._arch_params_type,
                                          module.model_name):
                continue
            changed_name_list, mask_weight_list = decode_fn_ms(
                module, changed_name_list, mask_weight_list)
        assert len(changed_name_list) == len(mask_weight_list)
        # change model and rebuild
        model_desc = model.desc
        root_name = [
            name for name in list(model_desc.keys())
            if name not in ('type', '_arch_params')
        ]
        for changed_name, mask in zip(changed_name_list, mask_weight_list):
            name = changed_name.split('.')
            name[0] = root_name[int(name[0])]
            assert len(name) <= 6
            if len(name) == 6:
                model_desc[name[0]][name[1]][name[2]][name[3]][name[4]][
                    name[5]] = sum(mask)
            if len(name) == 5:
                model_desc[name[0]][name[1]][name[2]][name[3]][name[4]] = sum(
                    mask)
            if len(name) == 4:
                model_desc[name[0]][name[1]][name[2]][name[3]] = sum(mask)
            if len(name) == 3:
                model_desc[name[0]][name[1]][name[2]] = sum(mask)
            if len(name) == 2:
                model_desc[name[0]][name[1]] = sum(mask)
        network = NetworkDesc(model_desc)
        model = network.to_model()
        model_desc.pop(
            '_arch_params') if '_arch_params' in model_desc else model_desc
        model.desc = model_desc
        # change weight
        if hasattr(model, "pretrained"):
            pretrained_weight = model.pretrained(pretrained_model_file)
            load_checkpoint(pretrained_weight, net=model)
            os.remove(pretrained_weight)

    else:
        for name, module in model.named_modules():
            if not ClassFactory.is_exists(model._arch_params_type,
                                          module.model_name):
                continue
            arch_cls = ClassFactory.get_cls(model._arch_params_type,
                                            module.model_name)

            decode_fn(module, arch_cls)
            module.register_forward_pre_hook(arch_cls.fit_weights)
            module.register_forward_hook(module.clear_module_arch_params)
    return model