class IntersectModelBase(ModelBase): def __init__(self): super().__init__() self.intersection_obj = None self.proc_obj = None self.intersect_num = -1 self.intersect_rate = -1 self.unmatched_num = -1 self.unmatched_rate = -1 self.intersect_ids = None self.metric_name = "intersection" self.metric_namespace = "train" self.metric_type = "INTERSECTION" self.model_param_name = "IntersectModelParam" self.model_meta_name = "IntersectModelMeta" self.model_param = IntersectParam() self.use_match_id_process = False self.role = None self.guest_party_id = None self.host_party_id = None self.host_party_id_list = None self.transfer_variable = IntersectionFuncTransferVariable() def _init_model(self, params): self.model_param = params self.intersect_preprocess_params = params.intersect_preprocess_params def init_intersect_method(self): LOGGER.info("Using {} intersection, role is {}".format(self.model_param.intersect_method, self.role)) self.host_party_id_list = self.component_properties.host_party_idlist self.guest_party_id = self.component_properties.guest_partyid if self.role not in [consts.HOST, consts.GUEST]: raise ValueError("role {} is not support".format(self.role)) def get_model_summary(self): return {"intersect_num": self.intersect_num, "intersect_rate": self.intersect_rate, "cardinality_only": self.intersection_obj.cardinality_only} def __share_info(self, data): LOGGER.info("Start to share information with another role") info_share = self.transfer_variable.info_share_from_guest if self.model_param.info_owner == consts.GUEST else \ self.transfer_variable.info_share_from_host party_role = consts.GUEST if self.model_param.info_owner == consts.HOST else consts.HOST if self.role == self.model_param.info_owner: if data.schema.get('header') is not None: try: share_info_col_idx = data.schema.get('header').index(consts.SHARE_INFO_COL_NAME) one_data = data.first() if isinstance(one_data[1], Instance): share_data = data.join(self.intersect_ids, lambda d, i: [d.features[share_info_col_idx]]) else: share_data = data.join(self.intersect_ids, lambda d, i: [d[share_info_col_idx]]) info_share.remote(share_data, role=party_role, idx=-1) LOGGER.info("Remote share information to {}".format(party_role)) except Exception as e: LOGGER.warning("Something unexpected:{}, share a empty information to {}".format(e, party_role)) share_data = self.intersect_ids.mapValues(lambda v: ['null']) info_share.remote(share_data, role=party_role, idx=-1) else: raise ValueError( "'allow_info_share' is true, and 'info_owner' is {}, but can not get header in data, information sharing not done".format( self.model_param.info_owner)) else: self.intersect_ids = info_share.get(idx=0) self.intersect_ids.schema['header'] = [consts.SHARE_INFO_COL_NAME] LOGGER.info( "Get share information from {}, header:{}".format(self.model_param.info_owner, self.intersect_ids)) return self.intersect_ids def __sync_join_id(self, data, intersect_data): LOGGER.debug(f"data count: {data.count()}") LOGGER.debug(f"intersect_data count: {intersect_data.count()}") if self.model_param.sample_id_generator == consts.GUEST: sync_join_id = self.transfer_variable.join_id_from_guest else: sync_join_id = self.transfer_variable.join_id_from_host if self.role == self.model_param.sample_id_generator: join_data = data.subtractByKey(intersect_data) # LOGGER.debug(f"join_data count: {join_data.count()}") if self.model_param.new_sample_id: if self.model_param.only_output_key: join_data = join_data.map(lambda k, v: (uuid.uuid4().hex, None)) join_id = join_data else: join_data = join_data.map(lambda k, v: (uuid.uuid4().hex, v)) join_id = join_data.mapValues(lambda v: None) sync_join_id.remote(join_id) result_data = intersect_data.union(join_data) else: join_id = join_data.map(lambda k, v: (k, None)) result_data = data if self.model_param.only_output_key: if not self.use_match_id_process: result_data = data.mapValues(lambda v: None) sync_join_id.remote(join_id) else: join_id = sync_join_id.get(idx=0) # LOGGER.debug(f"received join_id count: {join_id.count()}") result_data = intersect_data.union(join_id) LOGGER.debug(f"result data count: {result_data.count()}") return result_data def callback(self): meta_info = {"intersect_method": self.model_param.intersect_method, "join_method": self.model_param.join_method} self.callback_metric(metric_name=self.metric_name, metric_namespace=self.metric_namespace, metric_data=[Metric("intersect_count", self.intersect_num), Metric("intersect_rate", self.intersect_rate), Metric("unmatched_count", self.unmatched_num), Metric("unmatched_rate", self.unmatched_rate)]) self.tracker.set_metric_meta(metric_namespace=self.metric_namespace, metric_name=self.metric_name, metric_meta=MetricMeta(name=self.metric_name, metric_type=self.metric_type, extra_metas=meta_info) ) def callback_cache_meta(self, intersect_meta): """ self.callback_metric(f"{self.metric_name}_cache_meta", f"{self.metric_namespace}_CACHE", metric_data=[Metric("intersect_cache_meta", 0)]) """ metric_name = f"{self.metric_name}_cache_meta" self.tracker.set_metric_meta(metric_namespace=self.metric_namespace, metric_name=metric_name, metric_meta=MetricMeta(name=f"{self.metric_name}_cache_meta", metric_type=self.metric_type, extra_metas=intersect_meta) ) def fit(self, data): if self.component_properties.caches: return self.intersect_online_process(data, self.component_properties.caches) self.init_intersect_method() if data_overview.check_with_inst_id(data): self.use_match_id_process = True LOGGER.info(f"use match_id_process") if self.use_match_id_process: if len(self.host_party_id_list) > 1 and self.model_param.sample_id_generator != consts.GUEST: raise ValueError("While multi-host, sample_id_generator should be guest.") if self.model_param.intersect_method == consts.RAW: if self.model_param.sample_id_generator != self.intersection_obj.join_role: raise ValueError(f"When using raw intersect with match id process," f"'join_role' should be same role as 'sample_id_generator'") else: if not self.model_param.sync_intersect_ids: if self.model_param.sample_id_generator != consts.GUEST: self.model_param.sample_id_generator = consts.GUEST LOGGER.warning(f"when not sync_intersect_ids with match id process," f"sample_id_generator is set to Guest") self.proc_obj = MatchIDIntersect(sample_id_generator=self.model_param.sample_id_generator, role=self.role) self.proc_obj.new_sample_id = self.model_param.new_sample_id if data_overview.check_with_inst_id(data) or self.model_param.with_sample_id: self.proc_obj.use_sample_id() match_data = self.proc_obj.recover(data=data) if self.intersection_obj.run_cache: self.cache_output = self.intersection_obj.generate_cache(match_data) intersect_meta = self.intersection_obj.get_intersect_method_meta() self.callback_cache_meta(intersect_meta) return data if self.intersection_obj.cardinality_only: self.intersection_obj.run_cardinality(match_data) else: intersect_data = match_data if self.model_param.run_preprocess: intersect_data = self.run_preprocess(match_data) self.intersect_ids = self.intersection_obj.run_intersect(intersect_data) else: if self.intersection_obj.run_cache: self.cache_output = self.intersection_obj.generate_cache(data) intersect_meta = self.intersection_obj.get_intersect_method_meta() # LOGGER.debug(f"callback intersect meta is: {intersect_meta}") self.callback_cache_meta(intersect_meta) return data if self.intersection_obj.cardinality_only: self.intersection_obj.run_cardinality(data) else: intersect_data = data if self.model_param.run_preprocess: intersect_data = self.run_preprocess(data) self.intersect_ids = self.intersection_obj.run_intersect(intersect_data) if self.intersection_obj.cardinality_only: if self.intersection_obj.intersect_num is not None: data_count = data.count() self.intersect_num = self.intersection_obj.intersect_num self.intersect_rate = self.intersect_num / data_count self.unmatched_num = data_count - self.intersect_num self.unmatched_rate = 1 - self.intersect_rate # self.model = self.intersection_obj.get_model() self.set_summary(self.get_model_summary()) self.callback() return data if self.use_match_id_process: if self.model_param.sync_intersect_ids: self.intersect_ids = self.proc_obj.expand(self.intersect_ids, match_data=match_data) else: # self.intersect_ids = match_data self.intersect_ids = self.proc_obj.expand(self.intersect_ids, match_data=match_data, owner_only=True) if self.model_param.only_output_key and self.intersect_ids: self.intersect_ids = self.intersect_ids.mapValues(lambda v: Instance(inst_id=v.inst_id)) self.intersect_ids.schema = {"match_id_name": data.schema["match_id_name"], "sid_name": data.schema["sid_name"]} LOGGER.info("Finish intersection") if self.intersect_ids: data_count = data.count() self.intersect_num = self.intersect_ids.count() self.intersect_rate = self.intersect_num / data_count self.unmatched_num = data_count - self.intersect_num self.unmatched_rate = 1 - self.intersect_rate self.set_summary(self.get_model_summary()) self.callback() result_data = self.intersect_ids if not self.use_match_id_process and not self.intersection_obj.only_output_key and result_data: result_data = self.intersection_obj.get_value_from_data(result_data, data) LOGGER.debug(f"not only_output_key, restore value called") if self.model_param.join_method == consts.LEFT_JOIN: result_data = self.__sync_join_id(data, self.intersect_ids) result_data.schema = self.intersect_ids.schema return result_data def check_consistency(self): pass def load_intersect_meta(self, intersect_meta): if self.model_param.intersect_method == consts.RSA: if intersect_meta["intersect_method"] != consts.RSA: raise ValueError(f"Current intersect method must match to cache record.") self.model_param.rsa_params.hash_method = intersect_meta["hash_method"] self.model_param.rsa_params.final_hash_method = intersect_meta["final_hash_method"] self.model_param.rsa_params.salt = intersect_meta["salt"] self.model_param.rsa_params.random_bit = intersect_meta["random_bit"] elif self.model_param.intersect_method == consts.DH: if intersect_meta["intersect_method"] != consts.DH: raise ValueError(f"Current intersect method must match to cache record.") self.model_param.dh_params.hash_method = intersect_meta["hash_method"] self.model_param.dh_params.salt = intersect_meta["salt"] else: raise ValueError(f"{self.model_param.intersect_method} does not support cache.") def make_filter_process(self, data_instances, hash_operator): raise NotImplementedError("This method should not be called here") def get_filter_process(self, data_instances, hash_operator): raise NotImplementedError("This method should not be called here") def run_preprocess(self, data_instances): preprocess_hash_operator = Hash(self.model_param.intersect_preprocess_params.preprocess_method, False) if self.role == self.model_param.intersect_preprocess_params.filter_owner: data = self.make_filter_process(data_instances, preprocess_hash_operator) else: LOGGER.debug(f"before preprocess, data count: {data_instances.count()}") data = self.get_filter_process(data_instances, preprocess_hash_operator) LOGGER.debug(f"after preprocess, data count: {data.count()}") return data def intersect_online_process(self, data_inst, caches): # LOGGER.debug(f"caches is: {caches}") cache_data, cache_meta = list(caches.values())[0] intersect_meta = list(cache_meta.values())[0]["intersect_meta"] # LOGGER.debug(f"intersect_meta is: {intersect_meta}") self.callback_cache_meta(intersect_meta) self.load_intersect_meta(intersect_meta) self.init_intersect_method() self.intersection_obj.load_intersect_key(cache_meta) if data_overview.check_with_inst_id(data_inst): self.use_match_id_process = True LOGGER.info(f"use match_id_process") intersect_data = data_inst if self.use_match_id_process: if len(self.host_party_id_list) > 1 and self.model_param.sample_id_generator != consts.GUEST: raise ValueError("While multi-host, sample_id_generator should be guest.") if self.model_param.intersect_method == consts.RAW: if self.model_param.sample_id_generator != self.intersection_obj.join_role: raise ValueError(f"When using raw intersect with match id process," f"'join_role' should be same role as 'sample_id_generator'") else: if not self.model_param.sync_intersect_ids: if self.model_param.sample_id_generator != consts.GUEST: self.model_param.sample_id_generator = consts.GUEST LOGGER.warning(f"when not sync_intersect_ids with match id process," f"sample_id_generator is set to Guest") proc_obj = MatchIDIntersect(sample_id_generator=self.model_param.sample_id_generator, role=self.role) proc_obj.new_sample_id = self.model_param.new_sample_id if data_overview.check_with_inst_id(data_inst) or self.model_param.with_sample_id: proc_obj.use_sample_id() match_data = proc_obj.recover(data=data_inst) intersect_data = match_data if self.role == consts.HOST: cache_id = cache_meta[str(self.guest_party_id)].get("cache_id") self.transfer_variable.cache_id.remote(cache_id, role=consts.GUEST, idx=0) guest_cache_id = self.transfer_variable.cache_id.get(role=consts.GUEST, idx=0) if guest_cache_id != cache_id: raise ValueError(f"cache_id check failed. cache_id from host & guest must match.") elif self.role == consts.GUEST: for i, party_id in enumerate(self.host_party_id_list): cache_id = cache_meta[str(party_id)].get("cache_id") self.transfer_variable.cache_id.remote(cache_id, role=consts.HOST, idx=i) host_cache_id = self.transfer_variable.cache_id.get(role=consts.HOST, idx=i) if host_cache_id != cache_id: raise ValueError(f"cache_id check failed. cache_id from host & guest must match.") else: raise ValueError(f"Role {self.role} cannot run intersection transform.") self.intersect_ids = self.intersection_obj.run_cache_intersect(intersect_data, cache_data) if self.use_match_id_process: if not self.model_param.sync_intersect_ids: self.intersect_ids = proc_obj.expand(self.intersect_ids, match_data=match_data, owner_only=True) else: self.intersect_ids = proc_obj.expand(self.intersect_ids, match_data=match_data) if self.intersect_ids and self.model_param.only_output_key: self.intersect_ids = self.intersect_ids.mapValues(lambda v: Instance(inst_id=v.inst_id)) self.intersect_ids.schema = {"match_id_name": data_inst.schema["match_id_name"], "sid_name": data_inst.schema["sid_name"]} LOGGER.info("Finish intersection") if self.intersect_ids: data_count = data_inst.count() self.intersect_num = self.intersect_ids.count() self.intersect_rate = self.intersect_num / data_count self.unmatched_num = data_count - self.intersect_num self.unmatched_rate = 1 - self.intersect_rate self.set_summary(self.get_model_summary()) self.callback() result_data = self.intersect_ids if not self.use_match_id_process: if not self.intersection_obj.only_output_key and result_data: result_data = self.intersection_obj.get_value_from_data(result_data, data_inst) self.intersect_ids.schema = result_data.schema LOGGER.debug(f"not only_output_key, restore value called") if self.intersection_obj.only_output_key and result_data: schema = {"sid_name": data_inst.schema["sid_name"]} result_data = result_data.mapValues(lambda v: 1) result_data.schema = schema self.intersect_ids.schema = schema if self.model_param.join_method == consts.LEFT_JOIN: result_data = self.__sync_join_id(data_inst, self.intersect_ids) result_data.schema = self.intersect_ids.schema return result_data
def intersect_online_process(self, data_inst, caches): # LOGGER.debug(f"caches is: {caches}") cache_data, cache_meta = list(caches.values())[0] intersect_meta = list(cache_meta.values())[0]["intersect_meta"] # LOGGER.debug(f"intersect_meta is: {intersect_meta}") self.callback_cache_meta(intersect_meta) self.load_intersect_meta(intersect_meta) self.init_intersect_method() self.intersection_obj.load_intersect_key(cache_meta) if data_overview.check_with_inst_id(data_inst): self.use_match_id_process = True LOGGER.info(f"use match_id_process") intersect_data = data_inst if self.use_match_id_process: if len(self.host_party_id_list) > 1 and self.model_param.sample_id_generator != consts.GUEST: raise ValueError("While multi-host, sample_id_generator should be guest.") if self.model_param.intersect_method == consts.RAW: if self.model_param.sample_id_generator != self.intersection_obj.join_role: raise ValueError(f"When using raw intersect with match id process," f"'join_role' should be same role as 'sample_id_generator'") else: if not self.model_param.sync_intersect_ids: if self.model_param.sample_id_generator != consts.GUEST: self.model_param.sample_id_generator = consts.GUEST LOGGER.warning(f"when not sync_intersect_ids with match id process," f"sample_id_generator is set to Guest") proc_obj = MatchIDIntersect(sample_id_generator=self.model_param.sample_id_generator, role=self.role) proc_obj.new_sample_id = self.model_param.new_sample_id if data_overview.check_with_inst_id(data_inst) or self.model_param.with_sample_id: proc_obj.use_sample_id() match_data = proc_obj.recover(data=data_inst) intersect_data = match_data if self.role == consts.HOST: cache_id = cache_meta[str(self.guest_party_id)].get("cache_id") self.transfer_variable.cache_id.remote(cache_id, role=consts.GUEST, idx=0) guest_cache_id = self.transfer_variable.cache_id.get(role=consts.GUEST, idx=0) if guest_cache_id != cache_id: raise ValueError(f"cache_id check failed. cache_id from host & guest must match.") elif self.role == consts.GUEST: for i, party_id in enumerate(self.host_party_id_list): cache_id = cache_meta[str(party_id)].get("cache_id") self.transfer_variable.cache_id.remote(cache_id, role=consts.HOST, idx=i) host_cache_id = self.transfer_variable.cache_id.get(role=consts.HOST, idx=i) if host_cache_id != cache_id: raise ValueError(f"cache_id check failed. cache_id from host & guest must match.") else: raise ValueError(f"Role {self.role} cannot run intersection transform.") self.intersect_ids = self.intersection_obj.run_cache_intersect(intersect_data, cache_data) if self.use_match_id_process: if not self.model_param.sync_intersect_ids: self.intersect_ids = proc_obj.expand(self.intersect_ids, match_data=match_data, owner_only=True) else: self.intersect_ids = proc_obj.expand(self.intersect_ids, match_data=match_data) if self.intersect_ids and self.model_param.only_output_key: self.intersect_ids = self.intersect_ids.mapValues(lambda v: Instance(inst_id=v.inst_id)) self.intersect_ids.schema = {"match_id_name": data_inst.schema["match_id_name"], "sid_name": data_inst.schema["sid_name"]} LOGGER.info("Finish intersection") if self.intersect_ids: data_count = data_inst.count() self.intersect_num = self.intersect_ids.count() self.intersect_rate = self.intersect_num / data_count self.unmatched_num = data_count - self.intersect_num self.unmatched_rate = 1 - self.intersect_rate self.set_summary(self.get_model_summary()) self.callback() result_data = self.intersect_ids if not self.use_match_id_process: if not self.intersection_obj.only_output_key and result_data: result_data = self.intersection_obj.get_value_from_data(result_data, data_inst) self.intersect_ids.schema = result_data.schema LOGGER.debug(f"not only_output_key, restore value called") if self.intersection_obj.only_output_key and result_data: schema = {"sid_name": data_inst.schema["sid_name"]} result_data = result_data.mapValues(lambda v: 1) result_data.schema = schema self.intersect_ids.schema = schema if self.model_param.join_method == consts.LEFT_JOIN: result_data = self.__sync_join_id(data_inst, self.intersect_ids) result_data.schema = self.intersect_ids.schema return result_data
def _recover_match_id(self, data_instance): self.proc_obj = MatchIDIntersect(sample_id_generator=consts.GUEST, role=self.intersection_obj.role) self.proc_obj.new_join_id = False self.proc_obj.use_sample_id() match_data = self.proc_obj.recover(data=data_instance) return match_data
def fit(self, data): if self.component_properties.caches: return self.intersect_online_process(data, self.component_properties.caches) self.init_intersect_method() if data_overview.check_with_inst_id(data): self.use_match_id_process = True LOGGER.info(f"use match_id_process") if self.use_match_id_process: if len(self.host_party_id_list) > 1 and self.model_param.sample_id_generator != consts.GUEST: raise ValueError("While multi-host, sample_id_generator should be guest.") if self.model_param.intersect_method == consts.RAW: if self.model_param.sample_id_generator != self.intersection_obj.join_role: raise ValueError(f"When using raw intersect with match id process," f"'join_role' should be same role as 'sample_id_generator'") else: if not self.model_param.sync_intersect_ids: if self.model_param.sample_id_generator != consts.GUEST: self.model_param.sample_id_generator = consts.GUEST LOGGER.warning(f"when not sync_intersect_ids with match id process," f"sample_id_generator is set to Guest") self.proc_obj = MatchIDIntersect(sample_id_generator=self.model_param.sample_id_generator, role=self.role) self.proc_obj.new_sample_id = self.model_param.new_sample_id if data_overview.check_with_inst_id(data) or self.model_param.with_sample_id: self.proc_obj.use_sample_id() match_data = self.proc_obj.recover(data=data) if self.intersection_obj.run_cache: self.cache_output = self.intersection_obj.generate_cache(match_data) intersect_meta = self.intersection_obj.get_intersect_method_meta() self.callback_cache_meta(intersect_meta) return data if self.intersection_obj.cardinality_only: self.intersection_obj.run_cardinality(match_data) else: intersect_data = match_data if self.model_param.run_preprocess: intersect_data = self.run_preprocess(match_data) self.intersect_ids = self.intersection_obj.run_intersect(intersect_data) else: if self.intersection_obj.run_cache: self.cache_output = self.intersection_obj.generate_cache(data) intersect_meta = self.intersection_obj.get_intersect_method_meta() # LOGGER.debug(f"callback intersect meta is: {intersect_meta}") self.callback_cache_meta(intersect_meta) return data if self.intersection_obj.cardinality_only: self.intersection_obj.run_cardinality(data) else: intersect_data = data if self.model_param.run_preprocess: intersect_data = self.run_preprocess(data) self.intersect_ids = self.intersection_obj.run_intersect(intersect_data) if self.intersection_obj.cardinality_only: if self.intersection_obj.intersect_num is not None: data_count = data.count() self.intersect_num = self.intersection_obj.intersect_num self.intersect_rate = self.intersect_num / data_count self.unmatched_num = data_count - self.intersect_num self.unmatched_rate = 1 - self.intersect_rate # self.model = self.intersection_obj.get_model() self.set_summary(self.get_model_summary()) self.callback() return data if self.use_match_id_process: if self.model_param.sync_intersect_ids: self.intersect_ids = self.proc_obj.expand(self.intersect_ids, match_data=match_data) else: # self.intersect_ids = match_data self.intersect_ids = self.proc_obj.expand(self.intersect_ids, match_data=match_data, owner_only=True) if self.model_param.only_output_key and self.intersect_ids: self.intersect_ids = self.intersect_ids.mapValues(lambda v: Instance(inst_id=v.inst_id)) self.intersect_ids.schema = {"match_id_name": data.schema["match_id_name"], "sid_name": data.schema["sid_name"]} LOGGER.info("Finish intersection") if self.intersect_ids: data_count = data.count() self.intersect_num = self.intersect_ids.count() self.intersect_rate = self.intersect_num / data_count self.unmatched_num = data_count - self.intersect_num self.unmatched_rate = 1 - self.intersect_rate self.set_summary(self.get_model_summary()) self.callback() result_data = self.intersect_ids if not self.use_match_id_process and not self.intersection_obj.only_output_key and result_data: result_data = self.intersection_obj.get_value_from_data(result_data, data) LOGGER.debug(f"not only_output_key, restore value called") if self.model_param.join_method == consts.LEFT_JOIN: result_data = self.__sync_join_id(data, self.intersect_ids) result_data.schema = self.intersect_ids.schema return result_data
class BaseSecureInformationRetrieval(ModelBase): """ """ def __init__(self): super(BaseSecureInformationRetrieval, self).__init__() self.model_param = SecureInformationRetrievalParam() self.security_level = None self.commutative_cipher = None self.transfer_variable = None self.block_num = None # N in 1-N OT self.coverage = None # the percentage of transactions whose values are successfully retrieved self.dh_params = None self.intersection_obj = None self.proc_obj = None self.with_inst_id = None self.need_label = False self.target_cols = None # For callback self.metric_name = "sir" self.metric_namespace = "train" self.metric_type = "SIR" def _init_base_model(self, param: SecureInformationRetrievalParam): self.transfer_variable = SecureInformationRetrievalTransferVariable() self._init_transfer_variable() self.model_param = param self.security_level = self.model_param.security_level self.dh_params = self.model_param.dh_params self.target_cols = self.model_param.target_cols def _init_transfer_variable(self): self.transfer_variable.natural_indexation.disable_auto_clean() self.transfer_variable.id_blocks_ciphertext.disable_auto_clean() @staticmethod def _abnormal_detection(data_instances): """ Make sure input data_instances is valid. """ abnormal_detection.empty_table_detection(data_instances) abnormal_detection.empty_feature_detection(data_instances) """ @staticmethod def record_original_id(k, v): if isinstance(k, str): restored_id = conversion.int_to_str(conversion.str_to_int(k)) else: restored_id = k return (restored_id, k) """ def _check_need_label(self): return len(self.target_cols) == 0 def _recover_match_id(self, data_instance): self.proc_obj = MatchIDIntersect(sample_id_generator=consts.GUEST, role=self.intersection_obj.role) self.proc_obj.new_join_id = False self.proc_obj.use_sample_id() match_data = self.proc_obj.recover(data=data_instance) return match_data def _restore_sample_id(self, data_instances): restore_data = self.proc_obj.expand(data_instances, owner_only=True) return restore_data def _raw_information_retrieval(self, data_instance): """ If security_level == 0, then perform raw information retrieval :param data_instance: :return: """ pass def _parse_security_level(self, data_instance): """ Cooperatively parse the security level index :param data_instance: :return: """ pass def _sync_natural_index(self, id_list_arr): """ guest -> host :param id_list_arr: :return: """ def _sync_natural_indexation(self, id_list, time): """ guest -> host :param id_list: :param time :return: """ def _sync_block_num(self): """ guest -> host :param :return: """ def _transmit_value_ciphertext(self, id_block, time): """ host -> guest :param id_block: :param time: int :return: """ def _check_oblivious_transfer_condition(self): """ 1-N OT with N no smaller than 2 is supported :return: """ return self.block_num >= 2 def _failure_response(self): """ If even 1-2 OT cannot be performed, make failure response :return: """ raise ValueError("Cannot perform even 1-2 OT, recommend use raw retrieval") def _sync_coverage(self, data_instance): """ guest -> host :param data_instance: :return: """ pass def _sync_nonce_list(self, nonce, time): """ host -> guest :param nonce: :return: """ pass def export_model(self): if self.model_output is not None: return self.model_output meta_obj = self._get_meta() param_obj = self._get_param() result = { MODEL_META_NAME: meta_obj, MODEL_PARAM_NAME: param_obj } self.model_output = result return result def _get_meta(self): return sir_meta_pb2.SecureInformationRetrievalMeta( security_level=self.security_level, oblivious_transfer_protocol=self.model_param.oblivious_transfer_protocol, commutative_encryption=self.model_param.commutative_encryption, non_committing_encryption=self.model_param.non_committing_encryption, key_size=self.model_param.key_size, raw_retrieval=self.model_param.raw_retrieval ) def _get_param(self): return sir_param_pb2.SecureInformationRetrievalParam( coverage=self.coverage, block_num=self.block_num ) def _display_result(self, block_num=None): if block_num is None: self.callback_metric(metric_name=self.metric_name, metric_namespace=self.metric_namespace, metric_data=[Metric("Coverage", self.coverage), Metric("Block number", self.block_num)]) self.tracker.set_metric_meta(metric_namespace=self.metric_namespace, metric_name=self.metric_name, metric_meta=MetricMeta(self.metric_name, metric_type="INTERSECTION")) else: self.callback_metric(metric_name=self.metric_name, metric_namespace=self.metric_namespace, metric_data=[Metric("Coverage", self.coverage), Metric("Block number", block_num)]) self.tracker.set_metric_meta(metric_namespace=self.metric_namespace, metric_name=self.metric_name, metric_meta=MetricMeta(self.metric_name, metric_type="INTERSECTION")) """