示例#1
0
def generate_model_info(config_data):
    default_table_config = dict()
    default_table_config['role'] = config_data.get('role')
    default_table_config['data_type'] = 'model'
    default_table_config['gen_table_info'] = True
    table_config = copy.deepcopy(default_table_config)
    table_config['local'] = config_data.get('local')
    table_config.update(
        config_data.get('model').get(
            table_config['local'].get('role'),
            {}).get(table_config['local'].get('party_id')))
    table_name, namespace = dtable_utils.get_table_info(config=table_config)
    models_table_name = table_name
    if not models_table_name or not namespace:
        return False
    for role_name, role_model_config in config_data.get("model").items():
        for _party_id, role_party_model_config in role_model_config.items():
            table_config = copy.deepcopy(default_table_config)
            table_config['local'] = {'role': role_name, 'party_id': _party_id}
            table_config.update(role_party_model_config)
            table_config[
                'table_name'] = table_config['table_name'] if table_config.get(
                    'table_name') else models_table_name
            table_name, namespace = dtable_utils.get_table_info(
                config=table_config)
            config_data['model'][role_name][_party_id][
                'table_name'] = table_name
            config_data['model'][role_name][_party_id]['namespace'] = namespace
示例#2
0
文件: workflow.py 项目: 03040081/FATE
def fill_runtime_conf_table_info(runtime_conf, default_runtime_conf):
    if not runtime_conf.get('scene_id') or not runtime_conf.get(
            'gen_table_info'):
        return
    table_config = copy.deepcopy(runtime_conf)
    workflow_param = runtime_conf.get('WorkFlowParam')
    default_workflow_param = default_runtime_conf.get('WorkFlowParam')
    for data_type in DEFAULT_WORKFLOW_DATA_TYPE:
        name_param = '{}_table'.format(data_type)
        namespace_param = '{}_namespace'.format(data_type)
        table_config['data_type'] = data_type
        input_output = data_type.split('_')[-1]
        if (not workflow_param.get(name_param)
            or workflow_param.get(name_param) == default_workflow_param.get(name_param)) \
                and (not workflow_param.get(namespace_param)
                     or workflow_param.get(namespace_param) == default_workflow_param.get(namespace_param)):
            if input_output == 'input':
                _create = False
                table_config['table_name'] = ''
            else:
                _create = True
                table_config['table_name'] = runtime_conf.get('JobParam',
                                                              {}).get('job_id')
            table_name, namespace = dtable_utils.get_table_info(
                config=table_config, create=_create)
            workflow_param[name_param] = table_name
            workflow_param[namespace_param] = namespace
示例#3
0
def dtable(table_func):
    config = request.json
    if table_func == 'table_info':
        table_name, namespace = get_table_info(config=config,
                                               create=config.get(
                                                   'create', False))
        if config.get('create', False):
            table_key_count = 0
            table_partition = None
        else:
            table = session.get_data_table(name=table_name,
                                           namespace=namespace)
            if table:
                table_key_count = table.count()
                table_partition = table.get_partitions()
            else:
                table_key_count = 0
                table_partition = None
        return get_json_result(
            data={
                'table_name': table_name,
                'namespace': namespace,
                'count': table_key_count,
                'partition': table_partition
            })
    else:
        return get_json_result()
示例#4
0
 def run(self, component_parameters=None, args=None):
     self.parameters = component_parameters["DownloadParam"]
     self.parameters["role"] = component_parameters["role"]
     self.parameters["local"] = component_parameters["local"]
     table_name, namespace = dtable_utils.get_table_info(
         config=self.parameters, create=False)
     job_id = "_".join(self.taskid.split("_")[:2])
     session.init(job_id, self.parameters["work_mode"])
     with open(os.path.abspath(self.parameters["output_path"]),
               "w") as fout:
         data_table = session.get_data_table(name=table_name,
                                             namespace=namespace)
         LOGGER.info('===== begin to export data =====')
         lines = 0
         for key, value in data_table.collect():
             if not value:
                 fout.write(key + "\n")
             else:
                 fout.write(key + self.parameters.get("delimitor", ",") +
                            str(value) + "\n")
             lines += 1
             if lines % 2000 == 0:
                 LOGGER.info("===== export {} lines =====".format(lines))
         LOGGER.info("===== export {} lines totally =====".format(lines))
         LOGGER.info('===== export data finish =====')
         LOGGER.info('===== export data file path:{} ====='.format(
             os.path.abspath(self.parameters["output_path"])))
