コード例 #1
0
class IntersectModelBase(ModelBase):
    def __init__(self):
        super().__init__()
        self.intersection_obj = None
        self.proc_obj = None
        self.intersect_num = -1
        self.intersect_rate = -1
        self.unmatched_num = -1
        self.unmatched_rate = -1
        self.intersect_ids = None
        self.metric_name = "intersection"
        self.metric_namespace = "train"
        self.metric_type = "INTERSECTION"
        self.model_param_name = "IntersectModelParam"
        self.model_meta_name = "IntersectModelMeta"
        self.model_param = IntersectParam()
        self.use_match_id_process = False
        self.role = None

        self.guest_party_id = None
        self.host_party_id = None
        self.host_party_id_list = None

        self.transfer_variable = IntersectionFuncTransferVariable()

    def _init_model(self, params):
        self.model_param = params
        self.intersect_preprocess_params = params.intersect_preprocess_params

    def init_intersect_method(self):
        LOGGER.info("Using {} intersection, role is {}".format(self.model_param.intersect_method, self.role))
        self.host_party_id_list = self.component_properties.host_party_idlist
        self.guest_party_id = self.component_properties.guest_partyid

        if self.role not in [consts.HOST, consts.GUEST]:
            raise ValueError("role {} is not support".format(self.role))

    def get_model_summary(self):
        return {"intersect_num": self.intersect_num, "intersect_rate": self.intersect_rate,
                "cardinality_only": self.intersection_obj.cardinality_only}

    def __share_info(self, data):
        LOGGER.info("Start to share information with another role")
        info_share = self.transfer_variable.info_share_from_guest if self.model_param.info_owner == consts.GUEST else \
            self.transfer_variable.info_share_from_host
        party_role = consts.GUEST if self.model_param.info_owner == consts.HOST else consts.HOST

        if self.role == self.model_param.info_owner:
            if data.schema.get('header') is not None:
                try:
                    share_info_col_idx = data.schema.get('header').index(consts.SHARE_INFO_COL_NAME)

                    one_data = data.first()
                    if isinstance(one_data[1], Instance):
                        share_data = data.join(self.intersect_ids, lambda d, i: [d.features[share_info_col_idx]])
                    else:
                        share_data = data.join(self.intersect_ids, lambda d, i: [d[share_info_col_idx]])

                    info_share.remote(share_data,
                                      role=party_role,
                                      idx=-1)
                    LOGGER.info("Remote share information to {}".format(party_role))

                except Exception as e:
                    LOGGER.warning("Something unexpected:{}, share a empty information to {}".format(e, party_role))
                    share_data = self.intersect_ids.mapValues(lambda v: ['null'])
                    info_share.remote(share_data,
                                      role=party_role,
                                      idx=-1)
            else:
                raise ValueError(
                    "'allow_info_share' is true, and 'info_owner' is {}, but can not get header in data, information sharing not done".format(
                        self.model_param.info_owner))
        else:
            self.intersect_ids = info_share.get(idx=0)
            self.intersect_ids.schema['header'] = [consts.SHARE_INFO_COL_NAME]
            LOGGER.info(
                "Get share information from {}, header:{}".format(self.model_param.info_owner, self.intersect_ids))

        return self.intersect_ids

    def __sync_join_id(self, data, intersect_data):
        LOGGER.debug(f"data count: {data.count()}")
        LOGGER.debug(f"intersect_data count: {intersect_data.count()}")

        if self.model_param.sample_id_generator == consts.GUEST:
            sync_join_id = self.transfer_variable.join_id_from_guest
        else:
            sync_join_id = self.transfer_variable.join_id_from_host
        if self.role == self.model_param.sample_id_generator:
            join_data = data.subtractByKey(intersect_data)
            # LOGGER.debug(f"join_data count: {join_data.count()}")
            if self.model_param.new_sample_id:
                if self.model_param.only_output_key:
                    join_data = join_data.map(lambda k, v: (uuid.uuid4().hex, None))
                    join_id = join_data
                else:
                    join_data = join_data.map(lambda k, v: (uuid.uuid4().hex, v))
                    join_id = join_data.mapValues(lambda v: None)
                sync_join_id.remote(join_id)

                result_data = intersect_data.union(join_data)
            else:
                join_id = join_data.map(lambda k, v: (k, None))
                result_data = data
                if self.model_param.only_output_key:
                    if not self.use_match_id_process:
                        result_data = data.mapValues(lambda v: None)
                sync_join_id.remote(join_id)
        else:
            join_id = sync_join_id.get(idx=0)
            # LOGGER.debug(f"received join_id count: {join_id.count()}")
            result_data = intersect_data.union(join_id)
        LOGGER.debug(f"result data count: {result_data.count()}")
        return result_data

    def callback(self):
        meta_info = {"intersect_method": self.model_param.intersect_method,
                     "join_method": self.model_param.join_method}
        self.callback_metric(metric_name=self.metric_name,
                             metric_namespace=self.metric_namespace,
                             metric_data=[Metric("intersect_count", self.intersect_num),
                                          Metric("intersect_rate", self.intersect_rate),
                                          Metric("unmatched_count", self.unmatched_num),
                                          Metric("unmatched_rate", self.unmatched_rate)])
        self.tracker.set_metric_meta(metric_namespace=self.metric_namespace,
                                     metric_name=self.metric_name,
                                     metric_meta=MetricMeta(name=self.metric_name,
                                                            metric_type=self.metric_type,
                                                            extra_metas=meta_info)
                                     )

    def callback_cache_meta(self, intersect_meta):
        """
        self.callback_metric(f"{self.metric_name}_cache_meta",
                             f"{self.metric_namespace}_CACHE",
                             metric_data=[Metric("intersect_cache_meta", 0)])
        """
        metric_name = f"{self.metric_name}_cache_meta"
        self.tracker.set_metric_meta(metric_namespace=self.metric_namespace,
                                     metric_name=metric_name,
                                     metric_meta=MetricMeta(name=f"{self.metric_name}_cache_meta",
                                                            metric_type=self.metric_type,
                                                            extra_metas=intersect_meta)
                                     )

    def fit(self, data):
        if self.component_properties.caches:
            return self.intersect_online_process(data, self.component_properties.caches)
        self.init_intersect_method()
        if data_overview.check_with_inst_id(data):
            self.use_match_id_process = True
            LOGGER.info(f"use match_id_process")

        if self.use_match_id_process:
            if len(self.host_party_id_list) > 1 and self.model_param.sample_id_generator != consts.GUEST:
                raise ValueError("While multi-host, sample_id_generator should be guest.")
            if self.model_param.intersect_method == consts.RAW:
                if self.model_param.sample_id_generator != self.intersection_obj.join_role:
                    raise ValueError(f"When using raw intersect with match id process,"
                                     f"'join_role' should be same role as 'sample_id_generator'")
            else:
                if not self.model_param.sync_intersect_ids:
                    if self.model_param.sample_id_generator != consts.GUEST:
                        self.model_param.sample_id_generator = consts.GUEST
                        LOGGER.warning(f"when not sync_intersect_ids with match id process,"
                                         f"sample_id_generator is set to Guest")

            self.proc_obj = MatchIDIntersect(sample_id_generator=self.model_param.sample_id_generator, role=self.role)
            self.proc_obj.new_sample_id = self.model_param.new_sample_id
            if data_overview.check_with_inst_id(data) or self.model_param.with_sample_id:
                self.proc_obj.use_sample_id()
            match_data = self.proc_obj.recover(data=data)
            if self.intersection_obj.run_cache:
                self.cache_output = self.intersection_obj.generate_cache(match_data)
                intersect_meta = self.intersection_obj.get_intersect_method_meta()
                self.callback_cache_meta(intersect_meta)
                return data
            if self.intersection_obj.cardinality_only:
                self.intersection_obj.run_cardinality(match_data)
            else:
                intersect_data = match_data
                if self.model_param.run_preprocess:
                    intersect_data = self.run_preprocess(match_data)
                self.intersect_ids = self.intersection_obj.run_intersect(intersect_data)
        else:
            if self.intersection_obj.run_cache:
                self.cache_output = self.intersection_obj.generate_cache(data)
                intersect_meta = self.intersection_obj.get_intersect_method_meta()
                # LOGGER.debug(f"callback intersect meta is: {intersect_meta}")
                self.callback_cache_meta(intersect_meta)
                return data
            if self.intersection_obj.cardinality_only:
                self.intersection_obj.run_cardinality(data)
            else:
                intersect_data = data
                if self.model_param.run_preprocess:
                    intersect_data = self.run_preprocess(data)
                self.intersect_ids = self.intersection_obj.run_intersect(intersect_data)

        if self.intersection_obj.cardinality_only:
            if self.intersection_obj.intersect_num is not None:
                data_count = data.count()
                self.intersect_num = self.intersection_obj.intersect_num
                self.intersect_rate = self.intersect_num / data_count
                self.unmatched_num = data_count - self.intersect_num
                self.unmatched_rate = 1 - self.intersect_rate
            # self.model = self.intersection_obj.get_model()
            self.set_summary(self.get_model_summary())
            self.callback()
            return data

        if self.use_match_id_process:
            if self.model_param.sync_intersect_ids:
                self.intersect_ids = self.proc_obj.expand(self.intersect_ids, match_data=match_data)
            else:
                # self.intersect_ids = match_data
                self.intersect_ids = self.proc_obj.expand(self.intersect_ids,
                                                          match_data=match_data,
                                                          owner_only=True)
            if self.model_param.only_output_key and self.intersect_ids:
                self.intersect_ids = self.intersect_ids.mapValues(lambda v: Instance(inst_id=v.inst_id))
                self.intersect_ids.schema = {"match_id_name": data.schema["match_id_name"],
                                             "sid_name": data.schema["sid_name"]}

        LOGGER.info("Finish intersection")

        if self.intersect_ids:
            data_count = data.count()
            self.intersect_num = self.intersect_ids.count()
            self.intersect_rate = self.intersect_num / data_count
            self.unmatched_num = data_count - self.intersect_num
            self.unmatched_rate = 1 - self.intersect_rate

        self.set_summary(self.get_model_summary())
        self.callback()

        result_data = self.intersect_ids
        if not self.use_match_id_process and not self.intersection_obj.only_output_key and result_data:
            result_data = self.intersection_obj.get_value_from_data(result_data, data)
            LOGGER.debug(f"not only_output_key, restore value called")

        if self.model_param.join_method == consts.LEFT_JOIN:
            result_data = self.__sync_join_id(data, self.intersect_ids)
            result_data.schema = self.intersect_ids.schema

        return result_data

    def check_consistency(self):
        pass

    def load_intersect_meta(self, intersect_meta):
        if self.model_param.intersect_method == consts.RSA:
            if intersect_meta["intersect_method"] != consts.RSA:
                raise ValueError(f"Current intersect method must match to cache record.")
            self.model_param.rsa_params.hash_method = intersect_meta["hash_method"]
            self.model_param.rsa_params.final_hash_method = intersect_meta["final_hash_method"]
            self.model_param.rsa_params.salt = intersect_meta["salt"]
            self.model_param.rsa_params.random_bit = intersect_meta["random_bit"]
        elif self.model_param.intersect_method == consts.DH:
            if intersect_meta["intersect_method"] != consts.DH:
                raise ValueError(f"Current intersect method must match to cache record.")
            self.model_param.dh_params.hash_method = intersect_meta["hash_method"]
            self.model_param.dh_params.salt = intersect_meta["salt"]
        else:
            raise ValueError(f"{self.model_param.intersect_method} does not support cache.")

    def make_filter_process(self, data_instances, hash_operator):
        raise NotImplementedError("This method should not be called here")

    def get_filter_process(self, data_instances, hash_operator):
        raise NotImplementedError("This method should not be called here")

    def run_preprocess(self, data_instances):
        preprocess_hash_operator = Hash(self.model_param.intersect_preprocess_params.preprocess_method, False)
        if self.role == self.model_param.intersect_preprocess_params.filter_owner:
            data = self.make_filter_process(data_instances, preprocess_hash_operator)
        else:
            LOGGER.debug(f"before preprocess, data count: {data_instances.count()}")
            data = self.get_filter_process(data_instances, preprocess_hash_operator)
            LOGGER.debug(f"after preprocess, data count: {data.count()}")
        return data

    def intersect_online_process(self, data_inst, caches):
        # LOGGER.debug(f"caches is: {caches}")
        cache_data, cache_meta = list(caches.values())[0]
        intersect_meta = list(cache_meta.values())[0]["intersect_meta"]
        # LOGGER.debug(f"intersect_meta is: {intersect_meta}")
        self.callback_cache_meta(intersect_meta)
        self.load_intersect_meta(intersect_meta)
        self.init_intersect_method()
        self.intersection_obj.load_intersect_key(cache_meta)

        if data_overview.check_with_inst_id(data_inst):
            self.use_match_id_process = True
            LOGGER.info(f"use match_id_process")
        intersect_data = data_inst
        if self.use_match_id_process:
            if len(self.host_party_id_list) > 1 and self.model_param.sample_id_generator != consts.GUEST:
                raise ValueError("While multi-host, sample_id_generator should be guest.")
            if self.model_param.intersect_method == consts.RAW:
                if self.model_param.sample_id_generator != self.intersection_obj.join_role:
                    raise ValueError(f"When using raw intersect with match id process,"
                                     f"'join_role' should be same role as 'sample_id_generator'")
            else:
                if not self.model_param.sync_intersect_ids:
                    if self.model_param.sample_id_generator != consts.GUEST:
                        self.model_param.sample_id_generator = consts.GUEST
                        LOGGER.warning(f"when not sync_intersect_ids with match id process,"
                                       f"sample_id_generator is set to Guest")

            proc_obj = MatchIDIntersect(sample_id_generator=self.model_param.sample_id_generator, role=self.role)
            proc_obj.new_sample_id = self.model_param.new_sample_id
            if data_overview.check_with_inst_id(data_inst) or self.model_param.with_sample_id:
                proc_obj.use_sample_id()
            match_data = proc_obj.recover(data=data_inst)
            intersect_data = match_data

        if self.role == consts.HOST:
            cache_id = cache_meta[str(self.guest_party_id)].get("cache_id")
            self.transfer_variable.cache_id.remote(cache_id, role=consts.GUEST, idx=0)
            guest_cache_id = self.transfer_variable.cache_id.get(role=consts.GUEST, idx=0)
            if guest_cache_id != cache_id:
                raise ValueError(f"cache_id check failed. cache_id from host & guest must match.")
        elif self.role == consts.GUEST:
            for i, party_id in enumerate(self.host_party_id_list):
                cache_id = cache_meta[str(party_id)].get("cache_id")
                self.transfer_variable.cache_id.remote(cache_id,
                                                       role=consts.HOST,
                                                       idx=i)
                host_cache_id = self.transfer_variable.cache_id.get(role=consts.HOST, idx=i)
                if host_cache_id != cache_id:
                    raise ValueError(f"cache_id check failed. cache_id from host & guest must match.")
        else:
            raise ValueError(f"Role {self.role} cannot run intersection transform.")

        self.intersect_ids = self.intersection_obj.run_cache_intersect(intersect_data, cache_data)
        if self.use_match_id_process:
            if not self.model_param.sync_intersect_ids:
                self.intersect_ids = proc_obj.expand(self.intersect_ids,
                                                     match_data=match_data,
                                                     owner_only=True)
            else:
                self.intersect_ids = proc_obj.expand(self.intersect_ids, match_data=match_data)
            if self.intersect_ids and self.model_param.only_output_key:
                self.intersect_ids = self.intersect_ids.mapValues(lambda v: Instance(inst_id=v.inst_id))
                self.intersect_ids.schema = {"match_id_name": data_inst.schema["match_id_name"],
                                             "sid_name": data_inst.schema["sid_name"]}

        LOGGER.info("Finish intersection")

        if self.intersect_ids:
            data_count = data_inst.count()
            self.intersect_num = self.intersect_ids.count()
            self.intersect_rate = self.intersect_num / data_count
            self.unmatched_num = data_count - self.intersect_num
            self.unmatched_rate = 1 - self.intersect_rate

        self.set_summary(self.get_model_summary())
        self.callback()

        result_data = self.intersect_ids
        if not self.use_match_id_process:
            if not self.intersection_obj.only_output_key and result_data:
                result_data = self.intersection_obj.get_value_from_data(result_data, data_inst)
                self.intersect_ids.schema = result_data.schema
                LOGGER.debug(f"not only_output_key, restore value called")
            if self.intersection_obj.only_output_key and result_data:
                schema = {"sid_name": data_inst.schema["sid_name"]}
                result_data = result_data.mapValues(lambda v: 1)
                result_data.schema = schema
                self.intersect_ids.schema = schema

        if self.model_param.join_method == consts.LEFT_JOIN:
            result_data = self.__sync_join_id(data_inst, self.intersect_ids)
            result_data.schema = self.intersect_ids.schema

        return result_data
