Esempio n. 1
0
    def setUp(self):
        self.jobid = str(uuid.uuid1())
        session.init(self.jobid)

        from federatedml.statistic.intersect import RsaIntersectionGuest
        from federatedml.statistic.intersect import RsaIntersect
        intersect_param = IntersectParam()
        self.rsa_operator = RsaIntersectionGuest()
        self.rsa_operator.load_params(intersect_param)
        self.rsa_op2 = RsaIntersect()
        self.rsa_op2.load_params(intersect_param)
Esempio n. 2
0
    def __init_intersect_method(self):
        LOGGER.info("Using {} intersection, role is {}".format(self.model_param.intersect_method, self.role))
        self.host_party_id_list = self.component_properties.host_party_idlist
        self.guest_party_id = self.component_properties.guest_partyid
        if self.role == consts.HOST:
            self.host_party_id = self.component_properties.local_partyid

        if self.model_param.intersect_method == "rsa":
            if self.role == consts.HOST:
                self.intersection_obj = RsaIntersectionHost(self.model_param)
                self.intersection_obj.host_party_id = self.host_party_id
            elif self.role == consts.GUEST:
                self.intersection_obj = RsaIntersectionGuest(self.model_param)
            else:
                raise ValueError("role {} is not support".format(self.role))

            self.intersection_obj.guest_party_id = self.guest_party_id

        elif self.model_param.intersect_method == "raw":
            if self.role == consts.HOST:
                self.intersection_obj = RawIntersectionHost(self.model_param)
            elif self.role == consts.GUEST:
                self.intersection_obj = RawIntersectionGuest(self.model_param)
            else:
                raise ValueError("role {} is not support".format(self.role))
            self.intersection_obj.task_id = self.taskid
        else:
            raise ValueError("intersect_method {} is not support yet".format(self.model_param.intersect_method))

        if self.role == consts.HOST:
            self.intersection_obj.host_party_id = self.host_party_id
        self.intersection_obj.guest_party_id = self.guest_party_id
        self.intersection_obj.host_party_id_list = self.host_party_id_list
Esempio n. 3
0
 def init_intersect_obj(self):
     intersect_obj = RsaIntersectionGuest()
     intersect_obj.guest_party_id = self.component_properties.local_partyid
     intersect_obj.host_party_id_list = self.component_properties.host_party_idlist
     intersect_obj.load_params(self.intersect_param)
     LOGGER.debug('intersect done')
     return intersect_obj
Esempio n. 4
0
class TestRsaIntersectGuest(unittest.TestCase):
    def setUp(self):
        self.jobid = str(uuid.uuid1())
        session.init(self.jobid)

        from federatedml.statistic.intersect import RsaIntersectionGuest
        from federatedml.statistic.intersect import RsaIntersect
        intersect_param = IntersectParam()
        self.rsa_operator = RsaIntersectionGuest()
        self.rsa_operator.load_params(intersect_param)
        self.rsa_op2 = RsaIntersect()
        self.rsa_op2.load_params(intersect_param)

    def data_to_table(self, data):
        return session.parallelize(data, include_key=True, partition=2)

    def test_func_map_raw_id_to_encrypt_id(self):
        d1 = [("a", 1), ("b", 2), ("c", 3)]
        d2 = [(4, "a"), (5, "b"), (6, "c")]
        D1 = self.data_to_table(d1)
        D2 = self.data_to_table(d2)

        res = self.rsa_operator.map_raw_id_to_encrypt_id(D1, D2)

        gt = [(4, "id"), (5, "id"), (6, "id")]
        self.assertListEqual(list(res.collect()), gt)

    def test_hash(self):
        hash_operator = Hash("sha256")
        res = str(self.rsa_op2.hash("1", hash_operator))
        self.assertEqual(
            res,
            "6b86b273ff34fce19d6b804eff5a3f5747ada4eaa22f1d49c01e52ddb7875b4b")

    def tearDown(self):
        session.stop()
Esempio n. 5
0
    def init_intersect_method(self):
        super().init_intersect_method()

        if self.model_param.intersect_method == "rsa":
            self.intersection_obj = RsaIntersectionGuest()
            self.intersection_obj.guest_party_id = self.guest_party_id

        elif self.model_param.intersect_method == "raw":
            self.intersection_obj = RawIntersectionGuest()
            self.intersection_obj.tracker = self.tracker
            self.intersection_obj.task_version_id = self.task_version_id
        else:
            raise ValueError("intersect_method {} is not support yet".format(
                self.model_param.intersect_method))

        self.intersection_obj.guest_party_id = self.guest_party_id
        self.intersection_obj.host_party_id_list = self.host_party_id_list
        self.intersection_obj.load_params(self.model_param)
Esempio n. 6
0
 def __init_intersect_method(self):
     LOGGER.info("Using {} intersection, role is {}".format(
         self.model_param.intersect_method, self.role))
     if self.model_param.intersect_method == "rsa":
         if self.role == consts.HOST:
             self.intersection_obj = RsaIntersectionHost(self.model_param)
         elif self.role == consts.GUEST:
             self.intersection_obj = RsaIntersectionGuest(self.model_param)
         else:
             raise ValueError("role {} is not support".format(self.role))
     elif self.model_param.intersect_method == "raw":
         if self.role == consts.HOST:
             self.intersection_obj = RawIntersectionHost(self.model_param)
         elif self.role == consts.GUEST:
             self.intersection_obj = RawIntersectionGuest(self.model_param)
         else:
             raise ValueError("role {} is not support".format(self.role))
     else:
         raise ValueError("intersect_method {} is not support yet".format(
             self.model_param.intersect_method))