示例#5
0
    def run(self, component_parameters=None, args=None):
        self.parameters = component_parameters["UploadParam"]
        self.parameters["role"] = component_parameters["role"]
        self.parameters["local"] = component_parameters["local"]
        job_id = self.taskid.split("_")[0]
        if not os.path.isabs(self.parameters.get("file", "")):
            self.parameters["file"] = os.path.join(file_utils.get_project_base_directory(), self.parameters["file"])
        if not os.path.exists(self.parameters["file"]):
            raise Exception("%s is not exist, please check the configure" % (self.parameters["file"]))
        table_name, namespace = dtable_utils.get_table_info(config=self.parameters,
                                                            create=True)
        _namespace, _table_name = self.generate_table_name(self.parameters["file"])
        if namespace is None:
            namespace = _namespace
        if table_name is None:
            table_name = _table_name
        read_head = self.parameters['head']
        if read_head == 0:
            head = False
        elif read_head == 1:
            head = True
        else:
            raise Exception("'head' in conf.json should be 0 or 1")
        partition = self.parameters["partition"]
        if partition <= 0 or partition >= self.MAX_PARTITION_NUM:
            raise Exception("Error number of partition, it should between %d and %d" % (0, self.MAX_PARTITION_NUM))

        session.init(mode=self.parameters['work_mode'])
        data_table_count = self.save_data_table(table_name, namespace, head, self.parameters.get('in_version', False))
        LOGGER.info("------------load data finish!-----------------")
        LOGGER.info("file: {}".format(self.parameters["file"]))
        LOGGER.info("total data_count: {}".format(data_table_count))
        LOGGER.info("table name: {}, table namespace: {}".format(table_name, namespace))
