コード例 #1
0
 def __init__(
     self,
     LearnerParams,
     PrivacyParams,
     EvaluatorParams,
     DatasetParams,
     querypool,
     available_actions,
 ):
     self.pp = PrivacyParams
     self.ev = EvaluatorParams
     self.dd = DatasetParams
     self.lp = LearnerParams
     self.state = 0
     self.query = None
     self.querypool = querypool
     self.available_actions = available_actions
     self.action_space = spaces.Discrete(int(len(self.available_actions)))
     self.observation_space = spaces.Discrete(int(
         self.lp.observation_space))  # number of query
     self.df, self.metadata = create_simulated_dataset(
         self.dd.dataset_size, "dataset")
     self.state_query_pair_base = {}
     self.pool = list(range(self.lp.observation_space - 2, 1, -1))
     self.info = {}
     self.output = []
     self.episode = 0
     self.reward = 0
     self.d1_dataset, self.d2_dataset, self.d1_metadata, self.d2_metadata = generate_neighbors(
         self.df, self.metadata)
コード例 #2
0
 def learn(self, querypool, export_as_csv=False):
     output = []
     for i in range(len(querypool)):
         df, metadata = create_simulated_dataset(self.dd.dataset_size, "dataset")
         d1_dataset, d2_dataset, d1_metadata, d2_metadata = generate_neighbors(df, metadata)
         d1 = PandasReader(d1_metadata, d1_dataset)
         d2 = PandasReader(d2_metadata, d2_dataset)
         eval = DPEvaluator()
         pa = DPSingletonQuery()
         key_metrics = eval.evaluate([d1_metadata, d1], [d2_metadata, d2], pa, querypool[i], self.pp, self.ev)
         if key_metrics['__key__'].dp_res is None:
             dp_res = key_metrics['__key__'].dp_res
             error =  key_metrics['__key__'].error
             output.append({"query":querypool[i], "dpresult": dp_res, "jensen_shannon_divergence":None, "error": error})
         else:
             res_list = []
             for key, metrics in key_metrics.items():
                 dp_res = metrics.dp_res
                 js_res = metrics.jensen_shannon_divergence
                 res_list.append([dp_res, js_res])
             dp_res = np.all(np.array([res[0] for res in res_list]))
             js_res = (np.array([res[1] for res in res_list])).max()
             output.append({"query":querypool[i], "dpresult": dp_res,"jensen_shannon_divergence": js_res, "error":None})
     if export_as_csv:
         write_to_csv('Bandit.csv', output, flag='bandit')
     else:
         return output
コード例 #3
0
 def _compute_reward(self, query):
     ast_transform = self.observe(query)
     d1_query = query
     d2_query = query.replace("d1.d1", "d2.d2")
     d1_dataset, d2_dataset, d1_metadata, d2_metadata = generate_neighbors(
         self.df, self.metadata)
     d1 = PandasReader(d1_metadata, d1_dataset)
     d2 = PandasReader(d2_metadata, d2_dataset)
     eval = DPEvaluator()
     pa = DPSingletonQuery()
     key_metrics = eval.evaluate([d1_metadata, d1], [d2_metadata, d2], pa,
                                 query, self.pp, self.ev)
     message = None
     if key_metrics["__key__"].dp_res is None:
         dpresult = "DP_BUG"
         self.reward = 1
         message = key_metrics["__key__"].error
     elif key_metrics["__key__"].dp_res == False:
         self._game_ended = True
         dpresult = "DP_FAIL"
         self.reward = 20
         message = "dp_res_False"
     elif (key_metrics["__key__"].dp_res == True and
           key_metrics["__key__"].jensen_shannon_divergence == math.inf):
         self._game_ended = True
         dpresult = "DP_BUG"
         self.reward = 20
         message = "jsdistance_is_inf"
     else:
         res_list = []
         for key, metrics in key_metrics.items():
             dp_res = metrics.dp_res
             js_res = metrics.jensen_shannon_divergence
             # ws_res = metrics.wasserstein_distance
             res_list.append([dp_res, js_res])
         dp_res = np.all(np.array([res[0] for res in res_list]))
         js_res = (np.array([res[1] for res in res_list])).max()
         # ws_res = (np.array([res[2] for res in res_list])).max()
         if dp_res == True:
             dpresult = "DP_PASS"
             self.reward = js_res
     return dpresult, self.reward, message, d1, d2