Пример #1
0
def load_model(config_data):
    stat_logger.info(config_data)
    if not config_data.get('servings'):
        return 100, 'Please configure servings address'
    for serving in config_data.get('servings'):
        with grpc.insecure_channel(serving) as channel:
            stub = model_service_pb2_grpc.ModelServiceStub(channel)
            load_model_request = model_service_pb2.PublishRequest()
            for role_name, role_partys in config_data.get("role").items():
                for _party_id in role_partys:
                    load_model_request.role[role_name].partyId.append(_party_id)
            for role_name, role_model_config in config_data.get("model").items():
                for _party_id, role_party_model_config in role_model_config.items():
                    load_model_request.model[role_name].roleModelInfo[_party_id].tableName = role_party_model_config[
                        'model_version']
                    load_model_request.model[role_name].roleModelInfo[_party_id].namespace = role_party_model_config[
                        'model_id']
            stat_logger.info('request serving: {} load model'.format(serving))
            load_model_request.local.role = config_data.get('local').get('role')
            load_model_request.local.partyId = config_data.get('local').get('party_id')
            load_model_request.loadType = config_data['job_parameters'].get("load_type", "FATEFLOW")
            if not get_base_config('use_registry'):
                load_model_request.filePath = f"http://{IP}:{HTTP_PORT}{FATE_FLOW_MODEL_TRANSFER_ENDPOINT}"
            else:
                load_model_request.filePath = config_data['job_parameters'].get("file_path", "")
            stat_logger.info(load_model_request)
            response = stub.publishLoad(load_model_request)
            stat_logger.info(
                '{} {} load model status: {}'.format(load_model_request.local.role, load_model_request.local.partyId,
                                                     response.statusCode))
            if response.statusCode != 0:
                return response.statusCode, '{} {}'.format(response.message, response.error)
    return 0, 'success'
Пример #2
0
def bind_model_service(config_data):
    service_id = config_data.get('service_id')
    initiator_role = config_data['initiator']['role']
    initiator_party_id = config_data['initiator']['party_id']
    model_id = config_data['job_parameters']['model_id']
    model_version = config_data['job_parameters']['model_version']
    if not config_data.get('servings'):
        return 100, 'Please configure servings address'
    for serving in config_data.get('servings'):
        with grpc.insecure_channel(serving) as channel:
            stub = model_service_pb2_grpc.ModelServiceStub(channel)
            publish_model_request = model_service_pb2.PublishRequest()
            publish_model_request.serviceId = service_id
            for role_name, role_party in config_data.get("role").items():
                publish_model_request.role[role_name].partyId.extend(role_party)

            publish_model_request.model[initiator_role].roleModelInfo[initiator_party_id].tableName = model_version
            publish_model_request.model[initiator_role].roleModelInfo[
                initiator_party_id].namespace = model_utils.gen_party_model_id(model_id, initiator_role,
                                                                               initiator_party_id)
            publish_model_request.local.role = initiator_role
            publish_model_request.local.partyId = initiator_party_id
            stat_logger.info(publish_model_request)
            response = stub.publishBind(publish_model_request)
            stat_logger.info(response)
            if response.statusCode != 0:
                return response.statusCode, response.message
    return 0, None
Пример #3
0
def load_model(config_data):
    stat_logger.info(config_data)
    if not config_data.get('servings'):
        return 100, 'Please configure servings address'

    for serving in config_data['servings']:
        with grpc.insecure_channel(serving) as channel:
            stub = model_service_pb2_grpc.ModelServiceStub(channel)
            load_model_request = model_service_pb2.PublishRequest()
            for role_name, role_partys in config_data.get("role", {}).items():
                for _party_id in role_partys:
                    load_model_request.role[role_name].partyId.append(
                        str(_party_id))
            for role_name, role_model_config in config_data.get("model",
                                                                {}).items():
                for _party_id, role_party_model_config in role_model_config.items(
                ):
                    load_model_request.model[role_name].roleModelInfo[str(_party_id)].tableName = \
                        role_party_model_config['model_version']
                    load_model_request.model[role_name].roleModelInfo[str(_party_id)].namespace = \
                        role_party_model_config['model_id']

            stat_logger.info('request serving: {} load model'.format(serving))
            load_model_request.local.role = config_data.get('local', {}).get(
                'role', '')
            load_model_request.local.partyId = str(
                config_data.get('local', {}).get('party_id', ''))
            load_model_request.loadType = config_data['job_parameters'].get(
                "load_type", "FATEFLOW")
            # make use of 'model.transfer.url' in serving server
            use_serving_url = config_data['job_parameters'].get(
                'use_transfer_url_on_serving', False)
            if not USE_REGISTRY and not use_serving_url:
                load_model_request.filePath = f"http://{HOST}:{HTTP_PORT}{FATE_FLOW_MODEL_TRANSFER_ENDPOINT}"
            else:
                load_model_request.filePath = config_data[
                    'job_parameters'].get("file_path", "")
            stat_logger.info(load_model_request)
            response = stub.publishLoad(load_model_request)
            stat_logger.info('{} {} load model status: {}'.format(
                load_model_request.local.role,
                load_model_request.local.partyId, response.statusCode))
            if response.statusCode != 0:
                return response.statusCode, '{} {}'.format(
                    response.message, response.error)
    return 0, 'success'