示例#6
0
 def run(self, component_parameters=None, args=None):
     self.parameters = component_parameters["DownloadParam"]
     self.parameters["role"] = component_parameters["role"]
     self.parameters["local"] = component_parameters["local"]
     table_name, namespace = dtable_utils.get_table_info(config=self.parameters,
                                                         create=False)
     job_id = self.taskid.split("_")[0]
     session.init(job_id, self.parameters["work_mode"])
     with open(os.path.abspath(self.parameters["output_path"]), "w") as fout:
         data_table = session.get_data_table(name=table_name, namespace=namespace)
         count = data_table.count()
         LOGGER.info('===== begin to export data =====')
         lines = 0
         for key, value in data_table.collect():
             if not value:
                 fout.write(key + "\n")
             else:
                 fout.write(key + self.parameters.get("delimitor", ",") + value + "\n")
             lines += 1
             if lines % 2000 == 0:
                 LOGGER.info("===== export {} lines =====".format(lines))
             if lines % 10000 == 0:
                 job_info = {'f_progress': lines/count*100//1}
                 self.update_job_status(self.parameters["local"]['role'], self.parameters["local"]['party_id'],
                                        job_info)
         self.update_job_status(self.parameters["local"]['role'],
                                self.parameters["local"]['party_id'], {'f_progress': 100})
         self.callback_metric(metric_name='data_access',
                              metric_namespace='download',
                              metric_data=[Metric("count", data_table.count())])
         LOGGER.info("===== export {} lines totally =====".format(lines))
         LOGGER.info('===== export data finish =====')
         LOGGER.info('===== export data file path:{} ====='.format(os.path.abspath(self.parameters["output_path"])))
示例#7
0
def get_table():
    table_name, namespace = get_table_info(config=request.json.get('config'),
                                           create=request.json.get('create'))
    return get_json_result(retcode=0,
                           retmsg='success',
                           data={
                               'table_name': table_name,
                               'namespace': namespace
                           })
示例#8
0
def store_cache(dtable, guest_party_id, host_party_id, version, id_type, encrypt_type, tag='Za', namespace=None):
    if namespace is None:
        namespace = gen_cache_namespace(id_type, encrypt_type, tag, host_party_id, guest_party_id=guest_party_id)
    table_config = {}
    table_config['gen_table_info'] = True
    table_config['namespace'] = namespace
    table_config['table_name'] = version
    LOGGER.info(table_config)
    version, namespace = get_table_info(config=table_config, create=True)
    return save_data(dtable, namespace, version)
示例#9
0
def dtable(table_func):
    config = request.json
    if table_func == 'tableInfo':
        table_name, namespace = get_table_info(config=config, create=config.get('create', False))
        dtable = storage.get_data_table(name=table_name, namespace=namespace)
        if dtable:
            table_key_count = dtable.count()
        else:
            table_key_count = 0
        return get_json_result(data={'table_name': table_name, 'namespace': namespace, 'count': table_key_count})
    else:
        return get_json_result()
示例#10
0
def download_upload(data_func):
    request_config = request.json
    _job_id = generate_job_id()
    logger.info('generated job_id {}, body {}'.format(_job_id, request_config))
    _job_dir = get_job_directory(_job_id)
    os.makedirs(_job_dir, exist_ok=True)
    module = data_func
    if module == "upload":
        if not os.path.isabs(request_config.get("file", "")):
            request_config["file"] = os.path.join(file_utils.get_project_base_directory(), request_config["file"])
    try:
        request_config["work_mode"] = request_config.get('work_mode', WORK_MODE)
        table_name, namespace = dtable_utils.get_table_info(config=request_config, create=(True if module == 'upload' else False))
        if not table_name or not namespace:
            return get_json_result(status=102, msg='no table name and namespace')
        request_config['table_name'] = table_name
        request_config['namespace'] = namespace
        conf_file_path = new_runtime_conf(job_dir=_job_dir, method=data_func, module=module,
                                          role=request_config.get('local', {}).get("role"),
                                          party_id=request_config.get('local', {}).get("party_id", PARTY_ID))
        file_utils.dump_json_conf(request_config, conf_file_path)
        if module == "download":
            progs = ["python3",
                     os.path.join(file_utils.get_project_base_directory(), JOB_MODULE_CONF[module]["module_path"]),
                     "-j", _job_id,
                     "-c", conf_file_path
                     ]
        else:
            progs = ["python3",
                     os.path.join(file_utils.get_project_base_directory(), JOB_MODULE_CONF[module]["module_path"]),
                     "-c", conf_file_path
                     ]
        p = run_subprocess(job_dir=_job_dir, job_role=data_func, progs=progs)
        return get_json_result(job_id=_job_id, data={'pid': p.pid, 'table_name': request_config['table_name'], 'namespace': request_config['namespace']})
    except Exception as e:
        logger.exception(e)
        return get_json_result(status=-104, msg="failed", job_id=_job_id)
示例#11
0
def publish_online(config_data):
    _role = config_data.get('local').get('role')
    _party_id = config_data.get('local').get('party_id')
    for serving in config_data.get('servings'):
        with grpc.insecure_channel(serving) as channel:
            stub = model_service_pb2_grpc.ModelServiceStub(channel)
            publish_model_request = model_service_pb2.PublishRequest()
            for role_name, role_party in config_data.get("role").items():
                publish_model_request.role[role_name].partyId.extend(
                    role_party)

            for role_name, role_model_config in config_data.get(
                    "model").items():
                if role_name != _role:
                    continue
                if role_model_config.get(str(_party_id)):
                    table_config = copy.deepcopy(
                        role_model_config.get(str(_party_id)))
                    table_config['scene_id'] = config_data.get('scene_id')
                    table_config['local'] = {
                        'role': _role,
                        'party_id': _party_id
                    }
                    table_config['role'] = config_data.get('role')
                    table_config['data_type'] = 'model'
                    table_config['gen_table_info'] = True
                    table_name, namespace = dtable_utils.get_table_info(
                        config=table_config)
                    publish_model_request.model[_role].roleModelInfo[
                        _party_id].tableName = table_name
                    publish_model_request.model[_role].roleModelInfo[
                        _party_id].namespace = namespace
            publish_model_request.local.role = _role
            publish_model_request.local.partyId = _party_id
            logger.info(publish_model_request)
            response = stub.publishOnline(publish_model_request)
            logger.info(response)
示例#12
0
def get_table_info_without_create(table_config):
    table_name, namespace = get_table_info(config=table_config, create=False)
    return {'table_name': table_name, 'namespace': namespace}
示例#13
0
    def run(self, component_parameters=None, args=None):
        self.parameters = component_parameters["UploadSyncParam"]
        self.parameters["role"] = component_parameters["role"]
        self.parameters["local"] = component_parameters["local"]
        job_id = self.taskid.split("_")[0]
        if not os.path.isabs(self.parameters.get("file", "")):
            self.parameters["file"] = os.path.join(
                file_utils.get_project_base_directory(),
                self.parameters["file"])
        if not os.path.exists(self.parameters["file"]):
            raise Exception("%s is not exist, please check the configure" %
                            (self.parameters["file"]))
        if not os.path.getsize(self.parameters["file"]):
            raise Exception("%s is an empty file" % (self.parameters["file"]))
        table_name, namespace = dtable_utils.get_table_info(
            config=self.parameters, create=True)
        _namespace, _table_name = self.generate_table_name(
            self.parameters["file"])
        if namespace is None:
            namespace = _namespace
        if table_name is None:
            table_name = _table_name
        read_head = self.parameters['head']
        if read_head == 0:
            head = False
        elif read_head == 1:
            head = True
        else:
            raise Exception("'head' in conf.json should be 0 or 1")
        partition = self.parameters["partition"]
        if partition <= 0 or partition >= self.MAX_PARTITION_NUM:
            raise Exception(
                "Error number of partition, it should between %d and %d" %
                (0, self.MAX_PARTITION_NUM))

        # 上传数据
        session.init(mode=self.parameters['work_mode'])
        data_table_count = self.save_data_table(
            table_name, namespace, head,
            self.parameters.get('in_version', False))
        LOGGER.info("------------load data finish!-----------------")

        # rm tmp file
        try:
            if '{}/fate_upload_tmp'.format(job_id) in self.parameters['file']:
                LOGGER.info("remove tmp upload file")
                shutil.rmtree(
                    os.path.join(self.parameters["file"].split('tmp')[0],
                                 'tmp'))
        except:
            LOGGER.info("remove tmp file failed")

        LOGGER.info("file: {}".format(self.parameters["file"]))
        LOGGER.info("total data_count: {}".format(data_table_count))
        LOGGER.info("table name: {}, table namespace: {}".format(
            table_name, namespace))

        # 生成表信息。
        task_role = component_parameters["local"]["role"]
        LOGGER.info(
            f'component_parameters["local"]={component_parameters["local"]}')
        LOGGER.info(f"task_role={task_role}")

        self.table_info["tabel_name"] = table_name
        self.table_info["namespace"] = namespace
        self.table_info["file"] = self.parameters["file"]
        # self.table_info["cols"] = []
        # self.table_info["v_len"] = data_table_count
        self.table_info["payty_id"] = component_parameters["local"].get(
            "party_id")
        # self.table_info["statistics"] = statistics

        if task_role == consts.GUEST:
            self.sync_table_info(table_info)
            LOGGER.info(f"GUEST: Send -> table_info={table_info}")
        elif task_role == consts.HOST:
            table_info = self.recv_table_info()
            LOGGER.info(f"GUEST: Get -> table_info={table_info}")
        else:
            raise ValueError("{} role not support yet".format(task_role))
示例#14
0
def load_model(config_data):
    default_table_config = dict()
    default_table_config['scene_id'] = config_data.get('scene_id')
    default_table_config['role'] = config_data.get('role')
    default_table_config['data_type'] = 'model'
    default_table_config['gen_table_info'] = True
    logger.info(config_data)
    for serving in config_data.get('servings'):
        with grpc.insecure_channel(serving) as channel:
            stub = model_service_pb2_grpc.ModelServiceStub(channel)
            load_model_request = model_service_pb2.PublishRequest()
            model_table_name = ''
            for role_name, role_party in config_data.get("role").items():
                for _party_id in role_party:
                    load_model_request.role[role_name].partyId.append(
                        _party_id)
                    if _party_id == PARTY_ID:
                        # get model table name
                        # the model table names automatically generated by all parties are the same
                        local_party_model_config = config_data.get(
                            'model').get(role_name, {}).get(str(_party_id))
                        if local_party_model_config:
                            table_config = copy.deepcopy(default_table_config)
                            table_config.update(local_party_model_config)
                            table_config['local'] = {
                                'role': role_name,
                                'party_id': PARTY_ID
                            }
                            table_name, namespace = dtable_utils.get_table_info(
                                config=table_config)
                            model_table_name = table_name
                            load_model_request.model[role_name].roleModelInfo[
                                int(_party_id)].tableName = table_name
                            load_model_request.model[role_name].roleModelInfo[
                                int(_party_id)].namespace = namespace
            logger.info('load another party model')
            for role_name, role_model_config in config_data.get(
                    "model").items():
                for _party_id, role_party_model_config in role_model_config.items(
                ):
                    if _party_id == str(PARTY_ID) or not model_table_name:
                        continue
                    table_config = copy.deepcopy(default_table_config)
                    table_config['local'] = {
                        'role': role_name,
                        'party_id': _party_id
                    }
                    table_config.update(role_party_model_config)
                    table_config['table_name'] = table_config[
                        'table_name'] if table_config.get(
                            'table_name') else model_table_name
                    table_name, namespace = dtable_utils.get_table_info(
                        config=table_config)
                    load_model_request.model[role_name].roleModelInfo[int(
                        _party_id)].tableName = table_name
                    load_model_request.model[role_name].roleModelInfo[int(
                        _party_id)].namespace = namespace
            logger.info('request serving: {} load model'.format(serving))
            for role_name, role_party in config_data.get("role").items():
                if role_name == 'arbiter':
                    continue
                for _party_id in role_party:
                    if _party_id == PARTY_ID:
                        load_model_request.local.role = role_name
                        load_model_request.local.partyId = _party_id
                        logger.info(load_model_request)
                        response = stub.publishLoad(load_model_request)
                        logger.info('{} {} load model status: {}'.format(
                            role_name, _party_id, response.statusCode))