コード例 #2
0
    def intersect_online_process(self, data_inst, caches):
        # LOGGER.debug(f"caches is: {caches}")
        cache_data, cache_meta = list(caches.values())[0]
        intersect_meta = list(cache_meta.values())[0]["intersect_meta"]
        # LOGGER.debug(f"intersect_meta is: {intersect_meta}")
        self.callback_cache_meta(intersect_meta)
        self.load_intersect_meta(intersect_meta)
        self.init_intersect_method()
        self.intersection_obj.load_intersect_key(cache_meta)

        if data_overview.check_with_inst_id(data_inst):
            self.use_match_id_process = True
            LOGGER.info(f"use match_id_process")
        intersect_data = data_inst
        if self.use_match_id_process:
            if len(self.host_party_id_list) > 1 and self.model_param.sample_id_generator != consts.GUEST:
                raise ValueError("While multi-host, sample_id_generator should be guest.")
            if self.model_param.intersect_method == consts.RAW:
                if self.model_param.sample_id_generator != self.intersection_obj.join_role:
                    raise ValueError(f"When using raw intersect with match id process,"
                                     f"'join_role' should be same role as 'sample_id_generator'")
            else:
                if not self.model_param.sync_intersect_ids:
                    if self.model_param.sample_id_generator != consts.GUEST:
                        self.model_param.sample_id_generator = consts.GUEST
                        LOGGER.warning(f"when not sync_intersect_ids with match id process,"
                                       f"sample_id_generator is set to Guest")

            proc_obj = MatchIDIntersect(sample_id_generator=self.model_param.sample_id_generator, role=self.role)
            proc_obj.new_sample_id = self.model_param.new_sample_id
            if data_overview.check_with_inst_id(data_inst) or self.model_param.with_sample_id:
                proc_obj.use_sample_id()
            match_data = proc_obj.recover(data=data_inst)
            intersect_data = match_data

        if self.role == consts.HOST:
            cache_id = cache_meta[str(self.guest_party_id)].get("cache_id")
            self.transfer_variable.cache_id.remote(cache_id, role=consts.GUEST, idx=0)
            guest_cache_id = self.transfer_variable.cache_id.get(role=consts.GUEST, idx=0)
            if guest_cache_id != cache_id:
                raise ValueError(f"cache_id check failed. cache_id from host & guest must match.")
        elif self.role == consts.GUEST:
            for i, party_id in enumerate(self.host_party_id_list):
                cache_id = cache_meta[str(party_id)].get("cache_id")
                self.transfer_variable.cache_id.remote(cache_id,
                                                       role=consts.HOST,
                                                       idx=i)
                host_cache_id = self.transfer_variable.cache_id.get(role=consts.HOST, idx=i)
                if host_cache_id != cache_id:
                    raise ValueError(f"cache_id check failed. cache_id from host & guest must match.")
        else:
            raise ValueError(f"Role {self.role} cannot run intersection transform.")

        self.intersect_ids = self.intersection_obj.run_cache_intersect(intersect_data, cache_data)
        if self.use_match_id_process:
            if not self.model_param.sync_intersect_ids:
                self.intersect_ids = proc_obj.expand(self.intersect_ids,
                                                     match_data=match_data,
                                                     owner_only=True)
            else:
                self.intersect_ids = proc_obj.expand(self.intersect_ids, match_data=match_data)
            if self.intersect_ids and self.model_param.only_output_key:
                self.intersect_ids = self.intersect_ids.mapValues(lambda v: Instance(inst_id=v.inst_id))
                self.intersect_ids.schema = {"match_id_name": data_inst.schema["match_id_name"],
                                             "sid_name": data_inst.schema["sid_name"]}

        LOGGER.info("Finish intersection")

        if self.intersect_ids:
            data_count = data_inst.count()
            self.intersect_num = self.intersect_ids.count()
            self.intersect_rate = self.intersect_num / data_count
            self.unmatched_num = data_count - self.intersect_num
            self.unmatched_rate = 1 - self.intersect_rate

        self.set_summary(self.get_model_summary())
        self.callback()

        result_data = self.intersect_ids
        if not self.use_match_id_process:
            if not self.intersection_obj.only_output_key and result_data:
                result_data = self.intersection_obj.get_value_from_data(result_data, data_inst)
                self.intersect_ids.schema = result_data.schema
                LOGGER.debug(f"not only_output_key, restore value called")
            if self.intersection_obj.only_output_key and result_data:
                schema = {"sid_name": data_inst.schema["sid_name"]}
                result_data = result_data.mapValues(lambda v: 1)
                result_data.schema = schema
                self.intersect_ids.schema = schema

        if self.model_param.join_method == consts.LEFT_JOIN:
            result_data = self.__sync_join_id(data_inst, self.intersect_ids)
            result_data.schema = self.intersect_ids.schema

        return result_data
