from federatedml.protobuf.model_migrate.converter.tree_model_converter import HeteroSBTConverter from federatedml.protobuf.model_migrate.model_migrate import model_migration import copy host_old = [10000, 9999] host_new = [ 114, 514, ] guest_old = [10000] guest_new = [1919] param = FeatureBinningParam() old_header = ['host_10000_0', 'host_10000_1', 'host_10000_2', 'host_10000_3'] param.header_anonymous = old_header rs = model_migration( { 'HelloParam': param, 'HelloMeta': {} }, 'HeteroSecureBoost', old_guest_list=guest_old, new_guest_list=guest_new, old_host_list=host_old, new_host_list=host_new, ) print(rs)
def migration(config_data: dict): try: party_model_id = model_utils.gen_party_model_id( model_id=config_data["model_id"], role=config_data["local"]["role"], party_id=config_data["local"]["party_id"]) model = pipelined_model.PipelinedModel( model_id=party_model_id, model_version=config_data["model_version"]) if not model.exists(): raise Exception("Can not found {} {} model local cache".format( config_data["model_id"], config_data["model_version"])) with DB.connection_context(): if MLModel.get_or_none(MLModel.f_model_version == config_data["unify_model_version"]): raise Exception( "Unify model version {} has been occupied in database. " "Please choose another unify model version and try again.". format(config_data["unify_model_version"])) model_data = model.collect_models(in_bytes=True) if "pipeline.pipeline:Pipeline" not in model_data: raise Exception("Can not found pipeline file in model.") migrate_model = pipelined_model.PipelinedModel( model_id=model_utils.gen_party_model_id( model_id=model_utils.gen_model_id(config_data["migrate_role"]), role=config_data["local"]["role"], party_id=config_data["local"]["migrate_party_id"]), model_version=config_data["unify_model_version"]) # migrate_model.create_pipelined_model() shutil.copytree(src=model.model_path, dst=migrate_model.model_path) pipeline = migrate_model.read_component_model('pipeline', 'pipeline')['Pipeline'] # Utilize Pipeline_model collect model data. And modify related inner information of model train_runtime_conf = json_loads(pipeline.train_runtime_conf) train_runtime_conf["role"] = config_data["migrate_role"] train_runtime_conf["initiator"] = config_data["migrate_initiator"] adapter = JobRuntimeConfigAdapter(train_runtime_conf) train_runtime_conf = adapter.update_model_id_version( model_id=model_utils.gen_model_id(train_runtime_conf["role"]), model_version=migrate_model.model_version) # update pipeline.pb file pipeline.train_runtime_conf = json_dumps(train_runtime_conf, byte=True) pipeline.model_id = bytes( adapter.get_common_parameters().to_dict.get("model_id"), "utf-8") pipeline.model_version = bytes( adapter.get_common_parameters().to_dict().get("model_version"), "utf-8") # save updated pipeline.pb file migrate_model.save_pipeline(pipeline) shutil.copyfile( os.path.join(migrate_model.model_path, "pipeline.pb"), os.path.join(migrate_model.model_path, "variables", "data", "pipeline", "pipeline", "Pipeline")) # modify proto with open( os.path.join(migrate_model.model_path, 'define', 'define_meta.yaml'), 'r') as fin: define_yaml = yaml.safe_load(fin) for key, value in define_yaml['model_proto'].items(): if key == 'pipeline': continue for v in value.keys(): buffer_obj = migrate_model.read_component_model(key, v) module_name = define_yaml['component_define'].get( key, {}).get('module_name') modified_buffer = model_migration( model_contents=buffer_obj, module_name=module_name, old_guest_list=config_data['role']['guest'], new_guest_list=config_data['migrate_role']['guest'], old_host_list=config_data['role']['host'], new_host_list=config_data['migrate_role']['host'], old_arbiter_list=config_data.get('role', {}).get('arbiter', None), new_arbiter_list=config_data.get('migrate_role', {}).get('arbiter', None)) migrate_model.save_component_model( component_name=key, component_module_name=module_name, model_alias=v, model_buffers=modified_buffer) archive_path = migrate_model.packaging_model() shutil.rmtree(os.path.abspath(migrate_model.model_path)) return (0, f"Migrating model successfully. " \ "The configuration of model has been modified automatically. " \ "New model id is: {}, model version is: {}. " \ "Model files can be found at '{}'.".format(adapter.get_common_parameters()["model_id"], migrate_model.model_version, os.path.abspath(archive_path)), {"model_id": migrate_model.model_id, "model_version": migrate_model.model_version, "path": os.path.abspath(archive_path)}) except Exception as e: return 100, str(e), {}