コード例 #1
0
def add_and_sub_plaintext(job_id, idx, data_list):
    _, all_parties = session_init(job_id, idx)
    with SPDZ():
        if idx == 0:
            x = FixedPointTensor.from_source("x", data_list[0])
        else:
            x = FixedPointTensor.from_source("x", all_parties[0])
        y = data_list[1]
        a = (x + y).get()
        a1 = (y + x).get()
        b = (x - y).get()
        b1 = (y - x).get()
        return a, a1, b, b1
コード例 #2
0
    def _test_spdz(self):
        table_list = []
        table_int_data_x, table_float_data_x = None, None
        table_int_data_y, table_float_data_y = None, None
        if self.local_party.role == "guest":
            table_int_data_x = session.parallelize(
                self.int_data_x,
                include_key=False,
                partition=self.data_partition)
            table_int_data_x = table_int_data_x.mapValues(
                lambda x: np.array([x]))
            table_float_data_x = session.parallelize(
                self.float_data_x,
                include_key=False,
                partition=self.data_partition)
            table_float_data_x = table_float_data_x.mapValues(
                lambda x: np.array([x]))
        else:
            table_int_data_y = session.parallelize(
                self.int_data_y,
                include_key=False,
                partition=self.data_partition)
            table_int_data_y = table_int_data_y.mapValues(
                lambda y: np.array([y]))
            table_float_data_y = session.parallelize(
                self.float_data_y,
                include_key=False,
                partition=self.data_partition)
            table_float_data_y = table_float_data_y.mapValues(
                lambda y: np.array([y]))

        for tensor_type in ["numpy", "table"]:
            table = PrettyTable()
            table.set_style(ORGMODE)
            field_name = [
                "DataType", "One time consumption",
                f"{self.data_num} times consumption", "relative acc",
                "log2 acc", "operations per second"
            ]
            self._summary["field_name"] = field_name
            table.field_names = field_name

            with SPDZ(local_party=self.local_party,
                      all_parties=self.parties) as spdz:
                for op_type in self.op_test_list:
                    start_time = time.time()
                    for epoch in range(self.test_round):
                        LOGGER.info(
                            f"test spdz, tensor_type: {tensor_type}, op_type: {op_type}, epoch: {epoch}"
                        )
                        tag = "_".join([tensor_type, op_type, str(epoch)])
                        spdz.set_flowid(tag)
                        if self.local_party.role == "guest":
                            if tensor_type == "table":
                                if op_type.startswith("int"):
                                    fixed_point_x = TableTensor.from_source(
                                        "int_x_" + tag, table_int_data_x)
                                    fixed_point_y = TableTensor.from_source(
                                        "int_y_" + tag, self.other_party)
                                else:
                                    fixed_point_x = TableTensor.from_source(
                                        "float_x_" + tag, table_float_data_x)
                                    fixed_point_y = TableTensor.from_source(
                                        "float_y_" + tag, self.other_party)
                            else:
                                if op_type.startswith("int"):
                                    fixed_point_x = NumpyTensor.from_source(
                                        "int_x_" + tag, self.int_data_x)
                                    fixed_point_y = NumpyTensor.from_source(
                                        "int_y_" + tag, self.other_party)
                                else:
                                    fixed_point_x = NumpyTensor.from_source(
                                        "float_x_" + tag, self.float_data_x)
                                    fixed_point_y = NumpyTensor.from_source(
                                        "float_y_" + tag, self.other_party)
                        else:
                            if tensor_type == "table":
                                if op_type.startswith("int"):
                                    fixed_point_y = TableTensor.from_source(
                                        "int_y_" + tag, table_int_data_y)
                                    fixed_point_x = TableTensor.from_source(
                                        "int_x_" + tag, self.other_party)
                                else:
                                    fixed_point_y = TableTensor.from_source(
                                        "float_y_" + tag, table_float_data_y)
                                    fixed_point_x = TableTensor.from_source(
                                        "float_x_" + tag, self.other_party)
                            else:
                                if op_type.startswith("int"):
                                    fixed_point_y = NumpyTensor.from_source(
                                        "int_y_" + tag, self.int_data_y)
                                    fixed_point_x = NumpyTensor.from_source(
                                        "int_x_" + tag, self.other_party)
                                else:
                                    fixed_point_y = NumpyTensor.from_source(
                                        "float_y_" + tag, self.float_data_y)
                                    fixed_point_x = NumpyTensor.from_source(
                                        "float_x_" + tag, self.other_party)

                        ret = self.calculate_ret(op_type, tensor_type,
                                                 fixed_point_x, fixed_point_y)

                    total_time = time.time() - start_time
                    self.output_table(op_type, table, tensor_type, total_time,
                                      ret)

            table_list.append(table)

        self.tracker.log_component_summary(self._summary)
        for table in table_list:
            LOGGER.info(table)