コード例 #3
0
 def _recover_match_id(self, data_instance):
     self.proc_obj = MatchIDIntersect(sample_id_generator=consts.GUEST, role=self.intersection_obj.role)
     self.proc_obj.new_join_id = False
     self.proc_obj.use_sample_id()
     match_data = self.proc_obj.recover(data=data_instance)
     return match_data
コード例 #4
0
    def fit(self, data):
        if self.component_properties.caches:
            return self.intersect_online_process(data, self.component_properties.caches)
        self.init_intersect_method()
        if data_overview.check_with_inst_id(data):
            self.use_match_id_process = True
            LOGGER.info(f"use match_id_process")

        if self.use_match_id_process:
            if len(self.host_party_id_list) > 1 and self.model_param.sample_id_generator != consts.GUEST:
                raise ValueError("While multi-host, sample_id_generator should be guest.")
            if self.model_param.intersect_method == consts.RAW:
                if self.model_param.sample_id_generator != self.intersection_obj.join_role:
                    raise ValueError(f"When using raw intersect with match id process,"
                                     f"'join_role' should be same role as 'sample_id_generator'")
            else:
                if not self.model_param.sync_intersect_ids:
                    if self.model_param.sample_id_generator != consts.GUEST:
                        self.model_param.sample_id_generator = consts.GUEST
                        LOGGER.warning(f"when not sync_intersect_ids with match id process,"
                                         f"sample_id_generator is set to Guest")

            self.proc_obj = MatchIDIntersect(sample_id_generator=self.model_param.sample_id_generator, role=self.role)
            self.proc_obj.new_sample_id = self.model_param.new_sample_id
            if data_overview.check_with_inst_id(data) or self.model_param.with_sample_id:
                self.proc_obj.use_sample_id()
            match_data = self.proc_obj.recover(data=data)
            if self.intersection_obj.run_cache:
                self.cache_output = self.intersection_obj.generate_cache(match_data)
                intersect_meta = self.intersection_obj.get_intersect_method_meta()
                self.callback_cache_meta(intersect_meta)
                return data
            if self.intersection_obj.cardinality_only:
                self.intersection_obj.run_cardinality(match_data)
            else:
                intersect_data = match_data
                if self.model_param.run_preprocess:
                    intersect_data = self.run_preprocess(match_data)
                self.intersect_ids = self.intersection_obj.run_intersect(intersect_data)
        else:
            if self.intersection_obj.run_cache:
                self.cache_output = self.intersection_obj.generate_cache(data)
                intersect_meta = self.intersection_obj.get_intersect_method_meta()
                # LOGGER.debug(f"callback intersect meta is: {intersect_meta}")
                self.callback_cache_meta(intersect_meta)
                return data
            if self.intersection_obj.cardinality_only:
                self.intersection_obj.run_cardinality(data)
            else:
                intersect_data = data
                if self.model_param.run_preprocess:
                    intersect_data = self.run_preprocess(data)
                self.intersect_ids = self.intersection_obj.run_intersect(intersect_data)

        if self.intersection_obj.cardinality_only:
            if self.intersection_obj.intersect_num is not None:
                data_count = data.count()
                self.intersect_num = self.intersection_obj.intersect_num
                self.intersect_rate = self.intersect_num / data_count
                self.unmatched_num = data_count - self.intersect_num
                self.unmatched_rate = 1 - self.intersect_rate
            # self.model = self.intersection_obj.get_model()
            self.set_summary(self.get_model_summary())
            self.callback()
            return data

        if self.use_match_id_process:
            if self.model_param.sync_intersect_ids:
                self.intersect_ids = self.proc_obj.expand(self.intersect_ids, match_data=match_data)
            else:
                # self.intersect_ids = match_data
                self.intersect_ids = self.proc_obj.expand(self.intersect_ids,
                                                          match_data=match_data,
                                                          owner_only=True)
            if self.model_param.only_output_key and self.intersect_ids:
                self.intersect_ids = self.intersect_ids.mapValues(lambda v: Instance(inst_id=v.inst_id))
                self.intersect_ids.schema = {"match_id_name": data.schema["match_id_name"],
                                             "sid_name": data.schema["sid_name"]}

        LOGGER.info("Finish intersection")

        if self.intersect_ids:
            data_count = data.count()
            self.intersect_num = self.intersect_ids.count()
            self.intersect_rate = self.intersect_num / data_count
            self.unmatched_num = data_count - self.intersect_num
            self.unmatched_rate = 1 - self.intersect_rate

        self.set_summary(self.get_model_summary())
        self.callback()

        result_data = self.intersect_ids
        if not self.use_match_id_process and not self.intersection_obj.only_output_key and result_data:
            result_data = self.intersection_obj.get_value_from_data(result_data, data)
            LOGGER.debug(f"not only_output_key, restore value called")

        if self.model_param.join_method == consts.LEFT_JOIN:
            result_data = self.__sync_join_id(data, self.intersect_ids)
            result_data.schema = self.intersect_ids.schema

        return result_data
