def testTrainingStopElastic(self): """This should now stop training after one actor died.""" os.environ["RXGB_ELASTIC_RESTART_DISABLED"] = "0" # The `train()` function raises a RuntimeError ft_manager = FaultToleranceManager.remote() ft_manager.schedule_kill.remote(rank=0, boost_round=3) ft_manager.schedule_kill.remote(rank=1, boost_round=6) ft_manager.delay_return.remote(rank=0, start_boost_round=4, end_boost_round=5) delay_callback = DelayedLoadingCallback(ft_manager, reload_data=True, sleep_time=0.1) die_callback = DieCallback(ft_manager, training_delay=0.25) with self.assertRaises(RuntimeError): train(self.params, RayDMatrix(self.x, self.y), callbacks=[die_callback], num_boost_round=20, ray_params=RayParams(elastic_training=True, max_failed_actors=1, max_actor_restarts=1, num_actors=2, distributed_callbacks=[delay_callback]))
def main(): print("Loading HIGGS data.") colnames = ["label"] + ["feature-%02d" % i for i in range(1, 29)] if args.smoke_test: data = pd.read_csv(SIMPLE_HIGGS_S3_URI, names=colnames) else: data = pd.read_csv(HIGGS_S3_URI, names=colnames) print("Loaded HIGGS data.") # partition on a column df_train = data[(data["feature-01"] < 0.4)] df_validation = data[(data["feature-01"] >= 0.4) & (data["feature-01"] < 0.8)] dtrain = RayDMatrix(df_train, label="label", columns=colnames) dvalidation = RayDMatrix(df_validation, label="label") evallist = [(dvalidation, "eval")] evals_result = {} config = {"tree_method": "hist", "eval_metric": ["logloss", "error"]} train(params=config, dtrain=dtrain, evals_result=evals_result, ray_params=RayParams(max_actor_restarts=1, num_actors=4, cpus_per_actor=2), num_boost_round=100, evals=evallist)
def test_communication_colocation(self): """Checks that Queue and Event actors are colocated with the driver.""" with self.ray_start_cluster() as cluster: cluster.add_node(num_cpus=3) cluster.add_node(num_cpus=3) cluster.wait_for_nodes() ray.init(address=cluster.address) local_node = ray.state.current_node_id() # Note that these will have the same IP in the test cluster assert len(ray.state.node_ids()) == 2 assert local_node in ray.state.node_ids() def _mock_train(*args, _training_state, **kwargs): assert ray.get(_training_state.queue.actor.get_node_id.remote( )) == ray.state.current_node_id() assert ray.get( _training_state.stop_event.actor.get_node_id.remote()) == \ ray.state.current_node_id() return _train(*args, _training_state=_training_state, **kwargs) with patch("xgboost_ray.main._train") as mocked: mocked.side_effect = _mock_train train( self.params, RayDMatrix(self.x, self.y), num_boost_round=2, ray_params=RayParams(max_actor_restarts=1, num_actors=6))
def testCheckpointContinuationValidity(self): """Test that checkpoints are stored and loaded correctly""" # Train once, get checkpoint via callback returns res_1 = {} bst_1 = train( self.params, RayDMatrix(self.x, self.y), callbacks=[ _checkpoint_callback(frequency=1, before_iteration_=False) ], num_boost_round=2, ray_params=RayParams(num_actors=2), additional_results=res_1) last_checkpoint_1 = res_1["callback_returns"][0][-1] last_checkpoint_other_rank_1 = res_1["callback_returns"][1][-1] # Sanity check lc1 = xgb.Booster() lc1.load_model(last_checkpoint_1) self.assertEqual(last_checkpoint_1, last_checkpoint_other_rank_1) self.assertEqual(last_checkpoint_1, lc1.save_raw()) self.assertEqual(bst_1.save_raw(), lc1.save_raw()) # Start new training run, starting from existing model res_2 = {} bst_2 = train( self.params, RayDMatrix(self.x, self.y), callbacks=[ _checkpoint_callback(frequency=1, before_iteration_=True), _checkpoint_callback(frequency=1, before_iteration_=False) ], num_boost_round=4, ray_params=RayParams(num_actors=2), additional_results=res_2, xgb_model=lc1) first_checkpoint_2 = res_2["callback_returns"][0][0] first_checkpoint_other_actor_2 = res_2["callback_returns"][1][0] last_checkpoint_2 = res_2["callback_returns"][0][-1] last_checkpoint_other_actor_2 = res_2["callback_returns"][1][-1] fcp_bst = xgb.Booster() fcp_bst.load_model(first_checkpoint_2) lcp_bst = xgb.Booster() lcp_bst.load_model(last_checkpoint_2) # Sanity check self.assertEqual(first_checkpoint_2, first_checkpoint_other_actor_2) self.assertEqual(last_checkpoint_2, last_checkpoint_other_actor_2) self.assertEqual(bst_2.save_raw(), lcp_bst.save_raw()) # Training should not have proceeded for the first checkpoint, # so trees should be equal self.assertEqual(last_checkpoint_1, fcp_bst.save_raw()) # Training should have proceeded for the last checkpoint, # so trees should not be equal self.assertNotEqual(fcp_bst.save_raw(), lcp_bst.save_raw())
def inner_func(config): with patch("xgboost_ray.main._train", _mock_train): train( params, RayDMatrix(x, y), num_boost_round=4, ray_params=ray_params)
def main(): ray.client("anyscale://").connect() print("Loading HIGGS data.") dask.config.set(scheduler=ray_dask_get) colnames = ["label"] + ["feature-%02d" % i for i in range(1, 29)] data = dd.read_csv(FILE_URL, names=colnames) if args.smoke_test: data = data.head(n=1000) print("Loaded HIGGS data.") # partition on a column df_train = data[(data["feature-01"] < 0.4)] df_train = df_train.persist() df_validation = data[(data["feature-01"] >= 0.4) & (data["feature-01"] < 0.8)] df_validation = df_validation.persist() dtrain = RayDMatrix(df_train, label="label", columns=colnames) dvalidation = RayDMatrix(df_validation, label="label") evallist = [(dvalidation, "eval")] evals_result = {} config = {"tree_method": "hist", "eval_metric": ["logloss", "error"]} train(params=config, dtrain=dtrain, evals_result=evals_result, ray_params=RayParams(max_actor_restarts=1, num_actors=4, cpus_per_actor=2), num_boost_round=100, evals=evallist)
def testFailPrintErrors(self): """Test that XGBoost training errors are propagated""" x = np.random.uniform(0, 1, size=(100, 4)) y = np.random.randint(0, 2, size=100) train_set = RayDMatrix(x, y) try: train( { "objective": "multi:softmax", "num_class": 2, "eval_metric": ["logloss", "error"] }, # This will error train_set, evals=[(train_set, "train")], ray_params=RayParams(num_actors=1, max_actor_restarts=0)) except RuntimeError as exc: self.assertTrue(exc.__cause__) self.assertTrue(isinstance(exc.__cause__, RayActorError)) self.assertTrue(exc.__cause__.__cause__) self.assertTrue(isinstance(exc.__cause__.__cause__, RayTaskError)) self.assertTrue(exc.__cause__.__cause__.cause) self.assertTrue( isinstance(exc.__cause__.__cause__.cause, RayXGBoostTrainingError)) self.assertIn("label and prediction size not match", str(exc.__cause__.__cause__))
def test_no_tune_spread(self): """Tests whether workers are spread when not using Tune.""" with self.ray_start_cluster() as cluster: cluster.add_node(num_cpus=2) cluster.add_node(num_cpus=2) cluster.wait_for_nodes() ray.init(address=cluster.address) ray_params = RayParams( max_actor_restarts=1, num_actors=2, cpus_per_actor=2) def _mock_train(*args, _training_state, **kwargs): try: results = _train( *args, _training_state=_training_state, **kwargs) return results except Exception: raise finally: assert len(_training_state.actors) == 2 if not any(a is None for a in _training_state.actors): actor_infos = ray.state.actors() actor_nodes = [] for a in _training_state.actors: actor_info = actor_infos.get(a._actor_id.hex()) actor_node = actor_info["Address"]["NodeID"] actor_nodes.append(actor_node) assert actor_nodes[0] != actor_nodes[1] with patch("xgboost_ray.main._train", _mock_train): train( self.params, RayDMatrix(self.x, self.y), num_boost_round=4, ray_params=ray_params)
def train_func(config): train_set = RayDMatrix(x, y) train(config["xgb"], dtrain=train_set, cpus_per_actor=1, num_actors=1, num_boost_round=config["num_boost_round"])
def testRanking(self): Xrow = np.array([1, 2, 6, 8, 11, 14, 16, 17]) Xcol = np.array([0, 0, 1, 1, 2, 2, 3, 3]) X = csr_matrix((np.ones(shape=8), (Xrow, Xcol)), shape=(20, 4)).toarray() y = np.array([ 0.0, 1.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0, 1.0, 1.0, 0.0, 0.0 ]) qid = np.array([0] * 5 + [1] * 5 + [2] * 5 + [3] * 5) dtrain = RayDMatrix(X, label=y, qid=qid) params = { "eta": 1, "objective": "rank:pairwise", "eval_metric": ["auc", "aucpr"], "max_depth": 1 } evals_result = {} train(params, dtrain, 10, evals=[(dtrain, "train")], evals_result=evals_result, ray_params=RayParams(num_actors=2, max_actor_restarts=0)) auc_rec = evals_result["train"]["auc"] self.assertTrue(all(p <= q for p, q in zip(auc_rec, auc_rec[1:]))) auc_rec = evals_result["train"]["aucpr"] self.assertTrue((p <= q for p, q in zip(auc_rec, auc_rec[1:])))
def testTrainingStop(self): """This should now stop training after one actor died.""" # The `train()` function raises a RuntimeError with self.assertRaises(RuntimeError): train(self.params, RayDMatrix(self.x, self.y), callbacks=[_kill_callback(self.die_lock_file)], num_boost_round=20, ray_params=RayParams(max_actor_restarts=0, num_actors=2))
def _inner_train(config, checkpoint_dir): train_set = RayDMatrix(x, y) train(config["xgb"], dtrain=train_set, ray_params=ray_params, num_boost_round=config["num_boost_round"], evals=[(train_set, "train")], callbacks=callbacks, **kwargs)
def testTrainingStop(self): """This should now stop training after one actor died.""" # The `train()` function raises a RuntimeError with self.assertRaises(RuntimeError): train(self.params, RayDMatrix(self.x, self.y), callbacks=[self._fail_callback(self.die_lock_file)], num_boost_round=20, max_actor_restarts=0, num_actors=2, checkpoint_path=self.tmpdir)
def test_timeout(self): """Checks that an error occurs when placement group setup times out.""" os.environ["RXGB_PLACEMENT_GROUP_TIMEOUT_S"] = "5" with self.ray_start_cluster() as cluster: ray.init(address=cluster.address) with self.assertRaises(TimeoutError): train(self.params, RayDMatrix(self.x, self.y), num_boost_round=2, ray_params=RayParams(max_actor_restarts=1, num_actors=2, resources_per_actor={"invalid": 1}))
def train_model(data) -> None: logfile = open("/tmp/ray/session_latest/custom.log", "w") def write(msg): logfile.write(f"{msg}\n") logfile.flush() dtrain, dvalidation = data evallist = [(dvalidation, "eval")] evals_result = {} config = { "tree_method": "hist", "eval_metric": ["logloss", "error"], } write("Start training") bst = xgb.train( params=config, dtrain=dtrain, evals_result=evals_result, ray_params=xgb.RayParams(max_actor_restarts=1, num_actors=2, cpus_per_actor=2), num_boost_round=100, evals=evallist, ) write("finish training") return bst
def train_breast_cancer(config, cpus_per_actor=1, num_actors=1): # Load dataset data, labels = datasets.load_breast_cancer(return_X_y=True) # Split into train and test set train_x, test_x, train_y, test_y = train_test_split(data, labels, test_size=0.25) train_set = RayDMatrix(train_x, train_y) test_set = RayDMatrix(test_x, test_y) evals_result = {} bst = train(params=config, dtrain=train_set, evals=[(test_set, "eval")], evals_result=evals_result, max_actor_restarts=1, checkpoint_path="/tmp/checkpoint/", gpus_per_actor=0, cpus_per_actor=cpus_per_actor, num_actors=num_actors, verbose_eval=False, num_boost_round=10) model_path = "simple.xgb" bst.save_model(model_path) print("Final validation error: {:.4f}".format( evals_result["eval"]["error"][-1]))
def testCustomObjectiveFunction(self): """Ensure that custom objective functions work. Runs a custom objective function with pure XGBoost and XGBoost on Ray and compares the prediction outputs.""" self._init_ray() params = self.params.copy() params.pop("objective", None) bst_xgb = xgb.train(params, xgb.DMatrix(self.x, self.y), obj=squared_log) bst_ray = train(params, RayDMatrix(self.x, self.y), ray_params=RayParams(num_actors=2), obj=squared_log, **self.kwargs) x_mat = xgb.DMatrix(self.x) pred_y_xgb = np.round(bst_xgb.predict(x_mat)) pred_y_ray = np.round(bst_ray.predict(x_mat)) self.assertSequenceEqual(list(pred_y_xgb), list(pred_y_ray)) self.assertSequenceEqual(list(self.y), list(pred_y_ray))
def train_breast_cancer(config, ray_params): # Load dataset data, labels = datasets.load_breast_cancer(return_X_y=True) # Split into train and test set train_x, test_x, train_y, test_y = train_test_split(data, labels, test_size=0.25) train_set = RayDMatrix(train_x, train_y) test_set = RayDMatrix(test_x, test_y) evals_result = {} bst = train(params=config, dtrain=train_set, evals=[(test_set, "eval")], evals_result=evals_result, ray_params=ray_params, verbose_eval=False, num_boost_round=10) model_path = "tuned.xgb" bst.save_model(model_path) print("Final validation error: {:.4f}".format( evals_result["eval"]["error"][-1]))
def testTrainingContinuationKilled(self): """This should continue after one actor died.""" additional_results = {} keep_actors = {} def keep(actors, *args, **kwargs): keep_actors["actors"] = actors.copy() return DEFAULT with patch("xgboost_ray.main._shutdown") as mocked: mocked.side_effect = keep bst = train(self.params, RayDMatrix(self.x, self.y), callbacks=[_kill_callback(self.die_lock_file)], num_boost_round=20, ray_params=RayParams(max_actor_restarts=1, num_actors=2), additional_results=additional_results) self.assertEqual(20, get_num_trees(bst)) x_mat = xgb.DMatrix(self.x) pred_y = bst.predict(x_mat) self.assertSequenceEqual(list(self.y), list(pred_y)) print(f"Got correct predictions: {pred_y}") actors = keep_actors["actors"] # End with two working actors self.assertTrue(actors[0]) self.assertTrue(actors[1]) # Two workers finished, so N=32 self.assertEqual(additional_results["total_n"], 32)
def main(): # Example adapted from this blog post: # https://medium.com/rapids-ai/a-new-official-dask-api-for-xgboost-e8b10f3d1eb7 # This uses the HIGGS dataset. Download here: # https://archive.ics.uci.edu/ml/machine-learning-databases/00280/HIGGS.csv.gz fname = "HIGGS.csv" colnames = ["label"] + ["feature-%02d" % i for i in range(1, 29)] dtrain = RayDMatrix(os.path.abspath(fname), label="label", names=colnames) config = { "tree_method": "hist", "eval_metric": ["logloss", "error"], } evals_result = {} start = time.time() bst = train( config, dtrain, evals_result=evals_result, max_actor_restarts=1, num_boost_round=100, evals=[(dtrain, "train")]) taken = time.time() - start print(f"TRAIN TIME TAKEN: {taken:.2f} seconds") bst.save_model("higgs.xgb") print("Final training error: {:.4f}".format( evals_result["train"]["error"][-1]))
def main(fname, num_actors=2): dtrain = RayDMatrix( os.path.abspath(fname), label="labels", ignore=["partition"]) config = { "tree_method": "hist", "eval_metric": ["logloss", "error"], } evals_result = {} start = time.time() bst = train( config, dtrain, evals_result=evals_result, ray_params=RayParams(max_actor_restarts=1, num_actors=num_actors), num_boost_round=10, evals=[(dtrain, "train")]) taken = time.time() - start print(f"TRAIN TIME TAKEN: {taken:.2f} seconds") bst.save_model("test_data.xgb") print("Final training error: {:.4f}".format( evals_result["train"]["error"][-1]))
def train_xgboost(config, train_df, test_df, target_column, ray_params): train_set = RayDMatrix(train_df, target_column) test_set = RayDMatrix(test_df, target_column) evals_result = {} train_start_time = time.time() # Train the classifier bst = train( params=config, dtrain=train_set, evals=[(test_set, "eval")], evals_result=evals_result, verbose_eval=False, num_boost_round=100, ray_params=ray_params, ) train_end_time = time.time() train_duration = train_end_time - train_start_time print(f"Total time taken: {train_duration} seconds.") model_path = "model.xgb" bst.save_model(model_path) print("Final validation error: {:.4f}".format( evals_result["eval"]["error"][-1])) return bst, evals_result
def testTrainingContinuationElasticMultiKilled(self): """This should still show 20 boost rounds after two failures.""" logging.getLogger().setLevel(10) additional_results = {} bst = train(self.params, RayDMatrix(self.x, self.y), callbacks=[ _kill_callback(self.die_lock_file, fail_iteration=6, actor_rank=0), _kill_callback(self.die_lock_file_2, fail_iteration=14, actor_rank=1), ], num_boost_round=20, ray_params=RayParams(max_actor_restarts=2, num_actors=2, elastic_training=True, max_failed_actors=2), additional_results=additional_results) self.assertEqual(20, get_num_trees(bst)) x_mat = xgb.DMatrix(self.x) pred_y = bst.predict(x_mat) self.assertSequenceEqual(list(self.y), list(pred_y)) print(f"Got correct predictions: {pred_y}")
def testTrainPredict(self, init=True, remote=None, **ray_param_dict): """Train with evaluation and predict""" if init: ray.init(num_cpus=2, num_gpus=0) dtrain = RayDMatrix(self.x, self.y) evals_result = {} bst = train(self.params, dtrain, num_boost_round=38, ray_params=RayParams(num_actors=2, **ray_param_dict), evals=[(dtrain, "dtrain")], evals_result=evals_result, _remote=remote) self.assertEqual(get_num_trees(bst), 38) self.assertTrue("dtrain" in evals_result) x_mat = RayDMatrix(self.x) pred_y = predict(bst, x_mat, ray_params=RayParams(num_actors=2, **ray_param_dict), _remote=remote) self.assertSequenceEqual(list(self.y), list(pred_y))
def main(): # Load dataset data, labels = datasets.load_breast_cancer(return_X_y=True) # Split into train and test set train_x, test_x, train_y, test_y = train_test_split(data, labels, test_size=0.25) train_set = RayDMatrix(train_x, train_y) test_set = RayDMatrix(test_x, test_y) # Set config config = { "tree_method": "approx", "objective": "binary:logistic", "eval_metric": ["logloss", "error"], "max_depth": 3, } evals_result = {} # Train the classifier bst = train(config, train_set, evals=[(test_set, "eval")], evals_result=evals_result, max_actor_restarts=1, verbose_eval=False) bst.save_model("simple.xgb") print("Final validation error: {:.4f}".format( evals_result["eval"]["error"][-1]))
def main(): # Run `create_test_data.py` first to create fake data. fname = "parted.parquet" dtrain = RayDMatrix(os.path.abspath(fname), label="labels", ignore=["partition"]) config = { "tree_method": "hist", "eval_metric": ["logloss", "error"], } evals_result = {} start = time.time() bst = train(config, dtrain, evals_result=evals_result, max_actor_restarts=1, num_boost_round=100, evals=[(dtrain, "train")]) taken = time.time() - start print(f"TRAIN TIME TAKEN: {taken:.2f} seconds") bst.save_model("test_data.xgb") print("Final training error: {:.4f}".format( evals_result["train"]["error"][-1]))
def testKwargsValidation(self): x = np.random.uniform(0, 1, size=(100, 4)) y = np.random.randint(0, 1, size=100) train_set = RayDMatrix(x, y) with self.assertRaisesRegex(TypeError, "totally_invalid_kwarg"): train( { "objective": "multi:softmax", "num_class": 2, "eval_metric": ["logloss", "error"] }, train_set, evals=[(train_set, "train")], ray_params=RayParams(num_actors=1, max_actor_restarts=0), totally_invalid_kwarg="")
def train_ray(num_workers, num_boost_rounds, num_files=0, use_gpu=False): path = "/data/parted.parquet" if num_files: files = sorted(glob.glob(f"{path}/**/*.parquet")) while num_files > len(files): files = files + files path = files[0:num_files] use_device_matrix = False if use_gpu: try: import cupy # noqa: F401 use_device_matrix = True except ImportError: use_device_matrix = False if use_device_matrix: dtrain = RayDeviceQuantileDMatrix(path, num_actors=num_workers, label="labels", ignore=["partition"], filetype=RayFileType.PARQUET) else: dtrain = RayDMatrix(path, num_actors=num_workers, label="labels", ignore=["partition"], filetype=RayFileType.PARQUET) config = { "tree_method": "hist" if not use_gpu else "gpu_hist", "eval_metric": ["logloss", "error"], } start = time.time() evals_result = {} bst = train(config, dtrain, evals_result=evals_result, max_actor_restarts=2, num_boost_round=num_boost_rounds, num_actors=num_workers, cpus_per_actor=4, checkpoint_path="/tmp/checkpoint/", gpus_per_actor=0 if not use_gpu else 1, resources_per_actor={ "actor_cpus": 4, "actor_gpus": 0 if not use_gpu else 1 }, evals=[(dtrain, "train")]) taken = time.time() - start print(f"TRAIN TIME TAKEN: {taken:.2f} seconds") bst.save_model("benchmark_{}.xgb".format("cpu" if not use_gpu else "gpu")) print("Final training error: {:.4f}".format( evals_result["train"]["error"][-1])) return taken
def testSameResultWithAndWithoutError(self): """Get the same model with and without errors during training.""" # Run training bst_noerror = train(self.params, RayDMatrix(self.x, self.y), num_boost_round=10, ray_params=RayParams(max_actor_restarts=0, num_actors=2)) bst_2part_1 = train(self.params, RayDMatrix(self.x, self.y), num_boost_round=5, ray_params=RayParams(max_actor_restarts=0, num_actors=2)) bst_2part_2 = train(self.params, RayDMatrix(self.x, self.y), num_boost_round=5, ray_params=RayParams(max_actor_restarts=0, num_actors=2), xgb_model=bst_2part_1) res_error = {} bst_error = train( self.params, RayDMatrix(self.x, self.y), callbacks=[_fail_callback(self.die_lock_file, fail_iteration=7)], num_boost_round=10, ray_params=RayParams(max_actor_restarts=1, num_actors=2, checkpoint_frequency=5), additional_results=res_error) flat_noerror = flatten_obj({"tree": tree_obj(bst_noerror)}) flat_error = flatten_obj({"tree": tree_obj(bst_error)}) flat_2part = flatten_obj({"tree": tree_obj(bst_2part_2)}) for key in flat_noerror: self.assertAlmostEqual(flat_noerror[key], flat_error[key]) self.assertAlmostEqual(flat_noerror[key], flat_2part[key]) # We fail at iteration 7, but checkpoints are saved at iteration 5 # Thus we have two additional returns here. print("Callback returns:", res_error["callback_returns"][0]) self.assertEqual(len(res_error["callback_returns"][0]), 10 + 2)
def testFaultToleranceManager(self): ft_manager = FaultToleranceManager.remote() ft_manager.schedule_kill.remote(rank=1, boost_round=16) ft_manager.delay_return.remote( rank=1, start_boost_round=14, end_boost_round=68) delay_callback = DelayedLoadingCallback( ft_manager, reload_data=True, sleep_time=0.1) die_callback = DieCallback(ft_manager, training_delay=0.25) res_1 = {} train( self.params, RayDMatrix(self.x, self.y), callbacks=[die_callback], num_boost_round=100, ray_params=RayParams( num_actors=2, checkpoint_frequency=1, elastic_training=True, max_failed_actors=1, max_actor_restarts=1, distributed_callbacks=[delay_callback]), additional_results=res_1) logs = ray.get(ft_manager.get_logs.remote()) print(logs) self.assertSequenceEqual([g for g, _ in logs[0][0:99]], range(99)) # Which steps exactly are executed is stochastic. The rank 1 actor # will die at iteration 16, so at least 15 will be logged (though 16 # might be logged as well). Iterations 17 to 67 should never be # logged. It comes back some time after iteration 68, so this might # be iter 68, 69, or later. We just make sure it comes back at all # (iter 70+). global_steps = [g for g, _ in logs[1]] self.assertTrue(global_steps) self.assertIn(15, global_steps) self.assertNotIn(17, global_steps) self.assertNotIn(67, global_steps) self.assertIn(70, global_steps)