def train_model( self, training_data: Optional[Dataset] = None, validation_data: Optional[Dataset] = None, num_workers: Optional[int] = None, num_prefetch: Optional[int] = None, shuffle_buffer_length: Optional[int] = None, cache_data: bool = False, ) -> TrainOutput: transformation = self.create_transformation() transformed_training_data = TransformedDataset( training_data, transformation ) training_data_loader = self.create_training_data_loader( transformed_training_data if not cache_data else Cached(transformed_training_data), num_workers=num_workers, num_prefetch=num_prefetch, shuffle_buffer_length=shuffle_buffer_length, ) validation_data_loader = None if validation_data is not None: transformed_validation_data = TransformedDataset( validation_data, transformation ) validation_data_loader = self.create_validation_data_loader( transformed_validation_data if not cache_data else Cached(transformed_validation_data), num_workers=num_workers, ) training_network = self.create_training_network() self.trainer( net=training_network, train_iter=training_data_loader, validation_iter=validation_data_loader, ) with self.trainer.ctx: predictor = self.create_predictor(transformation, training_network) return TrainOutput( transformation=transformation, trained_net=training_network, predictor=predictor, )
def create_training_data_loader(self, data: Dataset, network: nn.Module, **kwargs): data_loader = TrainDataLoader( Cached(data), batch_size=self.batch_size, stack_fn=batchify, transform=self.create_transformation() + self._create_instance_splitter("training"), num_batches_per_epoch=self.num_batches_per_epoch, ) return data_loader
def __init__( self, dataset: Dataset, transform: Transformation, is_train: bool = True, shuffle_buffer_length: Optional[int] = None, cache_data: bool = False, ): super().__init__() self.shuffle_buffer_length = shuffle_buffer_length self.transformed_dataset = TransformedDataset( Cyclic(dataset) if not cache_data else Cached(Cyclic(dataset)), transform, is_train=is_train, )
range(20), constant_dataset()[1], ], ) def test_pseudo_shuffled(data: Iterable) -> None: list_data = list(data) shuffled_iter = PseudoShuffled(iter(list_data), shuffle_buffer_length=5) shuffled_data = list(shuffled_iter) assert len(shuffled_data) == len(list_data) assert all(d in shuffled_data for d in list_data) @pytest.mark.parametrize( "data, expected_elements_per_iteration", [ (Cached(range(4)), (list(range(4)), ) * 5), (batcher(range(10), 3), ([[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]], [])), (IterableSlice(range(10), 3), ([0, 1, 2], ) * 5), ( IterableSlice(iter(range(10)), 3), ([0, 1, 2], [3, 4, 5], [6, 7, 8], [9], []), ), ( IterableSlice(iter(Cyclic(range(5))), 3), ([0, 1, 2], [3, 4, 0], [1, 2, 3], [4, 0, 1]), ), ], ) def test_iterate_multiple_times(data: Iterable, expected_elements_per_iteration: Tuple[List]): for expected_elements in expected_elements_per_iteration:
def train_model( self, training_data: Dataset, validation_data: Optional[Dataset] = None, num_workers: int = 0, shuffle_buffer_length: Optional[int] = None, cache_data: bool = False, **kwargs, ) -> TrainOutput: transformation = self.create_transformation() transformed_training_data = transformation.apply(training_data, is_train=True) training_network = self.create_lightning_module() training_data_loader = self.create_training_data_loader( transformed_training_data if not cache_data else Cached(transformed_training_data), training_network, num_workers=num_workers, shuffle_buffer_length=shuffle_buffer_length, ) validation_data_loader = None if validation_data is not None: transformed_validation_data = transformation.apply(validation_data, is_train=True) validation_data_loader = self.create_validation_data_loader( transformed_validation_data if not cache_data else Cached(transformed_validation_data), training_network, ) monitor = "train_loss" if validation_data is None else "val_loss" checkpoint = pl.callbacks.ModelCheckpoint(monitor=monitor, mode="min", verbose=True) custom_callbacks = self.trainer_kwargs.get("callbacks", []) callbacks = [checkpoint] + custom_callbacks trainer_kwargs = {**self.trainer_kwargs, "callbacks": callbacks} trainer = pl.Trainer(**trainer_kwargs) trainer.fit( model=training_network, train_dataloaders=training_data_loader, val_dataloaders=validation_data_loader, ) logger.info(f"Loading best model from {checkpoint.best_model_path}") best_model = training_network.load_from_checkpoint( checkpoint.best_model_path) return TrainOutput( transformation=transformation, trained_net=best_model, trainer=trainer, predictor=self.create_predictor(transformation, best_model), )