示例#1
0
class TestRsaIntersectGuest(unittest.TestCase):
    def setUp(self):
        self.jobid = str(uuid.uuid1())
        session.init(self.jobid)

        from federatedml.statistic.intersect.intersect_guest import RsaIntersectionGuest
        from federatedml.statistic.intersect.intersect import RsaIntersect
        intersect_param = IntersectParam()
        self.rsa_operator = RsaIntersectionGuest()
        self.rsa_operator.load_params(intersect_param)
        self.rsa_op2 = RsaIntersect()
        self.rsa_op2.load_params(intersect_param)

    def data_to_table(self, data):
        return session.parallelize(data, include_key=True, partition=2)

    def test_func_map_raw_id_to_encrypt_id(self):
        d1 = [("a", 1), ("b", 2), ("c", 3)]
        d2 = [(4, "a"), (5, "b"), (6, "c")]
        D1 = self.data_to_table(d1)
        D2 = self.data_to_table(d2)

        res = self.rsa_operator.map_raw_id_to_encrypt_id(D1, D2)

        gt = [(4, "id"), (5, "id"), (6, "id")]
        self.assertListEqual(list(res.collect()), gt)

    def test_hash(self):
        hash_operator = Hash("sha256")
        res = str(self.rsa_op2.hash("1", hash_operator))
        self.assertEqual(res, "6b86b273ff34fce19d6b804eff5a3f5747ada4eaa22f1d49c01e52ddb7875b4b")

    def tearDown(self):
        session.stop()
    def setUp(self):
        self.jobid = str(uuid.uuid1())
        session.init(self.jobid)

        from federatedml.statistic.intersect.intersect_guest import RsaIntersectionGuest
        from federatedml.statistic.intersect.intersect import RsaIntersect
        intersect_param = IntersectParam()
        self.rsa_operator = RsaIntersectionGuest(intersect_param)
        self.rsa_op2 = RsaIntersect(intersect_param)
示例#3
0
    def intersect(self, data_instance):
        if self.intersect_param.intersect_method == "rsa":
            LOGGER.info("Using rsa intersection")
            self.intersection = RsaIntersectionGuest(self.intersect_param)
        elif self.intersect_param.intersect_method == "raw":
            LOGGER.info("Using raw intersection")
            self.intersection = RawIntersectionGuest(self.intersect_param)
        else:
            raise TypeError("intersect_method {} is not support yet".format(self.workflow_param.intersect_method))

        intersect_ids = self.intersection.run(data_instance)
        
        self.save_intersect_result(intersect_ids)
        LOGGER.info("Save intersect results")