Exemple #1
0
    def train_model(
        self,
        training_data: Optional[Dataset] = None,
        validation_data: Optional[Dataset] = None,
        num_workers: Optional[int] = None,
        num_prefetch: Optional[int] = None,
        shuffle_buffer_length: Optional[int] = None,
        cache_data: bool = False,
    ) -> TrainOutput:
        transformation = self.create_transformation()

        transformed_training_data = TransformedDataset(
            training_data, transformation
        )

        training_data_loader = self.create_training_data_loader(
            transformed_training_data
            if not cache_data
            else Cached(transformed_training_data),
            num_workers=num_workers,
            num_prefetch=num_prefetch,
            shuffle_buffer_length=shuffle_buffer_length,
        )

        validation_data_loader = None

        if validation_data is not None:
            transformed_validation_data = TransformedDataset(
                validation_data, transformation
            )

            validation_data_loader = self.create_validation_data_loader(
                transformed_validation_data
                if not cache_data
                else Cached(transformed_validation_data),
                num_workers=num_workers,
            )

        training_network = self.create_training_network()

        self.trainer(
            net=training_network,
            train_iter=training_data_loader,
            validation_iter=validation_data_loader,
        )

        with self.trainer.ctx:
            predictor = self.create_predictor(transformation, training_network)

        return TrainOutput(
            transformation=transformation,
            trained_net=training_network,
            predictor=predictor,
        )
Exemple #2
0
 def create_training_data_loader(self, data: Dataset, network: nn.Module,
                                 **kwargs):
     data_loader = TrainDataLoader(
         Cached(data),
         batch_size=self.batch_size,
         stack_fn=batchify,
         transform=self.create_transformation() +
         self._create_instance_splitter("training"),
         num_batches_per_epoch=self.num_batches_per_epoch,
     )
     return data_loader
Exemple #3
0
    def __init__(
        self,
        dataset: Dataset,
        transform: Transformation,
        is_train: bool = True,
        shuffle_buffer_length: Optional[int] = None,
        cache_data: bool = False,
    ):
        super().__init__()
        self.shuffle_buffer_length = shuffle_buffer_length

        self.transformed_dataset = TransformedDataset(
            Cyclic(dataset) if not cache_data else Cached(Cyclic(dataset)),
            transform,
            is_train=is_train,
        )
Exemple #4
0
        range(20),
        constant_dataset()[1],
    ],
)
def test_pseudo_shuffled(data: Iterable) -> None:
    list_data = list(data)
    shuffled_iter = PseudoShuffled(iter(list_data), shuffle_buffer_length=5)
    shuffled_data = list(shuffled_iter)
    assert len(shuffled_data) == len(list_data)
    assert all(d in shuffled_data for d in list_data)


@pytest.mark.parametrize(
    "data, expected_elements_per_iteration",
    [
        (Cached(range(4)), (list(range(4)), ) * 5),
        (batcher(range(10), 3), ([[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]], [])),
        (IterableSlice(range(10), 3), ([0, 1, 2], ) * 5),
        (
            IterableSlice(iter(range(10)), 3),
            ([0, 1, 2], [3, 4, 5], [6, 7, 8], [9], []),
        ),
        (
            IterableSlice(iter(Cyclic(range(5))), 3),
            ([0, 1, 2], [3, 4, 0], [1, 2, 3], [4, 0, 1]),
        ),
    ],
)
def test_iterate_multiple_times(data: Iterable,
                                expected_elements_per_iteration: Tuple[List]):
    for expected_elements in expected_elements_per_iteration:
Exemple #5
0
    def train_model(
        self,
        training_data: Dataset,
        validation_data: Optional[Dataset] = None,
        num_workers: int = 0,
        shuffle_buffer_length: Optional[int] = None,
        cache_data: bool = False,
        **kwargs,
    ) -> TrainOutput:
        transformation = self.create_transformation()

        transformed_training_data = transformation.apply(training_data,
                                                         is_train=True)

        training_network = self.create_lightning_module()

        training_data_loader = self.create_training_data_loader(
            transformed_training_data
            if not cache_data else Cached(transformed_training_data),
            training_network,
            num_workers=num_workers,
            shuffle_buffer_length=shuffle_buffer_length,
        )

        validation_data_loader = None

        if validation_data is not None:
            transformed_validation_data = transformation.apply(validation_data,
                                                               is_train=True)

            validation_data_loader = self.create_validation_data_loader(
                transformed_validation_data
                if not cache_data else Cached(transformed_validation_data),
                training_network,
            )

        monitor = "train_loss" if validation_data is None else "val_loss"
        checkpoint = pl.callbacks.ModelCheckpoint(monitor=monitor,
                                                  mode="min",
                                                  verbose=True)

        custom_callbacks = self.trainer_kwargs.get("callbacks", [])
        callbacks = [checkpoint] + custom_callbacks
        trainer_kwargs = {**self.trainer_kwargs, "callbacks": callbacks}
        trainer = pl.Trainer(**trainer_kwargs)

        trainer.fit(
            model=training_network,
            train_dataloaders=training_data_loader,
            val_dataloaders=validation_data_loader,
        )

        logger.info(f"Loading best model from {checkpoint.best_model_path}")
        best_model = training_network.load_from_checkpoint(
            checkpoint.best_model_path)

        return TrainOutput(
            transformation=transformation,
            trained_net=best_model,
            trainer=trainer,
            predictor=self.create_predictor(transformation, best_model),
        )