コード例 #1
0
ファイル: start_stepwise.py プロジェクト: zpskt/FATE
def run(model, train_data, validate_data=None):
    if not model.need_run:
        return train_data
    if model.mode == consts.HETERO:
        step_obj = HeteroStepwise()
    else:
        raise ValueError("stepwise currently only support Hetero mode.")
    stepwise_param = _get_stepwise_param(model)
    step_obj.run(stepwise_param, train_data, validate_data, model)
    pred_result = HeteroStepwise.predict(train_data, model)
    LOGGER.info("Finish running Stepwise")
    return pred_result
コード例 #2
0
 def test_add_one(self):
     real_masks = [np.array([1, 1, 1, 1, 0], dtype=bool), np.array([1, 0, 1, 1, 1], dtype=bool)]
     mask_generator = HeteroStepwise.add_one(self.mask)
     i = 0
     for mask in mask_generator:
         np.testing.assert_array_equal(mask, real_masks[i],
                                       f"In stepwise_test add one: mask{mask} not equal to expected {real_masks[i]}")
         i += 1
コード例 #3
0
    def setUp(self):
        self.job_id = str(uuid.uuid1())
        self.session = Session.create(0, 0).init_computing(self.job_id).computing
        model = HeteroStepwise()
        model.__setattr__('role', consts.GUEST)
        model.__setattr__('fit_intercept', True)

        self.model = model
        data_num = 100
        feature_num = 5
        bool_list = [True, False, True, True, False]
        self.str_mask = "10110"
        self.header = ["x1", "x2", "x3", "x4", "x5"]
        self.mask = self.prepare_mask(bool_list)
        self.table = self.prepare_data(data_num, feature_num, self.header, "id", "y")
コード例 #4
0
 def test_get_dfe(self):
     real_dfe = 4
     dfe = HeteroStepwise.get_dfe(self.model, self.str_mask)
     self.assertEqual(dfe, real_dfe)
コード例 #5
0
 def test_string2mask(self):
     real_mask = np.array([1, 0, 1, 1, 0], dtype=bool)
     mask = HeteroStepwise.string2mask(self.str_mask)
     np.testing.assert_array_equal(mask, real_mask)
コード例 #6
0
 def test_mask2string(self):
     real_str_mask = "1011010110"
     str_mask = HeteroStepwise.mask2string(self.mask, self.mask)
     self.assertTrue(str_mask == real_str_mask)