def build_flash_serve_model_component(model, serve_input, output, transform, transform_kwargs): # TODO: Resolve this hack data_module = DataModule( predict_input=serve_input, batch_size=1, transform=transform, transform_kwargs=transform_kwargs, ) class MockTrainer(Trainer): def __init__(self): super().__init__() self.state.stage = RunningStage.PREDICTING @property def lightning_module(self): return model data_module.trainer = MockTrainer() dataloader = data_module.predict_dataloader() collate_fn = dataloader.collate_fn class FlashServeModelComponent(ModelComponent): def __init__(self, model): self.model = model self.model.eval() self.serve_input = serve_input self.on_after_batch_transfer = data_module.on_after_batch_transfer self.output_transform = getattr(model, "_output_transform", None) or OutputTransform() # TODO (@tchaton) Remove this hack self.extra_arguments = len( inspect.signature( self.model.transfer_batch_to_device).parameters) == 3 self.device = self.model.device @expose( inputs={ "inputs": FlashInputs(_ServeInputProcessor(serve_input, collate_fn)) }, outputs={"outputs": FlashOutputs(output)}, ) def predict(self, inputs): with torch.no_grad(): if self.extra_arguments: inputs = self.model.transfer_batch_to_device( inputs, self.device, 0) else: inputs = self.model.transfer_batch_to_device( inputs, self.device) inputs = self.on_after_batch_transfer(inputs, 0) preds = self.model.predict_step(inputs, 0) preds = self.output_transform(preds) return preds return FlashServeModelComponent(model)
def test_data_loaders_num_workers_to_0(tmpdir): """ num_workers should be set to `0` internally for visualization and not for training. """ datamodule = DataModule(train_dataset=range(10), num_workers=3) iterator = datamodule._reset_iterator(RunningStage.TRAINING) assert isinstance(iterator, torch.utils.data.dataloader._SingleProcessDataLoaderIter) iterator = iter(datamodule.train_dataloader()) assert isinstance(iterator, torch.utils.data.dataloader._MultiProcessingDataLoaderIter) assert datamodule.num_workers == 3
def test_available_data_sources(): preprocess = CustomPreprocess() assert DefaultDataSources.TENSORS in preprocess.available_data_sources() assert "test" in preprocess.available_data_sources() assert len(preprocess.available_data_sources()) == 3 data_module = DataModule(preprocess=preprocess) assert DefaultDataSources.TENSORS in data_module.available_data_sources() assert "test" in data_module.available_data_sources() assert len(data_module.available_data_sources()) == 3
def test_datapipeline_transformations(tmpdir): datamodule = DataModule.from_data_source( "default", 1, 1, 1, 1, batch_size=2, num_workers=0, preprocess=TestPreprocessTransformations() ) assert datamodule.train_dataloader().dataset[0] == (0, 1, 2, 3) batch = next(iter(datamodule.train_dataloader())) assert torch.equal(batch, tensor([[0, 1, 2, 3, 5], [0, 1, 2, 3, 5]])) assert datamodule.val_dataloader().dataset[0] == {'a': 0, 'b': 1} assert datamodule.val_dataloader().dataset[1] == {'a': 1, 'b': 2} with pytest.raises(MisconfigurationException, match="When ``to_tensor_transform``"): batch = next(iter(datamodule.val_dataloader())) datamodule = DataModule.from_data_source( "default", 1, 1, 1, 1, batch_size=2, num_workers=0, preprocess=TestPreprocessTransformations2() ) batch = next(iter(datamodule.val_dataloader())) assert torch.equal(batch["a"], tensor([0, 1])) assert torch.equal(batch["b"], tensor([1, 2])) model = CustomModel() trainer = Trainer( max_epochs=1, limit_train_batches=2, limit_val_batches=1, limit_test_batches=2, limit_predict_batches=2, num_sanity_val_steps=1 ) trainer.fit(model, datamodule=datamodule) trainer.test(model) trainer.predict(model) preprocess = model._preprocess data_source = preprocess.data_source_of_name("default") assert data_source.train_load_data_called assert preprocess.train_pre_tensor_transform_called assert preprocess.train_collate_called assert preprocess.train_per_batch_transform_on_device_called assert data_source.val_load_data_called assert data_source.val_load_sample_called assert preprocess.val_to_tensor_transform_called assert preprocess.val_collate_called assert preprocess.val_per_batch_transform_on_device_called assert data_source.test_load_data_called assert preprocess.test_to_tensor_transform_called assert preprocess.test_post_tensor_transform_called assert data_source.predict_load_data_called
def test_datapipeline_transformations_overridden_by_task(): # define input transforms class ImageInput(Input): def load_data(self, folder): # from folder -> return files paths return ["a.jpg", "b.jpg"] def load_sample(self, path): # from a file path, load the associated image return np.random.uniform(0, 1, (64, 64, 3)) class ImageClassificationInputTransform(InputTransform): def per_sample_transform(self) -> Callable: return T.Compose([T.ToTensor()]) def per_batch_transform_on_device(self) -> Callable: return T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) class OverrideInputTransform(InputTransform): def per_sample_transform(self) -> Callable: return T.Compose([T.ToTensor(), T.Resize(128)]) # define task which overrides transforms using set_state class CustomModel(Task): def __init__(self): super().__init__(model=torch.nn.Linear(1, 1), loss_fn=torch.nn.MSELoss()) # override default transform to resize images self.input_transform = OverrideInputTransform def training_step(self, batch, batch_idx): assert batch.shape == torch.Size([2, 3, 128, 128]) assert torch.max(batch) <= 1.0 assert torch.min(batch) >= 0.0 def validation_step(self, batch, batch_idx): assert batch.shape == torch.Size([2, 3, 128, 128]) assert torch.max(batch) <= 1.0 assert torch.min(batch) >= 0.0 transform = ImageClassificationInputTransform() datamodule = DataModule( ImageInput(RunningStage.TRAINING, [1]), ImageInput(RunningStage.VALIDATING, [1]), transform=transform, batch_size=2, num_workers=0, ) # call trainer model = CustomModel() trainer = Trainer( max_epochs=1, limit_train_batches=2, limit_val_batches=1, num_sanity_val_steps=1, ) trainer.fit(model, datamodule=datamodule)
def test_transformations(tmpdir): transform = TestInputTransform() datamodule = DataModule( TestInput(RunningStage.TRAINING, [1]), TestInput(RunningStage.VALIDATING, [1]), TestInput(RunningStage.TESTING, [1]), transform=transform, batch_size=2, num_workers=0, ) assert datamodule.train_dataloader().dataset[0] == (0, 1, 2, 3) batch = next(iter(datamodule.train_dataloader())) assert torch.equal(batch, torch.tensor([[0, 1, 2, 3, 5], [0, 1, 2, 3, 5]])) assert datamodule.val_dataloader().dataset[0] == {"a": 0, "b": 1} assert datamodule.val_dataloader().dataset[1] == {"a": 1, "b": 2} batch = next(iter(datamodule.val_dataloader())) datamodule = DataModule( TestInput(RunningStage.TRAINING, [1]), TestInput(RunningStage.VALIDATING, [1]), TestInput(RunningStage.TESTING, [1]), transform=TestInputTransform2, batch_size=2, num_workers=0, ) batch = next(iter(datamodule.val_dataloader())) assert torch.equal(batch["a"], torch.tensor([0, 1])) assert torch.equal(batch["b"], torch.tensor([1, 2])) model = CustomModel() trainer = Trainer( max_epochs=1, limit_train_batches=2, limit_val_batches=1, limit_test_batches=2, limit_predict_batches=2, num_sanity_val_steps=1, ) trainer.fit(model, datamodule=datamodule) trainer.test(model, datamodule=datamodule) assert datamodule.input_transform.train_per_sample_transform_called assert datamodule.input_transform.train_collate_called assert datamodule.input_transform.train_per_batch_transform_on_device_called assert datamodule.input_transform.train_per_sample_transform_called assert datamodule.input_transform.val_collate_called assert datamodule.input_transform.val_per_batch_transform_on_device_called assert datamodule.input_transform.test_per_sample_transform_called
def test_val_split(): datamodule = DataModule( Input(RunningStage.TRAINING, [1] * 100), batch_size=2, num_workers=0, val_split=0.2, ) assert len(datamodule.train_dataset) == 80 assert len(datamodule.val_dataset) == 20
def test_dataloaders_with_sampler(mock_dataloader, sampler, callable): train_input = TestInput(RunningStage.TRAINING, [1]) datamodule = DataModule( train_input, TestInput(RunningStage.VALIDATING, [1]), TestInput(RunningStage.TESTING, [1]), batch_size=2, num_workers=0, sampler=sampler, ) assert datamodule.sampler is sampler dl = datamodule.train_dataloader() if callable: sampler.assert_called_once_with(train_input) kwargs = mock_dataloader.call_args[1] assert "sampler" in kwargs assert kwargs["sampler"] is (sampler.return_value if callable else sampler) for dl in [datamodule.val_dataloader(), datamodule.test_dataloader()]: kwargs = mock_dataloader.call_args[1] assert "sampler" not in kwargs
def test_predict_dataset(tmpdir): """Tests that we can generate embeddings from a pytorch geometric dataset.""" tudataset = datasets.TUDataset(root=tmpdir, name="KKI") model = GraphEmbedder( GraphClassifier(num_features=tudataset.num_features, num_classes=tudataset.num_classes).backbone) datamodule = DataModule( predict_input=GraphClassificationDatasetInput(RunningStage.PREDICTING, tudataset), transform=GraphClassificationInputTransform, batch_size=4, ) trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True) out = trainer.predict(model, datamodule=datamodule) assert isinstance(out[0][0], torch.Tensor)
def test_split_dataset(tmpdir): train_ds, val_ds = DataModule._split_train_val(range(100), val_split=0.1) assert len(train_ds) == 90 assert len(val_ds) == 10 assert len(np.unique(train_ds.indices)) == len(train_ds.indices) with pytest.raises(MisconfigurationException, match="[0, 99]"): SplitDataset(range(100), indices=[100]) with pytest.raises(MisconfigurationException, match="[0, 49]"): SplitDataset(range(50), indices=[-1]) with pytest.raises(MisconfigurationException, match="[0, 49]"): SplitDataset(list(range(50)) + list(range(50)), indices=[-1]) with pytest.raises(MisconfigurationException, match="[0, 99]"): SplitDataset(list(range(50)) + list(range(50)), indices=[-1], use_duplicated_indices=True) class Dataset: def __init__(self): self.data = [0, 1, 2] self.name = "something" def __getitem__(self, index): return self.data[index] def __len__(self): return len(self.data) split_dataset = SplitDataset(Dataset(), indices=[0]) assert split_dataset.name == "something" assert split_dataset._INTERNAL_KEYS == ("dataset", "indices", "data") split_dataset.is_passed_down = True assert split_dataset.dataset.is_passed_down
def test_split_dataset(): train_ds, val_ds = DataModule._split_train_val(range(100), val_split=0.1) assert len(train_ds) == 90 assert len(val_ds) == 10 assert len(np.unique(train_ds.indices)) == len(train_ds.indices) class Dataset: def __init__(self): self.data = [0, 1, 2] self.name = "something" self.is_passed_down = False def __getitem__(self, index): return self.data[index] def __len__(self): return len(self.data) split_dataset = SplitDataset(Dataset(), indices=[0]) assert split_dataset.name == "something" split_dataset.is_passed_down = True assert not split_dataset.dataset.is_passed_down
def test_not_trainable(tmpdir): """Tests that the model gives an error when training, validating, or testing.""" tudataset = datasets.TUDataset(root=tmpdir, name="KKI") model = GraphEmbedder( GraphClassifier(num_features=1, num_classes=1).backbone) datamodule = DataModule( GraphClassificationDatasetInput(RunningStage.TRAINING, tudataset), GraphClassificationDatasetInput(RunningStage.VALIDATING, tudataset), GraphClassificationDatasetInput(RunningStage.TESTING, tudataset), transform=GraphClassificationInputTransform, batch_size=4, ) trainer = Trainer(default_root_dir=tmpdir, num_sanity_val_steps=0) with pytest.raises(NotImplementedError, match="Training a `GraphEmbedder` is not supported."): trainer.fit(model, datamodule=datamodule) with pytest.raises(NotImplementedError, match="Validating a `GraphEmbedder` is not supported."): trainer.validate(model, datamodule=datamodule) with pytest.raises(NotImplementedError, match="Testing a `GraphEmbedder` is not supported."): trainer.test(model, datamodule=datamodule)
T.Resize(self.image_size), T.ToTensor(), T.RandomRotation(self.rotation) ] return T.Compose(transforms) def input_per_sample_transform(self) -> Callable: # this will be used to transform only the input value associated with # the `input` key within each sample. transforms = [T.Resize(self.image_size), T.ToTensor()] return T.Compose(transforms) # Register your transform within the InputTransform registry of the Flash DataModule # Note: Registries can be shared by multiple dataset. DataModule.register_input_transform("base", BaseImageInputTransform) DataModule.register_input_transform("random_rotation", ImageRandomRotationInputTransform) DataModule.register_input_transform( "random_90_def_rotation", partial(ImageRandomRotationInputTransform, rotation=90)) ############################################################################################# # Step 3 / 3: Create a DataModule (Part 1) # # # # The `DataModule` class is a collection of `Input` for various stages and the # # `InputTransform` and you can pass them directly to its init function. # # # ############################################################################################# datamodule = DataModule(
def test_flash_callback(_, __, tmpdir): """Test the callback hook system for fit.""" callback_mock = mock.MagicMock() inputs = [(torch.rand(1), torch.rand(1))] transform = InputTransform() dm = DataModule( DatasetInput(RunningStage.TRAINING, inputs), DatasetInput(RunningStage.VALIDATING, inputs), DatasetInput(RunningStage.TESTING, inputs), transform=transform, batch_size=1, num_workers=0, data_fetcher=callback_mock, ) _ = next(iter(dm.train_dataloader())) assert callback_mock.method_calls == [ mock.call.on_load_sample(mock.ANY, RunningStage.TRAINING), mock.call.on_per_sample_transform(mock.ANY, RunningStage.TRAINING), mock.call.on_collate(mock.ANY, RunningStage.TRAINING), mock.call.on_per_batch_transform(mock.ANY, RunningStage.TRAINING), ] class CustomModel(Task): def __init__(self): super().__init__(model=torch.nn.Linear(1, 1), loss_fn=torch.nn.MSELoss()) def training_step(self, batch, batch_idx): batch = (batch[DataKeys.INPUT], batch[DataKeys.TARGET]) return super().training_step(batch, batch_idx) def validation_step(self, batch, batch_idx): batch = (batch[DataKeys.INPUT], batch[DataKeys.TARGET]) return super().validation_step(batch, batch_idx) def test_step(self, batch, batch_idx): batch = (batch[DataKeys.INPUT], batch[DataKeys.TARGET]) return super().test_step(batch, batch_idx) trainer = Trainer( default_root_dir=tmpdir, max_epochs=1, limit_val_batches=1, limit_train_batches=1, progress_bar_refresh_rate=0, ) transform = InputTransform() dm = DataModule( DatasetInput(RunningStage.TRAINING, inputs), DatasetInput(RunningStage.VALIDATING, inputs), DatasetInput(RunningStage.TESTING, inputs), transform=transform, batch_size=1, num_workers=0, data_fetcher=callback_mock, ) trainer.fit(CustomModel(), datamodule=dm) assert callback_mock.method_calls == [ mock.call.on_load_sample(mock.ANY, RunningStage.TRAINING), mock.call.on_per_sample_transform(mock.ANY, RunningStage.TRAINING), mock.call.on_collate(mock.ANY, RunningStage.TRAINING), mock.call.on_per_batch_transform(mock.ANY, RunningStage.TRAINING), mock.call.on_load_sample(mock.ANY, RunningStage.VALIDATING), mock.call.on_per_sample_transform(mock.ANY, RunningStage.VALIDATING), mock.call.on_collate(mock.ANY, RunningStage.VALIDATING), mock.call.on_per_batch_transform(mock.ANY, RunningStage.VALIDATING), mock.call.on_per_batch_transform_on_device(mock.ANY, RunningStage.VALIDATING), mock.call.on_load_sample(mock.ANY, RunningStage.TRAINING), mock.call.on_per_sample_transform(mock.ANY, RunningStage.TRAINING), mock.call.on_collate(mock.ANY, RunningStage.TRAINING), mock.call.on_per_batch_transform(mock.ANY, RunningStage.TRAINING), mock.call.on_per_batch_transform_on_device(mock.ANY, RunningStage.TRAINING), mock.call.on_load_sample(mock.ANY, RunningStage.VALIDATING), mock.call.on_per_sample_transform(mock.ANY, RunningStage.VALIDATING), mock.call.on_collate(mock.ANY, RunningStage.VALIDATING), mock.call.on_per_batch_transform(mock.ANY, RunningStage.VALIDATING), mock.call.on_per_batch_transform_on_device(mock.ANY, RunningStage.VALIDATING), ]
def test_deepcopy(): """Tests that deepcopy works with the ``SplitDataset``.""" dataset = list(range(100)) train_ds, val_ds = DataModule._split_train_val(dataset, val_split=0.1) deepcopy(train_ds)
def test_dataset_data_source(): dm = DataModule.from_datasets(range(10), range(10)) assert dm.train_dataset.sample == {DefaultDataKeys.INPUT: 0}
def test_data_module(): seed_everything(42) def train_fn(data): return data - 100 def val_fn(data): return data + 100 def test_fn(data): return data - 1000 def predict_fn(data): return data + 1000 @dataclass class TestTransform(InputTransform): def per_sample_transform(self): def fn(x): return x return fn def train_per_batch_transform_on_device(self) -> Callable: return train_fn def val_per_batch_transform_on_device(self) -> Callable: return val_fn def test_per_batch_transform_on_device(self) -> Callable: return test_fn def predict_per_batch_transform_on_device(self) -> Callable: return predict_fn transform = TestTransform() assert transform._transform is not None train_dataset = Input(RunningStage.TRAINING, np.arange(10, dtype=np.float32)) assert train_dataset.running_stage == RunningStage.TRAINING val_dataset = Input(RunningStage.VALIDATING, np.arange(10, dtype=np.float32)) assert val_dataset.running_stage == RunningStage.VALIDATING test_dataset = Input(RunningStage.TESTING, np.arange(10, dtype=np.float32)) assert test_dataset.running_stage == RunningStage.TESTING predict_dataset = Input(RunningStage.PREDICTING, np.arange(10, dtype=np.float32)) assert predict_dataset.running_stage == RunningStage.PREDICTING dm = DataModule( train_input=train_dataset, val_input=val_dataset, test_input=test_dataset, predict_input=predict_dataset, transform=transform, batch_size=2, ) assert len(dm.train_dataloader()) == 5 batch = next(iter(dm.train_dataloader())) assert batch.shape == torch.Size([2]) assert batch.min() >= 0 and batch.max() < 10 assert len(dm.val_dataloader()) == 5 batch = next(iter(dm.val_dataloader())) assert batch.shape == torch.Size([2]) assert batch.min() >= 0 and batch.max() < 10 class TestModel(Task): def training_step(self, batch, batch_idx): assert sum(batch < 0) == 2 def validation_step(self, batch, batch_idx): assert sum(batch > 0) == 2 def test_step(self, batch, batch_idx): assert sum(batch < 500) == 2 def predict_step(self, batch, *args, **kwargs): assert sum(batch > 500) == 2 assert torch.equal(batch, torch.tensor([1000.0, 1001.0])) def on_train_dataloader(self) -> None: pass def on_val_dataloader(self) -> None: pass def on_test_dataloader(self, *_) -> None: pass def on_predict_dataloader(self) -> None: pass def on_predict_end(self) -> None: pass def on_fit_end(self) -> None: pass model = TestModel(torch.nn.Linear(1, 1)) trainer = Trainer(fast_dev_run=True) trainer.fit(model, datamodule=dm) trainer.validate(model, datamodule=dm) trainer.test(model, datamodule=dm) trainer.predict(model, datamodule=dm) # Test that plain lightning module works with FlashDataModule class SampleBoringModel(BoringModel): def __init__(self): super().__init__() self.layer = torch.nn.Linear(2, 1) model = SampleBoringModel() trainer = Trainer(fast_dev_run=True) trainer.fit(model, datamodule=dm) trainer.validate(model, datamodule=dm) trainer.test(model, datamodule=dm) trainer.predict(model, datamodule=dm) transform = TestTransform() input = Input(RunningStage.TRAINING) dm = DataModule(train_input=input, batch_size=1, transform=transform) assert isinstance(dm.input_transform, TestTransform) class RandomDataset(Dataset): def __init__(self, size: int, length: int): self.len = length self.data = torch.ones(length, size) def __getitem__(self, index): return self.data[index] def __len__(self): return self.len def _add_hundred(x): if isinstance(x, Dict): x["input"] += 100 else: x += 100 return x class TrainInputTransform(InputTransform): def _add_one(self, x): if isinstance(x, Dict): x["input"] += 1 else: x += 1 return x def per_sample_transform(self) -> Callable: return self._add_one def val_per_sample_transform(self) -> Callable: return _add_hundred dm = DataModule( train_input=DatasetInput(RunningStage.TRAINING, RandomDataset(64, 32)), val_input=DatasetInput(RunningStage.VALIDATING, RandomDataset(64, 32)), test_input=DatasetInput(RunningStage.TESTING, RandomDataset(64, 32)), batch_size=3, transform=TrainInputTransform(), ) batch = next(iter(dm.train_dataloader())) assert batch["input"][0][0] == 2 batch = next(iter(dm.val_dataloader())) assert batch["input"][0][0] == 101 batch = next(iter(dm.test_dataloader())) assert batch["input"][0][0] == 2
def test_flash_callback(_, __, tmpdir): """Test the callback hook system for fit.""" callback_mock = MagicMock() inputs = [[torch.rand(1), torch.rand(1)]] dm = DataModule.from_data_source("default", inputs, inputs, inputs, None, preprocess=DefaultPreprocess(), batch_size=1, num_workers=0) dm.preprocess.callbacks += [callback_mock] _ = next(iter(dm.train_dataloader())) assert callback_mock.method_calls == [ call.on_load_sample(ANY, RunningStage.TRAINING), call.on_pre_tensor_transform(ANY, RunningStage.TRAINING), call.on_to_tensor_transform(ANY, RunningStage.TRAINING), call.on_post_tensor_transform(ANY, RunningStage.TRAINING), call.on_collate(ANY, RunningStage.TRAINING), call.on_per_batch_transform(ANY, RunningStage.TRAINING), ] class CustomModel(Task): def __init__(self): super().__init__(model=torch.nn.Linear(1, 1), loss_fn=torch.nn.MSELoss()) trainer = Trainer( default_root_dir=tmpdir, max_epochs=1, limit_val_batches=1, limit_train_batches=1, progress_bar_refresh_rate=0, ) dm = DataModule.from_data_source("default", inputs, inputs, inputs, None, preprocess=DefaultPreprocess(), batch_size=1, num_workers=0) dm.preprocess.callbacks += [callback_mock] trainer.fit(CustomModel(), datamodule=dm) assert callback_mock.method_calls == [ call.on_load_sample(ANY, RunningStage.TRAINING), call.on_pre_tensor_transform(ANY, RunningStage.TRAINING), call.on_to_tensor_transform(ANY, RunningStage.TRAINING), call.on_post_tensor_transform(ANY, RunningStage.TRAINING), call.on_collate(ANY, RunningStage.TRAINING), call.on_per_batch_transform(ANY, RunningStage.TRAINING), call.on_load_sample(ANY, RunningStage.VALIDATING), call.on_pre_tensor_transform(ANY, RunningStage.VALIDATING), call.on_to_tensor_transform(ANY, RunningStage.VALIDATING), call.on_post_tensor_transform(ANY, RunningStage.VALIDATING), call.on_collate(ANY, RunningStage.VALIDATING), call.on_per_batch_transform(ANY, RunningStage.VALIDATING), call.on_per_batch_transform_on_device(ANY, RunningStage.VALIDATING), call.on_load_sample(ANY, RunningStage.TRAINING), call.on_pre_tensor_transform(ANY, RunningStage.TRAINING), call.on_to_tensor_transform(ANY, RunningStage.TRAINING), call.on_post_tensor_transform(ANY, RunningStage.TRAINING), call.on_collate(ANY, RunningStage.TRAINING), call.on_per_batch_transform(ANY, RunningStage.TRAINING), call.on_per_batch_transform_on_device(ANY, RunningStage.TRAINING), call.on_load_sample(ANY, RunningStage.VALIDATING), call.on_pre_tensor_transform(ANY, RunningStage.VALIDATING), call.on_to_tensor_transform(ANY, RunningStage.VALIDATING), call.on_post_tensor_transform(ANY, RunningStage.VALIDATING), call.on_collate(ANY, RunningStage.VALIDATING), call.on_per_batch_transform(ANY, RunningStage.VALIDATING), call.on_per_batch_transform_on_device(ANY, RunningStage.VALIDATING), ]