def test_correct_model_param_validation_test(self): correct_model_param = { "FTLModelParam": { "eps": 10e-3, "alpha": 100, "max_iter": 6, "is_encrypt": False } } ftl_model_param = FTLModelParam() ftl_model_param = get_filled_param(ftl_model_param, correct_model_param) FTLModelParamChecker.check_param(ftl_model_param)
def _initialize_model(self, config): LOGGER.debug("@ initialize model") ftl_model_param = FTLModelParam() ftl_local_model_param = FTLLocalModelParam() ftl_data_param = FTLDataParam() ftl_model_param = ParamExtract.parse_param_from_config( ftl_model_param, config) ftl_local_model_param = ParamExtract.parse_param_from_config( ftl_local_model_param, config) self.ftl_data_param = ParamExtract.parse_param_from_config( ftl_data_param, config) self.ftl_transfer_variable = HeteroFTLTransferVariable() self._do_initialize_model(ftl_model_param, ftl_local_model_param, ftl_data_param)
def test_hetero_plain_guest_prepare_table(self): U_A = np.array([[1, 2, 3, 4, 5], [4, 5, 6, 7, 8], [7, 8, 9, 10, 11], [4, 5, 6, 7, 8]]) y = np.array([[1], [-1], [1], [-1]]) Wh = np.ones((5, U_A.shape[1])) bh = np.zeros(U_A.shape[1]) model_param = FTLModelParam(alpha=1, max_iteration=1) autoencoderA = MockAutoencoder(0) autoencoderA.build(U_A.shape[1], Wh, bh) guest = PlainFTLGuestModel(autoencoderA, model_param) converge_func = MockDiffConverge(None) ftl_guest = TestHeteroFTLGuest(guest, model_param, HeteroFTLTransferVariable()) ftl_guest.set_converge_function(converge_func) guest_sample_indexes = np.array([0, 1, 2, 3]) guest_x_dict = {} guest_label_dict = {} instance_dict = {} instance_list = [] np.random.seed(100) for i, feature, label, in zip(guest_sample_indexes, U_A, y): instance = Instance(inst_id=i, features=feature, label=label[0]) guest_x_dict[i] = feature guest_label_dict[i] = label[0] instance_dict[i] = instance instance_list.append(instance) guest_x = create_table(instance_list, indexes=guest_sample_indexes) guest_x, overlap_indexes, non_overlap_indexes, guest_y = ftl_guest.prepare_data( guest_x) print("guest_x", guest_x) print("overlap_indexes", overlap_indexes) print("non_overlap_indexes", non_overlap_indexes) print("guest_y", guest_y)
def test_create_plain_ftl_guest(self): ftl_model_param = FTLModelParam(is_encrypt=False) guest = self.create_guest(ftl_model_param) self.assertTrue(isinstance(guest, HeteroPlainFTLGuest))
def test_create_enc_faster_ftl_guest(self): ftl_model_param = FTLModelParam(is_encrypt=True, enc_ftl="enc_ftl2") guest = self.create_guest(ftl_model_param) self.assertTrue(isinstance(guest, FasterHeteroEncryptFTLGuest))
def test_create_enc_ftl_host(self): ftl_model_param = FTLModelParam(is_encrypt=True, enc_ftl=None) host = self.create_host(ftl_model_param) self.assertTrue(isinstance(host, HeteroEncryptFTLHost))
def assertFTLModelParamValueError(self, param_json, param_to_validate): ftl_model_param = FTLModelParam() ftl_model_param = get_filled_param(ftl_model_param, param_json) with self.assertRaisesRegex(ValueError, param_to_validate): FTLModelParamChecker.check_param(ftl_model_param)