def from_module(module): """From Model.""" name = module.__class__.__name__ if ClassFactory.is_exists(ClassType.NETWORK, name): module_cls = ClassFactory.get_cls(ClassType.NETWORK, name) if hasattr(module_cls, "from_module"): return module_cls.from_module(module) return module
def transform_architecture(model, pretrained_model_file=None): """Transform architecture.""" if not hasattr(model, "_arch_params") or not model._arch_params or \ PipeStepConfig.pipe_step.get("type") == "TrainPipeStep": return model model._apply_names() logging.info( "Start to transform architecture, model arch params type: {}".format( model._arch_params_type)) ConnectionsArchParamsCombiner().combine(model) if vega.is_ms_backend(): from mindspore.train.serialization import load_checkpoint changed_name_list = [] mask_weight_list = [] for name, module in model.named_modules(): if not ClassFactory.is_exists(model._arch_params_type, module.model_name): continue changed_name_list, mask_weight_list = decode_fn_ms( module, changed_name_list, mask_weight_list) assert len(changed_name_list) == len(mask_weight_list) # change model and rebuild model_desc = model.desc root_name = [ name for name in list(model_desc.keys()) if name not in ('type', '_arch_params') ] for changed_name, mask in zip(changed_name_list, mask_weight_list): name = changed_name.split('.') name[0] = root_name[int(name[0])] assert len(name) <= 6 if len(name) == 6: model_desc[name[0]][name[1]][name[2]][name[3]][name[4]][ name[5]] = sum(mask) if len(name) == 5: model_desc[name[0]][name[1]][name[2]][name[3]][name[4]] = sum( mask) if len(name) == 4: model_desc[name[0]][name[1]][name[2]][name[3]] = sum(mask) if len(name) == 3: model_desc[name[0]][name[1]][name[2]] = sum(mask) if len(name) == 2: model_desc[name[0]][name[1]] = sum(mask) network = NetworkDesc(model_desc) model = network.to_model() model_desc.pop( '_arch_params') if '_arch_params' in model_desc else model_desc model.desc = model_desc # change weight if hasattr(model, "pretrained"): pretrained_weight = model.pretrained(pretrained_model_file) load_checkpoint(pretrained_weight, net=model) os.remove(pretrained_weight) else: for name, module in model.named_modules(): if not ClassFactory.is_exists(model._arch_params_type, module.model_name): continue arch_cls = ClassFactory.get_cls(model._arch_params_type, module.model_name) decode_fn(module, arch_cls) module.register_forward_pre_hook(arch_cls.fit_weights) module.register_forward_hook(module.clear_module_arch_params) return model