コード例 #1
0
from federatedml.protobuf.model_migrate.converter.tree_model_converter import HeteroSBTConverter
from federatedml.protobuf.model_migrate.model_migrate import model_migration
import copy

host_old = [10000, 9999]
host_new = [
    114,
    514,
]

guest_old = [10000]
guest_new = [1919]

param = FeatureBinningParam()

old_header = ['host_10000_0', 'host_10000_1', 'host_10000_2', 'host_10000_3']
param.header_anonymous = old_header

rs = model_migration(
    {
        'HelloParam': param,
        'HelloMeta': {}
    },
    'HeteroSecureBoost',
    old_guest_list=guest_old,
    new_guest_list=guest_new,
    old_host_list=host_old,
    new_host_list=host_new,
)
print(rs)
コード例 #2
0
ファイル: migrate_model.py プロジェクト: zhlj98/FATE
def migration(config_data: dict):
    try:
        party_model_id = model_utils.gen_party_model_id(
            model_id=config_data["model_id"],
            role=config_data["local"]["role"],
            party_id=config_data["local"]["party_id"])
        model = pipelined_model.PipelinedModel(
            model_id=party_model_id,
            model_version=config_data["model_version"])
        if not model.exists():
            raise Exception("Can not found {} {} model local cache".format(
                config_data["model_id"], config_data["model_version"]))
        with DB.connection_context():
            if MLModel.get_or_none(MLModel.f_model_version ==
                                   config_data["unify_model_version"]):
                raise Exception(
                    "Unify model version {} has been occupied in database. "
                    "Please choose another unify model version and try again.".
                    format(config_data["unify_model_version"]))

        model_data = model.collect_models(in_bytes=True)
        if "pipeline.pipeline:Pipeline" not in model_data:
            raise Exception("Can not found pipeline file in model.")

        migrate_model = pipelined_model.PipelinedModel(
            model_id=model_utils.gen_party_model_id(
                model_id=model_utils.gen_model_id(config_data["migrate_role"]),
                role=config_data["local"]["role"],
                party_id=config_data["local"]["migrate_party_id"]),
            model_version=config_data["unify_model_version"])

        # migrate_model.create_pipelined_model()
        shutil.copytree(src=model.model_path, dst=migrate_model.model_path)

        pipeline = migrate_model.read_component_model('pipeline',
                                                      'pipeline')['Pipeline']

        # Utilize Pipeline_model collect model data. And modify related inner information of model
        train_runtime_conf = json_loads(pipeline.train_runtime_conf)
        train_runtime_conf["role"] = config_data["migrate_role"]
        train_runtime_conf["initiator"] = config_data["migrate_initiator"]

        adapter = JobRuntimeConfigAdapter(train_runtime_conf)
        train_runtime_conf = adapter.update_model_id_version(
            model_id=model_utils.gen_model_id(train_runtime_conf["role"]),
            model_version=migrate_model.model_version)

        # update pipeline.pb file
        pipeline.train_runtime_conf = json_dumps(train_runtime_conf, byte=True)
        pipeline.model_id = bytes(
            adapter.get_common_parameters().to_dict.get("model_id"), "utf-8")
        pipeline.model_version = bytes(
            adapter.get_common_parameters().to_dict().get("model_version"),
            "utf-8")

        # save updated pipeline.pb file
        migrate_model.save_pipeline(pipeline)
        shutil.copyfile(
            os.path.join(migrate_model.model_path, "pipeline.pb"),
            os.path.join(migrate_model.model_path, "variables", "data",
                         "pipeline", "pipeline", "Pipeline"))

        # modify proto
        with open(
                os.path.join(migrate_model.model_path, 'define',
                             'define_meta.yaml'), 'r') as fin:
            define_yaml = yaml.safe_load(fin)

        for key, value in define_yaml['model_proto'].items():
            if key == 'pipeline':
                continue
            for v in value.keys():
                buffer_obj = migrate_model.read_component_model(key, v)
                module_name = define_yaml['component_define'].get(
                    key, {}).get('module_name')
                modified_buffer = model_migration(
                    model_contents=buffer_obj,
                    module_name=module_name,
                    old_guest_list=config_data['role']['guest'],
                    new_guest_list=config_data['migrate_role']['guest'],
                    old_host_list=config_data['role']['host'],
                    new_host_list=config_data['migrate_role']['host'],
                    old_arbiter_list=config_data.get('role',
                                                     {}).get('arbiter', None),
                    new_arbiter_list=config_data.get('migrate_role',
                                                     {}).get('arbiter', None))
                migrate_model.save_component_model(
                    component_name=key,
                    component_module_name=module_name,
                    model_alias=v,
                    model_buffers=modified_buffer)

        archive_path = migrate_model.packaging_model()
        shutil.rmtree(os.path.abspath(migrate_model.model_path))

        return (0, f"Migrating model successfully. " \
                  "The configuration of model has been modified automatically. " \
                  "New model id is: {}, model version is: {}. " \
                  "Model files can be found at '{}'.".format(adapter.get_common_parameters()["model_id"],
                                                             migrate_model.model_version,
                                                             os.path.abspath(archive_path)),
                {"model_id": migrate_model.model_id,
                 "model_version": migrate_model.model_version,
                 "path": os.path.abspath(archive_path)})

    except Exception as e:
        return 100, str(e), {}