Esempio n. 1
0
class TestRsaIntersectHost(unittest.TestCase):
    def setUp(self):
        self.jobid = str(uuid.uuid1())
        session.init(self.jobid)

        from federatedml.statistic.intersect import RsaIntersectionHost
        from federatedml.statistic.intersect import RawIntersectionHost
        intersect_param = IntersectParam()
        self.rsa_operator = RsaIntersectionHost()
        self.rsa_operator.load_params(intersect_param)
        self.raw_operator = RawIntersectionHost()
        self.raw_operator.load_params(intersect_param)

    def data_to_table(self, data):
        return session.parallelize(data, include_key=True, partition=2)

    def test_func_generate_rsa_key(self):
        res = self.rsa_operator.generate_rsa_key(1024)
        self.assertEqual(65537, res[0])

    def test_get_common_intersection(self):
        d1 = [(1, "a"), (2, "b"), (4, "c")]
        d2 = [(4, "a"), (5, "b"), (6, "c")]
        d3 = [(4, "a"), (5, "b"), (7, "c")]
        D1 = self.data_to_table(d1)
        D2 = self.data_to_table(d2)
        D3 = self.data_to_table(d3)

        res = self.raw_operator.get_common_intersection([D1, D2, D3])
        gt = [(4, "id")]
        self.assertListEqual(list(res.collect()), gt)

    def tearDown(self):
        session.stop()
Esempio n. 2
0
    def setUp(self):
        self.jobid = str(uuid.uuid1())
        session.init(self.jobid)

        from federatedml.statistic.intersect import RsaIntersectionHost
        from federatedml.statistic.intersect import RawIntersectionHost
        intersect_param = IntersectParam()
        self.rsa_operator = RsaIntersectionHost()
        self.rsa_operator.load_params(intersect_param)
        self.raw_operator = RawIntersectionHost()
        self.raw_operator.load_params(intersect_param)
Esempio n. 3
0
    def __init_intersect_method(self):
        LOGGER.info("Using {} intersection, role is {}".format(self.model_param.intersect_method, self.role))
        self.host_party_id_list = self.component_properties.host_party_idlist
        self.guest_party_id = self.component_properties.guest_partyid
        if self.role == consts.HOST:
            self.host_party_id = self.component_properties.local_partyid

        if self.model_param.intersect_method == "rsa":
            if self.role == consts.HOST:
                self.intersection_obj = RsaIntersectionHost(self.model_param)
                self.intersection_obj.host_party_id = self.host_party_id
            elif self.role == consts.GUEST:
                self.intersection_obj = RsaIntersectionGuest(self.model_param)
            else:
                raise ValueError("role {} is not support".format(self.role))

            self.intersection_obj.guest_party_id = self.guest_party_id

        elif self.model_param.intersect_method == "raw":
            if self.role == consts.HOST:
                self.intersection_obj = RawIntersectionHost(self.model_param)
            elif self.role == consts.GUEST:
                self.intersection_obj = RawIntersectionGuest(self.model_param)
            else:
                raise ValueError("role {} is not support".format(self.role))
            self.intersection_obj.task_id = self.taskid
        else:
            raise ValueError("intersect_method {} is not support yet".format(self.model_param.intersect_method))

        if self.role == consts.HOST:
            self.intersection_obj.host_party_id = self.host_party_id
        self.intersection_obj.guest_party_id = self.guest_party_id
        self.intersection_obj.host_party_id_list = self.host_party_id_list
Esempio n. 4
0
 def init_intersect_obj(self):
     LOGGER.debug('creating intersect obj done')
     intersect_obj = RsaIntersectionHost()
     intersect_obj.host_party_id = self.component_properties.local_partyid
     intersect_obj.host_party_id_list = self.component_properties.host_party_idlist
     intersect_obj.load_params(self.intersect_param)
     return intersect_obj
Esempio n. 5
0
    def init_intersect_method(self):
        super().init_intersect_method()
        self.host_party_id = self.component_properties.local_partyid

        if self.model_param.intersect_method == "rsa":
            self.intersection_obj = RsaIntersectionHost()

        elif self.model_param.intersect_method == "raw":
            self.intersection_obj = RawIntersectionHost()
            self.intersection_obj.tracker = self.tracker
            self.intersection_obj.task_version_id = self.task_version_id
        else:
            raise ValueError("intersect_method {} is not support yet".format(
                self.model_param.intersect_method))

        self.intersection_obj.host_party_id = self.host_party_id
        self.intersection_obj.guest_party_id = self.guest_party_id
        self.intersection_obj.host_party_id_list = self.host_party_id_list
        self.intersection_obj.load_params(self.model_param)
Esempio n. 6
0
 def __init_intersect_method(self):
     LOGGER.info("Using {} intersection, role is {}".format(
         self.model_param.intersect_method, self.role))
     if self.model_param.intersect_method == "rsa":
         if self.role == consts.HOST:
             self.intersection_obj = RsaIntersectionHost(self.model_param)
         elif self.role == consts.GUEST:
             self.intersection_obj = RsaIntersectionGuest(self.model_param)
         else:
             raise ValueError("role {} is not support".format(self.role))
     elif self.model_param.intersect_method == "raw":
         if self.role == consts.HOST:
             self.intersection_obj = RawIntersectionHost(self.model_param)
         elif self.role == consts.GUEST:
             self.intersection_obj = RawIntersectionGuest(self.model_param)
         else:
             raise ValueError("role {} is not support".format(self.role))
     else:
         raise ValueError("intersect_method {} is not support yet".format(
             self.model_param.intersect_method))