def test_db_latest_all(): logging.info("test db load w/ multiple results ...") inp1, res1 = get_sample_records(1)[0] lis1 = list(tuple(res1)) lis2 = list(tuple(res1)) lis3 = list(tuple(res1)) # set timestamp lis1[-1] = 0.0 lis2[-1] = 1.1 lis3[-1] = 9999.9999 res1 = MeasureResult(*lis1) res2 = MeasureResult(*lis2) res3 = MeasureResult(*lis3) _db = database.DummyDatabase() _db.flush() _db.save(inp1, res1, extend=True) load1 = _db.load(inp1) assert load1.timestamp == 0.0 _db.save(inp1, res2, extend=True) load2 = _db.load(inp1) assert load2.timestamp == 1.1 _db.save(inp1, res3, extend=True) load3 = _db.load(inp1) assert load3.timestamp == 9999.9999 load4 = _db.load(inp1, get_all=True) assert encode(inp1, load4[0]) == encode(inp1, res1) assert encode(inp1, load4[1]) == encode(inp1, res2) assert encode(inp1, load4[2]) == encode(inp1, res3)
def test_update(): task, target = get_sample_task() tuner = autotvm.tuner.XGBTuner(task) n_records = 5 records = get_sample_records(n=n_records) tuner.update([inp for inp, _ in records], [res for _, res in records]) assert len(tuner.xs) == n_records assert len(tuner.ys) == n_records assert len(tuner.visited) == n_records assert all(x in tuner.visited for x in tuner.xs)
def test_db_filter(): logging.info("test db filter ...") records = get_sample_records(5) _db = database.DummyDatabase() _db.flush() for inp, result in records: _db.save(inp, result) records = _db.filter(lambda inp, ress: any(r.costs[0] <= 2 for r in ress)) assert len(records) == 2
def test_tuner(): task, target = get_sample_task() records = get_sample_records(n=10) tuner = autotvm.tuner.XGBTuner(task) tuner.load_history(records, min_seed_records=10) # Confirm that loading history successfully loaded a # base_model. assert tuner.cost_model.base_model is not None tuner = autotvm.tuner.XGBTuner(task) tuner.load_history(records, min_seed_records=11) # Confirm that loading history did not load base_model # when not enough records according to `min_seed_records` # are provided assert tuner.cost_model.base_model is None
def test_fit(): task, target = get_sample_task() records = get_sample_records(n=500) base_model = XGBoostCostModel(task, feature_type="itervar", loss_type="rank") base_model.fit_log(records, plan_size=32) upper_model = XGBoostCostModel(task, feature_type="itervar", loss_type="rank") upper_model.load_basemodel(base_model) xs = np.arange(10) ys = np.arange(10) upper_model.fit(xs, ys, plan_size=32)
def test_save_load(): logging.info("test basic db load/save ...") records = get_sample_records(3) inp1, res1 = records[0] inp2, res2 = records[1] inp3, _ = records[2] _db = database.DummyDatabase() _db.flush() _db.save(inp1, res1) _db.save(inp2, res2) load1 = _db.load(inp1) load2 = _db.load(inp2) load3 = _db.load(inp3) assert load1 == res1 assert load2 == res2 assert load3 is None assert load1 != load2
def test_db_hash(): logging.info("test db hash check ...") inp1, res1 = get_sample_records(1)[0] inp2 = copy.deepcopy(inp1) inp1.config.code_hash = "cafecafe" inp2.config.code_hash = "dbffdbff" res2l = list(tuple(res1)) # set timestamp res2l[-1] = -1 res2 = MeasureResult(*res2l) _db = database.DummyDatabase() _db.flush() _db.save(inp1, res1, extend=True) _db.save(inp2, res2, extend=True) load1 = _db.load(inp1) load2 = _db.load(inp2) assert load1 != load2 assert load1.timestamp != -1 assert load2.timestamp == -1
def test_fit(): task, target = get_sample_task() records = get_sample_records(n=500) base_model = XGBoostCostModel(task, feature_type="itervar", loss_type="rank") base_model.fit_log(records, plan_size=32) upper_model = XGBoostCostModel(task, feature_type="itervar", loss_type="rank") upper_model.load_basemodel(base_model) xs = np.arange(10) ys = np.arange(10) upper_model.fit(xs, ys, plan_size=32) # feature lengths are not guaranteed to always be the same upper_model.predict(np.ones(12)) upper_model.predict(np.ones(8))