def test_resize(ray_start_2_cpus): # noqa: F811 if not dist.is_available(): return def single_loader(config): return LinearDataset(2, 5, size=1000000) def step_with_fail(self, *args, **kwargs): worker_stats = [ w.train_epoch.remote(*args, **kwargs) for w in self.workers ] if self._num_failures < 1: time.sleep(1) # Make the batch will fail correctly. self.workers[0].__ray_kill__() success = check_for_failure(worker_stats) return success, worker_stats with patch.object(TorchTrainer, "_train_epoch", step_with_fail): trainer1 = TorchTrainer(model_creator, single_loader, optimizer_creator, batch_size=100000, loss_creator=lambda config: nn.MSELoss(), num_replicas=2) @ray.remote def try_test(): import time time.sleep(100) try_test.remote() trainer1.train(max_retries=1) assert len(trainer1.workers) == 1
def test_metrics_nan(ray_start_2_cpus, num_workers): data_size, val_size = 100, 100 batch_size = 10 num_train_steps = int(data_size / batch_size) num_val_steps = int(val_size / batch_size) train_scores = [np.nan] + ([0] * num_train_steps) val_scores = [np.nan] + ([0] * num_val_steps) trainer = TorchTrainer(model_creator=model_creator, data_creator=data_creator, optimizer_creator=optimizer_creator, loss_creator=lambda config: nn.MSELoss(), num_workers=num_workers, config={ "scores": train_scores, "val_scores": val_scores, "key": "score", "batch_size": batch_size, "data_size": data_size, "val_size": val_size }, training_operator_cls=_TestMetricsOperator) stats = trainer.train(num_steps=num_train_steps) assert "mean_score" in stats assert stats["last_score"] == 0 assert np.isnan(stats["mean_score"]) stats = trainer.validate() assert "mean_score" in stats assert stats["last_score"] == 0 assert np.isnan(stats["mean_score"])
def test_split_batch(ray_start_2_cpus, use_local): if not dist.is_available(): return def data_creator(config): """Returns training dataloader, validation dataloader.""" train_dataset = LinearDataset(2, 5, size=config["data_size"]) return DataLoader( train_dataset, batch_size=config[BATCH_SIZE], ) data_size = 600 batch_size = 21 TestOperator = TrainingOperator.from_creators( model_creator, optimizer_creator, data_creator, loss_creator=lambda config: nn.MSELoss()) trainer = TorchTrainer( training_operator_cls=TestOperator, num_workers=2, use_local=use_local, config={ BATCH_SIZE: batch_size, "data_size": data_size, }) stats = trainer.train() assert trainer.config[BATCH_SIZE] == (batch_size - 1) assert stats[NUM_SAMPLES] == 600 assert stats[BATCH_COUNT] == (data_size // 20) trainer.shutdown()
def test_split_batch(ray_start_2_cpus): if not dist.is_available(): return def data_creator(config): """Returns training dataloader, validation dataloader.""" train_dataset = LinearDataset(2, 5, size=config["data_size"]) return torch.utils.data.DataLoader( train_dataset, batch_size=config[BATCH_SIZE], ) data_size = 600 batch_size = 21 trainer = TorchTrainer(model_creator=model_creator, data_creator=data_creator, optimizer_creator=optimizer_creator, loss_creator=lambda config: nn.MSELoss(), num_workers=2, config={ BATCH_SIZE: batch_size, "data_size": data_size, }) stats = trainer.train() assert trainer.config[BATCH_SIZE] == (batch_size - 1) assert stats[NUM_SAMPLES] == 600 assert stats[BATCH_COUNT] == (data_size // 20)
def test_timeout(ray_4_node_1_cpu): """Tests that an error is thrown when placement group setup times out.""" with pytest.raises(TimeoutError): trainer = TorchTrainer(training_operator_cls=Operator, num_workers=7, use_gpu=False) trainer.shutdown()
def test_shutdown(ray_8_node_2_cpu): """Tests if placement group is removed when worker group is shut down.""" assert ray.available_resources()["CPU"] == 16 placement_group_table = ray.state.state.placement_group_table() assert len(placement_group_table) == 0 trainer = TorchTrainer( training_operator_cls=Operator, num_workers=7, use_gpu=False, ) assert ray.available_resources()["CPU"] == 9 placement_group_table = ray.state.state.placement_group_table() assert len(placement_group_table) == 1 placement_group_id = list(placement_group_table)[0] placement_group = placement_group_table[placement_group_id] assert placement_group["strategy"] == "SPREAD" assert placement_group["state"] == "CREATED" trainer.shutdown() assert ray.available_resources()["CPU"] == 16 placement_group_table = ray.state.state.placement_group_table() assert len(placement_group_table) == 1 placement_group = placement_group_table[placement_group_id] assert placement_group["strategy"] == "SPREAD" assert placement_group["state"] == "REMOVED"
def test_fail_twice(ray_start_2_cpus): # noqa: F811 if not dist.is_available(): return def single_loader(config): dataset = LinearDataset(2, 5, size=1000000) return torch.utils.data.DataLoader(dataset, batch_size=config.get( "batch_size", 32)) def step_with_fail(self, *args, **kwargs): worker_stats = [ w.train_epoch.remote(*args, **kwargs) for w in self.workers ] if self._num_failures < 2: time.sleep(1) self.workers[0].__ray_kill__() success = check_for_failure(worker_stats) return success, worker_stats with patch.object(TorchTrainer, "_train_epoch", step_with_fail): trainer1 = TorchTrainer(model_creator=model_creator, data_creator=single_loader, optimizer_creator=optimizer_creator, config={"batch_size": 100000}, loss_creator=lambda config: nn.MSELoss(), num_workers=2) trainer1.train(max_retries=2)
def test_single_step(ray_start_2_cpus): # noqa: F811 trainer = TorchTrainer(model_creator=model_creator, data_creator=data_creator, optimizer_creator=optimizer_creator, loss_creator=lambda config: nn.MSELoss(), num_workers=1) metrics = trainer.train(num_steps=1) assert metrics[BATCH_COUNT] == 1 val_metrics = trainer.validate(num_steps=1) assert val_metrics[BATCH_COUNT] == 1
def test_apply_all_workers(ray_start_2_cpus, num_workers, use_local): def fn(): return 1 trainer = TorchTrainer( training_operator_cls=Operator, num_workers=num_workers, use_local=use_local, use_gpu=False) results = trainer.apply_all_workers(fn) assert all(x == 1 for x in results)
def test_multi_input_model(ray_start_2_cpus, use_local): def model_creator(config): class MultiInputModel(nn.Module): def __init__(self): super(MultiInputModel, self).__init__() self._fc1 = torch.nn.Linear(1, 1) self._fc2 = torch.nn.Linear(1, 1) def forward(self, x, y): return self._fc1(x) + self._fc2(y) return MultiInputModel() def data_creator(config): class LinearDataset(torch.utils.data.Dataset): def __init__(self, a, b, size=1000): x = np.random.randn(size) y = np.random.randn(size) self.x = torch.tensor(x, dtype=torch.float32) self.y = torch.tensor(y, dtype=torch.float32) self.z = torch.tensor(a * (x + y) + 2 * b, dtype=torch.float32) def __getitem__(self, index): return (self.x[index, None], self.y[index, None], self.z[index, None]) def __len__(self): return len(self.x) train_dataset = LinearDataset(3, 4) train_loader = torch.utils.data.DataLoader( train_dataset, batch_size=config.get("batch_size", 32), ) return train_loader, None Operator = TrainingOperator.from_creators( model_creator, optimizer_creator, data_creator, loss_creator=lambda config: nn.MSELoss()) trainer = TorchTrainer(training_operator_cls=Operator, num_workers=1, use_local=use_local) metrics = trainer.train(num_steps=1) assert metrics[BATCH_COUNT] == 1 trainer.shutdown()
def test_scheduler_freq(ray_start_2_cpus, scheduler_freq): # noqa: F811 def train_epoch(self, iterator, info): assert info[SCHEDULER_STEP] == scheduler_freq return {"done": 1} def scheduler_creator(optimizer, config): return torch.optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.1) if scheduler_freq is None: with pytest.raises(ValueError): trainer = TorchTrainer(model_creator=model_creator, data_creator=data_creator, optimizer_creator=optimizer_creator, loss_creator=lambda config: nn.MSELoss(), scheduler_creator=scheduler_creator, scheduler_step_freq=scheduler_freq) else: trainer = TorchTrainer(model_creator=model_creator, data_creator=data_creator, optimizer_creator=optimizer_creator, loss_creator=lambda config: nn.MSELoss(), config={"custom_func": train_epoch}, training_operator_cls=_TestingOperator, scheduler_creator=scheduler_creator, scheduler_step_freq=scheduler_freq) for i in range(3): trainer.train() trainer.shutdown()
def train_example(num_workers=1, num_epochs=5, use_gpu=False, use_fp16=False, test_mode=False): trainer1 = TorchTrainer( model_creator=ResNet18, data_creator=cifar_creator, optimizer_creator=optimizer_creator, loss_creator=nn.CrossEntropyLoss, scheduler_creator=scheduler_creator, initialization_hook=initialization_hook, num_workers=num_workers, config={ "lr": 0.1, "test_mode": test_mode, # user-defined param to subset the data BATCH_SIZE: 128 * num_workers # this will be split across workers. }, use_gpu=use_gpu, scheduler_step_freq="epoch", use_fp16=use_fp16, use_tqdm=True) pbar = trange(num_epochs, unit="epoch") for i in pbar: info = {"num_steps": 1} if test_mode else {} info["epoch_idx"] = i info["num_epochs"] = num_epochs # Increase `max_retries` to turn on fault tolerance. trainer1.train(max_retries=1, info=info) val_stats = trainer1.validate() pbar.set_postfix(dict(acc=val_stats["val_accuracy"])) print(trainer1.validate()) trainer1.shutdown() print("success!")
def test_iterable_model(ray_start_2_cpus, num_workers, use_local): # noqa: F811 class IterableOptimizer(torch.optim.SGD): def __iter__(self): return self.param_groups class Operator(TrainingOperator): def setup(self, config): model = nn.Sequential(nn.Linear(1, config.get("hidden_size", 1))) optimizer = IterableOptimizer(model.parameters(), lr=config.get("lr", 1e-2)) criterion = nn.MSELoss() self.model, self.optimizer, self.criterion = self.register( models=model, optimizers=optimizer, criterion=criterion) train_ld, val_ld = data_creator(config) self.register_data(train_loader=train_ld, validation_loader=val_ld) trainer = TorchTrainer(training_operator_cls=Operator, num_workers=num_workers, use_local=use_local, use_gpu=False) for i in range(3): train_loss1 = trainer.train()["train_loss"] validation_loss1 = trainer.validate()["val_loss"] for i in range(3): train_loss2 = trainer.train()["train_loss"] validation_loss2 = trainer.validate()["val_loss"] assert train_loss2 <= train_loss1, (train_loss2, train_loss1) assert validation_loss2 <= validation_loss1, (validation_loss2, validation_loss1) trainer.shutdown()
def train_cifar(test_mode=False, num_workers=1, use_gpu=False, num_epochs=5, fp16=False): trainer1 = TorchTrainer( training_operator_cls=CifarTrainingOperator, initialization_hook=initialization_hook, num_workers=num_workers, config={ "lr": 0.1, "test_mode": test_mode, # subset the data # this will be split across workers. BATCH_SIZE: 128 * num_workers, }, use_gpu=use_gpu, scheduler_step_freq="epoch", use_fp16=fp16, use_tqdm=False, ) pbar = trange(num_epochs, unit="epoch") for i in pbar: info = {"num_steps": 1} if test_mode else {} info["epoch_idx"] = i info["num_epochs"] = num_epochs # Increase `max_retries` to turn on fault tolerance. trainer1.train(max_retries=1, info=info) val_stats = trainer1.validate() pbar.set_postfix(dict(acc=val_stats["val_accuracy"])) print(trainer1.validate()) trainer1.shutdown() print("success!")
def test_dead_trainer(ray_start_2_cpus): # noqa: F811 TestOperator = get_test_operator(Operator) trainer = TorchTrainer(training_operator_cls=TestOperator, num_workers=2) trainer.train(num_steps=1) trainer.shutdown() with pytest.raises(RuntimeError): trainer.train()
def test_tune_train(ray_start_4_cpus, num_workers, use_local): # noqa: F811 TorchTrainable = TorchTrainer.as_trainable( **{ "training_operator_cls": Operator, "num_workers": num_workers, "use_gpu": False, "backend": "gloo", "use_local": use_local, "config": { "batch_size": 512, "lr": 0.001 } }) analysis = tune.run(TorchTrainable, num_samples=2, stop={"training_iteration": 2}, verbose=1) # checks loss decreasing for every trials for path, df in analysis.trial_dataframes.items(): mean_train_loss1 = df.loc[0, "train_loss"] mean_train_loss2 = df.loc[1, "train_loss"] mean_val_loss1 = df.loc[0, "val_loss"] mean_val_loss2 = df.loc[1, "val_loss"] assert mean_train_loss2 <= mean_train_loss1 assert mean_val_loss2 <= mean_val_loss1
def test_tune_train_pack(ray_4_node_8_cpu, num_workers): """Tests if workers are colocated when running Tune.""" def custom_train_func(trainer, info): train_stats = trainer.train(profile=True) val_stats = trainer.validate(profile=True) stats = merge_dicts(train_stats, val_stats) actors = ray.state.actors().values() assert len(actors) == num_workers + 1 node_id_set = set() for actor_info in actors: node_id = actor_info["Address"]["NodeID"] node_id_set.add(node_id) assert len(node_id_set) == 1 + num_workers // 8 return stats TorchTrainable = TorchTrainer.as_trainable( override_tune_step=custom_train_func, **{ "training_operator_cls": Operator, "num_workers": num_workers, "use_gpu": False, "backend": "gloo", "config": { "batch_size": 512, "lr": 0.001 } }) tune.run(TorchTrainable, num_samples=1, stop={"training_iteration": 2}, verbose=1)
def test_tune_custom_train(ray_start_4_cpus, num_workers, use_local): # noqa: F811 def custom_train_func(trainer, info): train_stats = trainer.train(profile=True) val_stats = trainer.validate(profile=True) stats = merge_dicts(train_stats, val_stats) return stats TorchTrainable = TorchTrainer.as_trainable( **{ "override_tune_step": custom_train_func, "training_operator_cls": Operator, "num_workers": num_workers, "use_gpu": False, "backend": "gloo", "use_local": use_local, "config": {"batch_size": 512, "lr": 0.001}, } ) analysis = tune.run( TorchTrainable, num_samples=2, stop={"training_iteration": 2}, verbose=1 ) # checks loss decreasing for every trials for path, df in analysis.trial_dataframes.items(): mean_train_loss1 = df.loc[0, "train_loss"] mean_train_loss2 = df.loc[1, "train_loss"] mean_val_loss1 = df.loc[0, "val_loss"] mean_val_loss2 = df.loc[1, "val_loss"] assert mean_train_loss2 <= mean_train_loss1 assert mean_val_loss2 <= mean_val_loss1
def tune_example_manual(operator_cls, num_workers=1, use_gpu=False): def step(trainer, info: dict): """Define a custom training loop for tune. This is needed because we want to manually update our scheduler. """ train_stats = trainer.train(profile=True) validation_stats = trainer.validate(profile=True) # Manually update our scheduler with the given metric. trainer.update_scheduler(metric=validation_stats["val_loss"]) all_stats = merge_dicts(train_stats, validation_stats) return all_stats TorchTrainable = TorchTrainer.as_trainable( override_tune_step=step, training_operator_cls=operator_cls, num_workers=num_workers, use_gpu=use_gpu, scheduler_step_freq="manual", config={BATCH_SIZE: 128} ) analysis = tune.run( TorchTrainable, num_samples=3, config={"lr": tune.grid_search([1e-4, 1e-3])}, stop={"training_iteration": 2}, verbose=1) return analysis.get_best_config(metric="val_loss", mode="min")
def test_dead_trainer(ray_start_2_cpus): # noqa: F811 trainer = TorchTrainer(model_creator=model_creator, data_creator=data_creator, optimizer_creator=optimizer_creator, loss_creator=lambda config: nn.MSELoss(), num_workers=2) trainer.train(num_steps=1) trainer.shutdown() with pytest.raises(RuntimeError): trainer.train()
def test_dataset(ray_start_4_cpus, use_local): """ This test tries training the mlp_identity example. We check the accuracy of the model as an all inclusive way of ensuring that we are properly sharding and iterating over the entire dataset (instead of repeating the first set of points for example). """ model_creator = mlp_identity.model_creator optimizer_creator = mlp_identity.optimizer_creator dataset_creator = mlp_identity.dataset_creator DatasetOperator = TrainingOperator.from_creators( model_creator=model_creator, optimizer_creator=optimizer_creator, loss_creator=nn.MSELoss) trainer = TorchTrainer( training_operator_cls=DatasetOperator, use_local=use_local, num_workers=2, ) dataset = dataset_creator() for i in range(5): trainer.train(dataset=dataset, num_steps=100) x = mlp_identity.to_mat(0.5) prediction = float(trainer.get_model()(x)[0][0]) assert 0.4 <= prediction <= 0.6 trainer.shutdown()
def test_train_spread(ray_8_node_2_cpu): """Tests if workers are spread across nodes.""" assert ray.available_resources()["CPU"] == 16 trainer = TorchTrainer( training_operator_cls=Operator, num_workers=7, use_gpu=False, ) assert ray.available_resources()["CPU"] == 9 node_id_set = set() for actor_info in ray.state.actors().values(): node_id = actor_info["Address"]["NodeID"] node_id_set.add(node_id) assert len(node_id_set) == 7 trainer.shutdown() assert ray.available_resources()["CPU"] == 16
def test_non_serialized_data(ray_start_2_cpus): # noqa: F811 duration = 10 def slow_data(func): def slowed_func(*args, **kwargs): time.sleep(duration) return func(*args, **kwargs) return slowed_func start = time.time() trainer = TorchTrainer(model_creator=model_creator, data_creator=slow_data(data_creator), optimizer_creator=optimizer_creator, serialize_data_creation=False, loss_creator=lambda config: nn.MSELoss(), num_workers=2) elapsed = time.time() - start assert elapsed < duration * 2 trainer.shutdown()
def test_resize(ray_start_2_cpus, use_local): # noqa: F811 if not dist.is_available(): return def single_loader(config): dataset = LinearDataset(2, 5, size=1000000) return DataLoader(dataset, batch_size=config.get("batch_size", 32)) start_with_fail = gen_start_with_fail(1) TestOperator = TrainingOperator.from_creators( model_creator, optimizer_creator, single_loader, loss_creator=lambda config: nn.MSELoss(), ) with patch.object(TorchTrainer, "_start_workers", start_with_fail): trainer1 = TorchTrainer( training_operator_cls=TestOperator, config={"batch_size": 100000}, use_local=use_local, num_workers=2, ) # we use placement_group to occupy resources bundle = { "CPU": 1, } bundles = [bundle] dummy_pg = ray.util.placement_group(bundles, strategy="SPREAD") trainer1.train(max_retries=1) assert trainer1.worker_group.num_workers == 1 assert trainer1._num_failures == 1 ray.util.remove_placement_group(dummy_pg) def is_placement_group_removed(): table = ray.util.placement_group_table(dummy_pg) if "state" not in table: return False return table["state"] == "REMOVED" # wait for free resource wait_for_condition(is_placement_group_removed) # trigger scale up trainer1.train() assert trainer1.worker_group.num_workers == 2 trainer1.shutdown(force=True)
def test_reduce_result(ray_start_2_cpus): if not dist.is_available(): return def data_creator(config): """Returns training dataloader, validation dataloader.""" train_dataset = LinearDataset(2, 5, size=config["data_size"]) return torch.utils.data.DataLoader(train_dataset, batch_size=1) data_size = 600 trainer = TorchTrainer(model_creator=model_creator, data_creator=data_creator, optimizer_creator=optimizer_creator, loss_creator=lambda config: nn.MSELoss(), num_workers=2, config={"data_size": data_size}) list_stats = trainer.train(reduce_results=False, profile=True) assert len(list_stats) == 2 assert [stats[NUM_SAMPLES] == data_size for stats in list_stats] assert [stats[BATCH_COUNT] == (data_size // 2) for stats in list_stats]
def test_fail_with_recover(ray_start_2_cpus, use_local): # noqa: F811 print(locals()) if not dist.is_available(): return def single_loader(config): dataset = LinearDataset(2, 5, size=1000000) return DataLoader(dataset, batch_size=config.get("batch_size", 32)) TestOperator = TrainingOperator.from_creators( model_creator, optimizer_creator, single_loader, loss_creator=lambda config: nn.MSELoss()) start_with_fail = gen_start_with_fail(3) with patch.object(TorchTrainer, "_start_workers", start_with_fail): trainer1 = TorchTrainer(training_operator_cls=TestOperator, config={"batch_size": 100000}, timeout_s=5, use_local=use_local, num_workers=2) with pytest.raises(RuntimeError): trainer1.train(max_retries=1) trainer1.shutdown(force=True)
def test_fail_twice(ray_start_2_cpus, use_local): # noqa: F811 if not dist.is_available(): return def single_loader(config): dataset = LinearDataset(2, 5, size=1000000) return DataLoader(dataset, batch_size=config.get("batch_size", 32)) TestOperator = TrainingOperator.from_creators( model_creator, optimizer_creator, single_loader, loss_creator=lambda config: nn.MSELoss()) start_with_fail = gen_start_with_fail(2) with patch.object(TorchTrainer, "_start_workers", start_with_fail): trainer1 = TorchTrainer(training_operator_cls=TestOperator, config={"batch_size": 100000}, use_local=use_local, num_workers=2) # MAX RETRIES SHOULD BE ON BY DEFAULT trainer1.train() assert trainer1._num_failures == 2 assert trainer1.worker_group.num_workers == 2 trainer1.shutdown(force=True)
def train_example(num_workers=1, num_epochs=5, use_gpu=False, use_fp16=False, test_mode=False): trainer1 = TorchTrainer( model_creator=ResNet18, data_creator=cifar_creator, optimizer_creator=optimizer_creator, loss_creator=nn.CrossEntropyLoss, scheduler_creator=scheduler_creator, initialization_hook=initialization_hook, num_workers=num_workers, config={ "lr": 0.01, "test_mode": test_mode, BATCH_SIZE: 128, }, use_gpu=use_gpu, backend="nccl" if use_gpu else "gloo", scheduler_step_freq="epoch", use_fp16=use_fp16, use_tqdm=True) pbar = trange(num_epochs, unit="epoch") for i in pbar: info = {"num_steps": 1} if test_mode else {} info["epoch_idx"] = i info["num_epochs"] = num_epochs # Increase `max_retries` to turn on fault tolerance. stats = trainer1.train(max_retries=1, info=info) pbar.set_postfix(dict(loss=stats["mean_train_loss"])) print(trainer1.validate()) trainer1.shutdown() print("success!")
def test_metrics_nan(ray_start_2_cpus, num_workers, use_local): data_size, val_size = 100, 100 batch_size = 10 num_train_steps = int(data_size / batch_size) num_val_steps = int(val_size / batch_size) train_scores = [np.nan] + ([0] * num_train_steps) val_scores = [np.nan] + ([0] * num_val_steps) TestOperator = get_test_metrics_operator(Operator) trainer = TorchTrainer( training_operator_cls=TestOperator, num_workers=num_workers, use_local=use_local, config={ "scores": train_scores, "val_scores": val_scores, "key": "score", "batch_size": batch_size, "data_size": data_size, "val_size": val_size }) stats = trainer.train(num_steps=num_train_steps) assert "score" in stats assert stats["last_score"] == 0 assert np.isnan(stats["score"]) stats = trainer.validate() assert "score" in stats assert stats["last_score"] == 0 assert np.isnan(stats["score"]) trainer.shutdown()
def test_reduce_result(ray_start_2_cpus, use_local): if not dist.is_available(): return def data_creator(config): """Returns training dataloader, validation dataloader.""" train_dataset = LinearDataset(2, 5, size=config["data_size"]) test_dataset = LinearDataset(2, 5, size=config["data_size"]) return DataLoader( train_dataset, batch_size=1), DataLoader( test_dataset, batch_size=1) data_size = 600 TestOperator = TrainingOperator.from_creators( model_creator, optimizer_creator, data_creator, loss_creator=lambda config: nn.MSELoss()) trainer = TorchTrainer( training_operator_cls=TestOperator, num_workers=2, use_local=use_local, config={"data_size": data_size}) list_stats = trainer.train(reduce_results=False, profile=True) assert len(list_stats) == 2 assert [stats[NUM_SAMPLES] == data_size for stats in list_stats] assert [stats[BATCH_COUNT] == (data_size // 2) for stats in list_stats] list_stats = trainer.validate(reduce_results=False, profile=True) assert len(list_stats) == 2 assert [stats[NUM_SAMPLES] == data_size for stats in list_stats] assert [stats[BATCH_COUNT] == (data_size // 2) for stats in list_stats] trainer.shutdown()