コード例 #5
0
class BaseSecureInformationRetrieval(ModelBase):
    """

    """

    def __init__(self):
        super(BaseSecureInformationRetrieval, self).__init__()
        self.model_param = SecureInformationRetrievalParam()
        self.security_level = None
        self.commutative_cipher = None
        self.transfer_variable = None
        self.block_num = None  # N in 1-N OT
        self.coverage = None  # the percentage of transactions whose values are successfully retrieved

        self.dh_params = None
        self.intersection_obj = None
        self.proc_obj = None
        self.with_inst_id = None
        self.need_label = False
        self.target_cols = None

        # For callback
        self.metric_name = "sir"
        self.metric_namespace = "train"
        self.metric_type = "SIR"

    def _init_base_model(self, param: SecureInformationRetrievalParam):
        self.transfer_variable = SecureInformationRetrievalTransferVariable()
        self._init_transfer_variable()

        self.model_param = param
        self.security_level = self.model_param.security_level
        self.dh_params = self.model_param.dh_params
        self.target_cols = self.model_param.target_cols

    def _init_transfer_variable(self):
        self.transfer_variable.natural_indexation.disable_auto_clean()
        self.transfer_variable.id_blocks_ciphertext.disable_auto_clean()

    @staticmethod
    def _abnormal_detection(data_instances):
        """
        Make sure input data_instances is valid.
        """
        abnormal_detection.empty_table_detection(data_instances)
        abnormal_detection.empty_feature_detection(data_instances)

    """
    @staticmethod
    def record_original_id(k, v):
        if isinstance(k, str):
            restored_id = conversion.int_to_str(conversion.str_to_int(k))
        else:
            restored_id = k
        return (restored_id, k)
    """

    def _check_need_label(self):
        return len(self.target_cols) == 0

    def _recover_match_id(self, data_instance):
        self.proc_obj = MatchIDIntersect(sample_id_generator=consts.GUEST, role=self.intersection_obj.role)
        self.proc_obj.new_join_id = False
        self.proc_obj.use_sample_id()
        match_data = self.proc_obj.recover(data=data_instance)
        return match_data

    def _restore_sample_id(self, data_instances):
        restore_data = self.proc_obj.expand(data_instances, owner_only=True)

        return restore_data

    def _raw_information_retrieval(self, data_instance):
        """
        If security_level == 0, then perform raw information retrieval
        :param data_instance:
        :return:
        """
        pass

    def _parse_security_level(self, data_instance):
        """
        Cooperatively parse the security level index
        :param data_instance:
        :return:
        """
        pass

    def _sync_natural_index(self, id_list_arr):
        """
        guest -> host
        :param id_list_arr:
        :return:
        """

    def _sync_natural_indexation(self, id_list, time):
        """
        guest -> host
        :param id_list:
        :param time
        :return:
        """

    def _sync_block_num(self):
        """
        guest -> host
        :param
        :return:
        """

    def _transmit_value_ciphertext(self, id_block, time):
        """
        host -> guest
        :param id_block:
        :param time: int
        :return:
        """

    def _check_oblivious_transfer_condition(self):
        """
        1-N OT with N no smaller than 2 is supported
        :return:
        """
        return self.block_num >= 2

    def _failure_response(self):
        """
        If even 1-2 OT cannot be performed, make failure response
        :return:
        """
        raise ValueError("Cannot perform even 1-2 OT, recommend use raw retrieval")

    def _sync_coverage(self, data_instance):
        """
        guest -> host
        :param data_instance:
        :return:
        """
        pass

    def _sync_nonce_list(self, nonce, time):
        """
        host -> guest
        :param nonce:
        :return:
        """
        pass

    def export_model(self):
        if self.model_output is not None:
            return self.model_output

        meta_obj = self._get_meta()
        param_obj = self._get_param()
        result = {
            MODEL_META_NAME: meta_obj,
            MODEL_PARAM_NAME: param_obj
        }
        self.model_output = result
        return result

    def _get_meta(self):
        return sir_meta_pb2.SecureInformationRetrievalMeta(
            security_level=self.security_level,
            oblivious_transfer_protocol=self.model_param.oblivious_transfer_protocol,
            commutative_encryption=self.model_param.commutative_encryption,
            non_committing_encryption=self.model_param.non_committing_encryption,
            key_size=self.model_param.key_size,
            raw_retrieval=self.model_param.raw_retrieval
        )

    def _get_param(self):
        return sir_param_pb2.SecureInformationRetrievalParam(
            coverage=self.coverage,
            block_num=self.block_num
        )

    def _display_result(self, block_num=None):
        if block_num is None:
            self.callback_metric(metric_name=self.metric_name,
                                 metric_namespace=self.metric_namespace,
                                 metric_data=[Metric("Coverage", self.coverage),
                                              Metric("Block number", self.block_num)])
            self.tracker.set_metric_meta(metric_namespace=self.metric_namespace,
                                         metric_name=self.metric_name,
                                         metric_meta=MetricMeta(self.metric_name, metric_type="INTERSECTION"))
        else:
            self.callback_metric(metric_name=self.metric_name,
                                 metric_namespace=self.metric_namespace,
                                 metric_data=[Metric("Coverage", self.coverage),
                                              Metric("Block number", block_num)])
            self.tracker.set_metric_meta(metric_namespace=self.metric_namespace,
                                         metric_name=self.metric_name,
                                         metric_meta=MetricMeta(self.metric_name, metric_type="INTERSECTION"))

    """