Exemplo n.º 1
0
 def test_load_pkl(self):
     '''Test whether prediction is correct.'''
     assert os.environ['CUDA_VISIBLE_DEVICES'] == '-1'
     bst = load_pickle(model_path)
     x, y = build_dataset()
     test_x = xgb.DMatrix(x)
     res = bst.predict(test_x)
     assert len(res) == 10
Exemplo n.º 2
0
 def test_predictor_type_is_gpu(self):
     '''When CUDA_VISIBLE_DEVICES is not specified, keep using
     `gpu_predictor`'''
     assert 'CUDA_VISIBLE_DEVICES' not in os.environ.keys()
     bst = load_pickle(model_path)
     config = bst.save_config()
     config = json.loads(config)
     assert config['learner']['gradient_booster']['gbtree_train_param'][
         'predictor'] == 'gpu_predictor'
Exemplo n.º 3
0
 def test_predictor_type_is_auto(self):
     '''Under invalid CUDA_VISIBLE_DEVICES, predictor should be set to
     auto'''
     assert os.environ['CUDA_VISIBLE_DEVICES'] == '-1'
     bst = load_pickle(model_path)
     config = bst.save_config()
     config = json.loads(config)
     assert config['learner']['gradient_booster']['gbtree_train_param'][
         'predictor'] == 'auto'
Exemplo n.º 4
0
    def test_wrap_gpu_id(self):
        assert os.environ['CUDA_VISIBLE_DEVICES'] == '0'
        bst = load_pickle(model_path)
        config = bst.save_config()
        config = json.loads(config)
        assert config['learner']['generic_param']['gpu_id'] == '0'

        x, y = build_dataset()
        test_x = xgb.DMatrix(x)
        res = bst.predict(test_x)
        assert len(res) == 10
Exemplo n.º 5
0
    def test_load_pkl(self):
        '''Test whether prediction is correct.'''
        assert os.environ['CUDA_VISIBLE_DEVICES'] == '-1'
        bst = load_pickle(model_path)
        x, y = build_dataset()
        if isinstance(bst, xgb.Booster):
            test_x = xgb.DMatrix(x)
            res = bst.predict(test_x)
        else:
            res = bst.predict(x)
            assert len(res) == 10
            bst.set_params(n_jobs=1)  # triggers a re-configuration
            res = bst.predict(x)

        assert len(res) == 10