示例#1
0
class TestRsaIntersectHost(unittest.TestCase):
    def setUp(self):
        self.jobid = str(uuid.uuid1())
        session.init(self.jobid)

        from federatedml.statistic.intersect import RsaIntersectionHost
        from federatedml.statistic.intersect import RawIntersectionHost
        intersect_param = IntersectParam()
        self.rsa_operator = RsaIntersectionHost()
        self.rsa_operator.load_params(intersect_param)
        self.raw_operator = RawIntersectionHost()
        self.raw_operator.load_params(intersect_param)

    def data_to_table(self, data):
        return session.parallelize(data, include_key=True, partition=2)

    def test_func_generate_rsa_key(self):
        res = self.rsa_operator.generate_rsa_key(1024)
        self.assertEqual(65537, res[0])

    def test_get_common_intersection(self):
        d1 = [(1, "a"), (2, "b"), (4, "c")]
        d2 = [(4, "a"), (5, "b"), (6, "c")]
        d3 = [(4, "a"), (5, "b"), (7, "c")]
        D1 = self.data_to_table(d1)
        D2 = self.data_to_table(d2)
        D3 = self.data_to_table(d3)

        res = self.raw_operator.get_common_intersection([D1, D2, D3])
        gt = [(4, "id")]
        self.assertListEqual(list(res.collect()), gt)

    def tearDown(self):
        session.stop()
示例#2
0
    def setUp(self):
        self.jobid = str(uuid.uuid1())
        session.init(self.jobid)

        from federatedml.statistic.intersect import RsaIntersectionHost
        from federatedml.statistic.intersect import RawIntersectionHost
        intersect_param = IntersectParam()
        self.rsa_operator = RsaIntersectionHost()
        self.rsa_operator.load_params(intersect_param)
        self.raw_operator = RawIntersectionHost()
        self.raw_operator.load_params(intersect_param)
示例#3
0
    def __init_intersect_method(self):
        LOGGER.info("Using {} intersection, role is {}".format(self.model_param.intersect_method, self.role))
        self.host_party_id_list = self.component_properties.host_party_idlist
        self.guest_party_id = self.component_properties.guest_partyid
        if self.role == consts.HOST:
            self.host_party_id = self.component_properties.local_partyid

        if self.model_param.intersect_method == "rsa":
            if self.role == consts.HOST:
                self.intersection_obj = RsaIntersectionHost(self.model_param)
                self.intersection_obj.host_party_id = self.host_party_id
            elif self.role == consts.GUEST:
                self.intersection_obj = RsaIntersectionGuest(self.model_param)
            else:
                raise ValueError("role {} is not support".format(self.role))

            self.intersection_obj.guest_party_id = self.guest_party_id

        elif self.model_param.intersect_method == "raw":
            if self.role == consts.HOST:
                self.intersection_obj = RawIntersectionHost(self.model_param)
            elif self.role == consts.GUEST:
                self.intersection_obj = RawIntersectionGuest(self.model_param)
            else:
                raise ValueError("role {} is not support".format(self.role))
            self.intersection_obj.task_id = self.taskid
        else:
            raise ValueError("intersect_method {} is not support yet".format(self.model_param.intersect_method))

        if self.role == consts.HOST:
            self.intersection_obj.host_party_id = self.host_party_id
        self.intersection_obj.guest_party_id = self.guest_party_id
        self.intersection_obj.host_party_id_list = self.host_party_id_list
示例#4
0
    def init_intersect_method(self):
        super().init_intersect_method()
        self.host_party_id = self.component_properties.local_partyid

        if self.model_param.intersect_method == "rsa":
            self.intersection_obj = RsaIntersectionHost()

        elif self.model_param.intersect_method == "raw":
            self.intersection_obj = RawIntersectionHost()
            self.intersection_obj.tracker = self.tracker
            self.intersection_obj.task_version_id = self.task_version_id
        else:
            raise ValueError("intersect_method {} is not support yet".format(
                self.model_param.intersect_method))

        self.intersection_obj.host_party_id = self.host_party_id
        self.intersection_obj.guest_party_id = self.guest_party_id
        self.intersection_obj.host_party_id_list = self.host_party_id_list
        self.intersection_obj.load_params(self.model_param)
示例#5
0
    def intersect(self, data_instance, intersect_flowid=''):
        if data_instance is None:
            return data_instance

        if self.workflow_param.need_intersect:
            header = data_instance.schema.get('header')
            LOGGER.info("need_intersect: true!")
            intersect_param = IntersectParam()
            self.intersect_params = ParamExtract.parse_param_from_config(
                intersect_param, self.config_path)

            LOGGER.info("Start intersection!")
            if self.role == consts.HOST:
                intersect_operator = RawIntersectionHost(self.intersect_params)
            elif self.role == consts.GUEST:
                intersect_operator = RawIntersectionGuest(
                    self.intersect_params)
            elif self.role == consts.ARBITER:
                return data_instance
            else:
                raise ValueError("Unknown role of workflow")
            intersect_operator.set_flowid(intersect_flowid)
            intersect_ids = intersect_operator.run(data_instance)
            LOGGER.info("finish intersection!")

            intersect_data_instance = intersect_ids.join(
                data_instance, lambda i, d: d)
            LOGGER.info("get intersect data_instance!")
            # LOGGER.debug("intersect_data_instance count:{}".format(intersect_data_instance.count()))
            intersect_data_instance.schema['header'] = header
            return intersect_data_instance

        else:
            LOGGER.info("need_intersect: false!")
            return data_instance
示例#6
0
    def intersect(self, data_instance, intersect_flowid=''):
        if data_instance is None:
            return data_instance

        if self.workflow_param.need_intersect:
            header = data_instance.schema.get('header')
            LOGGER.info("need_intersect: true!")
            intersect_param = IntersectParam()
            self.intersect_params = self._load_param(intersect_param)

            LOGGER.info("Start intersection!")
            if self.role == consts.HOST:
                intersect_operator = RawIntersectionHost(self.intersect_params)
            elif self.role == consts.GUEST:
                intersect_operator = RawIntersectionGuest(
                    self.intersect_params)
            elif self.role == consts.ARBITER:
                return data_instance
            else:
                raise ValueError("Unknown role of workflow")
            intersect_operator.set_flowid(intersect_flowid)
            intersect_ids = intersect_operator.run(data_instance)
            LOGGER.info("finish intersection!")

            return intersect_ids
        else:
            LOGGER.info("need_intersect: false!")
            return data_instance
示例#7
0
 def __init_intersect_method(self):
     LOGGER.info("Using {} intersection, role is {}".format(
         self.model_param.intersect_method, self.role))
     if self.model_param.intersect_method == "rsa":
         if self.role == consts.HOST:
             self.intersection_obj = RsaIntersectionHost(self.model_param)
         elif self.role == consts.GUEST:
             self.intersection_obj = RsaIntersectionGuest(self.model_param)
         else:
             raise ValueError("role {} is not support".format(self.role))
     elif self.model_param.intersect_method == "raw":
         if self.role == consts.HOST:
             self.intersection_obj = RawIntersectionHost(self.model_param)
         elif self.role == consts.GUEST:
             self.intersection_obj = RawIntersectionGuest(self.model_param)
         else:
             raise ValueError("role {} is not support".format(self.role))
     else:
         raise ValueError("intersect_method {} is not support yet".format(
             self.model_param.intersect_method))